pyRDDLGym-jax 2.8__py3-none-any.whl → 3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +1080 -906
- pyRDDLGym_jax/core/logic.py +1537 -1369
- pyRDDLGym_jax/core/model.py +75 -86
- pyRDDLGym_jax/core/planner.py +883 -935
- pyRDDLGym_jax/core/simulator.py +20 -17
- pyRDDLGym_jax/core/tuning.py +11 -7
- pyRDDLGym_jax/core/visualization.py +115 -78
- pyRDDLGym_jax/entry_point.py +2 -1
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +6 -8
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +5 -7
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +7 -8
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +7 -8
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +8 -9
- pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +5 -7
- pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +5 -7
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +7 -8
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +6 -7
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +6 -7
- pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +6 -8
- pyRDDLGym_jax/examples/configs/Quadcopter_physics_drp.cfg +17 -0
- pyRDDLGym_jax/examples/configs/Quadcopter_physics_slp.cfg +17 -0
- pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +5 -7
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +4 -7
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +5 -7
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +4 -7
- pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +5 -7
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +6 -7
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +6 -7
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +6 -7
- pyRDDLGym_jax/examples/configs/default_drp.cfg +5 -8
- pyRDDLGym_jax/examples/configs/default_replan.cfg +5 -8
- pyRDDLGym_jax/examples/configs/default_slp.cfg +5 -8
- pyRDDLGym_jax/examples/configs/tuning_drp.cfg +6 -8
- pyRDDLGym_jax/examples/configs/tuning_replan.cfg +6 -8
- pyRDDLGym_jax/examples/configs/tuning_slp.cfg +6 -8
- pyRDDLGym_jax/examples/run_plan.py +2 -2
- pyRDDLGym_jax/examples/run_tune.py +2 -2
- {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/METADATA +22 -23
- pyrddlgym_jax-3.0.dist-info/RECORD +51 -0
- {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax/examples/run_gradient.py +0 -102
- pyrddlgym_jax-2.8.dist-info/RECORD +0 -50
- {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/entry_points.txt +0 -0
- {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/licenses/LICENSE +0 -0
- {pyrddlgym_jax-2.8.dist-info → pyrddlgym_jax-3.0.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/model.py
CHANGED
|
@@ -16,8 +16,7 @@ import optax
|
|
|
16
16
|
|
|
17
17
|
from pyRDDLGym.core.compiler.model import RDDLLiftedModel
|
|
18
18
|
|
|
19
|
-
from pyRDDLGym_jax.core.logic import
|
|
20
|
-
from pyRDDLGym_jax.core.planner import JaxRDDLCompilerWithGrad
|
|
19
|
+
from pyRDDLGym_jax.core.logic import JaxRDDLCompilerWithGrad
|
|
21
20
|
|
|
22
21
|
Kwargs = Dict[str, Any]
|
|
23
22
|
State = Dict[str, np.ndarray]
|
|
@@ -38,25 +37,22 @@ LossFunction = Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
|
|
|
38
37
|
|
|
39
38
|
def mean_squared_error() -> LossFunction:
|
|
40
39
|
def _jax_wrapped_mse_loss(target, pred):
|
|
41
|
-
|
|
42
|
-
return loss_values
|
|
40
|
+
return jnp.square(target - pred)
|
|
43
41
|
return jax.jit(_jax_wrapped_mse_loss)
|
|
44
42
|
|
|
45
43
|
|
|
46
|
-
def binary_cross_entropy(eps: float=1e-
|
|
44
|
+
def binary_cross_entropy(eps: float=1e-8) -> LossFunction:
|
|
47
45
|
def _jax_wrapped_binary_cross_entropy_loss(target, pred):
|
|
48
46
|
pred = jnp.clip(pred, eps, 1.0 - eps)
|
|
49
47
|
log_pred = jnp.log(pred)
|
|
50
48
|
log_not_pred = jnp.log(1.0 - pred)
|
|
51
|
-
|
|
52
|
-
return loss_values
|
|
49
|
+
return -target * log_pred - (1.0 - target) * log_not_pred
|
|
53
50
|
return jax.jit(_jax_wrapped_binary_cross_entropy_loss)
|
|
54
51
|
|
|
55
52
|
|
|
56
53
|
def optax_loss(loss_fn: LossFunction, **kwargs) -> LossFunction:
|
|
57
54
|
def _jax_wrapped_optax_loss(target, pred):
|
|
58
|
-
|
|
59
|
-
return loss_values
|
|
55
|
+
return loss_fn(pred, target, **kwargs)
|
|
60
56
|
return jax.jit(_jax_wrapped_optax_loss)
|
|
61
57
|
|
|
62
58
|
|
|
@@ -96,11 +92,11 @@ class JaxModelLearner:
|
|
|
96
92
|
optimizer_kwargs: Optional[Kwargs]=None,
|
|
97
93
|
initializer: initializers.Initializer = initializers.normal(),
|
|
98
94
|
wrap_non_bool: bool=True,
|
|
99
|
-
use64bit: bool=False,
|
|
100
95
|
bool_fluent_loss: LossFunction=binary_cross_entropy(),
|
|
101
96
|
real_fluent_loss: LossFunction=mean_squared_error(),
|
|
102
97
|
int_fluent_loss: LossFunction=mean_squared_error(),
|
|
103
|
-
|
|
98
|
+
compiler: JaxRDDLCompilerWithGrad=JaxRDDLCompilerWithGrad,
|
|
99
|
+
compiler_kwargs: Optional[Kwargs]=None,
|
|
104
100
|
model_params_reduction: Callable=lambda x: x[0]) -> None:
|
|
105
101
|
'''Creates a new gradient-based algorithm for inferring unknown non-fluents
|
|
106
102
|
in a RDDL domain from a data set or stream coming from the real environment.
|
|
@@ -117,12 +113,11 @@ class JaxModelLearner:
|
|
|
117
113
|
:param initializer: how to initialize non-fluents
|
|
118
114
|
:param wrap_non_bool: whether to wrap non-boolean trainable parameters to satisfy
|
|
119
115
|
required ranges as specified in param_ranges (use a projected gradient otherwise)
|
|
120
|
-
:param use64bit: whether to perform arithmetic in 64 bit
|
|
121
116
|
:param bool_fluent_loss: loss function to optimize for bool-valued fluents
|
|
122
117
|
:param real_fluent_loss: loss function to optimize for real-valued fluents
|
|
123
118
|
:param int_fluent_loss: loss function to optimize for int-valued fluents
|
|
124
|
-
:param
|
|
125
|
-
|
|
119
|
+
:param compiler: compiler instance to use for planning
|
|
120
|
+
:param compiler_kwargs: compiler instances kwargs for initialization
|
|
126
121
|
:param model_params_reduction: how to aggregate updated model_params across runs
|
|
127
122
|
in the batch (defaults to selecting the first element's parameters in the batch)
|
|
128
123
|
'''
|
|
@@ -135,11 +130,9 @@ class JaxModelLearner:
|
|
|
135
130
|
self.optimizer_kwargs = optimizer_kwargs
|
|
136
131
|
self.initializer = initializer
|
|
137
132
|
self.wrap_non_bool = wrap_non_bool
|
|
138
|
-
self.use64bit = use64bit
|
|
139
133
|
self.bool_fluent_loss = bool_fluent_loss
|
|
140
134
|
self.real_fluent_loss = real_fluent_loss
|
|
141
135
|
self.int_fluent_loss = int_fluent_loss
|
|
142
|
-
self.logic = logic
|
|
143
136
|
self.model_params_reduction = model_params_reduction
|
|
144
137
|
|
|
145
138
|
# validate param_ranges
|
|
@@ -166,6 +159,11 @@ class JaxModelLearner:
|
|
|
166
159
|
self.optimizer = optax.chain(*pipeline)
|
|
167
160
|
|
|
168
161
|
# build the computation graph
|
|
162
|
+
if compiler_kwargs is None:
|
|
163
|
+
compiler_kwargs = {}
|
|
164
|
+
self.compiler_kwargs = compiler_kwargs
|
|
165
|
+
self.compiler_type = compiler
|
|
166
|
+
|
|
169
167
|
self.step_fn = self._jax_compile_rddl()
|
|
170
168
|
self.map_fn = self._jax_map()
|
|
171
169
|
self.loss_fn = self._jax_loss(map_fn=self.map_fn, step_fn=self.step_fn)
|
|
@@ -179,44 +177,39 @@ class JaxModelLearner:
|
|
|
179
177
|
def _jax_compile_rddl(self):
|
|
180
178
|
|
|
181
179
|
# compile the RDDL model
|
|
182
|
-
self.compiled =
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
use64bit=self.use64bit,
|
|
186
|
-
compile_non_fluent_exact=False,
|
|
187
|
-
print_warnings=True
|
|
180
|
+
self.compiled = self.compiler_type(
|
|
181
|
+
self.rddl,
|
|
182
|
+
**self.compiler_kwargs
|
|
188
183
|
)
|
|
189
184
|
self.compiled.compile(log_jax_expr=True, heading='RELAXED MODEL')
|
|
190
185
|
|
|
191
186
|
# compile the transition step function
|
|
192
187
|
step_fn = self.compiled.compile_transition()
|
|
193
188
|
|
|
194
|
-
def _jax_wrapped_step(key, param_fluents,
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
return
|
|
189
|
+
def _jax_wrapped_step(key, param_fluents, fls, nfls, actions, hyperparams):
|
|
190
|
+
nfls = nfls.copy()
|
|
191
|
+
nfls.update(param_fluents)
|
|
192
|
+
fls, _, hyperparams = step_fn(key, actions, fls, nfls, hyperparams)
|
|
193
|
+
return fls, hyperparams
|
|
199
194
|
|
|
200
195
|
# batched step function
|
|
201
|
-
def _jax_wrapped_batched_step(key, param_fluents,
|
|
202
|
-
keys =
|
|
203
|
-
|
|
204
|
-
_jax_wrapped_step, in_axes=(0, None, 0, 0, None)
|
|
205
|
-
)(keys, param_fluents,
|
|
196
|
+
def _jax_wrapped_batched_step(key, param_fluents, fls, nfls, actions, hyperparams):
|
|
197
|
+
keys = random.split(key, num=self.batch_size_train)
|
|
198
|
+
fls, hyperparams = jax.vmap(
|
|
199
|
+
_jax_wrapped_step, in_axes=(0, None, 0, None, 0, None)
|
|
200
|
+
)(keys, param_fluents, fls, nfls, actions, hyperparams)
|
|
206
201
|
hyperparams = jax.tree_util.tree_map(self.model_params_reduction, hyperparams)
|
|
207
|
-
return
|
|
202
|
+
return fls, hyperparams
|
|
208
203
|
|
|
209
204
|
# batched step function with parallel samples per data point
|
|
210
|
-
def _jax_wrapped_batched_parallel_step(key, param_fluents,
|
|
211
|
-
keys =
|
|
212
|
-
|
|
213
|
-
_jax_wrapped_batched_step, in_axes=(0, None, None, None, None)
|
|
214
|
-
)(keys, param_fluents,
|
|
205
|
+
def _jax_wrapped_batched_parallel_step(key, param_fluents, fls, nfls, actions, hyperparams):
|
|
206
|
+
keys = random.split(key, num=self.samples_per_datapoint)
|
|
207
|
+
fls, hyperparams = jax.vmap(
|
|
208
|
+
_jax_wrapped_batched_step, in_axes=(0, None, None, None, None, None)
|
|
209
|
+
)(keys, param_fluents, fls, nfls, actions, hyperparams)
|
|
215
210
|
hyperparams = jax.tree_util.tree_map(self.model_params_reduction, hyperparams)
|
|
216
|
-
return
|
|
217
|
-
|
|
218
|
-
batched_step_fn = jax.jit(_jax_wrapped_batched_parallel_step)
|
|
219
|
-
return batched_step_fn
|
|
211
|
+
return fls, hyperparams
|
|
212
|
+
return jax.jit(_jax_wrapped_batched_parallel_step)
|
|
220
213
|
|
|
221
214
|
def _jax_map(self):
|
|
222
215
|
|
|
@@ -254,20 +247,18 @@ class JaxModelLearner:
|
|
|
254
247
|
else:
|
|
255
248
|
param_fluents[name] = param
|
|
256
249
|
return param_fluents
|
|
257
|
-
|
|
258
|
-
map_fn = jax.jit(_jax_wrapped_params_to_fluents)
|
|
259
|
-
return map_fn
|
|
250
|
+
return jax.jit(_jax_wrapped_params_to_fluents)
|
|
260
251
|
|
|
261
252
|
def _jax_loss(self, map_fn, step_fn):
|
|
262
253
|
|
|
263
254
|
# use binary cross entropy for bool fluents
|
|
264
255
|
# mean squared error for continuous and integer fluents
|
|
265
|
-
def _jax_wrapped_batched_model_loss(key, param_fluents,
|
|
266
|
-
hyperparams):
|
|
267
|
-
|
|
256
|
+
def _jax_wrapped_batched_model_loss(key, param_fluents, fls, nfls, actions,
|
|
257
|
+
next_fluents, hyperparams):
|
|
258
|
+
fls, hyperparams = step_fn(key, param_fluents, fls, nfls, actions, hyperparams)
|
|
268
259
|
total_loss = 0.0
|
|
269
260
|
for (name, next_value) in next_fluents.items():
|
|
270
|
-
preds = jnp.asarray(
|
|
261
|
+
preds = jnp.asarray(fls[name], dtype=self.compiled.REAL)
|
|
271
262
|
targets = jnp.asarray(next_value, dtype=self.compiled.REAL)[jnp.newaxis, ...]
|
|
272
263
|
if self.rddl.variable_ranges[name] == 'bool':
|
|
273
264
|
loss_values = self.bool_fluent_loss(targets, preds)
|
|
@@ -279,14 +270,12 @@ class JaxModelLearner:
|
|
|
279
270
|
return total_loss, hyperparams
|
|
280
271
|
|
|
281
272
|
# loss with the parameters mapped to their fluents
|
|
282
|
-
def _jax_wrapped_batched_loss(key, params,
|
|
273
|
+
def _jax_wrapped_batched_loss(key, params, fls, nfls, actions, next_fluents,
|
|
274
|
+
hyperparams):
|
|
283
275
|
param_fluents = map_fn(params)
|
|
284
|
-
|
|
285
|
-
key, param_fluents,
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
loss_fn = jax.jit(_jax_wrapped_batched_loss)
|
|
289
|
-
return loss_fn
|
|
276
|
+
return _jax_wrapped_batched_model_loss(
|
|
277
|
+
key, param_fluents, fls, nfls, actions, next_fluents, hyperparams)
|
|
278
|
+
return jax.jit(_jax_wrapped_batched_loss)
|
|
290
279
|
|
|
291
280
|
def _jax_init(self, project_fn):
|
|
292
281
|
optimizer = self.optimizer
|
|
@@ -325,20 +314,19 @@ class JaxModelLearner:
|
|
|
325
314
|
if self.rddl.variable_ranges[name] == 'bool':
|
|
326
315
|
new_params[name] = value
|
|
327
316
|
else:
|
|
328
|
-
|
|
329
|
-
new_params[name] = jnp.clip(value, lower, upper)
|
|
317
|
+
new_params[name] = jnp.clip(value, *self.param_ranges[name])
|
|
330
318
|
return new_params
|
|
331
319
|
|
|
332
320
|
# gradient descent update
|
|
333
|
-
def _jax_wrapped_params_update(key, params,
|
|
321
|
+
def _jax_wrapped_params_update(key, params, fls, nfls, actions, next_fluents,
|
|
334
322
|
hyperparams, opt_state):
|
|
335
323
|
(loss_val, hyperparams), grad = jax.value_and_grad(
|
|
336
324
|
loss_fn, argnums=1, has_aux=True
|
|
337
|
-
)(key, params,
|
|
325
|
+
)(key, params, fls, nfls, actions, next_fluents, hyperparams)
|
|
338
326
|
updates, opt_state = optimizer.update(grad, opt_state)
|
|
339
327
|
params = optax.apply_updates(params, updates)
|
|
340
328
|
params = _jax_wrapped_project_params(params)
|
|
341
|
-
zero_grads = jax.tree_util.tree_map(partial(jnp.allclose, b=0
|
|
329
|
+
zero_grads = jax.tree_util.tree_map(partial(jnp.allclose, b=0), grad)
|
|
342
330
|
return params, opt_state, loss_val, zero_grads, hyperparams
|
|
343
331
|
|
|
344
332
|
update_fn = jax.jit(_jax_wrapped_params_update)
|
|
@@ -346,15 +334,17 @@ class JaxModelLearner:
|
|
|
346
334
|
return update_fn, project_fn
|
|
347
335
|
|
|
348
336
|
def _batched_init_subs(self):
|
|
349
|
-
|
|
337
|
+
init_fls, init_nfls = {}, {}
|
|
350
338
|
for (name, value) in self.compiled.init_values.items():
|
|
351
|
-
value = np.reshape(value, np.shape(value))[np.newaxis, ...]
|
|
352
|
-
value = np.repeat(value, repeats=self.batch_size_train, axis=0)
|
|
353
339
|
value = np.asarray(value, dtype=self.compiled.REAL)
|
|
354
|
-
|
|
340
|
+
if name in self.rddl.non_fluents:
|
|
341
|
+
init_nfls[name] = value
|
|
342
|
+
else:
|
|
343
|
+
init_fls[name] = np.repeat(
|
|
344
|
+
value[np.newaxis, ...], repeats=self.batch_size_train, axis=0)
|
|
355
345
|
for (state, next_state) in self.rddl.next_state.items():
|
|
356
|
-
|
|
357
|
-
return
|
|
346
|
+
init_fls[next_state] = init_fls[state]
|
|
347
|
+
return init_fls, init_nfls
|
|
358
348
|
|
|
359
349
|
# ===========================================================================
|
|
360
350
|
# ESTIMATE API
|
|
@@ -415,7 +405,7 @@ class JaxModelLearner:
|
|
|
415
405
|
key = random.PRNGKey(round(time.time() * 1000))
|
|
416
406
|
|
|
417
407
|
# prepare initial subs
|
|
418
|
-
|
|
408
|
+
fls, nfls = self._batched_init_subs()
|
|
419
409
|
|
|
420
410
|
# initialize parameter fluents to optimize
|
|
421
411
|
if guess is None:
|
|
@@ -425,7 +415,7 @@ class JaxModelLearner:
|
|
|
425
415
|
params, opt_state = self.init_opt_fn(guess)
|
|
426
416
|
|
|
427
417
|
# initialize model hyper-parameters
|
|
428
|
-
hyperparams = self.compiled.
|
|
418
|
+
hyperparams = self.compiled.model_aux['params']
|
|
429
419
|
|
|
430
420
|
# progress bar
|
|
431
421
|
if print_progress:
|
|
@@ -439,10 +429,10 @@ class JaxModelLearner:
|
|
|
439
429
|
status = JaxLearnerStatus.NORMAL
|
|
440
430
|
|
|
441
431
|
# gradient update
|
|
442
|
-
|
|
432
|
+
fls.update(states)
|
|
443
433
|
key, subkey = random.split(key)
|
|
444
434
|
params, opt_state, loss, zero_grads, hyperparams = self.update_fn(
|
|
445
|
-
subkey, params,
|
|
435
|
+
subkey, params, fls, nfls, actions, next_states, hyperparams, opt_state)
|
|
446
436
|
|
|
447
437
|
# extract non-fluent values from the trainable parameters
|
|
448
438
|
param_fluents = self.map_fn(params)
|
|
@@ -450,7 +440,8 @@ class JaxModelLearner:
|
|
|
450
440
|
|
|
451
441
|
# check for learnability
|
|
452
442
|
params_zero_grads = {
|
|
453
|
-
name for (name, zero_grad) in zero_grads.items() if zero_grad
|
|
443
|
+
name for (name, zero_grad) in zero_grads.items() if zero_grad
|
|
444
|
+
}
|
|
454
445
|
if params_zero_grads:
|
|
455
446
|
status = JaxLearnerStatus.NO_PROGRESS
|
|
456
447
|
|
|
@@ -504,14 +495,14 @@ class JaxModelLearner:
|
|
|
504
495
|
'''
|
|
505
496
|
if key is None:
|
|
506
497
|
key = random.PRNGKey(round(time.time() * 1000))
|
|
507
|
-
|
|
508
|
-
hyperparams = self.compiled.
|
|
498
|
+
fls, nfls = self._batched_init_subs()
|
|
499
|
+
hyperparams = self.compiled.model_aux['params']
|
|
509
500
|
mean_loss = 0.0
|
|
510
501
|
for (it, (states, actions, next_states)) in enumerate(data):
|
|
511
|
-
|
|
502
|
+
fls.update(states)
|
|
512
503
|
key, subkey = random.split(key)
|
|
513
504
|
loss_value, _ = self.loss_fn(
|
|
514
|
-
subkey, param_fluents,
|
|
505
|
+
subkey, param_fluents, fls, nfls, actions, next_states, hyperparams)
|
|
515
506
|
mean_loss += (loss_value - mean_loss) / (it + 1)
|
|
516
507
|
return mean_loss
|
|
517
508
|
|
|
@@ -524,15 +515,13 @@ class JaxModelLearner:
|
|
|
524
515
|
model = deepcopy(self.rddl)
|
|
525
516
|
for (name, values) in param_fluents.items():
|
|
526
517
|
prange = model.variable_ranges[name]
|
|
527
|
-
if prange == '
|
|
528
|
-
pass
|
|
529
|
-
elif prange == 'bool':
|
|
518
|
+
if prange == 'bool':
|
|
530
519
|
values = values > 0.5
|
|
531
|
-
|
|
520
|
+
elif prange != 'real':
|
|
532
521
|
values = np.asarray(values, dtype=self.compiled.INT)
|
|
533
522
|
values = np.ravel(values, order='C').tolist()
|
|
534
523
|
if not self.rddl.variable_params[name]:
|
|
535
|
-
assert(len(values) == 1)
|
|
524
|
+
assert (len(values) == 1)
|
|
536
525
|
values = values[0]
|
|
537
526
|
model.non_fluents[name] = values
|
|
538
527
|
return model
|
|
@@ -549,7 +538,7 @@ if __name__ == '__main__':
|
|
|
549
538
|
env = pyRDDLGym.make('CartPole_Continuous_gym', '0', vectorized=True)
|
|
550
539
|
model = JaxModelLearner(rddl=env.model, param_ranges={}, batch_size_train=bs)
|
|
551
540
|
key = random.PRNGKey(round(time.time() * 1000))
|
|
552
|
-
|
|
541
|
+
fls, nfls = model._batched_init_subs()
|
|
553
542
|
param_fluents = {}
|
|
554
543
|
while True:
|
|
555
544
|
states = {
|
|
@@ -558,14 +547,14 @@ if __name__ == '__main__':
|
|
|
558
547
|
'ang-pos': np.random.uniform(-0.21, 0.21, (bs,)),
|
|
559
548
|
'ang-vel': np.random.uniform(-0.21, 0.21, (bs,))
|
|
560
549
|
}
|
|
561
|
-
|
|
550
|
+
fls.update(states)
|
|
562
551
|
actions = {
|
|
563
552
|
'force': np.random.uniform(-10., 10., (bs,))
|
|
564
553
|
}
|
|
565
554
|
key, subkey = random.split(key)
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
next_states = {k:
|
|
555
|
+
fls, _ = model.step_fn(subkey, param_fluents, fls, nfls, actions, {})
|
|
556
|
+
fls = {k: np.asarray(v)[0, ...] for k, v in fls.items()}
|
|
557
|
+
next_states = {k: fls[k] for k in model.rddl.state_fluents}
|
|
569
558
|
yield (states, actions, next_states)
|
|
570
559
|
|
|
571
560
|
# train it
|