Commit 9463e157 authored by Bryson Howell's avatar Bryson Howell

code working with Torch=1.13 now

parent 1af694d9
......@@ -18,6 +18,7 @@ torch.set_printoptions(precision=4, threshold=5)
import pyro
import pdb
import inspect
class RobotGP(torch.nn.Module):
def __init__(self, mc_handle, planner, meas_std=1e0, lsZ=5e0, _stime=None, parameters = {}):
......@@ -35,7 +36,7 @@ class RobotGP(torch.nn.Module):
self.update_searcher_paths = True
self.grad_clamp_value = 1e50
self.grad_clamp = lambda grad: torch.clamp(grad, -self.grad_clamp_value, self.grad_clamp_value)
#self.grad_clamp = lambda grad: torch.clamp(grad, -self.grad_clamp_value, self.grad_clamp_value)
# self.X_batch.tregister_hook(lambda grad: torch.clamp(grad, -self.grad_clamp, self.grad_clamp))
self.init_Xstar()
......@@ -83,13 +84,15 @@ class RobotGP(torch.nn.Module):
self._stime = _stime if _stime is not None else time.time()
#!!Memory problem is here @_@. Hook is triggered by backward(forward too?)
#!!Memory problem is here @_@. This registers backwards hooks
#What is this hook doing though. Why do we need it
def _hook(self, tensor, _min=-1e10, _max=1e10):
print(tensor)
curframe = inspect.currentframe()
calframe = inspect.getouterframes(curframe,2)
print(calframe[1][3])
tensor.requires_grad_(True)
#tensor.retain_grad()
tensor.register_hook(lambda grad: grad.clamp_(min=_min, max=_max))
tensor.register_hook(lambda grad: grad.clamp(min=_min, max=_max))
# tensor.register_hook(lambda grad: grad.clamp_(-self.grad_clamp_value, self.grad_clamp_value))
return tensor
......@@ -161,7 +164,6 @@ class RobotGP(torch.nn.Module):
# self.robot_values = self._hook(self.robot_values)
self.num_train = self.fixed_points.shape[0] + self.robot_points.shape[0]
self.update_trainXY()
def update_Xstar(self):
......@@ -361,7 +363,8 @@ class RobotGP(torch.nn.Module):
self.scaled_risk_cost = self.omega.sum() * self.riskcost_scaling
self.scaled_path_length_cost = self.lengthcost_scaling * self.path_len_cost
self.risk_cost = self.scaled_risk_cost + self.scaled_path_length_cost
#self.risk_cost = self.scaled_risk_cost + self.scaled_path_length_cost
self.risk_cost = torch.add(self.scaled_risk_cost, self.scaled_path_length_cost)
self.risk_cost = self._hook(self.risk_cost)
print("!!Ran hook on Risk cost, why are we doing it again...\n\n")
print(self.risk_cost)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment