pyRDDLGym-jax 2.3__py3-none-any.whl → 2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +10 -7
- pyRDDLGym_jax/core/logic.py +117 -66
- pyRDDLGym_jax/core/planner.py +585 -248
- pyRDDLGym_jax/core/simulator.py +37 -13
- pyRDDLGym_jax/core/tuning.py +52 -31
- pyRDDLGym_jax/entry_point.py +39 -7
- pyRDDLGym_jax/examples/configs/tuning_drp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/tuning_replan.cfg +1 -0
- pyRDDLGym_jax/examples/configs/tuning_slp.cfg +1 -0
- pyRDDLGym_jax/examples/run_plan.py +3 -3
- pyRDDLGym_jax/examples/run_scipy.py +2 -2
- pyRDDLGym_jax/examples/run_tune.py +8 -2
- {pyrddlgym_jax-2.3.dist-info → pyrddlgym_jax-2.5.dist-info}/METADATA +13 -18
- {pyrddlgym_jax-2.3.dist-info → pyrddlgym_jax-2.5.dist-info}/RECORD +19 -19
- {pyrddlgym_jax-2.3.dist-info → pyrddlgym_jax-2.5.dist-info}/WHEEL +1 -1
- {pyrddlgym_jax-2.3.dist-info → pyrddlgym_jax-2.5.dist-info}/entry_points.txt +0 -0
- {pyrddlgym_jax-2.3.dist-info → pyrddlgym_jax-2.5.dist-info/licenses}/LICENSE +0 -0
- {pyrddlgym_jax-2.3.dist-info → pyrddlgym_jax-2.5.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/planner.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
#
|
|
4
4
|
# Author: Michael Gimelfarb
|
|
5
5
|
#
|
|
6
|
-
#
|
|
6
|
+
# REFERENCES:
|
|
7
7
|
#
|
|
8
8
|
# [1] Gimelfarb, Michael, Ayal Taitler, and Scott Sanner. "JaxPlan and GurobiPlan:
|
|
9
9
|
# Optimization Baselines for Replanning in Discrete and Mixed Discrete-Continuous
|
|
@@ -18,26 +18,33 @@
|
|
|
18
18
|
# reactive policies for planning in stochastic nonlinear domains." In Proceedings of the
|
|
19
19
|
# AAAI Conference on Artificial Intelligence, vol. 33, no. 01, pp. 7530-7537. 2019.
|
|
20
20
|
#
|
|
21
|
-
# [4]
|
|
21
|
+
# [4] Cui, Hao, Thomas Keller, and Roni Khardon. "Stochastic planning with lifted symbolic
|
|
22
|
+
# trajectory optimization." In Proceedings of the International Conference on Automated
|
|
23
|
+
# Planning and Scheduling, vol. 29, pp. 119-127. 2019.
|
|
24
|
+
#
|
|
25
|
+
# [5] Wu, Ga, Buser Say, and Scott Sanner. "Scalable planning with tensorflow for hybrid
|
|
22
26
|
# nonlinear domains." Advances in Neural Information Processing Systems 30 (2017).
|
|
23
27
|
#
|
|
24
|
-
# [
|
|
28
|
+
# [6] Sehnke, Frank, and Tingting Zhao. "Baseline-free sampling in parameter exploring
|
|
25
29
|
# policy gradients: Super symmetric pgpe." Artificial Neural Networks: Methods and
|
|
26
30
|
# Applications in Bio-/Neuroinformatics. Springer International Publishing, 2015.
|
|
27
31
|
#
|
|
28
32
|
# ***********************************************************************
|
|
29
33
|
|
|
30
34
|
|
|
35
|
+
from abc import ABCMeta, abstractmethod
|
|
31
36
|
from ast import literal_eval
|
|
32
37
|
from collections import deque
|
|
33
38
|
import configparser
|
|
34
39
|
from enum import Enum
|
|
35
40
|
from functools import partial
|
|
36
41
|
import os
|
|
42
|
+
import pickle
|
|
37
43
|
import sys
|
|
38
44
|
import time
|
|
39
45
|
import traceback
|
|
40
|
-
from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Type, Tuple,
|
|
46
|
+
from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Type, Tuple, \
|
|
47
|
+
Union
|
|
41
48
|
|
|
42
49
|
import haiku as hk
|
|
43
50
|
import jax
|
|
@@ -51,6 +58,7 @@ from tqdm import tqdm, TqdmWarning
|
|
|
51
58
|
import warnings
|
|
52
59
|
warnings.filterwarnings("ignore", category=TqdmWarning)
|
|
53
60
|
|
|
61
|
+
from pyRDDLGym.core.compiler.initializer import RDDLValueInitializer
|
|
54
62
|
from pyRDDLGym.core.compiler.model import RDDLPlanningModel, RDDLLiftedModel
|
|
55
63
|
from pyRDDLGym.core.debug.logger import Logger
|
|
56
64
|
from pyRDDLGym.core.debug.exception import (
|
|
@@ -157,25 +165,20 @@ def _load_config(config, args):
|
|
|
157
165
|
initializer = _getattr_any(
|
|
158
166
|
packages=[initializers, hk.initializers], item=plan_initializer)
|
|
159
167
|
if initializer is None:
|
|
160
|
-
|
|
161
|
-
del plan_kwargs['initializer']
|
|
168
|
+
raise ValueError(f'Invalid initializer <{plan_initializer}>.')
|
|
162
169
|
else:
|
|
163
170
|
init_kwargs = plan_kwargs.pop('initializer_kwargs', {})
|
|
164
171
|
try:
|
|
165
172
|
plan_kwargs['initializer'] = initializer(**init_kwargs)
|
|
166
173
|
except Exception as _:
|
|
167
|
-
|
|
168
|
-
f'Ignoring invalid initializer_kwargs <{init_kwargs}>.', 'red')
|
|
169
|
-
plan_kwargs['initializer'] = initializer
|
|
174
|
+
raise ValueError(f'Invalid initializer kwargs <{init_kwargs}>.')
|
|
170
175
|
|
|
171
176
|
# policy activation
|
|
172
177
|
plan_activation = plan_kwargs.get('activation', None)
|
|
173
178
|
if plan_activation is not None:
|
|
174
|
-
activation = _getattr_any(
|
|
175
|
-
packages=[jax.nn, jax.numpy], item=plan_activation)
|
|
179
|
+
activation = _getattr_any(packages=[jax.nn, jax.numpy], item=plan_activation)
|
|
176
180
|
if activation is None:
|
|
177
|
-
|
|
178
|
-
del plan_kwargs['activation']
|
|
181
|
+
raise ValueError(f'Invalid activation <{plan_activation}>.')
|
|
179
182
|
else:
|
|
180
183
|
plan_kwargs['activation'] = activation
|
|
181
184
|
|
|
@@ -188,8 +191,7 @@ def _load_config(config, args):
|
|
|
188
191
|
if planner_optimizer is not None:
|
|
189
192
|
optimizer = _getattr_any(packages=[optax], item=planner_optimizer)
|
|
190
193
|
if optimizer is None:
|
|
191
|
-
|
|
192
|
-
del planner_args['optimizer']
|
|
194
|
+
raise ValueError(f'Invalid optimizer <{planner_optimizer}>.')
|
|
193
195
|
else:
|
|
194
196
|
planner_args['optimizer'] = optimizer
|
|
195
197
|
|
|
@@ -200,8 +202,7 @@ def _load_config(config, args):
|
|
|
200
202
|
if 'optimizer' in pgpe_kwargs:
|
|
201
203
|
pgpe_optimizer = _getattr_any(packages=[optax], item=pgpe_kwargs['optimizer'])
|
|
202
204
|
if pgpe_optimizer is None:
|
|
203
|
-
|
|
204
|
-
del pgpe_kwargs['optimizer']
|
|
205
|
+
raise ValueError(f'Invalid optimizer <{pgpe_optimizer}>.')
|
|
205
206
|
else:
|
|
206
207
|
pgpe_kwargs['optimizer'] = pgpe_optimizer
|
|
207
208
|
planner_args['pgpe'] = getattr(sys.modules[__name__], pgpe_method)(**pgpe_kwargs)
|
|
@@ -229,13 +230,19 @@ def _load_config(config, args):
|
|
|
229
230
|
|
|
230
231
|
|
|
231
232
|
def load_config(path: str) -> Tuple[Kwargs, ...]:
|
|
232
|
-
'''Loads a config file at the specified file path.
|
|
233
|
+
'''Loads a config file at the specified file path.
|
|
234
|
+
|
|
235
|
+
:param path: the path of the config file to load and parse
|
|
236
|
+
'''
|
|
233
237
|
config, args = _parse_config_file(path)
|
|
234
238
|
return _load_config(config, args)
|
|
235
239
|
|
|
236
240
|
|
|
237
241
|
def load_config_from_string(value: str) -> Tuple[Kwargs, ...]:
|
|
238
|
-
'''Loads config file contents specified explicitly as a string value.
|
|
242
|
+
'''Loads config file contents specified explicitly as a string value.
|
|
243
|
+
|
|
244
|
+
:param value: the string in json format containing the config contents to parse
|
|
245
|
+
'''
|
|
239
246
|
config, args = _parse_config_string(value)
|
|
240
247
|
return _load_config(config, args)
|
|
241
248
|
|
|
@@ -258,10 +265,10 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
258
265
|
def __init__(self, *args,
|
|
259
266
|
logic: Logic=FuzzyLogic(),
|
|
260
267
|
cpfs_without_grad: Optional[Set[str]]=None,
|
|
268
|
+
print_warnings: bool=True,
|
|
261
269
|
**kwargs) -> None:
|
|
262
270
|
'''Creates a new RDDL to Jax compiler, where operations that are not
|
|
263
|
-
differentiable are converted to approximate forms that have defined
|
|
264
|
-
gradients.
|
|
271
|
+
differentiable are converted to approximate forms that have defined gradients.
|
|
265
272
|
|
|
266
273
|
:param *args: arguments to pass to base compiler
|
|
267
274
|
:param logic: Fuzzy logic object that specifies how exact operations
|
|
@@ -269,6 +276,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
269
276
|
to customize these operations
|
|
270
277
|
:param cpfs_without_grad: which CPFs do not have gradients (use straight
|
|
271
278
|
through gradient trick)
|
|
279
|
+
:param print_warnings: whether to print warnings
|
|
272
280
|
:param *kwargs: keyword arguments to pass to base compiler
|
|
273
281
|
'''
|
|
274
282
|
super(JaxRDDLCompilerWithGrad, self).__init__(*args, **kwargs)
|
|
@@ -278,6 +286,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
278
286
|
if cpfs_without_grad is None:
|
|
279
287
|
cpfs_without_grad = set()
|
|
280
288
|
self.cpfs_without_grad = cpfs_without_grad
|
|
289
|
+
self.print_warnings = print_warnings
|
|
281
290
|
|
|
282
291
|
# actions and CPFs must be continuous
|
|
283
292
|
pvars_cast = set()
|
|
@@ -285,9 +294,11 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
285
294
|
self.init_values[var] = np.asarray(values, dtype=self.REAL)
|
|
286
295
|
if not np.issubdtype(np.result_type(values), np.floating):
|
|
287
296
|
pvars_cast.add(var)
|
|
288
|
-
if pvars_cast:
|
|
289
|
-
|
|
290
|
-
|
|
297
|
+
if self.print_warnings and pvars_cast:
|
|
298
|
+
message = termcolor.colored(
|
|
299
|
+
f'[INFO] JAX gradient compiler will cast p-vars {pvars_cast} to float.',
|
|
300
|
+
'green')
|
|
301
|
+
print(message)
|
|
291
302
|
|
|
292
303
|
# overwrite basic operations with fuzzy ones
|
|
293
304
|
self.OPS = logic.get_operator_dicts()
|
|
@@ -300,6 +311,8 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
300
311
|
return _jax_wrapped_stop_grad
|
|
301
312
|
|
|
302
313
|
def _compile_cpfs(self, init_params):
|
|
314
|
+
|
|
315
|
+
# cpfs will all be cast to float
|
|
303
316
|
cpfs_cast = set()
|
|
304
317
|
jax_cpfs = {}
|
|
305
318
|
for (_, cpfs) in self.levels.items():
|
|
@@ -311,12 +324,16 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
311
324
|
if cpf in self.cpfs_without_grad:
|
|
312
325
|
jax_cpfs[cpf] = self._jax_stop_grad(jax_cpfs[cpf])
|
|
313
326
|
|
|
314
|
-
if cpfs_cast:
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
327
|
+
if self.print_warnings and cpfs_cast:
|
|
328
|
+
message = termcolor.colored(
|
|
329
|
+
f'[INFO] JAX gradient compiler will cast CPFs {cpfs_cast} to float.',
|
|
330
|
+
'green')
|
|
331
|
+
print(message)
|
|
332
|
+
if self.print_warnings and self.cpfs_without_grad:
|
|
333
|
+
message = termcolor.colored(
|
|
334
|
+
f'[INFO] Gradients will not flow through CPFs {self.cpfs_without_grad}.',
|
|
335
|
+
'green')
|
|
336
|
+
print(message)
|
|
320
337
|
|
|
321
338
|
return jax_cpfs
|
|
322
339
|
|
|
@@ -335,7 +352,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
335
352
|
# ***********************************************************************
|
|
336
353
|
|
|
337
354
|
|
|
338
|
-
class JaxPlan:
|
|
355
|
+
class JaxPlan(metaclass=ABCMeta):
|
|
339
356
|
'''Base class for all JAX policy representations.'''
|
|
340
357
|
|
|
341
358
|
def __init__(self) -> None:
|
|
@@ -345,16 +362,18 @@ class JaxPlan:
|
|
|
345
362
|
self._projection = None
|
|
346
363
|
self.bounds = None
|
|
347
364
|
|
|
348
|
-
def summarize_hyperparameters(self) ->
|
|
349
|
-
|
|
350
|
-
|
|
365
|
+
def summarize_hyperparameters(self) -> str:
|
|
366
|
+
return self.__str__()
|
|
367
|
+
|
|
368
|
+
@abstractmethod
|
|
351
369
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
352
370
|
_bounds: Bounds,
|
|
353
371
|
horizon: int) -> None:
|
|
354
|
-
|
|
372
|
+
pass
|
|
355
373
|
|
|
374
|
+
@abstractmethod
|
|
356
375
|
def guess_next_epoch(self, params: Pytree) -> Pytree:
|
|
357
|
-
|
|
376
|
+
pass
|
|
358
377
|
|
|
359
378
|
@property
|
|
360
379
|
def initializer(self):
|
|
@@ -397,10 +416,11 @@ class JaxPlan:
|
|
|
397
416
|
continue
|
|
398
417
|
|
|
399
418
|
# check invalid type
|
|
400
|
-
if prange not in compiled.JAX_TYPES:
|
|
419
|
+
if prange not in compiled.JAX_TYPES and prange not in compiled.rddl.enum_types:
|
|
420
|
+
keys = list(compiled.JAX_TYPES.keys()) + list(compiled.rddl.enum_types)
|
|
401
421
|
raise RDDLTypeError(
|
|
402
422
|
f'Invalid range <{prange}> of action-fluent <{name}>, '
|
|
403
|
-
f'must be one of {
|
|
423
|
+
f'must be one of {keys}.')
|
|
404
424
|
|
|
405
425
|
# clip boolean to (0, 1), otherwise use the RDDL action bounds
|
|
406
426
|
# or the user defined action bounds if provided
|
|
@@ -408,7 +428,12 @@ class JaxPlan:
|
|
|
408
428
|
if prange == 'bool':
|
|
409
429
|
lower, upper = None, None
|
|
410
430
|
else:
|
|
411
|
-
|
|
431
|
+
if prange in compiled.rddl.enum_types:
|
|
432
|
+
lower = np.zeros(shape=shapes[name][1:])
|
|
433
|
+
upper = len(compiled.rddl.type_to_objects[prange]) - 1
|
|
434
|
+
upper = np.ones(shape=shapes[name][1:]) * upper
|
|
435
|
+
else:
|
|
436
|
+
lower, upper = compiled.constraints.bounds[name]
|
|
412
437
|
lower, upper = user_bounds.get(name, (lower, upper))
|
|
413
438
|
lower = np.asarray(lower, dtype=compiled.REAL)
|
|
414
439
|
upper = np.asarray(upper, dtype=compiled.REAL)
|
|
@@ -421,7 +446,11 @@ class JaxPlan:
|
|
|
421
446
|
~lower_finite & upper_finite,
|
|
422
447
|
~lower_finite & ~upper_finite]
|
|
423
448
|
bounds[name] = (lower, upper)
|
|
424
|
-
|
|
449
|
+
if compiled.print_warnings:
|
|
450
|
+
message = termcolor.colored(
|
|
451
|
+
f'[INFO] Bounds of action-fluent <{name}> set to {bounds[name]}.',
|
|
452
|
+
'green')
|
|
453
|
+
print(message)
|
|
425
454
|
return shapes, bounds, bounds_safe, cond_lists
|
|
426
455
|
|
|
427
456
|
def _count_bool_actions(self, rddl: RDDLLiftedModel):
|
|
@@ -501,11 +530,12 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
501
530
|
# action concurrency check
|
|
502
531
|
bool_action_count, allowed_actions = self._count_bool_actions(rddl)
|
|
503
532
|
use_constraint_satisfaction = allowed_actions < bool_action_count
|
|
504
|
-
if use_constraint_satisfaction:
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
533
|
+
if compiled.print_warnings and use_constraint_satisfaction:
|
|
534
|
+
message = termcolor.colored(
|
|
535
|
+
f'[INFO] SLP will use projected gradient to satisfy '
|
|
536
|
+
f'max_nondef_actions since total boolean actions '
|
|
537
|
+
f'{bool_action_count} > max_nondef_actions {allowed_actions}.', 'green')
|
|
538
|
+
print(message)
|
|
509
539
|
|
|
510
540
|
noop = {var: (values[0] if isinstance(values, list) else values)
|
|
511
541
|
for (var, values) in rddl.action_fluents.items()}
|
|
@@ -586,7 +616,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
586
616
|
start = 0
|
|
587
617
|
for (name, size) in action_sizes.items():
|
|
588
618
|
action = output[..., start:start + size]
|
|
589
|
-
action = jnp.reshape(action,
|
|
619
|
+
action = jnp.reshape(action, shapes[name][1:])
|
|
590
620
|
if noop[name]:
|
|
591
621
|
action = 1.0 - action
|
|
592
622
|
actions[name] = action
|
|
@@ -623,7 +653,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
623
653
|
else:
|
|
624
654
|
action = _jax_non_bool_param_to_action(var, action, hyperparams)
|
|
625
655
|
action = jnp.clip(action, *bounds[var])
|
|
626
|
-
if ranges[var] == 'int':
|
|
656
|
+
if ranges[var] == 'int' or ranges[var] in rddl.enum_types:
|
|
627
657
|
action = jnp.asarray(jnp.round(action), dtype=compiled.INT)
|
|
628
658
|
actions[var] = action
|
|
629
659
|
return actions
|
|
@@ -642,7 +672,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
642
672
|
# only allow one action non-noop for now
|
|
643
673
|
if 1 < allowed_actions < bool_action_count:
|
|
644
674
|
raise RDDLNotImplementedError(
|
|
645
|
-
f'
|
|
675
|
+
f'SLPs with wrap_softmax currently '
|
|
646
676
|
f'do not support max-nondef-actions {allowed_actions} > 1.')
|
|
647
677
|
|
|
648
678
|
# potentially apply projection but to non-bool actions only
|
|
@@ -764,7 +794,8 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
764
794
|
for (var, action) in actions.items():
|
|
765
795
|
if ranges[var] == 'bool':
|
|
766
796
|
action = jnp.clip(action, min_action, max_action)
|
|
767
|
-
|
|
797
|
+
param = _jax_bool_action_to_param(var, action, hyperparams)
|
|
798
|
+
new_params[var] = param
|
|
768
799
|
else:
|
|
769
800
|
new_params[var] = action
|
|
770
801
|
return new_params, converged
|
|
@@ -818,7 +849,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
818
849
|
|
|
819
850
|
def guess_next_epoch(self, params: Pytree) -> Pytree:
|
|
820
851
|
next_fn = JaxStraightLinePlan._guess_next_epoch
|
|
821
|
-
return jax.tree_map(next_fn, params)
|
|
852
|
+
return jax.tree_util.tree_map(next_fn, params)
|
|
822
853
|
|
|
823
854
|
|
|
824
855
|
class JaxDeepReactivePolicy(JaxPlan):
|
|
@@ -890,8 +921,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
890
921
|
bool_action_count, allowed_actions = self._count_bool_actions(rddl)
|
|
891
922
|
if 1 < allowed_actions < bool_action_count:
|
|
892
923
|
raise RDDLNotImplementedError(
|
|
893
|
-
f'
|
|
894
|
-
f'max-nondef-actions {allowed_actions} > 1.')
|
|
924
|
+
f'DRPs currently do not support max-nondef-actions {allowed_actions} > 1.')
|
|
895
925
|
use_constraint_satisfaction = allowed_actions < bool_action_count
|
|
896
926
|
|
|
897
927
|
noop = {var: (values[0] if isinstance(values, list) else values)
|
|
@@ -927,15 +957,19 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
927
957
|
if ranges[var] != 'bool':
|
|
928
958
|
value_size = np.size(values)
|
|
929
959
|
if normalize_per_layer and value_size == 1:
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
960
|
+
if compiled.print_warnings:
|
|
961
|
+
message = termcolor.colored(
|
|
962
|
+
f'[WARN] Cannot apply layer norm to state-fluent <{var}> '
|
|
963
|
+
f'of size 1: setting normalize_per_layer = False.', 'yellow')
|
|
964
|
+
print(message)
|
|
933
965
|
normalize_per_layer = False
|
|
934
966
|
non_bool_dims += value_size
|
|
935
967
|
if not normalize_per_layer and non_bool_dims == 1:
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
968
|
+
if compiled.print_warnings:
|
|
969
|
+
message = termcolor.colored(
|
|
970
|
+
'[WARN] Cannot apply layer norm to state-fluents of total size 1: '
|
|
971
|
+
'setting normalize = False.', 'yellow')
|
|
972
|
+
print(message)
|
|
939
973
|
normalize = False
|
|
940
974
|
|
|
941
975
|
# convert subs dictionary into a state vector to feed to the MLP
|
|
@@ -1033,7 +1067,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1033
1067
|
for (name, size) in layer_sizes.items():
|
|
1034
1068
|
if ranges[name] == 'bool':
|
|
1035
1069
|
action = output[..., start:start + size]
|
|
1036
|
-
action = jnp.reshape(action,
|
|
1070
|
+
action = jnp.reshape(action, shapes[name])
|
|
1037
1071
|
if noop[name]:
|
|
1038
1072
|
action = 1.0 - action
|
|
1039
1073
|
actions[name] = action
|
|
@@ -1061,7 +1095,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1061
1095
|
prange = ranges[var]
|
|
1062
1096
|
if prange == 'bool':
|
|
1063
1097
|
new_action = action > 0.5
|
|
1064
|
-
elif prange == 'int':
|
|
1098
|
+
elif prange == 'int' or prange in rddl.enum_types:
|
|
1065
1099
|
action = jnp.clip(action, *bounds[var])
|
|
1066
1100
|
new_action = jnp.asarray(jnp.round(action), dtype=compiled.INT)
|
|
1067
1101
|
else:
|
|
@@ -1112,19 +1146,18 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1112
1146
|
|
|
1113
1147
|
|
|
1114
1148
|
class RollingMean:
|
|
1115
|
-
'''Maintains
|
|
1116
|
-
observations.'''
|
|
1149
|
+
'''Maintains the rolling mean of a stream of real-valued observations.'''
|
|
1117
1150
|
|
|
1118
1151
|
def __init__(self, window_size: int) -> None:
|
|
1119
1152
|
self._window_size = window_size
|
|
1120
1153
|
self._memory = deque(maxlen=window_size)
|
|
1121
1154
|
self._total = 0
|
|
1122
1155
|
|
|
1123
|
-
def update(self, x: float) -> float:
|
|
1156
|
+
def update(self, x: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
|
|
1124
1157
|
memory = self._memory
|
|
1125
|
-
self._total
|
|
1158
|
+
self._total = self._total + x
|
|
1126
1159
|
if len(memory) == self._window_size:
|
|
1127
|
-
self._total
|
|
1160
|
+
self._total = self._total - memory.popleft()
|
|
1128
1161
|
memory.append(x)
|
|
1129
1162
|
return self._total / len(memory)
|
|
1130
1163
|
|
|
@@ -1147,14 +1180,16 @@ class JaxPlannerStatus(Enum):
|
|
|
1147
1180
|
return self.value == 1 or self.value >= 4
|
|
1148
1181
|
|
|
1149
1182
|
|
|
1150
|
-
class JaxPlannerStoppingRule:
|
|
1183
|
+
class JaxPlannerStoppingRule(metaclass=ABCMeta):
|
|
1151
1184
|
'''The base class of all planner stopping rules.'''
|
|
1152
1185
|
|
|
1186
|
+
@abstractmethod
|
|
1153
1187
|
def reset(self) -> None:
|
|
1154
|
-
|
|
1155
|
-
|
|
1188
|
+
pass
|
|
1189
|
+
|
|
1190
|
+
@abstractmethod
|
|
1156
1191
|
def monitor(self, callback: Dict[str, Any]) -> bool:
|
|
1157
|
-
|
|
1192
|
+
pass
|
|
1158
1193
|
|
|
1159
1194
|
|
|
1160
1195
|
class NoImprovementStoppingRule(JaxPlannerStoppingRule):
|
|
@@ -1168,8 +1203,7 @@ class NoImprovementStoppingRule(JaxPlannerStoppingRule):
|
|
|
1168
1203
|
self.iters_since_last_update = 0
|
|
1169
1204
|
|
|
1170
1205
|
def monitor(self, callback: Dict[str, Any]) -> bool:
|
|
1171
|
-
if self.callback is None
|
|
1172
|
-
or callback['best_return'] > self.callback['best_return']:
|
|
1206
|
+
if self.callback is None or callback['best_return'] > self.callback['best_return']:
|
|
1173
1207
|
self.callback = callback
|
|
1174
1208
|
self.iters_since_last_update = 0
|
|
1175
1209
|
else:
|
|
@@ -1188,7 +1222,7 @@ class NoImprovementStoppingRule(JaxPlannerStoppingRule):
|
|
|
1188
1222
|
# ***********************************************************************
|
|
1189
1223
|
|
|
1190
1224
|
|
|
1191
|
-
class PGPE:
|
|
1225
|
+
class PGPE(metaclass=ABCMeta):
|
|
1192
1226
|
"""Base class for all PGPE strategies."""
|
|
1193
1227
|
|
|
1194
1228
|
def __init__(self) -> None:
|
|
@@ -1203,8 +1237,11 @@ class PGPE:
|
|
|
1203
1237
|
def update(self):
|
|
1204
1238
|
return self._update
|
|
1205
1239
|
|
|
1206
|
-
|
|
1207
|
-
|
|
1240
|
+
@abstractmethod
|
|
1241
|
+
def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type,
|
|
1242
|
+
print_warnings: bool,
|
|
1243
|
+
parallel_updates: Optional[int]=None) -> None:
|
|
1244
|
+
pass
|
|
1208
1245
|
|
|
1209
1246
|
|
|
1210
1247
|
class GaussianPGPE(PGPE):
|
|
@@ -1268,10 +1305,11 @@ class GaussianPGPE(PGPE):
|
|
|
1268
1305
|
mu_optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs_mu)
|
|
1269
1306
|
sigma_optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs_sigma)
|
|
1270
1307
|
except Exception as _:
|
|
1271
|
-
|
|
1272
|
-
|
|
1273
|
-
'rolling back to safer method:
|
|
1274
|
-
'
|
|
1308
|
+
message = termcolor.colored(
|
|
1309
|
+
'[FAIL] Failed to inject hyperparameters into PGPE optimizer, '
|
|
1310
|
+
'rolling back to safer method: '
|
|
1311
|
+
'kl-divergence constraint will be disabled.', 'red')
|
|
1312
|
+
print(message)
|
|
1275
1313
|
mu_optimizer = optimizer(**optimizer_kwargs_mu)
|
|
1276
1314
|
sigma_optimizer = optimizer(**optimizer_kwargs_sigma)
|
|
1277
1315
|
max_kl_update = None
|
|
@@ -1297,15 +1335,17 @@ class GaussianPGPE(PGPE):
|
|
|
1297
1335
|
f' max_kl_update ={self.max_kl}\n'
|
|
1298
1336
|
)
|
|
1299
1337
|
|
|
1300
|
-
def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type
|
|
1338
|
+
def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type,
|
|
1339
|
+
print_warnings: bool,
|
|
1340
|
+
parallel_updates: Optional[int]=None) -> None:
|
|
1301
1341
|
sigma0 = self.init_sigma
|
|
1302
|
-
|
|
1342
|
+
sigma_lo, sigma_hi = self.sigma_range
|
|
1303
1343
|
scale_reward = self.scale_reward
|
|
1304
1344
|
min_reward_scale = self.min_reward_scale
|
|
1305
1345
|
super_symmetric = self.super_symmetric
|
|
1306
1346
|
super_symmetric_accurate = self.super_symmetric_accurate
|
|
1307
1347
|
batch_size = self.batch_size
|
|
1308
|
-
|
|
1348
|
+
mu_optimizer, sigma_optimizer = self.optimizers
|
|
1309
1349
|
max_kl = self.max_kl
|
|
1310
1350
|
|
|
1311
1351
|
# entropy regularization penalty is decayed exponentially by elapsed budget
|
|
@@ -1322,13 +1362,22 @@ class GaussianPGPE(PGPE):
|
|
|
1322
1362
|
|
|
1323
1363
|
def _jax_wrapped_pgpe_init(key, policy_params):
|
|
1324
1364
|
mu = policy_params
|
|
1325
|
-
sigma = jax.tree_map(
|
|
1365
|
+
sigma = jax.tree_util.tree_map(partial(jnp.full_like, fill_value=sigma0), mu)
|
|
1326
1366
|
pgpe_params = (mu, sigma)
|
|
1327
|
-
pgpe_opt_state =
|
|
1328
|
-
|
|
1329
|
-
return pgpe_params, pgpe_opt_state
|
|
1367
|
+
pgpe_opt_state = (mu_optimizer.init(mu), sigma_optimizer.init(sigma))
|
|
1368
|
+
r_max = -jnp.inf
|
|
1369
|
+
return pgpe_params, pgpe_opt_state, r_max
|
|
1330
1370
|
|
|
1331
|
-
|
|
1371
|
+
if parallel_updates is None:
|
|
1372
|
+
self._initializer = jax.jit(_jax_wrapped_pgpe_init)
|
|
1373
|
+
else:
|
|
1374
|
+
|
|
1375
|
+
# for parallel policy update
|
|
1376
|
+
def _jax_wrapped_pgpe_inits(key, policy_params):
|
|
1377
|
+
keys = jnp.asarray(random.split(key, num=parallel_updates))
|
|
1378
|
+
return jax.vmap(_jax_wrapped_pgpe_init, in_axes=0)(keys, policy_params)
|
|
1379
|
+
|
|
1380
|
+
self._initializer = jax.jit(_jax_wrapped_pgpe_inits)
|
|
1332
1381
|
|
|
1333
1382
|
# ***********************************************************************
|
|
1334
1383
|
# PARAMETER SAMPLING FUNCTIONS
|
|
@@ -1338,6 +1387,8 @@ class GaussianPGPE(PGPE):
|
|
|
1338
1387
|
def _jax_wrapped_mu_noise(key, sigma):
|
|
1339
1388
|
return sigma * random.normal(key, shape=jnp.shape(sigma), dtype=real_dtype)
|
|
1340
1389
|
|
|
1390
|
+
# this samples a noise variable epsilon* from epsilon with the N(0, 1) density
|
|
1391
|
+
# according to super-symmetric sampling paper
|
|
1341
1392
|
def _jax_wrapped_epsilon_star(sigma, epsilon):
|
|
1342
1393
|
c1, c2, c3 = -0.06655, -0.9706, 0.124
|
|
1343
1394
|
phi = 0.67449 * sigma
|
|
@@ -1354,17 +1405,19 @@ class GaussianPGPE(PGPE):
|
|
|
1354
1405
|
epsilon_star = jnp.sign(epsilon) * phi * jnp.exp(a)
|
|
1355
1406
|
return epsilon_star
|
|
1356
1407
|
|
|
1408
|
+
# implements baseline-free super-symmetric sampling to generate 4 trajectories
|
|
1357
1409
|
def _jax_wrapped_sample_params(key, mu, sigma):
|
|
1358
1410
|
treedef = jax.tree_util.tree_structure(sigma)
|
|
1359
1411
|
keys = random.split(key, num=treedef.num_leaves)
|
|
1360
1412
|
keys_pytree = jax.tree_util.tree_unflatten(treedef=treedef, leaves=keys)
|
|
1361
|
-
epsilon = jax.tree_map(_jax_wrapped_mu_noise, keys_pytree, sigma)
|
|
1362
|
-
p1 = jax.tree_map(jnp.add, mu, epsilon)
|
|
1363
|
-
p2 = jax.tree_map(jnp.subtract, mu, epsilon)
|
|
1413
|
+
epsilon = jax.tree_util.tree_map(_jax_wrapped_mu_noise, keys_pytree, sigma)
|
|
1414
|
+
p1 = jax.tree_util.tree_map(jnp.add, mu, epsilon)
|
|
1415
|
+
p2 = jax.tree_util.tree_map(jnp.subtract, mu, epsilon)
|
|
1364
1416
|
if super_symmetric:
|
|
1365
|
-
epsilon_star = jax.tree_map(
|
|
1366
|
-
|
|
1367
|
-
|
|
1417
|
+
epsilon_star = jax.tree_util.tree_map(
|
|
1418
|
+
_jax_wrapped_epsilon_star, sigma, epsilon)
|
|
1419
|
+
p3 = jax.tree_util.tree_map(jnp.add, mu, epsilon_star)
|
|
1420
|
+
p4 = jax.tree_util.tree_map(jnp.subtract, mu, epsilon_star)
|
|
1368
1421
|
else:
|
|
1369
1422
|
epsilon_star, p3, p4 = epsilon, p1, p2
|
|
1370
1423
|
return p1, p2, p3, p4, epsilon, epsilon_star
|
|
@@ -1374,6 +1427,7 @@ class GaussianPGPE(PGPE):
|
|
|
1374
1427
|
#
|
|
1375
1428
|
# ***********************************************************************
|
|
1376
1429
|
|
|
1430
|
+
# gradient with respect to mean
|
|
1377
1431
|
def _jax_wrapped_mu_grad(epsilon, epsilon_star, r1, r2, r3, r4, m):
|
|
1378
1432
|
if super_symmetric:
|
|
1379
1433
|
if scale_reward:
|
|
@@ -1393,6 +1447,7 @@ class GaussianPGPE(PGPE):
|
|
|
1393
1447
|
grad = -r_mu * epsilon
|
|
1394
1448
|
return grad
|
|
1395
1449
|
|
|
1450
|
+
# gradient with respect to std. deviation
|
|
1396
1451
|
def _jax_wrapped_sigma_grad(epsilon, epsilon_star, sigma, r1, r2, r3, r4, m, ent):
|
|
1397
1452
|
if super_symmetric:
|
|
1398
1453
|
mask = r1 + r2 >= r3 + r4
|
|
@@ -1413,6 +1468,7 @@ class GaussianPGPE(PGPE):
|
|
|
1413
1468
|
grad = -(r_sigma * s + ent / sigma)
|
|
1414
1469
|
return grad
|
|
1415
1470
|
|
|
1471
|
+
# calculate the policy gradients
|
|
1416
1472
|
def _jax_wrapped_pgpe_grad(key, mu, sigma, r_max, ent,
|
|
1417
1473
|
policy_hyperparams, subs, model_params):
|
|
1418
1474
|
key, subkey = random.split(key)
|
|
@@ -1429,11 +1485,11 @@ class GaussianPGPE(PGPE):
|
|
|
1429
1485
|
r_max = jnp.maximum(r_max, r4)
|
|
1430
1486
|
else:
|
|
1431
1487
|
r3, r4 = r1, r2
|
|
1432
|
-
grad_mu = jax.tree_map(
|
|
1488
|
+
grad_mu = jax.tree_util.tree_map(
|
|
1433
1489
|
partial(_jax_wrapped_mu_grad, r1=r1, r2=r2, r3=r3, r4=r4, m=r_max),
|
|
1434
1490
|
epsilon, epsilon_star
|
|
1435
1491
|
)
|
|
1436
|
-
grad_sigma = jax.tree_map(
|
|
1492
|
+
grad_sigma = jax.tree_util.tree_map(
|
|
1437
1493
|
partial(_jax_wrapped_sigma_grad,
|
|
1438
1494
|
r1=r1, r2=r2, r3=r3, r4=r4, m=r_max, ent=ent),
|
|
1439
1495
|
epsilon, epsilon_star, sigma
|
|
@@ -1452,7 +1508,7 @@ class GaussianPGPE(PGPE):
|
|
|
1452
1508
|
_jax_wrapped_pgpe_grad,
|
|
1453
1509
|
in_axes=(0, None, None, None, None, None, None, None)
|
|
1454
1510
|
)(keys, mu, sigma, r_max, ent, policy_hyperparams, subs, model_params)
|
|
1455
|
-
mu_grad, sigma_grad = jax.tree_map(
|
|
1511
|
+
mu_grad, sigma_grad = jax.tree_util.tree_map(
|
|
1456
1512
|
partial(jnp.mean, axis=0), (mu_grads, sigma_grads))
|
|
1457
1513
|
new_r_max = jnp.max(r_maxs)
|
|
1458
1514
|
return mu_grad, sigma_grad, new_r_max
|
|
@@ -1462,11 +1518,24 @@ class GaussianPGPE(PGPE):
|
|
|
1462
1518
|
#
|
|
1463
1519
|
# ***********************************************************************
|
|
1464
1520
|
|
|
1521
|
+
# estimate KL divergence between two updates
|
|
1465
1522
|
def _jax_wrapped_pgpe_kl_term(mu, sigma, old_mu, old_sigma):
|
|
1466
1523
|
return 0.5 * jnp.sum(2 * jnp.log(sigma / old_sigma) +
|
|
1467
1524
|
jnp.square(old_sigma / sigma) +
|
|
1468
1525
|
jnp.square((mu - old_mu) / sigma) - 1)
|
|
1469
1526
|
|
|
1527
|
+
# update mean and std. deviation with a gradient step
|
|
1528
|
+
def _jax_wrapped_pgpe_update_helper(mu, sigma, mu_grad, sigma_grad,
|
|
1529
|
+
mu_state, sigma_state):
|
|
1530
|
+
mu_updates, new_mu_state = mu_optimizer.update(mu_grad, mu_state, params=mu)
|
|
1531
|
+
sigma_updates, new_sigma_state = sigma_optimizer.update(
|
|
1532
|
+
sigma_grad, sigma_state, params=sigma)
|
|
1533
|
+
new_mu = optax.apply_updates(mu, mu_updates)
|
|
1534
|
+
new_sigma = optax.apply_updates(sigma, sigma_updates)
|
|
1535
|
+
new_sigma = jax.tree_util.tree_map(
|
|
1536
|
+
partial(jnp.clip, min=sigma_lo, max=sigma_hi), new_sigma)
|
|
1537
|
+
return new_mu, new_sigma, new_mu_state, new_sigma_state
|
|
1538
|
+
|
|
1470
1539
|
def _jax_wrapped_pgpe_update(key, pgpe_params, r_max, progress,
|
|
1471
1540
|
policy_hyperparams, subs, model_params,
|
|
1472
1541
|
pgpe_opt_state):
|
|
@@ -1476,29 +1545,23 @@ class GaussianPGPE(PGPE):
|
|
|
1476
1545
|
ent = start_entropy_coeff * jnp.power(entropy_coeff_decay, progress)
|
|
1477
1546
|
mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad_batched(
|
|
1478
1547
|
key, pgpe_params, r_max, ent, policy_hyperparams, subs, model_params)
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
new_mu = optax.apply_updates(mu, mu_updates)
|
|
1483
|
-
new_sigma = optax.apply_updates(sigma, sigma_updates)
|
|
1484
|
-
new_sigma = jax.tree_map(lambda x: jnp.clip(x, *sigma_range), new_sigma)
|
|
1548
|
+
new_mu, new_sigma, new_mu_state, new_sigma_state = \
|
|
1549
|
+
_jax_wrapped_pgpe_update_helper(mu, sigma, mu_grad, sigma_grad,
|
|
1550
|
+
mu_state, sigma_state)
|
|
1485
1551
|
|
|
1486
1552
|
# respect KL divergence contraint with old parameters
|
|
1487
1553
|
if max_kl is not None:
|
|
1488
1554
|
old_mu_lr = new_mu_state.hyperparams['learning_rate']
|
|
1489
1555
|
old_sigma_lr = new_sigma_state.hyperparams['learning_rate']
|
|
1490
|
-
kl_terms = jax.tree_map(
|
|
1556
|
+
kl_terms = jax.tree_util.tree_map(
|
|
1491
1557
|
_jax_wrapped_pgpe_kl_term, new_mu, new_sigma, mu, sigma)
|
|
1492
1558
|
total_kl = jax.tree_util.tree_reduce(jnp.add, kl_terms)
|
|
1493
1559
|
kl_reduction = jnp.minimum(1.0, jnp.sqrt(max_kl / total_kl))
|
|
1494
1560
|
mu_state.hyperparams['learning_rate'] = old_mu_lr * kl_reduction
|
|
1495
1561
|
sigma_state.hyperparams['learning_rate'] = old_sigma_lr * kl_reduction
|
|
1496
|
-
|
|
1497
|
-
|
|
1498
|
-
|
|
1499
|
-
new_mu = optax.apply_updates(mu, mu_updates)
|
|
1500
|
-
new_sigma = optax.apply_updates(sigma, sigma_updates)
|
|
1501
|
-
new_sigma = jax.tree_map(lambda x: jnp.clip(x, *sigma_range), new_sigma)
|
|
1562
|
+
new_mu, new_sigma, new_mu_state, new_sigma_state = \
|
|
1563
|
+
_jax_wrapped_pgpe_update_helper(mu, sigma, mu_grad, sigma_grad,
|
|
1564
|
+
mu_state, sigma_state)
|
|
1502
1565
|
new_mu_state.hyperparams['learning_rate'] = old_mu_lr
|
|
1503
1566
|
new_sigma_state.hyperparams['learning_rate'] = old_sigma_lr
|
|
1504
1567
|
|
|
@@ -1509,7 +1572,21 @@ class GaussianPGPE(PGPE):
|
|
|
1509
1572
|
policy_params = new_mu
|
|
1510
1573
|
return new_pgpe_params, new_r_max, new_pgpe_opt_state, policy_params, converged
|
|
1511
1574
|
|
|
1512
|
-
|
|
1575
|
+
if parallel_updates is None:
|
|
1576
|
+
self._update = jax.jit(_jax_wrapped_pgpe_update)
|
|
1577
|
+
else:
|
|
1578
|
+
|
|
1579
|
+
# for parallel policy update
|
|
1580
|
+
def _jax_wrapped_pgpe_updates(key, pgpe_params, r_max, progress,
|
|
1581
|
+
policy_hyperparams, subs, model_params,
|
|
1582
|
+
pgpe_opt_state):
|
|
1583
|
+
keys = jnp.asarray(random.split(key, num=parallel_updates))
|
|
1584
|
+
return jax.vmap(
|
|
1585
|
+
_jax_wrapped_pgpe_update, in_axes=(0, 0, 0, None, None, None, 0, 0)
|
|
1586
|
+
)(keys, pgpe_params, r_max, progress, policy_hyperparams, subs,
|
|
1587
|
+
model_params, pgpe_opt_state)
|
|
1588
|
+
|
|
1589
|
+
self._update = jax.jit(_jax_wrapped_pgpe_updates)
|
|
1513
1590
|
|
|
1514
1591
|
|
|
1515
1592
|
# ***********************************************************************
|
|
@@ -1565,6 +1642,7 @@ def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
|
|
|
1565
1642
|
return jnp.sum(returns * weights)
|
|
1566
1643
|
|
|
1567
1644
|
|
|
1645
|
+
# set of all currently valid built-in utility functions
|
|
1568
1646
|
UTILITY_LOOKUP = {
|
|
1569
1647
|
'mean': jnp.mean,
|
|
1570
1648
|
'mean_var': mean_variance_utility,
|
|
@@ -1609,7 +1687,9 @@ class JaxBackpropPlanner:
|
|
|
1609
1687
|
cpfs_without_grad: Optional[Set[str]]=None,
|
|
1610
1688
|
compile_non_fluent_exact: bool=True,
|
|
1611
1689
|
logger: Optional[Logger]=None,
|
|
1612
|
-
dashboard_viz: Optional[Any]=None
|
|
1690
|
+
dashboard_viz: Optional[Any]=None,
|
|
1691
|
+
print_warnings: bool=True,
|
|
1692
|
+
parallel_updates: Optional[int]=None) -> None:
|
|
1613
1693
|
'''Creates a new gradient-based algorithm for optimizing action sequences
|
|
1614
1694
|
(plan) in the given RDDL. Some operations will be converted to their
|
|
1615
1695
|
differentiable counterparts; the specific operations can be customized
|
|
@@ -1649,6 +1729,8 @@ class JaxBackpropPlanner:
|
|
|
1649
1729
|
:param logger: to log information about compilation to file
|
|
1650
1730
|
:param dashboard_viz: optional visualizer object from the environment
|
|
1651
1731
|
to pass to the dashboard to visualize the policy
|
|
1732
|
+
:param print_warnings: whether to print warnings
|
|
1733
|
+
:param parallel_updates: how many optimizers to run independently in parallel
|
|
1652
1734
|
'''
|
|
1653
1735
|
self.rddl = rddl
|
|
1654
1736
|
self.plan = plan
|
|
@@ -1656,6 +1738,7 @@ class JaxBackpropPlanner:
|
|
|
1656
1738
|
if batch_size_test is None:
|
|
1657
1739
|
batch_size_test = batch_size_train
|
|
1658
1740
|
self.batch_size_test = batch_size_test
|
|
1741
|
+
self.parallel_updates = parallel_updates
|
|
1659
1742
|
if rollout_horizon is None:
|
|
1660
1743
|
rollout_horizon = rddl.horizon
|
|
1661
1744
|
self.horizon = rollout_horizon
|
|
@@ -1672,15 +1755,17 @@ class JaxBackpropPlanner:
|
|
|
1672
1755
|
self.noise_kwargs = noise_kwargs
|
|
1673
1756
|
self.pgpe = pgpe
|
|
1674
1757
|
self.use_pgpe = pgpe is not None
|
|
1758
|
+
self.print_warnings = print_warnings
|
|
1675
1759
|
|
|
1676
1760
|
# set optimizer
|
|
1677
1761
|
try:
|
|
1678
1762
|
optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs)
|
|
1679
1763
|
except Exception as _:
|
|
1680
|
-
|
|
1681
|
-
|
|
1682
|
-
'rolling back to safer method: please note that modification of '
|
|
1683
|
-
'
|
|
1764
|
+
message = termcolor.colored(
|
|
1765
|
+
'[FAIL] Failed to inject hyperparameters into JaxPlan optimizer, '
|
|
1766
|
+
'rolling back to safer method: please note that runtime modification of '
|
|
1767
|
+
'hyperparameters will be disabled.', 'red')
|
|
1768
|
+
print(message)
|
|
1684
1769
|
optimizer = optimizer(**optimizer_kwargs)
|
|
1685
1770
|
|
|
1686
1771
|
# apply optimizer chain of transformations
|
|
@@ -1700,7 +1785,7 @@ class JaxBackpropPlanner:
|
|
|
1700
1785
|
utility_fn = UTILITY_LOOKUP.get(utility, None)
|
|
1701
1786
|
if utility_fn is None:
|
|
1702
1787
|
raise RDDLNotImplementedError(
|
|
1703
|
-
f'Utility <{utility}> is not supported, '
|
|
1788
|
+
f'Utility function <{utility}> is not supported, '
|
|
1704
1789
|
f'must be one of {list(UTILITY_LOOKUP.keys())}.')
|
|
1705
1790
|
else:
|
|
1706
1791
|
utility_fn = utility
|
|
@@ -1723,7 +1808,11 @@ class JaxBackpropPlanner:
|
|
|
1723
1808
|
self._jax_compile_rddl()
|
|
1724
1809
|
self._jax_compile_optimizer()
|
|
1725
1810
|
|
|
1726
|
-
|
|
1811
|
+
@staticmethod
|
|
1812
|
+
def summarize_system() -> str:
|
|
1813
|
+
'''Returns a string containing information about the system, Python version
|
|
1814
|
+
and jax-related packages that are relevant to the current planner.
|
|
1815
|
+
'''
|
|
1727
1816
|
try:
|
|
1728
1817
|
jaxlib_version = jax._src.lib.version_str
|
|
1729
1818
|
except Exception as _:
|
|
@@ -1742,7 +1831,7 @@ r"""
|
|
|
1742
1831
|
\/_____/ \/_/\/_/ \/_/\/_/ \/_/ \/_____/ \/_/\/_/ \/_/ \/_/
|
|
1743
1832
|
"""
|
|
1744
1833
|
|
|
1745
|
-
return ('\n'
|
|
1834
|
+
return (f'\n'
|
|
1746
1835
|
f'{LOGO}\n'
|
|
1747
1836
|
f'Version {__version__}\n'
|
|
1748
1837
|
f'Python {sys.version}\n'
|
|
@@ -1751,7 +1840,29 @@ r"""
|
|
|
1751
1840
|
f'numpy {np.__version__}\n'
|
|
1752
1841
|
f'devices: {devices_short}\n')
|
|
1753
1842
|
|
|
1754
|
-
def
|
|
1843
|
+
def summarize_relaxations(self) -> str:
|
|
1844
|
+
'''Returns a summary table containing all non-differentiable operators
|
|
1845
|
+
and their relaxations.
|
|
1846
|
+
'''
|
|
1847
|
+
result = ''
|
|
1848
|
+
if self.compiled.model_params:
|
|
1849
|
+
result += ('Some RDDL operations are non-differentiable '
|
|
1850
|
+
'and will be approximated as follows:' + '\n')
|
|
1851
|
+
exprs_by_rddl_op, values_by_rddl_op = {}, {}
|
|
1852
|
+
for info in self.compiled.model_parameter_info().values():
|
|
1853
|
+
rddl_op = info['rddl_op']
|
|
1854
|
+
exprs_by_rddl_op.setdefault(rddl_op, []).append(info['id'])
|
|
1855
|
+
values_by_rddl_op.setdefault(rddl_op, []).append(info['init_value'])
|
|
1856
|
+
for rddl_op in sorted(exprs_by_rddl_op.keys()):
|
|
1857
|
+
result += (f' {rddl_op}:\n'
|
|
1858
|
+
f' addresses ={exprs_by_rddl_op[rddl_op]}\n'
|
|
1859
|
+
f' init_values={values_by_rddl_op[rddl_op]}\n')
|
|
1860
|
+
return result
|
|
1861
|
+
|
|
1862
|
+
def summarize_hyperparameters(self) -> str:
|
|
1863
|
+
'''Returns a string summarizing the hyper-parameters of the current planner
|
|
1864
|
+
instance.
|
|
1865
|
+
'''
|
|
1755
1866
|
result = (f'objective hyper-parameters:\n'
|
|
1756
1867
|
f' utility_fn ={self.utility.__name__}\n'
|
|
1757
1868
|
f' utility args ={self.utility_kwargs}\n'
|
|
@@ -1769,30 +1880,14 @@ r"""
|
|
|
1769
1880
|
f' line_search_kwargs={self.line_search_kwargs}\n'
|
|
1770
1881
|
f' noise_kwargs ={self.noise_kwargs}\n'
|
|
1771
1882
|
f' batch_size_train ={self.batch_size_train}\n'
|
|
1772
|
-
f' batch_size_test ={self.batch_size_test}\n'
|
|
1883
|
+
f' batch_size_test ={self.batch_size_test}\n'
|
|
1884
|
+
f' parallel_updates ={self.parallel_updates}\n')
|
|
1773
1885
|
result += str(self.plan)
|
|
1774
1886
|
if self.use_pgpe:
|
|
1775
1887
|
result += str(self.pgpe)
|
|
1776
1888
|
result += str(self.logic)
|
|
1777
|
-
|
|
1778
|
-
# print model relaxation information
|
|
1779
|
-
if self.compiled.model_params:
|
|
1780
|
-
result += ('Some RDDL operations are non-differentiable '
|
|
1781
|
-
'and will be approximated as follows:' + '\n')
|
|
1782
|
-
exprs_by_rddl_op, values_by_rddl_op = {}, {}
|
|
1783
|
-
for info in self.compiled.model_parameter_info().values():
|
|
1784
|
-
rddl_op = info['rddl_op']
|
|
1785
|
-
exprs_by_rddl_op.setdefault(rddl_op, []).append(info['id'])
|
|
1786
|
-
values_by_rddl_op.setdefault(rddl_op, []).append(info['init_value'])
|
|
1787
|
-
for rddl_op in sorted(exprs_by_rddl_op.keys()):
|
|
1788
|
-
result += (f' {rddl_op}:\n'
|
|
1789
|
-
f' addresses ={exprs_by_rddl_op[rddl_op]}\n'
|
|
1790
|
-
f' init_values={values_by_rddl_op[rddl_op]}\n')
|
|
1791
1889
|
return result
|
|
1792
1890
|
|
|
1793
|
-
def summarize_hyperparameters(self) -> None:
|
|
1794
|
-
print(self.__str__())
|
|
1795
|
-
|
|
1796
1891
|
# ===========================================================================
|
|
1797
1892
|
# COMPILATION SUBROUTINES
|
|
1798
1893
|
# ===========================================================================
|
|
@@ -1807,7 +1902,8 @@ r"""
|
|
|
1807
1902
|
logger=self.logger,
|
|
1808
1903
|
use64bit=self.use64bit,
|
|
1809
1904
|
cpfs_without_grad=self.cpfs_without_grad,
|
|
1810
|
-
compile_non_fluent_exact=self.compile_non_fluent_exact
|
|
1905
|
+
compile_non_fluent_exact=self.compile_non_fluent_exact,
|
|
1906
|
+
print_warnings=self.print_warnings
|
|
1811
1907
|
)
|
|
1812
1908
|
self.compiled.compile(log_jax_expr=True, heading='RELAXED MODEL')
|
|
1813
1909
|
|
|
@@ -1844,23 +1940,33 @@ r"""
|
|
|
1844
1940
|
self.test_rollouts = jax.jit(test_rollouts)
|
|
1845
1941
|
|
|
1846
1942
|
# initialization
|
|
1847
|
-
self.initialize =
|
|
1943
|
+
self.initialize, self.init_optimizer = self._jax_init()
|
|
1848
1944
|
|
|
1849
1945
|
# losses
|
|
1850
1946
|
train_loss = self._jax_loss(train_rollouts, use_symlog=self.use_symlog_reward)
|
|
1851
|
-
|
|
1947
|
+
test_loss = self._jax_loss(test_rollouts, use_symlog=False)
|
|
1948
|
+
if self.parallel_updates is None:
|
|
1949
|
+
self.test_loss = jax.jit(test_loss)
|
|
1950
|
+
else:
|
|
1951
|
+
self.test_loss = jax.jit(jax.vmap(test_loss, in_axes=(None, 0, None, None, 0)))
|
|
1852
1952
|
|
|
1853
1953
|
# optimization
|
|
1854
1954
|
self.update = self._jax_update(train_loss)
|
|
1955
|
+
self.pytree_at = jax.jit(
|
|
1956
|
+
lambda tree, i: jax.tree_util.tree_map(lambda x: x[i], tree))
|
|
1855
1957
|
|
|
1856
1958
|
# pgpe option
|
|
1857
1959
|
if self.use_pgpe:
|
|
1858
|
-
loss_fn = self._jax_loss(rollouts=test_rollouts)
|
|
1859
1960
|
self.pgpe.compile(
|
|
1860
|
-
loss_fn=
|
|
1961
|
+
loss_fn=test_loss,
|
|
1861
1962
|
projection=self.plan.projection,
|
|
1862
|
-
real_dtype=self.test_compiled.REAL
|
|
1963
|
+
real_dtype=self.test_compiled.REAL,
|
|
1964
|
+
print_warnings=self.print_warnings,
|
|
1965
|
+
parallel_updates=self.parallel_updates
|
|
1863
1966
|
)
|
|
1967
|
+
self.merge_pgpe = self._jax_merge_pgpe_jaxplan()
|
|
1968
|
+
else:
|
|
1969
|
+
self.merge_pgpe = None
|
|
1864
1970
|
|
|
1865
1971
|
def _jax_return(self, use_symlog):
|
|
1866
1972
|
gamma = self.rddl.discount
|
|
@@ -1900,24 +2006,43 @@ r"""
|
|
|
1900
2006
|
def _jax_init(self):
|
|
1901
2007
|
init = self.plan.initializer
|
|
1902
2008
|
optimizer = self.optimizer
|
|
2009
|
+
num_parallel = self.parallel_updates
|
|
1903
2010
|
|
|
1904
2011
|
# initialize both the policy and its optimizer
|
|
1905
2012
|
def _jax_wrapped_init_policy(key, policy_hyperparams, subs):
|
|
1906
2013
|
policy_params = init(key, policy_hyperparams, subs)
|
|
1907
2014
|
opt_state = optimizer.init(policy_params)
|
|
1908
|
-
return policy_params, opt_state, {}
|
|
2015
|
+
return policy_params, opt_state, {}
|
|
2016
|
+
|
|
2017
|
+
# initialize just the optimizer from the policy
|
|
2018
|
+
def _jax_wrapped_init_opt(policy_params):
|
|
2019
|
+
if num_parallel is None:
|
|
2020
|
+
opt_state = optimizer.init(policy_params)
|
|
2021
|
+
else:
|
|
2022
|
+
opt_state = jax.vmap(optimizer.init, in_axes=0)(policy_params)
|
|
2023
|
+
return opt_state, {}
|
|
2024
|
+
|
|
2025
|
+
if num_parallel is None:
|
|
2026
|
+
return jax.jit(_jax_wrapped_init_policy), jax.jit(_jax_wrapped_init_opt)
|
|
1909
2027
|
|
|
1910
|
-
|
|
2028
|
+
# for parallel policy update
|
|
2029
|
+
def _jax_wrapped_init_policies(key, policy_hyperparams, subs):
|
|
2030
|
+
keys = jnp.asarray(random.split(key, num=num_parallel))
|
|
2031
|
+
return jax.vmap(_jax_wrapped_init_policy, in_axes=(0, None, None))(
|
|
2032
|
+
keys, policy_hyperparams, subs)
|
|
2033
|
+
|
|
2034
|
+
return jax.jit(_jax_wrapped_init_policies), jax.jit(_jax_wrapped_init_opt)
|
|
1911
2035
|
|
|
1912
2036
|
def _jax_update(self, loss):
|
|
1913
2037
|
optimizer = self.optimizer
|
|
1914
2038
|
projection = self.plan.projection
|
|
1915
2039
|
use_ls = self.line_search_kwargs is not None
|
|
2040
|
+
num_parallel = self.parallel_updates
|
|
1916
2041
|
|
|
1917
2042
|
# check if the gradients are all zeros
|
|
1918
2043
|
def _jax_wrapped_zero_gradients(grad):
|
|
1919
2044
|
leaves, _ = jax.tree_util.tree_flatten(
|
|
1920
|
-
jax.tree_map(
|
|
2045
|
+
jax.tree_util.tree_map(partial(jnp.allclose, b=0), grad))
|
|
1921
2046
|
return jnp.all(jnp.asarray(leaves))
|
|
1922
2047
|
|
|
1923
2048
|
# calculate the plan gradient w.r.t. return loss and update optimizer
|
|
@@ -1948,8 +2073,43 @@ r"""
|
|
|
1948
2073
|
return policy_params, converged, opt_state, opt_aux, \
|
|
1949
2074
|
loss_val, log, model_params, zero_grads
|
|
1950
2075
|
|
|
1951
|
-
|
|
2076
|
+
if num_parallel is None:
|
|
2077
|
+
return jax.jit(_jax_wrapped_plan_update)
|
|
2078
|
+
|
|
2079
|
+
# for parallel policy update
|
|
2080
|
+
def _jax_wrapped_plan_updates(key, policy_params, policy_hyperparams,
|
|
2081
|
+
subs, model_params, opt_state, opt_aux):
|
|
2082
|
+
keys = jnp.asarray(random.split(key, num=num_parallel))
|
|
2083
|
+
return jax.vmap(
|
|
2084
|
+
_jax_wrapped_plan_update, in_axes=(0, 0, None, None, 0, 0, 0)
|
|
2085
|
+
)(keys, policy_params, policy_hyperparams, subs, model_params,
|
|
2086
|
+
opt_state, opt_aux)
|
|
2087
|
+
|
|
2088
|
+
return jax.jit(_jax_wrapped_plan_updates)
|
|
1952
2089
|
|
|
2090
|
+
def _jax_merge_pgpe_jaxplan(self):
|
|
2091
|
+
if self.parallel_updates is None:
|
|
2092
|
+
return None
|
|
2093
|
+
|
|
2094
|
+
# for parallel policy update
|
|
2095
|
+
# currently implements a hard replacement where the jaxplan parameter
|
|
2096
|
+
# is replaced by the PGPE parameter if the latter is an improvement
|
|
2097
|
+
def _jax_wrapped_pgpe_jaxplan_merge(pgpe_mask, pgpe_param, policy_params,
|
|
2098
|
+
pgpe_loss, test_loss,
|
|
2099
|
+
pgpe_loss_smooth, test_loss_smooth,
|
|
2100
|
+
pgpe_converged, converged):
|
|
2101
|
+
def select_fn(leaf1, leaf2):
|
|
2102
|
+
expanded_mask = pgpe_mask[(...,) + (jnp.newaxis,) * (jnp.ndim(leaf1) - 1)]
|
|
2103
|
+
return jnp.where(expanded_mask, leaf1, leaf2)
|
|
2104
|
+
policy_params = jax.tree_util.tree_map(select_fn, pgpe_param, policy_params)
|
|
2105
|
+
test_loss = jnp.where(pgpe_mask, pgpe_loss, test_loss)
|
|
2106
|
+
test_loss_smooth = jnp.where(pgpe_mask, pgpe_loss_smooth, test_loss_smooth)
|
|
2107
|
+
expanded_mask = pgpe_mask[(...,) + (jnp.newaxis,) * (jnp.ndim(converged) - 1)]
|
|
2108
|
+
converged = jnp.where(expanded_mask, pgpe_converged, converged)
|
|
2109
|
+
return policy_params, test_loss, test_loss_smooth, converged
|
|
2110
|
+
|
|
2111
|
+
return jax.jit(_jax_wrapped_pgpe_jaxplan_merge)
|
|
2112
|
+
|
|
1953
2113
|
def _batched_init_subs(self, subs):
|
|
1954
2114
|
rddl = self.rddl
|
|
1955
2115
|
n_train, n_test = self.batch_size_train, self.batch_size_test
|
|
@@ -1963,11 +2123,20 @@ r"""
|
|
|
1963
2123
|
f'Variable <{name}> in subs argument is not a '
|
|
1964
2124
|
f'valid p-variable, must be one of '
|
|
1965
2125
|
f'{set(self.test_compiled.init_values.keys())}.')
|
|
1966
|
-
value = np.reshape(value,
|
|
2126
|
+
value = np.reshape(value, np.shape(init_value))[np.newaxis, ...]
|
|
2127
|
+
if value.dtype.type is np.str_:
|
|
2128
|
+
value = rddl.object_string_to_index_array(rddl.variable_ranges[name], value)
|
|
1967
2129
|
train_value = np.repeat(value, repeats=n_train, axis=0)
|
|
1968
2130
|
train_value = np.asarray(train_value, dtype=self.compiled.REAL)
|
|
1969
2131
|
init_train[name] = train_value
|
|
1970
2132
|
init_test[name] = np.repeat(value, repeats=n_test, axis=0)
|
|
2133
|
+
|
|
2134
|
+
# safely cast test subs variable to required type in case the type is wrong
|
|
2135
|
+
if name in rddl.variable_ranges:
|
|
2136
|
+
required_type = RDDLValueInitializer.NUMPY_TYPES.get(
|
|
2137
|
+
rddl.variable_ranges[name], RDDLValueInitializer.INT)
|
|
2138
|
+
if np.result_type(init_test[name]) != required_type:
|
|
2139
|
+
init_test[name] = np.asarray(init_test[name], dtype=required_type)
|
|
1971
2140
|
|
|
1972
2141
|
# make sure next-state fluents are also set
|
|
1973
2142
|
for (state, next_state) in rddl.next_state.items():
|
|
@@ -1975,6 +2144,19 @@ r"""
|
|
|
1975
2144
|
init_test[next_state] = init_test[state]
|
|
1976
2145
|
return init_train, init_test
|
|
1977
2146
|
|
|
2147
|
+
def _broadcast_pytree(self, pytree):
|
|
2148
|
+
if self.parallel_updates is None:
|
|
2149
|
+
return pytree
|
|
2150
|
+
|
|
2151
|
+
# for parallel policy update
|
|
2152
|
+
def make_batched(x):
|
|
2153
|
+
x = np.asarray(x)
|
|
2154
|
+
x = np.broadcast_to(
|
|
2155
|
+
x[np.newaxis, ...], shape=(self.parallel_updates,) + np.shape(x))
|
|
2156
|
+
return x
|
|
2157
|
+
|
|
2158
|
+
return jax.tree_util.tree_map(make_batched, pytree)
|
|
2159
|
+
|
|
1978
2160
|
def as_optimization_problem(
|
|
1979
2161
|
self, key: Optional[random.PRNGKey]=None,
|
|
1980
2162
|
policy_hyperparams: Optional[Pytree]=None,
|
|
@@ -2002,6 +2184,11 @@ r"""
|
|
|
2002
2184
|
:param grad_function_updates_key: if True, the gradient function
|
|
2003
2185
|
updates the PRNG key internally independently of the loss function.
|
|
2004
2186
|
'''
|
|
2187
|
+
|
|
2188
|
+
# make sure parallel updates are disabled
|
|
2189
|
+
if self.parallel_updates is not None:
|
|
2190
|
+
raise ValueError('Cannot compile static optimization problem '
|
|
2191
|
+
'when parallel_updates is not None.')
|
|
2005
2192
|
|
|
2006
2193
|
# if PRNG key is not provided
|
|
2007
2194
|
if key is None:
|
|
@@ -2012,8 +2199,11 @@ r"""
|
|
|
2012
2199
|
train_subs, _ = self._batched_init_subs(subs)
|
|
2013
2200
|
model_params = self.compiled.model_params
|
|
2014
2201
|
if policy_hyperparams is None:
|
|
2015
|
-
|
|
2016
|
-
|
|
2202
|
+
if self.print_warnings:
|
|
2203
|
+
message = termcolor.colored(
|
|
2204
|
+
'[WARN] policy_hyperparams is not set, setting 1.0 for '
|
|
2205
|
+
'all action-fluents which could be suboptimal.', 'yellow')
|
|
2206
|
+
print(message)
|
|
2017
2207
|
policy_hyperparams = {action: 1.0
|
|
2018
2208
|
for action in self.rddl.action_fluents}
|
|
2019
2209
|
|
|
@@ -2084,10 +2274,12 @@ r"""
|
|
|
2084
2274
|
their values: if None initializes all variables from the RDDL instance
|
|
2085
2275
|
:param guess: initial policy parameters: if None will use the initializer
|
|
2086
2276
|
specified in this instance
|
|
2087
|
-
:param print_summary: whether to print planner header
|
|
2088
|
-
summary, and diagnosis
|
|
2277
|
+
:param print_summary: whether to print planner header and diagnosis
|
|
2089
2278
|
:param print_progress: whether to print the progress bar during training
|
|
2279
|
+
:param print_hyperparams: whether to print list of hyper-parameter settings
|
|
2090
2280
|
:param stopping_rule: stopping criterion
|
|
2281
|
+
:param restart_epochs: restart the optimizer from a random policy configuration
|
|
2282
|
+
if there is no progress for this many consecutive iterations
|
|
2091
2283
|
:param test_rolling_window: the test return is averaged on a rolling
|
|
2092
2284
|
window of the past test_rolling_window returns when updating the best
|
|
2093
2285
|
parameters found so far
|
|
@@ -2120,7 +2312,9 @@ r"""
|
|
|
2120
2312
|
guess: Optional[Pytree]=None,
|
|
2121
2313
|
print_summary: bool=True,
|
|
2122
2314
|
print_progress: bool=True,
|
|
2315
|
+
print_hyperparams: bool=False,
|
|
2123
2316
|
stopping_rule: Optional[JaxPlannerStoppingRule]=None,
|
|
2317
|
+
restart_epochs: int=999999,
|
|
2124
2318
|
test_rolling_window: int=10,
|
|
2125
2319
|
tqdm_position: Optional[int]=None) -> Generator[Dict[str, Any], None, None]:
|
|
2126
2320
|
'''Returns a generator for computing an optimal policy or plan.
|
|
@@ -2139,10 +2333,12 @@ r"""
|
|
|
2139
2333
|
their values: if None initializes all variables from the RDDL instance
|
|
2140
2334
|
:param guess: initial policy parameters: if None will use the initializer
|
|
2141
2335
|
specified in this instance
|
|
2142
|
-
:param print_summary: whether to print planner header
|
|
2143
|
-
summary, and diagnosis
|
|
2336
|
+
:param print_summary: whether to print planner header and diagnosis
|
|
2144
2337
|
:param print_progress: whether to print the progress bar during training
|
|
2338
|
+
:param print_hyperparams: whether to print list of hyper-parameter settings
|
|
2145
2339
|
:param stopping_rule: stopping criterion
|
|
2340
|
+
:param restart_epochs: restart the optimizer from a random policy configuration
|
|
2341
|
+
if there is no progress for this many consecutive iterations
|
|
2146
2342
|
:param test_rolling_window: the test return is averaged on a rolling
|
|
2147
2343
|
window of the past test_rolling_window returns when updating the best
|
|
2148
2344
|
parameters found so far
|
|
@@ -2155,6 +2351,15 @@ r"""
|
|
|
2155
2351
|
# INITIALIZATION OF HYPER-PARAMETERS
|
|
2156
2352
|
# ======================================================================
|
|
2157
2353
|
|
|
2354
|
+
# cannot run dashboard with parallel updates
|
|
2355
|
+
if dashboard is not None and self.parallel_updates is not None:
|
|
2356
|
+
if self.print_warnings:
|
|
2357
|
+
message = termcolor.colored(
|
|
2358
|
+
'[WARN] Dashboard is unavailable if parallel_updates is not None: '
|
|
2359
|
+
'setting dashboard to None.', 'yellow')
|
|
2360
|
+
print(message)
|
|
2361
|
+
dashboard = None
|
|
2362
|
+
|
|
2158
2363
|
# if PRNG key is not provided
|
|
2159
2364
|
if key is None:
|
|
2160
2365
|
key = random.PRNGKey(round(time.time() * 1000))
|
|
@@ -2162,15 +2367,21 @@ r"""
|
|
|
2162
2367
|
|
|
2163
2368
|
# if policy_hyperparams is not provided
|
|
2164
2369
|
if policy_hyperparams is None:
|
|
2165
|
-
|
|
2166
|
-
|
|
2370
|
+
if self.print_warnings:
|
|
2371
|
+
message = termcolor.colored(
|
|
2372
|
+
'[WARN] policy_hyperparams is not set, setting 1.0 for '
|
|
2373
|
+
'all action-fluents which could be suboptimal.', 'yellow')
|
|
2374
|
+
print(message)
|
|
2167
2375
|
policy_hyperparams = {action: 1.0
|
|
2168
2376
|
for action in self.rddl.action_fluents}
|
|
2169
2377
|
|
|
2170
2378
|
# if policy_hyperparams is a scalar
|
|
2171
2379
|
elif isinstance(policy_hyperparams, (int, float, np.number)):
|
|
2172
|
-
|
|
2173
|
-
|
|
2380
|
+
if self.print_warnings:
|
|
2381
|
+
message = termcolor.colored(
|
|
2382
|
+
f'[INFO] policy_hyperparams is {policy_hyperparams}, '
|
|
2383
|
+
f'setting this value for all action-fluents.', 'green')
|
|
2384
|
+
print(message)
|
|
2174
2385
|
hyperparam_value = float(policy_hyperparams)
|
|
2175
2386
|
policy_hyperparams = {action: hyperparam_value
|
|
2176
2387
|
for action in self.rddl.action_fluents}
|
|
@@ -2179,14 +2390,20 @@ r"""
|
|
|
2179
2390
|
elif isinstance(policy_hyperparams, dict):
|
|
2180
2391
|
for action in self.rddl.action_fluents:
|
|
2181
2392
|
if action not in policy_hyperparams:
|
|
2182
|
-
|
|
2183
|
-
|
|
2393
|
+
if self.print_warnings:
|
|
2394
|
+
message = termcolor.colored(
|
|
2395
|
+
f'[WARN] policy_hyperparams[{action}] is not set, '
|
|
2396
|
+
f'setting 1.0 for missing action-fluents '
|
|
2397
|
+
f'which could be suboptimal.', 'yellow')
|
|
2398
|
+
print(message)
|
|
2184
2399
|
policy_hyperparams[action] = 1.0
|
|
2185
2400
|
|
|
2186
2401
|
# print summary of parameters:
|
|
2187
2402
|
if print_summary:
|
|
2188
2403
|
print(self.summarize_system())
|
|
2189
|
-
self.
|
|
2404
|
+
print(self.summarize_relaxations())
|
|
2405
|
+
if print_hyperparams:
|
|
2406
|
+
print(self.summarize_hyperparameters())
|
|
2190
2407
|
print(f'optimize() call hyper-parameters:\n'
|
|
2191
2408
|
f' PRNG key ={key}\n'
|
|
2192
2409
|
f' max_iterations ={epochs}\n'
|
|
@@ -2200,7 +2417,8 @@ r"""
|
|
|
2200
2417
|
f' dashboard_id ={dashboard_id}\n'
|
|
2201
2418
|
f' print_summary ={print_summary}\n'
|
|
2202
2419
|
f' print_progress ={print_progress}\n'
|
|
2203
|
-
f' stopping_rule ={stopping_rule}\n'
|
|
2420
|
+
f' stopping_rule ={stopping_rule}\n'
|
|
2421
|
+
f' restart_epochs ={restart_epochs}\n')
|
|
2204
2422
|
|
|
2205
2423
|
# ======================================================================
|
|
2206
2424
|
# INITIALIZATION OF STATE AND POLICY
|
|
@@ -2217,16 +2435,18 @@ r"""
|
|
|
2217
2435
|
if var not in subs:
|
|
2218
2436
|
subs[var] = value
|
|
2219
2437
|
added_pvars_to_subs.append(var)
|
|
2220
|
-
if added_pvars_to_subs:
|
|
2221
|
-
|
|
2222
|
-
|
|
2223
|
-
|
|
2438
|
+
if self.print_warnings and added_pvars_to_subs:
|
|
2439
|
+
message = termcolor.colored(
|
|
2440
|
+
f'[INFO] p-variables {added_pvars_to_subs} is not in '
|
|
2441
|
+
f'provided subs, using their initial values.', 'green')
|
|
2442
|
+
print(message)
|
|
2224
2443
|
train_subs, test_subs = self._batched_init_subs(subs)
|
|
2225
2444
|
|
|
2226
2445
|
# initialize model parameters
|
|
2227
2446
|
if model_params is None:
|
|
2228
2447
|
model_params = self.compiled.model_params
|
|
2229
|
-
|
|
2448
|
+
model_params = self._broadcast_pytree(model_params)
|
|
2449
|
+
model_params_test = self._broadcast_pytree(self.test_compiled.model_params)
|
|
2230
2450
|
|
|
2231
2451
|
# initialize policy parameters
|
|
2232
2452
|
if guess is None:
|
|
@@ -2234,29 +2454,31 @@ r"""
|
|
|
2234
2454
|
policy_params, opt_state, opt_aux = self.initialize(
|
|
2235
2455
|
subkey, policy_hyperparams, train_subs)
|
|
2236
2456
|
else:
|
|
2237
|
-
policy_params = guess
|
|
2238
|
-
opt_state = self.
|
|
2239
|
-
opt_aux = {}
|
|
2457
|
+
policy_params = self._broadcast_pytree(guess)
|
|
2458
|
+
opt_state, opt_aux = self.init_optimizer(policy_params)
|
|
2240
2459
|
|
|
2241
2460
|
# initialize pgpe parameters
|
|
2242
2461
|
if self.use_pgpe:
|
|
2243
|
-
pgpe_params, pgpe_opt_state = self.pgpe.initialize(key, policy_params)
|
|
2462
|
+
pgpe_params, pgpe_opt_state, r_max = self.pgpe.initialize(key, policy_params)
|
|
2244
2463
|
rolling_pgpe_loss = RollingMean(test_rolling_window)
|
|
2245
2464
|
else:
|
|
2246
|
-
pgpe_params, pgpe_opt_state = None, None
|
|
2465
|
+
pgpe_params, pgpe_opt_state, r_max = None, None, None
|
|
2247
2466
|
rolling_pgpe_loss = None
|
|
2248
2467
|
total_pgpe_it = 0
|
|
2249
|
-
r_max = -jnp.inf
|
|
2250
2468
|
|
|
2251
2469
|
# ======================================================================
|
|
2252
2470
|
# INITIALIZATION OF RUNNING STATISTICS
|
|
2253
2471
|
# ======================================================================
|
|
2254
2472
|
|
|
2255
2473
|
# initialize running statistics
|
|
2256
|
-
|
|
2474
|
+
if self.parallel_updates is None:
|
|
2475
|
+
best_params = policy_params
|
|
2476
|
+
else:
|
|
2477
|
+
best_params = self.pytree_at(policy_params, 0)
|
|
2478
|
+
best_loss, pbest_loss, best_grad = np.inf, np.inf, None
|
|
2257
2479
|
last_iter_improve = 0
|
|
2480
|
+
no_progress_count = 0
|
|
2258
2481
|
rolling_test_loss = RollingMean(test_rolling_window)
|
|
2259
|
-
log = {}
|
|
2260
2482
|
status = JaxPlannerStatus.NORMAL
|
|
2261
2483
|
progress_percent = 0
|
|
2262
2484
|
|
|
@@ -2277,6 +2499,11 @@ r"""
|
|
|
2277
2499
|
else:
|
|
2278
2500
|
progress_bar = None
|
|
2279
2501
|
position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
|
|
2502
|
+
|
|
2503
|
+
# error handlers (to avoid spam messaging)
|
|
2504
|
+
policy_constraint_msg_shown = False
|
|
2505
|
+
jax_train_msg_shown = False
|
|
2506
|
+
jax_test_msg_shown = False
|
|
2280
2507
|
|
|
2281
2508
|
# ======================================================================
|
|
2282
2509
|
# MAIN TRAINING LOOP BEGINS
|
|
@@ -2296,8 +2523,13 @@ r"""
|
|
|
2296
2523
|
model_params, zero_grads) = self.update(
|
|
2297
2524
|
subkey, policy_params, policy_hyperparams, train_subs, model_params,
|
|
2298
2525
|
opt_state, opt_aux)
|
|
2526
|
+
|
|
2527
|
+
# evaluate
|
|
2299
2528
|
test_loss, (test_log, model_params_test) = self.test_loss(
|
|
2300
2529
|
subkey, policy_params, policy_hyperparams, test_subs, model_params_test)
|
|
2530
|
+
if self.parallel_updates:
|
|
2531
|
+
train_loss = np.asarray(train_loss)
|
|
2532
|
+
test_loss = np.asarray(test_loss)
|
|
2301
2533
|
test_loss_smooth = rolling_test_loss.update(test_loss)
|
|
2302
2534
|
|
|
2303
2535
|
# pgpe update of the plan
|
|
@@ -2308,52 +2540,112 @@ r"""
|
|
|
2308
2540
|
self.pgpe.update(subkey, pgpe_params, r_max, progress_percent,
|
|
2309
2541
|
policy_hyperparams, test_subs, model_params_test,
|
|
2310
2542
|
pgpe_opt_state)
|
|
2543
|
+
|
|
2544
|
+
# evaluate
|
|
2311
2545
|
pgpe_loss, _ = self.test_loss(
|
|
2312
2546
|
subkey, pgpe_param, policy_hyperparams, test_subs, model_params_test)
|
|
2547
|
+
if self.parallel_updates:
|
|
2548
|
+
pgpe_loss = np.asarray(pgpe_loss)
|
|
2313
2549
|
pgpe_loss_smooth = rolling_pgpe_loss.update(pgpe_loss)
|
|
2314
2550
|
pgpe_return = -pgpe_loss_smooth
|
|
2315
2551
|
|
|
2316
|
-
# replace with PGPE if
|
|
2317
|
-
if
|
|
2318
|
-
|
|
2319
|
-
|
|
2320
|
-
|
|
2321
|
-
|
|
2322
|
-
|
|
2552
|
+
# replace JaxPlan with PGPE if new minimum reached or train loss invalid
|
|
2553
|
+
if self.parallel_updates is None:
|
|
2554
|
+
if pgpe_loss_smooth < best_loss or not np.isfinite(train_loss):
|
|
2555
|
+
policy_params = pgpe_param
|
|
2556
|
+
test_loss, test_loss_smooth = pgpe_loss, pgpe_loss_smooth
|
|
2557
|
+
converged = pgpe_converged
|
|
2558
|
+
pgpe_improve = True
|
|
2559
|
+
total_pgpe_it += 1
|
|
2560
|
+
else:
|
|
2561
|
+
pgpe_mask = (pgpe_loss_smooth < pbest_loss) | ~np.isfinite(train_loss)
|
|
2562
|
+
if np.any(pgpe_mask):
|
|
2563
|
+
policy_params, test_loss, test_loss_smooth, converged = \
|
|
2564
|
+
self.merge_pgpe(pgpe_mask, pgpe_param, policy_params,
|
|
2565
|
+
pgpe_loss, test_loss,
|
|
2566
|
+
pgpe_loss_smooth, test_loss_smooth,
|
|
2567
|
+
pgpe_converged, converged)
|
|
2568
|
+
pgpe_improve = True
|
|
2569
|
+
total_pgpe_it += 1
|
|
2323
2570
|
else:
|
|
2324
2571
|
pgpe_loss, pgpe_loss_smooth, pgpe_return = None, None, None
|
|
2325
2572
|
|
|
2326
|
-
# evaluate test losses and record best
|
|
2327
|
-
if
|
|
2328
|
-
|
|
2329
|
-
|
|
2330
|
-
|
|
2573
|
+
# evaluate test losses and record best parameters so far
|
|
2574
|
+
if self.parallel_updates is None:
|
|
2575
|
+
if test_loss_smooth < best_loss:
|
|
2576
|
+
best_params, best_loss, best_grad = \
|
|
2577
|
+
policy_params, test_loss_smooth, train_log['grad']
|
|
2578
|
+
pbest_loss = best_loss
|
|
2579
|
+
else:
|
|
2580
|
+
best_index = np.argmin(test_loss_smooth)
|
|
2581
|
+
if test_loss_smooth[best_index] < best_loss:
|
|
2582
|
+
best_params = self.pytree_at(policy_params, best_index)
|
|
2583
|
+
best_grad = self.pytree_at(train_log['grad'], best_index)
|
|
2584
|
+
best_loss = test_loss_smooth[best_index]
|
|
2585
|
+
pbest_loss = np.minimum(pbest_loss, test_loss_smooth)
|
|
2331
2586
|
|
|
2332
2587
|
# ==================================================================
|
|
2333
2588
|
# STATUS CHECKS AND LOGGING
|
|
2334
2589
|
# ==================================================================
|
|
2335
2590
|
|
|
2336
2591
|
# no progress
|
|
2337
|
-
|
|
2592
|
+
no_progress_flag = (not pgpe_improve) and np.all(zero_grads)
|
|
2593
|
+
if no_progress_flag:
|
|
2338
2594
|
status = JaxPlannerStatus.NO_PROGRESS
|
|
2339
|
-
|
|
2595
|
+
|
|
2340
2596
|
# constraint satisfaction problem
|
|
2341
|
-
if not np.all(converged):
|
|
2342
|
-
|
|
2343
|
-
|
|
2344
|
-
|
|
2345
|
-
|
|
2597
|
+
if not np.all(converged):
|
|
2598
|
+
if progress_bar is not None and not policy_constraint_msg_shown:
|
|
2599
|
+
message = termcolor.colored(
|
|
2600
|
+
'[FAIL] Policy update failed to satisfy action constraints.',
|
|
2601
|
+
'red')
|
|
2602
|
+
progress_bar.write(message)
|
|
2603
|
+
policy_constraint_msg_shown = True
|
|
2346
2604
|
status = JaxPlannerStatus.PRECONDITION_POSSIBLY_UNSATISFIED
|
|
2347
2605
|
|
|
2348
2606
|
# numerical error
|
|
2349
2607
|
if self.use_pgpe:
|
|
2350
|
-
invalid_loss = not (np.isfinite(train_loss) or
|
|
2608
|
+
invalid_loss = not (np.any(np.isfinite(train_loss)) or
|
|
2609
|
+
np.any(np.isfinite(pgpe_loss)))
|
|
2351
2610
|
else:
|
|
2352
|
-
invalid_loss = not np.isfinite(train_loss)
|
|
2611
|
+
invalid_loss = not np.any(np.isfinite(train_loss))
|
|
2353
2612
|
if invalid_loss:
|
|
2354
|
-
|
|
2613
|
+
if progress_bar is not None:
|
|
2614
|
+
message = termcolor.colored(
|
|
2615
|
+
f'[FAIL] Planner aborted due to invalid train loss {train_loss}.',
|
|
2616
|
+
'red')
|
|
2617
|
+
progress_bar.write(message)
|
|
2355
2618
|
status = JaxPlannerStatus.INVALID_GRADIENT
|
|
2356
2619
|
|
|
2620
|
+
# problem in the model compilation
|
|
2621
|
+
if progress_bar is not None:
|
|
2622
|
+
|
|
2623
|
+
# train model
|
|
2624
|
+
if not jax_train_msg_shown:
|
|
2625
|
+
messages = set()
|
|
2626
|
+
for error_code in np.unique(train_log['error']):
|
|
2627
|
+
messages.update(JaxRDDLCompiler.get_error_messages(error_code))
|
|
2628
|
+
if messages:
|
|
2629
|
+
messages = '\n '.join(messages)
|
|
2630
|
+
message = termcolor.colored(
|
|
2631
|
+
f'[FAIL] Compiler encountered the following '
|
|
2632
|
+
f'error(s) in the training model:\n {messages}', 'red')
|
|
2633
|
+
progress_bar.write(message)
|
|
2634
|
+
jax_train_msg_shown = True
|
|
2635
|
+
|
|
2636
|
+
# test model
|
|
2637
|
+
if not jax_test_msg_shown:
|
|
2638
|
+
messages = set()
|
|
2639
|
+
for error_code in np.unique(test_log['error']):
|
|
2640
|
+
messages.update(JaxRDDLCompiler.get_error_messages(error_code))
|
|
2641
|
+
if messages:
|
|
2642
|
+
messages = '\n '.join(messages)
|
|
2643
|
+
message = termcolor.colored(
|
|
2644
|
+
f'[FAIL] Compiler encountered the following '
|
|
2645
|
+
f'error(s) in the testing model:\n {messages}', 'red')
|
|
2646
|
+
progress_bar.write(message)
|
|
2647
|
+
jax_test_msg_shown = True
|
|
2648
|
+
|
|
2357
2649
|
# reached computation budget
|
|
2358
2650
|
elapsed = time.time() - start_time - elapsed_outside_loop
|
|
2359
2651
|
if elapsed >= train_seconds:
|
|
@@ -2387,20 +2679,39 @@ r"""
|
|
|
2387
2679
|
**test_log
|
|
2388
2680
|
}
|
|
2389
2681
|
|
|
2682
|
+
# hard restart
|
|
2683
|
+
if guess is None and no_progress_flag:
|
|
2684
|
+
no_progress_count += 1
|
|
2685
|
+
if no_progress_count > restart_epochs:
|
|
2686
|
+
key, subkey = random.split(key)
|
|
2687
|
+
policy_params, opt_state, opt_aux = self.initialize(
|
|
2688
|
+
subkey, policy_hyperparams, train_subs)
|
|
2689
|
+
no_progress_count = 0
|
|
2690
|
+
if self.print_warnings and progress_bar is not None:
|
|
2691
|
+
message = termcolor.colored(
|
|
2692
|
+
f'[INFO] Optimizer restarted at iteration {it} '
|
|
2693
|
+
f'due to lack of progress.', 'green')
|
|
2694
|
+
progress_bar.write(message)
|
|
2695
|
+
else:
|
|
2696
|
+
no_progress_count = 0
|
|
2697
|
+
|
|
2390
2698
|
# stopping condition reached
|
|
2391
2699
|
if stopping_rule is not None and stopping_rule.monitor(callback):
|
|
2700
|
+
if self.print_warnings and progress_bar is not None:
|
|
2701
|
+
message = termcolor.colored(
|
|
2702
|
+
'[SUCC] Stopping rule has been reached.', 'green')
|
|
2703
|
+
progress_bar.write(message)
|
|
2392
2704
|
callback['status'] = status = JaxPlannerStatus.STOPPING_RULE_REACHED
|
|
2393
2705
|
|
|
2394
2706
|
# if the progress bar is used
|
|
2395
2707
|
if print_progress:
|
|
2396
2708
|
progress_bar.set_description(
|
|
2397
|
-
f'{position_str} {it:6} it / {-train_loss:14.5f} train / '
|
|
2398
|
-
f'{-test_loss_smooth:14.5f} test / {-best_loss:14.5f} best / '
|
|
2709
|
+
f'{position_str} {it:6} it / {-np.min(train_loss):14.5f} train / '
|
|
2710
|
+
f'{-np.min(test_loss_smooth):14.5f} test / {-best_loss:14.5f} best / '
|
|
2399
2711
|
f'{status.value} status / {total_pgpe_it:6} pgpe',
|
|
2400
|
-
refresh=False
|
|
2401
|
-
)
|
|
2712
|
+
refresh=False)
|
|
2402
2713
|
progress_bar.set_postfix_str(
|
|
2403
|
-
f
|
|
2714
|
+
f'{(it + 1) / (elapsed + 1e-6):.2f}it/s', refresh=False)
|
|
2404
2715
|
progress_bar.update(progress_percent - progress_bar.n)
|
|
2405
2716
|
|
|
2406
2717
|
# dash-board
|
|
@@ -2423,24 +2734,16 @@ r"""
|
|
|
2423
2734
|
# release resources
|
|
2424
2735
|
if print_progress:
|
|
2425
2736
|
progress_bar.close()
|
|
2426
|
-
|
|
2427
|
-
# validate the test return
|
|
2428
|
-
if log:
|
|
2429
|
-
messages = set()
|
|
2430
|
-
for error_code in np.unique(log['error']):
|
|
2431
|
-
messages.update(JaxRDDLCompiler.get_error_messages(error_code))
|
|
2432
|
-
if messages:
|
|
2433
|
-
messages = '\n'.join(messages)
|
|
2434
|
-
raise_warning('JAX compiler encountered the following '
|
|
2435
|
-
'error(s) in the original RDDL formulation '
|
|
2436
|
-
f'during test evaluation:\n{messages}', 'red')
|
|
2737
|
+
print()
|
|
2437
2738
|
|
|
2438
2739
|
# summarize and test for convergence
|
|
2439
2740
|
if print_summary:
|
|
2440
|
-
grad_norm = jax.tree_map(
|
|
2741
|
+
grad_norm = jax.tree_util.tree_map(
|
|
2742
|
+
lambda x: np.linalg.norm(x).item(), best_grad)
|
|
2441
2743
|
diagnosis = self._perform_diagnosis(
|
|
2442
|
-
last_iter_improve, -train_loss, -test_loss_smooth,
|
|
2443
|
-
|
|
2744
|
+
last_iter_improve, -np.min(train_loss), -np.min(test_loss_smooth),
|
|
2745
|
+
-best_loss, grad_norm)
|
|
2746
|
+
print(f'Summary of optimization:\n'
|
|
2444
2747
|
f' status ={status}\n'
|
|
2445
2748
|
f' time ={elapsed:.3f} sec.\n'
|
|
2446
2749
|
f' iterations ={it}\n'
|
|
@@ -2453,12 +2756,9 @@ r"""
|
|
|
2453
2756
|
max_grad_norm = max(jax.tree_util.tree_leaves(grad_norm))
|
|
2454
2757
|
grad_is_zero = np.allclose(max_grad_norm, 0)
|
|
2455
2758
|
|
|
2456
|
-
validation_error = 100 * abs(test_return - train_return) / \
|
|
2457
|
-
max(abs(train_return), abs(test_return))
|
|
2458
|
-
|
|
2459
2759
|
# divergence if the solution is not finite
|
|
2460
2760
|
if not np.isfinite(train_return):
|
|
2461
|
-
return termcolor.colored('[
|
|
2761
|
+
return termcolor.colored('[FAIL] Training loss diverged.', 'red')
|
|
2462
2762
|
|
|
2463
2763
|
# hit a plateau is likely IF:
|
|
2464
2764
|
# 1. planner does not improve at all
|
|
@@ -2466,23 +2766,25 @@ r"""
|
|
|
2466
2766
|
if last_iter_improve <= 1:
|
|
2467
2767
|
if grad_is_zero:
|
|
2468
2768
|
return termcolor.colored(
|
|
2469
|
-
'[
|
|
2769
|
+
f'[FAIL] No progress was made '
|
|
2470
2770
|
f'and max grad norm {max_grad_norm:.6f} was zero: '
|
|
2471
|
-
'solver likely stuck in a plateau.', 'red')
|
|
2771
|
+
f'solver likely stuck in a plateau.', 'red')
|
|
2472
2772
|
else:
|
|
2473
2773
|
return termcolor.colored(
|
|
2474
|
-
'[
|
|
2774
|
+
f'[FAIL] No progress was made '
|
|
2475
2775
|
f'but max grad norm {max_grad_norm:.6f} was non-zero: '
|
|
2476
|
-
'learning rate or other hyper-parameters
|
|
2776
|
+
f'learning rate or other hyper-parameters could be suboptimal.',
|
|
2477
2777
|
'red')
|
|
2478
2778
|
|
|
2479
2779
|
# model is likely poor IF:
|
|
2480
2780
|
# 1. the train and test return disagree
|
|
2781
|
+
validation_error = 100 * abs(test_return - train_return) / \
|
|
2782
|
+
max(abs(train_return), abs(test_return))
|
|
2481
2783
|
if not (validation_error < 20):
|
|
2482
2784
|
return termcolor.colored(
|
|
2483
|
-
'[
|
|
2785
|
+
f'[WARN] Progress was made '
|
|
2484
2786
|
f'but relative train-test error {validation_error:.6f} was high: '
|
|
2485
|
-
'poor model relaxation around solution or batch size too small.',
|
|
2787
|
+
f'poor model relaxation around solution or batch size too small.',
|
|
2486
2788
|
'yellow')
|
|
2487
2789
|
|
|
2488
2790
|
# model likely did not converge IF:
|
|
@@ -2491,24 +2793,22 @@ r"""
|
|
|
2491
2793
|
return_to_grad_norm = abs(best_return) / max_grad_norm
|
|
2492
2794
|
if not (return_to_grad_norm > 1):
|
|
2493
2795
|
return termcolor.colored(
|
|
2494
|
-
'[
|
|
2796
|
+
f'[WARN] Progress was made '
|
|
2495
2797
|
f'but max grad norm {max_grad_norm:.6f} was high: '
|
|
2496
|
-
'solution locally suboptimal '
|
|
2497
|
-
'or
|
|
2498
|
-
'or batch size too small.', 'yellow')
|
|
2798
|
+
f'solution locally suboptimal, relaxed model nonsmooth around solution, '
|
|
2799
|
+
f'or batch size too small.', 'yellow')
|
|
2499
2800
|
|
|
2500
2801
|
# likely successful
|
|
2501
2802
|
return termcolor.colored(
|
|
2502
|
-
'[
|
|
2503
|
-
'(note: not all
|
|
2803
|
+
'[SUCC] Planner converged successfully '
|
|
2804
|
+
'(note: not all problems can be ruled out).', 'green')
|
|
2504
2805
|
|
|
2505
2806
|
def get_action(self, key: random.PRNGKey,
|
|
2506
2807
|
params: Pytree,
|
|
2507
2808
|
step: int,
|
|
2508
2809
|
subs: Dict[str, Any],
|
|
2509
2810
|
policy_hyperparams: Optional[Dict[str, Any]]=None) -> Dict[str, Any]:
|
|
2510
|
-
'''Returns an action dictionary from the policy or plan with the given
|
|
2511
|
-
parameters.
|
|
2811
|
+
'''Returns an action dictionary from the policy or plan with the given parameters.
|
|
2512
2812
|
|
|
2513
2813
|
:param key: the JAX PRNG key
|
|
2514
2814
|
:param params: the trainable parameter PyTree of the policy
|
|
@@ -2517,6 +2817,7 @@ r"""
|
|
|
2517
2817
|
:param policy_hyperparams: hyper-parameters for the policy/plan, such as
|
|
2518
2818
|
weights for sigmoid wrapping boolean actions (optional)
|
|
2519
2819
|
'''
|
|
2820
|
+
subs = subs.copy()
|
|
2520
2821
|
|
|
2521
2822
|
# check compatibility of the subs dictionary
|
|
2522
2823
|
for (var, values) in subs.items():
|
|
@@ -2535,13 +2836,17 @@ r"""
|
|
|
2535
2836
|
if step == 0 and var in self.rddl.observ_fluents:
|
|
2536
2837
|
subs[var] = self.test_compiled.init_values[var]
|
|
2537
2838
|
else:
|
|
2538
|
-
|
|
2539
|
-
|
|
2540
|
-
|
|
2839
|
+
if dtype.type is np.str_:
|
|
2840
|
+
prange = self.rddl.variable_ranges[var]
|
|
2841
|
+
subs[var] = self.rddl.object_string_to_index_array(prange, subs[var])
|
|
2842
|
+
else:
|
|
2843
|
+
raise ValueError(
|
|
2844
|
+
f'Values {values} assigned to p-variable <{var}> are '
|
|
2845
|
+
f'non-numeric of type {dtype}.')
|
|
2541
2846
|
|
|
2542
2847
|
# cast device arrays to numpy
|
|
2543
2848
|
actions = self.test_policy(key, params, policy_hyperparams, step, subs)
|
|
2544
|
-
actions = jax.tree_map(np.asarray, actions)
|
|
2849
|
+
actions = jax.tree_util.tree_map(np.asarray, actions)
|
|
2545
2850
|
return actions
|
|
2546
2851
|
|
|
2547
2852
|
|
|
@@ -2562,8 +2867,9 @@ class JaxOfflineController(BaseAgent):
|
|
|
2562
2867
|
def __init__(self, planner: JaxBackpropPlanner,
|
|
2563
2868
|
key: Optional[random.PRNGKey]=None,
|
|
2564
2869
|
eval_hyperparams: Optional[Dict[str, Any]]=None,
|
|
2565
|
-
params: Optional[Pytree]=None,
|
|
2870
|
+
params: Optional[Union[str, Pytree]]=None,
|
|
2566
2871
|
train_on_reset: bool=False,
|
|
2872
|
+
save_path: Optional[str]=None,
|
|
2567
2873
|
**train_kwargs) -> None:
|
|
2568
2874
|
'''Creates a new JAX offline control policy that is trained once, then
|
|
2569
2875
|
deployed later.
|
|
@@ -2574,8 +2880,10 @@ class JaxOfflineController(BaseAgent):
|
|
|
2574
2880
|
:param eval_hyperparams: policy hyperparameters to apply for evaluation
|
|
2575
2881
|
or whenever sample_action is called
|
|
2576
2882
|
:param params: use the specified policy parameters instead of calling
|
|
2577
|
-
planner.optimize()
|
|
2883
|
+
planner.optimize(); can be a string pointing to a valid file path where params
|
|
2884
|
+
have been saved, or a pytree of parameters
|
|
2578
2885
|
:param train_on_reset: retrain policy parameters on every episode reset
|
|
2886
|
+
:param save_path: optional path to save parameters to
|
|
2579
2887
|
:param **train_kwargs: any keyword arguments to be passed to the planner
|
|
2580
2888
|
for optimization
|
|
2581
2889
|
'''
|
|
@@ -2588,12 +2896,24 @@ class JaxOfflineController(BaseAgent):
|
|
|
2588
2896
|
self.train_kwargs = train_kwargs
|
|
2589
2897
|
self.params_given = params is not None
|
|
2590
2898
|
|
|
2899
|
+
# load the policy from file
|
|
2900
|
+
if not self.train_on_reset and params is not None and isinstance(params, str):
|
|
2901
|
+
with open(params, 'rb') as file:
|
|
2902
|
+
params = pickle.load(file)
|
|
2903
|
+
|
|
2904
|
+
# train the policy
|
|
2591
2905
|
self.step = 0
|
|
2592
2906
|
self.callback = None
|
|
2593
2907
|
if not self.train_on_reset and not self.params_given:
|
|
2594
2908
|
callback = self.planner.optimize(key=self.key, **self.train_kwargs)
|
|
2595
2909
|
self.callback = callback
|
|
2596
2910
|
params = callback['best_params']
|
|
2911
|
+
|
|
2912
|
+
# save the policy
|
|
2913
|
+
if save_path is not None:
|
|
2914
|
+
with open(save_path, 'wb') as file:
|
|
2915
|
+
pickle.dump(params, file)
|
|
2916
|
+
|
|
2597
2917
|
self.params = params
|
|
2598
2918
|
|
|
2599
2919
|
def sample_action(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
|
@@ -2605,6 +2925,8 @@ class JaxOfflineController(BaseAgent):
|
|
|
2605
2925
|
|
|
2606
2926
|
def reset(self) -> None:
|
|
2607
2927
|
self.step = 0
|
|
2928
|
+
|
|
2929
|
+
# train the policy if required to reset at the start of every episode
|
|
2608
2930
|
if self.train_on_reset and not self.params_given:
|
|
2609
2931
|
callback = self.planner.optimize(key=self.key, **self.train_kwargs)
|
|
2610
2932
|
self.callback = callback
|
|
@@ -2612,8 +2934,7 @@ class JaxOfflineController(BaseAgent):
|
|
|
2612
2934
|
|
|
2613
2935
|
|
|
2614
2936
|
class JaxOnlineController(BaseAgent):
|
|
2615
|
-
'''A container class for a Jax controller continuously updated using state
|
|
2616
|
-
feedback.'''
|
|
2937
|
+
'''A container class for a Jax controller continuously updated using state feedback.'''
|
|
2617
2938
|
|
|
2618
2939
|
use_tensor_obs = True
|
|
2619
2940
|
|
|
@@ -2621,17 +2942,19 @@ class JaxOnlineController(BaseAgent):
|
|
|
2621
2942
|
key: Optional[random.PRNGKey]=None,
|
|
2622
2943
|
eval_hyperparams: Optional[Dict[str, Any]]=None,
|
|
2623
2944
|
warm_start: bool=True,
|
|
2945
|
+
max_attempts: int=3,
|
|
2624
2946
|
**train_kwargs) -> None:
|
|
2625
2947
|
'''Creates a new JAX control policy that is trained online in a closed-
|
|
2626
2948
|
loop fashion.
|
|
2627
2949
|
|
|
2628
2950
|
:param planner: underlying planning algorithm for optimizing actions
|
|
2629
|
-
:param key: the RNG key to seed randomness (derives from clock if not
|
|
2630
|
-
provided)
|
|
2951
|
+
:param key: the RNG key to seed randomness (derives from clock if not provided)
|
|
2631
2952
|
:param eval_hyperparams: policy hyperparameters to apply for evaluation
|
|
2632
2953
|
or whenever sample_action is called
|
|
2633
2954
|
:param warm_start: whether to use the previous decision epoch final
|
|
2634
2955
|
policy parameters to warm the next decision epoch
|
|
2956
|
+
:param max_attempts: maximum attempted restarts of the optimizer when the total
|
|
2957
|
+
iteration count is 1 (i.e. the execution time is dominated by the jit compilation)
|
|
2635
2958
|
:param **train_kwargs: any keyword arguments to be passed to the planner
|
|
2636
2959
|
for optimization
|
|
2637
2960
|
'''
|
|
@@ -2642,20 +2965,34 @@ class JaxOnlineController(BaseAgent):
|
|
|
2642
2965
|
self.eval_hyperparams = eval_hyperparams
|
|
2643
2966
|
self.warm_start = warm_start
|
|
2644
2967
|
self.train_kwargs = train_kwargs
|
|
2968
|
+
self.max_attempts = max_attempts
|
|
2645
2969
|
self.reset()
|
|
2646
2970
|
|
|
2647
2971
|
def sample_action(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
|
2648
2972
|
planner = self.planner
|
|
2649
2973
|
callback = planner.optimize(
|
|
2650
|
-
key=self.key,
|
|
2651
|
-
|
|
2652
|
-
|
|
2653
|
-
|
|
2654
|
-
|
|
2974
|
+
key=self.key, guess=self.guess, subs=state, **self.train_kwargs)
|
|
2975
|
+
|
|
2976
|
+
# optimize again if jit compilation takes up the entire time budget
|
|
2977
|
+
attempts = 0
|
|
2978
|
+
while attempts < self.max_attempts and callback['iteration'] <= 1:
|
|
2979
|
+
attempts += 1
|
|
2980
|
+
if self.planner.print_warnings:
|
|
2981
|
+
message = termcolor.colored(
|
|
2982
|
+
f'[WARN] JIT compilation dominated the execution time: '
|
|
2983
|
+
f'executing the optimizer again on the traced model '
|
|
2984
|
+
f'[attempt {attempts}].', 'yellow')
|
|
2985
|
+
print(message)
|
|
2986
|
+
callback = planner.optimize(
|
|
2987
|
+
key=self.key, guess=self.guess, subs=state, **self.train_kwargs)
|
|
2655
2988
|
self.callback = callback
|
|
2656
2989
|
params = callback['best_params']
|
|
2990
|
+
|
|
2991
|
+
# get the action from the parameters for the current state
|
|
2657
2992
|
self.key, subkey = random.split(self.key)
|
|
2658
2993
|
actions = planner.get_action(subkey, params, 0, state, self.eval_hyperparams)
|
|
2994
|
+
|
|
2995
|
+
# apply warm start for the next epoch
|
|
2659
2996
|
if self.warm_start:
|
|
2660
2997
|
self.guess = planner.plan.guess_next_epoch(params)
|
|
2661
2998
|
return actions
|