Commit 872db6e3 authored by Bryson Howell's avatar Bryson Howell

Tests with trust scaling for constraint

parent b8fe4105
absl-py==2.1.0
anyio==3.6.2 anyio==3.6.2
arcgis==2.0.0 arcgis==2.0.0
argon2-cffi==21.3.0 argon2-cffi==21.3.0
...@@ -14,6 +15,7 @@ certifi==2022.9.24 ...@@ -14,6 +15,7 @@ certifi==2022.9.24
cffi==1.15.1 cffi==1.15.1
charset-normalizer==2.1.1 charset-normalizer==2.1.1
click==8.1.3 click==8.1.3
cloudpickle==3.0.0
contourpy==1.0.6 contourpy==1.0.6
cryptography==38.0.4 cryptography==38.0.4
cycler==0.11.0 cycler==0.11.0
...@@ -26,7 +28,10 @@ executing==1.2.0 ...@@ -26,7 +28,10 @@ executing==1.2.0
fastjsonschema==2.16.2 fastjsonschema==2.16.2
fonttools==4.38.0 fonttools==4.38.0
geomet==1.0.0 geomet==1.0.0
grpcio==1.64.1
gssapi==1.8.2 gssapi==1.8.2
gym==0.26.2
gym-notices==0.0.8
idna==3.4 idna==3.4
ipykernel==6.17.1 ipykernel==6.17.1
ipython==8.7.0 ipython==8.7.0
...@@ -35,6 +40,7 @@ ipywidgets==8.0.2 ...@@ -35,6 +40,7 @@ ipywidgets==8.0.2
jedi==0.18.2 jedi==0.18.2
jeepney==0.8.0 jeepney==0.8.0
Jinja2==3.1.2 Jinja2==3.1.2
joblib==1.4.2
json5==0.9.10 json5==0.9.10
jsonschema==4.17.3 jsonschema==4.17.3
jupyter-server==1.23.3 jupyter-server==1.23.3
...@@ -47,11 +53,14 @@ jupyterlab_server==2.16.3 ...@@ -47,11 +53,14 @@ jupyterlab_server==2.16.3
keyring==21.7.0 keyring==21.7.0
kiwisolver==1.4.4 kiwisolver==1.4.4
lerc==0.1.0 lerc==0.1.0
libtiff==0.4.2
lxml==4.9.1 lxml==4.9.1
Markdown==3.6
MarkupSafe==2.1.1 MarkupSafe==2.1.1
matplotlib==3.6.2 matplotlib==3.6.2
matplotlib-inline==0.1.6 matplotlib-inline==0.1.6
mistune==2.0.4 mistune==2.0.4
mpi4py==3.1.6
nbclassic==0.4.8 nbclassic==0.4.8
nbclient==0.7.2 nbclient==0.7.2
nbconvert==7.2.6 nbconvert==7.2.6
...@@ -66,6 +75,7 @@ nvidia-cuda-nvrtc-cu11==11.7.99 ...@@ -66,6 +75,7 @@ nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99 nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96 nvidia-cudnn-cu11==8.5.0.96
oauthlib==3.2.2 oauthlib==3.2.2
opencv-python==4.7.0.72
opt-einsum==3.3.0 opt-einsum==3.3.0
packaging==21.3 packaging==21.3
pandas==1.5.2 pandas==1.5.2
...@@ -78,10 +88,12 @@ platformdirs==2.5.4 ...@@ -78,10 +88,12 @@ platformdirs==2.5.4
plotly==5.11.0 plotly==5.11.0
prometheus-client==0.15.0 prometheus-client==0.15.0
prompt-toolkit==3.0.33 prompt-toolkit==3.0.33
protobuf==4.25.3
psutil==5.9.4 psutil==5.9.4
ptyprocess==0.7.0 ptyprocess==0.7.0
pure-eval==0.2.2 pure-eval==0.2.2
pycparser==2.21 pycparser==2.21
pygame==2.5.2
Pygments==2.13.0 Pygments==2.13.0
pyparsing==3.0.9 pyparsing==3.0.9
pyproj==3.4.0 pyproj==3.4.0
...@@ -108,6 +120,8 @@ sniffio==1.3.0 ...@@ -108,6 +120,8 @@ sniffio==1.3.0
soupsieve==2.3.2.post1 soupsieve==2.3.2.post1
stack-data==0.6.2 stack-data==0.6.2
tenacity==8.1.0 tenacity==8.1.0
tensorboard==2.17.0
tensorboard-data-server==0.7.2
terminado==0.17.1 terminado==0.17.1
tinycss2==1.2.1 tinycss2==1.2.1
tomli==2.0.1 tomli==2.0.1
...@@ -121,5 +135,6 @@ urllib3==1.25.11 ...@@ -121,5 +135,6 @@ urllib3==1.25.11
wcwidth==0.2.5 wcwidth==0.2.5
webencodings==0.5.1 webencodings==0.5.1
websocket-client==1.4.2 websocket-client==1.4.2
Werkzeug==3.0.3
widgetsnbextension==4.0.3 widgetsnbextension==4.0.3
wrapt==1.14.1 wrapt==1.14.1
...@@ -14,11 +14,14 @@ FORWARD_SCALE = 0.1 ...@@ -14,11 +14,14 @@ FORWARD_SCALE = 0.1
class GridWorldSAR(gym.Env): class GridWorldSAR(gym.Env):
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4} metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}
def __init__(self, render_mode=None, map_size=48): def __init__(self, render_mode=None, map_size=48, trust=0.5):
#Gym parameters #Gym parameters
self.size = map_size # The size of the square grid self.size = map_size # The size of the square grid
self.window_size = 512 # The size of the PyGame window self.window_size = 512 # The size of the PyGame window
#Trust value
self.trust = trust
#Values the agent will need #Values the agent will need
self.pos_x = 0.0 #Current position, update w/ velocity (Might need to make this the pos in 1200x1200 path coordinates?) self.pos_x = 0.0 #Current position, update w/ velocity (Might need to make this the pos in 1200x1200 path coordinates?)
self.pos_y = 0.0 self.pos_y = 0.0
...@@ -101,6 +104,8 @@ class GridWorldSAR(gym.Env): ...@@ -101,6 +104,8 @@ class GridWorldSAR(gym.Env):
if(self.cell[1] >= self.size): if(self.cell[1] >= self.size):
self.cell[1] = self.size-1 self.cell[1] = self.size-1
map_diff = abs(self.heatmap[self.cell[0],self.cell[1]] - self.ringmap[self.cell[0],self.cell[1]]) map_diff = abs(self.heatmap[self.cell[0],self.cell[1]] - self.ringmap[self.cell[0],self.cell[1]])
#Trust scaling experiment
map_diff = map_diff * (1.0 - self.trust)
return {'loc': [self.pos_x,self.pos_y], 'cost': map_diff} return {'loc': [self.pos_x,self.pos_y], 'cost': map_diff}
......
...@@ -210,9 +210,11 @@ def train(): ...@@ -210,9 +210,11 @@ def train():
#Run this function to create data #Run this function to create data
#create_data() #create_data()
trustval = 1.0
#Training CPO #Training CPO
J = cpo.cpo( J = cpo.cpo(
lambda: sar_gym.GridWorldSAR(render_mode=None), lambda: sar_gym.GridWorldSAR(render_mode=None,trust=trustval),
actor_critic=core.MLPActorCriticTD3trust, actor_critic=core.MLPActorCriticTD3trust,
max_ep_len=MAX_STEPS_PER_EPISODE, max_ep_len=MAX_STEPS_PER_EPISODE,
cost_lim=COST_THRESHOLD, cost_lim=COST_THRESHOLD,
...@@ -237,7 +239,7 @@ def train(): ...@@ -237,7 +239,7 @@ def train():
print(s.getvalue()) print(s.getvalue())
#Save policy #Save policy
torch.save(J.pi, './deeprl_data/trained_model4.pt') torch.save(J.pi, './trust_data/trained_model_100.pt')
#Testing #Testing
...@@ -283,12 +285,14 @@ def train(): ...@@ -283,12 +285,14 @@ def train():
def test_policy(): def test_policy():
#Testing #Testing
policy = torch.load('./deeprl_data/trained_model2.pt') policy = torch.load('./trust_data/trained_model_100.pt')
heatmap = np.load('./deeprl_data/lpm/lpm4_prob.npy') heatmap = np.load('./deeprl_data/lpm/lpm4_prob.npy')
ringmap = np.load('./deeprl_data/ring/ringlowres_prob.npy') ringmap = np.load('./deeprl_data/ring/ringlowres_prob.npy')
trustval = 1.0
#Make Gym Environment #Make Gym Environment
env = sar_gym.GridWorldSAR(render_mode=None) env = sar_gym.GridWorldSAR(render_mode=None, trust=trustval)
#Test one episode #Test one episode
...@@ -334,7 +338,7 @@ def test_policy(): ...@@ -334,7 +338,7 @@ def test_policy():
#x2.set_title('Rewards') #x2.set_title('Rewards')
plt.tight_layout() plt.tight_layout()
plt.suptitle("LPM Reward Function with Ring Model Constraint") plt.suptitle("LPM Reward Function with Ring Model Constraint (Trust = %.2f)" % trustval)
plt.show() plt.show()
...@@ -342,8 +346,8 @@ def test_policy(): ...@@ -342,8 +346,8 @@ def test_policy():
def main(): def main():
#test_policy() test_policy()
train() #train()
return return
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment