pyRDDLGym-jax 1.0__tar.gz → 1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/PKG-INFO +12 -13
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/README.md +11 -12
- pyrddlgym_jax-1.1/pyRDDLGym_jax/__init__.py +1 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/core/compiler.py +60 -60
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/core/planner.py +8 -10
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/core/tuning.py +20 -6
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/core/visualization.py +1 -1
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax.egg-info/PKG-INFO +12 -13
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/setup.py +1 -1
- pyrddlgym_jax-1.0/pyRDDLGym_jax/__init__.py +0 -1
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/LICENSE +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/core/__init__.py +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/core/logic.py +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/core/simulator.py +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/__init__.py +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/__init__.py +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/default_drp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/default_replan.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/default_slp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/tuning_drp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/tuning_replan.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/tuning_slp.cfg +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/run_gradient.py +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/run_gym.py +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/run_plan.py +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/run_scipy.py +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/run_tune.py +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax.egg-info/SOURCES.txt +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax.egg-info/dependency_links.txt +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax.egg-info/requires.txt +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax.egg-info/top_level.txt +0 -0
- {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: pyRDDLGym-jax
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.1
|
|
4
4
|
Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
|
|
5
5
|
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
6
|
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
@@ -31,7 +31,17 @@ Requires-Dist: dash-bootstrap-components>=1.6.0; extra == "dashboard"
|
|
|
31
31
|
|
|
32
32
|
# pyRDDLGym-jax
|
|
33
33
|
|
|
34
|
-
|
|
34
|
+

|
|
35
|
+
[](https://pypi.org/project/pyRDDLGym-jax/)
|
|
36
|
+
[](https://pyrddlgym.readthedocs.io/en/latest/jax.html)
|
|
37
|
+

|
|
38
|
+
[](https://pypistats.org/packages/pyrddlgym-jax)
|
|
39
|
+
|
|
40
|
+
[Installation](#installation) | [Run cmd](#running-from-the-command-line) | [Run python](#running-from-another-python-application) | [Configuration](#configuring-the-planner) | [Dashboard](#jaxplan-dashboard) | [Tuning](#tuning-the-planner) | [Simulation](#simulation) | [Citing](#citing-jaxplan)
|
|
41
|
+
|
|
42
|
+
**pyRDDLGym-jax (known in the literature as JaxPlan) is an efficient gradient-based/differentiable planning algorithm in JAX.**
|
|
43
|
+
|
|
44
|
+
Purpose:
|
|
35
45
|
|
|
36
46
|
1. automatic translation of any RDDL description file into a differentiable simulator in JAX
|
|
37
47
|
2. flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and Bayesian hyper-parameter tuning.
|
|
@@ -56,17 +66,6 @@ and was moved to the individual logic components which have their own unique wei
|
|
|
56
66
|
> [!NOTE]
|
|
57
67
|
> While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
|
|
58
68
|
> If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
|
|
59
|
-
|
|
60
|
-
## Contents
|
|
61
|
-
|
|
62
|
-
- [Installation](#installation)
|
|
63
|
-
- [Running from the Command Line](#running-from-the-command-line)
|
|
64
|
-
- [Running from Another Python Application](#running-from-another-python-application)
|
|
65
|
-
- [Configuring the Planner](#configuring-the-planner)
|
|
66
|
-
- [JaxPlan Dashboard](#jaxplan-dashboard)
|
|
67
|
-
- [Tuning the Planner](#tuning-the-planner)
|
|
68
|
-
- [Simulation](#simulation)
|
|
69
|
-
- [Citing JaxPlan](#citing-jaxplan)
|
|
70
69
|
|
|
71
70
|
## Installation
|
|
72
71
|
|
|
@@ -1,6 +1,16 @@
|
|
|
1
1
|
# pyRDDLGym-jax
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+

|
|
4
|
+
[](https://pypi.org/project/pyRDDLGym-jax/)
|
|
5
|
+
[](https://pyrddlgym.readthedocs.io/en/latest/jax.html)
|
|
6
|
+

|
|
7
|
+
[](https://pypistats.org/packages/pyrddlgym-jax)
|
|
8
|
+
|
|
9
|
+
[Installation](#installation) | [Run cmd](#running-from-the-command-line) | [Run python](#running-from-another-python-application) | [Configuration](#configuring-the-planner) | [Dashboard](#jaxplan-dashboard) | [Tuning](#tuning-the-planner) | [Simulation](#simulation) | [Citing](#citing-jaxplan)
|
|
10
|
+
|
|
11
|
+
**pyRDDLGym-jax (known in the literature as JaxPlan) is an efficient gradient-based/differentiable planning algorithm in JAX.**
|
|
12
|
+
|
|
13
|
+
Purpose:
|
|
4
14
|
|
|
5
15
|
1. automatic translation of any RDDL description file into a differentiable simulator in JAX
|
|
6
16
|
2. flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and Bayesian hyper-parameter tuning.
|
|
@@ -25,17 +35,6 @@ and was moved to the individual logic components which have their own unique wei
|
|
|
25
35
|
> [!NOTE]
|
|
26
36
|
> While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
|
|
27
37
|
> If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
|
|
28
|
-
|
|
29
|
-
## Contents
|
|
30
|
-
|
|
31
|
-
- [Installation](#installation)
|
|
32
|
-
- [Running from the Command Line](#running-from-the-command-line)
|
|
33
|
-
- [Running from Another Python Application](#running-from-another-python-application)
|
|
34
|
-
- [Configuring the Planner](#configuring-the-planner)
|
|
35
|
-
- [JaxPlan Dashboard](#jaxplan-dashboard)
|
|
36
|
-
- [Tuning the Planner](#tuning-the-planner)
|
|
37
|
-
- [Simulation](#simulation)
|
|
38
|
-
- [Citing JaxPlan](#citing-jaxplan)
|
|
39
38
|
|
|
40
39
|
## Installation
|
|
41
40
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = '1.1'
|
|
@@ -51,65 +51,65 @@ class JaxRDDLCompiler:
|
|
|
51
51
|
return func
|
|
52
52
|
return exact_func
|
|
53
53
|
|
|
54
|
-
EXACT_RDDL_TO_JAX_NEGATIVE = wrap_logic(ExactLogic.exact_unary_function(jnp.negative))
|
|
54
|
+
EXACT_RDDL_TO_JAX_NEGATIVE = wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.negative))
|
|
55
55
|
EXACT_RDDL_TO_JAX_ARITHMETIC = {
|
|
56
|
-
'+': wrap_logic(ExactLogic.exact_binary_function(jnp.add)),
|
|
57
|
-
'-': wrap_logic(ExactLogic.exact_binary_function(jnp.subtract)),
|
|
58
|
-
'*': wrap_logic(ExactLogic.exact_binary_function(jnp.multiply)),
|
|
59
|
-
'/': wrap_logic(ExactLogic.exact_binary_function(jnp.divide))
|
|
56
|
+
'+': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.add)),
|
|
57
|
+
'-': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.subtract)),
|
|
58
|
+
'*': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.multiply)),
|
|
59
|
+
'/': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.divide))
|
|
60
60
|
}
|
|
61
61
|
|
|
62
62
|
EXACT_RDDL_TO_JAX_RELATIONAL = {
|
|
63
|
-
'>=': wrap_logic(ExactLogic.exact_binary_function(jnp.greater_equal)),
|
|
64
|
-
'<=': wrap_logic(ExactLogic.exact_binary_function(jnp.less_equal)),
|
|
65
|
-
'<': wrap_logic(ExactLogic.exact_binary_function(jnp.less)),
|
|
66
|
-
'>': wrap_logic(ExactLogic.exact_binary_function(jnp.greater)),
|
|
67
|
-
'==': wrap_logic(ExactLogic.exact_binary_function(jnp.equal)),
|
|
68
|
-
'~=': wrap_logic(ExactLogic.exact_binary_function(jnp.not_equal))
|
|
63
|
+
'>=': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.greater_equal)),
|
|
64
|
+
'<=': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.less_equal)),
|
|
65
|
+
'<': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.less)),
|
|
66
|
+
'>': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.greater)),
|
|
67
|
+
'==': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.equal)),
|
|
68
|
+
'~=': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.not_equal))
|
|
69
69
|
}
|
|
70
70
|
|
|
71
|
-
EXACT_RDDL_TO_JAX_LOGICAL_NOT = wrap_logic(ExactLogic.exact_unary_function(jnp.logical_not))
|
|
71
|
+
EXACT_RDDL_TO_JAX_LOGICAL_NOT = wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.logical_not))
|
|
72
72
|
EXACT_RDDL_TO_JAX_LOGICAL = {
|
|
73
|
-
'^': wrap_logic(ExactLogic.exact_binary_function(jnp.logical_and)),
|
|
74
|
-
'&': wrap_logic(ExactLogic.exact_binary_function(jnp.logical_and)),
|
|
75
|
-
'|': wrap_logic(ExactLogic.exact_binary_function(jnp.logical_or)),
|
|
76
|
-
'~': wrap_logic(ExactLogic.exact_binary_function(jnp.logical_xor)),
|
|
77
|
-
'=>': wrap_logic(ExactLogic.exact_binary_implies),
|
|
78
|
-
'<=>': wrap_logic(ExactLogic.exact_binary_function(jnp.equal))
|
|
73
|
+
'^': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_and)),
|
|
74
|
+
'&': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_and)),
|
|
75
|
+
'|': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_or)),
|
|
76
|
+
'~': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_xor)),
|
|
77
|
+
'=>': wrap_logic.__func__(ExactLogic.exact_binary_implies),
|
|
78
|
+
'<=>': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.equal))
|
|
79
79
|
}
|
|
80
80
|
|
|
81
81
|
EXACT_RDDL_TO_JAX_AGGREGATION = {
|
|
82
|
-
'sum': wrap_logic(ExactLogic.exact_aggregation(jnp.sum)),
|
|
83
|
-
'avg': wrap_logic(ExactLogic.exact_aggregation(jnp.mean)),
|
|
84
|
-
'prod': wrap_logic(ExactLogic.exact_aggregation(jnp.prod)),
|
|
85
|
-
'minimum': wrap_logic(ExactLogic.exact_aggregation(jnp.min)),
|
|
86
|
-
'maximum': wrap_logic(ExactLogic.exact_aggregation(jnp.max)),
|
|
87
|
-
'forall': wrap_logic(ExactLogic.exact_aggregation(jnp.all)),
|
|
88
|
-
'exists': wrap_logic(ExactLogic.exact_aggregation(jnp.any)),
|
|
89
|
-
'argmin': wrap_logic(ExactLogic.exact_aggregation(jnp.argmin)),
|
|
90
|
-
'argmax': wrap_logic(ExactLogic.exact_aggregation(jnp.argmax))
|
|
82
|
+
'sum': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.sum)),
|
|
83
|
+
'avg': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.mean)),
|
|
84
|
+
'prod': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.prod)),
|
|
85
|
+
'minimum': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.min)),
|
|
86
|
+
'maximum': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.max)),
|
|
87
|
+
'forall': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.all)),
|
|
88
|
+
'exists': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.any)),
|
|
89
|
+
'argmin': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.argmin)),
|
|
90
|
+
'argmax': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.argmax))
|
|
91
91
|
}
|
|
92
92
|
|
|
93
93
|
EXACT_RDDL_TO_JAX_UNARY = {
|
|
94
|
-
'abs': wrap_logic(ExactLogic.exact_unary_function(jnp.abs)),
|
|
95
|
-
'sgn': wrap_logic(ExactLogic.exact_unary_function(jnp.sign)),
|
|
96
|
-
'round': wrap_logic(ExactLogic.exact_unary_function(jnp.round)),
|
|
97
|
-
'floor': wrap_logic(ExactLogic.exact_unary_function(jnp.floor)),
|
|
98
|
-
'ceil': wrap_logic(ExactLogic.exact_unary_function(jnp.ceil)),
|
|
99
|
-
'cos': wrap_logic(ExactLogic.exact_unary_function(jnp.cos)),
|
|
100
|
-
'sin': wrap_logic(ExactLogic.exact_unary_function(jnp.sin)),
|
|
101
|
-
'tan': wrap_logic(ExactLogic.exact_unary_function(jnp.tan)),
|
|
102
|
-
'acos': wrap_logic(ExactLogic.exact_unary_function(jnp.arccos)),
|
|
103
|
-
'asin': wrap_logic(ExactLogic.exact_unary_function(jnp.arcsin)),
|
|
104
|
-
'atan': wrap_logic(ExactLogic.exact_unary_function(jnp.arctan)),
|
|
105
|
-
'cosh': wrap_logic(ExactLogic.exact_unary_function(jnp.cosh)),
|
|
106
|
-
'sinh': wrap_logic(ExactLogic.exact_unary_function(jnp.sinh)),
|
|
107
|
-
'tanh': wrap_logic(ExactLogic.exact_unary_function(jnp.tanh)),
|
|
108
|
-
'exp': wrap_logic(ExactLogic.exact_unary_function(jnp.exp)),
|
|
109
|
-
'ln': wrap_logic(ExactLogic.exact_unary_function(jnp.log)),
|
|
110
|
-
'sqrt': wrap_logic(ExactLogic.exact_unary_function(jnp.sqrt)),
|
|
111
|
-
'lngamma': wrap_logic(ExactLogic.exact_unary_function(scipy.special.gammaln)),
|
|
112
|
-
'gamma': wrap_logic(ExactLogic.exact_unary_function(scipy.special.gamma))
|
|
94
|
+
'abs': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.abs)),
|
|
95
|
+
'sgn': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sign)),
|
|
96
|
+
'round': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.round)),
|
|
97
|
+
'floor': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.floor)),
|
|
98
|
+
'ceil': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.ceil)),
|
|
99
|
+
'cos': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.cos)),
|
|
100
|
+
'sin': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sin)),
|
|
101
|
+
'tan': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.tan)),
|
|
102
|
+
'acos': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.arccos)),
|
|
103
|
+
'asin': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.arcsin)),
|
|
104
|
+
'atan': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.arctan)),
|
|
105
|
+
'cosh': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.cosh)),
|
|
106
|
+
'sinh': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sinh)),
|
|
107
|
+
'tanh': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.tanh)),
|
|
108
|
+
'exp': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.exp)),
|
|
109
|
+
'ln': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.log)),
|
|
110
|
+
'sqrt': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sqrt)),
|
|
111
|
+
'lngamma': wrap_logic.__func__(ExactLogic.exact_unary_function(scipy.special.gammaln)),
|
|
112
|
+
'gamma': wrap_logic.__func__(ExactLogic.exact_unary_function(scipy.special.gamma))
|
|
113
113
|
}
|
|
114
114
|
|
|
115
115
|
@staticmethod
|
|
@@ -117,23 +117,23 @@ class JaxRDDLCompiler:
|
|
|
117
117
|
return jnp.log(x) / jnp.log(y), params
|
|
118
118
|
|
|
119
119
|
EXACT_RDDL_TO_JAX_BINARY = {
|
|
120
|
-
'div': wrap_logic(ExactLogic.exact_binary_function(jnp.floor_divide)),
|
|
121
|
-
'mod': wrap_logic(ExactLogic.exact_binary_function(jnp.mod)),
|
|
122
|
-
'fmod': wrap_logic(ExactLogic.exact_binary_function(jnp.mod)),
|
|
123
|
-
'min': wrap_logic(ExactLogic.exact_binary_function(jnp.minimum)),
|
|
124
|
-
'max': wrap_logic(ExactLogic.exact_binary_function(jnp.maximum)),
|
|
125
|
-
'pow': wrap_logic(ExactLogic.exact_binary_function(jnp.power)),
|
|
126
|
-
'log': wrap_logic(_jax_wrapped_calc_log_exact),
|
|
127
|
-
'hypot': wrap_logic(ExactLogic.exact_binary_function(jnp.hypot)),
|
|
120
|
+
'div': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.floor_divide)),
|
|
121
|
+
'mod': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.mod)),
|
|
122
|
+
'fmod': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.mod)),
|
|
123
|
+
'min': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.minimum)),
|
|
124
|
+
'max': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.maximum)),
|
|
125
|
+
'pow': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.power)),
|
|
126
|
+
'log': wrap_logic.__func__(_jax_wrapped_calc_log_exact.__func__),
|
|
127
|
+
'hypot': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.hypot)),
|
|
128
128
|
}
|
|
129
129
|
|
|
130
|
-
EXACT_RDDL_TO_JAX_IF = wrap_logic(ExactLogic.exact_if_then_else)
|
|
131
|
-
EXACT_RDDL_TO_JAX_SWITCH = wrap_logic(ExactLogic.exact_switch)
|
|
130
|
+
EXACT_RDDL_TO_JAX_IF = wrap_logic.__func__(ExactLogic.exact_if_then_else)
|
|
131
|
+
EXACT_RDDL_TO_JAX_SWITCH = wrap_logic.__func__(ExactLogic.exact_switch)
|
|
132
132
|
|
|
133
|
-
EXACT_RDDL_TO_JAX_BERNOULLI = wrap_logic(ExactLogic.exact_bernoulli)
|
|
134
|
-
EXACT_RDDL_TO_JAX_DISCRETE = wrap_logic(ExactLogic.exact_discrete)
|
|
135
|
-
EXACT_RDDL_TO_JAX_POISSON = wrap_logic(ExactLogic.exact_poisson)
|
|
136
|
-
EXACT_RDDL_TO_JAX_GEOMETRIC = wrap_logic(ExactLogic.exact_geometric)
|
|
133
|
+
EXACT_RDDL_TO_JAX_BERNOULLI = wrap_logic.__func__(ExactLogic.exact_bernoulli)
|
|
134
|
+
EXACT_RDDL_TO_JAX_DISCRETE = wrap_logic.__func__(ExactLogic.exact_discrete)
|
|
135
|
+
EXACT_RDDL_TO_JAX_POISSON = wrap_logic.__func__(ExactLogic.exact_poisson)
|
|
136
|
+
EXACT_RDDL_TO_JAX_GEOMETRIC = wrap_logic.__func__(ExactLogic.exact_geometric)
|
|
137
137
|
|
|
138
138
|
def __init__(self, rddl: RDDLLiftedModel,
|
|
139
139
|
allow_synchronous_state: bool=True,
|
|
@@ -65,9 +65,8 @@ def _parse_config_file(path: str):
|
|
|
65
65
|
config = configparser.RawConfigParser()
|
|
66
66
|
config.optionxform = str
|
|
67
67
|
config.read(path)
|
|
68
|
-
args = {k: literal_eval(v)
|
|
69
|
-
for section in config.sections()
|
|
70
|
-
for (k, v) in config.items(section)}
|
|
68
|
+
args = {section: {k: literal_eval(v) for (k, v) in config.items(section)}
|
|
69
|
+
for section in config.sections()}
|
|
71
70
|
return config, args
|
|
72
71
|
|
|
73
72
|
|
|
@@ -75,9 +74,8 @@ def _parse_config_string(value: str):
|
|
|
75
74
|
config = configparser.RawConfigParser()
|
|
76
75
|
config.optionxform = str
|
|
77
76
|
config.read_string(value)
|
|
78
|
-
args = {k: literal_eval(v)
|
|
79
|
-
for section in config.sections()
|
|
80
|
-
for (k, v) in config.items(section)}
|
|
77
|
+
args = {section: {k: literal_eval(v) for (k, v) in config.items(section)}
|
|
78
|
+
for section in config.sections()}
|
|
81
79
|
return config, args
|
|
82
80
|
|
|
83
81
|
|
|
@@ -90,9 +88,9 @@ def _getattr_any(packages, item):
|
|
|
90
88
|
|
|
91
89
|
|
|
92
90
|
def _load_config(config, args):
|
|
93
|
-
model_args = {k: args[k] for (k, _) in config.items('Model')}
|
|
94
|
-
planner_args = {k: args[k] for (k, _) in config.items('Optimizer')}
|
|
95
|
-
train_args = {k: args[k] for (k, _) in config.items('Training')}
|
|
91
|
+
model_args = {k: args['Model'][k] for (k, _) in config.items('Model')}
|
|
92
|
+
planner_args = {k: args['Optimizer'][k] for (k, _) in config.items('Optimizer')}
|
|
93
|
+
train_args = {k: args['Training'][k] for (k, _) in config.items('Training')}
|
|
96
94
|
|
|
97
95
|
# read the model settings
|
|
98
96
|
logic_name = model_args.get('logic', 'FuzzyLogic')
|
|
@@ -1661,7 +1659,7 @@ r"""
|
|
|
1661
1659
|
def optimize_generator(self, key: Optional[random.PRNGKey]=None,
|
|
1662
1660
|
epochs: int=999999,
|
|
1663
1661
|
train_seconds: float=120.,
|
|
1664
|
-
dashboard: Optional[
|
|
1662
|
+
dashboard: Optional[Any]=None,
|
|
1665
1663
|
dashboard_id: Optional[str]=None,
|
|
1666
1664
|
model_params: Optional[Dict[str, Any]]=None,
|
|
1667
1665
|
policy_hyperparams: Optional[Dict[str, Any]]=None,
|
|
@@ -4,6 +4,7 @@ import threading
|
|
|
4
4
|
import multiprocessing
|
|
5
5
|
import os
|
|
6
6
|
import time
|
|
7
|
+
import traceback
|
|
7
8
|
from typing import Any, Callable, Dict, Iterable, Optional, Tuple
|
|
8
9
|
import warnings
|
|
9
10
|
warnings.filterwarnings("ignore")
|
|
@@ -14,6 +15,7 @@ from bayes_opt.acquisition import AcquisitionFunction, UpperConfidenceBound
|
|
|
14
15
|
import jax
|
|
15
16
|
import numpy as np
|
|
16
17
|
|
|
18
|
+
from pyRDDLGym.core.debug.exception import raise_warning
|
|
17
19
|
from pyRDDLGym.core.env import RDDLEnv
|
|
18
20
|
|
|
19
21
|
from pyRDDLGym_jax.core.planner import (
|
|
@@ -64,6 +66,7 @@ class JaxParameterTuning:
|
|
|
64
66
|
hyperparams: Hyperparameters,
|
|
65
67
|
online: bool,
|
|
66
68
|
eval_trials: int=5,
|
|
69
|
+
rollouts_per_trial: int=1,
|
|
67
70
|
verbose: bool=True,
|
|
68
71
|
timeout_tuning: float=np.inf,
|
|
69
72
|
pool_context: str='spawn',
|
|
@@ -87,6 +90,8 @@ class JaxParameterTuning:
|
|
|
87
90
|
hyperparameters in general (in seconds)
|
|
88
91
|
:param eval_trials: how many trials to perform independent training
|
|
89
92
|
in order to estimate the return for each set of hyper-parameters
|
|
93
|
+
:param rollouts_per_trial: how many rollouts to perform during evaluation
|
|
94
|
+
at the end of each training trial (only applies when online=False)
|
|
90
95
|
:param verbose: whether to print intermediate results of tuning
|
|
91
96
|
:param pool_context: context for multiprocessing pool (default "spawn")
|
|
92
97
|
:param num_workers: how many points to evaluate in parallel
|
|
@@ -108,6 +113,7 @@ class JaxParameterTuning:
|
|
|
108
113
|
self.hyperparams_dict = hyperparams_dict
|
|
109
114
|
self.online = online
|
|
110
115
|
self.eval_trials = eval_trials
|
|
116
|
+
self.rollouts_per_trial = rollouts_per_trial
|
|
111
117
|
self.verbose = verbose
|
|
112
118
|
|
|
113
119
|
# Bayesian parameters
|
|
@@ -154,6 +160,7 @@ class JaxParameterTuning:
|
|
|
154
160
|
f' mp_pool_poll_frequency ={self.poll_frequency}\n'
|
|
155
161
|
f'meta-objective parameters:\n'
|
|
156
162
|
f' planning_trials_per_iter ={self.eval_trials}\n'
|
|
163
|
+
f' rollouts_per_trial ={self.rollouts_per_trial}\n'
|
|
157
164
|
f' acquisition_fn ={self.acquisition}')
|
|
158
165
|
|
|
159
166
|
@staticmethod
|
|
@@ -200,12 +207,14 @@ class JaxParameterTuning:
|
|
|
200
207
|
|
|
201
208
|
@staticmethod
|
|
202
209
|
def offline_trials(env, planner, train_args, key, iteration, index, num_trials,
|
|
203
|
-
verbose, viz, queue):
|
|
210
|
+
rollouts_per_trial, verbose, viz, queue):
|
|
204
211
|
average_reward = 0.0
|
|
205
212
|
for trial in range(num_trials):
|
|
206
213
|
key, subkey = jax.random.split(key)
|
|
214
|
+
|
|
215
|
+
# for the dashboard
|
|
207
216
|
experiment_id = f'iter={iteration}, worker={index}, trial={trial}'
|
|
208
|
-
if queue is not None:
|
|
217
|
+
if queue is not None and JaxPlannerDashboard is not None:
|
|
209
218
|
queue.put((
|
|
210
219
|
experiment_id,
|
|
211
220
|
JaxPlannerDashboard.get_planner_info(planner),
|
|
@@ -224,7 +233,8 @@ class JaxParameterTuning:
|
|
|
224
233
|
policy = JaxOfflineController(
|
|
225
234
|
planner=planner, key=subkey, tqdm_position=index,
|
|
226
235
|
params=best_params, train_on_reset=False)
|
|
227
|
-
total_reward = policy.evaluate(env,
|
|
236
|
+
total_reward = policy.evaluate(env, episodes=rollouts_per_trial,
|
|
237
|
+
seed=np.array(subkey)[0])['mean']
|
|
228
238
|
|
|
229
239
|
# update average reward
|
|
230
240
|
if verbose:
|
|
@@ -243,8 +253,10 @@ class JaxParameterTuning:
|
|
|
243
253
|
average_reward = 0.0
|
|
244
254
|
for trial in range(num_trials):
|
|
245
255
|
key, subkey = jax.random.split(key)
|
|
256
|
+
|
|
257
|
+
# for the dashboard
|
|
246
258
|
experiment_id = f'iter={iteration}, worker={index}, trial={trial}'
|
|
247
|
-
if queue is not None:
|
|
259
|
+
if queue is not None and JaxPlannerDashboard is not None:
|
|
248
260
|
queue.put((
|
|
249
261
|
experiment_id,
|
|
250
262
|
JaxPlannerDashboard.get_planner_info(planner),
|
|
@@ -304,6 +316,7 @@ class JaxParameterTuning:
|
|
|
304
316
|
domain = kwargs['domain']
|
|
305
317
|
instance = kwargs['instance']
|
|
306
318
|
num_trials = kwargs['eval_trials']
|
|
319
|
+
rollouts_per_trial = kwargs['rollouts_per_trial']
|
|
307
320
|
viz = kwargs['viz']
|
|
308
321
|
verbose = kwargs['verbose']
|
|
309
322
|
|
|
@@ -332,7 +345,7 @@ class JaxParameterTuning:
|
|
|
332
345
|
else:
|
|
333
346
|
average_reward = JaxParameterTuning.offline_trials(
|
|
334
347
|
env, planner, train_args, key, iteration, index,
|
|
335
|
-
num_trials, verbose, viz, queue
|
|
348
|
+
num_trials, rollouts_per_trial, verbose, viz, queue
|
|
336
349
|
)
|
|
337
350
|
|
|
338
351
|
pid = os.getpid()
|
|
@@ -353,7 +366,7 @@ class JaxParameterTuning:
|
|
|
353
366
|
writer.writerow(COLUMNS + list(self.hyperparams_dict.keys()))
|
|
354
367
|
|
|
355
368
|
# create a dash-board for visualizing experiment runs
|
|
356
|
-
if show_dashboard:
|
|
369
|
+
if show_dashboard and JaxPlannerDashboard is not None:
|
|
357
370
|
dashboard = JaxPlannerDashboard()
|
|
358
371
|
dashboard.launch()
|
|
359
372
|
|
|
@@ -365,6 +378,7 @@ class JaxParameterTuning:
|
|
|
365
378
|
'domain': self.env.domain_text,
|
|
366
379
|
'instance': self.env.instance_text,
|
|
367
380
|
'eval_trials': self.eval_trials,
|
|
381
|
+
'rollouts_per_trial': self.rollouts_per_trial,
|
|
368
382
|
'viz': self.env._visualizer,
|
|
369
383
|
'verbose': self.verbose
|
|
370
384
|
}
|
|
@@ -1405,7 +1405,7 @@ class JaxPlannerDashboard:
|
|
|
1405
1405
|
self.test_reward_dist[experiment_id] = callback['reward']
|
|
1406
1406
|
self.train_state_fluents[experiment_id] = {
|
|
1407
1407
|
name: np.asarray(callback['train_log']['fluents'][name])
|
|
1408
|
-
for name in rddl.state_fluents
|
|
1408
|
+
for name in rddl.state_fluents
|
|
1409
1409
|
}
|
|
1410
1410
|
self.test_state_fluents[experiment_id] = {
|
|
1411
1411
|
name: np.asarray(callback['fluents'][name])
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: pyRDDLGym-jax
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.1
|
|
4
4
|
Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
|
|
5
5
|
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
6
|
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
@@ -31,7 +31,17 @@ Requires-Dist: dash-bootstrap-components>=1.6.0; extra == "dashboard"
|
|
|
31
31
|
|
|
32
32
|
# pyRDDLGym-jax
|
|
33
33
|
|
|
34
|
-
|
|
34
|
+

|
|
35
|
+
[](https://pypi.org/project/pyRDDLGym-jax/)
|
|
36
|
+
[](https://pyrddlgym.readthedocs.io/en/latest/jax.html)
|
|
37
|
+

|
|
38
|
+
[](https://pypistats.org/packages/pyrddlgym-jax)
|
|
39
|
+
|
|
40
|
+
[Installation](#installation) | [Run cmd](#running-from-the-command-line) | [Run python](#running-from-another-python-application) | [Configuration](#configuring-the-planner) | [Dashboard](#jaxplan-dashboard) | [Tuning](#tuning-the-planner) | [Simulation](#simulation) | [Citing](#citing-jaxplan)
|
|
41
|
+
|
|
42
|
+
**pyRDDLGym-jax (known in the literature as JaxPlan) is an efficient gradient-based/differentiable planning algorithm in JAX.**
|
|
43
|
+
|
|
44
|
+
Purpose:
|
|
35
45
|
|
|
36
46
|
1. automatic translation of any RDDL description file into a differentiable simulator in JAX
|
|
37
47
|
2. flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and Bayesian hyper-parameter tuning.
|
|
@@ -56,17 +66,6 @@ and was moved to the individual logic components which have their own unique wei
|
|
|
56
66
|
> [!NOTE]
|
|
57
67
|
> While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
|
|
58
68
|
> If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
|
|
59
|
-
|
|
60
|
-
## Contents
|
|
61
|
-
|
|
62
|
-
- [Installation](#installation)
|
|
63
|
-
- [Running from the Command Line](#running-from-the-command-line)
|
|
64
|
-
- [Running from Another Python Application](#running-from-another-python-application)
|
|
65
|
-
- [Configuring the Planner](#configuring-the-planner)
|
|
66
|
-
- [JaxPlan Dashboard](#jaxplan-dashboard)
|
|
67
|
-
- [Tuning the Planner](#tuning-the-planner)
|
|
68
|
-
- [Simulation](#simulation)
|
|
69
|
-
- [Citing JaxPlan](#citing-jaxplan)
|
|
70
69
|
|
|
71
70
|
## Installation
|
|
72
71
|
|
|
@@ -19,7 +19,7 @@ long_description = (Path(__file__).parent / "README.md").read_text()
|
|
|
19
19
|
|
|
20
20
|
setup(
|
|
21
21
|
name='pyRDDLGym-jax',
|
|
22
|
-
version='1.
|
|
22
|
+
version='1.1',
|
|
23
23
|
author="Michael Gimelfarb, Ayal Taitler, Scott Sanner",
|
|
24
24
|
author_email="mike.gimelfarb@mail.utoronto.ca, ataitler@gmail.com, ssanner@mie.utoronto.ca",
|
|
25
25
|
description="pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.",
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
__version__ = '1.0'
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg
RENAMED
|
File without changes
|
{pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg
RENAMED
|
File without changes
|
|
File without changes
|
{pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg
RENAMED
|
File without changes
|
{pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg
RENAMED
|
File without changes
|
|
File without changes
|
{pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg
RENAMED
|
File without changes
|
|
File without changes
|
{pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg
RENAMED
|
File without changes
|
{pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg
RENAMED
|
File without changes
|
{pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg
RENAMED
|
File without changes
|
|
File without changes
|
{pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|