pyRDDLGym-jax 0.5__py3-none-any.whl → 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. pyRDDLGym_jax/__init__.py +1 -1
  2. pyRDDLGym_jax/core/compiler.py +463 -592
  3. pyRDDLGym_jax/core/logic.py +784 -544
  4. pyRDDLGym_jax/core/planner.py +329 -463
  5. pyRDDLGym_jax/core/simulator.py +7 -5
  6. pyRDDLGym_jax/core/tuning.py +379 -568
  7. pyRDDLGym_jax/core/visualization.py +1463 -0
  8. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +5 -6
  9. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +4 -5
  10. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +5 -6
  11. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
  12. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +4 -4
  13. pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +3 -3
  14. pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +3 -3
  15. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +3 -3
  16. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +3 -3
  17. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +3 -3
  18. pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +3 -3
  19. pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -4
  20. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +3 -3
  21. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +3 -3
  22. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +5 -5
  23. pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +4 -4
  24. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +3 -3
  25. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +3 -3
  26. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +3 -3
  27. pyRDDLGym_jax/examples/configs/default_drp.cfg +3 -3
  28. pyRDDLGym_jax/examples/configs/default_replan.cfg +3 -3
  29. pyRDDLGym_jax/examples/configs/default_slp.cfg +3 -3
  30. pyRDDLGym_jax/examples/configs/tuning_drp.cfg +19 -0
  31. pyRDDLGym_jax/examples/configs/tuning_replan.cfg +20 -0
  32. pyRDDLGym_jax/examples/configs/tuning_slp.cfg +19 -0
  33. pyRDDLGym_jax/examples/run_plan.py +4 -1
  34. pyRDDLGym_jax/examples/run_tune.py +40 -27
  35. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/METADATA +161 -104
  36. pyRDDLGym_jax-1.0.dist-info/RECORD +45 -0
  37. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/WHEEL +1 -1
  38. pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -19
  39. pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -20
  40. pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -18
  41. pyRDDLGym_jax-0.5.dist-info/RECORD +0 -44
  42. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/LICENSE +0 -0
  43. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/top_level.txt +0 -0
@@ -107,7 +107,7 @@ class JaxRDDLSimulator(RDDLSimulator):
107
107
  '''Throws an exception if the state invariants are not satisfied.'''
108
108
  for (i, invariant) in enumerate(self.invariants):
109
109
  loc = self.invariant_names[i]
110
- sample, self.key, error = invariant(
110
+ sample, self.key, error, self.model_params = invariant(
111
111
  self.subs, self.model_params, self.key)
112
112
  self.handle_error_code(error, loc)
113
113
  if not bool(sample):
@@ -125,7 +125,8 @@ class JaxRDDLSimulator(RDDLSimulator):
125
125
 
126
126
  for (i, precond) in enumerate(self.preconds):
127
127
  loc = self.precond_names[i]
128
- sample, self.key, error = precond(subs, self.model_params, self.key)
128
+ sample, self.key, error, self.model_params = precond(
129
+ subs, self.model_params, self.key)
129
130
  self.handle_error_code(error, loc)
130
131
  if not bool(sample):
131
132
  if not silent:
@@ -138,7 +139,7 @@ class JaxRDDLSimulator(RDDLSimulator):
138
139
  '''return True if a terminal state has been reached.'''
139
140
  for (i, terminal) in enumerate(self.terminals):
140
141
  loc = self.terminal_names[i]
141
- sample, self.key, error = terminal(
142
+ sample, self.key, error, self.model_params = terminal(
142
143
  self.subs, self.model_params, self.key)
143
144
  self.handle_error_code(error, loc)
144
145
  if bool(sample):
@@ -147,7 +148,7 @@ class JaxRDDLSimulator(RDDLSimulator):
147
148
 
148
149
  def sample_reward(self) -> float:
149
150
  '''Samples the current reward given the current state and action.'''
150
- reward, self.key, error = self.reward(
151
+ reward, self.key, error, self.model_params = self.reward(
151
152
  self.subs, self.model_params, self.key)
152
153
  self.handle_error_code(error, 'reward function')
153
154
  return float(reward)
@@ -165,7 +166,8 @@ class JaxRDDLSimulator(RDDLSimulator):
165
166
 
166
167
  # compute CPFs in topological order
167
168
  for (cpf, expr, _) in self.cpfs:
168
- subs[cpf], self.key, error = expr(subs, self.model_params, self.key)
169
+ subs[cpf], self.key, error, self.model_params = expr(
170
+ subs, self.model_params, self.key)
169
171
  self.handle_error_code(error, f'CPF <{cpf}>')
170
172
 
171
173
  # sample reward