pyRDDLGym-jax 0.4__tar.gz → 0.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/PKG-INFO +11 -9
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/README.md +6 -6
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/core/logic.py +115 -53
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/core/planner.py +140 -58
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/core/tuning.py +53 -58
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +2 -1
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +2 -1
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +2 -1
- pyrddlgym_jax-0.5/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +21 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/default_replan.cfg +2 -1
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/run_tune.py +1 -3
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax.egg-info/PKG-INFO +11 -9
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax.egg-info/requires.txt +4 -1
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/setup.py +6 -5
- pyrddlgym_jax-0.4/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +0 -20
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/LICENSE +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/__init__.py +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/core/__init__.py +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/core/compiler.py +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/core/simulator.py +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/__init__.py +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/__init__.py +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/default_drp.cfg +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/default_slp.cfg +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/run_gradient.py +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/run_gym.py +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/run_plan.py +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/run_scipy.py +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax.egg-info/SOURCES.txt +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax.egg-info/dependency_links.txt +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax.egg-info/top_level.txt +0 -0
- {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: pyRDDLGym-jax
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.5
|
|
4
4
|
Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
|
|
5
5
|
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
6
|
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
@@ -13,16 +13,18 @@ Classifier: Natural Language :: English
|
|
|
13
13
|
Classifier: Operating System :: OS Independent
|
|
14
14
|
Classifier: Programming Language :: Python :: 3
|
|
15
15
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
16
|
-
Requires-Python: >=3.
|
|
16
|
+
Requires-Python: >=3.9
|
|
17
17
|
Description-Content-Type: text/markdown
|
|
18
18
|
License-File: LICENSE
|
|
19
19
|
Requires-Dist: pyRDDLGym>=2.0
|
|
20
20
|
Requires-Dist: tqdm>=4.66
|
|
21
|
-
Requires-Dist: bayesian-optimization>=1.4.3
|
|
22
21
|
Requires-Dist: jax>=0.4.12
|
|
23
22
|
Requires-Dist: optax>=0.1.9
|
|
24
23
|
Requires-Dist: dm-haiku>=0.0.10
|
|
25
24
|
Requires-Dist: tensorflow-probability>=0.21.0
|
|
25
|
+
Provides-Extra: extra
|
|
26
|
+
Requires-Dist: bayesian-optimization>=2.0.0; extra == "extra"
|
|
27
|
+
Requires-Dist: rddlrepository>=2.0; extra == "extra"
|
|
26
28
|
|
|
27
29
|
# pyRDDLGym-jax
|
|
28
30
|
|
|
@@ -57,17 +59,17 @@ To use the compiler or planner without the automated hyper-parameter tuning, you
|
|
|
57
59
|
- ``tensorflow-probability>=0.21.0``
|
|
58
60
|
|
|
59
61
|
Additionally, if you wish to run the examples, you need ``rddlrepository>=2``.
|
|
60
|
-
To run the automated tuning optimization, you will also need ``bayesian-optimization>=
|
|
62
|
+
To run the automated tuning optimization, you will also need ``bayesian-optimization>=2.0.0``.
|
|
61
63
|
|
|
62
|
-
You can install
|
|
64
|
+
You can install pyRDDLGym-jax with all requirements using pip:
|
|
63
65
|
|
|
64
66
|
```shell
|
|
65
|
-
pip install
|
|
67
|
+
pip install pyRDDLGym-jax[extra]
|
|
66
68
|
```
|
|
67
69
|
|
|
68
70
|
## Running from the Command Line
|
|
69
71
|
|
|
70
|
-
A basic run script is provided to run the Jax Planner on any domain in ``rddlrepository
|
|
72
|
+
A basic run script is provided to run the Jax Planner on any domain in ``rddlrepository`` from the install directory of pyRDDLGym-jax:
|
|
71
73
|
|
|
72
74
|
```shell
|
|
73
75
|
python -m pyRDDLGym_jax.examples.run_plan <domain> <instance> <method> <episodes>
|
|
@@ -91,8 +93,8 @@ python -m pyRDDLGym_jax.examples.run_tune <domain> <instance> <method> <trials>
|
|
|
91
93
|
```
|
|
92
94
|
|
|
93
95
|
where:
|
|
94
|
-
- ``domain`` is the domain identifier as specified in rddlrepository
|
|
95
|
-
- ``instance`` is the instance identifier
|
|
96
|
+
- ``domain`` is the domain identifier as specified in rddlrepository
|
|
97
|
+
- ``instance`` is the instance identifier
|
|
96
98
|
- ``method`` is the planning method to use (i.e. drp, slp, replan)
|
|
97
99
|
- ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
|
|
98
100
|
- ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
|
|
@@ -31,17 +31,17 @@ To use the compiler or planner without the automated hyper-parameter tuning, you
|
|
|
31
31
|
- ``tensorflow-probability>=0.21.0``
|
|
32
32
|
|
|
33
33
|
Additionally, if you wish to run the examples, you need ``rddlrepository>=2``.
|
|
34
|
-
To run the automated tuning optimization, you will also need ``bayesian-optimization>=
|
|
34
|
+
To run the automated tuning optimization, you will also need ``bayesian-optimization>=2.0.0``.
|
|
35
35
|
|
|
36
|
-
You can install
|
|
36
|
+
You can install pyRDDLGym-jax with all requirements using pip:
|
|
37
37
|
|
|
38
38
|
```shell
|
|
39
|
-
pip install
|
|
39
|
+
pip install pyRDDLGym-jax[extra]
|
|
40
40
|
```
|
|
41
41
|
|
|
42
42
|
## Running from the Command Line
|
|
43
43
|
|
|
44
|
-
A basic run script is provided to run the Jax Planner on any domain in ``rddlrepository
|
|
44
|
+
A basic run script is provided to run the Jax Planner on any domain in ``rddlrepository`` from the install directory of pyRDDLGym-jax:
|
|
45
45
|
|
|
46
46
|
```shell
|
|
47
47
|
python -m pyRDDLGym_jax.examples.run_plan <domain> <instance> <method> <episodes>
|
|
@@ -65,8 +65,8 @@ python -m pyRDDLGym_jax.examples.run_tune <domain> <instance> <method> <trials>
|
|
|
65
65
|
```
|
|
66
66
|
|
|
67
67
|
where:
|
|
68
|
-
- ``domain`` is the domain identifier as specified in rddlrepository
|
|
69
|
-
- ``instance`` is the instance identifier
|
|
68
|
+
- ``domain`` is the domain identifier as specified in rddlrepository
|
|
69
|
+
- ``instance`` is the instance identifier
|
|
70
70
|
- ``method`` is the planning method to use (i.e. drp, slp, replan)
|
|
71
71
|
- ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
|
|
72
72
|
- ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
|
|
@@ -7,27 +7,6 @@ import jax.random as random
|
|
|
7
7
|
from pyRDDLGym.core.debug.exception import raise_warning
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
# ===========================================================================
|
|
11
|
-
# LOGICAL COMPLEMENT
|
|
12
|
-
# - abstract class
|
|
13
|
-
# - standard complement
|
|
14
|
-
#
|
|
15
|
-
# ===========================================================================
|
|
16
|
-
|
|
17
|
-
class Complement:
|
|
18
|
-
'''Base class for approximate logical complement operations.'''
|
|
19
|
-
|
|
20
|
-
def __call__(self, x):
|
|
21
|
-
raise NotImplementedError
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
class StandardComplement(Complement):
|
|
25
|
-
'''The standard approximate logical complement given by x -> 1 - x.'''
|
|
26
|
-
|
|
27
|
-
def __call__(self, x):
|
|
28
|
-
return 1.0 - x
|
|
29
|
-
|
|
30
|
-
|
|
31
10
|
# ===========================================================================
|
|
32
11
|
# RELATIONAL OPERATIONS
|
|
33
12
|
# - abstract class
|
|
@@ -47,10 +26,14 @@ class Comparison:
|
|
|
47
26
|
def equal(self, x, y, param):
|
|
48
27
|
raise NotImplementedError
|
|
49
28
|
|
|
29
|
+
def sgn(self, x, param):
|
|
30
|
+
raise NotImplementedError
|
|
31
|
+
|
|
50
32
|
|
|
51
33
|
class SigmoidComparison(Comparison):
|
|
52
34
|
'''Comparison operations approximated using sigmoid functions.'''
|
|
53
35
|
|
|
36
|
+
# https://arxiv.org/abs/2110.05651
|
|
54
37
|
def greater_equal(self, x, y, param):
|
|
55
38
|
return jax.nn.sigmoid(param * (x - y))
|
|
56
39
|
|
|
@@ -59,8 +42,65 @@ class SigmoidComparison(Comparison):
|
|
|
59
42
|
|
|
60
43
|
def equal(self, x, y, param):
|
|
61
44
|
return 1.0 - jnp.square(jnp.tanh(param * (y - x)))
|
|
45
|
+
|
|
46
|
+
def sgn(self, x, param):
|
|
47
|
+
return jnp.tanh(param * x)
|
|
62
48
|
|
|
63
|
-
|
|
49
|
+
|
|
50
|
+
# ===========================================================================
|
|
51
|
+
# ROUNDING OPERATIONS
|
|
52
|
+
# - abstract class
|
|
53
|
+
# - soft rounding
|
|
54
|
+
#
|
|
55
|
+
# ===========================================================================
|
|
56
|
+
|
|
57
|
+
class Rounding:
|
|
58
|
+
'''Base class for approximate rounding operations.'''
|
|
59
|
+
|
|
60
|
+
def floor(self, x, param):
|
|
61
|
+
raise NotImplementedError
|
|
62
|
+
|
|
63
|
+
def round(self, x, param):
|
|
64
|
+
raise NotImplementedError
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class SoftRounding(Rounding):
|
|
68
|
+
'''Rounding operations approximated using soft operations.'''
|
|
69
|
+
|
|
70
|
+
# https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/bijectors/Softfloor
|
|
71
|
+
def floor(self, x, param):
|
|
72
|
+
denom = jnp.tanh(param / 4.0)
|
|
73
|
+
return (jax.nn.sigmoid(param * (x - jnp.floor(x) - 1.0)) -
|
|
74
|
+
jax.nn.sigmoid(-param / 2.0)) / denom + jnp.floor(x)
|
|
75
|
+
|
|
76
|
+
# https://arxiv.org/abs/2006.09952
|
|
77
|
+
def round(self, x, param):
|
|
78
|
+
m = jnp.floor(x) + 0.5
|
|
79
|
+
return m + 0.5 * jnp.tanh(param * (x - m)) / jnp.tanh(param / 2.0)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
# ===========================================================================
|
|
83
|
+
# LOGICAL COMPLEMENT
|
|
84
|
+
# - abstract class
|
|
85
|
+
# - standard complement
|
|
86
|
+
#
|
|
87
|
+
# ===========================================================================
|
|
88
|
+
|
|
89
|
+
class Complement:
|
|
90
|
+
'''Base class for approximate logical complement operations.'''
|
|
91
|
+
|
|
92
|
+
def __call__(self, x):
|
|
93
|
+
raise NotImplementedError
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class StandardComplement(Complement):
|
|
97
|
+
'''The standard approximate logical complement given by x -> 1 - x.'''
|
|
98
|
+
|
|
99
|
+
# https://www.sciencedirect.com/science/article/abs/pii/016501149190171L
|
|
100
|
+
def __call__(self, x):
|
|
101
|
+
return 1.0 - x
|
|
102
|
+
|
|
103
|
+
|
|
64
104
|
# ===========================================================================
|
|
65
105
|
# TNORMS
|
|
66
106
|
# - abstract tnorm
|
|
@@ -69,6 +109,7 @@ class SigmoidComparison(Comparison):
|
|
|
69
109
|
# - Lukasiewicz tnorm
|
|
70
110
|
# - Yager(p) tnorm
|
|
71
111
|
#
|
|
112
|
+
# https://www.sciencedirect.com/science/article/abs/pii/016501149190171L
|
|
72
113
|
# ===========================================================================
|
|
73
114
|
|
|
74
115
|
class TNorm:
|
|
@@ -118,17 +159,17 @@ class YagerTNorm(TNorm):
|
|
|
118
159
|
(x, y) -> max(1 - ((1 - x)^p + (1 - y)^p)^(1/p)).'''
|
|
119
160
|
|
|
120
161
|
def __init__(self, p=2.0):
|
|
121
|
-
self.p = p
|
|
162
|
+
self.p = float(p)
|
|
122
163
|
|
|
123
|
-
def norm(self, x, y):
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
arg = jnp.power(base_x ** self.p + base_y ** self.p, 1.0 / self.p)
|
|
164
|
+
def norm(self, x, y):
|
|
165
|
+
base = jax.nn.relu(1.0 - jnp.stack([x, y], axis=0))
|
|
166
|
+
arg = jnp.linalg.norm(base, ord=self.p, axis=0)
|
|
127
167
|
return jax.nn.relu(1.0 - arg)
|
|
128
168
|
|
|
129
169
|
def norms(self, x, axis):
|
|
130
|
-
|
|
131
|
-
|
|
170
|
+
arg = jax.nn.relu(1.0 - x)
|
|
171
|
+
for ax in sorted(axis, reverse=True):
|
|
172
|
+
arg = jnp.linalg.norm(arg, ord=self.p, axis=ax)
|
|
132
173
|
return jax.nn.relu(1.0 - arg)
|
|
133
174
|
|
|
134
175
|
|
|
@@ -185,10 +226,11 @@ class GumbelSoftmax(RandomSampling):
|
|
|
185
226
|
def discrete(self, logic):
|
|
186
227
|
if logic.verbose:
|
|
187
228
|
raise_warning('Using the replacement rule: '
|
|
188
|
-
'Discrete(p) --> Gumbel-
|
|
229
|
+
'Discrete(p) --> Gumbel-Softmax(p)')
|
|
189
230
|
|
|
190
231
|
jax_argmax, jax_param = logic.argmax()
|
|
191
232
|
|
|
233
|
+
# https://arxiv.org/pdf/1611.01144
|
|
192
234
|
def _jax_wrapped_calc_discrete_gumbel_softmax(key, prob, param):
|
|
193
235
|
Gumbel01 = random.gumbel(key=key, shape=prob.shape, dtype=logic.REAL)
|
|
194
236
|
sample = Gumbel01 + jnp.log(prob + logic.eps)
|
|
@@ -249,6 +291,7 @@ class FuzzyLogic:
|
|
|
249
291
|
complement: Complement=StandardComplement(),
|
|
250
292
|
comparison: Comparison=SigmoidComparison(),
|
|
251
293
|
sampling: RandomSampling=GumbelSoftmax(),
|
|
294
|
+
rounding: Rounding=SoftRounding(),
|
|
252
295
|
weight: float=10.0,
|
|
253
296
|
debias: Optional[Set[str]]=None,
|
|
254
297
|
eps: float=1e-15,
|
|
@@ -260,6 +303,7 @@ class FuzzyLogic:
|
|
|
260
303
|
:param complement: fuzzy operator for logical NOT
|
|
261
304
|
:param comparison: fuzzy operator for comparisons (>, >=, <, ==, ~=, ...)
|
|
262
305
|
:param sampling: random sampling of non-reparameterizable distributions
|
|
306
|
+
:param rounding: rounding floating values to integers
|
|
263
307
|
:param weight: a sharpness parameter for sigmoid and softmax activations
|
|
264
308
|
:param debias: which functions to de-bias approximate on forward pass
|
|
265
309
|
:param eps: small positive float to mitigate underflow
|
|
@@ -270,6 +314,7 @@ class FuzzyLogic:
|
|
|
270
314
|
self.complement = complement
|
|
271
315
|
self.comparison = comparison
|
|
272
316
|
self.sampling = sampling
|
|
317
|
+
self.rounding = rounding
|
|
273
318
|
self.weight = float(weight)
|
|
274
319
|
if debias is None:
|
|
275
320
|
debias = set()
|
|
@@ -295,6 +340,7 @@ class FuzzyLogic:
|
|
|
295
340
|
f' complement ={type(self.complement).__name__}\n'
|
|
296
341
|
f' comparison ={type(self.comparison).__name__}\n'
|
|
297
342
|
f' sampling ={type(self.sampling).__name__}\n'
|
|
343
|
+
f' rounding ={type(self.rounding).__name__}\n'
|
|
298
344
|
f' sigmoid_weight={self.weight}\n'
|
|
299
345
|
f' cpfs_to_debias={self.debias}\n'
|
|
300
346
|
f' underflow_tol ={self.eps}\n'
|
|
@@ -492,12 +538,14 @@ class FuzzyLogic:
|
|
|
492
538
|
|
|
493
539
|
def sgn(self):
|
|
494
540
|
if self.verbose:
|
|
495
|
-
raise_warning('Using the replacement rule:
|
|
496
|
-
|
|
541
|
+
raise_warning('Using the replacement rule: '
|
|
542
|
+
'sgn(x) --> comparison.sgn(x)')
|
|
543
|
+
|
|
544
|
+
sgn_op = self.comparison.sgn
|
|
497
545
|
debias = 'sgn' in self.debias
|
|
498
546
|
|
|
499
547
|
def _jax_wrapped_calc_sgn_approx(x, param):
|
|
500
|
-
sample =
|
|
548
|
+
sample = sgn_op(x, param)
|
|
501
549
|
if debias:
|
|
502
550
|
hard_sample = jnp.sign(x)
|
|
503
551
|
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
@@ -510,37 +558,48 @@ class FuzzyLogic:
|
|
|
510
558
|
def floor(self):
|
|
511
559
|
if self.verbose:
|
|
512
560
|
raise_warning('Using the replacement rule: '
|
|
513
|
-
'floor(x) -->
|
|
561
|
+
'floor(x) --> rounding.floor(x)')
|
|
562
|
+
|
|
563
|
+
floor_op = self.rounding.floor
|
|
564
|
+
debias = 'floor' in self.debias
|
|
514
565
|
|
|
515
566
|
def _jax_wrapped_calc_floor_approx(x, param):
|
|
516
|
-
|
|
517
|
-
|
|
567
|
+
sample = floor_op(x, param)
|
|
568
|
+
if debias:
|
|
569
|
+
hard_sample = jnp.floor(x)
|
|
570
|
+
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
518
571
|
return sample
|
|
519
572
|
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
jax_floor, jax_param = self.floor()
|
|
524
|
-
|
|
525
|
-
def _jax_wrapped_calc_ceil_approx(x, param):
|
|
526
|
-
return -jax_floor(-x, param)
|
|
573
|
+
tags = ('weight', 'floor')
|
|
574
|
+
new_param = (tags, self.weight)
|
|
575
|
+
return _jax_wrapped_calc_floor_approx, new_param
|
|
527
576
|
|
|
528
|
-
return _jax_wrapped_calc_ceil_approx, jax_param
|
|
529
|
-
|
|
530
577
|
def round(self):
|
|
531
578
|
if self.verbose:
|
|
532
|
-
raise_warning('Using the replacement rule:
|
|
579
|
+
raise_warning('Using the replacement rule: '
|
|
580
|
+
'round(x) --> rounding.round(x)')
|
|
533
581
|
|
|
582
|
+
round_op = self.rounding.round
|
|
534
583
|
debias = 'round' in self.debias
|
|
535
584
|
|
|
536
585
|
def _jax_wrapped_calc_round_approx(x, param):
|
|
537
|
-
sample = x
|
|
586
|
+
sample = round_op(x, param)
|
|
538
587
|
if debias:
|
|
539
588
|
hard_sample = jnp.round(x)
|
|
540
589
|
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
541
590
|
return sample
|
|
542
591
|
|
|
543
|
-
|
|
592
|
+
tags = ('weight', 'round')
|
|
593
|
+
new_param = (tags, self.weight)
|
|
594
|
+
return _jax_wrapped_calc_round_approx, new_param
|
|
595
|
+
|
|
596
|
+
def ceil(self):
|
|
597
|
+
jax_floor, jax_param = self.floor()
|
|
598
|
+
|
|
599
|
+
def _jax_wrapped_calc_ceil_approx(x, param):
|
|
600
|
+
return -jax_floor(-x, param)
|
|
601
|
+
|
|
602
|
+
return _jax_wrapped_calc_ceil_approx, jax_param
|
|
544
603
|
|
|
545
604
|
def mod(self):
|
|
546
605
|
jax_floor, jax_param = self.floor()
|
|
@@ -586,6 +645,7 @@ class FuzzyLogic:
|
|
|
586
645
|
|
|
587
646
|
debias = 'argmax' in self.debias
|
|
588
647
|
|
|
648
|
+
# https://arxiv.org/abs/2110.05651
|
|
589
649
|
def _jax_wrapped_calc_argmax_approx(x, axis, param):
|
|
590
650
|
literals = FuzzyLogic.enumerate_literals(x.shape, axis=axis)
|
|
591
651
|
soft_max = jax.nn.softmax(param * x, axis=axis)
|
|
@@ -631,14 +691,14 @@ class FuzzyLogic:
|
|
|
631
691
|
if self.verbose:
|
|
632
692
|
raise_warning('Using the replacement rule: '
|
|
633
693
|
'switch(pred) { cases } --> '
|
|
634
|
-
'sum(cases[i] * softmax(-
|
|
694
|
+
'sum(cases[i] * softmax(-(pred - i)^2))')
|
|
635
695
|
|
|
636
696
|
debias = 'switch' in self.debias
|
|
637
697
|
|
|
638
698
|
def _jax_wrapped_calc_switch_approx(pred, cases, param):
|
|
639
699
|
literals = FuzzyLogic.enumerate_literals(cases.shape, axis=0)
|
|
640
700
|
pred = jnp.broadcast_to(pred[jnp.newaxis, ...], shape=cases.shape)
|
|
641
|
-
proximity = -jnp.
|
|
701
|
+
proximity = -jnp.square(pred - literals)
|
|
642
702
|
soft_case = jax.nn.softmax(param * proximity, axis=0)
|
|
643
703
|
sample = jnp.sum(cases * soft_case, axis=0)
|
|
644
704
|
if debias:
|
|
@@ -674,7 +734,7 @@ class FuzzyLogic:
|
|
|
674
734
|
# ===========================================================================
|
|
675
735
|
|
|
676
736
|
logic = FuzzyLogic()
|
|
677
|
-
w =
|
|
737
|
+
w = 1000.0
|
|
678
738
|
|
|
679
739
|
|
|
680
740
|
def _test_logical():
|
|
@@ -701,7 +761,7 @@ def _test_logical():
|
|
|
701
761
|
def _test_indexing():
|
|
702
762
|
print('testing indexing')
|
|
703
763
|
_argmax, _ = logic.argmax()
|
|
704
|
-
_argmin, _ = logic.
|
|
764
|
+
_argmin, _ = logic.argmin()
|
|
705
765
|
|
|
706
766
|
def argmaxmin(x):
|
|
707
767
|
amax = _argmax(x, 0, w)
|
|
@@ -764,11 +824,13 @@ def _test_rounding():
|
|
|
764
824
|
print('testing rounding')
|
|
765
825
|
_floor, _ = logic.floor()
|
|
766
826
|
_ceil, _ = logic.ceil()
|
|
827
|
+
_round, _ = logic.round()
|
|
767
828
|
_mod, _ = logic.mod()
|
|
768
829
|
|
|
769
|
-
x = jnp.asarray([2.1, 0.
|
|
830
|
+
x = jnp.asarray([2.1, 0.6, 1.99, -2.01, -3.2, -0.1, -1.01, 23.01, -101.99, 200.01])
|
|
770
831
|
print(_floor(x, w))
|
|
771
832
|
print(_ceil(x, w))
|
|
833
|
+
print(_round(x, w))
|
|
772
834
|
print(_mod(x, 2.0, w))
|
|
773
835
|
|
|
774
836
|
|