pyRDDLGym-jax 0.4__py3-none-any.whl → 0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -47,7 +47,6 @@ Bounds = Dict[str, Tuple[np.ndarray, np.ndarray]]
47
47
  Kwargs = Dict[str, Any]
48
48
  Pytree = Any
49
49
 
50
-
51
50
  # ***********************************************************************
52
51
  # CONFIG FILE MANAGEMENT
53
52
  #
@@ -57,6 +56,7 @@ Pytree = Any
57
56
  #
58
57
  # ***********************************************************************
59
58
 
59
+
60
60
  def _parse_config_file(path: str):
61
61
  if not os.path.isfile(path):
62
62
  raise FileNotFoundError(f'File {path} does not exist.')
@@ -103,10 +103,13 @@ def _load_config(config, args):
103
103
  compare_kwargs = model_args.get('comparison_kwargs', {})
104
104
  sampling_name = model_args.get('sampling', 'GumbelSoftmax')
105
105
  sampling_kwargs = model_args.get('sampling_kwargs', {})
106
+ rounding_name = model_args.get('rounding', 'SoftRounding')
107
+ rounding_kwargs = model_args.get('rounding_kwargs', {})
106
108
  logic_kwargs['tnorm'] = getattr(logic, tnorm_name)(**tnorm_kwargs)
107
109
  logic_kwargs['complement'] = getattr(logic, comp_name)(**comp_kwargs)
108
110
  logic_kwargs['comparison'] = getattr(logic, compare_name)(**compare_kwargs)
109
111
  logic_kwargs['sampling'] = getattr(logic, sampling_name)(**sampling_kwargs)
112
+ logic_kwargs['rounding'] = getattr(logic, rounding_name)(**rounding_kwargs)
110
113
 
111
114
  # read the policy settings
112
115
  plan_method = planner_args.pop('method')
@@ -157,11 +160,18 @@ def _load_config(config, args):
157
160
  else:
158
161
  planner_args['optimizer'] = optimizer
159
162
 
160
- # read the optimize call settings
163
+ # optimize call RNG key
161
164
  planner_key = train_args.get('key', None)
162
165
  if planner_key is not None:
163
166
  train_args['key'] = random.PRNGKey(planner_key)
164
167
 
168
+ # optimize call stopping rule
169
+ stopping_rule = train_args.get('stopping_rule', None)
170
+ if stopping_rule is not None:
171
+ stopping_rule_kwargs = train_args.pop('stopping_rule_kwargs', {})
172
+ train_args['stopping_rule'] = getattr(
173
+ sys.modules[__name__], stopping_rule)(**stopping_rule_kwargs)
174
+
165
175
  return planner_args, plan_kwargs, train_args
166
176
 
167
177
 
@@ -175,7 +185,6 @@ def load_config_from_string(value: str) -> Tuple[Kwargs, ...]:
175
185
  '''Loads config file contents specified explicitly as a string value.'''
176
186
  config, args = _parse_config_string(value)
177
187
  return _load_config(config, args)
178
-
179
188
 
180
189
  # ***********************************************************************
181
190
  # MODEL RELAXATIONS
@@ -299,7 +308,6 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
299
308
  arg, = expr.args
300
309
  arg = self._jax(arg, info)
301
310
  return arg
302
-
303
311
 
304
312
  # ***********************************************************************
305
313
  # ALL VERSIONS OF JAX PLANS
@@ -309,6 +317,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
309
317
  #
310
318
  # ***********************************************************************
311
319
 
320
+
312
321
  class JaxPlan:
313
322
  '''Base class for all JAX policy representations.'''
314
323
 
@@ -363,7 +372,7 @@ class JaxPlan:
363
372
  self._projection = value
364
373
 
365
374
  def _calculate_action_info(self, compiled: JaxRDDLCompilerWithGrad,
366
- user_bounds: Bounds,
375
+ user_bounds: Bounds,
367
376
  horizon: int):
368
377
  shapes, bounds, bounds_safe, cond_lists = {}, {}, {}, {}
369
378
  for (name, prange) in compiled.rddl.variable_ranges.items():
@@ -463,7 +472,7 @@ class JaxStraightLinePlan(JaxPlan):
463
472
  f' max_projection_iters ={self._max_constraint_iter}')
464
473
 
465
474
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
466
- _bounds: Bounds,
475
+ _bounds: Bounds,
467
476
  horizon: int) -> None:
468
477
  rddl = compiled.rddl
469
478
 
@@ -504,7 +513,7 @@ class JaxStraightLinePlan(JaxPlan):
504
513
  def _jax_bool_action_to_param(var, action, hyperparams):
505
514
  if wrap_sigmoid:
506
515
  weight = hyperparams[var]
507
- return (-1.0 / weight) * jnp.log(1.0 / action - 1.0)
516
+ return jax.scipy.special.logit(action) / weight
508
517
  else:
509
518
  return action
510
519
 
@@ -513,14 +522,13 @@ class JaxStraightLinePlan(JaxPlan):
513
522
  def _jax_non_bool_param_to_action(var, param, hyperparams):
514
523
  if wrap_non_bool:
515
524
  lower, upper = bounds_safe[var]
516
- action = jnp.select(
517
- condlist=cond_lists[var],
518
- choicelist=[
519
- lower + (upper - lower) * jax.nn.sigmoid(param),
520
- lower + (jax.nn.elu(param) + 1.0),
521
- upper - (jax.nn.elu(-param) + 1.0),
522
- param
523
- ]
525
+ mb, ml, mu, mn = [mask.astype(compiled.REAL)
526
+ for mask in cond_lists[var]]
527
+ action = (
528
+ mb * (lower + (upper - lower) * jax.nn.sigmoid(param)) +
529
+ ml * (lower + (jax.nn.elu(param) + 1.0)) +
530
+ mu * (upper - (jax.nn.elu(-param) + 1.0)) +
531
+ mn * param
524
532
  )
525
533
  else:
526
534
  action = param
@@ -780,7 +788,7 @@ class JaxDeepReactivePolicy(JaxPlan):
780
788
  def __init__(self, topology: Optional[Sequence[int]]=None,
781
789
  activation: Activation=jnp.tanh,
782
790
  initializer: hk.initializers.Initializer=hk.initializers.VarianceScaling(scale=2.0),
783
- normalize: bool=False,
791
+ normalize: bool=False,
784
792
  normalize_per_layer: bool=False,
785
793
  normalizer_kwargs: Optional[Kwargs]=None,
786
794
  wrap_non_bool: bool=False) -> None:
@@ -828,7 +836,7 @@ class JaxDeepReactivePolicy(JaxPlan):
828
836
  f' wrap_non_bool ={self._wrap_non_bool}')
829
837
 
830
838
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
831
- _bounds: Bounds,
839
+ _bounds: Bounds,
832
840
  horizon: int) -> None:
833
841
  rddl = compiled.rddl
834
842
 
@@ -881,7 +889,7 @@ class JaxDeepReactivePolicy(JaxPlan):
881
889
  if normalize_per_layer and value_size == 1:
882
890
  raise_warning(
883
891
  f'Cannot apply layer norm to state-fluent <{var}> '
884
- f'of size 1: setting normalize_per_layer = False.',
892
+ f'of size 1: setting normalize_per_layer = False.',
885
893
  'red')
886
894
  normalize_per_layer = False
887
895
  non_bool_dims += value_size
@@ -906,8 +914,8 @@ class JaxDeepReactivePolicy(JaxPlan):
906
914
  else:
907
915
  if normalize and normalize_per_layer:
908
916
  normalizer = hk.LayerNorm(
909
- axis=-1, param_axis=-1,
910
- name=f'input_norm_{input_names[var]}',
917
+ axis=-1, param_axis=-1,
918
+ name=f'input_norm_{input_names[var]}',
911
919
  **self._normalizer_kwargs)
912
920
  state = normalizer(state)
913
921
  states_non_bool.append(state)
@@ -917,7 +925,7 @@ class JaxDeepReactivePolicy(JaxPlan):
917
925
  # optionally perform layer normalization on the non-bool inputs
918
926
  if normalize and not normalize_per_layer and non_bool_dims:
919
927
  normalizer = hk.LayerNorm(
920
- axis=-1, param_axis=-1, name='input_norm',
928
+ axis=-1, param_axis=-1, name='input_norm',
921
929
  **self._normalizer_kwargs)
922
930
  normalized = normalizer(state[:non_bool_dims])
923
931
  state = state.at[:non_bool_dims].set(normalized)
@@ -950,14 +958,13 @@ class JaxDeepReactivePolicy(JaxPlan):
950
958
  else:
951
959
  if wrap_non_bool:
952
960
  lower, upper = bounds_safe[var]
953
- action = jnp.select(
954
- condlist=cond_lists[var],
955
- choicelist=[
956
- lower + (upper - lower) * jax.nn.sigmoid(output),
957
- lower + (jax.nn.elu(output) + 1.0),
958
- upper - (jax.nn.elu(-output) + 1.0),
959
- output
960
- ]
961
+ mb, ml, mu, mn = [mask.astype(compiled.REAL)
962
+ for mask in cond_lists[var]]
963
+ action = (
964
+ mb * (lower + (upper - lower) * jax.nn.sigmoid(output)) +
965
+ ml * (lower + (jax.nn.elu(output) + 1.0)) +
966
+ mu * (upper - (jax.nn.elu(-output) + 1.0)) +
967
+ mn * output
961
968
  )
962
969
  else:
963
970
  action = output
@@ -1049,7 +1056,6 @@ class JaxDeepReactivePolicy(JaxPlan):
1049
1056
 
1050
1057
  def guess_next_epoch(self, params: Pytree) -> Pytree:
1051
1058
  return params
1052
-
1053
1059
 
1054
1060
  # ***********************************************************************
1055
1061
  # ALL VERSIONS OF JAX PLANNER
@@ -1059,6 +1065,7 @@ class JaxDeepReactivePolicy(JaxPlan):
1059
1065
  #
1060
1066
  # ***********************************************************************
1061
1067
 
1068
+
1062
1069
  class RollingMean:
1063
1070
  '''Maintains an estimate of the rolling mean of a stream of real-valued
1064
1071
  observations.'''
@@ -1080,7 +1087,7 @@ class RollingMean:
1080
1087
  class JaxPlannerPlot:
1081
1088
  '''Supports plotting and visualization of a JAX policy in real time.'''
1082
1089
 
1083
- def __init__(self, rddl: RDDLPlanningModel, horizon: int,
1090
+ def __init__(self, rddl: RDDLPlanningModel, horizon: int,
1084
1091
  show_violin: bool=True, show_action: bool=True) -> None:
1085
1092
  '''Creates a new planner visualizer.
1086
1093
 
@@ -1128,7 +1135,7 @@ class JaxPlannerPlot:
1128
1135
  for dim in rddl.object_counts(rddl.variable_params[name]):
1129
1136
  action_dim *= dim
1130
1137
  action_plot = ax.pcolormesh(
1131
- np.zeros((action_dim, horizon)),
1138
+ np.zeros((action_dim, horizon)),
1132
1139
  cmap='seismic', vmin=vmin, vmax=vmax)
1133
1140
  ax.set_aspect('auto')
1134
1141
  ax.set_xlabel('decision epoch')
@@ -1201,6 +1208,39 @@ class JaxPlannerStatus(Enum):
1201
1208
  return self.value >= 3
1202
1209
 
1203
1210
 
1211
+ class JaxPlannerStoppingRule:
1212
+ '''The base class of all planner stopping rules.'''
1213
+
1214
+ def reset(self) -> None:
1215
+ raise NotImplementedError
1216
+
1217
+ def monitor(self, callback: Dict[str, Any]) -> bool:
1218
+ raise NotImplementedError
1219
+
1220
+
1221
+ class NoImprovementStoppingRule(JaxPlannerStoppingRule):
1222
+ '''Stopping rule based on no improvement for a fixed number of iterations.'''
1223
+
1224
+ def __init__(self, patience: int) -> None:
1225
+ self.patience = patience
1226
+
1227
+ def reset(self) -> None:
1228
+ self.callback = None
1229
+ self.iters_since_last_update = 0
1230
+
1231
+ def monitor(self, callback: Dict[str, Any]) -> bool:
1232
+ if self.callback is None \
1233
+ or callback['best_return'] > self.callback['best_return']:
1234
+ self.callback = callback
1235
+ self.iters_since_last_update = 0
1236
+ else:
1237
+ self.iters_since_last_update += 1
1238
+ return self.iters_since_last_update >= self.patience
1239
+
1240
+ def __str__(self) -> str:
1241
+ return f'No improvement for {self.patience} iterations'
1242
+
1243
+
1204
1244
  class JaxBackpropPlanner:
1205
1245
  '''A class for optimizing an action sequence in the given RDDL MDP using
1206
1246
  gradient descent.'''
@@ -1215,6 +1255,8 @@ class JaxBackpropPlanner:
1215
1255
  optimizer: Callable[..., optax.GradientTransformation]=optax.rmsprop,
1216
1256
  optimizer_kwargs: Optional[Kwargs]=None,
1217
1257
  clip_grad: Optional[float]=None,
1258
+ noise_grad_eta: float=0.0,
1259
+ noise_grad_gamma: float=1.0,
1218
1260
  logic: FuzzyLogic=FuzzyLogic(),
1219
1261
  use_symlog_reward: bool=False,
1220
1262
  utility: Union[Callable[[jnp.ndarray], float], str]='mean',
@@ -1241,6 +1283,8 @@ class JaxBackpropPlanner:
1241
1283
  :param optimizer_kwargs: a dictionary of parameters to pass to the SGD
1242
1284
  factory (e.g. which parameters are controllable externally)
1243
1285
  :param clip_grad: maximum magnitude of gradient updates
1286
+ :param noise_grad_eta: scale of the gradient noise variance
1287
+ :param noise_grad_gamma: decay rate of the gradient noise variance
1244
1288
  :param logic: a subclass of FuzzyLogic for mapping exact mathematical
1245
1289
  operations to their differentiable counterparts
1246
1290
  :param use_symlog_reward: whether to use the symlog transform on the
@@ -1275,6 +1319,8 @@ class JaxBackpropPlanner:
1275
1319
  optimizer_kwargs = {'learning_rate': 0.1}
1276
1320
  self._optimizer_kwargs = optimizer_kwargs
1277
1321
  self.clip_grad = clip_grad
1322
+ self.noise_grad_eta = noise_grad_eta
1323
+ self.noise_grad_gamma = noise_grad_gamma
1278
1324
 
1279
1325
  # set optimizer
1280
1326
  try:
@@ -1340,14 +1386,14 @@ class JaxBackpropPlanner:
1340
1386
  except Exception as _:
1341
1387
  devices_short = 'N/A'
1342
1388
  LOGO = \
1389
+ r"""
1390
+ __ ______ __ __ ______ __ ______ __ __
1391
+ /\ \ /\ __ \ /\_\_\_\ /\ == \/\ \ /\ __ \ /\ "-.\ \
1392
+ _\_\ \\ \ __ \\/_/\_\/_\ \ _-/\ \ \____\ \ __ \\ \ \-. \
1393
+ /\_____\\ \_\ \_\ /\_\/\_\\ \_\ \ \_____\\ \_\ \_\\ \_\\"\_\
1394
+ \/_____/ \/_/\/_/ \/_/\/_/ \/_/ \/_____/ \/_/\/_/ \/_/ \/_/
1343
1395
  """
1344
- __ ______ __ __ ______ __ ______ __ __
1345
- /\ \ /\ __ \ /\_\_\_\ /\ == \/\ \ /\ __ \ /\ "-.\ \
1346
- _\_\ \ \ \ __ \ \/_/\_\/_ \ \ _-/\ \ \____ \ \ __ \ \ \ \-. \
1347
- /\_____\ \ \_\ \_\ /\_\/\_\ \ \_\ \ \_____\ \ \_\ \_\ \ \_\\"\_\
1348
- \/_____/ \/_/\/_/ \/_/\/_/ \/_/ \/_____/ \/_/\/_/ \/_/ \/_/
1349
- """
1350
-
1396
+
1351
1397
  print('\n'
1352
1398
  f'{LOGO}\n'
1353
1399
  f'Version {__version__}\n'
@@ -1372,6 +1418,8 @@ class JaxBackpropPlanner:
1372
1418
  f' optimizer ={self._optimizer_name.__name__}\n'
1373
1419
  f' optimizer args ={self._optimizer_kwargs}\n'
1374
1420
  f' clip_gradient ={self.clip_grad}\n'
1421
+ f' noise_grad_eta ={self.noise_grad_eta}\n'
1422
+ f' noise_grad_gamma ={self.noise_grad_gamma}\n'
1375
1423
  f' batch_size_train ={self.batch_size_train}\n'
1376
1424
  f' batch_size_test ={self.batch_size_test}')
1377
1425
  self.plan.summarize_hyperparameters()
@@ -1396,7 +1444,7 @@ class JaxBackpropPlanner:
1396
1444
 
1397
1445
  # Jax compilation of the exact RDDL for testing
1398
1446
  self.test_compiled = JaxRDDLCompiler(
1399
- rddl=rddl,
1447
+ rddl=rddl,
1400
1448
  logger=self.logger,
1401
1449
  use64bit=self.use64bit)
1402
1450
  self.test_compiled.compile(log_jax_expr=True, heading='EXACT MODEL')
@@ -1473,7 +1521,7 @@ class JaxBackpropPlanner:
1473
1521
  def _jax_wrapped_init_policy(key, hyperparams, subs):
1474
1522
  policy_params = init(key, hyperparams, subs)
1475
1523
  opt_state = optimizer.init(policy_params)
1476
- return policy_params, opt_state, None
1524
+ return policy_params, opt_state, {}
1477
1525
 
1478
1526
  return _jax_wrapped_init_policy
1479
1527
 
@@ -1481,6 +1529,19 @@ class JaxBackpropPlanner:
1481
1529
  optimizer = self.optimizer
1482
1530
  projection = self.plan.projection
1483
1531
 
1532
+ # add Gaussian gradient noise per Neelakantan et al., 2016.
1533
+ def _jax_wrapped_gaussian_param_noise(key, grads, sigma):
1534
+ treedef = jax.tree_util.tree_structure(grads)
1535
+ keys_flat = random.split(key, num=treedef.num_leaves)
1536
+ keys_tree = jax.tree_util.tree_unflatten(treedef, keys_flat)
1537
+ new_grads = jax.tree_map(
1538
+ lambda g, k: g + sigma * random.normal(
1539
+ key=k, shape=g.shape, dtype=g.dtype),
1540
+ grads,
1541
+ keys_tree
1542
+ )
1543
+ return new_grads
1544
+
1484
1545
  # calculate the plan gradient w.r.t. return loss and update optimizer
1485
1546
  # also perform a projection step to satisfy constraints on actions
1486
1547
  def _jax_wrapped_plan_update(key, policy_params, hyperparams,
@@ -1488,12 +1549,14 @@ class JaxBackpropPlanner:
1488
1549
  grad_fn = jax.value_and_grad(loss, argnums=1, has_aux=True)
1489
1550
  (loss_val, log), grad = grad_fn(
1490
1551
  key, policy_params, hyperparams, subs, model_params)
1552
+ sigma = opt_aux.get('noise_sigma', 0.0)
1553
+ grad = _jax_wrapped_gaussian_param_noise(key, grad, sigma)
1491
1554
  updates, opt_state = optimizer.update(grad, opt_state)
1492
1555
  policy_params = optax.apply_updates(policy_params, updates)
1493
1556
  policy_params, converged = projection(policy_params, hyperparams)
1494
1557
  log['grad'] = grad
1495
1558
  log['updates'] = updates
1496
- return policy_params, converged, opt_state, None, loss_val, log
1559
+ return policy_params, converged, opt_state, opt_aux, loss_val, log
1497
1560
 
1498
1561
  return jax.jit(_jax_wrapped_plan_update)
1499
1562
 
@@ -1524,7 +1587,7 @@ class JaxBackpropPlanner:
1524
1587
  return init_train, init_test
1525
1588
 
1526
1589
  def as_optimization_problem(
1527
- self, key: Optional[random.PRNGKey]=None,
1590
+ self, key: Optional[random.PRNGKey]=None,
1528
1591
  policy_hyperparams: Optional[Pytree]=None,
1529
1592
  loss_function_updates_key: bool=True,
1530
1593
  grad_function_updates_key: bool=False) -> Tuple[Callable, Callable, np.ndarray, Callable]:
@@ -1576,7 +1639,7 @@ class JaxBackpropPlanner:
1576
1639
  @jax.jit
1577
1640
  def _loss_with_key(key, params_1d):
1578
1641
  policy_params = unravel_fn(params_1d)
1579
- loss_val, _ = loss_fn(key, policy_params, policy_hyperparams,
1642
+ loss_val, _ = loss_fn(key, policy_params, policy_hyperparams,
1580
1643
  train_subs, model_params)
1581
1644
  return loss_val
1582
1645
 
@@ -1584,7 +1647,7 @@ class JaxBackpropPlanner:
1584
1647
  def _grad_with_key(key, params_1d):
1585
1648
  policy_params = unravel_fn(params_1d)
1586
1649
  grad_fn = jax.grad(loss_fn, argnums=1, has_aux=True)
1587
- grad_val, _ = grad_fn(key, policy_params, policy_hyperparams,
1650
+ grad_val, _ = grad_fn(key, policy_params, policy_hyperparams,
1588
1651
  train_subs, model_params)
1589
1652
  grad_1d = jax.flatten_util.ravel_pytree(grad_val)[0]
1590
1653
  return grad_1d
@@ -1633,6 +1696,7 @@ class JaxBackpropPlanner:
1633
1696
  :param print_summary: whether to print planner header, parameter
1634
1697
  summary, and diagnosis
1635
1698
  :param print_progress: whether to print the progress bar during training
1699
+ :param stopping_rule: stopping criterion
1636
1700
  :param test_rolling_window: the test return is averaged on a rolling
1637
1701
  window of the past test_rolling_window returns when updating the best
1638
1702
  parameters found so far
@@ -1658,13 +1722,14 @@ class JaxBackpropPlanner:
1658
1722
  epochs: int=999999,
1659
1723
  train_seconds: float=120.,
1660
1724
  plot_step: Optional[int]=None,
1661
- plot_kwargs: Optional[Dict[str, Any]]=None,
1725
+ plot_kwargs: Optional[Kwargs]=None,
1662
1726
  model_params: Optional[Dict[str, Any]]=None,
1663
1727
  policy_hyperparams: Optional[Dict[str, Any]]=None,
1664
1728
  subs: Optional[Dict[str, Any]]=None,
1665
1729
  guess: Optional[Pytree]=None,
1666
1730
  print_summary: bool=True,
1667
1731
  print_progress: bool=True,
1732
+ stopping_rule: Optional[JaxPlannerStoppingRule]=None,
1668
1733
  test_rolling_window: int=10,
1669
1734
  tqdm_position: Optional[int]=None) -> Generator[Dict[str, Any], None, None]:
1670
1735
  '''Returns a generator for computing an optimal policy or plan.
@@ -1686,6 +1751,7 @@ class JaxBackpropPlanner:
1686
1751
  :param print_summary: whether to print planner header, parameter
1687
1752
  summary, and diagnosis
1688
1753
  :param print_progress: whether to print the progress bar during training
1754
+ :param stopping_rule: stopping criterion
1689
1755
  :param test_rolling_window: the test return is averaged on a rolling
1690
1756
  window of the past test_rolling_window returns when updating the best
1691
1757
  parameters found so far
@@ -1737,10 +1803,11 @@ class JaxBackpropPlanner:
1737
1803
  f' plot_frequency ={plot_step}\n'
1738
1804
  f' plot_kwargs ={plot_kwargs}\n'
1739
1805
  f' print_summary ={print_summary}\n'
1740
- f' print_progress ={print_progress}\n')
1806
+ f' print_progress ={print_progress}\n'
1807
+ f' stopping_rule ={stopping_rule}\n')
1741
1808
  if self.compiled.relaxations:
1742
1809
  print('Some RDDL operations are non-differentiable, '
1743
- 'replacing them with differentiable relaxations:')
1810
+ 'they will be approximated as follows:')
1744
1811
  print(self.compiled.summarize_model_relaxations())
1745
1812
 
1746
1813
  # compute a batched version of the initial values
@@ -1773,7 +1840,7 @@ class JaxBackpropPlanner:
1773
1840
  else:
1774
1841
  policy_params = guess
1775
1842
  opt_state = self.optimizer.init(policy_params)
1776
- opt_aux = None
1843
+ opt_aux = {}
1777
1844
 
1778
1845
  # initialize running statistics
1779
1846
  best_params, best_loss, best_grad = policy_params, jnp.inf, jnp.inf
@@ -1783,6 +1850,10 @@ class JaxBackpropPlanner:
1783
1850
  status = JaxPlannerStatus.NORMAL
1784
1851
  is_all_zero_fn = lambda x: np.allclose(x, 0)
1785
1852
 
1853
+ # initialize stopping criterion
1854
+ if stopping_rule is not None:
1855
+ stopping_rule.reset()
1856
+
1786
1857
  # initialize plot area
1787
1858
  if plot_step is None or plot_step <= 0 or plt is None:
1788
1859
  plot = None
@@ -1801,6 +1872,11 @@ class JaxBackpropPlanner:
1801
1872
  for it in iters:
1802
1873
  status = JaxPlannerStatus.NORMAL
1803
1874
 
1875
+ # gradient noise schedule
1876
+ noise_var = self.noise_grad_eta / (1. + it) ** self.noise_grad_gamma
1877
+ noise_sigma = np.sqrt(noise_var)
1878
+ opt_aux['noise_sigma'] = noise_sigma
1879
+
1804
1880
  # update the parameters of the plan
1805
1881
  key, subkey = random.split(key)
1806
1882
  policy_params, converged, opt_state, opt_aux, \
@@ -1865,8 +1941,7 @@ class JaxBackpropPlanner:
1865
1941
  status = JaxPlannerStatus.ITER_BUDGET_REACHED
1866
1942
 
1867
1943
  # return a callback
1868
- start_time_outside = time.time()
1869
- yield {
1944
+ callback = {
1870
1945
  'status': status,
1871
1946
  'iteration': it,
1872
1947
  'train_return':-train_loss,
@@ -1877,16 +1952,23 @@ class JaxBackpropPlanner:
1877
1952
  'last_iteration_improved': last_iter_improve,
1878
1953
  'grad': train_log['grad'],
1879
1954
  'best_grad': best_grad,
1955
+ 'noise_sigma': noise_sigma,
1880
1956
  'updates': train_log['updates'],
1881
1957
  'elapsed_time': elapsed,
1882
1958
  'key': key,
1883
1959
  **log
1884
1960
  }
1961
+ start_time_outside = time.time()
1962
+ yield callback
1885
1963
  elapsed_outside_loop += (time.time() - start_time_outside)
1886
1964
 
1887
1965
  # abortion check
1888
1966
  if status.is_failure():
1889
1967
  break
1968
+
1969
+ # stopping condition reached
1970
+ if stopping_rule is not None and stopping_rule.monitor(callback):
1971
+ break
1890
1972
 
1891
1973
  # release resources
1892
1974
  if print_progress:
@@ -1918,7 +2000,7 @@ class JaxBackpropPlanner:
1918
2000
  f' best_grad_norm={grad_norm}\n'
1919
2001
  f' diagnosis: {diagnosis}\n')
1920
2002
 
1921
- def _perform_diagnosis(self, last_iter_improve,
2003
+ def _perform_diagnosis(self, last_iter_improve,
1922
2004
  train_return, test_return, best_return, grad_norm):
1923
2005
  max_grad_norm = max(jax.tree_util.tree_leaves(grad_norm))
1924
2006
  grad_is_zero = np.allclose(max_grad_norm, 0)
@@ -2097,7 +2179,7 @@ class JaxLineSearchPlanner(JaxBackpropPlanner):
2097
2179
  trials += 1
2098
2180
  step *= decay
2099
2181
  f_step, new_params, new_state = _jax_wrapped_line_search_trial(
2100
- step, grad, key, policy_params, hyperparams, subs,
2182
+ step, grad, key, policy_params, hyperparams, subs,
2101
2183
  model_params, opt_state)
2102
2184
  if f_step < best_f:
2103
2185
  best_f, best_step, best_params, best_state = \
@@ -2106,11 +2188,11 @@ class JaxLineSearchPlanner(JaxBackpropPlanner):
2106
2188
  log['updates'] = None
2107
2189
  log['line_search_iters'] = trials
2108
2190
  log['learning_rate'] = best_step
2109
- return best_params, True, best_state, best_step, best_f, log
2191
+ opt_aux['best_step'] = best_step
2192
+ return best_params, True, best_state, opt_aux, best_f, log
2110
2193
 
2111
2194
  return _jax_wrapped_plan_update
2112
2195
 
2113
-
2114
2196
  # ***********************************************************************
2115
2197
  # ALL VERSIONS OF RISK FUNCTIONS
2116
2198
  #
@@ -2141,7 +2223,6 @@ def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
2141
2223
  alpha_mask = jax.lax.stop_gradient(
2142
2224
  returns <= jnp.percentile(returns, q=100 * alpha))
2143
2225
  return jnp.sum(returns * alpha_mask) / jnp.sum(alpha_mask)
2144
-
2145
2226
 
2146
2227
  # ***********************************************************************
2147
2228
  # ALL VERSIONS OF CONTROLLERS
@@ -2151,12 +2232,13 @@ def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
2151
2232
  #
2152
2233
  # ***********************************************************************
2153
2234
 
2235
+
2154
2236
  class JaxOfflineController(BaseAgent):
2155
2237
  '''A container class for a Jax policy trained offline.'''
2156
2238
 
2157
2239
  use_tensor_obs = True
2158
2240
 
2159
- def __init__(self, planner: JaxBackpropPlanner,
2241
+ def __init__(self, planner: JaxBackpropPlanner,
2160
2242
  key: Optional[random.PRNGKey]=None,
2161
2243
  eval_hyperparams: Optional[Dict[str, Any]]=None,
2162
2244
  params: Optional[Pytree]=None,
@@ -2211,7 +2293,7 @@ class JaxOnlineController(BaseAgent):
2211
2293
 
2212
2294
  use_tensor_obs = True
2213
2295
 
2214
- def __init__(self, planner: JaxBackpropPlanner,
2296
+ def __init__(self, planner: JaxBackpropPlanner,
2215
2297
  key: Optional[random.PRNGKey]=None,
2216
2298
  eval_hyperparams: Optional[Dict[str, Any]]=None,
2217
2299
  warm_start: bool=True,