pyRDDLGym-jax 2.4__py3-none-any.whl → 2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +8 -4
- pyRDDLGym_jax/core/planner.py +144 -78
- pyRDDLGym_jax/core/simulator.py +37 -13
- pyRDDLGym_jax/core/tuning.py +25 -10
- pyRDDLGym_jax/entry_point.py +39 -7
- pyRDDLGym_jax/examples/configs/tuning_drp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/tuning_replan.cfg +1 -0
- pyRDDLGym_jax/examples/configs/tuning_slp.cfg +1 -0
- pyRDDLGym_jax/examples/run_plan.py +1 -1
- pyRDDLGym_jax/examples/run_tune.py +8 -2
- {pyrddlgym_jax-2.4.dist-info → pyrddlgym_jax-2.5.dist-info}/METADATA +13 -18
- {pyrddlgym_jax-2.4.dist-info → pyrddlgym_jax-2.5.dist-info}/RECORD +17 -17
- {pyrddlgym_jax-2.4.dist-info → pyrddlgym_jax-2.5.dist-info}/WHEEL +1 -1
- {pyrddlgym_jax-2.4.dist-info → pyrddlgym_jax-2.5.dist-info}/entry_points.txt +0 -0
- {pyrddlgym_jax-2.4.dist-info → pyrddlgym_jax-2.5.dist-info/licenses}/LICENSE +0 -0
- {pyrddlgym_jax-2.4.dist-info → pyrddlgym_jax-2.5.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '2.
|
|
1
|
+
__version__ = '2.5'
|
pyRDDLGym_jax/core/compiler.py
CHANGED
|
@@ -430,7 +430,7 @@ class JaxRDDLCompiler:
|
|
|
430
430
|
_jax_wrapped_single_step_policy,
|
|
431
431
|
in_axes=(0, None, None, None, 0, None)
|
|
432
432
|
)(keys, policy_params, hyperparams, step, subs, model_params)
|
|
433
|
-
model_params = jax.tree_map(partial(jnp.mean, axis=0), model_params)
|
|
433
|
+
model_params = jax.tree_util.tree_map(partial(jnp.mean, axis=0), model_params)
|
|
434
434
|
carry = (key, policy_params, hyperparams, subs, model_params)
|
|
435
435
|
return carry, log
|
|
436
436
|
|
|
@@ -440,7 +440,7 @@ class JaxRDDLCompiler:
|
|
|
440
440
|
start = (key, policy_params, hyperparams, subs, model_params)
|
|
441
441
|
steps = jnp.arange(n_steps)
|
|
442
442
|
end, log = jax.lax.scan(_jax_wrapped_batched_step_policy, start, steps)
|
|
443
|
-
log = jax.tree_map(partial(jnp.swapaxes, axis1=0, axis2=1), log)
|
|
443
|
+
log = jax.tree_util.tree_map(partial(jnp.swapaxes, axis1=0, axis2=1), log)
|
|
444
444
|
model_params = end[-1]
|
|
445
445
|
return log, model_params
|
|
446
446
|
|
|
@@ -707,7 +707,10 @@ class JaxRDDLCompiler:
|
|
|
707
707
|
sample = jnp.asarray(value, dtype=self._fix_dtype(value))
|
|
708
708
|
new_slices = [None] * len(jax_nested_expr)
|
|
709
709
|
for (i, jax_expr) in enumerate(jax_nested_expr):
|
|
710
|
-
|
|
710
|
+
new_slice, key, err, params = jax_expr(x, params, key)
|
|
711
|
+
if not jnp.issubdtype(jnp.result_type(new_slice), jnp.integer):
|
|
712
|
+
new_slice = jnp.asarray(new_slice, dtype=self.INT)
|
|
713
|
+
new_slices[i] = new_slice
|
|
711
714
|
error |= err
|
|
712
715
|
new_slices = tuple(new_slices)
|
|
713
716
|
sample = sample[new_slices]
|
|
@@ -986,7 +989,8 @@ class JaxRDDLCompiler:
|
|
|
986
989
|
sample_cases = [None] * len(jax_cases)
|
|
987
990
|
for (i, jax_case) in enumerate(jax_cases):
|
|
988
991
|
sample_cases[i], key, err_case, params = jax_case(x, params, key)
|
|
989
|
-
err |= err_case
|
|
992
|
+
err |= err_case
|
|
993
|
+
sample_cases = jnp.asarray(sample_cases)
|
|
990
994
|
sample_cases = jnp.asarray(sample_cases, dtype=self._fix_dtype(sample_cases))
|
|
991
995
|
|
|
992
996
|
# predicate (enum) is an integer - use it to extract from case array
|
pyRDDLGym_jax/core/planner.py
CHANGED
|
@@ -39,6 +39,7 @@ import configparser
|
|
|
39
39
|
from enum import Enum
|
|
40
40
|
from functools import partial
|
|
41
41
|
import os
|
|
42
|
+
import pickle
|
|
42
43
|
import sys
|
|
43
44
|
import time
|
|
44
45
|
import traceback
|
|
@@ -229,13 +230,19 @@ def _load_config(config, args):
|
|
|
229
230
|
|
|
230
231
|
|
|
231
232
|
def load_config(path: str) -> Tuple[Kwargs, ...]:
|
|
232
|
-
'''Loads a config file at the specified file path.
|
|
233
|
+
'''Loads a config file at the specified file path.
|
|
234
|
+
|
|
235
|
+
:param path: the path of the config file to load and parse
|
|
236
|
+
'''
|
|
233
237
|
config, args = _parse_config_file(path)
|
|
234
238
|
return _load_config(config, args)
|
|
235
239
|
|
|
236
240
|
|
|
237
241
|
def load_config_from_string(value: str) -> Tuple[Kwargs, ...]:
|
|
238
|
-
'''Loads config file contents specified explicitly as a string value.
|
|
242
|
+
'''Loads config file contents specified explicitly as a string value.
|
|
243
|
+
|
|
244
|
+
:param value: the string in json format containing the config contents to parse
|
|
245
|
+
'''
|
|
239
246
|
config, args = _parse_config_string(value)
|
|
240
247
|
return _load_config(config, args)
|
|
241
248
|
|
|
@@ -258,6 +265,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
258
265
|
def __init__(self, *args,
|
|
259
266
|
logic: Logic=FuzzyLogic(),
|
|
260
267
|
cpfs_without_grad: Optional[Set[str]]=None,
|
|
268
|
+
print_warnings: bool=True,
|
|
261
269
|
**kwargs) -> None:
|
|
262
270
|
'''Creates a new RDDL to Jax compiler, where operations that are not
|
|
263
271
|
differentiable are converted to approximate forms that have defined gradients.
|
|
@@ -268,6 +276,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
268
276
|
to customize these operations
|
|
269
277
|
:param cpfs_without_grad: which CPFs do not have gradients (use straight
|
|
270
278
|
through gradient trick)
|
|
279
|
+
:param print_warnings: whether to print warnings
|
|
271
280
|
:param *kwargs: keyword arguments to pass to base compiler
|
|
272
281
|
'''
|
|
273
282
|
super(JaxRDDLCompilerWithGrad, self).__init__(*args, **kwargs)
|
|
@@ -277,6 +286,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
277
286
|
if cpfs_without_grad is None:
|
|
278
287
|
cpfs_without_grad = set()
|
|
279
288
|
self.cpfs_without_grad = cpfs_without_grad
|
|
289
|
+
self.print_warnings = print_warnings
|
|
280
290
|
|
|
281
291
|
# actions and CPFs must be continuous
|
|
282
292
|
pvars_cast = set()
|
|
@@ -284,7 +294,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
284
294
|
self.init_values[var] = np.asarray(values, dtype=self.REAL)
|
|
285
295
|
if not np.issubdtype(np.result_type(values), np.floating):
|
|
286
296
|
pvars_cast.add(var)
|
|
287
|
-
if pvars_cast:
|
|
297
|
+
if self.print_warnings and pvars_cast:
|
|
288
298
|
message = termcolor.colored(
|
|
289
299
|
f'[INFO] JAX gradient compiler will cast p-vars {pvars_cast} to float.',
|
|
290
300
|
'green')
|
|
@@ -314,12 +324,12 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
|
|
|
314
324
|
if cpf in self.cpfs_without_grad:
|
|
315
325
|
jax_cpfs[cpf] = self._jax_stop_grad(jax_cpfs[cpf])
|
|
316
326
|
|
|
317
|
-
if cpfs_cast:
|
|
327
|
+
if self.print_warnings and cpfs_cast:
|
|
318
328
|
message = termcolor.colored(
|
|
319
329
|
f'[INFO] JAX gradient compiler will cast CPFs {cpfs_cast} to float.',
|
|
320
330
|
'green')
|
|
321
331
|
print(message)
|
|
322
|
-
if self.cpfs_without_grad:
|
|
332
|
+
if self.print_warnings and self.cpfs_without_grad:
|
|
323
333
|
message = termcolor.colored(
|
|
324
334
|
f'[INFO] Gradients will not flow through CPFs {self.cpfs_without_grad}.',
|
|
325
335
|
'green')
|
|
@@ -436,10 +446,11 @@ class JaxPlan(metaclass=ABCMeta):
|
|
|
436
446
|
~lower_finite & upper_finite,
|
|
437
447
|
~lower_finite & ~upper_finite]
|
|
438
448
|
bounds[name] = (lower, upper)
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
449
|
+
if compiled.print_warnings:
|
|
450
|
+
message = termcolor.colored(
|
|
451
|
+
f'[INFO] Bounds of action-fluent <{name}> set to {bounds[name]}.',
|
|
452
|
+
'green')
|
|
453
|
+
print(message)
|
|
443
454
|
return shapes, bounds, bounds_safe, cond_lists
|
|
444
455
|
|
|
445
456
|
def _count_bool_actions(self, rddl: RDDLLiftedModel):
|
|
@@ -519,7 +530,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
519
530
|
# action concurrency check
|
|
520
531
|
bool_action_count, allowed_actions = self._count_bool_actions(rddl)
|
|
521
532
|
use_constraint_satisfaction = allowed_actions < bool_action_count
|
|
522
|
-
if use_constraint_satisfaction:
|
|
533
|
+
if compiled.print_warnings and use_constraint_satisfaction:
|
|
523
534
|
message = termcolor.colored(
|
|
524
535
|
f'[INFO] SLP will use projected gradient to satisfy '
|
|
525
536
|
f'max_nondef_actions since total boolean actions '
|
|
@@ -605,7 +616,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
605
616
|
start = 0
|
|
606
617
|
for (name, size) in action_sizes.items():
|
|
607
618
|
action = output[..., start:start + size]
|
|
608
|
-
action = jnp.reshape(action,
|
|
619
|
+
action = jnp.reshape(action, shapes[name][1:])
|
|
609
620
|
if noop[name]:
|
|
610
621
|
action = 1.0 - action
|
|
611
622
|
actions[name] = action
|
|
@@ -838,7 +849,7 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
838
849
|
|
|
839
850
|
def guess_next_epoch(self, params: Pytree) -> Pytree:
|
|
840
851
|
next_fn = JaxStraightLinePlan._guess_next_epoch
|
|
841
|
-
return jax.tree_map(next_fn, params)
|
|
852
|
+
return jax.tree_util.tree_map(next_fn, params)
|
|
842
853
|
|
|
843
854
|
|
|
844
855
|
class JaxDeepReactivePolicy(JaxPlan):
|
|
@@ -946,17 +957,19 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
946
957
|
if ranges[var] != 'bool':
|
|
947
958
|
value_size = np.size(values)
|
|
948
959
|
if normalize_per_layer and value_size == 1:
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
960
|
+
if compiled.print_warnings:
|
|
961
|
+
message = termcolor.colored(
|
|
962
|
+
f'[WARN] Cannot apply layer norm to state-fluent <{var}> '
|
|
963
|
+
f'of size 1: setting normalize_per_layer = False.', 'yellow')
|
|
964
|
+
print(message)
|
|
953
965
|
normalize_per_layer = False
|
|
954
966
|
non_bool_dims += value_size
|
|
955
967
|
if not normalize_per_layer and non_bool_dims == 1:
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
968
|
+
if compiled.print_warnings:
|
|
969
|
+
message = termcolor.colored(
|
|
970
|
+
'[WARN] Cannot apply layer norm to state-fluents of total size 1: '
|
|
971
|
+
'setting normalize = False.', 'yellow')
|
|
972
|
+
print(message)
|
|
960
973
|
normalize = False
|
|
961
974
|
|
|
962
975
|
# convert subs dictionary into a state vector to feed to the MLP
|
|
@@ -1054,7 +1067,7 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1054
1067
|
for (name, size) in layer_sizes.items():
|
|
1055
1068
|
if ranges[name] == 'bool':
|
|
1056
1069
|
action = output[..., start:start + size]
|
|
1057
|
-
action = jnp.reshape(action,
|
|
1070
|
+
action = jnp.reshape(action, shapes[name])
|
|
1058
1071
|
if noop[name]:
|
|
1059
1072
|
action = 1.0 - action
|
|
1060
1073
|
actions[name] = action
|
|
@@ -1226,6 +1239,7 @@ class PGPE(metaclass=ABCMeta):
|
|
|
1226
1239
|
|
|
1227
1240
|
@abstractmethod
|
|
1228
1241
|
def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type,
|
|
1242
|
+
print_warnings: bool,
|
|
1229
1243
|
parallel_updates: Optional[int]=None) -> None:
|
|
1230
1244
|
pass
|
|
1231
1245
|
|
|
@@ -1322,6 +1336,7 @@ class GaussianPGPE(PGPE):
|
|
|
1322
1336
|
)
|
|
1323
1337
|
|
|
1324
1338
|
def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type,
|
|
1339
|
+
print_warnings: bool,
|
|
1325
1340
|
parallel_updates: Optional[int]=None) -> None:
|
|
1326
1341
|
sigma0 = self.init_sigma
|
|
1327
1342
|
sigma_lo, sigma_hi = self.sigma_range
|
|
@@ -1347,7 +1362,7 @@ class GaussianPGPE(PGPE):
|
|
|
1347
1362
|
|
|
1348
1363
|
def _jax_wrapped_pgpe_init(key, policy_params):
|
|
1349
1364
|
mu = policy_params
|
|
1350
|
-
sigma = jax.tree_map(partial(jnp.full_like, fill_value=sigma0), mu)
|
|
1365
|
+
sigma = jax.tree_util.tree_map(partial(jnp.full_like, fill_value=sigma0), mu)
|
|
1351
1366
|
pgpe_params = (mu, sigma)
|
|
1352
1367
|
pgpe_opt_state = (mu_optimizer.init(mu), sigma_optimizer.init(sigma))
|
|
1353
1368
|
r_max = -jnp.inf
|
|
@@ -1395,13 +1410,14 @@ class GaussianPGPE(PGPE):
|
|
|
1395
1410
|
treedef = jax.tree_util.tree_structure(sigma)
|
|
1396
1411
|
keys = random.split(key, num=treedef.num_leaves)
|
|
1397
1412
|
keys_pytree = jax.tree_util.tree_unflatten(treedef=treedef, leaves=keys)
|
|
1398
|
-
epsilon = jax.tree_map(_jax_wrapped_mu_noise, keys_pytree, sigma)
|
|
1399
|
-
p1 = jax.tree_map(jnp.add, mu, epsilon)
|
|
1400
|
-
p2 = jax.tree_map(jnp.subtract, mu, epsilon)
|
|
1413
|
+
epsilon = jax.tree_util.tree_map(_jax_wrapped_mu_noise, keys_pytree, sigma)
|
|
1414
|
+
p1 = jax.tree_util.tree_map(jnp.add, mu, epsilon)
|
|
1415
|
+
p2 = jax.tree_util.tree_map(jnp.subtract, mu, epsilon)
|
|
1401
1416
|
if super_symmetric:
|
|
1402
|
-
epsilon_star = jax.tree_map(
|
|
1403
|
-
|
|
1404
|
-
|
|
1417
|
+
epsilon_star = jax.tree_util.tree_map(
|
|
1418
|
+
_jax_wrapped_epsilon_star, sigma, epsilon)
|
|
1419
|
+
p3 = jax.tree_util.tree_map(jnp.add, mu, epsilon_star)
|
|
1420
|
+
p4 = jax.tree_util.tree_map(jnp.subtract, mu, epsilon_star)
|
|
1405
1421
|
else:
|
|
1406
1422
|
epsilon_star, p3, p4 = epsilon, p1, p2
|
|
1407
1423
|
return p1, p2, p3, p4, epsilon, epsilon_star
|
|
@@ -1469,11 +1485,11 @@ class GaussianPGPE(PGPE):
|
|
|
1469
1485
|
r_max = jnp.maximum(r_max, r4)
|
|
1470
1486
|
else:
|
|
1471
1487
|
r3, r4 = r1, r2
|
|
1472
|
-
grad_mu = jax.tree_map(
|
|
1488
|
+
grad_mu = jax.tree_util.tree_map(
|
|
1473
1489
|
partial(_jax_wrapped_mu_grad, r1=r1, r2=r2, r3=r3, r4=r4, m=r_max),
|
|
1474
1490
|
epsilon, epsilon_star
|
|
1475
1491
|
)
|
|
1476
|
-
grad_sigma = jax.tree_map(
|
|
1492
|
+
grad_sigma = jax.tree_util.tree_map(
|
|
1477
1493
|
partial(_jax_wrapped_sigma_grad,
|
|
1478
1494
|
r1=r1, r2=r2, r3=r3, r4=r4, m=r_max, ent=ent),
|
|
1479
1495
|
epsilon, epsilon_star, sigma
|
|
@@ -1492,7 +1508,7 @@ class GaussianPGPE(PGPE):
|
|
|
1492
1508
|
_jax_wrapped_pgpe_grad,
|
|
1493
1509
|
in_axes=(0, None, None, None, None, None, None, None)
|
|
1494
1510
|
)(keys, mu, sigma, r_max, ent, policy_hyperparams, subs, model_params)
|
|
1495
|
-
mu_grad, sigma_grad = jax.tree_map(
|
|
1511
|
+
mu_grad, sigma_grad = jax.tree_util.tree_map(
|
|
1496
1512
|
partial(jnp.mean, axis=0), (mu_grads, sigma_grads))
|
|
1497
1513
|
new_r_max = jnp.max(r_maxs)
|
|
1498
1514
|
return mu_grad, sigma_grad, new_r_max
|
|
@@ -1516,7 +1532,7 @@ class GaussianPGPE(PGPE):
|
|
|
1516
1532
|
sigma_grad, sigma_state, params=sigma)
|
|
1517
1533
|
new_mu = optax.apply_updates(mu, mu_updates)
|
|
1518
1534
|
new_sigma = optax.apply_updates(sigma, sigma_updates)
|
|
1519
|
-
new_sigma = jax.tree_map(
|
|
1535
|
+
new_sigma = jax.tree_util.tree_map(
|
|
1520
1536
|
partial(jnp.clip, min=sigma_lo, max=sigma_hi), new_sigma)
|
|
1521
1537
|
return new_mu, new_sigma, new_mu_state, new_sigma_state
|
|
1522
1538
|
|
|
@@ -1537,7 +1553,7 @@ class GaussianPGPE(PGPE):
|
|
|
1537
1553
|
if max_kl is not None:
|
|
1538
1554
|
old_mu_lr = new_mu_state.hyperparams['learning_rate']
|
|
1539
1555
|
old_sigma_lr = new_sigma_state.hyperparams['learning_rate']
|
|
1540
|
-
kl_terms = jax.tree_map(
|
|
1556
|
+
kl_terms = jax.tree_util.tree_map(
|
|
1541
1557
|
_jax_wrapped_pgpe_kl_term, new_mu, new_sigma, mu, sigma)
|
|
1542
1558
|
total_kl = jax.tree_util.tree_reduce(jnp.add, kl_terms)
|
|
1543
1559
|
kl_reduction = jnp.minimum(1.0, jnp.sqrt(max_kl / total_kl))
|
|
@@ -1672,6 +1688,7 @@ class JaxBackpropPlanner:
|
|
|
1672
1688
|
compile_non_fluent_exact: bool=True,
|
|
1673
1689
|
logger: Optional[Logger]=None,
|
|
1674
1690
|
dashboard_viz: Optional[Any]=None,
|
|
1691
|
+
print_warnings: bool=True,
|
|
1675
1692
|
parallel_updates: Optional[int]=None) -> None:
|
|
1676
1693
|
'''Creates a new gradient-based algorithm for optimizing action sequences
|
|
1677
1694
|
(plan) in the given RDDL. Some operations will be converted to their
|
|
@@ -1712,6 +1729,7 @@ class JaxBackpropPlanner:
|
|
|
1712
1729
|
:param logger: to log information about compilation to file
|
|
1713
1730
|
:param dashboard_viz: optional visualizer object from the environment
|
|
1714
1731
|
to pass to the dashboard to visualize the policy
|
|
1732
|
+
:param print_warnings: whether to print warnings
|
|
1715
1733
|
:param parallel_updates: how many optimizers to run independently in parallel
|
|
1716
1734
|
'''
|
|
1717
1735
|
self.rddl = rddl
|
|
@@ -1737,6 +1755,7 @@ class JaxBackpropPlanner:
|
|
|
1737
1755
|
self.noise_kwargs = noise_kwargs
|
|
1738
1756
|
self.pgpe = pgpe
|
|
1739
1757
|
self.use_pgpe = pgpe is not None
|
|
1758
|
+
self.print_warnings = print_warnings
|
|
1740
1759
|
|
|
1741
1760
|
# set optimizer
|
|
1742
1761
|
try:
|
|
@@ -1789,7 +1808,11 @@ class JaxBackpropPlanner:
|
|
|
1789
1808
|
self._jax_compile_rddl()
|
|
1790
1809
|
self._jax_compile_optimizer()
|
|
1791
1810
|
|
|
1792
|
-
|
|
1811
|
+
@staticmethod
|
|
1812
|
+
def summarize_system() -> str:
|
|
1813
|
+
'''Returns a string containing information about the system, Python version
|
|
1814
|
+
and jax-related packages that are relevant to the current planner.
|
|
1815
|
+
'''
|
|
1793
1816
|
try:
|
|
1794
1817
|
jaxlib_version = jax._src.lib.version_str
|
|
1795
1818
|
except Exception as _:
|
|
@@ -1818,6 +1841,9 @@ r"""
|
|
|
1818
1841
|
f'devices: {devices_short}\n')
|
|
1819
1842
|
|
|
1820
1843
|
def summarize_relaxations(self) -> str:
|
|
1844
|
+
'''Returns a summary table containing all non-differentiable operators
|
|
1845
|
+
and their relaxations.
|
|
1846
|
+
'''
|
|
1821
1847
|
result = ''
|
|
1822
1848
|
if self.compiled.model_params:
|
|
1823
1849
|
result += ('Some RDDL operations are non-differentiable '
|
|
@@ -1834,6 +1860,9 @@ r"""
|
|
|
1834
1860
|
return result
|
|
1835
1861
|
|
|
1836
1862
|
def summarize_hyperparameters(self) -> str:
|
|
1863
|
+
'''Returns a string summarizing the hyper-parameters of the current planner
|
|
1864
|
+
instance.
|
|
1865
|
+
'''
|
|
1837
1866
|
result = (f'objective hyper-parameters:\n'
|
|
1838
1867
|
f' utility_fn ={self.utility.__name__}\n'
|
|
1839
1868
|
f' utility args ={self.utility_kwargs}\n'
|
|
@@ -1873,7 +1902,8 @@ r"""
|
|
|
1873
1902
|
logger=self.logger,
|
|
1874
1903
|
use64bit=self.use64bit,
|
|
1875
1904
|
cpfs_without_grad=self.cpfs_without_grad,
|
|
1876
|
-
compile_non_fluent_exact=self.compile_non_fluent_exact
|
|
1905
|
+
compile_non_fluent_exact=self.compile_non_fluent_exact,
|
|
1906
|
+
print_warnings=self.print_warnings
|
|
1877
1907
|
)
|
|
1878
1908
|
self.compiled.compile(log_jax_expr=True, heading='RELAXED MODEL')
|
|
1879
1909
|
|
|
@@ -1922,7 +1952,8 @@ r"""
|
|
|
1922
1952
|
|
|
1923
1953
|
# optimization
|
|
1924
1954
|
self.update = self._jax_update(train_loss)
|
|
1925
|
-
self.pytree_at = jax.jit(
|
|
1955
|
+
self.pytree_at = jax.jit(
|
|
1956
|
+
lambda tree, i: jax.tree_util.tree_map(lambda x: x[i], tree))
|
|
1926
1957
|
|
|
1927
1958
|
# pgpe option
|
|
1928
1959
|
if self.use_pgpe:
|
|
@@ -1930,6 +1961,7 @@ r"""
|
|
|
1930
1961
|
loss_fn=test_loss,
|
|
1931
1962
|
projection=self.plan.projection,
|
|
1932
1963
|
real_dtype=self.test_compiled.REAL,
|
|
1964
|
+
print_warnings=self.print_warnings,
|
|
1933
1965
|
parallel_updates=self.parallel_updates
|
|
1934
1966
|
)
|
|
1935
1967
|
self.merge_pgpe = self._jax_merge_pgpe_jaxplan()
|
|
@@ -2010,7 +2042,7 @@ r"""
|
|
|
2010
2042
|
# check if the gradients are all zeros
|
|
2011
2043
|
def _jax_wrapped_zero_gradients(grad):
|
|
2012
2044
|
leaves, _ = jax.tree_util.tree_flatten(
|
|
2013
|
-
jax.tree_map(partial(jnp.allclose, b=0), grad))
|
|
2045
|
+
jax.tree_util.tree_map(partial(jnp.allclose, b=0), grad))
|
|
2014
2046
|
return jnp.all(jnp.asarray(leaves))
|
|
2015
2047
|
|
|
2016
2048
|
# calculate the plan gradient w.r.t. return loss and update optimizer
|
|
@@ -2069,7 +2101,7 @@ r"""
|
|
|
2069
2101
|
def select_fn(leaf1, leaf2):
|
|
2070
2102
|
expanded_mask = pgpe_mask[(...,) + (jnp.newaxis,) * (jnp.ndim(leaf1) - 1)]
|
|
2071
2103
|
return jnp.where(expanded_mask, leaf1, leaf2)
|
|
2072
|
-
policy_params = jax.tree_map(select_fn, pgpe_param, policy_params)
|
|
2104
|
+
policy_params = jax.tree_util.tree_map(select_fn, pgpe_param, policy_params)
|
|
2073
2105
|
test_loss = jnp.where(pgpe_mask, pgpe_loss, test_loss)
|
|
2074
2106
|
test_loss_smooth = jnp.where(pgpe_mask, pgpe_loss_smooth, test_loss_smooth)
|
|
2075
2107
|
expanded_mask = pgpe_mask[(...,) + (jnp.newaxis,) * (jnp.ndim(converged) - 1)]
|
|
@@ -2091,7 +2123,9 @@ r"""
|
|
|
2091
2123
|
f'Variable <{name}> in subs argument is not a '
|
|
2092
2124
|
f'valid p-variable, must be one of '
|
|
2093
2125
|
f'{set(self.test_compiled.init_values.keys())}.')
|
|
2094
|
-
value = np.reshape(value,
|
|
2126
|
+
value = np.reshape(value, np.shape(init_value))[np.newaxis, ...]
|
|
2127
|
+
if value.dtype.type is np.str_:
|
|
2128
|
+
value = rddl.object_string_to_index_array(rddl.variable_ranges[name], value)
|
|
2095
2129
|
train_value = np.repeat(value, repeats=n_train, axis=0)
|
|
2096
2130
|
train_value = np.asarray(train_value, dtype=self.compiled.REAL)
|
|
2097
2131
|
init_train[name] = train_value
|
|
@@ -2121,7 +2155,7 @@ r"""
|
|
|
2121
2155
|
x[np.newaxis, ...], shape=(self.parallel_updates,) + np.shape(x))
|
|
2122
2156
|
return x
|
|
2123
2157
|
|
|
2124
|
-
return jax.tree_map(make_batched, pytree)
|
|
2158
|
+
return jax.tree_util.tree_map(make_batched, pytree)
|
|
2125
2159
|
|
|
2126
2160
|
def as_optimization_problem(
|
|
2127
2161
|
self, key: Optional[random.PRNGKey]=None,
|
|
@@ -2165,10 +2199,11 @@ r"""
|
|
|
2165
2199
|
train_subs, _ = self._batched_init_subs(subs)
|
|
2166
2200
|
model_params = self.compiled.model_params
|
|
2167
2201
|
if policy_hyperparams is None:
|
|
2168
|
-
|
|
2169
|
-
|
|
2170
|
-
|
|
2171
|
-
|
|
2202
|
+
if self.print_warnings:
|
|
2203
|
+
message = termcolor.colored(
|
|
2204
|
+
'[WARN] policy_hyperparams is not set, setting 1.0 for '
|
|
2205
|
+
'all action-fluents which could be suboptimal.', 'yellow')
|
|
2206
|
+
print(message)
|
|
2172
2207
|
policy_hyperparams = {action: 1.0
|
|
2173
2208
|
for action in self.rddl.action_fluents}
|
|
2174
2209
|
|
|
@@ -2318,10 +2353,11 @@ r"""
|
|
|
2318
2353
|
|
|
2319
2354
|
# cannot run dashboard with parallel updates
|
|
2320
2355
|
if dashboard is not None and self.parallel_updates is not None:
|
|
2321
|
-
|
|
2322
|
-
|
|
2323
|
-
|
|
2324
|
-
|
|
2356
|
+
if self.print_warnings:
|
|
2357
|
+
message = termcolor.colored(
|
|
2358
|
+
'[WARN] Dashboard is unavailable if parallel_updates is not None: '
|
|
2359
|
+
'setting dashboard to None.', 'yellow')
|
|
2360
|
+
print(message)
|
|
2325
2361
|
dashboard = None
|
|
2326
2362
|
|
|
2327
2363
|
# if PRNG key is not provided
|
|
@@ -2331,19 +2367,21 @@ r"""
|
|
|
2331
2367
|
|
|
2332
2368
|
# if policy_hyperparams is not provided
|
|
2333
2369
|
if policy_hyperparams is None:
|
|
2334
|
-
|
|
2335
|
-
|
|
2336
|
-
|
|
2337
|
-
|
|
2370
|
+
if self.print_warnings:
|
|
2371
|
+
message = termcolor.colored(
|
|
2372
|
+
'[WARN] policy_hyperparams is not set, setting 1.0 for '
|
|
2373
|
+
'all action-fluents which could be suboptimal.', 'yellow')
|
|
2374
|
+
print(message)
|
|
2338
2375
|
policy_hyperparams = {action: 1.0
|
|
2339
2376
|
for action in self.rddl.action_fluents}
|
|
2340
2377
|
|
|
2341
2378
|
# if policy_hyperparams is a scalar
|
|
2342
2379
|
elif isinstance(policy_hyperparams, (int, float, np.number)):
|
|
2343
|
-
|
|
2344
|
-
|
|
2345
|
-
|
|
2346
|
-
|
|
2380
|
+
if self.print_warnings:
|
|
2381
|
+
message = termcolor.colored(
|
|
2382
|
+
f'[INFO] policy_hyperparams is {policy_hyperparams}, '
|
|
2383
|
+
f'setting this value for all action-fluents.', 'green')
|
|
2384
|
+
print(message)
|
|
2347
2385
|
hyperparam_value = float(policy_hyperparams)
|
|
2348
2386
|
policy_hyperparams = {action: hyperparam_value
|
|
2349
2387
|
for action in self.rddl.action_fluents}
|
|
@@ -2352,11 +2390,12 @@ r"""
|
|
|
2352
2390
|
elif isinstance(policy_hyperparams, dict):
|
|
2353
2391
|
for action in self.rddl.action_fluents:
|
|
2354
2392
|
if action not in policy_hyperparams:
|
|
2355
|
-
|
|
2356
|
-
|
|
2357
|
-
|
|
2358
|
-
|
|
2359
|
-
|
|
2393
|
+
if self.print_warnings:
|
|
2394
|
+
message = termcolor.colored(
|
|
2395
|
+
f'[WARN] policy_hyperparams[{action}] is not set, '
|
|
2396
|
+
f'setting 1.0 for missing action-fluents '
|
|
2397
|
+
f'which could be suboptimal.', 'yellow')
|
|
2398
|
+
print(message)
|
|
2360
2399
|
policy_hyperparams[action] = 1.0
|
|
2361
2400
|
|
|
2362
2401
|
# print summary of parameters:
|
|
@@ -2396,7 +2435,7 @@ r"""
|
|
|
2396
2435
|
if var not in subs:
|
|
2397
2436
|
subs[var] = value
|
|
2398
2437
|
added_pvars_to_subs.append(var)
|
|
2399
|
-
if added_pvars_to_subs:
|
|
2438
|
+
if self.print_warnings and added_pvars_to_subs:
|
|
2400
2439
|
message = termcolor.colored(
|
|
2401
2440
|
f'[INFO] p-variables {added_pvars_to_subs} is not in '
|
|
2402
2441
|
f'provided subs, using their initial values.', 'green')
|
|
@@ -2648,7 +2687,7 @@ r"""
|
|
|
2648
2687
|
policy_params, opt_state, opt_aux = self.initialize(
|
|
2649
2688
|
subkey, policy_hyperparams, train_subs)
|
|
2650
2689
|
no_progress_count = 0
|
|
2651
|
-
if progress_bar is not None:
|
|
2690
|
+
if self.print_warnings and progress_bar is not None:
|
|
2652
2691
|
message = termcolor.colored(
|
|
2653
2692
|
f'[INFO] Optimizer restarted at iteration {it} '
|
|
2654
2693
|
f'due to lack of progress.', 'green')
|
|
@@ -2658,7 +2697,7 @@ r"""
|
|
|
2658
2697
|
|
|
2659
2698
|
# stopping condition reached
|
|
2660
2699
|
if stopping_rule is not None and stopping_rule.monitor(callback):
|
|
2661
|
-
if progress_bar is not None:
|
|
2700
|
+
if self.print_warnings and progress_bar is not None:
|
|
2662
2701
|
message = termcolor.colored(
|
|
2663
2702
|
'[SUCC] Stopping rule has been reached.', 'green')
|
|
2664
2703
|
progress_bar.write(message)
|
|
@@ -2699,7 +2738,8 @@ r"""
|
|
|
2699
2738
|
|
|
2700
2739
|
# summarize and test for convergence
|
|
2701
2740
|
if print_summary:
|
|
2702
|
-
grad_norm = jax.tree_map(
|
|
2741
|
+
grad_norm = jax.tree_util.tree_map(
|
|
2742
|
+
lambda x: np.linalg.norm(x).item(), best_grad)
|
|
2703
2743
|
diagnosis = self._perform_diagnosis(
|
|
2704
2744
|
last_iter_improve, -np.min(train_loss), -np.min(test_loss_smooth),
|
|
2705
2745
|
-best_loss, grad_norm)
|
|
@@ -2777,6 +2817,7 @@ r"""
|
|
|
2777
2817
|
:param policy_hyperparams: hyper-parameters for the policy/plan, such as
|
|
2778
2818
|
weights for sigmoid wrapping boolean actions (optional)
|
|
2779
2819
|
'''
|
|
2820
|
+
subs = subs.copy()
|
|
2780
2821
|
|
|
2781
2822
|
# check compatibility of the subs dictionary
|
|
2782
2823
|
for (var, values) in subs.items():
|
|
@@ -2795,13 +2836,17 @@ r"""
|
|
|
2795
2836
|
if step == 0 and var in self.rddl.observ_fluents:
|
|
2796
2837
|
subs[var] = self.test_compiled.init_values[var]
|
|
2797
2838
|
else:
|
|
2798
|
-
|
|
2799
|
-
|
|
2800
|
-
|
|
2839
|
+
if dtype.type is np.str_:
|
|
2840
|
+
prange = self.rddl.variable_ranges[var]
|
|
2841
|
+
subs[var] = self.rddl.object_string_to_index_array(prange, subs[var])
|
|
2842
|
+
else:
|
|
2843
|
+
raise ValueError(
|
|
2844
|
+
f'Values {values} assigned to p-variable <{var}> are '
|
|
2845
|
+
f'non-numeric of type {dtype}.')
|
|
2801
2846
|
|
|
2802
2847
|
# cast device arrays to numpy
|
|
2803
2848
|
actions = self.test_policy(key, params, policy_hyperparams, step, subs)
|
|
2804
|
-
actions = jax.tree_map(np.asarray, actions)
|
|
2849
|
+
actions = jax.tree_util.tree_map(np.asarray, actions)
|
|
2805
2850
|
return actions
|
|
2806
2851
|
|
|
2807
2852
|
|
|
@@ -2822,8 +2867,9 @@ class JaxOfflineController(BaseAgent):
|
|
|
2822
2867
|
def __init__(self, planner: JaxBackpropPlanner,
|
|
2823
2868
|
key: Optional[random.PRNGKey]=None,
|
|
2824
2869
|
eval_hyperparams: Optional[Dict[str, Any]]=None,
|
|
2825
|
-
params: Optional[Pytree]=None,
|
|
2870
|
+
params: Optional[Union[str, Pytree]]=None,
|
|
2826
2871
|
train_on_reset: bool=False,
|
|
2872
|
+
save_path: Optional[str]=None,
|
|
2827
2873
|
**train_kwargs) -> None:
|
|
2828
2874
|
'''Creates a new JAX offline control policy that is trained once, then
|
|
2829
2875
|
deployed later.
|
|
@@ -2834,8 +2880,10 @@ class JaxOfflineController(BaseAgent):
|
|
|
2834
2880
|
:param eval_hyperparams: policy hyperparameters to apply for evaluation
|
|
2835
2881
|
or whenever sample_action is called
|
|
2836
2882
|
:param params: use the specified policy parameters instead of calling
|
|
2837
|
-
planner.optimize()
|
|
2883
|
+
planner.optimize(); can be a string pointing to a valid file path where params
|
|
2884
|
+
have been saved, or a pytree of parameters
|
|
2838
2885
|
:param train_on_reset: retrain policy parameters on every episode reset
|
|
2886
|
+
:param save_path: optional path to save parameters to
|
|
2839
2887
|
:param **train_kwargs: any keyword arguments to be passed to the planner
|
|
2840
2888
|
for optimization
|
|
2841
2889
|
'''
|
|
@@ -2848,12 +2896,24 @@ class JaxOfflineController(BaseAgent):
|
|
|
2848
2896
|
self.train_kwargs = train_kwargs
|
|
2849
2897
|
self.params_given = params is not None
|
|
2850
2898
|
|
|
2899
|
+
# load the policy from file
|
|
2900
|
+
if not self.train_on_reset and params is not None and isinstance(params, str):
|
|
2901
|
+
with open(params, 'rb') as file:
|
|
2902
|
+
params = pickle.load(file)
|
|
2903
|
+
|
|
2904
|
+
# train the policy
|
|
2851
2905
|
self.step = 0
|
|
2852
2906
|
self.callback = None
|
|
2853
2907
|
if not self.train_on_reset and not self.params_given:
|
|
2854
2908
|
callback = self.planner.optimize(key=self.key, **self.train_kwargs)
|
|
2855
2909
|
self.callback = callback
|
|
2856
2910
|
params = callback['best_params']
|
|
2911
|
+
|
|
2912
|
+
# save the policy
|
|
2913
|
+
if save_path is not None:
|
|
2914
|
+
with open(save_path, 'wb') as file:
|
|
2915
|
+
pickle.dump(params, file)
|
|
2916
|
+
|
|
2857
2917
|
self.params = params
|
|
2858
2918
|
|
|
2859
2919
|
def sample_action(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
|
@@ -2865,6 +2925,8 @@ class JaxOfflineController(BaseAgent):
|
|
|
2865
2925
|
|
|
2866
2926
|
def reset(self) -> None:
|
|
2867
2927
|
self.step = 0
|
|
2928
|
+
|
|
2929
|
+
# train the policy if required to reset at the start of every episode
|
|
2868
2930
|
if self.train_on_reset and not self.params_given:
|
|
2869
2931
|
callback = self.planner.optimize(key=self.key, **self.train_kwargs)
|
|
2870
2932
|
self.callback = callback
|
|
@@ -2915,18 +2977,22 @@ class JaxOnlineController(BaseAgent):
|
|
|
2915
2977
|
attempts = 0
|
|
2916
2978
|
while attempts < self.max_attempts and callback['iteration'] <= 1:
|
|
2917
2979
|
attempts += 1
|
|
2918
|
-
|
|
2919
|
-
|
|
2920
|
-
|
|
2921
|
-
|
|
2922
|
-
|
|
2980
|
+
if self.planner.print_warnings:
|
|
2981
|
+
message = termcolor.colored(
|
|
2982
|
+
f'[WARN] JIT compilation dominated the execution time: '
|
|
2983
|
+
f'executing the optimizer again on the traced model '
|
|
2984
|
+
f'[attempt {attempts}].', 'yellow')
|
|
2985
|
+
print(message)
|
|
2923
2986
|
callback = planner.optimize(
|
|
2924
|
-
key=self.key, guess=self.guess, subs=state, **self.train_kwargs)
|
|
2925
|
-
|
|
2987
|
+
key=self.key, guess=self.guess, subs=state, **self.train_kwargs)
|
|
2926
2988
|
self.callback = callback
|
|
2927
2989
|
params = callback['best_params']
|
|
2990
|
+
|
|
2991
|
+
# get the action from the parameters for the current state
|
|
2928
2992
|
self.key, subkey = random.split(self.key)
|
|
2929
2993
|
actions = planner.get_action(subkey, params, 0, state, self.eval_hyperparams)
|
|
2994
|
+
|
|
2995
|
+
# apply warm start for the next epoch
|
|
2930
2996
|
if self.warm_start:
|
|
2931
2997
|
self.guess = planner.plan.guess_next_epoch(params)
|
|
2932
2998
|
return actions
|
pyRDDLGym_jax/core/simulator.py
CHANGED
|
@@ -19,10 +19,12 @@
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
import time
|
|
22
|
-
|
|
22
|
+
import numpy as np
|
|
23
|
+
from typing import Dict, Optional, Union
|
|
23
24
|
|
|
24
25
|
import jax
|
|
25
26
|
|
|
27
|
+
from pyRDDLGym.core.compiler.initializer import RDDLValueInitializer
|
|
26
28
|
from pyRDDLGym.core.compiler.model import RDDLLiftedModel
|
|
27
29
|
from pyRDDLGym.core.debug.exception import (
|
|
28
30
|
RDDLActionPreconditionNotSatisfiedError,
|
|
@@ -35,7 +37,7 @@ from pyRDDLGym.core.simulator import RDDLSimulator
|
|
|
35
37
|
|
|
36
38
|
from pyRDDLGym_jax.core.compiler import JaxRDDLCompiler
|
|
37
39
|
|
|
38
|
-
Args = Dict[str, Value]
|
|
40
|
+
Args = Dict[str, Union[np.ndarray, Value]]
|
|
39
41
|
|
|
40
42
|
|
|
41
43
|
class JaxRDDLSimulator(RDDLSimulator):
|
|
@@ -45,6 +47,7 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
45
47
|
raise_error: bool=True,
|
|
46
48
|
logger: Optional[Logger]=None,
|
|
47
49
|
keep_tensors: bool=False,
|
|
50
|
+
objects_as_strings: bool=True,
|
|
48
51
|
**compiler_args) -> None:
|
|
49
52
|
'''Creates a new simulator for the given RDDL model with Jax as a backend.
|
|
50
53
|
|
|
@@ -57,6 +60,8 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
57
60
|
:param logger: to log information about compilation to file
|
|
58
61
|
:param keep_tensors: whether the sampler takes actions and
|
|
59
62
|
returns state in numpy array form
|
|
63
|
+
param objects_as_strings: whether to return object values as strings (defaults
|
|
64
|
+
to integer indices if False)
|
|
60
65
|
:param **compiler_args: keyword arguments to pass to the Jax compiler
|
|
61
66
|
'''
|
|
62
67
|
if key is None:
|
|
@@ -67,7 +72,8 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
67
72
|
|
|
68
73
|
# generate direct sampling with default numpy RNG and operations
|
|
69
74
|
super(JaxRDDLSimulator, self).__init__(
|
|
70
|
-
rddl, logger=logger,
|
|
75
|
+
rddl, logger=logger,
|
|
76
|
+
keep_tensors=keep_tensors, objects_as_strings=objects_as_strings)
|
|
71
77
|
|
|
72
78
|
def seed(self, seed: int) -> None:
|
|
73
79
|
super(JaxRDDLSimulator, self).seed(seed)
|
|
@@ -84,11 +90,11 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
84
90
|
self.levels = compiled.levels
|
|
85
91
|
self.traced = compiled.traced
|
|
86
92
|
|
|
87
|
-
self.invariants = jax.tree_map(jax.jit, compiled.invariants)
|
|
88
|
-
self.preconds = jax.tree_map(jax.jit, compiled.preconditions)
|
|
89
|
-
self.terminals = jax.tree_map(jax.jit, compiled.terminations)
|
|
93
|
+
self.invariants = jax.tree_util.tree_map(jax.jit, compiled.invariants)
|
|
94
|
+
self.preconds = jax.tree_util.tree_map(jax.jit, compiled.preconditions)
|
|
95
|
+
self.terminals = jax.tree_util.tree_map(jax.jit, compiled.terminations)
|
|
90
96
|
self.reward = jax.jit(compiled.reward)
|
|
91
|
-
jax_cpfs = jax.tree_map(jax.jit, compiled.cpfs)
|
|
97
|
+
jax_cpfs = jax.tree_util.tree_map(jax.jit, compiled.cpfs)
|
|
92
98
|
self.model_params = compiled.model_params
|
|
93
99
|
|
|
94
100
|
# level analysis
|
|
@@ -139,7 +145,6 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
139
145
|
|
|
140
146
|
def check_action_preconditions(self, actions: Args, silent: bool=False) -> bool:
|
|
141
147
|
'''Throws an exception if the action preconditions are not satisfied.'''
|
|
142
|
-
actions = self._process_actions(actions)
|
|
143
148
|
subs = self.subs
|
|
144
149
|
subs.update(actions)
|
|
145
150
|
|
|
@@ -180,7 +185,6 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
180
185
|
'''
|
|
181
186
|
rddl = self.rddl
|
|
182
187
|
keep_tensors = self.keep_tensors
|
|
183
|
-
actions = self._process_actions(actions)
|
|
184
188
|
subs = self.subs
|
|
185
189
|
subs.update(actions)
|
|
186
190
|
|
|
@@ -196,20 +200,40 @@ class JaxRDDLSimulator(RDDLSimulator):
|
|
|
196
200
|
# update state
|
|
197
201
|
self.state = {}
|
|
198
202
|
for (state, next_state) in rddl.next_state.items():
|
|
203
|
+
|
|
204
|
+
# set state = state' for the next epoch
|
|
199
205
|
subs[state] = subs[next_state]
|
|
206
|
+
|
|
207
|
+
# convert object integer to string representation
|
|
208
|
+
state_values = subs[state]
|
|
209
|
+
if self.objects_as_strings:
|
|
210
|
+
ptype = rddl.variable_ranges[state]
|
|
211
|
+
if ptype not in RDDLValueInitializer.NUMPY_TYPES:
|
|
212
|
+
state_values = rddl.index_to_object_string_array(ptype, state_values)
|
|
213
|
+
|
|
214
|
+
# optional grounding of state dictionary
|
|
200
215
|
if keep_tensors:
|
|
201
|
-
self.state[state] =
|
|
216
|
+
self.state[state] = state_values
|
|
202
217
|
else:
|
|
203
|
-
self.state.update(rddl.ground_var_with_values(state,
|
|
218
|
+
self.state.update(rddl.ground_var_with_values(state, state_values))
|
|
204
219
|
|
|
205
220
|
# update observation
|
|
206
221
|
if self._pomdp:
|
|
207
222
|
obs = {}
|
|
208
223
|
for var in rddl.observ_fluents:
|
|
224
|
+
|
|
225
|
+
# convert object integer to string representation
|
|
226
|
+
obs_values = subs[var]
|
|
227
|
+
if self.objects_as_strings:
|
|
228
|
+
ptype = rddl.variable_ranges[var]
|
|
229
|
+
if ptype not in RDDLValueInitializer.NUMPY_TYPES:
|
|
230
|
+
obs_values = rddl.index_to_object_string_array(ptype, obs_values)
|
|
231
|
+
|
|
232
|
+
# optional grounding of observ-fluent dictionary
|
|
209
233
|
if keep_tensors:
|
|
210
|
-
obs[var] =
|
|
234
|
+
obs[var] = obs_values
|
|
211
235
|
else:
|
|
212
|
-
obs.update(rddl.ground_var_with_values(var,
|
|
236
|
+
obs.update(rddl.ground_var_with_values(var, obs_values))
|
|
213
237
|
else:
|
|
214
238
|
obs = self.state
|
|
215
239
|
|
pyRDDLGym_jax/core/tuning.py
CHANGED
|
@@ -371,16 +371,30 @@ class JaxParameterTuning:
|
|
|
371
371
|
'''Tunes the Bayesian optimization algorithm hyper-parameters.'''
|
|
372
372
|
print(f'Kernel: {repr(optimizer._gp.kernel_)}.')
|
|
373
373
|
|
|
374
|
-
def tune(self, key: int,
|
|
375
|
-
|
|
374
|
+
def tune(self, key: int,
|
|
375
|
+
log_file: Optional[str]=None,
|
|
376
|
+
show_dashboard: bool=False,
|
|
377
|
+
print_hyperparams: bool=False) -> ParameterValues:
|
|
378
|
+
'''Tunes the hyper-parameters for Jax planner, returns the best found.
|
|
376
379
|
|
|
377
|
-
|
|
380
|
+
:param key: RNG key to seed the hyper-parameter optimizer
|
|
381
|
+
:param log_file: optional path to file where tuning progress will be saved
|
|
382
|
+
:param show_dashboard: whether to display tuning results in a dashboard
|
|
383
|
+
:param print_hyperparams: whether to print a hyper-parameter summary of the
|
|
384
|
+
optimizer
|
|
385
|
+
'''
|
|
378
386
|
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
387
|
+
if self.verbose:
|
|
388
|
+
print(JaxBackpropPlanner.summarize_system())
|
|
389
|
+
if print_hyperparams:
|
|
390
|
+
print(self.summarize_hyperparameters())
|
|
383
391
|
|
|
392
|
+
# clear and prepare output file
|
|
393
|
+
if log_file is not None:
|
|
394
|
+
with open(log_file, 'w', newline='') as file:
|
|
395
|
+
writer = csv.writer(file)
|
|
396
|
+
writer.writerow(COLUMNS + list(self.hyperparams_dict.keys()))
|
|
397
|
+
|
|
384
398
|
# create a dash-board for visualizing experiment runs
|
|
385
399
|
if show_dashboard and JaxPlannerDashboard is not None:
|
|
386
400
|
dashboard = JaxPlannerDashboard()
|
|
@@ -519,9 +533,10 @@ class JaxParameterTuning:
|
|
|
519
533
|
self.tune_optimizer(optimizer)
|
|
520
534
|
|
|
521
535
|
# write results of all processes in current iteration to file
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
536
|
+
if log_file is not None:
|
|
537
|
+
with open(log_file, 'a', newline='') as file:
|
|
538
|
+
writer = csv.writer(file)
|
|
539
|
+
writer.writerows(rows)
|
|
525
540
|
|
|
526
541
|
# update the dashboard tuning
|
|
527
542
|
if show_dashboard:
|
pyRDDLGym_jax/entry_point.py
CHANGED
|
@@ -2,24 +2,56 @@ import argparse
|
|
|
2
2
|
|
|
3
3
|
from pyRDDLGym_jax.examples import run_plan, run_tune
|
|
4
4
|
|
|
5
|
+
EPILOG = 'For complete documentation, see https://pyrddlgym.readthedocs.io/en/latest/jax.html.'
|
|
6
|
+
|
|
5
7
|
def main():
|
|
6
|
-
parser = argparse.ArgumentParser(
|
|
8
|
+
parser = argparse.ArgumentParser(prog='jaxplan',
|
|
9
|
+
description="command line parser for the jaxplan planner",
|
|
10
|
+
epilog=EPILOG)
|
|
7
11
|
subparsers = parser.add_subparsers(dest="jaxplan", required=True)
|
|
8
12
|
|
|
9
13
|
# planning
|
|
10
|
-
parser_plan = subparsers.add_parser("plan",
|
|
11
|
-
|
|
14
|
+
parser_plan = subparsers.add_parser("plan",
|
|
15
|
+
help="execute jaxplan on a specified RDDL problem",
|
|
16
|
+
epilog=EPILOG)
|
|
17
|
+
parser_plan.add_argument('domain', type=str,
|
|
18
|
+
help='name of domain in rddlrepository or a valid file path')
|
|
19
|
+
parser_plan.add_argument('instance', type=str,
|
|
20
|
+
help='name of instance in rddlrepository or a valid file path')
|
|
21
|
+
parser_plan.add_argument('method', type=str,
|
|
22
|
+
help='training method to apply: [slp, drp] are offline methods, and [replan] are online')
|
|
23
|
+
parser_plan.add_argument('-e', '--episodes', type=int, required=False, default=1,
|
|
24
|
+
help='number of training or evaluation episodes')
|
|
12
25
|
|
|
13
26
|
# tuning
|
|
14
|
-
parser_tune = subparsers.add_parser("tune",
|
|
15
|
-
|
|
27
|
+
parser_tune = subparsers.add_parser("tune",
|
|
28
|
+
help="tune jaxplan on a specified RDDL problem",
|
|
29
|
+
epilog=EPILOG)
|
|
30
|
+
parser_tune.add_argument('domain', type=str,
|
|
31
|
+
help='name of domain in rddlrepository or a valid file path')
|
|
32
|
+
parser_tune.add_argument('instance', type=str,
|
|
33
|
+
help='name of instance in rddlrepository or a valid file path')
|
|
34
|
+
parser_tune.add_argument('method', type=str,
|
|
35
|
+
help='training method to apply: [slp, drp] are offline methods, and [replan] are online')
|
|
36
|
+
parser_tune.add_argument('-t', '--trials', type=int, required=False, default=5,
|
|
37
|
+
help='number of evaluation rollouts per hyper-parameter choice')
|
|
38
|
+
parser_tune.add_argument('-i', '--iters', type=int, required=False, default=20,
|
|
39
|
+
help='number of iterations of bayesian optimization')
|
|
40
|
+
parser_tune.add_argument('-w', '--workers', type=int, required=False, default=4,
|
|
41
|
+
help='number of parallel hyper-parameters to evaluate per iteration')
|
|
42
|
+
parser_tune.add_argument('-d', '--dashboard', type=bool, required=False, default=False,
|
|
43
|
+
help='show the dashboard')
|
|
44
|
+
parser_tune.add_argument('-f', '--filepath', type=str, required=False, default='',
|
|
45
|
+
help='where to save the config file of the best hyper-parameters')
|
|
16
46
|
|
|
17
47
|
# dispatch
|
|
18
48
|
args = parser.parse_args()
|
|
19
49
|
if args.jaxplan == "plan":
|
|
20
|
-
run_plan.
|
|
50
|
+
run_plan.main(args.domain, args.instance, args.method, args.episodes)
|
|
21
51
|
elif args.jaxplan == "tune":
|
|
22
|
-
run_tune.
|
|
52
|
+
run_tune.main(args.domain, args.instance, args.method,
|
|
53
|
+
args.trials, args.iters, args.workers, args.dashboard,
|
|
54
|
+
args.filepath)
|
|
23
55
|
else:
|
|
24
56
|
parser.print_help()
|
|
25
57
|
|
|
@@ -26,7 +26,7 @@ from pyRDDLGym_jax.core.planner import (
|
|
|
26
26
|
)
|
|
27
27
|
|
|
28
28
|
|
|
29
|
-
def main(domain, instance, method, episodes=1):
|
|
29
|
+
def main(domain: str, instance: str, method: str, episodes: int=1) -> None:
|
|
30
30
|
|
|
31
31
|
# set up the environment
|
|
32
32
|
env = pyRDDLGym.make(domain, instance, vectorized=True)
|
|
@@ -36,7 +36,9 @@ def power_10(x):
|
|
|
36
36
|
return 10.0 ** x
|
|
37
37
|
|
|
38
38
|
|
|
39
|
-
def main(domain, instance, method
|
|
39
|
+
def main(domain: str, instance: str, method: str,
|
|
40
|
+
trials: int=5, iters: int=20, workers: int=4, dashboard: bool=False,
|
|
41
|
+
filepath: str='') -> None:
|
|
40
42
|
|
|
41
43
|
# set up the environment
|
|
42
44
|
env = pyRDDLGym.make(domain, instance, vectorized=True)
|
|
@@ -68,6 +70,9 @@ def main(domain, instance, method, trials=5, iters=20, workers=4, dashboard=Fals
|
|
|
68
70
|
tuning.tune(key=42,
|
|
69
71
|
log_file=f'gp_{method}_{domain}_{instance}.csv',
|
|
70
72
|
show_dashboard=dashboard)
|
|
73
|
+
if filepath is not None and filepath:
|
|
74
|
+
with open(filepath, "w") as file:
|
|
75
|
+
file.write(tuning.best_config)
|
|
71
76
|
|
|
72
77
|
# evaluate the agent on the best parameters
|
|
73
78
|
planner_args, _, train_args = load_config_from_string(tuning.best_config)
|
|
@@ -80,7 +85,7 @@ def main(domain, instance, method, trials=5, iters=20, workers=4, dashboard=Fals
|
|
|
80
85
|
|
|
81
86
|
def run_from_args(args):
|
|
82
87
|
if len(args) < 3:
|
|
83
|
-
print('python run_tune.py <domain> <instance> <method> [<trials>] [<iters>] [<workers>] [<dashboard>]')
|
|
88
|
+
print('python run_tune.py <domain> <instance> <method> [<trials>] [<iters>] [<workers>] [<dashboard>] [<filepath>]')
|
|
84
89
|
exit(1)
|
|
85
90
|
if args[2] not in ['drp', 'slp', 'replan']:
|
|
86
91
|
print('<method> in [drp, slp, replan]')
|
|
@@ -90,6 +95,7 @@ def run_from_args(args):
|
|
|
90
95
|
if len(args) >= 5: kwargs['iters'] = int(args[4])
|
|
91
96
|
if len(args) >= 6: kwargs['workers'] = int(args[5])
|
|
92
97
|
if len(args) >= 7: kwargs['dashboard'] = bool(args[6])
|
|
98
|
+
if len(args) >= 8: kwargs['filepath'] = bool(args[7])
|
|
93
99
|
main(**kwargs)
|
|
94
100
|
|
|
95
101
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: pyRDDLGym-jax
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.5
|
|
4
4
|
Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
|
|
5
5
|
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
6
|
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
@@ -39,6 +39,7 @@ Dynamic: description
|
|
|
39
39
|
Dynamic: description-content-type
|
|
40
40
|
Dynamic: home-page
|
|
41
41
|
Dynamic: license
|
|
42
|
+
Dynamic: license-file
|
|
42
43
|
Dynamic: provides-extra
|
|
43
44
|
Dynamic: requires-dist
|
|
44
45
|
Dynamic: requires-python
|
|
@@ -116,7 +117,7 @@ pip install pyRDDLGym-jax[extra,dashboard]
|
|
|
116
117
|
A basic run script is provided to train JaxPlan on any RDDL problem:
|
|
117
118
|
|
|
118
119
|
```shell
|
|
119
|
-
jaxplan plan <domain> <instance> <method> <episodes>
|
|
120
|
+
jaxplan plan <domain> <instance> <method> --episodes <episodes>
|
|
120
121
|
```
|
|
121
122
|
|
|
122
123
|
where:
|
|
@@ -241,7 +242,7 @@ More documentation about this and other new features will be coming soon.
|
|
|
241
242
|
A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
|
|
242
243
|
|
|
243
244
|
```shell
|
|
244
|
-
jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
|
|
245
|
+
jaxplan tune <domain> <instance> <method> --trials <trials> --iters <iters> --workers <workers> --dashboard <dashboard> --filepath <filepath>
|
|
245
246
|
```
|
|
246
247
|
|
|
247
248
|
where:
|
|
@@ -251,7 +252,8 @@ where:
|
|
|
251
252
|
- ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
|
|
252
253
|
- ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
|
|
253
254
|
- ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
|
|
254
|
-
- ``dashboard`` is whether the optimizations are tracked in the dashboard application
|
|
255
|
+
- ``dashboard`` is whether the optimizations are tracked in the dashboard application
|
|
256
|
+
- ``filepath`` is the optional file path where a config file with the best hyper-parameter setting will be saved.
|
|
255
257
|
|
|
256
258
|
It is easy to tune a custom range of the planner's hyper-parameters efficiently.
|
|
257
259
|
First create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
|
|
@@ -291,23 +293,16 @@ env = pyRDDLGym.make(domain, instance, vectorized=True)
|
|
|
291
293
|
with open('path/to/config.cfg', 'r') as file:
|
|
292
294
|
config_template = file.read()
|
|
293
295
|
|
|
294
|
-
#
|
|
296
|
+
# tune weight from 10^-1 ... 10^5 and lr from 10^-5 ... 10^1
|
|
295
297
|
def power_10(x):
|
|
296
|
-
return 10.0 ** x
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10), # tune weight from 10^-1 ... 10^5
|
|
300
|
-
Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10), # tune lr from 10^-5 ... 10^1
|
|
301
|
-
]
|
|
298
|
+
return 10.0 ** x
|
|
299
|
+
hyperparams = [Hyperparameter('TUNABLE_WEIGHT', -1., 5., power_10),
|
|
300
|
+
Hyperparameter('TUNABLE_LEARNING_RATE', -5., 1., power_10)]
|
|
302
301
|
|
|
303
302
|
# build the tuner and tune
|
|
304
303
|
tuning = JaxParameterTuning(env=env,
|
|
305
|
-
config_template=config_template,
|
|
306
|
-
|
|
307
|
-
online=False,
|
|
308
|
-
eval_trials=trials,
|
|
309
|
-
num_workers=workers,
|
|
310
|
-
gp_iters=iters)
|
|
304
|
+
config_template=config_template, hyperparams=hyperparams,
|
|
305
|
+
online=False, eval_trials=trials, num_workers=workers, gp_iters=iters)
|
|
311
306
|
tuning.tune(key=42, log_file='path/to/log.csv')
|
|
312
307
|
```
|
|
313
308
|
|
|
@@ -1,20 +1,20 @@
|
|
|
1
|
-
pyRDDLGym_jax/__init__.py,sha256=
|
|
2
|
-
pyRDDLGym_jax/entry_point.py,sha256=
|
|
1
|
+
pyRDDLGym_jax/__init__.py,sha256=VoxLo_sy8RlJIIyu7szqL-cdMGBJdQPg-aSeyOVVIkY,19
|
|
2
|
+
pyRDDLGym_jax/entry_point.py,sha256=K0zy1oe66jfBHkHHCM6aGHbbiVqnQvDhDb8se4uaKHE,3319
|
|
3
3
|
pyRDDLGym_jax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
|
-
pyRDDLGym_jax/core/compiler.py,sha256=
|
|
4
|
+
pyRDDLGym_jax/core/compiler.py,sha256=uFCtoipsIa3MM9nGgT3X8iCViPl2XSPNXh0jMdzN0ko,82895
|
|
5
5
|
pyRDDLGym_jax/core/logic.py,sha256=lfc2ak_ap_ajMEFlB5EHCRNgJym31dNyA-5d-7N4CZA,56271
|
|
6
|
-
pyRDDLGym_jax/core/planner.py,sha256=
|
|
7
|
-
pyRDDLGym_jax/core/simulator.py,sha256=
|
|
8
|
-
pyRDDLGym_jax/core/tuning.py,sha256=
|
|
6
|
+
pyRDDLGym_jax/core/planner.py,sha256=M6GKzN7Ml57B4ZrFZhhkpsQCvReKaCQNzer7zeHCM9E,140275
|
|
7
|
+
pyRDDLGym_jax/core/simulator.py,sha256=ayCATTUL3clLaZPQ5OUg2bI_c26KKCTq6TbrxbMsVdc,10470
|
|
8
|
+
pyRDDLGym_jax/core/tuning.py,sha256=BWcQZk02TMLexTz1Sw4lX2EQKvmPbp7biC51M-IiNUw,25153
|
|
9
9
|
pyRDDLGym_jax/core/visualization.py,sha256=4BghMp8N7qtF0tdyDSqtxAxNfP9HPrQWTiXzAMJmx7o,70365
|
|
10
10
|
pyRDDLGym_jax/core/assets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
11
|
pyRDDLGym_jax/core/assets/favicon.ico,sha256=RMMrI9YvmF81TgYG7FO7UAre6WmYFkV3B2GmbA1l0kM,175085
|
|
12
12
|
pyRDDLGym_jax/examples/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
13
|
pyRDDLGym_jax/examples/run_gradient.py,sha256=KhXvijRDZ4V7N8NOI2WV8ePGpPna5_vnET61YwS7Tco,2919
|
|
14
14
|
pyRDDLGym_jax/examples/run_gym.py,sha256=rXvNWkxe4jHllvbvU_EOMji_2-2k5d4tbBKhpMm_Gaw,1526
|
|
15
|
-
pyRDDLGym_jax/examples/run_plan.py,sha256=
|
|
15
|
+
pyRDDLGym_jax/examples/run_plan.py,sha256=4y7JHqTxY5O1ltP6N7rar0jMiw7u9w1nuAIOcmDaAuE,2806
|
|
16
16
|
pyRDDLGym_jax/examples/run_scipy.py,sha256=7uVnDXb7D3NTJqA2L8nrcYDJP-k0ba9dl9YqA2CD9ac,2301
|
|
17
|
-
pyRDDLGym_jax/examples/run_tune.py,sha256=
|
|
17
|
+
pyRDDLGym_jax/examples/run_tune.py,sha256=F5KWgtoCPbf7XHB6HW9LjxarD57U2LvuGdTz67OL1DY,4114
|
|
18
18
|
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg,sha256=mE8MqhOlkHeXIGEVrnR3QY6I-_iy4uxFYRA71P1bmtk,347
|
|
19
19
|
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg,sha256=nFFYHCKQUMn8x-OpJwu2pwe1tycNSJ8iAIwSkCBn33E,370
|
|
20
20
|
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg,sha256=eJ3HvHjODoKdtX7u-AM51xQaHJnYgzEy2t3omNG2oCs,340
|
|
@@ -38,12 +38,12 @@ pyRDDLGym_jax/examples/configs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5
|
|
|
38
38
|
pyRDDLGym_jax/examples/configs/default_drp.cfg,sha256=XeMWAAG_OFZo7JAMxS5-XXroZaeVMzfM0NswmEobIns,373
|
|
39
39
|
pyRDDLGym_jax/examples/configs/default_replan.cfg,sha256=CK4cEz8ReXyAZPLaLG9clIIRXAqM3IplUCxbLt_V2lY,407
|
|
40
40
|
pyRDDLGym_jax/examples/configs/default_slp.cfg,sha256=mJo0woDevhQCSQfJg30ULVy9qGIJDIw73XCe6pyIPtg,369
|
|
41
|
-
pyRDDLGym_jax/examples/configs/tuning_drp.cfg,sha256=
|
|
42
|
-
pyRDDLGym_jax/examples/configs/tuning_replan.cfg,sha256=
|
|
43
|
-
pyRDDLGym_jax/examples/configs/tuning_slp.cfg,sha256=
|
|
44
|
-
pyrddlgym_jax-2.
|
|
45
|
-
pyrddlgym_jax-2.
|
|
46
|
-
pyrddlgym_jax-2.
|
|
47
|
-
pyrddlgym_jax-2.
|
|
48
|
-
pyrddlgym_jax-2.
|
|
49
|
-
pyrddlgym_jax-2.
|
|
41
|
+
pyRDDLGym_jax/examples/configs/tuning_drp.cfg,sha256=zocZn_cVarH5i0hOlt2Zu0NwmXYBmTTghLaXLtQOGto,526
|
|
42
|
+
pyRDDLGym_jax/examples/configs/tuning_replan.cfg,sha256=9oIhtw9cuikmlbDgCgbrTc5G7hUio-HeAv_3CEGVclY,523
|
|
43
|
+
pyRDDLGym_jax/examples/configs/tuning_slp.cfg,sha256=QqnyR__5-HhKeCDfGDel8VIlqsjxRHk4SSH089zJP8s,486
|
|
44
|
+
pyrddlgym_jax-2.5.dist-info/licenses/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
|
|
45
|
+
pyrddlgym_jax-2.5.dist-info/METADATA,sha256=XAaEJfbsYW-txxZhFZ6o_HmvqxkIMTqBF9LbV-KdTzI,17058
|
|
46
|
+
pyrddlgym_jax-2.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
47
|
+
pyrddlgym_jax-2.5.dist-info/entry_points.txt,sha256=Q--z9QzqDBz1xjswPZ87PU-pib-WPXx44hUWAFoBGBA,59
|
|
48
|
+
pyrddlgym_jax-2.5.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
|
|
49
|
+
pyrddlgym_jax-2.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|