pyRDDLGym-jax 0.2__py3-none-any.whl → 0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,8 @@
1
- import jax
2
1
  import time
3
2
  from typing import Dict, Optional
4
3
 
4
+ import jax
5
+
5
6
  from pyRDDLGym.core.compiler.model import RDDLLiftedModel
6
7
  from pyRDDLGym.core.debug.exception import (
7
8
  RDDLActionPreconditionNotSatisfiedError,
@@ -1,20 +1,18 @@
1
- from bayes_opt import BayesianOptimization
2
- from bayes_opt.util import UtilityFunction
3
1
  from copy import deepcopy
4
2
  import csv
5
3
  import datetime
6
- import jax
7
4
  from multiprocessing import get_context
8
- import numpy as np
9
5
  import os
10
6
  import time
11
7
  from typing import Any, Callable, Dict, Optional, Tuple
12
-
13
- Kwargs = Dict[str, Any]
14
-
15
8
  import warnings
16
9
  warnings.filterwarnings("ignore")
17
10
 
11
+ from bayes_opt import BayesianOptimization
12
+ from bayes_opt.util import UtilityFunction
13
+ import jax
14
+ import numpy as np
15
+
18
16
  from pyRDDLGym.core.debug.exception import raise_warning
19
17
  from pyRDDLGym.core.env import RDDLEnv
20
18
 
@@ -26,6 +24,8 @@ from pyRDDLGym_jax.core.planner import (
26
24
  JaxOnlineController
27
25
  )
28
26
 
27
+ Kwargs = Dict[str, Any]
28
+
29
29
 
30
30
  # ===============================================================================
31
31
  #
@@ -368,7 +368,8 @@ def objective_slp(params, kwargs, key, index):
368
368
  train_seconds=kwargs['timeout_training'],
369
369
  model_params=model_params,
370
370
  policy_hyperparams=policy_hparams,
371
- verbose=0,
371
+ print_summary=False,
372
+ print_progress=False,
372
373
  tqdm_position=index)
373
374
 
374
375
  # initialize env for evaluation (need fresh copy to avoid concurrency)
@@ -499,7 +500,8 @@ def objective_replan(params, kwargs, key, index):
499
500
  train_seconds=kwargs['timeout_training'],
500
501
  model_params=model_params,
501
502
  policy_hyperparams=policy_hparams,
502
- verbose=0,
503
+ print_summary=False,
504
+ print_progress=False,
503
505
  tqdm_position=index)
504
506
 
505
507
  # initialize env for evaluation (need fresh copy to avoid concurrency)
@@ -626,7 +628,8 @@ def objective_drp(params, kwargs, key, index):
626
628
  train_seconds=kwargs['timeout_training'],
627
629
  model_params=model_params,
628
630
  policy_hyperparams=policy_hparams,
629
- verbose=0,
631
+ print_summary=False,
632
+ print_progress=False,
630
633
  tqdm_position=index)
631
634
 
632
635
  # initialize env for evaluation (need fresh copy to avoid concurrency)
@@ -6,7 +6,7 @@ tnorm_kwargs={}
6
6
 
7
7
  [Optimizer]
8
8
  method='JaxDeepReactivePolicy'
9
- method_kwargs={'topology': [128, 128]}
9
+ method_kwargs={'topology': [64, 64]}
10
10
  optimizer='rmsprop'
11
11
  optimizer_kwargs={'learning_rate': 0.001}
12
12
  batch_size_train=1
@@ -14,5 +14,5 @@ batch_size_test=1
14
14
 
15
15
  [Training]
16
16
  key=42
17
- epochs=3000
18
- train_seconds=30
17
+ epochs=6000
18
+ train_seconds=60
@@ -11,6 +11,7 @@ optimizer='rmsprop'
11
11
  optimizer_kwargs={'learning_rate': 0.01}
12
12
  batch_size_train=1
13
13
  batch_size_test=1
14
+ action_bounds={'power-x': (-0.09999, 0.09999), 'power-y': (-0.09999, 0.09999)}
14
15
 
15
16
  [Training]
16
17
  key=42
@@ -8,7 +8,7 @@ tnorm_kwargs={}
8
8
  method='JaxStraightLinePlan'
9
9
  method_kwargs={}
10
10
  optimizer='rmsprop'
11
- optimizer_kwargs={'learning_rate': 5.0}
11
+ optimizer_kwargs={'learning_rate': 1.0}
12
12
  batch_size_train=1
13
13
  batch_size_test=1
14
14
 
@@ -15,5 +15,5 @@ batch_size_test=32
15
15
  [Training]
16
16
  key=42
17
17
  epochs=30000
18
- train_seconds=90
18
+ train_seconds=60
19
19
  policy_hyperparams=2.0
@@ -15,5 +15,5 @@ batch_size_test=32
15
15
  [Training]
16
16
  key=42
17
17
  epochs=30000
18
- train_seconds=90
18
+ train_seconds=60
19
19
  policy_hyperparams=2.0
@@ -23,16 +23,13 @@ from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator
23
23
  def main(domain, instance, episodes=1, seed=42):
24
24
 
25
25
  # create the environment
26
- env = pyRDDLGym.make(domain, instance, enforce_action_constraints=True,
27
- backend=JaxRDDLSimulator)
26
+ env = pyRDDLGym.make(domain, instance, backend=JaxRDDLSimulator)
28
27
 
29
- # set up a random policy
28
+ # evaluate a random policy
30
29
  agent = RandomAgent(action_space=env.action_space,
31
30
  num_actions=env.max_allowed_actions,
32
31
  seed=seed)
33
32
  agent.evaluate(env, episodes=episodes, verbose=True, render=True, seed=seed)
34
-
35
- # important when logging to save all traces
36
33
  env.close()
37
34
 
38
35
 
@@ -13,6 +13,7 @@ where:
13
13
  <domain> is the name of a domain located in the /Examples directory
14
14
  <instance> is the instance number
15
15
  <method> is either slp, drp, or replan
16
+ <episodes> is the optional number of evaluation rollouts
16
17
  '''
17
18
  import os
18
19
  import sys
@@ -28,29 +29,26 @@ from pyRDDLGym_jax.core.planner import (
28
29
  def main(domain, instance, method, episodes=1):
29
30
 
30
31
  # set up the environment
31
- env = pyRDDLGym.make(domain, instance, vectorized=True, enforce_action_constraints=True)
32
+ env = pyRDDLGym.make(domain, instance, vectorized=True)
32
33
 
33
34
  # load the config file with planner settings
34
35
  abs_path = os.path.dirname(os.path.abspath(__file__))
35
36
  config_path = os.path.join(abs_path, 'configs', f'{domain}_{method}.cfg')
36
37
  if not os.path.isfile(config_path):
37
- raise_warning(f'Config file {domain}_{method}.cfg was not found, '
38
- f'using default config (parameters could be suboptimal).',
39
- 'red')
38
+ raise_warning(f'Config file {config_path} was not found, '
39
+ f'using default_{method}.cfg.', 'red')
40
40
  config_path = os.path.join(abs_path, 'configs', f'default_{method}.cfg')
41
41
  planner_args, _, train_args = load_config(config_path)
42
42
 
43
43
  # create the planning algorithm
44
44
  planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
45
45
 
46
- # create the controller
46
+ # evaluate the controller
47
47
  if method == 'replan':
48
48
  controller = JaxOnlineController(planner, **train_args)
49
49
  else:
50
- controller = JaxOfflineController(planner, **train_args)
51
-
50
+ controller = JaxOfflineController(planner, **train_args)
52
51
  controller.evaluate(env, episodes=episodes, verbose=True, render=True)
53
-
54
52
  env.close()
55
53
 
56
54
 
@@ -0,0 +1,61 @@
1
+ '''In this example, the user has the choice to run the Jax planner using an
2
+ optimizer from scipy.minimize.
3
+
4
+ The syntax for running this example is:
5
+
6
+ python run_scipy.py <domain> <instance> <method> [<episodes>]
7
+
8
+ where:
9
+ <domain> is the name of a domain located in the /Examples directory
10
+ <instance> is the instance number
11
+ <method> is the name of a method provided to scipy.optimize.minimize()
12
+ <episodes> is the optional number of evaluation rollouts
13
+ '''
14
+ import os
15
+ import sys
16
+ import jax
17
+ from scipy.optimize import minimize
18
+
19
+ import pyRDDLGym
20
+ from pyRDDLGym.core.debug.exception import raise_warning
21
+
22
+ from pyRDDLGym_jax.core.planner import load_config, JaxBackpropPlanner, JaxOfflineController
23
+
24
+
25
+ def main(domain, instance, method, episodes=1):
26
+
27
+ # set up the environment
28
+ env = pyRDDLGym.make(domain, instance, vectorized=True)
29
+
30
+ # load the config file with planner settings
31
+ abs_path = os.path.dirname(os.path.abspath(__file__))
32
+ config_path = os.path.join(abs_path, 'configs', f'{domain}_slp.cfg')
33
+ if not os.path.isfile(config_path):
34
+ raise_warning(f'Config file {config_path} was not found, '
35
+ f'using default_slp.cfg.', 'red')
36
+ config_path = os.path.join(abs_path, 'configs', 'default_slp.cfg')
37
+ planner_args, _, train_args = load_config(config_path)
38
+
39
+ # create the planning algorithm
40
+ planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
41
+
42
+ # find the optimal plan
43
+ loss_fn, grad_fn, guess, unravel_fn = planner.as_optimization_problem()
44
+ opt = minimize(loss_fn, jac=grad_fn, x0=guess, method=method, options={'disp': True})
45
+ params = unravel_fn(opt.x)
46
+
47
+ # evaluate the optimal plan
48
+ controller = JaxOfflineController(planner, params=params, **train_args)
49
+ controller.evaluate(env, episodes=episodes, verbose=True, render=True)
50
+ env.close()
51
+
52
+
53
+ if __name__ == "__main__":
54
+ args = sys.argv[1:]
55
+ if len(args) < 3:
56
+ print('python run_scipy.py <domain> <instance> <method> [<episodes>]')
57
+ exit(1)
58
+ kwargs = {'domain': args[0], 'instance': args[1], 'method': args[2]}
59
+ if len(args) >= 4: kwargs['episodes'] = int(args[3])
60
+ main(**kwargs)
61
+
@@ -31,15 +31,14 @@ from pyRDDLGym_jax.core.planner import load_config
31
31
  def main(domain, instance, method, trials=5, iters=20, workers=4):
32
32
 
33
33
  # set up the environment
34
- env = pyRDDLGym.make(domain, instance, vectorized=True, enforce_action_constraints=True)
34
+ env = pyRDDLGym.make(domain, instance, vectorized=True)
35
35
 
36
36
  # load the config file with planner settings
37
37
  abs_path = os.path.dirname(os.path.abspath(__file__))
38
38
  config_path = os.path.join(abs_path, 'configs', f'{domain}_{method}.cfg')
39
39
  if not os.path.isfile(config_path):
40
- raise_warning(f'Config file {domain}_{method}.cfg was not found, '
41
- f'using default config (parameters could be suboptimal).',
42
- 'red')
40
+ raise_warning(f'Config file {config_path} was not found, '
41
+ f'using default_{method}.cfg.', 'red')
43
42
  config_path = os.path.join(abs_path, 'configs', f'default_{method}.cfg')
44
43
  planner_args, plan_args, train_args = load_config(config_path)
45
44
 
@@ -49,8 +48,7 @@ def main(domain, instance, method, trials=5, iters=20, workers=4):
49
48
  elif method == 'drp':
50
49
  tuning_class = JaxParameterTuningDRP
51
50
  elif method == 'replan':
52
- tuning_class = JaxParameterTuningSLPReplan
53
-
51
+ tuning_class = JaxParameterTuningSLPReplan
54
52
  tuning = tuning_class(env=env,
55
53
  train_epochs=train_args['epochs'],
56
54
  timeout_training=train_args['train_seconds'],
@@ -60,6 +58,7 @@ def main(domain, instance, method, trials=5, iters=20, workers=4):
60
58
  num_workers=workers,
61
59
  gp_iters=iters)
62
60
 
61
+ # perform tuning and report best parameters
63
62
  best = tuning.tune(key=train_args['key'], filename=f'gp_{method}',
64
63
  save_plot=True)
65
64
  print(f'best parameters found: {best}')
@@ -0,0 +1,276 @@
1
+ Metadata-Version: 2.1
2
+ Name: pyRDDLGym-jax
3
+ Version: 0.4
4
+ Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
+ Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
+ Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
7
+ Author-email: mike.gimelfarb@mail.utoronto.ca, ataitler@gmail.com, ssanner@mie.utoronto.ca
8
+ License: MIT License
9
+ Classifier: Development Status :: 3 - Alpha
10
+ Classifier: Intended Audience :: Science/Research
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Natural Language :: English
13
+ Classifier: Operating System :: OS Independent
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
16
+ Requires-Python: >=3.8
17
+ Description-Content-Type: text/markdown
18
+ License-File: LICENSE
19
+ Requires-Dist: pyRDDLGym >=2.0
20
+ Requires-Dist: tqdm >=4.66
21
+ Requires-Dist: bayesian-optimization >=1.4.3
22
+ Requires-Dist: jax >=0.4.12
23
+ Requires-Dist: optax >=0.1.9
24
+ Requires-Dist: dm-haiku >=0.0.10
25
+ Requires-Dist: tensorflow-probability >=0.21.0
26
+
27
+ # pyRDDLGym-jax
28
+
29
+ Author: [Mike Gimelfarb](https://mike-gimelfarb.github.io)
30
+
31
+ This directory provides:
32
+ 1. automated translation and compilation of RDDL description files into [JAX](https://github.com/google/jax), converting any RDDL domain to a differentiable simulator!
33
+ 2. powerful, fast and scalable gradient-based planning algorithms, with extendible and flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and much more!
34
+
35
+ > [!NOTE]
36
+ > While Jax planners can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
37
+ > If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
38
+
39
+ ## Contents
40
+
41
+ - [Installation](#installation)
42
+ - [Running from the Command Line](#running-from-the-command-line)
43
+ - [Running from within Python](#running-from-within-python)
44
+ - [Configuring the Planner](#configuring-the-planner)
45
+ - [Simulation](#simulation)
46
+ - [Manual Gradient Calculation](#manual-gradient-calculation)
47
+ - [Citing pyRDDLGym-jax](#citing-pyrddlgym-jax)
48
+
49
+ ## Installation
50
+
51
+ To use the compiler or planner without the automated hyper-parameter tuning, you will need the following packages installed:
52
+ - ``pyRDDLGym>=2.0``
53
+ - ``tqdm>=4.66``
54
+ - ``jax>=0.4.12``
55
+ - ``optax>=0.1.9``
56
+ - ``dm-haiku>=0.0.10``
57
+ - ``tensorflow-probability>=0.21.0``
58
+
59
+ Additionally, if you wish to run the examples, you need ``rddlrepository>=2``.
60
+ To run the automated tuning optimization, you will also need ``bayesian-optimization>=1.4.3``.
61
+
62
+ You can install this package, together with all of its requirements, via pip:
63
+
64
+ ```shell
65
+ pip install rddlrepository pyRDDLGym-jax
66
+ ```
67
+
68
+ ## Running from the Command Line
69
+
70
+ A basic run script is provided to run the Jax Planner on any domain in ``rddlrepository``, and can be launched in the command line from the install directory of pyRDDLGym-jax:
71
+
72
+ ```shell
73
+ python -m pyRDDLGym_jax.examples.run_plan <domain> <instance> <method> <episodes>
74
+ ```
75
+
76
+ where:
77
+ - ``domain`` is the domain identifier as specified in rddlrepository (i.e. Wildfire_MDP_ippc2014), or a path pointing to a valid ``domain.rddl`` file
78
+ - ``instance`` is the instance identifier (i.e. 1, 2, ... 10), or a path pointing to a valid ``instance.rddl`` file
79
+ - ``method`` is the planning method to use (i.e. drp, slp, replan)
80
+ - ``episodes`` is the (optional) number of episodes to evaluate the learned policy.
81
+
82
+ The ``method`` parameter supports three possible modes:
83
+ - ``slp`` is the basic straight line planner described [in this paper](https://proceedings.neurips.cc/paper_files/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
84
+ - ``drp`` is the deep reactive policy network described [in this paper](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
85
+ - ``replan`` is the same as ``slp`` except the plan is recalculated at every decision time step.
86
+
87
+ A basic run script is also provided to run the automatic hyper-parameter tuning:
88
+
89
+ ```shell
90
+ python -m pyRDDLGym_jax.examples.run_tune <domain> <instance> <method> <trials> <iters> <workers>
91
+ ```
92
+
93
+ where:
94
+ - ``domain`` is the domain identifier as specified in rddlrepository (i.e. Wildfire_MDP_ippc2014)
95
+ - ``instance`` is the instance identifier (i.e. 1, 2, ... 10)
96
+ - ``method`` is the planning method to use (i.e. drp, slp, replan)
97
+ - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
98
+ - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
99
+ - ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``.
100
+
101
+ For example, the following will train the Jax Planner on the Quadcopter domain with 4 drones:
102
+
103
+ ```shell
104
+ python -m pyRDDLGym_jax.examples.run_plan Quadcopter 1 slp
105
+ ```
106
+
107
+ After several minutes of optimization, you should get a visualization as follows:
108
+
109
+ <p align="center">
110
+ <img src="Images/quadcopter.gif" width="400" height="400" margin=1/>
111
+ </p>
112
+
113
+ ## Running from within Python
114
+
115
+ To run the Jax planner from within a Python application, refer to the following example:
116
+
117
+ ```python
118
+ import pyRDDLGym
119
+ from pyRDDLGym_jax.core.planner import JaxBackpropPlanner, JaxOfflineController
120
+
121
+ # set up the environment (note the vectorized option must be True)
122
+ env = pyRDDLGym.make("domain", "instance", vectorized=True)
123
+
124
+ # create the planning algorithm
125
+ planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
126
+ controller = JaxOfflineController(planner, **train_args)
127
+
128
+ # evaluate the planner
129
+ controller.evaluate(env, episodes=1, verbose=True, render=True)
130
+ env.close()
131
+ ```
132
+
133
+ Here, we have used the straight-line controller, although you can configure the combination of planner and policy representation if you wish.
134
+ All controllers are instances of pyRDDLGym's ``BaseAgent`` class, so they provide the ``evaluate()`` function to streamline interaction with the environment.
135
+ The ``**planner_args`` and ``**train_args`` are keyword argument parameters to pass during initialization, but we strongly recommend creating and loading a config file as discussed in the next section.
136
+
137
+ ## Configuring the Planner
138
+
139
+ The simplest way to configure the planner is to write and pass a configuration file with the necessary [hyper-parameters](https://pyrddlgym.readthedocs.io/en/latest/jax.html#configuring-pyrddlgym-jax).
140
+ The basic structure of a configuration file is provided below for a straight-line planner:
141
+
142
+ ```ini
143
+ [Model]
144
+ logic='FuzzyLogic'
145
+ logic_kwargs={'weight': 20}
146
+ tnorm='ProductTNorm'
147
+ tnorm_kwargs={}
148
+
149
+ [Optimizer]
150
+ method='JaxStraightLinePlan'
151
+ method_kwargs={}
152
+ optimizer='rmsprop'
153
+ optimizer_kwargs={'learning_rate': 0.001}
154
+ batch_size_train=1
155
+ batch_size_test=1
156
+
157
+ [Training]
158
+ key=42
159
+ epochs=5000
160
+ train_seconds=30
161
+ ```
162
+
163
+ The configuration file contains three sections:
164
+ - ``[Model]`` specifies the fuzzy logic operations used to relax discrete operations to differentiable approximations; the ``weight`` dictates the quality of the approximation,
165
+ and ``tnorm`` specifies the type of [fuzzy logic](https://en.wikipedia.org/wiki/T-norm_fuzzy_logics) for relacing logical operations in RDDL (e.g. ``ProductTNorm``, ``GodelTNorm``, ``LukasiewiczTNorm``)
166
+ - ``[Optimizer]`` generally specify the optimizer and plan settings; the ``method`` specifies the plan/policy representation (e.g. ``JaxStraightLinePlan``, ``JaxDeepReactivePolicy``), the gradient descent settings, learning rate, batch size, etc.
167
+ - ``[Training]`` specifies computation limits, such as total training time and number of iterations, and options for printing or visualizing information from the planner.
168
+
169
+ For a policy network approach, simply change the ``[Optimizer]`` settings like so:
170
+
171
+ ```ini
172
+ ...
173
+ [Optimizer]
174
+ method='JaxDeepReactivePolicy'
175
+ method_kwargs={'topology': [128, 64], 'activation': 'tanh'}
176
+ ...
177
+ ```
178
+
179
+ The configuration file must then be passed to the planner during initialization.
180
+ For example, the [previous script here](#running-from-within-python) can be modified to set parameters from a config file:
181
+
182
+ ```python
183
+ from pyRDDLGym_jax.core.planner import load_config
184
+
185
+ # load the config file with planner settings
186
+ planner_args, _, train_args = load_config("/path/to/config.cfg")
187
+
188
+ # create the planning algorithm
189
+ planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
190
+ controller = JaxOfflineController(planner, **train_args)
191
+ ...
192
+ ```
193
+
194
+ ## Simulation
195
+
196
+ The JAX compiler can be used as a backend for simulating and evaluating RDDL environments:
197
+
198
+ ```python
199
+ import pyRDDLGym
200
+ from pyRDDLGym.core.policy import RandomAgent
201
+ from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator
202
+
203
+ # create the environment
204
+ env = pyRDDLGym.make("domain", "instance", backend=JaxRDDLSimulator)
205
+
206
+ # evaluate the random policy
207
+ agent = RandomAgent(action_space=env.action_space,
208
+ num_actions=env.max_allowed_actions)
209
+ agent.evaluate(env, verbose=True, render=True)
210
+ ```
211
+
212
+ For some domains, the JAX backend could perform better than the numpy-based one, due to various compiler optimizations.
213
+ In any event, the simulation results using the JAX backend should (almost) always match the numpy backend.
214
+
215
+ ## Manual Gradient Calculation
216
+
217
+ For custom applications, it is desirable to compute gradients of the model that can be optimized downstream.
218
+ Fortunately, we provide a very convenient function for compiling the transition/step function ``P(s, a, s')`` of the environment into JAX.
219
+
220
+ ```python
221
+ import pyRDDLGym
222
+ from pyRDDLGym_jax.core.planner import JaxRDDLCompilerWithGrad
223
+
224
+ # set up the environment
225
+ env = pyRDDLGym.make("domain", "instance", vectorized=True)
226
+
227
+ # create the step function
228
+ compiled = JaxRDDLCompilerWithGrad(rddl=env.model)
229
+ compiled.compile()
230
+ step_fn = compiled.compile_transition()
231
+ ```
232
+
233
+ This will return a JAX compiled (pure) function requiring the following inputs:
234
+ - ``key`` is the ``jax.random.PRNGKey`` key for reproducible randomness
235
+ - ``actions`` is the dictionary of action fluent tensors
236
+ - ``subs`` is the dictionary of state-fluent and non-fluent tensors
237
+ - ``model_params`` are the parameters of the differentiable relaxations, such as ``weight``
238
+
239
+ The function returns a dictionary containing a variety of variables, such as updated pvariables including next-state fluents (``pvar``), reward obtained (``reward``), error codes (``error``).
240
+ It is thus possible to apply any JAX transformation to the output of the function, such as computing gradient using ``jax.grad()`` or batched simulation using ``jax.vmap()``.
241
+
242
+ Compilation of entire rollouts is also possible by calling the ``compile_rollouts`` function.
243
+ An [example is provided to illustrate how you can define your own policy class and compute the return gradient manually](https://github.com/pyrddlgym-project/pyRDDLGym-jax/blob/main/pyRDDLGym_jax/examples/run_gradient.py).
244
+
245
+ ## Citing pyRDDLGym-jax
246
+
247
+ The [following citation](https://ojs.aaai.org/index.php/ICAPS/article/view/31480) describes the main ideas of the framework. Please cite it if you found it useful:
248
+
249
+ ```
250
+ @inproceedings{gimelfarb2024jaxplan,
251
+ title={JaxPlan and GurobiPlan: Optimization Baselines for Replanning in Discrete and Mixed Discrete and Continuous Probabilistic Domains},
252
+ author={Michael Gimelfarb and Ayal Taitler and Scott Sanner},
253
+ booktitle={34th International Conference on Automated Planning and Scheduling},
254
+ year={2024},
255
+ url={https://openreview.net/forum?id=7IKtmUpLEH}
256
+ }
257
+ ```
258
+
259
+ The utility optimization is discussed in [this paper](https://ojs.aaai.org/index.php/AAAI/article/view/21226):
260
+
261
+ ```
262
+ @inproceedings{patton2022distributional,
263
+ title={A distributional framework for risk-sensitive end-to-end planning in continuous mdps},
264
+ author={Patton, Noah and Jeong, Jihwan and Gimelfarb, Mike and Sanner, Scott},
265
+ booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
266
+ volume={36},
267
+ number={9},
268
+ pages={9894--9901},
269
+ year={2022}
270
+ }
271
+ ```
272
+
273
+ Some of the implementation details derive from the following literature, which you may wish to also cite in your research papers:
274
+ - [Deep reactive policies for planning in stochastic nonlinear domains, AAAI 2019](https://ojs.aaai.org/index.php/AAAI/article/view/4744)
275
+ - [Scalable planning with tensorflow for hybrid nonlinear domains, NeurIPS 2017](https://proceedings.neurips.cc/paper/2017/file/98b17f068d5d9b7668e19fb8ae470841-Paper.pdf)
276
+
@@ -1,26 +1,26 @@
1
- pyRDDLGym_jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
1
+ pyRDDLGym_jax/__init__.py,sha256=rexmxcBiCOcwctw4wGvk7UxS9MfZn_1CYXp53SoLKlU,19
2
2
  pyRDDLGym_jax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- pyRDDLGym_jax/core/compiler.py,sha256=K1p99rKdx2PMdsN1jq6ZEtSANUWACWmOzJrGirj5wq8,89176
4
- pyRDDLGym_jax/core/logic.py,sha256=zujSHiR5KhTO81E5Zn8Gy_xSzVzfDskFCGvZygFRdMI,21930
5
- pyRDDLGym_jax/core/planner.py,sha256=vEc-Um_3q1QGhlqk-6oq0eGW7iFAbu-6kpSnyKYS9tI,91731
6
- pyRDDLGym_jax/core/simulator.py,sha256=fp6bep3XwwBWED0w7_4qhiwDjkSka6B2prwdNcPRCMc,8329
7
- pyRDDLGym_jax/core/tuning.py,sha256=uhpL3UCfSIgxDEvKI8PibwgTafCMLR_8LrRj5cBKLWE,29466
3
+ pyRDDLGym_jax/core/compiler.py,sha256=SnDN3-J84Wv_YVHoDmfM_U4Ob8uaFLGX4vEaeWC-ERY,90037
4
+ pyRDDLGym_jax/core/logic.py,sha256=o1YAjMnXfi8gwb42kAigBmaf9uIYUWal9__FEkWohrk,26733
5
+ pyRDDLGym_jax/core/planner.py,sha256=Hrwfn88bUu1LNZcnFC5psHPzcIUbPeF4Rn1pFO6_qH0,102655
6
+ pyRDDLGym_jax/core/simulator.py,sha256=hWv6pr-4V-SSCzBYgdIPmKdUDMalft-Zh6dzOo5O9-0,8331
7
+ pyRDDLGym_jax/core/tuning.py,sha256=D_kD8wjqMroCdtjE9eksR2UqrqXJqazsAKrMEHwPxYM,29589
8
8
  pyRDDLGym_jax/examples/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
9
  pyRDDLGym_jax/examples/run_gradient.py,sha256=KhXvijRDZ4V7N8NOI2WV8ePGpPna5_vnET61YwS7Tco,2919
10
- pyRDDLGym_jax/examples/run_gym.py,sha256=tY6xvLCX9gLCYLK2o0gr44j26SHEwOrhyEQF0wNNIWY,1639
11
- pyRDDLGym_jax/examples/run_plan.py,sha256=Z40aoikg_mN3rXPl-EGU-BQ9QdHl96P1Qvp4GNZeo5c,2499
12
- pyRDDLGym_jax/examples/run_tune.py,sha256=N3mMCiRsxWtb94xfk6-v2nelkL5VekneeDmaLxPJND4,3318
10
+ pyRDDLGym_jax/examples/run_gym.py,sha256=rXvNWkxe4jHllvbvU_EOMji_2-2k5d4tbBKhpMm_Gaw,1526
11
+ pyRDDLGym_jax/examples/run_plan.py,sha256=OENf8s-SrMlh7CYXNhanQiau35b4atLBJMNjgP88DCg,2463
12
+ pyRDDLGym_jax/examples/run_scipy.py,sha256=wvcpWCvdjvYHntO95a7JYfY2fuCMUTKnqjJikW0PnL4,2291
13
+ pyRDDLGym_jax/examples/run_tune.py,sha256=-M4KoBpg5lshQ4mmU0cnLs2i7-ldSIr_OcxHK7YA6bw,3273
13
14
  pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg,sha256=pbkz6ccgk5dHXp7cfYbZNFyJobpGyxUZleCy4fvlmaU,336
14
15
  pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg,sha256=OswO9YD4Xh1pw3R3LkUBb67WLtj5XlE3qnMQ5CKwPsM,332
15
16
  pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg,sha256=FxZ4xcg2j2PzeH-wUseRR280juQN5bJjoyt6PtI1W7c,329
16
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg,sha256=FZSZc93617oglT66U80vF2QAPv18tHR1NqTbVJRjlfs,338
17
+ pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg,sha256=FTGFwRAGyeRrbDMh_FV8iv8ZHrlj3Htju4pfPNmKIcw,336
17
18
  pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg,sha256=wjtz86_Gz0RfQu3bbrz56PTXL8JMernINx7AtJuZCPs,314
18
- pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg,sha256=MGLcEIixCzdR_UHR4Ydr8hjB1-lff7U2Zj_cZ0iuPqo,335
19
+ pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg,sha256=C_0BFyhGXbtF7N4vyeua2XkORbkj10HELC1GpzM0Uh4,415
19
20
  pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg,sha256=Yb4tFzUOj4epCCsofXAZo70lm5C2KzPIzI5PQHsa_Vk,429
20
21
  pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg,sha256=e7j-1Z66o7F-KZDSf2e8TQRWwkXOPRwrRFkIavK8G7g,327
21
22
  pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg,sha256=Z6CxaOxHv4oF6nW7SfSn_HshlQGDlNCPGASTnDTdL7Q,327
22
- pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg,sha256=gS66TWLSWL9D2DSRU9XK_5geEz2Nq0aBkoF9Oi2tTkc,315
23
- pyRDDLGym_jax/examples/configs/Pong_slp.cfg,sha256=S45mBj5hTEshdeJ4rdRaty6YliggtEMkLQV6IYxEkyU,315
23
+ pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg,sha256=Uy1mrX-AZMS-KBAhWXJ3c_QAhd4bRSWttDoFGYQ08lQ,315
24
24
  pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg,sha256=SM5_U4RwvvucHVAOdMG4vqH0Eg43f3WX9ZlV6aFPgTw,341
25
25
  pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg,sha256=lcqQ7P7X4qAbMlpkKKuYGn2luSZH-yFB7oi-eHj9Qng,332
26
26
  pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg,sha256=kG1-02ScmwsEwX7QIAZTD7si90Mb06b79G5oqcMQ9Hg,316
@@ -29,18 +29,16 @@ pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg,sha256=9QNl58PyoJYhmwvrhzUxlLE
29
29
  pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg,sha256=rrubYvC1q7Ff0ADV0GXtLw-rD9E4m7qfR66qxdYNTD8,339
30
30
  pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg,sha256=DAb-J2KwvJXViRRSHZe8aJwZiPljC28HtrKJPieeUCY,331
31
31
  pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg,sha256=QwKzCAFaErrTCHaJwDPLOxPHpNGNuAKMUoZjLLnMrNc,314
32
- pyRDDLGym_jax/examples/configs/SupplyChain_slp.cfg,sha256=vU_m6KjfNfaPuYosFdAWeYiV1zQGd6eNA17Yn5QB_BI,319
33
- pyRDDLGym_jax/examples/configs/Traffic_slp.cfg,sha256=03scuHAl6032YhyYy0w5MLMbTibhdbUZFHLhH2WWaPI,370
34
32
  pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg,sha256=QiJCJYOrdXXZfOTuPleGswREFxjGlqQSA0rw00YJWWI,318
35
33
  pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg,sha256=PGkgll7h5vhSF13JScKoQ-vpWaAGNJ_PUEhK7jEjNx4,340
36
34
  pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg,sha256=kEDAwsJQ_t9WPzPhIxfS0hRtgOhtFdJFfmPtTTJuwUE,454
37
35
  pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg,sha256=w2wipsA8PE5OBkYVIKajjtCOtiHqmMeY3XQVPAApwFk,371
38
36
  pyRDDLGym_jax/examples/configs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
39
- pyRDDLGym_jax/examples/configs/default_drp.cfg,sha256=AtHBy2G9xXTWb6CfDGeKrC2tycwhe2u-JjLnAVLTugQ,344
37
+ pyRDDLGym_jax/examples/configs/default_drp.cfg,sha256=S2-5hPZtgAwUAFpiCAgSi-cnGhYHSDzMGMmatwhbM78,344
40
38
  pyRDDLGym_jax/examples/configs/default_replan.cfg,sha256=VWWPhOYBRq4cWwtrChw5pPqRmlX_nHbMvwciHd9hoLc,357
41
- pyRDDLGym_jax/examples/configs/default_slp.cfg,sha256=iicFMK1XJRo1OnLPPY9DKKHgt8T0FttZxRT5UYMJCRE,340
42
- pyRDDLGym_jax-0.2.dist-info/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
43
- pyRDDLGym_jax-0.2.dist-info/METADATA,sha256=sXEBaZx1czXcRRMj0vXyJrtJUJHnkorsWgqrWU9axV0,1085
44
- pyRDDLGym_jax-0.2.dist-info/WHEEL,sha256=y4mX-SOX4fYIkonsAGA5N0Oy-8_gI4FXw5HNI1xqvWg,91
45
- pyRDDLGym_jax-0.2.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
46
- pyRDDLGym_jax-0.2.dist-info/RECORD,,
39
+ pyRDDLGym_jax/examples/configs/default_slp.cfg,sha256=TG3mtHUnCA7J2Gm9SczENpqAymTnzCE9dj1Z_R-FnVk,340
40
+ pyRDDLGym_jax-0.4.dist-info/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
41
+ pyRDDLGym_jax-0.4.dist-info/METADATA,sha256=-Kf8PLxf_7MiiYXzlZAf31kV1pT-Rurc7QY7dT3Fwk0,12857
42
+ pyRDDLGym_jax-0.4.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
43
+ pyRDDLGym_jax-0.4.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
44
+ pyRDDLGym_jax-0.4.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (70.2.0)
2
+ Generator: setuptools (75.3.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,18 +0,0 @@
1
- [Model]
2
- logic='FuzzyLogic'
3
- logic_kwargs={'weight': 1.0}
4
- tnorm='ProductTNorm'
5
- tnorm_kwargs={}
6
-
7
- [Optimizer]
8
- method='JaxStraightLinePlan'
9
- method_kwargs={}
10
- optimizer='rmsprop'
11
- optimizer_kwargs={'learning_rate': 0.001}
12
- batch_size_train=1
13
- batch_size_test=1
14
-
15
- [Training]
16
- key=42
17
- epochs=2000
18
- train_seconds=30