pyRDDLGym-jax 2.5__py3-none-any.whl → 2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +107 -11
- pyRDDLGym_jax/core/logic.py +6 -8
- pyRDDLGym_jax/core/model.py +595 -0
- pyRDDLGym_jax/core/planner.py +183 -24
- pyRDDLGym_jax/core/simulator.py +12 -4
- pyRDDLGym_jax/examples/run_plan.py +31 -0
- {pyrddlgym_jax-2.5.dist-info → pyrddlgym_jax-2.7.dist-info}/METADATA +5 -13
- {pyrddlgym_jax-2.5.dist-info → pyrddlgym_jax-2.7.dist-info}/RECORD +13 -12
- {pyrddlgym_jax-2.5.dist-info → pyrddlgym_jax-2.7.dist-info}/licenses/LICENSE +1 -1
- {pyrddlgym_jax-2.5.dist-info → pyrddlgym_jax-2.7.dist-info}/WHEEL +0 -0
- {pyrddlgym_jax-2.5.dist-info → pyrddlgym_jax-2.7.dist-info}/entry_points.txt +0 -0
- {pyrddlgym_jax-2.5.dist-info → pyrddlgym_jax-2.7.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/planner.py
CHANGED
|
@@ -207,6 +207,13 @@ def _load_config(config, args):
|
|
|
207
207
|
pgpe_kwargs['optimizer'] = pgpe_optimizer
|
|
208
208
|
planner_args['pgpe'] = getattr(sys.modules[__name__], pgpe_method)(**pgpe_kwargs)
|
|
209
209
|
|
|
210
|
+
# preprocessor settings
|
|
211
|
+
preproc_method = planner_args.get('preprocessor', None)
|
|
212
|
+
preproc_kwargs = planner_args.pop('preprocessor_kwargs', {})
|
|
213
|
+
if preproc_method is not None:
|
|
214
|
+
planner_args['preprocessor'] = getattr(
|
|
215
|
+
sys.modules[__name__], preproc_method)(**preproc_kwargs)
|
|
216
|
+
|
|
210
217
|
# optimize call RNG key
|
|
211
218
|
planner_key = train_args.get('key', None)
|
|
212
219
|
if planner_key is not None:
|
|
@@ -343,6 +350,100 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
343
350
|
return arg
|
|
344
351
|
|
|
345
352
|
|
|
353
|
+
# ***********************************************************************
|
|
354
|
+
# ALL VERSIONS OF STATE PREPROCESSING FOR DRP
|
|
355
|
+
#
|
|
356
|
+
# - static normalization
|
|
357
|
+
#
|
|
358
|
+
# ***********************************************************************
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
class Preprocessor(metaclass=ABCMeta):
|
|
362
|
+
'''Base class for all state preprocessors.'''
|
|
363
|
+
|
|
364
|
+
HYPERPARAMS_KEY = 'preprocessor__'
|
|
365
|
+
|
|
366
|
+
def __init__(self) -> None:
|
|
367
|
+
self._initializer = None
|
|
368
|
+
self._update = None
|
|
369
|
+
self._transform = None
|
|
370
|
+
|
|
371
|
+
@property
|
|
372
|
+
def initialize(self):
|
|
373
|
+
return self._initializer
|
|
374
|
+
|
|
375
|
+
@property
|
|
376
|
+
def update(self):
|
|
377
|
+
return self._update
|
|
378
|
+
|
|
379
|
+
@property
|
|
380
|
+
def transform(self):
|
|
381
|
+
return self._transform
|
|
382
|
+
|
|
383
|
+
@abstractmethod
|
|
384
|
+
def compile(self, compiled: JaxRDDLCompilerWithGrad) -> None:
|
|
385
|
+
pass
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
class StaticNormalizer(Preprocessor):
|
|
389
|
+
'''Normalize values by box constraints on fluents computed from the RDDL domain.'''
|
|
390
|
+
|
|
391
|
+
def __init__(self, fluent_bounds: Dict[str, Tuple[np.ndarray, np.ndarray]]={}) -> None:
|
|
392
|
+
'''Create a new instance of the static normalizer.
|
|
393
|
+
|
|
394
|
+
:param fluent_bounds: optional bounds on fluents to overwrite default values.
|
|
395
|
+
'''
|
|
396
|
+
self.fluent_bounds = fluent_bounds
|
|
397
|
+
|
|
398
|
+
def compile(self, compiled: JaxRDDLCompilerWithGrad) -> None:
|
|
399
|
+
|
|
400
|
+
# adjust for partial observability
|
|
401
|
+
rddl = compiled.rddl
|
|
402
|
+
if rddl.observ_fluents:
|
|
403
|
+
observed_vars = rddl.observ_fluents
|
|
404
|
+
else:
|
|
405
|
+
observed_vars = rddl.state_fluents
|
|
406
|
+
|
|
407
|
+
# ignore boolean fluents and infinite bounds
|
|
408
|
+
bounded_vars = {}
|
|
409
|
+
for var in observed_vars:
|
|
410
|
+
if rddl.variable_ranges[var] != 'bool':
|
|
411
|
+
lower, upper = compiled.constraints.bounds[var]
|
|
412
|
+
if np.all(np.isfinite(lower) & np.isfinite(upper) & (lower < upper)):
|
|
413
|
+
bounded_vars[var] = (lower, upper)
|
|
414
|
+
user_bounds = self.fluent_bounds.get(var, None)
|
|
415
|
+
if user_bounds is not None:
|
|
416
|
+
bounded_vars[var] = tuple(user_bounds)
|
|
417
|
+
|
|
418
|
+
# initialize to ranges computed by the constraint parser
|
|
419
|
+
def _jax_wrapped_normalizer_init():
|
|
420
|
+
return bounded_vars
|
|
421
|
+
self._initializer = jax.jit(_jax_wrapped_normalizer_init)
|
|
422
|
+
|
|
423
|
+
# static bounds
|
|
424
|
+
def _jax_wrapped_normalizer_update(subs, stats):
|
|
425
|
+
stats = {var: (jnp.asarray(lower, dtype=compiled.REAL),
|
|
426
|
+
jnp.asarray(upper, dtype=compiled.REAL))
|
|
427
|
+
for (var, (lower, upper)) in bounded_vars.items()}
|
|
428
|
+
return stats
|
|
429
|
+
self._update = jax.jit(_jax_wrapped_normalizer_update)
|
|
430
|
+
|
|
431
|
+
# apply min max scaling
|
|
432
|
+
def _jax_wrapped_normalizer_transform(subs, stats):
|
|
433
|
+
new_subs = {}
|
|
434
|
+
for (var, values) in subs.items():
|
|
435
|
+
if var in stats:
|
|
436
|
+
lower, upper = stats[var]
|
|
437
|
+
new_dims = jnp.ndim(values) - jnp.ndim(lower)
|
|
438
|
+
lower = lower[(jnp.newaxis,) * new_dims + (...,)]
|
|
439
|
+
upper = upper[(jnp.newaxis,) * new_dims + (...,)]
|
|
440
|
+
new_subs[var] = (values - lower) / (upper - lower)
|
|
441
|
+
else:
|
|
442
|
+
new_subs[var] = values
|
|
443
|
+
return new_subs
|
|
444
|
+
self._transform = jax.jit(_jax_wrapped_normalizer_transform)
|
|
445
|
+
|
|
446
|
+
|
|
346
447
|
# ***********************************************************************
|
|
347
448
|
# ALL VERSIONS OF JAX PLANS
|
|
348
449
|
#
|
|
@@ -368,7 +469,8 @@ class JaxPlan(metaclass=ABCMeta):
|
|
|
368
469
|
@abstractmethod
|
|
369
470
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
370
471
|
_bounds: Bounds,
|
|
371
|
-
horizon: int
|
|
472
|
+
horizon: int,
|
|
473
|
+
preprocessor: Optional[Preprocessor]=None) -> None:
|
|
372
474
|
pass
|
|
373
475
|
|
|
374
476
|
@abstractmethod
|
|
@@ -519,7 +621,8 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
519
621
|
|
|
520
622
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
521
623
|
_bounds: Bounds,
|
|
522
|
-
horizon: int
|
|
624
|
+
horizon: int,
|
|
625
|
+
preprocessor: Optional[Preprocessor]=None) -> None:
|
|
523
626
|
rddl = compiled.rddl
|
|
524
627
|
|
|
525
628
|
# calculate the correct action box bounds
|
|
@@ -607,7 +710,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
607
710
|
return new_params, True
|
|
608
711
|
|
|
609
712
|
# convert softmax action back to action dict
|
|
610
|
-
action_sizes = {var: np.prod(shape[1:], dtype=
|
|
713
|
+
action_sizes = {var: np.prod(shape[1:], dtype=np.int64)
|
|
611
714
|
for (var, shape) in shapes.items()
|
|
612
715
|
if ranges[var] == 'bool'}
|
|
613
716
|
|
|
@@ -691,7 +794,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
691
794
|
scores = []
|
|
692
795
|
for (var, param) in params.items():
|
|
693
796
|
if ranges[var] == 'bool':
|
|
694
|
-
param_flat = jnp.ravel(param)
|
|
797
|
+
param_flat = jnp.ravel(param, order='C')
|
|
695
798
|
if noop[var]:
|
|
696
799
|
if wrap_sigmoid:
|
|
697
800
|
param_flat = -param_flat
|
|
@@ -908,7 +1011,8 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
908
1011
|
|
|
909
1012
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
910
1013
|
_bounds: Bounds,
|
|
911
|
-
horizon: int
|
|
1014
|
+
horizon: int,
|
|
1015
|
+
preprocessor: Optional[Preprocessor]=None) -> None:
|
|
912
1016
|
rddl = compiled.rddl
|
|
913
1017
|
|
|
914
1018
|
# calculate the correct action box bounds
|
|
@@ -939,7 +1043,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
939
1043
|
wrap_non_bool = self._wrap_non_bool
|
|
940
1044
|
init = self._initializer
|
|
941
1045
|
layers = list(enumerate(zip(self._topology, self._activations)))
|
|
942
|
-
layer_sizes = {var: np.prod(shape, dtype=
|
|
1046
|
+
layer_sizes = {var: np.prod(shape, dtype=np.int64)
|
|
943
1047
|
for (var, shape) in shapes.items()}
|
|
944
1048
|
layer_names = {var: f'output_{var}'.replace('-', '_') for var in shapes}
|
|
945
1049
|
|
|
@@ -973,7 +1077,12 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
973
1077
|
normalize = False
|
|
974
1078
|
|
|
975
1079
|
# convert subs dictionary into a state vector to feed to the MLP
|
|
976
|
-
def _jax_wrapped_policy_input(subs):
|
|
1080
|
+
def _jax_wrapped_policy_input(subs, hyperparams):
|
|
1081
|
+
|
|
1082
|
+
# optional state preprocessing
|
|
1083
|
+
if preprocessor is not None:
|
|
1084
|
+
stats = hyperparams[preprocessor.HYPERPARAMS_KEY]
|
|
1085
|
+
subs = preprocessor.transform(subs, stats)
|
|
977
1086
|
|
|
978
1087
|
# concatenate all state variables into a single vector
|
|
979
1088
|
# optionally apply layer norm to each input tensor
|
|
@@ -981,7 +1090,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
981
1090
|
non_bool_dims = 0
|
|
982
1091
|
for (var, value) in subs.items():
|
|
983
1092
|
if var in observed_vars:
|
|
984
|
-
state = jnp.ravel(value)
|
|
1093
|
+
state = jnp.ravel(value, order='C')
|
|
985
1094
|
if ranges[var] == 'bool':
|
|
986
1095
|
states_bool.append(state)
|
|
987
1096
|
else:
|
|
@@ -1010,8 +1119,8 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1010
1119
|
return state
|
|
1011
1120
|
|
|
1012
1121
|
# predict actions from the policy network for current state
|
|
1013
|
-
def _jax_wrapped_policy_network_predict(subs):
|
|
1014
|
-
state = _jax_wrapped_policy_input(subs)
|
|
1122
|
+
def _jax_wrapped_policy_network_predict(subs, hyperparams):
|
|
1123
|
+
state = _jax_wrapped_policy_input(subs, hyperparams)
|
|
1015
1124
|
|
|
1016
1125
|
# feed state vector through hidden layers
|
|
1017
1126
|
hidden = state
|
|
@@ -1076,7 +1185,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1076
1185
|
|
|
1077
1186
|
# train action prediction
|
|
1078
1187
|
def _jax_wrapped_drp_predict_train(key, params, hyperparams, step, subs):
|
|
1079
|
-
actions = predict_fn.apply(params, subs)
|
|
1188
|
+
actions = predict_fn.apply(params, subs, hyperparams)
|
|
1080
1189
|
if not wrap_non_bool:
|
|
1081
1190
|
for (var, action) in actions.items():
|
|
1082
1191
|
if var != bool_key and ranges[var] != 'bool':
|
|
@@ -1126,7 +1235,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1126
1235
|
subs = {var: value[0, ...]
|
|
1127
1236
|
for (var, value) in subs.items()
|
|
1128
1237
|
if var in observed_vars}
|
|
1129
|
-
params = predict_fn.init(key, subs)
|
|
1238
|
+
params = predict_fn.init(key, subs, hyperparams)
|
|
1130
1239
|
return params
|
|
1131
1240
|
|
|
1132
1241
|
self.initializer = _jax_wrapped_drp_init
|
|
@@ -1634,12 +1743,21 @@ def mean_semivariance_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
|
1634
1743
|
return mu - 0.5 * beta * msv
|
|
1635
1744
|
|
|
1636
1745
|
|
|
1746
|
+
@jax.jit
|
|
1747
|
+
def sharpe_utility(returns: jnp.ndarray, risk_free: float) -> float:
|
|
1748
|
+
return (jnp.mean(returns) - risk_free) / (jnp.std(returns) + 1e-10)
|
|
1749
|
+
|
|
1750
|
+
|
|
1751
|
+
@jax.jit
|
|
1752
|
+
def var_utility(returns: jnp.ndarray, alpha: float) -> float:
|
|
1753
|
+
return jnp.percentile(returns, q=100 * alpha)
|
|
1754
|
+
|
|
1755
|
+
|
|
1637
1756
|
@jax.jit
|
|
1638
1757
|
def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
|
|
1639
1758
|
var = jnp.percentile(returns, q=100 * alpha)
|
|
1640
1759
|
mask = returns <= var
|
|
1641
|
-
|
|
1642
|
-
return jnp.sum(returns * weights)
|
|
1760
|
+
return jnp.sum(returns * mask) / jnp.maximum(1, jnp.sum(mask))
|
|
1643
1761
|
|
|
1644
1762
|
|
|
1645
1763
|
# set of all currently valid built-in utility functions
|
|
@@ -1649,8 +1767,10 @@ UTILITY_LOOKUP = {
|
|
|
1649
1767
|
'mean_std': mean_deviation_utility,
|
|
1650
1768
|
'mean_semivar': mean_semivariance_utility,
|
|
1651
1769
|
'mean_semidev': mean_semideviation_utility,
|
|
1770
|
+
'sharpe': sharpe_utility,
|
|
1652
1771
|
'entropic': entropic_utility,
|
|
1653
1772
|
'exponential': entropic_utility,
|
|
1773
|
+
'var': var_utility,
|
|
1654
1774
|
'cvar': cvar_utility
|
|
1655
1775
|
}
|
|
1656
1776
|
|
|
@@ -1689,7 +1809,9 @@ class JaxBackpropPlanner:
|
|
|
1689
1809
|
logger: Optional[Logger]=None,
|
|
1690
1810
|
dashboard_viz: Optional[Any]=None,
|
|
1691
1811
|
print_warnings: bool=True,
|
|
1692
|
-
parallel_updates: Optional[int]=None
|
|
1812
|
+
parallel_updates: Optional[int]=None,
|
|
1813
|
+
preprocessor: Optional[Preprocessor]=None,
|
|
1814
|
+
python_functions: Optional[Dict[str, Callable]]=None) -> None:
|
|
1693
1815
|
'''Creates a new gradient-based algorithm for optimizing action sequences
|
|
1694
1816
|
(plan) in the given RDDL. Some operations will be converted to their
|
|
1695
1817
|
differentiable counterparts; the specific operations can be customized
|
|
@@ -1731,6 +1853,8 @@ class JaxBackpropPlanner:
|
|
|
1731
1853
|
to pass to the dashboard to visualize the policy
|
|
1732
1854
|
:param print_warnings: whether to print warnings
|
|
1733
1855
|
:param parallel_updates: how many optimizers to run independently in parallel
|
|
1856
|
+
:param preprocessor: optional preprocessor for state inputs to plan
|
|
1857
|
+
:param python_functions: dictionary of external Python functions to call from RDDL
|
|
1734
1858
|
'''
|
|
1735
1859
|
self.rddl = rddl
|
|
1736
1860
|
self.plan = plan
|
|
@@ -1756,7 +1880,11 @@ class JaxBackpropPlanner:
|
|
|
1756
1880
|
self.pgpe = pgpe
|
|
1757
1881
|
self.use_pgpe = pgpe is not None
|
|
1758
1882
|
self.print_warnings = print_warnings
|
|
1759
|
-
|
|
1883
|
+
self.preprocessor = preprocessor
|
|
1884
|
+
if python_functions is None:
|
|
1885
|
+
python_functions = {}
|
|
1886
|
+
self.python_functions = python_functions
|
|
1887
|
+
|
|
1760
1888
|
# set optimizer
|
|
1761
1889
|
try:
|
|
1762
1890
|
optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs)
|
|
@@ -1881,7 +2009,8 @@ r"""
|
|
|
1881
2009
|
f' noise_kwargs ={self.noise_kwargs}\n'
|
|
1882
2010
|
f' batch_size_train ={self.batch_size_train}\n'
|
|
1883
2011
|
f' batch_size_test ={self.batch_size_test}\n'
|
|
1884
|
-
f' parallel_updates ={self.parallel_updates}\n'
|
|
2012
|
+
f' parallel_updates ={self.parallel_updates}\n'
|
|
2013
|
+
f' preprocessor ={self.preprocessor}\n')
|
|
1885
2014
|
result += str(self.plan)
|
|
1886
2015
|
if self.use_pgpe:
|
|
1887
2016
|
result += str(self.pgpe)
|
|
@@ -1903,7 +2032,8 @@ r"""
|
|
|
1903
2032
|
use64bit=self.use64bit,
|
|
1904
2033
|
cpfs_without_grad=self.cpfs_without_grad,
|
|
1905
2034
|
compile_non_fluent_exact=self.compile_non_fluent_exact,
|
|
1906
|
-
print_warnings=self.print_warnings
|
|
2035
|
+
print_warnings=self.print_warnings,
|
|
2036
|
+
python_functions=self.python_functions
|
|
1907
2037
|
)
|
|
1908
2038
|
self.compiled.compile(log_jax_expr=True, heading='RELAXED MODEL')
|
|
1909
2039
|
|
|
@@ -1911,16 +2041,22 @@ r"""
|
|
|
1911
2041
|
self.test_compiled = JaxRDDLCompiler(
|
|
1912
2042
|
rddl=rddl,
|
|
1913
2043
|
logger=self.logger,
|
|
1914
|
-
use64bit=self.use64bit
|
|
2044
|
+
use64bit=self.use64bit,
|
|
2045
|
+
python_functions=self.python_functions
|
|
1915
2046
|
)
|
|
1916
2047
|
self.test_compiled.compile(log_jax_expr=True, heading='EXACT MODEL')
|
|
1917
2048
|
|
|
1918
2049
|
def _jax_compile_optimizer(self):
|
|
1919
2050
|
|
|
2051
|
+
# preprocessor
|
|
2052
|
+
if self.preprocessor is not None:
|
|
2053
|
+
self.preprocessor.compile(self.compiled)
|
|
2054
|
+
|
|
1920
2055
|
# policy
|
|
1921
2056
|
self.plan.compile(self.compiled,
|
|
1922
2057
|
_bounds=self._action_bounds,
|
|
1923
|
-
horizon=self.horizon
|
|
2058
|
+
horizon=self.horizon,
|
|
2059
|
+
preprocessor=self.preprocessor)
|
|
1924
2060
|
self.train_policy = jax.jit(self.plan.train_policy)
|
|
1925
2061
|
self.test_policy = jax.jit(self.plan.test_policy)
|
|
1926
2062
|
|
|
@@ -1928,14 +2064,16 @@ r"""
|
|
|
1928
2064
|
train_rollouts = self.compiled.compile_rollouts(
|
|
1929
2065
|
policy=self.plan.train_policy,
|
|
1930
2066
|
n_steps=self.horizon,
|
|
1931
|
-
n_batch=self.batch_size_train
|
|
2067
|
+
n_batch=self.batch_size_train,
|
|
2068
|
+
cache_path_info=self.preprocessor is not None
|
|
1932
2069
|
)
|
|
1933
2070
|
self.train_rollouts = train_rollouts
|
|
1934
2071
|
|
|
1935
2072
|
test_rollouts = self.test_compiled.compile_rollouts(
|
|
1936
2073
|
policy=self.plan.test_policy,
|
|
1937
2074
|
n_steps=self.horizon,
|
|
1938
|
-
n_batch=self.batch_size_test
|
|
2075
|
+
n_batch=self.batch_size_test,
|
|
2076
|
+
cache_path_info=False
|
|
1939
2077
|
)
|
|
1940
2078
|
self.test_rollouts = jax.jit(test_rollouts)
|
|
1941
2079
|
|
|
@@ -2397,7 +2535,13 @@ r"""
|
|
|
2397
2535
|
f'which could be suboptimal.', 'yellow')
|
|
2398
2536
|
print(message)
|
|
2399
2537
|
policy_hyperparams[action] = 1.0
|
|
2400
|
-
|
|
2538
|
+
|
|
2539
|
+
# initialize preprocessor
|
|
2540
|
+
preproc_key = None
|
|
2541
|
+
if self.preprocessor is not None:
|
|
2542
|
+
preproc_key = self.preprocessor.HYPERPARAMS_KEY
|
|
2543
|
+
policy_hyperparams[preproc_key] = self.preprocessor.initialize()
|
|
2544
|
+
|
|
2401
2545
|
# print summary of parameters:
|
|
2402
2546
|
if print_summary:
|
|
2403
2547
|
print(self.summarize_system())
|
|
@@ -2524,6 +2668,11 @@ r"""
|
|
|
2524
2668
|
subkey, policy_params, policy_hyperparams, train_subs, model_params,
|
|
2525
2669
|
opt_state, opt_aux)
|
|
2526
2670
|
|
|
2671
|
+
# update the preprocessor
|
|
2672
|
+
if self.preprocessor is not None:
|
|
2673
|
+
policy_hyperparams[preproc_key] = self.preprocessor.update(
|
|
2674
|
+
train_log['fluents'], policy_hyperparams[preproc_key])
|
|
2675
|
+
|
|
2527
2676
|
# evaluate
|
|
2528
2677
|
test_loss, (test_log, model_params_test) = self.test_loss(
|
|
2529
2678
|
subkey, policy_params, policy_hyperparams, test_subs, model_params_test)
|
|
@@ -2676,6 +2825,7 @@ r"""
|
|
|
2676
2825
|
'model_params': model_params,
|
|
2677
2826
|
'progress': progress_percent,
|
|
2678
2827
|
'train_log': train_log,
|
|
2828
|
+
'policy_hyperparams': policy_hyperparams,
|
|
2679
2829
|
**test_log
|
|
2680
2830
|
}
|
|
2681
2831
|
|
|
@@ -2753,7 +2903,8 @@ r"""
|
|
|
2753
2903
|
|
|
2754
2904
|
def _perform_diagnosis(self, last_iter_improve,
|
|
2755
2905
|
train_return, test_return, best_return, grad_norm):
|
|
2756
|
-
|
|
2906
|
+
grad_norms = jax.tree_util.tree_leaves(grad_norm)
|
|
2907
|
+
max_grad_norm = max(grad_norms) if grad_norms else np.nan
|
|
2757
2908
|
grad_is_zero = np.allclose(max_grad_norm, 0)
|
|
2758
2909
|
|
|
2759
2910
|
# divergence if the solution is not finite
|
|
@@ -2895,6 +3046,7 @@ class JaxOfflineController(BaseAgent):
|
|
|
2895
3046
|
self.train_on_reset = train_on_reset
|
|
2896
3047
|
self.train_kwargs = train_kwargs
|
|
2897
3048
|
self.params_given = params is not None
|
|
3049
|
+
self.hyperparams_given = eval_hyperparams is not None
|
|
2898
3050
|
|
|
2899
3051
|
# load the policy from file
|
|
2900
3052
|
if not self.train_on_reset and params is not None and isinstance(params, str):
|
|
@@ -2908,6 +3060,8 @@ class JaxOfflineController(BaseAgent):
|
|
|
2908
3060
|
callback = self.planner.optimize(key=self.key, **self.train_kwargs)
|
|
2909
3061
|
self.callback = callback
|
|
2910
3062
|
params = callback['best_params']
|
|
3063
|
+
if not self.hyperparams_given:
|
|
3064
|
+
self.eval_hyperparams = callback['policy_hyperparams']
|
|
2911
3065
|
|
|
2912
3066
|
# save the policy
|
|
2913
3067
|
if save_path is not None:
|
|
@@ -2931,6 +3085,8 @@ class JaxOfflineController(BaseAgent):
|
|
|
2931
3085
|
callback = self.planner.optimize(key=self.key, **self.train_kwargs)
|
|
2932
3086
|
self.callback = callback
|
|
2933
3087
|
self.params = callback['best_params']
|
|
3088
|
+
if not self.hyperparams_given:
|
|
3089
|
+
self.eval_hyperparams = callback['policy_hyperparams']
|
|
2934
3090
|
|
|
2935
3091
|
|
|
2936
3092
|
class JaxOnlineController(BaseAgent):
|
|
@@ -2963,6 +3119,7 @@ class JaxOnlineController(BaseAgent):
|
|
|
2963
3119
|
key = random.PRNGKey(round(time.time() * 1000))
|
|
2964
3120
|
self.key = key
|
|
2965
3121
|
self.eval_hyperparams = eval_hyperparams
|
|
3122
|
+
self.hyperparams_given = eval_hyperparams is not None
|
|
2966
3123
|
self.warm_start = warm_start
|
|
2967
3124
|
self.train_kwargs = train_kwargs
|
|
2968
3125
|
self.max_attempts = max_attempts
|
|
@@ -2987,6 +3144,8 @@ class JaxOnlineController(BaseAgent):
|
|
|
2987
3144
|
key=self.key, guess=self.guess, subs=state, **self.train_kwargs)
|
|
2988
3145
|
self.callback = callback
|
|
2989
3146
|
params = callback['best_params']
|
|
3147
|
+
if not self.hyperparams_given:
|
|
3148
|
+
self.eval_hyperparams = callback['policy_hyperparams']
|
|
2990
3149
|
|
|
2991
3150
|
# get the action from the parameters for the current state
|
|
2992
3151
|
self.key, subkey = random.split(self.key)
|
pyRDDLGym_jax/core/simulator.py
CHANGED
|
@@ -20,7 +20,7 @@
|
|
|
20
20
|
|
|
21
21
|
import time
|
|
22
22
|
import numpy as np
|
|
23
|
-
from typing import Dict, Optional, Union
|
|
23
|
+
from typing import Callable, Dict, Optional, Union
|
|
24
24
|
|
|
25
25
|
import jax
|
|
26
26
|
|
|
@@ -48,6 +48,7 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
48
48
|
logger: Optional[Logger]=None,
|
|
49
49
|
keep_tensors: bool=False,
|
|
50
50
|
objects_as_strings: bool=True,
|
|
51
|
+
python_functions: Optional[Dict[str, Callable]]=None,
|
|
51
52
|
**compiler_args) -> None:
|
|
52
53
|
'''Creates a new simulator for the given RDDL model with Jax as a backend.
|
|
53
54
|
|
|
@@ -60,8 +61,9 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
60
61
|
:param logger: to log information about compilation to file
|
|
61
62
|
:param keep_tensors: whether the sampler takes actions and
|
|
62
63
|
returns state in numpy array form
|
|
63
|
-
param objects_as_strings: whether to return object values as strings (defaults
|
|
64
|
+
:param objects_as_strings: whether to return object values as strings (defaults
|
|
64
65
|
to integer indices if False)
|
|
66
|
+
:param python_functions: dictionary of external Python functions to call from RDDL
|
|
65
67
|
:param **compiler_args: keyword arguments to pass to the Jax compiler
|
|
66
68
|
'''
|
|
67
69
|
if key is None:
|
|
@@ -73,7 +75,8 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
73
75
|
# generate direct sampling with default numpy RNG and operations
|
|
74
76
|
super(JaxRDDLSimulator, self).__init__(
|
|
75
77
|
rddl, logger=logger,
|
|
76
|
-
keep_tensors=keep_tensors, objects_as_strings=objects_as_strings
|
|
78
|
+
keep_tensors=keep_tensors, objects_as_strings=objects_as_strings,
|
|
79
|
+
python_functions=python_functions)
|
|
77
80
|
|
|
78
81
|
def seed(self, seed: int) -> None:
|
|
79
82
|
super(JaxRDDLSimulator, self).seed(seed)
|
|
@@ -83,7 +86,12 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
83
86
|
rddl = self.rddl
|
|
84
87
|
|
|
85
88
|
# compilation
|
|
86
|
-
compiled = JaxRDDLCompiler(
|
|
89
|
+
compiled = JaxRDDLCompiler(
|
|
90
|
+
rddl,
|
|
91
|
+
logger=self.logger,
|
|
92
|
+
python_functions=self.python_functions,
|
|
93
|
+
**self.compiler_args
|
|
94
|
+
)
|
|
87
95
|
compiled.compile(log_jax_expr=True, heading='SIMULATION MODEL')
|
|
88
96
|
|
|
89
97
|
self.init_values = compiled.init_values
|
|
@@ -25,6 +25,36 @@ from pyRDDLGym_jax.core.planner import (
|
|
|
25
25
|
load_config, JaxBackpropPlanner, JaxOfflineController, JaxOnlineController
|
|
26
26
|
)
|
|
27
27
|
|
|
28
|
+
|
|
29
|
+
def run_cnn1d():
|
|
30
|
+
import haiku as hk
|
|
31
|
+
import jax
|
|
32
|
+
import jax.numpy as jnp
|
|
33
|
+
|
|
34
|
+
class CNN(hk.Module):
|
|
35
|
+
def __init__(self, name=None):
|
|
36
|
+
super().__init__(name=name)
|
|
37
|
+
self.conv1d_layer = hk.Conv1D(
|
|
38
|
+
output_channels=4,
|
|
39
|
+
kernel_shape=6, # Kernel size for 1D convolution
|
|
40
|
+
padding="SAME",
|
|
41
|
+
name="conv"
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
def __call__(self, x):
|
|
45
|
+
return self.conv1d_layer(x)
|
|
46
|
+
|
|
47
|
+
# Example usage:
|
|
48
|
+
key = jax.random.PRNGKey(42)
|
|
49
|
+
input_data = jnp.ones([1, 4]) # Batch size 1, sequence length 10, 1 input channel
|
|
50
|
+
|
|
51
|
+
# Transform the Haiku module into a pure function
|
|
52
|
+
f = hk.transform(lambda x: CNN()(x))
|
|
53
|
+
params = f.init(key, input_data)
|
|
54
|
+
print(params['cnn/~/conv']['w'].shape)
|
|
55
|
+
print(params['cnn/~/conv']['b'].shape)
|
|
56
|
+
print(f.apply(params, key, input_data).shape)
|
|
57
|
+
|
|
28
58
|
|
|
29
59
|
def main(domain: str, instance: str, method: str, episodes: int=1) -> None:
|
|
30
60
|
|
|
@@ -63,6 +93,7 @@ def main(domain: str, instance: str, method: str, episodes: int=1) -> None:
|
|
|
63
93
|
|
|
64
94
|
|
|
65
95
|
def run_from_args(args):
|
|
96
|
+
run_cnn1d()
|
|
66
97
|
if len(args) < 3:
|
|
67
98
|
print('python run_plan.py <domain> <instance> <method> [<episodes>]')
|
|
68
99
|
exit(1)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pyRDDLGym-jax
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.7
|
|
4
4
|
Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
|
|
5
5
|
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
6
|
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
@@ -20,7 +20,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
20
20
|
Requires-Python: >=3.9
|
|
21
21
|
Description-Content-Type: text/markdown
|
|
22
22
|
License-File: LICENSE
|
|
23
|
-
Requires-Dist: pyRDDLGym>=2.
|
|
23
|
+
Requires-Dist: pyRDDLGym>=2.5
|
|
24
24
|
Requires-Dist: tqdm>=4.66
|
|
25
25
|
Requires-Dist: jax>=0.4.12
|
|
26
26
|
Requires-Dist: optax>=0.1.9
|
|
@@ -55,7 +55,7 @@ Dynamic: summary
|
|
|
55
55
|
|
|
56
56
|
[Installation](#installation) | [Run cmd](#running-from-the-command-line) | [Run python](#running-from-another-python-application) | [Configuration](#configuring-the-planner) | [Dashboard](#jaxplan-dashboard) | [Tuning](#tuning-the-planner) | [Simulation](#simulation) | [Citing](#citing-jaxplan)
|
|
57
57
|
|
|
58
|
-
**pyRDDLGym-jax (
|
|
58
|
+
**pyRDDLGym-jax (or JaxPlan) is an efficient gradient-based planning algorithm based on JAX.**
|
|
59
59
|
|
|
60
60
|
Purpose:
|
|
61
61
|
|
|
@@ -84,7 +84,7 @@ and was moved to the individual logic components which have their own unique wei
|
|
|
84
84
|
|
|
85
85
|
> [!NOTE]
|
|
86
86
|
> While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
|
|
87
|
-
> If you find it is not making
|
|
87
|
+
> If you find it is not making progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
|
|
88
88
|
|
|
89
89
|
## Installation
|
|
90
90
|
|
|
@@ -220,13 +220,7 @@ controller = JaxOfflineController(planner, **train_args)
|
|
|
220
220
|
## JaxPlan Dashboard
|
|
221
221
|
|
|
222
222
|
Since version 1.0, JaxPlan has an optional dashboard that allows keeping track of the planner performance across multiple runs,
|
|
223
|
-
and visualization of the policy or model, and other useful debugging features.
|
|
224
|
-
|
|
225
|
-
<p align="middle">
|
|
226
|
-
<img src="https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/Images/dashboard.png" width="480" height="248" margin=0/>
|
|
227
|
-
</p>
|
|
228
|
-
|
|
229
|
-
To run the dashboard, add the following entry to your config file:
|
|
223
|
+
and visualization of the policy or model, and other useful debugging features. To run the dashboard, add the following to your config file:
|
|
230
224
|
|
|
231
225
|
```ini
|
|
232
226
|
...
|
|
@@ -235,8 +229,6 @@ dashboard=True
|
|
|
235
229
|
...
|
|
236
230
|
```
|
|
237
231
|
|
|
238
|
-
More documentation about this and other new features will be coming soon.
|
|
239
|
-
|
|
240
232
|
## Tuning the Planner
|
|
241
233
|
|
|
242
234
|
A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
|
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
pyRDDLGym_jax/__init__.py,sha256=
|
|
1
|
+
pyRDDLGym_jax/__init__.py,sha256=nHQztRWlKCpxZgvKkxsGQax5-clS2XguHhAvmBZt0sA,19
|
|
2
2
|
pyRDDLGym_jax/entry_point.py,sha256=K0zy1oe66jfBHkHHCM6aGHbbiVqnQvDhDb8se4uaKHE,3319
|
|
3
3
|
pyRDDLGym_jax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
|
-
pyRDDLGym_jax/core/compiler.py,sha256=
|
|
5
|
-
pyRDDLGym_jax/core/logic.py,sha256=
|
|
6
|
-
pyRDDLGym_jax/core/
|
|
7
|
-
pyRDDLGym_jax/core/
|
|
4
|
+
pyRDDLGym_jax/core/compiler.py,sha256=DS4G5f5U83cOUQsUe6RsyyJnLPDuHaqjxM7bHSWMCtM,88040
|
|
5
|
+
pyRDDLGym_jax/core/logic.py,sha256=9rRpKJCx4Us_2c6BiSWRN9k2sM_iYsAK1B7zcgwu3ZA,56290
|
|
6
|
+
pyRDDLGym_jax/core/model.py,sha256=4WfmtUVN1EKCD-7eWeQByWk8_zKyDcMABAMdlxN1LOU,27215
|
|
7
|
+
pyRDDLGym_jax/core/planner.py,sha256=cvl3JS1tLQqj8KJ5ATkHUfIzCzcYJWOCoWJYwLxMDSg,146835
|
|
8
|
+
pyRDDLGym_jax/core/simulator.py,sha256=D-yLxDFw67DvFHdb_kJjZHujSBSmiFA1J3osel-KOvY,10799
|
|
8
9
|
pyRDDLGym_jax/core/tuning.py,sha256=BWcQZk02TMLexTz1Sw4lX2EQKvmPbp7biC51M-IiNUw,25153
|
|
9
10
|
pyRDDLGym_jax/core/visualization.py,sha256=4BghMp8N7qtF0tdyDSqtxAxNfP9HPrQWTiXzAMJmx7o,70365
|
|
10
11
|
pyRDDLGym_jax/core/assets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -12,7 +13,7 @@ pyRDDLGym_jax/core/assets/favicon.ico,sha256=RMMrI9YvmF81TgYG7FO7UAre6WmYFkV3B2G
|
|
|
12
13
|
pyRDDLGym_jax/examples/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
14
|
pyRDDLGym_jax/examples/run_gradient.py,sha256=KhXvijRDZ4V7N8NOI2WV8ePGpPna5_vnET61YwS7Tco,2919
|
|
14
15
|
pyRDDLGym_jax/examples/run_gym.py,sha256=rXvNWkxe4jHllvbvU_EOMji_2-2k5d4tbBKhpMm_Gaw,1526
|
|
15
|
-
pyRDDLGym_jax/examples/run_plan.py,sha256=
|
|
16
|
+
pyRDDLGym_jax/examples/run_plan.py,sha256=uScTTUSdwohhaqvmSf9zvOjQn4xZ97qU1xYezZTIIHg,3745
|
|
16
17
|
pyRDDLGym_jax/examples/run_scipy.py,sha256=7uVnDXb7D3NTJqA2L8nrcYDJP-k0ba9dl9YqA2CD9ac,2301
|
|
17
18
|
pyRDDLGym_jax/examples/run_tune.py,sha256=F5KWgtoCPbf7XHB6HW9LjxarD57U2LvuGdTz67OL1DY,4114
|
|
18
19
|
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg,sha256=mE8MqhOlkHeXIGEVrnR3QY6I-_iy4uxFYRA71P1bmtk,347
|
|
@@ -41,9 +42,9 @@ pyRDDLGym_jax/examples/configs/default_slp.cfg,sha256=mJo0woDevhQCSQfJg30ULVy9qG
|
|
|
41
42
|
pyRDDLGym_jax/examples/configs/tuning_drp.cfg,sha256=zocZn_cVarH5i0hOlt2Zu0NwmXYBmTTghLaXLtQOGto,526
|
|
42
43
|
pyRDDLGym_jax/examples/configs/tuning_replan.cfg,sha256=9oIhtw9cuikmlbDgCgbrTc5G7hUio-HeAv_3CEGVclY,523
|
|
43
44
|
pyRDDLGym_jax/examples/configs/tuning_slp.cfg,sha256=QqnyR__5-HhKeCDfGDel8VIlqsjxRHk4SSH089zJP8s,486
|
|
44
|
-
pyrddlgym_jax-2.
|
|
45
|
-
pyrddlgym_jax-2.
|
|
46
|
-
pyrddlgym_jax-2.
|
|
47
|
-
pyrddlgym_jax-2.
|
|
48
|
-
pyrddlgym_jax-2.
|
|
49
|
-
pyrddlgym_jax-2.
|
|
45
|
+
pyrddlgym_jax-2.7.dist-info/licenses/LICENSE,sha256=2a-BZEY7aEZW-DkmmOQsuUDU0pc6ovQy3QnYFZ4baq4,1095
|
|
46
|
+
pyrddlgym_jax-2.7.dist-info/METADATA,sha256=xN_SB6x-qiC9cj8O0VvF9HIEDpK79i7FQgn8D3og2xQ,16770
|
|
47
|
+
pyrddlgym_jax-2.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
48
|
+
pyrddlgym_jax-2.7.dist-info/entry_points.txt,sha256=Q--z9QzqDBz1xjswPZ87PU-pib-WPXx44hUWAFoBGBA,59
|
|
49
|
+
pyrddlgym_jax-2.7.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
|
|
50
|
+
pyrddlgym_jax-2.7.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|