pyRDDLGym-jax 2.8__py3-none-any.whl → 3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. pyRDDLGym_jax/__init__.py +1 -1
  2. pyRDDLGym_jax/core/compiler.py +1080 -906
  3. pyRDDLGym_jax/core/logic.py +1537 -1369
  4. pyRDDLGym_jax/core/model.py +75 -86
  5. pyRDDLGym_jax/core/planner.py +883 -935
  6. pyRDDLGym_jax/core/simulator.py +20 -17
  7. pyRDDLGym_jax/core/tuning.py +11 -7
  8. pyRDDLGym_jax/core/visualization.py +115 -78
  9. pyRDDLGym_jax/entry_point.py +2 -1
  10. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +6 -8
  11. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +5 -7
  12. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +7 -8
  13. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +7 -8
  14. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +8 -9
  15. pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +5 -7
  16. pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +5 -7
  17. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +7 -8
  18. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +6 -7
  19. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +6 -7
  20. pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +6 -8
  21. pyRDDLGym_jax/examples/configs/Quadcopter_physics_drp.cfg +17 -0
  22. pyRDDLGym_jax/examples/configs/Quadcopter_physics_slp.cfg +17 -0
  23. pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +5 -7
  24. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +4 -7
  25. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +5 -7
  26. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +4 -7
  27. pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +5 -7
  28. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +6 -7
  29. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +6 -7
  30. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +6 -7
  31. pyRDDLGym_jax/examples/configs/default_drp.cfg +5 -8
  32. pyRDDLGym_jax/examples/configs/default_replan.cfg +5 -8
  33. pyRDDLGym_jax/examples/configs/default_slp.cfg +5 -8
  34. pyRDDLGym_jax/examples/configs/tuning_drp.cfg +6 -8
  35. pyRDDLGym_jax/examples/configs/tuning_replan.cfg +6 -8
  36. pyRDDLGym_jax/examples/configs/tuning_slp.cfg +6 -8
  37. pyRDDLGym_jax/examples/run_plan.py +2 -2
  38. pyRDDLGym_jax/examples/run_tune.py +2 -2
  39. {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/METADATA +22 -23
  40. pyrddlgym_jax-3.0.dist-info/RECORD +51 -0
  41. {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/WHEEL +1 -1
  42. pyRDDLGym_jax/examples/run_gradient.py +0 -102
  43. pyrddlgym_jax-2.8.dist-info/RECORD +0 -50
  44. {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/entry_points.txt +0 -0
  45. {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/licenses/LICENSE +0 -0
  46. {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/top_level.txt +0 -0
@@ -16,8 +16,7 @@ import optax
16
16
 
17
17
  from pyRDDLGym.core.compiler.model import RDDLLiftedModel
18
18
 
19
- from pyRDDLGym_jax.core.logic import Logic, ExactLogic
20
- from pyRDDLGym_jax.core.planner import JaxRDDLCompilerWithGrad
19
+ from pyRDDLGym_jax.core.logic import JaxRDDLCompilerWithGrad
21
20
 
22
21
  Kwargs = Dict[str, Any]
23
22
  State = Dict[str, np.ndarray]
@@ -38,25 +37,22 @@ LossFunction = Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
38
37
 
39
38
  def mean_squared_error() -> LossFunction:
40
39
  def _jax_wrapped_mse_loss(target, pred):
41
- loss_values = jnp.square(target - pred)
42
- return loss_values
40
+ return jnp.square(target - pred)
43
41
  return jax.jit(_jax_wrapped_mse_loss)
44
42
 
45
43
 
46
- def binary_cross_entropy(eps: float=1e-6) -> LossFunction:
44
+ def binary_cross_entropy(eps: float=1e-8) -> LossFunction:
47
45
  def _jax_wrapped_binary_cross_entropy_loss(target, pred):
48
46
  pred = jnp.clip(pred, eps, 1.0 - eps)
49
47
  log_pred = jnp.log(pred)
50
48
  log_not_pred = jnp.log(1.0 - pred)
51
- loss_values = -target * log_pred - (1.0 - target) * log_not_pred
52
- return loss_values
49
+ return -target * log_pred - (1.0 - target) * log_not_pred
53
50
  return jax.jit(_jax_wrapped_binary_cross_entropy_loss)
54
51
 
55
52
 
56
53
  def optax_loss(loss_fn: LossFunction, **kwargs) -> LossFunction:
57
54
  def _jax_wrapped_optax_loss(target, pred):
58
- loss_values = loss_fn(pred, target, **kwargs)
59
- return loss_values
55
+ return loss_fn(pred, target, **kwargs)
60
56
  return jax.jit(_jax_wrapped_optax_loss)
61
57
 
62
58
 
@@ -96,11 +92,11 @@ class JaxModelLearner:
96
92
  optimizer_kwargs: Optional[Kwargs]=None,
97
93
  initializer: initializers.Initializer = initializers.normal(),
98
94
  wrap_non_bool: bool=True,
99
- use64bit: bool=False,
100
95
  bool_fluent_loss: LossFunction=binary_cross_entropy(),
101
96
  real_fluent_loss: LossFunction=mean_squared_error(),
102
97
  int_fluent_loss: LossFunction=mean_squared_error(),
103
- logic: Logic=ExactLogic(),
98
+ compiler: JaxRDDLCompilerWithGrad=JaxRDDLCompilerWithGrad,
99
+ compiler_kwargs: Optional[Kwargs]=None,
104
100
  model_params_reduction: Callable=lambda x: x[0]) -> None:
105
101
  '''Creates a new gradient-based algorithm for inferring unknown non-fluents
106
102
  in a RDDL domain from a data set or stream coming from the real environment.
@@ -117,12 +113,11 @@ class JaxModelLearner:
117
113
  :param initializer: how to initialize non-fluents
118
114
  :param wrap_non_bool: whether to wrap non-boolean trainable parameters to satisfy
119
115
  required ranges as specified in param_ranges (use a projected gradient otherwise)
120
- :param use64bit: whether to perform arithmetic in 64 bit
121
116
  :param bool_fluent_loss: loss function to optimize for bool-valued fluents
122
117
  :param real_fluent_loss: loss function to optimize for real-valued fluents
123
118
  :param int_fluent_loss: loss function to optimize for int-valued fluents
124
- :param logic: a subclass of Logic for mapping exact mathematical
125
- operations to their differentiable counterparts
119
+ :param compiler: compiler instance to use for planning
120
+ :param compiler_kwargs: compiler instances kwargs for initialization
126
121
  :param model_params_reduction: how to aggregate updated model_params across runs
127
122
  in the batch (defaults to selecting the first element's parameters in the batch)
128
123
  '''
@@ -135,11 +130,9 @@ class JaxModelLearner:
135
130
  self.optimizer_kwargs = optimizer_kwargs
136
131
  self.initializer = initializer
137
132
  self.wrap_non_bool = wrap_non_bool
138
- self.use64bit = use64bit
139
133
  self.bool_fluent_loss = bool_fluent_loss
140
134
  self.real_fluent_loss = real_fluent_loss
141
135
  self.int_fluent_loss = int_fluent_loss
142
- self.logic = logic
143
136
  self.model_params_reduction = model_params_reduction
144
137
 
145
138
  # validate param_ranges
@@ -166,6 +159,11 @@ class JaxModelLearner:
166
159
  self.optimizer = optax.chain(*pipeline)
167
160
 
168
161
  # build the computation graph
162
+ if compiler_kwargs is None:
163
+ compiler_kwargs = {}
164
+ self.compiler_kwargs = compiler_kwargs
165
+ self.compiler_type = compiler
166
+
169
167
  self.step_fn = self._jax_compile_rddl()
170
168
  self.map_fn = self._jax_map()
171
169
  self.loss_fn = self._jax_loss(map_fn=self.map_fn, step_fn=self.step_fn)
@@ -179,44 +177,39 @@ class JaxModelLearner:
179
177
  def _jax_compile_rddl(self):
180
178
 
181
179
  # compile the RDDL model
182
- self.compiled = JaxRDDLCompilerWithGrad(
183
- rddl=self.rddl,
184
- logic=self.logic,
185
- use64bit=self.use64bit,
186
- compile_non_fluent_exact=False,
187
- print_warnings=True
180
+ self.compiled = self.compiler_type(
181
+ self.rddl,
182
+ **self.compiler_kwargs
188
183
  )
189
184
  self.compiled.compile(log_jax_expr=True, heading='RELAXED MODEL')
190
185
 
191
186
  # compile the transition step function
192
187
  step_fn = self.compiled.compile_transition()
193
188
 
194
- def _jax_wrapped_step(key, param_fluents, subs, actions, hyperparams):
195
- for (name, param) in param_fluents.items():
196
- subs[name] = param
197
- subs, _, hyperparams = step_fn(key, actions, subs, hyperparams)
198
- return subs, hyperparams
189
+ def _jax_wrapped_step(key, param_fluents, fls, nfls, actions, hyperparams):
190
+ nfls = nfls.copy()
191
+ nfls.update(param_fluents)
192
+ fls, _, hyperparams = step_fn(key, actions, fls, nfls, hyperparams)
193
+ return fls, hyperparams
199
194
 
200
195
  # batched step function
201
- def _jax_wrapped_batched_step(key, param_fluents, subs, actions, hyperparams):
202
- keys = jnp.asarray(random.split(key, num=self.batch_size_train))
203
- subs, hyperparams = jax.vmap(
204
- _jax_wrapped_step, in_axes=(0, None, 0, 0, None)
205
- )(keys, param_fluents, subs, actions, hyperparams)
196
+ def _jax_wrapped_batched_step(key, param_fluents, fls, nfls, actions, hyperparams):
197
+ keys = random.split(key, num=self.batch_size_train)
198
+ fls, hyperparams = jax.vmap(
199
+ _jax_wrapped_step, in_axes=(0, None, 0, None, 0, None)
200
+ )(keys, param_fluents, fls, nfls, actions, hyperparams)
206
201
  hyperparams = jax.tree_util.tree_map(self.model_params_reduction, hyperparams)
207
- return subs, hyperparams
202
+ return fls, hyperparams
208
203
 
209
204
  # batched step function with parallel samples per data point
210
- def _jax_wrapped_batched_parallel_step(key, param_fluents, subs, actions, hyperparams):
211
- keys = jnp.asarray(random.split(key, num=self.samples_per_datapoint))
212
- subs, hyperparams = jax.vmap(
213
- _jax_wrapped_batched_step, in_axes=(0, None, None, None, None)
214
- )(keys, param_fluents, subs, actions, hyperparams)
205
+ def _jax_wrapped_batched_parallel_step(key, param_fluents, fls, nfls, actions, hyperparams):
206
+ keys = random.split(key, num=self.samples_per_datapoint)
207
+ fls, hyperparams = jax.vmap(
208
+ _jax_wrapped_batched_step, in_axes=(0, None, None, None, None, None)
209
+ )(keys, param_fluents, fls, nfls, actions, hyperparams)
215
210
  hyperparams = jax.tree_util.tree_map(self.model_params_reduction, hyperparams)
216
- return subs, hyperparams
217
-
218
- batched_step_fn = jax.jit(_jax_wrapped_batched_parallel_step)
219
- return batched_step_fn
211
+ return fls, hyperparams
212
+ return jax.jit(_jax_wrapped_batched_parallel_step)
220
213
 
221
214
  def _jax_map(self):
222
215
 
@@ -254,20 +247,18 @@ class JaxModelLearner:
254
247
  else:
255
248
  param_fluents[name] = param
256
249
  return param_fluents
257
-
258
- map_fn = jax.jit(_jax_wrapped_params_to_fluents)
259
- return map_fn
250
+ return jax.jit(_jax_wrapped_params_to_fluents)
260
251
 
261
252
  def _jax_loss(self, map_fn, step_fn):
262
253
 
263
254
  # use binary cross entropy for bool fluents
264
255
  # mean squared error for continuous and integer fluents
265
- def _jax_wrapped_batched_model_loss(key, param_fluents, subs, actions, next_fluents,
266
- hyperparams):
267
- next_subs, hyperparams = step_fn(key, param_fluents, subs, actions, hyperparams)
256
+ def _jax_wrapped_batched_model_loss(key, param_fluents, fls, nfls, actions,
257
+ next_fluents, hyperparams):
258
+ fls, hyperparams = step_fn(key, param_fluents, fls, nfls, actions, hyperparams)
268
259
  total_loss = 0.0
269
260
  for (name, next_value) in next_fluents.items():
270
- preds = jnp.asarray(next_subs[name], dtype=self.compiled.REAL)
261
+ preds = jnp.asarray(fls[name], dtype=self.compiled.REAL)
271
262
  targets = jnp.asarray(next_value, dtype=self.compiled.REAL)[jnp.newaxis, ...]
272
263
  if self.rddl.variable_ranges[name] == 'bool':
273
264
  loss_values = self.bool_fluent_loss(targets, preds)
@@ -279,14 +270,12 @@ class JaxModelLearner:
279
270
  return total_loss, hyperparams
280
271
 
281
272
  # loss with the parameters mapped to their fluents
282
- def _jax_wrapped_batched_loss(key, params, subs, actions, next_fluents, hyperparams):
273
+ def _jax_wrapped_batched_loss(key, params, fls, nfls, actions, next_fluents,
274
+ hyperparams):
283
275
  param_fluents = map_fn(params)
284
- loss, hyperparams = _jax_wrapped_batched_model_loss(
285
- key, param_fluents, subs, actions, next_fluents, hyperparams)
286
- return loss, hyperparams
287
-
288
- loss_fn = jax.jit(_jax_wrapped_batched_loss)
289
- return loss_fn
276
+ return _jax_wrapped_batched_model_loss(
277
+ key, param_fluents, fls, nfls, actions, next_fluents, hyperparams)
278
+ return jax.jit(_jax_wrapped_batched_loss)
290
279
 
291
280
  def _jax_init(self, project_fn):
292
281
  optimizer = self.optimizer
@@ -325,20 +314,19 @@ class JaxModelLearner:
325
314
  if self.rddl.variable_ranges[name] == 'bool':
326
315
  new_params[name] = value
327
316
  else:
328
- lower, upper = self.param_ranges[name]
329
- new_params[name] = jnp.clip(value, lower, upper)
317
+ new_params[name] = jnp.clip(value, *self.param_ranges[name])
330
318
  return new_params
331
319
 
332
320
  # gradient descent update
333
- def _jax_wrapped_params_update(key, params, subs, actions, next_fluents,
321
+ def _jax_wrapped_params_update(key, params, fls, nfls, actions, next_fluents,
334
322
  hyperparams, opt_state):
335
323
  (loss_val, hyperparams), grad = jax.value_and_grad(
336
324
  loss_fn, argnums=1, has_aux=True
337
- )(key, params, subs, actions, next_fluents, hyperparams)
325
+ )(key, params, fls, nfls, actions, next_fluents, hyperparams)
338
326
  updates, opt_state = optimizer.update(grad, opt_state)
339
327
  params = optax.apply_updates(params, updates)
340
328
  params = _jax_wrapped_project_params(params)
341
- zero_grads = jax.tree_util.tree_map(partial(jnp.allclose, b=0.0), grad)
329
+ zero_grads = jax.tree_util.tree_map(partial(jnp.allclose, b=0), grad)
342
330
  return params, opt_state, loss_val, zero_grads, hyperparams
343
331
 
344
332
  update_fn = jax.jit(_jax_wrapped_params_update)
@@ -346,15 +334,17 @@ class JaxModelLearner:
346
334
  return update_fn, project_fn
347
335
 
348
336
  def _batched_init_subs(self):
349
- init_train = {}
337
+ init_fls, init_nfls = {}, {}
350
338
  for (name, value) in self.compiled.init_values.items():
351
- value = np.reshape(value, np.shape(value))[np.newaxis, ...]
352
- value = np.repeat(value, repeats=self.batch_size_train, axis=0)
353
339
  value = np.asarray(value, dtype=self.compiled.REAL)
354
- init_train[name] = value
340
+ if name in self.rddl.non_fluents:
341
+ init_nfls[name] = value
342
+ else:
343
+ init_fls[name] = np.repeat(
344
+ value[np.newaxis, ...], repeats=self.batch_size_train, axis=0)
355
345
  for (state, next_state) in self.rddl.next_state.items():
356
- init_train[next_state] = init_train[state]
357
- return init_train
346
+ init_fls[next_state] = init_fls[state]
347
+ return init_fls, init_nfls
358
348
 
359
349
  # ===========================================================================
360
350
  # ESTIMATE API
@@ -415,7 +405,7 @@ class JaxModelLearner:
415
405
  key = random.PRNGKey(round(time.time() * 1000))
416
406
 
417
407
  # prepare initial subs
418
- subs = self._batched_init_subs()
408
+ fls, nfls = self._batched_init_subs()
419
409
 
420
410
  # initialize parameter fluents to optimize
421
411
  if guess is None:
@@ -425,7 +415,7 @@ class JaxModelLearner:
425
415
  params, opt_state = self.init_opt_fn(guess)
426
416
 
427
417
  # initialize model hyper-parameters
428
- hyperparams = self.compiled.model_params
418
+ hyperparams = self.compiled.model_aux['params']
429
419
 
430
420
  # progress bar
431
421
  if print_progress:
@@ -439,10 +429,10 @@ class JaxModelLearner:
439
429
  status = JaxLearnerStatus.NORMAL
440
430
 
441
431
  # gradient update
442
- subs.update(states)
432
+ fls.update(states)
443
433
  key, subkey = random.split(key)
444
434
  params, opt_state, loss, zero_grads, hyperparams = self.update_fn(
445
- subkey, params, subs, actions, next_states, hyperparams, opt_state)
435
+ subkey, params, fls, nfls, actions, next_states, hyperparams, opt_state)
446
436
 
447
437
  # extract non-fluent values from the trainable parameters
448
438
  param_fluents = self.map_fn(params)
@@ -450,7 +440,8 @@ class JaxModelLearner:
450
440
 
451
441
  # check for learnability
452
442
  params_zero_grads = {
453
- name for (name, zero_grad) in zero_grads.items() if zero_grad}
443
+ name for (name, zero_grad) in zero_grads.items() if zero_grad
444
+ }
454
445
  if params_zero_grads:
455
446
  status = JaxLearnerStatus.NO_PROGRESS
456
447
 
@@ -504,14 +495,14 @@ class JaxModelLearner:
504
495
  '''
505
496
  if key is None:
506
497
  key = random.PRNGKey(round(time.time() * 1000))
507
- subs = self._batched_init_subs()
508
- hyperparams = self.compiled.model_params
498
+ fls, nfls = self._batched_init_subs()
499
+ hyperparams = self.compiled.model_aux['params']
509
500
  mean_loss = 0.0
510
501
  for (it, (states, actions, next_states)) in enumerate(data):
511
- subs.update(states)
502
+ fls.update(states)
512
503
  key, subkey = random.split(key)
513
504
  loss_value, _ = self.loss_fn(
514
- subkey, param_fluents, subs, actions, next_states, hyperparams)
505
+ subkey, param_fluents, fls, nfls, actions, next_states, hyperparams)
515
506
  mean_loss += (loss_value - mean_loss) / (it + 1)
516
507
  return mean_loss
517
508
 
@@ -524,15 +515,13 @@ class JaxModelLearner:
524
515
  model = deepcopy(self.rddl)
525
516
  for (name, values) in param_fluents.items():
526
517
  prange = model.variable_ranges[name]
527
- if prange == 'real':
528
- pass
529
- elif prange == 'bool':
518
+ if prange == 'bool':
530
519
  values = values > 0.5
531
- else:
520
+ elif prange != 'real':
532
521
  values = np.asarray(values, dtype=self.compiled.INT)
533
522
  values = np.ravel(values, order='C').tolist()
534
523
  if not self.rddl.variable_params[name]:
535
- assert(len(values) == 1)
524
+ assert (len(values) == 1)
536
525
  values = values[0]
537
526
  model.non_fluents[name] = values
538
527
  return model
@@ -549,7 +538,7 @@ if __name__ == '__main__':
549
538
  env = pyRDDLGym.make('CartPole_Continuous_gym', '0', vectorized=True)
550
539
  model = JaxModelLearner(rddl=env.model, param_ranges={}, batch_size_train=bs)
551
540
  key = random.PRNGKey(round(time.time() * 1000))
552
- subs = model._batched_init_subs()
541
+ fls, nfls = model._batched_init_subs()
553
542
  param_fluents = {}
554
543
  while True:
555
544
  states = {
@@ -558,14 +547,14 @@ if __name__ == '__main__':
558
547
  'ang-pos': np.random.uniform(-0.21, 0.21, (bs,)),
559
548
  'ang-vel': np.random.uniform(-0.21, 0.21, (bs,))
560
549
  }
561
- subs.update(states)
550
+ fls.update(states)
562
551
  actions = {
563
552
  'force': np.random.uniform(-10., 10., (bs,))
564
553
  }
565
554
  key, subkey = random.split(key)
566
- subs, _ = model.step_fn(subkey, param_fluents, subs, actions, {})
567
- subs = {k: np.asarray(v)[0, ...] for k, v in subs.items()}
568
- next_states = {k: subs[k] for k in model.rddl.state_fluents}
555
+ fls, _ = model.step_fn(subkey, param_fluents, fls, nfls, actions, {})
556
+ fls = {k: np.asarray(v)[0, ...] for k, v in fls.items()}
557
+ next_states = {k: fls[k] for k in model.rddl.state_fluents}
569
558
  yield (states, actions, next_states)
570
559
 
571
560
  # train it