pyRDDLGym-jax 0.1__py3-none-any.whl → 0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -0
- pyRDDLGym_jax/core/compiler.py +444 -221
- pyRDDLGym_jax/core/logic.py +129 -62
- pyRDDLGym_jax/core/planner.py +965 -394
- pyRDDLGym_jax/core/simulator.py +5 -7
- pyRDDLGym_jax/core/tuning.py +29 -15
- pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_drp.cfg → Cartpole_Continuous_gym_drp.cfg} +2 -3
- pyRDDLGym_jax/examples/configs/{HVAC_drp.cfg → HVAC_ippc2023_drp.cfg} +4 -4
- pyRDDLGym_jax/examples/configs/{MarsRover_drp.cfg → MarsRover_ippc2023_drp.cfg} +1 -0
- pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +19 -0
- pyRDDLGym_jax/examples/configs/{Pendulum_slp.cfg → Pendulum_gym_slp.cfg} +1 -1
- pyRDDLGym_jax/examples/configs/{Pong_slp.cfg → Quadcopter_drp.cfg} +5 -5
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +18 -0
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +1 -1
- pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +1 -1
- pyRDDLGym_jax/examples/configs/default_drp.cfg +19 -0
- pyRDDLGym_jax/examples/configs/default_replan.cfg +20 -0
- pyRDDLGym_jax/examples/configs/default_slp.cfg +19 -0
- pyRDDLGym_jax/examples/run_gradient.py +1 -1
- pyRDDLGym_jax/examples/run_gym.py +3 -7
- pyRDDLGym_jax/examples/run_plan.py +10 -5
- pyRDDLGym_jax/examples/run_scipy.py +61 -0
- pyRDDLGym_jax/examples/run_tune.py +8 -3
- {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.3.dist-info}/METADATA +1 -1
- pyRDDLGym_jax-0.3.dist-info/RECORD +44 -0
- {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.3.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax/examples/configs/SupplyChain_slp.cfg +0 -18
- pyRDDLGym_jax/examples/configs/Traffic_slp.cfg +0 -20
- pyRDDLGym_jax-0.1.dist-info/RECORD +0 -40
- /pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_replan.cfg → Cartpole_Continuous_gym_replan.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_slp.cfg → Cartpole_Continuous_gym_slp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{HVAC_slp.cfg → HVAC_ippc2023_slp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{MarsRover_slp.cfg → MarsRover_ippc2023_slp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{MountainCar_slp.cfg → MountainCar_Continuous_gym_slp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{PowerGen_drp.cfg → PowerGen_Continuous_drp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{PowerGen_replan.cfg → PowerGen_Continuous_replan.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{PowerGen_slp.cfg → PowerGen_Continuous_slp.cfg} +0 -0
- {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.3.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.3.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/planner.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from ast import literal_eval
|
|
2
2
|
from collections import deque
|
|
3
3
|
import configparser
|
|
4
|
+
from enum import Enum
|
|
4
5
|
import haiku as hk
|
|
5
6
|
import jax
|
|
6
7
|
import jax.numpy as jnp
|
|
@@ -12,12 +13,33 @@ import os
|
|
|
12
13
|
import sys
|
|
13
14
|
import termcolor
|
|
14
15
|
import time
|
|
16
|
+
import traceback
|
|
15
17
|
from tqdm import tqdm
|
|
16
|
-
from typing import Callable, Dict, Generator, Set, Sequence, Tuple
|
|
18
|
+
from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Tuple, Union
|
|
17
19
|
|
|
20
|
+
Activation = Callable[[jnp.ndarray], jnp.ndarray]
|
|
21
|
+
Bounds = Dict[str, Tuple[np.ndarray, np.ndarray]]
|
|
22
|
+
Kwargs = Dict[str, Any]
|
|
23
|
+
Pytree = Any
|
|
24
|
+
|
|
25
|
+
from pyRDDLGym.core.debug.exception import raise_warning
|
|
26
|
+
|
|
27
|
+
from pyRDDLGym_jax import __version__
|
|
28
|
+
|
|
29
|
+
# try to import matplotlib, if failed then skip plotting
|
|
30
|
+
try:
|
|
31
|
+
import matplotlib
|
|
32
|
+
import matplotlib.pyplot as plt
|
|
33
|
+
matplotlib.use('TkAgg')
|
|
34
|
+
except Exception:
|
|
35
|
+
raise_warning('failed to import matplotlib: '
|
|
36
|
+
'plotting functionality will be disabled.', 'red')
|
|
37
|
+
traceback.print_exc()
|
|
38
|
+
plt = None
|
|
39
|
+
|
|
18
40
|
from pyRDDLGym.core.compiler.model import RDDLPlanningModel, RDDLLiftedModel
|
|
41
|
+
from pyRDDLGym.core.debug.logger import Logger
|
|
19
42
|
from pyRDDLGym.core.debug.exception import (
|
|
20
|
-
raise_warning,
|
|
21
43
|
RDDLNotImplementedError,
|
|
22
44
|
RDDLUndefinedVariableError,
|
|
23
45
|
RDDLTypeError
|
|
@@ -37,6 +59,7 @@ from pyRDDLGym_jax.core.logic import FuzzyLogic
|
|
|
37
59
|
# - instantiate planner
|
|
38
60
|
#
|
|
39
61
|
# ***********************************************************************
|
|
62
|
+
|
|
40
63
|
def _parse_config_file(path: str):
|
|
41
64
|
if not os.path.isfile(path):
|
|
42
65
|
raise FileNotFoundError(f'File {path} does not exist.')
|
|
@@ -59,51 +82,96 @@ def _parse_config_string(value: str):
|
|
|
59
82
|
return config, args
|
|
60
83
|
|
|
61
84
|
|
|
85
|
+
def _getattr_any(packages, item):
|
|
86
|
+
for package in packages:
|
|
87
|
+
loaded = getattr(package, item, None)
|
|
88
|
+
if loaded is not None:
|
|
89
|
+
return loaded
|
|
90
|
+
return None
|
|
91
|
+
|
|
92
|
+
|
|
62
93
|
def _load_config(config, args):
|
|
63
94
|
model_args = {k: args[k] for (k, _) in config.items('Model')}
|
|
64
95
|
planner_args = {k: args[k] for (k, _) in config.items('Optimizer')}
|
|
65
96
|
train_args = {k: args[k] for (k, _) in config.items('Training')}
|
|
66
97
|
|
|
67
|
-
train_args['key'] = jax.random.PRNGKey(train_args['key'])
|
|
68
|
-
|
|
69
98
|
# read the model settings
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
99
|
+
logic_name = model_args.get('logic', 'FuzzyLogic')
|
|
100
|
+
logic_kwargs = model_args.get('logic_kwargs', {})
|
|
101
|
+
tnorm_name = model_args.get('tnorm', 'ProductTNorm')
|
|
102
|
+
tnorm_kwargs = model_args.get('tnorm_kwargs', {})
|
|
103
|
+
comp_name = model_args.get('complement', 'StandardComplement')
|
|
104
|
+
comp_kwargs = model_args.get('complement_kwargs', {})
|
|
105
|
+
compare_name = model_args.get('comparison', 'SigmoidComparison')
|
|
106
|
+
compare_kwargs = model_args.get('comparison_kwargs', {})
|
|
74
107
|
logic_kwargs['tnorm'] = getattr(logic, tnorm_name)(**tnorm_kwargs)
|
|
75
|
-
|
|
108
|
+
logic_kwargs['complement'] = getattr(logic, comp_name)(**comp_kwargs)
|
|
109
|
+
logic_kwargs['comparison'] = getattr(logic, compare_name)(**compare_kwargs)
|
|
76
110
|
|
|
77
|
-
# read the
|
|
111
|
+
# read the policy settings
|
|
78
112
|
plan_method = planner_args.pop('method')
|
|
79
113
|
plan_kwargs = planner_args.pop('method_kwargs', {})
|
|
80
114
|
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
plan_kwargs['initializer']
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
115
|
+
# policy initialization
|
|
116
|
+
plan_initializer = plan_kwargs.get('initializer', None)
|
|
117
|
+
if plan_initializer is not None:
|
|
118
|
+
initializer = _getattr_any(
|
|
119
|
+
packages=[initializers, hk.initializers], item=plan_initializer)
|
|
120
|
+
if initializer is None:
|
|
121
|
+
raise_warning(
|
|
122
|
+
f'Ignoring invalid initializer <{plan_initializer}>.', 'red')
|
|
123
|
+
del plan_kwargs['initializer']
|
|
124
|
+
else:
|
|
125
|
+
init_kwargs = plan_kwargs.pop('initializer_kwargs', {})
|
|
126
|
+
try:
|
|
127
|
+
plan_kwargs['initializer'] = initializer(**init_kwargs)
|
|
128
|
+
except Exception as _:
|
|
129
|
+
raise_warning(
|
|
130
|
+
f'Ignoring invalid initializer_kwargs <{init_kwargs}>.', 'red')
|
|
131
|
+
plan_kwargs['initializer'] = initializer
|
|
132
|
+
|
|
133
|
+
# policy activation
|
|
134
|
+
plan_activation = plan_kwargs.get('activation', None)
|
|
135
|
+
if plan_activation is not None:
|
|
136
|
+
activation = _getattr_any(
|
|
137
|
+
packages=[jax.nn, jax.numpy], item=plan_activation)
|
|
138
|
+
if activation is None:
|
|
139
|
+
raise_warning(
|
|
140
|
+
f'Ignoring invalid activation <{plan_activation}>.', 'red')
|
|
141
|
+
del plan_kwargs['activation']
|
|
142
|
+
else:
|
|
143
|
+
plan_kwargs['activation'] = activation
|
|
93
144
|
|
|
145
|
+
# read the planner settings
|
|
146
|
+
planner_args['logic'] = getattr(logic, logic_name)(**logic_kwargs)
|
|
94
147
|
planner_args['plan'] = getattr(sys.modules[__name__], plan_method)(**plan_kwargs)
|
|
95
|
-
|
|
148
|
+
|
|
149
|
+
# planner optimizer
|
|
150
|
+
planner_optimizer = planner_args.get('optimizer', None)
|
|
151
|
+
if planner_optimizer is not None:
|
|
152
|
+
optimizer = _getattr_any(packages=[optax], item=planner_optimizer)
|
|
153
|
+
if optimizer is None:
|
|
154
|
+
raise_warning(
|
|
155
|
+
f'Ignoring invalid optimizer <{planner_optimizer}>.', 'red')
|
|
156
|
+
del planner_args['optimizer']
|
|
157
|
+
else:
|
|
158
|
+
planner_args['optimizer'] = optimizer
|
|
159
|
+
|
|
160
|
+
# read the optimize call settings
|
|
161
|
+
planner_key = train_args.get('key', None)
|
|
162
|
+
if planner_key is not None:
|
|
163
|
+
train_args['key'] = random.PRNGKey(planner_key)
|
|
96
164
|
|
|
97
165
|
return planner_args, plan_kwargs, train_args
|
|
98
166
|
|
|
99
167
|
|
|
100
|
-
def load_config(path: str) -> Tuple[
|
|
168
|
+
def load_config(path: str) -> Tuple[Kwargs, ...]:
|
|
101
169
|
'''Loads a config file at the specified file path.'''
|
|
102
170
|
config, args = _parse_config_file(path)
|
|
103
171
|
return _load_config(config, args)
|
|
104
172
|
|
|
105
173
|
|
|
106
|
-
def load_config_from_string(value: str) -> Tuple[
|
|
174
|
+
def load_config_from_string(value: str) -> Tuple[Kwargs, ...]:
|
|
107
175
|
'''Loads config file contents specified explicitly as a string value.'''
|
|
108
176
|
config, args = _parse_config_string(value)
|
|
109
177
|
return _load_config(config, args)
|
|
@@ -115,6 +183,20 @@ def load_config_from_string(value: str) -> Tuple[Dict[str, object], ...]:
|
|
|
115
183
|
# - replace discrete ops in state dynamics/reward with differentiable ones
|
|
116
184
|
#
|
|
117
185
|
# ***********************************************************************
|
|
186
|
+
|
|
187
|
+
def _function_discrete_approx_named(logic):
|
|
188
|
+
jax_discrete, jax_param = logic.discrete()
|
|
189
|
+
|
|
190
|
+
def _jax_wrapped_discrete_calc_approx(key, prob, params):
|
|
191
|
+
sample = jax_discrete(key, prob, params)
|
|
192
|
+
out_of_bounds = jnp.logical_not(jnp.logical_and(
|
|
193
|
+
jnp.all(prob >= 0),
|
|
194
|
+
jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
|
|
195
|
+
return sample, out_of_bounds
|
|
196
|
+
|
|
197
|
+
return _jax_wrapped_discrete_calc_approx, jax_param
|
|
198
|
+
|
|
199
|
+
|
|
118
200
|
class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
119
201
|
'''Compiles a RDDL AST representation to an equivalent JAX representation.
|
|
120
202
|
Unlike its parent class, this class treats all fluents as real-valued, and
|
|
@@ -124,7 +206,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
124
206
|
|
|
125
207
|
def __init__(self, *args,
|
|
126
208
|
logic: FuzzyLogic=FuzzyLogic(),
|
|
127
|
-
cpfs_without_grad: Set=
|
|
209
|
+
cpfs_without_grad: Optional[Set[str]]=None,
|
|
128
210
|
**kwargs) -> None:
|
|
129
211
|
'''Creates a new RDDL to Jax compiler, where operations that are not
|
|
130
212
|
differentiable are converted to approximate forms that have defined
|
|
@@ -139,28 +221,37 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
139
221
|
:param *kwargs: keyword arguments to pass to base compiler
|
|
140
222
|
'''
|
|
141
223
|
super(JaxRDDLCompilerWithGrad, self).__init__(*args, **kwargs)
|
|
224
|
+
|
|
142
225
|
self.logic = logic
|
|
226
|
+
self.logic.set_use64bit(self.use64bit)
|
|
227
|
+
if cpfs_without_grad is None:
|
|
228
|
+
cpfs_without_grad = set()
|
|
143
229
|
self.cpfs_without_grad = cpfs_without_grad
|
|
144
230
|
|
|
145
231
|
# actions and CPFs must be continuous
|
|
146
|
-
|
|
232
|
+
pvars_cast = set()
|
|
147
233
|
for (var, values) in self.init_values.items():
|
|
148
234
|
self.init_values[var] = np.asarray(values, dtype=self.REAL)
|
|
235
|
+
if not np.issubdtype(np.atleast_1d(values).dtype, np.floating):
|
|
236
|
+
pvars_cast.add(var)
|
|
237
|
+
if pvars_cast:
|
|
238
|
+
raise_warning(f'JAX gradient compiler requires that initial values '
|
|
239
|
+
f'of p-variables {pvars_cast} be cast to float.')
|
|
149
240
|
|
|
150
241
|
# overwrite basic operations with fuzzy ones
|
|
151
242
|
self.RELATIONAL_OPS = {
|
|
152
|
-
'>=': logic.
|
|
153
|
-
'<=': logic.
|
|
243
|
+
'>=': logic.greater_equal(),
|
|
244
|
+
'<=': logic.less_equal(),
|
|
154
245
|
'<': logic.less(),
|
|
155
246
|
'>': logic.greater(),
|
|
156
247
|
'==': logic.equal(),
|
|
157
|
-
'~=': logic.
|
|
248
|
+
'~=': logic.not_equal()
|
|
158
249
|
}
|
|
159
|
-
self.LOGICAL_NOT = logic.
|
|
250
|
+
self.LOGICAL_NOT = logic.logical_not()
|
|
160
251
|
self.LOGICAL_OPS = {
|
|
161
|
-
'^': logic.
|
|
162
|
-
'&': logic.
|
|
163
|
-
'|': logic.
|
|
252
|
+
'^': logic.logical_and(),
|
|
253
|
+
'&': logic.logical_and(),
|
|
254
|
+
'|': logic.logical_or(),
|
|
164
255
|
'~': logic.xor(),
|
|
165
256
|
'=>': logic.implies(),
|
|
166
257
|
'<=>': logic.equiv()
|
|
@@ -169,15 +260,19 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
169
260
|
self.AGGREGATION_OPS['exists'] = logic.exists()
|
|
170
261
|
self.AGGREGATION_OPS['argmin'] = logic.argmin()
|
|
171
262
|
self.AGGREGATION_OPS['argmax'] = logic.argmax()
|
|
172
|
-
self.KNOWN_UNARY['sgn'] = logic.
|
|
263
|
+
self.KNOWN_UNARY['sgn'] = logic.sgn()
|
|
173
264
|
self.KNOWN_UNARY['floor'] = logic.floor()
|
|
174
265
|
self.KNOWN_UNARY['ceil'] = logic.ceil()
|
|
175
266
|
self.KNOWN_UNARY['round'] = logic.round()
|
|
176
267
|
self.KNOWN_UNARY['sqrt'] = logic.sqrt()
|
|
177
|
-
self.KNOWN_BINARY['div'] = logic.
|
|
268
|
+
self.KNOWN_BINARY['div'] = logic.div()
|
|
178
269
|
self.KNOWN_BINARY['mod'] = logic.mod()
|
|
179
270
|
self.KNOWN_BINARY['fmod'] = logic.mod()
|
|
180
|
-
|
|
271
|
+
self.IF_HELPER = logic.control_if()
|
|
272
|
+
self.SWITCH_HELPER = logic.control_switch()
|
|
273
|
+
self.BERNOULLI_HELPER = logic.bernoulli()
|
|
274
|
+
self.DISCRETE_HELPER = _function_discrete_approx_named(logic)
|
|
275
|
+
|
|
181
276
|
def _jax_stop_grad(self, jax_expr):
|
|
182
277
|
|
|
183
278
|
def _jax_wrapped_stop_grad(x, params, key):
|
|
@@ -188,46 +283,33 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
188
283
|
return _jax_wrapped_stop_grad
|
|
189
284
|
|
|
190
285
|
def _compile_cpfs(self, info):
|
|
191
|
-
|
|
286
|
+
cpfs_cast = set()
|
|
192
287
|
jax_cpfs = {}
|
|
193
288
|
for (_, cpfs) in self.levels.items():
|
|
194
289
|
for cpf in cpfs:
|
|
195
290
|
_, expr = self.rddl.cpfs[cpf]
|
|
196
291
|
jax_cpfs[cpf] = self._jax(expr, info, dtype=self.REAL)
|
|
292
|
+
if self.rddl.variable_ranges[cpf] != 'real':
|
|
293
|
+
cpfs_cast.add(cpf)
|
|
197
294
|
if cpf in self.cpfs_without_grad:
|
|
198
|
-
raise_warning(f'CPF <{cpf}> stops gradient.')
|
|
199
295
|
jax_cpfs[cpf] = self._jax_stop_grad(jax_cpfs[cpf])
|
|
296
|
+
|
|
297
|
+
if cpfs_cast:
|
|
298
|
+
raise_warning(f'JAX gradient compiler requires that outputs of CPFs '
|
|
299
|
+
f'{cpfs_cast} be cast to float.')
|
|
300
|
+
if self.cpfs_without_grad:
|
|
301
|
+
raise_warning(f'User requested that gradients not flow '
|
|
302
|
+
f'through CPFs {self.cpfs_without_grad}.')
|
|
200
303
|
return jax_cpfs
|
|
201
304
|
|
|
202
|
-
def _jax_if_helper(self):
|
|
203
|
-
return self.logic.If()
|
|
204
|
-
|
|
205
|
-
def _jax_switch_helper(self):
|
|
206
|
-
return self.logic.Switch()
|
|
207
|
-
|
|
208
305
|
def _jax_kron(self, expr, info):
|
|
209
306
|
if self.logic.verbose:
|
|
210
|
-
raise_warning('
|
|
211
|
-
|
|
307
|
+
raise_warning('JAX gradient compiler ignores KronDelta '
|
|
308
|
+
'during compilation.')
|
|
212
309
|
arg, = expr.args
|
|
213
310
|
arg = self._jax(arg, info)
|
|
214
311
|
return arg
|
|
215
312
|
|
|
216
|
-
def _jax_bernoulli_helper(self):
|
|
217
|
-
return self.logic.bernoulli()
|
|
218
|
-
|
|
219
|
-
def _jax_discrete_helper(self):
|
|
220
|
-
jax_discrete, jax_param = self.logic.discrete()
|
|
221
|
-
|
|
222
|
-
def _jax_wrapped_discrete_calc_approx(key, prob, params):
|
|
223
|
-
sample = jax_discrete(key, prob, params)
|
|
224
|
-
out_of_bounds = jnp.logical_not(jnp.logical_and(
|
|
225
|
-
jnp.all(prob >= 0),
|
|
226
|
-
jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
|
|
227
|
-
return sample, out_of_bounds
|
|
228
|
-
|
|
229
|
-
return _jax_wrapped_discrete_calc_approx, jax_param
|
|
230
|
-
|
|
231
313
|
|
|
232
314
|
# ***********************************************************************
|
|
233
315
|
# ALL VERSIONS OF JAX PLANS
|
|
@@ -236,6 +318,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
236
318
|
# - deep reactive policy
|
|
237
319
|
#
|
|
238
320
|
# ***********************************************************************
|
|
321
|
+
|
|
239
322
|
class JaxPlan:
|
|
240
323
|
'''Base class for all JAX policy representations.'''
|
|
241
324
|
|
|
@@ -244,16 +327,17 @@ class JaxPlan:
|
|
|
244
327
|
self._train_policy = None
|
|
245
328
|
self._test_policy = None
|
|
246
329
|
self._projection = None
|
|
247
|
-
|
|
248
|
-
|
|
330
|
+
self.bounds = None
|
|
331
|
+
|
|
332
|
+
def summarize_hyperparameters(self) -> None:
|
|
249
333
|
pass
|
|
250
334
|
|
|
251
335
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
252
|
-
_bounds:
|
|
336
|
+
_bounds: Bounds,
|
|
253
337
|
horizon: int) -> None:
|
|
254
338
|
raise NotImplementedError
|
|
255
339
|
|
|
256
|
-
def guess_next_epoch(self, params:
|
|
340
|
+
def guess_next_epoch(self, params: Pytree) -> Pytree:
|
|
257
341
|
raise NotImplementedError
|
|
258
342
|
|
|
259
343
|
@property
|
|
@@ -289,7 +373,8 @@ class JaxPlan:
|
|
|
289
373
|
self._projection = value
|
|
290
374
|
|
|
291
375
|
def _calculate_action_info(self, compiled: JaxRDDLCompilerWithGrad,
|
|
292
|
-
user_bounds:
|
|
376
|
+
user_bounds: Bounds,
|
|
377
|
+
horizon: int):
|
|
293
378
|
shapes, bounds, bounds_safe, cond_lists = {}, {}, {}, {}
|
|
294
379
|
for (name, prange) in compiled.rddl.variable_ranges.items():
|
|
295
380
|
if compiled.rddl.variable_types[name] != 'action-fluent':
|
|
@@ -298,7 +383,7 @@ class JaxPlan:
|
|
|
298
383
|
# check invalid type
|
|
299
384
|
if prange not in compiled.JAX_TYPES:
|
|
300
385
|
raise RDDLTypeError(
|
|
301
|
-
f'Invalid range <{prange}
|
|
386
|
+
f'Invalid range <{prange}> of action-fluent <{name}>, '
|
|
302
387
|
f'must be one of {set(compiled.JAX_TYPES.keys())}.')
|
|
303
388
|
|
|
304
389
|
# clip boolean to (0, 1), otherwise use the RDDL action bounds
|
|
@@ -309,8 +394,8 @@ class JaxPlan:
|
|
|
309
394
|
else:
|
|
310
395
|
lower, upper = compiled.constraints.bounds[name]
|
|
311
396
|
lower, upper = user_bounds.get(name, (lower, upper))
|
|
312
|
-
lower = np.asarray(lower, dtype=
|
|
313
|
-
upper = np.asarray(upper, dtype=
|
|
397
|
+
lower = np.asarray(lower, dtype=compiled.REAL)
|
|
398
|
+
upper = np.asarray(upper, dtype=compiled.REAL)
|
|
314
399
|
lower_finite = np.isfinite(lower)
|
|
315
400
|
upper_finite = np.isfinite(upper)
|
|
316
401
|
bounds_safe[name] = (np.where(lower_finite, lower, 0.0),
|
|
@@ -320,7 +405,7 @@ class JaxPlan:
|
|
|
320
405
|
~lower_finite & upper_finite,
|
|
321
406
|
~lower_finite & ~upper_finite]
|
|
322
407
|
bounds[name] = (lower, upper)
|
|
323
|
-
raise_warning(f'Bounds of action
|
|
408
|
+
raise_warning(f'Bounds of action-fluent <{name}> set to {bounds[name]}.')
|
|
324
409
|
return shapes, bounds, bounds_safe, cond_lists
|
|
325
410
|
|
|
326
411
|
def _count_bool_actions(self, rddl: RDDLLiftedModel):
|
|
@@ -336,7 +421,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
336
421
|
|
|
337
422
|
def __init__(self, initializer: initializers.Initializer=initializers.normal(),
|
|
338
423
|
wrap_sigmoid: bool=True,
|
|
339
|
-
min_action_prob: float=1e-
|
|
424
|
+
min_action_prob: float=1e-6,
|
|
340
425
|
wrap_non_bool: bool=False,
|
|
341
426
|
wrap_softmax: bool=False,
|
|
342
427
|
use_new_projection: bool=False,
|
|
@@ -362,6 +447,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
362
447
|
use_new_projection = True
|
|
363
448
|
'''
|
|
364
449
|
super(JaxStraightLinePlan, self).__init__()
|
|
450
|
+
|
|
365
451
|
self._initializer_base = initializer
|
|
366
452
|
self._initializer = initializer
|
|
367
453
|
self._wrap_sigmoid = wrap_sigmoid
|
|
@@ -371,10 +457,13 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
371
457
|
self._use_new_projection = use_new_projection
|
|
372
458
|
self._max_constraint_iter = max_constraint_iter
|
|
373
459
|
|
|
374
|
-
def summarize_hyperparameters(self):
|
|
460
|
+
def summarize_hyperparameters(self) -> None:
|
|
461
|
+
bounds = '\n '.join(
|
|
462
|
+
map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
|
|
375
463
|
print(f'policy hyper-parameters:\n'
|
|
376
|
-
f' initializer ={
|
|
464
|
+
f' initializer ={self._initializer_base}\n'
|
|
377
465
|
f'constraint-sat strategy (simple):\n'
|
|
466
|
+
f' parsed_action_bounds =\n {bounds}\n'
|
|
378
467
|
f' wrap_sigmoid ={self._wrap_sigmoid}\n'
|
|
379
468
|
f' wrap_sigmoid_min_prob={self._min_action_prob}\n'
|
|
380
469
|
f' wrap_non_bool ={self._wrap_non_bool}\n'
|
|
@@ -383,7 +472,8 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
383
472
|
f' use_new_projection ={self._use_new_projection}')
|
|
384
473
|
|
|
385
474
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
386
|
-
_bounds:
|
|
475
|
+
_bounds: Bounds,
|
|
476
|
+
horizon: int) -> None:
|
|
387
477
|
rddl = compiled.rddl
|
|
388
478
|
|
|
389
479
|
# calculate the correct action box bounds
|
|
@@ -423,7 +513,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
423
513
|
def _jax_bool_action_to_param(var, action, hyperparams):
|
|
424
514
|
if wrap_sigmoid:
|
|
425
515
|
weight = hyperparams[var]
|
|
426
|
-
return (-1.0 / weight) * jnp.
|
|
516
|
+
return (-1.0 / weight) * jnp.log(1.0 / action - 1.0)
|
|
427
517
|
else:
|
|
428
518
|
return action
|
|
429
519
|
|
|
@@ -506,7 +596,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
506
596
|
def _jax_wrapped_slp_predict_test(key, params, hyperparams, step, subs):
|
|
507
597
|
actions = {}
|
|
508
598
|
for (var, param) in params.items():
|
|
509
|
-
action = jnp.asarray(param[step, ...])
|
|
599
|
+
action = jnp.asarray(param[step, ...], dtype=compiled.REAL)
|
|
510
600
|
if var == bool_key:
|
|
511
601
|
output = jax.nn.softmax(action)
|
|
512
602
|
bool_actions = _jax_unstack_bool_from_softmax(output)
|
|
@@ -537,7 +627,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
537
627
|
if 1 < allowed_actions < bool_action_count:
|
|
538
628
|
raise RDDLNotImplementedError(
|
|
539
629
|
f'Straight-line plans with wrap_softmax currently '
|
|
540
|
-
f'do not support max-nondef-actions
|
|
630
|
+
f'do not support max-nondef-actions {allowed_actions} > 1.')
|
|
541
631
|
|
|
542
632
|
# potentially apply projection but to non-bool actions only
|
|
543
633
|
self.projection = _jax_wrapped_slp_project_to_box
|
|
@@ -668,14 +758,14 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
668
758
|
for (var, shape) in shapes.items():
|
|
669
759
|
if ranges[var] != 'bool' or not stack_bool_params:
|
|
670
760
|
key, subkey = random.split(key)
|
|
671
|
-
param = init(subkey, shape, dtype=compiled.REAL)
|
|
761
|
+
param = init(key=subkey, shape=shape, dtype=compiled.REAL)
|
|
672
762
|
if ranges[var] == 'bool':
|
|
673
763
|
param += bool_threshold
|
|
674
764
|
params[var] = param
|
|
675
765
|
if stack_bool_params:
|
|
676
766
|
key, subkey = random.split(key)
|
|
677
767
|
bool_shape = (horizon, bool_action_count)
|
|
678
|
-
bool_param = init(subkey, bool_shape, dtype=compiled.REAL)
|
|
768
|
+
bool_param = init(key=subkey, shape=bool_shape, dtype=compiled.REAL)
|
|
679
769
|
params[bool_key] = bool_param
|
|
680
770
|
params, _ = _jax_wrapped_slp_project_to_box(params, hyperparams)
|
|
681
771
|
return params
|
|
@@ -688,7 +778,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
688
778
|
# "progress" the plan one step forward and set last action to second-last
|
|
689
779
|
return jnp.append(param[1:, ...], param[-1:, ...], axis=0)
|
|
690
780
|
|
|
691
|
-
def guess_next_epoch(self, params:
|
|
781
|
+
def guess_next_epoch(self, params: Pytree) -> Pytree:
|
|
692
782
|
next_fn = JaxStraightLinePlan._guess_next_epoch
|
|
693
783
|
return jax.tree_map(next_fn, params)
|
|
694
784
|
|
|
@@ -696,10 +786,13 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
696
786
|
class JaxDeepReactivePolicy(JaxPlan):
|
|
697
787
|
'''A deep reactive policy network implementation in JAX.'''
|
|
698
788
|
|
|
699
|
-
def __init__(self, topology: Sequence[int],
|
|
700
|
-
activation:
|
|
789
|
+
def __init__(self, topology: Optional[Sequence[int]]=None,
|
|
790
|
+
activation: Activation=jnp.tanh,
|
|
701
791
|
initializer: hk.initializers.Initializer=hk.initializers.VarianceScaling(scale=2.0),
|
|
702
|
-
normalize: bool=
|
|
792
|
+
normalize: bool=False,
|
|
793
|
+
normalize_per_layer: bool=False,
|
|
794
|
+
normalizer_kwargs: Optional[Kwargs]=None,
|
|
795
|
+
wrap_non_bool: bool=False) -> None:
|
|
703
796
|
'''Creates a new deep reactive policy in JAX.
|
|
704
797
|
|
|
705
798
|
:param neurons: sequence consisting of the number of neurons in each
|
|
@@ -707,23 +800,45 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
707
800
|
:param activation: function to apply after each layer of the policy
|
|
708
801
|
:param initializer: weight initialization
|
|
709
802
|
:param normalize: whether to apply layer norm to the inputs
|
|
803
|
+
:param normalize_per_layer: whether to apply layer norm to each input
|
|
804
|
+
individually (only active if normalize is True)
|
|
805
|
+
:param normalizer_kwargs: if normalize is True, apply additional arguments
|
|
806
|
+
to layer norm
|
|
807
|
+
:param wrap_non_bool: whether to wrap real or int action fluent parameters
|
|
808
|
+
with non-linearity (e.g. sigmoid or ELU) to satisfy box constraints
|
|
710
809
|
'''
|
|
711
810
|
super(JaxDeepReactivePolicy, self).__init__()
|
|
811
|
+
|
|
812
|
+
if topology is None:
|
|
813
|
+
topology = [128, 64]
|
|
712
814
|
self._topology = topology
|
|
713
815
|
self._activations = [activation for _ in topology]
|
|
714
816
|
self._initializer_base = initializer
|
|
715
817
|
self._initializer = initializer
|
|
716
818
|
self._normalize = normalize
|
|
819
|
+
self._normalize_per_layer = normalize_per_layer
|
|
820
|
+
if normalizer_kwargs is None:
|
|
821
|
+
normalizer_kwargs = {'create_offset': True, 'create_scale': True}
|
|
822
|
+
self._normalizer_kwargs = normalizer_kwargs
|
|
823
|
+
self._wrap_non_bool = wrap_non_bool
|
|
717
824
|
|
|
718
|
-
def summarize_hyperparameters(self):
|
|
825
|
+
def summarize_hyperparameters(self) -> None:
|
|
826
|
+
bounds = '\n '.join(
|
|
827
|
+
map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
|
|
719
828
|
print(f'policy hyper-parameters:\n'
|
|
720
|
-
f' topology
|
|
721
|
-
f' activation_fn
|
|
722
|
-
f' initializer
|
|
723
|
-
f'
|
|
829
|
+
f' topology ={self._topology}\n'
|
|
830
|
+
f' activation_fn ={self._activations[0].__name__}\n'
|
|
831
|
+
f' initializer ={type(self._initializer_base).__name__}\n'
|
|
832
|
+
f' apply_input_norm ={self._normalize}\n'
|
|
833
|
+
f' input_norm_layerwise={self._normalize_per_layer}\n'
|
|
834
|
+
f' input_norm_args ={self._normalizer_kwargs}\n'
|
|
835
|
+
f'constraint-sat strategy:\n'
|
|
836
|
+
f' parsed_action_bounds=\n {bounds}\n'
|
|
837
|
+
f' wrap_non_bool ={self._wrap_non_bool}')
|
|
724
838
|
|
|
725
839
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
726
|
-
_bounds:
|
|
840
|
+
_bounds: Bounds,
|
|
841
|
+
horizon: int) -> None:
|
|
727
842
|
rddl = compiled.rddl
|
|
728
843
|
|
|
729
844
|
# calculate the correct action box bounds
|
|
@@ -737,7 +852,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
737
852
|
if 1 < allowed_actions < bool_action_count:
|
|
738
853
|
raise RDDLNotImplementedError(
|
|
739
854
|
f'Deep reactive policies currently do not support '
|
|
740
|
-
f'max-nondef-actions
|
|
855
|
+
f'max-nondef-actions {allowed_actions} > 1.')
|
|
741
856
|
use_constraint_satisfaction = allowed_actions < bool_action_count
|
|
742
857
|
|
|
743
858
|
noop = {var: (values[0] if isinstance(values, list) else values)
|
|
@@ -751,22 +866,75 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
751
866
|
|
|
752
867
|
ranges = rddl.variable_ranges
|
|
753
868
|
normalize = self._normalize
|
|
869
|
+
normalize_per_layer = self._normalize_per_layer
|
|
870
|
+
wrap_non_bool = self._wrap_non_bool
|
|
754
871
|
init = self._initializer
|
|
755
872
|
layers = list(enumerate(zip(self._topology, self._activations)))
|
|
756
873
|
layer_sizes = {var: np.prod(shape, dtype=int)
|
|
757
874
|
for (var, shape) in shapes.items()}
|
|
758
875
|
layer_names = {var: f'output_{var}'.replace('-', '_') for var in shapes}
|
|
759
876
|
|
|
760
|
-
#
|
|
761
|
-
|
|
877
|
+
# inputs for the policy network
|
|
878
|
+
if rddl.observ_fluents:
|
|
879
|
+
observed_vars = rddl.observ_fluents
|
|
880
|
+
else:
|
|
881
|
+
observed_vars = rddl.state_fluents
|
|
882
|
+
input_names = {var: f'{var}'.replace('-', '_') for var in observed_vars}
|
|
883
|
+
|
|
884
|
+
# catch if input norm is applied to size 1 tensor
|
|
885
|
+
if normalize:
|
|
886
|
+
non_bool_dims = 0
|
|
887
|
+
for (var, values) in observed_vars.items():
|
|
888
|
+
if ranges[var] != 'bool':
|
|
889
|
+
value_size = np.atleast_1d(values).size
|
|
890
|
+
if normalize_per_layer and value_size == 1:
|
|
891
|
+
raise_warning(
|
|
892
|
+
f'Cannot apply layer norm to state-fluent <{var}> '
|
|
893
|
+
f'of size 1: setting normalize_per_layer = False.',
|
|
894
|
+
'red')
|
|
895
|
+
normalize_per_layer = False
|
|
896
|
+
non_bool_dims += value_size
|
|
897
|
+
if not normalize_per_layer and non_bool_dims == 1:
|
|
898
|
+
raise_warning(
|
|
899
|
+
'Cannot apply layer norm to state-fluents of total size 1: '
|
|
900
|
+
'setting normalize = False.', 'red')
|
|
901
|
+
normalize = False
|
|
902
|
+
|
|
903
|
+
# convert subs dictionary into a state vector to feed to the MLP
|
|
904
|
+
def _jax_wrapped_policy_input(subs):
|
|
762
905
|
|
|
763
|
-
#
|
|
764
|
-
|
|
906
|
+
# concatenate all state variables into a single vector
|
|
907
|
+
# optionally apply layer norm to each input tensor
|
|
908
|
+
states_bool, states_non_bool = [], []
|
|
909
|
+
non_bool_dims = 0
|
|
910
|
+
for (var, value) in subs.items():
|
|
911
|
+
if var in observed_vars:
|
|
912
|
+
state = jnp.ravel(value)
|
|
913
|
+
if ranges[var] == 'bool':
|
|
914
|
+
states_bool.append(state)
|
|
915
|
+
else:
|
|
916
|
+
if normalize and normalize_per_layer:
|
|
917
|
+
normalizer = hk.LayerNorm(
|
|
918
|
+
axis=-1, param_axis=-1,
|
|
919
|
+
name=f'input_norm_{input_names[var]}',
|
|
920
|
+
**self._normalizer_kwargs)
|
|
921
|
+
state = normalizer(state)
|
|
922
|
+
states_non_bool.append(state)
|
|
923
|
+
non_bool_dims += state.size
|
|
924
|
+
state = jnp.concatenate(states_non_bool + states_bool)
|
|
925
|
+
|
|
926
|
+
# optionally perform layer normalization on the non-bool inputs
|
|
927
|
+
if normalize and not normalize_per_layer and non_bool_dims:
|
|
765
928
|
normalizer = hk.LayerNorm(
|
|
766
|
-
axis=-1, param_axis=-1,
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
state =
|
|
929
|
+
axis=-1, param_axis=-1, name='input_norm',
|
|
930
|
+
**self._normalizer_kwargs)
|
|
931
|
+
normalized = normalizer(state[:non_bool_dims])
|
|
932
|
+
state = state.at[:non_bool_dims].set(normalized)
|
|
933
|
+
return state
|
|
934
|
+
|
|
935
|
+
# predict actions from the policy network for current state
|
|
936
|
+
def _jax_wrapped_policy_network_predict(subs):
|
|
937
|
+
state = _jax_wrapped_policy_input(subs)
|
|
770
938
|
|
|
771
939
|
# feed state vector through hidden layers
|
|
772
940
|
hidden = state
|
|
@@ -789,16 +957,19 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
789
957
|
if not use_constraint_satisfaction:
|
|
790
958
|
actions[var] = jax.nn.sigmoid(output)
|
|
791
959
|
else:
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
960
|
+
if wrap_non_bool:
|
|
961
|
+
lower, upper = bounds_safe[var]
|
|
962
|
+
action = jnp.select(
|
|
963
|
+
condlist=cond_lists[var],
|
|
964
|
+
choicelist=[
|
|
965
|
+
lower + (upper - lower) * jax.nn.sigmoid(output),
|
|
966
|
+
lower + (jax.nn.elu(output) + 1.0),
|
|
967
|
+
upper - (jax.nn.elu(-output) + 1.0),
|
|
968
|
+
output
|
|
969
|
+
]
|
|
970
|
+
)
|
|
971
|
+
else:
|
|
972
|
+
action = output
|
|
802
973
|
actions[var] = action
|
|
803
974
|
|
|
804
975
|
# for constraint satisfaction wrap bool actions with softmax
|
|
@@ -826,21 +997,14 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
826
997
|
actions[name] = action
|
|
827
998
|
start += size
|
|
828
999
|
return actions
|
|
829
|
-
|
|
830
|
-
# state is concatenated into single tensor
|
|
831
|
-
def _jax_wrapped_subs_to_state(subs):
|
|
832
|
-
subs = {var: value
|
|
833
|
-
for (var, value) in subs.items()
|
|
834
|
-
if var in rddl.state_fluents}
|
|
835
|
-
flat_subs = jax.tree_map(jnp.ravel, subs)
|
|
836
|
-
states = list(flat_subs.values())
|
|
837
|
-
state = jnp.concatenate(states)
|
|
838
|
-
return state
|
|
839
1000
|
|
|
840
1001
|
# train action prediction
|
|
841
1002
|
def _jax_wrapped_drp_predict_train(key, params, hyperparams, step, subs):
|
|
842
|
-
|
|
843
|
-
|
|
1003
|
+
actions = predict_fn.apply(params, subs)
|
|
1004
|
+
if not wrap_non_bool:
|
|
1005
|
+
for (var, action) in actions.items():
|
|
1006
|
+
if var != bool_key and ranges[var] != 'bool':
|
|
1007
|
+
actions[var] = jnp.clip(action, *bounds[var])
|
|
844
1008
|
if use_constraint_satisfaction:
|
|
845
1009
|
bool_actions = _jax_unstack_bool_from_softmax(actions[bool_key])
|
|
846
1010
|
actions.update(bool_actions)
|
|
@@ -886,14 +1050,13 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
886
1050
|
def _jax_wrapped_drp_init(key, hyperparams, subs):
|
|
887
1051
|
subs = {var: value[0, ...]
|
|
888
1052
|
for (var, value) in subs.items()
|
|
889
|
-
if var in
|
|
890
|
-
|
|
891
|
-
params = predict_fn.init(key, state)
|
|
1053
|
+
if var in observed_vars}
|
|
1054
|
+
params = predict_fn.init(key, subs)
|
|
892
1055
|
return params
|
|
893
1056
|
|
|
894
1057
|
self.initializer = _jax_wrapped_drp_init
|
|
895
1058
|
|
|
896
|
-
def guess_next_epoch(self, params:
|
|
1059
|
+
def guess_next_epoch(self, params: Pytree) -> Pytree:
|
|
897
1060
|
return params
|
|
898
1061
|
|
|
899
1062
|
|
|
@@ -904,24 +1067,170 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
904
1067
|
# - more stable but slower line search based planner
|
|
905
1068
|
#
|
|
906
1069
|
# ***********************************************************************
|
|
1070
|
+
|
|
1071
|
+
class RollingMean:
|
|
1072
|
+
'''Maintains an estimate of the rolling mean of a stream of real-valued
|
|
1073
|
+
observations.'''
|
|
1074
|
+
|
|
1075
|
+
def __init__(self, window_size: int) -> None:
|
|
1076
|
+
self._window_size = window_size
|
|
1077
|
+
self._memory = deque(maxlen=window_size)
|
|
1078
|
+
self._total = 0
|
|
1079
|
+
|
|
1080
|
+
def update(self, x: float) -> float:
|
|
1081
|
+
memory = self._memory
|
|
1082
|
+
self._total += x
|
|
1083
|
+
if len(memory) == self._window_size:
|
|
1084
|
+
self._total -= memory.popleft()
|
|
1085
|
+
memory.append(x)
|
|
1086
|
+
return self._total / len(memory)
|
|
1087
|
+
|
|
1088
|
+
|
|
1089
|
+
class JaxPlannerPlot:
|
|
1090
|
+
'''Supports plotting and visualization of a JAX policy in real time.'''
|
|
1091
|
+
|
|
1092
|
+
def __init__(self, rddl: RDDLPlanningModel, horizon: int,
|
|
1093
|
+
show_violin: bool=True, show_action: bool=True) -> None:
|
|
1094
|
+
'''Creates a new planner visualizer.
|
|
1095
|
+
|
|
1096
|
+
:param rddl: the planning model to optimize
|
|
1097
|
+
:param horizon: the lookahead or planning horizon
|
|
1098
|
+
:param show_violin: whether to show the distribution of batch losses
|
|
1099
|
+
:param show_action: whether to show heatmaps of the action fluents
|
|
1100
|
+
'''
|
|
1101
|
+
num_plots = 1
|
|
1102
|
+
if show_violin:
|
|
1103
|
+
num_plots += 1
|
|
1104
|
+
if show_action:
|
|
1105
|
+
num_plots += len(rddl.action_fluents)
|
|
1106
|
+
self._fig, axes = plt.subplots(num_plots)
|
|
1107
|
+
if num_plots == 1:
|
|
1108
|
+
axes = [axes]
|
|
1109
|
+
|
|
1110
|
+
# prepare the loss plot
|
|
1111
|
+
self._loss_ax = axes[0]
|
|
1112
|
+
self._loss_ax.autoscale(enable=True)
|
|
1113
|
+
self._loss_ax.set_xlabel('training time')
|
|
1114
|
+
self._loss_ax.set_ylabel('loss value')
|
|
1115
|
+
self._loss_plot = self._loss_ax.plot(
|
|
1116
|
+
[], [], linestyle=':', marker='o', markersize=2)[0]
|
|
1117
|
+
self._loss_back = self._fig.canvas.copy_from_bbox(self._loss_ax.bbox)
|
|
1118
|
+
|
|
1119
|
+
# prepare the violin plot
|
|
1120
|
+
if show_violin:
|
|
1121
|
+
self._hist_ax = axes[1]
|
|
1122
|
+
else:
|
|
1123
|
+
self._hist_ax = None
|
|
1124
|
+
|
|
1125
|
+
# prepare the action plots
|
|
1126
|
+
if show_action:
|
|
1127
|
+
self._action_ax = {name: axes[idx + (2 if show_violin else 1)]
|
|
1128
|
+
for (idx, name) in enumerate(rddl.action_fluents)}
|
|
1129
|
+
self._action_plots = {}
|
|
1130
|
+
for name in rddl.action_fluents:
|
|
1131
|
+
ax = self._action_ax[name]
|
|
1132
|
+
if rddl.variable_ranges[name] == 'bool':
|
|
1133
|
+
vmin, vmax = 0.0, 1.0
|
|
1134
|
+
else:
|
|
1135
|
+
vmin, vmax = None, None
|
|
1136
|
+
action_dim = 1
|
|
1137
|
+
for dim in rddl.object_counts(rddl.variable_params[name]):
|
|
1138
|
+
action_dim *= dim
|
|
1139
|
+
action_plot = ax.pcolormesh(
|
|
1140
|
+
np.zeros((action_dim, horizon)),
|
|
1141
|
+
cmap='seismic', vmin=vmin, vmax=vmax)
|
|
1142
|
+
ax.set_aspect('auto')
|
|
1143
|
+
ax.set_xlabel('decision epoch')
|
|
1144
|
+
ax.set_ylabel(name)
|
|
1145
|
+
plt.colorbar(action_plot, ax=ax)
|
|
1146
|
+
self._action_plots[name] = action_plot
|
|
1147
|
+
self._action_back = {name: self._fig.canvas.copy_from_bbox(ax.bbox)
|
|
1148
|
+
for (name, ax) in self._action_ax.items()}
|
|
1149
|
+
else:
|
|
1150
|
+
self._action_ax = None
|
|
1151
|
+
self._action_plots = None
|
|
1152
|
+
self._action_back = None
|
|
1153
|
+
|
|
1154
|
+
plt.tight_layout()
|
|
1155
|
+
plt.show(block=False)
|
|
1156
|
+
|
|
1157
|
+
def redraw(self, xticks, losses, actions, returns) -> None:
|
|
1158
|
+
|
|
1159
|
+
# draw the loss curve
|
|
1160
|
+
self._fig.canvas.restore_region(self._loss_back)
|
|
1161
|
+
self._loss_plot.set_xdata(xticks)
|
|
1162
|
+
self._loss_plot.set_ydata(losses)
|
|
1163
|
+
self._loss_ax.set_xlim([0, len(xticks)])
|
|
1164
|
+
self._loss_ax.set_ylim([np.min(losses), np.max(losses)])
|
|
1165
|
+
self._loss_ax.draw_artist(self._loss_plot)
|
|
1166
|
+
self._fig.canvas.blit(self._loss_ax.bbox)
|
|
1167
|
+
|
|
1168
|
+
# draw the violin plot
|
|
1169
|
+
if self._hist_ax is not None:
|
|
1170
|
+
self._hist_ax.clear()
|
|
1171
|
+
self._hist_ax.set_xlabel('loss value')
|
|
1172
|
+
self._hist_ax.set_ylabel('density')
|
|
1173
|
+
self._hist_ax.violinplot(returns, vert=False, showmeans=True)
|
|
1174
|
+
|
|
1175
|
+
# draw the actions
|
|
1176
|
+
if self._action_ax is not None:
|
|
1177
|
+
for (name, values) in actions.items():
|
|
1178
|
+
values = np.mean(values, axis=0, dtype=float)
|
|
1179
|
+
values = np.reshape(values, newshape=(values.shape[0], -1)).T
|
|
1180
|
+
self._fig.canvas.restore_region(self._action_back[name])
|
|
1181
|
+
self._action_plots[name].set_array(values)
|
|
1182
|
+
self._action_ax[name].draw_artist(self._action_plots[name])
|
|
1183
|
+
self._fig.canvas.blit(self._action_ax[name].bbox)
|
|
1184
|
+
self._action_plots[name].set_clim([np.min(values), np.max(values)])
|
|
1185
|
+
|
|
1186
|
+
self._fig.canvas.draw()
|
|
1187
|
+
self._fig.canvas.flush_events()
|
|
1188
|
+
|
|
1189
|
+
def close(self) -> None:
|
|
1190
|
+
plt.close(self._fig)
|
|
1191
|
+
del self._loss_ax, self._hist_ax, self._action_ax, \
|
|
1192
|
+
self._loss_plot, self._action_plots, self._fig, \
|
|
1193
|
+
self._loss_back, self._action_back
|
|
1194
|
+
|
|
1195
|
+
|
|
1196
|
+
class JaxPlannerStatus(Enum):
|
|
1197
|
+
'''Represents the status of a policy update from the JAX planner,
|
|
1198
|
+
including whether the update resulted in nan gradient,
|
|
1199
|
+
whether progress was made, budget was reached, or other information that
|
|
1200
|
+
can be used to monitor and act based on the planner's progress.'''
|
|
1201
|
+
|
|
1202
|
+
NORMAL = 0
|
|
1203
|
+
NO_PROGRESS = 1
|
|
1204
|
+
PRECONDITION_POSSIBLY_UNSATISFIED = 2
|
|
1205
|
+
INVALID_GRADIENT = 3
|
|
1206
|
+
TIME_BUDGET_REACHED = 4
|
|
1207
|
+
ITER_BUDGET_REACHED = 5
|
|
1208
|
+
|
|
1209
|
+
def is_failure(self) -> bool:
|
|
1210
|
+
return self.value >= 3
|
|
1211
|
+
|
|
1212
|
+
|
|
907
1213
|
class JaxBackpropPlanner:
|
|
908
1214
|
'''A class for optimizing an action sequence in the given RDDL MDP using
|
|
909
1215
|
gradient descent.'''
|
|
910
1216
|
|
|
911
1217
|
def __init__(self, rddl: RDDLLiftedModel,
|
|
912
1218
|
plan: JaxPlan,
|
|
913
|
-
batch_size_train: int,
|
|
914
|
-
batch_size_test: int=None,
|
|
915
|
-
rollout_horizon: int=None,
|
|
1219
|
+
batch_size_train: int=32,
|
|
1220
|
+
batch_size_test: Optional[int]=None,
|
|
1221
|
+
rollout_horizon: Optional[int]=None,
|
|
916
1222
|
use64bit: bool=False,
|
|
917
|
-
action_bounds:
|
|
1223
|
+
action_bounds: Optional[Bounds]=None,
|
|
918
1224
|
optimizer: Callable[..., optax.GradientTransformation]=optax.rmsprop,
|
|
919
|
-
optimizer_kwargs:
|
|
920
|
-
clip_grad: float=None,
|
|
1225
|
+
optimizer_kwargs: Optional[Kwargs]=None,
|
|
1226
|
+
clip_grad: Optional[float]=None,
|
|
921
1227
|
logic: FuzzyLogic=FuzzyLogic(),
|
|
922
1228
|
use_symlog_reward: bool=False,
|
|
923
|
-
utility
|
|
924
|
-
|
|
1229
|
+
utility: Union[Callable[[jnp.ndarray], float], str]='mean',
|
|
1230
|
+
utility_kwargs: Optional[Kwargs]=None,
|
|
1231
|
+
cpfs_without_grad: Optional[Set[str]]=None,
|
|
1232
|
+
compile_non_fluent_exact: bool=True,
|
|
1233
|
+
logger: Optional[Logger]=None) -> None:
|
|
925
1234
|
'''Creates a new gradient-based algorithm for optimizing action sequences
|
|
926
1235
|
(plan) in the given RDDL. Some operations will be converted to their
|
|
927
1236
|
differentiable counterparts; the specific operations can be customized
|
|
@@ -946,9 +1255,16 @@ class JaxBackpropPlanner:
|
|
|
946
1255
|
:param use_symlog_reward: whether to use the symlog transform on the
|
|
947
1256
|
reward as a form of normalization
|
|
948
1257
|
:param utility: how to aggregate return observations to compute utility
|
|
949
|
-
of a policy or plan
|
|
1258
|
+
of a policy or plan; must be either a function mapping jax array to a
|
|
1259
|
+
scalar, or a a string identifying the utility function by name
|
|
1260
|
+
("mean", "mean_var", "entropic", or "cvar" are currently supported)
|
|
1261
|
+
:param utility_kwargs: additional keyword arguments to pass hyper-
|
|
1262
|
+
parameters to the utility function call
|
|
950
1263
|
:param cpfs_without_grad: which CPFs do not have gradients (use straight
|
|
951
1264
|
through gradient trick)
|
|
1265
|
+
:param compile_non_fluent_exact: whether non-fluent expressions
|
|
1266
|
+
are always compiled using exact JAX expressions
|
|
1267
|
+
:param logger: to log information about compilation to file
|
|
952
1268
|
'''
|
|
953
1269
|
self.rddl = rddl
|
|
954
1270
|
self.plan = plan
|
|
@@ -959,22 +1275,25 @@ class JaxBackpropPlanner:
|
|
|
959
1275
|
if rollout_horizon is None:
|
|
960
1276
|
rollout_horizon = rddl.horizon
|
|
961
1277
|
self.horizon = rollout_horizon
|
|
1278
|
+
if action_bounds is None:
|
|
1279
|
+
action_bounds = {}
|
|
962
1280
|
self._action_bounds = action_bounds
|
|
963
1281
|
self.use64bit = use64bit
|
|
964
1282
|
self._optimizer_name = optimizer
|
|
1283
|
+
if optimizer_kwargs is None:
|
|
1284
|
+
optimizer_kwargs = {'learning_rate': 0.1}
|
|
965
1285
|
self._optimizer_kwargs = optimizer_kwargs
|
|
966
1286
|
self.clip_grad = clip_grad
|
|
967
1287
|
|
|
968
1288
|
# set optimizer
|
|
969
1289
|
try:
|
|
970
1290
|
optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs)
|
|
971
|
-
except:
|
|
1291
|
+
except Exception as _:
|
|
972
1292
|
raise_warning(
|
|
973
1293
|
'Failed to inject hyperparameters into optax optimizer, '
|
|
974
1294
|
'rolling back to safer method: please note that modification of '
|
|
975
1295
|
'optimizer hyperparameters will not work, and it is '
|
|
976
|
-
'recommended to update
|
|
977
|
-
'red')
|
|
1296
|
+
'recommended to update optax and related packages.', 'red')
|
|
978
1297
|
optimizer = optimizer(**optimizer_kwargs)
|
|
979
1298
|
if clip_grad is None:
|
|
980
1299
|
self.optimizer = optimizer
|
|
@@ -983,33 +1302,84 @@ class JaxBackpropPlanner:
|
|
|
983
1302
|
optax.clip(clip_grad),
|
|
984
1303
|
optimizer
|
|
985
1304
|
)
|
|
986
|
-
|
|
1305
|
+
|
|
1306
|
+
# set utility
|
|
1307
|
+
if isinstance(utility, str):
|
|
1308
|
+
utility = utility.lower()
|
|
1309
|
+
if utility == 'mean':
|
|
1310
|
+
utility_fn = jnp.mean
|
|
1311
|
+
elif utility == 'mean_var':
|
|
1312
|
+
utility_fn = mean_variance_utility
|
|
1313
|
+
elif utility == 'entropic':
|
|
1314
|
+
utility_fn = entropic_utility
|
|
1315
|
+
elif utility == 'cvar':
|
|
1316
|
+
utility_fn = cvar_utility
|
|
1317
|
+
else:
|
|
1318
|
+
raise RDDLNotImplementedError(
|
|
1319
|
+
f'Utility function <{utility}> is not supported: '
|
|
1320
|
+
'must be one of ["mean", "mean_var", "entropic", "cvar"].')
|
|
1321
|
+
else:
|
|
1322
|
+
utility_fn = utility
|
|
1323
|
+
self.utility = utility_fn
|
|
1324
|
+
|
|
1325
|
+
if utility_kwargs is None:
|
|
1326
|
+
utility_kwargs = {}
|
|
1327
|
+
self.utility_kwargs = utility_kwargs
|
|
1328
|
+
|
|
987
1329
|
self.logic = logic
|
|
1330
|
+
self.logic.set_use64bit(self.use64bit)
|
|
988
1331
|
self.use_symlog_reward = use_symlog_reward
|
|
989
|
-
|
|
1332
|
+
if cpfs_without_grad is None:
|
|
1333
|
+
cpfs_without_grad = set()
|
|
990
1334
|
self.cpfs_without_grad = cpfs_without_grad
|
|
1335
|
+
self.compile_non_fluent_exact = compile_non_fluent_exact
|
|
1336
|
+
self.logger = logger
|
|
991
1337
|
|
|
992
1338
|
self._jax_compile_rddl()
|
|
993
1339
|
self._jax_compile_optimizer()
|
|
994
|
-
|
|
995
|
-
def
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
|
|
1340
|
+
|
|
1341
|
+
def _summarize_system(self) -> None:
|
|
1342
|
+
try:
|
|
1343
|
+
jaxlib_version = jax._src.lib.version_str
|
|
1344
|
+
except Exception as _:
|
|
1345
|
+
jaxlib_version = 'N/A'
|
|
1346
|
+
try:
|
|
1347
|
+
devices_short = ', '.join(
|
|
1348
|
+
map(str, jax._src.xla_bridge.devices())).replace('\n', '')
|
|
1349
|
+
except Exception as _:
|
|
1350
|
+
devices_short = 'N/A'
|
|
1351
|
+
print('\n'
|
|
1352
|
+
f'JAX Planner version {__version__}\n'
|
|
1353
|
+
f'Python {sys.version}\n'
|
|
1354
|
+
f'jax {jax.version.__version__}, jaxlib {jaxlib_version}, '
|
|
1355
|
+
f'optax {optax.__version__}, haiku {hk.__version__}, '
|
|
1356
|
+
f'numpy {np.__version__}\n'
|
|
1357
|
+
f'devices: {devices_short}\n')
|
|
1358
|
+
|
|
1359
|
+
def summarize_hyperparameters(self) -> None:
|
|
1360
|
+
print(f'objective hyper-parameters:\n'
|
|
1361
|
+
f' utility_fn ={self.utility.__name__}\n'
|
|
1362
|
+
f' utility args ={self.utility_kwargs}\n'
|
|
1363
|
+
f' use_symlog ={self.use_symlog_reward}\n'
|
|
1364
|
+
f' lookahead ={self.horizon}\n'
|
|
1365
|
+
f' user_action_bounds={self._action_bounds}\n'
|
|
1366
|
+
f' fuzzy logic type ={type(self.logic).__name__}\n'
|
|
1367
|
+
f' nonfluents exact ={self.compile_non_fluent_exact}\n'
|
|
1368
|
+
f' cpfs_no_gradient ={self.cpfs_without_grad}\n'
|
|
1003
1369
|
f'optimizer hyper-parameters:\n'
|
|
1004
|
-
f' use_64_bit
|
|
1005
|
-
f' optimizer
|
|
1006
|
-
f' optimizer args
|
|
1007
|
-
f' clip_gradient
|
|
1008
|
-
f' batch_size_train={self.batch_size_train}\n'
|
|
1009
|
-
f' batch_size_test
|
|
1370
|
+
f' use_64_bit ={self.use64bit}\n'
|
|
1371
|
+
f' optimizer ={self._optimizer_name.__name__}\n'
|
|
1372
|
+
f' optimizer args ={self._optimizer_kwargs}\n'
|
|
1373
|
+
f' clip_gradient ={self.clip_grad}\n'
|
|
1374
|
+
f' batch_size_train ={self.batch_size_train}\n'
|
|
1375
|
+
f' batch_size_test ={self.batch_size_test}')
|
|
1010
1376
|
self.plan.summarize_hyperparameters()
|
|
1011
1377
|
self.logic.summarize_hyperparameters()
|
|
1012
1378
|
|
|
1379
|
+
# ===========================================================================
|
|
1380
|
+
# COMPILATION SUBROUTINES
|
|
1381
|
+
# ===========================================================================
|
|
1382
|
+
|
|
1013
1383
|
def _jax_compile_rddl(self):
|
|
1014
1384
|
rddl = self.rddl
|
|
1015
1385
|
|
|
@@ -1017,13 +1387,18 @@ class JaxBackpropPlanner:
|
|
|
1017
1387
|
self.compiled = JaxRDDLCompilerWithGrad(
|
|
1018
1388
|
rddl=rddl,
|
|
1019
1389
|
logic=self.logic,
|
|
1390
|
+
logger=self.logger,
|
|
1020
1391
|
use64bit=self.use64bit,
|
|
1021
|
-
cpfs_without_grad=self.cpfs_without_grad
|
|
1022
|
-
|
|
1392
|
+
cpfs_without_grad=self.cpfs_without_grad,
|
|
1393
|
+
compile_non_fluent_exact=self.compile_non_fluent_exact)
|
|
1394
|
+
self.compiled.compile(log_jax_expr=True, heading='RELAXED MODEL')
|
|
1023
1395
|
|
|
1024
1396
|
# Jax compilation of the exact RDDL for testing
|
|
1025
|
-
self.test_compiled = JaxRDDLCompiler(
|
|
1026
|
-
|
|
1397
|
+
self.test_compiled = JaxRDDLCompiler(
|
|
1398
|
+
rddl=rddl,
|
|
1399
|
+
logger=self.logger,
|
|
1400
|
+
use64bit=self.use64bit)
|
|
1401
|
+
self.test_compiled.compile(log_jax_expr=True, heading='EXACT MODEL')
|
|
1027
1402
|
|
|
1028
1403
|
def _jax_compile_optimizer(self):
|
|
1029
1404
|
|
|
@@ -1039,6 +1414,7 @@ class JaxBackpropPlanner:
|
|
|
1039
1414
|
policy=self.plan.train_policy,
|
|
1040
1415
|
n_steps=self.horizon,
|
|
1041
1416
|
n_batch=self.batch_size_train)
|
|
1417
|
+
self.train_rollouts = train_rollouts
|
|
1042
1418
|
|
|
1043
1419
|
test_rollouts = self.test_compiled.compile_rollouts(
|
|
1044
1420
|
policy=self.plan.test_policy,
|
|
@@ -1051,11 +1427,10 @@ class JaxBackpropPlanner:
|
|
|
1051
1427
|
|
|
1052
1428
|
# losses
|
|
1053
1429
|
train_loss = self._jax_loss(train_rollouts, use_symlog=self.use_symlog_reward)
|
|
1054
|
-
self.train_loss = jax.jit(train_loss)
|
|
1055
1430
|
self.test_loss = jax.jit(self._jax_loss(test_rollouts, use_symlog=False))
|
|
1056
1431
|
|
|
1057
1432
|
# optimization
|
|
1058
|
-
self.update =
|
|
1433
|
+
self.update = self._jax_update(train_loss)
|
|
1059
1434
|
|
|
1060
1435
|
def _jax_return(self, use_symlog):
|
|
1061
1436
|
gamma = self.rddl.discount
|
|
@@ -1068,13 +1443,14 @@ class JaxBackpropPlanner:
|
|
|
1068
1443
|
rewards = rewards * discount[jnp.newaxis, ...]
|
|
1069
1444
|
returns = jnp.sum(rewards, axis=1)
|
|
1070
1445
|
if use_symlog:
|
|
1071
|
-
returns = jnp.sign(returns) * jnp.
|
|
1446
|
+
returns = jnp.sign(returns) * jnp.log(1.0 + jnp.abs(returns))
|
|
1072
1447
|
return returns
|
|
1073
1448
|
|
|
1074
1449
|
return _jax_wrapped_returns
|
|
1075
1450
|
|
|
1076
1451
|
def _jax_loss(self, rollouts, use_symlog=False):
|
|
1077
|
-
utility_fn = self.utility
|
|
1452
|
+
utility_fn = self.utility
|
|
1453
|
+
utility_kwargs = self.utility_kwargs
|
|
1078
1454
|
_jax_wrapped_returns = self._jax_return(use_symlog)
|
|
1079
1455
|
|
|
1080
1456
|
# the loss is the average cumulative reward across all roll-outs
|
|
@@ -1083,7 +1459,7 @@ class JaxBackpropPlanner:
|
|
|
1083
1459
|
log = rollouts(key, policy_params, hyperparams, subs, model_params)
|
|
1084
1460
|
rewards = log['reward']
|
|
1085
1461
|
returns = _jax_wrapped_returns(rewards)
|
|
1086
|
-
utility = utility_fn(returns)
|
|
1462
|
+
utility = utility_fn(returns, **utility_kwargs)
|
|
1087
1463
|
loss = -utility
|
|
1088
1464
|
return loss, log
|
|
1089
1465
|
|
|
@@ -1096,7 +1472,7 @@ class JaxBackpropPlanner:
|
|
|
1096
1472
|
def _jax_wrapped_init_policy(key, hyperparams, subs):
|
|
1097
1473
|
policy_params = init(key, hyperparams, subs)
|
|
1098
1474
|
opt_state = optimizer.init(policy_params)
|
|
1099
|
-
return policy_params, opt_state
|
|
1475
|
+
return policy_params, opt_state, None
|
|
1100
1476
|
|
|
1101
1477
|
return _jax_wrapped_init_policy
|
|
1102
1478
|
|
|
@@ -1107,17 +1483,18 @@ class JaxBackpropPlanner:
|
|
|
1107
1483
|
# calculate the plan gradient w.r.t. return loss and update optimizer
|
|
1108
1484
|
# also perform a projection step to satisfy constraints on actions
|
|
1109
1485
|
def _jax_wrapped_plan_update(key, policy_params, hyperparams,
|
|
1110
|
-
subs, model_params, opt_state):
|
|
1111
|
-
grad_fn = jax.
|
|
1112
|
-
|
|
1486
|
+
subs, model_params, opt_state, opt_aux):
|
|
1487
|
+
grad_fn = jax.value_and_grad(loss, argnums=1, has_aux=True)
|
|
1488
|
+
(loss_val, log), grad = grad_fn(
|
|
1489
|
+
key, policy_params, hyperparams, subs, model_params)
|
|
1113
1490
|
updates, opt_state = optimizer.update(grad, opt_state)
|
|
1114
1491
|
policy_params = optax.apply_updates(policy_params, updates)
|
|
1115
1492
|
policy_params, converged = projection(policy_params, hyperparams)
|
|
1116
1493
|
log['grad'] = grad
|
|
1117
1494
|
log['updates'] = updates
|
|
1118
|
-
return policy_params, converged, opt_state, log
|
|
1495
|
+
return policy_params, converged, opt_state, None, loss_val, log
|
|
1119
1496
|
|
|
1120
|
-
return _jax_wrapped_plan_update
|
|
1497
|
+
return jax.jit(_jax_wrapped_plan_update)
|
|
1121
1498
|
|
|
1122
1499
|
def _batched_init_subs(self, subs):
|
|
1123
1500
|
rddl = self.rddl
|
|
@@ -1145,15 +1522,106 @@ class JaxBackpropPlanner:
|
|
|
1145
1522
|
|
|
1146
1523
|
return init_train, init_test
|
|
1147
1524
|
|
|
1148
|
-
def
|
|
1149
|
-
|
|
1150
|
-
|
|
1525
|
+
def as_optimization_problem(
|
|
1526
|
+
self, key: Optional[random.PRNGKey]=None,
|
|
1527
|
+
policy_hyperparams: Optional[Pytree]=None,
|
|
1528
|
+
loss_function_updates_key: bool=True,
|
|
1529
|
+
grad_function_updates_key: bool=False) -> Tuple[Callable, Callable, np.ndarray, Callable]:
|
|
1530
|
+
'''Returns a function that computes the loss and a function that
|
|
1531
|
+
computes gradient of the return as a 1D vector given a 1D representation
|
|
1532
|
+
of policy parameters. These functions are designed to be compatible with
|
|
1533
|
+
off-the-shelf optimizers such as scipy.
|
|
1534
|
+
|
|
1535
|
+
Also returns the initial parameter vector to seed an optimizer,
|
|
1536
|
+
as well as a mapping that recovers the parameter pytree from the vector.
|
|
1537
|
+
The PRNG key is updated internally starting from the optional given key.
|
|
1538
|
+
|
|
1539
|
+
Constraints on actions, if they are required, cannot be constructed
|
|
1540
|
+
automatically in the general case. The user should build constraints
|
|
1541
|
+
for each problem in the format required by the downstream optimizer.
|
|
1542
|
+
|
|
1543
|
+
:param key: JAX PRNG key (derived from clock if not provided)
|
|
1544
|
+
:param policy_hyperparameters: hyper-parameters for the policy/plan,
|
|
1545
|
+
such as weights for sigmoid wrapping boolean actions (defaults to 1
|
|
1546
|
+
for all action-fluents if not provided)
|
|
1547
|
+
:param loss_function_updates_key: if True, the loss function
|
|
1548
|
+
updates the PRNG key internally independently of the grad function
|
|
1549
|
+
:param grad_function_updates_key: if True, the gradient function
|
|
1550
|
+
updates the PRNG key internally independently of the loss function.
|
|
1551
|
+
'''
|
|
1151
1552
|
|
|
1152
|
-
|
|
1553
|
+
# if PRNG key is not provided
|
|
1554
|
+
if key is None:
|
|
1555
|
+
key = random.PRNGKey(round(time.time() * 1000))
|
|
1556
|
+
|
|
1557
|
+
# initialize the initial fluents, model parameters, policy hyper-params
|
|
1558
|
+
subs = self.test_compiled.init_values
|
|
1559
|
+
train_subs, _ = self._batched_init_subs(subs)
|
|
1560
|
+
model_params = self.compiled.model_params
|
|
1561
|
+
if policy_hyperparams is None:
|
|
1562
|
+
raise_warning('policy_hyperparams is not set, setting 1.0 for '
|
|
1563
|
+
'all action-fluents which could be suboptimal.')
|
|
1564
|
+
policy_hyperparams = {action: 1.0
|
|
1565
|
+
for action in self.rddl.action_fluents}
|
|
1566
|
+
|
|
1567
|
+
# initialize the policy parameters
|
|
1568
|
+
params_guess, *_ = self.initialize(key, policy_hyperparams, train_subs)
|
|
1569
|
+
guess_1d, unravel_fn = jax.flatten_util.ravel_pytree(params_guess)
|
|
1570
|
+
guess_1d = np.asarray(guess_1d)
|
|
1571
|
+
|
|
1572
|
+
# computes the training loss function and its 1D gradient
|
|
1573
|
+
loss_fn = self._jax_loss(self.train_rollouts)
|
|
1574
|
+
|
|
1575
|
+
@jax.jit
|
|
1576
|
+
def _loss_with_key(key, params_1d):
|
|
1577
|
+
policy_params = unravel_fn(params_1d)
|
|
1578
|
+
loss_val, _ = loss_fn(key, policy_params, policy_hyperparams,
|
|
1579
|
+
train_subs, model_params)
|
|
1580
|
+
return loss_val
|
|
1581
|
+
|
|
1582
|
+
@jax.jit
|
|
1583
|
+
def _grad_with_key(key, params_1d):
|
|
1584
|
+
policy_params = unravel_fn(params_1d)
|
|
1585
|
+
grad_fn = jax.grad(loss_fn, argnums=1, has_aux=True)
|
|
1586
|
+
grad_val, _ = grad_fn(key, policy_params, policy_hyperparams,
|
|
1587
|
+
train_subs, model_params)
|
|
1588
|
+
grad_1d = jax.flatten_util.ravel_pytree(grad_val)[0]
|
|
1589
|
+
return grad_1d
|
|
1590
|
+
|
|
1591
|
+
def _loss_function(params_1d):
|
|
1592
|
+
nonlocal key
|
|
1593
|
+
if loss_function_updates_key:
|
|
1594
|
+
key, subkey = random.split(key)
|
|
1595
|
+
else:
|
|
1596
|
+
subkey = key
|
|
1597
|
+
loss_val = _loss_with_key(subkey, params_1d)
|
|
1598
|
+
loss_val = float(loss_val)
|
|
1599
|
+
return loss_val
|
|
1600
|
+
|
|
1601
|
+
def _grad_function(params_1d):
|
|
1602
|
+
nonlocal key
|
|
1603
|
+
if grad_function_updates_key:
|
|
1604
|
+
key, subkey = random.split(key)
|
|
1605
|
+
else:
|
|
1606
|
+
subkey = key
|
|
1607
|
+
grad = _grad_with_key(subkey, params_1d)
|
|
1608
|
+
grad = np.asarray(grad)
|
|
1609
|
+
return grad
|
|
1610
|
+
|
|
1611
|
+
return _loss_function, _grad_function, guess_1d, jax.jit(unravel_fn)
|
|
1612
|
+
|
|
1613
|
+
# ===========================================================================
|
|
1614
|
+
# OPTIMIZE API
|
|
1615
|
+
# ===========================================================================
|
|
1616
|
+
|
|
1617
|
+
def optimize(self, *args, **kwargs) -> Dict[str, Any]:
|
|
1618
|
+
'''Compute an optimal policy or plan. Return the callback from training.
|
|
1619
|
+
|
|
1620
|
+
:param key: JAX PRNG key (derived from clock if not provided)
|
|
1153
1621
|
:param epochs: the maximum number of steps of gradient descent
|
|
1154
|
-
:param the maximum number of steps of gradient descent
|
|
1155
1622
|
:param train_seconds: total time allocated for gradient descent
|
|
1156
1623
|
:param plot_step: frequency to plot the plan and save result to disk
|
|
1624
|
+
:param plot_kwargs: additional arguments to pass to the plotter
|
|
1157
1625
|
:param model_params: optional model-parameters to override default
|
|
1158
1626
|
:param policy_hyperparams: hyper-parameters for the policy/plan, such as
|
|
1159
1627
|
weights for sigmoid wrapping boolean actions
|
|
@@ -1161,64 +1629,110 @@ class JaxBackpropPlanner:
|
|
|
1161
1629
|
their values: if None initializes all variables from the RDDL instance
|
|
1162
1630
|
:param guess: initial policy parameters: if None will use the initializer
|
|
1163
1631
|
specified in this instance
|
|
1164
|
-
:param
|
|
1165
|
-
|
|
1166
|
-
|
|
1632
|
+
:param print_summary: whether to print planner header, parameter
|
|
1633
|
+
summary, and diagnosis
|
|
1634
|
+
:param print_progress: whether to print the progress bar during training
|
|
1635
|
+
:param test_rolling_window: the test return is averaged on a rolling
|
|
1636
|
+
window of the past test_rolling_window returns when updating the best
|
|
1637
|
+
parameters found so far
|
|
1638
|
+
:param tqdm_position: position of tqdm progress bar (for multiprocessing)
|
|
1167
1639
|
'''
|
|
1168
1640
|
it = self.optimize_generator(*args, **kwargs)
|
|
1169
|
-
|
|
1170
|
-
if
|
|
1171
|
-
|
|
1641
|
+
|
|
1642
|
+
# if the python is C-compiled then the deque is native C and much faster
|
|
1643
|
+
# than naively exhausting iterator, but not if the python is some other
|
|
1644
|
+
# version (e.g. PyPi); for details, see
|
|
1645
|
+
# https://stackoverflow.com/questions/50937966/fastest-most-pythonic-way-to-consume-an-iterator
|
|
1646
|
+
callback = None
|
|
1647
|
+
if sys.implementation.name == 'cpython':
|
|
1648
|
+
last_callback = deque(it, maxlen=1)
|
|
1649
|
+
if last_callback:
|
|
1650
|
+
callback = last_callback.pop()
|
|
1172
1651
|
else:
|
|
1173
|
-
|
|
1652
|
+
for callback in it:
|
|
1653
|
+
pass
|
|
1654
|
+
return callback
|
|
1174
1655
|
|
|
1175
|
-
def optimize_generator(self, key: random.PRNGKey,
|
|
1656
|
+
def optimize_generator(self, key: Optional[random.PRNGKey]=None,
|
|
1176
1657
|
epochs: int=999999,
|
|
1177
1658
|
train_seconds: float=120.,
|
|
1178
|
-
plot_step: int=None,
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
|
|
1659
|
+
plot_step: Optional[int]=None,
|
|
1660
|
+
plot_kwargs: Optional[Dict[str, Any]]=None,
|
|
1661
|
+
model_params: Optional[Dict[str, Any]]=None,
|
|
1662
|
+
policy_hyperparams: Optional[Dict[str, Any]]=None,
|
|
1663
|
+
subs: Optional[Dict[str, Any]]=None,
|
|
1664
|
+
guess: Optional[Pytree]=None,
|
|
1665
|
+
print_summary: bool=True,
|
|
1666
|
+
print_progress: bool=True,
|
|
1667
|
+
test_rolling_window: int=10,
|
|
1668
|
+
tqdm_position: Optional[int]=None) -> Generator[Dict[str, Any], None, None]:
|
|
1669
|
+
'''Returns a generator for computing an optimal policy or plan.
|
|
1186
1670
|
Generator can be iterated over to lazily optimize the plan, yielding
|
|
1187
1671
|
a dictionary of intermediate computations.
|
|
1188
1672
|
|
|
1189
|
-
:param key: JAX PRNG key
|
|
1673
|
+
:param key: JAX PRNG key (derived from clock if not provided)
|
|
1190
1674
|
:param epochs: the maximum number of steps of gradient descent
|
|
1191
|
-
:param the maximum number of steps of gradient descent
|
|
1192
1675
|
:param train_seconds: total time allocated for gradient descent
|
|
1193
1676
|
:param plot_step: frequency to plot the plan and save result to disk
|
|
1677
|
+
:param plot_kwargs: additional arguments to pass to the plotter
|
|
1194
1678
|
:param model_params: optional model-parameters to override default
|
|
1195
1679
|
:param policy_hyperparams: hyper-parameters for the policy/plan, such as
|
|
1196
1680
|
weights for sigmoid wrapping boolean actions
|
|
1197
1681
|
:param subs: dictionary mapping initial state and non-fluents to
|
|
1198
1682
|
their values: if None initializes all variables from the RDDL instance
|
|
1199
1683
|
:param guess: initial policy parameters: if None will use the initializer
|
|
1200
|
-
specified in this instance
|
|
1201
|
-
:param
|
|
1684
|
+
specified in this instance
|
|
1685
|
+
:param print_summary: whether to print planner header, parameter
|
|
1686
|
+
summary, and diagnosis
|
|
1687
|
+
:param print_progress: whether to print the progress bar during training
|
|
1688
|
+
:param test_rolling_window: the test return is averaged on a rolling
|
|
1689
|
+
window of the past test_rolling_window returns when updating the best
|
|
1690
|
+
parameters found so far
|
|
1202
1691
|
:param tqdm_position: position of tqdm progress bar (for multiprocessing)
|
|
1203
1692
|
'''
|
|
1204
|
-
verbose = int(verbose)
|
|
1205
1693
|
start_time = time.time()
|
|
1206
1694
|
elapsed_outside_loop = 0
|
|
1207
1695
|
|
|
1696
|
+
# if PRNG key is not provided
|
|
1697
|
+
if key is None:
|
|
1698
|
+
key = random.PRNGKey(round(time.time() * 1000))
|
|
1699
|
+
|
|
1700
|
+
# if policy_hyperparams is not provided
|
|
1701
|
+
if policy_hyperparams is None:
|
|
1702
|
+
raise_warning('policy_hyperparams is not set, setting 1.0 for '
|
|
1703
|
+
'all action-fluents which could be suboptimal.')
|
|
1704
|
+
policy_hyperparams = {action: 1.0
|
|
1705
|
+
for action in self.rddl.action_fluents}
|
|
1706
|
+
|
|
1707
|
+
# if policy_hyperparams is a scalar
|
|
1708
|
+
elif isinstance(policy_hyperparams, (int, float, np.number)):
|
|
1709
|
+
raise_warning(f'policy_hyperparams is {policy_hyperparams}, '
|
|
1710
|
+
'setting this value for all action-fluents.')
|
|
1711
|
+
hyperparam_value = float(policy_hyperparams)
|
|
1712
|
+
policy_hyperparams = {action: hyperparam_value
|
|
1713
|
+
for action in self.rddl.action_fluents}
|
|
1714
|
+
|
|
1208
1715
|
# print summary of parameters:
|
|
1209
|
-
if
|
|
1210
|
-
|
|
1211
|
-
'JAX PLANNER PARAMETER SUMMARY\n'
|
|
1212
|
-
'==============================================')
|
|
1716
|
+
if print_summary:
|
|
1717
|
+
self._summarize_system()
|
|
1213
1718
|
self.summarize_hyperparameters()
|
|
1214
1719
|
print(f'optimize() call hyper-parameters:\n'
|
|
1720
|
+
f' PRNG key ={key}\n'
|
|
1215
1721
|
f' max_iterations ={epochs}\n'
|
|
1216
1722
|
f' max_seconds ={train_seconds}\n'
|
|
1217
1723
|
f' model_params ={model_params}\n'
|
|
1218
1724
|
f' policy_hyper_params={policy_hyperparams}\n'
|
|
1219
1725
|
f' override_subs_dict ={subs is not None}\n'
|
|
1220
|
-
f' provide_param_guess={guess is not None}\n'
|
|
1221
|
-
f'
|
|
1726
|
+
f' provide_param_guess={guess is not None}\n'
|
|
1727
|
+
f' test_rolling_window={test_rolling_window}\n'
|
|
1728
|
+
f' plot_frequency ={plot_step}\n'
|
|
1729
|
+
f' plot_kwargs ={plot_kwargs}\n'
|
|
1730
|
+
f' print_summary ={print_summary}\n'
|
|
1731
|
+
f' print_progress ={print_progress}\n')
|
|
1732
|
+
if self.compiled.relaxations:
|
|
1733
|
+
print('Some RDDL operations are non-differentiable, '
|
|
1734
|
+
'replacing them with differentiable relaxations:')
|
|
1735
|
+
print(self.compiled.summarize_model_relaxations())
|
|
1222
1736
|
|
|
1223
1737
|
# compute a batched version of the initial values
|
|
1224
1738
|
if subs is None:
|
|
@@ -1237,7 +1751,7 @@ class JaxBackpropPlanner:
|
|
|
1237
1751
|
'from the RDDL files.')
|
|
1238
1752
|
train_subs, test_subs = self._batched_init_subs(subs)
|
|
1239
1753
|
|
|
1240
|
-
# initialize
|
|
1754
|
+
# initialize model parameters
|
|
1241
1755
|
if model_params is None:
|
|
1242
1756
|
model_params = self.compiled.model_params
|
|
1243
1757
|
model_params_test = self.test_compiled.model_params
|
|
@@ -1245,63 +1759,103 @@ class JaxBackpropPlanner:
|
|
|
1245
1759
|
# initialize policy parameters
|
|
1246
1760
|
if guess is None:
|
|
1247
1761
|
key, subkey = random.split(key)
|
|
1248
|
-
policy_params, opt_state = self.initialize(
|
|
1762
|
+
policy_params, opt_state, opt_aux = self.initialize(
|
|
1249
1763
|
subkey, policy_hyperparams, train_subs)
|
|
1250
1764
|
else:
|
|
1251
1765
|
policy_params = guess
|
|
1252
1766
|
opt_state = self.optimizer.init(policy_params)
|
|
1767
|
+
opt_aux = None
|
|
1768
|
+
|
|
1769
|
+
# initialize running statistics
|
|
1253
1770
|
best_params, best_loss, best_grad = policy_params, jnp.inf, jnp.inf
|
|
1254
1771
|
last_iter_improve = 0
|
|
1772
|
+
rolling_test_loss = RollingMean(test_rolling_window)
|
|
1255
1773
|
log = {}
|
|
1774
|
+
status = JaxPlannerStatus.NORMAL
|
|
1775
|
+
|
|
1776
|
+
# initialize plot area
|
|
1777
|
+
if plot_step is None or plot_step <= 0 or plt is None:
|
|
1778
|
+
plot = None
|
|
1779
|
+
else:
|
|
1780
|
+
if plot_kwargs is None:
|
|
1781
|
+
plot_kwargs = {}
|
|
1782
|
+
plot = JaxPlannerPlot(self.rddl, self.horizon, **plot_kwargs)
|
|
1783
|
+
xticks, loss_values = [], []
|
|
1256
1784
|
|
|
1257
1785
|
# training loop
|
|
1258
1786
|
iters = range(epochs)
|
|
1259
|
-
if
|
|
1787
|
+
if print_progress:
|
|
1260
1788
|
iters = tqdm(iters, total=100, position=tqdm_position)
|
|
1261
1789
|
|
|
1262
1790
|
for it in iters:
|
|
1791
|
+
status = JaxPlannerStatus.NORMAL
|
|
1263
1792
|
|
|
1264
1793
|
# update the parameters of the plan
|
|
1265
|
-
key,
|
|
1266
|
-
policy_params, converged, opt_state,
|
|
1267
|
-
|
|
1268
|
-
|
|
1794
|
+
key, subkey = random.split(key)
|
|
1795
|
+
policy_params, converged, opt_state, opt_aux, \
|
|
1796
|
+
train_loss, train_log = \
|
|
1797
|
+
self.update(subkey, policy_params, policy_hyperparams,
|
|
1798
|
+
train_subs, model_params, opt_state, opt_aux)
|
|
1799
|
+
|
|
1800
|
+
# no progress
|
|
1801
|
+
grad_norm_zero, _ = jax.tree_util.tree_flatten(
|
|
1802
|
+
jax.tree_map(lambda x: np.allclose(x, 0), train_log['grad']))
|
|
1803
|
+
if np.all(grad_norm_zero):
|
|
1804
|
+
status = JaxPlannerStatus.NO_PROGRESS
|
|
1805
|
+
|
|
1806
|
+
# constraint satisfaction problem
|
|
1269
1807
|
if not np.all(converged):
|
|
1270
1808
|
raise_warning(
|
|
1271
1809
|
'Projected gradient method for satisfying action concurrency '
|
|
1272
1810
|
'constraints reached the iteration limit: plan is possibly '
|
|
1273
1811
|
'invalid for the current instance.', 'red')
|
|
1812
|
+
status = JaxPlannerStatus.PRECONDITION_POSSIBLY_UNSATISFIED
|
|
1274
1813
|
|
|
1275
|
-
#
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1814
|
+
# numerical error
|
|
1815
|
+
if not np.isfinite(train_loss):
|
|
1816
|
+
raise_warning(
|
|
1817
|
+
f'Aborting JAX planner due to invalid train loss {train_loss}.',
|
|
1818
|
+
'red')
|
|
1819
|
+
status = JaxPlannerStatus.INVALID_GRADIENT
|
|
1820
|
+
|
|
1821
|
+
# evaluate test losses and record best plan so far
|
|
1279
1822
|
test_loss, log = self.test_loss(
|
|
1280
|
-
|
|
1823
|
+
subkey, policy_params, policy_hyperparams,
|
|
1281
1824
|
test_subs, model_params_test)
|
|
1282
|
-
|
|
1283
|
-
# record the best plan so far
|
|
1825
|
+
test_loss = rolling_test_loss.update(test_loss)
|
|
1284
1826
|
if test_loss < best_loss:
|
|
1285
1827
|
best_params, best_loss, best_grad = \
|
|
1286
1828
|
policy_params, test_loss, train_log['grad']
|
|
1287
1829
|
last_iter_improve = it
|
|
1288
1830
|
|
|
1289
1831
|
# save the plan figure
|
|
1290
|
-
if
|
|
1291
|
-
|
|
1292
|
-
|
|
1832
|
+
if plot is not None and it % plot_step == 0:
|
|
1833
|
+
xticks.append(it // plot_step)
|
|
1834
|
+
loss_values.append(test_loss.item())
|
|
1835
|
+
action_values = {name: values
|
|
1836
|
+
for (name, values) in log['fluents'].items()
|
|
1837
|
+
if name in self.rddl.action_fluents}
|
|
1838
|
+
returns = -np.sum(np.asarray(log['reward']), axis=1)
|
|
1839
|
+
plot.redraw(xticks, loss_values, action_values, returns)
|
|
1293
1840
|
|
|
1294
1841
|
# if the progress bar is used
|
|
1295
1842
|
elapsed = time.time() - start_time - elapsed_outside_loop
|
|
1296
|
-
if
|
|
1843
|
+
if print_progress:
|
|
1297
1844
|
iters.n = int(100 * min(1, max(elapsed / train_seconds, it / epochs)))
|
|
1298
1845
|
iters.set_description(
|
|
1299
|
-
f'[{tqdm_position}] {it:6} it / {-train_loss:14.
|
|
1300
|
-
f'{-test_loss:14.
|
|
1846
|
+
f'[{tqdm_position}] {it:6} it / {-train_loss:14.6f} train / '
|
|
1847
|
+
f'{-test_loss:14.6f} test / {-best_loss:14.6f} best')
|
|
1848
|
+
|
|
1849
|
+
# reached computation budget
|
|
1850
|
+
if elapsed >= train_seconds:
|
|
1851
|
+
status = JaxPlannerStatus.TIME_BUDGET_REACHED
|
|
1852
|
+
if it >= epochs - 1:
|
|
1853
|
+
status = JaxPlannerStatus.ITER_BUDGET_REACHED
|
|
1301
1854
|
|
|
1302
1855
|
# return a callback
|
|
1303
1856
|
start_time_outside = time.time()
|
|
1304
1857
|
yield {
|
|
1858
|
+
'status': status,
|
|
1305
1859
|
'iteration': it,
|
|
1306
1860
|
'train_return':-train_loss,
|
|
1307
1861
|
'test_return':-test_loss,
|
|
@@ -1318,16 +1872,15 @@ class JaxBackpropPlanner:
|
|
|
1318
1872
|
}
|
|
1319
1873
|
elapsed_outside_loop += (time.time() - start_time_outside)
|
|
1320
1874
|
|
|
1321
|
-
#
|
|
1322
|
-
if
|
|
1323
|
-
break
|
|
1324
|
-
|
|
1325
|
-
# numerical error
|
|
1326
|
-
if not np.isfinite(train_loss):
|
|
1875
|
+
# abortion check
|
|
1876
|
+
if status.is_failure():
|
|
1327
1877
|
break
|
|
1328
|
-
|
|
1329
|
-
|
|
1878
|
+
|
|
1879
|
+
# release resources
|
|
1880
|
+
if print_progress:
|
|
1330
1881
|
iters.close()
|
|
1882
|
+
if plot is not None:
|
|
1883
|
+
plot.close()
|
|
1331
1884
|
|
|
1332
1885
|
# validate the test return
|
|
1333
1886
|
if log:
|
|
@@ -1337,24 +1890,23 @@ class JaxBackpropPlanner:
|
|
|
1337
1890
|
if messages:
|
|
1338
1891
|
messages = '\n'.join(messages)
|
|
1339
1892
|
raise_warning('The JAX compiler encountered the following '
|
|
1340
|
-
'
|
|
1893
|
+
'error(s) in the original RDDL formulation '
|
|
1341
1894
|
f'during test evaluation:\n{messages}', 'red')
|
|
1342
1895
|
|
|
1343
1896
|
# summarize and test for convergence
|
|
1344
|
-
if
|
|
1345
|
-
grad_norm = jax.tree_map(
|
|
1346
|
-
lambda x: np.array(jnp.linalg.norm(x)).item(), best_grad)
|
|
1897
|
+
if print_summary:
|
|
1898
|
+
grad_norm = jax.tree_map(lambda x: np.linalg.norm(x).item(), best_grad)
|
|
1347
1899
|
diagnosis = self._perform_diagnosis(
|
|
1348
|
-
last_iter_improve,
|
|
1349
|
-
-train_loss, -test_loss, -best_loss, grad_norm)
|
|
1900
|
+
last_iter_improve, -train_loss, -test_loss, -best_loss, grad_norm)
|
|
1350
1901
|
print(f'summary of optimization:\n'
|
|
1902
|
+
f' status_code ={status}\n'
|
|
1351
1903
|
f' time_elapsed ={elapsed}\n'
|
|
1352
1904
|
f' iterations ={it}\n'
|
|
1353
1905
|
f' best_objective={-best_loss}\n'
|
|
1354
|
-
f'
|
|
1906
|
+
f' best_grad_norm={grad_norm}\n'
|
|
1355
1907
|
f'diagnosis: {diagnosis}\n')
|
|
1356
1908
|
|
|
1357
|
-
def _perform_diagnosis(self, last_iter_improve,
|
|
1909
|
+
def _perform_diagnosis(self, last_iter_improve,
|
|
1358
1910
|
train_return, test_return, best_return, grad_norm):
|
|
1359
1911
|
max_grad_norm = max(jax.tree_util.tree_leaves(grad_norm))
|
|
1360
1912
|
grad_is_zero = np.allclose(max_grad_norm, 0)
|
|
@@ -1373,20 +1925,20 @@ class JaxBackpropPlanner:
|
|
|
1373
1925
|
if grad_is_zero:
|
|
1374
1926
|
return termcolor.colored(
|
|
1375
1927
|
'[FAILURE] no progress was made, '
|
|
1376
|
-
f'and max grad norm
|
|
1377
|
-
'likely stuck in a plateau.', 'red')
|
|
1928
|
+
f'and max grad norm {max_grad_norm:.6f} is zero: '
|
|
1929
|
+
'solver likely stuck in a plateau.', 'red')
|
|
1378
1930
|
else:
|
|
1379
1931
|
return termcolor.colored(
|
|
1380
1932
|
'[FAILURE] no progress was made, '
|
|
1381
|
-
f'but max grad norm
|
|
1382
|
-
'likely
|
|
1933
|
+
f'but max grad norm {max_grad_norm:.6f} is non-zero: '
|
|
1934
|
+
'likely poor learning rate or other hyper-parameter.', 'red')
|
|
1383
1935
|
|
|
1384
1936
|
# model is likely poor IF:
|
|
1385
1937
|
# 1. the train and test return disagree
|
|
1386
1938
|
if not (validation_error < 20):
|
|
1387
1939
|
return termcolor.colored(
|
|
1388
1940
|
'[WARNING] progress was made, '
|
|
1389
|
-
f'but relative train
|
|
1941
|
+
f'but relative train-test error {validation_error:.6f} is high: '
|
|
1390
1942
|
'likely poor model relaxation around the solution, '
|
|
1391
1943
|
'or the batch size is too small.', 'yellow')
|
|
1392
1944
|
|
|
@@ -1397,208 +1949,213 @@ class JaxBackpropPlanner:
|
|
|
1397
1949
|
if not (return_to_grad_norm > 1):
|
|
1398
1950
|
return termcolor.colored(
|
|
1399
1951
|
'[WARNING] progress was made, '
|
|
1400
|
-
f'but max grad norm
|
|
1401
|
-
'likely
|
|
1402
|
-
'or the model is not smooth around the solution, '
|
|
1952
|
+
f'but max grad norm {max_grad_norm:.6f} is high: '
|
|
1953
|
+
'likely the solution is not locally optimal, '
|
|
1954
|
+
'or the relaxed model is not smooth around the solution, '
|
|
1403
1955
|
'or the batch size is too small.', 'yellow')
|
|
1404
1956
|
|
|
1405
1957
|
# likely successful
|
|
1406
1958
|
return termcolor.colored(
|
|
1407
|
-
'[SUCCESS] planner
|
|
1959
|
+
'[SUCCESS] planner has converged successfully '
|
|
1408
1960
|
'(note: not all potential problems can be ruled out).', 'green')
|
|
1409
1961
|
|
|
1410
1962
|
def get_action(self, key: random.PRNGKey,
|
|
1411
|
-
params:
|
|
1963
|
+
params: Pytree,
|
|
1412
1964
|
step: int,
|
|
1413
|
-
subs: Dict,
|
|
1414
|
-
policy_hyperparams: Dict[str,
|
|
1965
|
+
subs: Dict[str, Any],
|
|
1966
|
+
policy_hyperparams: Optional[Dict[str, Any]]=None) -> Dict[str, Any]:
|
|
1415
1967
|
'''Returns an action dictionary from the policy or plan with the given
|
|
1416
1968
|
parameters.
|
|
1417
1969
|
|
|
1418
1970
|
:param key: the JAX PRNG key
|
|
1419
1971
|
:param params: the trainable parameter PyTree of the policy
|
|
1420
1972
|
:param step: the time step at which decision is made
|
|
1421
|
-
:param policy_hyperparams: hyper-parameters for the policy/plan, such as
|
|
1422
|
-
weights for sigmoid wrapping boolean actions
|
|
1423
1973
|
:param subs: the dict of pvariables
|
|
1974
|
+
:param policy_hyperparams: hyper-parameters for the policy/plan, such as
|
|
1975
|
+
weights for sigmoid wrapping boolean actions (optional)
|
|
1424
1976
|
'''
|
|
1425
1977
|
|
|
1426
1978
|
# check compatibility of the subs dictionary
|
|
1427
|
-
for var in subs.
|
|
1979
|
+
for (var, values) in subs.items():
|
|
1980
|
+
|
|
1981
|
+
# must not be grounded
|
|
1428
1982
|
if RDDLPlanningModel.FLUENT_SEP in var \
|
|
1429
1983
|
or RDDLPlanningModel.OBJECT_SEP in var:
|
|
1430
|
-
raise
|
|
1431
|
-
|
|
1432
|
-
|
|
1433
|
-
|
|
1434
|
-
|
|
1984
|
+
raise ValueError(f'State dictionary passed to the JAX policy is '
|
|
1985
|
+
f'grounded, since it contains the key <{var}>, '
|
|
1986
|
+
f'but a vectorized environment is required: '
|
|
1987
|
+
f'make sure vectorized = True in the RDDLEnv.')
|
|
1988
|
+
|
|
1989
|
+
# must be numeric array
|
|
1990
|
+
# exception is for POMDPs at 1st epoch when observ-fluents are None
|
|
1991
|
+
dtype = np.atleast_1d(values).dtype
|
|
1992
|
+
if not jnp.issubdtype(dtype, jnp.number) \
|
|
1993
|
+
and not jnp.issubdtype(dtype, jnp.bool_):
|
|
1994
|
+
if step == 0 and var in self.rddl.observ_fluents:
|
|
1995
|
+
subs[var] = self.test_compiled.init_values[var]
|
|
1996
|
+
else:
|
|
1997
|
+
raise ValueError(
|
|
1998
|
+
f'Values {values} assigned to p-variable <{var}> are '
|
|
1999
|
+
f'non-numeric of type {dtype}.')
|
|
2000
|
+
|
|
1435
2001
|
# cast device arrays to numpy
|
|
1436
2002
|
actions = self.test_policy(key, params, policy_hyperparams, step, subs)
|
|
1437
2003
|
actions = jax.tree_map(np.asarray, actions)
|
|
1438
2004
|
return actions
|
|
1439
|
-
|
|
1440
|
-
def _plot_actions(self, key, params, hyperparams, subs, it):
|
|
1441
|
-
rddl = self.rddl
|
|
1442
|
-
try:
|
|
1443
|
-
import matplotlib.pyplot as plt
|
|
1444
|
-
except Exception:
|
|
1445
|
-
print('matplotlib is not installed, aborting plot...')
|
|
1446
|
-
return
|
|
1447
|
-
|
|
1448
|
-
# predict actions from the trained policy or plan
|
|
1449
|
-
actions = self.test_rollouts(key, params, hyperparams, subs, {})['action']
|
|
1450
|
-
|
|
1451
|
-
# plot the action sequences as color maps
|
|
1452
|
-
fig, axs = plt.subplots(nrows=len(actions), constrained_layout=True)
|
|
1453
|
-
for (ax, name) in zip(axs, actions):
|
|
1454
|
-
action = np.mean(actions[name], axis=0, dtype=float)
|
|
1455
|
-
action = np.reshape(action, newshape=(action.shape[0], -1)).T
|
|
1456
|
-
if rddl.variable_ranges[name] == 'bool':
|
|
1457
|
-
vmin, vmax = 0.0, 1.0
|
|
1458
|
-
else:
|
|
1459
|
-
vmin, vmax = None, None
|
|
1460
|
-
img = ax.imshow(
|
|
1461
|
-
action, vmin=vmin, vmax=vmax, cmap='seismic', aspect='auto')
|
|
1462
|
-
ax.set_xlabel('time')
|
|
1463
|
-
ax.set_ylabel(name)
|
|
1464
|
-
plt.colorbar(img, ax=ax)
|
|
1465
|
-
|
|
1466
|
-
# write plot to disk
|
|
1467
|
-
plt.savefig(f'plan_{rddl.domain_name}_{rddl.instance_name}_{it}.pdf',
|
|
1468
|
-
bbox_inches='tight')
|
|
1469
|
-
plt.clf()
|
|
1470
|
-
plt.close(fig)
|
|
1471
2005
|
|
|
1472
2006
|
|
|
1473
|
-
class
|
|
2007
|
+
class JaxLineSearchPlanner(JaxBackpropPlanner):
|
|
1474
2008
|
'''A class for optimizing an action sequence in the given RDDL MDP using
|
|
1475
|
-
|
|
2009
|
+
linear search gradient descent, with the Armijo condition.'''
|
|
1476
2010
|
|
|
1477
2011
|
def __init__(self, *args,
|
|
1478
|
-
|
|
1479
|
-
optimizer_kwargs: Dict[str, object]={'learning_rate': 1.0},
|
|
1480
|
-
beta: float=0.8,
|
|
2012
|
+
decay: float=0.8,
|
|
1481
2013
|
c: float=0.1,
|
|
1482
|
-
|
|
1483
|
-
|
|
2014
|
+
step_max: float=1.0,
|
|
2015
|
+
step_min: float=1e-6,
|
|
1484
2016
|
**kwargs) -> None:
|
|
1485
2017
|
'''Creates a new gradient-based algorithm for optimizing action sequences
|
|
1486
|
-
(plan) in the given RDDL using
|
|
2018
|
+
(plan) in the given RDDL using line search. All arguments are the
|
|
1487
2019
|
same as in the parent class, except:
|
|
1488
2020
|
|
|
1489
|
-
:param
|
|
1490
|
-
:param c: coefficient in Armijo condition
|
|
1491
|
-
:param
|
|
1492
|
-
:param
|
|
2021
|
+
:param decay: reduction factor of learning rate per line search iteration
|
|
2022
|
+
:param c: positive coefficient in Armijo condition, should be in (0, 1)
|
|
2023
|
+
:param step_max: initial learning rate for line search
|
|
2024
|
+
:param step_min: minimum possible learning rate (line search halts)
|
|
1493
2025
|
'''
|
|
1494
|
-
self.
|
|
2026
|
+
self.decay = decay
|
|
1495
2027
|
self.c = c
|
|
1496
|
-
self.
|
|
1497
|
-
self.
|
|
1498
|
-
|
|
1499
|
-
|
|
1500
|
-
|
|
1501
|
-
|
|
1502
|
-
|
|
1503
|
-
|
|
1504
|
-
def summarize_hyperparameters(self):
|
|
1505
|
-
super(
|
|
2028
|
+
self.step_max = step_max
|
|
2029
|
+
self.step_min = step_min
|
|
2030
|
+
if 'clip_grad' in kwargs:
|
|
2031
|
+
raise_warning('clip_grad parameter conflicts with '
|
|
2032
|
+
'line search planner and will be ignored.', 'red')
|
|
2033
|
+
del kwargs['clip_grad']
|
|
2034
|
+
super(JaxLineSearchPlanner, self).__init__(*args, **kwargs)
|
|
2035
|
+
|
|
2036
|
+
def summarize_hyperparameters(self) -> None:
|
|
2037
|
+
super(JaxLineSearchPlanner, self).summarize_hyperparameters()
|
|
1506
2038
|
print(f'linesearch hyper-parameters:\n'
|
|
1507
|
-
f'
|
|
2039
|
+
f' decay ={self.decay}\n'
|
|
1508
2040
|
f' c ={self.c}\n'
|
|
1509
|
-
f' lr_range=({self.
|
|
2041
|
+
f' lr_range=({self.step_min}, {self.step_max})')
|
|
1510
2042
|
|
|
1511
2043
|
def _jax_update(self, loss):
|
|
1512
2044
|
optimizer = self.optimizer
|
|
1513
2045
|
projection = self.plan.projection
|
|
1514
|
-
|
|
1515
|
-
|
|
1516
|
-
#
|
|
1517
|
-
|
|
1518
|
-
def
|
|
1519
|
-
|
|
1520
|
-
|
|
1521
|
-
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
|
|
1525
|
-
|
|
1526
|
-
old_x, _, old_g, _, old_state = old
|
|
1527
|
-
_, _, lr, iters = new
|
|
1528
|
-
_, best_f, _, _ = best
|
|
1529
|
-
key, hyperparams, *other = aux
|
|
1530
|
-
|
|
1531
|
-
# anneal learning rate and apply a gradient step
|
|
1532
|
-
new_lr = beta * lr
|
|
1533
|
-
old_state.hyperparams['learning_rate'] = new_lr
|
|
1534
|
-
updates, new_state = optimizer.update(old_g, old_state)
|
|
1535
|
-
new_x = optax.apply_updates(old_x, updates)
|
|
1536
|
-
new_x, _ = projection(new_x, hyperparams)
|
|
1537
|
-
|
|
1538
|
-
# evaluate new loss and record best so far
|
|
1539
|
-
new_f, _ = loss(key, new_x, hyperparams, *other)
|
|
1540
|
-
new = (new_x, new_f, new_lr, iters + 1)
|
|
1541
|
-
best = jax.lax.cond(
|
|
1542
|
-
new_f < best_f,
|
|
1543
|
-
lambda: (new_x, new_f, new_lr, new_state),
|
|
1544
|
-
lambda: best
|
|
1545
|
-
)
|
|
1546
|
-
return old, new, best, aux
|
|
2046
|
+
decay, c, lrmax, lrmin = self.decay, self.c, self.step_max, self.step_min
|
|
2047
|
+
|
|
2048
|
+
# initialize the line search routine
|
|
2049
|
+
@jax.jit
|
|
2050
|
+
def _jax_wrapped_line_search_init(key, policy_params, hyperparams,
|
|
2051
|
+
subs, model_params):
|
|
2052
|
+
(f, log), grad = jax.value_and_grad(loss, argnums=1, has_aux=True)(
|
|
2053
|
+
key, policy_params, hyperparams, subs, model_params)
|
|
2054
|
+
gnorm2 = jax.tree_map(lambda x: jnp.sum(jnp.square(x)), grad)
|
|
2055
|
+
gnorm2 = jax.tree_util.tree_reduce(jnp.add, gnorm2)
|
|
2056
|
+
log['grad'] = grad
|
|
2057
|
+
return f, grad, gnorm2, log
|
|
1547
2058
|
|
|
2059
|
+
# compute the next trial solution
|
|
2060
|
+
@jax.jit
|
|
2061
|
+
def _jax_wrapped_line_search_trial(
|
|
2062
|
+
step, grad, key, params, hparams, subs, mparams, state):
|
|
2063
|
+
state.hyperparams['learning_rate'] = step
|
|
2064
|
+
updates, new_state = optimizer.update(grad, state)
|
|
2065
|
+
new_params = optax.apply_updates(params, updates)
|
|
2066
|
+
new_params, _ = projection(new_params, hparams)
|
|
2067
|
+
f_step, _ = loss(key, new_params, hparams, subs, mparams)
|
|
2068
|
+
return f_step, new_params, new_state
|
|
2069
|
+
|
|
2070
|
+
# main iteration of line search
|
|
1548
2071
|
def _jax_wrapped_plan_update(key, policy_params, hyperparams,
|
|
1549
|
-
subs, model_params, opt_state):
|
|
1550
|
-
|
|
1551
|
-
# calculate initial loss value, gradient and squared norm
|
|
1552
|
-
old_x = policy_params
|
|
1553
|
-
loss_and_grad_fn = jax.value_and_grad(loss, argnums=1, has_aux=True)
|
|
1554
|
-
(old_f, log), old_g = loss_and_grad_fn(
|
|
1555
|
-
key, old_x, hyperparams, subs, model_params)
|
|
1556
|
-
old_norm_g2 = jax.tree_map(lambda x: jnp.sum(jnp.square(x)), old_g)
|
|
1557
|
-
old_norm_g2 = jax.tree_util.tree_reduce(jnp.add, old_norm_g2)
|
|
1558
|
-
log['grad'] = old_g
|
|
2072
|
+
subs, model_params, opt_state, opt_aux):
|
|
1559
2073
|
|
|
1560
|
-
# initialize
|
|
1561
|
-
|
|
1562
|
-
|
|
1563
|
-
new = (old_x, old_f, new_lr, 0)
|
|
1564
|
-
best = (old_x, jnp.inf, jnp.nan, opt_state)
|
|
1565
|
-
aux = (key, hyperparams, subs, model_params)
|
|
2074
|
+
# initialize the line search
|
|
2075
|
+
f, grad, gnorm2, log = _jax_wrapped_line_search_init(
|
|
2076
|
+
key, policy_params, hyperparams, subs, model_params)
|
|
1566
2077
|
|
|
1567
|
-
#
|
|
1568
|
-
|
|
1569
|
-
|
|
2078
|
+
# continue to reduce the learning rate until the Armijo condition holds
|
|
2079
|
+
trials = 0
|
|
2080
|
+
step = lrmax / decay
|
|
2081
|
+
f_step = np.inf
|
|
2082
|
+
best_f, best_step, best_params, best_state = np.inf, None, None, None
|
|
2083
|
+
while (f_step > f - c * step * gnorm2 and step * decay >= lrmin) \
|
|
2084
|
+
or not trials:
|
|
2085
|
+
trials += 1
|
|
2086
|
+
step *= decay
|
|
2087
|
+
f_step, new_params, new_state = _jax_wrapped_line_search_trial(
|
|
2088
|
+
step, grad, key, policy_params, hyperparams, subs,
|
|
2089
|
+
model_params, opt_state)
|
|
2090
|
+
if f_step < best_f:
|
|
2091
|
+
best_f, best_step, best_params, best_state = \
|
|
2092
|
+
f_step, step, new_params, new_state
|
|
1570
2093
|
|
|
1571
|
-
# continue to anneal the learning rate until Armijo condition holds
|
|
1572
|
-
# or the learning rate becomes too small, then use the best parameter
|
|
1573
|
-
_, (*_, iters), (best_params, _, best_lr, best_state), _ = \
|
|
1574
|
-
jax.lax.while_loop(
|
|
1575
|
-
cond_fun=_jax_wrapped_line_search_armijo_check,
|
|
1576
|
-
body_fun=_jax_wrapped_line_search_iteration,
|
|
1577
|
-
init_val=init_val
|
|
1578
|
-
)
|
|
1579
|
-
best_state.hyperparams['learning_rate'] = best_lr
|
|
1580
2094
|
log['updates'] = None
|
|
1581
|
-
log['line_search_iters'] =
|
|
1582
|
-
log['learning_rate'] =
|
|
1583
|
-
return best_params, True, best_state, log
|
|
2095
|
+
log['line_search_iters'] = trials
|
|
2096
|
+
log['learning_rate'] = best_step
|
|
2097
|
+
return best_params, True, best_state, best_step, best_f, log
|
|
1584
2098
|
|
|
1585
2099
|
return _jax_wrapped_plan_update
|
|
1586
2100
|
|
|
1587
|
-
|
|
2101
|
+
|
|
2102
|
+
# ***********************************************************************
|
|
2103
|
+
# ALL VERSIONS OF RISK FUNCTIONS
|
|
2104
|
+
#
|
|
2105
|
+
# Based on the original paper "A Distributional Framework for Risk-Sensitive
|
|
2106
|
+
# End-to-End Planning in Continuous MDPs" by Patton et al., AAAI 2022.
|
|
2107
|
+
#
|
|
2108
|
+
# Original risk functions:
|
|
2109
|
+
# - entropic utility
|
|
2110
|
+
# - mean-variance approximation
|
|
2111
|
+
# - conditional value at risk with straight-through gradient trick
|
|
2112
|
+
#
|
|
2113
|
+
# ***********************************************************************
|
|
2114
|
+
|
|
2115
|
+
|
|
2116
|
+
@jax.jit
|
|
2117
|
+
def entropic_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
2118
|
+
return (-1.0 / beta) * jax.scipy.special.logsumexp(
|
|
2119
|
+
-beta * returns, b=1.0 / returns.size)
|
|
2120
|
+
|
|
2121
|
+
|
|
2122
|
+
@jax.jit
|
|
2123
|
+
def mean_variance_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
2124
|
+
return jnp.mean(returns) - 0.5 * beta * jnp.var(returns)
|
|
2125
|
+
|
|
2126
|
+
|
|
2127
|
+
@jax.jit
|
|
2128
|
+
def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
|
|
2129
|
+
alpha_mask = jax.lax.stop_gradient(
|
|
2130
|
+
returns <= jnp.percentile(returns, q=100 * alpha))
|
|
2131
|
+
return jnp.sum(returns * alpha_mask) / jnp.sum(alpha_mask)
|
|
2132
|
+
|
|
2133
|
+
|
|
2134
|
+
# ***********************************************************************
|
|
2135
|
+
# ALL VERSIONS OF CONTROLLERS
|
|
2136
|
+
#
|
|
2137
|
+
# - offline controller is the straight-line planner
|
|
2138
|
+
# - online controller is the replanning mode
|
|
2139
|
+
#
|
|
2140
|
+
# ***********************************************************************
|
|
2141
|
+
|
|
1588
2142
|
class JaxOfflineController(BaseAgent):
|
|
1589
2143
|
'''A container class for a Jax policy trained offline.'''
|
|
2144
|
+
|
|
1590
2145
|
use_tensor_obs = True
|
|
1591
2146
|
|
|
1592
|
-
def __init__(self, planner: JaxBackpropPlanner,
|
|
1593
|
-
|
|
1594
|
-
|
|
2147
|
+
def __init__(self, planner: JaxBackpropPlanner,
|
|
2148
|
+
key: Optional[random.PRNGKey]=None,
|
|
2149
|
+
eval_hyperparams: Optional[Dict[str, Any]]=None,
|
|
2150
|
+
params: Optional[Pytree]=None,
|
|
1595
2151
|
train_on_reset: bool=False,
|
|
1596
2152
|
**train_kwargs) -> None:
|
|
1597
2153
|
'''Creates a new JAX offline control policy that is trained once, then
|
|
1598
2154
|
deployed later.
|
|
1599
2155
|
|
|
1600
2156
|
:param planner: underlying planning algorithm for optimizing actions
|
|
1601
|
-
:param key: the RNG key to seed randomness
|
|
2157
|
+
:param key: the RNG key to seed randomness (derives from clock if not
|
|
2158
|
+
provided)
|
|
1602
2159
|
:param eval_hyperparams: policy hyperparameters to apply for evaluation
|
|
1603
2160
|
or whenever sample_action is called
|
|
1604
2161
|
:param params: use the specified policy parameters instead of calling
|
|
@@ -1608,6 +2165,8 @@ class JaxOfflineController(BaseAgent):
|
|
|
1608
2165
|
for optimization
|
|
1609
2166
|
'''
|
|
1610
2167
|
self.planner = planner
|
|
2168
|
+
if key is None:
|
|
2169
|
+
key = random.PRNGKey(round(time.time() * 1000))
|
|
1611
2170
|
self.key = key
|
|
1612
2171
|
self.eval_hyperparams = eval_hyperparams
|
|
1613
2172
|
self.train_on_reset = train_on_reset
|
|
@@ -1616,60 +2175,72 @@ class JaxOfflineController(BaseAgent):
|
|
|
1616
2175
|
|
|
1617
2176
|
self.step = 0
|
|
1618
2177
|
if not self.train_on_reset and not self.params_given:
|
|
1619
|
-
|
|
2178
|
+
callback = self.planner.optimize(key=self.key, **self.train_kwargs)
|
|
2179
|
+
params = callback['best_params']
|
|
1620
2180
|
self.params = params
|
|
1621
2181
|
|
|
1622
|
-
def sample_action(self, state):
|
|
2182
|
+
def sample_action(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
|
1623
2183
|
self.key, subkey = random.split(self.key)
|
|
1624
2184
|
actions = self.planner.get_action(
|
|
1625
2185
|
subkey, self.params, self.step, state, self.eval_hyperparams)
|
|
1626
2186
|
self.step += 1
|
|
1627
2187
|
return actions
|
|
1628
2188
|
|
|
1629
|
-
def reset(self):
|
|
2189
|
+
def reset(self) -> None:
|
|
1630
2190
|
self.step = 0
|
|
1631
2191
|
if self.train_on_reset and not self.params_given:
|
|
1632
|
-
|
|
2192
|
+
callback = self.planner.optimize(key=self.key, **self.train_kwargs)
|
|
2193
|
+
self.params = callback['best_params']
|
|
1633
2194
|
|
|
1634
2195
|
|
|
1635
2196
|
class JaxOnlineController(BaseAgent):
|
|
1636
2197
|
'''A container class for a Jax controller continuously updated using state
|
|
1637
2198
|
feedback.'''
|
|
2199
|
+
|
|
1638
2200
|
use_tensor_obs = True
|
|
1639
2201
|
|
|
1640
|
-
def __init__(self, planner: JaxBackpropPlanner,
|
|
1641
|
-
|
|
2202
|
+
def __init__(self, planner: JaxBackpropPlanner,
|
|
2203
|
+
key: Optional[random.PRNGKey]=None,
|
|
2204
|
+
eval_hyperparams: Optional[Dict[str, Any]]=None,
|
|
2205
|
+
warm_start: bool=True,
|
|
1642
2206
|
**train_kwargs) -> None:
|
|
1643
2207
|
'''Creates a new JAX control policy that is trained online in a closed-
|
|
1644
2208
|
loop fashion.
|
|
1645
2209
|
|
|
1646
2210
|
:param planner: underlying planning algorithm for optimizing actions
|
|
1647
|
-
:param key: the RNG key to seed randomness
|
|
2211
|
+
:param key: the RNG key to seed randomness (derives from clock if not
|
|
2212
|
+
provided)
|
|
1648
2213
|
:param eval_hyperparams: policy hyperparameters to apply for evaluation
|
|
1649
2214
|
or whenever sample_action is called
|
|
2215
|
+
:param warm_start: whether to use the previous decision epoch final
|
|
2216
|
+
policy parameters to warm the next decision epoch
|
|
1650
2217
|
:param **train_kwargs: any keyword arguments to be passed to the planner
|
|
1651
2218
|
for optimization
|
|
1652
2219
|
'''
|
|
1653
2220
|
self.planner = planner
|
|
2221
|
+
if key is None:
|
|
2222
|
+
key = random.PRNGKey(round(time.time() * 1000))
|
|
1654
2223
|
self.key = key
|
|
1655
2224
|
self.eval_hyperparams = eval_hyperparams
|
|
1656
2225
|
self.warm_start = warm_start
|
|
1657
2226
|
self.train_kwargs = train_kwargs
|
|
1658
2227
|
self.reset()
|
|
1659
2228
|
|
|
1660
|
-
def sample_action(self, state):
|
|
2229
|
+
def sample_action(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
|
1661
2230
|
planner = self.planner
|
|
1662
|
-
|
|
2231
|
+
callback = planner.optimize(
|
|
1663
2232
|
key=self.key,
|
|
1664
2233
|
guess=self.guess,
|
|
1665
2234
|
subs=state,
|
|
1666
2235
|
**self.train_kwargs)
|
|
2236
|
+
params = callback['best_params']
|
|
1667
2237
|
self.key, subkey = random.split(self.key)
|
|
1668
|
-
actions = planner.get_action(
|
|
2238
|
+
actions = planner.get_action(
|
|
2239
|
+
subkey, params, 0, state, self.eval_hyperparams)
|
|
1669
2240
|
if self.warm_start:
|
|
1670
2241
|
self.guess = planner.plan.guess_next_epoch(params)
|
|
1671
2242
|
return actions
|
|
1672
2243
|
|
|
1673
|
-
def reset(self):
|
|
2244
|
+
def reset(self) -> None:
|
|
1674
2245
|
self.guess = None
|
|
1675
2246
|
|