pyRDDLGym-jax 2.3__py3-none-any.whl → 2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +2 -3
- pyRDDLGym_jax/core/logic.py +117 -66
- pyRDDLGym_jax/core/planner.py +489 -218
- pyRDDLGym_jax/core/tuning.py +28 -22
- pyRDDLGym_jax/examples/run_plan.py +2 -2
- pyRDDLGym_jax/examples/run_scipy.py +2 -2
- {pyrddlgym_jax-2.3.dist-info → pyrddlgym_jax-2.4.dist-info}/METADATA +1 -1
- {pyrddlgym_jax-2.3.dist-info → pyrddlgym_jax-2.4.dist-info}/RECORD +13 -13
- {pyrddlgym_jax-2.3.dist-info → pyrddlgym_jax-2.4.dist-info}/LICENSE +0 -0
- {pyrddlgym_jax-2.3.dist-info → pyrddlgym_jax-2.4.dist-info}/WHEEL +0 -0
- {pyrddlgym_jax-2.3.dist-info → pyrddlgym_jax-2.4.dist-info}/entry_points.txt +0 -0
- {pyrddlgym_jax-2.3.dist-info → pyrddlgym_jax-2.4.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/planner.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
#
|
|
4
4
|
# Author: Michael Gimelfarb
|
|
5
5
|
#
|
|
6
|
-
#
|
|
6
|
+
# REFERENCES:
|
|
7
7
|
#
|
|
8
8
|
# [1] Gimelfarb, Michael, Ayal Taitler, and Scott Sanner. "JaxPlan and GurobiPlan:
|
|
9
9
|
# Optimization Baselines for Replanning in Discrete and Mixed Discrete-Continuous
|
|
@@ -18,16 +18,21 @@
|
|
|
18
18
|
# reactive policies for planning in stochastic nonlinear domains." In Proceedings of the
|
|
19
19
|
# AAAI Conference on Artificial Intelligence, vol. 33, no. 01, pp. 7530-7537. 2019.
|
|
20
20
|
#
|
|
21
|
-
# [4]
|
|
21
|
+
# [4] Cui, Hao, Thomas Keller, and Roni Khardon. "Stochastic planning with lifted symbolic
|
|
22
|
+
# trajectory optimization." In Proceedings of the International Conference on Automated
|
|
23
|
+
# Planning and Scheduling, vol. 29, pp. 119-127. 2019.
|
|
24
|
+
#
|
|
25
|
+
# [5] Wu, Ga, Buser Say, and Scott Sanner. "Scalable planning with tensorflow for hybrid
|
|
22
26
|
# nonlinear domains." Advances in Neural Information Processing Systems 30 (2017).
|
|
23
27
|
#
|
|
24
|
-
# [
|
|
28
|
+
# [6] Sehnke, Frank, and Tingting Zhao. "Baseline-free sampling in parameter exploring
|
|
25
29
|
# policy gradients: Super symmetric pgpe." Artificial Neural Networks: Methods and
|
|
26
30
|
# Applications in Bio-/Neuroinformatics. Springer International Publishing, 2015.
|
|
27
31
|
#
|
|
28
32
|
# ***********************************************************************
|
|
29
33
|
|
|
30
34
|
|
|
35
|
+
from abc import ABCMeta, abstractmethod
|
|
31
36
|
from ast import literal_eval
|
|
32
37
|
from collections import deque
|
|
33
38
|
import configparser
|
|
@@ -37,7 +42,8 @@ import os
|
|
|
37
42
|
import sys
|
|
38
43
|
import time
|
|
39
44
|
import traceback
|
|
40
|
-
from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Type, Tuple,
|
|
45
|
+
from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Type, Tuple, \
|
|
46
|
+
Union
|
|
41
47
|
|
|
42
48
|
import haiku as hk
|
|
43
49
|
import jax
|
|
@@ -51,6 +57,7 @@ from tqdm import tqdm, TqdmWarning
|
|
|
51
57
|
import warnings
|
|
52
58
|
warnings.filterwarnings("ignore", category=TqdmWarning)
|
|
53
59
|
|
|
60
|
+
from pyRDDLGym.core.compiler.initializer import RDDLValueInitializer
|
|
54
61
|
from pyRDDLGym.core.compiler.model import RDDLPlanningModel, RDDLLiftedModel
|
|
55
62
|
from pyRDDLGym.core.debug.logger import Logger
|
|
56
63
|
from pyRDDLGym.core.debug.exception import (
|
|
@@ -157,25 +164,20 @@ def _load_config(config, args):
|
|
|
157
164
|
initializer = _getattr_any(
|
|
158
165
|
packages=[initializers, hk.initializers], item=plan_initializer)
|
|
159
166
|
if initializer is None:
|
|
160
|
-
|
|
161
|
-
del plan_kwargs['initializer']
|
|
167
|
+
raise ValueError(f'Invalid initializer <{plan_initializer}>.')
|
|
162
168
|
else:
|
|
163
169
|
init_kwargs = plan_kwargs.pop('initializer_kwargs', {})
|
|
164
170
|
try:
|
|
165
171
|
plan_kwargs['initializer'] = initializer(**init_kwargs)
|
|
166
172
|
except Exception as _:
|
|
167
|
-
|
|
168
|
-
f'Ignoring invalid initializer_kwargs <{init_kwargs}>.', 'red')
|
|
169
|
-
plan_kwargs['initializer'] = initializer
|
|
173
|
+
raise ValueError(f'Invalid initializer kwargs <{init_kwargs}>.')
|
|
170
174
|
|
|
171
175
|
# policy activation
|
|
172
176
|
plan_activation = plan_kwargs.get('activation', None)
|
|
173
177
|
if plan_activation is not None:
|
|
174
|
-
activation = _getattr_any(
|
|
175
|
-
packages=[jax.nn, jax.numpy], item=plan_activation)
|
|
178
|
+
activation = _getattr_any(packages=[jax.nn, jax.numpy], item=plan_activation)
|
|
176
179
|
if activation is None:
|
|
177
|
-
|
|
178
|
-
del plan_kwargs['activation']
|
|
180
|
+
raise ValueError(f'Invalid activation <{plan_activation}>.')
|
|
179
181
|
else:
|
|
180
182
|
plan_kwargs['activation'] = activation
|
|
181
183
|
|
|
@@ -188,8 +190,7 @@ def _load_config(config, args):
|
|
|
188
190
|
if planner_optimizer is not None:
|
|
189
191
|
optimizer = _getattr_any(packages=[optax], item=planner_optimizer)
|
|
190
192
|
if optimizer is None:
|
|
191
|
-
|
|
192
|
-
del planner_args['optimizer']
|
|
193
|
+
raise ValueError(f'Invalid optimizer <{planner_optimizer}>.')
|
|
193
194
|
else:
|
|
194
195
|
planner_args['optimizer'] = optimizer
|
|
195
196
|
|
|
@@ -200,8 +201,7 @@ def _load_config(config, args):
|
|
|
200
201
|
if 'optimizer' in pgpe_kwargs:
|
|
201
202
|
pgpe_optimizer = _getattr_any(packages=[optax], item=pgpe_kwargs['optimizer'])
|
|
202
203
|
if pgpe_optimizer is None:
|
|
203
|
-
|
|
204
|
-
del pgpe_kwargs['optimizer']
|
|
204
|
+
raise ValueError(f'Invalid optimizer <{pgpe_optimizer}>.')
|
|
205
205
|
else:
|
|
206
206
|
pgpe_kwargs['optimizer'] = pgpe_optimizer
|
|
207
207
|
planner_args['pgpe'] = getattr(sys.modules[__name__], pgpe_method)(**pgpe_kwargs)
|
|
@@ -260,8 +260,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
260
260
|
cpfs_without_grad: Optional[Set[str]]=None,
|
|
261
261
|
**kwargs) -> None:
|
|
262
262
|
'''Creates a new RDDL to Jax compiler, where operations that are not
|
|
263
|
-
differentiable are converted to approximate forms that have defined
|
|
264
|
-
gradients.
|
|
263
|
+
differentiable are converted to approximate forms that have defined gradients.
|
|
265
264
|
|
|
266
265
|
:param *args: arguments to pass to base compiler
|
|
267
266
|
:param logic: Fuzzy logic object that specifies how exact operations
|
|
@@ -286,8 +285,10 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
286
285
|
if not np.issubdtype(np.result_type(values), np.floating):
|
|
287
286
|
pvars_cast.add(var)
|
|
288
287
|
if pvars_cast:
|
|
289
|
-
|
|
290
|
-
|
|
288
|
+
message = termcolor.colored(
|
|
289
|
+
f'[INFO] JAX gradient compiler will cast p-vars {pvars_cast} to float.',
|
|
290
|
+
'green')
|
|
291
|
+
print(message)
|
|
291
292
|
|
|
292
293
|
# overwrite basic operations with fuzzy ones
|
|
293
294
|
self.OPS = logic.get_operator_dicts()
|
|
@@ -300,6 +301,8 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
300
301
|
return _jax_wrapped_stop_grad
|
|
301
302
|
|
|
302
303
|
def _compile_cpfs(self, init_params):
|
|
304
|
+
|
|
305
|
+
# cpfs will all be cast to float
|
|
303
306
|
cpfs_cast = set()
|
|
304
307
|
jax_cpfs = {}
|
|
305
308
|
for (_, cpfs) in self.levels.items():
|
|
@@ -312,11 +315,15 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
312
315
|
jax_cpfs[cpf] = self._jax_stop_grad(jax_cpfs[cpf])
|
|
313
316
|
|
|
314
317
|
if cpfs_cast:
|
|
315
|
-
|
|
316
|
-
|
|
318
|
+
message = termcolor.colored(
|
|
319
|
+
f'[INFO] JAX gradient compiler will cast CPFs {cpfs_cast} to float.',
|
|
320
|
+
'green')
|
|
321
|
+
print(message)
|
|
317
322
|
if self.cpfs_without_grad:
|
|
318
|
-
|
|
319
|
-
|
|
323
|
+
message = termcolor.colored(
|
|
324
|
+
f'[INFO] Gradients will not flow through CPFs {self.cpfs_without_grad}.',
|
|
325
|
+
'green')
|
|
326
|
+
print(message)
|
|
320
327
|
|
|
321
328
|
return jax_cpfs
|
|
322
329
|
|
|
@@ -335,7 +342,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
335
342
|
# ***********************************************************************
|
|
336
343
|
|
|
337
344
|
|
|
338
|
-
class JaxPlan:
|
|
345
|
+
class JaxPlan(metaclass=ABCMeta):
|
|
339
346
|
'''Base class for all JAX policy representations.'''
|
|
340
347
|
|
|
341
348
|
def __init__(self) -> None:
|
|
@@ -345,16 +352,18 @@ class JaxPlan:
|
|
|
345
352
|
self._projection = None
|
|
346
353
|
self.bounds = None
|
|
347
354
|
|
|
348
|
-
def summarize_hyperparameters(self) ->
|
|
349
|
-
|
|
350
|
-
|
|
355
|
+
def summarize_hyperparameters(self) -> str:
|
|
356
|
+
return self.__str__()
|
|
357
|
+
|
|
358
|
+
@abstractmethod
|
|
351
359
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
352
360
|
_bounds: Bounds,
|
|
353
361
|
horizon: int) -> None:
|
|
354
|
-
|
|
362
|
+
pass
|
|
355
363
|
|
|
364
|
+
@abstractmethod
|
|
356
365
|
def guess_next_epoch(self, params: Pytree) -> Pytree:
|
|
357
|
-
|
|
366
|
+
pass
|
|
358
367
|
|
|
359
368
|
@property
|
|
360
369
|
def initializer(self):
|
|
@@ -397,10 +406,11 @@ class JaxPlan:
|
|
|
397
406
|
continue
|
|
398
407
|
|
|
399
408
|
# check invalid type
|
|
400
|
-
if prange not in compiled.JAX_TYPES:
|
|
409
|
+
if prange not in compiled.JAX_TYPES and prange not in compiled.rddl.enum_types:
|
|
410
|
+
keys = list(compiled.JAX_TYPES.keys()) + list(compiled.rddl.enum_types)
|
|
401
411
|
raise RDDLTypeError(
|
|
402
412
|
f'Invalid range <{prange}> of action-fluent <{name}>, '
|
|
403
|
-
f'must be one of {
|
|
413
|
+
f'must be one of {keys}.')
|
|
404
414
|
|
|
405
415
|
# clip boolean to (0, 1), otherwise use the RDDL action bounds
|
|
406
416
|
# or the user defined action bounds if provided
|
|
@@ -408,7 +418,12 @@ class JaxPlan:
|
|
|
408
418
|
if prange == 'bool':
|
|
409
419
|
lower, upper = None, None
|
|
410
420
|
else:
|
|
411
|
-
|
|
421
|
+
if prange in compiled.rddl.enum_types:
|
|
422
|
+
lower = np.zeros(shape=shapes[name][1:])
|
|
423
|
+
upper = len(compiled.rddl.type_to_objects[prange]) - 1
|
|
424
|
+
upper = np.ones(shape=shapes[name][1:]) * upper
|
|
425
|
+
else:
|
|
426
|
+
lower, upper = compiled.constraints.bounds[name]
|
|
412
427
|
lower, upper = user_bounds.get(name, (lower, upper))
|
|
413
428
|
lower = np.asarray(lower, dtype=compiled.REAL)
|
|
414
429
|
upper = np.asarray(upper, dtype=compiled.REAL)
|
|
@@ -421,7 +436,10 @@ class JaxPlan:
|
|
|
421
436
|
~lower_finite & upper_finite,
|
|
422
437
|
~lower_finite & ~upper_finite]
|
|
423
438
|
bounds[name] = (lower, upper)
|
|
424
|
-
|
|
439
|
+
message = termcolor.colored(
|
|
440
|
+
f'[INFO] Bounds of action-fluent <{name}> set to {bounds[name]}.',
|
|
441
|
+
'green')
|
|
442
|
+
print(message)
|
|
425
443
|
return shapes, bounds, bounds_safe, cond_lists
|
|
426
444
|
|
|
427
445
|
def _count_bool_actions(self, rddl: RDDLLiftedModel):
|
|
@@ -502,10 +520,11 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
502
520
|
bool_action_count, allowed_actions = self._count_bool_actions(rddl)
|
|
503
521
|
use_constraint_satisfaction = allowed_actions < bool_action_count
|
|
504
522
|
if use_constraint_satisfaction:
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
523
|
+
message = termcolor.colored(
|
|
524
|
+
f'[INFO] SLP will use projected gradient to satisfy '
|
|
525
|
+
f'max_nondef_actions since total boolean actions '
|
|
526
|
+
f'{bool_action_count} > max_nondef_actions {allowed_actions}.', 'green')
|
|
527
|
+
print(message)
|
|
509
528
|
|
|
510
529
|
noop = {var: (values[0] if isinstance(values, list) else values)
|
|
511
530
|
for (var, values) in rddl.action_fluents.items()}
|
|
@@ -623,7 +642,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
623
642
|
else:
|
|
624
643
|
action = _jax_non_bool_param_to_action(var, action, hyperparams)
|
|
625
644
|
action = jnp.clip(action, *bounds[var])
|
|
626
|
-
if ranges[var] == 'int':
|
|
645
|
+
if ranges[var] == 'int' or ranges[var] in rddl.enum_types:
|
|
627
646
|
action = jnp.asarray(jnp.round(action), dtype=compiled.INT)
|
|
628
647
|
actions[var] = action
|
|
629
648
|
return actions
|
|
@@ -642,7 +661,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
642
661
|
# only allow one action non-noop for now
|
|
643
662
|
if 1 < allowed_actions < bool_action_count:
|
|
644
663
|
raise RDDLNotImplementedError(
|
|
645
|
-
f'
|
|
664
|
+
f'SLPs with wrap_softmax currently '
|
|
646
665
|
f'do not support max-nondef-actions {allowed_actions} > 1.')
|
|
647
666
|
|
|
648
667
|
# potentially apply projection but to non-bool actions only
|
|
@@ -764,7 +783,8 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
764
783
|
for (var, action) in actions.items():
|
|
765
784
|
if ranges[var] == 'bool':
|
|
766
785
|
action = jnp.clip(action, min_action, max_action)
|
|
767
|
-
|
|
786
|
+
param = _jax_bool_action_to_param(var, action, hyperparams)
|
|
787
|
+
new_params[var] = param
|
|
768
788
|
else:
|
|
769
789
|
new_params[var] = action
|
|
770
790
|
return new_params, converged
|
|
@@ -890,8 +910,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
890
910
|
bool_action_count, allowed_actions = self._count_bool_actions(rddl)
|
|
891
911
|
if 1 < allowed_actions < bool_action_count:
|
|
892
912
|
raise RDDLNotImplementedError(
|
|
893
|
-
f'
|
|
894
|
-
f'max-nondef-actions {allowed_actions} > 1.')
|
|
913
|
+
f'DRPs currently do not support max-nondef-actions {allowed_actions} > 1.')
|
|
895
914
|
use_constraint_satisfaction = allowed_actions < bool_action_count
|
|
896
915
|
|
|
897
916
|
noop = {var: (values[0] if isinstance(values, list) else values)
|
|
@@ -927,15 +946,17 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
927
946
|
if ranges[var] != 'bool':
|
|
928
947
|
value_size = np.size(values)
|
|
929
948
|
if normalize_per_layer and value_size == 1:
|
|
930
|
-
|
|
931
|
-
f'Cannot apply layer norm to state-fluent <{var}> '
|
|
932
|
-
f'of size 1: setting normalize_per_layer = False.', '
|
|
949
|
+
message = termcolor.colored(
|
|
950
|
+
f'[WARN] Cannot apply layer norm to state-fluent <{var}> '
|
|
951
|
+
f'of size 1: setting normalize_per_layer = False.', 'yellow')
|
|
952
|
+
print(message)
|
|
933
953
|
normalize_per_layer = False
|
|
934
954
|
non_bool_dims += value_size
|
|
935
955
|
if not normalize_per_layer and non_bool_dims == 1:
|
|
936
|
-
|
|
937
|
-
'Cannot apply layer norm to state-fluents of total size 1: '
|
|
938
|
-
'setting normalize = False.', '
|
|
956
|
+
message = termcolor.colored(
|
|
957
|
+
'[WARN] Cannot apply layer norm to state-fluents of total size 1: '
|
|
958
|
+
'setting normalize = False.', 'yellow')
|
|
959
|
+
print(message)
|
|
939
960
|
normalize = False
|
|
940
961
|
|
|
941
962
|
# convert subs dictionary into a state vector to feed to the MLP
|
|
@@ -1061,7 +1082,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1061
1082
|
prange = ranges[var]
|
|
1062
1083
|
if prange == 'bool':
|
|
1063
1084
|
new_action = action > 0.5
|
|
1064
|
-
elif prange == 'int':
|
|
1085
|
+
elif prange == 'int' or prange in rddl.enum_types:
|
|
1065
1086
|
action = jnp.clip(action, *bounds[var])
|
|
1066
1087
|
new_action = jnp.asarray(jnp.round(action), dtype=compiled.INT)
|
|
1067
1088
|
else:
|
|
@@ -1112,19 +1133,18 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1112
1133
|
|
|
1113
1134
|
|
|
1114
1135
|
class RollingMean:
|
|
1115
|
-
'''Maintains
|
|
1116
|
-
observations.'''
|
|
1136
|
+
'''Maintains the rolling mean of a stream of real-valued observations.'''
|
|
1117
1137
|
|
|
1118
1138
|
def __init__(self, window_size: int) -> None:
|
|
1119
1139
|
self._window_size = window_size
|
|
1120
1140
|
self._memory = deque(maxlen=window_size)
|
|
1121
1141
|
self._total = 0
|
|
1122
1142
|
|
|
1123
|
-
def update(self, x: float) -> float:
|
|
1143
|
+
def update(self, x: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
|
|
1124
1144
|
memory = self._memory
|
|
1125
|
-
self._total
|
|
1145
|
+
self._total = self._total + x
|
|
1126
1146
|
if len(memory) == self._window_size:
|
|
1127
|
-
self._total
|
|
1147
|
+
self._total = self._total - memory.popleft()
|
|
1128
1148
|
memory.append(x)
|
|
1129
1149
|
return self._total / len(memory)
|
|
1130
1150
|
|
|
@@ -1147,14 +1167,16 @@ class JaxPlannerStatus(Enum):
|
|
|
1147
1167
|
return self.value == 1 or self.value >= 4
|
|
1148
1168
|
|
|
1149
1169
|
|
|
1150
|
-
class JaxPlannerStoppingRule:
|
|
1170
|
+
class JaxPlannerStoppingRule(metaclass=ABCMeta):
|
|
1151
1171
|
'''The base class of all planner stopping rules.'''
|
|
1152
1172
|
|
|
1173
|
+
@abstractmethod
|
|
1153
1174
|
def reset(self) -> None:
|
|
1154
|
-
|
|
1155
|
-
|
|
1175
|
+
pass
|
|
1176
|
+
|
|
1177
|
+
@abstractmethod
|
|
1156
1178
|
def monitor(self, callback: Dict[str, Any]) -> bool:
|
|
1157
|
-
|
|
1179
|
+
pass
|
|
1158
1180
|
|
|
1159
1181
|
|
|
1160
1182
|
class NoImprovementStoppingRule(JaxPlannerStoppingRule):
|
|
@@ -1168,8 +1190,7 @@ class NoImprovementStoppingRule(JaxPlannerStoppingRule):
|
|
|
1168
1190
|
self.iters_since_last_update = 0
|
|
1169
1191
|
|
|
1170
1192
|
def monitor(self, callback: Dict[str, Any]) -> bool:
|
|
1171
|
-
if self.callback is None
|
|
1172
|
-
or callback['best_return'] > self.callback['best_return']:
|
|
1193
|
+
if self.callback is None or callback['best_return'] > self.callback['best_return']:
|
|
1173
1194
|
self.callback = callback
|
|
1174
1195
|
self.iters_since_last_update = 0
|
|
1175
1196
|
else:
|
|
@@ -1188,7 +1209,7 @@ class NoImprovementStoppingRule(JaxPlannerStoppingRule):
|
|
|
1188
1209
|
# ***********************************************************************
|
|
1189
1210
|
|
|
1190
1211
|
|
|
1191
|
-
class PGPE:
|
|
1212
|
+
class PGPE(metaclass=ABCMeta):
|
|
1192
1213
|
"""Base class for all PGPE strategies."""
|
|
1193
1214
|
|
|
1194
1215
|
def __init__(self) -> None:
|
|
@@ -1203,8 +1224,10 @@ class PGPE:
|
|
|
1203
1224
|
def update(self):
|
|
1204
1225
|
return self._update
|
|
1205
1226
|
|
|
1206
|
-
|
|
1207
|
-
|
|
1227
|
+
@abstractmethod
|
|
1228
|
+
def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type,
|
|
1229
|
+
parallel_updates: Optional[int]=None) -> None:
|
|
1230
|
+
pass
|
|
1208
1231
|
|
|
1209
1232
|
|
|
1210
1233
|
class GaussianPGPE(PGPE):
|
|
@@ -1268,10 +1291,11 @@ class GaussianPGPE(PGPE):
|
|
|
1268
1291
|
mu_optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs_mu)
|
|
1269
1292
|
sigma_optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs_sigma)
|
|
1270
1293
|
except Exception as _:
|
|
1271
|
-
|
|
1272
|
-
|
|
1273
|
-
'rolling back to safer method:
|
|
1274
|
-
'
|
|
1294
|
+
message = termcolor.colored(
|
|
1295
|
+
'[FAIL] Failed to inject hyperparameters into PGPE optimizer, '
|
|
1296
|
+
'rolling back to safer method: '
|
|
1297
|
+
'kl-divergence constraint will be disabled.', 'red')
|
|
1298
|
+
print(message)
|
|
1275
1299
|
mu_optimizer = optimizer(**optimizer_kwargs_mu)
|
|
1276
1300
|
sigma_optimizer = optimizer(**optimizer_kwargs_sigma)
|
|
1277
1301
|
max_kl_update = None
|
|
@@ -1297,15 +1321,16 @@ class GaussianPGPE(PGPE):
|
|
|
1297
1321
|
f' max_kl_update ={self.max_kl}\n'
|
|
1298
1322
|
)
|
|
1299
1323
|
|
|
1300
|
-
def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type
|
|
1324
|
+
def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type,
|
|
1325
|
+
parallel_updates: Optional[int]=None) -> None:
|
|
1301
1326
|
sigma0 = self.init_sigma
|
|
1302
|
-
|
|
1327
|
+
sigma_lo, sigma_hi = self.sigma_range
|
|
1303
1328
|
scale_reward = self.scale_reward
|
|
1304
1329
|
min_reward_scale = self.min_reward_scale
|
|
1305
1330
|
super_symmetric = self.super_symmetric
|
|
1306
1331
|
super_symmetric_accurate = self.super_symmetric_accurate
|
|
1307
1332
|
batch_size = self.batch_size
|
|
1308
|
-
|
|
1333
|
+
mu_optimizer, sigma_optimizer = self.optimizers
|
|
1309
1334
|
max_kl = self.max_kl
|
|
1310
1335
|
|
|
1311
1336
|
# entropy regularization penalty is decayed exponentially by elapsed budget
|
|
@@ -1322,13 +1347,22 @@ class GaussianPGPE(PGPE):
|
|
|
1322
1347
|
|
|
1323
1348
|
def _jax_wrapped_pgpe_init(key, policy_params):
|
|
1324
1349
|
mu = policy_params
|
|
1325
|
-
sigma = jax.tree_map(
|
|
1350
|
+
sigma = jax.tree_map(partial(jnp.full_like, fill_value=sigma0), mu)
|
|
1326
1351
|
pgpe_params = (mu, sigma)
|
|
1327
|
-
pgpe_opt_state =
|
|
1328
|
-
|
|
1329
|
-
return pgpe_params, pgpe_opt_state
|
|
1352
|
+
pgpe_opt_state = (mu_optimizer.init(mu), sigma_optimizer.init(sigma))
|
|
1353
|
+
r_max = -jnp.inf
|
|
1354
|
+
return pgpe_params, pgpe_opt_state, r_max
|
|
1330
1355
|
|
|
1331
|
-
|
|
1356
|
+
if parallel_updates is None:
|
|
1357
|
+
self._initializer = jax.jit(_jax_wrapped_pgpe_init)
|
|
1358
|
+
else:
|
|
1359
|
+
|
|
1360
|
+
# for parallel policy update
|
|
1361
|
+
def _jax_wrapped_pgpe_inits(key, policy_params):
|
|
1362
|
+
keys = jnp.asarray(random.split(key, num=parallel_updates))
|
|
1363
|
+
return jax.vmap(_jax_wrapped_pgpe_init, in_axes=0)(keys, policy_params)
|
|
1364
|
+
|
|
1365
|
+
self._initializer = jax.jit(_jax_wrapped_pgpe_inits)
|
|
1332
1366
|
|
|
1333
1367
|
# ***********************************************************************
|
|
1334
1368
|
# PARAMETER SAMPLING FUNCTIONS
|
|
@@ -1338,6 +1372,8 @@ class GaussianPGPE(PGPE):
|
|
|
1338
1372
|
def _jax_wrapped_mu_noise(key, sigma):
|
|
1339
1373
|
return sigma * random.normal(key, shape=jnp.shape(sigma), dtype=real_dtype)
|
|
1340
1374
|
|
|
1375
|
+
# this samples a noise variable epsilon* from epsilon with the N(0, 1) density
|
|
1376
|
+
# according to super-symmetric sampling paper
|
|
1341
1377
|
def _jax_wrapped_epsilon_star(sigma, epsilon):
|
|
1342
1378
|
c1, c2, c3 = -0.06655, -0.9706, 0.124
|
|
1343
1379
|
phi = 0.67449 * sigma
|
|
@@ -1354,6 +1390,7 @@ class GaussianPGPE(PGPE):
|
|
|
1354
1390
|
epsilon_star = jnp.sign(epsilon) * phi * jnp.exp(a)
|
|
1355
1391
|
return epsilon_star
|
|
1356
1392
|
|
|
1393
|
+
# implements baseline-free super-symmetric sampling to generate 4 trajectories
|
|
1357
1394
|
def _jax_wrapped_sample_params(key, mu, sigma):
|
|
1358
1395
|
treedef = jax.tree_util.tree_structure(sigma)
|
|
1359
1396
|
keys = random.split(key, num=treedef.num_leaves)
|
|
@@ -1374,6 +1411,7 @@ class GaussianPGPE(PGPE):
|
|
|
1374
1411
|
#
|
|
1375
1412
|
# ***********************************************************************
|
|
1376
1413
|
|
|
1414
|
+
# gradient with respect to mean
|
|
1377
1415
|
def _jax_wrapped_mu_grad(epsilon, epsilon_star, r1, r2, r3, r4, m):
|
|
1378
1416
|
if super_symmetric:
|
|
1379
1417
|
if scale_reward:
|
|
@@ -1393,6 +1431,7 @@ class GaussianPGPE(PGPE):
|
|
|
1393
1431
|
grad = -r_mu * epsilon
|
|
1394
1432
|
return grad
|
|
1395
1433
|
|
|
1434
|
+
# gradient with respect to std. deviation
|
|
1396
1435
|
def _jax_wrapped_sigma_grad(epsilon, epsilon_star, sigma, r1, r2, r3, r4, m, ent):
|
|
1397
1436
|
if super_symmetric:
|
|
1398
1437
|
mask = r1 + r2 >= r3 + r4
|
|
@@ -1413,6 +1452,7 @@ class GaussianPGPE(PGPE):
|
|
|
1413
1452
|
grad = -(r_sigma * s + ent / sigma)
|
|
1414
1453
|
return grad
|
|
1415
1454
|
|
|
1455
|
+
# calculate the policy gradients
|
|
1416
1456
|
def _jax_wrapped_pgpe_grad(key, mu, sigma, r_max, ent,
|
|
1417
1457
|
policy_hyperparams, subs, model_params):
|
|
1418
1458
|
key, subkey = random.split(key)
|
|
@@ -1462,11 +1502,24 @@ class GaussianPGPE(PGPE):
|
|
|
1462
1502
|
#
|
|
1463
1503
|
# ***********************************************************************
|
|
1464
1504
|
|
|
1505
|
+
# estimate KL divergence between two updates
|
|
1465
1506
|
def _jax_wrapped_pgpe_kl_term(mu, sigma, old_mu, old_sigma):
|
|
1466
1507
|
return 0.5 * jnp.sum(2 * jnp.log(sigma / old_sigma) +
|
|
1467
1508
|
jnp.square(old_sigma / sigma) +
|
|
1468
1509
|
jnp.square((mu - old_mu) / sigma) - 1)
|
|
1469
1510
|
|
|
1511
|
+
# update mean and std. deviation with a gradient step
|
|
1512
|
+
def _jax_wrapped_pgpe_update_helper(mu, sigma, mu_grad, sigma_grad,
|
|
1513
|
+
mu_state, sigma_state):
|
|
1514
|
+
mu_updates, new_mu_state = mu_optimizer.update(mu_grad, mu_state, params=mu)
|
|
1515
|
+
sigma_updates, new_sigma_state = sigma_optimizer.update(
|
|
1516
|
+
sigma_grad, sigma_state, params=sigma)
|
|
1517
|
+
new_mu = optax.apply_updates(mu, mu_updates)
|
|
1518
|
+
new_sigma = optax.apply_updates(sigma, sigma_updates)
|
|
1519
|
+
new_sigma = jax.tree_map(
|
|
1520
|
+
partial(jnp.clip, min=sigma_lo, max=sigma_hi), new_sigma)
|
|
1521
|
+
return new_mu, new_sigma, new_mu_state, new_sigma_state
|
|
1522
|
+
|
|
1470
1523
|
def _jax_wrapped_pgpe_update(key, pgpe_params, r_max, progress,
|
|
1471
1524
|
policy_hyperparams, subs, model_params,
|
|
1472
1525
|
pgpe_opt_state):
|
|
@@ -1476,12 +1529,9 @@ class GaussianPGPE(PGPE):
|
|
|
1476
1529
|
ent = start_entropy_coeff * jnp.power(entropy_coeff_decay, progress)
|
|
1477
1530
|
mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad_batched(
|
|
1478
1531
|
key, pgpe_params, r_max, ent, policy_hyperparams, subs, model_params)
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
new_mu = optax.apply_updates(mu, mu_updates)
|
|
1483
|
-
new_sigma = optax.apply_updates(sigma, sigma_updates)
|
|
1484
|
-
new_sigma = jax.tree_map(lambda x: jnp.clip(x, *sigma_range), new_sigma)
|
|
1532
|
+
new_mu, new_sigma, new_mu_state, new_sigma_state = \
|
|
1533
|
+
_jax_wrapped_pgpe_update_helper(mu, sigma, mu_grad, sigma_grad,
|
|
1534
|
+
mu_state, sigma_state)
|
|
1485
1535
|
|
|
1486
1536
|
# respect KL divergence contraint with old parameters
|
|
1487
1537
|
if max_kl is not None:
|
|
@@ -1493,12 +1543,9 @@ class GaussianPGPE(PGPE):
|
|
|
1493
1543
|
kl_reduction = jnp.minimum(1.0, jnp.sqrt(max_kl / total_kl))
|
|
1494
1544
|
mu_state.hyperparams['learning_rate'] = old_mu_lr * kl_reduction
|
|
1495
1545
|
sigma_state.hyperparams['learning_rate'] = old_sigma_lr * kl_reduction
|
|
1496
|
-
|
|
1497
|
-
|
|
1498
|
-
|
|
1499
|
-
new_mu = optax.apply_updates(mu, mu_updates)
|
|
1500
|
-
new_sigma = optax.apply_updates(sigma, sigma_updates)
|
|
1501
|
-
new_sigma = jax.tree_map(lambda x: jnp.clip(x, *sigma_range), new_sigma)
|
|
1546
|
+
new_mu, new_sigma, new_mu_state, new_sigma_state = \
|
|
1547
|
+
_jax_wrapped_pgpe_update_helper(mu, sigma, mu_grad, sigma_grad,
|
|
1548
|
+
mu_state, sigma_state)
|
|
1502
1549
|
new_mu_state.hyperparams['learning_rate'] = old_mu_lr
|
|
1503
1550
|
new_sigma_state.hyperparams['learning_rate'] = old_sigma_lr
|
|
1504
1551
|
|
|
@@ -1509,7 +1556,21 @@ class GaussianPGPE(PGPE):
|
|
|
1509
1556
|
policy_params = new_mu
|
|
1510
1557
|
return new_pgpe_params, new_r_max, new_pgpe_opt_state, policy_params, converged
|
|
1511
1558
|
|
|
1512
|
-
|
|
1559
|
+
if parallel_updates is None:
|
|
1560
|
+
self._update = jax.jit(_jax_wrapped_pgpe_update)
|
|
1561
|
+
else:
|
|
1562
|
+
|
|
1563
|
+
# for parallel policy update
|
|
1564
|
+
def _jax_wrapped_pgpe_updates(key, pgpe_params, r_max, progress,
|
|
1565
|
+
policy_hyperparams, subs, model_params,
|
|
1566
|
+
pgpe_opt_state):
|
|
1567
|
+
keys = jnp.asarray(random.split(key, num=parallel_updates))
|
|
1568
|
+
return jax.vmap(
|
|
1569
|
+
_jax_wrapped_pgpe_update, in_axes=(0, 0, 0, None, None, None, 0, 0)
|
|
1570
|
+
)(keys, pgpe_params, r_max, progress, policy_hyperparams, subs,
|
|
1571
|
+
model_params, pgpe_opt_state)
|
|
1572
|
+
|
|
1573
|
+
self._update = jax.jit(_jax_wrapped_pgpe_updates)
|
|
1513
1574
|
|
|
1514
1575
|
|
|
1515
1576
|
# ***********************************************************************
|
|
@@ -1565,6 +1626,7 @@ def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
|
|
|
1565
1626
|
return jnp.sum(returns * weights)
|
|
1566
1627
|
|
|
1567
1628
|
|
|
1629
|
+
# set of all currently valid built-in utility functions
|
|
1568
1630
|
UTILITY_LOOKUP = {
|
|
1569
1631
|
'mean': jnp.mean,
|
|
1570
1632
|
'mean_var': mean_variance_utility,
|
|
@@ -1609,7 +1671,8 @@ class JaxBackpropPlanner:
|
|
|
1609
1671
|
cpfs_without_grad: Optional[Set[str]]=None,
|
|
1610
1672
|
compile_non_fluent_exact: bool=True,
|
|
1611
1673
|
logger: Optional[Logger]=None,
|
|
1612
|
-
dashboard_viz: Optional[Any]=None
|
|
1674
|
+
dashboard_viz: Optional[Any]=None,
|
|
1675
|
+
parallel_updates: Optional[int]=None) -> None:
|
|
1613
1676
|
'''Creates a new gradient-based algorithm for optimizing action sequences
|
|
1614
1677
|
(plan) in the given RDDL. Some operations will be converted to their
|
|
1615
1678
|
differentiable counterparts; the specific operations can be customized
|
|
@@ -1649,6 +1712,7 @@ class JaxBackpropPlanner:
|
|
|
1649
1712
|
:param logger: to log information about compilation to file
|
|
1650
1713
|
:param dashboard_viz: optional visualizer object from the environment
|
|
1651
1714
|
to pass to the dashboard to visualize the policy
|
|
1715
|
+
:param parallel_updates: how many optimizers to run independently in parallel
|
|
1652
1716
|
'''
|
|
1653
1717
|
self.rddl = rddl
|
|
1654
1718
|
self.plan = plan
|
|
@@ -1656,6 +1720,7 @@ class JaxBackpropPlanner:
|
|
|
1656
1720
|
if batch_size_test is None:
|
|
1657
1721
|
batch_size_test = batch_size_train
|
|
1658
1722
|
self.batch_size_test = batch_size_test
|
|
1723
|
+
self.parallel_updates = parallel_updates
|
|
1659
1724
|
if rollout_horizon is None:
|
|
1660
1725
|
rollout_horizon = rddl.horizon
|
|
1661
1726
|
self.horizon = rollout_horizon
|
|
@@ -1677,10 +1742,11 @@ class JaxBackpropPlanner:
|
|
|
1677
1742
|
try:
|
|
1678
1743
|
optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs)
|
|
1679
1744
|
except Exception as _:
|
|
1680
|
-
|
|
1681
|
-
|
|
1682
|
-
'rolling back to safer method: please note that modification of '
|
|
1683
|
-
'
|
|
1745
|
+
message = termcolor.colored(
|
|
1746
|
+
'[FAIL] Failed to inject hyperparameters into JaxPlan optimizer, '
|
|
1747
|
+
'rolling back to safer method: please note that runtime modification of '
|
|
1748
|
+
'hyperparameters will be disabled.', 'red')
|
|
1749
|
+
print(message)
|
|
1684
1750
|
optimizer = optimizer(**optimizer_kwargs)
|
|
1685
1751
|
|
|
1686
1752
|
# apply optimizer chain of transformations
|
|
@@ -1700,7 +1766,7 @@ class JaxBackpropPlanner:
|
|
|
1700
1766
|
utility_fn = UTILITY_LOOKUP.get(utility, None)
|
|
1701
1767
|
if utility_fn is None:
|
|
1702
1768
|
raise RDDLNotImplementedError(
|
|
1703
|
-
f'Utility <{utility}> is not supported, '
|
|
1769
|
+
f'Utility function <{utility}> is not supported, '
|
|
1704
1770
|
f'must be one of {list(UTILITY_LOOKUP.keys())}.')
|
|
1705
1771
|
else:
|
|
1706
1772
|
utility_fn = utility
|
|
@@ -1742,7 +1808,7 @@ r"""
|
|
|
1742
1808
|
\/_____/ \/_/\/_/ \/_/\/_/ \/_/ \/_____/ \/_/\/_/ \/_/ \/_/
|
|
1743
1809
|
"""
|
|
1744
1810
|
|
|
1745
|
-
return ('\n'
|
|
1811
|
+
return (f'\n'
|
|
1746
1812
|
f'{LOGO}\n'
|
|
1747
1813
|
f'Version {__version__}\n'
|
|
1748
1814
|
f'Python {sys.version}\n'
|
|
@@ -1751,7 +1817,23 @@ r"""
|
|
|
1751
1817
|
f'numpy {np.__version__}\n'
|
|
1752
1818
|
f'devices: {devices_short}\n')
|
|
1753
1819
|
|
|
1754
|
-
def
|
|
1820
|
+
def summarize_relaxations(self) -> str:
|
|
1821
|
+
result = ''
|
|
1822
|
+
if self.compiled.model_params:
|
|
1823
|
+
result += ('Some RDDL operations are non-differentiable '
|
|
1824
|
+
'and will be approximated as follows:' + '\n')
|
|
1825
|
+
exprs_by_rddl_op, values_by_rddl_op = {}, {}
|
|
1826
|
+
for info in self.compiled.model_parameter_info().values():
|
|
1827
|
+
rddl_op = info['rddl_op']
|
|
1828
|
+
exprs_by_rddl_op.setdefault(rddl_op, []).append(info['id'])
|
|
1829
|
+
values_by_rddl_op.setdefault(rddl_op, []).append(info['init_value'])
|
|
1830
|
+
for rddl_op in sorted(exprs_by_rddl_op.keys()):
|
|
1831
|
+
result += (f' {rddl_op}:\n'
|
|
1832
|
+
f' addresses ={exprs_by_rddl_op[rddl_op]}\n'
|
|
1833
|
+
f' init_values={values_by_rddl_op[rddl_op]}\n')
|
|
1834
|
+
return result
|
|
1835
|
+
|
|
1836
|
+
def summarize_hyperparameters(self) -> str:
|
|
1755
1837
|
result = (f'objective hyper-parameters:\n'
|
|
1756
1838
|
f' utility_fn ={self.utility.__name__}\n'
|
|
1757
1839
|
f' utility args ={self.utility_kwargs}\n'
|
|
@@ -1769,30 +1851,14 @@ r"""
|
|
|
1769
1851
|
f' line_search_kwargs={self.line_search_kwargs}\n'
|
|
1770
1852
|
f' noise_kwargs ={self.noise_kwargs}\n'
|
|
1771
1853
|
f' batch_size_train ={self.batch_size_train}\n'
|
|
1772
|
-
f' batch_size_test ={self.batch_size_test}\n'
|
|
1854
|
+
f' batch_size_test ={self.batch_size_test}\n'
|
|
1855
|
+
f' parallel_updates ={self.parallel_updates}\n')
|
|
1773
1856
|
result += str(self.plan)
|
|
1774
1857
|
if self.use_pgpe:
|
|
1775
1858
|
result += str(self.pgpe)
|
|
1776
1859
|
result += str(self.logic)
|
|
1777
|
-
|
|
1778
|
-
# print model relaxation information
|
|
1779
|
-
if self.compiled.model_params:
|
|
1780
|
-
result += ('Some RDDL operations are non-differentiable '
|
|
1781
|
-
'and will be approximated as follows:' + '\n')
|
|
1782
|
-
exprs_by_rddl_op, values_by_rddl_op = {}, {}
|
|
1783
|
-
for info in self.compiled.model_parameter_info().values():
|
|
1784
|
-
rddl_op = info['rddl_op']
|
|
1785
|
-
exprs_by_rddl_op.setdefault(rddl_op, []).append(info['id'])
|
|
1786
|
-
values_by_rddl_op.setdefault(rddl_op, []).append(info['init_value'])
|
|
1787
|
-
for rddl_op in sorted(exprs_by_rddl_op.keys()):
|
|
1788
|
-
result += (f' {rddl_op}:\n'
|
|
1789
|
-
f' addresses ={exprs_by_rddl_op[rddl_op]}\n'
|
|
1790
|
-
f' init_values={values_by_rddl_op[rddl_op]}\n')
|
|
1791
1860
|
return result
|
|
1792
1861
|
|
|
1793
|
-
def summarize_hyperparameters(self) -> None:
|
|
1794
|
-
print(self.__str__())
|
|
1795
|
-
|
|
1796
1862
|
# ===========================================================================
|
|
1797
1863
|
# COMPILATION SUBROUTINES
|
|
1798
1864
|
# ===========================================================================
|
|
@@ -1844,23 +1910,31 @@ r"""
|
|
|
1844
1910
|
self.test_rollouts = jax.jit(test_rollouts)
|
|
1845
1911
|
|
|
1846
1912
|
# initialization
|
|
1847
|
-
self.initialize =
|
|
1913
|
+
self.initialize, self.init_optimizer = self._jax_init()
|
|
1848
1914
|
|
|
1849
1915
|
# losses
|
|
1850
1916
|
train_loss = self._jax_loss(train_rollouts, use_symlog=self.use_symlog_reward)
|
|
1851
|
-
|
|
1917
|
+
test_loss = self._jax_loss(test_rollouts, use_symlog=False)
|
|
1918
|
+
if self.parallel_updates is None:
|
|
1919
|
+
self.test_loss = jax.jit(test_loss)
|
|
1920
|
+
else:
|
|
1921
|
+
self.test_loss = jax.jit(jax.vmap(test_loss, in_axes=(None, 0, None, None, 0)))
|
|
1852
1922
|
|
|
1853
1923
|
# optimization
|
|
1854
1924
|
self.update = self._jax_update(train_loss)
|
|
1925
|
+
self.pytree_at = jax.jit(lambda tree, i: jax.tree_map(lambda x: x[i], tree))
|
|
1855
1926
|
|
|
1856
1927
|
# pgpe option
|
|
1857
1928
|
if self.use_pgpe:
|
|
1858
|
-
loss_fn = self._jax_loss(rollouts=test_rollouts)
|
|
1859
1929
|
self.pgpe.compile(
|
|
1860
|
-
loss_fn=
|
|
1930
|
+
loss_fn=test_loss,
|
|
1861
1931
|
projection=self.plan.projection,
|
|
1862
|
-
real_dtype=self.test_compiled.REAL
|
|
1932
|
+
real_dtype=self.test_compiled.REAL,
|
|
1933
|
+
parallel_updates=self.parallel_updates
|
|
1863
1934
|
)
|
|
1935
|
+
self.merge_pgpe = self._jax_merge_pgpe_jaxplan()
|
|
1936
|
+
else:
|
|
1937
|
+
self.merge_pgpe = None
|
|
1864
1938
|
|
|
1865
1939
|
def _jax_return(self, use_symlog):
|
|
1866
1940
|
gamma = self.rddl.discount
|
|
@@ -1900,24 +1974,43 @@ r"""
|
|
|
1900
1974
|
def _jax_init(self):
|
|
1901
1975
|
init = self.plan.initializer
|
|
1902
1976
|
optimizer = self.optimizer
|
|
1977
|
+
num_parallel = self.parallel_updates
|
|
1903
1978
|
|
|
1904
1979
|
# initialize both the policy and its optimizer
|
|
1905
1980
|
def _jax_wrapped_init_policy(key, policy_hyperparams, subs):
|
|
1906
1981
|
policy_params = init(key, policy_hyperparams, subs)
|
|
1907
1982
|
opt_state = optimizer.init(policy_params)
|
|
1908
|
-
return policy_params, opt_state, {}
|
|
1983
|
+
return policy_params, opt_state, {}
|
|
1909
1984
|
|
|
1910
|
-
|
|
1985
|
+
# initialize just the optimizer from the policy
|
|
1986
|
+
def _jax_wrapped_init_opt(policy_params):
|
|
1987
|
+
if num_parallel is None:
|
|
1988
|
+
opt_state = optimizer.init(policy_params)
|
|
1989
|
+
else:
|
|
1990
|
+
opt_state = jax.vmap(optimizer.init, in_axes=0)(policy_params)
|
|
1991
|
+
return opt_state, {}
|
|
1992
|
+
|
|
1993
|
+
if num_parallel is None:
|
|
1994
|
+
return jax.jit(_jax_wrapped_init_policy), jax.jit(_jax_wrapped_init_opt)
|
|
1995
|
+
|
|
1996
|
+
# for parallel policy update
|
|
1997
|
+
def _jax_wrapped_init_policies(key, policy_hyperparams, subs):
|
|
1998
|
+
keys = jnp.asarray(random.split(key, num=num_parallel))
|
|
1999
|
+
return jax.vmap(_jax_wrapped_init_policy, in_axes=(0, None, None))(
|
|
2000
|
+
keys, policy_hyperparams, subs)
|
|
2001
|
+
|
|
2002
|
+
return jax.jit(_jax_wrapped_init_policies), jax.jit(_jax_wrapped_init_opt)
|
|
1911
2003
|
|
|
1912
2004
|
def _jax_update(self, loss):
|
|
1913
2005
|
optimizer = self.optimizer
|
|
1914
2006
|
projection = self.plan.projection
|
|
1915
2007
|
use_ls = self.line_search_kwargs is not None
|
|
2008
|
+
num_parallel = self.parallel_updates
|
|
1916
2009
|
|
|
1917
2010
|
# check if the gradients are all zeros
|
|
1918
2011
|
def _jax_wrapped_zero_gradients(grad):
|
|
1919
2012
|
leaves, _ = jax.tree_util.tree_flatten(
|
|
1920
|
-
jax.tree_map(
|
|
2013
|
+
jax.tree_map(partial(jnp.allclose, b=0), grad))
|
|
1921
2014
|
return jnp.all(jnp.asarray(leaves))
|
|
1922
2015
|
|
|
1923
2016
|
# calculate the plan gradient w.r.t. return loss and update optimizer
|
|
@@ -1948,8 +2041,43 @@ r"""
|
|
|
1948
2041
|
return policy_params, converged, opt_state, opt_aux, \
|
|
1949
2042
|
loss_val, log, model_params, zero_grads
|
|
1950
2043
|
|
|
1951
|
-
|
|
2044
|
+
if num_parallel is None:
|
|
2045
|
+
return jax.jit(_jax_wrapped_plan_update)
|
|
2046
|
+
|
|
2047
|
+
# for parallel policy update
|
|
2048
|
+
def _jax_wrapped_plan_updates(key, policy_params, policy_hyperparams,
|
|
2049
|
+
subs, model_params, opt_state, opt_aux):
|
|
2050
|
+
keys = jnp.asarray(random.split(key, num=num_parallel))
|
|
2051
|
+
return jax.vmap(
|
|
2052
|
+
_jax_wrapped_plan_update, in_axes=(0, 0, None, None, 0, 0, 0)
|
|
2053
|
+
)(keys, policy_params, policy_hyperparams, subs, model_params,
|
|
2054
|
+
opt_state, opt_aux)
|
|
2055
|
+
|
|
2056
|
+
return jax.jit(_jax_wrapped_plan_updates)
|
|
1952
2057
|
|
|
2058
|
+
def _jax_merge_pgpe_jaxplan(self):
|
|
2059
|
+
if self.parallel_updates is None:
|
|
2060
|
+
return None
|
|
2061
|
+
|
|
2062
|
+
# for parallel policy update
|
|
2063
|
+
# currently implements a hard replacement where the jaxplan parameter
|
|
2064
|
+
# is replaced by the PGPE parameter if the latter is an improvement
|
|
2065
|
+
def _jax_wrapped_pgpe_jaxplan_merge(pgpe_mask, pgpe_param, policy_params,
|
|
2066
|
+
pgpe_loss, test_loss,
|
|
2067
|
+
pgpe_loss_smooth, test_loss_smooth,
|
|
2068
|
+
pgpe_converged, converged):
|
|
2069
|
+
def select_fn(leaf1, leaf2):
|
|
2070
|
+
expanded_mask = pgpe_mask[(...,) + (jnp.newaxis,) * (jnp.ndim(leaf1) - 1)]
|
|
2071
|
+
return jnp.where(expanded_mask, leaf1, leaf2)
|
|
2072
|
+
policy_params = jax.tree_map(select_fn, pgpe_param, policy_params)
|
|
2073
|
+
test_loss = jnp.where(pgpe_mask, pgpe_loss, test_loss)
|
|
2074
|
+
test_loss_smooth = jnp.where(pgpe_mask, pgpe_loss_smooth, test_loss_smooth)
|
|
2075
|
+
expanded_mask = pgpe_mask[(...,) + (jnp.newaxis,) * (jnp.ndim(converged) - 1)]
|
|
2076
|
+
converged = jnp.where(expanded_mask, pgpe_converged, converged)
|
|
2077
|
+
return policy_params, test_loss, test_loss_smooth, converged
|
|
2078
|
+
|
|
2079
|
+
return jax.jit(_jax_wrapped_pgpe_jaxplan_merge)
|
|
2080
|
+
|
|
1953
2081
|
def _batched_init_subs(self, subs):
|
|
1954
2082
|
rddl = self.rddl
|
|
1955
2083
|
n_train, n_test = self.batch_size_train, self.batch_size_test
|
|
@@ -1968,6 +2096,13 @@ r"""
|
|
|
1968
2096
|
train_value = np.asarray(train_value, dtype=self.compiled.REAL)
|
|
1969
2097
|
init_train[name] = train_value
|
|
1970
2098
|
init_test[name] = np.repeat(value, repeats=n_test, axis=0)
|
|
2099
|
+
|
|
2100
|
+
# safely cast test subs variable to required type in case the type is wrong
|
|
2101
|
+
if name in rddl.variable_ranges:
|
|
2102
|
+
required_type = RDDLValueInitializer.NUMPY_TYPES.get(
|
|
2103
|
+
rddl.variable_ranges[name], RDDLValueInitializer.INT)
|
|
2104
|
+
if np.result_type(init_test[name]) != required_type:
|
|
2105
|
+
init_test[name] = np.asarray(init_test[name], dtype=required_type)
|
|
1971
2106
|
|
|
1972
2107
|
# make sure next-state fluents are also set
|
|
1973
2108
|
for (state, next_state) in rddl.next_state.items():
|
|
@@ -1975,6 +2110,19 @@ r"""
|
|
|
1975
2110
|
init_test[next_state] = init_test[state]
|
|
1976
2111
|
return init_train, init_test
|
|
1977
2112
|
|
|
2113
|
+
def _broadcast_pytree(self, pytree):
|
|
2114
|
+
if self.parallel_updates is None:
|
|
2115
|
+
return pytree
|
|
2116
|
+
|
|
2117
|
+
# for parallel policy update
|
|
2118
|
+
def make_batched(x):
|
|
2119
|
+
x = np.asarray(x)
|
|
2120
|
+
x = np.broadcast_to(
|
|
2121
|
+
x[np.newaxis, ...], shape=(self.parallel_updates,) + np.shape(x))
|
|
2122
|
+
return x
|
|
2123
|
+
|
|
2124
|
+
return jax.tree_map(make_batched, pytree)
|
|
2125
|
+
|
|
1978
2126
|
def as_optimization_problem(
|
|
1979
2127
|
self, key: Optional[random.PRNGKey]=None,
|
|
1980
2128
|
policy_hyperparams: Optional[Pytree]=None,
|
|
@@ -2002,6 +2150,11 @@ r"""
|
|
|
2002
2150
|
:param grad_function_updates_key: if True, the gradient function
|
|
2003
2151
|
updates the PRNG key internally independently of the loss function.
|
|
2004
2152
|
'''
|
|
2153
|
+
|
|
2154
|
+
# make sure parallel updates are disabled
|
|
2155
|
+
if self.parallel_updates is not None:
|
|
2156
|
+
raise ValueError('Cannot compile static optimization problem '
|
|
2157
|
+
'when parallel_updates is not None.')
|
|
2005
2158
|
|
|
2006
2159
|
# if PRNG key is not provided
|
|
2007
2160
|
if key is None:
|
|
@@ -2012,8 +2165,10 @@ r"""
|
|
|
2012
2165
|
train_subs, _ = self._batched_init_subs(subs)
|
|
2013
2166
|
model_params = self.compiled.model_params
|
|
2014
2167
|
if policy_hyperparams is None:
|
|
2015
|
-
|
|
2016
|
-
|
|
2168
|
+
message = termcolor.colored(
|
|
2169
|
+
'[WARN] policy_hyperparams is not set, setting 1.0 for '
|
|
2170
|
+
'all action-fluents which could be suboptimal.', 'yellow')
|
|
2171
|
+
print(message)
|
|
2017
2172
|
policy_hyperparams = {action: 1.0
|
|
2018
2173
|
for action in self.rddl.action_fluents}
|
|
2019
2174
|
|
|
@@ -2084,10 +2239,12 @@ r"""
|
|
|
2084
2239
|
their values: if None initializes all variables from the RDDL instance
|
|
2085
2240
|
:param guess: initial policy parameters: if None will use the initializer
|
|
2086
2241
|
specified in this instance
|
|
2087
|
-
:param print_summary: whether to print planner header
|
|
2088
|
-
summary, and diagnosis
|
|
2242
|
+
:param print_summary: whether to print planner header and diagnosis
|
|
2089
2243
|
:param print_progress: whether to print the progress bar during training
|
|
2244
|
+
:param print_hyperparams: whether to print list of hyper-parameter settings
|
|
2090
2245
|
:param stopping_rule: stopping criterion
|
|
2246
|
+
:param restart_epochs: restart the optimizer from a random policy configuration
|
|
2247
|
+
if there is no progress for this many consecutive iterations
|
|
2091
2248
|
:param test_rolling_window: the test return is averaged on a rolling
|
|
2092
2249
|
window of the past test_rolling_window returns when updating the best
|
|
2093
2250
|
parameters found so far
|
|
@@ -2120,7 +2277,9 @@ r"""
|
|
|
2120
2277
|
guess: Optional[Pytree]=None,
|
|
2121
2278
|
print_summary: bool=True,
|
|
2122
2279
|
print_progress: bool=True,
|
|
2280
|
+
print_hyperparams: bool=False,
|
|
2123
2281
|
stopping_rule: Optional[JaxPlannerStoppingRule]=None,
|
|
2282
|
+
restart_epochs: int=999999,
|
|
2124
2283
|
test_rolling_window: int=10,
|
|
2125
2284
|
tqdm_position: Optional[int]=None) -> Generator[Dict[str, Any], None, None]:
|
|
2126
2285
|
'''Returns a generator for computing an optimal policy or plan.
|
|
@@ -2139,10 +2298,12 @@ r"""
|
|
|
2139
2298
|
their values: if None initializes all variables from the RDDL instance
|
|
2140
2299
|
:param guess: initial policy parameters: if None will use the initializer
|
|
2141
2300
|
specified in this instance
|
|
2142
|
-
:param print_summary: whether to print planner header
|
|
2143
|
-
summary, and diagnosis
|
|
2301
|
+
:param print_summary: whether to print planner header and diagnosis
|
|
2144
2302
|
:param print_progress: whether to print the progress bar during training
|
|
2303
|
+
:param print_hyperparams: whether to print list of hyper-parameter settings
|
|
2145
2304
|
:param stopping_rule: stopping criterion
|
|
2305
|
+
:param restart_epochs: restart the optimizer from a random policy configuration
|
|
2306
|
+
if there is no progress for this many consecutive iterations
|
|
2146
2307
|
:param test_rolling_window: the test return is averaged on a rolling
|
|
2147
2308
|
window of the past test_rolling_window returns when updating the best
|
|
2148
2309
|
parameters found so far
|
|
@@ -2155,6 +2316,14 @@ r"""
|
|
|
2155
2316
|
# INITIALIZATION OF HYPER-PARAMETERS
|
|
2156
2317
|
# ======================================================================
|
|
2157
2318
|
|
|
2319
|
+
# cannot run dashboard with parallel updates
|
|
2320
|
+
if dashboard is not None and self.parallel_updates is not None:
|
|
2321
|
+
message = termcolor.colored(
|
|
2322
|
+
'[WARN] Dashboard is unavailable if parallel_updates is not None: '
|
|
2323
|
+
'setting dashboard to None.', 'yellow')
|
|
2324
|
+
print(message)
|
|
2325
|
+
dashboard = None
|
|
2326
|
+
|
|
2158
2327
|
# if PRNG key is not provided
|
|
2159
2328
|
if key is None:
|
|
2160
2329
|
key = random.PRNGKey(round(time.time() * 1000))
|
|
@@ -2162,15 +2331,19 @@ r"""
|
|
|
2162
2331
|
|
|
2163
2332
|
# if policy_hyperparams is not provided
|
|
2164
2333
|
if policy_hyperparams is None:
|
|
2165
|
-
|
|
2166
|
-
|
|
2334
|
+
message = termcolor.colored(
|
|
2335
|
+
'[WARN] policy_hyperparams is not set, setting 1.0 for '
|
|
2336
|
+
'all action-fluents which could be suboptimal.', 'yellow')
|
|
2337
|
+
print(message)
|
|
2167
2338
|
policy_hyperparams = {action: 1.0
|
|
2168
2339
|
for action in self.rddl.action_fluents}
|
|
2169
2340
|
|
|
2170
2341
|
# if policy_hyperparams is a scalar
|
|
2171
2342
|
elif isinstance(policy_hyperparams, (int, float, np.number)):
|
|
2172
|
-
|
|
2173
|
-
|
|
2343
|
+
message = termcolor.colored(
|
|
2344
|
+
f'[INFO] policy_hyperparams is {policy_hyperparams}, '
|
|
2345
|
+
f'setting this value for all action-fluents.', 'green')
|
|
2346
|
+
print(message)
|
|
2174
2347
|
hyperparam_value = float(policy_hyperparams)
|
|
2175
2348
|
policy_hyperparams = {action: hyperparam_value
|
|
2176
2349
|
for action in self.rddl.action_fluents}
|
|
@@ -2179,14 +2352,19 @@ r"""
|
|
|
2179
2352
|
elif isinstance(policy_hyperparams, dict):
|
|
2180
2353
|
for action in self.rddl.action_fluents:
|
|
2181
2354
|
if action not in policy_hyperparams:
|
|
2182
|
-
|
|
2183
|
-
|
|
2355
|
+
message = termcolor.colored(
|
|
2356
|
+
f'[WARN] policy_hyperparams[{action}] is not set, '
|
|
2357
|
+
f'setting 1.0 for missing action-fluents '
|
|
2358
|
+
f'which could be suboptimal.', 'yellow')
|
|
2359
|
+
print(message)
|
|
2184
2360
|
policy_hyperparams[action] = 1.0
|
|
2185
2361
|
|
|
2186
2362
|
# print summary of parameters:
|
|
2187
2363
|
if print_summary:
|
|
2188
2364
|
print(self.summarize_system())
|
|
2189
|
-
self.
|
|
2365
|
+
print(self.summarize_relaxations())
|
|
2366
|
+
if print_hyperparams:
|
|
2367
|
+
print(self.summarize_hyperparameters())
|
|
2190
2368
|
print(f'optimize() call hyper-parameters:\n'
|
|
2191
2369
|
f' PRNG key ={key}\n'
|
|
2192
2370
|
f' max_iterations ={epochs}\n'
|
|
@@ -2200,7 +2378,8 @@ r"""
|
|
|
2200
2378
|
f' dashboard_id ={dashboard_id}\n'
|
|
2201
2379
|
f' print_summary ={print_summary}\n'
|
|
2202
2380
|
f' print_progress ={print_progress}\n'
|
|
2203
|
-
f' stopping_rule ={stopping_rule}\n'
|
|
2381
|
+
f' stopping_rule ={stopping_rule}\n'
|
|
2382
|
+
f' restart_epochs ={restart_epochs}\n')
|
|
2204
2383
|
|
|
2205
2384
|
# ======================================================================
|
|
2206
2385
|
# INITIALIZATION OF STATE AND POLICY
|
|
@@ -2218,15 +2397,17 @@ r"""
|
|
|
2218
2397
|
subs[var] = value
|
|
2219
2398
|
added_pvars_to_subs.append(var)
|
|
2220
2399
|
if added_pvars_to_subs:
|
|
2221
|
-
|
|
2222
|
-
|
|
2223
|
-
|
|
2400
|
+
message = termcolor.colored(
|
|
2401
|
+
f'[INFO] p-variables {added_pvars_to_subs} is not in '
|
|
2402
|
+
f'provided subs, using their initial values.', 'green')
|
|
2403
|
+
print(message)
|
|
2224
2404
|
train_subs, test_subs = self._batched_init_subs(subs)
|
|
2225
2405
|
|
|
2226
2406
|
# initialize model parameters
|
|
2227
2407
|
if model_params is None:
|
|
2228
2408
|
model_params = self.compiled.model_params
|
|
2229
|
-
|
|
2409
|
+
model_params = self._broadcast_pytree(model_params)
|
|
2410
|
+
model_params_test = self._broadcast_pytree(self.test_compiled.model_params)
|
|
2230
2411
|
|
|
2231
2412
|
# initialize policy parameters
|
|
2232
2413
|
if guess is None:
|
|
@@ -2234,29 +2415,31 @@ r"""
|
|
|
2234
2415
|
policy_params, opt_state, opt_aux = self.initialize(
|
|
2235
2416
|
subkey, policy_hyperparams, train_subs)
|
|
2236
2417
|
else:
|
|
2237
|
-
policy_params = guess
|
|
2238
|
-
opt_state = self.
|
|
2239
|
-
opt_aux = {}
|
|
2418
|
+
policy_params = self._broadcast_pytree(guess)
|
|
2419
|
+
opt_state, opt_aux = self.init_optimizer(policy_params)
|
|
2240
2420
|
|
|
2241
2421
|
# initialize pgpe parameters
|
|
2242
2422
|
if self.use_pgpe:
|
|
2243
|
-
pgpe_params, pgpe_opt_state = self.pgpe.initialize(key, policy_params)
|
|
2423
|
+
pgpe_params, pgpe_opt_state, r_max = self.pgpe.initialize(key, policy_params)
|
|
2244
2424
|
rolling_pgpe_loss = RollingMean(test_rolling_window)
|
|
2245
2425
|
else:
|
|
2246
|
-
pgpe_params, pgpe_opt_state = None, None
|
|
2426
|
+
pgpe_params, pgpe_opt_state, r_max = None, None, None
|
|
2247
2427
|
rolling_pgpe_loss = None
|
|
2248
2428
|
total_pgpe_it = 0
|
|
2249
|
-
r_max = -jnp.inf
|
|
2250
2429
|
|
|
2251
2430
|
# ======================================================================
|
|
2252
2431
|
# INITIALIZATION OF RUNNING STATISTICS
|
|
2253
2432
|
# ======================================================================
|
|
2254
2433
|
|
|
2255
2434
|
# initialize running statistics
|
|
2256
|
-
|
|
2435
|
+
if self.parallel_updates is None:
|
|
2436
|
+
best_params = policy_params
|
|
2437
|
+
else:
|
|
2438
|
+
best_params = self.pytree_at(policy_params, 0)
|
|
2439
|
+
best_loss, pbest_loss, best_grad = np.inf, np.inf, None
|
|
2257
2440
|
last_iter_improve = 0
|
|
2441
|
+
no_progress_count = 0
|
|
2258
2442
|
rolling_test_loss = RollingMean(test_rolling_window)
|
|
2259
|
-
log = {}
|
|
2260
2443
|
status = JaxPlannerStatus.NORMAL
|
|
2261
2444
|
progress_percent = 0
|
|
2262
2445
|
|
|
@@ -2277,6 +2460,11 @@ r"""
|
|
|
2277
2460
|
else:
|
|
2278
2461
|
progress_bar = None
|
|
2279
2462
|
position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
|
|
2463
|
+
|
|
2464
|
+
# error handlers (to avoid spam messaging)
|
|
2465
|
+
policy_constraint_msg_shown = False
|
|
2466
|
+
jax_train_msg_shown = False
|
|
2467
|
+
jax_test_msg_shown = False
|
|
2280
2468
|
|
|
2281
2469
|
# ======================================================================
|
|
2282
2470
|
# MAIN TRAINING LOOP BEGINS
|
|
@@ -2296,8 +2484,13 @@ r"""
|
|
|
2296
2484
|
model_params, zero_grads) = self.update(
|
|
2297
2485
|
subkey, policy_params, policy_hyperparams, train_subs, model_params,
|
|
2298
2486
|
opt_state, opt_aux)
|
|
2487
|
+
|
|
2488
|
+
# evaluate
|
|
2299
2489
|
test_loss, (test_log, model_params_test) = self.test_loss(
|
|
2300
2490
|
subkey, policy_params, policy_hyperparams, test_subs, model_params_test)
|
|
2491
|
+
if self.parallel_updates:
|
|
2492
|
+
train_loss = np.asarray(train_loss)
|
|
2493
|
+
test_loss = np.asarray(test_loss)
|
|
2301
2494
|
test_loss_smooth = rolling_test_loss.update(test_loss)
|
|
2302
2495
|
|
|
2303
2496
|
# pgpe update of the plan
|
|
@@ -2308,52 +2501,112 @@ r"""
|
|
|
2308
2501
|
self.pgpe.update(subkey, pgpe_params, r_max, progress_percent,
|
|
2309
2502
|
policy_hyperparams, test_subs, model_params_test,
|
|
2310
2503
|
pgpe_opt_state)
|
|
2504
|
+
|
|
2505
|
+
# evaluate
|
|
2311
2506
|
pgpe_loss, _ = self.test_loss(
|
|
2312
2507
|
subkey, pgpe_param, policy_hyperparams, test_subs, model_params_test)
|
|
2508
|
+
if self.parallel_updates:
|
|
2509
|
+
pgpe_loss = np.asarray(pgpe_loss)
|
|
2313
2510
|
pgpe_loss_smooth = rolling_pgpe_loss.update(pgpe_loss)
|
|
2314
2511
|
pgpe_return = -pgpe_loss_smooth
|
|
2315
2512
|
|
|
2316
|
-
# replace with PGPE if
|
|
2317
|
-
if
|
|
2318
|
-
|
|
2319
|
-
|
|
2320
|
-
|
|
2321
|
-
|
|
2322
|
-
|
|
2513
|
+
# replace JaxPlan with PGPE if new minimum reached or train loss invalid
|
|
2514
|
+
if self.parallel_updates is None:
|
|
2515
|
+
if pgpe_loss_smooth < best_loss or not np.isfinite(train_loss):
|
|
2516
|
+
policy_params = pgpe_param
|
|
2517
|
+
test_loss, test_loss_smooth = pgpe_loss, pgpe_loss_smooth
|
|
2518
|
+
converged = pgpe_converged
|
|
2519
|
+
pgpe_improve = True
|
|
2520
|
+
total_pgpe_it += 1
|
|
2521
|
+
else:
|
|
2522
|
+
pgpe_mask = (pgpe_loss_smooth < pbest_loss) | ~np.isfinite(train_loss)
|
|
2523
|
+
if np.any(pgpe_mask):
|
|
2524
|
+
policy_params, test_loss, test_loss_smooth, converged = \
|
|
2525
|
+
self.merge_pgpe(pgpe_mask, pgpe_param, policy_params,
|
|
2526
|
+
pgpe_loss, test_loss,
|
|
2527
|
+
pgpe_loss_smooth, test_loss_smooth,
|
|
2528
|
+
pgpe_converged, converged)
|
|
2529
|
+
pgpe_improve = True
|
|
2530
|
+
total_pgpe_it += 1
|
|
2323
2531
|
else:
|
|
2324
2532
|
pgpe_loss, pgpe_loss_smooth, pgpe_return = None, None, None
|
|
2325
2533
|
|
|
2326
|
-
# evaluate test losses and record best
|
|
2327
|
-
if
|
|
2328
|
-
|
|
2329
|
-
|
|
2330
|
-
|
|
2534
|
+
# evaluate test losses and record best parameters so far
|
|
2535
|
+
if self.parallel_updates is None:
|
|
2536
|
+
if test_loss_smooth < best_loss:
|
|
2537
|
+
best_params, best_loss, best_grad = \
|
|
2538
|
+
policy_params, test_loss_smooth, train_log['grad']
|
|
2539
|
+
pbest_loss = best_loss
|
|
2540
|
+
else:
|
|
2541
|
+
best_index = np.argmin(test_loss_smooth)
|
|
2542
|
+
if test_loss_smooth[best_index] < best_loss:
|
|
2543
|
+
best_params = self.pytree_at(policy_params, best_index)
|
|
2544
|
+
best_grad = self.pytree_at(train_log['grad'], best_index)
|
|
2545
|
+
best_loss = test_loss_smooth[best_index]
|
|
2546
|
+
pbest_loss = np.minimum(pbest_loss, test_loss_smooth)
|
|
2331
2547
|
|
|
2332
2548
|
# ==================================================================
|
|
2333
2549
|
# STATUS CHECKS AND LOGGING
|
|
2334
2550
|
# ==================================================================
|
|
2335
2551
|
|
|
2336
2552
|
# no progress
|
|
2337
|
-
|
|
2553
|
+
no_progress_flag = (not pgpe_improve) and np.all(zero_grads)
|
|
2554
|
+
if no_progress_flag:
|
|
2338
2555
|
status = JaxPlannerStatus.NO_PROGRESS
|
|
2339
|
-
|
|
2556
|
+
|
|
2340
2557
|
# constraint satisfaction problem
|
|
2341
|
-
if not np.all(converged):
|
|
2342
|
-
|
|
2343
|
-
|
|
2344
|
-
|
|
2345
|
-
|
|
2558
|
+
if not np.all(converged):
|
|
2559
|
+
if progress_bar is not None and not policy_constraint_msg_shown:
|
|
2560
|
+
message = termcolor.colored(
|
|
2561
|
+
'[FAIL] Policy update failed to satisfy action constraints.',
|
|
2562
|
+
'red')
|
|
2563
|
+
progress_bar.write(message)
|
|
2564
|
+
policy_constraint_msg_shown = True
|
|
2346
2565
|
status = JaxPlannerStatus.PRECONDITION_POSSIBLY_UNSATISFIED
|
|
2347
2566
|
|
|
2348
2567
|
# numerical error
|
|
2349
2568
|
if self.use_pgpe:
|
|
2350
|
-
invalid_loss = not (np.isfinite(train_loss) or
|
|
2569
|
+
invalid_loss = not (np.any(np.isfinite(train_loss)) or
|
|
2570
|
+
np.any(np.isfinite(pgpe_loss)))
|
|
2351
2571
|
else:
|
|
2352
|
-
invalid_loss = not np.isfinite(train_loss)
|
|
2572
|
+
invalid_loss = not np.any(np.isfinite(train_loss))
|
|
2353
2573
|
if invalid_loss:
|
|
2354
|
-
|
|
2574
|
+
if progress_bar is not None:
|
|
2575
|
+
message = termcolor.colored(
|
|
2576
|
+
f'[FAIL] Planner aborted due to invalid train loss {train_loss}.',
|
|
2577
|
+
'red')
|
|
2578
|
+
progress_bar.write(message)
|
|
2355
2579
|
status = JaxPlannerStatus.INVALID_GRADIENT
|
|
2356
2580
|
|
|
2581
|
+
# problem in the model compilation
|
|
2582
|
+
if progress_bar is not None:
|
|
2583
|
+
|
|
2584
|
+
# train model
|
|
2585
|
+
if not jax_train_msg_shown:
|
|
2586
|
+
messages = set()
|
|
2587
|
+
for error_code in np.unique(train_log['error']):
|
|
2588
|
+
messages.update(JaxRDDLCompiler.get_error_messages(error_code))
|
|
2589
|
+
if messages:
|
|
2590
|
+
messages = '\n '.join(messages)
|
|
2591
|
+
message = termcolor.colored(
|
|
2592
|
+
f'[FAIL] Compiler encountered the following '
|
|
2593
|
+
f'error(s) in the training model:\n {messages}', 'red')
|
|
2594
|
+
progress_bar.write(message)
|
|
2595
|
+
jax_train_msg_shown = True
|
|
2596
|
+
|
|
2597
|
+
# test model
|
|
2598
|
+
if not jax_test_msg_shown:
|
|
2599
|
+
messages = set()
|
|
2600
|
+
for error_code in np.unique(test_log['error']):
|
|
2601
|
+
messages.update(JaxRDDLCompiler.get_error_messages(error_code))
|
|
2602
|
+
if messages:
|
|
2603
|
+
messages = '\n '.join(messages)
|
|
2604
|
+
message = termcolor.colored(
|
|
2605
|
+
f'[FAIL] Compiler encountered the following '
|
|
2606
|
+
f'error(s) in the testing model:\n {messages}', 'red')
|
|
2607
|
+
progress_bar.write(message)
|
|
2608
|
+
jax_test_msg_shown = True
|
|
2609
|
+
|
|
2357
2610
|
# reached computation budget
|
|
2358
2611
|
elapsed = time.time() - start_time - elapsed_outside_loop
|
|
2359
2612
|
if elapsed >= train_seconds:
|
|
@@ -2387,20 +2640,39 @@ r"""
|
|
|
2387
2640
|
**test_log
|
|
2388
2641
|
}
|
|
2389
2642
|
|
|
2643
|
+
# hard restart
|
|
2644
|
+
if guess is None and no_progress_flag:
|
|
2645
|
+
no_progress_count += 1
|
|
2646
|
+
if no_progress_count > restart_epochs:
|
|
2647
|
+
key, subkey = random.split(key)
|
|
2648
|
+
policy_params, opt_state, opt_aux = self.initialize(
|
|
2649
|
+
subkey, policy_hyperparams, train_subs)
|
|
2650
|
+
no_progress_count = 0
|
|
2651
|
+
if progress_bar is not None:
|
|
2652
|
+
message = termcolor.colored(
|
|
2653
|
+
f'[INFO] Optimizer restarted at iteration {it} '
|
|
2654
|
+
f'due to lack of progress.', 'green')
|
|
2655
|
+
progress_bar.write(message)
|
|
2656
|
+
else:
|
|
2657
|
+
no_progress_count = 0
|
|
2658
|
+
|
|
2390
2659
|
# stopping condition reached
|
|
2391
2660
|
if stopping_rule is not None and stopping_rule.monitor(callback):
|
|
2661
|
+
if progress_bar is not None:
|
|
2662
|
+
message = termcolor.colored(
|
|
2663
|
+
'[SUCC] Stopping rule has been reached.', 'green')
|
|
2664
|
+
progress_bar.write(message)
|
|
2392
2665
|
callback['status'] = status = JaxPlannerStatus.STOPPING_RULE_REACHED
|
|
2393
2666
|
|
|
2394
2667
|
# if the progress bar is used
|
|
2395
2668
|
if print_progress:
|
|
2396
2669
|
progress_bar.set_description(
|
|
2397
|
-
f'{position_str} {it:6} it / {-train_loss:14.5f} train / '
|
|
2398
|
-
f'{-test_loss_smooth:14.5f} test / {-best_loss:14.5f} best / '
|
|
2670
|
+
f'{position_str} {it:6} it / {-np.min(train_loss):14.5f} train / '
|
|
2671
|
+
f'{-np.min(test_loss_smooth):14.5f} test / {-best_loss:14.5f} best / '
|
|
2399
2672
|
f'{status.value} status / {total_pgpe_it:6} pgpe',
|
|
2400
|
-
refresh=False
|
|
2401
|
-
)
|
|
2673
|
+
refresh=False)
|
|
2402
2674
|
progress_bar.set_postfix_str(
|
|
2403
|
-
f
|
|
2675
|
+
f'{(it + 1) / (elapsed + 1e-6):.2f}it/s', refresh=False)
|
|
2404
2676
|
progress_bar.update(progress_percent - progress_bar.n)
|
|
2405
2677
|
|
|
2406
2678
|
# dash-board
|
|
@@ -2423,24 +2695,15 @@ r"""
|
|
|
2423
2695
|
# release resources
|
|
2424
2696
|
if print_progress:
|
|
2425
2697
|
progress_bar.close()
|
|
2426
|
-
|
|
2427
|
-
# validate the test return
|
|
2428
|
-
if log:
|
|
2429
|
-
messages = set()
|
|
2430
|
-
for error_code in np.unique(log['error']):
|
|
2431
|
-
messages.update(JaxRDDLCompiler.get_error_messages(error_code))
|
|
2432
|
-
if messages:
|
|
2433
|
-
messages = '\n'.join(messages)
|
|
2434
|
-
raise_warning('JAX compiler encountered the following '
|
|
2435
|
-
'error(s) in the original RDDL formulation '
|
|
2436
|
-
f'during test evaluation:\n{messages}', 'red')
|
|
2698
|
+
print()
|
|
2437
2699
|
|
|
2438
2700
|
# summarize and test for convergence
|
|
2439
2701
|
if print_summary:
|
|
2440
2702
|
grad_norm = jax.tree_map(lambda x: np.linalg.norm(x).item(), best_grad)
|
|
2441
2703
|
diagnosis = self._perform_diagnosis(
|
|
2442
|
-
last_iter_improve, -train_loss, -test_loss_smooth,
|
|
2443
|
-
|
|
2704
|
+
last_iter_improve, -np.min(train_loss), -np.min(test_loss_smooth),
|
|
2705
|
+
-best_loss, grad_norm)
|
|
2706
|
+
print(f'Summary of optimization:\n'
|
|
2444
2707
|
f' status ={status}\n'
|
|
2445
2708
|
f' time ={elapsed:.3f} sec.\n'
|
|
2446
2709
|
f' iterations ={it}\n'
|
|
@@ -2453,12 +2716,9 @@ r"""
|
|
|
2453
2716
|
max_grad_norm = max(jax.tree_util.tree_leaves(grad_norm))
|
|
2454
2717
|
grad_is_zero = np.allclose(max_grad_norm, 0)
|
|
2455
2718
|
|
|
2456
|
-
validation_error = 100 * abs(test_return - train_return) / \
|
|
2457
|
-
max(abs(train_return), abs(test_return))
|
|
2458
|
-
|
|
2459
2719
|
# divergence if the solution is not finite
|
|
2460
2720
|
if not np.isfinite(train_return):
|
|
2461
|
-
return termcolor.colored('[
|
|
2721
|
+
return termcolor.colored('[FAIL] Training loss diverged.', 'red')
|
|
2462
2722
|
|
|
2463
2723
|
# hit a plateau is likely IF:
|
|
2464
2724
|
# 1. planner does not improve at all
|
|
@@ -2466,23 +2726,25 @@ r"""
|
|
|
2466
2726
|
if last_iter_improve <= 1:
|
|
2467
2727
|
if grad_is_zero:
|
|
2468
2728
|
return termcolor.colored(
|
|
2469
|
-
'[
|
|
2729
|
+
f'[FAIL] No progress was made '
|
|
2470
2730
|
f'and max grad norm {max_grad_norm:.6f} was zero: '
|
|
2471
|
-
'solver likely stuck in a plateau.', 'red')
|
|
2731
|
+
f'solver likely stuck in a plateau.', 'red')
|
|
2472
2732
|
else:
|
|
2473
2733
|
return termcolor.colored(
|
|
2474
|
-
'[
|
|
2734
|
+
f'[FAIL] No progress was made '
|
|
2475
2735
|
f'but max grad norm {max_grad_norm:.6f} was non-zero: '
|
|
2476
|
-
'learning rate or other hyper-parameters
|
|
2736
|
+
f'learning rate or other hyper-parameters could be suboptimal.',
|
|
2477
2737
|
'red')
|
|
2478
2738
|
|
|
2479
2739
|
# model is likely poor IF:
|
|
2480
2740
|
# 1. the train and test return disagree
|
|
2741
|
+
validation_error = 100 * abs(test_return - train_return) / \
|
|
2742
|
+
max(abs(train_return), abs(test_return))
|
|
2481
2743
|
if not (validation_error < 20):
|
|
2482
2744
|
return termcolor.colored(
|
|
2483
|
-
'[
|
|
2745
|
+
f'[WARN] Progress was made '
|
|
2484
2746
|
f'but relative train-test error {validation_error:.6f} was high: '
|
|
2485
|
-
'poor model relaxation around solution or batch size too small.',
|
|
2747
|
+
f'poor model relaxation around solution or batch size too small.',
|
|
2486
2748
|
'yellow')
|
|
2487
2749
|
|
|
2488
2750
|
# model likely did not converge IF:
|
|
@@ -2491,24 +2753,22 @@ r"""
|
|
|
2491
2753
|
return_to_grad_norm = abs(best_return) / max_grad_norm
|
|
2492
2754
|
if not (return_to_grad_norm > 1):
|
|
2493
2755
|
return termcolor.colored(
|
|
2494
|
-
'[
|
|
2756
|
+
f'[WARN] Progress was made '
|
|
2495
2757
|
f'but max grad norm {max_grad_norm:.6f} was high: '
|
|
2496
|
-
'solution locally suboptimal '
|
|
2497
|
-
'or
|
|
2498
|
-
'or batch size too small.', 'yellow')
|
|
2758
|
+
f'solution locally suboptimal, relaxed model nonsmooth around solution, '
|
|
2759
|
+
f'or batch size too small.', 'yellow')
|
|
2499
2760
|
|
|
2500
2761
|
# likely successful
|
|
2501
2762
|
return termcolor.colored(
|
|
2502
|
-
'[
|
|
2503
|
-
'(note: not all
|
|
2763
|
+
'[SUCC] Planner converged successfully '
|
|
2764
|
+
'(note: not all problems can be ruled out).', 'green')
|
|
2504
2765
|
|
|
2505
2766
|
def get_action(self, key: random.PRNGKey,
|
|
2506
2767
|
params: Pytree,
|
|
2507
2768
|
step: int,
|
|
2508
2769
|
subs: Dict[str, Any],
|
|
2509
2770
|
policy_hyperparams: Optional[Dict[str, Any]]=None) -> Dict[str, Any]:
|
|
2510
|
-
'''Returns an action dictionary from the policy or plan with the given
|
|
2511
|
-
parameters.
|
|
2771
|
+
'''Returns an action dictionary from the policy or plan with the given parameters.
|
|
2512
2772
|
|
|
2513
2773
|
:param key: the JAX PRNG key
|
|
2514
2774
|
:param params: the trainable parameter PyTree of the policy
|
|
@@ -2612,8 +2872,7 @@ class JaxOfflineController(BaseAgent):
|
|
|
2612
2872
|
|
|
2613
2873
|
|
|
2614
2874
|
class JaxOnlineController(BaseAgent):
|
|
2615
|
-
'''A container class for a Jax controller continuously updated using state
|
|
2616
|
-
feedback.'''
|
|
2875
|
+
'''A container class for a Jax controller continuously updated using state feedback.'''
|
|
2617
2876
|
|
|
2618
2877
|
use_tensor_obs = True
|
|
2619
2878
|
|
|
@@ -2621,17 +2880,19 @@ class JaxOnlineController(BaseAgent):
|
|
|
2621
2880
|
key: Optional[random.PRNGKey]=None,
|
|
2622
2881
|
eval_hyperparams: Optional[Dict[str, Any]]=None,
|
|
2623
2882
|
warm_start: bool=True,
|
|
2883
|
+
max_attempts: int=3,
|
|
2624
2884
|
**train_kwargs) -> None:
|
|
2625
2885
|
'''Creates a new JAX control policy that is trained online in a closed-
|
|
2626
2886
|
loop fashion.
|
|
2627
2887
|
|
|
2628
2888
|
:param planner: underlying planning algorithm for optimizing actions
|
|
2629
|
-
:param key: the RNG key to seed randomness (derives from clock if not
|
|
2630
|
-
provided)
|
|
2889
|
+
:param key: the RNG key to seed randomness (derives from clock if not provided)
|
|
2631
2890
|
:param eval_hyperparams: policy hyperparameters to apply for evaluation
|
|
2632
2891
|
or whenever sample_action is called
|
|
2633
2892
|
:param warm_start: whether to use the previous decision epoch final
|
|
2634
2893
|
policy parameters to warm the next decision epoch
|
|
2894
|
+
:param max_attempts: maximum attempted restarts of the optimizer when the total
|
|
2895
|
+
iteration count is 1 (i.e. the execution time is dominated by the jit compilation)
|
|
2635
2896
|
:param **train_kwargs: any keyword arguments to be passed to the planner
|
|
2636
2897
|
for optimization
|
|
2637
2898
|
'''
|
|
@@ -2642,16 +2903,26 @@ class JaxOnlineController(BaseAgent):
|
|
|
2642
2903
|
self.eval_hyperparams = eval_hyperparams
|
|
2643
2904
|
self.warm_start = warm_start
|
|
2644
2905
|
self.train_kwargs = train_kwargs
|
|
2906
|
+
self.max_attempts = max_attempts
|
|
2645
2907
|
self.reset()
|
|
2646
2908
|
|
|
2647
2909
|
def sample_action(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
|
2648
2910
|
planner = self.planner
|
|
2649
2911
|
callback = planner.optimize(
|
|
2650
|
-
key=self.key,
|
|
2651
|
-
|
|
2652
|
-
|
|
2653
|
-
|
|
2654
|
-
|
|
2912
|
+
key=self.key, guess=self.guess, subs=state, **self.train_kwargs)
|
|
2913
|
+
|
|
2914
|
+
# optimize again if jit compilation takes up the entire time budget
|
|
2915
|
+
attempts = 0
|
|
2916
|
+
while attempts < self.max_attempts and callback['iteration'] <= 1:
|
|
2917
|
+
attempts += 1
|
|
2918
|
+
message = termcolor.colored(
|
|
2919
|
+
f'[WARN] JIT compilation dominated the execution time: '
|
|
2920
|
+
f'executing the optimizer again on the traced model [attempt {attempts}].',
|
|
2921
|
+
'yellow')
|
|
2922
|
+
print(message)
|
|
2923
|
+
callback = planner.optimize(
|
|
2924
|
+
key=self.key, guess=self.guess, subs=state, **self.train_kwargs)
|
|
2925
|
+
|
|
2655
2926
|
self.callback = callback
|
|
2656
2927
|
params = callback['best_params']
|
|
2657
2928
|
self.key, subkey = random.split(self.key)
|