pyRDDLGym-jax 2.5__py3-none-any.whl → 2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -207,6 +207,13 @@ def _load_config(config, args):
207
207
  pgpe_kwargs['optimizer'] = pgpe_optimizer
208
208
  planner_args['pgpe'] = getattr(sys.modules[__name__], pgpe_method)(**pgpe_kwargs)
209
209
 
210
+ # preprocessor settings
211
+ preproc_method = planner_args.get('preprocessor', None)
212
+ preproc_kwargs = planner_args.pop('preprocessor_kwargs', {})
213
+ if preproc_method is not None:
214
+ planner_args['preprocessor'] = getattr(
215
+ sys.modules[__name__], preproc_method)(**preproc_kwargs)
216
+
210
217
  # optimize call RNG key
211
218
  planner_key = train_args.get('key', None)
212
219
  if planner_key is not None:
@@ -343,6 +350,100 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
343
350
  return arg
344
351
 
345
352
 
353
+ # ***********************************************************************
354
+ # ALL VERSIONS OF STATE PREPROCESSING FOR DRP
355
+ #
356
+ # - static normalization
357
+ #
358
+ # ***********************************************************************
359
+
360
+
361
+ class Preprocessor(metaclass=ABCMeta):
362
+ '''Base class for all state preprocessors.'''
363
+
364
+ HYPERPARAMS_KEY = 'preprocessor__'
365
+
366
+ def __init__(self) -> None:
367
+ self._initializer = None
368
+ self._update = None
369
+ self._transform = None
370
+
371
+ @property
372
+ def initialize(self):
373
+ return self._initializer
374
+
375
+ @property
376
+ def update(self):
377
+ return self._update
378
+
379
+ @property
380
+ def transform(self):
381
+ return self._transform
382
+
383
+ @abstractmethod
384
+ def compile(self, compiled: JaxRDDLCompilerWithGrad) -> None:
385
+ pass
386
+
387
+
388
+ class StaticNormalizer(Preprocessor):
389
+ '''Normalize values by box constraints on fluents computed from the RDDL domain.'''
390
+
391
+ def __init__(self, fluent_bounds: Dict[str, Tuple[np.ndarray, np.ndarray]]={}) -> None:
392
+ '''Create a new instance of the static normalizer.
393
+
394
+ :param fluent_bounds: optional bounds on fluents to overwrite default values.
395
+ '''
396
+ self.fluent_bounds = fluent_bounds
397
+
398
+ def compile(self, compiled: JaxRDDLCompilerWithGrad) -> None:
399
+
400
+ # adjust for partial observability
401
+ rddl = compiled.rddl
402
+ if rddl.observ_fluents:
403
+ observed_vars = rddl.observ_fluents
404
+ else:
405
+ observed_vars = rddl.state_fluents
406
+
407
+ # ignore boolean fluents and infinite bounds
408
+ bounded_vars = {}
409
+ for var in observed_vars:
410
+ if rddl.variable_ranges[var] != 'bool':
411
+ lower, upper = compiled.constraints.bounds[var]
412
+ if np.all(np.isfinite(lower) & np.isfinite(upper) & (lower < upper)):
413
+ bounded_vars[var] = (lower, upper)
414
+ user_bounds = self.fluent_bounds.get(var, None)
415
+ if user_bounds is not None:
416
+ bounded_vars[var] = tuple(user_bounds)
417
+
418
+ # initialize to ranges computed by the constraint parser
419
+ def _jax_wrapped_normalizer_init():
420
+ return bounded_vars
421
+ self._initializer = jax.jit(_jax_wrapped_normalizer_init)
422
+
423
+ # static bounds
424
+ def _jax_wrapped_normalizer_update(subs, stats):
425
+ stats = {var: (jnp.asarray(lower, dtype=compiled.REAL),
426
+ jnp.asarray(upper, dtype=compiled.REAL))
427
+ for (var, (lower, upper)) in bounded_vars.items()}
428
+ return stats
429
+ self._update = jax.jit(_jax_wrapped_normalizer_update)
430
+
431
+ # apply min max scaling
432
+ def _jax_wrapped_normalizer_transform(subs, stats):
433
+ new_subs = {}
434
+ for (var, values) in subs.items():
435
+ if var in stats:
436
+ lower, upper = stats[var]
437
+ new_dims = jnp.ndim(values) - jnp.ndim(lower)
438
+ lower = lower[(jnp.newaxis,) * new_dims + (...,)]
439
+ upper = upper[(jnp.newaxis,) * new_dims + (...,)]
440
+ new_subs[var] = (values - lower) / (upper - lower)
441
+ else:
442
+ new_subs[var] = values
443
+ return new_subs
444
+ self._transform = jax.jit(_jax_wrapped_normalizer_transform)
445
+
446
+
346
447
  # ***********************************************************************
347
448
  # ALL VERSIONS OF JAX PLANS
348
449
  #
@@ -368,7 +469,8 @@ class JaxPlan(metaclass=ABCMeta):
368
469
  @abstractmethod
369
470
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
370
471
  _bounds: Bounds,
371
- horizon: int) -> None:
472
+ horizon: int,
473
+ preprocessor: Optional[Preprocessor]=None) -> None:
372
474
  pass
373
475
 
374
476
  @abstractmethod
@@ -519,7 +621,8 @@ class JaxStraightLinePlan(JaxPlan):
519
621
 
520
622
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
521
623
  _bounds: Bounds,
522
- horizon: int) -> None:
624
+ horizon: int,
625
+ preprocessor: Optional[Preprocessor]=None) -> None:
523
626
  rddl = compiled.rddl
524
627
 
525
628
  # calculate the correct action box bounds
@@ -607,7 +710,7 @@ class JaxStraightLinePlan(JaxPlan):
607
710
  return new_params, True
608
711
 
609
712
  # convert softmax action back to action dict
610
- action_sizes = {var: np.prod(shape[1:], dtype=int)
713
+ action_sizes = {var: np.prod(shape[1:], dtype=np.int64)
611
714
  for (var, shape) in shapes.items()
612
715
  if ranges[var] == 'bool'}
613
716
 
@@ -691,7 +794,7 @@ class JaxStraightLinePlan(JaxPlan):
691
794
  scores = []
692
795
  for (var, param) in params.items():
693
796
  if ranges[var] == 'bool':
694
- param_flat = jnp.ravel(param)
797
+ param_flat = jnp.ravel(param, order='C')
695
798
  if noop[var]:
696
799
  if wrap_sigmoid:
697
800
  param_flat = -param_flat
@@ -908,7 +1011,8 @@ class JaxDeepReactivePolicy(JaxPlan):
908
1011
 
909
1012
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
910
1013
  _bounds: Bounds,
911
- horizon: int) -> None:
1014
+ horizon: int,
1015
+ preprocessor: Optional[Preprocessor]=None) -> None:
912
1016
  rddl = compiled.rddl
913
1017
 
914
1018
  # calculate the correct action box bounds
@@ -939,7 +1043,7 @@ class JaxDeepReactivePolicy(JaxPlan):
939
1043
  wrap_non_bool = self._wrap_non_bool
940
1044
  init = self._initializer
941
1045
  layers = list(enumerate(zip(self._topology, self._activations)))
942
- layer_sizes = {var: np.prod(shape, dtype=int)
1046
+ layer_sizes = {var: np.prod(shape, dtype=np.int64)
943
1047
  for (var, shape) in shapes.items()}
944
1048
  layer_names = {var: f'output_{var}'.replace('-', '_') for var in shapes}
945
1049
 
@@ -973,7 +1077,12 @@ class JaxDeepReactivePolicy(JaxPlan):
973
1077
  normalize = False
974
1078
 
975
1079
  # convert subs dictionary into a state vector to feed to the MLP
976
- def _jax_wrapped_policy_input(subs):
1080
+ def _jax_wrapped_policy_input(subs, hyperparams):
1081
+
1082
+ # optional state preprocessing
1083
+ if preprocessor is not None:
1084
+ stats = hyperparams[preprocessor.HYPERPARAMS_KEY]
1085
+ subs = preprocessor.transform(subs, stats)
977
1086
 
978
1087
  # concatenate all state variables into a single vector
979
1088
  # optionally apply layer norm to each input tensor
@@ -981,7 +1090,7 @@ class JaxDeepReactivePolicy(JaxPlan):
981
1090
  non_bool_dims = 0
982
1091
  for (var, value) in subs.items():
983
1092
  if var in observed_vars:
984
- state = jnp.ravel(value)
1093
+ state = jnp.ravel(value, order='C')
985
1094
  if ranges[var] == 'bool':
986
1095
  states_bool.append(state)
987
1096
  else:
@@ -1010,8 +1119,8 @@ class JaxDeepReactivePolicy(JaxPlan):
1010
1119
  return state
1011
1120
 
1012
1121
  # predict actions from the policy network for current state
1013
- def _jax_wrapped_policy_network_predict(subs):
1014
- state = _jax_wrapped_policy_input(subs)
1122
+ def _jax_wrapped_policy_network_predict(subs, hyperparams):
1123
+ state = _jax_wrapped_policy_input(subs, hyperparams)
1015
1124
 
1016
1125
  # feed state vector through hidden layers
1017
1126
  hidden = state
@@ -1076,7 +1185,7 @@ class JaxDeepReactivePolicy(JaxPlan):
1076
1185
 
1077
1186
  # train action prediction
1078
1187
  def _jax_wrapped_drp_predict_train(key, params, hyperparams, step, subs):
1079
- actions = predict_fn.apply(params, subs)
1188
+ actions = predict_fn.apply(params, subs, hyperparams)
1080
1189
  if not wrap_non_bool:
1081
1190
  for (var, action) in actions.items():
1082
1191
  if var != bool_key and ranges[var] != 'bool':
@@ -1126,7 +1235,7 @@ class JaxDeepReactivePolicy(JaxPlan):
1126
1235
  subs = {var: value[0, ...]
1127
1236
  for (var, value) in subs.items()
1128
1237
  if var in observed_vars}
1129
- params = predict_fn.init(key, subs)
1238
+ params = predict_fn.init(key, subs, hyperparams)
1130
1239
  return params
1131
1240
 
1132
1241
  self.initializer = _jax_wrapped_drp_init
@@ -1634,12 +1743,21 @@ def mean_semivariance_utility(returns: jnp.ndarray, beta: float) -> float:
1634
1743
  return mu - 0.5 * beta * msv
1635
1744
 
1636
1745
 
1746
+ @jax.jit
1747
+ def sharpe_utility(returns: jnp.ndarray, risk_free: float) -> float:
1748
+ return (jnp.mean(returns) - risk_free) / (jnp.std(returns) + 1e-10)
1749
+
1750
+
1751
+ @jax.jit
1752
+ def var_utility(returns: jnp.ndarray, alpha: float) -> float:
1753
+ return jnp.percentile(returns, q=100 * alpha)
1754
+
1755
+
1637
1756
  @jax.jit
1638
1757
  def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
1639
1758
  var = jnp.percentile(returns, q=100 * alpha)
1640
1759
  mask = returns <= var
1641
- weights = mask / jnp.maximum(1, jnp.sum(mask))
1642
- return jnp.sum(returns * weights)
1760
+ return jnp.sum(returns * mask) / jnp.maximum(1, jnp.sum(mask))
1643
1761
 
1644
1762
 
1645
1763
  # set of all currently valid built-in utility functions
@@ -1649,8 +1767,10 @@ UTILITY_LOOKUP = {
1649
1767
  'mean_std': mean_deviation_utility,
1650
1768
  'mean_semivar': mean_semivariance_utility,
1651
1769
  'mean_semidev': mean_semideviation_utility,
1770
+ 'sharpe': sharpe_utility,
1652
1771
  'entropic': entropic_utility,
1653
1772
  'exponential': entropic_utility,
1773
+ 'var': var_utility,
1654
1774
  'cvar': cvar_utility
1655
1775
  }
1656
1776
 
@@ -1689,7 +1809,9 @@ class JaxBackpropPlanner:
1689
1809
  logger: Optional[Logger]=None,
1690
1810
  dashboard_viz: Optional[Any]=None,
1691
1811
  print_warnings: bool=True,
1692
- parallel_updates: Optional[int]=None) -> None:
1812
+ parallel_updates: Optional[int]=None,
1813
+ preprocessor: Optional[Preprocessor]=None,
1814
+ python_functions: Optional[Dict[str, Callable]]=None) -> None:
1693
1815
  '''Creates a new gradient-based algorithm for optimizing action sequences
1694
1816
  (plan) in the given RDDL. Some operations will be converted to their
1695
1817
  differentiable counterparts; the specific operations can be customized
@@ -1731,6 +1853,8 @@ class JaxBackpropPlanner:
1731
1853
  to pass to the dashboard to visualize the policy
1732
1854
  :param print_warnings: whether to print warnings
1733
1855
  :param parallel_updates: how many optimizers to run independently in parallel
1856
+ :param preprocessor: optional preprocessor for state inputs to plan
1857
+ :param python_functions: dictionary of external Python functions to call from RDDL
1734
1858
  '''
1735
1859
  self.rddl = rddl
1736
1860
  self.plan = plan
@@ -1756,7 +1880,11 @@ class JaxBackpropPlanner:
1756
1880
  self.pgpe = pgpe
1757
1881
  self.use_pgpe = pgpe is not None
1758
1882
  self.print_warnings = print_warnings
1759
-
1883
+ self.preprocessor = preprocessor
1884
+ if python_functions is None:
1885
+ python_functions = {}
1886
+ self.python_functions = python_functions
1887
+
1760
1888
  # set optimizer
1761
1889
  try:
1762
1890
  optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs)
@@ -1881,7 +2009,8 @@ r"""
1881
2009
  f' noise_kwargs ={self.noise_kwargs}\n'
1882
2010
  f' batch_size_train ={self.batch_size_train}\n'
1883
2011
  f' batch_size_test ={self.batch_size_test}\n'
1884
- f' parallel_updates ={self.parallel_updates}\n')
2012
+ f' parallel_updates ={self.parallel_updates}\n'
2013
+ f' preprocessor ={self.preprocessor}\n')
1885
2014
  result += str(self.plan)
1886
2015
  if self.use_pgpe:
1887
2016
  result += str(self.pgpe)
@@ -1903,7 +2032,8 @@ r"""
1903
2032
  use64bit=self.use64bit,
1904
2033
  cpfs_without_grad=self.cpfs_without_grad,
1905
2034
  compile_non_fluent_exact=self.compile_non_fluent_exact,
1906
- print_warnings=self.print_warnings
2035
+ print_warnings=self.print_warnings,
2036
+ python_functions=self.python_functions
1907
2037
  )
1908
2038
  self.compiled.compile(log_jax_expr=True, heading='RELAXED MODEL')
1909
2039
 
@@ -1911,16 +2041,22 @@ r"""
1911
2041
  self.test_compiled = JaxRDDLCompiler(
1912
2042
  rddl=rddl,
1913
2043
  logger=self.logger,
1914
- use64bit=self.use64bit
2044
+ use64bit=self.use64bit,
2045
+ python_functions=self.python_functions
1915
2046
  )
1916
2047
  self.test_compiled.compile(log_jax_expr=True, heading='EXACT MODEL')
1917
2048
 
1918
2049
  def _jax_compile_optimizer(self):
1919
2050
 
2051
+ # preprocessor
2052
+ if self.preprocessor is not None:
2053
+ self.preprocessor.compile(self.compiled)
2054
+
1920
2055
  # policy
1921
2056
  self.plan.compile(self.compiled,
1922
2057
  _bounds=self._action_bounds,
1923
- horizon=self.horizon)
2058
+ horizon=self.horizon,
2059
+ preprocessor=self.preprocessor)
1924
2060
  self.train_policy = jax.jit(self.plan.train_policy)
1925
2061
  self.test_policy = jax.jit(self.plan.test_policy)
1926
2062
 
@@ -1928,14 +2064,16 @@ r"""
1928
2064
  train_rollouts = self.compiled.compile_rollouts(
1929
2065
  policy=self.plan.train_policy,
1930
2066
  n_steps=self.horizon,
1931
- n_batch=self.batch_size_train
2067
+ n_batch=self.batch_size_train,
2068
+ cache_path_info=self.preprocessor is not None
1932
2069
  )
1933
2070
  self.train_rollouts = train_rollouts
1934
2071
 
1935
2072
  test_rollouts = self.test_compiled.compile_rollouts(
1936
2073
  policy=self.plan.test_policy,
1937
2074
  n_steps=self.horizon,
1938
- n_batch=self.batch_size_test
2075
+ n_batch=self.batch_size_test,
2076
+ cache_path_info=False
1939
2077
  )
1940
2078
  self.test_rollouts = jax.jit(test_rollouts)
1941
2079
 
@@ -2397,7 +2535,13 @@ r"""
2397
2535
  f'which could be suboptimal.', 'yellow')
2398
2536
  print(message)
2399
2537
  policy_hyperparams[action] = 1.0
2400
-
2538
+
2539
+ # initialize preprocessor
2540
+ preproc_key = None
2541
+ if self.preprocessor is not None:
2542
+ preproc_key = self.preprocessor.HYPERPARAMS_KEY
2543
+ policy_hyperparams[preproc_key] = self.preprocessor.initialize()
2544
+
2401
2545
  # print summary of parameters:
2402
2546
  if print_summary:
2403
2547
  print(self.summarize_system())
@@ -2524,6 +2668,11 @@ r"""
2524
2668
  subkey, policy_params, policy_hyperparams, train_subs, model_params,
2525
2669
  opt_state, opt_aux)
2526
2670
 
2671
+ # update the preprocessor
2672
+ if self.preprocessor is not None:
2673
+ policy_hyperparams[preproc_key] = self.preprocessor.update(
2674
+ train_log['fluents'], policy_hyperparams[preproc_key])
2675
+
2527
2676
  # evaluate
2528
2677
  test_loss, (test_log, model_params_test) = self.test_loss(
2529
2678
  subkey, policy_params, policy_hyperparams, test_subs, model_params_test)
@@ -2676,6 +2825,7 @@ r"""
2676
2825
  'model_params': model_params,
2677
2826
  'progress': progress_percent,
2678
2827
  'train_log': train_log,
2828
+ 'policy_hyperparams': policy_hyperparams,
2679
2829
  **test_log
2680
2830
  }
2681
2831
 
@@ -2753,7 +2903,8 @@ r"""
2753
2903
 
2754
2904
  def _perform_diagnosis(self, last_iter_improve,
2755
2905
  train_return, test_return, best_return, grad_norm):
2756
- max_grad_norm = max(jax.tree_util.tree_leaves(grad_norm))
2906
+ grad_norms = jax.tree_util.tree_leaves(grad_norm)
2907
+ max_grad_norm = max(grad_norms) if grad_norms else np.nan
2757
2908
  grad_is_zero = np.allclose(max_grad_norm, 0)
2758
2909
 
2759
2910
  # divergence if the solution is not finite
@@ -2895,6 +3046,7 @@ class JaxOfflineController(BaseAgent):
2895
3046
  self.train_on_reset = train_on_reset
2896
3047
  self.train_kwargs = train_kwargs
2897
3048
  self.params_given = params is not None
3049
+ self.hyperparams_given = eval_hyperparams is not None
2898
3050
 
2899
3051
  # load the policy from file
2900
3052
  if not self.train_on_reset and params is not None and isinstance(params, str):
@@ -2908,6 +3060,8 @@ class JaxOfflineController(BaseAgent):
2908
3060
  callback = self.planner.optimize(key=self.key, **self.train_kwargs)
2909
3061
  self.callback = callback
2910
3062
  params = callback['best_params']
3063
+ if not self.hyperparams_given:
3064
+ self.eval_hyperparams = callback['policy_hyperparams']
2911
3065
 
2912
3066
  # save the policy
2913
3067
  if save_path is not None:
@@ -2931,6 +3085,8 @@ class JaxOfflineController(BaseAgent):
2931
3085
  callback = self.planner.optimize(key=self.key, **self.train_kwargs)
2932
3086
  self.callback = callback
2933
3087
  self.params = callback['best_params']
3088
+ if not self.hyperparams_given:
3089
+ self.eval_hyperparams = callback['policy_hyperparams']
2934
3090
 
2935
3091
 
2936
3092
  class JaxOnlineController(BaseAgent):
@@ -2963,6 +3119,7 @@ class JaxOnlineController(BaseAgent):
2963
3119
  key = random.PRNGKey(round(time.time() * 1000))
2964
3120
  self.key = key
2965
3121
  self.eval_hyperparams = eval_hyperparams
3122
+ self.hyperparams_given = eval_hyperparams is not None
2966
3123
  self.warm_start = warm_start
2967
3124
  self.train_kwargs = train_kwargs
2968
3125
  self.max_attempts = max_attempts
@@ -2987,6 +3144,8 @@ class JaxOnlineController(BaseAgent):
2987
3144
  key=self.key, guess=self.guess, subs=state, **self.train_kwargs)
2988
3145
  self.callback = callback
2989
3146
  params = callback['best_params']
3147
+ if not self.hyperparams_given:
3148
+ self.eval_hyperparams = callback['policy_hyperparams']
2990
3149
 
2991
3150
  # get the action from the parameters for the current state
2992
3151
  self.key, subkey = random.split(self.key)
@@ -20,7 +20,7 @@
20
20
 
21
21
  import time
22
22
  import numpy as np
23
- from typing import Dict, Optional, Union
23
+ from typing import Callable, Dict, Optional, Union
24
24
 
25
25
  import jax
26
26
 
@@ -48,6 +48,7 @@ class JaxRDDLSimulator(RDDLSimulator):
48
48
  logger: Optional[Logger]=None,
49
49
  keep_tensors: bool=False,
50
50
  objects_as_strings: bool=True,
51
+ python_functions: Optional[Dict[str, Callable]]=None,
51
52
  **compiler_args) -> None:
52
53
  '''Creates a new simulator for the given RDDL model with Jax as a backend.
53
54
 
@@ -60,8 +61,9 @@ class JaxRDDLSimulator(RDDLSimulator):
60
61
  :param logger: to log information about compilation to file
61
62
  :param keep_tensors: whether the sampler takes actions and
62
63
  returns state in numpy array form
63
- param objects_as_strings: whether to return object values as strings (defaults
64
+ :param objects_as_strings: whether to return object values as strings (defaults
64
65
  to integer indices if False)
66
+ :param python_functions: dictionary of external Python functions to call from RDDL
65
67
  :param **compiler_args: keyword arguments to pass to the Jax compiler
66
68
  '''
67
69
  if key is None:
@@ -73,7 +75,8 @@ class JaxRDDLSimulator(RDDLSimulator):
73
75
  # generate direct sampling with default numpy RNG and operations
74
76
  super(JaxRDDLSimulator, self).__init__(
75
77
  rddl, logger=logger,
76
- keep_tensors=keep_tensors, objects_as_strings=objects_as_strings)
78
+ keep_tensors=keep_tensors, objects_as_strings=objects_as_strings,
79
+ python_functions=python_functions)
77
80
 
78
81
  def seed(self, seed: int) -> None:
79
82
  super(JaxRDDLSimulator, self).seed(seed)
@@ -83,7 +86,12 @@ class JaxRDDLSimulator(RDDLSimulator):
83
86
  rddl = self.rddl
84
87
 
85
88
  # compilation
86
- compiled = JaxRDDLCompiler(rddl, logger=self.logger, **self.compiler_args)
89
+ compiled = JaxRDDLCompiler(
90
+ rddl,
91
+ logger=self.logger,
92
+ python_functions=self.python_functions,
93
+ **self.compiler_args
94
+ )
87
95
  compiled.compile(log_jax_expr=True, heading='SIMULATION MODEL')
88
96
 
89
97
  self.init_values = compiled.init_values
@@ -25,6 +25,36 @@ from pyRDDLGym_jax.core.planner import (
25
25
  load_config, JaxBackpropPlanner, JaxOfflineController, JaxOnlineController
26
26
  )
27
27
 
28
+
29
+ def run_cnn1d():
30
+ import haiku as hk
31
+ import jax
32
+ import jax.numpy as jnp
33
+
34
+ class CNN(hk.Module):
35
+ def __init__(self, name=None):
36
+ super().__init__(name=name)
37
+ self.conv1d_layer = hk.Conv1D(
38
+ output_channels=4,
39
+ kernel_shape=6, # Kernel size for 1D convolution
40
+ padding="SAME",
41
+ name="conv"
42
+ )
43
+
44
+ def __call__(self, x):
45
+ return self.conv1d_layer(x)
46
+
47
+ # Example usage:
48
+ key = jax.random.PRNGKey(42)
49
+ input_data = jnp.ones([1, 4]) # Batch size 1, sequence length 10, 1 input channel
50
+
51
+ # Transform the Haiku module into a pure function
52
+ f = hk.transform(lambda x: CNN()(x))
53
+ params = f.init(key, input_data)
54
+ print(params['cnn/~/conv']['w'].shape)
55
+ print(params['cnn/~/conv']['b'].shape)
56
+ print(f.apply(params, key, input_data).shape)
57
+
28
58
 
29
59
  def main(domain: str, instance: str, method: str, episodes: int=1) -> None:
30
60
 
@@ -63,6 +93,7 @@ def main(domain: str, instance: str, method: str, episodes: int=1) -> None:
63
93
 
64
94
 
65
95
  def run_from_args(args):
96
+ run_cnn1d()
66
97
  if len(args) < 3:
67
98
  print('python run_plan.py <domain> <instance> <method> [<episodes>]')
68
99
  exit(1)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyRDDLGym-jax
3
- Version: 2.5
3
+ Version: 2.7
4
4
  Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
5
  Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
6
  Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
@@ -20,7 +20,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
20
  Requires-Python: >=3.9
21
21
  Description-Content-Type: text/markdown
22
22
  License-File: LICENSE
23
- Requires-Dist: pyRDDLGym>=2.0
23
+ Requires-Dist: pyRDDLGym>=2.5
24
24
  Requires-Dist: tqdm>=4.66
25
25
  Requires-Dist: jax>=0.4.12
26
26
  Requires-Dist: optax>=0.1.9
@@ -55,7 +55,7 @@ Dynamic: summary
55
55
 
56
56
  [Installation](#installation) | [Run cmd](#running-from-the-command-line) | [Run python](#running-from-another-python-application) | [Configuration](#configuring-the-planner) | [Dashboard](#jaxplan-dashboard) | [Tuning](#tuning-the-planner) | [Simulation](#simulation) | [Citing](#citing-jaxplan)
57
57
 
58
- **pyRDDLGym-jax (known in the literature as JaxPlan) is an efficient gradient-based/differentiable planning algorithm in JAX.**
58
+ **pyRDDLGym-jax (or JaxPlan) is an efficient gradient-based planning algorithm based on JAX.**
59
59
 
60
60
  Purpose:
61
61
 
@@ -84,7 +84,7 @@ and was moved to the individual logic components which have their own unique wei
84
84
 
85
85
  > [!NOTE]
86
86
  > While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
87
- > If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
87
+ > If you find it is not making progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
88
88
 
89
89
  ## Installation
90
90
 
@@ -220,13 +220,7 @@ controller = JaxOfflineController(planner, **train_args)
220
220
  ## JaxPlan Dashboard
221
221
 
222
222
  Since version 1.0, JaxPlan has an optional dashboard that allows keeping track of the planner performance across multiple runs,
223
- and visualization of the policy or model, and other useful debugging features.
224
-
225
- <p align="middle">
226
- <img src="https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/Images/dashboard.png" width="480" height="248" margin=0/>
227
- </p>
228
-
229
- To run the dashboard, add the following entry to your config file:
223
+ and visualization of the policy or model, and other useful debugging features. To run the dashboard, add the following to your config file:
230
224
 
231
225
  ```ini
232
226
  ...
@@ -235,8 +229,6 @@ dashboard=True
235
229
  ...
236
230
  ```
237
231
 
238
- More documentation about this and other new features will be coming soon.
239
-
240
232
  ## Tuning the Planner
241
233
 
242
234
  A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
@@ -1,10 +1,11 @@
1
- pyRDDLGym_jax/__init__.py,sha256=VoxLo_sy8RlJIIyu7szqL-cdMGBJdQPg-aSeyOVVIkY,19
1
+ pyRDDLGym_jax/__init__.py,sha256=nHQztRWlKCpxZgvKkxsGQax5-clS2XguHhAvmBZt0sA,19
2
2
  pyRDDLGym_jax/entry_point.py,sha256=K0zy1oe66jfBHkHHCM6aGHbbiVqnQvDhDb8se4uaKHE,3319
3
3
  pyRDDLGym_jax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- pyRDDLGym_jax/core/compiler.py,sha256=uFCtoipsIa3MM9nGgT3X8iCViPl2XSPNXh0jMdzN0ko,82895
5
- pyRDDLGym_jax/core/logic.py,sha256=lfc2ak_ap_ajMEFlB5EHCRNgJym31dNyA-5d-7N4CZA,56271
6
- pyRDDLGym_jax/core/planner.py,sha256=M6GKzN7Ml57B4ZrFZhhkpsQCvReKaCQNzer7zeHCM9E,140275
7
- pyRDDLGym_jax/core/simulator.py,sha256=ayCATTUL3clLaZPQ5OUg2bI_c26KKCTq6TbrxbMsVdc,10470
4
+ pyRDDLGym_jax/core/compiler.py,sha256=DS4G5f5U83cOUQsUe6RsyyJnLPDuHaqjxM7bHSWMCtM,88040
5
+ pyRDDLGym_jax/core/logic.py,sha256=9rRpKJCx4Us_2c6BiSWRN9k2sM_iYsAK1B7zcgwu3ZA,56290
6
+ pyRDDLGym_jax/core/model.py,sha256=4WfmtUVN1EKCD-7eWeQByWk8_zKyDcMABAMdlxN1LOU,27215
7
+ pyRDDLGym_jax/core/planner.py,sha256=cvl3JS1tLQqj8KJ5ATkHUfIzCzcYJWOCoWJYwLxMDSg,146835
8
+ pyRDDLGym_jax/core/simulator.py,sha256=D-yLxDFw67DvFHdb_kJjZHujSBSmiFA1J3osel-KOvY,10799
8
9
  pyRDDLGym_jax/core/tuning.py,sha256=BWcQZk02TMLexTz1Sw4lX2EQKvmPbp7biC51M-IiNUw,25153
9
10
  pyRDDLGym_jax/core/visualization.py,sha256=4BghMp8N7qtF0tdyDSqtxAxNfP9HPrQWTiXzAMJmx7o,70365
10
11
  pyRDDLGym_jax/core/assets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -12,7 +13,7 @@ pyRDDLGym_jax/core/assets/favicon.ico,sha256=RMMrI9YvmF81TgYG7FO7UAre6WmYFkV3B2G
12
13
  pyRDDLGym_jax/examples/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
14
  pyRDDLGym_jax/examples/run_gradient.py,sha256=KhXvijRDZ4V7N8NOI2WV8ePGpPna5_vnET61YwS7Tco,2919
14
15
  pyRDDLGym_jax/examples/run_gym.py,sha256=rXvNWkxe4jHllvbvU_EOMji_2-2k5d4tbBKhpMm_Gaw,1526
15
- pyRDDLGym_jax/examples/run_plan.py,sha256=4y7JHqTxY5O1ltP6N7rar0jMiw7u9w1nuAIOcmDaAuE,2806
16
+ pyRDDLGym_jax/examples/run_plan.py,sha256=uScTTUSdwohhaqvmSf9zvOjQn4xZ97qU1xYezZTIIHg,3745
16
17
  pyRDDLGym_jax/examples/run_scipy.py,sha256=7uVnDXb7D3NTJqA2L8nrcYDJP-k0ba9dl9YqA2CD9ac,2301
17
18
  pyRDDLGym_jax/examples/run_tune.py,sha256=F5KWgtoCPbf7XHB6HW9LjxarD57U2LvuGdTz67OL1DY,4114
18
19
  pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg,sha256=mE8MqhOlkHeXIGEVrnR3QY6I-_iy4uxFYRA71P1bmtk,347
@@ -41,9 +42,9 @@ pyRDDLGym_jax/examples/configs/default_slp.cfg,sha256=mJo0woDevhQCSQfJg30ULVy9qG
41
42
  pyRDDLGym_jax/examples/configs/tuning_drp.cfg,sha256=zocZn_cVarH5i0hOlt2Zu0NwmXYBmTTghLaXLtQOGto,526
42
43
  pyRDDLGym_jax/examples/configs/tuning_replan.cfg,sha256=9oIhtw9cuikmlbDgCgbrTc5G7hUio-HeAv_3CEGVclY,523
43
44
  pyRDDLGym_jax/examples/configs/tuning_slp.cfg,sha256=QqnyR__5-HhKeCDfGDel8VIlqsjxRHk4SSH089zJP8s,486
44
- pyrddlgym_jax-2.5.dist-info/licenses/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
45
- pyrddlgym_jax-2.5.dist-info/METADATA,sha256=XAaEJfbsYW-txxZhFZ6o_HmvqxkIMTqBF9LbV-KdTzI,17058
46
- pyrddlgym_jax-2.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
47
- pyrddlgym_jax-2.5.dist-info/entry_points.txt,sha256=Q--z9QzqDBz1xjswPZ87PU-pib-WPXx44hUWAFoBGBA,59
48
- pyrddlgym_jax-2.5.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
49
- pyrddlgym_jax-2.5.dist-info/RECORD,,
45
+ pyrddlgym_jax-2.7.dist-info/licenses/LICENSE,sha256=2a-BZEY7aEZW-DkmmOQsuUDU0pc6ovQy3QnYFZ4baq4,1095
46
+ pyrddlgym_jax-2.7.dist-info/METADATA,sha256=xN_SB6x-qiC9cj8O0VvF9HIEDpK79i7FQgn8D3og2xQ,16770
47
+ pyrddlgym_jax-2.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
48
+ pyrddlgym_jax-2.7.dist-info/entry_points.txt,sha256=Q--z9QzqDBz1xjswPZ87PU-pib-WPXx44hUWAFoBGBA,59
49
+ pyrddlgym_jax-2.7.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
50
+ pyrddlgym_jax-2.7.dist-info/RECORD,,
@@ -1,6 +1,6 @@
1
1
  MIT License
2
2
 
3
- Copyright (c) 2024 pyrddlgym-project
3
+ Copyright (c) 2025 pyrddlgym-project
4
4
 
5
5
  Permission is hereby granted, free of charge, to any person obtaining a copy
6
6
  of this software and associated documentation files (the "Software"), to deal