pyRDDLGym-jax 2.3__py3-none-any.whl → 2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,7 +3,7 @@
3
3
  #
4
4
  # Author: Michael Gimelfarb
5
5
  #
6
- # RELEVANT SOURCES:
6
+ # REFERENCES:
7
7
  #
8
8
  # [1] Gimelfarb, Michael, Ayal Taitler, and Scott Sanner. "JaxPlan and GurobiPlan:
9
9
  # Optimization Baselines for Replanning in Discrete and Mixed Discrete-Continuous
@@ -18,26 +18,33 @@
18
18
  # reactive policies for planning in stochastic nonlinear domains." In Proceedings of the
19
19
  # AAAI Conference on Artificial Intelligence, vol. 33, no. 01, pp. 7530-7537. 2019.
20
20
  #
21
- # [4] Wu, Ga, Buser Say, and Scott Sanner. "Scalable planning with tensorflow for hybrid
21
+ # [4] Cui, Hao, Thomas Keller, and Roni Khardon. "Stochastic planning with lifted symbolic
22
+ # trajectory optimization." In Proceedings of the International Conference on Automated
23
+ # Planning and Scheduling, vol. 29, pp. 119-127. 2019.
24
+ #
25
+ # [5] Wu, Ga, Buser Say, and Scott Sanner. "Scalable planning with tensorflow for hybrid
22
26
  # nonlinear domains." Advances in Neural Information Processing Systems 30 (2017).
23
27
  #
24
- # [5] Sehnke, Frank, and Tingting Zhao. "Baseline-free sampling in parameter exploring
28
+ # [6] Sehnke, Frank, and Tingting Zhao. "Baseline-free sampling in parameter exploring
25
29
  # policy gradients: Super symmetric pgpe." Artificial Neural Networks: Methods and
26
30
  # Applications in Bio-/Neuroinformatics. Springer International Publishing, 2015.
27
31
  #
28
32
  # ***********************************************************************
29
33
 
30
34
 
35
+ from abc import ABCMeta, abstractmethod
31
36
  from ast import literal_eval
32
37
  from collections import deque
33
38
  import configparser
34
39
  from enum import Enum
35
40
  from functools import partial
36
41
  import os
42
+ import pickle
37
43
  import sys
38
44
  import time
39
45
  import traceback
40
- from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Type, Tuple, Union
46
+ from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Type, Tuple, \
47
+ Union
41
48
 
42
49
  import haiku as hk
43
50
  import jax
@@ -51,6 +58,7 @@ from tqdm import tqdm, TqdmWarning
51
58
  import warnings
52
59
  warnings.filterwarnings("ignore", category=TqdmWarning)
53
60
 
61
+ from pyRDDLGym.core.compiler.initializer import RDDLValueInitializer
54
62
  from pyRDDLGym.core.compiler.model import RDDLPlanningModel, RDDLLiftedModel
55
63
  from pyRDDLGym.core.debug.logger import Logger
56
64
  from pyRDDLGym.core.debug.exception import (
@@ -157,25 +165,20 @@ def _load_config(config, args):
157
165
  initializer = _getattr_any(
158
166
  packages=[initializers, hk.initializers], item=plan_initializer)
159
167
  if initializer is None:
160
- raise_warning(f'Ignoring invalid initializer <{plan_initializer}>.', 'red')
161
- del plan_kwargs['initializer']
168
+ raise ValueError(f'Invalid initializer <{plan_initializer}>.')
162
169
  else:
163
170
  init_kwargs = plan_kwargs.pop('initializer_kwargs', {})
164
171
  try:
165
172
  plan_kwargs['initializer'] = initializer(**init_kwargs)
166
173
  except Exception as _:
167
- raise_warning(
168
- f'Ignoring invalid initializer_kwargs <{init_kwargs}>.', 'red')
169
- plan_kwargs['initializer'] = initializer
174
+ raise ValueError(f'Invalid initializer kwargs <{init_kwargs}>.')
170
175
 
171
176
  # policy activation
172
177
  plan_activation = plan_kwargs.get('activation', None)
173
178
  if plan_activation is not None:
174
- activation = _getattr_any(
175
- packages=[jax.nn, jax.numpy], item=plan_activation)
179
+ activation = _getattr_any(packages=[jax.nn, jax.numpy], item=plan_activation)
176
180
  if activation is None:
177
- raise_warning(f'Ignoring invalid activation <{plan_activation}>.', 'red')
178
- del plan_kwargs['activation']
181
+ raise ValueError(f'Invalid activation <{plan_activation}>.')
179
182
  else:
180
183
  plan_kwargs['activation'] = activation
181
184
 
@@ -188,8 +191,7 @@ def _load_config(config, args):
188
191
  if planner_optimizer is not None:
189
192
  optimizer = _getattr_any(packages=[optax], item=planner_optimizer)
190
193
  if optimizer is None:
191
- raise_warning(f'Ignoring invalid optimizer <{planner_optimizer}>.', 'red')
192
- del planner_args['optimizer']
194
+ raise ValueError(f'Invalid optimizer <{planner_optimizer}>.')
193
195
  else:
194
196
  planner_args['optimizer'] = optimizer
195
197
 
@@ -200,8 +202,7 @@ def _load_config(config, args):
200
202
  if 'optimizer' in pgpe_kwargs:
201
203
  pgpe_optimizer = _getattr_any(packages=[optax], item=pgpe_kwargs['optimizer'])
202
204
  if pgpe_optimizer is None:
203
- raise_warning(f'Ignoring invalid optimizer <{pgpe_optimizer}>.', 'red')
204
- del pgpe_kwargs['optimizer']
205
+ raise ValueError(f'Invalid optimizer <{pgpe_optimizer}>.')
205
206
  else:
206
207
  pgpe_kwargs['optimizer'] = pgpe_optimizer
207
208
  planner_args['pgpe'] = getattr(sys.modules[__name__], pgpe_method)(**pgpe_kwargs)
@@ -229,13 +230,19 @@ def _load_config(config, args):
229
230
 
230
231
 
231
232
  def load_config(path: str) -> Tuple[Kwargs, ...]:
232
- '''Loads a config file at the specified file path.'''
233
+ '''Loads a config file at the specified file path.
234
+
235
+ :param path: the path of the config file to load and parse
236
+ '''
233
237
  config, args = _parse_config_file(path)
234
238
  return _load_config(config, args)
235
239
 
236
240
 
237
241
  def load_config_from_string(value: str) -> Tuple[Kwargs, ...]:
238
- '''Loads config file contents specified explicitly as a string value.'''
242
+ '''Loads config file contents specified explicitly as a string value.
243
+
244
+ :param value: the string in json format containing the config contents to parse
245
+ '''
239
246
  config, args = _parse_config_string(value)
240
247
  return _load_config(config, args)
241
248
 
@@ -258,10 +265,10 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
258
265
  def __init__(self, *args,
259
266
  logic: Logic=FuzzyLogic(),
260
267
  cpfs_without_grad: Optional[Set[str]]=None,
268
+ print_warnings: bool=True,
261
269
  **kwargs) -> None:
262
270
  '''Creates a new RDDL to Jax compiler, where operations that are not
263
- differentiable are converted to approximate forms that have defined
264
- gradients.
271
+ differentiable are converted to approximate forms that have defined gradients.
265
272
 
266
273
  :param *args: arguments to pass to base compiler
267
274
  :param logic: Fuzzy logic object that specifies how exact operations
@@ -269,6 +276,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
269
276
  to customize these operations
270
277
  :param cpfs_without_grad: which CPFs do not have gradients (use straight
271
278
  through gradient trick)
279
+ :param print_warnings: whether to print warnings
272
280
  :param *kwargs: keyword arguments to pass to base compiler
273
281
  '''
274
282
  super(JaxRDDLCompilerWithGrad, self).__init__(*args, **kwargs)
@@ -278,6 +286,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
278
286
  if cpfs_without_grad is None:
279
287
  cpfs_without_grad = set()
280
288
  self.cpfs_without_grad = cpfs_without_grad
289
+ self.print_warnings = print_warnings
281
290
 
282
291
  # actions and CPFs must be continuous
283
292
  pvars_cast = set()
@@ -285,9 +294,11 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
285
294
  self.init_values[var] = np.asarray(values, dtype=self.REAL)
286
295
  if not np.issubdtype(np.result_type(values), np.floating):
287
296
  pvars_cast.add(var)
288
- if pvars_cast:
289
- raise_warning(f'JAX gradient compiler requires that initial values '
290
- f'of p-variables {pvars_cast} be cast to float.')
297
+ if self.print_warnings and pvars_cast:
298
+ message = termcolor.colored(
299
+ f'[INFO] JAX gradient compiler will cast p-vars {pvars_cast} to float.',
300
+ 'green')
301
+ print(message)
291
302
 
292
303
  # overwrite basic operations with fuzzy ones
293
304
  self.OPS = logic.get_operator_dicts()
@@ -300,6 +311,8 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
300
311
  return _jax_wrapped_stop_grad
301
312
 
302
313
  def _compile_cpfs(self, init_params):
314
+
315
+ # cpfs will all be cast to float
303
316
  cpfs_cast = set()
304
317
  jax_cpfs = {}
305
318
  for (_, cpfs) in self.levels.items():
@@ -311,12 +324,16 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
311
324
  if cpf in self.cpfs_without_grad:
312
325
  jax_cpfs[cpf] = self._jax_stop_grad(jax_cpfs[cpf])
313
326
 
314
- if cpfs_cast:
315
- raise_warning(f'JAX gradient compiler requires that outputs of CPFs '
316
- f'{cpfs_cast} be cast to float.')
317
- if self.cpfs_without_grad:
318
- raise_warning(f'User requested that gradients not flow '
319
- f'through CPFs {self.cpfs_without_grad}.')
327
+ if self.print_warnings and cpfs_cast:
328
+ message = termcolor.colored(
329
+ f'[INFO] JAX gradient compiler will cast CPFs {cpfs_cast} to float.',
330
+ 'green')
331
+ print(message)
332
+ if self.print_warnings and self.cpfs_without_grad:
333
+ message = termcolor.colored(
334
+ f'[INFO] Gradients will not flow through CPFs {self.cpfs_without_grad}.',
335
+ 'green')
336
+ print(message)
320
337
 
321
338
  return jax_cpfs
322
339
 
@@ -335,7 +352,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
335
352
  # ***********************************************************************
336
353
 
337
354
 
338
- class JaxPlan:
355
+ class JaxPlan(metaclass=ABCMeta):
339
356
  '''Base class for all JAX policy representations.'''
340
357
 
341
358
  def __init__(self) -> None:
@@ -345,16 +362,18 @@ class JaxPlan:
345
362
  self._projection = None
346
363
  self.bounds = None
347
364
 
348
- def summarize_hyperparameters(self) -> None:
349
- print(self.__str__())
350
-
365
+ def summarize_hyperparameters(self) -> str:
366
+ return self.__str__()
367
+
368
+ @abstractmethod
351
369
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
352
370
  _bounds: Bounds,
353
371
  horizon: int) -> None:
354
- raise NotImplementedError
372
+ pass
355
373
 
374
+ @abstractmethod
356
375
  def guess_next_epoch(self, params: Pytree) -> Pytree:
357
- raise NotImplementedError
376
+ pass
358
377
 
359
378
  @property
360
379
  def initializer(self):
@@ -397,10 +416,11 @@ class JaxPlan:
397
416
  continue
398
417
 
399
418
  # check invalid type
400
- if prange not in compiled.JAX_TYPES:
419
+ if prange not in compiled.JAX_TYPES and prange not in compiled.rddl.enum_types:
420
+ keys = list(compiled.JAX_TYPES.keys()) + list(compiled.rddl.enum_types)
401
421
  raise RDDLTypeError(
402
422
  f'Invalid range <{prange}> of action-fluent <{name}>, '
403
- f'must be one of {set(compiled.JAX_TYPES.keys())}.')
423
+ f'must be one of {keys}.')
404
424
 
405
425
  # clip boolean to (0, 1), otherwise use the RDDL action bounds
406
426
  # or the user defined action bounds if provided
@@ -408,7 +428,12 @@ class JaxPlan:
408
428
  if prange == 'bool':
409
429
  lower, upper = None, None
410
430
  else:
411
- lower, upper = compiled.constraints.bounds[name]
431
+ if prange in compiled.rddl.enum_types:
432
+ lower = np.zeros(shape=shapes[name][1:])
433
+ upper = len(compiled.rddl.type_to_objects[prange]) - 1
434
+ upper = np.ones(shape=shapes[name][1:]) * upper
435
+ else:
436
+ lower, upper = compiled.constraints.bounds[name]
412
437
  lower, upper = user_bounds.get(name, (lower, upper))
413
438
  lower = np.asarray(lower, dtype=compiled.REAL)
414
439
  upper = np.asarray(upper, dtype=compiled.REAL)
@@ -421,7 +446,11 @@ class JaxPlan:
421
446
  ~lower_finite & upper_finite,
422
447
  ~lower_finite & ~upper_finite]
423
448
  bounds[name] = (lower, upper)
424
- raise_warning(f'Bounds of action-fluent <{name}> set to {bounds[name]}.')
449
+ if compiled.print_warnings:
450
+ message = termcolor.colored(
451
+ f'[INFO] Bounds of action-fluent <{name}> set to {bounds[name]}.',
452
+ 'green')
453
+ print(message)
425
454
  return shapes, bounds, bounds_safe, cond_lists
426
455
 
427
456
  def _count_bool_actions(self, rddl: RDDLLiftedModel):
@@ -501,11 +530,12 @@ class JaxStraightLinePlan(JaxPlan):
501
530
  # action concurrency check
502
531
  bool_action_count, allowed_actions = self._count_bool_actions(rddl)
503
532
  use_constraint_satisfaction = allowed_actions < bool_action_count
504
- if use_constraint_satisfaction:
505
- raise_warning(f'Using projected gradient trick to satisfy '
506
- f'max_nondef_actions: total boolean actions '
507
- f'{bool_action_count} > max_nondef_actions '
508
- f'{allowed_actions}.')
533
+ if compiled.print_warnings and use_constraint_satisfaction:
534
+ message = termcolor.colored(
535
+ f'[INFO] SLP will use projected gradient to satisfy '
536
+ f'max_nondef_actions since total boolean actions '
537
+ f'{bool_action_count} > max_nondef_actions {allowed_actions}.', 'green')
538
+ print(message)
509
539
 
510
540
  noop = {var: (values[0] if isinstance(values, list) else values)
511
541
  for (var, values) in rddl.action_fluents.items()}
@@ -586,7 +616,7 @@ class JaxStraightLinePlan(JaxPlan):
586
616
  start = 0
587
617
  for (name, size) in action_sizes.items():
588
618
  action = output[..., start:start + size]
589
- action = jnp.reshape(action, newshape=shapes[name][1:])
619
+ action = jnp.reshape(action, shapes[name][1:])
590
620
  if noop[name]:
591
621
  action = 1.0 - action
592
622
  actions[name] = action
@@ -623,7 +653,7 @@ class JaxStraightLinePlan(JaxPlan):
623
653
  else:
624
654
  action = _jax_non_bool_param_to_action(var, action, hyperparams)
625
655
  action = jnp.clip(action, *bounds[var])
626
- if ranges[var] == 'int':
656
+ if ranges[var] == 'int' or ranges[var] in rddl.enum_types:
627
657
  action = jnp.asarray(jnp.round(action), dtype=compiled.INT)
628
658
  actions[var] = action
629
659
  return actions
@@ -642,7 +672,7 @@ class JaxStraightLinePlan(JaxPlan):
642
672
  # only allow one action non-noop for now
643
673
  if 1 < allowed_actions < bool_action_count:
644
674
  raise RDDLNotImplementedError(
645
- f'Straight-line plans with wrap_softmax currently '
675
+ f'SLPs with wrap_softmax currently '
646
676
  f'do not support max-nondef-actions {allowed_actions} > 1.')
647
677
 
648
678
  # potentially apply projection but to non-bool actions only
@@ -764,7 +794,8 @@ class JaxStraightLinePlan(JaxPlan):
764
794
  for (var, action) in actions.items():
765
795
  if ranges[var] == 'bool':
766
796
  action = jnp.clip(action, min_action, max_action)
767
- new_params[var] = _jax_bool_action_to_param(var, action, hyperparams)
797
+ param = _jax_bool_action_to_param(var, action, hyperparams)
798
+ new_params[var] = param
768
799
  else:
769
800
  new_params[var] = action
770
801
  return new_params, converged
@@ -818,7 +849,7 @@ class JaxStraightLinePlan(JaxPlan):
818
849
 
819
850
  def guess_next_epoch(self, params: Pytree) -> Pytree:
820
851
  next_fn = JaxStraightLinePlan._guess_next_epoch
821
- return jax.tree_map(next_fn, params)
852
+ return jax.tree_util.tree_map(next_fn, params)
822
853
 
823
854
 
824
855
  class JaxDeepReactivePolicy(JaxPlan):
@@ -890,8 +921,7 @@ class JaxDeepReactivePolicy(JaxPlan):
890
921
  bool_action_count, allowed_actions = self._count_bool_actions(rddl)
891
922
  if 1 < allowed_actions < bool_action_count:
892
923
  raise RDDLNotImplementedError(
893
- f'Deep reactive policies currently do not support '
894
- f'max-nondef-actions {allowed_actions} > 1.')
924
+ f'DRPs currently do not support max-nondef-actions {allowed_actions} > 1.')
895
925
  use_constraint_satisfaction = allowed_actions < bool_action_count
896
926
 
897
927
  noop = {var: (values[0] if isinstance(values, list) else values)
@@ -927,15 +957,19 @@ class JaxDeepReactivePolicy(JaxPlan):
927
957
  if ranges[var] != 'bool':
928
958
  value_size = np.size(values)
929
959
  if normalize_per_layer and value_size == 1:
930
- raise_warning(
931
- f'Cannot apply layer norm to state-fluent <{var}> '
932
- f'of size 1: setting normalize_per_layer = False.', 'red')
960
+ if compiled.print_warnings:
961
+ message = termcolor.colored(
962
+ f'[WARN] Cannot apply layer norm to state-fluent <{var}> '
963
+ f'of size 1: setting normalize_per_layer = False.', 'yellow')
964
+ print(message)
933
965
  normalize_per_layer = False
934
966
  non_bool_dims += value_size
935
967
  if not normalize_per_layer and non_bool_dims == 1:
936
- raise_warning(
937
- 'Cannot apply layer norm to state-fluents of total size 1: '
938
- 'setting normalize = False.', 'red')
968
+ if compiled.print_warnings:
969
+ message = termcolor.colored(
970
+ '[WARN] Cannot apply layer norm to state-fluents of total size 1: '
971
+ 'setting normalize = False.', 'yellow')
972
+ print(message)
939
973
  normalize = False
940
974
 
941
975
  # convert subs dictionary into a state vector to feed to the MLP
@@ -1033,7 +1067,7 @@ class JaxDeepReactivePolicy(JaxPlan):
1033
1067
  for (name, size) in layer_sizes.items():
1034
1068
  if ranges[name] == 'bool':
1035
1069
  action = output[..., start:start + size]
1036
- action = jnp.reshape(action, newshape=shapes[name])
1070
+ action = jnp.reshape(action, shapes[name])
1037
1071
  if noop[name]:
1038
1072
  action = 1.0 - action
1039
1073
  actions[name] = action
@@ -1061,7 +1095,7 @@ class JaxDeepReactivePolicy(JaxPlan):
1061
1095
  prange = ranges[var]
1062
1096
  if prange == 'bool':
1063
1097
  new_action = action > 0.5
1064
- elif prange == 'int':
1098
+ elif prange == 'int' or prange in rddl.enum_types:
1065
1099
  action = jnp.clip(action, *bounds[var])
1066
1100
  new_action = jnp.asarray(jnp.round(action), dtype=compiled.INT)
1067
1101
  else:
@@ -1112,19 +1146,18 @@ class JaxDeepReactivePolicy(JaxPlan):
1112
1146
 
1113
1147
 
1114
1148
  class RollingMean:
1115
- '''Maintains an estimate of the rolling mean of a stream of real-valued
1116
- observations.'''
1149
+ '''Maintains the rolling mean of a stream of real-valued observations.'''
1117
1150
 
1118
1151
  def __init__(self, window_size: int) -> None:
1119
1152
  self._window_size = window_size
1120
1153
  self._memory = deque(maxlen=window_size)
1121
1154
  self._total = 0
1122
1155
 
1123
- def update(self, x: float) -> float:
1156
+ def update(self, x: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
1124
1157
  memory = self._memory
1125
- self._total += x
1158
+ self._total = self._total + x
1126
1159
  if len(memory) == self._window_size:
1127
- self._total -= memory.popleft()
1160
+ self._total = self._total - memory.popleft()
1128
1161
  memory.append(x)
1129
1162
  return self._total / len(memory)
1130
1163
 
@@ -1147,14 +1180,16 @@ class JaxPlannerStatus(Enum):
1147
1180
  return self.value == 1 or self.value >= 4
1148
1181
 
1149
1182
 
1150
- class JaxPlannerStoppingRule:
1183
+ class JaxPlannerStoppingRule(metaclass=ABCMeta):
1151
1184
  '''The base class of all planner stopping rules.'''
1152
1185
 
1186
+ @abstractmethod
1153
1187
  def reset(self) -> None:
1154
- raise NotImplementedError
1155
-
1188
+ pass
1189
+
1190
+ @abstractmethod
1156
1191
  def monitor(self, callback: Dict[str, Any]) -> bool:
1157
- raise NotImplementedError
1192
+ pass
1158
1193
 
1159
1194
 
1160
1195
  class NoImprovementStoppingRule(JaxPlannerStoppingRule):
@@ -1168,8 +1203,7 @@ class NoImprovementStoppingRule(JaxPlannerStoppingRule):
1168
1203
  self.iters_since_last_update = 0
1169
1204
 
1170
1205
  def monitor(self, callback: Dict[str, Any]) -> bool:
1171
- if self.callback is None \
1172
- or callback['best_return'] > self.callback['best_return']:
1206
+ if self.callback is None or callback['best_return'] > self.callback['best_return']:
1173
1207
  self.callback = callback
1174
1208
  self.iters_since_last_update = 0
1175
1209
  else:
@@ -1188,7 +1222,7 @@ class NoImprovementStoppingRule(JaxPlannerStoppingRule):
1188
1222
  # ***********************************************************************
1189
1223
 
1190
1224
 
1191
- class PGPE:
1225
+ class PGPE(metaclass=ABCMeta):
1192
1226
  """Base class for all PGPE strategies."""
1193
1227
 
1194
1228
  def __init__(self) -> None:
@@ -1203,8 +1237,11 @@ class PGPE:
1203
1237
  def update(self):
1204
1238
  return self._update
1205
1239
 
1206
- def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type) -> None:
1207
- raise NotImplementedError
1240
+ @abstractmethod
1241
+ def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type,
1242
+ print_warnings: bool,
1243
+ parallel_updates: Optional[int]=None) -> None:
1244
+ pass
1208
1245
 
1209
1246
 
1210
1247
  class GaussianPGPE(PGPE):
@@ -1268,10 +1305,11 @@ class GaussianPGPE(PGPE):
1268
1305
  mu_optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs_mu)
1269
1306
  sigma_optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs_sigma)
1270
1307
  except Exception as _:
1271
- raise_warning(
1272
- f'Failed to inject hyperparameters into optax optimizer for PGPE, '
1273
- 'rolling back to safer method: please note that kl-divergence '
1274
- 'constraints will be disabled.', 'red')
1308
+ message = termcolor.colored(
1309
+ '[FAIL] Failed to inject hyperparameters into PGPE optimizer, '
1310
+ 'rolling back to safer method: '
1311
+ 'kl-divergence constraint will be disabled.', 'red')
1312
+ print(message)
1275
1313
  mu_optimizer = optimizer(**optimizer_kwargs_mu)
1276
1314
  sigma_optimizer = optimizer(**optimizer_kwargs_sigma)
1277
1315
  max_kl_update = None
@@ -1297,15 +1335,17 @@ class GaussianPGPE(PGPE):
1297
1335
  f' max_kl_update ={self.max_kl}\n'
1298
1336
  )
1299
1337
 
1300
- def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type) -> None:
1338
+ def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type,
1339
+ print_warnings: bool,
1340
+ parallel_updates: Optional[int]=None) -> None:
1301
1341
  sigma0 = self.init_sigma
1302
- sigma_range = self.sigma_range
1342
+ sigma_lo, sigma_hi = self.sigma_range
1303
1343
  scale_reward = self.scale_reward
1304
1344
  min_reward_scale = self.min_reward_scale
1305
1345
  super_symmetric = self.super_symmetric
1306
1346
  super_symmetric_accurate = self.super_symmetric_accurate
1307
1347
  batch_size = self.batch_size
1308
- optimizers = (mu_optimizer, sigma_optimizer) = self.optimizers
1348
+ mu_optimizer, sigma_optimizer = self.optimizers
1309
1349
  max_kl = self.max_kl
1310
1350
 
1311
1351
  # entropy regularization penalty is decayed exponentially by elapsed budget
@@ -1322,13 +1362,22 @@ class GaussianPGPE(PGPE):
1322
1362
 
1323
1363
  def _jax_wrapped_pgpe_init(key, policy_params):
1324
1364
  mu = policy_params
1325
- sigma = jax.tree_map(lambda x: sigma0 * jnp.ones_like(x), mu)
1365
+ sigma = jax.tree_util.tree_map(partial(jnp.full_like, fill_value=sigma0), mu)
1326
1366
  pgpe_params = (mu, sigma)
1327
- pgpe_opt_state = tuple(opt.init(param)
1328
- for (opt, param) in zip(optimizers, pgpe_params))
1329
- return pgpe_params, pgpe_opt_state
1367
+ pgpe_opt_state = (mu_optimizer.init(mu), sigma_optimizer.init(sigma))
1368
+ r_max = -jnp.inf
1369
+ return pgpe_params, pgpe_opt_state, r_max
1330
1370
 
1331
- self._initializer = jax.jit(_jax_wrapped_pgpe_init)
1371
+ if parallel_updates is None:
1372
+ self._initializer = jax.jit(_jax_wrapped_pgpe_init)
1373
+ else:
1374
+
1375
+ # for parallel policy update
1376
+ def _jax_wrapped_pgpe_inits(key, policy_params):
1377
+ keys = jnp.asarray(random.split(key, num=parallel_updates))
1378
+ return jax.vmap(_jax_wrapped_pgpe_init, in_axes=0)(keys, policy_params)
1379
+
1380
+ self._initializer = jax.jit(_jax_wrapped_pgpe_inits)
1332
1381
 
1333
1382
  # ***********************************************************************
1334
1383
  # PARAMETER SAMPLING FUNCTIONS
@@ -1338,6 +1387,8 @@ class GaussianPGPE(PGPE):
1338
1387
  def _jax_wrapped_mu_noise(key, sigma):
1339
1388
  return sigma * random.normal(key, shape=jnp.shape(sigma), dtype=real_dtype)
1340
1389
 
1390
+ # this samples a noise variable epsilon* from epsilon with the N(0, 1) density
1391
+ # according to super-symmetric sampling paper
1341
1392
  def _jax_wrapped_epsilon_star(sigma, epsilon):
1342
1393
  c1, c2, c3 = -0.06655, -0.9706, 0.124
1343
1394
  phi = 0.67449 * sigma
@@ -1354,17 +1405,19 @@ class GaussianPGPE(PGPE):
1354
1405
  epsilon_star = jnp.sign(epsilon) * phi * jnp.exp(a)
1355
1406
  return epsilon_star
1356
1407
 
1408
+ # implements baseline-free super-symmetric sampling to generate 4 trajectories
1357
1409
  def _jax_wrapped_sample_params(key, mu, sigma):
1358
1410
  treedef = jax.tree_util.tree_structure(sigma)
1359
1411
  keys = random.split(key, num=treedef.num_leaves)
1360
1412
  keys_pytree = jax.tree_util.tree_unflatten(treedef=treedef, leaves=keys)
1361
- epsilon = jax.tree_map(_jax_wrapped_mu_noise, keys_pytree, sigma)
1362
- p1 = jax.tree_map(jnp.add, mu, epsilon)
1363
- p2 = jax.tree_map(jnp.subtract, mu, epsilon)
1413
+ epsilon = jax.tree_util.tree_map(_jax_wrapped_mu_noise, keys_pytree, sigma)
1414
+ p1 = jax.tree_util.tree_map(jnp.add, mu, epsilon)
1415
+ p2 = jax.tree_util.tree_map(jnp.subtract, mu, epsilon)
1364
1416
  if super_symmetric:
1365
- epsilon_star = jax.tree_map(_jax_wrapped_epsilon_star, sigma, epsilon)
1366
- p3 = jax.tree_map(jnp.add, mu, epsilon_star)
1367
- p4 = jax.tree_map(jnp.subtract, mu, epsilon_star)
1417
+ epsilon_star = jax.tree_util.tree_map(
1418
+ _jax_wrapped_epsilon_star, sigma, epsilon)
1419
+ p3 = jax.tree_util.tree_map(jnp.add, mu, epsilon_star)
1420
+ p4 = jax.tree_util.tree_map(jnp.subtract, mu, epsilon_star)
1368
1421
  else:
1369
1422
  epsilon_star, p3, p4 = epsilon, p1, p2
1370
1423
  return p1, p2, p3, p4, epsilon, epsilon_star
@@ -1374,6 +1427,7 @@ class GaussianPGPE(PGPE):
1374
1427
  #
1375
1428
  # ***********************************************************************
1376
1429
 
1430
+ # gradient with respect to mean
1377
1431
  def _jax_wrapped_mu_grad(epsilon, epsilon_star, r1, r2, r3, r4, m):
1378
1432
  if super_symmetric:
1379
1433
  if scale_reward:
@@ -1393,6 +1447,7 @@ class GaussianPGPE(PGPE):
1393
1447
  grad = -r_mu * epsilon
1394
1448
  return grad
1395
1449
 
1450
+ # gradient with respect to std. deviation
1396
1451
  def _jax_wrapped_sigma_grad(epsilon, epsilon_star, sigma, r1, r2, r3, r4, m, ent):
1397
1452
  if super_symmetric:
1398
1453
  mask = r1 + r2 >= r3 + r4
@@ -1413,6 +1468,7 @@ class GaussianPGPE(PGPE):
1413
1468
  grad = -(r_sigma * s + ent / sigma)
1414
1469
  return grad
1415
1470
 
1471
+ # calculate the policy gradients
1416
1472
  def _jax_wrapped_pgpe_grad(key, mu, sigma, r_max, ent,
1417
1473
  policy_hyperparams, subs, model_params):
1418
1474
  key, subkey = random.split(key)
@@ -1429,11 +1485,11 @@ class GaussianPGPE(PGPE):
1429
1485
  r_max = jnp.maximum(r_max, r4)
1430
1486
  else:
1431
1487
  r3, r4 = r1, r2
1432
- grad_mu = jax.tree_map(
1488
+ grad_mu = jax.tree_util.tree_map(
1433
1489
  partial(_jax_wrapped_mu_grad, r1=r1, r2=r2, r3=r3, r4=r4, m=r_max),
1434
1490
  epsilon, epsilon_star
1435
1491
  )
1436
- grad_sigma = jax.tree_map(
1492
+ grad_sigma = jax.tree_util.tree_map(
1437
1493
  partial(_jax_wrapped_sigma_grad,
1438
1494
  r1=r1, r2=r2, r3=r3, r4=r4, m=r_max, ent=ent),
1439
1495
  epsilon, epsilon_star, sigma
@@ -1452,7 +1508,7 @@ class GaussianPGPE(PGPE):
1452
1508
  _jax_wrapped_pgpe_grad,
1453
1509
  in_axes=(0, None, None, None, None, None, None, None)
1454
1510
  )(keys, mu, sigma, r_max, ent, policy_hyperparams, subs, model_params)
1455
- mu_grad, sigma_grad = jax.tree_map(
1511
+ mu_grad, sigma_grad = jax.tree_util.tree_map(
1456
1512
  partial(jnp.mean, axis=0), (mu_grads, sigma_grads))
1457
1513
  new_r_max = jnp.max(r_maxs)
1458
1514
  return mu_grad, sigma_grad, new_r_max
@@ -1462,11 +1518,24 @@ class GaussianPGPE(PGPE):
1462
1518
  #
1463
1519
  # ***********************************************************************
1464
1520
 
1521
+ # estimate KL divergence between two updates
1465
1522
  def _jax_wrapped_pgpe_kl_term(mu, sigma, old_mu, old_sigma):
1466
1523
  return 0.5 * jnp.sum(2 * jnp.log(sigma / old_sigma) +
1467
1524
  jnp.square(old_sigma / sigma) +
1468
1525
  jnp.square((mu - old_mu) / sigma) - 1)
1469
1526
 
1527
+ # update mean and std. deviation with a gradient step
1528
+ def _jax_wrapped_pgpe_update_helper(mu, sigma, mu_grad, sigma_grad,
1529
+ mu_state, sigma_state):
1530
+ mu_updates, new_mu_state = mu_optimizer.update(mu_grad, mu_state, params=mu)
1531
+ sigma_updates, new_sigma_state = sigma_optimizer.update(
1532
+ sigma_grad, sigma_state, params=sigma)
1533
+ new_mu = optax.apply_updates(mu, mu_updates)
1534
+ new_sigma = optax.apply_updates(sigma, sigma_updates)
1535
+ new_sigma = jax.tree_util.tree_map(
1536
+ partial(jnp.clip, min=sigma_lo, max=sigma_hi), new_sigma)
1537
+ return new_mu, new_sigma, new_mu_state, new_sigma_state
1538
+
1470
1539
  def _jax_wrapped_pgpe_update(key, pgpe_params, r_max, progress,
1471
1540
  policy_hyperparams, subs, model_params,
1472
1541
  pgpe_opt_state):
@@ -1476,29 +1545,23 @@ class GaussianPGPE(PGPE):
1476
1545
  ent = start_entropy_coeff * jnp.power(entropy_coeff_decay, progress)
1477
1546
  mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad_batched(
1478
1547
  key, pgpe_params, r_max, ent, policy_hyperparams, subs, model_params)
1479
- mu_updates, new_mu_state = mu_optimizer.update(mu_grad, mu_state, params=mu)
1480
- sigma_updates, new_sigma_state = sigma_optimizer.update(
1481
- sigma_grad, sigma_state, params=sigma)
1482
- new_mu = optax.apply_updates(mu, mu_updates)
1483
- new_sigma = optax.apply_updates(sigma, sigma_updates)
1484
- new_sigma = jax.tree_map(lambda x: jnp.clip(x, *sigma_range), new_sigma)
1548
+ new_mu, new_sigma, new_mu_state, new_sigma_state = \
1549
+ _jax_wrapped_pgpe_update_helper(mu, sigma, mu_grad, sigma_grad,
1550
+ mu_state, sigma_state)
1485
1551
 
1486
1552
  # respect KL divergence contraint with old parameters
1487
1553
  if max_kl is not None:
1488
1554
  old_mu_lr = new_mu_state.hyperparams['learning_rate']
1489
1555
  old_sigma_lr = new_sigma_state.hyperparams['learning_rate']
1490
- kl_terms = jax.tree_map(
1556
+ kl_terms = jax.tree_util.tree_map(
1491
1557
  _jax_wrapped_pgpe_kl_term, new_mu, new_sigma, mu, sigma)
1492
1558
  total_kl = jax.tree_util.tree_reduce(jnp.add, kl_terms)
1493
1559
  kl_reduction = jnp.minimum(1.0, jnp.sqrt(max_kl / total_kl))
1494
1560
  mu_state.hyperparams['learning_rate'] = old_mu_lr * kl_reduction
1495
1561
  sigma_state.hyperparams['learning_rate'] = old_sigma_lr * kl_reduction
1496
- mu_updates, new_mu_state = mu_optimizer.update(mu_grad, mu_state, params=mu)
1497
- sigma_updates, new_sigma_state = sigma_optimizer.update(
1498
- sigma_grad, sigma_state, params=sigma)
1499
- new_mu = optax.apply_updates(mu, mu_updates)
1500
- new_sigma = optax.apply_updates(sigma, sigma_updates)
1501
- new_sigma = jax.tree_map(lambda x: jnp.clip(x, *sigma_range), new_sigma)
1562
+ new_mu, new_sigma, new_mu_state, new_sigma_state = \
1563
+ _jax_wrapped_pgpe_update_helper(mu, sigma, mu_grad, sigma_grad,
1564
+ mu_state, sigma_state)
1502
1565
  new_mu_state.hyperparams['learning_rate'] = old_mu_lr
1503
1566
  new_sigma_state.hyperparams['learning_rate'] = old_sigma_lr
1504
1567
 
@@ -1509,7 +1572,21 @@ class GaussianPGPE(PGPE):
1509
1572
  policy_params = new_mu
1510
1573
  return new_pgpe_params, new_r_max, new_pgpe_opt_state, policy_params, converged
1511
1574
 
1512
- self._update = jax.jit(_jax_wrapped_pgpe_update)
1575
+ if parallel_updates is None:
1576
+ self._update = jax.jit(_jax_wrapped_pgpe_update)
1577
+ else:
1578
+
1579
+ # for parallel policy update
1580
+ def _jax_wrapped_pgpe_updates(key, pgpe_params, r_max, progress,
1581
+ policy_hyperparams, subs, model_params,
1582
+ pgpe_opt_state):
1583
+ keys = jnp.asarray(random.split(key, num=parallel_updates))
1584
+ return jax.vmap(
1585
+ _jax_wrapped_pgpe_update, in_axes=(0, 0, 0, None, None, None, 0, 0)
1586
+ )(keys, pgpe_params, r_max, progress, policy_hyperparams, subs,
1587
+ model_params, pgpe_opt_state)
1588
+
1589
+ self._update = jax.jit(_jax_wrapped_pgpe_updates)
1513
1590
 
1514
1591
 
1515
1592
  # ***********************************************************************
@@ -1565,6 +1642,7 @@ def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
1565
1642
  return jnp.sum(returns * weights)
1566
1643
 
1567
1644
 
1645
+ # set of all currently valid built-in utility functions
1568
1646
  UTILITY_LOOKUP = {
1569
1647
  'mean': jnp.mean,
1570
1648
  'mean_var': mean_variance_utility,
@@ -1609,7 +1687,9 @@ class JaxBackpropPlanner:
1609
1687
  cpfs_without_grad: Optional[Set[str]]=None,
1610
1688
  compile_non_fluent_exact: bool=True,
1611
1689
  logger: Optional[Logger]=None,
1612
- dashboard_viz: Optional[Any]=None) -> None:
1690
+ dashboard_viz: Optional[Any]=None,
1691
+ print_warnings: bool=True,
1692
+ parallel_updates: Optional[int]=None) -> None:
1613
1693
  '''Creates a new gradient-based algorithm for optimizing action sequences
1614
1694
  (plan) in the given RDDL. Some operations will be converted to their
1615
1695
  differentiable counterparts; the specific operations can be customized
@@ -1649,6 +1729,8 @@ class JaxBackpropPlanner:
1649
1729
  :param logger: to log information about compilation to file
1650
1730
  :param dashboard_viz: optional visualizer object from the environment
1651
1731
  to pass to the dashboard to visualize the policy
1732
+ :param print_warnings: whether to print warnings
1733
+ :param parallel_updates: how many optimizers to run independently in parallel
1652
1734
  '''
1653
1735
  self.rddl = rddl
1654
1736
  self.plan = plan
@@ -1656,6 +1738,7 @@ class JaxBackpropPlanner:
1656
1738
  if batch_size_test is None:
1657
1739
  batch_size_test = batch_size_train
1658
1740
  self.batch_size_test = batch_size_test
1741
+ self.parallel_updates = parallel_updates
1659
1742
  if rollout_horizon is None:
1660
1743
  rollout_horizon = rddl.horizon
1661
1744
  self.horizon = rollout_horizon
@@ -1672,15 +1755,17 @@ class JaxBackpropPlanner:
1672
1755
  self.noise_kwargs = noise_kwargs
1673
1756
  self.pgpe = pgpe
1674
1757
  self.use_pgpe = pgpe is not None
1758
+ self.print_warnings = print_warnings
1675
1759
 
1676
1760
  # set optimizer
1677
1761
  try:
1678
1762
  optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs)
1679
1763
  except Exception as _:
1680
- raise_warning(
1681
- f'Failed to inject hyperparameters into optax optimizer {optimizer}, '
1682
- 'rolling back to safer method: please note that modification of '
1683
- 'optimizer hyperparameters will not work.', 'red')
1764
+ message = termcolor.colored(
1765
+ '[FAIL] Failed to inject hyperparameters into JaxPlan optimizer, '
1766
+ 'rolling back to safer method: please note that runtime modification of '
1767
+ 'hyperparameters will be disabled.', 'red')
1768
+ print(message)
1684
1769
  optimizer = optimizer(**optimizer_kwargs)
1685
1770
 
1686
1771
  # apply optimizer chain of transformations
@@ -1700,7 +1785,7 @@ class JaxBackpropPlanner:
1700
1785
  utility_fn = UTILITY_LOOKUP.get(utility, None)
1701
1786
  if utility_fn is None:
1702
1787
  raise RDDLNotImplementedError(
1703
- f'Utility <{utility}> is not supported, '
1788
+ f'Utility function <{utility}> is not supported, '
1704
1789
  f'must be one of {list(UTILITY_LOOKUP.keys())}.')
1705
1790
  else:
1706
1791
  utility_fn = utility
@@ -1723,7 +1808,11 @@ class JaxBackpropPlanner:
1723
1808
  self._jax_compile_rddl()
1724
1809
  self._jax_compile_optimizer()
1725
1810
 
1726
- def summarize_system(self) -> str:
1811
+ @staticmethod
1812
+ def summarize_system() -> str:
1813
+ '''Returns a string containing information about the system, Python version
1814
+ and jax-related packages that are relevant to the current planner.
1815
+ '''
1727
1816
  try:
1728
1817
  jaxlib_version = jax._src.lib.version_str
1729
1818
  except Exception as _:
@@ -1742,7 +1831,7 @@ r"""
1742
1831
  \/_____/ \/_/\/_/ \/_/\/_/ \/_/ \/_____/ \/_/\/_/ \/_/ \/_/
1743
1832
  """
1744
1833
 
1745
- return ('\n'
1834
+ return (f'\n'
1746
1835
  f'{LOGO}\n'
1747
1836
  f'Version {__version__}\n'
1748
1837
  f'Python {sys.version}\n'
@@ -1751,7 +1840,29 @@ r"""
1751
1840
  f'numpy {np.__version__}\n'
1752
1841
  f'devices: {devices_short}\n')
1753
1842
 
1754
- def __str__(self) -> str:
1843
+ def summarize_relaxations(self) -> str:
1844
+ '''Returns a summary table containing all non-differentiable operators
1845
+ and their relaxations.
1846
+ '''
1847
+ result = ''
1848
+ if self.compiled.model_params:
1849
+ result += ('Some RDDL operations are non-differentiable '
1850
+ 'and will be approximated as follows:' + '\n')
1851
+ exprs_by_rddl_op, values_by_rddl_op = {}, {}
1852
+ for info in self.compiled.model_parameter_info().values():
1853
+ rddl_op = info['rddl_op']
1854
+ exprs_by_rddl_op.setdefault(rddl_op, []).append(info['id'])
1855
+ values_by_rddl_op.setdefault(rddl_op, []).append(info['init_value'])
1856
+ for rddl_op in sorted(exprs_by_rddl_op.keys()):
1857
+ result += (f' {rddl_op}:\n'
1858
+ f' addresses ={exprs_by_rddl_op[rddl_op]}\n'
1859
+ f' init_values={values_by_rddl_op[rddl_op]}\n')
1860
+ return result
1861
+
1862
+ def summarize_hyperparameters(self) -> str:
1863
+ '''Returns a string summarizing the hyper-parameters of the current planner
1864
+ instance.
1865
+ '''
1755
1866
  result = (f'objective hyper-parameters:\n'
1756
1867
  f' utility_fn ={self.utility.__name__}\n'
1757
1868
  f' utility args ={self.utility_kwargs}\n'
@@ -1769,30 +1880,14 @@ r"""
1769
1880
  f' line_search_kwargs={self.line_search_kwargs}\n'
1770
1881
  f' noise_kwargs ={self.noise_kwargs}\n'
1771
1882
  f' batch_size_train ={self.batch_size_train}\n'
1772
- f' batch_size_test ={self.batch_size_test}\n')
1883
+ f' batch_size_test ={self.batch_size_test}\n'
1884
+ f' parallel_updates ={self.parallel_updates}\n')
1773
1885
  result += str(self.plan)
1774
1886
  if self.use_pgpe:
1775
1887
  result += str(self.pgpe)
1776
1888
  result += str(self.logic)
1777
-
1778
- # print model relaxation information
1779
- if self.compiled.model_params:
1780
- result += ('Some RDDL operations are non-differentiable '
1781
- 'and will be approximated as follows:' + '\n')
1782
- exprs_by_rddl_op, values_by_rddl_op = {}, {}
1783
- for info in self.compiled.model_parameter_info().values():
1784
- rddl_op = info['rddl_op']
1785
- exprs_by_rddl_op.setdefault(rddl_op, []).append(info['id'])
1786
- values_by_rddl_op.setdefault(rddl_op, []).append(info['init_value'])
1787
- for rddl_op in sorted(exprs_by_rddl_op.keys()):
1788
- result += (f' {rddl_op}:\n'
1789
- f' addresses ={exprs_by_rddl_op[rddl_op]}\n'
1790
- f' init_values={values_by_rddl_op[rddl_op]}\n')
1791
1889
  return result
1792
1890
 
1793
- def summarize_hyperparameters(self) -> None:
1794
- print(self.__str__())
1795
-
1796
1891
  # ===========================================================================
1797
1892
  # COMPILATION SUBROUTINES
1798
1893
  # ===========================================================================
@@ -1807,7 +1902,8 @@ r"""
1807
1902
  logger=self.logger,
1808
1903
  use64bit=self.use64bit,
1809
1904
  cpfs_without_grad=self.cpfs_without_grad,
1810
- compile_non_fluent_exact=self.compile_non_fluent_exact
1905
+ compile_non_fluent_exact=self.compile_non_fluent_exact,
1906
+ print_warnings=self.print_warnings
1811
1907
  )
1812
1908
  self.compiled.compile(log_jax_expr=True, heading='RELAXED MODEL')
1813
1909
 
@@ -1844,23 +1940,33 @@ r"""
1844
1940
  self.test_rollouts = jax.jit(test_rollouts)
1845
1941
 
1846
1942
  # initialization
1847
- self.initialize = jax.jit(self._jax_init())
1943
+ self.initialize, self.init_optimizer = self._jax_init()
1848
1944
 
1849
1945
  # losses
1850
1946
  train_loss = self._jax_loss(train_rollouts, use_symlog=self.use_symlog_reward)
1851
- self.test_loss = jax.jit(self._jax_loss(test_rollouts, use_symlog=False))
1947
+ test_loss = self._jax_loss(test_rollouts, use_symlog=False)
1948
+ if self.parallel_updates is None:
1949
+ self.test_loss = jax.jit(test_loss)
1950
+ else:
1951
+ self.test_loss = jax.jit(jax.vmap(test_loss, in_axes=(None, 0, None, None, 0)))
1852
1952
 
1853
1953
  # optimization
1854
1954
  self.update = self._jax_update(train_loss)
1955
+ self.pytree_at = jax.jit(
1956
+ lambda tree, i: jax.tree_util.tree_map(lambda x: x[i], tree))
1855
1957
 
1856
1958
  # pgpe option
1857
1959
  if self.use_pgpe:
1858
- loss_fn = self._jax_loss(rollouts=test_rollouts)
1859
1960
  self.pgpe.compile(
1860
- loss_fn=loss_fn,
1961
+ loss_fn=test_loss,
1861
1962
  projection=self.plan.projection,
1862
- real_dtype=self.test_compiled.REAL
1963
+ real_dtype=self.test_compiled.REAL,
1964
+ print_warnings=self.print_warnings,
1965
+ parallel_updates=self.parallel_updates
1863
1966
  )
1967
+ self.merge_pgpe = self._jax_merge_pgpe_jaxplan()
1968
+ else:
1969
+ self.merge_pgpe = None
1864
1970
 
1865
1971
  def _jax_return(self, use_symlog):
1866
1972
  gamma = self.rddl.discount
@@ -1900,24 +2006,43 @@ r"""
1900
2006
  def _jax_init(self):
1901
2007
  init = self.plan.initializer
1902
2008
  optimizer = self.optimizer
2009
+ num_parallel = self.parallel_updates
1903
2010
 
1904
2011
  # initialize both the policy and its optimizer
1905
2012
  def _jax_wrapped_init_policy(key, policy_hyperparams, subs):
1906
2013
  policy_params = init(key, policy_hyperparams, subs)
1907
2014
  opt_state = optimizer.init(policy_params)
1908
- return policy_params, opt_state, {}
2015
+ return policy_params, opt_state, {}
2016
+
2017
+ # initialize just the optimizer from the policy
2018
+ def _jax_wrapped_init_opt(policy_params):
2019
+ if num_parallel is None:
2020
+ opt_state = optimizer.init(policy_params)
2021
+ else:
2022
+ opt_state = jax.vmap(optimizer.init, in_axes=0)(policy_params)
2023
+ return opt_state, {}
2024
+
2025
+ if num_parallel is None:
2026
+ return jax.jit(_jax_wrapped_init_policy), jax.jit(_jax_wrapped_init_opt)
1909
2027
 
1910
- return _jax_wrapped_init_policy
2028
+ # for parallel policy update
2029
+ def _jax_wrapped_init_policies(key, policy_hyperparams, subs):
2030
+ keys = jnp.asarray(random.split(key, num=num_parallel))
2031
+ return jax.vmap(_jax_wrapped_init_policy, in_axes=(0, None, None))(
2032
+ keys, policy_hyperparams, subs)
2033
+
2034
+ return jax.jit(_jax_wrapped_init_policies), jax.jit(_jax_wrapped_init_opt)
1911
2035
 
1912
2036
  def _jax_update(self, loss):
1913
2037
  optimizer = self.optimizer
1914
2038
  projection = self.plan.projection
1915
2039
  use_ls = self.line_search_kwargs is not None
2040
+ num_parallel = self.parallel_updates
1916
2041
 
1917
2042
  # check if the gradients are all zeros
1918
2043
  def _jax_wrapped_zero_gradients(grad):
1919
2044
  leaves, _ = jax.tree_util.tree_flatten(
1920
- jax.tree_map(lambda g: jnp.allclose(g, 0), grad))
2045
+ jax.tree_util.tree_map(partial(jnp.allclose, b=0), grad))
1921
2046
  return jnp.all(jnp.asarray(leaves))
1922
2047
 
1923
2048
  # calculate the plan gradient w.r.t. return loss and update optimizer
@@ -1948,8 +2073,43 @@ r"""
1948
2073
  return policy_params, converged, opt_state, opt_aux, \
1949
2074
  loss_val, log, model_params, zero_grads
1950
2075
 
1951
- return jax.jit(_jax_wrapped_plan_update)
2076
+ if num_parallel is None:
2077
+ return jax.jit(_jax_wrapped_plan_update)
2078
+
2079
+ # for parallel policy update
2080
+ def _jax_wrapped_plan_updates(key, policy_params, policy_hyperparams,
2081
+ subs, model_params, opt_state, opt_aux):
2082
+ keys = jnp.asarray(random.split(key, num=num_parallel))
2083
+ return jax.vmap(
2084
+ _jax_wrapped_plan_update, in_axes=(0, 0, None, None, 0, 0, 0)
2085
+ )(keys, policy_params, policy_hyperparams, subs, model_params,
2086
+ opt_state, opt_aux)
2087
+
2088
+ return jax.jit(_jax_wrapped_plan_updates)
1952
2089
 
2090
+ def _jax_merge_pgpe_jaxplan(self):
2091
+ if self.parallel_updates is None:
2092
+ return None
2093
+
2094
+ # for parallel policy update
2095
+ # currently implements a hard replacement where the jaxplan parameter
2096
+ # is replaced by the PGPE parameter if the latter is an improvement
2097
+ def _jax_wrapped_pgpe_jaxplan_merge(pgpe_mask, pgpe_param, policy_params,
2098
+ pgpe_loss, test_loss,
2099
+ pgpe_loss_smooth, test_loss_smooth,
2100
+ pgpe_converged, converged):
2101
+ def select_fn(leaf1, leaf2):
2102
+ expanded_mask = pgpe_mask[(...,) + (jnp.newaxis,) * (jnp.ndim(leaf1) - 1)]
2103
+ return jnp.where(expanded_mask, leaf1, leaf2)
2104
+ policy_params = jax.tree_util.tree_map(select_fn, pgpe_param, policy_params)
2105
+ test_loss = jnp.where(pgpe_mask, pgpe_loss, test_loss)
2106
+ test_loss_smooth = jnp.where(pgpe_mask, pgpe_loss_smooth, test_loss_smooth)
2107
+ expanded_mask = pgpe_mask[(...,) + (jnp.newaxis,) * (jnp.ndim(converged) - 1)]
2108
+ converged = jnp.where(expanded_mask, pgpe_converged, converged)
2109
+ return policy_params, test_loss, test_loss_smooth, converged
2110
+
2111
+ return jax.jit(_jax_wrapped_pgpe_jaxplan_merge)
2112
+
1953
2113
  def _batched_init_subs(self, subs):
1954
2114
  rddl = self.rddl
1955
2115
  n_train, n_test = self.batch_size_train, self.batch_size_test
@@ -1963,11 +2123,20 @@ r"""
1963
2123
  f'Variable <{name}> in subs argument is not a '
1964
2124
  f'valid p-variable, must be one of '
1965
2125
  f'{set(self.test_compiled.init_values.keys())}.')
1966
- value = np.reshape(value, newshape=np.shape(init_value))[np.newaxis, ...]
2126
+ value = np.reshape(value, np.shape(init_value))[np.newaxis, ...]
2127
+ if value.dtype.type is np.str_:
2128
+ value = rddl.object_string_to_index_array(rddl.variable_ranges[name], value)
1967
2129
  train_value = np.repeat(value, repeats=n_train, axis=0)
1968
2130
  train_value = np.asarray(train_value, dtype=self.compiled.REAL)
1969
2131
  init_train[name] = train_value
1970
2132
  init_test[name] = np.repeat(value, repeats=n_test, axis=0)
2133
+
2134
+ # safely cast test subs variable to required type in case the type is wrong
2135
+ if name in rddl.variable_ranges:
2136
+ required_type = RDDLValueInitializer.NUMPY_TYPES.get(
2137
+ rddl.variable_ranges[name], RDDLValueInitializer.INT)
2138
+ if np.result_type(init_test[name]) != required_type:
2139
+ init_test[name] = np.asarray(init_test[name], dtype=required_type)
1971
2140
 
1972
2141
  # make sure next-state fluents are also set
1973
2142
  for (state, next_state) in rddl.next_state.items():
@@ -1975,6 +2144,19 @@ r"""
1975
2144
  init_test[next_state] = init_test[state]
1976
2145
  return init_train, init_test
1977
2146
 
2147
+ def _broadcast_pytree(self, pytree):
2148
+ if self.parallel_updates is None:
2149
+ return pytree
2150
+
2151
+ # for parallel policy update
2152
+ def make_batched(x):
2153
+ x = np.asarray(x)
2154
+ x = np.broadcast_to(
2155
+ x[np.newaxis, ...], shape=(self.parallel_updates,) + np.shape(x))
2156
+ return x
2157
+
2158
+ return jax.tree_util.tree_map(make_batched, pytree)
2159
+
1978
2160
  def as_optimization_problem(
1979
2161
  self, key: Optional[random.PRNGKey]=None,
1980
2162
  policy_hyperparams: Optional[Pytree]=None,
@@ -2002,6 +2184,11 @@ r"""
2002
2184
  :param grad_function_updates_key: if True, the gradient function
2003
2185
  updates the PRNG key internally independently of the loss function.
2004
2186
  '''
2187
+
2188
+ # make sure parallel updates are disabled
2189
+ if self.parallel_updates is not None:
2190
+ raise ValueError('Cannot compile static optimization problem '
2191
+ 'when parallel_updates is not None.')
2005
2192
 
2006
2193
  # if PRNG key is not provided
2007
2194
  if key is None:
@@ -2012,8 +2199,11 @@ r"""
2012
2199
  train_subs, _ = self._batched_init_subs(subs)
2013
2200
  model_params = self.compiled.model_params
2014
2201
  if policy_hyperparams is None:
2015
- raise_warning('policy_hyperparams is not set, setting 1.0 for '
2016
- 'all action-fluents which could be suboptimal.')
2202
+ if self.print_warnings:
2203
+ message = termcolor.colored(
2204
+ '[WARN] policy_hyperparams is not set, setting 1.0 for '
2205
+ 'all action-fluents which could be suboptimal.', 'yellow')
2206
+ print(message)
2017
2207
  policy_hyperparams = {action: 1.0
2018
2208
  for action in self.rddl.action_fluents}
2019
2209
 
@@ -2084,10 +2274,12 @@ r"""
2084
2274
  their values: if None initializes all variables from the RDDL instance
2085
2275
  :param guess: initial policy parameters: if None will use the initializer
2086
2276
  specified in this instance
2087
- :param print_summary: whether to print planner header, parameter
2088
- summary, and diagnosis
2277
+ :param print_summary: whether to print planner header and diagnosis
2089
2278
  :param print_progress: whether to print the progress bar during training
2279
+ :param print_hyperparams: whether to print list of hyper-parameter settings
2090
2280
  :param stopping_rule: stopping criterion
2281
+ :param restart_epochs: restart the optimizer from a random policy configuration
2282
+ if there is no progress for this many consecutive iterations
2091
2283
  :param test_rolling_window: the test return is averaged on a rolling
2092
2284
  window of the past test_rolling_window returns when updating the best
2093
2285
  parameters found so far
@@ -2120,7 +2312,9 @@ r"""
2120
2312
  guess: Optional[Pytree]=None,
2121
2313
  print_summary: bool=True,
2122
2314
  print_progress: bool=True,
2315
+ print_hyperparams: bool=False,
2123
2316
  stopping_rule: Optional[JaxPlannerStoppingRule]=None,
2317
+ restart_epochs: int=999999,
2124
2318
  test_rolling_window: int=10,
2125
2319
  tqdm_position: Optional[int]=None) -> Generator[Dict[str, Any], None, None]:
2126
2320
  '''Returns a generator for computing an optimal policy or plan.
@@ -2139,10 +2333,12 @@ r"""
2139
2333
  their values: if None initializes all variables from the RDDL instance
2140
2334
  :param guess: initial policy parameters: if None will use the initializer
2141
2335
  specified in this instance
2142
- :param print_summary: whether to print planner header, parameter
2143
- summary, and diagnosis
2336
+ :param print_summary: whether to print planner header and diagnosis
2144
2337
  :param print_progress: whether to print the progress bar during training
2338
+ :param print_hyperparams: whether to print list of hyper-parameter settings
2145
2339
  :param stopping_rule: stopping criterion
2340
+ :param restart_epochs: restart the optimizer from a random policy configuration
2341
+ if there is no progress for this many consecutive iterations
2146
2342
  :param test_rolling_window: the test return is averaged on a rolling
2147
2343
  window of the past test_rolling_window returns when updating the best
2148
2344
  parameters found so far
@@ -2155,6 +2351,15 @@ r"""
2155
2351
  # INITIALIZATION OF HYPER-PARAMETERS
2156
2352
  # ======================================================================
2157
2353
 
2354
+ # cannot run dashboard with parallel updates
2355
+ if dashboard is not None and self.parallel_updates is not None:
2356
+ if self.print_warnings:
2357
+ message = termcolor.colored(
2358
+ '[WARN] Dashboard is unavailable if parallel_updates is not None: '
2359
+ 'setting dashboard to None.', 'yellow')
2360
+ print(message)
2361
+ dashboard = None
2362
+
2158
2363
  # if PRNG key is not provided
2159
2364
  if key is None:
2160
2365
  key = random.PRNGKey(round(time.time() * 1000))
@@ -2162,15 +2367,21 @@ r"""
2162
2367
 
2163
2368
  # if policy_hyperparams is not provided
2164
2369
  if policy_hyperparams is None:
2165
- raise_warning('policy_hyperparams is not set, setting 1.0 for '
2166
- 'all action-fluents which could be suboptimal.')
2370
+ if self.print_warnings:
2371
+ message = termcolor.colored(
2372
+ '[WARN] policy_hyperparams is not set, setting 1.0 for '
2373
+ 'all action-fluents which could be suboptimal.', 'yellow')
2374
+ print(message)
2167
2375
  policy_hyperparams = {action: 1.0
2168
2376
  for action in self.rddl.action_fluents}
2169
2377
 
2170
2378
  # if policy_hyperparams is a scalar
2171
2379
  elif isinstance(policy_hyperparams, (int, float, np.number)):
2172
- raise_warning(f'policy_hyperparams is {policy_hyperparams}, '
2173
- 'setting this value for all action-fluents.')
2380
+ if self.print_warnings:
2381
+ message = termcolor.colored(
2382
+ f'[INFO] policy_hyperparams is {policy_hyperparams}, '
2383
+ f'setting this value for all action-fluents.', 'green')
2384
+ print(message)
2174
2385
  hyperparam_value = float(policy_hyperparams)
2175
2386
  policy_hyperparams = {action: hyperparam_value
2176
2387
  for action in self.rddl.action_fluents}
@@ -2179,14 +2390,20 @@ r"""
2179
2390
  elif isinstance(policy_hyperparams, dict):
2180
2391
  for action in self.rddl.action_fluents:
2181
2392
  if action not in policy_hyperparams:
2182
- raise_warning(f'policy_hyperparams[{action}] is not set, '
2183
- 'setting 1.0 which could be suboptimal.')
2393
+ if self.print_warnings:
2394
+ message = termcolor.colored(
2395
+ f'[WARN] policy_hyperparams[{action}] is not set, '
2396
+ f'setting 1.0 for missing action-fluents '
2397
+ f'which could be suboptimal.', 'yellow')
2398
+ print(message)
2184
2399
  policy_hyperparams[action] = 1.0
2185
2400
 
2186
2401
  # print summary of parameters:
2187
2402
  if print_summary:
2188
2403
  print(self.summarize_system())
2189
- self.summarize_hyperparameters()
2404
+ print(self.summarize_relaxations())
2405
+ if print_hyperparams:
2406
+ print(self.summarize_hyperparameters())
2190
2407
  print(f'optimize() call hyper-parameters:\n'
2191
2408
  f' PRNG key ={key}\n'
2192
2409
  f' max_iterations ={epochs}\n'
@@ -2200,7 +2417,8 @@ r"""
2200
2417
  f' dashboard_id ={dashboard_id}\n'
2201
2418
  f' print_summary ={print_summary}\n'
2202
2419
  f' print_progress ={print_progress}\n'
2203
- f' stopping_rule ={stopping_rule}\n')
2420
+ f' stopping_rule ={stopping_rule}\n'
2421
+ f' restart_epochs ={restart_epochs}\n')
2204
2422
 
2205
2423
  # ======================================================================
2206
2424
  # INITIALIZATION OF STATE AND POLICY
@@ -2217,16 +2435,18 @@ r"""
2217
2435
  if var not in subs:
2218
2436
  subs[var] = value
2219
2437
  added_pvars_to_subs.append(var)
2220
- if added_pvars_to_subs:
2221
- raise_warning(f'p-variables {added_pvars_to_subs} not in '
2222
- 'provided subs, using their initial values '
2223
- 'from the RDDL files.')
2438
+ if self.print_warnings and added_pvars_to_subs:
2439
+ message = termcolor.colored(
2440
+ f'[INFO] p-variables {added_pvars_to_subs} is not in '
2441
+ f'provided subs, using their initial values.', 'green')
2442
+ print(message)
2224
2443
  train_subs, test_subs = self._batched_init_subs(subs)
2225
2444
 
2226
2445
  # initialize model parameters
2227
2446
  if model_params is None:
2228
2447
  model_params = self.compiled.model_params
2229
- model_params_test = self.test_compiled.model_params
2448
+ model_params = self._broadcast_pytree(model_params)
2449
+ model_params_test = self._broadcast_pytree(self.test_compiled.model_params)
2230
2450
 
2231
2451
  # initialize policy parameters
2232
2452
  if guess is None:
@@ -2234,29 +2454,31 @@ r"""
2234
2454
  policy_params, opt_state, opt_aux = self.initialize(
2235
2455
  subkey, policy_hyperparams, train_subs)
2236
2456
  else:
2237
- policy_params = guess
2238
- opt_state = self.optimizer.init(policy_params)
2239
- opt_aux = {}
2457
+ policy_params = self._broadcast_pytree(guess)
2458
+ opt_state, opt_aux = self.init_optimizer(policy_params)
2240
2459
 
2241
2460
  # initialize pgpe parameters
2242
2461
  if self.use_pgpe:
2243
- pgpe_params, pgpe_opt_state = self.pgpe.initialize(key, policy_params)
2462
+ pgpe_params, pgpe_opt_state, r_max = self.pgpe.initialize(key, policy_params)
2244
2463
  rolling_pgpe_loss = RollingMean(test_rolling_window)
2245
2464
  else:
2246
- pgpe_params, pgpe_opt_state = None, None
2465
+ pgpe_params, pgpe_opt_state, r_max = None, None, None
2247
2466
  rolling_pgpe_loss = None
2248
2467
  total_pgpe_it = 0
2249
- r_max = -jnp.inf
2250
2468
 
2251
2469
  # ======================================================================
2252
2470
  # INITIALIZATION OF RUNNING STATISTICS
2253
2471
  # ======================================================================
2254
2472
 
2255
2473
  # initialize running statistics
2256
- best_params, best_loss, best_grad = policy_params, jnp.inf, None
2474
+ if self.parallel_updates is None:
2475
+ best_params = policy_params
2476
+ else:
2477
+ best_params = self.pytree_at(policy_params, 0)
2478
+ best_loss, pbest_loss, best_grad = np.inf, np.inf, None
2257
2479
  last_iter_improve = 0
2480
+ no_progress_count = 0
2258
2481
  rolling_test_loss = RollingMean(test_rolling_window)
2259
- log = {}
2260
2482
  status = JaxPlannerStatus.NORMAL
2261
2483
  progress_percent = 0
2262
2484
 
@@ -2277,6 +2499,11 @@ r"""
2277
2499
  else:
2278
2500
  progress_bar = None
2279
2501
  position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
2502
+
2503
+ # error handlers (to avoid spam messaging)
2504
+ policy_constraint_msg_shown = False
2505
+ jax_train_msg_shown = False
2506
+ jax_test_msg_shown = False
2280
2507
 
2281
2508
  # ======================================================================
2282
2509
  # MAIN TRAINING LOOP BEGINS
@@ -2296,8 +2523,13 @@ r"""
2296
2523
  model_params, zero_grads) = self.update(
2297
2524
  subkey, policy_params, policy_hyperparams, train_subs, model_params,
2298
2525
  opt_state, opt_aux)
2526
+
2527
+ # evaluate
2299
2528
  test_loss, (test_log, model_params_test) = self.test_loss(
2300
2529
  subkey, policy_params, policy_hyperparams, test_subs, model_params_test)
2530
+ if self.parallel_updates:
2531
+ train_loss = np.asarray(train_loss)
2532
+ test_loss = np.asarray(test_loss)
2301
2533
  test_loss_smooth = rolling_test_loss.update(test_loss)
2302
2534
 
2303
2535
  # pgpe update of the plan
@@ -2308,52 +2540,112 @@ r"""
2308
2540
  self.pgpe.update(subkey, pgpe_params, r_max, progress_percent,
2309
2541
  policy_hyperparams, test_subs, model_params_test,
2310
2542
  pgpe_opt_state)
2543
+
2544
+ # evaluate
2311
2545
  pgpe_loss, _ = self.test_loss(
2312
2546
  subkey, pgpe_param, policy_hyperparams, test_subs, model_params_test)
2547
+ if self.parallel_updates:
2548
+ pgpe_loss = np.asarray(pgpe_loss)
2313
2549
  pgpe_loss_smooth = rolling_pgpe_loss.update(pgpe_loss)
2314
2550
  pgpe_return = -pgpe_loss_smooth
2315
2551
 
2316
- # replace with PGPE if it reaches a new minimum or train loss invalid
2317
- if pgpe_loss_smooth < best_loss or not np.isfinite(train_loss):
2318
- policy_params = pgpe_param
2319
- test_loss, test_loss_smooth = pgpe_loss, pgpe_loss_smooth
2320
- converged = pgpe_converged
2321
- pgpe_improve = True
2322
- total_pgpe_it += 1
2552
+ # replace JaxPlan with PGPE if new minimum reached or train loss invalid
2553
+ if self.parallel_updates is None:
2554
+ if pgpe_loss_smooth < best_loss or not np.isfinite(train_loss):
2555
+ policy_params = pgpe_param
2556
+ test_loss, test_loss_smooth = pgpe_loss, pgpe_loss_smooth
2557
+ converged = pgpe_converged
2558
+ pgpe_improve = True
2559
+ total_pgpe_it += 1
2560
+ else:
2561
+ pgpe_mask = (pgpe_loss_smooth < pbest_loss) | ~np.isfinite(train_loss)
2562
+ if np.any(pgpe_mask):
2563
+ policy_params, test_loss, test_loss_smooth, converged = \
2564
+ self.merge_pgpe(pgpe_mask, pgpe_param, policy_params,
2565
+ pgpe_loss, test_loss,
2566
+ pgpe_loss_smooth, test_loss_smooth,
2567
+ pgpe_converged, converged)
2568
+ pgpe_improve = True
2569
+ total_pgpe_it += 1
2323
2570
  else:
2324
2571
  pgpe_loss, pgpe_loss_smooth, pgpe_return = None, None, None
2325
2572
 
2326
- # evaluate test losses and record best plan so far
2327
- if test_loss_smooth < best_loss:
2328
- best_params, best_loss, best_grad = \
2329
- policy_params, test_loss_smooth, train_log['grad']
2330
- last_iter_improve = it
2573
+ # evaluate test losses and record best parameters so far
2574
+ if self.parallel_updates is None:
2575
+ if test_loss_smooth < best_loss:
2576
+ best_params, best_loss, best_grad = \
2577
+ policy_params, test_loss_smooth, train_log['grad']
2578
+ pbest_loss = best_loss
2579
+ else:
2580
+ best_index = np.argmin(test_loss_smooth)
2581
+ if test_loss_smooth[best_index] < best_loss:
2582
+ best_params = self.pytree_at(policy_params, best_index)
2583
+ best_grad = self.pytree_at(train_log['grad'], best_index)
2584
+ best_loss = test_loss_smooth[best_index]
2585
+ pbest_loss = np.minimum(pbest_loss, test_loss_smooth)
2331
2586
 
2332
2587
  # ==================================================================
2333
2588
  # STATUS CHECKS AND LOGGING
2334
2589
  # ==================================================================
2335
2590
 
2336
2591
  # no progress
2337
- if (not pgpe_improve) and zero_grads:
2592
+ no_progress_flag = (not pgpe_improve) and np.all(zero_grads)
2593
+ if no_progress_flag:
2338
2594
  status = JaxPlannerStatus.NO_PROGRESS
2339
-
2595
+
2340
2596
  # constraint satisfaction problem
2341
- if not np.all(converged):
2342
- raise_warning(
2343
- 'Projected gradient method for satisfying action concurrency '
2344
- 'constraints reached the iteration limit: plan is possibly '
2345
- 'invalid for the current instance.', 'red')
2597
+ if not np.all(converged):
2598
+ if progress_bar is not None and not policy_constraint_msg_shown:
2599
+ message = termcolor.colored(
2600
+ '[FAIL] Policy update failed to satisfy action constraints.',
2601
+ 'red')
2602
+ progress_bar.write(message)
2603
+ policy_constraint_msg_shown = True
2346
2604
  status = JaxPlannerStatus.PRECONDITION_POSSIBLY_UNSATISFIED
2347
2605
 
2348
2606
  # numerical error
2349
2607
  if self.use_pgpe:
2350
- invalid_loss = not (np.isfinite(train_loss) or np.isfinite(pgpe_loss))
2608
+ invalid_loss = not (np.any(np.isfinite(train_loss)) or
2609
+ np.any(np.isfinite(pgpe_loss)))
2351
2610
  else:
2352
- invalid_loss = not np.isfinite(train_loss)
2611
+ invalid_loss = not np.any(np.isfinite(train_loss))
2353
2612
  if invalid_loss:
2354
- raise_warning(f'Planner aborted due to invalid loss {train_loss}.', 'red')
2613
+ if progress_bar is not None:
2614
+ message = termcolor.colored(
2615
+ f'[FAIL] Planner aborted due to invalid train loss {train_loss}.',
2616
+ 'red')
2617
+ progress_bar.write(message)
2355
2618
  status = JaxPlannerStatus.INVALID_GRADIENT
2356
2619
 
2620
+ # problem in the model compilation
2621
+ if progress_bar is not None:
2622
+
2623
+ # train model
2624
+ if not jax_train_msg_shown:
2625
+ messages = set()
2626
+ for error_code in np.unique(train_log['error']):
2627
+ messages.update(JaxRDDLCompiler.get_error_messages(error_code))
2628
+ if messages:
2629
+ messages = '\n '.join(messages)
2630
+ message = termcolor.colored(
2631
+ f'[FAIL] Compiler encountered the following '
2632
+ f'error(s) in the training model:\n {messages}', 'red')
2633
+ progress_bar.write(message)
2634
+ jax_train_msg_shown = True
2635
+
2636
+ # test model
2637
+ if not jax_test_msg_shown:
2638
+ messages = set()
2639
+ for error_code in np.unique(test_log['error']):
2640
+ messages.update(JaxRDDLCompiler.get_error_messages(error_code))
2641
+ if messages:
2642
+ messages = '\n '.join(messages)
2643
+ message = termcolor.colored(
2644
+ f'[FAIL] Compiler encountered the following '
2645
+ f'error(s) in the testing model:\n {messages}', 'red')
2646
+ progress_bar.write(message)
2647
+ jax_test_msg_shown = True
2648
+
2357
2649
  # reached computation budget
2358
2650
  elapsed = time.time() - start_time - elapsed_outside_loop
2359
2651
  if elapsed >= train_seconds:
@@ -2387,20 +2679,39 @@ r"""
2387
2679
  **test_log
2388
2680
  }
2389
2681
 
2682
+ # hard restart
2683
+ if guess is None and no_progress_flag:
2684
+ no_progress_count += 1
2685
+ if no_progress_count > restart_epochs:
2686
+ key, subkey = random.split(key)
2687
+ policy_params, opt_state, opt_aux = self.initialize(
2688
+ subkey, policy_hyperparams, train_subs)
2689
+ no_progress_count = 0
2690
+ if self.print_warnings and progress_bar is not None:
2691
+ message = termcolor.colored(
2692
+ f'[INFO] Optimizer restarted at iteration {it} '
2693
+ f'due to lack of progress.', 'green')
2694
+ progress_bar.write(message)
2695
+ else:
2696
+ no_progress_count = 0
2697
+
2390
2698
  # stopping condition reached
2391
2699
  if stopping_rule is not None and stopping_rule.monitor(callback):
2700
+ if self.print_warnings and progress_bar is not None:
2701
+ message = termcolor.colored(
2702
+ '[SUCC] Stopping rule has been reached.', 'green')
2703
+ progress_bar.write(message)
2392
2704
  callback['status'] = status = JaxPlannerStatus.STOPPING_RULE_REACHED
2393
2705
 
2394
2706
  # if the progress bar is used
2395
2707
  if print_progress:
2396
2708
  progress_bar.set_description(
2397
- f'{position_str} {it:6} it / {-train_loss:14.5f} train / '
2398
- f'{-test_loss_smooth:14.5f} test / {-best_loss:14.5f} best / '
2709
+ f'{position_str} {it:6} it / {-np.min(train_loss):14.5f} train / '
2710
+ f'{-np.min(test_loss_smooth):14.5f} test / {-best_loss:14.5f} best / '
2399
2711
  f'{status.value} status / {total_pgpe_it:6} pgpe',
2400
- refresh=False
2401
- )
2712
+ refresh=False)
2402
2713
  progress_bar.set_postfix_str(
2403
- f"{(it + 1) / (elapsed + 1e-6):.2f}it/s", refresh=False)
2714
+ f'{(it + 1) / (elapsed + 1e-6):.2f}it/s', refresh=False)
2404
2715
  progress_bar.update(progress_percent - progress_bar.n)
2405
2716
 
2406
2717
  # dash-board
@@ -2423,24 +2734,16 @@ r"""
2423
2734
  # release resources
2424
2735
  if print_progress:
2425
2736
  progress_bar.close()
2426
-
2427
- # validate the test return
2428
- if log:
2429
- messages = set()
2430
- for error_code in np.unique(log['error']):
2431
- messages.update(JaxRDDLCompiler.get_error_messages(error_code))
2432
- if messages:
2433
- messages = '\n'.join(messages)
2434
- raise_warning('JAX compiler encountered the following '
2435
- 'error(s) in the original RDDL formulation '
2436
- f'during test evaluation:\n{messages}', 'red')
2737
+ print()
2437
2738
 
2438
2739
  # summarize and test for convergence
2439
2740
  if print_summary:
2440
- grad_norm = jax.tree_map(lambda x: np.linalg.norm(x).item(), best_grad)
2741
+ grad_norm = jax.tree_util.tree_map(
2742
+ lambda x: np.linalg.norm(x).item(), best_grad)
2441
2743
  diagnosis = self._perform_diagnosis(
2442
- last_iter_improve, -train_loss, -test_loss_smooth, -best_loss, grad_norm)
2443
- print(f'summary of optimization:\n'
2744
+ last_iter_improve, -np.min(train_loss), -np.min(test_loss_smooth),
2745
+ -best_loss, grad_norm)
2746
+ print(f'Summary of optimization:\n'
2444
2747
  f' status ={status}\n'
2445
2748
  f' time ={elapsed:.3f} sec.\n'
2446
2749
  f' iterations ={it}\n'
@@ -2453,12 +2756,9 @@ r"""
2453
2756
  max_grad_norm = max(jax.tree_util.tree_leaves(grad_norm))
2454
2757
  grad_is_zero = np.allclose(max_grad_norm, 0)
2455
2758
 
2456
- validation_error = 100 * abs(test_return - train_return) / \
2457
- max(abs(train_return), abs(test_return))
2458
-
2459
2759
  # divergence if the solution is not finite
2460
2760
  if not np.isfinite(train_return):
2461
- return termcolor.colored('[FAILURE] training loss diverged.', 'red')
2761
+ return termcolor.colored('[FAIL] Training loss diverged.', 'red')
2462
2762
 
2463
2763
  # hit a plateau is likely IF:
2464
2764
  # 1. planner does not improve at all
@@ -2466,23 +2766,25 @@ r"""
2466
2766
  if last_iter_improve <= 1:
2467
2767
  if grad_is_zero:
2468
2768
  return termcolor.colored(
2469
- '[FAILURE] no progress was made '
2769
+ f'[FAIL] No progress was made '
2470
2770
  f'and max grad norm {max_grad_norm:.6f} was zero: '
2471
- 'solver likely stuck in a plateau.', 'red')
2771
+ f'solver likely stuck in a plateau.', 'red')
2472
2772
  else:
2473
2773
  return termcolor.colored(
2474
- '[FAILURE] no progress was made '
2774
+ f'[FAIL] No progress was made '
2475
2775
  f'but max grad norm {max_grad_norm:.6f} was non-zero: '
2476
- 'learning rate or other hyper-parameters likely suboptimal.',
2776
+ f'learning rate or other hyper-parameters could be suboptimal.',
2477
2777
  'red')
2478
2778
 
2479
2779
  # model is likely poor IF:
2480
2780
  # 1. the train and test return disagree
2781
+ validation_error = 100 * abs(test_return - train_return) / \
2782
+ max(abs(train_return), abs(test_return))
2481
2783
  if not (validation_error < 20):
2482
2784
  return termcolor.colored(
2483
- '[WARNING] progress was made '
2785
+ f'[WARN] Progress was made '
2484
2786
  f'but relative train-test error {validation_error:.6f} was high: '
2485
- 'poor model relaxation around solution or batch size too small.',
2787
+ f'poor model relaxation around solution or batch size too small.',
2486
2788
  'yellow')
2487
2789
 
2488
2790
  # model likely did not converge IF:
@@ -2491,24 +2793,22 @@ r"""
2491
2793
  return_to_grad_norm = abs(best_return) / max_grad_norm
2492
2794
  if not (return_to_grad_norm > 1):
2493
2795
  return termcolor.colored(
2494
- '[WARNING] progress was made '
2796
+ f'[WARN] Progress was made '
2495
2797
  f'but max grad norm {max_grad_norm:.6f} was high: '
2496
- 'solution locally suboptimal '
2497
- 'or relaxed model not smooth around solution '
2498
- 'or batch size too small.', 'yellow')
2798
+ f'solution locally suboptimal, relaxed model nonsmooth around solution, '
2799
+ f'or batch size too small.', 'yellow')
2499
2800
 
2500
2801
  # likely successful
2501
2802
  return termcolor.colored(
2502
- '[SUCCESS] solver converged successfully '
2503
- '(note: not all potential problems can be ruled out).', 'green')
2803
+ '[SUCC] Planner converged successfully '
2804
+ '(note: not all problems can be ruled out).', 'green')
2504
2805
 
2505
2806
  def get_action(self, key: random.PRNGKey,
2506
2807
  params: Pytree,
2507
2808
  step: int,
2508
2809
  subs: Dict[str, Any],
2509
2810
  policy_hyperparams: Optional[Dict[str, Any]]=None) -> Dict[str, Any]:
2510
- '''Returns an action dictionary from the policy or plan with the given
2511
- parameters.
2811
+ '''Returns an action dictionary from the policy or plan with the given parameters.
2512
2812
 
2513
2813
  :param key: the JAX PRNG key
2514
2814
  :param params: the trainable parameter PyTree of the policy
@@ -2517,6 +2817,7 @@ r"""
2517
2817
  :param policy_hyperparams: hyper-parameters for the policy/plan, such as
2518
2818
  weights for sigmoid wrapping boolean actions (optional)
2519
2819
  '''
2820
+ subs = subs.copy()
2520
2821
 
2521
2822
  # check compatibility of the subs dictionary
2522
2823
  for (var, values) in subs.items():
@@ -2535,13 +2836,17 @@ r"""
2535
2836
  if step == 0 and var in self.rddl.observ_fluents:
2536
2837
  subs[var] = self.test_compiled.init_values[var]
2537
2838
  else:
2538
- raise ValueError(
2539
- f'Values {values} assigned to p-variable <{var}> are '
2540
- f'non-numeric of type {dtype}.')
2839
+ if dtype.type is np.str_:
2840
+ prange = self.rddl.variable_ranges[var]
2841
+ subs[var] = self.rddl.object_string_to_index_array(prange, subs[var])
2842
+ else:
2843
+ raise ValueError(
2844
+ f'Values {values} assigned to p-variable <{var}> are '
2845
+ f'non-numeric of type {dtype}.')
2541
2846
 
2542
2847
  # cast device arrays to numpy
2543
2848
  actions = self.test_policy(key, params, policy_hyperparams, step, subs)
2544
- actions = jax.tree_map(np.asarray, actions)
2849
+ actions = jax.tree_util.tree_map(np.asarray, actions)
2545
2850
  return actions
2546
2851
 
2547
2852
 
@@ -2562,8 +2867,9 @@ class JaxOfflineController(BaseAgent):
2562
2867
  def __init__(self, planner: JaxBackpropPlanner,
2563
2868
  key: Optional[random.PRNGKey]=None,
2564
2869
  eval_hyperparams: Optional[Dict[str, Any]]=None,
2565
- params: Optional[Pytree]=None,
2870
+ params: Optional[Union[str, Pytree]]=None,
2566
2871
  train_on_reset: bool=False,
2872
+ save_path: Optional[str]=None,
2567
2873
  **train_kwargs) -> None:
2568
2874
  '''Creates a new JAX offline control policy that is trained once, then
2569
2875
  deployed later.
@@ -2574,8 +2880,10 @@ class JaxOfflineController(BaseAgent):
2574
2880
  :param eval_hyperparams: policy hyperparameters to apply for evaluation
2575
2881
  or whenever sample_action is called
2576
2882
  :param params: use the specified policy parameters instead of calling
2577
- planner.optimize()
2883
+ planner.optimize(); can be a string pointing to a valid file path where params
2884
+ have been saved, or a pytree of parameters
2578
2885
  :param train_on_reset: retrain policy parameters on every episode reset
2886
+ :param save_path: optional path to save parameters to
2579
2887
  :param **train_kwargs: any keyword arguments to be passed to the planner
2580
2888
  for optimization
2581
2889
  '''
@@ -2588,12 +2896,24 @@ class JaxOfflineController(BaseAgent):
2588
2896
  self.train_kwargs = train_kwargs
2589
2897
  self.params_given = params is not None
2590
2898
 
2899
+ # load the policy from file
2900
+ if not self.train_on_reset and params is not None and isinstance(params, str):
2901
+ with open(params, 'rb') as file:
2902
+ params = pickle.load(file)
2903
+
2904
+ # train the policy
2591
2905
  self.step = 0
2592
2906
  self.callback = None
2593
2907
  if not self.train_on_reset and not self.params_given:
2594
2908
  callback = self.planner.optimize(key=self.key, **self.train_kwargs)
2595
2909
  self.callback = callback
2596
2910
  params = callback['best_params']
2911
+
2912
+ # save the policy
2913
+ if save_path is not None:
2914
+ with open(save_path, 'wb') as file:
2915
+ pickle.dump(params, file)
2916
+
2597
2917
  self.params = params
2598
2918
 
2599
2919
  def sample_action(self, state: Dict[str, Any]) -> Dict[str, Any]:
@@ -2605,6 +2925,8 @@ class JaxOfflineController(BaseAgent):
2605
2925
 
2606
2926
  def reset(self) -> None:
2607
2927
  self.step = 0
2928
+
2929
+ # train the policy if required to reset at the start of every episode
2608
2930
  if self.train_on_reset and not self.params_given:
2609
2931
  callback = self.planner.optimize(key=self.key, **self.train_kwargs)
2610
2932
  self.callback = callback
@@ -2612,8 +2934,7 @@ class JaxOfflineController(BaseAgent):
2612
2934
 
2613
2935
 
2614
2936
  class JaxOnlineController(BaseAgent):
2615
- '''A container class for a Jax controller continuously updated using state
2616
- feedback.'''
2937
+ '''A container class for a Jax controller continuously updated using state feedback.'''
2617
2938
 
2618
2939
  use_tensor_obs = True
2619
2940
 
@@ -2621,17 +2942,19 @@ class JaxOnlineController(BaseAgent):
2621
2942
  key: Optional[random.PRNGKey]=None,
2622
2943
  eval_hyperparams: Optional[Dict[str, Any]]=None,
2623
2944
  warm_start: bool=True,
2945
+ max_attempts: int=3,
2624
2946
  **train_kwargs) -> None:
2625
2947
  '''Creates a new JAX control policy that is trained online in a closed-
2626
2948
  loop fashion.
2627
2949
 
2628
2950
  :param planner: underlying planning algorithm for optimizing actions
2629
- :param key: the RNG key to seed randomness (derives from clock if not
2630
- provided)
2951
+ :param key: the RNG key to seed randomness (derives from clock if not provided)
2631
2952
  :param eval_hyperparams: policy hyperparameters to apply for evaluation
2632
2953
  or whenever sample_action is called
2633
2954
  :param warm_start: whether to use the previous decision epoch final
2634
2955
  policy parameters to warm the next decision epoch
2956
+ :param max_attempts: maximum attempted restarts of the optimizer when the total
2957
+ iteration count is 1 (i.e. the execution time is dominated by the jit compilation)
2635
2958
  :param **train_kwargs: any keyword arguments to be passed to the planner
2636
2959
  for optimization
2637
2960
  '''
@@ -2642,20 +2965,34 @@ class JaxOnlineController(BaseAgent):
2642
2965
  self.eval_hyperparams = eval_hyperparams
2643
2966
  self.warm_start = warm_start
2644
2967
  self.train_kwargs = train_kwargs
2968
+ self.max_attempts = max_attempts
2645
2969
  self.reset()
2646
2970
 
2647
2971
  def sample_action(self, state: Dict[str, Any]) -> Dict[str, Any]:
2648
2972
  planner = self.planner
2649
2973
  callback = planner.optimize(
2650
- key=self.key,
2651
- guess=self.guess,
2652
- subs=state,
2653
- **self.train_kwargs
2654
- )
2974
+ key=self.key, guess=self.guess, subs=state, **self.train_kwargs)
2975
+
2976
+ # optimize again if jit compilation takes up the entire time budget
2977
+ attempts = 0
2978
+ while attempts < self.max_attempts and callback['iteration'] <= 1:
2979
+ attempts += 1
2980
+ if self.planner.print_warnings:
2981
+ message = termcolor.colored(
2982
+ f'[WARN] JIT compilation dominated the execution time: '
2983
+ f'executing the optimizer again on the traced model '
2984
+ f'[attempt {attempts}].', 'yellow')
2985
+ print(message)
2986
+ callback = planner.optimize(
2987
+ key=self.key, guess=self.guess, subs=state, **self.train_kwargs)
2655
2988
  self.callback = callback
2656
2989
  params = callback['best_params']
2990
+
2991
+ # get the action from the parameters for the current state
2657
2992
  self.key, subkey = random.split(self.key)
2658
2993
  actions = planner.get_action(subkey, params, 0, state, self.eval_hyperparams)
2994
+
2995
+ # apply warm start for the next epoch
2659
2996
  if self.warm_start:
2660
2997
  self.guess = planner.plan.guess_next_epoch(params)
2661
2998
  return actions