pyRDDLGym-jax 0.5__py3-none-any.whl → 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +463 -592
- pyRDDLGym_jax/core/logic.py +784 -544
- pyRDDLGym_jax/core/planner.py +329 -463
- pyRDDLGym_jax/core/simulator.py +7 -5
- pyRDDLGym_jax/core/tuning.py +379 -568
- pyRDDLGym_jax/core/visualization.py +1463 -0
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +5 -6
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +4 -5
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +5 -6
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +5 -5
- pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/default_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/default_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/default_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/tuning_drp.cfg +19 -0
- pyRDDLGym_jax/examples/configs/tuning_replan.cfg +20 -0
- pyRDDLGym_jax/examples/configs/tuning_slp.cfg +19 -0
- pyRDDLGym_jax/examples/run_plan.py +4 -1
- pyRDDLGym_jax/examples/run_tune.py +40 -27
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/METADATA +161 -104
- pyRDDLGym_jax-1.0.dist-info/RECORD +45 -0
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -19
- pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -20
- pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -18
- pyRDDLGym_jax-0.5.dist-info/RECORD +0 -44
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/planner.py
CHANGED
|
@@ -31,22 +31,24 @@ from pyRDDLGym.core.policy import BaseAgent
|
|
|
31
31
|
from pyRDDLGym_jax import __version__
|
|
32
32
|
from pyRDDLGym_jax.core import logic
|
|
33
33
|
from pyRDDLGym_jax.core.compiler import JaxRDDLCompiler
|
|
34
|
-
from pyRDDLGym_jax.core.logic import FuzzyLogic
|
|
34
|
+
from pyRDDLGym_jax.core.logic import Logic, FuzzyLogic
|
|
35
35
|
|
|
36
|
-
# try to
|
|
36
|
+
# try to load the dash board
|
|
37
37
|
try:
|
|
38
|
-
|
|
38
|
+
from pyRDDLGym_jax.core.visualization import JaxPlannerDashboard
|
|
39
39
|
except Exception:
|
|
40
|
-
raise_warning('
|
|
41
|
-
'
|
|
40
|
+
raise_warning('Failed to load the dashboard visualization tool: '
|
|
41
|
+
'please make sure you have installed the required packages.',
|
|
42
|
+
'red')
|
|
42
43
|
traceback.print_exc()
|
|
43
|
-
|
|
44
|
-
|
|
44
|
+
JaxPlannerDashboard = None
|
|
45
|
+
|
|
45
46
|
Activation = Callable[[jnp.ndarray], jnp.ndarray]
|
|
46
47
|
Bounds = Dict[str, Tuple[np.ndarray, np.ndarray]]
|
|
47
48
|
Kwargs = Dict[str, Any]
|
|
48
49
|
Pytree = Any
|
|
49
50
|
|
|
51
|
+
|
|
50
52
|
# ***********************************************************************
|
|
51
53
|
# CONFIG FILE MANAGEMENT
|
|
52
54
|
#
|
|
@@ -95,21 +97,25 @@ def _load_config(config, args):
|
|
|
95
97
|
# read the model settings
|
|
96
98
|
logic_name = model_args.get('logic', 'FuzzyLogic')
|
|
97
99
|
logic_kwargs = model_args.get('logic_kwargs', {})
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
100
|
+
if logic_name == 'FuzzyLogic':
|
|
101
|
+
tnorm_name = model_args.get('tnorm', 'ProductTNorm')
|
|
102
|
+
tnorm_kwargs = model_args.get('tnorm_kwargs', {})
|
|
103
|
+
comp_name = model_args.get('complement', 'StandardComplement')
|
|
104
|
+
comp_kwargs = model_args.get('complement_kwargs', {})
|
|
105
|
+
compare_name = model_args.get('comparison', 'SigmoidComparison')
|
|
106
|
+
compare_kwargs = model_args.get('comparison_kwargs', {})
|
|
107
|
+
sampling_name = model_args.get('sampling', 'GumbelSoftmax')
|
|
108
|
+
sampling_kwargs = model_args.get('sampling_kwargs', {})
|
|
109
|
+
rounding_name = model_args.get('rounding', 'SoftRounding')
|
|
110
|
+
rounding_kwargs = model_args.get('rounding_kwargs', {})
|
|
111
|
+
control_name = model_args.get('control', 'SoftControlFlow')
|
|
112
|
+
control_kwargs = model_args.get('control_kwargs', {})
|
|
113
|
+
logic_kwargs['tnorm'] = getattr(logic, tnorm_name)(**tnorm_kwargs)
|
|
114
|
+
logic_kwargs['complement'] = getattr(logic, comp_name)(**comp_kwargs)
|
|
115
|
+
logic_kwargs['comparison'] = getattr(logic, compare_name)(**compare_kwargs)
|
|
116
|
+
logic_kwargs['sampling'] = getattr(logic, sampling_name)(**sampling_kwargs)
|
|
117
|
+
logic_kwargs['rounding'] = getattr(logic, rounding_name)(**rounding_kwargs)
|
|
118
|
+
logic_kwargs['control'] = getattr(logic, control_name)(**control_kwargs)
|
|
113
119
|
|
|
114
120
|
# read the policy settings
|
|
115
121
|
plan_method = planner_args.pop('method')
|
|
@@ -165,6 +171,13 @@ def _load_config(config, args):
|
|
|
165
171
|
if planner_key is not None:
|
|
166
172
|
train_args['key'] = random.PRNGKey(planner_key)
|
|
167
173
|
|
|
174
|
+
# dashboard
|
|
175
|
+
dashboard_key = train_args.get('dashboard', None)
|
|
176
|
+
if dashboard_key is not None and dashboard_key and JaxPlannerDashboard is not None:
|
|
177
|
+
train_args['dashboard'] = JaxPlannerDashboard()
|
|
178
|
+
elif dashboard_key is not None:
|
|
179
|
+
del train_args['dashboard']
|
|
180
|
+
|
|
168
181
|
# optimize call stopping rule
|
|
169
182
|
stopping_rule = train_args.get('stopping_rule', None)
|
|
170
183
|
if stopping_rule is not None:
|
|
@@ -186,6 +199,7 @@ def load_config_from_string(value: str) -> Tuple[Kwargs, ...]:
|
|
|
186
199
|
config, args = _parse_config_string(value)
|
|
187
200
|
return _load_config(config, args)
|
|
188
201
|
|
|
202
|
+
|
|
189
203
|
# ***********************************************************************
|
|
190
204
|
# MODEL RELAXATIONS
|
|
191
205
|
#
|
|
@@ -202,7 +216,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
202
216
|
'''
|
|
203
217
|
|
|
204
218
|
def __init__(self, *args,
|
|
205
|
-
logic:
|
|
219
|
+
logic: Logic=FuzzyLogic(),
|
|
206
220
|
cpfs_without_grad: Optional[Set[str]]=None,
|
|
207
221
|
**kwargs) -> None:
|
|
208
222
|
'''Creates a new RDDL to Jax compiler, where operations that are not
|
|
@@ -237,57 +251,55 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
237
251
|
|
|
238
252
|
# overwrite basic operations with fuzzy ones
|
|
239
253
|
self.RELATIONAL_OPS = {
|
|
240
|
-
'>=': logic.greater_equal
|
|
241
|
-
'<=': logic.less_equal
|
|
242
|
-
'<': logic.less
|
|
243
|
-
'>': logic.greater
|
|
244
|
-
'==': logic.equal
|
|
245
|
-
'~=': logic.not_equal
|
|
254
|
+
'>=': logic.greater_equal,
|
|
255
|
+
'<=': logic.less_equal,
|
|
256
|
+
'<': logic.less,
|
|
257
|
+
'>': logic.greater,
|
|
258
|
+
'==': logic.equal,
|
|
259
|
+
'~=': logic.not_equal
|
|
246
260
|
}
|
|
247
|
-
self.LOGICAL_NOT = logic.logical_not
|
|
261
|
+
self.LOGICAL_NOT = logic.logical_not
|
|
248
262
|
self.LOGICAL_OPS = {
|
|
249
|
-
'^': logic.logical_and
|
|
250
|
-
'&': logic.logical_and
|
|
251
|
-
'|': logic.logical_or
|
|
252
|
-
'~': logic.xor
|
|
253
|
-
'=>': logic.implies
|
|
254
|
-
'<=>': logic.equiv
|
|
263
|
+
'^': logic.logical_and,
|
|
264
|
+
'&': logic.logical_and,
|
|
265
|
+
'|': logic.logical_or,
|
|
266
|
+
'~': logic.xor,
|
|
267
|
+
'=>': logic.implies,
|
|
268
|
+
'<=>': logic.equiv
|
|
255
269
|
}
|
|
256
|
-
self.AGGREGATION_OPS['forall'] = logic.forall
|
|
257
|
-
self.AGGREGATION_OPS['exists'] = logic.exists
|
|
258
|
-
self.AGGREGATION_OPS['argmin'] = logic.argmin
|
|
259
|
-
self.AGGREGATION_OPS['argmax'] = logic.argmax
|
|
260
|
-
self.KNOWN_UNARY['sgn'] = logic.sgn
|
|
261
|
-
self.KNOWN_UNARY['floor'] = logic.floor
|
|
262
|
-
self.KNOWN_UNARY['ceil'] = logic.ceil
|
|
263
|
-
self.KNOWN_UNARY['round'] = logic.round
|
|
264
|
-
self.KNOWN_UNARY['sqrt'] = logic.sqrt
|
|
265
|
-
self.KNOWN_BINARY['div'] = logic.div
|
|
266
|
-
self.KNOWN_BINARY['mod'] = logic.mod
|
|
267
|
-
self.KNOWN_BINARY['fmod'] = logic.mod
|
|
268
|
-
self.IF_HELPER = logic.control_if
|
|
269
|
-
self.SWITCH_HELPER = logic.control_switch
|
|
270
|
-
self.BERNOULLI_HELPER = logic.bernoulli
|
|
271
|
-
self.DISCRETE_HELPER = logic.discrete
|
|
272
|
-
self.POISSON_HELPER = logic.poisson
|
|
273
|
-
self.GEOMETRIC_HELPER = logic.geometric
|
|
274
|
-
|
|
275
|
-
def _jax_stop_grad(self, jax_expr):
|
|
276
|
-
|
|
270
|
+
self.AGGREGATION_OPS['forall'] = logic.forall
|
|
271
|
+
self.AGGREGATION_OPS['exists'] = logic.exists
|
|
272
|
+
self.AGGREGATION_OPS['argmin'] = logic.argmin
|
|
273
|
+
self.AGGREGATION_OPS['argmax'] = logic.argmax
|
|
274
|
+
self.KNOWN_UNARY['sgn'] = logic.sgn
|
|
275
|
+
self.KNOWN_UNARY['floor'] = logic.floor
|
|
276
|
+
self.KNOWN_UNARY['ceil'] = logic.ceil
|
|
277
|
+
self.KNOWN_UNARY['round'] = logic.round
|
|
278
|
+
self.KNOWN_UNARY['sqrt'] = logic.sqrt
|
|
279
|
+
self.KNOWN_BINARY['div'] = logic.div
|
|
280
|
+
self.KNOWN_BINARY['mod'] = logic.mod
|
|
281
|
+
self.KNOWN_BINARY['fmod'] = logic.mod
|
|
282
|
+
self.IF_HELPER = logic.control_if
|
|
283
|
+
self.SWITCH_HELPER = logic.control_switch
|
|
284
|
+
self.BERNOULLI_HELPER = logic.bernoulli
|
|
285
|
+
self.DISCRETE_HELPER = logic.discrete
|
|
286
|
+
self.POISSON_HELPER = logic.poisson
|
|
287
|
+
self.GEOMETRIC_HELPER = logic.geometric
|
|
288
|
+
|
|
289
|
+
def _jax_stop_grad(self, jax_expr):
|
|
277
290
|
def _jax_wrapped_stop_grad(x, params, key):
|
|
278
|
-
sample, key, error = jax_expr(x, params, key)
|
|
291
|
+
sample, key, error, params = jax_expr(x, params, key)
|
|
279
292
|
sample = jax.lax.stop_gradient(sample)
|
|
280
|
-
return sample, key, error
|
|
281
|
-
|
|
293
|
+
return sample, key, error, params
|
|
282
294
|
return _jax_wrapped_stop_grad
|
|
283
295
|
|
|
284
|
-
def _compile_cpfs(self,
|
|
296
|
+
def _compile_cpfs(self, init_params):
|
|
285
297
|
cpfs_cast = set()
|
|
286
298
|
jax_cpfs = {}
|
|
287
299
|
for (_, cpfs) in self.levels.items():
|
|
288
300
|
for cpf in cpfs:
|
|
289
301
|
_, expr = self.rddl.cpfs[cpf]
|
|
290
|
-
jax_cpfs[cpf] = self._jax(expr,
|
|
302
|
+
jax_cpfs[cpf] = self._jax(expr, init_params, dtype=self.REAL)
|
|
291
303
|
if self.rddl.variable_ranges[cpf] != 'real':
|
|
292
304
|
cpfs_cast.add(cpf)
|
|
293
305
|
if cpf in self.cpfs_without_grad:
|
|
@@ -298,17 +310,16 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
298
310
|
f'{cpfs_cast} be cast to float.')
|
|
299
311
|
if self.cpfs_without_grad:
|
|
300
312
|
raise_warning(f'User requested that gradients not flow '
|
|
301
|
-
f'through CPFs {self.cpfs_without_grad}.')
|
|
313
|
+
f'through CPFs {self.cpfs_without_grad}.')
|
|
314
|
+
|
|
302
315
|
return jax_cpfs
|
|
303
316
|
|
|
304
|
-
def _jax_kron(self, expr,
|
|
305
|
-
if self.logic.verbose:
|
|
306
|
-
raise_warning('JAX gradient compiler ignores KronDelta '
|
|
307
|
-
'during compilation.')
|
|
317
|
+
def _jax_kron(self, expr, init_params):
|
|
308
318
|
arg, = expr.args
|
|
309
|
-
arg = self._jax(arg,
|
|
319
|
+
arg = self._jax(arg, init_params)
|
|
310
320
|
return arg
|
|
311
321
|
|
|
322
|
+
|
|
312
323
|
# ***********************************************************************
|
|
313
324
|
# ALL VERSIONS OF JAX PLANS
|
|
314
325
|
#
|
|
@@ -329,7 +340,7 @@ class JaxPlan:
|
|
|
329
340
|
self.bounds = None
|
|
330
341
|
|
|
331
342
|
def summarize_hyperparameters(self) -> None:
|
|
332
|
-
|
|
343
|
+
print(self.__str__())
|
|
333
344
|
|
|
334
345
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
335
346
|
_bounds: Bounds,
|
|
@@ -455,21 +466,21 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
455
466
|
self._wrap_softmax = wrap_softmax
|
|
456
467
|
self._use_new_projection = use_new_projection
|
|
457
468
|
self._max_constraint_iter = max_constraint_iter
|
|
458
|
-
|
|
459
|
-
def
|
|
469
|
+
|
|
470
|
+
def __str__(self) -> str:
|
|
460
471
|
bounds = '\n '.join(
|
|
461
472
|
map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
+
return (f'policy hyper-parameters:\n'
|
|
474
|
+
f' initializer ={self._initializer_base}\n'
|
|
475
|
+
f'constraint-sat strategy (simple):\n'
|
|
476
|
+
f' parsed_action_bounds =\n {bounds}\n'
|
|
477
|
+
f' wrap_sigmoid ={self._wrap_sigmoid}\n'
|
|
478
|
+
f' wrap_sigmoid_min_prob={self._min_action_prob}\n'
|
|
479
|
+
f' wrap_non_bool ={self._wrap_non_bool}\n'
|
|
480
|
+
f'constraint-sat strategy (complex):\n'
|
|
481
|
+
f' wrap_softmax ={self._wrap_softmax}\n'
|
|
482
|
+
f' use_new_projection ={self._use_new_projection}\n'
|
|
483
|
+
f' max_projection_iters ={self._max_constraint_iter}')
|
|
473
484
|
|
|
474
485
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
475
486
|
_bounds: Bounds,
|
|
@@ -820,21 +831,21 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
820
831
|
normalizer_kwargs = {'create_offset': True, 'create_scale': True}
|
|
821
832
|
self._normalizer_kwargs = normalizer_kwargs
|
|
822
833
|
self._wrap_non_bool = wrap_non_bool
|
|
823
|
-
|
|
824
|
-
def
|
|
834
|
+
|
|
835
|
+
def __str__(self) -> str:
|
|
825
836
|
bounds = '\n '.join(
|
|
826
837
|
map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
+
return (f'policy hyper-parameters:\n'
|
|
839
|
+
f' topology ={self._topology}\n'
|
|
840
|
+
f' activation_fn ={self._activations[0].__name__}\n'
|
|
841
|
+
f' initializer ={type(self._initializer_base).__name__}\n'
|
|
842
|
+
f' apply_input_norm ={self._normalize}\n'
|
|
843
|
+
f' input_norm_layerwise={self._normalize_per_layer}\n'
|
|
844
|
+
f' input_norm_args ={self._normalizer_kwargs}\n'
|
|
845
|
+
f'constraint-sat strategy:\n'
|
|
846
|
+
f' parsed_action_bounds=\n {bounds}\n'
|
|
847
|
+
f' wrap_non_bool ={self._wrap_non_bool}')
|
|
848
|
+
|
|
838
849
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
839
850
|
_bounds: Bounds,
|
|
840
851
|
horizon: int) -> None:
|
|
@@ -1057,6 +1068,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1057
1068
|
def guess_next_epoch(self, params: Pytree) -> Pytree:
|
|
1058
1069
|
return params
|
|
1059
1070
|
|
|
1071
|
+
|
|
1060
1072
|
# ***********************************************************************
|
|
1061
1073
|
# ALL VERSIONS OF JAX PLANNER
|
|
1062
1074
|
#
|
|
@@ -1084,113 +1096,6 @@ class RollingMean:
|
|
|
1084
1096
|
return self._total / len(memory)
|
|
1085
1097
|
|
|
1086
1098
|
|
|
1087
|
-
class JaxPlannerPlot:
|
|
1088
|
-
'''Supports plotting and visualization of a JAX policy in real time.'''
|
|
1089
|
-
|
|
1090
|
-
def __init__(self, rddl: RDDLPlanningModel, horizon: int,
|
|
1091
|
-
show_violin: bool=True, show_action: bool=True) -> None:
|
|
1092
|
-
'''Creates a new planner visualizer.
|
|
1093
|
-
|
|
1094
|
-
:param rddl: the planning model to optimize
|
|
1095
|
-
:param horizon: the lookahead or planning horizon
|
|
1096
|
-
:param show_violin: whether to show the distribution of batch losses
|
|
1097
|
-
:param show_action: whether to show heatmaps of the action fluents
|
|
1098
|
-
'''
|
|
1099
|
-
num_plots = 1
|
|
1100
|
-
if show_violin:
|
|
1101
|
-
num_plots += 1
|
|
1102
|
-
if show_action:
|
|
1103
|
-
num_plots += len(rddl.action_fluents)
|
|
1104
|
-
self._fig, axes = plt.subplots(num_plots)
|
|
1105
|
-
if num_plots == 1:
|
|
1106
|
-
axes = [axes]
|
|
1107
|
-
|
|
1108
|
-
# prepare the loss plot
|
|
1109
|
-
self._loss_ax = axes[0]
|
|
1110
|
-
self._loss_ax.autoscale(enable=True)
|
|
1111
|
-
self._loss_ax.set_xlabel('training time')
|
|
1112
|
-
self._loss_ax.set_ylabel('loss value')
|
|
1113
|
-
self._loss_plot = self._loss_ax.plot(
|
|
1114
|
-
[], [], linestyle=':', marker='o', markersize=2)[0]
|
|
1115
|
-
self._loss_back = self._fig.canvas.copy_from_bbox(self._loss_ax.bbox)
|
|
1116
|
-
|
|
1117
|
-
# prepare the violin plot
|
|
1118
|
-
if show_violin:
|
|
1119
|
-
self._hist_ax = axes[1]
|
|
1120
|
-
else:
|
|
1121
|
-
self._hist_ax = None
|
|
1122
|
-
|
|
1123
|
-
# prepare the action plots
|
|
1124
|
-
if show_action:
|
|
1125
|
-
self._action_ax = {name: axes[idx + (2 if show_violin else 1)]
|
|
1126
|
-
for (idx, name) in enumerate(rddl.action_fluents)}
|
|
1127
|
-
self._action_plots = {}
|
|
1128
|
-
for name in rddl.action_fluents:
|
|
1129
|
-
ax = self._action_ax[name]
|
|
1130
|
-
if rddl.variable_ranges[name] == 'bool':
|
|
1131
|
-
vmin, vmax = 0.0, 1.0
|
|
1132
|
-
else:
|
|
1133
|
-
vmin, vmax = None, None
|
|
1134
|
-
action_dim = 1
|
|
1135
|
-
for dim in rddl.object_counts(rddl.variable_params[name]):
|
|
1136
|
-
action_dim *= dim
|
|
1137
|
-
action_plot = ax.pcolormesh(
|
|
1138
|
-
np.zeros((action_dim, horizon)),
|
|
1139
|
-
cmap='seismic', vmin=vmin, vmax=vmax)
|
|
1140
|
-
ax.set_aspect('auto')
|
|
1141
|
-
ax.set_xlabel('decision epoch')
|
|
1142
|
-
ax.set_ylabel(name)
|
|
1143
|
-
plt.colorbar(action_plot, ax=ax)
|
|
1144
|
-
self._action_plots[name] = action_plot
|
|
1145
|
-
self._action_back = {name: self._fig.canvas.copy_from_bbox(ax.bbox)
|
|
1146
|
-
for (name, ax) in self._action_ax.items()}
|
|
1147
|
-
else:
|
|
1148
|
-
self._action_ax = None
|
|
1149
|
-
self._action_plots = None
|
|
1150
|
-
self._action_back = None
|
|
1151
|
-
|
|
1152
|
-
plt.tight_layout()
|
|
1153
|
-
plt.show(block=False)
|
|
1154
|
-
|
|
1155
|
-
def redraw(self, xticks, losses, actions, returns) -> None:
|
|
1156
|
-
|
|
1157
|
-
# draw the loss curve
|
|
1158
|
-
self._fig.canvas.restore_region(self._loss_back)
|
|
1159
|
-
self._loss_plot.set_xdata(xticks)
|
|
1160
|
-
self._loss_plot.set_ydata(losses)
|
|
1161
|
-
self._loss_ax.set_xlim([0, len(xticks)])
|
|
1162
|
-
self._loss_ax.set_ylim([np.min(losses), np.max(losses)])
|
|
1163
|
-
self._loss_ax.draw_artist(self._loss_plot)
|
|
1164
|
-
self._fig.canvas.blit(self._loss_ax.bbox)
|
|
1165
|
-
|
|
1166
|
-
# draw the violin plot
|
|
1167
|
-
if self._hist_ax is not None:
|
|
1168
|
-
self._hist_ax.clear()
|
|
1169
|
-
self._hist_ax.set_xlabel('loss value')
|
|
1170
|
-
self._hist_ax.set_ylabel('density')
|
|
1171
|
-
self._hist_ax.violinplot(returns, vert=False, showmeans=True)
|
|
1172
|
-
|
|
1173
|
-
# draw the actions
|
|
1174
|
-
if self._action_ax is not None:
|
|
1175
|
-
for (name, values) in actions.items():
|
|
1176
|
-
values = np.mean(values, axis=0, dtype=float)
|
|
1177
|
-
values = np.reshape(values, newshape=(values.shape[0], -1)).T
|
|
1178
|
-
self._fig.canvas.restore_region(self._action_back[name])
|
|
1179
|
-
self._action_plots[name].set_array(values)
|
|
1180
|
-
self._action_ax[name].draw_artist(self._action_plots[name])
|
|
1181
|
-
self._fig.canvas.blit(self._action_ax[name].bbox)
|
|
1182
|
-
self._action_plots[name].set_clim([np.min(values), np.max(values)])
|
|
1183
|
-
|
|
1184
|
-
self._fig.canvas.draw()
|
|
1185
|
-
self._fig.canvas.flush_events()
|
|
1186
|
-
|
|
1187
|
-
def close(self) -> None:
|
|
1188
|
-
plt.close(self._fig)
|
|
1189
|
-
del self._loss_ax, self._hist_ax, self._action_ax, \
|
|
1190
|
-
self._loss_plot, self._action_plots, self._fig, \
|
|
1191
|
-
self._loss_back, self._action_back
|
|
1192
|
-
|
|
1193
|
-
|
|
1194
1099
|
class JaxPlannerStatus(Enum):
|
|
1195
1100
|
'''Represents the status of a policy update from the JAX planner,
|
|
1196
1101
|
including whether the update resulted in nan gradient,
|
|
@@ -1198,14 +1103,15 @@ class JaxPlannerStatus(Enum):
|
|
|
1198
1103
|
can be used to monitor and act based on the planner's progress.'''
|
|
1199
1104
|
|
|
1200
1105
|
NORMAL = 0
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
|
|
1106
|
+
STOPPING_RULE_REACHED = 1
|
|
1107
|
+
NO_PROGRESS = 2
|
|
1108
|
+
PRECONDITION_POSSIBLY_UNSATISFIED = 3
|
|
1109
|
+
INVALID_GRADIENT = 4
|
|
1110
|
+
TIME_BUDGET_REACHED = 5
|
|
1111
|
+
ITER_BUDGET_REACHED = 6
|
|
1206
1112
|
|
|
1207
|
-
def
|
|
1208
|
-
return self.value >=
|
|
1113
|
+
def is_terminal(self) -> bool:
|
|
1114
|
+
return self.value == 1 or self.value >= 4
|
|
1209
1115
|
|
|
1210
1116
|
|
|
1211
1117
|
class JaxPlannerStoppingRule:
|
|
@@ -1255,15 +1161,16 @@ class JaxBackpropPlanner:
|
|
|
1255
1161
|
optimizer: Callable[..., optax.GradientTransformation]=optax.rmsprop,
|
|
1256
1162
|
optimizer_kwargs: Optional[Kwargs]=None,
|
|
1257
1163
|
clip_grad: Optional[float]=None,
|
|
1258
|
-
|
|
1259
|
-
|
|
1260
|
-
logic:
|
|
1164
|
+
line_search_kwargs: Optional[Kwargs]=None,
|
|
1165
|
+
noise_kwargs: Optional[Kwargs]=None,
|
|
1166
|
+
logic: Logic=FuzzyLogic(),
|
|
1261
1167
|
use_symlog_reward: bool=False,
|
|
1262
1168
|
utility: Union[Callable[[jnp.ndarray], float], str]='mean',
|
|
1263
1169
|
utility_kwargs: Optional[Kwargs]=None,
|
|
1264
1170
|
cpfs_without_grad: Optional[Set[str]]=None,
|
|
1265
1171
|
compile_non_fluent_exact: bool=True,
|
|
1266
|
-
logger: Optional[Logger]=None
|
|
1172
|
+
logger: Optional[Logger]=None,
|
|
1173
|
+
dashboard_viz: Optional[Any]=None) -> None:
|
|
1267
1174
|
'''Creates a new gradient-based algorithm for optimizing action sequences
|
|
1268
1175
|
(plan) in the given RDDL. Some operations will be converted to their
|
|
1269
1176
|
differentiable counterparts; the specific operations can be customized
|
|
@@ -1283,9 +1190,10 @@ class JaxBackpropPlanner:
|
|
|
1283
1190
|
:param optimizer_kwargs: a dictionary of parameters to pass to the SGD
|
|
1284
1191
|
factory (e.g. which parameters are controllable externally)
|
|
1285
1192
|
:param clip_grad: maximum magnitude of gradient updates
|
|
1286
|
-
:param
|
|
1287
|
-
|
|
1288
|
-
:param
|
|
1193
|
+
:param line_search_kwargs: parameters to pass to optional line search
|
|
1194
|
+
method to scale learning rate
|
|
1195
|
+
:param noise_kwargs: parameters of optional gradient noise
|
|
1196
|
+
:param logic: a subclass of Logic for mapping exact mathematical
|
|
1289
1197
|
operations to their differentiable counterparts
|
|
1290
1198
|
:param use_symlog_reward: whether to use the symlog transform on the
|
|
1291
1199
|
reward as a form of normalization
|
|
@@ -1300,6 +1208,8 @@ class JaxBackpropPlanner:
|
|
|
1300
1208
|
:param compile_non_fluent_exact: whether non-fluent expressions
|
|
1301
1209
|
are always compiled using exact JAX expressions
|
|
1302
1210
|
:param logger: to log information about compilation to file
|
|
1211
|
+
:param dashboard_viz: optional visualizer object from the environment
|
|
1212
|
+
to pass to the dashboard to visualize the policy
|
|
1303
1213
|
'''
|
|
1304
1214
|
self.rddl = rddl
|
|
1305
1215
|
self.plan = plan
|
|
@@ -1314,31 +1224,34 @@ class JaxBackpropPlanner:
|
|
|
1314
1224
|
action_bounds = {}
|
|
1315
1225
|
self._action_bounds = action_bounds
|
|
1316
1226
|
self.use64bit = use64bit
|
|
1317
|
-
self.
|
|
1227
|
+
self.optimizer_name = optimizer
|
|
1318
1228
|
if optimizer_kwargs is None:
|
|
1319
1229
|
optimizer_kwargs = {'learning_rate': 0.1}
|
|
1320
|
-
self.
|
|
1230
|
+
self.optimizer_kwargs = optimizer_kwargs
|
|
1321
1231
|
self.clip_grad = clip_grad
|
|
1322
|
-
self.
|
|
1323
|
-
self.
|
|
1232
|
+
self.line_search_kwargs = line_search_kwargs
|
|
1233
|
+
self.noise_kwargs = noise_kwargs
|
|
1324
1234
|
|
|
1325
1235
|
# set optimizer
|
|
1326
1236
|
try:
|
|
1327
1237
|
optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs)
|
|
1328
1238
|
except Exception as _:
|
|
1329
1239
|
raise_warning(
|
|
1330
|
-
'Failed to inject hyperparameters into optax optimizer, '
|
|
1240
|
+
f'Failed to inject hyperparameters into optax optimizer {optimizer}, '
|
|
1331
1241
|
'rolling back to safer method: please note that modification of '
|
|
1332
|
-
'optimizer hyperparameters will not work,
|
|
1333
|
-
|
|
1334
|
-
|
|
1335
|
-
|
|
1336
|
-
|
|
1337
|
-
|
|
1338
|
-
|
|
1339
|
-
|
|
1340
|
-
|
|
1341
|
-
|
|
1242
|
+
'optimizer hyperparameters will not work.', 'red')
|
|
1243
|
+
optimizer = optimizer(**optimizer_kwargs)
|
|
1244
|
+
|
|
1245
|
+
# apply optimizer chain of transformations
|
|
1246
|
+
pipeline = []
|
|
1247
|
+
if clip_grad is not None:
|
|
1248
|
+
pipeline.append(optax.clip(clip_grad))
|
|
1249
|
+
if noise_kwargs is not None:
|
|
1250
|
+
pipeline.append(optax.add_noise(**noise_kwargs))
|
|
1251
|
+
pipeline.append(optimizer)
|
|
1252
|
+
if line_search_kwargs is not None:
|
|
1253
|
+
pipeline.append(optax.scale_by_zoom_linesearch(**line_search_kwargs))
|
|
1254
|
+
self.optimizer = optax.chain(*pipeline)
|
|
1342
1255
|
|
|
1343
1256
|
# set utility
|
|
1344
1257
|
if isinstance(utility, str):
|
|
@@ -1371,11 +1284,12 @@ class JaxBackpropPlanner:
|
|
|
1371
1284
|
self.cpfs_without_grad = cpfs_without_grad
|
|
1372
1285
|
self.compile_non_fluent_exact = compile_non_fluent_exact
|
|
1373
1286
|
self.logger = logger
|
|
1287
|
+
self.dashboard_viz = dashboard_viz
|
|
1374
1288
|
|
|
1375
1289
|
self._jax_compile_rddl()
|
|
1376
1290
|
self._jax_compile_optimizer()
|
|
1377
1291
|
|
|
1378
|
-
def
|
|
1292
|
+
def summarize_system(self) -> str:
|
|
1379
1293
|
try:
|
|
1380
1294
|
jaxlib_version = jax._src.lib.version_str
|
|
1381
1295
|
except Exception as _:
|
|
@@ -1394,36 +1308,55 @@ r"""
|
|
|
1394
1308
|
\/_____/ \/_/\/_/ \/_/\/_/ \/_/ \/_____/ \/_/\/_/ \/_/ \/_/
|
|
1395
1309
|
"""
|
|
1396
1310
|
|
|
1397
|
-
|
|
1398
|
-
|
|
1399
|
-
|
|
1400
|
-
|
|
1401
|
-
|
|
1402
|
-
|
|
1403
|
-
|
|
1404
|
-
|
|
1311
|
+
return ('\n'
|
|
1312
|
+
f'{LOGO}\n'
|
|
1313
|
+
f'Version {__version__}\n'
|
|
1314
|
+
f'Python {sys.version}\n'
|
|
1315
|
+
f'jax {jax.version.__version__}, jaxlib {jaxlib_version}, '
|
|
1316
|
+
f'optax {optax.__version__}, haiku {hk.__version__}, '
|
|
1317
|
+
f'numpy {np.__version__}\n'
|
|
1318
|
+
f'devices: {devices_short}\n')
|
|
1319
|
+
|
|
1320
|
+
def __str__(self) -> str:
|
|
1321
|
+
result = (f'objective hyper-parameters:\n'
|
|
1322
|
+
f' utility_fn ={self.utility.__name__}\n'
|
|
1323
|
+
f' utility args ={self.utility_kwargs}\n'
|
|
1324
|
+
f' use_symlog ={self.use_symlog_reward}\n'
|
|
1325
|
+
f' lookahead ={self.horizon}\n'
|
|
1326
|
+
f' user_action_bounds={self._action_bounds}\n'
|
|
1327
|
+
f' fuzzy logic type ={type(self.logic).__name__}\n'
|
|
1328
|
+
f' non_fluents exact ={self.compile_non_fluent_exact}\n'
|
|
1329
|
+
f' cpfs_no_gradient ={self.cpfs_without_grad}\n'
|
|
1330
|
+
f'optimizer hyper-parameters:\n'
|
|
1331
|
+
f' use_64_bit ={self.use64bit}\n'
|
|
1332
|
+
f' optimizer ={self.optimizer_name}\n'
|
|
1333
|
+
f' optimizer args ={self.optimizer_kwargs}\n'
|
|
1334
|
+
f' clip_gradient ={self.clip_grad}\n'
|
|
1335
|
+
f' line_search_kwargs={self.line_search_kwargs}\n'
|
|
1336
|
+
f' noise_kwargs ={self.noise_kwargs}\n'
|
|
1337
|
+
f' batch_size_train ={self.batch_size_train}\n'
|
|
1338
|
+
f' batch_size_test ={self.batch_size_test}')
|
|
1339
|
+
result += '\n' + str(self.plan)
|
|
1340
|
+
result += '\n' + str(self.logic)
|
|
1341
|
+
|
|
1342
|
+
# print model relaxation information
|
|
1343
|
+
if not self.compiled.model_params:
|
|
1344
|
+
return result
|
|
1345
|
+
result += '\n' + ('Some RDDL operations are non-differentiable '
|
|
1346
|
+
'and will be approximated as follows:' + '\n')
|
|
1347
|
+
exprs_by_rddl_op, values_by_rddl_op = {}, {}
|
|
1348
|
+
for info in self.compiled.model_parameter_info().values():
|
|
1349
|
+
rddl_op = info['rddl_op']
|
|
1350
|
+
exprs_by_rddl_op.setdefault(rddl_op, []).append(info['id'])
|
|
1351
|
+
values_by_rddl_op.setdefault(rddl_op, []).append(info['init_value'])
|
|
1352
|
+
for rddl_op in sorted(exprs_by_rddl_op.keys()):
|
|
1353
|
+
result += (f' {rddl_op}:\n'
|
|
1354
|
+
f' addresses ={exprs_by_rddl_op[rddl_op]}\n'
|
|
1355
|
+
f' init_values={values_by_rddl_op[rddl_op]}\n')
|
|
1356
|
+
return result
|
|
1405
1357
|
|
|
1406
1358
|
def summarize_hyperparameters(self) -> None:
|
|
1407
|
-
print(
|
|
1408
|
-
f' utility_fn ={self.utility.__name__}\n'
|
|
1409
|
-
f' utility args ={self.utility_kwargs}\n'
|
|
1410
|
-
f' use_symlog ={self.use_symlog_reward}\n'
|
|
1411
|
-
f' lookahead ={self.horizon}\n'
|
|
1412
|
-
f' user_action_bounds={self._action_bounds}\n'
|
|
1413
|
-
f' fuzzy logic type ={type(self.logic).__name__}\n'
|
|
1414
|
-
f' nonfluents exact ={self.compile_non_fluent_exact}\n'
|
|
1415
|
-
f' cpfs_no_gradient ={self.cpfs_without_grad}\n'
|
|
1416
|
-
f'optimizer hyper-parameters:\n'
|
|
1417
|
-
f' use_64_bit ={self.use64bit}\n'
|
|
1418
|
-
f' optimizer ={self._optimizer_name.__name__}\n'
|
|
1419
|
-
f' optimizer args ={self._optimizer_kwargs}\n'
|
|
1420
|
-
f' clip_gradient ={self.clip_grad}\n'
|
|
1421
|
-
f' noise_grad_eta ={self.noise_grad_eta}\n'
|
|
1422
|
-
f' noise_grad_gamma ={self.noise_grad_gamma}\n'
|
|
1423
|
-
f' batch_size_train ={self.batch_size_train}\n'
|
|
1424
|
-
f' batch_size_test ={self.batch_size_test}')
|
|
1425
|
-
self.plan.summarize_hyperparameters()
|
|
1426
|
-
self.logic.summarize_hyperparameters()
|
|
1359
|
+
print(self.__str__())
|
|
1427
1360
|
|
|
1428
1361
|
# ===========================================================================
|
|
1429
1362
|
# COMPILATION SUBROUTINES
|
|
@@ -1439,14 +1372,16 @@ r"""
|
|
|
1439
1372
|
logger=self.logger,
|
|
1440
1373
|
use64bit=self.use64bit,
|
|
1441
1374
|
cpfs_without_grad=self.cpfs_without_grad,
|
|
1442
|
-
compile_non_fluent_exact=self.compile_non_fluent_exact
|
|
1375
|
+
compile_non_fluent_exact=self.compile_non_fluent_exact
|
|
1376
|
+
)
|
|
1443
1377
|
self.compiled.compile(log_jax_expr=True, heading='RELAXED MODEL')
|
|
1444
1378
|
|
|
1445
1379
|
# Jax compilation of the exact RDDL for testing
|
|
1446
1380
|
self.test_compiled = JaxRDDLCompiler(
|
|
1447
1381
|
rddl=rddl,
|
|
1448
1382
|
logger=self.logger,
|
|
1449
|
-
use64bit=self.use64bit
|
|
1383
|
+
use64bit=self.use64bit
|
|
1384
|
+
)
|
|
1450
1385
|
self.test_compiled.compile(log_jax_expr=True, heading='EXACT MODEL')
|
|
1451
1386
|
|
|
1452
1387
|
def _jax_compile_optimizer(self):
|
|
@@ -1462,13 +1397,15 @@ r"""
|
|
|
1462
1397
|
train_rollouts = self.compiled.compile_rollouts(
|
|
1463
1398
|
policy=self.plan.train_policy,
|
|
1464
1399
|
n_steps=self.horizon,
|
|
1465
|
-
n_batch=self.batch_size_train
|
|
1400
|
+
n_batch=self.batch_size_train
|
|
1401
|
+
)
|
|
1466
1402
|
self.train_rollouts = train_rollouts
|
|
1467
1403
|
|
|
1468
1404
|
test_rollouts = self.test_compiled.compile_rollouts(
|
|
1469
1405
|
policy=self.plan.test_policy,
|
|
1470
1406
|
n_steps=self.horizon,
|
|
1471
|
-
n_batch=self.batch_size_test
|
|
1407
|
+
n_batch=self.batch_size_test
|
|
1408
|
+
)
|
|
1472
1409
|
self.test_rollouts = jax.jit(test_rollouts)
|
|
1473
1410
|
|
|
1474
1411
|
# initialization
|
|
@@ -1503,14 +1440,16 @@ r"""
|
|
|
1503
1440
|
_jax_wrapped_returns = self._jax_return(use_symlog)
|
|
1504
1441
|
|
|
1505
1442
|
# the loss is the average cumulative reward across all roll-outs
|
|
1506
|
-
def _jax_wrapped_plan_loss(key, policy_params,
|
|
1443
|
+
def _jax_wrapped_plan_loss(key, policy_params, policy_hyperparams,
|
|
1507
1444
|
subs, model_params):
|
|
1508
|
-
log = rollouts(
|
|
1445
|
+
log, model_params = rollouts(
|
|
1446
|
+
key, policy_params, policy_hyperparams, subs, model_params)
|
|
1509
1447
|
rewards = log['reward']
|
|
1510
1448
|
returns = _jax_wrapped_returns(rewards)
|
|
1511
1449
|
utility = utility_fn(returns, **utility_kwargs)
|
|
1512
1450
|
loss = -utility
|
|
1513
|
-
|
|
1451
|
+
aux = (log, model_params)
|
|
1452
|
+
return loss, aux
|
|
1514
1453
|
|
|
1515
1454
|
return _jax_wrapped_plan_loss
|
|
1516
1455
|
|
|
@@ -1518,8 +1457,9 @@ r"""
|
|
|
1518
1457
|
init = self.plan.initializer
|
|
1519
1458
|
optimizer = self.optimizer
|
|
1520
1459
|
|
|
1521
|
-
|
|
1522
|
-
|
|
1460
|
+
# initialize both the policy and its optimizer
|
|
1461
|
+
def _jax_wrapped_init_policy(key, policy_hyperparams, subs):
|
|
1462
|
+
policy_params = init(key, policy_hyperparams, subs)
|
|
1523
1463
|
opt_state = optimizer.init(policy_params)
|
|
1524
1464
|
return policy_params, opt_state, {}
|
|
1525
1465
|
|
|
@@ -1528,35 +1468,34 @@ r"""
|
|
|
1528
1468
|
def _jax_update(self, loss):
|
|
1529
1469
|
optimizer = self.optimizer
|
|
1530
1470
|
projection = self.plan.projection
|
|
1531
|
-
|
|
1532
|
-
# add Gaussian gradient noise per Neelakantan et al., 2016.
|
|
1533
|
-
def _jax_wrapped_gaussian_param_noise(key, grads, sigma):
|
|
1534
|
-
treedef = jax.tree_util.tree_structure(grads)
|
|
1535
|
-
keys_flat = random.split(key, num=treedef.num_leaves)
|
|
1536
|
-
keys_tree = jax.tree_util.tree_unflatten(treedef, keys_flat)
|
|
1537
|
-
new_grads = jax.tree_map(
|
|
1538
|
-
lambda g, k: g + sigma * random.normal(
|
|
1539
|
-
key=k, shape=g.shape, dtype=g.dtype),
|
|
1540
|
-
grads,
|
|
1541
|
-
keys_tree
|
|
1542
|
-
)
|
|
1543
|
-
return new_grads
|
|
1471
|
+
use_ls = self.line_search_kwargs is not None
|
|
1544
1472
|
|
|
1545
1473
|
# calculate the plan gradient w.r.t. return loss and update optimizer
|
|
1546
1474
|
# also perform a projection step to satisfy constraints on actions
|
|
1547
|
-
def
|
|
1475
|
+
def _jax_wrapped_loss_swapped(policy_params, key, policy_hyperparams,
|
|
1476
|
+
subs, model_params):
|
|
1477
|
+
return loss(key, policy_params, policy_hyperparams, subs, model_params)[0]
|
|
1478
|
+
|
|
1479
|
+
def _jax_wrapped_plan_update(key, policy_params, policy_hyperparams,
|
|
1548
1480
|
subs, model_params, opt_state, opt_aux):
|
|
1549
1481
|
grad_fn = jax.value_and_grad(loss, argnums=1, has_aux=True)
|
|
1550
|
-
(loss_val, log), grad = grad_fn(
|
|
1551
|
-
key, policy_params,
|
|
1552
|
-
|
|
1553
|
-
|
|
1554
|
-
|
|
1482
|
+
(loss_val, (log, model_params)), grad = grad_fn(
|
|
1483
|
+
key, policy_params, policy_hyperparams, subs, model_params)
|
|
1484
|
+
if use_ls:
|
|
1485
|
+
updates, opt_state = optimizer.update(
|
|
1486
|
+
grad, opt_state, params=policy_params,
|
|
1487
|
+
value=loss_val, grad=grad, value_fn=_jax_wrapped_loss_swapped,
|
|
1488
|
+
key=key, policy_hyperparams=policy_hyperparams, subs=subs,
|
|
1489
|
+
model_params=model_params)
|
|
1490
|
+
else:
|
|
1491
|
+
updates, opt_state = optimizer.update(
|
|
1492
|
+
grad, opt_state, params=policy_params)
|
|
1555
1493
|
policy_params = optax.apply_updates(policy_params, updates)
|
|
1556
|
-
policy_params, converged = projection(policy_params,
|
|
1494
|
+
policy_params, converged = projection(policy_params, policy_hyperparams)
|
|
1557
1495
|
log['grad'] = grad
|
|
1558
1496
|
log['updates'] = updates
|
|
1559
|
-
return policy_params, converged, opt_state, opt_aux,
|
|
1497
|
+
return policy_params, converged, opt_state, opt_aux, \
|
|
1498
|
+
loss_val, log, model_params
|
|
1560
1499
|
|
|
1561
1500
|
return jax.jit(_jax_wrapped_plan_update)
|
|
1562
1501
|
|
|
@@ -1583,7 +1522,6 @@ r"""
|
|
|
1583
1522
|
for (state, next_state) in rddl.next_state.items():
|
|
1584
1523
|
init_train[next_state] = init_train[state]
|
|
1585
1524
|
init_test[next_state] = init_test[state]
|
|
1586
|
-
|
|
1587
1525
|
return init_train, init_test
|
|
1588
1526
|
|
|
1589
1527
|
def as_optimization_problem(
|
|
@@ -1637,38 +1575,40 @@ r"""
|
|
|
1637
1575
|
loss_fn = self._jax_loss(self.train_rollouts)
|
|
1638
1576
|
|
|
1639
1577
|
@jax.jit
|
|
1640
|
-
def _loss_with_key(key, params_1d):
|
|
1578
|
+
def _loss_with_key(key, params_1d, model_params):
|
|
1641
1579
|
policy_params = unravel_fn(params_1d)
|
|
1642
|
-
loss_val, _ = loss_fn(
|
|
1643
|
-
|
|
1644
|
-
return loss_val
|
|
1580
|
+
loss_val, (_, model_params) = loss_fn(
|
|
1581
|
+
key, policy_params, policy_hyperparams, train_subs, model_params)
|
|
1582
|
+
return loss_val, model_params
|
|
1645
1583
|
|
|
1646
1584
|
@jax.jit
|
|
1647
|
-
def _grad_with_key(key, params_1d):
|
|
1585
|
+
def _grad_with_key(key, params_1d, model_params):
|
|
1648
1586
|
policy_params = unravel_fn(params_1d)
|
|
1649
1587
|
grad_fn = jax.grad(loss_fn, argnums=1, has_aux=True)
|
|
1650
|
-
grad_val, _ = grad_fn(
|
|
1651
|
-
|
|
1652
|
-
|
|
1653
|
-
return
|
|
1588
|
+
grad_val, (_, model_params) = grad_fn(
|
|
1589
|
+
key, policy_params, policy_hyperparams, train_subs, model_params)
|
|
1590
|
+
grad_val = jax.flatten_util.ravel_pytree(grad_val)[0]
|
|
1591
|
+
return grad_val, model_params
|
|
1654
1592
|
|
|
1655
1593
|
def _loss_function(params_1d):
|
|
1656
1594
|
nonlocal key
|
|
1595
|
+
nonlocal model_params
|
|
1657
1596
|
if loss_function_updates_key:
|
|
1658
1597
|
key, subkey = random.split(key)
|
|
1659
1598
|
else:
|
|
1660
1599
|
subkey = key
|
|
1661
|
-
loss_val = _loss_with_key(subkey, params_1d)
|
|
1600
|
+
loss_val, model_params = _loss_with_key(subkey, params_1d, model_params)
|
|
1662
1601
|
loss_val = float(loss_val)
|
|
1663
1602
|
return loss_val
|
|
1664
1603
|
|
|
1665
1604
|
def _grad_function(params_1d):
|
|
1666
1605
|
nonlocal key
|
|
1606
|
+
nonlocal model_params
|
|
1667
1607
|
if grad_function_updates_key:
|
|
1668
1608
|
key, subkey = random.split(key)
|
|
1669
1609
|
else:
|
|
1670
1610
|
subkey = key
|
|
1671
|
-
grad = _grad_with_key(subkey, params_1d)
|
|
1611
|
+
grad, model_params = _grad_with_key(subkey, params_1d, model_params)
|
|
1672
1612
|
grad = np.asarray(grad)
|
|
1673
1613
|
return grad
|
|
1674
1614
|
|
|
@@ -1683,9 +1623,9 @@ r"""
|
|
|
1683
1623
|
|
|
1684
1624
|
:param key: JAX PRNG key (derived from clock if not provided)
|
|
1685
1625
|
:param epochs: the maximum number of steps of gradient descent
|
|
1686
|
-
:param train_seconds: total time allocated for gradient descent
|
|
1687
|
-
:param
|
|
1688
|
-
:param
|
|
1626
|
+
:param train_seconds: total time allocated for gradient descent
|
|
1627
|
+
:param dashboard: dashboard to display training results
|
|
1628
|
+
:param dashboard_id: experiment id for the dashboard
|
|
1689
1629
|
:param model_params: optional model-parameters to override default
|
|
1690
1630
|
:param policy_hyperparams: hyper-parameters for the policy/plan, such as
|
|
1691
1631
|
weights for sigmoid wrapping boolean actions
|
|
@@ -1721,8 +1661,8 @@ r"""
|
|
|
1721
1661
|
def optimize_generator(self, key: Optional[random.PRNGKey]=None,
|
|
1722
1662
|
epochs: int=999999,
|
|
1723
1663
|
train_seconds: float=120.,
|
|
1724
|
-
|
|
1725
|
-
|
|
1664
|
+
dashboard: Optional[JaxPlannerDashboard]=None,
|
|
1665
|
+
dashboard_id: Optional[str]=None,
|
|
1726
1666
|
model_params: Optional[Dict[str, Any]]=None,
|
|
1727
1667
|
policy_hyperparams: Optional[Dict[str, Any]]=None,
|
|
1728
1668
|
subs: Optional[Dict[str, Any]]=None,
|
|
@@ -1738,9 +1678,9 @@ r"""
|
|
|
1738
1678
|
|
|
1739
1679
|
:param key: JAX PRNG key (derived from clock if not provided)
|
|
1740
1680
|
:param epochs: the maximum number of steps of gradient descent
|
|
1741
|
-
:param train_seconds: total time allocated for gradient descent
|
|
1742
|
-
:param
|
|
1743
|
-
:param
|
|
1681
|
+
:param train_seconds: total time allocated for gradient descent
|
|
1682
|
+
:param dashboard: dashboard to display training results
|
|
1683
|
+
:param dashboard_id: experiment id for the dashboard
|
|
1744
1684
|
:param model_params: optional model-parameters to override default
|
|
1745
1685
|
:param policy_hyperparams: hyper-parameters for the policy/plan, such as
|
|
1746
1686
|
weights for sigmoid wrapping boolean actions
|
|
@@ -1760,9 +1700,14 @@ r"""
|
|
|
1760
1700
|
start_time = time.time()
|
|
1761
1701
|
elapsed_outside_loop = 0
|
|
1762
1702
|
|
|
1703
|
+
# ======================================================================
|
|
1704
|
+
# INITIALIZATION OF HYPER-PARAMETERS
|
|
1705
|
+
# ======================================================================
|
|
1706
|
+
|
|
1763
1707
|
# if PRNG key is not provided
|
|
1764
1708
|
if key is None:
|
|
1765
1709
|
key = random.PRNGKey(round(time.time() * 1000))
|
|
1710
|
+
dash_key = key[1].item()
|
|
1766
1711
|
|
|
1767
1712
|
# if policy_hyperparams is not provided
|
|
1768
1713
|
if policy_hyperparams is None:
|
|
@@ -1789,7 +1734,7 @@ r"""
|
|
|
1789
1734
|
|
|
1790
1735
|
# print summary of parameters:
|
|
1791
1736
|
if print_summary:
|
|
1792
|
-
self.
|
|
1737
|
+
print(self.summarize_system())
|
|
1793
1738
|
self.summarize_hyperparameters()
|
|
1794
1739
|
print(f'optimize() call hyper-parameters:\n'
|
|
1795
1740
|
f' PRNG key ={key}\n'
|
|
@@ -1800,16 +1745,16 @@ r"""
|
|
|
1800
1745
|
f' override_subs_dict ={subs is not None}\n'
|
|
1801
1746
|
f' provide_param_guess={guess is not None}\n'
|
|
1802
1747
|
f' test_rolling_window={test_rolling_window}\n'
|
|
1803
|
-
f'
|
|
1804
|
-
f'
|
|
1748
|
+
f' dashboard ={dashboard is not None}\n'
|
|
1749
|
+
f' dashboard_id ={dashboard_id}\n'
|
|
1805
1750
|
f' print_summary ={print_summary}\n'
|
|
1806
1751
|
f' print_progress ={print_progress}\n'
|
|
1807
1752
|
f' stopping_rule ={stopping_rule}\n')
|
|
1808
|
-
|
|
1809
|
-
|
|
1810
|
-
|
|
1811
|
-
|
|
1812
|
-
|
|
1753
|
+
|
|
1754
|
+
# ======================================================================
|
|
1755
|
+
# INITIALIZATION OF STATE AND POLICY
|
|
1756
|
+
# ======================================================================
|
|
1757
|
+
|
|
1813
1758
|
# compute a batched version of the initial values
|
|
1814
1759
|
if subs is None:
|
|
1815
1760
|
subs = self.test_compiled.init_values
|
|
@@ -1841,6 +1786,10 @@ r"""
|
|
|
1841
1786
|
policy_params = guess
|
|
1842
1787
|
opt_state = self.optimizer.init(policy_params)
|
|
1843
1788
|
opt_aux = {}
|
|
1789
|
+
|
|
1790
|
+
# ======================================================================
|
|
1791
|
+
# INITIALIZATION OF RUNNING STATISTICS
|
|
1792
|
+
# ======================================================================
|
|
1844
1793
|
|
|
1845
1794
|
# initialize running statistics
|
|
1846
1795
|
best_params, best_loss, best_grad = policy_params, jnp.inf, jnp.inf
|
|
@@ -1854,35 +1803,39 @@ r"""
|
|
|
1854
1803
|
if stopping_rule is not None:
|
|
1855
1804
|
stopping_rule.reset()
|
|
1856
1805
|
|
|
1857
|
-
# initialize
|
|
1858
|
-
if
|
|
1859
|
-
|
|
1860
|
-
|
|
1861
|
-
|
|
1862
|
-
|
|
1863
|
-
|
|
1864
|
-
|
|
1806
|
+
# initialize dash board
|
|
1807
|
+
if dashboard is not None:
|
|
1808
|
+
dashboard_id = dashboard.register_experiment(
|
|
1809
|
+
dashboard_id, dashboard.get_planner_info(self),
|
|
1810
|
+
key=dash_key, viz=self.dashboard_viz)
|
|
1811
|
+
|
|
1812
|
+
# ======================================================================
|
|
1813
|
+
# MAIN TRAINING LOOP BEGINS
|
|
1814
|
+
# ======================================================================
|
|
1865
1815
|
|
|
1866
|
-
# training loop
|
|
1867
1816
|
iters = range(epochs)
|
|
1868
1817
|
if print_progress:
|
|
1869
1818
|
iters = tqdm(iters, total=100, position=tqdm_position)
|
|
1870
1819
|
position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
|
|
1871
1820
|
|
|
1872
1821
|
for it in iters:
|
|
1873
|
-
status = JaxPlannerStatus.NORMAL
|
|
1874
1822
|
|
|
1875
|
-
#
|
|
1876
|
-
|
|
1877
|
-
|
|
1878
|
-
|
|
1823
|
+
# ==================================================================
|
|
1824
|
+
# NEXT GRADIENT DESCENT STEP
|
|
1825
|
+
# ==================================================================
|
|
1826
|
+
|
|
1827
|
+
status = JaxPlannerStatus.NORMAL
|
|
1879
1828
|
|
|
1880
1829
|
# update the parameters of the plan
|
|
1881
1830
|
key, subkey = random.split(key)
|
|
1882
|
-
policy_params, converged, opt_state, opt_aux,
|
|
1883
|
-
|
|
1831
|
+
(policy_params, converged, opt_state, opt_aux,
|
|
1832
|
+
train_loss, train_log, model_params) = \
|
|
1884
1833
|
self.update(subkey, policy_params, policy_hyperparams,
|
|
1885
1834
|
train_subs, model_params, opt_state, opt_aux)
|
|
1835
|
+
|
|
1836
|
+
# ==================================================================
|
|
1837
|
+
# STATUS CHECKS AND LOGGING
|
|
1838
|
+
# ==================================================================
|
|
1886
1839
|
|
|
1887
1840
|
# no progress
|
|
1888
1841
|
grad_norm_zero, _ = jax.tree_util.tree_flatten(
|
|
@@ -1901,12 +1854,11 @@ r"""
|
|
|
1901
1854
|
# numerical error
|
|
1902
1855
|
if not np.isfinite(train_loss):
|
|
1903
1856
|
raise_warning(
|
|
1904
|
-
f'
|
|
1905
|
-
'red')
|
|
1857
|
+
f'JAX planner aborted due to invalid loss {train_loss}.', 'red')
|
|
1906
1858
|
status = JaxPlannerStatus.INVALID_GRADIENT
|
|
1907
1859
|
|
|
1908
1860
|
# evaluate test losses and record best plan so far
|
|
1909
|
-
test_loss, log = self.test_loss(
|
|
1861
|
+
test_loss, (log, model_params_test) = self.test_loss(
|
|
1910
1862
|
subkey, policy_params, policy_hyperparams,
|
|
1911
1863
|
test_subs, model_params_test)
|
|
1912
1864
|
test_loss = rolling_test_loss.update(test_loss)
|
|
@@ -1915,32 +1867,15 @@ r"""
|
|
|
1915
1867
|
policy_params, test_loss, train_log['grad']
|
|
1916
1868
|
last_iter_improve = it
|
|
1917
1869
|
|
|
1918
|
-
# save the plan figure
|
|
1919
|
-
if plot is not None and it % plot_step == 0:
|
|
1920
|
-
xticks.append(it // plot_step)
|
|
1921
|
-
loss_values.append(test_loss.item())
|
|
1922
|
-
action_values = {name: values
|
|
1923
|
-
for (name, values) in log['fluents'].items()
|
|
1924
|
-
if name in self.rddl.action_fluents}
|
|
1925
|
-
returns = -np.sum(np.asarray(log['reward']), axis=1)
|
|
1926
|
-
plot.redraw(xticks, loss_values, action_values, returns)
|
|
1927
|
-
|
|
1928
|
-
# if the progress bar is used
|
|
1929
|
-
elapsed = time.time() - start_time - elapsed_outside_loop
|
|
1930
|
-
if print_progress:
|
|
1931
|
-
iters.n = int(100 * min(1, max(elapsed / train_seconds, it / epochs)))
|
|
1932
|
-
iters.set_description(
|
|
1933
|
-
f'{position_str} {it:6} it / {-train_loss:14.6f} train / '
|
|
1934
|
-
f'{-test_loss:14.6f} test / {-best_loss:14.6f} best / '
|
|
1935
|
-
f'{status.value} status')
|
|
1936
|
-
|
|
1937
1870
|
# reached computation budget
|
|
1871
|
+
elapsed = time.time() - start_time - elapsed_outside_loop
|
|
1938
1872
|
if elapsed >= train_seconds:
|
|
1939
1873
|
status = JaxPlannerStatus.TIME_BUDGET_REACHED
|
|
1940
1874
|
if it >= epochs - 1:
|
|
1941
1875
|
status = JaxPlannerStatus.ITER_BUDGET_REACHED
|
|
1942
1876
|
|
|
1943
|
-
#
|
|
1877
|
+
# build a callback
|
|
1878
|
+
progress_percent = int(100 * min(1, max(elapsed / train_seconds, it / epochs)))
|
|
1944
1879
|
callback = {
|
|
1945
1880
|
'status': status,
|
|
1946
1881
|
'iteration': it,
|
|
@@ -1952,29 +1887,48 @@ r"""
|
|
|
1952
1887
|
'last_iteration_improved': last_iter_improve,
|
|
1953
1888
|
'grad': train_log['grad'],
|
|
1954
1889
|
'best_grad': best_grad,
|
|
1955
|
-
'noise_sigma': noise_sigma,
|
|
1956
1890
|
'updates': train_log['updates'],
|
|
1957
1891
|
'elapsed_time': elapsed,
|
|
1958
1892
|
'key': key,
|
|
1893
|
+
'model_params': model_params,
|
|
1894
|
+
'progress': progress_percent,
|
|
1895
|
+
'train_log': train_log,
|
|
1959
1896
|
**log
|
|
1960
1897
|
}
|
|
1898
|
+
|
|
1899
|
+
# stopping condition reached
|
|
1900
|
+
if stopping_rule is not None and stopping_rule.monitor(callback):
|
|
1901
|
+
callback['status'] = status = JaxPlannerStatus.STOPPING_RULE_REACHED
|
|
1902
|
+
|
|
1903
|
+
# if the progress bar is used
|
|
1904
|
+
if print_progress:
|
|
1905
|
+
iters.n = progress_percent
|
|
1906
|
+
iters.set_description(
|
|
1907
|
+
f'{position_str} {it:6} it / {-train_loss:14.6f} train / '
|
|
1908
|
+
f'{-test_loss:14.6f} test / {-best_loss:14.6f} best / '
|
|
1909
|
+
f'{status.value} status'
|
|
1910
|
+
)
|
|
1911
|
+
|
|
1912
|
+
# dash-board
|
|
1913
|
+
if dashboard is not None:
|
|
1914
|
+
dashboard.update_experiment(dashboard_id, callback)
|
|
1915
|
+
|
|
1916
|
+
# yield the callback
|
|
1961
1917
|
start_time_outside = time.time()
|
|
1962
1918
|
yield callback
|
|
1963
1919
|
elapsed_outside_loop += (time.time() - start_time_outside)
|
|
1964
1920
|
|
|
1965
1921
|
# abortion check
|
|
1966
|
-
if status.
|
|
1967
|
-
break
|
|
1968
|
-
|
|
1969
|
-
|
|
1970
|
-
|
|
1971
|
-
|
|
1972
|
-
|
|
1922
|
+
if status.is_terminal():
|
|
1923
|
+
break
|
|
1924
|
+
|
|
1925
|
+
# ======================================================================
|
|
1926
|
+
# POST-PROCESSING AND CLEANUP
|
|
1927
|
+
# ======================================================================
|
|
1928
|
+
|
|
1973
1929
|
# release resources
|
|
1974
1930
|
if print_progress:
|
|
1975
1931
|
iters.close()
|
|
1976
|
-
if plot is not None:
|
|
1977
|
-
plot.close()
|
|
1978
1932
|
|
|
1979
1933
|
# validate the test return
|
|
1980
1934
|
if log:
|
|
@@ -2098,101 +2052,6 @@ r"""
|
|
|
2098
2052
|
return actions
|
|
2099
2053
|
|
|
2100
2054
|
|
|
2101
|
-
class JaxLineSearchPlanner(JaxBackpropPlanner):
|
|
2102
|
-
'''A class for optimizing an action sequence in the given RDDL MDP using
|
|
2103
|
-
linear search gradient descent, with the Armijo condition.'''
|
|
2104
|
-
|
|
2105
|
-
def __init__(self, *args,
|
|
2106
|
-
decay: float=0.8,
|
|
2107
|
-
c: float=0.1,
|
|
2108
|
-
step_max: float=1.0,
|
|
2109
|
-
step_min: float=1e-6,
|
|
2110
|
-
**kwargs) -> None:
|
|
2111
|
-
'''Creates a new gradient-based algorithm for optimizing action sequences
|
|
2112
|
-
(plan) in the given RDDL using line search. All arguments are the
|
|
2113
|
-
same as in the parent class, except:
|
|
2114
|
-
|
|
2115
|
-
:param decay: reduction factor of learning rate per line search iteration
|
|
2116
|
-
:param c: positive coefficient in Armijo condition, should be in (0, 1)
|
|
2117
|
-
:param step_max: initial learning rate for line search
|
|
2118
|
-
:param step_min: minimum possible learning rate (line search halts)
|
|
2119
|
-
'''
|
|
2120
|
-
self.decay = decay
|
|
2121
|
-
self.c = c
|
|
2122
|
-
self.step_max = step_max
|
|
2123
|
-
self.step_min = step_min
|
|
2124
|
-
if 'clip_grad' in kwargs:
|
|
2125
|
-
raise_warning('clip_grad parameter conflicts with '
|
|
2126
|
-
'line search planner and will be ignored.', 'red')
|
|
2127
|
-
del kwargs['clip_grad']
|
|
2128
|
-
super(JaxLineSearchPlanner, self).__init__(*args, **kwargs)
|
|
2129
|
-
|
|
2130
|
-
def summarize_hyperparameters(self) -> None:
|
|
2131
|
-
super(JaxLineSearchPlanner, self).summarize_hyperparameters()
|
|
2132
|
-
print(f'linesearch hyper-parameters:\n'
|
|
2133
|
-
f' decay ={self.decay}\n'
|
|
2134
|
-
f' c ={self.c}\n'
|
|
2135
|
-
f' lr_range=({self.step_min}, {self.step_max})')
|
|
2136
|
-
|
|
2137
|
-
def _jax_update(self, loss):
|
|
2138
|
-
optimizer = self.optimizer
|
|
2139
|
-
projection = self.plan.projection
|
|
2140
|
-
decay, c, lrmax, lrmin = self.decay, self.c, self.step_max, self.step_min
|
|
2141
|
-
|
|
2142
|
-
# initialize the line search routine
|
|
2143
|
-
@jax.jit
|
|
2144
|
-
def _jax_wrapped_line_search_init(key, policy_params, hyperparams,
|
|
2145
|
-
subs, model_params):
|
|
2146
|
-
(f, log), grad = jax.value_and_grad(loss, argnums=1, has_aux=True)(
|
|
2147
|
-
key, policy_params, hyperparams, subs, model_params)
|
|
2148
|
-
gnorm2 = jax.tree_map(lambda x: jnp.sum(jnp.square(x)), grad)
|
|
2149
|
-
gnorm2 = jax.tree_util.tree_reduce(jnp.add, gnorm2)
|
|
2150
|
-
log['grad'] = grad
|
|
2151
|
-
return f, grad, gnorm2, log
|
|
2152
|
-
|
|
2153
|
-
# compute the next trial solution
|
|
2154
|
-
@jax.jit
|
|
2155
|
-
def _jax_wrapped_line_search_trial(
|
|
2156
|
-
step, grad, key, params, hparams, subs, mparams, state):
|
|
2157
|
-
state.hyperparams['learning_rate'] = step
|
|
2158
|
-
updates, new_state = optimizer.update(grad, state)
|
|
2159
|
-
new_params = optax.apply_updates(params, updates)
|
|
2160
|
-
new_params, _ = projection(new_params, hparams)
|
|
2161
|
-
f_step, _ = loss(key, new_params, hparams, subs, mparams)
|
|
2162
|
-
return f_step, new_params, new_state
|
|
2163
|
-
|
|
2164
|
-
# main iteration of line search
|
|
2165
|
-
def _jax_wrapped_plan_update(key, policy_params, hyperparams,
|
|
2166
|
-
subs, model_params, opt_state, opt_aux):
|
|
2167
|
-
|
|
2168
|
-
# initialize the line search
|
|
2169
|
-
f, grad, gnorm2, log = _jax_wrapped_line_search_init(
|
|
2170
|
-
key, policy_params, hyperparams, subs, model_params)
|
|
2171
|
-
|
|
2172
|
-
# continue to reduce the learning rate until the Armijo condition holds
|
|
2173
|
-
trials = 0
|
|
2174
|
-
step = lrmax / decay
|
|
2175
|
-
f_step = np.inf
|
|
2176
|
-
best_f, best_step, best_params, best_state = np.inf, None, None, None
|
|
2177
|
-
while (f_step > f - c * step * gnorm2 and step * decay >= lrmin) \
|
|
2178
|
-
or not trials:
|
|
2179
|
-
trials += 1
|
|
2180
|
-
step *= decay
|
|
2181
|
-
f_step, new_params, new_state = _jax_wrapped_line_search_trial(
|
|
2182
|
-
step, grad, key, policy_params, hyperparams, subs,
|
|
2183
|
-
model_params, opt_state)
|
|
2184
|
-
if f_step < best_f:
|
|
2185
|
-
best_f, best_step, best_params, best_state = \
|
|
2186
|
-
f_step, step, new_params, new_state
|
|
2187
|
-
|
|
2188
|
-
log['updates'] = None
|
|
2189
|
-
log['line_search_iters'] = trials
|
|
2190
|
-
log['learning_rate'] = best_step
|
|
2191
|
-
opt_aux['best_step'] = best_step
|
|
2192
|
-
return best_params, True, best_state, opt_aux, best_f, log
|
|
2193
|
-
|
|
2194
|
-
return _jax_wrapped_plan_update
|
|
2195
|
-
|
|
2196
2055
|
# ***********************************************************************
|
|
2197
2056
|
# ALL VERSIONS OF RISK FUNCTIONS
|
|
2198
2057
|
#
|
|
@@ -2224,6 +2083,7 @@ def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
|
|
|
2224
2083
|
returns <= jnp.percentile(returns, q=100 * alpha))
|
|
2225
2084
|
return jnp.sum(returns * alpha_mask) / jnp.sum(alpha_mask)
|
|
2226
2085
|
|
|
2086
|
+
|
|
2227
2087
|
# ***********************************************************************
|
|
2228
2088
|
# ALL VERSIONS OF CONTROLLERS
|
|
2229
2089
|
#
|
|
@@ -2268,8 +2128,10 @@ class JaxOfflineController(BaseAgent):
|
|
|
2268
2128
|
self.params_given = params is not None
|
|
2269
2129
|
|
|
2270
2130
|
self.step = 0
|
|
2131
|
+
self.callback = None
|
|
2271
2132
|
if not self.train_on_reset and not self.params_given:
|
|
2272
2133
|
callback = self.planner.optimize(key=self.key, **self.train_kwargs)
|
|
2134
|
+
self.callback = callback
|
|
2273
2135
|
params = callback['best_params']
|
|
2274
2136
|
self.params = params
|
|
2275
2137
|
|
|
@@ -2284,6 +2146,7 @@ class JaxOfflineController(BaseAgent):
|
|
|
2284
2146
|
self.step = 0
|
|
2285
2147
|
if self.train_on_reset and not self.params_given:
|
|
2286
2148
|
callback = self.planner.optimize(key=self.key, **self.train_kwargs)
|
|
2149
|
+
self.callback = callback
|
|
2287
2150
|
self.params = callback['best_params']
|
|
2288
2151
|
|
|
2289
2152
|
|
|
@@ -2326,7 +2189,9 @@ class JaxOnlineController(BaseAgent):
|
|
|
2326
2189
|
key=self.key,
|
|
2327
2190
|
guess=self.guess,
|
|
2328
2191
|
subs=state,
|
|
2329
|
-
**self.train_kwargs
|
|
2192
|
+
**self.train_kwargs
|
|
2193
|
+
)
|
|
2194
|
+
self.callback = callback
|
|
2330
2195
|
params = callback['best_params']
|
|
2331
2196
|
self.key, subkey = random.split(self.key)
|
|
2332
2197
|
actions = planner.get_action(
|
|
@@ -2337,4 +2202,5 @@ class JaxOnlineController(BaseAgent):
|
|
|
2337
2202
|
|
|
2338
2203
|
def reset(self) -> None:
|
|
2339
2204
|
self.guess = None
|
|
2205
|
+
self.callback = None
|
|
2340
2206
|
|