pyRDDLGym-jax 0.2__py3-none-any.whl → 0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -0
- pyRDDLGym_jax/core/compiler.py +1 -2
- pyRDDLGym_jax/core/planner.py +359 -155
- pyRDDLGym_jax/core/tuning.py +6 -3
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +1 -1
- pyRDDLGym_jax/examples/configs/default_drp.cfg +1 -1
- pyRDDLGym_jax/examples/configs/default_slp.cfg +1 -1
- pyRDDLGym_jax/examples/run_gym.py +2 -5
- pyRDDLGym_jax/examples/run_plan.py +6 -8
- pyRDDLGym_jax/examples/run_scipy.py +61 -0
- pyRDDLGym_jax/examples/run_tune.py +5 -6
- {pyRDDLGym_jax-0.2.dist-info → pyRDDLGym_jax-0.3.dist-info}/METADATA +1 -1
- {pyRDDLGym_jax-0.2.dist-info → pyRDDLGym_jax-0.3.dist-info}/RECORD +18 -20
- pyRDDLGym_jax/examples/configs/Pong_slp.cfg +0 -18
- pyRDDLGym_jax/examples/configs/SupplyChain_slp.cfg +0 -18
- pyRDDLGym_jax/examples/configs/Traffic_slp.cfg +0 -20
- {pyRDDLGym_jax-0.2.dist-info → pyRDDLGym_jax-0.3.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-0.2.dist-info → pyRDDLGym_jax-0.3.dist-info}/WHEEL +0 -0
- {pyRDDLGym_jax-0.2.dist-info → pyRDDLGym_jax-0.3.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/planner.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
__version__ = '0.2'
|
|
2
|
-
|
|
3
1
|
from ast import literal_eval
|
|
4
2
|
from collections import deque
|
|
5
3
|
import configparser
|
|
@@ -15,6 +13,7 @@ import os
|
|
|
15
13
|
import sys
|
|
16
14
|
import termcolor
|
|
17
15
|
import time
|
|
16
|
+
import traceback
|
|
18
17
|
from tqdm import tqdm
|
|
19
18
|
from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Tuple, Union
|
|
20
19
|
|
|
@@ -25,14 +24,17 @@ Pytree = Any
|
|
|
25
24
|
|
|
26
25
|
from pyRDDLGym.core.debug.exception import raise_warning
|
|
27
26
|
|
|
27
|
+
from pyRDDLGym_jax import __version__
|
|
28
|
+
|
|
28
29
|
# try to import matplotlib, if failed then skip plotting
|
|
29
30
|
try:
|
|
30
31
|
import matplotlib
|
|
31
32
|
import matplotlib.pyplot as plt
|
|
32
33
|
matplotlib.use('TkAgg')
|
|
33
34
|
except Exception:
|
|
34
|
-
raise_warning('
|
|
35
|
-
'plotting functionality
|
|
35
|
+
raise_warning('failed to import matplotlib: '
|
|
36
|
+
'plotting functionality will be disabled.', 'red')
|
|
37
|
+
traceback.print_exc()
|
|
36
38
|
plt = None
|
|
37
39
|
|
|
38
40
|
from pyRDDLGym.core.compiler.model import RDDLPlanningModel, RDDLLiftedModel
|
|
@@ -113,7 +115,8 @@ def _load_config(config, args):
|
|
|
113
115
|
# policy initialization
|
|
114
116
|
plan_initializer = plan_kwargs.get('initializer', None)
|
|
115
117
|
if plan_initializer is not None:
|
|
116
|
-
initializer = _getattr_any(
|
|
118
|
+
initializer = _getattr_any(
|
|
119
|
+
packages=[initializers, hk.initializers], item=plan_initializer)
|
|
117
120
|
if initializer is None:
|
|
118
121
|
raise_warning(
|
|
119
122
|
f'Ignoring invalid initializer <{plan_initializer}>.', 'red')
|
|
@@ -130,7 +133,8 @@ def _load_config(config, args):
|
|
|
130
133
|
# policy activation
|
|
131
134
|
plan_activation = plan_kwargs.get('activation', None)
|
|
132
135
|
if plan_activation is not None:
|
|
133
|
-
activation = _getattr_any(
|
|
136
|
+
activation = _getattr_any(
|
|
137
|
+
packages=[jax.nn, jax.numpy], item=plan_activation)
|
|
134
138
|
if activation is None:
|
|
135
139
|
raise_warning(
|
|
136
140
|
f'Ignoring invalid activation <{plan_activation}>.', 'red')
|
|
@@ -217,6 +221,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
217
221
|
:param *kwargs: keyword arguments to pass to base compiler
|
|
218
222
|
'''
|
|
219
223
|
super(JaxRDDLCompilerWithGrad, self).__init__(*args, **kwargs)
|
|
224
|
+
|
|
220
225
|
self.logic = logic
|
|
221
226
|
self.logic.set_use64bit(self.use64bit)
|
|
222
227
|
if cpfs_without_grad is None:
|
|
@@ -224,9 +229,14 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
224
229
|
self.cpfs_without_grad = cpfs_without_grad
|
|
225
230
|
|
|
226
231
|
# actions and CPFs must be continuous
|
|
227
|
-
|
|
232
|
+
pvars_cast = set()
|
|
228
233
|
for (var, values) in self.init_values.items():
|
|
229
234
|
self.init_values[var] = np.asarray(values, dtype=self.REAL)
|
|
235
|
+
if not np.issubdtype(np.atleast_1d(values).dtype, np.floating):
|
|
236
|
+
pvars_cast.add(var)
|
|
237
|
+
if pvars_cast:
|
|
238
|
+
raise_warning(f'JAX gradient compiler requires that initial values '
|
|
239
|
+
f'of p-variables {pvars_cast} be cast to float.')
|
|
230
240
|
|
|
231
241
|
# overwrite basic operations with fuzzy ones
|
|
232
242
|
self.RELATIONAL_OPS = {
|
|
@@ -273,20 +283,29 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
273
283
|
return _jax_wrapped_stop_grad
|
|
274
284
|
|
|
275
285
|
def _compile_cpfs(self, info):
|
|
276
|
-
|
|
286
|
+
cpfs_cast = set()
|
|
277
287
|
jax_cpfs = {}
|
|
278
288
|
for (_, cpfs) in self.levels.items():
|
|
279
289
|
for cpf in cpfs:
|
|
280
290
|
_, expr = self.rddl.cpfs[cpf]
|
|
281
291
|
jax_cpfs[cpf] = self._jax(expr, info, dtype=self.REAL)
|
|
292
|
+
if self.rddl.variable_ranges[cpf] != 'real':
|
|
293
|
+
cpfs_cast.add(cpf)
|
|
282
294
|
if cpf in self.cpfs_without_grad:
|
|
283
|
-
raise_warning(f'CPF <{cpf}> stops gradient.')
|
|
284
295
|
jax_cpfs[cpf] = self._jax_stop_grad(jax_cpfs[cpf])
|
|
296
|
+
|
|
297
|
+
if cpfs_cast:
|
|
298
|
+
raise_warning(f'JAX gradient compiler requires that outputs of CPFs '
|
|
299
|
+
f'{cpfs_cast} be cast to float.')
|
|
300
|
+
if self.cpfs_without_grad:
|
|
301
|
+
raise_warning(f'User requested that gradients not flow '
|
|
302
|
+
f'through CPFs {self.cpfs_without_grad}.')
|
|
285
303
|
return jax_cpfs
|
|
286
304
|
|
|
287
305
|
def _jax_kron(self, expr, info):
|
|
288
306
|
if self.logic.verbose:
|
|
289
|
-
raise_warning('
|
|
307
|
+
raise_warning('JAX gradient compiler ignores KronDelta '
|
|
308
|
+
'during compilation.')
|
|
290
309
|
arg, = expr.args
|
|
291
310
|
arg = self._jax(arg, info)
|
|
292
311
|
return arg
|
|
@@ -308,7 +327,8 @@ class JaxPlan:
|
|
|
308
327
|
self._train_policy = None
|
|
309
328
|
self._test_policy = None
|
|
310
329
|
self._projection = None
|
|
311
|
-
|
|
330
|
+
self.bounds = None
|
|
331
|
+
|
|
312
332
|
def summarize_hyperparameters(self) -> None:
|
|
313
333
|
pass
|
|
314
334
|
|
|
@@ -363,7 +383,7 @@ class JaxPlan:
|
|
|
363
383
|
# check invalid type
|
|
364
384
|
if prange not in compiled.JAX_TYPES:
|
|
365
385
|
raise RDDLTypeError(
|
|
366
|
-
f'Invalid range <{prange}
|
|
386
|
+
f'Invalid range <{prange}> of action-fluent <{name}>, '
|
|
367
387
|
f'must be one of {set(compiled.JAX_TYPES.keys())}.')
|
|
368
388
|
|
|
369
389
|
# clip boolean to (0, 1), otherwise use the RDDL action bounds
|
|
@@ -385,7 +405,7 @@ class JaxPlan:
|
|
|
385
405
|
~lower_finite & upper_finite,
|
|
386
406
|
~lower_finite & ~upper_finite]
|
|
387
407
|
bounds[name] = (lower, upper)
|
|
388
|
-
raise_warning(f'Bounds of action
|
|
408
|
+
raise_warning(f'Bounds of action-fluent <{name}> set to {bounds[name]}.')
|
|
389
409
|
return shapes, bounds, bounds_safe, cond_lists
|
|
390
410
|
|
|
391
411
|
def _count_bool_actions(self, rddl: RDDLLiftedModel):
|
|
@@ -427,6 +447,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
427
447
|
use_new_projection = True
|
|
428
448
|
'''
|
|
429
449
|
super(JaxStraightLinePlan, self).__init__()
|
|
450
|
+
|
|
430
451
|
self._initializer_base = initializer
|
|
431
452
|
self._initializer = initializer
|
|
432
453
|
self._wrap_sigmoid = wrap_sigmoid
|
|
@@ -437,9 +458,12 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
437
458
|
self._max_constraint_iter = max_constraint_iter
|
|
438
459
|
|
|
439
460
|
def summarize_hyperparameters(self) -> None:
|
|
461
|
+
bounds = '\n '.join(
|
|
462
|
+
map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
|
|
440
463
|
print(f'policy hyper-parameters:\n'
|
|
441
|
-
f' initializer ={
|
|
464
|
+
f' initializer ={self._initializer_base}\n'
|
|
442
465
|
f'constraint-sat strategy (simple):\n'
|
|
466
|
+
f' parsed_action_bounds =\n {bounds}\n'
|
|
443
467
|
f' wrap_sigmoid ={self._wrap_sigmoid}\n'
|
|
444
468
|
f' wrap_sigmoid_min_prob={self._min_action_prob}\n'
|
|
445
469
|
f' wrap_non_bool ={self._wrap_non_bool}\n'
|
|
@@ -603,7 +627,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
603
627
|
if 1 < allowed_actions < bool_action_count:
|
|
604
628
|
raise RDDLNotImplementedError(
|
|
605
629
|
f'Straight-line plans with wrap_softmax currently '
|
|
606
|
-
f'do not support max-nondef-actions
|
|
630
|
+
f'do not support max-nondef-actions {allowed_actions} > 1.')
|
|
607
631
|
|
|
608
632
|
# potentially apply projection but to non-bool actions only
|
|
609
633
|
self.projection = _jax_wrapped_slp_project_to_box
|
|
@@ -734,14 +758,14 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
734
758
|
for (var, shape) in shapes.items():
|
|
735
759
|
if ranges[var] != 'bool' or not stack_bool_params:
|
|
736
760
|
key, subkey = random.split(key)
|
|
737
|
-
param = init(subkey, shape, dtype=compiled.REAL)
|
|
761
|
+
param = init(key=subkey, shape=shape, dtype=compiled.REAL)
|
|
738
762
|
if ranges[var] == 'bool':
|
|
739
763
|
param += bool_threshold
|
|
740
764
|
params[var] = param
|
|
741
765
|
if stack_bool_params:
|
|
742
766
|
key, subkey = random.split(key)
|
|
743
767
|
bool_shape = (horizon, bool_action_count)
|
|
744
|
-
bool_param = init(subkey, bool_shape, dtype=compiled.REAL)
|
|
768
|
+
bool_param = init(key=subkey, shape=bool_shape, dtype=compiled.REAL)
|
|
745
769
|
params[bool_key] = bool_param
|
|
746
770
|
params, _ = _jax_wrapped_slp_project_to_box(params, hyperparams)
|
|
747
771
|
return params
|
|
@@ -765,7 +789,8 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
765
789
|
def __init__(self, topology: Optional[Sequence[int]]=None,
|
|
766
790
|
activation: Activation=jnp.tanh,
|
|
767
791
|
initializer: hk.initializers.Initializer=hk.initializers.VarianceScaling(scale=2.0),
|
|
768
|
-
normalize: bool=
|
|
792
|
+
normalize: bool=False,
|
|
793
|
+
normalize_per_layer: bool=False,
|
|
769
794
|
normalizer_kwargs: Optional[Kwargs]=None,
|
|
770
795
|
wrap_non_bool: bool=False) -> None:
|
|
771
796
|
'''Creates a new deep reactive policy in JAX.
|
|
@@ -775,12 +800,15 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
775
800
|
:param activation: function to apply after each layer of the policy
|
|
776
801
|
:param initializer: weight initialization
|
|
777
802
|
:param normalize: whether to apply layer norm to the inputs
|
|
803
|
+
:param normalize_per_layer: whether to apply layer norm to each input
|
|
804
|
+
individually (only active if normalize is True)
|
|
778
805
|
:param normalizer_kwargs: if normalize is True, apply additional arguments
|
|
779
806
|
to layer norm
|
|
780
807
|
:param wrap_non_bool: whether to wrap real or int action fluent parameters
|
|
781
808
|
with non-linearity (e.g. sigmoid or ELU) to satisfy box constraints
|
|
782
809
|
'''
|
|
783
810
|
super(JaxDeepReactivePolicy, self).__init__()
|
|
811
|
+
|
|
784
812
|
if topology is None:
|
|
785
813
|
topology = [128, 64]
|
|
786
814
|
self._topology = topology
|
|
@@ -788,22 +816,25 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
788
816
|
self._initializer_base = initializer
|
|
789
817
|
self._initializer = initializer
|
|
790
818
|
self._normalize = normalize
|
|
819
|
+
self._normalize_per_layer = normalize_per_layer
|
|
791
820
|
if normalizer_kwargs is None:
|
|
792
|
-
normalizer_kwargs = {
|
|
793
|
-
'create_offset': True, 'create_scale': True,
|
|
794
|
-
'name': 'input_norm'
|
|
795
|
-
}
|
|
821
|
+
normalizer_kwargs = {'create_offset': True, 'create_scale': True}
|
|
796
822
|
self._normalizer_kwargs = normalizer_kwargs
|
|
797
823
|
self._wrap_non_bool = wrap_non_bool
|
|
798
824
|
|
|
799
825
|
def summarize_hyperparameters(self) -> None:
|
|
826
|
+
bounds = '\n '.join(
|
|
827
|
+
map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
|
|
800
828
|
print(f'policy hyper-parameters:\n'
|
|
801
|
-
f' topology
|
|
802
|
-
f' activation_fn
|
|
803
|
-
f' initializer
|
|
804
|
-
f'
|
|
805
|
-
f'
|
|
806
|
-
f'
|
|
829
|
+
f' topology ={self._topology}\n'
|
|
830
|
+
f' activation_fn ={self._activations[0].__name__}\n'
|
|
831
|
+
f' initializer ={type(self._initializer_base).__name__}\n'
|
|
832
|
+
f' apply_input_norm ={self._normalize}\n'
|
|
833
|
+
f' input_norm_layerwise={self._normalize_per_layer}\n'
|
|
834
|
+
f' input_norm_args ={self._normalizer_kwargs}\n'
|
|
835
|
+
f'constraint-sat strategy:\n'
|
|
836
|
+
f' parsed_action_bounds=\n {bounds}\n'
|
|
837
|
+
f' wrap_non_bool ={self._wrap_non_bool}')
|
|
807
838
|
|
|
808
839
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
809
840
|
_bounds: Bounds,
|
|
@@ -821,7 +852,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
821
852
|
if 1 < allowed_actions < bool_action_count:
|
|
822
853
|
raise RDDLNotImplementedError(
|
|
823
854
|
f'Deep reactive policies currently do not support '
|
|
824
|
-
f'max-nondef-actions
|
|
855
|
+
f'max-nondef-actions {allowed_actions} > 1.')
|
|
825
856
|
use_constraint_satisfaction = allowed_actions < bool_action_count
|
|
826
857
|
|
|
827
858
|
noop = {var: (values[0] if isinstance(values, list) else values)
|
|
@@ -835,6 +866,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
835
866
|
|
|
836
867
|
ranges = rddl.variable_ranges
|
|
837
868
|
normalize = self._normalize
|
|
869
|
+
normalize_per_layer = self._normalize_per_layer
|
|
838
870
|
wrap_non_bool = self._wrap_non_bool
|
|
839
871
|
init = self._initializer
|
|
840
872
|
layers = list(enumerate(zip(self._topology, self._activations)))
|
|
@@ -842,14 +874,67 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
842
874
|
for (var, shape) in shapes.items()}
|
|
843
875
|
layer_names = {var: f'output_{var}'.replace('-', '_') for var in shapes}
|
|
844
876
|
|
|
845
|
-
#
|
|
846
|
-
|
|
877
|
+
# inputs for the policy network
|
|
878
|
+
if rddl.observ_fluents:
|
|
879
|
+
observed_vars = rddl.observ_fluents
|
|
880
|
+
else:
|
|
881
|
+
observed_vars = rddl.state_fluents
|
|
882
|
+
input_names = {var: f'{var}'.replace('-', '_') for var in observed_vars}
|
|
883
|
+
|
|
884
|
+
# catch if input norm is applied to size 1 tensor
|
|
885
|
+
if normalize:
|
|
886
|
+
non_bool_dims = 0
|
|
887
|
+
for (var, values) in observed_vars.items():
|
|
888
|
+
if ranges[var] != 'bool':
|
|
889
|
+
value_size = np.atleast_1d(values).size
|
|
890
|
+
if normalize_per_layer and value_size == 1:
|
|
891
|
+
raise_warning(
|
|
892
|
+
f'Cannot apply layer norm to state-fluent <{var}> '
|
|
893
|
+
f'of size 1: setting normalize_per_layer = False.',
|
|
894
|
+
'red')
|
|
895
|
+
normalize_per_layer = False
|
|
896
|
+
non_bool_dims += value_size
|
|
897
|
+
if not normalize_per_layer and non_bool_dims == 1:
|
|
898
|
+
raise_warning(
|
|
899
|
+
'Cannot apply layer norm to state-fluents of total size 1: '
|
|
900
|
+
'setting normalize = False.', 'red')
|
|
901
|
+
normalize = False
|
|
902
|
+
|
|
903
|
+
# convert subs dictionary into a state vector to feed to the MLP
|
|
904
|
+
def _jax_wrapped_policy_input(subs):
|
|
905
|
+
|
|
906
|
+
# concatenate all state variables into a single vector
|
|
907
|
+
# optionally apply layer norm to each input tensor
|
|
908
|
+
states_bool, states_non_bool = [], []
|
|
909
|
+
non_bool_dims = 0
|
|
910
|
+
for (var, value) in subs.items():
|
|
911
|
+
if var in observed_vars:
|
|
912
|
+
state = jnp.ravel(value)
|
|
913
|
+
if ranges[var] == 'bool':
|
|
914
|
+
states_bool.append(state)
|
|
915
|
+
else:
|
|
916
|
+
if normalize and normalize_per_layer:
|
|
917
|
+
normalizer = hk.LayerNorm(
|
|
918
|
+
axis=-1, param_axis=-1,
|
|
919
|
+
name=f'input_norm_{input_names[var]}',
|
|
920
|
+
**self._normalizer_kwargs)
|
|
921
|
+
state = normalizer(state)
|
|
922
|
+
states_non_bool.append(state)
|
|
923
|
+
non_bool_dims += state.size
|
|
924
|
+
state = jnp.concatenate(states_non_bool + states_bool)
|
|
847
925
|
|
|
848
|
-
#
|
|
849
|
-
if normalize:
|
|
926
|
+
# optionally perform layer normalization on the non-bool inputs
|
|
927
|
+
if normalize and not normalize_per_layer and non_bool_dims:
|
|
850
928
|
normalizer = hk.LayerNorm(
|
|
851
|
-
axis=-1, param_axis=-1,
|
|
852
|
-
|
|
929
|
+
axis=-1, param_axis=-1, name='input_norm',
|
|
930
|
+
**self._normalizer_kwargs)
|
|
931
|
+
normalized = normalizer(state[:non_bool_dims])
|
|
932
|
+
state = state.at[:non_bool_dims].set(normalized)
|
|
933
|
+
return state
|
|
934
|
+
|
|
935
|
+
# predict actions from the policy network for current state
|
|
936
|
+
def _jax_wrapped_policy_network_predict(subs):
|
|
937
|
+
state = _jax_wrapped_policy_input(subs)
|
|
853
938
|
|
|
854
939
|
# feed state vector through hidden layers
|
|
855
940
|
hidden = state
|
|
@@ -913,25 +998,9 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
913
998
|
start += size
|
|
914
999
|
return actions
|
|
915
1000
|
|
|
916
|
-
if rddl.observ_fluents:
|
|
917
|
-
observed_vars = rddl.observ_fluents
|
|
918
|
-
else:
|
|
919
|
-
observed_vars = rddl.state_fluents
|
|
920
|
-
|
|
921
|
-
# state is concatenated into single tensor
|
|
922
|
-
def _jax_wrapped_subs_to_state(subs):
|
|
923
|
-
subs = {var: value
|
|
924
|
-
for (var, value) in subs.items()
|
|
925
|
-
if var in observed_vars}
|
|
926
|
-
flat_subs = jax.tree_map(jnp.ravel, subs)
|
|
927
|
-
states = list(flat_subs.values())
|
|
928
|
-
state = jnp.concatenate(states)
|
|
929
|
-
return state
|
|
930
|
-
|
|
931
1001
|
# train action prediction
|
|
932
1002
|
def _jax_wrapped_drp_predict_train(key, params, hyperparams, step, subs):
|
|
933
|
-
|
|
934
|
-
actions = predict_fn.apply(params, state)
|
|
1003
|
+
actions = predict_fn.apply(params, subs)
|
|
935
1004
|
if not wrap_non_bool:
|
|
936
1005
|
for (var, action) in actions.items():
|
|
937
1006
|
if var != bool_key and ranges[var] != 'bool':
|
|
@@ -982,8 +1051,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
982
1051
|
subs = {var: value[0, ...]
|
|
983
1052
|
for (var, value) in subs.items()
|
|
984
1053
|
if var in observed_vars}
|
|
985
|
-
|
|
986
|
-
params = predict_fn.init(key, state)
|
|
1054
|
+
params = predict_fn.init(key, subs)
|
|
987
1055
|
return params
|
|
988
1056
|
|
|
989
1057
|
self.initializer = _jax_wrapped_drp_init
|
|
@@ -1021,46 +1089,72 @@ class RollingMean:
|
|
|
1021
1089
|
class JaxPlannerPlot:
|
|
1022
1090
|
'''Supports plotting and visualization of a JAX policy in real time.'''
|
|
1023
1091
|
|
|
1024
|
-
def __init__(self, rddl: RDDLPlanningModel, horizon: int
|
|
1025
|
-
|
|
1092
|
+
def __init__(self, rddl: RDDLPlanningModel, horizon: int,
|
|
1093
|
+
show_violin: bool=True, show_action: bool=True) -> None:
|
|
1094
|
+
'''Creates a new planner visualizer.
|
|
1095
|
+
|
|
1096
|
+
:param rddl: the planning model to optimize
|
|
1097
|
+
:param horizon: the lookahead or planning horizon
|
|
1098
|
+
:param show_violin: whether to show the distribution of batch losses
|
|
1099
|
+
:param show_action: whether to show heatmaps of the action fluents
|
|
1100
|
+
'''
|
|
1101
|
+
num_plots = 1
|
|
1102
|
+
if show_violin:
|
|
1103
|
+
num_plots += 1
|
|
1104
|
+
if show_action:
|
|
1105
|
+
num_plots += len(rddl.action_fluents)
|
|
1106
|
+
self._fig, axes = plt.subplots(num_plots)
|
|
1107
|
+
if num_plots == 1:
|
|
1108
|
+
axes = [axes]
|
|
1026
1109
|
|
|
1027
1110
|
# prepare the loss plot
|
|
1028
1111
|
self._loss_ax = axes[0]
|
|
1029
1112
|
self._loss_ax.autoscale(enable=True)
|
|
1030
|
-
self._loss_ax.set_xlabel('
|
|
1113
|
+
self._loss_ax.set_xlabel('training time')
|
|
1031
1114
|
self._loss_ax.set_ylabel('loss value')
|
|
1032
1115
|
self._loss_plot = self._loss_ax.plot(
|
|
1033
1116
|
[], [], linestyle=':', marker='o', markersize=2)[0]
|
|
1034
1117
|
self._loss_back = self._fig.canvas.copy_from_bbox(self._loss_ax.bbox)
|
|
1035
1118
|
|
|
1119
|
+
# prepare the violin plot
|
|
1120
|
+
if show_violin:
|
|
1121
|
+
self._hist_ax = axes[1]
|
|
1122
|
+
else:
|
|
1123
|
+
self._hist_ax = None
|
|
1124
|
+
|
|
1036
1125
|
# prepare the action plots
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1126
|
+
if show_action:
|
|
1127
|
+
self._action_ax = {name: axes[idx + (2 if show_violin else 1)]
|
|
1128
|
+
for (idx, name) in enumerate(rddl.action_fluents)}
|
|
1129
|
+
self._action_plots = {}
|
|
1130
|
+
for name in rddl.action_fluents:
|
|
1131
|
+
ax = self._action_ax[name]
|
|
1132
|
+
if rddl.variable_ranges[name] == 'bool':
|
|
1133
|
+
vmin, vmax = 0.0, 1.0
|
|
1134
|
+
else:
|
|
1135
|
+
vmin, vmax = None, None
|
|
1136
|
+
action_dim = 1
|
|
1137
|
+
for dim in rddl.object_counts(rddl.variable_params[name]):
|
|
1138
|
+
action_dim *= dim
|
|
1139
|
+
action_plot = ax.pcolormesh(
|
|
1140
|
+
np.zeros((action_dim, horizon)),
|
|
1141
|
+
cmap='seismic', vmin=vmin, vmax=vmax)
|
|
1142
|
+
ax.set_aspect('auto')
|
|
1143
|
+
ax.set_xlabel('decision epoch')
|
|
1144
|
+
ax.set_ylabel(name)
|
|
1145
|
+
plt.colorbar(action_plot, ax=ax)
|
|
1146
|
+
self._action_plots[name] = action_plot
|
|
1147
|
+
self._action_back = {name: self._fig.canvas.copy_from_bbox(ax.bbox)
|
|
1148
|
+
for (name, ax) in self._action_ax.items()}
|
|
1149
|
+
else:
|
|
1150
|
+
self._action_ax = None
|
|
1151
|
+
self._action_plots = None
|
|
1152
|
+
self._action_back = None
|
|
1153
|
+
|
|
1060
1154
|
plt.tight_layout()
|
|
1061
1155
|
plt.show(block=False)
|
|
1062
1156
|
|
|
1063
|
-
def redraw(self, xticks, losses, actions) -> None:
|
|
1157
|
+
def redraw(self, xticks, losses, actions, returns) -> None:
|
|
1064
1158
|
|
|
1065
1159
|
# draw the loss curve
|
|
1066
1160
|
self._fig.canvas.restore_region(self._loss_back)
|
|
@@ -1071,21 +1165,30 @@ class JaxPlannerPlot:
|
|
|
1071
1165
|
self._loss_ax.draw_artist(self._loss_plot)
|
|
1072
1166
|
self._fig.canvas.blit(self._loss_ax.bbox)
|
|
1073
1167
|
|
|
1168
|
+
# draw the violin plot
|
|
1169
|
+
if self._hist_ax is not None:
|
|
1170
|
+
self._hist_ax.clear()
|
|
1171
|
+
self._hist_ax.set_xlabel('loss value')
|
|
1172
|
+
self._hist_ax.set_ylabel('density')
|
|
1173
|
+
self._hist_ax.violinplot(returns, vert=False, showmeans=True)
|
|
1174
|
+
|
|
1074
1175
|
# draw the actions
|
|
1075
|
-
|
|
1076
|
-
values
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
|
|
1176
|
+
if self._action_ax is not None:
|
|
1177
|
+
for (name, values) in actions.items():
|
|
1178
|
+
values = np.mean(values, axis=0, dtype=float)
|
|
1179
|
+
values = np.reshape(values, newshape=(values.shape[0], -1)).T
|
|
1180
|
+
self._fig.canvas.restore_region(self._action_back[name])
|
|
1181
|
+
self._action_plots[name].set_array(values)
|
|
1182
|
+
self._action_ax[name].draw_artist(self._action_plots[name])
|
|
1183
|
+
self._fig.canvas.blit(self._action_ax[name].bbox)
|
|
1184
|
+
self._action_plots[name].set_clim([np.min(values), np.max(values)])
|
|
1185
|
+
|
|
1083
1186
|
self._fig.canvas.draw()
|
|
1084
1187
|
self._fig.canvas.flush_events()
|
|
1085
1188
|
|
|
1086
1189
|
def close(self) -> None:
|
|
1087
1190
|
plt.close(self._fig)
|
|
1088
|
-
del self._loss_ax, self._action_ax, \
|
|
1191
|
+
del self._loss_ax, self._hist_ax, self._action_ax, \
|
|
1089
1192
|
self._loss_plot, self._action_plots, self._fig, \
|
|
1090
1193
|
self._loss_back, self._action_back
|
|
1091
1194
|
|
|
@@ -1099,9 +1202,9 @@ class JaxPlannerStatus(Enum):
|
|
|
1099
1202
|
NORMAL = 0
|
|
1100
1203
|
NO_PROGRESS = 1
|
|
1101
1204
|
PRECONDITION_POSSIBLY_UNSATISFIED = 2
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
|
|
1205
|
+
INVALID_GRADIENT = 3
|
|
1206
|
+
TIME_BUDGET_REACHED = 4
|
|
1207
|
+
ITER_BUDGET_REACHED = 5
|
|
1105
1208
|
|
|
1106
1209
|
def is_failure(self) -> bool:
|
|
1107
1210
|
return self.value >= 3
|
|
@@ -1249,26 +1352,27 @@ class JaxBackpropPlanner:
|
|
|
1249
1352
|
f'JAX Planner version {__version__}\n'
|
|
1250
1353
|
f'Python {sys.version}\n'
|
|
1251
1354
|
f'jax {jax.version.__version__}, jaxlib {jaxlib_version}, '
|
|
1355
|
+
f'optax {optax.__version__}, haiku {hk.__version__}, '
|
|
1252
1356
|
f'numpy {np.__version__}\n'
|
|
1253
1357
|
f'devices: {devices_short}\n')
|
|
1254
1358
|
|
|
1255
1359
|
def summarize_hyperparameters(self) -> None:
|
|
1256
1360
|
print(f'objective hyper-parameters:\n'
|
|
1257
|
-
f' utility_fn
|
|
1258
|
-
f' utility args
|
|
1259
|
-
f' use_symlog
|
|
1260
|
-
f' lookahead
|
|
1261
|
-
f'
|
|
1262
|
-
f' fuzzy logic type={type(self.logic).__name__}\n'
|
|
1263
|
-
f' nonfluents exact={self.compile_non_fluent_exact}\n'
|
|
1264
|
-
f' cpfs_no_gradient={self.cpfs_without_grad}\n'
|
|
1361
|
+
f' utility_fn ={self.utility.__name__}\n'
|
|
1362
|
+
f' utility args ={self.utility_kwargs}\n'
|
|
1363
|
+
f' use_symlog ={self.use_symlog_reward}\n'
|
|
1364
|
+
f' lookahead ={self.horizon}\n'
|
|
1365
|
+
f' user_action_bounds={self._action_bounds}\n'
|
|
1366
|
+
f' fuzzy logic type ={type(self.logic).__name__}\n'
|
|
1367
|
+
f' nonfluents exact ={self.compile_non_fluent_exact}\n'
|
|
1368
|
+
f' cpfs_no_gradient ={self.cpfs_without_grad}\n'
|
|
1265
1369
|
f'optimizer hyper-parameters:\n'
|
|
1266
|
-
f' use_64_bit
|
|
1267
|
-
f' optimizer
|
|
1268
|
-
f' optimizer args
|
|
1269
|
-
f' clip_gradient
|
|
1270
|
-
f' batch_size_train={self.batch_size_train}\n'
|
|
1271
|
-
f' batch_size_test
|
|
1370
|
+
f' use_64_bit ={self.use64bit}\n'
|
|
1371
|
+
f' optimizer ={self._optimizer_name.__name__}\n'
|
|
1372
|
+
f' optimizer args ={self._optimizer_kwargs}\n'
|
|
1373
|
+
f' clip_gradient ={self.clip_grad}\n'
|
|
1374
|
+
f' batch_size_train ={self.batch_size_train}\n'
|
|
1375
|
+
f' batch_size_test ={self.batch_size_test}')
|
|
1272
1376
|
self.plan.summarize_hyperparameters()
|
|
1273
1377
|
self.logic.summarize_hyperparameters()
|
|
1274
1378
|
|
|
@@ -1310,6 +1414,7 @@ class JaxBackpropPlanner:
|
|
|
1310
1414
|
policy=self.plan.train_policy,
|
|
1311
1415
|
n_steps=self.horizon,
|
|
1312
1416
|
n_batch=self.batch_size_train)
|
|
1417
|
+
self.train_rollouts = train_rollouts
|
|
1313
1418
|
|
|
1314
1419
|
test_rollouts = self.test_compiled.compile_rollouts(
|
|
1315
1420
|
policy=self.plan.test_policy,
|
|
@@ -1417,17 +1522,106 @@ class JaxBackpropPlanner:
|
|
|
1417
1522
|
|
|
1418
1523
|
return init_train, init_test
|
|
1419
1524
|
|
|
1525
|
+
def as_optimization_problem(
|
|
1526
|
+
self, key: Optional[random.PRNGKey]=None,
|
|
1527
|
+
policy_hyperparams: Optional[Pytree]=None,
|
|
1528
|
+
loss_function_updates_key: bool=True,
|
|
1529
|
+
grad_function_updates_key: bool=False) -> Tuple[Callable, Callable, np.ndarray, Callable]:
|
|
1530
|
+
'''Returns a function that computes the loss and a function that
|
|
1531
|
+
computes gradient of the return as a 1D vector given a 1D representation
|
|
1532
|
+
of policy parameters. These functions are designed to be compatible with
|
|
1533
|
+
off-the-shelf optimizers such as scipy.
|
|
1534
|
+
|
|
1535
|
+
Also returns the initial parameter vector to seed an optimizer,
|
|
1536
|
+
as well as a mapping that recovers the parameter pytree from the vector.
|
|
1537
|
+
The PRNG key is updated internally starting from the optional given key.
|
|
1538
|
+
|
|
1539
|
+
Constraints on actions, if they are required, cannot be constructed
|
|
1540
|
+
automatically in the general case. The user should build constraints
|
|
1541
|
+
for each problem in the format required by the downstream optimizer.
|
|
1542
|
+
|
|
1543
|
+
:param key: JAX PRNG key (derived from clock if not provided)
|
|
1544
|
+
:param policy_hyperparameters: hyper-parameters for the policy/plan,
|
|
1545
|
+
such as weights for sigmoid wrapping boolean actions (defaults to 1
|
|
1546
|
+
for all action-fluents if not provided)
|
|
1547
|
+
:param loss_function_updates_key: if True, the loss function
|
|
1548
|
+
updates the PRNG key internally independently of the grad function
|
|
1549
|
+
:param grad_function_updates_key: if True, the gradient function
|
|
1550
|
+
updates the PRNG key internally independently of the loss function.
|
|
1551
|
+
'''
|
|
1552
|
+
|
|
1553
|
+
# if PRNG key is not provided
|
|
1554
|
+
if key is None:
|
|
1555
|
+
key = random.PRNGKey(round(time.time() * 1000))
|
|
1556
|
+
|
|
1557
|
+
# initialize the initial fluents, model parameters, policy hyper-params
|
|
1558
|
+
subs = self.test_compiled.init_values
|
|
1559
|
+
train_subs, _ = self._batched_init_subs(subs)
|
|
1560
|
+
model_params = self.compiled.model_params
|
|
1561
|
+
if policy_hyperparams is None:
|
|
1562
|
+
raise_warning('policy_hyperparams is not set, setting 1.0 for '
|
|
1563
|
+
'all action-fluents which could be suboptimal.')
|
|
1564
|
+
policy_hyperparams = {action: 1.0
|
|
1565
|
+
for action in self.rddl.action_fluents}
|
|
1566
|
+
|
|
1567
|
+
# initialize the policy parameters
|
|
1568
|
+
params_guess, *_ = self.initialize(key, policy_hyperparams, train_subs)
|
|
1569
|
+
guess_1d, unravel_fn = jax.flatten_util.ravel_pytree(params_guess)
|
|
1570
|
+
guess_1d = np.asarray(guess_1d)
|
|
1571
|
+
|
|
1572
|
+
# computes the training loss function and its 1D gradient
|
|
1573
|
+
loss_fn = self._jax_loss(self.train_rollouts)
|
|
1574
|
+
|
|
1575
|
+
@jax.jit
|
|
1576
|
+
def _loss_with_key(key, params_1d):
|
|
1577
|
+
policy_params = unravel_fn(params_1d)
|
|
1578
|
+
loss_val, _ = loss_fn(key, policy_params, policy_hyperparams,
|
|
1579
|
+
train_subs, model_params)
|
|
1580
|
+
return loss_val
|
|
1581
|
+
|
|
1582
|
+
@jax.jit
|
|
1583
|
+
def _grad_with_key(key, params_1d):
|
|
1584
|
+
policy_params = unravel_fn(params_1d)
|
|
1585
|
+
grad_fn = jax.grad(loss_fn, argnums=1, has_aux=True)
|
|
1586
|
+
grad_val, _ = grad_fn(key, policy_params, policy_hyperparams,
|
|
1587
|
+
train_subs, model_params)
|
|
1588
|
+
grad_1d = jax.flatten_util.ravel_pytree(grad_val)[0]
|
|
1589
|
+
return grad_1d
|
|
1590
|
+
|
|
1591
|
+
def _loss_function(params_1d):
|
|
1592
|
+
nonlocal key
|
|
1593
|
+
if loss_function_updates_key:
|
|
1594
|
+
key, subkey = random.split(key)
|
|
1595
|
+
else:
|
|
1596
|
+
subkey = key
|
|
1597
|
+
loss_val = _loss_with_key(subkey, params_1d)
|
|
1598
|
+
loss_val = float(loss_val)
|
|
1599
|
+
return loss_val
|
|
1600
|
+
|
|
1601
|
+
def _grad_function(params_1d):
|
|
1602
|
+
nonlocal key
|
|
1603
|
+
if grad_function_updates_key:
|
|
1604
|
+
key, subkey = random.split(key)
|
|
1605
|
+
else:
|
|
1606
|
+
subkey = key
|
|
1607
|
+
grad = _grad_with_key(subkey, params_1d)
|
|
1608
|
+
grad = np.asarray(grad)
|
|
1609
|
+
return grad
|
|
1610
|
+
|
|
1611
|
+
return _loss_function, _grad_function, guess_1d, jax.jit(unravel_fn)
|
|
1612
|
+
|
|
1420
1613
|
# ===========================================================================
|
|
1421
1614
|
# OPTIMIZE API
|
|
1422
1615
|
# ===========================================================================
|
|
1423
1616
|
|
|
1424
1617
|
def optimize(self, *args, **kwargs) -> Dict[str, Any]:
|
|
1425
|
-
'''
|
|
1618
|
+
'''Compute an optimal policy or plan. Return the callback from training.
|
|
1426
1619
|
|
|
1427
1620
|
:param key: JAX PRNG key (derived from clock if not provided)
|
|
1428
1621
|
:param epochs: the maximum number of steps of gradient descent
|
|
1429
1622
|
:param train_seconds: total time allocated for gradient descent
|
|
1430
1623
|
:param plot_step: frequency to plot the plan and save result to disk
|
|
1624
|
+
:param plot_kwargs: additional arguments to pass to the plotter
|
|
1431
1625
|
:param model_params: optional model-parameters to override default
|
|
1432
1626
|
:param policy_hyperparams: hyper-parameters for the policy/plan, such as
|
|
1433
1627
|
weights for sigmoid wrapping boolean actions
|
|
@@ -1435,7 +1629,9 @@ class JaxBackpropPlanner:
|
|
|
1435
1629
|
their values: if None initializes all variables from the RDDL instance
|
|
1436
1630
|
:param guess: initial policy parameters: if None will use the initializer
|
|
1437
1631
|
specified in this instance
|
|
1438
|
-
:param
|
|
1632
|
+
:param print_summary: whether to print planner header, parameter
|
|
1633
|
+
summary, and diagnosis
|
|
1634
|
+
:param print_progress: whether to print the progress bar during training
|
|
1439
1635
|
:param test_rolling_window: the test return is averaged on a rolling
|
|
1440
1636
|
window of the past test_rolling_window returns when updating the best
|
|
1441
1637
|
parameters found so far
|
|
@@ -1461,11 +1657,13 @@ class JaxBackpropPlanner:
|
|
|
1461
1657
|
epochs: int=999999,
|
|
1462
1658
|
train_seconds: float=120.,
|
|
1463
1659
|
plot_step: Optional[int]=None,
|
|
1660
|
+
plot_kwargs: Optional[Dict[str, Any]]=None,
|
|
1464
1661
|
model_params: Optional[Dict[str, Any]]=None,
|
|
1465
1662
|
policy_hyperparams: Optional[Dict[str, Any]]=None,
|
|
1466
1663
|
subs: Optional[Dict[str, Any]]=None,
|
|
1467
1664
|
guess: Optional[Pytree]=None,
|
|
1468
|
-
|
|
1665
|
+
print_summary: bool=True,
|
|
1666
|
+
print_progress: bool=True,
|
|
1469
1667
|
test_rolling_window: int=10,
|
|
1470
1668
|
tqdm_position: Optional[int]=None) -> Generator[Dict[str, Any], None, None]:
|
|
1471
1669
|
'''Returns a generator for computing an optimal policy or plan.
|
|
@@ -1476,20 +1674,22 @@ class JaxBackpropPlanner:
|
|
|
1476
1674
|
:param epochs: the maximum number of steps of gradient descent
|
|
1477
1675
|
:param train_seconds: total time allocated for gradient descent
|
|
1478
1676
|
:param plot_step: frequency to plot the plan and save result to disk
|
|
1677
|
+
:param plot_kwargs: additional arguments to pass to the plotter
|
|
1479
1678
|
:param model_params: optional model-parameters to override default
|
|
1480
1679
|
:param policy_hyperparams: hyper-parameters for the policy/plan, such as
|
|
1481
1680
|
weights for sigmoid wrapping boolean actions
|
|
1482
1681
|
:param subs: dictionary mapping initial state and non-fluents to
|
|
1483
1682
|
their values: if None initializes all variables from the RDDL instance
|
|
1484
1683
|
:param guess: initial policy parameters: if None will use the initializer
|
|
1485
|
-
specified in this instance
|
|
1486
|
-
:param
|
|
1684
|
+
specified in this instance
|
|
1685
|
+
:param print_summary: whether to print planner header, parameter
|
|
1686
|
+
summary, and diagnosis
|
|
1687
|
+
:param print_progress: whether to print the progress bar during training
|
|
1487
1688
|
:param test_rolling_window: the test return is averaged on a rolling
|
|
1488
1689
|
window of the past test_rolling_window returns when updating the best
|
|
1489
1690
|
parameters found so far
|
|
1490
1691
|
:param tqdm_position: position of tqdm progress bar (for multiprocessing)
|
|
1491
1692
|
'''
|
|
1492
|
-
verbose = int(verbose)
|
|
1493
1693
|
start_time = time.time()
|
|
1494
1694
|
elapsed_outside_loop = 0
|
|
1495
1695
|
|
|
@@ -1513,7 +1713,7 @@ class JaxBackpropPlanner:
|
|
|
1513
1713
|
for action in self.rddl.action_fluents}
|
|
1514
1714
|
|
|
1515
1715
|
# print summary of parameters:
|
|
1516
|
-
if
|
|
1716
|
+
if print_summary:
|
|
1517
1717
|
self._summarize_system()
|
|
1518
1718
|
self.summarize_hyperparameters()
|
|
1519
1719
|
print(f'optimize() call hyper-parameters:\n'
|
|
@@ -1526,8 +1726,10 @@ class JaxBackpropPlanner:
|
|
|
1526
1726
|
f' provide_param_guess={guess is not None}\n'
|
|
1527
1727
|
f' test_rolling_window={test_rolling_window}\n'
|
|
1528
1728
|
f' plot_frequency ={plot_step}\n'
|
|
1529
|
-
f'
|
|
1530
|
-
|
|
1729
|
+
f' plot_kwargs ={plot_kwargs}\n'
|
|
1730
|
+
f' print_summary ={print_summary}\n'
|
|
1731
|
+
f' print_progress ={print_progress}\n')
|
|
1732
|
+
if self.compiled.relaxations:
|
|
1531
1733
|
print('Some RDDL operations are non-differentiable, '
|
|
1532
1734
|
'replacing them with differentiable relaxations:')
|
|
1533
1735
|
print(self.compiled.summarize_model_relaxations())
|
|
@@ -1549,7 +1751,7 @@ class JaxBackpropPlanner:
|
|
|
1549
1751
|
'from the RDDL files.')
|
|
1550
1752
|
train_subs, test_subs = self._batched_init_subs(subs)
|
|
1551
1753
|
|
|
1552
|
-
# initialize
|
|
1754
|
+
# initialize model parameters
|
|
1553
1755
|
if model_params is None:
|
|
1554
1756
|
model_params = self.compiled.model_params
|
|
1555
1757
|
model_params_test = self.test_compiled.model_params
|
|
@@ -1575,12 +1777,14 @@ class JaxBackpropPlanner:
|
|
|
1575
1777
|
if plot_step is None or plot_step <= 0 or plt is None:
|
|
1576
1778
|
plot = None
|
|
1577
1779
|
else:
|
|
1578
|
-
|
|
1780
|
+
if plot_kwargs is None:
|
|
1781
|
+
plot_kwargs = {}
|
|
1782
|
+
plot = JaxPlannerPlot(self.rddl, self.horizon, **plot_kwargs)
|
|
1579
1783
|
xticks, loss_values = [], []
|
|
1580
1784
|
|
|
1581
1785
|
# training loop
|
|
1582
1786
|
iters = range(epochs)
|
|
1583
|
-
if
|
|
1787
|
+
if print_progress:
|
|
1584
1788
|
iters = tqdm(iters, total=100, position=tqdm_position)
|
|
1585
1789
|
|
|
1586
1790
|
for it in iters:
|
|
@@ -1588,9 +1792,18 @@ class JaxBackpropPlanner:
|
|
|
1588
1792
|
|
|
1589
1793
|
# update the parameters of the plan
|
|
1590
1794
|
key, subkey = random.split(key)
|
|
1591
|
-
policy_params, converged, opt_state, opt_aux,
|
|
1795
|
+
policy_params, converged, opt_state, opt_aux, \
|
|
1796
|
+
train_loss, train_log = \
|
|
1592
1797
|
self.update(subkey, policy_params, policy_hyperparams,
|
|
1593
1798
|
train_subs, model_params, opt_state, opt_aux)
|
|
1799
|
+
|
|
1800
|
+
# no progress
|
|
1801
|
+
grad_norm_zero, _ = jax.tree_util.tree_flatten(
|
|
1802
|
+
jax.tree_map(lambda x: np.allclose(x, 0), train_log['grad']))
|
|
1803
|
+
if np.all(grad_norm_zero):
|
|
1804
|
+
status = JaxPlannerStatus.NO_PROGRESS
|
|
1805
|
+
|
|
1806
|
+
# constraint satisfaction problem
|
|
1594
1807
|
if not np.all(converged):
|
|
1595
1808
|
raise_warning(
|
|
1596
1809
|
'Projected gradient method for satisfying action concurrency '
|
|
@@ -1598,13 +1811,18 @@ class JaxBackpropPlanner:
|
|
|
1598
1811
|
'invalid for the current instance.', 'red')
|
|
1599
1812
|
status = JaxPlannerStatus.PRECONDITION_POSSIBLY_UNSATISFIED
|
|
1600
1813
|
|
|
1601
|
-
#
|
|
1814
|
+
# numerical error
|
|
1815
|
+
if not np.isfinite(train_loss):
|
|
1816
|
+
raise_warning(
|
|
1817
|
+
f'Aborting JAX planner due to invalid train loss {train_loss}.',
|
|
1818
|
+
'red')
|
|
1819
|
+
status = JaxPlannerStatus.INVALID_GRADIENT
|
|
1820
|
+
|
|
1821
|
+
# evaluate test losses and record best plan so far
|
|
1602
1822
|
test_loss, log = self.test_loss(
|
|
1603
1823
|
subkey, policy_params, policy_hyperparams,
|
|
1604
1824
|
test_subs, model_params_test)
|
|
1605
1825
|
test_loss = rolling_test_loss.update(test_loss)
|
|
1606
|
-
|
|
1607
|
-
# record the best plan so far
|
|
1608
1826
|
if test_loss < best_loss:
|
|
1609
1827
|
best_params, best_loss, best_grad = \
|
|
1610
1828
|
policy_params, test_loss, train_log['grad']
|
|
@@ -1617,11 +1835,12 @@ class JaxBackpropPlanner:
|
|
|
1617
1835
|
action_values = {name: values
|
|
1618
1836
|
for (name, values) in log['fluents'].items()
|
|
1619
1837
|
if name in self.rddl.action_fluents}
|
|
1620
|
-
|
|
1838
|
+
returns = -np.sum(np.asarray(log['reward']), axis=1)
|
|
1839
|
+
plot.redraw(xticks, loss_values, action_values, returns)
|
|
1621
1840
|
|
|
1622
1841
|
# if the progress bar is used
|
|
1623
1842
|
elapsed = time.time() - start_time - elapsed_outside_loop
|
|
1624
|
-
if
|
|
1843
|
+
if print_progress:
|
|
1625
1844
|
iters.n = int(100 * min(1, max(elapsed / train_seconds, it / epochs)))
|
|
1626
1845
|
iters.set_description(
|
|
1627
1846
|
f'[{tqdm_position}] {it:6} it / {-train_loss:14.6f} train / '
|
|
@@ -1633,19 +1852,6 @@ class JaxBackpropPlanner:
|
|
|
1633
1852
|
if it >= epochs - 1:
|
|
1634
1853
|
status = JaxPlannerStatus.ITER_BUDGET_REACHED
|
|
1635
1854
|
|
|
1636
|
-
# numerical error
|
|
1637
|
-
if not np.isfinite(train_loss):
|
|
1638
|
-
raise_warning(
|
|
1639
|
-
f'Aborting JAX planner due to invalid train loss {train_loss}.',
|
|
1640
|
-
'red')
|
|
1641
|
-
status = JaxPlannerStatus.INVALID_GRADIENT
|
|
1642
|
-
|
|
1643
|
-
# no progress
|
|
1644
|
-
grad_norm_zero, _ = jax.tree_util.tree_flatten(
|
|
1645
|
-
jax.tree_map(lambda x: np.allclose(x, 0), train_log['grad']))
|
|
1646
|
-
if np.all(grad_norm_zero):
|
|
1647
|
-
status = JaxPlannerStatus.NO_PROGRESS
|
|
1648
|
-
|
|
1649
1855
|
# return a callback
|
|
1650
1856
|
start_time_outside = time.time()
|
|
1651
1857
|
yield {
|
|
@@ -1671,7 +1877,7 @@ class JaxBackpropPlanner:
|
|
|
1671
1877
|
break
|
|
1672
1878
|
|
|
1673
1879
|
# release resources
|
|
1674
|
-
if
|
|
1880
|
+
if print_progress:
|
|
1675
1881
|
iters.close()
|
|
1676
1882
|
if plot is not None:
|
|
1677
1883
|
plot.close()
|
|
@@ -1688,7 +1894,7 @@ class JaxBackpropPlanner:
|
|
|
1688
1894
|
f'during test evaluation:\n{messages}', 'red')
|
|
1689
1895
|
|
|
1690
1896
|
# summarize and test for convergence
|
|
1691
|
-
if
|
|
1897
|
+
if print_summary:
|
|
1692
1898
|
grad_norm = jax.tree_map(lambda x: np.linalg.norm(x).item(), best_grad)
|
|
1693
1899
|
diagnosis = self._perform_diagnosis(
|
|
1694
1900
|
last_iter_improve, -train_loss, -test_loss, -best_loss, grad_norm)
|
|
@@ -1778,17 +1984,19 @@ class JaxBackpropPlanner:
|
|
|
1778
1984
|
raise ValueError(f'State dictionary passed to the JAX policy is '
|
|
1779
1985
|
f'grounded, since it contains the key <{var}>, '
|
|
1780
1986
|
f'but a vectorized environment is required: '
|
|
1781
|
-
f'
|
|
1987
|
+
f'make sure vectorized = True in the RDDLEnv.')
|
|
1782
1988
|
|
|
1783
1989
|
# must be numeric array
|
|
1784
1990
|
# exception is for POMDPs at 1st epoch when observ-fluents are None
|
|
1785
|
-
|
|
1786
|
-
|
|
1991
|
+
dtype = np.atleast_1d(values).dtype
|
|
1992
|
+
if not jnp.issubdtype(dtype, jnp.number) \
|
|
1993
|
+
and not jnp.issubdtype(dtype, jnp.bool_):
|
|
1787
1994
|
if step == 0 and var in self.rddl.observ_fluents:
|
|
1788
1995
|
subs[var] = self.test_compiled.init_values[var]
|
|
1789
1996
|
else:
|
|
1790
|
-
raise ValueError(
|
|
1791
|
-
|
|
1997
|
+
raise ValueError(
|
|
1998
|
+
f'Values {values} assigned to p-variable <{var}> are '
|
|
1999
|
+
f'non-numeric of type {dtype}.')
|
|
1792
2000
|
|
|
1793
2001
|
# cast device arrays to numpy
|
|
1794
2002
|
actions = self.test_policy(key, params, policy_hyperparams, step, subs)
|
|
@@ -1801,8 +2009,6 @@ class JaxLineSearchPlanner(JaxBackpropPlanner):
|
|
|
1801
2009
|
linear search gradient descent, with the Armijo condition.'''
|
|
1802
2010
|
|
|
1803
2011
|
def __init__(self, *args,
|
|
1804
|
-
optimizer: Callable[..., optax.GradientTransformation]=optax.sgd,
|
|
1805
|
-
optimizer_kwargs: Kwargs={'learning_rate': 1.0},
|
|
1806
2012
|
decay: float=0.8,
|
|
1807
2013
|
c: float=0.1,
|
|
1808
2014
|
step_max: float=1.0,
|
|
@@ -1825,11 +2031,7 @@ class JaxLineSearchPlanner(JaxBackpropPlanner):
|
|
|
1825
2031
|
raise_warning('clip_grad parameter conflicts with '
|
|
1826
2032
|
'line search planner and will be ignored.', 'red')
|
|
1827
2033
|
del kwargs['clip_grad']
|
|
1828
|
-
super(JaxLineSearchPlanner, self).__init__(
|
|
1829
|
-
*args,
|
|
1830
|
-
optimizer=optimizer,
|
|
1831
|
-
optimizer_kwargs=optimizer_kwargs,
|
|
1832
|
-
**kwargs)
|
|
2034
|
+
super(JaxLineSearchPlanner, self).__init__(*args, **kwargs)
|
|
1833
2035
|
|
|
1834
2036
|
def summarize_hyperparameters(self) -> None:
|
|
1835
2037
|
super(JaxLineSearchPlanner, self).summarize_hyperparameters()
|
|
@@ -1878,7 +2080,8 @@ class JaxLineSearchPlanner(JaxBackpropPlanner):
|
|
|
1878
2080
|
step = lrmax / decay
|
|
1879
2081
|
f_step = np.inf
|
|
1880
2082
|
best_f, best_step, best_params, best_state = np.inf, None, None, None
|
|
1881
|
-
while f_step > f - c * step * gnorm2 and step * decay >= lrmin
|
|
2083
|
+
while (f_step > f - c * step * gnorm2 and step * decay >= lrmin) \
|
|
2084
|
+
or not trials:
|
|
1882
2085
|
trials += 1
|
|
1883
2086
|
step *= decay
|
|
1884
2087
|
f_step, new_params, new_state = _jax_wrapped_line_search_trial(
|
|
@@ -1918,7 +2121,7 @@ def entropic_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
|
1918
2121
|
|
|
1919
2122
|
@jax.jit
|
|
1920
2123
|
def mean_variance_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
1921
|
-
return jnp.mean(returns) -
|
|
2124
|
+
return jnp.mean(returns) - 0.5 * beta * jnp.var(returns)
|
|
1922
2125
|
|
|
1923
2126
|
|
|
1924
2127
|
@jax.jit
|
|
@@ -1986,7 +2189,8 @@ class JaxOfflineController(BaseAgent):
|
|
|
1986
2189
|
def reset(self) -> None:
|
|
1987
2190
|
self.step = 0
|
|
1988
2191
|
if self.train_on_reset and not self.params_given:
|
|
1989
|
-
|
|
2192
|
+
callback = self.planner.optimize(key=self.key, **self.train_kwargs)
|
|
2193
|
+
self.params = callback['best_params']
|
|
1990
2194
|
|
|
1991
2195
|
|
|
1992
2196
|
class JaxOnlineController(BaseAgent):
|