pyRDDLGym-jax 2.6__py3-none-any.whl → 2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pyRDDLGym_jax/__init__.py CHANGED
@@ -1 +1 @@
1
- __version__ = '2.6'
1
+ __version__ = '2.8'
@@ -30,7 +30,8 @@ from pyRDDLGym.core.debug.exception import (
30
30
  print_stack_trace,
31
31
  raise_warning,
32
32
  RDDLInvalidNumberOfArgumentsError,
33
- RDDLNotImplementedError
33
+ RDDLNotImplementedError,
34
+ RDDLUndefinedVariableError
34
35
  )
35
36
  from pyRDDLGym.core.debug.logger import Logger
36
37
  from pyRDDLGym.core.simulator import RDDLSimulatorPrecompiled
@@ -56,7 +57,8 @@ class JaxRDDLCompiler:
56
57
  allow_synchronous_state: bool=True,
57
58
  logger: Optional[Logger]=None,
58
59
  use64bit: bool=False,
59
- compile_non_fluent_exact: bool=True) -> None:
60
+ compile_non_fluent_exact: bool=True,
61
+ python_functions: Optional[Dict[str, Callable]]=None) -> None:
60
62
  '''Creates a new RDDL to Jax compiler.
61
63
 
62
64
  :param rddl: the RDDL model to compile into Jax
@@ -65,7 +67,8 @@ class JaxRDDLCompiler:
65
67
  :param logger: to log information about compilation to file
66
68
  :param use64bit: whether to use 64 bit arithmetic
67
69
  :param compile_non_fluent_exact: whether non-fluent expressions
68
- are always compiled using exact JAX expressions.
70
+ are always compiled using exact JAX expressions
71
+ :param python_functions: dictionary of external Python functions to call from RDDL
69
72
  '''
70
73
  self.rddl = rddl
71
74
  self.logger = logger
@@ -99,11 +102,15 @@ class JaxRDDLCompiler:
99
102
  self.traced = tracer.trace()
100
103
 
101
104
  # extract the box constraints on actions
105
+ if python_functions is None:
106
+ python_functions = {}
107
+ self.python_functions = python_functions
102
108
  simulator = RDDLSimulatorPrecompiled(
103
109
  rddl=self.rddl,
104
110
  init_values=self.init_values,
105
111
  levels=self.levels,
106
- trace_info=self.traced
112
+ trace_info=self.traced,
113
+ python_functions=python_functions
107
114
  )
108
115
  constraints = RDDLConstraints(simulator, vectorized=True)
109
116
  self.constraints = constraints
@@ -605,6 +612,8 @@ class JaxRDDLCompiler:
605
612
  jax_expr = self._jax_aggregation(expr, init_params)
606
613
  elif etype == 'func':
607
614
  jax_expr = self._jax_functional(expr, init_params)
615
+ elif etype == 'pyfunc':
616
+ jax_expr = self._jax_pyfunc(expr, init_params)
608
617
  elif etype == 'control':
609
618
  jax_expr = self._jax_control(expr, init_params)
610
619
  elif etype == 'randomvar':
@@ -926,6 +935,84 @@ class JaxRDDLCompiler:
926
935
  raise RDDLNotImplementedError(
927
936
  f'Function {op} is not supported.\n' + print_stack_trace(expr))
928
937
 
938
+ def _jax_pyfunc(self, expr, init_params):
939
+ NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
940
+
941
+ # get the Python function by name
942
+ _, pyfunc_name = expr.etype
943
+ pyfunc = self.python_functions.get(pyfunc_name)
944
+ if pyfunc is None:
945
+ raise RDDLUndefinedVariableError(
946
+ f'Undefined external Python function <{pyfunc_name}>, '
947
+ f'must be one of {list(self.python_functions.keys())}.\n' +
948
+ print_stack_trace(expr))
949
+
950
+ captured_vars, args = expr.args
951
+ scope_vars = self.traced.cached_objects_in_scope(expr)
952
+ dest_indices = self.traced.cached_sim_info(expr)
953
+ free_vars = [p for p in scope_vars if p[0] not in captured_vars]
954
+ free_dims = self.rddl.object_counts(p for (_, p) in free_vars)
955
+ num_free_vars = len(free_vars)
956
+ captured_types = [t for (p, t) in scope_vars if p in captured_vars]
957
+ require_dims = self.rddl.object_counts(captured_types)
958
+
959
+ # compile the inputs to the function
960
+ jax_inputs = [self._jax(arg, init_params) for arg in args]
961
+
962
+ # compile the function evaluation function
963
+ def _jax_wrapped_external_function(x, params, key):
964
+
965
+ # evaluate inputs to the function
966
+ # first dimensions are non-captured vars in outer scope followed by all the _
967
+ error = NORMAL
968
+ flat_samples = []
969
+ for jax_expr in jax_inputs:
970
+ sample, key, err, params = jax_expr(x, params, key)
971
+ shape = jnp.shape(sample)
972
+ first_dim = 1
973
+ for dim in shape[:num_free_vars]:
974
+ first_dim *= dim
975
+ new_shape = (first_dim,) + shape[num_free_vars:]
976
+ flat_sample = jnp.reshape(sample, new_shape)
977
+ flat_samples.append(flat_sample)
978
+ error |= err
979
+
980
+ # now all the inputs have dimensions equal to (k,) + the number of _ occurences
981
+ # k is the number of possible non-captured object combinations
982
+ # evaluate the function independently for each combination
983
+ # output dimension for each combination is captured variables (n1, n2, ...)
984
+ # so the total dimension of the output array is (k, n1, n2, ...)
985
+ sample = jax.vmap(pyfunc, in_axes=0)(*flat_samples)
986
+ if not isinstance(sample, jnp.ndarray):
987
+ raise ValueError(
988
+ f'Output of external Python function <{pyfunc_name}> '
989
+ f'is not a JAX array.\n' + print_stack_trace(expr))
990
+
991
+ pyfunc_dims = jnp.shape(sample)[1:]
992
+ if len(require_dims) != len(pyfunc_dims):
993
+ raise ValueError(
994
+ f'External Python function <{pyfunc_name}> returned array with '
995
+ f'{len(pyfunc_dims)} dimensions, which does not match the '
996
+ f'number of captured parameter(s) {len(require_dims)}.\n' +
997
+ print_stack_trace(expr))
998
+ for (param, require_dim, actual_dim) in zip(captured_vars, require_dims, pyfunc_dims):
999
+ if require_dim != actual_dim:
1000
+ raise ValueError(
1001
+ f'External Python function <{pyfunc_name}> returned array with '
1002
+ f'{actual_dim} elements for captured parameter <{param}>, '
1003
+ f'which does not match the number of objects {require_dim}.\n' +
1004
+ print_stack_trace(expr))
1005
+
1006
+ # unravel the combinations k back into their original dimensions
1007
+ sample = jnp.reshape(sample, free_dims + pyfunc_dims)
1008
+
1009
+ # rearrange the output dimensions to match the outer scope
1010
+ source_indices = [num_free_vars + i for i in range(len(pyfunc_dims))]
1011
+ sample = jnp.moveaxis(sample, source=source_indices, destination=dest_indices)
1012
+ return sample, key, error, params
1013
+
1014
+ return _jax_wrapped_external_function
1015
+
929
1016
  # ===========================================================================
930
1017
  # control flow
931
1018
  # ===========================================================================
@@ -1810,7 +1810,8 @@ class JaxBackpropPlanner:
1810
1810
  dashboard_viz: Optional[Any]=None,
1811
1811
  print_warnings: bool=True,
1812
1812
  parallel_updates: Optional[int]=None,
1813
- preprocessor: Optional[Preprocessor]=None) -> None:
1813
+ preprocessor: Optional[Preprocessor]=None,
1814
+ python_functions: Optional[Dict[str, Callable]]=None) -> None:
1814
1815
  '''Creates a new gradient-based algorithm for optimizing action sequences
1815
1816
  (plan) in the given RDDL. Some operations will be converted to their
1816
1817
  differentiable counterparts; the specific operations can be customized
@@ -1853,6 +1854,7 @@ class JaxBackpropPlanner:
1853
1854
  :param print_warnings: whether to print warnings
1854
1855
  :param parallel_updates: how many optimizers to run independently in parallel
1855
1856
  :param preprocessor: optional preprocessor for state inputs to plan
1857
+ :param python_functions: dictionary of external Python functions to call from RDDL
1856
1858
  '''
1857
1859
  self.rddl = rddl
1858
1860
  self.plan = plan
@@ -1879,7 +1881,10 @@ class JaxBackpropPlanner:
1879
1881
  self.use_pgpe = pgpe is not None
1880
1882
  self.print_warnings = print_warnings
1881
1883
  self.preprocessor = preprocessor
1882
-
1884
+ if python_functions is None:
1885
+ python_functions = {}
1886
+ self.python_functions = python_functions
1887
+
1883
1888
  # set optimizer
1884
1889
  try:
1885
1890
  optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs)
@@ -2027,7 +2032,8 @@ r"""
2027
2032
  use64bit=self.use64bit,
2028
2033
  cpfs_without_grad=self.cpfs_without_grad,
2029
2034
  compile_non_fluent_exact=self.compile_non_fluent_exact,
2030
- print_warnings=self.print_warnings
2035
+ print_warnings=self.print_warnings,
2036
+ python_functions=self.python_functions
2031
2037
  )
2032
2038
  self.compiled.compile(log_jax_expr=True, heading='RELAXED MODEL')
2033
2039
 
@@ -2035,7 +2041,8 @@ r"""
2035
2041
  self.test_compiled = JaxRDDLCompiler(
2036
2042
  rddl=rddl,
2037
2043
  logger=self.logger,
2038
- use64bit=self.use64bit
2044
+ use64bit=self.use64bit,
2045
+ python_functions=self.python_functions
2039
2046
  )
2040
2047
  self.test_compiled.compile(log_jax_expr=True, heading='EXACT MODEL')
2041
2048
 
@@ -20,7 +20,7 @@
20
20
 
21
21
  import time
22
22
  import numpy as np
23
- from typing import Dict, Optional, Union
23
+ from typing import Callable, Dict, Optional, Union
24
24
 
25
25
  import jax
26
26
 
@@ -48,6 +48,7 @@ class JaxRDDLSimulator(RDDLSimulator):
48
48
  logger: Optional[Logger]=None,
49
49
  keep_tensors: bool=False,
50
50
  objects_as_strings: bool=True,
51
+ python_functions: Optional[Dict[str, Callable]]=None,
51
52
  **compiler_args) -> None:
52
53
  '''Creates a new simulator for the given RDDL model with Jax as a backend.
53
54
 
@@ -60,8 +61,9 @@ class JaxRDDLSimulator(RDDLSimulator):
60
61
  :param logger: to log information about compilation to file
61
62
  :param keep_tensors: whether the sampler takes actions and
62
63
  returns state in numpy array form
63
- param objects_as_strings: whether to return object values as strings (defaults
64
+ :param objects_as_strings: whether to return object values as strings (defaults
64
65
  to integer indices if False)
66
+ :param python_functions: dictionary of external Python functions to call from RDDL
65
67
  :param **compiler_args: keyword arguments to pass to the Jax compiler
66
68
  '''
67
69
  if key is None:
@@ -73,7 +75,8 @@ class JaxRDDLSimulator(RDDLSimulator):
73
75
  # generate direct sampling with default numpy RNG and operations
74
76
  super(JaxRDDLSimulator, self).__init__(
75
77
  rddl, logger=logger,
76
- keep_tensors=keep_tensors, objects_as_strings=objects_as_strings)
78
+ keep_tensors=keep_tensors, objects_as_strings=objects_as_strings,
79
+ python_functions=python_functions)
77
80
 
78
81
  def seed(self, seed: int) -> None:
79
82
  super(JaxRDDLSimulator, self).seed(seed)
@@ -83,7 +86,12 @@ class JaxRDDLSimulator(RDDLSimulator):
83
86
  rddl = self.rddl
84
87
 
85
88
  # compilation
86
- compiled = JaxRDDLCompiler(rddl, logger=self.logger, **self.compiler_args)
89
+ compiled = JaxRDDLCompiler(
90
+ rddl,
91
+ logger=self.logger,
92
+ python_functions=self.python_functions,
93
+ **self.compiler_args
94
+ )
87
95
  compiled.compile(log_jax_expr=True, heading='SIMULATION MODEL')
88
96
 
89
97
  self.init_values = compiled.init_values
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyRDDLGym-jax
3
- Version: 2.6
3
+ Version: 2.8
4
4
  Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
5
  Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
6
  Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
@@ -20,7 +20,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
20
  Requires-Python: >=3.9
21
21
  Description-Content-Type: text/markdown
22
22
  License-File: LICENSE
23
- Requires-Dist: pyRDDLGym>=2.3
23
+ Requires-Dist: pyRDDLGym>=2.5
24
24
  Requires-Dist: tqdm>=4.66
25
25
  Requires-Dist: jax>=0.4.12
26
26
  Requires-Dist: optax>=0.1.9
@@ -1,11 +1,11 @@
1
- pyRDDLGym_jax/__init__.py,sha256=VUmQViJtwUg1JGcgXlmNm0fE3Njyruyt_76c16R-LTo,19
1
+ pyRDDLGym_jax/__init__.py,sha256=wFUjCk0MO0Yrqazz3Sl7bIWJNarnZ6DdSmyPNcX43ek,19
2
2
  pyRDDLGym_jax/entry_point.py,sha256=K0zy1oe66jfBHkHHCM6aGHbbiVqnQvDhDb8se4uaKHE,3319
3
3
  pyRDDLGym_jax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- pyRDDLGym_jax/core/compiler.py,sha256=Bpgfw4nqRFqiTju7ioR0B0Dhp3wMvk-9LmTRpMmLIOc,83457
4
+ pyRDDLGym_jax/core/compiler.py,sha256=DS4G5f5U83cOUQsUe6RsyyJnLPDuHaqjxM7bHSWMCtM,88040
5
5
  pyRDDLGym_jax/core/logic.py,sha256=9rRpKJCx4Us_2c6BiSWRN9k2sM_iYsAK1B7zcgwu3ZA,56290
6
6
  pyRDDLGym_jax/core/model.py,sha256=4WfmtUVN1EKCD-7eWeQByWk8_zKyDcMABAMdlxN1LOU,27215
7
- pyRDDLGym_jax/core/planner.py,sha256=a684ss5TAkJ-P2SEbZA90FSpDwFxHwRoaLtbRIBspAA,146450
8
- pyRDDLGym_jax/core/simulator.py,sha256=ayCATTUL3clLaZPQ5OUg2bI_c26KKCTq6TbrxbMsVdc,10470
7
+ pyRDDLGym_jax/core/planner.py,sha256=cvl3JS1tLQqj8KJ5ATkHUfIzCzcYJWOCoWJYwLxMDSg,146835
8
+ pyRDDLGym_jax/core/simulator.py,sha256=D-yLxDFw67DvFHdb_kJjZHujSBSmiFA1J3osel-KOvY,10799
9
9
  pyRDDLGym_jax/core/tuning.py,sha256=BWcQZk02TMLexTz1Sw4lX2EQKvmPbp7biC51M-IiNUw,25153
10
10
  pyRDDLGym_jax/core/visualization.py,sha256=4BghMp8N7qtF0tdyDSqtxAxNfP9HPrQWTiXzAMJmx7o,70365
11
11
  pyRDDLGym_jax/core/assets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -42,9 +42,9 @@ pyRDDLGym_jax/examples/configs/default_slp.cfg,sha256=mJo0woDevhQCSQfJg30ULVy9qG
42
42
  pyRDDLGym_jax/examples/configs/tuning_drp.cfg,sha256=zocZn_cVarH5i0hOlt2Zu0NwmXYBmTTghLaXLtQOGto,526
43
43
  pyRDDLGym_jax/examples/configs/tuning_replan.cfg,sha256=9oIhtw9cuikmlbDgCgbrTc5G7hUio-HeAv_3CEGVclY,523
44
44
  pyRDDLGym_jax/examples/configs/tuning_slp.cfg,sha256=QqnyR__5-HhKeCDfGDel8VIlqsjxRHk4SSH089zJP8s,486
45
- pyrddlgym_jax-2.6.dist-info/licenses/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
46
- pyrddlgym_jax-2.6.dist-info/METADATA,sha256=1gY3EPRHKMVeZYYgq4DCqWvw3Q1Ak5XVYRaIO2UlQXc,16770
47
- pyrddlgym_jax-2.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
48
- pyrddlgym_jax-2.6.dist-info/entry_points.txt,sha256=Q--z9QzqDBz1xjswPZ87PU-pib-WPXx44hUWAFoBGBA,59
49
- pyrddlgym_jax-2.6.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
50
- pyrddlgym_jax-2.6.dist-info/RECORD,,
45
+ pyrddlgym_jax-2.8.dist-info/licenses/LICENSE,sha256=2a-BZEY7aEZW-DkmmOQsuUDU0pc6ovQy3QnYFZ4baq4,1095
46
+ pyrddlgym_jax-2.8.dist-info/METADATA,sha256=WKXqbnUZX508HqaTJ1LVAhENd3A1zFsZFYnSk2dONFo,16770
47
+ pyrddlgym_jax-2.8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
48
+ pyrddlgym_jax-2.8.dist-info/entry_points.txt,sha256=Q--z9QzqDBz1xjswPZ87PU-pib-WPXx44hUWAFoBGBA,59
49
+ pyrddlgym_jax-2.8.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
50
+ pyrddlgym_jax-2.8.dist-info/RECORD,,
@@ -1,6 +1,6 @@
1
1
  MIT License
2
2
 
3
- Copyright (c) 2024 pyrddlgym-project
3
+ Copyright (c) 2025 pyrddlgym-project
4
4
 
5
5
  Permission is hereby granted, free of charge, to any person obtaining a copy
6
6
  of this software and associated documentation files (the "Software"), to deal