{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "\n",
    "import numpy as np\n",
    "import math\n",
    "import torch\n",
    "import gpytorch\n",
    "\n",
    "from matplotlib import cm, pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "\n",
    "import sys, os, json, time\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SARKernel(gpytorch.kernels.RBFKernel):\n",
    "    def __init__(\n",
    "        self,\n",
    "        ard_num_dims=None,\n",
    "        batch_size=1,\n",
    "        active_dims=(0, 1),\n",
    "        lengthscale_prior=None,\n",
    "        eps=1e-6,\n",
    "        min_altitude = 1.0,\n",
    "        **kwargs\n",
    "    ):\n",
    "        super(SARKernel, self).__init__(\n",
    "            has_lengthscale=True,\n",
    "            ard_num_dims=ard_num_dims,\n",
    "            batch_size=batch_size,\n",
    "            active_dims=active_dims,\n",
    "            lengthscale_prior=lengthscale_prior,\n",
    "            eps=eps\n",
    "        )\n",
    "        self.min_altitude = min_altitude\n",
    "\n",
    "        def forward(self, x1, x2, **params):\n",
    "            if x1.shape[0] == 3: # is X1 has the altitude\n",
    "                raise Exception('X1 has altitude')\n",
    "                altitude = x1[2]\n",
    "                x1 = x1[:2]\n",
    "            elif x2.shape[0] == 3: # is X2 has the altitude\n",
    "                raise Exception('X2 has altitude')\n",
    "                altitude = x2[2]\n",
    "                x2 = x2[:2]\n",
    "            else:\n",
    "                raise Exception('Default altitude')\n",
    "                altitude = self.min_altitude\n",
    "                \n",
    "            x1_ = x1.div(self.lengthscale * altitude)\n",
    "            x2_ = x2.div(self.lengthscale * altitude)\n",
    "            diff = self._covar_dist(x1_, x2_, square_dist=True, diag=diag, \\\n",
    "                                    dist_postprocess_func=postprocess_rbf, **params)\n",
    "            return diff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 500 random locations in [(-10,-10), (10, 10)] square\n",
    "# and probabilities to be associated with these locations\n",
    "xylim = 10\n",
    "num_points = 500\n",
    "std_max = 0.0\n",
    "\n",
    "locs = ( np.random.rand(num_points, 2) * 2*xylim ) - xylim\n",
    "locs = np.hstack(( locs, 10 * np.ones( (num_points, 1) ) ))\n",
    "\n",
    "# locs = np.vstack( ( np.linspace( -xylim, xylim, num_points ), \\\n",
    "#                     np.linspace( -xylim, xylim, num_points ) ) ).transpose()\n",
    "locs.sort(axis=0)\n",
    "locs = np.asarray(locs, dtype=np.float32)\n",
    "\n",
    "probs = np.sin(locs[:,1]) #np.random.rand(num_points,)\n",
    "\n",
    "prob_std = ( np.random.rand(num_points,) * 2*std_max ) - std_max\n",
    "# prob_std = std_max * np.ones_like(probs)\n",
    "\n",
    "probs_new = (probs * (1-std_max)) + prob_std\n",
    "probs_new = np.asarray(probs_new, dtype=np.float32)\n",
    "\n",
    "prob_var = np.asarray(prob_std**2, dtype=np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GPRegressionModel(gpytorch.models.ExactGP):\n",
    "        def __init__(self, train_x, train_y, likelihood):\n",
    "            super(GPRegressionModel, self).__init__(train_x, train_y, likelihood)\n",
    "            self.mean_module = gpytorch.means.ConstantMean().cuda()\n",
    "            self.covar_module = SARKernel().cuda()\n",
    "#             self.covar_module = gpytorch.kernels.PeriodicKernel().cuda()\n",
    "\n",
    "        def forward(self, x):\n",
    "            mean_x = self.mean_module(x)\n",
    "            covar_x = self.covar_module(x)\n",
    "            return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "cuda runtime error (30) : unknown error at /opt/conda/conda-bld/pytorch_1549630534704/work/aten/src/THC/THCGeneral.cpp:51",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-5-15e6da4b74ae>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtrain_x\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_numpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlocs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0mtrain_y\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_numpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprobs_new\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0mtrain_y_var\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_numpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprob_var\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m log_noise_model = GPRegressionModel(\n",
      "\u001b[0;32m~/anaconda3/envs/sar/lib/python3.6/site-packages/torch/cuda/__init__.py\u001b[0m in \u001b[0;36m_lazy_init\u001b[0;34m()\u001b[0m\n\u001b[1;32m    160\u001b[0m             \"Cannot re-initialize CUDA in forked subprocess. \" + msg)\n\u001b[1;32m    161\u001b[0m     \u001b[0m_check_driver\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m     \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cuda_init\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    163\u001b[0m     \u001b[0m_cudart\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_load_cudart\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    164\u001b[0m     \u001b[0m_cudart\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcudaGetErrorName\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrestype\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mctypes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mc_char_p\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: cuda runtime error (30) : unknown error at /opt/conda/conda-bld/pytorch_1549630534704/work/aten/src/THC/THCGeneral.cpp:51"
     ]
    }
   ],
   "source": [
    "train_x = torch.from_numpy(locs).cuda()\n",
    "train_y = torch.from_numpy(probs_new).cuda()\n",
    "train_y_var = torch.from_numpy(prob_var).cuda()\n",
    "\n",
    "log_noise_model = GPRegressionModel(\n",
    "    train_x,\n",
    "    train_y_var,\n",
    "    gpytorch.likelihoods.GaussianLikelihood().cuda()\n",
    ")\n",
    "\n",
    "likelihood = gpytorch.likelihoods.GaussianLikelihood().cuda()\n",
    "model = GPRegressionModel( train_x, train_y, likelihood ).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(10, 10))\n",
    "ax = fig.gca(projection='3d')\n",
    "ax.scatter( locs[:, 1], locs[:, 0], probs_new, c=probs_new, cmap='cool' )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(10, 10))\n",
    "ax = fig.gca(projection='3d')\n",
    "ax.plot( locs[:, 1], locs[:, 0], probs_new )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Find optimal model hyperparameters\n",
    "model.train()\n",
    "likelihood.train()\n",
    "\n",
    "# Use the adam optimizer\n",
    "optimizer = torch.optim.Adam([\n",
    "    {'params': model.parameters()},  # Includes GaussianLikelihood parameters\n",
    "], lr=0.1)\n",
    "\n",
    "# \"Loss\" for GPs - the marginal log likelihood\n",
    "mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)\n",
    "\n",
    "n_iter = 75\n",
    "tic = time.time()\n",
    "for i in range(n_iter):\n",
    "    optimizer.zero_grad()\n",
    "    output = model(train_x)\n",
    "    loss = -mll(output, train_y, train_x)\n",
    "    loss.backward()\n",
    "    if (i+1) % 10 == 0:\n",
    "        print('Iter %d/%d - Loss: %.3f' % (i + 1, n_iter, loss.item()))\n",
    "    optimizer.step()\n",
    "    \n",
    "print(f\"wall time: {time.time() - tic:.2f} sec\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "likelihood.eval()\n",
    "\n",
    "test_xylim = 15\n",
    "test_num_points = 150\n",
    "\n",
    "test_locs = np.vstack( (  np.linspace(-test_xylim, test_xylim, test_num_points), \\\n",
    "                          np.linspace(test_xylim, -test_xylim, test_num_points),\\\n",
    "                          10 * np.ones((1, test_num_points))\n",
    "                       ) ).transpose()\n",
    "test_locs.sort(axis=0)\n",
    "test_locs = np.asarray(test_locs, dtype=np.float32)\n",
    "\n",
    "with torch.no_grad(), gpytorch.fast_pred_var():\n",
    "    test_x = torch.from_numpy(test_locs).cuda()\n",
    "    post_f = model(test_x)\n",
    "    post_obs = likelihood(post_f, test_x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    fig = plt.figure(figsize=(10, 10))\n",
    "    ax = fig.gca()\n",
    "\n",
    "    lower_f, upper_f = post_f.confidence_region()\n",
    "    lower_f, upper_f = lower_f.cpu(), upper_f.cpu()\n",
    "    \n",
    "    lower_obs, upper_obs = post_obs.confidence_region()\n",
    "    lower_obs, upper_obs = lower_obs.cpu(), upper_obs.cpu()\n",
    "\n",
    "    post_f__mean_cpu = post_f.mean.cpu().numpy()\n",
    "    ax.scatter( test_locs[:, 1], test_locs[:, 0], c=post_f__mean_cpu, cmap='cool' )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    fig = plt.figure(figsize=(10, 10))\n",
    "    ax = plt.gca(projection='3d')\n",
    "    \n",
    "    ax.scatter( test_locs[:, 1], test_locs[:, 0], post_f__mean_cpu, c=post_f__mean_cpu, cmap='winter' )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    fig = plt.figure(figsize=(10, 10))\n",
    "    ax = plt.gca()\n",
    "    \n",
    "    ax.scatter( test_locs[:, 1], post_f__mean_cpu, c=post_f__mean_cpu, cmap='winter' )\n",
    "    ax.plot(locs[:, 1], probs_new, 'rx')\n",
    "\n",
    "    ax.plot( test_locs[:, 1], post_f__mean_cpu, 'b')\n",
    "    \n",
    "    ax.fill_between(test_locs[:, 1], lower_f.numpy(), upper_f.numpy(), alpha=0.5)#, color='k')\n",
    "    ax.fill_between(test_locs[:, 1], lower_obs.numpy(), upper_obs.numpy(), alpha=0.25, color='r')\n",
    "    \n",
    "    ax.plot( locs[:, 1], probs, 'gx')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}