pyRDDLGym-jax 2.5__py3-none-any.whl → 2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +16 -7
- pyRDDLGym_jax/core/logic.py +6 -8
- pyRDDLGym_jax/core/model.py +595 -0
- pyRDDLGym_jax/core/planner.py +173 -21
- {pyrddlgym_jax-2.5.dist-info → pyrddlgym_jax-2.6.dist-info}/METADATA +5 -13
- {pyrddlgym_jax-2.5.dist-info → pyrddlgym_jax-2.6.dist-info}/RECORD +11 -10
- {pyrddlgym_jax-2.5.dist-info → pyrddlgym_jax-2.6.dist-info}/WHEEL +0 -0
- {pyrddlgym_jax-2.5.dist-info → pyrddlgym_jax-2.6.dist-info}/entry_points.txt +0 -0
- {pyrddlgym_jax-2.5.dist-info → pyrddlgym_jax-2.6.dist-info}/licenses/LICENSE +0 -0
- {pyrddlgym_jax-2.5.dist-info → pyrddlgym_jax-2.6.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '2.
|
|
1
|
+
__version__ = '2.6'
|
pyRDDLGym_jax/core/compiler.py
CHANGED
|
@@ -237,7 +237,8 @@ class JaxRDDLCompiler:
|
|
|
237
237
|
|
|
238
238
|
def compile_transition(self, check_constraints: bool=False,
|
|
239
239
|
constraint_func: bool=False,
|
|
240
|
-
init_params_constr: Dict[str, Any]={}
|
|
240
|
+
init_params_constr: Dict[str, Any]={},
|
|
241
|
+
cache_path_info: bool=False) -> Callable:
|
|
241
242
|
'''Compiles the current RDDL model into a JAX transition function that
|
|
242
243
|
samples the next state.
|
|
243
244
|
|
|
@@ -274,6 +275,7 @@ class JaxRDDLCompiler:
|
|
|
274
275
|
returned log and does not raise an exception
|
|
275
276
|
:param constraint_func: produces the h(s, a) function described above
|
|
276
277
|
in addition to the usual outputs
|
|
278
|
+
:param cache_path_info: whether to save full path traces as part of the log
|
|
277
279
|
'''
|
|
278
280
|
NORMAL = JaxRDDLCompiler.ERROR_CODES['NORMAL']
|
|
279
281
|
rddl = self.rddl
|
|
@@ -322,8 +324,11 @@ class JaxRDDLCompiler:
|
|
|
322
324
|
errors |= err
|
|
323
325
|
|
|
324
326
|
# calculate fluent values
|
|
325
|
-
|
|
326
|
-
|
|
327
|
+
if cache_path_info:
|
|
328
|
+
fluents = {name: values for (name, values) in subs.items()
|
|
329
|
+
if name not in rddl.non_fluents}
|
|
330
|
+
else:
|
|
331
|
+
fluents = {}
|
|
327
332
|
|
|
328
333
|
# set the next state to the current state
|
|
329
334
|
for (state, next_state) in rddl.next_state.items():
|
|
@@ -367,7 +372,9 @@ class JaxRDDLCompiler:
|
|
|
367
372
|
n_batch: int,
|
|
368
373
|
check_constraints: bool=False,
|
|
369
374
|
constraint_func: bool=False,
|
|
370
|
-
init_params_constr: Dict[str, Any]={}
|
|
375
|
+
init_params_constr: Dict[str, Any]={},
|
|
376
|
+
model_params_reduction: Callable=lambda x: x[0],
|
|
377
|
+
cache_path_info: bool=False) -> Callable:
|
|
371
378
|
'''Compiles the current RDDL model into a JAX transition function that
|
|
372
379
|
samples trajectories with a fixed horizon from a policy.
|
|
373
380
|
|
|
@@ -399,10 +406,13 @@ class JaxRDDLCompiler:
|
|
|
399
406
|
returned log and does not raise an exception
|
|
400
407
|
:param constraint_func: produces the h(s, a) constraint function
|
|
401
408
|
in addition to the usual outputs
|
|
409
|
+
:param model_params_reduction: how to aggregate updated model_params across runs
|
|
410
|
+
in the batch (defaults to selecting the first element's parameters in the batch)
|
|
411
|
+
:param cache_path_info: whether to save full path traces as part of the log
|
|
402
412
|
'''
|
|
403
413
|
rddl = self.rddl
|
|
404
414
|
jax_step_fn = self.compile_transition(
|
|
405
|
-
check_constraints, constraint_func, init_params_constr)
|
|
415
|
+
check_constraints, constraint_func, init_params_constr, cache_path_info)
|
|
406
416
|
|
|
407
417
|
# for POMDP only observ-fluents are assumed visible to the policy
|
|
408
418
|
if rddl.observ_fluents:
|
|
@@ -421,7 +431,6 @@ class JaxRDDLCompiler:
|
|
|
421
431
|
return jax_step_fn(subkey, actions, subs, model_params)
|
|
422
432
|
|
|
423
433
|
# do a batched step update from the policy
|
|
424
|
-
# TODO: come up with a better way to reduce the model_param batch dim
|
|
425
434
|
def _jax_wrapped_batched_step_policy(carry, step):
|
|
426
435
|
key, policy_params, hyperparams, subs, model_params = carry
|
|
427
436
|
key, *subkeys = random.split(key, num=1 + n_batch)
|
|
@@ -430,7 +439,7 @@ class JaxRDDLCompiler:
|
|
|
430
439
|
_jax_wrapped_single_step_policy,
|
|
431
440
|
in_axes=(0, None, None, None, 0, None)
|
|
432
441
|
)(keys, policy_params, hyperparams, step, subs, model_params)
|
|
433
|
-
model_params = jax.tree_util.tree_map(
|
|
442
|
+
model_params = jax.tree_util.tree_map(model_params_reduction, model_params)
|
|
434
443
|
carry = (key, policy_params, hyperparams, subs, model_params)
|
|
435
444
|
return carry, log
|
|
436
445
|
|
pyRDDLGym_jax/core/logic.py
CHANGED
|
@@ -1056,15 +1056,13 @@ class ExactLogic(Logic):
|
|
|
1056
1056
|
def control_if(self, id, init_params):
|
|
1057
1057
|
return self._jax_wrapped_calc_if_then_else_exact
|
|
1058
1058
|
|
|
1059
|
-
@staticmethod
|
|
1060
|
-
def _jax_wrapped_calc_switch_exact(pred, cases, params):
|
|
1061
|
-
pred = pred[jnp.newaxis, ...]
|
|
1062
|
-
sample = jnp.take_along_axis(cases, pred, axis=0)
|
|
1063
|
-
assert sample.shape[0] == 1
|
|
1064
|
-
return sample[0, ...], params
|
|
1065
|
-
|
|
1066
1059
|
def control_switch(self, id, init_params):
|
|
1067
|
-
|
|
1060
|
+
def _jax_wrapped_calc_switch_exact(pred, cases, params):
|
|
1061
|
+
pred = jnp.asarray(pred[jnp.newaxis, ...], dtype=self.INT)
|
|
1062
|
+
sample = jnp.take_along_axis(cases, pred, axis=0)
|
|
1063
|
+
assert sample.shape[0] == 1
|
|
1064
|
+
return sample[0, ...], params
|
|
1065
|
+
return _jax_wrapped_calc_switch_exact
|
|
1068
1066
|
|
|
1069
1067
|
# ===========================================================================
|
|
1070
1068
|
# random variables
|
|
@@ -0,0 +1,595 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
from copy import deepcopy
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from functools import partial
|
|
5
|
+
import sys
|
|
6
|
+
import time
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple
|
|
9
|
+
|
|
10
|
+
import jax
|
|
11
|
+
import jax.nn.initializers as initializers
|
|
12
|
+
import jax.numpy as jnp
|
|
13
|
+
import jax.random as random
|
|
14
|
+
import numpy as np
|
|
15
|
+
import optax
|
|
16
|
+
|
|
17
|
+
from pyRDDLGym.core.compiler.model import RDDLLiftedModel
|
|
18
|
+
|
|
19
|
+
from pyRDDLGym_jax.core.logic import Logic, ExactLogic
|
|
20
|
+
from pyRDDLGym_jax.core.planner import JaxRDDLCompilerWithGrad
|
|
21
|
+
|
|
22
|
+
Kwargs = Dict[str, Any]
|
|
23
|
+
State = Dict[str, np.ndarray]
|
|
24
|
+
Action = Dict[str, np.ndarray]
|
|
25
|
+
DataStream = Iterable[Tuple[State, Action, State]]
|
|
26
|
+
Params = Dict[str, np.ndarray]
|
|
27
|
+
Callback = Dict[str, Any]
|
|
28
|
+
LossFunction = Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# ***********************************************************************
|
|
32
|
+
# ALL VERSIONS OF LOSS FUNCTIONS
|
|
33
|
+
#
|
|
34
|
+
# - loss functions based on specific likelihood assumptions (MSE, cross-entropy)
|
|
35
|
+
#
|
|
36
|
+
# ***********************************************************************
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def mean_squared_error() -> LossFunction:
|
|
40
|
+
def _jax_wrapped_mse_loss(target, pred):
|
|
41
|
+
loss_values = jnp.square(target - pred)
|
|
42
|
+
return loss_values
|
|
43
|
+
return jax.jit(_jax_wrapped_mse_loss)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def binary_cross_entropy(eps: float=1e-6) -> LossFunction:
|
|
47
|
+
def _jax_wrapped_binary_cross_entropy_loss(target, pred):
|
|
48
|
+
pred = jnp.clip(pred, eps, 1.0 - eps)
|
|
49
|
+
log_pred = jnp.log(pred)
|
|
50
|
+
log_not_pred = jnp.log(1.0 - pred)
|
|
51
|
+
loss_values = -target * log_pred - (1.0 - target) * log_not_pred
|
|
52
|
+
return loss_values
|
|
53
|
+
return jax.jit(_jax_wrapped_binary_cross_entropy_loss)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def optax_loss(loss_fn: LossFunction, **kwargs) -> LossFunction:
|
|
57
|
+
def _jax_wrapped_optax_loss(target, pred):
|
|
58
|
+
loss_values = loss_fn(pred, target, **kwargs)
|
|
59
|
+
return loss_values
|
|
60
|
+
return jax.jit(_jax_wrapped_optax_loss)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
# ***********************************************************************
|
|
64
|
+
# ALL VERSIONS OF JAX MODEL LEARNER
|
|
65
|
+
#
|
|
66
|
+
# - gradient based model learning
|
|
67
|
+
#
|
|
68
|
+
# ***********************************************************************
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class JaxLearnerStatus(Enum):
|
|
72
|
+
'''Represents the status of a parameter update from the JAX model learner,
|
|
73
|
+
including whether the update resulted in nan gradient,
|
|
74
|
+
whether progress was made, budget was reached, or other information that
|
|
75
|
+
can be used to monitor and act based on the learner's progress.'''
|
|
76
|
+
|
|
77
|
+
NORMAL = 0
|
|
78
|
+
NO_PROGRESS = 1
|
|
79
|
+
INVALID_GRADIENT = 2
|
|
80
|
+
TIME_BUDGET_REACHED = 3
|
|
81
|
+
ITER_BUDGET_REACHED = 4
|
|
82
|
+
|
|
83
|
+
def is_terminal(self) -> bool:
|
|
84
|
+
return self.value >= 2
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class JaxModelLearner:
|
|
88
|
+
'''A class for data-driven estimation of unknown parameters in a given RDDL MDP using
|
|
89
|
+
gradient descent.'''
|
|
90
|
+
|
|
91
|
+
def __init__(self, rddl: RDDLLiftedModel,
|
|
92
|
+
param_ranges: Dict[str, Tuple[Optional[np.ndarray], Optional[np.ndarray]]],
|
|
93
|
+
batch_size_train: int=32,
|
|
94
|
+
samples_per_datapoint: int=1,
|
|
95
|
+
optimizer: Callable[..., optax.GradientTransformation]=optax.rmsprop,
|
|
96
|
+
optimizer_kwargs: Optional[Kwargs]=None,
|
|
97
|
+
initializer: initializers.Initializer = initializers.normal(),
|
|
98
|
+
wrap_non_bool: bool=True,
|
|
99
|
+
use64bit: bool=False,
|
|
100
|
+
bool_fluent_loss: LossFunction=binary_cross_entropy(),
|
|
101
|
+
real_fluent_loss: LossFunction=mean_squared_error(),
|
|
102
|
+
int_fluent_loss: LossFunction=mean_squared_error(),
|
|
103
|
+
logic: Logic=ExactLogic(),
|
|
104
|
+
model_params_reduction: Callable=lambda x: x[0]) -> None:
|
|
105
|
+
'''Creates a new gradient-based algorithm for inferring unknown non-fluents
|
|
106
|
+
in a RDDL domain from a data set or stream coming from the real environment.
|
|
107
|
+
|
|
108
|
+
:param rddl: the RDDL domain to learn
|
|
109
|
+
:param param_ranges: the ranges of all learnable non-fluents
|
|
110
|
+
:param batch_size_train: how many transitions to compute per optimization
|
|
111
|
+
step
|
|
112
|
+
:param samples_per_datapoint: how many random samples to produce from the step
|
|
113
|
+
function per data point during training
|
|
114
|
+
:param optimizer: a factory for an optax SGD algorithm
|
|
115
|
+
:param optimizer_kwargs: a dictionary of parameters to pass to the SGD
|
|
116
|
+
factory (e.g. which parameters are controllable externally)
|
|
117
|
+
:param initializer: how to initialize non-fluents
|
|
118
|
+
:param wrap_non_bool: whether to wrap non-boolean trainable parameters to satisfy
|
|
119
|
+
required ranges as specified in param_ranges (use a projected gradient otherwise)
|
|
120
|
+
:param use64bit: whether to perform arithmetic in 64 bit
|
|
121
|
+
:param bool_fluent_loss: loss function to optimize for bool-valued fluents
|
|
122
|
+
:param real_fluent_loss: loss function to optimize for real-valued fluents
|
|
123
|
+
:param int_fluent_loss: loss function to optimize for int-valued fluents
|
|
124
|
+
:param logic: a subclass of Logic for mapping exact mathematical
|
|
125
|
+
operations to their differentiable counterparts
|
|
126
|
+
:param model_params_reduction: how to aggregate updated model_params across runs
|
|
127
|
+
in the batch (defaults to selecting the first element's parameters in the batch)
|
|
128
|
+
'''
|
|
129
|
+
self.rddl = rddl
|
|
130
|
+
self.param_ranges = param_ranges.copy()
|
|
131
|
+
self.batch_size_train = batch_size_train
|
|
132
|
+
self.samples_per_datapoint = samples_per_datapoint
|
|
133
|
+
if optimizer_kwargs is None:
|
|
134
|
+
optimizer_kwargs = {'learning_rate': 0.001}
|
|
135
|
+
self.optimizer_kwargs = optimizer_kwargs
|
|
136
|
+
self.initializer = initializer
|
|
137
|
+
self.wrap_non_bool = wrap_non_bool
|
|
138
|
+
self.use64bit = use64bit
|
|
139
|
+
self.bool_fluent_loss = bool_fluent_loss
|
|
140
|
+
self.real_fluent_loss = real_fluent_loss
|
|
141
|
+
self.int_fluent_loss = int_fluent_loss
|
|
142
|
+
self.logic = logic
|
|
143
|
+
self.model_params_reduction = model_params_reduction
|
|
144
|
+
|
|
145
|
+
# validate param_ranges
|
|
146
|
+
for (name, values) in param_ranges.items():
|
|
147
|
+
if name not in rddl.non_fluents:
|
|
148
|
+
raise ValueError(
|
|
149
|
+
f'param_ranges key <{name}> is not a valid non-fluent '
|
|
150
|
+
f'in the current rddl.')
|
|
151
|
+
if not isinstance(values, (tuple, list)):
|
|
152
|
+
raise ValueError(
|
|
153
|
+
f'param_ranges values with key <{name}> are neither a tuple nor a list.')
|
|
154
|
+
if len(values) != 2:
|
|
155
|
+
raise ValueError(
|
|
156
|
+
f'param_ranges values with key <{name}> must be of length 2, '
|
|
157
|
+
f'got length {len(values)}.')
|
|
158
|
+
lower, upper = values
|
|
159
|
+
if lower is not None and upper is not None and not np.all(lower <= upper):
|
|
160
|
+
raise ValueError(
|
|
161
|
+
f'param_ranges values with key <{name}> do not satisfy lower <= upper.')
|
|
162
|
+
|
|
163
|
+
# build the optimizer
|
|
164
|
+
optimizer = optimizer(**optimizer_kwargs)
|
|
165
|
+
pipeline = [optimizer]
|
|
166
|
+
self.optimizer = optax.chain(*pipeline)
|
|
167
|
+
|
|
168
|
+
# build the computation graph
|
|
169
|
+
self.step_fn = self._jax_compile_rddl()
|
|
170
|
+
self.map_fn = self._jax_map()
|
|
171
|
+
self.loss_fn = self._jax_loss(map_fn=self.map_fn, step_fn=self.step_fn)
|
|
172
|
+
self.update_fn, self.project_fn = self._jax_update(loss_fn=self.loss_fn)
|
|
173
|
+
self.init_fn, self.init_opt_fn = self._jax_init(project_fn=self.project_fn)
|
|
174
|
+
|
|
175
|
+
# ===========================================================================
|
|
176
|
+
# COMPILATION SUBROUTINES
|
|
177
|
+
# ===========================================================================
|
|
178
|
+
|
|
179
|
+
def _jax_compile_rddl(self):
|
|
180
|
+
|
|
181
|
+
# compile the RDDL model
|
|
182
|
+
self.compiled = JaxRDDLCompilerWithGrad(
|
|
183
|
+
rddl=self.rddl,
|
|
184
|
+
logic=self.logic,
|
|
185
|
+
use64bit=self.use64bit,
|
|
186
|
+
compile_non_fluent_exact=False,
|
|
187
|
+
print_warnings=True
|
|
188
|
+
)
|
|
189
|
+
self.compiled.compile(log_jax_expr=True, heading='RELAXED MODEL')
|
|
190
|
+
|
|
191
|
+
# compile the transition step function
|
|
192
|
+
step_fn = self.compiled.compile_transition()
|
|
193
|
+
|
|
194
|
+
def _jax_wrapped_step(key, param_fluents, subs, actions, hyperparams):
|
|
195
|
+
for (name, param) in param_fluents.items():
|
|
196
|
+
subs[name] = param
|
|
197
|
+
subs, _, hyperparams = step_fn(key, actions, subs, hyperparams)
|
|
198
|
+
return subs, hyperparams
|
|
199
|
+
|
|
200
|
+
# batched step function
|
|
201
|
+
def _jax_wrapped_batched_step(key, param_fluents, subs, actions, hyperparams):
|
|
202
|
+
keys = jnp.asarray(random.split(key, num=self.batch_size_train))
|
|
203
|
+
subs, hyperparams = jax.vmap(
|
|
204
|
+
_jax_wrapped_step, in_axes=(0, None, 0, 0, None)
|
|
205
|
+
)(keys, param_fluents, subs, actions, hyperparams)
|
|
206
|
+
hyperparams = jax.tree_util.tree_map(self.model_params_reduction, hyperparams)
|
|
207
|
+
return subs, hyperparams
|
|
208
|
+
|
|
209
|
+
# batched step function with parallel samples per data point
|
|
210
|
+
def _jax_wrapped_batched_parallel_step(key, param_fluents, subs, actions, hyperparams):
|
|
211
|
+
keys = jnp.asarray(random.split(key, num=self.samples_per_datapoint))
|
|
212
|
+
subs, hyperparams = jax.vmap(
|
|
213
|
+
_jax_wrapped_batched_step, in_axes=(0, None, None, None, None)
|
|
214
|
+
)(keys, param_fluents, subs, actions, hyperparams)
|
|
215
|
+
hyperparams = jax.tree_util.tree_map(self.model_params_reduction, hyperparams)
|
|
216
|
+
return subs, hyperparams
|
|
217
|
+
|
|
218
|
+
batched_step_fn = jax.jit(_jax_wrapped_batched_parallel_step)
|
|
219
|
+
return batched_step_fn
|
|
220
|
+
|
|
221
|
+
def _jax_map(self):
|
|
222
|
+
|
|
223
|
+
# compute case indices for bounding
|
|
224
|
+
case_indices = {}
|
|
225
|
+
if self.wrap_non_bool:
|
|
226
|
+
for (name, (lower, upper)) in self.param_ranges.items():
|
|
227
|
+
if lower is None: lower = -np.inf
|
|
228
|
+
if upper is None: upper = +np.inf
|
|
229
|
+
self.param_ranges[name] = (lower, upper)
|
|
230
|
+
case_indices[name] = (
|
|
231
|
+
0 * (np.isfinite(lower) & np.isfinite(upper)) +
|
|
232
|
+
1 * (np.isfinite(lower) & ~np.isfinite(upper)) +
|
|
233
|
+
2 * (~np.isfinite(lower) & np.isfinite(upper)) +
|
|
234
|
+
3 * (~np.isfinite(lower) & ~np.isfinite(upper))
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
# map trainable parameters to their non-fluent values
|
|
238
|
+
def _jax_wrapped_params_to_fluents(params):
|
|
239
|
+
param_fluents = {}
|
|
240
|
+
for (name, param) in params.items():
|
|
241
|
+
if self.rddl.variable_ranges[name] == 'bool':
|
|
242
|
+
param_fluents[name] = jax.nn.sigmoid(param)
|
|
243
|
+
else:
|
|
244
|
+
if self.wrap_non_bool:
|
|
245
|
+
lower, upper = self.param_ranges[name]
|
|
246
|
+
cases = [
|
|
247
|
+
lambda x: lower + (upper - lower) * jax.nn.sigmoid(x),
|
|
248
|
+
lambda x: lower + (jax.nn.elu(x) + 1.0),
|
|
249
|
+
lambda x: upper - (jax.nn.elu(-x) + 1.0),
|
|
250
|
+
lambda x: x
|
|
251
|
+
]
|
|
252
|
+
indices = case_indices[name]
|
|
253
|
+
param_fluents[name] = jax.lax.switch(indices, cases, param)
|
|
254
|
+
else:
|
|
255
|
+
param_fluents[name] = param
|
|
256
|
+
return param_fluents
|
|
257
|
+
|
|
258
|
+
map_fn = jax.jit(_jax_wrapped_params_to_fluents)
|
|
259
|
+
return map_fn
|
|
260
|
+
|
|
261
|
+
def _jax_loss(self, map_fn, step_fn):
|
|
262
|
+
|
|
263
|
+
# use binary cross entropy for bool fluents
|
|
264
|
+
# mean squared error for continuous and integer fluents
|
|
265
|
+
def _jax_wrapped_batched_model_loss(key, param_fluents, subs, actions, next_fluents,
|
|
266
|
+
hyperparams):
|
|
267
|
+
next_subs, hyperparams = step_fn(key, param_fluents, subs, actions, hyperparams)
|
|
268
|
+
total_loss = 0.0
|
|
269
|
+
for (name, next_value) in next_fluents.items():
|
|
270
|
+
preds = jnp.asarray(next_subs[name], dtype=self.compiled.REAL)
|
|
271
|
+
targets = jnp.asarray(next_value, dtype=self.compiled.REAL)[jnp.newaxis, ...]
|
|
272
|
+
if self.rddl.variable_ranges[name] == 'bool':
|
|
273
|
+
loss_values = self.bool_fluent_loss(targets, preds)
|
|
274
|
+
elif self.rddl.variable_ranges[name] == 'real':
|
|
275
|
+
loss_values = self.real_fluent_loss(targets, preds)
|
|
276
|
+
else:
|
|
277
|
+
loss_values = self.int_fluent_loss(targets, preds)
|
|
278
|
+
total_loss += jnp.mean(loss_values) / len(next_fluents)
|
|
279
|
+
return total_loss, hyperparams
|
|
280
|
+
|
|
281
|
+
# loss with the parameters mapped to their fluents
|
|
282
|
+
def _jax_wrapped_batched_loss(key, params, subs, actions, next_fluents, hyperparams):
|
|
283
|
+
param_fluents = map_fn(params)
|
|
284
|
+
loss, hyperparams = _jax_wrapped_batched_model_loss(
|
|
285
|
+
key, param_fluents, subs, actions, next_fluents, hyperparams)
|
|
286
|
+
return loss, hyperparams
|
|
287
|
+
|
|
288
|
+
loss_fn = jax.jit(_jax_wrapped_batched_loss)
|
|
289
|
+
return loss_fn
|
|
290
|
+
|
|
291
|
+
def _jax_init(self, project_fn):
|
|
292
|
+
optimizer = self.optimizer
|
|
293
|
+
|
|
294
|
+
# initialize both the non-fluents and optimizer
|
|
295
|
+
def _jax_wrapped_init_params_optimizer(key):
|
|
296
|
+
params = {}
|
|
297
|
+
for name in self.param_ranges:
|
|
298
|
+
shape = jnp.shape(self.compiled.init_values[name])
|
|
299
|
+
key, subkey = random.split(key)
|
|
300
|
+
params[name] = self.initializer(subkey, shape, dtype=self.compiled.REAL)
|
|
301
|
+
params = project_fn(params)
|
|
302
|
+
opt_state = optimizer.init(params)
|
|
303
|
+
return params, opt_state
|
|
304
|
+
|
|
305
|
+
# initialize just the optimizer given the non-fluents
|
|
306
|
+
def _jax_wrapped_init_optimizer(params):
|
|
307
|
+
params = project_fn(params)
|
|
308
|
+
opt_state = optimizer.init(params)
|
|
309
|
+
return params, opt_state
|
|
310
|
+
|
|
311
|
+
init_fn = jax.jit(_jax_wrapped_init_params_optimizer)
|
|
312
|
+
init_opt_fn = jax.jit(_jax_wrapped_init_optimizer)
|
|
313
|
+
return init_fn, init_opt_fn
|
|
314
|
+
|
|
315
|
+
def _jax_update(self, loss_fn):
|
|
316
|
+
optimizer = self.optimizer
|
|
317
|
+
|
|
318
|
+
# projected gradient trick to satisfy box constraints on params
|
|
319
|
+
def _jax_wrapped_project_params(params):
|
|
320
|
+
if self.wrap_non_bool:
|
|
321
|
+
return params
|
|
322
|
+
else:
|
|
323
|
+
new_params = {}
|
|
324
|
+
for (name, value) in params.items():
|
|
325
|
+
if self.rddl.variable_ranges[name] == 'bool':
|
|
326
|
+
new_params[name] = value
|
|
327
|
+
else:
|
|
328
|
+
lower, upper = self.param_ranges[name]
|
|
329
|
+
new_params[name] = jnp.clip(value, lower, upper)
|
|
330
|
+
return new_params
|
|
331
|
+
|
|
332
|
+
# gradient descent update
|
|
333
|
+
def _jax_wrapped_params_update(key, params, subs, actions, next_fluents,
|
|
334
|
+
hyperparams, opt_state):
|
|
335
|
+
(loss_val, hyperparams), grad = jax.value_and_grad(
|
|
336
|
+
loss_fn, argnums=1, has_aux=True
|
|
337
|
+
)(key, params, subs, actions, next_fluents, hyperparams)
|
|
338
|
+
updates, opt_state = optimizer.update(grad, opt_state)
|
|
339
|
+
params = optax.apply_updates(params, updates)
|
|
340
|
+
params = _jax_wrapped_project_params(params)
|
|
341
|
+
zero_grads = jax.tree_util.tree_map(partial(jnp.allclose, b=0.0), grad)
|
|
342
|
+
return params, opt_state, loss_val, zero_grads, hyperparams
|
|
343
|
+
|
|
344
|
+
update_fn = jax.jit(_jax_wrapped_params_update)
|
|
345
|
+
project_fn = jax.jit(_jax_wrapped_project_params)
|
|
346
|
+
return update_fn, project_fn
|
|
347
|
+
|
|
348
|
+
def _batched_init_subs(self):
|
|
349
|
+
init_train = {}
|
|
350
|
+
for (name, value) in self.compiled.init_values.items():
|
|
351
|
+
value = np.reshape(value, np.shape(value))[np.newaxis, ...]
|
|
352
|
+
value = np.repeat(value, repeats=self.batch_size_train, axis=0)
|
|
353
|
+
value = np.asarray(value, dtype=self.compiled.REAL)
|
|
354
|
+
init_train[name] = value
|
|
355
|
+
for (state, next_state) in self.rddl.next_state.items():
|
|
356
|
+
init_train[next_state] = init_train[state]
|
|
357
|
+
return init_train
|
|
358
|
+
|
|
359
|
+
# ===========================================================================
|
|
360
|
+
# ESTIMATE API
|
|
361
|
+
# ===========================================================================
|
|
362
|
+
|
|
363
|
+
def optimize(self, *args, **kwargs) -> Optional[Callback]:
|
|
364
|
+
'''Estimate the unknown parameters from the given data set.
|
|
365
|
+
Return the callback from training.
|
|
366
|
+
|
|
367
|
+
:param data: a data stream represented as a (possibly infinite) sequence of
|
|
368
|
+
transition batches of the form (states, actions, next-states), where each element
|
|
369
|
+
is a numpy array of leading dimension equal to batch_size_train
|
|
370
|
+
:param key: JAX PRNG key (derived from clock if not provided)
|
|
371
|
+
:param epochs: the maximum number of steps of gradient descent
|
|
372
|
+
:param train_seconds: total time allocated for gradient descent
|
|
373
|
+
:param guess: initial non-fluent parameters: if None will use the initializer
|
|
374
|
+
specified in this instance
|
|
375
|
+
:param print_progress: whether to print the progress bar during training
|
|
376
|
+
'''
|
|
377
|
+
it = self.optimize_generator(*args, **kwargs)
|
|
378
|
+
|
|
379
|
+
# https://stackoverflow.com/questions/50937966/fastest-most-pythonic-way-to-consume-an-iterator
|
|
380
|
+
callback = None
|
|
381
|
+
if sys.implementation.name == 'cpython':
|
|
382
|
+
last_callback = deque(it, maxlen=1)
|
|
383
|
+
if last_callback:
|
|
384
|
+
callback = last_callback.pop()
|
|
385
|
+
else:
|
|
386
|
+
for callback in it:
|
|
387
|
+
pass
|
|
388
|
+
return callback
|
|
389
|
+
|
|
390
|
+
def optimize_generator(self, data: DataStream,
|
|
391
|
+
key: Optional[random.PRNGKey]=None,
|
|
392
|
+
epochs: int=999999,
|
|
393
|
+
train_seconds: float=120.,
|
|
394
|
+
guess: Optional[Params]=None,
|
|
395
|
+
print_progress: bool=True) -> Generator[Callback, None, None]:
|
|
396
|
+
'''Return a generator for estimating the unknown parameters from the given data set.
|
|
397
|
+
Generator can be iterated over to lazily estimate the parameters, yielding
|
|
398
|
+
a dictionary of intermediate computations.
|
|
399
|
+
|
|
400
|
+
:param data: a data stream represented as a (possibly infinite) sequence of
|
|
401
|
+
transition batches of the form (states, actions, next-states), where each element
|
|
402
|
+
is a numpy array of leading dimension equal to batch_size_train
|
|
403
|
+
:param key: JAX PRNG key (derived from clock if not provided)
|
|
404
|
+
:param epochs: the maximum number of steps of gradient descent
|
|
405
|
+
:param train_seconds: total time allocated for gradient descent
|
|
406
|
+
:param guess: initial non-fluent parameters: if None will use the initializer
|
|
407
|
+
specified in this instance
|
|
408
|
+
:param print_progress: whether to print the progress bar during training
|
|
409
|
+
'''
|
|
410
|
+
start_time = time.time()
|
|
411
|
+
elapsed_outside_loop = 0
|
|
412
|
+
|
|
413
|
+
# if PRNG key is not provided
|
|
414
|
+
if key is None:
|
|
415
|
+
key = random.PRNGKey(round(time.time() * 1000))
|
|
416
|
+
|
|
417
|
+
# prepare initial subs
|
|
418
|
+
subs = self._batched_init_subs()
|
|
419
|
+
|
|
420
|
+
# initialize parameter fluents to optimize
|
|
421
|
+
if guess is None:
|
|
422
|
+
key, subkey = random.split(key)
|
|
423
|
+
params, opt_state = self.init_fn(subkey)
|
|
424
|
+
else:
|
|
425
|
+
params, opt_state = self.init_opt_fn(guess)
|
|
426
|
+
|
|
427
|
+
# initialize model hyper-parameters
|
|
428
|
+
hyperparams = self.compiled.model_params
|
|
429
|
+
|
|
430
|
+
# progress bar
|
|
431
|
+
if print_progress:
|
|
432
|
+
progress_bar = tqdm(
|
|
433
|
+
None, total=100, bar_format='{l_bar}{bar}| {elapsed} {postfix}')
|
|
434
|
+
else:
|
|
435
|
+
progress_bar = None
|
|
436
|
+
|
|
437
|
+
# main training loop
|
|
438
|
+
for (it, (states, actions, next_states)) in enumerate(data):
|
|
439
|
+
status = JaxLearnerStatus.NORMAL
|
|
440
|
+
|
|
441
|
+
# gradient update
|
|
442
|
+
subs.update(states)
|
|
443
|
+
key, subkey = random.split(key)
|
|
444
|
+
params, opt_state, loss, zero_grads, hyperparams = self.update_fn(
|
|
445
|
+
subkey, params, subs, actions, next_states, hyperparams, opt_state)
|
|
446
|
+
|
|
447
|
+
# extract non-fluent values from the trainable parameters
|
|
448
|
+
param_fluents = self.map_fn(params)
|
|
449
|
+
param_fluents = {name: param_fluents[name] for name in self.param_ranges}
|
|
450
|
+
|
|
451
|
+
# check for learnability
|
|
452
|
+
params_zero_grads = {
|
|
453
|
+
name for (name, zero_grad) in zero_grads.items() if zero_grad}
|
|
454
|
+
if params_zero_grads:
|
|
455
|
+
status = JaxLearnerStatus.NO_PROGRESS
|
|
456
|
+
|
|
457
|
+
# reached computation budget
|
|
458
|
+
elapsed = time.time() - start_time - elapsed_outside_loop
|
|
459
|
+
if elapsed >= train_seconds:
|
|
460
|
+
status = JaxLearnerStatus.TIME_BUDGET_REACHED
|
|
461
|
+
if it >= epochs - 1:
|
|
462
|
+
status = JaxLearnerStatus.ITER_BUDGET_REACHED
|
|
463
|
+
|
|
464
|
+
# build a callback
|
|
465
|
+
progress_percent = 100 * min(
|
|
466
|
+
1, max(0, elapsed / train_seconds, it / (epochs - 1)))
|
|
467
|
+
callback = {
|
|
468
|
+
'status': status,
|
|
469
|
+
'iteration': it,
|
|
470
|
+
'train_loss': loss,
|
|
471
|
+
'params': params,
|
|
472
|
+
'param_fluents': param_fluents,
|
|
473
|
+
'key': key,
|
|
474
|
+
'progress': progress_percent
|
|
475
|
+
}
|
|
476
|
+
|
|
477
|
+
# update progress
|
|
478
|
+
if print_progress:
|
|
479
|
+
progress_bar.set_description(
|
|
480
|
+
f'{it:7} it / {loss:12.8f} train / {status.value} status', refresh=False)
|
|
481
|
+
progress_bar.set_postfix_str(
|
|
482
|
+
f'{(it + 1) / (elapsed + 1e-6):.2f}it/s', refresh=False)
|
|
483
|
+
progress_bar.update(progress_percent - progress_bar.n)
|
|
484
|
+
|
|
485
|
+
# yield the callback
|
|
486
|
+
start_time_outside = time.time()
|
|
487
|
+
yield callback
|
|
488
|
+
elapsed_outside_loop += (time.time() - start_time_outside)
|
|
489
|
+
|
|
490
|
+
# abortion check
|
|
491
|
+
if status.is_terminal():
|
|
492
|
+
break
|
|
493
|
+
|
|
494
|
+
def evaluate_loss(self, data: DataStream,
|
|
495
|
+
key: Optional[random.PRNGKey],
|
|
496
|
+
param_fluents: Params) -> float:
|
|
497
|
+
'''Evaluates the model loss of the given learned non-fluent values and the data.
|
|
498
|
+
|
|
499
|
+
:param data: a data stream represented as a (possibly infinite) sequence of
|
|
500
|
+
transition batches of the form (states, actions, next-states), where each element
|
|
501
|
+
is a numpy array of leading dimension equal to batch_size_train
|
|
502
|
+
:param key: JAX PRNG key (derived from clock if not provided)
|
|
503
|
+
:param param_fluents: the learned non-fluent values
|
|
504
|
+
'''
|
|
505
|
+
if key is None:
|
|
506
|
+
key = random.PRNGKey(round(time.time() * 1000))
|
|
507
|
+
subs = self._batched_init_subs()
|
|
508
|
+
hyperparams = self.compiled.model_params
|
|
509
|
+
mean_loss = 0.0
|
|
510
|
+
for (it, (states, actions, next_states)) in enumerate(data):
|
|
511
|
+
subs.update(states)
|
|
512
|
+
key, subkey = random.split(key)
|
|
513
|
+
loss_value, _ = self.loss_fn(
|
|
514
|
+
subkey, param_fluents, subs, actions, next_states, hyperparams)
|
|
515
|
+
mean_loss += (loss_value - mean_loss) / (it + 1)
|
|
516
|
+
return mean_loss
|
|
517
|
+
|
|
518
|
+
def learned_model(self, param_fluents: Params) -> RDDLLiftedModel:
|
|
519
|
+
'''Substitutes the given learned non-fluent values into the RDDL model and returns
|
|
520
|
+
the new model.
|
|
521
|
+
|
|
522
|
+
:param param_fluents: the learned non-fluent values
|
|
523
|
+
'''
|
|
524
|
+
model = deepcopy(self.rddl)
|
|
525
|
+
for (name, values) in param_fluents.items():
|
|
526
|
+
prange = model.variable_ranges[name]
|
|
527
|
+
if prange == 'real':
|
|
528
|
+
pass
|
|
529
|
+
elif prange == 'bool':
|
|
530
|
+
values = values > 0.5
|
|
531
|
+
else:
|
|
532
|
+
values = np.asarray(values, dtype=self.compiled.INT)
|
|
533
|
+
values = np.ravel(values, order='C').tolist()
|
|
534
|
+
if not self.rddl.variable_params[name]:
|
|
535
|
+
assert(len(values) == 1)
|
|
536
|
+
values = values[0]
|
|
537
|
+
model.non_fluents[name] = values
|
|
538
|
+
return model
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
if __name__ == '__main__':
|
|
542
|
+
import os
|
|
543
|
+
import pyRDDLGym
|
|
544
|
+
from pyRDDLGym_jax.core.planner import load_config, JaxBackpropPlanner, JaxOfflineController
|
|
545
|
+
bs = 32
|
|
546
|
+
|
|
547
|
+
# make some data
|
|
548
|
+
def data_iterator():
|
|
549
|
+
env = pyRDDLGym.make('CartPole_Continuous_gym', '0', vectorized=True)
|
|
550
|
+
model = JaxModelLearner(rddl=env.model, param_ranges={}, batch_size_train=bs)
|
|
551
|
+
key = random.PRNGKey(round(time.time() * 1000))
|
|
552
|
+
subs = model._batched_init_subs()
|
|
553
|
+
param_fluents = {}
|
|
554
|
+
while True:
|
|
555
|
+
states = {
|
|
556
|
+
'pos': np.random.uniform(-2.4, 2.4, (bs,)),
|
|
557
|
+
'vel': np.random.uniform(-2.4, 2.4, (bs,)),
|
|
558
|
+
'ang-pos': np.random.uniform(-0.21, 0.21, (bs,)),
|
|
559
|
+
'ang-vel': np.random.uniform(-0.21, 0.21, (bs,))
|
|
560
|
+
}
|
|
561
|
+
subs.update(states)
|
|
562
|
+
actions = {
|
|
563
|
+
'force': np.random.uniform(-10., 10., (bs,))
|
|
564
|
+
}
|
|
565
|
+
key, subkey = random.split(key)
|
|
566
|
+
subs, _ = model.step_fn(subkey, param_fluents, subs, actions, {})
|
|
567
|
+
subs = {k: np.asarray(v)[0, ...] for k, v in subs.items()}
|
|
568
|
+
next_states = {k: subs[k] for k in model.rddl.state_fluents}
|
|
569
|
+
yield (states, actions, next_states)
|
|
570
|
+
|
|
571
|
+
# train it
|
|
572
|
+
env = pyRDDLGym.make('TestJax', '0', vectorized=True)
|
|
573
|
+
model_learner = JaxModelLearner(rddl=env.model,
|
|
574
|
+
param_ranges={
|
|
575
|
+
'w1': (None, None), 'b1': (None, None),
|
|
576
|
+
'w2': (None, None), 'b2': (None, None),
|
|
577
|
+
'w1o': (None, None), 'b1o': (None, None),
|
|
578
|
+
'w2o': (None, None), 'b2o': (None, None)
|
|
579
|
+
},
|
|
580
|
+
batch_size_train=bs,
|
|
581
|
+
optimizer_kwargs = {'learning_rate': 0.0003})
|
|
582
|
+
for cb in model_learner.optimize_generator(data_iterator(), epochs=10000):
|
|
583
|
+
pass
|
|
584
|
+
|
|
585
|
+
# planning in the trained model
|
|
586
|
+
model = model_learner.learned_model(cb['param_fluents'])
|
|
587
|
+
abs_path = os.path.dirname(os.path.abspath(__file__))
|
|
588
|
+
config_path = os.path.join(os.path.dirname(abs_path), 'examples', 'configs', 'default_drp.cfg')
|
|
589
|
+
planner_args, _, train_args = load_config(config_path)
|
|
590
|
+
planner = JaxBackpropPlanner(rddl=model, **planner_args)
|
|
591
|
+
controller = JaxOfflineController(planner, **train_args)
|
|
592
|
+
|
|
593
|
+
# evaluation of the plan
|
|
594
|
+
test_env = pyRDDLGym.make('CartPole_Continuous_gym', '0', vectorized=True)
|
|
595
|
+
controller.evaluate(test_env, episodes=1, verbose=True, render=True)
|
pyRDDLGym_jax/core/planner.py
CHANGED
|
@@ -207,6 +207,13 @@ def _load_config(config, args):
|
|
|
207
207
|
pgpe_kwargs['optimizer'] = pgpe_optimizer
|
|
208
208
|
planner_args['pgpe'] = getattr(sys.modules[__name__], pgpe_method)(**pgpe_kwargs)
|
|
209
209
|
|
|
210
|
+
# preprocessor settings
|
|
211
|
+
preproc_method = planner_args.get('preprocessor', None)
|
|
212
|
+
preproc_kwargs = planner_args.pop('preprocessor_kwargs', {})
|
|
213
|
+
if preproc_method is not None:
|
|
214
|
+
planner_args['preprocessor'] = getattr(
|
|
215
|
+
sys.modules[__name__], preproc_method)(**preproc_kwargs)
|
|
216
|
+
|
|
210
217
|
# optimize call RNG key
|
|
211
218
|
planner_key = train_args.get('key', None)
|
|
212
219
|
if planner_key is not None:
|
|
@@ -343,6 +350,100 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
343
350
|
return arg
|
|
344
351
|
|
|
345
352
|
|
|
353
|
+
# ***********************************************************************
|
|
354
|
+
# ALL VERSIONS OF STATE PREPROCESSING FOR DRP
|
|
355
|
+
#
|
|
356
|
+
# - static normalization
|
|
357
|
+
#
|
|
358
|
+
# ***********************************************************************
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
class Preprocessor(metaclass=ABCMeta):
|
|
362
|
+
'''Base class for all state preprocessors.'''
|
|
363
|
+
|
|
364
|
+
HYPERPARAMS_KEY = 'preprocessor__'
|
|
365
|
+
|
|
366
|
+
def __init__(self) -> None:
|
|
367
|
+
self._initializer = None
|
|
368
|
+
self._update = None
|
|
369
|
+
self._transform = None
|
|
370
|
+
|
|
371
|
+
@property
|
|
372
|
+
def initialize(self):
|
|
373
|
+
return self._initializer
|
|
374
|
+
|
|
375
|
+
@property
|
|
376
|
+
def update(self):
|
|
377
|
+
return self._update
|
|
378
|
+
|
|
379
|
+
@property
|
|
380
|
+
def transform(self):
|
|
381
|
+
return self._transform
|
|
382
|
+
|
|
383
|
+
@abstractmethod
|
|
384
|
+
def compile(self, compiled: JaxRDDLCompilerWithGrad) -> None:
|
|
385
|
+
pass
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
class StaticNormalizer(Preprocessor):
|
|
389
|
+
'''Normalize values by box constraints on fluents computed from the RDDL domain.'''
|
|
390
|
+
|
|
391
|
+
def __init__(self, fluent_bounds: Dict[str, Tuple[np.ndarray, np.ndarray]]={}) -> None:
|
|
392
|
+
'''Create a new instance of the static normalizer.
|
|
393
|
+
|
|
394
|
+
:param fluent_bounds: optional bounds on fluents to overwrite default values.
|
|
395
|
+
'''
|
|
396
|
+
self.fluent_bounds = fluent_bounds
|
|
397
|
+
|
|
398
|
+
def compile(self, compiled: JaxRDDLCompilerWithGrad) -> None:
|
|
399
|
+
|
|
400
|
+
# adjust for partial observability
|
|
401
|
+
rddl = compiled.rddl
|
|
402
|
+
if rddl.observ_fluents:
|
|
403
|
+
observed_vars = rddl.observ_fluents
|
|
404
|
+
else:
|
|
405
|
+
observed_vars = rddl.state_fluents
|
|
406
|
+
|
|
407
|
+
# ignore boolean fluents and infinite bounds
|
|
408
|
+
bounded_vars = {}
|
|
409
|
+
for var in observed_vars:
|
|
410
|
+
if rddl.variable_ranges[var] != 'bool':
|
|
411
|
+
lower, upper = compiled.constraints.bounds[var]
|
|
412
|
+
if np.all(np.isfinite(lower) & np.isfinite(upper) & (lower < upper)):
|
|
413
|
+
bounded_vars[var] = (lower, upper)
|
|
414
|
+
user_bounds = self.fluent_bounds.get(var, None)
|
|
415
|
+
if user_bounds is not None:
|
|
416
|
+
bounded_vars[var] = tuple(user_bounds)
|
|
417
|
+
|
|
418
|
+
# initialize to ranges computed by the constraint parser
|
|
419
|
+
def _jax_wrapped_normalizer_init():
|
|
420
|
+
return bounded_vars
|
|
421
|
+
self._initializer = jax.jit(_jax_wrapped_normalizer_init)
|
|
422
|
+
|
|
423
|
+
# static bounds
|
|
424
|
+
def _jax_wrapped_normalizer_update(subs, stats):
|
|
425
|
+
stats = {var: (jnp.asarray(lower, dtype=compiled.REAL),
|
|
426
|
+
jnp.asarray(upper, dtype=compiled.REAL))
|
|
427
|
+
for (var, (lower, upper)) in bounded_vars.items()}
|
|
428
|
+
return stats
|
|
429
|
+
self._update = jax.jit(_jax_wrapped_normalizer_update)
|
|
430
|
+
|
|
431
|
+
# apply min max scaling
|
|
432
|
+
def _jax_wrapped_normalizer_transform(subs, stats):
|
|
433
|
+
new_subs = {}
|
|
434
|
+
for (var, values) in subs.items():
|
|
435
|
+
if var in stats:
|
|
436
|
+
lower, upper = stats[var]
|
|
437
|
+
new_dims = jnp.ndim(values) - jnp.ndim(lower)
|
|
438
|
+
lower = lower[(jnp.newaxis,) * new_dims + (...,)]
|
|
439
|
+
upper = upper[(jnp.newaxis,) * new_dims + (...,)]
|
|
440
|
+
new_subs[var] = (values - lower) / (upper - lower)
|
|
441
|
+
else:
|
|
442
|
+
new_subs[var] = values
|
|
443
|
+
return new_subs
|
|
444
|
+
self._transform = jax.jit(_jax_wrapped_normalizer_transform)
|
|
445
|
+
|
|
446
|
+
|
|
346
447
|
# ***********************************************************************
|
|
347
448
|
# ALL VERSIONS OF JAX PLANS
|
|
348
449
|
#
|
|
@@ -368,7 +469,8 @@ class JaxPlan(metaclass=ABCMeta):
|
|
|
368
469
|
@abstractmethod
|
|
369
470
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
370
471
|
_bounds: Bounds,
|
|
371
|
-
horizon: int
|
|
472
|
+
horizon: int,
|
|
473
|
+
preprocessor: Optional[Preprocessor]=None) -> None:
|
|
372
474
|
pass
|
|
373
475
|
|
|
374
476
|
@abstractmethod
|
|
@@ -519,7 +621,8 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
519
621
|
|
|
520
622
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
521
623
|
_bounds: Bounds,
|
|
522
|
-
horizon: int
|
|
624
|
+
horizon: int,
|
|
625
|
+
preprocessor: Optional[Preprocessor]=None) -> None:
|
|
523
626
|
rddl = compiled.rddl
|
|
524
627
|
|
|
525
628
|
# calculate the correct action box bounds
|
|
@@ -607,7 +710,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
607
710
|
return new_params, True
|
|
608
711
|
|
|
609
712
|
# convert softmax action back to action dict
|
|
610
|
-
action_sizes = {var: np.prod(shape[1:], dtype=
|
|
713
|
+
action_sizes = {var: np.prod(shape[1:], dtype=np.int64)
|
|
611
714
|
for (var, shape) in shapes.items()
|
|
612
715
|
if ranges[var] == 'bool'}
|
|
613
716
|
|
|
@@ -691,7 +794,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
691
794
|
scores = []
|
|
692
795
|
for (var, param) in params.items():
|
|
693
796
|
if ranges[var] == 'bool':
|
|
694
|
-
param_flat = jnp.ravel(param)
|
|
797
|
+
param_flat = jnp.ravel(param, order='C')
|
|
695
798
|
if noop[var]:
|
|
696
799
|
if wrap_sigmoid:
|
|
697
800
|
param_flat = -param_flat
|
|
@@ -908,7 +1011,8 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
908
1011
|
|
|
909
1012
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
910
1013
|
_bounds: Bounds,
|
|
911
|
-
horizon: int
|
|
1014
|
+
horizon: int,
|
|
1015
|
+
preprocessor: Optional[Preprocessor]=None) -> None:
|
|
912
1016
|
rddl = compiled.rddl
|
|
913
1017
|
|
|
914
1018
|
# calculate the correct action box bounds
|
|
@@ -939,7 +1043,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
939
1043
|
wrap_non_bool = self._wrap_non_bool
|
|
940
1044
|
init = self._initializer
|
|
941
1045
|
layers = list(enumerate(zip(self._topology, self._activations)))
|
|
942
|
-
layer_sizes = {var: np.prod(shape, dtype=
|
|
1046
|
+
layer_sizes = {var: np.prod(shape, dtype=np.int64)
|
|
943
1047
|
for (var, shape) in shapes.items()}
|
|
944
1048
|
layer_names = {var: f'output_{var}'.replace('-', '_') for var in shapes}
|
|
945
1049
|
|
|
@@ -973,7 +1077,12 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
973
1077
|
normalize = False
|
|
974
1078
|
|
|
975
1079
|
# convert subs dictionary into a state vector to feed to the MLP
|
|
976
|
-
def _jax_wrapped_policy_input(subs):
|
|
1080
|
+
def _jax_wrapped_policy_input(subs, hyperparams):
|
|
1081
|
+
|
|
1082
|
+
# optional state preprocessing
|
|
1083
|
+
if preprocessor is not None:
|
|
1084
|
+
stats = hyperparams[preprocessor.HYPERPARAMS_KEY]
|
|
1085
|
+
subs = preprocessor.transform(subs, stats)
|
|
977
1086
|
|
|
978
1087
|
# concatenate all state variables into a single vector
|
|
979
1088
|
# optionally apply layer norm to each input tensor
|
|
@@ -981,7 +1090,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
981
1090
|
non_bool_dims = 0
|
|
982
1091
|
for (var, value) in subs.items():
|
|
983
1092
|
if var in observed_vars:
|
|
984
|
-
state = jnp.ravel(value)
|
|
1093
|
+
state = jnp.ravel(value, order='C')
|
|
985
1094
|
if ranges[var] == 'bool':
|
|
986
1095
|
states_bool.append(state)
|
|
987
1096
|
else:
|
|
@@ -1010,8 +1119,8 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1010
1119
|
return state
|
|
1011
1120
|
|
|
1012
1121
|
# predict actions from the policy network for current state
|
|
1013
|
-
def _jax_wrapped_policy_network_predict(subs):
|
|
1014
|
-
state = _jax_wrapped_policy_input(subs)
|
|
1122
|
+
def _jax_wrapped_policy_network_predict(subs, hyperparams):
|
|
1123
|
+
state = _jax_wrapped_policy_input(subs, hyperparams)
|
|
1015
1124
|
|
|
1016
1125
|
# feed state vector through hidden layers
|
|
1017
1126
|
hidden = state
|
|
@@ -1076,7 +1185,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1076
1185
|
|
|
1077
1186
|
# train action prediction
|
|
1078
1187
|
def _jax_wrapped_drp_predict_train(key, params, hyperparams, step, subs):
|
|
1079
|
-
actions = predict_fn.apply(params, subs)
|
|
1188
|
+
actions = predict_fn.apply(params, subs, hyperparams)
|
|
1080
1189
|
if not wrap_non_bool:
|
|
1081
1190
|
for (var, action) in actions.items():
|
|
1082
1191
|
if var != bool_key and ranges[var] != 'bool':
|
|
@@ -1126,7 +1235,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1126
1235
|
subs = {var: value[0, ...]
|
|
1127
1236
|
for (var, value) in subs.items()
|
|
1128
1237
|
if var in observed_vars}
|
|
1129
|
-
params = predict_fn.init(key, subs)
|
|
1238
|
+
params = predict_fn.init(key, subs, hyperparams)
|
|
1130
1239
|
return params
|
|
1131
1240
|
|
|
1132
1241
|
self.initializer = _jax_wrapped_drp_init
|
|
@@ -1634,12 +1743,21 @@ def mean_semivariance_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
|
1634
1743
|
return mu - 0.5 * beta * msv
|
|
1635
1744
|
|
|
1636
1745
|
|
|
1746
|
+
@jax.jit
|
|
1747
|
+
def sharpe_utility(returns: jnp.ndarray, risk_free: float) -> float:
|
|
1748
|
+
return (jnp.mean(returns) - risk_free) / (jnp.std(returns) + 1e-10)
|
|
1749
|
+
|
|
1750
|
+
|
|
1751
|
+
@jax.jit
|
|
1752
|
+
def var_utility(returns: jnp.ndarray, alpha: float) -> float:
|
|
1753
|
+
return jnp.percentile(returns, q=100 * alpha)
|
|
1754
|
+
|
|
1755
|
+
|
|
1637
1756
|
@jax.jit
|
|
1638
1757
|
def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
|
|
1639
1758
|
var = jnp.percentile(returns, q=100 * alpha)
|
|
1640
1759
|
mask = returns <= var
|
|
1641
|
-
|
|
1642
|
-
return jnp.sum(returns * weights)
|
|
1760
|
+
return jnp.sum(returns * mask) / jnp.maximum(1, jnp.sum(mask))
|
|
1643
1761
|
|
|
1644
1762
|
|
|
1645
1763
|
# set of all currently valid built-in utility functions
|
|
@@ -1649,8 +1767,10 @@ UTILITY_LOOKUP = {
|
|
|
1649
1767
|
'mean_std': mean_deviation_utility,
|
|
1650
1768
|
'mean_semivar': mean_semivariance_utility,
|
|
1651
1769
|
'mean_semidev': mean_semideviation_utility,
|
|
1770
|
+
'sharpe': sharpe_utility,
|
|
1652
1771
|
'entropic': entropic_utility,
|
|
1653
1772
|
'exponential': entropic_utility,
|
|
1773
|
+
'var': var_utility,
|
|
1654
1774
|
'cvar': cvar_utility
|
|
1655
1775
|
}
|
|
1656
1776
|
|
|
@@ -1689,7 +1809,8 @@ class JaxBackpropPlanner:
|
|
|
1689
1809
|
logger: Optional[Logger]=None,
|
|
1690
1810
|
dashboard_viz: Optional[Any]=None,
|
|
1691
1811
|
print_warnings: bool=True,
|
|
1692
|
-
parallel_updates: Optional[int]=None
|
|
1812
|
+
parallel_updates: Optional[int]=None,
|
|
1813
|
+
preprocessor: Optional[Preprocessor]=None) -> None:
|
|
1693
1814
|
'''Creates a new gradient-based algorithm for optimizing action sequences
|
|
1694
1815
|
(plan) in the given RDDL. Some operations will be converted to their
|
|
1695
1816
|
differentiable counterparts; the specific operations can be customized
|
|
@@ -1731,6 +1852,7 @@ class JaxBackpropPlanner:
|
|
|
1731
1852
|
to pass to the dashboard to visualize the policy
|
|
1732
1853
|
:param print_warnings: whether to print warnings
|
|
1733
1854
|
:param parallel_updates: how many optimizers to run independently in parallel
|
|
1855
|
+
:param preprocessor: optional preprocessor for state inputs to plan
|
|
1734
1856
|
'''
|
|
1735
1857
|
self.rddl = rddl
|
|
1736
1858
|
self.plan = plan
|
|
@@ -1756,6 +1878,7 @@ class JaxBackpropPlanner:
|
|
|
1756
1878
|
self.pgpe = pgpe
|
|
1757
1879
|
self.use_pgpe = pgpe is not None
|
|
1758
1880
|
self.print_warnings = print_warnings
|
|
1881
|
+
self.preprocessor = preprocessor
|
|
1759
1882
|
|
|
1760
1883
|
# set optimizer
|
|
1761
1884
|
try:
|
|
@@ -1881,7 +2004,8 @@ r"""
|
|
|
1881
2004
|
f' noise_kwargs ={self.noise_kwargs}\n'
|
|
1882
2005
|
f' batch_size_train ={self.batch_size_train}\n'
|
|
1883
2006
|
f' batch_size_test ={self.batch_size_test}\n'
|
|
1884
|
-
f' parallel_updates ={self.parallel_updates}\n'
|
|
2007
|
+
f' parallel_updates ={self.parallel_updates}\n'
|
|
2008
|
+
f' preprocessor ={self.preprocessor}\n')
|
|
1885
2009
|
result += str(self.plan)
|
|
1886
2010
|
if self.use_pgpe:
|
|
1887
2011
|
result += str(self.pgpe)
|
|
@@ -1917,10 +2041,15 @@ r"""
|
|
|
1917
2041
|
|
|
1918
2042
|
def _jax_compile_optimizer(self):
|
|
1919
2043
|
|
|
2044
|
+
# preprocessor
|
|
2045
|
+
if self.preprocessor is not None:
|
|
2046
|
+
self.preprocessor.compile(self.compiled)
|
|
2047
|
+
|
|
1920
2048
|
# policy
|
|
1921
2049
|
self.plan.compile(self.compiled,
|
|
1922
2050
|
_bounds=self._action_bounds,
|
|
1923
|
-
horizon=self.horizon
|
|
2051
|
+
horizon=self.horizon,
|
|
2052
|
+
preprocessor=self.preprocessor)
|
|
1924
2053
|
self.train_policy = jax.jit(self.plan.train_policy)
|
|
1925
2054
|
self.test_policy = jax.jit(self.plan.test_policy)
|
|
1926
2055
|
|
|
@@ -1928,14 +2057,16 @@ r"""
|
|
|
1928
2057
|
train_rollouts = self.compiled.compile_rollouts(
|
|
1929
2058
|
policy=self.plan.train_policy,
|
|
1930
2059
|
n_steps=self.horizon,
|
|
1931
|
-
n_batch=self.batch_size_train
|
|
2060
|
+
n_batch=self.batch_size_train,
|
|
2061
|
+
cache_path_info=self.preprocessor is not None
|
|
1932
2062
|
)
|
|
1933
2063
|
self.train_rollouts = train_rollouts
|
|
1934
2064
|
|
|
1935
2065
|
test_rollouts = self.test_compiled.compile_rollouts(
|
|
1936
2066
|
policy=self.plan.test_policy,
|
|
1937
2067
|
n_steps=self.horizon,
|
|
1938
|
-
n_batch=self.batch_size_test
|
|
2068
|
+
n_batch=self.batch_size_test,
|
|
2069
|
+
cache_path_info=False
|
|
1939
2070
|
)
|
|
1940
2071
|
self.test_rollouts = jax.jit(test_rollouts)
|
|
1941
2072
|
|
|
@@ -2397,7 +2528,13 @@ r"""
|
|
|
2397
2528
|
f'which could be suboptimal.', 'yellow')
|
|
2398
2529
|
print(message)
|
|
2399
2530
|
policy_hyperparams[action] = 1.0
|
|
2400
|
-
|
|
2531
|
+
|
|
2532
|
+
# initialize preprocessor
|
|
2533
|
+
preproc_key = None
|
|
2534
|
+
if self.preprocessor is not None:
|
|
2535
|
+
preproc_key = self.preprocessor.HYPERPARAMS_KEY
|
|
2536
|
+
policy_hyperparams[preproc_key] = self.preprocessor.initialize()
|
|
2537
|
+
|
|
2401
2538
|
# print summary of parameters:
|
|
2402
2539
|
if print_summary:
|
|
2403
2540
|
print(self.summarize_system())
|
|
@@ -2524,6 +2661,11 @@ r"""
|
|
|
2524
2661
|
subkey, policy_params, policy_hyperparams, train_subs, model_params,
|
|
2525
2662
|
opt_state, opt_aux)
|
|
2526
2663
|
|
|
2664
|
+
# update the preprocessor
|
|
2665
|
+
if self.preprocessor is not None:
|
|
2666
|
+
policy_hyperparams[preproc_key] = self.preprocessor.update(
|
|
2667
|
+
train_log['fluents'], policy_hyperparams[preproc_key])
|
|
2668
|
+
|
|
2527
2669
|
# evaluate
|
|
2528
2670
|
test_loss, (test_log, model_params_test) = self.test_loss(
|
|
2529
2671
|
subkey, policy_params, policy_hyperparams, test_subs, model_params_test)
|
|
@@ -2676,6 +2818,7 @@ r"""
|
|
|
2676
2818
|
'model_params': model_params,
|
|
2677
2819
|
'progress': progress_percent,
|
|
2678
2820
|
'train_log': train_log,
|
|
2821
|
+
'policy_hyperparams': policy_hyperparams,
|
|
2679
2822
|
**test_log
|
|
2680
2823
|
}
|
|
2681
2824
|
|
|
@@ -2753,7 +2896,8 @@ r"""
|
|
|
2753
2896
|
|
|
2754
2897
|
def _perform_diagnosis(self, last_iter_improve,
|
|
2755
2898
|
train_return, test_return, best_return, grad_norm):
|
|
2756
|
-
|
|
2899
|
+
grad_norms = jax.tree_util.tree_leaves(grad_norm)
|
|
2900
|
+
max_grad_norm = max(grad_norms) if grad_norms else np.nan
|
|
2757
2901
|
grad_is_zero = np.allclose(max_grad_norm, 0)
|
|
2758
2902
|
|
|
2759
2903
|
# divergence if the solution is not finite
|
|
@@ -2895,6 +3039,7 @@ class JaxOfflineController(BaseAgent):
|
|
|
2895
3039
|
self.train_on_reset = train_on_reset
|
|
2896
3040
|
self.train_kwargs = train_kwargs
|
|
2897
3041
|
self.params_given = params is not None
|
|
3042
|
+
self.hyperparams_given = eval_hyperparams is not None
|
|
2898
3043
|
|
|
2899
3044
|
# load the policy from file
|
|
2900
3045
|
if not self.train_on_reset and params is not None and isinstance(params, str):
|
|
@@ -2908,6 +3053,8 @@ class JaxOfflineController(BaseAgent):
|
|
|
2908
3053
|
callback = self.planner.optimize(key=self.key, **self.train_kwargs)
|
|
2909
3054
|
self.callback = callback
|
|
2910
3055
|
params = callback['best_params']
|
|
3056
|
+
if not self.hyperparams_given:
|
|
3057
|
+
self.eval_hyperparams = callback['policy_hyperparams']
|
|
2911
3058
|
|
|
2912
3059
|
# save the policy
|
|
2913
3060
|
if save_path is not None:
|
|
@@ -2931,6 +3078,8 @@ class JaxOfflineController(BaseAgent):
|
|
|
2931
3078
|
callback = self.planner.optimize(key=self.key, **self.train_kwargs)
|
|
2932
3079
|
self.callback = callback
|
|
2933
3080
|
self.params = callback['best_params']
|
|
3081
|
+
if not self.hyperparams_given:
|
|
3082
|
+
self.eval_hyperparams = callback['policy_hyperparams']
|
|
2934
3083
|
|
|
2935
3084
|
|
|
2936
3085
|
class JaxOnlineController(BaseAgent):
|
|
@@ -2963,6 +3112,7 @@ class JaxOnlineController(BaseAgent):
|
|
|
2963
3112
|
key = random.PRNGKey(round(time.time() * 1000))
|
|
2964
3113
|
self.key = key
|
|
2965
3114
|
self.eval_hyperparams = eval_hyperparams
|
|
3115
|
+
self.hyperparams_given = eval_hyperparams is not None
|
|
2966
3116
|
self.warm_start = warm_start
|
|
2967
3117
|
self.train_kwargs = train_kwargs
|
|
2968
3118
|
self.max_attempts = max_attempts
|
|
@@ -2987,6 +3137,8 @@ class JaxOnlineController(BaseAgent):
|
|
|
2987
3137
|
key=self.key, guess=self.guess, subs=state, **self.train_kwargs)
|
|
2988
3138
|
self.callback = callback
|
|
2989
3139
|
params = callback['best_params']
|
|
3140
|
+
if not self.hyperparams_given:
|
|
3141
|
+
self.eval_hyperparams = callback['policy_hyperparams']
|
|
2990
3142
|
|
|
2991
3143
|
# get the action from the parameters for the current state
|
|
2992
3144
|
self.key, subkey = random.split(self.key)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pyRDDLGym-jax
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.6
|
|
4
4
|
Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
|
|
5
5
|
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
6
|
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
@@ -20,7 +20,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
20
20
|
Requires-Python: >=3.9
|
|
21
21
|
Description-Content-Type: text/markdown
|
|
22
22
|
License-File: LICENSE
|
|
23
|
-
Requires-Dist: pyRDDLGym>=2.
|
|
23
|
+
Requires-Dist: pyRDDLGym>=2.3
|
|
24
24
|
Requires-Dist: tqdm>=4.66
|
|
25
25
|
Requires-Dist: jax>=0.4.12
|
|
26
26
|
Requires-Dist: optax>=0.1.9
|
|
@@ -55,7 +55,7 @@ Dynamic: summary
|
|
|
55
55
|
|
|
56
56
|
[Installation](#installation) | [Run cmd](#running-from-the-command-line) | [Run python](#running-from-another-python-application) | [Configuration](#configuring-the-planner) | [Dashboard](#jaxplan-dashboard) | [Tuning](#tuning-the-planner) | [Simulation](#simulation) | [Citing](#citing-jaxplan)
|
|
57
57
|
|
|
58
|
-
**pyRDDLGym-jax (
|
|
58
|
+
**pyRDDLGym-jax (or JaxPlan) is an efficient gradient-based planning algorithm based on JAX.**
|
|
59
59
|
|
|
60
60
|
Purpose:
|
|
61
61
|
|
|
@@ -84,7 +84,7 @@ and was moved to the individual logic components which have their own unique wei
|
|
|
84
84
|
|
|
85
85
|
> [!NOTE]
|
|
86
86
|
> While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
|
|
87
|
-
> If you find it is not making
|
|
87
|
+
> If you find it is not making progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
|
|
88
88
|
|
|
89
89
|
## Installation
|
|
90
90
|
|
|
@@ -220,13 +220,7 @@ controller = JaxOfflineController(planner, **train_args)
|
|
|
220
220
|
## JaxPlan Dashboard
|
|
221
221
|
|
|
222
222
|
Since version 1.0, JaxPlan has an optional dashboard that allows keeping track of the planner performance across multiple runs,
|
|
223
|
-
and visualization of the policy or model, and other useful debugging features.
|
|
224
|
-
|
|
225
|
-
<p align="middle">
|
|
226
|
-
<img src="https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/Images/dashboard.png" width="480" height="248" margin=0/>
|
|
227
|
-
</p>
|
|
228
|
-
|
|
229
|
-
To run the dashboard, add the following entry to your config file:
|
|
223
|
+
and visualization of the policy or model, and other useful debugging features. To run the dashboard, add the following to your config file:
|
|
230
224
|
|
|
231
225
|
```ini
|
|
232
226
|
...
|
|
@@ -235,8 +229,6 @@ dashboard=True
|
|
|
235
229
|
...
|
|
236
230
|
```
|
|
237
231
|
|
|
238
|
-
More documentation about this and other new features will be coming soon.
|
|
239
|
-
|
|
240
232
|
## Tuning the Planner
|
|
241
233
|
|
|
242
234
|
A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
|
|
@@ -1,9 +1,10 @@
|
|
|
1
|
-
pyRDDLGym_jax/__init__.py,sha256=
|
|
1
|
+
pyRDDLGym_jax/__init__.py,sha256=VUmQViJtwUg1JGcgXlmNm0fE3Njyruyt_76c16R-LTo,19
|
|
2
2
|
pyRDDLGym_jax/entry_point.py,sha256=K0zy1oe66jfBHkHHCM6aGHbbiVqnQvDhDb8se4uaKHE,3319
|
|
3
3
|
pyRDDLGym_jax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
|
-
pyRDDLGym_jax/core/compiler.py,sha256=
|
|
5
|
-
pyRDDLGym_jax/core/logic.py,sha256=
|
|
6
|
-
pyRDDLGym_jax/core/
|
|
4
|
+
pyRDDLGym_jax/core/compiler.py,sha256=Bpgfw4nqRFqiTju7ioR0B0Dhp3wMvk-9LmTRpMmLIOc,83457
|
|
5
|
+
pyRDDLGym_jax/core/logic.py,sha256=9rRpKJCx4Us_2c6BiSWRN9k2sM_iYsAK1B7zcgwu3ZA,56290
|
|
6
|
+
pyRDDLGym_jax/core/model.py,sha256=4WfmtUVN1EKCD-7eWeQByWk8_zKyDcMABAMdlxN1LOU,27215
|
|
7
|
+
pyRDDLGym_jax/core/planner.py,sha256=a684ss5TAkJ-P2SEbZA90FSpDwFxHwRoaLtbRIBspAA,146450
|
|
7
8
|
pyRDDLGym_jax/core/simulator.py,sha256=ayCATTUL3clLaZPQ5OUg2bI_c26KKCTq6TbrxbMsVdc,10470
|
|
8
9
|
pyRDDLGym_jax/core/tuning.py,sha256=BWcQZk02TMLexTz1Sw4lX2EQKvmPbp7biC51M-IiNUw,25153
|
|
9
10
|
pyRDDLGym_jax/core/visualization.py,sha256=4BghMp8N7qtF0tdyDSqtxAxNfP9HPrQWTiXzAMJmx7o,70365
|
|
@@ -41,9 +42,9 @@ pyRDDLGym_jax/examples/configs/default_slp.cfg,sha256=mJo0woDevhQCSQfJg30ULVy9qG
|
|
|
41
42
|
pyRDDLGym_jax/examples/configs/tuning_drp.cfg,sha256=zocZn_cVarH5i0hOlt2Zu0NwmXYBmTTghLaXLtQOGto,526
|
|
42
43
|
pyRDDLGym_jax/examples/configs/tuning_replan.cfg,sha256=9oIhtw9cuikmlbDgCgbrTc5G7hUio-HeAv_3CEGVclY,523
|
|
43
44
|
pyRDDLGym_jax/examples/configs/tuning_slp.cfg,sha256=QqnyR__5-HhKeCDfGDel8VIlqsjxRHk4SSH089zJP8s,486
|
|
44
|
-
pyrddlgym_jax-2.
|
|
45
|
-
pyrddlgym_jax-2.
|
|
46
|
-
pyrddlgym_jax-2.
|
|
47
|
-
pyrddlgym_jax-2.
|
|
48
|
-
pyrddlgym_jax-2.
|
|
49
|
-
pyrddlgym_jax-2.
|
|
45
|
+
pyrddlgym_jax-2.6.dist-info/licenses/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
|
|
46
|
+
pyrddlgym_jax-2.6.dist-info/METADATA,sha256=1gY3EPRHKMVeZYYgq4DCqWvw3Q1Ak5XVYRaIO2UlQXc,16770
|
|
47
|
+
pyrddlgym_jax-2.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
48
|
+
pyrddlgym_jax-2.6.dist-info/entry_points.txt,sha256=Q--z9QzqDBz1xjswPZ87PU-pib-WPXx44hUWAFoBGBA,59
|
|
49
|
+
pyrddlgym_jax-2.6.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
|
|
50
|
+
pyrddlgym_jax-2.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|