pyRDDLGym-jax 2.5__py3-none-any.whl → 2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +107 -11
- pyRDDLGym_jax/core/logic.py +6 -8
- pyRDDLGym_jax/core/model.py +595 -0
- pyRDDLGym_jax/core/planner.py +183 -24
- pyRDDLGym_jax/core/simulator.py +12 -4
- pyRDDLGym_jax/examples/run_plan.py +31 -0
- {pyrddlgym_jax-2.5.dist-info → pyrddlgym_jax-2.7.dist-info}/METADATA +5 -13
- {pyrddlgym_jax-2.5.dist-info → pyrddlgym_jax-2.7.dist-info}/RECORD +13 -12
- {pyrddlgym_jax-2.5.dist-info → pyrddlgym_jax-2.7.dist-info}/licenses/LICENSE +1 -1
- {pyrddlgym_jax-2.5.dist-info → pyrddlgym_jax-2.7.dist-info}/WHEEL +0 -0
- {pyrddlgym_jax-2.5.dist-info → pyrddlgym_jax-2.7.dist-info}/entry_points.txt +0 -0
- {pyrddlgym_jax-2.5.dist-info → pyrddlgym_jax-2.7.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '2.
|
|
1
|
+
__version__ = '2.7'
|
pyRDDLGym_jax/core/compiler.py
CHANGED
|
@@ -30,7 +30,8 @@ from pyRDDLGym.core.debug.exception import (
|
|
|
30
30
|
print_stack_trace,
|
|
31
31
|
raise_warning,
|
|
32
32
|
RDDLInvalidNumberOfArgumentsError,
|
|
33
|
-
RDDLNotImplementedError
|
|
33
|
+
RDDLNotImplementedError,
|
|
34
|
+
RDDLUndefinedVariableError
|
|
34
35
|
)
|
|
35
36
|
from pyRDDLGym.core.debug.logger import Logger
|
|
36
37
|
from pyRDDLGym.core.simulator import RDDLSimulatorPrecompiled
|
|
@@ -56,7 +57,8 @@ class JaxRDDLCompiler:
|
|
|
56
57
|
allow_synchronous_state: bool=True,
|
|
57
58
|
logger: Optional[Logger]=None,
|
|
58
59
|
use64bit: bool=False,
|
|
59
|
-
compile_non_fluent_exact: bool=True
|
|
60
|
+
compile_non_fluent_exact: bool=True,
|
|
61
|
+
python_functions: Optional[Dict[str, Callable]]=None) -> None:
|
|
60
62
|
'''Creates a new RDDL to Jax compiler.
|
|
61
63
|
|
|
62
64
|
:param rddl: the RDDL model to compile into Jax
|
|
@@ -65,7 +67,8 @@ class JaxRDDLCompiler:
|
|
|
65
67
|
:param logger: to log information about compilation to file
|
|
66
68
|
:param use64bit: whether to use 64 bit arithmetic
|
|
67
69
|
:param compile_non_fluent_exact: whether non-fluent expressions
|
|
68
|
-
are always compiled using exact JAX expressions
|
|
70
|
+
are always compiled using exact JAX expressions
|
|
71
|
+
:param python_functions: dictionary of external Python functions to call from RDDL
|
|
69
72
|
'''
|
|
70
73
|
self.rddl = rddl
|
|
71
74
|
self.logger = logger
|
|
@@ -99,11 +102,15 @@ class JaxRDDLCompiler:
|
|
|
99
102
|
self.traced = tracer.trace()
|
|
100
103
|
|
|
101
104
|
# extract the box constraints on actions
|
|
105
|
+
if python_functions is None:
|
|
106
|
+
python_functions = {}
|
|
107
|
+
self.python_functions = python_functions
|
|
102
108
|
simulator = RDDLSimulatorPrecompiled(
|
|
103
109
|
rddl=self.rddl,
|
|
104
110
|
init_values=self.init_values,
|
|
105
111
|
levels=self.levels,
|
|
106
|
-
trace_info=self.traced
|
|
112
|
+
trace_info=self.traced,
|
|
113
|
+
python_functions=python_functions
|
|
107
114
|
)
|
|
108
115
|
constraints = RDDLConstraints(simulator, vectorized=True)
|
|
109
116
|
self.constraints = constraints
|
|
@@ -237,7 +244,8 @@ class JaxRDDLCompiler:
|
|
|
237
244
|
|
|
238
245
|
def compile_transition(self, check_constraints: bool=False,
|
|
239
246
|
constraint_func: bool=False,
|
|
240
|
-
init_params_constr: Dict[str, Any]={}
|
|
247
|
+
init_params_constr: Dict[str, Any]={},
|
|
248
|
+
cache_path_info: bool=False) -> Callable:
|
|
241
249
|
'''Compiles the current RDDL model into a JAX transition function that
|
|
242
250
|
samples the next state.
|
|
243
251
|
|
|
@@ -274,6 +282,7 @@ class JaxRDDLCompiler:
|
|
|
274
282
|
returned log and does not raise an exception
|
|
275
283
|
:param constraint_func: produces the h(s, a) function described above
|
|
276
284
|
in addition to the usual outputs
|
|
285
|
+
:param cache_path_info: whether to save full path traces as part of the log
|
|
277
286
|
'''
|
|
278
287
|
NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
|
|
279
288
|
rddl = self.rddl
|
|
@@ -322,8 +331,11 @@ class JaxRDDLCompiler:
|
|
|
322
331
|
errors |= err
|
|
323
332
|
|
|
324
333
|
# calculate fluent values
|
|
325
|
-
|
|
326
|
-
|
|
334
|
+
if cache_path_info:
|
|
335
|
+
fluents = {name: values for (name, values) in subs.items()
|
|
336
|
+
if name not in rddl.non_fluents}
|
|
337
|
+
else:
|
|
338
|
+
fluents = {}
|
|
327
339
|
|
|
328
340
|
# set the next state to the current state
|
|
329
341
|
for (state, next_state) in rddl.next_state.items():
|
|
@@ -367,7 +379,9 @@ class JaxRDDLCompiler:
|
|
|
367
379
|
n_batch: int,
|
|
368
380
|
check_constraints: bool=False,
|
|
369
381
|
constraint_func: bool=False,
|
|
370
|
-
init_params_constr: Dict[str, Any]={}
|
|
382
|
+
init_params_constr: Dict[str, Any]={},
|
|
383
|
+
model_params_reduction: Callable=lambda x: x[0],
|
|
384
|
+
cache_path_info: bool=False) -> Callable:
|
|
371
385
|
'''Compiles the current RDDL model into a JAX transition function that
|
|
372
386
|
samples trajectories with a fixed horizon from a policy.
|
|
373
387
|
|
|
@@ -399,10 +413,13 @@ class JaxRDDLCompiler:
|
|
|
399
413
|
returned log and does not raise an exception
|
|
400
414
|
:param constraint_func: produces the h(s, a) constraint function
|
|
401
415
|
in addition to the usual outputs
|
|
416
|
+
:param model_params_reduction: how to aggregate updated model_params across runs
|
|
417
|
+
in the batch (defaults to selecting the first element's parameters in the batch)
|
|
418
|
+
:param cache_path_info: whether to save full path traces as part of the log
|
|
402
419
|
'''
|
|
403
420
|
rddl = self.rddl
|
|
404
421
|
jax_step_fn = self.compile_transition(
|
|
405
|
-
check_constraints, constraint_func, init_params_constr)
|
|
422
|
+
check_constraints, constraint_func, init_params_constr, cache_path_info)
|
|
406
423
|
|
|
407
424
|
# for POMDP only observ-fluents are assumed visible to the policy
|
|
408
425
|
if rddl.observ_fluents:
|
|
@@ -421,7 +438,6 @@ class JaxRDDLCompiler:
|
|
|
421
438
|
return jax_step_fn(subkey, actions, subs, model_params)
|
|
422
439
|
|
|
423
440
|
# do a batched step update from the policy
|
|
424
|
-
# TODO: come up with a better way to reduce the model_param batch dim
|
|
425
441
|
def _jax_wrapped_batched_step_policy(carry, step):
|
|
426
442
|
key, policy_params, hyperparams, subs, model_params = carry
|
|
427
443
|
key, *subkeys = random.split(key, num=1 + n_batch)
|
|
@@ -430,7 +446,7 @@ class JaxRDDLCompiler:
|
|
|
430
446
|
_jax_wrapped_single_step_policy,
|
|
431
447
|
in_axes=(0, None, None, None, 0, None)
|
|
432
448
|
)(keys, policy_params, hyperparams, step, subs, model_params)
|
|
433
|
-
model_params = jax.tree_util.tree_map(
|
|
449
|
+
model_params = jax.tree_util.tree_map(model_params_reduction, model_params)
|
|
434
450
|
carry = (key, policy_params, hyperparams, subs, model_params)
|
|
435
451
|
return carry, log
|
|
436
452
|
|
|
@@ -596,6 +612,8 @@ class JaxRDDLCompiler:
|
|
|
596
612
|
jax_expr = self._jax_aggregation(expr, init_params)
|
|
597
613
|
elif etype == 'func':
|
|
598
614
|
jax_expr = self._jax_functional(expr, init_params)
|
|
615
|
+
elif etype == 'pyfunc':
|
|
616
|
+
jax_expr = self._jax_pyfunc(expr, init_params)
|
|
599
617
|
elif etype == 'control':
|
|
600
618
|
jax_expr = self._jax_control(expr, init_params)
|
|
601
619
|
elif etype == 'randomvar':
|
|
@@ -917,6 +935,84 @@ class JaxRDDLCompiler:
|
|
|
917
935
|
raise RDDLNotImplementedError(
|
|
918
936
|
f'Function {op} is not supported.\n' + print_stack_trace(expr))
|
|
919
937
|
|
|
938
|
+
def _jax_pyfunc(self, expr, init_params):
|
|
939
|
+
NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
|
|
940
|
+
|
|
941
|
+
# get the Python function by name
|
|
942
|
+
_, pyfunc_name = expr.etype
|
|
943
|
+
pyfunc = self.python_functions.get(pyfunc_name)
|
|
944
|
+
if pyfunc is None:
|
|
945
|
+
raise RDDLUndefinedVariableError(
|
|
946
|
+
f'Undefined external Python function <{pyfunc_name}>, '
|
|
947
|
+
f'must be one of {list(self.python_functions.keys())}.\n' +
|
|
948
|
+
print_stack_trace(expr))
|
|
949
|
+
|
|
950
|
+
captured_vars, args = expr.args
|
|
951
|
+
scope_vars = self.traced.cached_objects_in_scope(expr)
|
|
952
|
+
dest_indices = self.traced.cached_sim_info(expr)
|
|
953
|
+
free_vars = [p for p in scope_vars if p[0] not in captured_vars]
|
|
954
|
+
free_dims = self.rddl.object_counts(p for (_, p) in free_vars)
|
|
955
|
+
num_free_vars = len(free_vars)
|
|
956
|
+
captured_types = [t for (p, t) in scope_vars if p in captured_vars]
|
|
957
|
+
require_dims = self.rddl.object_counts(captured_types)
|
|
958
|
+
|
|
959
|
+
# compile the inputs to the function
|
|
960
|
+
jax_inputs = [self._jax(arg, init_params) for arg in args]
|
|
961
|
+
|
|
962
|
+
# compile the function evaluation function
|
|
963
|
+
def _jax_wrapped_external_function(x, params, key):
|
|
964
|
+
|
|
965
|
+
# evaluate inputs to the function
|
|
966
|
+
# first dimensions are non-captured vars in outer scope followed by all the _
|
|
967
|
+
error = NORMAL
|
|
968
|
+
flat_samples = []
|
|
969
|
+
for jax_expr in jax_inputs:
|
|
970
|
+
sample, key, err, params = jax_expr(x, params, key)
|
|
971
|
+
shape = jnp.shape(sample)
|
|
972
|
+
first_dim = 1
|
|
973
|
+
for dim in shape[:num_free_vars]:
|
|
974
|
+
first_dim *= dim
|
|
975
|
+
new_shape = (first_dim,) + shape[num_free_vars:]
|
|
976
|
+
flat_sample = jnp.reshape(sample, new_shape)
|
|
977
|
+
flat_samples.append(flat_sample)
|
|
978
|
+
error |= err
|
|
979
|
+
|
|
980
|
+
# now all the inputs have dimensions equal to (k,) + the number of _ occurences
|
|
981
|
+
# k is the number of possible non-captured object combinations
|
|
982
|
+
# evaluate the function independently for each combination
|
|
983
|
+
# output dimension for each combination is captured variables (n1, n2, ...)
|
|
984
|
+
# so the total dimension of the output array is (k, n1, n2, ...)
|
|
985
|
+
sample = jax.vmap(pyfunc, in_axes=0)(*flat_samples)
|
|
986
|
+
if not isinstance(sample, jnp.ndarray):
|
|
987
|
+
raise ValueError(
|
|
988
|
+
f'Output of external Python function <{pyfunc_name}> '
|
|
989
|
+
f'is not a JAX array.\n' + print_stack_trace(expr))
|
|
990
|
+
|
|
991
|
+
pyfunc_dims = jnp.shape(sample)[1:]
|
|
992
|
+
if len(require_dims) != len(pyfunc_dims):
|
|
993
|
+
raise ValueError(
|
|
994
|
+
f'External Python function <{pyfunc_name}> returned array with '
|
|
995
|
+
f'{len(pyfunc_dims)} dimensions, which does not match the '
|
|
996
|
+
f'number of captured parameter(s) {len(require_dims)}.\n' +
|
|
997
|
+
print_stack_trace(expr))
|
|
998
|
+
for (param, require_dim, actual_dim) in zip(captured_vars, require_dims, pyfunc_dims):
|
|
999
|
+
if require_dim != actual_dim:
|
|
1000
|
+
raise ValueError(
|
|
1001
|
+
f'External Python function <{pyfunc_name}> returned array with '
|
|
1002
|
+
f'{actual_dim} elements for captured parameter <{param}>, '
|
|
1003
|
+
f'which does not match the number of objects {require_dim}.\n' +
|
|
1004
|
+
print_stack_trace(expr))
|
|
1005
|
+
|
|
1006
|
+
# unravel the combinations k back into their original dimensions
|
|
1007
|
+
sample = jnp.reshape(sample, free_dims + pyfunc_dims)
|
|
1008
|
+
|
|
1009
|
+
# rearrange the output dimensions to match the outer scope
|
|
1010
|
+
source_indices = [num_free_vars + i for i in range(len(pyfunc_dims))]
|
|
1011
|
+
sample = jnp.moveaxis(sample, source=source_indices, destination=dest_indices)
|
|
1012
|
+
return sample, key, error, params
|
|
1013
|
+
|
|
1014
|
+
return _jax_wrapped_external_function
|
|
1015
|
+
|
|
920
1016
|
# ===========================================================================
|
|
921
1017
|
# control flow
|
|
922
1018
|
# ===========================================================================
|
pyRDDLGym_jax/core/logic.py
CHANGED
|
@@ -1056,15 +1056,13 @@ class ExactLogic(Logic):
|
|
|
1056
1056
|
def control_if(self, id, init_params):
|
|
1057
1057
|
return self._jax_wrapped_calc_if_then_else_exact
|
|
1058
1058
|
|
|
1059
|
-
@staticmethod
|
|
1060
|
-
def _jax_wrapped_calc_switch_exact(pred, cases, params):
|
|
1061
|
-
pred = pred[jnp.newaxis, ...]
|
|
1062
|
-
sample = jnp.take_along_axis(cases, pred, axis=0)
|
|
1063
|
-
assert sample.shape[0] == 1
|
|
1064
|
-
return sample[0, ...], params
|
|
1065
|
-
|
|
1066
1059
|
def control_switch(self, id, init_params):
|
|
1067
|
-
|
|
1060
|
+
def _jax_wrapped_calc_switch_exact(pred, cases, params):
|
|
1061
|
+
pred = jnp.asarray(pred[jnp.newaxis, ...], dtype=self.INT)
|
|
1062
|
+
sample = jnp.take_along_axis(cases, pred, axis=0)
|
|
1063
|
+
assert sample.shape[0] == 1
|
|
1064
|
+
return sample[0, ...], params
|
|
1065
|
+
return _jax_wrapped_calc_switch_exact
|
|
1068
1066
|
|
|
1069
1067
|
# ===========================================================================
|
|
1070
1068
|
# random variables
|