pyRDDLGym-jax 2.4__tar.gz → 2.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/PKG-INFO +17 -30
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/README.md +13 -27
- pyrddlgym_jax-2.6/pyRDDLGym_jax/__init__.py +1 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/core/compiler.py +23 -10
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/core/logic.py +6 -8
- pyrddlgym_jax-2.6/pyRDDLGym_jax/core/model.py +595 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/core/planner.py +317 -99
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/core/simulator.py +37 -13
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/core/tuning.py +25 -10
- pyrddlgym_jax-2.6/pyRDDLGym_jax/entry_point.py +59 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/tuning_drp.cfg +1 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/tuning_replan.cfg +1 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/tuning_slp.cfg +1 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/run_plan.py +1 -1
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/run_tune.py +8 -2
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax.egg-info/PKG-INFO +17 -30
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax.egg-info/SOURCES.txt +1 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax.egg-info/requires.txt +1 -1
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/setup.py +2 -2
- pyrddlgym_jax-2.4/pyRDDLGym_jax/__init__.py +0 -1
- pyrddlgym_jax-2.4/pyRDDLGym_jax/entry_point.py +0 -27
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/LICENSE +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/core/__init__.py +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/core/assets/__init__.py +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/core/assets/favicon.ico +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/core/visualization.py +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/__init__.py +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/__init__.py +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/default_drp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/default_replan.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/configs/default_slp.cfg +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/run_gradient.py +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/run_gym.py +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax/examples/run_scipy.py +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax.egg-info/dependency_links.txt +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax.egg-info/entry_points.txt +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/pyRDDLGym_jax.egg-info/top_level.txt +0 -0
- {pyrddlgym_jax-2.4 → pyrddlgym_jax-2.6}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: pyRDDLGym-jax
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.6
|
|
4
4
|
Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
|
|
5
5
|
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
6
|
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
@@ -20,7 +20,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
20
20
|
Requires-Python: >=3.9
|
|
21
21
|
Description-Content-Type: text/markdown
|
|
22
22
|
License-File: LICENSE
|
|
23
|
-
Requires-Dist: pyRDDLGym>=2.
|
|
23
|
+
Requires-Dist: pyRDDLGym>=2.3
|
|
24
24
|
Requires-Dist: tqdm>=4.66
|
|
25
25
|
Requires-Dist: jax>=0.4.12
|
|
26
26
|
Requires-Dist: optax>=0.1.9
|
|
@@ -39,6 +39,7 @@ Dynamic: description
|
|
|
39
39
|
Dynamic: description-content-type
|
|
40
40
|
Dynamic: home-page
|
|
41
41
|
Dynamic: license
|
|
42
|
+
Dynamic: license-file
|
|
42
43
|
Dynamic: provides-extra
|
|
43
44
|
Dynamic: requires-dist
|
|
44
45
|
Dynamic: requires-python
|
|
@@ -54,7 +55,7 @@ Dynamic: summary
|
|
|
54
55
|
|
|
55
56
|
[Installation](#installation) | [Run cmd](#running-from-the-command-line) | [Run python](#running-from-another-python-application) | [Configuration](#configuring-the-planner) | [Dashboard](#jaxplan-dashboard) | [Tuning](#tuning-the-planner) | [Simulation](#simulation) | [Citing](#citing-jaxplan)
|
|
56
57
|
|
|
57
|
-
**pyRDDLGym-jax (
|
|
58
|
+
**pyRDDLGym-jax (or JaxPlan) is an efficient gradient-based planning algorithm based on JAX.**
|
|
58
59
|
|
|
59
60
|
Purpose:
|
|
60
61
|
|
|
@@ -83,7 +84,7 @@ and was moved to the individual logic components which have their own unique wei
|
|
|
83
84
|
|
|
84
85
|
> [!NOTE]
|
|
85
86
|
> While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
|
|
86
|
-
> If you find it is not making
|
|
87
|
+
> If you find it is not making progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
|
|
87
88
|
|
|
88
89
|
## Installation
|
|
89
90
|
|
|
@@ -116,7 +117,7 @@ pip install pyRDDLGym-jax[extra,dashboard]
|
|
|
116
117
|
A basic run script is provided to train JaxPlan on any RDDL problem:
|
|
117
118
|
|
|
118
119
|
```shell
|
|
119
|
-
jaxplan plan <domain> <instance> <method> <episodes>
|
|
120
|
+
jaxplan plan <domain> <instance> <method> --episodes <episodes>
|
|
120
121
|
```
|
|
121
122
|
|
|
122
123
|
where:
|
|
@@ -219,13 +220,7 @@ controller = JaxOfflineController(planner, **train_args)
|
|
|
219
220
|
## JaxPlan Dashboard
|
|
220
221
|
|
|
221
222
|
Since version 1.0, JaxPlan has an optional dashboard that allows keeping track of the planner performance across multiple runs,
|
|
222
|
-
and visualization of the policy or model, and other useful debugging features.
|
|
223
|
-
|
|
224
|
-
<p align="middle">
|
|
225
|
-
<img src="https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/Images/dashboard.png" width="480" height="248" margin=0/>
|
|
226
|
-
</p>
|
|
227
|
-
|
|
228
|
-
To run the dashboard, add the following entry to your config file:
|
|
223
|
+
and visualization of the policy or model, and other useful debugging features. To run the dashboard, add the following to your config file:
|
|
229
224
|
|
|
230
225
|
```ini
|
|
231
226
|
...
|
|
@@ -234,14 +229,12 @@ dashboard=True
|
|
|
234
229
|
...
|
|
235
230
|
```
|
|
236
231
|
|
|
237
|
-
More documentation about this and other new features will be coming soon.
|
|
238
|
-
|
|
239
232
|
## Tuning the Planner
|
|
240
233
|
|
|
241
234
|
A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
|
|
242
235
|
|
|
243
236
|
```shell
|
|
244
|
-
jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
|
|
237
|
+
jaxplan tune <domain> <instance> <method> --trials <trials> --iters <iters> --workers <workers> --dashboard <dashboard> --filepath <filepath>
|
|
245
238
|
```
|
|
246
239
|
|
|
247
240
|
where:
|
|
@@ -251,7 +244,8 @@ where:
|
|
|
251
244
|
- ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
|
|
252
245
|
- ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
|
|
253
246
|
- ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
|
|
254
|
-
- ``dashboard`` is whether the optimizations are tracked in the dashboard application
|
|
247
|
+
- ``dashboard`` is whether the optimizations are tracked in the dashboard application
|
|
248
|
+
- ``filepath`` is the optional file path where a config file with the best hyper-parameter setting will be saved.
|
|
255
249
|
|
|
256
250
|
It is easy to tune a custom range of the planner's hyper-parameters efficiently.
|
|
257
251
|
First create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
|
|
@@ -291,23 +285,16 @@ env = pyRDDLGym.make(domain, instance, vectorized=True)
|
|
|
291
285
|
with open('path/to/config.cfg', 'r') as file:
|
|
292
286
|
config_template = file.read()
|
|
293
287
|
|
|
294
|
-
#
|
|
288
|
+
# tune weight from 10^-1 ... 10^5 and lr from 10^-5 ... 10^1
|
|
295
289
|
def power_10(x):
|
|
296
|
-
return 10.0 ** x
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10), # tune weight from 10^-1 ... 10^5
|
|
300
|
-
Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10), # tune lr from 10^-5 ... 10^1
|
|
301
|
-
]
|
|
290
|
+
return 10.0 ** x
|
|
291
|
+
hyperparams = [Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10),
|
|
292
|
+
Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10)]
|
|
302
293
|
|
|
303
294
|
# build the tuner and tune
|
|
304
295
|
tuning = JaxParameterTuning(env=env,
|
|
305
|
-
config_template=config_template,
|
|
306
|
-
|
|
307
|
-
online=False,
|
|
308
|
-
eval_trials=trials,
|
|
309
|
-
num_workers=workers,
|
|
310
|
-
gp_iters=iters)
|
|
296
|
+
config_template=config_template, hyperparams=hyperparams,
|
|
297
|
+
online=False, eval_trials=trials, num_workers=workers, gp_iters=iters)
|
|
311
298
|
tuning.tune(key=42, log_file='path/to/log.csv')
|
|
312
299
|
```
|
|
313
300
|
|
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|
|
|
9
9
|
[Installation](#installation) | [Run cmd](#running-from-the-command-line) | [Run python](#running-from-another-python-application) | [Configuration](#configuring-the-planner) | [Dashboard](#jaxplan-dashboard) | [Tuning](#tuning-the-planner) | [Simulation](#simulation) | [Citing](#citing-jaxplan)
|
|
10
10
|
|
|
11
|
-
**pyRDDLGym-jax (
|
|
11
|
+
**pyRDDLGym-jax (or JaxPlan) is an efficient gradient-based planning algorithm based on JAX.**
|
|
12
12
|
|
|
13
13
|
Purpose:
|
|
14
14
|
|
|
@@ -37,7 +37,7 @@ and was moved to the individual logic components which have their own unique wei
|
|
|
37
37
|
|
|
38
38
|
> [!NOTE]
|
|
39
39
|
> While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
|
|
40
|
-
> If you find it is not making
|
|
40
|
+
> If you find it is not making progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
|
|
41
41
|
|
|
42
42
|
## Installation
|
|
43
43
|
|
|
@@ -70,7 +70,7 @@ pip install pyRDDLGym-jax[extra,dashboard]
|
|
|
70
70
|
A basic run script is provided to train JaxPlan on any RDDL problem:
|
|
71
71
|
|
|
72
72
|
```shell
|
|
73
|
-
jaxplan plan <domain> <instance> <method> <episodes>
|
|
73
|
+
jaxplan plan <domain> <instance> <method> --episodes <episodes>
|
|
74
74
|
```
|
|
75
75
|
|
|
76
76
|
where:
|
|
@@ -173,13 +173,7 @@ controller = JaxOfflineController(planner, **train_args)
|
|
|
173
173
|
## JaxPlan Dashboard
|
|
174
174
|
|
|
175
175
|
Since version 1.0, JaxPlan has an optional dashboard that allows keeping track of the planner performance across multiple runs,
|
|
176
|
-
and visualization of the policy or model, and other useful debugging features.
|
|
177
|
-
|
|
178
|
-
<p align="middle">
|
|
179
|
-
<img src="https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/Images/dashboard.png" width="480" height="248" margin=0/>
|
|
180
|
-
</p>
|
|
181
|
-
|
|
182
|
-
To run the dashboard, add the following entry to your config file:
|
|
176
|
+
and visualization of the policy or model, and other useful debugging features. To run the dashboard, add the following to your config file:
|
|
183
177
|
|
|
184
178
|
```ini
|
|
185
179
|
...
|
|
@@ -188,14 +182,12 @@ dashboard=True
|
|
|
188
182
|
...
|
|
189
183
|
```
|
|
190
184
|
|
|
191
|
-
More documentation about this and other new features will be coming soon.
|
|
192
|
-
|
|
193
185
|
## Tuning the Planner
|
|
194
186
|
|
|
195
187
|
A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
|
|
196
188
|
|
|
197
189
|
```shell
|
|
198
|
-
jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
|
|
190
|
+
jaxplan tune <domain> <instance> <method> --trials <trials> --iters <iters> --workers <workers> --dashboard <dashboard> --filepath <filepath>
|
|
199
191
|
```
|
|
200
192
|
|
|
201
193
|
where:
|
|
@@ -205,7 +197,8 @@ where:
|
|
|
205
197
|
- ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
|
|
206
198
|
- ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
|
|
207
199
|
- ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
|
|
208
|
-
- ``dashboard`` is whether the optimizations are tracked in the dashboard application
|
|
200
|
+
- ``dashboard`` is whether the optimizations are tracked in the dashboard application
|
|
201
|
+
- ``filepath`` is the optional file path where a config file with the best hyper-parameter setting will be saved.
|
|
209
202
|
|
|
210
203
|
It is easy to tune a custom range of the planner's hyper-parameters efficiently.
|
|
211
204
|
First create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
|
|
@@ -245,23 +238,16 @@ env = pyRDDLGym.make(domain, instance, vectorized=True)
|
|
|
245
238
|
with open('path/to/config.cfg', 'r') as file:
|
|
246
239
|
config_template = file.read()
|
|
247
240
|
|
|
248
|
-
#
|
|
241
|
+
# tune weight from 10^-1 ... 10^5 and lr from 10^-5 ... 10^1
|
|
249
242
|
def power_10(x):
|
|
250
|
-
return 10.0 ** x
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10), # tune weight from 10^-1 ... 10^5
|
|
254
|
-
Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10), # tune lr from 10^-5 ... 10^1
|
|
255
|
-
]
|
|
243
|
+
return 10.0 ** x
|
|
244
|
+
hyperparams = [Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10),
|
|
245
|
+
Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10)]
|
|
256
246
|
|
|
257
247
|
# build the tuner and tune
|
|
258
248
|
tuning = JaxParameterTuning(env=env,
|
|
259
|
-
config_template=config_template,
|
|
260
|
-
|
|
261
|
-
online=False,
|
|
262
|
-
eval_trials=trials,
|
|
263
|
-
num_workers=workers,
|
|
264
|
-
gp_iters=iters)
|
|
249
|
+
config_template=config_template, hyperparams=hyperparams,
|
|
250
|
+
online=False, eval_trials=trials, num_workers=workers, gp_iters=iters)
|
|
265
251
|
tuning.tune(key=42, log_file='path/to/log.csv')
|
|
266
252
|
```
|
|
267
253
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = '2.6'
|
|
@@ -237,7 +237,8 @@ class JaxRDDLCompiler:
|
|
|
237
237
|
|
|
238
238
|
def compile_transition(self, check_constraints: bool=False,
|
|
239
239
|
constraint_func: bool=False,
|
|
240
|
-
init_params_constr: Dict[str, Any]={}
|
|
240
|
+
init_params_constr: Dict[str, Any]={},
|
|
241
|
+
cache_path_info: bool=False) -> Callable:
|
|
241
242
|
'''Compiles the current RDDL model into a JAX transition function that
|
|
242
243
|
samples the next state.
|
|
243
244
|
|
|
@@ -274,6 +275,7 @@ class JaxRDDLCompiler:
|
|
|
274
275
|
returned log and does not raise an exception
|
|
275
276
|
:param constraint_func: produces the h(s, a) function described above
|
|
276
277
|
in addition to the usual outputs
|
|
278
|
+
:param cache_path_info: whether to save full path traces as part of the log
|
|
277
279
|
'''
|
|
278
280
|
NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
|
|
279
281
|
rddl = self.rddl
|
|
@@ -322,8 +324,11 @@ class JaxRDDLCompiler:
|
|
|
322
324
|
errors |= err
|
|
323
325
|
|
|
324
326
|
# calculate fluent values
|
|
325
|
-
|
|
326
|
-
|
|
327
|
+
if cache_path_info:
|
|
328
|
+
fluents = {name: values for (name, values) in subs.items()
|
|
329
|
+
if name not in rddl.non_fluents}
|
|
330
|
+
else:
|
|
331
|
+
fluents = {}
|
|
327
332
|
|
|
328
333
|
# set the next state to the current state
|
|
329
334
|
for (state, next_state) in rddl.next_state.items():
|
|
@@ -367,7 +372,9 @@ class JaxRDDLCompiler:
|
|
|
367
372
|
n_batch: int,
|
|
368
373
|
check_constraints: bool=False,
|
|
369
374
|
constraint_func: bool=False,
|
|
370
|
-
init_params_constr: Dict[str, Any]={}
|
|
375
|
+
init_params_constr: Dict[str, Any]={},
|
|
376
|
+
model_params_reduction: Callable=lambda x: x[0],
|
|
377
|
+
cache_path_info: bool=False) -> Callable:
|
|
371
378
|
'''Compiles the current RDDL model into a JAX transition function that
|
|
372
379
|
samples trajectories with a fixed horizon from a policy.
|
|
373
380
|
|
|
@@ -399,10 +406,13 @@ class JaxRDDLCompiler:
|
|
|
399
406
|
returned log and does not raise an exception
|
|
400
407
|
:param constraint_func: produces the h(s, a) constraint function
|
|
401
408
|
in addition to the usual outputs
|
|
409
|
+
:param model_params_reduction: how to aggregate updated model_params across runs
|
|
410
|
+
in the batch (defaults to selecting the first element's parameters in the batch)
|
|
411
|
+
:param cache_path_info: whether to save full path traces as part of the log
|
|
402
412
|
'''
|
|
403
413
|
rddl = self.rddl
|
|
404
414
|
jax_step_fn = self.compile_transition(
|
|
405
|
-
check_constraints, constraint_func, init_params_constr)
|
|
415
|
+
check_constraints, constraint_func, init_params_constr, cache_path_info)
|
|
406
416
|
|
|
407
417
|
# for POMDP only observ-fluents are assumed visible to the policy
|
|
408
418
|
if rddl.observ_fluents:
|
|
@@ -421,7 +431,6 @@ class JaxRDDLCompiler:
|
|
|
421
431
|
return jax_step_fn(subkey, actions, subs, model_params)
|
|
422
432
|
|
|
423
433
|
# do a batched step update from the policy
|
|
424
|
-
# TODO: come up with a better way to reduce the model_param batch dim
|
|
425
434
|
def _jax_wrapped_batched_step_policy(carry, step):
|
|
426
435
|
key, policy_params, hyperparams, subs, model_params = carry
|
|
427
436
|
key, *subkeys = random.split(key, num=1 + n_batch)
|
|
@@ -430,7 +439,7 @@ class JaxRDDLCompiler:
|
|
|
430
439
|
_jax_wrapped_single_step_policy,
|
|
431
440
|
in_axes=(0, None, None, None, 0, None)
|
|
432
441
|
)(keys, policy_params, hyperparams, step, subs, model_params)
|
|
433
|
-
model_params = jax.tree_map(
|
|
442
|
+
model_params = jax.tree_util.tree_map(model_params_reduction, model_params)
|
|
434
443
|
carry = (key, policy_params, hyperparams, subs, model_params)
|
|
435
444
|
return carry, log
|
|
436
445
|
|
|
@@ -440,7 +449,7 @@ class JaxRDDLCompiler:
|
|
|
440
449
|
start = (key, policy_params, hyperparams, subs, model_params)
|
|
441
450
|
steps = jnp.arange(n_steps)
|
|
442
451
|
end, log = jax.lax.scan(_jax_wrapped_batched_step_policy, start, steps)
|
|
443
|
-
log = jax.tree_map(partial(jnp.swapaxes, axis1=0, axis2=1), log)
|
|
452
|
+
log = jax.tree_util.tree_map(partial(jnp.swapaxes, axis1=0, axis2=1), log)
|
|
444
453
|
model_params = end[-1]
|
|
445
454
|
return log, model_params
|
|
446
455
|
|
|
@@ -707,7 +716,10 @@ class JaxRDDLCompiler:
|
|
|
707
716
|
sample = jnp.asarray(value, dtype=self._fix_dtype(value))
|
|
708
717
|
new_slices = [None] * len(jax_nested_expr)
|
|
709
718
|
for (i, jax_expr) in enumerate(jax_nested_expr):
|
|
710
|
-
|
|
719
|
+
new_slice, key, err, params = jax_expr(x, params, key)
|
|
720
|
+
if not jnp.issubdtype(jnp.result_type(new_slice), jnp.integer):
|
|
721
|
+
new_slice = jnp.asarray(new_slice, dtype=self.INT)
|
|
722
|
+
new_slices[i] = new_slice
|
|
711
723
|
error |= err
|
|
712
724
|
new_slices = tuple(new_slices)
|
|
713
725
|
sample = sample[new_slices]
|
|
@@ -986,7 +998,8 @@ class JaxRDDLCompiler:
|
|
|
986
998
|
sample_cases = [None] * len(jax_cases)
|
|
987
999
|
for (i, jax_case) in enumerate(jax_cases):
|
|
988
1000
|
sample_cases[i], key, err_case, params = jax_case(x, params, key)
|
|
989
|
-
err |= err_case
|
|
1001
|
+
err |= err_case
|
|
1002
|
+
sample_cases = jnp.asarray(sample_cases)
|
|
990
1003
|
sample_cases = jnp.asarray(sample_cases, dtype=self._fix_dtype(sample_cases))
|
|
991
1004
|
|
|
992
1005
|
# predicate (enum) is an integer - use it to extract from case array
|
|
@@ -1056,15 +1056,13 @@ class ExactLogic(Logic):
|
|
|
1056
1056
|
def control_if(self, id, init_params):
|
|
1057
1057
|
return self._jax_wrapped_calc_if_then_else_exact
|
|
1058
1058
|
|
|
1059
|
-
@staticmethod
|
|
1060
|
-
def _jax_wrapped_calc_switch_exact(pred, cases, params):
|
|
1061
|
-
pred = pred[jnp.newaxis, ...]
|
|
1062
|
-
sample = jnp.take_along_axis(cases, pred, axis=0)
|
|
1063
|
-
assert sample.shape[0] == 1
|
|
1064
|
-
return sample[0, ...], params
|
|
1065
|
-
|
|
1066
1059
|
def control_switch(self, id, init_params):
|
|
1067
|
-
|
|
1060
|
+
def _jax_wrapped_calc_switch_exact(pred, cases, params):
|
|
1061
|
+
pred = jnp.asarray(pred[jnp.newaxis, ...], dtype=self.INT)
|
|
1062
|
+
sample = jnp.take_along_axis(cases, pred, axis=0)
|
|
1063
|
+
assert sample.shape[0] == 1
|
|
1064
|
+
return sample[0, ...], params
|
|
1065
|
+
return _jax_wrapped_calc_switch_exact
|
|
1068
1066
|
|
|
1069
1067
|
# ===========================================================================
|
|
1070
1068
|
# random variables
|