pyRDDLGym-jax 2.6__tar.gz → 2.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/LICENSE +1 -1
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/PKG-INFO +2 -2
- pyrddlgym_jax-2.8/pyRDDLGym_jax/__init__.py +1 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/core/compiler.py +91 -4
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/core/planner.py +11 -4
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/core/simulator.py +12 -4
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax.egg-info/PKG-INFO +2 -2
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax.egg-info/requires.txt +1 -1
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/setup.py +2 -2
- pyrddlgym_jax-2.6/pyRDDLGym_jax/__init__.py +0 -1
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/README.md +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/core/__init__.py +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/core/assets/__init__.py +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/core/assets/favicon.ico +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/core/logic.py +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/core/model.py +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/core/tuning.py +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/core/visualization.py +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/entry_point.py +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/__init__.py +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/__init__.py +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/default_drp.cfg +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/default_replan.cfg +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/default_slp.cfg +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/tuning_drp.cfg +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/tuning_replan.cfg +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/tuning_slp.cfg +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/run_gradient.py +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/run_gym.py +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/run_plan.py +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/run_scipy.py +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/run_tune.py +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax.egg-info/SOURCES.txt +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax.egg-info/dependency_links.txt +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax.egg-info/entry_points.txt +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax.egg-info/top_level.txt +0 -0
- {pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pyRDDLGym-jax
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.8
|
|
4
4
|
Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
|
|
5
5
|
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
6
|
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
@@ -20,7 +20,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
20
20
|
Requires-Python: >=3.9
|
|
21
21
|
Description-Content-Type: text/markdown
|
|
22
22
|
License-File: LICENSE
|
|
23
|
-
Requires-Dist: pyRDDLGym>=2.
|
|
23
|
+
Requires-Dist: pyRDDLGym>=2.5
|
|
24
24
|
Requires-Dist: tqdm>=4.66
|
|
25
25
|
Requires-Dist: jax>=0.4.12
|
|
26
26
|
Requires-Dist: optax>=0.1.9
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = '2.8'
|
|
@@ -30,7 +30,8 @@ from pyRDDLGym.core.debug.exception import (
|
|
|
30
30
|
print_stack_trace,
|
|
31
31
|
raise_warning,
|
|
32
32
|
RDDLInvalidNumberOfArgumentsError,
|
|
33
|
-
RDDLNotImplementedError
|
|
33
|
+
RDDLNotImplementedError,
|
|
34
|
+
RDDLUndefinedVariableError
|
|
34
35
|
)
|
|
35
36
|
from pyRDDLGym.core.debug.logger import Logger
|
|
36
37
|
from pyRDDLGym.core.simulator import RDDLSimulatorPrecompiled
|
|
@@ -56,7 +57,8 @@ class JaxRDDLCompiler:
|
|
|
56
57
|
allow_synchronous_state: bool=True,
|
|
57
58
|
logger: Optional[Logger]=None,
|
|
58
59
|
use64bit: bool=False,
|
|
59
|
-
compile_non_fluent_exact: bool=True
|
|
60
|
+
compile_non_fluent_exact: bool=True,
|
|
61
|
+
python_functions: Optional[Dict[str, Callable]]=None) -> None:
|
|
60
62
|
'''Creates a new RDDL to Jax compiler.
|
|
61
63
|
|
|
62
64
|
:param rddl: the RDDL model to compile into Jax
|
|
@@ -65,7 +67,8 @@ class JaxRDDLCompiler:
|
|
|
65
67
|
:param logger: to log information about compilation to file
|
|
66
68
|
:param use64bit: whether to use 64 bit arithmetic
|
|
67
69
|
:param compile_non_fluent_exact: whether non-fluent expressions
|
|
68
|
-
are always compiled using exact JAX expressions
|
|
70
|
+
are always compiled using exact JAX expressions
|
|
71
|
+
:param python_functions: dictionary of external Python functions to call from RDDL
|
|
69
72
|
'''
|
|
70
73
|
self.rddl = rddl
|
|
71
74
|
self.logger = logger
|
|
@@ -99,11 +102,15 @@ class JaxRDDLCompiler:
|
|
|
99
102
|
self.traced = tracer.trace()
|
|
100
103
|
|
|
101
104
|
# extract the box constraints on actions
|
|
105
|
+
if python_functions is None:
|
|
106
|
+
python_functions = {}
|
|
107
|
+
self.python_functions = python_functions
|
|
102
108
|
simulator = RDDLSimulatorPrecompiled(
|
|
103
109
|
rddl=self.rddl,
|
|
104
110
|
init_values=self.init_values,
|
|
105
111
|
levels=self.levels,
|
|
106
|
-
trace_info=self.traced
|
|
112
|
+
trace_info=self.traced,
|
|
113
|
+
python_functions=python_functions
|
|
107
114
|
)
|
|
108
115
|
constraints = RDDLConstraints(simulator, vectorized=True)
|
|
109
116
|
self.constraints = constraints
|
|
@@ -605,6 +612,8 @@ class JaxRDDLCompiler:
|
|
|
605
612
|
jax_expr = self._jax_aggregation(expr, init_params)
|
|
606
613
|
elif etype == 'func':
|
|
607
614
|
jax_expr = self._jax_functional(expr, init_params)
|
|
615
|
+
elif etype == 'pyfunc':
|
|
616
|
+
jax_expr = self._jax_pyfunc(expr, init_params)
|
|
608
617
|
elif etype == 'control':
|
|
609
618
|
jax_expr = self._jax_control(expr, init_params)
|
|
610
619
|
elif etype == 'randomvar':
|
|
@@ -926,6 +935,84 @@ class JaxRDDLCompiler:
|
|
|
926
935
|
raise RDDLNotImplementedError(
|
|
927
936
|
f'Function {op} is not supported.\n' + print_stack_trace(expr))
|
|
928
937
|
|
|
938
|
+
def _jax_pyfunc(self, expr, init_params):
|
|
939
|
+
NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
|
|
940
|
+
|
|
941
|
+
# get the Python function by name
|
|
942
|
+
_, pyfunc_name = expr.etype
|
|
943
|
+
pyfunc = self.python_functions.get(pyfunc_name)
|
|
944
|
+
if pyfunc is None:
|
|
945
|
+
raise RDDLUndefinedVariableError(
|
|
946
|
+
f'Undefined external Python function <{pyfunc_name}>, '
|
|
947
|
+
f'must be one of {list(self.python_functions.keys())}.\n' +
|
|
948
|
+
print_stack_trace(expr))
|
|
949
|
+
|
|
950
|
+
captured_vars, args = expr.args
|
|
951
|
+
scope_vars = self.traced.cached_objects_in_scope(expr)
|
|
952
|
+
dest_indices = self.traced.cached_sim_info(expr)
|
|
953
|
+
free_vars = [p for p in scope_vars if p[0] not in captured_vars]
|
|
954
|
+
free_dims = self.rddl.object_counts(p for (_, p) in free_vars)
|
|
955
|
+
num_free_vars = len(free_vars)
|
|
956
|
+
captured_types = [t for (p, t) in scope_vars if p in captured_vars]
|
|
957
|
+
require_dims = self.rddl.object_counts(captured_types)
|
|
958
|
+
|
|
959
|
+
# compile the inputs to the function
|
|
960
|
+
jax_inputs = [self._jax(arg, init_params) for arg in args]
|
|
961
|
+
|
|
962
|
+
# compile the function evaluation function
|
|
963
|
+
def _jax_wrapped_external_function(x, params, key):
|
|
964
|
+
|
|
965
|
+
# evaluate inputs to the function
|
|
966
|
+
# first dimensions are non-captured vars in outer scope followed by all the _
|
|
967
|
+
error = NORMAL
|
|
968
|
+
flat_samples = []
|
|
969
|
+
for jax_expr in jax_inputs:
|
|
970
|
+
sample, key, err, params = jax_expr(x, params, key)
|
|
971
|
+
shape = jnp.shape(sample)
|
|
972
|
+
first_dim = 1
|
|
973
|
+
for dim in shape[:num_free_vars]:
|
|
974
|
+
first_dim *= dim
|
|
975
|
+
new_shape = (first_dim,) + shape[num_free_vars:]
|
|
976
|
+
flat_sample = jnp.reshape(sample, new_shape)
|
|
977
|
+
flat_samples.append(flat_sample)
|
|
978
|
+
error |= err
|
|
979
|
+
|
|
980
|
+
# now all the inputs have dimensions equal to (k,) + the number of _ occurences
|
|
981
|
+
# k is the number of possible non-captured object combinations
|
|
982
|
+
# evaluate the function independently for each combination
|
|
983
|
+
# output dimension for each combination is captured variables (n1, n2, ...)
|
|
984
|
+
# so the total dimension of the output array is (k, n1, n2, ...)
|
|
985
|
+
sample = jax.vmap(pyfunc, in_axes=0)(*flat_samples)
|
|
986
|
+
if not isinstance(sample, jnp.ndarray):
|
|
987
|
+
raise ValueError(
|
|
988
|
+
f'Output of external Python function <{pyfunc_name}> '
|
|
989
|
+
f'is not a JAX array.\n' + print_stack_trace(expr))
|
|
990
|
+
|
|
991
|
+
pyfunc_dims = jnp.shape(sample)[1:]
|
|
992
|
+
if len(require_dims) != len(pyfunc_dims):
|
|
993
|
+
raise ValueError(
|
|
994
|
+
f'External Python function <{pyfunc_name}> returned array with '
|
|
995
|
+
f'{len(pyfunc_dims)} dimensions, which does not match the '
|
|
996
|
+
f'number of captured parameter(s) {len(require_dims)}.\n' +
|
|
997
|
+
print_stack_trace(expr))
|
|
998
|
+
for (param, require_dim, actual_dim) in zip(captured_vars, require_dims, pyfunc_dims):
|
|
999
|
+
if require_dim != actual_dim:
|
|
1000
|
+
raise ValueError(
|
|
1001
|
+
f'External Python function <{pyfunc_name}> returned array with '
|
|
1002
|
+
f'{actual_dim} elements for captured parameter <{param}>, '
|
|
1003
|
+
f'which does not match the number of objects {require_dim}.\n' +
|
|
1004
|
+
print_stack_trace(expr))
|
|
1005
|
+
|
|
1006
|
+
# unravel the combinations k back into their original dimensions
|
|
1007
|
+
sample = jnp.reshape(sample, free_dims + pyfunc_dims)
|
|
1008
|
+
|
|
1009
|
+
# rearrange the output dimensions to match the outer scope
|
|
1010
|
+
source_indices = [num_free_vars + i for i in range(len(pyfunc_dims))]
|
|
1011
|
+
sample = jnp.moveaxis(sample, source=source_indices, destination=dest_indices)
|
|
1012
|
+
return sample, key, error, params
|
|
1013
|
+
|
|
1014
|
+
return _jax_wrapped_external_function
|
|
1015
|
+
|
|
929
1016
|
# ===========================================================================
|
|
930
1017
|
# control flow
|
|
931
1018
|
# ===========================================================================
|
|
@@ -1810,7 +1810,8 @@ class JaxBackpropPlanner:
|
|
|
1810
1810
|
dashboard_viz: Optional[Any]=None,
|
|
1811
1811
|
print_warnings: bool=True,
|
|
1812
1812
|
parallel_updates: Optional[int]=None,
|
|
1813
|
-
preprocessor: Optional[Preprocessor]=None
|
|
1813
|
+
preprocessor: Optional[Preprocessor]=None,
|
|
1814
|
+
python_functions: Optional[Dict[str, Callable]]=None) -> None:
|
|
1814
1815
|
'''Creates a new gradient-based algorithm for optimizing action sequences
|
|
1815
1816
|
(plan) in the given RDDL. Some operations will be converted to their
|
|
1816
1817
|
differentiable counterparts; the specific operations can be customized
|
|
@@ -1853,6 +1854,7 @@ class JaxBackpropPlanner:
|
|
|
1853
1854
|
:param print_warnings: whether to print warnings
|
|
1854
1855
|
:param parallel_updates: how many optimizers to run independently in parallel
|
|
1855
1856
|
:param preprocessor: optional preprocessor for state inputs to plan
|
|
1857
|
+
:param python_functions: dictionary of external Python functions to call from RDDL
|
|
1856
1858
|
'''
|
|
1857
1859
|
self.rddl = rddl
|
|
1858
1860
|
self.plan = plan
|
|
@@ -1879,7 +1881,10 @@ class JaxBackpropPlanner:
|
|
|
1879
1881
|
self.use_pgpe = pgpe is not None
|
|
1880
1882
|
self.print_warnings = print_warnings
|
|
1881
1883
|
self.preprocessor = preprocessor
|
|
1882
|
-
|
|
1884
|
+
if python_functions is None:
|
|
1885
|
+
python_functions = {}
|
|
1886
|
+
self.python_functions = python_functions
|
|
1887
|
+
|
|
1883
1888
|
# set optimizer
|
|
1884
1889
|
try:
|
|
1885
1890
|
optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs)
|
|
@@ -2027,7 +2032,8 @@ r"""
|
|
|
2027
2032
|
use64bit=self.use64bit,
|
|
2028
2033
|
cpfs_without_grad=self.cpfs_without_grad,
|
|
2029
2034
|
compile_non_fluent_exact=self.compile_non_fluent_exact,
|
|
2030
|
-
print_warnings=self.print_warnings
|
|
2035
|
+
print_warnings=self.print_warnings,
|
|
2036
|
+
python_functions=self.python_functions
|
|
2031
2037
|
)
|
|
2032
2038
|
self.compiled.compile(log_jax_expr=True, heading='RELAXED MODEL')
|
|
2033
2039
|
|
|
@@ -2035,7 +2041,8 @@ r"""
|
|
|
2035
2041
|
self.test_compiled = JaxRDDLCompiler(
|
|
2036
2042
|
rddl=rddl,
|
|
2037
2043
|
logger=self.logger,
|
|
2038
|
-
use64bit=self.use64bit
|
|
2044
|
+
use64bit=self.use64bit,
|
|
2045
|
+
python_functions=self.python_functions
|
|
2039
2046
|
)
|
|
2040
2047
|
self.test_compiled.compile(log_jax_expr=True, heading='EXACT MODEL')
|
|
2041
2048
|
|
|
@@ -20,7 +20,7 @@
|
|
|
20
20
|
|
|
21
21
|
import time
|
|
22
22
|
import numpy as np
|
|
23
|
-
from typing import Dict, Optional, Union
|
|
23
|
+
from typing import Callable, Dict, Optional, Union
|
|
24
24
|
|
|
25
25
|
import jax
|
|
26
26
|
|
|
@@ -48,6 +48,7 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
48
48
|
logger: Optional[Logger]=None,
|
|
49
49
|
keep_tensors: bool=False,
|
|
50
50
|
objects_as_strings: bool=True,
|
|
51
|
+
python_functions: Optional[Dict[str, Callable]]=None,
|
|
51
52
|
**compiler_args) -> None:
|
|
52
53
|
'''Creates a new simulator for the given RDDL model with Jax as a backend.
|
|
53
54
|
|
|
@@ -60,8 +61,9 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
60
61
|
:param logger: to log information about compilation to file
|
|
61
62
|
:param keep_tensors: whether the sampler takes actions and
|
|
62
63
|
returns state in numpy array form
|
|
63
|
-
param objects_as_strings: whether to return object values as strings (defaults
|
|
64
|
+
:param objects_as_strings: whether to return object values as strings (defaults
|
|
64
65
|
to integer indices if False)
|
|
66
|
+
:param python_functions: dictionary of external Python functions to call from RDDL
|
|
65
67
|
:param **compiler_args: keyword arguments to pass to the Jax compiler
|
|
66
68
|
'''
|
|
67
69
|
if key is None:
|
|
@@ -73,7 +75,8 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
73
75
|
# generate direct sampling with default numpy RNG and operations
|
|
74
76
|
super(JaxRDDLSimulator, self).__init__(
|
|
75
77
|
rddl, logger=logger,
|
|
76
|
-
keep_tensors=keep_tensors, objects_as_strings=objects_as_strings
|
|
78
|
+
keep_tensors=keep_tensors, objects_as_strings=objects_as_strings,
|
|
79
|
+
python_functions=python_functions)
|
|
77
80
|
|
|
78
81
|
def seed(self, seed: int) -> None:
|
|
79
82
|
super(JaxRDDLSimulator, self).seed(seed)
|
|
@@ -83,7 +86,12 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
83
86
|
rddl = self.rddl
|
|
84
87
|
|
|
85
88
|
# compilation
|
|
86
|
-
compiled = JaxRDDLCompiler(
|
|
89
|
+
compiled = JaxRDDLCompiler(
|
|
90
|
+
rddl,
|
|
91
|
+
logger=self.logger,
|
|
92
|
+
python_functions=self.python_functions,
|
|
93
|
+
**self.compiler_args
|
|
94
|
+
)
|
|
87
95
|
compiled.compile(log_jax_expr=True, heading='SIMULATION MODEL')
|
|
88
96
|
|
|
89
97
|
self.init_values = compiled.init_values
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pyRDDLGym-jax
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.8
|
|
4
4
|
Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
|
|
5
5
|
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
6
|
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
@@ -20,7 +20,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
20
20
|
Requires-Python: >=3.9
|
|
21
21
|
Description-Content-Type: text/markdown
|
|
22
22
|
License-File: LICENSE
|
|
23
|
-
Requires-Dist: pyRDDLGym>=2.
|
|
23
|
+
Requires-Dist: pyRDDLGym>=2.5
|
|
24
24
|
Requires-Dist: tqdm>=4.66
|
|
25
25
|
Requires-Dist: jax>=0.4.12
|
|
26
26
|
Requires-Dist: optax>=0.1.9
|
|
@@ -19,7 +19,7 @@ long_description = (Path(__file__).parent / "README.md").read_text()
|
|
|
19
19
|
|
|
20
20
|
setup(
|
|
21
21
|
name='pyRDDLGym-jax',
|
|
22
|
-
version='2.
|
|
22
|
+
version='2.8',
|
|
23
23
|
author="Michael Gimelfarb, Ayal Taitler, Scott Sanner",
|
|
24
24
|
author_email="mike.gimelfarb@mail.utoronto.ca, ataitler@gmail.com, ssanner@mie.utoronto.ca",
|
|
25
25
|
description="pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.",
|
|
@@ -29,7 +29,7 @@ setup(
|
|
|
29
29
|
url="https://github.com/pyrddlgym-project/pyRDDLGym-jax",
|
|
30
30
|
packages=find_packages(),
|
|
31
31
|
install_requires=[
|
|
32
|
-
'pyRDDLGym>=2.
|
|
32
|
+
'pyRDDLGym>=2.5',
|
|
33
33
|
'tqdm>=4.66',
|
|
34
34
|
'jax>=0.4.12',
|
|
35
35
|
'optax>=0.1.9',
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
__version__ = '2.6'
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg
RENAMED
|
File without changes
|
{pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg
RENAMED
|
File without changes
|
|
File without changes
|
{pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg
RENAMED
|
File without changes
|
{pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg
RENAMED
|
File without changes
|
|
File without changes
|
{pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg
RENAMED
|
File without changes
|
|
File without changes
|
{pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg
RENAMED
|
File without changes
|
{pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg
RENAMED
|
File without changes
|
{pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg
RENAMED
|
File without changes
|
|
File without changes
|
{pyrddlgym_jax-2.6 → pyrddlgym_jax-2.8}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|