pyRDDLGym-jax 0.5__py3-none-any.whl → 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +463 -592
- pyRDDLGym_jax/core/logic.py +784 -544
- pyRDDLGym_jax/core/planner.py +336 -472
- pyRDDLGym_jax/core/simulator.py +7 -5
- pyRDDLGym_jax/core/tuning.py +392 -567
- pyRDDLGym_jax/core/visualization.py +1463 -0
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +5 -6
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +4 -5
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +5 -6
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +5 -5
- pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/default_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/default_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/default_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/tuning_drp.cfg +19 -0
- pyRDDLGym_jax/examples/configs/tuning_replan.cfg +20 -0
- pyRDDLGym_jax/examples/configs/tuning_slp.cfg +19 -0
- pyRDDLGym_jax/examples/run_plan.py +4 -1
- pyRDDLGym_jax/examples/run_tune.py +40 -27
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.1.dist-info}/METADATA +167 -111
- pyRDDLGym_jax-1.1.dist-info/RECORD +45 -0
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.1.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -19
- pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -20
- pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -18
- pyRDDLGym_jax-0.5.dist-info/RECORD +0 -44
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.1.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.1.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/planner.py
CHANGED
|
@@ -31,22 +31,24 @@ from pyRDDLGym.core.policy import BaseAgent
|
|
|
31
31
|
from pyRDDLGym_jax import __version__
|
|
32
32
|
from pyRDDLGym_jax.core import logic
|
|
33
33
|
from pyRDDLGym_jax.core.compiler import JaxRDDLCompiler
|
|
34
|
-
from pyRDDLGym_jax.core.logic import FuzzyLogic
|
|
34
|
+
from pyRDDLGym_jax.core.logic import Logic, FuzzyLogic
|
|
35
35
|
|
|
36
|
-
# try to
|
|
36
|
+
# try to load the dash board
|
|
37
37
|
try:
|
|
38
|
-
|
|
38
|
+
from pyRDDLGym_jax.core.visualization import JaxPlannerDashboard
|
|
39
39
|
except Exception:
|
|
40
|
-
raise_warning('
|
|
41
|
-
'
|
|
40
|
+
raise_warning('Failed to load the dashboard visualization tool: '
|
|
41
|
+
'please make sure you have installed the required packages.',
|
|
42
|
+
'red')
|
|
42
43
|
traceback.print_exc()
|
|
43
|
-
|
|
44
|
-
|
|
44
|
+
JaxPlannerDashboard = None
|
|
45
|
+
|
|
45
46
|
Activation = Callable[[jnp.ndarray], jnp.ndarray]
|
|
46
47
|
Bounds = Dict[str, Tuple[np.ndarray, np.ndarray]]
|
|
47
48
|
Kwargs = Dict[str, Any]
|
|
48
49
|
Pytree = Any
|
|
49
50
|
|
|
51
|
+
|
|
50
52
|
# ***********************************************************************
|
|
51
53
|
# CONFIG FILE MANAGEMENT
|
|
52
54
|
#
|
|
@@ -63,9 +65,8 @@ def _parse_config_file(path: str):
|
|
|
63
65
|
config = configparser.RawConfigParser()
|
|
64
66
|
config.optionxform = str
|
|
65
67
|
config.read(path)
|
|
66
|
-
args = {k: literal_eval(v)
|
|
67
|
-
for section in config.sections()
|
|
68
|
-
for (k, v) in config.items(section)}
|
|
68
|
+
args = {section: {k: literal_eval(v) for (k, v) in config.items(section)}
|
|
69
|
+
for section in config.sections()}
|
|
69
70
|
return config, args
|
|
70
71
|
|
|
71
72
|
|
|
@@ -73,9 +74,8 @@ def _parse_config_string(value: str):
|
|
|
73
74
|
config = configparser.RawConfigParser()
|
|
74
75
|
config.optionxform = str
|
|
75
76
|
config.read_string(value)
|
|
76
|
-
args = {k: literal_eval(v)
|
|
77
|
-
for section in config.sections()
|
|
78
|
-
for (k, v) in config.items(section)}
|
|
77
|
+
args = {section: {k: literal_eval(v) for (k, v) in config.items(section)}
|
|
78
|
+
for section in config.sections()}
|
|
79
79
|
return config, args
|
|
80
80
|
|
|
81
81
|
|
|
@@ -88,28 +88,32 @@ def _getattr_any(packages, item):
|
|
|
88
88
|
|
|
89
89
|
|
|
90
90
|
def _load_config(config, args):
|
|
91
|
-
model_args = {k: args[k] for (k, _) in config.items('Model')}
|
|
92
|
-
planner_args = {k: args[k] for (k, _) in config.items('Optimizer')}
|
|
93
|
-
train_args = {k: args[k] for (k, _) in config.items('Training')}
|
|
91
|
+
model_args = {k: args['Model'][k] for (k, _) in config.items('Model')}
|
|
92
|
+
planner_args = {k: args['Optimizer'][k] for (k, _) in config.items('Optimizer')}
|
|
93
|
+
train_args = {k: args['Training'][k] for (k, _) in config.items('Training')}
|
|
94
94
|
|
|
95
95
|
# read the model settings
|
|
96
96
|
logic_name = model_args.get('logic', 'FuzzyLogic')
|
|
97
97
|
logic_kwargs = model_args.get('logic_kwargs', {})
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
98
|
+
if logic_name == 'FuzzyLogic':
|
|
99
|
+
tnorm_name = model_args.get('tnorm', 'ProductTNorm')
|
|
100
|
+
tnorm_kwargs = model_args.get('tnorm_kwargs', {})
|
|
101
|
+
comp_name = model_args.get('complement', 'StandardComplement')
|
|
102
|
+
comp_kwargs = model_args.get('complement_kwargs', {})
|
|
103
|
+
compare_name = model_args.get('comparison', 'SigmoidComparison')
|
|
104
|
+
compare_kwargs = model_args.get('comparison_kwargs', {})
|
|
105
|
+
sampling_name = model_args.get('sampling', 'GumbelSoftmax')
|
|
106
|
+
sampling_kwargs = model_args.get('sampling_kwargs', {})
|
|
107
|
+
rounding_name = model_args.get('rounding', 'SoftRounding')
|
|
108
|
+
rounding_kwargs = model_args.get('rounding_kwargs', {})
|
|
109
|
+
control_name = model_args.get('control', 'SoftControlFlow')
|
|
110
|
+
control_kwargs = model_args.get('control_kwargs', {})
|
|
111
|
+
logic_kwargs['tnorm'] = getattr(logic, tnorm_name)(**tnorm_kwargs)
|
|
112
|
+
logic_kwargs['complement'] = getattr(logic, comp_name)(**comp_kwargs)
|
|
113
|
+
logic_kwargs['comparison'] = getattr(logic, compare_name)(**compare_kwargs)
|
|
114
|
+
logic_kwargs['sampling'] = getattr(logic, sampling_name)(**sampling_kwargs)
|
|
115
|
+
logic_kwargs['rounding'] = getattr(logic, rounding_name)(**rounding_kwargs)
|
|
116
|
+
logic_kwargs['control'] = getattr(logic, control_name)(**control_kwargs)
|
|
113
117
|
|
|
114
118
|
# read the policy settings
|
|
115
119
|
plan_method = planner_args.pop('method')
|
|
@@ -165,6 +169,13 @@ def _load_config(config, args):
|
|
|
165
169
|
if planner_key is not None:
|
|
166
170
|
train_args['key'] = random.PRNGKey(planner_key)
|
|
167
171
|
|
|
172
|
+
# dashboard
|
|
173
|
+
dashboard_key = train_args.get('dashboard', None)
|
|
174
|
+
if dashboard_key is not None and dashboard_key and JaxPlannerDashboard is not None:
|
|
175
|
+
train_args['dashboard'] = JaxPlannerDashboard()
|
|
176
|
+
elif dashboard_key is not None:
|
|
177
|
+
del train_args['dashboard']
|
|
178
|
+
|
|
168
179
|
# optimize call stopping rule
|
|
169
180
|
stopping_rule = train_args.get('stopping_rule', None)
|
|
170
181
|
if stopping_rule is not None:
|
|
@@ -186,6 +197,7 @@ def load_config_from_string(value: str) -> Tuple[Kwargs, ...]:
|
|
|
186
197
|
config, args = _parse_config_string(value)
|
|
187
198
|
return _load_config(config, args)
|
|
188
199
|
|
|
200
|
+
|
|
189
201
|
# ***********************************************************************
|
|
190
202
|
# MODEL RELAXATIONS
|
|
191
203
|
#
|
|
@@ -202,7 +214,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
202
214
|
'''
|
|
203
215
|
|
|
204
216
|
def __init__(self, *args,
|
|
205
|
-
logic:
|
|
217
|
+
logic: Logic=FuzzyLogic(),
|
|
206
218
|
cpfs_without_grad: Optional[Set[str]]=None,
|
|
207
219
|
**kwargs) -> None:
|
|
208
220
|
'''Creates a new RDDL to Jax compiler, where operations that are not
|
|
@@ -237,57 +249,55 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
237
249
|
|
|
238
250
|
# overwrite basic operations with fuzzy ones
|
|
239
251
|
self.RELATIONAL_OPS = {
|
|
240
|
-
'>=': logic.greater_equal
|
|
241
|
-
'<=': logic.less_equal
|
|
242
|
-
'<': logic.less
|
|
243
|
-
'>': logic.greater
|
|
244
|
-
'==': logic.equal
|
|
245
|
-
'~=': logic.not_equal
|
|
252
|
+
'>=': logic.greater_equal,
|
|
253
|
+
'<=': logic.less_equal,
|
|
254
|
+
'<': logic.less,
|
|
255
|
+
'>': logic.greater,
|
|
256
|
+
'==': logic.equal,
|
|
257
|
+
'~=': logic.not_equal
|
|
246
258
|
}
|
|
247
|
-
self.LOGICAL_NOT = logic.logical_not
|
|
259
|
+
self.LOGICAL_NOT = logic.logical_not
|
|
248
260
|
self.LOGICAL_OPS = {
|
|
249
|
-
'^': logic.logical_and
|
|
250
|
-
'&': logic.logical_and
|
|
251
|
-
'|': logic.logical_or
|
|
252
|
-
'~': logic.xor
|
|
253
|
-
'=>': logic.implies
|
|
254
|
-
'<=>': logic.equiv
|
|
261
|
+
'^': logic.logical_and,
|
|
262
|
+
'&': logic.logical_and,
|
|
263
|
+
'|': logic.logical_or,
|
|
264
|
+
'~': logic.xor,
|
|
265
|
+
'=>': logic.implies,
|
|
266
|
+
'<=>': logic.equiv
|
|
255
267
|
}
|
|
256
|
-
self.AGGREGATION_OPS['forall'] = logic.forall
|
|
257
|
-
self.AGGREGATION_OPS['exists'] = logic.exists
|
|
258
|
-
self.AGGREGATION_OPS['argmin'] = logic.argmin
|
|
259
|
-
self.AGGREGATION_OPS['argmax'] = logic.argmax
|
|
260
|
-
self.KNOWN_UNARY['sgn'] = logic.sgn
|
|
261
|
-
self.KNOWN_UNARY['floor'] = logic.floor
|
|
262
|
-
self.KNOWN_UNARY['ceil'] = logic.ceil
|
|
263
|
-
self.KNOWN_UNARY['round'] = logic.round
|
|
264
|
-
self.KNOWN_UNARY['sqrt'] = logic.sqrt
|
|
265
|
-
self.KNOWN_BINARY['div'] = logic.div
|
|
266
|
-
self.KNOWN_BINARY['mod'] = logic.mod
|
|
267
|
-
self.KNOWN_BINARY['fmod'] = logic.mod
|
|
268
|
-
self.IF_HELPER = logic.control_if
|
|
269
|
-
self.SWITCH_HELPER = logic.control_switch
|
|
270
|
-
self.BERNOULLI_HELPER = logic.bernoulli
|
|
271
|
-
self.DISCRETE_HELPER = logic.discrete
|
|
272
|
-
self.POISSON_HELPER = logic.poisson
|
|
273
|
-
self.GEOMETRIC_HELPER = logic.geometric
|
|
274
|
-
|
|
275
|
-
def _jax_stop_grad(self, jax_expr):
|
|
276
|
-
|
|
268
|
+
self.AGGREGATION_OPS['forall'] = logic.forall
|
|
269
|
+
self.AGGREGATION_OPS['exists'] = logic.exists
|
|
270
|
+
self.AGGREGATION_OPS['argmin'] = logic.argmin
|
|
271
|
+
self.AGGREGATION_OPS['argmax'] = logic.argmax
|
|
272
|
+
self.KNOWN_UNARY['sgn'] = logic.sgn
|
|
273
|
+
self.KNOWN_UNARY['floor'] = logic.floor
|
|
274
|
+
self.KNOWN_UNARY['ceil'] = logic.ceil
|
|
275
|
+
self.KNOWN_UNARY['round'] = logic.round
|
|
276
|
+
self.KNOWN_UNARY['sqrt'] = logic.sqrt
|
|
277
|
+
self.KNOWN_BINARY['div'] = logic.div
|
|
278
|
+
self.KNOWN_BINARY['mod'] = logic.mod
|
|
279
|
+
self.KNOWN_BINARY['fmod'] = logic.mod
|
|
280
|
+
self.IF_HELPER = logic.control_if
|
|
281
|
+
self.SWITCH_HELPER = logic.control_switch
|
|
282
|
+
self.BERNOULLI_HELPER = logic.bernoulli
|
|
283
|
+
self.DISCRETE_HELPER = logic.discrete
|
|
284
|
+
self.POISSON_HELPER = logic.poisson
|
|
285
|
+
self.GEOMETRIC_HELPER = logic.geometric
|
|
286
|
+
|
|
287
|
+
def _jax_stop_grad(self, jax_expr):
|
|
277
288
|
def _jax_wrapped_stop_grad(x, params, key):
|
|
278
|
-
sample, key, error = jax_expr(x, params, key)
|
|
289
|
+
sample, key, error, params = jax_expr(x, params, key)
|
|
279
290
|
sample = jax.lax.stop_gradient(sample)
|
|
280
|
-
return sample, key, error
|
|
281
|
-
|
|
291
|
+
return sample, key, error, params
|
|
282
292
|
return _jax_wrapped_stop_grad
|
|
283
293
|
|
|
284
|
-
def _compile_cpfs(self,
|
|
294
|
+
def _compile_cpfs(self, init_params):
|
|
285
295
|
cpfs_cast = set()
|
|
286
296
|
jax_cpfs = {}
|
|
287
297
|
for (_, cpfs) in self.levels.items():
|
|
288
298
|
for cpf in cpfs:
|
|
289
299
|
_, expr = self.rddl.cpfs[cpf]
|
|
290
|
-
jax_cpfs[cpf] = self._jax(expr,
|
|
300
|
+
jax_cpfs[cpf] = self._jax(expr, init_params, dtype=self.REAL)
|
|
291
301
|
if self.rddl.variable_ranges[cpf] != 'real':
|
|
292
302
|
cpfs_cast.add(cpf)
|
|
293
303
|
if cpf in self.cpfs_without_grad:
|
|
@@ -298,17 +308,16 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
298
308
|
f'{cpfs_cast} be cast to float.')
|
|
299
309
|
if self.cpfs_without_grad:
|
|
300
310
|
raise_warning(f'User requested that gradients not flow '
|
|
301
|
-
f'through CPFs {self.cpfs_without_grad}.')
|
|
311
|
+
f'through CPFs {self.cpfs_without_grad}.')
|
|
312
|
+
|
|
302
313
|
return jax_cpfs
|
|
303
314
|
|
|
304
|
-
def _jax_kron(self, expr,
|
|
305
|
-
if self.logic.verbose:
|
|
306
|
-
raise_warning('JAX gradient compiler ignores KronDelta '
|
|
307
|
-
'during compilation.')
|
|
315
|
+
def _jax_kron(self, expr, init_params):
|
|
308
316
|
arg, = expr.args
|
|
309
|
-
arg = self._jax(arg,
|
|
317
|
+
arg = self._jax(arg, init_params)
|
|
310
318
|
return arg
|
|
311
319
|
|
|
320
|
+
|
|
312
321
|
# ***********************************************************************
|
|
313
322
|
# ALL VERSIONS OF JAX PLANS
|
|
314
323
|
#
|
|
@@ -329,7 +338,7 @@ class JaxPlan:
|
|
|
329
338
|
self.bounds = None
|
|
330
339
|
|
|
331
340
|
def summarize_hyperparameters(self) -> None:
|
|
332
|
-
|
|
341
|
+
print(self.__str__())
|
|
333
342
|
|
|
334
343
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
335
344
|
_bounds: Bounds,
|
|
@@ -455,21 +464,21 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
455
464
|
self._wrap_softmax = wrap_softmax
|
|
456
465
|
self._use_new_projection = use_new_projection
|
|
457
466
|
self._max_constraint_iter = max_constraint_iter
|
|
458
|
-
|
|
459
|
-
def
|
|
467
|
+
|
|
468
|
+
def __str__(self) -> str:
|
|
460
469
|
bounds = '\n '.join(
|
|
461
470
|
map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
471
|
+
return (f'policy hyper-parameters:\n'
|
|
472
|
+
f' initializer ={self._initializer_base}\n'
|
|
473
|
+
f'constraint-sat strategy (simple):\n'
|
|
474
|
+
f' parsed_action_bounds =\n {bounds}\n'
|
|
475
|
+
f' wrap_sigmoid ={self._wrap_sigmoid}\n'
|
|
476
|
+
f' wrap_sigmoid_min_prob={self._min_action_prob}\n'
|
|
477
|
+
f' wrap_non_bool ={self._wrap_non_bool}\n'
|
|
478
|
+
f'constraint-sat strategy (complex):\n'
|
|
479
|
+
f' wrap_softmax ={self._wrap_softmax}\n'
|
|
480
|
+
f' use_new_projection ={self._use_new_projection}\n'
|
|
481
|
+
f' max_projection_iters ={self._max_constraint_iter}')
|
|
473
482
|
|
|
474
483
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
475
484
|
_bounds: Bounds,
|
|
@@ -820,21 +829,21 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
820
829
|
normalizer_kwargs = {'create_offset': True, 'create_scale': True}
|
|
821
830
|
self._normalizer_kwargs = normalizer_kwargs
|
|
822
831
|
self._wrap_non_bool = wrap_non_bool
|
|
823
|
-
|
|
824
|
-
def
|
|
832
|
+
|
|
833
|
+
def __str__(self) -> str:
|
|
825
834
|
bounds = '\n '.join(
|
|
826
835
|
map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
836
|
+
return (f'policy hyper-parameters:\n'
|
|
837
|
+
f' topology ={self._topology}\n'
|
|
838
|
+
f' activation_fn ={self._activations[0].__name__}\n'
|
|
839
|
+
f' initializer ={type(self._initializer_base).__name__}\n'
|
|
840
|
+
f' apply_input_norm ={self._normalize}\n'
|
|
841
|
+
f' input_norm_layerwise={self._normalize_per_layer}\n'
|
|
842
|
+
f' input_norm_args ={self._normalizer_kwargs}\n'
|
|
843
|
+
f'constraint-sat strategy:\n'
|
|
844
|
+
f' parsed_action_bounds=\n {bounds}\n'
|
|
845
|
+
f' wrap_non_bool ={self._wrap_non_bool}')
|
|
846
|
+
|
|
838
847
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
839
848
|
_bounds: Bounds,
|
|
840
849
|
horizon: int) -> None:
|
|
@@ -1057,6 +1066,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1057
1066
|
def guess_next_epoch(self, params: Pytree) -> Pytree:
|
|
1058
1067
|
return params
|
|
1059
1068
|
|
|
1069
|
+
|
|
1060
1070
|
# ***********************************************************************
|
|
1061
1071
|
# ALL VERSIONS OF JAX PLANNER
|
|
1062
1072
|
#
|
|
@@ -1084,113 +1094,6 @@ class RollingMean:
|
|
|
1084
1094
|
return self._total / len(memory)
|
|
1085
1095
|
|
|
1086
1096
|
|
|
1087
|
-
class JaxPlannerPlot:
|
|
1088
|
-
'''Supports plotting and visualization of a JAX policy in real time.'''
|
|
1089
|
-
|
|
1090
|
-
def __init__(self, rddl: RDDLPlanningModel, horizon: int,
|
|
1091
|
-
show_violin: bool=True, show_action: bool=True) -> None:
|
|
1092
|
-
'''Creates a new planner visualizer.
|
|
1093
|
-
|
|
1094
|
-
:param rddl: the planning model to optimize
|
|
1095
|
-
:param horizon: the lookahead or planning horizon
|
|
1096
|
-
:param show_violin: whether to show the distribution of batch losses
|
|
1097
|
-
:param show_action: whether to show heatmaps of the action fluents
|
|
1098
|
-
'''
|
|
1099
|
-
num_plots = 1
|
|
1100
|
-
if show_violin:
|
|
1101
|
-
num_plots += 1
|
|
1102
|
-
if show_action:
|
|
1103
|
-
num_plots += len(rddl.action_fluents)
|
|
1104
|
-
self._fig, axes = plt.subplots(num_plots)
|
|
1105
|
-
if num_plots == 1:
|
|
1106
|
-
axes = [axes]
|
|
1107
|
-
|
|
1108
|
-
# prepare the loss plot
|
|
1109
|
-
self._loss_ax = axes[0]
|
|
1110
|
-
self._loss_ax.autoscale(enable=True)
|
|
1111
|
-
self._loss_ax.set_xlabel('training time')
|
|
1112
|
-
self._loss_ax.set_ylabel('loss value')
|
|
1113
|
-
self._loss_plot = self._loss_ax.plot(
|
|
1114
|
-
[], [], linestyle=':', marker='o', markersize=2)[0]
|
|
1115
|
-
self._loss_back = self._fig.canvas.copy_from_bbox(self._loss_ax.bbox)
|
|
1116
|
-
|
|
1117
|
-
# prepare the violin plot
|
|
1118
|
-
if show_violin:
|
|
1119
|
-
self._hist_ax = axes[1]
|
|
1120
|
-
else:
|
|
1121
|
-
self._hist_ax = None
|
|
1122
|
-
|
|
1123
|
-
# prepare the action plots
|
|
1124
|
-
if show_action:
|
|
1125
|
-
self._action_ax = {name: axes[idx + (2 if show_violin else 1)]
|
|
1126
|
-
for (idx, name) in enumerate(rddl.action_fluents)}
|
|
1127
|
-
self._action_plots = {}
|
|
1128
|
-
for name in rddl.action_fluents:
|
|
1129
|
-
ax = self._action_ax[name]
|
|
1130
|
-
if rddl.variable_ranges[name] == 'bool':
|
|
1131
|
-
vmin, vmax = 0.0, 1.0
|
|
1132
|
-
else:
|
|
1133
|
-
vmin, vmax = None, None
|
|
1134
|
-
action_dim = 1
|
|
1135
|
-
for dim in rddl.object_counts(rddl.variable_params[name]):
|
|
1136
|
-
action_dim *= dim
|
|
1137
|
-
action_plot = ax.pcolormesh(
|
|
1138
|
-
np.zeros((action_dim, horizon)),
|
|
1139
|
-
cmap='seismic', vmin=vmin, vmax=vmax)
|
|
1140
|
-
ax.set_aspect('auto')
|
|
1141
|
-
ax.set_xlabel('decision epoch')
|
|
1142
|
-
ax.set_ylabel(name)
|
|
1143
|
-
plt.colorbar(action_plot, ax=ax)
|
|
1144
|
-
self._action_plots[name] = action_plot
|
|
1145
|
-
self._action_back = {name: self._fig.canvas.copy_from_bbox(ax.bbox)
|
|
1146
|
-
for (name, ax) in self._action_ax.items()}
|
|
1147
|
-
else:
|
|
1148
|
-
self._action_ax = None
|
|
1149
|
-
self._action_plots = None
|
|
1150
|
-
self._action_back = None
|
|
1151
|
-
|
|
1152
|
-
plt.tight_layout()
|
|
1153
|
-
plt.show(block=False)
|
|
1154
|
-
|
|
1155
|
-
def redraw(self, xticks, losses, actions, returns) -> None:
|
|
1156
|
-
|
|
1157
|
-
# draw the loss curve
|
|
1158
|
-
self._fig.canvas.restore_region(self._loss_back)
|
|
1159
|
-
self._loss_plot.set_xdata(xticks)
|
|
1160
|
-
self._loss_plot.set_ydata(losses)
|
|
1161
|
-
self._loss_ax.set_xlim([0, len(xticks)])
|
|
1162
|
-
self._loss_ax.set_ylim([np.min(losses), np.max(losses)])
|
|
1163
|
-
self._loss_ax.draw_artist(self._loss_plot)
|
|
1164
|
-
self._fig.canvas.blit(self._loss_ax.bbox)
|
|
1165
|
-
|
|
1166
|
-
# draw the violin plot
|
|
1167
|
-
if self._hist_ax is not None:
|
|
1168
|
-
self._hist_ax.clear()
|
|
1169
|
-
self._hist_ax.set_xlabel('loss value')
|
|
1170
|
-
self._hist_ax.set_ylabel('density')
|
|
1171
|
-
self._hist_ax.violinplot(returns, vert=False, showmeans=True)
|
|
1172
|
-
|
|
1173
|
-
# draw the actions
|
|
1174
|
-
if self._action_ax is not None:
|
|
1175
|
-
for (name, values) in actions.items():
|
|
1176
|
-
values = np.mean(values, axis=0, dtype=float)
|
|
1177
|
-
values = np.reshape(values, newshape=(values.shape[0], -1)).T
|
|
1178
|
-
self._fig.canvas.restore_region(self._action_back[name])
|
|
1179
|
-
self._action_plots[name].set_array(values)
|
|
1180
|
-
self._action_ax[name].draw_artist(self._action_plots[name])
|
|
1181
|
-
self._fig.canvas.blit(self._action_ax[name].bbox)
|
|
1182
|
-
self._action_plots[name].set_clim([np.min(values), np.max(values)])
|
|
1183
|
-
|
|
1184
|
-
self._fig.canvas.draw()
|
|
1185
|
-
self._fig.canvas.flush_events()
|
|
1186
|
-
|
|
1187
|
-
def close(self) -> None:
|
|
1188
|
-
plt.close(self._fig)
|
|
1189
|
-
del self._loss_ax, self._hist_ax, self._action_ax, \
|
|
1190
|
-
self._loss_plot, self._action_plots, self._fig, \
|
|
1191
|
-
self._loss_back, self._action_back
|
|
1192
|
-
|
|
1193
|
-
|
|
1194
1097
|
class JaxPlannerStatus(Enum):
|
|
1195
1098
|
'''Represents the status of a policy update from the JAX planner,
|
|
1196
1099
|
including whether the update resulted in nan gradient,
|
|
@@ -1198,14 +1101,15 @@ class JaxPlannerStatus(Enum):
|
|
|
1198
1101
|
can be used to monitor and act based on the planner's progress.'''
|
|
1199
1102
|
|
|
1200
1103
|
NORMAL = 0
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
|
|
1104
|
+
STOPPING_RULE_REACHED = 1
|
|
1105
|
+
NO_PROGRESS = 2
|
|
1106
|
+
PRECONDITION_POSSIBLY_UNSATISFIED = 3
|
|
1107
|
+
INVALID_GRADIENT = 4
|
|
1108
|
+
TIME_BUDGET_REACHED = 5
|
|
1109
|
+
ITER_BUDGET_REACHED = 6
|
|
1206
1110
|
|
|
1207
|
-
def
|
|
1208
|
-
return self.value >=
|
|
1111
|
+
def is_terminal(self) -> bool:
|
|
1112
|
+
return self.value == 1 or self.value >= 4
|
|
1209
1113
|
|
|
1210
1114
|
|
|
1211
1115
|
class JaxPlannerStoppingRule:
|
|
@@ -1255,15 +1159,16 @@ class JaxBackpropPlanner:
|
|
|
1255
1159
|
optimizer: Callable[..., optax.GradientTransformation]=optax.rmsprop,
|
|
1256
1160
|
optimizer_kwargs: Optional[Kwargs]=None,
|
|
1257
1161
|
clip_grad: Optional[float]=None,
|
|
1258
|
-
|
|
1259
|
-
|
|
1260
|
-
logic:
|
|
1162
|
+
line_search_kwargs: Optional[Kwargs]=None,
|
|
1163
|
+
noise_kwargs: Optional[Kwargs]=None,
|
|
1164
|
+
logic: Logic=FuzzyLogic(),
|
|
1261
1165
|
use_symlog_reward: bool=False,
|
|
1262
1166
|
utility: Union[Callable[[jnp.ndarray], float], str]='mean',
|
|
1263
1167
|
utility_kwargs: Optional[Kwargs]=None,
|
|
1264
1168
|
cpfs_without_grad: Optional[Set[str]]=None,
|
|
1265
1169
|
compile_non_fluent_exact: bool=True,
|
|
1266
|
-
logger: Optional[Logger]=None
|
|
1170
|
+
logger: Optional[Logger]=None,
|
|
1171
|
+
dashboard_viz: Optional[Any]=None) -> None:
|
|
1267
1172
|
'''Creates a new gradient-based algorithm for optimizing action sequences
|
|
1268
1173
|
(plan) in the given RDDL. Some operations will be converted to their
|
|
1269
1174
|
differentiable counterparts; the specific operations can be customized
|
|
@@ -1283,9 +1188,10 @@ class JaxBackpropPlanner:
|
|
|
1283
1188
|
:param optimizer_kwargs: a dictionary of parameters to pass to the SGD
|
|
1284
1189
|
factory (e.g. which parameters are controllable externally)
|
|
1285
1190
|
:param clip_grad: maximum magnitude of gradient updates
|
|
1286
|
-
:param
|
|
1287
|
-
|
|
1288
|
-
:param
|
|
1191
|
+
:param line_search_kwargs: parameters to pass to optional line search
|
|
1192
|
+
method to scale learning rate
|
|
1193
|
+
:param noise_kwargs: parameters of optional gradient noise
|
|
1194
|
+
:param logic: a subclass of Logic for mapping exact mathematical
|
|
1289
1195
|
operations to their differentiable counterparts
|
|
1290
1196
|
:param use_symlog_reward: whether to use the symlog transform on the
|
|
1291
1197
|
reward as a form of normalization
|
|
@@ -1300,6 +1206,8 @@ class JaxBackpropPlanner:
|
|
|
1300
1206
|
:param compile_non_fluent_exact: whether non-fluent expressions
|
|
1301
1207
|
are always compiled using exact JAX expressions
|
|
1302
1208
|
:param logger: to log information about compilation to file
|
|
1209
|
+
:param dashboard_viz: optional visualizer object from the environment
|
|
1210
|
+
to pass to the dashboard to visualize the policy
|
|
1303
1211
|
'''
|
|
1304
1212
|
self.rddl = rddl
|
|
1305
1213
|
self.plan = plan
|
|
@@ -1314,31 +1222,34 @@ class JaxBackpropPlanner:
|
|
|
1314
1222
|
action_bounds = {}
|
|
1315
1223
|
self._action_bounds = action_bounds
|
|
1316
1224
|
self.use64bit = use64bit
|
|
1317
|
-
self.
|
|
1225
|
+
self.optimizer_name = optimizer
|
|
1318
1226
|
if optimizer_kwargs is None:
|
|
1319
1227
|
optimizer_kwargs = {'learning_rate': 0.1}
|
|
1320
|
-
self.
|
|
1228
|
+
self.optimizer_kwargs = optimizer_kwargs
|
|
1321
1229
|
self.clip_grad = clip_grad
|
|
1322
|
-
self.
|
|
1323
|
-
self.
|
|
1230
|
+
self.line_search_kwargs = line_search_kwargs
|
|
1231
|
+
self.noise_kwargs = noise_kwargs
|
|
1324
1232
|
|
|
1325
1233
|
# set optimizer
|
|
1326
1234
|
try:
|
|
1327
1235
|
optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs)
|
|
1328
1236
|
except Exception as _:
|
|
1329
1237
|
raise_warning(
|
|
1330
|
-
'Failed to inject hyperparameters into optax optimizer, '
|
|
1238
|
+
f'Failed to inject hyperparameters into optax optimizer {optimizer}, '
|
|
1331
1239
|
'rolling back to safer method: please note that modification of '
|
|
1332
|
-
'optimizer hyperparameters will not work,
|
|
1333
|
-
|
|
1334
|
-
|
|
1335
|
-
|
|
1336
|
-
|
|
1337
|
-
|
|
1338
|
-
|
|
1339
|
-
|
|
1340
|
-
|
|
1341
|
-
|
|
1240
|
+
'optimizer hyperparameters will not work.', 'red')
|
|
1241
|
+
optimizer = optimizer(**optimizer_kwargs)
|
|
1242
|
+
|
|
1243
|
+
# apply optimizer chain of transformations
|
|
1244
|
+
pipeline = []
|
|
1245
|
+
if clip_grad is not None:
|
|
1246
|
+
pipeline.append(optax.clip(clip_grad))
|
|
1247
|
+
if noise_kwargs is not None:
|
|
1248
|
+
pipeline.append(optax.add_noise(**noise_kwargs))
|
|
1249
|
+
pipeline.append(optimizer)
|
|
1250
|
+
if line_search_kwargs is not None:
|
|
1251
|
+
pipeline.append(optax.scale_by_zoom_linesearch(**line_search_kwargs))
|
|
1252
|
+
self.optimizer = optax.chain(*pipeline)
|
|
1342
1253
|
|
|
1343
1254
|
# set utility
|
|
1344
1255
|
if isinstance(utility, str):
|
|
@@ -1371,11 +1282,12 @@ class JaxBackpropPlanner:
|
|
|
1371
1282
|
self.cpfs_without_grad = cpfs_without_grad
|
|
1372
1283
|
self.compile_non_fluent_exact = compile_non_fluent_exact
|
|
1373
1284
|
self.logger = logger
|
|
1285
|
+
self.dashboard_viz = dashboard_viz
|
|
1374
1286
|
|
|
1375
1287
|
self._jax_compile_rddl()
|
|
1376
1288
|
self._jax_compile_optimizer()
|
|
1377
1289
|
|
|
1378
|
-
def
|
|
1290
|
+
def summarize_system(self) -> str:
|
|
1379
1291
|
try:
|
|
1380
1292
|
jaxlib_version = jax._src.lib.version_str
|
|
1381
1293
|
except Exception as _:
|
|
@@ -1394,36 +1306,55 @@ r"""
|
|
|
1394
1306
|
\/_____/ \/_/\/_/ \/_/\/_/ \/_/ \/_____/ \/_/\/_/ \/_/ \/_/
|
|
1395
1307
|
"""
|
|
1396
1308
|
|
|
1397
|
-
|
|
1398
|
-
|
|
1399
|
-
|
|
1400
|
-
|
|
1401
|
-
|
|
1402
|
-
|
|
1403
|
-
|
|
1404
|
-
|
|
1309
|
+
return ('\n'
|
|
1310
|
+
f'{LOGO}\n'
|
|
1311
|
+
f'Version {__version__}\n'
|
|
1312
|
+
f'Python {sys.version}\n'
|
|
1313
|
+
f'jax {jax.version.__version__}, jaxlib {jaxlib_version}, '
|
|
1314
|
+
f'optax {optax.__version__}, haiku {hk.__version__}, '
|
|
1315
|
+
f'numpy {np.__version__}\n'
|
|
1316
|
+
f'devices: {devices_short}\n')
|
|
1317
|
+
|
|
1318
|
+
def __str__(self) -> str:
|
|
1319
|
+
result = (f'objective hyper-parameters:\n'
|
|
1320
|
+
f' utility_fn ={self.utility.__name__}\n'
|
|
1321
|
+
f' utility args ={self.utility_kwargs}\n'
|
|
1322
|
+
f' use_symlog ={self.use_symlog_reward}\n'
|
|
1323
|
+
f' lookahead ={self.horizon}\n'
|
|
1324
|
+
f' user_action_bounds={self._action_bounds}\n'
|
|
1325
|
+
f' fuzzy logic type ={type(self.logic).__name__}\n'
|
|
1326
|
+
f' non_fluents exact ={self.compile_non_fluent_exact}\n'
|
|
1327
|
+
f' cpfs_no_gradient ={self.cpfs_without_grad}\n'
|
|
1328
|
+
f'optimizer hyper-parameters:\n'
|
|
1329
|
+
f' use_64_bit ={self.use64bit}\n'
|
|
1330
|
+
f' optimizer ={self.optimizer_name}\n'
|
|
1331
|
+
f' optimizer args ={self.optimizer_kwargs}\n'
|
|
1332
|
+
f' clip_gradient ={self.clip_grad}\n'
|
|
1333
|
+
f' line_search_kwargs={self.line_search_kwargs}\n'
|
|
1334
|
+
f' noise_kwargs ={self.noise_kwargs}\n'
|
|
1335
|
+
f' batch_size_train ={self.batch_size_train}\n'
|
|
1336
|
+
f' batch_size_test ={self.batch_size_test}')
|
|
1337
|
+
result += '\n' + str(self.plan)
|
|
1338
|
+
result += '\n' + str(self.logic)
|
|
1339
|
+
|
|
1340
|
+
# print model relaxation information
|
|
1341
|
+
if not self.compiled.model_params:
|
|
1342
|
+
return result
|
|
1343
|
+
result += '\n' + ('Some RDDL operations are non-differentiable '
|
|
1344
|
+
'and will be approximated as follows:' + '\n')
|
|
1345
|
+
exprs_by_rddl_op, values_by_rddl_op = {}, {}
|
|
1346
|
+
for info in self.compiled.model_parameter_info().values():
|
|
1347
|
+
rddl_op = info['rddl_op']
|
|
1348
|
+
exprs_by_rddl_op.setdefault(rddl_op, []).append(info['id'])
|
|
1349
|
+
values_by_rddl_op.setdefault(rddl_op, []).append(info['init_value'])
|
|
1350
|
+
for rddl_op in sorted(exprs_by_rddl_op.keys()):
|
|
1351
|
+
result += (f' {rddl_op}:\n'
|
|
1352
|
+
f' addresses ={exprs_by_rddl_op[rddl_op]}\n'
|
|
1353
|
+
f' init_values={values_by_rddl_op[rddl_op]}\n')
|
|
1354
|
+
return result
|
|
1405
1355
|
|
|
1406
1356
|
def summarize_hyperparameters(self) -> None:
|
|
1407
|
-
print(
|
|
1408
|
-
f' utility_fn ={self.utility.__name__}\n'
|
|
1409
|
-
f' utility args ={self.utility_kwargs}\n'
|
|
1410
|
-
f' use_symlog ={self.use_symlog_reward}\n'
|
|
1411
|
-
f' lookahead ={self.horizon}\n'
|
|
1412
|
-
f' user_action_bounds={self._action_bounds}\n'
|
|
1413
|
-
f' fuzzy logic type ={type(self.logic).__name__}\n'
|
|
1414
|
-
f' nonfluents exact ={self.compile_non_fluent_exact}\n'
|
|
1415
|
-
f' cpfs_no_gradient ={self.cpfs_without_grad}\n'
|
|
1416
|
-
f'optimizer hyper-parameters:\n'
|
|
1417
|
-
f' use_64_bit ={self.use64bit}\n'
|
|
1418
|
-
f' optimizer ={self._optimizer_name.__name__}\n'
|
|
1419
|
-
f' optimizer args ={self._optimizer_kwargs}\n'
|
|
1420
|
-
f' clip_gradient ={self.clip_grad}\n'
|
|
1421
|
-
f' noise_grad_eta ={self.noise_grad_eta}\n'
|
|
1422
|
-
f' noise_grad_gamma ={self.noise_grad_gamma}\n'
|
|
1423
|
-
f' batch_size_train ={self.batch_size_train}\n'
|
|
1424
|
-
f' batch_size_test ={self.batch_size_test}')
|
|
1425
|
-
self.plan.summarize_hyperparameters()
|
|
1426
|
-
self.logic.summarize_hyperparameters()
|
|
1357
|
+
print(self.__str__())
|
|
1427
1358
|
|
|
1428
1359
|
# ===========================================================================
|
|
1429
1360
|
# COMPILATION SUBROUTINES
|
|
@@ -1439,14 +1370,16 @@ r"""
|
|
|
1439
1370
|
logger=self.logger,
|
|
1440
1371
|
use64bit=self.use64bit,
|
|
1441
1372
|
cpfs_without_grad=self.cpfs_without_grad,
|
|
1442
|
-
compile_non_fluent_exact=self.compile_non_fluent_exact
|
|
1373
|
+
compile_non_fluent_exact=self.compile_non_fluent_exact
|
|
1374
|
+
)
|
|
1443
1375
|
self.compiled.compile(log_jax_expr=True, heading='RELAXED MODEL')
|
|
1444
1376
|
|
|
1445
1377
|
# Jax compilation of the exact RDDL for testing
|
|
1446
1378
|
self.test_compiled = JaxRDDLCompiler(
|
|
1447
1379
|
rddl=rddl,
|
|
1448
1380
|
logger=self.logger,
|
|
1449
|
-
use64bit=self.use64bit
|
|
1381
|
+
use64bit=self.use64bit
|
|
1382
|
+
)
|
|
1450
1383
|
self.test_compiled.compile(log_jax_expr=True, heading='EXACT MODEL')
|
|
1451
1384
|
|
|
1452
1385
|
def _jax_compile_optimizer(self):
|
|
@@ -1462,13 +1395,15 @@ r"""
|
|
|
1462
1395
|
train_rollouts = self.compiled.compile_rollouts(
|
|
1463
1396
|
policy=self.plan.train_policy,
|
|
1464
1397
|
n_steps=self.horizon,
|
|
1465
|
-
n_batch=self.batch_size_train
|
|
1398
|
+
n_batch=self.batch_size_train
|
|
1399
|
+
)
|
|
1466
1400
|
self.train_rollouts = train_rollouts
|
|
1467
1401
|
|
|
1468
1402
|
test_rollouts = self.test_compiled.compile_rollouts(
|
|
1469
1403
|
policy=self.plan.test_policy,
|
|
1470
1404
|
n_steps=self.horizon,
|
|
1471
|
-
n_batch=self.batch_size_test
|
|
1405
|
+
n_batch=self.batch_size_test
|
|
1406
|
+
)
|
|
1472
1407
|
self.test_rollouts = jax.jit(test_rollouts)
|
|
1473
1408
|
|
|
1474
1409
|
# initialization
|
|
@@ -1503,14 +1438,16 @@ r"""
|
|
|
1503
1438
|
_jax_wrapped_returns = self._jax_return(use_symlog)
|
|
1504
1439
|
|
|
1505
1440
|
# the loss is the average cumulative reward across all roll-outs
|
|
1506
|
-
def _jax_wrapped_plan_loss(key, policy_params,
|
|
1441
|
+
def _jax_wrapped_plan_loss(key, policy_params, policy_hyperparams,
|
|
1507
1442
|
subs, model_params):
|
|
1508
|
-
log = rollouts(
|
|
1443
|
+
log, model_params = rollouts(
|
|
1444
|
+
key, policy_params, policy_hyperparams, subs, model_params)
|
|
1509
1445
|
rewards = log['reward']
|
|
1510
1446
|
returns = _jax_wrapped_returns(rewards)
|
|
1511
1447
|
utility = utility_fn(returns, **utility_kwargs)
|
|
1512
1448
|
loss = -utility
|
|
1513
|
-
|
|
1449
|
+
aux = (log, model_params)
|
|
1450
|
+
return loss, aux
|
|
1514
1451
|
|
|
1515
1452
|
return _jax_wrapped_plan_loss
|
|
1516
1453
|
|
|
@@ -1518,8 +1455,9 @@ r"""
|
|
|
1518
1455
|
init = self.plan.initializer
|
|
1519
1456
|
optimizer = self.optimizer
|
|
1520
1457
|
|
|
1521
|
-
|
|
1522
|
-
|
|
1458
|
+
# initialize both the policy and its optimizer
|
|
1459
|
+
def _jax_wrapped_init_policy(key, policy_hyperparams, subs):
|
|
1460
|
+
policy_params = init(key, policy_hyperparams, subs)
|
|
1523
1461
|
opt_state = optimizer.init(policy_params)
|
|
1524
1462
|
return policy_params, opt_state, {}
|
|
1525
1463
|
|
|
@@ -1528,35 +1466,34 @@ r"""
|
|
|
1528
1466
|
def _jax_update(self, loss):
|
|
1529
1467
|
optimizer = self.optimizer
|
|
1530
1468
|
projection = self.plan.projection
|
|
1531
|
-
|
|
1532
|
-
# add Gaussian gradient noise per Neelakantan et al., 2016.
|
|
1533
|
-
def _jax_wrapped_gaussian_param_noise(key, grads, sigma):
|
|
1534
|
-
treedef = jax.tree_util.tree_structure(grads)
|
|
1535
|
-
keys_flat = random.split(key, num=treedef.num_leaves)
|
|
1536
|
-
keys_tree = jax.tree_util.tree_unflatten(treedef, keys_flat)
|
|
1537
|
-
new_grads = jax.tree_map(
|
|
1538
|
-
lambda g, k: g + sigma * random.normal(
|
|
1539
|
-
key=k, shape=g.shape, dtype=g.dtype),
|
|
1540
|
-
grads,
|
|
1541
|
-
keys_tree
|
|
1542
|
-
)
|
|
1543
|
-
return new_grads
|
|
1469
|
+
use_ls = self.line_search_kwargs is not None
|
|
1544
1470
|
|
|
1545
1471
|
# calculate the plan gradient w.r.t. return loss and update optimizer
|
|
1546
1472
|
# also perform a projection step to satisfy constraints on actions
|
|
1547
|
-
def
|
|
1473
|
+
def _jax_wrapped_loss_swapped(policy_params, key, policy_hyperparams,
|
|
1474
|
+
subs, model_params):
|
|
1475
|
+
return loss(key, policy_params, policy_hyperparams, subs, model_params)[0]
|
|
1476
|
+
|
|
1477
|
+
def _jax_wrapped_plan_update(key, policy_params, policy_hyperparams,
|
|
1548
1478
|
subs, model_params, opt_state, opt_aux):
|
|
1549
1479
|
grad_fn = jax.value_and_grad(loss, argnums=1, has_aux=True)
|
|
1550
|
-
(loss_val, log), grad = grad_fn(
|
|
1551
|
-
key, policy_params,
|
|
1552
|
-
|
|
1553
|
-
|
|
1554
|
-
|
|
1480
|
+
(loss_val, (log, model_params)), grad = grad_fn(
|
|
1481
|
+
key, policy_params, policy_hyperparams, subs, model_params)
|
|
1482
|
+
if use_ls:
|
|
1483
|
+
updates, opt_state = optimizer.update(
|
|
1484
|
+
grad, opt_state, params=policy_params,
|
|
1485
|
+
value=loss_val, grad=grad, value_fn=_jax_wrapped_loss_swapped,
|
|
1486
|
+
key=key, policy_hyperparams=policy_hyperparams, subs=subs,
|
|
1487
|
+
model_params=model_params)
|
|
1488
|
+
else:
|
|
1489
|
+
updates, opt_state = optimizer.update(
|
|
1490
|
+
grad, opt_state, params=policy_params)
|
|
1555
1491
|
policy_params = optax.apply_updates(policy_params, updates)
|
|
1556
|
-
policy_params, converged = projection(policy_params,
|
|
1492
|
+
policy_params, converged = projection(policy_params, policy_hyperparams)
|
|
1557
1493
|
log['grad'] = grad
|
|
1558
1494
|
log['updates'] = updates
|
|
1559
|
-
return policy_params, converged, opt_state, opt_aux,
|
|
1495
|
+
return policy_params, converged, opt_state, opt_aux, \
|
|
1496
|
+
loss_val, log, model_params
|
|
1560
1497
|
|
|
1561
1498
|
return jax.jit(_jax_wrapped_plan_update)
|
|
1562
1499
|
|
|
@@ -1583,7 +1520,6 @@ r"""
|
|
|
1583
1520
|
for (state, next_state) in rddl.next_state.items():
|
|
1584
1521
|
init_train[next_state] = init_train[state]
|
|
1585
1522
|
init_test[next_state] = init_test[state]
|
|
1586
|
-
|
|
1587
1523
|
return init_train, init_test
|
|
1588
1524
|
|
|
1589
1525
|
def as_optimization_problem(
|
|
@@ -1637,38 +1573,40 @@ r"""
|
|
|
1637
1573
|
loss_fn = self._jax_loss(self.train_rollouts)
|
|
1638
1574
|
|
|
1639
1575
|
@jax.jit
|
|
1640
|
-
def _loss_with_key(key, params_1d):
|
|
1576
|
+
def _loss_with_key(key, params_1d, model_params):
|
|
1641
1577
|
policy_params = unravel_fn(params_1d)
|
|
1642
|
-
loss_val, _ = loss_fn(
|
|
1643
|
-
|
|
1644
|
-
return loss_val
|
|
1578
|
+
loss_val, (_, model_params) = loss_fn(
|
|
1579
|
+
key, policy_params, policy_hyperparams, train_subs, model_params)
|
|
1580
|
+
return loss_val, model_params
|
|
1645
1581
|
|
|
1646
1582
|
@jax.jit
|
|
1647
|
-
def _grad_with_key(key, params_1d):
|
|
1583
|
+
def _grad_with_key(key, params_1d, model_params):
|
|
1648
1584
|
policy_params = unravel_fn(params_1d)
|
|
1649
1585
|
grad_fn = jax.grad(loss_fn, argnums=1, has_aux=True)
|
|
1650
|
-
grad_val, _ = grad_fn(
|
|
1651
|
-
|
|
1652
|
-
|
|
1653
|
-
return
|
|
1586
|
+
grad_val, (_, model_params) = grad_fn(
|
|
1587
|
+
key, policy_params, policy_hyperparams, train_subs, model_params)
|
|
1588
|
+
grad_val = jax.flatten_util.ravel_pytree(grad_val)[0]
|
|
1589
|
+
return grad_val, model_params
|
|
1654
1590
|
|
|
1655
1591
|
def _loss_function(params_1d):
|
|
1656
1592
|
nonlocal key
|
|
1593
|
+
nonlocal model_params
|
|
1657
1594
|
if loss_function_updates_key:
|
|
1658
1595
|
key, subkey = random.split(key)
|
|
1659
1596
|
else:
|
|
1660
1597
|
subkey = key
|
|
1661
|
-
loss_val = _loss_with_key(subkey, params_1d)
|
|
1598
|
+
loss_val, model_params = _loss_with_key(subkey, params_1d, model_params)
|
|
1662
1599
|
loss_val = float(loss_val)
|
|
1663
1600
|
return loss_val
|
|
1664
1601
|
|
|
1665
1602
|
def _grad_function(params_1d):
|
|
1666
1603
|
nonlocal key
|
|
1604
|
+
nonlocal model_params
|
|
1667
1605
|
if grad_function_updates_key:
|
|
1668
1606
|
key, subkey = random.split(key)
|
|
1669
1607
|
else:
|
|
1670
1608
|
subkey = key
|
|
1671
|
-
grad = _grad_with_key(subkey, params_1d)
|
|
1609
|
+
grad, model_params = _grad_with_key(subkey, params_1d, model_params)
|
|
1672
1610
|
grad = np.asarray(grad)
|
|
1673
1611
|
return grad
|
|
1674
1612
|
|
|
@@ -1683,9 +1621,9 @@ r"""
|
|
|
1683
1621
|
|
|
1684
1622
|
:param key: JAX PRNG key (derived from clock if not provided)
|
|
1685
1623
|
:param epochs: the maximum number of steps of gradient descent
|
|
1686
|
-
:param train_seconds: total time allocated for gradient descent
|
|
1687
|
-
:param
|
|
1688
|
-
:param
|
|
1624
|
+
:param train_seconds: total time allocated for gradient descent
|
|
1625
|
+
:param dashboard: dashboard to display training results
|
|
1626
|
+
:param dashboard_id: experiment id for the dashboard
|
|
1689
1627
|
:param model_params: optional model-parameters to override default
|
|
1690
1628
|
:param policy_hyperparams: hyper-parameters for the policy/plan, such as
|
|
1691
1629
|
weights for sigmoid wrapping boolean actions
|
|
@@ -1721,8 +1659,8 @@ r"""
|
|
|
1721
1659
|
def optimize_generator(self, key: Optional[random.PRNGKey]=None,
|
|
1722
1660
|
epochs: int=999999,
|
|
1723
1661
|
train_seconds: float=120.,
|
|
1724
|
-
|
|
1725
|
-
|
|
1662
|
+
dashboard: Optional[Any]=None,
|
|
1663
|
+
dashboard_id: Optional[str]=None,
|
|
1726
1664
|
model_params: Optional[Dict[str, Any]]=None,
|
|
1727
1665
|
policy_hyperparams: Optional[Dict[str, Any]]=None,
|
|
1728
1666
|
subs: Optional[Dict[str, Any]]=None,
|
|
@@ -1738,9 +1676,9 @@ r"""
|
|
|
1738
1676
|
|
|
1739
1677
|
:param key: JAX PRNG key (derived from clock if not provided)
|
|
1740
1678
|
:param epochs: the maximum number of steps of gradient descent
|
|
1741
|
-
:param train_seconds: total time allocated for gradient descent
|
|
1742
|
-
:param
|
|
1743
|
-
:param
|
|
1679
|
+
:param train_seconds: total time allocated for gradient descent
|
|
1680
|
+
:param dashboard: dashboard to display training results
|
|
1681
|
+
:param dashboard_id: experiment id for the dashboard
|
|
1744
1682
|
:param model_params: optional model-parameters to override default
|
|
1745
1683
|
:param policy_hyperparams: hyper-parameters for the policy/plan, such as
|
|
1746
1684
|
weights for sigmoid wrapping boolean actions
|
|
@@ -1760,9 +1698,14 @@ r"""
|
|
|
1760
1698
|
start_time = time.time()
|
|
1761
1699
|
elapsed_outside_loop = 0
|
|
1762
1700
|
|
|
1701
|
+
# ======================================================================
|
|
1702
|
+
# INITIALIZATION OF HYPER-PARAMETERS
|
|
1703
|
+
# ======================================================================
|
|
1704
|
+
|
|
1763
1705
|
# if PRNG key is not provided
|
|
1764
1706
|
if key is None:
|
|
1765
1707
|
key = random.PRNGKey(round(time.time() * 1000))
|
|
1708
|
+
dash_key = key[1].item()
|
|
1766
1709
|
|
|
1767
1710
|
# if policy_hyperparams is not provided
|
|
1768
1711
|
if policy_hyperparams is None:
|
|
@@ -1789,7 +1732,7 @@ r"""
|
|
|
1789
1732
|
|
|
1790
1733
|
# print summary of parameters:
|
|
1791
1734
|
if print_summary:
|
|
1792
|
-
self.
|
|
1735
|
+
print(self.summarize_system())
|
|
1793
1736
|
self.summarize_hyperparameters()
|
|
1794
1737
|
print(f'optimize() call hyper-parameters:\n'
|
|
1795
1738
|
f' PRNG key ={key}\n'
|
|
@@ -1800,16 +1743,16 @@ r"""
|
|
|
1800
1743
|
f' override_subs_dict ={subs is not None}\n'
|
|
1801
1744
|
f' provide_param_guess={guess is not None}\n'
|
|
1802
1745
|
f' test_rolling_window={test_rolling_window}\n'
|
|
1803
|
-
f'
|
|
1804
|
-
f'
|
|
1746
|
+
f' dashboard ={dashboard is not None}\n'
|
|
1747
|
+
f' dashboard_id ={dashboard_id}\n'
|
|
1805
1748
|
f' print_summary ={print_summary}\n'
|
|
1806
1749
|
f' print_progress ={print_progress}\n'
|
|
1807
1750
|
f' stopping_rule ={stopping_rule}\n')
|
|
1808
|
-
|
|
1809
|
-
|
|
1810
|
-
|
|
1811
|
-
|
|
1812
|
-
|
|
1751
|
+
|
|
1752
|
+
# ======================================================================
|
|
1753
|
+
# INITIALIZATION OF STATE AND POLICY
|
|
1754
|
+
# ======================================================================
|
|
1755
|
+
|
|
1813
1756
|
# compute a batched version of the initial values
|
|
1814
1757
|
if subs is None:
|
|
1815
1758
|
subs = self.test_compiled.init_values
|
|
@@ -1841,6 +1784,10 @@ r"""
|
|
|
1841
1784
|
policy_params = guess
|
|
1842
1785
|
opt_state = self.optimizer.init(policy_params)
|
|
1843
1786
|
opt_aux = {}
|
|
1787
|
+
|
|
1788
|
+
# ======================================================================
|
|
1789
|
+
# INITIALIZATION OF RUNNING STATISTICS
|
|
1790
|
+
# ======================================================================
|
|
1844
1791
|
|
|
1845
1792
|
# initialize running statistics
|
|
1846
1793
|
best_params, best_loss, best_grad = policy_params, jnp.inf, jnp.inf
|
|
@@ -1854,35 +1801,39 @@ r"""
|
|
|
1854
1801
|
if stopping_rule is not None:
|
|
1855
1802
|
stopping_rule.reset()
|
|
1856
1803
|
|
|
1857
|
-
# initialize
|
|
1858
|
-
if
|
|
1859
|
-
|
|
1860
|
-
|
|
1861
|
-
|
|
1862
|
-
|
|
1863
|
-
|
|
1864
|
-
|
|
1804
|
+
# initialize dash board
|
|
1805
|
+
if dashboard is not None:
|
|
1806
|
+
dashboard_id = dashboard.register_experiment(
|
|
1807
|
+
dashboard_id, dashboard.get_planner_info(self),
|
|
1808
|
+
key=dash_key, viz=self.dashboard_viz)
|
|
1809
|
+
|
|
1810
|
+
# ======================================================================
|
|
1811
|
+
# MAIN TRAINING LOOP BEGINS
|
|
1812
|
+
# ======================================================================
|
|
1865
1813
|
|
|
1866
|
-
# training loop
|
|
1867
1814
|
iters = range(epochs)
|
|
1868
1815
|
if print_progress:
|
|
1869
1816
|
iters = tqdm(iters, total=100, position=tqdm_position)
|
|
1870
1817
|
position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
|
|
1871
1818
|
|
|
1872
1819
|
for it in iters:
|
|
1873
|
-
status = JaxPlannerStatus.NORMAL
|
|
1874
1820
|
|
|
1875
|
-
#
|
|
1876
|
-
|
|
1877
|
-
|
|
1878
|
-
|
|
1821
|
+
# ==================================================================
|
|
1822
|
+
# NEXT GRADIENT DESCENT STEP
|
|
1823
|
+
# ==================================================================
|
|
1824
|
+
|
|
1825
|
+
status = JaxPlannerStatus.NORMAL
|
|
1879
1826
|
|
|
1880
1827
|
# update the parameters of the plan
|
|
1881
1828
|
key, subkey = random.split(key)
|
|
1882
|
-
policy_params, converged, opt_state, opt_aux,
|
|
1883
|
-
|
|
1829
|
+
(policy_params, converged, opt_state, opt_aux,
|
|
1830
|
+
train_loss, train_log, model_params) = \
|
|
1884
1831
|
self.update(subkey, policy_params, policy_hyperparams,
|
|
1885
1832
|
train_subs, model_params, opt_state, opt_aux)
|
|
1833
|
+
|
|
1834
|
+
# ==================================================================
|
|
1835
|
+
# STATUS CHECKS AND LOGGING
|
|
1836
|
+
# ==================================================================
|
|
1886
1837
|
|
|
1887
1838
|
# no progress
|
|
1888
1839
|
grad_norm_zero, _ = jax.tree_util.tree_flatten(
|
|
@@ -1901,12 +1852,11 @@ r"""
|
|
|
1901
1852
|
# numerical error
|
|
1902
1853
|
if not np.isfinite(train_loss):
|
|
1903
1854
|
raise_warning(
|
|
1904
|
-
f'
|
|
1905
|
-
'red')
|
|
1855
|
+
f'JAX planner aborted due to invalid loss {train_loss}.', 'red')
|
|
1906
1856
|
status = JaxPlannerStatus.INVALID_GRADIENT
|
|
1907
1857
|
|
|
1908
1858
|
# evaluate test losses and record best plan so far
|
|
1909
|
-
test_loss, log = self.test_loss(
|
|
1859
|
+
test_loss, (log, model_params_test) = self.test_loss(
|
|
1910
1860
|
subkey, policy_params, policy_hyperparams,
|
|
1911
1861
|
test_subs, model_params_test)
|
|
1912
1862
|
test_loss = rolling_test_loss.update(test_loss)
|
|
@@ -1915,32 +1865,15 @@ r"""
|
|
|
1915
1865
|
policy_params, test_loss, train_log['grad']
|
|
1916
1866
|
last_iter_improve = it
|
|
1917
1867
|
|
|
1918
|
-
# save the plan figure
|
|
1919
|
-
if plot is not None and it % plot_step == 0:
|
|
1920
|
-
xticks.append(it // plot_step)
|
|
1921
|
-
loss_values.append(test_loss.item())
|
|
1922
|
-
action_values = {name: values
|
|
1923
|
-
for (name, values) in log['fluents'].items()
|
|
1924
|
-
if name in self.rddl.action_fluents}
|
|
1925
|
-
returns = -np.sum(np.asarray(log['reward']), axis=1)
|
|
1926
|
-
plot.redraw(xticks, loss_values, action_values, returns)
|
|
1927
|
-
|
|
1928
|
-
# if the progress bar is used
|
|
1929
|
-
elapsed = time.time() - start_time - elapsed_outside_loop
|
|
1930
|
-
if print_progress:
|
|
1931
|
-
iters.n = int(100 * min(1, max(elapsed / train_seconds, it / epochs)))
|
|
1932
|
-
iters.set_description(
|
|
1933
|
-
f'{position_str} {it:6} it / {-train_loss:14.6f} train / '
|
|
1934
|
-
f'{-test_loss:14.6f} test / {-best_loss:14.6f} best / '
|
|
1935
|
-
f'{status.value} status')
|
|
1936
|
-
|
|
1937
1868
|
# reached computation budget
|
|
1869
|
+
elapsed = time.time() - start_time - elapsed_outside_loop
|
|
1938
1870
|
if elapsed >= train_seconds:
|
|
1939
1871
|
status = JaxPlannerStatus.TIME_BUDGET_REACHED
|
|
1940
1872
|
if it >= epochs - 1:
|
|
1941
1873
|
status = JaxPlannerStatus.ITER_BUDGET_REACHED
|
|
1942
1874
|
|
|
1943
|
-
#
|
|
1875
|
+
# build a callback
|
|
1876
|
+
progress_percent = int(100 * min(1, max(elapsed / train_seconds, it / epochs)))
|
|
1944
1877
|
callback = {
|
|
1945
1878
|
'status': status,
|
|
1946
1879
|
'iteration': it,
|
|
@@ -1952,29 +1885,48 @@ r"""
|
|
|
1952
1885
|
'last_iteration_improved': last_iter_improve,
|
|
1953
1886
|
'grad': train_log['grad'],
|
|
1954
1887
|
'best_grad': best_grad,
|
|
1955
|
-
'noise_sigma': noise_sigma,
|
|
1956
1888
|
'updates': train_log['updates'],
|
|
1957
1889
|
'elapsed_time': elapsed,
|
|
1958
1890
|
'key': key,
|
|
1891
|
+
'model_params': model_params,
|
|
1892
|
+
'progress': progress_percent,
|
|
1893
|
+
'train_log': train_log,
|
|
1959
1894
|
**log
|
|
1960
1895
|
}
|
|
1896
|
+
|
|
1897
|
+
# stopping condition reached
|
|
1898
|
+
if stopping_rule is not None and stopping_rule.monitor(callback):
|
|
1899
|
+
callback['status'] = status = JaxPlannerStatus.STOPPING_RULE_REACHED
|
|
1900
|
+
|
|
1901
|
+
# if the progress bar is used
|
|
1902
|
+
if print_progress:
|
|
1903
|
+
iters.n = progress_percent
|
|
1904
|
+
iters.set_description(
|
|
1905
|
+
f'{position_str} {it:6} it / {-train_loss:14.6f} train / '
|
|
1906
|
+
f'{-test_loss:14.6f} test / {-best_loss:14.6f} best / '
|
|
1907
|
+
f'{status.value} status'
|
|
1908
|
+
)
|
|
1909
|
+
|
|
1910
|
+
# dash-board
|
|
1911
|
+
if dashboard is not None:
|
|
1912
|
+
dashboard.update_experiment(dashboard_id, callback)
|
|
1913
|
+
|
|
1914
|
+
# yield the callback
|
|
1961
1915
|
start_time_outside = time.time()
|
|
1962
1916
|
yield callback
|
|
1963
1917
|
elapsed_outside_loop += (time.time() - start_time_outside)
|
|
1964
1918
|
|
|
1965
1919
|
# abortion check
|
|
1966
|
-
if status.
|
|
1967
|
-
break
|
|
1968
|
-
|
|
1969
|
-
|
|
1970
|
-
|
|
1971
|
-
|
|
1972
|
-
|
|
1920
|
+
if status.is_terminal():
|
|
1921
|
+
break
|
|
1922
|
+
|
|
1923
|
+
# ======================================================================
|
|
1924
|
+
# POST-PROCESSING AND CLEANUP
|
|
1925
|
+
# ======================================================================
|
|
1926
|
+
|
|
1973
1927
|
# release resources
|
|
1974
1928
|
if print_progress:
|
|
1975
1929
|
iters.close()
|
|
1976
|
-
if plot is not None:
|
|
1977
|
-
plot.close()
|
|
1978
1930
|
|
|
1979
1931
|
# validate the test return
|
|
1980
1932
|
if log:
|
|
@@ -2098,101 +2050,6 @@ r"""
|
|
|
2098
2050
|
return actions
|
|
2099
2051
|
|
|
2100
2052
|
|
|
2101
|
-
class JaxLineSearchPlanner(JaxBackpropPlanner):
|
|
2102
|
-
'''A class for optimizing an action sequence in the given RDDL MDP using
|
|
2103
|
-
linear search gradient descent, with the Armijo condition.'''
|
|
2104
|
-
|
|
2105
|
-
def __init__(self, *args,
|
|
2106
|
-
decay: float=0.8,
|
|
2107
|
-
c: float=0.1,
|
|
2108
|
-
step_max: float=1.0,
|
|
2109
|
-
step_min: float=1e-6,
|
|
2110
|
-
**kwargs) -> None:
|
|
2111
|
-
'''Creates a new gradient-based algorithm for optimizing action sequences
|
|
2112
|
-
(plan) in the given RDDL using line search. All arguments are the
|
|
2113
|
-
same as in the parent class, except:
|
|
2114
|
-
|
|
2115
|
-
:param decay: reduction factor of learning rate per line search iteration
|
|
2116
|
-
:param c: positive coefficient in Armijo condition, should be in (0, 1)
|
|
2117
|
-
:param step_max: initial learning rate for line search
|
|
2118
|
-
:param step_min: minimum possible learning rate (line search halts)
|
|
2119
|
-
'''
|
|
2120
|
-
self.decay = decay
|
|
2121
|
-
self.c = c
|
|
2122
|
-
self.step_max = step_max
|
|
2123
|
-
self.step_min = step_min
|
|
2124
|
-
if 'clip_grad' in kwargs:
|
|
2125
|
-
raise_warning('clip_grad parameter conflicts with '
|
|
2126
|
-
'line search planner and will be ignored.', 'red')
|
|
2127
|
-
del kwargs['clip_grad']
|
|
2128
|
-
super(JaxLineSearchPlanner, self).__init__(*args, **kwargs)
|
|
2129
|
-
|
|
2130
|
-
def summarize_hyperparameters(self) -> None:
|
|
2131
|
-
super(JaxLineSearchPlanner, self).summarize_hyperparameters()
|
|
2132
|
-
print(f'linesearch hyper-parameters:\n'
|
|
2133
|
-
f' decay ={self.decay}\n'
|
|
2134
|
-
f' c ={self.c}\n'
|
|
2135
|
-
f' lr_range=({self.step_min}, {self.step_max})')
|
|
2136
|
-
|
|
2137
|
-
def _jax_update(self, loss):
|
|
2138
|
-
optimizer = self.optimizer
|
|
2139
|
-
projection = self.plan.projection
|
|
2140
|
-
decay, c, lrmax, lrmin = self.decay, self.c, self.step_max, self.step_min
|
|
2141
|
-
|
|
2142
|
-
# initialize the line search routine
|
|
2143
|
-
@jax.jit
|
|
2144
|
-
def _jax_wrapped_line_search_init(key, policy_params, hyperparams,
|
|
2145
|
-
subs, model_params):
|
|
2146
|
-
(f, log), grad = jax.value_and_grad(loss, argnums=1, has_aux=True)(
|
|
2147
|
-
key, policy_params, hyperparams, subs, model_params)
|
|
2148
|
-
gnorm2 = jax.tree_map(lambda x: jnp.sum(jnp.square(x)), grad)
|
|
2149
|
-
gnorm2 = jax.tree_util.tree_reduce(jnp.add, gnorm2)
|
|
2150
|
-
log['grad'] = grad
|
|
2151
|
-
return f, grad, gnorm2, log
|
|
2152
|
-
|
|
2153
|
-
# compute the next trial solution
|
|
2154
|
-
@jax.jit
|
|
2155
|
-
def _jax_wrapped_line_search_trial(
|
|
2156
|
-
step, grad, key, params, hparams, subs, mparams, state):
|
|
2157
|
-
state.hyperparams['learning_rate'] = step
|
|
2158
|
-
updates, new_state = optimizer.update(grad, state)
|
|
2159
|
-
new_params = optax.apply_updates(params, updates)
|
|
2160
|
-
new_params, _ = projection(new_params, hparams)
|
|
2161
|
-
f_step, _ = loss(key, new_params, hparams, subs, mparams)
|
|
2162
|
-
return f_step, new_params, new_state
|
|
2163
|
-
|
|
2164
|
-
# main iteration of line search
|
|
2165
|
-
def _jax_wrapped_plan_update(key, policy_params, hyperparams,
|
|
2166
|
-
subs, model_params, opt_state, opt_aux):
|
|
2167
|
-
|
|
2168
|
-
# initialize the line search
|
|
2169
|
-
f, grad, gnorm2, log = _jax_wrapped_line_search_init(
|
|
2170
|
-
key, policy_params, hyperparams, subs, model_params)
|
|
2171
|
-
|
|
2172
|
-
# continue to reduce the learning rate until the Armijo condition holds
|
|
2173
|
-
trials = 0
|
|
2174
|
-
step = lrmax / decay
|
|
2175
|
-
f_step = np.inf
|
|
2176
|
-
best_f, best_step, best_params, best_state = np.inf, None, None, None
|
|
2177
|
-
while (f_step > f - c * step * gnorm2 and step * decay >= lrmin) \
|
|
2178
|
-
or not trials:
|
|
2179
|
-
trials += 1
|
|
2180
|
-
step *= decay
|
|
2181
|
-
f_step, new_params, new_state = _jax_wrapped_line_search_trial(
|
|
2182
|
-
step, grad, key, policy_params, hyperparams, subs,
|
|
2183
|
-
model_params, opt_state)
|
|
2184
|
-
if f_step < best_f:
|
|
2185
|
-
best_f, best_step, best_params, best_state = \
|
|
2186
|
-
f_step, step, new_params, new_state
|
|
2187
|
-
|
|
2188
|
-
log['updates'] = None
|
|
2189
|
-
log['line_search_iters'] = trials
|
|
2190
|
-
log['learning_rate'] = best_step
|
|
2191
|
-
opt_aux['best_step'] = best_step
|
|
2192
|
-
return best_params, True, best_state, opt_aux, best_f, log
|
|
2193
|
-
|
|
2194
|
-
return _jax_wrapped_plan_update
|
|
2195
|
-
|
|
2196
2053
|
# ***********************************************************************
|
|
2197
2054
|
# ALL VERSIONS OF RISK FUNCTIONS
|
|
2198
2055
|
#
|
|
@@ -2224,6 +2081,7 @@ def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
|
|
|
2224
2081
|
returns <= jnp.percentile(returns, q=100 * alpha))
|
|
2225
2082
|
return jnp.sum(returns * alpha_mask) / jnp.sum(alpha_mask)
|
|
2226
2083
|
|
|
2084
|
+
|
|
2227
2085
|
# ***********************************************************************
|
|
2228
2086
|
# ALL VERSIONS OF CONTROLLERS
|
|
2229
2087
|
#
|
|
@@ -2268,8 +2126,10 @@ class JaxOfflineController(BaseAgent):
|
|
|
2268
2126
|
self.params_given = params is not None
|
|
2269
2127
|
|
|
2270
2128
|
self.step = 0
|
|
2129
|
+
self.callback = None
|
|
2271
2130
|
if not self.train_on_reset and not self.params_given:
|
|
2272
2131
|
callback = self.planner.optimize(key=self.key, **self.train_kwargs)
|
|
2132
|
+
self.callback = callback
|
|
2273
2133
|
params = callback['best_params']
|
|
2274
2134
|
self.params = params
|
|
2275
2135
|
|
|
@@ -2284,6 +2144,7 @@ class JaxOfflineController(BaseAgent):
|
|
|
2284
2144
|
self.step = 0
|
|
2285
2145
|
if self.train_on_reset and not self.params_given:
|
|
2286
2146
|
callback = self.planner.optimize(key=self.key, **self.train_kwargs)
|
|
2147
|
+
self.callback = callback
|
|
2287
2148
|
self.params = callback['best_params']
|
|
2288
2149
|
|
|
2289
2150
|
|
|
@@ -2326,7 +2187,9 @@ class JaxOnlineController(BaseAgent):
|
|
|
2326
2187
|
key=self.key,
|
|
2327
2188
|
guess=self.guess,
|
|
2328
2189
|
subs=state,
|
|
2329
|
-
**self.train_kwargs
|
|
2190
|
+
**self.train_kwargs
|
|
2191
|
+
)
|
|
2192
|
+
self.callback = callback
|
|
2330
2193
|
params = callback['best_params']
|
|
2331
2194
|
self.key, subkey = random.split(self.key)
|
|
2332
2195
|
actions = planner.get_action(
|
|
@@ -2337,4 +2200,5 @@ class JaxOnlineController(BaseAgent):
|
|
|
2337
2200
|
|
|
2338
2201
|
def reset(self) -> None:
|
|
2339
2202
|
self.guess = None
|
|
2203
|
+
self.callback = None
|
|
2340
2204
|
|