Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
P
planning_llh_bgc
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Larkin Heintzman
planning_llh_bgc
Commits
9463e157
Commit
9463e157
authored
Feb 26, 2024
by
Bryson Howell
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
code working with Torch=1.13 now
parent
1af694d9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
6 deletions
+9
-6
robotgp.py
gp/robotgp.py
+9
-6
No files found.
gp/robotgp.py
View file @
9463e157
...
...
@@ -18,6 +18,7 @@ torch.set_printoptions(precision=4, threshold=5)
import
pyro
import
pdb
import
inspect
class
RobotGP
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
mc_handle
,
planner
,
meas_std
=
1e0
,
lsZ
=
5e0
,
_stime
=
None
,
parameters
=
{}):
...
...
@@ -35,7 +36,7 @@ class RobotGP(torch.nn.Module):
self
.
update_searcher_paths
=
True
self
.
grad_clamp_value
=
1e50
self
.
grad_clamp
=
lambda
grad
:
torch
.
clamp
(
grad
,
-
self
.
grad_clamp_value
,
self
.
grad_clamp_value
)
#
self.grad_clamp = lambda grad: torch.clamp(grad, -self.grad_clamp_value, self.grad_clamp_value)
# self.X_batch.tregister_hook(lambda grad: torch.clamp(grad, -self.grad_clamp, self.grad_clamp))
self
.
init_Xstar
()
...
...
@@ -83,13 +84,15 @@ class RobotGP(torch.nn.Module):
self
.
_stime
=
_stime
if
_stime
is
not
None
else
time
.
time
()
#!!Memory problem is here @_@.
Hook is triggered by backward(forward too?)
#!!Memory problem is here @_@.
This registers backwards hooks
#What is this hook doing though. Why do we need it
def
_hook
(
self
,
tensor
,
_min
=-
1e10
,
_max
=
1e10
):
print
(
tensor
)
curframe
=
inspect
.
currentframe
()
calframe
=
inspect
.
getouterframes
(
curframe
,
2
)
print
(
calframe
[
1
][
3
])
tensor
.
requires_grad_
(
True
)
#tensor.retain_grad()
tensor
.
register_hook
(
lambda
grad
:
grad
.
clamp
_
(
min
=
_min
,
max
=
_max
))
tensor
.
register_hook
(
lambda
grad
:
grad
.
clamp
(
min
=
_min
,
max
=
_max
))
# tensor.register_hook(lambda grad: grad.clamp_(-self.grad_clamp_value, self.grad_clamp_value))
return
tensor
...
...
@@ -161,7 +164,6 @@ class RobotGP(torch.nn.Module):
# self.robot_values = self._hook(self.robot_values)
self
.
num_train
=
self
.
fixed_points
.
shape
[
0
]
+
self
.
robot_points
.
shape
[
0
]
self
.
update_trainXY
()
def
update_Xstar
(
self
):
...
...
@@ -361,7 +363,8 @@ class RobotGP(torch.nn.Module):
self
.
scaled_risk_cost
=
self
.
omega
.
sum
()
*
self
.
riskcost_scaling
self
.
scaled_path_length_cost
=
self
.
lengthcost_scaling
*
self
.
path_len_cost
self
.
risk_cost
=
self
.
scaled_risk_cost
+
self
.
scaled_path_length_cost
#self.risk_cost = self.scaled_risk_cost + self.scaled_path_length_cost
self
.
risk_cost
=
torch
.
add
(
self
.
scaled_risk_cost
,
self
.
scaled_path_length_cost
)
self
.
risk_cost
=
self
.
_hook
(
self
.
risk_cost
)
print
(
"!!Ran hook on Risk cost, why are we doing it again...
\n\n
"
)
print
(
self
.
risk_cost
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment