pyRDDLGym-jax 2.4__tar.gz → 2.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/PKG-INFO +17 -30
  2. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/README.md +13 -27
  3. pyrddlgym_jax-2.6/pyRDDLGym_jax/__init__.py +1 -0
  4. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/core/compiler.py +23 -10
  5. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/core/logic.py +6 -8
  6. pyrddlgym_jax-2.6/pyRDDLGym_jax/core/model.py +595 -0
  7. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/core/planner.py +317 -99
  8. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/core/simulator.py +37 -13
  9. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/core/tuning.py +25 -10
  10. pyrddlgym_jax-2.6/pyRDDLGym_jax/entry_point.py +59 -0
  11. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/tuning_drp.cfg +1 -0
  12. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/tuning_replan.cfg +1 -0
  13. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/tuning_slp.cfg +1 -0
  14. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/run_plan.py +1 -1
  15. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/run_tune.py +8 -2
  16. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax.egg-info/PKG-INFO +17 -30
  17. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax.egg-info/SOURCES.txt +1 -0
  18. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax.egg-info/requires.txt +1 -1
  19. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/setup.py +2 -2
  20. pyrddlgym_jax-2.4/pyRDDLGym_jax/__init__.py +0 -1
  21. pyrddlgym_jax-2.4/pyRDDLGym_jax/entry_point.py +0 -27
  22. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/LICENSE +0 -0
  23. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/core/__init__.py +0 -0
  24. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/core/assets/__init__.py +0 -0
  25. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/core/assets/favicon.ico +0 -0
  26. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/core/visualization.py +0 -0
  27. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/__init__.py +0 -0
  28. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +0 -0
  29. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +0 -0
  30. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +0 -0
  31. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +0 -0
  32. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +0 -0
  33. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +0 -0
  34. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +0 -0
  35. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +0 -0
  36. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +0 -0
  37. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +0 -0
  38. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +0 -0
  39. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +0 -0
  40. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +0 -0
  41. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +0 -0
  42. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +0 -0
  43. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +0 -0
  44. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +0 -0
  45. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +0 -0
  46. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +0 -0
  47. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/__init__.py +0 -0
  48. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/default_drp.cfg +0 -0
  49. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/default_replan.cfg +0 -0
  50. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/default_slp.cfg +0 -0
  51. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/run_gradient.py +0 -0
  52. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/run_gym.py +0 -0
  53. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/run_scipy.py +0 -0
  54. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax.egg-info/dependency_links.txt +0 -0
  55. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax.egg-info/entry_points.txt +0 -0
  56. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax.egg-info/top_level.txt +0 -0
  57. {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: pyRDDLGym-jax
3
- Version: 2.4
3
+ Version: 2.6
4
4
  Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
5
  Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
6
  Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
@@ -20,7 +20,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
20
  Requires-Python: >=3.9
21
21
  Description-Content-Type: text/markdown
22
22
  License-File: LICENSE
23
- Requires-Dist: pyRDDLGym>=2.0
23
+ Requires-Dist: pyRDDLGym>=2.3
24
24
  Requires-Dist: tqdm>=4.66
25
25
  Requires-Dist: jax>=0.4.12
26
26
  Requires-Dist: optax>=0.1.9
@@ -39,6 +39,7 @@ Dynamic: description
39
39
  Dynamic: description-content-type
40
40
  Dynamic: home-page
41
41
  Dynamic: license
42
+ Dynamic: license-file
42
43
  Dynamic: provides-extra
43
44
  Dynamic: requires-dist
44
45
  Dynamic: requires-python
@@ -54,7 +55,7 @@ Dynamic: summary
54
55
 
55
56
  [Installation](#installation) | [Run cmd](#running-from-the-command-line) | [Run python](#running-from-another-python-application) | [Configuration](#configuring-the-planner) | [Dashboard](#jaxplan-dashboard) | [Tuning](#tuning-the-planner) | [Simulation](#simulation) | [Citing](#citing-jaxplan)
56
57
 
57
- **pyRDDLGym-jax (known in the literature as JaxPlan) is an efficient gradient-based/differentiable planning algorithm in JAX.**
58
+ **pyRDDLGym-jax (or JaxPlan) is an efficient gradient-based planning algorithm based on JAX.**
58
59
 
59
60
  Purpose:
60
61
 
@@ -83,7 +84,7 @@ and was moved to the individual logic components which have their own unique wei
83
84
 
84
85
  > [!NOTE]
85
86
  > While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
86
- > If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
87
+ > If you find it is not making progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
87
88
 
88
89
  ## Installation
89
90
 
@@ -116,7 +117,7 @@ pip install pyRDDLGym-jax[extra,dashboard]
116
117
  A basic run script is provided to train JaxPlan on any RDDL problem:
117
118
 
118
119
  ```shell
119
- jaxplan plan <domain> <instance> <method> <episodes>
120
+ jaxplan plan <domain> <instance> <method> --episodes <episodes>
120
121
  ```
121
122
 
122
123
  where:
@@ -219,13 +220,7 @@ controller = JaxOfflineController(planner, **train_args)
219
220
  ## JaxPlan Dashboard
220
221
 
221
222
  Since version 1.0, JaxPlan has an optional dashboard that allows keeping track of the planner performance across multiple runs,
222
- and visualization of the policy or model, and other useful debugging features.
223
-
224
- <p align="middle">
225
- <img src="https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/Images/dashboard.png" width="480" height="248" margin=0/>
226
- </p>
227
-
228
- To run the dashboard, add the following entry to your config file:
223
+ and visualization of the policy or model, and other useful debugging features. To run the dashboard, add the following to your config file:
229
224
 
230
225
  ```ini
231
226
  ...
@@ -234,14 +229,12 @@ dashboard=True
234
229
  ...
235
230
  ```
236
231
 
237
- More documentation about this and other new features will be coming soon.
238
-
239
232
  ## Tuning the Planner
240
233
 
241
234
  A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
242
235
 
243
236
  ```shell
244
- jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
237
+ jaxplan tune <domain> <instance> <method> --trials <trials> --iters <iters> --workers <workers> --dashboard <dashboard> --filepath <filepath>
245
238
  ```
246
239
 
247
240
  where:
@@ -251,7 +244,8 @@ where:
251
244
  - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
252
245
  - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
253
246
  - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
254
- - ``dashboard`` is whether the optimizations are tracked in the dashboard application.
247
+ - ``dashboard`` is whether the optimizations are tracked in the dashboard application
248
+ - ``filepath`` is the optional file path where a config file with the best hyper-parameter setting will be saved.
255
249
 
256
250
  It is easy to tune a custom range of the planner's hyper-parameters efficiently.
257
251
  First create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
@@ -291,23 +285,16 @@ env = pyRDDLGym.make(domain, instance, vectorized=True)
291
285
  with open('path/to/config.cfg', 'r') as file:
292
286
  config_template = file.read()
293
287
 
294
- # map parameters in the config that will be tuned
288
+ # tune weight from 10^-1 ... 10^5 and lr from 10^-5 ... 10^1
295
289
  def power_10(x):
296
- return 10.0 ** x
297
-
298
- hyperparams = [
299
- Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10), # tune weight from 10^-1 ... 10^5
300
- Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10), # tune lr from 10^-5 ... 10^1
301
- ]
290
+ return 10.0 ** x
291
+ hyperparams = [Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10),
292
+ Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10)]
302
293
 
303
294
  # build the tuner and tune
304
295
  tuning = JaxParameterTuning(env=env,
305
- config_template=config_template,
306
- hyperparams=hyperparams,
307
- online=False,
308
- eval_trials=trials,
309
- num_workers=workers,
310
- gp_iters=iters)
296
+ config_template=config_template, hyperparams=hyperparams,
297
+ online=False, eval_trials=trials, num_workers=workers, gp_iters=iters)
311
298
  tuning.tune(key=42, log_file='path/to/log.csv')
312
299
  ```
313
300
 
@@ -8,7 +8,7 @@
8
8
 
9
9
  [Installation](#installation) | [Run cmd](#running-from-the-command-line) | [Run python](#running-from-another-python-application) | [Configuration](#configuring-the-planner) | [Dashboard](#jaxplan-dashboard) | [Tuning](#tuning-the-planner) | [Simulation](#simulation) | [Citing](#citing-jaxplan)
10
10
 
11
- **pyRDDLGym-jax (known in the literature as JaxPlan) is an efficient gradient-based/differentiable planning algorithm in JAX.**
11
+ **pyRDDLGym-jax (or JaxPlan) is an efficient gradient-based planning algorithm based on JAX.**
12
12
 
13
13
  Purpose:
14
14
 
@@ -37,7 +37,7 @@ and was moved to the individual logic components which have their own unique wei
37
37
 
38
38
  > [!NOTE]
39
39
  > While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
40
- > If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
40
+ > If you find it is not making progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
41
41
 
42
42
  ## Installation
43
43
 
@@ -70,7 +70,7 @@ pip install pyRDDLGym-jax[extra,dashboard]
70
70
  A basic run script is provided to train JaxPlan on any RDDL problem:
71
71
 
72
72
  ```shell
73
- jaxplan plan <domain> <instance> <method> <episodes>
73
+ jaxplan plan <domain> <instance> <method> --episodes <episodes>
74
74
  ```
75
75
 
76
76
  where:
@@ -173,13 +173,7 @@ controller = JaxOfflineController(planner, **train_args)
173
173
  ## JaxPlan Dashboard
174
174
 
175
175
  Since version 1.0, JaxPlan has an optional dashboard that allows keeping track of the planner performance across multiple runs,
176
- and visualization of the policy or model, and other useful debugging features.
177
-
178
- <p align="middle">
179
- <img src="https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/Images/dashboard.png" width="480" height="248" margin=0/>
180
- </p>
181
-
182
- To run the dashboard, add the following entry to your config file:
176
+ and visualization of the policy or model, and other useful debugging features. To run the dashboard, add the following to your config file:
183
177
 
184
178
  ```ini
185
179
  ...
@@ -188,14 +182,12 @@ dashboard=True
188
182
  ...
189
183
  ```
190
184
 
191
- More documentation about this and other new features will be coming soon.
192
-
193
185
  ## Tuning the Planner
194
186
 
195
187
  A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
196
188
 
197
189
  ```shell
198
- jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
190
+ jaxplan tune <domain> <instance> <method> --trials <trials> --iters <iters> --workers <workers> --dashboard <dashboard> --filepath <filepath>
199
191
  ```
200
192
 
201
193
  where:
@@ -205,7 +197,8 @@ where:
205
197
  - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
206
198
  - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
207
199
  - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
208
- - ``dashboard`` is whether the optimizations are tracked in the dashboard application.
200
+ - ``dashboard`` is whether the optimizations are tracked in the dashboard application
201
+ - ``filepath`` is the optional file path where a config file with the best hyper-parameter setting will be saved.
209
202
 
210
203
  It is easy to tune a custom range of the planner's hyper-parameters efficiently.
211
204
  First create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
@@ -245,23 +238,16 @@ env = pyRDDLGym.make(domain, instance, vectorized=True)
245
238
  with open('path/to/config.cfg', 'r') as file:
246
239
  config_template = file.read()
247
240
 
248
- # map parameters in the config that will be tuned
241
+ # tune weight from 10^-1 ... 10^5 and lr from 10^-5 ... 10^1
249
242
  def power_10(x):
250
- return 10.0 ** x
251
-
252
- hyperparams = [
253
- Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10), # tune weight from 10^-1 ... 10^5
254
- Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10), # tune lr from 10^-5 ... 10^1
255
- ]
243
+ return 10.0 ** x
244
+ hyperparams = [Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10),
245
+ Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10)]
256
246
 
257
247
  # build the tuner and tune
258
248
  tuning = JaxParameterTuning(env=env,
259
- config_template=config_template,
260
- hyperparams=hyperparams,
261
- online=False,
262
- eval_trials=trials,
263
- num_workers=workers,
264
- gp_iters=iters)
249
+ config_template=config_template, hyperparams=hyperparams,
250
+ online=False, eval_trials=trials, num_workers=workers, gp_iters=iters)
265
251
  tuning.tune(key=42, log_file='path/to/log.csv')
266
252
  ```
267
253
 
@@ -0,0 +1 @@
1
+ __version__ = '2.6'
@@ -237,7 +237,8 @@ class JaxRDDLCompiler:
237
237
 
238
238
  def compile_transition(self, check_constraints: bool=False,
239
239
  constraint_func: bool=False,
240
- init_params_constr: Dict[str, Any]={}) -> Callable:
240
+ init_params_constr: Dict[str, Any]={},
241
+ cache_path_info: bool=False) -> Callable:
241
242
  '''Compiles the current RDDL model into a JAX transition function that
242
243
  samples the next state.
243
244
 
@@ -274,6 +275,7 @@ class JaxRDDLCompiler:
274
275
  returned log and does not raise an exception
275
276
  :param constraint_func: produces the h(s, a) function described above
276
277
  in addition to the usual outputs
278
+ :param cache_path_info: whether to save full path traces as part of the log
277
279
  '''
278
280
  NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
279
281
  rddl = self.rddl
@@ -322,8 +324,11 @@ class JaxRDDLCompiler:
322
324
  errors |= err
323
325
 
324
326
  # calculate fluent values
325
- fluents = {name: values for (name, values) in subs.items()
326
- if name not in rddl.non_fluents}
327
+ if cache_path_info:
328
+ fluents = {name: values for (name, values) in subs.items()
329
+ if name not in rddl.non_fluents}
330
+ else:
331
+ fluents = {}
327
332
 
328
333
  # set the next state to the current state
329
334
  for (state, next_state) in rddl.next_state.items():
@@ -367,7 +372,9 @@ class JaxRDDLCompiler:
367
372
  n_batch: int,
368
373
  check_constraints: bool=False,
369
374
  constraint_func: bool=False,
370
- init_params_constr: Dict[str, Any]={}) -> Callable:
375
+ init_params_constr: Dict[str, Any]={},
376
+ model_params_reduction: Callable=lambda x: x[0],
377
+ cache_path_info: bool=False) -> Callable:
371
378
  '''Compiles the current RDDL model into a JAX transition function that
372
379
  samples trajectories with a fixed horizon from a policy.
373
380
 
@@ -399,10 +406,13 @@ class JaxRDDLCompiler:
399
406
  returned log and does not raise an exception
400
407
  :param constraint_func: produces the h(s, a) constraint function
401
408
  in addition to the usual outputs
409
+ :param model_params_reduction: how to aggregate updated model_params across runs
410
+ in the batch (defaults to selecting the first element's parameters in the batch)
411
+ :param cache_path_info: whether to save full path traces as part of the log
402
412
  '''
403
413
  rddl = self.rddl
404
414
  jax_step_fn = self.compile_transition(
405
- check_constraints, constraint_func, init_params_constr)
415
+ check_constraints, constraint_func, init_params_constr, cache_path_info)
406
416
 
407
417
  # for POMDP only observ-fluents are assumed visible to the policy
408
418
  if rddl.observ_fluents:
@@ -421,7 +431,6 @@ class JaxRDDLCompiler:
421
431
  return jax_step_fn(subkey, actions, subs, model_params)
422
432
 
423
433
  # do a batched step update from the policy
424
- # TODO: come up with a better way to reduce the model_param batch dim
425
434
  def _jax_wrapped_batched_step_policy(carry, step):
426
435
  key, policy_params, hyperparams, subs, model_params = carry
427
436
  key, *subkeys = random.split(key, num=1 + n_batch)
@@ -430,7 +439,7 @@ class JaxRDDLCompiler:
430
439
  _jax_wrapped_single_step_policy,
431
440
  in_axes=(0, None, None, None, 0, None)
432
441
  )(keys, policy_params, hyperparams, step, subs, model_params)
433
- model_params = jax.tree_map(partial(jnp.mean, axis=0), model_params)
442
+ model_params = jax.tree_util.tree_map(model_params_reduction, model_params)
434
443
  carry = (key, policy_params, hyperparams, subs, model_params)
435
444
  return carry, log
436
445
 
@@ -440,7 +449,7 @@ class JaxRDDLCompiler:
440
449
  start = (key, policy_params, hyperparams, subs, model_params)
441
450
  steps = jnp.arange(n_steps)
442
451
  end, log = jax.lax.scan(_jax_wrapped_batched_step_policy, start, steps)
443
- log = jax.tree_map(partial(jnp.swapaxes, axis1=0, axis2=1), log)
452
+ log = jax.tree_util.tree_map(partial(jnp.swapaxes, axis1=0, axis2=1), log)
444
453
  model_params = end[-1]
445
454
  return log, model_params
446
455
 
@@ -707,7 +716,10 @@ class JaxRDDLCompiler:
707
716
  sample = jnp.asarray(value, dtype=self._fix_dtype(value))
708
717
  new_slices = [None] * len(jax_nested_expr)
709
718
  for (i, jax_expr) in enumerate(jax_nested_expr):
710
- new_slices[i], key, err, params = jax_expr(x, params, key)
719
+ new_slice, key, err, params = jax_expr(x, params, key)
720
+ if not jnp.issubdtype(jnp.result_type(new_slice), jnp.integer):
721
+ new_slice = jnp.asarray(new_slice, dtype=self.INT)
722
+ new_slices[i] = new_slice
711
723
  error |= err
712
724
  new_slices = tuple(new_slices)
713
725
  sample = sample[new_slices]
@@ -986,7 +998,8 @@ class JaxRDDLCompiler:
986
998
  sample_cases = [None] * len(jax_cases)
987
999
  for (i, jax_case) in enumerate(jax_cases):
988
1000
  sample_cases[i], key, err_case, params = jax_case(x, params, key)
989
- err |= err_case
1001
+ err |= err_case
1002
+ sample_cases = jnp.asarray(sample_cases)
990
1003
  sample_cases = jnp.asarray(sample_cases, dtype=self._fix_dtype(sample_cases))
991
1004
 
992
1005
  # predicate (enum) is an integer - use it to extract from case array
@@ -1056,15 +1056,13 @@ class ExactLogic(Logic):
1056
1056
  def control_if(self, id, init_params):
1057
1057
  return self._jax_wrapped_calc_if_then_else_exact
1058
1058
 
1059
- @staticmethod
1060
- def _jax_wrapped_calc_switch_exact(pred, cases, params):
1061
- pred = pred[jnp.newaxis, ...]
1062
- sample = jnp.take_along_axis(cases, pred, axis=0)
1063
- assert sample.shape[0] == 1
1064
- return sample[0, ...], params
1065
-
1066
1059
  def control_switch(self, id, init_params):
1067
- return self._jax_wrapped_calc_switch_exact
1060
+ def _jax_wrapped_calc_switch_exact(pred, cases, params):
1061
+ pred = jnp.asarray(pred[jnp.newaxis, ...], dtype=self.INT)
1062
+ sample = jnp.take_along_axis(cases, pred, axis=0)
1063
+ assert sample.shape[0] == 1
1064
+ return sample[0, ...], params
1065
+ return _jax_wrapped_calc_switch_exact
1068
1066
 
1069
1067
  # ===========================================================================
1070
1068
  # random variables