pyRDDLGym-jax 1.3__py3-none-any.whl → 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +101 -191
- pyRDDLGym_jax/core/logic.py +349 -65
- pyRDDLGym_jax/core/planner.py +554 -208
- pyRDDLGym_jax/core/simulator.py +20 -0
- pyRDDLGym_jax/core/tuning.py +15 -0
- pyRDDLGym_jax/core/visualization.py +55 -8
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -3
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +1 -0
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +1 -0
- pyRDDLGym_jax/examples/run_tune.py +10 -6
- {pyRDDLGym_jax-1.3.dist-info → pyrddlgym_jax-2.1.dist-info}/METADATA +22 -12
- {pyRDDLGym_jax-1.3.dist-info → pyrddlgym_jax-2.1.dist-info}/RECORD +24 -24
- {pyRDDLGym_jax-1.3.dist-info → pyrddlgym_jax-2.1.dist-info}/WHEEL +1 -1
- {pyRDDLGym_jax-1.3.dist-info → pyrddlgym_jax-2.1.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-1.3.dist-info → pyrddlgym_jax-2.1.dist-info}/entry_points.txt +0 -0
- {pyRDDLGym_jax-1.3.dist-info → pyrddlgym_jax-2.1.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/planner.py
CHANGED
|
@@ -1,12 +1,43 @@
|
|
|
1
|
+
# ***********************************************************************
|
|
2
|
+
# JAXPLAN
|
|
3
|
+
#
|
|
4
|
+
# Author: Michael Gimelfarb
|
|
5
|
+
#
|
|
6
|
+
# RELEVANT SOURCES:
|
|
7
|
+
#
|
|
8
|
+
# [1] Gimelfarb, Michael, Ayal Taitler, and Scott Sanner. "JaxPlan and GurobiPlan:
|
|
9
|
+
# Optimization Baselines for Replanning in Discrete and Mixed Discrete-Continuous
|
|
10
|
+
# Probabilistic Domains." Proceedings of the International Conference on Automated
|
|
11
|
+
# Planning and Scheduling. Vol. 34. 2024.
|
|
12
|
+
#
|
|
13
|
+
# [2] Patton, Noah, Jihwan Jeong, Mike Gimelfarb, and Scott Sanner. "A Distributional
|
|
14
|
+
# Framework for Risk-Sensitive End-to-End Planning in Continuous MDPs." In Proceedings of
|
|
15
|
+
# the AAAI Conference on Artificial Intelligence, vol. 36, no. 9, pp. 9894-9901. 2022.
|
|
16
|
+
#
|
|
17
|
+
# [3] Bueno, Thiago P., Leliane N. de Barros, Denis D. Mauá, and Scott Sanner. "Deep
|
|
18
|
+
# reactive policies for planning in stochastic nonlinear domains." In Proceedings of the
|
|
19
|
+
# AAAI Conference on Artificial Intelligence, vol. 33, no. 01, pp. 7530-7537. 2019.
|
|
20
|
+
#
|
|
21
|
+
# [4] Wu, Ga, Buser Say, and Scott Sanner. "Scalable planning with tensorflow for hybrid
|
|
22
|
+
# nonlinear domains." Advances in Neural Information Processing Systems 30 (2017).
|
|
23
|
+
#
|
|
24
|
+
# [5] Sehnke, Frank, and Tingting Zhao. "Baseline-free sampling in parameter exploring
|
|
25
|
+
# policy gradients: Super symmetric pgpe." Artificial Neural Networks: Methods and
|
|
26
|
+
# Applications in Bio-/Neuroinformatics. Springer International Publishing, 2015.
|
|
27
|
+
#
|
|
28
|
+
# ***********************************************************************
|
|
29
|
+
|
|
30
|
+
|
|
1
31
|
from ast import literal_eval
|
|
2
32
|
from collections import deque
|
|
3
33
|
import configparser
|
|
4
34
|
from enum import Enum
|
|
35
|
+
from functools import partial
|
|
5
36
|
import os
|
|
6
37
|
import sys
|
|
7
38
|
import time
|
|
8
39
|
import traceback
|
|
9
|
-
from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Tuple, Union
|
|
40
|
+
from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Type, Tuple, Union
|
|
10
41
|
|
|
11
42
|
import haiku as hk
|
|
12
43
|
import jax
|
|
@@ -38,8 +69,7 @@ try:
|
|
|
38
69
|
from pyRDDLGym_jax.core.visualization import JaxPlannerDashboard
|
|
39
70
|
except Exception:
|
|
40
71
|
raise_warning('Failed to load the dashboard visualization tool: '
|
|
41
|
-
'please make sure you have installed the required packages.',
|
|
42
|
-
'red')
|
|
72
|
+
'please make sure you have installed the required packages.', 'red')
|
|
43
73
|
traceback.print_exc()
|
|
44
74
|
JaxPlannerDashboard = None
|
|
45
75
|
|
|
@@ -102,7 +132,7 @@ def _load_config(config, args):
|
|
|
102
132
|
comp_kwargs = model_args.get('complement_kwargs', {})
|
|
103
133
|
compare_name = model_args.get('comparison', 'SigmoidComparison')
|
|
104
134
|
compare_kwargs = model_args.get('comparison_kwargs', {})
|
|
105
|
-
sampling_name = model_args.get('sampling', '
|
|
135
|
+
sampling_name = model_args.get('sampling', 'SoftRandomSampling')
|
|
106
136
|
sampling_kwargs = model_args.get('sampling_kwargs', {})
|
|
107
137
|
rounding_name = model_args.get('rounding', 'SoftRounding')
|
|
108
138
|
rounding_kwargs = model_args.get('rounding_kwargs', {})
|
|
@@ -125,8 +155,7 @@ def _load_config(config, args):
|
|
|
125
155
|
initializer = _getattr_any(
|
|
126
156
|
packages=[initializers, hk.initializers], item=plan_initializer)
|
|
127
157
|
if initializer is None:
|
|
128
|
-
raise_warning(
|
|
129
|
-
f'Ignoring invalid initializer <{plan_initializer}>.', 'red')
|
|
158
|
+
raise_warning(f'Ignoring invalid initializer <{plan_initializer}>.', 'red')
|
|
130
159
|
del plan_kwargs['initializer']
|
|
131
160
|
else:
|
|
132
161
|
init_kwargs = plan_kwargs.pop('initializer_kwargs', {})
|
|
@@ -143,8 +172,7 @@ def _load_config(config, args):
|
|
|
143
172
|
activation = _getattr_any(
|
|
144
173
|
packages=[jax.nn, jax.numpy], item=plan_activation)
|
|
145
174
|
if activation is None:
|
|
146
|
-
raise_warning(
|
|
147
|
-
f'Ignoring invalid activation <{plan_activation}>.', 'red')
|
|
175
|
+
raise_warning(f'Ignoring invalid activation <{plan_activation}>.', 'red')
|
|
148
176
|
del plan_kwargs['activation']
|
|
149
177
|
else:
|
|
150
178
|
plan_kwargs['activation'] = activation
|
|
@@ -158,12 +186,24 @@ def _load_config(config, args):
|
|
|
158
186
|
if planner_optimizer is not None:
|
|
159
187
|
optimizer = _getattr_any(packages=[optax], item=planner_optimizer)
|
|
160
188
|
if optimizer is None:
|
|
161
|
-
raise_warning(
|
|
162
|
-
f'Ignoring invalid optimizer <{planner_optimizer}>.', 'red')
|
|
189
|
+
raise_warning(f'Ignoring invalid optimizer <{planner_optimizer}>.', 'red')
|
|
163
190
|
del planner_args['optimizer']
|
|
164
191
|
else:
|
|
165
192
|
planner_args['optimizer'] = optimizer
|
|
166
|
-
|
|
193
|
+
|
|
194
|
+
# pgpe optimizer
|
|
195
|
+
pgpe_method = planner_args.get('pgpe', 'GaussianPGPE')
|
|
196
|
+
pgpe_kwargs = planner_args.pop('pgpe_kwargs', {})
|
|
197
|
+
if pgpe_method is not None:
|
|
198
|
+
if 'optimizer' in pgpe_kwargs:
|
|
199
|
+
pgpe_optimizer = _getattr_any(packages=[optax], item=pgpe_kwargs['optimizer'])
|
|
200
|
+
if pgpe_optimizer is None:
|
|
201
|
+
raise_warning(f'Ignoring invalid optimizer <{pgpe_optimizer}>.', 'red')
|
|
202
|
+
del pgpe_kwargs['optimizer']
|
|
203
|
+
else:
|
|
204
|
+
pgpe_kwargs['optimizer'] = pgpe_optimizer
|
|
205
|
+
planner_args['pgpe'] = getattr(sys.modules[__name__], pgpe_method)(**pgpe_kwargs)
|
|
206
|
+
|
|
167
207
|
# optimize call RNG key
|
|
168
208
|
planner_key = train_args.get('key', None)
|
|
169
209
|
if planner_key is not None:
|
|
@@ -241,48 +281,14 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
241
281
|
pvars_cast = set()
|
|
242
282
|
for (var, values) in self.init_values.items():
|
|
243
283
|
self.init_values[var] = np.asarray(values, dtype=self.REAL)
|
|
244
|
-
if not np.issubdtype(np.
|
|
284
|
+
if not np.issubdtype(np.result_type(values), np.floating):
|
|
245
285
|
pvars_cast.add(var)
|
|
246
286
|
if pvars_cast:
|
|
247
287
|
raise_warning(f'JAX gradient compiler requires that initial values '
|
|
248
288
|
f'of p-variables {pvars_cast} be cast to float.')
|
|
249
289
|
|
|
250
290
|
# overwrite basic operations with fuzzy ones
|
|
251
|
-
self.
|
|
252
|
-
'>=': logic.greater_equal,
|
|
253
|
-
'<=': logic.less_equal,
|
|
254
|
-
'<': logic.less,
|
|
255
|
-
'>': logic.greater,
|
|
256
|
-
'==': logic.equal,
|
|
257
|
-
'~=': logic.not_equal
|
|
258
|
-
}
|
|
259
|
-
self.LOGICAL_NOT = logic.logical_not
|
|
260
|
-
self.LOGICAL_OPS = {
|
|
261
|
-
'^': logic.logical_and,
|
|
262
|
-
'&': logic.logical_and,
|
|
263
|
-
'|': logic.logical_or,
|
|
264
|
-
'~': logic.xor,
|
|
265
|
-
'=>': logic.implies,
|
|
266
|
-
'<=>': logic.equiv
|
|
267
|
-
}
|
|
268
|
-
self.AGGREGATION_OPS['forall'] = logic.forall
|
|
269
|
-
self.AGGREGATION_OPS['exists'] = logic.exists
|
|
270
|
-
self.AGGREGATION_OPS['argmin'] = logic.argmin
|
|
271
|
-
self.AGGREGATION_OPS['argmax'] = logic.argmax
|
|
272
|
-
self.KNOWN_UNARY['sgn'] = logic.sgn
|
|
273
|
-
self.KNOWN_UNARY['floor'] = logic.floor
|
|
274
|
-
self.KNOWN_UNARY['ceil'] = logic.ceil
|
|
275
|
-
self.KNOWN_UNARY['round'] = logic.round
|
|
276
|
-
self.KNOWN_UNARY['sqrt'] = logic.sqrt
|
|
277
|
-
self.KNOWN_BINARY['div'] = logic.div
|
|
278
|
-
self.KNOWN_BINARY['mod'] = logic.mod
|
|
279
|
-
self.KNOWN_BINARY['fmod'] = logic.mod
|
|
280
|
-
self.IF_HELPER = logic.control_if
|
|
281
|
-
self.SWITCH_HELPER = logic.control_switch
|
|
282
|
-
self.BERNOULLI_HELPER = logic.bernoulli
|
|
283
|
-
self.DISCRETE_HELPER = logic.discrete
|
|
284
|
-
self.POISSON_HELPER = logic.poisson
|
|
285
|
-
self.GEOMETRIC_HELPER = logic.geometric
|
|
291
|
+
self.OPS = logic.get_operator_dicts()
|
|
286
292
|
|
|
287
293
|
def _jax_stop_grad(self, jax_expr):
|
|
288
294
|
def _jax_wrapped_stop_grad(x, params, key):
|
|
@@ -469,16 +475,16 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
469
475
|
bounds = '\n '.join(
|
|
470
476
|
map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
|
|
471
477
|
return (f'policy hyper-parameters:\n'
|
|
472
|
-
f' initializer
|
|
473
|
-
f'constraint-sat strategy (simple):\n'
|
|
474
|
-
f'
|
|
475
|
-
f'
|
|
476
|
-
f'
|
|
477
|
-
f'
|
|
478
|
-
f'constraint-sat strategy (complex):\n'
|
|
479
|
-
f'
|
|
480
|
-
f'
|
|
481
|
-
f'
|
|
478
|
+
f' initializer={self._initializer_base}\n'
|
|
479
|
+
f' constraint-sat strategy (simple):\n'
|
|
480
|
+
f' parsed_action_bounds =\n {bounds}\n'
|
|
481
|
+
f' wrap_sigmoid ={self._wrap_sigmoid}\n'
|
|
482
|
+
f' wrap_sigmoid_min_prob={self._min_action_prob}\n'
|
|
483
|
+
f' wrap_non_bool ={self._wrap_non_bool}\n'
|
|
484
|
+
f' constraint-sat strategy (complex):\n'
|
|
485
|
+
f' wrap_softmax ={self._wrap_softmax}\n'
|
|
486
|
+
f' use_new_projection ={self._use_new_projection}\n'
|
|
487
|
+
f' max_projection_iters={self._max_constraint_iter}\n')
|
|
482
488
|
|
|
483
489
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
484
490
|
_bounds: Bounds,
|
|
@@ -531,7 +537,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
531
537
|
def _jax_non_bool_param_to_action(var, param, hyperparams):
|
|
532
538
|
if wrap_non_bool:
|
|
533
539
|
lower, upper = bounds_safe[var]
|
|
534
|
-
mb, ml, mu, mn = [
|
|
540
|
+
mb, ml, mu, mn = [jnp.asarray(mask, dtype=compiled.REAL)
|
|
535
541
|
for mask in cond_lists[var]]
|
|
536
542
|
action = (
|
|
537
543
|
mb * (lower + (upper - lower) * jax.nn.sigmoid(param)) +
|
|
@@ -616,7 +622,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
616
622
|
action = _jax_non_bool_param_to_action(var, action, hyperparams)
|
|
617
623
|
action = jnp.clip(action, *bounds[var])
|
|
618
624
|
if ranges[var] == 'int':
|
|
619
|
-
action = jnp.round(action)
|
|
625
|
+
action = jnp.asarray(jnp.round(action), dtype=compiled.INT)
|
|
620
626
|
actions[var] = action
|
|
621
627
|
return actions
|
|
622
628
|
|
|
@@ -856,15 +862,16 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
856
862
|
bounds = '\n '.join(
|
|
857
863
|
map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
|
|
858
864
|
return (f'policy hyper-parameters:\n'
|
|
859
|
-
f' topology
|
|
860
|
-
f' activation_fn
|
|
861
|
-
f' initializer
|
|
862
|
-
f'
|
|
863
|
-
f'
|
|
864
|
-
f'
|
|
865
|
-
f'
|
|
866
|
-
f'
|
|
867
|
-
f'
|
|
865
|
+
f' topology ={self._topology}\n'
|
|
866
|
+
f' activation_fn={self._activations[0].__name__}\n'
|
|
867
|
+
f' initializer ={type(self._initializer_base).__name__}\n'
|
|
868
|
+
f' input norm:\n'
|
|
869
|
+
f' apply_input_norm ={self._normalize}\n'
|
|
870
|
+
f' input_norm_layerwise={self._normalize_per_layer}\n'
|
|
871
|
+
f' input_norm_args ={self._normalizer_kwargs}\n'
|
|
872
|
+
f' constraint-sat strategy:\n'
|
|
873
|
+
f' parsed_action_bounds=\n {bounds}\n'
|
|
874
|
+
f' wrap_non_bool ={self._wrap_non_bool}\n')
|
|
868
875
|
|
|
869
876
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
870
877
|
_bounds: Bounds,
|
|
@@ -916,12 +923,11 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
916
923
|
non_bool_dims = 0
|
|
917
924
|
for (var, values) in observed_vars.items():
|
|
918
925
|
if ranges[var] != 'bool':
|
|
919
|
-
value_size = np.
|
|
926
|
+
value_size = np.size(values)
|
|
920
927
|
if normalize_per_layer and value_size == 1:
|
|
921
928
|
raise_warning(
|
|
922
929
|
f'Cannot apply layer norm to state-fluent <{var}> '
|
|
923
|
-
f'of size 1: setting normalize_per_layer = False.',
|
|
924
|
-
'red')
|
|
930
|
+
f'of size 1: setting normalize_per_layer = False.', 'red')
|
|
925
931
|
normalize_per_layer = False
|
|
926
932
|
non_bool_dims += value_size
|
|
927
933
|
if not normalize_per_layer and non_bool_dims == 1:
|
|
@@ -945,9 +951,11 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
945
951
|
else:
|
|
946
952
|
if normalize and normalize_per_layer:
|
|
947
953
|
normalizer = hk.LayerNorm(
|
|
948
|
-
axis=-1,
|
|
954
|
+
axis=-1,
|
|
955
|
+
param_axis=-1,
|
|
949
956
|
name=f'input_norm_{input_names[var]}',
|
|
950
|
-
**self._normalizer_kwargs
|
|
957
|
+
**self._normalizer_kwargs
|
|
958
|
+
)
|
|
951
959
|
state = normalizer(state)
|
|
952
960
|
states_non_bool.append(state)
|
|
953
961
|
non_bool_dims += state.size
|
|
@@ -956,8 +964,11 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
956
964
|
# optionally perform layer normalization on the non-bool inputs
|
|
957
965
|
if normalize and not normalize_per_layer and non_bool_dims:
|
|
958
966
|
normalizer = hk.LayerNorm(
|
|
959
|
-
axis=-1,
|
|
960
|
-
|
|
967
|
+
axis=-1,
|
|
968
|
+
param_axis=-1,
|
|
969
|
+
name='input_norm',
|
|
970
|
+
**self._normalizer_kwargs
|
|
971
|
+
)
|
|
961
972
|
normalized = normalizer(state[:non_bool_dims])
|
|
962
973
|
state = state.at[:non_bool_dims].set(normalized)
|
|
963
974
|
return state
|
|
@@ -976,7 +987,8 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
976
987
|
actions = {}
|
|
977
988
|
for (var, size) in layer_sizes.items():
|
|
978
989
|
linear = hk.Linear(size, name=layer_names[var], w_init=init)
|
|
979
|
-
reshape = hk.Reshape(output_shape=shapes[var],
|
|
990
|
+
reshape = hk.Reshape(output_shape=shapes[var],
|
|
991
|
+
preserve_dims=-1,
|
|
980
992
|
name=f'reshape_{layer_names[var]}')
|
|
981
993
|
output = reshape(linear(hidden))
|
|
982
994
|
if not shapes[var]:
|
|
@@ -989,7 +1001,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
989
1001
|
else:
|
|
990
1002
|
if wrap_non_bool:
|
|
991
1003
|
lower, upper = bounds_safe[var]
|
|
992
|
-
mb, ml, mu, mn = [
|
|
1004
|
+
mb, ml, mu, mn = [jnp.asarray(mask, dtype=compiled.REAL)
|
|
993
1005
|
for mask in cond_lists[var]]
|
|
994
1006
|
action = (
|
|
995
1007
|
mb * (lower + (upper - lower) * jax.nn.sigmoid(output)) +
|
|
@@ -1003,8 +1015,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1003
1015
|
|
|
1004
1016
|
# for constraint satisfaction wrap bool actions with softmax
|
|
1005
1017
|
if use_constraint_satisfaction:
|
|
1006
|
-
linear = hk.Linear(
|
|
1007
|
-
bool_action_count, name='output_bool', w_init=init)
|
|
1018
|
+
linear = hk.Linear(bool_action_count, name='output_bool', w_init=init)
|
|
1008
1019
|
output = jax.nn.softmax(linear(hidden))
|
|
1009
1020
|
actions[bool_key] = output
|
|
1010
1021
|
|
|
@@ -1042,8 +1053,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1042
1053
|
|
|
1043
1054
|
# test action prediction
|
|
1044
1055
|
def _jax_wrapped_drp_predict_test(key, params, hyperparams, step, subs):
|
|
1045
|
-
actions = _jax_wrapped_drp_predict_train(
|
|
1046
|
-
key, params, hyperparams, step, subs)
|
|
1056
|
+
actions = _jax_wrapped_drp_predict_train(key, params, hyperparams, step, subs)
|
|
1047
1057
|
new_actions = {}
|
|
1048
1058
|
for (var, action) in actions.items():
|
|
1049
1059
|
prange = ranges[var]
|
|
@@ -1051,7 +1061,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1051
1061
|
new_action = action > 0.5
|
|
1052
1062
|
elif prange == 'int':
|
|
1053
1063
|
action = jnp.clip(action, *bounds[var])
|
|
1054
|
-
new_action = jnp.round(action)
|
|
1064
|
+
new_action = jnp.asarray(jnp.round(action), dtype=compiled.INT)
|
|
1055
1065
|
else:
|
|
1056
1066
|
new_action = jnp.clip(action, *bounds[var])
|
|
1057
1067
|
new_actions[var] = new_action
|
|
@@ -1090,10 +1100,11 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1090
1100
|
|
|
1091
1101
|
|
|
1092
1102
|
# ***********************************************************************
|
|
1093
|
-
#
|
|
1103
|
+
# SUPPORTING FUNCTIONS
|
|
1094
1104
|
#
|
|
1095
|
-
# -
|
|
1096
|
-
# -
|
|
1105
|
+
# - smoothed mean calculation
|
|
1106
|
+
# - planner status
|
|
1107
|
+
# - stopping criteria
|
|
1097
1108
|
#
|
|
1098
1109
|
# ***********************************************************************
|
|
1099
1110
|
|
|
@@ -1167,6 +1178,329 @@ class NoImprovementStoppingRule(JaxPlannerStoppingRule):
|
|
|
1167
1178
|
return f'No improvement for {self.patience} iterations'
|
|
1168
1179
|
|
|
1169
1180
|
|
|
1181
|
+
# ***********************************************************************
|
|
1182
|
+
# PARAMETER EXPLORING POLICY GRADIENTS (PGPE)
|
|
1183
|
+
#
|
|
1184
|
+
# - simple Gaussian PGPE
|
|
1185
|
+
#
|
|
1186
|
+
# ***********************************************************************
|
|
1187
|
+
|
|
1188
|
+
|
|
1189
|
+
class PGPE:
|
|
1190
|
+
"""Base class for all PGPE strategies."""
|
|
1191
|
+
|
|
1192
|
+
def __init__(self) -> None:
|
|
1193
|
+
self._initializer = None
|
|
1194
|
+
self._update = None
|
|
1195
|
+
|
|
1196
|
+
@property
|
|
1197
|
+
def initialize(self):
|
|
1198
|
+
return self._initializer
|
|
1199
|
+
|
|
1200
|
+
@property
|
|
1201
|
+
def update(self):
|
|
1202
|
+
return self._update
|
|
1203
|
+
|
|
1204
|
+
def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type) -> None:
|
|
1205
|
+
raise NotImplementedError
|
|
1206
|
+
|
|
1207
|
+
|
|
1208
|
+
class GaussianPGPE(PGPE):
|
|
1209
|
+
'''PGPE with a Gaussian parameter distribution.'''
|
|
1210
|
+
|
|
1211
|
+
def __init__(self, batch_size: int=1,
|
|
1212
|
+
init_sigma: float=1.0,
|
|
1213
|
+
sigma_range: Tuple[float, float]=(1e-5, 1e5),
|
|
1214
|
+
scale_reward: bool=True,
|
|
1215
|
+
super_symmetric: bool=True,
|
|
1216
|
+
super_symmetric_accurate: bool=True,
|
|
1217
|
+
optimizer: Callable[..., optax.GradientTransformation]=optax.adam,
|
|
1218
|
+
optimizer_kwargs_mu: Optional[Kwargs]=None,
|
|
1219
|
+
optimizer_kwargs_sigma: Optional[Kwargs]=None) -> None:
|
|
1220
|
+
'''Creates a new Gaussian PGPE planner.
|
|
1221
|
+
|
|
1222
|
+
:param batch_size: how many policy parameters to sample per optimization step
|
|
1223
|
+
:param init_sigma: initial standard deviation of Gaussian
|
|
1224
|
+
:param sigma_range: bounds to constrain standard deviation
|
|
1225
|
+
:param scale_reward: whether to apply reward scaling as in the paper
|
|
1226
|
+
:param super_symmetric: whether to use super-symmetric sampling as in the paper
|
|
1227
|
+
:param super_symmetric_accurate: whether to use the accurate formula for super-
|
|
1228
|
+
symmetric sampling or the simplified but biased formula
|
|
1229
|
+
:param optimizer: a factory for an optax SGD algorithm
|
|
1230
|
+
:param optimizer_kwargs_mu: a dictionary of parameters to pass to the SGD
|
|
1231
|
+
factory for the mean optimizer
|
|
1232
|
+
:param optimizer_kwargs_sigma: a dictionary of parameters to pass to the SGD
|
|
1233
|
+
factory for the standard deviation optimizer
|
|
1234
|
+
'''
|
|
1235
|
+
super().__init__()
|
|
1236
|
+
|
|
1237
|
+
self.batch_size = batch_size
|
|
1238
|
+
self.init_sigma = init_sigma
|
|
1239
|
+
self.sigma_range = sigma_range
|
|
1240
|
+
self.scale_reward = scale_reward
|
|
1241
|
+
self.super_symmetric = super_symmetric
|
|
1242
|
+
self.super_symmetric_accurate = super_symmetric_accurate
|
|
1243
|
+
|
|
1244
|
+
# set optimizers
|
|
1245
|
+
if optimizer_kwargs_mu is None:
|
|
1246
|
+
optimizer_kwargs_mu = {'learning_rate': 0.1}
|
|
1247
|
+
self.optimizer_kwargs_mu = optimizer_kwargs_mu
|
|
1248
|
+
if optimizer_kwargs_sigma is None:
|
|
1249
|
+
optimizer_kwargs_sigma = {'learning_rate': 0.1}
|
|
1250
|
+
self.optimizer_kwargs_sigma = optimizer_kwargs_sigma
|
|
1251
|
+
self.optimizer_name = optimizer
|
|
1252
|
+
mu_optimizer = optimizer(**optimizer_kwargs_mu)
|
|
1253
|
+
sigma_optimizer = optimizer(**optimizer_kwargs_sigma)
|
|
1254
|
+
self.optimizers = (mu_optimizer, sigma_optimizer)
|
|
1255
|
+
|
|
1256
|
+
def __str__(self) -> str:
|
|
1257
|
+
return (f'PGPE hyper-parameters:\n'
|
|
1258
|
+
f' method ={self.__class__.__name__}\n'
|
|
1259
|
+
f' batch_size ={self.batch_size}\n'
|
|
1260
|
+
f' init_sigma ={self.init_sigma}\n'
|
|
1261
|
+
f' sigma_range ={self.sigma_range}\n'
|
|
1262
|
+
f' scale_reward ={self.scale_reward}\n'
|
|
1263
|
+
f' super_symmetric={self.super_symmetric}\n'
|
|
1264
|
+
f' accurate ={self.super_symmetric_accurate}\n'
|
|
1265
|
+
f' optimizer ={self.optimizer_name}\n'
|
|
1266
|
+
f' optimizer_kwargs:\n'
|
|
1267
|
+
f' mu ={self.optimizer_kwargs_mu}\n'
|
|
1268
|
+
f' sigma={self.optimizer_kwargs_sigma}\n'
|
|
1269
|
+
)
|
|
1270
|
+
|
|
1271
|
+
def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type) -> None:
|
|
1272
|
+
MIN_NORM = 1e-5
|
|
1273
|
+
sigma0 = self.init_sigma
|
|
1274
|
+
sigma_range = self.sigma_range
|
|
1275
|
+
scale_reward = self.scale_reward
|
|
1276
|
+
super_symmetric = self.super_symmetric
|
|
1277
|
+
super_symmetric_accurate = self.super_symmetric_accurate
|
|
1278
|
+
batch_size = self.batch_size
|
|
1279
|
+
optimizers = (mu_optimizer, sigma_optimizer) = self.optimizers
|
|
1280
|
+
|
|
1281
|
+
# initializer
|
|
1282
|
+
def _jax_wrapped_pgpe_init(key, policy_params):
|
|
1283
|
+
mu = policy_params
|
|
1284
|
+
sigma = jax.tree_map(lambda x: sigma0 * jnp.ones_like(x), mu)
|
|
1285
|
+
pgpe_params = (mu, sigma)
|
|
1286
|
+
pgpe_opt_state = tuple(opt.init(param)
|
|
1287
|
+
for (opt, param) in zip(optimizers, pgpe_params))
|
|
1288
|
+
return pgpe_params, pgpe_opt_state
|
|
1289
|
+
|
|
1290
|
+
self._initializer = jax.jit(_jax_wrapped_pgpe_init)
|
|
1291
|
+
|
|
1292
|
+
# parameter sampling functions
|
|
1293
|
+
def _jax_wrapped_mu_noise(key, sigma):
|
|
1294
|
+
return sigma * random.normal(key, shape=jnp.shape(sigma), dtype=real_dtype)
|
|
1295
|
+
|
|
1296
|
+
def _jax_wrapped_epsilon_star(sigma, epsilon):
|
|
1297
|
+
c1, c2, c3 = -0.06655, -0.9706, 0.124
|
|
1298
|
+
phi = 0.67449 * sigma
|
|
1299
|
+
a = (sigma - jnp.abs(epsilon)) / sigma
|
|
1300
|
+
if super_symmetric_accurate:
|
|
1301
|
+
aa = jnp.abs(a)
|
|
1302
|
+
epsilon_star = jnp.sign(epsilon) * phi * jnp.where(
|
|
1303
|
+
a <= 0,
|
|
1304
|
+
jnp.exp(c1 * aa * (aa * aa - 1) / jnp.log(aa + 1e-10) + c2 * aa),
|
|
1305
|
+
jnp.exp(aa - c3 * aa * jnp.log(1.0 - jnp.power(aa, 3) + 1e-10))
|
|
1306
|
+
)
|
|
1307
|
+
else:
|
|
1308
|
+
epsilon_star = jnp.sign(epsilon) * phi * jnp.exp(a)
|
|
1309
|
+
return epsilon_star
|
|
1310
|
+
|
|
1311
|
+
def _jax_wrapped_sample_params(key, mu, sigma):
|
|
1312
|
+
keys = random.split(key, num=len(jax.tree_util.tree_leaves(mu)))
|
|
1313
|
+
keys_pytree = jax.tree_util.tree_unflatten(
|
|
1314
|
+
treedef=jax.tree_util.tree_structure(mu), leaves=keys)
|
|
1315
|
+
epsilon = jax.tree_map(_jax_wrapped_mu_noise, keys_pytree, sigma)
|
|
1316
|
+
p1 = jax.tree_map(jnp.add, mu, epsilon)
|
|
1317
|
+
p2 = jax.tree_map(jnp.subtract, mu, epsilon)
|
|
1318
|
+
if super_symmetric:
|
|
1319
|
+
epsilon_star = jax.tree_map(_jax_wrapped_epsilon_star, sigma, epsilon)
|
|
1320
|
+
p3 = jax.tree_map(jnp.add, mu, epsilon_star)
|
|
1321
|
+
p4 = jax.tree_map(jnp.subtract, mu, epsilon_star)
|
|
1322
|
+
else:
|
|
1323
|
+
epsilon_star, p3, p4 = epsilon, p1, p2
|
|
1324
|
+
return (p1, p2, p3, p4), (epsilon, epsilon_star)
|
|
1325
|
+
|
|
1326
|
+
# policy gradient update functions
|
|
1327
|
+
def _jax_wrapped_mu_grad(epsilon, epsilon_star, r1, r2, r3, r4, m):
|
|
1328
|
+
if super_symmetric:
|
|
1329
|
+
if scale_reward:
|
|
1330
|
+
scale1 = jnp.maximum(MIN_NORM, m - (r1 + r2) / 2)
|
|
1331
|
+
scale2 = jnp.maximum(MIN_NORM, m - (r3 + r4) / 2)
|
|
1332
|
+
else:
|
|
1333
|
+
scale1 = scale2 = 1.0
|
|
1334
|
+
r_mu1 = (r1 - r2) / (2 * scale1)
|
|
1335
|
+
r_mu2 = (r3 - r4) / (2 * scale2)
|
|
1336
|
+
grad = -(r_mu1 * epsilon + r_mu2 * epsilon_star)
|
|
1337
|
+
else:
|
|
1338
|
+
if scale_reward:
|
|
1339
|
+
scale = jnp.maximum(MIN_NORM, m - (r1 + r2) / 2)
|
|
1340
|
+
else:
|
|
1341
|
+
scale = 1.0
|
|
1342
|
+
r_mu = (r1 - r2) / (2 * scale)
|
|
1343
|
+
grad = -r_mu * epsilon
|
|
1344
|
+
return grad
|
|
1345
|
+
|
|
1346
|
+
def _jax_wrapped_sigma_grad(epsilon, epsilon_star, sigma, r1, r2, r3, r4, m):
|
|
1347
|
+
if super_symmetric:
|
|
1348
|
+
mask = r1 + r2 >= r3 + r4
|
|
1349
|
+
epsilon_tau = mask * epsilon + (1 - mask) * epsilon_star
|
|
1350
|
+
s = epsilon_tau * epsilon_tau / sigma - sigma
|
|
1351
|
+
if scale_reward:
|
|
1352
|
+
scale = jnp.maximum(MIN_NORM, m - (r1 + r2 + r3 + r4) / 4)
|
|
1353
|
+
else:
|
|
1354
|
+
scale = 1.0
|
|
1355
|
+
r_sigma = ((r1 + r2) - (r3 + r4)) / (4 * scale)
|
|
1356
|
+
else:
|
|
1357
|
+
s = epsilon * epsilon / sigma - sigma
|
|
1358
|
+
if scale_reward:
|
|
1359
|
+
scale = jnp.maximum(MIN_NORM, jnp.abs(m))
|
|
1360
|
+
else:
|
|
1361
|
+
scale = 1.0
|
|
1362
|
+
r_sigma = (r1 + r2) / (2 * scale)
|
|
1363
|
+
grad = -r_sigma * s
|
|
1364
|
+
return grad
|
|
1365
|
+
|
|
1366
|
+
def _jax_wrapped_pgpe_grad(key, mu, sigma, r_max,
|
|
1367
|
+
policy_hyperparams, subs, model_params):
|
|
1368
|
+
key, subkey = random.split(key)
|
|
1369
|
+
(p1, p2, p3, p4), (epsilon, epsilon_star) = _jax_wrapped_sample_params(
|
|
1370
|
+
key, mu, sigma)
|
|
1371
|
+
r1 = -loss_fn(subkey, p1, policy_hyperparams, subs, model_params)[0]
|
|
1372
|
+
r2 = -loss_fn(subkey, p2, policy_hyperparams, subs, model_params)[0]
|
|
1373
|
+
r_max = jnp.maximum(r_max, r1)
|
|
1374
|
+
r_max = jnp.maximum(r_max, r2)
|
|
1375
|
+
if super_symmetric:
|
|
1376
|
+
r3 = -loss_fn(subkey, p3, policy_hyperparams, subs, model_params)[0]
|
|
1377
|
+
r4 = -loss_fn(subkey, p4, policy_hyperparams, subs, model_params)[0]
|
|
1378
|
+
r_max = jnp.maximum(r_max, r3)
|
|
1379
|
+
r_max = jnp.maximum(r_max, r4)
|
|
1380
|
+
else:
|
|
1381
|
+
r3, r4 = r1, r2
|
|
1382
|
+
grad_mu = jax.tree_map(
|
|
1383
|
+
partial(_jax_wrapped_mu_grad, r1=r1, r2=r2, r3=r3, r4=r4, m=r_max),
|
|
1384
|
+
epsilon, epsilon_star
|
|
1385
|
+
)
|
|
1386
|
+
grad_sigma = jax.tree_map(
|
|
1387
|
+
partial(_jax_wrapped_sigma_grad, r1=r1, r2=r2, r3=r3, r4=r4, m=r_max),
|
|
1388
|
+
epsilon, epsilon_star, sigma
|
|
1389
|
+
)
|
|
1390
|
+
return grad_mu, grad_sigma, r_max
|
|
1391
|
+
|
|
1392
|
+
def _jax_wrapped_pgpe_grad_batched(key, pgpe_params, r_max,
|
|
1393
|
+
policy_hyperparams, subs, model_params):
|
|
1394
|
+
mu, sigma = pgpe_params
|
|
1395
|
+
if batch_size == 1:
|
|
1396
|
+
mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad(
|
|
1397
|
+
key, mu, sigma, r_max, policy_hyperparams, subs, model_params)
|
|
1398
|
+
else:
|
|
1399
|
+
keys = random.split(key, num=batch_size)
|
|
1400
|
+
mu_grads, sigma_grads, r_maxs = jax.vmap(
|
|
1401
|
+
_jax_wrapped_pgpe_grad,
|
|
1402
|
+
in_axes=(0, None, None, None, None, None, None)
|
|
1403
|
+
)(keys, mu, sigma, r_max, policy_hyperparams, subs, model_params)
|
|
1404
|
+
mu_grad, sigma_grad = jax.tree_map(
|
|
1405
|
+
partial(jnp.mean, axis=0), (mu_grads, sigma_grads))
|
|
1406
|
+
new_r_max = jnp.max(r_maxs)
|
|
1407
|
+
return mu_grad, sigma_grad, new_r_max
|
|
1408
|
+
|
|
1409
|
+
def _jax_wrapped_pgpe_update(key, pgpe_params, r_max,
|
|
1410
|
+
policy_hyperparams, subs, model_params,
|
|
1411
|
+
pgpe_opt_state):
|
|
1412
|
+
mu, sigma = pgpe_params
|
|
1413
|
+
mu_state, sigma_state = pgpe_opt_state
|
|
1414
|
+
mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad_batched(
|
|
1415
|
+
key, pgpe_params, r_max, policy_hyperparams, subs, model_params)
|
|
1416
|
+
mu_updates, new_mu_state = mu_optimizer.update(mu_grad, mu_state, params=mu)
|
|
1417
|
+
sigma_updates, new_sigma_state = sigma_optimizer.update(
|
|
1418
|
+
sigma_grad, sigma_state, params=sigma)
|
|
1419
|
+
new_mu = optax.apply_updates(mu, mu_updates)
|
|
1420
|
+
new_mu, converged = projection(new_mu, policy_hyperparams)
|
|
1421
|
+
new_sigma = optax.apply_updates(sigma, sigma_updates)
|
|
1422
|
+
new_sigma = jax.tree_map(lambda x: jnp.clip(x, *sigma_range), new_sigma)
|
|
1423
|
+
new_pgpe_params = (new_mu, new_sigma)
|
|
1424
|
+
new_pgpe_opt_state = (new_mu_state, new_sigma_state)
|
|
1425
|
+
policy_params = new_mu
|
|
1426
|
+
return new_pgpe_params, new_r_max, new_pgpe_opt_state, policy_params, converged
|
|
1427
|
+
|
|
1428
|
+
self._update = jax.jit(_jax_wrapped_pgpe_update)
|
|
1429
|
+
|
|
1430
|
+
|
|
1431
|
+
# ***********************************************************************
|
|
1432
|
+
# ALL VERSIONS OF RISK FUNCTIONS
|
|
1433
|
+
#
|
|
1434
|
+
# Based on the original paper "A Distributional Framework for Risk-Sensitive
|
|
1435
|
+
# End-to-End Planning in Continuous MDPs" by Patton et al., AAAI 2022.
|
|
1436
|
+
#
|
|
1437
|
+
# Original risk functions:
|
|
1438
|
+
# - entropic utility
|
|
1439
|
+
# - mean-variance
|
|
1440
|
+
# - mean-semideviation
|
|
1441
|
+
# - conditional value at risk with straight-through gradient trick
|
|
1442
|
+
#
|
|
1443
|
+
# ***********************************************************************
|
|
1444
|
+
|
|
1445
|
+
|
|
1446
|
+
@jax.jit
|
|
1447
|
+
def entropic_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
1448
|
+
return (-1.0 / beta) * jax.scipy.special.logsumexp(
|
|
1449
|
+
-beta * returns, b=1.0 / returns.size)
|
|
1450
|
+
|
|
1451
|
+
|
|
1452
|
+
@jax.jit
|
|
1453
|
+
def mean_variance_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
1454
|
+
return jnp.mean(returns) - 0.5 * beta * jnp.var(returns)
|
|
1455
|
+
|
|
1456
|
+
|
|
1457
|
+
@jax.jit
|
|
1458
|
+
def mean_deviation_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
1459
|
+
return jnp.mean(returns) - 0.5 * beta * jnp.std(returns)
|
|
1460
|
+
|
|
1461
|
+
|
|
1462
|
+
@jax.jit
|
|
1463
|
+
def mean_semideviation_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
1464
|
+
mu = jnp.mean(returns)
|
|
1465
|
+
msd = jnp.sqrt(jnp.mean(jnp.minimum(0.0, returns - mu) ** 2))
|
|
1466
|
+
return mu - 0.5 * beta * msd
|
|
1467
|
+
|
|
1468
|
+
|
|
1469
|
+
@jax.jit
|
|
1470
|
+
def mean_semivariance_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
1471
|
+
mu = jnp.mean(returns)
|
|
1472
|
+
msv = jnp.mean(jnp.minimum(0.0, returns - mu) ** 2)
|
|
1473
|
+
return mu - 0.5 * beta * msv
|
|
1474
|
+
|
|
1475
|
+
|
|
1476
|
+
@jax.jit
|
|
1477
|
+
def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
|
|
1478
|
+
var = jnp.percentile(returns, q=100 * alpha)
|
|
1479
|
+
mask = returns <= var
|
|
1480
|
+
weights = mask / jnp.maximum(1, jnp.sum(mask))
|
|
1481
|
+
return jnp.sum(returns * weights)
|
|
1482
|
+
|
|
1483
|
+
|
|
1484
|
+
UTILITY_LOOKUP = {
|
|
1485
|
+
'mean': jnp.mean,
|
|
1486
|
+
'mean_var': mean_variance_utility,
|
|
1487
|
+
'mean_std': mean_deviation_utility,
|
|
1488
|
+
'mean_semivar': mean_semivariance_utility,
|
|
1489
|
+
'mean_semidev': mean_semideviation_utility,
|
|
1490
|
+
'entropic': entropic_utility,
|
|
1491
|
+
'exponential': entropic_utility,
|
|
1492
|
+
'cvar': cvar_utility
|
|
1493
|
+
}
|
|
1494
|
+
|
|
1495
|
+
|
|
1496
|
+
# ***********************************************************************
|
|
1497
|
+
# ALL VERSIONS OF JAX PLANNER
|
|
1498
|
+
#
|
|
1499
|
+
# - simple gradient descent based planner
|
|
1500
|
+
#
|
|
1501
|
+
# ***********************************************************************
|
|
1502
|
+
|
|
1503
|
+
|
|
1170
1504
|
class JaxBackpropPlanner:
|
|
1171
1505
|
'''A class for optimizing an action sequence in the given RDDL MDP using
|
|
1172
1506
|
gradient descent.'''
|
|
@@ -1183,6 +1517,7 @@ class JaxBackpropPlanner:
|
|
|
1183
1517
|
clip_grad: Optional[float]=None,
|
|
1184
1518
|
line_search_kwargs: Optional[Kwargs]=None,
|
|
1185
1519
|
noise_kwargs: Optional[Kwargs]=None,
|
|
1520
|
+
pgpe: Optional[PGPE]=GaussianPGPE(),
|
|
1186
1521
|
logic: Logic=FuzzyLogic(),
|
|
1187
1522
|
use_symlog_reward: bool=False,
|
|
1188
1523
|
utility: Union[Callable[[jnp.ndarray], float], str]='mean',
|
|
@@ -1213,14 +1548,14 @@ class JaxBackpropPlanner:
|
|
|
1213
1548
|
:param line_search_kwargs: parameters to pass to optional line search
|
|
1214
1549
|
method to scale learning rate
|
|
1215
1550
|
:param noise_kwargs: parameters of optional gradient noise
|
|
1551
|
+
:param pgpe: optional policy gradient to run alongside the planner
|
|
1216
1552
|
:param logic: a subclass of Logic for mapping exact mathematical
|
|
1217
1553
|
operations to their differentiable counterparts
|
|
1218
1554
|
:param use_symlog_reward: whether to use the symlog transform on the
|
|
1219
1555
|
reward as a form of normalization
|
|
1220
1556
|
:param utility: how to aggregate return observations to compute utility
|
|
1221
1557
|
of a policy or plan; must be either a function mapping jax array to a
|
|
1222
|
-
scalar, or a a string identifying the utility function by name
|
|
1223
|
-
("mean", "mean_var", "entropic", or "cvar" are currently supported)
|
|
1558
|
+
scalar, or a a string identifying the utility function by name
|
|
1224
1559
|
:param utility_kwargs: additional keyword arguments to pass hyper-
|
|
1225
1560
|
parameters to the utility function call
|
|
1226
1561
|
:param cpfs_without_grad: which CPFs do not have gradients (use straight
|
|
@@ -1251,6 +1586,8 @@ class JaxBackpropPlanner:
|
|
|
1251
1586
|
self.clip_grad = clip_grad
|
|
1252
1587
|
self.line_search_kwargs = line_search_kwargs
|
|
1253
1588
|
self.noise_kwargs = noise_kwargs
|
|
1589
|
+
self.pgpe = pgpe
|
|
1590
|
+
self.use_pgpe = pgpe is not None
|
|
1254
1591
|
|
|
1255
1592
|
# set optimizer
|
|
1256
1593
|
try:
|
|
@@ -1276,18 +1613,11 @@ class JaxBackpropPlanner:
|
|
|
1276
1613
|
# set utility
|
|
1277
1614
|
if isinstance(utility, str):
|
|
1278
1615
|
utility = utility.lower()
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
elif utility == 'mean_var':
|
|
1282
|
-
utility_fn = mean_variance_utility
|
|
1283
|
-
elif utility == 'entropic':
|
|
1284
|
-
utility_fn = entropic_utility
|
|
1285
|
-
elif utility == 'cvar':
|
|
1286
|
-
utility_fn = cvar_utility
|
|
1287
|
-
else:
|
|
1616
|
+
utility_fn = UTILITY_LOOKUP.get(utility, None)
|
|
1617
|
+
if utility_fn is None:
|
|
1288
1618
|
raise RDDLNotImplementedError(
|
|
1289
|
-
f'Utility
|
|
1290
|
-
'must be one of
|
|
1619
|
+
f'Utility <{utility}> is not supported, '
|
|
1620
|
+
f'must be one of {list(UTILITY_LOOKUP.keys())}.')
|
|
1291
1621
|
else:
|
|
1292
1622
|
utility_fn = utility
|
|
1293
1623
|
self.utility = utility_fn
|
|
@@ -1355,24 +1685,25 @@ r"""
|
|
|
1355
1685
|
f' line_search_kwargs={self.line_search_kwargs}\n'
|
|
1356
1686
|
f' noise_kwargs ={self.noise_kwargs}\n'
|
|
1357
1687
|
f' batch_size_train ={self.batch_size_train}\n'
|
|
1358
|
-
f' batch_size_test ={self.batch_size_test}')
|
|
1359
|
-
result +=
|
|
1360
|
-
|
|
1688
|
+
f' batch_size_test ={self.batch_size_test}\n')
|
|
1689
|
+
result += str(self.plan)
|
|
1690
|
+
if self.use_pgpe:
|
|
1691
|
+
result += str(self.pgpe)
|
|
1692
|
+
result += str(self.logic)
|
|
1361
1693
|
|
|
1362
1694
|
# print model relaxation information
|
|
1363
|
-
if
|
|
1364
|
-
|
|
1365
|
-
|
|
1366
|
-
|
|
1367
|
-
|
|
1368
|
-
|
|
1369
|
-
|
|
1370
|
-
|
|
1371
|
-
|
|
1372
|
-
|
|
1373
|
-
|
|
1374
|
-
|
|
1375
|
-
f' init_values={values_by_rddl_op[rddl_op]}\n')
|
|
1695
|
+
if self.compiled.model_params:
|
|
1696
|
+
result += ('Some RDDL operations are non-differentiable '
|
|
1697
|
+
'and will be approximated as follows:' + '\n')
|
|
1698
|
+
exprs_by_rddl_op, values_by_rddl_op = {}, {}
|
|
1699
|
+
for info in self.compiled.model_parameter_info().values():
|
|
1700
|
+
rddl_op = info['rddl_op']
|
|
1701
|
+
exprs_by_rddl_op.setdefault(rddl_op, []).append(info['id'])
|
|
1702
|
+
values_by_rddl_op.setdefault(rddl_op, []).append(info['init_value'])
|
|
1703
|
+
for rddl_op in sorted(exprs_by_rddl_op.keys()):
|
|
1704
|
+
result += (f' {rddl_op}:\n'
|
|
1705
|
+
f' addresses ={exprs_by_rddl_op[rddl_op]}\n'
|
|
1706
|
+
f' init_values={values_by_rddl_op[rddl_op]}\n')
|
|
1376
1707
|
return result
|
|
1377
1708
|
|
|
1378
1709
|
def summarize_hyperparameters(self) -> None:
|
|
@@ -1438,6 +1769,15 @@ r"""
|
|
|
1438
1769
|
# optimization
|
|
1439
1770
|
self.update = self._jax_update(train_loss)
|
|
1440
1771
|
self.check_zero_grad = self._jax_check_zero_gradients()
|
|
1772
|
+
|
|
1773
|
+
# pgpe option
|
|
1774
|
+
if self.use_pgpe:
|
|
1775
|
+
loss_fn = self._jax_loss(rollouts=test_rollouts)
|
|
1776
|
+
self.pgpe.compile(
|
|
1777
|
+
loss_fn=loss_fn,
|
|
1778
|
+
projection=self.plan.projection,
|
|
1779
|
+
real_dtype=self.test_compiled.REAL
|
|
1780
|
+
)
|
|
1441
1781
|
|
|
1442
1782
|
def _jax_return(self, use_symlog):
|
|
1443
1783
|
gamma = self.rddl.discount
|
|
@@ -1547,7 +1887,7 @@ r"""
|
|
|
1547
1887
|
f'{set(self.test_compiled.init_values.keys())}.')
|
|
1548
1888
|
value = np.reshape(value, newshape=np.shape(init_value))[np.newaxis, ...]
|
|
1549
1889
|
train_value = np.repeat(value, repeats=n_train, axis=0)
|
|
1550
|
-
train_value =
|
|
1890
|
+
train_value = np.asarray(train_value, dtype=self.compiled.REAL)
|
|
1551
1891
|
init_train[name] = train_value
|
|
1552
1892
|
init_test[name] = np.repeat(value, repeats=n_test, axis=0)
|
|
1553
1893
|
|
|
@@ -1646,7 +1986,7 @@ r"""
|
|
|
1646
1986
|
return grad
|
|
1647
1987
|
|
|
1648
1988
|
return _loss_function, _grad_function, guess_1d, jax.jit(unravel_fn)
|
|
1649
|
-
|
|
1989
|
+
|
|
1650
1990
|
# ===========================================================================
|
|
1651
1991
|
# OPTIMIZE API
|
|
1652
1992
|
# ===========================================================================
|
|
@@ -1819,7 +2159,17 @@ r"""
|
|
|
1819
2159
|
policy_params = guess
|
|
1820
2160
|
opt_state = self.optimizer.init(policy_params)
|
|
1821
2161
|
opt_aux = {}
|
|
1822
|
-
|
|
2162
|
+
|
|
2163
|
+
# initialize pgpe parameters
|
|
2164
|
+
if self.use_pgpe:
|
|
2165
|
+
pgpe_params, pgpe_opt_state = self.pgpe.initialize(key, policy_params)
|
|
2166
|
+
rolling_pgpe_loss = RollingMean(test_rolling_window)
|
|
2167
|
+
else:
|
|
2168
|
+
pgpe_params, pgpe_opt_state = None, None
|
|
2169
|
+
rolling_pgpe_loss = None
|
|
2170
|
+
total_pgpe_it = 0
|
|
2171
|
+
r_max = -jnp.inf
|
|
2172
|
+
|
|
1823
2173
|
# ======================================================================
|
|
1824
2174
|
# INITIALIZATION OF RUNNING STATISTICS
|
|
1825
2175
|
# ======================================================================
|
|
@@ -1847,7 +2197,9 @@ r"""
|
|
|
1847
2197
|
|
|
1848
2198
|
iters = range(epochs)
|
|
1849
2199
|
if print_progress:
|
|
1850
|
-
iters = tqdm(iters, total=100,
|
|
2200
|
+
iters = tqdm(iters, total=100,
|
|
2201
|
+
bar_format='{l_bar}{bar}| {elapsed} {postfix}',
|
|
2202
|
+
position=tqdm_position)
|
|
1851
2203
|
position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
|
|
1852
2204
|
|
|
1853
2205
|
for it in iters:
|
|
@@ -1860,17 +2212,47 @@ r"""
|
|
|
1860
2212
|
|
|
1861
2213
|
# update the parameters of the plan
|
|
1862
2214
|
key, subkey = random.split(key)
|
|
1863
|
-
(policy_params, converged, opt_state, opt_aux,
|
|
1864
|
-
|
|
1865
|
-
|
|
1866
|
-
|
|
1867
|
-
|
|
2215
|
+
(policy_params, converged, opt_state, opt_aux, train_loss, train_log,
|
|
2216
|
+
model_params) = self.update(subkey, policy_params, policy_hyperparams,
|
|
2217
|
+
train_subs, model_params, opt_state, opt_aux)
|
|
2218
|
+
test_loss, (test_log, model_params_test) = self.test_loss(
|
|
2219
|
+
subkey, policy_params, policy_hyperparams, test_subs, model_params_test)
|
|
2220
|
+
test_loss_smooth = rolling_test_loss.update(test_loss)
|
|
2221
|
+
|
|
2222
|
+
# pgpe update of the plan
|
|
2223
|
+
pgpe_improve = False
|
|
2224
|
+
if self.use_pgpe:
|
|
2225
|
+
key, subkey = random.split(key)
|
|
2226
|
+
pgpe_params, r_max, pgpe_opt_state, pgpe_param, pgpe_converged = \
|
|
2227
|
+
self.pgpe.update(subkey, pgpe_params, r_max, policy_hyperparams,
|
|
2228
|
+
test_subs, model_params, pgpe_opt_state)
|
|
2229
|
+
pgpe_loss, _ = self.test_loss(
|
|
2230
|
+
subkey, pgpe_param, policy_hyperparams, test_subs, model_params_test)
|
|
2231
|
+
pgpe_loss_smooth = rolling_pgpe_loss.update(pgpe_loss)
|
|
2232
|
+
pgpe_return = -pgpe_loss_smooth
|
|
2233
|
+
|
|
2234
|
+
# replace with PGPE if it reaches a new minimum or train loss invalid
|
|
2235
|
+
if pgpe_loss_smooth < best_loss or not np.isfinite(train_loss):
|
|
2236
|
+
policy_params = pgpe_param
|
|
2237
|
+
test_loss, test_loss_smooth = pgpe_loss, pgpe_loss_smooth
|
|
2238
|
+
converged = pgpe_converged
|
|
2239
|
+
pgpe_improve = True
|
|
2240
|
+
total_pgpe_it += 1
|
|
2241
|
+
else:
|
|
2242
|
+
pgpe_loss, pgpe_loss_smooth, pgpe_return = None, None, None
|
|
2243
|
+
|
|
2244
|
+
# evaluate test losses and record best plan so far
|
|
2245
|
+
if test_loss_smooth < best_loss:
|
|
2246
|
+
best_params, best_loss, best_grad = \
|
|
2247
|
+
policy_params, test_loss_smooth, train_log['grad']
|
|
2248
|
+
last_iter_improve = it
|
|
2249
|
+
|
|
1868
2250
|
# ==================================================================
|
|
1869
2251
|
# STATUS CHECKS AND LOGGING
|
|
1870
2252
|
# ==================================================================
|
|
1871
2253
|
|
|
1872
2254
|
# no progress
|
|
1873
|
-
if self.check_zero_grad(train_log['grad']):
|
|
2255
|
+
if (not pgpe_improve) and self.check_zero_grad(train_log['grad']):
|
|
1874
2256
|
status = JaxPlannerStatus.NO_PROGRESS
|
|
1875
2257
|
|
|
1876
2258
|
# constraint satisfaction problem
|
|
@@ -1882,21 +2264,14 @@ r"""
|
|
|
1882
2264
|
status = JaxPlannerStatus.PRECONDITION_POSSIBLY_UNSATISFIED
|
|
1883
2265
|
|
|
1884
2266
|
# numerical error
|
|
1885
|
-
if
|
|
1886
|
-
|
|
1887
|
-
|
|
2267
|
+
if self.use_pgpe:
|
|
2268
|
+
invalid_loss = not (np.isfinite(train_loss) or np.isfinite(pgpe_loss))
|
|
2269
|
+
else:
|
|
2270
|
+
invalid_loss = not np.isfinite(train_loss)
|
|
2271
|
+
if invalid_loss:
|
|
2272
|
+
raise_warning(f'Planner aborted due to invalid loss {train_loss}.', 'red')
|
|
1888
2273
|
status = JaxPlannerStatus.INVALID_GRADIENT
|
|
1889
2274
|
|
|
1890
|
-
# evaluate test losses and record best plan so far
|
|
1891
|
-
test_loss, (log, model_params_test) = self.test_loss(
|
|
1892
|
-
subkey, policy_params, policy_hyperparams,
|
|
1893
|
-
test_subs, model_params_test)
|
|
1894
|
-
test_loss = rolling_test_loss.update(test_loss)
|
|
1895
|
-
if test_loss < best_loss:
|
|
1896
|
-
best_params, best_loss, best_grad = \
|
|
1897
|
-
policy_params, test_loss, train_log['grad']
|
|
1898
|
-
last_iter_improve = it
|
|
1899
|
-
|
|
1900
2275
|
# reached computation budget
|
|
1901
2276
|
elapsed = time.time() - start_time - elapsed_outside_loop
|
|
1902
2277
|
if elapsed >= train_seconds:
|
|
@@ -1905,16 +2280,20 @@ r"""
|
|
|
1905
2280
|
status = JaxPlannerStatus.ITER_BUDGET_REACHED
|
|
1906
2281
|
|
|
1907
2282
|
# build a callback
|
|
1908
|
-
progress_percent =
|
|
2283
|
+
progress_percent = 100 * min(
|
|
2284
|
+
1, max(0, elapsed / train_seconds, it / (epochs - 1)))
|
|
1909
2285
|
callback = {
|
|
1910
2286
|
'status': status,
|
|
1911
2287
|
'iteration': it,
|
|
1912
2288
|
'train_return':-train_loss,
|
|
1913
|
-
'test_return':-
|
|
2289
|
+
'test_return':-test_loss_smooth,
|
|
1914
2290
|
'best_return':-best_loss,
|
|
2291
|
+
'pgpe_return': pgpe_return,
|
|
1915
2292
|
'params': policy_params,
|
|
1916
2293
|
'best_params': best_params,
|
|
2294
|
+
'pgpe_params': pgpe_params,
|
|
1917
2295
|
'last_iteration_improved': last_iter_improve,
|
|
2296
|
+
'pgpe_improved': pgpe_improve,
|
|
1918
2297
|
'grad': train_log['grad'],
|
|
1919
2298
|
'best_grad': best_grad,
|
|
1920
2299
|
'updates': train_log['updates'],
|
|
@@ -1923,9 +2302,9 @@ r"""
|
|
|
1923
2302
|
'model_params': model_params,
|
|
1924
2303
|
'progress': progress_percent,
|
|
1925
2304
|
'train_log': train_log,
|
|
1926
|
-
**
|
|
2305
|
+
**test_log
|
|
1927
2306
|
}
|
|
1928
|
-
|
|
2307
|
+
|
|
1929
2308
|
# stopping condition reached
|
|
1930
2309
|
if stopping_rule is not None and stopping_rule.monitor(callback):
|
|
1931
2310
|
callback['status'] = status = JaxPlannerStatus.STOPPING_RULE_REACHED
|
|
@@ -1934,10 +2313,12 @@ r"""
|
|
|
1934
2313
|
if print_progress:
|
|
1935
2314
|
iters.n = progress_percent
|
|
1936
2315
|
iters.set_description(
|
|
1937
|
-
f'{position_str} {it:6} it / {-train_loss:14.
|
|
1938
|
-
f'{-
|
|
1939
|
-
f'{status.value} status'
|
|
2316
|
+
f'{position_str} {it:6} it / {-train_loss:14.5f} train / '
|
|
2317
|
+
f'{-test_loss_smooth:14.5f} test / {-best_loss:14.5f} best / '
|
|
2318
|
+
f'{status.value} status / {total_pgpe_it:6} pgpe',
|
|
2319
|
+
refresh=False
|
|
1940
2320
|
)
|
|
2321
|
+
iters.set_postfix_str(f"{(it + 1) / elapsed:.2f}it/s", refresh=True)
|
|
1941
2322
|
|
|
1942
2323
|
# dash-board
|
|
1943
2324
|
if dashboard is not None:
|
|
@@ -1955,7 +2336,7 @@ r"""
|
|
|
1955
2336
|
# ======================================================================
|
|
1956
2337
|
# POST-PROCESSING AND CLEANUP
|
|
1957
2338
|
# ======================================================================
|
|
1958
|
-
|
|
2339
|
+
|
|
1959
2340
|
# release resources
|
|
1960
2341
|
if print_progress:
|
|
1961
2342
|
iters.close()
|
|
@@ -1967,7 +2348,7 @@ r"""
|
|
|
1967
2348
|
messages.update(JaxRDDLCompiler.get_error_messages(error_code))
|
|
1968
2349
|
if messages:
|
|
1969
2350
|
messages = '\n'.join(messages)
|
|
1970
|
-
raise_warning('
|
|
2351
|
+
raise_warning('JAX compiler encountered the following '
|
|
1971
2352
|
'error(s) in the original RDDL formulation '
|
|
1972
2353
|
f'during test evaluation:\n{messages}', 'red')
|
|
1973
2354
|
|
|
@@ -1975,14 +2356,14 @@ r"""
|
|
|
1975
2356
|
if print_summary:
|
|
1976
2357
|
grad_norm = jax.tree_map(lambda x: np.linalg.norm(x).item(), best_grad)
|
|
1977
2358
|
diagnosis = self._perform_diagnosis(
|
|
1978
|
-
last_iter_improve, -train_loss, -
|
|
2359
|
+
last_iter_improve, -train_loss, -test_loss_smooth, -best_loss, grad_norm)
|
|
1979
2360
|
print(f'summary of optimization:\n'
|
|
1980
|
-
f'
|
|
1981
|
-
f'
|
|
2361
|
+
f' status ={status}\n'
|
|
2362
|
+
f' time ={elapsed:.3f} sec.\n'
|
|
1982
2363
|
f' iterations ={it}\n'
|
|
1983
|
-
f'
|
|
1984
|
-
f'
|
|
1985
|
-
f'
|
|
2364
|
+
f' best objective={-best_loss:.6f}\n'
|
|
2365
|
+
f' best grad norm={grad_norm}\n'
|
|
2366
|
+
f'diagnosis: {diagnosis}\n')
|
|
1986
2367
|
|
|
1987
2368
|
def _perform_diagnosis(self, last_iter_improve,
|
|
1988
2369
|
train_return, test_return, best_return, grad_norm):
|
|
@@ -2002,23 +2383,24 @@ r"""
|
|
|
2002
2383
|
if last_iter_improve <= 1:
|
|
2003
2384
|
if grad_is_zero:
|
|
2004
2385
|
return termcolor.colored(
|
|
2005
|
-
'[FAILURE] no progress was made
|
|
2006
|
-
f'and max grad norm {max_grad_norm:.6f}
|
|
2386
|
+
'[FAILURE] no progress was made '
|
|
2387
|
+
f'and max grad norm {max_grad_norm:.6f} was zero: '
|
|
2007
2388
|
'solver likely stuck in a plateau.', 'red')
|
|
2008
2389
|
else:
|
|
2009
2390
|
return termcolor.colored(
|
|
2010
|
-
'[FAILURE] no progress was made
|
|
2011
|
-
f'but max grad norm {max_grad_norm:.6f}
|
|
2012
|
-
'
|
|
2391
|
+
'[FAILURE] no progress was made '
|
|
2392
|
+
f'but max grad norm {max_grad_norm:.6f} was non-zero: '
|
|
2393
|
+
'learning rate or other hyper-parameters likely suboptimal.',
|
|
2394
|
+
'red')
|
|
2013
2395
|
|
|
2014
2396
|
# model is likely poor IF:
|
|
2015
2397
|
# 1. the train and test return disagree
|
|
2016
2398
|
if not (validation_error < 20):
|
|
2017
2399
|
return termcolor.colored(
|
|
2018
|
-
'[WARNING] progress was made
|
|
2019
|
-
f'but relative train-test error {validation_error:.6f}
|
|
2020
|
-
'
|
|
2021
|
-
'
|
|
2400
|
+
'[WARNING] progress was made '
|
|
2401
|
+
f'but relative train-test error {validation_error:.6f} was high: '
|
|
2402
|
+
'poor model relaxation around solution or batch size too small.',
|
|
2403
|
+
'yellow')
|
|
2022
2404
|
|
|
2023
2405
|
# model likely did not converge IF:
|
|
2024
2406
|
# 1. the max grad relative to the return is high
|
|
@@ -2026,15 +2408,15 @@ r"""
|
|
|
2026
2408
|
return_to_grad_norm = abs(best_return) / max_grad_norm
|
|
2027
2409
|
if not (return_to_grad_norm > 1):
|
|
2028
2410
|
return termcolor.colored(
|
|
2029
|
-
'[WARNING] progress was made
|
|
2030
|
-
f'but max grad norm {max_grad_norm:.6f}
|
|
2031
|
-
'
|
|
2032
|
-
'or
|
|
2033
|
-
'or
|
|
2411
|
+
'[WARNING] progress was made '
|
|
2412
|
+
f'but max grad norm {max_grad_norm:.6f} was high: '
|
|
2413
|
+
'solution locally suboptimal '
|
|
2414
|
+
'or relaxed model not smooth around solution '
|
|
2415
|
+
'or batch size too small.', 'yellow')
|
|
2034
2416
|
|
|
2035
2417
|
# likely successful
|
|
2036
2418
|
return termcolor.colored(
|
|
2037
|
-
'[SUCCESS]
|
|
2419
|
+
'[SUCCESS] solver converged successfully '
|
|
2038
2420
|
'(note: not all potential problems can be ruled out).', 'green')
|
|
2039
2421
|
|
|
2040
2422
|
def get_action(self, key: random.PRNGKey,
|
|
@@ -2057,8 +2439,7 @@ r"""
|
|
|
2057
2439
|
for (var, values) in subs.items():
|
|
2058
2440
|
|
|
2059
2441
|
# must not be grounded
|
|
2060
|
-
if RDDLPlanningModel.FLUENT_SEP in var
|
|
2061
|
-
or RDDLPlanningModel.OBJECT_SEP in var:
|
|
2442
|
+
if RDDLPlanningModel.FLUENT_SEP in var or RDDLPlanningModel.OBJECT_SEP in var:
|
|
2062
2443
|
raise ValueError(f'State dictionary passed to the JAX policy is '
|
|
2063
2444
|
f'grounded, since it contains the key <{var}>, '
|
|
2064
2445
|
f'but a vectorized environment is required: '
|
|
@@ -2066,9 +2447,8 @@ r"""
|
|
|
2066
2447
|
|
|
2067
2448
|
# must be numeric array
|
|
2068
2449
|
# exception is for POMDPs at 1st epoch when observ-fluents are None
|
|
2069
|
-
dtype = np.
|
|
2070
|
-
if not np.issubdtype(dtype, np.number)
|
|
2071
|
-
and not np.issubdtype(dtype, np.bool_):
|
|
2450
|
+
dtype = np.result_type(values)
|
|
2451
|
+
if not np.issubdtype(dtype, np.number) and not np.issubdtype(dtype, np.bool_):
|
|
2072
2452
|
if step == 0 and var in self.rddl.observ_fluents:
|
|
2073
2453
|
subs[var] = self.test_compiled.init_values[var]
|
|
2074
2454
|
else:
|
|
@@ -2080,40 +2460,7 @@ r"""
|
|
|
2080
2460
|
actions = self.test_policy(key, params, policy_hyperparams, step, subs)
|
|
2081
2461
|
actions = jax.tree_map(np.asarray, actions)
|
|
2082
2462
|
return actions
|
|
2083
|
-
|
|
2084
|
-
|
|
2085
|
-
# ***********************************************************************
|
|
2086
|
-
# ALL VERSIONS OF RISK FUNCTIONS
|
|
2087
|
-
#
|
|
2088
|
-
# Based on the original paper "A Distributional Framework for Risk-Sensitive
|
|
2089
|
-
# End-to-End Planning in Continuous MDPs" by Patton et al., AAAI 2022.
|
|
2090
|
-
#
|
|
2091
|
-
# Original risk functions:
|
|
2092
|
-
# - entropic utility
|
|
2093
|
-
# - mean-variance approximation
|
|
2094
|
-
# - conditional value at risk with straight-through gradient trick
|
|
2095
|
-
#
|
|
2096
|
-
# ***********************************************************************
|
|
2097
|
-
|
|
2098
|
-
|
|
2099
|
-
@jax.jit
|
|
2100
|
-
def entropic_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
2101
|
-
return (-1.0 / beta) * jax.scipy.special.logsumexp(
|
|
2102
|
-
-beta * returns, b=1.0 / returns.size)
|
|
2103
|
-
|
|
2104
|
-
|
|
2105
|
-
@jax.jit
|
|
2106
|
-
def mean_variance_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
2107
|
-
return jnp.mean(returns) - 0.5 * beta * jnp.var(returns)
|
|
2108
|
-
|
|
2109
|
-
|
|
2110
|
-
@jax.jit
|
|
2111
|
-
def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
|
|
2112
|
-
var = jnp.percentile(returns, q=100 * alpha)
|
|
2113
|
-
mask = returns <= var
|
|
2114
|
-
weights = mask / jnp.maximum(1, jnp.sum(mask))
|
|
2115
|
-
return jnp.sum(returns * weights)
|
|
2116
|
-
|
|
2463
|
+
|
|
2117
2464
|
|
|
2118
2465
|
# ***********************************************************************
|
|
2119
2466
|
# ALL VERSIONS OF CONTROLLERS
|
|
@@ -2225,8 +2572,7 @@ class JaxOnlineController(BaseAgent):
|
|
|
2225
2572
|
self.callback = callback
|
|
2226
2573
|
params = callback['best_params']
|
|
2227
2574
|
self.key, subkey = random.split(self.key)
|
|
2228
|
-
actions = planner.get_action(
|
|
2229
|
-
subkey, params, 0, state, self.eval_hyperparams)
|
|
2575
|
+
actions = planner.get_action(subkey, params, 0, state, self.eval_hyperparams)
|
|
2230
2576
|
if self.warm_start:
|
|
2231
2577
|
self.guess = planner.plan.guess_next_epoch(params)
|
|
2232
2578
|
return actions
|