pyRDDLGym-jax 0.5__py3-none-any.whl → 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +463 -592
- pyRDDLGym_jax/core/logic.py +784 -544
- pyRDDLGym_jax/core/planner.py +329 -463
- pyRDDLGym_jax/core/simulator.py +7 -5
- pyRDDLGym_jax/core/tuning.py +379 -568
- pyRDDLGym_jax/core/visualization.py +1463 -0
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +5 -6
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +4 -5
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +5 -6
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +5 -5
- pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/default_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/default_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/default_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/tuning_drp.cfg +19 -0
- pyRDDLGym_jax/examples/configs/tuning_replan.cfg +20 -0
- pyRDDLGym_jax/examples/configs/tuning_slp.cfg +19 -0
- pyRDDLGym_jax/examples/run_plan.py +4 -1
- pyRDDLGym_jax/examples/run_tune.py +40 -27
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/METADATA +161 -104
- pyRDDLGym_jax-1.0.dist-info/RECORD +45 -0
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -19
- pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -20
- pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -18
- pyRDDLGym_jax-0.5.dist-info/RECORD +0 -44
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/simulator.py
CHANGED
|
@@ -107,7 +107,7 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
107
107
|
'''Throws an exception if the state invariants are not satisfied.'''
|
|
108
108
|
for (i, invariant) in enumerate(self.invariants):
|
|
109
109
|
loc = self.invariant_names[i]
|
|
110
|
-
sample, self.key, error = invariant(
|
|
110
|
+
sample, self.key, error, self.model_params = invariant(
|
|
111
111
|
self.subs, self.model_params, self.key)
|
|
112
112
|
self.handle_error_code(error, loc)
|
|
113
113
|
if not bool(sample):
|
|
@@ -125,7 +125,8 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
125
125
|
|
|
126
126
|
for (i, precond) in enumerate(self.preconds):
|
|
127
127
|
loc = self.precond_names[i]
|
|
128
|
-
sample, self.key, error
|
|
128
|
+
sample, self.key, error, self.model_params = precond(
|
|
129
|
+
subs, self.model_params, self.key)
|
|
129
130
|
self.handle_error_code(error, loc)
|
|
130
131
|
if not bool(sample):
|
|
131
132
|
if not silent:
|
|
@@ -138,7 +139,7 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
138
139
|
'''return True if a terminal state has been reached.'''
|
|
139
140
|
for (i, terminal) in enumerate(self.terminals):
|
|
140
141
|
loc = self.terminal_names[i]
|
|
141
|
-
sample, self.key, error = terminal(
|
|
142
|
+
sample, self.key, error, self.model_params = terminal(
|
|
142
143
|
self.subs, self.model_params, self.key)
|
|
143
144
|
self.handle_error_code(error, loc)
|
|
144
145
|
if bool(sample):
|
|
@@ -147,7 +148,7 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
147
148
|
|
|
148
149
|
def sample_reward(self) -> float:
|
|
149
150
|
'''Samples the current reward given the current state and action.'''
|
|
150
|
-
reward, self.key, error = self.reward(
|
|
151
|
+
reward, self.key, error, self.model_params = self.reward(
|
|
151
152
|
self.subs, self.model_params, self.key)
|
|
152
153
|
self.handle_error_code(error, 'reward function')
|
|
153
154
|
return float(reward)
|
|
@@ -165,7 +166,8 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
165
166
|
|
|
166
167
|
# compute CPFs in topological order
|
|
167
168
|
for (cpf, expr, _) in self.cpfs:
|
|
168
|
-
subs[cpf], self.key, error
|
|
169
|
+
subs[cpf], self.key, error, self.model_params = expr(
|
|
170
|
+
subs, self.model_params, self.key)
|
|
169
171
|
self.handle_error_code(error, f'CPF <{cpf}>')
|
|
170
172
|
|
|
171
173
|
# sample reward
|