pyRDDLGym-jax 0.4__py3-none-any.whl → 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +463 -592
- pyRDDLGym_jax/core/logic.py +832 -530
- pyRDDLGym_jax/core/planner.py +422 -474
- pyRDDLGym_jax/core/simulator.py +7 -5
- pyRDDLGym_jax/core/tuning.py +390 -584
- pyRDDLGym_jax/core/visualization.py +1463 -0
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +5 -6
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +5 -5
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +5 -6
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +5 -4
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +5 -4
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +5 -5
- pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +7 -6
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/default_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/default_replan.cfg +5 -4
- pyRDDLGym_jax/examples/configs/default_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/tuning_drp.cfg +19 -0
- pyRDDLGym_jax/examples/configs/tuning_replan.cfg +20 -0
- pyRDDLGym_jax/examples/configs/tuning_slp.cfg +19 -0
- pyRDDLGym_jax/examples/run_plan.py +4 -1
- pyRDDLGym_jax/examples/run_tune.py +40 -29
- {pyRDDLGym_jax-0.4.dist-info → pyRDDLGym_jax-1.0.dist-info}/METADATA +164 -105
- pyRDDLGym_jax-1.0.dist-info/RECORD +45 -0
- {pyRDDLGym_jax-0.4.dist-info → pyRDDLGym_jax-1.0.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -19
- pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -20
- pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -18
- pyRDDLGym_jax-0.4.dist-info/RECORD +0 -44
- {pyRDDLGym_jax-0.4.dist-info → pyRDDLGym_jax-1.0.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-0.4.dist-info → pyRDDLGym_jax-1.0.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/planner.py
CHANGED
|
@@ -31,17 +31,18 @@ from pyRDDLGym.core.policy import BaseAgent
|
|
|
31
31
|
from pyRDDLGym_jax import __version__
|
|
32
32
|
from pyRDDLGym_jax.core import logic
|
|
33
33
|
from pyRDDLGym_jax.core.compiler import JaxRDDLCompiler
|
|
34
|
-
from pyRDDLGym_jax.core.logic import FuzzyLogic
|
|
34
|
+
from pyRDDLGym_jax.core.logic import Logic, FuzzyLogic
|
|
35
35
|
|
|
36
|
-
# try to
|
|
36
|
+
# try to load the dash board
|
|
37
37
|
try:
|
|
38
|
-
|
|
38
|
+
from pyRDDLGym_jax.core.visualization import JaxPlannerDashboard
|
|
39
39
|
except Exception:
|
|
40
|
-
raise_warning('
|
|
41
|
-
'
|
|
40
|
+
raise_warning('Failed to load the dashboard visualization tool: '
|
|
41
|
+
'please make sure you have installed the required packages.',
|
|
42
|
+
'red')
|
|
42
43
|
traceback.print_exc()
|
|
43
|
-
|
|
44
|
-
|
|
44
|
+
JaxPlannerDashboard = None
|
|
45
|
+
|
|
45
46
|
Activation = Callable[[jnp.ndarray], jnp.ndarray]
|
|
46
47
|
Bounds = Dict[str, Tuple[np.ndarray, np.ndarray]]
|
|
47
48
|
Kwargs = Dict[str, Any]
|
|
@@ -57,6 +58,7 @@ Pytree = Any
|
|
|
57
58
|
#
|
|
58
59
|
# ***********************************************************************
|
|
59
60
|
|
|
61
|
+
|
|
60
62
|
def _parse_config_file(path: str):
|
|
61
63
|
if not os.path.isfile(path):
|
|
62
64
|
raise FileNotFoundError(f'File {path} does not exist.')
|
|
@@ -95,18 +97,25 @@ def _load_config(config, args):
|
|
|
95
97
|
# read the model settings
|
|
96
98
|
logic_name = model_args.get('logic', 'FuzzyLogic')
|
|
97
99
|
logic_kwargs = model_args.get('logic_kwargs', {})
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
100
|
+
if logic_name == 'FuzzyLogic':
|
|
101
|
+
tnorm_name = model_args.get('tnorm', 'ProductTNorm')
|
|
102
|
+
tnorm_kwargs = model_args.get('tnorm_kwargs', {})
|
|
103
|
+
comp_name = model_args.get('complement', 'StandardComplement')
|
|
104
|
+
comp_kwargs = model_args.get('complement_kwargs', {})
|
|
105
|
+
compare_name = model_args.get('comparison', 'SigmoidComparison')
|
|
106
|
+
compare_kwargs = model_args.get('comparison_kwargs', {})
|
|
107
|
+
sampling_name = model_args.get('sampling', 'GumbelSoftmax')
|
|
108
|
+
sampling_kwargs = model_args.get('sampling_kwargs', {})
|
|
109
|
+
rounding_name = model_args.get('rounding', 'SoftRounding')
|
|
110
|
+
rounding_kwargs = model_args.get('rounding_kwargs', {})
|
|
111
|
+
control_name = model_args.get('control', 'SoftControlFlow')
|
|
112
|
+
control_kwargs = model_args.get('control_kwargs', {})
|
|
113
|
+
logic_kwargs['tnorm'] = getattr(logic, tnorm_name)(**tnorm_kwargs)
|
|
114
|
+
logic_kwargs['complement'] = getattr(logic, comp_name)(**comp_kwargs)
|
|
115
|
+
logic_kwargs['comparison'] = getattr(logic, compare_name)(**compare_kwargs)
|
|
116
|
+
logic_kwargs['sampling'] = getattr(logic, sampling_name)(**sampling_kwargs)
|
|
117
|
+
logic_kwargs['rounding'] = getattr(logic, rounding_name)(**rounding_kwargs)
|
|
118
|
+
logic_kwargs['control'] = getattr(logic, control_name)(**control_kwargs)
|
|
110
119
|
|
|
111
120
|
# read the policy settings
|
|
112
121
|
plan_method = planner_args.pop('method')
|
|
@@ -157,11 +166,25 @@ def _load_config(config, args):
|
|
|
157
166
|
else:
|
|
158
167
|
planner_args['optimizer'] = optimizer
|
|
159
168
|
|
|
160
|
-
#
|
|
169
|
+
# optimize call RNG key
|
|
161
170
|
planner_key = train_args.get('key', None)
|
|
162
171
|
if planner_key is not None:
|
|
163
172
|
train_args['key'] = random.PRNGKey(planner_key)
|
|
164
173
|
|
|
174
|
+
# dashboard
|
|
175
|
+
dashboard_key = train_args.get('dashboard', None)
|
|
176
|
+
if dashboard_key is not None and dashboard_key and JaxPlannerDashboard is not None:
|
|
177
|
+
train_args['dashboard'] = JaxPlannerDashboard()
|
|
178
|
+
elif dashboard_key is not None:
|
|
179
|
+
del train_args['dashboard']
|
|
180
|
+
|
|
181
|
+
# optimize call stopping rule
|
|
182
|
+
stopping_rule = train_args.get('stopping_rule', None)
|
|
183
|
+
if stopping_rule is not None:
|
|
184
|
+
stopping_rule_kwargs = train_args.pop('stopping_rule_kwargs', {})
|
|
185
|
+
train_args['stopping_rule'] = getattr(
|
|
186
|
+
sys.modules[__name__], stopping_rule)(**stopping_rule_kwargs)
|
|
187
|
+
|
|
165
188
|
return planner_args, plan_kwargs, train_args
|
|
166
189
|
|
|
167
190
|
|
|
@@ -175,7 +198,7 @@ def load_config_from_string(value: str) -> Tuple[Kwargs, ...]:
|
|
|
175
198
|
'''Loads config file contents specified explicitly as a string value.'''
|
|
176
199
|
config, args = _parse_config_string(value)
|
|
177
200
|
return _load_config(config, args)
|
|
178
|
-
|
|
201
|
+
|
|
179
202
|
|
|
180
203
|
# ***********************************************************************
|
|
181
204
|
# MODEL RELAXATIONS
|
|
@@ -193,7 +216,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
193
216
|
'''
|
|
194
217
|
|
|
195
218
|
def __init__(self, *args,
|
|
196
|
-
logic:
|
|
219
|
+
logic: Logic=FuzzyLogic(),
|
|
197
220
|
cpfs_without_grad: Optional[Set[str]]=None,
|
|
198
221
|
**kwargs) -> None:
|
|
199
222
|
'''Creates a new RDDL to Jax compiler, where operations that are not
|
|
@@ -228,57 +251,55 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
228
251
|
|
|
229
252
|
# overwrite basic operations with fuzzy ones
|
|
230
253
|
self.RELATIONAL_OPS = {
|
|
231
|
-
'>=': logic.greater_equal
|
|
232
|
-
'<=': logic.less_equal
|
|
233
|
-
'<': logic.less
|
|
234
|
-
'>': logic.greater
|
|
235
|
-
'==': logic.equal
|
|
236
|
-
'~=': logic.not_equal
|
|
254
|
+
'>=': logic.greater_equal,
|
|
255
|
+
'<=': logic.less_equal,
|
|
256
|
+
'<': logic.less,
|
|
257
|
+
'>': logic.greater,
|
|
258
|
+
'==': logic.equal,
|
|
259
|
+
'~=': logic.not_equal
|
|
237
260
|
}
|
|
238
|
-
self.LOGICAL_NOT = logic.logical_not
|
|
261
|
+
self.LOGICAL_NOT = logic.logical_not
|
|
239
262
|
self.LOGICAL_OPS = {
|
|
240
|
-
'^': logic.logical_and
|
|
241
|
-
'&': logic.logical_and
|
|
242
|
-
'|': logic.logical_or
|
|
243
|
-
'~': logic.xor
|
|
244
|
-
'=>': logic.implies
|
|
245
|
-
'<=>': logic.equiv
|
|
263
|
+
'^': logic.logical_and,
|
|
264
|
+
'&': logic.logical_and,
|
|
265
|
+
'|': logic.logical_or,
|
|
266
|
+
'~': logic.xor,
|
|
267
|
+
'=>': logic.implies,
|
|
268
|
+
'<=>': logic.equiv
|
|
246
269
|
}
|
|
247
|
-
self.AGGREGATION_OPS['forall'] = logic.forall
|
|
248
|
-
self.AGGREGATION_OPS['exists'] = logic.exists
|
|
249
|
-
self.AGGREGATION_OPS['argmin'] = logic.argmin
|
|
250
|
-
self.AGGREGATION_OPS['argmax'] = logic.argmax
|
|
251
|
-
self.KNOWN_UNARY['sgn'] = logic.sgn
|
|
252
|
-
self.KNOWN_UNARY['floor'] = logic.floor
|
|
253
|
-
self.KNOWN_UNARY['ceil'] = logic.ceil
|
|
254
|
-
self.KNOWN_UNARY['round'] = logic.round
|
|
255
|
-
self.KNOWN_UNARY['sqrt'] = logic.sqrt
|
|
256
|
-
self.KNOWN_BINARY['div'] = logic.div
|
|
257
|
-
self.KNOWN_BINARY['mod'] = logic.mod
|
|
258
|
-
self.KNOWN_BINARY['fmod'] = logic.mod
|
|
259
|
-
self.IF_HELPER = logic.control_if
|
|
260
|
-
self.SWITCH_HELPER = logic.control_switch
|
|
261
|
-
self.BERNOULLI_HELPER = logic.bernoulli
|
|
262
|
-
self.DISCRETE_HELPER = logic.discrete
|
|
263
|
-
self.POISSON_HELPER = logic.poisson
|
|
264
|
-
self.GEOMETRIC_HELPER = logic.geometric
|
|
265
|
-
|
|
266
|
-
def _jax_stop_grad(self, jax_expr):
|
|
267
|
-
|
|
270
|
+
self.AGGREGATION_OPS['forall'] = logic.forall
|
|
271
|
+
self.AGGREGATION_OPS['exists'] = logic.exists
|
|
272
|
+
self.AGGREGATION_OPS['argmin'] = logic.argmin
|
|
273
|
+
self.AGGREGATION_OPS['argmax'] = logic.argmax
|
|
274
|
+
self.KNOWN_UNARY['sgn'] = logic.sgn
|
|
275
|
+
self.KNOWN_UNARY['floor'] = logic.floor
|
|
276
|
+
self.KNOWN_UNARY['ceil'] = logic.ceil
|
|
277
|
+
self.KNOWN_UNARY['round'] = logic.round
|
|
278
|
+
self.KNOWN_UNARY['sqrt'] = logic.sqrt
|
|
279
|
+
self.KNOWN_BINARY['div'] = logic.div
|
|
280
|
+
self.KNOWN_BINARY['mod'] = logic.mod
|
|
281
|
+
self.KNOWN_BINARY['fmod'] = logic.mod
|
|
282
|
+
self.IF_HELPER = logic.control_if
|
|
283
|
+
self.SWITCH_HELPER = logic.control_switch
|
|
284
|
+
self.BERNOULLI_HELPER = logic.bernoulli
|
|
285
|
+
self.DISCRETE_HELPER = logic.discrete
|
|
286
|
+
self.POISSON_HELPER = logic.poisson
|
|
287
|
+
self.GEOMETRIC_HELPER = logic.geometric
|
|
288
|
+
|
|
289
|
+
def _jax_stop_grad(self, jax_expr):
|
|
268
290
|
def _jax_wrapped_stop_grad(x, params, key):
|
|
269
|
-
sample, key, error = jax_expr(x, params, key)
|
|
291
|
+
sample, key, error, params = jax_expr(x, params, key)
|
|
270
292
|
sample = jax.lax.stop_gradient(sample)
|
|
271
|
-
return sample, key, error
|
|
272
|
-
|
|
293
|
+
return sample, key, error, params
|
|
273
294
|
return _jax_wrapped_stop_grad
|
|
274
295
|
|
|
275
|
-
def _compile_cpfs(self,
|
|
296
|
+
def _compile_cpfs(self, init_params):
|
|
276
297
|
cpfs_cast = set()
|
|
277
298
|
jax_cpfs = {}
|
|
278
299
|
for (_, cpfs) in self.levels.items():
|
|
279
300
|
for cpf in cpfs:
|
|
280
301
|
_, expr = self.rddl.cpfs[cpf]
|
|
281
|
-
jax_cpfs[cpf] = self._jax(expr,
|
|
302
|
+
jax_cpfs[cpf] = self._jax(expr, init_params, dtype=self.REAL)
|
|
282
303
|
if self.rddl.variable_ranges[cpf] != 'real':
|
|
283
304
|
cpfs_cast.add(cpf)
|
|
284
305
|
if cpf in self.cpfs_without_grad:
|
|
@@ -289,17 +310,15 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
289
310
|
f'{cpfs_cast} be cast to float.')
|
|
290
311
|
if self.cpfs_without_grad:
|
|
291
312
|
raise_warning(f'User requested that gradients not flow '
|
|
292
|
-
f'through CPFs {self.cpfs_without_grad}.')
|
|
313
|
+
f'through CPFs {self.cpfs_without_grad}.')
|
|
314
|
+
|
|
293
315
|
return jax_cpfs
|
|
294
316
|
|
|
295
|
-
def _jax_kron(self, expr,
|
|
296
|
-
if self.logic.verbose:
|
|
297
|
-
raise_warning('JAX gradient compiler ignores KronDelta '
|
|
298
|
-
'during compilation.')
|
|
317
|
+
def _jax_kron(self, expr, init_params):
|
|
299
318
|
arg, = expr.args
|
|
300
|
-
arg = self._jax(arg,
|
|
319
|
+
arg = self._jax(arg, init_params)
|
|
301
320
|
return arg
|
|
302
|
-
|
|
321
|
+
|
|
303
322
|
|
|
304
323
|
# ***********************************************************************
|
|
305
324
|
# ALL VERSIONS OF JAX PLANS
|
|
@@ -309,6 +328,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
309
328
|
#
|
|
310
329
|
# ***********************************************************************
|
|
311
330
|
|
|
331
|
+
|
|
312
332
|
class JaxPlan:
|
|
313
333
|
'''Base class for all JAX policy representations.'''
|
|
314
334
|
|
|
@@ -320,7 +340,7 @@ class JaxPlan:
|
|
|
320
340
|
self.bounds = None
|
|
321
341
|
|
|
322
342
|
def summarize_hyperparameters(self) -> None:
|
|
323
|
-
|
|
343
|
+
print(self.__str__())
|
|
324
344
|
|
|
325
345
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
326
346
|
_bounds: Bounds,
|
|
@@ -363,7 +383,7 @@ class JaxPlan:
|
|
|
363
383
|
self._projection = value
|
|
364
384
|
|
|
365
385
|
def _calculate_action_info(self, compiled: JaxRDDLCompilerWithGrad,
|
|
366
|
-
user_bounds: Bounds,
|
|
386
|
+
user_bounds: Bounds,
|
|
367
387
|
horizon: int):
|
|
368
388
|
shapes, bounds, bounds_safe, cond_lists = {}, {}, {}, {}
|
|
369
389
|
for (name, prange) in compiled.rddl.variable_ranges.items():
|
|
@@ -446,24 +466,24 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
446
466
|
self._wrap_softmax = wrap_softmax
|
|
447
467
|
self._use_new_projection = use_new_projection
|
|
448
468
|
self._max_constraint_iter = max_constraint_iter
|
|
449
|
-
|
|
450
|
-
def
|
|
469
|
+
|
|
470
|
+
def __str__(self) -> str:
|
|
451
471
|
bounds = '\n '.join(
|
|
452
472
|
map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
473
|
+
return (f'policy hyper-parameters:\n'
|
|
474
|
+
f' initializer ={self._initializer_base}\n'
|
|
475
|
+
f'constraint-sat strategy (simple):\n'
|
|
476
|
+
f' parsed_action_bounds =\n {bounds}\n'
|
|
477
|
+
f' wrap_sigmoid ={self._wrap_sigmoid}\n'
|
|
478
|
+
f' wrap_sigmoid_min_prob={self._min_action_prob}\n'
|
|
479
|
+
f' wrap_non_bool ={self._wrap_non_bool}\n'
|
|
480
|
+
f'constraint-sat strategy (complex):\n'
|
|
481
|
+
f' wrap_softmax ={self._wrap_softmax}\n'
|
|
482
|
+
f' use_new_projection ={self._use_new_projection}\n'
|
|
483
|
+
f' max_projection_iters ={self._max_constraint_iter}')
|
|
464
484
|
|
|
465
485
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
466
|
-
_bounds: Bounds,
|
|
486
|
+
_bounds: Bounds,
|
|
467
487
|
horizon: int) -> None:
|
|
468
488
|
rddl = compiled.rddl
|
|
469
489
|
|
|
@@ -504,7 +524,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
504
524
|
def _jax_bool_action_to_param(var, action, hyperparams):
|
|
505
525
|
if wrap_sigmoid:
|
|
506
526
|
weight = hyperparams[var]
|
|
507
|
-
return
|
|
527
|
+
return jax.scipy.special.logit(action) / weight
|
|
508
528
|
else:
|
|
509
529
|
return action
|
|
510
530
|
|
|
@@ -513,14 +533,13 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
513
533
|
def _jax_non_bool_param_to_action(var, param, hyperparams):
|
|
514
534
|
if wrap_non_bool:
|
|
515
535
|
lower, upper = bounds_safe[var]
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
]
|
|
536
|
+
mb, ml, mu, mn = [mask.astype(compiled.REAL)
|
|
537
|
+
for mask in cond_lists[var]]
|
|
538
|
+
action = (
|
|
539
|
+
mb * (lower + (upper - lower) * jax.nn.sigmoid(param)) +
|
|
540
|
+
ml * (lower + (jax.nn.elu(param) + 1.0)) +
|
|
541
|
+
mu * (upper - (jax.nn.elu(-param) + 1.0)) +
|
|
542
|
+
mn * param
|
|
524
543
|
)
|
|
525
544
|
else:
|
|
526
545
|
action = param
|
|
@@ -780,7 +799,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
780
799
|
def __init__(self, topology: Optional[Sequence[int]]=None,
|
|
781
800
|
activation: Activation=jnp.tanh,
|
|
782
801
|
initializer: hk.initializers.Initializer=hk.initializers.VarianceScaling(scale=2.0),
|
|
783
|
-
normalize: bool=False,
|
|
802
|
+
normalize: bool=False,
|
|
784
803
|
normalize_per_layer: bool=False,
|
|
785
804
|
normalizer_kwargs: Optional[Kwargs]=None,
|
|
786
805
|
wrap_non_bool: bool=False) -> None:
|
|
@@ -812,23 +831,23 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
812
831
|
normalizer_kwargs = {'create_offset': True, 'create_scale': True}
|
|
813
832
|
self._normalizer_kwargs = normalizer_kwargs
|
|
814
833
|
self._wrap_non_bool = wrap_non_bool
|
|
815
|
-
|
|
816
|
-
def
|
|
834
|
+
|
|
835
|
+
def __str__(self) -> str:
|
|
817
836
|
bounds = '\n '.join(
|
|
818
837
|
map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
838
|
+
return (f'policy hyper-parameters:\n'
|
|
839
|
+
f' topology ={self._topology}\n'
|
|
840
|
+
f' activation_fn ={self._activations[0].__name__}\n'
|
|
841
|
+
f' initializer ={type(self._initializer_base).__name__}\n'
|
|
842
|
+
f' apply_input_norm ={self._normalize}\n'
|
|
843
|
+
f' input_norm_layerwise={self._normalize_per_layer}\n'
|
|
844
|
+
f' input_norm_args ={self._normalizer_kwargs}\n'
|
|
845
|
+
f'constraint-sat strategy:\n'
|
|
846
|
+
f' parsed_action_bounds=\n {bounds}\n'
|
|
847
|
+
f' wrap_non_bool ={self._wrap_non_bool}')
|
|
848
|
+
|
|
830
849
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
831
|
-
_bounds: Bounds,
|
|
850
|
+
_bounds: Bounds,
|
|
832
851
|
horizon: int) -> None:
|
|
833
852
|
rddl = compiled.rddl
|
|
834
853
|
|
|
@@ -881,7 +900,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
881
900
|
if normalize_per_layer and value_size == 1:
|
|
882
901
|
raise_warning(
|
|
883
902
|
f'Cannot apply layer norm to state-fluent <{var}> '
|
|
884
|
-
f'of size 1: setting normalize_per_layer = False.',
|
|
903
|
+
f'of size 1: setting normalize_per_layer = False.',
|
|
885
904
|
'red')
|
|
886
905
|
normalize_per_layer = False
|
|
887
906
|
non_bool_dims += value_size
|
|
@@ -906,8 +925,8 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
906
925
|
else:
|
|
907
926
|
if normalize and normalize_per_layer:
|
|
908
927
|
normalizer = hk.LayerNorm(
|
|
909
|
-
axis=-1, param_axis=-1,
|
|
910
|
-
name=f'input_norm_{input_names[var]}',
|
|
928
|
+
axis=-1, param_axis=-1,
|
|
929
|
+
name=f'input_norm_{input_names[var]}',
|
|
911
930
|
**self._normalizer_kwargs)
|
|
912
931
|
state = normalizer(state)
|
|
913
932
|
states_non_bool.append(state)
|
|
@@ -917,7 +936,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
917
936
|
# optionally perform layer normalization on the non-bool inputs
|
|
918
937
|
if normalize and not normalize_per_layer and non_bool_dims:
|
|
919
938
|
normalizer = hk.LayerNorm(
|
|
920
|
-
axis=-1, param_axis=-1, name='input_norm',
|
|
939
|
+
axis=-1, param_axis=-1, name='input_norm',
|
|
921
940
|
**self._normalizer_kwargs)
|
|
922
941
|
normalized = normalizer(state[:non_bool_dims])
|
|
923
942
|
state = state.at[:non_bool_dims].set(normalized)
|
|
@@ -950,14 +969,13 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
950
969
|
else:
|
|
951
970
|
if wrap_non_bool:
|
|
952
971
|
lower, upper = bounds_safe[var]
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
]
|
|
972
|
+
mb, ml, mu, mn = [mask.astype(compiled.REAL)
|
|
973
|
+
for mask in cond_lists[var]]
|
|
974
|
+
action = (
|
|
975
|
+
mb * (lower + (upper - lower) * jax.nn.sigmoid(output)) +
|
|
976
|
+
ml * (lower + (jax.nn.elu(output) + 1.0)) +
|
|
977
|
+
mu * (upper - (jax.nn.elu(-output) + 1.0)) +
|
|
978
|
+
mn * output
|
|
961
979
|
)
|
|
962
980
|
else:
|
|
963
981
|
action = output
|
|
@@ -1049,7 +1067,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1049
1067
|
|
|
1050
1068
|
def guess_next_epoch(self, params: Pytree) -> Pytree:
|
|
1051
1069
|
return params
|
|
1052
|
-
|
|
1070
|
+
|
|
1053
1071
|
|
|
1054
1072
|
# ***********************************************************************
|
|
1055
1073
|
# ALL VERSIONS OF JAX PLANNER
|
|
@@ -1059,6 +1077,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1059
1077
|
#
|
|
1060
1078
|
# ***********************************************************************
|
|
1061
1079
|
|
|
1080
|
+
|
|
1062
1081
|
class RollingMean:
|
|
1063
1082
|
'''Maintains an estimate of the rolling mean of a stream of real-valued
|
|
1064
1083
|
observations.'''
|
|
@@ -1077,113 +1096,6 @@ class RollingMean:
|
|
|
1077
1096
|
return self._total / len(memory)
|
|
1078
1097
|
|
|
1079
1098
|
|
|
1080
|
-
class JaxPlannerPlot:
|
|
1081
|
-
'''Supports plotting and visualization of a JAX policy in real time.'''
|
|
1082
|
-
|
|
1083
|
-
def __init__(self, rddl: RDDLPlanningModel, horizon: int,
|
|
1084
|
-
show_violin: bool=True, show_action: bool=True) -> None:
|
|
1085
|
-
'''Creates a new planner visualizer.
|
|
1086
|
-
|
|
1087
|
-
:param rddl: the planning model to optimize
|
|
1088
|
-
:param horizon: the lookahead or planning horizon
|
|
1089
|
-
:param show_violin: whether to show the distribution of batch losses
|
|
1090
|
-
:param show_action: whether to show heatmaps of the action fluents
|
|
1091
|
-
'''
|
|
1092
|
-
num_plots = 1
|
|
1093
|
-
if show_violin:
|
|
1094
|
-
num_plots += 1
|
|
1095
|
-
if show_action:
|
|
1096
|
-
num_plots += len(rddl.action_fluents)
|
|
1097
|
-
self._fig, axes = plt.subplots(num_plots)
|
|
1098
|
-
if num_plots == 1:
|
|
1099
|
-
axes = [axes]
|
|
1100
|
-
|
|
1101
|
-
# prepare the loss plot
|
|
1102
|
-
self._loss_ax = axes[0]
|
|
1103
|
-
self._loss_ax.autoscale(enable=True)
|
|
1104
|
-
self._loss_ax.set_xlabel('training time')
|
|
1105
|
-
self._loss_ax.set_ylabel('loss value')
|
|
1106
|
-
self._loss_plot = self._loss_ax.plot(
|
|
1107
|
-
[], [], linestyle=':', marker='o', markersize=2)[0]
|
|
1108
|
-
self._loss_back = self._fig.canvas.copy_from_bbox(self._loss_ax.bbox)
|
|
1109
|
-
|
|
1110
|
-
# prepare the violin plot
|
|
1111
|
-
if show_violin:
|
|
1112
|
-
self._hist_ax = axes[1]
|
|
1113
|
-
else:
|
|
1114
|
-
self._hist_ax = None
|
|
1115
|
-
|
|
1116
|
-
# prepare the action plots
|
|
1117
|
-
if show_action:
|
|
1118
|
-
self._action_ax = {name: axes[idx + (2 if show_violin else 1)]
|
|
1119
|
-
for (idx, name) in enumerate(rddl.action_fluents)}
|
|
1120
|
-
self._action_plots = {}
|
|
1121
|
-
for name in rddl.action_fluents:
|
|
1122
|
-
ax = self._action_ax[name]
|
|
1123
|
-
if rddl.variable_ranges[name] == 'bool':
|
|
1124
|
-
vmin, vmax = 0.0, 1.0
|
|
1125
|
-
else:
|
|
1126
|
-
vmin, vmax = None, None
|
|
1127
|
-
action_dim = 1
|
|
1128
|
-
for dim in rddl.object_counts(rddl.variable_params[name]):
|
|
1129
|
-
action_dim *= dim
|
|
1130
|
-
action_plot = ax.pcolormesh(
|
|
1131
|
-
np.zeros((action_dim, horizon)),
|
|
1132
|
-
cmap='seismic', vmin=vmin, vmax=vmax)
|
|
1133
|
-
ax.set_aspect('auto')
|
|
1134
|
-
ax.set_xlabel('decision epoch')
|
|
1135
|
-
ax.set_ylabel(name)
|
|
1136
|
-
plt.colorbar(action_plot, ax=ax)
|
|
1137
|
-
self._action_plots[name] = action_plot
|
|
1138
|
-
self._action_back = {name: self._fig.canvas.copy_from_bbox(ax.bbox)
|
|
1139
|
-
for (name, ax) in self._action_ax.items()}
|
|
1140
|
-
else:
|
|
1141
|
-
self._action_ax = None
|
|
1142
|
-
self._action_plots = None
|
|
1143
|
-
self._action_back = None
|
|
1144
|
-
|
|
1145
|
-
plt.tight_layout()
|
|
1146
|
-
plt.show(block=False)
|
|
1147
|
-
|
|
1148
|
-
def redraw(self, xticks, losses, actions, returns) -> None:
|
|
1149
|
-
|
|
1150
|
-
# draw the loss curve
|
|
1151
|
-
self._fig.canvas.restore_region(self._loss_back)
|
|
1152
|
-
self._loss_plot.set_xdata(xticks)
|
|
1153
|
-
self._loss_plot.set_ydata(losses)
|
|
1154
|
-
self._loss_ax.set_xlim([0, len(xticks)])
|
|
1155
|
-
self._loss_ax.set_ylim([np.min(losses), np.max(losses)])
|
|
1156
|
-
self._loss_ax.draw_artist(self._loss_plot)
|
|
1157
|
-
self._fig.canvas.blit(self._loss_ax.bbox)
|
|
1158
|
-
|
|
1159
|
-
# draw the violin plot
|
|
1160
|
-
if self._hist_ax is not None:
|
|
1161
|
-
self._hist_ax.clear()
|
|
1162
|
-
self._hist_ax.set_xlabel('loss value')
|
|
1163
|
-
self._hist_ax.set_ylabel('density')
|
|
1164
|
-
self._hist_ax.violinplot(returns, vert=False, showmeans=True)
|
|
1165
|
-
|
|
1166
|
-
# draw the actions
|
|
1167
|
-
if self._action_ax is not None:
|
|
1168
|
-
for (name, values) in actions.items():
|
|
1169
|
-
values = np.mean(values, axis=0, dtype=float)
|
|
1170
|
-
values = np.reshape(values, newshape=(values.shape[0], -1)).T
|
|
1171
|
-
self._fig.canvas.restore_region(self._action_back[name])
|
|
1172
|
-
self._action_plots[name].set_array(values)
|
|
1173
|
-
self._action_ax[name].draw_artist(self._action_plots[name])
|
|
1174
|
-
self._fig.canvas.blit(self._action_ax[name].bbox)
|
|
1175
|
-
self._action_plots[name].set_clim([np.min(values), np.max(values)])
|
|
1176
|
-
|
|
1177
|
-
self._fig.canvas.draw()
|
|
1178
|
-
self._fig.canvas.flush_events()
|
|
1179
|
-
|
|
1180
|
-
def close(self) -> None:
|
|
1181
|
-
plt.close(self._fig)
|
|
1182
|
-
del self._loss_ax, self._hist_ax, self._action_ax, \
|
|
1183
|
-
self._loss_plot, self._action_plots, self._fig, \
|
|
1184
|
-
self._loss_back, self._action_back
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
1099
|
class JaxPlannerStatus(Enum):
|
|
1188
1100
|
'''Represents the status of a policy update from the JAX planner,
|
|
1189
1101
|
including whether the update resulted in nan gradient,
|
|
@@ -1191,16 +1103,50 @@ class JaxPlannerStatus(Enum):
|
|
|
1191
1103
|
can be used to monitor and act based on the planner's progress.'''
|
|
1192
1104
|
|
|
1193
1105
|
NORMAL = 0
|
|
1194
|
-
|
|
1195
|
-
|
|
1196
|
-
|
|
1197
|
-
|
|
1198
|
-
|
|
1106
|
+
STOPPING_RULE_REACHED = 1
|
|
1107
|
+
NO_PROGRESS = 2
|
|
1108
|
+
PRECONDITION_POSSIBLY_UNSATISFIED = 3
|
|
1109
|
+
INVALID_GRADIENT = 4
|
|
1110
|
+
TIME_BUDGET_REACHED = 5
|
|
1111
|
+
ITER_BUDGET_REACHED = 6
|
|
1199
1112
|
|
|
1200
|
-
def
|
|
1201
|
-
return self.value >=
|
|
1113
|
+
def is_terminal(self) -> bool:
|
|
1114
|
+
return self.value == 1 or self.value >= 4
|
|
1202
1115
|
|
|
1203
1116
|
|
|
1117
|
+
class JaxPlannerStoppingRule:
|
|
1118
|
+
'''The base class of all planner stopping rules.'''
|
|
1119
|
+
|
|
1120
|
+
def reset(self) -> None:
|
|
1121
|
+
raise NotImplementedError
|
|
1122
|
+
|
|
1123
|
+
def monitor(self, callback: Dict[str, Any]) -> bool:
|
|
1124
|
+
raise NotImplementedError
|
|
1125
|
+
|
|
1126
|
+
|
|
1127
|
+
class NoImprovementStoppingRule(JaxPlannerStoppingRule):
|
|
1128
|
+
'''Stopping rule based on no improvement for a fixed number of iterations.'''
|
|
1129
|
+
|
|
1130
|
+
def __init__(self, patience: int) -> None:
|
|
1131
|
+
self.patience = patience
|
|
1132
|
+
|
|
1133
|
+
def reset(self) -> None:
|
|
1134
|
+
self.callback = None
|
|
1135
|
+
self.iters_since_last_update = 0
|
|
1136
|
+
|
|
1137
|
+
def monitor(self, callback: Dict[str, Any]) -> bool:
|
|
1138
|
+
if self.callback is None \
|
|
1139
|
+
or callback['best_return'] > self.callback['best_return']:
|
|
1140
|
+
self.callback = callback
|
|
1141
|
+
self.iters_since_last_update = 0
|
|
1142
|
+
else:
|
|
1143
|
+
self.iters_since_last_update += 1
|
|
1144
|
+
return self.iters_since_last_update >= self.patience
|
|
1145
|
+
|
|
1146
|
+
def __str__(self) -> str:
|
|
1147
|
+
return f'No improvement for {self.patience} iterations'
|
|
1148
|
+
|
|
1149
|
+
|
|
1204
1150
|
class JaxBackpropPlanner:
|
|
1205
1151
|
'''A class for optimizing an action sequence in the given RDDL MDP using
|
|
1206
1152
|
gradient descent.'''
|
|
@@ -1215,13 +1161,16 @@ class JaxBackpropPlanner:
|
|
|
1215
1161
|
optimizer: Callable[..., optax.GradientTransformation]=optax.rmsprop,
|
|
1216
1162
|
optimizer_kwargs: Optional[Kwargs]=None,
|
|
1217
1163
|
clip_grad: Optional[float]=None,
|
|
1218
|
-
|
|
1164
|
+
line_search_kwargs: Optional[Kwargs]=None,
|
|
1165
|
+
noise_kwargs: Optional[Kwargs]=None,
|
|
1166
|
+
logic: Logic=FuzzyLogic(),
|
|
1219
1167
|
use_symlog_reward: bool=False,
|
|
1220
1168
|
utility: Union[Callable[[jnp.ndarray], float], str]='mean',
|
|
1221
1169
|
utility_kwargs: Optional[Kwargs]=None,
|
|
1222
1170
|
cpfs_without_grad: Optional[Set[str]]=None,
|
|
1223
1171
|
compile_non_fluent_exact: bool=True,
|
|
1224
|
-
logger: Optional[Logger]=None
|
|
1172
|
+
logger: Optional[Logger]=None,
|
|
1173
|
+
dashboard_viz: Optional[Any]=None) -> None:
|
|
1225
1174
|
'''Creates a new gradient-based algorithm for optimizing action sequences
|
|
1226
1175
|
(plan) in the given RDDL. Some operations will be converted to their
|
|
1227
1176
|
differentiable counterparts; the specific operations can be customized
|
|
@@ -1241,7 +1190,10 @@ class JaxBackpropPlanner:
|
|
|
1241
1190
|
:param optimizer_kwargs: a dictionary of parameters to pass to the SGD
|
|
1242
1191
|
factory (e.g. which parameters are controllable externally)
|
|
1243
1192
|
:param clip_grad: maximum magnitude of gradient updates
|
|
1244
|
-
:param
|
|
1193
|
+
:param line_search_kwargs: parameters to pass to optional line search
|
|
1194
|
+
method to scale learning rate
|
|
1195
|
+
:param noise_kwargs: parameters of optional gradient noise
|
|
1196
|
+
:param logic: a subclass of Logic for mapping exact mathematical
|
|
1245
1197
|
operations to their differentiable counterparts
|
|
1246
1198
|
:param use_symlog_reward: whether to use the symlog transform on the
|
|
1247
1199
|
reward as a form of normalization
|
|
@@ -1256,6 +1208,8 @@ class JaxBackpropPlanner:
|
|
|
1256
1208
|
:param compile_non_fluent_exact: whether non-fluent expressions
|
|
1257
1209
|
are always compiled using exact JAX expressions
|
|
1258
1210
|
:param logger: to log information about compilation to file
|
|
1211
|
+
:param dashboard_viz: optional visualizer object from the environment
|
|
1212
|
+
to pass to the dashboard to visualize the policy
|
|
1259
1213
|
'''
|
|
1260
1214
|
self.rddl = rddl
|
|
1261
1215
|
self.plan = plan
|
|
@@ -1270,29 +1224,34 @@ class JaxBackpropPlanner:
|
|
|
1270
1224
|
action_bounds = {}
|
|
1271
1225
|
self._action_bounds = action_bounds
|
|
1272
1226
|
self.use64bit = use64bit
|
|
1273
|
-
self.
|
|
1227
|
+
self.optimizer_name = optimizer
|
|
1274
1228
|
if optimizer_kwargs is None:
|
|
1275
1229
|
optimizer_kwargs = {'learning_rate': 0.1}
|
|
1276
|
-
self.
|
|
1230
|
+
self.optimizer_kwargs = optimizer_kwargs
|
|
1277
1231
|
self.clip_grad = clip_grad
|
|
1232
|
+
self.line_search_kwargs = line_search_kwargs
|
|
1233
|
+
self.noise_kwargs = noise_kwargs
|
|
1278
1234
|
|
|
1279
1235
|
# set optimizer
|
|
1280
1236
|
try:
|
|
1281
1237
|
optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs)
|
|
1282
1238
|
except Exception as _:
|
|
1283
1239
|
raise_warning(
|
|
1284
|
-
'Failed to inject hyperparameters into optax optimizer, '
|
|
1240
|
+
f'Failed to inject hyperparameters into optax optimizer {optimizer}, '
|
|
1285
1241
|
'rolling back to safer method: please note that modification of '
|
|
1286
|
-
'optimizer hyperparameters will not work,
|
|
1287
|
-
|
|
1288
|
-
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
|
|
1242
|
+
'optimizer hyperparameters will not work.', 'red')
|
|
1243
|
+
optimizer = optimizer(**optimizer_kwargs)
|
|
1244
|
+
|
|
1245
|
+
# apply optimizer chain of transformations
|
|
1246
|
+
pipeline = []
|
|
1247
|
+
if clip_grad is not None:
|
|
1248
|
+
pipeline.append(optax.clip(clip_grad))
|
|
1249
|
+
if noise_kwargs is not None:
|
|
1250
|
+
pipeline.append(optax.add_noise(**noise_kwargs))
|
|
1251
|
+
pipeline.append(optimizer)
|
|
1252
|
+
if line_search_kwargs is not None:
|
|
1253
|
+
pipeline.append(optax.scale_by_zoom_linesearch(**line_search_kwargs))
|
|
1254
|
+
self.optimizer = optax.chain(*pipeline)
|
|
1296
1255
|
|
|
1297
1256
|
# set utility
|
|
1298
1257
|
if isinstance(utility, str):
|
|
@@ -1325,11 +1284,12 @@ class JaxBackpropPlanner:
|
|
|
1325
1284
|
self.cpfs_without_grad = cpfs_without_grad
|
|
1326
1285
|
self.compile_non_fluent_exact = compile_non_fluent_exact
|
|
1327
1286
|
self.logger = logger
|
|
1287
|
+
self.dashboard_viz = dashboard_viz
|
|
1328
1288
|
|
|
1329
1289
|
self._jax_compile_rddl()
|
|
1330
1290
|
self._jax_compile_optimizer()
|
|
1331
1291
|
|
|
1332
|
-
def
|
|
1292
|
+
def summarize_system(self) -> str:
|
|
1333
1293
|
try:
|
|
1334
1294
|
jaxlib_version = jax._src.lib.version_str
|
|
1335
1295
|
except Exception as _:
|
|
@@ -1340,42 +1300,63 @@ class JaxBackpropPlanner:
|
|
|
1340
1300
|
except Exception as _:
|
|
1341
1301
|
devices_short = 'N/A'
|
|
1342
1302
|
LOGO = \
|
|
1303
|
+
r"""
|
|
1304
|
+
__ ______ __ __ ______ __ ______ __ __
|
|
1305
|
+
/\ \ /\ __ \ /\_\_\_\ /\ == \/\ \ /\ __ \ /\ "-.\ \
|
|
1306
|
+
_\_\ \\ \ __ \\/_/\_\/_\ \ _-/\ \ \____\ \ __ \\ \ \-. \
|
|
1307
|
+
/\_____\\ \_\ \_\ /\_\/\_\\ \_\ \ \_____\\ \_\ \_\\ \_\\"\_\
|
|
1308
|
+
\/_____/ \/_/\/_/ \/_/\/_/ \/_/ \/_____/ \/_/\/_/ \/_/ \/_/
|
|
1343
1309
|
"""
|
|
1344
|
-
|
|
1345
|
-
|
|
1346
|
-
|
|
1347
|
-
|
|
1348
|
-
|
|
1349
|
-
|
|
1350
|
-
|
|
1351
|
-
|
|
1352
|
-
|
|
1353
|
-
|
|
1354
|
-
|
|
1355
|
-
|
|
1356
|
-
|
|
1357
|
-
|
|
1358
|
-
|
|
1310
|
+
|
|
1311
|
+
return ('\n'
|
|
1312
|
+
f'{LOGO}\n'
|
|
1313
|
+
f'Version {__version__}\n'
|
|
1314
|
+
f'Python {sys.version}\n'
|
|
1315
|
+
f'jax {jax.version.__version__}, jaxlib {jaxlib_version}, '
|
|
1316
|
+
f'optax {optax.__version__}, haiku {hk.__version__}, '
|
|
1317
|
+
f'numpy {np.__version__}\n'
|
|
1318
|
+
f'devices: {devices_short}\n')
|
|
1319
|
+
|
|
1320
|
+
def __str__(self) -> str:
|
|
1321
|
+
result = (f'objective hyper-parameters:\n'
|
|
1322
|
+
f' utility_fn ={self.utility.__name__}\n'
|
|
1323
|
+
f' utility args ={self.utility_kwargs}\n'
|
|
1324
|
+
f' use_symlog ={self.use_symlog_reward}\n'
|
|
1325
|
+
f' lookahead ={self.horizon}\n'
|
|
1326
|
+
f' user_action_bounds={self._action_bounds}\n'
|
|
1327
|
+
f' fuzzy logic type ={type(self.logic).__name__}\n'
|
|
1328
|
+
f' non_fluents exact ={self.compile_non_fluent_exact}\n'
|
|
1329
|
+
f' cpfs_no_gradient ={self.cpfs_without_grad}\n'
|
|
1330
|
+
f'optimizer hyper-parameters:\n'
|
|
1331
|
+
f' use_64_bit ={self.use64bit}\n'
|
|
1332
|
+
f' optimizer ={self.optimizer_name}\n'
|
|
1333
|
+
f' optimizer args ={self.optimizer_kwargs}\n'
|
|
1334
|
+
f' clip_gradient ={self.clip_grad}\n'
|
|
1335
|
+
f' line_search_kwargs={self.line_search_kwargs}\n'
|
|
1336
|
+
f' noise_kwargs ={self.noise_kwargs}\n'
|
|
1337
|
+
f' batch_size_train ={self.batch_size_train}\n'
|
|
1338
|
+
f' batch_size_test ={self.batch_size_test}')
|
|
1339
|
+
result += '\n' + str(self.plan)
|
|
1340
|
+
result += '\n' + str(self.logic)
|
|
1341
|
+
|
|
1342
|
+
# print model relaxation information
|
|
1343
|
+
if not self.compiled.model_params:
|
|
1344
|
+
return result
|
|
1345
|
+
result += '\n' + ('Some RDDL operations are non-differentiable '
|
|
1346
|
+
'and will be approximated as follows:' + '\n')
|
|
1347
|
+
exprs_by_rddl_op, values_by_rddl_op = {}, {}
|
|
1348
|
+
for info in self.compiled.model_parameter_info().values():
|
|
1349
|
+
rddl_op = info['rddl_op']
|
|
1350
|
+
exprs_by_rddl_op.setdefault(rddl_op, []).append(info['id'])
|
|
1351
|
+
values_by_rddl_op.setdefault(rddl_op, []).append(info['init_value'])
|
|
1352
|
+
for rddl_op in sorted(exprs_by_rddl_op.keys()):
|
|
1353
|
+
result += (f' {rddl_op}:\n'
|
|
1354
|
+
f' addresses ={exprs_by_rddl_op[rddl_op]}\n'
|
|
1355
|
+
f' init_values={values_by_rddl_op[rddl_op]}\n')
|
|
1356
|
+
return result
|
|
1359
1357
|
|
|
1360
1358
|
def summarize_hyperparameters(self) -> None:
|
|
1361
|
-
print(
|
|
1362
|
-
f' utility_fn ={self.utility.__name__}\n'
|
|
1363
|
-
f' utility args ={self.utility_kwargs}\n'
|
|
1364
|
-
f' use_symlog ={self.use_symlog_reward}\n'
|
|
1365
|
-
f' lookahead ={self.horizon}\n'
|
|
1366
|
-
f' user_action_bounds={self._action_bounds}\n'
|
|
1367
|
-
f' fuzzy logic type ={type(self.logic).__name__}\n'
|
|
1368
|
-
f' nonfluents exact ={self.compile_non_fluent_exact}\n'
|
|
1369
|
-
f' cpfs_no_gradient ={self.cpfs_without_grad}\n'
|
|
1370
|
-
f'optimizer hyper-parameters:\n'
|
|
1371
|
-
f' use_64_bit ={self.use64bit}\n'
|
|
1372
|
-
f' optimizer ={self._optimizer_name.__name__}\n'
|
|
1373
|
-
f' optimizer args ={self._optimizer_kwargs}\n'
|
|
1374
|
-
f' clip_gradient ={self.clip_grad}\n'
|
|
1375
|
-
f' batch_size_train ={self.batch_size_train}\n'
|
|
1376
|
-
f' batch_size_test ={self.batch_size_test}')
|
|
1377
|
-
self.plan.summarize_hyperparameters()
|
|
1378
|
-
self.logic.summarize_hyperparameters()
|
|
1359
|
+
print(self.__str__())
|
|
1379
1360
|
|
|
1380
1361
|
# ===========================================================================
|
|
1381
1362
|
# COMPILATION SUBROUTINES
|
|
@@ -1391,14 +1372,16 @@ class JaxBackpropPlanner:
|
|
|
1391
1372
|
logger=self.logger,
|
|
1392
1373
|
use64bit=self.use64bit,
|
|
1393
1374
|
cpfs_without_grad=self.cpfs_without_grad,
|
|
1394
|
-
compile_non_fluent_exact=self.compile_non_fluent_exact
|
|
1375
|
+
compile_non_fluent_exact=self.compile_non_fluent_exact
|
|
1376
|
+
)
|
|
1395
1377
|
self.compiled.compile(log_jax_expr=True, heading='RELAXED MODEL')
|
|
1396
1378
|
|
|
1397
1379
|
# Jax compilation of the exact RDDL for testing
|
|
1398
1380
|
self.test_compiled = JaxRDDLCompiler(
|
|
1399
|
-
rddl=rddl,
|
|
1381
|
+
rddl=rddl,
|
|
1400
1382
|
logger=self.logger,
|
|
1401
|
-
use64bit=self.use64bit
|
|
1383
|
+
use64bit=self.use64bit
|
|
1384
|
+
)
|
|
1402
1385
|
self.test_compiled.compile(log_jax_expr=True, heading='EXACT MODEL')
|
|
1403
1386
|
|
|
1404
1387
|
def _jax_compile_optimizer(self):
|
|
@@ -1414,13 +1397,15 @@ class JaxBackpropPlanner:
|
|
|
1414
1397
|
train_rollouts = self.compiled.compile_rollouts(
|
|
1415
1398
|
policy=self.plan.train_policy,
|
|
1416
1399
|
n_steps=self.horizon,
|
|
1417
|
-
n_batch=self.batch_size_train
|
|
1400
|
+
n_batch=self.batch_size_train
|
|
1401
|
+
)
|
|
1418
1402
|
self.train_rollouts = train_rollouts
|
|
1419
1403
|
|
|
1420
1404
|
test_rollouts = self.test_compiled.compile_rollouts(
|
|
1421
1405
|
policy=self.plan.test_policy,
|
|
1422
1406
|
n_steps=self.horizon,
|
|
1423
|
-
n_batch=self.batch_size_test
|
|
1407
|
+
n_batch=self.batch_size_test
|
|
1408
|
+
)
|
|
1424
1409
|
self.test_rollouts = jax.jit(test_rollouts)
|
|
1425
1410
|
|
|
1426
1411
|
# initialization
|
|
@@ -1455,14 +1440,16 @@ class JaxBackpropPlanner:
|
|
|
1455
1440
|
_jax_wrapped_returns = self._jax_return(use_symlog)
|
|
1456
1441
|
|
|
1457
1442
|
# the loss is the average cumulative reward across all roll-outs
|
|
1458
|
-
def _jax_wrapped_plan_loss(key, policy_params,
|
|
1443
|
+
def _jax_wrapped_plan_loss(key, policy_params, policy_hyperparams,
|
|
1459
1444
|
subs, model_params):
|
|
1460
|
-
log = rollouts(
|
|
1445
|
+
log, model_params = rollouts(
|
|
1446
|
+
key, policy_params, policy_hyperparams, subs, model_params)
|
|
1461
1447
|
rewards = log['reward']
|
|
1462
1448
|
returns = _jax_wrapped_returns(rewards)
|
|
1463
1449
|
utility = utility_fn(returns, **utility_kwargs)
|
|
1464
1450
|
loss = -utility
|
|
1465
|
-
|
|
1451
|
+
aux = (log, model_params)
|
|
1452
|
+
return loss, aux
|
|
1466
1453
|
|
|
1467
1454
|
return _jax_wrapped_plan_loss
|
|
1468
1455
|
|
|
@@ -1470,30 +1457,45 @@ class JaxBackpropPlanner:
|
|
|
1470
1457
|
init = self.plan.initializer
|
|
1471
1458
|
optimizer = self.optimizer
|
|
1472
1459
|
|
|
1473
|
-
|
|
1474
|
-
|
|
1460
|
+
# initialize both the policy and its optimizer
|
|
1461
|
+
def _jax_wrapped_init_policy(key, policy_hyperparams, subs):
|
|
1462
|
+
policy_params = init(key, policy_hyperparams, subs)
|
|
1475
1463
|
opt_state = optimizer.init(policy_params)
|
|
1476
|
-
return policy_params, opt_state,
|
|
1464
|
+
return policy_params, opt_state, {}
|
|
1477
1465
|
|
|
1478
1466
|
return _jax_wrapped_init_policy
|
|
1479
1467
|
|
|
1480
1468
|
def _jax_update(self, loss):
|
|
1481
1469
|
optimizer = self.optimizer
|
|
1482
1470
|
projection = self.plan.projection
|
|
1471
|
+
use_ls = self.line_search_kwargs is not None
|
|
1483
1472
|
|
|
1484
1473
|
# calculate the plan gradient w.r.t. return loss and update optimizer
|
|
1485
1474
|
# also perform a projection step to satisfy constraints on actions
|
|
1486
|
-
def
|
|
1475
|
+
def _jax_wrapped_loss_swapped(policy_params, key, policy_hyperparams,
|
|
1476
|
+
subs, model_params):
|
|
1477
|
+
return loss(key, policy_params, policy_hyperparams, subs, model_params)[0]
|
|
1478
|
+
|
|
1479
|
+
def _jax_wrapped_plan_update(key, policy_params, policy_hyperparams,
|
|
1487
1480
|
subs, model_params, opt_state, opt_aux):
|
|
1488
1481
|
grad_fn = jax.value_and_grad(loss, argnums=1, has_aux=True)
|
|
1489
|
-
(loss_val, log), grad = grad_fn(
|
|
1490
|
-
key, policy_params,
|
|
1491
|
-
|
|
1482
|
+
(loss_val, (log, model_params)), grad = grad_fn(
|
|
1483
|
+
key, policy_params, policy_hyperparams, subs, model_params)
|
|
1484
|
+
if use_ls:
|
|
1485
|
+
updates, opt_state = optimizer.update(
|
|
1486
|
+
grad, opt_state, params=policy_params,
|
|
1487
|
+
value=loss_val, grad=grad, value_fn=_jax_wrapped_loss_swapped,
|
|
1488
|
+
key=key, policy_hyperparams=policy_hyperparams, subs=subs,
|
|
1489
|
+
model_params=model_params)
|
|
1490
|
+
else:
|
|
1491
|
+
updates, opt_state = optimizer.update(
|
|
1492
|
+
grad, opt_state, params=policy_params)
|
|
1492
1493
|
policy_params = optax.apply_updates(policy_params, updates)
|
|
1493
|
-
policy_params, converged = projection(policy_params,
|
|
1494
|
+
policy_params, converged = projection(policy_params, policy_hyperparams)
|
|
1494
1495
|
log['grad'] = grad
|
|
1495
1496
|
log['updates'] = updates
|
|
1496
|
-
return policy_params, converged, opt_state,
|
|
1497
|
+
return policy_params, converged, opt_state, opt_aux, \
|
|
1498
|
+
loss_val, log, model_params
|
|
1497
1499
|
|
|
1498
1500
|
return jax.jit(_jax_wrapped_plan_update)
|
|
1499
1501
|
|
|
@@ -1520,11 +1522,10 @@ class JaxBackpropPlanner:
|
|
|
1520
1522
|
for (state, next_state) in rddl.next_state.items():
|
|
1521
1523
|
init_train[next_state] = init_train[state]
|
|
1522
1524
|
init_test[next_state] = init_test[state]
|
|
1523
|
-
|
|
1524
1525
|
return init_train, init_test
|
|
1525
1526
|
|
|
1526
1527
|
def as_optimization_problem(
|
|
1527
|
-
self, key: Optional[random.PRNGKey]=None,
|
|
1528
|
+
self, key: Optional[random.PRNGKey]=None,
|
|
1528
1529
|
policy_hyperparams: Optional[Pytree]=None,
|
|
1529
1530
|
loss_function_updates_key: bool=True,
|
|
1530
1531
|
grad_function_updates_key: bool=False) -> Tuple[Callable, Callable, np.ndarray, Callable]:
|
|
@@ -1574,38 +1575,40 @@ class JaxBackpropPlanner:
|
|
|
1574
1575
|
loss_fn = self._jax_loss(self.train_rollouts)
|
|
1575
1576
|
|
|
1576
1577
|
@jax.jit
|
|
1577
|
-
def _loss_with_key(key, params_1d):
|
|
1578
|
+
def _loss_with_key(key, params_1d, model_params):
|
|
1578
1579
|
policy_params = unravel_fn(params_1d)
|
|
1579
|
-
loss_val, _ = loss_fn(
|
|
1580
|
-
|
|
1581
|
-
return loss_val
|
|
1580
|
+
loss_val, (_, model_params) = loss_fn(
|
|
1581
|
+
key, policy_params, policy_hyperparams, train_subs, model_params)
|
|
1582
|
+
return loss_val, model_params
|
|
1582
1583
|
|
|
1583
1584
|
@jax.jit
|
|
1584
|
-
def _grad_with_key(key, params_1d):
|
|
1585
|
+
def _grad_with_key(key, params_1d, model_params):
|
|
1585
1586
|
policy_params = unravel_fn(params_1d)
|
|
1586
1587
|
grad_fn = jax.grad(loss_fn, argnums=1, has_aux=True)
|
|
1587
|
-
grad_val, _ = grad_fn(
|
|
1588
|
-
|
|
1589
|
-
|
|
1590
|
-
return
|
|
1588
|
+
grad_val, (_, model_params) = grad_fn(
|
|
1589
|
+
key, policy_params, policy_hyperparams, train_subs, model_params)
|
|
1590
|
+
grad_val = jax.flatten_util.ravel_pytree(grad_val)[0]
|
|
1591
|
+
return grad_val, model_params
|
|
1591
1592
|
|
|
1592
1593
|
def _loss_function(params_1d):
|
|
1593
1594
|
nonlocal key
|
|
1595
|
+
nonlocal model_params
|
|
1594
1596
|
if loss_function_updates_key:
|
|
1595
1597
|
key, subkey = random.split(key)
|
|
1596
1598
|
else:
|
|
1597
1599
|
subkey = key
|
|
1598
|
-
loss_val = _loss_with_key(subkey, params_1d)
|
|
1600
|
+
loss_val, model_params = _loss_with_key(subkey, params_1d, model_params)
|
|
1599
1601
|
loss_val = float(loss_val)
|
|
1600
1602
|
return loss_val
|
|
1601
1603
|
|
|
1602
1604
|
def _grad_function(params_1d):
|
|
1603
1605
|
nonlocal key
|
|
1606
|
+
nonlocal model_params
|
|
1604
1607
|
if grad_function_updates_key:
|
|
1605
1608
|
key, subkey = random.split(key)
|
|
1606
1609
|
else:
|
|
1607
1610
|
subkey = key
|
|
1608
|
-
grad = _grad_with_key(subkey, params_1d)
|
|
1611
|
+
grad, model_params = _grad_with_key(subkey, params_1d, model_params)
|
|
1609
1612
|
grad = np.asarray(grad)
|
|
1610
1613
|
return grad
|
|
1611
1614
|
|
|
@@ -1620,9 +1623,9 @@ class JaxBackpropPlanner:
|
|
|
1620
1623
|
|
|
1621
1624
|
:param key: JAX PRNG key (derived from clock if not provided)
|
|
1622
1625
|
:param epochs: the maximum number of steps of gradient descent
|
|
1623
|
-
:param train_seconds: total time allocated for gradient descent
|
|
1624
|
-
:param
|
|
1625
|
-
:param
|
|
1626
|
+
:param train_seconds: total time allocated for gradient descent
|
|
1627
|
+
:param dashboard: dashboard to display training results
|
|
1628
|
+
:param dashboard_id: experiment id for the dashboard
|
|
1626
1629
|
:param model_params: optional model-parameters to override default
|
|
1627
1630
|
:param policy_hyperparams: hyper-parameters for the policy/plan, such as
|
|
1628
1631
|
weights for sigmoid wrapping boolean actions
|
|
@@ -1633,6 +1636,7 @@ class JaxBackpropPlanner:
|
|
|
1633
1636
|
:param print_summary: whether to print planner header, parameter
|
|
1634
1637
|
summary, and diagnosis
|
|
1635
1638
|
:param print_progress: whether to print the progress bar during training
|
|
1639
|
+
:param stopping_rule: stopping criterion
|
|
1636
1640
|
:param test_rolling_window: the test return is averaged on a rolling
|
|
1637
1641
|
window of the past test_rolling_window returns when updating the best
|
|
1638
1642
|
parameters found so far
|
|
@@ -1657,14 +1661,15 @@ class JaxBackpropPlanner:
|
|
|
1657
1661
|
def optimize_generator(self, key: Optional[random.PRNGKey]=None,
|
|
1658
1662
|
epochs: int=999999,
|
|
1659
1663
|
train_seconds: float=120.,
|
|
1660
|
-
|
|
1661
|
-
|
|
1664
|
+
dashboard: Optional[JaxPlannerDashboard]=None,
|
|
1665
|
+
dashboard_id: Optional[str]=None,
|
|
1662
1666
|
model_params: Optional[Dict[str, Any]]=None,
|
|
1663
1667
|
policy_hyperparams: Optional[Dict[str, Any]]=None,
|
|
1664
1668
|
subs: Optional[Dict[str, Any]]=None,
|
|
1665
1669
|
guess: Optional[Pytree]=None,
|
|
1666
1670
|
print_summary: bool=True,
|
|
1667
1671
|
print_progress: bool=True,
|
|
1672
|
+
stopping_rule: Optional[JaxPlannerStoppingRule]=None,
|
|
1668
1673
|
test_rolling_window: int=10,
|
|
1669
1674
|
tqdm_position: Optional[int]=None) -> Generator[Dict[str, Any], None, None]:
|
|
1670
1675
|
'''Returns a generator for computing an optimal policy or plan.
|
|
@@ -1673,9 +1678,9 @@ class JaxBackpropPlanner:
|
|
|
1673
1678
|
|
|
1674
1679
|
:param key: JAX PRNG key (derived from clock if not provided)
|
|
1675
1680
|
:param epochs: the maximum number of steps of gradient descent
|
|
1676
|
-
:param train_seconds: total time allocated for gradient descent
|
|
1677
|
-
:param
|
|
1678
|
-
:param
|
|
1681
|
+
:param train_seconds: total time allocated for gradient descent
|
|
1682
|
+
:param dashboard: dashboard to display training results
|
|
1683
|
+
:param dashboard_id: experiment id for the dashboard
|
|
1679
1684
|
:param model_params: optional model-parameters to override default
|
|
1680
1685
|
:param policy_hyperparams: hyper-parameters for the policy/plan, such as
|
|
1681
1686
|
weights for sigmoid wrapping boolean actions
|
|
@@ -1686,6 +1691,7 @@ class JaxBackpropPlanner:
|
|
|
1686
1691
|
:param print_summary: whether to print planner header, parameter
|
|
1687
1692
|
summary, and diagnosis
|
|
1688
1693
|
:param print_progress: whether to print the progress bar during training
|
|
1694
|
+
:param stopping_rule: stopping criterion
|
|
1689
1695
|
:param test_rolling_window: the test return is averaged on a rolling
|
|
1690
1696
|
window of the past test_rolling_window returns when updating the best
|
|
1691
1697
|
parameters found so far
|
|
@@ -1694,9 +1700,14 @@ class JaxBackpropPlanner:
|
|
|
1694
1700
|
start_time = time.time()
|
|
1695
1701
|
elapsed_outside_loop = 0
|
|
1696
1702
|
|
|
1703
|
+
# ======================================================================
|
|
1704
|
+
# INITIALIZATION OF HYPER-PARAMETERS
|
|
1705
|
+
# ======================================================================
|
|
1706
|
+
|
|
1697
1707
|
# if PRNG key is not provided
|
|
1698
1708
|
if key is None:
|
|
1699
1709
|
key = random.PRNGKey(round(time.time() * 1000))
|
|
1710
|
+
dash_key = key[1].item()
|
|
1700
1711
|
|
|
1701
1712
|
# if policy_hyperparams is not provided
|
|
1702
1713
|
if policy_hyperparams is None:
|
|
@@ -1723,7 +1734,7 @@ class JaxBackpropPlanner:
|
|
|
1723
1734
|
|
|
1724
1735
|
# print summary of parameters:
|
|
1725
1736
|
if print_summary:
|
|
1726
|
-
self.
|
|
1737
|
+
print(self.summarize_system())
|
|
1727
1738
|
self.summarize_hyperparameters()
|
|
1728
1739
|
print(f'optimize() call hyper-parameters:\n'
|
|
1729
1740
|
f' PRNG key ={key}\n'
|
|
@@ -1734,15 +1745,16 @@ class JaxBackpropPlanner:
|
|
|
1734
1745
|
f' override_subs_dict ={subs is not None}\n'
|
|
1735
1746
|
f' provide_param_guess={guess is not None}\n'
|
|
1736
1747
|
f' test_rolling_window={test_rolling_window}\n'
|
|
1737
|
-
f'
|
|
1738
|
-
f'
|
|
1748
|
+
f' dashboard ={dashboard is not None}\n'
|
|
1749
|
+
f' dashboard_id ={dashboard_id}\n'
|
|
1739
1750
|
f' print_summary ={print_summary}\n'
|
|
1740
|
-
f' print_progress ={print_progress}\n'
|
|
1741
|
-
|
|
1742
|
-
|
|
1743
|
-
|
|
1744
|
-
|
|
1745
|
-
|
|
1751
|
+
f' print_progress ={print_progress}\n'
|
|
1752
|
+
f' stopping_rule ={stopping_rule}\n')
|
|
1753
|
+
|
|
1754
|
+
# ======================================================================
|
|
1755
|
+
# INITIALIZATION OF STATE AND POLICY
|
|
1756
|
+
# ======================================================================
|
|
1757
|
+
|
|
1746
1758
|
# compute a batched version of the initial values
|
|
1747
1759
|
if subs is None:
|
|
1748
1760
|
subs = self.test_compiled.init_values
|
|
@@ -1773,7 +1785,11 @@ class JaxBackpropPlanner:
|
|
|
1773
1785
|
else:
|
|
1774
1786
|
policy_params = guess
|
|
1775
1787
|
opt_state = self.optimizer.init(policy_params)
|
|
1776
|
-
opt_aux =
|
|
1788
|
+
opt_aux = {}
|
|
1789
|
+
|
|
1790
|
+
# ======================================================================
|
|
1791
|
+
# INITIALIZATION OF RUNNING STATISTICS
|
|
1792
|
+
# ======================================================================
|
|
1777
1793
|
|
|
1778
1794
|
# initialize running statistics
|
|
1779
1795
|
best_params, best_loss, best_grad = policy_params, jnp.inf, jnp.inf
|
|
@@ -1783,30 +1799,43 @@ class JaxBackpropPlanner:
|
|
|
1783
1799
|
status = JaxPlannerStatus.NORMAL
|
|
1784
1800
|
is_all_zero_fn = lambda x: np.allclose(x, 0)
|
|
1785
1801
|
|
|
1786
|
-
# initialize
|
|
1787
|
-
if
|
|
1788
|
-
|
|
1789
|
-
|
|
1790
|
-
|
|
1791
|
-
|
|
1792
|
-
|
|
1793
|
-
|
|
1802
|
+
# initialize stopping criterion
|
|
1803
|
+
if stopping_rule is not None:
|
|
1804
|
+
stopping_rule.reset()
|
|
1805
|
+
|
|
1806
|
+
# initialize dash board
|
|
1807
|
+
if dashboard is not None:
|
|
1808
|
+
dashboard_id = dashboard.register_experiment(
|
|
1809
|
+
dashboard_id, dashboard.get_planner_info(self),
|
|
1810
|
+
key=dash_key, viz=self.dashboard_viz)
|
|
1811
|
+
|
|
1812
|
+
# ======================================================================
|
|
1813
|
+
# MAIN TRAINING LOOP BEGINS
|
|
1814
|
+
# ======================================================================
|
|
1794
1815
|
|
|
1795
|
-
# training loop
|
|
1796
1816
|
iters = range(epochs)
|
|
1797
1817
|
if print_progress:
|
|
1798
1818
|
iters = tqdm(iters, total=100, position=tqdm_position)
|
|
1799
1819
|
position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
|
|
1800
1820
|
|
|
1801
1821
|
for it in iters:
|
|
1822
|
+
|
|
1823
|
+
# ==================================================================
|
|
1824
|
+
# NEXT GRADIENT DESCENT STEP
|
|
1825
|
+
# ==================================================================
|
|
1826
|
+
|
|
1802
1827
|
status = JaxPlannerStatus.NORMAL
|
|
1803
1828
|
|
|
1804
1829
|
# update the parameters of the plan
|
|
1805
1830
|
key, subkey = random.split(key)
|
|
1806
|
-
policy_params, converged, opt_state, opt_aux,
|
|
1807
|
-
|
|
1831
|
+
(policy_params, converged, opt_state, opt_aux,
|
|
1832
|
+
train_loss, train_log, model_params) = \
|
|
1808
1833
|
self.update(subkey, policy_params, policy_hyperparams,
|
|
1809
1834
|
train_subs, model_params, opt_state, opt_aux)
|
|
1835
|
+
|
|
1836
|
+
# ==================================================================
|
|
1837
|
+
# STATUS CHECKS AND LOGGING
|
|
1838
|
+
# ==================================================================
|
|
1810
1839
|
|
|
1811
1840
|
# no progress
|
|
1812
1841
|
grad_norm_zero, _ = jax.tree_util.tree_flatten(
|
|
@@ -1825,12 +1854,11 @@ class JaxBackpropPlanner:
|
|
|
1825
1854
|
# numerical error
|
|
1826
1855
|
if not np.isfinite(train_loss):
|
|
1827
1856
|
raise_warning(
|
|
1828
|
-
f'
|
|
1829
|
-
'red')
|
|
1857
|
+
f'JAX planner aborted due to invalid loss {train_loss}.', 'red')
|
|
1830
1858
|
status = JaxPlannerStatus.INVALID_GRADIENT
|
|
1831
1859
|
|
|
1832
1860
|
# evaluate test losses and record best plan so far
|
|
1833
|
-
test_loss, log = self.test_loss(
|
|
1861
|
+
test_loss, (log, model_params_test) = self.test_loss(
|
|
1834
1862
|
subkey, policy_params, policy_hyperparams,
|
|
1835
1863
|
test_subs, model_params_test)
|
|
1836
1864
|
test_loss = rolling_test_loss.update(test_loss)
|
|
@@ -1839,34 +1867,16 @@ class JaxBackpropPlanner:
|
|
|
1839
1867
|
policy_params, test_loss, train_log['grad']
|
|
1840
1868
|
last_iter_improve = it
|
|
1841
1869
|
|
|
1842
|
-
# save the plan figure
|
|
1843
|
-
if plot is not None and it % plot_step == 0:
|
|
1844
|
-
xticks.append(it // plot_step)
|
|
1845
|
-
loss_values.append(test_loss.item())
|
|
1846
|
-
action_values = {name: values
|
|
1847
|
-
for (name, values) in log['fluents'].items()
|
|
1848
|
-
if name in self.rddl.action_fluents}
|
|
1849
|
-
returns = -np.sum(np.asarray(log['reward']), axis=1)
|
|
1850
|
-
plot.redraw(xticks, loss_values, action_values, returns)
|
|
1851
|
-
|
|
1852
|
-
# if the progress bar is used
|
|
1853
|
-
elapsed = time.time() - start_time - elapsed_outside_loop
|
|
1854
|
-
if print_progress:
|
|
1855
|
-
iters.n = int(100 * min(1, max(elapsed / train_seconds, it / epochs)))
|
|
1856
|
-
iters.set_description(
|
|
1857
|
-
f'{position_str} {it:6} it / {-train_loss:14.6f} train / '
|
|
1858
|
-
f'{-test_loss:14.6f} test / {-best_loss:14.6f} best / '
|
|
1859
|
-
f'{status.value} status')
|
|
1860
|
-
|
|
1861
1870
|
# reached computation budget
|
|
1871
|
+
elapsed = time.time() - start_time - elapsed_outside_loop
|
|
1862
1872
|
if elapsed >= train_seconds:
|
|
1863
1873
|
status = JaxPlannerStatus.TIME_BUDGET_REACHED
|
|
1864
1874
|
if it >= epochs - 1:
|
|
1865
1875
|
status = JaxPlannerStatus.ITER_BUDGET_REACHED
|
|
1866
1876
|
|
|
1867
|
-
#
|
|
1868
|
-
|
|
1869
|
-
|
|
1877
|
+
# build a callback
|
|
1878
|
+
progress_percent = int(100 * min(1, max(elapsed / train_seconds, it / epochs)))
|
|
1879
|
+
callback = {
|
|
1870
1880
|
'status': status,
|
|
1871
1881
|
'iteration': it,
|
|
1872
1882
|
'train_return':-train_loss,
|
|
@@ -1880,19 +1890,45 @@ class JaxBackpropPlanner:
|
|
|
1880
1890
|
'updates': train_log['updates'],
|
|
1881
1891
|
'elapsed_time': elapsed,
|
|
1882
1892
|
'key': key,
|
|
1893
|
+
'model_params': model_params,
|
|
1894
|
+
'progress': progress_percent,
|
|
1895
|
+
'train_log': train_log,
|
|
1883
1896
|
**log
|
|
1884
1897
|
}
|
|
1898
|
+
|
|
1899
|
+
# stopping condition reached
|
|
1900
|
+
if stopping_rule is not None and stopping_rule.monitor(callback):
|
|
1901
|
+
callback['status'] = status = JaxPlannerStatus.STOPPING_RULE_REACHED
|
|
1902
|
+
|
|
1903
|
+
# if the progress bar is used
|
|
1904
|
+
if print_progress:
|
|
1905
|
+
iters.n = progress_percent
|
|
1906
|
+
iters.set_description(
|
|
1907
|
+
f'{position_str} {it:6} it / {-train_loss:14.6f} train / '
|
|
1908
|
+
f'{-test_loss:14.6f} test / {-best_loss:14.6f} best / '
|
|
1909
|
+
f'{status.value} status'
|
|
1910
|
+
)
|
|
1911
|
+
|
|
1912
|
+
# dash-board
|
|
1913
|
+
if dashboard is not None:
|
|
1914
|
+
dashboard.update_experiment(dashboard_id, callback)
|
|
1915
|
+
|
|
1916
|
+
# yield the callback
|
|
1917
|
+
start_time_outside = time.time()
|
|
1918
|
+
yield callback
|
|
1885
1919
|
elapsed_outside_loop += (time.time() - start_time_outside)
|
|
1886
1920
|
|
|
1887
1921
|
# abortion check
|
|
1888
|
-
if status.
|
|
1889
|
-
break
|
|
1890
|
-
|
|
1922
|
+
if status.is_terminal():
|
|
1923
|
+
break
|
|
1924
|
+
|
|
1925
|
+
# ======================================================================
|
|
1926
|
+
# POST-PROCESSING AND CLEANUP
|
|
1927
|
+
# ======================================================================
|
|
1928
|
+
|
|
1891
1929
|
# release resources
|
|
1892
1930
|
if print_progress:
|
|
1893
1931
|
iters.close()
|
|
1894
|
-
if plot is not None:
|
|
1895
|
-
plot.close()
|
|
1896
1932
|
|
|
1897
1933
|
# validate the test return
|
|
1898
1934
|
if log:
|
|
@@ -1918,7 +1954,7 @@ class JaxBackpropPlanner:
|
|
|
1918
1954
|
f' best_grad_norm={grad_norm}\n'
|
|
1919
1955
|
f' diagnosis: {diagnosis}\n')
|
|
1920
1956
|
|
|
1921
|
-
def _perform_diagnosis(self, last_iter_improve,
|
|
1957
|
+
def _perform_diagnosis(self, last_iter_improve,
|
|
1922
1958
|
train_return, test_return, best_return, grad_norm):
|
|
1923
1959
|
max_grad_norm = max(jax.tree_util.tree_leaves(grad_norm))
|
|
1924
1960
|
grad_is_zero = np.allclose(max_grad_norm, 0)
|
|
@@ -2016,101 +2052,6 @@ class JaxBackpropPlanner:
|
|
|
2016
2052
|
return actions
|
|
2017
2053
|
|
|
2018
2054
|
|
|
2019
|
-
class JaxLineSearchPlanner(JaxBackpropPlanner):
|
|
2020
|
-
'''A class for optimizing an action sequence in the given RDDL MDP using
|
|
2021
|
-
linear search gradient descent, with the Armijo condition.'''
|
|
2022
|
-
|
|
2023
|
-
def __init__(self, *args,
|
|
2024
|
-
decay: float=0.8,
|
|
2025
|
-
c: float=0.1,
|
|
2026
|
-
step_max: float=1.0,
|
|
2027
|
-
step_min: float=1e-6,
|
|
2028
|
-
**kwargs) -> None:
|
|
2029
|
-
'''Creates a new gradient-based algorithm for optimizing action sequences
|
|
2030
|
-
(plan) in the given RDDL using line search. All arguments are the
|
|
2031
|
-
same as in the parent class, except:
|
|
2032
|
-
|
|
2033
|
-
:param decay: reduction factor of learning rate per line search iteration
|
|
2034
|
-
:param c: positive coefficient in Armijo condition, should be in (0, 1)
|
|
2035
|
-
:param step_max: initial learning rate for line search
|
|
2036
|
-
:param step_min: minimum possible learning rate (line search halts)
|
|
2037
|
-
'''
|
|
2038
|
-
self.decay = decay
|
|
2039
|
-
self.c = c
|
|
2040
|
-
self.step_max = step_max
|
|
2041
|
-
self.step_min = step_min
|
|
2042
|
-
if 'clip_grad' in kwargs:
|
|
2043
|
-
raise_warning('clip_grad parameter conflicts with '
|
|
2044
|
-
'line search planner and will be ignored.', 'red')
|
|
2045
|
-
del kwargs['clip_grad']
|
|
2046
|
-
super(JaxLineSearchPlanner, self).__init__(*args, **kwargs)
|
|
2047
|
-
|
|
2048
|
-
def summarize_hyperparameters(self) -> None:
|
|
2049
|
-
super(JaxLineSearchPlanner, self).summarize_hyperparameters()
|
|
2050
|
-
print(f'linesearch hyper-parameters:\n'
|
|
2051
|
-
f' decay ={self.decay}\n'
|
|
2052
|
-
f' c ={self.c}\n'
|
|
2053
|
-
f' lr_range=({self.step_min}, {self.step_max})')
|
|
2054
|
-
|
|
2055
|
-
def _jax_update(self, loss):
|
|
2056
|
-
optimizer = self.optimizer
|
|
2057
|
-
projection = self.plan.projection
|
|
2058
|
-
decay, c, lrmax, lrmin = self.decay, self.c, self.step_max, self.step_min
|
|
2059
|
-
|
|
2060
|
-
# initialize the line search routine
|
|
2061
|
-
@jax.jit
|
|
2062
|
-
def _jax_wrapped_line_search_init(key, policy_params, hyperparams,
|
|
2063
|
-
subs, model_params):
|
|
2064
|
-
(f, log), grad = jax.value_and_grad(loss, argnums=1, has_aux=True)(
|
|
2065
|
-
key, policy_params, hyperparams, subs, model_params)
|
|
2066
|
-
gnorm2 = jax.tree_map(lambda x: jnp.sum(jnp.square(x)), grad)
|
|
2067
|
-
gnorm2 = jax.tree_util.tree_reduce(jnp.add, gnorm2)
|
|
2068
|
-
log['grad'] = grad
|
|
2069
|
-
return f, grad, gnorm2, log
|
|
2070
|
-
|
|
2071
|
-
# compute the next trial solution
|
|
2072
|
-
@jax.jit
|
|
2073
|
-
def _jax_wrapped_line_search_trial(
|
|
2074
|
-
step, grad, key, params, hparams, subs, mparams, state):
|
|
2075
|
-
state.hyperparams['learning_rate'] = step
|
|
2076
|
-
updates, new_state = optimizer.update(grad, state)
|
|
2077
|
-
new_params = optax.apply_updates(params, updates)
|
|
2078
|
-
new_params, _ = projection(new_params, hparams)
|
|
2079
|
-
f_step, _ = loss(key, new_params, hparams, subs, mparams)
|
|
2080
|
-
return f_step, new_params, new_state
|
|
2081
|
-
|
|
2082
|
-
# main iteration of line search
|
|
2083
|
-
def _jax_wrapped_plan_update(key, policy_params, hyperparams,
|
|
2084
|
-
subs, model_params, opt_state, opt_aux):
|
|
2085
|
-
|
|
2086
|
-
# initialize the line search
|
|
2087
|
-
f, grad, gnorm2, log = _jax_wrapped_line_search_init(
|
|
2088
|
-
key, policy_params, hyperparams, subs, model_params)
|
|
2089
|
-
|
|
2090
|
-
# continue to reduce the learning rate until the Armijo condition holds
|
|
2091
|
-
trials = 0
|
|
2092
|
-
step = lrmax / decay
|
|
2093
|
-
f_step = np.inf
|
|
2094
|
-
best_f, best_step, best_params, best_state = np.inf, None, None, None
|
|
2095
|
-
while (f_step > f - c * step * gnorm2 and step * decay >= lrmin) \
|
|
2096
|
-
or not trials:
|
|
2097
|
-
trials += 1
|
|
2098
|
-
step *= decay
|
|
2099
|
-
f_step, new_params, new_state = _jax_wrapped_line_search_trial(
|
|
2100
|
-
step, grad, key, policy_params, hyperparams, subs,
|
|
2101
|
-
model_params, opt_state)
|
|
2102
|
-
if f_step < best_f:
|
|
2103
|
-
best_f, best_step, best_params, best_state = \
|
|
2104
|
-
f_step, step, new_params, new_state
|
|
2105
|
-
|
|
2106
|
-
log['updates'] = None
|
|
2107
|
-
log['line_search_iters'] = trials
|
|
2108
|
-
log['learning_rate'] = best_step
|
|
2109
|
-
return best_params, True, best_state, best_step, best_f, log
|
|
2110
|
-
|
|
2111
|
-
return _jax_wrapped_plan_update
|
|
2112
|
-
|
|
2113
|
-
|
|
2114
2055
|
# ***********************************************************************
|
|
2115
2056
|
# ALL VERSIONS OF RISK FUNCTIONS
|
|
2116
2057
|
#
|
|
@@ -2141,7 +2082,7 @@ def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
|
|
|
2141
2082
|
alpha_mask = jax.lax.stop_gradient(
|
|
2142
2083
|
returns <= jnp.percentile(returns, q=100 * alpha))
|
|
2143
2084
|
return jnp.sum(returns * alpha_mask) / jnp.sum(alpha_mask)
|
|
2144
|
-
|
|
2085
|
+
|
|
2145
2086
|
|
|
2146
2087
|
# ***********************************************************************
|
|
2147
2088
|
# ALL VERSIONS OF CONTROLLERS
|
|
@@ -2151,12 +2092,13 @@ def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
|
|
|
2151
2092
|
#
|
|
2152
2093
|
# ***********************************************************************
|
|
2153
2094
|
|
|
2095
|
+
|
|
2154
2096
|
class JaxOfflineController(BaseAgent):
|
|
2155
2097
|
'''A container class for a Jax policy trained offline.'''
|
|
2156
2098
|
|
|
2157
2099
|
use_tensor_obs = True
|
|
2158
2100
|
|
|
2159
|
-
def __init__(self, planner: JaxBackpropPlanner,
|
|
2101
|
+
def __init__(self, planner: JaxBackpropPlanner,
|
|
2160
2102
|
key: Optional[random.PRNGKey]=None,
|
|
2161
2103
|
eval_hyperparams: Optional[Dict[str, Any]]=None,
|
|
2162
2104
|
params: Optional[Pytree]=None,
|
|
@@ -2186,8 +2128,10 @@ class JaxOfflineController(BaseAgent):
|
|
|
2186
2128
|
self.params_given = params is not None
|
|
2187
2129
|
|
|
2188
2130
|
self.step = 0
|
|
2131
|
+
self.callback = None
|
|
2189
2132
|
if not self.train_on_reset and not self.params_given:
|
|
2190
2133
|
callback = self.planner.optimize(key=self.key, **self.train_kwargs)
|
|
2134
|
+
self.callback = callback
|
|
2191
2135
|
params = callback['best_params']
|
|
2192
2136
|
self.params = params
|
|
2193
2137
|
|
|
@@ -2202,6 +2146,7 @@ class JaxOfflineController(BaseAgent):
|
|
|
2202
2146
|
self.step = 0
|
|
2203
2147
|
if self.train_on_reset and not self.params_given:
|
|
2204
2148
|
callback = self.planner.optimize(key=self.key, **self.train_kwargs)
|
|
2149
|
+
self.callback = callback
|
|
2205
2150
|
self.params = callback['best_params']
|
|
2206
2151
|
|
|
2207
2152
|
|
|
@@ -2211,7 +2156,7 @@ class JaxOnlineController(BaseAgent):
|
|
|
2211
2156
|
|
|
2212
2157
|
use_tensor_obs = True
|
|
2213
2158
|
|
|
2214
|
-
def __init__(self, planner: JaxBackpropPlanner,
|
|
2159
|
+
def __init__(self, planner: JaxBackpropPlanner,
|
|
2215
2160
|
key: Optional[random.PRNGKey]=None,
|
|
2216
2161
|
eval_hyperparams: Optional[Dict[str, Any]]=None,
|
|
2217
2162
|
warm_start: bool=True,
|
|
@@ -2244,7 +2189,9 @@ class JaxOnlineController(BaseAgent):
|
|
|
2244
2189
|
key=self.key,
|
|
2245
2190
|
guess=self.guess,
|
|
2246
2191
|
subs=state,
|
|
2247
|
-
**self.train_kwargs
|
|
2192
|
+
**self.train_kwargs
|
|
2193
|
+
)
|
|
2194
|
+
self.callback = callback
|
|
2248
2195
|
params = callback['best_params']
|
|
2249
2196
|
self.key, subkey = random.split(self.key)
|
|
2250
2197
|
actions = planner.get_action(
|
|
@@ -2255,4 +2202,5 @@ class JaxOnlineController(BaseAgent):
|
|
|
2255
2202
|
|
|
2256
2203
|
def reset(self) -> None:
|
|
2257
2204
|
self.guess = None
|
|
2205
|
+
self.callback = None
|
|
2258
2206
|
|