pyRDDLGym-jax 2.5__py3-none-any.whl → 2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pyRDDLGym_jax/__init__.py CHANGED
@@ -1 +1 @@
1
- __version__ = '2.5'
1
+ __version__ = '2.7'
@@ -30,7 +30,8 @@ from pyRDDLGym.core.debug.exception import (
30
30
  print_stack_trace,
31
31
  raise_warning,
32
32
  RDDLInvalidNumberOfArgumentsError,
33
- RDDLNotImplementedError
33
+ RDDLNotImplementedError,
34
+ RDDLUndefinedVariableError
34
35
  )
35
36
  from pyRDDLGym.core.debug.logger import Logger
36
37
  from pyRDDLGym.core.simulator import RDDLSimulatorPrecompiled
@@ -56,7 +57,8 @@ class JaxRDDLCompiler:
56
57
  allow_synchronous_state: bool=True,
57
58
  logger: Optional[Logger]=None,
58
59
  use64bit: bool=False,
59
- compile_non_fluent_exact: bool=True) -> None:
60
+ compile_non_fluent_exact: bool=True,
61
+ python_functions: Optional[Dict[str, Callable]]=None) -> None:
60
62
  '''Creates a new RDDL to Jax compiler.
61
63
 
62
64
  :param rddl: the RDDL model to compile into Jax
@@ -65,7 +67,8 @@ class JaxRDDLCompiler:
65
67
  :param logger: to log information about compilation to file
66
68
  :param use64bit: whether to use 64 bit arithmetic
67
69
  :param compile_non_fluent_exact: whether non-fluent expressions
68
- are always compiled using exact JAX expressions.
70
+ are always compiled using exact JAX expressions
71
+ :param python_functions: dictionary of external Python functions to call from RDDL
69
72
  '''
70
73
  self.rddl = rddl
71
74
  self.logger = logger
@@ -99,11 +102,15 @@ class JaxRDDLCompiler:
99
102
  self.traced = tracer.trace()
100
103
 
101
104
  # extract the box constraints on actions
105
+ if python_functions is None:
106
+ python_functions = {}
107
+ self.python_functions = python_functions
102
108
  simulator = RDDLSimulatorPrecompiled(
103
109
  rddl=self.rddl,
104
110
  init_values=self.init_values,
105
111
  levels=self.levels,
106
- trace_info=self.traced
112
+ trace_info=self.traced,
113
+ python_functions=python_functions
107
114
  )
108
115
  constraints = RDDLConstraints(simulator, vectorized=True)
109
116
  self.constraints = constraints
@@ -237,7 +244,8 @@ class JaxRDDLCompiler:
237
244
 
238
245
  def compile_transition(self, check_constraints: bool=False,
239
246
  constraint_func: bool=False,
240
- init_params_constr: Dict[str, Any]={}) -> Callable:
247
+ init_params_constr: Dict[str, Any]={},
248
+ cache_path_info: bool=False) -> Callable:
241
249
  '''Compiles the current RDDL model into a JAX transition function that
242
250
  samples the next state.
243
251
 
@@ -274,6 +282,7 @@ class JaxRDDLCompiler:
274
282
  returned log and does not raise an exception
275
283
  :param constraint_func: produces the h(s, a) function described above
276
284
  in addition to the usual outputs
285
+ :param cache_path_info: whether to save full path traces as part of the log
277
286
  '''
278
287
  NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
279
288
  rddl = self.rddl
@@ -322,8 +331,11 @@ class JaxRDDLCompiler:
322
331
  errors |= err
323
332
 
324
333
  # calculate fluent values
325
- fluents = {name: values for (name, values) in subs.items()
326
- if name not in rddl.non_fluents}
334
+ if cache_path_info:
335
+ fluents = {name: values for (name, values) in subs.items()
336
+ if name not in rddl.non_fluents}
337
+ else:
338
+ fluents = {}
327
339
 
328
340
  # set the next state to the current state
329
341
  for (state, next_state) in rddl.next_state.items():
@@ -367,7 +379,9 @@ class JaxRDDLCompiler:
367
379
  n_batch: int,
368
380
  check_constraints: bool=False,
369
381
  constraint_func: bool=False,
370
- init_params_constr: Dict[str, Any]={}) -> Callable:
382
+ init_params_constr: Dict[str, Any]={},
383
+ model_params_reduction: Callable=lambda x: x[0],
384
+ cache_path_info: bool=False) -> Callable:
371
385
  '''Compiles the current RDDL model into a JAX transition function that
372
386
  samples trajectories with a fixed horizon from a policy.
373
387
 
@@ -399,10 +413,13 @@ class JaxRDDLCompiler:
399
413
  returned log and does not raise an exception
400
414
  :param constraint_func: produces the h(s, a) constraint function
401
415
  in addition to the usual outputs
416
+ :param model_params_reduction: how to aggregate updated model_params across runs
417
+ in the batch (defaults to selecting the first element's parameters in the batch)
418
+ :param cache_path_info: whether to save full path traces as part of the log
402
419
  '''
403
420
  rddl = self.rddl
404
421
  jax_step_fn = self.compile_transition(
405
- check_constraints, constraint_func, init_params_constr)
422
+ check_constraints, constraint_func, init_params_constr, cache_path_info)
406
423
 
407
424
  # for POMDP only observ-fluents are assumed visible to the policy
408
425
  if rddl.observ_fluents:
@@ -421,7 +438,6 @@ class JaxRDDLCompiler:
421
438
  return jax_step_fn(subkey, actions, subs, model_params)
422
439
 
423
440
  # do a batched step update from the policy
424
- # TODO: come up with a better way to reduce the model_param batch dim
425
441
  def _jax_wrapped_batched_step_policy(carry, step):
426
442
  key, policy_params, hyperparams, subs, model_params = carry
427
443
  key, *subkeys = random.split(key, num=1 + n_batch)
@@ -430,7 +446,7 @@ class JaxRDDLCompiler:
430
446
  _jax_wrapped_single_step_policy,
431
447
  in_axes=(0, None, None, None, 0, None)
432
448
  )(keys, policy_params, hyperparams, step, subs, model_params)
433
- model_params = jax.tree_util.tree_map(partial(jnp.mean, axis=0), model_params)
449
+ model_params = jax.tree_util.tree_map(model_params_reduction, model_params)
434
450
  carry = (key, policy_params, hyperparams, subs, model_params)
435
451
  return carry, log
436
452
 
@@ -596,6 +612,8 @@ class JaxRDDLCompiler:
596
612
  jax_expr = self._jax_aggregation(expr, init_params)
597
613
  elif etype == 'func':
598
614
  jax_expr = self._jax_functional(expr, init_params)
615
+ elif etype == 'pyfunc':
616
+ jax_expr = self._jax_pyfunc(expr, init_params)
599
617
  elif etype == 'control':
600
618
  jax_expr = self._jax_control(expr, init_params)
601
619
  elif etype == 'randomvar':
@@ -917,6 +935,84 @@ class JaxRDDLCompiler:
917
935
  raise RDDLNotImplementedError(
918
936
  f'Function {op} is not supported.\n' + print_stack_trace(expr))
919
937
 
938
+ def _jax_pyfunc(self, expr, init_params):
939
+ NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
940
+
941
+ # get the Python function by name
942
+ _, pyfunc_name = expr.etype
943
+ pyfunc = self.python_functions.get(pyfunc_name)
944
+ if pyfunc is None:
945
+ raise RDDLUndefinedVariableError(
946
+ f'Undefined external Python function <{pyfunc_name}>, '
947
+ f'must be one of {list(self.python_functions.keys())}.\n' +
948
+ print_stack_trace(expr))
949
+
950
+ captured_vars, args = expr.args
951
+ scope_vars = self.traced.cached_objects_in_scope(expr)
952
+ dest_indices = self.traced.cached_sim_info(expr)
953
+ free_vars = [p for p in scope_vars if p[0] not in captured_vars]
954
+ free_dims = self.rddl.object_counts(p for (_, p) in free_vars)
955
+ num_free_vars = len(free_vars)
956
+ captured_types = [t for (p, t) in scope_vars if p in captured_vars]
957
+ require_dims = self.rddl.object_counts(captured_types)
958
+
959
+ # compile the inputs to the function
960
+ jax_inputs = [self._jax(arg, init_params) for arg in args]
961
+
962
+ # compile the function evaluation function
963
+ def _jax_wrapped_external_function(x, params, key):
964
+
965
+ # evaluate inputs to the function
966
+ # first dimensions are non-captured vars in outer scope followed by all the _
967
+ error = NORMAL
968
+ flat_samples = []
969
+ for jax_expr in jax_inputs:
970
+ sample, key, err, params = jax_expr(x, params, key)
971
+ shape = jnp.shape(sample)
972
+ first_dim = 1
973
+ for dim in shape[:num_free_vars]:
974
+ first_dim *= dim
975
+ new_shape = (first_dim,) + shape[num_free_vars:]
976
+ flat_sample = jnp.reshape(sample, new_shape)
977
+ flat_samples.append(flat_sample)
978
+ error |= err
979
+
980
+ # now all the inputs have dimensions equal to (k,) + the number of _ occurences
981
+ # k is the number of possible non-captured object combinations
982
+ # evaluate the function independently for each combination
983
+ # output dimension for each combination is captured variables (n1, n2, ...)
984
+ # so the total dimension of the output array is (k, n1, n2, ...)
985
+ sample = jax.vmap(pyfunc, in_axes=0)(*flat_samples)
986
+ if not isinstance(sample, jnp.ndarray):
987
+ raise ValueError(
988
+ f'Output of external Python function <{pyfunc_name}> '
989
+ f'is not a JAX array.\n' + print_stack_trace(expr))
990
+
991
+ pyfunc_dims = jnp.shape(sample)[1:]
992
+ if len(require_dims) != len(pyfunc_dims):
993
+ raise ValueError(
994
+ f'External Python function <{pyfunc_name}> returned array with '
995
+ f'{len(pyfunc_dims)} dimensions, which does not match the '
996
+ f'number of captured parameter(s) {len(require_dims)}.\n' +
997
+ print_stack_trace(expr))
998
+ for (param, require_dim, actual_dim) in zip(captured_vars, require_dims, pyfunc_dims):
999
+ if require_dim != actual_dim:
1000
+ raise ValueError(
1001
+ f'External Python function <{pyfunc_name}> returned array with '
1002
+ f'{actual_dim} elements for captured parameter <{param}>, '
1003
+ f'which does not match the number of objects {require_dim}.\n' +
1004
+ print_stack_trace(expr))
1005
+
1006
+ # unravel the combinations k back into their original dimensions
1007
+ sample = jnp.reshape(sample, free_dims + pyfunc_dims)
1008
+
1009
+ # rearrange the output dimensions to match the outer scope
1010
+ source_indices = [num_free_vars + i for i in range(len(pyfunc_dims))]
1011
+ sample = jnp.moveaxis(sample, source=source_indices, destination=dest_indices)
1012
+ return sample, key, error, params
1013
+
1014
+ return _jax_wrapped_external_function
1015
+
920
1016
  # ===========================================================================
921
1017
  # control flow
922
1018
  # ===========================================================================
@@ -1056,15 +1056,13 @@ class ExactLogic(Logic):
1056
1056
  def control_if(self, id, init_params):
1057
1057
  return self._jax_wrapped_calc_if_then_else_exact
1058
1058
 
1059
- @staticmethod
1060
- def _jax_wrapped_calc_switch_exact(pred, cases, params):
1061
- pred = pred[jnp.newaxis, ...]
1062
- sample = jnp.take_along_axis(cases, pred, axis=0)
1063
- assert sample.shape[0] == 1
1064
- return sample[0, ...], params
1065
-
1066
1059
  def control_switch(self, id, init_params):
1067
- return self._jax_wrapped_calc_switch_exact
1060
+ def _jax_wrapped_calc_switch_exact(pred, cases, params):
1061
+ pred = jnp.asarray(pred[jnp.newaxis, ...], dtype=self.INT)
1062
+ sample = jnp.take_along_axis(cases, pred, axis=0)
1063
+ assert sample.shape[0] == 1
1064
+ return sample[0, ...], params
1065
+ return _jax_wrapped_calc_switch_exact
1068
1066
 
1069
1067
  # ===========================================================================
1070
1068
  # random variables