pyRDDLGym-jax 2.4__py3-none-any.whl → 2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +23 -10
- pyRDDLGym_jax/core/logic.py +6 -8
- pyRDDLGym_jax/core/model.py +595 -0
- pyRDDLGym_jax/core/planner.py +317 -99
- pyRDDLGym_jax/core/simulator.py +37 -13
- pyRDDLGym_jax/core/tuning.py +25 -10
- pyRDDLGym_jax/entry_point.py +39 -7
- pyRDDLGym_jax/examples/configs/tuning_drp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/tuning_replan.cfg +1 -0
- pyRDDLGym_jax/examples/configs/tuning_slp.cfg +1 -0
- pyRDDLGym_jax/examples/run_plan.py +1 -1
- pyRDDLGym_jax/examples/run_tune.py +8 -2
- {pyrddlgym_jax-2.4.dist-info → pyrddlgym_jax-2.6.dist-info}/METADATA +17 -30
- {pyrddlgym_jax-2.4.dist-info → pyrddlgym_jax-2.6.dist-info}/RECORD +19 -18
- {pyrddlgym_jax-2.4.dist-info → pyrddlgym_jax-2.6.dist-info}/WHEEL +1 -1
- {pyrddlgym_jax-2.4.dist-info → pyrddlgym_jax-2.6.dist-info}/entry_points.txt +0 -0
- {pyrddlgym_jax-2.4.dist-info → pyrddlgym_jax-2.6.dist-info/licenses}/LICENSE +0 -0
- {pyrddlgym_jax-2.4.dist-info → pyrddlgym_jax-2.6.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/planner.py
CHANGED
|
@@ -39,6 +39,7 @@ import configparser
|
|
|
39
39
|
from enum import Enum
|
|
40
40
|
from functools import partial
|
|
41
41
|
import os
|
|
42
|
+
import pickle
|
|
42
43
|
import sys
|
|
43
44
|
import time
|
|
44
45
|
import traceback
|
|
@@ -206,6 +207,13 @@ def _load_config(config, args):
|
|
|
206
207
|
pgpe_kwargs['optimizer'] = pgpe_optimizer
|
|
207
208
|
planner_args['pgpe'] = getattr(sys.modules[__name__], pgpe_method)(**pgpe_kwargs)
|
|
208
209
|
|
|
210
|
+
# preprocessor settings
|
|
211
|
+
preproc_method = planner_args.get('preprocessor', None)
|
|
212
|
+
preproc_kwargs = planner_args.pop('preprocessor_kwargs', {})
|
|
213
|
+
if preproc_method is not None:
|
|
214
|
+
planner_args['preprocessor'] = getattr(
|
|
215
|
+
sys.modules[__name__], preproc_method)(**preproc_kwargs)
|
|
216
|
+
|
|
209
217
|
# optimize call RNG key
|
|
210
218
|
planner_key = train_args.get('key', None)
|
|
211
219
|
if planner_key is not None:
|
|
@@ -229,13 +237,19 @@ def _load_config(config, args):
|
|
|
229
237
|
|
|
230
238
|
|
|
231
239
|
def load_config(path: str) -> Tuple[Kwargs, ...]:
|
|
232
|
-
'''Loads a config file at the specified file path.
|
|
240
|
+
'''Loads a config file at the specified file path.
|
|
241
|
+
|
|
242
|
+
:param path: the path of the config file to load and parse
|
|
243
|
+
'''
|
|
233
244
|
config, args = _parse_config_file(path)
|
|
234
245
|
return _load_config(config, args)
|
|
235
246
|
|
|
236
247
|
|
|
237
248
|
def load_config_from_string(value: str) -> Tuple[Kwargs, ...]:
|
|
238
|
-
'''Loads config file contents specified explicitly as a string value.
|
|
249
|
+
'''Loads config file contents specified explicitly as a string value.
|
|
250
|
+
|
|
251
|
+
:param value: the string in json format containing the config contents to parse
|
|
252
|
+
'''
|
|
239
253
|
config, args = _parse_config_string(value)
|
|
240
254
|
return _load_config(config, args)
|
|
241
255
|
|
|
@@ -258,6 +272,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
258
272
|
def __init__(self, *args,
|
|
259
273
|
logic: Logic=FuzzyLogic(),
|
|
260
274
|
cpfs_without_grad: Optional[Set[str]]=None,
|
|
275
|
+
print_warnings: bool=True,
|
|
261
276
|
**kwargs) -> None:
|
|
262
277
|
'''Creates a new RDDL to Jax compiler, where operations that are not
|
|
263
278
|
differentiable are converted to approximate forms that have defined gradients.
|
|
@@ -268,6 +283,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
268
283
|
to customize these operations
|
|
269
284
|
:param cpfs_without_grad: which CPFs do not have gradients (use straight
|
|
270
285
|
through gradient trick)
|
|
286
|
+
:param print_warnings: whether to print warnings
|
|
271
287
|
:param *kwargs: keyword arguments to pass to base compiler
|
|
272
288
|
'''
|
|
273
289
|
super(JaxRDDLCompilerWithGrad, self).__init__(*args, **kwargs)
|
|
@@ -277,6 +293,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
277
293
|
if cpfs_without_grad is None:
|
|
278
294
|
cpfs_without_grad = set()
|
|
279
295
|
self.cpfs_without_grad = cpfs_without_grad
|
|
296
|
+
self.print_warnings = print_warnings
|
|
280
297
|
|
|
281
298
|
# actions and CPFs must be continuous
|
|
282
299
|
pvars_cast = set()
|
|
@@ -284,7 +301,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
284
301
|
self.init_values[var] = np.asarray(values, dtype=self.REAL)
|
|
285
302
|
if not np.issubdtype(np.result_type(values), np.floating):
|
|
286
303
|
pvars_cast.add(var)
|
|
287
|
-
if pvars_cast:
|
|
304
|
+
if self.print_warnings and pvars_cast:
|
|
288
305
|
message = termcolor.colored(
|
|
289
306
|
f'[INFO] JAX gradient compiler will cast p-vars {pvars_cast} to float.',
|
|
290
307
|
'green')
|
|
@@ -314,12 +331,12 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
314
331
|
if cpf in self.cpfs_without_grad:
|
|
315
332
|
jax_cpfs[cpf] = self._jax_stop_grad(jax_cpfs[cpf])
|
|
316
333
|
|
|
317
|
-
if cpfs_cast:
|
|
334
|
+
if self.print_warnings and cpfs_cast:
|
|
318
335
|
message = termcolor.colored(
|
|
319
336
|
f'[INFO] JAX gradient compiler will cast CPFs {cpfs_cast} to float.',
|
|
320
337
|
'green')
|
|
321
338
|
print(message)
|
|
322
|
-
if self.cpfs_without_grad:
|
|
339
|
+
if self.print_warnings and self.cpfs_without_grad:
|
|
323
340
|
message = termcolor.colored(
|
|
324
341
|
f'[INFO] Gradients will not flow through CPFs {self.cpfs_without_grad}.',
|
|
325
342
|
'green')
|
|
@@ -333,6 +350,100 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
333
350
|
return arg
|
|
334
351
|
|
|
335
352
|
|
|
353
|
+
# ***********************************************************************
|
|
354
|
+
# ALL VERSIONS OF STATE PREPROCESSING FOR DRP
|
|
355
|
+
#
|
|
356
|
+
# - static normalization
|
|
357
|
+
#
|
|
358
|
+
# ***********************************************************************
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
class Preprocessor(metaclass=ABCMeta):
|
|
362
|
+
'''Base class for all state preprocessors.'''
|
|
363
|
+
|
|
364
|
+
HYPERPARAMS_KEY = 'preprocessor__'
|
|
365
|
+
|
|
366
|
+
def __init__(self) -> None:
|
|
367
|
+
self._initializer = None
|
|
368
|
+
self._update = None
|
|
369
|
+
self._transform = None
|
|
370
|
+
|
|
371
|
+
@property
|
|
372
|
+
def initialize(self):
|
|
373
|
+
return self._initializer
|
|
374
|
+
|
|
375
|
+
@property
|
|
376
|
+
def update(self):
|
|
377
|
+
return self._update
|
|
378
|
+
|
|
379
|
+
@property
|
|
380
|
+
def transform(self):
|
|
381
|
+
return self._transform
|
|
382
|
+
|
|
383
|
+
@abstractmethod
|
|
384
|
+
def compile(self, compiled: JaxRDDLCompilerWithGrad) -> None:
|
|
385
|
+
pass
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
class StaticNormalizer(Preprocessor):
|
|
389
|
+
'''Normalize values by box constraints on fluents computed from the RDDL domain.'''
|
|
390
|
+
|
|
391
|
+
def __init__(self, fluent_bounds: Dict[str, Tuple[np.ndarray, np.ndarray]]={}) -> None:
|
|
392
|
+
'''Create a new instance of the static normalizer.
|
|
393
|
+
|
|
394
|
+
:param fluent_bounds: optional bounds on fluents to overwrite default values.
|
|
395
|
+
'''
|
|
396
|
+
self.fluent_bounds = fluent_bounds
|
|
397
|
+
|
|
398
|
+
def compile(self, compiled: JaxRDDLCompilerWithGrad) -> None:
|
|
399
|
+
|
|
400
|
+
# adjust for partial observability
|
|
401
|
+
rddl = compiled.rddl
|
|
402
|
+
if rddl.observ_fluents:
|
|
403
|
+
observed_vars = rddl.observ_fluents
|
|
404
|
+
else:
|
|
405
|
+
observed_vars = rddl.state_fluents
|
|
406
|
+
|
|
407
|
+
# ignore boolean fluents and infinite bounds
|
|
408
|
+
bounded_vars = {}
|
|
409
|
+
for var in observed_vars:
|
|
410
|
+
if rddl.variable_ranges[var] != 'bool':
|
|
411
|
+
lower, upper = compiled.constraints.bounds[var]
|
|
412
|
+
if np.all(np.isfinite(lower) & np.isfinite(upper) & (lower < upper)):
|
|
413
|
+
bounded_vars[var] = (lower, upper)
|
|
414
|
+
user_bounds = self.fluent_bounds.get(var, None)
|
|
415
|
+
if user_bounds is not None:
|
|
416
|
+
bounded_vars[var] = tuple(user_bounds)
|
|
417
|
+
|
|
418
|
+
# initialize to ranges computed by the constraint parser
|
|
419
|
+
def _jax_wrapped_normalizer_init():
|
|
420
|
+
return bounded_vars
|
|
421
|
+
self._initializer = jax.jit(_jax_wrapped_normalizer_init)
|
|
422
|
+
|
|
423
|
+
# static bounds
|
|
424
|
+
def _jax_wrapped_normalizer_update(subs, stats):
|
|
425
|
+
stats = {var: (jnp.asarray(lower, dtype=compiled.REAL),
|
|
426
|
+
jnp.asarray(upper, dtype=compiled.REAL))
|
|
427
|
+
for (var, (lower, upper)) in bounded_vars.items()}
|
|
428
|
+
return stats
|
|
429
|
+
self._update = jax.jit(_jax_wrapped_normalizer_update)
|
|
430
|
+
|
|
431
|
+
# apply min max scaling
|
|
432
|
+
def _jax_wrapped_normalizer_transform(subs, stats):
|
|
433
|
+
new_subs = {}
|
|
434
|
+
for (var, values) in subs.items():
|
|
435
|
+
if var in stats:
|
|
436
|
+
lower, upper = stats[var]
|
|
437
|
+
new_dims = jnp.ndim(values) - jnp.ndim(lower)
|
|
438
|
+
lower = lower[(jnp.newaxis,) * new_dims + (...,)]
|
|
439
|
+
upper = upper[(jnp.newaxis,) * new_dims + (...,)]
|
|
440
|
+
new_subs[var] = (values - lower) / (upper - lower)
|
|
441
|
+
else:
|
|
442
|
+
new_subs[var] = values
|
|
443
|
+
return new_subs
|
|
444
|
+
self._transform = jax.jit(_jax_wrapped_normalizer_transform)
|
|
445
|
+
|
|
446
|
+
|
|
336
447
|
# ***********************************************************************
|
|
337
448
|
# ALL VERSIONS OF JAX PLANS
|
|
338
449
|
#
|
|
@@ -358,7 +469,8 @@ class JaxPlan(metaclass=ABCMeta):
|
|
|
358
469
|
@abstractmethod
|
|
359
470
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
360
471
|
_bounds: Bounds,
|
|
361
|
-
horizon: int
|
|
472
|
+
horizon: int,
|
|
473
|
+
preprocessor: Optional[Preprocessor]=None) -> None:
|
|
362
474
|
pass
|
|
363
475
|
|
|
364
476
|
@abstractmethod
|
|
@@ -436,10 +548,11 @@ class JaxPlan(metaclass=ABCMeta):
|
|
|
436
548
|
~lower_finite & upper_finite,
|
|
437
549
|
~lower_finite & ~upper_finite]
|
|
438
550
|
bounds[name] = (lower, upper)
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
551
|
+
if compiled.print_warnings:
|
|
552
|
+
message = termcolor.colored(
|
|
553
|
+
f'[INFO] Bounds of action-fluent <{name}> set to {bounds[name]}.',
|
|
554
|
+
'green')
|
|
555
|
+
print(message)
|
|
443
556
|
return shapes, bounds, bounds_safe, cond_lists
|
|
444
557
|
|
|
445
558
|
def _count_bool_actions(self, rddl: RDDLLiftedModel):
|
|
@@ -508,7 +621,8 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
508
621
|
|
|
509
622
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
510
623
|
_bounds: Bounds,
|
|
511
|
-
horizon: int
|
|
624
|
+
horizon: int,
|
|
625
|
+
preprocessor: Optional[Preprocessor]=None) -> None:
|
|
512
626
|
rddl = compiled.rddl
|
|
513
627
|
|
|
514
628
|
# calculate the correct action box bounds
|
|
@@ -519,7 +633,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
519
633
|
# action concurrency check
|
|
520
634
|
bool_action_count, allowed_actions = self._count_bool_actions(rddl)
|
|
521
635
|
use_constraint_satisfaction = allowed_actions < bool_action_count
|
|
522
|
-
if use_constraint_satisfaction:
|
|
636
|
+
if compiled.print_warnings and use_constraint_satisfaction:
|
|
523
637
|
message = termcolor.colored(
|
|
524
638
|
f'[INFO] SLP will use projected gradient to satisfy '
|
|
525
639
|
f'max_nondef_actions since total boolean actions '
|
|
@@ -596,7 +710,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
596
710
|
return new_params, True
|
|
597
711
|
|
|
598
712
|
# convert softmax action back to action dict
|
|
599
|
-
action_sizes = {var: np.prod(shape[1:], dtype=
|
|
713
|
+
action_sizes = {var: np.prod(shape[1:], dtype=np.int64)
|
|
600
714
|
for (var, shape) in shapes.items()
|
|
601
715
|
if ranges[var] == 'bool'}
|
|
602
716
|
|
|
@@ -605,7 +719,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
605
719
|
start = 0
|
|
606
720
|
for (name, size) in action_sizes.items():
|
|
607
721
|
action = output[..., start:start + size]
|
|
608
|
-
action = jnp.reshape(action,
|
|
722
|
+
action = jnp.reshape(action, shapes[name][1:])
|
|
609
723
|
if noop[name]:
|
|
610
724
|
action = 1.0 - action
|
|
611
725
|
actions[name] = action
|
|
@@ -680,7 +794,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
680
794
|
scores = []
|
|
681
795
|
for (var, param) in params.items():
|
|
682
796
|
if ranges[var] == 'bool':
|
|
683
|
-
param_flat = jnp.ravel(param)
|
|
797
|
+
param_flat = jnp.ravel(param, order='C')
|
|
684
798
|
if noop[var]:
|
|
685
799
|
if wrap_sigmoid:
|
|
686
800
|
param_flat = -param_flat
|
|
@@ -838,7 +952,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
838
952
|
|
|
839
953
|
def guess_next_epoch(self, params: Pytree) -> Pytree:
|
|
840
954
|
next_fn = JaxStraightLinePlan._guess_next_epoch
|
|
841
|
-
return jax.tree_map(next_fn, params)
|
|
955
|
+
return jax.tree_util.tree_map(next_fn, params)
|
|
842
956
|
|
|
843
957
|
|
|
844
958
|
class JaxDeepReactivePolicy(JaxPlan):
|
|
@@ -897,7 +1011,8 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
897
1011
|
|
|
898
1012
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
899
1013
|
_bounds: Bounds,
|
|
900
|
-
horizon: int
|
|
1014
|
+
horizon: int,
|
|
1015
|
+
preprocessor: Optional[Preprocessor]=None) -> None:
|
|
901
1016
|
rddl = compiled.rddl
|
|
902
1017
|
|
|
903
1018
|
# calculate the correct action box bounds
|
|
@@ -928,7 +1043,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
928
1043
|
wrap_non_bool = self._wrap_non_bool
|
|
929
1044
|
init = self._initializer
|
|
930
1045
|
layers = list(enumerate(zip(self._topology, self._activations)))
|
|
931
|
-
layer_sizes = {var: np.prod(shape, dtype=
|
|
1046
|
+
layer_sizes = {var: np.prod(shape, dtype=np.int64)
|
|
932
1047
|
for (var, shape) in shapes.items()}
|
|
933
1048
|
layer_names = {var: f'output_{var}'.replace('-', '_') for var in shapes}
|
|
934
1049
|
|
|
@@ -946,21 +1061,28 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
946
1061
|
if ranges[var] != 'bool':
|
|
947
1062
|
value_size = np.size(values)
|
|
948
1063
|
if normalize_per_layer and value_size == 1:
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
1064
|
+
if compiled.print_warnings:
|
|
1065
|
+
message = termcolor.colored(
|
|
1066
|
+
f'[WARN] Cannot apply layer norm to state-fluent <{var}> '
|
|
1067
|
+
f'of size 1: setting normalize_per_layer = False.', 'yellow')
|
|
1068
|
+
print(message)
|
|
953
1069
|
normalize_per_layer = False
|
|
954
1070
|
non_bool_dims += value_size
|
|
955
1071
|
if not normalize_per_layer and non_bool_dims == 1:
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
1072
|
+
if compiled.print_warnings:
|
|
1073
|
+
message = termcolor.colored(
|
|
1074
|
+
'[WARN] Cannot apply layer norm to state-fluents of total size 1: '
|
|
1075
|
+
'setting normalize = False.', 'yellow')
|
|
1076
|
+
print(message)
|
|
960
1077
|
normalize = False
|
|
961
1078
|
|
|
962
1079
|
# convert subs dictionary into a state vector to feed to the MLP
|
|
963
|
-
def _jax_wrapped_policy_input(subs):
|
|
1080
|
+
def _jax_wrapped_policy_input(subs, hyperparams):
|
|
1081
|
+
|
|
1082
|
+
# optional state preprocessing
|
|
1083
|
+
if preprocessor is not None:
|
|
1084
|
+
stats = hyperparams[preprocessor.HYPERPARAMS_KEY]
|
|
1085
|
+
subs = preprocessor.transform(subs, stats)
|
|
964
1086
|
|
|
965
1087
|
# concatenate all state variables into a single vector
|
|
966
1088
|
# optionally apply layer norm to each input tensor
|
|
@@ -968,7 +1090,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
968
1090
|
non_bool_dims = 0
|
|
969
1091
|
for (var, value) in subs.items():
|
|
970
1092
|
if var in observed_vars:
|
|
971
|
-
state = jnp.ravel(value)
|
|
1093
|
+
state = jnp.ravel(value, order='C')
|
|
972
1094
|
if ranges[var] == 'bool':
|
|
973
1095
|
states_bool.append(state)
|
|
974
1096
|
else:
|
|
@@ -997,8 +1119,8 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
997
1119
|
return state
|
|
998
1120
|
|
|
999
1121
|
# predict actions from the policy network for current state
|
|
1000
|
-
def _jax_wrapped_policy_network_predict(subs):
|
|
1001
|
-
state = _jax_wrapped_policy_input(subs)
|
|
1122
|
+
def _jax_wrapped_policy_network_predict(subs, hyperparams):
|
|
1123
|
+
state = _jax_wrapped_policy_input(subs, hyperparams)
|
|
1002
1124
|
|
|
1003
1125
|
# feed state vector through hidden layers
|
|
1004
1126
|
hidden = state
|
|
@@ -1054,7 +1176,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1054
1176
|
for (name, size) in layer_sizes.items():
|
|
1055
1177
|
if ranges[name] == 'bool':
|
|
1056
1178
|
action = output[..., start:start + size]
|
|
1057
|
-
action = jnp.reshape(action,
|
|
1179
|
+
action = jnp.reshape(action, shapes[name])
|
|
1058
1180
|
if noop[name]:
|
|
1059
1181
|
action = 1.0 - action
|
|
1060
1182
|
actions[name] = action
|
|
@@ -1063,7 +1185,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1063
1185
|
|
|
1064
1186
|
# train action prediction
|
|
1065
1187
|
def _jax_wrapped_drp_predict_train(key, params, hyperparams, step, subs):
|
|
1066
|
-
actions = predict_fn.apply(params, subs)
|
|
1188
|
+
actions = predict_fn.apply(params, subs, hyperparams)
|
|
1067
1189
|
if not wrap_non_bool:
|
|
1068
1190
|
for (var, action) in actions.items():
|
|
1069
1191
|
if var != bool_key and ranges[var] != 'bool':
|
|
@@ -1113,7 +1235,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1113
1235
|
subs = {var: value[0, ...]
|
|
1114
1236
|
for (var, value) in subs.items()
|
|
1115
1237
|
if var in observed_vars}
|
|
1116
|
-
params = predict_fn.init(key, subs)
|
|
1238
|
+
params = predict_fn.init(key, subs, hyperparams)
|
|
1117
1239
|
return params
|
|
1118
1240
|
|
|
1119
1241
|
self.initializer = _jax_wrapped_drp_init
|
|
@@ -1226,6 +1348,7 @@ class PGPE(metaclass=ABCMeta):
|
|
|
1226
1348
|
|
|
1227
1349
|
@abstractmethod
|
|
1228
1350
|
def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type,
|
|
1351
|
+
print_warnings: bool,
|
|
1229
1352
|
parallel_updates: Optional[int]=None) -> None:
|
|
1230
1353
|
pass
|
|
1231
1354
|
|
|
@@ -1322,6 +1445,7 @@ class GaussianPGPE(PGPE):
|
|
|
1322
1445
|
)
|
|
1323
1446
|
|
|
1324
1447
|
def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type,
|
|
1448
|
+
print_warnings: bool,
|
|
1325
1449
|
parallel_updates: Optional[int]=None) -> None:
|
|
1326
1450
|
sigma0 = self.init_sigma
|
|
1327
1451
|
sigma_lo, sigma_hi = self.sigma_range
|
|
@@ -1347,7 +1471,7 @@ class GaussianPGPE(PGPE):
|
|
|
1347
1471
|
|
|
1348
1472
|
def _jax_wrapped_pgpe_init(key, policy_params):
|
|
1349
1473
|
mu = policy_params
|
|
1350
|
-
sigma = jax.tree_map(partial(jnp.full_like, fill_value=sigma0), mu)
|
|
1474
|
+
sigma = jax.tree_util.tree_map(partial(jnp.full_like, fill_value=sigma0), mu)
|
|
1351
1475
|
pgpe_params = (mu, sigma)
|
|
1352
1476
|
pgpe_opt_state = (mu_optimizer.init(mu), sigma_optimizer.init(sigma))
|
|
1353
1477
|
r_max = -jnp.inf
|
|
@@ -1395,13 +1519,14 @@ class GaussianPGPE(PGPE):
|
|
|
1395
1519
|
treedef = jax.tree_util.tree_structure(sigma)
|
|
1396
1520
|
keys = random.split(key, num=treedef.num_leaves)
|
|
1397
1521
|
keys_pytree = jax.tree_util.tree_unflatten(treedef=treedef, leaves=keys)
|
|
1398
|
-
epsilon = jax.tree_map(_jax_wrapped_mu_noise, keys_pytree, sigma)
|
|
1399
|
-
p1 = jax.tree_map(jnp.add, mu, epsilon)
|
|
1400
|
-
p2 = jax.tree_map(jnp.subtract, mu, epsilon)
|
|
1522
|
+
epsilon = jax.tree_util.tree_map(_jax_wrapped_mu_noise, keys_pytree, sigma)
|
|
1523
|
+
p1 = jax.tree_util.tree_map(jnp.add, mu, epsilon)
|
|
1524
|
+
p2 = jax.tree_util.tree_map(jnp.subtract, mu, epsilon)
|
|
1401
1525
|
if super_symmetric:
|
|
1402
|
-
epsilon_star = jax.tree_map(
|
|
1403
|
-
|
|
1404
|
-
|
|
1526
|
+
epsilon_star = jax.tree_util.tree_map(
|
|
1527
|
+
_jax_wrapped_epsilon_star, sigma, epsilon)
|
|
1528
|
+
p3 = jax.tree_util.tree_map(jnp.add, mu, epsilon_star)
|
|
1529
|
+
p4 = jax.tree_util.tree_map(jnp.subtract, mu, epsilon_star)
|
|
1405
1530
|
else:
|
|
1406
1531
|
epsilon_star, p3, p4 = epsilon, p1, p2
|
|
1407
1532
|
return p1, p2, p3, p4, epsilon, epsilon_star
|
|
@@ -1469,11 +1594,11 @@ class GaussianPGPE(PGPE):
|
|
|
1469
1594
|
r_max = jnp.maximum(r_max, r4)
|
|
1470
1595
|
else:
|
|
1471
1596
|
r3, r4 = r1, r2
|
|
1472
|
-
grad_mu = jax.tree_map(
|
|
1597
|
+
grad_mu = jax.tree_util.tree_map(
|
|
1473
1598
|
partial(_jax_wrapped_mu_grad, r1=r1, r2=r2, r3=r3, r4=r4, m=r_max),
|
|
1474
1599
|
epsilon, epsilon_star
|
|
1475
1600
|
)
|
|
1476
|
-
grad_sigma = jax.tree_map(
|
|
1601
|
+
grad_sigma = jax.tree_util.tree_map(
|
|
1477
1602
|
partial(_jax_wrapped_sigma_grad,
|
|
1478
1603
|
r1=r1, r2=r2, r3=r3, r4=r4, m=r_max, ent=ent),
|
|
1479
1604
|
epsilon, epsilon_star, sigma
|
|
@@ -1492,7 +1617,7 @@ class GaussianPGPE(PGPE):
|
|
|
1492
1617
|
_jax_wrapped_pgpe_grad,
|
|
1493
1618
|
in_axes=(0, None, None, None, None, None, None, None)
|
|
1494
1619
|
)(keys, mu, sigma, r_max, ent, policy_hyperparams, subs, model_params)
|
|
1495
|
-
mu_grad, sigma_grad = jax.tree_map(
|
|
1620
|
+
mu_grad, sigma_grad = jax.tree_util.tree_map(
|
|
1496
1621
|
partial(jnp.mean, axis=0), (mu_grads, sigma_grads))
|
|
1497
1622
|
new_r_max = jnp.max(r_maxs)
|
|
1498
1623
|
return mu_grad, sigma_grad, new_r_max
|
|
@@ -1516,7 +1641,7 @@ class GaussianPGPE(PGPE):
|
|
|
1516
1641
|
sigma_grad, sigma_state, params=sigma)
|
|
1517
1642
|
new_mu = optax.apply_updates(mu, mu_updates)
|
|
1518
1643
|
new_sigma = optax.apply_updates(sigma, sigma_updates)
|
|
1519
|
-
new_sigma = jax.tree_map(
|
|
1644
|
+
new_sigma = jax.tree_util.tree_map(
|
|
1520
1645
|
partial(jnp.clip, min=sigma_lo, max=sigma_hi), new_sigma)
|
|
1521
1646
|
return new_mu, new_sigma, new_mu_state, new_sigma_state
|
|
1522
1647
|
|
|
@@ -1537,7 +1662,7 @@ class GaussianPGPE(PGPE):
|
|
|
1537
1662
|
if max_kl is not None:
|
|
1538
1663
|
old_mu_lr = new_mu_state.hyperparams['learning_rate']
|
|
1539
1664
|
old_sigma_lr = new_sigma_state.hyperparams['learning_rate']
|
|
1540
|
-
kl_terms = jax.tree_map(
|
|
1665
|
+
kl_terms = jax.tree_util.tree_map(
|
|
1541
1666
|
_jax_wrapped_pgpe_kl_term, new_mu, new_sigma, mu, sigma)
|
|
1542
1667
|
total_kl = jax.tree_util.tree_reduce(jnp.add, kl_terms)
|
|
1543
1668
|
kl_reduction = jnp.minimum(1.0, jnp.sqrt(max_kl / total_kl))
|
|
@@ -1618,12 +1743,21 @@ def mean_semivariance_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
|
1618
1743
|
return mu - 0.5 * beta * msv
|
|
1619
1744
|
|
|
1620
1745
|
|
|
1746
|
+
@jax.jit
|
|
1747
|
+
def sharpe_utility(returns: jnp.ndarray, risk_free: float) -> float:
|
|
1748
|
+
return (jnp.mean(returns) - risk_free) / (jnp.std(returns) + 1e-10)
|
|
1749
|
+
|
|
1750
|
+
|
|
1751
|
+
@jax.jit
|
|
1752
|
+
def var_utility(returns: jnp.ndarray, alpha: float) -> float:
|
|
1753
|
+
return jnp.percentile(returns, q=100 * alpha)
|
|
1754
|
+
|
|
1755
|
+
|
|
1621
1756
|
@jax.jit
|
|
1622
1757
|
def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
|
|
1623
1758
|
var = jnp.percentile(returns, q=100 * alpha)
|
|
1624
1759
|
mask = returns <= var
|
|
1625
|
-
|
|
1626
|
-
return jnp.sum(returns * weights)
|
|
1760
|
+
return jnp.sum(returns * mask) / jnp.maximum(1, jnp.sum(mask))
|
|
1627
1761
|
|
|
1628
1762
|
|
|
1629
1763
|
# set of all currently valid built-in utility functions
|
|
@@ -1633,8 +1767,10 @@ UTILITY_LOOKUP = {
|
|
|
1633
1767
|
'mean_std': mean_deviation_utility,
|
|
1634
1768
|
'mean_semivar': mean_semivariance_utility,
|
|
1635
1769
|
'mean_semidev': mean_semideviation_utility,
|
|
1770
|
+
'sharpe': sharpe_utility,
|
|
1636
1771
|
'entropic': entropic_utility,
|
|
1637
1772
|
'exponential': entropic_utility,
|
|
1773
|
+
'var': var_utility,
|
|
1638
1774
|
'cvar': cvar_utility
|
|
1639
1775
|
}
|
|
1640
1776
|
|
|
@@ -1672,7 +1808,9 @@ class JaxBackpropPlanner:
|
|
|
1672
1808
|
compile_non_fluent_exact: bool=True,
|
|
1673
1809
|
logger: Optional[Logger]=None,
|
|
1674
1810
|
dashboard_viz: Optional[Any]=None,
|
|
1675
|
-
|
|
1811
|
+
print_warnings: bool=True,
|
|
1812
|
+
parallel_updates: Optional[int]=None,
|
|
1813
|
+
preprocessor: Optional[Preprocessor]=None) -> None:
|
|
1676
1814
|
'''Creates a new gradient-based algorithm for optimizing action sequences
|
|
1677
1815
|
(plan) in the given RDDL. Some operations will be converted to their
|
|
1678
1816
|
differentiable counterparts; the specific operations can be customized
|
|
@@ -1712,7 +1850,9 @@ class JaxBackpropPlanner:
|
|
|
1712
1850
|
:param logger: to log information about compilation to file
|
|
1713
1851
|
:param dashboard_viz: optional visualizer object from the environment
|
|
1714
1852
|
to pass to the dashboard to visualize the policy
|
|
1853
|
+
:param print_warnings: whether to print warnings
|
|
1715
1854
|
:param parallel_updates: how many optimizers to run independently in parallel
|
|
1855
|
+
:param preprocessor: optional preprocessor for state inputs to plan
|
|
1716
1856
|
'''
|
|
1717
1857
|
self.rddl = rddl
|
|
1718
1858
|
self.plan = plan
|
|
@@ -1737,6 +1877,8 @@ class JaxBackpropPlanner:
|
|
|
1737
1877
|
self.noise_kwargs = noise_kwargs
|
|
1738
1878
|
self.pgpe = pgpe
|
|
1739
1879
|
self.use_pgpe = pgpe is not None
|
|
1880
|
+
self.print_warnings = print_warnings
|
|
1881
|
+
self.preprocessor = preprocessor
|
|
1740
1882
|
|
|
1741
1883
|
# set optimizer
|
|
1742
1884
|
try:
|
|
@@ -1789,7 +1931,11 @@ class JaxBackpropPlanner:
|
|
|
1789
1931
|
self._jax_compile_rddl()
|
|
1790
1932
|
self._jax_compile_optimizer()
|
|
1791
1933
|
|
|
1792
|
-
|
|
1934
|
+
@staticmethod
|
|
1935
|
+
def summarize_system() -> str:
|
|
1936
|
+
'''Returns a string containing information about the system, Python version
|
|
1937
|
+
and jax-related packages that are relevant to the current planner.
|
|
1938
|
+
'''
|
|
1793
1939
|
try:
|
|
1794
1940
|
jaxlib_version = jax._src.lib.version_str
|
|
1795
1941
|
except Exception as _:
|
|
@@ -1818,6 +1964,9 @@ r"""
|
|
|
1818
1964
|
f'devices: {devices_short}\n')
|
|
1819
1965
|
|
|
1820
1966
|
def summarize_relaxations(self) -> str:
|
|
1967
|
+
'''Returns a summary table containing all non-differentiable operators
|
|
1968
|
+
and their relaxations.
|
|
1969
|
+
'''
|
|
1821
1970
|
result = ''
|
|
1822
1971
|
if self.compiled.model_params:
|
|
1823
1972
|
result += ('Some RDDL operations are non-differentiable '
|
|
@@ -1834,6 +1983,9 @@ r"""
|
|
|
1834
1983
|
return result
|
|
1835
1984
|
|
|
1836
1985
|
def summarize_hyperparameters(self) -> str:
|
|
1986
|
+
'''Returns a string summarizing the hyper-parameters of the current planner
|
|
1987
|
+
instance.
|
|
1988
|
+
'''
|
|
1837
1989
|
result = (f'objective hyper-parameters:\n'
|
|
1838
1990
|
f' utility_fn ={self.utility.__name__}\n'
|
|
1839
1991
|
f' utility args ={self.utility_kwargs}\n'
|
|
@@ -1852,7 +2004,8 @@ r"""
|
|
|
1852
2004
|
f' noise_kwargs ={self.noise_kwargs}\n'
|
|
1853
2005
|
f' batch_size_train ={self.batch_size_train}\n'
|
|
1854
2006
|
f' batch_size_test ={self.batch_size_test}\n'
|
|
1855
|
-
f' parallel_updates ={self.parallel_updates}\n'
|
|
2007
|
+
f' parallel_updates ={self.parallel_updates}\n'
|
|
2008
|
+
f' preprocessor ={self.preprocessor}\n')
|
|
1856
2009
|
result += str(self.plan)
|
|
1857
2010
|
if self.use_pgpe:
|
|
1858
2011
|
result += str(self.pgpe)
|
|
@@ -1873,7 +2026,8 @@ r"""
|
|
|
1873
2026
|
logger=self.logger,
|
|
1874
2027
|
use64bit=self.use64bit,
|
|
1875
2028
|
cpfs_without_grad=self.cpfs_without_grad,
|
|
1876
|
-
compile_non_fluent_exact=self.compile_non_fluent_exact
|
|
2029
|
+
compile_non_fluent_exact=self.compile_non_fluent_exact,
|
|
2030
|
+
print_warnings=self.print_warnings
|
|
1877
2031
|
)
|
|
1878
2032
|
self.compiled.compile(log_jax_expr=True, heading='RELAXED MODEL')
|
|
1879
2033
|
|
|
@@ -1887,10 +2041,15 @@ r"""
|
|
|
1887
2041
|
|
|
1888
2042
|
def _jax_compile_optimizer(self):
|
|
1889
2043
|
|
|
2044
|
+
# preprocessor
|
|
2045
|
+
if self.preprocessor is not None:
|
|
2046
|
+
self.preprocessor.compile(self.compiled)
|
|
2047
|
+
|
|
1890
2048
|
# policy
|
|
1891
2049
|
self.plan.compile(self.compiled,
|
|
1892
2050
|
_bounds=self._action_bounds,
|
|
1893
|
-
horizon=self.horizon
|
|
2051
|
+
horizon=self.horizon,
|
|
2052
|
+
preprocessor=self.preprocessor)
|
|
1894
2053
|
self.train_policy = jax.jit(self.plan.train_policy)
|
|
1895
2054
|
self.test_policy = jax.jit(self.plan.test_policy)
|
|
1896
2055
|
|
|
@@ -1898,14 +2057,16 @@ r"""
|
|
|
1898
2057
|
train_rollouts = self.compiled.compile_rollouts(
|
|
1899
2058
|
policy=self.plan.train_policy,
|
|
1900
2059
|
n_steps=self.horizon,
|
|
1901
|
-
n_batch=self.batch_size_train
|
|
2060
|
+
n_batch=self.batch_size_train,
|
|
2061
|
+
cache_path_info=self.preprocessor is not None
|
|
1902
2062
|
)
|
|
1903
2063
|
self.train_rollouts = train_rollouts
|
|
1904
2064
|
|
|
1905
2065
|
test_rollouts = self.test_compiled.compile_rollouts(
|
|
1906
2066
|
policy=self.plan.test_policy,
|
|
1907
2067
|
n_steps=self.horizon,
|
|
1908
|
-
n_batch=self.batch_size_test
|
|
2068
|
+
n_batch=self.batch_size_test,
|
|
2069
|
+
cache_path_info=False
|
|
1909
2070
|
)
|
|
1910
2071
|
self.test_rollouts = jax.jit(test_rollouts)
|
|
1911
2072
|
|
|
@@ -1922,7 +2083,8 @@ r"""
|
|
|
1922
2083
|
|
|
1923
2084
|
# optimization
|
|
1924
2085
|
self.update = self._jax_update(train_loss)
|
|
1925
|
-
self.pytree_at = jax.jit(
|
|
2086
|
+
self.pytree_at = jax.jit(
|
|
2087
|
+
lambda tree, i: jax.tree_util.tree_map(lambda x: x[i], tree))
|
|
1926
2088
|
|
|
1927
2089
|
# pgpe option
|
|
1928
2090
|
if self.use_pgpe:
|
|
@@ -1930,6 +2092,7 @@ r"""
|
|
|
1930
2092
|
loss_fn=test_loss,
|
|
1931
2093
|
projection=self.plan.projection,
|
|
1932
2094
|
real_dtype=self.test_compiled.REAL,
|
|
2095
|
+
print_warnings=self.print_warnings,
|
|
1933
2096
|
parallel_updates=self.parallel_updates
|
|
1934
2097
|
)
|
|
1935
2098
|
self.merge_pgpe = self._jax_merge_pgpe_jaxplan()
|
|
@@ -2010,7 +2173,7 @@ r"""
|
|
|
2010
2173
|
# check if the gradients are all zeros
|
|
2011
2174
|
def _jax_wrapped_zero_gradients(grad):
|
|
2012
2175
|
leaves, _ = jax.tree_util.tree_flatten(
|
|
2013
|
-
jax.tree_map(partial(jnp.allclose, b=0), grad))
|
|
2176
|
+
jax.tree_util.tree_map(partial(jnp.allclose, b=0), grad))
|
|
2014
2177
|
return jnp.all(jnp.asarray(leaves))
|
|
2015
2178
|
|
|
2016
2179
|
# calculate the plan gradient w.r.t. return loss and update optimizer
|
|
@@ -2069,7 +2232,7 @@ r"""
|
|
|
2069
2232
|
def select_fn(leaf1, leaf2):
|
|
2070
2233
|
expanded_mask = pgpe_mask[(...,) + (jnp.newaxis,) * (jnp.ndim(leaf1) - 1)]
|
|
2071
2234
|
return jnp.where(expanded_mask, leaf1, leaf2)
|
|
2072
|
-
policy_params = jax.tree_map(select_fn, pgpe_param, policy_params)
|
|
2235
|
+
policy_params = jax.tree_util.tree_map(select_fn, pgpe_param, policy_params)
|
|
2073
2236
|
test_loss = jnp.where(pgpe_mask, pgpe_loss, test_loss)
|
|
2074
2237
|
test_loss_smooth = jnp.where(pgpe_mask, pgpe_loss_smooth, test_loss_smooth)
|
|
2075
2238
|
expanded_mask = pgpe_mask[(...,) + (jnp.newaxis,) * (jnp.ndim(converged) - 1)]
|
|
@@ -2091,7 +2254,9 @@ r"""
|
|
|
2091
2254
|
f'Variable <{name}> in subs argument is not a '
|
|
2092
2255
|
f'valid p-variable, must be one of '
|
|
2093
2256
|
f'{set(self.test_compiled.init_values.keys())}.')
|
|
2094
|
-
value = np.reshape(value,
|
|
2257
|
+
value = np.reshape(value, np.shape(init_value))[np.newaxis, ...]
|
|
2258
|
+
if value.dtype.type is np.str_:
|
|
2259
|
+
value = rddl.object_string_to_index_array(rddl.variable_ranges[name], value)
|
|
2095
2260
|
train_value = np.repeat(value, repeats=n_train, axis=0)
|
|
2096
2261
|
train_value = np.asarray(train_value, dtype=self.compiled.REAL)
|
|
2097
2262
|
init_train[name] = train_value
|
|
@@ -2121,7 +2286,7 @@ r"""
|
|
|
2121
2286
|
x[np.newaxis, ...], shape=(self.parallel_updates,) + np.shape(x))
|
|
2122
2287
|
return x
|
|
2123
2288
|
|
|
2124
|
-
return jax.tree_map(make_batched, pytree)
|
|
2289
|
+
return jax.tree_util.tree_map(make_batched, pytree)
|
|
2125
2290
|
|
|
2126
2291
|
def as_optimization_problem(
|
|
2127
2292
|
self, key: Optional[random.PRNGKey]=None,
|
|
@@ -2165,10 +2330,11 @@ r"""
|
|
|
2165
2330
|
train_subs, _ = self._batched_init_subs(subs)
|
|
2166
2331
|
model_params = self.compiled.model_params
|
|
2167
2332
|
if policy_hyperparams is None:
|
|
2168
|
-
|
|
2169
|
-
|
|
2170
|
-
|
|
2171
|
-
|
|
2333
|
+
if self.print_warnings:
|
|
2334
|
+
message = termcolor.colored(
|
|
2335
|
+
'[WARN] policy_hyperparams is not set, setting 1.0 for '
|
|
2336
|
+
'all action-fluents which could be suboptimal.', 'yellow')
|
|
2337
|
+
print(message)
|
|
2172
2338
|
policy_hyperparams = {action: 1.0
|
|
2173
2339
|
for action in self.rddl.action_fluents}
|
|
2174
2340
|
|
|
@@ -2318,10 +2484,11 @@ r"""
|
|
|
2318
2484
|
|
|
2319
2485
|
# cannot run dashboard with parallel updates
|
|
2320
2486
|
if dashboard is not None and self.parallel_updates is not None:
|
|
2321
|
-
|
|
2322
|
-
|
|
2323
|
-
|
|
2324
|
-
|
|
2487
|
+
if self.print_warnings:
|
|
2488
|
+
message = termcolor.colored(
|
|
2489
|
+
'[WARN] Dashboard is unavailable if parallel_updates is not None: '
|
|
2490
|
+
'setting dashboard to None.', 'yellow')
|
|
2491
|
+
print(message)
|
|
2325
2492
|
dashboard = None
|
|
2326
2493
|
|
|
2327
2494
|
# if PRNG key is not provided
|
|
@@ -2331,19 +2498,21 @@ r"""
|
|
|
2331
2498
|
|
|
2332
2499
|
# if policy_hyperparams is not provided
|
|
2333
2500
|
if policy_hyperparams is None:
|
|
2334
|
-
|
|
2335
|
-
|
|
2336
|
-
|
|
2337
|
-
|
|
2501
|
+
if self.print_warnings:
|
|
2502
|
+
message = termcolor.colored(
|
|
2503
|
+
'[WARN] policy_hyperparams is not set, setting 1.0 for '
|
|
2504
|
+
'all action-fluents which could be suboptimal.', 'yellow')
|
|
2505
|
+
print(message)
|
|
2338
2506
|
policy_hyperparams = {action: 1.0
|
|
2339
2507
|
for action in self.rddl.action_fluents}
|
|
2340
2508
|
|
|
2341
2509
|
# if policy_hyperparams is a scalar
|
|
2342
2510
|
elif isinstance(policy_hyperparams, (int, float, np.number)):
|
|
2343
|
-
|
|
2344
|
-
|
|
2345
|
-
|
|
2346
|
-
|
|
2511
|
+
if self.print_warnings:
|
|
2512
|
+
message = termcolor.colored(
|
|
2513
|
+
f'[INFO] policy_hyperparams is {policy_hyperparams}, '
|
|
2514
|
+
f'setting this value for all action-fluents.', 'green')
|
|
2515
|
+
print(message)
|
|
2347
2516
|
hyperparam_value = float(policy_hyperparams)
|
|
2348
2517
|
policy_hyperparams = {action: hyperparam_value
|
|
2349
2518
|
for action in self.rddl.action_fluents}
|
|
@@ -2352,13 +2521,20 @@ r"""
|
|
|
2352
2521
|
elif isinstance(policy_hyperparams, dict):
|
|
2353
2522
|
for action in self.rddl.action_fluents:
|
|
2354
2523
|
if action not in policy_hyperparams:
|
|
2355
|
-
|
|
2356
|
-
|
|
2357
|
-
|
|
2358
|
-
|
|
2359
|
-
|
|
2524
|
+
if self.print_warnings:
|
|
2525
|
+
message = termcolor.colored(
|
|
2526
|
+
f'[WARN] policy_hyperparams[{action}] is not set, '
|
|
2527
|
+
f'setting 1.0 for missing action-fluents '
|
|
2528
|
+
f'which could be suboptimal.', 'yellow')
|
|
2529
|
+
print(message)
|
|
2360
2530
|
policy_hyperparams[action] = 1.0
|
|
2361
|
-
|
|
2531
|
+
|
|
2532
|
+
# initialize preprocessor
|
|
2533
|
+
preproc_key = None
|
|
2534
|
+
if self.preprocessor is not None:
|
|
2535
|
+
preproc_key = self.preprocessor.HYPERPARAMS_KEY
|
|
2536
|
+
policy_hyperparams[preproc_key] = self.preprocessor.initialize()
|
|
2537
|
+
|
|
2362
2538
|
# print summary of parameters:
|
|
2363
2539
|
if print_summary:
|
|
2364
2540
|
print(self.summarize_system())
|
|
@@ -2396,7 +2572,7 @@ r"""
|
|
|
2396
2572
|
if var not in subs:
|
|
2397
2573
|
subs[var] = value
|
|
2398
2574
|
added_pvars_to_subs.append(var)
|
|
2399
|
-
if added_pvars_to_subs:
|
|
2575
|
+
if self.print_warnings and added_pvars_to_subs:
|
|
2400
2576
|
message = termcolor.colored(
|
|
2401
2577
|
f'[INFO] p-variables {added_pvars_to_subs} is not in '
|
|
2402
2578
|
f'provided subs, using their initial values.', 'green')
|
|
@@ -2485,6 +2661,11 @@ r"""
|
|
|
2485
2661
|
subkey, policy_params, policy_hyperparams, train_subs, model_params,
|
|
2486
2662
|
opt_state, opt_aux)
|
|
2487
2663
|
|
|
2664
|
+
# update the preprocessor
|
|
2665
|
+
if self.preprocessor is not None:
|
|
2666
|
+
policy_hyperparams[preproc_key] = self.preprocessor.update(
|
|
2667
|
+
train_log['fluents'], policy_hyperparams[preproc_key])
|
|
2668
|
+
|
|
2488
2669
|
# evaluate
|
|
2489
2670
|
test_loss, (test_log, model_params_test) = self.test_loss(
|
|
2490
2671
|
subkey, policy_params, policy_hyperparams, test_subs, model_params_test)
|
|
@@ -2637,6 +2818,7 @@ r"""
|
|
|
2637
2818
|
'model_params': model_params,
|
|
2638
2819
|
'progress': progress_percent,
|
|
2639
2820
|
'train_log': train_log,
|
|
2821
|
+
'policy_hyperparams': policy_hyperparams,
|
|
2640
2822
|
**test_log
|
|
2641
2823
|
}
|
|
2642
2824
|
|
|
@@ -2648,7 +2830,7 @@ r"""
|
|
|
2648
2830
|
policy_params, opt_state, opt_aux = self.initialize(
|
|
2649
2831
|
subkey, policy_hyperparams, train_subs)
|
|
2650
2832
|
no_progress_count = 0
|
|
2651
|
-
if progress_bar is not None:
|
|
2833
|
+
if self.print_warnings and progress_bar is not None:
|
|
2652
2834
|
message = termcolor.colored(
|
|
2653
2835
|
f'[INFO] Optimizer restarted at iteration {it} '
|
|
2654
2836
|
f'due to lack of progress.', 'green')
|
|
@@ -2658,7 +2840,7 @@ r"""
|
|
|
2658
2840
|
|
|
2659
2841
|
# stopping condition reached
|
|
2660
2842
|
if stopping_rule is not None and stopping_rule.monitor(callback):
|
|
2661
|
-
if progress_bar is not None:
|
|
2843
|
+
if self.print_warnings and progress_bar is not None:
|
|
2662
2844
|
message = termcolor.colored(
|
|
2663
2845
|
'[SUCC] Stopping rule has been reached.', 'green')
|
|
2664
2846
|
progress_bar.write(message)
|
|
@@ -2699,7 +2881,8 @@ r"""
|
|
|
2699
2881
|
|
|
2700
2882
|
# summarize and test for convergence
|
|
2701
2883
|
if print_summary:
|
|
2702
|
-
grad_norm = jax.tree_map(
|
|
2884
|
+
grad_norm = jax.tree_util.tree_map(
|
|
2885
|
+
lambda x: np.linalg.norm(x).item(), best_grad)
|
|
2703
2886
|
diagnosis = self._perform_diagnosis(
|
|
2704
2887
|
last_iter_improve, -np.min(train_loss), -np.min(test_loss_smooth),
|
|
2705
2888
|
-best_loss, grad_norm)
|
|
@@ -2713,7 +2896,8 @@ r"""
|
|
|
2713
2896
|
|
|
2714
2897
|
def _perform_diagnosis(self, last_iter_improve,
|
|
2715
2898
|
train_return, test_return, best_return, grad_norm):
|
|
2716
|
-
|
|
2899
|
+
grad_norms = jax.tree_util.tree_leaves(grad_norm)
|
|
2900
|
+
max_grad_norm = max(grad_norms) if grad_norms else np.nan
|
|
2717
2901
|
grad_is_zero = np.allclose(max_grad_norm, 0)
|
|
2718
2902
|
|
|
2719
2903
|
# divergence if the solution is not finite
|
|
@@ -2777,6 +2961,7 @@ r"""
|
|
|
2777
2961
|
:param policy_hyperparams: hyper-parameters for the policy/plan, such as
|
|
2778
2962
|
weights for sigmoid wrapping boolean actions (optional)
|
|
2779
2963
|
'''
|
|
2964
|
+
subs = subs.copy()
|
|
2780
2965
|
|
|
2781
2966
|
# check compatibility of the subs dictionary
|
|
2782
2967
|
for (var, values) in subs.items():
|
|
@@ -2795,13 +2980,17 @@ r"""
|
|
|
2795
2980
|
if step == 0 and var in self.rddl.observ_fluents:
|
|
2796
2981
|
subs[var] = self.test_compiled.init_values[var]
|
|
2797
2982
|
else:
|
|
2798
|
-
|
|
2799
|
-
|
|
2800
|
-
|
|
2983
|
+
if dtype.type is np.str_:
|
|
2984
|
+
prange = self.rddl.variable_ranges[var]
|
|
2985
|
+
subs[var] = self.rddl.object_string_to_index_array(prange, subs[var])
|
|
2986
|
+
else:
|
|
2987
|
+
raise ValueError(
|
|
2988
|
+
f'Values {values} assigned to p-variable <{var}> are '
|
|
2989
|
+
f'non-numeric of type {dtype}.')
|
|
2801
2990
|
|
|
2802
2991
|
# cast device arrays to numpy
|
|
2803
2992
|
actions = self.test_policy(key, params, policy_hyperparams, step, subs)
|
|
2804
|
-
actions = jax.tree_map(np.asarray, actions)
|
|
2993
|
+
actions = jax.tree_util.tree_map(np.asarray, actions)
|
|
2805
2994
|
return actions
|
|
2806
2995
|
|
|
2807
2996
|
|
|
@@ -2822,8 +3011,9 @@ class JaxOfflineController(BaseAgent):
|
|
|
2822
3011
|
def __init__(self, planner: JaxBackpropPlanner,
|
|
2823
3012
|
key: Optional[random.PRNGKey]=None,
|
|
2824
3013
|
eval_hyperparams: Optional[Dict[str, Any]]=None,
|
|
2825
|
-
params: Optional[Pytree]=None,
|
|
3014
|
+
params: Optional[Union[str, Pytree]]=None,
|
|
2826
3015
|
train_on_reset: bool=False,
|
|
3016
|
+
save_path: Optional[str]=None,
|
|
2827
3017
|
**train_kwargs) -> None:
|
|
2828
3018
|
'''Creates a new JAX offline control policy that is trained once, then
|
|
2829
3019
|
deployed later.
|
|
@@ -2834,8 +3024,10 @@ class JaxOfflineController(BaseAgent):
|
|
|
2834
3024
|
:param eval_hyperparams: policy hyperparameters to apply for evaluation
|
|
2835
3025
|
or whenever sample_action is called
|
|
2836
3026
|
:param params: use the specified policy parameters instead of calling
|
|
2837
|
-
planner.optimize()
|
|
3027
|
+
planner.optimize(); can be a string pointing to a valid file path where params
|
|
3028
|
+
have been saved, or a pytree of parameters
|
|
2838
3029
|
:param train_on_reset: retrain policy parameters on every episode reset
|
|
3030
|
+
:param save_path: optional path to save parameters to
|
|
2839
3031
|
:param **train_kwargs: any keyword arguments to be passed to the planner
|
|
2840
3032
|
for optimization
|
|
2841
3033
|
'''
|
|
@@ -2847,13 +3039,28 @@ class JaxOfflineController(BaseAgent):
|
|
|
2847
3039
|
self.train_on_reset = train_on_reset
|
|
2848
3040
|
self.train_kwargs = train_kwargs
|
|
2849
3041
|
self.params_given = params is not None
|
|
3042
|
+
self.hyperparams_given = eval_hyperparams is not None
|
|
2850
3043
|
|
|
3044
|
+
# load the policy from file
|
|
3045
|
+
if not self.train_on_reset and params is not None and isinstance(params, str):
|
|
3046
|
+
with open(params, 'rb') as file:
|
|
3047
|
+
params = pickle.load(file)
|
|
3048
|
+
|
|
3049
|
+
# train the policy
|
|
2851
3050
|
self.step = 0
|
|
2852
3051
|
self.callback = None
|
|
2853
3052
|
if not self.train_on_reset and not self.params_given:
|
|
2854
3053
|
callback = self.planner.optimize(key=self.key, **self.train_kwargs)
|
|
2855
3054
|
self.callback = callback
|
|
2856
3055
|
params = callback['best_params']
|
|
3056
|
+
if not self.hyperparams_given:
|
|
3057
|
+
self.eval_hyperparams = callback['policy_hyperparams']
|
|
3058
|
+
|
|
3059
|
+
# save the policy
|
|
3060
|
+
if save_path is not None:
|
|
3061
|
+
with open(save_path, 'wb') as file:
|
|
3062
|
+
pickle.dump(params, file)
|
|
3063
|
+
|
|
2857
3064
|
self.params = params
|
|
2858
3065
|
|
|
2859
3066
|
def sample_action(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
|
@@ -2865,10 +3072,14 @@ class JaxOfflineController(BaseAgent):
|
|
|
2865
3072
|
|
|
2866
3073
|
def reset(self) -> None:
|
|
2867
3074
|
self.step = 0
|
|
3075
|
+
|
|
3076
|
+
# train the policy if required to reset at the start of every episode
|
|
2868
3077
|
if self.train_on_reset and not self.params_given:
|
|
2869
3078
|
callback = self.planner.optimize(key=self.key, **self.train_kwargs)
|
|
2870
3079
|
self.callback = callback
|
|
2871
3080
|
self.params = callback['best_params']
|
|
3081
|
+
if not self.hyperparams_given:
|
|
3082
|
+
self.eval_hyperparams = callback['policy_hyperparams']
|
|
2872
3083
|
|
|
2873
3084
|
|
|
2874
3085
|
class JaxOnlineController(BaseAgent):
|
|
@@ -2901,6 +3112,7 @@ class JaxOnlineController(BaseAgent):
|
|
|
2901
3112
|
key = random.PRNGKey(round(time.time() * 1000))
|
|
2902
3113
|
self.key = key
|
|
2903
3114
|
self.eval_hyperparams = eval_hyperparams
|
|
3115
|
+
self.hyperparams_given = eval_hyperparams is not None
|
|
2904
3116
|
self.warm_start = warm_start
|
|
2905
3117
|
self.train_kwargs = train_kwargs
|
|
2906
3118
|
self.max_attempts = max_attempts
|
|
@@ -2915,18 +3127,24 @@ class JaxOnlineController(BaseAgent):
|
|
|
2915
3127
|
attempts = 0
|
|
2916
3128
|
while attempts < self.max_attempts and callback['iteration'] <= 1:
|
|
2917
3129
|
attempts += 1
|
|
2918
|
-
|
|
2919
|
-
|
|
2920
|
-
|
|
2921
|
-
|
|
2922
|
-
|
|
3130
|
+
if self.planner.print_warnings:
|
|
3131
|
+
message = termcolor.colored(
|
|
3132
|
+
f'[WARN] JIT compilation dominated the execution time: '
|
|
3133
|
+
f'executing the optimizer again on the traced model '
|
|
3134
|
+
f'[attempt {attempts}].', 'yellow')
|
|
3135
|
+
print(message)
|
|
2923
3136
|
callback = planner.optimize(
|
|
2924
|
-
key=self.key, guess=self.guess, subs=state, **self.train_kwargs)
|
|
2925
|
-
|
|
3137
|
+
key=self.key, guess=self.guess, subs=state, **self.train_kwargs)
|
|
2926
3138
|
self.callback = callback
|
|
2927
3139
|
params = callback['best_params']
|
|
3140
|
+
if not self.hyperparams_given:
|
|
3141
|
+
self.eval_hyperparams = callback['policy_hyperparams']
|
|
3142
|
+
|
|
3143
|
+
# get the action from the parameters for the current state
|
|
2928
3144
|
self.key, subkey = random.split(self.key)
|
|
2929
3145
|
actions = planner.get_action(subkey, params, 0, state, self.eval_hyperparams)
|
|
3146
|
+
|
|
3147
|
+
# apply warm start for the next epoch
|
|
2930
3148
|
if self.warm_start:
|
|
2931
3149
|
self.guess = planner.plan.guess_next_epoch(params)
|
|
2932
3150
|
return actions
|