pyRDDLGym-jax 2.4__py3-none-any.whl → 2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +23 -10
- pyRDDLGym_jax/core/logic.py +6 -8
- pyRDDLGym_jax/core/model.py +595 -0
- pyRDDLGym_jax/core/planner.py +317 -99
- pyRDDLGym_jax/core/simulator.py +37 -13
- pyRDDLGym_jax/core/tuning.py +25 -10
- pyRDDLGym_jax/entry_point.py +39 -7
- pyRDDLGym_jax/examples/configs/tuning_drp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/tuning_replan.cfg +1 -0
- pyRDDLGym_jax/examples/configs/tuning_slp.cfg +1 -0
- pyRDDLGym_jax/examples/run_plan.py +1 -1
- pyRDDLGym_jax/examples/run_tune.py +8 -2
- {pyrddlgym_jax-2.4.dist-info → pyrddlgym_jax-2.6.dist-info}/METADATA +17 -30
- {pyrddlgym_jax-2.4.dist-info → pyrddlgym_jax-2.6.dist-info}/RECORD +19 -18
- {pyrddlgym_jax-2.4.dist-info → pyrddlgym_jax-2.6.dist-info}/WHEEL +1 -1
- {pyrddlgym_jax-2.4.dist-info → pyrddlgym_jax-2.6.dist-info}/entry_points.txt +0 -0
- {pyrddlgym_jax-2.4.dist-info → pyrddlgym_jax-2.6.dist-info/licenses}/LICENSE +0 -0
- {pyrddlgym_jax-2.4.dist-info → pyrddlgym_jax-2.6.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,595 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
from copy import deepcopy
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from functools import partial
|
|
5
|
+
import sys
|
|
6
|
+
import time
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple
|
|
9
|
+
|
|
10
|
+
import jax
|
|
11
|
+
import jax.nn.initializers as initializers
|
|
12
|
+
import jax.numpy as jnp
|
|
13
|
+
import jax.random as random
|
|
14
|
+
import numpy as np
|
|
15
|
+
import optax
|
|
16
|
+
|
|
17
|
+
from pyRDDLGym.core.compiler.model import RDDLLiftedModel
|
|
18
|
+
|
|
19
|
+
from pyRDDLGym_jax.core.logic import Logic, ExactLogic
|
|
20
|
+
from pyRDDLGym_jax.core.planner import JaxRDDLCompilerWithGrad
|
|
21
|
+
|
|
22
|
+
Kwargs = Dict[str, Any]
|
|
23
|
+
State = Dict[str, np.ndarray]
|
|
24
|
+
Action = Dict[str, np.ndarray]
|
|
25
|
+
DataStream = Iterable[Tuple[State, Action, State]]
|
|
26
|
+
Params = Dict[str, np.ndarray]
|
|
27
|
+
Callback = Dict[str, Any]
|
|
28
|
+
LossFunction = Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# ***********************************************************************
|
|
32
|
+
# ALL VERSIONS OF LOSS FUNCTIONS
|
|
33
|
+
#
|
|
34
|
+
# - loss functions based on specific likelihood assumptions (MSE, cross-entropy)
|
|
35
|
+
#
|
|
36
|
+
# ***********************************************************************
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def mean_squared_error() -> LossFunction:
|
|
40
|
+
def _jax_wrapped_mse_loss(target, pred):
|
|
41
|
+
loss_values = jnp.square(target - pred)
|
|
42
|
+
return loss_values
|
|
43
|
+
return jax.jit(_jax_wrapped_mse_loss)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def binary_cross_entropy(eps: float=1e-6) -> LossFunction:
|
|
47
|
+
def _jax_wrapped_binary_cross_entropy_loss(target, pred):
|
|
48
|
+
pred = jnp.clip(pred, eps, 1.0 - eps)
|
|
49
|
+
log_pred = jnp.log(pred)
|
|
50
|
+
log_not_pred = jnp.log(1.0 - pred)
|
|
51
|
+
loss_values = -target * log_pred - (1.0 - target) * log_not_pred
|
|
52
|
+
return loss_values
|
|
53
|
+
return jax.jit(_jax_wrapped_binary_cross_entropy_loss)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def optax_loss(loss_fn: LossFunction, **kwargs) -> LossFunction:
|
|
57
|
+
def _jax_wrapped_optax_loss(target, pred):
|
|
58
|
+
loss_values = loss_fn(pred, target, **kwargs)
|
|
59
|
+
return loss_values
|
|
60
|
+
return jax.jit(_jax_wrapped_optax_loss)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
# ***********************************************************************
|
|
64
|
+
# ALL VERSIONS OF JAX MODEL LEARNER
|
|
65
|
+
#
|
|
66
|
+
# - gradient based model learning
|
|
67
|
+
#
|
|
68
|
+
# ***********************************************************************
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class JaxLearnerStatus(Enum):
|
|
72
|
+
'''Represents the status of a parameter update from the JAX model learner,
|
|
73
|
+
including whether the update resulted in nan gradient,
|
|
74
|
+
whether progress was made, budget was reached, or other information that
|
|
75
|
+
can be used to monitor and act based on the learner's progress.'''
|
|
76
|
+
|
|
77
|
+
NORMAL = 0
|
|
78
|
+
NO_PROGRESS = 1
|
|
79
|
+
INVALID_GRADIENT = 2
|
|
80
|
+
TIME_BUDGET_REACHED = 3
|
|
81
|
+
ITER_BUDGET_REACHED = 4
|
|
82
|
+
|
|
83
|
+
def is_terminal(self) -> bool:
|
|
84
|
+
return self.value >= 2
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class JaxModelLearner:
|
|
88
|
+
'''A class for data-driven estimation of unknown parameters in a given RDDL MDP using
|
|
89
|
+
gradient descent.'''
|
|
90
|
+
|
|
91
|
+
def __init__(self, rddl: RDDLLiftedModel,
|
|
92
|
+
param_ranges: Dict[str, Tuple[Optional[np.ndarray], Optional[np.ndarray]]],
|
|
93
|
+
batch_size_train: int=32,
|
|
94
|
+
samples_per_datapoint: int=1,
|
|
95
|
+
optimizer: Callable[..., optax.GradientTransformation]=optax.rmsprop,
|
|
96
|
+
optimizer_kwargs: Optional[Kwargs]=None,
|
|
97
|
+
initializer: initializers.Initializer = initializers.normal(),
|
|
98
|
+
wrap_non_bool: bool=True,
|
|
99
|
+
use64bit: bool=False,
|
|
100
|
+
bool_fluent_loss: LossFunction=binary_cross_entropy(),
|
|
101
|
+
real_fluent_loss: LossFunction=mean_squared_error(),
|
|
102
|
+
int_fluent_loss: LossFunction=mean_squared_error(),
|
|
103
|
+
logic: Logic=ExactLogic(),
|
|
104
|
+
model_params_reduction: Callable=lambda x: x[0]) -> None:
|
|
105
|
+
'''Creates a new gradient-based algorithm for inferring unknown non-fluents
|
|
106
|
+
in a RDDL domain from a data set or stream coming from the real environment.
|
|
107
|
+
|
|
108
|
+
:param rddl: the RDDL domain to learn
|
|
109
|
+
:param param_ranges: the ranges of all learnable non-fluents
|
|
110
|
+
:param batch_size_train: how many transitions to compute per optimization
|
|
111
|
+
step
|
|
112
|
+
:param samples_per_datapoint: how many random samples to produce from the step
|
|
113
|
+
function per data point during training
|
|
114
|
+
:param optimizer: a factory for an optax SGD algorithm
|
|
115
|
+
:param optimizer_kwargs: a dictionary of parameters to pass to the SGD
|
|
116
|
+
factory (e.g. which parameters are controllable externally)
|
|
117
|
+
:param initializer: how to initialize non-fluents
|
|
118
|
+
:param wrap_non_bool: whether to wrap non-boolean trainable parameters to satisfy
|
|
119
|
+
required ranges as specified in param_ranges (use a projected gradient otherwise)
|
|
120
|
+
:param use64bit: whether to perform arithmetic in 64 bit
|
|
121
|
+
:param bool_fluent_loss: loss function to optimize for bool-valued fluents
|
|
122
|
+
:param real_fluent_loss: loss function to optimize for real-valued fluents
|
|
123
|
+
:param int_fluent_loss: loss function to optimize for int-valued fluents
|
|
124
|
+
:param logic: a subclass of Logic for mapping exact mathematical
|
|
125
|
+
operations to their differentiable counterparts
|
|
126
|
+
:param model_params_reduction: how to aggregate updated model_params across runs
|
|
127
|
+
in the batch (defaults to selecting the first element's parameters in the batch)
|
|
128
|
+
'''
|
|
129
|
+
self.rddl = rddl
|
|
130
|
+
self.param_ranges = param_ranges.copy()
|
|
131
|
+
self.batch_size_train = batch_size_train
|
|
132
|
+
self.samples_per_datapoint = samples_per_datapoint
|
|
133
|
+
if optimizer_kwargs is None:
|
|
134
|
+
optimizer_kwargs = {'learning_rate': 0.001}
|
|
135
|
+
self.optimizer_kwargs = optimizer_kwargs
|
|
136
|
+
self.initializer = initializer
|
|
137
|
+
self.wrap_non_bool = wrap_non_bool
|
|
138
|
+
self.use64bit = use64bit
|
|
139
|
+
self.bool_fluent_loss = bool_fluent_loss
|
|
140
|
+
self.real_fluent_loss = real_fluent_loss
|
|
141
|
+
self.int_fluent_loss = int_fluent_loss
|
|
142
|
+
self.logic = logic
|
|
143
|
+
self.model_params_reduction = model_params_reduction
|
|
144
|
+
|
|
145
|
+
# validate param_ranges
|
|
146
|
+
for (name, values) in param_ranges.items():
|
|
147
|
+
if name not in rddl.non_fluents:
|
|
148
|
+
raise ValueError(
|
|
149
|
+
f'param_ranges key <{name}> is not a valid non-fluent '
|
|
150
|
+
f'in the current rddl.')
|
|
151
|
+
if not isinstance(values, (tuple, list)):
|
|
152
|
+
raise ValueError(
|
|
153
|
+
f'param_ranges values with key <{name}> are neither a tuple nor a list.')
|
|
154
|
+
if len(values) != 2:
|
|
155
|
+
raise ValueError(
|
|
156
|
+
f'param_ranges values with key <{name}> must be of length 2, '
|
|
157
|
+
f'got length {len(values)}.')
|
|
158
|
+
lower, upper = values
|
|
159
|
+
if lower is not None and upper is not None and not np.all(lower <= upper):
|
|
160
|
+
raise ValueError(
|
|
161
|
+
f'param_ranges values with key <{name}> do not satisfy lower <= upper.')
|
|
162
|
+
|
|
163
|
+
# build the optimizer
|
|
164
|
+
optimizer = optimizer(**optimizer_kwargs)
|
|
165
|
+
pipeline = [optimizer]
|
|
166
|
+
self.optimizer = optax.chain(*pipeline)
|
|
167
|
+
|
|
168
|
+
# build the computation graph
|
|
169
|
+
self.step_fn = self._jax_compile_rddl()
|
|
170
|
+
self.map_fn = self._jax_map()
|
|
171
|
+
self.loss_fn = self._jax_loss(map_fn=self.map_fn, step_fn=self.step_fn)
|
|
172
|
+
self.update_fn, self.project_fn = self._jax_update(loss_fn=self.loss_fn)
|
|
173
|
+
self.init_fn, self.init_opt_fn = self._jax_init(project_fn=self.project_fn)
|
|
174
|
+
|
|
175
|
+
# ===========================================================================
|
|
176
|
+
# COMPILATION SUBROUTINES
|
|
177
|
+
# ===========================================================================
|
|
178
|
+
|
|
179
|
+
def _jax_compile_rddl(self):
|
|
180
|
+
|
|
181
|
+
# compile the RDDL model
|
|
182
|
+
self.compiled = JaxRDDLCompilerWithGrad(
|
|
183
|
+
rddl=self.rddl,
|
|
184
|
+
logic=self.logic,
|
|
185
|
+
use64bit=self.use64bit,
|
|
186
|
+
compile_non_fluent_exact=False,
|
|
187
|
+
print_warnings=True
|
|
188
|
+
)
|
|
189
|
+
self.compiled.compile(log_jax_expr=True, heading='RELAXED MODEL')
|
|
190
|
+
|
|
191
|
+
# compile the transition step function
|
|
192
|
+
step_fn = self.compiled.compile_transition()
|
|
193
|
+
|
|
194
|
+
def _jax_wrapped_step(key, param_fluents, subs, actions, hyperparams):
|
|
195
|
+
for (name, param) in param_fluents.items():
|
|
196
|
+
subs[name] = param
|
|
197
|
+
subs, _, hyperparams = step_fn(key, actions, subs, hyperparams)
|
|
198
|
+
return subs, hyperparams
|
|
199
|
+
|
|
200
|
+
# batched step function
|
|
201
|
+
def _jax_wrapped_batched_step(key, param_fluents, subs, actions, hyperparams):
|
|
202
|
+
keys = jnp.asarray(random.split(key, num=self.batch_size_train))
|
|
203
|
+
subs, hyperparams = jax.vmap(
|
|
204
|
+
_jax_wrapped_step, in_axes=(0, None, 0, 0, None)
|
|
205
|
+
)(keys, param_fluents, subs, actions, hyperparams)
|
|
206
|
+
hyperparams = jax.tree_util.tree_map(self.model_params_reduction, hyperparams)
|
|
207
|
+
return subs, hyperparams
|
|
208
|
+
|
|
209
|
+
# batched step function with parallel samples per data point
|
|
210
|
+
def _jax_wrapped_batched_parallel_step(key, param_fluents, subs, actions, hyperparams):
|
|
211
|
+
keys = jnp.asarray(random.split(key, num=self.samples_per_datapoint))
|
|
212
|
+
subs, hyperparams = jax.vmap(
|
|
213
|
+
_jax_wrapped_batched_step, in_axes=(0, None, None, None, None)
|
|
214
|
+
)(keys, param_fluents, subs, actions, hyperparams)
|
|
215
|
+
hyperparams = jax.tree_util.tree_map(self.model_params_reduction, hyperparams)
|
|
216
|
+
return subs, hyperparams
|
|
217
|
+
|
|
218
|
+
batched_step_fn = jax.jit(_jax_wrapped_batched_parallel_step)
|
|
219
|
+
return batched_step_fn
|
|
220
|
+
|
|
221
|
+
def _jax_map(self):
|
|
222
|
+
|
|
223
|
+
# compute case indices for bounding
|
|
224
|
+
case_indices = {}
|
|
225
|
+
if self.wrap_non_bool:
|
|
226
|
+
for (name, (lower, upper)) in self.param_ranges.items():
|
|
227
|
+
if lower is None: lower = -np.inf
|
|
228
|
+
if upper is None: upper = +np.inf
|
|
229
|
+
self.param_ranges[name] = (lower, upper)
|
|
230
|
+
case_indices[name] = (
|
|
231
|
+
0 * (np.isfinite(lower) & np.isfinite(upper)) +
|
|
232
|
+
1 * (np.isfinite(lower) & ~np.isfinite(upper)) +
|
|
233
|
+
2 * (~np.isfinite(lower) & np.isfinite(upper)) +
|
|
234
|
+
3 * (~np.isfinite(lower) & ~np.isfinite(upper))
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
# map trainable parameters to their non-fluent values
|
|
238
|
+
def _jax_wrapped_params_to_fluents(params):
|
|
239
|
+
param_fluents = {}
|
|
240
|
+
for (name, param) in params.items():
|
|
241
|
+
if self.rddl.variable_ranges[name] == 'bool':
|
|
242
|
+
param_fluents[name] = jax.nn.sigmoid(param)
|
|
243
|
+
else:
|
|
244
|
+
if self.wrap_non_bool:
|
|
245
|
+
lower, upper = self.param_ranges[name]
|
|
246
|
+
cases = [
|
|
247
|
+
lambda x: lower + (upper - lower) * jax.nn.sigmoid(x),
|
|
248
|
+
lambda x: lower + (jax.nn.elu(x) + 1.0),
|
|
249
|
+
lambda x: upper - (jax.nn.elu(-x) + 1.0),
|
|
250
|
+
lambda x: x
|
|
251
|
+
]
|
|
252
|
+
indices = case_indices[name]
|
|
253
|
+
param_fluents[name] = jax.lax.switch(indices, cases, param)
|
|
254
|
+
else:
|
|
255
|
+
param_fluents[name] = param
|
|
256
|
+
return param_fluents
|
|
257
|
+
|
|
258
|
+
map_fn = jax.jit(_jax_wrapped_params_to_fluents)
|
|
259
|
+
return map_fn
|
|
260
|
+
|
|
261
|
+
def _jax_loss(self, map_fn, step_fn):
|
|
262
|
+
|
|
263
|
+
# use binary cross entropy for bool fluents
|
|
264
|
+
# mean squared error for continuous and integer fluents
|
|
265
|
+
def _jax_wrapped_batched_model_loss(key, param_fluents, subs, actions, next_fluents,
|
|
266
|
+
hyperparams):
|
|
267
|
+
next_subs, hyperparams = step_fn(key, param_fluents, subs, actions, hyperparams)
|
|
268
|
+
total_loss = 0.0
|
|
269
|
+
for (name, next_value) in next_fluents.items():
|
|
270
|
+
preds = jnp.asarray(next_subs[name], dtype=self.compiled.REAL)
|
|
271
|
+
targets = jnp.asarray(next_value, dtype=self.compiled.REAL)[jnp.newaxis, ...]
|
|
272
|
+
if self.rddl.variable_ranges[name] == 'bool':
|
|
273
|
+
loss_values = self.bool_fluent_loss(targets, preds)
|
|
274
|
+
elif self.rddl.variable_ranges[name] == 'real':
|
|
275
|
+
loss_values = self.real_fluent_loss(targets, preds)
|
|
276
|
+
else:
|
|
277
|
+
loss_values = self.int_fluent_loss(targets, preds)
|
|
278
|
+
total_loss += jnp.mean(loss_values) / len(next_fluents)
|
|
279
|
+
return total_loss, hyperparams
|
|
280
|
+
|
|
281
|
+
# loss with the parameters mapped to their fluents
|
|
282
|
+
def _jax_wrapped_batched_loss(key, params, subs, actions, next_fluents, hyperparams):
|
|
283
|
+
param_fluents = map_fn(params)
|
|
284
|
+
loss, hyperparams = _jax_wrapped_batched_model_loss(
|
|
285
|
+
key, param_fluents, subs, actions, next_fluents, hyperparams)
|
|
286
|
+
return loss, hyperparams
|
|
287
|
+
|
|
288
|
+
loss_fn = jax.jit(_jax_wrapped_batched_loss)
|
|
289
|
+
return loss_fn
|
|
290
|
+
|
|
291
|
+
def _jax_init(self, project_fn):
|
|
292
|
+
optimizer = self.optimizer
|
|
293
|
+
|
|
294
|
+
# initialize both the non-fluents and optimizer
|
|
295
|
+
def _jax_wrapped_init_params_optimizer(key):
|
|
296
|
+
params = {}
|
|
297
|
+
for name in self.param_ranges:
|
|
298
|
+
shape = jnp.shape(self.compiled.init_values[name])
|
|
299
|
+
key, subkey = random.split(key)
|
|
300
|
+
params[name] = self.initializer(subkey, shape, dtype=self.compiled.REAL)
|
|
301
|
+
params = project_fn(params)
|
|
302
|
+
opt_state = optimizer.init(params)
|
|
303
|
+
return params, opt_state
|
|
304
|
+
|
|
305
|
+
# initialize just the optimizer given the non-fluents
|
|
306
|
+
def _jax_wrapped_init_optimizer(params):
|
|
307
|
+
params = project_fn(params)
|
|
308
|
+
opt_state = optimizer.init(params)
|
|
309
|
+
return params, opt_state
|
|
310
|
+
|
|
311
|
+
init_fn = jax.jit(_jax_wrapped_init_params_optimizer)
|
|
312
|
+
init_opt_fn = jax.jit(_jax_wrapped_init_optimizer)
|
|
313
|
+
return init_fn, init_opt_fn
|
|
314
|
+
|
|
315
|
+
def _jax_update(self, loss_fn):
|
|
316
|
+
optimizer = self.optimizer
|
|
317
|
+
|
|
318
|
+
# projected gradient trick to satisfy box constraints on params
|
|
319
|
+
def _jax_wrapped_project_params(params):
|
|
320
|
+
if self.wrap_non_bool:
|
|
321
|
+
return params
|
|
322
|
+
else:
|
|
323
|
+
new_params = {}
|
|
324
|
+
for (name, value) in params.items():
|
|
325
|
+
if self.rddl.variable_ranges[name] == 'bool':
|
|
326
|
+
new_params[name] = value
|
|
327
|
+
else:
|
|
328
|
+
lower, upper = self.param_ranges[name]
|
|
329
|
+
new_params[name] = jnp.clip(value, lower, upper)
|
|
330
|
+
return new_params
|
|
331
|
+
|
|
332
|
+
# gradient descent update
|
|
333
|
+
def _jax_wrapped_params_update(key, params, subs, actions, next_fluents,
|
|
334
|
+
hyperparams, opt_state):
|
|
335
|
+
(loss_val, hyperparams), grad = jax.value_and_grad(
|
|
336
|
+
loss_fn, argnums=1, has_aux=True
|
|
337
|
+
)(key, params, subs, actions, next_fluents, hyperparams)
|
|
338
|
+
updates, opt_state = optimizer.update(grad, opt_state)
|
|
339
|
+
params = optax.apply_updates(params, updates)
|
|
340
|
+
params = _jax_wrapped_project_params(params)
|
|
341
|
+
zero_grads = jax.tree_util.tree_map(partial(jnp.allclose, b=0.0), grad)
|
|
342
|
+
return params, opt_state, loss_val, zero_grads, hyperparams
|
|
343
|
+
|
|
344
|
+
update_fn = jax.jit(_jax_wrapped_params_update)
|
|
345
|
+
project_fn = jax.jit(_jax_wrapped_project_params)
|
|
346
|
+
return update_fn, project_fn
|
|
347
|
+
|
|
348
|
+
def _batched_init_subs(self):
|
|
349
|
+
init_train = {}
|
|
350
|
+
for (name, value) in self.compiled.init_values.items():
|
|
351
|
+
value = np.reshape(value, np.shape(value))[np.newaxis, ...]
|
|
352
|
+
value = np.repeat(value, repeats=self.batch_size_train, axis=0)
|
|
353
|
+
value = np.asarray(value, dtype=self.compiled.REAL)
|
|
354
|
+
init_train[name] = value
|
|
355
|
+
for (state, next_state) in self.rddl.next_state.items():
|
|
356
|
+
init_train[next_state] = init_train[state]
|
|
357
|
+
return init_train
|
|
358
|
+
|
|
359
|
+
# ===========================================================================
|
|
360
|
+
# ESTIMATE API
|
|
361
|
+
# ===========================================================================
|
|
362
|
+
|
|
363
|
+
def optimize(self, *args, **kwargs) -> Optional[Callback]:
|
|
364
|
+
'''Estimate the unknown parameters from the given data set.
|
|
365
|
+
Return the callback from training.
|
|
366
|
+
|
|
367
|
+
:param data: a data stream represented as a (possibly infinite) sequence of
|
|
368
|
+
transition batches of the form (states, actions, next-states), where each element
|
|
369
|
+
is a numpy array of leading dimension equal to batch_size_train
|
|
370
|
+
:param key: JAX PRNG key (derived from clock if not provided)
|
|
371
|
+
:param epochs: the maximum number of steps of gradient descent
|
|
372
|
+
:param train_seconds: total time allocated for gradient descent
|
|
373
|
+
:param guess: initial non-fluent parameters: if None will use the initializer
|
|
374
|
+
specified in this instance
|
|
375
|
+
:param print_progress: whether to print the progress bar during training
|
|
376
|
+
'''
|
|
377
|
+
it = self.optimize_generator(*args, **kwargs)
|
|
378
|
+
|
|
379
|
+
# https://stackoverflow.com/questions/50937966/fastest-most-pythonic-way-to-consume-an-iterator
|
|
380
|
+
callback = None
|
|
381
|
+
if sys.implementation.name == 'cpython':
|
|
382
|
+
last_callback = deque(it, maxlen=1)
|
|
383
|
+
if last_callback:
|
|
384
|
+
callback = last_callback.pop()
|
|
385
|
+
else:
|
|
386
|
+
for callback in it:
|
|
387
|
+
pass
|
|
388
|
+
return callback
|
|
389
|
+
|
|
390
|
+
def optimize_generator(self, data: DataStream,
|
|
391
|
+
key: Optional[random.PRNGKey]=None,
|
|
392
|
+
epochs: int=999999,
|
|
393
|
+
train_seconds: float=120.,
|
|
394
|
+
guess: Optional[Params]=None,
|
|
395
|
+
print_progress: bool=True) -> Generator[Callback, None, None]:
|
|
396
|
+
'''Return a generator for estimating the unknown parameters from the given data set.
|
|
397
|
+
Generator can be iterated over to lazily estimate the parameters, yielding
|
|
398
|
+
a dictionary of intermediate computations.
|
|
399
|
+
|
|
400
|
+
:param data: a data stream represented as a (possibly infinite) sequence of
|
|
401
|
+
transition batches of the form (states, actions, next-states), where each element
|
|
402
|
+
is a numpy array of leading dimension equal to batch_size_train
|
|
403
|
+
:param key: JAX PRNG key (derived from clock if not provided)
|
|
404
|
+
:param epochs: the maximum number of steps of gradient descent
|
|
405
|
+
:param train_seconds: total time allocated for gradient descent
|
|
406
|
+
:param guess: initial non-fluent parameters: if None will use the initializer
|
|
407
|
+
specified in this instance
|
|
408
|
+
:param print_progress: whether to print the progress bar during training
|
|
409
|
+
'''
|
|
410
|
+
start_time = time.time()
|
|
411
|
+
elapsed_outside_loop = 0
|
|
412
|
+
|
|
413
|
+
# if PRNG key is not provided
|
|
414
|
+
if key is None:
|
|
415
|
+
key = random.PRNGKey(round(time.time() * 1000))
|
|
416
|
+
|
|
417
|
+
# prepare initial subs
|
|
418
|
+
subs = self._batched_init_subs()
|
|
419
|
+
|
|
420
|
+
# initialize parameter fluents to optimize
|
|
421
|
+
if guess is None:
|
|
422
|
+
key, subkey = random.split(key)
|
|
423
|
+
params, opt_state = self.init_fn(subkey)
|
|
424
|
+
else:
|
|
425
|
+
params, opt_state = self.init_opt_fn(guess)
|
|
426
|
+
|
|
427
|
+
# initialize model hyper-parameters
|
|
428
|
+
hyperparams = self.compiled.model_params
|
|
429
|
+
|
|
430
|
+
# progress bar
|
|
431
|
+
if print_progress:
|
|
432
|
+
progress_bar = tqdm(
|
|
433
|
+
None, total=100, bar_format='{l_bar}{bar}| {elapsed} {postfix}')
|
|
434
|
+
else:
|
|
435
|
+
progress_bar = None
|
|
436
|
+
|
|
437
|
+
# main training loop
|
|
438
|
+
for (it, (states, actions, next_states)) in enumerate(data):
|
|
439
|
+
status = JaxLearnerStatus.NORMAL
|
|
440
|
+
|
|
441
|
+
# gradient update
|
|
442
|
+
subs.update(states)
|
|
443
|
+
key, subkey = random.split(key)
|
|
444
|
+
params, opt_state, loss, zero_grads, hyperparams = self.update_fn(
|
|
445
|
+
subkey, params, subs, actions, next_states, hyperparams, opt_state)
|
|
446
|
+
|
|
447
|
+
# extract non-fluent values from the trainable parameters
|
|
448
|
+
param_fluents = self.map_fn(params)
|
|
449
|
+
param_fluents = {name: param_fluents[name] for name in self.param_ranges}
|
|
450
|
+
|
|
451
|
+
# check for learnability
|
|
452
|
+
params_zero_grads = {
|
|
453
|
+
name for (name, zero_grad) in zero_grads.items() if zero_grad}
|
|
454
|
+
if params_zero_grads:
|
|
455
|
+
status = JaxLearnerStatus.NO_PROGRESS
|
|
456
|
+
|
|
457
|
+
# reached computation budget
|
|
458
|
+
elapsed = time.time() - start_time - elapsed_outside_loop
|
|
459
|
+
if elapsed >= train_seconds:
|
|
460
|
+
status = JaxLearnerStatus.TIME_BUDGET_REACHED
|
|
461
|
+
if it >= epochs - 1:
|
|
462
|
+
status = JaxLearnerStatus.ITER_BUDGET_REACHED
|
|
463
|
+
|
|
464
|
+
# build a callback
|
|
465
|
+
progress_percent = 100 * min(
|
|
466
|
+
1, max(0, elapsed / train_seconds, it / (epochs - 1)))
|
|
467
|
+
callback = {
|
|
468
|
+
'status': status,
|
|
469
|
+
'iteration': it,
|
|
470
|
+
'train_loss': loss,
|
|
471
|
+
'params': params,
|
|
472
|
+
'param_fluents': param_fluents,
|
|
473
|
+
'key': key,
|
|
474
|
+
'progress': progress_percent
|
|
475
|
+
}
|
|
476
|
+
|
|
477
|
+
# update progress
|
|
478
|
+
if print_progress:
|
|
479
|
+
progress_bar.set_description(
|
|
480
|
+
f'{it:7} it / {loss:12.8f} train / {status.value} status', refresh=False)
|
|
481
|
+
progress_bar.set_postfix_str(
|
|
482
|
+
f'{(it + 1) / (elapsed + 1e-6):.2f}it/s', refresh=False)
|
|
483
|
+
progress_bar.update(progress_percent - progress_bar.n)
|
|
484
|
+
|
|
485
|
+
# yield the callback
|
|
486
|
+
start_time_outside = time.time()
|
|
487
|
+
yield callback
|
|
488
|
+
elapsed_outside_loop += (time.time() - start_time_outside)
|
|
489
|
+
|
|
490
|
+
# abortion check
|
|
491
|
+
if status.is_terminal():
|
|
492
|
+
break
|
|
493
|
+
|
|
494
|
+
def evaluate_loss(self, data: DataStream,
|
|
495
|
+
key: Optional[random.PRNGKey],
|
|
496
|
+
param_fluents: Params) -> float:
|
|
497
|
+
'''Evaluates the model loss of the given learned non-fluent values and the data.
|
|
498
|
+
|
|
499
|
+
:param data: a data stream represented as a (possibly infinite) sequence of
|
|
500
|
+
transition batches of the form (states, actions, next-states), where each element
|
|
501
|
+
is a numpy array of leading dimension equal to batch_size_train
|
|
502
|
+
:param key: JAX PRNG key (derived from clock if not provided)
|
|
503
|
+
:param param_fluents: the learned non-fluent values
|
|
504
|
+
'''
|
|
505
|
+
if key is None:
|
|
506
|
+
key = random.PRNGKey(round(time.time() * 1000))
|
|
507
|
+
subs = self._batched_init_subs()
|
|
508
|
+
hyperparams = self.compiled.model_params
|
|
509
|
+
mean_loss = 0.0
|
|
510
|
+
for (it, (states, actions, next_states)) in enumerate(data):
|
|
511
|
+
subs.update(states)
|
|
512
|
+
key, subkey = random.split(key)
|
|
513
|
+
loss_value, _ = self.loss_fn(
|
|
514
|
+
subkey, param_fluents, subs, actions, next_states, hyperparams)
|
|
515
|
+
mean_loss += (loss_value - mean_loss) / (it + 1)
|
|
516
|
+
return mean_loss
|
|
517
|
+
|
|
518
|
+
def learned_model(self, param_fluents: Params) -> RDDLLiftedModel:
|
|
519
|
+
'''Substitutes the given learned non-fluent values into the RDDL model and returns
|
|
520
|
+
the new model.
|
|
521
|
+
|
|
522
|
+
:param param_fluents: the learned non-fluent values
|
|
523
|
+
'''
|
|
524
|
+
model = deepcopy(self.rddl)
|
|
525
|
+
for (name, values) in param_fluents.items():
|
|
526
|
+
prange = model.variable_ranges[name]
|
|
527
|
+
if prange == 'real':
|
|
528
|
+
pass
|
|
529
|
+
elif prange == 'bool':
|
|
530
|
+
values = values > 0.5
|
|
531
|
+
else:
|
|
532
|
+
values = np.asarray(values, dtype=self.compiled.INT)
|
|
533
|
+
values = np.ravel(values, order='C').tolist()
|
|
534
|
+
if not self.rddl.variable_params[name]:
|
|
535
|
+
assert(len(values) == 1)
|
|
536
|
+
values = values[0]
|
|
537
|
+
model.non_fluents[name] = values
|
|
538
|
+
return model
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
if __name__ == '__main__':
|
|
542
|
+
import os
|
|
543
|
+
import pyRDDLGym
|
|
544
|
+
from pyRDDLGym_jax.core.planner import load_config, JaxBackpropPlanner, JaxOfflineController
|
|
545
|
+
bs = 32
|
|
546
|
+
|
|
547
|
+
# make some data
|
|
548
|
+
def data_iterator():
|
|
549
|
+
env = pyRDDLGym.make('CartPole_Continuous_gym', '0', vectorized=True)
|
|
550
|
+
model = JaxModelLearner(rddl=env.model, param_ranges={}, batch_size_train=bs)
|
|
551
|
+
key = random.PRNGKey(round(time.time() * 1000))
|
|
552
|
+
subs = model._batched_init_subs()
|
|
553
|
+
param_fluents = {}
|
|
554
|
+
while True:
|
|
555
|
+
states = {
|
|
556
|
+
'pos': np.random.uniform(-2.4, 2.4, (bs,)),
|
|
557
|
+
'vel': np.random.uniform(-2.4, 2.4, (bs,)),
|
|
558
|
+
'ang-pos': np.random.uniform(-0.21, 0.21, (bs,)),
|
|
559
|
+
'ang-vel': np.random.uniform(-0.21, 0.21, (bs,))
|
|
560
|
+
}
|
|
561
|
+
subs.update(states)
|
|
562
|
+
actions = {
|
|
563
|
+
'force': np.random.uniform(-10., 10., (bs,))
|
|
564
|
+
}
|
|
565
|
+
key, subkey = random.split(key)
|
|
566
|
+
subs, _ = model.step_fn(subkey, param_fluents, subs, actions, {})
|
|
567
|
+
subs = {k: np.asarray(v)[0, ...] for k, v in subs.items()}
|
|
568
|
+
next_states = {k: subs[k] for k in model.rddl.state_fluents}
|
|
569
|
+
yield (states, actions, next_states)
|
|
570
|
+
|
|
571
|
+
# train it
|
|
572
|
+
env = pyRDDLGym.make('TestJax', '0', vectorized=True)
|
|
573
|
+
model_learner = JaxModelLearner(rddl=env.model,
|
|
574
|
+
param_ranges={
|
|
575
|
+
'w1': (None, None), 'b1': (None, None),
|
|
576
|
+
'w2': (None, None), 'b2': (None, None),
|
|
577
|
+
'w1o': (None, None), 'b1o': (None, None),
|
|
578
|
+
'w2o': (None, None), 'b2o': (None, None)
|
|
579
|
+
},
|
|
580
|
+
batch_size_train=bs,
|
|
581
|
+
optimizer_kwargs = {'learning_rate': 0.0003})
|
|
582
|
+
for cb in model_learner.optimize_generator(data_iterator(), epochs=10000):
|
|
583
|
+
pass
|
|
584
|
+
|
|
585
|
+
# planning in the trained model
|
|
586
|
+
model = model_learner.learned_model(cb['param_fluents'])
|
|
587
|
+
abs_path = os.path.dirname(os.path.abspath(__file__))
|
|
588
|
+
config_path = os.path.join(os.path.dirname(abs_path), 'examples', 'configs', 'default_drp.cfg')
|
|
589
|
+
planner_args, _, train_args = load_config(config_path)
|
|
590
|
+
planner = JaxBackpropPlanner(rddl=model, **planner_args)
|
|
591
|
+
controller = JaxOfflineController(planner, **train_args)
|
|
592
|
+
|
|
593
|
+
# evaluation of the plan
|
|
594
|
+
test_env = pyRDDLGym.make('CartPole_Continuous_gym', '0', vectorized=True)
|
|
595
|
+
controller.evaluate(test_env, episodes=1, verbose=True, render=True)
|