pyRDDLGym-jax 2.4__py3-none-any.whl → 2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pyRDDLGym_jax/__init__.py CHANGED
@@ -1 +1 @@
1
- __version__ = '2.4'
1
+ __version__ = '2.5'
@@ -430,7 +430,7 @@ class JaxRDDLCompiler:
430
430
  _jax_wrapped_single_step_policy,
431
431
  in_axes=(0, None, None, None, 0, None)
432
432
  )(keys, policy_params, hyperparams, step, subs, model_params)
433
- model_params = jax.tree_map(partial(jnp.mean, axis=0), model_params)
433
+ model_params = jax.tree_util.tree_map(partial(jnp.mean, axis=0), model_params)
434
434
  carry = (key, policy_params, hyperparams, subs, model_params)
435
435
  return carry, log
436
436
 
@@ -440,7 +440,7 @@ class JaxRDDLCompiler:
440
440
  start = (key, policy_params, hyperparams, subs, model_params)
441
441
  steps = jnp.arange(n_steps)
442
442
  end, log = jax.lax.scan(_jax_wrapped_batched_step_policy, start, steps)
443
- log = jax.tree_map(partial(jnp.swapaxes, axis1=0, axis2=1), log)
443
+ log = jax.tree_util.tree_map(partial(jnp.swapaxes, axis1=0, axis2=1), log)
444
444
  model_params = end[-1]
445
445
  return log, model_params
446
446
 
@@ -707,7 +707,10 @@ class JaxRDDLCompiler:
707
707
  sample = jnp.asarray(value, dtype=self._fix_dtype(value))
708
708
  new_slices = [None] * len(jax_nested_expr)
709
709
  for (i, jax_expr) in enumerate(jax_nested_expr):
710
- new_slices[i], key, err, params = jax_expr(x, params, key)
710
+ new_slice, key, err, params = jax_expr(x, params, key)
711
+ if not jnp.issubdtype(jnp.result_type(new_slice), jnp.integer):
712
+ new_slice = jnp.asarray(new_slice, dtype=self.INT)
713
+ new_slices[i] = new_slice
711
714
  error |= err
712
715
  new_slices = tuple(new_slices)
713
716
  sample = sample[new_slices]
@@ -986,7 +989,8 @@ class JaxRDDLCompiler:
986
989
  sample_cases = [None] * len(jax_cases)
987
990
  for (i, jax_case) in enumerate(jax_cases):
988
991
  sample_cases[i], key, err_case, params = jax_case(x, params, key)
989
- err |= err_case
992
+ err |= err_case
993
+ sample_cases = jnp.asarray(sample_cases)
990
994
  sample_cases = jnp.asarray(sample_cases, dtype=self._fix_dtype(sample_cases))
991
995
 
992
996
  # predicate (enum) is an integer - use it to extract from case array
@@ -39,6 +39,7 @@ import configparser
39
39
  from enum import Enum
40
40
  from functools import partial
41
41
  import os
42
+ import pickle
42
43
  import sys
43
44
  import time
44
45
  import traceback
@@ -229,13 +230,19 @@ def _load_config(config, args):
229
230
 
230
231
 
231
232
  def load_config(path: str) -> Tuple[Kwargs, ...]:
232
- '''Loads a config file at the specified file path.'''
233
+ '''Loads a config file at the specified file path.
234
+
235
+ :param path: the path of the config file to load and parse
236
+ '''
233
237
  config, args = _parse_config_file(path)
234
238
  return _load_config(config, args)
235
239
 
236
240
 
237
241
  def load_config_from_string(value: str) -> Tuple[Kwargs, ...]:
238
- '''Loads config file contents specified explicitly as a string value.'''
242
+ '''Loads config file contents specified explicitly as a string value.
243
+
244
+ :param value: the string in json format containing the config contents to parse
245
+ '''
239
246
  config, args = _parse_config_string(value)
240
247
  return _load_config(config, args)
241
248
 
@@ -258,6 +265,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
258
265
  def __init__(self, *args,
259
266
  logic: Logic=FuzzyLogic(),
260
267
  cpfs_without_grad: Optional[Set[str]]=None,
268
+ print_warnings: bool=True,
261
269
  **kwargs) -> None:
262
270
  '''Creates a new RDDL to Jax compiler, where operations that are not
263
271
  differentiable are converted to approximate forms that have defined gradients.
@@ -268,6 +276,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
268
276
  to customize these operations
269
277
  :param cpfs_without_grad: which CPFs do not have gradients (use straight
270
278
  through gradient trick)
279
+ :param print_warnings: whether to print warnings
271
280
  :param *kwargs: keyword arguments to pass to base compiler
272
281
  '''
273
282
  super(JaxRDDLCompilerWithGrad, self).__init__(*args, **kwargs)
@@ -277,6 +286,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
277
286
  if cpfs_without_grad is None:
278
287
  cpfs_without_grad = set()
279
288
  self.cpfs_without_grad = cpfs_without_grad
289
+ self.print_warnings = print_warnings
280
290
 
281
291
  # actions and CPFs must be continuous
282
292
  pvars_cast = set()
@@ -284,7 +294,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
284
294
  self.init_values[var] = np.asarray(values, dtype=self.REAL)
285
295
  if not np.issubdtype(np.result_type(values), np.floating):
286
296
  pvars_cast.add(var)
287
- if pvars_cast:
297
+ if self.print_warnings and pvars_cast:
288
298
  message = termcolor.colored(
289
299
  f'[INFO] JAX gradient compiler will cast p-vars {pvars_cast} to float.',
290
300
  'green')
@@ -314,12 +324,12 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
314
324
  if cpf in self.cpfs_without_grad:
315
325
  jax_cpfs[cpf] = self._jax_stop_grad(jax_cpfs[cpf])
316
326
 
317
- if cpfs_cast:
327
+ if self.print_warnings and cpfs_cast:
318
328
  message = termcolor.colored(
319
329
  f'[INFO] JAX gradient compiler will cast CPFs {cpfs_cast} to float.',
320
330
  'green')
321
331
  print(message)
322
- if self.cpfs_without_grad:
332
+ if self.print_warnings and self.cpfs_without_grad:
323
333
  message = termcolor.colored(
324
334
  f'[INFO] Gradients will not flow through CPFs {self.cpfs_without_grad}.',
325
335
  'green')
@@ -436,10 +446,11 @@ class JaxPlan(metaclass=ABCMeta):
436
446
  ~lower_finite & upper_finite,
437
447
  ~lower_finite & ~upper_finite]
438
448
  bounds[name] = (lower, upper)
439
- message = termcolor.colored(
440
- f'[INFO] Bounds of action-fluent <{name}> set to {bounds[name]}.',
441
- 'green')
442
- print(message)
449
+ if compiled.print_warnings:
450
+ message = termcolor.colored(
451
+ f'[INFO] Bounds of action-fluent <{name}> set to {bounds[name]}.',
452
+ 'green')
453
+ print(message)
443
454
  return shapes, bounds, bounds_safe, cond_lists
444
455
 
445
456
  def _count_bool_actions(self, rddl: RDDLLiftedModel):
@@ -519,7 +530,7 @@ class JaxStraightLinePlan(JaxPlan):
519
530
  # action concurrency check
520
531
  bool_action_count, allowed_actions = self._count_bool_actions(rddl)
521
532
  use_constraint_satisfaction = allowed_actions < bool_action_count
522
- if use_constraint_satisfaction:
533
+ if compiled.print_warnings and use_constraint_satisfaction:
523
534
  message = termcolor.colored(
524
535
  f'[INFO] SLP will use projected gradient to satisfy '
525
536
  f'max_nondef_actions since total boolean actions '
@@ -605,7 +616,7 @@ class JaxStraightLinePlan(JaxPlan):
605
616
  start = 0
606
617
  for (name, size) in action_sizes.items():
607
618
  action = output[..., start:start + size]
608
- action = jnp.reshape(action, newshape=shapes[name][1:])
619
+ action = jnp.reshape(action, shapes[name][1:])
609
620
  if noop[name]:
610
621
  action = 1.0 - action
611
622
  actions[name] = action
@@ -838,7 +849,7 @@ class JaxStraightLinePlan(JaxPlan):
838
849
 
839
850
  def guess_next_epoch(self, params: Pytree) -> Pytree:
840
851
  next_fn = JaxStraightLinePlan._guess_next_epoch
841
- return jax.tree_map(next_fn, params)
852
+ return jax.tree_util.tree_map(next_fn, params)
842
853
 
843
854
 
844
855
  class JaxDeepReactivePolicy(JaxPlan):
@@ -946,17 +957,19 @@ class JaxDeepReactivePolicy(JaxPlan):
946
957
  if ranges[var] != 'bool':
947
958
  value_size = np.size(values)
948
959
  if normalize_per_layer and value_size == 1:
949
- message = termcolor.colored(
950
- f'[WARN] Cannot apply layer norm to state-fluent <{var}> '
951
- f'of size 1: setting normalize_per_layer = False.', 'yellow')
952
- print(message)
960
+ if compiled.print_warnings:
961
+ message = termcolor.colored(
962
+ f'[WARN] Cannot apply layer norm to state-fluent <{var}> '
963
+ f'of size 1: setting normalize_per_layer = False.', 'yellow')
964
+ print(message)
953
965
  normalize_per_layer = False
954
966
  non_bool_dims += value_size
955
967
  if not normalize_per_layer and non_bool_dims == 1:
956
- message = termcolor.colored(
957
- '[WARN] Cannot apply layer norm to state-fluents of total size 1: '
958
- 'setting normalize = False.', 'yellow')
959
- print(message)
968
+ if compiled.print_warnings:
969
+ message = termcolor.colored(
970
+ '[WARN] Cannot apply layer norm to state-fluents of total size 1: '
971
+ 'setting normalize = False.', 'yellow')
972
+ print(message)
960
973
  normalize = False
961
974
 
962
975
  # convert subs dictionary into a state vector to feed to the MLP
@@ -1054,7 +1067,7 @@ class JaxDeepReactivePolicy(JaxPlan):
1054
1067
  for (name, size) in layer_sizes.items():
1055
1068
  if ranges[name] == 'bool':
1056
1069
  action = output[..., start:start + size]
1057
- action = jnp.reshape(action, newshape=shapes[name])
1070
+ action = jnp.reshape(action, shapes[name])
1058
1071
  if noop[name]:
1059
1072
  action = 1.0 - action
1060
1073
  actions[name] = action
@@ -1226,6 +1239,7 @@ class PGPE(metaclass=ABCMeta):
1226
1239
 
1227
1240
  @abstractmethod
1228
1241
  def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type,
1242
+ print_warnings: bool,
1229
1243
  parallel_updates: Optional[int]=None) -> None:
1230
1244
  pass
1231
1245
 
@@ -1322,6 +1336,7 @@ class GaussianPGPE(PGPE):
1322
1336
  )
1323
1337
 
1324
1338
  def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type,
1339
+ print_warnings: bool,
1325
1340
  parallel_updates: Optional[int]=None) -> None:
1326
1341
  sigma0 = self.init_sigma
1327
1342
  sigma_lo, sigma_hi = self.sigma_range
@@ -1347,7 +1362,7 @@ class GaussianPGPE(PGPE):
1347
1362
 
1348
1363
  def _jax_wrapped_pgpe_init(key, policy_params):
1349
1364
  mu = policy_params
1350
- sigma = jax.tree_map(partial(jnp.full_like, fill_value=sigma0), mu)
1365
+ sigma = jax.tree_util.tree_map(partial(jnp.full_like, fill_value=sigma0), mu)
1351
1366
  pgpe_params = (mu, sigma)
1352
1367
  pgpe_opt_state = (mu_optimizer.init(mu), sigma_optimizer.init(sigma))
1353
1368
  r_max = -jnp.inf
@@ -1395,13 +1410,14 @@ class GaussianPGPE(PGPE):
1395
1410
  treedef = jax.tree_util.tree_structure(sigma)
1396
1411
  keys = random.split(key, num=treedef.num_leaves)
1397
1412
  keys_pytree = jax.tree_util.tree_unflatten(treedef=treedef, leaves=keys)
1398
- epsilon = jax.tree_map(_jax_wrapped_mu_noise, keys_pytree, sigma)
1399
- p1 = jax.tree_map(jnp.add, mu, epsilon)
1400
- p2 = jax.tree_map(jnp.subtract, mu, epsilon)
1413
+ epsilon = jax.tree_util.tree_map(_jax_wrapped_mu_noise, keys_pytree, sigma)
1414
+ p1 = jax.tree_util.tree_map(jnp.add, mu, epsilon)
1415
+ p2 = jax.tree_util.tree_map(jnp.subtract, mu, epsilon)
1401
1416
  if super_symmetric:
1402
- epsilon_star = jax.tree_map(_jax_wrapped_epsilon_star, sigma, epsilon)
1403
- p3 = jax.tree_map(jnp.add, mu, epsilon_star)
1404
- p4 = jax.tree_map(jnp.subtract, mu, epsilon_star)
1417
+ epsilon_star = jax.tree_util.tree_map(
1418
+ _jax_wrapped_epsilon_star, sigma, epsilon)
1419
+ p3 = jax.tree_util.tree_map(jnp.add, mu, epsilon_star)
1420
+ p4 = jax.tree_util.tree_map(jnp.subtract, mu, epsilon_star)
1405
1421
  else:
1406
1422
  epsilon_star, p3, p4 = epsilon, p1, p2
1407
1423
  return p1, p2, p3, p4, epsilon, epsilon_star
@@ -1469,11 +1485,11 @@ class GaussianPGPE(PGPE):
1469
1485
  r_max = jnp.maximum(r_max, r4)
1470
1486
  else:
1471
1487
  r3, r4 = r1, r2
1472
- grad_mu = jax.tree_map(
1488
+ grad_mu = jax.tree_util.tree_map(
1473
1489
  partial(_jax_wrapped_mu_grad, r1=r1, r2=r2, r3=r3, r4=r4, m=r_max),
1474
1490
  epsilon, epsilon_star
1475
1491
  )
1476
- grad_sigma = jax.tree_map(
1492
+ grad_sigma = jax.tree_util.tree_map(
1477
1493
  partial(_jax_wrapped_sigma_grad,
1478
1494
  r1=r1, r2=r2, r3=r3, r4=r4, m=r_max, ent=ent),
1479
1495
  epsilon, epsilon_star, sigma
@@ -1492,7 +1508,7 @@ class GaussianPGPE(PGPE):
1492
1508
  _jax_wrapped_pgpe_grad,
1493
1509
  in_axes=(0, None, None, None, None, None, None, None)
1494
1510
  )(keys, mu, sigma, r_max, ent, policy_hyperparams, subs, model_params)
1495
- mu_grad, sigma_grad = jax.tree_map(
1511
+ mu_grad, sigma_grad = jax.tree_util.tree_map(
1496
1512
  partial(jnp.mean, axis=0), (mu_grads, sigma_grads))
1497
1513
  new_r_max = jnp.max(r_maxs)
1498
1514
  return mu_grad, sigma_grad, new_r_max
@@ -1516,7 +1532,7 @@ class GaussianPGPE(PGPE):
1516
1532
  sigma_grad, sigma_state, params=sigma)
1517
1533
  new_mu = optax.apply_updates(mu, mu_updates)
1518
1534
  new_sigma = optax.apply_updates(sigma, sigma_updates)
1519
- new_sigma = jax.tree_map(
1535
+ new_sigma = jax.tree_util.tree_map(
1520
1536
  partial(jnp.clip, min=sigma_lo, max=sigma_hi), new_sigma)
1521
1537
  return new_mu, new_sigma, new_mu_state, new_sigma_state
1522
1538
 
@@ -1537,7 +1553,7 @@ class GaussianPGPE(PGPE):
1537
1553
  if max_kl is not None:
1538
1554
  old_mu_lr = new_mu_state.hyperparams['learning_rate']
1539
1555
  old_sigma_lr = new_sigma_state.hyperparams['learning_rate']
1540
- kl_terms = jax.tree_map(
1556
+ kl_terms = jax.tree_util.tree_map(
1541
1557
  _jax_wrapped_pgpe_kl_term, new_mu, new_sigma, mu, sigma)
1542
1558
  total_kl = jax.tree_util.tree_reduce(jnp.add, kl_terms)
1543
1559
  kl_reduction = jnp.minimum(1.0, jnp.sqrt(max_kl / total_kl))
@@ -1672,6 +1688,7 @@ class JaxBackpropPlanner:
1672
1688
  compile_non_fluent_exact: bool=True,
1673
1689
  logger: Optional[Logger]=None,
1674
1690
  dashboard_viz: Optional[Any]=None,
1691
+ print_warnings: bool=True,
1675
1692
  parallel_updates: Optional[int]=None) -> None:
1676
1693
  '''Creates a new gradient-based algorithm for optimizing action sequences
1677
1694
  (plan) in the given RDDL. Some operations will be converted to their
@@ -1712,6 +1729,7 @@ class JaxBackpropPlanner:
1712
1729
  :param logger: to log information about compilation to file
1713
1730
  :param dashboard_viz: optional visualizer object from the environment
1714
1731
  to pass to the dashboard to visualize the policy
1732
+ :param print_warnings: whether to print warnings
1715
1733
  :param parallel_updates: how many optimizers to run independently in parallel
1716
1734
  '''
1717
1735
  self.rddl = rddl
@@ -1737,6 +1755,7 @@ class JaxBackpropPlanner:
1737
1755
  self.noise_kwargs = noise_kwargs
1738
1756
  self.pgpe = pgpe
1739
1757
  self.use_pgpe = pgpe is not None
1758
+ self.print_warnings = print_warnings
1740
1759
 
1741
1760
  # set optimizer
1742
1761
  try:
@@ -1789,7 +1808,11 @@ class JaxBackpropPlanner:
1789
1808
  self._jax_compile_rddl()
1790
1809
  self._jax_compile_optimizer()
1791
1810
 
1792
- def summarize_system(self) -> str:
1811
+ @staticmethod
1812
+ def summarize_system() -> str:
1813
+ '''Returns a string containing information about the system, Python version
1814
+ and jax-related packages that are relevant to the current planner.
1815
+ '''
1793
1816
  try:
1794
1817
  jaxlib_version = jax._src.lib.version_str
1795
1818
  except Exception as _:
@@ -1818,6 +1841,9 @@ r"""
1818
1841
  f'devices: {devices_short}\n')
1819
1842
 
1820
1843
  def summarize_relaxations(self) -> str:
1844
+ '''Returns a summary table containing all non-differentiable operators
1845
+ and their relaxations.
1846
+ '''
1821
1847
  result = ''
1822
1848
  if self.compiled.model_params:
1823
1849
  result += ('Some RDDL operations are non-differentiable '
@@ -1834,6 +1860,9 @@ r"""
1834
1860
  return result
1835
1861
 
1836
1862
  def summarize_hyperparameters(self) -> str:
1863
+ '''Returns a string summarizing the hyper-parameters of the current planner
1864
+ instance.
1865
+ '''
1837
1866
  result = (f'objective hyper-parameters:\n'
1838
1867
  f' utility_fn ={self.utility.__name__}\n'
1839
1868
  f' utility args ={self.utility_kwargs}\n'
@@ -1873,7 +1902,8 @@ r"""
1873
1902
  logger=self.logger,
1874
1903
  use64bit=self.use64bit,
1875
1904
  cpfs_without_grad=self.cpfs_without_grad,
1876
- compile_non_fluent_exact=self.compile_non_fluent_exact
1905
+ compile_non_fluent_exact=self.compile_non_fluent_exact,
1906
+ print_warnings=self.print_warnings
1877
1907
  )
1878
1908
  self.compiled.compile(log_jax_expr=True, heading='RELAXED MODEL')
1879
1909
 
@@ -1922,7 +1952,8 @@ r"""
1922
1952
 
1923
1953
  # optimization
1924
1954
  self.update = self._jax_update(train_loss)
1925
- self.pytree_at = jax.jit(lambda tree, i: jax.tree_map(lambda x: x[i], tree))
1955
+ self.pytree_at = jax.jit(
1956
+ lambda tree, i: jax.tree_util.tree_map(lambda x: x[i], tree))
1926
1957
 
1927
1958
  # pgpe option
1928
1959
  if self.use_pgpe:
@@ -1930,6 +1961,7 @@ r"""
1930
1961
  loss_fn=test_loss,
1931
1962
  projection=self.plan.projection,
1932
1963
  real_dtype=self.test_compiled.REAL,
1964
+ print_warnings=self.print_warnings,
1933
1965
  parallel_updates=self.parallel_updates
1934
1966
  )
1935
1967
  self.merge_pgpe = self._jax_merge_pgpe_jaxplan()
@@ -2010,7 +2042,7 @@ r"""
2010
2042
  # check if the gradients are all zeros
2011
2043
  def _jax_wrapped_zero_gradients(grad):
2012
2044
  leaves, _ = jax.tree_util.tree_flatten(
2013
- jax.tree_map(partial(jnp.allclose, b=0), grad))
2045
+ jax.tree_util.tree_map(partial(jnp.allclose, b=0), grad))
2014
2046
  return jnp.all(jnp.asarray(leaves))
2015
2047
 
2016
2048
  # calculate the plan gradient w.r.t. return loss and update optimizer
@@ -2069,7 +2101,7 @@ r"""
2069
2101
  def select_fn(leaf1, leaf2):
2070
2102
  expanded_mask = pgpe_mask[(...,) + (jnp.newaxis,) * (jnp.ndim(leaf1) - 1)]
2071
2103
  return jnp.where(expanded_mask, leaf1, leaf2)
2072
- policy_params = jax.tree_map(select_fn, pgpe_param, policy_params)
2104
+ policy_params = jax.tree_util.tree_map(select_fn, pgpe_param, policy_params)
2073
2105
  test_loss = jnp.where(pgpe_mask, pgpe_loss, test_loss)
2074
2106
  test_loss_smooth = jnp.where(pgpe_mask, pgpe_loss_smooth, test_loss_smooth)
2075
2107
  expanded_mask = pgpe_mask[(...,) + (jnp.newaxis,) * (jnp.ndim(converged) - 1)]
@@ -2091,7 +2123,9 @@ r"""
2091
2123
  f'Variable <{name}> in subs argument is not a '
2092
2124
  f'valid p-variable, must be one of '
2093
2125
  f'{set(self.test_compiled.init_values.keys())}.')
2094
- value = np.reshape(value, newshape=np.shape(init_value))[np.newaxis, ...]
2126
+ value = np.reshape(value, np.shape(init_value))[np.newaxis, ...]
2127
+ if value.dtype.type is np.str_:
2128
+ value = rddl.object_string_to_index_array(rddl.variable_ranges[name], value)
2095
2129
  train_value = np.repeat(value, repeats=n_train, axis=0)
2096
2130
  train_value = np.asarray(train_value, dtype=self.compiled.REAL)
2097
2131
  init_train[name] = train_value
@@ -2121,7 +2155,7 @@ r"""
2121
2155
  x[np.newaxis, ...], shape=(self.parallel_updates,) + np.shape(x))
2122
2156
  return x
2123
2157
 
2124
- return jax.tree_map(make_batched, pytree)
2158
+ return jax.tree_util.tree_map(make_batched, pytree)
2125
2159
 
2126
2160
  def as_optimization_problem(
2127
2161
  self, key: Optional[random.PRNGKey]=None,
@@ -2165,10 +2199,11 @@ r"""
2165
2199
  train_subs, _ = self._batched_init_subs(subs)
2166
2200
  model_params = self.compiled.model_params
2167
2201
  if policy_hyperparams is None:
2168
- message = termcolor.colored(
2169
- '[WARN] policy_hyperparams is not set, setting 1.0 for '
2170
- 'all action-fluents which could be suboptimal.', 'yellow')
2171
- print(message)
2202
+ if self.print_warnings:
2203
+ message = termcolor.colored(
2204
+ '[WARN] policy_hyperparams is not set, setting 1.0 for '
2205
+ 'all action-fluents which could be suboptimal.', 'yellow')
2206
+ print(message)
2172
2207
  policy_hyperparams = {action: 1.0
2173
2208
  for action in self.rddl.action_fluents}
2174
2209
 
@@ -2318,10 +2353,11 @@ r"""
2318
2353
 
2319
2354
  # cannot run dashboard with parallel updates
2320
2355
  if dashboard is not None and self.parallel_updates is not None:
2321
- message = termcolor.colored(
2322
- '[WARN] Dashboard is unavailable if parallel_updates is not None: '
2323
- 'setting dashboard to None.', 'yellow')
2324
- print(message)
2356
+ if self.print_warnings:
2357
+ message = termcolor.colored(
2358
+ '[WARN] Dashboard is unavailable if parallel_updates is not None: '
2359
+ 'setting dashboard to None.', 'yellow')
2360
+ print(message)
2325
2361
  dashboard = None
2326
2362
 
2327
2363
  # if PRNG key is not provided
@@ -2331,19 +2367,21 @@ r"""
2331
2367
 
2332
2368
  # if policy_hyperparams is not provided
2333
2369
  if policy_hyperparams is None:
2334
- message = termcolor.colored(
2335
- '[WARN] policy_hyperparams is not set, setting 1.0 for '
2336
- 'all action-fluents which could be suboptimal.', 'yellow')
2337
- print(message)
2370
+ if self.print_warnings:
2371
+ message = termcolor.colored(
2372
+ '[WARN] policy_hyperparams is not set, setting 1.0 for '
2373
+ 'all action-fluents which could be suboptimal.', 'yellow')
2374
+ print(message)
2338
2375
  policy_hyperparams = {action: 1.0
2339
2376
  for action in self.rddl.action_fluents}
2340
2377
 
2341
2378
  # if policy_hyperparams is a scalar
2342
2379
  elif isinstance(policy_hyperparams, (int, float, np.number)):
2343
- message = termcolor.colored(
2344
- f'[INFO] policy_hyperparams is {policy_hyperparams}, '
2345
- f'setting this value for all action-fluents.', 'green')
2346
- print(message)
2380
+ if self.print_warnings:
2381
+ message = termcolor.colored(
2382
+ f'[INFO] policy_hyperparams is {policy_hyperparams}, '
2383
+ f'setting this value for all action-fluents.', 'green')
2384
+ print(message)
2347
2385
  hyperparam_value = float(policy_hyperparams)
2348
2386
  policy_hyperparams = {action: hyperparam_value
2349
2387
  for action in self.rddl.action_fluents}
@@ -2352,11 +2390,12 @@ r"""
2352
2390
  elif isinstance(policy_hyperparams, dict):
2353
2391
  for action in self.rddl.action_fluents:
2354
2392
  if action not in policy_hyperparams:
2355
- message = termcolor.colored(
2356
- f'[WARN] policy_hyperparams[{action}] is not set, '
2357
- f'setting 1.0 for missing action-fluents '
2358
- f'which could be suboptimal.', 'yellow')
2359
- print(message)
2393
+ if self.print_warnings:
2394
+ message = termcolor.colored(
2395
+ f'[WARN] policy_hyperparams[{action}] is not set, '
2396
+ f'setting 1.0 for missing action-fluents '
2397
+ f'which could be suboptimal.', 'yellow')
2398
+ print(message)
2360
2399
  policy_hyperparams[action] = 1.0
2361
2400
 
2362
2401
  # print summary of parameters:
@@ -2396,7 +2435,7 @@ r"""
2396
2435
  if var not in subs:
2397
2436
  subs[var] = value
2398
2437
  added_pvars_to_subs.append(var)
2399
- if added_pvars_to_subs:
2438
+ if self.print_warnings and added_pvars_to_subs:
2400
2439
  message = termcolor.colored(
2401
2440
  f'[INFO] p-variables {added_pvars_to_subs} is not in '
2402
2441
  f'provided subs, using their initial values.', 'green')
@@ -2648,7 +2687,7 @@ r"""
2648
2687
  policy_params, opt_state, opt_aux = self.initialize(
2649
2688
  subkey, policy_hyperparams, train_subs)
2650
2689
  no_progress_count = 0
2651
- if progress_bar is not None:
2690
+ if self.print_warnings and progress_bar is not None:
2652
2691
  message = termcolor.colored(
2653
2692
  f'[INFO] Optimizer restarted at iteration {it} '
2654
2693
  f'due to lack of progress.', 'green')
@@ -2658,7 +2697,7 @@ r"""
2658
2697
 
2659
2698
  # stopping condition reached
2660
2699
  if stopping_rule is not None and stopping_rule.monitor(callback):
2661
- if progress_bar is not None:
2700
+ if self.print_warnings and progress_bar is not None:
2662
2701
  message = termcolor.colored(
2663
2702
  '[SUCC] Stopping rule has been reached.', 'green')
2664
2703
  progress_bar.write(message)
@@ -2699,7 +2738,8 @@ r"""
2699
2738
 
2700
2739
  # summarize and test for convergence
2701
2740
  if print_summary:
2702
- grad_norm = jax.tree_map(lambda x: np.linalg.norm(x).item(), best_grad)
2741
+ grad_norm = jax.tree_util.tree_map(
2742
+ lambda x: np.linalg.norm(x).item(), best_grad)
2703
2743
  diagnosis = self._perform_diagnosis(
2704
2744
  last_iter_improve, -np.min(train_loss), -np.min(test_loss_smooth),
2705
2745
  -best_loss, grad_norm)
@@ -2777,6 +2817,7 @@ r"""
2777
2817
  :param policy_hyperparams: hyper-parameters for the policy/plan, such as
2778
2818
  weights for sigmoid wrapping boolean actions (optional)
2779
2819
  '''
2820
+ subs = subs.copy()
2780
2821
 
2781
2822
  # check compatibility of the subs dictionary
2782
2823
  for (var, values) in subs.items():
@@ -2795,13 +2836,17 @@ r"""
2795
2836
  if step == 0 and var in self.rddl.observ_fluents:
2796
2837
  subs[var] = self.test_compiled.init_values[var]
2797
2838
  else:
2798
- raise ValueError(
2799
- f'Values {values} assigned to p-variable <{var}> are '
2800
- f'non-numeric of type {dtype}.')
2839
+ if dtype.type is np.str_:
2840
+ prange = self.rddl.variable_ranges[var]
2841
+ subs[var] = self.rddl.object_string_to_index_array(prange, subs[var])
2842
+ else:
2843
+ raise ValueError(
2844
+ f'Values {values} assigned to p-variable <{var}> are '
2845
+ f'non-numeric of type {dtype}.')
2801
2846
 
2802
2847
  # cast device arrays to numpy
2803
2848
  actions = self.test_policy(key, params, policy_hyperparams, step, subs)
2804
- actions = jax.tree_map(np.asarray, actions)
2849
+ actions = jax.tree_util.tree_map(np.asarray, actions)
2805
2850
  return actions
2806
2851
 
2807
2852
 
@@ -2822,8 +2867,9 @@ class JaxOfflineController(BaseAgent):
2822
2867
  def __init__(self, planner: JaxBackpropPlanner,
2823
2868
  key: Optional[random.PRNGKey]=None,
2824
2869
  eval_hyperparams: Optional[Dict[str, Any]]=None,
2825
- params: Optional[Pytree]=None,
2870
+ params: Optional[Union[str, Pytree]]=None,
2826
2871
  train_on_reset: bool=False,
2872
+ save_path: Optional[str]=None,
2827
2873
  **train_kwargs) -> None:
2828
2874
  '''Creates a new JAX offline control policy that is trained once, then
2829
2875
  deployed later.
@@ -2834,8 +2880,10 @@ class JaxOfflineController(BaseAgent):
2834
2880
  :param eval_hyperparams: policy hyperparameters to apply for evaluation
2835
2881
  or whenever sample_action is called
2836
2882
  :param params: use the specified policy parameters instead of calling
2837
- planner.optimize()
2883
+ planner.optimize(); can be a string pointing to a valid file path where params
2884
+ have been saved, or a pytree of parameters
2838
2885
  :param train_on_reset: retrain policy parameters on every episode reset
2886
+ :param save_path: optional path to save parameters to
2839
2887
  :param **train_kwargs: any keyword arguments to be passed to the planner
2840
2888
  for optimization
2841
2889
  '''
@@ -2848,12 +2896,24 @@ class JaxOfflineController(BaseAgent):
2848
2896
  self.train_kwargs = train_kwargs
2849
2897
  self.params_given = params is not None
2850
2898
 
2899
+ # load the policy from file
2900
+ if not self.train_on_reset and params is not None and isinstance(params, str):
2901
+ with open(params, 'rb') as file:
2902
+ params = pickle.load(file)
2903
+
2904
+ # train the policy
2851
2905
  self.step = 0
2852
2906
  self.callback = None
2853
2907
  if not self.train_on_reset and not self.params_given:
2854
2908
  callback = self.planner.optimize(key=self.key, **self.train_kwargs)
2855
2909
  self.callback = callback
2856
2910
  params = callback['best_params']
2911
+
2912
+ # save the policy
2913
+ if save_path is not None:
2914
+ with open(save_path, 'wb') as file:
2915
+ pickle.dump(params, file)
2916
+
2857
2917
  self.params = params
2858
2918
 
2859
2919
  def sample_action(self, state: Dict[str, Any]) -> Dict[str, Any]:
@@ -2865,6 +2925,8 @@ class JaxOfflineController(BaseAgent):
2865
2925
 
2866
2926
  def reset(self) -> None:
2867
2927
  self.step = 0
2928
+
2929
+ # train the policy if required to reset at the start of every episode
2868
2930
  if self.train_on_reset and not self.params_given:
2869
2931
  callback = self.planner.optimize(key=self.key, **self.train_kwargs)
2870
2932
  self.callback = callback
@@ -2915,18 +2977,22 @@ class JaxOnlineController(BaseAgent):
2915
2977
  attempts = 0
2916
2978
  while attempts < self.max_attempts and callback['iteration'] <= 1:
2917
2979
  attempts += 1
2918
- message = termcolor.colored(
2919
- f'[WARN] JIT compilation dominated the execution time: '
2920
- f'executing the optimizer again on the traced model [attempt {attempts}].',
2921
- 'yellow')
2922
- print(message)
2980
+ if self.planner.print_warnings:
2981
+ message = termcolor.colored(
2982
+ f'[WARN] JIT compilation dominated the execution time: '
2983
+ f'executing the optimizer again on the traced model '
2984
+ f'[attempt {attempts}].', 'yellow')
2985
+ print(message)
2923
2986
  callback = planner.optimize(
2924
- key=self.key, guess=self.guess, subs=state, **self.train_kwargs)
2925
-
2987
+ key=self.key, guess=self.guess, subs=state, **self.train_kwargs)
2926
2988
  self.callback = callback
2927
2989
  params = callback['best_params']
2990
+
2991
+ # get the action from the parameters for the current state
2928
2992
  self.key, subkey = random.split(self.key)
2929
2993
  actions = planner.get_action(subkey, params, 0, state, self.eval_hyperparams)
2994
+
2995
+ # apply warm start for the next epoch
2930
2996
  if self.warm_start:
2931
2997
  self.guess = planner.plan.guess_next_epoch(params)
2932
2998
  return actions
@@ -19,10 +19,12 @@
19
19
 
20
20
 
21
21
  import time
22
- from typing import Dict, Optional
22
+ import numpy as np
23
+ from typing import Dict, Optional, Union
23
24
 
24
25
  import jax
25
26
 
27
+ from pyRDDLGym.core.compiler.initializer import RDDLValueInitializer
26
28
  from pyRDDLGym.core.compiler.model import RDDLLiftedModel
27
29
  from pyRDDLGym.core.debug.exception import (
28
30
  RDDLActionPreconditionNotSatisfiedError,
@@ -35,7 +37,7 @@ from pyRDDLGym.core.simulator import RDDLSimulator
35
37
 
36
38
  from pyRDDLGym_jax.core.compiler import JaxRDDLCompiler
37
39
 
38
- Args = Dict[str, Value]
40
+ Args = Dict[str, Union[np.ndarray, Value]]
39
41
 
40
42
 
41
43
  class JaxRDDLSimulator(RDDLSimulator):
@@ -45,6 +47,7 @@ class JaxRDDLSimulator(RDDLSimulator):
45
47
  raise_error: bool=True,
46
48
  logger: Optional[Logger]=None,
47
49
  keep_tensors: bool=False,
50
+ objects_as_strings: bool=True,
48
51
  **compiler_args) -> None:
49
52
  '''Creates a new simulator for the given RDDL model with Jax as a backend.
50
53
 
@@ -57,6 +60,8 @@ class JaxRDDLSimulator(RDDLSimulator):
57
60
  :param logger: to log information about compilation to file
58
61
  :param keep_tensors: whether the sampler takes actions and
59
62
  returns state in numpy array form
63
+ param objects_as_strings: whether to return object values as strings (defaults
64
+ to integer indices if False)
60
65
  :param **compiler_args: keyword arguments to pass to the Jax compiler
61
66
  '''
62
67
  if key is None:
@@ -67,7 +72,8 @@ class JaxRDDLSimulator(RDDLSimulator):
67
72
 
68
73
  # generate direct sampling with default numpy RNG and operations
69
74
  super(JaxRDDLSimulator, self).__init__(
70
- rddl, logger=logger, keep_tensors=keep_tensors)
75
+ rddl, logger=logger,
76
+ keep_tensors=keep_tensors, objects_as_strings=objects_as_strings)
71
77
 
72
78
  def seed(self, seed: int) -> None:
73
79
  super(JaxRDDLSimulator, self).seed(seed)
@@ -84,11 +90,11 @@ class JaxRDDLSimulator(RDDLSimulator):
84
90
  self.levels = compiled.levels
85
91
  self.traced = compiled.traced
86
92
 
87
- self.invariants = jax.tree_map(jax.jit, compiled.invariants)
88
- self.preconds = jax.tree_map(jax.jit, compiled.preconditions)
89
- self.terminals = jax.tree_map(jax.jit, compiled.terminations)
93
+ self.invariants = jax.tree_util.tree_map(jax.jit, compiled.invariants)
94
+ self.preconds = jax.tree_util.tree_map(jax.jit, compiled.preconditions)
95
+ self.terminals = jax.tree_util.tree_map(jax.jit, compiled.terminations)
90
96
  self.reward = jax.jit(compiled.reward)
91
- jax_cpfs = jax.tree_map(jax.jit, compiled.cpfs)
97
+ jax_cpfs = jax.tree_util.tree_map(jax.jit, compiled.cpfs)
92
98
  self.model_params = compiled.model_params
93
99
 
94
100
  # level analysis
@@ -139,7 +145,6 @@ class JaxRDDLSimulator(RDDLSimulator):
139
145
 
140
146
  def check_action_preconditions(self, actions: Args, silent: bool=False) -> bool:
141
147
  '''Throws an exception if the action preconditions are not satisfied.'''
142
- actions = self._process_actions(actions)
143
148
  subs = self.subs
144
149
  subs.update(actions)
145
150
 
@@ -180,7 +185,6 @@ class JaxRDDLSimulator(RDDLSimulator):
180
185
  '''
181
186
  rddl = self.rddl
182
187
  keep_tensors = self.keep_tensors
183
- actions = self._process_actions(actions)
184
188
  subs = self.subs
185
189
  subs.update(actions)
186
190
 
@@ -196,20 +200,40 @@ class JaxRDDLSimulator(RDDLSimulator):
196
200
  # update state
197
201
  self.state = {}
198
202
  for (state, next_state) in rddl.next_state.items():
203
+
204
+ # set state = state' for the next epoch
199
205
  subs[state] = subs[next_state]
206
+
207
+ # convert object integer to string representation
208
+ state_values = subs[state]
209
+ if self.objects_as_strings:
210
+ ptype = rddl.variable_ranges[state]
211
+ if ptype not in RDDLValueInitializer.NUMPY_TYPES:
212
+ state_values = rddl.index_to_object_string_array(ptype, state_values)
213
+
214
+ # optional grounding of state dictionary
200
215
  if keep_tensors:
201
- self.state[state] = subs[state]
216
+ self.state[state] = state_values
202
217
  else:
203
- self.state.update(rddl.ground_var_with_values(state, subs[state]))
218
+ self.state.update(rddl.ground_var_with_values(state, state_values))
204
219
 
205
220
  # update observation
206
221
  if self._pomdp:
207
222
  obs = {}
208
223
  for var in rddl.observ_fluents:
224
+
225
+ # convert object integer to string representation
226
+ obs_values = subs[var]
227
+ if self.objects_as_strings:
228
+ ptype = rddl.variable_ranges[var]
229
+ if ptype not in RDDLValueInitializer.NUMPY_TYPES:
230
+ obs_values = rddl.index_to_object_string_array(ptype, obs_values)
231
+
232
+ # optional grounding of observ-fluent dictionary
209
233
  if keep_tensors:
210
- obs[var] = subs[var]
234
+ obs[var] = obs_values
211
235
  else:
212
- obs.update(rddl.ground_var_with_values(var, subs[var]))
236
+ obs.update(rddl.ground_var_with_values(var, obs_values))
213
237
  else:
214
238
  obs = self.state
215
239
 
@@ -371,16 +371,30 @@ class JaxParameterTuning:
371
371
  '''Tunes the Bayesian optimization algorithm hyper-parameters.'''
372
372
  print(f'Kernel: {repr(optimizer._gp.kernel_)}.')
373
373
 
374
- def tune(self, key: int, log_file: str, show_dashboard: bool=False) -> ParameterValues:
375
- '''Tunes the hyper-parameters for Jax planner, returns the best found.'''
374
+ def tune(self, key: int,
375
+ log_file: Optional[str]=None,
376
+ show_dashboard: bool=False,
377
+ print_hyperparams: bool=False) -> ParameterValues:
378
+ '''Tunes the hyper-parameters for Jax planner, returns the best found.
376
379
 
377
- print(self.summarize_hyperparameters())
380
+ :param key: RNG key to seed the hyper-parameter optimizer
381
+ :param log_file: optional path to file where tuning progress will be saved
382
+ :param show_dashboard: whether to display tuning results in a dashboard
383
+ :param print_hyperparams: whether to print a hyper-parameter summary of the
384
+ optimizer
385
+ '''
378
386
 
379
- # clear and prepare output file
380
- with open(log_file, 'w', newline='') as file:
381
- writer = csv.writer(file)
382
- writer.writerow(COLUMNS + list(self.hyperparams_dict.keys()))
387
+ if self.verbose:
388
+ print(JaxBackpropPlanner.summarize_system())
389
+ if print_hyperparams:
390
+ print(self.summarize_hyperparameters())
383
391
 
392
+ # clear and prepare output file
393
+ if log_file is not None:
394
+ with open(log_file, 'w', newline='') as file:
395
+ writer = csv.writer(file)
396
+ writer.writerow(COLUMNS + list(self.hyperparams_dict.keys()))
397
+
384
398
  # create a dash-board for visualizing experiment runs
385
399
  if show_dashboard and JaxPlannerDashboard is not None:
386
400
  dashboard = JaxPlannerDashboard()
@@ -519,9 +533,10 @@ class JaxParameterTuning:
519
533
  self.tune_optimizer(optimizer)
520
534
 
521
535
  # write results of all processes in current iteration to file
522
- with open(log_file, 'a', newline='') as file:
523
- writer = csv.writer(file)
524
- writer.writerows(rows)
536
+ if log_file is not None:
537
+ with open(log_file, 'a', newline='') as file:
538
+ writer = csv.writer(file)
539
+ writer.writerows(rows)
525
540
 
526
541
  # update the dashboard tuning
527
542
  if show_dashboard:
@@ -2,24 +2,56 @@ import argparse
2
2
 
3
3
  from pyRDDLGym_jax.examples import run_plan, run_tune
4
4
 
5
+ EPILOG = 'For complete documentation, see https://pyrddlgym.readthedocs.io/en/latest/jax.html.'
6
+
5
7
  def main():
6
- parser = argparse.ArgumentParser(description="Command line parser for the JaxPlan planner.")
8
+ parser = argparse.ArgumentParser(prog='jaxplan',
9
+ description="command line parser for the jaxplan planner",
10
+ epilog=EPILOG)
7
11
  subparsers = parser.add_subparsers(dest="jaxplan", required=True)
8
12
 
9
13
  # planning
10
- parser_plan = subparsers.add_parser("plan", help="Executes JaxPlan on a specified RDDL problem and method (slp, drp, or replan).")
11
- parser_plan.add_argument('args', nargs=argparse.REMAINDER)
14
+ parser_plan = subparsers.add_parser("plan",
15
+ help="execute jaxplan on a specified RDDL problem",
16
+ epilog=EPILOG)
17
+ parser_plan.add_argument('domain', type=str,
18
+ help='name of domain in rddlrepository or a valid file path')
19
+ parser_plan.add_argument('instance', type=str,
20
+ help='name of instance in rddlrepository or a valid file path')
21
+ parser_plan.add_argument('method', type=str,
22
+ help='training method to apply: [slp, drp] are offline methods, and [replan] are online')
23
+ parser_plan.add_argument('-e', '--episodes', type=int, required=False, default=1,
24
+ help='number of training or evaluation episodes')
12
25
 
13
26
  # tuning
14
- parser_tune = subparsers.add_parser("tune", help="Tunes JaxPlan on a specified RDDL problem and method (slp, drp, or replan).")
15
- parser_tune.add_argument('args', nargs=argparse.REMAINDER)
27
+ parser_tune = subparsers.add_parser("tune",
28
+ help="tune jaxplan on a specified RDDL problem",
29
+ epilog=EPILOG)
30
+ parser_tune.add_argument('domain', type=str,
31
+ help='name of domain in rddlrepository or a valid file path')
32
+ parser_tune.add_argument('instance', type=str,
33
+ help='name of instance in rddlrepository or a valid file path')
34
+ parser_tune.add_argument('method', type=str,
35
+ help='training method to apply: [slp, drp] are offline methods, and [replan] are online')
36
+ parser_tune.add_argument('-t', '--trials', type=int, required=False, default=5,
37
+ help='number of evaluation rollouts per hyper-parameter choice')
38
+ parser_tune.add_argument('-i', '--iters', type=int, required=False, default=20,
39
+ help='number of iterations of bayesian optimization')
40
+ parser_tune.add_argument('-w', '--workers', type=int, required=False, default=4,
41
+ help='number of parallel hyper-parameters to evaluate per iteration')
42
+ parser_tune.add_argument('-d', '--dashboard', type=bool, required=False, default=False,
43
+ help='show the dashboard')
44
+ parser_tune.add_argument('-f', '--filepath', type=str, required=False, default='',
45
+ help='where to save the config file of the best hyper-parameters')
16
46
 
17
47
  # dispatch
18
48
  args = parser.parse_args()
19
49
  if args.jaxplan == "plan":
20
- run_plan.run_from_args(args.args)
50
+ run_plan.main(args.domain, args.instance, args.method, args.episodes)
21
51
  elif args.jaxplan == "tune":
22
- run_tune.run_from_args(args.args)
52
+ run_tune.main(args.domain, args.instance, args.method,
53
+ args.trials, args.iters, args.workers, args.dashboard,
54
+ args.filepath)
23
55
  else:
24
56
  parser.print_help()
25
57
 
@@ -11,6 +11,7 @@ optimizer='rmsprop'
11
11
  optimizer_kwargs={'learning_rate': LEARNING_RATE_TUNE}
12
12
  batch_size_train=32
13
13
  batch_size_test=32
14
+ print_warnings=False
14
15
 
15
16
  [Training]
16
17
  train_seconds=30
@@ -12,6 +12,7 @@ optimizer_kwargs={'learning_rate': LEARNING_RATE_TUNE}
12
12
  batch_size_train=32
13
13
  batch_size_test=32
14
14
  rollout_horizon=ROLLOUT_HORIZON_TUNE
15
+ print_warnings=False
15
16
 
16
17
  [Training]
17
18
  train_seconds=1
@@ -11,6 +11,7 @@ optimizer='rmsprop'
11
11
  optimizer_kwargs={'learning_rate': LEARNING_RATE_TUNE}
12
12
  batch_size_train=32
13
13
  batch_size_test=32
14
+ print_warnings=False
14
15
 
15
16
  [Training]
16
17
  train_seconds=30
@@ -26,7 +26,7 @@ from pyRDDLGym_jax.core.planner import (
26
26
  )
27
27
 
28
28
 
29
- def main(domain, instance, method, episodes=1):
29
+ def main(domain: str, instance: str, method: str, episodes: int=1) -> None:
30
30
 
31
31
  # set up the environment
32
32
  env = pyRDDLGym.make(domain, instance, vectorized=True)
@@ -36,7 +36,9 @@ def power_10(x):
36
36
  return 10.0 ** x
37
37
 
38
38
 
39
- def main(domain, instance, method, trials=5, iters=20, workers=4, dashboard=False):
39
+ def main(domain: str, instance: str, method: str,
40
+ trials: int=5, iters: int=20, workers: int=4, dashboard: bool=False,
41
+ filepath: str='') -> None:
40
42
 
41
43
  # set up the environment
42
44
  env = pyRDDLGym.make(domain, instance, vectorized=True)
@@ -68,6 +70,9 @@ def main(domain, instance, method, trials=5, iters=20, workers=4, dashboard=Fals
68
70
  tuning.tune(key=42,
69
71
  log_file=f'gp_{method}_{domain}_{instance}.csv',
70
72
  show_dashboard=dashboard)
73
+ if filepath is not None and filepath:
74
+ with open(filepath, "w") as file:
75
+ file.write(tuning.best_config)
71
76
 
72
77
  # evaluate the agent on the best parameters
73
78
  planner_args, _, train_args = load_config_from_string(tuning.best_config)
@@ -80,7 +85,7 @@ def main(domain, instance, method, trials=5, iters=20, workers=4, dashboard=Fals
80
85
 
81
86
  def run_from_args(args):
82
87
  if len(args) < 3:
83
- print('python run_tune.py <domain> <instance> <method> [<trials>] [<iters>] [<workers>] [<dashboard>]')
88
+ print('python run_tune.py <domain> <instance> <method> [<trials>] [<iters>] [<workers>] [<dashboard>] [<filepath>]')
84
89
  exit(1)
85
90
  if args[2] not in ['drp', 'slp', 'replan']:
86
91
  print('<method> in [drp, slp, replan]')
@@ -90,6 +95,7 @@ def run_from_args(args):
90
95
  if len(args) >= 5: kwargs['iters'] = int(args[4])
91
96
  if len(args) >= 6: kwargs['workers'] = int(args[5])
92
97
  if len(args) >= 7: kwargs['dashboard'] = bool(args[6])
98
+ if len(args) >= 8: kwargs['filepath'] = bool(args[7])
93
99
  main(**kwargs)
94
100
 
95
101
 
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: pyRDDLGym-jax
3
- Version: 2.4
3
+ Version: 2.5
4
4
  Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
5
  Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
6
  Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
@@ -39,6 +39,7 @@ Dynamic: description
39
39
  Dynamic: description-content-type
40
40
  Dynamic: home-page
41
41
  Dynamic: license
42
+ Dynamic: license-file
42
43
  Dynamic: provides-extra
43
44
  Dynamic: requires-dist
44
45
  Dynamic: requires-python
@@ -116,7 +117,7 @@ pip install pyRDDLGym-jax[extra,dashboard]
116
117
  A basic run script is provided to train JaxPlan on any RDDL problem:
117
118
 
118
119
  ```shell
119
- jaxplan plan <domain> <instance> <method> <episodes>
120
+ jaxplan plan <domain> <instance> <method> --episodes <episodes>
120
121
  ```
121
122
 
122
123
  where:
@@ -241,7 +242,7 @@ More documentation about this and other new features will be coming soon.
241
242
  A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
242
243
 
243
244
  ```shell
244
- jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
245
+ jaxplan tune <domain> <instance> <method> --trials <trials> --iters <iters> --workers <workers> --dashboard <dashboard> --filepath <filepath>
245
246
  ```
246
247
 
247
248
  where:
@@ -251,7 +252,8 @@ where:
251
252
  - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
252
253
  - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
253
254
  - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
254
- - ``dashboard`` is whether the optimizations are tracked in the dashboard application.
255
+ - ``dashboard`` is whether the optimizations are tracked in the dashboard application
256
+ - ``filepath`` is the optional file path where a config file with the best hyper-parameter setting will be saved.
255
257
 
256
258
  It is easy to tune a custom range of the planner's hyper-parameters efficiently.
257
259
  First create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
@@ -291,23 +293,16 @@ env = pyRDDLGym.make(domain, instance, vectorized=True)
291
293
  with open('path/to/config.cfg', 'r') as file:
292
294
  config_template = file.read()
293
295
 
294
- # map parameters in the config that will be tuned
296
+ # tune weight from 10^-1 ... 10^5 and lr from 10^-5 ... 10^1
295
297
  def power_10(x):
296
- return 10.0 ** x
297
-
298
- hyperparams = [
299
- Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10), # tune weight from 10^-1 ... 10^5
300
- Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10), # tune lr from 10^-5 ... 10^1
301
- ]
298
+ return 10.0 ** x
299
+ hyperparams = [Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10),
300
+ Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10)]
302
301
 
303
302
  # build the tuner and tune
304
303
  tuning = JaxParameterTuning(env=env,
305
- config_template=config_template,
306
- hyperparams=hyperparams,
307
- online=False,
308
- eval_trials=trials,
309
- num_workers=workers,
310
- gp_iters=iters)
304
+ config_template=config_template, hyperparams=hyperparams,
305
+ online=False, eval_trials=trials, num_workers=workers, gp_iters=iters)
311
306
  tuning.tune(key=42, log_file='path/to/log.csv')
312
307
  ```
313
308
 
@@ -1,20 +1,20 @@
1
- pyRDDLGym_jax/__init__.py,sha256=6Bd43-94X_2dH_ErGLQ0_DvlhX5cLWkVPvn31JBzFkY,19
2
- pyRDDLGym_jax/entry_point.py,sha256=dxDlO_5gneEEViwkLCg30Z-KVzUgdRXaKuFjoZklkA0,974
1
+ pyRDDLGym_jax/__init__.py,sha256=VoxLo_sy8RlJIIyu7szqL-cdMGBJdQPg-aSeyOVVIkY,19
2
+ pyRDDLGym_jax/entry_point.py,sha256=K0zy1oe66jfBHkHHCM6aGHbbiVqnQvDhDb8se4uaKHE,3319
3
3
  pyRDDLGym_jax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- pyRDDLGym_jax/core/compiler.py,sha256=NFWfTHtGf7F-t7Qhn6X-VpSAJkTVHm-oRjujFw4O1HA,82605
4
+ pyRDDLGym_jax/core/compiler.py,sha256=uFCtoipsIa3MM9nGgT3X8iCViPl2XSPNXh0jMdzN0ko,82895
5
5
  pyRDDLGym_jax/core/logic.py,sha256=lfc2ak_ap_ajMEFlB5EHCRNgJym31dNyA-5d-7N4CZA,56271
6
- pyRDDLGym_jax/core/planner.py,sha256=wZJiZHV0Qxi9DS3AQ9Rx1doBvsKQXc1HYziY6GXTu_A,136965
7
- pyRDDLGym_jax/core/simulator.py,sha256=DnPL93WVCMZqtqMUoiJdfWcH9pEvNgGfDfO4NV0wIS0,9271
8
- pyRDDLGym_jax/core/tuning.py,sha256=Gm3YJF84_2vDIIJpOj0tK0-4rlJoEjYwxRt_JpUKAOA,24482
6
+ pyRDDLGym_jax/core/planner.py,sha256=M6GKzN7Ml57B4ZrFZhhkpsQCvReKaCQNzer7zeHCM9E,140275
7
+ pyRDDLGym_jax/core/simulator.py,sha256=ayCATTUL3clLaZPQ5OUg2bI_c26KKCTq6TbrxbMsVdc,10470
8
+ pyRDDLGym_jax/core/tuning.py,sha256=BWcQZk02TMLexTz1Sw4lX2EQKvmPbp7biC51M-IiNUw,25153
9
9
  pyRDDLGym_jax/core/visualization.py,sha256=4BghMp8N7qtF0tdyDSqtxAxNfP9HPrQWTiXzAMJmx7o,70365
10
10
  pyRDDLGym_jax/core/assets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  pyRDDLGym_jax/core/assets/favicon.ico,sha256=RMMrI9YvmF81TgYG7FO7UAre6WmYFkV3B2GmbA1l0kM,175085
12
12
  pyRDDLGym_jax/examples/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
13
  pyRDDLGym_jax/examples/run_gradient.py,sha256=KhXvijRDZ4V7N8NOI2WV8ePGpPna5_vnET61YwS7Tco,2919
14
14
  pyRDDLGym_jax/examples/run_gym.py,sha256=rXvNWkxe4jHllvbvU_EOMji_2-2k5d4tbBKhpMm_Gaw,1526
15
- pyRDDLGym_jax/examples/run_plan.py,sha256=TVfziHHaEC56wxwRw9llZ5iqSHe3m6yy8HxiR2TyvXE,2778
15
+ pyRDDLGym_jax/examples/run_plan.py,sha256=4y7JHqTxY5O1ltP6N7rar0jMiw7u9w1nuAIOcmDaAuE,2806
16
16
  pyRDDLGym_jax/examples/run_scipy.py,sha256=7uVnDXb7D3NTJqA2L8nrcYDJP-k0ba9dl9YqA2CD9ac,2301
17
- pyRDDLGym_jax/examples/run_tune.py,sha256=WbGO8RudIK-cPMAMKvI8NbFQAqkG-Blbnta3Efsep6c,3828
17
+ pyRDDLGym_jax/examples/run_tune.py,sha256=F5KWgtoCPbf7XHB6HW9LjxarD57U2LvuGdTz67OL1DY,4114
18
18
  pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg,sha256=mE8MqhOlkHeXIGEVrnR3QY6I-_iy4uxFYRA71P1bmtk,347
19
19
  pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg,sha256=nFFYHCKQUMn8x-OpJwu2pwe1tycNSJ8iAIwSkCBn33E,370
20
20
  pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg,sha256=eJ3HvHjODoKdtX7u-AM51xQaHJnYgzEy2t3omNG2oCs,340
@@ -38,12 +38,12 @@ pyRDDLGym_jax/examples/configs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5
38
38
  pyRDDLGym_jax/examples/configs/default_drp.cfg,sha256=XeMWAAG_OFZo7JAMxS5-XXroZaeVMzfM0NswmEobIns,373
39
39
  pyRDDLGym_jax/examples/configs/default_replan.cfg,sha256=CK4cEz8ReXyAZPLaLG9clIIRXAqM3IplUCxbLt_V2lY,407
40
40
  pyRDDLGym_jax/examples/configs/default_slp.cfg,sha256=mJo0woDevhQCSQfJg30ULVy9qGIJDIw73XCe6pyIPtg,369
41
- pyRDDLGym_jax/examples/configs/tuning_drp.cfg,sha256=CQMpSCKTkGioO7U82mHMsYWFRsutULx0V6Wrl3YzV2U,504
42
- pyRDDLGym_jax/examples/configs/tuning_replan.cfg,sha256=m_0nozFg_GVld0tGv92Xao_KONFJDq_vtiJKt5isqI8,501
43
- pyRDDLGym_jax/examples/configs/tuning_slp.cfg,sha256=KHu8II6CA-h_HblwvWHylNRjSvvGS3VHxN7JQNR4p_Q,464
44
- pyrddlgym_jax-2.4.dist-info/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
45
- pyrddlgym_jax-2.4.dist-info/METADATA,sha256=98Nl3EnEk-fRLeoy9orDScaikCT9M8X4zOfYtiS-WXI,17021
46
- pyrddlgym_jax-2.4.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
47
- pyrddlgym_jax-2.4.dist-info/entry_points.txt,sha256=Q--z9QzqDBz1xjswPZ87PU-pib-WPXx44hUWAFoBGBA,59
48
- pyrddlgym_jax-2.4.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
49
- pyrddlgym_jax-2.4.dist-info/RECORD,,
41
+ pyRDDLGym_jax/examples/configs/tuning_drp.cfg,sha256=zocZn_cVarH5i0hOlt2Zu0NwmXYBmTTghLaXLtQOGto,526
42
+ pyRDDLGym_jax/examples/configs/tuning_replan.cfg,sha256=9oIhtw9cuikmlbDgCgbrTc5G7hUio-HeAv_3CEGVclY,523
43
+ pyRDDLGym_jax/examples/configs/tuning_slp.cfg,sha256=QqnyR__5-HhKeCDfGDel8VIlqsjxRHk4SSH089zJP8s,486
44
+ pyrddlgym_jax-2.5.dist-info/licenses/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
45
+ pyrddlgym_jax-2.5.dist-info/METADATA,sha256=XAaEJfbsYW-txxZhFZ6o_HmvqxkIMTqBF9LbV-KdTzI,17058
46
+ pyrddlgym_jax-2.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
47
+ pyrddlgym_jax-2.5.dist-info/entry_points.txt,sha256=Q--z9QzqDBz1xjswPZ87PU-pib-WPXx44hUWAFoBGBA,59
48
+ pyrddlgym_jax-2.5.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
49
+ pyrddlgym_jax-2.5.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (76.0.0)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5