pyRDDLGym-jax 2.8__py3-none-any.whl → 3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +1080 -906
- pyRDDLGym_jax/core/logic.py +1537 -1369
- pyRDDLGym_jax/core/model.py +75 -86
- pyRDDLGym_jax/core/planner.py +883 -935
- pyRDDLGym_jax/core/simulator.py +20 -17
- pyRDDLGym_jax/core/tuning.py +11 -7
- pyRDDLGym_jax/core/visualization.py +115 -78
- pyRDDLGym_jax/entry_point.py +2 -1
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +6 -8
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +5 -7
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +7 -8
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +7 -8
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +8 -9
- pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +5 -7
- pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +5 -7
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +7 -8
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +6 -7
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +6 -7
- pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +6 -8
- pyRDDLGym_jax/examples/configs/Quadcopter_physics_drp.cfg +17 -0
- pyRDDLGym_jax/examples/configs/Quadcopter_physics_slp.cfg +17 -0
- pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +5 -7
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +4 -7
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +5 -7
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +4 -7
- pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +5 -7
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +6 -7
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +6 -7
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +6 -7
- pyRDDLGym_jax/examples/configs/default_drp.cfg +5 -8
- pyRDDLGym_jax/examples/configs/default_replan.cfg +5 -8
- pyRDDLGym_jax/examples/configs/default_slp.cfg +5 -8
- pyRDDLGym_jax/examples/configs/tuning_drp.cfg +6 -8
- pyRDDLGym_jax/examples/configs/tuning_replan.cfg +6 -8
- pyRDDLGym_jax/examples/configs/tuning_slp.cfg +6 -8
- pyRDDLGym_jax/examples/run_plan.py +2 -2
- pyRDDLGym_jax/examples/run_tune.py +2 -2
- {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/METADATA +22 -23
- pyrddlgym_jax-3.0.dist-info/RECORD +51 -0
- {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax/examples/run_gradient.py +0 -102
- pyrddlgym_jax-2.8.dist-info/RECORD +0 -50
- {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/entry_points.txt +0 -0
- {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/licenses/LICENSE +0 -0
- {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/planner.py
CHANGED
|
@@ -43,8 +43,7 @@ import pickle
|
|
|
43
43
|
import sys
|
|
44
44
|
import time
|
|
45
45
|
import traceback
|
|
46
|
-
from typing import Any, Callable, Dict, Generator, Optional,
|
|
47
|
-
Union
|
|
46
|
+
from typing import Any, Callable, Dict, Generator, Optional, Sequence, Type, Tuple, Union
|
|
48
47
|
|
|
49
48
|
import haiku as hk
|
|
50
49
|
import jax
|
|
@@ -70,16 +69,18 @@ from pyRDDLGym.core.debug.exception import (
|
|
|
70
69
|
from pyRDDLGym.core.policy import BaseAgent
|
|
71
70
|
|
|
72
71
|
from pyRDDLGym_jax import __version__
|
|
73
|
-
from pyRDDLGym_jax.core import logic
|
|
74
72
|
from pyRDDLGym_jax.core.compiler import JaxRDDLCompiler
|
|
75
|
-
from pyRDDLGym_jax.core
|
|
73
|
+
from pyRDDLGym_jax.core import logic
|
|
74
|
+
from pyRDDLGym_jax.core.logic import (
|
|
75
|
+
JaxRDDLCompilerWithGrad, DefaultJaxRDDLCompilerWithGrad, stable_sigmoid
|
|
76
|
+
)
|
|
76
77
|
|
|
77
78
|
# try to load the dash board
|
|
78
79
|
try:
|
|
79
80
|
from pyRDDLGym_jax.core.visualization import JaxPlannerDashboard
|
|
80
81
|
except Exception:
|
|
81
82
|
raise_warning('Failed to load the dashboard visualization tool: '
|
|
82
|
-
'
|
|
83
|
+
'ensure all prerequisite packages are installed.', 'red')
|
|
83
84
|
traceback.print_exc()
|
|
84
85
|
JaxPlannerDashboard = None
|
|
85
86
|
|
|
@@ -128,33 +129,15 @@ def _getattr_any(packages, item):
|
|
|
128
129
|
|
|
129
130
|
|
|
130
131
|
def _load_config(config, args):
|
|
131
|
-
|
|
132
|
-
planner_args = {k: args['
|
|
133
|
-
train_args = {k: args['
|
|
134
|
-
|
|
135
|
-
# read the model settings
|
|
136
|
-
logic_name = model_args.get('logic', 'FuzzyLogic')
|
|
137
|
-
logic_kwargs = model_args.get('logic_kwargs', {})
|
|
138
|
-
if logic_name == 'FuzzyLogic':
|
|
139
|
-
tnorm_name = model_args.get('tnorm', 'ProductTNorm')
|
|
140
|
-
tnorm_kwargs = model_args.get('tnorm_kwargs', {})
|
|
141
|
-
comp_name = model_args.get('complement', 'StandardComplement')
|
|
142
|
-
comp_kwargs = model_args.get('complement_kwargs', {})
|
|
143
|
-
compare_name = model_args.get('comparison', 'SigmoidComparison')
|
|
144
|
-
compare_kwargs = model_args.get('comparison_kwargs', {})
|
|
145
|
-
sampling_name = model_args.get('sampling', 'SoftRandomSampling')
|
|
146
|
-
sampling_kwargs = model_args.get('sampling_kwargs', {})
|
|
147
|
-
rounding_name = model_args.get('rounding', 'SoftRounding')
|
|
148
|
-
rounding_kwargs = model_args.get('rounding_kwargs', {})
|
|
149
|
-
control_name = model_args.get('control', 'SoftControlFlow')
|
|
150
|
-
control_kwargs = model_args.get('control_kwargs', {})
|
|
151
|
-
logic_kwargs['tnorm'] = getattr(logic, tnorm_name)(**tnorm_kwargs)
|
|
152
|
-
logic_kwargs['complement'] = getattr(logic, comp_name)(**comp_kwargs)
|
|
153
|
-
logic_kwargs['comparison'] = getattr(logic, compare_name)(**compare_kwargs)
|
|
154
|
-
logic_kwargs['sampling'] = getattr(logic, sampling_name)(**sampling_kwargs)
|
|
155
|
-
logic_kwargs['rounding'] = getattr(logic, rounding_name)(**rounding_kwargs)
|
|
156
|
-
logic_kwargs['control'] = getattr(logic, control_name)(**control_kwargs)
|
|
132
|
+
compiler_kwargs = {k: args['Compiler'][k] for (k, _) in config.items('Compiler')}
|
|
133
|
+
planner_args = {k: args['Planner'][k] for (k, _) in config.items('Planner')}
|
|
134
|
+
train_args = {k: args['Optimize'][k] for (k, _) in config.items('Optimize')}
|
|
157
135
|
|
|
136
|
+
# read the compiler settings
|
|
137
|
+
compiler_name = compiler_kwargs.pop('method', 'DefaultJaxRDDLCompilerWithGrad')
|
|
138
|
+
planner_args['compiler'] = getattr(logic, compiler_name)
|
|
139
|
+
planner_args['compiler_kwargs'] = compiler_kwargs
|
|
140
|
+
|
|
158
141
|
# read the policy settings
|
|
159
142
|
plan_method = planner_args.pop('method')
|
|
160
143
|
plan_kwargs = planner_args.pop('method_kwargs', {})
|
|
@@ -183,7 +166,6 @@ def _load_config(config, args):
|
|
|
183
166
|
plan_kwargs['activation'] = activation
|
|
184
167
|
|
|
185
168
|
# read the planner settings
|
|
186
|
-
planner_args['logic'] = getattr(logic, logic_name)(**logic_kwargs)
|
|
187
169
|
planner_args['plan'] = getattr(sys.modules[__name__], plan_method)(**plan_kwargs)
|
|
188
170
|
|
|
189
171
|
# planner optimizer
|
|
@@ -220,11 +202,11 @@ def _load_config(config, args):
|
|
|
220
202
|
train_args['key'] = random.PRNGKey(planner_key)
|
|
221
203
|
|
|
222
204
|
# dashboard
|
|
223
|
-
dashboard_key =
|
|
205
|
+
dashboard_key = planner_args.get('dashboard', None)
|
|
224
206
|
if dashboard_key is not None and dashboard_key and JaxPlannerDashboard is not None:
|
|
225
|
-
|
|
207
|
+
planner_args['dashboard'] = JaxPlannerDashboard()
|
|
226
208
|
elif dashboard_key is not None:
|
|
227
|
-
del
|
|
209
|
+
del planner_args['dashboard']
|
|
228
210
|
|
|
229
211
|
# optimize call stopping rule
|
|
230
212
|
stopping_rule = train_args.get('stopping_rule', None)
|
|
@@ -253,102 +235,6 @@ def load_config_from_string(value: str) -> Tuple[Kwargs, ...]:
|
|
|
253
235
|
config, args = _parse_config_string(value)
|
|
254
236
|
return _load_config(config, args)
|
|
255
237
|
|
|
256
|
-
|
|
257
|
-
# ***********************************************************************
|
|
258
|
-
# MODEL RELAXATIONS
|
|
259
|
-
#
|
|
260
|
-
# - replace discrete ops in state dynamics/reward with differentiable ones
|
|
261
|
-
#
|
|
262
|
-
# ***********************************************************************
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
266
|
-
'''Compiles a RDDL AST representation to an equivalent JAX representation.
|
|
267
|
-
Unlike its parent class, this class treats all fluents as real-valued, and
|
|
268
|
-
replaces all mathematical operations by equivalent ones with a well defined
|
|
269
|
-
(e.g. non-zero) gradient where appropriate.
|
|
270
|
-
'''
|
|
271
|
-
|
|
272
|
-
def __init__(self, *args,
|
|
273
|
-
logic: Logic=FuzzyLogic(),
|
|
274
|
-
cpfs_without_grad: Optional[Set[str]]=None,
|
|
275
|
-
print_warnings: bool=True,
|
|
276
|
-
**kwargs) -> None:
|
|
277
|
-
'''Creates a new RDDL to Jax compiler, where operations that are not
|
|
278
|
-
differentiable are converted to approximate forms that have defined gradients.
|
|
279
|
-
|
|
280
|
-
:param *args: arguments to pass to base compiler
|
|
281
|
-
:param logic: Fuzzy logic object that specifies how exact operations
|
|
282
|
-
are converted to their approximate forms: this class may be subclassed
|
|
283
|
-
to customize these operations
|
|
284
|
-
:param cpfs_without_grad: which CPFs do not have gradients (use straight
|
|
285
|
-
through gradient trick)
|
|
286
|
-
:param print_warnings: whether to print warnings
|
|
287
|
-
:param *kwargs: keyword arguments to pass to base compiler
|
|
288
|
-
'''
|
|
289
|
-
super(JaxRDDLCompilerWithGrad, self).__init__(*args, **kwargs)
|
|
290
|
-
|
|
291
|
-
self.logic = logic
|
|
292
|
-
self.logic.set_use64bit(self.use64bit)
|
|
293
|
-
if cpfs_without_grad is None:
|
|
294
|
-
cpfs_without_grad = set()
|
|
295
|
-
self.cpfs_without_grad = cpfs_without_grad
|
|
296
|
-
self.print_warnings = print_warnings
|
|
297
|
-
|
|
298
|
-
# actions and CPFs must be continuous
|
|
299
|
-
pvars_cast = set()
|
|
300
|
-
for (var, values) in self.init_values.items():
|
|
301
|
-
self.init_values[var] = np.asarray(values, dtype=self.REAL)
|
|
302
|
-
if not np.issubdtype(np.result_type(values), np.floating):
|
|
303
|
-
pvars_cast.add(var)
|
|
304
|
-
if self.print_warnings and pvars_cast:
|
|
305
|
-
message = termcolor.colored(
|
|
306
|
-
f'[INFO] JAX gradient compiler will cast p-vars {pvars_cast} to float.',
|
|
307
|
-
'green')
|
|
308
|
-
print(message)
|
|
309
|
-
|
|
310
|
-
# overwrite basic operations with fuzzy ones
|
|
311
|
-
self.OPS = logic.get_operator_dicts()
|
|
312
|
-
|
|
313
|
-
def _jax_stop_grad(self, jax_expr):
|
|
314
|
-
def _jax_wrapped_stop_grad(x, params, key):
|
|
315
|
-
sample, key, error, params = jax_expr(x, params, key)
|
|
316
|
-
sample = jax.lax.stop_gradient(sample)
|
|
317
|
-
return sample, key, error, params
|
|
318
|
-
return _jax_wrapped_stop_grad
|
|
319
|
-
|
|
320
|
-
def _compile_cpfs(self, init_params):
|
|
321
|
-
|
|
322
|
-
# cpfs will all be cast to float
|
|
323
|
-
cpfs_cast = set()
|
|
324
|
-
jax_cpfs = {}
|
|
325
|
-
for (_, cpfs) in self.levels.items():
|
|
326
|
-
for cpf in cpfs:
|
|
327
|
-
_, expr = self.rddl.cpfs[cpf]
|
|
328
|
-
jax_cpfs[cpf] = self._jax(expr, init_params, dtype=self.REAL)
|
|
329
|
-
if self.rddl.variable_ranges[cpf] != 'real':
|
|
330
|
-
cpfs_cast.add(cpf)
|
|
331
|
-
if cpf in self.cpfs_without_grad:
|
|
332
|
-
jax_cpfs[cpf] = self._jax_stop_grad(jax_cpfs[cpf])
|
|
333
|
-
|
|
334
|
-
if self.print_warnings and cpfs_cast:
|
|
335
|
-
message = termcolor.colored(
|
|
336
|
-
f'[INFO] JAX gradient compiler will cast CPFs {cpfs_cast} to float.',
|
|
337
|
-
'green')
|
|
338
|
-
print(message)
|
|
339
|
-
if self.print_warnings and self.cpfs_without_grad:
|
|
340
|
-
message = termcolor.colored(
|
|
341
|
-
f'[INFO] Gradients will not flow through CPFs {self.cpfs_without_grad}.',
|
|
342
|
-
'green')
|
|
343
|
-
print(message)
|
|
344
|
-
|
|
345
|
-
return jax_cpfs
|
|
346
|
-
|
|
347
|
-
def _jax_kron(self, expr, init_params):
|
|
348
|
-
arg, = expr.args
|
|
349
|
-
arg = self._jax(arg, init_params)
|
|
350
|
-
return arg
|
|
351
|
-
|
|
352
238
|
|
|
353
239
|
# ***********************************************************************
|
|
354
240
|
# ALL VERSIONS OF STATE PREPROCESSING FOR DRP
|
|
@@ -369,15 +255,15 @@ class Preprocessor(metaclass=ABCMeta):
|
|
|
369
255
|
self._transform = None
|
|
370
256
|
|
|
371
257
|
@property
|
|
372
|
-
def initialize(self):
|
|
258
|
+
def initialize(self) -> Callable:
|
|
373
259
|
return self._initializer
|
|
374
260
|
|
|
375
261
|
@property
|
|
376
|
-
def update(self):
|
|
262
|
+
def update(self) -> Callable:
|
|
377
263
|
return self._update
|
|
378
264
|
|
|
379
265
|
@property
|
|
380
|
-
def transform(self):
|
|
266
|
+
def transform(self) -> Callable:
|
|
381
267
|
return self._transform
|
|
382
268
|
|
|
383
269
|
@abstractmethod
|
|
@@ -421,26 +307,25 @@ class StaticNormalizer(Preprocessor):
|
|
|
421
307
|
self._initializer = jax.jit(_jax_wrapped_normalizer_init)
|
|
422
308
|
|
|
423
309
|
# static bounds
|
|
424
|
-
def _jax_wrapped_normalizer_update(
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
return stats
|
|
310
|
+
def _jax_wrapped_normalizer_update(fls, stats):
|
|
311
|
+
return {var: (jnp.asarray(lower, dtype=compiled.REAL),
|
|
312
|
+
jnp.asarray(upper, dtype=compiled.REAL))
|
|
313
|
+
for (var, (lower, upper)) in bounded_vars.items()}
|
|
429
314
|
self._update = jax.jit(_jax_wrapped_normalizer_update)
|
|
430
315
|
|
|
431
316
|
# apply min max scaling
|
|
432
|
-
def _jax_wrapped_normalizer_transform(
|
|
433
|
-
|
|
434
|
-
for (var, values) in
|
|
317
|
+
def _jax_wrapped_normalizer_transform(fls, stats):
|
|
318
|
+
new_fls = {}
|
|
319
|
+
for (var, values) in fls.items():
|
|
435
320
|
if var in stats:
|
|
436
321
|
lower, upper = stats[var]
|
|
437
322
|
new_dims = jnp.ndim(values) - jnp.ndim(lower)
|
|
438
323
|
lower = lower[(jnp.newaxis,) * new_dims + (...,)]
|
|
439
324
|
upper = upper[(jnp.newaxis,) * new_dims + (...,)]
|
|
440
|
-
|
|
325
|
+
new_fls[var] = (values - lower) / (upper - lower)
|
|
441
326
|
else:
|
|
442
|
-
|
|
443
|
-
return
|
|
327
|
+
new_fls[var] = values
|
|
328
|
+
return new_fls
|
|
444
329
|
self._transform = jax.jit(_jax_wrapped_normalizer_transform)
|
|
445
330
|
|
|
446
331
|
|
|
@@ -478,40 +363,38 @@ class JaxPlan(metaclass=ABCMeta):
|
|
|
478
363
|
pass
|
|
479
364
|
|
|
480
365
|
@property
|
|
481
|
-
def initializer(self):
|
|
366
|
+
def initializer(self) -> Callable:
|
|
482
367
|
return self._initializer
|
|
483
368
|
|
|
484
369
|
@initializer.setter
|
|
485
|
-
def initializer(self, value):
|
|
370
|
+
def initializer(self, value: Callable) -> None:
|
|
486
371
|
self._initializer = value
|
|
487
372
|
|
|
488
373
|
@property
|
|
489
|
-
def train_policy(self):
|
|
374
|
+
def train_policy(self) -> Callable:
|
|
490
375
|
return self._train_policy
|
|
491
376
|
|
|
492
377
|
@train_policy.setter
|
|
493
|
-
def train_policy(self, value):
|
|
378
|
+
def train_policy(self, value: Callable) -> None:
|
|
494
379
|
self._train_policy = value
|
|
495
380
|
|
|
496
381
|
@property
|
|
497
|
-
def test_policy(self):
|
|
382
|
+
def test_policy(self) -> Callable:
|
|
498
383
|
return self._test_policy
|
|
499
384
|
|
|
500
385
|
@test_policy.setter
|
|
501
|
-
def test_policy(self, value):
|
|
386
|
+
def test_policy(self, value: Callable) -> None:
|
|
502
387
|
self._test_policy = value
|
|
503
388
|
|
|
504
389
|
@property
|
|
505
|
-
def projection(self):
|
|
390
|
+
def projection(self) -> Callable:
|
|
506
391
|
return self._projection
|
|
507
392
|
|
|
508
393
|
@projection.setter
|
|
509
|
-
def projection(self, value):
|
|
394
|
+
def projection(self, value: Callable) -> None:
|
|
510
395
|
self._projection = value
|
|
511
396
|
|
|
512
|
-
def _calculate_action_info(self, compiled
|
|
513
|
-
user_bounds: Bounds,
|
|
514
|
-
horizon: int):
|
|
397
|
+
def _calculate_action_info(self, compiled, user_bounds, horizon):
|
|
515
398
|
shapes, bounds, bounds_safe, cond_lists = {}, {}, {}, {}
|
|
516
399
|
for (name, prange) in compiled.rddl.variable_ranges.items():
|
|
517
400
|
if compiled.rddl.variable_types[name] != 'action-fluent':
|
|
@@ -522,7 +405,8 @@ class JaxPlan(metaclass=ABCMeta):
|
|
|
522
405
|
keys = list(compiled.JAX_TYPES.keys()) + list(compiled.rddl.enum_types)
|
|
523
406
|
raise RDDLTypeError(
|
|
524
407
|
f'Invalid range <{prange}> of action-fluent <{name}>, '
|
|
525
|
-
f'must be one of {keys}.'
|
|
408
|
+
f'must be one of {keys}.'
|
|
409
|
+
)
|
|
526
410
|
|
|
527
411
|
# clip boolean to (0, 1), otherwise use the RDDL action bounds
|
|
528
412
|
# or the user defined action bounds if provided
|
|
@@ -530,15 +414,22 @@ class JaxPlan(metaclass=ABCMeta):
|
|
|
530
414
|
if prange == 'bool':
|
|
531
415
|
lower, upper = None, None
|
|
532
416
|
else:
|
|
417
|
+
|
|
418
|
+
# enum values are ordered from 0 to number of objects - 1
|
|
533
419
|
if prange in compiled.rddl.enum_types:
|
|
534
420
|
lower = np.zeros(shape=shapes[name][1:])
|
|
535
421
|
upper = len(compiled.rddl.type_to_objects[prange]) - 1
|
|
536
422
|
upper = np.ones(shape=shapes[name][1:]) * upper
|
|
537
423
|
else:
|
|
538
424
|
lower, upper = compiled.constraints.bounds[name]
|
|
425
|
+
|
|
426
|
+
# override with user defined bounds
|
|
539
427
|
lower, upper = user_bounds.get(name, (lower, upper))
|
|
540
428
|
lower = np.asarray(lower, dtype=compiled.REAL)
|
|
541
429
|
upper = np.asarray(upper, dtype=compiled.REAL)
|
|
430
|
+
|
|
431
|
+
# get masks for a jax conditional statement to avoid numerical errors
|
|
432
|
+
# for infinite values
|
|
542
433
|
lower_finite = np.isfinite(lower)
|
|
543
434
|
upper_finite = np.isfinite(upper)
|
|
544
435
|
bounds_safe[name] = (np.where(lower_finite, lower, 0.0),
|
|
@@ -548,21 +439,173 @@ class JaxPlan(metaclass=ABCMeta):
|
|
|
548
439
|
~lower_finite & upper_finite,
|
|
549
440
|
~lower_finite & ~upper_finite]
|
|
550
441
|
bounds[name] = (lower, upper)
|
|
442
|
+
|
|
551
443
|
if compiled.print_warnings:
|
|
552
|
-
|
|
444
|
+
print(termcolor.colored(
|
|
553
445
|
f'[INFO] Bounds of action-fluent <{name}> set to {bounds[name]}.',
|
|
554
|
-
'
|
|
555
|
-
|
|
446
|
+
'dark_grey'
|
|
447
|
+
))
|
|
556
448
|
return shapes, bounds, bounds_safe, cond_lists
|
|
557
449
|
|
|
558
|
-
def _count_bool_actions(self, rddl
|
|
450
|
+
def _count_bool_actions(self, rddl):
|
|
559
451
|
constraint = rddl.max_allowed_actions
|
|
560
452
|
num_bool_actions = sum(np.size(values)
|
|
561
453
|
for (var, values) in rddl.action_fluents.items()
|
|
562
454
|
if rddl.variable_ranges[var] == 'bool')
|
|
563
455
|
return num_bool_actions, constraint
|
|
564
456
|
|
|
457
|
+
|
|
458
|
+
class JaxActionProjection(metaclass=ABCMeta):
|
|
459
|
+
'''Base of all straight-line plan action projections.'''
|
|
460
|
+
|
|
461
|
+
@abstractmethod
|
|
462
|
+
def compile(self, *args, **kwargs) -> Callable:
|
|
463
|
+
pass
|
|
565
464
|
|
|
465
|
+
|
|
466
|
+
class JaxSortingActionProjection(JaxActionProjection):
|
|
467
|
+
'''Action projection using sorting method.'''
|
|
468
|
+
|
|
469
|
+
def compile(self, ranges: Dict[str, str], noop: Dict[str, Any],
|
|
470
|
+
wrap_sigmoid: bool, allowed_actions: int, bool_threshold: float,
|
|
471
|
+
jax_bool_to_box: Callable, *args, **kwargs) -> Callable:
|
|
472
|
+
|
|
473
|
+
# shift the boolean actions uniformly, clipping at the min/max values
|
|
474
|
+
# the amount to move is such that only top allowed_actions actions
|
|
475
|
+
# are still active (e.g. not equal to noop) after the shift
|
|
476
|
+
def _jax_wrapped_sorting_project(params, hyperparams):
|
|
477
|
+
|
|
478
|
+
# find the amount to shift action parameters: if noop=True reflect parameter
|
|
479
|
+
scores = []
|
|
480
|
+
for (var, param) in params.items():
|
|
481
|
+
if ranges[var] == 'bool':
|
|
482
|
+
param_flat = jnp.ravel(param, order='C')
|
|
483
|
+
if noop[var]:
|
|
484
|
+
if wrap_sigmoid:
|
|
485
|
+
param_flat = -param_flat
|
|
486
|
+
else:
|
|
487
|
+
param_flat = 1.0 - param_flat
|
|
488
|
+
scores.append(param_flat)
|
|
489
|
+
scores = jnp.concatenate(scores)
|
|
490
|
+
descending = jnp.sort(scores)[::-1]
|
|
491
|
+
kplus1st_greatest = descending[allowed_actions]
|
|
492
|
+
surplus = jnp.maximum(kplus1st_greatest - bool_threshold, 0.0)
|
|
493
|
+
|
|
494
|
+
# perform the shift
|
|
495
|
+
new_params = {}
|
|
496
|
+
for (var, param) in params.items():
|
|
497
|
+
if ranges[var] == 'bool':
|
|
498
|
+
if noop[var]:
|
|
499
|
+
new_param = param + surplus
|
|
500
|
+
else:
|
|
501
|
+
new_param = param - surplus
|
|
502
|
+
new_params[var] = jax_bool_to_box(var, new_param, hyperparams)
|
|
503
|
+
else:
|
|
504
|
+
new_params[var] = param
|
|
505
|
+
converged = jnp.array(True, dtype=jnp.bool_)
|
|
506
|
+
return new_params, converged
|
|
507
|
+
return _jax_wrapped_sorting_project
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
class JaxSogbofaActionProjection(JaxActionProjection):
|
|
511
|
+
'''Action projection using the SOGBOFA method.'''
|
|
512
|
+
|
|
513
|
+
def compile(self, ranges: Dict[str, str], noop: Dict[str, Any],
|
|
514
|
+
allowed_actions: int, max_constraint_iter: int,
|
|
515
|
+
jax_param_to_action: Callable, jax_action_to_param: Callable,
|
|
516
|
+
min_action: float, max_action: float, real_dtype: type, *args, **kwargs) -> Callable:
|
|
517
|
+
|
|
518
|
+
# calculate the surplus of actions above max-nondef-actions
|
|
519
|
+
def _jax_wrapped_sogbofa_surplus(actions):
|
|
520
|
+
sum_action = jnp.array(0.0, dtype=real_dtype)
|
|
521
|
+
k = jnp.array(0, dtype=jnp.int32)
|
|
522
|
+
for (var, action) in actions.items():
|
|
523
|
+
if ranges[var] == 'bool':
|
|
524
|
+
if noop[var]:
|
|
525
|
+
action = 1 - action
|
|
526
|
+
sum_action = sum_action + jnp.sum(action)
|
|
527
|
+
k = k + jnp.count_nonzero(action)
|
|
528
|
+
surplus = jnp.maximum(sum_action - allowed_actions, 0.0)
|
|
529
|
+
return surplus, k
|
|
530
|
+
|
|
531
|
+
# return whether the surplus is positive or reached compute limit
|
|
532
|
+
def _jax_wrapped_sogbofa_continue(values):
|
|
533
|
+
it, _, surplus, k = values
|
|
534
|
+
return jnp.logical_and(
|
|
535
|
+
it < max_constraint_iter, jnp.logical_and(surplus > 0, k > 0))
|
|
536
|
+
|
|
537
|
+
# reduce all bool action values by the surplus clipping at minimum
|
|
538
|
+
# for no-op = True, do the opposite, i.e. increase all
|
|
539
|
+
# bool action values by surplus clipping at maximum
|
|
540
|
+
def _jax_wrapped_sogbofa_subtract_surplus(values):
|
|
541
|
+
it, actions, surplus, k = values
|
|
542
|
+
amount = surplus / k
|
|
543
|
+
new_actions = {}
|
|
544
|
+
for (var, action) in actions.items():
|
|
545
|
+
if ranges[var] == 'bool':
|
|
546
|
+
if noop[var]:
|
|
547
|
+
new_actions[var] = jnp.minimum(action + amount, 1)
|
|
548
|
+
else:
|
|
549
|
+
new_actions[var] = jnp.maximum(action - amount, 0)
|
|
550
|
+
else:
|
|
551
|
+
new_actions[var] = action
|
|
552
|
+
new_surplus, new_k = _jax_wrapped_sogbofa_surplus(new_actions)
|
|
553
|
+
new_it = it + 1
|
|
554
|
+
return new_it, new_actions, new_surplus, new_k
|
|
555
|
+
|
|
556
|
+
# apply the surplus to the actions until it becomes zero
|
|
557
|
+
def _jax_wrapped_sogbofa_project(params, hyperparams):
|
|
558
|
+
|
|
559
|
+
# convert parameters to actions
|
|
560
|
+
actions = {}
|
|
561
|
+
for (var, param) in params.items():
|
|
562
|
+
if ranges[var] == 'bool':
|
|
563
|
+
actions[var] = jax_param_to_action(var, param, hyperparams)
|
|
564
|
+
else:
|
|
565
|
+
actions[var] = param
|
|
566
|
+
|
|
567
|
+
# run SOGBOFA loop on the actions to get adjusted actions
|
|
568
|
+
surplus, k = _jax_wrapped_sogbofa_surplus(actions)
|
|
569
|
+
_, actions, surplus, k = jax.lax.while_loop(
|
|
570
|
+
cond_fun=_jax_wrapped_sogbofa_continue,
|
|
571
|
+
body_fun=_jax_wrapped_sogbofa_subtract_surplus,
|
|
572
|
+
init_val=(0, actions, surplus, k)
|
|
573
|
+
)
|
|
574
|
+
converged = jnp.logical_not(surplus > 0)
|
|
575
|
+
|
|
576
|
+
# check for any remaining constraint violation
|
|
577
|
+
total_bool = jnp.array(0, dtype=jnp.int32)
|
|
578
|
+
for (var, action) in actions.items():
|
|
579
|
+
if ranges[var] == 'bool':
|
|
580
|
+
if noop[var]:
|
|
581
|
+
total_bool = total_bool + jnp.count_nonzero(action < 0.5)
|
|
582
|
+
else:
|
|
583
|
+
total_bool = total_bool + jnp.count_nonzero(action > 0.5)
|
|
584
|
+
excess = jnp.maximum(total_bool - allowed_actions, 0)
|
|
585
|
+
|
|
586
|
+
# convert the adjusted actions back to parameters
|
|
587
|
+
# reduce the excess number of parameters that are non-noop above constraint
|
|
588
|
+
new_params = {}
|
|
589
|
+
for (var, action) in actions.items():
|
|
590
|
+
if ranges[var] == 'bool':
|
|
591
|
+
action = jnp.clip(action, min_action, max_action)
|
|
592
|
+
flat_action = jnp.ravel(action, order='C')
|
|
593
|
+
if noop[var]:
|
|
594
|
+
ranks = jnp.cumsum(flat_action < 0.5)
|
|
595
|
+
replace_mask = (flat_action < 0.5) & (ranks <= excess)
|
|
596
|
+
else:
|
|
597
|
+
ranks = jnp.cumsum(flat_action > 0.5)
|
|
598
|
+
replace_mask = (flat_action > 0.5) & (ranks <= excess)
|
|
599
|
+
flat_action = jnp.where(replace_mask, 0.5, flat_action)
|
|
600
|
+
action = jnp.reshape(flat_action, jnp.shape(action))
|
|
601
|
+
new_params[var] = jax_action_to_param(var, action, hyperparams)
|
|
602
|
+
excess = jnp.maximum(excess - jnp.count_nonzero(replace_mask), 0)
|
|
603
|
+
else:
|
|
604
|
+
new_params[var] = action
|
|
605
|
+
return new_params, converged
|
|
606
|
+
return _jax_wrapped_sogbofa_project
|
|
607
|
+
|
|
608
|
+
|
|
566
609
|
class JaxStraightLinePlan(JaxPlan):
|
|
567
610
|
'''A straight line plan implementation in JAX'''
|
|
568
611
|
|
|
@@ -607,7 +650,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
607
650
|
def __str__(self) -> str:
|
|
608
651
|
bounds = '\n '.join(
|
|
609
652
|
map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
|
|
610
|
-
return (f'policy hyper-parameters:\n'
|
|
653
|
+
return (f'[INFO] policy hyper-parameters:\n'
|
|
611
654
|
f' initializer={self._initializer_base}\n'
|
|
612
655
|
f' constraint-sat strategy (simple):\n'
|
|
613
656
|
f' parsed_action_bounds =\n {bounds}\n'
|
|
@@ -630,16 +673,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
630
673
|
compiled, _bounds, horizon)
|
|
631
674
|
self.bounds = bounds
|
|
632
675
|
|
|
633
|
-
# action
|
|
634
|
-
bool_action_count, allowed_actions = self._count_bool_actions(rddl)
|
|
635
|
-
use_constraint_satisfaction = allowed_actions < bool_action_count
|
|
636
|
-
if compiled.print_warnings and use_constraint_satisfaction:
|
|
637
|
-
message = termcolor.colored(
|
|
638
|
-
f'[INFO] SLP will use projected gradient to satisfy '
|
|
639
|
-
f'max_nondef_actions since total boolean actions '
|
|
640
|
-
f'{bool_action_count} > max_nondef_actions {allowed_actions}.', 'green')
|
|
641
|
-
print(message)
|
|
642
|
-
|
|
676
|
+
# get the noop action values
|
|
643
677
|
noop = {var: (values[0] if isinstance(values, list) else values)
|
|
644
678
|
for (var, values) in rddl.action_fluents.items()}
|
|
645
679
|
bool_key = 'bool__'
|
|
@@ -649,14 +683,18 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
649
683
|
#
|
|
650
684
|
# ***********************************************************************
|
|
651
685
|
|
|
652
|
-
#
|
|
686
|
+
# boolean actions are parameters wrapped by sigmoid to ensure [0, 1]:
|
|
687
|
+
#
|
|
688
|
+
# action = sigmoid(weight * param)
|
|
689
|
+
#
|
|
690
|
+
# here weight is a hyper-parameter and param is the trainable policy parameter
|
|
653
691
|
wrap_sigmoid = self._wrap_sigmoid
|
|
654
692
|
bool_threshold = 0.0 if wrap_sigmoid else 0.5
|
|
655
693
|
|
|
656
694
|
def _jax_bool_param_to_action(var, param, hyperparams):
|
|
657
695
|
if wrap_sigmoid:
|
|
658
696
|
weight = hyperparams[var]
|
|
659
|
-
return
|
|
697
|
+
return stable_sigmoid(weight * param)
|
|
660
698
|
else:
|
|
661
699
|
return param
|
|
662
700
|
|
|
@@ -666,7 +704,10 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
666
704
|
return jax.scipy.special.logit(action) / weight
|
|
667
705
|
else:
|
|
668
706
|
return action
|
|
669
|
-
|
|
707
|
+
|
|
708
|
+
# the same technique could be applied to non-bool actions following Bueno et al.
|
|
709
|
+
# this is disabled by default since the gradient projection trick seems to work
|
|
710
|
+
# better, especially for one-sided bounds (-inf, B) or (B, +inf)
|
|
670
711
|
wrap_non_bool = self._wrap_non_bool
|
|
671
712
|
|
|
672
713
|
def _jax_non_bool_param_to_action(var, param, hyperparams):
|
|
@@ -675,45 +716,24 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
675
716
|
mb, ml, mu, mn = [jnp.asarray(mask, dtype=compiled.REAL)
|
|
676
717
|
for mask in cond_lists[var]]
|
|
677
718
|
action = (
|
|
678
|
-
mb * (lower + (upper - lower) *
|
|
679
|
-
ml * (lower +
|
|
680
|
-
mu * (upper -
|
|
719
|
+
mb * (lower + (upper - lower) * stable_sigmoid(param)) +
|
|
720
|
+
ml * (lower + jax.nn.softplus(param)) +
|
|
721
|
+
mu * (upper - jax.nn.softplus(-param)) +
|
|
681
722
|
mn * param
|
|
682
723
|
)
|
|
683
724
|
else:
|
|
684
725
|
action = param
|
|
685
726
|
return action
|
|
686
|
-
|
|
687
|
-
# handle
|
|
688
|
-
|
|
689
|
-
max_action = 1.0 - min_action
|
|
690
|
-
|
|
691
|
-
def _jax_project_bool_to_box(var, param, hyperparams):
|
|
692
|
-
lower = _jax_bool_action_to_param(var, min_action, hyperparams)
|
|
693
|
-
upper = _jax_bool_action_to_param(var, max_action, hyperparams)
|
|
694
|
-
valid_param = jnp.clip(param, lower, upper)
|
|
695
|
-
return valid_param
|
|
696
|
-
|
|
727
|
+
|
|
728
|
+
# a different option to handle boolean action concurrency constraints with |A| = 1
|
|
729
|
+
# is to use a softmax activation layer over pooled action parameters
|
|
697
730
|
ranges = rddl.variable_ranges
|
|
698
|
-
|
|
699
|
-
def _jax_wrapped_slp_project_to_box(params, hyperparams):
|
|
700
|
-
new_params = {}
|
|
701
|
-
for (var, param) in params.items():
|
|
702
|
-
if var == bool_key:
|
|
703
|
-
new_params[var] = param
|
|
704
|
-
elif ranges[var] == 'bool':
|
|
705
|
-
new_params[var] = _jax_project_bool_to_box(var, param, hyperparams)
|
|
706
|
-
elif wrap_non_bool:
|
|
707
|
-
new_params[var] = param
|
|
708
|
-
else:
|
|
709
|
-
new_params[var] = jnp.clip(param, *bounds[var])
|
|
710
|
-
return new_params, True
|
|
711
|
-
|
|
712
|
-
# convert softmax action back to action dict
|
|
713
731
|
action_sizes = {var: np.prod(shape[1:], dtype=np.int64)
|
|
714
732
|
for (var, shape) in shapes.items()
|
|
715
733
|
if ranges[var] == 'bool'}
|
|
716
734
|
|
|
735
|
+
# given a softmax output, this simply unpacks the result of the softmax back into
|
|
736
|
+
# the original action fluent dictionary
|
|
717
737
|
def _jax_unstack_bool_from_softmax(output):
|
|
718
738
|
actions = {}
|
|
719
739
|
start = 0
|
|
@@ -723,11 +743,12 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
723
743
|
if noop[name]:
|
|
724
744
|
action = 1.0 - action
|
|
725
745
|
actions[name] = action
|
|
726
|
-
start
|
|
746
|
+
start = start + size
|
|
727
747
|
return actions
|
|
728
748
|
|
|
729
|
-
#
|
|
730
|
-
|
|
749
|
+
# the main subroutine to compute the trainable rddl actions from the trainable
|
|
750
|
+
# parameters (TODO: implement one-hot for integer actions)
|
|
751
|
+
def _jax_wrapped_slp_predict_train(key, params, hyperparams, step, fls):
|
|
731
752
|
actions = {}
|
|
732
753
|
for (var, param) in params.items():
|
|
733
754
|
action = jnp.asarray(param[step, ...], dtype=compiled.REAL)
|
|
@@ -740,9 +761,12 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
740
761
|
else:
|
|
741
762
|
actions[var] = _jax_non_bool_param_to_action(var, action, hyperparams)
|
|
742
763
|
return actions
|
|
764
|
+
self.train_policy = _jax_wrapped_slp_predict_train
|
|
743
765
|
|
|
744
|
-
# test
|
|
745
|
-
|
|
766
|
+
# the main subroutine to compute the test rddl actions from the trainable
|
|
767
|
+
# parameters: the difference here is that actions are converted to their required
|
|
768
|
+
# types (i.e. bool, int, float)
|
|
769
|
+
def _jax_wrapped_slp_predict_test(key, params, hyperparams, step, fls):
|
|
746
770
|
actions = {}
|
|
747
771
|
for (var, param) in params.items():
|
|
748
772
|
action = jnp.asarray(param[step, ...], dtype=compiled.REAL)
|
|
@@ -760,8 +784,6 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
760
784
|
action = jnp.asarray(jnp.round(action), dtype=compiled.INT)
|
|
761
785
|
actions[var] = action
|
|
762
786
|
return actions
|
|
763
|
-
|
|
764
|
-
self.train_policy = _jax_wrapped_slp_predict_train
|
|
765
787
|
self.test_policy = _jax_wrapped_slp_predict_test
|
|
766
788
|
|
|
767
789
|
# ***********************************************************************
|
|
@@ -769,148 +791,76 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
769
791
|
#
|
|
770
792
|
# ***********************************************************************
|
|
771
793
|
|
|
772
|
-
#
|
|
794
|
+
# if the user wants min/max values for clipping boolean action parameters
|
|
795
|
+
# this might be a good idea to avoid saturation of action-fluents since the
|
|
796
|
+
# gradient could vanish as a result
|
|
797
|
+
min_action = self._min_action_prob
|
|
798
|
+
max_action = 1.0 - min_action
|
|
799
|
+
|
|
800
|
+
def _jax_project_bool_to_box(var, param, hyperparams):
|
|
801
|
+
lower = _jax_bool_action_to_param(var, min_action, hyperparams)
|
|
802
|
+
upper = _jax_bool_action_to_param(var, max_action, hyperparams)
|
|
803
|
+
return jnp.clip(param, lower, upper)
|
|
804
|
+
|
|
805
|
+
def _jax_wrapped_slp_project_to_box(params, hyperparams):
|
|
806
|
+
new_params = {}
|
|
807
|
+
for (var, param) in params.items():
|
|
808
|
+
if var == bool_key:
|
|
809
|
+
new_params[var] = param
|
|
810
|
+
elif ranges[var] == 'bool':
|
|
811
|
+
new_params[var] = _jax_project_bool_to_box(var, param, hyperparams)
|
|
812
|
+
elif wrap_non_bool:
|
|
813
|
+
new_params[var] = param
|
|
814
|
+
else:
|
|
815
|
+
new_params[var] = jnp.clip(param, *bounds[var])
|
|
816
|
+
converged = jnp.array(True, dtype=jnp.bool_)
|
|
817
|
+
return new_params, converged
|
|
818
|
+
|
|
819
|
+
# enable constraint satisfaction subroutines during optimization
|
|
820
|
+
# if there are nontrivial concurrency constraints in the problem description
|
|
821
|
+
bool_action_count, allowed_actions = self._count_bool_actions(rddl)
|
|
822
|
+
use_constraint_satisfaction = allowed_actions < bool_action_count
|
|
823
|
+
if compiled.print_warnings and use_constraint_satisfaction:
|
|
824
|
+
print(termcolor.colored(
|
|
825
|
+
f'[INFO] Number of boolean actions {bool_action_count} '
|
|
826
|
+
f'> max_nondef_actions {allowed_actions}: enabling projected gradient to '
|
|
827
|
+
f'satisfy constraints on action-fluents.', 'dark_grey'
|
|
828
|
+
))
|
|
829
|
+
|
|
830
|
+
# use a softmax output activation: only allow one action non-noop for now
|
|
773
831
|
if use_constraint_satisfaction and self._wrap_softmax:
|
|
774
|
-
|
|
775
|
-
# only allow one action non-noop for now
|
|
776
832
|
if 1 < allowed_actions < bool_action_count:
|
|
777
833
|
raise RDDLNotImplementedError(
|
|
778
834
|
f'SLPs with wrap_softmax currently '
|
|
779
|
-
f'do not support max-nondef-actions {allowed_actions} > 1.'
|
|
780
|
-
|
|
781
|
-
# potentially apply projection but to non-bool actions only
|
|
835
|
+
f'do not support max-nondef-actions {allowed_actions} > 1.'
|
|
836
|
+
)
|
|
782
837
|
self.projection = _jax_wrapped_slp_project_to_box
|
|
783
838
|
|
|
784
|
-
# use new gradient projection method
|
|
839
|
+
# use new gradient projection method
|
|
785
840
|
elif use_constraint_satisfaction and self._use_new_projection:
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
# find the amount to shift action parameters
|
|
793
|
-
# if noop is True pretend it is False and reflect the parameter
|
|
794
|
-
scores = []
|
|
795
|
-
for (var, param) in params.items():
|
|
796
|
-
if ranges[var] == 'bool':
|
|
797
|
-
param_flat = jnp.ravel(param, order='C')
|
|
798
|
-
if noop[var]:
|
|
799
|
-
if wrap_sigmoid:
|
|
800
|
-
param_flat = -param_flat
|
|
801
|
-
else:
|
|
802
|
-
param_flat = 1.0 - param_flat
|
|
803
|
-
scores.append(param_flat)
|
|
804
|
-
scores = jnp.concatenate(scores)
|
|
805
|
-
descending = jnp.sort(scores)[::-1]
|
|
806
|
-
kplus1st_greatest = descending[allowed_actions]
|
|
807
|
-
surplus = jnp.maximum(kplus1st_greatest - bool_threshold, 0.0)
|
|
808
|
-
|
|
809
|
-
# perform the shift
|
|
810
|
-
new_params = {}
|
|
811
|
-
for (var, param) in params.items():
|
|
812
|
-
if ranges[var] == 'bool':
|
|
813
|
-
if noop[var]:
|
|
814
|
-
new_param = param + surplus
|
|
815
|
-
else:
|
|
816
|
-
new_param = param - surplus
|
|
817
|
-
new_param = _jax_project_bool_to_box(var, new_param, hyperparams)
|
|
818
|
-
else:
|
|
819
|
-
new_param = param
|
|
820
|
-
new_params[var] = new_param
|
|
821
|
-
return new_params, True
|
|
822
|
-
|
|
841
|
+
jax_project_fn = JaxSortingActionProjection().compile(
|
|
842
|
+
ranges, noop, wrap_sigmoid, allowed_actions, bool_threshold,
|
|
843
|
+
_jax_project_bool_to_box
|
|
844
|
+
)
|
|
845
|
+
|
|
823
846
|
# clip actions to valid bounds and satisfy constraint on max actions
|
|
824
847
|
def _jax_wrapped_slp_project_to_max_constraint(params, hyperparams):
|
|
825
848
|
params, _ = _jax_wrapped_slp_project_to_box(params, hyperparams)
|
|
826
|
-
|
|
827
|
-
_jax_wrapped_sorting_project, in_axes=(0, None)
|
|
828
|
-
)(params, hyperparams)
|
|
829
|
-
return project_over_horizon
|
|
830
|
-
|
|
849
|
+
return jax.vmap(jax_project_fn, in_axes=(0, None))(params, hyperparams)
|
|
831
850
|
self.projection = _jax_wrapped_slp_project_to_max_constraint
|
|
832
851
|
|
|
833
|
-
# use SOGBOFA projection method
|
|
852
|
+
# use SOGBOFA projection method
|
|
834
853
|
elif use_constraint_satisfaction and not self._use_new_projection:
|
|
854
|
+
jax_project_fn = JaxSogbofaActionProjection().compile(
|
|
855
|
+
ranges, noop, allowed_actions, self._max_constraint_iter,
|
|
856
|
+
_jax_bool_param_to_action, _jax_bool_action_to_param,
|
|
857
|
+
min_action, max_action, compiled.REAL
|
|
858
|
+
)
|
|
835
859
|
|
|
836
|
-
# calculate the surplus of actions above max-nondef-actions
|
|
837
|
-
def _jax_wrapped_sogbofa_surplus(actions):
|
|
838
|
-
sum_action, k = 0.0, 0
|
|
839
|
-
for (var, action) in actions.items():
|
|
840
|
-
if ranges[var] == 'bool':
|
|
841
|
-
if noop[var]:
|
|
842
|
-
action = 1 - action
|
|
843
|
-
sum_action += jnp.sum(action)
|
|
844
|
-
k += jnp.count_nonzero(action)
|
|
845
|
-
surplus = jnp.maximum(sum_action - allowed_actions, 0.0)
|
|
846
|
-
return surplus, k
|
|
847
|
-
|
|
848
|
-
# return whether the surplus is positive or reached compute limit
|
|
849
|
-
max_constraint_iter = self._max_constraint_iter
|
|
850
|
-
|
|
851
|
-
def _jax_wrapped_sogbofa_continue(values):
|
|
852
|
-
it, _, surplus, k = values
|
|
853
|
-
return jnp.logical_and(
|
|
854
|
-
it < max_constraint_iter, jnp.logical_and(surplus > 0, k > 0))
|
|
855
|
-
|
|
856
|
-
# reduce all bool action values by the surplus clipping at minimum
|
|
857
|
-
# for no-op = True, do the opposite, i.e. increase all
|
|
858
|
-
# bool action values by surplus clipping at maximum
|
|
859
|
-
def _jax_wrapped_sogbofa_subtract_surplus(values):
|
|
860
|
-
it, actions, surplus, k = values
|
|
861
|
-
amount = surplus / k
|
|
862
|
-
new_actions = {}
|
|
863
|
-
for (var, action) in actions.items():
|
|
864
|
-
if ranges[var] == 'bool':
|
|
865
|
-
if noop[var]:
|
|
866
|
-
new_actions[var] = jnp.minimum(action + amount, 1)
|
|
867
|
-
else:
|
|
868
|
-
new_actions[var] = jnp.maximum(action - amount, 0)
|
|
869
|
-
else:
|
|
870
|
-
new_actions[var] = action
|
|
871
|
-
new_surplus, new_k = _jax_wrapped_sogbofa_surplus(new_actions)
|
|
872
|
-
new_it = it + 1
|
|
873
|
-
return new_it, new_actions, new_surplus, new_k
|
|
874
|
-
|
|
875
|
-
# apply the surplus to the actions until it becomes zero
|
|
876
|
-
def _jax_wrapped_sogbofa_project(params, hyperparams):
|
|
877
|
-
|
|
878
|
-
# convert parameters to actions
|
|
879
|
-
actions = {}
|
|
880
|
-
for (var, param) in params.items():
|
|
881
|
-
if ranges[var] == 'bool':
|
|
882
|
-
actions[var] = _jax_bool_param_to_action(var, param, hyperparams)
|
|
883
|
-
else:
|
|
884
|
-
actions[var] = param
|
|
885
|
-
|
|
886
|
-
# run SOGBOFA loop on the actions to get adjusted actions
|
|
887
|
-
surplus, k = _jax_wrapped_sogbofa_surplus(actions)
|
|
888
|
-
_, actions, surplus, k = jax.lax.while_loop(
|
|
889
|
-
cond_fun=_jax_wrapped_sogbofa_continue,
|
|
890
|
-
body_fun=_jax_wrapped_sogbofa_subtract_surplus,
|
|
891
|
-
init_val=(0, actions, surplus, k)
|
|
892
|
-
)
|
|
893
|
-
converged = jnp.logical_not(surplus > 0)
|
|
894
|
-
|
|
895
|
-
# convert the adjusted actions back to parameters
|
|
896
|
-
new_params = {}
|
|
897
|
-
for (var, action) in actions.items():
|
|
898
|
-
if ranges[var] == 'bool':
|
|
899
|
-
action = jnp.clip(action, min_action, max_action)
|
|
900
|
-
param = _jax_bool_action_to_param(var, action, hyperparams)
|
|
901
|
-
new_params[var] = param
|
|
902
|
-
else:
|
|
903
|
-
new_params[var] = action
|
|
904
|
-
return new_params, converged
|
|
905
|
-
|
|
906
860
|
# clip actions to valid bounds and satisfy constraint on max actions
|
|
907
861
|
def _jax_wrapped_slp_project_to_max_constraint(params, hyperparams):
|
|
908
862
|
params, _ = _jax_wrapped_slp_project_to_box(params, hyperparams)
|
|
909
|
-
|
|
910
|
-
_jax_wrapped_sogbofa_project, in_axes=(0, None)
|
|
911
|
-
)(params, hyperparams)
|
|
912
|
-
return project_over_horizon
|
|
913
|
-
|
|
863
|
+
return jax.vmap(jax_project_fn, in_axes=(0, None))(params, hyperparams)
|
|
914
864
|
self.projection = _jax_wrapped_slp_project_to_max_constraint
|
|
915
865
|
|
|
916
866
|
# just project to box constraints
|
|
@@ -925,34 +875,32 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
925
875
|
init = self._initializer
|
|
926
876
|
stack_bool_params = use_constraint_satisfaction and self._wrap_softmax
|
|
927
877
|
|
|
928
|
-
|
|
878
|
+
# use the user required initializer and project actions to feasible range
|
|
879
|
+
def _jax_wrapped_slp_init(key, hyperparams, fls):
|
|
929
880
|
params = {}
|
|
930
881
|
for (var, shape) in shapes.items():
|
|
931
882
|
if ranges[var] != 'bool' or not stack_bool_params:
|
|
932
883
|
key, subkey = random.split(key)
|
|
933
884
|
param = init(key=subkey, shape=shape, dtype=compiled.REAL)
|
|
934
885
|
if ranges[var] == 'bool':
|
|
935
|
-
param
|
|
886
|
+
param = param + bool_threshold
|
|
936
887
|
params[var] = param
|
|
937
888
|
if stack_bool_params:
|
|
938
889
|
key, subkey = random.split(key)
|
|
939
890
|
bool_shape = (horizon, bool_action_count)
|
|
940
|
-
|
|
941
|
-
params[bool_key] = bool_param
|
|
891
|
+
params[bool_key] = init(key=subkey, shape=bool_shape, dtype=compiled.REAL)
|
|
942
892
|
params, _ = _jax_wrapped_slp_project_to_box(params, hyperparams)
|
|
943
893
|
return params
|
|
944
|
-
|
|
945
894
|
self.initializer = _jax_wrapped_slp_init
|
|
946
895
|
|
|
947
896
|
@staticmethod
|
|
948
897
|
@jax.jit
|
|
949
898
|
def _guess_next_epoch(param):
|
|
950
|
-
# "progress" the plan one step forward and set last action to second-last
|
|
951
899
|
return jnp.append(param[1:, ...], param[-1:, ...], axis=0)
|
|
952
900
|
|
|
953
901
|
def guess_next_epoch(self, params: Pytree) -> Pytree:
|
|
954
|
-
|
|
955
|
-
return jax.tree_util.tree_map(
|
|
902
|
+
# "progress" the plan one step forward and set last action to second-last
|
|
903
|
+
return jax.tree_util.tree_map(JaxStraightLinePlan._guess_next_epoch, params)
|
|
956
904
|
|
|
957
905
|
|
|
958
906
|
class JaxDeepReactivePolicy(JaxPlan):
|
|
@@ -997,7 +945,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
997
945
|
def __str__(self) -> str:
|
|
998
946
|
bounds = '\n '.join(
|
|
999
947
|
map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
|
|
1000
|
-
return (f'policy hyper-parameters:\n'
|
|
948
|
+
return (f'[INFO] policy hyper-parameters:\n'
|
|
1001
949
|
f' topology ={self._topology}\n'
|
|
1002
950
|
f' activation_fn={self._activations[0].__name__}\n'
|
|
1003
951
|
f' initializer ={type(self._initializer_base).__name__}\n'
|
|
@@ -1021,13 +969,18 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1021
969
|
shapes = {var: value[1:] for (var, value) in shapes.items()}
|
|
1022
970
|
self.bounds = bounds
|
|
1023
971
|
|
|
1024
|
-
#
|
|
972
|
+
# enable constraint satisfaction subroutines during optimization
|
|
973
|
+
# if there are nontrivial concurrency constraints in the problem description
|
|
974
|
+
# only handles the case where |A| = 1 for now, as there is no way to do projection
|
|
975
|
+
# currently (TODO: fix this)
|
|
1025
976
|
bool_action_count, allowed_actions = self._count_bool_actions(rddl)
|
|
1026
977
|
if 1 < allowed_actions < bool_action_count:
|
|
1027
978
|
raise RDDLNotImplementedError(
|
|
1028
|
-
f'DRPs currently do not support max-nondef-actions {allowed_actions} > 1.'
|
|
979
|
+
f'DRPs currently do not support max-nondef-actions {allowed_actions} > 1.'
|
|
980
|
+
)
|
|
1029
981
|
use_constraint_satisfaction = allowed_actions < bool_action_count
|
|
1030
|
-
|
|
982
|
+
|
|
983
|
+
# get the noop action values
|
|
1031
984
|
noop = {var: (values[0] if isinstance(values, list) else values)
|
|
1032
985
|
for (var, values) in rddl.action_fluents.items()}
|
|
1033
986
|
bool_key = 'bool__'
|
|
@@ -1036,7 +989,8 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1036
989
|
# POLICY NETWORK PREDICTION
|
|
1037
990
|
#
|
|
1038
991
|
# ***********************************************************************
|
|
1039
|
-
|
|
992
|
+
|
|
993
|
+
# compute the correct shapes of the output layers based on the action-fluent shape
|
|
1040
994
|
ranges = rddl.variable_ranges
|
|
1041
995
|
normalize = self._normalize
|
|
1042
996
|
normalize_per_layer = self._normalize_per_layer
|
|
@@ -1047,14 +1001,15 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1047
1001
|
for (var, shape) in shapes.items()}
|
|
1048
1002
|
layer_names = {var: f'output_{var}'.replace('-', '_') for var in shapes}
|
|
1049
1003
|
|
|
1050
|
-
# inputs for the policy network
|
|
1004
|
+
# inputs for the policy network are states for fully observed and obs for POMDPs
|
|
1051
1005
|
if rddl.observ_fluents:
|
|
1052
1006
|
observed_vars = rddl.observ_fluents
|
|
1053
1007
|
else:
|
|
1054
1008
|
observed_vars = rddl.state_fluents
|
|
1055
1009
|
input_names = {var: f'{var}'.replace('-', '_') for var in observed_vars}
|
|
1056
1010
|
|
|
1057
|
-
# catch if input norm is applied to size 1 tensor
|
|
1011
|
+
# catch if input norm is applied to size 1 tensor:
|
|
1012
|
+
# this leads to incorrect behavior as the input is always "1"
|
|
1058
1013
|
if normalize:
|
|
1059
1014
|
non_bool_dims = 0
|
|
1060
1015
|
for (var, values) in observed_vars.items():
|
|
@@ -1062,33 +1017,33 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1062
1017
|
value_size = np.size(values)
|
|
1063
1018
|
if normalize_per_layer and value_size == 1:
|
|
1064
1019
|
if compiled.print_warnings:
|
|
1065
|
-
|
|
1020
|
+
print(termcolor.colored(
|
|
1066
1021
|
f'[WARN] Cannot apply layer norm to state-fluent <{var}> '
|
|
1067
|
-
f'of size 1: setting normalize_per_layer = False.', 'yellow'
|
|
1068
|
-
|
|
1022
|
+
f'of size 1: setting normalize_per_layer = False.', 'yellow'
|
|
1023
|
+
))
|
|
1069
1024
|
normalize_per_layer = False
|
|
1070
1025
|
non_bool_dims += value_size
|
|
1071
1026
|
if not normalize_per_layer and non_bool_dims == 1:
|
|
1072
1027
|
if compiled.print_warnings:
|
|
1073
|
-
|
|
1028
|
+
print(termcolor.colored(
|
|
1074
1029
|
'[WARN] Cannot apply layer norm to state-fluents of total size 1: '
|
|
1075
|
-
'setting normalize = False.', 'yellow'
|
|
1076
|
-
|
|
1030
|
+
'setting normalize = False.', 'yellow'
|
|
1031
|
+
))
|
|
1077
1032
|
normalize = False
|
|
1078
1033
|
|
|
1079
|
-
# convert
|
|
1080
|
-
def _jax_wrapped_policy_input(
|
|
1034
|
+
# convert fluents dictionary into a state vector to feed to the MLP
|
|
1035
|
+
def _jax_wrapped_policy_input(fls, hyperparams):
|
|
1081
1036
|
|
|
1082
1037
|
# optional state preprocessing
|
|
1083
1038
|
if preprocessor is not None:
|
|
1084
1039
|
stats = hyperparams[preprocessor.HYPERPARAMS_KEY]
|
|
1085
|
-
|
|
1040
|
+
fls = preprocessor.transform(fls, stats)
|
|
1086
1041
|
|
|
1087
1042
|
# concatenate all state variables into a single vector
|
|
1088
1043
|
# optionally apply layer norm to each input tensor
|
|
1089
1044
|
states_bool, states_non_bool = [], []
|
|
1090
1045
|
non_bool_dims = 0
|
|
1091
|
-
for (var, value) in
|
|
1046
|
+
for (var, value) in fls.items():
|
|
1092
1047
|
if var in observed_vars:
|
|
1093
1048
|
state = jnp.ravel(value, order='C')
|
|
1094
1049
|
if ranges[var] == 'bool':
|
|
@@ -1103,7 +1058,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1103
1058
|
)
|
|
1104
1059
|
state = normalizer(state)
|
|
1105
1060
|
states_non_bool.append(state)
|
|
1106
|
-
non_bool_dims
|
|
1061
|
+
non_bool_dims = non_bool_dims + state.size
|
|
1107
1062
|
state = jnp.concatenate(states_non_bool + states_bool)
|
|
1108
1063
|
|
|
1109
1064
|
# optionally perform layer normalization on the non-bool inputs
|
|
@@ -1119,8 +1074,8 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1119
1074
|
return state
|
|
1120
1075
|
|
|
1121
1076
|
# predict actions from the policy network for current state
|
|
1122
|
-
def _jax_wrapped_policy_network_predict(
|
|
1123
|
-
state = _jax_wrapped_policy_input(
|
|
1077
|
+
def _jax_wrapped_policy_network_predict(fls, hyperparams):
|
|
1078
|
+
state = _jax_wrapped_policy_input(fls, hyperparams)
|
|
1124
1079
|
|
|
1125
1080
|
# feed state vector through hidden layers
|
|
1126
1081
|
hidden = state
|
|
@@ -1139,37 +1094,37 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1139
1094
|
if not shapes[var]:
|
|
1140
1095
|
output = jnp.squeeze(output)
|
|
1141
1096
|
|
|
1142
|
-
# project action output to valid box constraints
|
|
1097
|
+
# project action output to valid box constraints following Bueno et. al.
|
|
1143
1098
|
if ranges[var] == 'bool':
|
|
1144
1099
|
if not use_constraint_satisfaction:
|
|
1145
|
-
actions[var] =
|
|
1100
|
+
actions[var] = stable_sigmoid(output)
|
|
1146
1101
|
else:
|
|
1147
1102
|
if wrap_non_bool:
|
|
1148
1103
|
lower, upper = bounds_safe[var]
|
|
1149
1104
|
mb, ml, mu, mn = [jnp.asarray(mask, dtype=compiled.REAL)
|
|
1150
1105
|
for mask in cond_lists[var]]
|
|
1151
|
-
|
|
1152
|
-
mb * (lower + (upper - lower) *
|
|
1153
|
-
ml * (lower +
|
|
1154
|
-
mu * (upper -
|
|
1106
|
+
actions[var] = (
|
|
1107
|
+
mb * (lower + (upper - lower) * stable_sigmoid(output)) +
|
|
1108
|
+
ml * (lower + jax.nn.softplus(output)) +
|
|
1109
|
+
mu * (upper - jax.nn.softplus(-output)) +
|
|
1155
1110
|
mn * output
|
|
1156
1111
|
)
|
|
1157
1112
|
else:
|
|
1158
|
-
|
|
1159
|
-
actions[var] = action
|
|
1113
|
+
actions[var] = output
|
|
1160
1114
|
|
|
1161
|
-
# for constraint satisfaction wrap bool actions with softmax
|
|
1115
|
+
# for constraint satisfaction wrap bool actions with softmax:
|
|
1116
|
+
# this only works when |A| = 1
|
|
1162
1117
|
if use_constraint_satisfaction:
|
|
1163
1118
|
linear = hk.Linear(bool_action_count, name='output_bool', w_init=init)
|
|
1164
|
-
|
|
1165
|
-
actions[bool_key] = output
|
|
1166
|
-
|
|
1119
|
+
actions[bool_key] = jax.nn.softmax(linear(hidden))
|
|
1167
1120
|
return actions
|
|
1168
1121
|
|
|
1122
|
+
# we need pure JAX functions for the policy network prediction
|
|
1169
1123
|
predict_fn = hk.transform(_jax_wrapped_policy_network_predict)
|
|
1170
1124
|
predict_fn = hk.without_apply_rng(predict_fn)
|
|
1171
1125
|
|
|
1172
|
-
#
|
|
1126
|
+
# given a softmax output, this simply unpacks the result of the softmax back into
|
|
1127
|
+
# the original action fluent dictionary
|
|
1173
1128
|
def _jax_unstack_bool_from_softmax(output):
|
|
1174
1129
|
actions = {}
|
|
1175
1130
|
start = 0
|
|
@@ -1180,12 +1135,13 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1180
1135
|
if noop[name]:
|
|
1181
1136
|
action = 1.0 - action
|
|
1182
1137
|
actions[name] = action
|
|
1183
|
-
start
|
|
1138
|
+
start = start + size
|
|
1184
1139
|
return actions
|
|
1185
1140
|
|
|
1186
|
-
#
|
|
1187
|
-
|
|
1188
|
-
|
|
1141
|
+
# the main subroutine to compute the trainable rddl actions from the trainable
|
|
1142
|
+
# parameters and the current state/obs dictionary
|
|
1143
|
+
def _jax_wrapped_drp_predict_train(key, params, hyperparams, step, fls):
|
|
1144
|
+
actions = predict_fn.apply(params, fls, hyperparams)
|
|
1189
1145
|
if not wrap_non_bool:
|
|
1190
1146
|
for (var, action) in actions.items():
|
|
1191
1147
|
if var != bool_key and ranges[var] != 'bool':
|
|
@@ -1195,10 +1151,13 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1195
1151
|
actions.update(bool_actions)
|
|
1196
1152
|
del actions[bool_key]
|
|
1197
1153
|
return actions
|
|
1154
|
+
self.train_policy = _jax_wrapped_drp_predict_train
|
|
1198
1155
|
|
|
1199
|
-
# test
|
|
1200
|
-
|
|
1201
|
-
|
|
1156
|
+
# the main subroutine to compute the test rddl actions from the trainable
|
|
1157
|
+
# parameters and state/obs dict: the difference here is that actions are converted
|
|
1158
|
+
# to their required types (i.e. bool, int, float)
|
|
1159
|
+
def _jax_wrapped_drp_predict_test(key, params, hyperparams, step, fls):
|
|
1160
|
+
actions = _jax_wrapped_drp_predict_train(key, params, hyperparams, step, fls)
|
|
1202
1161
|
new_actions = {}
|
|
1203
1162
|
for (var, action) in actions.items():
|
|
1204
1163
|
prange = ranges[var]
|
|
@@ -1211,8 +1170,6 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1211
1170
|
new_action = jnp.clip(action, *bounds[var])
|
|
1212
1171
|
new_actions[var] = new_action
|
|
1213
1172
|
return new_actions
|
|
1214
|
-
|
|
1215
|
-
self.train_policy = _jax_wrapped_drp_predict_train
|
|
1216
1173
|
self.test_policy = _jax_wrapped_drp_predict_test
|
|
1217
1174
|
|
|
1218
1175
|
# ***********************************************************************
|
|
@@ -1222,8 +1179,8 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1222
1179
|
|
|
1223
1180
|
# no projection applied since the actions are already constrained
|
|
1224
1181
|
def _jax_wrapped_drp_no_projection(params, hyperparams):
|
|
1225
|
-
|
|
1226
|
-
|
|
1182
|
+
converged = jnp.array(True, dtype=jnp.bool_)
|
|
1183
|
+
return params, converged
|
|
1227
1184
|
self.projection = _jax_wrapped_drp_no_projection
|
|
1228
1185
|
|
|
1229
1186
|
# ***********************************************************************
|
|
@@ -1231,16 +1188,16 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1231
1188
|
#
|
|
1232
1189
|
# ***********************************************************************
|
|
1233
1190
|
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
return
|
|
1240
|
-
|
|
1191
|
+
# initialize policy parameters according to user-desired weight initializer
|
|
1192
|
+
def _jax_wrapped_drp_init(key, hyperparams, fls):
|
|
1193
|
+
obs_vars = {var: value[0, ...]
|
|
1194
|
+
for (var, value) in fls.items()
|
|
1195
|
+
if var in observed_vars}
|
|
1196
|
+
return predict_fn.init(key, obs_vars, hyperparams)
|
|
1241
1197
|
self.initializer = _jax_wrapped_drp_init
|
|
1242
1198
|
|
|
1243
1199
|
def guess_next_epoch(self, params: Pytree) -> Pytree:
|
|
1200
|
+
# this is easy: just warm-start from the previously obtained policy
|
|
1244
1201
|
return params
|
|
1245
1202
|
|
|
1246
1203
|
|
|
@@ -1339,17 +1296,16 @@ class PGPE(metaclass=ABCMeta):
|
|
|
1339
1296
|
self._update = None
|
|
1340
1297
|
|
|
1341
1298
|
@property
|
|
1342
|
-
def initialize(self):
|
|
1299
|
+
def initialize(self) -> Callable:
|
|
1343
1300
|
return self._initializer
|
|
1344
1301
|
|
|
1345
1302
|
@property
|
|
1346
|
-
def update(self):
|
|
1303
|
+
def update(self) -> Callable:
|
|
1347
1304
|
return self._update
|
|
1348
1305
|
|
|
1349
1306
|
@abstractmethod
|
|
1350
1307
|
def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type,
|
|
1351
|
-
print_warnings: bool,
|
|
1352
|
-
parallel_updates: Optional[int]=None) -> None:
|
|
1308
|
+
print_warnings: bool, parallel_updates: int=1) -> None:
|
|
1353
1309
|
pass
|
|
1354
1310
|
|
|
1355
1311
|
|
|
@@ -1414,11 +1370,10 @@ class GaussianPGPE(PGPE):
|
|
|
1414
1370
|
mu_optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs_mu)
|
|
1415
1371
|
sigma_optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs_sigma)
|
|
1416
1372
|
except Exception as _:
|
|
1417
|
-
|
|
1418
|
-
'[
|
|
1419
|
-
'
|
|
1420
|
-
|
|
1421
|
-
print(message)
|
|
1373
|
+
print(termcolor.colored(
|
|
1374
|
+
'[WARN] Could not inject hyperparameters into PGPE optimizer: '
|
|
1375
|
+
'kl-divergence constraint will be disabled.', 'yellow'
|
|
1376
|
+
))
|
|
1422
1377
|
mu_optimizer = optimizer(**optimizer_kwargs_mu)
|
|
1423
1378
|
sigma_optimizer = optimizer(**optimizer_kwargs_sigma)
|
|
1424
1379
|
max_kl_update = None
|
|
@@ -1426,7 +1381,7 @@ class GaussianPGPE(PGPE):
|
|
|
1426
1381
|
self.max_kl = max_kl_update
|
|
1427
1382
|
|
|
1428
1383
|
def __str__(self) -> str:
|
|
1429
|
-
return (f'PGPE hyper-parameters:\n'
|
|
1384
|
+
return (f'[INFO] PGPE hyper-parameters:\n'
|
|
1430
1385
|
f' method ={self.__class__.__name__}\n'
|
|
1431
1386
|
f' batch_size ={self.batch_size}\n'
|
|
1432
1387
|
f' init_sigma ={self.init_sigma}\n'
|
|
@@ -1444,9 +1399,11 @@ class GaussianPGPE(PGPE):
|
|
|
1444
1399
|
f' max_kl_update ={self.max_kl}\n'
|
|
1445
1400
|
)
|
|
1446
1401
|
|
|
1447
|
-
def compile(self, loss_fn: Callable,
|
|
1402
|
+
def compile(self, loss_fn: Callable,
|
|
1403
|
+
projection: Callable,
|
|
1404
|
+
real_dtype: Type,
|
|
1448
1405
|
print_warnings: bool,
|
|
1449
|
-
parallel_updates:
|
|
1406
|
+
parallel_updates: int=1) -> None:
|
|
1450
1407
|
sigma0 = self.init_sigma
|
|
1451
1408
|
sigma_lo, sigma_hi = self.sigma_range
|
|
1452
1409
|
scale_reward = self.scale_reward
|
|
@@ -1458,6 +1415,7 @@ class GaussianPGPE(PGPE):
|
|
|
1458
1415
|
max_kl = self.max_kl
|
|
1459
1416
|
|
|
1460
1417
|
# entropy regularization penalty is decayed exponentially by elapsed budget
|
|
1418
|
+
# this uses the optimizer progress (as percentage) to move the decay
|
|
1461
1419
|
start_entropy_coeff = self.start_entropy_coeff
|
|
1462
1420
|
if start_entropy_coeff == 0:
|
|
1463
1421
|
entropy_coeff_decay = 0
|
|
@@ -1469,6 +1427,8 @@ class GaussianPGPE(PGPE):
|
|
|
1469
1427
|
#
|
|
1470
1428
|
# ***********************************************************************
|
|
1471
1429
|
|
|
1430
|
+
# use the default initializer for the (mean, sigma) parameters
|
|
1431
|
+
# these parameters define the sampling distribution over policy parameters
|
|
1472
1432
|
def _jax_wrapped_pgpe_init(key, policy_params):
|
|
1473
1433
|
mu = policy_params
|
|
1474
1434
|
sigma = jax.tree_util.tree_map(partial(jnp.full_like, fill_value=sigma0), mu)
|
|
@@ -1477,51 +1437,60 @@ class GaussianPGPE(PGPE):
|
|
|
1477
1437
|
r_max = -jnp.inf
|
|
1478
1438
|
return pgpe_params, pgpe_opt_state, r_max
|
|
1479
1439
|
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
|
|
1484
|
-
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
|
-
return jax.vmap(_jax_wrapped_pgpe_init, in_axes=0)(keys, policy_params)
|
|
1488
|
-
|
|
1489
|
-
self._initializer = jax.jit(_jax_wrapped_pgpe_inits)
|
|
1440
|
+
# for parallel policy update, initialize multiple indepdendent (mean, sigma)
|
|
1441
|
+
# gaussians that will be optimized in parallel
|
|
1442
|
+
def _jax_wrapped_batched_pgpe_init(key, policy_params):
|
|
1443
|
+
keys = random.split(key, num=parallel_updates)
|
|
1444
|
+
return jax.vmap(_jax_wrapped_pgpe_init, in_axes=0)(keys, policy_params)
|
|
1445
|
+
|
|
1446
|
+
self._initializer = jax.jit(_jax_wrapped_batched_pgpe_init)
|
|
1490
1447
|
|
|
1491
1448
|
# ***********************************************************************
|
|
1492
1449
|
# PARAMETER SAMPLING FUNCTIONS
|
|
1493
1450
|
#
|
|
1494
1451
|
# ***********************************************************************
|
|
1495
1452
|
|
|
1453
|
+
# sample from i.i.d. Normal(0, sigma)
|
|
1496
1454
|
def _jax_wrapped_mu_noise(key, sigma):
|
|
1497
1455
|
return sigma * random.normal(key, shape=jnp.shape(sigma), dtype=real_dtype)
|
|
1498
1456
|
|
|
1499
1457
|
# this samples a noise variable epsilon* from epsilon with the N(0, 1) density
|
|
1500
|
-
# according to super-symmetric sampling paper
|
|
1458
|
+
# according to super-symmetric sampling paper:
|
|
1459
|
+
# the paper presents a more accurate formula which is used by default
|
|
1501
1460
|
def _jax_wrapped_epsilon_star(sigma, epsilon):
|
|
1502
|
-
c1, c2, c3 = -0.06655, -0.9706, 0.124
|
|
1503
1461
|
phi = 0.67449 * sigma
|
|
1504
1462
|
a = (sigma - jnp.abs(epsilon)) / sigma
|
|
1463
|
+
|
|
1464
|
+
# more accurate formula
|
|
1505
1465
|
if super_symmetric_accurate:
|
|
1506
1466
|
aa = jnp.abs(a)
|
|
1507
|
-
|
|
1508
|
-
|
|
1509
|
-
|
|
1510
|
-
|
|
1511
|
-
|
|
1512
|
-
|
|
1467
|
+
atol = 1e-10
|
|
1468
|
+
c1, c2, c3 = -0.06655, -0.9706, 0.124
|
|
1469
|
+
term_neg_log = c1 * (aa * aa - 1.) / jnp.log(aa + atol) + c2
|
|
1470
|
+
term_pos_log = 1. - c3 * jnp.log1p(-aa ** 3 + atol)
|
|
1471
|
+
epsilon_star = jnp.sign(epsilon) * phi * jnp.exp(
|
|
1472
|
+
aa * jnp.where(a <= 0, term_neg_log, term_pos_log))
|
|
1473
|
+
|
|
1474
|
+
# less accurate and simple formula
|
|
1513
1475
|
else:
|
|
1514
1476
|
epsilon_star = jnp.sign(epsilon) * phi * jnp.exp(a)
|
|
1515
1477
|
return epsilon_star
|
|
1516
1478
|
|
|
1517
1479
|
# implements baseline-free super-symmetric sampling to generate 4 trajectories
|
|
1480
|
+
# this type of sampling removes the need for the baseline completely
|
|
1518
1481
|
def _jax_wrapped_sample_params(key, mu, sigma):
|
|
1482
|
+
|
|
1483
|
+
# this samples the basic two policy parameters from Gaussian(mean, sigma)
|
|
1484
|
+
# using the control variates
|
|
1519
1485
|
treedef = jax.tree_util.tree_structure(sigma)
|
|
1520
1486
|
keys = random.split(key, num=treedef.num_leaves)
|
|
1521
1487
|
keys_pytree = jax.tree_util.tree_unflatten(treedef=treedef, leaves=keys)
|
|
1522
1488
|
epsilon = jax.tree_util.tree_map(_jax_wrapped_mu_noise, keys_pytree, sigma)
|
|
1523
1489
|
p1 = jax.tree_util.tree_map(jnp.add, mu, epsilon)
|
|
1524
1490
|
p2 = jax.tree_util.tree_map(jnp.subtract, mu, epsilon)
|
|
1491
|
+
|
|
1492
|
+
# sumer-symmetric sampling removes the need for a baseline but requires
|
|
1493
|
+
# two additional policies to be sampled
|
|
1525
1494
|
if super_symmetric:
|
|
1526
1495
|
epsilon_star = jax.tree_util.tree_map(
|
|
1527
1496
|
_jax_wrapped_epsilon_star, sigma, epsilon)
|
|
@@ -1538,6 +1507,8 @@ class GaussianPGPE(PGPE):
|
|
|
1538
1507
|
|
|
1539
1508
|
# gradient with respect to mean
|
|
1540
1509
|
def _jax_wrapped_mu_grad(epsilon, epsilon_star, r1, r2, r3, r4, m):
|
|
1510
|
+
|
|
1511
|
+
# for super symmetric sampling
|
|
1541
1512
|
if super_symmetric:
|
|
1542
1513
|
if scale_reward:
|
|
1543
1514
|
scale1 = jnp.maximum(min_reward_scale, m - (r1 + r2) / 2)
|
|
@@ -1547,6 +1518,8 @@ class GaussianPGPE(PGPE):
|
|
|
1547
1518
|
r_mu1 = (r1 - r2) / (2 * scale1)
|
|
1548
1519
|
r_mu2 = (r3 - r4) / (2 * scale2)
|
|
1549
1520
|
grad = -(r_mu1 * epsilon + r_mu2 * epsilon_star)
|
|
1521
|
+
|
|
1522
|
+
# for the basic pgpe
|
|
1550
1523
|
else:
|
|
1551
1524
|
if scale_reward:
|
|
1552
1525
|
scale = jnp.maximum(min_reward_scale, m - (r1 + r2) / 2)
|
|
@@ -1558,6 +1531,8 @@ class GaussianPGPE(PGPE):
|
|
|
1558
1531
|
|
|
1559
1532
|
# gradient with respect to std. deviation
|
|
1560
1533
|
def _jax_wrapped_sigma_grad(epsilon, epsilon_star, sigma, r1, r2, r3, r4, m, ent):
|
|
1534
|
+
|
|
1535
|
+
# for super symmetric sampling
|
|
1561
1536
|
if super_symmetric:
|
|
1562
1537
|
mask = r1 + r2 >= r3 + r4
|
|
1563
1538
|
epsilon_tau = mask * epsilon + (1 - mask) * epsilon_star
|
|
@@ -1567,6 +1542,8 @@ class GaussianPGPE(PGPE):
|
|
|
1567
1542
|
else:
|
|
1568
1543
|
scale = 1.0
|
|
1569
1544
|
r_sigma = ((r1 + r2) - (r3 + r4)) / (4 * scale)
|
|
1545
|
+
|
|
1546
|
+
# for basic pgpe
|
|
1570
1547
|
else:
|
|
1571
1548
|
s = jnp.square(epsilon) / sigma - sigma
|
|
1572
1549
|
if scale_reward:
|
|
@@ -1574,30 +1551,40 @@ class GaussianPGPE(PGPE):
|
|
|
1574
1551
|
else:
|
|
1575
1552
|
scale = 1.0
|
|
1576
1553
|
r_sigma = (r1 + r2) / (2 * scale)
|
|
1577
|
-
|
|
1578
|
-
return
|
|
1554
|
+
|
|
1555
|
+
return -(r_sigma * s + ent / sigma)
|
|
1579
1556
|
|
|
1580
1557
|
# calculate the policy gradients
|
|
1581
1558
|
def _jax_wrapped_pgpe_grad(key, mu, sigma, r_max, ent,
|
|
1582
|
-
policy_hyperparams,
|
|
1559
|
+
policy_hyperparams, fls, nfls, model_params):
|
|
1560
|
+
|
|
1561
|
+
# basic pgpe sampling with return estimation
|
|
1583
1562
|
key, subkey = random.split(key)
|
|
1584
1563
|
p1, p2, p3, p4, epsilon, epsilon_star = _jax_wrapped_sample_params(
|
|
1585
1564
|
key, mu, sigma)
|
|
1586
|
-
r1 = -loss_fn(subkey, p1, policy_hyperparams,
|
|
1587
|
-
r2 = -loss_fn(subkey, p2, policy_hyperparams,
|
|
1565
|
+
r1 = -loss_fn(subkey, p1, policy_hyperparams, fls, nfls, model_params)[0]
|
|
1566
|
+
r2 = -loss_fn(subkey, p2, policy_hyperparams, fls, nfls, model_params)[0]
|
|
1567
|
+
|
|
1568
|
+
# do a return normalization for optimizer stability
|
|
1588
1569
|
r_max = jnp.maximum(r_max, r1)
|
|
1589
1570
|
r_max = jnp.maximum(r_max, r2)
|
|
1571
|
+
|
|
1572
|
+
# super symmetric sampling requires two more trajectories and their returns
|
|
1590
1573
|
if super_symmetric:
|
|
1591
|
-
r3 = -loss_fn(subkey, p3, policy_hyperparams,
|
|
1592
|
-
r4 = -loss_fn(subkey, p4, policy_hyperparams,
|
|
1574
|
+
r3 = -loss_fn(subkey, p3, policy_hyperparams, fls, nfls, model_params)[0]
|
|
1575
|
+
r4 = -loss_fn(subkey, p4, policy_hyperparams, fls, nfls, model_params)[0]
|
|
1593
1576
|
r_max = jnp.maximum(r_max, r3)
|
|
1594
1577
|
r_max = jnp.maximum(r_max, r4)
|
|
1595
1578
|
else:
|
|
1596
|
-
r3, r4 = r1, r2
|
|
1579
|
+
r3, r4 = r1, r2
|
|
1580
|
+
|
|
1581
|
+
# calculate gradient with respect to the mean
|
|
1597
1582
|
grad_mu = jax.tree_util.tree_map(
|
|
1598
1583
|
partial(_jax_wrapped_mu_grad, r1=r1, r2=r2, r3=r3, r4=r4, m=r_max),
|
|
1599
1584
|
epsilon, epsilon_star
|
|
1600
1585
|
)
|
|
1586
|
+
|
|
1587
|
+
# calculate gradient with respect to the sigma
|
|
1601
1588
|
grad_sigma = jax.tree_util.tree_map(
|
|
1602
1589
|
partial(_jax_wrapped_sigma_grad,
|
|
1603
1590
|
r1=r1, r2=r2, r3=r3, r4=r4, m=r_max, ent=ent),
|
|
@@ -1605,21 +1592,30 @@ class GaussianPGPE(PGPE):
|
|
|
1605
1592
|
)
|
|
1606
1593
|
return grad_mu, grad_sigma, r_max
|
|
1607
1594
|
|
|
1595
|
+
# calculate the policy gradients with batching on the first dimension
|
|
1608
1596
|
def _jax_wrapped_pgpe_grad_batched(key, pgpe_params, r_max, ent,
|
|
1609
|
-
policy_hyperparams,
|
|
1597
|
+
policy_hyperparams, fls, nfls, model_params):
|
|
1610
1598
|
mu, sigma = pgpe_params
|
|
1599
|
+
|
|
1600
|
+
# no batching required
|
|
1611
1601
|
if batch_size == 1:
|
|
1612
1602
|
mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad(
|
|
1613
|
-
key, mu, sigma, r_max, ent, policy_hyperparams,
|
|
1603
|
+
key, mu, sigma, r_max, ent, policy_hyperparams, fls, nfls, model_params)
|
|
1604
|
+
|
|
1605
|
+
# for batching need to handle how meta-gradients of mean, sigma are aggregated
|
|
1614
1606
|
else:
|
|
1607
|
+
# do the batched calculation of mean and sigma gradients
|
|
1615
1608
|
keys = random.split(key, num=batch_size)
|
|
1616
1609
|
mu_grads, sigma_grads, r_maxs = jax.vmap(
|
|
1617
1610
|
_jax_wrapped_pgpe_grad,
|
|
1618
|
-
in_axes=(0, None, None, None, None, None, None, None)
|
|
1619
|
-
)(keys, mu, sigma, r_max, ent, policy_hyperparams,
|
|
1611
|
+
in_axes=(0, None, None, None, None, None, None, None, None)
|
|
1612
|
+
)(keys, mu, sigma, r_max, ent, policy_hyperparams, fls, nfls, model_params)
|
|
1613
|
+
|
|
1614
|
+
# calculate the average gradient for aggregation
|
|
1620
1615
|
mu_grad, sigma_grad = jax.tree_util.tree_map(
|
|
1621
1616
|
partial(jnp.mean, axis=0), (mu_grads, sigma_grads))
|
|
1622
1617
|
new_r_max = jnp.max(r_maxs)
|
|
1618
|
+
|
|
1623
1619
|
return mu_grad, sigma_grad, new_r_max
|
|
1624
1620
|
|
|
1625
1621
|
# ***********************************************************************
|
|
@@ -1646,17 +1642,16 @@ class GaussianPGPE(PGPE):
|
|
|
1646
1642
|
return new_mu, new_sigma, new_mu_state, new_sigma_state
|
|
1647
1643
|
|
|
1648
1644
|
def _jax_wrapped_pgpe_update(key, pgpe_params, r_max, progress,
|
|
1649
|
-
policy_hyperparams,
|
|
1645
|
+
policy_hyperparams, fls, nfls, model_params,
|
|
1650
1646
|
pgpe_opt_state):
|
|
1651
|
-
# regular update
|
|
1647
|
+
# regular update for pgpe
|
|
1652
1648
|
mu, sigma = pgpe_params
|
|
1653
1649
|
mu_state, sigma_state = pgpe_opt_state
|
|
1654
1650
|
ent = start_entropy_coeff * jnp.power(entropy_coeff_decay, progress)
|
|
1655
1651
|
mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad_batched(
|
|
1656
|
-
key, pgpe_params, r_max, ent, policy_hyperparams,
|
|
1657
|
-
new_mu, new_sigma, new_mu_state, new_sigma_state =
|
|
1658
|
-
|
|
1659
|
-
mu_state, sigma_state)
|
|
1652
|
+
key, pgpe_params, r_max, ent, policy_hyperparams, fls, nfls, model_params)
|
|
1653
|
+
new_mu, new_sigma, new_mu_state, new_sigma_state = _jax_wrapped_pgpe_update_helper(
|
|
1654
|
+
mu, sigma, mu_grad, sigma_grad, mu_state, sigma_state)
|
|
1660
1655
|
|
|
1661
1656
|
# respect KL divergence contraint with old parameters
|
|
1662
1657
|
if max_kl is not None:
|
|
@@ -1668,34 +1663,30 @@ class GaussianPGPE(PGPE):
|
|
|
1668
1663
|
kl_reduction = jnp.minimum(1.0, jnp.sqrt(max_kl / total_kl))
|
|
1669
1664
|
mu_state.hyperparams['learning_rate'] = old_mu_lr * kl_reduction
|
|
1670
1665
|
sigma_state.hyperparams['learning_rate'] = old_sigma_lr * kl_reduction
|
|
1671
|
-
new_mu, new_sigma, new_mu_state, new_sigma_state =
|
|
1672
|
-
|
|
1673
|
-
mu_state, sigma_state)
|
|
1666
|
+
new_mu, new_sigma, new_mu_state, new_sigma_state = _jax_wrapped_pgpe_update_helper(
|
|
1667
|
+
mu, sigma, mu_grad, sigma_grad, mu_state, sigma_state)
|
|
1674
1668
|
new_mu_state.hyperparams['learning_rate'] = old_mu_lr
|
|
1675
1669
|
new_sigma_state.hyperparams['learning_rate'] = old_sigma_lr
|
|
1676
1670
|
|
|
1677
|
-
# apply projection step
|
|
1671
|
+
# apply projection step to the sampled policy
|
|
1678
1672
|
new_mu, converged = projection(new_mu, policy_hyperparams)
|
|
1673
|
+
|
|
1679
1674
|
new_pgpe_params = (new_mu, new_sigma)
|
|
1680
1675
|
new_pgpe_opt_state = (new_mu_state, new_sigma_state)
|
|
1681
1676
|
policy_params = new_mu
|
|
1682
1677
|
return new_pgpe_params, new_r_max, new_pgpe_opt_state, policy_params, converged
|
|
1683
1678
|
|
|
1684
|
-
|
|
1685
|
-
|
|
1686
|
-
|
|
1687
|
-
|
|
1688
|
-
|
|
1689
|
-
|
|
1690
|
-
|
|
1691
|
-
|
|
1692
|
-
|
|
1693
|
-
|
|
1694
|
-
|
|
1695
|
-
)(keys, pgpe_params, r_max, progress, policy_hyperparams, subs,
|
|
1696
|
-
model_params, pgpe_opt_state)
|
|
1697
|
-
|
|
1698
|
-
self._update = jax.jit(_jax_wrapped_pgpe_updates)
|
|
1679
|
+
# for parallel policy update
|
|
1680
|
+
def _jax_wrapped_batched_pgpe_updates(key, pgpe_params, r_max, progress,
|
|
1681
|
+
policy_hyperparams, fls, nfls, model_params,
|
|
1682
|
+
pgpe_opt_state):
|
|
1683
|
+
keys = random.split(key, num=parallel_updates)
|
|
1684
|
+
return jax.vmap(
|
|
1685
|
+
_jax_wrapped_pgpe_update, in_axes=(0, 0, 0, None, None, None, None, 0, 0)
|
|
1686
|
+
)(keys, pgpe_params, r_max, progress, policy_hyperparams, fls, nfls,
|
|
1687
|
+
model_params, pgpe_opt_state)
|
|
1688
|
+
|
|
1689
|
+
self._update = jax.jit(_jax_wrapped_batched_pgpe_updates)
|
|
1699
1690
|
|
|
1700
1691
|
|
|
1701
1692
|
# ***********************************************************************
|
|
@@ -1757,7 +1748,7 @@ def var_utility(returns: jnp.ndarray, alpha: float) -> float:
|
|
|
1757
1748
|
def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
|
|
1758
1749
|
var = jnp.percentile(returns, q=100 * alpha)
|
|
1759
1750
|
mask = returns <= var
|
|
1760
|
-
return jnp.sum(returns * mask) / jnp.maximum(1, jnp.
|
|
1751
|
+
return jnp.sum(returns * mask) / jnp.maximum(1, jnp.count_nonzero(mask))
|
|
1761
1752
|
|
|
1762
1753
|
|
|
1763
1754
|
# set of all currently valid built-in utility functions
|
|
@@ -1783,6 +1774,11 @@ UTILITY_LOOKUP = {
|
|
|
1783
1774
|
# ***********************************************************************
|
|
1784
1775
|
|
|
1785
1776
|
|
|
1777
|
+
@jax.jit
|
|
1778
|
+
def pytree_at(tree: Pytree, i: int) -> Pytree:
|
|
1779
|
+
return jax.tree_util.tree_map(lambda x: x[i], tree)
|
|
1780
|
+
|
|
1781
|
+
|
|
1786
1782
|
class JaxBackpropPlanner:
|
|
1787
1783
|
'''A class for optimizing an action sequence in the given RDDL MDP using
|
|
1788
1784
|
gradient descent.'''
|
|
@@ -1792,30 +1788,29 @@ class JaxBackpropPlanner:
|
|
|
1792
1788
|
batch_size_train: int=32,
|
|
1793
1789
|
batch_size_test: Optional[int]=None,
|
|
1794
1790
|
rollout_horizon: Optional[int]=None,
|
|
1795
|
-
|
|
1791
|
+
parallel_updates: int=1,
|
|
1796
1792
|
action_bounds: Optional[Bounds]=None,
|
|
1797
1793
|
optimizer: Callable[..., optax.GradientTransformation]=optax.rmsprop,
|
|
1798
1794
|
optimizer_kwargs: Optional[Kwargs]=None,
|
|
1799
1795
|
clip_grad: Optional[float]=None,
|
|
1800
1796
|
line_search_kwargs: Optional[Kwargs]=None,
|
|
1801
1797
|
noise_kwargs: Optional[Kwargs]=None,
|
|
1798
|
+
ema_decay: Optional[float]=None,
|
|
1802
1799
|
pgpe: Optional[PGPE]=GaussianPGPE(),
|
|
1803
|
-
|
|
1800
|
+
compiler: JaxRDDLCompilerWithGrad=DefaultJaxRDDLCompilerWithGrad,
|
|
1801
|
+
compiler_kwargs: Optional[Kwargs]=None,
|
|
1804
1802
|
use_symlog_reward: bool=False,
|
|
1805
1803
|
utility: Union[Callable[[jnp.ndarray], float], str]='mean',
|
|
1806
1804
|
utility_kwargs: Optional[Kwargs]=None,
|
|
1807
|
-
cpfs_without_grad: Optional[Set[str]]=None,
|
|
1808
|
-
compile_non_fluent_exact: bool=True,
|
|
1809
1805
|
logger: Optional[Logger]=None,
|
|
1806
|
+
dashboard: Optional[Any]=None,
|
|
1810
1807
|
dashboard_viz: Optional[Any]=None,
|
|
1811
|
-
print_warnings: bool=True,
|
|
1812
|
-
parallel_updates: Optional[int]=None,
|
|
1813
1808
|
preprocessor: Optional[Preprocessor]=None,
|
|
1814
1809
|
python_functions: Optional[Dict[str, Callable]]=None) -> None:
|
|
1815
1810
|
'''Creates a new gradient-based algorithm for optimizing action sequences
|
|
1816
1811
|
(plan) in the given RDDL. Some operations will be converted to their
|
|
1817
1812
|
differentiable counterparts; the specific operations can be customized
|
|
1818
|
-
by providing a
|
|
1813
|
+
by providing a tailored compiler instance.
|
|
1819
1814
|
|
|
1820
1815
|
:param rddl: the RDDL domain to optimize
|
|
1821
1816
|
:param plan: the policy/plan representation to optimize
|
|
@@ -1823,9 +1818,8 @@ class JaxBackpropPlanner:
|
|
|
1823
1818
|
step
|
|
1824
1819
|
:param batch_size_test: how many rollouts to use to test the plan at each
|
|
1825
1820
|
optimization step
|
|
1826
|
-
:param rollout_horizon: lookahead planning horizon: None uses the
|
|
1827
|
-
:param
|
|
1828
|
-
horizon parameter in the RDDL instance
|
|
1821
|
+
:param rollout_horizon: lookahead planning horizon: None uses the env horizon
|
|
1822
|
+
:param parallel_updates: how many optimizers to run independently in parallel
|
|
1829
1823
|
:param action_bounds: box constraints on actions
|
|
1830
1824
|
:param optimizer: a factory for an optax SGD algorithm
|
|
1831
1825
|
:param optimizer_kwargs: a dictionary of parameters to pass to the SGD
|
|
@@ -1834,9 +1828,10 @@ class JaxBackpropPlanner:
|
|
|
1834
1828
|
:param line_search_kwargs: parameters to pass to optional line search
|
|
1835
1829
|
method to scale learning rate
|
|
1836
1830
|
:param noise_kwargs: parameters of optional gradient noise
|
|
1831
|
+
:param ema_decay: optional exponential moving average of past parameters
|
|
1837
1832
|
:param pgpe: optional policy gradient to run alongside the planner
|
|
1838
|
-
:param
|
|
1839
|
-
|
|
1833
|
+
:param compiler: compiler instance to use for planning
|
|
1834
|
+
:param compiler_kwargs: compiler instances kwargs for initialization
|
|
1840
1835
|
:param use_symlog_reward: whether to use the symlog transform on the
|
|
1841
1836
|
reward as a form of normalization
|
|
1842
1837
|
:param utility: how to aggregate return observations to compute utility
|
|
@@ -1844,15 +1839,10 @@ class JaxBackpropPlanner:
|
|
|
1844
1839
|
scalar, or a a string identifying the utility function by name
|
|
1845
1840
|
:param utility_kwargs: additional keyword arguments to pass hyper-
|
|
1846
1841
|
parameters to the utility function call
|
|
1847
|
-
:param cpfs_without_grad: which CPFs do not have gradients (use straight
|
|
1848
|
-
through gradient trick)
|
|
1849
|
-
:param compile_non_fluent_exact: whether non-fluent expressions
|
|
1850
|
-
are always compiled using exact JAX expressions
|
|
1851
1842
|
:param logger: to log information about compilation to file
|
|
1843
|
+
:param dashboard: optional dashboard to display training progress and results
|
|
1852
1844
|
:param dashboard_viz: optional visualizer object from the environment
|
|
1853
1845
|
to pass to the dashboard to visualize the policy
|
|
1854
|
-
:param print_warnings: whether to print warnings
|
|
1855
|
-
:param parallel_updates: how many optimizers to run independently in parallel
|
|
1856
1846
|
:param preprocessor: optional preprocessor for state inputs to plan
|
|
1857
1847
|
:param python_functions: dictionary of external Python functions to call from RDDL
|
|
1858
1848
|
'''
|
|
@@ -1869,7 +1859,6 @@ class JaxBackpropPlanner:
|
|
|
1869
1859
|
if action_bounds is None:
|
|
1870
1860
|
action_bounds = {}
|
|
1871
1861
|
self._action_bounds = action_bounds
|
|
1872
|
-
self.use64bit = use64bit
|
|
1873
1862
|
self.optimizer_name = optimizer
|
|
1874
1863
|
if optimizer_kwargs is None:
|
|
1875
1864
|
optimizer_kwargs = {'learning_rate': 0.1}
|
|
@@ -1877,9 +1866,9 @@ class JaxBackpropPlanner:
|
|
|
1877
1866
|
self.clip_grad = clip_grad
|
|
1878
1867
|
self.line_search_kwargs = line_search_kwargs
|
|
1879
1868
|
self.noise_kwargs = noise_kwargs
|
|
1869
|
+
self.ema_decay = ema_decay
|
|
1880
1870
|
self.pgpe = pgpe
|
|
1881
1871
|
self.use_pgpe = pgpe is not None
|
|
1882
|
-
self.print_warnings = print_warnings
|
|
1883
1872
|
self.preprocessor = preprocessor
|
|
1884
1873
|
if python_functions is None:
|
|
1885
1874
|
python_functions = {}
|
|
@@ -1889,11 +1878,10 @@ class JaxBackpropPlanner:
|
|
|
1889
1878
|
try:
|
|
1890
1879
|
optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs)
|
|
1891
1880
|
except Exception as _:
|
|
1892
|
-
|
|
1893
|
-
'[
|
|
1894
|
-
'
|
|
1895
|
-
|
|
1896
|
-
print(message)
|
|
1881
|
+
print(termcolor.colored(
|
|
1882
|
+
'[WARN] Could not inject hyperparameters into JaxPlan optimizer: '
|
|
1883
|
+
'runtime modification of hyperparameters will be disabled.', 'yellow'
|
|
1884
|
+
))
|
|
1897
1885
|
optimizer = optimizer(**optimizer_kwargs)
|
|
1898
1886
|
|
|
1899
1887
|
# apply optimizer chain of transformations
|
|
@@ -1905,6 +1893,8 @@ class JaxBackpropPlanner:
|
|
|
1905
1893
|
pipeline.append(optimizer)
|
|
1906
1894
|
if line_search_kwargs is not None:
|
|
1907
1895
|
pipeline.append(optax.scale_by_zoom_linesearch(**line_search_kwargs))
|
|
1896
|
+
if ema_decay is not None:
|
|
1897
|
+
pipeline.append(optax.ema(ema_decay))
|
|
1908
1898
|
self.optimizer = optax.chain(*pipeline)
|
|
1909
1899
|
|
|
1910
1900
|
# set utility
|
|
@@ -1914,99 +1904,75 @@ class JaxBackpropPlanner:
|
|
|
1914
1904
|
if utility_fn is None:
|
|
1915
1905
|
raise RDDLNotImplementedError(
|
|
1916
1906
|
f'Utility function <{utility}> is not supported, '
|
|
1917
|
-
f'must be one of {list(UTILITY_LOOKUP.keys())}.'
|
|
1907
|
+
f'must be one of {list(UTILITY_LOOKUP.keys())}.'
|
|
1908
|
+
)
|
|
1918
1909
|
else:
|
|
1919
1910
|
utility_fn = utility
|
|
1920
1911
|
self.utility = utility_fn
|
|
1921
|
-
|
|
1922
1912
|
if utility_kwargs is None:
|
|
1923
1913
|
utility_kwargs = {}
|
|
1924
1914
|
self.utility_kwargs = utility_kwargs
|
|
1925
1915
|
|
|
1926
|
-
|
|
1927
|
-
|
|
1916
|
+
if compiler_kwargs is None:
|
|
1917
|
+
compiler_kwargs = {}
|
|
1918
|
+
self.compiler_type = compiler
|
|
1919
|
+
self.compiler_kwargs = compiler_kwargs
|
|
1928
1920
|
self.use_symlog_reward = use_symlog_reward
|
|
1929
|
-
|
|
1930
|
-
cpfs_without_grad = set()
|
|
1931
|
-
self.cpfs_without_grad = cpfs_without_grad
|
|
1932
|
-
self.compile_non_fluent_exact = compile_non_fluent_exact
|
|
1921
|
+
|
|
1933
1922
|
self.logger = logger
|
|
1923
|
+
self.dashboard = dashboard
|
|
1934
1924
|
self.dashboard_viz = dashboard_viz
|
|
1935
1925
|
|
|
1936
|
-
self.
|
|
1937
|
-
self._jax_compile_optimizer()
|
|
1926
|
+
self._jax_compile_graph()
|
|
1938
1927
|
|
|
1939
1928
|
@staticmethod
|
|
1940
1929
|
def summarize_system() -> str:
|
|
1941
1930
|
'''Returns a string containing information about the system, Python version
|
|
1942
1931
|
and jax-related packages that are relevant to the current planner.
|
|
1943
|
-
'''
|
|
1944
|
-
|
|
1945
|
-
|
|
1946
|
-
|
|
1947
|
-
|
|
1948
|
-
|
|
1949
|
-
|
|
1950
|
-
|
|
1951
|
-
except Exception as _:
|
|
1952
|
-
devices_short = 'N/A'
|
|
1953
|
-
LOGO = \
|
|
1954
|
-
r"""
|
|
1955
|
-
__ ______ __ __ ______ __ ______ __ __
|
|
1956
|
-
/\ \ /\ __ \ /\_\_\_\ /\ == \/\ \ /\ __ \ /\ "-.\ \
|
|
1957
|
-
_\_\ \\ \ __ \\/_/\_\/_\ \ _-/\ \ \____\ \ __ \\ \ \-. \
|
|
1958
|
-
/\_____\\ \_\ \_\ /\_\/\_\\ \_\ \ \_____\\ \_\ \_\\ \_\\"\_\
|
|
1959
|
-
\/_____/ \/_/\/_/ \/_/\/_/ \/_/ \/_____/ \/_/\/_/ \/_/ \/_/
|
|
1960
|
-
"""
|
|
1961
|
-
|
|
1962
|
-
return (f'\n'
|
|
1963
|
-
f'{LOGO}\n'
|
|
1964
|
-
f'Version {__version__}\n'
|
|
1965
|
-
f'Python {sys.version}\n'
|
|
1966
|
-
f'jax {jax.version.__version__}, jaxlib {jaxlib_version}, '
|
|
1967
|
-
f'optax {optax.__version__}, haiku {hk.__version__}, '
|
|
1968
|
-
f'numpy {np.__version__}\n'
|
|
1969
|
-
f'devices: {devices_short}\n')
|
|
1932
|
+
'''
|
|
1933
|
+
devices = jax.devices()
|
|
1934
|
+
default_device = devices[0] if devices else 'n/a'
|
|
1935
|
+
return termcolor.colored(
|
|
1936
|
+
'\n'
|
|
1937
|
+
f'Starting JaxPlan v{__version__} '
|
|
1938
|
+
f'on device {default_device.platform}{default_device.id}\n', attrs=['bold']
|
|
1939
|
+
)
|
|
1970
1940
|
|
|
1971
1941
|
def summarize_relaxations(self) -> str:
|
|
1972
1942
|
'''Returns a summary table containing all non-differentiable operators
|
|
1973
1943
|
and their relaxations.
|
|
1974
1944
|
'''
|
|
1975
1945
|
result = ''
|
|
1976
|
-
|
|
1977
|
-
|
|
1946
|
+
overriden_ops_info = self.compiled.overriden_ops_info()
|
|
1947
|
+
if overriden_ops_info:
|
|
1948
|
+
result += ('[INFO] Some RDDL operations are non-differentiable '
|
|
1978
1949
|
'and will be approximated as follows:' + '\n')
|
|
1979
|
-
|
|
1980
|
-
|
|
1981
|
-
|
|
1982
|
-
|
|
1983
|
-
|
|
1984
|
-
|
|
1985
|
-
|
|
1986
|
-
f' addresses ={exprs_by_rddl_op[rddl_op]}\n'
|
|
1987
|
-
f' init_values={values_by_rddl_op[rddl_op]}\n')
|
|
1950
|
+
for (class_, op_to_ids_dict) in overriden_ops_info.items():
|
|
1951
|
+
result += f' {class_}:\n'
|
|
1952
|
+
for (op, ids) in op_to_ids_dict.items():
|
|
1953
|
+
result += (
|
|
1954
|
+
f' {op} ' +
|
|
1955
|
+
termcolor.colored(f'[{len(ids)} occurences]\n', 'dark_grey')
|
|
1956
|
+
)
|
|
1988
1957
|
return result
|
|
1989
1958
|
|
|
1990
1959
|
def summarize_hyperparameters(self) -> str:
|
|
1991
1960
|
'''Returns a string summarizing the hyper-parameters of the current planner
|
|
1992
1961
|
instance.
|
|
1993
1962
|
'''
|
|
1994
|
-
result = (f'objective hyper-parameters:\n'
|
|
1963
|
+
result = (f'[INFO] objective hyper-parameters:\n'
|
|
1995
1964
|
f' utility_fn ={self.utility.__name__}\n'
|
|
1996
1965
|
f' utility args ={self.utility_kwargs}\n'
|
|
1997
1966
|
f' use_symlog ={self.use_symlog_reward}\n'
|
|
1998
1967
|
f' lookahead ={self.horizon}\n'
|
|
1999
1968
|
f' user_action_bounds={self._action_bounds}\n'
|
|
2000
|
-
f'
|
|
2001
|
-
f' non_fluents exact ={self.compile_non_fluent_exact}\n'
|
|
2002
|
-
f' cpfs_no_gradient ={self.cpfs_without_grad}\n'
|
|
2003
|
-
f'optimizer hyper-parameters:\n'
|
|
2004
|
-
f' use_64_bit ={self.use64bit}\n'
|
|
1969
|
+
f'[INFO] optimizer hyper-parameters:\n'
|
|
2005
1970
|
f' optimizer ={self.optimizer_name}\n'
|
|
2006
1971
|
f' optimizer args ={self.optimizer_kwargs}\n'
|
|
2007
1972
|
f' clip_gradient ={self.clip_grad}\n'
|
|
2008
1973
|
f' line_search_kwargs={self.line_search_kwargs}\n'
|
|
2009
1974
|
f' noise_kwargs ={self.noise_kwargs}\n'
|
|
1975
|
+
f' ema_decay ={self.ema_decay}\n'
|
|
2010
1976
|
f' batch_size_train ={self.batch_size_train}\n'
|
|
2011
1977
|
f' batch_size_test ={self.batch_size_test}\n'
|
|
2012
1978
|
f' parallel_updates ={self.parallel_updates}\n'
|
|
@@ -2014,89 +1980,78 @@ r"""
|
|
|
2014
1980
|
result += str(self.plan)
|
|
2015
1981
|
if self.use_pgpe:
|
|
2016
1982
|
result += str(self.pgpe)
|
|
2017
|
-
result +=
|
|
1983
|
+
result += 'test compiler:\n'
|
|
1984
|
+
for k, v in self.test_compiled.get_kwargs().items():
|
|
1985
|
+
result += f' {k}={v}\n'
|
|
1986
|
+
result += 'train compiler:\n'
|
|
1987
|
+
for k, v in self.compiled.get_kwargs().items():
|
|
1988
|
+
result += f' {k}={v}\n'
|
|
2018
1989
|
return result
|
|
2019
1990
|
|
|
2020
1991
|
# ===========================================================================
|
|
2021
|
-
#
|
|
1992
|
+
# COMPILE RDDL
|
|
2022
1993
|
# ===========================================================================
|
|
2023
1994
|
|
|
2024
1995
|
def _jax_compile_rddl(self):
|
|
2025
|
-
|
|
2026
|
-
|
|
2027
|
-
# Jax compilation of the differentiable RDDL for training
|
|
2028
|
-
self.compiled = JaxRDDLCompilerWithGrad(
|
|
2029
|
-
rddl=rddl,
|
|
2030
|
-
logic=self.logic,
|
|
1996
|
+
self.compiled = self.compiler_type(
|
|
1997
|
+
rddl=self.rddl,
|
|
2031
1998
|
logger=self.logger,
|
|
2032
|
-
|
|
2033
|
-
|
|
2034
|
-
compile_non_fluent_exact=self.compile_non_fluent_exact,
|
|
2035
|
-
print_warnings=self.print_warnings,
|
|
2036
|
-
python_functions=self.python_functions
|
|
1999
|
+
python_functions=self.python_functions,
|
|
2000
|
+
**self.compiler_kwargs
|
|
2037
2001
|
)
|
|
2002
|
+
self.print_warnings = self.compiled.print_warnings
|
|
2038
2003
|
self.compiled.compile(log_jax_expr=True, heading='RELAXED MODEL')
|
|
2039
|
-
|
|
2040
|
-
# Jax compilation of the exact RDDL for testing
|
|
2004
|
+
|
|
2041
2005
|
self.test_compiled = JaxRDDLCompiler(
|
|
2042
|
-
rddl=rddl,
|
|
2006
|
+
rddl=self.rddl,
|
|
2007
|
+
allow_synchronous_state=True,
|
|
2043
2008
|
logger=self.logger,
|
|
2044
|
-
use64bit=self.use64bit,
|
|
2009
|
+
use64bit=self.compiled.use64bit,
|
|
2045
2010
|
python_functions=self.python_functions
|
|
2046
2011
|
)
|
|
2047
2012
|
self.test_compiled.compile(log_jax_expr=True, heading='EXACT MODEL')
|
|
2048
|
-
|
|
2049
|
-
def
|
|
2050
|
-
|
|
2051
|
-
# preprocessor
|
|
2013
|
+
|
|
2014
|
+
def _jax_compile_policy(self):
|
|
2052
2015
|
if self.preprocessor is not None:
|
|
2053
2016
|
self.preprocessor.compile(self.compiled)
|
|
2054
|
-
|
|
2055
|
-
# policy
|
|
2056
2017
|
self.plan.compile(self.compiled,
|
|
2057
2018
|
_bounds=self._action_bounds,
|
|
2058
2019
|
horizon=self.horizon,
|
|
2059
2020
|
preprocessor=self.preprocessor)
|
|
2060
2021
|
self.train_policy = jax.jit(self.plan.train_policy)
|
|
2061
2022
|
self.test_policy = jax.jit(self.plan.test_policy)
|
|
2062
|
-
|
|
2063
|
-
|
|
2064
|
-
train_rollouts = self.compiled.compile_rollouts(
|
|
2023
|
+
|
|
2024
|
+
def _jax_compile_rollouts(self):
|
|
2025
|
+
self.train_rollouts = self.compiled.compile_rollouts(
|
|
2065
2026
|
policy=self.plan.train_policy,
|
|
2066
2027
|
n_steps=self.horizon,
|
|
2067
2028
|
n_batch=self.batch_size_train,
|
|
2068
|
-
cache_path_info=self.preprocessor is not None
|
|
2029
|
+
cache_path_info=self.preprocessor is not None or self.dashboard is not None
|
|
2069
2030
|
)
|
|
2070
|
-
self.train_rollouts = train_rollouts
|
|
2071
|
-
|
|
2072
2031
|
test_rollouts = self.test_compiled.compile_rollouts(
|
|
2073
2032
|
policy=self.plan.test_policy,
|
|
2074
2033
|
n_steps=self.horizon,
|
|
2075
2034
|
n_batch=self.batch_size_test,
|
|
2076
|
-
cache_path_info=
|
|
2035
|
+
cache_path_info=self.dashboard is not None
|
|
2077
2036
|
)
|
|
2078
2037
|
self.test_rollouts = jax.jit(test_rollouts)
|
|
2079
|
-
|
|
2080
|
-
|
|
2081
|
-
self.initialize, self.init_optimizer = self.
|
|
2082
|
-
|
|
2083
|
-
|
|
2084
|
-
train_loss = self._jax_loss(train_rollouts, use_symlog=self.use_symlog_reward)
|
|
2085
|
-
test_loss = self._jax_loss(test_rollouts, use_symlog=False)
|
|
2086
|
-
if self.parallel_updates is None:
|
|
2087
|
-
self.test_loss = jax.jit(test_loss)
|
|
2088
|
-
else:
|
|
2089
|
-
self.test_loss = jax.jit(jax.vmap(test_loss, in_axes=(None, 0, None, None, 0)))
|
|
2090
|
-
|
|
2091
|
-
# optimization
|
|
2038
|
+
|
|
2039
|
+
def _jax_compile_train_update(self):
|
|
2040
|
+
self.initialize, self.init_optimizer = self._jax_init_optimizer()
|
|
2041
|
+
train_loss = self._jax_loss(self.train_rollouts, use_symlog=self.use_symlog_reward)
|
|
2042
|
+
self.single_train_loss = train_loss
|
|
2092
2043
|
self.update = self._jax_update(train_loss)
|
|
2093
|
-
|
|
2094
|
-
|
|
2044
|
+
|
|
2045
|
+
def _jax_compile_test_loss(self):
|
|
2046
|
+
test_loss = self._jax_loss(self.test_rollouts, use_symlog=False)
|
|
2047
|
+
self.single_test_loss = test_loss
|
|
2048
|
+
self.test_loss = jax.jit(jax.vmap(
|
|
2049
|
+
test_loss, in_axes=(None, 0, None, None, None, 0)))
|
|
2095
2050
|
|
|
2096
|
-
|
|
2051
|
+
def _jax_compile_pgpe(self):
|
|
2097
2052
|
if self.use_pgpe:
|
|
2098
2053
|
self.pgpe.compile(
|
|
2099
|
-
loss_fn=
|
|
2054
|
+
loss_fn=self.single_test_loss,
|
|
2100
2055
|
projection=self.plan.projection,
|
|
2101
2056
|
real_dtype=self.test_compiled.REAL,
|
|
2102
2057
|
print_warnings=self.print_warnings,
|
|
@@ -2106,6 +2061,14 @@ r"""
|
|
|
2106
2061
|
else:
|
|
2107
2062
|
self.merge_pgpe = None
|
|
2108
2063
|
|
|
2064
|
+
def _jax_compile_graph(self):
|
|
2065
|
+
self._jax_compile_rddl()
|
|
2066
|
+
self._jax_compile_policy()
|
|
2067
|
+
self._jax_compile_rollouts()
|
|
2068
|
+
self._jax_compile_train_update()
|
|
2069
|
+
self._jax_compile_test_loss()
|
|
2070
|
+
self._jax_compile_pgpe()
|
|
2071
|
+
|
|
2109
2072
|
def _jax_return(self, use_symlog):
|
|
2110
2073
|
gamma = self.rddl.discount
|
|
2111
2074
|
|
|
@@ -2117,9 +2080,8 @@ r"""
|
|
|
2117
2080
|
rewards = rewards * discount[jnp.newaxis, ...]
|
|
2118
2081
|
returns = jnp.sum(rewards, axis=1)
|
|
2119
2082
|
if use_symlog:
|
|
2120
|
-
returns = jnp.sign(returns) * jnp.
|
|
2083
|
+
returns = jnp.sign(returns) * jnp.log1p(jnp.abs(returns))
|
|
2121
2084
|
return returns
|
|
2122
|
-
|
|
2123
2085
|
return _jax_wrapped_returns
|
|
2124
2086
|
|
|
2125
2087
|
def _jax_loss(self, rollouts, use_symlog=False):
|
|
@@ -2128,48 +2090,44 @@ r"""
|
|
|
2128
2090
|
_jax_wrapped_returns = self._jax_return(use_symlog)
|
|
2129
2091
|
|
|
2130
2092
|
# the loss is the average cumulative reward across all roll-outs
|
|
2131
|
-
|
|
2132
|
-
|
|
2093
|
+
# but applies a utility function if requested to each return observation:
|
|
2094
|
+
# by default, the utility function is the mean
|
|
2095
|
+
def _jax_wrapped_plan_loss(key, policy_params, policy_hyperparams, fls, nfls,
|
|
2096
|
+
model_params):
|
|
2133
2097
|
log, model_params = rollouts(
|
|
2134
|
-
key, policy_params, policy_hyperparams,
|
|
2098
|
+
key, policy_params, policy_hyperparams, fls, nfls, model_params)
|
|
2135
2099
|
rewards = log['reward']
|
|
2136
2100
|
returns = _jax_wrapped_returns(rewards)
|
|
2137
2101
|
utility = utility_fn(returns, **utility_kwargs)
|
|
2138
2102
|
loss = -utility
|
|
2139
2103
|
aux = (log, model_params)
|
|
2140
2104
|
return loss, aux
|
|
2141
|
-
|
|
2142
2105
|
return _jax_wrapped_plan_loss
|
|
2143
2106
|
|
|
2144
|
-
def
|
|
2107
|
+
def _jax_init_optimizer(self):
|
|
2145
2108
|
init = self.plan.initializer
|
|
2146
2109
|
optimizer = self.optimizer
|
|
2147
2110
|
num_parallel = self.parallel_updates
|
|
2148
2111
|
|
|
2149
2112
|
# initialize both the policy and its optimizer
|
|
2150
|
-
def _jax_wrapped_init_policy(key, policy_hyperparams,
|
|
2151
|
-
policy_params = init(key, policy_hyperparams,
|
|
2113
|
+
def _jax_wrapped_init_policy(key, policy_hyperparams, fls):
|
|
2114
|
+
policy_params = init(key, policy_hyperparams, fls)
|
|
2152
2115
|
opt_state = optimizer.init(policy_params)
|
|
2153
2116
|
return policy_params, opt_state, {}
|
|
2154
2117
|
|
|
2155
2118
|
# initialize just the optimizer from the policy
|
|
2156
2119
|
def _jax_wrapped_init_opt(policy_params):
|
|
2157
|
-
|
|
2158
|
-
opt_state = optimizer.init(policy_params)
|
|
2159
|
-
else:
|
|
2160
|
-
opt_state = jax.vmap(optimizer.init, in_axes=0)(policy_params)
|
|
2120
|
+
opt_state = jax.vmap(optimizer.init, in_axes=0)(policy_params)
|
|
2161
2121
|
return opt_state, {}
|
|
2162
2122
|
|
|
2163
|
-
|
|
2164
|
-
|
|
2165
|
-
|
|
2166
|
-
|
|
2167
|
-
|
|
2168
|
-
keys
|
|
2169
|
-
return jax.vmap(_jax_wrapped_init_policy, in_axes=(0, None, None))(
|
|
2170
|
-
keys, policy_hyperparams, subs)
|
|
2123
|
+
# initialize multiple policies to be optimized in parallel
|
|
2124
|
+
def _jax_wrapped_batched_init_policy(key, policy_hyperparams, fls):
|
|
2125
|
+
keys = random.split(key, num=num_parallel)
|
|
2126
|
+
return jax.vmap(
|
|
2127
|
+
_jax_wrapped_init_policy, in_axes=(0, None, None)
|
|
2128
|
+
)(keys, policy_hyperparams, fls)
|
|
2171
2129
|
|
|
2172
|
-
return jax.jit(
|
|
2130
|
+
return jax.jit(_jax_wrapped_batched_init_policy), jax.jit(_jax_wrapped_init_opt)
|
|
2173
2131
|
|
|
2174
2132
|
def _jax_update(self, loss):
|
|
2175
2133
|
optimizer = self.optimizer
|
|
@@ -2185,114 +2143,121 @@ r"""
|
|
|
2185
2143
|
|
|
2186
2144
|
# calculate the plan gradient w.r.t. return loss and update optimizer
|
|
2187
2145
|
# also perform a projection step to satisfy constraints on actions
|
|
2188
|
-
def _jax_wrapped_loss_swapped(policy_params, key, policy_hyperparams,
|
|
2189
|
-
|
|
2190
|
-
return loss(key, policy_params, policy_hyperparams,
|
|
2146
|
+
def _jax_wrapped_loss_swapped(policy_params, key, policy_hyperparams, fls, nfls,
|
|
2147
|
+
model_params):
|
|
2148
|
+
return loss(key, policy_params, policy_hyperparams, fls, nfls, model_params)[0]
|
|
2191
2149
|
|
|
2192
|
-
def _jax_wrapped_plan_update(key, policy_params, policy_hyperparams,
|
|
2193
|
-
|
|
2150
|
+
def _jax_wrapped_plan_update(key, policy_params, policy_hyperparams, fls, nfls,
|
|
2151
|
+
model_params, opt_state, opt_aux):
|
|
2152
|
+
|
|
2153
|
+
# calculate the gradient of the loss with respect to the policy
|
|
2194
2154
|
grad_fn = jax.value_and_grad(loss, argnums=1, has_aux=True)
|
|
2195
2155
|
(loss_val, (log, model_params)), grad = grad_fn(
|
|
2196
|
-
key, policy_params, policy_hyperparams,
|
|
2156
|
+
key, policy_params, policy_hyperparams, fls, nfls, model_params)
|
|
2157
|
+
|
|
2158
|
+
# require a slightly different update if line search is used
|
|
2197
2159
|
if use_ls:
|
|
2198
2160
|
updates, opt_state = optimizer.update(
|
|
2199
2161
|
grad, opt_state, params=policy_params,
|
|
2200
2162
|
value=loss_val, grad=grad, value_fn=_jax_wrapped_loss_swapped,
|
|
2201
|
-
key=key, policy_hyperparams=policy_hyperparams,
|
|
2202
|
-
model_params=model_params
|
|
2163
|
+
key=key, policy_hyperparams=policy_hyperparams, fls=fls, nfls=nfls,
|
|
2164
|
+
model_params=model_params
|
|
2165
|
+
)
|
|
2203
2166
|
else:
|
|
2204
|
-
updates, opt_state = optimizer.update(
|
|
2205
|
-
|
|
2167
|
+
updates, opt_state = optimizer.update(grad, opt_state, params=policy_params)
|
|
2168
|
+
|
|
2169
|
+
# apply optimizer and optional policy projection
|
|
2206
2170
|
policy_params = optax.apply_updates(policy_params, updates)
|
|
2207
2171
|
policy_params, converged = projection(policy_params, policy_hyperparams)
|
|
2172
|
+
|
|
2208
2173
|
log['grad'] = grad
|
|
2209
2174
|
log['updates'] = updates
|
|
2210
2175
|
zero_grads = _jax_wrapped_zero_gradients(grad)
|
|
2211
|
-
return policy_params, converged, opt_state, opt_aux,
|
|
2212
|
-
|
|
2176
|
+
return (policy_params, converged, opt_state, opt_aux,
|
|
2177
|
+
loss_val, log, model_params, zero_grads)
|
|
2213
2178
|
|
|
2214
|
-
|
|
2215
|
-
|
|
2216
|
-
|
|
2217
|
-
|
|
2218
|
-
def _jax_wrapped_plan_updates(key, policy_params, policy_hyperparams,
|
|
2219
|
-
subs, model_params, opt_state, opt_aux):
|
|
2220
|
-
keys = jnp.asarray(random.split(key, num=num_parallel))
|
|
2179
|
+
# for parallel policy update, just do each policy update in parallel
|
|
2180
|
+
def _jax_wrapped_batched_plan_update(key, policy_params, policy_hyperparams,
|
|
2181
|
+
fls, nfls, model_params, opt_state, opt_aux):
|
|
2182
|
+
keys = random.split(key, num=num_parallel)
|
|
2221
2183
|
return jax.vmap(
|
|
2222
|
-
_jax_wrapped_plan_update, in_axes=(0, 0, None, None, 0, 0, 0)
|
|
2223
|
-
)(keys, policy_params, policy_hyperparams,
|
|
2184
|
+
_jax_wrapped_plan_update, in_axes=(0, 0, None, None, None, 0, 0, 0)
|
|
2185
|
+
)(keys, policy_params, policy_hyperparams, fls, nfls, model_params,
|
|
2224
2186
|
opt_state, opt_aux)
|
|
2225
|
-
|
|
2226
|
-
return jax.jit(_jax_wrapped_plan_updates)
|
|
2187
|
+
return jax.jit(_jax_wrapped_batched_plan_update)
|
|
2227
2188
|
|
|
2228
2189
|
def _jax_merge_pgpe_jaxplan(self):
|
|
2229
|
-
|
|
2230
|
-
return None
|
|
2231
|
-
|
|
2232
|
-
# for parallel policy update
|
|
2190
|
+
|
|
2233
2191
|
# currently implements a hard replacement where the jaxplan parameter
|
|
2234
2192
|
# is replaced by the PGPE parameter if the latter is an improvement
|
|
2235
|
-
def
|
|
2193
|
+
def _jax_wrapped_batched_pgpe_merge(pgpe_mask, pgpe_param, policy_params,
|
|
2236
2194
|
pgpe_loss, test_loss,
|
|
2237
2195
|
pgpe_loss_smooth, test_loss_smooth,
|
|
2238
2196
|
pgpe_converged, converged):
|
|
2239
|
-
|
|
2240
|
-
|
|
2241
|
-
|
|
2242
|
-
policy_params = jax.tree_util.tree_map(
|
|
2197
|
+
mask_tree = jax.tree_util.tree_map(
|
|
2198
|
+
lambda leaf: pgpe_mask[(...,) + (jnp.newaxis,) * (jnp.ndim(leaf) - 1)],
|
|
2199
|
+
pgpe_param)
|
|
2200
|
+
policy_params = jax.tree_util.tree_map(
|
|
2201
|
+
jnp.where, mask_tree, pgpe_param, policy_params)
|
|
2243
2202
|
test_loss = jnp.where(pgpe_mask, pgpe_loss, test_loss)
|
|
2244
2203
|
test_loss_smooth = jnp.where(pgpe_mask, pgpe_loss_smooth, test_loss_smooth)
|
|
2245
|
-
|
|
2246
|
-
converged = jnp.where(expanded_mask, pgpe_converged, converged)
|
|
2204
|
+
converged = jnp.where(pgpe_mask, pgpe_converged, converged)
|
|
2247
2205
|
return policy_params, test_loss, test_loss_smooth, converged
|
|
2206
|
+
return jax.jit(_jax_wrapped_batched_pgpe_merge)
|
|
2248
2207
|
|
|
2249
|
-
|
|
2250
|
-
|
|
2251
|
-
def _batched_init_subs(self, subs):
|
|
2208
|
+
def _batched_init_subs(self, init_values):
|
|
2252
2209
|
rddl = self.rddl
|
|
2253
2210
|
n_train, n_test = self.batch_size_train, self.batch_size_test
|
|
2254
2211
|
|
|
2255
|
-
|
|
2256
|
-
|
|
2257
|
-
|
|
2212
|
+
init_train_fls, init_train_nfls, init_test_fls, init_test_nfls = {}, {}, {}, {}
|
|
2213
|
+
for (name, value) in init_values.items():
|
|
2214
|
+
|
|
2215
|
+
# get the initial fluent values and check validity
|
|
2258
2216
|
init_value = self.test_compiled.init_values.get(name, None)
|
|
2259
2217
|
if init_value is None:
|
|
2260
2218
|
raise RDDLUndefinedVariableError(
|
|
2261
|
-
f'Variable <{name}> in
|
|
2219
|
+
f'Variable <{name}> in init_values argument is not a '
|
|
2262
2220
|
f'valid p-variable, must be one of '
|
|
2263
|
-
f'{set(self.test_compiled.init_values.keys())}.'
|
|
2264
|
-
|
|
2221
|
+
f'{set(self.test_compiled.init_values.keys())}.'
|
|
2222
|
+
)
|
|
2223
|
+
|
|
2224
|
+
# for enum types need to convert the string values to integer indices
|
|
2225
|
+
if np.size(value) != np.size(init_value):
|
|
2226
|
+
value = init_value
|
|
2227
|
+
value = np.reshape(value, np.shape(init_value))
|
|
2265
2228
|
if value.dtype.type is np.str_:
|
|
2266
2229
|
value = rddl.object_string_to_index_array(rddl.variable_ranges[name], value)
|
|
2267
|
-
|
|
2268
|
-
|
|
2269
|
-
|
|
2270
|
-
|
|
2230
|
+
|
|
2231
|
+
# train and test fluents have a batch dimension added, non-fluents do not
|
|
2232
|
+
# train fluents are also converted to float
|
|
2233
|
+
if name not in rddl.non_fluents:
|
|
2234
|
+
train_value = np.repeat(value[np.newaxis, ...], repeats=n_train, axis=0)
|
|
2235
|
+
init_train_fls[name] = np.asarray(train_value, dtype=self.compiled.REAL)
|
|
2236
|
+
init_test_fls[name] = np.repeat(value[np.newaxis, ...], repeats=n_test, axis=0)
|
|
2237
|
+
else:
|
|
2238
|
+
init_train_nfls[name] = np.asarray(value, dtype=self.compiled.REAL)
|
|
2239
|
+
init_test_nfls[name] = value
|
|
2271
2240
|
|
|
2272
|
-
# safely cast test
|
|
2241
|
+
# safely cast test variable to required type in case the type is wrong
|
|
2273
2242
|
if name in rddl.variable_ranges:
|
|
2274
2243
|
required_type = RDDLValueInitializer.NUMPY_TYPES.get(
|
|
2275
2244
|
rddl.variable_ranges[name], RDDLValueInitializer.INT)
|
|
2245
|
+
init_test = init_test_nfls if name in rddl.non_fluents else init_test_fls
|
|
2276
2246
|
if np.result_type(init_test[name]) != required_type:
|
|
2277
2247
|
init_test[name] = np.asarray(init_test[name], dtype=required_type)
|
|
2278
2248
|
|
|
2279
2249
|
# make sure next-state fluents are also set
|
|
2280
2250
|
for (state, next_state) in rddl.next_state.items():
|
|
2281
|
-
|
|
2282
|
-
|
|
2283
|
-
return
|
|
2251
|
+
init_train_fls[next_state] = init_train_fls[state]
|
|
2252
|
+
init_test_fls[next_state] = init_test_fls[state]
|
|
2253
|
+
return (init_train_fls, init_train_nfls), (init_test_fls, init_test_nfls)
|
|
2284
2254
|
|
|
2285
2255
|
def _broadcast_pytree(self, pytree):
|
|
2286
|
-
if self.parallel_updates is None:
|
|
2287
|
-
return pytree
|
|
2288
|
-
|
|
2289
|
-
# for parallel policy update
|
|
2290
2256
|
def make_batched(x):
|
|
2291
2257
|
x = np.asarray(x)
|
|
2292
2258
|
x = np.broadcast_to(
|
|
2293
2259
|
x[np.newaxis, ...], shape=(self.parallel_updates,) + np.shape(x))
|
|
2294
2260
|
return x
|
|
2295
|
-
|
|
2296
2261
|
return jax.tree_util.tree_map(make_batched, pytree)
|
|
2297
2262
|
|
|
2298
2263
|
def as_optimization_problem(
|
|
@@ -2324,7 +2289,7 @@ r"""
|
|
|
2324
2289
|
'''
|
|
2325
2290
|
|
|
2326
2291
|
# make sure parallel updates are disabled
|
|
2327
|
-
if self.parallel_updates
|
|
2292
|
+
if self.parallel_updates > 1:
|
|
2328
2293
|
raise ValueError('Cannot compile static optimization problem '
|
|
2329
2294
|
'when parallel_updates is not None.')
|
|
2330
2295
|
|
|
@@ -2333,42 +2298,45 @@ r"""
|
|
|
2333
2298
|
key = random.PRNGKey(round(time.time() * 1000))
|
|
2334
2299
|
|
|
2335
2300
|
# initialize the initial fluents, model parameters, policy hyper-params
|
|
2336
|
-
|
|
2337
|
-
|
|
2338
|
-
model_params = self.compiled.model_params
|
|
2301
|
+
(fls, nfls), _ = self._batched_init_subs(self.test_compiled.init_values)
|
|
2302
|
+
model_params = self.compiled.model_aux['params']
|
|
2339
2303
|
if policy_hyperparams is None:
|
|
2340
2304
|
if self.print_warnings:
|
|
2341
|
-
|
|
2342
|
-
'[WARN] policy_hyperparams is not set
|
|
2343
|
-
'all action-fluents which could be suboptimal.', 'yellow'
|
|
2344
|
-
|
|
2345
|
-
policy_hyperparams = {action: 1.
|
|
2346
|
-
for action in self.rddl.action_fluents}
|
|
2305
|
+
print(termcolor.colored(
|
|
2306
|
+
'[WARN] policy_hyperparams is not set: setting values to 1.0 for '
|
|
2307
|
+
'all action-fluents, which could be suboptimal.', 'yellow'
|
|
2308
|
+
))
|
|
2309
|
+
policy_hyperparams = {action: 1. for action in self.rddl.action_fluents}
|
|
2347
2310
|
|
|
2348
2311
|
# initialize the policy parameters
|
|
2349
|
-
params_guess
|
|
2312
|
+
params_guess = self.initialize(key, policy_hyperparams, fls)[0]
|
|
2313
|
+
params_guess = pytree_at(params_guess, 0)
|
|
2314
|
+
|
|
2315
|
+
# get the params mapping to a 1D vector
|
|
2350
2316
|
guess_1d, unravel_fn = jax.flatten_util.ravel_pytree(params_guess)
|
|
2351
2317
|
guess_1d = np.asarray(guess_1d)
|
|
2352
2318
|
|
|
2353
|
-
# computes the training loss function
|
|
2354
|
-
loss_fn = self._jax_loss(self.train_rollouts)
|
|
2355
|
-
|
|
2319
|
+
# computes the training loss function in a 1D vector
|
|
2356
2320
|
@jax.jit
|
|
2357
2321
|
def _loss_with_key(key, params_1d, model_params):
|
|
2358
2322
|
policy_params = unravel_fn(params_1d)
|
|
2359
|
-
loss_val, (_, model_params) =
|
|
2360
|
-
key, policy_params, policy_hyperparams,
|
|
2323
|
+
loss_val, (_, model_params) = self.single_train_loss(
|
|
2324
|
+
key, policy_params, policy_hyperparams, fls, nfls, model_params)
|
|
2361
2325
|
return loss_val, model_params
|
|
2362
2326
|
|
|
2327
|
+
# computes the training loss gradient function in a 1D vector
|
|
2328
|
+
grad_fn = jax.grad(self.single_train_loss, argnums=1, has_aux=True)
|
|
2329
|
+
|
|
2363
2330
|
@jax.jit
|
|
2364
2331
|
def _grad_with_key(key, params_1d, model_params):
|
|
2365
2332
|
policy_params = unravel_fn(params_1d)
|
|
2366
|
-
grad_fn = jax.grad(loss_fn, argnums=1, has_aux=True)
|
|
2367
2333
|
grad_val, (_, model_params) = grad_fn(
|
|
2368
|
-
key, policy_params, policy_hyperparams,
|
|
2334
|
+
key, policy_params, policy_hyperparams, fls, nfls, model_params)
|
|
2369
2335
|
grad_val = jax.flatten_util.ravel_pytree(grad_val)[0]
|
|
2370
2336
|
return grad_val, model_params
|
|
2371
2337
|
|
|
2338
|
+
# store a global reference to the key on every JAX function call and pass when
|
|
2339
|
+
# required by JAX, then update it upon return
|
|
2372
2340
|
def _loss_function(params_1d):
|
|
2373
2341
|
nonlocal key
|
|
2374
2342
|
nonlocal model_params
|
|
@@ -2402,9 +2370,7 @@ r"""
|
|
|
2402
2370
|
|
|
2403
2371
|
:param key: JAX PRNG key (derived from clock if not provided)
|
|
2404
2372
|
:param epochs: the maximum number of steps of gradient descent
|
|
2405
|
-
:param train_seconds: total time allocated for gradient descent
|
|
2406
|
-
:param dashboard: dashboard to display training results
|
|
2407
|
-
:param dashboard_id: experiment id for the dashboard
|
|
2373
|
+
:param train_seconds: total time allocated for gradient descent
|
|
2408
2374
|
:param model_params: optional model-parameters to override default
|
|
2409
2375
|
:param policy_hyperparams: hyper-parameters for the policy/plan, such as
|
|
2410
2376
|
weights for sigmoid wrapping boolean actions
|
|
@@ -2414,10 +2380,9 @@ r"""
|
|
|
2414
2380
|
specified in this instance
|
|
2415
2381
|
:param print_summary: whether to print planner header and diagnosis
|
|
2416
2382
|
:param print_progress: whether to print the progress bar during training
|
|
2417
|
-
:param print_hyperparams: whether to print list of hyper-parameter settings
|
|
2383
|
+
:param print_hyperparams: whether to print list of hyper-parameter settings
|
|
2384
|
+
:param dashboard_id: experiment id for the dashboard
|
|
2418
2385
|
:param stopping_rule: stopping criterion
|
|
2419
|
-
:param restart_epochs: restart the optimizer from a random policy configuration
|
|
2420
|
-
if there is no progress for this many consecutive iterations
|
|
2421
2386
|
:param test_rolling_window: the test return is averaged on a rolling
|
|
2422
2387
|
window of the past test_rolling_window returns when updating the best
|
|
2423
2388
|
parameters found so far
|
|
@@ -2442,8 +2407,6 @@ r"""
|
|
|
2442
2407
|
def optimize_generator(self, key: Optional[random.PRNGKey]=None,
|
|
2443
2408
|
epochs: int=999999,
|
|
2444
2409
|
train_seconds: float=120.,
|
|
2445
|
-
dashboard: Optional[Any]=None,
|
|
2446
|
-
dashboard_id: Optional[str]=None,
|
|
2447
2410
|
model_params: Optional[Dict[str, Any]]=None,
|
|
2448
2411
|
policy_hyperparams: Optional[Dict[str, Any]]=None,
|
|
2449
2412
|
subs: Optional[Dict[str, Any]]=None,
|
|
@@ -2451,8 +2414,8 @@ r"""
|
|
|
2451
2414
|
print_summary: bool=True,
|
|
2452
2415
|
print_progress: bool=True,
|
|
2453
2416
|
print_hyperparams: bool=False,
|
|
2417
|
+
dashboard_id: Optional[str]=None,
|
|
2454
2418
|
stopping_rule: Optional[JaxPlannerStoppingRule]=None,
|
|
2455
|
-
restart_epochs: int=999999,
|
|
2456
2419
|
test_rolling_window: int=10,
|
|
2457
2420
|
tqdm_position: Optional[int]=None) -> Generator[Dict[str, Any], None, None]:
|
|
2458
2421
|
'''Returns a generator for computing an optimal policy or plan.
|
|
@@ -2461,9 +2424,7 @@ r"""
|
|
|
2461
2424
|
|
|
2462
2425
|
:param key: JAX PRNG key (derived from clock if not provided)
|
|
2463
2426
|
:param epochs: the maximum number of steps of gradient descent
|
|
2464
|
-
:param train_seconds: total time allocated for gradient descent
|
|
2465
|
-
:param dashboard: dashboard to display training results
|
|
2466
|
-
:param dashboard_id: experiment id for the dashboard
|
|
2427
|
+
:param train_seconds: total time allocated for gradient descent
|
|
2467
2428
|
:param model_params: optional model-parameters to override default
|
|
2468
2429
|
:param policy_hyperparams: hyper-parameters for the policy/plan, such as
|
|
2469
2430
|
weights for sigmoid wrapping boolean actions
|
|
@@ -2474,14 +2435,15 @@ r"""
|
|
|
2474
2435
|
:param print_summary: whether to print planner header and diagnosis
|
|
2475
2436
|
:param print_progress: whether to print the progress bar during training
|
|
2476
2437
|
:param print_hyperparams: whether to print list of hyper-parameter settings
|
|
2438
|
+
:param dashboard_id: experiment id for the dashboard
|
|
2477
2439
|
:param stopping_rule: stopping criterion
|
|
2478
|
-
:param restart_epochs: restart the optimizer from a random policy configuration
|
|
2479
|
-
if there is no progress for this many consecutive iterations
|
|
2480
2440
|
:param test_rolling_window: the test return is averaged on a rolling
|
|
2481
2441
|
window of the past test_rolling_window returns when updating the best
|
|
2482
2442
|
parameters found so far
|
|
2483
2443
|
:param tqdm_position: position of tqdm progress bar (for multiprocessing)
|
|
2484
2444
|
'''
|
|
2445
|
+
|
|
2446
|
+
# start measuring execution time here, including time spent outside optimize loop
|
|
2485
2447
|
start_time = time.time()
|
|
2486
2448
|
elapsed_outside_loop = 0
|
|
2487
2449
|
|
|
@@ -2489,39 +2451,27 @@ r"""
|
|
|
2489
2451
|
# INITIALIZATION OF HYPER-PARAMETERS
|
|
2490
2452
|
# ======================================================================
|
|
2491
2453
|
|
|
2492
|
-
# cannot run dashboard with parallel updates
|
|
2493
|
-
if dashboard is not None and self.parallel_updates is not None:
|
|
2494
|
-
if self.print_warnings:
|
|
2495
|
-
message = termcolor.colored(
|
|
2496
|
-
'[WARN] Dashboard is unavailable if parallel_updates is not None: '
|
|
2497
|
-
'setting dashboard to None.', 'yellow')
|
|
2498
|
-
print(message)
|
|
2499
|
-
dashboard = None
|
|
2500
|
-
|
|
2501
2454
|
# if PRNG key is not provided
|
|
2502
2455
|
if key is None:
|
|
2503
2456
|
key = random.PRNGKey(round(time.time() * 1000))
|
|
2457
|
+
if self.print_warnings:
|
|
2458
|
+
print(termcolor.colored(
|
|
2459
|
+
'[WARN] PRNG key is not set: setting from clock.', 'yellow'
|
|
2460
|
+
))
|
|
2504
2461
|
dash_key = key[1].item()
|
|
2505
2462
|
|
|
2506
2463
|
# if policy_hyperparams is not provided
|
|
2507
2464
|
if policy_hyperparams is None:
|
|
2508
2465
|
if self.print_warnings:
|
|
2509
|
-
|
|
2510
|
-
'[WARN] policy_hyperparams is not set
|
|
2511
|
-
'all action-fluents which could be suboptimal.', 'yellow'
|
|
2512
|
-
|
|
2513
|
-
policy_hyperparams = {action: 1.
|
|
2514
|
-
for action in self.rddl.action_fluents}
|
|
2466
|
+
print(termcolor.colored(
|
|
2467
|
+
'[WARN] policy_hyperparams is not set: setting values to 1.0 for '
|
|
2468
|
+
'all action-fluents, which could be suboptimal.', 'yellow'
|
|
2469
|
+
))
|
|
2470
|
+
policy_hyperparams = {action: 1. for action in self.rddl.action_fluents}
|
|
2515
2471
|
|
|
2516
2472
|
# if policy_hyperparams is a scalar
|
|
2517
2473
|
elif isinstance(policy_hyperparams, (int, float, np.number)):
|
|
2518
|
-
|
|
2519
|
-
message = termcolor.colored(
|
|
2520
|
-
f'[INFO] policy_hyperparams is {policy_hyperparams}, '
|
|
2521
|
-
f'setting this value for all action-fluents.', 'green')
|
|
2522
|
-
print(message)
|
|
2523
|
-
hyperparam_value = float(policy_hyperparams)
|
|
2524
|
-
policy_hyperparams = {action: hyperparam_value
|
|
2474
|
+
policy_hyperparams = {action: float(policy_hyperparams)
|
|
2525
2475
|
for action in self.rddl.action_fluents}
|
|
2526
2476
|
|
|
2527
2477
|
# fill in missing entries
|
|
@@ -2529,12 +2479,12 @@ r"""
|
|
|
2529
2479
|
for action in self.rddl.action_fluents:
|
|
2530
2480
|
if action not in policy_hyperparams:
|
|
2531
2481
|
if self.print_warnings:
|
|
2532
|
-
|
|
2533
|
-
f'[WARN] policy_hyperparams[{action}] is not set
|
|
2534
|
-
f'setting 1.0 for missing action-fluents '
|
|
2535
|
-
f'which could be suboptimal.', 'yellow'
|
|
2536
|
-
|
|
2537
|
-
policy_hyperparams[action] = 1.
|
|
2482
|
+
print(termcolor.colored(
|
|
2483
|
+
f'[WARN] policy_hyperparams[{action}] is not set: '
|
|
2484
|
+
f'setting values to 1.0 for missing action-fluents, '
|
|
2485
|
+
f'which could be suboptimal.', 'yellow'
|
|
2486
|
+
))
|
|
2487
|
+
policy_hyperparams[action] = 1.
|
|
2538
2488
|
|
|
2539
2489
|
# initialize preprocessor
|
|
2540
2490
|
preproc_key = None
|
|
@@ -2548,21 +2498,21 @@ r"""
|
|
|
2548
2498
|
print(self.summarize_relaxations())
|
|
2549
2499
|
if print_hyperparams:
|
|
2550
2500
|
print(self.summarize_hyperparameters())
|
|
2551
|
-
print(
|
|
2552
|
-
|
|
2553
|
-
|
|
2554
|
-
|
|
2555
|
-
|
|
2556
|
-
|
|
2557
|
-
|
|
2558
|
-
|
|
2559
|
-
|
|
2560
|
-
|
|
2561
|
-
|
|
2562
|
-
|
|
2563
|
-
|
|
2564
|
-
|
|
2565
|
-
|
|
2501
|
+
print(
|
|
2502
|
+
f'[INFO] optimize call hyper-parameters:\n'
|
|
2503
|
+
f' PRNG key ={key}\n'
|
|
2504
|
+
f' max_iterations ={epochs}\n'
|
|
2505
|
+
f' max_seconds ={train_seconds}\n'
|
|
2506
|
+
f' model_params ={model_params}\n'
|
|
2507
|
+
f' policy_hyper_params={policy_hyperparams}\n'
|
|
2508
|
+
f' override_subs_dict ={subs is not None}\n'
|
|
2509
|
+
f' provide_param_guess={guess is not None}\n'
|
|
2510
|
+
f' test_rolling_window={test_rolling_window}\n'
|
|
2511
|
+
f' print_summary ={print_summary}\n'
|
|
2512
|
+
f' print_progress ={print_progress}\n'
|
|
2513
|
+
f' dashboard_id ={dashboard_id}\n'
|
|
2514
|
+
f' stopping_rule ={stopping_rule}\n'
|
|
2515
|
+
)
|
|
2566
2516
|
|
|
2567
2517
|
# ======================================================================
|
|
2568
2518
|
# INITIALIZATION OF STATE AND POLICY
|
|
@@ -2580,23 +2530,23 @@ r"""
|
|
|
2580
2530
|
subs[var] = value
|
|
2581
2531
|
added_pvars_to_subs.append(var)
|
|
2582
2532
|
if self.print_warnings and added_pvars_to_subs:
|
|
2583
|
-
|
|
2584
|
-
f'[INFO] p-
|
|
2585
|
-
f'provided subs
|
|
2586
|
-
|
|
2533
|
+
print(termcolor.colored(
|
|
2534
|
+
f'[INFO] p-variable(s) {added_pvars_to_subs} are not in '
|
|
2535
|
+
f'provided subs: using their initial values.', 'dark_grey'
|
|
2536
|
+
))
|
|
2587
2537
|
train_subs, test_subs = self._batched_init_subs(subs)
|
|
2588
2538
|
|
|
2589
2539
|
# initialize model parameters
|
|
2590
2540
|
if model_params is None:
|
|
2591
|
-
model_params = self.compiled.
|
|
2541
|
+
model_params = self.compiled.model_aux['params']
|
|
2592
2542
|
model_params = self._broadcast_pytree(model_params)
|
|
2593
|
-
model_params_test = self._broadcast_pytree(self.test_compiled.
|
|
2543
|
+
model_params_test = self._broadcast_pytree(self.test_compiled.model_aux['params'])
|
|
2594
2544
|
|
|
2595
2545
|
# initialize policy parameters
|
|
2596
2546
|
if guess is None:
|
|
2597
2547
|
key, subkey = random.split(key)
|
|
2598
2548
|
policy_params, opt_state, opt_aux = self.initialize(
|
|
2599
|
-
subkey, policy_hyperparams, train_subs)
|
|
2549
|
+
subkey, policy_hyperparams, train_subs[0])
|
|
2600
2550
|
else:
|
|
2601
2551
|
policy_params = self._broadcast_pytree(guess)
|
|
2602
2552
|
opt_state, opt_aux = self.init_optimizer(policy_params)
|
|
@@ -2606,8 +2556,7 @@ r"""
|
|
|
2606
2556
|
pgpe_params, pgpe_opt_state, r_max = self.pgpe.initialize(key, policy_params)
|
|
2607
2557
|
rolling_pgpe_loss = RollingMean(test_rolling_window)
|
|
2608
2558
|
else:
|
|
2609
|
-
pgpe_params
|
|
2610
|
-
rolling_pgpe_loss = None
|
|
2559
|
+
pgpe_params = pgpe_opt_state = r_max = rolling_pgpe_loss = None
|
|
2611
2560
|
total_pgpe_it = 0
|
|
2612
2561
|
|
|
2613
2562
|
# ======================================================================
|
|
@@ -2615,13 +2564,10 @@ r"""
|
|
|
2615
2564
|
# ======================================================================
|
|
2616
2565
|
|
|
2617
2566
|
# initialize running statistics
|
|
2618
|
-
|
|
2619
|
-
best_params = policy_params
|
|
2620
|
-
else:
|
|
2621
|
-
best_params = self.pytree_at(policy_params, 0)
|
|
2567
|
+
best_params = pytree_at(policy_params, 0)
|
|
2622
2568
|
best_loss, pbest_loss, best_grad = np.inf, np.inf, None
|
|
2569
|
+
best_index = 0
|
|
2623
2570
|
last_iter_improve = 0
|
|
2624
|
-
no_progress_count = 0
|
|
2625
2571
|
rolling_test_loss = RollingMean(test_rolling_window)
|
|
2626
2572
|
status = JaxPlannerStatus.NORMAL
|
|
2627
2573
|
progress_percent = 0
|
|
@@ -2630,11 +2576,15 @@ r"""
|
|
|
2630
2576
|
if stopping_rule is not None:
|
|
2631
2577
|
stopping_rule.reset()
|
|
2632
2578
|
|
|
2633
|
-
# initialize
|
|
2579
|
+
# initialize dashboard
|
|
2580
|
+
dashboard = self.dashboard
|
|
2634
2581
|
if dashboard is not None:
|
|
2635
2582
|
dashboard_id = dashboard.register_experiment(
|
|
2636
|
-
dashboard_id,
|
|
2637
|
-
|
|
2583
|
+
dashboard_id,
|
|
2584
|
+
dashboard.get_planner_info(self),
|
|
2585
|
+
key=dash_key,
|
|
2586
|
+
viz=self.dashboard_viz
|
|
2587
|
+
)
|
|
2638
2588
|
|
|
2639
2589
|
# progress bar
|
|
2640
2590
|
if print_progress:
|
|
@@ -2646,8 +2596,8 @@ r"""
|
|
|
2646
2596
|
|
|
2647
2597
|
# error handlers (to avoid spam messaging)
|
|
2648
2598
|
policy_constraint_msg_shown = False
|
|
2649
|
-
jax_train_msg_shown =
|
|
2650
|
-
jax_test_msg_shown =
|
|
2599
|
+
jax_train_msg_shown = set()
|
|
2600
|
+
jax_test_msg_shown = set()
|
|
2651
2601
|
|
|
2652
2602
|
# ======================================================================
|
|
2653
2603
|
# MAIN TRAINING LOOP BEGINS
|
|
@@ -2656,7 +2606,7 @@ r"""
|
|
|
2656
2606
|
for it in range(epochs):
|
|
2657
2607
|
|
|
2658
2608
|
# ==================================================================
|
|
2659
|
-
#
|
|
2609
|
+
# JAXPLAN GRADIENT DESCENT STEP
|
|
2660
2610
|
# ==================================================================
|
|
2661
2611
|
|
|
2662
2612
|
status = JaxPlannerStatus.NORMAL
|
|
@@ -2665,135 +2615,113 @@ r"""
|
|
|
2665
2615
|
key, subkey = random.split(key)
|
|
2666
2616
|
(policy_params, converged, opt_state, opt_aux, train_loss, train_log,
|
|
2667
2617
|
model_params, zero_grads) = self.update(
|
|
2668
|
-
subkey, policy_params, policy_hyperparams, train_subs, model_params,
|
|
2669
|
-
opt_state, opt_aux
|
|
2618
|
+
subkey, policy_params, policy_hyperparams, *train_subs, model_params,
|
|
2619
|
+
opt_state, opt_aux
|
|
2620
|
+
)
|
|
2670
2621
|
|
|
2671
2622
|
# update the preprocessor
|
|
2672
2623
|
if self.preprocessor is not None:
|
|
2673
2624
|
policy_hyperparams[preproc_key] = self.preprocessor.update(
|
|
2674
|
-
train_log['fluents'], policy_hyperparams[preproc_key]
|
|
2625
|
+
train_log['fluents'], policy_hyperparams[preproc_key]
|
|
2626
|
+
)
|
|
2675
2627
|
|
|
2676
2628
|
# evaluate
|
|
2677
2629
|
test_loss, (test_log, model_params_test) = self.test_loss(
|
|
2678
|
-
subkey, policy_params, policy_hyperparams, test_subs, model_params_test
|
|
2679
|
-
|
|
2680
|
-
|
|
2681
|
-
|
|
2630
|
+
subkey, policy_params, policy_hyperparams, *test_subs, model_params_test
|
|
2631
|
+
)
|
|
2632
|
+
train_loss = np.asarray(train_loss)
|
|
2633
|
+
test_loss = np.asarray(test_loss)
|
|
2682
2634
|
test_loss_smooth = rolling_test_loss.update(test_loss)
|
|
2683
2635
|
|
|
2684
|
-
#
|
|
2636
|
+
# ==================================================================
|
|
2637
|
+
# PGPE GRADIENT DESCENT STEP
|
|
2638
|
+
# ==================================================================
|
|
2639
|
+
|
|
2685
2640
|
pgpe_improve = False
|
|
2686
2641
|
if self.use_pgpe:
|
|
2642
|
+
|
|
2643
|
+
# pgpe update of the plan
|
|
2687
2644
|
key, subkey = random.split(key)
|
|
2688
|
-
pgpe_params, r_max, pgpe_opt_state, pgpe_param, pgpe_converged =
|
|
2689
|
-
|
|
2690
|
-
|
|
2691
|
-
|
|
2645
|
+
pgpe_params, r_max, pgpe_opt_state, pgpe_param, pgpe_converged = self.pgpe.update(
|
|
2646
|
+
subkey, pgpe_params, r_max, progress_percent,
|
|
2647
|
+
policy_hyperparams, *test_subs, model_params_test, pgpe_opt_state
|
|
2648
|
+
)
|
|
2692
2649
|
|
|
2693
2650
|
# evaluate
|
|
2694
2651
|
pgpe_loss, _ = self.test_loss(
|
|
2695
|
-
subkey, pgpe_param, policy_hyperparams, test_subs, model_params_test)
|
|
2696
|
-
|
|
2697
|
-
pgpe_loss = np.asarray(pgpe_loss)
|
|
2652
|
+
subkey, pgpe_param, policy_hyperparams, *test_subs, model_params_test)
|
|
2653
|
+
pgpe_loss = np.asarray(pgpe_loss)
|
|
2698
2654
|
pgpe_loss_smooth = rolling_pgpe_loss.update(pgpe_loss)
|
|
2699
2655
|
pgpe_return = -pgpe_loss_smooth
|
|
2700
2656
|
|
|
2701
2657
|
# replace JaxPlan with PGPE if new minimum reached or train loss invalid
|
|
2702
|
-
|
|
2703
|
-
|
|
2704
|
-
|
|
2705
|
-
|
|
2706
|
-
|
|
2707
|
-
|
|
2708
|
-
|
|
2709
|
-
|
|
2710
|
-
|
|
2711
|
-
if np.any(pgpe_mask):
|
|
2712
|
-
policy_params, test_loss, test_loss_smooth, converged = \
|
|
2713
|
-
self.merge_pgpe(pgpe_mask, pgpe_param, policy_params,
|
|
2714
|
-
pgpe_loss, test_loss,
|
|
2715
|
-
pgpe_loss_smooth, test_loss_smooth,
|
|
2716
|
-
pgpe_converged, converged)
|
|
2717
|
-
pgpe_improve = True
|
|
2718
|
-
total_pgpe_it += 1
|
|
2658
|
+
pgpe_mask = (pgpe_loss_smooth < pbest_loss) | ~np.isfinite(train_loss)
|
|
2659
|
+
if np.any(pgpe_mask):
|
|
2660
|
+
policy_params, test_loss, test_loss_smooth, converged = self.merge_pgpe(
|
|
2661
|
+
pgpe_mask, pgpe_param, policy_params,
|
|
2662
|
+
pgpe_loss, test_loss, pgpe_loss_smooth, test_loss_smooth,
|
|
2663
|
+
pgpe_converged, converged
|
|
2664
|
+
)
|
|
2665
|
+
pgpe_improve = True
|
|
2666
|
+
total_pgpe_it += 1
|
|
2719
2667
|
else:
|
|
2720
|
-
pgpe_loss
|
|
2668
|
+
pgpe_loss = pgpe_loss_smooth = pgpe_return = None
|
|
2721
2669
|
|
|
2722
|
-
# evaluate test losses and record best parameters so far
|
|
2723
|
-
if self.parallel_updates is None:
|
|
2724
|
-
if test_loss_smooth < best_loss:
|
|
2725
|
-
best_params, best_loss, best_grad = \
|
|
2726
|
-
policy_params, test_loss_smooth, train_log['grad']
|
|
2727
|
-
pbest_loss = best_loss
|
|
2728
|
-
else:
|
|
2729
|
-
best_index = np.argmin(test_loss_smooth)
|
|
2730
|
-
if test_loss_smooth[best_index] < best_loss:
|
|
2731
|
-
best_params = self.pytree_at(policy_params, best_index)
|
|
2732
|
-
best_grad = self.pytree_at(train_log['grad'], best_index)
|
|
2733
|
-
best_loss = test_loss_smooth[best_index]
|
|
2734
|
-
pbest_loss = np.minimum(pbest_loss, test_loss_smooth)
|
|
2735
|
-
|
|
2736
2670
|
# ==================================================================
|
|
2737
2671
|
# STATUS CHECKS AND LOGGING
|
|
2738
2672
|
# ==================================================================
|
|
2739
|
-
|
|
2673
|
+
|
|
2674
|
+
# evaluate test losses and record best parameters so far
|
|
2675
|
+
best_index = np.argmin(test_loss_smooth)
|
|
2676
|
+
if test_loss_smooth[best_index] < best_loss:
|
|
2677
|
+
best_params = pytree_at(policy_params, best_index)
|
|
2678
|
+
best_grad = pytree_at(train_log['grad'], best_index)
|
|
2679
|
+
best_loss = test_loss_smooth[best_index]
|
|
2680
|
+
last_iter_improve = it
|
|
2681
|
+
pbest_loss = np.minimum(pbest_loss, test_loss_smooth)
|
|
2682
|
+
|
|
2740
2683
|
# no progress
|
|
2741
|
-
|
|
2742
|
-
if no_progress_flag:
|
|
2684
|
+
if (not pgpe_improve) and np.all(zero_grads):
|
|
2743
2685
|
status = JaxPlannerStatus.NO_PROGRESS
|
|
2744
2686
|
|
|
2745
2687
|
# constraint satisfaction problem
|
|
2746
2688
|
if not np.all(converged):
|
|
2747
2689
|
if progress_bar is not None and not policy_constraint_msg_shown:
|
|
2748
|
-
|
|
2749
|
-
'[FAIL] Policy update
|
|
2750
|
-
|
|
2751
|
-
progress_bar.write(message)
|
|
2690
|
+
progress_bar.write(termcolor.colored(
|
|
2691
|
+
'[FAIL] Policy update violated action constraints.', 'red'
|
|
2692
|
+
))
|
|
2752
2693
|
policy_constraint_msg_shown = True
|
|
2753
2694
|
status = JaxPlannerStatus.PRECONDITION_POSSIBLY_UNSATISFIED
|
|
2754
2695
|
|
|
2755
2696
|
# numerical error
|
|
2697
|
+
invalid_loss = not np.any(np.isfinite(train_loss))
|
|
2756
2698
|
if self.use_pgpe:
|
|
2757
|
-
invalid_loss = not
|
|
2758
|
-
np.any(np.isfinite(pgpe_loss)))
|
|
2759
|
-
else:
|
|
2760
|
-
invalid_loss = not np.any(np.isfinite(train_loss))
|
|
2699
|
+
invalid_loss = invalid_loss and not np.any(np.isfinite(pgpe_loss))
|
|
2761
2700
|
if invalid_loss:
|
|
2762
2701
|
if progress_bar is not None:
|
|
2763
|
-
|
|
2764
|
-
f'[FAIL] Planner aborted
|
|
2765
|
-
|
|
2766
|
-
progress_bar.write(message)
|
|
2702
|
+
progress_bar.write(termcolor.colored(
|
|
2703
|
+
f'[FAIL] Planner aborted early with train loss {train_loss}.', 'red'
|
|
2704
|
+
))
|
|
2767
2705
|
status = JaxPlannerStatus.INVALID_GRADIENT
|
|
2768
2706
|
|
|
2769
2707
|
# problem in the model compilation
|
|
2770
2708
|
if progress_bar is not None:
|
|
2771
2709
|
|
|
2772
2710
|
# train model
|
|
2773
|
-
|
|
2774
|
-
|
|
2775
|
-
|
|
2776
|
-
|
|
2777
|
-
|
|
2778
|
-
|
|
2779
|
-
message = termcolor.colored(
|
|
2780
|
-
f'[FAIL] Compiler encountered the following '
|
|
2781
|
-
f'error(s) in the training model:\n {messages}', 'red')
|
|
2782
|
-
progress_bar.write(message)
|
|
2783
|
-
jax_train_msg_shown = True
|
|
2711
|
+
for error_code in np.unique(train_log['error']):
|
|
2712
|
+
if error_code not in jax_train_msg_shown:
|
|
2713
|
+
jax_train_msg_shown.add(error_code)
|
|
2714
|
+
for message in JaxRDDLCompiler.get_error_messages(error_code):
|
|
2715
|
+
progress_bar.write(termcolor.colored(
|
|
2716
|
+
'[FAIL] Training model error: ' + message, 'red'))
|
|
2784
2717
|
|
|
2785
2718
|
# test model
|
|
2786
|
-
|
|
2787
|
-
|
|
2788
|
-
|
|
2789
|
-
|
|
2790
|
-
|
|
2791
|
-
|
|
2792
|
-
message = termcolor.colored(
|
|
2793
|
-
f'[FAIL] Compiler encountered the following '
|
|
2794
|
-
f'error(s) in the testing model:\n {messages}', 'red')
|
|
2795
|
-
progress_bar.write(message)
|
|
2796
|
-
jax_test_msg_shown = True
|
|
2719
|
+
for error_code in np.unique(test_log['error']):
|
|
2720
|
+
if error_code not in jax_test_msg_shown:
|
|
2721
|
+
jax_test_msg_shown.add(error_code)
|
|
2722
|
+
for message in JaxRDDLCompiler.get_error_messages(error_code):
|
|
2723
|
+
progress_bar.write(termcolor.colored(
|
|
2724
|
+
'[FAIL] Testing model error: ' + message, 'red'))
|
|
2797
2725
|
|
|
2798
2726
|
# reached computation budget
|
|
2799
2727
|
elapsed = time.time() - start_time - elapsed_outside_loop
|
|
@@ -2806,66 +2734,53 @@ r"""
|
|
|
2806
2734
|
progress_percent = 100 * min(
|
|
2807
2735
|
1, max(0, elapsed / train_seconds, it / (epochs - 1)))
|
|
2808
2736
|
callback = {
|
|
2809
|
-
'status': status,
|
|
2810
2737
|
'iteration': it,
|
|
2738
|
+
'elapsed_time': elapsed,
|
|
2739
|
+
'progress': progress_percent,
|
|
2740
|
+
'status': status,
|
|
2741
|
+
'key': key,
|
|
2811
2742
|
'train_return':-train_loss,
|
|
2812
2743
|
'test_return':-test_loss_smooth,
|
|
2813
2744
|
'best_return':-best_loss,
|
|
2814
2745
|
'pgpe_return': pgpe_return,
|
|
2746
|
+
'last_iteration_improved': last_iter_improve,
|
|
2747
|
+
'pgpe_improved': pgpe_improve,
|
|
2815
2748
|
'params': policy_params,
|
|
2816
2749
|
'best_params': best_params,
|
|
2750
|
+
'best_index': best_index,
|
|
2817
2751
|
'pgpe_params': pgpe_params,
|
|
2818
|
-
'
|
|
2819
|
-
'
|
|
2752
|
+
'model_params': model_params,
|
|
2753
|
+
'policy_hyperparams': policy_hyperparams,
|
|
2820
2754
|
'grad': train_log['grad'],
|
|
2821
2755
|
'best_grad': best_grad,
|
|
2822
|
-
'updates': train_log['updates'],
|
|
2823
|
-
'elapsed_time': elapsed,
|
|
2824
|
-
'key': key,
|
|
2825
|
-
'model_params': model_params,
|
|
2826
|
-
'progress': progress_percent,
|
|
2827
2756
|
'train_log': train_log,
|
|
2828
|
-
'
|
|
2829
|
-
**test_log
|
|
2757
|
+
'test_log': test_log
|
|
2830
2758
|
}
|
|
2831
2759
|
|
|
2832
|
-
# hard restart
|
|
2833
|
-
if guess is None and no_progress_flag:
|
|
2834
|
-
no_progress_count += 1
|
|
2835
|
-
if no_progress_count > restart_epochs:
|
|
2836
|
-
key, subkey = random.split(key)
|
|
2837
|
-
policy_params, opt_state, opt_aux = self.initialize(
|
|
2838
|
-
subkey, policy_hyperparams, train_subs)
|
|
2839
|
-
no_progress_count = 0
|
|
2840
|
-
if self.print_warnings and progress_bar is not None:
|
|
2841
|
-
message = termcolor.colored(
|
|
2842
|
-
f'[INFO] Optimizer restarted at iteration {it} '
|
|
2843
|
-
f'due to lack of progress.', 'green')
|
|
2844
|
-
progress_bar.write(message)
|
|
2845
|
-
else:
|
|
2846
|
-
no_progress_count = 0
|
|
2847
|
-
|
|
2848
2760
|
# stopping condition reached
|
|
2849
2761
|
if stopping_rule is not None and stopping_rule.monitor(callback):
|
|
2850
2762
|
if self.print_warnings and progress_bar is not None:
|
|
2851
|
-
|
|
2852
|
-
'[SUCC] Stopping rule has been reached.', 'green'
|
|
2853
|
-
|
|
2763
|
+
progress_bar.write(termcolor.colored(
|
|
2764
|
+
'[SUCC] Stopping rule has been reached.', 'green'
|
|
2765
|
+
))
|
|
2854
2766
|
callback['status'] = status = JaxPlannerStatus.STOPPING_RULE_REACHED
|
|
2855
2767
|
|
|
2856
2768
|
# if the progress bar is used
|
|
2857
2769
|
if print_progress:
|
|
2858
2770
|
progress_bar.set_description(
|
|
2859
|
-
f'{position_str} {it
|
|
2860
|
-
f'{-np.min(test_loss_smooth):
|
|
2861
|
-
f'{
|
|
2862
|
-
|
|
2771
|
+
f'{position_str} {it} it | {-np.min(train_loss):13.5f} train | '
|
|
2772
|
+
f'{-np.min(test_loss_smooth):13.5f} test | '
|
|
2773
|
+
f'{-best_loss:13.5f} best | '
|
|
2774
|
+
f'{total_pgpe_it} pgpe | {status.value} status',
|
|
2775
|
+
refresh=False
|
|
2776
|
+
)
|
|
2863
2777
|
progress_bar.set_postfix_str(
|
|
2864
|
-
f'{(it + 1) / (elapsed + 1e-6):.2f}it/s', refresh=False
|
|
2778
|
+
f'{(it + 1) / (elapsed + 1e-6):.2f}it/s', refresh=False
|
|
2779
|
+
)
|
|
2865
2780
|
progress_bar.update(progress_percent - progress_bar.n)
|
|
2866
2781
|
|
|
2867
|
-
#
|
|
2868
|
-
if dashboard is not None:
|
|
2782
|
+
# dashboard
|
|
2783
|
+
if dashboard is not None:
|
|
2869
2784
|
dashboard.update_experiment(dashboard_id, callback)
|
|
2870
2785
|
|
|
2871
2786
|
# yield the callback
|
|
@@ -2884,28 +2799,51 @@ r"""
|
|
|
2884
2799
|
# release resources
|
|
2885
2800
|
if print_progress:
|
|
2886
2801
|
progress_bar.close()
|
|
2887
|
-
print()
|
|
2888
2802
|
|
|
2889
2803
|
# summarize and test for convergence
|
|
2890
2804
|
if print_summary:
|
|
2891
|
-
|
|
2892
|
-
|
|
2805
|
+
|
|
2806
|
+
# calculate gradient norm
|
|
2807
|
+
grad_norm = jax.tree_util.tree_map(lambda x: np.linalg.norm(x).item(), best_grad)
|
|
2808
|
+
grad_norms = jax.tree_util.tree_leaves(grad_norm)
|
|
2809
|
+
max_grad_norm = max(grad_norms) if grad_norms else np.nan
|
|
2810
|
+
|
|
2811
|
+
# calculate best policy return
|
|
2812
|
+
_, (final_log, _) = self.test_loss(
|
|
2813
|
+
key, self._broadcast_pytree(best_params), policy_hyperparams,
|
|
2814
|
+
*test_subs, model_params_test
|
|
2815
|
+
)
|
|
2816
|
+
best_returns = np.ravel(np.sum(final_log['reward'], axis=2))
|
|
2817
|
+
mean, rlo, rhi = self.ci_bootstrap(best_returns)
|
|
2818
|
+
|
|
2819
|
+
# diagnosis
|
|
2893
2820
|
diagnosis = self._perform_diagnosis(
|
|
2894
2821
|
last_iter_improve, -np.min(train_loss), -np.min(test_loss_smooth),
|
|
2895
|
-
-best_loss,
|
|
2896
|
-
|
|
2897
|
-
|
|
2898
|
-
|
|
2899
|
-
|
|
2900
|
-
|
|
2901
|
-
|
|
2902
|
-
|
|
2822
|
+
-best_loss, max_grad_norm
|
|
2823
|
+
)
|
|
2824
|
+
print(
|
|
2825
|
+
f'[INFO] Summary of optimization:\n'
|
|
2826
|
+
f' status: {status}\n'
|
|
2827
|
+
f' time: {elapsed:.2f} seconds\n'
|
|
2828
|
+
f' iterations: {it}\n'
|
|
2829
|
+
f' best objective: {-best_loss:.5f}\n'
|
|
2830
|
+
f' best grad norm: {max_grad_norm:.5f}\n'
|
|
2831
|
+
f' best cuml reward: Mean = {mean:.5f}, 95% CI [{rlo:.5f}, {rhi:.5f}]\n'
|
|
2832
|
+
f' diagnosis: {diagnosis}\n'
|
|
2833
|
+
)
|
|
2903
2834
|
|
|
2904
|
-
|
|
2905
|
-
|
|
2906
|
-
|
|
2907
|
-
|
|
2908
|
-
|
|
2835
|
+
@staticmethod
|
|
2836
|
+
def ci_bootstrap(returns, confidence=0.95, n_boot=10000):
|
|
2837
|
+
means = np.zeros((n_boot,))
|
|
2838
|
+
for i in range(n_boot):
|
|
2839
|
+
means[i] = np.mean(np.random.choice(returns, size=len(returns), replace=True))
|
|
2840
|
+
lower = np.percentile(means, (1 - confidence) / 2 * 100)
|
|
2841
|
+
upper = np.percentile(means, (1 + confidence) / 2 * 100)
|
|
2842
|
+
mean = np.mean(returns)
|
|
2843
|
+
return mean, lower, upper
|
|
2844
|
+
|
|
2845
|
+
def _perform_diagnosis(self, last_iter_improve, train_return, test_return, best_return,
|
|
2846
|
+
max_grad_norm):
|
|
2909
2847
|
|
|
2910
2848
|
# divergence if the solution is not finite
|
|
2911
2849
|
if not np.isfinite(train_return):
|
|
@@ -2914,64 +2852,61 @@ r"""
|
|
|
2914
2852
|
# hit a plateau is likely IF:
|
|
2915
2853
|
# 1. planner does not improve at all
|
|
2916
2854
|
# 2. the gradient norm at the best solution is zero
|
|
2855
|
+
grad_is_zero = np.allclose(max_grad_norm, 0)
|
|
2917
2856
|
if last_iter_improve <= 1:
|
|
2918
2857
|
if grad_is_zero:
|
|
2919
2858
|
return termcolor.colored(
|
|
2920
|
-
f'[FAIL] No progress
|
|
2921
|
-
f'
|
|
2922
|
-
|
|
2859
|
+
f'[FAIL] No progress and ||g||={max_grad_norm:.4f}, '
|
|
2860
|
+
f'solver initialized in a plateau.', 'red'
|
|
2861
|
+
)
|
|
2923
2862
|
else:
|
|
2924
2863
|
return termcolor.colored(
|
|
2925
|
-
f'[FAIL] No progress
|
|
2926
|
-
f'
|
|
2927
|
-
|
|
2928
|
-
'red')
|
|
2864
|
+
f'[FAIL] No progress and ||g||={max_grad_norm:.4f}, '
|
|
2865
|
+
f'adjust learning rate or other parameters.', 'red'
|
|
2866
|
+
)
|
|
2929
2867
|
|
|
2930
2868
|
# model is likely poor IF:
|
|
2931
2869
|
# 1. the train and test return disagree
|
|
2932
|
-
validation_error =
|
|
2933
|
-
max(abs(train_return), abs(test_return))
|
|
2934
|
-
if not (validation_error <
|
|
2870
|
+
validation_error = (abs(test_return - train_return) /
|
|
2871
|
+
max(abs(train_return), abs(test_return)))
|
|
2872
|
+
if not (validation_error < 0.2):
|
|
2935
2873
|
return termcolor.colored(
|
|
2936
|
-
f'[WARN] Progress
|
|
2937
|
-
f'
|
|
2938
|
-
|
|
2939
|
-
'yellow')
|
|
2874
|
+
f'[WARN] Progress but large rel. train/test error {validation_error:.4f}, '
|
|
2875
|
+
f'adjust model or batch size.', 'yellow'
|
|
2876
|
+
)
|
|
2940
2877
|
|
|
2941
2878
|
# model likely did not converge IF:
|
|
2942
2879
|
# 1. the max grad relative to the return is high
|
|
2943
2880
|
if not grad_is_zero:
|
|
2944
|
-
|
|
2945
|
-
if not (return_to_grad_norm > 1):
|
|
2881
|
+
if not (abs(best_return) > 1.0 * max_grad_norm):
|
|
2946
2882
|
return termcolor.colored(
|
|
2947
|
-
f'[WARN] Progress
|
|
2948
|
-
f'
|
|
2949
|
-
|
|
2950
|
-
f'or batch size too small.', 'yellow')
|
|
2883
|
+
f'[WARN] Progress but large ||g||={max_grad_norm:.4f}, '
|
|
2884
|
+
f'adjust learning rate or budget.', 'yellow'
|
|
2885
|
+
)
|
|
2951
2886
|
|
|
2952
2887
|
# likely successful
|
|
2953
2888
|
return termcolor.colored(
|
|
2954
|
-
'[SUCC]
|
|
2955
|
-
|
|
2889
|
+
'[SUCC] No convergence problems found.', 'green'
|
|
2890
|
+
)
|
|
2956
2891
|
|
|
2957
2892
|
def get_action(self, key: random.PRNGKey,
|
|
2958
2893
|
params: Pytree,
|
|
2959
2894
|
step: int,
|
|
2960
|
-
|
|
2895
|
+
state: Dict[str, Any],
|
|
2961
2896
|
policy_hyperparams: Optional[Dict[str, Any]]=None) -> Dict[str, Any]:
|
|
2962
2897
|
'''Returns an action dictionary from the policy or plan with the given parameters.
|
|
2963
2898
|
|
|
2964
2899
|
:param key: the JAX PRNG key
|
|
2965
2900
|
:param params: the trainable parameter PyTree of the policy
|
|
2966
2901
|
:param step: the time step at which decision is made
|
|
2967
|
-
:param
|
|
2902
|
+
:param state: the dict of state p-variables
|
|
2968
2903
|
:param policy_hyperparams: hyper-parameters for the policy/plan, such as
|
|
2969
2904
|
weights for sigmoid wrapping boolean actions (optional)
|
|
2970
2905
|
'''
|
|
2971
|
-
|
|
2906
|
+
state = state.copy()
|
|
2972
2907
|
|
|
2973
|
-
# check compatibility of the
|
|
2974
|
-
for (var, values) in
|
|
2908
|
+
# check compatibility of the state dictionary
|
|
2909
|
+
for (var, values) in state.items():
|
|
2975
2910
|
|
|
2976
2911
|
# must not be grounded
|
|
2977
2912
|
if RDDLPlanningModel.FLUENT_SEP in var or RDDLPlanningModel.OBJECT_SEP in var:
|
|
@@ -2985,18 +2920,19 @@ r"""
|
|
|
2985
2920
|
dtype = np.result_type(values)
|
|
2986
2921
|
if not np.issubdtype(dtype, np.number) and not np.issubdtype(dtype, np.bool_):
|
|
2987
2922
|
if step == 0 and var in self.rddl.observ_fluents:
|
|
2988
|
-
|
|
2923
|
+
state[var] = self.test_compiled.init_values[var]
|
|
2989
2924
|
else:
|
|
2990
2925
|
if dtype.type is np.str_:
|
|
2991
2926
|
prange = self.rddl.variable_ranges[var]
|
|
2992
|
-
|
|
2927
|
+
state[var] = self.rddl.object_string_to_index_array(prange, state[var])
|
|
2993
2928
|
else:
|
|
2994
2929
|
raise ValueError(
|
|
2995
2930
|
f'Values {values} assigned to p-variable <{var}> are '
|
|
2996
|
-
f'non-numeric of type {dtype}.'
|
|
2931
|
+
f'non-numeric of type {dtype}.'
|
|
2932
|
+
)
|
|
2997
2933
|
|
|
2998
2934
|
# cast device arrays to numpy
|
|
2999
|
-
actions = self.test_policy(key, params, policy_hyperparams, step,
|
|
2935
|
+
actions = self.test_policy(key, params, policy_hyperparams, step, state)
|
|
3000
2936
|
actions = jax.tree_util.tree_map(np.asarray, actions)
|
|
3001
2937
|
return actions
|
|
3002
2938
|
|
|
@@ -3053,7 +2989,8 @@ class JaxOfflineController(BaseAgent):
|
|
|
3053
2989
|
with open(params, 'rb') as file:
|
|
3054
2990
|
params = pickle.load(file)
|
|
3055
2991
|
|
|
3056
|
-
# train the policy
|
|
2992
|
+
# train the policy once before starting to step() through the environment
|
|
2993
|
+
# and then execute this policy in open-loop fashion
|
|
3057
2994
|
self.step = 0
|
|
3058
2995
|
self.callback = None
|
|
3059
2996
|
if not self.train_on_reset and not self.params_given:
|
|
@@ -3126,22 +3063,33 @@ class JaxOnlineController(BaseAgent):
|
|
|
3126
3063
|
self.reset()
|
|
3127
3064
|
|
|
3128
3065
|
def sample_action(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
|
3066
|
+
|
|
3067
|
+
# we train the policy from the current state every time we step()
|
|
3129
3068
|
planner = self.planner
|
|
3130
3069
|
callback = planner.optimize(
|
|
3131
3070
|
key=self.key, guess=self.guess, subs=state, **self.train_kwargs)
|
|
3132
3071
|
|
|
3133
3072
|
# optimize again if jit compilation takes up the entire time budget
|
|
3073
|
+
# this can be done for several attempts until the optimizer has traced the
|
|
3074
|
+
# computation graph: we report the callback of the successful attempt (if exists)
|
|
3134
3075
|
attempts = 0
|
|
3135
3076
|
while attempts < self.max_attempts and callback['iteration'] <= 1:
|
|
3136
3077
|
attempts += 1
|
|
3137
3078
|
if self.planner.print_warnings:
|
|
3138
|
-
|
|
3139
|
-
f'[
|
|
3079
|
+
print(termcolor.colored(
|
|
3080
|
+
f'[INFO] JIT compilation dominated the execution time: '
|
|
3140
3081
|
f'executing the optimizer again on the traced model '
|
|
3141
|
-
f'[attempt {attempts}].', '
|
|
3142
|
-
|
|
3082
|
+
f'[attempt {attempts}].', 'dark_grey'
|
|
3083
|
+
))
|
|
3143
3084
|
callback = planner.optimize(
|
|
3144
|
-
key=self.key, guess=self.guess, subs=state, **self.train_kwargs)
|
|
3085
|
+
key=self.key, guess=self.guess, subs=state, **self.train_kwargs)
|
|
3086
|
+
if callback['iteration'] <= 1 and self.planner.print_warnings:
|
|
3087
|
+
print(termcolor.colored(
|
|
3088
|
+
f'[FAIL] JIT compilation dominated the execution time and '
|
|
3089
|
+
f'ran out of attempts: increase max_attempts or the training time.', 'red'
|
|
3090
|
+
))
|
|
3091
|
+
|
|
3092
|
+
# use the last callback obtained
|
|
3145
3093
|
self.callback = callback
|
|
3146
3094
|
params = callback['best_params']
|
|
3147
3095
|
if not self.hyperparams_given:
|