pyRDDLGym-jax 1.3__py3-none-any.whl → 2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +16 -1
- pyRDDLGym_jax/core/logic.py +36 -9
- pyRDDLGym_jax/core/planner.py +445 -90
- pyRDDLGym_jax/core/simulator.py +20 -0
- pyRDDLGym_jax/core/tuning.py +15 -0
- pyRDDLGym_jax/core/visualization.py +48 -0
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -3
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +1 -0
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +1 -0
- {pyRDDLGym_jax-1.3.dist-info → pyRDDLGym_jax-2.0.dist-info}/METADATA +1 -1
- {pyRDDLGym_jax-1.3.dist-info → pyRDDLGym_jax-2.0.dist-info}/RECORD +23 -23
- {pyRDDLGym_jax-1.3.dist-info → pyRDDLGym_jax-2.0.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-1.3.dist-info → pyRDDLGym_jax-2.0.dist-info}/WHEEL +0 -0
- {pyRDDLGym_jax-1.3.dist-info → pyRDDLGym_jax-2.0.dist-info}/entry_points.txt +0 -0
- {pyRDDLGym_jax-1.3.dist-info → pyRDDLGym_jax-2.0.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '
|
|
1
|
+
__version__ = '2.0'
|
pyRDDLGym_jax/core/compiler.py
CHANGED
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# ***********************************************************************
|
|
2
|
+
# JAXPLAN
|
|
3
|
+
#
|
|
4
|
+
# Author: Michael Gimelfarb
|
|
5
|
+
#
|
|
6
|
+
# REFERENCES:
|
|
7
|
+
#
|
|
8
|
+
# [1] Gimelfarb, Michael, Ayal Taitler, and Scott Sanner. "JaxPlan and GurobiPlan:
|
|
9
|
+
# Optimization Baselines for Replanning in Discrete and Mixed Discrete-Continuous
|
|
10
|
+
# Probabilistic Domains." Proceedings of the International Conference on Automated
|
|
11
|
+
# Planning and Scheduling. Vol. 34. 2024.
|
|
12
|
+
#
|
|
13
|
+
# ***********************************************************************
|
|
14
|
+
|
|
15
|
+
|
|
1
16
|
from functools import partial
|
|
2
17
|
import traceback
|
|
3
18
|
from typing import Any, Callable, Dict, List, Optional
|
|
@@ -524,7 +539,7 @@ class JaxRDDLCompiler:
|
|
|
524
539
|
_jax_wrapped_single_step_policy,
|
|
525
540
|
in_axes=(0, None, None, None, 0, None)
|
|
526
541
|
)(keys, policy_params, hyperparams, step, subs, model_params)
|
|
527
|
-
model_params = jax.tree_map(
|
|
542
|
+
model_params = jax.tree_map(partial(jnp.mean, axis=0), model_params)
|
|
528
543
|
carry = (key, policy_params, hyperparams, subs, model_params)
|
|
529
544
|
return carry, log
|
|
530
545
|
|
pyRDDLGym_jax/core/logic.py
CHANGED
|
@@ -1,4 +1,31 @@
|
|
|
1
|
-
|
|
1
|
+
# ***********************************************************************
|
|
2
|
+
# JAXPLAN
|
|
3
|
+
#
|
|
4
|
+
# Author: Michael Gimelfarb
|
|
5
|
+
#
|
|
6
|
+
# REFERENCES:
|
|
7
|
+
#
|
|
8
|
+
# [1] Gimelfarb, Michael, Ayal Taitler, and Scott Sanner. "JaxPlan and GurobiPlan:
|
|
9
|
+
# Optimization Baselines for Replanning in Discrete and Mixed Discrete-Continuous
|
|
10
|
+
# Probabilistic Domains." Proceedings of the International Conference on Automated
|
|
11
|
+
# Planning and Scheduling. Vol. 34. 2024.
|
|
12
|
+
#
|
|
13
|
+
# [2] Petersen, Felix, Christian Borgelt, Hilde Kuehne, and Oliver Deussen. "Learning with
|
|
14
|
+
# algorithmic supervision via continuous relaxations." Advances in Neural Information
|
|
15
|
+
# Processing Systems 34 (2021): 16520-16531.
|
|
16
|
+
#
|
|
17
|
+
# [3] Agustsson, Eirikur, and Lucas Theis. "Universally quantized neural compression."
|
|
18
|
+
# Advances in neural information processing systems 33 (2020): 12367-12376.
|
|
19
|
+
#
|
|
20
|
+
# [4] Gupta, Madan M., and J11043360726 Qi. "Theory of T-norms and fuzzy inference
|
|
21
|
+
# methods." Fuzzy sets and systems 40, no. 3 (1991): 431-450.
|
|
22
|
+
#
|
|
23
|
+
# [5] Jang, Eric, Shixiang Gu, and Ben Poole. "Categorical Reparametrization with
|
|
24
|
+
# Gumble-Softmax." In International Conference on Learning Representations (ICLR 2017).
|
|
25
|
+
# OpenReview. net, 2017.
|
|
26
|
+
#
|
|
27
|
+
# ***********************************************************************
|
|
28
|
+
|
|
2
29
|
|
|
3
30
|
import jax
|
|
4
31
|
import jax.numpy as jnp
|
|
@@ -759,14 +786,14 @@ class FuzzyLogic(Logic):
|
|
|
759
786
|
|
|
760
787
|
def __str__(self) -> str:
|
|
761
788
|
return (f'model relaxation:\n'
|
|
762
|
-
f' tnorm
|
|
763
|
-
f' complement
|
|
764
|
-
f' comparison
|
|
765
|
-
f' sampling
|
|
766
|
-
f' rounding
|
|
767
|
-
f' control
|
|
768
|
-
f' underflow_tol
|
|
769
|
-
f' use_64_bit
|
|
789
|
+
f' tnorm ={str(self.tnorm)}\n'
|
|
790
|
+
f' complement ={str(self.complement)}\n'
|
|
791
|
+
f' comparison ={str(self.comparison)}\n'
|
|
792
|
+
f' sampling ={str(self.sampling)}\n'
|
|
793
|
+
f' rounding ={str(self.rounding)}\n'
|
|
794
|
+
f' control ={str(self.control)}\n'
|
|
795
|
+
f' underflow_tol={self.eps}\n'
|
|
796
|
+
f' use_64_bit ={self.use64bit}\n')
|
|
770
797
|
|
|
771
798
|
def summarize_hyperparameters(self) -> None:
|
|
772
799
|
print(self.__str__())
|
pyRDDLGym_jax/core/planner.py
CHANGED
|
@@ -1,12 +1,43 @@
|
|
|
1
|
+
# ***********************************************************************
|
|
2
|
+
# JAXPLAN
|
|
3
|
+
#
|
|
4
|
+
# Author: Michael Gimelfarb
|
|
5
|
+
#
|
|
6
|
+
# RELEVANT SOURCES:
|
|
7
|
+
#
|
|
8
|
+
# [1] Gimelfarb, Michael, Ayal Taitler, and Scott Sanner. "JaxPlan and GurobiPlan:
|
|
9
|
+
# Optimization Baselines for Replanning in Discrete and Mixed Discrete-Continuous
|
|
10
|
+
# Probabilistic Domains." Proceedings of the International Conference on Automated
|
|
11
|
+
# Planning and Scheduling. Vol. 34. 2024.
|
|
12
|
+
#
|
|
13
|
+
# [2] Patton, Noah, Jihwan Jeong, Mike Gimelfarb, and Scott Sanner. "A Distributional
|
|
14
|
+
# Framework for Risk-Sensitive End-to-End Planning in Continuous MDPs." In Proceedings of
|
|
15
|
+
# the AAAI Conference on Artificial Intelligence, vol. 36, no. 9, pp. 9894-9901. 2022.
|
|
16
|
+
#
|
|
17
|
+
# [3] Bueno, Thiago P., Leliane N. de Barros, Denis D. Mauá, and Scott Sanner. "Deep
|
|
18
|
+
# reactive policies for planning in stochastic nonlinear domains." In Proceedings of the
|
|
19
|
+
# AAAI Conference on Artificial Intelligence, vol. 33, no. 01, pp. 7530-7537. 2019.
|
|
20
|
+
#
|
|
21
|
+
# [4] Wu, Ga, Buser Say, and Scott Sanner. "Scalable planning with tensorflow for hybrid
|
|
22
|
+
# nonlinear domains." Advances in Neural Information Processing Systems 30 (2017).
|
|
23
|
+
#
|
|
24
|
+
# [5] Sehnke, Frank, and Tingting Zhao. "Baseline-free sampling in parameter exploring
|
|
25
|
+
# policy gradients: Super symmetric pgpe." Artificial Neural Networks: Methods and
|
|
26
|
+
# Applications in Bio-/Neuroinformatics. Springer International Publishing, 2015.
|
|
27
|
+
#
|
|
28
|
+
# ***********************************************************************
|
|
29
|
+
|
|
30
|
+
|
|
1
31
|
from ast import literal_eval
|
|
2
32
|
from collections import deque
|
|
3
33
|
import configparser
|
|
4
34
|
from enum import Enum
|
|
35
|
+
from functools import partial
|
|
5
36
|
import os
|
|
6
37
|
import sys
|
|
7
38
|
import time
|
|
8
39
|
import traceback
|
|
9
|
-
from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Tuple, Union
|
|
40
|
+
from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Type, Tuple, Union
|
|
10
41
|
|
|
11
42
|
import haiku as hk
|
|
12
43
|
import jax
|
|
@@ -163,7 +194,20 @@ def _load_config(config, args):
|
|
|
163
194
|
del planner_args['optimizer']
|
|
164
195
|
else:
|
|
165
196
|
planner_args['optimizer'] = optimizer
|
|
166
|
-
|
|
197
|
+
|
|
198
|
+
# pgpe optimizer
|
|
199
|
+
pgpe_method = planner_args.get('pgpe', 'GaussianPGPE')
|
|
200
|
+
pgpe_kwargs = planner_args.pop('pgpe_kwargs', {})
|
|
201
|
+
if pgpe_method is not None:
|
|
202
|
+
if 'optimizer' in pgpe_kwargs:
|
|
203
|
+
pgpe_optimizer = _getattr_any(packages=[optax], item=pgpe_kwargs['optimizer'])
|
|
204
|
+
if pgpe_optimizer is None:
|
|
205
|
+
raise_warning(f'Ignoring invalid optimizer <{pgpe_optimizer}>.', 'red')
|
|
206
|
+
del pgpe_kwargs['optimizer']
|
|
207
|
+
else:
|
|
208
|
+
pgpe_kwargs['optimizer'] = pgpe_optimizer
|
|
209
|
+
planner_args['pgpe'] = getattr(sys.modules[__name__], pgpe_method)(**pgpe_kwargs)
|
|
210
|
+
|
|
167
211
|
# optimize call RNG key
|
|
168
212
|
planner_key = train_args.get('key', None)
|
|
169
213
|
if planner_key is not None:
|
|
@@ -469,16 +513,16 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
469
513
|
bounds = '\n '.join(
|
|
470
514
|
map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
|
|
471
515
|
return (f'policy hyper-parameters:\n'
|
|
472
|
-
f' initializer
|
|
473
|
-
f'constraint-sat strategy (simple):\n'
|
|
474
|
-
f'
|
|
475
|
-
f'
|
|
476
|
-
f'
|
|
477
|
-
f'
|
|
478
|
-
f'constraint-sat strategy (complex):\n'
|
|
479
|
-
f'
|
|
480
|
-
f'
|
|
481
|
-
f'
|
|
516
|
+
f' initializer={self._initializer_base}\n'
|
|
517
|
+
f' constraint-sat strategy (simple):\n'
|
|
518
|
+
f' parsed_action_bounds =\n {bounds}\n'
|
|
519
|
+
f' wrap_sigmoid ={self._wrap_sigmoid}\n'
|
|
520
|
+
f' wrap_sigmoid_min_prob={self._min_action_prob}\n'
|
|
521
|
+
f' wrap_non_bool ={self._wrap_non_bool}\n'
|
|
522
|
+
f' constraint-sat strategy (complex):\n'
|
|
523
|
+
f' wrap_softmax ={self._wrap_softmax}\n'
|
|
524
|
+
f' use_new_projection ={self._use_new_projection}\n'
|
|
525
|
+
f' max_projection_iters={self._max_constraint_iter}\n')
|
|
482
526
|
|
|
483
527
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
484
528
|
_bounds: Bounds,
|
|
@@ -856,15 +900,16 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
856
900
|
bounds = '\n '.join(
|
|
857
901
|
map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
|
|
858
902
|
return (f'policy hyper-parameters:\n'
|
|
859
|
-
f' topology
|
|
860
|
-
f' activation_fn
|
|
861
|
-
f' initializer
|
|
862
|
-
f'
|
|
863
|
-
f'
|
|
864
|
-
f'
|
|
865
|
-
f'
|
|
866
|
-
f'
|
|
867
|
-
f'
|
|
903
|
+
f' topology ={self._topology}\n'
|
|
904
|
+
f' activation_fn={self._activations[0].__name__}\n'
|
|
905
|
+
f' initializer ={type(self._initializer_base).__name__}\n'
|
|
906
|
+
f' input norm:\n'
|
|
907
|
+
f' apply_input_norm ={self._normalize}\n'
|
|
908
|
+
f' input_norm_layerwise={self._normalize_per_layer}\n'
|
|
909
|
+
f' input_norm_args ={self._normalizer_kwargs}\n'
|
|
910
|
+
f' constraint-sat strategy:\n'
|
|
911
|
+
f' parsed_action_bounds=\n {bounds}\n'
|
|
912
|
+
f' wrap_non_bool ={self._wrap_non_bool}\n')
|
|
868
913
|
|
|
869
914
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
870
915
|
_bounds: Bounds,
|
|
@@ -1090,10 +1135,11 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1090
1135
|
|
|
1091
1136
|
|
|
1092
1137
|
# ***********************************************************************
|
|
1093
|
-
#
|
|
1138
|
+
# SUPPORTING FUNCTIONS
|
|
1094
1139
|
#
|
|
1095
|
-
# -
|
|
1096
|
-
# -
|
|
1140
|
+
# - smoothed mean calculation
|
|
1141
|
+
# - planner status
|
|
1142
|
+
# - stopping criteria
|
|
1097
1143
|
#
|
|
1098
1144
|
# ***********************************************************************
|
|
1099
1145
|
|
|
@@ -1167,6 +1213,264 @@ class NoImprovementStoppingRule(JaxPlannerStoppingRule):
|
|
|
1167
1213
|
return f'No improvement for {self.patience} iterations'
|
|
1168
1214
|
|
|
1169
1215
|
|
|
1216
|
+
# ***********************************************************************
|
|
1217
|
+
# PARAMETER EXPLORING POLICY GRADIENTS (PGPE)
|
|
1218
|
+
#
|
|
1219
|
+
# - simple Gaussian PGPE
|
|
1220
|
+
#
|
|
1221
|
+
# ***********************************************************************
|
|
1222
|
+
|
|
1223
|
+
|
|
1224
|
+
class PGPE:
|
|
1225
|
+
"""Base class for all PGPE strategies."""
|
|
1226
|
+
|
|
1227
|
+
def __init__(self) -> None:
|
|
1228
|
+
self._initializer = None
|
|
1229
|
+
self._update = None
|
|
1230
|
+
|
|
1231
|
+
@property
|
|
1232
|
+
def initialize(self):
|
|
1233
|
+
return self._initializer
|
|
1234
|
+
|
|
1235
|
+
@property
|
|
1236
|
+
def update(self):
|
|
1237
|
+
return self._update
|
|
1238
|
+
|
|
1239
|
+
def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type) -> None:
|
|
1240
|
+
raise NotImplementedError
|
|
1241
|
+
|
|
1242
|
+
|
|
1243
|
+
class GaussianPGPE(PGPE):
|
|
1244
|
+
'''PGPE with a Gaussian parameter distribution.'''
|
|
1245
|
+
|
|
1246
|
+
def __init__(self, batch_size: int=1,
|
|
1247
|
+
init_sigma: float=1.0,
|
|
1248
|
+
sigma_range: Tuple[float, float]=(1e-5, 1e5),
|
|
1249
|
+
scale_reward: bool=True,
|
|
1250
|
+
super_symmetric: bool=True,
|
|
1251
|
+
super_symmetric_accurate: bool=True,
|
|
1252
|
+
optimizer: Callable[..., optax.GradientTransformation]=optax.adam,
|
|
1253
|
+
optimizer_kwargs_mu: Optional[Kwargs]=None,
|
|
1254
|
+
optimizer_kwargs_sigma: Optional[Kwargs]=None) -> None:
|
|
1255
|
+
'''Creates a new Gaussian PGPE planner.
|
|
1256
|
+
|
|
1257
|
+
:param batch_size: how many policy parameters to sample per optimization step
|
|
1258
|
+
:param init_sigma: initial standard deviation of Gaussian
|
|
1259
|
+
:param sigma_range: bounds to constrain standard deviation
|
|
1260
|
+
:param scale_reward: whether to apply reward scaling as in the paper
|
|
1261
|
+
:param super_symmetric: whether to use super-symmetric sampling as in the paper
|
|
1262
|
+
:param super_symmetric_accurate: whether to use the accurate formula for super-
|
|
1263
|
+
symmetric sampling or the simplified but biased formula
|
|
1264
|
+
:param optimizer: a factory for an optax SGD algorithm
|
|
1265
|
+
:param optimizer_kwargs_mu: a dictionary of parameters to pass to the SGD
|
|
1266
|
+
factory for the mean optimizer
|
|
1267
|
+
:param optimizer_kwargs_sigma: a dictionary of parameters to pass to the SGD
|
|
1268
|
+
factory for the standard deviation optimizer
|
|
1269
|
+
'''
|
|
1270
|
+
super().__init__()
|
|
1271
|
+
|
|
1272
|
+
self.batch_size = batch_size
|
|
1273
|
+
self.init_sigma = init_sigma
|
|
1274
|
+
self.sigma_range = sigma_range
|
|
1275
|
+
self.scale_reward = scale_reward
|
|
1276
|
+
self.super_symmetric = super_symmetric
|
|
1277
|
+
self.super_symmetric_accurate = super_symmetric_accurate
|
|
1278
|
+
|
|
1279
|
+
# set optimizers
|
|
1280
|
+
if optimizer_kwargs_mu is None:
|
|
1281
|
+
optimizer_kwargs_mu = {'learning_rate': 0.1}
|
|
1282
|
+
self.optimizer_kwargs_mu = optimizer_kwargs_mu
|
|
1283
|
+
if optimizer_kwargs_sigma is None:
|
|
1284
|
+
optimizer_kwargs_sigma = {'learning_rate': 0.1}
|
|
1285
|
+
self.optimizer_kwargs_sigma = optimizer_kwargs_sigma
|
|
1286
|
+
self.optimizer_name = optimizer
|
|
1287
|
+
mu_optimizer = optimizer(**optimizer_kwargs_mu)
|
|
1288
|
+
sigma_optimizer = optimizer(**optimizer_kwargs_sigma)
|
|
1289
|
+
self.optimizers = (mu_optimizer, sigma_optimizer)
|
|
1290
|
+
|
|
1291
|
+
def __str__(self) -> str:
|
|
1292
|
+
return (f'PGPE hyper-parameters:\n'
|
|
1293
|
+
f' method ={self.__class__.__name__}\n'
|
|
1294
|
+
f' batch_size ={self.batch_size}\n'
|
|
1295
|
+
f' init_sigma ={self.init_sigma}\n'
|
|
1296
|
+
f' sigma_range ={self.sigma_range}\n'
|
|
1297
|
+
f' scale_reward ={self.scale_reward}\n'
|
|
1298
|
+
f' super_symmetric={self.super_symmetric}\n'
|
|
1299
|
+
f' accurate ={self.super_symmetric_accurate}\n'
|
|
1300
|
+
f' optimizer ={self.optimizer_name}\n'
|
|
1301
|
+
f' optimizer_kwargs:\n'
|
|
1302
|
+
f' mu ={self.optimizer_kwargs_mu}\n'
|
|
1303
|
+
f' sigma={self.optimizer_kwargs_sigma}\n'
|
|
1304
|
+
)
|
|
1305
|
+
|
|
1306
|
+
def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type) -> None:
|
|
1307
|
+
MIN_NORM = 1e-5
|
|
1308
|
+
sigma0 = self.init_sigma
|
|
1309
|
+
sigma_range = self.sigma_range
|
|
1310
|
+
scale_reward = self.scale_reward
|
|
1311
|
+
super_symmetric = self.super_symmetric
|
|
1312
|
+
super_symmetric_accurate = self.super_symmetric_accurate
|
|
1313
|
+
batch_size = self.batch_size
|
|
1314
|
+
optimizers = (mu_optimizer, sigma_optimizer) = self.optimizers
|
|
1315
|
+
|
|
1316
|
+
# initializer
|
|
1317
|
+
def _jax_wrapped_pgpe_init(key, policy_params):
|
|
1318
|
+
mu = policy_params
|
|
1319
|
+
sigma = jax.tree_map(lambda x: sigma0 * jnp.ones_like(x), mu)
|
|
1320
|
+
pgpe_params = (mu, sigma)
|
|
1321
|
+
pgpe_opt_state = tuple(opt.init(param)
|
|
1322
|
+
for (opt, param) in zip(optimizers, pgpe_params))
|
|
1323
|
+
return pgpe_params, pgpe_opt_state
|
|
1324
|
+
|
|
1325
|
+
self._initializer = jax.jit(_jax_wrapped_pgpe_init)
|
|
1326
|
+
|
|
1327
|
+
# parameter sampling functions
|
|
1328
|
+
def _jax_wrapped_mu_noise(key, sigma):
|
|
1329
|
+
return sigma * random.normal(key, shape=jnp.shape(sigma), dtype=real_dtype)
|
|
1330
|
+
|
|
1331
|
+
def _jax_wrapped_epsilon_star(sigma, epsilon):
|
|
1332
|
+
c1, c2, c3 = -0.06655, -0.9706, 0.124
|
|
1333
|
+
phi = 0.67449 * sigma
|
|
1334
|
+
a = (sigma - jnp.abs(epsilon)) / sigma
|
|
1335
|
+
if super_symmetric_accurate:
|
|
1336
|
+
aa = jnp.abs(a)
|
|
1337
|
+
epsilon_star = jnp.sign(epsilon) * phi * jnp.where(
|
|
1338
|
+
a <= 0,
|
|
1339
|
+
jnp.exp(c1 * aa * (aa * aa - 1) / jnp.log(aa + 1e-10) + c2 * aa),
|
|
1340
|
+
jnp.exp(aa - c3 * aa * jnp.log(1.0 - jnp.power(aa, 3) + 1e-10))
|
|
1341
|
+
)
|
|
1342
|
+
else:
|
|
1343
|
+
epsilon_star = jnp.sign(epsilon) * phi * jnp.exp(a)
|
|
1344
|
+
return epsilon_star
|
|
1345
|
+
|
|
1346
|
+
def _jax_wrapped_sample_params(key, mu, sigma):
|
|
1347
|
+
keys = random.split(key, num=len(jax.tree_util.tree_leaves(mu)))
|
|
1348
|
+
keys_pytree = jax.tree_util.tree_unflatten(
|
|
1349
|
+
treedef=jax.tree_util.tree_structure(mu), leaves=keys)
|
|
1350
|
+
epsilon = jax.tree_map(_jax_wrapped_mu_noise, keys_pytree, sigma)
|
|
1351
|
+
p1 = jax.tree_map(jnp.add, mu, epsilon)
|
|
1352
|
+
p2 = jax.tree_map(jnp.subtract, mu, epsilon)
|
|
1353
|
+
if super_symmetric:
|
|
1354
|
+
epsilon_star = jax.tree_map(_jax_wrapped_epsilon_star, sigma, epsilon)
|
|
1355
|
+
p3 = jax.tree_map(jnp.add, mu, epsilon_star)
|
|
1356
|
+
p4 = jax.tree_map(jnp.subtract, mu, epsilon_star)
|
|
1357
|
+
else:
|
|
1358
|
+
epsilon_star, p3, p4 = epsilon, p1, p2
|
|
1359
|
+
return (p1, p2, p3, p4), (epsilon, epsilon_star)
|
|
1360
|
+
|
|
1361
|
+
# policy gradient update functions
|
|
1362
|
+
def _jax_wrapped_mu_grad(epsilon, epsilon_star, r1, r2, r3, r4, m):
|
|
1363
|
+
if super_symmetric:
|
|
1364
|
+
if scale_reward:
|
|
1365
|
+
scale1 = jnp.maximum(MIN_NORM, m - (r1 + r2) / 2)
|
|
1366
|
+
scale2 = jnp.maximum(MIN_NORM, m - (r3 + r4) / 2)
|
|
1367
|
+
else:
|
|
1368
|
+
scale1 = scale2 = 1.0
|
|
1369
|
+
r_mu1 = (r1 - r2) / (2 * scale1)
|
|
1370
|
+
r_mu2 = (r3 - r4) / (2 * scale2)
|
|
1371
|
+
grad = -(r_mu1 * epsilon + r_mu2 * epsilon_star)
|
|
1372
|
+
else:
|
|
1373
|
+
if scale_reward:
|
|
1374
|
+
scale = jnp.maximum(MIN_NORM, m - (r1 + r2) / 2)
|
|
1375
|
+
else:
|
|
1376
|
+
scale = 1.0
|
|
1377
|
+
r_mu = (r1 - r2) / (2 * scale)
|
|
1378
|
+
grad = -r_mu * epsilon
|
|
1379
|
+
return grad
|
|
1380
|
+
|
|
1381
|
+
def _jax_wrapped_sigma_grad(epsilon, epsilon_star, sigma, r1, r2, r3, r4, m):
|
|
1382
|
+
if super_symmetric:
|
|
1383
|
+
mask = r1 + r2 >= r3 + r4
|
|
1384
|
+
epsilon_tau = mask * epsilon + (1 - mask) * epsilon_star
|
|
1385
|
+
s = epsilon_tau * epsilon_tau / sigma - sigma
|
|
1386
|
+
if scale_reward:
|
|
1387
|
+
scale = jnp.maximum(MIN_NORM, m - (r1 + r2 + r3 + r4) / 4)
|
|
1388
|
+
else:
|
|
1389
|
+
scale = 1.0
|
|
1390
|
+
r_sigma = ((r1 + r2) - (r3 + r4)) / (4 * scale)
|
|
1391
|
+
else:
|
|
1392
|
+
s = epsilon * epsilon / sigma - sigma
|
|
1393
|
+
if scale_reward:
|
|
1394
|
+
scale = jnp.maximum(MIN_NORM, jnp.abs(m))
|
|
1395
|
+
else:
|
|
1396
|
+
scale = 1.0
|
|
1397
|
+
r_sigma = (r1 + r2) / (2 * scale)
|
|
1398
|
+
grad = -r_sigma * s
|
|
1399
|
+
return grad
|
|
1400
|
+
|
|
1401
|
+
def _jax_wrapped_pgpe_grad(key, mu, sigma, r_max,
|
|
1402
|
+
policy_hyperparams, subs, model_params):
|
|
1403
|
+
key, subkey = random.split(key)
|
|
1404
|
+
(p1, p2, p3, p4), (epsilon, epsilon_star) = _jax_wrapped_sample_params(
|
|
1405
|
+
key, mu, sigma)
|
|
1406
|
+
r1 = -loss_fn(subkey, p1, policy_hyperparams, subs, model_params)[0]
|
|
1407
|
+
r2 = -loss_fn(subkey, p2, policy_hyperparams, subs, model_params)[0]
|
|
1408
|
+
r_max = jnp.maximum(r_max, r1)
|
|
1409
|
+
r_max = jnp.maximum(r_max, r2)
|
|
1410
|
+
if super_symmetric:
|
|
1411
|
+
r3 = -loss_fn(subkey, p3, policy_hyperparams, subs, model_params)[0]
|
|
1412
|
+
r4 = -loss_fn(subkey, p4, policy_hyperparams, subs, model_params)[0]
|
|
1413
|
+
r_max = jnp.maximum(r_max, r3)
|
|
1414
|
+
r_max = jnp.maximum(r_max, r4)
|
|
1415
|
+
else:
|
|
1416
|
+
r3, r4 = r1, r2
|
|
1417
|
+
grad_mu = jax.tree_map(
|
|
1418
|
+
partial(_jax_wrapped_mu_grad, r1=r1, r2=r2, r3=r3, r4=r4, m=r_max),
|
|
1419
|
+
epsilon, epsilon_star
|
|
1420
|
+
)
|
|
1421
|
+
grad_sigma = jax.tree_map(
|
|
1422
|
+
partial(_jax_wrapped_sigma_grad, r1=r1, r2=r2, r3=r3, r4=r4, m=r_max),
|
|
1423
|
+
epsilon, epsilon_star, sigma
|
|
1424
|
+
)
|
|
1425
|
+
return grad_mu, grad_sigma, r_max
|
|
1426
|
+
|
|
1427
|
+
def _jax_wrapped_pgpe_grad_batched(key, pgpe_params, r_max,
|
|
1428
|
+
policy_hyperparams, subs, model_params):
|
|
1429
|
+
mu, sigma = pgpe_params
|
|
1430
|
+
if batch_size == 1:
|
|
1431
|
+
mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad(
|
|
1432
|
+
key, mu, sigma, r_max, policy_hyperparams, subs, model_params)
|
|
1433
|
+
else:
|
|
1434
|
+
keys = random.split(key, num=batch_size)
|
|
1435
|
+
mu_grads, sigma_grads, r_maxs = jax.vmap(
|
|
1436
|
+
_jax_wrapped_pgpe_grad,
|
|
1437
|
+
in_axes=(0, None, None, None, None, None, None)
|
|
1438
|
+
)(keys, mu, sigma, r_max, policy_hyperparams, subs, model_params)
|
|
1439
|
+
mu_grad = jax.tree_map(partial(jnp.mean, axis=0), mu_grads)
|
|
1440
|
+
sigma_grad = jax.tree_map(partial(jnp.mean, axis=0), sigma_grads)
|
|
1441
|
+
new_r_max = jnp.max(r_maxs)
|
|
1442
|
+
return mu_grad, sigma_grad, new_r_max
|
|
1443
|
+
|
|
1444
|
+
def _jax_wrapped_pgpe_update(key, pgpe_params, r_max,
|
|
1445
|
+
policy_hyperparams, subs, model_params,
|
|
1446
|
+
pgpe_opt_state):
|
|
1447
|
+
mu, sigma = pgpe_params
|
|
1448
|
+
mu_state, sigma_state = pgpe_opt_state
|
|
1449
|
+
mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad_batched(
|
|
1450
|
+
key, pgpe_params, r_max, policy_hyperparams, subs, model_params)
|
|
1451
|
+
mu_updates, new_mu_state = mu_optimizer.update(mu_grad, mu_state, params=mu)
|
|
1452
|
+
sigma_updates, new_sigma_state = sigma_optimizer.update(
|
|
1453
|
+
sigma_grad, sigma_state, params=sigma)
|
|
1454
|
+
new_mu = optax.apply_updates(mu, mu_updates)
|
|
1455
|
+
new_mu, converged = projection(new_mu, policy_hyperparams)
|
|
1456
|
+
new_sigma = optax.apply_updates(sigma, sigma_updates)
|
|
1457
|
+
new_sigma = jax.tree_map(lambda x: jnp.clip(x, *sigma_range), new_sigma)
|
|
1458
|
+
new_pgpe_params = (new_mu, new_sigma)
|
|
1459
|
+
new_pgpe_opt_state = (new_mu_state, new_sigma_state)
|
|
1460
|
+
policy_params = new_mu
|
|
1461
|
+
return new_pgpe_params, new_r_max, new_pgpe_opt_state, policy_params, converged
|
|
1462
|
+
|
|
1463
|
+
self._update = jax.jit(_jax_wrapped_pgpe_update)
|
|
1464
|
+
|
|
1465
|
+
|
|
1466
|
+
# ***********************************************************************
|
|
1467
|
+
# ALL VERSIONS OF JAX PLANNER
|
|
1468
|
+
#
|
|
1469
|
+
# - simple gradient descent based planner
|
|
1470
|
+
#
|
|
1471
|
+
# ***********************************************************************
|
|
1472
|
+
|
|
1473
|
+
|
|
1170
1474
|
class JaxBackpropPlanner:
|
|
1171
1475
|
'''A class for optimizing an action sequence in the given RDDL MDP using
|
|
1172
1476
|
gradient descent.'''
|
|
@@ -1183,6 +1487,7 @@ class JaxBackpropPlanner:
|
|
|
1183
1487
|
clip_grad: Optional[float]=None,
|
|
1184
1488
|
line_search_kwargs: Optional[Kwargs]=None,
|
|
1185
1489
|
noise_kwargs: Optional[Kwargs]=None,
|
|
1490
|
+
pgpe: Optional[PGPE]=GaussianPGPE(),
|
|
1186
1491
|
logic: Logic=FuzzyLogic(),
|
|
1187
1492
|
use_symlog_reward: bool=False,
|
|
1188
1493
|
utility: Union[Callable[[jnp.ndarray], float], str]='mean',
|
|
@@ -1213,6 +1518,7 @@ class JaxBackpropPlanner:
|
|
|
1213
1518
|
:param line_search_kwargs: parameters to pass to optional line search
|
|
1214
1519
|
method to scale learning rate
|
|
1215
1520
|
:param noise_kwargs: parameters of optional gradient noise
|
|
1521
|
+
:param pgpe: optional policy gradient to run alongside the planner
|
|
1216
1522
|
:param logic: a subclass of Logic for mapping exact mathematical
|
|
1217
1523
|
operations to their differentiable counterparts
|
|
1218
1524
|
:param use_symlog_reward: whether to use the symlog transform on the
|
|
@@ -1251,6 +1557,8 @@ class JaxBackpropPlanner:
|
|
|
1251
1557
|
self.clip_grad = clip_grad
|
|
1252
1558
|
self.line_search_kwargs = line_search_kwargs
|
|
1253
1559
|
self.noise_kwargs = noise_kwargs
|
|
1560
|
+
self.pgpe = pgpe
|
|
1561
|
+
self.use_pgpe = pgpe is not None
|
|
1254
1562
|
|
|
1255
1563
|
# set optimizer
|
|
1256
1564
|
try:
|
|
@@ -1355,24 +1663,25 @@ r"""
|
|
|
1355
1663
|
f' line_search_kwargs={self.line_search_kwargs}\n'
|
|
1356
1664
|
f' noise_kwargs ={self.noise_kwargs}\n'
|
|
1357
1665
|
f' batch_size_train ={self.batch_size_train}\n'
|
|
1358
|
-
f' batch_size_test ={self.batch_size_test}')
|
|
1359
|
-
result +=
|
|
1360
|
-
|
|
1666
|
+
f' batch_size_test ={self.batch_size_test}\n')
|
|
1667
|
+
result += str(self.plan)
|
|
1668
|
+
if self.use_pgpe:
|
|
1669
|
+
result += str(self.pgpe)
|
|
1670
|
+
result += str(self.logic)
|
|
1361
1671
|
|
|
1362
1672
|
# print model relaxation information
|
|
1363
|
-
if
|
|
1364
|
-
|
|
1365
|
-
|
|
1366
|
-
|
|
1367
|
-
|
|
1368
|
-
|
|
1369
|
-
|
|
1370
|
-
|
|
1371
|
-
|
|
1372
|
-
|
|
1373
|
-
|
|
1374
|
-
|
|
1375
|
-
f' init_values={values_by_rddl_op[rddl_op]}\n')
|
|
1673
|
+
if self.compiled.model_params:
|
|
1674
|
+
result += ('Some RDDL operations are non-differentiable '
|
|
1675
|
+
'and will be approximated as follows:' + '\n')
|
|
1676
|
+
exprs_by_rddl_op, values_by_rddl_op = {}, {}
|
|
1677
|
+
for info in self.compiled.model_parameter_info().values():
|
|
1678
|
+
rddl_op = info['rddl_op']
|
|
1679
|
+
exprs_by_rddl_op.setdefault(rddl_op, []).append(info['id'])
|
|
1680
|
+
values_by_rddl_op.setdefault(rddl_op, []).append(info['init_value'])
|
|
1681
|
+
for rddl_op in sorted(exprs_by_rddl_op.keys()):
|
|
1682
|
+
result += (f' {rddl_op}:\n'
|
|
1683
|
+
f' addresses ={exprs_by_rddl_op[rddl_op]}\n'
|
|
1684
|
+
f' init_values={values_by_rddl_op[rddl_op]}\n')
|
|
1376
1685
|
return result
|
|
1377
1686
|
|
|
1378
1687
|
def summarize_hyperparameters(self) -> None:
|
|
@@ -1438,6 +1747,15 @@ r"""
|
|
|
1438
1747
|
# optimization
|
|
1439
1748
|
self.update = self._jax_update(train_loss)
|
|
1440
1749
|
self.check_zero_grad = self._jax_check_zero_gradients()
|
|
1750
|
+
|
|
1751
|
+
# pgpe option
|
|
1752
|
+
if self.use_pgpe:
|
|
1753
|
+
loss_fn = self._jax_loss(rollouts=test_rollouts)
|
|
1754
|
+
self.pgpe.compile(
|
|
1755
|
+
loss_fn=loss_fn,
|
|
1756
|
+
projection=self.plan.projection,
|
|
1757
|
+
real_dtype=self.test_compiled.REAL
|
|
1758
|
+
)
|
|
1441
1759
|
|
|
1442
1760
|
def _jax_return(self, use_symlog):
|
|
1443
1761
|
gamma = self.rddl.discount
|
|
@@ -1646,7 +1964,7 @@ r"""
|
|
|
1646
1964
|
return grad
|
|
1647
1965
|
|
|
1648
1966
|
return _loss_function, _grad_function, guess_1d, jax.jit(unravel_fn)
|
|
1649
|
-
|
|
1967
|
+
|
|
1650
1968
|
# ===========================================================================
|
|
1651
1969
|
# OPTIMIZE API
|
|
1652
1970
|
# ===========================================================================
|
|
@@ -1819,7 +2137,17 @@ r"""
|
|
|
1819
2137
|
policy_params = guess
|
|
1820
2138
|
opt_state = self.optimizer.init(policy_params)
|
|
1821
2139
|
opt_aux = {}
|
|
1822
|
-
|
|
2140
|
+
|
|
2141
|
+
# initialize pgpe parameters
|
|
2142
|
+
if self.use_pgpe:
|
|
2143
|
+
pgpe_params, pgpe_opt_state = self.pgpe.initialize(key, policy_params)
|
|
2144
|
+
rolling_pgpe_loss = RollingMean(test_rolling_window)
|
|
2145
|
+
else:
|
|
2146
|
+
pgpe_params, pgpe_opt_state = None, None
|
|
2147
|
+
rolling_pgpe_loss = None
|
|
2148
|
+
total_pgpe_it = 0
|
|
2149
|
+
r_max = -jnp.inf
|
|
2150
|
+
|
|
1823
2151
|
# ======================================================================
|
|
1824
2152
|
# INITIALIZATION OF RUNNING STATISTICS
|
|
1825
2153
|
# ======================================================================
|
|
@@ -1860,17 +2188,47 @@ r"""
|
|
|
1860
2188
|
|
|
1861
2189
|
# update the parameters of the plan
|
|
1862
2190
|
key, subkey = random.split(key)
|
|
1863
|
-
(policy_params, converged, opt_state, opt_aux,
|
|
1864
|
-
|
|
1865
|
-
|
|
1866
|
-
|
|
1867
|
-
|
|
2191
|
+
(policy_params, converged, opt_state, opt_aux, train_loss, train_log,
|
|
2192
|
+
model_params) = self.update(subkey, policy_params, policy_hyperparams,
|
|
2193
|
+
train_subs, model_params, opt_state, opt_aux)
|
|
2194
|
+
test_loss, (test_log, model_params_test) = self.test_loss(
|
|
2195
|
+
subkey, policy_params, policy_hyperparams, test_subs, model_params_test)
|
|
2196
|
+
test_loss_smooth = rolling_test_loss.update(test_loss)
|
|
2197
|
+
|
|
2198
|
+
# pgpe update of the plan
|
|
2199
|
+
pgpe_improve = False
|
|
2200
|
+
if self.use_pgpe:
|
|
2201
|
+
key, subkey = random.split(key)
|
|
2202
|
+
pgpe_params, r_max, pgpe_opt_state, pgpe_param, pgpe_converged = \
|
|
2203
|
+
self.pgpe.update(subkey, pgpe_params, r_max, policy_hyperparams,
|
|
2204
|
+
test_subs, model_params, pgpe_opt_state)
|
|
2205
|
+
pgpe_loss, _ = self.test_loss(
|
|
2206
|
+
subkey, pgpe_param, policy_hyperparams, test_subs, model_params_test)
|
|
2207
|
+
pgpe_loss_smooth = rolling_pgpe_loss.update(pgpe_loss)
|
|
2208
|
+
pgpe_return = -pgpe_loss_smooth
|
|
2209
|
+
|
|
2210
|
+
# replace with PGPE if it reaches a new minimum or train loss invalid
|
|
2211
|
+
if pgpe_loss_smooth < best_loss or not np.isfinite(train_loss):
|
|
2212
|
+
policy_params = pgpe_param
|
|
2213
|
+
test_loss, test_loss_smooth = pgpe_loss, pgpe_loss_smooth
|
|
2214
|
+
converged = pgpe_converged
|
|
2215
|
+
pgpe_improve = True
|
|
2216
|
+
total_pgpe_it += 1
|
|
2217
|
+
else:
|
|
2218
|
+
pgpe_loss, pgpe_loss_smooth, pgpe_return = None, None, None
|
|
2219
|
+
|
|
2220
|
+
# evaluate test losses and record best plan so far
|
|
2221
|
+
if test_loss_smooth < best_loss:
|
|
2222
|
+
best_params, best_loss, best_grad = \
|
|
2223
|
+
policy_params, test_loss_smooth, train_log['grad']
|
|
2224
|
+
last_iter_improve = it
|
|
2225
|
+
|
|
1868
2226
|
# ==================================================================
|
|
1869
2227
|
# STATUS CHECKS AND LOGGING
|
|
1870
2228
|
# ==================================================================
|
|
1871
2229
|
|
|
1872
2230
|
# no progress
|
|
1873
|
-
if self.check_zero_grad(train_log['grad']):
|
|
2231
|
+
if (not pgpe_improve) and self.check_zero_grad(train_log['grad']):
|
|
1874
2232
|
status = JaxPlannerStatus.NO_PROGRESS
|
|
1875
2233
|
|
|
1876
2234
|
# constraint satisfaction problem
|
|
@@ -1882,21 +2240,14 @@ r"""
|
|
|
1882
2240
|
status = JaxPlannerStatus.PRECONDITION_POSSIBLY_UNSATISFIED
|
|
1883
2241
|
|
|
1884
2242
|
# numerical error
|
|
1885
|
-
if
|
|
1886
|
-
|
|
1887
|
-
|
|
2243
|
+
if self.use_pgpe:
|
|
2244
|
+
invalid_loss = not (np.isfinite(train_loss) or np.isfinite(pgpe_loss))
|
|
2245
|
+
else:
|
|
2246
|
+
invalid_loss = not np.isfinite(train_loss)
|
|
2247
|
+
if invalid_loss:
|
|
2248
|
+
raise_warning(f'Planner aborted due to invalid loss {train_loss}.', 'red')
|
|
1888
2249
|
status = JaxPlannerStatus.INVALID_GRADIENT
|
|
1889
2250
|
|
|
1890
|
-
# evaluate test losses and record best plan so far
|
|
1891
|
-
test_loss, (log, model_params_test) = self.test_loss(
|
|
1892
|
-
subkey, policy_params, policy_hyperparams,
|
|
1893
|
-
test_subs, model_params_test)
|
|
1894
|
-
test_loss = rolling_test_loss.update(test_loss)
|
|
1895
|
-
if test_loss < best_loss:
|
|
1896
|
-
best_params, best_loss, best_grad = \
|
|
1897
|
-
policy_params, test_loss, train_log['grad']
|
|
1898
|
-
last_iter_improve = it
|
|
1899
|
-
|
|
1900
2251
|
# reached computation budget
|
|
1901
2252
|
elapsed = time.time() - start_time - elapsed_outside_loop
|
|
1902
2253
|
if elapsed >= train_seconds:
|
|
@@ -1910,11 +2261,14 @@ r"""
|
|
|
1910
2261
|
'status': status,
|
|
1911
2262
|
'iteration': it,
|
|
1912
2263
|
'train_return':-train_loss,
|
|
1913
|
-
'test_return':-
|
|
2264
|
+
'test_return':-test_loss_smooth,
|
|
1914
2265
|
'best_return':-best_loss,
|
|
2266
|
+
'pgpe_return': pgpe_return,
|
|
1915
2267
|
'params': policy_params,
|
|
1916
2268
|
'best_params': best_params,
|
|
2269
|
+
'pgpe_params': pgpe_params,
|
|
1917
2270
|
'last_iteration_improved': last_iter_improve,
|
|
2271
|
+
'pgpe_improved': pgpe_improve,
|
|
1918
2272
|
'grad': train_log['grad'],
|
|
1919
2273
|
'best_grad': best_grad,
|
|
1920
2274
|
'updates': train_log['updates'],
|
|
@@ -1923,7 +2277,7 @@ r"""
|
|
|
1923
2277
|
'model_params': model_params,
|
|
1924
2278
|
'progress': progress_percent,
|
|
1925
2279
|
'train_log': train_log,
|
|
1926
|
-
**
|
|
2280
|
+
**test_log
|
|
1927
2281
|
}
|
|
1928
2282
|
|
|
1929
2283
|
# stopping condition reached
|
|
@@ -1934,9 +2288,9 @@ r"""
|
|
|
1934
2288
|
if print_progress:
|
|
1935
2289
|
iters.n = progress_percent
|
|
1936
2290
|
iters.set_description(
|
|
1937
|
-
f'{position_str} {it:6} it / {-train_loss:14.
|
|
1938
|
-
f'{-
|
|
1939
|
-
f'{status.value} status'
|
|
2291
|
+
f'{position_str} {it:6} it / {-train_loss:14.5f} train / '
|
|
2292
|
+
f'{-test_loss_smooth:14.5f} test / {-best_loss:14.5f} best / '
|
|
2293
|
+
f'{status.value} status / {total_pgpe_it:6} pgpe'
|
|
1940
2294
|
)
|
|
1941
2295
|
|
|
1942
2296
|
# dash-board
|
|
@@ -1955,7 +2309,7 @@ r"""
|
|
|
1955
2309
|
# ======================================================================
|
|
1956
2310
|
# POST-PROCESSING AND CLEANUP
|
|
1957
2311
|
# ======================================================================
|
|
1958
|
-
|
|
2312
|
+
|
|
1959
2313
|
# release resources
|
|
1960
2314
|
if print_progress:
|
|
1961
2315
|
iters.close()
|
|
@@ -1967,7 +2321,7 @@ r"""
|
|
|
1967
2321
|
messages.update(JaxRDDLCompiler.get_error_messages(error_code))
|
|
1968
2322
|
if messages:
|
|
1969
2323
|
messages = '\n'.join(messages)
|
|
1970
|
-
raise_warning('
|
|
2324
|
+
raise_warning('JAX compiler encountered the following '
|
|
1971
2325
|
'error(s) in the original RDDL formulation '
|
|
1972
2326
|
f'during test evaluation:\n{messages}', 'red')
|
|
1973
2327
|
|
|
@@ -1975,14 +2329,14 @@ r"""
|
|
|
1975
2329
|
if print_summary:
|
|
1976
2330
|
grad_norm = jax.tree_map(lambda x: np.linalg.norm(x).item(), best_grad)
|
|
1977
2331
|
diagnosis = self._perform_diagnosis(
|
|
1978
|
-
last_iter_improve, -train_loss, -
|
|
2332
|
+
last_iter_improve, -train_loss, -test_loss_smooth, -best_loss, grad_norm)
|
|
1979
2333
|
print(f'summary of optimization:\n'
|
|
1980
|
-
f'
|
|
1981
|
-
f'
|
|
2334
|
+
f' status ={status}\n'
|
|
2335
|
+
f' time ={elapsed:.6f} sec.\n'
|
|
1982
2336
|
f' iterations ={it}\n'
|
|
1983
|
-
f'
|
|
1984
|
-
f'
|
|
1985
|
-
f'
|
|
2337
|
+
f' best objective={-best_loss:.6f}\n'
|
|
2338
|
+
f' best grad norm={grad_norm}\n'
|
|
2339
|
+
f'diagnosis: {diagnosis}\n')
|
|
1986
2340
|
|
|
1987
2341
|
def _perform_diagnosis(self, last_iter_improve,
|
|
1988
2342
|
train_return, test_return, best_return, grad_norm):
|
|
@@ -2002,23 +2356,24 @@ r"""
|
|
|
2002
2356
|
if last_iter_improve <= 1:
|
|
2003
2357
|
if grad_is_zero:
|
|
2004
2358
|
return termcolor.colored(
|
|
2005
|
-
'[FAILURE] no progress was made
|
|
2006
|
-
f'and max grad norm {max_grad_norm:.6f}
|
|
2007
|
-
'solver likely stuck in a plateau.', 'red')
|
|
2359
|
+
'[FAILURE] no progress was made '
|
|
2360
|
+
f'and max grad norm {max_grad_norm:.6f} was zero: '
|
|
2361
|
+
'the solver was likely stuck in a plateau.', 'red')
|
|
2008
2362
|
else:
|
|
2009
2363
|
return termcolor.colored(
|
|
2010
|
-
'[FAILURE] no progress was made
|
|
2011
|
-
f'but max grad norm {max_grad_norm:.6f}
|
|
2012
|
-
'
|
|
2364
|
+
'[FAILURE] no progress was made '
|
|
2365
|
+
f'but max grad norm {max_grad_norm:.6f} was non-zero: '
|
|
2366
|
+
'the learning rate or other hyper-parameters were likely suboptimal.',
|
|
2367
|
+
'red')
|
|
2013
2368
|
|
|
2014
2369
|
# model is likely poor IF:
|
|
2015
2370
|
# 1. the train and test return disagree
|
|
2016
2371
|
if not (validation_error < 20):
|
|
2017
2372
|
return termcolor.colored(
|
|
2018
|
-
'[WARNING] progress was made
|
|
2019
|
-
f'but relative train-test error {validation_error:.6f}
|
|
2020
|
-
'
|
|
2021
|
-
'or the batch size
|
|
2373
|
+
'[WARNING] progress was made '
|
|
2374
|
+
f'but relative train-test error {validation_error:.6f} was high: '
|
|
2375
|
+
'model relaxation around the solution was poor '
|
|
2376
|
+
'or the batch size was too small.', 'yellow')
|
|
2022
2377
|
|
|
2023
2378
|
# model likely did not converge IF:
|
|
2024
2379
|
# 1. the max grad relative to the return is high
|
|
@@ -2026,15 +2381,15 @@ r"""
|
|
|
2026
2381
|
return_to_grad_norm = abs(best_return) / max_grad_norm
|
|
2027
2382
|
if not (return_to_grad_norm > 1):
|
|
2028
2383
|
return termcolor.colored(
|
|
2029
|
-
'[WARNING] progress was made
|
|
2030
|
-
f'but max grad norm {max_grad_norm:.6f}
|
|
2031
|
-
'
|
|
2032
|
-
'or the relaxed model
|
|
2033
|
-
'or the batch size
|
|
2384
|
+
'[WARNING] progress was made '
|
|
2385
|
+
f'but max grad norm {max_grad_norm:.6f} was high: '
|
|
2386
|
+
'the solution was likely locally suboptimal, '
|
|
2387
|
+
'or the relaxed model was not smooth around the solution, '
|
|
2388
|
+
'or the batch size was too small.', 'yellow')
|
|
2034
2389
|
|
|
2035
2390
|
# likely successful
|
|
2036
2391
|
return termcolor.colored(
|
|
2037
|
-
'[SUCCESS]
|
|
2392
|
+
'[SUCCESS] solver converged successfully '
|
|
2038
2393
|
'(note: not all potential problems can be ruled out).', 'green')
|
|
2039
2394
|
|
|
2040
2395
|
def get_action(self, key: random.PRNGKey,
|
pyRDDLGym_jax/core/simulator.py
CHANGED
|
@@ -1,3 +1,23 @@
|
|
|
1
|
+
# ***********************************************************************
|
|
2
|
+
# JAXPLAN
|
|
3
|
+
#
|
|
4
|
+
# Author: Michael Gimelfarb
|
|
5
|
+
#
|
|
6
|
+
# REFERENCES:
|
|
7
|
+
#
|
|
8
|
+
# [1] Gimelfarb, Michael, Ayal Taitler, and Scott Sanner. "JaxPlan and GurobiPlan:
|
|
9
|
+
# Optimization Baselines for Replanning in Discrete and Mixed Discrete-Continuous
|
|
10
|
+
# Probabilistic Domains." Proceedings of the International Conference on Automated
|
|
11
|
+
# Planning and Scheduling. Vol. 34. 2024.
|
|
12
|
+
#
|
|
13
|
+
# [2] Taitler, Ayal, Michael Gimelfarb, Jihwan Jeong, Sriram Gopalakrishnan, Martin
|
|
14
|
+
# Mladenov, Xiaotian Liu, and Scott Sanner. "pyRDDLGym: From RDDL to Gym Environments."
|
|
15
|
+
# In PRL Workshop Series {\textendash} Bridging the Gap Between AI Planning and
|
|
16
|
+
# Reinforcement Learning.
|
|
17
|
+
#
|
|
18
|
+
# ***********************************************************************
|
|
19
|
+
|
|
20
|
+
|
|
1
21
|
import time
|
|
2
22
|
from typing import Dict, Optional
|
|
3
23
|
|
pyRDDLGym_jax/core/tuning.py
CHANGED
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# ***********************************************************************
|
|
2
|
+
# JAXPLAN
|
|
3
|
+
#
|
|
4
|
+
# Author: Michael Gimelfarb
|
|
5
|
+
#
|
|
6
|
+
# REFERENCES:
|
|
7
|
+
#
|
|
8
|
+
# [1] Gimelfarb, Michael, Ayal Taitler, and Scott Sanner. "JaxPlan and GurobiPlan:
|
|
9
|
+
# Optimization Baselines for Replanning in Discrete and Mixed Discrete-Continuous
|
|
10
|
+
# Probabilistic Domains." Proceedings of the International Conference on Automated
|
|
11
|
+
# Planning and Scheduling. Vol. 34. 2024.
|
|
12
|
+
#
|
|
13
|
+
# ***********************************************************************
|
|
14
|
+
|
|
15
|
+
|
|
1
16
|
import csv
|
|
2
17
|
import datetime
|
|
3
18
|
import threading
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# ***********************************************************************
|
|
2
|
+
# JAXPLAN
|
|
3
|
+
#
|
|
4
|
+
# Author: Michael Gimelfarb
|
|
5
|
+
#
|
|
6
|
+
# REFERENCES:
|
|
7
|
+
#
|
|
8
|
+
# [1] Gimelfarb, Michael, Ayal Taitler, and Scott Sanner. "JaxPlan and GurobiPlan:
|
|
9
|
+
# Optimization Baselines for Replanning in Discrete and Mixed Discrete-Continuous
|
|
10
|
+
# Probabilistic Domains." Proceedings of the International Conference on Automated
|
|
11
|
+
# Planning and Scheduling. Vol. 34. 2024.
|
|
12
|
+
#
|
|
13
|
+
# ***********************************************************************
|
|
14
|
+
|
|
15
|
+
|
|
1
16
|
import ast
|
|
2
17
|
import os
|
|
3
18
|
from datetime import datetime
|
|
@@ -61,6 +76,7 @@ class JaxPlannerDashboard:
|
|
|
61
76
|
self.xticks = {}
|
|
62
77
|
self.test_return = {}
|
|
63
78
|
self.train_return = {}
|
|
79
|
+
self.pgpe_return = {}
|
|
64
80
|
self.return_dist = {}
|
|
65
81
|
self.return_dist_ticks = {}
|
|
66
82
|
self.return_dist_last_progress = {}
|
|
@@ -299,6 +315,9 @@ class JaxPlannerDashboard:
|
|
|
299
315
|
dbc.Col(Graph(id='train-return-graph'), width=6),
|
|
300
316
|
dbc.Col(Graph(id='test-return-graph'), width=6),
|
|
301
317
|
]),
|
|
318
|
+
dbc.Row([
|
|
319
|
+
dbc.Col(Graph(id='pgpe-return-graph'), width=6)
|
|
320
|
+
]),
|
|
302
321
|
dbc.Row([
|
|
303
322
|
Graph(id='dist-return-graph')
|
|
304
323
|
])
|
|
@@ -661,6 +680,33 @@ class JaxPlannerDashboard:
|
|
|
661
680
|
)
|
|
662
681
|
return fig
|
|
663
682
|
|
|
683
|
+
@app.callback(
|
|
684
|
+
Output('pgpe-return-graph', 'figure'),
|
|
685
|
+
[Input('interval', 'n_intervals'),
|
|
686
|
+
Input('trigger-experiment-check', 'children'),
|
|
687
|
+
Input('tabs-main', 'active_tab')]
|
|
688
|
+
)
|
|
689
|
+
def update_pgpe_return_graph(n, trigger, active_tab):
|
|
690
|
+
if active_tab != 'tab-performance': return dash.no_update
|
|
691
|
+
fig = go.Figure()
|
|
692
|
+
for (row, checked) in self.checked.copy().items():
|
|
693
|
+
if checked:
|
|
694
|
+
fig.add_trace(go.Scatter(
|
|
695
|
+
x=self.xticks[row], y=self.pgpe_return[row],
|
|
696
|
+
name=f'id={row}',
|
|
697
|
+
mode='lines+markers',
|
|
698
|
+
marker=dict(size=3), line=dict(width=2)
|
|
699
|
+
))
|
|
700
|
+
fig.update_layout(
|
|
701
|
+
title=dict(text="PGPE Return"),
|
|
702
|
+
xaxis=dict(title=dict(text="Training Iteration")),
|
|
703
|
+
yaxis=dict(title=dict(text="Cumulative Reward")),
|
|
704
|
+
font=dict(size=PLOT_AXES_FONT_SIZE),
|
|
705
|
+
legend=dict(bgcolor='rgba(0,0,0,0)'),
|
|
706
|
+
template="plotly_white"
|
|
707
|
+
)
|
|
708
|
+
return fig
|
|
709
|
+
|
|
664
710
|
@app.callback(
|
|
665
711
|
Output('dist-return-graph', 'figure'),
|
|
666
712
|
[Input('interval', 'n_intervals'),
|
|
@@ -1316,6 +1362,7 @@ class JaxPlannerDashboard:
|
|
|
1316
1362
|
self.xticks[experiment_id] = []
|
|
1317
1363
|
self.train_return[experiment_id] = []
|
|
1318
1364
|
self.test_return[experiment_id] = []
|
|
1365
|
+
self.pgpe_return[experiment_id] = []
|
|
1319
1366
|
self.return_dist_ticks[experiment_id] = []
|
|
1320
1367
|
self.return_dist_last_progress[experiment_id] = 0
|
|
1321
1368
|
self.return_dist[experiment_id] = []
|
|
@@ -1367,6 +1414,7 @@ class JaxPlannerDashboard:
|
|
|
1367
1414
|
self.xticks[experiment_id].append(iteration)
|
|
1368
1415
|
self.train_return[experiment_id].append(callback['train_return'])
|
|
1369
1416
|
self.test_return[experiment_id].append(callback['best_return'])
|
|
1417
|
+
self.pgpe_return[experiment_id].append(callback['pgpe_return'])
|
|
1370
1418
|
|
|
1371
1419
|
# data for return distributions
|
|
1372
1420
|
progress = callback['progress']
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
[Model]
|
|
2
2
|
logic='FuzzyLogic'
|
|
3
|
-
comparison_kwargs={'weight':
|
|
4
|
-
rounding_kwargs={'weight':
|
|
5
|
-
control_kwargs={'weight':
|
|
3
|
+
comparison_kwargs={'weight': 20}
|
|
4
|
+
rounding_kwargs={'weight': 20}
|
|
5
|
+
control_kwargs={'weight': 20}
|
|
6
6
|
|
|
7
7
|
[Optimizer]
|
|
8
8
|
method='JaxStraightLinePlan'
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
[Model]
|
|
2
2
|
logic='FuzzyLogic'
|
|
3
|
-
comparison_kwargs={'weight':
|
|
4
|
-
rounding_kwargs={'weight':
|
|
5
|
-
control_kwargs={'weight':
|
|
3
|
+
comparison_kwargs={'weight': 20}
|
|
4
|
+
rounding_kwargs={'weight': 20}
|
|
5
|
+
control_kwargs={'weight': 20}
|
|
6
6
|
|
|
7
7
|
[Optimizer]
|
|
8
8
|
method='JaxStraightLinePlan'
|
|
9
9
|
method_kwargs={}
|
|
10
10
|
optimizer='rmsprop'
|
|
11
|
-
optimizer_kwargs={'learning_rate': 0.
|
|
11
|
+
optimizer_kwargs={'learning_rate': 0.001}
|
|
12
12
|
batch_size_train=1
|
|
13
13
|
batch_size_test=1
|
|
14
14
|
clip_grad=1.0
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
[Model]
|
|
2
2
|
logic='FuzzyLogic'
|
|
3
|
-
comparison_kwargs={'weight':
|
|
4
|
-
rounding_kwargs={'weight':
|
|
5
|
-
control_kwargs={'weight':
|
|
3
|
+
comparison_kwargs={'weight': 10}
|
|
4
|
+
rounding_kwargs={'weight': 10}
|
|
5
|
+
control_kwargs={'weight': 10}
|
|
6
6
|
|
|
7
7
|
[Optimizer]
|
|
8
8
|
method='JaxStraightLinePlan'
|
|
@@ -11,6 +11,7 @@ optimizer='rmsprop'
|
|
|
11
11
|
optimizer_kwargs={'learning_rate': 0.03}
|
|
12
12
|
batch_size_train=1
|
|
13
13
|
batch_size_test=1
|
|
14
|
+
pgpe=None
|
|
14
15
|
|
|
15
16
|
[Training]
|
|
16
17
|
key=42
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: pyRDDLGym-jax
|
|
3
|
-
Version:
|
|
3
|
+
Version: 2.0
|
|
4
4
|
Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
|
|
5
5
|
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
6
|
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
@@ -1,12 +1,12 @@
|
|
|
1
|
-
pyRDDLGym_jax/__init__.py,sha256=
|
|
1
|
+
pyRDDLGym_jax/__init__.py,sha256=TiPG4w8nN4AzPkhugwVvZkHmAgP955NltD4QRmBLhRU,19
|
|
2
2
|
pyRDDLGym_jax/entry_point.py,sha256=dxDlO_5gneEEViwkLCg30Z-KVzUgdRXaKuFjoZklkA0,974
|
|
3
3
|
pyRDDLGym_jax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
|
-
pyRDDLGym_jax/core/compiler.py,sha256=
|
|
5
|
-
pyRDDLGym_jax/core/logic.py,sha256=
|
|
6
|
-
pyRDDLGym_jax/core/planner.py,sha256=
|
|
7
|
-
pyRDDLGym_jax/core/simulator.py,sha256=
|
|
8
|
-
pyRDDLGym_jax/core/tuning.py,sha256=
|
|
9
|
-
pyRDDLGym_jax/core/visualization.py,sha256=
|
|
4
|
+
pyRDDLGym_jax/core/compiler.py,sha256=Rn-aIqfgfWqu45bvCfPb9tB8RIOBVdbj-pI-V3WS2Z8,89212
|
|
5
|
+
pyRDDLGym_jax/core/logic.py,sha256=_A6eGYtLVU3pbLAezxJVB9bnClJoaFIa2mBIDdFrqoU,39655
|
|
6
|
+
pyRDDLGym_jax/core/planner.py,sha256=4j56l7SL7F89g2QA4nOpyhODmY0DamvxYLfCMKxJNbQ,118593
|
|
7
|
+
pyRDDLGym_jax/core/simulator.py,sha256=DnPL93WVCMZqtqMUoiJdfWcH9pEvNgGfDfO4NV0wIS0,9271
|
|
8
|
+
pyRDDLGym_jax/core/tuning.py,sha256=RKKtDZp7unvfbhZEoaunZtcAn5xtzGYqXBB_Ij_Aapc,24205
|
|
9
|
+
pyRDDLGym_jax/core/visualization.py,sha256=XtQL1A5dQIlfeUpte-r3lNVw-GNLxj2EYUNMz7AFOtc,70359
|
|
10
10
|
pyRDDLGym_jax/core/assets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
11
|
pyRDDLGym_jax/core/assets/favicon.ico,sha256=RMMrI9YvmF81TgYG7FO7UAre6WmYFkV3B2GmbA1l0kM,175085
|
|
12
12
|
pyRDDLGym_jax/examples/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -16,8 +16,8 @@ pyRDDLGym_jax/examples/run_plan.py,sha256=v2AvwgIa4Ejr626vBOgWFJIQvay3IPKWno02zt
|
|
|
16
16
|
pyRDDLGym_jax/examples/run_scipy.py,sha256=wvcpWCvdjvYHntO95a7JYfY2fuCMUTKnqjJikW0PnL4,2291
|
|
17
17
|
pyRDDLGym_jax/examples/run_tune.py,sha256=zqrhvLR5PeWJv0NsRxDCzAPmvgPgz_1NrtM1xBy6ndU,3606
|
|
18
18
|
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg,sha256=mE8MqhOlkHeXIGEVrnR3QY6I-_iy4uxFYRA71P1bmtk,347
|
|
19
|
-
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg,sha256=
|
|
20
|
-
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg,sha256=
|
|
19
|
+
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg,sha256=nFFYHCKQUMn8x-OpJwu2pwe1tycNSJ8iAIwSkCBn33E,370
|
|
20
|
+
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg,sha256=eJ3HvHjODoKdtX7u-AM51xQaHJnYgzEy2t3omNG2oCs,340
|
|
21
21
|
pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg,sha256=9-QMZPZuecAEaerD79ZAbGX-tgfL8Y2W-tfkAyD15Cw,362
|
|
22
22
|
pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg,sha256=BiY6wwSYkR9-T46AA4n3okJ1Qvj8Iu-y1V5BrfCbqrM,340
|
|
23
23
|
pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg,sha256=VBlTiHFQG72D1wpebMsuzSokwqlPVD99WjPp4YoWs84,356
|
|
@@ -25,15 +25,15 @@ pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg,sha256=bH_5O13-Y6ztv
|
|
|
25
25
|
pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg,sha256=Pq6E9RYksue7X2cWjdWyUsV0LqQTjTvq6p0aLBVKWfY,370
|
|
26
26
|
pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg,sha256=SGVQAOqrOjEsZEtxL_Z6aGbLR19h5gKCcy0oz2vtQp8,382
|
|
27
27
|
pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg,sha256=6obQik2FBldoJ3VwoVfGhQqKpKdnYox770cF-SGRi3Q,345
|
|
28
|
-
pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg,sha256=
|
|
29
|
-
pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg,sha256=
|
|
30
|
-
pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg,sha256=
|
|
28
|
+
pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg,sha256=rs-CzOAyZV_NvwSh2f6Fm9XNw5Z8WIYgpAOzgTm_Gv8,403
|
|
29
|
+
pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg,sha256=EtSCTjd8gWm7akQdfHFxdpGnQvHzjo2IHbAuVxTAX4U,356
|
|
30
|
+
pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg,sha256=7nPOJCo3eaZuq1pCyIJJJkDM0jjJThDuDECJDZzX-uc,379
|
|
31
31
|
pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg,sha256=V3jzPGuNq2IAxYy_EeZWin4Y_uf0HvGhzg06ODNSY-I,381
|
|
32
|
-
pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg,sha256=
|
|
33
|
-
pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg,sha256=
|
|
34
|
-
pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg,sha256=
|
|
35
|
-
pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg,sha256=
|
|
36
|
-
pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg,sha256=
|
|
32
|
+
pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg,sha256=SYAJmoUIUhhvAej3XOzC5boGxKVHnSiVi5-ZGj2S29M,354
|
|
33
|
+
pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg,sha256=osoIPfrldPw7oJF2AaAw0-ke6YHQNdrslFBCTytsqmo,354
|
|
34
|
+
pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg,sha256=oNX8uW8Bw2uG9zHX1zeLF3mHWDHRIlJXYvbFcY0pfCI,382
|
|
35
|
+
pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg,sha256=exCfGI3WU7IFO7n5rRe5cO1ZHAdFwttRYzjIdD4Pz2Y,451
|
|
36
|
+
pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg,sha256=e6Ikgv2uBbKuXHfVKt4KQ01LDUBGbc31D28bCcztJ58,413
|
|
37
37
|
pyRDDLGym_jax/examples/configs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
38
38
|
pyRDDLGym_jax/examples/configs/default_drp.cfg,sha256=XeMWAAG_OFZo7JAMxS5-XXroZaeVMzfM0NswmEobIns,373
|
|
39
39
|
pyRDDLGym_jax/examples/configs/default_replan.cfg,sha256=CK4cEz8ReXyAZPLaLG9clIIRXAqM3IplUCxbLt_V2lY,407
|
|
@@ -41,9 +41,9 @@ pyRDDLGym_jax/examples/configs/default_slp.cfg,sha256=mJo0woDevhQCSQfJg30ULVy9qG
|
|
|
41
41
|
pyRDDLGym_jax/examples/configs/tuning_drp.cfg,sha256=CQMpSCKTkGioO7U82mHMsYWFRsutULx0V6Wrl3YzV2U,504
|
|
42
42
|
pyRDDLGym_jax/examples/configs/tuning_replan.cfg,sha256=m_0nozFg_GVld0tGv92Xao_KONFJDq_vtiJKt5isqI8,501
|
|
43
43
|
pyRDDLGym_jax/examples/configs/tuning_slp.cfg,sha256=KHu8II6CA-h_HblwvWHylNRjSvvGS3VHxN7JQNR4p_Q,464
|
|
44
|
-
pyRDDLGym_jax-
|
|
45
|
-
pyRDDLGym_jax-
|
|
46
|
-
pyRDDLGym_jax-
|
|
47
|
-
pyRDDLGym_jax-
|
|
48
|
-
pyRDDLGym_jax-
|
|
49
|
-
pyRDDLGym_jax-
|
|
44
|
+
pyRDDLGym_jax-2.0.dist-info/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
|
|
45
|
+
pyRDDLGym_jax-2.0.dist-info/METADATA,sha256=ZYIe9c_Tar4WO8qQOvcUIJVMmZznPUBRaegS0DH2un8,15090
|
|
46
|
+
pyRDDLGym_jax-2.0.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
47
|
+
pyRDDLGym_jax-2.0.dist-info/entry_points.txt,sha256=Q--z9QzqDBz1xjswPZ87PU-pib-WPXx44hUWAFoBGBA,59
|
|
48
|
+
pyRDDLGym_jax-2.0.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
|
|
49
|
+
pyRDDLGym_jax-2.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|