pyRDDLGym-jax 2.3__py3-none-any.whl → 2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,7 +3,7 @@
3
3
  #
4
4
  # Author: Michael Gimelfarb
5
5
  #
6
- # RELEVANT SOURCES:
6
+ # REFERENCES:
7
7
  #
8
8
  # [1] Gimelfarb, Michael, Ayal Taitler, and Scott Sanner. "JaxPlan and GurobiPlan:
9
9
  # Optimization Baselines for Replanning in Discrete and Mixed Discrete-Continuous
@@ -18,16 +18,21 @@
18
18
  # reactive policies for planning in stochastic nonlinear domains." In Proceedings of the
19
19
  # AAAI Conference on Artificial Intelligence, vol. 33, no. 01, pp. 7530-7537. 2019.
20
20
  #
21
- # [4] Wu, Ga, Buser Say, and Scott Sanner. "Scalable planning with tensorflow for hybrid
21
+ # [4] Cui, Hao, Thomas Keller, and Roni Khardon. "Stochastic planning with lifted symbolic
22
+ # trajectory optimization." In Proceedings of the International Conference on Automated
23
+ # Planning and Scheduling, vol. 29, pp. 119-127. 2019.
24
+ #
25
+ # [5] Wu, Ga, Buser Say, and Scott Sanner. "Scalable planning with tensorflow for hybrid
22
26
  # nonlinear domains." Advances in Neural Information Processing Systems 30 (2017).
23
27
  #
24
- # [5] Sehnke, Frank, and Tingting Zhao. "Baseline-free sampling in parameter exploring
28
+ # [6] Sehnke, Frank, and Tingting Zhao. "Baseline-free sampling in parameter exploring
25
29
  # policy gradients: Super symmetric pgpe." Artificial Neural Networks: Methods and
26
30
  # Applications in Bio-/Neuroinformatics. Springer International Publishing, 2015.
27
31
  #
28
32
  # ***********************************************************************
29
33
 
30
34
 
35
+ from abc import ABCMeta, abstractmethod
31
36
  from ast import literal_eval
32
37
  from collections import deque
33
38
  import configparser
@@ -37,7 +42,8 @@ import os
37
42
  import sys
38
43
  import time
39
44
  import traceback
40
- from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Type, Tuple, Union
45
+ from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Type, Tuple, \
46
+ Union
41
47
 
42
48
  import haiku as hk
43
49
  import jax
@@ -51,6 +57,7 @@ from tqdm import tqdm, TqdmWarning
51
57
  import warnings
52
58
  warnings.filterwarnings("ignore", category=TqdmWarning)
53
59
 
60
+ from pyRDDLGym.core.compiler.initializer import RDDLValueInitializer
54
61
  from pyRDDLGym.core.compiler.model import RDDLPlanningModel, RDDLLiftedModel
55
62
  from pyRDDLGym.core.debug.logger import Logger
56
63
  from pyRDDLGym.core.debug.exception import (
@@ -157,25 +164,20 @@ def _load_config(config, args):
157
164
  initializer = _getattr_any(
158
165
  packages=[initializers, hk.initializers], item=plan_initializer)
159
166
  if initializer is None:
160
- raise_warning(f'Ignoring invalid initializer <{plan_initializer}>.', 'red')
161
- del plan_kwargs['initializer']
167
+ raise ValueError(f'Invalid initializer <{plan_initializer}>.')
162
168
  else:
163
169
  init_kwargs = plan_kwargs.pop('initializer_kwargs', {})
164
170
  try:
165
171
  plan_kwargs['initializer'] = initializer(**init_kwargs)
166
172
  except Exception as _:
167
- raise_warning(
168
- f'Ignoring invalid initializer_kwargs <{init_kwargs}>.', 'red')
169
- plan_kwargs['initializer'] = initializer
173
+ raise ValueError(f'Invalid initializer kwargs <{init_kwargs}>.')
170
174
 
171
175
  # policy activation
172
176
  plan_activation = plan_kwargs.get('activation', None)
173
177
  if plan_activation is not None:
174
- activation = _getattr_any(
175
- packages=[jax.nn, jax.numpy], item=plan_activation)
178
+ activation = _getattr_any(packages=[jax.nn, jax.numpy], item=plan_activation)
176
179
  if activation is None:
177
- raise_warning(f'Ignoring invalid activation <{plan_activation}>.', 'red')
178
- del plan_kwargs['activation']
180
+ raise ValueError(f'Invalid activation <{plan_activation}>.')
179
181
  else:
180
182
  plan_kwargs['activation'] = activation
181
183
 
@@ -188,8 +190,7 @@ def _load_config(config, args):
188
190
  if planner_optimizer is not None:
189
191
  optimizer = _getattr_any(packages=[optax], item=planner_optimizer)
190
192
  if optimizer is None:
191
- raise_warning(f'Ignoring invalid optimizer <{planner_optimizer}>.', 'red')
192
- del planner_args['optimizer']
193
+ raise ValueError(f'Invalid optimizer <{planner_optimizer}>.')
193
194
  else:
194
195
  planner_args['optimizer'] = optimizer
195
196
 
@@ -200,8 +201,7 @@ def _load_config(config, args):
200
201
  if 'optimizer' in pgpe_kwargs:
201
202
  pgpe_optimizer = _getattr_any(packages=[optax], item=pgpe_kwargs['optimizer'])
202
203
  if pgpe_optimizer is None:
203
- raise_warning(f'Ignoring invalid optimizer <{pgpe_optimizer}>.', 'red')
204
- del pgpe_kwargs['optimizer']
204
+ raise ValueError(f'Invalid optimizer <{pgpe_optimizer}>.')
205
205
  else:
206
206
  pgpe_kwargs['optimizer'] = pgpe_optimizer
207
207
  planner_args['pgpe'] = getattr(sys.modules[__name__], pgpe_method)(**pgpe_kwargs)
@@ -260,8 +260,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
260
260
  cpfs_without_grad: Optional[Set[str]]=None,
261
261
  **kwargs) -> None:
262
262
  '''Creates a new RDDL to Jax compiler, where operations that are not
263
- differentiable are converted to approximate forms that have defined
264
- gradients.
263
+ differentiable are converted to approximate forms that have defined gradients.
265
264
 
266
265
  :param *args: arguments to pass to base compiler
267
266
  :param logic: Fuzzy logic object that specifies how exact operations
@@ -286,8 +285,10 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
286
285
  if not np.issubdtype(np.result_type(values), np.floating):
287
286
  pvars_cast.add(var)
288
287
  if pvars_cast:
289
- raise_warning(f'JAX gradient compiler requires that initial values '
290
- f'of p-variables {pvars_cast} be cast to float.')
288
+ message = termcolor.colored(
289
+ f'[INFO] JAX gradient compiler will cast p-vars {pvars_cast} to float.',
290
+ 'green')
291
+ print(message)
291
292
 
292
293
  # overwrite basic operations with fuzzy ones
293
294
  self.OPS = logic.get_operator_dicts()
@@ -300,6 +301,8 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
300
301
  return _jax_wrapped_stop_grad
301
302
 
302
303
  def _compile_cpfs(self, init_params):
304
+
305
+ # cpfs will all be cast to float
303
306
  cpfs_cast = set()
304
307
  jax_cpfs = {}
305
308
  for (_, cpfs) in self.levels.items():
@@ -312,11 +315,15 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
312
315
  jax_cpfs[cpf] = self._jax_stop_grad(jax_cpfs[cpf])
313
316
 
314
317
  if cpfs_cast:
315
- raise_warning(f'JAX gradient compiler requires that outputs of CPFs '
316
- f'{cpfs_cast} be cast to float.')
318
+ message = termcolor.colored(
319
+ f'[INFO] JAX gradient compiler will cast CPFs {cpfs_cast} to float.',
320
+ 'green')
321
+ print(message)
317
322
  if self.cpfs_without_grad:
318
- raise_warning(f'User requested that gradients not flow '
319
- f'through CPFs {self.cpfs_without_grad}.')
323
+ message = termcolor.colored(
324
+ f'[INFO] Gradients will not flow through CPFs {self.cpfs_without_grad}.',
325
+ 'green')
326
+ print(message)
320
327
 
321
328
  return jax_cpfs
322
329
 
@@ -335,7 +342,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
335
342
  # ***********************************************************************
336
343
 
337
344
 
338
- class JaxPlan:
345
+ class JaxPlan(metaclass=ABCMeta):
339
346
  '''Base class for all JAX policy representations.'''
340
347
 
341
348
  def __init__(self) -> None:
@@ -345,16 +352,18 @@ class JaxPlan:
345
352
  self._projection = None
346
353
  self.bounds = None
347
354
 
348
- def summarize_hyperparameters(self) -> None:
349
- print(self.__str__())
350
-
355
+ def summarize_hyperparameters(self) -> str:
356
+ return self.__str__()
357
+
358
+ @abstractmethod
351
359
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
352
360
  _bounds: Bounds,
353
361
  horizon: int) -> None:
354
- raise NotImplementedError
362
+ pass
355
363
 
364
+ @abstractmethod
356
365
  def guess_next_epoch(self, params: Pytree) -> Pytree:
357
- raise NotImplementedError
366
+ pass
358
367
 
359
368
  @property
360
369
  def initializer(self):
@@ -397,10 +406,11 @@ class JaxPlan:
397
406
  continue
398
407
 
399
408
  # check invalid type
400
- if prange not in compiled.JAX_TYPES:
409
+ if prange not in compiled.JAX_TYPES and prange not in compiled.rddl.enum_types:
410
+ keys = list(compiled.JAX_TYPES.keys()) + list(compiled.rddl.enum_types)
401
411
  raise RDDLTypeError(
402
412
  f'Invalid range <{prange}> of action-fluent <{name}>, '
403
- f'must be one of {set(compiled.JAX_TYPES.keys())}.')
413
+ f'must be one of {keys}.')
404
414
 
405
415
  # clip boolean to (0, 1), otherwise use the RDDL action bounds
406
416
  # or the user defined action bounds if provided
@@ -408,7 +418,12 @@ class JaxPlan:
408
418
  if prange == 'bool':
409
419
  lower, upper = None, None
410
420
  else:
411
- lower, upper = compiled.constraints.bounds[name]
421
+ if prange in compiled.rddl.enum_types:
422
+ lower = np.zeros(shape=shapes[name][1:])
423
+ upper = len(compiled.rddl.type_to_objects[prange]) - 1
424
+ upper = np.ones(shape=shapes[name][1:]) * upper
425
+ else:
426
+ lower, upper = compiled.constraints.bounds[name]
412
427
  lower, upper = user_bounds.get(name, (lower, upper))
413
428
  lower = np.asarray(lower, dtype=compiled.REAL)
414
429
  upper = np.asarray(upper, dtype=compiled.REAL)
@@ -421,7 +436,10 @@ class JaxPlan:
421
436
  ~lower_finite & upper_finite,
422
437
  ~lower_finite & ~upper_finite]
423
438
  bounds[name] = (lower, upper)
424
- raise_warning(f'Bounds of action-fluent <{name}> set to {bounds[name]}.')
439
+ message = termcolor.colored(
440
+ f'[INFO] Bounds of action-fluent <{name}> set to {bounds[name]}.',
441
+ 'green')
442
+ print(message)
425
443
  return shapes, bounds, bounds_safe, cond_lists
426
444
 
427
445
  def _count_bool_actions(self, rddl: RDDLLiftedModel):
@@ -502,10 +520,11 @@ class JaxStraightLinePlan(JaxPlan):
502
520
  bool_action_count, allowed_actions = self._count_bool_actions(rddl)
503
521
  use_constraint_satisfaction = allowed_actions < bool_action_count
504
522
  if use_constraint_satisfaction:
505
- raise_warning(f'Using projected gradient trick to satisfy '
506
- f'max_nondef_actions: total boolean actions '
507
- f'{bool_action_count} > max_nondef_actions '
508
- f'{allowed_actions}.')
523
+ message = termcolor.colored(
524
+ f'[INFO] SLP will use projected gradient to satisfy '
525
+ f'max_nondef_actions since total boolean actions '
526
+ f'{bool_action_count} > max_nondef_actions {allowed_actions}.', 'green')
527
+ print(message)
509
528
 
510
529
  noop = {var: (values[0] if isinstance(values, list) else values)
511
530
  for (var, values) in rddl.action_fluents.items()}
@@ -623,7 +642,7 @@ class JaxStraightLinePlan(JaxPlan):
623
642
  else:
624
643
  action = _jax_non_bool_param_to_action(var, action, hyperparams)
625
644
  action = jnp.clip(action, *bounds[var])
626
- if ranges[var] == 'int':
645
+ if ranges[var] == 'int' or ranges[var] in rddl.enum_types:
627
646
  action = jnp.asarray(jnp.round(action), dtype=compiled.INT)
628
647
  actions[var] = action
629
648
  return actions
@@ -642,7 +661,7 @@ class JaxStraightLinePlan(JaxPlan):
642
661
  # only allow one action non-noop for now
643
662
  if 1 < allowed_actions < bool_action_count:
644
663
  raise RDDLNotImplementedError(
645
- f'Straight-line plans with wrap_softmax currently '
664
+ f'SLPs with wrap_softmax currently '
646
665
  f'do not support max-nondef-actions {allowed_actions} > 1.')
647
666
 
648
667
  # potentially apply projection but to non-bool actions only
@@ -764,7 +783,8 @@ class JaxStraightLinePlan(JaxPlan):
764
783
  for (var, action) in actions.items():
765
784
  if ranges[var] == 'bool':
766
785
  action = jnp.clip(action, min_action, max_action)
767
- new_params[var] = _jax_bool_action_to_param(var, action, hyperparams)
786
+ param = _jax_bool_action_to_param(var, action, hyperparams)
787
+ new_params[var] = param
768
788
  else:
769
789
  new_params[var] = action
770
790
  return new_params, converged
@@ -890,8 +910,7 @@ class JaxDeepReactivePolicy(JaxPlan):
890
910
  bool_action_count, allowed_actions = self._count_bool_actions(rddl)
891
911
  if 1 < allowed_actions < bool_action_count:
892
912
  raise RDDLNotImplementedError(
893
- f'Deep reactive policies currently do not support '
894
- f'max-nondef-actions {allowed_actions} > 1.')
913
+ f'DRPs currently do not support max-nondef-actions {allowed_actions} > 1.')
895
914
  use_constraint_satisfaction = allowed_actions < bool_action_count
896
915
 
897
916
  noop = {var: (values[0] if isinstance(values, list) else values)
@@ -927,15 +946,17 @@ class JaxDeepReactivePolicy(JaxPlan):
927
946
  if ranges[var] != 'bool':
928
947
  value_size = np.size(values)
929
948
  if normalize_per_layer and value_size == 1:
930
- raise_warning(
931
- f'Cannot apply layer norm to state-fluent <{var}> '
932
- f'of size 1: setting normalize_per_layer = False.', 'red')
949
+ message = termcolor.colored(
950
+ f'[WARN] Cannot apply layer norm to state-fluent <{var}> '
951
+ f'of size 1: setting normalize_per_layer = False.', 'yellow')
952
+ print(message)
933
953
  normalize_per_layer = False
934
954
  non_bool_dims += value_size
935
955
  if not normalize_per_layer and non_bool_dims == 1:
936
- raise_warning(
937
- 'Cannot apply layer norm to state-fluents of total size 1: '
938
- 'setting normalize = False.', 'red')
956
+ message = termcolor.colored(
957
+ '[WARN] Cannot apply layer norm to state-fluents of total size 1: '
958
+ 'setting normalize = False.', 'yellow')
959
+ print(message)
939
960
  normalize = False
940
961
 
941
962
  # convert subs dictionary into a state vector to feed to the MLP
@@ -1061,7 +1082,7 @@ class JaxDeepReactivePolicy(JaxPlan):
1061
1082
  prange = ranges[var]
1062
1083
  if prange == 'bool':
1063
1084
  new_action = action > 0.5
1064
- elif prange == 'int':
1085
+ elif prange == 'int' or prange in rddl.enum_types:
1065
1086
  action = jnp.clip(action, *bounds[var])
1066
1087
  new_action = jnp.asarray(jnp.round(action), dtype=compiled.INT)
1067
1088
  else:
@@ -1112,19 +1133,18 @@ class JaxDeepReactivePolicy(JaxPlan):
1112
1133
 
1113
1134
 
1114
1135
  class RollingMean:
1115
- '''Maintains an estimate of the rolling mean of a stream of real-valued
1116
- observations.'''
1136
+ '''Maintains the rolling mean of a stream of real-valued observations.'''
1117
1137
 
1118
1138
  def __init__(self, window_size: int) -> None:
1119
1139
  self._window_size = window_size
1120
1140
  self._memory = deque(maxlen=window_size)
1121
1141
  self._total = 0
1122
1142
 
1123
- def update(self, x: float) -> float:
1143
+ def update(self, x: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
1124
1144
  memory = self._memory
1125
- self._total += x
1145
+ self._total = self._total + x
1126
1146
  if len(memory) == self._window_size:
1127
- self._total -= memory.popleft()
1147
+ self._total = self._total - memory.popleft()
1128
1148
  memory.append(x)
1129
1149
  return self._total / len(memory)
1130
1150
 
@@ -1147,14 +1167,16 @@ class JaxPlannerStatus(Enum):
1147
1167
  return self.value == 1 or self.value >= 4
1148
1168
 
1149
1169
 
1150
- class JaxPlannerStoppingRule:
1170
+ class JaxPlannerStoppingRule(metaclass=ABCMeta):
1151
1171
  '''The base class of all planner stopping rules.'''
1152
1172
 
1173
+ @abstractmethod
1153
1174
  def reset(self) -> None:
1154
- raise NotImplementedError
1155
-
1175
+ pass
1176
+
1177
+ @abstractmethod
1156
1178
  def monitor(self, callback: Dict[str, Any]) -> bool:
1157
- raise NotImplementedError
1179
+ pass
1158
1180
 
1159
1181
 
1160
1182
  class NoImprovementStoppingRule(JaxPlannerStoppingRule):
@@ -1168,8 +1190,7 @@ class NoImprovementStoppingRule(JaxPlannerStoppingRule):
1168
1190
  self.iters_since_last_update = 0
1169
1191
 
1170
1192
  def monitor(self, callback: Dict[str, Any]) -> bool:
1171
- if self.callback is None \
1172
- or callback['best_return'] > self.callback['best_return']:
1193
+ if self.callback is None or callback['best_return'] > self.callback['best_return']:
1173
1194
  self.callback = callback
1174
1195
  self.iters_since_last_update = 0
1175
1196
  else:
@@ -1188,7 +1209,7 @@ class NoImprovementStoppingRule(JaxPlannerStoppingRule):
1188
1209
  # ***********************************************************************
1189
1210
 
1190
1211
 
1191
- class PGPE:
1212
+ class PGPE(metaclass=ABCMeta):
1192
1213
  """Base class for all PGPE strategies."""
1193
1214
 
1194
1215
  def __init__(self) -> None:
@@ -1203,8 +1224,10 @@ class PGPE:
1203
1224
  def update(self):
1204
1225
  return self._update
1205
1226
 
1206
- def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type) -> None:
1207
- raise NotImplementedError
1227
+ @abstractmethod
1228
+ def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type,
1229
+ parallel_updates: Optional[int]=None) -> None:
1230
+ pass
1208
1231
 
1209
1232
 
1210
1233
  class GaussianPGPE(PGPE):
@@ -1268,10 +1291,11 @@ class GaussianPGPE(PGPE):
1268
1291
  mu_optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs_mu)
1269
1292
  sigma_optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs_sigma)
1270
1293
  except Exception as _:
1271
- raise_warning(
1272
- f'Failed to inject hyperparameters into optax optimizer for PGPE, '
1273
- 'rolling back to safer method: please note that kl-divergence '
1274
- 'constraints will be disabled.', 'red')
1294
+ message = termcolor.colored(
1295
+ '[FAIL] Failed to inject hyperparameters into PGPE optimizer, '
1296
+ 'rolling back to safer method: '
1297
+ 'kl-divergence constraint will be disabled.', 'red')
1298
+ print(message)
1275
1299
  mu_optimizer = optimizer(**optimizer_kwargs_mu)
1276
1300
  sigma_optimizer = optimizer(**optimizer_kwargs_sigma)
1277
1301
  max_kl_update = None
@@ -1297,15 +1321,16 @@ class GaussianPGPE(PGPE):
1297
1321
  f' max_kl_update ={self.max_kl}\n'
1298
1322
  )
1299
1323
 
1300
- def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type) -> None:
1324
+ def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type,
1325
+ parallel_updates: Optional[int]=None) -> None:
1301
1326
  sigma0 = self.init_sigma
1302
- sigma_range = self.sigma_range
1327
+ sigma_lo, sigma_hi = self.sigma_range
1303
1328
  scale_reward = self.scale_reward
1304
1329
  min_reward_scale = self.min_reward_scale
1305
1330
  super_symmetric = self.super_symmetric
1306
1331
  super_symmetric_accurate = self.super_symmetric_accurate
1307
1332
  batch_size = self.batch_size
1308
- optimizers = (mu_optimizer, sigma_optimizer) = self.optimizers
1333
+ mu_optimizer, sigma_optimizer = self.optimizers
1309
1334
  max_kl = self.max_kl
1310
1335
 
1311
1336
  # entropy regularization penalty is decayed exponentially by elapsed budget
@@ -1322,13 +1347,22 @@ class GaussianPGPE(PGPE):
1322
1347
 
1323
1348
  def _jax_wrapped_pgpe_init(key, policy_params):
1324
1349
  mu = policy_params
1325
- sigma = jax.tree_map(lambda x: sigma0 * jnp.ones_like(x), mu)
1350
+ sigma = jax.tree_map(partial(jnp.full_like, fill_value=sigma0), mu)
1326
1351
  pgpe_params = (mu, sigma)
1327
- pgpe_opt_state = tuple(opt.init(param)
1328
- for (opt, param) in zip(optimizers, pgpe_params))
1329
- return pgpe_params, pgpe_opt_state
1352
+ pgpe_opt_state = (mu_optimizer.init(mu), sigma_optimizer.init(sigma))
1353
+ r_max = -jnp.inf
1354
+ return pgpe_params, pgpe_opt_state, r_max
1330
1355
 
1331
- self._initializer = jax.jit(_jax_wrapped_pgpe_init)
1356
+ if parallel_updates is None:
1357
+ self._initializer = jax.jit(_jax_wrapped_pgpe_init)
1358
+ else:
1359
+
1360
+ # for parallel policy update
1361
+ def _jax_wrapped_pgpe_inits(key, policy_params):
1362
+ keys = jnp.asarray(random.split(key, num=parallel_updates))
1363
+ return jax.vmap(_jax_wrapped_pgpe_init, in_axes=0)(keys, policy_params)
1364
+
1365
+ self._initializer = jax.jit(_jax_wrapped_pgpe_inits)
1332
1366
 
1333
1367
  # ***********************************************************************
1334
1368
  # PARAMETER SAMPLING FUNCTIONS
@@ -1338,6 +1372,8 @@ class GaussianPGPE(PGPE):
1338
1372
  def _jax_wrapped_mu_noise(key, sigma):
1339
1373
  return sigma * random.normal(key, shape=jnp.shape(sigma), dtype=real_dtype)
1340
1374
 
1375
+ # this samples a noise variable epsilon* from epsilon with the N(0, 1) density
1376
+ # according to super-symmetric sampling paper
1341
1377
  def _jax_wrapped_epsilon_star(sigma, epsilon):
1342
1378
  c1, c2, c3 = -0.06655, -0.9706, 0.124
1343
1379
  phi = 0.67449 * sigma
@@ -1354,6 +1390,7 @@ class GaussianPGPE(PGPE):
1354
1390
  epsilon_star = jnp.sign(epsilon) * phi * jnp.exp(a)
1355
1391
  return epsilon_star
1356
1392
 
1393
+ # implements baseline-free super-symmetric sampling to generate 4 trajectories
1357
1394
  def _jax_wrapped_sample_params(key, mu, sigma):
1358
1395
  treedef = jax.tree_util.tree_structure(sigma)
1359
1396
  keys = random.split(key, num=treedef.num_leaves)
@@ -1374,6 +1411,7 @@ class GaussianPGPE(PGPE):
1374
1411
  #
1375
1412
  # ***********************************************************************
1376
1413
 
1414
+ # gradient with respect to mean
1377
1415
  def _jax_wrapped_mu_grad(epsilon, epsilon_star, r1, r2, r3, r4, m):
1378
1416
  if super_symmetric:
1379
1417
  if scale_reward:
@@ -1393,6 +1431,7 @@ class GaussianPGPE(PGPE):
1393
1431
  grad = -r_mu * epsilon
1394
1432
  return grad
1395
1433
 
1434
+ # gradient with respect to std. deviation
1396
1435
  def _jax_wrapped_sigma_grad(epsilon, epsilon_star, sigma, r1, r2, r3, r4, m, ent):
1397
1436
  if super_symmetric:
1398
1437
  mask = r1 + r2 >= r3 + r4
@@ -1413,6 +1452,7 @@ class GaussianPGPE(PGPE):
1413
1452
  grad = -(r_sigma * s + ent / sigma)
1414
1453
  return grad
1415
1454
 
1455
+ # calculate the policy gradients
1416
1456
  def _jax_wrapped_pgpe_grad(key, mu, sigma, r_max, ent,
1417
1457
  policy_hyperparams, subs, model_params):
1418
1458
  key, subkey = random.split(key)
@@ -1462,11 +1502,24 @@ class GaussianPGPE(PGPE):
1462
1502
  #
1463
1503
  # ***********************************************************************
1464
1504
 
1505
+ # estimate KL divergence between two updates
1465
1506
  def _jax_wrapped_pgpe_kl_term(mu, sigma, old_mu, old_sigma):
1466
1507
  return 0.5 * jnp.sum(2 * jnp.log(sigma / old_sigma) +
1467
1508
  jnp.square(old_sigma / sigma) +
1468
1509
  jnp.square((mu - old_mu) / sigma) - 1)
1469
1510
 
1511
+ # update mean and std. deviation with a gradient step
1512
+ def _jax_wrapped_pgpe_update_helper(mu, sigma, mu_grad, sigma_grad,
1513
+ mu_state, sigma_state):
1514
+ mu_updates, new_mu_state = mu_optimizer.update(mu_grad, mu_state, params=mu)
1515
+ sigma_updates, new_sigma_state = sigma_optimizer.update(
1516
+ sigma_grad, sigma_state, params=sigma)
1517
+ new_mu = optax.apply_updates(mu, mu_updates)
1518
+ new_sigma = optax.apply_updates(sigma, sigma_updates)
1519
+ new_sigma = jax.tree_map(
1520
+ partial(jnp.clip, min=sigma_lo, max=sigma_hi), new_sigma)
1521
+ return new_mu, new_sigma, new_mu_state, new_sigma_state
1522
+
1470
1523
  def _jax_wrapped_pgpe_update(key, pgpe_params, r_max, progress,
1471
1524
  policy_hyperparams, subs, model_params,
1472
1525
  pgpe_opt_state):
@@ -1476,12 +1529,9 @@ class GaussianPGPE(PGPE):
1476
1529
  ent = start_entropy_coeff * jnp.power(entropy_coeff_decay, progress)
1477
1530
  mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad_batched(
1478
1531
  key, pgpe_params, r_max, ent, policy_hyperparams, subs, model_params)
1479
- mu_updates, new_mu_state = mu_optimizer.update(mu_grad, mu_state, params=mu)
1480
- sigma_updates, new_sigma_state = sigma_optimizer.update(
1481
- sigma_grad, sigma_state, params=sigma)
1482
- new_mu = optax.apply_updates(mu, mu_updates)
1483
- new_sigma = optax.apply_updates(sigma, sigma_updates)
1484
- new_sigma = jax.tree_map(lambda x: jnp.clip(x, *sigma_range), new_sigma)
1532
+ new_mu, new_sigma, new_mu_state, new_sigma_state = \
1533
+ _jax_wrapped_pgpe_update_helper(mu, sigma, mu_grad, sigma_grad,
1534
+ mu_state, sigma_state)
1485
1535
 
1486
1536
  # respect KL divergence contraint with old parameters
1487
1537
  if max_kl is not None:
@@ -1493,12 +1543,9 @@ class GaussianPGPE(PGPE):
1493
1543
  kl_reduction = jnp.minimum(1.0, jnp.sqrt(max_kl / total_kl))
1494
1544
  mu_state.hyperparams['learning_rate'] = old_mu_lr * kl_reduction
1495
1545
  sigma_state.hyperparams['learning_rate'] = old_sigma_lr * kl_reduction
1496
- mu_updates, new_mu_state = mu_optimizer.update(mu_grad, mu_state, params=mu)
1497
- sigma_updates, new_sigma_state = sigma_optimizer.update(
1498
- sigma_grad, sigma_state, params=sigma)
1499
- new_mu = optax.apply_updates(mu, mu_updates)
1500
- new_sigma = optax.apply_updates(sigma, sigma_updates)
1501
- new_sigma = jax.tree_map(lambda x: jnp.clip(x, *sigma_range), new_sigma)
1546
+ new_mu, new_sigma, new_mu_state, new_sigma_state = \
1547
+ _jax_wrapped_pgpe_update_helper(mu, sigma, mu_grad, sigma_grad,
1548
+ mu_state, sigma_state)
1502
1549
  new_mu_state.hyperparams['learning_rate'] = old_mu_lr
1503
1550
  new_sigma_state.hyperparams['learning_rate'] = old_sigma_lr
1504
1551
 
@@ -1509,7 +1556,21 @@ class GaussianPGPE(PGPE):
1509
1556
  policy_params = new_mu
1510
1557
  return new_pgpe_params, new_r_max, new_pgpe_opt_state, policy_params, converged
1511
1558
 
1512
- self._update = jax.jit(_jax_wrapped_pgpe_update)
1559
+ if parallel_updates is None:
1560
+ self._update = jax.jit(_jax_wrapped_pgpe_update)
1561
+ else:
1562
+
1563
+ # for parallel policy update
1564
+ def _jax_wrapped_pgpe_updates(key, pgpe_params, r_max, progress,
1565
+ policy_hyperparams, subs, model_params,
1566
+ pgpe_opt_state):
1567
+ keys = jnp.asarray(random.split(key, num=parallel_updates))
1568
+ return jax.vmap(
1569
+ _jax_wrapped_pgpe_update, in_axes=(0, 0, 0, None, None, None, 0, 0)
1570
+ )(keys, pgpe_params, r_max, progress, policy_hyperparams, subs,
1571
+ model_params, pgpe_opt_state)
1572
+
1573
+ self._update = jax.jit(_jax_wrapped_pgpe_updates)
1513
1574
 
1514
1575
 
1515
1576
  # ***********************************************************************
@@ -1565,6 +1626,7 @@ def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
1565
1626
  return jnp.sum(returns * weights)
1566
1627
 
1567
1628
 
1629
+ # set of all currently valid built-in utility functions
1568
1630
  UTILITY_LOOKUP = {
1569
1631
  'mean': jnp.mean,
1570
1632
  'mean_var': mean_variance_utility,
@@ -1609,7 +1671,8 @@ class JaxBackpropPlanner:
1609
1671
  cpfs_without_grad: Optional[Set[str]]=None,
1610
1672
  compile_non_fluent_exact: bool=True,
1611
1673
  logger: Optional[Logger]=None,
1612
- dashboard_viz: Optional[Any]=None) -> None:
1674
+ dashboard_viz: Optional[Any]=None,
1675
+ parallel_updates: Optional[int]=None) -> None:
1613
1676
  '''Creates a new gradient-based algorithm for optimizing action sequences
1614
1677
  (plan) in the given RDDL. Some operations will be converted to their
1615
1678
  differentiable counterparts; the specific operations can be customized
@@ -1649,6 +1712,7 @@ class JaxBackpropPlanner:
1649
1712
  :param logger: to log information about compilation to file
1650
1713
  :param dashboard_viz: optional visualizer object from the environment
1651
1714
  to pass to the dashboard to visualize the policy
1715
+ :param parallel_updates: how many optimizers to run independently in parallel
1652
1716
  '''
1653
1717
  self.rddl = rddl
1654
1718
  self.plan = plan
@@ -1656,6 +1720,7 @@ class JaxBackpropPlanner:
1656
1720
  if batch_size_test is None:
1657
1721
  batch_size_test = batch_size_train
1658
1722
  self.batch_size_test = batch_size_test
1723
+ self.parallel_updates = parallel_updates
1659
1724
  if rollout_horizon is None:
1660
1725
  rollout_horizon = rddl.horizon
1661
1726
  self.horizon = rollout_horizon
@@ -1677,10 +1742,11 @@ class JaxBackpropPlanner:
1677
1742
  try:
1678
1743
  optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs)
1679
1744
  except Exception as _:
1680
- raise_warning(
1681
- f'Failed to inject hyperparameters into optax optimizer {optimizer}, '
1682
- 'rolling back to safer method: please note that modification of '
1683
- 'optimizer hyperparameters will not work.', 'red')
1745
+ message = termcolor.colored(
1746
+ '[FAIL] Failed to inject hyperparameters into JaxPlan optimizer, '
1747
+ 'rolling back to safer method: please note that runtime modification of '
1748
+ 'hyperparameters will be disabled.', 'red')
1749
+ print(message)
1684
1750
  optimizer = optimizer(**optimizer_kwargs)
1685
1751
 
1686
1752
  # apply optimizer chain of transformations
@@ -1700,7 +1766,7 @@ class JaxBackpropPlanner:
1700
1766
  utility_fn = UTILITY_LOOKUP.get(utility, None)
1701
1767
  if utility_fn is None:
1702
1768
  raise RDDLNotImplementedError(
1703
- f'Utility <{utility}> is not supported, '
1769
+ f'Utility function <{utility}> is not supported, '
1704
1770
  f'must be one of {list(UTILITY_LOOKUP.keys())}.')
1705
1771
  else:
1706
1772
  utility_fn = utility
@@ -1742,7 +1808,7 @@ r"""
1742
1808
  \/_____/ \/_/\/_/ \/_/\/_/ \/_/ \/_____/ \/_/\/_/ \/_/ \/_/
1743
1809
  """
1744
1810
 
1745
- return ('\n'
1811
+ return (f'\n'
1746
1812
  f'{LOGO}\n'
1747
1813
  f'Version {__version__}\n'
1748
1814
  f'Python {sys.version}\n'
@@ -1751,7 +1817,23 @@ r"""
1751
1817
  f'numpy {np.__version__}\n'
1752
1818
  f'devices: {devices_short}\n')
1753
1819
 
1754
- def __str__(self) -> str:
1820
+ def summarize_relaxations(self) -> str:
1821
+ result = ''
1822
+ if self.compiled.model_params:
1823
+ result += ('Some RDDL operations are non-differentiable '
1824
+ 'and will be approximated as follows:' + '\n')
1825
+ exprs_by_rddl_op, values_by_rddl_op = {}, {}
1826
+ for info in self.compiled.model_parameter_info().values():
1827
+ rddl_op = info['rddl_op']
1828
+ exprs_by_rddl_op.setdefault(rddl_op, []).append(info['id'])
1829
+ values_by_rddl_op.setdefault(rddl_op, []).append(info['init_value'])
1830
+ for rddl_op in sorted(exprs_by_rddl_op.keys()):
1831
+ result += (f' {rddl_op}:\n'
1832
+ f' addresses ={exprs_by_rddl_op[rddl_op]}\n'
1833
+ f' init_values={values_by_rddl_op[rddl_op]}\n')
1834
+ return result
1835
+
1836
+ def summarize_hyperparameters(self) -> str:
1755
1837
  result = (f'objective hyper-parameters:\n'
1756
1838
  f' utility_fn ={self.utility.__name__}\n'
1757
1839
  f' utility args ={self.utility_kwargs}\n'
@@ -1769,30 +1851,14 @@ r"""
1769
1851
  f' line_search_kwargs={self.line_search_kwargs}\n'
1770
1852
  f' noise_kwargs ={self.noise_kwargs}\n'
1771
1853
  f' batch_size_train ={self.batch_size_train}\n'
1772
- f' batch_size_test ={self.batch_size_test}\n')
1854
+ f' batch_size_test ={self.batch_size_test}\n'
1855
+ f' parallel_updates ={self.parallel_updates}\n')
1773
1856
  result += str(self.plan)
1774
1857
  if self.use_pgpe:
1775
1858
  result += str(self.pgpe)
1776
1859
  result += str(self.logic)
1777
-
1778
- # print model relaxation information
1779
- if self.compiled.model_params:
1780
- result += ('Some RDDL operations are non-differentiable '
1781
- 'and will be approximated as follows:' + '\n')
1782
- exprs_by_rddl_op, values_by_rddl_op = {}, {}
1783
- for info in self.compiled.model_parameter_info().values():
1784
- rddl_op = info['rddl_op']
1785
- exprs_by_rddl_op.setdefault(rddl_op, []).append(info['id'])
1786
- values_by_rddl_op.setdefault(rddl_op, []).append(info['init_value'])
1787
- for rddl_op in sorted(exprs_by_rddl_op.keys()):
1788
- result += (f' {rddl_op}:\n'
1789
- f' addresses ={exprs_by_rddl_op[rddl_op]}\n'
1790
- f' init_values={values_by_rddl_op[rddl_op]}\n')
1791
1860
  return result
1792
1861
 
1793
- def summarize_hyperparameters(self) -> None:
1794
- print(self.__str__())
1795
-
1796
1862
  # ===========================================================================
1797
1863
  # COMPILATION SUBROUTINES
1798
1864
  # ===========================================================================
@@ -1844,23 +1910,31 @@ r"""
1844
1910
  self.test_rollouts = jax.jit(test_rollouts)
1845
1911
 
1846
1912
  # initialization
1847
- self.initialize = jax.jit(self._jax_init())
1913
+ self.initialize, self.init_optimizer = self._jax_init()
1848
1914
 
1849
1915
  # losses
1850
1916
  train_loss = self._jax_loss(train_rollouts, use_symlog=self.use_symlog_reward)
1851
- self.test_loss = jax.jit(self._jax_loss(test_rollouts, use_symlog=False))
1917
+ test_loss = self._jax_loss(test_rollouts, use_symlog=False)
1918
+ if self.parallel_updates is None:
1919
+ self.test_loss = jax.jit(test_loss)
1920
+ else:
1921
+ self.test_loss = jax.jit(jax.vmap(test_loss, in_axes=(None, 0, None, None, 0)))
1852
1922
 
1853
1923
  # optimization
1854
1924
  self.update = self._jax_update(train_loss)
1925
+ self.pytree_at = jax.jit(lambda tree, i: jax.tree_map(lambda x: x[i], tree))
1855
1926
 
1856
1927
  # pgpe option
1857
1928
  if self.use_pgpe:
1858
- loss_fn = self._jax_loss(rollouts=test_rollouts)
1859
1929
  self.pgpe.compile(
1860
- loss_fn=loss_fn,
1930
+ loss_fn=test_loss,
1861
1931
  projection=self.plan.projection,
1862
- real_dtype=self.test_compiled.REAL
1932
+ real_dtype=self.test_compiled.REAL,
1933
+ parallel_updates=self.parallel_updates
1863
1934
  )
1935
+ self.merge_pgpe = self._jax_merge_pgpe_jaxplan()
1936
+ else:
1937
+ self.merge_pgpe = None
1864
1938
 
1865
1939
  def _jax_return(self, use_symlog):
1866
1940
  gamma = self.rddl.discount
@@ -1900,24 +1974,43 @@ r"""
1900
1974
  def _jax_init(self):
1901
1975
  init = self.plan.initializer
1902
1976
  optimizer = self.optimizer
1977
+ num_parallel = self.parallel_updates
1903
1978
 
1904
1979
  # initialize both the policy and its optimizer
1905
1980
  def _jax_wrapped_init_policy(key, policy_hyperparams, subs):
1906
1981
  policy_params = init(key, policy_hyperparams, subs)
1907
1982
  opt_state = optimizer.init(policy_params)
1908
- return policy_params, opt_state, {}
1983
+ return policy_params, opt_state, {}
1909
1984
 
1910
- return _jax_wrapped_init_policy
1985
+ # initialize just the optimizer from the policy
1986
+ def _jax_wrapped_init_opt(policy_params):
1987
+ if num_parallel is None:
1988
+ opt_state = optimizer.init(policy_params)
1989
+ else:
1990
+ opt_state = jax.vmap(optimizer.init, in_axes=0)(policy_params)
1991
+ return opt_state, {}
1992
+
1993
+ if num_parallel is None:
1994
+ return jax.jit(_jax_wrapped_init_policy), jax.jit(_jax_wrapped_init_opt)
1995
+
1996
+ # for parallel policy update
1997
+ def _jax_wrapped_init_policies(key, policy_hyperparams, subs):
1998
+ keys = jnp.asarray(random.split(key, num=num_parallel))
1999
+ return jax.vmap(_jax_wrapped_init_policy, in_axes=(0, None, None))(
2000
+ keys, policy_hyperparams, subs)
2001
+
2002
+ return jax.jit(_jax_wrapped_init_policies), jax.jit(_jax_wrapped_init_opt)
1911
2003
 
1912
2004
  def _jax_update(self, loss):
1913
2005
  optimizer = self.optimizer
1914
2006
  projection = self.plan.projection
1915
2007
  use_ls = self.line_search_kwargs is not None
2008
+ num_parallel = self.parallel_updates
1916
2009
 
1917
2010
  # check if the gradients are all zeros
1918
2011
  def _jax_wrapped_zero_gradients(grad):
1919
2012
  leaves, _ = jax.tree_util.tree_flatten(
1920
- jax.tree_map(lambda g: jnp.allclose(g, 0), grad))
2013
+ jax.tree_map(partial(jnp.allclose, b=0), grad))
1921
2014
  return jnp.all(jnp.asarray(leaves))
1922
2015
 
1923
2016
  # calculate the plan gradient w.r.t. return loss and update optimizer
@@ -1948,8 +2041,43 @@ r"""
1948
2041
  return policy_params, converged, opt_state, opt_aux, \
1949
2042
  loss_val, log, model_params, zero_grads
1950
2043
 
1951
- return jax.jit(_jax_wrapped_plan_update)
2044
+ if num_parallel is None:
2045
+ return jax.jit(_jax_wrapped_plan_update)
2046
+
2047
+ # for parallel policy update
2048
+ def _jax_wrapped_plan_updates(key, policy_params, policy_hyperparams,
2049
+ subs, model_params, opt_state, opt_aux):
2050
+ keys = jnp.asarray(random.split(key, num=num_parallel))
2051
+ return jax.vmap(
2052
+ _jax_wrapped_plan_update, in_axes=(0, 0, None, None, 0, 0, 0)
2053
+ )(keys, policy_params, policy_hyperparams, subs, model_params,
2054
+ opt_state, opt_aux)
2055
+
2056
+ return jax.jit(_jax_wrapped_plan_updates)
1952
2057
 
2058
+ def _jax_merge_pgpe_jaxplan(self):
2059
+ if self.parallel_updates is None:
2060
+ return None
2061
+
2062
+ # for parallel policy update
2063
+ # currently implements a hard replacement where the jaxplan parameter
2064
+ # is replaced by the PGPE parameter if the latter is an improvement
2065
+ def _jax_wrapped_pgpe_jaxplan_merge(pgpe_mask, pgpe_param, policy_params,
2066
+ pgpe_loss, test_loss,
2067
+ pgpe_loss_smooth, test_loss_smooth,
2068
+ pgpe_converged, converged):
2069
+ def select_fn(leaf1, leaf2):
2070
+ expanded_mask = pgpe_mask[(...,) + (jnp.newaxis,) * (jnp.ndim(leaf1) - 1)]
2071
+ return jnp.where(expanded_mask, leaf1, leaf2)
2072
+ policy_params = jax.tree_map(select_fn, pgpe_param, policy_params)
2073
+ test_loss = jnp.where(pgpe_mask, pgpe_loss, test_loss)
2074
+ test_loss_smooth = jnp.where(pgpe_mask, pgpe_loss_smooth, test_loss_smooth)
2075
+ expanded_mask = pgpe_mask[(...,) + (jnp.newaxis,) * (jnp.ndim(converged) - 1)]
2076
+ converged = jnp.where(expanded_mask, pgpe_converged, converged)
2077
+ return policy_params, test_loss, test_loss_smooth, converged
2078
+
2079
+ return jax.jit(_jax_wrapped_pgpe_jaxplan_merge)
2080
+
1953
2081
  def _batched_init_subs(self, subs):
1954
2082
  rddl = self.rddl
1955
2083
  n_train, n_test = self.batch_size_train, self.batch_size_test
@@ -1968,6 +2096,13 @@ r"""
1968
2096
  train_value = np.asarray(train_value, dtype=self.compiled.REAL)
1969
2097
  init_train[name] = train_value
1970
2098
  init_test[name] = np.repeat(value, repeats=n_test, axis=0)
2099
+
2100
+ # safely cast test subs variable to required type in case the type is wrong
2101
+ if name in rddl.variable_ranges:
2102
+ required_type = RDDLValueInitializer.NUMPY_TYPES.get(
2103
+ rddl.variable_ranges[name], RDDLValueInitializer.INT)
2104
+ if np.result_type(init_test[name]) != required_type:
2105
+ init_test[name] = np.asarray(init_test[name], dtype=required_type)
1971
2106
 
1972
2107
  # make sure next-state fluents are also set
1973
2108
  for (state, next_state) in rddl.next_state.items():
@@ -1975,6 +2110,19 @@ r"""
1975
2110
  init_test[next_state] = init_test[state]
1976
2111
  return init_train, init_test
1977
2112
 
2113
+ def _broadcast_pytree(self, pytree):
2114
+ if self.parallel_updates is None:
2115
+ return pytree
2116
+
2117
+ # for parallel policy update
2118
+ def make_batched(x):
2119
+ x = np.asarray(x)
2120
+ x = np.broadcast_to(
2121
+ x[np.newaxis, ...], shape=(self.parallel_updates,) + np.shape(x))
2122
+ return x
2123
+
2124
+ return jax.tree_map(make_batched, pytree)
2125
+
1978
2126
  def as_optimization_problem(
1979
2127
  self, key: Optional[random.PRNGKey]=None,
1980
2128
  policy_hyperparams: Optional[Pytree]=None,
@@ -2002,6 +2150,11 @@ r"""
2002
2150
  :param grad_function_updates_key: if True, the gradient function
2003
2151
  updates the PRNG key internally independently of the loss function.
2004
2152
  '''
2153
+
2154
+ # make sure parallel updates are disabled
2155
+ if self.parallel_updates is not None:
2156
+ raise ValueError('Cannot compile static optimization problem '
2157
+ 'when parallel_updates is not None.')
2005
2158
 
2006
2159
  # if PRNG key is not provided
2007
2160
  if key is None:
@@ -2012,8 +2165,10 @@ r"""
2012
2165
  train_subs, _ = self._batched_init_subs(subs)
2013
2166
  model_params = self.compiled.model_params
2014
2167
  if policy_hyperparams is None:
2015
- raise_warning('policy_hyperparams is not set, setting 1.0 for '
2016
- 'all action-fluents which could be suboptimal.')
2168
+ message = termcolor.colored(
2169
+ '[WARN] policy_hyperparams is not set, setting 1.0 for '
2170
+ 'all action-fluents which could be suboptimal.', 'yellow')
2171
+ print(message)
2017
2172
  policy_hyperparams = {action: 1.0
2018
2173
  for action in self.rddl.action_fluents}
2019
2174
 
@@ -2084,10 +2239,12 @@ r"""
2084
2239
  their values: if None initializes all variables from the RDDL instance
2085
2240
  :param guess: initial policy parameters: if None will use the initializer
2086
2241
  specified in this instance
2087
- :param print_summary: whether to print planner header, parameter
2088
- summary, and diagnosis
2242
+ :param print_summary: whether to print planner header and diagnosis
2089
2243
  :param print_progress: whether to print the progress bar during training
2244
+ :param print_hyperparams: whether to print list of hyper-parameter settings
2090
2245
  :param stopping_rule: stopping criterion
2246
+ :param restart_epochs: restart the optimizer from a random policy configuration
2247
+ if there is no progress for this many consecutive iterations
2091
2248
  :param test_rolling_window: the test return is averaged on a rolling
2092
2249
  window of the past test_rolling_window returns when updating the best
2093
2250
  parameters found so far
@@ -2120,7 +2277,9 @@ r"""
2120
2277
  guess: Optional[Pytree]=None,
2121
2278
  print_summary: bool=True,
2122
2279
  print_progress: bool=True,
2280
+ print_hyperparams: bool=False,
2123
2281
  stopping_rule: Optional[JaxPlannerStoppingRule]=None,
2282
+ restart_epochs: int=999999,
2124
2283
  test_rolling_window: int=10,
2125
2284
  tqdm_position: Optional[int]=None) -> Generator[Dict[str, Any], None, None]:
2126
2285
  '''Returns a generator for computing an optimal policy or plan.
@@ -2139,10 +2298,12 @@ r"""
2139
2298
  their values: if None initializes all variables from the RDDL instance
2140
2299
  :param guess: initial policy parameters: if None will use the initializer
2141
2300
  specified in this instance
2142
- :param print_summary: whether to print planner header, parameter
2143
- summary, and diagnosis
2301
+ :param print_summary: whether to print planner header and diagnosis
2144
2302
  :param print_progress: whether to print the progress bar during training
2303
+ :param print_hyperparams: whether to print list of hyper-parameter settings
2145
2304
  :param stopping_rule: stopping criterion
2305
+ :param restart_epochs: restart the optimizer from a random policy configuration
2306
+ if there is no progress for this many consecutive iterations
2146
2307
  :param test_rolling_window: the test return is averaged on a rolling
2147
2308
  window of the past test_rolling_window returns when updating the best
2148
2309
  parameters found so far
@@ -2155,6 +2316,14 @@ r"""
2155
2316
  # INITIALIZATION OF HYPER-PARAMETERS
2156
2317
  # ======================================================================
2157
2318
 
2319
+ # cannot run dashboard with parallel updates
2320
+ if dashboard is not None and self.parallel_updates is not None:
2321
+ message = termcolor.colored(
2322
+ '[WARN] Dashboard is unavailable if parallel_updates is not None: '
2323
+ 'setting dashboard to None.', 'yellow')
2324
+ print(message)
2325
+ dashboard = None
2326
+
2158
2327
  # if PRNG key is not provided
2159
2328
  if key is None:
2160
2329
  key = random.PRNGKey(round(time.time() * 1000))
@@ -2162,15 +2331,19 @@ r"""
2162
2331
 
2163
2332
  # if policy_hyperparams is not provided
2164
2333
  if policy_hyperparams is None:
2165
- raise_warning('policy_hyperparams is not set, setting 1.0 for '
2166
- 'all action-fluents which could be suboptimal.')
2334
+ message = termcolor.colored(
2335
+ '[WARN] policy_hyperparams is not set, setting 1.0 for '
2336
+ 'all action-fluents which could be suboptimal.', 'yellow')
2337
+ print(message)
2167
2338
  policy_hyperparams = {action: 1.0
2168
2339
  for action in self.rddl.action_fluents}
2169
2340
 
2170
2341
  # if policy_hyperparams is a scalar
2171
2342
  elif isinstance(policy_hyperparams, (int, float, np.number)):
2172
- raise_warning(f'policy_hyperparams is {policy_hyperparams}, '
2173
- 'setting this value for all action-fluents.')
2343
+ message = termcolor.colored(
2344
+ f'[INFO] policy_hyperparams is {policy_hyperparams}, '
2345
+ f'setting this value for all action-fluents.', 'green')
2346
+ print(message)
2174
2347
  hyperparam_value = float(policy_hyperparams)
2175
2348
  policy_hyperparams = {action: hyperparam_value
2176
2349
  for action in self.rddl.action_fluents}
@@ -2179,14 +2352,19 @@ r"""
2179
2352
  elif isinstance(policy_hyperparams, dict):
2180
2353
  for action in self.rddl.action_fluents:
2181
2354
  if action not in policy_hyperparams:
2182
- raise_warning(f'policy_hyperparams[{action}] is not set, '
2183
- 'setting 1.0 which could be suboptimal.')
2355
+ message = termcolor.colored(
2356
+ f'[WARN] policy_hyperparams[{action}] is not set, '
2357
+ f'setting 1.0 for missing action-fluents '
2358
+ f'which could be suboptimal.', 'yellow')
2359
+ print(message)
2184
2360
  policy_hyperparams[action] = 1.0
2185
2361
 
2186
2362
  # print summary of parameters:
2187
2363
  if print_summary:
2188
2364
  print(self.summarize_system())
2189
- self.summarize_hyperparameters()
2365
+ print(self.summarize_relaxations())
2366
+ if print_hyperparams:
2367
+ print(self.summarize_hyperparameters())
2190
2368
  print(f'optimize() call hyper-parameters:\n'
2191
2369
  f' PRNG key ={key}\n'
2192
2370
  f' max_iterations ={epochs}\n'
@@ -2200,7 +2378,8 @@ r"""
2200
2378
  f' dashboard_id ={dashboard_id}\n'
2201
2379
  f' print_summary ={print_summary}\n'
2202
2380
  f' print_progress ={print_progress}\n'
2203
- f' stopping_rule ={stopping_rule}\n')
2381
+ f' stopping_rule ={stopping_rule}\n'
2382
+ f' restart_epochs ={restart_epochs}\n')
2204
2383
 
2205
2384
  # ======================================================================
2206
2385
  # INITIALIZATION OF STATE AND POLICY
@@ -2218,15 +2397,17 @@ r"""
2218
2397
  subs[var] = value
2219
2398
  added_pvars_to_subs.append(var)
2220
2399
  if added_pvars_to_subs:
2221
- raise_warning(f'p-variables {added_pvars_to_subs} not in '
2222
- 'provided subs, using their initial values '
2223
- 'from the RDDL files.')
2400
+ message = termcolor.colored(
2401
+ f'[INFO] p-variables {added_pvars_to_subs} is not in '
2402
+ f'provided subs, using their initial values.', 'green')
2403
+ print(message)
2224
2404
  train_subs, test_subs = self._batched_init_subs(subs)
2225
2405
 
2226
2406
  # initialize model parameters
2227
2407
  if model_params is None:
2228
2408
  model_params = self.compiled.model_params
2229
- model_params_test = self.test_compiled.model_params
2409
+ model_params = self._broadcast_pytree(model_params)
2410
+ model_params_test = self._broadcast_pytree(self.test_compiled.model_params)
2230
2411
 
2231
2412
  # initialize policy parameters
2232
2413
  if guess is None:
@@ -2234,29 +2415,31 @@ r"""
2234
2415
  policy_params, opt_state, opt_aux = self.initialize(
2235
2416
  subkey, policy_hyperparams, train_subs)
2236
2417
  else:
2237
- policy_params = guess
2238
- opt_state = self.optimizer.init(policy_params)
2239
- opt_aux = {}
2418
+ policy_params = self._broadcast_pytree(guess)
2419
+ opt_state, opt_aux = self.init_optimizer(policy_params)
2240
2420
 
2241
2421
  # initialize pgpe parameters
2242
2422
  if self.use_pgpe:
2243
- pgpe_params, pgpe_opt_state = self.pgpe.initialize(key, policy_params)
2423
+ pgpe_params, pgpe_opt_state, r_max = self.pgpe.initialize(key, policy_params)
2244
2424
  rolling_pgpe_loss = RollingMean(test_rolling_window)
2245
2425
  else:
2246
- pgpe_params, pgpe_opt_state = None, None
2426
+ pgpe_params, pgpe_opt_state, r_max = None, None, None
2247
2427
  rolling_pgpe_loss = None
2248
2428
  total_pgpe_it = 0
2249
- r_max = -jnp.inf
2250
2429
 
2251
2430
  # ======================================================================
2252
2431
  # INITIALIZATION OF RUNNING STATISTICS
2253
2432
  # ======================================================================
2254
2433
 
2255
2434
  # initialize running statistics
2256
- best_params, best_loss, best_grad = policy_params, jnp.inf, None
2435
+ if self.parallel_updates is None:
2436
+ best_params = policy_params
2437
+ else:
2438
+ best_params = self.pytree_at(policy_params, 0)
2439
+ best_loss, pbest_loss, best_grad = np.inf, np.inf, None
2257
2440
  last_iter_improve = 0
2441
+ no_progress_count = 0
2258
2442
  rolling_test_loss = RollingMean(test_rolling_window)
2259
- log = {}
2260
2443
  status = JaxPlannerStatus.NORMAL
2261
2444
  progress_percent = 0
2262
2445
 
@@ -2277,6 +2460,11 @@ r"""
2277
2460
  else:
2278
2461
  progress_bar = None
2279
2462
  position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
2463
+
2464
+ # error handlers (to avoid spam messaging)
2465
+ policy_constraint_msg_shown = False
2466
+ jax_train_msg_shown = False
2467
+ jax_test_msg_shown = False
2280
2468
 
2281
2469
  # ======================================================================
2282
2470
  # MAIN TRAINING LOOP BEGINS
@@ -2296,8 +2484,13 @@ r"""
2296
2484
  model_params, zero_grads) = self.update(
2297
2485
  subkey, policy_params, policy_hyperparams, train_subs, model_params,
2298
2486
  opt_state, opt_aux)
2487
+
2488
+ # evaluate
2299
2489
  test_loss, (test_log, model_params_test) = self.test_loss(
2300
2490
  subkey, policy_params, policy_hyperparams, test_subs, model_params_test)
2491
+ if self.parallel_updates:
2492
+ train_loss = np.asarray(train_loss)
2493
+ test_loss = np.asarray(test_loss)
2301
2494
  test_loss_smooth = rolling_test_loss.update(test_loss)
2302
2495
 
2303
2496
  # pgpe update of the plan
@@ -2308,52 +2501,112 @@ r"""
2308
2501
  self.pgpe.update(subkey, pgpe_params, r_max, progress_percent,
2309
2502
  policy_hyperparams, test_subs, model_params_test,
2310
2503
  pgpe_opt_state)
2504
+
2505
+ # evaluate
2311
2506
  pgpe_loss, _ = self.test_loss(
2312
2507
  subkey, pgpe_param, policy_hyperparams, test_subs, model_params_test)
2508
+ if self.parallel_updates:
2509
+ pgpe_loss = np.asarray(pgpe_loss)
2313
2510
  pgpe_loss_smooth = rolling_pgpe_loss.update(pgpe_loss)
2314
2511
  pgpe_return = -pgpe_loss_smooth
2315
2512
 
2316
- # replace with PGPE if it reaches a new minimum or train loss invalid
2317
- if pgpe_loss_smooth < best_loss or not np.isfinite(train_loss):
2318
- policy_params = pgpe_param
2319
- test_loss, test_loss_smooth = pgpe_loss, pgpe_loss_smooth
2320
- converged = pgpe_converged
2321
- pgpe_improve = True
2322
- total_pgpe_it += 1
2513
+ # replace JaxPlan with PGPE if new minimum reached or train loss invalid
2514
+ if self.parallel_updates is None:
2515
+ if pgpe_loss_smooth < best_loss or not np.isfinite(train_loss):
2516
+ policy_params = pgpe_param
2517
+ test_loss, test_loss_smooth = pgpe_loss, pgpe_loss_smooth
2518
+ converged = pgpe_converged
2519
+ pgpe_improve = True
2520
+ total_pgpe_it += 1
2521
+ else:
2522
+ pgpe_mask = (pgpe_loss_smooth < pbest_loss) | ~np.isfinite(train_loss)
2523
+ if np.any(pgpe_mask):
2524
+ policy_params, test_loss, test_loss_smooth, converged = \
2525
+ self.merge_pgpe(pgpe_mask, pgpe_param, policy_params,
2526
+ pgpe_loss, test_loss,
2527
+ pgpe_loss_smooth, test_loss_smooth,
2528
+ pgpe_converged, converged)
2529
+ pgpe_improve = True
2530
+ total_pgpe_it += 1
2323
2531
  else:
2324
2532
  pgpe_loss, pgpe_loss_smooth, pgpe_return = None, None, None
2325
2533
 
2326
- # evaluate test losses and record best plan so far
2327
- if test_loss_smooth < best_loss:
2328
- best_params, best_loss, best_grad = \
2329
- policy_params, test_loss_smooth, train_log['grad']
2330
- last_iter_improve = it
2534
+ # evaluate test losses and record best parameters so far
2535
+ if self.parallel_updates is None:
2536
+ if test_loss_smooth < best_loss:
2537
+ best_params, best_loss, best_grad = \
2538
+ policy_params, test_loss_smooth, train_log['grad']
2539
+ pbest_loss = best_loss
2540
+ else:
2541
+ best_index = np.argmin(test_loss_smooth)
2542
+ if test_loss_smooth[best_index] < best_loss:
2543
+ best_params = self.pytree_at(policy_params, best_index)
2544
+ best_grad = self.pytree_at(train_log['grad'], best_index)
2545
+ best_loss = test_loss_smooth[best_index]
2546
+ pbest_loss = np.minimum(pbest_loss, test_loss_smooth)
2331
2547
 
2332
2548
  # ==================================================================
2333
2549
  # STATUS CHECKS AND LOGGING
2334
2550
  # ==================================================================
2335
2551
 
2336
2552
  # no progress
2337
- if (not pgpe_improve) and zero_grads:
2553
+ no_progress_flag = (not pgpe_improve) and np.all(zero_grads)
2554
+ if no_progress_flag:
2338
2555
  status = JaxPlannerStatus.NO_PROGRESS
2339
-
2556
+
2340
2557
  # constraint satisfaction problem
2341
- if not np.all(converged):
2342
- raise_warning(
2343
- 'Projected gradient method for satisfying action concurrency '
2344
- 'constraints reached the iteration limit: plan is possibly '
2345
- 'invalid for the current instance.', 'red')
2558
+ if not np.all(converged):
2559
+ if progress_bar is not None and not policy_constraint_msg_shown:
2560
+ message = termcolor.colored(
2561
+ '[FAIL] Policy update failed to satisfy action constraints.',
2562
+ 'red')
2563
+ progress_bar.write(message)
2564
+ policy_constraint_msg_shown = True
2346
2565
  status = JaxPlannerStatus.PRECONDITION_POSSIBLY_UNSATISFIED
2347
2566
 
2348
2567
  # numerical error
2349
2568
  if self.use_pgpe:
2350
- invalid_loss = not (np.isfinite(train_loss) or np.isfinite(pgpe_loss))
2569
+ invalid_loss = not (np.any(np.isfinite(train_loss)) or
2570
+ np.any(np.isfinite(pgpe_loss)))
2351
2571
  else:
2352
- invalid_loss = not np.isfinite(train_loss)
2572
+ invalid_loss = not np.any(np.isfinite(train_loss))
2353
2573
  if invalid_loss:
2354
- raise_warning(f'Planner aborted due to invalid loss {train_loss}.', 'red')
2574
+ if progress_bar is not None:
2575
+ message = termcolor.colored(
2576
+ f'[FAIL] Planner aborted due to invalid train loss {train_loss}.',
2577
+ 'red')
2578
+ progress_bar.write(message)
2355
2579
  status = JaxPlannerStatus.INVALID_GRADIENT
2356
2580
 
2581
+ # problem in the model compilation
2582
+ if progress_bar is not None:
2583
+
2584
+ # train model
2585
+ if not jax_train_msg_shown:
2586
+ messages = set()
2587
+ for error_code in np.unique(train_log['error']):
2588
+ messages.update(JaxRDDLCompiler.get_error_messages(error_code))
2589
+ if messages:
2590
+ messages = '\n '.join(messages)
2591
+ message = termcolor.colored(
2592
+ f'[FAIL] Compiler encountered the following '
2593
+ f'error(s) in the training model:\n {messages}', 'red')
2594
+ progress_bar.write(message)
2595
+ jax_train_msg_shown = True
2596
+
2597
+ # test model
2598
+ if not jax_test_msg_shown:
2599
+ messages = set()
2600
+ for error_code in np.unique(test_log['error']):
2601
+ messages.update(JaxRDDLCompiler.get_error_messages(error_code))
2602
+ if messages:
2603
+ messages = '\n '.join(messages)
2604
+ message = termcolor.colored(
2605
+ f'[FAIL] Compiler encountered the following '
2606
+ f'error(s) in the testing model:\n {messages}', 'red')
2607
+ progress_bar.write(message)
2608
+ jax_test_msg_shown = True
2609
+
2357
2610
  # reached computation budget
2358
2611
  elapsed = time.time() - start_time - elapsed_outside_loop
2359
2612
  if elapsed >= train_seconds:
@@ -2387,20 +2640,39 @@ r"""
2387
2640
  **test_log
2388
2641
  }
2389
2642
 
2643
+ # hard restart
2644
+ if guess is None and no_progress_flag:
2645
+ no_progress_count += 1
2646
+ if no_progress_count > restart_epochs:
2647
+ key, subkey = random.split(key)
2648
+ policy_params, opt_state, opt_aux = self.initialize(
2649
+ subkey, policy_hyperparams, train_subs)
2650
+ no_progress_count = 0
2651
+ if progress_bar is not None:
2652
+ message = termcolor.colored(
2653
+ f'[INFO] Optimizer restarted at iteration {it} '
2654
+ f'due to lack of progress.', 'green')
2655
+ progress_bar.write(message)
2656
+ else:
2657
+ no_progress_count = 0
2658
+
2390
2659
  # stopping condition reached
2391
2660
  if stopping_rule is not None and stopping_rule.monitor(callback):
2661
+ if progress_bar is not None:
2662
+ message = termcolor.colored(
2663
+ '[SUCC] Stopping rule has been reached.', 'green')
2664
+ progress_bar.write(message)
2392
2665
  callback['status'] = status = JaxPlannerStatus.STOPPING_RULE_REACHED
2393
2666
 
2394
2667
  # if the progress bar is used
2395
2668
  if print_progress:
2396
2669
  progress_bar.set_description(
2397
- f'{position_str} {it:6} it / {-train_loss:14.5f} train / '
2398
- f'{-test_loss_smooth:14.5f} test / {-best_loss:14.5f} best / '
2670
+ f'{position_str} {it:6} it / {-np.min(train_loss):14.5f} train / '
2671
+ f'{-np.min(test_loss_smooth):14.5f} test / {-best_loss:14.5f} best / '
2399
2672
  f'{status.value} status / {total_pgpe_it:6} pgpe',
2400
- refresh=False
2401
- )
2673
+ refresh=False)
2402
2674
  progress_bar.set_postfix_str(
2403
- f"{(it + 1) / (elapsed + 1e-6):.2f}it/s", refresh=False)
2675
+ f'{(it + 1) / (elapsed + 1e-6):.2f}it/s', refresh=False)
2404
2676
  progress_bar.update(progress_percent - progress_bar.n)
2405
2677
 
2406
2678
  # dash-board
@@ -2423,24 +2695,15 @@ r"""
2423
2695
  # release resources
2424
2696
  if print_progress:
2425
2697
  progress_bar.close()
2426
-
2427
- # validate the test return
2428
- if log:
2429
- messages = set()
2430
- for error_code in np.unique(log['error']):
2431
- messages.update(JaxRDDLCompiler.get_error_messages(error_code))
2432
- if messages:
2433
- messages = '\n'.join(messages)
2434
- raise_warning('JAX compiler encountered the following '
2435
- 'error(s) in the original RDDL formulation '
2436
- f'during test evaluation:\n{messages}', 'red')
2698
+ print()
2437
2699
 
2438
2700
  # summarize and test for convergence
2439
2701
  if print_summary:
2440
2702
  grad_norm = jax.tree_map(lambda x: np.linalg.norm(x).item(), best_grad)
2441
2703
  diagnosis = self._perform_diagnosis(
2442
- last_iter_improve, -train_loss, -test_loss_smooth, -best_loss, grad_norm)
2443
- print(f'summary of optimization:\n'
2704
+ last_iter_improve, -np.min(train_loss), -np.min(test_loss_smooth),
2705
+ -best_loss, grad_norm)
2706
+ print(f'Summary of optimization:\n'
2444
2707
  f' status ={status}\n'
2445
2708
  f' time ={elapsed:.3f} sec.\n'
2446
2709
  f' iterations ={it}\n'
@@ -2453,12 +2716,9 @@ r"""
2453
2716
  max_grad_norm = max(jax.tree_util.tree_leaves(grad_norm))
2454
2717
  grad_is_zero = np.allclose(max_grad_norm, 0)
2455
2718
 
2456
- validation_error = 100 * abs(test_return - train_return) / \
2457
- max(abs(train_return), abs(test_return))
2458
-
2459
2719
  # divergence if the solution is not finite
2460
2720
  if not np.isfinite(train_return):
2461
- return termcolor.colored('[FAILURE] training loss diverged.', 'red')
2721
+ return termcolor.colored('[FAIL] Training loss diverged.', 'red')
2462
2722
 
2463
2723
  # hit a plateau is likely IF:
2464
2724
  # 1. planner does not improve at all
@@ -2466,23 +2726,25 @@ r"""
2466
2726
  if last_iter_improve <= 1:
2467
2727
  if grad_is_zero:
2468
2728
  return termcolor.colored(
2469
- '[FAILURE] no progress was made '
2729
+ f'[FAIL] No progress was made '
2470
2730
  f'and max grad norm {max_grad_norm:.6f} was zero: '
2471
- 'solver likely stuck in a plateau.', 'red')
2731
+ f'solver likely stuck in a plateau.', 'red')
2472
2732
  else:
2473
2733
  return termcolor.colored(
2474
- '[FAILURE] no progress was made '
2734
+ f'[FAIL] No progress was made '
2475
2735
  f'but max grad norm {max_grad_norm:.6f} was non-zero: '
2476
- 'learning rate or other hyper-parameters likely suboptimal.',
2736
+ f'learning rate or other hyper-parameters could be suboptimal.',
2477
2737
  'red')
2478
2738
 
2479
2739
  # model is likely poor IF:
2480
2740
  # 1. the train and test return disagree
2741
+ validation_error = 100 * abs(test_return - train_return) / \
2742
+ max(abs(train_return), abs(test_return))
2481
2743
  if not (validation_error < 20):
2482
2744
  return termcolor.colored(
2483
- '[WARNING] progress was made '
2745
+ f'[WARN] Progress was made '
2484
2746
  f'but relative train-test error {validation_error:.6f} was high: '
2485
- 'poor model relaxation around solution or batch size too small.',
2747
+ f'poor model relaxation around solution or batch size too small.',
2486
2748
  'yellow')
2487
2749
 
2488
2750
  # model likely did not converge IF:
@@ -2491,24 +2753,22 @@ r"""
2491
2753
  return_to_grad_norm = abs(best_return) / max_grad_norm
2492
2754
  if not (return_to_grad_norm > 1):
2493
2755
  return termcolor.colored(
2494
- '[WARNING] progress was made '
2756
+ f'[WARN] Progress was made '
2495
2757
  f'but max grad norm {max_grad_norm:.6f} was high: '
2496
- 'solution locally suboptimal '
2497
- 'or relaxed model not smooth around solution '
2498
- 'or batch size too small.', 'yellow')
2758
+ f'solution locally suboptimal, relaxed model nonsmooth around solution, '
2759
+ f'or batch size too small.', 'yellow')
2499
2760
 
2500
2761
  # likely successful
2501
2762
  return termcolor.colored(
2502
- '[SUCCESS] solver converged successfully '
2503
- '(note: not all potential problems can be ruled out).', 'green')
2763
+ '[SUCC] Planner converged successfully '
2764
+ '(note: not all problems can be ruled out).', 'green')
2504
2765
 
2505
2766
  def get_action(self, key: random.PRNGKey,
2506
2767
  params: Pytree,
2507
2768
  step: int,
2508
2769
  subs: Dict[str, Any],
2509
2770
  policy_hyperparams: Optional[Dict[str, Any]]=None) -> Dict[str, Any]:
2510
- '''Returns an action dictionary from the policy or plan with the given
2511
- parameters.
2771
+ '''Returns an action dictionary from the policy or plan with the given parameters.
2512
2772
 
2513
2773
  :param key: the JAX PRNG key
2514
2774
  :param params: the trainable parameter PyTree of the policy
@@ -2612,8 +2872,7 @@ class JaxOfflineController(BaseAgent):
2612
2872
 
2613
2873
 
2614
2874
  class JaxOnlineController(BaseAgent):
2615
- '''A container class for a Jax controller continuously updated using state
2616
- feedback.'''
2875
+ '''A container class for a Jax controller continuously updated using state feedback.'''
2617
2876
 
2618
2877
  use_tensor_obs = True
2619
2878
 
@@ -2621,17 +2880,19 @@ class JaxOnlineController(BaseAgent):
2621
2880
  key: Optional[random.PRNGKey]=None,
2622
2881
  eval_hyperparams: Optional[Dict[str, Any]]=None,
2623
2882
  warm_start: bool=True,
2883
+ max_attempts: int=3,
2624
2884
  **train_kwargs) -> None:
2625
2885
  '''Creates a new JAX control policy that is trained online in a closed-
2626
2886
  loop fashion.
2627
2887
 
2628
2888
  :param planner: underlying planning algorithm for optimizing actions
2629
- :param key: the RNG key to seed randomness (derives from clock if not
2630
- provided)
2889
+ :param key: the RNG key to seed randomness (derives from clock if not provided)
2631
2890
  :param eval_hyperparams: policy hyperparameters to apply for evaluation
2632
2891
  or whenever sample_action is called
2633
2892
  :param warm_start: whether to use the previous decision epoch final
2634
2893
  policy parameters to warm the next decision epoch
2894
+ :param max_attempts: maximum attempted restarts of the optimizer when the total
2895
+ iteration count is 1 (i.e. the execution time is dominated by the jit compilation)
2635
2896
  :param **train_kwargs: any keyword arguments to be passed to the planner
2636
2897
  for optimization
2637
2898
  '''
@@ -2642,16 +2903,26 @@ class JaxOnlineController(BaseAgent):
2642
2903
  self.eval_hyperparams = eval_hyperparams
2643
2904
  self.warm_start = warm_start
2644
2905
  self.train_kwargs = train_kwargs
2906
+ self.max_attempts = max_attempts
2645
2907
  self.reset()
2646
2908
 
2647
2909
  def sample_action(self, state: Dict[str, Any]) -> Dict[str, Any]:
2648
2910
  planner = self.planner
2649
2911
  callback = planner.optimize(
2650
- key=self.key,
2651
- guess=self.guess,
2652
- subs=state,
2653
- **self.train_kwargs
2654
- )
2912
+ key=self.key, guess=self.guess, subs=state, **self.train_kwargs)
2913
+
2914
+ # optimize again if jit compilation takes up the entire time budget
2915
+ attempts = 0
2916
+ while attempts < self.max_attempts and callback['iteration'] <= 1:
2917
+ attempts += 1
2918
+ message = termcolor.colored(
2919
+ f'[WARN] JIT compilation dominated the execution time: '
2920
+ f'executing the optimizer again on the traced model [attempt {attempts}].',
2921
+ 'yellow')
2922
+ print(message)
2923
+ callback = planner.optimize(
2924
+ key=self.key, guess=self.guess, subs=state, **self.train_kwargs)
2925
+
2655
2926
  self.callback = callback
2656
2927
  params = callback['best_params']
2657
2928
  self.key, subkey = random.split(self.key)