pyRDDLGym-jax 0.3__py3-none-any.whl → 0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,54 +2,50 @@ from ast import literal_eval
2
2
  from collections import deque
3
3
  import configparser
4
4
  from enum import Enum
5
+ import os
6
+ import sys
7
+ import time
8
+ import traceback
9
+ from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Tuple, Union
10
+
5
11
  import haiku as hk
6
12
  import jax
13
+ import jax.nn.initializers as initializers
7
14
  import jax.numpy as jnp
8
15
  import jax.random as random
9
- import jax.nn.initializers as initializers
10
16
  import numpy as np
11
17
  import optax
12
- import os
13
- import sys
14
18
  import termcolor
15
- import time
16
- import traceback
17
19
  from tqdm import tqdm
18
- from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Tuple, Union
19
20
 
20
- Activation = Callable[[jnp.ndarray], jnp.ndarray]
21
- Bounds = Dict[str, Tuple[np.ndarray, np.ndarray]]
22
- Kwargs = Dict[str, Any]
23
- Pytree = Any
24
-
25
- from pyRDDLGym.core.debug.exception import raise_warning
26
-
27
- from pyRDDLGym_jax import __version__
28
-
29
- # try to import matplotlib, if failed then skip plotting
30
- try:
31
- import matplotlib
32
- import matplotlib.pyplot as plt
33
- matplotlib.use('TkAgg')
34
- except Exception:
35
- raise_warning('failed to import matplotlib: '
36
- 'plotting functionality will be disabled.', 'red')
37
- traceback.print_exc()
38
- plt = None
39
-
40
21
  from pyRDDLGym.core.compiler.model import RDDLPlanningModel, RDDLLiftedModel
41
22
  from pyRDDLGym.core.debug.logger import Logger
42
23
  from pyRDDLGym.core.debug.exception import (
24
+ raise_warning,
43
25
  RDDLNotImplementedError,
44
26
  RDDLUndefinedVariableError,
45
27
  RDDLTypeError
46
28
  )
47
29
  from pyRDDLGym.core.policy import BaseAgent
48
30
 
49
- from pyRDDLGym_jax.core.compiler import JaxRDDLCompiler
31
+ from pyRDDLGym_jax import __version__
50
32
  from pyRDDLGym_jax.core import logic
33
+ from pyRDDLGym_jax.core.compiler import JaxRDDLCompiler
51
34
  from pyRDDLGym_jax.core.logic import FuzzyLogic
52
35
 
36
+ # try to import matplotlib, if failed then skip plotting
37
+ try:
38
+ import matplotlib.pyplot as plt
39
+ except Exception:
40
+ raise_warning('failed to import matplotlib: '
41
+ 'plotting functionality will be disabled.', 'red')
42
+ traceback.print_exc()
43
+ plt = None
44
+
45
+ Activation = Callable[[jnp.ndarray], jnp.ndarray]
46
+ Bounds = Dict[str, Tuple[np.ndarray, np.ndarray]]
47
+ Kwargs = Dict[str, Any]
48
+ Pytree = Any
53
49
 
54
50
  # ***********************************************************************
55
51
  # CONFIG FILE MANAGEMENT
@@ -60,6 +56,7 @@ from pyRDDLGym_jax.core.logic import FuzzyLogic
60
56
  #
61
57
  # ***********************************************************************
62
58
 
59
+
63
60
  def _parse_config_file(path: str):
64
61
  if not os.path.isfile(path):
65
62
  raise FileNotFoundError(f'File {path} does not exist.')
@@ -104,9 +101,15 @@ def _load_config(config, args):
104
101
  comp_kwargs = model_args.get('complement_kwargs', {})
105
102
  compare_name = model_args.get('comparison', 'SigmoidComparison')
106
103
  compare_kwargs = model_args.get('comparison_kwargs', {})
104
+ sampling_name = model_args.get('sampling', 'GumbelSoftmax')
105
+ sampling_kwargs = model_args.get('sampling_kwargs', {})
106
+ rounding_name = model_args.get('rounding', 'SoftRounding')
107
+ rounding_kwargs = model_args.get('rounding_kwargs', {})
107
108
  logic_kwargs['tnorm'] = getattr(logic, tnorm_name)(**tnorm_kwargs)
108
109
  logic_kwargs['complement'] = getattr(logic, comp_name)(**comp_kwargs)
109
110
  logic_kwargs['comparison'] = getattr(logic, compare_name)(**compare_kwargs)
111
+ logic_kwargs['sampling'] = getattr(logic, sampling_name)(**sampling_kwargs)
112
+ logic_kwargs['rounding'] = getattr(logic, rounding_name)(**rounding_kwargs)
110
113
 
111
114
  # read the policy settings
112
115
  plan_method = planner_args.pop('method')
@@ -157,11 +160,18 @@ def _load_config(config, args):
157
160
  else:
158
161
  planner_args['optimizer'] = optimizer
159
162
 
160
- # read the optimize call settings
163
+ # optimize call RNG key
161
164
  planner_key = train_args.get('key', None)
162
165
  if planner_key is not None:
163
166
  train_args['key'] = random.PRNGKey(planner_key)
164
167
 
168
+ # optimize call stopping rule
169
+ stopping_rule = train_args.get('stopping_rule', None)
170
+ if stopping_rule is not None:
171
+ stopping_rule_kwargs = train_args.pop('stopping_rule_kwargs', {})
172
+ train_args['stopping_rule'] = getattr(
173
+ sys.modules[__name__], stopping_rule)(**stopping_rule_kwargs)
174
+
165
175
  return planner_args, plan_kwargs, train_args
166
176
 
167
177
 
@@ -175,7 +185,6 @@ def load_config_from_string(value: str) -> Tuple[Kwargs, ...]:
175
185
  '''Loads config file contents specified explicitly as a string value.'''
176
186
  config, args = _parse_config_string(value)
177
187
  return _load_config(config, args)
178
-
179
188
 
180
189
  # ***********************************************************************
181
190
  # MODEL RELAXATIONS
@@ -184,18 +193,6 @@ def load_config_from_string(value: str) -> Tuple[Kwargs, ...]:
184
193
  #
185
194
  # ***********************************************************************
186
195
 
187
- def _function_discrete_approx_named(logic):
188
- jax_discrete, jax_param = logic.discrete()
189
-
190
- def _jax_wrapped_discrete_calc_approx(key, prob, params):
191
- sample = jax_discrete(key, prob, params)
192
- out_of_bounds = jnp.logical_not(jnp.logical_and(
193
- jnp.all(prob >= 0),
194
- jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
195
- return sample, out_of_bounds
196
-
197
- return _jax_wrapped_discrete_calc_approx, jax_param
198
-
199
196
 
200
197
  class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
201
198
  '''Compiles a RDDL AST representation to an equivalent JAX representation.
@@ -271,7 +268,9 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
271
268
  self.IF_HELPER = logic.control_if()
272
269
  self.SWITCH_HELPER = logic.control_switch()
273
270
  self.BERNOULLI_HELPER = logic.bernoulli()
274
- self.DISCRETE_HELPER = _function_discrete_approx_named(logic)
271
+ self.DISCRETE_HELPER = logic.discrete()
272
+ self.POISSON_HELPER = logic.poisson()
273
+ self.GEOMETRIC_HELPER = logic.geometric()
275
274
 
276
275
  def _jax_stop_grad(self, jax_expr):
277
276
 
@@ -309,7 +308,6 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
309
308
  arg, = expr.args
310
309
  arg = self._jax(arg, info)
311
310
  return arg
312
-
313
311
 
314
312
  # ***********************************************************************
315
313
  # ALL VERSIONS OF JAX PLANS
@@ -319,6 +317,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
319
317
  #
320
318
  # ***********************************************************************
321
319
 
320
+
322
321
  class JaxPlan:
323
322
  '''Base class for all JAX policy representations.'''
324
323
 
@@ -373,7 +372,7 @@ class JaxPlan:
373
372
  self._projection = value
374
373
 
375
374
  def _calculate_action_info(self, compiled: JaxRDDLCompilerWithGrad,
376
- user_bounds: Bounds,
375
+ user_bounds: Bounds,
377
376
  horizon: int):
378
377
  shapes, bounds, bounds_safe, cond_lists = {}, {}, {}, {}
379
378
  for (name, prange) in compiled.rddl.variable_ranges.items():
@@ -469,10 +468,11 @@ class JaxStraightLinePlan(JaxPlan):
469
468
  f' wrap_non_bool ={self._wrap_non_bool}\n'
470
469
  f'constraint-sat strategy (complex):\n'
471
470
  f' wrap_softmax ={self._wrap_softmax}\n'
472
- f' use_new_projection ={self._use_new_projection}')
471
+ f' use_new_projection ={self._use_new_projection}\n'
472
+ f' max_projection_iters ={self._max_constraint_iter}')
473
473
 
474
474
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
475
- _bounds: Bounds,
475
+ _bounds: Bounds,
476
476
  horizon: int) -> None:
477
477
  rddl = compiled.rddl
478
478
 
@@ -513,7 +513,7 @@ class JaxStraightLinePlan(JaxPlan):
513
513
  def _jax_bool_action_to_param(var, action, hyperparams):
514
514
  if wrap_sigmoid:
515
515
  weight = hyperparams[var]
516
- return (-1.0 / weight) * jnp.log(1.0 / action - 1.0)
516
+ return jax.scipy.special.logit(action) / weight
517
517
  else:
518
518
  return action
519
519
 
@@ -522,14 +522,13 @@ class JaxStraightLinePlan(JaxPlan):
522
522
  def _jax_non_bool_param_to_action(var, param, hyperparams):
523
523
  if wrap_non_bool:
524
524
  lower, upper = bounds_safe[var]
525
- action = jnp.select(
526
- condlist=cond_lists[var],
527
- choicelist=[
528
- lower + (upper - lower) * jax.nn.sigmoid(param),
529
- lower + (jax.nn.elu(param) + 1.0),
530
- upper - (jax.nn.elu(-param) + 1.0),
531
- param
532
- ]
525
+ mb, ml, mu, mn = [mask.astype(compiled.REAL)
526
+ for mask in cond_lists[var]]
527
+ action = (
528
+ mb * (lower + (upper - lower) * jax.nn.sigmoid(param)) +
529
+ ml * (lower + (jax.nn.elu(param) + 1.0)) +
530
+ mu * (upper - (jax.nn.elu(-param) + 1.0)) +
531
+ mn * param
533
532
  )
534
533
  else:
535
534
  action = param
@@ -789,7 +788,7 @@ class JaxDeepReactivePolicy(JaxPlan):
789
788
  def __init__(self, topology: Optional[Sequence[int]]=None,
790
789
  activation: Activation=jnp.tanh,
791
790
  initializer: hk.initializers.Initializer=hk.initializers.VarianceScaling(scale=2.0),
792
- normalize: bool=False,
791
+ normalize: bool=False,
793
792
  normalize_per_layer: bool=False,
794
793
  normalizer_kwargs: Optional[Kwargs]=None,
795
794
  wrap_non_bool: bool=False) -> None:
@@ -837,7 +836,7 @@ class JaxDeepReactivePolicy(JaxPlan):
837
836
  f' wrap_non_bool ={self._wrap_non_bool}')
838
837
 
839
838
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
840
- _bounds: Bounds,
839
+ _bounds: Bounds,
841
840
  horizon: int) -> None:
842
841
  rddl = compiled.rddl
843
842
 
@@ -890,7 +889,7 @@ class JaxDeepReactivePolicy(JaxPlan):
890
889
  if normalize_per_layer and value_size == 1:
891
890
  raise_warning(
892
891
  f'Cannot apply layer norm to state-fluent <{var}> '
893
- f'of size 1: setting normalize_per_layer = False.',
892
+ f'of size 1: setting normalize_per_layer = False.',
894
893
  'red')
895
894
  normalize_per_layer = False
896
895
  non_bool_dims += value_size
@@ -915,8 +914,8 @@ class JaxDeepReactivePolicy(JaxPlan):
915
914
  else:
916
915
  if normalize and normalize_per_layer:
917
916
  normalizer = hk.LayerNorm(
918
- axis=-1, param_axis=-1,
919
- name=f'input_norm_{input_names[var]}',
917
+ axis=-1, param_axis=-1,
918
+ name=f'input_norm_{input_names[var]}',
920
919
  **self._normalizer_kwargs)
921
920
  state = normalizer(state)
922
921
  states_non_bool.append(state)
@@ -926,7 +925,7 @@ class JaxDeepReactivePolicy(JaxPlan):
926
925
  # optionally perform layer normalization on the non-bool inputs
927
926
  if normalize and not normalize_per_layer and non_bool_dims:
928
927
  normalizer = hk.LayerNorm(
929
- axis=-1, param_axis=-1, name='input_norm',
928
+ axis=-1, param_axis=-1, name='input_norm',
930
929
  **self._normalizer_kwargs)
931
930
  normalized = normalizer(state[:non_bool_dims])
932
931
  state = state.at[:non_bool_dims].set(normalized)
@@ -959,14 +958,13 @@ class JaxDeepReactivePolicy(JaxPlan):
959
958
  else:
960
959
  if wrap_non_bool:
961
960
  lower, upper = bounds_safe[var]
962
- action = jnp.select(
963
- condlist=cond_lists[var],
964
- choicelist=[
965
- lower + (upper - lower) * jax.nn.sigmoid(output),
966
- lower + (jax.nn.elu(output) + 1.0),
967
- upper - (jax.nn.elu(-output) + 1.0),
968
- output
969
- ]
961
+ mb, ml, mu, mn = [mask.astype(compiled.REAL)
962
+ for mask in cond_lists[var]]
963
+ action = (
964
+ mb * (lower + (upper - lower) * jax.nn.sigmoid(output)) +
965
+ ml * (lower + (jax.nn.elu(output) + 1.0)) +
966
+ mu * (upper - (jax.nn.elu(-output) + 1.0)) +
967
+ mn * output
970
968
  )
971
969
  else:
972
970
  action = output
@@ -1058,7 +1056,6 @@ class JaxDeepReactivePolicy(JaxPlan):
1058
1056
 
1059
1057
  def guess_next_epoch(self, params: Pytree) -> Pytree:
1060
1058
  return params
1061
-
1062
1059
 
1063
1060
  # ***********************************************************************
1064
1061
  # ALL VERSIONS OF JAX PLANNER
@@ -1068,6 +1065,7 @@ class JaxDeepReactivePolicy(JaxPlan):
1068
1065
  #
1069
1066
  # ***********************************************************************
1070
1067
 
1068
+
1071
1069
  class RollingMean:
1072
1070
  '''Maintains an estimate of the rolling mean of a stream of real-valued
1073
1071
  observations.'''
@@ -1089,7 +1087,7 @@ class RollingMean:
1089
1087
  class JaxPlannerPlot:
1090
1088
  '''Supports plotting and visualization of a JAX policy in real time.'''
1091
1089
 
1092
- def __init__(self, rddl: RDDLPlanningModel, horizon: int,
1090
+ def __init__(self, rddl: RDDLPlanningModel, horizon: int,
1093
1091
  show_violin: bool=True, show_action: bool=True) -> None:
1094
1092
  '''Creates a new planner visualizer.
1095
1093
 
@@ -1137,7 +1135,7 @@ class JaxPlannerPlot:
1137
1135
  for dim in rddl.object_counts(rddl.variable_params[name]):
1138
1136
  action_dim *= dim
1139
1137
  action_plot = ax.pcolormesh(
1140
- np.zeros((action_dim, horizon)),
1138
+ np.zeros((action_dim, horizon)),
1141
1139
  cmap='seismic', vmin=vmin, vmax=vmax)
1142
1140
  ax.set_aspect('auto')
1143
1141
  ax.set_xlabel('decision epoch')
@@ -1210,6 +1208,39 @@ class JaxPlannerStatus(Enum):
1210
1208
  return self.value >= 3
1211
1209
 
1212
1210
 
1211
+ class JaxPlannerStoppingRule:
1212
+ '''The base class of all planner stopping rules.'''
1213
+
1214
+ def reset(self) -> None:
1215
+ raise NotImplementedError
1216
+
1217
+ def monitor(self, callback: Dict[str, Any]) -> bool:
1218
+ raise NotImplementedError
1219
+
1220
+
1221
+ class NoImprovementStoppingRule(JaxPlannerStoppingRule):
1222
+ '''Stopping rule based on no improvement for a fixed number of iterations.'''
1223
+
1224
+ def __init__(self, patience: int) -> None:
1225
+ self.patience = patience
1226
+
1227
+ def reset(self) -> None:
1228
+ self.callback = None
1229
+ self.iters_since_last_update = 0
1230
+
1231
+ def monitor(self, callback: Dict[str, Any]) -> bool:
1232
+ if self.callback is None \
1233
+ or callback['best_return'] > self.callback['best_return']:
1234
+ self.callback = callback
1235
+ self.iters_since_last_update = 0
1236
+ else:
1237
+ self.iters_since_last_update += 1
1238
+ return self.iters_since_last_update >= self.patience
1239
+
1240
+ def __str__(self) -> str:
1241
+ return f'No improvement for {self.patience} iterations'
1242
+
1243
+
1213
1244
  class JaxBackpropPlanner:
1214
1245
  '''A class for optimizing an action sequence in the given RDDL MDP using
1215
1246
  gradient descent.'''
@@ -1224,6 +1255,8 @@ class JaxBackpropPlanner:
1224
1255
  optimizer: Callable[..., optax.GradientTransformation]=optax.rmsprop,
1225
1256
  optimizer_kwargs: Optional[Kwargs]=None,
1226
1257
  clip_grad: Optional[float]=None,
1258
+ noise_grad_eta: float=0.0,
1259
+ noise_grad_gamma: float=1.0,
1227
1260
  logic: FuzzyLogic=FuzzyLogic(),
1228
1261
  use_symlog_reward: bool=False,
1229
1262
  utility: Union[Callable[[jnp.ndarray], float], str]='mean',
@@ -1250,6 +1283,8 @@ class JaxBackpropPlanner:
1250
1283
  :param optimizer_kwargs: a dictionary of parameters to pass to the SGD
1251
1284
  factory (e.g. which parameters are controllable externally)
1252
1285
  :param clip_grad: maximum magnitude of gradient updates
1286
+ :param noise_grad_eta: scale of the gradient noise variance
1287
+ :param noise_grad_gamma: decay rate of the gradient noise variance
1253
1288
  :param logic: a subclass of FuzzyLogic for mapping exact mathematical
1254
1289
  operations to their differentiable counterparts
1255
1290
  :param use_symlog_reward: whether to use the symlog transform on the
@@ -1284,6 +1319,8 @@ class JaxBackpropPlanner:
1284
1319
  optimizer_kwargs = {'learning_rate': 0.1}
1285
1320
  self._optimizer_kwargs = optimizer_kwargs
1286
1321
  self.clip_grad = clip_grad
1322
+ self.noise_grad_eta = noise_grad_eta
1323
+ self.noise_grad_gamma = noise_grad_gamma
1287
1324
 
1288
1325
  # set optimizer
1289
1326
  try:
@@ -1348,8 +1385,18 @@ class JaxBackpropPlanner:
1348
1385
  map(str, jax._src.xla_bridge.devices())).replace('\n', '')
1349
1386
  except Exception as _:
1350
1387
  devices_short = 'N/A'
1388
+ LOGO = \
1389
+ r"""
1390
+ __ ______ __ __ ______ __ ______ __ __
1391
+ /\ \ /\ __ \ /\_\_\_\ /\ == \/\ \ /\ __ \ /\ "-.\ \
1392
+ _\_\ \\ \ __ \\/_/\_\/_\ \ _-/\ \ \____\ \ __ \\ \ \-. \
1393
+ /\_____\\ \_\ \_\ /\_\/\_\\ \_\ \ \_____\\ \_\ \_\\ \_\\"\_\
1394
+ \/_____/ \/_/\/_/ \/_/\/_/ \/_/ \/_____/ \/_/\/_/ \/_/ \/_/
1395
+ """
1396
+
1351
1397
  print('\n'
1352
- f'JAX Planner version {__version__}\n'
1398
+ f'{LOGO}\n'
1399
+ f'Version {__version__}\n'
1353
1400
  f'Python {sys.version}\n'
1354
1401
  f'jax {jax.version.__version__}, jaxlib {jaxlib_version}, '
1355
1402
  f'optax {optax.__version__}, haiku {hk.__version__}, '
@@ -1371,6 +1418,8 @@ class JaxBackpropPlanner:
1371
1418
  f' optimizer ={self._optimizer_name.__name__}\n'
1372
1419
  f' optimizer args ={self._optimizer_kwargs}\n'
1373
1420
  f' clip_gradient ={self.clip_grad}\n'
1421
+ f' noise_grad_eta ={self.noise_grad_eta}\n'
1422
+ f' noise_grad_gamma ={self.noise_grad_gamma}\n'
1374
1423
  f' batch_size_train ={self.batch_size_train}\n'
1375
1424
  f' batch_size_test ={self.batch_size_test}')
1376
1425
  self.plan.summarize_hyperparameters()
@@ -1395,7 +1444,7 @@ class JaxBackpropPlanner:
1395
1444
 
1396
1445
  # Jax compilation of the exact RDDL for testing
1397
1446
  self.test_compiled = JaxRDDLCompiler(
1398
- rddl=rddl,
1447
+ rddl=rddl,
1399
1448
  logger=self.logger,
1400
1449
  use64bit=self.use64bit)
1401
1450
  self.test_compiled.compile(log_jax_expr=True, heading='EXACT MODEL')
@@ -1472,7 +1521,7 @@ class JaxBackpropPlanner:
1472
1521
  def _jax_wrapped_init_policy(key, hyperparams, subs):
1473
1522
  policy_params = init(key, hyperparams, subs)
1474
1523
  opt_state = optimizer.init(policy_params)
1475
- return policy_params, opt_state, None
1524
+ return policy_params, opt_state, {}
1476
1525
 
1477
1526
  return _jax_wrapped_init_policy
1478
1527
 
@@ -1480,6 +1529,19 @@ class JaxBackpropPlanner:
1480
1529
  optimizer = self.optimizer
1481
1530
  projection = self.plan.projection
1482
1531
 
1532
+ # add Gaussian gradient noise per Neelakantan et al., 2016.
1533
+ def _jax_wrapped_gaussian_param_noise(key, grads, sigma):
1534
+ treedef = jax.tree_util.tree_structure(grads)
1535
+ keys_flat = random.split(key, num=treedef.num_leaves)
1536
+ keys_tree = jax.tree_util.tree_unflatten(treedef, keys_flat)
1537
+ new_grads = jax.tree_map(
1538
+ lambda g, k: g + sigma * random.normal(
1539
+ key=k, shape=g.shape, dtype=g.dtype),
1540
+ grads,
1541
+ keys_tree
1542
+ )
1543
+ return new_grads
1544
+
1483
1545
  # calculate the plan gradient w.r.t. return loss and update optimizer
1484
1546
  # also perform a projection step to satisfy constraints on actions
1485
1547
  def _jax_wrapped_plan_update(key, policy_params, hyperparams,
@@ -1487,12 +1549,14 @@ class JaxBackpropPlanner:
1487
1549
  grad_fn = jax.value_and_grad(loss, argnums=1, has_aux=True)
1488
1550
  (loss_val, log), grad = grad_fn(
1489
1551
  key, policy_params, hyperparams, subs, model_params)
1552
+ sigma = opt_aux.get('noise_sigma', 0.0)
1553
+ grad = _jax_wrapped_gaussian_param_noise(key, grad, sigma)
1490
1554
  updates, opt_state = optimizer.update(grad, opt_state)
1491
1555
  policy_params = optax.apply_updates(policy_params, updates)
1492
1556
  policy_params, converged = projection(policy_params, hyperparams)
1493
1557
  log['grad'] = grad
1494
1558
  log['updates'] = updates
1495
- return policy_params, converged, opt_state, None, loss_val, log
1559
+ return policy_params, converged, opt_state, opt_aux, loss_val, log
1496
1560
 
1497
1561
  return jax.jit(_jax_wrapped_plan_update)
1498
1562
 
@@ -1523,7 +1587,7 @@ class JaxBackpropPlanner:
1523
1587
  return init_train, init_test
1524
1588
 
1525
1589
  def as_optimization_problem(
1526
- self, key: Optional[random.PRNGKey]=None,
1590
+ self, key: Optional[random.PRNGKey]=None,
1527
1591
  policy_hyperparams: Optional[Pytree]=None,
1528
1592
  loss_function_updates_key: bool=True,
1529
1593
  grad_function_updates_key: bool=False) -> Tuple[Callable, Callable, np.ndarray, Callable]:
@@ -1575,7 +1639,7 @@ class JaxBackpropPlanner:
1575
1639
  @jax.jit
1576
1640
  def _loss_with_key(key, params_1d):
1577
1641
  policy_params = unravel_fn(params_1d)
1578
- loss_val, _ = loss_fn(key, policy_params, policy_hyperparams,
1642
+ loss_val, _ = loss_fn(key, policy_params, policy_hyperparams,
1579
1643
  train_subs, model_params)
1580
1644
  return loss_val
1581
1645
 
@@ -1583,7 +1647,7 @@ class JaxBackpropPlanner:
1583
1647
  def _grad_with_key(key, params_1d):
1584
1648
  policy_params = unravel_fn(params_1d)
1585
1649
  grad_fn = jax.grad(loss_fn, argnums=1, has_aux=True)
1586
- grad_val, _ = grad_fn(key, policy_params, policy_hyperparams,
1650
+ grad_val, _ = grad_fn(key, policy_params, policy_hyperparams,
1587
1651
  train_subs, model_params)
1588
1652
  grad_1d = jax.flatten_util.ravel_pytree(grad_val)[0]
1589
1653
  return grad_1d
@@ -1632,6 +1696,7 @@ class JaxBackpropPlanner:
1632
1696
  :param print_summary: whether to print planner header, parameter
1633
1697
  summary, and diagnosis
1634
1698
  :param print_progress: whether to print the progress bar during training
1699
+ :param stopping_rule: stopping criterion
1635
1700
  :param test_rolling_window: the test return is averaged on a rolling
1636
1701
  window of the past test_rolling_window returns when updating the best
1637
1702
  parameters found so far
@@ -1657,13 +1722,14 @@ class JaxBackpropPlanner:
1657
1722
  epochs: int=999999,
1658
1723
  train_seconds: float=120.,
1659
1724
  plot_step: Optional[int]=None,
1660
- plot_kwargs: Optional[Dict[str, Any]]=None,
1725
+ plot_kwargs: Optional[Kwargs]=None,
1661
1726
  model_params: Optional[Dict[str, Any]]=None,
1662
1727
  policy_hyperparams: Optional[Dict[str, Any]]=None,
1663
1728
  subs: Optional[Dict[str, Any]]=None,
1664
1729
  guess: Optional[Pytree]=None,
1665
1730
  print_summary: bool=True,
1666
1731
  print_progress: bool=True,
1732
+ stopping_rule: Optional[JaxPlannerStoppingRule]=None,
1667
1733
  test_rolling_window: int=10,
1668
1734
  tqdm_position: Optional[int]=None) -> Generator[Dict[str, Any], None, None]:
1669
1735
  '''Returns a generator for computing an optimal policy or plan.
@@ -1685,6 +1751,7 @@ class JaxBackpropPlanner:
1685
1751
  :param print_summary: whether to print planner header, parameter
1686
1752
  summary, and diagnosis
1687
1753
  :param print_progress: whether to print the progress bar during training
1754
+ :param stopping_rule: stopping criterion
1688
1755
  :param test_rolling_window: the test return is averaged on a rolling
1689
1756
  window of the past test_rolling_window returns when updating the best
1690
1757
  parameters found so far
@@ -1711,6 +1778,14 @@ class JaxBackpropPlanner:
1711
1778
  hyperparam_value = float(policy_hyperparams)
1712
1779
  policy_hyperparams = {action: hyperparam_value
1713
1780
  for action in self.rddl.action_fluents}
1781
+
1782
+ # fill in missing entries
1783
+ elif isinstance(policy_hyperparams, dict):
1784
+ for action in self.rddl.action_fluents:
1785
+ if action not in policy_hyperparams:
1786
+ raise_warning(f'policy_hyperparams[{action}] is not set, '
1787
+ 'setting 1.0 which could be suboptimal.')
1788
+ policy_hyperparams[action] = 1.0
1714
1789
 
1715
1790
  # print summary of parameters:
1716
1791
  if print_summary:
@@ -1728,10 +1803,11 @@ class JaxBackpropPlanner:
1728
1803
  f' plot_frequency ={plot_step}\n'
1729
1804
  f' plot_kwargs ={plot_kwargs}\n'
1730
1805
  f' print_summary ={print_summary}\n'
1731
- f' print_progress ={print_progress}\n')
1806
+ f' print_progress ={print_progress}\n'
1807
+ f' stopping_rule ={stopping_rule}\n')
1732
1808
  if self.compiled.relaxations:
1733
1809
  print('Some RDDL operations are non-differentiable, '
1734
- 'replacing them with differentiable relaxations:')
1810
+ 'they will be approximated as follows:')
1735
1811
  print(self.compiled.summarize_model_relaxations())
1736
1812
 
1737
1813
  # compute a batched version of the initial values
@@ -1764,7 +1840,7 @@ class JaxBackpropPlanner:
1764
1840
  else:
1765
1841
  policy_params = guess
1766
1842
  opt_state = self.optimizer.init(policy_params)
1767
- opt_aux = None
1843
+ opt_aux = {}
1768
1844
 
1769
1845
  # initialize running statistics
1770
1846
  best_params, best_loss, best_grad = policy_params, jnp.inf, jnp.inf
@@ -1772,7 +1848,12 @@ class JaxBackpropPlanner:
1772
1848
  rolling_test_loss = RollingMean(test_rolling_window)
1773
1849
  log = {}
1774
1850
  status = JaxPlannerStatus.NORMAL
1851
+ is_all_zero_fn = lambda x: np.allclose(x, 0)
1775
1852
 
1853
+ # initialize stopping criterion
1854
+ if stopping_rule is not None:
1855
+ stopping_rule.reset()
1856
+
1776
1857
  # initialize plot area
1777
1858
  if plot_step is None or plot_step <= 0 or plt is None:
1778
1859
  plot = None
@@ -1786,10 +1867,16 @@ class JaxBackpropPlanner:
1786
1867
  iters = range(epochs)
1787
1868
  if print_progress:
1788
1869
  iters = tqdm(iters, total=100, position=tqdm_position)
1870
+ position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
1789
1871
 
1790
1872
  for it in iters:
1791
1873
  status = JaxPlannerStatus.NORMAL
1792
1874
 
1875
+ # gradient noise schedule
1876
+ noise_var = self.noise_grad_eta / (1. + it) ** self.noise_grad_gamma
1877
+ noise_sigma = np.sqrt(noise_var)
1878
+ opt_aux['noise_sigma'] = noise_sigma
1879
+
1793
1880
  # update the parameters of the plan
1794
1881
  key, subkey = random.split(key)
1795
1882
  policy_params, converged, opt_state, opt_aux, \
@@ -1799,7 +1886,7 @@ class JaxBackpropPlanner:
1799
1886
 
1800
1887
  # no progress
1801
1888
  grad_norm_zero, _ = jax.tree_util.tree_flatten(
1802
- jax.tree_map(lambda x: np.allclose(x, 0), train_log['grad']))
1889
+ jax.tree_map(is_all_zero_fn, train_log['grad']))
1803
1890
  if np.all(grad_norm_zero):
1804
1891
  status = JaxPlannerStatus.NO_PROGRESS
1805
1892
 
@@ -1843,8 +1930,9 @@ class JaxBackpropPlanner:
1843
1930
  if print_progress:
1844
1931
  iters.n = int(100 * min(1, max(elapsed / train_seconds, it / epochs)))
1845
1932
  iters.set_description(
1846
- f'[{tqdm_position}] {it:6} it / {-train_loss:14.6f} train / '
1847
- f'{-test_loss:14.6f} test / {-best_loss:14.6f} best')
1933
+ f'{position_str} {it:6} it / {-train_loss:14.6f} train / '
1934
+ f'{-test_loss:14.6f} test / {-best_loss:14.6f} best / '
1935
+ f'{status.value} status')
1848
1936
 
1849
1937
  # reached computation budget
1850
1938
  if elapsed >= train_seconds:
@@ -1853,8 +1941,7 @@ class JaxBackpropPlanner:
1853
1941
  status = JaxPlannerStatus.ITER_BUDGET_REACHED
1854
1942
 
1855
1943
  # return a callback
1856
- start_time_outside = time.time()
1857
- yield {
1944
+ callback = {
1858
1945
  'status': status,
1859
1946
  'iteration': it,
1860
1947
  'train_return':-train_loss,
@@ -1865,16 +1952,23 @@ class JaxBackpropPlanner:
1865
1952
  'last_iteration_improved': last_iter_improve,
1866
1953
  'grad': train_log['grad'],
1867
1954
  'best_grad': best_grad,
1955
+ 'noise_sigma': noise_sigma,
1868
1956
  'updates': train_log['updates'],
1869
1957
  'elapsed_time': elapsed,
1870
1958
  'key': key,
1871
1959
  **log
1872
1960
  }
1961
+ start_time_outside = time.time()
1962
+ yield callback
1873
1963
  elapsed_outside_loop += (time.time() - start_time_outside)
1874
1964
 
1875
1965
  # abortion check
1876
1966
  if status.is_failure():
1877
1967
  break
1968
+
1969
+ # stopping condition reached
1970
+ if stopping_rule is not None and stopping_rule.monitor(callback):
1971
+ break
1878
1972
 
1879
1973
  # release resources
1880
1974
  if print_progress:
@@ -1904,9 +1998,9 @@ class JaxBackpropPlanner:
1904
1998
  f' iterations ={it}\n'
1905
1999
  f' best_objective={-best_loss}\n'
1906
2000
  f' best_grad_norm={grad_norm}\n'
1907
- f'diagnosis: {diagnosis}\n')
2001
+ f' diagnosis: {diagnosis}\n')
1908
2002
 
1909
- def _perform_diagnosis(self, last_iter_improve,
2003
+ def _perform_diagnosis(self, last_iter_improve,
1910
2004
  train_return, test_return, best_return, grad_norm):
1911
2005
  max_grad_norm = max(jax.tree_util.tree_leaves(grad_norm))
1912
2006
  grad_is_zero = np.allclose(max_grad_norm, 0)
@@ -2085,7 +2179,7 @@ class JaxLineSearchPlanner(JaxBackpropPlanner):
2085
2179
  trials += 1
2086
2180
  step *= decay
2087
2181
  f_step, new_params, new_state = _jax_wrapped_line_search_trial(
2088
- step, grad, key, policy_params, hyperparams, subs,
2182
+ step, grad, key, policy_params, hyperparams, subs,
2089
2183
  model_params, opt_state)
2090
2184
  if f_step < best_f:
2091
2185
  best_f, best_step, best_params, best_state = \
@@ -2094,11 +2188,11 @@ class JaxLineSearchPlanner(JaxBackpropPlanner):
2094
2188
  log['updates'] = None
2095
2189
  log['line_search_iters'] = trials
2096
2190
  log['learning_rate'] = best_step
2097
- return best_params, True, best_state, best_step, best_f, log
2191
+ opt_aux['best_step'] = best_step
2192
+ return best_params, True, best_state, opt_aux, best_f, log
2098
2193
 
2099
2194
  return _jax_wrapped_plan_update
2100
2195
 
2101
-
2102
2196
  # ***********************************************************************
2103
2197
  # ALL VERSIONS OF RISK FUNCTIONS
2104
2198
  #
@@ -2116,7 +2210,7 @@ class JaxLineSearchPlanner(JaxBackpropPlanner):
2116
2210
  @jax.jit
2117
2211
  def entropic_utility(returns: jnp.ndarray, beta: float) -> float:
2118
2212
  return (-1.0 / beta) * jax.scipy.special.logsumexp(
2119
- -beta * returns, b=1.0 / returns.size)
2213
+ -beta * returns, b=1.0 / returns.size)
2120
2214
 
2121
2215
 
2122
2216
  @jax.jit
@@ -2129,7 +2223,6 @@ def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
2129
2223
  alpha_mask = jax.lax.stop_gradient(
2130
2224
  returns <= jnp.percentile(returns, q=100 * alpha))
2131
2225
  return jnp.sum(returns * alpha_mask) / jnp.sum(alpha_mask)
2132
-
2133
2226
 
2134
2227
  # ***********************************************************************
2135
2228
  # ALL VERSIONS OF CONTROLLERS
@@ -2139,12 +2232,13 @@ def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
2139
2232
  #
2140
2233
  # ***********************************************************************
2141
2234
 
2235
+
2142
2236
  class JaxOfflineController(BaseAgent):
2143
2237
  '''A container class for a Jax policy trained offline.'''
2144
2238
 
2145
2239
  use_tensor_obs = True
2146
2240
 
2147
- def __init__(self, planner: JaxBackpropPlanner,
2241
+ def __init__(self, planner: JaxBackpropPlanner,
2148
2242
  key: Optional[random.PRNGKey]=None,
2149
2243
  eval_hyperparams: Optional[Dict[str, Any]]=None,
2150
2244
  params: Optional[Pytree]=None,
@@ -2199,7 +2293,7 @@ class JaxOnlineController(BaseAgent):
2199
2293
 
2200
2294
  use_tensor_obs = True
2201
2295
 
2202
- def __init__(self, planner: JaxBackpropPlanner,
2296
+ def __init__(self, planner: JaxBackpropPlanner,
2203
2297
  key: Optional[random.PRNGKey]=None,
2204
2298
  eval_hyperparams: Optional[Dict[str, Any]]=None,
2205
2299
  warm_start: bool=True,