pyRDDLGym-jax 0.3__py3-none-any.whl → 0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +90 -67
- pyRDDLGym_jax/core/logic.py +286 -82
- pyRDDLGym_jax/core/planner.py +191 -97
- pyRDDLGym_jax/core/simulator.py +2 -1
- pyRDDLGym_jax/core/tuning.py +58 -63
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +2 -1
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +2 -1
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +2 -1
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +4 -3
- pyRDDLGym_jax/examples/configs/default_replan.cfg +2 -1
- pyRDDLGym_jax/examples/run_tune.py +1 -3
- pyRDDLGym_jax-0.5.dist-info/METADATA +278 -0
- {pyRDDLGym_jax-0.3.dist-info → pyRDDLGym_jax-0.5.dist-info}/RECORD +17 -17
- {pyRDDLGym_jax-0.3.dist-info → pyRDDLGym_jax-0.5.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax-0.3.dist-info/METADATA +0 -26
- {pyRDDLGym_jax-0.3.dist-info → pyRDDLGym_jax-0.5.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-0.3.dist-info → pyRDDLGym_jax-0.5.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/planner.py
CHANGED
|
@@ -2,54 +2,50 @@ from ast import literal_eval
|
|
|
2
2
|
from collections import deque
|
|
3
3
|
import configparser
|
|
4
4
|
from enum import Enum
|
|
5
|
+
import os
|
|
6
|
+
import sys
|
|
7
|
+
import time
|
|
8
|
+
import traceback
|
|
9
|
+
from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Tuple, Union
|
|
10
|
+
|
|
5
11
|
import haiku as hk
|
|
6
12
|
import jax
|
|
13
|
+
import jax.nn.initializers as initializers
|
|
7
14
|
import jax.numpy as jnp
|
|
8
15
|
import jax.random as random
|
|
9
|
-
import jax.nn.initializers as initializers
|
|
10
16
|
import numpy as np
|
|
11
17
|
import optax
|
|
12
|
-
import os
|
|
13
|
-
import sys
|
|
14
18
|
import termcolor
|
|
15
|
-
import time
|
|
16
|
-
import traceback
|
|
17
19
|
from tqdm import tqdm
|
|
18
|
-
from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Tuple, Union
|
|
19
20
|
|
|
20
|
-
Activation = Callable[[jnp.ndarray], jnp.ndarray]
|
|
21
|
-
Bounds = Dict[str, Tuple[np.ndarray, np.ndarray]]
|
|
22
|
-
Kwargs = Dict[str, Any]
|
|
23
|
-
Pytree = Any
|
|
24
|
-
|
|
25
|
-
from pyRDDLGym.core.debug.exception import raise_warning
|
|
26
|
-
|
|
27
|
-
from pyRDDLGym_jax import __version__
|
|
28
|
-
|
|
29
|
-
# try to import matplotlib, if failed then skip plotting
|
|
30
|
-
try:
|
|
31
|
-
import matplotlib
|
|
32
|
-
import matplotlib.pyplot as plt
|
|
33
|
-
matplotlib.use('TkAgg')
|
|
34
|
-
except Exception:
|
|
35
|
-
raise_warning('failed to import matplotlib: '
|
|
36
|
-
'plotting functionality will be disabled.', 'red')
|
|
37
|
-
traceback.print_exc()
|
|
38
|
-
plt = None
|
|
39
|
-
|
|
40
21
|
from pyRDDLGym.core.compiler.model import RDDLPlanningModel, RDDLLiftedModel
|
|
41
22
|
from pyRDDLGym.core.debug.logger import Logger
|
|
42
23
|
from pyRDDLGym.core.debug.exception import (
|
|
24
|
+
raise_warning,
|
|
43
25
|
RDDLNotImplementedError,
|
|
44
26
|
RDDLUndefinedVariableError,
|
|
45
27
|
RDDLTypeError
|
|
46
28
|
)
|
|
47
29
|
from pyRDDLGym.core.policy import BaseAgent
|
|
48
30
|
|
|
49
|
-
from pyRDDLGym_jax
|
|
31
|
+
from pyRDDLGym_jax import __version__
|
|
50
32
|
from pyRDDLGym_jax.core import logic
|
|
33
|
+
from pyRDDLGym_jax.core.compiler import JaxRDDLCompiler
|
|
51
34
|
from pyRDDLGym_jax.core.logic import FuzzyLogic
|
|
52
35
|
|
|
36
|
+
# try to import matplotlib, if failed then skip plotting
|
|
37
|
+
try:
|
|
38
|
+
import matplotlib.pyplot as plt
|
|
39
|
+
except Exception:
|
|
40
|
+
raise_warning('failed to import matplotlib: '
|
|
41
|
+
'plotting functionality will be disabled.', 'red')
|
|
42
|
+
traceback.print_exc()
|
|
43
|
+
plt = None
|
|
44
|
+
|
|
45
|
+
Activation = Callable[[jnp.ndarray], jnp.ndarray]
|
|
46
|
+
Bounds = Dict[str, Tuple[np.ndarray, np.ndarray]]
|
|
47
|
+
Kwargs = Dict[str, Any]
|
|
48
|
+
Pytree = Any
|
|
53
49
|
|
|
54
50
|
# ***********************************************************************
|
|
55
51
|
# CONFIG FILE MANAGEMENT
|
|
@@ -60,6 +56,7 @@ from pyRDDLGym_jax.core.logic import FuzzyLogic
|
|
|
60
56
|
#
|
|
61
57
|
# ***********************************************************************
|
|
62
58
|
|
|
59
|
+
|
|
63
60
|
def _parse_config_file(path: str):
|
|
64
61
|
if not os.path.isfile(path):
|
|
65
62
|
raise FileNotFoundError(f'File {path} does not exist.')
|
|
@@ -104,9 +101,15 @@ def _load_config(config, args):
|
|
|
104
101
|
comp_kwargs = model_args.get('complement_kwargs', {})
|
|
105
102
|
compare_name = model_args.get('comparison', 'SigmoidComparison')
|
|
106
103
|
compare_kwargs = model_args.get('comparison_kwargs', {})
|
|
104
|
+
sampling_name = model_args.get('sampling', 'GumbelSoftmax')
|
|
105
|
+
sampling_kwargs = model_args.get('sampling_kwargs', {})
|
|
106
|
+
rounding_name = model_args.get('rounding', 'SoftRounding')
|
|
107
|
+
rounding_kwargs = model_args.get('rounding_kwargs', {})
|
|
107
108
|
logic_kwargs['tnorm'] = getattr(logic, tnorm_name)(**tnorm_kwargs)
|
|
108
109
|
logic_kwargs['complement'] = getattr(logic, comp_name)(**comp_kwargs)
|
|
109
110
|
logic_kwargs['comparison'] = getattr(logic, compare_name)(**compare_kwargs)
|
|
111
|
+
logic_kwargs['sampling'] = getattr(logic, sampling_name)(**sampling_kwargs)
|
|
112
|
+
logic_kwargs['rounding'] = getattr(logic, rounding_name)(**rounding_kwargs)
|
|
110
113
|
|
|
111
114
|
# read the policy settings
|
|
112
115
|
plan_method = planner_args.pop('method')
|
|
@@ -157,11 +160,18 @@ def _load_config(config, args):
|
|
|
157
160
|
else:
|
|
158
161
|
planner_args['optimizer'] = optimizer
|
|
159
162
|
|
|
160
|
-
#
|
|
163
|
+
# optimize call RNG key
|
|
161
164
|
planner_key = train_args.get('key', None)
|
|
162
165
|
if planner_key is not None:
|
|
163
166
|
train_args['key'] = random.PRNGKey(planner_key)
|
|
164
167
|
|
|
168
|
+
# optimize call stopping rule
|
|
169
|
+
stopping_rule = train_args.get('stopping_rule', None)
|
|
170
|
+
if stopping_rule is not None:
|
|
171
|
+
stopping_rule_kwargs = train_args.pop('stopping_rule_kwargs', {})
|
|
172
|
+
train_args['stopping_rule'] = getattr(
|
|
173
|
+
sys.modules[__name__], stopping_rule)(**stopping_rule_kwargs)
|
|
174
|
+
|
|
165
175
|
return planner_args, plan_kwargs, train_args
|
|
166
176
|
|
|
167
177
|
|
|
@@ -175,7 +185,6 @@ def load_config_from_string(value: str) -> Tuple[Kwargs, ...]:
|
|
|
175
185
|
'''Loads config file contents specified explicitly as a string value.'''
|
|
176
186
|
config, args = _parse_config_string(value)
|
|
177
187
|
return _load_config(config, args)
|
|
178
|
-
|
|
179
188
|
|
|
180
189
|
# ***********************************************************************
|
|
181
190
|
# MODEL RELAXATIONS
|
|
@@ -184,18 +193,6 @@ def load_config_from_string(value: str) -> Tuple[Kwargs, ...]:
|
|
|
184
193
|
#
|
|
185
194
|
# ***********************************************************************
|
|
186
195
|
|
|
187
|
-
def _function_discrete_approx_named(logic):
|
|
188
|
-
jax_discrete, jax_param = logic.discrete()
|
|
189
|
-
|
|
190
|
-
def _jax_wrapped_discrete_calc_approx(key, prob, params):
|
|
191
|
-
sample = jax_discrete(key, prob, params)
|
|
192
|
-
out_of_bounds = jnp.logical_not(jnp.logical_and(
|
|
193
|
-
jnp.all(prob >= 0),
|
|
194
|
-
jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
|
|
195
|
-
return sample, out_of_bounds
|
|
196
|
-
|
|
197
|
-
return _jax_wrapped_discrete_calc_approx, jax_param
|
|
198
|
-
|
|
199
196
|
|
|
200
197
|
class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
201
198
|
'''Compiles a RDDL AST representation to an equivalent JAX representation.
|
|
@@ -271,7 +268,9 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
271
268
|
self.IF_HELPER = logic.control_if()
|
|
272
269
|
self.SWITCH_HELPER = logic.control_switch()
|
|
273
270
|
self.BERNOULLI_HELPER = logic.bernoulli()
|
|
274
|
-
self.DISCRETE_HELPER =
|
|
271
|
+
self.DISCRETE_HELPER = logic.discrete()
|
|
272
|
+
self.POISSON_HELPER = logic.poisson()
|
|
273
|
+
self.GEOMETRIC_HELPER = logic.geometric()
|
|
275
274
|
|
|
276
275
|
def _jax_stop_grad(self, jax_expr):
|
|
277
276
|
|
|
@@ -309,7 +308,6 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
309
308
|
arg, = expr.args
|
|
310
309
|
arg = self._jax(arg, info)
|
|
311
310
|
return arg
|
|
312
|
-
|
|
313
311
|
|
|
314
312
|
# ***********************************************************************
|
|
315
313
|
# ALL VERSIONS OF JAX PLANS
|
|
@@ -319,6 +317,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
319
317
|
#
|
|
320
318
|
# ***********************************************************************
|
|
321
319
|
|
|
320
|
+
|
|
322
321
|
class JaxPlan:
|
|
323
322
|
'''Base class for all JAX policy representations.'''
|
|
324
323
|
|
|
@@ -373,7 +372,7 @@ class JaxPlan:
|
|
|
373
372
|
self._projection = value
|
|
374
373
|
|
|
375
374
|
def _calculate_action_info(self, compiled: JaxRDDLCompilerWithGrad,
|
|
376
|
-
user_bounds: Bounds,
|
|
375
|
+
user_bounds: Bounds,
|
|
377
376
|
horizon: int):
|
|
378
377
|
shapes, bounds, bounds_safe, cond_lists = {}, {}, {}, {}
|
|
379
378
|
for (name, prange) in compiled.rddl.variable_ranges.items():
|
|
@@ -469,10 +468,11 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
469
468
|
f' wrap_non_bool ={self._wrap_non_bool}\n'
|
|
470
469
|
f'constraint-sat strategy (complex):\n'
|
|
471
470
|
f' wrap_softmax ={self._wrap_softmax}\n'
|
|
472
|
-
f' use_new_projection ={self._use_new_projection}'
|
|
471
|
+
f' use_new_projection ={self._use_new_projection}\n'
|
|
472
|
+
f' max_projection_iters ={self._max_constraint_iter}')
|
|
473
473
|
|
|
474
474
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
475
|
-
_bounds: Bounds,
|
|
475
|
+
_bounds: Bounds,
|
|
476
476
|
horizon: int) -> None:
|
|
477
477
|
rddl = compiled.rddl
|
|
478
478
|
|
|
@@ -513,7 +513,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
513
513
|
def _jax_bool_action_to_param(var, action, hyperparams):
|
|
514
514
|
if wrap_sigmoid:
|
|
515
515
|
weight = hyperparams[var]
|
|
516
|
-
return
|
|
516
|
+
return jax.scipy.special.logit(action) / weight
|
|
517
517
|
else:
|
|
518
518
|
return action
|
|
519
519
|
|
|
@@ -522,14 +522,13 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
522
522
|
def _jax_non_bool_param_to_action(var, param, hyperparams):
|
|
523
523
|
if wrap_non_bool:
|
|
524
524
|
lower, upper = bounds_safe[var]
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
]
|
|
525
|
+
mb, ml, mu, mn = [mask.astype(compiled.REAL)
|
|
526
|
+
for mask in cond_lists[var]]
|
|
527
|
+
action = (
|
|
528
|
+
mb * (lower + (upper - lower) * jax.nn.sigmoid(param)) +
|
|
529
|
+
ml * (lower + (jax.nn.elu(param) + 1.0)) +
|
|
530
|
+
mu * (upper - (jax.nn.elu(-param) + 1.0)) +
|
|
531
|
+
mn * param
|
|
533
532
|
)
|
|
534
533
|
else:
|
|
535
534
|
action = param
|
|
@@ -789,7 +788,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
789
788
|
def __init__(self, topology: Optional[Sequence[int]]=None,
|
|
790
789
|
activation: Activation=jnp.tanh,
|
|
791
790
|
initializer: hk.initializers.Initializer=hk.initializers.VarianceScaling(scale=2.0),
|
|
792
|
-
normalize: bool=False,
|
|
791
|
+
normalize: bool=False,
|
|
793
792
|
normalize_per_layer: bool=False,
|
|
794
793
|
normalizer_kwargs: Optional[Kwargs]=None,
|
|
795
794
|
wrap_non_bool: bool=False) -> None:
|
|
@@ -837,7 +836,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
837
836
|
f' wrap_non_bool ={self._wrap_non_bool}')
|
|
838
837
|
|
|
839
838
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
840
|
-
_bounds: Bounds,
|
|
839
|
+
_bounds: Bounds,
|
|
841
840
|
horizon: int) -> None:
|
|
842
841
|
rddl = compiled.rddl
|
|
843
842
|
|
|
@@ -890,7 +889,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
890
889
|
if normalize_per_layer and value_size == 1:
|
|
891
890
|
raise_warning(
|
|
892
891
|
f'Cannot apply layer norm to state-fluent <{var}> '
|
|
893
|
-
f'of size 1: setting normalize_per_layer = False.',
|
|
892
|
+
f'of size 1: setting normalize_per_layer = False.',
|
|
894
893
|
'red')
|
|
895
894
|
normalize_per_layer = False
|
|
896
895
|
non_bool_dims += value_size
|
|
@@ -915,8 +914,8 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
915
914
|
else:
|
|
916
915
|
if normalize and normalize_per_layer:
|
|
917
916
|
normalizer = hk.LayerNorm(
|
|
918
|
-
axis=-1, param_axis=-1,
|
|
919
|
-
name=f'input_norm_{input_names[var]}',
|
|
917
|
+
axis=-1, param_axis=-1,
|
|
918
|
+
name=f'input_norm_{input_names[var]}',
|
|
920
919
|
**self._normalizer_kwargs)
|
|
921
920
|
state = normalizer(state)
|
|
922
921
|
states_non_bool.append(state)
|
|
@@ -926,7 +925,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
926
925
|
# optionally perform layer normalization on the non-bool inputs
|
|
927
926
|
if normalize and not normalize_per_layer and non_bool_dims:
|
|
928
927
|
normalizer = hk.LayerNorm(
|
|
929
|
-
axis=-1, param_axis=-1, name='input_norm',
|
|
928
|
+
axis=-1, param_axis=-1, name='input_norm',
|
|
930
929
|
**self._normalizer_kwargs)
|
|
931
930
|
normalized = normalizer(state[:non_bool_dims])
|
|
932
931
|
state = state.at[:non_bool_dims].set(normalized)
|
|
@@ -959,14 +958,13 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
959
958
|
else:
|
|
960
959
|
if wrap_non_bool:
|
|
961
960
|
lower, upper = bounds_safe[var]
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
]
|
|
961
|
+
mb, ml, mu, mn = [mask.astype(compiled.REAL)
|
|
962
|
+
for mask in cond_lists[var]]
|
|
963
|
+
action = (
|
|
964
|
+
mb * (lower + (upper - lower) * jax.nn.sigmoid(output)) +
|
|
965
|
+
ml * (lower + (jax.nn.elu(output) + 1.0)) +
|
|
966
|
+
mu * (upper - (jax.nn.elu(-output) + 1.0)) +
|
|
967
|
+
mn * output
|
|
970
968
|
)
|
|
971
969
|
else:
|
|
972
970
|
action = output
|
|
@@ -1058,7 +1056,6 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1058
1056
|
|
|
1059
1057
|
def guess_next_epoch(self, params: Pytree) -> Pytree:
|
|
1060
1058
|
return params
|
|
1061
|
-
|
|
1062
1059
|
|
|
1063
1060
|
# ***********************************************************************
|
|
1064
1061
|
# ALL VERSIONS OF JAX PLANNER
|
|
@@ -1068,6 +1065,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1068
1065
|
#
|
|
1069
1066
|
# ***********************************************************************
|
|
1070
1067
|
|
|
1068
|
+
|
|
1071
1069
|
class RollingMean:
|
|
1072
1070
|
'''Maintains an estimate of the rolling mean of a stream of real-valued
|
|
1073
1071
|
observations.'''
|
|
@@ -1089,7 +1087,7 @@ class RollingMean:
|
|
|
1089
1087
|
class JaxPlannerPlot:
|
|
1090
1088
|
'''Supports plotting and visualization of a JAX policy in real time.'''
|
|
1091
1089
|
|
|
1092
|
-
def __init__(self, rddl: RDDLPlanningModel, horizon: int,
|
|
1090
|
+
def __init__(self, rddl: RDDLPlanningModel, horizon: int,
|
|
1093
1091
|
show_violin: bool=True, show_action: bool=True) -> None:
|
|
1094
1092
|
'''Creates a new planner visualizer.
|
|
1095
1093
|
|
|
@@ -1137,7 +1135,7 @@ class JaxPlannerPlot:
|
|
|
1137
1135
|
for dim in rddl.object_counts(rddl.variable_params[name]):
|
|
1138
1136
|
action_dim *= dim
|
|
1139
1137
|
action_plot = ax.pcolormesh(
|
|
1140
|
-
np.zeros((action_dim, horizon)),
|
|
1138
|
+
np.zeros((action_dim, horizon)),
|
|
1141
1139
|
cmap='seismic', vmin=vmin, vmax=vmax)
|
|
1142
1140
|
ax.set_aspect('auto')
|
|
1143
1141
|
ax.set_xlabel('decision epoch')
|
|
@@ -1210,6 +1208,39 @@ class JaxPlannerStatus(Enum):
|
|
|
1210
1208
|
return self.value >= 3
|
|
1211
1209
|
|
|
1212
1210
|
|
|
1211
|
+
class JaxPlannerStoppingRule:
|
|
1212
|
+
'''The base class of all planner stopping rules.'''
|
|
1213
|
+
|
|
1214
|
+
def reset(self) -> None:
|
|
1215
|
+
raise NotImplementedError
|
|
1216
|
+
|
|
1217
|
+
def monitor(self, callback: Dict[str, Any]) -> bool:
|
|
1218
|
+
raise NotImplementedError
|
|
1219
|
+
|
|
1220
|
+
|
|
1221
|
+
class NoImprovementStoppingRule(JaxPlannerStoppingRule):
|
|
1222
|
+
'''Stopping rule based on no improvement for a fixed number of iterations.'''
|
|
1223
|
+
|
|
1224
|
+
def __init__(self, patience: int) -> None:
|
|
1225
|
+
self.patience = patience
|
|
1226
|
+
|
|
1227
|
+
def reset(self) -> None:
|
|
1228
|
+
self.callback = None
|
|
1229
|
+
self.iters_since_last_update = 0
|
|
1230
|
+
|
|
1231
|
+
def monitor(self, callback: Dict[str, Any]) -> bool:
|
|
1232
|
+
if self.callback is None \
|
|
1233
|
+
or callback['best_return'] > self.callback['best_return']:
|
|
1234
|
+
self.callback = callback
|
|
1235
|
+
self.iters_since_last_update = 0
|
|
1236
|
+
else:
|
|
1237
|
+
self.iters_since_last_update += 1
|
|
1238
|
+
return self.iters_since_last_update >= self.patience
|
|
1239
|
+
|
|
1240
|
+
def __str__(self) -> str:
|
|
1241
|
+
return f'No improvement for {self.patience} iterations'
|
|
1242
|
+
|
|
1243
|
+
|
|
1213
1244
|
class JaxBackpropPlanner:
|
|
1214
1245
|
'''A class for optimizing an action sequence in the given RDDL MDP using
|
|
1215
1246
|
gradient descent.'''
|
|
@@ -1224,6 +1255,8 @@ class JaxBackpropPlanner:
|
|
|
1224
1255
|
optimizer: Callable[..., optax.GradientTransformation]=optax.rmsprop,
|
|
1225
1256
|
optimizer_kwargs: Optional[Kwargs]=None,
|
|
1226
1257
|
clip_grad: Optional[float]=None,
|
|
1258
|
+
noise_grad_eta: float=0.0,
|
|
1259
|
+
noise_grad_gamma: float=1.0,
|
|
1227
1260
|
logic: FuzzyLogic=FuzzyLogic(),
|
|
1228
1261
|
use_symlog_reward: bool=False,
|
|
1229
1262
|
utility: Union[Callable[[jnp.ndarray], float], str]='mean',
|
|
@@ -1250,6 +1283,8 @@ class JaxBackpropPlanner:
|
|
|
1250
1283
|
:param optimizer_kwargs: a dictionary of parameters to pass to the SGD
|
|
1251
1284
|
factory (e.g. which parameters are controllable externally)
|
|
1252
1285
|
:param clip_grad: maximum magnitude of gradient updates
|
|
1286
|
+
:param noise_grad_eta: scale of the gradient noise variance
|
|
1287
|
+
:param noise_grad_gamma: decay rate of the gradient noise variance
|
|
1253
1288
|
:param logic: a subclass of FuzzyLogic for mapping exact mathematical
|
|
1254
1289
|
operations to their differentiable counterparts
|
|
1255
1290
|
:param use_symlog_reward: whether to use the symlog transform on the
|
|
@@ -1284,6 +1319,8 @@ class JaxBackpropPlanner:
|
|
|
1284
1319
|
optimizer_kwargs = {'learning_rate': 0.1}
|
|
1285
1320
|
self._optimizer_kwargs = optimizer_kwargs
|
|
1286
1321
|
self.clip_grad = clip_grad
|
|
1322
|
+
self.noise_grad_eta = noise_grad_eta
|
|
1323
|
+
self.noise_grad_gamma = noise_grad_gamma
|
|
1287
1324
|
|
|
1288
1325
|
# set optimizer
|
|
1289
1326
|
try:
|
|
@@ -1348,8 +1385,18 @@ class JaxBackpropPlanner:
|
|
|
1348
1385
|
map(str, jax._src.xla_bridge.devices())).replace('\n', '')
|
|
1349
1386
|
except Exception as _:
|
|
1350
1387
|
devices_short = 'N/A'
|
|
1388
|
+
LOGO = \
|
|
1389
|
+
r"""
|
|
1390
|
+
__ ______ __ __ ______ __ ______ __ __
|
|
1391
|
+
/\ \ /\ __ \ /\_\_\_\ /\ == \/\ \ /\ __ \ /\ "-.\ \
|
|
1392
|
+
_\_\ \\ \ __ \\/_/\_\/_\ \ _-/\ \ \____\ \ __ \\ \ \-. \
|
|
1393
|
+
/\_____\\ \_\ \_\ /\_\/\_\\ \_\ \ \_____\\ \_\ \_\\ \_\\"\_\
|
|
1394
|
+
\/_____/ \/_/\/_/ \/_/\/_/ \/_/ \/_____/ \/_/\/_/ \/_/ \/_/
|
|
1395
|
+
"""
|
|
1396
|
+
|
|
1351
1397
|
print('\n'
|
|
1352
|
-
f'
|
|
1398
|
+
f'{LOGO}\n'
|
|
1399
|
+
f'Version {__version__}\n'
|
|
1353
1400
|
f'Python {sys.version}\n'
|
|
1354
1401
|
f'jax {jax.version.__version__}, jaxlib {jaxlib_version}, '
|
|
1355
1402
|
f'optax {optax.__version__}, haiku {hk.__version__}, '
|
|
@@ -1371,6 +1418,8 @@ class JaxBackpropPlanner:
|
|
|
1371
1418
|
f' optimizer ={self._optimizer_name.__name__}\n'
|
|
1372
1419
|
f' optimizer args ={self._optimizer_kwargs}\n'
|
|
1373
1420
|
f' clip_gradient ={self.clip_grad}\n'
|
|
1421
|
+
f' noise_grad_eta ={self.noise_grad_eta}\n'
|
|
1422
|
+
f' noise_grad_gamma ={self.noise_grad_gamma}\n'
|
|
1374
1423
|
f' batch_size_train ={self.batch_size_train}\n'
|
|
1375
1424
|
f' batch_size_test ={self.batch_size_test}')
|
|
1376
1425
|
self.plan.summarize_hyperparameters()
|
|
@@ -1395,7 +1444,7 @@ class JaxBackpropPlanner:
|
|
|
1395
1444
|
|
|
1396
1445
|
# Jax compilation of the exact RDDL for testing
|
|
1397
1446
|
self.test_compiled = JaxRDDLCompiler(
|
|
1398
|
-
rddl=rddl,
|
|
1447
|
+
rddl=rddl,
|
|
1399
1448
|
logger=self.logger,
|
|
1400
1449
|
use64bit=self.use64bit)
|
|
1401
1450
|
self.test_compiled.compile(log_jax_expr=True, heading='EXACT MODEL')
|
|
@@ -1472,7 +1521,7 @@ class JaxBackpropPlanner:
|
|
|
1472
1521
|
def _jax_wrapped_init_policy(key, hyperparams, subs):
|
|
1473
1522
|
policy_params = init(key, hyperparams, subs)
|
|
1474
1523
|
opt_state = optimizer.init(policy_params)
|
|
1475
|
-
return policy_params, opt_state,
|
|
1524
|
+
return policy_params, opt_state, {}
|
|
1476
1525
|
|
|
1477
1526
|
return _jax_wrapped_init_policy
|
|
1478
1527
|
|
|
@@ -1480,6 +1529,19 @@ class JaxBackpropPlanner:
|
|
|
1480
1529
|
optimizer = self.optimizer
|
|
1481
1530
|
projection = self.plan.projection
|
|
1482
1531
|
|
|
1532
|
+
# add Gaussian gradient noise per Neelakantan et al., 2016.
|
|
1533
|
+
def _jax_wrapped_gaussian_param_noise(key, grads, sigma):
|
|
1534
|
+
treedef = jax.tree_util.tree_structure(grads)
|
|
1535
|
+
keys_flat = random.split(key, num=treedef.num_leaves)
|
|
1536
|
+
keys_tree = jax.tree_util.tree_unflatten(treedef, keys_flat)
|
|
1537
|
+
new_grads = jax.tree_map(
|
|
1538
|
+
lambda g, k: g + sigma * random.normal(
|
|
1539
|
+
key=k, shape=g.shape, dtype=g.dtype),
|
|
1540
|
+
grads,
|
|
1541
|
+
keys_tree
|
|
1542
|
+
)
|
|
1543
|
+
return new_grads
|
|
1544
|
+
|
|
1483
1545
|
# calculate the plan gradient w.r.t. return loss and update optimizer
|
|
1484
1546
|
# also perform a projection step to satisfy constraints on actions
|
|
1485
1547
|
def _jax_wrapped_plan_update(key, policy_params, hyperparams,
|
|
@@ -1487,12 +1549,14 @@ class JaxBackpropPlanner:
|
|
|
1487
1549
|
grad_fn = jax.value_and_grad(loss, argnums=1, has_aux=True)
|
|
1488
1550
|
(loss_val, log), grad = grad_fn(
|
|
1489
1551
|
key, policy_params, hyperparams, subs, model_params)
|
|
1552
|
+
sigma = opt_aux.get('noise_sigma', 0.0)
|
|
1553
|
+
grad = _jax_wrapped_gaussian_param_noise(key, grad, sigma)
|
|
1490
1554
|
updates, opt_state = optimizer.update(grad, opt_state)
|
|
1491
1555
|
policy_params = optax.apply_updates(policy_params, updates)
|
|
1492
1556
|
policy_params, converged = projection(policy_params, hyperparams)
|
|
1493
1557
|
log['grad'] = grad
|
|
1494
1558
|
log['updates'] = updates
|
|
1495
|
-
return policy_params, converged, opt_state,
|
|
1559
|
+
return policy_params, converged, opt_state, opt_aux, loss_val, log
|
|
1496
1560
|
|
|
1497
1561
|
return jax.jit(_jax_wrapped_plan_update)
|
|
1498
1562
|
|
|
@@ -1523,7 +1587,7 @@ class JaxBackpropPlanner:
|
|
|
1523
1587
|
return init_train, init_test
|
|
1524
1588
|
|
|
1525
1589
|
def as_optimization_problem(
|
|
1526
|
-
self, key: Optional[random.PRNGKey]=None,
|
|
1590
|
+
self, key: Optional[random.PRNGKey]=None,
|
|
1527
1591
|
policy_hyperparams: Optional[Pytree]=None,
|
|
1528
1592
|
loss_function_updates_key: bool=True,
|
|
1529
1593
|
grad_function_updates_key: bool=False) -> Tuple[Callable, Callable, np.ndarray, Callable]:
|
|
@@ -1575,7 +1639,7 @@ class JaxBackpropPlanner:
|
|
|
1575
1639
|
@jax.jit
|
|
1576
1640
|
def _loss_with_key(key, params_1d):
|
|
1577
1641
|
policy_params = unravel_fn(params_1d)
|
|
1578
|
-
loss_val, _ = loss_fn(key, policy_params, policy_hyperparams,
|
|
1642
|
+
loss_val, _ = loss_fn(key, policy_params, policy_hyperparams,
|
|
1579
1643
|
train_subs, model_params)
|
|
1580
1644
|
return loss_val
|
|
1581
1645
|
|
|
@@ -1583,7 +1647,7 @@ class JaxBackpropPlanner:
|
|
|
1583
1647
|
def _grad_with_key(key, params_1d):
|
|
1584
1648
|
policy_params = unravel_fn(params_1d)
|
|
1585
1649
|
grad_fn = jax.grad(loss_fn, argnums=1, has_aux=True)
|
|
1586
|
-
grad_val, _ = grad_fn(key, policy_params, policy_hyperparams,
|
|
1650
|
+
grad_val, _ = grad_fn(key, policy_params, policy_hyperparams,
|
|
1587
1651
|
train_subs, model_params)
|
|
1588
1652
|
grad_1d = jax.flatten_util.ravel_pytree(grad_val)[0]
|
|
1589
1653
|
return grad_1d
|
|
@@ -1632,6 +1696,7 @@ class JaxBackpropPlanner:
|
|
|
1632
1696
|
:param print_summary: whether to print planner header, parameter
|
|
1633
1697
|
summary, and diagnosis
|
|
1634
1698
|
:param print_progress: whether to print the progress bar during training
|
|
1699
|
+
:param stopping_rule: stopping criterion
|
|
1635
1700
|
:param test_rolling_window: the test return is averaged on a rolling
|
|
1636
1701
|
window of the past test_rolling_window returns when updating the best
|
|
1637
1702
|
parameters found so far
|
|
@@ -1657,13 +1722,14 @@ class JaxBackpropPlanner:
|
|
|
1657
1722
|
epochs: int=999999,
|
|
1658
1723
|
train_seconds: float=120.,
|
|
1659
1724
|
plot_step: Optional[int]=None,
|
|
1660
|
-
plot_kwargs: Optional[
|
|
1725
|
+
plot_kwargs: Optional[Kwargs]=None,
|
|
1661
1726
|
model_params: Optional[Dict[str, Any]]=None,
|
|
1662
1727
|
policy_hyperparams: Optional[Dict[str, Any]]=None,
|
|
1663
1728
|
subs: Optional[Dict[str, Any]]=None,
|
|
1664
1729
|
guess: Optional[Pytree]=None,
|
|
1665
1730
|
print_summary: bool=True,
|
|
1666
1731
|
print_progress: bool=True,
|
|
1732
|
+
stopping_rule: Optional[JaxPlannerStoppingRule]=None,
|
|
1667
1733
|
test_rolling_window: int=10,
|
|
1668
1734
|
tqdm_position: Optional[int]=None) -> Generator[Dict[str, Any], None, None]:
|
|
1669
1735
|
'''Returns a generator for computing an optimal policy or plan.
|
|
@@ -1685,6 +1751,7 @@ class JaxBackpropPlanner:
|
|
|
1685
1751
|
:param print_summary: whether to print planner header, parameter
|
|
1686
1752
|
summary, and diagnosis
|
|
1687
1753
|
:param print_progress: whether to print the progress bar during training
|
|
1754
|
+
:param stopping_rule: stopping criterion
|
|
1688
1755
|
:param test_rolling_window: the test return is averaged on a rolling
|
|
1689
1756
|
window of the past test_rolling_window returns when updating the best
|
|
1690
1757
|
parameters found so far
|
|
@@ -1711,6 +1778,14 @@ class JaxBackpropPlanner:
|
|
|
1711
1778
|
hyperparam_value = float(policy_hyperparams)
|
|
1712
1779
|
policy_hyperparams = {action: hyperparam_value
|
|
1713
1780
|
for action in self.rddl.action_fluents}
|
|
1781
|
+
|
|
1782
|
+
# fill in missing entries
|
|
1783
|
+
elif isinstance(policy_hyperparams, dict):
|
|
1784
|
+
for action in self.rddl.action_fluents:
|
|
1785
|
+
if action not in policy_hyperparams:
|
|
1786
|
+
raise_warning(f'policy_hyperparams[{action}] is not set, '
|
|
1787
|
+
'setting 1.0 which could be suboptimal.')
|
|
1788
|
+
policy_hyperparams[action] = 1.0
|
|
1714
1789
|
|
|
1715
1790
|
# print summary of parameters:
|
|
1716
1791
|
if print_summary:
|
|
@@ -1728,10 +1803,11 @@ class JaxBackpropPlanner:
|
|
|
1728
1803
|
f' plot_frequency ={plot_step}\n'
|
|
1729
1804
|
f' plot_kwargs ={plot_kwargs}\n'
|
|
1730
1805
|
f' print_summary ={print_summary}\n'
|
|
1731
|
-
f' print_progress ={print_progress}\n'
|
|
1806
|
+
f' print_progress ={print_progress}\n'
|
|
1807
|
+
f' stopping_rule ={stopping_rule}\n')
|
|
1732
1808
|
if self.compiled.relaxations:
|
|
1733
1809
|
print('Some RDDL operations are non-differentiable, '
|
|
1734
|
-
'
|
|
1810
|
+
'they will be approximated as follows:')
|
|
1735
1811
|
print(self.compiled.summarize_model_relaxations())
|
|
1736
1812
|
|
|
1737
1813
|
# compute a batched version of the initial values
|
|
@@ -1764,7 +1840,7 @@ class JaxBackpropPlanner:
|
|
|
1764
1840
|
else:
|
|
1765
1841
|
policy_params = guess
|
|
1766
1842
|
opt_state = self.optimizer.init(policy_params)
|
|
1767
|
-
opt_aux =
|
|
1843
|
+
opt_aux = {}
|
|
1768
1844
|
|
|
1769
1845
|
# initialize running statistics
|
|
1770
1846
|
best_params, best_loss, best_grad = policy_params, jnp.inf, jnp.inf
|
|
@@ -1772,7 +1848,12 @@ class JaxBackpropPlanner:
|
|
|
1772
1848
|
rolling_test_loss = RollingMean(test_rolling_window)
|
|
1773
1849
|
log = {}
|
|
1774
1850
|
status = JaxPlannerStatus.NORMAL
|
|
1851
|
+
is_all_zero_fn = lambda x: np.allclose(x, 0)
|
|
1775
1852
|
|
|
1853
|
+
# initialize stopping criterion
|
|
1854
|
+
if stopping_rule is not None:
|
|
1855
|
+
stopping_rule.reset()
|
|
1856
|
+
|
|
1776
1857
|
# initialize plot area
|
|
1777
1858
|
if plot_step is None or plot_step <= 0 or plt is None:
|
|
1778
1859
|
plot = None
|
|
@@ -1786,10 +1867,16 @@ class JaxBackpropPlanner:
|
|
|
1786
1867
|
iters = range(epochs)
|
|
1787
1868
|
if print_progress:
|
|
1788
1869
|
iters = tqdm(iters, total=100, position=tqdm_position)
|
|
1870
|
+
position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
|
|
1789
1871
|
|
|
1790
1872
|
for it in iters:
|
|
1791
1873
|
status = JaxPlannerStatus.NORMAL
|
|
1792
1874
|
|
|
1875
|
+
# gradient noise schedule
|
|
1876
|
+
noise_var = self.noise_grad_eta / (1. + it) ** self.noise_grad_gamma
|
|
1877
|
+
noise_sigma = np.sqrt(noise_var)
|
|
1878
|
+
opt_aux['noise_sigma'] = noise_sigma
|
|
1879
|
+
|
|
1793
1880
|
# update the parameters of the plan
|
|
1794
1881
|
key, subkey = random.split(key)
|
|
1795
1882
|
policy_params, converged, opt_state, opt_aux, \
|
|
@@ -1799,7 +1886,7 @@ class JaxBackpropPlanner:
|
|
|
1799
1886
|
|
|
1800
1887
|
# no progress
|
|
1801
1888
|
grad_norm_zero, _ = jax.tree_util.tree_flatten(
|
|
1802
|
-
jax.tree_map(
|
|
1889
|
+
jax.tree_map(is_all_zero_fn, train_log['grad']))
|
|
1803
1890
|
if np.all(grad_norm_zero):
|
|
1804
1891
|
status = JaxPlannerStatus.NO_PROGRESS
|
|
1805
1892
|
|
|
@@ -1843,8 +1930,9 @@ class JaxBackpropPlanner:
|
|
|
1843
1930
|
if print_progress:
|
|
1844
1931
|
iters.n = int(100 * min(1, max(elapsed / train_seconds, it / epochs)))
|
|
1845
1932
|
iters.set_description(
|
|
1846
|
-
f'
|
|
1847
|
-
f'{-test_loss:14.6f} test / {-best_loss:14.6f} best'
|
|
1933
|
+
f'{position_str} {it:6} it / {-train_loss:14.6f} train / '
|
|
1934
|
+
f'{-test_loss:14.6f} test / {-best_loss:14.6f} best / '
|
|
1935
|
+
f'{status.value} status')
|
|
1848
1936
|
|
|
1849
1937
|
# reached computation budget
|
|
1850
1938
|
if elapsed >= train_seconds:
|
|
@@ -1853,8 +1941,7 @@ class JaxBackpropPlanner:
|
|
|
1853
1941
|
status = JaxPlannerStatus.ITER_BUDGET_REACHED
|
|
1854
1942
|
|
|
1855
1943
|
# return a callback
|
|
1856
|
-
|
|
1857
|
-
yield {
|
|
1944
|
+
callback = {
|
|
1858
1945
|
'status': status,
|
|
1859
1946
|
'iteration': it,
|
|
1860
1947
|
'train_return':-train_loss,
|
|
@@ -1865,16 +1952,23 @@ class JaxBackpropPlanner:
|
|
|
1865
1952
|
'last_iteration_improved': last_iter_improve,
|
|
1866
1953
|
'grad': train_log['grad'],
|
|
1867
1954
|
'best_grad': best_grad,
|
|
1955
|
+
'noise_sigma': noise_sigma,
|
|
1868
1956
|
'updates': train_log['updates'],
|
|
1869
1957
|
'elapsed_time': elapsed,
|
|
1870
1958
|
'key': key,
|
|
1871
1959
|
**log
|
|
1872
1960
|
}
|
|
1961
|
+
start_time_outside = time.time()
|
|
1962
|
+
yield callback
|
|
1873
1963
|
elapsed_outside_loop += (time.time() - start_time_outside)
|
|
1874
1964
|
|
|
1875
1965
|
# abortion check
|
|
1876
1966
|
if status.is_failure():
|
|
1877
1967
|
break
|
|
1968
|
+
|
|
1969
|
+
# stopping condition reached
|
|
1970
|
+
if stopping_rule is not None and stopping_rule.monitor(callback):
|
|
1971
|
+
break
|
|
1878
1972
|
|
|
1879
1973
|
# release resources
|
|
1880
1974
|
if print_progress:
|
|
@@ -1904,9 +1998,9 @@ class JaxBackpropPlanner:
|
|
|
1904
1998
|
f' iterations ={it}\n'
|
|
1905
1999
|
f' best_objective={-best_loss}\n'
|
|
1906
2000
|
f' best_grad_norm={grad_norm}\n'
|
|
1907
|
-
f'diagnosis: {diagnosis}\n')
|
|
2001
|
+
f' diagnosis: {diagnosis}\n')
|
|
1908
2002
|
|
|
1909
|
-
def _perform_diagnosis(self, last_iter_improve,
|
|
2003
|
+
def _perform_diagnosis(self, last_iter_improve,
|
|
1910
2004
|
train_return, test_return, best_return, grad_norm):
|
|
1911
2005
|
max_grad_norm = max(jax.tree_util.tree_leaves(grad_norm))
|
|
1912
2006
|
grad_is_zero = np.allclose(max_grad_norm, 0)
|
|
@@ -2085,7 +2179,7 @@ class JaxLineSearchPlanner(JaxBackpropPlanner):
|
|
|
2085
2179
|
trials += 1
|
|
2086
2180
|
step *= decay
|
|
2087
2181
|
f_step, new_params, new_state = _jax_wrapped_line_search_trial(
|
|
2088
|
-
step, grad, key, policy_params, hyperparams, subs,
|
|
2182
|
+
step, grad, key, policy_params, hyperparams, subs,
|
|
2089
2183
|
model_params, opt_state)
|
|
2090
2184
|
if f_step < best_f:
|
|
2091
2185
|
best_f, best_step, best_params, best_state = \
|
|
@@ -2094,11 +2188,11 @@ class JaxLineSearchPlanner(JaxBackpropPlanner):
|
|
|
2094
2188
|
log['updates'] = None
|
|
2095
2189
|
log['line_search_iters'] = trials
|
|
2096
2190
|
log['learning_rate'] = best_step
|
|
2097
|
-
|
|
2191
|
+
opt_aux['best_step'] = best_step
|
|
2192
|
+
return best_params, True, best_state, opt_aux, best_f, log
|
|
2098
2193
|
|
|
2099
2194
|
return _jax_wrapped_plan_update
|
|
2100
2195
|
|
|
2101
|
-
|
|
2102
2196
|
# ***********************************************************************
|
|
2103
2197
|
# ALL VERSIONS OF RISK FUNCTIONS
|
|
2104
2198
|
#
|
|
@@ -2116,7 +2210,7 @@ class JaxLineSearchPlanner(JaxBackpropPlanner):
|
|
|
2116
2210
|
@jax.jit
|
|
2117
2211
|
def entropic_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
2118
2212
|
return (-1.0 / beta) * jax.scipy.special.logsumexp(
|
|
2119
|
-
|
|
2213
|
+
-beta * returns, b=1.0 / returns.size)
|
|
2120
2214
|
|
|
2121
2215
|
|
|
2122
2216
|
@jax.jit
|
|
@@ -2129,7 +2223,6 @@ def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
|
|
|
2129
2223
|
alpha_mask = jax.lax.stop_gradient(
|
|
2130
2224
|
returns <= jnp.percentile(returns, q=100 * alpha))
|
|
2131
2225
|
return jnp.sum(returns * alpha_mask) / jnp.sum(alpha_mask)
|
|
2132
|
-
|
|
2133
2226
|
|
|
2134
2227
|
# ***********************************************************************
|
|
2135
2228
|
# ALL VERSIONS OF CONTROLLERS
|
|
@@ -2139,12 +2232,13 @@ def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
|
|
|
2139
2232
|
#
|
|
2140
2233
|
# ***********************************************************************
|
|
2141
2234
|
|
|
2235
|
+
|
|
2142
2236
|
class JaxOfflineController(BaseAgent):
|
|
2143
2237
|
'''A container class for a Jax policy trained offline.'''
|
|
2144
2238
|
|
|
2145
2239
|
use_tensor_obs = True
|
|
2146
2240
|
|
|
2147
|
-
def __init__(self, planner: JaxBackpropPlanner,
|
|
2241
|
+
def __init__(self, planner: JaxBackpropPlanner,
|
|
2148
2242
|
key: Optional[random.PRNGKey]=None,
|
|
2149
2243
|
eval_hyperparams: Optional[Dict[str, Any]]=None,
|
|
2150
2244
|
params: Optional[Pytree]=None,
|
|
@@ -2199,7 +2293,7 @@ class JaxOnlineController(BaseAgent):
|
|
|
2199
2293
|
|
|
2200
2294
|
use_tensor_obs = True
|
|
2201
2295
|
|
|
2202
|
-
def __init__(self, planner: JaxBackpropPlanner,
|
|
2296
|
+
def __init__(self, planner: JaxBackpropPlanner,
|
|
2203
2297
|
key: Optional[random.PRNGKey]=None,
|
|
2204
2298
|
eval_hyperparams: Optional[Dict[str, Any]]=None,
|
|
2205
2299
|
warm_start: bool=True,
|