pyRDDLGym-jax 1.0__tar.gz → 1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/PKG-INFO +12 -13
  2. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/README.md +11 -12
  3. pyrddlgym_jax-1.1/pyRDDLGym_jax/__init__.py +1 -0
  4. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/core/compiler.py +60 -60
  5. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/core/planner.py +8 -10
  6. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/core/tuning.py +20 -6
  7. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/core/visualization.py +1 -1
  8. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax.egg-info/PKG-INFO +12 -13
  9. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/setup.py +1 -1
  10. pyrddlgym_jax-1.0/pyRDDLGym_jax/__init__.py +0 -1
  11. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/LICENSE +0 -0
  12. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/core/__init__.py +0 -0
  13. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/core/logic.py +0 -0
  14. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/core/simulator.py +0 -0
  15. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/__init__.py +0 -0
  16. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +0 -0
  17. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +0 -0
  18. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +0 -0
  19. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +0 -0
  20. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +0 -0
  21. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +0 -0
  22. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +0 -0
  23. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +0 -0
  24. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +0 -0
  25. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +0 -0
  26. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +0 -0
  27. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +0 -0
  28. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +0 -0
  29. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +0 -0
  30. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +0 -0
  31. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +0 -0
  32. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +0 -0
  33. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +0 -0
  34. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +0 -0
  35. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/__init__.py +0 -0
  36. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/default_drp.cfg +0 -0
  37. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/default_replan.cfg +0 -0
  38. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/default_slp.cfg +0 -0
  39. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/tuning_drp.cfg +0 -0
  40. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/tuning_replan.cfg +0 -0
  41. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/configs/tuning_slp.cfg +0 -0
  42. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/run_gradient.py +0 -0
  43. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/run_gym.py +0 -0
  44. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/run_plan.py +0 -0
  45. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/run_scipy.py +0 -0
  46. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax/examples/run_tune.py +0 -0
  47. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax.egg-info/SOURCES.txt +0 -0
  48. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax.egg-info/dependency_links.txt +0 -0
  49. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax.egg-info/requires.txt +0 -0
  50. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/pyRDDLGym_jax.egg-info/top_level.txt +0 -0
  51. {pyrddlgym_jax-1.0 → pyrddlgym_jax-1.1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pyRDDLGym-jax
3
- Version: 1.0
3
+ Version: 1.1
4
4
  Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
5
  Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
6
  Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
@@ -31,7 +31,17 @@ Requires-Dist: dash-bootstrap-components>=1.6.0; extra == "dashboard"
31
31
 
32
32
  # pyRDDLGym-jax
33
33
 
34
- **pyRDDLGym-jax (known in the literature as JaxPlan) is an efficient gradient-based/differentiable planning algorithm in JAX.** It provides:
34
+ ![Python Version](https://img.shields.io/badge/python-3.9%2B-blue)
35
+ [![PyPI Version](https://img.shields.io/pypi/v/pyRDDLGym-jax.svg)](https://pypi.org/project/pyRDDLGym-jax/)
36
+ [![Documentation Status](https://readthedocs.org/projects/pyrddlgym/badge/?version=latest)](https://pyrddlgym.readthedocs.io/en/latest/jax.html)
37
+ ![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)
38
+ [![Cumulative PyPI Downloads](https://img.shields.io/pypi/dm/pyrddlgym-jax)](https://pypistats.org/packages/pyrddlgym-jax)
39
+
40
+ [Installation](#installation) | [Run cmd](#running-from-the-command-line) | [Run python](#running-from-another-python-application) | [Configuration](#configuring-the-planner) | [Dashboard](#jaxplan-dashboard) | [Tuning](#tuning-the-planner) | [Simulation](#simulation) | [Citing](#citing-jaxplan)
41
+
42
+ **pyRDDLGym-jax (known in the literature as JaxPlan) is an efficient gradient-based/differentiable planning algorithm in JAX.**
43
+
44
+ Purpose:
35
45
 
36
46
  1. automatic translation of any RDDL description file into a differentiable simulator in JAX
37
47
  2. flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and Bayesian hyper-parameter tuning.
@@ -56,17 +66,6 @@ and was moved to the individual logic components which have their own unique wei
56
66
  > [!NOTE]
57
67
  > While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
58
68
  > If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
59
-
60
- ## Contents
61
-
62
- - [Installation](#installation)
63
- - [Running from the Command Line](#running-from-the-command-line)
64
- - [Running from Another Python Application](#running-from-another-python-application)
65
- - [Configuring the Planner](#configuring-the-planner)
66
- - [JaxPlan Dashboard](#jaxplan-dashboard)
67
- - [Tuning the Planner](#tuning-the-planner)
68
- - [Simulation](#simulation)
69
- - [Citing JaxPlan](#citing-jaxplan)
70
69
 
71
70
  ## Installation
72
71
 
@@ -1,6 +1,16 @@
1
1
  # pyRDDLGym-jax
2
2
 
3
- **pyRDDLGym-jax (known in the literature as JaxPlan) is an efficient gradient-based/differentiable planning algorithm in JAX.** It provides:
3
+ ![Python Version](https://img.shields.io/badge/python-3.9%2B-blue)
4
+ [![PyPI Version](https://img.shields.io/pypi/v/pyRDDLGym-jax.svg)](https://pypi.org/project/pyRDDLGym-jax/)
5
+ [![Documentation Status](https://readthedocs.org/projects/pyrddlgym/badge/?version=latest)](https://pyrddlgym.readthedocs.io/en/latest/jax.html)
6
+ ![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)
7
+ [![Cumulative PyPI Downloads](https://img.shields.io/pypi/dm/pyrddlgym-jax)](https://pypistats.org/packages/pyrddlgym-jax)
8
+
9
+ [Installation](#installation) | [Run cmd](#running-from-the-command-line) | [Run python](#running-from-another-python-application) | [Configuration](#configuring-the-planner) | [Dashboard](#jaxplan-dashboard) | [Tuning](#tuning-the-planner) | [Simulation](#simulation) | [Citing](#citing-jaxplan)
10
+
11
+ **pyRDDLGym-jax (known in the literature as JaxPlan) is an efficient gradient-based/differentiable planning algorithm in JAX.**
12
+
13
+ Purpose:
4
14
 
5
15
  1. automatic translation of any RDDL description file into a differentiable simulator in JAX
6
16
  2. flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and Bayesian hyper-parameter tuning.
@@ -25,17 +35,6 @@ and was moved to the individual logic components which have their own unique wei
25
35
  > [!NOTE]
26
36
  > While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
27
37
  > If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
28
-
29
- ## Contents
30
-
31
- - [Installation](#installation)
32
- - [Running from the Command Line](#running-from-the-command-line)
33
- - [Running from Another Python Application](#running-from-another-python-application)
34
- - [Configuring the Planner](#configuring-the-planner)
35
- - [JaxPlan Dashboard](#jaxplan-dashboard)
36
- - [Tuning the Planner](#tuning-the-planner)
37
- - [Simulation](#simulation)
38
- - [Citing JaxPlan](#citing-jaxplan)
39
38
 
40
39
  ## Installation
41
40
 
@@ -0,0 +1 @@
1
+ __version__ = '1.1'
@@ -51,65 +51,65 @@ class JaxRDDLCompiler:
51
51
  return func
52
52
  return exact_func
53
53
 
54
- EXACT_RDDL_TO_JAX_NEGATIVE = wrap_logic(ExactLogic.exact_unary_function(jnp.negative))
54
+ EXACT_RDDL_TO_JAX_NEGATIVE = wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.negative))
55
55
  EXACT_RDDL_TO_JAX_ARITHMETIC = {
56
- '+': wrap_logic(ExactLogic.exact_binary_function(jnp.add)),
57
- '-': wrap_logic(ExactLogic.exact_binary_function(jnp.subtract)),
58
- '*': wrap_logic(ExactLogic.exact_binary_function(jnp.multiply)),
59
- '/': wrap_logic(ExactLogic.exact_binary_function(jnp.divide))
56
+ '+': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.add)),
57
+ '-': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.subtract)),
58
+ '*': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.multiply)),
59
+ '/': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.divide))
60
60
  }
61
61
 
62
62
  EXACT_RDDL_TO_JAX_RELATIONAL = {
63
- '>=': wrap_logic(ExactLogic.exact_binary_function(jnp.greater_equal)),
64
- '<=': wrap_logic(ExactLogic.exact_binary_function(jnp.less_equal)),
65
- '<': wrap_logic(ExactLogic.exact_binary_function(jnp.less)),
66
- '>': wrap_logic(ExactLogic.exact_binary_function(jnp.greater)),
67
- '==': wrap_logic(ExactLogic.exact_binary_function(jnp.equal)),
68
- '~=': wrap_logic(ExactLogic.exact_binary_function(jnp.not_equal))
63
+ '>=': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.greater_equal)),
64
+ '<=': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.less_equal)),
65
+ '<': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.less)),
66
+ '>': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.greater)),
67
+ '==': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.equal)),
68
+ '~=': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.not_equal))
69
69
  }
70
70
 
71
- EXACT_RDDL_TO_JAX_LOGICAL_NOT = wrap_logic(ExactLogic.exact_unary_function(jnp.logical_not))
71
+ EXACT_RDDL_TO_JAX_LOGICAL_NOT = wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.logical_not))
72
72
  EXACT_RDDL_TO_JAX_LOGICAL = {
73
- '^': wrap_logic(ExactLogic.exact_binary_function(jnp.logical_and)),
74
- '&': wrap_logic(ExactLogic.exact_binary_function(jnp.logical_and)),
75
- '|': wrap_logic(ExactLogic.exact_binary_function(jnp.logical_or)),
76
- '~': wrap_logic(ExactLogic.exact_binary_function(jnp.logical_xor)),
77
- '=>': wrap_logic(ExactLogic.exact_binary_implies),
78
- '<=>': wrap_logic(ExactLogic.exact_binary_function(jnp.equal))
73
+ '^': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_and)),
74
+ '&': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_and)),
75
+ '|': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_or)),
76
+ '~': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_xor)),
77
+ '=>': wrap_logic.__func__(ExactLogic.exact_binary_implies),
78
+ '<=>': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.equal))
79
79
  }
80
80
 
81
81
  EXACT_RDDL_TO_JAX_AGGREGATION = {
82
- 'sum': wrap_logic(ExactLogic.exact_aggregation(jnp.sum)),
83
- 'avg': wrap_logic(ExactLogic.exact_aggregation(jnp.mean)),
84
- 'prod': wrap_logic(ExactLogic.exact_aggregation(jnp.prod)),
85
- 'minimum': wrap_logic(ExactLogic.exact_aggregation(jnp.min)),
86
- 'maximum': wrap_logic(ExactLogic.exact_aggregation(jnp.max)),
87
- 'forall': wrap_logic(ExactLogic.exact_aggregation(jnp.all)),
88
- 'exists': wrap_logic(ExactLogic.exact_aggregation(jnp.any)),
89
- 'argmin': wrap_logic(ExactLogic.exact_aggregation(jnp.argmin)),
90
- 'argmax': wrap_logic(ExactLogic.exact_aggregation(jnp.argmax))
82
+ 'sum': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.sum)),
83
+ 'avg': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.mean)),
84
+ 'prod': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.prod)),
85
+ 'minimum': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.min)),
86
+ 'maximum': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.max)),
87
+ 'forall': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.all)),
88
+ 'exists': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.any)),
89
+ 'argmin': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.argmin)),
90
+ 'argmax': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.argmax))
91
91
  }
92
92
 
93
93
  EXACT_RDDL_TO_JAX_UNARY = {
94
- 'abs': wrap_logic(ExactLogic.exact_unary_function(jnp.abs)),
95
- 'sgn': wrap_logic(ExactLogic.exact_unary_function(jnp.sign)),
96
- 'round': wrap_logic(ExactLogic.exact_unary_function(jnp.round)),
97
- 'floor': wrap_logic(ExactLogic.exact_unary_function(jnp.floor)),
98
- 'ceil': wrap_logic(ExactLogic.exact_unary_function(jnp.ceil)),
99
- 'cos': wrap_logic(ExactLogic.exact_unary_function(jnp.cos)),
100
- 'sin': wrap_logic(ExactLogic.exact_unary_function(jnp.sin)),
101
- 'tan': wrap_logic(ExactLogic.exact_unary_function(jnp.tan)),
102
- 'acos': wrap_logic(ExactLogic.exact_unary_function(jnp.arccos)),
103
- 'asin': wrap_logic(ExactLogic.exact_unary_function(jnp.arcsin)),
104
- 'atan': wrap_logic(ExactLogic.exact_unary_function(jnp.arctan)),
105
- 'cosh': wrap_logic(ExactLogic.exact_unary_function(jnp.cosh)),
106
- 'sinh': wrap_logic(ExactLogic.exact_unary_function(jnp.sinh)),
107
- 'tanh': wrap_logic(ExactLogic.exact_unary_function(jnp.tanh)),
108
- 'exp': wrap_logic(ExactLogic.exact_unary_function(jnp.exp)),
109
- 'ln': wrap_logic(ExactLogic.exact_unary_function(jnp.log)),
110
- 'sqrt': wrap_logic(ExactLogic.exact_unary_function(jnp.sqrt)),
111
- 'lngamma': wrap_logic(ExactLogic.exact_unary_function(scipy.special.gammaln)),
112
- 'gamma': wrap_logic(ExactLogic.exact_unary_function(scipy.special.gamma))
94
+ 'abs': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.abs)),
95
+ 'sgn': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sign)),
96
+ 'round': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.round)),
97
+ 'floor': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.floor)),
98
+ 'ceil': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.ceil)),
99
+ 'cos': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.cos)),
100
+ 'sin': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sin)),
101
+ 'tan': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.tan)),
102
+ 'acos': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.arccos)),
103
+ 'asin': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.arcsin)),
104
+ 'atan': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.arctan)),
105
+ 'cosh': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.cosh)),
106
+ 'sinh': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sinh)),
107
+ 'tanh': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.tanh)),
108
+ 'exp': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.exp)),
109
+ 'ln': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.log)),
110
+ 'sqrt': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sqrt)),
111
+ 'lngamma': wrap_logic.__func__(ExactLogic.exact_unary_function(scipy.special.gammaln)),
112
+ 'gamma': wrap_logic.__func__(ExactLogic.exact_unary_function(scipy.special.gamma))
113
113
  }
114
114
 
115
115
  @staticmethod
@@ -117,23 +117,23 @@ class JaxRDDLCompiler:
117
117
  return jnp.log(x) / jnp.log(y), params
118
118
 
119
119
  EXACT_RDDL_TO_JAX_BINARY = {
120
- 'div': wrap_logic(ExactLogic.exact_binary_function(jnp.floor_divide)),
121
- 'mod': wrap_logic(ExactLogic.exact_binary_function(jnp.mod)),
122
- 'fmod': wrap_logic(ExactLogic.exact_binary_function(jnp.mod)),
123
- 'min': wrap_logic(ExactLogic.exact_binary_function(jnp.minimum)),
124
- 'max': wrap_logic(ExactLogic.exact_binary_function(jnp.maximum)),
125
- 'pow': wrap_logic(ExactLogic.exact_binary_function(jnp.power)),
126
- 'log': wrap_logic(_jax_wrapped_calc_log_exact),
127
- 'hypot': wrap_logic(ExactLogic.exact_binary_function(jnp.hypot)),
120
+ 'div': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.floor_divide)),
121
+ 'mod': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.mod)),
122
+ 'fmod': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.mod)),
123
+ 'min': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.minimum)),
124
+ 'max': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.maximum)),
125
+ 'pow': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.power)),
126
+ 'log': wrap_logic.__func__(_jax_wrapped_calc_log_exact.__func__),
127
+ 'hypot': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.hypot)),
128
128
  }
129
129
 
130
- EXACT_RDDL_TO_JAX_IF = wrap_logic(ExactLogic.exact_if_then_else)
131
- EXACT_RDDL_TO_JAX_SWITCH = wrap_logic(ExactLogic.exact_switch)
130
+ EXACT_RDDL_TO_JAX_IF = wrap_logic.__func__(ExactLogic.exact_if_then_else)
131
+ EXACT_RDDL_TO_JAX_SWITCH = wrap_logic.__func__(ExactLogic.exact_switch)
132
132
 
133
- EXACT_RDDL_TO_JAX_BERNOULLI = wrap_logic(ExactLogic.exact_bernoulli)
134
- EXACT_RDDL_TO_JAX_DISCRETE = wrap_logic(ExactLogic.exact_discrete)
135
- EXACT_RDDL_TO_JAX_POISSON = wrap_logic(ExactLogic.exact_poisson)
136
- EXACT_RDDL_TO_JAX_GEOMETRIC = wrap_logic(ExactLogic.exact_geometric)
133
+ EXACT_RDDL_TO_JAX_BERNOULLI = wrap_logic.__func__(ExactLogic.exact_bernoulli)
134
+ EXACT_RDDL_TO_JAX_DISCRETE = wrap_logic.__func__(ExactLogic.exact_discrete)
135
+ EXACT_RDDL_TO_JAX_POISSON = wrap_logic.__func__(ExactLogic.exact_poisson)
136
+ EXACT_RDDL_TO_JAX_GEOMETRIC = wrap_logic.__func__(ExactLogic.exact_geometric)
137
137
 
138
138
  def __init__(self, rddl: RDDLLiftedModel,
139
139
  allow_synchronous_state: bool=True,
@@ -65,9 +65,8 @@ def _parse_config_file(path: str):
65
65
  config = configparser.RawConfigParser()
66
66
  config.optionxform = str
67
67
  config.read(path)
68
- args = {k: literal_eval(v)
69
- for section in config.sections()
70
- for (k, v) in config.items(section)}
68
+ args = {section: {k: literal_eval(v) for (k, v) in config.items(section)}
69
+ for section in config.sections()}
71
70
  return config, args
72
71
 
73
72
 
@@ -75,9 +74,8 @@ def _parse_config_string(value: str):
75
74
  config = configparser.RawConfigParser()
76
75
  config.optionxform = str
77
76
  config.read_string(value)
78
- args = {k: literal_eval(v)
79
- for section in config.sections()
80
- for (k, v) in config.items(section)}
77
+ args = {section: {k: literal_eval(v) for (k, v) in config.items(section)}
78
+ for section in config.sections()}
81
79
  return config, args
82
80
 
83
81
 
@@ -90,9 +88,9 @@ def _getattr_any(packages, item):
90
88
 
91
89
 
92
90
  def _load_config(config, args):
93
- model_args = {k: args[k] for (k, _) in config.items('Model')}
94
- planner_args = {k: args[k] for (k, _) in config.items('Optimizer')}
95
- train_args = {k: args[k] for (k, _) in config.items('Training')}
91
+ model_args = {k: args['Model'][k] for (k, _) in config.items('Model')}
92
+ planner_args = {k: args['Optimizer'][k] for (k, _) in config.items('Optimizer')}
93
+ train_args = {k: args['Training'][k] for (k, _) in config.items('Training')}
96
94
 
97
95
  # read the model settings
98
96
  logic_name = model_args.get('logic', 'FuzzyLogic')
@@ -1661,7 +1659,7 @@ r"""
1661
1659
  def optimize_generator(self, key: Optional[random.PRNGKey]=None,
1662
1660
  epochs: int=999999,
1663
1661
  train_seconds: float=120.,
1664
- dashboard: Optional[JaxPlannerDashboard]=None,
1662
+ dashboard: Optional[Any]=None,
1665
1663
  dashboard_id: Optional[str]=None,
1666
1664
  model_params: Optional[Dict[str, Any]]=None,
1667
1665
  policy_hyperparams: Optional[Dict[str, Any]]=None,
@@ -4,6 +4,7 @@ import threading
4
4
  import multiprocessing
5
5
  import os
6
6
  import time
7
+ import traceback
7
8
  from typing import Any, Callable, Dict, Iterable, Optional, Tuple
8
9
  import warnings
9
10
  warnings.filterwarnings("ignore")
@@ -14,6 +15,7 @@ from bayes_opt.acquisition import AcquisitionFunction, UpperConfidenceBound
14
15
  import jax
15
16
  import numpy as np
16
17
 
18
+ from pyRDDLGym.core.debug.exception import raise_warning
17
19
  from pyRDDLGym.core.env import RDDLEnv
18
20
 
19
21
  from pyRDDLGym_jax.core.planner import (
@@ -64,6 +66,7 @@ class JaxParameterTuning:
64
66
  hyperparams: Hyperparameters,
65
67
  online: bool,
66
68
  eval_trials: int=5,
69
+ rollouts_per_trial: int=1,
67
70
  verbose: bool=True,
68
71
  timeout_tuning: float=np.inf,
69
72
  pool_context: str='spawn',
@@ -87,6 +90,8 @@ class JaxParameterTuning:
87
90
  hyperparameters in general (in seconds)
88
91
  :param eval_trials: how many trials to perform independent training
89
92
  in order to estimate the return for each set of hyper-parameters
93
+ :param rollouts_per_trial: how many rollouts to perform during evaluation
94
+ at the end of each training trial (only applies when online=False)
90
95
  :param verbose: whether to print intermediate results of tuning
91
96
  :param pool_context: context for multiprocessing pool (default "spawn")
92
97
  :param num_workers: how many points to evaluate in parallel
@@ -108,6 +113,7 @@ class JaxParameterTuning:
108
113
  self.hyperparams_dict = hyperparams_dict
109
114
  self.online = online
110
115
  self.eval_trials = eval_trials
116
+ self.rollouts_per_trial = rollouts_per_trial
111
117
  self.verbose = verbose
112
118
 
113
119
  # Bayesian parameters
@@ -154,6 +160,7 @@ class JaxParameterTuning:
154
160
  f' mp_pool_poll_frequency ={self.poll_frequency}\n'
155
161
  f'meta-objective parameters:\n'
156
162
  f' planning_trials_per_iter ={self.eval_trials}\n'
163
+ f' rollouts_per_trial ={self.rollouts_per_trial}\n'
157
164
  f' acquisition_fn ={self.acquisition}')
158
165
 
159
166
  @staticmethod
@@ -200,12 +207,14 @@ class JaxParameterTuning:
200
207
 
201
208
  @staticmethod
202
209
  def offline_trials(env, planner, train_args, key, iteration, index, num_trials,
203
- verbose, viz, queue):
210
+ rollouts_per_trial, verbose, viz, queue):
204
211
  average_reward = 0.0
205
212
  for trial in range(num_trials):
206
213
  key, subkey = jax.random.split(key)
214
+
215
+ # for the dashboard
207
216
  experiment_id = f'iter={iteration}, worker={index}, trial={trial}'
208
- if queue is not None:
217
+ if queue is not None and JaxPlannerDashboard is not None:
209
218
  queue.put((
210
219
  experiment_id,
211
220
  JaxPlannerDashboard.get_planner_info(planner),
@@ -224,7 +233,8 @@ class JaxParameterTuning:
224
233
  policy = JaxOfflineController(
225
234
  planner=planner, key=subkey, tqdm_position=index,
226
235
  params=best_params, train_on_reset=False)
227
- total_reward = policy.evaluate(env, seed=np.array(subkey)[0])['mean']
236
+ total_reward = policy.evaluate(env, episodes=rollouts_per_trial,
237
+ seed=np.array(subkey)[0])['mean']
228
238
 
229
239
  # update average reward
230
240
  if verbose:
@@ -243,8 +253,10 @@ class JaxParameterTuning:
243
253
  average_reward = 0.0
244
254
  for trial in range(num_trials):
245
255
  key, subkey = jax.random.split(key)
256
+
257
+ # for the dashboard
246
258
  experiment_id = f'iter={iteration}, worker={index}, trial={trial}'
247
- if queue is not None:
259
+ if queue is not None and JaxPlannerDashboard is not None:
248
260
  queue.put((
249
261
  experiment_id,
250
262
  JaxPlannerDashboard.get_planner_info(planner),
@@ -304,6 +316,7 @@ class JaxParameterTuning:
304
316
  domain = kwargs['domain']
305
317
  instance = kwargs['instance']
306
318
  num_trials = kwargs['eval_trials']
319
+ rollouts_per_trial = kwargs['rollouts_per_trial']
307
320
  viz = kwargs['viz']
308
321
  verbose = kwargs['verbose']
309
322
 
@@ -332,7 +345,7 @@ class JaxParameterTuning:
332
345
  else:
333
346
  average_reward = JaxParameterTuning.offline_trials(
334
347
  env, planner, train_args, key, iteration, index,
335
- num_trials, verbose, viz, queue
348
+ num_trials, rollouts_per_trial, verbose, viz, queue
336
349
  )
337
350
 
338
351
  pid = os.getpid()
@@ -353,7 +366,7 @@ class JaxParameterTuning:
353
366
  writer.writerow(COLUMNS + list(self.hyperparams_dict.keys()))
354
367
 
355
368
  # create a dash-board for visualizing experiment runs
356
- if show_dashboard:
369
+ if show_dashboard and JaxPlannerDashboard is not None:
357
370
  dashboard = JaxPlannerDashboard()
358
371
  dashboard.launch()
359
372
 
@@ -365,6 +378,7 @@ class JaxParameterTuning:
365
378
  'domain': self.env.domain_text,
366
379
  'instance': self.env.instance_text,
367
380
  'eval_trials': self.eval_trials,
381
+ 'rollouts_per_trial': self.rollouts_per_trial,
368
382
  'viz': self.env._visualizer,
369
383
  'verbose': self.verbose
370
384
  }
@@ -1405,7 +1405,7 @@ class JaxPlannerDashboard:
1405
1405
  self.test_reward_dist[experiment_id] = callback['reward']
1406
1406
  self.train_state_fluents[experiment_id] = {
1407
1407
  name: np.asarray(callback['train_log']['fluents'][name])
1408
- for name in rddl.state_fluents or name in rddl.observ_fluents
1408
+ for name in rddl.state_fluents
1409
1409
  }
1410
1410
  self.test_state_fluents[experiment_id] = {
1411
1411
  name: np.asarray(callback['fluents'][name])
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pyRDDLGym-jax
3
- Version: 1.0
3
+ Version: 1.1
4
4
  Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
5
  Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
6
  Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
@@ -31,7 +31,17 @@ Requires-Dist: dash-bootstrap-components>=1.6.0; extra == "dashboard"
31
31
 
32
32
  # pyRDDLGym-jax
33
33
 
34
- **pyRDDLGym-jax (known in the literature as JaxPlan) is an efficient gradient-based/differentiable planning algorithm in JAX.** It provides:
34
+ ![Python Version](https://img.shields.io/badge/python-3.9%2B-blue)
35
+ [![PyPI Version](https://img.shields.io/pypi/v/pyRDDLGym-jax.svg)](https://pypi.org/project/pyRDDLGym-jax/)
36
+ [![Documentation Status](https://readthedocs.org/projects/pyrddlgym/badge/?version=latest)](https://pyrddlgym.readthedocs.io/en/latest/jax.html)
37
+ ![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)
38
+ [![Cumulative PyPI Downloads](https://img.shields.io/pypi/dm/pyrddlgym-jax)](https://pypistats.org/packages/pyrddlgym-jax)
39
+
40
+ [Installation](#installation) | [Run cmd](#running-from-the-command-line) | [Run python](#running-from-another-python-application) | [Configuration](#configuring-the-planner) | [Dashboard](#jaxplan-dashboard) | [Tuning](#tuning-the-planner) | [Simulation](#simulation) | [Citing](#citing-jaxplan)
41
+
42
+ **pyRDDLGym-jax (known in the literature as JaxPlan) is an efficient gradient-based/differentiable planning algorithm in JAX.**
43
+
44
+ Purpose:
35
45
 
36
46
  1. automatic translation of any RDDL description file into a differentiable simulator in JAX
37
47
  2. flexible policy class representations, automatic model relaxations for working in discrete and hybrid domains, and Bayesian hyper-parameter tuning.
@@ -56,17 +66,6 @@ and was moved to the individual logic components which have their own unique wei
56
66
  > [!NOTE]
57
67
  > While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!).
58
68
  > If you find it is not making sufficient progress, check out the [PROST planner](https://github.com/pyrddlgym-project/pyRDDLGym-prost) (for discrete spaces) or the [deep reinforcement learning wrappers](https://github.com/pyrddlgym-project/pyRDDLGym-rl).
59
-
60
- ## Contents
61
-
62
- - [Installation](#installation)
63
- - [Running from the Command Line](#running-from-the-command-line)
64
- - [Running from Another Python Application](#running-from-another-python-application)
65
- - [Configuring the Planner](#configuring-the-planner)
66
- - [JaxPlan Dashboard](#jaxplan-dashboard)
67
- - [Tuning the Planner](#tuning-the-planner)
68
- - [Simulation](#simulation)
69
- - [Citing JaxPlan](#citing-jaxplan)
70
69
 
71
70
  ## Installation
72
71
 
@@ -19,7 +19,7 @@ long_description = (Path(__file__).parent / "README.md").read_text()
19
19
 
20
20
  setup(
21
21
  name='pyRDDLGym-jax',
22
- version='1.0',
22
+ version='1.1',
23
23
  author="Michael Gimelfarb, Ayal Taitler, Scott Sanner",
24
24
  author_email="mike.gimelfarb@mail.utoronto.ca, ataitler@gmail.com, ssanner@mie.utoronto.ca",
25
25
  description="pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.",
@@ -1 +0,0 @@
1
- __version__ = '1.0'
File without changes
File without changes