pyRDDLGym-jax 2.5__py3-none-any.whl → 2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pyRDDLGym_jax/__init__.py CHANGED
@@ -1 +1 @@
1
- __version__ = '2.5'
1
+ __version__ = '2.6'
@@ -237,7 +237,8 @@ class JaxRDDLCompiler:
237
237
 
238
238
  def compile_transition(self, check_constraints: bool=False,
239
239
  constraint_func: bool=False,
240
- init_params_constr: Dict[str, Any]={}) -> Callable:
240
+ init_params_constr: Dict[str, Any]={},
241
+ cache_path_info: bool=False) -> Callable:
241
242
  '''Compiles the current RDDL model into a JAX transition function that
242
243
  samples the next state.
243
244
 
@@ -274,6 +275,7 @@ class JaxRDDLCompiler:
274
275
  returned log and does not raise an exception
275
276
  :param constraint_func: produces the h(s, a) function described above
276
277
  in addition to the usual outputs
278
+ :param cache_path_info: whether to save full path traces as part of the log
277
279
  '''
278
280
  NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
279
281
  rddl = self.rddl
@@ -322,8 +324,11 @@ class JaxRDDLCompiler:
322
324
  errors |= err
323
325
 
324
326
  # calculate fluent values
325
- fluents = {name: values for (name, values) in subs.items()
326
- if name not in rddl.non_fluents}
327
+ if cache_path_info:
328
+ fluents = {name: values for (name, values) in subs.items()
329
+ if name not in rddl.non_fluents}
330
+ else:
331
+ fluents = {}
327
332
 
328
333
  # set the next state to the current state
329
334
  for (state, next_state) in rddl.next_state.items():
@@ -367,7 +372,9 @@ class JaxRDDLCompiler:
367
372
  n_batch: int,
368
373
  check_constraints: bool=False,
369
374
  constraint_func: bool=False,
370
- init_params_constr: Dict[str, Any]={}) -> Callable:
375
+ init_params_constr: Dict[str, Any]={},
376
+ model_params_reduction: Callable=lambda x: x[0],
377
+ cache_path_info: bool=False) -> Callable:
371
378
  '''Compiles the current RDDL model into a JAX transition function that
372
379
  samples trajectories with a fixed horizon from a policy.
373
380
 
@@ -399,10 +406,13 @@ class JaxRDDLCompiler:
399
406
  returned log and does not raise an exception
400
407
  :param constraint_func: produces the h(s, a) constraint function
401
408
  in addition to the usual outputs
409
+ :param model_params_reduction: how to aggregate updated model_params across runs
410
+ in the batch (defaults to selecting the first element's parameters in the batch)
411
+ :param cache_path_info: whether to save full path traces as part of the log
402
412
  '''
403
413
  rddl = self.rddl
404
414
  jax_step_fn = self.compile_transition(
405
- check_constraints, constraint_func, init_params_constr)
415
+ check_constraints, constraint_func, init_params_constr, cache_path_info)
406
416
 
407
417
  # for POMDP only observ-fluents are assumed visible to the policy
408
418
  if rddl.observ_fluents:
@@ -421,7 +431,6 @@ class JaxRDDLCompiler:
421
431
  return jax_step_fn(subkey, actions, subs, model_params)
422
432
 
423
433
  # do a batched step update from the policy
424
- # TODO: come up with a better way to reduce the model_param batch dim
425
434
  def _jax_wrapped_batched_step_policy(carry, step):
426
435
  key, policy_params, hyperparams, subs, model_params = carry
427
436
  key, *subkeys = random.split(key, num=1 + n_batch)
@@ -430,7 +439,7 @@ class JaxRDDLCompiler:
430
439
  _jax_wrapped_single_step_policy,
431
440
  in_axes=(0, None, None, None, 0, None)
432
441
  )(keys, policy_params, hyperparams, step, subs, model_params)
433
- model_params = jax.tree_util.tree_map(partial(jnp.mean, axis=0), model_params)
442
+ model_params = jax.tree_util.tree_map(model_params_reduction, model_params)
434
443
  carry = (key, policy_params, hyperparams, subs, model_params)
435
444
  return carry, log
436
445
 
@@ -1056,15 +1056,13 @@ class ExactLogic(Logic):
1056
1056
  def control_if(self, id, init_params):
1057
1057
  return self._jax_wrapped_calc_if_then_else_exact
1058
1058
 
1059
- @staticmethod
1060
- def _jax_wrapped_calc_switch_exact(pred, cases, params):
1061
- pred = pred[jnp.newaxis, ...]
1062
- sample = jnp.take_along_axis(cases, pred, axis=0)
1063
- assert sample.shape[0] == 1
1064
- return sample[0, ...], params
1065
-
1066
1059
  def control_switch(self, id, init_params):
1067
- return self._jax_wrapped_calc_switch_exact
1060
+ def _jax_wrapped_calc_switch_exact(pred, cases, params):
1061
+ pred = jnp.asarray(pred[jnp.newaxis, ...], dtype=self.INT)
1062
+ sample = jnp.take_along_axis(cases, pred, axis=0)
1063
+ assert sample.shape[0] == 1
1064
+ return sample[0, ...], params
1065
+ return _jax_wrapped_calc_switch_exact
1068
1066
 
1069
1067
  # ===========================================================================
1070
1068
  # random variables
@@ -0,0 +1,595 @@
1
+ from collections import deque
2
+ from copy import deepcopy
3
+ from enum import Enum
4
+ from functools import partial
5
+ import sys
6
+ import time
7
+ from tqdm import tqdm
8
+ from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple
9
+
10
+ import jax
11
+ import jax.nn.initializers as initializers
12
+ import jax.numpy as jnp
13
+ import jax.random as random
14
+ import numpy as np
15
+ import optax
16
+
17
+ from pyRDDLGym.core.compiler.model import RDDLLiftedModel
18
+
19
+ from pyRDDLGym_jax.core.logic import Logic, ExactLogic
20
+ from pyRDDLGym_jax.core.planner import JaxRDDLCompilerWithGrad
21
+
22
+ Kwargs = Dict[str, Any]
23
+ State = Dict[str, np.ndarray]
24
+ Action = Dict[str, np.ndarray]
25
+ DataStream = Iterable[Tuple[State, Action, State]]
26
+ Params = Dict[str, np.ndarray]
27
+ Callback = Dict[str, Any]
28
+ LossFunction = Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
29
+
30
+
31
+ # ***********************************************************************
32
+ # ALL VERSIONS OF LOSS FUNCTIONS
33
+ #
34
+ # - loss functions based on specific likelihood assumptions (MSE, cross-entropy)
35
+ #
36
+ # ***********************************************************************
37
+
38
+
39
+ def mean_squared_error() -> LossFunction:
40
+ def _jax_wrapped_mse_loss(target, pred):
41
+ loss_values = jnp.square(target - pred)
42
+ return loss_values
43
+ return jax.jit(_jax_wrapped_mse_loss)
44
+
45
+
46
+ def binary_cross_entropy(eps: float=1e-6) -> LossFunction:
47
+ def _jax_wrapped_binary_cross_entropy_loss(target, pred):
48
+ pred = jnp.clip(pred, eps, 1.0 - eps)
49
+ log_pred = jnp.log(pred)
50
+ log_not_pred = jnp.log(1.0 - pred)
51
+ loss_values = -target * log_pred - (1.0 - target) * log_not_pred
52
+ return loss_values
53
+ return jax.jit(_jax_wrapped_binary_cross_entropy_loss)
54
+
55
+
56
+ def optax_loss(loss_fn: LossFunction, **kwargs) -> LossFunction:
57
+ def _jax_wrapped_optax_loss(target, pred):
58
+ loss_values = loss_fn(pred, target, **kwargs)
59
+ return loss_values
60
+ return jax.jit(_jax_wrapped_optax_loss)
61
+
62
+
63
+ # ***********************************************************************
64
+ # ALL VERSIONS OF JAX MODEL LEARNER
65
+ #
66
+ # - gradient based model learning
67
+ #
68
+ # ***********************************************************************
69
+
70
+
71
+ class JaxLearnerStatus(Enum):
72
+ '''Represents the status of a parameter update from the JAX model learner,
73
+ including whether the update resulted in nan gradient,
74
+ whether progress was made, budget was reached, or other information that
75
+ can be used to monitor and act based on the learner's progress.'''
76
+
77
+ NORMAL = 0
78
+ NO_PROGRESS = 1
79
+ INVALID_GRADIENT = 2
80
+ TIME_BUDGET_REACHED = 3
81
+ ITER_BUDGET_REACHED = 4
82
+
83
+ def is_terminal(self) -> bool:
84
+ return self.value >= 2
85
+
86
+
87
+ class JaxModelLearner:
88
+ '''A class for data-driven estimation of unknown parameters in a given RDDL MDP using
89
+ gradient descent.'''
90
+
91
+ def __init__(self, rddl: RDDLLiftedModel,
92
+ param_ranges: Dict[str, Tuple[Optional[np.ndarray], Optional[np.ndarray]]],
93
+ batch_size_train: int=32,
94
+ samples_per_datapoint: int=1,
95
+ optimizer: Callable[..., optax.GradientTransformation]=optax.rmsprop,
96
+ optimizer_kwargs: Optional[Kwargs]=None,
97
+ initializer: initializers.Initializer = initializers.normal(),
98
+ wrap_non_bool: bool=True,
99
+ use64bit: bool=False,
100
+ bool_fluent_loss: LossFunction=binary_cross_entropy(),
101
+ real_fluent_loss: LossFunction=mean_squared_error(),
102
+ int_fluent_loss: LossFunction=mean_squared_error(),
103
+ logic: Logic=ExactLogic(),
104
+ model_params_reduction: Callable=lambda x: x[0]) -> None:
105
+ '''Creates a new gradient-based algorithm for inferring unknown non-fluents
106
+ in a RDDL domain from a data set or stream coming from the real environment.
107
+
108
+ :param rddl: the RDDL domain to learn
109
+ :param param_ranges: the ranges of all learnable non-fluents
110
+ :param batch_size_train: how many transitions to compute per optimization
111
+ step
112
+ :param samples_per_datapoint: how many random samples to produce from the step
113
+ function per data point during training
114
+ :param optimizer: a factory for an optax SGD algorithm
115
+ :param optimizer_kwargs: a dictionary of parameters to pass to the SGD
116
+ factory (e.g. which parameters are controllable externally)
117
+ :param initializer: how to initialize non-fluents
118
+ :param wrap_non_bool: whether to wrap non-boolean trainable parameters to satisfy
119
+ required ranges as specified in param_ranges (use a projected gradient otherwise)
120
+ :param use64bit: whether to perform arithmetic in 64 bit
121
+ :param bool_fluent_loss: loss function to optimize for bool-valued fluents
122
+ :param real_fluent_loss: loss function to optimize for real-valued fluents
123
+ :param int_fluent_loss: loss function to optimize for int-valued fluents
124
+ :param logic: a subclass of Logic for mapping exact mathematical
125
+ operations to their differentiable counterparts
126
+ :param model_params_reduction: how to aggregate updated model_params across runs
127
+ in the batch (defaults to selecting the first element's parameters in the batch)
128
+ '''
129
+ self.rddl = rddl
130
+ self.param_ranges = param_ranges.copy()
131
+ self.batch_size_train = batch_size_train
132
+ self.samples_per_datapoint = samples_per_datapoint
133
+ if optimizer_kwargs is None:
134
+ optimizer_kwargs = {'learning_rate': 0.001}
135
+ self.optimizer_kwargs = optimizer_kwargs
136
+ self.initializer = initializer
137
+ self.wrap_non_bool = wrap_non_bool
138
+ self.use64bit = use64bit
139
+ self.bool_fluent_loss = bool_fluent_loss
140
+ self.real_fluent_loss = real_fluent_loss
141
+ self.int_fluent_loss = int_fluent_loss
142
+ self.logic = logic
143
+ self.model_params_reduction = model_params_reduction
144
+
145
+ # validate param_ranges
146
+ for (name, values) in param_ranges.items():
147
+ if name not in rddl.non_fluents:
148
+ raise ValueError(
149
+ f'param_ranges key <{name}> is not a valid non-fluent '
150
+ f'in the current rddl.')
151
+ if not isinstance(values, (tuple, list)):
152
+ raise ValueError(
153
+ f'param_ranges values with key <{name}> are neither a tuple nor a list.')
154
+ if len(values) != 2:
155
+ raise ValueError(
156
+ f'param_ranges values with key <{name}> must be of length 2, '
157
+ f'got length {len(values)}.')
158
+ lower, upper = values
159
+ if lower is not None and upper is not None and not np.all(lower <= upper):
160
+ raise ValueError(
161
+ f'param_ranges values with key <{name}> do not satisfy lower <= upper.')
162
+
163
+ # build the optimizer
164
+ optimizer = optimizer(**optimizer_kwargs)
165
+ pipeline = [optimizer]
166
+ self.optimizer = optax.chain(*pipeline)
167
+
168
+ # build the computation graph
169
+ self.step_fn = self._jax_compile_rddl()
170
+ self.map_fn = self._jax_map()
171
+ self.loss_fn = self._jax_loss(map_fn=self.map_fn, step_fn=self.step_fn)
172
+ self.update_fn, self.project_fn = self._jax_update(loss_fn=self.loss_fn)
173
+ self.init_fn, self.init_opt_fn = self._jax_init(project_fn=self.project_fn)
174
+
175
+ # ===========================================================================
176
+ # COMPILATION SUBROUTINES
177
+ # ===========================================================================
178
+
179
+ def _jax_compile_rddl(self):
180
+
181
+ # compile the RDDL model
182
+ self.compiled = JaxRDDLCompilerWithGrad(
183
+ rddl=self.rddl,
184
+ logic=self.logic,
185
+ use64bit=self.use64bit,
186
+ compile_non_fluent_exact=False,
187
+ print_warnings=True
188
+ )
189
+ self.compiled.compile(log_jax_expr=True, heading='RELAXED MODEL')
190
+
191
+ # compile the transition step function
192
+ step_fn = self.compiled.compile_transition()
193
+
194
+ def _jax_wrapped_step(key, param_fluents, subs, actions, hyperparams):
195
+ for (name, param) in param_fluents.items():
196
+ subs[name] = param
197
+ subs, _, hyperparams = step_fn(key, actions, subs, hyperparams)
198
+ return subs, hyperparams
199
+
200
+ # batched step function
201
+ def _jax_wrapped_batched_step(key, param_fluents, subs, actions, hyperparams):
202
+ keys = jnp.asarray(random.split(key, num=self.batch_size_train))
203
+ subs, hyperparams = jax.vmap(
204
+ _jax_wrapped_step, in_axes=(0, None, 0, 0, None)
205
+ )(keys, param_fluents, subs, actions, hyperparams)
206
+ hyperparams = jax.tree_util.tree_map(self.model_params_reduction, hyperparams)
207
+ return subs, hyperparams
208
+
209
+ # batched step function with parallel samples per data point
210
+ def _jax_wrapped_batched_parallel_step(key, param_fluents, subs, actions, hyperparams):
211
+ keys = jnp.asarray(random.split(key, num=self.samples_per_datapoint))
212
+ subs, hyperparams = jax.vmap(
213
+ _jax_wrapped_batched_step, in_axes=(0, None, None, None, None)
214
+ )(keys, param_fluents, subs, actions, hyperparams)
215
+ hyperparams = jax.tree_util.tree_map(self.model_params_reduction, hyperparams)
216
+ return subs, hyperparams
217
+
218
+ batched_step_fn = jax.jit(_jax_wrapped_batched_parallel_step)
219
+ return batched_step_fn
220
+
221
+ def _jax_map(self):
222
+
223
+ # compute case indices for bounding
224
+ case_indices = {}
225
+ if self.wrap_non_bool:
226
+ for (name, (lower, upper)) in self.param_ranges.items():
227
+ if lower is None: lower = -np.inf
228
+ if upper is None: upper = +np.inf
229
+ self.param_ranges[name] = (lower, upper)
230
+ case_indices[name] = (
231
+ 0 * (np.isfinite(lower) & np.isfinite(upper)) +
232
+ 1 * (np.isfinite(lower) & ~np.isfinite(upper)) +
233
+ 2 * (~np.isfinite(lower) & np.isfinite(upper)) +
234
+ 3 * (~np.isfinite(lower) & ~np.isfinite(upper))
235
+ )
236
+
237
+ # map trainable parameters to their non-fluent values
238
+ def _jax_wrapped_params_to_fluents(params):
239
+ param_fluents = {}
240
+ for (name, param) in params.items():
241
+ if self.rddl.variable_ranges[name] == 'bool':
242
+ param_fluents[name] = jax.nn.sigmoid(param)
243
+ else:
244
+ if self.wrap_non_bool:
245
+ lower, upper = self.param_ranges[name]
246
+ cases = [
247
+ lambda x: lower + (upper - lower) * jax.nn.sigmoid(x),
248
+ lambda x: lower + (jax.nn.elu(x) + 1.0),
249
+ lambda x: upper - (jax.nn.elu(-x) + 1.0),
250
+ lambda x: x
251
+ ]
252
+ indices = case_indices[name]
253
+ param_fluents[name] = jax.lax.switch(indices, cases, param)
254
+ else:
255
+ param_fluents[name] = param
256
+ return param_fluents
257
+
258
+ map_fn = jax.jit(_jax_wrapped_params_to_fluents)
259
+ return map_fn
260
+
261
+ def _jax_loss(self, map_fn, step_fn):
262
+
263
+ # use binary cross entropy for bool fluents
264
+ # mean squared error for continuous and integer fluents
265
+ def _jax_wrapped_batched_model_loss(key, param_fluents, subs, actions, next_fluents,
266
+ hyperparams):
267
+ next_subs, hyperparams = step_fn(key, param_fluents, subs, actions, hyperparams)
268
+ total_loss = 0.0
269
+ for (name, next_value) in next_fluents.items():
270
+ preds = jnp.asarray(next_subs[name], dtype=self.compiled.REAL)
271
+ targets = jnp.asarray(next_value, dtype=self.compiled.REAL)[jnp.newaxis, ...]
272
+ if self.rddl.variable_ranges[name] == 'bool':
273
+ loss_values = self.bool_fluent_loss(targets, preds)
274
+ elif self.rddl.variable_ranges[name] == 'real':
275
+ loss_values = self.real_fluent_loss(targets, preds)
276
+ else:
277
+ loss_values = self.int_fluent_loss(targets, preds)
278
+ total_loss += jnp.mean(loss_values) / len(next_fluents)
279
+ return total_loss, hyperparams
280
+
281
+ # loss with the parameters mapped to their fluents
282
+ def _jax_wrapped_batched_loss(key, params, subs, actions, next_fluents, hyperparams):
283
+ param_fluents = map_fn(params)
284
+ loss, hyperparams = _jax_wrapped_batched_model_loss(
285
+ key, param_fluents, subs, actions, next_fluents, hyperparams)
286
+ return loss, hyperparams
287
+
288
+ loss_fn = jax.jit(_jax_wrapped_batched_loss)
289
+ return loss_fn
290
+
291
+ def _jax_init(self, project_fn):
292
+ optimizer = self.optimizer
293
+
294
+ # initialize both the non-fluents and optimizer
295
+ def _jax_wrapped_init_params_optimizer(key):
296
+ params = {}
297
+ for name in self.param_ranges:
298
+ shape = jnp.shape(self.compiled.init_values[name])
299
+ key, subkey = random.split(key)
300
+ params[name] = self.initializer(subkey, shape, dtype=self.compiled.REAL)
301
+ params = project_fn(params)
302
+ opt_state = optimizer.init(params)
303
+ return params, opt_state
304
+
305
+ # initialize just the optimizer given the non-fluents
306
+ def _jax_wrapped_init_optimizer(params):
307
+ params = project_fn(params)
308
+ opt_state = optimizer.init(params)
309
+ return params, opt_state
310
+
311
+ init_fn = jax.jit(_jax_wrapped_init_params_optimizer)
312
+ init_opt_fn = jax.jit(_jax_wrapped_init_optimizer)
313
+ return init_fn, init_opt_fn
314
+
315
+ def _jax_update(self, loss_fn):
316
+ optimizer = self.optimizer
317
+
318
+ # projected gradient trick to satisfy box constraints on params
319
+ def _jax_wrapped_project_params(params):
320
+ if self.wrap_non_bool:
321
+ return params
322
+ else:
323
+ new_params = {}
324
+ for (name, value) in params.items():
325
+ if self.rddl.variable_ranges[name] == 'bool':
326
+ new_params[name] = value
327
+ else:
328
+ lower, upper = self.param_ranges[name]
329
+ new_params[name] = jnp.clip(value, lower, upper)
330
+ return new_params
331
+
332
+ # gradient descent update
333
+ def _jax_wrapped_params_update(key, params, subs, actions, next_fluents,
334
+ hyperparams, opt_state):
335
+ (loss_val, hyperparams), grad = jax.value_and_grad(
336
+ loss_fn, argnums=1, has_aux=True
337
+ )(key, params, subs, actions, next_fluents, hyperparams)
338
+ updates, opt_state = optimizer.update(grad, opt_state)
339
+ params = optax.apply_updates(params, updates)
340
+ params = _jax_wrapped_project_params(params)
341
+ zero_grads = jax.tree_util.tree_map(partial(jnp.allclose, b=0.0), grad)
342
+ return params, opt_state, loss_val, zero_grads, hyperparams
343
+
344
+ update_fn = jax.jit(_jax_wrapped_params_update)
345
+ project_fn = jax.jit(_jax_wrapped_project_params)
346
+ return update_fn, project_fn
347
+
348
+ def _batched_init_subs(self):
349
+ init_train = {}
350
+ for (name, value) in self.compiled.init_values.items():
351
+ value = np.reshape(value, np.shape(value))[np.newaxis, ...]
352
+ value = np.repeat(value, repeats=self.batch_size_train, axis=0)
353
+ value = np.asarray(value, dtype=self.compiled.REAL)
354
+ init_train[name] = value
355
+ for (state, next_state) in self.rddl.next_state.items():
356
+ init_train[next_state] = init_train[state]
357
+ return init_train
358
+
359
+ # ===========================================================================
360
+ # ESTIMATE API
361
+ # ===========================================================================
362
+
363
+ def optimize(self, *args, **kwargs) -> Optional[Callback]:
364
+ '''Estimate the unknown parameters from the given data set.
365
+ Return the callback from training.
366
+
367
+ :param data: a data stream represented as a (possibly infinite) sequence of
368
+ transition batches of the form (states, actions, next-states), where each element
369
+ is a numpy array of leading dimension equal to batch_size_train
370
+ :param key: JAX PRNG key (derived from clock if not provided)
371
+ :param epochs: the maximum number of steps of gradient descent
372
+ :param train_seconds: total time allocated for gradient descent
373
+ :param guess: initial non-fluent parameters: if None will use the initializer
374
+ specified in this instance
375
+ :param print_progress: whether to print the progress bar during training
376
+ '''
377
+ it = self.optimize_generator(*args, **kwargs)
378
+
379
+ # https://stackoverflow.com/questions/50937966/fastest-most-pythonic-way-to-consume-an-iterator
380
+ callback = None
381
+ if sys.implementation.name == 'cpython':
382
+ last_callback = deque(it, maxlen=1)
383
+ if last_callback:
384
+ callback = last_callback.pop()
385
+ else:
386
+ for callback in it:
387
+ pass
388
+ return callback
389
+
390
+ def optimize_generator(self, data: DataStream,
391
+ key: Optional[random.PRNGKey]=None,
392
+ epochs: int=999999,
393
+ train_seconds: float=120.,
394
+ guess: Optional[Params]=None,
395
+ print_progress: bool=True) -> Generator[Callback, None, None]:
396
+ '''Return a generator for estimating the unknown parameters from the given data set.
397
+ Generator can be iterated over to lazily estimate the parameters, yielding
398
+ a dictionary of intermediate computations.
399
+
400
+ :param data: a data stream represented as a (possibly infinite) sequence of
401
+ transition batches of the form (states, actions, next-states), where each element
402
+ is a numpy array of leading dimension equal to batch_size_train
403
+ :param key: JAX PRNG key (derived from clock if not provided)
404
+ :param epochs: the maximum number of steps of gradient descent
405
+ :param train_seconds: total time allocated for gradient descent
406
+ :param guess: initial non-fluent parameters: if None will use the initializer
407
+ specified in this instance
408
+ :param print_progress: whether to print the progress bar during training
409
+ '''
410
+ start_time = time.time()
411
+ elapsed_outside_loop = 0
412
+
413
+ # if PRNG key is not provided
414
+ if key is None:
415
+ key = random.PRNGKey(round(time.time() * 1000))
416
+
417
+ # prepare initial subs
418
+ subs = self._batched_init_subs()
419
+
420
+ # initialize parameter fluents to optimize
421
+ if guess is None:
422
+ key, subkey = random.split(key)
423
+ params, opt_state = self.init_fn(subkey)
424
+ else:
425
+ params, opt_state = self.init_opt_fn(guess)
426
+
427
+ # initialize model hyper-parameters
428
+ hyperparams = self.compiled.model_params
429
+
430
+ # progress bar
431
+ if print_progress:
432
+ progress_bar = tqdm(
433
+ None, total=100, bar_format='{l_bar}{bar}| {elapsed} {postfix}')
434
+ else:
435
+ progress_bar = None
436
+
437
+ # main training loop
438
+ for (it, (states, actions, next_states)) in enumerate(data):
439
+ status = JaxLearnerStatus.NORMAL
440
+
441
+ # gradient update
442
+ subs.update(states)
443
+ key, subkey = random.split(key)
444
+ params, opt_state, loss, zero_grads, hyperparams = self.update_fn(
445
+ subkey, params, subs, actions, next_states, hyperparams, opt_state)
446
+
447
+ # extract non-fluent values from the trainable parameters
448
+ param_fluents = self.map_fn(params)
449
+ param_fluents = {name: param_fluents[name] for name in self.param_ranges}
450
+
451
+ # check for learnability
452
+ params_zero_grads = {
453
+ name for (name, zero_grad) in zero_grads.items() if zero_grad}
454
+ if params_zero_grads:
455
+ status = JaxLearnerStatus.NO_PROGRESS
456
+
457
+ # reached computation budget
458
+ elapsed = time.time() - start_time - elapsed_outside_loop
459
+ if elapsed >= train_seconds:
460
+ status = JaxLearnerStatus.TIME_BUDGET_REACHED
461
+ if it >= epochs - 1:
462
+ status = JaxLearnerStatus.ITER_BUDGET_REACHED
463
+
464
+ # build a callback
465
+ progress_percent = 100 * min(
466
+ 1, max(0, elapsed / train_seconds, it / (epochs - 1)))
467
+ callback = {
468
+ 'status': status,
469
+ 'iteration': it,
470
+ 'train_loss': loss,
471
+ 'params': params,
472
+ 'param_fluents': param_fluents,
473
+ 'key': key,
474
+ 'progress': progress_percent
475
+ }
476
+
477
+ # update progress
478
+ if print_progress:
479
+ progress_bar.set_description(
480
+ f'{it:7} it / {loss:12.8f} train / {status.value} status', refresh=False)
481
+ progress_bar.set_postfix_str(
482
+ f'{(it + 1) / (elapsed + 1e-6):.2f}it/s', refresh=False)
483
+ progress_bar.update(progress_percent - progress_bar.n)
484
+
485
+ # yield the callback
486
+ start_time_outside = time.time()
487
+ yield callback
488
+ elapsed_outside_loop += (time.time() - start_time_outside)
489
+
490
+ # abortion check
491
+ if status.is_terminal():
492
+ break
493
+
494
+ def evaluate_loss(self, data: DataStream,
495
+ key: Optional[random.PRNGKey],
496
+ param_fluents: Params) -> float:
497
+ '''Evaluates the model loss of the given learned non-fluent values and the data.
498
+
499
+ :param data: a data stream represented as a (possibly infinite) sequence of
500
+ transition batches of the form (states, actions, next-states), where each element
501
+ is a numpy array of leading dimension equal to batch_size_train
502
+ :param key: JAX PRNG key (derived from clock if not provided)
503
+ :param param_fluents: the learned non-fluent values
504
+ '''
505
+ if key is None:
506
+ key = random.PRNGKey(round(time.time() * 1000))
507
+ subs = self._batched_init_subs()
508
+ hyperparams = self.compiled.model_params
509
+ mean_loss = 0.0
510
+ for (it, (states, actions, next_states)) in enumerate(data):
511
+ subs.update(states)
512
+ key, subkey = random.split(key)
513
+ loss_value, _ = self.loss_fn(
514
+ subkey, param_fluents, subs, actions, next_states, hyperparams)
515
+ mean_loss += (loss_value - mean_loss) / (it + 1)
516
+ return mean_loss
517
+
518
+ def learned_model(self, param_fluents: Params) -> RDDLLiftedModel:
519
+ '''Substitutes the given learned non-fluent values into the RDDL model and returns
520
+ the new model.
521
+
522
+ :param param_fluents: the learned non-fluent values
523
+ '''
524
+ model = deepcopy(self.rddl)
525
+ for (name, values) in param_fluents.items():
526
+ prange = model.variable_ranges[name]
527
+ if prange == 'real':
528
+ pass
529
+ elif prange == 'bool':
530
+ values = values > 0.5
531
+ else:
532
+ values = np.asarray(values, dtype=self.compiled.INT)
533
+ values = np.ravel(values, order='C').tolist()
534
+ if not self.rddl.variable_params[name]:
535
+ assert(len(values) == 1)
536
+ values = values[0]
537
+ model.non_fluents[name] = values
538
+ return model
539
+
540
+
541
+ if __name__ == '__main__':
542
+ import os
543
+ import pyRDDLGym
544
+ from pyRDDLGym_jax.core.planner import load_config, JaxBackpropPlanner, JaxOfflineController
545
+ bs = 32
546
+
547
+ # make some data
548
+ def data_iterator():
549
+ env = pyRDDLGym.make('CartPole_Continuous_gym', '0', vectorized=True)
550
+ model = JaxModelLearner(rddl=env.model, param_ranges={}, batch_size_train=bs)
551
+ key = random.PRNGKey(round(time.time() * 1000))
552
+ subs = model._batched_init_subs()
553
+ param_fluents = {}
554
+ while True:
555
+ states = {
556
+ 'pos': np.random.uniform(-2.4, 2.4, (bs,)),
557
+ 'vel': np.random.uniform(-2.4, 2.4, (bs,)),
558
+ 'ang-pos': np.random.uniform(-0.21, 0.21, (bs,)),
559
+ 'ang-vel': np.random.uniform(-0.21, 0.21, (bs,))
560
+ }
561
+ subs.update(states)
562
+ actions = {
563
+ 'force': np.random.uniform(-10., 10., (bs,))
564
+ }
565
+ key, subkey = random.split(key)
566
+ subs, _ = model.step_fn(subkey, param_fluents, subs, actions, {})
567
+ subs = {k: np.asarray(v)[0, ...] for k, v in subs.items()}
568
+ next_states = {k: subs[k] for k in model.rddl.state_fluents}
569
+ yield (states, actions, next_states)
570
+
571
+ # train it
572
+ env = pyRDDLGym.make('TestJax', '0', vectorized=True)
573
+ model_learner = JaxModelLearner(rddl=env.model,
574
+ param_ranges={
575
+ 'w1': (None, None), 'b1': (None, None),
576
+ 'w2': (None, None), 'b2': (None, None),
577
+ 'w1o': (None, None), 'b1o': (None, None),
578
+ 'w2o': (None, None), 'b2o': (None, None)
579
+ },
580
+ batch_size_train=bs,
581
+ optimizer_kwargs = {'learning_rate': 0.0003})
582
+ for cb in model_learner.optimize_generator(data_iterator(), epochs=10000):
583
+ pass
584
+
585
+ # planning in the trained model
586
+ model = model_learner.learned_model(cb['param_fluents'])
587
+ abs_path = os.path.dirname(os.path.abspath(__file__))
588
+ config_path = os.path.join(os.path.dirname(abs_path), 'examples', 'configs', 'default_drp.cfg')
589
+ planner_args, _, train_args = load_config(config_path)
590
+ planner = JaxBackpropPlanner(rddl=model, **planner_args)
591
+ controller = JaxOfflineController(planner, **train_args)
592
+
593
+ # evaluation of the plan
594
+ test_env = pyRDDLGym.make('CartPole_Continuous_gym', '0', vectorized=True)
595
+ controller.evaluate(test_env, episodes=1, verbose=True, render=True)
@@ -207,6 +207,13 @@ def _load_config(config, args):
207
207
  pgpe_kwargs['optimizer'] = pgpe_optimizer
208
208
  planner_args['pgpe'] = getattr(sys.modules[__name__], pgpe_method)(**pgpe_kwargs)
209
209
 
210
+ # preprocessor settings
211
+ preproc_method = planner_args.get('preprocessor', None)
212
+ preproc_kwargs = planner_args.pop('preprocessor_kwargs', {})
213
+ if preproc_method is not None:
214
+ planner_args['preprocessor'] = getattr(
215
+ sys.modules[__name__], preproc_method)(**preproc_kwargs)
216
+
210
217
  # optimize call RNG key
211
218
  planner_key = train_args.get('key', None)
212
219
  if planner_key is not None:
@@ -343,6 +350,100 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
343
350
  return arg
344
351
 
345
352
 
353
+ # ***********************************************************************
354
+ # ALL VERSIONS OF STATE PREPROCESSING FOR DRP
355
+ #
356
+ # - static normalization
357
+ #
358
+ # ***********************************************************************
359
+
360
+
361
+ class Preprocessor(metaclass=ABCMeta):
362
+ '''Base class for all state preprocessors.'''
363
+
364
+ HYPERPARAMS_KEY = 'preprocessor__'
365
+
366
+ def __init__(self) -> None:
367
+ self._initializer = None
368
+ self._update = None
369
+ self._transform = None
370
+
371
+ @property
372
+ def initialize(self):
373
+ return self._initializer
374
+
375
+ @property
376
+ def update(self):
377
+ return self._update
378
+
379
+ @property
380
+ def transform(self):
381
+ return self._transform
382
+
383
+ @abstractmethod
384
+ def compile(self, compiled: JaxRDDLCompilerWithGrad) -> None:
385
+ pass
386
+
387
+
388
+ class StaticNormalizer(Preprocessor):
389
+ '''Normalize values by box constraints on fluents computed from the RDDL domain.'''
390
+
391
+ def __init__(self, fluent_bounds: Dict[str, Tuple[np.ndarray, np.ndarray]]={}) -> None:
392
+ '''Create a new instance of the static normalizer.
393
+
394
+ :param fluent_bounds: optional bounds on fluents to overwrite default values.
395
+ '''
396
+ self.fluent_bounds = fluent_bounds
397
+
398
+ def compile(self, compiled: JaxRDDLCompilerWithGrad) -> None:
399
+
400
+ # adjust for partial observability
401
+ rddl = compiled.rddl
402
+ if rddl.observ_fluents:
403
+ observed_vars = rddl.observ_fluents
404
+ else:
405
+ observed_vars = rddl.state_fluents
406
+
407
+ # ignore boolean fluents and infinite bounds
408
+ bounded_vars = {}
409
+ for var in observed_vars:
410
+ if rddl.variable_ranges[var] != 'bool':
411
+ lower, upper = compiled.constraints.bounds[var]
412
+ if np.all(np.isfinite(lower) & np.isfinite(upper) & (lower < upper)):
413
+ bounded_vars[var] = (lower, upper)
414
+ user_bounds = self.fluent_bounds.get(var, None)
415
+ if user_bounds is not None:
416
+ bounded_vars[var] = tuple(user_bounds)
417
+
418
+ # initialize to ranges computed by the constraint parser
419
+ def _jax_wrapped_normalizer_init():
420
+ return bounded_vars
421
+ self._initializer = jax.jit(_jax_wrapped_normalizer_init)
422
+
423
+ # static bounds
424
+ def _jax_wrapped_normalizer_update(subs, stats):
425
+ stats = {var: (jnp.asarray(lower, dtype=compiled.REAL),
426
+ jnp.asarray(upper, dtype=compiled.REAL))
427
+ for (var, (lower, upper)) in bounded_vars.items()}
428
+ return stats
429
+ self._update = jax.jit(_jax_wrapped_normalizer_update)
430
+
431
+ # apply min max scaling
432
+ def _jax_wrapped_normalizer_transform(subs, stats):
433
+ new_subs = {}
434
+ for (var, values) in subs.items():
435
+ if var in stats:
436
+ lower, upper = stats[var]
437
+ new_dims = jnp.ndim(values) - jnp.ndim(lower)
438
+ lower = lower[(jnp.newaxis,) * new_dims + (...,)]
439
+ upper = upper[(jnp.newaxis,) * new_dims + (...,)]
440
+ new_subs[var] = (values - lower) / (upper - lower)
441
+ else:
442
+ new_subs[var] = values
443
+ return new_subs
444
+ self._transform = jax.jit(_jax_wrapped_normalizer_transform)
445
+
446
+
346
447
  # ***********************************************************************
347
448
  # ALL VERSIONS OF JAX PLANS
348
449
  #
@@ -368,7 +469,8 @@ class JaxPlan(metaclass=ABCMeta):
368
469
  @abstractmethod
369
470
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
370
471
  _bounds: Bounds,
371
- horizon: int) -> None:
472
+ horizon: int,
473
+ preprocessor: Optional[Preprocessor]=None) -> None:
372
474
  pass
373
475
 
374
476
  @abstractmethod
@@ -519,7 +621,8 @@ class JaxStraightLinePlan(JaxPlan):
519
621
 
520
622
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
521
623
  _bounds: Bounds,
522
- horizon: int) -> None:
624
+ horizon: int,
625
+ preprocessor: Optional[Preprocessor]=None) -> None:
523
626
  rddl = compiled.rddl
524
627
 
525
628
  # calculate the correct action box bounds
@@ -607,7 +710,7 @@ class JaxStraightLinePlan(JaxPlan):
607
710
  return new_params, True
608
711
 
609
712
  # convert softmax action back to action dict
610
- action_sizes = {var: np.prod(shape[1:], dtype=int)
713
+ action_sizes = {var: np.prod(shape[1:], dtype=np.int64)
611
714
  for (var, shape) in shapes.items()
612
715
  if ranges[var] == 'bool'}
613
716
 
@@ -691,7 +794,7 @@ class JaxStraightLinePlan(JaxPlan):
691
794
  scores = []
692
795
  for (var, param) in params.items():
693
796
  if ranges[var] == 'bool':
694
- param_flat = jnp.ravel(param)
797
+ param_flat = jnp.ravel(param, order='C')
695
798
  if noop[var]:
696
799
  if wrap_sigmoid:
697
800
  param_flat = -param_flat
@@ -908,7 +1011,8 @@ class JaxDeepReactivePolicy(JaxPlan):
908
1011
 
909
1012
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
910
1013
  _bounds: Bounds,
911
- horizon: int) -> None:
1014
+ horizon: int,
1015
+ preprocessor: Optional[Preprocessor]=None) -> None:
912
1016
  rddl = compiled.rddl
913
1017
 
914
1018
  # calculate the correct action box bounds
@@ -939,7 +1043,7 @@ class JaxDeepReactivePolicy(JaxPlan):
939
1043
  wrap_non_bool = self._wrap_non_bool
940
1044
  init = self._initializer
941
1045
  layers = list(enumerate(zip(self._topology, self._activations)))
942
- layer_sizes = {var: np.prod(shape, dtype=int)
1046
+ layer_sizes = {var: np.prod(shape, dtype=np.int64)
943
1047
  for (var, shape) in shapes.items()}
944
1048
  layer_names = {var: f'output_{var}'.replace('-', '_') for var in shapes}
945
1049
 
@@ -973,7 +1077,12 @@ class JaxDeepReactivePolicy(JaxPlan):
973
1077
  normalize = False
974
1078
 
975
1079
  # convert subs dictionary into a state vector to feed to the MLP
976
- def _jax_wrapped_policy_input(subs):
1080
+ def _jax_wrapped_policy_input(subs, hyperparams):
1081
+
1082
+ # optional state preprocessing
1083
+ if preprocessor is not None:
1084
+ stats = hyperparams[preprocessor.HYPERPARAMS_KEY]
1085
+ subs = preprocessor.transform(subs, stats)
977
1086
 
978
1087
  # concatenate all state variables into a single vector
979
1088
  # optionally apply layer norm to each input tensor
@@ -981,7 +1090,7 @@ class JaxDeepReactivePolicy(JaxPlan):
981
1090
  non_bool_dims = 0
982
1091
  for (var, value) in subs.items():
983
1092
  if var in observed_vars:
984
- state = jnp.ravel(value)
1093
+ state = jnp.ravel(value, order='C')
985
1094
  if ranges[var] == 'bool':
986
1095
  states_bool.append(state)
987
1096
  else:
@@ -1010,8 +1119,8 @@ class JaxDeepReactivePolicy(JaxPlan):
1010
1119
  return state
1011
1120
 
1012
1121
  # predict actions from the policy network for current state
1013
- def _jax_wrapped_policy_network_predict(subs):
1014
- state = _jax_wrapped_policy_input(subs)
1122
+ def _jax_wrapped_policy_network_predict(subs, hyperparams):
1123
+ state = _jax_wrapped_policy_input(subs, hyperparams)
1015
1124
 
1016
1125
  # feed state vector through hidden layers
1017
1126
  hidden = state
@@ -1076,7 +1185,7 @@ class JaxDeepReactivePolicy(JaxPlan):
1076
1185
 
1077
1186
  # train action prediction
1078
1187
  def _jax_wrapped_drp_predict_train(key, params, hyperparams, step, subs):
1079
- actions = predict_fn.apply(params, subs)
1188
+ actions = predict_fn.apply(params, subs, hyperparams)
1080
1189
  if not wrap_non_bool:
1081
1190
  for (var, action) in actions.items():
1082
1191
  if var != bool_key and ranges[var] != 'bool':
@@ -1126,7 +1235,7 @@ class JaxDeepReactivePolicy(JaxPlan):
1126
1235
  subs = {var: value[0, ...]
1127
1236
  for (var, value) in subs.items()
1128
1237
  if var in observed_vars}
1129
- params = predict_fn.init(key, subs)
1238
+ params = predict_fn.init(key, subs, hyperparams)
1130
1239
  return params
1131
1240
 
1132
1241
  self.initializer = _jax_wrapped_drp_init
@@ -1634,12 +1743,21 @@ def mean_semivariance_utility(returns: jnp.ndarray, beta: float) -> float:
1634
1743
  return mu - 0.5 * beta * msv
1635
1744
 
1636
1745
 
1746
+ @jax.jit
1747
+ def sharpe_utility(returns: jnp.ndarray, risk_free: float) -> float:
1748
+ return (jnp.mean(returns) - risk_free) / (jnp.std(returns) + 1e-10)
1749
+
1750
+
1751
+ @jax.jit
1752
+ def var_utility(returns: jnp.ndarray, alpha: float) -> float:
1753
+ return jnp.percentile(returns, q=100 * alpha)
1754
+
1755
+
1637
1756
  @jax.jit
1638
1757
  def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
1639
1758
  var = jnp.percentile(returns, q=100 * alpha)
1640
1759
  mask = returns <= var
1641
- weights = mask / jnp.maximum(1, jnp.sum(mask))
1642
- return jnp.sum(returns * weights)
1760
+ return jnp.sum(returns * mask) / jnp.maximum(1, jnp.sum(mask))
1643
1761
 
1644
1762
 
1645
1763
  # set of all currently valid built-in utility functions
@@ -1649,8 +1767,10 @@ UTILITY_LOOKUP = {
1649
1767
  'mean_std': mean_deviation_utility,
1650
1768
  'mean_semivar': mean_semivariance_utility,
1651
1769
  'mean_semidev': mean_semideviation_utility,
1770
+ 'sharpe': sharpe_utility,
1652
1771
  'entropic': entropic_utility,
1653
1772
  'exponential': entropic_utility,
1773
+ 'var': var_utility,
1654
1774
  'cvar': cvar_utility
1655
1775
  }
1656
1776
 
@@ -1689,7 +1809,8 @@ class JaxBackpropPlanner:
1689
1809
  logger: Optional[Logger]=None,
1690
1810
  dashboard_viz: Optional[Any]=None,
1691
1811
  print_warnings: bool=True,
1692
- parallel_updates: Optional[int]=None) -> None:
1812
+ parallel_updates: Optional[int]=None,
1813
+ preprocessor: Optional[Preprocessor]=None) -> None:
1693
1814
  '''Creates a new gradient-based algorithm for optimizing action sequences
1694
1815
  (plan) in the given RDDL. Some operations will be converted to their
1695
1816
  differentiable counterparts; the specific operations can be customized
@@ -1731,6 +1852,7 @@ class JaxBackpropPlanner:
1731
1852
  to pass to the dashboard to visualize the policy
1732
1853
  :param print_warnings: whether to print warnings
1733
1854
  :param parallel_updates: how many optimizers to run independently in parallel
1855
+ :param preprocessor: optional preprocessor for state inputs to plan
1734
1856
  '''
1735
1857
  self.rddl = rddl
1736
1858
  self.plan = plan
@@ -1756,6 +1878,7 @@ class JaxBackpropPlanner:
1756
1878
  self.pgpe = pgpe
1757
1879
  self.use_pgpe = pgpe is not None
1758
1880
  self.print_warnings = print_warnings
1881
+ self.preprocessor = preprocessor
1759
1882
 
1760
1883
  # set optimizer
1761
1884
  try:
@@ -1881,7 +2004,8 @@ r"""
1881
2004
  f' noise_kwargs ={self.noise_kwargs}\n'
1882
2005
  f' batch_size_train ={self.batch_size_train}\n'
1883
2006
  f' batch_size_test ={self.batch_size_test}\n'
1884
- f' parallel_updates ={self.parallel_updates}\n')
2007
+ f' parallel_updates ={self.parallel_updates}\n'
2008
+ f' preprocessor ={self.preprocessor}\n')
1885
2009
  result += str(self.plan)
1886
2010
  if self.use_pgpe:
1887
2011
  result += str(self.pgpe)
@@ -1917,10 +2041,15 @@ r"""
1917
2041
 
1918
2042
  def _jax_compile_optimizer(self):
1919
2043
 
2044
+ # preprocessor
2045
+ if self.preprocessor is not None:
2046
+ self.preprocessor.compile(self.compiled)
2047
+
1920
2048
  # policy
1921
2049
  self.plan.compile(self.compiled,
1922
2050
  _bounds=self._action_bounds,
1923
- horizon=self.horizon)
2051
+ horizon=self.horizon,
2052
+ preprocessor=self.preprocessor)
1924
2053
  self.train_policy = jax.jit(self.plan.train_policy)
1925
2054
  self.test_policy = jax.jit(self.plan.test_policy)
1926
2055
 
@@ -1928,14 +2057,16 @@ r"""
1928
2057
  train_rollouts = self.compiled.compile_rollouts(
1929
2058
  policy=self.plan.train_policy,
1930
2059
  n_steps=self.horizon,
1931
- n_batch=self.batch_size_train
2060
+ n_batch=self.batch_size_train,
2061
+ cache_path_info=self.preprocessor is not None
1932
2062
  )
1933
2063
  self.train_rollouts = train_rollouts
1934
2064
 
1935
2065
  test_rollouts = self.test_compiled.compile_rollouts(
1936
2066
  policy=self.plan.test_policy,
1937
2067
  n_steps=self.horizon,
1938
- n_batch=self.batch_size_test
2068
+ n_batch=self.batch_size_test,
2069
+ cache_path_info=False
1939
2070
  )
1940
2071
  self.test_rollouts = jax.jit(test_rollouts)
1941
2072
 
@@ -2397,7 +2528,13 @@ r"""
2397
2528
  f'which could be suboptimal.', 'yellow')
2398
2529
  print(message)
2399
2530
  policy_hyperparams[action] = 1.0
2400
-
2531
+
2532
+ # initialize preprocessor
2533
+ preproc_key = None
2534
+ if self.preprocessor is not None:
2535
+ preproc_key = self.preprocessor.HYPERPARAMS_KEY
2536
+ policy_hyperparams[preproc_key] = self.preprocessor.initialize()
2537
+
2401
2538
  # print summary of parameters:
2402
2539
  if print_summary:
2403
2540
  print(self.summarize_system())
@@ -2524,6 +2661,11 @@ r"""
2524
2661
  subkey, policy_params, policy_hyperparams, train_subs, model_params,
2525
2662
  opt_state, opt_aux)
2526
2663
 
2664
+ # update the preprocessor
2665
+ if self.preprocessor is not None:
2666
+ policy_hyperparams[preproc_key] = self.preprocessor.update(
2667
+ train_log['fluents'], policy_hyperparams[preproc_key])
2668
+
2527
2669
  # evaluate
2528
2670
  test_loss, (test_log, model_params_test) = self.test_loss(
2529
2671
  subkey, policy_params, policy_hyperparams, test_subs, model_params_test)
@@ -2676,6 +2818,7 @@ r"""
2676
2818
  'model_params': model_params,
2677
2819
  'progress': progress_percent,
2678
2820
  'train_log': train_log,
2821
+ 'policy_hyperparams': policy_hyperparams,
2679
2822
  **test_log
2680
2823
  }
2681
2824
 
@@ -2753,7 +2896,8 @@ r"""
2753
2896
 
2754
2897
  def _perform_diagnosis(self, last_iter_improve,
2755
2898
  train_return, test_return, best_return, grad_norm):
2756
- max_grad_norm = max(jax.tree_util.tree_leaves(grad_norm))
2899
+ grad_norms = jax.tree_util.tree_leaves(grad_norm)
2900
+ max_grad_norm = max(grad_norms) if grad_norms else np.nan
2757
2901
  grad_is_zero = np.allclose(max_grad_norm, 0)
2758
2902
 
2759
2903
  # divergence if the solution is not finite
@@ -2895,6 +3039,7 @@ class JaxOfflineController(BaseAgent):
2895
3039
  self.train_on_reset = train_on_reset
2896
3040
  self.train_kwargs = train_kwargs
2897
3041
  self.params_given = params is not None
3042
+ self.hyperparams_given = eval_hyperparams is not None
2898
3043
 
2899
3044
  # load the policy from file
2900
3045
  if not self.train_on_reset and params is not None and isinstance(params, str):
@@ -2908,6 +3053,8 @@ class JaxOfflineController(BaseAgent):
2908
3053
  callback = self.planner.optimize(key=self.key, **self.train_kwargs)
2909
3054
  self.callback = callback
2910
3055
  params = callback['best_params']
3056
+ if not self.hyperparams_given:
3057
+ self.eval_hyperparams = callback['policy_hyperparams']
2911
3058
 
2912
3059
  # save the policy
2913
3060
  if save_path is not None:
@@ -2931,6 +3078,8 @@ class JaxOfflineController(BaseAgent):
2931
3078
  callback = self.planner.optimize(key=self.key, **self.train_kwargs)
2932
3079
  self.callback = callback
2933
3080
  self.params = callback['best_params']
3081
+ if not self.hyperparams_given:
3082
+ self.eval_hyperparams = callback['policy_hyperparams']
2934
3083
 
2935
3084
 
2936
3085
  class JaxOnlineController(BaseAgent):
@@ -2963,6 +3112,7 @@ class JaxOnlineController(BaseAgent):
2963
3112
  key = random.PRNGKey(round(time.time() * 1000))
2964
3113
  self.key = key
2965
3114
  self.eval_hyperparams = eval_hyperparams
3115
+ self.hyperparams_given = eval_hyperparams is not None
2966
3116
  self.warm_start = warm_start
2967
3117
  self.train_kwargs = train_kwargs
2968
3118
  self.max_attempts = max_attempts
@@ -2987,6 +3137,8 @@ class JaxOnlineController(BaseAgent):
2987
3137
  key=self.key, guess=self.guess, subs=state, **self.train_kwargs)
2988
3138
  self.callback = callback
2989
3139
  params = callback['best_params']
3140
+ if not self.hyperparams_given:
3141
+ self.eval_hyperparams = callback['policy_hyperparams']
2990
3142
 
2991
3143
  # get the action from the parameters for the current state
2992
3144
  self.key, subkey = random.split(self.key)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyRDDLGym-jax
3
- Version: 2.5
3
+ Version: 2.6
4
4
  Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
5
  Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
6
  Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
@@ -20,7 +20,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
20
  Requires-Python: >=3.9
21
21
  Description-Content-Type: text/markdown
22
22
  License-File: LICENSE
23
- Requires-Dist: pyRDDLGym>=2.0
23
+ Requires-Dist: pyRDDLGym>=2.3
24
24
  Requires-Dist: tqdm>=4.66
25
25
  Requires-Dist: jax>=0.4.12
26
26
  Requires-Dist: optax>=0.1.9
@@ -55,7 +55,7 @@ Dynamic: summary
55
55
 
56
56
  [Installation](#installation) | [Run cmd](#running-from-the-command-line) | [Run python](#running-from-another-python-application) | [Configuration](#configuring-the-planner) | [Dashboard](#jaxplan-dashboard) | [Tuning](#tuning-the-planner) | [Simulation](#simulation) | [Citing](#citing-jaxplan)
57
57
 
58
- **pyRDDLGym-jax (known in the literature as JaxPlan) is an efficient gradient-based/differentiable planning algorithm in JAX.**
58
+ **pyRDDLGym-jax (or JaxPlan) is an efficient gradient-based planning algorithm based on JAX.**
59
59
 
60
60
  Purpose:
61
61
 
@@ -84,7 +84,7 @@ and was moved to the individual logic components which have their own unique wei
84
84
 
85
85
  > [!NOTE]
86
86
  > While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
87
- > If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
87
+ > If you find it is not making progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
88
88
 
89
89
  ## Installation
90
90
 
@@ -220,13 +220,7 @@ controller = JaxOfflineController(planner, **train_args)
220
220
  ## JaxPlan Dashboard
221
221
 
222
222
  Since version 1.0, JaxPlan has an optional dashboard that allows keeping track of the planner performance across multiple runs,
223
- and visualization of the policy or model, and other useful debugging features.
224
-
225
- <p align="middle">
226
- <img src="https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/Images/dashboard.png" width="480" height="248" margin=0/>
227
- </p>
228
-
229
- To run the dashboard, add the following entry to your config file:
223
+ and visualization of the policy or model, and other useful debugging features. To run the dashboard, add the following to your config file:
230
224
 
231
225
  ```ini
232
226
  ...
@@ -235,8 +229,6 @@ dashboard=True
235
229
  ...
236
230
  ```
237
231
 
238
- More documentation about this and other new features will be coming soon.
239
-
240
232
  ## Tuning the Planner
241
233
 
242
234
  A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
@@ -1,9 +1,10 @@
1
- pyRDDLGym_jax/__init__.py,sha256=VoxLo_sy8RlJIIyu7szqL-cdMGBJdQPg-aSeyOVVIkY,19
1
+ pyRDDLGym_jax/__init__.py,sha256=VUmQViJtwUg1JGcgXlmNm0fE3Njyruyt_76c16R-LTo,19
2
2
  pyRDDLGym_jax/entry_point.py,sha256=K0zy1oe66jfBHkHHCM6aGHbbiVqnQvDhDb8se4uaKHE,3319
3
3
  pyRDDLGym_jax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- pyRDDLGym_jax/core/compiler.py,sha256=uFCtoipsIa3MM9nGgT3X8iCViPl2XSPNXh0jMdzN0ko,82895
5
- pyRDDLGym_jax/core/logic.py,sha256=lfc2ak_ap_ajMEFlB5EHCRNgJym31dNyA-5d-7N4CZA,56271
6
- pyRDDLGym_jax/core/planner.py,sha256=M6GKzN7Ml57B4ZrFZhhkpsQCvReKaCQNzer7zeHCM9E,140275
4
+ pyRDDLGym_jax/core/compiler.py,sha256=Bpgfw4nqRFqiTju7ioR0B0Dhp3wMvk-9LmTRpMmLIOc,83457
5
+ pyRDDLGym_jax/core/logic.py,sha256=9rRpKJCx4Us_2c6BiSWRN9k2sM_iYsAK1B7zcgwu3ZA,56290
6
+ pyRDDLGym_jax/core/model.py,sha256=4WfmtUVN1EKCD-7eWeQByWk8_zKyDcMABAMdlxN1LOU,27215
7
+ pyRDDLGym_jax/core/planner.py,sha256=a684ss5TAkJ-P2SEbZA90FSpDwFxHwRoaLtbRIBspAA,146450
7
8
  pyRDDLGym_jax/core/simulator.py,sha256=ayCATTUL3clLaZPQ5OUg2bI_c26KKCTq6TbrxbMsVdc,10470
8
9
  pyRDDLGym_jax/core/tuning.py,sha256=BWcQZk02TMLexTz1Sw4lX2EQKvmPbp7biC51M-IiNUw,25153
9
10
  pyRDDLGym_jax/core/visualization.py,sha256=4BghMp8N7qtF0tdyDSqtxAxNfP9HPrQWTiXzAMJmx7o,70365
@@ -41,9 +42,9 @@ pyRDDLGym_jax/examples/configs/default_slp.cfg,sha256=mJo0woDevhQCSQfJg30ULVy9qG
41
42
  pyRDDLGym_jax/examples/configs/tuning_drp.cfg,sha256=zocZn_cVarH5i0hOlt2Zu0NwmXYBmTTghLaXLtQOGto,526
42
43
  pyRDDLGym_jax/examples/configs/tuning_replan.cfg,sha256=9oIhtw9cuikmlbDgCgbrTc5G7hUio-HeAv_3CEGVclY,523
43
44
  pyRDDLGym_jax/examples/configs/tuning_slp.cfg,sha256=QqnyR__5-HhKeCDfGDel8VIlqsjxRHk4SSH089zJP8s,486
44
- pyrddlgym_jax-2.5.dist-info/licenses/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
45
- pyrddlgym_jax-2.5.dist-info/METADATA,sha256=XAaEJfbsYW-txxZhFZ6o_HmvqxkIMTqBF9LbV-KdTzI,17058
46
- pyrddlgym_jax-2.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
47
- pyrddlgym_jax-2.5.dist-info/entry_points.txt,sha256=Q--z9QzqDBz1xjswPZ87PU-pib-WPXx44hUWAFoBGBA,59
48
- pyrddlgym_jax-2.5.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
49
- pyrddlgym_jax-2.5.dist-info/RECORD,,
45
+ pyrddlgym_jax-2.6.dist-info/licenses/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
46
+ pyrddlgym_jax-2.6.dist-info/METADATA,sha256=1gY3EPRHKMVeZYYgq4DCqWvw3Q1Ak5XVYRaIO2UlQXc,16770
47
+ pyrddlgym_jax-2.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
48
+ pyrddlgym_jax-2.6.dist-info/entry_points.txt,sha256=Q--z9QzqDBz1xjswPZ87PU-pib-WPXx44hUWAFoBGBA,59
49
+ pyrddlgym_jax-2.6.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
50
+ pyrddlgym_jax-2.6.dist-info/RECORD,,