pyRDDLGym-jax 2.4__py3-none-any.whl → 2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,595 @@
1
+ from collections import deque
2
+ from copy import deepcopy
3
+ from enum import Enum
4
+ from functools import partial
5
+ import sys
6
+ import time
7
+ from tqdm import tqdm
8
+ from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple
9
+
10
+ import jax
11
+ import jax.nn.initializers as initializers
12
+ import jax.numpy as jnp
13
+ import jax.random as random
14
+ import numpy as np
15
+ import optax
16
+
17
+ from pyRDDLGym.core.compiler.model import RDDLLiftedModel
18
+
19
+ from pyRDDLGym_jax.core.logic import Logic, ExactLogic
20
+ from pyRDDLGym_jax.core.planner import JaxRDDLCompilerWithGrad
21
+
22
+ Kwargs = Dict[str, Any]
23
+ State = Dict[str, np.ndarray]
24
+ Action = Dict[str, np.ndarray]
25
+ DataStream = Iterable[Tuple[State, Action, State]]
26
+ Params = Dict[str, np.ndarray]
27
+ Callback = Dict[str, Any]
28
+ LossFunction = Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
29
+
30
+
31
+ # ***********************************************************************
32
+ # ALL VERSIONS OF LOSS FUNCTIONS
33
+ #
34
+ # - loss functions based on specific likelihood assumptions (MSE, cross-entropy)
35
+ #
36
+ # ***********************************************************************
37
+
38
+
39
+ def mean_squared_error() -> LossFunction:
40
+ def _jax_wrapped_mse_loss(target, pred):
41
+ loss_values = jnp.square(target - pred)
42
+ return loss_values
43
+ return jax.jit(_jax_wrapped_mse_loss)
44
+
45
+
46
+ def binary_cross_entropy(eps: float=1e-6) -> LossFunction:
47
+ def _jax_wrapped_binary_cross_entropy_loss(target, pred):
48
+ pred = jnp.clip(pred, eps, 1.0 - eps)
49
+ log_pred = jnp.log(pred)
50
+ log_not_pred = jnp.log(1.0 - pred)
51
+ loss_values = -target * log_pred - (1.0 - target) * log_not_pred
52
+ return loss_values
53
+ return jax.jit(_jax_wrapped_binary_cross_entropy_loss)
54
+
55
+
56
+ def optax_loss(loss_fn: LossFunction, **kwargs) -> LossFunction:
57
+ def _jax_wrapped_optax_loss(target, pred):
58
+ loss_values = loss_fn(pred, target, **kwargs)
59
+ return loss_values
60
+ return jax.jit(_jax_wrapped_optax_loss)
61
+
62
+
63
+ # ***********************************************************************
64
+ # ALL VERSIONS OF JAX MODEL LEARNER
65
+ #
66
+ # - gradient based model learning
67
+ #
68
+ # ***********************************************************************
69
+
70
+
71
+ class JaxLearnerStatus(Enum):
72
+ '''Represents the status of a parameter update from the JAX model learner,
73
+ including whether the update resulted in nan gradient,
74
+ whether progress was made, budget was reached, or other information that
75
+ can be used to monitor and act based on the learner's progress.'''
76
+
77
+ NORMAL = 0
78
+ NO_PROGRESS = 1
79
+ INVALID_GRADIENT = 2
80
+ TIME_BUDGET_REACHED = 3
81
+ ITER_BUDGET_REACHED = 4
82
+
83
+ def is_terminal(self) -> bool:
84
+ return self.value >= 2
85
+
86
+
87
+ class JaxModelLearner:
88
+ '''A class for data-driven estimation of unknown parameters in a given RDDL MDP using
89
+ gradient descent.'''
90
+
91
+ def __init__(self, rddl: RDDLLiftedModel,
92
+ param_ranges: Dict[str, Tuple[Optional[np.ndarray], Optional[np.ndarray]]],
93
+ batch_size_train: int=32,
94
+ samples_per_datapoint: int=1,
95
+ optimizer: Callable[..., optax.GradientTransformation]=optax.rmsprop,
96
+ optimizer_kwargs: Optional[Kwargs]=None,
97
+ initializer: initializers.Initializer = initializers.normal(),
98
+ wrap_non_bool: bool=True,
99
+ use64bit: bool=False,
100
+ bool_fluent_loss: LossFunction=binary_cross_entropy(),
101
+ real_fluent_loss: LossFunction=mean_squared_error(),
102
+ int_fluent_loss: LossFunction=mean_squared_error(),
103
+ logic: Logic=ExactLogic(),
104
+ model_params_reduction: Callable=lambda x: x[0]) -> None:
105
+ '''Creates a new gradient-based algorithm for inferring unknown non-fluents
106
+ in a RDDL domain from a data set or stream coming from the real environment.
107
+
108
+ :param rddl: the RDDL domain to learn
109
+ :param param_ranges: the ranges of all learnable non-fluents
110
+ :param batch_size_train: how many transitions to compute per optimization
111
+ step
112
+ :param samples_per_datapoint: how many random samples to produce from the step
113
+ function per data point during training
114
+ :param optimizer: a factory for an optax SGD algorithm
115
+ :param optimizer_kwargs: a dictionary of parameters to pass to the SGD
116
+ factory (e.g. which parameters are controllable externally)
117
+ :param initializer: how to initialize non-fluents
118
+ :param wrap_non_bool: whether to wrap non-boolean trainable parameters to satisfy
119
+ required ranges as specified in param_ranges (use a projected gradient otherwise)
120
+ :param use64bit: whether to perform arithmetic in 64 bit
121
+ :param bool_fluent_loss: loss function to optimize for bool-valued fluents
122
+ :param real_fluent_loss: loss function to optimize for real-valued fluents
123
+ :param int_fluent_loss: loss function to optimize for int-valued fluents
124
+ :param logic: a subclass of Logic for mapping exact mathematical
125
+ operations to their differentiable counterparts
126
+ :param model_params_reduction: how to aggregate updated model_params across runs
127
+ in the batch (defaults to selecting the first element's parameters in the batch)
128
+ '''
129
+ self.rddl = rddl
130
+ self.param_ranges = param_ranges.copy()
131
+ self.batch_size_train = batch_size_train
132
+ self.samples_per_datapoint = samples_per_datapoint
133
+ if optimizer_kwargs is None:
134
+ optimizer_kwargs = {'learning_rate': 0.001}
135
+ self.optimizer_kwargs = optimizer_kwargs
136
+ self.initializer = initializer
137
+ self.wrap_non_bool = wrap_non_bool
138
+ self.use64bit = use64bit
139
+ self.bool_fluent_loss = bool_fluent_loss
140
+ self.real_fluent_loss = real_fluent_loss
141
+ self.int_fluent_loss = int_fluent_loss
142
+ self.logic = logic
143
+ self.model_params_reduction = model_params_reduction
144
+
145
+ # validate param_ranges
146
+ for (name, values) in param_ranges.items():
147
+ if name not in rddl.non_fluents:
148
+ raise ValueError(
149
+ f'param_ranges key <{name}> is not a valid non-fluent '
150
+ f'in the current rddl.')
151
+ if not isinstance(values, (tuple, list)):
152
+ raise ValueError(
153
+ f'param_ranges values with key <{name}> are neither a tuple nor a list.')
154
+ if len(values) != 2:
155
+ raise ValueError(
156
+ f'param_ranges values with key <{name}> must be of length 2, '
157
+ f'got length {len(values)}.')
158
+ lower, upper = values
159
+ if lower is not None and upper is not None and not np.all(lower <= upper):
160
+ raise ValueError(
161
+ f'param_ranges values with key <{name}> do not satisfy lower <= upper.')
162
+
163
+ # build the optimizer
164
+ optimizer = optimizer(**optimizer_kwargs)
165
+ pipeline = [optimizer]
166
+ self.optimizer = optax.chain(*pipeline)
167
+
168
+ # build the computation graph
169
+ self.step_fn = self._jax_compile_rddl()
170
+ self.map_fn = self._jax_map()
171
+ self.loss_fn = self._jax_loss(map_fn=self.map_fn, step_fn=self.step_fn)
172
+ self.update_fn, self.project_fn = self._jax_update(loss_fn=self.loss_fn)
173
+ self.init_fn, self.init_opt_fn = self._jax_init(project_fn=self.project_fn)
174
+
175
+ # ===========================================================================
176
+ # COMPILATION SUBROUTINES
177
+ # ===========================================================================
178
+
179
+ def _jax_compile_rddl(self):
180
+
181
+ # compile the RDDL model
182
+ self.compiled = JaxRDDLCompilerWithGrad(
183
+ rddl=self.rddl,
184
+ logic=self.logic,
185
+ use64bit=self.use64bit,
186
+ compile_non_fluent_exact=False,
187
+ print_warnings=True
188
+ )
189
+ self.compiled.compile(log_jax_expr=True, heading='RELAXED MODEL')
190
+
191
+ # compile the transition step function
192
+ step_fn = self.compiled.compile_transition()
193
+
194
+ def _jax_wrapped_step(key, param_fluents, subs, actions, hyperparams):
195
+ for (name, param) in param_fluents.items():
196
+ subs[name] = param
197
+ subs, _, hyperparams = step_fn(key, actions, subs, hyperparams)
198
+ return subs, hyperparams
199
+
200
+ # batched step function
201
+ def _jax_wrapped_batched_step(key, param_fluents, subs, actions, hyperparams):
202
+ keys = jnp.asarray(random.split(key, num=self.batch_size_train))
203
+ subs, hyperparams = jax.vmap(
204
+ _jax_wrapped_step, in_axes=(0, None, 0, 0, None)
205
+ )(keys, param_fluents, subs, actions, hyperparams)
206
+ hyperparams = jax.tree_util.tree_map(self.model_params_reduction, hyperparams)
207
+ return subs, hyperparams
208
+
209
+ # batched step function with parallel samples per data point
210
+ def _jax_wrapped_batched_parallel_step(key, param_fluents, subs, actions, hyperparams):
211
+ keys = jnp.asarray(random.split(key, num=self.samples_per_datapoint))
212
+ subs, hyperparams = jax.vmap(
213
+ _jax_wrapped_batched_step, in_axes=(0, None, None, None, None)
214
+ )(keys, param_fluents, subs, actions, hyperparams)
215
+ hyperparams = jax.tree_util.tree_map(self.model_params_reduction, hyperparams)
216
+ return subs, hyperparams
217
+
218
+ batched_step_fn = jax.jit(_jax_wrapped_batched_parallel_step)
219
+ return batched_step_fn
220
+
221
+ def _jax_map(self):
222
+
223
+ # compute case indices for bounding
224
+ case_indices = {}
225
+ if self.wrap_non_bool:
226
+ for (name, (lower, upper)) in self.param_ranges.items():
227
+ if lower is None: lower = -np.inf
228
+ if upper is None: upper = +np.inf
229
+ self.param_ranges[name] = (lower, upper)
230
+ case_indices[name] = (
231
+ 0 * (np.isfinite(lower) & np.isfinite(upper)) +
232
+ 1 * (np.isfinite(lower) & ~np.isfinite(upper)) +
233
+ 2 * (~np.isfinite(lower) & np.isfinite(upper)) +
234
+ 3 * (~np.isfinite(lower) & ~np.isfinite(upper))
235
+ )
236
+
237
+ # map trainable parameters to their non-fluent values
238
+ def _jax_wrapped_params_to_fluents(params):
239
+ param_fluents = {}
240
+ for (name, param) in params.items():
241
+ if self.rddl.variable_ranges[name] == 'bool':
242
+ param_fluents[name] = jax.nn.sigmoid(param)
243
+ else:
244
+ if self.wrap_non_bool:
245
+ lower, upper = self.param_ranges[name]
246
+ cases = [
247
+ lambda x: lower + (upper - lower) * jax.nn.sigmoid(x),
248
+ lambda x: lower + (jax.nn.elu(x) + 1.0),
249
+ lambda x: upper - (jax.nn.elu(-x) + 1.0),
250
+ lambda x: x
251
+ ]
252
+ indices = case_indices[name]
253
+ param_fluents[name] = jax.lax.switch(indices, cases, param)
254
+ else:
255
+ param_fluents[name] = param
256
+ return param_fluents
257
+
258
+ map_fn = jax.jit(_jax_wrapped_params_to_fluents)
259
+ return map_fn
260
+
261
+ def _jax_loss(self, map_fn, step_fn):
262
+
263
+ # use binary cross entropy for bool fluents
264
+ # mean squared error for continuous and integer fluents
265
+ def _jax_wrapped_batched_model_loss(key, param_fluents, subs, actions, next_fluents,
266
+ hyperparams):
267
+ next_subs, hyperparams = step_fn(key, param_fluents, subs, actions, hyperparams)
268
+ total_loss = 0.0
269
+ for (name, next_value) in next_fluents.items():
270
+ preds = jnp.asarray(next_subs[name], dtype=self.compiled.REAL)
271
+ targets = jnp.asarray(next_value, dtype=self.compiled.REAL)[jnp.newaxis, ...]
272
+ if self.rddl.variable_ranges[name] == 'bool':
273
+ loss_values = self.bool_fluent_loss(targets, preds)
274
+ elif self.rddl.variable_ranges[name] == 'real':
275
+ loss_values = self.real_fluent_loss(targets, preds)
276
+ else:
277
+ loss_values = self.int_fluent_loss(targets, preds)
278
+ total_loss += jnp.mean(loss_values) / len(next_fluents)
279
+ return total_loss, hyperparams
280
+
281
+ # loss with the parameters mapped to their fluents
282
+ def _jax_wrapped_batched_loss(key, params, subs, actions, next_fluents, hyperparams):
283
+ param_fluents = map_fn(params)
284
+ loss, hyperparams = _jax_wrapped_batched_model_loss(
285
+ key, param_fluents, subs, actions, next_fluents, hyperparams)
286
+ return loss, hyperparams
287
+
288
+ loss_fn = jax.jit(_jax_wrapped_batched_loss)
289
+ return loss_fn
290
+
291
+ def _jax_init(self, project_fn):
292
+ optimizer = self.optimizer
293
+
294
+ # initialize both the non-fluents and optimizer
295
+ def _jax_wrapped_init_params_optimizer(key):
296
+ params = {}
297
+ for name in self.param_ranges:
298
+ shape = jnp.shape(self.compiled.init_values[name])
299
+ key, subkey = random.split(key)
300
+ params[name] = self.initializer(subkey, shape, dtype=self.compiled.REAL)
301
+ params = project_fn(params)
302
+ opt_state = optimizer.init(params)
303
+ return params, opt_state
304
+
305
+ # initialize just the optimizer given the non-fluents
306
+ def _jax_wrapped_init_optimizer(params):
307
+ params = project_fn(params)
308
+ opt_state = optimizer.init(params)
309
+ return params, opt_state
310
+
311
+ init_fn = jax.jit(_jax_wrapped_init_params_optimizer)
312
+ init_opt_fn = jax.jit(_jax_wrapped_init_optimizer)
313
+ return init_fn, init_opt_fn
314
+
315
+ def _jax_update(self, loss_fn):
316
+ optimizer = self.optimizer
317
+
318
+ # projected gradient trick to satisfy box constraints on params
319
+ def _jax_wrapped_project_params(params):
320
+ if self.wrap_non_bool:
321
+ return params
322
+ else:
323
+ new_params = {}
324
+ for (name, value) in params.items():
325
+ if self.rddl.variable_ranges[name] == 'bool':
326
+ new_params[name] = value
327
+ else:
328
+ lower, upper = self.param_ranges[name]
329
+ new_params[name] = jnp.clip(value, lower, upper)
330
+ return new_params
331
+
332
+ # gradient descent update
333
+ def _jax_wrapped_params_update(key, params, subs, actions, next_fluents,
334
+ hyperparams, opt_state):
335
+ (loss_val, hyperparams), grad = jax.value_and_grad(
336
+ loss_fn, argnums=1, has_aux=True
337
+ )(key, params, subs, actions, next_fluents, hyperparams)
338
+ updates, opt_state = optimizer.update(grad, opt_state)
339
+ params = optax.apply_updates(params, updates)
340
+ params = _jax_wrapped_project_params(params)
341
+ zero_grads = jax.tree_util.tree_map(partial(jnp.allclose, b=0.0), grad)
342
+ return params, opt_state, loss_val, zero_grads, hyperparams
343
+
344
+ update_fn = jax.jit(_jax_wrapped_params_update)
345
+ project_fn = jax.jit(_jax_wrapped_project_params)
346
+ return update_fn, project_fn
347
+
348
+ def _batched_init_subs(self):
349
+ init_train = {}
350
+ for (name, value) in self.compiled.init_values.items():
351
+ value = np.reshape(value, np.shape(value))[np.newaxis, ...]
352
+ value = np.repeat(value, repeats=self.batch_size_train, axis=0)
353
+ value = np.asarray(value, dtype=self.compiled.REAL)
354
+ init_train[name] = value
355
+ for (state, next_state) in self.rddl.next_state.items():
356
+ init_train[next_state] = init_train[state]
357
+ return init_train
358
+
359
+ # ===========================================================================
360
+ # ESTIMATE API
361
+ # ===========================================================================
362
+
363
+ def optimize(self, *args, **kwargs) -> Optional[Callback]:
364
+ '''Estimate the unknown parameters from the given data set.
365
+ Return the callback from training.
366
+
367
+ :param data: a data stream represented as a (possibly infinite) sequence of
368
+ transition batches of the form (states, actions, next-states), where each element
369
+ is a numpy array of leading dimension equal to batch_size_train
370
+ :param key: JAX PRNG key (derived from clock if not provided)
371
+ :param epochs: the maximum number of steps of gradient descent
372
+ :param train_seconds: total time allocated for gradient descent
373
+ :param guess: initial non-fluent parameters: if None will use the initializer
374
+ specified in this instance
375
+ :param print_progress: whether to print the progress bar during training
376
+ '''
377
+ it = self.optimize_generator(*args, **kwargs)
378
+
379
+ # https://stackoverflow.com/questions/50937966/fastest-most-pythonic-way-to-consume-an-iterator
380
+ callback = None
381
+ if sys.implementation.name == 'cpython':
382
+ last_callback = deque(it, maxlen=1)
383
+ if last_callback:
384
+ callback = last_callback.pop()
385
+ else:
386
+ for callback in it:
387
+ pass
388
+ return callback
389
+
390
+ def optimize_generator(self, data: DataStream,
391
+ key: Optional[random.PRNGKey]=None,
392
+ epochs: int=999999,
393
+ train_seconds: float=120.,
394
+ guess: Optional[Params]=None,
395
+ print_progress: bool=True) -> Generator[Callback, None, None]:
396
+ '''Return a generator for estimating the unknown parameters from the given data set.
397
+ Generator can be iterated over to lazily estimate the parameters, yielding
398
+ a dictionary of intermediate computations.
399
+
400
+ :param data: a data stream represented as a (possibly infinite) sequence of
401
+ transition batches of the form (states, actions, next-states), where each element
402
+ is a numpy array of leading dimension equal to batch_size_train
403
+ :param key: JAX PRNG key (derived from clock if not provided)
404
+ :param epochs: the maximum number of steps of gradient descent
405
+ :param train_seconds: total time allocated for gradient descent
406
+ :param guess: initial non-fluent parameters: if None will use the initializer
407
+ specified in this instance
408
+ :param print_progress: whether to print the progress bar during training
409
+ '''
410
+ start_time = time.time()
411
+ elapsed_outside_loop = 0
412
+
413
+ # if PRNG key is not provided
414
+ if key is None:
415
+ key = random.PRNGKey(round(time.time() * 1000))
416
+
417
+ # prepare initial subs
418
+ subs = self._batched_init_subs()
419
+
420
+ # initialize parameter fluents to optimize
421
+ if guess is None:
422
+ key, subkey = random.split(key)
423
+ params, opt_state = self.init_fn(subkey)
424
+ else:
425
+ params, opt_state = self.init_opt_fn(guess)
426
+
427
+ # initialize model hyper-parameters
428
+ hyperparams = self.compiled.model_params
429
+
430
+ # progress bar
431
+ if print_progress:
432
+ progress_bar = tqdm(
433
+ None, total=100, bar_format='{l_bar}{bar}| {elapsed} {postfix}')
434
+ else:
435
+ progress_bar = None
436
+
437
+ # main training loop
438
+ for (it, (states, actions, next_states)) in enumerate(data):
439
+ status = JaxLearnerStatus.NORMAL
440
+
441
+ # gradient update
442
+ subs.update(states)
443
+ key, subkey = random.split(key)
444
+ params, opt_state, loss, zero_grads, hyperparams = self.update_fn(
445
+ subkey, params, subs, actions, next_states, hyperparams, opt_state)
446
+
447
+ # extract non-fluent values from the trainable parameters
448
+ param_fluents = self.map_fn(params)
449
+ param_fluents = {name: param_fluents[name] for name in self.param_ranges}
450
+
451
+ # check for learnability
452
+ params_zero_grads = {
453
+ name for (name, zero_grad) in zero_grads.items() if zero_grad}
454
+ if params_zero_grads:
455
+ status = JaxLearnerStatus.NO_PROGRESS
456
+
457
+ # reached computation budget
458
+ elapsed = time.time() - start_time - elapsed_outside_loop
459
+ if elapsed >= train_seconds:
460
+ status = JaxLearnerStatus.TIME_BUDGET_REACHED
461
+ if it >= epochs - 1:
462
+ status = JaxLearnerStatus.ITER_BUDGET_REACHED
463
+
464
+ # build a callback
465
+ progress_percent = 100 * min(
466
+ 1, max(0, elapsed / train_seconds, it / (epochs - 1)))
467
+ callback = {
468
+ 'status': status,
469
+ 'iteration': it,
470
+ 'train_loss': loss,
471
+ 'params': params,
472
+ 'param_fluents': param_fluents,
473
+ 'key': key,
474
+ 'progress': progress_percent
475
+ }
476
+
477
+ # update progress
478
+ if print_progress:
479
+ progress_bar.set_description(
480
+ f'{it:7} it / {loss:12.8f} train / {status.value} status', refresh=False)
481
+ progress_bar.set_postfix_str(
482
+ f'{(it + 1) / (elapsed + 1e-6):.2f}it/s', refresh=False)
483
+ progress_bar.update(progress_percent - progress_bar.n)
484
+
485
+ # yield the callback
486
+ start_time_outside = time.time()
487
+ yield callback
488
+ elapsed_outside_loop += (time.time() - start_time_outside)
489
+
490
+ # abortion check
491
+ if status.is_terminal():
492
+ break
493
+
494
+ def evaluate_loss(self, data: DataStream,
495
+ key: Optional[random.PRNGKey],
496
+ param_fluents: Params) -> float:
497
+ '''Evaluates the model loss of the given learned non-fluent values and the data.
498
+
499
+ :param data: a data stream represented as a (possibly infinite) sequence of
500
+ transition batches of the form (states, actions, next-states), where each element
501
+ is a numpy array of leading dimension equal to batch_size_train
502
+ :param key: JAX PRNG key (derived from clock if not provided)
503
+ :param param_fluents: the learned non-fluent values
504
+ '''
505
+ if key is None:
506
+ key = random.PRNGKey(round(time.time() * 1000))
507
+ subs = self._batched_init_subs()
508
+ hyperparams = self.compiled.model_params
509
+ mean_loss = 0.0
510
+ for (it, (states, actions, next_states)) in enumerate(data):
511
+ subs.update(states)
512
+ key, subkey = random.split(key)
513
+ loss_value, _ = self.loss_fn(
514
+ subkey, param_fluents, subs, actions, next_states, hyperparams)
515
+ mean_loss += (loss_value - mean_loss) / (it + 1)
516
+ return mean_loss
517
+
518
+ def learned_model(self, param_fluents: Params) -> RDDLLiftedModel:
519
+ '''Substitutes the given learned non-fluent values into the RDDL model and returns
520
+ the new model.
521
+
522
+ :param param_fluents: the learned non-fluent values
523
+ '''
524
+ model = deepcopy(self.rddl)
525
+ for (name, values) in param_fluents.items():
526
+ prange = model.variable_ranges[name]
527
+ if prange == 'real':
528
+ pass
529
+ elif prange == 'bool':
530
+ values = values > 0.5
531
+ else:
532
+ values = np.asarray(values, dtype=self.compiled.INT)
533
+ values = np.ravel(values, order='C').tolist()
534
+ if not self.rddl.variable_params[name]:
535
+ assert(len(values) == 1)
536
+ values = values[0]
537
+ model.non_fluents[name] = values
538
+ return model
539
+
540
+
541
+ if __name__ == '__main__':
542
+ import os
543
+ import pyRDDLGym
544
+ from pyRDDLGym_jax.core.planner import load_config, JaxBackpropPlanner, JaxOfflineController
545
+ bs = 32
546
+
547
+ # make some data
548
+ def data_iterator():
549
+ env = pyRDDLGym.make('CartPole_Continuous_gym', '0', vectorized=True)
550
+ model = JaxModelLearner(rddl=env.model, param_ranges={}, batch_size_train=bs)
551
+ key = random.PRNGKey(round(time.time() * 1000))
552
+ subs = model._batched_init_subs()
553
+ param_fluents = {}
554
+ while True:
555
+ states = {
556
+ 'pos': np.random.uniform(-2.4, 2.4, (bs,)),
557
+ 'vel': np.random.uniform(-2.4, 2.4, (bs,)),
558
+ 'ang-pos': np.random.uniform(-0.21, 0.21, (bs,)),
559
+ 'ang-vel': np.random.uniform(-0.21, 0.21, (bs,))
560
+ }
561
+ subs.update(states)
562
+ actions = {
563
+ 'force': np.random.uniform(-10., 10., (bs,))
564
+ }
565
+ key, subkey = random.split(key)
566
+ subs, _ = model.step_fn(subkey, param_fluents, subs, actions, {})
567
+ subs = {k: np.asarray(v)[0, ...] for k, v in subs.items()}
568
+ next_states = {k: subs[k] for k in model.rddl.state_fluents}
569
+ yield (states, actions, next_states)
570
+
571
+ # train it
572
+ env = pyRDDLGym.make('TestJax', '0', vectorized=True)
573
+ model_learner = JaxModelLearner(rddl=env.model,
574
+ param_ranges={
575
+ 'w1': (None, None), 'b1': (None, None),
576
+ 'w2': (None, None), 'b2': (None, None),
577
+ 'w1o': (None, None), 'b1o': (None, None),
578
+ 'w2o': (None, None), 'b2o': (None, None)
579
+ },
580
+ batch_size_train=bs,
581
+ optimizer_kwargs = {'learning_rate': 0.0003})
582
+ for cb in model_learner.optimize_generator(data_iterator(), epochs=10000):
583
+ pass
584
+
585
+ # planning in the trained model
586
+ model = model_learner.learned_model(cb['param_fluents'])
587
+ abs_path = os.path.dirname(os.path.abspath(__file__))
588
+ config_path = os.path.join(os.path.dirname(abs_path), 'examples', 'configs', 'default_drp.cfg')
589
+ planner_args, _, train_args = load_config(config_path)
590
+ planner = JaxBackpropPlanner(rddl=model, **planner_args)
591
+ controller = JaxOfflineController(planner, **train_args)
592
+
593
+ # evaluation of the plan
594
+ test_env = pyRDDLGym.make('CartPole_Continuous_gym', '0', vectorized=True)
595
+ controller.evaluate(test_env, episodes=1, verbose=True, render=True)