pyRDDLGym-jax 0.3__py3-none-any.whl → 0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,8 @@
1
- import jax
2
1
  import time
3
2
  from typing import Dict, Optional
4
3
 
4
+ import jax
5
+
5
6
  from pyRDDLGym.core.compiler.model import RDDLLiftedModel
6
7
  from pyRDDLGym.core.debug.exception import (
7
8
  RDDLActionPreconditionNotSatisfiedError,
@@ -1,20 +1,18 @@
1
- from bayes_opt import BayesianOptimization
2
- from bayes_opt.util import UtilityFunction
3
1
  from copy import deepcopy
4
2
  import csv
5
3
  import datetime
6
- import jax
7
4
  from multiprocessing import get_context
8
- import numpy as np
9
5
  import os
10
6
  import time
11
7
  from typing import Any, Callable, Dict, Optional, Tuple
12
-
13
- Kwargs = Dict[str, Any]
14
-
15
8
  import warnings
16
9
  warnings.filterwarnings("ignore")
17
10
 
11
+ from bayes_opt import BayesianOptimization
12
+ from bayes_opt.acquisition import AcquisitionFunction, UpperConfidenceBound
13
+ import jax
14
+ import numpy as np
15
+
18
16
  from pyRDDLGym.core.debug.exception import raise_warning
19
17
  from pyRDDLGym.core.env import RDDLEnv
20
18
 
@@ -26,6 +24,7 @@ from pyRDDLGym_jax.core.planner import (
26
24
  JaxOnlineController
27
25
  )
28
26
 
27
+ Kwargs = Dict[str, Any]
29
28
 
30
29
  # ===============================================================================
31
30
  #
@@ -37,6 +36,9 @@ from pyRDDLGym_jax.core.planner import (
37
36
  # 3. deep reactive policies
38
37
  #
39
38
  # ===============================================================================
39
+ COLUMNS = ['pid', 'worker', 'iteration', 'target', 'best_target', 'acq_params']
40
+
41
+
40
42
  class JaxParameterTuning:
41
43
  '''A general-purpose class for tuning a Jax planner.'''
42
44
 
@@ -53,7 +55,7 @@ class JaxParameterTuning:
53
55
  num_workers: int=1,
54
56
  poll_frequency: float=0.2,
55
57
  gp_iters: int=25,
56
- acquisition: Optional[UtilityFunction]=None,
58
+ acquisition: Optional[AcquisitionFunction]=None,
57
59
  gp_init_kwargs: Optional[Kwargs]=None,
58
60
  gp_params: Optional[Kwargs]=None) -> None:
59
61
  '''Creates a new instance for tuning hyper-parameters for Jax planners
@@ -113,10 +115,9 @@ class JaxParameterTuning:
113
115
  self.gp_params = gp_params
114
116
 
115
117
  # create acquisition function
116
- self.acq_args = None
117
118
  if acquisition is None:
118
119
  num_samples = self.gp_iters * self.num_workers
119
- acquisition, self.acq_args = JaxParameterTuning._annealing_utility(num_samples)
120
+ acquisition = JaxParameterTuning._annealing_acquisition(num_samples)
120
121
  self.acquisition = acquisition
121
122
 
122
123
  def summarize_hyperparameters(self) -> None:
@@ -133,23 +134,15 @@ class JaxParameterTuning:
133
134
  f' planning_trials_per_iter ={self.eval_trials}\n'
134
135
  f' planning_iters_per_trial ={self.train_epochs}\n'
135
136
  f' planning_timeout_per_trial={self.timeout_training}\n'
136
- f' acquisition_fn ={type(self.acquisition).__name__}')
137
- if self.acq_args is not None:
138
- print(f'using default acquisition function:\n'
139
- f' utility_kind ={self.acq_args[0]}\n'
140
- f' initial_kappa={self.acq_args[1]}\n'
141
- f' kappa_decay ={self.acq_args[2]}')
137
+ f' acquisition_fn ={self.acquisition}')
142
138
 
143
139
  @staticmethod
144
- def _annealing_utility(n_samples, n_delay_samples=0, kappa1=10.0, kappa2=1.0):
145
- kappa_decay = (kappa2 / kappa1) ** (1.0 / (n_samples - n_delay_samples))
146
- utility_fn = UtilityFunction(
147
- kind='ucb',
140
+ def _annealing_acquisition(n_samples, n_delay_samples=0, kappa1=10.0, kappa2=1.0):
141
+ acq_fn = UpperConfidenceBound(
148
142
  kappa=kappa1,
149
- kappa_decay=kappa_decay,
150
- kappa_decay_delay=n_delay_samples)
151
- utility_args = ['ucb', kappa1, kappa_decay]
152
- return utility_fn, utility_args
143
+ exploration_decay=(kappa2 / kappa1) ** (1.0 / (n_samples - n_delay_samples)),
144
+ exploration_decay_delay=n_delay_samples)
145
+ return acq_fn
153
146
 
154
147
  def _pickleable_objective_with_kwargs(self):
155
148
  raise NotImplementedError
@@ -160,7 +153,7 @@ class JaxParameterTuning:
160
153
  pid = os.getpid()
161
154
  return index, pid, params, target
162
155
 
163
- def tune(self, key: jax.random.PRNGKey,
156
+ def tune(self, key: jax.random.PRNGKey,
164
157
  filename: str,
165
158
  save_plot: bool=False) -> Dict[str, Any]:
166
159
  '''Tunes the hyper-parameters for Jax planner, returns the best found.'''
@@ -178,32 +171,28 @@ class JaxParameterTuning:
178
171
  for (name, hparam) in self.hyperparams_dict.items()
179
172
  }
180
173
  optimizer = BayesianOptimization(
181
- f=None, # probe() is not called
174
+ f=None,
175
+ acquisition_function=self.acquisition,
182
176
  pbounds=hyperparams_bounds,
183
177
  allow_duplicate_points=True, # to avoid crash
184
178
  random_state=np.random.RandomState(key),
185
179
  **self.gp_init_kwargs
186
180
  )
187
181
  optimizer.set_gp_params(**self.gp_params)
188
- utility = self.acquisition
189
182
 
190
183
  # suggest initial parameters to evaluate
191
184
  num_workers = self.num_workers
192
- suggested, kappas = [], []
185
+ suggested, acq_params = [], []
193
186
  for _ in range(num_workers):
194
- utility.update_params()
195
- probe = optimizer.suggest(utility)
196
- suggested.append(probe)
197
- kappas.append(utility.kappa)
187
+ probe = optimizer.suggest()
188
+ suggested.append(probe)
189
+ acq_params.append(vars(optimizer.acquisition_function))
198
190
 
199
191
  # clear and prepare output file
200
192
  filename = self._filename(filename, 'csv')
201
193
  with open(filename, 'w', newline='') as file:
202
194
  writer = csv.writer(file)
203
- writer.writerow(
204
- ['pid', 'worker', 'iteration', 'target', 'best_target', 'kappa'] + \
205
- list(hyperparams_bounds.keys())
206
- )
195
+ writer.writerow(COLUMNS + list(hyperparams_bounds.keys()))
207
196
 
208
197
  # start multiprocess evaluation
209
198
  worker_ids = list(range(num_workers))
@@ -219,8 +208,8 @@ class JaxParameterTuning:
219
208
 
220
209
  # continue with next iteration
221
210
  print('\n' + '*' * 25 +
222
- '\n' + f'[{datetime.timedelta(seconds=elapsed)}] ' +
223
- f'starting iteration {it}' +
211
+ f'\n[{datetime.timedelta(seconds=elapsed)}] ' +
212
+ f'starting iteration {it + 1}' +
224
213
  '\n' + '*' * 25)
225
214
  key, *subkeys = jax.random.split(key, num=num_workers + 1)
226
215
  rows = [None] * num_workers
@@ -256,10 +245,9 @@ class JaxParameterTuning:
256
245
  optimizer.register(params, target)
257
246
 
258
247
  # update acquisition function and suggest a new point
259
- utility.update_params()
260
- suggested[index] = optimizer.suggest(utility)
261
- old_kappa = kappas[index]
262
- kappas[index] = utility.kappa
248
+ suggested[index] = optimizer.suggest()
249
+ old_acq_params = acq_params[index]
250
+ acq_params[index] = vars(optimizer.acquisition_function)
263
251
 
264
252
  # transform suggestion back to natural space
265
253
  rddl_params = {
@@ -272,8 +260,8 @@ class JaxParameterTuning:
272
260
  best_params, best_target = rddl_params, target
273
261
 
274
262
  # write progress to file in real time
275
- rows[index] = [pid, index, it, target, best_target, old_kappa] + \
276
- list(rddl_params.values())
263
+ info_i = [pid, index, it, target, best_target, old_acq_params]
264
+ rows[index] = info_i + list(rddl_params.values())
277
265
 
278
266
  # write results of all processes in current iteration to file
279
267
  with open(filename, 'a', newline='') as file:
@@ -308,16 +296,20 @@ class JaxParameterTuning:
308
296
  raise_warning(f'failed to import packages matplotlib or sklearn, '
309
297
  f'aborting plot of search space\n{e}', 'red')
310
298
  else:
311
- data = np.loadtxt(filename, delimiter=',', dtype=object)
312
- data, target = data[1:, 3:], data[1:, 2]
313
- data = data.astype(np.float64)
314
- target = target.astype(np.float64)
299
+ with open(filename, 'r') as file:
300
+ data_iter = csv.reader(file, delimiter=',')
301
+ data = [row for row in data_iter]
302
+ data = np.asarray(data, dtype=object)
303
+ hparam = data[1:, len(COLUMNS):].astype(np.float64)
304
+ target = data[1:, 3].astype(np.float64)
315
305
  target = (target - np.min(target)) / (np.max(target) - np.min(target))
316
306
  embedding = MDS(n_components=2, normalized_stress='auto')
317
- data1 = embedding.fit_transform(data)
318
- sc = plt.scatter(data1[:, 0], data1[:, 1], c=target, s=4.,
319
- cmap='seismic', edgecolor='gray',
320
- linewidth=0.01, alpha=0.4)
307
+ hparam_low = embedding.fit_transform(hparam)
308
+ sc = plt.scatter(hparam_low[:, 0], hparam_low[:, 1], c=target, s=5,
309
+ cmap='seismic', edgecolor='gray', linewidth=0)
310
+ ax = plt.gca()
311
+ for i in range(len(target)):
312
+ ax.annotate(str(i), (hparam_low[i, 0], hparam_low[i, 1]), fontsize=3)
321
313
  plt.colorbar(sc)
322
314
  plt.savefig(self._filename('gp_points', 'pdf'))
323
315
  plt.clf()
@@ -342,9 +334,11 @@ def objective_slp(params, kwargs, key, index):
342
334
  std, lr, w, wa = param_values
343
335
  else:
344
336
  std, lr, w = param_values
345
- wa = None
337
+ wa = None
338
+ key, subkey = jax.random.split(key)
346
339
  if kwargs['verbose']:
347
- print(f'[{index}] key={key}, std={std}, lr={lr}, w={w}, wa={wa}...', flush=True)
340
+ print(f'[{index}] key={subkey[0]}, '
341
+ f'std={std}, lr={lr}, w={w}, wa={wa}...', flush=True)
348
342
 
349
343
  # initialize planning algorithm
350
344
  planner = JaxBackpropPlanner(
@@ -358,7 +352,6 @@ def objective_slp(params, kwargs, key, index):
358
352
  model_params = {name: w for name in planner.compiled.model_params}
359
353
 
360
354
  # initialize policy
361
- key, subkey = jax.random.split(key)
362
355
  policy = JaxOfflineController(
363
356
  planner=planner,
364
357
  key=subkey,
@@ -384,7 +377,7 @@ def objective_slp(params, kwargs, key, index):
384
377
  key, subkey = jax.random.split(key)
385
378
  total_reward = policy.evaluate(env, seed=np.array(subkey)[0])['mean']
386
379
  if kwargs['verbose']:
387
- print(f' [{index}] trial {trial + 1} key={subkey}, '
380
+ print(f' [{index}] trial {trial + 1} key={subkey[0]}, '
388
381
  f'reward={total_reward}', flush=True)
389
382
  average_reward += total_reward / kwargs['eval_trials']
390
383
  if kwargs['verbose']:
@@ -474,8 +467,10 @@ def objective_replan(params, kwargs, key, index):
474
467
  else:
475
468
  std, lr, w, T = param_values
476
469
  wa = None
470
+ key, subkey = jax.random.split(key)
477
471
  if kwargs['verbose']:
478
- print(f'[{index}] key={key}, std={std}, lr={lr}, w={w}, wa={wa}, T={T}...', flush=True)
472
+ print(f'[{index}] key={subkey[0]}, '
473
+ f'std={std}, lr={lr}, w={w}, wa={wa}, T={T}...', flush=True)
479
474
 
480
475
  # initialize planning algorithm
481
476
  planner = JaxBackpropPlanner(
@@ -490,7 +485,6 @@ def objective_replan(params, kwargs, key, index):
490
485
  model_params = {name: w for name in planner.compiled.model_params}
491
486
 
492
487
  # initialize controller
493
- key, subkey = jax.random.split(key)
494
488
  policy = JaxOnlineController(
495
489
  planner=planner,
496
490
  key=subkey,
@@ -516,7 +510,7 @@ def objective_replan(params, kwargs, key, index):
516
510
  key, subkey = jax.random.split(key)
517
511
  total_reward = policy.evaluate(env, seed=np.array(subkey)[0])['mean']
518
512
  if kwargs['verbose']:
519
- print(f' [{index}] trial {trial + 1} key={subkey}, '
513
+ print(f' [{index}] trial {trial + 1} key={subkey[0]}, '
520
514
  f'reward={total_reward}', flush=True)
521
515
  average_reward += total_reward / kwargs['eval_trials']
522
516
  if kwargs['verbose']:
@@ -602,9 +596,11 @@ def objective_drp(params, kwargs, key, index):
602
596
  ]
603
597
 
604
598
  # unpack hyper-parameters
605
- lr, w, layers, neurons = param_values
599
+ lr, w, layers, neurons = param_values
600
+ key, subkey = jax.random.split(key)
606
601
  if kwargs['verbose']:
607
- print(f'[{index}] key={key}, lr={lr}, w={w}, layers={layers}, neurons={neurons}...', flush=True)
602
+ print(f'[{index}] key={subkey[0]}, '
603
+ f'lr={lr}, w={w}, layers={layers}, neurons={neurons}...', flush=True)
608
604
 
609
605
  # initialize planning algorithm
610
606
  planner = JaxBackpropPlanner(
@@ -618,7 +614,6 @@ def objective_drp(params, kwargs, key, index):
618
614
  model_params = {name: w for name in planner.compiled.model_params}
619
615
 
620
616
  # initialize policy
621
- key, subkey = jax.random.split(key)
622
617
  policy = JaxOfflineController(
623
618
  planner=planner,
624
619
  key=subkey,
@@ -644,7 +639,7 @@ def objective_drp(params, kwargs, key, index):
644
639
  key, subkey = jax.random.split(key)
645
640
  total_reward = policy.evaluate(env, seed=np.array(subkey)[0])['mean']
646
641
  if kwargs['verbose']:
647
- print(f' [{index}] trial {trial + 1} key={subkey}, '
642
+ print(f' [{index}] trial {trial + 1} key={subkey[0]}, '
648
643
  f'reward={total_reward}', flush=True)
649
644
  average_reward += total_reward / kwargs['eval_trials']
650
645
  if kwargs['verbose']:
@@ -16,4 +16,5 @@ rollout_horizon=30
16
16
  [Training]
17
17
  key=42
18
18
  epochs=1000
19
- train_seconds=1
19
+ train_seconds=1
20
+ print_summary=False
@@ -16,4 +16,5 @@ rollout_horizon=5
16
16
  [Training]
17
17
  key=42
18
18
  epochs=2000
19
- train_seconds=1
19
+ train_seconds=1
20
+ print_summary=False
@@ -16,4 +16,5 @@ rollout_horizon=5
16
16
  [Training]
17
17
  key=42
18
18
  epochs=500
19
- train_seconds=1
19
+ train_seconds=1
20
+ print_summary=False
@@ -6,9 +6,9 @@ tnorm_kwargs={}
6
6
 
7
7
  [Optimizer]
8
8
  method='JaxStraightLinePlan'
9
- method_kwargs={'initializer': 'normal', 'initializer_kwargs': {'stddev': 0.001}}
9
+ method_kwargs={}
10
10
  optimizer='rmsprop'
11
- optimizer_kwargs={'learning_rate': 0.001}
11
+ optimizer_kwargs={'learning_rate': 0.1}
12
12
  batch_size_train=32
13
13
  batch_size_test=32
14
14
  rollout_horizon=5
@@ -17,4 +17,5 @@ rollout_horizon=5
17
17
  key=42
18
18
  epochs=1000
19
19
  train_seconds=1
20
- policy_hyperparams={'cut-out': 10.0, 'put-out': 10.0}
20
+ policy_hyperparams={'cut-out': 10.0, 'put-out': 10.0}
21
+ print_summary=False
@@ -17,4 +17,5 @@ rollout_horizon=5
17
17
  key=42
18
18
  epochs=2000
19
19
  train_seconds=1
20
- policy_hyperparams=2.0
20
+ policy_hyperparams=2.0
21
+ print_summary=False
@@ -59,9 +59,7 @@ def main(domain, instance, method, trials=5, iters=20, workers=4):
59
59
  gp_iters=iters)
60
60
 
61
61
  # perform tuning and report best parameters
62
- best = tuning.tune(key=train_args['key'], filename=f'gp_{method}',
63
- save_plot=True)
64
- print(f'best parameters found: {best}')
62
+ tuning.tune(key=train_args['key'], filename=f'gp_{method}', save_plot=True)
65
63
 
66
64
 
67
65
  if __name__ == "__main__":
@@ -0,0 +1,278 @@
1
+ Metadata-Version: 2.1
2
+ Name: pyRDDLGym-jax
3
+ Version: 0.5
4
+ Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
+ Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
+ Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
7
+ Author-email: mike.gimelfarb@mail.utoronto.ca, ataitler@gmail.com, ssanner@mie.utoronto.ca
8
+ License: MIT License
9
+ Classifier: Development Status :: 3 - Alpha
10
+ Classifier: Intended Audience :: Science/Research
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Natural Language :: English
13
+ Classifier: Operating System :: OS Independent
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
16
+ Requires-Python: >=3.9
17
+ Description-Content-Type: text/markdown
18
+ License-File: LICENSE
19
+ Requires-Dist: pyRDDLGym >=2.0
20
+ Requires-Dist: tqdm >=4.66
21
+ Requires-Dist: jax >=0.4.12
22
+ Requires-Dist: optax >=0.1.9
23
+ Requires-Dist: dm-haiku >=0.0.10
24
+ Requires-Dist: tensorflow-probability >=0.21.0
25
+ Provides-Extra: extra
26
+ Requires-Dist: bayesian-optimization >=2.0.0 ; extra == 'extra'
27
+ Requires-Dist: rddlrepository >=2.0 ; extra == 'extra'
28
+
29
+ # pyRDDLGym-jax
30
+
31
+ Author: [Mike Gimelfarb](https://mike-gimelfarb.github.io)
32
+
33
+ This directory provides:
34
+ 1. automated translation and compilation of RDDL description files into [JAX](https://github.com/google/jax), converting any RDDL domain to a differentiable simulator!
35
+ 2. powerful, fast and scalable gradient-based planning algorithms, with extendible and flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and much more!
36
+
37
+ > [!NOTE]
38
+ > While Jax planners can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
39
+ > If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
40
+
41
+ ## Contents
42
+
43
+ - [Installation](#installation)
44
+ - [Running from the Command Line](#running-from-the-command-line)
45
+ - [Running from within Python](#running-from-within-python)
46
+ - [Configuring the Planner](#configuring-the-planner)
47
+ - [Simulation](#simulation)
48
+ - [Manual Gradient Calculation](#manual-gradient-calculation)
49
+ - [Citing pyRDDLGym-jax](#citing-pyrddlgym-jax)
50
+
51
+ ## Installation
52
+
53
+ To use the compiler or planner without the automated hyper-parameter tuning, you will need the following packages installed:
54
+ - ``pyRDDLGym>=2.0``
55
+ - ``tqdm>=4.66``
56
+ - ``jax>=0.4.12``
57
+ - ``optax>=0.1.9``
58
+ - ``dm-haiku>=0.0.10``
59
+ - ``tensorflow-probability>=0.21.0``
60
+
61
+ Additionally, if you wish to run the examples, you need ``rddlrepository>=2``.
62
+ To run the automated tuning optimization, you will also need ``bayesian-optimization>=2.0.0``.
63
+
64
+ You can install pyRDDLGym-jax with all requirements using pip:
65
+
66
+ ```shell
67
+ pip install pyRDDLGym-jax[extra]
68
+ ```
69
+
70
+ ## Running from the Command Line
71
+
72
+ A basic run script is provided to run the Jax Planner on any domain in ``rddlrepository`` from the install directory of pyRDDLGym-jax:
73
+
74
+ ```shell
75
+ python -m pyRDDLGym_jax.examples.run_plan <domain> <instance> <method> <episodes>
76
+ ```
77
+
78
+ where:
79
+ - ``domain`` is the domain identifier as specified in rddlrepository (i.e. Wildfire_MDP_ippc2014), or a path pointing to a valid ``domain.rddl`` file
80
+ - ``instance`` is the instance identifier (i.e. 1, 2, ... 10), or a path pointing to a valid ``instance.rddl`` file
81
+ - ``method`` is the planning method to use (i.e. drp, slp, replan)
82
+ - ``episodes`` is the (optional) number of episodes to evaluate the learned policy.
83
+
84
+ The ``method`` parameter supports three possible modes:
85
+ - ``slp`` is the basic straight line planner described [in this paper](https://proceedings.neurips.cc/paper_files/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
86
+ - ``drp`` is the deep reactive policy network described [in this paper](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
87
+ - ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step.
88
+
89
+ A basic run script is also provided to run the automatic hyper-parameter tuning:
90
+
91
+ ```shell
92
+ python -m pyRDDLGym_jax.examples.run_tune <domain> <instance> <method> <trials> <iters> <workers>
93
+ ```
94
+
95
+ where:
96
+ - ``domain`` is the domain identifier as specified in rddlrepository
97
+ - ``instance`` is the instance identifier
98
+ - ``method`` is the planning method to use (i.e. drp, slp, replan)
99
+ - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
100
+ - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
101
+ - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``.
102
+
103
+ For example, the following will train the Jax Planner on the Quadcopter domain with 4 drones:
104
+
105
+ ```shell
106
+ python -m pyRDDLGym_jax.examples.run_plan Quadcopter 1 slp
107
+ ```
108
+
109
+ After several minutes of optimization, you should get a visualization as follows:
110
+
111
+ <p align="center">
112
+ <img src="Images/quadcopter.gif" width="400" height="400" margin=1/>
113
+ </p>
114
+
115
+ ## Running from within Python
116
+
117
+ To run the Jax planner from within a Python application, refer to the following example:
118
+
119
+ ```python
120
+ import pyRDDLGym
121
+ from pyRDDLGym_jax.core.planner import JaxBackpropPlanner, JaxOfflineController
122
+
123
+ # set up the environment (note the vectorized option must be True)
124
+ env = pyRDDLGym.make("domain", "instance", vectorized=True)
125
+
126
+ # create the planning algorithm
127
+ planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
128
+ controller = JaxOfflineController(planner, **train_args)
129
+
130
+ # evaluate the planner
131
+ controller.evaluate(env, episodes=1, verbose=True, render=True)
132
+ env.close()
133
+ ```
134
+
135
+ Here, we have used the straight-line controller, although you can configure the combination of planner and policy representation if you wish.
136
+ All controllers are instances of pyRDDLGym's ``BaseAgent`` class, so they provide the ``evaluate()`` function to streamline interaction with the environment.
137
+ The ``**planner_args`` and ``**train_args`` are keyword argument parameters to pass during initialization, but we strongly recommend creating and loading a config file as discussed in the next section.
138
+
139
+ ## Configuring the Planner
140
+
141
+ The simplest way to configure the planner is to write and pass a configuration file with the necessary [hyper-parameters](https://pyrddlgym.readthedocs.io/en/latest/jax.html#configuring-pyrddlgym-jax).
142
+ The basic structure of a configuration file is provided below for a straight-line planner:
143
+
144
+ ```ini
145
+ [Model]
146
+ logic='FuzzyLogic'
147
+ logic_kwargs={'weight': 20}
148
+ tnorm='ProductTNorm'
149
+ tnorm_kwargs={}
150
+
151
+ [Optimizer]
152
+ method='JaxStraightLinePlan'
153
+ method_kwargs={}
154
+ optimizer='rmsprop'
155
+ optimizer_kwargs={'learning_rate': 0.001}
156
+ batch_size_train=1
157
+ batch_size_test=1
158
+
159
+ [Training]
160
+ key=42
161
+ epochs=5000
162
+ train_seconds=30
163
+ ```
164
+
165
+ The configuration file contains three sections:
166
+ - ``[Model]`` specifies the fuzzy logic operations used to relax discrete operations to differentiable approximations; the ``weight`` dictates the quality of the approximation,
167
+ and ``tnorm`` specifies the type of [fuzzy logic](https://en.wikipedia.org/wiki/T-norm_fuzzy_logics) for relacing logical operations in RDDL (e.g. ``ProductTNorm``, ``GodelTNorm``, ``LukasiewiczTNorm``)
168
+ - ``[Optimizer]`` generally specify the optimizer and plan settings; the ``method`` specifies the plan/policy representation (e.g. ``JaxStraightLinePlan``, ``JaxDeepReactivePolicy``), the gradient descent settings, learning rate, batch size, etc.
169
+ - ``[Training]`` specifies computation limits, such as total training time and number of iterations, and options for printing or visualizing information from the planner.
170
+
171
+ For a policy network approach, simply change the ``[Optimizer]`` settings like so:
172
+
173
+ ```ini
174
+ ...
175
+ [Optimizer]
176
+ method='JaxDeepReactivePolicy'
177
+ method_kwargs={'topology': [128, 64], 'activation': 'tanh'}
178
+ ...
179
+ ```
180
+
181
+ The configuration file must then be passed to the planner during initialization.
182
+ For example, the [previous script here](#running-from-within-python) can be modified to set parameters from a config file:
183
+
184
+ ```python
185
+ from pyRDDLGym_jax.core.planner import load_config
186
+
187
+ # load the config file with planner settings
188
+ planner_args, _, train_args = load_config("/path/to/config.cfg")
189
+
190
+ # create the planning algorithm
191
+ planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
192
+ controller = JaxOfflineController(planner, **train_args)
193
+ ...
194
+ ```
195
+
196
+ ## Simulation
197
+
198
+ The JAX compiler can be used as a backend for simulating and evaluating RDDL environments:
199
+
200
+ ```python
201
+ import pyRDDLGym
202
+ from pyRDDLGym.core.policy import RandomAgent
203
+ from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator
204
+
205
+ # create the environment
206
+ env = pyRDDLGym.make("domain", "instance", backend=JaxRDDLSimulator)
207
+
208
+ # evaluate the random policy
209
+ agent = RandomAgent(action_space=env.action_space,
210
+ num_actions=env.max_allowed_actions)
211
+ agent.evaluate(env, verbose=True, render=True)
212
+ ```
213
+
214
+ For some domains, the JAX backend could perform better than the numpy-based one, due to various compiler optimizations.
215
+ In any event, the simulation results using the JAX backend should (almost) always match the numpy backend.
216
+
217
+ ## Manual Gradient Calculation
218
+
219
+ For custom applications, it is desirable to compute gradients of the model that can be optimized downstream.
220
+ Fortunately, we provide a very convenient function for compiling the transition/step function ``P(s, a, s')`` of the environment into JAX.
221
+
222
+ ```python
223
+ import pyRDDLGym
224
+ from pyRDDLGym_jax.core.planner import JaxRDDLCompilerWithGrad
225
+
226
+ # set up the environment
227
+ env = pyRDDLGym.make("domain", "instance", vectorized=True)
228
+
229
+ # create the step function
230
+ compiled = JaxRDDLCompilerWithGrad(rddl=env.model)
231
+ compiled.compile()
232
+ step_fn = compiled.compile_transition()
233
+ ```
234
+
235
+ This will return a JAX compiled (pure) function requiring the following inputs:
236
+ - ``key`` is the ``jax.random.PRNGKey`` key for reproducible randomness
237
+ - ``actions`` is the dictionary of action fluent tensors
238
+ - ``subs`` is the dictionary of state-fluent and non-fluent tensors
239
+ - ``model_params`` are the parameters of the differentiable relaxations, such as ``weight``
240
+
241
+ The function returns a dictionary containing a variety of variables, such as updated pvariables including next-state fluents (``pvar``), reward obtained (``reward``), error codes (``error``).
242
+ It is thus possible to apply any JAX transformation to the output of the function, such as computing gradient using ``jax.grad()`` or batched simulation using ``jax.vmap()``.
243
+
244
+ Compilation of entire rollouts is also possible by calling the ``compile_rollouts`` function.
245
+ An [example is provided to illustrate how you can define your own policy class and compute the return gradient manually](https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/pyRDDLGym_jax/examples/run_gradient.py).
246
+
247
+ ## Citing pyRDDLGym-jax
248
+
249
+ The [following citation](https://ojs.aaai.org/index.php/ICAPS/article/view/31480) describes the main ideas of the framework. Please cite it if you found it useful:
250
+
251
+ ```
252
+ @inproceedings{gimelfarb2024jaxplan,
253
+ title={JaxPlan and GurobiPlan: Optimization Baselines for Replanning in Discrete and Mixed Discrete and Continuous Probabilistic Domains},
254
+ author={Michael Gimelfarb and Ayal Taitler and Scott Sanner},
255
+ booktitle={34th International Conference on Automated Planning and Scheduling},
256
+ year={2024},
257
+ url={https://openreview.net/forum?id=7IKtmUpLEH}
258
+ }
259
+ ```
260
+
261
+ The utility optimization is discussed in [this paper](https://ojs.aaai.org/index.php/AAAI/article/view/21226):
262
+
263
+ ```
264
+ @inproceedings{patton2022distributional,
265
+ title={A distributional framework for risk-sensitive end-to-end planning in continuous mdps},
266
+ author={Patton, Noah and Jeong, Jihwan and Gimelfarb, Mike and Sanner, Scott},
267
+ booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
268
+ volume={36},
269
+ number={9},
270
+ pages={9894--9901},
271
+ year={2022}
272
+ }
273
+ ```
274
+
275
+ Some of the implementation details derive from the following literature, which you may wish to also cite in your research papers:
276
+ - [Deep reactive policies for planning in stochastic nonlinear domains, AAAI 2019](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
277
+ - [Scalable planning with tensorflow for hybrid nonlinear domains, NeurIPS 2017](https://proceedings.neurips.cc/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
278
+