pyRDDLGym-jax 2.5__tar.gz → 2.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/LICENSE +1 -1
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/PKG-INFO +5 -13
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/README.md +3 -11
- pyrddlgym_jax-2.7/pyRDDLGym_jax/__init__.py +1 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/core/compiler.py +107 -11
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/core/logic.py +6 -8
- pyrddlgym_jax-2.7/pyRDDLGym_jax/core/model.py +595 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/core/planner.py +183 -24
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/core/simulator.py +12 -4
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/run_plan.py +31 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax.egg-info/PKG-INFO +5 -13
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax.egg-info/SOURCES.txt +1 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax.egg-info/requires.txt +1 -1
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/setup.py +2 -2
- pyrddlgym_jax-2.5/pyRDDLGym_jax/__init__.py +0 -1
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/core/__init__.py +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/core/assets/__init__.py +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/core/assets/favicon.ico +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/core/tuning.py +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/core/visualization.py +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/entry_point.py +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/__init__.py +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/configs/__init__.py +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/configs/default_drp.cfg +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/configs/default_replan.cfg +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/configs/default_slp.cfg +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/configs/tuning_drp.cfg +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/configs/tuning_replan.cfg +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/configs/tuning_slp.cfg +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/run_gradient.py +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/run_gym.py +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/run_scipy.py +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax/examples/run_tune.py +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax.egg-info/dependency_links.txt +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax.egg-info/entry_points.txt +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/pyRDDLGym_jax.egg-info/top_level.txt +0 -0
- {pyrddlgym_jax-2.5 → pyrddlgym_jax-2.7}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pyRDDLGym-jax
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.7
|
|
4
4
|
Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
|
|
5
5
|
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
6
|
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
@@ -20,7 +20,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
20
20
|
Requires-Python: >=3.9
|
|
21
21
|
Description-Content-Type: text/markdown
|
|
22
22
|
License-File: LICENSE
|
|
23
|
-
Requires-Dist: pyRDDLGym>=2.
|
|
23
|
+
Requires-Dist: pyRDDLGym>=2.5
|
|
24
24
|
Requires-Dist: tqdm>=4.66
|
|
25
25
|
Requires-Dist: jax>=0.4.12
|
|
26
26
|
Requires-Dist: optax>=0.1.9
|
|
@@ -55,7 +55,7 @@ Dynamic: summary
|
|
|
55
55
|
|
|
56
56
|
[Installation](#installation) | [Run cmd](#running-from-the-command-line) | [Run python](#running-from-another-python-application) | [Configuration](#configuring-the-planner) | [Dashboard](#jaxplan-dashboard) | [Tuning](#tuning-the-planner) | [Simulation](#simulation) | [Citing](#citing-jaxplan)
|
|
57
57
|
|
|
58
|
-
**pyRDDLGym-jax (
|
|
58
|
+
**pyRDDLGym-jax (or JaxPlan) is an efficient gradient-based planning algorithm based on JAX.**
|
|
59
59
|
|
|
60
60
|
Purpose:
|
|
61
61
|
|
|
@@ -84,7 +84,7 @@ and was moved to the individual logic components which have their own unique wei
|
|
|
84
84
|
|
|
85
85
|
> [!NOTE]
|
|
86
86
|
> While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
|
|
87
|
-
> If you find it is not making
|
|
87
|
+
> If you find it is not making progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
|
|
88
88
|
|
|
89
89
|
## Installation
|
|
90
90
|
|
|
@@ -220,13 +220,7 @@ controller = JaxOfflineController(planner, **train_args)
|
|
|
220
220
|
## JaxPlan Dashboard
|
|
221
221
|
|
|
222
222
|
Since version 1.0, JaxPlan has an optional dashboard that allows keeping track of the planner performance across multiple runs,
|
|
223
|
-
and visualization of the policy or model, and other useful debugging features.
|
|
224
|
-
|
|
225
|
-
<p align="middle">
|
|
226
|
-
<img src="https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/Images/dashboard.png" width="480" height="248" margin=0/>
|
|
227
|
-
</p>
|
|
228
|
-
|
|
229
|
-
To run the dashboard, add the following entry to your config file:
|
|
223
|
+
and visualization of the policy or model, and other useful debugging features. To run the dashboard, add the following to your config file:
|
|
230
224
|
|
|
231
225
|
```ini
|
|
232
226
|
...
|
|
@@ -235,8 +229,6 @@ dashboard=True
|
|
|
235
229
|
...
|
|
236
230
|
```
|
|
237
231
|
|
|
238
|
-
More documentation about this and other new features will be coming soon.
|
|
239
|
-
|
|
240
232
|
## Tuning the Planner
|
|
241
233
|
|
|
242
234
|
A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
|
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|
|
|
9
9
|
[Installation](#installation) | [Run cmd](#running-from-the-command-line) | [Run python](#running-from-another-python-application) | [Configuration](#configuring-the-planner) | [Dashboard](#jaxplan-dashboard) | [Tuning](#tuning-the-planner) | [Simulation](#simulation) | [Citing](#citing-jaxplan)
|
|
10
10
|
|
|
11
|
-
**pyRDDLGym-jax (
|
|
11
|
+
**pyRDDLGym-jax (or JaxPlan) is an efficient gradient-based planning algorithm based on JAX.**
|
|
12
12
|
|
|
13
13
|
Purpose:
|
|
14
14
|
|
|
@@ -37,7 +37,7 @@ and was moved to the individual logic components which have their own unique wei
|
|
|
37
37
|
|
|
38
38
|
> [!NOTE]
|
|
39
39
|
> While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
|
|
40
|
-
> If you find it is not making
|
|
40
|
+
> If you find it is not making progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
|
|
41
41
|
|
|
42
42
|
## Installation
|
|
43
43
|
|
|
@@ -173,13 +173,7 @@ controller = JaxOfflineController(planner, **train_args)
|
|
|
173
173
|
## JaxPlan Dashboard
|
|
174
174
|
|
|
175
175
|
Since version 1.0, JaxPlan has an optional dashboard that allows keeping track of the planner performance across multiple runs,
|
|
176
|
-
and visualization of the policy or model, and other useful debugging features.
|
|
177
|
-
|
|
178
|
-
<p align="middle">
|
|
179
|
-
<img src="https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/Images/dashboard.png" width="480" height="248" margin=0/>
|
|
180
|
-
</p>
|
|
181
|
-
|
|
182
|
-
To run the dashboard, add the following entry to your config file:
|
|
176
|
+
and visualization of the policy or model, and other useful debugging features. To run the dashboard, add the following to your config file:
|
|
183
177
|
|
|
184
178
|
```ini
|
|
185
179
|
...
|
|
@@ -188,8 +182,6 @@ dashboard=True
|
|
|
188
182
|
...
|
|
189
183
|
```
|
|
190
184
|
|
|
191
|
-
More documentation about this and other new features will be coming soon.
|
|
192
|
-
|
|
193
185
|
## Tuning the Planner
|
|
194
186
|
|
|
195
187
|
A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = '2.7'
|
|
@@ -30,7 +30,8 @@ from pyRDDLGym.core.debug.exception import (
|
|
|
30
30
|
print_stack_trace,
|
|
31
31
|
raise_warning,
|
|
32
32
|
RDDLInvalidNumberOfArgumentsError,
|
|
33
|
-
RDDLNotImplementedError
|
|
33
|
+
RDDLNotImplementedError,
|
|
34
|
+
RDDLUndefinedVariableError
|
|
34
35
|
)
|
|
35
36
|
from pyRDDLGym.core.debug.logger import Logger
|
|
36
37
|
from pyRDDLGym.core.simulator import RDDLSimulatorPrecompiled
|
|
@@ -56,7 +57,8 @@ class JaxRDDLCompiler:
|
|
|
56
57
|
allow_synchronous_state: bool=True,
|
|
57
58
|
logger: Optional[Logger]=None,
|
|
58
59
|
use64bit: bool=False,
|
|
59
|
-
compile_non_fluent_exact: bool=True
|
|
60
|
+
compile_non_fluent_exact: bool=True,
|
|
61
|
+
python_functions: Optional[Dict[str, Callable]]=None) -> None:
|
|
60
62
|
'''Creates a new RDDL to Jax compiler.
|
|
61
63
|
|
|
62
64
|
:param rddl: the RDDL model to compile into Jax
|
|
@@ -65,7 +67,8 @@ class JaxRDDLCompiler:
|
|
|
65
67
|
:param logger: to log information about compilation to file
|
|
66
68
|
:param use64bit: whether to use 64 bit arithmetic
|
|
67
69
|
:param compile_non_fluent_exact: whether non-fluent expressions
|
|
68
|
-
are always compiled using exact JAX expressions
|
|
70
|
+
are always compiled using exact JAX expressions
|
|
71
|
+
:param python_functions: dictionary of external Python functions to call from RDDL
|
|
69
72
|
'''
|
|
70
73
|
self.rddl = rddl
|
|
71
74
|
self.logger = logger
|
|
@@ -99,11 +102,15 @@ class JaxRDDLCompiler:
|
|
|
99
102
|
self.traced = tracer.trace()
|
|
100
103
|
|
|
101
104
|
# extract the box constraints on actions
|
|
105
|
+
if python_functions is None:
|
|
106
|
+
python_functions = {}
|
|
107
|
+
self.python_functions = python_functions
|
|
102
108
|
simulator = RDDLSimulatorPrecompiled(
|
|
103
109
|
rddl=self.rddl,
|
|
104
110
|
init_values=self.init_values,
|
|
105
111
|
levels=self.levels,
|
|
106
|
-
trace_info=self.traced
|
|
112
|
+
trace_info=self.traced,
|
|
113
|
+
python_functions=python_functions
|
|
107
114
|
)
|
|
108
115
|
constraints = RDDLConstraints(simulator, vectorized=True)
|
|
109
116
|
self.constraints = constraints
|
|
@@ -237,7 +244,8 @@ class JaxRDDLCompiler:
|
|
|
237
244
|
|
|
238
245
|
def compile_transition(self, check_constraints: bool=False,
|
|
239
246
|
constraint_func: bool=False,
|
|
240
|
-
init_params_constr: Dict[str, Any]={}
|
|
247
|
+
init_params_constr: Dict[str, Any]={},
|
|
248
|
+
cache_path_info: bool=False) -> Callable:
|
|
241
249
|
'''Compiles the current RDDL model into a JAX transition function that
|
|
242
250
|
samples the next state.
|
|
243
251
|
|
|
@@ -274,6 +282,7 @@ class JaxRDDLCompiler:
|
|
|
274
282
|
returned log and does not raise an exception
|
|
275
283
|
:param constraint_func: produces the h(s, a) function described above
|
|
276
284
|
in addition to the usual outputs
|
|
285
|
+
:param cache_path_info: whether to save full path traces as part of the log
|
|
277
286
|
'''
|
|
278
287
|
NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
|
|
279
288
|
rddl = self.rddl
|
|
@@ -322,8 +331,11 @@ class JaxRDDLCompiler:
|
|
|
322
331
|
errors |= err
|
|
323
332
|
|
|
324
333
|
# calculate fluent values
|
|
325
|
-
|
|
326
|
-
|
|
334
|
+
if cache_path_info:
|
|
335
|
+
fluents = {name: values for (name, values) in subs.items()
|
|
336
|
+
if name not in rddl.non_fluents}
|
|
337
|
+
else:
|
|
338
|
+
fluents = {}
|
|
327
339
|
|
|
328
340
|
# set the next state to the current state
|
|
329
341
|
for (state, next_state) in rddl.next_state.items():
|
|
@@ -367,7 +379,9 @@ class JaxRDDLCompiler:
|
|
|
367
379
|
n_batch: int,
|
|
368
380
|
check_constraints: bool=False,
|
|
369
381
|
constraint_func: bool=False,
|
|
370
|
-
init_params_constr: Dict[str, Any]={}
|
|
382
|
+
init_params_constr: Dict[str, Any]={},
|
|
383
|
+
model_params_reduction: Callable=lambda x: x[0],
|
|
384
|
+
cache_path_info: bool=False) -> Callable:
|
|
371
385
|
'''Compiles the current RDDL model into a JAX transition function that
|
|
372
386
|
samples trajectories with a fixed horizon from a policy.
|
|
373
387
|
|
|
@@ -399,10 +413,13 @@ class JaxRDDLCompiler:
|
|
|
399
413
|
returned log and does not raise an exception
|
|
400
414
|
:param constraint_func: produces the h(s, a) constraint function
|
|
401
415
|
in addition to the usual outputs
|
|
416
|
+
:param model_params_reduction: how to aggregate updated model_params across runs
|
|
417
|
+
in the batch (defaults to selecting the first element's parameters in the batch)
|
|
418
|
+
:param cache_path_info: whether to save full path traces as part of the log
|
|
402
419
|
'''
|
|
403
420
|
rddl = self.rddl
|
|
404
421
|
jax_step_fn = self.compile_transition(
|
|
405
|
-
check_constraints, constraint_func, init_params_constr)
|
|
422
|
+
check_constraints, constraint_func, init_params_constr, cache_path_info)
|
|
406
423
|
|
|
407
424
|
# for POMDP only observ-fluents are assumed visible to the policy
|
|
408
425
|
if rddl.observ_fluents:
|
|
@@ -421,7 +438,6 @@ class JaxRDDLCompiler:
|
|
|
421
438
|
return jax_step_fn(subkey, actions, subs, model_params)
|
|
422
439
|
|
|
423
440
|
# do a batched step update from the policy
|
|
424
|
-
# TODO: come up with a better way to reduce the model_param batch dim
|
|
425
441
|
def _jax_wrapped_batched_step_policy(carry, step):
|
|
426
442
|
key, policy_params, hyperparams, subs, model_params = carry
|
|
427
443
|
key, *subkeys = random.split(key, num=1 + n_batch)
|
|
@@ -430,7 +446,7 @@ class JaxRDDLCompiler:
|
|
|
430
446
|
_jax_wrapped_single_step_policy,
|
|
431
447
|
in_axes=(0, None, None, None, 0, None)
|
|
432
448
|
)(keys, policy_params, hyperparams, step, subs, model_params)
|
|
433
|
-
model_params = jax.tree_util.tree_map(
|
|
449
|
+
model_params = jax.tree_util.tree_map(model_params_reduction, model_params)
|
|
434
450
|
carry = (key, policy_params, hyperparams, subs, model_params)
|
|
435
451
|
return carry, log
|
|
436
452
|
|
|
@@ -596,6 +612,8 @@ class JaxRDDLCompiler:
|
|
|
596
612
|
jax_expr = self._jax_aggregation(expr, init_params)
|
|
597
613
|
elif etype == 'func':
|
|
598
614
|
jax_expr = self._jax_functional(expr, init_params)
|
|
615
|
+
elif etype == 'pyfunc':
|
|
616
|
+
jax_expr = self._jax_pyfunc(expr, init_params)
|
|
599
617
|
elif etype == 'control':
|
|
600
618
|
jax_expr = self._jax_control(expr, init_params)
|
|
601
619
|
elif etype == 'randomvar':
|
|
@@ -917,6 +935,84 @@ class JaxRDDLCompiler:
|
|
|
917
935
|
raise RDDLNotImplementedError(
|
|
918
936
|
f'Function {op} is not supported.\n' + print_stack_trace(expr))
|
|
919
937
|
|
|
938
|
+
def _jax_pyfunc(self, expr, init_params):
|
|
939
|
+
NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
|
|
940
|
+
|
|
941
|
+
# get the Python function by name
|
|
942
|
+
_, pyfunc_name = expr.etype
|
|
943
|
+
pyfunc = self.python_functions.get(pyfunc_name)
|
|
944
|
+
if pyfunc is None:
|
|
945
|
+
raise RDDLUndefinedVariableError(
|
|
946
|
+
f'Undefined external Python function <{pyfunc_name}>, '
|
|
947
|
+
f'must be one of {list(self.python_functions.keys())}.\n' +
|
|
948
|
+
print_stack_trace(expr))
|
|
949
|
+
|
|
950
|
+
captured_vars, args = expr.args
|
|
951
|
+
scope_vars = self.traced.cached_objects_in_scope(expr)
|
|
952
|
+
dest_indices = self.traced.cached_sim_info(expr)
|
|
953
|
+
free_vars = [p for p in scope_vars if p[0] not in captured_vars]
|
|
954
|
+
free_dims = self.rddl.object_counts(p for (_, p) in free_vars)
|
|
955
|
+
num_free_vars = len(free_vars)
|
|
956
|
+
captured_types = [t for (p, t) in scope_vars if p in captured_vars]
|
|
957
|
+
require_dims = self.rddl.object_counts(captured_types)
|
|
958
|
+
|
|
959
|
+
# compile the inputs to the function
|
|
960
|
+
jax_inputs = [self._jax(arg, init_params) for arg in args]
|
|
961
|
+
|
|
962
|
+
# compile the function evaluation function
|
|
963
|
+
def _jax_wrapped_external_function(x, params, key):
|
|
964
|
+
|
|
965
|
+
# evaluate inputs to the function
|
|
966
|
+
# first dimensions are non-captured vars in outer scope followed by all the _
|
|
967
|
+
error = NORMAL
|
|
968
|
+
flat_samples = []
|
|
969
|
+
for jax_expr in jax_inputs:
|
|
970
|
+
sample, key, err, params = jax_expr(x, params, key)
|
|
971
|
+
shape = jnp.shape(sample)
|
|
972
|
+
first_dim = 1
|
|
973
|
+
for dim in shape[:num_free_vars]:
|
|
974
|
+
first_dim *= dim
|
|
975
|
+
new_shape = (first_dim,) + shape[num_free_vars:]
|
|
976
|
+
flat_sample = jnp.reshape(sample, new_shape)
|
|
977
|
+
flat_samples.append(flat_sample)
|
|
978
|
+
error |= err
|
|
979
|
+
|
|
980
|
+
# now all the inputs have dimensions equal to (k,) + the number of _ occurences
|
|
981
|
+
# k is the number of possible non-captured object combinations
|
|
982
|
+
# evaluate the function independently for each combination
|
|
983
|
+
# output dimension for each combination is captured variables (n1, n2, ...)
|
|
984
|
+
# so the total dimension of the output array is (k, n1, n2, ...)
|
|
985
|
+
sample = jax.vmap(pyfunc, in_axes=0)(*flat_samples)
|
|
986
|
+
if not isinstance(sample, jnp.ndarray):
|
|
987
|
+
raise ValueError(
|
|
988
|
+
f'Output of external Python function <{pyfunc_name}> '
|
|
989
|
+
f'is not a JAX array.\n' + print_stack_trace(expr))
|
|
990
|
+
|
|
991
|
+
pyfunc_dims = jnp.shape(sample)[1:]
|
|
992
|
+
if len(require_dims) != len(pyfunc_dims):
|
|
993
|
+
raise ValueError(
|
|
994
|
+
f'External Python function <{pyfunc_name}> returned array with '
|
|
995
|
+
f'{len(pyfunc_dims)} dimensions, which does not match the '
|
|
996
|
+
f'number of captured parameter(s) {len(require_dims)}.\n' +
|
|
997
|
+
print_stack_trace(expr))
|
|
998
|
+
for (param, require_dim, actual_dim) in zip(captured_vars, require_dims, pyfunc_dims):
|
|
999
|
+
if require_dim != actual_dim:
|
|
1000
|
+
raise ValueError(
|
|
1001
|
+
f'External Python function <{pyfunc_name}> returned array with '
|
|
1002
|
+
f'{actual_dim} elements for captured parameter <{param}>, '
|
|
1003
|
+
f'which does not match the number of objects {require_dim}.\n' +
|
|
1004
|
+
print_stack_trace(expr))
|
|
1005
|
+
|
|
1006
|
+
# unravel the combinations k back into their original dimensions
|
|
1007
|
+
sample = jnp.reshape(sample, free_dims + pyfunc_dims)
|
|
1008
|
+
|
|
1009
|
+
# rearrange the output dimensions to match the outer scope
|
|
1010
|
+
source_indices = [num_free_vars + i for i in range(len(pyfunc_dims))]
|
|
1011
|
+
sample = jnp.moveaxis(sample, source=source_indices, destination=dest_indices)
|
|
1012
|
+
return sample, key, error, params
|
|
1013
|
+
|
|
1014
|
+
return _jax_wrapped_external_function
|
|
1015
|
+
|
|
920
1016
|
# ===========================================================================
|
|
921
1017
|
# control flow
|
|
922
1018
|
# ===========================================================================
|
|
@@ -1056,15 +1056,13 @@ class ExactLogic(Logic):
|
|
|
1056
1056
|
def control_if(self, id, init_params):
|
|
1057
1057
|
return self._jax_wrapped_calc_if_then_else_exact
|
|
1058
1058
|
|
|
1059
|
-
@staticmethod
|
|
1060
|
-
def _jax_wrapped_calc_switch_exact(pred, cases, params):
|
|
1061
|
-
pred = pred[jnp.newaxis, ...]
|
|
1062
|
-
sample = jnp.take_along_axis(cases, pred, axis=0)
|
|
1063
|
-
assert sample.shape[0] == 1
|
|
1064
|
-
return sample[0, ...], params
|
|
1065
|
-
|
|
1066
1059
|
def control_switch(self, id, init_params):
|
|
1067
|
-
|
|
1060
|
+
def _jax_wrapped_calc_switch_exact(pred, cases, params):
|
|
1061
|
+
pred = jnp.asarray(pred[jnp.newaxis, ...], dtype=self.INT)
|
|
1062
|
+
sample = jnp.take_along_axis(cases, pred, axis=0)
|
|
1063
|
+
assert sample.shape[0] == 1
|
|
1064
|
+
return sample[0, ...], params
|
|
1065
|
+
return _jax_wrapped_calc_switch_exact
|
|
1068
1066
|
|
|
1069
1067
|
# ===========================================================================
|
|
1070
1068
|
# random variables
|