pyRDDLGym-jax 2.0__py3-none-any.whl → 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -69,8 +69,7 @@ try:
69
69
  from pyRDDLGym_jax.core.visualization import JaxPlannerDashboard
70
70
  except Exception:
71
71
  raise_warning('Failed to load the dashboard visualization tool: '
72
- 'please make sure you have installed the required packages.',
73
- 'red')
72
+ 'please make sure you have installed the required packages.', 'red')
74
73
  traceback.print_exc()
75
74
  JaxPlannerDashboard = None
76
75
 
@@ -133,7 +132,7 @@ def _load_config(config, args):
133
132
  comp_kwargs = model_args.get('complement_kwargs', {})
134
133
  compare_name = model_args.get('comparison', 'SigmoidComparison')
135
134
  compare_kwargs = model_args.get('comparison_kwargs', {})
136
- sampling_name = model_args.get('sampling', 'GumbelSoftmax')
135
+ sampling_name = model_args.get('sampling', 'SoftRandomSampling')
137
136
  sampling_kwargs = model_args.get('sampling_kwargs', {})
138
137
  rounding_name = model_args.get('rounding', 'SoftRounding')
139
138
  rounding_kwargs = model_args.get('rounding_kwargs', {})
@@ -156,8 +155,7 @@ def _load_config(config, args):
156
155
  initializer = _getattr_any(
157
156
  packages=[initializers, hk.initializers], item=plan_initializer)
158
157
  if initializer is None:
159
- raise_warning(
160
- f'Ignoring invalid initializer <{plan_initializer}>.', 'red')
158
+ raise_warning(f'Ignoring invalid initializer <{plan_initializer}>.', 'red')
161
159
  del plan_kwargs['initializer']
162
160
  else:
163
161
  init_kwargs = plan_kwargs.pop('initializer_kwargs', {})
@@ -174,8 +172,7 @@ def _load_config(config, args):
174
172
  activation = _getattr_any(
175
173
  packages=[jax.nn, jax.numpy], item=plan_activation)
176
174
  if activation is None:
177
- raise_warning(
178
- f'Ignoring invalid activation <{plan_activation}>.', 'red')
175
+ raise_warning(f'Ignoring invalid activation <{plan_activation}>.', 'red')
179
176
  del plan_kwargs['activation']
180
177
  else:
181
178
  plan_kwargs['activation'] = activation
@@ -189,8 +186,7 @@ def _load_config(config, args):
189
186
  if planner_optimizer is not None:
190
187
  optimizer = _getattr_any(packages=[optax], item=planner_optimizer)
191
188
  if optimizer is None:
192
- raise_warning(
193
- f'Ignoring invalid optimizer <{planner_optimizer}>.', 'red')
189
+ raise_warning(f'Ignoring invalid optimizer <{planner_optimizer}>.', 'red')
194
190
  del planner_args['optimizer']
195
191
  else:
196
192
  planner_args['optimizer'] = optimizer
@@ -285,48 +281,14 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
285
281
  pvars_cast = set()
286
282
  for (var, values) in self.init_values.items():
287
283
  self.init_values[var] = np.asarray(values, dtype=self.REAL)
288
- if not np.issubdtype(np.atleast_1d(values).dtype, np.floating):
284
+ if not np.issubdtype(np.result_type(values), np.floating):
289
285
  pvars_cast.add(var)
290
286
  if pvars_cast:
291
287
  raise_warning(f'JAX gradient compiler requires that initial values '
292
288
  f'of p-variables {pvars_cast} be cast to float.')
293
289
 
294
290
  # overwrite basic operations with fuzzy ones
295
- self.RELATIONAL_OPS = {
296
- '>=': logic.greater_equal,
297
- '<=': logic.less_equal,
298
- '<': logic.less,
299
- '>': logic.greater,
300
- '==': logic.equal,
301
- '~=': logic.not_equal
302
- }
303
- self.LOGICAL_NOT = logic.logical_not
304
- self.LOGICAL_OPS = {
305
- '^': logic.logical_and,
306
- '&': logic.logical_and,
307
- '|': logic.logical_or,
308
- '~': logic.xor,
309
- '=>': logic.implies,
310
- '<=>': logic.equiv
311
- }
312
- self.AGGREGATION_OPS['forall'] = logic.forall
313
- self.AGGREGATION_OPS['exists'] = logic.exists
314
- self.AGGREGATION_OPS['argmin'] = logic.argmin
315
- self.AGGREGATION_OPS['argmax'] = logic.argmax
316
- self.KNOWN_UNARY['sgn'] = logic.sgn
317
- self.KNOWN_UNARY['floor'] = logic.floor
318
- self.KNOWN_UNARY['ceil'] = logic.ceil
319
- self.KNOWN_UNARY['round'] = logic.round
320
- self.KNOWN_UNARY['sqrt'] = logic.sqrt
321
- self.KNOWN_BINARY['div'] = logic.div
322
- self.KNOWN_BINARY['mod'] = logic.mod
323
- self.KNOWN_BINARY['fmod'] = logic.mod
324
- self.IF_HELPER = logic.control_if
325
- self.SWITCH_HELPER = logic.control_switch
326
- self.BERNOULLI_HELPER = logic.bernoulli
327
- self.DISCRETE_HELPER = logic.discrete
328
- self.POISSON_HELPER = logic.poisson
329
- self.GEOMETRIC_HELPER = logic.geometric
291
+ self.OPS = logic.get_operator_dicts()
330
292
 
331
293
  def _jax_stop_grad(self, jax_expr):
332
294
  def _jax_wrapped_stop_grad(x, params, key):
@@ -575,7 +537,7 @@ class JaxStraightLinePlan(JaxPlan):
575
537
  def _jax_non_bool_param_to_action(var, param, hyperparams):
576
538
  if wrap_non_bool:
577
539
  lower, upper = bounds_safe[var]
578
- mb, ml, mu, mn = [mask.astype(compiled.REAL)
540
+ mb, ml, mu, mn = [jnp.asarray(mask, dtype=compiled.REAL)
579
541
  for mask in cond_lists[var]]
580
542
  action = (
581
543
  mb * (lower + (upper - lower) * jax.nn.sigmoid(param)) +
@@ -660,7 +622,7 @@ class JaxStraightLinePlan(JaxPlan):
660
622
  action = _jax_non_bool_param_to_action(var, action, hyperparams)
661
623
  action = jnp.clip(action, *bounds[var])
662
624
  if ranges[var] == 'int':
663
- action = jnp.round(action).astype(compiled.INT)
625
+ action = jnp.asarray(jnp.round(action), dtype=compiled.INT)
664
626
  actions[var] = action
665
627
  return actions
666
628
 
@@ -961,12 +923,11 @@ class JaxDeepReactivePolicy(JaxPlan):
961
923
  non_bool_dims = 0
962
924
  for (var, values) in observed_vars.items():
963
925
  if ranges[var] != 'bool':
964
- value_size = np.atleast_1d(values).size
926
+ value_size = np.size(values)
965
927
  if normalize_per_layer and value_size == 1:
966
928
  raise_warning(
967
929
  f'Cannot apply layer norm to state-fluent <{var}> '
968
- f'of size 1: setting normalize_per_layer = False.',
969
- 'red')
930
+ f'of size 1: setting normalize_per_layer = False.', 'red')
970
931
  normalize_per_layer = False
971
932
  non_bool_dims += value_size
972
933
  if not normalize_per_layer and non_bool_dims == 1:
@@ -990,9 +951,11 @@ class JaxDeepReactivePolicy(JaxPlan):
990
951
  else:
991
952
  if normalize and normalize_per_layer:
992
953
  normalizer = hk.LayerNorm(
993
- axis=-1, param_axis=-1,
954
+ axis=-1,
955
+ param_axis=-1,
994
956
  name=f'input_norm_{input_names[var]}',
995
- **self._normalizer_kwargs)
957
+ **self._normalizer_kwargs
958
+ )
996
959
  state = normalizer(state)
997
960
  states_non_bool.append(state)
998
961
  non_bool_dims += state.size
@@ -1001,8 +964,11 @@ class JaxDeepReactivePolicy(JaxPlan):
1001
964
  # optionally perform layer normalization on the non-bool inputs
1002
965
  if normalize and not normalize_per_layer and non_bool_dims:
1003
966
  normalizer = hk.LayerNorm(
1004
- axis=-1, param_axis=-1, name='input_norm',
1005
- **self._normalizer_kwargs)
967
+ axis=-1,
968
+ param_axis=-1,
969
+ name='input_norm',
970
+ **self._normalizer_kwargs
971
+ )
1006
972
  normalized = normalizer(state[:non_bool_dims])
1007
973
  state = state.at[:non_bool_dims].set(normalized)
1008
974
  return state
@@ -1021,7 +987,8 @@ class JaxDeepReactivePolicy(JaxPlan):
1021
987
  actions = {}
1022
988
  for (var, size) in layer_sizes.items():
1023
989
  linear = hk.Linear(size, name=layer_names[var], w_init=init)
1024
- reshape = hk.Reshape(output_shape=shapes[var], preserve_dims=-1,
990
+ reshape = hk.Reshape(output_shape=shapes[var],
991
+ preserve_dims=-1,
1025
992
  name=f'reshape_{layer_names[var]}')
1026
993
  output = reshape(linear(hidden))
1027
994
  if not shapes[var]:
@@ -1034,7 +1001,7 @@ class JaxDeepReactivePolicy(JaxPlan):
1034
1001
  else:
1035
1002
  if wrap_non_bool:
1036
1003
  lower, upper = bounds_safe[var]
1037
- mb, ml, mu, mn = [mask.astype(compiled.REAL)
1004
+ mb, ml, mu, mn = [jnp.asarray(mask, dtype=compiled.REAL)
1038
1005
  for mask in cond_lists[var]]
1039
1006
  action = (
1040
1007
  mb * (lower + (upper - lower) * jax.nn.sigmoid(output)) +
@@ -1048,8 +1015,7 @@ class JaxDeepReactivePolicy(JaxPlan):
1048
1015
 
1049
1016
  # for constraint satisfaction wrap bool actions with softmax
1050
1017
  if use_constraint_satisfaction:
1051
- linear = hk.Linear(
1052
- bool_action_count, name='output_bool', w_init=init)
1018
+ linear = hk.Linear(bool_action_count, name='output_bool', w_init=init)
1053
1019
  output = jax.nn.softmax(linear(hidden))
1054
1020
  actions[bool_key] = output
1055
1021
 
@@ -1087,8 +1053,7 @@ class JaxDeepReactivePolicy(JaxPlan):
1087
1053
 
1088
1054
  # test action prediction
1089
1055
  def _jax_wrapped_drp_predict_test(key, params, hyperparams, step, subs):
1090
- actions = _jax_wrapped_drp_predict_train(
1091
- key, params, hyperparams, step, subs)
1056
+ actions = _jax_wrapped_drp_predict_train(key, params, hyperparams, step, subs)
1092
1057
  new_actions = {}
1093
1058
  for (var, action) in actions.items():
1094
1059
  prange = ranges[var]
@@ -1096,7 +1061,7 @@ class JaxDeepReactivePolicy(JaxPlan):
1096
1061
  new_action = action > 0.5
1097
1062
  elif prange == 'int':
1098
1063
  action = jnp.clip(action, *bounds[var])
1099
- new_action = jnp.round(action).astype(compiled.INT)
1064
+ new_action = jnp.asarray(jnp.round(action), dtype=compiled.INT)
1100
1065
  else:
1101
1066
  new_action = jnp.clip(action, *bounds[var])
1102
1067
  new_actions[var] = new_action
@@ -1436,8 +1401,8 @@ class GaussianPGPE(PGPE):
1436
1401
  _jax_wrapped_pgpe_grad,
1437
1402
  in_axes=(0, None, None, None, None, None, None)
1438
1403
  )(keys, mu, sigma, r_max, policy_hyperparams, subs, model_params)
1439
- mu_grad = jax.tree_map(partial(jnp.mean, axis=0), mu_grads)
1440
- sigma_grad = jax.tree_map(partial(jnp.mean, axis=0), sigma_grads)
1404
+ mu_grad, sigma_grad = jax.tree_map(
1405
+ partial(jnp.mean, axis=0), (mu_grads, sigma_grads))
1441
1406
  new_r_max = jnp.max(r_maxs)
1442
1407
  return mu_grad, sigma_grad, new_r_max
1443
1408
 
@@ -1463,6 +1428,71 @@ class GaussianPGPE(PGPE):
1463
1428
  self._update = jax.jit(_jax_wrapped_pgpe_update)
1464
1429
 
1465
1430
 
1431
+ # ***********************************************************************
1432
+ # ALL VERSIONS OF RISK FUNCTIONS
1433
+ #
1434
+ # Based on the original paper "A Distributional Framework for Risk-Sensitive
1435
+ # End-to-End Planning in Continuous MDPs" by Patton et al., AAAI 2022.
1436
+ #
1437
+ # Original risk functions:
1438
+ # - entropic utility
1439
+ # - mean-variance
1440
+ # - mean-semideviation
1441
+ # - conditional value at risk with straight-through gradient trick
1442
+ #
1443
+ # ***********************************************************************
1444
+
1445
+
1446
+ @jax.jit
1447
+ def entropic_utility(returns: jnp.ndarray, beta: float) -> float:
1448
+ return (-1.0 / beta) * jax.scipy.special.logsumexp(
1449
+ -beta * returns, b=1.0 / returns.size)
1450
+
1451
+
1452
+ @jax.jit
1453
+ def mean_variance_utility(returns: jnp.ndarray, beta: float) -> float:
1454
+ return jnp.mean(returns) - 0.5 * beta * jnp.var(returns)
1455
+
1456
+
1457
+ @jax.jit
1458
+ def mean_deviation_utility(returns: jnp.ndarray, beta: float) -> float:
1459
+ return jnp.mean(returns) - 0.5 * beta * jnp.std(returns)
1460
+
1461
+
1462
+ @jax.jit
1463
+ def mean_semideviation_utility(returns: jnp.ndarray, beta: float) -> float:
1464
+ mu = jnp.mean(returns)
1465
+ msd = jnp.sqrt(jnp.mean(jnp.minimum(0.0, returns - mu) ** 2))
1466
+ return mu - 0.5 * beta * msd
1467
+
1468
+
1469
+ @jax.jit
1470
+ def mean_semivariance_utility(returns: jnp.ndarray, beta: float) -> float:
1471
+ mu = jnp.mean(returns)
1472
+ msv = jnp.mean(jnp.minimum(0.0, returns - mu) ** 2)
1473
+ return mu - 0.5 * beta * msv
1474
+
1475
+
1476
+ @jax.jit
1477
+ def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
1478
+ var = jnp.percentile(returns, q=100 * alpha)
1479
+ mask = returns <= var
1480
+ weights = mask / jnp.maximum(1, jnp.sum(mask))
1481
+ return jnp.sum(returns * weights)
1482
+
1483
+
1484
+ UTILITY_LOOKUP = {
1485
+ 'mean': jnp.mean,
1486
+ 'mean_var': mean_variance_utility,
1487
+ 'mean_std': mean_deviation_utility,
1488
+ 'mean_semivar': mean_semivariance_utility,
1489
+ 'mean_semidev': mean_semideviation_utility,
1490
+ 'entropic': entropic_utility,
1491
+ 'exponential': entropic_utility,
1492
+ 'cvar': cvar_utility
1493
+ }
1494
+
1495
+
1466
1496
  # ***********************************************************************
1467
1497
  # ALL VERSIONS OF JAX PLANNER
1468
1498
  #
@@ -1525,8 +1555,7 @@ class JaxBackpropPlanner:
1525
1555
  reward as a form of normalization
1526
1556
  :param utility: how to aggregate return observations to compute utility
1527
1557
  of a policy or plan; must be either a function mapping jax array to a
1528
- scalar, or a a string identifying the utility function by name
1529
- ("mean", "mean_var", "entropic", or "cvar" are currently supported)
1558
+ scalar, or a a string identifying the utility function by name
1530
1559
  :param utility_kwargs: additional keyword arguments to pass hyper-
1531
1560
  parameters to the utility function call
1532
1561
  :param cpfs_without_grad: which CPFs do not have gradients (use straight
@@ -1584,18 +1613,11 @@ class JaxBackpropPlanner:
1584
1613
  # set utility
1585
1614
  if isinstance(utility, str):
1586
1615
  utility = utility.lower()
1587
- if utility == 'mean':
1588
- utility_fn = jnp.mean
1589
- elif utility == 'mean_var':
1590
- utility_fn = mean_variance_utility
1591
- elif utility == 'entropic':
1592
- utility_fn = entropic_utility
1593
- elif utility == 'cvar':
1594
- utility_fn = cvar_utility
1595
- else:
1616
+ utility_fn = UTILITY_LOOKUP.get(utility, None)
1617
+ if utility_fn is None:
1596
1618
  raise RDDLNotImplementedError(
1597
- f'Utility function <{utility}> is not supported: '
1598
- 'must be one of ["mean", "mean_var", "entropic", "cvar"].')
1619
+ f'Utility <{utility}> is not supported, '
1620
+ f'must be one of {list(UTILITY_LOOKUP.keys())}.')
1599
1621
  else:
1600
1622
  utility_fn = utility
1601
1623
  self.utility = utility_fn
@@ -1865,7 +1887,7 @@ r"""
1865
1887
  f'{set(self.test_compiled.init_values.keys())}.')
1866
1888
  value = np.reshape(value, newshape=np.shape(init_value))[np.newaxis, ...]
1867
1889
  train_value = np.repeat(value, repeats=n_train, axis=0)
1868
- train_value = train_value.astype(self.compiled.REAL)
1890
+ train_value = np.asarray(train_value, dtype=self.compiled.REAL)
1869
1891
  init_train[name] = train_value
1870
1892
  init_test[name] = np.repeat(value, repeats=n_test, axis=0)
1871
1893
 
@@ -2175,7 +2197,9 @@ r"""
2175
2197
 
2176
2198
  iters = range(epochs)
2177
2199
  if print_progress:
2178
- iters = tqdm(iters, total=100, position=tqdm_position)
2200
+ iters = tqdm(iters, total=100,
2201
+ bar_format='{l_bar}{bar}| {elapsed} {postfix}',
2202
+ position=tqdm_position)
2179
2203
  position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
2180
2204
 
2181
2205
  for it in iters:
@@ -2256,7 +2280,8 @@ r"""
2256
2280
  status = JaxPlannerStatus.ITER_BUDGET_REACHED
2257
2281
 
2258
2282
  # build a callback
2259
- progress_percent = int(100 * min(1, max(elapsed / train_seconds, it / epochs)))
2283
+ progress_percent = 100 * min(
2284
+ 1, max(0, elapsed / train_seconds, it / (epochs - 1)))
2260
2285
  callback = {
2261
2286
  'status': status,
2262
2287
  'iteration': it,
@@ -2279,7 +2304,7 @@ r"""
2279
2304
  'train_log': train_log,
2280
2305
  **test_log
2281
2306
  }
2282
-
2307
+
2283
2308
  # stopping condition reached
2284
2309
  if stopping_rule is not None and stopping_rule.monitor(callback):
2285
2310
  callback['status'] = status = JaxPlannerStatus.STOPPING_RULE_REACHED
@@ -2290,8 +2315,10 @@ r"""
2290
2315
  iters.set_description(
2291
2316
  f'{position_str} {it:6} it / {-train_loss:14.5f} train / '
2292
2317
  f'{-test_loss_smooth:14.5f} test / {-best_loss:14.5f} best / '
2293
- f'{status.value} status / {total_pgpe_it:6} pgpe'
2318
+ f'{status.value} status / {total_pgpe_it:6} pgpe',
2319
+ refresh=False
2294
2320
  )
2321
+ iters.set_postfix_str(f"{(it + 1) / elapsed:.2f}it/s", refresh=True)
2295
2322
 
2296
2323
  # dash-board
2297
2324
  if dashboard is not None:
@@ -2332,7 +2359,7 @@ r"""
2332
2359
  last_iter_improve, -train_loss, -test_loss_smooth, -best_loss, grad_norm)
2333
2360
  print(f'summary of optimization:\n'
2334
2361
  f' status ={status}\n'
2335
- f' time ={elapsed:.6f} sec.\n'
2362
+ f' time ={elapsed:.3f} sec.\n'
2336
2363
  f' iterations ={it}\n'
2337
2364
  f' best objective={-best_loss:.6f}\n'
2338
2365
  f' best grad norm={grad_norm}\n'
@@ -2358,12 +2385,12 @@ r"""
2358
2385
  return termcolor.colored(
2359
2386
  '[FAILURE] no progress was made '
2360
2387
  f'and max grad norm {max_grad_norm:.6f} was zero: '
2361
- 'the solver was likely stuck in a plateau.', 'red')
2388
+ 'solver likely stuck in a plateau.', 'red')
2362
2389
  else:
2363
2390
  return termcolor.colored(
2364
2391
  '[FAILURE] no progress was made '
2365
2392
  f'but max grad norm {max_grad_norm:.6f} was non-zero: '
2366
- 'the learning rate or other hyper-parameters were likely suboptimal.',
2393
+ 'learning rate or other hyper-parameters likely suboptimal.',
2367
2394
  'red')
2368
2395
 
2369
2396
  # model is likely poor IF:
@@ -2372,8 +2399,8 @@ r"""
2372
2399
  return termcolor.colored(
2373
2400
  '[WARNING] progress was made '
2374
2401
  f'but relative train-test error {validation_error:.6f} was high: '
2375
- 'model relaxation around the solution was poor '
2376
- 'or the batch size was too small.', 'yellow')
2402
+ 'poor model relaxation around solution or batch size too small.',
2403
+ 'yellow')
2377
2404
 
2378
2405
  # model likely did not converge IF:
2379
2406
  # 1. the max grad relative to the return is high
@@ -2383,9 +2410,9 @@ r"""
2383
2410
  return termcolor.colored(
2384
2411
  '[WARNING] progress was made '
2385
2412
  f'but max grad norm {max_grad_norm:.6f} was high: '
2386
- 'the solution was likely locally suboptimal, '
2387
- 'or the relaxed model was not smooth around the solution, '
2388
- 'or the batch size was too small.', 'yellow')
2413
+ 'solution locally suboptimal '
2414
+ 'or relaxed model not smooth around solution '
2415
+ 'or batch size too small.', 'yellow')
2389
2416
 
2390
2417
  # likely successful
2391
2418
  return termcolor.colored(
@@ -2412,8 +2439,7 @@ r"""
2412
2439
  for (var, values) in subs.items():
2413
2440
 
2414
2441
  # must not be grounded
2415
- if RDDLPlanningModel.FLUENT_SEP in var \
2416
- or RDDLPlanningModel.OBJECT_SEP in var:
2442
+ if RDDLPlanningModel.FLUENT_SEP in var or RDDLPlanningModel.OBJECT_SEP in var:
2417
2443
  raise ValueError(f'State dictionary passed to the JAX policy is '
2418
2444
  f'grounded, since it contains the key <{var}>, '
2419
2445
  f'but a vectorized environment is required: '
@@ -2421,9 +2447,8 @@ r"""
2421
2447
 
2422
2448
  # must be numeric array
2423
2449
  # exception is for POMDPs at 1st epoch when observ-fluents are None
2424
- dtype = np.atleast_1d(values).dtype
2425
- if not np.issubdtype(dtype, np.number) \
2426
- and not np.issubdtype(dtype, np.bool_):
2450
+ dtype = np.result_type(values)
2451
+ if not np.issubdtype(dtype, np.number) and not np.issubdtype(dtype, np.bool_):
2427
2452
  if step == 0 and var in self.rddl.observ_fluents:
2428
2453
  subs[var] = self.test_compiled.init_values[var]
2429
2454
  else:
@@ -2435,40 +2460,7 @@ r"""
2435
2460
  actions = self.test_policy(key, params, policy_hyperparams, step, subs)
2436
2461
  actions = jax.tree_map(np.asarray, actions)
2437
2462
  return actions
2438
-
2439
-
2440
- # ***********************************************************************
2441
- # ALL VERSIONS OF RISK FUNCTIONS
2442
- #
2443
- # Based on the original paper "A Distributional Framework for Risk-Sensitive
2444
- # End-to-End Planning in Continuous MDPs" by Patton et al., AAAI 2022.
2445
- #
2446
- # Original risk functions:
2447
- # - entropic utility
2448
- # - mean-variance approximation
2449
- # - conditional value at risk with straight-through gradient trick
2450
- #
2451
- # ***********************************************************************
2452
-
2453
-
2454
- @jax.jit
2455
- def entropic_utility(returns: jnp.ndarray, beta: float) -> float:
2456
- return (-1.0 / beta) * jax.scipy.special.logsumexp(
2457
- -beta * returns, b=1.0 / returns.size)
2458
-
2459
-
2460
- @jax.jit
2461
- def mean_variance_utility(returns: jnp.ndarray, beta: float) -> float:
2462
- return jnp.mean(returns) - 0.5 * beta * jnp.var(returns)
2463
-
2464
-
2465
- @jax.jit
2466
- def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
2467
- var = jnp.percentile(returns, q=100 * alpha)
2468
- mask = returns <= var
2469
- weights = mask / jnp.maximum(1, jnp.sum(mask))
2470
- return jnp.sum(returns * weights)
2471
-
2463
+
2472
2464
 
2473
2465
  # ***********************************************************************
2474
2466
  # ALL VERSIONS OF CONTROLLERS
@@ -2580,8 +2572,7 @@ class JaxOnlineController(BaseAgent):
2580
2572
  self.callback = callback
2581
2573
  params = callback['best_params']
2582
2574
  self.key, subkey = random.split(self.key)
2583
- actions = planner.get_action(
2584
- subkey, params, 0, state, self.eval_hyperparams)
2575
+ actions = planner.get_action(subkey, params, 0, state, self.eval_hyperparams)
2585
2576
  if self.warm_start:
2586
2577
  self.guess = planner.plan.guess_next_epoch(params)
2587
2578
  return actions
@@ -20,8 +20,7 @@ import math
20
20
  import numpy as np
21
21
  import time
22
22
  import threading
23
- from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
24
- import warnings
23
+ from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
25
24
  import webbrowser
26
25
 
27
26
  # prevent endless console prints
@@ -32,7 +31,7 @@ log.setLevel(logging.ERROR)
32
31
  import dash
33
32
  from dash.dcc import Interval, Graph, Store
34
33
  from dash.dependencies import Input, Output, State, ALL
35
- from dash.html import Div, B, H4, P, Img, Hr
34
+ from dash.html import Div, B, H4, P, Hr
36
35
  import dash_bootstrap_components as dbc
37
36
 
38
37
  import plotly.colors as pc
@@ -53,6 +52,7 @@ REWARD_ERROR_DIST_SUBPLOTS = 20
53
52
  MODEL_STATE_ERROR_HEIGHT = 300
54
53
  POLICY_STATE_VIZ_MAX_HEIGHT = 800
55
54
  GP_POSTERIOR_MAX_HEIGHT = 800
55
+ GP_POSTERIOR_PIXELS = 100
56
56
 
57
57
  PLOT_AXES_FONT_SIZE = 11
58
58
  EXPERIMENT_ENTRY_FONT_SIZE = 14
@@ -1417,7 +1417,7 @@ class JaxPlannerDashboard:
1417
1417
  self.pgpe_return[experiment_id].append(callback['pgpe_return'])
1418
1418
 
1419
1419
  # data for return distributions
1420
- progress = callback['progress']
1420
+ progress = int(callback['progress'])
1421
1421
  if progress - self.return_dist_last_progress[experiment_id] \
1422
1422
  >= PROGRESS_FOR_NEXT_RETURN_DIST:
1423
1423
  self.return_dist_ticks[experiment_id].append(iteration)
@@ -1486,8 +1486,8 @@ class JaxPlannerDashboard:
1486
1486
  if i2 > i1:
1487
1487
 
1488
1488
  # Generate a grid for visualization
1489
- p1_values = np.linspace(*bounds[param1], 100)
1490
- p2_values = np.linspace(*bounds[param2], 100)
1489
+ p1_values = np.linspace(*bounds[param1], GP_POSTERIOR_PIXELS)
1490
+ p2_values = np.linspace(*bounds[param2], GP_POSTERIOR_PIXELS)
1491
1491
  P1, P2 = np.meshgrid(p1_values, p2_values)
1492
1492
 
1493
1493
  # Predict the mean and deviation of the surrogate model
@@ -1500,8 +1500,7 @@ class JaxPlannerDashboard:
1500
1500
  for p1, p2 in zip(np.ravel(P1), np.ravel(P2)):
1501
1501
  params = {param1: p1, param2: p2}
1502
1502
  params.update(fixed_params)
1503
- param_grid.append(
1504
- [params[key] for key in optimizer.space.keys])
1503
+ param_grid.append([params[key] for key in optimizer.space.keys])
1505
1504
  param_grid = np.asarray(param_grid)
1506
1505
  mean, std = optimizer._gp.predict(param_grid, return_std=True)
1507
1506
  mean = mean.reshape(P1.shape)
@@ -3,7 +3,7 @@ is performed using a batched parallelized Bayesian optimization.
3
3
 
4
4
  The syntax is:
5
5
 
6
- python run_tune.py <domain> <instance> <method> [<trials>] [<iters>] [<workers>]
6
+ python run_tune.py <domain> <instance> <method> [<trials>] [<iters>] [<workers>] [<dashboard>]
7
7
 
8
8
  where:
9
9
  <domain> is the name of a domain located in the /Examples directory
@@ -15,6 +15,7 @@ where:
15
15
  (defaults to 20)
16
16
  <workers> is the number of parallel workers (i.e. batch size), which must
17
17
  not exceed the number of cores available on the machine (defaults to 4)
18
+ <dashboard> is whether the dashboard is displayed
18
19
  '''
19
20
  import os
20
21
  import sys
@@ -35,7 +36,7 @@ def power_10(x):
35
36
  return 10.0 ** x
36
37
 
37
38
 
38
- def main(domain, instance, method, trials=5, iters=20, workers=4):
39
+ def main(domain, instance, method, trials=5, iters=20, workers=4, dashboard=False):
39
40
 
40
41
  # set up the environment
41
42
  env = pyRDDLGym.make(domain, instance, vectorized=True)
@@ -48,9 +49,9 @@ def main(domain, instance, method, trials=5, iters=20, workers=4):
48
49
 
49
50
  # map parameters in the config that will be tuned
50
51
  hyperparams = [
51
- Hyperparameter('MODEL_WEIGHT_TUNE', -1., 5., power_10),
52
+ Hyperparameter('MODEL_WEIGHT_TUNE', -1., 4., power_10),
52
53
  Hyperparameter('POLICY_WEIGHT_TUNE', -2., 2., power_10),
53
- Hyperparameter('LEARNING_RATE_TUNE', -5., 1., power_10),
54
+ Hyperparameter('LEARNING_RATE_TUNE', -5., 0., power_10),
54
55
  Hyperparameter('LAYER1_TUNE', 1, 8, power_2),
55
56
  Hyperparameter('LAYER2_TUNE', 1, 8, power_2),
56
57
  Hyperparameter('ROLLOUT_HORIZON_TUNE', 1, min(env.horizon, 100), int)
@@ -64,7 +65,9 @@ def main(domain, instance, method, trials=5, iters=20, workers=4):
64
65
  eval_trials=trials,
65
66
  num_workers=workers,
66
67
  gp_iters=iters)
67
- tuning.tune(key=42, log_file=f'gp_{method}_{domain}_{instance}.csv')
68
+ tuning.tune(key=42,
69
+ log_file=f'gp_{method}_{domain}_{instance}.csv',
70
+ show_dashboard=dashboard)
68
71
 
69
72
  # evaluate the agent on the best parameters
70
73
  planner_args, _, train_args = load_config_from_string(tuning.best_config)
@@ -77,7 +80,7 @@ def main(domain, instance, method, trials=5, iters=20, workers=4):
77
80
 
78
81
  def run_from_args(args):
79
82
  if len(args) < 3:
80
- print('python run_tune.py <domain> <instance> <method> [<trials>] [<iters>] [<workers>]')
83
+ print('python run_tune.py <domain> <instance> <method> [<trials>] [<iters>] [<workers>] [<dashboard>]')
81
84
  exit(1)
82
85
  if args[2] not in ['drp', 'slp', 'replan']:
83
86
  print('<method> in [drp, slp, replan]')
@@ -86,6 +89,7 @@ def run_from_args(args):
86
89
  if len(args) >= 4: kwargs['trials'] = int(args[3])
87
90
  if len(args) >= 5: kwargs['iters'] = int(args[4])
88
91
  if len(args) >= 6: kwargs['workers'] = int(args[5])
92
+ if len(args) >= 7: kwargs['dashboard'] = bool(args[6])
89
93
  main(**kwargs)
90
94
 
91
95