pyRDDLGym-jax 0.5__py3-none-any.whl → 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. pyRDDLGym_jax/__init__.py +1 -1
  2. pyRDDLGym_jax/core/compiler.py +463 -592
  3. pyRDDLGym_jax/core/logic.py +784 -544
  4. pyRDDLGym_jax/core/planner.py +329 -463
  5. pyRDDLGym_jax/core/simulator.py +7 -5
  6. pyRDDLGym_jax/core/tuning.py +379 -568
  7. pyRDDLGym_jax/core/visualization.py +1463 -0
  8. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +5 -6
  9. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +4 -5
  10. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +5 -6
  11. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
  12. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +4 -4
  13. pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +3 -3
  14. pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +3 -3
  15. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +3 -3
  16. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +3 -3
  17. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +3 -3
  18. pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +3 -3
  19. pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -4
  20. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +3 -3
  21. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +3 -3
  22. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +5 -5
  23. pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +4 -4
  24. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +3 -3
  25. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +3 -3
  26. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +3 -3
  27. pyRDDLGym_jax/examples/configs/default_drp.cfg +3 -3
  28. pyRDDLGym_jax/examples/configs/default_replan.cfg +3 -3
  29. pyRDDLGym_jax/examples/configs/default_slp.cfg +3 -3
  30. pyRDDLGym_jax/examples/configs/tuning_drp.cfg +19 -0
  31. pyRDDLGym_jax/examples/configs/tuning_replan.cfg +20 -0
  32. pyRDDLGym_jax/examples/configs/tuning_slp.cfg +19 -0
  33. pyRDDLGym_jax/examples/run_plan.py +4 -1
  34. pyRDDLGym_jax/examples/run_tune.py +40 -27
  35. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/METADATA +161 -104
  36. pyRDDLGym_jax-1.0.dist-info/RECORD +45 -0
  37. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/WHEEL +1 -1
  38. pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -19
  39. pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -20
  40. pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -18
  41. pyRDDLGym_jax-0.5.dist-info/RECORD +0 -44
  42. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/LICENSE +0 -0
  43. {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,13 @@ import jax
4
4
  import jax.numpy as jnp
5
5
  import jax.random as random
6
6
 
7
- from pyRDDLGym.core.debug.exception import raise_warning
7
+
8
+ def enumerate_literals(shape, axis, dtype=jnp.int32):
9
+ literals = jnp.arange(shape[axis], dtype=dtype)
10
+ literals = literals[(...,) + (jnp.newaxis,) * (len(shape) - 1)]
11
+ literals = jnp.moveaxis(literals, source=0, destination=axis)
12
+ literals = jnp.broadcast_to(literals, shape=shape)
13
+ return literals
8
14
 
9
15
 
10
16
  # ===========================================================================
@@ -17,36 +23,71 @@ from pyRDDLGym.core.debug.exception import raise_warning
17
23
  class Comparison:
18
24
  '''Base class for approximate comparison operations.'''
19
25
 
20
- def greater_equal(self, x, y, param):
26
+ def greater_equal(self, id, init_params):
27
+ raise NotImplementedError
28
+
29
+ def greater(self, id, init_params):
21
30
  raise NotImplementedError
22
31
 
23
- def greater(self, x, y, param):
32
+ def equal(self, id, init_params):
24
33
  raise NotImplementedError
25
34
 
26
- def equal(self, x, y, param):
35
+ def sgn(self, id, init_params):
27
36
  raise NotImplementedError
28
37
 
29
- def sgn(self, x, param):
38
+ def argmax(self, id, init_params):
30
39
  raise NotImplementedError
31
40
 
32
41
 
33
42
  class SigmoidComparison(Comparison):
34
43
  '''Comparison operations approximated using sigmoid functions.'''
35
44
 
45
+ def __init__(self, weight: float=10.0):
46
+ self.weight = weight
47
+
36
48
  # https://arxiv.org/abs/2110.05651
37
- def greater_equal(self, x, y, param):
38
- return jax.nn.sigmoid(param * (x - y))
39
-
40
- def greater(self, x, y, param):
41
- return jax.nn.sigmoid(param * (x - y))
42
-
43
- def equal(self, x, y, param):
44
- return 1.0 - jnp.square(jnp.tanh(param * (y - x)))
49
+ def greater_equal(self, id, init_params):
50
+ id_ = str(id)
51
+ init_params[id_] = self.weight
52
+ def _jax_wrapped_calc_greater_equal_approx(x, y, params):
53
+ gre_eq = jax.nn.sigmoid(params[id_] * (x - y))
54
+ return gre_eq, params
55
+ return _jax_wrapped_calc_greater_equal_approx
56
+
57
+ def greater(self, id, init_params):
58
+ return self.greater_equal(id, init_params)
59
+
60
+ def equal(self, id, init_params):
61
+ id_ = str(id)
62
+ init_params[id_] = self.weight
63
+ def _jax_wrapped_calc_equal_approx(x, y, params):
64
+ equal = 1.0 - jnp.square(jnp.tanh(params[id_] * (y - x)))
65
+ return equal, params
66
+ return _jax_wrapped_calc_equal_approx
67
+
68
+ def sgn(self, id, init_params):
69
+ id_ = str(id)
70
+ init_params[id_] = self.weight
71
+ def _jax_wrapped_calc_sgn_approx(x, params):
72
+ sgn = jnp.tanh(params[id_] * x)
73
+ return sgn, params
74
+ return _jax_wrapped_calc_sgn_approx
45
75
 
46
- def sgn(self, x, param):
47
- return jnp.tanh(param * x)
48
-
49
-
76
+ # https://arxiv.org/abs/2110.05651
77
+ def argmax(self, id, init_params):
78
+ id_ = str(id)
79
+ init_params[id_] = self.weight
80
+ def _jax_wrapped_calc_argmax_approx(x, axis, params):
81
+ literals = enumerate_literals(x.shape, axis=axis)
82
+ softmax = jax.nn.softmax(params[id_] * x, axis=axis)
83
+ sample = jnp.sum(literals * softmax, axis=axis)
84
+ return sample, params
85
+ return _jax_wrapped_calc_argmax_approx
86
+
87
+ def __str__(self) -> str:
88
+ return f'Sigmoid comparison with weight {self.weight}'
89
+
90
+
50
91
  # ===========================================================================
51
92
  # ROUNDING OPERATIONS
52
93
  # - abstract class
@@ -57,26 +98,44 @@ class SigmoidComparison(Comparison):
57
98
  class Rounding:
58
99
  '''Base class for approximate rounding operations.'''
59
100
 
60
- def floor(self, x, param):
101
+ def floor(self, id, init_params):
61
102
  raise NotImplementedError
62
103
 
63
- def round(self, x, param):
104
+ def round(self, id, init_params):
64
105
  raise NotImplementedError
65
106
 
66
107
 
67
108
  class SoftRounding(Rounding):
68
109
  '''Rounding operations approximated using soft operations.'''
69
110
 
111
+ def __init__(self, weight: float=10.0):
112
+ self.weight = weight
113
+
70
114
  # https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/bijectors/Softfloor
71
- def floor(self, x, param):
72
- denom = jnp.tanh(param / 4.0)
73
- return (jax.nn.sigmoid(param * (x - jnp.floor(x) - 1.0)) -
74
- jax.nn.sigmoid(-param / 2.0)) / denom + jnp.floor(x)
115
+ def floor(self, id, init_params):
116
+ id_ = str(id)
117
+ init_params[id_] = self.weight
118
+ def _jax_wrapped_calc_floor_approx(x, params):
119
+ param = params[id_]
120
+ denom = jnp.tanh(param / 4.0)
121
+ floor = (jax.nn.sigmoid(param * (x - jnp.floor(x) - 1.0)) -
122
+ jax.nn.sigmoid(-param / 2.0)) / denom + jnp.floor(x)
123
+ return floor, params
124
+ return _jax_wrapped_calc_floor_approx
75
125
 
76
126
  # https://arxiv.org/abs/2006.09952
77
- def round(self, x, param):
78
- m = jnp.floor(x) + 0.5
79
- return m + 0.5 * jnp.tanh(param * (x - m)) / jnp.tanh(param / 2.0)
127
+ def round(self, id, init_params):
128
+ id_ = str(id)
129
+ init_params[id_] = self.weight
130
+ def _jax_wrapped_calc_round_approx(x, params):
131
+ param = params[id_]
132
+ m = jnp.floor(x) + 0.5
133
+ rounded = m + 0.5 * jnp.tanh(param * (x - m)) / jnp.tanh(param / 2.0)
134
+ return rounded, params
135
+ return _jax_wrapped_calc_round_approx
136
+
137
+ def __str__(self) -> str:
138
+ return f'SoftFloor and SoftRound with weight {self.weight}'
80
139
 
81
140
 
82
141
  # ===========================================================================
@@ -89,7 +148,7 @@ class SoftRounding(Rounding):
89
148
  class Complement:
90
149
  '''Base class for approximate logical complement operations.'''
91
150
 
92
- def __call__(self, x):
151
+ def __call__(self, id, init_params):
93
152
  raise NotImplementedError
94
153
 
95
154
 
@@ -97,9 +156,16 @@ class StandardComplement(Complement):
97
156
  '''The standard approximate logical complement given by x -> 1 - x.'''
98
157
 
99
158
  # https://www.sciencedirect.com/science/article/abs/pii/016501149190171L
100
- def __call__(self, x):
101
- return 1.0 - x
102
-
159
+ @staticmethod
160
+ def _jax_wrapped_calc_not_approx(x, params):
161
+ return 1.0 - x, params
162
+
163
+ def __call__(self, id, init_params):
164
+ return self._jax_wrapped_calc_not_approx
165
+
166
+ def __str__(self) -> str:
167
+ return 'Standard complement'
168
+
103
169
 
104
170
  # ===========================================================================
105
171
  # TNORMS
@@ -115,43 +181,78 @@ class StandardComplement(Complement):
115
181
  class TNorm:
116
182
  '''Base class for fuzzy differentiable t-norms.'''
117
183
 
118
- def norm(self, x, y):
184
+ def norm(self, id, init_params):
119
185
  '''Elementwise t-norm of x and y.'''
120
186
  raise NotImplementedError
121
187
 
122
- def norms(self, x, axis):
188
+ def norms(self, id, init_params):
123
189
  '''T-norm computed for tensor x along axis.'''
124
190
  raise NotImplementedError
125
-
191
+
126
192
 
127
193
  class ProductTNorm(TNorm):
128
194
  '''Product t-norm given by the expression (x, y) -> x * y.'''
129
195
 
130
- def norm(self, x, y):
131
- return x * y
196
+ @staticmethod
197
+ def _jax_wrapped_calc_and_approx(x, y, params):
198
+ return x * y, params
199
+
200
+ def norm(self, id, init_params):
201
+ return self._jax_wrapped_calc_and_approx
132
202
 
133
- def norms(self, x, axis):
134
- return jnp.prod(x, axis=axis)
203
+ @staticmethod
204
+ def _jax_wrapped_calc_forall_approx(x, axis, params):
205
+ return jnp.prod(x, axis=axis), params
206
+
207
+ def norms(self, id, init_params):
208
+ return self._jax_wrapped_calc_forall_approx
135
209
 
210
+ def __str__(self) -> str:
211
+ return 'Product t-norm'
212
+
136
213
 
137
214
  class GodelTNorm(TNorm):
138
215
  '''Godel t-norm given by the expression (x, y) -> min(x, y).'''
139
216
 
140
- def norm(self, x, y):
141
- return jnp.minimum(x, y)
217
+ @staticmethod
218
+ def _jax_wrapped_calc_and_approx(x, y, params):
219
+ return jnp.minimum(x, y), params
220
+
221
+ def norm(self, id, init_params):
222
+ return self._jax_wrapped_calc_and_approx
223
+
224
+ @staticmethod
225
+ def _jax_wrapped_calc_forall_approx(x, axis, params):
226
+ return jnp.min(x, axis=axis), params
227
+
228
+ def norms(self, id, init_params):
229
+ return self._jax_wrapped_calc_forall_approx
230
+
231
+ def __str__(self) -> str:
232
+ return 'Godel t-norm'
142
233
 
143
- def norms(self, x, axis):
144
- return jnp.min(x, axis=axis)
145
-
146
234
 
147
235
  class LukasiewiczTNorm(TNorm):
148
236
  '''Lukasiewicz t-norm given by the expression (x, y) -> max(x + y - 1, 0).'''
149
237
 
150
- def norm(self, x, y):
151
- return jax.nn.relu(x + y - 1.0)
238
+ @staticmethod
239
+ def _jax_wrapped_calc_and_approx(x, y, params):
240
+ land = jax.nn.relu(x + y - 1.0)
241
+ return land, params
242
+
243
+ def norm(self, id, init_params):
244
+ return self._jax_wrapped_calc_and_approx
152
245
 
153
- def norms(self, x, axis):
154
- return jax.nn.relu(jnp.sum(x - 1.0, axis=axis) + 1.0)
246
+ @staticmethod
247
+ def _jax_wrapped_calc_forall_approx(x, axis, params):
248
+ forall = jax.nn.relu(jnp.sum(x - 1.0, axis=axis) + 1.0)
249
+ return forall, params
250
+
251
+ def norms(self, id, init_params):
252
+ return self._jax_wrapped_calc_forall_approx
253
+
254
+ def __str__(self) -> str:
255
+ return 'Lukasiewicz t-norm'
155
256
 
156
257
 
157
258
  class YagerTNorm(TNorm):
@@ -161,17 +262,30 @@ class YagerTNorm(TNorm):
161
262
  def __init__(self, p=2.0):
162
263
  self.p = float(p)
163
264
 
164
- def norm(self, x, y):
165
- base = jax.nn.relu(1.0 - jnp.stack([x, y], axis=0))
166
- arg = jnp.linalg.norm(base, ord=self.p, axis=0)
167
- return jax.nn.relu(1.0 - arg)
265
+ def norm(self, id, init_params):
266
+ id_ = str(id)
267
+ init_params[id_] = self.p
268
+ def _jax_wrapped_calc_and_approx(x, y, params):
269
+ base = jax.nn.relu(1.0 - jnp.stack([x, y], axis=0))
270
+ arg = jnp.linalg.norm(base, ord=params[id_], axis=0)
271
+ land = jax.nn.relu(1.0 - arg)
272
+ return land, params
273
+ return _jax_wrapped_calc_and_approx
274
+
275
+ def norms(self, id, init_params):
276
+ id_ = str(id)
277
+ init_params[id_] = self.p
278
+ def _jax_wrapped_calc_forall_approx(x, axis, params):
279
+ arg = jax.nn.relu(1.0 - x)
280
+ for ax in sorted(axis, reverse=True):
281
+ arg = jnp.linalg.norm(arg, ord=params[id_], axis=ax)
282
+ forall = jax.nn.relu(1.0 - arg)
283
+ return forall, params
284
+ return _jax_wrapped_calc_forall_approx
285
+
286
+ def __str__(self) -> str:
287
+ return f'Yager({self.p}) t-norm'
168
288
 
169
- def norms(self, x, axis):
170
- arg = jax.nn.relu(1.0 - x)
171
- for ax in sorted(axis, reverse=True):
172
- arg = jnp.linalg.norm(arg, ord=self.p, axis=ax)
173
- return jax.nn.relu(1.0 - arg)
174
-
175
289
 
176
290
  # ===========================================================================
177
291
  # RANDOM SAMPLING
@@ -185,117 +299,443 @@ class RandomSampling:
185
299
  '''An abstract class that describes how discrete and non-reparameterizable
186
300
  random variables are sampled.'''
187
301
 
188
- def discrete(self, logic):
302
+ def discrete(self, id, init_params, logic):
189
303
  raise NotImplementedError
190
304
 
191
- def bernoulli(self, logic):
192
- jax_discrete, jax_param = self.discrete(logic)
193
-
194
- def _jax_wrapped_calc_bernoulli_approx(key, prob, param):
305
+ def bernoulli(self, id, init_params, logic):
306
+ discrete_approx = self.discrete(id, init_params, logic)
307
+ def _jax_wrapped_calc_bernoulli_approx(key, prob, params):
195
308
  prob = jnp.stack([1.0 - prob, prob], axis=-1)
196
- sample = jax_discrete(key, prob, param)
197
- return sample
198
-
199
- return _jax_wrapped_calc_bernoulli_approx, jax_param
309
+ return discrete_approx(key, prob, params)
310
+ return _jax_wrapped_calc_bernoulli_approx
200
311
 
201
- def poisson(self, logic):
202
-
203
- def _jax_wrapped_calc_poisson_exact(key, rate, param):
204
- return random.poisson(key=key, lam=rate, dtype=logic.INT)
312
+ @staticmethod
313
+ def _jax_wrapped_calc_poisson_exact(key, rate, params):
314
+ sample = random.poisson(key=key, lam=rate, dtype=logic.INT)
315
+ return sample, params
205
316
 
206
- return _jax_wrapped_calc_poisson_exact, None
207
-
208
- def geometric(self, logic):
209
- if logic.verbose:
210
- raise_warning('Using the replacement rule: '
211
- 'Geometric(p) --> floor(log(U) / log(1 - p)) + 1')
212
-
213
- jax_floor, jax_param = logic.floor()
214
-
215
- def _jax_wrapped_calc_geometric_approx(key, prob, param):
317
+ def poisson(self, id, init_params, logic):
318
+ return self._jax_wrapped_calc_poisson_exact
319
+
320
+ def geometric(self, id, init_params, logic):
321
+ approx_floor = logic.floor(id, init_params)
322
+ def _jax_wrapped_calc_geometric_approx(key, prob, params):
216
323
  U = random.uniform(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
217
- sample = jax_floor(jnp.log(U) / jnp.log(1.0 - prob), param) + 1
218
- return sample
219
-
220
- return _jax_wrapped_calc_geometric_approx, jax_param
324
+ floor, params = approx_floor(jnp.log(U) / jnp.log(1.0 - prob), params)
325
+ sample = floor + 1
326
+ return sample, params
327
+ return _jax_wrapped_calc_geometric_approx
221
328
 
222
329
 
223
330
  class GumbelSoftmax(RandomSampling):
224
331
  '''Random sampling of discrete variables using Gumbel-softmax trick.'''
225
332
 
226
- def discrete(self, logic):
227
- if logic.verbose:
228
- raise_warning('Using the replacement rule: '
229
- 'Discrete(p) --> Gumbel-Softmax(p)')
230
-
231
- jax_argmax, jax_param = logic.argmax()
232
-
233
- # https://arxiv.org/pdf/1611.01144
234
- def _jax_wrapped_calc_discrete_gumbel_softmax(key, prob, param):
333
+ # https://arxiv.org/pdf/1611.01144
334
+ def discrete(self, id, init_params, logic):
335
+ argmax_approx = logic.argmax(id, init_params)
336
+ def _jax_wrapped_calc_discrete_gumbel_softmax(key, prob, params):
235
337
  Gumbel01 = random.gumbel(key=key, shape=prob.shape, dtype=logic.REAL)
236
338
  sample = Gumbel01 + jnp.log(prob + logic.eps)
237
- sample = jax_argmax(sample, axis=-1, param=param)
238
- return sample
239
-
240
- return _jax_wrapped_calc_discrete_gumbel_softmax, jax_param
339
+ return argmax_approx(sample, axis=-1, params=params)
340
+ return _jax_wrapped_calc_discrete_gumbel_softmax
341
+
342
+ def __str__(self) -> str:
343
+ return 'Gumbel-Softmax'
241
344
 
242
345
 
243
346
  class Determinization(RandomSampling):
244
347
  '''Random sampling of variables using their deterministic mean estimate.'''
245
348
 
246
- def discrete(self, logic):
247
- if logic.verbose:
248
- raise_warning('Using the replacement rule: '
249
- 'Discrete(p) --> sum(i * p[i])')
250
-
251
- def _jax_wrapped_calc_discrete_determinized(key, prob, param):
252
- literals = FuzzyLogic.enumerate_literals(prob.shape, axis=-1)
253
- sample = jnp.sum(literals * prob, axis=-1)
254
- return sample
255
-
256
- return _jax_wrapped_calc_discrete_determinized, None
257
-
258
- def poisson(self, logic):
259
- if logic.verbose:
260
- raise_warning('Using the replacement rule: Poisson(rate) --> rate')
261
-
262
- def _jax_wrapped_calc_poisson_determinized(key, rate, param):
263
- return rate
264
-
265
- return _jax_wrapped_calc_poisson_determinized, None
266
-
267
- def geometric(self, logic):
268
- if logic.verbose:
269
- raise_warning('Using the replacement rule: Geometric(p) --> 1 / p')
270
-
271
- def _jax_wrapped_calc_geometric_determinized(key, prob, param):
272
- sample = 1.0 / prob
273
- return sample
349
+ @staticmethod
350
+ def _jax_wrapped_calc_discrete_determinized(key, prob, params):
351
+ literals = enumerate_literals(prob.shape, axis=-1)
352
+ sample = jnp.sum(literals * prob, axis=-1)
353
+ return sample, params
274
354
 
275
- return _jax_wrapped_calc_geometric_determinized, None
355
+ def discrete(self, id, init_params, logic):
356
+ return self._jax_wrapped_calc_discrete_determinized
357
+
358
+ @staticmethod
359
+ def _jax_wrapped_calc_poisson_determinized(key, rate, params):
360
+ return rate, params
276
361
 
362
+ def poisson(self, id, init_params, logic):
363
+ return self._jax_wrapped_calc_poisson_determinized
364
+
365
+ @staticmethod
366
+ def _jax_wrapped_calc_geometric_determinized(key, prob, params):
367
+ sample = 1.0 / prob
368
+ return sample, params
369
+
370
+ def geometric(self, id, init_params, logic):
371
+ return self._jax_wrapped_calc_geometric_determinized
372
+
373
+ def __str__(self) -> str:
374
+ return 'Deterministic'
375
+
277
376
 
278
377
  # ===========================================================================
279
- # FUZZY LOGIC
378
+ # CONTROL FLOW
379
+ # - soft flow
280
380
  #
281
381
  # ===========================================================================
282
382
 
283
- class FuzzyLogic:
284
- '''A class representing fuzzy logic in JAX.
383
+ class ControlFlow:
384
+ '''A base class for control flow, including if and switch statements.'''
385
+
386
+ def if_then_else(self, id, init_params):
387
+ raise NotImplementedError
285
388
 
286
- Functionality can be customized by either providing a tnorm as parameters,
287
- or by overriding its methods.
288
- '''
389
+ def switch(self, id, init_params):
390
+ raise NotImplementedError
391
+
392
+
393
+ class SoftControlFlow(ControlFlow):
394
+ '''Soft control flow using a probabilistic interpretation.'''
395
+
396
+ def __init__(self, weight: float=10.0) -> None:
397
+ self.weight = weight
398
+
399
+ @staticmethod
400
+ def _jax_wrapped_calc_if_then_else_soft(c, a, b, params):
401
+ sample = c * a + (1.0 - c) * b
402
+ return sample, params
403
+
404
+ def if_then_else(self, id, init_params):
405
+ return self._jax_wrapped_calc_if_then_else_soft
406
+
407
+ def switch(self, id, init_params):
408
+ id_ = str(id)
409
+ init_params[id_] = self.weight
410
+ def _jax_wrapped_calc_switch_soft(pred, cases, params):
411
+ literals = enumerate_literals(cases.shape, axis=0)
412
+ pred = jnp.broadcast_to(pred[jnp.newaxis, ...], shape=cases.shape)
413
+ proximity = -jnp.square(pred - literals)
414
+ softcase = jax.nn.softmax(params[id_] * proximity, axis=0)
415
+ sample = jnp.sum(cases * softcase, axis=0)
416
+ return sample, params
417
+ return _jax_wrapped_calc_switch_soft
418
+
419
+ def __str__(self) -> str:
420
+ return f'Soft control flow with weight {self.weight}'
421
+
422
+
423
+ # ===========================================================================
424
+ # LOGIC
425
+ # - exact logic
426
+ # - fuzzy logic
427
+ #
428
+ # ===========================================================================
429
+
430
+
431
+ class Logic:
432
+ '''A base class for representing logic computations in JAX.'''
433
+
434
+ def __init__(self, use64bit: bool=False) -> None:
435
+ self.set_use64bit(use64bit)
436
+
437
+ def summarize_hyperparameters(self) -> None:
438
+ print(f'model relaxation:\n'
439
+ f' use_64_bit ={self.use64bit}')
440
+
441
+ def set_use64bit(self, use64bit: bool) -> None:
442
+ '''Toggles whether or not the JAX system will use 64 bit precision.'''
443
+ self.use64bit = use64bit
444
+ if use64bit:
445
+ self.REAL = jnp.float64
446
+ self.INT = jnp.int64
447
+ jax.config.update('jax_enable_x64', True)
448
+ else:
449
+ self.REAL = jnp.float32
450
+ self.INT = jnp.int32
451
+ jax.config.update('jax_enable_x64', False)
452
+
453
+ # ===========================================================================
454
+ # logical operators
455
+ # ===========================================================================
456
+
457
+ def logical_and(self, id, init_params):
458
+ raise NotImplementedError
459
+
460
+ def logical_not(self, id, init_params):
461
+ raise NotImplementedError
462
+
463
+ def logical_or(self, id, init_params):
464
+ raise NotImplementedError
465
+
466
+ def xor(self, id, init_params):
467
+ raise NotImplementedError
468
+
469
+ def implies(self, id, init_params):
470
+ raise NotImplementedError
471
+
472
+ def equiv(self, id, init_params):
473
+ raise NotImplementedError
474
+
475
+ def forall(self, id, init_params):
476
+ raise NotImplementedError
477
+
478
+ def exists(self, id, init_params):
479
+ raise NotImplementedError
480
+
481
+ # ===========================================================================
482
+ # comparison operators
483
+ # ===========================================================================
484
+
485
+ def greater_equal(self, id, init_params):
486
+ raise NotImplementedError
487
+
488
+ def greater(self, id, init_params):
489
+ raise NotImplementedError
490
+
491
+ def less_equal(self, id, init_params):
492
+ raise NotImplementedError
493
+
494
+ def less(self, id, init_params):
495
+ raise NotImplementedError
496
+
497
+ def equal(self, id, init_params):
498
+ raise NotImplementedError
499
+
500
+ def not_equal(self, id, init_params):
501
+ raise NotImplementedError
502
+
503
+ # ===========================================================================
504
+ # special functions
505
+ # ===========================================================================
506
+
507
+ def sgn(self, id, init_params):
508
+ raise NotImplementedError
509
+
510
+ def floor(self, id, init_params):
511
+ raise NotImplementedError
512
+
513
+ def round(self, id, init_params):
514
+ raise NotImplementedError
515
+
516
+ def ceil(self, id, init_params):
517
+ raise NotImplementedError
518
+
519
+ def div(self, id, init_params):
520
+ raise NotImplementedError
521
+
522
+ def mod(self, id, init_params):
523
+ raise NotImplementedError
524
+
525
+ def sqrt(self, id, init_params):
526
+ raise NotImplementedError
527
+
528
+ # ===========================================================================
529
+ # indexing
530
+ # ===========================================================================
531
+
532
+ def argmax(self, id, init_params):
533
+ raise NotImplementedError
534
+
535
+ def argmin(self, id, init_params):
536
+ raise NotImplementedError
537
+
538
+ # ===========================================================================
539
+ # control flow
540
+ # ===========================================================================
541
+
542
+ def control_if(self, id, init_params):
543
+ raise NotImplementedError
544
+
545
+ def control_switch(self, id, init_params):
546
+ raise NotImplementedError
547
+
548
+ # ===========================================================================
549
+ # random variables
550
+ # ===========================================================================
551
+
552
+ def discrete(self, id, init_params):
553
+ raise NotImplementedError
554
+
555
+ def bernoulli(self, id, init_params):
556
+ raise NotImplementedError
557
+
558
+ def poisson(self, id, init_params):
559
+ raise NotImplementedError
560
+
561
+ def geometric(self, id, init_params):
562
+ raise NotImplementedError
563
+
564
+
565
+ class ExactLogic(Logic):
566
+ '''A class representing exact logic in JAX.'''
567
+
568
+ @staticmethod
569
+ def exact_unary_function(op):
570
+ def _jax_wrapped_calc_unary_function_exact(x, params):
571
+ return op(x), params
572
+ return _jax_wrapped_calc_unary_function_exact
573
+
574
+ @staticmethod
575
+ def exact_binary_function(op):
576
+ def _jax_wrapped_calc_binary_function_exact(x, y, params):
577
+ return op(x, y), params
578
+ return _jax_wrapped_calc_binary_function_exact
579
+
580
+ @staticmethod
581
+ def exact_aggregation(op):
582
+ def _jax_wrapped_calc_aggregation_exact(x, axis, params):
583
+ return op(x, axis=axis), params
584
+ return _jax_wrapped_calc_aggregation_exact
585
+
586
+ # ===========================================================================
587
+ # logical operators
588
+ # ===========================================================================
589
+
590
+ def logical_and(self, id, init_params):
591
+ return self.exact_binary_function(jnp.logical_and)
592
+
593
+ def logical_not(self, id, init_params):
594
+ return self.exact_unary_function(jnp.logical_not)
595
+
596
+ def logical_or(self, id, init_params):
597
+ return self.exact_binary_function(jnp.logical_or)
598
+
599
+ def xor(self, id, init_params):
600
+ return self.exact_binary_function(jnp.logical_xor)
601
+
602
+ @staticmethod
603
+ def exact_binary_implies(x, y, params):
604
+ return jnp.logical_or(jnp.logical_not(x), y), params
605
+
606
+ def implies(self, id, init_params):
607
+ return self.exact_binary_implies
608
+
609
+ def equiv(self, id, init_params):
610
+ return self.exact_binary_function(jnp.equal)
611
+
612
+ def forall(self, id, init_params):
613
+ return self.exact_aggregation(jnp.all)
614
+
615
+ def exists(self, id, init_params):
616
+ return self.exact_aggregation(jnp.any)
617
+
618
+ # ===========================================================================
619
+ # comparison operators
620
+ # ===========================================================================
621
+
622
+ def greater_equal(self, id, init_params):
623
+ return self.exact_binary_function(jnp.greater_equal)
624
+
625
+ def greater(self, id, init_params):
626
+ return self.exact_binary_function(jnp.greater)
627
+
628
+ def less_equal(self, id, init_params):
629
+ return self.exact_binary_function(jnp.less_equal)
630
+
631
+ def less(self, id, init_params):
632
+ return self.exact_binary_function(jnp.less)
633
+
634
+ def equal(self, id, init_params):
635
+ return self.exact_binary_function(jnp.equal)
636
+
637
+ def not_equal(self, id, init_params):
638
+ return self.exact_binary_function(jnp.not_equal)
639
+
640
+ # ===========================================================================
641
+ # special functions
642
+ # ===========================================================================
643
+
644
+ def sgn(self, id, init_params):
645
+ return self.exact_unary_function(jnp.sign)
646
+
647
+ def floor(self, id, init_params):
648
+ return self.exact_unary_function(jnp.floor)
649
+
650
+ def round(self, id, init_params):
651
+ return self.exact_unary_function(jnp.round)
652
+
653
+ def ceil(self, id, init_params):
654
+ return self.exact_unary_function(jnp.ceil)
655
+
656
+ def div(self, id, init_params):
657
+ return self.exact_binary_function(jnp.floor_divide)
658
+
659
+ def mod(self, id, init_params):
660
+ return self.exact_binary_function(jnp.mod)
661
+
662
+ def sqrt(self, id, init_params):
663
+ return self.exact_unary_function(jnp.sqrt)
664
+
665
+ # ===========================================================================
666
+ # indexing
667
+ # ===========================================================================
668
+
669
+ def argmax(self, id, init_params):
670
+ return self.exact_aggregation(jnp.argmax)
671
+
672
+ def argmin(self, id, init_params):
673
+ return self.exact_aggregation(jnp.argmin)
674
+
675
+ # ===========================================================================
676
+ # control flow
677
+ # ===========================================================================
678
+
679
+ @staticmethod
680
+ def exact_if_then_else(c, a, b, params):
681
+ return jnp.where(c > 0.5, a, b), params
682
+
683
+ def control_if(self, id, init_params):
684
+ return self.exact_if_then_else
685
+
686
+ @staticmethod
687
+ def exact_switch(pred, cases, params):
688
+ pred = pred[jnp.newaxis, ...]
689
+ sample = jnp.take_along_axis(cases, pred, axis=0)
690
+ assert sample.shape[0] == 1
691
+ return sample[0, ...], params
692
+
693
+ def control_switch(self, id, init_params):
694
+ return self.exact_switch
695
+
696
+ # ===========================================================================
697
+ # random variables
698
+ # ===========================================================================
699
+
700
+ @staticmethod
701
+ def exact_discrete(key, prob, params):
702
+ return random.categorical(key=key, logits=jnp.log(prob), axis=-1), params
703
+
704
+ def discrete(self, id, init_params):
705
+ return self.exact_discrete
706
+
707
+ @staticmethod
708
+ def exact_bernoulli(key, prob, params):
709
+ return random.bernoulli(key, prob), params
710
+
711
+ def bernoulli(self, id, init_params):
712
+ return self.exact_bernoulli
713
+
714
+ @staticmethod
715
+ def exact_poisson(key, rate, params):
716
+ return random.poisson(key=key, lam=rate), params
717
+
718
+ def poisson(self, id, init_params):
719
+ return self.exact_poisson
720
+
721
+ @staticmethod
722
+ def exact_geometric(key, prob, params):
723
+ return random.geometric(key=key, p=prob), params
724
+
725
+ def geometric(self, id, init_params):
726
+ return self.exact_geometric
727
+
728
+
729
+ class FuzzyLogic(Logic):
730
+ '''A class representing fuzzy logic in JAX.'''
289
731
 
290
732
  def __init__(self, tnorm: TNorm=ProductTNorm(),
291
733
  complement: Complement=StandardComplement(),
292
734
  comparison: Comparison=SigmoidComparison(),
293
735
  sampling: RandomSampling=GumbelSoftmax(),
294
736
  rounding: Rounding=SoftRounding(),
295
- weight: float=10.0,
296
- debias: Optional[Set[str]]=None,
737
+ control: ControlFlow=SoftControlFlow(),
297
738
  eps: float=1e-15,
298
- verbose: bool=False,
299
739
  use64bit: bool=False) -> None:
300
740
  '''Creates a new fuzzy logic in Jax.
301
741
 
@@ -304,428 +744,212 @@ class FuzzyLogic:
304
744
  :param comparison: fuzzy operator for comparisons (>, >=, <, ==, ~=, ...)
305
745
  :param sampling: random sampling of non-reparameterizable distributions
306
746
  :param rounding: rounding floating values to integers
307
- :param weight: a sharpness parameter for sigmoid and softmax activations
308
- :param debias: which functions to de-bias approximate on forward pass
747
+ :param control: if and switch control structures
309
748
  :param eps: small positive float to mitigate underflow
310
- :param verbose: whether to dump replacements and other info to console
311
749
  :param use64bit: whether to perform arithmetic in 64 bit
312
750
  '''
751
+ super().__init__(use64bit=use64bit)
313
752
  self.tnorm = tnorm
314
753
  self.complement = complement
315
754
  self.comparison = comparison
316
755
  self.sampling = sampling
317
756
  self.rounding = rounding
318
- self.weight = float(weight)
319
- if debias is None:
320
- debias = set()
321
- self.debias = debias
757
+ self.control = control
322
758
  self.eps = eps
323
- self.verbose = verbose
324
- self.set_use64bit(use64bit)
325
759
 
326
- def set_use64bit(self, use64bit: bool) -> None:
327
- self.use64bit = use64bit
328
- if use64bit:
329
- self.REAL = jnp.float64
330
- self.INT = jnp.int64
331
- jax.config.update('jax_enable_x64', True)
332
- else:
333
- self.REAL = jnp.float32
334
- self.INT = jnp.int32
335
- jax.config.update('jax_enable_x64', False)
336
-
760
+ def __str__(self) -> str:
761
+ return (f'model relaxation:\n'
762
+ f' tnorm ={str(self.tnorm)}\n'
763
+ f' complement ={str(self.complement)}\n'
764
+ f' comparison ={str(self.comparison)}\n'
765
+ f' sampling ={str(self.sampling)}\n'
766
+ f' rounding ={str(self.rounding)}\n'
767
+ f' control ={str(self.control)}\n'
768
+ f' underflow_tol ={self.eps}\n'
769
+ f' use_64_bit ={self.use64bit}')
770
+
337
771
  def summarize_hyperparameters(self) -> None:
338
- print(f'model relaxation:\n'
339
- f' tnorm ={type(self.tnorm).__name__}\n'
340
- f' complement ={type(self.complement).__name__}\n'
341
- f' comparison ={type(self.comparison).__name__}\n'
342
- f' sampling ={type(self.sampling).__name__}\n'
343
- f' rounding ={type(self.rounding).__name__}\n'
344
- f' sigmoid_weight={self.weight}\n'
345
- f' cpfs_to_debias={self.debias}\n'
346
- f' underflow_tol ={self.eps}\n'
347
- f' use_64_bit ={self.use64bit}')
772
+ print(self.__str__())
348
773
 
349
774
  # ===========================================================================
350
775
  # logical operators
351
776
  # ===========================================================================
352
777
 
353
- def logical_and(self):
354
- if self.verbose:
355
- raise_warning('Using the replacement rule: a ^ b --> tnorm(a, b).')
356
-
357
- _and = self.tnorm.norm
358
-
359
- def _jax_wrapped_calc_and_approx(a, b, param):
360
- return _and(a, b)
361
-
362
- return _jax_wrapped_calc_and_approx, None
363
-
364
- def logical_not(self):
365
- if self.verbose:
366
- raise_warning('Using the replacement rule: ~a --> complement(a)')
367
-
368
- _not = self.complement
369
-
370
- def _jax_wrapped_calc_not_approx(x, param):
371
- return _not(x)
372
-
373
- return _jax_wrapped_calc_not_approx, None
374
-
375
- def logical_or(self):
376
- if self.verbose:
377
- raise_warning('Using the replacement rule: a | b --> tconorm(a, b).')
378
-
379
- _not = self.complement
380
- _and = self.tnorm.norm
381
-
382
- def _jax_wrapped_calc_or_approx(a, b, param):
383
- return _not(_and(_not(a), _not(b)))
384
-
385
- return _jax_wrapped_calc_or_approx, None
778
+ def logical_and(self, id, init_params):
779
+ return self.tnorm.norm(id, init_params)
780
+
781
+ def logical_not(self, id, init_params):
782
+ return self.complement(id, init_params)
783
+
784
+ def logical_or(self, id, init_params):
785
+ _not1 = self.complement(f'{id}_~1', init_params)
786
+ _not2 = self.complement(f'{id}_~2', init_params)
787
+ _and = self.tnorm.norm(f'{id}_^', init_params)
788
+ _not = self.complement(f'{id}_~', init_params)
789
+
790
+ def _jax_wrapped_calc_or_approx(x, y, params):
791
+ not_x, params = _not1(x, params)
792
+ not_y, params = _not2(y, params)
793
+ not_x_and_not_y, params = _and(not_x, not_y, params)
794
+ return _not(not_x_and_not_y, params)
795
+ return _jax_wrapped_calc_or_approx
386
796
 
387
- def xor(self):
388
- if self.verbose:
389
- raise_warning('Using the replacement rule: '
390
- 'a ~ b --> (a | b) ^ (a ^ b).')
391
-
392
- _not = self.complement
393
- _and = self.tnorm.norm
394
-
395
- def _jax_wrapped_calc_xor_approx(a, b, param):
396
- _or = _not(_and(_not(a), _not(b)))
397
- return _and(_or(a, b), _not(_and(a, b)))
398
-
399
- return _jax_wrapped_calc_xor_approx, None
400
-
401
- def implies(self):
402
- if self.verbose:
403
- raise_warning('Using the replacement rule: a => b --> ~a ^ b')
404
-
405
- _not = self.complement
406
- _and = self.tnorm.norm
407
-
408
- def _jax_wrapped_calc_implies_approx(a, b, param):
409
- return _not(_and(a, _not(b)))
410
-
411
- return _jax_wrapped_calc_implies_approx, None
412
-
413
- def equiv(self):
414
- if self.verbose:
415
- raise_warning('Using the replacement rule: '
416
- 'a <=> b --> (a => b) ^ (b => a)')
417
-
418
- _not = self.complement
419
- _and = self.tnorm.norm
420
-
421
- def _jax_wrapped_calc_equiv_approx(a, b, param):
422
- atob = _not(_and(a, _not(b)))
423
- btoa = _not(_and(b, _not(a)))
424
- return _and(atob, btoa)
425
-
426
- return _jax_wrapped_calc_equiv_approx, None
427
-
428
- def forall(self):
429
- if self.verbose:
430
- raise_warning('Using the replacement rule: '
431
- 'forall(a) --> a[1] ^ a[2] ^ ...')
432
-
433
- _forall = self.tnorm.norms
434
-
435
- def _jax_wrapped_calc_forall_approx(x, axis, param):
436
- return _forall(x, axis=axis)
437
-
438
- return _jax_wrapped_calc_forall_approx, None
439
-
440
- def exists(self):
441
- _not = self.complement
442
- jax_forall, jax_param = self.forall()
443
-
444
- def _jax_wrapped_calc_exists_approx(x, axis, param):
445
- return _not(jax_forall(_not(x), axis, param))
446
-
447
- return _jax_wrapped_calc_exists_approx, jax_param
797
+ def xor(self, id, init_params):
798
+ _not = self.complement(f'{id}_~', init_params)
799
+ _and1 = self.tnorm.norm(f'{id}_^1', init_params)
800
+ _and2 = self.tnorm.norm(f'{id}_^2', init_params)
801
+ _or = self.logical_or(f'{id}_|', init_params)
802
+
803
+ def _jax_wrapped_calc_xor_approx(x, y, params):
804
+ x_and_y, params = _and1(x, y, params)
805
+ not_x_and_y, params = _not(x_and_y, params)
806
+ x_or_y, params = _or(x, y, params)
807
+ return _and2(x_or_y, not_x_and_y, params)
808
+ return _jax_wrapped_calc_xor_approx
809
+
810
+ def implies(self, id, init_params):
811
+ _not = self.complement(f'{id}_~', init_params)
812
+ _or = self.logical_or(f'{id}_|', init_params)
813
+
814
+ def _jax_wrapped_calc_implies_approx(x, y, params):
815
+ not_x, params = _not(x, params)
816
+ return _or(not_x, y, params)
817
+ return _jax_wrapped_calc_implies_approx
818
+
819
+ def equiv(self, id, init_params):
820
+ _implies1 = self.implies(f'{id}_=>1', init_params)
821
+ _implies2 = self.implies(f'{id}_=>2', init_params)
822
+ _and = self.tnorm.norm(f'{id}_^', init_params)
823
+
824
+ def _jax_wrapped_calc_equiv_approx(x, y, params):
825
+ x_implies_y, params = _implies1(x, y, params)
826
+ y_implies_x, params = _implies2(y, x, params)
827
+ return _and(x_implies_y, y_implies_x, params)
828
+ return _jax_wrapped_calc_equiv_approx
829
+
830
+ def forall(self, id, init_params):
831
+ return self.tnorm.norms(id, init_params)
832
+
833
+ def exists(self, id, init_params):
834
+ _not1 = self.complement(f'{id}_~1', init_params)
835
+ _not2 = self.complement(f'{id}_~2', init_params)
836
+ _forall = self.forall(f'{id}_forall', init_params)
837
+
838
+ def _jax_wrapped_calc_exists_approx(x, axis, params):
839
+ not_x, params = _not1(x, params)
840
+ forall_not_x, params = _forall(not_x, axis, params)
841
+ return _not2(forall_not_x, params)
842
+ return _jax_wrapped_calc_exists_approx
448
843
 
449
844
  # ===========================================================================
450
845
  # comparison operators
451
846
  # ===========================================================================
452
-
453
- def greater_equal(self):
454
- if self.verbose:
455
- raise_warning('Using the replacement rule: '
456
- 'a >= b --> comparison.greater_equal(a, b)')
457
-
458
- greater_equal_op = self.comparison.greater_equal
459
- debias = 'greater_equal' in self.debias
460
-
461
- def _jax_wrapped_calc_geq_approx(a, b, param):
462
- sample = greater_equal_op(a, b, param)
463
- if debias:
464
- hard_sample = jnp.greater_equal(a, b)
465
- sample += jax.lax.stop_gradient(hard_sample - sample)
466
- return sample
467
-
468
- tags = ('weight', 'greater_equal')
469
- new_param = (tags, self.weight)
470
- return _jax_wrapped_calc_geq_approx, new_param
471
-
472
- def greater(self):
473
- if self.verbose:
474
- raise_warning('Using the replacement rule: '
475
- 'a > b --> comparison.greater(a, b)')
476
-
477
- greater_op = self.comparison.greater
478
- debias = 'greater' in self.debias
479
-
480
- def _jax_wrapped_calc_gre_approx(a, b, param):
481
- sample = greater_op(a, b, param)
482
- if debias:
483
- hard_sample = jnp.greater(a, b)
484
- sample += jax.lax.stop_gradient(hard_sample - sample)
485
- return sample
486
-
487
- tags = ('weight', 'greater')
488
- new_param = (tags, self.weight)
489
- return _jax_wrapped_calc_gre_approx, new_param
490
847
 
491
- def less_equal(self):
492
- jax_geq, jax_param = self.greater_equal()
493
-
494
- def _jax_wrapped_calc_leq_approx(a, b, param):
495
- return jax_geq(-a, -b, param)
496
-
497
- return _jax_wrapped_calc_leq_approx, jax_param
848
+ def greater_equal(self, id, init_params):
849
+ return self.comparison.greater_equal(id, init_params)
498
850
 
499
- def less(self):
500
- jax_gre, jax_param = self.greater()
501
-
502
- def _jax_wrapped_calc_less_approx(a, b, param):
503
- return jax_gre(-a, -b, param)
504
-
505
- return _jax_wrapped_calc_less_approx, jax_param
506
-
507
- def equal(self):
508
- if self.verbose:
509
- raise_warning('Using the replacement rule: '
510
- 'a == b --> comparison.equal(a, b)')
511
-
512
- equal_op = self.comparison.equal
513
- debias = 'equal' in self.debias
514
-
515
- def _jax_wrapped_calc_equal_approx(a, b, param):
516
- sample = equal_op(a, b, param)
517
- if debias:
518
- hard_sample = jnp.equal(a, b)
519
- sample += jax.lax.stop_gradient(hard_sample - sample)
520
- return sample
521
-
522
- tags = ('weight', 'equal')
523
- new_param = (tags, self.weight)
524
- return _jax_wrapped_calc_equal_approx, new_param
851
+ def greater(self, id, init_params):
852
+ return self.comparison.greater(id, init_params)
525
853
 
526
- def not_equal(self):
527
- _not = self.complement
528
- jax_eq, jax_param = self.equal()
529
-
530
- def _jax_wrapped_calc_neq_approx(a, b, param):
531
- return _not(jax_eq(a, b, param))
854
+ def less_equal(self, id, init_params):
855
+ _greater_eq = self.greater_equal(id, init_params)
856
+ def _jax_wrapped_calc_leq_approx(x, y, params):
857
+ return _greater_eq(-x, -y, params)
858
+ return _jax_wrapped_calc_leq_approx
859
+
860
+ def less(self, id, init_params):
861
+ _greater = self.greater(id, init_params)
862
+ def _jax_wrapped_calc_less_approx(x, y, params):
863
+ return _greater(-x, -y, params)
864
+ return _jax_wrapped_calc_less_approx
532
865
 
533
- return _jax_wrapped_calc_neq_approx, jax_param
866
+ def equal(self, id, init_params):
867
+ return self.comparison.equal(id, init_params)
868
+
869
+ def not_equal(self, id, init_params):
870
+ _not = self.complement(f'{id}_~', init_params)
871
+ _equal = self.comparison.equal(f'{id}_==', init_params)
872
+ def _jax_wrapped_calc_neq_approx(x, y, params):
873
+ equal, params = _equal(x, y, params)
874
+ return _not(equal, params)
875
+ return _jax_wrapped_calc_neq_approx
534
876
 
535
877
  # ===========================================================================
536
878
  # special functions
537
879
  # ===========================================================================
538
-
539
- def sgn(self):
540
- if self.verbose:
541
- raise_warning('Using the replacement rule: '
542
- 'sgn(x) --> comparison.sgn(x)')
543
-
544
- sgn_op = self.comparison.sgn
545
- debias = 'sgn' in self.debias
546
-
547
- def _jax_wrapped_calc_sgn_approx(x, param):
548
- sample = sgn_op(x, param)
549
- if debias:
550
- hard_sample = jnp.sign(x)
551
- sample += jax.lax.stop_gradient(hard_sample - sample)
552
- return sample
553
-
554
- tags = ('weight', 'sgn')
555
- new_param = (tags, self.weight)
556
- return _jax_wrapped_calc_sgn_approx, new_param
557
-
558
- def floor(self):
559
- if self.verbose:
560
- raise_warning('Using the replacement rule: '
561
- 'floor(x) --> rounding.floor(x)')
562
-
563
- floor_op = self.rounding.floor
564
- debias = 'floor' in self.debias
565
-
566
- def _jax_wrapped_calc_floor_approx(x, param):
567
- sample = floor_op(x, param)
568
- if debias:
569
- hard_sample = jnp.floor(x)
570
- sample += jax.lax.stop_gradient(hard_sample - sample)
571
- return sample
572
-
573
- tags = ('weight', 'floor')
574
- new_param = (tags, self.weight)
575
- return _jax_wrapped_calc_floor_approx, new_param
576
-
577
- def round(self):
578
- if self.verbose:
579
- raise_warning('Using the replacement rule: '
580
- 'round(x) --> rounding.round(x)')
581
-
582
- round_op = self.rounding.round
583
- debias = 'round' in self.debias
584
-
585
- def _jax_wrapped_calc_round_approx(x, param):
586
- sample = round_op(x, param)
587
- if debias:
588
- hard_sample = jnp.round(x)
589
- sample += jax.lax.stop_gradient(hard_sample - sample)
590
- return sample
591
-
592
- tags = ('weight', 'round')
593
- new_param = (tags, self.weight)
594
- return _jax_wrapped_calc_round_approx, new_param
595
880
 
596
- def ceil(self):
597
- jax_floor, jax_param = self.floor()
598
-
599
- def _jax_wrapped_calc_ceil_approx(x, param):
600
- return -jax_floor(-x, param)
601
-
602
- return _jax_wrapped_calc_ceil_approx, jax_param
881
+ def sgn(self, id, init_params):
882
+ return self.comparison.sgn(id, init_params)
603
883
 
604
- def mod(self):
605
- jax_floor, jax_param = self.floor()
884
+ def floor(self, id, init_params):
885
+ return self.rounding.floor(id, init_params)
606
886
 
607
- def _jax_wrapped_calc_mod_approx(x, y, param):
608
- return x - y * jax_floor(x / y, param)
609
-
610
- return _jax_wrapped_calc_mod_approx, jax_param
887
+ def round(self, id, init_params):
888
+ return self.rounding.round(id, init_params)
611
889
 
612
- def div(self):
613
- jax_floor, jax_param = self.floor()
614
-
615
- def _jax_wrapped_calc_div_approx(x, y, param):
616
- return jax_floor(x / y, param)
617
-
618
- return _jax_wrapped_calc_div_approx, jax_param
890
+ def ceil(self, id, init_params):
891
+ _floor = self.rounding.floor(id, init_params)
892
+ def _jax_wrapped_calc_ceil_approx(x, params):
893
+ neg_floor, params = _floor(-x, params)
894
+ return -neg_floor, params
895
+ return _jax_wrapped_calc_ceil_approx
619
896
 
620
- def sqrt(self):
621
- if self.verbose:
622
- raise_warning('Using the replacement rule: sqrt(x) --> sqrt(x + eps)')
623
-
624
- def _jax_wrapped_calc_sqrt_approx(x, param):
625
- return jnp.sqrt(x + self.eps)
626
-
627
- return _jax_wrapped_calc_sqrt_approx, None
897
+ def div(self, id, init_params):
898
+ _floor = self.rounding.floor(id, init_params)
899
+ def _jax_wrapped_calc_div_approx(x, y, params):
900
+ return _floor(x / y, params)
901
+ return _jax_wrapped_calc_div_approx
902
+
903
+ def mod(self, id, init_params):
904
+ _div = self.div(id, init_params)
905
+ def _jax_wrapped_calc_mod_approx(x, y, params):
906
+ div, params = _div(x, y, params)
907
+ return x - y * div, params
908
+ return _jax_wrapped_calc_mod_approx
909
+
910
+ def sqrt(self, id, init_params):
911
+ def _jax_wrapped_calc_sqrt_approx(x, params):
912
+ return jnp.sqrt(x + self.eps), params
913
+ return _jax_wrapped_calc_sqrt_approx
628
914
 
629
915
  # ===========================================================================
630
916
  # indexing
631
917
  # ===========================================================================
632
918
 
633
- @staticmethod
634
- def enumerate_literals(shape, axis):
635
- literals = jnp.arange(shape[axis])
636
- literals = literals[(...,) + (jnp.newaxis,) * (len(shape) - 1)]
637
- literals = jnp.moveaxis(literals, source=0, destination=axis)
638
- literals = jnp.broadcast_to(literals, shape=shape)
639
- return literals
640
-
641
- def argmax(self):
642
- if self.verbose:
643
- raise_warning('Using the replacement rule: '
644
- 'argmax(x) --> sum(i * softmax(x[i]))')
645
-
646
- debias = 'argmax' in self.debias
647
-
648
- # https://arxiv.org/abs/2110.05651
649
- def _jax_wrapped_calc_argmax_approx(x, axis, param):
650
- literals = FuzzyLogic.enumerate_literals(x.shape, axis=axis)
651
- soft_max = jax.nn.softmax(param * x, axis=axis)
652
- sample = jnp.sum(literals * soft_max, axis=axis)
653
- if debias:
654
- hard_sample = jnp.argmax(x, axis=axis)
655
- sample += jax.lax.stop_gradient(hard_sample - sample)
656
- return sample
657
-
658
- tags = ('weight', 'argmax')
659
- new_param = (tags, self.weight)
660
- return _jax_wrapped_calc_argmax_approx, new_param
919
+ def argmax(self, id, init_params):
920
+ return self.comparison.argmax(id, init_params)
661
921
 
662
- def argmin(self):
663
- jax_argmax, jax_param = self.argmax()
664
-
922
+ def argmin(self, id, init_params):
923
+ _argmax = self.argmax(id, init_params)
665
924
  def _jax_wrapped_calc_argmin_approx(x, axis, param):
666
- return jax_argmax(-x, axis, param)
667
-
668
- return _jax_wrapped_calc_argmin_approx, jax_param
925
+ return _argmax(-x, axis, param)
926
+ return _jax_wrapped_calc_argmin_approx
669
927
 
670
928
  # ===========================================================================
671
929
  # control flow
672
930
  # ===========================================================================
673
931
 
674
- def control_if(self):
675
- if self.verbose:
676
- raise_warning('Using the replacement rule: '
677
- 'if c then a else b --> c * a + (1 - c) * b')
678
-
679
- debias = 'if' in self.debias
680
-
681
- def _jax_wrapped_calc_if_approx(c, a, b, param):
682
- sample = c * a + (1.0 - c) * b
683
- if debias:
684
- hard_sample = jnp.where(c > 0.5, a, b)
685
- sample += jax.lax.stop_gradient(hard_sample - sample)
686
- return sample
687
-
688
- return _jax_wrapped_calc_if_approx, None
689
-
690
- def control_switch(self):
691
- if self.verbose:
692
- raise_warning('Using the replacement rule: '
693
- 'switch(pred) { cases } --> '
694
- 'sum(cases[i] * softmax(-(pred - i)^2))')
695
-
696
- debias = 'switch' in self.debias
697
-
698
- def _jax_wrapped_calc_switch_approx(pred, cases, param):
699
- literals = FuzzyLogic.enumerate_literals(cases.shape, axis=0)
700
- pred = jnp.broadcast_to(pred[jnp.newaxis, ...], shape=cases.shape)
701
- proximity = -jnp.square(pred - literals)
702
- soft_case = jax.nn.softmax(param * proximity, axis=0)
703
- sample = jnp.sum(cases * soft_case, axis=0)
704
- if debias:
705
- hard_case = jnp.argmax(proximity, axis=0)[jnp.newaxis, ...]
706
- hard_sample = jnp.take_along_axis(cases, hard_case, axis=0)[0, ...]
707
- sample += jax.lax.stop_gradient(hard_sample - sample)
708
- return sample
709
-
710
- tags = ('weight', 'switch')
711
- new_param = (tags, self.weight)
712
- return _jax_wrapped_calc_switch_approx, new_param
932
+ def control_if(self, id, init_params):
933
+ return self.control.if_then_else(id, init_params)
934
+
935
+ def control_switch(self, id, init_params):
936
+ return self.control.switch(id, init_params)
713
937
 
714
938
  # ===========================================================================
715
939
  # random variables
716
940
  # ===========================================================================
717
941
 
718
- def discrete(self):
719
- return self.sampling.discrete(self)
942
+ def discrete(self, id, init_params):
943
+ return self.sampling.discrete(id, init_params, self)
720
944
 
721
- def bernoulli(self):
722
- return self.sampling.bernoulli(self)
945
+ def bernoulli(self, id, init_params):
946
+ return self.sampling.bernoulli(id, init_params, self)
723
947
 
724
- def poisson(self):
725
- return self.sampling.poisson(self)
948
+ def poisson(self, id, init_params):
949
+ return self.sampling.poisson(id, init_params, self)
726
950
 
727
- def geometric(self):
728
- return self.sampling.geometric(self)
951
+ def geometric(self, id, init_params):
952
+ return self.sampling.geometric(id, init_params, self)
729
953
 
730
954
 
731
955
  # ===========================================================================
@@ -733,105 +957,121 @@ class FuzzyLogic:
733
957
  #
734
958
  # ===========================================================================
735
959
 
736
- logic = FuzzyLogic()
737
- w = 1000.0
960
+ logic = FuzzyLogic(comparison=SigmoidComparison(10000.0),
961
+ rounding=SoftRounding(10000.0),
962
+ control=SoftControlFlow(10000.0))
738
963
 
739
964
 
740
965
  def _test_logical():
741
966
  print('testing logical')
742
- _and, _ = logic.logical_and()
743
- _not, _ = logic.logical_not()
744
- _gre, _ = logic.greater()
745
- _or, _ = logic.logical_or()
746
- _if, _ = logic.control_if()
747
-
967
+ init_params = {}
968
+ _and = logic.logical_and(0, init_params)
969
+ _not = logic.logical_not(1, init_params)
970
+ _gre = logic.greater(2, init_params)
971
+ _or = logic.logical_or(3, init_params)
972
+ _if = logic.control_if(4, init_params)
973
+ print(init_params)
974
+
748
975
  # https://towardsdatascience.com/emulating-logical-gates-with-a-neural-network-75c229ec4cc9
749
- def test_logic(x1, x2):
750
- q1 = _and(_gre(x1, 0, w), _gre(x2, 0, w), w)
751
- q2 = _and(_not(_gre(x1, 0, w), w), _not(_gre(x2, 0, w), w), w)
752
- cond = _or(q1, q2, w)
753
- pred = _if(cond, +1, -1, w)
976
+ def test_logic(x1, x2, w):
977
+ q1, w = _gre(x1, 0, w)
978
+ q2, w = _gre(x2, 0, w)
979
+ q3, w = _and(q1, q2, w)
980
+ q4, w = _not(q1, w)
981
+ q5, w = _not(q2, w)
982
+ q6, w = _and(q4, q5, w)
983
+ cond, w = _or(q3, q6, w)
984
+ pred, w = _if(cond, +1, -1, w)
754
985
  return pred
755
986
 
756
987
  x1 = jnp.asarray([1, 1, -1, -1, 0.1, 15, -0.5]).astype(float)
757
988
  x2 = jnp.asarray([1, -1, 1, -1, 10, -30, 6]).astype(float)
758
- print(test_logic(x1, x2))
989
+ print(test_logic(x1, x2, init_params))
759
990
 
760
991
 
761
992
  def _test_indexing():
762
993
  print('testing indexing')
763
- _argmax, _ = logic.argmax()
764
- _argmin, _ = logic.argmin()
994
+ init_params = {}
995
+ _argmax = logic.argmax(0, init_params)
996
+ _argmin = logic.argmin(1, init_params)
997
+ print(init_params)
765
998
 
766
- def argmaxmin(x):
767
- amax = _argmax(x, 0, w)
768
- amin = _argmin(x, 0, w)
999
+ def argmaxmin(x, w):
1000
+ amax, w = _argmax(x, 0, w)
1001
+ amin, w = _argmin(x, 0, w)
769
1002
  return amax, amin
770
1003
 
771
1004
  values = jnp.asarray([2., 3., 5., 4.9, 4., 1., -1., -2.])
772
- amax, amin = argmaxmin(values)
1005
+ amax, amin = argmaxmin(values, init_params)
773
1006
  print(amax)
774
1007
  print(amin)
775
1008
 
776
1009
 
777
1010
  def _test_control():
778
1011
  print('testing control')
779
- _switch, _ = logic.control_switch()
1012
+ init_params = {}
1013
+ _switch = logic.control_switch(0, init_params)
1014
+ print(init_params)
780
1015
 
781
1016
  pred = jnp.asarray(jnp.linspace(0, 2, 10))
782
1017
  case1 = jnp.asarray([-10.] * 10)
783
1018
  case2 = jnp.asarray([1.5] * 10)
784
1019
  case3 = jnp.asarray([10.] * 10)
785
1020
  cases = jnp.asarray([case1, case2, case3])
786
- print(_switch(pred, cases, w))
1021
+ switch, _ = _switch(pred, cases, init_params)
1022
+ print(switch)
787
1023
 
788
1024
 
789
1025
  def _test_random():
790
1026
  print('testing random')
791
1027
  key = random.PRNGKey(42)
792
- _bernoulli, _ = logic.bernoulli()
793
- _discrete, _ = logic.discrete()
794
- _geometric, _ = logic.geometric()
1028
+ init_params = {}
1029
+ _bernoulli = logic.bernoulli(0, init_params)
1030
+ _discrete = logic.discrete(1, init_params)
1031
+ _geometric = logic.geometric(2, init_params)
1032
+ print(init_params)
795
1033
 
796
- def bern(n):
1034
+ def bern(n, w):
797
1035
  prob = jnp.asarray([0.3] * n)
798
- sample = _bernoulli(key, prob, w)
1036
+ sample, _ = _bernoulli(key, prob, w)
799
1037
  return sample
800
1038
 
801
- samples = bern(50000)
1039
+ samples = bern(50000, init_params)
802
1040
  print(jnp.mean(samples))
803
1041
 
804
- def disc(n):
1042
+ def disc(n, w):
805
1043
  prob = jnp.asarray([0.1, 0.4, 0.5])
806
1044
  prob = jnp.tile(prob, (n, 1))
807
- sample = _discrete(key, prob, w)
1045
+ sample, _ = _discrete(key, prob, w)
808
1046
  return sample
809
1047
 
810
- samples = disc(50000)
1048
+ samples = disc(50000, init_params)
811
1049
  samples = jnp.round(samples)
812
1050
  print([jnp.mean(samples == i) for i in range(3)])
813
1051
 
814
- def geom(n):
1052
+ def geom(n, w):
815
1053
  prob = jnp.asarray([0.3] * n)
816
- sample = _geometric(key, prob, w)
1054
+ sample, _ = _geometric(key, prob, w)
817
1055
  return sample
818
1056
 
819
- samples = geom(50000)
1057
+ samples = geom(50000, init_params)
820
1058
  print(jnp.mean(samples))
821
1059
 
822
1060
 
823
1061
  def _test_rounding():
824
1062
  print('testing rounding')
825
- _floor, _ = logic.floor()
826
- _ceil, _ = logic.ceil()
827
- _round, _ = logic.round()
828
- _mod, _ = logic.mod()
1063
+ init_params = {}
1064
+ _floor = logic.floor(0, init_params)
1065
+ _ceil = logic.ceil(1, init_params)
1066
+ _round = logic.round(2, init_params)
1067
+ _mod = logic.mod(3, init_params)
1068
+ print(init_params)
829
1069
 
830
1070
  x = jnp.asarray([2.1, 0.6, 1.99, -2.01, -3.2, -0.1, -1.01, 23.01, -101.99, 200.01])
831
- print(_floor(x, w))
832
- print(_ceil(x, w))
833
- print(_round(x, w))
834
- print(_mod(x, 2.0, w))
1071
+ print(_floor(x, init_params)[0])
1072
+ print(_ceil(x, init_params)[0])
1073
+ print(_round(x, init_params)[0])
1074
+ print(_mod(x, 2.0, init_params)[0])
835
1075
 
836
1076
 
837
1077
  if __name__ == '__main__':