pyRDDLGym-jax 0.2__py3-none-any.whl → 0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -0
- pyRDDLGym_jax/core/compiler.py +90 -68
- pyRDDLGym_jax/core/logic.py +188 -46
- pyRDDLGym_jax/core/planner.py +411 -195
- pyRDDLGym_jax/core/simulator.py +2 -1
- pyRDDLGym_jax/core/tuning.py +13 -10
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +1 -1
- pyRDDLGym_jax/examples/configs/default_drp.cfg +1 -1
- pyRDDLGym_jax/examples/configs/default_slp.cfg +1 -1
- pyRDDLGym_jax/examples/run_gym.py +2 -5
- pyRDDLGym_jax/examples/run_plan.py +6 -8
- pyRDDLGym_jax/examples/run_scipy.py +61 -0
- pyRDDLGym_jax/examples/run_tune.py +5 -6
- pyRDDLGym_jax-0.4.dist-info/METADATA +276 -0
- {pyRDDLGym_jax-0.2.dist-info → pyRDDLGym_jax-0.4.dist-info}/RECORD +20 -22
- {pyRDDLGym_jax-0.2.dist-info → pyRDDLGym_jax-0.4.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax/examples/configs/Pong_slp.cfg +0 -18
- pyRDDLGym_jax/examples/configs/SupplyChain_slp.cfg +0 -18
- pyRDDLGym_jax/examples/configs/Traffic_slp.cfg +0 -20
- pyRDDLGym_jax-0.2.dist-info/METADATA +0 -26
- {pyRDDLGym_jax-0.2.dist-info → pyRDDLGym_jax-0.4.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-0.2.dist-info → pyRDDLGym_jax-0.4.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/planner.py
CHANGED
|
@@ -1,53 +1,52 @@
|
|
|
1
|
-
__version__ = '0.2'
|
|
2
|
-
|
|
3
1
|
from ast import literal_eval
|
|
4
2
|
from collections import deque
|
|
5
3
|
import configparser
|
|
6
4
|
from enum import Enum
|
|
5
|
+
import os
|
|
6
|
+
import sys
|
|
7
|
+
import time
|
|
8
|
+
import traceback
|
|
9
|
+
from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Tuple, Union
|
|
10
|
+
|
|
7
11
|
import haiku as hk
|
|
8
12
|
import jax
|
|
13
|
+
import jax.nn.initializers as initializers
|
|
9
14
|
import jax.numpy as jnp
|
|
10
15
|
import jax.random as random
|
|
11
|
-
import jax.nn.initializers as initializers
|
|
12
16
|
import numpy as np
|
|
13
17
|
import optax
|
|
14
|
-
import os
|
|
15
|
-
import sys
|
|
16
18
|
import termcolor
|
|
17
|
-
import time
|
|
18
19
|
from tqdm import tqdm
|
|
19
|
-
from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Tuple, Union
|
|
20
|
-
|
|
21
|
-
Activation = Callable[[jnp.ndarray], jnp.ndarray]
|
|
22
|
-
Bounds = Dict[str, Tuple[np.ndarray, np.ndarray]]
|
|
23
|
-
Kwargs = Dict[str, Any]
|
|
24
|
-
Pytree = Any
|
|
25
20
|
|
|
26
|
-
from pyRDDLGym.core.debug.exception import raise_warning
|
|
27
|
-
|
|
28
|
-
# try to import matplotlib, if failed then skip plotting
|
|
29
|
-
try:
|
|
30
|
-
import matplotlib
|
|
31
|
-
import matplotlib.pyplot as plt
|
|
32
|
-
matplotlib.use('TkAgg')
|
|
33
|
-
except Exception:
|
|
34
|
-
raise_warning('matplotlib is not installed, '
|
|
35
|
-
'plotting functionality is disabled.', 'red')
|
|
36
|
-
plt = None
|
|
37
|
-
|
|
38
21
|
from pyRDDLGym.core.compiler.model import RDDLPlanningModel, RDDLLiftedModel
|
|
39
22
|
from pyRDDLGym.core.debug.logger import Logger
|
|
40
23
|
from pyRDDLGym.core.debug.exception import (
|
|
24
|
+
raise_warning,
|
|
41
25
|
RDDLNotImplementedError,
|
|
42
26
|
RDDLUndefinedVariableError,
|
|
43
27
|
RDDLTypeError
|
|
44
28
|
)
|
|
45
29
|
from pyRDDLGym.core.policy import BaseAgent
|
|
46
30
|
|
|
47
|
-
from pyRDDLGym_jax
|
|
31
|
+
from pyRDDLGym_jax import __version__
|
|
48
32
|
from pyRDDLGym_jax.core import logic
|
|
33
|
+
from pyRDDLGym_jax.core.compiler import JaxRDDLCompiler
|
|
49
34
|
from pyRDDLGym_jax.core.logic import FuzzyLogic
|
|
50
35
|
|
|
36
|
+
# try to import matplotlib, if failed then skip plotting
|
|
37
|
+
try:
|
|
38
|
+
import matplotlib.pyplot as plt
|
|
39
|
+
except Exception:
|
|
40
|
+
raise_warning('failed to import matplotlib: '
|
|
41
|
+
'plotting functionality will be disabled.', 'red')
|
|
42
|
+
traceback.print_exc()
|
|
43
|
+
plt = None
|
|
44
|
+
|
|
45
|
+
Activation = Callable[[jnp.ndarray], jnp.ndarray]
|
|
46
|
+
Bounds = Dict[str, Tuple[np.ndarray, np.ndarray]]
|
|
47
|
+
Kwargs = Dict[str, Any]
|
|
48
|
+
Pytree = Any
|
|
49
|
+
|
|
51
50
|
|
|
52
51
|
# ***********************************************************************
|
|
53
52
|
# CONFIG FILE MANAGEMENT
|
|
@@ -102,9 +101,12 @@ def _load_config(config, args):
|
|
|
102
101
|
comp_kwargs = model_args.get('complement_kwargs', {})
|
|
103
102
|
compare_name = model_args.get('comparison', 'SigmoidComparison')
|
|
104
103
|
compare_kwargs = model_args.get('comparison_kwargs', {})
|
|
104
|
+
sampling_name = model_args.get('sampling', 'GumbelSoftmax')
|
|
105
|
+
sampling_kwargs = model_args.get('sampling_kwargs', {})
|
|
105
106
|
logic_kwargs['tnorm'] = getattr(logic, tnorm_name)(**tnorm_kwargs)
|
|
106
107
|
logic_kwargs['complement'] = getattr(logic, comp_name)(**comp_kwargs)
|
|
107
108
|
logic_kwargs['comparison'] = getattr(logic, compare_name)(**compare_kwargs)
|
|
109
|
+
logic_kwargs['sampling'] = getattr(logic, sampling_name)(**sampling_kwargs)
|
|
108
110
|
|
|
109
111
|
# read the policy settings
|
|
110
112
|
plan_method = planner_args.pop('method')
|
|
@@ -113,7 +115,8 @@ def _load_config(config, args):
|
|
|
113
115
|
# policy initialization
|
|
114
116
|
plan_initializer = plan_kwargs.get('initializer', None)
|
|
115
117
|
if plan_initializer is not None:
|
|
116
|
-
initializer = _getattr_any(
|
|
118
|
+
initializer = _getattr_any(
|
|
119
|
+
packages=[initializers, hk.initializers], item=plan_initializer)
|
|
117
120
|
if initializer is None:
|
|
118
121
|
raise_warning(
|
|
119
122
|
f'Ignoring invalid initializer <{plan_initializer}>.', 'red')
|
|
@@ -130,7 +133,8 @@ def _load_config(config, args):
|
|
|
130
133
|
# policy activation
|
|
131
134
|
plan_activation = plan_kwargs.get('activation', None)
|
|
132
135
|
if plan_activation is not None:
|
|
133
|
-
activation = _getattr_any(
|
|
136
|
+
activation = _getattr_any(
|
|
137
|
+
packages=[jax.nn, jax.numpy], item=plan_activation)
|
|
134
138
|
if activation is None:
|
|
135
139
|
raise_warning(
|
|
136
140
|
f'Ignoring invalid activation <{plan_activation}>.', 'red')
|
|
@@ -180,18 +184,6 @@ def load_config_from_string(value: str) -> Tuple[Kwargs, ...]:
|
|
|
180
184
|
#
|
|
181
185
|
# ***********************************************************************
|
|
182
186
|
|
|
183
|
-
def _function_discrete_approx_named(logic):
|
|
184
|
-
jax_discrete, jax_param = logic.discrete()
|
|
185
|
-
|
|
186
|
-
def _jax_wrapped_discrete_calc_approx(key, prob, params):
|
|
187
|
-
sample = jax_discrete(key, prob, params)
|
|
188
|
-
out_of_bounds = jnp.logical_not(jnp.logical_and(
|
|
189
|
-
jnp.all(prob >= 0),
|
|
190
|
-
jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
|
|
191
|
-
return sample, out_of_bounds
|
|
192
|
-
|
|
193
|
-
return _jax_wrapped_discrete_calc_approx, jax_param
|
|
194
|
-
|
|
195
187
|
|
|
196
188
|
class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
197
189
|
'''Compiles a RDDL AST representation to an equivalent JAX representation.
|
|
@@ -217,6 +209,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
217
209
|
:param *kwargs: keyword arguments to pass to base compiler
|
|
218
210
|
'''
|
|
219
211
|
super(JaxRDDLCompilerWithGrad, self).__init__(*args, **kwargs)
|
|
212
|
+
|
|
220
213
|
self.logic = logic
|
|
221
214
|
self.logic.set_use64bit(self.use64bit)
|
|
222
215
|
if cpfs_without_grad is None:
|
|
@@ -224,9 +217,14 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
224
217
|
self.cpfs_without_grad = cpfs_without_grad
|
|
225
218
|
|
|
226
219
|
# actions and CPFs must be continuous
|
|
227
|
-
|
|
220
|
+
pvars_cast = set()
|
|
228
221
|
for (var, values) in self.init_values.items():
|
|
229
222
|
self.init_values[var] = np.asarray(values, dtype=self.REAL)
|
|
223
|
+
if not np.issubdtype(np.atleast_1d(values).dtype, np.floating):
|
|
224
|
+
pvars_cast.add(var)
|
|
225
|
+
if pvars_cast:
|
|
226
|
+
raise_warning(f'JAX gradient compiler requires that initial values '
|
|
227
|
+
f'of p-variables {pvars_cast} be cast to float.')
|
|
230
228
|
|
|
231
229
|
# overwrite basic operations with fuzzy ones
|
|
232
230
|
self.RELATIONAL_OPS = {
|
|
@@ -261,7 +259,9 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
261
259
|
self.IF_HELPER = logic.control_if()
|
|
262
260
|
self.SWITCH_HELPER = logic.control_switch()
|
|
263
261
|
self.BERNOULLI_HELPER = logic.bernoulli()
|
|
264
|
-
self.DISCRETE_HELPER =
|
|
262
|
+
self.DISCRETE_HELPER = logic.discrete()
|
|
263
|
+
self.POISSON_HELPER = logic.poisson()
|
|
264
|
+
self.GEOMETRIC_HELPER = logic.geometric()
|
|
265
265
|
|
|
266
266
|
def _jax_stop_grad(self, jax_expr):
|
|
267
267
|
|
|
@@ -273,20 +273,29 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
273
273
|
return _jax_wrapped_stop_grad
|
|
274
274
|
|
|
275
275
|
def _compile_cpfs(self, info):
|
|
276
|
-
|
|
276
|
+
cpfs_cast = set()
|
|
277
277
|
jax_cpfs = {}
|
|
278
278
|
for (_, cpfs) in self.levels.items():
|
|
279
279
|
for cpf in cpfs:
|
|
280
280
|
_, expr = self.rddl.cpfs[cpf]
|
|
281
281
|
jax_cpfs[cpf] = self._jax(expr, info, dtype=self.REAL)
|
|
282
|
+
if self.rddl.variable_ranges[cpf] != 'real':
|
|
283
|
+
cpfs_cast.add(cpf)
|
|
282
284
|
if cpf in self.cpfs_without_grad:
|
|
283
|
-
raise_warning(f'CPF <{cpf}> stops gradient.')
|
|
284
285
|
jax_cpfs[cpf] = self._jax_stop_grad(jax_cpfs[cpf])
|
|
286
|
+
|
|
287
|
+
if cpfs_cast:
|
|
288
|
+
raise_warning(f'JAX gradient compiler requires that outputs of CPFs '
|
|
289
|
+
f'{cpfs_cast} be cast to float.')
|
|
290
|
+
if self.cpfs_without_grad:
|
|
291
|
+
raise_warning(f'User requested that gradients not flow '
|
|
292
|
+
f'through CPFs {self.cpfs_without_grad}.')
|
|
285
293
|
return jax_cpfs
|
|
286
294
|
|
|
287
295
|
def _jax_kron(self, expr, info):
|
|
288
296
|
if self.logic.verbose:
|
|
289
|
-
raise_warning('
|
|
297
|
+
raise_warning('JAX gradient compiler ignores KronDelta '
|
|
298
|
+
'during compilation.')
|
|
290
299
|
arg, = expr.args
|
|
291
300
|
arg = self._jax(arg, info)
|
|
292
301
|
return arg
|
|
@@ -308,7 +317,8 @@ class JaxPlan:
|
|
|
308
317
|
self._train_policy = None
|
|
309
318
|
self._test_policy = None
|
|
310
319
|
self._projection = None
|
|
311
|
-
|
|
320
|
+
self.bounds = None
|
|
321
|
+
|
|
312
322
|
def summarize_hyperparameters(self) -> None:
|
|
313
323
|
pass
|
|
314
324
|
|
|
@@ -363,7 +373,7 @@ class JaxPlan:
|
|
|
363
373
|
# check invalid type
|
|
364
374
|
if prange not in compiled.JAX_TYPES:
|
|
365
375
|
raise RDDLTypeError(
|
|
366
|
-
f'Invalid range <{prange}
|
|
376
|
+
f'Invalid range <{prange}> of action-fluent <{name}>, '
|
|
367
377
|
f'must be one of {set(compiled.JAX_TYPES.keys())}.')
|
|
368
378
|
|
|
369
379
|
# clip boolean to (0, 1), otherwise use the RDDL action bounds
|
|
@@ -385,7 +395,7 @@ class JaxPlan:
|
|
|
385
395
|
~lower_finite & upper_finite,
|
|
386
396
|
~lower_finite & ~upper_finite]
|
|
387
397
|
bounds[name] = (lower, upper)
|
|
388
|
-
raise_warning(f'Bounds of action
|
|
398
|
+
raise_warning(f'Bounds of action-fluent <{name}> set to {bounds[name]}.')
|
|
389
399
|
return shapes, bounds, bounds_safe, cond_lists
|
|
390
400
|
|
|
391
401
|
def _count_bool_actions(self, rddl: RDDLLiftedModel):
|
|
@@ -427,6 +437,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
427
437
|
use_new_projection = True
|
|
428
438
|
'''
|
|
429
439
|
super(JaxStraightLinePlan, self).__init__()
|
|
440
|
+
|
|
430
441
|
self._initializer_base = initializer
|
|
431
442
|
self._initializer = initializer
|
|
432
443
|
self._wrap_sigmoid = wrap_sigmoid
|
|
@@ -437,15 +448,19 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
437
448
|
self._max_constraint_iter = max_constraint_iter
|
|
438
449
|
|
|
439
450
|
def summarize_hyperparameters(self) -> None:
|
|
451
|
+
bounds = '\n '.join(
|
|
452
|
+
map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
|
|
440
453
|
print(f'policy hyper-parameters:\n'
|
|
441
|
-
f' initializer ={
|
|
454
|
+
f' initializer ={self._initializer_base}\n'
|
|
442
455
|
f'constraint-sat strategy (simple):\n'
|
|
456
|
+
f' parsed_action_bounds =\n {bounds}\n'
|
|
443
457
|
f' wrap_sigmoid ={self._wrap_sigmoid}\n'
|
|
444
458
|
f' wrap_sigmoid_min_prob={self._min_action_prob}\n'
|
|
445
459
|
f' wrap_non_bool ={self._wrap_non_bool}\n'
|
|
446
460
|
f'constraint-sat strategy (complex):\n'
|
|
447
461
|
f' wrap_softmax ={self._wrap_softmax}\n'
|
|
448
|
-
f' use_new_projection ={self._use_new_projection}'
|
|
462
|
+
f' use_new_projection ={self._use_new_projection}\n'
|
|
463
|
+
f' max_projection_iters ={self._max_constraint_iter}')
|
|
449
464
|
|
|
450
465
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
451
466
|
_bounds: Bounds,
|
|
@@ -603,7 +618,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
603
618
|
if 1 < allowed_actions < bool_action_count:
|
|
604
619
|
raise RDDLNotImplementedError(
|
|
605
620
|
f'Straight-line plans with wrap_softmax currently '
|
|
606
|
-
f'do not support max-nondef-actions
|
|
621
|
+
f'do not support max-nondef-actions {allowed_actions} > 1.')
|
|
607
622
|
|
|
608
623
|
# potentially apply projection but to non-bool actions only
|
|
609
624
|
self.projection = _jax_wrapped_slp_project_to_box
|
|
@@ -734,14 +749,14 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
734
749
|
for (var, shape) in shapes.items():
|
|
735
750
|
if ranges[var] != 'bool' or not stack_bool_params:
|
|
736
751
|
key, subkey = random.split(key)
|
|
737
|
-
param = init(subkey, shape, dtype=compiled.REAL)
|
|
752
|
+
param = init(key=subkey, shape=shape, dtype=compiled.REAL)
|
|
738
753
|
if ranges[var] == 'bool':
|
|
739
754
|
param += bool_threshold
|
|
740
755
|
params[var] = param
|
|
741
756
|
if stack_bool_params:
|
|
742
757
|
key, subkey = random.split(key)
|
|
743
758
|
bool_shape = (horizon, bool_action_count)
|
|
744
|
-
bool_param = init(subkey, bool_shape, dtype=compiled.REAL)
|
|
759
|
+
bool_param = init(key=subkey, shape=bool_shape, dtype=compiled.REAL)
|
|
745
760
|
params[bool_key] = bool_param
|
|
746
761
|
params, _ = _jax_wrapped_slp_project_to_box(params, hyperparams)
|
|
747
762
|
return params
|
|
@@ -765,7 +780,8 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
765
780
|
def __init__(self, topology: Optional[Sequence[int]]=None,
|
|
766
781
|
activation: Activation=jnp.tanh,
|
|
767
782
|
initializer: hk.initializers.Initializer=hk.initializers.VarianceScaling(scale=2.0),
|
|
768
|
-
normalize: bool=
|
|
783
|
+
normalize: bool=False,
|
|
784
|
+
normalize_per_layer: bool=False,
|
|
769
785
|
normalizer_kwargs: Optional[Kwargs]=None,
|
|
770
786
|
wrap_non_bool: bool=False) -> None:
|
|
771
787
|
'''Creates a new deep reactive policy in JAX.
|
|
@@ -775,12 +791,15 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
775
791
|
:param activation: function to apply after each layer of the policy
|
|
776
792
|
:param initializer: weight initialization
|
|
777
793
|
:param normalize: whether to apply layer norm to the inputs
|
|
794
|
+
:param normalize_per_layer: whether to apply layer norm to each input
|
|
795
|
+
individually (only active if normalize is True)
|
|
778
796
|
:param normalizer_kwargs: if normalize is True, apply additional arguments
|
|
779
797
|
to layer norm
|
|
780
798
|
:param wrap_non_bool: whether to wrap real or int action fluent parameters
|
|
781
799
|
with non-linearity (e.g. sigmoid or ELU) to satisfy box constraints
|
|
782
800
|
'''
|
|
783
801
|
super(JaxDeepReactivePolicy, self).__init__()
|
|
802
|
+
|
|
784
803
|
if topology is None:
|
|
785
804
|
topology = [128, 64]
|
|
786
805
|
self._topology = topology
|
|
@@ -788,22 +807,25 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
788
807
|
self._initializer_base = initializer
|
|
789
808
|
self._initializer = initializer
|
|
790
809
|
self._normalize = normalize
|
|
810
|
+
self._normalize_per_layer = normalize_per_layer
|
|
791
811
|
if normalizer_kwargs is None:
|
|
792
|
-
normalizer_kwargs = {
|
|
793
|
-
'create_offset': True, 'create_scale': True,
|
|
794
|
-
'name': 'input_norm'
|
|
795
|
-
}
|
|
812
|
+
normalizer_kwargs = {'create_offset': True, 'create_scale': True}
|
|
796
813
|
self._normalizer_kwargs = normalizer_kwargs
|
|
797
814
|
self._wrap_non_bool = wrap_non_bool
|
|
798
815
|
|
|
799
816
|
def summarize_hyperparameters(self) -> None:
|
|
817
|
+
bounds = '\n '.join(
|
|
818
|
+
map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
|
|
800
819
|
print(f'policy hyper-parameters:\n'
|
|
801
|
-
f' topology
|
|
802
|
-
f' activation_fn
|
|
803
|
-
f' initializer
|
|
804
|
-
f'
|
|
805
|
-
f'
|
|
806
|
-
f'
|
|
820
|
+
f' topology ={self._topology}\n'
|
|
821
|
+
f' activation_fn ={self._activations[0].__name__}\n'
|
|
822
|
+
f' initializer ={type(self._initializer_base).__name__}\n'
|
|
823
|
+
f' apply_input_norm ={self._normalize}\n'
|
|
824
|
+
f' input_norm_layerwise={self._normalize_per_layer}\n'
|
|
825
|
+
f' input_norm_args ={self._normalizer_kwargs}\n'
|
|
826
|
+
f'constraint-sat strategy:\n'
|
|
827
|
+
f' parsed_action_bounds=\n {bounds}\n'
|
|
828
|
+
f' wrap_non_bool ={self._wrap_non_bool}')
|
|
807
829
|
|
|
808
830
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
809
831
|
_bounds: Bounds,
|
|
@@ -821,7 +843,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
821
843
|
if 1 < allowed_actions < bool_action_count:
|
|
822
844
|
raise RDDLNotImplementedError(
|
|
823
845
|
f'Deep reactive policies currently do not support '
|
|
824
|
-
f'max-nondef-actions
|
|
846
|
+
f'max-nondef-actions {allowed_actions} > 1.')
|
|
825
847
|
use_constraint_satisfaction = allowed_actions < bool_action_count
|
|
826
848
|
|
|
827
849
|
noop = {var: (values[0] if isinstance(values, list) else values)
|
|
@@ -835,6 +857,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
835
857
|
|
|
836
858
|
ranges = rddl.variable_ranges
|
|
837
859
|
normalize = self._normalize
|
|
860
|
+
normalize_per_layer = self._normalize_per_layer
|
|
838
861
|
wrap_non_bool = self._wrap_non_bool
|
|
839
862
|
init = self._initializer
|
|
840
863
|
layers = list(enumerate(zip(self._topology, self._activations)))
|
|
@@ -842,14 +865,67 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
842
865
|
for (var, shape) in shapes.items()}
|
|
843
866
|
layer_names = {var: f'output_{var}'.replace('-', '_') for var in shapes}
|
|
844
867
|
|
|
845
|
-
#
|
|
846
|
-
|
|
868
|
+
# inputs for the policy network
|
|
869
|
+
if rddl.observ_fluents:
|
|
870
|
+
observed_vars = rddl.observ_fluents
|
|
871
|
+
else:
|
|
872
|
+
observed_vars = rddl.state_fluents
|
|
873
|
+
input_names = {var: f'{var}'.replace('-', '_') for var in observed_vars}
|
|
874
|
+
|
|
875
|
+
# catch if input norm is applied to size 1 tensor
|
|
876
|
+
if normalize:
|
|
877
|
+
non_bool_dims = 0
|
|
878
|
+
for (var, values) in observed_vars.items():
|
|
879
|
+
if ranges[var] != 'bool':
|
|
880
|
+
value_size = np.atleast_1d(values).size
|
|
881
|
+
if normalize_per_layer and value_size == 1:
|
|
882
|
+
raise_warning(
|
|
883
|
+
f'Cannot apply layer norm to state-fluent <{var}> '
|
|
884
|
+
f'of size 1: setting normalize_per_layer = False.',
|
|
885
|
+
'red')
|
|
886
|
+
normalize_per_layer = False
|
|
887
|
+
non_bool_dims += value_size
|
|
888
|
+
if not normalize_per_layer and non_bool_dims == 1:
|
|
889
|
+
raise_warning(
|
|
890
|
+
'Cannot apply layer norm to state-fluents of total size 1: '
|
|
891
|
+
'setting normalize = False.', 'red')
|
|
892
|
+
normalize = False
|
|
893
|
+
|
|
894
|
+
# convert subs dictionary into a state vector to feed to the MLP
|
|
895
|
+
def _jax_wrapped_policy_input(subs):
|
|
847
896
|
|
|
848
|
-
#
|
|
849
|
-
|
|
897
|
+
# concatenate all state variables into a single vector
|
|
898
|
+
# optionally apply layer norm to each input tensor
|
|
899
|
+
states_bool, states_non_bool = [], []
|
|
900
|
+
non_bool_dims = 0
|
|
901
|
+
for (var, value) in subs.items():
|
|
902
|
+
if var in observed_vars:
|
|
903
|
+
state = jnp.ravel(value)
|
|
904
|
+
if ranges[var] == 'bool':
|
|
905
|
+
states_bool.append(state)
|
|
906
|
+
else:
|
|
907
|
+
if normalize and normalize_per_layer:
|
|
908
|
+
normalizer = hk.LayerNorm(
|
|
909
|
+
axis=-1, param_axis=-1,
|
|
910
|
+
name=f'input_norm_{input_names[var]}',
|
|
911
|
+
**self._normalizer_kwargs)
|
|
912
|
+
state = normalizer(state)
|
|
913
|
+
states_non_bool.append(state)
|
|
914
|
+
non_bool_dims += state.size
|
|
915
|
+
state = jnp.concatenate(states_non_bool + states_bool)
|
|
916
|
+
|
|
917
|
+
# optionally perform layer normalization on the non-bool inputs
|
|
918
|
+
if normalize and not normalize_per_layer and non_bool_dims:
|
|
850
919
|
normalizer = hk.LayerNorm(
|
|
851
|
-
axis=-1, param_axis=-1,
|
|
852
|
-
|
|
920
|
+
axis=-1, param_axis=-1, name='input_norm',
|
|
921
|
+
**self._normalizer_kwargs)
|
|
922
|
+
normalized = normalizer(state[:non_bool_dims])
|
|
923
|
+
state = state.at[:non_bool_dims].set(normalized)
|
|
924
|
+
return state
|
|
925
|
+
|
|
926
|
+
# predict actions from the policy network for current state
|
|
927
|
+
def _jax_wrapped_policy_network_predict(subs):
|
|
928
|
+
state = _jax_wrapped_policy_input(subs)
|
|
853
929
|
|
|
854
930
|
# feed state vector through hidden layers
|
|
855
931
|
hidden = state
|
|
@@ -913,25 +989,9 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
913
989
|
start += size
|
|
914
990
|
return actions
|
|
915
991
|
|
|
916
|
-
if rddl.observ_fluents:
|
|
917
|
-
observed_vars = rddl.observ_fluents
|
|
918
|
-
else:
|
|
919
|
-
observed_vars = rddl.state_fluents
|
|
920
|
-
|
|
921
|
-
# state is concatenated into single tensor
|
|
922
|
-
def _jax_wrapped_subs_to_state(subs):
|
|
923
|
-
subs = {var: value
|
|
924
|
-
for (var, value) in subs.items()
|
|
925
|
-
if var in observed_vars}
|
|
926
|
-
flat_subs = jax.tree_map(jnp.ravel, subs)
|
|
927
|
-
states = list(flat_subs.values())
|
|
928
|
-
state = jnp.concatenate(states)
|
|
929
|
-
return state
|
|
930
|
-
|
|
931
992
|
# train action prediction
|
|
932
993
|
def _jax_wrapped_drp_predict_train(key, params, hyperparams, step, subs):
|
|
933
|
-
|
|
934
|
-
actions = predict_fn.apply(params, state)
|
|
994
|
+
actions = predict_fn.apply(params, subs)
|
|
935
995
|
if not wrap_non_bool:
|
|
936
996
|
for (var, action) in actions.items():
|
|
937
997
|
if var != bool_key and ranges[var] != 'bool':
|
|
@@ -982,8 +1042,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
982
1042
|
subs = {var: value[0, ...]
|
|
983
1043
|
for (var, value) in subs.items()
|
|
984
1044
|
if var in observed_vars}
|
|
985
|
-
|
|
986
|
-
params = predict_fn.init(key, state)
|
|
1045
|
+
params = predict_fn.init(key, subs)
|
|
987
1046
|
return params
|
|
988
1047
|
|
|
989
1048
|
self.initializer = _jax_wrapped_drp_init
|
|
@@ -1021,46 +1080,72 @@ class RollingMean:
|
|
|
1021
1080
|
class JaxPlannerPlot:
|
|
1022
1081
|
'''Supports plotting and visualization of a JAX policy in real time.'''
|
|
1023
1082
|
|
|
1024
|
-
def __init__(self, rddl: RDDLPlanningModel, horizon: int
|
|
1025
|
-
|
|
1083
|
+
def __init__(self, rddl: RDDLPlanningModel, horizon: int,
|
|
1084
|
+
show_violin: bool=True, show_action: bool=True) -> None:
|
|
1085
|
+
'''Creates a new planner visualizer.
|
|
1086
|
+
|
|
1087
|
+
:param rddl: the planning model to optimize
|
|
1088
|
+
:param horizon: the lookahead or planning horizon
|
|
1089
|
+
:param show_violin: whether to show the distribution of batch losses
|
|
1090
|
+
:param show_action: whether to show heatmaps of the action fluents
|
|
1091
|
+
'''
|
|
1092
|
+
num_plots = 1
|
|
1093
|
+
if show_violin:
|
|
1094
|
+
num_plots += 1
|
|
1095
|
+
if show_action:
|
|
1096
|
+
num_plots += len(rddl.action_fluents)
|
|
1097
|
+
self._fig, axes = plt.subplots(num_plots)
|
|
1098
|
+
if num_plots == 1:
|
|
1099
|
+
axes = [axes]
|
|
1026
1100
|
|
|
1027
1101
|
# prepare the loss plot
|
|
1028
1102
|
self._loss_ax = axes[0]
|
|
1029
1103
|
self._loss_ax.autoscale(enable=True)
|
|
1030
|
-
self._loss_ax.set_xlabel('
|
|
1104
|
+
self._loss_ax.set_xlabel('training time')
|
|
1031
1105
|
self._loss_ax.set_ylabel('loss value')
|
|
1032
1106
|
self._loss_plot = self._loss_ax.plot(
|
|
1033
1107
|
[], [], linestyle=':', marker='o', markersize=2)[0]
|
|
1034
1108
|
self._loss_back = self._fig.canvas.copy_from_bbox(self._loss_ax.bbox)
|
|
1035
1109
|
|
|
1110
|
+
# prepare the violin plot
|
|
1111
|
+
if show_violin:
|
|
1112
|
+
self._hist_ax = axes[1]
|
|
1113
|
+
else:
|
|
1114
|
+
self._hist_ax = None
|
|
1115
|
+
|
|
1036
1116
|
# prepare the action plots
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1117
|
+
if show_action:
|
|
1118
|
+
self._action_ax = {name: axes[idx + (2 if show_violin else 1)]
|
|
1119
|
+
for (idx, name) in enumerate(rddl.action_fluents)}
|
|
1120
|
+
self._action_plots = {}
|
|
1121
|
+
for name in rddl.action_fluents:
|
|
1122
|
+
ax = self._action_ax[name]
|
|
1123
|
+
if rddl.variable_ranges[name] == 'bool':
|
|
1124
|
+
vmin, vmax = 0.0, 1.0
|
|
1125
|
+
else:
|
|
1126
|
+
vmin, vmax = None, None
|
|
1127
|
+
action_dim = 1
|
|
1128
|
+
for dim in rddl.object_counts(rddl.variable_params[name]):
|
|
1129
|
+
action_dim *= dim
|
|
1130
|
+
action_plot = ax.pcolormesh(
|
|
1131
|
+
np.zeros((action_dim, horizon)),
|
|
1132
|
+
cmap='seismic', vmin=vmin, vmax=vmax)
|
|
1133
|
+
ax.set_aspect('auto')
|
|
1134
|
+
ax.set_xlabel('decision epoch')
|
|
1135
|
+
ax.set_ylabel(name)
|
|
1136
|
+
plt.colorbar(action_plot, ax=ax)
|
|
1137
|
+
self._action_plots[name] = action_plot
|
|
1138
|
+
self._action_back = {name: self._fig.canvas.copy_from_bbox(ax.bbox)
|
|
1139
|
+
for (name, ax) in self._action_ax.items()}
|
|
1140
|
+
else:
|
|
1141
|
+
self._action_ax = None
|
|
1142
|
+
self._action_plots = None
|
|
1143
|
+
self._action_back = None
|
|
1144
|
+
|
|
1060
1145
|
plt.tight_layout()
|
|
1061
1146
|
plt.show(block=False)
|
|
1062
1147
|
|
|
1063
|
-
def redraw(self, xticks, losses, actions) -> None:
|
|
1148
|
+
def redraw(self, xticks, losses, actions, returns) -> None:
|
|
1064
1149
|
|
|
1065
1150
|
# draw the loss curve
|
|
1066
1151
|
self._fig.canvas.restore_region(self._loss_back)
|
|
@@ -1071,21 +1156,30 @@ class JaxPlannerPlot:
|
|
|
1071
1156
|
self._loss_ax.draw_artist(self._loss_plot)
|
|
1072
1157
|
self._fig.canvas.blit(self._loss_ax.bbox)
|
|
1073
1158
|
|
|
1159
|
+
# draw the violin plot
|
|
1160
|
+
if self._hist_ax is not None:
|
|
1161
|
+
self._hist_ax.clear()
|
|
1162
|
+
self._hist_ax.set_xlabel('loss value')
|
|
1163
|
+
self._hist_ax.set_ylabel('density')
|
|
1164
|
+
self._hist_ax.violinplot(returns, vert=False, showmeans=True)
|
|
1165
|
+
|
|
1074
1166
|
# draw the actions
|
|
1075
|
-
|
|
1076
|
-
values
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
|
|
1167
|
+
if self._action_ax is not None:
|
|
1168
|
+
for (name, values) in actions.items():
|
|
1169
|
+
values = np.mean(values, axis=0, dtype=float)
|
|
1170
|
+
values = np.reshape(values, newshape=(values.shape[0], -1)).T
|
|
1171
|
+
self._fig.canvas.restore_region(self._action_back[name])
|
|
1172
|
+
self._action_plots[name].set_array(values)
|
|
1173
|
+
self._action_ax[name].draw_artist(self._action_plots[name])
|
|
1174
|
+
self._fig.canvas.blit(self._action_ax[name].bbox)
|
|
1175
|
+
self._action_plots[name].set_clim([np.min(values), np.max(values)])
|
|
1176
|
+
|
|
1083
1177
|
self._fig.canvas.draw()
|
|
1084
1178
|
self._fig.canvas.flush_events()
|
|
1085
1179
|
|
|
1086
1180
|
def close(self) -> None:
|
|
1087
1181
|
plt.close(self._fig)
|
|
1088
|
-
del self._loss_ax, self._action_ax, \
|
|
1182
|
+
del self._loss_ax, self._hist_ax, self._action_ax, \
|
|
1089
1183
|
self._loss_plot, self._action_plots, self._fig, \
|
|
1090
1184
|
self._loss_back, self._action_back
|
|
1091
1185
|
|
|
@@ -1099,9 +1193,9 @@ class JaxPlannerStatus(Enum):
|
|
|
1099
1193
|
NORMAL = 0
|
|
1100
1194
|
NO_PROGRESS = 1
|
|
1101
1195
|
PRECONDITION_POSSIBLY_UNSATISFIED = 2
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
|
|
1196
|
+
INVALID_GRADIENT = 3
|
|
1197
|
+
TIME_BUDGET_REACHED = 4
|
|
1198
|
+
ITER_BUDGET_REACHED = 5
|
|
1105
1199
|
|
|
1106
1200
|
def is_failure(self) -> bool:
|
|
1107
1201
|
return self.value >= 3
|
|
@@ -1245,30 +1339,41 @@ class JaxBackpropPlanner:
|
|
|
1245
1339
|
map(str, jax._src.xla_bridge.devices())).replace('\n', '')
|
|
1246
1340
|
except Exception as _:
|
|
1247
1341
|
devices_short = 'N/A'
|
|
1342
|
+
LOGO = \
|
|
1343
|
+
"""
|
|
1344
|
+
__ ______ __ __ ______ __ ______ __ __
|
|
1345
|
+
/\ \ /\ __ \ /\_\_\_\ /\ == \/\ \ /\ __ \ /\ "-.\ \
|
|
1346
|
+
_\_\ \ \ \ __ \ \/_/\_\/_ \ \ _-/\ \ \____ \ \ __ \ \ \ \-. \
|
|
1347
|
+
/\_____\ \ \_\ \_\ /\_\/\_\ \ \_\ \ \_____\ \ \_\ \_\ \ \_\\"\_\
|
|
1348
|
+
\/_____/ \/_/\/_/ \/_/\/_/ \/_/ \/_____/ \/_/\/_/ \/_/ \/_/
|
|
1349
|
+
"""
|
|
1350
|
+
|
|
1248
1351
|
print('\n'
|
|
1249
|
-
f'
|
|
1352
|
+
f'{LOGO}\n'
|
|
1353
|
+
f'Version {__version__}\n'
|
|
1250
1354
|
f'Python {sys.version}\n'
|
|
1251
1355
|
f'jax {jax.version.__version__}, jaxlib {jaxlib_version}, '
|
|
1356
|
+
f'optax {optax.__version__}, haiku {hk.__version__}, '
|
|
1252
1357
|
f'numpy {np.__version__}\n'
|
|
1253
1358
|
f'devices: {devices_short}\n')
|
|
1254
1359
|
|
|
1255
1360
|
def summarize_hyperparameters(self) -> None:
|
|
1256
1361
|
print(f'objective hyper-parameters:\n'
|
|
1257
|
-
f' utility_fn
|
|
1258
|
-
f' utility args
|
|
1259
|
-
f' use_symlog
|
|
1260
|
-
f' lookahead
|
|
1261
|
-
f'
|
|
1262
|
-
f' fuzzy logic type={type(self.logic).__name__}\n'
|
|
1263
|
-
f' nonfluents exact={self.compile_non_fluent_exact}\n'
|
|
1264
|
-
f' cpfs_no_gradient={self.cpfs_without_grad}\n'
|
|
1362
|
+
f' utility_fn ={self.utility.__name__}\n'
|
|
1363
|
+
f' utility args ={self.utility_kwargs}\n'
|
|
1364
|
+
f' use_symlog ={self.use_symlog_reward}\n'
|
|
1365
|
+
f' lookahead ={self.horizon}\n'
|
|
1366
|
+
f' user_action_bounds={self._action_bounds}\n'
|
|
1367
|
+
f' fuzzy logic type ={type(self.logic).__name__}\n'
|
|
1368
|
+
f' nonfluents exact ={self.compile_non_fluent_exact}\n'
|
|
1369
|
+
f' cpfs_no_gradient ={self.cpfs_without_grad}\n'
|
|
1265
1370
|
f'optimizer hyper-parameters:\n'
|
|
1266
|
-
f' use_64_bit
|
|
1267
|
-
f' optimizer
|
|
1268
|
-
f' optimizer args
|
|
1269
|
-
f' clip_gradient
|
|
1270
|
-
f' batch_size_train={self.batch_size_train}\n'
|
|
1271
|
-
f' batch_size_test
|
|
1371
|
+
f' use_64_bit ={self.use64bit}\n'
|
|
1372
|
+
f' optimizer ={self._optimizer_name.__name__}\n'
|
|
1373
|
+
f' optimizer args ={self._optimizer_kwargs}\n'
|
|
1374
|
+
f' clip_gradient ={self.clip_grad}\n'
|
|
1375
|
+
f' batch_size_train ={self.batch_size_train}\n'
|
|
1376
|
+
f' batch_size_test ={self.batch_size_test}')
|
|
1272
1377
|
self.plan.summarize_hyperparameters()
|
|
1273
1378
|
self.logic.summarize_hyperparameters()
|
|
1274
1379
|
|
|
@@ -1310,6 +1415,7 @@ class JaxBackpropPlanner:
|
|
|
1310
1415
|
policy=self.plan.train_policy,
|
|
1311
1416
|
n_steps=self.horizon,
|
|
1312
1417
|
n_batch=self.batch_size_train)
|
|
1418
|
+
self.train_rollouts = train_rollouts
|
|
1313
1419
|
|
|
1314
1420
|
test_rollouts = self.test_compiled.compile_rollouts(
|
|
1315
1421
|
policy=self.plan.test_policy,
|
|
@@ -1417,17 +1523,106 @@ class JaxBackpropPlanner:
|
|
|
1417
1523
|
|
|
1418
1524
|
return init_train, init_test
|
|
1419
1525
|
|
|
1526
|
+
def as_optimization_problem(
|
|
1527
|
+
self, key: Optional[random.PRNGKey]=None,
|
|
1528
|
+
policy_hyperparams: Optional[Pytree]=None,
|
|
1529
|
+
loss_function_updates_key: bool=True,
|
|
1530
|
+
grad_function_updates_key: bool=False) -> Tuple[Callable, Callable, np.ndarray, Callable]:
|
|
1531
|
+
'''Returns a function that computes the loss and a function that
|
|
1532
|
+
computes gradient of the return as a 1D vector given a 1D representation
|
|
1533
|
+
of policy parameters. These functions are designed to be compatible with
|
|
1534
|
+
off-the-shelf optimizers such as scipy.
|
|
1535
|
+
|
|
1536
|
+
Also returns the initial parameter vector to seed an optimizer,
|
|
1537
|
+
as well as a mapping that recovers the parameter pytree from the vector.
|
|
1538
|
+
The PRNG key is updated internally starting from the optional given key.
|
|
1539
|
+
|
|
1540
|
+
Constraints on actions, if they are required, cannot be constructed
|
|
1541
|
+
automatically in the general case. The user should build constraints
|
|
1542
|
+
for each problem in the format required by the downstream optimizer.
|
|
1543
|
+
|
|
1544
|
+
:param key: JAX PRNG key (derived from clock if not provided)
|
|
1545
|
+
:param policy_hyperparameters: hyper-parameters for the policy/plan,
|
|
1546
|
+
such as weights for sigmoid wrapping boolean actions (defaults to 1
|
|
1547
|
+
for all action-fluents if not provided)
|
|
1548
|
+
:param loss_function_updates_key: if True, the loss function
|
|
1549
|
+
updates the PRNG key internally independently of the grad function
|
|
1550
|
+
:param grad_function_updates_key: if True, the gradient function
|
|
1551
|
+
updates the PRNG key internally independently of the loss function.
|
|
1552
|
+
'''
|
|
1553
|
+
|
|
1554
|
+
# if PRNG key is not provided
|
|
1555
|
+
if key is None:
|
|
1556
|
+
key = random.PRNGKey(round(time.time() * 1000))
|
|
1557
|
+
|
|
1558
|
+
# initialize the initial fluents, model parameters, policy hyper-params
|
|
1559
|
+
subs = self.test_compiled.init_values
|
|
1560
|
+
train_subs, _ = self._batched_init_subs(subs)
|
|
1561
|
+
model_params = self.compiled.model_params
|
|
1562
|
+
if policy_hyperparams is None:
|
|
1563
|
+
raise_warning('policy_hyperparams is not set, setting 1.0 for '
|
|
1564
|
+
'all action-fluents which could be suboptimal.')
|
|
1565
|
+
policy_hyperparams = {action: 1.0
|
|
1566
|
+
for action in self.rddl.action_fluents}
|
|
1567
|
+
|
|
1568
|
+
# initialize the policy parameters
|
|
1569
|
+
params_guess, *_ = self.initialize(key, policy_hyperparams, train_subs)
|
|
1570
|
+
guess_1d, unravel_fn = jax.flatten_util.ravel_pytree(params_guess)
|
|
1571
|
+
guess_1d = np.asarray(guess_1d)
|
|
1572
|
+
|
|
1573
|
+
# computes the training loss function and its 1D gradient
|
|
1574
|
+
loss_fn = self._jax_loss(self.train_rollouts)
|
|
1575
|
+
|
|
1576
|
+
@jax.jit
|
|
1577
|
+
def _loss_with_key(key, params_1d):
|
|
1578
|
+
policy_params = unravel_fn(params_1d)
|
|
1579
|
+
loss_val, _ = loss_fn(key, policy_params, policy_hyperparams,
|
|
1580
|
+
train_subs, model_params)
|
|
1581
|
+
return loss_val
|
|
1582
|
+
|
|
1583
|
+
@jax.jit
|
|
1584
|
+
def _grad_with_key(key, params_1d):
|
|
1585
|
+
policy_params = unravel_fn(params_1d)
|
|
1586
|
+
grad_fn = jax.grad(loss_fn, argnums=1, has_aux=True)
|
|
1587
|
+
grad_val, _ = grad_fn(key, policy_params, policy_hyperparams,
|
|
1588
|
+
train_subs, model_params)
|
|
1589
|
+
grad_1d = jax.flatten_util.ravel_pytree(grad_val)[0]
|
|
1590
|
+
return grad_1d
|
|
1591
|
+
|
|
1592
|
+
def _loss_function(params_1d):
|
|
1593
|
+
nonlocal key
|
|
1594
|
+
if loss_function_updates_key:
|
|
1595
|
+
key, subkey = random.split(key)
|
|
1596
|
+
else:
|
|
1597
|
+
subkey = key
|
|
1598
|
+
loss_val = _loss_with_key(subkey, params_1d)
|
|
1599
|
+
loss_val = float(loss_val)
|
|
1600
|
+
return loss_val
|
|
1601
|
+
|
|
1602
|
+
def _grad_function(params_1d):
|
|
1603
|
+
nonlocal key
|
|
1604
|
+
if grad_function_updates_key:
|
|
1605
|
+
key, subkey = random.split(key)
|
|
1606
|
+
else:
|
|
1607
|
+
subkey = key
|
|
1608
|
+
grad = _grad_with_key(subkey, params_1d)
|
|
1609
|
+
grad = np.asarray(grad)
|
|
1610
|
+
return grad
|
|
1611
|
+
|
|
1612
|
+
return _loss_function, _grad_function, guess_1d, jax.jit(unravel_fn)
|
|
1613
|
+
|
|
1420
1614
|
# ===========================================================================
|
|
1421
1615
|
# OPTIMIZE API
|
|
1422
1616
|
# ===========================================================================
|
|
1423
1617
|
|
|
1424
1618
|
def optimize(self, *args, **kwargs) -> Dict[str, Any]:
|
|
1425
|
-
'''
|
|
1619
|
+
'''Compute an optimal policy or plan. Return the callback from training.
|
|
1426
1620
|
|
|
1427
1621
|
:param key: JAX PRNG key (derived from clock if not provided)
|
|
1428
1622
|
:param epochs: the maximum number of steps of gradient descent
|
|
1429
1623
|
:param train_seconds: total time allocated for gradient descent
|
|
1430
1624
|
:param plot_step: frequency to plot the plan and save result to disk
|
|
1625
|
+
:param plot_kwargs: additional arguments to pass to the plotter
|
|
1431
1626
|
:param model_params: optional model-parameters to override default
|
|
1432
1627
|
:param policy_hyperparams: hyper-parameters for the policy/plan, such as
|
|
1433
1628
|
weights for sigmoid wrapping boolean actions
|
|
@@ -1435,7 +1630,9 @@ class JaxBackpropPlanner:
|
|
|
1435
1630
|
their values: if None initializes all variables from the RDDL instance
|
|
1436
1631
|
:param guess: initial policy parameters: if None will use the initializer
|
|
1437
1632
|
specified in this instance
|
|
1438
|
-
:param
|
|
1633
|
+
:param print_summary: whether to print planner header, parameter
|
|
1634
|
+
summary, and diagnosis
|
|
1635
|
+
:param print_progress: whether to print the progress bar during training
|
|
1439
1636
|
:param test_rolling_window: the test return is averaged on a rolling
|
|
1440
1637
|
window of the past test_rolling_window returns when updating the best
|
|
1441
1638
|
parameters found so far
|
|
@@ -1461,11 +1658,13 @@ class JaxBackpropPlanner:
|
|
|
1461
1658
|
epochs: int=999999,
|
|
1462
1659
|
train_seconds: float=120.,
|
|
1463
1660
|
plot_step: Optional[int]=None,
|
|
1661
|
+
plot_kwargs: Optional[Dict[str, Any]]=None,
|
|
1464
1662
|
model_params: Optional[Dict[str, Any]]=None,
|
|
1465
1663
|
policy_hyperparams: Optional[Dict[str, Any]]=None,
|
|
1466
1664
|
subs: Optional[Dict[str, Any]]=None,
|
|
1467
1665
|
guess: Optional[Pytree]=None,
|
|
1468
|
-
|
|
1666
|
+
print_summary: bool=True,
|
|
1667
|
+
print_progress: bool=True,
|
|
1469
1668
|
test_rolling_window: int=10,
|
|
1470
1669
|
tqdm_position: Optional[int]=None) -> Generator[Dict[str, Any], None, None]:
|
|
1471
1670
|
'''Returns a generator for computing an optimal policy or plan.
|
|
@@ -1476,20 +1675,22 @@ class JaxBackpropPlanner:
|
|
|
1476
1675
|
:param epochs: the maximum number of steps of gradient descent
|
|
1477
1676
|
:param train_seconds: total time allocated for gradient descent
|
|
1478
1677
|
:param plot_step: frequency to plot the plan and save result to disk
|
|
1678
|
+
:param plot_kwargs: additional arguments to pass to the plotter
|
|
1479
1679
|
:param model_params: optional model-parameters to override default
|
|
1480
1680
|
:param policy_hyperparams: hyper-parameters for the policy/plan, such as
|
|
1481
1681
|
weights for sigmoid wrapping boolean actions
|
|
1482
1682
|
:param subs: dictionary mapping initial state and non-fluents to
|
|
1483
1683
|
their values: if None initializes all variables from the RDDL instance
|
|
1484
1684
|
:param guess: initial policy parameters: if None will use the initializer
|
|
1485
|
-
specified in this instance
|
|
1486
|
-
:param
|
|
1685
|
+
specified in this instance
|
|
1686
|
+
:param print_summary: whether to print planner header, parameter
|
|
1687
|
+
summary, and diagnosis
|
|
1688
|
+
:param print_progress: whether to print the progress bar during training
|
|
1487
1689
|
:param test_rolling_window: the test return is averaged on a rolling
|
|
1488
1690
|
window of the past test_rolling_window returns when updating the best
|
|
1489
1691
|
parameters found so far
|
|
1490
1692
|
:param tqdm_position: position of tqdm progress bar (for multiprocessing)
|
|
1491
1693
|
'''
|
|
1492
|
-
verbose = int(verbose)
|
|
1493
1694
|
start_time = time.time()
|
|
1494
1695
|
elapsed_outside_loop = 0
|
|
1495
1696
|
|
|
@@ -1511,9 +1712,17 @@ class JaxBackpropPlanner:
|
|
|
1511
1712
|
hyperparam_value = float(policy_hyperparams)
|
|
1512
1713
|
policy_hyperparams = {action: hyperparam_value
|
|
1513
1714
|
for action in self.rddl.action_fluents}
|
|
1715
|
+
|
|
1716
|
+
# fill in missing entries
|
|
1717
|
+
elif isinstance(policy_hyperparams, dict):
|
|
1718
|
+
for action in self.rddl.action_fluents:
|
|
1719
|
+
if action not in policy_hyperparams:
|
|
1720
|
+
raise_warning(f'policy_hyperparams[{action}] is not set, '
|
|
1721
|
+
'setting 1.0 which could be suboptimal.')
|
|
1722
|
+
policy_hyperparams[action] = 1.0
|
|
1514
1723
|
|
|
1515
1724
|
# print summary of parameters:
|
|
1516
|
-
if
|
|
1725
|
+
if print_summary:
|
|
1517
1726
|
self._summarize_system()
|
|
1518
1727
|
self.summarize_hyperparameters()
|
|
1519
1728
|
print(f'optimize() call hyper-parameters:\n'
|
|
@@ -1526,8 +1735,10 @@ class JaxBackpropPlanner:
|
|
|
1526
1735
|
f' provide_param_guess={guess is not None}\n'
|
|
1527
1736
|
f' test_rolling_window={test_rolling_window}\n'
|
|
1528
1737
|
f' plot_frequency ={plot_step}\n'
|
|
1529
|
-
f'
|
|
1530
|
-
|
|
1738
|
+
f' plot_kwargs ={plot_kwargs}\n'
|
|
1739
|
+
f' print_summary ={print_summary}\n'
|
|
1740
|
+
f' print_progress ={print_progress}\n')
|
|
1741
|
+
if self.compiled.relaxations:
|
|
1531
1742
|
print('Some RDDL operations are non-differentiable, '
|
|
1532
1743
|
'replacing them with differentiable relaxations:')
|
|
1533
1744
|
print(self.compiled.summarize_model_relaxations())
|
|
@@ -1549,7 +1760,7 @@ class JaxBackpropPlanner:
|
|
|
1549
1760
|
'from the RDDL files.')
|
|
1550
1761
|
train_subs, test_subs = self._batched_init_subs(subs)
|
|
1551
1762
|
|
|
1552
|
-
# initialize
|
|
1763
|
+
# initialize model parameters
|
|
1553
1764
|
if model_params is None:
|
|
1554
1765
|
model_params = self.compiled.model_params
|
|
1555
1766
|
model_params_test = self.test_compiled.model_params
|
|
@@ -1570,27 +1781,40 @@ class JaxBackpropPlanner:
|
|
|
1570
1781
|
rolling_test_loss = RollingMean(test_rolling_window)
|
|
1571
1782
|
log = {}
|
|
1572
1783
|
status = JaxPlannerStatus.NORMAL
|
|
1784
|
+
is_all_zero_fn = lambda x: np.allclose(x, 0)
|
|
1573
1785
|
|
|
1574
1786
|
# initialize plot area
|
|
1575
1787
|
if plot_step is None or plot_step <= 0 or plt is None:
|
|
1576
1788
|
plot = None
|
|
1577
1789
|
else:
|
|
1578
|
-
|
|
1790
|
+
if plot_kwargs is None:
|
|
1791
|
+
plot_kwargs = {}
|
|
1792
|
+
plot = JaxPlannerPlot(self.rddl, self.horizon, **plot_kwargs)
|
|
1579
1793
|
xticks, loss_values = [], []
|
|
1580
1794
|
|
|
1581
1795
|
# training loop
|
|
1582
1796
|
iters = range(epochs)
|
|
1583
|
-
if
|
|
1797
|
+
if print_progress:
|
|
1584
1798
|
iters = tqdm(iters, total=100, position=tqdm_position)
|
|
1799
|
+
position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
|
|
1585
1800
|
|
|
1586
1801
|
for it in iters:
|
|
1587
1802
|
status = JaxPlannerStatus.NORMAL
|
|
1588
1803
|
|
|
1589
1804
|
# update the parameters of the plan
|
|
1590
1805
|
key, subkey = random.split(key)
|
|
1591
|
-
policy_params, converged, opt_state, opt_aux,
|
|
1806
|
+
policy_params, converged, opt_state, opt_aux, \
|
|
1807
|
+
train_loss, train_log = \
|
|
1592
1808
|
self.update(subkey, policy_params, policy_hyperparams,
|
|
1593
1809
|
train_subs, model_params, opt_state, opt_aux)
|
|
1810
|
+
|
|
1811
|
+
# no progress
|
|
1812
|
+
grad_norm_zero, _ = jax.tree_util.tree_flatten(
|
|
1813
|
+
jax.tree_map(is_all_zero_fn, train_log['grad']))
|
|
1814
|
+
if np.all(grad_norm_zero):
|
|
1815
|
+
status = JaxPlannerStatus.NO_PROGRESS
|
|
1816
|
+
|
|
1817
|
+
# constraint satisfaction problem
|
|
1594
1818
|
if not np.all(converged):
|
|
1595
1819
|
raise_warning(
|
|
1596
1820
|
'Projected gradient method for satisfying action concurrency '
|
|
@@ -1598,13 +1822,18 @@ class JaxBackpropPlanner:
|
|
|
1598
1822
|
'invalid for the current instance.', 'red')
|
|
1599
1823
|
status = JaxPlannerStatus.PRECONDITION_POSSIBLY_UNSATISFIED
|
|
1600
1824
|
|
|
1601
|
-
#
|
|
1825
|
+
# numerical error
|
|
1826
|
+
if not np.isfinite(train_loss):
|
|
1827
|
+
raise_warning(
|
|
1828
|
+
f'Aborting JAX planner due to invalid train loss {train_loss}.',
|
|
1829
|
+
'red')
|
|
1830
|
+
status = JaxPlannerStatus.INVALID_GRADIENT
|
|
1831
|
+
|
|
1832
|
+
# evaluate test losses and record best plan so far
|
|
1602
1833
|
test_loss, log = self.test_loss(
|
|
1603
1834
|
subkey, policy_params, policy_hyperparams,
|
|
1604
1835
|
test_subs, model_params_test)
|
|
1605
1836
|
test_loss = rolling_test_loss.update(test_loss)
|
|
1606
|
-
|
|
1607
|
-
# record the best plan so far
|
|
1608
1837
|
if test_loss < best_loss:
|
|
1609
1838
|
best_params, best_loss, best_grad = \
|
|
1610
1839
|
policy_params, test_loss, train_log['grad']
|
|
@@ -1617,15 +1846,17 @@ class JaxBackpropPlanner:
|
|
|
1617
1846
|
action_values = {name: values
|
|
1618
1847
|
for (name, values) in log['fluents'].items()
|
|
1619
1848
|
if name in self.rddl.action_fluents}
|
|
1620
|
-
|
|
1849
|
+
returns = -np.sum(np.asarray(log['reward']), axis=1)
|
|
1850
|
+
plot.redraw(xticks, loss_values, action_values, returns)
|
|
1621
1851
|
|
|
1622
1852
|
# if the progress bar is used
|
|
1623
1853
|
elapsed = time.time() - start_time - elapsed_outside_loop
|
|
1624
|
-
if
|
|
1854
|
+
if print_progress:
|
|
1625
1855
|
iters.n = int(100 * min(1, max(elapsed / train_seconds, it / epochs)))
|
|
1626
1856
|
iters.set_description(
|
|
1627
|
-
f'
|
|
1628
|
-
f'{-test_loss:14.6f} test / {-best_loss:14.6f} best'
|
|
1857
|
+
f'{position_str} {it:6} it / {-train_loss:14.6f} train / '
|
|
1858
|
+
f'{-test_loss:14.6f} test / {-best_loss:14.6f} best / '
|
|
1859
|
+
f'{status.value} status')
|
|
1629
1860
|
|
|
1630
1861
|
# reached computation budget
|
|
1631
1862
|
if elapsed >= train_seconds:
|
|
@@ -1633,19 +1864,6 @@ class JaxBackpropPlanner:
|
|
|
1633
1864
|
if it >= epochs - 1:
|
|
1634
1865
|
status = JaxPlannerStatus.ITER_BUDGET_REACHED
|
|
1635
1866
|
|
|
1636
|
-
# numerical error
|
|
1637
|
-
if not np.isfinite(train_loss):
|
|
1638
|
-
raise_warning(
|
|
1639
|
-
f'Aborting JAX planner due to invalid train loss {train_loss}.',
|
|
1640
|
-
'red')
|
|
1641
|
-
status = JaxPlannerStatus.INVALID_GRADIENT
|
|
1642
|
-
|
|
1643
|
-
# no progress
|
|
1644
|
-
grad_norm_zero, _ = jax.tree_util.tree_flatten(
|
|
1645
|
-
jax.tree_map(lambda x: np.allclose(x, 0), train_log['grad']))
|
|
1646
|
-
if np.all(grad_norm_zero):
|
|
1647
|
-
status = JaxPlannerStatus.NO_PROGRESS
|
|
1648
|
-
|
|
1649
1867
|
# return a callback
|
|
1650
1868
|
start_time_outside = time.time()
|
|
1651
1869
|
yield {
|
|
@@ -1671,7 +1889,7 @@ class JaxBackpropPlanner:
|
|
|
1671
1889
|
break
|
|
1672
1890
|
|
|
1673
1891
|
# release resources
|
|
1674
|
-
if
|
|
1892
|
+
if print_progress:
|
|
1675
1893
|
iters.close()
|
|
1676
1894
|
if plot is not None:
|
|
1677
1895
|
plot.close()
|
|
@@ -1688,7 +1906,7 @@ class JaxBackpropPlanner:
|
|
|
1688
1906
|
f'during test evaluation:\n{messages}', 'red')
|
|
1689
1907
|
|
|
1690
1908
|
# summarize and test for convergence
|
|
1691
|
-
if
|
|
1909
|
+
if print_summary:
|
|
1692
1910
|
grad_norm = jax.tree_map(lambda x: np.linalg.norm(x).item(), best_grad)
|
|
1693
1911
|
diagnosis = self._perform_diagnosis(
|
|
1694
1912
|
last_iter_improve, -train_loss, -test_loss, -best_loss, grad_norm)
|
|
@@ -1698,7 +1916,7 @@ class JaxBackpropPlanner:
|
|
|
1698
1916
|
f' iterations ={it}\n'
|
|
1699
1917
|
f' best_objective={-best_loss}\n'
|
|
1700
1918
|
f' best_grad_norm={grad_norm}\n'
|
|
1701
|
-
f'diagnosis: {diagnosis}\n')
|
|
1919
|
+
f' diagnosis: {diagnosis}\n')
|
|
1702
1920
|
|
|
1703
1921
|
def _perform_diagnosis(self, last_iter_improve,
|
|
1704
1922
|
train_return, test_return, best_return, grad_norm):
|
|
@@ -1778,17 +1996,19 @@ class JaxBackpropPlanner:
|
|
|
1778
1996
|
raise ValueError(f'State dictionary passed to the JAX policy is '
|
|
1779
1997
|
f'grounded, since it contains the key <{var}>, '
|
|
1780
1998
|
f'but a vectorized environment is required: '
|
|
1781
|
-
f'
|
|
1999
|
+
f'make sure vectorized = True in the RDDLEnv.')
|
|
1782
2000
|
|
|
1783
2001
|
# must be numeric array
|
|
1784
2002
|
# exception is for POMDPs at 1st epoch when observ-fluents are None
|
|
1785
|
-
|
|
1786
|
-
|
|
2003
|
+
dtype = np.atleast_1d(values).dtype
|
|
2004
|
+
if not jnp.issubdtype(dtype, jnp.number) \
|
|
2005
|
+
and not jnp.issubdtype(dtype, jnp.bool_):
|
|
1787
2006
|
if step == 0 and var in self.rddl.observ_fluents:
|
|
1788
2007
|
subs[var] = self.test_compiled.init_values[var]
|
|
1789
2008
|
else:
|
|
1790
|
-
raise ValueError(
|
|
1791
|
-
|
|
2009
|
+
raise ValueError(
|
|
2010
|
+
f'Values {values} assigned to p-variable <{var}> are '
|
|
2011
|
+
f'non-numeric of type {dtype}.')
|
|
1792
2012
|
|
|
1793
2013
|
# cast device arrays to numpy
|
|
1794
2014
|
actions = self.test_policy(key, params, policy_hyperparams, step, subs)
|
|
@@ -1801,8 +2021,6 @@ class JaxLineSearchPlanner(JaxBackpropPlanner):
|
|
|
1801
2021
|
linear search gradient descent, with the Armijo condition.'''
|
|
1802
2022
|
|
|
1803
2023
|
def __init__(self, *args,
|
|
1804
|
-
optimizer: Callable[..., optax.GradientTransformation]=optax.sgd,
|
|
1805
|
-
optimizer_kwargs: Kwargs={'learning_rate': 1.0},
|
|
1806
2024
|
decay: float=0.8,
|
|
1807
2025
|
c: float=0.1,
|
|
1808
2026
|
step_max: float=1.0,
|
|
@@ -1825,11 +2043,7 @@ class JaxLineSearchPlanner(JaxBackpropPlanner):
|
|
|
1825
2043
|
raise_warning('clip_grad parameter conflicts with '
|
|
1826
2044
|
'line search planner and will be ignored.', 'red')
|
|
1827
2045
|
del kwargs['clip_grad']
|
|
1828
|
-
super(JaxLineSearchPlanner, self).__init__(
|
|
1829
|
-
*args,
|
|
1830
|
-
optimizer=optimizer,
|
|
1831
|
-
optimizer_kwargs=optimizer_kwargs,
|
|
1832
|
-
**kwargs)
|
|
2046
|
+
super(JaxLineSearchPlanner, self).__init__(*args, **kwargs)
|
|
1833
2047
|
|
|
1834
2048
|
def summarize_hyperparameters(self) -> None:
|
|
1835
2049
|
super(JaxLineSearchPlanner, self).summarize_hyperparameters()
|
|
@@ -1878,7 +2092,8 @@ class JaxLineSearchPlanner(JaxBackpropPlanner):
|
|
|
1878
2092
|
step = lrmax / decay
|
|
1879
2093
|
f_step = np.inf
|
|
1880
2094
|
best_f, best_step, best_params, best_state = np.inf, None, None, None
|
|
1881
|
-
while f_step > f - c * step * gnorm2 and step * decay >= lrmin
|
|
2095
|
+
while (f_step > f - c * step * gnorm2 and step * decay >= lrmin) \
|
|
2096
|
+
or not trials:
|
|
1882
2097
|
trials += 1
|
|
1883
2098
|
step *= decay
|
|
1884
2099
|
f_step, new_params, new_state = _jax_wrapped_line_search_trial(
|
|
@@ -1913,12 +2128,12 @@ class JaxLineSearchPlanner(JaxBackpropPlanner):
|
|
|
1913
2128
|
@jax.jit
|
|
1914
2129
|
def entropic_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
1915
2130
|
return (-1.0 / beta) * jax.scipy.special.logsumexp(
|
|
1916
|
-
|
|
2131
|
+
-beta * returns, b=1.0 / returns.size)
|
|
1917
2132
|
|
|
1918
2133
|
|
|
1919
2134
|
@jax.jit
|
|
1920
2135
|
def mean_variance_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
1921
|
-
return jnp.mean(returns) -
|
|
2136
|
+
return jnp.mean(returns) - 0.5 * beta * jnp.var(returns)
|
|
1922
2137
|
|
|
1923
2138
|
|
|
1924
2139
|
@jax.jit
|
|
@@ -1986,7 +2201,8 @@ class JaxOfflineController(BaseAgent):
|
|
|
1986
2201
|
def reset(self) -> None:
|
|
1987
2202
|
self.step = 0
|
|
1988
2203
|
if self.train_on_reset and not self.params_given:
|
|
1989
|
-
|
|
2204
|
+
callback = self.planner.optimize(key=self.key, **self.train_kwargs)
|
|
2205
|
+
self.params = callback['best_params']
|
|
1990
2206
|
|
|
1991
2207
|
|
|
1992
2208
|
class JaxOnlineController(BaseAgent):
|