pyRDDLGym-jax 2.0__py3-none-any.whl → 2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +85 -190
- pyRDDLGym_jax/core/logic.py +313 -56
- pyRDDLGym_jax/core/planner.py +274 -200
- pyRDDLGym_jax/core/visualization.py +7 -8
- pyRDDLGym_jax/examples/run_tune.py +10 -6
- {pyRDDLGym_jax-2.0.dist-info → pyrddlgym_jax-2.2.dist-info}/METADATA +43 -30
- {pyRDDLGym_jax-2.0.dist-info → pyrddlgym_jax-2.2.dist-info}/RECORD +12 -12
- {pyRDDLGym_jax-2.0.dist-info → pyrddlgym_jax-2.2.dist-info}/WHEEL +1 -1
- {pyRDDLGym_jax-2.0.dist-info → pyrddlgym_jax-2.2.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-2.0.dist-info → pyrddlgym_jax-2.2.dist-info}/entry_points.txt +0 -0
- {pyRDDLGym_jax-2.0.dist-info → pyrddlgym_jax-2.2.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/planner.py
CHANGED
|
@@ -47,7 +47,9 @@ import jax.random as random
|
|
|
47
47
|
import numpy as np
|
|
48
48
|
import optax
|
|
49
49
|
import termcolor
|
|
50
|
-
from tqdm import tqdm
|
|
50
|
+
from tqdm import tqdm, TqdmWarning
|
|
51
|
+
import warnings
|
|
52
|
+
warnings.filterwarnings("ignore", category=TqdmWarning)
|
|
51
53
|
|
|
52
54
|
from pyRDDLGym.core.compiler.model import RDDLPlanningModel, RDDLLiftedModel
|
|
53
55
|
from pyRDDLGym.core.debug.logger import Logger
|
|
@@ -69,8 +71,7 @@ try:
|
|
|
69
71
|
from pyRDDLGym_jax.core.visualization import JaxPlannerDashboard
|
|
70
72
|
except Exception:
|
|
71
73
|
raise_warning('Failed to load the dashboard visualization tool: '
|
|
72
|
-
'please make sure you have installed the required packages.',
|
|
73
|
-
'red')
|
|
74
|
+
'please make sure you have installed the required packages.', 'red')
|
|
74
75
|
traceback.print_exc()
|
|
75
76
|
JaxPlannerDashboard = None
|
|
76
77
|
|
|
@@ -133,7 +134,7 @@ def _load_config(config, args):
|
|
|
133
134
|
comp_kwargs = model_args.get('complement_kwargs', {})
|
|
134
135
|
compare_name = model_args.get('comparison', 'SigmoidComparison')
|
|
135
136
|
compare_kwargs = model_args.get('comparison_kwargs', {})
|
|
136
|
-
sampling_name = model_args.get('sampling', '
|
|
137
|
+
sampling_name = model_args.get('sampling', 'SoftRandomSampling')
|
|
137
138
|
sampling_kwargs = model_args.get('sampling_kwargs', {})
|
|
138
139
|
rounding_name = model_args.get('rounding', 'SoftRounding')
|
|
139
140
|
rounding_kwargs = model_args.get('rounding_kwargs', {})
|
|
@@ -156,8 +157,7 @@ def _load_config(config, args):
|
|
|
156
157
|
initializer = _getattr_any(
|
|
157
158
|
packages=[initializers, hk.initializers], item=plan_initializer)
|
|
158
159
|
if initializer is None:
|
|
159
|
-
raise_warning(
|
|
160
|
-
f'Ignoring invalid initializer <{plan_initializer}>.', 'red')
|
|
160
|
+
raise_warning(f'Ignoring invalid initializer <{plan_initializer}>.', 'red')
|
|
161
161
|
del plan_kwargs['initializer']
|
|
162
162
|
else:
|
|
163
163
|
init_kwargs = plan_kwargs.pop('initializer_kwargs', {})
|
|
@@ -174,8 +174,7 @@ def _load_config(config, args):
|
|
|
174
174
|
activation = _getattr_any(
|
|
175
175
|
packages=[jax.nn, jax.numpy], item=plan_activation)
|
|
176
176
|
if activation is None:
|
|
177
|
-
raise_warning(
|
|
178
|
-
f'Ignoring invalid activation <{plan_activation}>.', 'red')
|
|
177
|
+
raise_warning(f'Ignoring invalid activation <{plan_activation}>.', 'red')
|
|
179
178
|
del plan_kwargs['activation']
|
|
180
179
|
else:
|
|
181
180
|
plan_kwargs['activation'] = activation
|
|
@@ -189,8 +188,7 @@ def _load_config(config, args):
|
|
|
189
188
|
if planner_optimizer is not None:
|
|
190
189
|
optimizer = _getattr_any(packages=[optax], item=planner_optimizer)
|
|
191
190
|
if optimizer is None:
|
|
192
|
-
raise_warning(
|
|
193
|
-
f'Ignoring invalid optimizer <{planner_optimizer}>.', 'red')
|
|
191
|
+
raise_warning(f'Ignoring invalid optimizer <{planner_optimizer}>.', 'red')
|
|
194
192
|
del planner_args['optimizer']
|
|
195
193
|
else:
|
|
196
194
|
planner_args['optimizer'] = optimizer
|
|
@@ -285,48 +283,14 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
285
283
|
pvars_cast = set()
|
|
286
284
|
for (var, values) in self.init_values.items():
|
|
287
285
|
self.init_values[var] = np.asarray(values, dtype=self.REAL)
|
|
288
|
-
if not np.issubdtype(np.
|
|
286
|
+
if not np.issubdtype(np.result_type(values), np.floating):
|
|
289
287
|
pvars_cast.add(var)
|
|
290
288
|
if pvars_cast:
|
|
291
289
|
raise_warning(f'JAX gradient compiler requires that initial values '
|
|
292
290
|
f'of p-variables {pvars_cast} be cast to float.')
|
|
293
291
|
|
|
294
292
|
# overwrite basic operations with fuzzy ones
|
|
295
|
-
self.
|
|
296
|
-
'>=': logic.greater_equal,
|
|
297
|
-
'<=': logic.less_equal,
|
|
298
|
-
'<': logic.less,
|
|
299
|
-
'>': logic.greater,
|
|
300
|
-
'==': logic.equal,
|
|
301
|
-
'~=': logic.not_equal
|
|
302
|
-
}
|
|
303
|
-
self.LOGICAL_NOT = logic.logical_not
|
|
304
|
-
self.LOGICAL_OPS = {
|
|
305
|
-
'^': logic.logical_and,
|
|
306
|
-
'&': logic.logical_and,
|
|
307
|
-
'|': logic.logical_or,
|
|
308
|
-
'~': logic.xor,
|
|
309
|
-
'=>': logic.implies,
|
|
310
|
-
'<=>': logic.equiv
|
|
311
|
-
}
|
|
312
|
-
self.AGGREGATION_OPS['forall'] = logic.forall
|
|
313
|
-
self.AGGREGATION_OPS['exists'] = logic.exists
|
|
314
|
-
self.AGGREGATION_OPS['argmin'] = logic.argmin
|
|
315
|
-
self.AGGREGATION_OPS['argmax'] = logic.argmax
|
|
316
|
-
self.KNOWN_UNARY['sgn'] = logic.sgn
|
|
317
|
-
self.KNOWN_UNARY['floor'] = logic.floor
|
|
318
|
-
self.KNOWN_UNARY['ceil'] = logic.ceil
|
|
319
|
-
self.KNOWN_UNARY['round'] = logic.round
|
|
320
|
-
self.KNOWN_UNARY['sqrt'] = logic.sqrt
|
|
321
|
-
self.KNOWN_BINARY['div'] = logic.div
|
|
322
|
-
self.KNOWN_BINARY['mod'] = logic.mod
|
|
323
|
-
self.KNOWN_BINARY['fmod'] = logic.mod
|
|
324
|
-
self.IF_HELPER = logic.control_if
|
|
325
|
-
self.SWITCH_HELPER = logic.control_switch
|
|
326
|
-
self.BERNOULLI_HELPER = logic.bernoulli
|
|
327
|
-
self.DISCRETE_HELPER = logic.discrete
|
|
328
|
-
self.POISSON_HELPER = logic.poisson
|
|
329
|
-
self.GEOMETRIC_HELPER = logic.geometric
|
|
293
|
+
self.OPS = logic.get_operator_dicts()
|
|
330
294
|
|
|
331
295
|
def _jax_stop_grad(self, jax_expr):
|
|
332
296
|
def _jax_wrapped_stop_grad(x, params, key):
|
|
@@ -575,7 +539,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
575
539
|
def _jax_non_bool_param_to_action(var, param, hyperparams):
|
|
576
540
|
if wrap_non_bool:
|
|
577
541
|
lower, upper = bounds_safe[var]
|
|
578
|
-
mb, ml, mu, mn = [
|
|
542
|
+
mb, ml, mu, mn = [jnp.asarray(mask, dtype=compiled.REAL)
|
|
579
543
|
for mask in cond_lists[var]]
|
|
580
544
|
action = (
|
|
581
545
|
mb * (lower + (upper - lower) * jax.nn.sigmoid(param)) +
|
|
@@ -660,7 +624,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
660
624
|
action = _jax_non_bool_param_to_action(var, action, hyperparams)
|
|
661
625
|
action = jnp.clip(action, *bounds[var])
|
|
662
626
|
if ranges[var] == 'int':
|
|
663
|
-
action = jnp.round(action)
|
|
627
|
+
action = jnp.asarray(jnp.round(action), dtype=compiled.INT)
|
|
664
628
|
actions[var] = action
|
|
665
629
|
return actions
|
|
666
630
|
|
|
@@ -961,12 +925,11 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
961
925
|
non_bool_dims = 0
|
|
962
926
|
for (var, values) in observed_vars.items():
|
|
963
927
|
if ranges[var] != 'bool':
|
|
964
|
-
value_size = np.
|
|
928
|
+
value_size = np.size(values)
|
|
965
929
|
if normalize_per_layer and value_size == 1:
|
|
966
930
|
raise_warning(
|
|
967
931
|
f'Cannot apply layer norm to state-fluent <{var}> '
|
|
968
|
-
f'of size 1: setting normalize_per_layer = False.',
|
|
969
|
-
'red')
|
|
932
|
+
f'of size 1: setting normalize_per_layer = False.', 'red')
|
|
970
933
|
normalize_per_layer = False
|
|
971
934
|
non_bool_dims += value_size
|
|
972
935
|
if not normalize_per_layer and non_bool_dims == 1:
|
|
@@ -990,9 +953,11 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
990
953
|
else:
|
|
991
954
|
if normalize and normalize_per_layer:
|
|
992
955
|
normalizer = hk.LayerNorm(
|
|
993
|
-
axis=-1,
|
|
956
|
+
axis=-1,
|
|
957
|
+
param_axis=-1,
|
|
994
958
|
name=f'input_norm_{input_names[var]}',
|
|
995
|
-
**self._normalizer_kwargs
|
|
959
|
+
**self._normalizer_kwargs
|
|
960
|
+
)
|
|
996
961
|
state = normalizer(state)
|
|
997
962
|
states_non_bool.append(state)
|
|
998
963
|
non_bool_dims += state.size
|
|
@@ -1001,8 +966,11 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1001
966
|
# optionally perform layer normalization on the non-bool inputs
|
|
1002
967
|
if normalize and not normalize_per_layer and non_bool_dims:
|
|
1003
968
|
normalizer = hk.LayerNorm(
|
|
1004
|
-
axis=-1,
|
|
1005
|
-
|
|
969
|
+
axis=-1,
|
|
970
|
+
param_axis=-1,
|
|
971
|
+
name='input_norm',
|
|
972
|
+
**self._normalizer_kwargs
|
|
973
|
+
)
|
|
1006
974
|
normalized = normalizer(state[:non_bool_dims])
|
|
1007
975
|
state = state.at[:non_bool_dims].set(normalized)
|
|
1008
976
|
return state
|
|
@@ -1021,7 +989,8 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1021
989
|
actions = {}
|
|
1022
990
|
for (var, size) in layer_sizes.items():
|
|
1023
991
|
linear = hk.Linear(size, name=layer_names[var], w_init=init)
|
|
1024
|
-
reshape = hk.Reshape(output_shape=shapes[var],
|
|
992
|
+
reshape = hk.Reshape(output_shape=shapes[var],
|
|
993
|
+
preserve_dims=-1,
|
|
1025
994
|
name=f'reshape_{layer_names[var]}')
|
|
1026
995
|
output = reshape(linear(hidden))
|
|
1027
996
|
if not shapes[var]:
|
|
@@ -1034,7 +1003,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1034
1003
|
else:
|
|
1035
1004
|
if wrap_non_bool:
|
|
1036
1005
|
lower, upper = bounds_safe[var]
|
|
1037
|
-
mb, ml, mu, mn = [
|
|
1006
|
+
mb, ml, mu, mn = [jnp.asarray(mask, dtype=compiled.REAL)
|
|
1038
1007
|
for mask in cond_lists[var]]
|
|
1039
1008
|
action = (
|
|
1040
1009
|
mb * (lower + (upper - lower) * jax.nn.sigmoid(output)) +
|
|
@@ -1048,8 +1017,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1048
1017
|
|
|
1049
1018
|
# for constraint satisfaction wrap bool actions with softmax
|
|
1050
1019
|
if use_constraint_satisfaction:
|
|
1051
|
-
linear = hk.Linear(
|
|
1052
|
-
bool_action_count, name='output_bool', w_init=init)
|
|
1020
|
+
linear = hk.Linear(bool_action_count, name='output_bool', w_init=init)
|
|
1053
1021
|
output = jax.nn.softmax(linear(hidden))
|
|
1054
1022
|
actions[bool_key] = output
|
|
1055
1023
|
|
|
@@ -1087,8 +1055,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1087
1055
|
|
|
1088
1056
|
# test action prediction
|
|
1089
1057
|
def _jax_wrapped_drp_predict_test(key, params, hyperparams, step, subs):
|
|
1090
|
-
actions = _jax_wrapped_drp_predict_train(
|
|
1091
|
-
key, params, hyperparams, step, subs)
|
|
1058
|
+
actions = _jax_wrapped_drp_predict_train(key, params, hyperparams, step, subs)
|
|
1092
1059
|
new_actions = {}
|
|
1093
1060
|
for (var, action) in actions.items():
|
|
1094
1061
|
prange = ranges[var]
|
|
@@ -1096,7 +1063,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1096
1063
|
new_action = action > 0.5
|
|
1097
1064
|
elif prange == 'int':
|
|
1098
1065
|
action = jnp.clip(action, *bounds[var])
|
|
1099
|
-
new_action = jnp.round(action)
|
|
1066
|
+
new_action = jnp.asarray(jnp.round(action), dtype=compiled.INT)
|
|
1100
1067
|
else:
|
|
1101
1068
|
new_action = jnp.clip(action, *bounds[var])
|
|
1102
1069
|
new_actions[var] = new_action
|
|
@@ -1247,17 +1214,22 @@ class GaussianPGPE(PGPE):
|
|
|
1247
1214
|
init_sigma: float=1.0,
|
|
1248
1215
|
sigma_range: Tuple[float, float]=(1e-5, 1e5),
|
|
1249
1216
|
scale_reward: bool=True,
|
|
1217
|
+
min_reward_scale: float=1e-5,
|
|
1250
1218
|
super_symmetric: bool=True,
|
|
1251
1219
|
super_symmetric_accurate: bool=True,
|
|
1252
1220
|
optimizer: Callable[..., optax.GradientTransformation]=optax.adam,
|
|
1253
1221
|
optimizer_kwargs_mu: Optional[Kwargs]=None,
|
|
1254
|
-
optimizer_kwargs_sigma: Optional[Kwargs]=None
|
|
1222
|
+
optimizer_kwargs_sigma: Optional[Kwargs]=None,
|
|
1223
|
+
start_entropy_coeff: float=1e-3,
|
|
1224
|
+
end_entropy_coeff: float=1e-8,
|
|
1225
|
+
max_kl_update: Optional[float]=None) -> None:
|
|
1255
1226
|
'''Creates a new Gaussian PGPE planner.
|
|
1256
1227
|
|
|
1257
1228
|
:param batch_size: how many policy parameters to sample per optimization step
|
|
1258
1229
|
:param init_sigma: initial standard deviation of Gaussian
|
|
1259
1230
|
:param sigma_range: bounds to constrain standard deviation
|
|
1260
1231
|
:param scale_reward: whether to apply reward scaling as in the paper
|
|
1232
|
+
:param min_reward_scale: minimum reward scaling to avoid underflow
|
|
1261
1233
|
:param super_symmetric: whether to use super-symmetric sampling as in the paper
|
|
1262
1234
|
:param super_symmetric_accurate: whether to use the accurate formula for super-
|
|
1263
1235
|
symmetric sampling or the simplified but biased formula
|
|
@@ -1266,6 +1238,9 @@ class GaussianPGPE(PGPE):
|
|
|
1266
1238
|
factory for the mean optimizer
|
|
1267
1239
|
:param optimizer_kwargs_sigma: a dictionary of parameters to pass to the SGD
|
|
1268
1240
|
factory for the standard deviation optimizer
|
|
1241
|
+
:param start_entropy_coeff: starting entropy regularization coeffient for Gaussian
|
|
1242
|
+
:param end_entropy_coeff: ending entropy regularization coeffient for Gaussian
|
|
1243
|
+
:param max_kl_update: bound on kl-divergence between parameter updates
|
|
1269
1244
|
'''
|
|
1270
1245
|
super().__init__()
|
|
1271
1246
|
|
|
@@ -1273,8 +1248,13 @@ class GaussianPGPE(PGPE):
|
|
|
1273
1248
|
self.init_sigma = init_sigma
|
|
1274
1249
|
self.sigma_range = sigma_range
|
|
1275
1250
|
self.scale_reward = scale_reward
|
|
1251
|
+
self.min_reward_scale = min_reward_scale
|
|
1276
1252
|
self.super_symmetric = super_symmetric
|
|
1277
1253
|
self.super_symmetric_accurate = super_symmetric_accurate
|
|
1254
|
+
|
|
1255
|
+
# entropy regularization penalty is decayed exponentially between these values
|
|
1256
|
+
self.start_entropy_coeff = start_entropy_coeff
|
|
1257
|
+
self.end_entropy_coeff = end_entropy_coeff
|
|
1278
1258
|
|
|
1279
1259
|
# set optimizers
|
|
1280
1260
|
if optimizer_kwargs_mu is None:
|
|
@@ -1284,36 +1264,62 @@ class GaussianPGPE(PGPE):
|
|
|
1284
1264
|
optimizer_kwargs_sigma = {'learning_rate': 0.1}
|
|
1285
1265
|
self.optimizer_kwargs_sigma = optimizer_kwargs_sigma
|
|
1286
1266
|
self.optimizer_name = optimizer
|
|
1287
|
-
|
|
1288
|
-
|
|
1267
|
+
try:
|
|
1268
|
+
mu_optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs_mu)
|
|
1269
|
+
sigma_optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs_sigma)
|
|
1270
|
+
except Exception as _:
|
|
1271
|
+
raise_warning(
|
|
1272
|
+
f'Failed to inject hyperparameters into optax optimizer for PGPE, '
|
|
1273
|
+
'rolling back to safer method: please note that kl-divergence '
|
|
1274
|
+
'constraints will be disabled.', 'red')
|
|
1275
|
+
mu_optimizer = optimizer(**optimizer_kwargs_mu)
|
|
1276
|
+
sigma_optimizer = optimizer(**optimizer_kwargs_sigma)
|
|
1277
|
+
max_kl_update = None
|
|
1289
1278
|
self.optimizers = (mu_optimizer, sigma_optimizer)
|
|
1279
|
+
self.max_kl = max_kl_update
|
|
1290
1280
|
|
|
1291
1281
|
def __str__(self) -> str:
|
|
1292
1282
|
return (f'PGPE hyper-parameters:\n'
|
|
1293
|
-
f' method
|
|
1294
|
-
f' batch_size
|
|
1295
|
-
f' init_sigma
|
|
1296
|
-
f' sigma_range
|
|
1297
|
-
f' scale_reward
|
|
1298
|
-
f'
|
|
1299
|
-
f'
|
|
1300
|
-
f'
|
|
1283
|
+
f' method ={self.__class__.__name__}\n'
|
|
1284
|
+
f' batch_size ={self.batch_size}\n'
|
|
1285
|
+
f' init_sigma ={self.init_sigma}\n'
|
|
1286
|
+
f' sigma_range ={self.sigma_range}\n'
|
|
1287
|
+
f' scale_reward ={self.scale_reward}\n'
|
|
1288
|
+
f' min_reward_scale ={self.min_reward_scale}\n'
|
|
1289
|
+
f' super_symmetric ={self.super_symmetric}\n'
|
|
1290
|
+
f' accurate ={self.super_symmetric_accurate}\n'
|
|
1291
|
+
f' optimizer ={self.optimizer_name}\n'
|
|
1301
1292
|
f' optimizer_kwargs:\n'
|
|
1302
1293
|
f' mu ={self.optimizer_kwargs_mu}\n'
|
|
1303
1294
|
f' sigma={self.optimizer_kwargs_sigma}\n'
|
|
1295
|
+
f' start_entropy_coeff={self.start_entropy_coeff}\n'
|
|
1296
|
+
f' end_entropy_coeff ={self.end_entropy_coeff}\n'
|
|
1297
|
+
f' max_kl_update ={self.max_kl}\n'
|
|
1304
1298
|
)
|
|
1305
1299
|
|
|
1306
1300
|
def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type) -> None:
|
|
1307
|
-
MIN_NORM = 1e-5
|
|
1308
1301
|
sigma0 = self.init_sigma
|
|
1309
1302
|
sigma_range = self.sigma_range
|
|
1310
1303
|
scale_reward = self.scale_reward
|
|
1304
|
+
min_reward_scale = self.min_reward_scale
|
|
1311
1305
|
super_symmetric = self.super_symmetric
|
|
1312
1306
|
super_symmetric_accurate = self.super_symmetric_accurate
|
|
1313
1307
|
batch_size = self.batch_size
|
|
1314
1308
|
optimizers = (mu_optimizer, sigma_optimizer) = self.optimizers
|
|
1315
|
-
|
|
1316
|
-
|
|
1309
|
+
max_kl = self.max_kl
|
|
1310
|
+
|
|
1311
|
+
# entropy regularization penalty is decayed exponentially by elapsed budget
|
|
1312
|
+
start_entropy_coeff = self.start_entropy_coeff
|
|
1313
|
+
if start_entropy_coeff == 0:
|
|
1314
|
+
entropy_coeff_decay = 0
|
|
1315
|
+
else:
|
|
1316
|
+
entropy_coeff_decay = (self.end_entropy_coeff / start_entropy_coeff) ** 0.01
|
|
1317
|
+
|
|
1318
|
+
# ***********************************************************************
|
|
1319
|
+
# INITIALIZATION OF POLICY
|
|
1320
|
+
#
|
|
1321
|
+
# ***********************************************************************
|
|
1322
|
+
|
|
1317
1323
|
def _jax_wrapped_pgpe_init(key, policy_params):
|
|
1318
1324
|
mu = policy_params
|
|
1319
1325
|
sigma = jax.tree_map(lambda x: sigma0 * jnp.ones_like(x), mu)
|
|
@@ -1324,7 +1330,11 @@ class GaussianPGPE(PGPE):
|
|
|
1324
1330
|
|
|
1325
1331
|
self._initializer = jax.jit(_jax_wrapped_pgpe_init)
|
|
1326
1332
|
|
|
1327
|
-
#
|
|
1333
|
+
# ***********************************************************************
|
|
1334
|
+
# PARAMETER SAMPLING FUNCTIONS
|
|
1335
|
+
#
|
|
1336
|
+
# ***********************************************************************
|
|
1337
|
+
|
|
1328
1338
|
def _jax_wrapped_mu_noise(key, sigma):
|
|
1329
1339
|
return sigma * random.normal(key, shape=jnp.shape(sigma), dtype=real_dtype)
|
|
1330
1340
|
|
|
@@ -1334,19 +1344,20 @@ class GaussianPGPE(PGPE):
|
|
|
1334
1344
|
a = (sigma - jnp.abs(epsilon)) / sigma
|
|
1335
1345
|
if super_symmetric_accurate:
|
|
1336
1346
|
aa = jnp.abs(a)
|
|
1347
|
+
aa3 = jnp.power(aa, 3)
|
|
1337
1348
|
epsilon_star = jnp.sign(epsilon) * phi * jnp.where(
|
|
1338
1349
|
a <= 0,
|
|
1339
|
-
jnp.exp(c1 *
|
|
1340
|
-
jnp.exp(aa - c3 * aa * jnp.log(1.0 -
|
|
1350
|
+
jnp.exp(c1 * (aa3 - aa) / jnp.log(aa + 1e-10) + c2 * aa),
|
|
1351
|
+
jnp.exp(aa - c3 * aa * jnp.log(1.0 - aa3 + 1e-10))
|
|
1341
1352
|
)
|
|
1342
1353
|
else:
|
|
1343
1354
|
epsilon_star = jnp.sign(epsilon) * phi * jnp.exp(a)
|
|
1344
1355
|
return epsilon_star
|
|
1345
1356
|
|
|
1346
1357
|
def _jax_wrapped_sample_params(key, mu, sigma):
|
|
1347
|
-
|
|
1348
|
-
|
|
1349
|
-
|
|
1358
|
+
treedef = jax.tree_util.tree_structure(sigma)
|
|
1359
|
+
keys = random.split(key, num=treedef.num_leaves)
|
|
1360
|
+
keys_pytree = jax.tree_util.tree_unflatten(treedef=treedef, leaves=keys)
|
|
1350
1361
|
epsilon = jax.tree_map(_jax_wrapped_mu_noise, keys_pytree, sigma)
|
|
1351
1362
|
p1 = jax.tree_map(jnp.add, mu, epsilon)
|
|
1352
1363
|
p2 = jax.tree_map(jnp.subtract, mu, epsilon)
|
|
@@ -1356,14 +1367,18 @@ class GaussianPGPE(PGPE):
|
|
|
1356
1367
|
p4 = jax.tree_map(jnp.subtract, mu, epsilon_star)
|
|
1357
1368
|
else:
|
|
1358
1369
|
epsilon_star, p3, p4 = epsilon, p1, p2
|
|
1359
|
-
return
|
|
1370
|
+
return p1, p2, p3, p4, epsilon, epsilon_star
|
|
1360
1371
|
|
|
1361
|
-
#
|
|
1372
|
+
# ***********************************************************************
|
|
1373
|
+
# POLICY GRADIENT CALCULATION
|
|
1374
|
+
#
|
|
1375
|
+
# ***********************************************************************
|
|
1376
|
+
|
|
1362
1377
|
def _jax_wrapped_mu_grad(epsilon, epsilon_star, r1, r2, r3, r4, m):
|
|
1363
1378
|
if super_symmetric:
|
|
1364
1379
|
if scale_reward:
|
|
1365
|
-
scale1 = jnp.maximum(
|
|
1366
|
-
scale2 = jnp.maximum(
|
|
1380
|
+
scale1 = jnp.maximum(min_reward_scale, m - (r1 + r2) / 2)
|
|
1381
|
+
scale2 = jnp.maximum(min_reward_scale, m - (r3 + r4) / 2)
|
|
1367
1382
|
else:
|
|
1368
1383
|
scale1 = scale2 = 1.0
|
|
1369
1384
|
r_mu1 = (r1 - r2) / (2 * scale1)
|
|
@@ -1371,37 +1386,37 @@ class GaussianPGPE(PGPE):
|
|
|
1371
1386
|
grad = -(r_mu1 * epsilon + r_mu2 * epsilon_star)
|
|
1372
1387
|
else:
|
|
1373
1388
|
if scale_reward:
|
|
1374
|
-
scale = jnp.maximum(
|
|
1389
|
+
scale = jnp.maximum(min_reward_scale, m - (r1 + r2) / 2)
|
|
1375
1390
|
else:
|
|
1376
1391
|
scale = 1.0
|
|
1377
1392
|
r_mu = (r1 - r2) / (2 * scale)
|
|
1378
1393
|
grad = -r_mu * epsilon
|
|
1379
1394
|
return grad
|
|
1380
1395
|
|
|
1381
|
-
def _jax_wrapped_sigma_grad(epsilon, epsilon_star, sigma, r1, r2, r3, r4, m):
|
|
1396
|
+
def _jax_wrapped_sigma_grad(epsilon, epsilon_star, sigma, r1, r2, r3, r4, m, ent):
|
|
1382
1397
|
if super_symmetric:
|
|
1383
1398
|
mask = r1 + r2 >= r3 + r4
|
|
1384
1399
|
epsilon_tau = mask * epsilon + (1 - mask) * epsilon_star
|
|
1385
|
-
s = epsilon_tau
|
|
1400
|
+
s = jnp.square(epsilon_tau) / sigma - sigma
|
|
1386
1401
|
if scale_reward:
|
|
1387
|
-
scale = jnp.maximum(
|
|
1402
|
+
scale = jnp.maximum(min_reward_scale, m - (r1 + r2 + r3 + r4) / 4)
|
|
1388
1403
|
else:
|
|
1389
1404
|
scale = 1.0
|
|
1390
1405
|
r_sigma = ((r1 + r2) - (r3 + r4)) / (4 * scale)
|
|
1391
1406
|
else:
|
|
1392
|
-
s = epsilon
|
|
1407
|
+
s = jnp.square(epsilon) / sigma - sigma
|
|
1393
1408
|
if scale_reward:
|
|
1394
|
-
scale = jnp.maximum(
|
|
1409
|
+
scale = jnp.maximum(min_reward_scale, jnp.abs(m))
|
|
1395
1410
|
else:
|
|
1396
1411
|
scale = 1.0
|
|
1397
1412
|
r_sigma = (r1 + r2) / (2 * scale)
|
|
1398
|
-
grad = -r_sigma * s
|
|
1413
|
+
grad = -(r_sigma * s + ent / sigma)
|
|
1399
1414
|
return grad
|
|
1400
1415
|
|
|
1401
|
-
def _jax_wrapped_pgpe_grad(key, mu, sigma, r_max,
|
|
1416
|
+
def _jax_wrapped_pgpe_grad(key, mu, sigma, r_max, ent,
|
|
1402
1417
|
policy_hyperparams, subs, model_params):
|
|
1403
1418
|
key, subkey = random.split(key)
|
|
1404
|
-
|
|
1419
|
+
p1, p2, p3, p4, epsilon, epsilon_star = _jax_wrapped_sample_params(
|
|
1405
1420
|
key, mu, sigma)
|
|
1406
1421
|
r1 = -loss_fn(subkey, p1, policy_hyperparams, subs, model_params)[0]
|
|
1407
1422
|
r2 = -loss_fn(subkey, p2, policy_hyperparams, subs, model_params)[0]
|
|
@@ -1419,42 +1434,76 @@ class GaussianPGPE(PGPE):
|
|
|
1419
1434
|
epsilon, epsilon_star
|
|
1420
1435
|
)
|
|
1421
1436
|
grad_sigma = jax.tree_map(
|
|
1422
|
-
partial(_jax_wrapped_sigma_grad,
|
|
1437
|
+
partial(_jax_wrapped_sigma_grad,
|
|
1438
|
+
r1=r1, r2=r2, r3=r3, r4=r4, m=r_max, ent=ent),
|
|
1423
1439
|
epsilon, epsilon_star, sigma
|
|
1424
1440
|
)
|
|
1425
1441
|
return grad_mu, grad_sigma, r_max
|
|
1426
1442
|
|
|
1427
|
-
def _jax_wrapped_pgpe_grad_batched(key, pgpe_params, r_max,
|
|
1443
|
+
def _jax_wrapped_pgpe_grad_batched(key, pgpe_params, r_max, ent,
|
|
1428
1444
|
policy_hyperparams, subs, model_params):
|
|
1429
1445
|
mu, sigma = pgpe_params
|
|
1430
1446
|
if batch_size == 1:
|
|
1431
1447
|
mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad(
|
|
1432
|
-
key, mu, sigma, r_max, policy_hyperparams, subs, model_params)
|
|
1448
|
+
key, mu, sigma, r_max, ent, policy_hyperparams, subs, model_params)
|
|
1433
1449
|
else:
|
|
1434
1450
|
keys = random.split(key, num=batch_size)
|
|
1435
1451
|
mu_grads, sigma_grads, r_maxs = jax.vmap(
|
|
1436
1452
|
_jax_wrapped_pgpe_grad,
|
|
1437
|
-
in_axes=(0, None, None, None, None, None, None)
|
|
1438
|
-
)(keys, mu, sigma, r_max, policy_hyperparams, subs, model_params)
|
|
1439
|
-
mu_grad = jax.tree_map(
|
|
1440
|
-
|
|
1453
|
+
in_axes=(0, None, None, None, None, None, None, None)
|
|
1454
|
+
)(keys, mu, sigma, r_max, ent, policy_hyperparams, subs, model_params)
|
|
1455
|
+
mu_grad, sigma_grad = jax.tree_map(
|
|
1456
|
+
partial(jnp.mean, axis=0), (mu_grads, sigma_grads))
|
|
1441
1457
|
new_r_max = jnp.max(r_maxs)
|
|
1442
1458
|
return mu_grad, sigma_grad, new_r_max
|
|
1459
|
+
|
|
1460
|
+
# ***********************************************************************
|
|
1461
|
+
# PARAMETER UPDATE
|
|
1462
|
+
#
|
|
1463
|
+
# ***********************************************************************
|
|
1443
1464
|
|
|
1444
|
-
def
|
|
1465
|
+
def _jax_wrapped_pgpe_kl_term(mu, sigma, old_mu, old_sigma):
|
|
1466
|
+
return 0.5 * jnp.sum(2 * jnp.log(sigma / old_sigma) +
|
|
1467
|
+
jnp.square(old_sigma / sigma) +
|
|
1468
|
+
jnp.square((mu - old_mu) / sigma) - 1)
|
|
1469
|
+
|
|
1470
|
+
def _jax_wrapped_pgpe_update(key, pgpe_params, r_max, progress,
|
|
1445
1471
|
policy_hyperparams, subs, model_params,
|
|
1446
1472
|
pgpe_opt_state):
|
|
1473
|
+
# regular update
|
|
1447
1474
|
mu, sigma = pgpe_params
|
|
1448
1475
|
mu_state, sigma_state = pgpe_opt_state
|
|
1476
|
+
ent = start_entropy_coeff * jnp.power(entropy_coeff_decay, progress)
|
|
1449
1477
|
mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad_batched(
|
|
1450
|
-
key, pgpe_params, r_max, policy_hyperparams, subs, model_params)
|
|
1478
|
+
key, pgpe_params, r_max, ent, policy_hyperparams, subs, model_params)
|
|
1451
1479
|
mu_updates, new_mu_state = mu_optimizer.update(mu_grad, mu_state, params=mu)
|
|
1452
1480
|
sigma_updates, new_sigma_state = sigma_optimizer.update(
|
|
1453
1481
|
sigma_grad, sigma_state, params=sigma)
|
|
1454
1482
|
new_mu = optax.apply_updates(mu, mu_updates)
|
|
1455
|
-
new_mu, converged = projection(new_mu, policy_hyperparams)
|
|
1456
1483
|
new_sigma = optax.apply_updates(sigma, sigma_updates)
|
|
1457
1484
|
new_sigma = jax.tree_map(lambda x: jnp.clip(x, *sigma_range), new_sigma)
|
|
1485
|
+
|
|
1486
|
+
# respect KL divergence contraint with old parameters
|
|
1487
|
+
if max_kl is not None:
|
|
1488
|
+
old_mu_lr = new_mu_state.hyperparams['learning_rate']
|
|
1489
|
+
old_sigma_lr = new_sigma_state.hyperparams['learning_rate']
|
|
1490
|
+
kl_terms = jax.tree_map(
|
|
1491
|
+
_jax_wrapped_pgpe_kl_term, new_mu, new_sigma, mu, sigma)
|
|
1492
|
+
total_kl = jax.tree_util.tree_reduce(jnp.add, kl_terms)
|
|
1493
|
+
kl_reduction = jnp.minimum(1.0, jnp.sqrt(max_kl / total_kl))
|
|
1494
|
+
mu_state.hyperparams['learning_rate'] = old_mu_lr * kl_reduction
|
|
1495
|
+
sigma_state.hyperparams['learning_rate'] = old_sigma_lr * kl_reduction
|
|
1496
|
+
mu_updates, new_mu_state = mu_optimizer.update(mu_grad, mu_state, params=mu)
|
|
1497
|
+
sigma_updates, new_sigma_state = sigma_optimizer.update(
|
|
1498
|
+
sigma_grad, sigma_state, params=sigma)
|
|
1499
|
+
new_mu = optax.apply_updates(mu, mu_updates)
|
|
1500
|
+
new_sigma = optax.apply_updates(sigma, sigma_updates)
|
|
1501
|
+
new_sigma = jax.tree_map(lambda x: jnp.clip(x, *sigma_range), new_sigma)
|
|
1502
|
+
new_mu_state.hyperparams['learning_rate'] = old_mu_lr
|
|
1503
|
+
new_sigma_state.hyperparams['learning_rate'] = old_sigma_lr
|
|
1504
|
+
|
|
1505
|
+
# apply projection step and finalize results
|
|
1506
|
+
new_mu, converged = projection(new_mu, policy_hyperparams)
|
|
1458
1507
|
new_pgpe_params = (new_mu, new_sigma)
|
|
1459
1508
|
new_pgpe_opt_state = (new_mu_state, new_sigma_state)
|
|
1460
1509
|
policy_params = new_mu
|
|
@@ -1463,6 +1512,71 @@ class GaussianPGPE(PGPE):
|
|
|
1463
1512
|
self._update = jax.jit(_jax_wrapped_pgpe_update)
|
|
1464
1513
|
|
|
1465
1514
|
|
|
1515
|
+
# ***********************************************************************
|
|
1516
|
+
# ALL VERSIONS OF RISK FUNCTIONS
|
|
1517
|
+
#
|
|
1518
|
+
# Based on the original paper "A Distributional Framework for Risk-Sensitive
|
|
1519
|
+
# End-to-End Planning in Continuous MDPs" by Patton et al., AAAI 2022.
|
|
1520
|
+
#
|
|
1521
|
+
# Original risk functions:
|
|
1522
|
+
# - entropic utility
|
|
1523
|
+
# - mean-variance
|
|
1524
|
+
# - mean-semideviation
|
|
1525
|
+
# - conditional value at risk with straight-through gradient trick
|
|
1526
|
+
#
|
|
1527
|
+
# ***********************************************************************
|
|
1528
|
+
|
|
1529
|
+
|
|
1530
|
+
@jax.jit
|
|
1531
|
+
def entropic_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
1532
|
+
return (-1.0 / beta) * jax.scipy.special.logsumexp(
|
|
1533
|
+
-beta * returns, b=1.0 / returns.size)
|
|
1534
|
+
|
|
1535
|
+
|
|
1536
|
+
@jax.jit
|
|
1537
|
+
def mean_variance_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
1538
|
+
return jnp.mean(returns) - 0.5 * beta * jnp.var(returns)
|
|
1539
|
+
|
|
1540
|
+
|
|
1541
|
+
@jax.jit
|
|
1542
|
+
def mean_deviation_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
1543
|
+
return jnp.mean(returns) - 0.5 * beta * jnp.std(returns)
|
|
1544
|
+
|
|
1545
|
+
|
|
1546
|
+
@jax.jit
|
|
1547
|
+
def mean_semideviation_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
1548
|
+
mu = jnp.mean(returns)
|
|
1549
|
+
msd = jnp.sqrt(jnp.mean(jnp.square(jnp.minimum(0.0, returns - mu))))
|
|
1550
|
+
return mu - 0.5 * beta * msd
|
|
1551
|
+
|
|
1552
|
+
|
|
1553
|
+
@jax.jit
|
|
1554
|
+
def mean_semivariance_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
1555
|
+
mu = jnp.mean(returns)
|
|
1556
|
+
msv = jnp.mean(jnp.square(jnp.minimum(0.0, returns - mu)))
|
|
1557
|
+
return mu - 0.5 * beta * msv
|
|
1558
|
+
|
|
1559
|
+
|
|
1560
|
+
@jax.jit
|
|
1561
|
+
def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
|
|
1562
|
+
var = jnp.percentile(returns, q=100 * alpha)
|
|
1563
|
+
mask = returns <= var
|
|
1564
|
+
weights = mask / jnp.maximum(1, jnp.sum(mask))
|
|
1565
|
+
return jnp.sum(returns * weights)
|
|
1566
|
+
|
|
1567
|
+
|
|
1568
|
+
UTILITY_LOOKUP = {
|
|
1569
|
+
'mean': jnp.mean,
|
|
1570
|
+
'mean_var': mean_variance_utility,
|
|
1571
|
+
'mean_std': mean_deviation_utility,
|
|
1572
|
+
'mean_semivar': mean_semivariance_utility,
|
|
1573
|
+
'mean_semidev': mean_semideviation_utility,
|
|
1574
|
+
'entropic': entropic_utility,
|
|
1575
|
+
'exponential': entropic_utility,
|
|
1576
|
+
'cvar': cvar_utility
|
|
1577
|
+
}
|
|
1578
|
+
|
|
1579
|
+
|
|
1466
1580
|
# ***********************************************************************
|
|
1467
1581
|
# ALL VERSIONS OF JAX PLANNER
|
|
1468
1582
|
#
|
|
@@ -1525,8 +1639,7 @@ class JaxBackpropPlanner:
|
|
|
1525
1639
|
reward as a form of normalization
|
|
1526
1640
|
:param utility: how to aggregate return observations to compute utility
|
|
1527
1641
|
of a policy or plan; must be either a function mapping jax array to a
|
|
1528
|
-
scalar, or a a string identifying the utility function by name
|
|
1529
|
-
("mean", "mean_var", "entropic", or "cvar" are currently supported)
|
|
1642
|
+
scalar, or a a string identifying the utility function by name
|
|
1530
1643
|
:param utility_kwargs: additional keyword arguments to pass hyper-
|
|
1531
1644
|
parameters to the utility function call
|
|
1532
1645
|
:param cpfs_without_grad: which CPFs do not have gradients (use straight
|
|
@@ -1584,18 +1697,11 @@ class JaxBackpropPlanner:
|
|
|
1584
1697
|
# set utility
|
|
1585
1698
|
if isinstance(utility, str):
|
|
1586
1699
|
utility = utility.lower()
|
|
1587
|
-
|
|
1588
|
-
|
|
1589
|
-
elif utility == 'mean_var':
|
|
1590
|
-
utility_fn = mean_variance_utility
|
|
1591
|
-
elif utility == 'entropic':
|
|
1592
|
-
utility_fn = entropic_utility
|
|
1593
|
-
elif utility == 'cvar':
|
|
1594
|
-
utility_fn = cvar_utility
|
|
1595
|
-
else:
|
|
1700
|
+
utility_fn = UTILITY_LOOKUP.get(utility, None)
|
|
1701
|
+
if utility_fn is None:
|
|
1596
1702
|
raise RDDLNotImplementedError(
|
|
1597
|
-
f'Utility
|
|
1598
|
-
'must be one of
|
|
1703
|
+
f'Utility <{utility}> is not supported, '
|
|
1704
|
+
f'must be one of {list(UTILITY_LOOKUP.keys())}.')
|
|
1599
1705
|
else:
|
|
1600
1706
|
utility_fn = utility
|
|
1601
1707
|
self.utility = utility_fn
|
|
@@ -1746,7 +1852,6 @@ r"""
|
|
|
1746
1852
|
|
|
1747
1853
|
# optimization
|
|
1748
1854
|
self.update = self._jax_update(train_loss)
|
|
1749
|
-
self.check_zero_grad = self._jax_check_zero_gradients()
|
|
1750
1855
|
|
|
1751
1856
|
# pgpe option
|
|
1752
1857
|
if self.use_pgpe:
|
|
@@ -1809,6 +1914,12 @@ r"""
|
|
|
1809
1914
|
projection = self.plan.projection
|
|
1810
1915
|
use_ls = self.line_search_kwargs is not None
|
|
1811
1916
|
|
|
1917
|
+
# check if the gradients are all zeros
|
|
1918
|
+
def _jax_wrapped_zero_gradients(grad):
|
|
1919
|
+
leaves, _ = jax.tree_util.tree_flatten(
|
|
1920
|
+
jax.tree_map(lambda g: jnp.allclose(g, 0), grad))
|
|
1921
|
+
return jnp.all(jnp.asarray(leaves))
|
|
1922
|
+
|
|
1812
1923
|
# calculate the plan gradient w.r.t. return loss and update optimizer
|
|
1813
1924
|
# also perform a projection step to satisfy constraints on actions
|
|
1814
1925
|
def _jax_wrapped_loss_swapped(policy_params, key, policy_hyperparams,
|
|
@@ -1833,23 +1944,12 @@ r"""
|
|
|
1833
1944
|
policy_params, converged = projection(policy_params, policy_hyperparams)
|
|
1834
1945
|
log['grad'] = grad
|
|
1835
1946
|
log['updates'] = updates
|
|
1947
|
+
zero_grads = _jax_wrapped_zero_gradients(grad)
|
|
1836
1948
|
return policy_params, converged, opt_state, opt_aux, \
|
|
1837
|
-
loss_val, log, model_params
|
|
1949
|
+
loss_val, log, model_params, zero_grads
|
|
1838
1950
|
|
|
1839
1951
|
return jax.jit(_jax_wrapped_plan_update)
|
|
1840
1952
|
|
|
1841
|
-
def _jax_check_zero_gradients(self):
|
|
1842
|
-
|
|
1843
|
-
def _jax_wrapped_zero_gradient(grad):
|
|
1844
|
-
return jnp.allclose(grad, 0)
|
|
1845
|
-
|
|
1846
|
-
def _jax_wrapped_zero_gradients(grad):
|
|
1847
|
-
leaves, _ = jax.tree_util.tree_flatten(
|
|
1848
|
-
jax.tree_map(_jax_wrapped_zero_gradient, grad))
|
|
1849
|
-
return jnp.all(jnp.asarray(leaves))
|
|
1850
|
-
|
|
1851
|
-
return jax.jit(_jax_wrapped_zero_gradients)
|
|
1852
|
-
|
|
1853
1953
|
def _batched_init_subs(self, subs):
|
|
1854
1954
|
rddl = self.rddl
|
|
1855
1955
|
n_train, n_test = self.batch_size_train, self.batch_size_test
|
|
@@ -1865,7 +1965,7 @@ r"""
|
|
|
1865
1965
|
f'{set(self.test_compiled.init_values.keys())}.')
|
|
1866
1966
|
value = np.reshape(value, newshape=np.shape(init_value))[np.newaxis, ...]
|
|
1867
1967
|
train_value = np.repeat(value, repeats=n_train, axis=0)
|
|
1868
|
-
train_value =
|
|
1968
|
+
train_value = np.asarray(train_value, dtype=self.compiled.REAL)
|
|
1869
1969
|
init_train[name] = train_value
|
|
1870
1970
|
init_test[name] = np.repeat(value, repeats=n_test, axis=0)
|
|
1871
1971
|
|
|
@@ -2153,11 +2253,12 @@ r"""
|
|
|
2153
2253
|
# ======================================================================
|
|
2154
2254
|
|
|
2155
2255
|
# initialize running statistics
|
|
2156
|
-
best_params, best_loss, best_grad = policy_params, jnp.inf,
|
|
2256
|
+
best_params, best_loss, best_grad = policy_params, jnp.inf, None
|
|
2157
2257
|
last_iter_improve = 0
|
|
2158
2258
|
rolling_test_loss = RollingMean(test_rolling_window)
|
|
2159
2259
|
log = {}
|
|
2160
2260
|
status = JaxPlannerStatus.NORMAL
|
|
2261
|
+
progress_percent = 0
|
|
2161
2262
|
|
|
2162
2263
|
# initialize stopping criterion
|
|
2163
2264
|
if stopping_rule is not None:
|
|
@@ -2169,16 +2270,19 @@ r"""
|
|
|
2169
2270
|
dashboard_id, dashboard.get_planner_info(self),
|
|
2170
2271
|
key=dash_key, viz=self.dashboard_viz)
|
|
2171
2272
|
|
|
2273
|
+
# progress bar
|
|
2274
|
+
if print_progress:
|
|
2275
|
+
progress_bar = tqdm(None, total=100, position=tqdm_position,
|
|
2276
|
+
bar_format='{l_bar}{bar}| {elapsed} {postfix}')
|
|
2277
|
+
else:
|
|
2278
|
+
progress_bar = None
|
|
2279
|
+
position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
|
|
2280
|
+
|
|
2172
2281
|
# ======================================================================
|
|
2173
2282
|
# MAIN TRAINING LOOP BEGINS
|
|
2174
2283
|
# ======================================================================
|
|
2175
2284
|
|
|
2176
|
-
|
|
2177
|
-
if print_progress:
|
|
2178
|
-
iters = tqdm(iters, total=100, position=tqdm_position)
|
|
2179
|
-
position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
|
|
2180
|
-
|
|
2181
|
-
for it in iters:
|
|
2285
|
+
for it in range(epochs):
|
|
2182
2286
|
|
|
2183
2287
|
# ==================================================================
|
|
2184
2288
|
# NEXT GRADIENT DESCENT STEP
|
|
@@ -2189,8 +2293,9 @@ r"""
|
|
|
2189
2293
|
# update the parameters of the plan
|
|
2190
2294
|
key, subkey = random.split(key)
|
|
2191
2295
|
(policy_params, converged, opt_state, opt_aux, train_loss, train_log,
|
|
2192
|
-
model_params) = self.update(
|
|
2193
|
-
|
|
2296
|
+
model_params, zero_grads) = self.update(
|
|
2297
|
+
subkey, policy_params, policy_hyperparams, train_subs, model_params,
|
|
2298
|
+
opt_state, opt_aux)
|
|
2194
2299
|
test_loss, (test_log, model_params_test) = self.test_loss(
|
|
2195
2300
|
subkey, policy_params, policy_hyperparams, test_subs, model_params_test)
|
|
2196
2301
|
test_loss_smooth = rolling_test_loss.update(test_loss)
|
|
@@ -2200,8 +2305,9 @@ r"""
|
|
|
2200
2305
|
if self.use_pgpe:
|
|
2201
2306
|
key, subkey = random.split(key)
|
|
2202
2307
|
pgpe_params, r_max, pgpe_opt_state, pgpe_param, pgpe_converged = \
|
|
2203
|
-
self.pgpe.update(subkey, pgpe_params, r_max,
|
|
2204
|
-
test_subs,
|
|
2308
|
+
self.pgpe.update(subkey, pgpe_params, r_max, progress_percent,
|
|
2309
|
+
policy_hyperparams, test_subs, model_params_test,
|
|
2310
|
+
pgpe_opt_state)
|
|
2205
2311
|
pgpe_loss, _ = self.test_loss(
|
|
2206
2312
|
subkey, pgpe_param, policy_hyperparams, test_subs, model_params_test)
|
|
2207
2313
|
pgpe_loss_smooth = rolling_pgpe_loss.update(pgpe_loss)
|
|
@@ -2228,7 +2334,7 @@ r"""
|
|
|
2228
2334
|
# ==================================================================
|
|
2229
2335
|
|
|
2230
2336
|
# no progress
|
|
2231
|
-
if (not pgpe_improve) and
|
|
2337
|
+
if (not pgpe_improve) and zero_grads:
|
|
2232
2338
|
status = JaxPlannerStatus.NO_PROGRESS
|
|
2233
2339
|
|
|
2234
2340
|
# constraint satisfaction problem
|
|
@@ -2256,7 +2362,8 @@ r"""
|
|
|
2256
2362
|
status = JaxPlannerStatus.ITER_BUDGET_REACHED
|
|
2257
2363
|
|
|
2258
2364
|
# build a callback
|
|
2259
|
-
progress_percent =
|
|
2365
|
+
progress_percent = 100 * min(
|
|
2366
|
+
1, max(0, elapsed / train_seconds, it / (epochs - 1)))
|
|
2260
2367
|
callback = {
|
|
2261
2368
|
'status': status,
|
|
2262
2369
|
'iteration': it,
|
|
@@ -2279,19 +2386,22 @@ r"""
|
|
|
2279
2386
|
'train_log': train_log,
|
|
2280
2387
|
**test_log
|
|
2281
2388
|
}
|
|
2282
|
-
|
|
2389
|
+
|
|
2283
2390
|
# stopping condition reached
|
|
2284
2391
|
if stopping_rule is not None and stopping_rule.monitor(callback):
|
|
2285
2392
|
callback['status'] = status = JaxPlannerStatus.STOPPING_RULE_REACHED
|
|
2286
2393
|
|
|
2287
2394
|
# if the progress bar is used
|
|
2288
2395
|
if print_progress:
|
|
2289
|
-
|
|
2290
|
-
iters.set_description(
|
|
2396
|
+
progress_bar.set_description(
|
|
2291
2397
|
f'{position_str} {it:6} it / {-train_loss:14.5f} train / '
|
|
2292
2398
|
f'{-test_loss_smooth:14.5f} test / {-best_loss:14.5f} best / '
|
|
2293
|
-
f'{status.value} status / {total_pgpe_it:6} pgpe'
|
|
2399
|
+
f'{status.value} status / {total_pgpe_it:6} pgpe',
|
|
2400
|
+
refresh=False
|
|
2294
2401
|
)
|
|
2402
|
+
progress_bar.set_postfix_str(
|
|
2403
|
+
f"{(it + 1) / (elapsed + 1e-6):.2f}it/s", refresh=False)
|
|
2404
|
+
progress_bar.update(progress_percent - progress_bar.n)
|
|
2295
2405
|
|
|
2296
2406
|
# dash-board
|
|
2297
2407
|
if dashboard is not None:
|
|
@@ -2312,7 +2422,7 @@ r"""
|
|
|
2312
2422
|
|
|
2313
2423
|
# release resources
|
|
2314
2424
|
if print_progress:
|
|
2315
|
-
|
|
2425
|
+
progress_bar.close()
|
|
2316
2426
|
|
|
2317
2427
|
# validate the test return
|
|
2318
2428
|
if log:
|
|
@@ -2332,7 +2442,7 @@ r"""
|
|
|
2332
2442
|
last_iter_improve, -train_loss, -test_loss_smooth, -best_loss, grad_norm)
|
|
2333
2443
|
print(f'summary of optimization:\n'
|
|
2334
2444
|
f' status ={status}\n'
|
|
2335
|
-
f' time ={elapsed:.
|
|
2445
|
+
f' time ={elapsed:.3f} sec.\n'
|
|
2336
2446
|
f' iterations ={it}\n'
|
|
2337
2447
|
f' best objective={-best_loss:.6f}\n'
|
|
2338
2448
|
f' best grad norm={grad_norm}\n'
|
|
@@ -2358,12 +2468,12 @@ r"""
|
|
|
2358
2468
|
return termcolor.colored(
|
|
2359
2469
|
'[FAILURE] no progress was made '
|
|
2360
2470
|
f'and max grad norm {max_grad_norm:.6f} was zero: '
|
|
2361
|
-
'
|
|
2471
|
+
'solver likely stuck in a plateau.', 'red')
|
|
2362
2472
|
else:
|
|
2363
2473
|
return termcolor.colored(
|
|
2364
2474
|
'[FAILURE] no progress was made '
|
|
2365
2475
|
f'but max grad norm {max_grad_norm:.6f} was non-zero: '
|
|
2366
|
-
'
|
|
2476
|
+
'learning rate or other hyper-parameters likely suboptimal.',
|
|
2367
2477
|
'red')
|
|
2368
2478
|
|
|
2369
2479
|
# model is likely poor IF:
|
|
@@ -2372,8 +2482,8 @@ r"""
|
|
|
2372
2482
|
return termcolor.colored(
|
|
2373
2483
|
'[WARNING] progress was made '
|
|
2374
2484
|
f'but relative train-test error {validation_error:.6f} was high: '
|
|
2375
|
-
'model relaxation around
|
|
2376
|
-
'
|
|
2485
|
+
'poor model relaxation around solution or batch size too small.',
|
|
2486
|
+
'yellow')
|
|
2377
2487
|
|
|
2378
2488
|
# model likely did not converge IF:
|
|
2379
2489
|
# 1. the max grad relative to the return is high
|
|
@@ -2383,9 +2493,9 @@ r"""
|
|
|
2383
2493
|
return termcolor.colored(
|
|
2384
2494
|
'[WARNING] progress was made '
|
|
2385
2495
|
f'but max grad norm {max_grad_norm:.6f} was high: '
|
|
2386
|
-
'
|
|
2387
|
-
'or
|
|
2388
|
-
'or
|
|
2496
|
+
'solution locally suboptimal '
|
|
2497
|
+
'or relaxed model not smooth around solution '
|
|
2498
|
+
'or batch size too small.', 'yellow')
|
|
2389
2499
|
|
|
2390
2500
|
# likely successful
|
|
2391
2501
|
return termcolor.colored(
|
|
@@ -2412,8 +2522,7 @@ r"""
|
|
|
2412
2522
|
for (var, values) in subs.items():
|
|
2413
2523
|
|
|
2414
2524
|
# must not be grounded
|
|
2415
|
-
if RDDLPlanningModel.FLUENT_SEP in var
|
|
2416
|
-
or RDDLPlanningModel.OBJECT_SEP in var:
|
|
2525
|
+
if RDDLPlanningModel.FLUENT_SEP in var or RDDLPlanningModel.OBJECT_SEP in var:
|
|
2417
2526
|
raise ValueError(f'State dictionary passed to the JAX policy is '
|
|
2418
2527
|
f'grounded, since it contains the key <{var}>, '
|
|
2419
2528
|
f'but a vectorized environment is required: '
|
|
@@ -2421,9 +2530,8 @@ r"""
|
|
|
2421
2530
|
|
|
2422
2531
|
# must be numeric array
|
|
2423
2532
|
# exception is for POMDPs at 1st epoch when observ-fluents are None
|
|
2424
|
-
dtype = np.
|
|
2425
|
-
if not np.issubdtype(dtype, np.number)
|
|
2426
|
-
and not np.issubdtype(dtype, np.bool_):
|
|
2533
|
+
dtype = np.result_type(values)
|
|
2534
|
+
if not np.issubdtype(dtype, np.number) and not np.issubdtype(dtype, np.bool_):
|
|
2427
2535
|
if step == 0 and var in self.rddl.observ_fluents:
|
|
2428
2536
|
subs[var] = self.test_compiled.init_values[var]
|
|
2429
2537
|
else:
|
|
@@ -2435,40 +2543,7 @@ r"""
|
|
|
2435
2543
|
actions = self.test_policy(key, params, policy_hyperparams, step, subs)
|
|
2436
2544
|
actions = jax.tree_map(np.asarray, actions)
|
|
2437
2545
|
return actions
|
|
2438
|
-
|
|
2439
|
-
|
|
2440
|
-
# ***********************************************************************
|
|
2441
|
-
# ALL VERSIONS OF RISK FUNCTIONS
|
|
2442
|
-
#
|
|
2443
|
-
# Based on the original paper "A Distributional Framework for Risk-Sensitive
|
|
2444
|
-
# End-to-End Planning in Continuous MDPs" by Patton et al., AAAI 2022.
|
|
2445
|
-
#
|
|
2446
|
-
# Original risk functions:
|
|
2447
|
-
# - entropic utility
|
|
2448
|
-
# - mean-variance approximation
|
|
2449
|
-
# - conditional value at risk with straight-through gradient trick
|
|
2450
|
-
#
|
|
2451
|
-
# ***********************************************************************
|
|
2452
|
-
|
|
2453
|
-
|
|
2454
|
-
@jax.jit
|
|
2455
|
-
def entropic_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
2456
|
-
return (-1.0 / beta) * jax.scipy.special.logsumexp(
|
|
2457
|
-
-beta * returns, b=1.0 / returns.size)
|
|
2458
|
-
|
|
2459
|
-
|
|
2460
|
-
@jax.jit
|
|
2461
|
-
def mean_variance_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
2462
|
-
return jnp.mean(returns) - 0.5 * beta * jnp.var(returns)
|
|
2463
|
-
|
|
2464
|
-
|
|
2465
|
-
@jax.jit
|
|
2466
|
-
def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
|
|
2467
|
-
var = jnp.percentile(returns, q=100 * alpha)
|
|
2468
|
-
mask = returns <= var
|
|
2469
|
-
weights = mask / jnp.maximum(1, jnp.sum(mask))
|
|
2470
|
-
return jnp.sum(returns * weights)
|
|
2471
|
-
|
|
2546
|
+
|
|
2472
2547
|
|
|
2473
2548
|
# ***********************************************************************
|
|
2474
2549
|
# ALL VERSIONS OF CONTROLLERS
|
|
@@ -2580,8 +2655,7 @@ class JaxOnlineController(BaseAgent):
|
|
|
2580
2655
|
self.callback = callback
|
|
2581
2656
|
params = callback['best_params']
|
|
2582
2657
|
self.key, subkey = random.split(self.key)
|
|
2583
|
-
actions = planner.get_action(
|
|
2584
|
-
subkey, params, 0, state, self.eval_hyperparams)
|
|
2658
|
+
actions = planner.get_action(subkey, params, 0, state, self.eval_hyperparams)
|
|
2585
2659
|
if self.warm_start:
|
|
2586
2660
|
self.guess = planner.plan.guess_next_epoch(params)
|
|
2587
2661
|
return actions
|