pyRDDLGym-jax 2.1__tar.gz → 2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/PKG-INFO +25 -22
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/README.md +24 -21
- pyrddlgym_jax-2.2/pyRDDLGym_jax/__init__.py +1 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/core/planner.py +159 -76
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax.egg-info/PKG-INFO +25 -22
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/setup.py +1 -1
- pyrddlgym_jax-2.1/pyRDDLGym_jax/__init__.py +0 -1
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/LICENSE +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/core/__init__.py +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/core/assets/__init__.py +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/core/assets/favicon.ico +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/core/compiler.py +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/core/logic.py +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/core/simulator.py +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/core/tuning.py +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/core/visualization.py +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/entry_point.py +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/__init__.py +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/__init__.py +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/default_drp.cfg +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/default_replan.cfg +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/default_slp.cfg +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/tuning_drp.cfg +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/tuning_replan.cfg +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/tuning_slp.cfg +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/run_gradient.py +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/run_gym.py +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/run_plan.py +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/run_scipy.py +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/run_tune.py +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax.egg-info/SOURCES.txt +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax.egg-info/dependency_links.txt +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax.egg-info/entry_points.txt +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax.egg-info/requires.txt +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax.egg-info/top_level.txt +0 -0
- {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: pyRDDLGym-jax
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.2
|
|
4
4
|
Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
|
|
5
5
|
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
6
|
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
@@ -58,8 +58,11 @@ Dynamic: summary
|
|
|
58
58
|
|
|
59
59
|
Purpose:
|
|
60
60
|
|
|
61
|
-
1. automatic translation of
|
|
62
|
-
2.
|
|
61
|
+
1. automatic translation of RDDL description files into differentiable JAX simulators
|
|
62
|
+
2. implementation of (highly configurable) operator relaxations for working in discrete and hybrid domains
|
|
63
|
+
3. flexible policy representations and automated Bayesian hyper-parameter tuning
|
|
64
|
+
4. interactive dashboard for dyanmic visualization and debugging
|
|
65
|
+
5. hybridization with parameter-exploring policy gradients.
|
|
63
66
|
|
|
64
67
|
Some demos of solved problems by JaxPlan:
|
|
65
68
|
|
|
@@ -235,8 +238,23 @@ More documentation about this and other new features will be coming soon.
|
|
|
235
238
|
|
|
236
239
|
## Tuning the Planner
|
|
237
240
|
|
|
238
|
-
|
|
239
|
-
|
|
241
|
+
A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
|
|
242
|
+
|
|
243
|
+
```shell
|
|
244
|
+
jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
|
|
245
|
+
```
|
|
246
|
+
|
|
247
|
+
where:
|
|
248
|
+
- ``domain`` is the domain identifier as specified in rddlrepository
|
|
249
|
+
- ``instance`` is the instance identifier
|
|
250
|
+
- ``method`` is the planning method to use (i.e. drp, slp, replan)
|
|
251
|
+
- ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
|
|
252
|
+
- ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
|
|
253
|
+
- ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
|
|
254
|
+
- ``dashboard`` is whether the optimizations are tracked in the dashboard application.
|
|
255
|
+
|
|
256
|
+
It is easy to tune a custom range of the planner's hyper-parameters efficiently.
|
|
257
|
+
First create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
|
|
240
258
|
|
|
241
259
|
```ini
|
|
242
260
|
[Model]
|
|
@@ -260,7 +278,7 @@ train_on_reset=True
|
|
|
260
278
|
|
|
261
279
|
would allow to tune the sharpness of model relaxations, and the learning rate of the optimizer.
|
|
262
280
|
|
|
263
|
-
Next, you must link the patterns in the config with concrete hyper-parameter ranges the tuner will understand:
|
|
281
|
+
Next, you must link the patterns in the config with concrete hyper-parameter ranges the tuner will understand, and run the optimizer:
|
|
264
282
|
|
|
265
283
|
```python
|
|
266
284
|
import pyRDDLGym
|
|
@@ -292,22 +310,7 @@ tuning = JaxParameterTuning(env=env,
|
|
|
292
310
|
gp_iters=iters)
|
|
293
311
|
tuning.tune(key=42, log_file='path/to/log.csv')
|
|
294
312
|
```
|
|
295
|
-
|
|
296
|
-
A basic run script is provided to run the automatic hyper-parameter tuning for the most sensitive parameters of JaxPlan:
|
|
297
|
-
|
|
298
|
-
```shell
|
|
299
|
-
jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
|
|
300
|
-
```
|
|
301
|
-
|
|
302
|
-
where:
|
|
303
|
-
- ``domain`` is the domain identifier as specified in rddlrepository
|
|
304
|
-
- ``instance`` is the instance identifier
|
|
305
|
-
- ``method`` is the planning method to use (i.e. drp, slp, replan)
|
|
306
|
-
- ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
|
|
307
|
-
- ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
|
|
308
|
-
- ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
|
|
309
|
-
- ``dashboard`` is whether the optimizations are tracked in the dashboard application.
|
|
310
|
-
|
|
313
|
+
|
|
311
314
|
|
|
312
315
|
## Simulation
|
|
313
316
|
|
|
@@ -12,8 +12,11 @@
|
|
|
12
12
|
|
|
13
13
|
Purpose:
|
|
14
14
|
|
|
15
|
-
1. automatic translation of
|
|
16
|
-
2.
|
|
15
|
+
1. automatic translation of RDDL description files into differentiable JAX simulators
|
|
16
|
+
2. implementation of (highly configurable) operator relaxations for working in discrete and hybrid domains
|
|
17
|
+
3. flexible policy representations and automated Bayesian hyper-parameter tuning
|
|
18
|
+
4. interactive dashboard for dyanmic visualization and debugging
|
|
19
|
+
5. hybridization with parameter-exploring policy gradients.
|
|
17
20
|
|
|
18
21
|
Some demos of solved problems by JaxPlan:
|
|
19
22
|
|
|
@@ -189,8 +192,23 @@ More documentation about this and other new features will be coming soon.
|
|
|
189
192
|
|
|
190
193
|
## Tuning the Planner
|
|
191
194
|
|
|
192
|
-
|
|
193
|
-
|
|
195
|
+
A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
|
|
196
|
+
|
|
197
|
+
```shell
|
|
198
|
+
jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
|
|
199
|
+
```
|
|
200
|
+
|
|
201
|
+
where:
|
|
202
|
+
- ``domain`` is the domain identifier as specified in rddlrepository
|
|
203
|
+
- ``instance`` is the instance identifier
|
|
204
|
+
- ``method`` is the planning method to use (i.e. drp, slp, replan)
|
|
205
|
+
- ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
|
|
206
|
+
- ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
|
|
207
|
+
- ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
|
|
208
|
+
- ``dashboard`` is whether the optimizations are tracked in the dashboard application.
|
|
209
|
+
|
|
210
|
+
It is easy to tune a custom range of the planner's hyper-parameters efficiently.
|
|
211
|
+
First create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
|
|
194
212
|
|
|
195
213
|
```ini
|
|
196
214
|
[Model]
|
|
@@ -214,7 +232,7 @@ train_on_reset=True
|
|
|
214
232
|
|
|
215
233
|
would allow to tune the sharpness of model relaxations, and the learning rate of the optimizer.
|
|
216
234
|
|
|
217
|
-
Next, you must link the patterns in the config with concrete hyper-parameter ranges the tuner will understand:
|
|
235
|
+
Next, you must link the patterns in the config with concrete hyper-parameter ranges the tuner will understand, and run the optimizer:
|
|
218
236
|
|
|
219
237
|
```python
|
|
220
238
|
import pyRDDLGym
|
|
@@ -246,22 +264,7 @@ tuning = JaxParameterTuning(env=env,
|
|
|
246
264
|
gp_iters=iters)
|
|
247
265
|
tuning.tune(key=42, log_file='path/to/log.csv')
|
|
248
266
|
```
|
|
249
|
-
|
|
250
|
-
A basic run script is provided to run the automatic hyper-parameter tuning for the most sensitive parameters of JaxPlan:
|
|
251
|
-
|
|
252
|
-
```shell
|
|
253
|
-
jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
|
|
254
|
-
```
|
|
255
|
-
|
|
256
|
-
where:
|
|
257
|
-
- ``domain`` is the domain identifier as specified in rddlrepository
|
|
258
|
-
- ``instance`` is the instance identifier
|
|
259
|
-
- ``method`` is the planning method to use (i.e. drp, slp, replan)
|
|
260
|
-
- ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
|
|
261
|
-
- ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
|
|
262
|
-
- ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
|
|
263
|
-
- ``dashboard`` is whether the optimizations are tracked in the dashboard application.
|
|
264
|
-
|
|
267
|
+
|
|
265
268
|
|
|
266
269
|
## Simulation
|
|
267
270
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = '2.2'
|
|
@@ -47,7 +47,9 @@ import jax.random as random
|
|
|
47
47
|
import numpy as np
|
|
48
48
|
import optax
|
|
49
49
|
import termcolor
|
|
50
|
-
from tqdm import tqdm
|
|
50
|
+
from tqdm import tqdm, TqdmWarning
|
|
51
|
+
import warnings
|
|
52
|
+
warnings.filterwarnings("ignore", category=TqdmWarning)
|
|
51
53
|
|
|
52
54
|
from pyRDDLGym.core.compiler.model import RDDLPlanningModel, RDDLLiftedModel
|
|
53
55
|
from pyRDDLGym.core.debug.logger import Logger
|
|
@@ -1212,17 +1214,22 @@ class GaussianPGPE(PGPE):
|
|
|
1212
1214
|
init_sigma: float=1.0,
|
|
1213
1215
|
sigma_range: Tuple[float, float]=(1e-5, 1e5),
|
|
1214
1216
|
scale_reward: bool=True,
|
|
1217
|
+
min_reward_scale: float=1e-5,
|
|
1215
1218
|
super_symmetric: bool=True,
|
|
1216
1219
|
super_symmetric_accurate: bool=True,
|
|
1217
1220
|
optimizer: Callable[..., optax.GradientTransformation]=optax.adam,
|
|
1218
1221
|
optimizer_kwargs_mu: Optional[Kwargs]=None,
|
|
1219
|
-
optimizer_kwargs_sigma: Optional[Kwargs]=None
|
|
1222
|
+
optimizer_kwargs_sigma: Optional[Kwargs]=None,
|
|
1223
|
+
start_entropy_coeff: float=1e-3,
|
|
1224
|
+
end_entropy_coeff: float=1e-8,
|
|
1225
|
+
max_kl_update: Optional[float]=None) -> None:
|
|
1220
1226
|
'''Creates a new Gaussian PGPE planner.
|
|
1221
1227
|
|
|
1222
1228
|
:param batch_size: how many policy parameters to sample per optimization step
|
|
1223
1229
|
:param init_sigma: initial standard deviation of Gaussian
|
|
1224
1230
|
:param sigma_range: bounds to constrain standard deviation
|
|
1225
1231
|
:param scale_reward: whether to apply reward scaling as in the paper
|
|
1232
|
+
:param min_reward_scale: minimum reward scaling to avoid underflow
|
|
1226
1233
|
:param super_symmetric: whether to use super-symmetric sampling as in the paper
|
|
1227
1234
|
:param super_symmetric_accurate: whether to use the accurate formula for super-
|
|
1228
1235
|
symmetric sampling or the simplified but biased formula
|
|
@@ -1231,6 +1238,9 @@ class GaussianPGPE(PGPE):
|
|
|
1231
1238
|
factory for the mean optimizer
|
|
1232
1239
|
:param optimizer_kwargs_sigma: a dictionary of parameters to pass to the SGD
|
|
1233
1240
|
factory for the standard deviation optimizer
|
|
1241
|
+
:param start_entropy_coeff: starting entropy regularization coeffient for Gaussian
|
|
1242
|
+
:param end_entropy_coeff: ending entropy regularization coeffient for Gaussian
|
|
1243
|
+
:param max_kl_update: bound on kl-divergence between parameter updates
|
|
1234
1244
|
'''
|
|
1235
1245
|
super().__init__()
|
|
1236
1246
|
|
|
@@ -1238,8 +1248,13 @@ class GaussianPGPE(PGPE):
|
|
|
1238
1248
|
self.init_sigma = init_sigma
|
|
1239
1249
|
self.sigma_range = sigma_range
|
|
1240
1250
|
self.scale_reward = scale_reward
|
|
1251
|
+
self.min_reward_scale = min_reward_scale
|
|
1241
1252
|
self.super_symmetric = super_symmetric
|
|
1242
1253
|
self.super_symmetric_accurate = super_symmetric_accurate
|
|
1254
|
+
|
|
1255
|
+
# entropy regularization penalty is decayed exponentially between these values
|
|
1256
|
+
self.start_entropy_coeff = start_entropy_coeff
|
|
1257
|
+
self.end_entropy_coeff = end_entropy_coeff
|
|
1243
1258
|
|
|
1244
1259
|
# set optimizers
|
|
1245
1260
|
if optimizer_kwargs_mu is None:
|
|
@@ -1249,36 +1264,62 @@ class GaussianPGPE(PGPE):
|
|
|
1249
1264
|
optimizer_kwargs_sigma = {'learning_rate': 0.1}
|
|
1250
1265
|
self.optimizer_kwargs_sigma = optimizer_kwargs_sigma
|
|
1251
1266
|
self.optimizer_name = optimizer
|
|
1252
|
-
|
|
1253
|
-
|
|
1267
|
+
try:
|
|
1268
|
+
mu_optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs_mu)
|
|
1269
|
+
sigma_optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs_sigma)
|
|
1270
|
+
except Exception as _:
|
|
1271
|
+
raise_warning(
|
|
1272
|
+
f'Failed to inject hyperparameters into optax optimizer for PGPE, '
|
|
1273
|
+
'rolling back to safer method: please note that kl-divergence '
|
|
1274
|
+
'constraints will be disabled.', 'red')
|
|
1275
|
+
mu_optimizer = optimizer(**optimizer_kwargs_mu)
|
|
1276
|
+
sigma_optimizer = optimizer(**optimizer_kwargs_sigma)
|
|
1277
|
+
max_kl_update = None
|
|
1254
1278
|
self.optimizers = (mu_optimizer, sigma_optimizer)
|
|
1279
|
+
self.max_kl = max_kl_update
|
|
1255
1280
|
|
|
1256
1281
|
def __str__(self) -> str:
|
|
1257
1282
|
return (f'PGPE hyper-parameters:\n'
|
|
1258
|
-
f' method
|
|
1259
|
-
f' batch_size
|
|
1260
|
-
f' init_sigma
|
|
1261
|
-
f' sigma_range
|
|
1262
|
-
f' scale_reward
|
|
1263
|
-
f'
|
|
1264
|
-
f'
|
|
1265
|
-
f'
|
|
1283
|
+
f' method ={self.__class__.__name__}\n'
|
|
1284
|
+
f' batch_size ={self.batch_size}\n'
|
|
1285
|
+
f' init_sigma ={self.init_sigma}\n'
|
|
1286
|
+
f' sigma_range ={self.sigma_range}\n'
|
|
1287
|
+
f' scale_reward ={self.scale_reward}\n'
|
|
1288
|
+
f' min_reward_scale ={self.min_reward_scale}\n'
|
|
1289
|
+
f' super_symmetric ={self.super_symmetric}\n'
|
|
1290
|
+
f' accurate ={self.super_symmetric_accurate}\n'
|
|
1291
|
+
f' optimizer ={self.optimizer_name}\n'
|
|
1266
1292
|
f' optimizer_kwargs:\n'
|
|
1267
1293
|
f' mu ={self.optimizer_kwargs_mu}\n'
|
|
1268
1294
|
f' sigma={self.optimizer_kwargs_sigma}\n'
|
|
1295
|
+
f' start_entropy_coeff={self.start_entropy_coeff}\n'
|
|
1296
|
+
f' end_entropy_coeff ={self.end_entropy_coeff}\n'
|
|
1297
|
+
f' max_kl_update ={self.max_kl}\n'
|
|
1269
1298
|
)
|
|
1270
1299
|
|
|
1271
1300
|
def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type) -> None:
|
|
1272
|
-
MIN_NORM = 1e-5
|
|
1273
1301
|
sigma0 = self.init_sigma
|
|
1274
1302
|
sigma_range = self.sigma_range
|
|
1275
1303
|
scale_reward = self.scale_reward
|
|
1304
|
+
min_reward_scale = self.min_reward_scale
|
|
1276
1305
|
super_symmetric = self.super_symmetric
|
|
1277
1306
|
super_symmetric_accurate = self.super_symmetric_accurate
|
|
1278
1307
|
batch_size = self.batch_size
|
|
1279
1308
|
optimizers = (mu_optimizer, sigma_optimizer) = self.optimizers
|
|
1280
|
-
|
|
1281
|
-
|
|
1309
|
+
max_kl = self.max_kl
|
|
1310
|
+
|
|
1311
|
+
# entropy regularization penalty is decayed exponentially by elapsed budget
|
|
1312
|
+
start_entropy_coeff = self.start_entropy_coeff
|
|
1313
|
+
if start_entropy_coeff == 0:
|
|
1314
|
+
entropy_coeff_decay = 0
|
|
1315
|
+
else:
|
|
1316
|
+
entropy_coeff_decay = (self.end_entropy_coeff / start_entropy_coeff) ** 0.01
|
|
1317
|
+
|
|
1318
|
+
# ***********************************************************************
|
|
1319
|
+
# INITIALIZATION OF POLICY
|
|
1320
|
+
#
|
|
1321
|
+
# ***********************************************************************
|
|
1322
|
+
|
|
1282
1323
|
def _jax_wrapped_pgpe_init(key, policy_params):
|
|
1283
1324
|
mu = policy_params
|
|
1284
1325
|
sigma = jax.tree_map(lambda x: sigma0 * jnp.ones_like(x), mu)
|
|
@@ -1289,7 +1330,11 @@ class GaussianPGPE(PGPE):
|
|
|
1289
1330
|
|
|
1290
1331
|
self._initializer = jax.jit(_jax_wrapped_pgpe_init)
|
|
1291
1332
|
|
|
1292
|
-
#
|
|
1333
|
+
# ***********************************************************************
|
|
1334
|
+
# PARAMETER SAMPLING FUNCTIONS
|
|
1335
|
+
#
|
|
1336
|
+
# ***********************************************************************
|
|
1337
|
+
|
|
1293
1338
|
def _jax_wrapped_mu_noise(key, sigma):
|
|
1294
1339
|
return sigma * random.normal(key, shape=jnp.shape(sigma), dtype=real_dtype)
|
|
1295
1340
|
|
|
@@ -1299,19 +1344,20 @@ class GaussianPGPE(PGPE):
|
|
|
1299
1344
|
a = (sigma - jnp.abs(epsilon)) / sigma
|
|
1300
1345
|
if super_symmetric_accurate:
|
|
1301
1346
|
aa = jnp.abs(a)
|
|
1347
|
+
aa3 = jnp.power(aa, 3)
|
|
1302
1348
|
epsilon_star = jnp.sign(epsilon) * phi * jnp.where(
|
|
1303
1349
|
a <= 0,
|
|
1304
|
-
jnp.exp(c1 *
|
|
1305
|
-
jnp.exp(aa - c3 * aa * jnp.log(1.0 -
|
|
1350
|
+
jnp.exp(c1 * (aa3 - aa) / jnp.log(aa + 1e-10) + c2 * aa),
|
|
1351
|
+
jnp.exp(aa - c3 * aa * jnp.log(1.0 - aa3 + 1e-10))
|
|
1306
1352
|
)
|
|
1307
1353
|
else:
|
|
1308
1354
|
epsilon_star = jnp.sign(epsilon) * phi * jnp.exp(a)
|
|
1309
1355
|
return epsilon_star
|
|
1310
1356
|
|
|
1311
1357
|
def _jax_wrapped_sample_params(key, mu, sigma):
|
|
1312
|
-
|
|
1313
|
-
|
|
1314
|
-
|
|
1358
|
+
treedef = jax.tree_util.tree_structure(sigma)
|
|
1359
|
+
keys = random.split(key, num=treedef.num_leaves)
|
|
1360
|
+
keys_pytree = jax.tree_util.tree_unflatten(treedef=treedef, leaves=keys)
|
|
1315
1361
|
epsilon = jax.tree_map(_jax_wrapped_mu_noise, keys_pytree, sigma)
|
|
1316
1362
|
p1 = jax.tree_map(jnp.add, mu, epsilon)
|
|
1317
1363
|
p2 = jax.tree_map(jnp.subtract, mu, epsilon)
|
|
@@ -1321,14 +1367,18 @@ class GaussianPGPE(PGPE):
|
|
|
1321
1367
|
p4 = jax.tree_map(jnp.subtract, mu, epsilon_star)
|
|
1322
1368
|
else:
|
|
1323
1369
|
epsilon_star, p3, p4 = epsilon, p1, p2
|
|
1324
|
-
return
|
|
1370
|
+
return p1, p2, p3, p4, epsilon, epsilon_star
|
|
1325
1371
|
|
|
1326
|
-
#
|
|
1372
|
+
# ***********************************************************************
|
|
1373
|
+
# POLICY GRADIENT CALCULATION
|
|
1374
|
+
#
|
|
1375
|
+
# ***********************************************************************
|
|
1376
|
+
|
|
1327
1377
|
def _jax_wrapped_mu_grad(epsilon, epsilon_star, r1, r2, r3, r4, m):
|
|
1328
1378
|
if super_symmetric:
|
|
1329
1379
|
if scale_reward:
|
|
1330
|
-
scale1 = jnp.maximum(
|
|
1331
|
-
scale2 = jnp.maximum(
|
|
1380
|
+
scale1 = jnp.maximum(min_reward_scale, m - (r1 + r2) / 2)
|
|
1381
|
+
scale2 = jnp.maximum(min_reward_scale, m - (r3 + r4) / 2)
|
|
1332
1382
|
else:
|
|
1333
1383
|
scale1 = scale2 = 1.0
|
|
1334
1384
|
r_mu1 = (r1 - r2) / (2 * scale1)
|
|
@@ -1336,37 +1386,37 @@ class GaussianPGPE(PGPE):
|
|
|
1336
1386
|
grad = -(r_mu1 * epsilon + r_mu2 * epsilon_star)
|
|
1337
1387
|
else:
|
|
1338
1388
|
if scale_reward:
|
|
1339
|
-
scale = jnp.maximum(
|
|
1389
|
+
scale = jnp.maximum(min_reward_scale, m - (r1 + r2) / 2)
|
|
1340
1390
|
else:
|
|
1341
1391
|
scale = 1.0
|
|
1342
1392
|
r_mu = (r1 - r2) / (2 * scale)
|
|
1343
1393
|
grad = -r_mu * epsilon
|
|
1344
1394
|
return grad
|
|
1345
1395
|
|
|
1346
|
-
def _jax_wrapped_sigma_grad(epsilon, epsilon_star, sigma, r1, r2, r3, r4, m):
|
|
1396
|
+
def _jax_wrapped_sigma_grad(epsilon, epsilon_star, sigma, r1, r2, r3, r4, m, ent):
|
|
1347
1397
|
if super_symmetric:
|
|
1348
1398
|
mask = r1 + r2 >= r3 + r4
|
|
1349
1399
|
epsilon_tau = mask * epsilon + (1 - mask) * epsilon_star
|
|
1350
|
-
s = epsilon_tau
|
|
1400
|
+
s = jnp.square(epsilon_tau) / sigma - sigma
|
|
1351
1401
|
if scale_reward:
|
|
1352
|
-
scale = jnp.maximum(
|
|
1402
|
+
scale = jnp.maximum(min_reward_scale, m - (r1 + r2 + r3 + r4) / 4)
|
|
1353
1403
|
else:
|
|
1354
1404
|
scale = 1.0
|
|
1355
1405
|
r_sigma = ((r1 + r2) - (r3 + r4)) / (4 * scale)
|
|
1356
1406
|
else:
|
|
1357
|
-
s = epsilon
|
|
1407
|
+
s = jnp.square(epsilon) / sigma - sigma
|
|
1358
1408
|
if scale_reward:
|
|
1359
|
-
scale = jnp.maximum(
|
|
1409
|
+
scale = jnp.maximum(min_reward_scale, jnp.abs(m))
|
|
1360
1410
|
else:
|
|
1361
1411
|
scale = 1.0
|
|
1362
1412
|
r_sigma = (r1 + r2) / (2 * scale)
|
|
1363
|
-
grad = -r_sigma * s
|
|
1413
|
+
grad = -(r_sigma * s + ent / sigma)
|
|
1364
1414
|
return grad
|
|
1365
1415
|
|
|
1366
|
-
def _jax_wrapped_pgpe_grad(key, mu, sigma, r_max,
|
|
1416
|
+
def _jax_wrapped_pgpe_grad(key, mu, sigma, r_max, ent,
|
|
1367
1417
|
policy_hyperparams, subs, model_params):
|
|
1368
1418
|
key, subkey = random.split(key)
|
|
1369
|
-
|
|
1419
|
+
p1, p2, p3, p4, epsilon, epsilon_star = _jax_wrapped_sample_params(
|
|
1370
1420
|
key, mu, sigma)
|
|
1371
1421
|
r1 = -loss_fn(subkey, p1, policy_hyperparams, subs, model_params)[0]
|
|
1372
1422
|
r2 = -loss_fn(subkey, p2, policy_hyperparams, subs, model_params)[0]
|
|
@@ -1384,42 +1434,76 @@ class GaussianPGPE(PGPE):
|
|
|
1384
1434
|
epsilon, epsilon_star
|
|
1385
1435
|
)
|
|
1386
1436
|
grad_sigma = jax.tree_map(
|
|
1387
|
-
partial(_jax_wrapped_sigma_grad,
|
|
1437
|
+
partial(_jax_wrapped_sigma_grad,
|
|
1438
|
+
r1=r1, r2=r2, r3=r3, r4=r4, m=r_max, ent=ent),
|
|
1388
1439
|
epsilon, epsilon_star, sigma
|
|
1389
1440
|
)
|
|
1390
1441
|
return grad_mu, grad_sigma, r_max
|
|
1391
1442
|
|
|
1392
|
-
def _jax_wrapped_pgpe_grad_batched(key, pgpe_params, r_max,
|
|
1443
|
+
def _jax_wrapped_pgpe_grad_batched(key, pgpe_params, r_max, ent,
|
|
1393
1444
|
policy_hyperparams, subs, model_params):
|
|
1394
1445
|
mu, sigma = pgpe_params
|
|
1395
1446
|
if batch_size == 1:
|
|
1396
1447
|
mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad(
|
|
1397
|
-
key, mu, sigma, r_max, policy_hyperparams, subs, model_params)
|
|
1448
|
+
key, mu, sigma, r_max, ent, policy_hyperparams, subs, model_params)
|
|
1398
1449
|
else:
|
|
1399
1450
|
keys = random.split(key, num=batch_size)
|
|
1400
1451
|
mu_grads, sigma_grads, r_maxs = jax.vmap(
|
|
1401
1452
|
_jax_wrapped_pgpe_grad,
|
|
1402
|
-
in_axes=(0, None, None, None, None, None, None)
|
|
1403
|
-
)(keys, mu, sigma, r_max, policy_hyperparams, subs, model_params)
|
|
1453
|
+
in_axes=(0, None, None, None, None, None, None, None)
|
|
1454
|
+
)(keys, mu, sigma, r_max, ent, policy_hyperparams, subs, model_params)
|
|
1404
1455
|
mu_grad, sigma_grad = jax.tree_map(
|
|
1405
1456
|
partial(jnp.mean, axis=0), (mu_grads, sigma_grads))
|
|
1406
1457
|
new_r_max = jnp.max(r_maxs)
|
|
1407
1458
|
return mu_grad, sigma_grad, new_r_max
|
|
1459
|
+
|
|
1460
|
+
# ***********************************************************************
|
|
1461
|
+
# PARAMETER UPDATE
|
|
1462
|
+
#
|
|
1463
|
+
# ***********************************************************************
|
|
1408
1464
|
|
|
1409
|
-
def
|
|
1465
|
+
def _jax_wrapped_pgpe_kl_term(mu, sigma, old_mu, old_sigma):
|
|
1466
|
+
return 0.5 * jnp.sum(2 * jnp.log(sigma / old_sigma) +
|
|
1467
|
+
jnp.square(old_sigma / sigma) +
|
|
1468
|
+
jnp.square((mu - old_mu) / sigma) - 1)
|
|
1469
|
+
|
|
1470
|
+
def _jax_wrapped_pgpe_update(key, pgpe_params, r_max, progress,
|
|
1410
1471
|
policy_hyperparams, subs, model_params,
|
|
1411
1472
|
pgpe_opt_state):
|
|
1473
|
+
# regular update
|
|
1412
1474
|
mu, sigma = pgpe_params
|
|
1413
1475
|
mu_state, sigma_state = pgpe_opt_state
|
|
1476
|
+
ent = start_entropy_coeff * jnp.power(entropy_coeff_decay, progress)
|
|
1414
1477
|
mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad_batched(
|
|
1415
|
-
key, pgpe_params, r_max, policy_hyperparams, subs, model_params)
|
|
1478
|
+
key, pgpe_params, r_max, ent, policy_hyperparams, subs, model_params)
|
|
1416
1479
|
mu_updates, new_mu_state = mu_optimizer.update(mu_grad, mu_state, params=mu)
|
|
1417
1480
|
sigma_updates, new_sigma_state = sigma_optimizer.update(
|
|
1418
1481
|
sigma_grad, sigma_state, params=sigma)
|
|
1419
1482
|
new_mu = optax.apply_updates(mu, mu_updates)
|
|
1420
|
-
new_mu, converged = projection(new_mu, policy_hyperparams)
|
|
1421
1483
|
new_sigma = optax.apply_updates(sigma, sigma_updates)
|
|
1422
1484
|
new_sigma = jax.tree_map(lambda x: jnp.clip(x, *sigma_range), new_sigma)
|
|
1485
|
+
|
|
1486
|
+
# respect KL divergence contraint with old parameters
|
|
1487
|
+
if max_kl is not None:
|
|
1488
|
+
old_mu_lr = new_mu_state.hyperparams['learning_rate']
|
|
1489
|
+
old_sigma_lr = new_sigma_state.hyperparams['learning_rate']
|
|
1490
|
+
kl_terms = jax.tree_map(
|
|
1491
|
+
_jax_wrapped_pgpe_kl_term, new_mu, new_sigma, mu, sigma)
|
|
1492
|
+
total_kl = jax.tree_util.tree_reduce(jnp.add, kl_terms)
|
|
1493
|
+
kl_reduction = jnp.minimum(1.0, jnp.sqrt(max_kl / total_kl))
|
|
1494
|
+
mu_state.hyperparams['learning_rate'] = old_mu_lr * kl_reduction
|
|
1495
|
+
sigma_state.hyperparams['learning_rate'] = old_sigma_lr * kl_reduction
|
|
1496
|
+
mu_updates, new_mu_state = mu_optimizer.update(mu_grad, mu_state, params=mu)
|
|
1497
|
+
sigma_updates, new_sigma_state = sigma_optimizer.update(
|
|
1498
|
+
sigma_grad, sigma_state, params=sigma)
|
|
1499
|
+
new_mu = optax.apply_updates(mu, mu_updates)
|
|
1500
|
+
new_sigma = optax.apply_updates(sigma, sigma_updates)
|
|
1501
|
+
new_sigma = jax.tree_map(lambda x: jnp.clip(x, *sigma_range), new_sigma)
|
|
1502
|
+
new_mu_state.hyperparams['learning_rate'] = old_mu_lr
|
|
1503
|
+
new_sigma_state.hyperparams['learning_rate'] = old_sigma_lr
|
|
1504
|
+
|
|
1505
|
+
# apply projection step and finalize results
|
|
1506
|
+
new_mu, converged = projection(new_mu, policy_hyperparams)
|
|
1423
1507
|
new_pgpe_params = (new_mu, new_sigma)
|
|
1424
1508
|
new_pgpe_opt_state = (new_mu_state, new_sigma_state)
|
|
1425
1509
|
policy_params = new_mu
|
|
@@ -1462,14 +1546,14 @@ def mean_deviation_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
|
1462
1546
|
@jax.jit
|
|
1463
1547
|
def mean_semideviation_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
1464
1548
|
mu = jnp.mean(returns)
|
|
1465
|
-
msd = jnp.sqrt(jnp.mean(jnp.minimum(0.0, returns - mu)
|
|
1549
|
+
msd = jnp.sqrt(jnp.mean(jnp.square(jnp.minimum(0.0, returns - mu))))
|
|
1466
1550
|
return mu - 0.5 * beta * msd
|
|
1467
1551
|
|
|
1468
1552
|
|
|
1469
1553
|
@jax.jit
|
|
1470
1554
|
def mean_semivariance_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
1471
1555
|
mu = jnp.mean(returns)
|
|
1472
|
-
msv = jnp.mean(jnp.minimum(0.0, returns - mu)
|
|
1556
|
+
msv = jnp.mean(jnp.square(jnp.minimum(0.0, returns - mu)))
|
|
1473
1557
|
return mu - 0.5 * beta * msv
|
|
1474
1558
|
|
|
1475
1559
|
|
|
@@ -1768,7 +1852,6 @@ r"""
|
|
|
1768
1852
|
|
|
1769
1853
|
# optimization
|
|
1770
1854
|
self.update = self._jax_update(train_loss)
|
|
1771
|
-
self.check_zero_grad = self._jax_check_zero_gradients()
|
|
1772
1855
|
|
|
1773
1856
|
# pgpe option
|
|
1774
1857
|
if self.use_pgpe:
|
|
@@ -1831,6 +1914,12 @@ r"""
|
|
|
1831
1914
|
projection = self.plan.projection
|
|
1832
1915
|
use_ls = self.line_search_kwargs is not None
|
|
1833
1916
|
|
|
1917
|
+
# check if the gradients are all zeros
|
|
1918
|
+
def _jax_wrapped_zero_gradients(grad):
|
|
1919
|
+
leaves, _ = jax.tree_util.tree_flatten(
|
|
1920
|
+
jax.tree_map(lambda g: jnp.allclose(g, 0), grad))
|
|
1921
|
+
return jnp.all(jnp.asarray(leaves))
|
|
1922
|
+
|
|
1834
1923
|
# calculate the plan gradient w.r.t. return loss and update optimizer
|
|
1835
1924
|
# also perform a projection step to satisfy constraints on actions
|
|
1836
1925
|
def _jax_wrapped_loss_swapped(policy_params, key, policy_hyperparams,
|
|
@@ -1855,23 +1944,12 @@ r"""
|
|
|
1855
1944
|
policy_params, converged = projection(policy_params, policy_hyperparams)
|
|
1856
1945
|
log['grad'] = grad
|
|
1857
1946
|
log['updates'] = updates
|
|
1947
|
+
zero_grads = _jax_wrapped_zero_gradients(grad)
|
|
1858
1948
|
return policy_params, converged, opt_state, opt_aux, \
|
|
1859
|
-
loss_val, log, model_params
|
|
1949
|
+
loss_val, log, model_params, zero_grads
|
|
1860
1950
|
|
|
1861
1951
|
return jax.jit(_jax_wrapped_plan_update)
|
|
1862
1952
|
|
|
1863
|
-
def _jax_check_zero_gradients(self):
|
|
1864
|
-
|
|
1865
|
-
def _jax_wrapped_zero_gradient(grad):
|
|
1866
|
-
return jnp.allclose(grad, 0)
|
|
1867
|
-
|
|
1868
|
-
def _jax_wrapped_zero_gradients(grad):
|
|
1869
|
-
leaves, _ = jax.tree_util.tree_flatten(
|
|
1870
|
-
jax.tree_map(_jax_wrapped_zero_gradient, grad))
|
|
1871
|
-
return jnp.all(jnp.asarray(leaves))
|
|
1872
|
-
|
|
1873
|
-
return jax.jit(_jax_wrapped_zero_gradients)
|
|
1874
|
-
|
|
1875
1953
|
def _batched_init_subs(self, subs):
|
|
1876
1954
|
rddl = self.rddl
|
|
1877
1955
|
n_train, n_test = self.batch_size_train, self.batch_size_test
|
|
@@ -2175,11 +2253,12 @@ r"""
|
|
|
2175
2253
|
# ======================================================================
|
|
2176
2254
|
|
|
2177
2255
|
# initialize running statistics
|
|
2178
|
-
best_params, best_loss, best_grad = policy_params, jnp.inf,
|
|
2256
|
+
best_params, best_loss, best_grad = policy_params, jnp.inf, None
|
|
2179
2257
|
last_iter_improve = 0
|
|
2180
2258
|
rolling_test_loss = RollingMean(test_rolling_window)
|
|
2181
2259
|
log = {}
|
|
2182
2260
|
status = JaxPlannerStatus.NORMAL
|
|
2261
|
+
progress_percent = 0
|
|
2183
2262
|
|
|
2184
2263
|
# initialize stopping criterion
|
|
2185
2264
|
if stopping_rule is not None:
|
|
@@ -2191,18 +2270,19 @@ r"""
|
|
|
2191
2270
|
dashboard_id, dashboard.get_planner_info(self),
|
|
2192
2271
|
key=dash_key, viz=self.dashboard_viz)
|
|
2193
2272
|
|
|
2273
|
+
# progress bar
|
|
2274
|
+
if print_progress:
|
|
2275
|
+
progress_bar = tqdm(None, total=100, position=tqdm_position,
|
|
2276
|
+
bar_format='{l_bar}{bar}| {elapsed} {postfix}')
|
|
2277
|
+
else:
|
|
2278
|
+
progress_bar = None
|
|
2279
|
+
position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
|
|
2280
|
+
|
|
2194
2281
|
# ======================================================================
|
|
2195
2282
|
# MAIN TRAINING LOOP BEGINS
|
|
2196
2283
|
# ======================================================================
|
|
2197
2284
|
|
|
2198
|
-
|
|
2199
|
-
if print_progress:
|
|
2200
|
-
iters = tqdm(iters, total=100,
|
|
2201
|
-
bar_format='{l_bar}{bar}| {elapsed} {postfix}',
|
|
2202
|
-
position=tqdm_position)
|
|
2203
|
-
position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
|
|
2204
|
-
|
|
2205
|
-
for it in iters:
|
|
2285
|
+
for it in range(epochs):
|
|
2206
2286
|
|
|
2207
2287
|
# ==================================================================
|
|
2208
2288
|
# NEXT GRADIENT DESCENT STEP
|
|
@@ -2213,8 +2293,9 @@ r"""
|
|
|
2213
2293
|
# update the parameters of the plan
|
|
2214
2294
|
key, subkey = random.split(key)
|
|
2215
2295
|
(policy_params, converged, opt_state, opt_aux, train_loss, train_log,
|
|
2216
|
-
model_params) = self.update(
|
|
2217
|
-
|
|
2296
|
+
model_params, zero_grads) = self.update(
|
|
2297
|
+
subkey, policy_params, policy_hyperparams, train_subs, model_params,
|
|
2298
|
+
opt_state, opt_aux)
|
|
2218
2299
|
test_loss, (test_log, model_params_test) = self.test_loss(
|
|
2219
2300
|
subkey, policy_params, policy_hyperparams, test_subs, model_params_test)
|
|
2220
2301
|
test_loss_smooth = rolling_test_loss.update(test_loss)
|
|
@@ -2224,8 +2305,9 @@ r"""
|
|
|
2224
2305
|
if self.use_pgpe:
|
|
2225
2306
|
key, subkey = random.split(key)
|
|
2226
2307
|
pgpe_params, r_max, pgpe_opt_state, pgpe_param, pgpe_converged = \
|
|
2227
|
-
self.pgpe.update(subkey, pgpe_params, r_max,
|
|
2228
|
-
test_subs,
|
|
2308
|
+
self.pgpe.update(subkey, pgpe_params, r_max, progress_percent,
|
|
2309
|
+
policy_hyperparams, test_subs, model_params_test,
|
|
2310
|
+
pgpe_opt_state)
|
|
2229
2311
|
pgpe_loss, _ = self.test_loss(
|
|
2230
2312
|
subkey, pgpe_param, policy_hyperparams, test_subs, model_params_test)
|
|
2231
2313
|
pgpe_loss_smooth = rolling_pgpe_loss.update(pgpe_loss)
|
|
@@ -2252,7 +2334,7 @@ r"""
|
|
|
2252
2334
|
# ==================================================================
|
|
2253
2335
|
|
|
2254
2336
|
# no progress
|
|
2255
|
-
if (not pgpe_improve) and
|
|
2337
|
+
if (not pgpe_improve) and zero_grads:
|
|
2256
2338
|
status = JaxPlannerStatus.NO_PROGRESS
|
|
2257
2339
|
|
|
2258
2340
|
# constraint satisfaction problem
|
|
@@ -2311,14 +2393,15 @@ r"""
|
|
|
2311
2393
|
|
|
2312
2394
|
# if the progress bar is used
|
|
2313
2395
|
if print_progress:
|
|
2314
|
-
|
|
2315
|
-
iters.set_description(
|
|
2396
|
+
progress_bar.set_description(
|
|
2316
2397
|
f'{position_str} {it:6} it / {-train_loss:14.5f} train / '
|
|
2317
2398
|
f'{-test_loss_smooth:14.5f} test / {-best_loss:14.5f} best / '
|
|
2318
2399
|
f'{status.value} status / {total_pgpe_it:6} pgpe',
|
|
2319
2400
|
refresh=False
|
|
2320
2401
|
)
|
|
2321
|
-
|
|
2402
|
+
progress_bar.set_postfix_str(
|
|
2403
|
+
f"{(it + 1) / (elapsed + 1e-6):.2f}it/s", refresh=False)
|
|
2404
|
+
progress_bar.update(progress_percent - progress_bar.n)
|
|
2322
2405
|
|
|
2323
2406
|
# dash-board
|
|
2324
2407
|
if dashboard is not None:
|
|
@@ -2339,7 +2422,7 @@ r"""
|
|
|
2339
2422
|
|
|
2340
2423
|
# release resources
|
|
2341
2424
|
if print_progress:
|
|
2342
|
-
|
|
2425
|
+
progress_bar.close()
|
|
2343
2426
|
|
|
2344
2427
|
# validate the test return
|
|
2345
2428
|
if log:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: pyRDDLGym-jax
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.2
|
|
4
4
|
Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
|
|
5
5
|
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
6
|
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
@@ -58,8 +58,11 @@ Dynamic: summary
|
|
|
58
58
|
|
|
59
59
|
Purpose:
|
|
60
60
|
|
|
61
|
-
1. automatic translation of
|
|
62
|
-
2.
|
|
61
|
+
1. automatic translation of RDDL description files into differentiable JAX simulators
|
|
62
|
+
2. implementation of (highly configurable) operator relaxations for working in discrete and hybrid domains
|
|
63
|
+
3. flexible policy representations and automated Bayesian hyper-parameter tuning
|
|
64
|
+
4. interactive dashboard for dyanmic visualization and debugging
|
|
65
|
+
5. hybridization with parameter-exploring policy gradients.
|
|
63
66
|
|
|
64
67
|
Some demos of solved problems by JaxPlan:
|
|
65
68
|
|
|
@@ -235,8 +238,23 @@ More documentation about this and other new features will be coming soon.
|
|
|
235
238
|
|
|
236
239
|
## Tuning the Planner
|
|
237
240
|
|
|
238
|
-
|
|
239
|
-
|
|
241
|
+
A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
|
|
242
|
+
|
|
243
|
+
```shell
|
|
244
|
+
jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
|
|
245
|
+
```
|
|
246
|
+
|
|
247
|
+
where:
|
|
248
|
+
- ``domain`` is the domain identifier as specified in rddlrepository
|
|
249
|
+
- ``instance`` is the instance identifier
|
|
250
|
+
- ``method`` is the planning method to use (i.e. drp, slp, replan)
|
|
251
|
+
- ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
|
|
252
|
+
- ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
|
|
253
|
+
- ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
|
|
254
|
+
- ``dashboard`` is whether the optimizations are tracked in the dashboard application.
|
|
255
|
+
|
|
256
|
+
It is easy to tune a custom range of the planner's hyper-parameters efficiently.
|
|
257
|
+
First create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
|
|
240
258
|
|
|
241
259
|
```ini
|
|
242
260
|
[Model]
|
|
@@ -260,7 +278,7 @@ train_on_reset=True
|
|
|
260
278
|
|
|
261
279
|
would allow to tune the sharpness of model relaxations, and the learning rate of the optimizer.
|
|
262
280
|
|
|
263
|
-
Next, you must link the patterns in the config with concrete hyper-parameter ranges the tuner will understand:
|
|
281
|
+
Next, you must link the patterns in the config with concrete hyper-parameter ranges the tuner will understand, and run the optimizer:
|
|
264
282
|
|
|
265
283
|
```python
|
|
266
284
|
import pyRDDLGym
|
|
@@ -292,22 +310,7 @@ tuning = JaxParameterTuning(env=env,
|
|
|
292
310
|
gp_iters=iters)
|
|
293
311
|
tuning.tune(key=42, log_file='path/to/log.csv')
|
|
294
312
|
```
|
|
295
|
-
|
|
296
|
-
A basic run script is provided to run the automatic hyper-parameter tuning for the most sensitive parameters of JaxPlan:
|
|
297
|
-
|
|
298
|
-
```shell
|
|
299
|
-
jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
|
|
300
|
-
```
|
|
301
|
-
|
|
302
|
-
where:
|
|
303
|
-
- ``domain`` is the domain identifier as specified in rddlrepository
|
|
304
|
-
- ``instance`` is the instance identifier
|
|
305
|
-
- ``method`` is the planning method to use (i.e. drp, slp, replan)
|
|
306
|
-
- ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
|
|
307
|
-
- ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
|
|
308
|
-
- ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
|
|
309
|
-
- ``dashboard`` is whether the optimizations are tracked in the dashboard application.
|
|
310
|
-
|
|
313
|
+
|
|
311
314
|
|
|
312
315
|
## Simulation
|
|
313
316
|
|
|
@@ -19,7 +19,7 @@ long_description = (Path(__file__).parent / "README.md").read_text()
|
|
|
19
19
|
|
|
20
20
|
setup(
|
|
21
21
|
name='pyRDDLGym-jax',
|
|
22
|
-
version='2.
|
|
22
|
+
version='2.2',
|
|
23
23
|
author="Michael Gimelfarb, Ayal Taitler, Scott Sanner",
|
|
24
24
|
author_email="mike.gimelfarb@mail.utoronto.ca, ataitler@gmail.com, ssanner@mie.utoronto.ca",
|
|
25
25
|
description="pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.",
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
__version__ = '2.1'
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg
RENAMED
|
File without changes
|
{pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg
RENAMED
|
File without changes
|
|
File without changes
|
{pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg
RENAMED
|
File without changes
|
{pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg
RENAMED
|
File without changes
|
|
File without changes
|
{pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg
RENAMED
|
File without changes
|
|
File without changes
|
{pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg
RENAMED
|
File without changes
|
{pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg
RENAMED
|
File without changes
|
{pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg
RENAMED
|
File without changes
|
|
File without changes
|
{pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|