pyRDDLGym-jax 2.1__tar.gz → 2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/PKG-INFO +25 -22
  2. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/README.md +24 -21
  3. pyrddlgym_jax-2.2/pyRDDLGym_jax/__init__.py +1 -0
  4. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/core/planner.py +159 -76
  5. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax.egg-info/PKG-INFO +25 -22
  6. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/setup.py +1 -1
  7. pyrddlgym_jax-2.1/pyRDDLGym_jax/__init__.py +0 -1
  8. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/LICENSE +0 -0
  9. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/core/__init__.py +0 -0
  10. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/core/assets/__init__.py +0 -0
  11. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/core/assets/favicon.ico +0 -0
  12. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/core/compiler.py +0 -0
  13. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/core/logic.py +0 -0
  14. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/core/simulator.py +0 -0
  15. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/core/tuning.py +0 -0
  16. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/core/visualization.py +0 -0
  17. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/entry_point.py +0 -0
  18. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/__init__.py +0 -0
  19. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +0 -0
  20. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +0 -0
  21. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +0 -0
  22. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +0 -0
  23. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +0 -0
  24. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +0 -0
  25. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +0 -0
  26. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +0 -0
  27. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +0 -0
  28. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +0 -0
  29. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +0 -0
  30. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +0 -0
  31. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +0 -0
  32. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +0 -0
  33. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +0 -0
  34. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +0 -0
  35. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +0 -0
  36. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +0 -0
  37. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +0 -0
  38. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/__init__.py +0 -0
  39. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/default_drp.cfg +0 -0
  40. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/default_replan.cfg +0 -0
  41. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/default_slp.cfg +0 -0
  42. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/tuning_drp.cfg +0 -0
  43. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/tuning_replan.cfg +0 -0
  44. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/configs/tuning_slp.cfg +0 -0
  45. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/run_gradient.py +0 -0
  46. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/run_gym.py +0 -0
  47. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/run_plan.py +0 -0
  48. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/run_scipy.py +0 -0
  49. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax/examples/run_tune.py +0 -0
  50. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax.egg-info/SOURCES.txt +0 -0
  51. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax.egg-info/dependency_links.txt +0 -0
  52. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax.egg-info/entry_points.txt +0 -0
  53. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax.egg-info/requires.txt +0 -0
  54. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/pyRDDLGym_jax.egg-info/top_level.txt +0 -0
  55. {pyrddlgym_jax-2.1 → pyrddlgym_jax-2.2}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: pyRDDLGym-jax
3
- Version: 2.1
3
+ Version: 2.2
4
4
  Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
5
  Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
6
  Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
@@ -58,8 +58,11 @@ Dynamic: summary
58
58
 
59
59
  Purpose:
60
60
 
61
- 1. automatic translation of any RDDL description file into a differentiable simulator in JAX
62
- 2. flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and Bayesian hyper-parameter tuning.
61
+ 1. automatic translation of RDDL description files into differentiable JAX simulators
62
+ 2. implementation of (highly configurable) operator relaxations for working in discrete and hybrid domains
63
+ 3. flexible policy representations and automated Bayesian hyper-parameter tuning
64
+ 4. interactive dashboard for dyanmic visualization and debugging
65
+ 5. hybridization with parameter-exploring policy gradients.
63
66
 
64
67
  Some demos of solved problems by JaxPlan:
65
68
 
@@ -235,8 +238,23 @@ More documentation about this and other new features will be coming soon.
235
238
 
236
239
  ## Tuning the Planner
237
240
 
238
- It is easy to tune the planner's hyper-parameters efficiently and automatically using Bayesian optimization.
239
- To do this, first create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
241
+ A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
242
+
243
+ ```shell
244
+ jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
245
+ ```
246
+
247
+ where:
248
+ - ``domain`` is the domain identifier as specified in rddlrepository
249
+ - ``instance`` is the instance identifier
250
+ - ``method`` is the planning method to use (i.e. drp, slp, replan)
251
+ - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
252
+ - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
253
+ - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
254
+ - ``dashboard`` is whether the optimizations are tracked in the dashboard application.
255
+
256
+ It is easy to tune a custom range of the planner's hyper-parameters efficiently.
257
+ First create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
240
258
 
241
259
  ```ini
242
260
  [Model]
@@ -260,7 +278,7 @@ train_on_reset=True
260
278
 
261
279
  would allow to tune the sharpness of model relaxations, and the learning rate of the optimizer.
262
280
 
263
- Next, you must link the patterns in the config with concrete hyper-parameter ranges the tuner will understand:
281
+ Next, you must link the patterns in the config with concrete hyper-parameter ranges the tuner will understand, and run the optimizer:
264
282
 
265
283
  ```python
266
284
  import pyRDDLGym
@@ -292,22 +310,7 @@ tuning = JaxParameterTuning(env=env,
292
310
  gp_iters=iters)
293
311
  tuning.tune(key=42, log_file='path/to/log.csv')
294
312
  ```
295
-
296
- A basic run script is provided to run the automatic hyper-parameter tuning for the most sensitive parameters of JaxPlan:
297
-
298
- ```shell
299
- jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
300
- ```
301
-
302
- where:
303
- - ``domain`` is the domain identifier as specified in rddlrepository
304
- - ``instance`` is the instance identifier
305
- - ``method`` is the planning method to use (i.e. drp, slp, replan)
306
- - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
307
- - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
308
- - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
309
- - ``dashboard`` is whether the optimizations are tracked in the dashboard application.
310
-
313
+
311
314
 
312
315
  ## Simulation
313
316
 
@@ -12,8 +12,11 @@
12
12
 
13
13
  Purpose:
14
14
 
15
- 1. automatic translation of any RDDL description file into a differentiable simulator in JAX
16
- 2. flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and Bayesian hyper-parameter tuning.
15
+ 1. automatic translation of RDDL description files into differentiable JAX simulators
16
+ 2. implementation of (highly configurable) operator relaxations for working in discrete and hybrid domains
17
+ 3. flexible policy representations and automated Bayesian hyper-parameter tuning
18
+ 4. interactive dashboard for dyanmic visualization and debugging
19
+ 5. hybridization with parameter-exploring policy gradients.
17
20
 
18
21
  Some demos of solved problems by JaxPlan:
19
22
 
@@ -189,8 +192,23 @@ More documentation about this and other new features will be coming soon.
189
192
 
190
193
  ## Tuning the Planner
191
194
 
192
- It is easy to tune the planner's hyper-parameters efficiently and automatically using Bayesian optimization.
193
- To do this, first create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
195
+ A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
196
+
197
+ ```shell
198
+ jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
199
+ ```
200
+
201
+ where:
202
+ - ``domain`` is the domain identifier as specified in rddlrepository
203
+ - ``instance`` is the instance identifier
204
+ - ``method`` is the planning method to use (i.e. drp, slp, replan)
205
+ - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
206
+ - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
207
+ - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
208
+ - ``dashboard`` is whether the optimizations are tracked in the dashboard application.
209
+
210
+ It is easy to tune a custom range of the planner's hyper-parameters efficiently.
211
+ First create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
194
212
 
195
213
  ```ini
196
214
  [Model]
@@ -214,7 +232,7 @@ train_on_reset=True
214
232
 
215
233
  would allow to tune the sharpness of model relaxations, and the learning rate of the optimizer.
216
234
 
217
- Next, you must link the patterns in the config with concrete hyper-parameter ranges the tuner will understand:
235
+ Next, you must link the patterns in the config with concrete hyper-parameter ranges the tuner will understand, and run the optimizer:
218
236
 
219
237
  ```python
220
238
  import pyRDDLGym
@@ -246,22 +264,7 @@ tuning = JaxParameterTuning(env=env,
246
264
  gp_iters=iters)
247
265
  tuning.tune(key=42, log_file='path/to/log.csv')
248
266
  ```
249
-
250
- A basic run script is provided to run the automatic hyper-parameter tuning for the most sensitive parameters of JaxPlan:
251
-
252
- ```shell
253
- jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
254
- ```
255
-
256
- where:
257
- - ``domain`` is the domain identifier as specified in rddlrepository
258
- - ``instance`` is the instance identifier
259
- - ``method`` is the planning method to use (i.e. drp, slp, replan)
260
- - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
261
- - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
262
- - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
263
- - ``dashboard`` is whether the optimizations are tracked in the dashboard application.
264
-
267
+
265
268
 
266
269
  ## Simulation
267
270
 
@@ -0,0 +1 @@
1
+ __version__ = '2.2'
@@ -47,7 +47,9 @@ import jax.random as random
47
47
  import numpy as np
48
48
  import optax
49
49
  import termcolor
50
- from tqdm import tqdm
50
+ from tqdm import tqdm, TqdmWarning
51
+ import warnings
52
+ warnings.filterwarnings("ignore", category=TqdmWarning)
51
53
 
52
54
  from pyRDDLGym.core.compiler.model import RDDLPlanningModel, RDDLLiftedModel
53
55
  from pyRDDLGym.core.debug.logger import Logger
@@ -1212,17 +1214,22 @@ class GaussianPGPE(PGPE):
1212
1214
  init_sigma: float=1.0,
1213
1215
  sigma_range: Tuple[float, float]=(1e-5, 1e5),
1214
1216
  scale_reward: bool=True,
1217
+ min_reward_scale: float=1e-5,
1215
1218
  super_symmetric: bool=True,
1216
1219
  super_symmetric_accurate: bool=True,
1217
1220
  optimizer: Callable[..., optax.GradientTransformation]=optax.adam,
1218
1221
  optimizer_kwargs_mu: Optional[Kwargs]=None,
1219
- optimizer_kwargs_sigma: Optional[Kwargs]=None) -> None:
1222
+ optimizer_kwargs_sigma: Optional[Kwargs]=None,
1223
+ start_entropy_coeff: float=1e-3,
1224
+ end_entropy_coeff: float=1e-8,
1225
+ max_kl_update: Optional[float]=None) -> None:
1220
1226
  '''Creates a new Gaussian PGPE planner.
1221
1227
 
1222
1228
  :param batch_size: how many policy parameters to sample per optimization step
1223
1229
  :param init_sigma: initial standard deviation of Gaussian
1224
1230
  :param sigma_range: bounds to constrain standard deviation
1225
1231
  :param scale_reward: whether to apply reward scaling as in the paper
1232
+ :param min_reward_scale: minimum reward scaling to avoid underflow
1226
1233
  :param super_symmetric: whether to use super-symmetric sampling as in the paper
1227
1234
  :param super_symmetric_accurate: whether to use the accurate formula for super-
1228
1235
  symmetric sampling or the simplified but biased formula
@@ -1231,6 +1238,9 @@ class GaussianPGPE(PGPE):
1231
1238
  factory for the mean optimizer
1232
1239
  :param optimizer_kwargs_sigma: a dictionary of parameters to pass to the SGD
1233
1240
  factory for the standard deviation optimizer
1241
+ :param start_entropy_coeff: starting entropy regularization coeffient for Gaussian
1242
+ :param end_entropy_coeff: ending entropy regularization coeffient for Gaussian
1243
+ :param max_kl_update: bound on kl-divergence between parameter updates
1234
1244
  '''
1235
1245
  super().__init__()
1236
1246
 
@@ -1238,8 +1248,13 @@ class GaussianPGPE(PGPE):
1238
1248
  self.init_sigma = init_sigma
1239
1249
  self.sigma_range = sigma_range
1240
1250
  self.scale_reward = scale_reward
1251
+ self.min_reward_scale = min_reward_scale
1241
1252
  self.super_symmetric = super_symmetric
1242
1253
  self.super_symmetric_accurate = super_symmetric_accurate
1254
+
1255
+ # entropy regularization penalty is decayed exponentially between these values
1256
+ self.start_entropy_coeff = start_entropy_coeff
1257
+ self.end_entropy_coeff = end_entropy_coeff
1243
1258
 
1244
1259
  # set optimizers
1245
1260
  if optimizer_kwargs_mu is None:
@@ -1249,36 +1264,62 @@ class GaussianPGPE(PGPE):
1249
1264
  optimizer_kwargs_sigma = {'learning_rate': 0.1}
1250
1265
  self.optimizer_kwargs_sigma = optimizer_kwargs_sigma
1251
1266
  self.optimizer_name = optimizer
1252
- mu_optimizer = optimizer(**optimizer_kwargs_mu)
1253
- sigma_optimizer = optimizer(**optimizer_kwargs_sigma)
1267
+ try:
1268
+ mu_optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs_mu)
1269
+ sigma_optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs_sigma)
1270
+ except Exception as _:
1271
+ raise_warning(
1272
+ f'Failed to inject hyperparameters into optax optimizer for PGPE, '
1273
+ 'rolling back to safer method: please note that kl-divergence '
1274
+ 'constraints will be disabled.', 'red')
1275
+ mu_optimizer = optimizer(**optimizer_kwargs_mu)
1276
+ sigma_optimizer = optimizer(**optimizer_kwargs_sigma)
1277
+ max_kl_update = None
1254
1278
  self.optimizers = (mu_optimizer, sigma_optimizer)
1279
+ self.max_kl = max_kl_update
1255
1280
 
1256
1281
  def __str__(self) -> str:
1257
1282
  return (f'PGPE hyper-parameters:\n'
1258
- f' method ={self.__class__.__name__}\n'
1259
- f' batch_size ={self.batch_size}\n'
1260
- f' init_sigma ={self.init_sigma}\n'
1261
- f' sigma_range ={self.sigma_range}\n'
1262
- f' scale_reward ={self.scale_reward}\n'
1263
- f' super_symmetric={self.super_symmetric}\n'
1264
- f' accurate ={self.super_symmetric_accurate}\n'
1265
- f' optimizer ={self.optimizer_name}\n'
1283
+ f' method ={self.__class__.__name__}\n'
1284
+ f' batch_size ={self.batch_size}\n'
1285
+ f' init_sigma ={self.init_sigma}\n'
1286
+ f' sigma_range ={self.sigma_range}\n'
1287
+ f' scale_reward ={self.scale_reward}\n'
1288
+ f' min_reward_scale ={self.min_reward_scale}\n'
1289
+ f' super_symmetric ={self.super_symmetric}\n'
1290
+ f' accurate ={self.super_symmetric_accurate}\n'
1291
+ f' optimizer ={self.optimizer_name}\n'
1266
1292
  f' optimizer_kwargs:\n'
1267
1293
  f' mu ={self.optimizer_kwargs_mu}\n'
1268
1294
  f' sigma={self.optimizer_kwargs_sigma}\n'
1295
+ f' start_entropy_coeff={self.start_entropy_coeff}\n'
1296
+ f' end_entropy_coeff ={self.end_entropy_coeff}\n'
1297
+ f' max_kl_update ={self.max_kl}\n'
1269
1298
  )
1270
1299
 
1271
1300
  def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type) -> None:
1272
- MIN_NORM = 1e-5
1273
1301
  sigma0 = self.init_sigma
1274
1302
  sigma_range = self.sigma_range
1275
1303
  scale_reward = self.scale_reward
1304
+ min_reward_scale = self.min_reward_scale
1276
1305
  super_symmetric = self.super_symmetric
1277
1306
  super_symmetric_accurate = self.super_symmetric_accurate
1278
1307
  batch_size = self.batch_size
1279
1308
  optimizers = (mu_optimizer, sigma_optimizer) = self.optimizers
1280
-
1281
- # initializer
1309
+ max_kl = self.max_kl
1310
+
1311
+ # entropy regularization penalty is decayed exponentially by elapsed budget
1312
+ start_entropy_coeff = self.start_entropy_coeff
1313
+ if start_entropy_coeff == 0:
1314
+ entropy_coeff_decay = 0
1315
+ else:
1316
+ entropy_coeff_decay = (self.end_entropy_coeff / start_entropy_coeff) ** 0.01
1317
+
1318
+ # ***********************************************************************
1319
+ # INITIALIZATION OF POLICY
1320
+ #
1321
+ # ***********************************************************************
1322
+
1282
1323
  def _jax_wrapped_pgpe_init(key, policy_params):
1283
1324
  mu = policy_params
1284
1325
  sigma = jax.tree_map(lambda x: sigma0 * jnp.ones_like(x), mu)
@@ -1289,7 +1330,11 @@ class GaussianPGPE(PGPE):
1289
1330
 
1290
1331
  self._initializer = jax.jit(_jax_wrapped_pgpe_init)
1291
1332
 
1292
- # parameter sampling functions
1333
+ # ***********************************************************************
1334
+ # PARAMETER SAMPLING FUNCTIONS
1335
+ #
1336
+ # ***********************************************************************
1337
+
1293
1338
  def _jax_wrapped_mu_noise(key, sigma):
1294
1339
  return sigma * random.normal(key, shape=jnp.shape(sigma), dtype=real_dtype)
1295
1340
 
@@ -1299,19 +1344,20 @@ class GaussianPGPE(PGPE):
1299
1344
  a = (sigma - jnp.abs(epsilon)) / sigma
1300
1345
  if super_symmetric_accurate:
1301
1346
  aa = jnp.abs(a)
1347
+ aa3 = jnp.power(aa, 3)
1302
1348
  epsilon_star = jnp.sign(epsilon) * phi * jnp.where(
1303
1349
  a <= 0,
1304
- jnp.exp(c1 * aa * (aa * aa - 1) / jnp.log(aa + 1e-10) + c2 * aa),
1305
- jnp.exp(aa - c3 * aa * jnp.log(1.0 - jnp.power(aa, 3) + 1e-10))
1350
+ jnp.exp(c1 * (aa3 - aa) / jnp.log(aa + 1e-10) + c2 * aa),
1351
+ jnp.exp(aa - c3 * aa * jnp.log(1.0 - aa3 + 1e-10))
1306
1352
  )
1307
1353
  else:
1308
1354
  epsilon_star = jnp.sign(epsilon) * phi * jnp.exp(a)
1309
1355
  return epsilon_star
1310
1356
 
1311
1357
  def _jax_wrapped_sample_params(key, mu, sigma):
1312
- keys = random.split(key, num=len(jax.tree_util.tree_leaves(mu)))
1313
- keys_pytree = jax.tree_util.tree_unflatten(
1314
- treedef=jax.tree_util.tree_structure(mu), leaves=keys)
1358
+ treedef = jax.tree_util.tree_structure(sigma)
1359
+ keys = random.split(key, num=treedef.num_leaves)
1360
+ keys_pytree = jax.tree_util.tree_unflatten(treedef=treedef, leaves=keys)
1315
1361
  epsilon = jax.tree_map(_jax_wrapped_mu_noise, keys_pytree, sigma)
1316
1362
  p1 = jax.tree_map(jnp.add, mu, epsilon)
1317
1363
  p2 = jax.tree_map(jnp.subtract, mu, epsilon)
@@ -1321,14 +1367,18 @@ class GaussianPGPE(PGPE):
1321
1367
  p4 = jax.tree_map(jnp.subtract, mu, epsilon_star)
1322
1368
  else:
1323
1369
  epsilon_star, p3, p4 = epsilon, p1, p2
1324
- return (p1, p2, p3, p4), (epsilon, epsilon_star)
1370
+ return p1, p2, p3, p4, epsilon, epsilon_star
1325
1371
 
1326
- # policy gradient update functions
1372
+ # ***********************************************************************
1373
+ # POLICY GRADIENT CALCULATION
1374
+ #
1375
+ # ***********************************************************************
1376
+
1327
1377
  def _jax_wrapped_mu_grad(epsilon, epsilon_star, r1, r2, r3, r4, m):
1328
1378
  if super_symmetric:
1329
1379
  if scale_reward:
1330
- scale1 = jnp.maximum(MIN_NORM, m - (r1 + r2) / 2)
1331
- scale2 = jnp.maximum(MIN_NORM, m - (r3 + r4) / 2)
1380
+ scale1 = jnp.maximum(min_reward_scale, m - (r1 + r2) / 2)
1381
+ scale2 = jnp.maximum(min_reward_scale, m - (r3 + r4) / 2)
1332
1382
  else:
1333
1383
  scale1 = scale2 = 1.0
1334
1384
  r_mu1 = (r1 - r2) / (2 * scale1)
@@ -1336,37 +1386,37 @@ class GaussianPGPE(PGPE):
1336
1386
  grad = -(r_mu1 * epsilon + r_mu2 * epsilon_star)
1337
1387
  else:
1338
1388
  if scale_reward:
1339
- scale = jnp.maximum(MIN_NORM, m - (r1 + r2) / 2)
1389
+ scale = jnp.maximum(min_reward_scale, m - (r1 + r2) / 2)
1340
1390
  else:
1341
1391
  scale = 1.0
1342
1392
  r_mu = (r1 - r2) / (2 * scale)
1343
1393
  grad = -r_mu * epsilon
1344
1394
  return grad
1345
1395
 
1346
- def _jax_wrapped_sigma_grad(epsilon, epsilon_star, sigma, r1, r2, r3, r4, m):
1396
+ def _jax_wrapped_sigma_grad(epsilon, epsilon_star, sigma, r1, r2, r3, r4, m, ent):
1347
1397
  if super_symmetric:
1348
1398
  mask = r1 + r2 >= r3 + r4
1349
1399
  epsilon_tau = mask * epsilon + (1 - mask) * epsilon_star
1350
- s = epsilon_tau * epsilon_tau / sigma - sigma
1400
+ s = jnp.square(epsilon_tau) / sigma - sigma
1351
1401
  if scale_reward:
1352
- scale = jnp.maximum(MIN_NORM, m - (r1 + r2 + r3 + r4) / 4)
1402
+ scale = jnp.maximum(min_reward_scale, m - (r1 + r2 + r3 + r4) / 4)
1353
1403
  else:
1354
1404
  scale = 1.0
1355
1405
  r_sigma = ((r1 + r2) - (r3 + r4)) / (4 * scale)
1356
1406
  else:
1357
- s = epsilon * epsilon / sigma - sigma
1407
+ s = jnp.square(epsilon) / sigma - sigma
1358
1408
  if scale_reward:
1359
- scale = jnp.maximum(MIN_NORM, jnp.abs(m))
1409
+ scale = jnp.maximum(min_reward_scale, jnp.abs(m))
1360
1410
  else:
1361
1411
  scale = 1.0
1362
1412
  r_sigma = (r1 + r2) / (2 * scale)
1363
- grad = -r_sigma * s
1413
+ grad = -(r_sigma * s + ent / sigma)
1364
1414
  return grad
1365
1415
 
1366
- def _jax_wrapped_pgpe_grad(key, mu, sigma, r_max,
1416
+ def _jax_wrapped_pgpe_grad(key, mu, sigma, r_max, ent,
1367
1417
  policy_hyperparams, subs, model_params):
1368
1418
  key, subkey = random.split(key)
1369
- (p1, p2, p3, p4), (epsilon, epsilon_star) = _jax_wrapped_sample_params(
1419
+ p1, p2, p3, p4, epsilon, epsilon_star = _jax_wrapped_sample_params(
1370
1420
  key, mu, sigma)
1371
1421
  r1 = -loss_fn(subkey, p1, policy_hyperparams, subs, model_params)[0]
1372
1422
  r2 = -loss_fn(subkey, p2, policy_hyperparams, subs, model_params)[0]
@@ -1384,42 +1434,76 @@ class GaussianPGPE(PGPE):
1384
1434
  epsilon, epsilon_star
1385
1435
  )
1386
1436
  grad_sigma = jax.tree_map(
1387
- partial(_jax_wrapped_sigma_grad, r1=r1, r2=r2, r3=r3, r4=r4, m=r_max),
1437
+ partial(_jax_wrapped_sigma_grad,
1438
+ r1=r1, r2=r2, r3=r3, r4=r4, m=r_max, ent=ent),
1388
1439
  epsilon, epsilon_star, sigma
1389
1440
  )
1390
1441
  return grad_mu, grad_sigma, r_max
1391
1442
 
1392
- def _jax_wrapped_pgpe_grad_batched(key, pgpe_params, r_max,
1443
+ def _jax_wrapped_pgpe_grad_batched(key, pgpe_params, r_max, ent,
1393
1444
  policy_hyperparams, subs, model_params):
1394
1445
  mu, sigma = pgpe_params
1395
1446
  if batch_size == 1:
1396
1447
  mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad(
1397
- key, mu, sigma, r_max, policy_hyperparams, subs, model_params)
1448
+ key, mu, sigma, r_max, ent, policy_hyperparams, subs, model_params)
1398
1449
  else:
1399
1450
  keys = random.split(key, num=batch_size)
1400
1451
  mu_grads, sigma_grads, r_maxs = jax.vmap(
1401
1452
  _jax_wrapped_pgpe_grad,
1402
- in_axes=(0, None, None, None, None, None, None)
1403
- )(keys, mu, sigma, r_max, policy_hyperparams, subs, model_params)
1453
+ in_axes=(0, None, None, None, None, None, None, None)
1454
+ )(keys, mu, sigma, r_max, ent, policy_hyperparams, subs, model_params)
1404
1455
  mu_grad, sigma_grad = jax.tree_map(
1405
1456
  partial(jnp.mean, axis=0), (mu_grads, sigma_grads))
1406
1457
  new_r_max = jnp.max(r_maxs)
1407
1458
  return mu_grad, sigma_grad, new_r_max
1459
+
1460
+ # ***********************************************************************
1461
+ # PARAMETER UPDATE
1462
+ #
1463
+ # ***********************************************************************
1408
1464
 
1409
- def _jax_wrapped_pgpe_update(key, pgpe_params, r_max,
1465
+ def _jax_wrapped_pgpe_kl_term(mu, sigma, old_mu, old_sigma):
1466
+ return 0.5 * jnp.sum(2 * jnp.log(sigma / old_sigma) +
1467
+ jnp.square(old_sigma / sigma) +
1468
+ jnp.square((mu - old_mu) / sigma) - 1)
1469
+
1470
+ def _jax_wrapped_pgpe_update(key, pgpe_params, r_max, progress,
1410
1471
  policy_hyperparams, subs, model_params,
1411
1472
  pgpe_opt_state):
1473
+ # regular update
1412
1474
  mu, sigma = pgpe_params
1413
1475
  mu_state, sigma_state = pgpe_opt_state
1476
+ ent = start_entropy_coeff * jnp.power(entropy_coeff_decay, progress)
1414
1477
  mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad_batched(
1415
- key, pgpe_params, r_max, policy_hyperparams, subs, model_params)
1478
+ key, pgpe_params, r_max, ent, policy_hyperparams, subs, model_params)
1416
1479
  mu_updates, new_mu_state = mu_optimizer.update(mu_grad, mu_state, params=mu)
1417
1480
  sigma_updates, new_sigma_state = sigma_optimizer.update(
1418
1481
  sigma_grad, sigma_state, params=sigma)
1419
1482
  new_mu = optax.apply_updates(mu, mu_updates)
1420
- new_mu, converged = projection(new_mu, policy_hyperparams)
1421
1483
  new_sigma = optax.apply_updates(sigma, sigma_updates)
1422
1484
  new_sigma = jax.tree_map(lambda x: jnp.clip(x, *sigma_range), new_sigma)
1485
+
1486
+ # respect KL divergence contraint with old parameters
1487
+ if max_kl is not None:
1488
+ old_mu_lr = new_mu_state.hyperparams['learning_rate']
1489
+ old_sigma_lr = new_sigma_state.hyperparams['learning_rate']
1490
+ kl_terms = jax.tree_map(
1491
+ _jax_wrapped_pgpe_kl_term, new_mu, new_sigma, mu, sigma)
1492
+ total_kl = jax.tree_util.tree_reduce(jnp.add, kl_terms)
1493
+ kl_reduction = jnp.minimum(1.0, jnp.sqrt(max_kl / total_kl))
1494
+ mu_state.hyperparams['learning_rate'] = old_mu_lr * kl_reduction
1495
+ sigma_state.hyperparams['learning_rate'] = old_sigma_lr * kl_reduction
1496
+ mu_updates, new_mu_state = mu_optimizer.update(mu_grad, mu_state, params=mu)
1497
+ sigma_updates, new_sigma_state = sigma_optimizer.update(
1498
+ sigma_grad, sigma_state, params=sigma)
1499
+ new_mu = optax.apply_updates(mu, mu_updates)
1500
+ new_sigma = optax.apply_updates(sigma, sigma_updates)
1501
+ new_sigma = jax.tree_map(lambda x: jnp.clip(x, *sigma_range), new_sigma)
1502
+ new_mu_state.hyperparams['learning_rate'] = old_mu_lr
1503
+ new_sigma_state.hyperparams['learning_rate'] = old_sigma_lr
1504
+
1505
+ # apply projection step and finalize results
1506
+ new_mu, converged = projection(new_mu, policy_hyperparams)
1423
1507
  new_pgpe_params = (new_mu, new_sigma)
1424
1508
  new_pgpe_opt_state = (new_mu_state, new_sigma_state)
1425
1509
  policy_params = new_mu
@@ -1462,14 +1546,14 @@ def mean_deviation_utility(returns: jnp.ndarray, beta: float) -> float:
1462
1546
  @jax.jit
1463
1547
  def mean_semideviation_utility(returns: jnp.ndarray, beta: float) -> float:
1464
1548
  mu = jnp.mean(returns)
1465
- msd = jnp.sqrt(jnp.mean(jnp.minimum(0.0, returns - mu) ** 2))
1549
+ msd = jnp.sqrt(jnp.mean(jnp.square(jnp.minimum(0.0, returns - mu))))
1466
1550
  return mu - 0.5 * beta * msd
1467
1551
 
1468
1552
 
1469
1553
  @jax.jit
1470
1554
  def mean_semivariance_utility(returns: jnp.ndarray, beta: float) -> float:
1471
1555
  mu = jnp.mean(returns)
1472
- msv = jnp.mean(jnp.minimum(0.0, returns - mu) ** 2)
1556
+ msv = jnp.mean(jnp.square(jnp.minimum(0.0, returns - mu)))
1473
1557
  return mu - 0.5 * beta * msv
1474
1558
 
1475
1559
 
@@ -1768,7 +1852,6 @@ r"""
1768
1852
 
1769
1853
  # optimization
1770
1854
  self.update = self._jax_update(train_loss)
1771
- self.check_zero_grad = self._jax_check_zero_gradients()
1772
1855
 
1773
1856
  # pgpe option
1774
1857
  if self.use_pgpe:
@@ -1831,6 +1914,12 @@ r"""
1831
1914
  projection = self.plan.projection
1832
1915
  use_ls = self.line_search_kwargs is not None
1833
1916
 
1917
+ # check if the gradients are all zeros
1918
+ def _jax_wrapped_zero_gradients(grad):
1919
+ leaves, _ = jax.tree_util.tree_flatten(
1920
+ jax.tree_map(lambda g: jnp.allclose(g, 0), grad))
1921
+ return jnp.all(jnp.asarray(leaves))
1922
+
1834
1923
  # calculate the plan gradient w.r.t. return loss and update optimizer
1835
1924
  # also perform a projection step to satisfy constraints on actions
1836
1925
  def _jax_wrapped_loss_swapped(policy_params, key, policy_hyperparams,
@@ -1855,23 +1944,12 @@ r"""
1855
1944
  policy_params, converged = projection(policy_params, policy_hyperparams)
1856
1945
  log['grad'] = grad
1857
1946
  log['updates'] = updates
1947
+ zero_grads = _jax_wrapped_zero_gradients(grad)
1858
1948
  return policy_params, converged, opt_state, opt_aux, \
1859
- loss_val, log, model_params
1949
+ loss_val, log, model_params, zero_grads
1860
1950
 
1861
1951
  return jax.jit(_jax_wrapped_plan_update)
1862
1952
 
1863
- def _jax_check_zero_gradients(self):
1864
-
1865
- def _jax_wrapped_zero_gradient(grad):
1866
- return jnp.allclose(grad, 0)
1867
-
1868
- def _jax_wrapped_zero_gradients(grad):
1869
- leaves, _ = jax.tree_util.tree_flatten(
1870
- jax.tree_map(_jax_wrapped_zero_gradient, grad))
1871
- return jnp.all(jnp.asarray(leaves))
1872
-
1873
- return jax.jit(_jax_wrapped_zero_gradients)
1874
-
1875
1953
  def _batched_init_subs(self, subs):
1876
1954
  rddl = self.rddl
1877
1955
  n_train, n_test = self.batch_size_train, self.batch_size_test
@@ -2175,11 +2253,12 @@ r"""
2175
2253
  # ======================================================================
2176
2254
 
2177
2255
  # initialize running statistics
2178
- best_params, best_loss, best_grad = policy_params, jnp.inf, jnp.inf
2256
+ best_params, best_loss, best_grad = policy_params, jnp.inf, None
2179
2257
  last_iter_improve = 0
2180
2258
  rolling_test_loss = RollingMean(test_rolling_window)
2181
2259
  log = {}
2182
2260
  status = JaxPlannerStatus.NORMAL
2261
+ progress_percent = 0
2183
2262
 
2184
2263
  # initialize stopping criterion
2185
2264
  if stopping_rule is not None:
@@ -2191,18 +2270,19 @@ r"""
2191
2270
  dashboard_id, dashboard.get_planner_info(self),
2192
2271
  key=dash_key, viz=self.dashboard_viz)
2193
2272
 
2273
+ # progress bar
2274
+ if print_progress:
2275
+ progress_bar = tqdm(None, total=100, position=tqdm_position,
2276
+ bar_format='{l_bar}{bar}| {elapsed} {postfix}')
2277
+ else:
2278
+ progress_bar = None
2279
+ position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
2280
+
2194
2281
  # ======================================================================
2195
2282
  # MAIN TRAINING LOOP BEGINS
2196
2283
  # ======================================================================
2197
2284
 
2198
- iters = range(epochs)
2199
- if print_progress:
2200
- iters = tqdm(iters, total=100,
2201
- bar_format='{l_bar}{bar}| {elapsed} {postfix}',
2202
- position=tqdm_position)
2203
- position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
2204
-
2205
- for it in iters:
2285
+ for it in range(epochs):
2206
2286
 
2207
2287
  # ==================================================================
2208
2288
  # NEXT GRADIENT DESCENT STEP
@@ -2213,8 +2293,9 @@ r"""
2213
2293
  # update the parameters of the plan
2214
2294
  key, subkey = random.split(key)
2215
2295
  (policy_params, converged, opt_state, opt_aux, train_loss, train_log,
2216
- model_params) = self.update(subkey, policy_params, policy_hyperparams,
2217
- train_subs, model_params, opt_state, opt_aux)
2296
+ model_params, zero_grads) = self.update(
2297
+ subkey, policy_params, policy_hyperparams, train_subs, model_params,
2298
+ opt_state, opt_aux)
2218
2299
  test_loss, (test_log, model_params_test) = self.test_loss(
2219
2300
  subkey, policy_params, policy_hyperparams, test_subs, model_params_test)
2220
2301
  test_loss_smooth = rolling_test_loss.update(test_loss)
@@ -2224,8 +2305,9 @@ r"""
2224
2305
  if self.use_pgpe:
2225
2306
  key, subkey = random.split(key)
2226
2307
  pgpe_params, r_max, pgpe_opt_state, pgpe_param, pgpe_converged = \
2227
- self.pgpe.update(subkey, pgpe_params, r_max, policy_hyperparams,
2228
- test_subs, model_params, pgpe_opt_state)
2308
+ self.pgpe.update(subkey, pgpe_params, r_max, progress_percent,
2309
+ policy_hyperparams, test_subs, model_params_test,
2310
+ pgpe_opt_state)
2229
2311
  pgpe_loss, _ = self.test_loss(
2230
2312
  subkey, pgpe_param, policy_hyperparams, test_subs, model_params_test)
2231
2313
  pgpe_loss_smooth = rolling_pgpe_loss.update(pgpe_loss)
@@ -2252,7 +2334,7 @@ r"""
2252
2334
  # ==================================================================
2253
2335
 
2254
2336
  # no progress
2255
- if (not pgpe_improve) and self.check_zero_grad(train_log['grad']):
2337
+ if (not pgpe_improve) and zero_grads:
2256
2338
  status = JaxPlannerStatus.NO_PROGRESS
2257
2339
 
2258
2340
  # constraint satisfaction problem
@@ -2311,14 +2393,15 @@ r"""
2311
2393
 
2312
2394
  # if the progress bar is used
2313
2395
  if print_progress:
2314
- iters.n = progress_percent
2315
- iters.set_description(
2396
+ progress_bar.set_description(
2316
2397
  f'{position_str} {it:6} it / {-train_loss:14.5f} train / '
2317
2398
  f'{-test_loss_smooth:14.5f} test / {-best_loss:14.5f} best / '
2318
2399
  f'{status.value} status / {total_pgpe_it:6} pgpe',
2319
2400
  refresh=False
2320
2401
  )
2321
- iters.set_postfix_str(f"{(it + 1) / elapsed:.2f}it/s", refresh=True)
2402
+ progress_bar.set_postfix_str(
2403
+ f"{(it + 1) / (elapsed + 1e-6):.2f}it/s", refresh=False)
2404
+ progress_bar.update(progress_percent - progress_bar.n)
2322
2405
 
2323
2406
  # dash-board
2324
2407
  if dashboard is not None:
@@ -2339,7 +2422,7 @@ r"""
2339
2422
 
2340
2423
  # release resources
2341
2424
  if print_progress:
2342
- iters.close()
2425
+ progress_bar.close()
2343
2426
 
2344
2427
  # validate the test return
2345
2428
  if log:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: pyRDDLGym-jax
3
- Version: 2.1
3
+ Version: 2.2
4
4
  Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
5
  Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
6
  Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
@@ -58,8 +58,11 @@ Dynamic: summary
58
58
 
59
59
  Purpose:
60
60
 
61
- 1. automatic translation of any RDDL description file into a differentiable simulator in JAX
62
- 2. flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and Bayesian hyper-parameter tuning.
61
+ 1. automatic translation of RDDL description files into differentiable JAX simulators
62
+ 2. implementation of (highly configurable) operator relaxations for working in discrete and hybrid domains
63
+ 3. flexible policy representations and automated Bayesian hyper-parameter tuning
64
+ 4. interactive dashboard for dyanmic visualization and debugging
65
+ 5. hybridization with parameter-exploring policy gradients.
63
66
 
64
67
  Some demos of solved problems by JaxPlan:
65
68
 
@@ -235,8 +238,23 @@ More documentation about this and other new features will be coming soon.
235
238
 
236
239
  ## Tuning the Planner
237
240
 
238
- It is easy to tune the planner's hyper-parameters efficiently and automatically using Bayesian optimization.
239
- To do this, first create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
241
+ A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
242
+
243
+ ```shell
244
+ jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
245
+ ```
246
+
247
+ where:
248
+ - ``domain`` is the domain identifier as specified in rddlrepository
249
+ - ``instance`` is the instance identifier
250
+ - ``method`` is the planning method to use (i.e. drp, slp, replan)
251
+ - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
252
+ - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
253
+ - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
254
+ - ``dashboard`` is whether the optimizations are tracked in the dashboard application.
255
+
256
+ It is easy to tune a custom range of the planner's hyper-parameters efficiently.
257
+ First create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
240
258
 
241
259
  ```ini
242
260
  [Model]
@@ -260,7 +278,7 @@ train_on_reset=True
260
278
 
261
279
  would allow to tune the sharpness of model relaxations, and the learning rate of the optimizer.
262
280
 
263
- Next, you must link the patterns in the config with concrete hyper-parameter ranges the tuner will understand:
281
+ Next, you must link the patterns in the config with concrete hyper-parameter ranges the tuner will understand, and run the optimizer:
264
282
 
265
283
  ```python
266
284
  import pyRDDLGym
@@ -292,22 +310,7 @@ tuning = JaxParameterTuning(env=env,
292
310
  gp_iters=iters)
293
311
  tuning.tune(key=42, log_file='path/to/log.csv')
294
312
  ```
295
-
296
- A basic run script is provided to run the automatic hyper-parameter tuning for the most sensitive parameters of JaxPlan:
297
-
298
- ```shell
299
- jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
300
- ```
301
-
302
- where:
303
- - ``domain`` is the domain identifier as specified in rddlrepository
304
- - ``instance`` is the instance identifier
305
- - ``method`` is the planning method to use (i.e. drp, slp, replan)
306
- - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
307
- - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
308
- - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
309
- - ``dashboard`` is whether the optimizations are tracked in the dashboard application.
310
-
313
+
311
314
 
312
315
  ## Simulation
313
316
 
@@ -19,7 +19,7 @@ long_description = (Path(__file__).parent / "README.md").read_text()
19
19
 
20
20
  setup(
21
21
  name='pyRDDLGym-jax',
22
- version='2.1',
22
+ version='2.2',
23
23
  author="Michael Gimelfarb, Ayal Taitler, Scott Sanner",
24
24
  author_email="mike.gimelfarb@mail.utoronto.ca, ataitler@gmail.com, ssanner@mie.utoronto.ca",
25
25
  description="pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.",
@@ -1 +0,0 @@
1
- __version__ = '2.1'
File without changes
File without changes