pyRDDLGym-jax 2.0__py3-none-any.whl → 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +85 -190
- pyRDDLGym_jax/core/logic.py +313 -56
- pyRDDLGym_jax/core/planner.py +121 -130
- pyRDDLGym_jax/core/visualization.py +7 -8
- pyRDDLGym_jax/examples/run_tune.py +10 -6
- {pyRDDLGym_jax-2.0.dist-info → pyrddlgym_jax-2.1.dist-info}/METADATA +22 -12
- {pyRDDLGym_jax-2.0.dist-info → pyrddlgym_jax-2.1.dist-info}/RECORD +12 -12
- {pyRDDLGym_jax-2.0.dist-info → pyrddlgym_jax-2.1.dist-info}/WHEEL +1 -1
- {pyRDDLGym_jax-2.0.dist-info → pyrddlgym_jax-2.1.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-2.0.dist-info → pyrddlgym_jax-2.1.dist-info}/entry_points.txt +0 -0
- {pyRDDLGym_jax-2.0.dist-info → pyrddlgym_jax-2.1.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/planner.py
CHANGED
|
@@ -69,8 +69,7 @@ try:
|
|
|
69
69
|
from pyRDDLGym_jax.core.visualization import JaxPlannerDashboard
|
|
70
70
|
except Exception:
|
|
71
71
|
raise_warning('Failed to load the dashboard visualization tool: '
|
|
72
|
-
'please make sure you have installed the required packages.',
|
|
73
|
-
'red')
|
|
72
|
+
'please make sure you have installed the required packages.', 'red')
|
|
74
73
|
traceback.print_exc()
|
|
75
74
|
JaxPlannerDashboard = None
|
|
76
75
|
|
|
@@ -133,7 +132,7 @@ def _load_config(config, args):
|
|
|
133
132
|
comp_kwargs = model_args.get('complement_kwargs', {})
|
|
134
133
|
compare_name = model_args.get('comparison', 'SigmoidComparison')
|
|
135
134
|
compare_kwargs = model_args.get('comparison_kwargs', {})
|
|
136
|
-
sampling_name = model_args.get('sampling', '
|
|
135
|
+
sampling_name = model_args.get('sampling', 'SoftRandomSampling')
|
|
137
136
|
sampling_kwargs = model_args.get('sampling_kwargs', {})
|
|
138
137
|
rounding_name = model_args.get('rounding', 'SoftRounding')
|
|
139
138
|
rounding_kwargs = model_args.get('rounding_kwargs', {})
|
|
@@ -156,8 +155,7 @@ def _load_config(config, args):
|
|
|
156
155
|
initializer = _getattr_any(
|
|
157
156
|
packages=[initializers, hk.initializers], item=plan_initializer)
|
|
158
157
|
if initializer is None:
|
|
159
|
-
raise_warning(
|
|
160
|
-
f'Ignoring invalid initializer <{plan_initializer}>.', 'red')
|
|
158
|
+
raise_warning(f'Ignoring invalid initializer <{plan_initializer}>.', 'red')
|
|
161
159
|
del plan_kwargs['initializer']
|
|
162
160
|
else:
|
|
163
161
|
init_kwargs = plan_kwargs.pop('initializer_kwargs', {})
|
|
@@ -174,8 +172,7 @@ def _load_config(config, args):
|
|
|
174
172
|
activation = _getattr_any(
|
|
175
173
|
packages=[jax.nn, jax.numpy], item=plan_activation)
|
|
176
174
|
if activation is None:
|
|
177
|
-
raise_warning(
|
|
178
|
-
f'Ignoring invalid activation <{plan_activation}>.', 'red')
|
|
175
|
+
raise_warning(f'Ignoring invalid activation <{plan_activation}>.', 'red')
|
|
179
176
|
del plan_kwargs['activation']
|
|
180
177
|
else:
|
|
181
178
|
plan_kwargs['activation'] = activation
|
|
@@ -189,8 +186,7 @@ def _load_config(config, args):
|
|
|
189
186
|
if planner_optimizer is not None:
|
|
190
187
|
optimizer = _getattr_any(packages=[optax], item=planner_optimizer)
|
|
191
188
|
if optimizer is None:
|
|
192
|
-
raise_warning(
|
|
193
|
-
f'Ignoring invalid optimizer <{planner_optimizer}>.', 'red')
|
|
189
|
+
raise_warning(f'Ignoring invalid optimizer <{planner_optimizer}>.', 'red')
|
|
194
190
|
del planner_args['optimizer']
|
|
195
191
|
else:
|
|
196
192
|
planner_args['optimizer'] = optimizer
|
|
@@ -285,48 +281,14 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
285
281
|
pvars_cast = set()
|
|
286
282
|
for (var, values) in self.init_values.items():
|
|
287
283
|
self.init_values[var] = np.asarray(values, dtype=self.REAL)
|
|
288
|
-
if not np.issubdtype(np.
|
|
284
|
+
if not np.issubdtype(np.result_type(values), np.floating):
|
|
289
285
|
pvars_cast.add(var)
|
|
290
286
|
if pvars_cast:
|
|
291
287
|
raise_warning(f'JAX gradient compiler requires that initial values '
|
|
292
288
|
f'of p-variables {pvars_cast} be cast to float.')
|
|
293
289
|
|
|
294
290
|
# overwrite basic operations with fuzzy ones
|
|
295
|
-
self.
|
|
296
|
-
'>=': logic.greater_equal,
|
|
297
|
-
'<=': logic.less_equal,
|
|
298
|
-
'<': logic.less,
|
|
299
|
-
'>': logic.greater,
|
|
300
|
-
'==': logic.equal,
|
|
301
|
-
'~=': logic.not_equal
|
|
302
|
-
}
|
|
303
|
-
self.LOGICAL_NOT = logic.logical_not
|
|
304
|
-
self.LOGICAL_OPS = {
|
|
305
|
-
'^': logic.logical_and,
|
|
306
|
-
'&': logic.logical_and,
|
|
307
|
-
'|': logic.logical_or,
|
|
308
|
-
'~': logic.xor,
|
|
309
|
-
'=>': logic.implies,
|
|
310
|
-
'<=>': logic.equiv
|
|
311
|
-
}
|
|
312
|
-
self.AGGREGATION_OPS['forall'] = logic.forall
|
|
313
|
-
self.AGGREGATION_OPS['exists'] = logic.exists
|
|
314
|
-
self.AGGREGATION_OPS['argmin'] = logic.argmin
|
|
315
|
-
self.AGGREGATION_OPS['argmax'] = logic.argmax
|
|
316
|
-
self.KNOWN_UNARY['sgn'] = logic.sgn
|
|
317
|
-
self.KNOWN_UNARY['floor'] = logic.floor
|
|
318
|
-
self.KNOWN_UNARY['ceil'] = logic.ceil
|
|
319
|
-
self.KNOWN_UNARY['round'] = logic.round
|
|
320
|
-
self.KNOWN_UNARY['sqrt'] = logic.sqrt
|
|
321
|
-
self.KNOWN_BINARY['div'] = logic.div
|
|
322
|
-
self.KNOWN_BINARY['mod'] = logic.mod
|
|
323
|
-
self.KNOWN_BINARY['fmod'] = logic.mod
|
|
324
|
-
self.IF_HELPER = logic.control_if
|
|
325
|
-
self.SWITCH_HELPER = logic.control_switch
|
|
326
|
-
self.BERNOULLI_HELPER = logic.bernoulli
|
|
327
|
-
self.DISCRETE_HELPER = logic.discrete
|
|
328
|
-
self.POISSON_HELPER = logic.poisson
|
|
329
|
-
self.GEOMETRIC_HELPER = logic.geometric
|
|
291
|
+
self.OPS = logic.get_operator_dicts()
|
|
330
292
|
|
|
331
293
|
def _jax_stop_grad(self, jax_expr):
|
|
332
294
|
def _jax_wrapped_stop_grad(x, params, key):
|
|
@@ -575,7 +537,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
575
537
|
def _jax_non_bool_param_to_action(var, param, hyperparams):
|
|
576
538
|
if wrap_non_bool:
|
|
577
539
|
lower, upper = bounds_safe[var]
|
|
578
|
-
mb, ml, mu, mn = [
|
|
540
|
+
mb, ml, mu, mn = [jnp.asarray(mask, dtype=compiled.REAL)
|
|
579
541
|
for mask in cond_lists[var]]
|
|
580
542
|
action = (
|
|
581
543
|
mb * (lower + (upper - lower) * jax.nn.sigmoid(param)) +
|
|
@@ -660,7 +622,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
660
622
|
action = _jax_non_bool_param_to_action(var, action, hyperparams)
|
|
661
623
|
action = jnp.clip(action, *bounds[var])
|
|
662
624
|
if ranges[var] == 'int':
|
|
663
|
-
action = jnp.round(action)
|
|
625
|
+
action = jnp.asarray(jnp.round(action), dtype=compiled.INT)
|
|
664
626
|
actions[var] = action
|
|
665
627
|
return actions
|
|
666
628
|
|
|
@@ -961,12 +923,11 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
961
923
|
non_bool_dims = 0
|
|
962
924
|
for (var, values) in observed_vars.items():
|
|
963
925
|
if ranges[var] != 'bool':
|
|
964
|
-
value_size = np.
|
|
926
|
+
value_size = np.size(values)
|
|
965
927
|
if normalize_per_layer and value_size == 1:
|
|
966
928
|
raise_warning(
|
|
967
929
|
f'Cannot apply layer norm to state-fluent <{var}> '
|
|
968
|
-
f'of size 1: setting normalize_per_layer = False.',
|
|
969
|
-
'red')
|
|
930
|
+
f'of size 1: setting normalize_per_layer = False.', 'red')
|
|
970
931
|
normalize_per_layer = False
|
|
971
932
|
non_bool_dims += value_size
|
|
972
933
|
if not normalize_per_layer and non_bool_dims == 1:
|
|
@@ -990,9 +951,11 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
990
951
|
else:
|
|
991
952
|
if normalize and normalize_per_layer:
|
|
992
953
|
normalizer = hk.LayerNorm(
|
|
993
|
-
axis=-1,
|
|
954
|
+
axis=-1,
|
|
955
|
+
param_axis=-1,
|
|
994
956
|
name=f'input_norm_{input_names[var]}',
|
|
995
|
-
**self._normalizer_kwargs
|
|
957
|
+
**self._normalizer_kwargs
|
|
958
|
+
)
|
|
996
959
|
state = normalizer(state)
|
|
997
960
|
states_non_bool.append(state)
|
|
998
961
|
non_bool_dims += state.size
|
|
@@ -1001,8 +964,11 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1001
964
|
# optionally perform layer normalization on the non-bool inputs
|
|
1002
965
|
if normalize and not normalize_per_layer and non_bool_dims:
|
|
1003
966
|
normalizer = hk.LayerNorm(
|
|
1004
|
-
axis=-1,
|
|
1005
|
-
|
|
967
|
+
axis=-1,
|
|
968
|
+
param_axis=-1,
|
|
969
|
+
name='input_norm',
|
|
970
|
+
**self._normalizer_kwargs
|
|
971
|
+
)
|
|
1006
972
|
normalized = normalizer(state[:non_bool_dims])
|
|
1007
973
|
state = state.at[:non_bool_dims].set(normalized)
|
|
1008
974
|
return state
|
|
@@ -1021,7 +987,8 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1021
987
|
actions = {}
|
|
1022
988
|
for (var, size) in layer_sizes.items():
|
|
1023
989
|
linear = hk.Linear(size, name=layer_names[var], w_init=init)
|
|
1024
|
-
reshape = hk.Reshape(output_shape=shapes[var],
|
|
990
|
+
reshape = hk.Reshape(output_shape=shapes[var],
|
|
991
|
+
preserve_dims=-1,
|
|
1025
992
|
name=f'reshape_{layer_names[var]}')
|
|
1026
993
|
output = reshape(linear(hidden))
|
|
1027
994
|
if not shapes[var]:
|
|
@@ -1034,7 +1001,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1034
1001
|
else:
|
|
1035
1002
|
if wrap_non_bool:
|
|
1036
1003
|
lower, upper = bounds_safe[var]
|
|
1037
|
-
mb, ml, mu, mn = [
|
|
1004
|
+
mb, ml, mu, mn = [jnp.asarray(mask, dtype=compiled.REAL)
|
|
1038
1005
|
for mask in cond_lists[var]]
|
|
1039
1006
|
action = (
|
|
1040
1007
|
mb * (lower + (upper - lower) * jax.nn.sigmoid(output)) +
|
|
@@ -1048,8 +1015,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1048
1015
|
|
|
1049
1016
|
# for constraint satisfaction wrap bool actions with softmax
|
|
1050
1017
|
if use_constraint_satisfaction:
|
|
1051
|
-
linear = hk.Linear(
|
|
1052
|
-
bool_action_count, name='output_bool', w_init=init)
|
|
1018
|
+
linear = hk.Linear(bool_action_count, name='output_bool', w_init=init)
|
|
1053
1019
|
output = jax.nn.softmax(linear(hidden))
|
|
1054
1020
|
actions[bool_key] = output
|
|
1055
1021
|
|
|
@@ -1087,8 +1053,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1087
1053
|
|
|
1088
1054
|
# test action prediction
|
|
1089
1055
|
def _jax_wrapped_drp_predict_test(key, params, hyperparams, step, subs):
|
|
1090
|
-
actions = _jax_wrapped_drp_predict_train(
|
|
1091
|
-
key, params, hyperparams, step, subs)
|
|
1056
|
+
actions = _jax_wrapped_drp_predict_train(key, params, hyperparams, step, subs)
|
|
1092
1057
|
new_actions = {}
|
|
1093
1058
|
for (var, action) in actions.items():
|
|
1094
1059
|
prange = ranges[var]
|
|
@@ -1096,7 +1061,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1096
1061
|
new_action = action > 0.5
|
|
1097
1062
|
elif prange == 'int':
|
|
1098
1063
|
action = jnp.clip(action, *bounds[var])
|
|
1099
|
-
new_action = jnp.round(action)
|
|
1064
|
+
new_action = jnp.asarray(jnp.round(action), dtype=compiled.INT)
|
|
1100
1065
|
else:
|
|
1101
1066
|
new_action = jnp.clip(action, *bounds[var])
|
|
1102
1067
|
new_actions[var] = new_action
|
|
@@ -1436,8 +1401,8 @@ class GaussianPGPE(PGPE):
|
|
|
1436
1401
|
_jax_wrapped_pgpe_grad,
|
|
1437
1402
|
in_axes=(0, None, None, None, None, None, None)
|
|
1438
1403
|
)(keys, mu, sigma, r_max, policy_hyperparams, subs, model_params)
|
|
1439
|
-
mu_grad = jax.tree_map(
|
|
1440
|
-
|
|
1404
|
+
mu_grad, sigma_grad = jax.tree_map(
|
|
1405
|
+
partial(jnp.mean, axis=0), (mu_grads, sigma_grads))
|
|
1441
1406
|
new_r_max = jnp.max(r_maxs)
|
|
1442
1407
|
return mu_grad, sigma_grad, new_r_max
|
|
1443
1408
|
|
|
@@ -1463,6 +1428,71 @@ class GaussianPGPE(PGPE):
|
|
|
1463
1428
|
self._update = jax.jit(_jax_wrapped_pgpe_update)
|
|
1464
1429
|
|
|
1465
1430
|
|
|
1431
|
+
# ***********************************************************************
|
|
1432
|
+
# ALL VERSIONS OF RISK FUNCTIONS
|
|
1433
|
+
#
|
|
1434
|
+
# Based on the original paper "A Distributional Framework for Risk-Sensitive
|
|
1435
|
+
# End-to-End Planning in Continuous MDPs" by Patton et al., AAAI 2022.
|
|
1436
|
+
#
|
|
1437
|
+
# Original risk functions:
|
|
1438
|
+
# - entropic utility
|
|
1439
|
+
# - mean-variance
|
|
1440
|
+
# - mean-semideviation
|
|
1441
|
+
# - conditional value at risk with straight-through gradient trick
|
|
1442
|
+
#
|
|
1443
|
+
# ***********************************************************************
|
|
1444
|
+
|
|
1445
|
+
|
|
1446
|
+
@jax.jit
|
|
1447
|
+
def entropic_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
1448
|
+
return (-1.0 / beta) * jax.scipy.special.logsumexp(
|
|
1449
|
+
-beta * returns, b=1.0 / returns.size)
|
|
1450
|
+
|
|
1451
|
+
|
|
1452
|
+
@jax.jit
|
|
1453
|
+
def mean_variance_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
1454
|
+
return jnp.mean(returns) - 0.5 * beta * jnp.var(returns)
|
|
1455
|
+
|
|
1456
|
+
|
|
1457
|
+
@jax.jit
|
|
1458
|
+
def mean_deviation_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
1459
|
+
return jnp.mean(returns) - 0.5 * beta * jnp.std(returns)
|
|
1460
|
+
|
|
1461
|
+
|
|
1462
|
+
@jax.jit
|
|
1463
|
+
def mean_semideviation_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
1464
|
+
mu = jnp.mean(returns)
|
|
1465
|
+
msd = jnp.sqrt(jnp.mean(jnp.minimum(0.0, returns - mu) ** 2))
|
|
1466
|
+
return mu - 0.5 * beta * msd
|
|
1467
|
+
|
|
1468
|
+
|
|
1469
|
+
@jax.jit
|
|
1470
|
+
def mean_semivariance_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
1471
|
+
mu = jnp.mean(returns)
|
|
1472
|
+
msv = jnp.mean(jnp.minimum(0.0, returns - mu) ** 2)
|
|
1473
|
+
return mu - 0.5 * beta * msv
|
|
1474
|
+
|
|
1475
|
+
|
|
1476
|
+
@jax.jit
|
|
1477
|
+
def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
|
|
1478
|
+
var = jnp.percentile(returns, q=100 * alpha)
|
|
1479
|
+
mask = returns <= var
|
|
1480
|
+
weights = mask / jnp.maximum(1, jnp.sum(mask))
|
|
1481
|
+
return jnp.sum(returns * weights)
|
|
1482
|
+
|
|
1483
|
+
|
|
1484
|
+
UTILITY_LOOKUP = {
|
|
1485
|
+
'mean': jnp.mean,
|
|
1486
|
+
'mean_var': mean_variance_utility,
|
|
1487
|
+
'mean_std': mean_deviation_utility,
|
|
1488
|
+
'mean_semivar': mean_semivariance_utility,
|
|
1489
|
+
'mean_semidev': mean_semideviation_utility,
|
|
1490
|
+
'entropic': entropic_utility,
|
|
1491
|
+
'exponential': entropic_utility,
|
|
1492
|
+
'cvar': cvar_utility
|
|
1493
|
+
}
|
|
1494
|
+
|
|
1495
|
+
|
|
1466
1496
|
# ***********************************************************************
|
|
1467
1497
|
# ALL VERSIONS OF JAX PLANNER
|
|
1468
1498
|
#
|
|
@@ -1525,8 +1555,7 @@ class JaxBackpropPlanner:
|
|
|
1525
1555
|
reward as a form of normalization
|
|
1526
1556
|
:param utility: how to aggregate return observations to compute utility
|
|
1527
1557
|
of a policy or plan; must be either a function mapping jax array to a
|
|
1528
|
-
scalar, or a a string identifying the utility function by name
|
|
1529
|
-
("mean", "mean_var", "entropic", or "cvar" are currently supported)
|
|
1558
|
+
scalar, or a a string identifying the utility function by name
|
|
1530
1559
|
:param utility_kwargs: additional keyword arguments to pass hyper-
|
|
1531
1560
|
parameters to the utility function call
|
|
1532
1561
|
:param cpfs_without_grad: which CPFs do not have gradients (use straight
|
|
@@ -1584,18 +1613,11 @@ class JaxBackpropPlanner:
|
|
|
1584
1613
|
# set utility
|
|
1585
1614
|
if isinstance(utility, str):
|
|
1586
1615
|
utility = utility.lower()
|
|
1587
|
-
|
|
1588
|
-
|
|
1589
|
-
elif utility == 'mean_var':
|
|
1590
|
-
utility_fn = mean_variance_utility
|
|
1591
|
-
elif utility == 'entropic':
|
|
1592
|
-
utility_fn = entropic_utility
|
|
1593
|
-
elif utility == 'cvar':
|
|
1594
|
-
utility_fn = cvar_utility
|
|
1595
|
-
else:
|
|
1616
|
+
utility_fn = UTILITY_LOOKUP.get(utility, None)
|
|
1617
|
+
if utility_fn is None:
|
|
1596
1618
|
raise RDDLNotImplementedError(
|
|
1597
|
-
f'Utility
|
|
1598
|
-
'must be one of
|
|
1619
|
+
f'Utility <{utility}> is not supported, '
|
|
1620
|
+
f'must be one of {list(UTILITY_LOOKUP.keys())}.')
|
|
1599
1621
|
else:
|
|
1600
1622
|
utility_fn = utility
|
|
1601
1623
|
self.utility = utility_fn
|
|
@@ -1865,7 +1887,7 @@ r"""
|
|
|
1865
1887
|
f'{set(self.test_compiled.init_values.keys())}.')
|
|
1866
1888
|
value = np.reshape(value, newshape=np.shape(init_value))[np.newaxis, ...]
|
|
1867
1889
|
train_value = np.repeat(value, repeats=n_train, axis=0)
|
|
1868
|
-
train_value =
|
|
1890
|
+
train_value = np.asarray(train_value, dtype=self.compiled.REAL)
|
|
1869
1891
|
init_train[name] = train_value
|
|
1870
1892
|
init_test[name] = np.repeat(value, repeats=n_test, axis=0)
|
|
1871
1893
|
|
|
@@ -2175,7 +2197,9 @@ r"""
|
|
|
2175
2197
|
|
|
2176
2198
|
iters = range(epochs)
|
|
2177
2199
|
if print_progress:
|
|
2178
|
-
iters = tqdm(iters, total=100,
|
|
2200
|
+
iters = tqdm(iters, total=100,
|
|
2201
|
+
bar_format='{l_bar}{bar}| {elapsed} {postfix}',
|
|
2202
|
+
position=tqdm_position)
|
|
2179
2203
|
position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
|
|
2180
2204
|
|
|
2181
2205
|
for it in iters:
|
|
@@ -2256,7 +2280,8 @@ r"""
|
|
|
2256
2280
|
status = JaxPlannerStatus.ITER_BUDGET_REACHED
|
|
2257
2281
|
|
|
2258
2282
|
# build a callback
|
|
2259
|
-
progress_percent =
|
|
2283
|
+
progress_percent = 100 * min(
|
|
2284
|
+
1, max(0, elapsed / train_seconds, it / (epochs - 1)))
|
|
2260
2285
|
callback = {
|
|
2261
2286
|
'status': status,
|
|
2262
2287
|
'iteration': it,
|
|
@@ -2279,7 +2304,7 @@ r"""
|
|
|
2279
2304
|
'train_log': train_log,
|
|
2280
2305
|
**test_log
|
|
2281
2306
|
}
|
|
2282
|
-
|
|
2307
|
+
|
|
2283
2308
|
# stopping condition reached
|
|
2284
2309
|
if stopping_rule is not None and stopping_rule.monitor(callback):
|
|
2285
2310
|
callback['status'] = status = JaxPlannerStatus.STOPPING_RULE_REACHED
|
|
@@ -2290,8 +2315,10 @@ r"""
|
|
|
2290
2315
|
iters.set_description(
|
|
2291
2316
|
f'{position_str} {it:6} it / {-train_loss:14.5f} train / '
|
|
2292
2317
|
f'{-test_loss_smooth:14.5f} test / {-best_loss:14.5f} best / '
|
|
2293
|
-
f'{status.value} status / {total_pgpe_it:6} pgpe'
|
|
2318
|
+
f'{status.value} status / {total_pgpe_it:6} pgpe',
|
|
2319
|
+
refresh=False
|
|
2294
2320
|
)
|
|
2321
|
+
iters.set_postfix_str(f"{(it + 1) / elapsed:.2f}it/s", refresh=True)
|
|
2295
2322
|
|
|
2296
2323
|
# dash-board
|
|
2297
2324
|
if dashboard is not None:
|
|
@@ -2332,7 +2359,7 @@ r"""
|
|
|
2332
2359
|
last_iter_improve, -train_loss, -test_loss_smooth, -best_loss, grad_norm)
|
|
2333
2360
|
print(f'summary of optimization:\n'
|
|
2334
2361
|
f' status ={status}\n'
|
|
2335
|
-
f' time ={elapsed:.
|
|
2362
|
+
f' time ={elapsed:.3f} sec.\n'
|
|
2336
2363
|
f' iterations ={it}\n'
|
|
2337
2364
|
f' best objective={-best_loss:.6f}\n'
|
|
2338
2365
|
f' best grad norm={grad_norm}\n'
|
|
@@ -2358,12 +2385,12 @@ r"""
|
|
|
2358
2385
|
return termcolor.colored(
|
|
2359
2386
|
'[FAILURE] no progress was made '
|
|
2360
2387
|
f'and max grad norm {max_grad_norm:.6f} was zero: '
|
|
2361
|
-
'
|
|
2388
|
+
'solver likely stuck in a plateau.', 'red')
|
|
2362
2389
|
else:
|
|
2363
2390
|
return termcolor.colored(
|
|
2364
2391
|
'[FAILURE] no progress was made '
|
|
2365
2392
|
f'but max grad norm {max_grad_norm:.6f} was non-zero: '
|
|
2366
|
-
'
|
|
2393
|
+
'learning rate or other hyper-parameters likely suboptimal.',
|
|
2367
2394
|
'red')
|
|
2368
2395
|
|
|
2369
2396
|
# model is likely poor IF:
|
|
@@ -2372,8 +2399,8 @@ r"""
|
|
|
2372
2399
|
return termcolor.colored(
|
|
2373
2400
|
'[WARNING] progress was made '
|
|
2374
2401
|
f'but relative train-test error {validation_error:.6f} was high: '
|
|
2375
|
-
'model relaxation around
|
|
2376
|
-
'
|
|
2402
|
+
'poor model relaxation around solution or batch size too small.',
|
|
2403
|
+
'yellow')
|
|
2377
2404
|
|
|
2378
2405
|
# model likely did not converge IF:
|
|
2379
2406
|
# 1. the max grad relative to the return is high
|
|
@@ -2383,9 +2410,9 @@ r"""
|
|
|
2383
2410
|
return termcolor.colored(
|
|
2384
2411
|
'[WARNING] progress was made '
|
|
2385
2412
|
f'but max grad norm {max_grad_norm:.6f} was high: '
|
|
2386
|
-
'
|
|
2387
|
-
'or
|
|
2388
|
-
'or
|
|
2413
|
+
'solution locally suboptimal '
|
|
2414
|
+
'or relaxed model not smooth around solution '
|
|
2415
|
+
'or batch size too small.', 'yellow')
|
|
2389
2416
|
|
|
2390
2417
|
# likely successful
|
|
2391
2418
|
return termcolor.colored(
|
|
@@ -2412,8 +2439,7 @@ r"""
|
|
|
2412
2439
|
for (var, values) in subs.items():
|
|
2413
2440
|
|
|
2414
2441
|
# must not be grounded
|
|
2415
|
-
if RDDLPlanningModel.FLUENT_SEP in var
|
|
2416
|
-
or RDDLPlanningModel.OBJECT_SEP in var:
|
|
2442
|
+
if RDDLPlanningModel.FLUENT_SEP in var or RDDLPlanningModel.OBJECT_SEP in var:
|
|
2417
2443
|
raise ValueError(f'State dictionary passed to the JAX policy is '
|
|
2418
2444
|
f'grounded, since it contains the key <{var}>, '
|
|
2419
2445
|
f'but a vectorized environment is required: '
|
|
@@ -2421,9 +2447,8 @@ r"""
|
|
|
2421
2447
|
|
|
2422
2448
|
# must be numeric array
|
|
2423
2449
|
# exception is for POMDPs at 1st epoch when observ-fluents are None
|
|
2424
|
-
dtype = np.
|
|
2425
|
-
if not np.issubdtype(dtype, np.number)
|
|
2426
|
-
and not np.issubdtype(dtype, np.bool_):
|
|
2450
|
+
dtype = np.result_type(values)
|
|
2451
|
+
if not np.issubdtype(dtype, np.number) and not np.issubdtype(dtype, np.bool_):
|
|
2427
2452
|
if step == 0 and var in self.rddl.observ_fluents:
|
|
2428
2453
|
subs[var] = self.test_compiled.init_values[var]
|
|
2429
2454
|
else:
|
|
@@ -2435,40 +2460,7 @@ r"""
|
|
|
2435
2460
|
actions = self.test_policy(key, params, policy_hyperparams, step, subs)
|
|
2436
2461
|
actions = jax.tree_map(np.asarray, actions)
|
|
2437
2462
|
return actions
|
|
2438
|
-
|
|
2439
|
-
|
|
2440
|
-
# ***********************************************************************
|
|
2441
|
-
# ALL VERSIONS OF RISK FUNCTIONS
|
|
2442
|
-
#
|
|
2443
|
-
# Based on the original paper "A Distributional Framework for Risk-Sensitive
|
|
2444
|
-
# End-to-End Planning in Continuous MDPs" by Patton et al., AAAI 2022.
|
|
2445
|
-
#
|
|
2446
|
-
# Original risk functions:
|
|
2447
|
-
# - entropic utility
|
|
2448
|
-
# - mean-variance approximation
|
|
2449
|
-
# - conditional value at risk with straight-through gradient trick
|
|
2450
|
-
#
|
|
2451
|
-
# ***********************************************************************
|
|
2452
|
-
|
|
2453
|
-
|
|
2454
|
-
@jax.jit
|
|
2455
|
-
def entropic_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
2456
|
-
return (-1.0 / beta) * jax.scipy.special.logsumexp(
|
|
2457
|
-
-beta * returns, b=1.0 / returns.size)
|
|
2458
|
-
|
|
2459
|
-
|
|
2460
|
-
@jax.jit
|
|
2461
|
-
def mean_variance_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
2462
|
-
return jnp.mean(returns) - 0.5 * beta * jnp.var(returns)
|
|
2463
|
-
|
|
2464
|
-
|
|
2465
|
-
@jax.jit
|
|
2466
|
-
def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
|
|
2467
|
-
var = jnp.percentile(returns, q=100 * alpha)
|
|
2468
|
-
mask = returns <= var
|
|
2469
|
-
weights = mask / jnp.maximum(1, jnp.sum(mask))
|
|
2470
|
-
return jnp.sum(returns * weights)
|
|
2471
|
-
|
|
2463
|
+
|
|
2472
2464
|
|
|
2473
2465
|
# ***********************************************************************
|
|
2474
2466
|
# ALL VERSIONS OF CONTROLLERS
|
|
@@ -2580,8 +2572,7 @@ class JaxOnlineController(BaseAgent):
|
|
|
2580
2572
|
self.callback = callback
|
|
2581
2573
|
params = callback['best_params']
|
|
2582
2574
|
self.key, subkey = random.split(self.key)
|
|
2583
|
-
actions = planner.get_action(
|
|
2584
|
-
subkey, params, 0, state, self.eval_hyperparams)
|
|
2575
|
+
actions = planner.get_action(subkey, params, 0, state, self.eval_hyperparams)
|
|
2585
2576
|
if self.warm_start:
|
|
2586
2577
|
self.guess = planner.plan.guess_next_epoch(params)
|
|
2587
2578
|
return actions
|
|
@@ -20,8 +20,7 @@ import math
|
|
|
20
20
|
import numpy as np
|
|
21
21
|
import time
|
|
22
22
|
import threading
|
|
23
|
-
from typing import Any, Dict,
|
|
24
|
-
import warnings
|
|
23
|
+
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
|
|
25
24
|
import webbrowser
|
|
26
25
|
|
|
27
26
|
# prevent endless console prints
|
|
@@ -32,7 +31,7 @@ log.setLevel(logging.ERROR)
|
|
|
32
31
|
import dash
|
|
33
32
|
from dash.dcc import Interval, Graph, Store
|
|
34
33
|
from dash.dependencies import Input, Output, State, ALL
|
|
35
|
-
from dash.html import Div, B, H4, P,
|
|
34
|
+
from dash.html import Div, B, H4, P, Hr
|
|
36
35
|
import dash_bootstrap_components as dbc
|
|
37
36
|
|
|
38
37
|
import plotly.colors as pc
|
|
@@ -53,6 +52,7 @@ REWARD_ERROR_DIST_SUBPLOTS = 20
|
|
|
53
52
|
MODEL_STATE_ERROR_HEIGHT = 300
|
|
54
53
|
POLICY_STATE_VIZ_MAX_HEIGHT = 800
|
|
55
54
|
GP_POSTERIOR_MAX_HEIGHT = 800
|
|
55
|
+
GP_POSTERIOR_PIXELS = 100
|
|
56
56
|
|
|
57
57
|
PLOT_AXES_FONT_SIZE = 11
|
|
58
58
|
EXPERIMENT_ENTRY_FONT_SIZE = 14
|
|
@@ -1417,7 +1417,7 @@ class JaxPlannerDashboard:
|
|
|
1417
1417
|
self.pgpe_return[experiment_id].append(callback['pgpe_return'])
|
|
1418
1418
|
|
|
1419
1419
|
# data for return distributions
|
|
1420
|
-
progress = callback['progress']
|
|
1420
|
+
progress = int(callback['progress'])
|
|
1421
1421
|
if progress - self.return_dist_last_progress[experiment_id] \
|
|
1422
1422
|
>= PROGRESS_FOR_NEXT_RETURN_DIST:
|
|
1423
1423
|
self.return_dist_ticks[experiment_id].append(iteration)
|
|
@@ -1486,8 +1486,8 @@ class JaxPlannerDashboard:
|
|
|
1486
1486
|
if i2 > i1:
|
|
1487
1487
|
|
|
1488
1488
|
# Generate a grid for visualization
|
|
1489
|
-
p1_values = np.linspace(*bounds[param1],
|
|
1490
|
-
p2_values = np.linspace(*bounds[param2],
|
|
1489
|
+
p1_values = np.linspace(*bounds[param1], GP_POSTERIOR_PIXELS)
|
|
1490
|
+
p2_values = np.linspace(*bounds[param2], GP_POSTERIOR_PIXELS)
|
|
1491
1491
|
P1, P2 = np.meshgrid(p1_values, p2_values)
|
|
1492
1492
|
|
|
1493
1493
|
# Predict the mean and deviation of the surrogate model
|
|
@@ -1500,8 +1500,7 @@ class JaxPlannerDashboard:
|
|
|
1500
1500
|
for p1, p2 in zip(np.ravel(P1), np.ravel(P2)):
|
|
1501
1501
|
params = {param1: p1, param2: p2}
|
|
1502
1502
|
params.update(fixed_params)
|
|
1503
|
-
param_grid.append(
|
|
1504
|
-
[params[key] for key in optimizer.space.keys])
|
|
1503
|
+
param_grid.append([params[key] for key in optimizer.space.keys])
|
|
1505
1504
|
param_grid = np.asarray(param_grid)
|
|
1506
1505
|
mean, std = optimizer._gp.predict(param_grid, return_std=True)
|
|
1507
1506
|
mean = mean.reshape(P1.shape)
|
|
@@ -3,7 +3,7 @@ is performed using a batched parallelized Bayesian optimization.
|
|
|
3
3
|
|
|
4
4
|
The syntax is:
|
|
5
5
|
|
|
6
|
-
python run_tune.py <domain> <instance> <method> [<trials>] [<iters>] [<workers>]
|
|
6
|
+
python run_tune.py <domain> <instance> <method> [<trials>] [<iters>] [<workers>] [<dashboard>]
|
|
7
7
|
|
|
8
8
|
where:
|
|
9
9
|
<domain> is the name of a domain located in the /Examples directory
|
|
@@ -15,6 +15,7 @@ where:
|
|
|
15
15
|
(defaults to 20)
|
|
16
16
|
<workers> is the number of parallel workers (i.e. batch size), which must
|
|
17
17
|
not exceed the number of cores available on the machine (defaults to 4)
|
|
18
|
+
<dashboard> is whether the dashboard is displayed
|
|
18
19
|
'''
|
|
19
20
|
import os
|
|
20
21
|
import sys
|
|
@@ -35,7 +36,7 @@ def power_10(x):
|
|
|
35
36
|
return 10.0 ** x
|
|
36
37
|
|
|
37
38
|
|
|
38
|
-
def main(domain, instance, method, trials=5, iters=20, workers=4):
|
|
39
|
+
def main(domain, instance, method, trials=5, iters=20, workers=4, dashboard=False):
|
|
39
40
|
|
|
40
41
|
# set up the environment
|
|
41
42
|
env = pyRDDLGym.make(domain, instance, vectorized=True)
|
|
@@ -48,9 +49,9 @@ def main(domain, instance, method, trials=5, iters=20, workers=4):
|
|
|
48
49
|
|
|
49
50
|
# map parameters in the config that will be tuned
|
|
50
51
|
hyperparams = [
|
|
51
|
-
Hyperparameter('MODEL_WEIGHT_TUNE', -1.,
|
|
52
|
+
Hyperparameter('MODEL_WEIGHT_TUNE', -1., 4., power_10),
|
|
52
53
|
Hyperparameter('POLICY_WEIGHT_TUNE', -2., 2., power_10),
|
|
53
|
-
Hyperparameter('LEARNING_RATE_TUNE', -5.,
|
|
54
|
+
Hyperparameter('LEARNING_RATE_TUNE', -5., 0., power_10),
|
|
54
55
|
Hyperparameter('LAYER1_TUNE', 1, 8, power_2),
|
|
55
56
|
Hyperparameter('LAYER2_TUNE', 1, 8, power_2),
|
|
56
57
|
Hyperparameter('ROLLOUT_HORIZON_TUNE', 1, min(env.horizon, 100), int)
|
|
@@ -64,7 +65,9 @@ def main(domain, instance, method, trials=5, iters=20, workers=4):
|
|
|
64
65
|
eval_trials=trials,
|
|
65
66
|
num_workers=workers,
|
|
66
67
|
gp_iters=iters)
|
|
67
|
-
tuning.tune(key=42,
|
|
68
|
+
tuning.tune(key=42,
|
|
69
|
+
log_file=f'gp_{method}_{domain}_{instance}.csv',
|
|
70
|
+
show_dashboard=dashboard)
|
|
68
71
|
|
|
69
72
|
# evaluate the agent on the best parameters
|
|
70
73
|
planner_args, _, train_args = load_config_from_string(tuning.best_config)
|
|
@@ -77,7 +80,7 @@ def main(domain, instance, method, trials=5, iters=20, workers=4):
|
|
|
77
80
|
|
|
78
81
|
def run_from_args(args):
|
|
79
82
|
if len(args) < 3:
|
|
80
|
-
print('python run_tune.py <domain> <instance> <method> [<trials>] [<iters>] [<workers>]')
|
|
83
|
+
print('python run_tune.py <domain> <instance> <method> [<trials>] [<iters>] [<workers>] [<dashboard>]')
|
|
81
84
|
exit(1)
|
|
82
85
|
if args[2] not in ['drp', 'slp', 'replan']:
|
|
83
86
|
print('<method> in [drp, slp, replan]')
|
|
@@ -86,6 +89,7 @@ def run_from_args(args):
|
|
|
86
89
|
if len(args) >= 4: kwargs['trials'] = int(args[3])
|
|
87
90
|
if len(args) >= 5: kwargs['iters'] = int(args[4])
|
|
88
91
|
if len(args) >= 6: kwargs['workers'] = int(args[5])
|
|
92
|
+
if len(args) >= 7: kwargs['dashboard'] = bool(args[6])
|
|
89
93
|
main(**kwargs)
|
|
90
94
|
|
|
91
95
|
|