Commit 872db6e3 authored by Bryson Howell's avatar Bryson Howell

Tests with trust scaling for constraint

parent b8fe4105
absl-py==2.1.0
anyio==3.6.2
arcgis==2.0.0
argon2-cffi==21.3.0
......@@ -14,6 +15,7 @@ certifi==2022.9.24
cffi==1.15.1
charset-normalizer==2.1.1
click==8.1.3
cloudpickle==3.0.0
contourpy==1.0.6
cryptography==38.0.4
cycler==0.11.0
......@@ -26,7 +28,10 @@ executing==1.2.0
fastjsonschema==2.16.2
fonttools==4.38.0
geomet==1.0.0
grpcio==1.64.1
gssapi==1.8.2
gym==0.26.2
gym-notices==0.0.8
idna==3.4
ipykernel==6.17.1
ipython==8.7.0
......@@ -35,6 +40,7 @@ ipywidgets==8.0.2
jedi==0.18.2
jeepney==0.8.0
Jinja2==3.1.2
joblib==1.4.2
json5==0.9.10
jsonschema==4.17.3
jupyter-server==1.23.3
......@@ -47,11 +53,14 @@ jupyterlab_server==2.16.3
keyring==21.7.0
kiwisolver==1.4.4
lerc==0.1.0
libtiff==0.4.2
lxml==4.9.1
Markdown==3.6
MarkupSafe==2.1.1
matplotlib==3.6.2
matplotlib-inline==0.1.6
mistune==2.0.4
mpi4py==3.1.6
nbclassic==0.4.8
nbclient==0.7.2
nbconvert==7.2.6
......@@ -66,6 +75,7 @@ nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
oauthlib==3.2.2
opencv-python==4.7.0.72
opt-einsum==3.3.0
packaging==21.3
pandas==1.5.2
......@@ -78,10 +88,12 @@ platformdirs==2.5.4
plotly==5.11.0
prometheus-client==0.15.0
prompt-toolkit==3.0.33
protobuf==4.25.3
psutil==5.9.4
ptyprocess==0.7.0
pure-eval==0.2.2
pycparser==2.21
pygame==2.5.2
Pygments==2.13.0
pyparsing==3.0.9
pyproj==3.4.0
......@@ -108,6 +120,8 @@ sniffio==1.3.0
soupsieve==2.3.2.post1
stack-data==0.6.2
tenacity==8.1.0
tensorboard==2.17.0
tensorboard-data-server==0.7.2
terminado==0.17.1
tinycss2==1.2.1
tomli==2.0.1
......@@ -121,5 +135,6 @@ urllib3==1.25.11
wcwidth==0.2.5
webencodings==0.5.1
websocket-client==1.4.2
Werkzeug==3.0.3
widgetsnbextension==4.0.3
wrapt==1.14.1
......@@ -14,11 +14,14 @@ FORWARD_SCALE = 0.1
class GridWorldSAR(gym.Env):
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}
def __init__(self, render_mode=None, map_size=48):
def __init__(self, render_mode=None, map_size=48, trust=0.5):
#Gym parameters
self.size = map_size # The size of the square grid
self.window_size = 512 # The size of the PyGame window
#Trust value
self.trust = trust
#Values the agent will need
self.pos_x = 0.0 #Current position, update w/ velocity (Might need to make this the pos in 1200x1200 path coordinates?)
self.pos_y = 0.0
......@@ -101,6 +104,8 @@ class GridWorldSAR(gym.Env):
if(self.cell[1] >= self.size):
self.cell[1] = self.size-1
map_diff = abs(self.heatmap[self.cell[0],self.cell[1]] - self.ringmap[self.cell[0],self.cell[1]])
#Trust scaling experiment
map_diff = map_diff * (1.0 - self.trust)
return {'loc': [self.pos_x,self.pos_y], 'cost': map_diff}
......
......@@ -210,9 +210,11 @@ def train():
#Run this function to create data
#create_data()
trustval = 1.0
#Training CPO
J = cpo.cpo(
lambda: sar_gym.GridWorldSAR(render_mode=None),
lambda: sar_gym.GridWorldSAR(render_mode=None,trust=trustval),
actor_critic=core.MLPActorCriticTD3trust,
max_ep_len=MAX_STEPS_PER_EPISODE,
cost_lim=COST_THRESHOLD,
......@@ -237,7 +239,7 @@ def train():
print(s.getvalue())
#Save policy
torch.save(J.pi, './deeprl_data/trained_model4.pt')
torch.save(J.pi, './trust_data/trained_model_100.pt')
#Testing
......@@ -283,12 +285,14 @@ def train():
def test_policy():
#Testing
policy = torch.load('./deeprl_data/trained_model2.pt')
policy = torch.load('./trust_data/trained_model_100.pt')
heatmap = np.load('./deeprl_data/lpm/lpm4_prob.npy')
ringmap = np.load('./deeprl_data/ring/ringlowres_prob.npy')
trustval = 1.0
#Make Gym Environment
env = sar_gym.GridWorldSAR(render_mode=None)
env = sar_gym.GridWorldSAR(render_mode=None, trust=trustval)
#Test one episode
......@@ -334,7 +338,7 @@ def test_policy():
#x2.set_title('Rewards')
plt.tight_layout()
plt.suptitle("LPM Reward Function with Ring Model Constraint")
plt.suptitle("LPM Reward Function with Ring Model Constraint (Trust = %.2f)" % trustval)
plt.show()
......@@ -342,8 +346,8 @@ def test_policy():
def main():
#test_policy()
train()
test_policy()
#train()
return
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment