pyRDDLGym-jax 0.4__py3-none-any.whl → 0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/core/logic.py +115 -53
- pyRDDLGym_jax/core/planner.py +140 -58
- pyRDDLGym_jax/core/tuning.py +53 -58
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +2 -1
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +2 -1
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +2 -1
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +4 -3
- pyRDDLGym_jax/examples/configs/default_replan.cfg +2 -1
- pyRDDLGym_jax/examples/run_tune.py +1 -3
- {pyRDDLGym_jax-0.4.dist-info → pyRDDLGym_jax-0.5.dist-info}/METADATA +11 -9
- {pyRDDLGym_jax-0.4.dist-info → pyRDDLGym_jax-0.5.dist-info}/RECORD +14 -14
- {pyRDDLGym_jax-0.4.dist-info → pyRDDLGym_jax-0.5.dist-info}/WHEEL +1 -1
- {pyRDDLGym_jax-0.4.dist-info → pyRDDLGym_jax-0.5.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-0.4.dist-info → pyRDDLGym_jax-0.5.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/planner.py
CHANGED
|
@@ -47,7 +47,6 @@ Bounds = Dict[str, Tuple[np.ndarray, np.ndarray]]
|
|
|
47
47
|
Kwargs = Dict[str, Any]
|
|
48
48
|
Pytree = Any
|
|
49
49
|
|
|
50
|
-
|
|
51
50
|
# ***********************************************************************
|
|
52
51
|
# CONFIG FILE MANAGEMENT
|
|
53
52
|
#
|
|
@@ -57,6 +56,7 @@ Pytree = Any
|
|
|
57
56
|
#
|
|
58
57
|
# ***********************************************************************
|
|
59
58
|
|
|
59
|
+
|
|
60
60
|
def _parse_config_file(path: str):
|
|
61
61
|
if not os.path.isfile(path):
|
|
62
62
|
raise FileNotFoundError(f'File {path} does not exist.')
|
|
@@ -103,10 +103,13 @@ def _load_config(config, args):
|
|
|
103
103
|
compare_kwargs = model_args.get('comparison_kwargs', {})
|
|
104
104
|
sampling_name = model_args.get('sampling', 'GumbelSoftmax')
|
|
105
105
|
sampling_kwargs = model_args.get('sampling_kwargs', {})
|
|
106
|
+
rounding_name = model_args.get('rounding', 'SoftRounding')
|
|
107
|
+
rounding_kwargs = model_args.get('rounding_kwargs', {})
|
|
106
108
|
logic_kwargs['tnorm'] = getattr(logic, tnorm_name)(**tnorm_kwargs)
|
|
107
109
|
logic_kwargs['complement'] = getattr(logic, comp_name)(**comp_kwargs)
|
|
108
110
|
logic_kwargs['comparison'] = getattr(logic, compare_name)(**compare_kwargs)
|
|
109
111
|
logic_kwargs['sampling'] = getattr(logic, sampling_name)(**sampling_kwargs)
|
|
112
|
+
logic_kwargs['rounding'] = getattr(logic, rounding_name)(**rounding_kwargs)
|
|
110
113
|
|
|
111
114
|
# read the policy settings
|
|
112
115
|
plan_method = planner_args.pop('method')
|
|
@@ -157,11 +160,18 @@ def _load_config(config, args):
|
|
|
157
160
|
else:
|
|
158
161
|
planner_args['optimizer'] = optimizer
|
|
159
162
|
|
|
160
|
-
#
|
|
163
|
+
# optimize call RNG key
|
|
161
164
|
planner_key = train_args.get('key', None)
|
|
162
165
|
if planner_key is not None:
|
|
163
166
|
train_args['key'] = random.PRNGKey(planner_key)
|
|
164
167
|
|
|
168
|
+
# optimize call stopping rule
|
|
169
|
+
stopping_rule = train_args.get('stopping_rule', None)
|
|
170
|
+
if stopping_rule is not None:
|
|
171
|
+
stopping_rule_kwargs = train_args.pop('stopping_rule_kwargs', {})
|
|
172
|
+
train_args['stopping_rule'] = getattr(
|
|
173
|
+
sys.modules[__name__], stopping_rule)(**stopping_rule_kwargs)
|
|
174
|
+
|
|
165
175
|
return planner_args, plan_kwargs, train_args
|
|
166
176
|
|
|
167
177
|
|
|
@@ -175,7 +185,6 @@ def load_config_from_string(value: str) -> Tuple[Kwargs, ...]:
|
|
|
175
185
|
'''Loads config file contents specified explicitly as a string value.'''
|
|
176
186
|
config, args = _parse_config_string(value)
|
|
177
187
|
return _load_config(config, args)
|
|
178
|
-
|
|
179
188
|
|
|
180
189
|
# ***********************************************************************
|
|
181
190
|
# MODEL RELAXATIONS
|
|
@@ -299,7 +308,6 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
299
308
|
arg, = expr.args
|
|
300
309
|
arg = self._jax(arg, info)
|
|
301
310
|
return arg
|
|
302
|
-
|
|
303
311
|
|
|
304
312
|
# ***********************************************************************
|
|
305
313
|
# ALL VERSIONS OF JAX PLANS
|
|
@@ -309,6 +317,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
309
317
|
#
|
|
310
318
|
# ***********************************************************************
|
|
311
319
|
|
|
320
|
+
|
|
312
321
|
class JaxPlan:
|
|
313
322
|
'''Base class for all JAX policy representations.'''
|
|
314
323
|
|
|
@@ -363,7 +372,7 @@ class JaxPlan:
|
|
|
363
372
|
self._projection = value
|
|
364
373
|
|
|
365
374
|
def _calculate_action_info(self, compiled: JaxRDDLCompilerWithGrad,
|
|
366
|
-
user_bounds: Bounds,
|
|
375
|
+
user_bounds: Bounds,
|
|
367
376
|
horizon: int):
|
|
368
377
|
shapes, bounds, bounds_safe, cond_lists = {}, {}, {}, {}
|
|
369
378
|
for (name, prange) in compiled.rddl.variable_ranges.items():
|
|
@@ -463,7 +472,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
463
472
|
f' max_projection_iters ={self._max_constraint_iter}')
|
|
464
473
|
|
|
465
474
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
466
|
-
_bounds: Bounds,
|
|
475
|
+
_bounds: Bounds,
|
|
467
476
|
horizon: int) -> None:
|
|
468
477
|
rddl = compiled.rddl
|
|
469
478
|
|
|
@@ -504,7 +513,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
504
513
|
def _jax_bool_action_to_param(var, action, hyperparams):
|
|
505
514
|
if wrap_sigmoid:
|
|
506
515
|
weight = hyperparams[var]
|
|
507
|
-
return
|
|
516
|
+
return jax.scipy.special.logit(action) / weight
|
|
508
517
|
else:
|
|
509
518
|
return action
|
|
510
519
|
|
|
@@ -513,14 +522,13 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
513
522
|
def _jax_non_bool_param_to_action(var, param, hyperparams):
|
|
514
523
|
if wrap_non_bool:
|
|
515
524
|
lower, upper = bounds_safe[var]
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
]
|
|
525
|
+
mb, ml, mu, mn = [mask.astype(compiled.REAL)
|
|
526
|
+
for mask in cond_lists[var]]
|
|
527
|
+
action = (
|
|
528
|
+
mb * (lower + (upper - lower) * jax.nn.sigmoid(param)) +
|
|
529
|
+
ml * (lower + (jax.nn.elu(param) + 1.0)) +
|
|
530
|
+
mu * (upper - (jax.nn.elu(-param) + 1.0)) +
|
|
531
|
+
mn * param
|
|
524
532
|
)
|
|
525
533
|
else:
|
|
526
534
|
action = param
|
|
@@ -780,7 +788,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
780
788
|
def __init__(self, topology: Optional[Sequence[int]]=None,
|
|
781
789
|
activation: Activation=jnp.tanh,
|
|
782
790
|
initializer: hk.initializers.Initializer=hk.initializers.VarianceScaling(scale=2.0),
|
|
783
|
-
normalize: bool=False,
|
|
791
|
+
normalize: bool=False,
|
|
784
792
|
normalize_per_layer: bool=False,
|
|
785
793
|
normalizer_kwargs: Optional[Kwargs]=None,
|
|
786
794
|
wrap_non_bool: bool=False) -> None:
|
|
@@ -828,7 +836,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
828
836
|
f' wrap_non_bool ={self._wrap_non_bool}')
|
|
829
837
|
|
|
830
838
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
831
|
-
_bounds: Bounds,
|
|
839
|
+
_bounds: Bounds,
|
|
832
840
|
horizon: int) -> None:
|
|
833
841
|
rddl = compiled.rddl
|
|
834
842
|
|
|
@@ -881,7 +889,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
881
889
|
if normalize_per_layer and value_size == 1:
|
|
882
890
|
raise_warning(
|
|
883
891
|
f'Cannot apply layer norm to state-fluent <{var}> '
|
|
884
|
-
f'of size 1: setting normalize_per_layer = False.',
|
|
892
|
+
f'of size 1: setting normalize_per_layer = False.',
|
|
885
893
|
'red')
|
|
886
894
|
normalize_per_layer = False
|
|
887
895
|
non_bool_dims += value_size
|
|
@@ -906,8 +914,8 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
906
914
|
else:
|
|
907
915
|
if normalize and normalize_per_layer:
|
|
908
916
|
normalizer = hk.LayerNorm(
|
|
909
|
-
axis=-1, param_axis=-1,
|
|
910
|
-
name=f'input_norm_{input_names[var]}',
|
|
917
|
+
axis=-1, param_axis=-1,
|
|
918
|
+
name=f'input_norm_{input_names[var]}',
|
|
911
919
|
**self._normalizer_kwargs)
|
|
912
920
|
state = normalizer(state)
|
|
913
921
|
states_non_bool.append(state)
|
|
@@ -917,7 +925,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
917
925
|
# optionally perform layer normalization on the non-bool inputs
|
|
918
926
|
if normalize and not normalize_per_layer and non_bool_dims:
|
|
919
927
|
normalizer = hk.LayerNorm(
|
|
920
|
-
axis=-1, param_axis=-1, name='input_norm',
|
|
928
|
+
axis=-1, param_axis=-1, name='input_norm',
|
|
921
929
|
**self._normalizer_kwargs)
|
|
922
930
|
normalized = normalizer(state[:non_bool_dims])
|
|
923
931
|
state = state.at[:non_bool_dims].set(normalized)
|
|
@@ -950,14 +958,13 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
950
958
|
else:
|
|
951
959
|
if wrap_non_bool:
|
|
952
960
|
lower, upper = bounds_safe[var]
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
]
|
|
961
|
+
mb, ml, mu, mn = [mask.astype(compiled.REAL)
|
|
962
|
+
for mask in cond_lists[var]]
|
|
963
|
+
action = (
|
|
964
|
+
mb * (lower + (upper - lower) * jax.nn.sigmoid(output)) +
|
|
965
|
+
ml * (lower + (jax.nn.elu(output) + 1.0)) +
|
|
966
|
+
mu * (upper - (jax.nn.elu(-output) + 1.0)) +
|
|
967
|
+
mn * output
|
|
961
968
|
)
|
|
962
969
|
else:
|
|
963
970
|
action = output
|
|
@@ -1049,7 +1056,6 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1049
1056
|
|
|
1050
1057
|
def guess_next_epoch(self, params: Pytree) -> Pytree:
|
|
1051
1058
|
return params
|
|
1052
|
-
|
|
1053
1059
|
|
|
1054
1060
|
# ***********************************************************************
|
|
1055
1061
|
# ALL VERSIONS OF JAX PLANNER
|
|
@@ -1059,6 +1065,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1059
1065
|
#
|
|
1060
1066
|
# ***********************************************************************
|
|
1061
1067
|
|
|
1068
|
+
|
|
1062
1069
|
class RollingMean:
|
|
1063
1070
|
'''Maintains an estimate of the rolling mean of a stream of real-valued
|
|
1064
1071
|
observations.'''
|
|
@@ -1080,7 +1087,7 @@ class RollingMean:
|
|
|
1080
1087
|
class JaxPlannerPlot:
|
|
1081
1088
|
'''Supports plotting and visualization of a JAX policy in real time.'''
|
|
1082
1089
|
|
|
1083
|
-
def __init__(self, rddl: RDDLPlanningModel, horizon: int,
|
|
1090
|
+
def __init__(self, rddl: RDDLPlanningModel, horizon: int,
|
|
1084
1091
|
show_violin: bool=True, show_action: bool=True) -> None:
|
|
1085
1092
|
'''Creates a new planner visualizer.
|
|
1086
1093
|
|
|
@@ -1128,7 +1135,7 @@ class JaxPlannerPlot:
|
|
|
1128
1135
|
for dim in rddl.object_counts(rddl.variable_params[name]):
|
|
1129
1136
|
action_dim *= dim
|
|
1130
1137
|
action_plot = ax.pcolormesh(
|
|
1131
|
-
np.zeros((action_dim, horizon)),
|
|
1138
|
+
np.zeros((action_dim, horizon)),
|
|
1132
1139
|
cmap='seismic', vmin=vmin, vmax=vmax)
|
|
1133
1140
|
ax.set_aspect('auto')
|
|
1134
1141
|
ax.set_xlabel('decision epoch')
|
|
@@ -1201,6 +1208,39 @@ class JaxPlannerStatus(Enum):
|
|
|
1201
1208
|
return self.value >= 3
|
|
1202
1209
|
|
|
1203
1210
|
|
|
1211
|
+
class JaxPlannerStoppingRule:
|
|
1212
|
+
'''The base class of all planner stopping rules.'''
|
|
1213
|
+
|
|
1214
|
+
def reset(self) -> None:
|
|
1215
|
+
raise NotImplementedError
|
|
1216
|
+
|
|
1217
|
+
def monitor(self, callback: Dict[str, Any]) -> bool:
|
|
1218
|
+
raise NotImplementedError
|
|
1219
|
+
|
|
1220
|
+
|
|
1221
|
+
class NoImprovementStoppingRule(JaxPlannerStoppingRule):
|
|
1222
|
+
'''Stopping rule based on no improvement for a fixed number of iterations.'''
|
|
1223
|
+
|
|
1224
|
+
def __init__(self, patience: int) -> None:
|
|
1225
|
+
self.patience = patience
|
|
1226
|
+
|
|
1227
|
+
def reset(self) -> None:
|
|
1228
|
+
self.callback = None
|
|
1229
|
+
self.iters_since_last_update = 0
|
|
1230
|
+
|
|
1231
|
+
def monitor(self, callback: Dict[str, Any]) -> bool:
|
|
1232
|
+
if self.callback is None \
|
|
1233
|
+
or callback['best_return'] > self.callback['best_return']:
|
|
1234
|
+
self.callback = callback
|
|
1235
|
+
self.iters_since_last_update = 0
|
|
1236
|
+
else:
|
|
1237
|
+
self.iters_since_last_update += 1
|
|
1238
|
+
return self.iters_since_last_update >= self.patience
|
|
1239
|
+
|
|
1240
|
+
def __str__(self) -> str:
|
|
1241
|
+
return f'No improvement for {self.patience} iterations'
|
|
1242
|
+
|
|
1243
|
+
|
|
1204
1244
|
class JaxBackpropPlanner:
|
|
1205
1245
|
'''A class for optimizing an action sequence in the given RDDL MDP using
|
|
1206
1246
|
gradient descent.'''
|
|
@@ -1215,6 +1255,8 @@ class JaxBackpropPlanner:
|
|
|
1215
1255
|
optimizer: Callable[..., optax.GradientTransformation]=optax.rmsprop,
|
|
1216
1256
|
optimizer_kwargs: Optional[Kwargs]=None,
|
|
1217
1257
|
clip_grad: Optional[float]=None,
|
|
1258
|
+
noise_grad_eta: float=0.0,
|
|
1259
|
+
noise_grad_gamma: float=1.0,
|
|
1218
1260
|
logic: FuzzyLogic=FuzzyLogic(),
|
|
1219
1261
|
use_symlog_reward: bool=False,
|
|
1220
1262
|
utility: Union[Callable[[jnp.ndarray], float], str]='mean',
|
|
@@ -1241,6 +1283,8 @@ class JaxBackpropPlanner:
|
|
|
1241
1283
|
:param optimizer_kwargs: a dictionary of parameters to pass to the SGD
|
|
1242
1284
|
factory (e.g. which parameters are controllable externally)
|
|
1243
1285
|
:param clip_grad: maximum magnitude of gradient updates
|
|
1286
|
+
:param noise_grad_eta: scale of the gradient noise variance
|
|
1287
|
+
:param noise_grad_gamma: decay rate of the gradient noise variance
|
|
1244
1288
|
:param logic: a subclass of FuzzyLogic for mapping exact mathematical
|
|
1245
1289
|
operations to their differentiable counterparts
|
|
1246
1290
|
:param use_symlog_reward: whether to use the symlog transform on the
|
|
@@ -1275,6 +1319,8 @@ class JaxBackpropPlanner:
|
|
|
1275
1319
|
optimizer_kwargs = {'learning_rate': 0.1}
|
|
1276
1320
|
self._optimizer_kwargs = optimizer_kwargs
|
|
1277
1321
|
self.clip_grad = clip_grad
|
|
1322
|
+
self.noise_grad_eta = noise_grad_eta
|
|
1323
|
+
self.noise_grad_gamma = noise_grad_gamma
|
|
1278
1324
|
|
|
1279
1325
|
# set optimizer
|
|
1280
1326
|
try:
|
|
@@ -1340,14 +1386,14 @@ class JaxBackpropPlanner:
|
|
|
1340
1386
|
except Exception as _:
|
|
1341
1387
|
devices_short = 'N/A'
|
|
1342
1388
|
LOGO = \
|
|
1389
|
+
r"""
|
|
1390
|
+
__ ______ __ __ ______ __ ______ __ __
|
|
1391
|
+
/\ \ /\ __ \ /\_\_\_\ /\ == \/\ \ /\ __ \ /\ "-.\ \
|
|
1392
|
+
_\_\ \\ \ __ \\/_/\_\/_\ \ _-/\ \ \____\ \ __ \\ \ \-. \
|
|
1393
|
+
/\_____\\ \_\ \_\ /\_\/\_\\ \_\ \ \_____\\ \_\ \_\\ \_\\"\_\
|
|
1394
|
+
\/_____/ \/_/\/_/ \/_/\/_/ \/_/ \/_____/ \/_/\/_/ \/_/ \/_/
|
|
1343
1395
|
"""
|
|
1344
|
-
|
|
1345
|
-
/\ \ /\ __ \ /\_\_\_\ /\ == \/\ \ /\ __ \ /\ "-.\ \
|
|
1346
|
-
_\_\ \ \ \ __ \ \/_/\_\/_ \ \ _-/\ \ \____ \ \ __ \ \ \ \-. \
|
|
1347
|
-
/\_____\ \ \_\ \_\ /\_\/\_\ \ \_\ \ \_____\ \ \_\ \_\ \ \_\\"\_\
|
|
1348
|
-
\/_____/ \/_/\/_/ \/_/\/_/ \/_/ \/_____/ \/_/\/_/ \/_/ \/_/
|
|
1349
|
-
"""
|
|
1350
|
-
|
|
1396
|
+
|
|
1351
1397
|
print('\n'
|
|
1352
1398
|
f'{LOGO}\n'
|
|
1353
1399
|
f'Version {__version__}\n'
|
|
@@ -1372,6 +1418,8 @@ class JaxBackpropPlanner:
|
|
|
1372
1418
|
f' optimizer ={self._optimizer_name.__name__}\n'
|
|
1373
1419
|
f' optimizer args ={self._optimizer_kwargs}\n'
|
|
1374
1420
|
f' clip_gradient ={self.clip_grad}\n'
|
|
1421
|
+
f' noise_grad_eta ={self.noise_grad_eta}\n'
|
|
1422
|
+
f' noise_grad_gamma ={self.noise_grad_gamma}\n'
|
|
1375
1423
|
f' batch_size_train ={self.batch_size_train}\n'
|
|
1376
1424
|
f' batch_size_test ={self.batch_size_test}')
|
|
1377
1425
|
self.plan.summarize_hyperparameters()
|
|
@@ -1396,7 +1444,7 @@ class JaxBackpropPlanner:
|
|
|
1396
1444
|
|
|
1397
1445
|
# Jax compilation of the exact RDDL for testing
|
|
1398
1446
|
self.test_compiled = JaxRDDLCompiler(
|
|
1399
|
-
rddl=rddl,
|
|
1447
|
+
rddl=rddl,
|
|
1400
1448
|
logger=self.logger,
|
|
1401
1449
|
use64bit=self.use64bit)
|
|
1402
1450
|
self.test_compiled.compile(log_jax_expr=True, heading='EXACT MODEL')
|
|
@@ -1473,7 +1521,7 @@ class JaxBackpropPlanner:
|
|
|
1473
1521
|
def _jax_wrapped_init_policy(key, hyperparams, subs):
|
|
1474
1522
|
policy_params = init(key, hyperparams, subs)
|
|
1475
1523
|
opt_state = optimizer.init(policy_params)
|
|
1476
|
-
return policy_params, opt_state,
|
|
1524
|
+
return policy_params, opt_state, {}
|
|
1477
1525
|
|
|
1478
1526
|
return _jax_wrapped_init_policy
|
|
1479
1527
|
|
|
@@ -1481,6 +1529,19 @@ class JaxBackpropPlanner:
|
|
|
1481
1529
|
optimizer = self.optimizer
|
|
1482
1530
|
projection = self.plan.projection
|
|
1483
1531
|
|
|
1532
|
+
# add Gaussian gradient noise per Neelakantan et al., 2016.
|
|
1533
|
+
def _jax_wrapped_gaussian_param_noise(key, grads, sigma):
|
|
1534
|
+
treedef = jax.tree_util.tree_structure(grads)
|
|
1535
|
+
keys_flat = random.split(key, num=treedef.num_leaves)
|
|
1536
|
+
keys_tree = jax.tree_util.tree_unflatten(treedef, keys_flat)
|
|
1537
|
+
new_grads = jax.tree_map(
|
|
1538
|
+
lambda g, k: g + sigma * random.normal(
|
|
1539
|
+
key=k, shape=g.shape, dtype=g.dtype),
|
|
1540
|
+
grads,
|
|
1541
|
+
keys_tree
|
|
1542
|
+
)
|
|
1543
|
+
return new_grads
|
|
1544
|
+
|
|
1484
1545
|
# calculate the plan gradient w.r.t. return loss and update optimizer
|
|
1485
1546
|
# also perform a projection step to satisfy constraints on actions
|
|
1486
1547
|
def _jax_wrapped_plan_update(key, policy_params, hyperparams,
|
|
@@ -1488,12 +1549,14 @@ class JaxBackpropPlanner:
|
|
|
1488
1549
|
grad_fn = jax.value_and_grad(loss, argnums=1, has_aux=True)
|
|
1489
1550
|
(loss_val, log), grad = grad_fn(
|
|
1490
1551
|
key, policy_params, hyperparams, subs, model_params)
|
|
1552
|
+
sigma = opt_aux.get('noise_sigma', 0.0)
|
|
1553
|
+
grad = _jax_wrapped_gaussian_param_noise(key, grad, sigma)
|
|
1491
1554
|
updates, opt_state = optimizer.update(grad, opt_state)
|
|
1492
1555
|
policy_params = optax.apply_updates(policy_params, updates)
|
|
1493
1556
|
policy_params, converged = projection(policy_params, hyperparams)
|
|
1494
1557
|
log['grad'] = grad
|
|
1495
1558
|
log['updates'] = updates
|
|
1496
|
-
return policy_params, converged, opt_state,
|
|
1559
|
+
return policy_params, converged, opt_state, opt_aux, loss_val, log
|
|
1497
1560
|
|
|
1498
1561
|
return jax.jit(_jax_wrapped_plan_update)
|
|
1499
1562
|
|
|
@@ -1524,7 +1587,7 @@ class JaxBackpropPlanner:
|
|
|
1524
1587
|
return init_train, init_test
|
|
1525
1588
|
|
|
1526
1589
|
def as_optimization_problem(
|
|
1527
|
-
self, key: Optional[random.PRNGKey]=None,
|
|
1590
|
+
self, key: Optional[random.PRNGKey]=None,
|
|
1528
1591
|
policy_hyperparams: Optional[Pytree]=None,
|
|
1529
1592
|
loss_function_updates_key: bool=True,
|
|
1530
1593
|
grad_function_updates_key: bool=False) -> Tuple[Callable, Callable, np.ndarray, Callable]:
|
|
@@ -1576,7 +1639,7 @@ class JaxBackpropPlanner:
|
|
|
1576
1639
|
@jax.jit
|
|
1577
1640
|
def _loss_with_key(key, params_1d):
|
|
1578
1641
|
policy_params = unravel_fn(params_1d)
|
|
1579
|
-
loss_val, _ = loss_fn(key, policy_params, policy_hyperparams,
|
|
1642
|
+
loss_val, _ = loss_fn(key, policy_params, policy_hyperparams,
|
|
1580
1643
|
train_subs, model_params)
|
|
1581
1644
|
return loss_val
|
|
1582
1645
|
|
|
@@ -1584,7 +1647,7 @@ class JaxBackpropPlanner:
|
|
|
1584
1647
|
def _grad_with_key(key, params_1d):
|
|
1585
1648
|
policy_params = unravel_fn(params_1d)
|
|
1586
1649
|
grad_fn = jax.grad(loss_fn, argnums=1, has_aux=True)
|
|
1587
|
-
grad_val, _ = grad_fn(key, policy_params, policy_hyperparams,
|
|
1650
|
+
grad_val, _ = grad_fn(key, policy_params, policy_hyperparams,
|
|
1588
1651
|
train_subs, model_params)
|
|
1589
1652
|
grad_1d = jax.flatten_util.ravel_pytree(grad_val)[0]
|
|
1590
1653
|
return grad_1d
|
|
@@ -1633,6 +1696,7 @@ class JaxBackpropPlanner:
|
|
|
1633
1696
|
:param print_summary: whether to print planner header, parameter
|
|
1634
1697
|
summary, and diagnosis
|
|
1635
1698
|
:param print_progress: whether to print the progress bar during training
|
|
1699
|
+
:param stopping_rule: stopping criterion
|
|
1636
1700
|
:param test_rolling_window: the test return is averaged on a rolling
|
|
1637
1701
|
window of the past test_rolling_window returns when updating the best
|
|
1638
1702
|
parameters found so far
|
|
@@ -1658,13 +1722,14 @@ class JaxBackpropPlanner:
|
|
|
1658
1722
|
epochs: int=999999,
|
|
1659
1723
|
train_seconds: float=120.,
|
|
1660
1724
|
plot_step: Optional[int]=None,
|
|
1661
|
-
plot_kwargs: Optional[
|
|
1725
|
+
plot_kwargs: Optional[Kwargs]=None,
|
|
1662
1726
|
model_params: Optional[Dict[str, Any]]=None,
|
|
1663
1727
|
policy_hyperparams: Optional[Dict[str, Any]]=None,
|
|
1664
1728
|
subs: Optional[Dict[str, Any]]=None,
|
|
1665
1729
|
guess: Optional[Pytree]=None,
|
|
1666
1730
|
print_summary: bool=True,
|
|
1667
1731
|
print_progress: bool=True,
|
|
1732
|
+
stopping_rule: Optional[JaxPlannerStoppingRule]=None,
|
|
1668
1733
|
test_rolling_window: int=10,
|
|
1669
1734
|
tqdm_position: Optional[int]=None) -> Generator[Dict[str, Any], None, None]:
|
|
1670
1735
|
'''Returns a generator for computing an optimal policy or plan.
|
|
@@ -1686,6 +1751,7 @@ class JaxBackpropPlanner:
|
|
|
1686
1751
|
:param print_summary: whether to print planner header, parameter
|
|
1687
1752
|
summary, and diagnosis
|
|
1688
1753
|
:param print_progress: whether to print the progress bar during training
|
|
1754
|
+
:param stopping_rule: stopping criterion
|
|
1689
1755
|
:param test_rolling_window: the test return is averaged on a rolling
|
|
1690
1756
|
window of the past test_rolling_window returns when updating the best
|
|
1691
1757
|
parameters found so far
|
|
@@ -1737,10 +1803,11 @@ class JaxBackpropPlanner:
|
|
|
1737
1803
|
f' plot_frequency ={plot_step}\n'
|
|
1738
1804
|
f' plot_kwargs ={plot_kwargs}\n'
|
|
1739
1805
|
f' print_summary ={print_summary}\n'
|
|
1740
|
-
f' print_progress ={print_progress}\n'
|
|
1806
|
+
f' print_progress ={print_progress}\n'
|
|
1807
|
+
f' stopping_rule ={stopping_rule}\n')
|
|
1741
1808
|
if self.compiled.relaxations:
|
|
1742
1809
|
print('Some RDDL operations are non-differentiable, '
|
|
1743
|
-
'
|
|
1810
|
+
'they will be approximated as follows:')
|
|
1744
1811
|
print(self.compiled.summarize_model_relaxations())
|
|
1745
1812
|
|
|
1746
1813
|
# compute a batched version of the initial values
|
|
@@ -1773,7 +1840,7 @@ class JaxBackpropPlanner:
|
|
|
1773
1840
|
else:
|
|
1774
1841
|
policy_params = guess
|
|
1775
1842
|
opt_state = self.optimizer.init(policy_params)
|
|
1776
|
-
opt_aux =
|
|
1843
|
+
opt_aux = {}
|
|
1777
1844
|
|
|
1778
1845
|
# initialize running statistics
|
|
1779
1846
|
best_params, best_loss, best_grad = policy_params, jnp.inf, jnp.inf
|
|
@@ -1783,6 +1850,10 @@ class JaxBackpropPlanner:
|
|
|
1783
1850
|
status = JaxPlannerStatus.NORMAL
|
|
1784
1851
|
is_all_zero_fn = lambda x: np.allclose(x, 0)
|
|
1785
1852
|
|
|
1853
|
+
# initialize stopping criterion
|
|
1854
|
+
if stopping_rule is not None:
|
|
1855
|
+
stopping_rule.reset()
|
|
1856
|
+
|
|
1786
1857
|
# initialize plot area
|
|
1787
1858
|
if plot_step is None or plot_step <= 0 or plt is None:
|
|
1788
1859
|
plot = None
|
|
@@ -1801,6 +1872,11 @@ class JaxBackpropPlanner:
|
|
|
1801
1872
|
for it in iters:
|
|
1802
1873
|
status = JaxPlannerStatus.NORMAL
|
|
1803
1874
|
|
|
1875
|
+
# gradient noise schedule
|
|
1876
|
+
noise_var = self.noise_grad_eta / (1. + it) ** self.noise_grad_gamma
|
|
1877
|
+
noise_sigma = np.sqrt(noise_var)
|
|
1878
|
+
opt_aux['noise_sigma'] = noise_sigma
|
|
1879
|
+
|
|
1804
1880
|
# update the parameters of the plan
|
|
1805
1881
|
key, subkey = random.split(key)
|
|
1806
1882
|
policy_params, converged, opt_state, opt_aux, \
|
|
@@ -1865,8 +1941,7 @@ class JaxBackpropPlanner:
|
|
|
1865
1941
|
status = JaxPlannerStatus.ITER_BUDGET_REACHED
|
|
1866
1942
|
|
|
1867
1943
|
# return a callback
|
|
1868
|
-
|
|
1869
|
-
yield {
|
|
1944
|
+
callback = {
|
|
1870
1945
|
'status': status,
|
|
1871
1946
|
'iteration': it,
|
|
1872
1947
|
'train_return':-train_loss,
|
|
@@ -1877,16 +1952,23 @@ class JaxBackpropPlanner:
|
|
|
1877
1952
|
'last_iteration_improved': last_iter_improve,
|
|
1878
1953
|
'grad': train_log['grad'],
|
|
1879
1954
|
'best_grad': best_grad,
|
|
1955
|
+
'noise_sigma': noise_sigma,
|
|
1880
1956
|
'updates': train_log['updates'],
|
|
1881
1957
|
'elapsed_time': elapsed,
|
|
1882
1958
|
'key': key,
|
|
1883
1959
|
**log
|
|
1884
1960
|
}
|
|
1961
|
+
start_time_outside = time.time()
|
|
1962
|
+
yield callback
|
|
1885
1963
|
elapsed_outside_loop += (time.time() - start_time_outside)
|
|
1886
1964
|
|
|
1887
1965
|
# abortion check
|
|
1888
1966
|
if status.is_failure():
|
|
1889
1967
|
break
|
|
1968
|
+
|
|
1969
|
+
# stopping condition reached
|
|
1970
|
+
if stopping_rule is not None and stopping_rule.monitor(callback):
|
|
1971
|
+
break
|
|
1890
1972
|
|
|
1891
1973
|
# release resources
|
|
1892
1974
|
if print_progress:
|
|
@@ -1918,7 +2000,7 @@ class JaxBackpropPlanner:
|
|
|
1918
2000
|
f' best_grad_norm={grad_norm}\n'
|
|
1919
2001
|
f' diagnosis: {diagnosis}\n')
|
|
1920
2002
|
|
|
1921
|
-
def _perform_diagnosis(self, last_iter_improve,
|
|
2003
|
+
def _perform_diagnosis(self, last_iter_improve,
|
|
1922
2004
|
train_return, test_return, best_return, grad_norm):
|
|
1923
2005
|
max_grad_norm = max(jax.tree_util.tree_leaves(grad_norm))
|
|
1924
2006
|
grad_is_zero = np.allclose(max_grad_norm, 0)
|
|
@@ -2097,7 +2179,7 @@ class JaxLineSearchPlanner(JaxBackpropPlanner):
|
|
|
2097
2179
|
trials += 1
|
|
2098
2180
|
step *= decay
|
|
2099
2181
|
f_step, new_params, new_state = _jax_wrapped_line_search_trial(
|
|
2100
|
-
step, grad, key, policy_params, hyperparams, subs,
|
|
2182
|
+
step, grad, key, policy_params, hyperparams, subs,
|
|
2101
2183
|
model_params, opt_state)
|
|
2102
2184
|
if f_step < best_f:
|
|
2103
2185
|
best_f, best_step, best_params, best_state = \
|
|
@@ -2106,11 +2188,11 @@ class JaxLineSearchPlanner(JaxBackpropPlanner):
|
|
|
2106
2188
|
log['updates'] = None
|
|
2107
2189
|
log['line_search_iters'] = trials
|
|
2108
2190
|
log['learning_rate'] = best_step
|
|
2109
|
-
|
|
2191
|
+
opt_aux['best_step'] = best_step
|
|
2192
|
+
return best_params, True, best_state, opt_aux, best_f, log
|
|
2110
2193
|
|
|
2111
2194
|
return _jax_wrapped_plan_update
|
|
2112
2195
|
|
|
2113
|
-
|
|
2114
2196
|
# ***********************************************************************
|
|
2115
2197
|
# ALL VERSIONS OF RISK FUNCTIONS
|
|
2116
2198
|
#
|
|
@@ -2141,7 +2223,6 @@ def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
|
|
|
2141
2223
|
alpha_mask = jax.lax.stop_gradient(
|
|
2142
2224
|
returns <= jnp.percentile(returns, q=100 * alpha))
|
|
2143
2225
|
return jnp.sum(returns * alpha_mask) / jnp.sum(alpha_mask)
|
|
2144
|
-
|
|
2145
2226
|
|
|
2146
2227
|
# ***********************************************************************
|
|
2147
2228
|
# ALL VERSIONS OF CONTROLLERS
|
|
@@ -2151,12 +2232,13 @@ def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
|
|
|
2151
2232
|
#
|
|
2152
2233
|
# ***********************************************************************
|
|
2153
2234
|
|
|
2235
|
+
|
|
2154
2236
|
class JaxOfflineController(BaseAgent):
|
|
2155
2237
|
'''A container class for a Jax policy trained offline.'''
|
|
2156
2238
|
|
|
2157
2239
|
use_tensor_obs = True
|
|
2158
2240
|
|
|
2159
|
-
def __init__(self, planner: JaxBackpropPlanner,
|
|
2241
|
+
def __init__(self, planner: JaxBackpropPlanner,
|
|
2160
2242
|
key: Optional[random.PRNGKey]=None,
|
|
2161
2243
|
eval_hyperparams: Optional[Dict[str, Any]]=None,
|
|
2162
2244
|
params: Optional[Pytree]=None,
|
|
@@ -2211,7 +2293,7 @@ class JaxOnlineController(BaseAgent):
|
|
|
2211
2293
|
|
|
2212
2294
|
use_tensor_obs = True
|
|
2213
2295
|
|
|
2214
|
-
def __init__(self, planner: JaxBackpropPlanner,
|
|
2296
|
+
def __init__(self, planner: JaxBackpropPlanner,
|
|
2215
2297
|
key: Optional[random.PRNGKey]=None,
|
|
2216
2298
|
eval_hyperparams: Optional[Dict[str, Any]]=None,
|
|
2217
2299
|
warm_start: bool=True,
|