pyRDDLGym-jax 0.1__py3-none-any.whl → 0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/core/compiler.py +445 -221
- pyRDDLGym_jax/core/logic.py +129 -62
- pyRDDLGym_jax/core/planner.py +699 -332
- pyRDDLGym_jax/core/simulator.py +5 -7
- pyRDDLGym_jax/core/tuning.py +23 -12
- pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_drp.cfg → Cartpole_Continuous_gym_drp.cfg} +2 -3
- pyRDDLGym_jax/examples/configs/{HVAC_drp.cfg → HVAC_ippc2023_drp.cfg} +2 -2
- pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +19 -0
- pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +18 -0
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +18 -0
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +1 -1
- pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +1 -1
- pyRDDLGym_jax/examples/configs/default_drp.cfg +19 -0
- pyRDDLGym_jax/examples/configs/default_replan.cfg +20 -0
- pyRDDLGym_jax/examples/configs/default_slp.cfg +19 -0
- pyRDDLGym_jax/examples/run_gradient.py +1 -1
- pyRDDLGym_jax/examples/run_gym.py +1 -2
- pyRDDLGym_jax/examples/run_plan.py +7 -0
- pyRDDLGym_jax/examples/run_tune.py +6 -0
- {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.2.dist-info}/METADATA +1 -1
- pyRDDLGym_jax-0.2.dist-info/RECORD +46 -0
- {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.2.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax-0.1.dist-info/RECORD +0 -40
- /pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_replan.cfg → Cartpole_Continuous_gym_replan.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_slp.cfg → Cartpole_Continuous_gym_slp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{HVAC_slp.cfg → HVAC_ippc2023_slp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{MarsRover_drp.cfg → MarsRover_ippc2023_drp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{MarsRover_slp.cfg → MarsRover_ippc2023_slp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{MountainCar_slp.cfg → MountainCar_Continuous_gym_slp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{Pendulum_slp.cfg → Pendulum_gym_slp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{PowerGen_drp.cfg → PowerGen_Continuous_drp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{PowerGen_replan.cfg → PowerGen_Continuous_replan.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{PowerGen_slp.cfg → PowerGen_Continuous_slp.cfg} +0 -0
- {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.2.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.2.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/planner.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
|
1
|
+
__version__ = '0.2'
|
|
2
|
+
|
|
1
3
|
from ast import literal_eval
|
|
2
4
|
from collections import deque
|
|
3
5
|
import configparser
|
|
6
|
+
from enum import Enum
|
|
4
7
|
import haiku as hk
|
|
5
8
|
import jax
|
|
6
9
|
import jax.numpy as jnp
|
|
@@ -13,11 +16,28 @@ import sys
|
|
|
13
16
|
import termcolor
|
|
14
17
|
import time
|
|
15
18
|
from tqdm import tqdm
|
|
16
|
-
from typing import Callable, Dict, Generator, Set, Sequence, Tuple
|
|
19
|
+
from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Tuple, Union
|
|
20
|
+
|
|
21
|
+
Activation = Callable[[jnp.ndarray], jnp.ndarray]
|
|
22
|
+
Bounds = Dict[str, Tuple[np.ndarray, np.ndarray]]
|
|
23
|
+
Kwargs = Dict[str, Any]
|
|
24
|
+
Pytree = Any
|
|
25
|
+
|
|
26
|
+
from pyRDDLGym.core.debug.exception import raise_warning
|
|
17
27
|
|
|
28
|
+
# try to import matplotlib, if failed then skip plotting
|
|
29
|
+
try:
|
|
30
|
+
import matplotlib
|
|
31
|
+
import matplotlib.pyplot as plt
|
|
32
|
+
matplotlib.use('TkAgg')
|
|
33
|
+
except Exception:
|
|
34
|
+
raise_warning('matplotlib is not installed, '
|
|
35
|
+
'plotting functionality is disabled.', 'red')
|
|
36
|
+
plt = None
|
|
37
|
+
|
|
18
38
|
from pyRDDLGym.core.compiler.model import RDDLPlanningModel, RDDLLiftedModel
|
|
39
|
+
from pyRDDLGym.core.debug.logger import Logger
|
|
19
40
|
from pyRDDLGym.core.debug.exception import (
|
|
20
|
-
raise_warning,
|
|
21
41
|
RDDLNotImplementedError,
|
|
22
42
|
RDDLUndefinedVariableError,
|
|
23
43
|
RDDLTypeError
|
|
@@ -37,6 +57,7 @@ from pyRDDLGym_jax.core.logic import FuzzyLogic
|
|
|
37
57
|
# - instantiate planner
|
|
38
58
|
#
|
|
39
59
|
# ***********************************************************************
|
|
60
|
+
|
|
40
61
|
def _parse_config_file(path: str):
|
|
41
62
|
if not os.path.isfile(path):
|
|
42
63
|
raise FileNotFoundError(f'File {path} does not exist.')
|
|
@@ -59,51 +80,94 @@ def _parse_config_string(value: str):
|
|
|
59
80
|
return config, args
|
|
60
81
|
|
|
61
82
|
|
|
83
|
+
def _getattr_any(packages, item):
|
|
84
|
+
for package in packages:
|
|
85
|
+
loaded = getattr(package, item, None)
|
|
86
|
+
if loaded is not None:
|
|
87
|
+
return loaded
|
|
88
|
+
return None
|
|
89
|
+
|
|
90
|
+
|
|
62
91
|
def _load_config(config, args):
|
|
63
92
|
model_args = {k: args[k] for (k, _) in config.items('Model')}
|
|
64
93
|
planner_args = {k: args[k] for (k, _) in config.items('Optimizer')}
|
|
65
94
|
train_args = {k: args[k] for (k, _) in config.items('Training')}
|
|
66
95
|
|
|
67
|
-
train_args['key'] = jax.random.PRNGKey(train_args['key'])
|
|
68
|
-
|
|
69
96
|
# read the model settings
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
97
|
+
logic_name = model_args.get('logic', 'FuzzyLogic')
|
|
98
|
+
logic_kwargs = model_args.get('logic_kwargs', {})
|
|
99
|
+
tnorm_name = model_args.get('tnorm', 'ProductTNorm')
|
|
100
|
+
tnorm_kwargs = model_args.get('tnorm_kwargs', {})
|
|
101
|
+
comp_name = model_args.get('complement', 'StandardComplement')
|
|
102
|
+
comp_kwargs = model_args.get('complement_kwargs', {})
|
|
103
|
+
compare_name = model_args.get('comparison', 'SigmoidComparison')
|
|
104
|
+
compare_kwargs = model_args.get('comparison_kwargs', {})
|
|
74
105
|
logic_kwargs['tnorm'] = getattr(logic, tnorm_name)(**tnorm_kwargs)
|
|
75
|
-
|
|
106
|
+
logic_kwargs['complement'] = getattr(logic, comp_name)(**comp_kwargs)
|
|
107
|
+
logic_kwargs['comparison'] = getattr(logic, compare_name)(**compare_kwargs)
|
|
76
108
|
|
|
77
|
-
# read the
|
|
109
|
+
# read the policy settings
|
|
78
110
|
plan_method = planner_args.pop('method')
|
|
79
111
|
plan_kwargs = planner_args.pop('method_kwargs', {})
|
|
80
112
|
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
113
|
+
# policy initialization
|
|
114
|
+
plan_initializer = plan_kwargs.get('initializer', None)
|
|
115
|
+
if plan_initializer is not None:
|
|
116
|
+
initializer = _getattr_any(packages=[initializers], item=plan_initializer)
|
|
117
|
+
if initializer is None:
|
|
118
|
+
raise_warning(
|
|
119
|
+
f'Ignoring invalid initializer <{plan_initializer}>.', 'red')
|
|
120
|
+
del plan_kwargs['initializer']
|
|
121
|
+
else:
|
|
122
|
+
init_kwargs = plan_kwargs.pop('initializer_kwargs', {})
|
|
123
|
+
try:
|
|
124
|
+
plan_kwargs['initializer'] = initializer(**init_kwargs)
|
|
125
|
+
except Exception as _:
|
|
126
|
+
raise_warning(
|
|
127
|
+
f'Ignoring invalid initializer_kwargs <{init_kwargs}>.', 'red')
|
|
128
|
+
plan_kwargs['initializer'] = initializer
|
|
93
129
|
|
|
130
|
+
# policy activation
|
|
131
|
+
plan_activation = plan_kwargs.get('activation', None)
|
|
132
|
+
if plan_activation is not None:
|
|
133
|
+
activation = _getattr_any(packages=[jax.nn, jax.numpy], item=plan_activation)
|
|
134
|
+
if activation is None:
|
|
135
|
+
raise_warning(
|
|
136
|
+
f'Ignoring invalid activation <{plan_activation}>.', 'red')
|
|
137
|
+
del plan_kwargs['activation']
|
|
138
|
+
else:
|
|
139
|
+
plan_kwargs['activation'] = activation
|
|
140
|
+
|
|
141
|
+
# read the planner settings
|
|
142
|
+
planner_args['logic'] = getattr(logic, logic_name)(**logic_kwargs)
|
|
94
143
|
planner_args['plan'] = getattr(sys.modules[__name__], plan_method)(**plan_kwargs)
|
|
95
|
-
|
|
144
|
+
|
|
145
|
+
# planner optimizer
|
|
146
|
+
planner_optimizer = planner_args.get('optimizer', None)
|
|
147
|
+
if planner_optimizer is not None:
|
|
148
|
+
optimizer = _getattr_any(packages=[optax], item=planner_optimizer)
|
|
149
|
+
if optimizer is None:
|
|
150
|
+
raise_warning(
|
|
151
|
+
f'Ignoring invalid optimizer <{planner_optimizer}>.', 'red')
|
|
152
|
+
del planner_args['optimizer']
|
|
153
|
+
else:
|
|
154
|
+
planner_args['optimizer'] = optimizer
|
|
155
|
+
|
|
156
|
+
# read the optimize call settings
|
|
157
|
+
planner_key = train_args.get('key', None)
|
|
158
|
+
if planner_key is not None:
|
|
159
|
+
train_args['key'] = random.PRNGKey(planner_key)
|
|
96
160
|
|
|
97
161
|
return planner_args, plan_kwargs, train_args
|
|
98
162
|
|
|
99
163
|
|
|
100
|
-
def load_config(path: str) -> Tuple[
|
|
164
|
+
def load_config(path: str) -> Tuple[Kwargs, ...]:
|
|
101
165
|
'''Loads a config file at the specified file path.'''
|
|
102
166
|
config, args = _parse_config_file(path)
|
|
103
167
|
return _load_config(config, args)
|
|
104
168
|
|
|
105
169
|
|
|
106
|
-
def load_config_from_string(value: str) -> Tuple[
|
|
170
|
+
def load_config_from_string(value: str) -> Tuple[Kwargs, ...]:
|
|
107
171
|
'''Loads config file contents specified explicitly as a string value.'''
|
|
108
172
|
config, args = _parse_config_string(value)
|
|
109
173
|
return _load_config(config, args)
|
|
@@ -115,6 +179,20 @@ def load_config_from_string(value: str) -> Tuple[Dict[str, object], ...]:
|
|
|
115
179
|
# - replace discrete ops in state dynamics/reward with differentiable ones
|
|
116
180
|
#
|
|
117
181
|
# ***********************************************************************
|
|
182
|
+
|
|
183
|
+
def _function_discrete_approx_named(logic):
|
|
184
|
+
jax_discrete, jax_param = logic.discrete()
|
|
185
|
+
|
|
186
|
+
def _jax_wrapped_discrete_calc_approx(key, prob, params):
|
|
187
|
+
sample = jax_discrete(key, prob, params)
|
|
188
|
+
out_of_bounds = jnp.logical_not(jnp.logical_and(
|
|
189
|
+
jnp.all(prob >= 0),
|
|
190
|
+
jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
|
|
191
|
+
return sample, out_of_bounds
|
|
192
|
+
|
|
193
|
+
return _jax_wrapped_discrete_calc_approx, jax_param
|
|
194
|
+
|
|
195
|
+
|
|
118
196
|
class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
119
197
|
'''Compiles a RDDL AST representation to an equivalent JAX representation.
|
|
120
198
|
Unlike its parent class, this class treats all fluents as real-valued, and
|
|
@@ -124,7 +202,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
124
202
|
|
|
125
203
|
def __init__(self, *args,
|
|
126
204
|
logic: FuzzyLogic=FuzzyLogic(),
|
|
127
|
-
cpfs_without_grad: Set=
|
|
205
|
+
cpfs_without_grad: Optional[Set[str]]=None,
|
|
128
206
|
**kwargs) -> None:
|
|
129
207
|
'''Creates a new RDDL to Jax compiler, where operations that are not
|
|
130
208
|
differentiable are converted to approximate forms that have defined
|
|
@@ -140,27 +218,30 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
140
218
|
'''
|
|
141
219
|
super(JaxRDDLCompilerWithGrad, self).__init__(*args, **kwargs)
|
|
142
220
|
self.logic = logic
|
|
221
|
+
self.logic.set_use64bit(self.use64bit)
|
|
222
|
+
if cpfs_without_grad is None:
|
|
223
|
+
cpfs_without_grad = set()
|
|
143
224
|
self.cpfs_without_grad = cpfs_without_grad
|
|
144
225
|
|
|
145
226
|
# actions and CPFs must be continuous
|
|
146
|
-
raise_warning(
|
|
227
|
+
raise_warning('Initial values of pvariables will be cast to real.')
|
|
147
228
|
for (var, values) in self.init_values.items():
|
|
148
229
|
self.init_values[var] = np.asarray(values, dtype=self.REAL)
|
|
149
230
|
|
|
150
231
|
# overwrite basic operations with fuzzy ones
|
|
151
232
|
self.RELATIONAL_OPS = {
|
|
152
|
-
'>=': logic.
|
|
153
|
-
'<=': logic.
|
|
233
|
+
'>=': logic.greater_equal(),
|
|
234
|
+
'<=': logic.less_equal(),
|
|
154
235
|
'<': logic.less(),
|
|
155
236
|
'>': logic.greater(),
|
|
156
237
|
'==': logic.equal(),
|
|
157
|
-
'~=': logic.
|
|
238
|
+
'~=': logic.not_equal()
|
|
158
239
|
}
|
|
159
|
-
self.LOGICAL_NOT = logic.
|
|
240
|
+
self.LOGICAL_NOT = logic.logical_not()
|
|
160
241
|
self.LOGICAL_OPS = {
|
|
161
|
-
'^': logic.
|
|
162
|
-
'&': logic.
|
|
163
|
-
'|': logic.
|
|
242
|
+
'^': logic.logical_and(),
|
|
243
|
+
'&': logic.logical_and(),
|
|
244
|
+
'|': logic.logical_or(),
|
|
164
245
|
'~': logic.xor(),
|
|
165
246
|
'=>': logic.implies(),
|
|
166
247
|
'<=>': logic.equiv()
|
|
@@ -169,15 +250,19 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
169
250
|
self.AGGREGATION_OPS['exists'] = logic.exists()
|
|
170
251
|
self.AGGREGATION_OPS['argmin'] = logic.argmin()
|
|
171
252
|
self.AGGREGATION_OPS['argmax'] = logic.argmax()
|
|
172
|
-
self.KNOWN_UNARY['sgn'] = logic.
|
|
253
|
+
self.KNOWN_UNARY['sgn'] = logic.sgn()
|
|
173
254
|
self.KNOWN_UNARY['floor'] = logic.floor()
|
|
174
255
|
self.KNOWN_UNARY['ceil'] = logic.ceil()
|
|
175
256
|
self.KNOWN_UNARY['round'] = logic.round()
|
|
176
257
|
self.KNOWN_UNARY['sqrt'] = logic.sqrt()
|
|
177
|
-
self.KNOWN_BINARY['div'] = logic.
|
|
258
|
+
self.KNOWN_BINARY['div'] = logic.div()
|
|
178
259
|
self.KNOWN_BINARY['mod'] = logic.mod()
|
|
179
260
|
self.KNOWN_BINARY['fmod'] = logic.mod()
|
|
180
|
-
|
|
261
|
+
self.IF_HELPER = logic.control_if()
|
|
262
|
+
self.SWITCH_HELPER = logic.control_switch()
|
|
263
|
+
self.BERNOULLI_HELPER = logic.bernoulli()
|
|
264
|
+
self.DISCRETE_HELPER = _function_discrete_approx_named(logic)
|
|
265
|
+
|
|
181
266
|
def _jax_stop_grad(self, jax_expr):
|
|
182
267
|
|
|
183
268
|
def _jax_wrapped_stop_grad(x, params, key):
|
|
@@ -199,35 +284,13 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
199
284
|
jax_cpfs[cpf] = self._jax_stop_grad(jax_cpfs[cpf])
|
|
200
285
|
return jax_cpfs
|
|
201
286
|
|
|
202
|
-
def _jax_if_helper(self):
|
|
203
|
-
return self.logic.If()
|
|
204
|
-
|
|
205
|
-
def _jax_switch_helper(self):
|
|
206
|
-
return self.logic.Switch()
|
|
207
|
-
|
|
208
287
|
def _jax_kron(self, expr, info):
|
|
209
288
|
if self.logic.verbose:
|
|
210
289
|
raise_warning('KronDelta will be ignored.')
|
|
211
|
-
|
|
212
290
|
arg, = expr.args
|
|
213
291
|
arg = self._jax(arg, info)
|
|
214
292
|
return arg
|
|
215
293
|
|
|
216
|
-
def _jax_bernoulli_helper(self):
|
|
217
|
-
return self.logic.bernoulli()
|
|
218
|
-
|
|
219
|
-
def _jax_discrete_helper(self):
|
|
220
|
-
jax_discrete, jax_param = self.logic.discrete()
|
|
221
|
-
|
|
222
|
-
def _jax_wrapped_discrete_calc_approx(key, prob, params):
|
|
223
|
-
sample = jax_discrete(key, prob, params)
|
|
224
|
-
out_of_bounds = jnp.logical_not(jnp.logical_and(
|
|
225
|
-
jnp.all(prob >= 0),
|
|
226
|
-
jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
|
|
227
|
-
return sample, out_of_bounds
|
|
228
|
-
|
|
229
|
-
return _jax_wrapped_discrete_calc_approx, jax_param
|
|
230
|
-
|
|
231
294
|
|
|
232
295
|
# ***********************************************************************
|
|
233
296
|
# ALL VERSIONS OF JAX PLANS
|
|
@@ -236,6 +299,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
236
299
|
# - deep reactive policy
|
|
237
300
|
#
|
|
238
301
|
# ***********************************************************************
|
|
302
|
+
|
|
239
303
|
class JaxPlan:
|
|
240
304
|
'''Base class for all JAX policy representations.'''
|
|
241
305
|
|
|
@@ -245,15 +309,15 @@ class JaxPlan:
|
|
|
245
309
|
self._test_policy = None
|
|
246
310
|
self._projection = None
|
|
247
311
|
|
|
248
|
-
def summarize_hyperparameters(self):
|
|
312
|
+
def summarize_hyperparameters(self) -> None:
|
|
249
313
|
pass
|
|
250
314
|
|
|
251
315
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
252
|
-
_bounds:
|
|
316
|
+
_bounds: Bounds,
|
|
253
317
|
horizon: int) -> None:
|
|
254
318
|
raise NotImplementedError
|
|
255
319
|
|
|
256
|
-
def guess_next_epoch(self, params:
|
|
320
|
+
def guess_next_epoch(self, params: Pytree) -> Pytree:
|
|
257
321
|
raise NotImplementedError
|
|
258
322
|
|
|
259
323
|
@property
|
|
@@ -289,7 +353,8 @@ class JaxPlan:
|
|
|
289
353
|
self._projection = value
|
|
290
354
|
|
|
291
355
|
def _calculate_action_info(self, compiled: JaxRDDLCompilerWithGrad,
|
|
292
|
-
user_bounds:
|
|
356
|
+
user_bounds: Bounds,
|
|
357
|
+
horizon: int):
|
|
293
358
|
shapes, bounds, bounds_safe, cond_lists = {}, {}, {}, {}
|
|
294
359
|
for (name, prange) in compiled.rddl.variable_ranges.items():
|
|
295
360
|
if compiled.rddl.variable_types[name] != 'action-fluent':
|
|
@@ -309,8 +374,8 @@ class JaxPlan:
|
|
|
309
374
|
else:
|
|
310
375
|
lower, upper = compiled.constraints.bounds[name]
|
|
311
376
|
lower, upper = user_bounds.get(name, (lower, upper))
|
|
312
|
-
lower = np.asarray(lower, dtype=
|
|
313
|
-
upper = np.asarray(upper, dtype=
|
|
377
|
+
lower = np.asarray(lower, dtype=compiled.REAL)
|
|
378
|
+
upper = np.asarray(upper, dtype=compiled.REAL)
|
|
314
379
|
lower_finite = np.isfinite(lower)
|
|
315
380
|
upper_finite = np.isfinite(upper)
|
|
316
381
|
bounds_safe[name] = (np.where(lower_finite, lower, 0.0),
|
|
@@ -336,7 +401,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
336
401
|
|
|
337
402
|
def __init__(self, initializer: initializers.Initializer=initializers.normal(),
|
|
338
403
|
wrap_sigmoid: bool=True,
|
|
339
|
-
min_action_prob: float=1e-
|
|
404
|
+
min_action_prob: float=1e-6,
|
|
340
405
|
wrap_non_bool: bool=False,
|
|
341
406
|
wrap_softmax: bool=False,
|
|
342
407
|
use_new_projection: bool=False,
|
|
@@ -371,7 +436,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
371
436
|
self._use_new_projection = use_new_projection
|
|
372
437
|
self._max_constraint_iter = max_constraint_iter
|
|
373
438
|
|
|
374
|
-
def summarize_hyperparameters(self):
|
|
439
|
+
def summarize_hyperparameters(self) -> None:
|
|
375
440
|
print(f'policy hyper-parameters:\n'
|
|
376
441
|
f' initializer ={type(self._initializer_base).__name__}\n'
|
|
377
442
|
f'constraint-sat strategy (simple):\n'
|
|
@@ -383,7 +448,8 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
383
448
|
f' use_new_projection ={self._use_new_projection}')
|
|
384
449
|
|
|
385
450
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
386
|
-
_bounds:
|
|
451
|
+
_bounds: Bounds,
|
|
452
|
+
horizon: int) -> None:
|
|
387
453
|
rddl = compiled.rddl
|
|
388
454
|
|
|
389
455
|
# calculate the correct action box bounds
|
|
@@ -423,7 +489,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
423
489
|
def _jax_bool_action_to_param(var, action, hyperparams):
|
|
424
490
|
if wrap_sigmoid:
|
|
425
491
|
weight = hyperparams[var]
|
|
426
|
-
return (-1.0 / weight) * jnp.
|
|
492
|
+
return (-1.0 / weight) * jnp.log(1.0 / action - 1.0)
|
|
427
493
|
else:
|
|
428
494
|
return action
|
|
429
495
|
|
|
@@ -506,7 +572,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
506
572
|
def _jax_wrapped_slp_predict_test(key, params, hyperparams, step, subs):
|
|
507
573
|
actions = {}
|
|
508
574
|
for (var, param) in params.items():
|
|
509
|
-
action = jnp.asarray(param[step, ...])
|
|
575
|
+
action = jnp.asarray(param[step, ...], dtype=compiled.REAL)
|
|
510
576
|
if var == bool_key:
|
|
511
577
|
output = jax.nn.softmax(action)
|
|
512
578
|
bool_actions = _jax_unstack_bool_from_softmax(output)
|
|
@@ -688,7 +754,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
688
754
|
# "progress" the plan one step forward and set last action to second-last
|
|
689
755
|
return jnp.append(param[1:, ...], param[-1:, ...], axis=0)
|
|
690
756
|
|
|
691
|
-
def guess_next_epoch(self, params:
|
|
757
|
+
def guess_next_epoch(self, params: Pytree) -> Pytree:
|
|
692
758
|
next_fn = JaxStraightLinePlan._guess_next_epoch
|
|
693
759
|
return jax.tree_map(next_fn, params)
|
|
694
760
|
|
|
@@ -696,10 +762,12 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
696
762
|
class JaxDeepReactivePolicy(JaxPlan):
|
|
697
763
|
'''A deep reactive policy network implementation in JAX.'''
|
|
698
764
|
|
|
699
|
-
def __init__(self, topology: Sequence[int],
|
|
700
|
-
activation:
|
|
765
|
+
def __init__(self, topology: Optional[Sequence[int]]=None,
|
|
766
|
+
activation: Activation=jnp.tanh,
|
|
701
767
|
initializer: hk.initializers.Initializer=hk.initializers.VarianceScaling(scale=2.0),
|
|
702
|
-
normalize: bool=True
|
|
768
|
+
normalize: bool=True,
|
|
769
|
+
normalizer_kwargs: Optional[Kwargs]=None,
|
|
770
|
+
wrap_non_bool: bool=False) -> None:
|
|
703
771
|
'''Creates a new deep reactive policy in JAX.
|
|
704
772
|
|
|
705
773
|
:param neurons: sequence consisting of the number of neurons in each
|
|
@@ -707,23 +775,39 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
707
775
|
:param activation: function to apply after each layer of the policy
|
|
708
776
|
:param initializer: weight initialization
|
|
709
777
|
:param normalize: whether to apply layer norm to the inputs
|
|
778
|
+
:param normalizer_kwargs: if normalize is True, apply additional arguments
|
|
779
|
+
to layer norm
|
|
780
|
+
:param wrap_non_bool: whether to wrap real or int action fluent parameters
|
|
781
|
+
with non-linearity (e.g. sigmoid or ELU) to satisfy box constraints
|
|
710
782
|
'''
|
|
711
783
|
super(JaxDeepReactivePolicy, self).__init__()
|
|
784
|
+
if topology is None:
|
|
785
|
+
topology = [128, 64]
|
|
712
786
|
self._topology = topology
|
|
713
787
|
self._activations = [activation for _ in topology]
|
|
714
788
|
self._initializer_base = initializer
|
|
715
789
|
self._initializer = initializer
|
|
716
790
|
self._normalize = normalize
|
|
791
|
+
if normalizer_kwargs is None:
|
|
792
|
+
normalizer_kwargs = {
|
|
793
|
+
'create_offset': True, 'create_scale': True,
|
|
794
|
+
'name': 'input_norm'
|
|
795
|
+
}
|
|
796
|
+
self._normalizer_kwargs = normalizer_kwargs
|
|
797
|
+
self._wrap_non_bool = wrap_non_bool
|
|
717
798
|
|
|
718
|
-
def summarize_hyperparameters(self):
|
|
799
|
+
def summarize_hyperparameters(self) -> None:
|
|
719
800
|
print(f'policy hyper-parameters:\n'
|
|
720
801
|
f' topology ={self._topology}\n'
|
|
721
802
|
f' activation_fn ={self._activations[0].__name__}\n'
|
|
722
803
|
f' initializer ={type(self._initializer_base).__name__}\n'
|
|
723
|
-
f' apply_layer_norm={self._normalize}'
|
|
804
|
+
f' apply_layer_norm={self._normalize}\n'
|
|
805
|
+
f' layer_norm_args ={self._normalizer_kwargs}\n'
|
|
806
|
+
f' wrap_non_bool ={self._wrap_non_bool}')
|
|
724
807
|
|
|
725
808
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
726
|
-
_bounds:
|
|
809
|
+
_bounds: Bounds,
|
|
810
|
+
horizon: int) -> None:
|
|
727
811
|
rddl = compiled.rddl
|
|
728
812
|
|
|
729
813
|
# calculate the correct action box bounds
|
|
@@ -751,6 +835,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
751
835
|
|
|
752
836
|
ranges = rddl.variable_ranges
|
|
753
837
|
normalize = self._normalize
|
|
838
|
+
wrap_non_bool = self._wrap_non_bool
|
|
754
839
|
init = self._initializer
|
|
755
840
|
layers = list(enumerate(zip(self._topology, self._activations)))
|
|
756
841
|
layer_sizes = {var: np.prod(shape, dtype=int)
|
|
@@ -763,9 +848,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
763
848
|
# apply layer norm
|
|
764
849
|
if normalize:
|
|
765
850
|
normalizer = hk.LayerNorm(
|
|
766
|
-
axis=-1, param_axis=-1,
|
|
767
|
-
create_offset=True, create_scale=True,
|
|
768
|
-
name='input_norm')
|
|
851
|
+
axis=-1, param_axis=-1, **self._normalizer_kwargs)
|
|
769
852
|
state = normalizer(state)
|
|
770
853
|
|
|
771
854
|
# feed state vector through hidden layers
|
|
@@ -789,16 +872,19 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
789
872
|
if not use_constraint_satisfaction:
|
|
790
873
|
actions[var] = jax.nn.sigmoid(output)
|
|
791
874
|
else:
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
875
|
+
if wrap_non_bool:
|
|
876
|
+
lower, upper = bounds_safe[var]
|
|
877
|
+
action = jnp.select(
|
|
878
|
+
condlist=cond_lists[var],
|
|
879
|
+
choicelist=[
|
|
880
|
+
lower + (upper - lower) * jax.nn.sigmoid(output),
|
|
881
|
+
lower + (jax.nn.elu(output) + 1.0),
|
|
882
|
+
upper - (jax.nn.elu(-output) + 1.0),
|
|
883
|
+
output
|
|
884
|
+
]
|
|
885
|
+
)
|
|
886
|
+
else:
|
|
887
|
+
action = output
|
|
802
888
|
actions[var] = action
|
|
803
889
|
|
|
804
890
|
# for constraint satisfaction wrap bool actions with softmax
|
|
@@ -826,12 +912,17 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
826
912
|
actions[name] = action
|
|
827
913
|
start += size
|
|
828
914
|
return actions
|
|
829
|
-
|
|
915
|
+
|
|
916
|
+
if rddl.observ_fluents:
|
|
917
|
+
observed_vars = rddl.observ_fluents
|
|
918
|
+
else:
|
|
919
|
+
observed_vars = rddl.state_fluents
|
|
920
|
+
|
|
830
921
|
# state is concatenated into single tensor
|
|
831
922
|
def _jax_wrapped_subs_to_state(subs):
|
|
832
923
|
subs = {var: value
|
|
833
924
|
for (var, value) in subs.items()
|
|
834
|
-
if var in
|
|
925
|
+
if var in observed_vars}
|
|
835
926
|
flat_subs = jax.tree_map(jnp.ravel, subs)
|
|
836
927
|
states = list(flat_subs.values())
|
|
837
928
|
state = jnp.concatenate(states)
|
|
@@ -841,6 +932,10 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
841
932
|
def _jax_wrapped_drp_predict_train(key, params, hyperparams, step, subs):
|
|
842
933
|
state = _jax_wrapped_subs_to_state(subs)
|
|
843
934
|
actions = predict_fn.apply(params, state)
|
|
935
|
+
if not wrap_non_bool:
|
|
936
|
+
for (var, action) in actions.items():
|
|
937
|
+
if var != bool_key and ranges[var] != 'bool':
|
|
938
|
+
actions[var] = jnp.clip(action, *bounds[var])
|
|
844
939
|
if use_constraint_satisfaction:
|
|
845
940
|
bool_actions = _jax_unstack_bool_from_softmax(actions[bool_key])
|
|
846
941
|
actions.update(bool_actions)
|
|
@@ -886,14 +981,14 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
886
981
|
def _jax_wrapped_drp_init(key, hyperparams, subs):
|
|
887
982
|
subs = {var: value[0, ...]
|
|
888
983
|
for (var, value) in subs.items()
|
|
889
|
-
if var in
|
|
984
|
+
if var in observed_vars}
|
|
890
985
|
state = _jax_wrapped_subs_to_state(subs)
|
|
891
986
|
params = predict_fn.init(key, state)
|
|
892
987
|
return params
|
|
893
988
|
|
|
894
989
|
self.initializer = _jax_wrapped_drp_init
|
|
895
990
|
|
|
896
|
-
def guess_next_epoch(self, params:
|
|
991
|
+
def guess_next_epoch(self, params: Pytree) -> Pytree:
|
|
897
992
|
return params
|
|
898
993
|
|
|
899
994
|
|
|
@@ -904,24 +999,135 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
904
999
|
# - more stable but slower line search based planner
|
|
905
1000
|
#
|
|
906
1001
|
# ***********************************************************************
|
|
1002
|
+
|
|
1003
|
+
class RollingMean:
|
|
1004
|
+
'''Maintains an estimate of the rolling mean of a stream of real-valued
|
|
1005
|
+
observations.'''
|
|
1006
|
+
|
|
1007
|
+
def __init__(self, window_size: int) -> None:
|
|
1008
|
+
self._window_size = window_size
|
|
1009
|
+
self._memory = deque(maxlen=window_size)
|
|
1010
|
+
self._total = 0
|
|
1011
|
+
|
|
1012
|
+
def update(self, x: float) -> float:
|
|
1013
|
+
memory = self._memory
|
|
1014
|
+
self._total += x
|
|
1015
|
+
if len(memory) == self._window_size:
|
|
1016
|
+
self._total -= memory.popleft()
|
|
1017
|
+
memory.append(x)
|
|
1018
|
+
return self._total / len(memory)
|
|
1019
|
+
|
|
1020
|
+
|
|
1021
|
+
class JaxPlannerPlot:
|
|
1022
|
+
'''Supports plotting and visualization of a JAX policy in real time.'''
|
|
1023
|
+
|
|
1024
|
+
def __init__(self, rddl: RDDLPlanningModel, horizon: int) -> None:
|
|
1025
|
+
self._fig, axes = plt.subplots(1 + len(rddl.action_fluents))
|
|
1026
|
+
|
|
1027
|
+
# prepare the loss plot
|
|
1028
|
+
self._loss_ax = axes[0]
|
|
1029
|
+
self._loss_ax.autoscale(enable=True)
|
|
1030
|
+
self._loss_ax.set_xlabel('decision epoch')
|
|
1031
|
+
self._loss_ax.set_ylabel('loss value')
|
|
1032
|
+
self._loss_plot = self._loss_ax.plot(
|
|
1033
|
+
[], [], linestyle=':', marker='o', markersize=2)[0]
|
|
1034
|
+
self._loss_back = self._fig.canvas.copy_from_bbox(self._loss_ax.bbox)
|
|
1035
|
+
|
|
1036
|
+
# prepare the action plots
|
|
1037
|
+
self._action_ax = {name: axes[idx + 1]
|
|
1038
|
+
for (idx, name) in enumerate(rddl.action_fluents)}
|
|
1039
|
+
self._action_plots = {}
|
|
1040
|
+
for name in rddl.action_fluents:
|
|
1041
|
+
ax = self._action_ax[name]
|
|
1042
|
+
if rddl.variable_ranges[name] == 'bool':
|
|
1043
|
+
vmin, vmax = 0.0, 1.0
|
|
1044
|
+
else:
|
|
1045
|
+
vmin, vmax = None, None
|
|
1046
|
+
action_dim = 1
|
|
1047
|
+
for dim in rddl.object_counts(rddl.variable_params[name]):
|
|
1048
|
+
action_dim *= dim
|
|
1049
|
+
action_plot = ax.pcolormesh(
|
|
1050
|
+
np.zeros((action_dim, horizon)),
|
|
1051
|
+
cmap='seismic', vmin=vmin, vmax=vmax)
|
|
1052
|
+
ax.set_aspect('auto')
|
|
1053
|
+
ax.set_xlabel('decision epoch')
|
|
1054
|
+
ax.set_ylabel(name)
|
|
1055
|
+
plt.colorbar(action_plot, ax=ax)
|
|
1056
|
+
self._action_plots[name] = action_plot
|
|
1057
|
+
self._action_back = {name: self._fig.canvas.copy_from_bbox(ax.bbox)
|
|
1058
|
+
for (name, ax) in self._action_ax.items()}
|
|
1059
|
+
|
|
1060
|
+
plt.tight_layout()
|
|
1061
|
+
plt.show(block=False)
|
|
1062
|
+
|
|
1063
|
+
def redraw(self, xticks, losses, actions) -> None:
|
|
1064
|
+
|
|
1065
|
+
# draw the loss curve
|
|
1066
|
+
self._fig.canvas.restore_region(self._loss_back)
|
|
1067
|
+
self._loss_plot.set_xdata(xticks)
|
|
1068
|
+
self._loss_plot.set_ydata(losses)
|
|
1069
|
+
self._loss_ax.set_xlim([0, len(xticks)])
|
|
1070
|
+
self._loss_ax.set_ylim([np.min(losses), np.max(losses)])
|
|
1071
|
+
self._loss_ax.draw_artist(self._loss_plot)
|
|
1072
|
+
self._fig.canvas.blit(self._loss_ax.bbox)
|
|
1073
|
+
|
|
1074
|
+
# draw the actions
|
|
1075
|
+
for (name, values) in actions.items():
|
|
1076
|
+
values = np.mean(values, axis=0, dtype=float)
|
|
1077
|
+
values = np.reshape(values, newshape=(values.shape[0], -1)).T
|
|
1078
|
+
self._fig.canvas.restore_region(self._action_back[name])
|
|
1079
|
+
self._action_plots[name].set_array(values)
|
|
1080
|
+
self._action_ax[name].draw_artist(self._action_plots[name])
|
|
1081
|
+
self._fig.canvas.blit(self._action_ax[name].bbox)
|
|
1082
|
+
self._action_plots[name].set_clim([np.min(values), np.max(values)])
|
|
1083
|
+
self._fig.canvas.draw()
|
|
1084
|
+
self._fig.canvas.flush_events()
|
|
1085
|
+
|
|
1086
|
+
def close(self) -> None:
|
|
1087
|
+
plt.close(self._fig)
|
|
1088
|
+
del self._loss_ax, self._action_ax, \
|
|
1089
|
+
self._loss_plot, self._action_plots, self._fig, \
|
|
1090
|
+
self._loss_back, self._action_back
|
|
1091
|
+
|
|
1092
|
+
|
|
1093
|
+
class JaxPlannerStatus(Enum):
|
|
1094
|
+
'''Represents the status of a policy update from the JAX planner,
|
|
1095
|
+
including whether the update resulted in nan gradient,
|
|
1096
|
+
whether progress was made, budget was reached, or other information that
|
|
1097
|
+
can be used to monitor and act based on the planner's progress.'''
|
|
1098
|
+
|
|
1099
|
+
NORMAL = 0
|
|
1100
|
+
NO_PROGRESS = 1
|
|
1101
|
+
PRECONDITION_POSSIBLY_UNSATISFIED = 2
|
|
1102
|
+
TIME_BUDGET_REACHED = 3
|
|
1103
|
+
ITER_BUDGET_REACHED = 4
|
|
1104
|
+
INVALID_GRADIENT = 5
|
|
1105
|
+
|
|
1106
|
+
def is_failure(self) -> bool:
|
|
1107
|
+
return self.value >= 3
|
|
1108
|
+
|
|
1109
|
+
|
|
907
1110
|
class JaxBackpropPlanner:
|
|
908
1111
|
'''A class for optimizing an action sequence in the given RDDL MDP using
|
|
909
1112
|
gradient descent.'''
|
|
910
1113
|
|
|
911
1114
|
def __init__(self, rddl: RDDLLiftedModel,
|
|
912
1115
|
plan: JaxPlan,
|
|
913
|
-
batch_size_train: int,
|
|
914
|
-
batch_size_test: int=None,
|
|
915
|
-
rollout_horizon: int=None,
|
|
1116
|
+
batch_size_train: int=32,
|
|
1117
|
+
batch_size_test: Optional[int]=None,
|
|
1118
|
+
rollout_horizon: Optional[int]=None,
|
|
916
1119
|
use64bit: bool=False,
|
|
917
|
-
action_bounds:
|
|
1120
|
+
action_bounds: Optional[Bounds]=None,
|
|
918
1121
|
optimizer: Callable[..., optax.GradientTransformation]=optax.rmsprop,
|
|
919
|
-
optimizer_kwargs:
|
|
920
|
-
clip_grad: float=None,
|
|
1122
|
+
optimizer_kwargs: Optional[Kwargs]=None,
|
|
1123
|
+
clip_grad: Optional[float]=None,
|
|
921
1124
|
logic: FuzzyLogic=FuzzyLogic(),
|
|
922
1125
|
use_symlog_reward: bool=False,
|
|
923
|
-
utility
|
|
924
|
-
|
|
1126
|
+
utility: Union[Callable[[jnp.ndarray], float], str]='mean',
|
|
1127
|
+
utility_kwargs: Optional[Kwargs]=None,
|
|
1128
|
+
cpfs_without_grad: Optional[Set[str]]=None,
|
|
1129
|
+
compile_non_fluent_exact: bool=True,
|
|
1130
|
+
logger: Optional[Logger]=None) -> None:
|
|
925
1131
|
'''Creates a new gradient-based algorithm for optimizing action sequences
|
|
926
1132
|
(plan) in the given RDDL. Some operations will be converted to their
|
|
927
1133
|
differentiable counterparts; the specific operations can be customized
|
|
@@ -946,9 +1152,16 @@ class JaxBackpropPlanner:
|
|
|
946
1152
|
:param use_symlog_reward: whether to use the symlog transform on the
|
|
947
1153
|
reward as a form of normalization
|
|
948
1154
|
:param utility: how to aggregate return observations to compute utility
|
|
949
|
-
of a policy or plan
|
|
1155
|
+
of a policy or plan; must be either a function mapping jax array to a
|
|
1156
|
+
scalar, or a a string identifying the utility function by name
|
|
1157
|
+
("mean", "mean_var", "entropic", or "cvar" are currently supported)
|
|
1158
|
+
:param utility_kwargs: additional keyword arguments to pass hyper-
|
|
1159
|
+
parameters to the utility function call
|
|
950
1160
|
:param cpfs_without_grad: which CPFs do not have gradients (use straight
|
|
951
1161
|
through gradient trick)
|
|
1162
|
+
:param compile_non_fluent_exact: whether non-fluent expressions
|
|
1163
|
+
are always compiled using exact JAX expressions
|
|
1164
|
+
:param logger: to log information about compilation to file
|
|
952
1165
|
'''
|
|
953
1166
|
self.rddl = rddl
|
|
954
1167
|
self.plan = plan
|
|
@@ -959,22 +1172,25 @@ class JaxBackpropPlanner:
|
|
|
959
1172
|
if rollout_horizon is None:
|
|
960
1173
|
rollout_horizon = rddl.horizon
|
|
961
1174
|
self.horizon = rollout_horizon
|
|
1175
|
+
if action_bounds is None:
|
|
1176
|
+
action_bounds = {}
|
|
962
1177
|
self._action_bounds = action_bounds
|
|
963
1178
|
self.use64bit = use64bit
|
|
964
1179
|
self._optimizer_name = optimizer
|
|
1180
|
+
if optimizer_kwargs is None:
|
|
1181
|
+
optimizer_kwargs = {'learning_rate': 0.1}
|
|
965
1182
|
self._optimizer_kwargs = optimizer_kwargs
|
|
966
1183
|
self.clip_grad = clip_grad
|
|
967
1184
|
|
|
968
1185
|
# set optimizer
|
|
969
1186
|
try:
|
|
970
1187
|
optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs)
|
|
971
|
-
except:
|
|
1188
|
+
except Exception as _:
|
|
972
1189
|
raise_warning(
|
|
973
1190
|
'Failed to inject hyperparameters into optax optimizer, '
|
|
974
1191
|
'rolling back to safer method: please note that modification of '
|
|
975
1192
|
'optimizer hyperparameters will not work, and it is '
|
|
976
|
-
'recommended to update
|
|
977
|
-
'red')
|
|
1193
|
+
'recommended to update optax and related packages.', 'red')
|
|
978
1194
|
optimizer = optimizer(**optimizer_kwargs)
|
|
979
1195
|
if clip_grad is None:
|
|
980
1196
|
self.optimizer = optimizer
|
|
@@ -983,22 +1199,68 @@ class JaxBackpropPlanner:
|
|
|
983
1199
|
optax.clip(clip_grad),
|
|
984
1200
|
optimizer
|
|
985
1201
|
)
|
|
986
|
-
|
|
1202
|
+
|
|
1203
|
+
# set utility
|
|
1204
|
+
if isinstance(utility, str):
|
|
1205
|
+
utility = utility.lower()
|
|
1206
|
+
if utility == 'mean':
|
|
1207
|
+
utility_fn = jnp.mean
|
|
1208
|
+
elif utility == 'mean_var':
|
|
1209
|
+
utility_fn = mean_variance_utility
|
|
1210
|
+
elif utility == 'entropic':
|
|
1211
|
+
utility_fn = entropic_utility
|
|
1212
|
+
elif utility == 'cvar':
|
|
1213
|
+
utility_fn = cvar_utility
|
|
1214
|
+
else:
|
|
1215
|
+
raise RDDLNotImplementedError(
|
|
1216
|
+
f'Utility function <{utility}> is not supported: '
|
|
1217
|
+
'must be one of ["mean", "mean_var", "entropic", "cvar"].')
|
|
1218
|
+
else:
|
|
1219
|
+
utility_fn = utility
|
|
1220
|
+
self.utility = utility_fn
|
|
1221
|
+
|
|
1222
|
+
if utility_kwargs is None:
|
|
1223
|
+
utility_kwargs = {}
|
|
1224
|
+
self.utility_kwargs = utility_kwargs
|
|
1225
|
+
|
|
987
1226
|
self.logic = logic
|
|
1227
|
+
self.logic.set_use64bit(self.use64bit)
|
|
988
1228
|
self.use_symlog_reward = use_symlog_reward
|
|
989
|
-
|
|
1229
|
+
if cpfs_without_grad is None:
|
|
1230
|
+
cpfs_without_grad = set()
|
|
990
1231
|
self.cpfs_without_grad = cpfs_without_grad
|
|
1232
|
+
self.compile_non_fluent_exact = compile_non_fluent_exact
|
|
1233
|
+
self.logger = logger
|
|
991
1234
|
|
|
992
1235
|
self._jax_compile_rddl()
|
|
993
1236
|
self._jax_compile_optimizer()
|
|
994
|
-
|
|
995
|
-
def
|
|
996
|
-
|
|
997
|
-
|
|
1237
|
+
|
|
1238
|
+
def _summarize_system(self) -> None:
|
|
1239
|
+
try:
|
|
1240
|
+
jaxlib_version = jax._src.lib.version_str
|
|
1241
|
+
except Exception as _:
|
|
1242
|
+
jaxlib_version = 'N/A'
|
|
1243
|
+
try:
|
|
1244
|
+
devices_short = ', '.join(
|
|
1245
|
+
map(str, jax._src.xla_bridge.devices())).replace('\n', '')
|
|
1246
|
+
except Exception as _:
|
|
1247
|
+
devices_short = 'N/A'
|
|
1248
|
+
print('\n'
|
|
1249
|
+
f'JAX Planner version {__version__}\n'
|
|
1250
|
+
f'Python {sys.version}\n'
|
|
1251
|
+
f'jax {jax.version.__version__}, jaxlib {jaxlib_version}, '
|
|
1252
|
+
f'numpy {np.__version__}\n'
|
|
1253
|
+
f'devices: {devices_short}\n')
|
|
1254
|
+
|
|
1255
|
+
def summarize_hyperparameters(self) -> None:
|
|
1256
|
+
print(f'objective hyper-parameters:\n'
|
|
1257
|
+
f' utility_fn ={self.utility.__name__}\n'
|
|
1258
|
+
f' utility args ={self.utility_kwargs}\n'
|
|
998
1259
|
f' use_symlog ={self.use_symlog_reward}\n'
|
|
999
1260
|
f' lookahead ={self.horizon}\n'
|
|
1000
|
-
f' model relaxation={type(self.logic).__name__}\n'
|
|
1001
1261
|
f' action_bounds ={self._action_bounds}\n'
|
|
1262
|
+
f' fuzzy logic type={type(self.logic).__name__}\n'
|
|
1263
|
+
f' nonfluents exact={self.compile_non_fluent_exact}\n'
|
|
1002
1264
|
f' cpfs_no_gradient={self.cpfs_without_grad}\n'
|
|
1003
1265
|
f'optimizer hyper-parameters:\n'
|
|
1004
1266
|
f' use_64_bit ={self.use64bit}\n'
|
|
@@ -1010,6 +1272,10 @@ class JaxBackpropPlanner:
|
|
|
1010
1272
|
self.plan.summarize_hyperparameters()
|
|
1011
1273
|
self.logic.summarize_hyperparameters()
|
|
1012
1274
|
|
|
1275
|
+
# ===========================================================================
|
|
1276
|
+
# COMPILATION SUBROUTINES
|
|
1277
|
+
# ===========================================================================
|
|
1278
|
+
|
|
1013
1279
|
def _jax_compile_rddl(self):
|
|
1014
1280
|
rddl = self.rddl
|
|
1015
1281
|
|
|
@@ -1017,13 +1283,18 @@ class JaxBackpropPlanner:
|
|
|
1017
1283
|
self.compiled = JaxRDDLCompilerWithGrad(
|
|
1018
1284
|
rddl=rddl,
|
|
1019
1285
|
logic=self.logic,
|
|
1286
|
+
logger=self.logger,
|
|
1020
1287
|
use64bit=self.use64bit,
|
|
1021
|
-
cpfs_without_grad=self.cpfs_without_grad
|
|
1022
|
-
|
|
1288
|
+
cpfs_without_grad=self.cpfs_without_grad,
|
|
1289
|
+
compile_non_fluent_exact=self.compile_non_fluent_exact)
|
|
1290
|
+
self.compiled.compile(log_jax_expr=True, heading='RELAXED MODEL')
|
|
1023
1291
|
|
|
1024
1292
|
# Jax compilation of the exact RDDL for testing
|
|
1025
|
-
self.test_compiled = JaxRDDLCompiler(
|
|
1026
|
-
|
|
1293
|
+
self.test_compiled = JaxRDDLCompiler(
|
|
1294
|
+
rddl=rddl,
|
|
1295
|
+
logger=self.logger,
|
|
1296
|
+
use64bit=self.use64bit)
|
|
1297
|
+
self.test_compiled.compile(log_jax_expr=True, heading='EXACT MODEL')
|
|
1027
1298
|
|
|
1028
1299
|
def _jax_compile_optimizer(self):
|
|
1029
1300
|
|
|
@@ -1051,11 +1322,10 @@ class JaxBackpropPlanner:
|
|
|
1051
1322
|
|
|
1052
1323
|
# losses
|
|
1053
1324
|
train_loss = self._jax_loss(train_rollouts, use_symlog=self.use_symlog_reward)
|
|
1054
|
-
self.train_loss = jax.jit(train_loss)
|
|
1055
1325
|
self.test_loss = jax.jit(self._jax_loss(test_rollouts, use_symlog=False))
|
|
1056
1326
|
|
|
1057
1327
|
# optimization
|
|
1058
|
-
self.update =
|
|
1328
|
+
self.update = self._jax_update(train_loss)
|
|
1059
1329
|
|
|
1060
1330
|
def _jax_return(self, use_symlog):
|
|
1061
1331
|
gamma = self.rddl.discount
|
|
@@ -1068,13 +1338,14 @@ class JaxBackpropPlanner:
|
|
|
1068
1338
|
rewards = rewards * discount[jnp.newaxis, ...]
|
|
1069
1339
|
returns = jnp.sum(rewards, axis=1)
|
|
1070
1340
|
if use_symlog:
|
|
1071
|
-
returns = jnp.sign(returns) * jnp.
|
|
1341
|
+
returns = jnp.sign(returns) * jnp.log(1.0 + jnp.abs(returns))
|
|
1072
1342
|
return returns
|
|
1073
1343
|
|
|
1074
1344
|
return _jax_wrapped_returns
|
|
1075
1345
|
|
|
1076
1346
|
def _jax_loss(self, rollouts, use_symlog=False):
|
|
1077
|
-
utility_fn = self.utility
|
|
1347
|
+
utility_fn = self.utility
|
|
1348
|
+
utility_kwargs = self.utility_kwargs
|
|
1078
1349
|
_jax_wrapped_returns = self._jax_return(use_symlog)
|
|
1079
1350
|
|
|
1080
1351
|
# the loss is the average cumulative reward across all roll-outs
|
|
@@ -1083,7 +1354,7 @@ class JaxBackpropPlanner:
|
|
|
1083
1354
|
log = rollouts(key, policy_params, hyperparams, subs, model_params)
|
|
1084
1355
|
rewards = log['reward']
|
|
1085
1356
|
returns = _jax_wrapped_returns(rewards)
|
|
1086
|
-
utility = utility_fn(returns)
|
|
1357
|
+
utility = utility_fn(returns, **utility_kwargs)
|
|
1087
1358
|
loss = -utility
|
|
1088
1359
|
return loss, log
|
|
1089
1360
|
|
|
@@ -1096,7 +1367,7 @@ class JaxBackpropPlanner:
|
|
|
1096
1367
|
def _jax_wrapped_init_policy(key, hyperparams, subs):
|
|
1097
1368
|
policy_params = init(key, hyperparams, subs)
|
|
1098
1369
|
opt_state = optimizer.init(policy_params)
|
|
1099
|
-
return policy_params, opt_state
|
|
1370
|
+
return policy_params, opt_state, None
|
|
1100
1371
|
|
|
1101
1372
|
return _jax_wrapped_init_policy
|
|
1102
1373
|
|
|
@@ -1107,17 +1378,18 @@ class JaxBackpropPlanner:
|
|
|
1107
1378
|
# calculate the plan gradient w.r.t. return loss and update optimizer
|
|
1108
1379
|
# also perform a projection step to satisfy constraints on actions
|
|
1109
1380
|
def _jax_wrapped_plan_update(key, policy_params, hyperparams,
|
|
1110
|
-
subs, model_params, opt_state):
|
|
1111
|
-
grad_fn = jax.
|
|
1112
|
-
|
|
1381
|
+
subs, model_params, opt_state, opt_aux):
|
|
1382
|
+
grad_fn = jax.value_and_grad(loss, argnums=1, has_aux=True)
|
|
1383
|
+
(loss_val, log), grad = grad_fn(
|
|
1384
|
+
key, policy_params, hyperparams, subs, model_params)
|
|
1113
1385
|
updates, opt_state = optimizer.update(grad, opt_state)
|
|
1114
1386
|
policy_params = optax.apply_updates(policy_params, updates)
|
|
1115
1387
|
policy_params, converged = projection(policy_params, hyperparams)
|
|
1116
1388
|
log['grad'] = grad
|
|
1117
1389
|
log['updates'] = updates
|
|
1118
|
-
return policy_params, converged, opt_state, log
|
|
1390
|
+
return policy_params, converged, opt_state, None, loss_val, log
|
|
1119
1391
|
|
|
1120
|
-
return _jax_wrapped_plan_update
|
|
1392
|
+
return jax.jit(_jax_wrapped_plan_update)
|
|
1121
1393
|
|
|
1122
1394
|
def _batched_init_subs(self, subs):
|
|
1123
1395
|
rddl = self.rddl
|
|
@@ -1145,13 +1417,15 @@ class JaxBackpropPlanner:
|
|
|
1145
1417
|
|
|
1146
1418
|
return init_train, init_test
|
|
1147
1419
|
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1420
|
+
# ===========================================================================
|
|
1421
|
+
# OPTIMIZE API
|
|
1422
|
+
# ===========================================================================
|
|
1423
|
+
|
|
1424
|
+
def optimize(self, *args, **kwargs) -> Dict[str, Any]:
|
|
1425
|
+
''' Compute an optimal policy or plan. Return the callback from training.
|
|
1151
1426
|
|
|
1152
|
-
:param key: JAX PRNG key
|
|
1427
|
+
:param key: JAX PRNG key (derived from clock if not provided)
|
|
1153
1428
|
:param epochs: the maximum number of steps of gradient descent
|
|
1154
|
-
:param the maximum number of steps of gradient descent
|
|
1155
1429
|
:param train_seconds: total time allocated for gradient descent
|
|
1156
1430
|
:param plot_step: frequency to plot the plan and save result to disk
|
|
1157
1431
|
:param model_params: optional model-parameters to override default
|
|
@@ -1162,33 +1436,44 @@ class JaxBackpropPlanner:
|
|
|
1162
1436
|
:param guess: initial policy parameters: if None will use the initializer
|
|
1163
1437
|
specified in this instance
|
|
1164
1438
|
:param verbose: not print (0), print summary (1), print progress (2)
|
|
1165
|
-
:param
|
|
1166
|
-
|
|
1439
|
+
:param test_rolling_window: the test return is averaged on a rolling
|
|
1440
|
+
window of the past test_rolling_window returns when updating the best
|
|
1441
|
+
parameters found so far
|
|
1442
|
+
:param tqdm_position: position of tqdm progress bar (for multiprocessing)
|
|
1167
1443
|
'''
|
|
1168
1444
|
it = self.optimize_generator(*args, **kwargs)
|
|
1169
|
-
|
|
1170
|
-
if
|
|
1171
|
-
|
|
1445
|
+
|
|
1446
|
+
# if the python is C-compiled then the deque is native C and much faster
|
|
1447
|
+
# than naively exhausting iterator, but not if the python is some other
|
|
1448
|
+
# version (e.g. PyPi); for details, see
|
|
1449
|
+
# https://stackoverflow.com/questions/50937966/fastest-most-pythonic-way-to-consume-an-iterator
|
|
1450
|
+
callback = None
|
|
1451
|
+
if sys.implementation.name == 'cpython':
|
|
1452
|
+
last_callback = deque(it, maxlen=1)
|
|
1453
|
+
if last_callback:
|
|
1454
|
+
callback = last_callback.pop()
|
|
1172
1455
|
else:
|
|
1173
|
-
|
|
1456
|
+
for callback in it:
|
|
1457
|
+
pass
|
|
1458
|
+
return callback
|
|
1174
1459
|
|
|
1175
|
-
def optimize_generator(self, key: random.PRNGKey,
|
|
1460
|
+
def optimize_generator(self, key: Optional[random.PRNGKey]=None,
|
|
1176
1461
|
epochs: int=999999,
|
|
1177
1462
|
train_seconds: float=120.,
|
|
1178
|
-
plot_step: int=None,
|
|
1179
|
-
model_params: Dict[str,
|
|
1180
|
-
policy_hyperparams: Dict[str,
|
|
1181
|
-
subs: Dict[str,
|
|
1182
|
-
guess:
|
|
1463
|
+
plot_step: Optional[int]=None,
|
|
1464
|
+
model_params: Optional[Dict[str, Any]]=None,
|
|
1465
|
+
policy_hyperparams: Optional[Dict[str, Any]]=None,
|
|
1466
|
+
subs: Optional[Dict[str, Any]]=None,
|
|
1467
|
+
guess: Optional[Pytree]=None,
|
|
1183
1468
|
verbose: int=2,
|
|
1184
|
-
|
|
1185
|
-
|
|
1469
|
+
test_rolling_window: int=10,
|
|
1470
|
+
tqdm_position: Optional[int]=None) -> Generator[Dict[str, Any], None, None]:
|
|
1471
|
+
'''Returns a generator for computing an optimal policy or plan.
|
|
1186
1472
|
Generator can be iterated over to lazily optimize the plan, yielding
|
|
1187
1473
|
a dictionary of intermediate computations.
|
|
1188
1474
|
|
|
1189
|
-
:param key: JAX PRNG key
|
|
1475
|
+
:param key: JAX PRNG key (derived from clock if not provided)
|
|
1190
1476
|
:param epochs: the maximum number of steps of gradient descent
|
|
1191
|
-
:param the maximum number of steps of gradient descent
|
|
1192
1477
|
:param train_seconds: total time allocated for gradient descent
|
|
1193
1478
|
:param plot_step: frequency to plot the plan and save result to disk
|
|
1194
1479
|
:param model_params: optional model-parameters to override default
|
|
@@ -1199,26 +1484,53 @@ class JaxBackpropPlanner:
|
|
|
1199
1484
|
:param guess: initial policy parameters: if None will use the initializer
|
|
1200
1485
|
specified in this instance
|
|
1201
1486
|
:param verbose: not print (0), print summary (1), print progress (2)
|
|
1487
|
+
:param test_rolling_window: the test return is averaged on a rolling
|
|
1488
|
+
window of the past test_rolling_window returns when updating the best
|
|
1489
|
+
parameters found so far
|
|
1202
1490
|
:param tqdm_position: position of tqdm progress bar (for multiprocessing)
|
|
1203
1491
|
'''
|
|
1204
1492
|
verbose = int(verbose)
|
|
1205
1493
|
start_time = time.time()
|
|
1206
1494
|
elapsed_outside_loop = 0
|
|
1207
1495
|
|
|
1496
|
+
# if PRNG key is not provided
|
|
1497
|
+
if key is None:
|
|
1498
|
+
key = random.PRNGKey(round(time.time() * 1000))
|
|
1499
|
+
|
|
1500
|
+
# if policy_hyperparams is not provided
|
|
1501
|
+
if policy_hyperparams is None:
|
|
1502
|
+
raise_warning('policy_hyperparams is not set, setting 1.0 for '
|
|
1503
|
+
'all action-fluents which could be suboptimal.')
|
|
1504
|
+
policy_hyperparams = {action: 1.0
|
|
1505
|
+
for action in self.rddl.action_fluents}
|
|
1506
|
+
|
|
1507
|
+
# if policy_hyperparams is a scalar
|
|
1508
|
+
elif isinstance(policy_hyperparams, (int, float, np.number)):
|
|
1509
|
+
raise_warning(f'policy_hyperparams is {policy_hyperparams}, '
|
|
1510
|
+
'setting this value for all action-fluents.')
|
|
1511
|
+
hyperparam_value = float(policy_hyperparams)
|
|
1512
|
+
policy_hyperparams = {action: hyperparam_value
|
|
1513
|
+
for action in self.rddl.action_fluents}
|
|
1514
|
+
|
|
1208
1515
|
# print summary of parameters:
|
|
1209
1516
|
if verbose >= 1:
|
|
1210
|
-
|
|
1211
|
-
'JAX PLANNER PARAMETER SUMMARY\n'
|
|
1212
|
-
'==============================================')
|
|
1517
|
+
self._summarize_system()
|
|
1213
1518
|
self.summarize_hyperparameters()
|
|
1214
1519
|
print(f'optimize() call hyper-parameters:\n'
|
|
1520
|
+
f' PRNG key ={key}\n'
|
|
1215
1521
|
f' max_iterations ={epochs}\n'
|
|
1216
1522
|
f' max_seconds ={train_seconds}\n'
|
|
1217
1523
|
f' model_params ={model_params}\n'
|
|
1218
1524
|
f' policy_hyper_params={policy_hyperparams}\n'
|
|
1219
1525
|
f' override_subs_dict ={subs is not None}\n'
|
|
1220
|
-
f' provide_param_guess={guess is not None}\n'
|
|
1221
|
-
f'
|
|
1526
|
+
f' provide_param_guess={guess is not None}\n'
|
|
1527
|
+
f' test_rolling_window={test_rolling_window}\n'
|
|
1528
|
+
f' plot_frequency ={plot_step}\n'
|
|
1529
|
+
f' verbose ={verbose}\n')
|
|
1530
|
+
if verbose >= 2 and self.compiled.relaxations:
|
|
1531
|
+
print('Some RDDL operations are non-differentiable, '
|
|
1532
|
+
'replacing them with differentiable relaxations:')
|
|
1533
|
+
print(self.compiled.summarize_model_relaxations())
|
|
1222
1534
|
|
|
1223
1535
|
# compute a batched version of the initial values
|
|
1224
1536
|
if subs is None:
|
|
@@ -1245,14 +1557,26 @@ class JaxBackpropPlanner:
|
|
|
1245
1557
|
# initialize policy parameters
|
|
1246
1558
|
if guess is None:
|
|
1247
1559
|
key, subkey = random.split(key)
|
|
1248
|
-
policy_params, opt_state = self.initialize(
|
|
1560
|
+
policy_params, opt_state, opt_aux = self.initialize(
|
|
1249
1561
|
subkey, policy_hyperparams, train_subs)
|
|
1250
1562
|
else:
|
|
1251
1563
|
policy_params = guess
|
|
1252
1564
|
opt_state = self.optimizer.init(policy_params)
|
|
1565
|
+
opt_aux = None
|
|
1566
|
+
|
|
1567
|
+
# initialize running statistics
|
|
1253
1568
|
best_params, best_loss, best_grad = policy_params, jnp.inf, jnp.inf
|
|
1254
1569
|
last_iter_improve = 0
|
|
1570
|
+
rolling_test_loss = RollingMean(test_rolling_window)
|
|
1255
1571
|
log = {}
|
|
1572
|
+
status = JaxPlannerStatus.NORMAL
|
|
1573
|
+
|
|
1574
|
+
# initialize plot area
|
|
1575
|
+
if plot_step is None or plot_step <= 0 or plt is None:
|
|
1576
|
+
plot = None
|
|
1577
|
+
else:
|
|
1578
|
+
plot = JaxPlannerPlot(self.rddl, self.horizon)
|
|
1579
|
+
xticks, loss_values = [], []
|
|
1256
1580
|
|
|
1257
1581
|
# training loop
|
|
1258
1582
|
iters = range(epochs)
|
|
@@ -1260,25 +1584,25 @@ class JaxBackpropPlanner:
|
|
|
1260
1584
|
iters = tqdm(iters, total=100, position=tqdm_position)
|
|
1261
1585
|
|
|
1262
1586
|
for it in iters:
|
|
1587
|
+
status = JaxPlannerStatus.NORMAL
|
|
1263
1588
|
|
|
1264
1589
|
# update the parameters of the plan
|
|
1265
|
-
key,
|
|
1266
|
-
policy_params, converged, opt_state, train_log =
|
|
1267
|
-
|
|
1268
|
-
|
|
1590
|
+
key, subkey = random.split(key)
|
|
1591
|
+
policy_params, converged, opt_state, opt_aux, train_loss, train_log = \
|
|
1592
|
+
self.update(subkey, policy_params, policy_hyperparams,
|
|
1593
|
+
train_subs, model_params, opt_state, opt_aux)
|
|
1269
1594
|
if not np.all(converged):
|
|
1270
1595
|
raise_warning(
|
|
1271
1596
|
'Projected gradient method for satisfying action concurrency '
|
|
1272
1597
|
'constraints reached the iteration limit: plan is possibly '
|
|
1273
1598
|
'invalid for the current instance.', 'red')
|
|
1599
|
+
status = JaxPlannerStatus.PRECONDITION_POSSIBLY_UNSATISFIED
|
|
1274
1600
|
|
|
1275
1601
|
# evaluate losses
|
|
1276
|
-
train_loss, _ = self.train_loss(
|
|
1277
|
-
subkey2, policy_params, policy_hyperparams,
|
|
1278
|
-
train_subs, model_params)
|
|
1279
1602
|
test_loss, log = self.test_loss(
|
|
1280
|
-
|
|
1603
|
+
subkey, policy_params, policy_hyperparams,
|
|
1281
1604
|
test_subs, model_params_test)
|
|
1605
|
+
test_loss = rolling_test_loss.update(test_loss)
|
|
1282
1606
|
|
|
1283
1607
|
# record the best plan so far
|
|
1284
1608
|
if test_loss < best_loss:
|
|
@@ -1287,21 +1611,45 @@ class JaxBackpropPlanner:
|
|
|
1287
1611
|
last_iter_improve = it
|
|
1288
1612
|
|
|
1289
1613
|
# save the plan figure
|
|
1290
|
-
if
|
|
1291
|
-
|
|
1292
|
-
|
|
1614
|
+
if plot is not None and it % plot_step == 0:
|
|
1615
|
+
xticks.append(it // plot_step)
|
|
1616
|
+
loss_values.append(test_loss.item())
|
|
1617
|
+
action_values = {name: values
|
|
1618
|
+
for (name, values) in log['fluents'].items()
|
|
1619
|
+
if name in self.rddl.action_fluents}
|
|
1620
|
+
plot.redraw(xticks, loss_values, action_values)
|
|
1293
1621
|
|
|
1294
1622
|
# if the progress bar is used
|
|
1295
1623
|
elapsed = time.time() - start_time - elapsed_outside_loop
|
|
1296
1624
|
if verbose >= 2:
|
|
1297
1625
|
iters.n = int(100 * min(1, max(elapsed / train_seconds, it / epochs)))
|
|
1298
1626
|
iters.set_description(
|
|
1299
|
-
f'[{tqdm_position}] {it:6} it / {-train_loss:14.
|
|
1300
|
-
f'{-test_loss:14.
|
|
1627
|
+
f'[{tqdm_position}] {it:6} it / {-train_loss:14.6f} train / '
|
|
1628
|
+
f'{-test_loss:14.6f} test / {-best_loss:14.6f} best')
|
|
1629
|
+
|
|
1630
|
+
# reached computation budget
|
|
1631
|
+
if elapsed >= train_seconds:
|
|
1632
|
+
status = JaxPlannerStatus.TIME_BUDGET_REACHED
|
|
1633
|
+
if it >= epochs - 1:
|
|
1634
|
+
status = JaxPlannerStatus.ITER_BUDGET_REACHED
|
|
1635
|
+
|
|
1636
|
+
# numerical error
|
|
1637
|
+
if not np.isfinite(train_loss):
|
|
1638
|
+
raise_warning(
|
|
1639
|
+
f'Aborting JAX planner due to invalid train loss {train_loss}.',
|
|
1640
|
+
'red')
|
|
1641
|
+
status = JaxPlannerStatus.INVALID_GRADIENT
|
|
1642
|
+
|
|
1643
|
+
# no progress
|
|
1644
|
+
grad_norm_zero, _ = jax.tree_util.tree_flatten(
|
|
1645
|
+
jax.tree_map(lambda x: np.allclose(x, 0), train_log['grad']))
|
|
1646
|
+
if np.all(grad_norm_zero):
|
|
1647
|
+
status = JaxPlannerStatus.NO_PROGRESS
|
|
1301
1648
|
|
|
1302
1649
|
# return a callback
|
|
1303
1650
|
start_time_outside = time.time()
|
|
1304
1651
|
yield {
|
|
1652
|
+
'status': status,
|
|
1305
1653
|
'iteration': it,
|
|
1306
1654
|
'train_return':-train_loss,
|
|
1307
1655
|
'test_return':-test_loss,
|
|
@@ -1318,16 +1666,15 @@ class JaxBackpropPlanner:
|
|
|
1318
1666
|
}
|
|
1319
1667
|
elapsed_outside_loop += (time.time() - start_time_outside)
|
|
1320
1668
|
|
|
1321
|
-
#
|
|
1322
|
-
if
|
|
1323
|
-
break
|
|
1324
|
-
|
|
1325
|
-
# numerical error
|
|
1326
|
-
if not np.isfinite(train_loss):
|
|
1669
|
+
# abortion check
|
|
1670
|
+
if status.is_failure():
|
|
1327
1671
|
break
|
|
1328
|
-
|
|
1672
|
+
|
|
1673
|
+
# release resources
|
|
1329
1674
|
if verbose >= 2:
|
|
1330
1675
|
iters.close()
|
|
1676
|
+
if plot is not None:
|
|
1677
|
+
plot.close()
|
|
1331
1678
|
|
|
1332
1679
|
# validate the test return
|
|
1333
1680
|
if log:
|
|
@@ -1337,24 +1684,23 @@ class JaxBackpropPlanner:
|
|
|
1337
1684
|
if messages:
|
|
1338
1685
|
messages = '\n'.join(messages)
|
|
1339
1686
|
raise_warning('The JAX compiler encountered the following '
|
|
1340
|
-
'
|
|
1687
|
+
'error(s) in the original RDDL formulation '
|
|
1341
1688
|
f'during test evaluation:\n{messages}', 'red')
|
|
1342
1689
|
|
|
1343
1690
|
# summarize and test for convergence
|
|
1344
1691
|
if verbose >= 1:
|
|
1345
|
-
grad_norm = jax.tree_map(
|
|
1346
|
-
lambda x: np.array(jnp.linalg.norm(x)).item(), best_grad)
|
|
1692
|
+
grad_norm = jax.tree_map(lambda x: np.linalg.norm(x).item(), best_grad)
|
|
1347
1693
|
diagnosis = self._perform_diagnosis(
|
|
1348
|
-
last_iter_improve,
|
|
1349
|
-
-train_loss, -test_loss, -best_loss, grad_norm)
|
|
1694
|
+
last_iter_improve, -train_loss, -test_loss, -best_loss, grad_norm)
|
|
1350
1695
|
print(f'summary of optimization:\n'
|
|
1696
|
+
f' status_code ={status}\n'
|
|
1351
1697
|
f' time_elapsed ={elapsed}\n'
|
|
1352
1698
|
f' iterations ={it}\n'
|
|
1353
1699
|
f' best_objective={-best_loss}\n'
|
|
1354
|
-
f'
|
|
1700
|
+
f' best_grad_norm={grad_norm}\n'
|
|
1355
1701
|
f'diagnosis: {diagnosis}\n')
|
|
1356
1702
|
|
|
1357
|
-
def _perform_diagnosis(self, last_iter_improve,
|
|
1703
|
+
def _perform_diagnosis(self, last_iter_improve,
|
|
1358
1704
|
train_return, test_return, best_return, grad_norm):
|
|
1359
1705
|
max_grad_norm = max(jax.tree_util.tree_leaves(grad_norm))
|
|
1360
1706
|
grad_is_zero = np.allclose(max_grad_norm, 0)
|
|
@@ -1373,20 +1719,20 @@ class JaxBackpropPlanner:
|
|
|
1373
1719
|
if grad_is_zero:
|
|
1374
1720
|
return termcolor.colored(
|
|
1375
1721
|
'[FAILURE] no progress was made, '
|
|
1376
|
-
f'and max grad norm
|
|
1377
|
-
'likely stuck in a plateau.', 'red')
|
|
1722
|
+
f'and max grad norm {max_grad_norm:.6f} is zero: '
|
|
1723
|
+
'solver likely stuck in a plateau.', 'red')
|
|
1378
1724
|
else:
|
|
1379
1725
|
return termcolor.colored(
|
|
1380
1726
|
'[FAILURE] no progress was made, '
|
|
1381
|
-
f'but max grad norm
|
|
1382
|
-
'likely
|
|
1727
|
+
f'but max grad norm {max_grad_norm:.6f} is non-zero: '
|
|
1728
|
+
'likely poor learning rate or other hyper-parameter.', 'red')
|
|
1383
1729
|
|
|
1384
1730
|
# model is likely poor IF:
|
|
1385
1731
|
# 1. the train and test return disagree
|
|
1386
1732
|
if not (validation_error < 20):
|
|
1387
1733
|
return termcolor.colored(
|
|
1388
1734
|
'[WARNING] progress was made, '
|
|
1389
|
-
f'but relative train
|
|
1735
|
+
f'but relative train-test error {validation_error:.6f} is high: '
|
|
1390
1736
|
'likely poor model relaxation around the solution, '
|
|
1391
1737
|
'or the batch size is too small.', 'yellow')
|
|
1392
1738
|
|
|
@@ -1397,208 +1743,216 @@ class JaxBackpropPlanner:
|
|
|
1397
1743
|
if not (return_to_grad_norm > 1):
|
|
1398
1744
|
return termcolor.colored(
|
|
1399
1745
|
'[WARNING] progress was made, '
|
|
1400
|
-
f'but max grad norm
|
|
1401
|
-
'likely
|
|
1402
|
-
'or the model is not smooth around the solution, '
|
|
1746
|
+
f'but max grad norm {max_grad_norm:.6f} is high: '
|
|
1747
|
+
'likely the solution is not locally optimal, '
|
|
1748
|
+
'or the relaxed model is not smooth around the solution, '
|
|
1403
1749
|
'or the batch size is too small.', 'yellow')
|
|
1404
1750
|
|
|
1405
1751
|
# likely successful
|
|
1406
1752
|
return termcolor.colored(
|
|
1407
|
-
'[SUCCESS] planner
|
|
1753
|
+
'[SUCCESS] planner has converged successfully '
|
|
1408
1754
|
'(note: not all potential problems can be ruled out).', 'green')
|
|
1409
1755
|
|
|
1410
1756
|
def get_action(self, key: random.PRNGKey,
|
|
1411
|
-
params:
|
|
1757
|
+
params: Pytree,
|
|
1412
1758
|
step: int,
|
|
1413
|
-
subs: Dict,
|
|
1414
|
-
policy_hyperparams: Dict[str,
|
|
1759
|
+
subs: Dict[str, Any],
|
|
1760
|
+
policy_hyperparams: Optional[Dict[str, Any]]=None) -> Dict[str, Any]:
|
|
1415
1761
|
'''Returns an action dictionary from the policy or plan with the given
|
|
1416
1762
|
parameters.
|
|
1417
1763
|
|
|
1418
1764
|
:param key: the JAX PRNG key
|
|
1419
1765
|
:param params: the trainable parameter PyTree of the policy
|
|
1420
1766
|
:param step: the time step at which decision is made
|
|
1421
|
-
:param policy_hyperparams: hyper-parameters for the policy/plan, such as
|
|
1422
|
-
weights for sigmoid wrapping boolean actions
|
|
1423
1767
|
:param subs: the dict of pvariables
|
|
1768
|
+
:param policy_hyperparams: hyper-parameters for the policy/plan, such as
|
|
1769
|
+
weights for sigmoid wrapping boolean actions (optional)
|
|
1424
1770
|
'''
|
|
1425
1771
|
|
|
1426
1772
|
# check compatibility of the subs dictionary
|
|
1427
|
-
for var in subs.
|
|
1773
|
+
for (var, values) in subs.items():
|
|
1774
|
+
|
|
1775
|
+
# must not be grounded
|
|
1428
1776
|
if RDDLPlanningModel.FLUENT_SEP in var \
|
|
1429
1777
|
or RDDLPlanningModel.OBJECT_SEP in var:
|
|
1430
|
-
raise
|
|
1431
|
-
|
|
1432
|
-
|
|
1433
|
-
|
|
1434
|
-
|
|
1778
|
+
raise ValueError(f'State dictionary passed to the JAX policy is '
|
|
1779
|
+
f'grounded, since it contains the key <{var}>, '
|
|
1780
|
+
f'but a vectorized environment is required: '
|
|
1781
|
+
f'please make sure vectorized=True in the RDDLEnv.')
|
|
1782
|
+
|
|
1783
|
+
# must be numeric array
|
|
1784
|
+
# exception is for POMDPs at 1st epoch when observ-fluents are None
|
|
1785
|
+
if not jnp.issubdtype(values.dtype, jnp.number) \
|
|
1786
|
+
and not jnp.issubdtype(values.dtype, jnp.bool_):
|
|
1787
|
+
if step == 0 and var in self.rddl.observ_fluents:
|
|
1788
|
+
subs[var] = self.test_compiled.init_values[var]
|
|
1789
|
+
else:
|
|
1790
|
+
raise ValueError(f'Values assigned to pvariable {var} are '
|
|
1791
|
+
f'non-numeric of type {values.dtype}: {values}.')
|
|
1792
|
+
|
|
1435
1793
|
# cast device arrays to numpy
|
|
1436
1794
|
actions = self.test_policy(key, params, policy_hyperparams, step, subs)
|
|
1437
1795
|
actions = jax.tree_map(np.asarray, actions)
|
|
1438
1796
|
return actions
|
|
1439
|
-
|
|
1440
|
-
def _plot_actions(self, key, params, hyperparams, subs, it):
|
|
1441
|
-
rddl = self.rddl
|
|
1442
|
-
try:
|
|
1443
|
-
import matplotlib.pyplot as plt
|
|
1444
|
-
except Exception:
|
|
1445
|
-
print('matplotlib is not installed, aborting plot...')
|
|
1446
|
-
return
|
|
1447
|
-
|
|
1448
|
-
# predict actions from the trained policy or plan
|
|
1449
|
-
actions = self.test_rollouts(key, params, hyperparams, subs, {})['action']
|
|
1450
|
-
|
|
1451
|
-
# plot the action sequences as color maps
|
|
1452
|
-
fig, axs = plt.subplots(nrows=len(actions), constrained_layout=True)
|
|
1453
|
-
for (ax, name) in zip(axs, actions):
|
|
1454
|
-
action = np.mean(actions[name], axis=0, dtype=float)
|
|
1455
|
-
action = np.reshape(action, newshape=(action.shape[0], -1)).T
|
|
1456
|
-
if rddl.variable_ranges[name] == 'bool':
|
|
1457
|
-
vmin, vmax = 0.0, 1.0
|
|
1458
|
-
else:
|
|
1459
|
-
vmin, vmax = None, None
|
|
1460
|
-
img = ax.imshow(
|
|
1461
|
-
action, vmin=vmin, vmax=vmax, cmap='seismic', aspect='auto')
|
|
1462
|
-
ax.set_xlabel('time')
|
|
1463
|
-
ax.set_ylabel(name)
|
|
1464
|
-
plt.colorbar(img, ax=ax)
|
|
1465
|
-
|
|
1466
|
-
# write plot to disk
|
|
1467
|
-
plt.savefig(f'plan_{rddl.domain_name}_{rddl.instance_name}_{it}.pdf',
|
|
1468
|
-
bbox_inches='tight')
|
|
1469
|
-
plt.clf()
|
|
1470
|
-
plt.close(fig)
|
|
1471
1797
|
|
|
1472
1798
|
|
|
1473
|
-
class
|
|
1799
|
+
class JaxLineSearchPlanner(JaxBackpropPlanner):
|
|
1474
1800
|
'''A class for optimizing an action sequence in the given RDDL MDP using
|
|
1475
|
-
|
|
1801
|
+
linear search gradient descent, with the Armijo condition.'''
|
|
1476
1802
|
|
|
1477
1803
|
def __init__(self, *args,
|
|
1478
1804
|
optimizer: Callable[..., optax.GradientTransformation]=optax.sgd,
|
|
1479
|
-
optimizer_kwargs:
|
|
1480
|
-
|
|
1805
|
+
optimizer_kwargs: Kwargs={'learning_rate': 1.0},
|
|
1806
|
+
decay: float=0.8,
|
|
1481
1807
|
c: float=0.1,
|
|
1482
|
-
|
|
1483
|
-
|
|
1808
|
+
step_max: float=1.0,
|
|
1809
|
+
step_min: float=1e-6,
|
|
1484
1810
|
**kwargs) -> None:
|
|
1485
1811
|
'''Creates a new gradient-based algorithm for optimizing action sequences
|
|
1486
|
-
(plan) in the given RDDL using
|
|
1812
|
+
(plan) in the given RDDL using line search. All arguments are the
|
|
1487
1813
|
same as in the parent class, except:
|
|
1488
1814
|
|
|
1489
|
-
:param
|
|
1490
|
-
:param c: coefficient in Armijo condition
|
|
1491
|
-
:param
|
|
1492
|
-
:param
|
|
1815
|
+
:param decay: reduction factor of learning rate per line search iteration
|
|
1816
|
+
:param c: positive coefficient in Armijo condition, should be in (0, 1)
|
|
1817
|
+
:param step_max: initial learning rate for line search
|
|
1818
|
+
:param step_min: minimum possible learning rate (line search halts)
|
|
1493
1819
|
'''
|
|
1494
|
-
self.
|
|
1820
|
+
self.decay = decay
|
|
1495
1821
|
self.c = c
|
|
1496
|
-
self.
|
|
1497
|
-
self.
|
|
1498
|
-
|
|
1822
|
+
self.step_max = step_max
|
|
1823
|
+
self.step_min = step_min
|
|
1824
|
+
if 'clip_grad' in kwargs:
|
|
1825
|
+
raise_warning('clip_grad parameter conflicts with '
|
|
1826
|
+
'line search planner and will be ignored.', 'red')
|
|
1827
|
+
del kwargs['clip_grad']
|
|
1828
|
+
super(JaxLineSearchPlanner, self).__init__(
|
|
1499
1829
|
*args,
|
|
1500
1830
|
optimizer=optimizer,
|
|
1501
1831
|
optimizer_kwargs=optimizer_kwargs,
|
|
1502
1832
|
**kwargs)
|
|
1503
1833
|
|
|
1504
|
-
def summarize_hyperparameters(self):
|
|
1505
|
-
super(
|
|
1834
|
+
def summarize_hyperparameters(self) -> None:
|
|
1835
|
+
super(JaxLineSearchPlanner, self).summarize_hyperparameters()
|
|
1506
1836
|
print(f'linesearch hyper-parameters:\n'
|
|
1507
|
-
f'
|
|
1837
|
+
f' decay ={self.decay}\n'
|
|
1508
1838
|
f' c ={self.c}\n'
|
|
1509
|
-
f' lr_range=({self.
|
|
1839
|
+
f' lr_range=({self.step_min}, {self.step_max})')
|
|
1510
1840
|
|
|
1511
1841
|
def _jax_update(self, loss):
|
|
1512
1842
|
optimizer = self.optimizer
|
|
1513
1843
|
projection = self.plan.projection
|
|
1514
|
-
|
|
1515
|
-
|
|
1516
|
-
#
|
|
1517
|
-
|
|
1518
|
-
def
|
|
1519
|
-
|
|
1520
|
-
|
|
1521
|
-
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
|
|
1525
|
-
|
|
1526
|
-
old_x, _, old_g, _, old_state = old
|
|
1527
|
-
_, _, lr, iters = new
|
|
1528
|
-
_, best_f, _, _ = best
|
|
1529
|
-
key, hyperparams, *other = aux
|
|
1530
|
-
|
|
1531
|
-
# anneal learning rate and apply a gradient step
|
|
1532
|
-
new_lr = beta * lr
|
|
1533
|
-
old_state.hyperparams['learning_rate'] = new_lr
|
|
1534
|
-
updates, new_state = optimizer.update(old_g, old_state)
|
|
1535
|
-
new_x = optax.apply_updates(old_x, updates)
|
|
1536
|
-
new_x, _ = projection(new_x, hyperparams)
|
|
1537
|
-
|
|
1538
|
-
# evaluate new loss and record best so far
|
|
1539
|
-
new_f, _ = loss(key, new_x, hyperparams, *other)
|
|
1540
|
-
new = (new_x, new_f, new_lr, iters + 1)
|
|
1541
|
-
best = jax.lax.cond(
|
|
1542
|
-
new_f < best_f,
|
|
1543
|
-
lambda: (new_x, new_f, new_lr, new_state),
|
|
1544
|
-
lambda: best
|
|
1545
|
-
)
|
|
1546
|
-
return old, new, best, aux
|
|
1844
|
+
decay, c, lrmax, lrmin = self.decay, self.c, self.step_max, self.step_min
|
|
1845
|
+
|
|
1846
|
+
# initialize the line search routine
|
|
1847
|
+
@jax.jit
|
|
1848
|
+
def _jax_wrapped_line_search_init(key, policy_params, hyperparams,
|
|
1849
|
+
subs, model_params):
|
|
1850
|
+
(f, log), grad = jax.value_and_grad(loss, argnums=1, has_aux=True)(
|
|
1851
|
+
key, policy_params, hyperparams, subs, model_params)
|
|
1852
|
+
gnorm2 = jax.tree_map(lambda x: jnp.sum(jnp.square(x)), grad)
|
|
1853
|
+
gnorm2 = jax.tree_util.tree_reduce(jnp.add, gnorm2)
|
|
1854
|
+
log['grad'] = grad
|
|
1855
|
+
return f, grad, gnorm2, log
|
|
1547
1856
|
|
|
1857
|
+
# compute the next trial solution
|
|
1858
|
+
@jax.jit
|
|
1859
|
+
def _jax_wrapped_line_search_trial(
|
|
1860
|
+
step, grad, key, params, hparams, subs, mparams, state):
|
|
1861
|
+
state.hyperparams['learning_rate'] = step
|
|
1862
|
+
updates, new_state = optimizer.update(grad, state)
|
|
1863
|
+
new_params = optax.apply_updates(params, updates)
|
|
1864
|
+
new_params, _ = projection(new_params, hparams)
|
|
1865
|
+
f_step, _ = loss(key, new_params, hparams, subs, mparams)
|
|
1866
|
+
return f_step, new_params, new_state
|
|
1867
|
+
|
|
1868
|
+
# main iteration of line search
|
|
1548
1869
|
def _jax_wrapped_plan_update(key, policy_params, hyperparams,
|
|
1549
|
-
subs, model_params, opt_state):
|
|
1550
|
-
|
|
1551
|
-
# calculate initial loss value, gradient and squared norm
|
|
1552
|
-
old_x = policy_params
|
|
1553
|
-
loss_and_grad_fn = jax.value_and_grad(loss, argnums=1, has_aux=True)
|
|
1554
|
-
(old_f, log), old_g = loss_and_grad_fn(
|
|
1555
|
-
key, old_x, hyperparams, subs, model_params)
|
|
1556
|
-
old_norm_g2 = jax.tree_map(lambda x: jnp.sum(jnp.square(x)), old_g)
|
|
1557
|
-
old_norm_g2 = jax.tree_util.tree_reduce(jnp.add, old_norm_g2)
|
|
1558
|
-
log['grad'] = old_g
|
|
1870
|
+
subs, model_params, opt_state, opt_aux):
|
|
1559
1871
|
|
|
1560
|
-
# initialize
|
|
1561
|
-
|
|
1562
|
-
|
|
1563
|
-
new = (old_x, old_f, new_lr, 0)
|
|
1564
|
-
best = (old_x, jnp.inf, jnp.nan, opt_state)
|
|
1565
|
-
aux = (key, hyperparams, subs, model_params)
|
|
1872
|
+
# initialize the line search
|
|
1873
|
+
f, grad, gnorm2, log = _jax_wrapped_line_search_init(
|
|
1874
|
+
key, policy_params, hyperparams, subs, model_params)
|
|
1566
1875
|
|
|
1567
|
-
#
|
|
1568
|
-
|
|
1569
|
-
|
|
1876
|
+
# continue to reduce the learning rate until the Armijo condition holds
|
|
1877
|
+
trials = 0
|
|
1878
|
+
step = lrmax / decay
|
|
1879
|
+
f_step = np.inf
|
|
1880
|
+
best_f, best_step, best_params, best_state = np.inf, None, None, None
|
|
1881
|
+
while f_step > f - c * step * gnorm2 and step * decay >= lrmin:
|
|
1882
|
+
trials += 1
|
|
1883
|
+
step *= decay
|
|
1884
|
+
f_step, new_params, new_state = _jax_wrapped_line_search_trial(
|
|
1885
|
+
step, grad, key, policy_params, hyperparams, subs,
|
|
1886
|
+
model_params, opt_state)
|
|
1887
|
+
if f_step < best_f:
|
|
1888
|
+
best_f, best_step, best_params, best_state = \
|
|
1889
|
+
f_step, step, new_params, new_state
|
|
1570
1890
|
|
|
1571
|
-
# continue to anneal the learning rate until Armijo condition holds
|
|
1572
|
-
# or the learning rate becomes too small, then use the best parameter
|
|
1573
|
-
_, (*_, iters), (best_params, _, best_lr, best_state), _ = \
|
|
1574
|
-
jax.lax.while_loop(
|
|
1575
|
-
cond_fun=_jax_wrapped_line_search_armijo_check,
|
|
1576
|
-
body_fun=_jax_wrapped_line_search_iteration,
|
|
1577
|
-
init_val=init_val
|
|
1578
|
-
)
|
|
1579
|
-
best_state.hyperparams['learning_rate'] = best_lr
|
|
1580
1891
|
log['updates'] = None
|
|
1581
|
-
log['line_search_iters'] =
|
|
1582
|
-
log['learning_rate'] =
|
|
1583
|
-
return best_params, True, best_state, log
|
|
1892
|
+
log['line_search_iters'] = trials
|
|
1893
|
+
log['learning_rate'] = best_step
|
|
1894
|
+
return best_params, True, best_state, best_step, best_f, log
|
|
1584
1895
|
|
|
1585
1896
|
return _jax_wrapped_plan_update
|
|
1586
1897
|
|
|
1587
|
-
|
|
1898
|
+
|
|
1899
|
+
# ***********************************************************************
|
|
1900
|
+
# ALL VERSIONS OF RISK FUNCTIONS
|
|
1901
|
+
#
|
|
1902
|
+
# Based on the original paper "A Distributional Framework for Risk-Sensitive
|
|
1903
|
+
# End-to-End Planning in Continuous MDPs" by Patton et al., AAAI 2022.
|
|
1904
|
+
#
|
|
1905
|
+
# Original risk functions:
|
|
1906
|
+
# - entropic utility
|
|
1907
|
+
# - mean-variance approximation
|
|
1908
|
+
# - conditional value at risk with straight-through gradient trick
|
|
1909
|
+
#
|
|
1910
|
+
# ***********************************************************************
|
|
1911
|
+
|
|
1912
|
+
|
|
1913
|
+
@jax.jit
|
|
1914
|
+
def entropic_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
1915
|
+
return (-1.0 / beta) * jax.scipy.special.logsumexp(
|
|
1916
|
+
-beta * returns, b=1.0 / returns.size)
|
|
1917
|
+
|
|
1918
|
+
|
|
1919
|
+
@jax.jit
|
|
1920
|
+
def mean_variance_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
1921
|
+
return jnp.mean(returns) - (beta / 2.0) * jnp.var(returns)
|
|
1922
|
+
|
|
1923
|
+
|
|
1924
|
+
@jax.jit
|
|
1925
|
+
def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
|
|
1926
|
+
alpha_mask = jax.lax.stop_gradient(
|
|
1927
|
+
returns <= jnp.percentile(returns, q=100 * alpha))
|
|
1928
|
+
return jnp.sum(returns * alpha_mask) / jnp.sum(alpha_mask)
|
|
1929
|
+
|
|
1930
|
+
|
|
1931
|
+
# ***********************************************************************
|
|
1932
|
+
# ALL VERSIONS OF CONTROLLERS
|
|
1933
|
+
#
|
|
1934
|
+
# - offline controller is the straight-line planner
|
|
1935
|
+
# - online controller is the replanning mode
|
|
1936
|
+
#
|
|
1937
|
+
# ***********************************************************************
|
|
1938
|
+
|
|
1588
1939
|
class JaxOfflineController(BaseAgent):
|
|
1589
1940
|
'''A container class for a Jax policy trained offline.'''
|
|
1941
|
+
|
|
1590
1942
|
use_tensor_obs = True
|
|
1591
1943
|
|
|
1592
|
-
def __init__(self, planner: JaxBackpropPlanner,
|
|
1593
|
-
|
|
1594
|
-
|
|
1944
|
+
def __init__(self, planner: JaxBackpropPlanner,
|
|
1945
|
+
key: Optional[random.PRNGKey]=None,
|
|
1946
|
+
eval_hyperparams: Optional[Dict[str, Any]]=None,
|
|
1947
|
+
params: Optional[Pytree]=None,
|
|
1595
1948
|
train_on_reset: bool=False,
|
|
1596
1949
|
**train_kwargs) -> None:
|
|
1597
1950
|
'''Creates a new JAX offline control policy that is trained once, then
|
|
1598
1951
|
deployed later.
|
|
1599
1952
|
|
|
1600
1953
|
:param planner: underlying planning algorithm for optimizing actions
|
|
1601
|
-
:param key: the RNG key to seed randomness
|
|
1954
|
+
:param key: the RNG key to seed randomness (derives from clock if not
|
|
1955
|
+
provided)
|
|
1602
1956
|
:param eval_hyperparams: policy hyperparameters to apply for evaluation
|
|
1603
1957
|
or whenever sample_action is called
|
|
1604
1958
|
:param params: use the specified policy parameters instead of calling
|
|
@@ -1608,6 +1962,8 @@ class JaxOfflineController(BaseAgent):
|
|
|
1608
1962
|
for optimization
|
|
1609
1963
|
'''
|
|
1610
1964
|
self.planner = planner
|
|
1965
|
+
if key is None:
|
|
1966
|
+
key = random.PRNGKey(round(time.time() * 1000))
|
|
1611
1967
|
self.key = key
|
|
1612
1968
|
self.eval_hyperparams = eval_hyperparams
|
|
1613
1969
|
self.train_on_reset = train_on_reset
|
|
@@ -1616,17 +1972,18 @@ class JaxOfflineController(BaseAgent):
|
|
|
1616
1972
|
|
|
1617
1973
|
self.step = 0
|
|
1618
1974
|
if not self.train_on_reset and not self.params_given:
|
|
1619
|
-
|
|
1975
|
+
callback = self.planner.optimize(key=self.key, **self.train_kwargs)
|
|
1976
|
+
params = callback['best_params']
|
|
1620
1977
|
self.params = params
|
|
1621
1978
|
|
|
1622
|
-
def sample_action(self, state):
|
|
1979
|
+
def sample_action(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
|
1623
1980
|
self.key, subkey = random.split(self.key)
|
|
1624
1981
|
actions = self.planner.get_action(
|
|
1625
1982
|
subkey, self.params, self.step, state, self.eval_hyperparams)
|
|
1626
1983
|
self.step += 1
|
|
1627
1984
|
return actions
|
|
1628
1985
|
|
|
1629
|
-
def reset(self):
|
|
1986
|
+
def reset(self) -> None:
|
|
1630
1987
|
self.step = 0
|
|
1631
1988
|
if self.train_on_reset and not self.params_given:
|
|
1632
1989
|
self.params = self.planner.optimize(key=self.key, **self.train_kwargs)
|
|
@@ -1635,41 +1992,51 @@ class JaxOfflineController(BaseAgent):
|
|
|
1635
1992
|
class JaxOnlineController(BaseAgent):
|
|
1636
1993
|
'''A container class for a Jax controller continuously updated using state
|
|
1637
1994
|
feedback.'''
|
|
1995
|
+
|
|
1638
1996
|
use_tensor_obs = True
|
|
1639
1997
|
|
|
1640
|
-
def __init__(self, planner: JaxBackpropPlanner,
|
|
1641
|
-
|
|
1998
|
+
def __init__(self, planner: JaxBackpropPlanner,
|
|
1999
|
+
key: Optional[random.PRNGKey]=None,
|
|
2000
|
+
eval_hyperparams: Optional[Dict[str, Any]]=None,
|
|
2001
|
+
warm_start: bool=True,
|
|
1642
2002
|
**train_kwargs) -> None:
|
|
1643
2003
|
'''Creates a new JAX control policy that is trained online in a closed-
|
|
1644
2004
|
loop fashion.
|
|
1645
2005
|
|
|
1646
2006
|
:param planner: underlying planning algorithm for optimizing actions
|
|
1647
|
-
:param key: the RNG key to seed randomness
|
|
2007
|
+
:param key: the RNG key to seed randomness (derives from clock if not
|
|
2008
|
+
provided)
|
|
1648
2009
|
:param eval_hyperparams: policy hyperparameters to apply for evaluation
|
|
1649
2010
|
or whenever sample_action is called
|
|
2011
|
+
:param warm_start: whether to use the previous decision epoch final
|
|
2012
|
+
policy parameters to warm the next decision epoch
|
|
1650
2013
|
:param **train_kwargs: any keyword arguments to be passed to the planner
|
|
1651
2014
|
for optimization
|
|
1652
2015
|
'''
|
|
1653
2016
|
self.planner = planner
|
|
2017
|
+
if key is None:
|
|
2018
|
+
key = random.PRNGKey(round(time.time() * 1000))
|
|
1654
2019
|
self.key = key
|
|
1655
2020
|
self.eval_hyperparams = eval_hyperparams
|
|
1656
2021
|
self.warm_start = warm_start
|
|
1657
2022
|
self.train_kwargs = train_kwargs
|
|
1658
2023
|
self.reset()
|
|
1659
2024
|
|
|
1660
|
-
def sample_action(self, state):
|
|
2025
|
+
def sample_action(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
|
1661
2026
|
planner = self.planner
|
|
1662
|
-
|
|
2027
|
+
callback = planner.optimize(
|
|
1663
2028
|
key=self.key,
|
|
1664
2029
|
guess=self.guess,
|
|
1665
2030
|
subs=state,
|
|
1666
2031
|
**self.train_kwargs)
|
|
2032
|
+
params = callback['best_params']
|
|
1667
2033
|
self.key, subkey = random.split(self.key)
|
|
1668
|
-
actions = planner.get_action(
|
|
2034
|
+
actions = planner.get_action(
|
|
2035
|
+
subkey, params, 0, state, self.eval_hyperparams)
|
|
1669
2036
|
if self.warm_start:
|
|
1670
2037
|
self.guess = planner.plan.guess_next_epoch(params)
|
|
1671
2038
|
return actions
|
|
1672
2039
|
|
|
1673
|
-
def reset(self):
|
|
2040
|
+
def reset(self) -> None:
|
|
1674
2041
|
self.guess = None
|
|
1675
2042
|
|