pyRDDLGym-jax 0.4__tar.gz → 0.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/PKG-INFO +11 -9
  2. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/README.md +6 -6
  3. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/core/logic.py +115 -53
  4. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/core/planner.py +140 -58
  5. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/core/tuning.py +53 -58
  6. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +2 -1
  7. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +2 -1
  8. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +2 -1
  9. pyrddlgym_jax-0.5/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +21 -0
  10. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/default_replan.cfg +2 -1
  11. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/run_tune.py +1 -3
  12. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax.egg-info/PKG-INFO +11 -9
  13. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax.egg-info/requires.txt +4 -1
  14. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/setup.py +6 -5
  15. pyrddlgym_jax-0.4/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +0 -20
  16. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/LICENSE +0 -0
  17. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/__init__.py +0 -0
  18. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/core/__init__.py +0 -0
  19. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/core/compiler.py +0 -0
  20. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/core/simulator.py +0 -0
  21. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/__init__.py +0 -0
  22. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +0 -0
  23. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +0 -0
  24. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +0 -0
  25. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +0 -0
  26. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -0
  27. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -0
  28. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +0 -0
  29. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +0 -0
  30. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -0
  31. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +0 -0
  32. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +0 -0
  33. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +0 -0
  34. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +0 -0
  35. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +0 -0
  36. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +0 -0
  37. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +0 -0
  38. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +0 -0
  39. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +0 -0
  40. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/__init__.py +0 -0
  41. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/default_drp.cfg +0 -0
  42. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/configs/default_slp.cfg +0 -0
  43. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/run_gradient.py +0 -0
  44. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/run_gym.py +0 -0
  45. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/run_plan.py +0 -0
  46. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax/examples/run_scipy.py +0 -0
  47. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax.egg-info/SOURCES.txt +0 -0
  48. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax.egg-info/dependency_links.txt +0 -0
  49. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/pyRDDLGym_jax.egg-info/top_level.txt +0 -0
  50. {pyrddlgym_jax-0.4 → pyrddlgym_jax-0.5}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pyRDDLGym-jax
3
- Version: 0.4
3
+ Version: 0.5
4
4
  Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
5
  Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
6
  Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
@@ -13,16 +13,18 @@ Classifier: Natural Language :: English
13
13
  Classifier: Operating System :: OS Independent
14
14
  Classifier: Programming Language :: Python :: 3
15
15
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
16
- Requires-Python: >=3.8
16
+ Requires-Python: >=3.9
17
17
  Description-Content-Type: text/markdown
18
18
  License-File: LICENSE
19
19
  Requires-Dist: pyRDDLGym>=2.0
20
20
  Requires-Dist: tqdm>=4.66
21
- Requires-Dist: bayesian-optimization>=1.4.3
22
21
  Requires-Dist: jax>=0.4.12
23
22
  Requires-Dist: optax>=0.1.9
24
23
  Requires-Dist: dm-haiku>=0.0.10
25
24
  Requires-Dist: tensorflow-probability>=0.21.0
25
+ Provides-Extra: extra
26
+ Requires-Dist: bayesian-optimization>=2.0.0; extra == "extra"
27
+ Requires-Dist: rddlrepository>=2.0; extra == "extra"
26
28
 
27
29
  # pyRDDLGym-jax
28
30
 
@@ -57,17 +59,17 @@ To use the compiler or planner without the automated hyper-parameter tuning, you
57
59
  - ``tensorflow-probability>=0.21.0``
58
60
 
59
61
  Additionally, if you wish to run the examples, you need ``rddlrepository>=2``.
60
- To run the automated tuning optimization, you will also need ``bayesian-optimization>=1.4.3``.
62
+ To run the automated tuning optimization, you will also need ``bayesian-optimization>=2.0.0``.
61
63
 
62
- You can install this package, together with all of its requirements, via pip:
64
+ You can install pyRDDLGym-jax with all requirements using pip:
63
65
 
64
66
  ```shell
65
- pip install rddlrepository pyRDDLGym-jax
67
+ pip install pyRDDLGym-jax[extra]
66
68
  ```
67
69
 
68
70
  ## Running from the Command Line
69
71
 
70
- A basic run script is provided to run the Jax Planner on any domain in ``rddlrepository``, and can be launched in the command line from the install directory of pyRDDLGym-jax:
72
+ A basic run script is provided to run the Jax Planner on any domain in ``rddlrepository`` from the install directory of pyRDDLGym-jax:
71
73
 
72
74
  ```shell
73
75
  python -m pyRDDLGym_jax.examples.run_plan <domain> <instance> <method> <episodes>
@@ -91,8 +93,8 @@ python -m pyRDDLGym_jax.examples.run_tune <domain> <instance> <method> <trials>
91
93
  ```
92
94
 
93
95
  where:
94
- - ``domain`` is the domain identifier as specified in rddlrepository (i.e. Wildfire_MDP_ippc2014)
95
- - ``instance`` is the instance identifier (i.e. 1, 2, ... 10)
96
+ - ``domain`` is the domain identifier as specified in rddlrepository
97
+ - ``instance`` is the instance identifier
96
98
  - ``method`` is the planning method to use (i.e. drp, slp, replan)
97
99
  - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
98
100
  - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
@@ -31,17 +31,17 @@ To use the compiler or planner without the automated hyper-parameter tuning, you
31
31
  - ``tensorflow-probability>=0.21.0``
32
32
 
33
33
  Additionally, if you wish to run the examples, you need ``rddlrepository>=2``.
34
- To run the automated tuning optimization, you will also need ``bayesian-optimization>=1.4.3``.
34
+ To run the automated tuning optimization, you will also need ``bayesian-optimization>=2.0.0``.
35
35
 
36
- You can install this package, together with all of its requirements, via pip:
36
+ You can install pyRDDLGym-jax with all requirements using pip:
37
37
 
38
38
  ```shell
39
- pip install rddlrepository pyRDDLGym-jax
39
+ pip install pyRDDLGym-jax[extra]
40
40
  ```
41
41
 
42
42
  ## Running from the Command Line
43
43
 
44
- A basic run script is provided to run the Jax Planner on any domain in ``rddlrepository``, and can be launched in the command line from the install directory of pyRDDLGym-jax:
44
+ A basic run script is provided to run the Jax Planner on any domain in ``rddlrepository`` from the install directory of pyRDDLGym-jax:
45
45
 
46
46
  ```shell
47
47
  python -m pyRDDLGym_jax.examples.run_plan <domain> <instance> <method> <episodes>
@@ -65,8 +65,8 @@ python -m pyRDDLGym_jax.examples.run_tune <domain> <instance> <method> <trials>
65
65
  ```
66
66
 
67
67
  where:
68
- - ``domain`` is the domain identifier as specified in rddlrepository (i.e. Wildfire_MDP_ippc2014)
69
- - ``instance`` is the instance identifier (i.e. 1, 2, ... 10)
68
+ - ``domain`` is the domain identifier as specified in rddlrepository
69
+ - ``instance`` is the instance identifier
70
70
  - ``method`` is the planning method to use (i.e. drp, slp, replan)
71
71
  - ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
72
72
  - ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
@@ -7,27 +7,6 @@ import jax.random as random
7
7
  from pyRDDLGym.core.debug.exception import raise_warning
8
8
 
9
9
 
10
- # ===========================================================================
11
- # LOGICAL COMPLEMENT
12
- # - abstract class
13
- # - standard complement
14
- #
15
- # ===========================================================================
16
-
17
- class Complement:
18
- '''Base class for approximate logical complement operations.'''
19
-
20
- def __call__(self, x):
21
- raise NotImplementedError
22
-
23
-
24
- class StandardComplement(Complement):
25
- '''The standard approximate logical complement given by x -> 1 - x.'''
26
-
27
- def __call__(self, x):
28
- return 1.0 - x
29
-
30
-
31
10
  # ===========================================================================
32
11
  # RELATIONAL OPERATIONS
33
12
  # - abstract class
@@ -47,10 +26,14 @@ class Comparison:
47
26
  def equal(self, x, y, param):
48
27
  raise NotImplementedError
49
28
 
29
+ def sgn(self, x, param):
30
+ raise NotImplementedError
31
+
50
32
 
51
33
  class SigmoidComparison(Comparison):
52
34
  '''Comparison operations approximated using sigmoid functions.'''
53
35
 
36
+ # https://arxiv.org/abs/2110.05651
54
37
  def greater_equal(self, x, y, param):
55
38
  return jax.nn.sigmoid(param * (x - y))
56
39
 
@@ -59,8 +42,65 @@ class SigmoidComparison(Comparison):
59
42
 
60
43
  def equal(self, x, y, param):
61
44
  return 1.0 - jnp.square(jnp.tanh(param * (y - x)))
45
+
46
+ def sgn(self, x, param):
47
+ return jnp.tanh(param * x)
62
48
 
63
-
49
+
50
+ # ===========================================================================
51
+ # ROUNDING OPERATIONS
52
+ # - abstract class
53
+ # - soft rounding
54
+ #
55
+ # ===========================================================================
56
+
57
+ class Rounding:
58
+ '''Base class for approximate rounding operations.'''
59
+
60
+ def floor(self, x, param):
61
+ raise NotImplementedError
62
+
63
+ def round(self, x, param):
64
+ raise NotImplementedError
65
+
66
+
67
+ class SoftRounding(Rounding):
68
+ '''Rounding operations approximated using soft operations.'''
69
+
70
+ # https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/bijectors/Softfloor
71
+ def floor(self, x, param):
72
+ denom = jnp.tanh(param / 4.0)
73
+ return (jax.nn.sigmoid(param * (x - jnp.floor(x) - 1.0)) -
74
+ jax.nn.sigmoid(-param / 2.0)) / denom + jnp.floor(x)
75
+
76
+ # https://arxiv.org/abs/2006.09952
77
+ def round(self, x, param):
78
+ m = jnp.floor(x) + 0.5
79
+ return m + 0.5 * jnp.tanh(param * (x - m)) / jnp.tanh(param / 2.0)
80
+
81
+
82
+ # ===========================================================================
83
+ # LOGICAL COMPLEMENT
84
+ # - abstract class
85
+ # - standard complement
86
+ #
87
+ # ===========================================================================
88
+
89
+ class Complement:
90
+ '''Base class for approximate logical complement operations.'''
91
+
92
+ def __call__(self, x):
93
+ raise NotImplementedError
94
+
95
+
96
+ class StandardComplement(Complement):
97
+ '''The standard approximate logical complement given by x -> 1 - x.'''
98
+
99
+ # https://www.sciencedirect.com/science/article/abs/pii/016501149190171L
100
+ def __call__(self, x):
101
+ return 1.0 - x
102
+
103
+
64
104
  # ===========================================================================
65
105
  # TNORMS
66
106
  # - abstract tnorm
@@ -69,6 +109,7 @@ class SigmoidComparison(Comparison):
69
109
  # - Lukasiewicz tnorm
70
110
  # - Yager(p) tnorm
71
111
  #
112
+ # https://www.sciencedirect.com/science/article/abs/pii/016501149190171L
72
113
  # ===========================================================================
73
114
 
74
115
  class TNorm:
@@ -118,17 +159,17 @@ class YagerTNorm(TNorm):
118
159
  (x, y) -> max(1 - ((1 - x)^p + (1 - y)^p)^(1/p)).'''
119
160
 
120
161
  def __init__(self, p=2.0):
121
- self.p = p
162
+ self.p = float(p)
122
163
 
123
- def norm(self, x, y):
124
- base_x = jax.nn.relu(1.0 - x)
125
- base_y = jax.nn.relu(1.0 - y)
126
- arg = jnp.power(base_x ** self.p + base_y ** self.p, 1.0 / self.p)
164
+ def norm(self, x, y):
165
+ base = jax.nn.relu(1.0 - jnp.stack([x, y], axis=0))
166
+ arg = jnp.linalg.norm(base, ord=self.p, axis=0)
127
167
  return jax.nn.relu(1.0 - arg)
128
168
 
129
169
  def norms(self, x, axis):
130
- base = jax.nn.relu(1.0 - x)
131
- arg = jnp.power(jnp.sum(base ** self.p, axis=axis), 1.0 / self.p)
170
+ arg = jax.nn.relu(1.0 - x)
171
+ for ax in sorted(axis, reverse=True):
172
+ arg = jnp.linalg.norm(arg, ord=self.p, axis=ax)
132
173
  return jax.nn.relu(1.0 - arg)
133
174
 
134
175
 
@@ -185,10 +226,11 @@ class GumbelSoftmax(RandomSampling):
185
226
  def discrete(self, logic):
186
227
  if logic.verbose:
187
228
  raise_warning('Using the replacement rule: '
188
- 'Discrete(p) --> Gumbel-softmax(p)')
229
+ 'Discrete(p) --> Gumbel-Softmax(p)')
189
230
 
190
231
  jax_argmax, jax_param = logic.argmax()
191
232
 
233
+ # https://arxiv.org/pdf/1611.01144
192
234
  def _jax_wrapped_calc_discrete_gumbel_softmax(key, prob, param):
193
235
  Gumbel01 = random.gumbel(key=key, shape=prob.shape, dtype=logic.REAL)
194
236
  sample = Gumbel01 + jnp.log(prob + logic.eps)
@@ -249,6 +291,7 @@ class FuzzyLogic:
249
291
  complement: Complement=StandardComplement(),
250
292
  comparison: Comparison=SigmoidComparison(),
251
293
  sampling: RandomSampling=GumbelSoftmax(),
294
+ rounding: Rounding=SoftRounding(),
252
295
  weight: float=10.0,
253
296
  debias: Optional[Set[str]]=None,
254
297
  eps: float=1e-15,
@@ -260,6 +303,7 @@ class FuzzyLogic:
260
303
  :param complement: fuzzy operator for logical NOT
261
304
  :param comparison: fuzzy operator for comparisons (>, >=, <, ==, ~=, ...)
262
305
  :param sampling: random sampling of non-reparameterizable distributions
306
+ :param rounding: rounding floating values to integers
263
307
  :param weight: a sharpness parameter for sigmoid and softmax activations
264
308
  :param debias: which functions to de-bias approximate on forward pass
265
309
  :param eps: small positive float to mitigate underflow
@@ -270,6 +314,7 @@ class FuzzyLogic:
270
314
  self.complement = complement
271
315
  self.comparison = comparison
272
316
  self.sampling = sampling
317
+ self.rounding = rounding
273
318
  self.weight = float(weight)
274
319
  if debias is None:
275
320
  debias = set()
@@ -295,6 +340,7 @@ class FuzzyLogic:
295
340
  f' complement ={type(self.complement).__name__}\n'
296
341
  f' comparison ={type(self.comparison).__name__}\n'
297
342
  f' sampling ={type(self.sampling).__name__}\n'
343
+ f' rounding ={type(self.rounding).__name__}\n'
298
344
  f' sigmoid_weight={self.weight}\n'
299
345
  f' cpfs_to_debias={self.debias}\n'
300
346
  f' underflow_tol ={self.eps}\n'
@@ -492,12 +538,14 @@ class FuzzyLogic:
492
538
 
493
539
  def sgn(self):
494
540
  if self.verbose:
495
- raise_warning('Using the replacement rule: sgn(x) --> tanh(x)')
496
-
541
+ raise_warning('Using the replacement rule: '
542
+ 'sgn(x) --> comparison.sgn(x)')
543
+
544
+ sgn_op = self.comparison.sgn
497
545
  debias = 'sgn' in self.debias
498
546
 
499
547
  def _jax_wrapped_calc_sgn_approx(x, param):
500
- sample = jnp.tanh(param * x)
548
+ sample = sgn_op(x, param)
501
549
  if debias:
502
550
  hard_sample = jnp.sign(x)
503
551
  sample += jax.lax.stop_gradient(hard_sample - sample)
@@ -510,37 +558,48 @@ class FuzzyLogic:
510
558
  def floor(self):
511
559
  if self.verbose:
512
560
  raise_warning('Using the replacement rule: '
513
- 'floor(x) --> x - atan(-1.0 / tan(pi * x)) / pi - 0.5')
561
+ 'floor(x) --> rounding.floor(x)')
562
+
563
+ floor_op = self.rounding.floor
564
+ debias = 'floor' in self.debias
514
565
 
515
566
  def _jax_wrapped_calc_floor_approx(x, param):
516
- sawtooth_part = jnp.arctan(-1.0 / jnp.tan(x * jnp.pi)) / jnp.pi + 0.5
517
- sample = x - jax.lax.stop_gradient(sawtooth_part)
567
+ sample = floor_op(x, param)
568
+ if debias:
569
+ hard_sample = jnp.floor(x)
570
+ sample += jax.lax.stop_gradient(hard_sample - sample)
518
571
  return sample
519
572
 
520
- return _jax_wrapped_calc_floor_approx, None
521
-
522
- def ceil(self):
523
- jax_floor, jax_param = self.floor()
524
-
525
- def _jax_wrapped_calc_ceil_approx(x, param):
526
- return -jax_floor(-x, param)
573
+ tags = ('weight', 'floor')
574
+ new_param = (tags, self.weight)
575
+ return _jax_wrapped_calc_floor_approx, new_param
527
576
 
528
- return _jax_wrapped_calc_ceil_approx, jax_param
529
-
530
577
  def round(self):
531
578
  if self.verbose:
532
- raise_warning('Using the replacement rule: round(x) --> x')
579
+ raise_warning('Using the replacement rule: '
580
+ 'round(x) --> rounding.round(x)')
533
581
 
582
+ round_op = self.rounding.round
534
583
  debias = 'round' in self.debias
535
584
 
536
585
  def _jax_wrapped_calc_round_approx(x, param):
537
- sample = x
586
+ sample = round_op(x, param)
538
587
  if debias:
539
588
  hard_sample = jnp.round(x)
540
589
  sample += jax.lax.stop_gradient(hard_sample - sample)
541
590
  return sample
542
591
 
543
- return _jax_wrapped_calc_round_approx, None
592
+ tags = ('weight', 'round')
593
+ new_param = (tags, self.weight)
594
+ return _jax_wrapped_calc_round_approx, new_param
595
+
596
+ def ceil(self):
597
+ jax_floor, jax_param = self.floor()
598
+
599
+ def _jax_wrapped_calc_ceil_approx(x, param):
600
+ return -jax_floor(-x, param)
601
+
602
+ return _jax_wrapped_calc_ceil_approx, jax_param
544
603
 
545
604
  def mod(self):
546
605
  jax_floor, jax_param = self.floor()
@@ -586,6 +645,7 @@ class FuzzyLogic:
586
645
 
587
646
  debias = 'argmax' in self.debias
588
647
 
648
+ # https://arxiv.org/abs/2110.05651
589
649
  def _jax_wrapped_calc_argmax_approx(x, axis, param):
590
650
  literals = FuzzyLogic.enumerate_literals(x.shape, axis=axis)
591
651
  soft_max = jax.nn.softmax(param * x, axis=axis)
@@ -631,14 +691,14 @@ class FuzzyLogic:
631
691
  if self.verbose:
632
692
  raise_warning('Using the replacement rule: '
633
693
  'switch(pred) { cases } --> '
634
- 'sum(cases[i] * softmax(-abs(pred - i)))')
694
+ 'sum(cases[i] * softmax(-(pred - i)^2))')
635
695
 
636
696
  debias = 'switch' in self.debias
637
697
 
638
698
  def _jax_wrapped_calc_switch_approx(pred, cases, param):
639
699
  literals = FuzzyLogic.enumerate_literals(cases.shape, axis=0)
640
700
  pred = jnp.broadcast_to(pred[jnp.newaxis, ...], shape=cases.shape)
641
- proximity = -jnp.abs(pred - literals)
701
+ proximity = -jnp.square(pred - literals)
642
702
  soft_case = jax.nn.softmax(param * proximity, axis=0)
643
703
  sample = jnp.sum(cases * soft_case, axis=0)
644
704
  if debias:
@@ -674,7 +734,7 @@ class FuzzyLogic:
674
734
  # ===========================================================================
675
735
 
676
736
  logic = FuzzyLogic()
677
- w = 100.0
737
+ w = 1000.0
678
738
 
679
739
 
680
740
  def _test_logical():
@@ -701,7 +761,7 @@ def _test_logical():
701
761
  def _test_indexing():
702
762
  print('testing indexing')
703
763
  _argmax, _ = logic.argmax()
704
- _argmin, _ = logic.argmax()
764
+ _argmin, _ = logic.argmin()
705
765
 
706
766
  def argmaxmin(x):
707
767
  amax = _argmax(x, 0, w)
@@ -764,11 +824,13 @@ def _test_rounding():
764
824
  print('testing rounding')
765
825
  _floor, _ = logic.floor()
766
826
  _ceil, _ = logic.ceil()
827
+ _round, _ = logic.round()
767
828
  _mod, _ = logic.mod()
768
829
 
769
- x = jnp.asarray([2.1, 0.5001, 1.99, -2.01, -3.2, -0.1, -1.01, 23.01, -101.99, 200.01])
830
+ x = jnp.asarray([2.1, 0.6, 1.99, -2.01, -3.2, -0.1, -1.01, 23.01, -101.99, 200.01])
770
831
  print(_floor(x, w))
771
832
  print(_ceil(x, w))
833
+ print(_round(x, w))
772
834
  print(_mod(x, 2.0, w))
773
835
 
774
836