pyRDDLGym-jax 0.4__py3-none-any.whl → 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. pyRDDLGym_jax/__init__.py +1 -1
  2. pyRDDLGym_jax/core/compiler.py +463 -592
  3. pyRDDLGym_jax/core/logic.py +832 -530
  4. pyRDDLGym_jax/core/planner.py +422 -474
  5. pyRDDLGym_jax/core/simulator.py +7 -5
  6. pyRDDLGym_jax/core/tuning.py +390 -584
  7. pyRDDLGym_jax/core/visualization.py +1463 -0
  8. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +5 -6
  9. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +5 -5
  10. pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +5 -6
  11. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
  12. pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +4 -4
  13. pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +3 -3
  14. pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +3 -3
  15. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +3 -3
  16. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +5 -4
  17. pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +3 -3
  18. pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +3 -3
  19. pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -4
  20. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +3 -3
  21. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +5 -4
  22. pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +5 -5
  23. pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +4 -4
  24. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +3 -3
  25. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +7 -6
  26. pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +3 -3
  27. pyRDDLGym_jax/examples/configs/default_drp.cfg +3 -3
  28. pyRDDLGym_jax/examples/configs/default_replan.cfg +5 -4
  29. pyRDDLGym_jax/examples/configs/default_slp.cfg +3 -3
  30. pyRDDLGym_jax/examples/configs/tuning_drp.cfg +19 -0
  31. pyRDDLGym_jax/examples/configs/tuning_replan.cfg +20 -0
  32. pyRDDLGym_jax/examples/configs/tuning_slp.cfg +19 -0
  33. pyRDDLGym_jax/examples/run_plan.py +4 -1
  34. pyRDDLGym_jax/examples/run_tune.py +40 -29
  35. {pyRDDLGym_jax-0.4.dist-info → pyRDDLGym_jax-1.0.dist-info}/METADATA +164 -105
  36. pyRDDLGym_jax-1.0.dist-info/RECORD +45 -0
  37. {pyRDDLGym_jax-0.4.dist-info → pyRDDLGym_jax-1.0.dist-info}/WHEEL +1 -1
  38. pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -19
  39. pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -20
  40. pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -18
  41. pyRDDLGym_jax-0.4.dist-info/RECORD +0 -44
  42. {pyRDDLGym_jax-0.4.dist-info → pyRDDLGym_jax-1.0.dist-info}/LICENSE +0 -0
  43. {pyRDDLGym_jax-0.4.dist-info → pyRDDLGym_jax-1.0.dist-info}/top_level.txt +0 -0
@@ -4,63 +4,169 @@ import jax
4
4
  import jax.numpy as jnp
5
5
  import jax.random as random
6
6
 
7
- from pyRDDLGym.core.debug.exception import raise_warning
7
+
8
+ def enumerate_literals(shape, axis, dtype=jnp.int32):
9
+ literals = jnp.arange(shape[axis], dtype=dtype)
10
+ literals = literals[(...,) + (jnp.newaxis,) * (len(shape) - 1)]
11
+ literals = jnp.moveaxis(literals, source=0, destination=axis)
12
+ literals = jnp.broadcast_to(literals, shape=shape)
13
+ return literals
8
14
 
9
15
 
10
16
  # ===========================================================================
11
- # LOGICAL COMPLEMENT
17
+ # RELATIONAL OPERATIONS
12
18
  # - abstract class
13
- # - standard complement
19
+ # - sigmoid comparison
14
20
  #
15
21
  # ===========================================================================
16
22
 
17
- class Complement:
18
- '''Base class for approximate logical complement operations.'''
23
+ class Comparison:
24
+ '''Base class for approximate comparison operations.'''
19
25
 
20
- def __call__(self, x):
26
+ def greater_equal(self, id, init_params):
21
27
  raise NotImplementedError
28
+
29
+ def greater(self, id, init_params):
30
+ raise NotImplementedError
31
+
32
+ def equal(self, id, init_params):
33
+ raise NotImplementedError
34
+
35
+ def sgn(self, id, init_params):
36
+ raise NotImplementedError
37
+
38
+ def argmax(self, id, init_params):
39
+ raise NotImplementedError
40
+
22
41
 
23
-
24
- class StandardComplement(Complement):
25
- '''The standard approximate logical complement given by x -> 1 - x.'''
42
+ class SigmoidComparison(Comparison):
43
+ '''Comparison operations approximated using sigmoid functions.'''
26
44
 
27
- def __call__(self, x):
28
- return 1.0 - x
45
+ def __init__(self, weight: float=10.0):
46
+ self.weight = weight
47
+
48
+ # https://arxiv.org/abs/2110.05651
49
+ def greater_equal(self, id, init_params):
50
+ id_ = str(id)
51
+ init_params[id_] = self.weight
52
+ def _jax_wrapped_calc_greater_equal_approx(x, y, params):
53
+ gre_eq = jax.nn.sigmoid(params[id_] * (x - y))
54
+ return gre_eq, params
55
+ return _jax_wrapped_calc_greater_equal_approx
56
+
57
+ def greater(self, id, init_params):
58
+ return self.greater_equal(id, init_params)
59
+
60
+ def equal(self, id, init_params):
61
+ id_ = str(id)
62
+ init_params[id_] = self.weight
63
+ def _jax_wrapped_calc_equal_approx(x, y, params):
64
+ equal = 1.0 - jnp.square(jnp.tanh(params[id_] * (y - x)))
65
+ return equal, params
66
+ return _jax_wrapped_calc_equal_approx
67
+
68
+ def sgn(self, id, init_params):
69
+ id_ = str(id)
70
+ init_params[id_] = self.weight
71
+ def _jax_wrapped_calc_sgn_approx(x, params):
72
+ sgn = jnp.tanh(params[id_] * x)
73
+ return sgn, params
74
+ return _jax_wrapped_calc_sgn_approx
75
+
76
+ # https://arxiv.org/abs/2110.05651
77
+ def argmax(self, id, init_params):
78
+ id_ = str(id)
79
+ init_params[id_] = self.weight
80
+ def _jax_wrapped_calc_argmax_approx(x, axis, params):
81
+ literals = enumerate_literals(x.shape, axis=axis)
82
+ softmax = jax.nn.softmax(params[id_] * x, axis=axis)
83
+ sample = jnp.sum(literals * softmax, axis=axis)
84
+ return sample, params
85
+ return _jax_wrapped_calc_argmax_approx
86
+
87
+ def __str__(self) -> str:
88
+ return f'Sigmoid comparison with weight {self.weight}'
29
89
 
30
90
 
31
91
  # ===========================================================================
32
- # RELATIONAL OPERATIONS
92
+ # ROUNDING OPERATIONS
33
93
  # - abstract class
34
- # - sigmoid comparison
94
+ # - soft rounding
35
95
  #
36
96
  # ===========================================================================
37
97
 
38
- class Comparison:
39
- '''Base class for approximate comparison operations.'''
98
+ class Rounding:
99
+ '''Base class for approximate rounding operations.'''
40
100
 
41
- def greater_equal(self, x, y, param):
101
+ def floor(self, id, init_params):
42
102
  raise NotImplementedError
43
103
 
44
- def greater(self, x, y, param):
104
+ def round(self, id, init_params):
45
105
  raise NotImplementedError
106
+
107
+
108
+ class SoftRounding(Rounding):
109
+ '''Rounding operations approximated using soft operations.'''
110
+
111
+ def __init__(self, weight: float=10.0):
112
+ self.weight = weight
113
+
114
+ # https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/bijectors/Softfloor
115
+ def floor(self, id, init_params):
116
+ id_ = str(id)
117
+ init_params[id_] = self.weight
118
+ def _jax_wrapped_calc_floor_approx(x, params):
119
+ param = params[id_]
120
+ denom = jnp.tanh(param / 4.0)
121
+ floor = (jax.nn.sigmoid(param * (x - jnp.floor(x) - 1.0)) -
122
+ jax.nn.sigmoid(-param / 2.0)) / denom + jnp.floor(x)
123
+ return floor, params
124
+ return _jax_wrapped_calc_floor_approx
125
+
126
+ # https://arxiv.org/abs/2006.09952
127
+ def round(self, id, init_params):
128
+ id_ = str(id)
129
+ init_params[id_] = self.weight
130
+ def _jax_wrapped_calc_round_approx(x, params):
131
+ param = params[id_]
132
+ m = jnp.floor(x) + 0.5
133
+ rounded = m + 0.5 * jnp.tanh(param * (x - m)) / jnp.tanh(param / 2.0)
134
+ return rounded, params
135
+ return _jax_wrapped_calc_round_approx
136
+
137
+ def __str__(self) -> str:
138
+ return f'SoftFloor and SoftRound with weight {self.weight}'
139
+
140
+
141
+ # ===========================================================================
142
+ # LOGICAL COMPLEMENT
143
+ # - abstract class
144
+ # - standard complement
145
+ #
146
+ # ===========================================================================
147
+
148
+ class Complement:
149
+ '''Base class for approximate logical complement operations.'''
46
150
 
47
- def equal(self, x, y, param):
151
+ def __call__(self, id, init_params):
48
152
  raise NotImplementedError
49
-
50
153
 
51
- class SigmoidComparison(Comparison):
52
- '''Comparison operations approximated using sigmoid functions.'''
154
+
155
+ class StandardComplement(Complement):
156
+ '''The standard approximate logical complement given by x -> 1 - x.'''
53
157
 
54
- def greater_equal(self, x, y, param):
55
- return jax.nn.sigmoid(param * (x - y))
158
+ # https://www.sciencedirect.com/science/article/abs/pii/016501149190171L
159
+ @staticmethod
160
+ def _jax_wrapped_calc_not_approx(x, params):
161
+ return 1.0 - x, params
56
162
 
57
- def greater(self, x, y, param):
58
- return jax.nn.sigmoid(param * (x - y))
163
+ def __call__(self, id, init_params):
164
+ return self._jax_wrapped_calc_not_approx
59
165
 
60
- def equal(self, x, y, param):
61
- return 1.0 - jnp.square(jnp.tanh(param * (y - x)))
62
-
63
-
166
+ def __str__(self) -> str:
167
+ return 'Standard complement'
168
+
169
+
64
170
  # ===========================================================================
65
171
  # TNORMS
66
172
  # - abstract tnorm
@@ -69,48 +175,84 @@ class SigmoidComparison(Comparison):
69
175
  # - Lukasiewicz tnorm
70
176
  # - Yager(p) tnorm
71
177
  #
178
+ # https://www.sciencedirect.com/science/article/abs/pii/016501149190171L
72
179
  # ===========================================================================
73
180
 
74
181
  class TNorm:
75
182
  '''Base class for fuzzy differentiable t-norms.'''
76
183
 
77
- def norm(self, x, y):
184
+ def norm(self, id, init_params):
78
185
  '''Elementwise t-norm of x and y.'''
79
186
  raise NotImplementedError
80
187
 
81
- def norms(self, x, axis):
188
+ def norms(self, id, init_params):
82
189
  '''T-norm computed for tensor x along axis.'''
83
190
  raise NotImplementedError
84
-
191
+
85
192
 
86
193
  class ProductTNorm(TNorm):
87
194
  '''Product t-norm given by the expression (x, y) -> x * y.'''
88
195
 
89
- def norm(self, x, y):
90
- return x * y
196
+ @staticmethod
197
+ def _jax_wrapped_calc_and_approx(x, y, params):
198
+ return x * y, params
199
+
200
+ def norm(self, id, init_params):
201
+ return self._jax_wrapped_calc_and_approx
91
202
 
92
- def norms(self, x, axis):
93
- return jnp.prod(x, axis=axis)
203
+ @staticmethod
204
+ def _jax_wrapped_calc_forall_approx(x, axis, params):
205
+ return jnp.prod(x, axis=axis), params
206
+
207
+ def norms(self, id, init_params):
208
+ return self._jax_wrapped_calc_forall_approx
94
209
 
210
+ def __str__(self) -> str:
211
+ return 'Product t-norm'
212
+
95
213
 
96
214
  class GodelTNorm(TNorm):
97
215
  '''Godel t-norm given by the expression (x, y) -> min(x, y).'''
98
216
 
99
- def norm(self, x, y):
100
- return jnp.minimum(x, y)
217
+ @staticmethod
218
+ def _jax_wrapped_calc_and_approx(x, y, params):
219
+ return jnp.minimum(x, y), params
220
+
221
+ def norm(self, id, init_params):
222
+ return self._jax_wrapped_calc_and_approx
223
+
224
+ @staticmethod
225
+ def _jax_wrapped_calc_forall_approx(x, axis, params):
226
+ return jnp.min(x, axis=axis), params
227
+
228
+ def norms(self, id, init_params):
229
+ return self._jax_wrapped_calc_forall_approx
230
+
231
+ def __str__(self) -> str:
232
+ return 'Godel t-norm'
101
233
 
102
- def norms(self, x, axis):
103
- return jnp.min(x, axis=axis)
104
-
105
234
 
106
235
  class LukasiewiczTNorm(TNorm):
107
236
  '''Lukasiewicz t-norm given by the expression (x, y) -> max(x + y - 1, 0).'''
108
237
 
109
- def norm(self, x, y):
110
- return jax.nn.relu(x + y - 1.0)
238
+ @staticmethod
239
+ def _jax_wrapped_calc_and_approx(x, y, params):
240
+ land = jax.nn.relu(x + y - 1.0)
241
+ return land, params
242
+
243
+ def norm(self, id, init_params):
244
+ return self._jax_wrapped_calc_and_approx
111
245
 
112
- def norms(self, x, axis):
113
- return jax.nn.relu(jnp.sum(x - 1.0, axis=axis) + 1.0)
246
+ @staticmethod
247
+ def _jax_wrapped_calc_forall_approx(x, axis, params):
248
+ forall = jax.nn.relu(jnp.sum(x - 1.0, axis=axis) + 1.0)
249
+ return forall, params
250
+
251
+ def norms(self, id, init_params):
252
+ return self._jax_wrapped_calc_forall_approx
253
+
254
+ def __str__(self) -> str:
255
+ return 'Lukasiewicz t-norm'
114
256
 
115
257
 
116
258
  class YagerTNorm(TNorm):
@@ -118,19 +260,32 @@ class YagerTNorm(TNorm):
118
260
  (x, y) -> max(1 - ((1 - x)^p + (1 - y)^p)^(1/p)).'''
119
261
 
120
262
  def __init__(self, p=2.0):
121
- self.p = p
122
-
123
- def norm(self, x, y):
124
- base_x = jax.nn.relu(1.0 - x)
125
- base_y = jax.nn.relu(1.0 - y)
126
- arg = jnp.power(base_x ** self.p + base_y ** self.p, 1.0 / self.p)
127
- return jax.nn.relu(1.0 - arg)
128
-
129
- def norms(self, x, axis):
130
- base = jax.nn.relu(1.0 - x)
131
- arg = jnp.power(jnp.sum(base ** self.p, axis=axis), 1.0 / self.p)
132
- return jax.nn.relu(1.0 - arg)
133
-
263
+ self.p = float(p)
264
+
265
+ def norm(self, id, init_params):
266
+ id_ = str(id)
267
+ init_params[id_] = self.p
268
+ def _jax_wrapped_calc_and_approx(x, y, params):
269
+ base = jax.nn.relu(1.0 - jnp.stack([x, y], axis=0))
270
+ arg = jnp.linalg.norm(base, ord=params[id_], axis=0)
271
+ land = jax.nn.relu(1.0 - arg)
272
+ return land, params
273
+ return _jax_wrapped_calc_and_approx
274
+
275
+ def norms(self, id, init_params):
276
+ id_ = str(id)
277
+ init_params[id_] = self.p
278
+ def _jax_wrapped_calc_forall_approx(x, axis, params):
279
+ arg = jax.nn.relu(1.0 - x)
280
+ for ax in sorted(axis, reverse=True):
281
+ arg = jnp.linalg.norm(arg, ord=params[id_], axis=ax)
282
+ forall = jax.nn.relu(1.0 - arg)
283
+ return forall, params
284
+ return _jax_wrapped_calc_forall_approx
285
+
286
+ def __str__(self) -> str:
287
+ return f'Yager({self.p}) t-norm'
288
+
134
289
 
135
290
  # ===========================================================================
136
291
  # RANDOM SAMPLING
@@ -144,141 +299,147 @@ class RandomSampling:
144
299
  '''An abstract class that describes how discrete and non-reparameterizable
145
300
  random variables are sampled.'''
146
301
 
147
- def discrete(self, logic):
302
+ def discrete(self, id, init_params, logic):
148
303
  raise NotImplementedError
149
304
 
150
- def bernoulli(self, logic):
151
- jax_discrete, jax_param = self.discrete(logic)
152
-
153
- def _jax_wrapped_calc_bernoulli_approx(key, prob, param):
305
+ def bernoulli(self, id, init_params, logic):
306
+ discrete_approx = self.discrete(id, init_params, logic)
307
+ def _jax_wrapped_calc_bernoulli_approx(key, prob, params):
154
308
  prob = jnp.stack([1.0 - prob, prob], axis=-1)
155
- sample = jax_discrete(key, prob, param)
156
- return sample
157
-
158
- return _jax_wrapped_calc_bernoulli_approx, jax_param
309
+ return discrete_approx(key, prob, params)
310
+ return _jax_wrapped_calc_bernoulli_approx
159
311
 
160
- def poisson(self, logic):
161
-
162
- def _jax_wrapped_calc_poisson_exact(key, rate, param):
163
- return random.poisson(key=key, lam=rate, dtype=logic.INT)
312
+ @staticmethod
313
+ def _jax_wrapped_calc_poisson_exact(key, rate, params):
314
+ sample = random.poisson(key=key, lam=rate, dtype=logic.INT)
315
+ return sample, params
164
316
 
165
- return _jax_wrapped_calc_poisson_exact, None
166
-
167
- def geometric(self, logic):
168
- if logic.verbose:
169
- raise_warning('Using the replacement rule: '
170
- 'Geometric(p) --> floor(log(U) / log(1 - p)) + 1')
171
-
172
- jax_floor, jax_param = logic.floor()
173
-
174
- def _jax_wrapped_calc_geometric_approx(key, prob, param):
317
+ def poisson(self, id, init_params, logic):
318
+ return self._jax_wrapped_calc_poisson_exact
319
+
320
+ def geometric(self, id, init_params, logic):
321
+ approx_floor = logic.floor(id, init_params)
322
+ def _jax_wrapped_calc_geometric_approx(key, prob, params):
175
323
  U = random.uniform(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
176
- sample = jax_floor(jnp.log(U) / jnp.log(1.0 - prob), param) + 1
177
- return sample
178
-
179
- return _jax_wrapped_calc_geometric_approx, jax_param
324
+ floor, params = approx_floor(jnp.log(U) / jnp.log(1.0 - prob), params)
325
+ sample = floor + 1
326
+ return sample, params
327
+ return _jax_wrapped_calc_geometric_approx
180
328
 
181
329
 
182
330
  class GumbelSoftmax(RandomSampling):
183
331
  '''Random sampling of discrete variables using Gumbel-softmax trick.'''
184
332
 
185
- def discrete(self, logic):
186
- if logic.verbose:
187
- raise_warning('Using the replacement rule: '
188
- 'Discrete(p) --> Gumbel-softmax(p)')
189
-
190
- jax_argmax, jax_param = logic.argmax()
191
-
192
- def _jax_wrapped_calc_discrete_gumbel_softmax(key, prob, param):
333
+ # https://arxiv.org/pdf/1611.01144
334
+ def discrete(self, id, init_params, logic):
335
+ argmax_approx = logic.argmax(id, init_params)
336
+ def _jax_wrapped_calc_discrete_gumbel_softmax(key, prob, params):
193
337
  Gumbel01 = random.gumbel(key=key, shape=prob.shape, dtype=logic.REAL)
194
338
  sample = Gumbel01 + jnp.log(prob + logic.eps)
195
- sample = jax_argmax(sample, axis=-1, param=param)
196
- return sample
197
-
198
- return _jax_wrapped_calc_discrete_gumbel_softmax, jax_param
339
+ return argmax_approx(sample, axis=-1, params=params)
340
+ return _jax_wrapped_calc_discrete_gumbel_softmax
341
+
342
+ def __str__(self) -> str:
343
+ return 'Gumbel-Softmax'
199
344
 
200
345
 
201
346
  class Determinization(RandomSampling):
202
347
  '''Random sampling of variables using their deterministic mean estimate.'''
203
348
 
204
- def discrete(self, logic):
205
- if logic.verbose:
206
- raise_warning('Using the replacement rule: '
207
- 'Discrete(p) --> sum(i * p[i])')
208
-
209
- def _jax_wrapped_calc_discrete_determinized(key, prob, param):
210
- literals = FuzzyLogic.enumerate_literals(prob.shape, axis=-1)
211
- sample = jnp.sum(literals * prob, axis=-1)
212
- return sample
213
-
214
- return _jax_wrapped_calc_discrete_determinized, None
215
-
216
- def poisson(self, logic):
217
- if logic.verbose:
218
- raise_warning('Using the replacement rule: Poisson(rate) --> rate')
219
-
220
- def _jax_wrapped_calc_poisson_determinized(key, rate, param):
221
- return rate
222
-
223
- return _jax_wrapped_calc_poisson_determinized, None
224
-
225
- def geometric(self, logic):
226
- if logic.verbose:
227
- raise_warning('Using the replacement rule: Geometric(p) --> 1 / p')
228
-
229
- def _jax_wrapped_calc_geometric_determinized(key, prob, param):
230
- sample = 1.0 / prob
231
- return sample
349
+ @staticmethod
350
+ def _jax_wrapped_calc_discrete_determinized(key, prob, params):
351
+ literals = enumerate_literals(prob.shape, axis=-1)
352
+ sample = jnp.sum(literals * prob, axis=-1)
353
+ return sample, params
232
354
 
233
- return _jax_wrapped_calc_geometric_determinized, None
355
+ def discrete(self, id, init_params, logic):
356
+ return self._jax_wrapped_calc_discrete_determinized
357
+
358
+ @staticmethod
359
+ def _jax_wrapped_calc_poisson_determinized(key, rate, params):
360
+ return rate, params
234
361
 
362
+ def poisson(self, id, init_params, logic):
363
+ return self._jax_wrapped_calc_poisson_determinized
364
+
365
+ @staticmethod
366
+ def _jax_wrapped_calc_geometric_determinized(key, prob, params):
367
+ sample = 1.0 / prob
368
+ return sample, params
369
+
370
+ def geometric(self, id, init_params, logic):
371
+ return self._jax_wrapped_calc_geometric_determinized
372
+
373
+ def __str__(self) -> str:
374
+ return 'Deterministic'
375
+
235
376
 
236
377
  # ===========================================================================
237
- # FUZZY LOGIC
378
+ # CONTROL FLOW
379
+ # - soft flow
238
380
  #
239
381
  # ===========================================================================
240
382
 
241
- class FuzzyLogic:
242
- '''A class representing fuzzy logic in JAX.
383
+ class ControlFlow:
384
+ '''A base class for control flow, including if and switch statements.'''
243
385
 
244
- Functionality can be customized by either providing a tnorm as parameters,
245
- or by overriding its methods.
246
- '''
386
+ def if_then_else(self, id, init_params):
387
+ raise NotImplementedError
247
388
 
248
- def __init__(self, tnorm: TNorm=ProductTNorm(),
249
- complement: Complement=StandardComplement(),
250
- comparison: Comparison=SigmoidComparison(),
251
- sampling: RandomSampling=GumbelSoftmax(),
252
- weight: float=10.0,
253
- debias: Optional[Set[str]]=None,
254
- eps: float=1e-15,
255
- verbose: bool=False,
256
- use64bit: bool=False) -> None:
257
- '''Creates a new fuzzy logic in Jax.
389
+ def switch(self, id, init_params):
390
+ raise NotImplementedError
391
+
392
+
393
+ class SoftControlFlow(ControlFlow):
394
+ '''Soft control flow using a probabilistic interpretation.'''
395
+
396
+ def __init__(self, weight: float=10.0) -> None:
397
+ self.weight = weight
258
398
 
259
- :param tnorm: fuzzy operator for logical AND
260
- :param complement: fuzzy operator for logical NOT
261
- :param comparison: fuzzy operator for comparisons (>, >=, <, ==, ~=, ...)
262
- :param sampling: random sampling of non-reparameterizable distributions
263
- :param weight: a sharpness parameter for sigmoid and softmax activations
264
- :param debias: which functions to de-bias approximate on forward pass
265
- :param eps: small positive float to mitigate underflow
266
- :param verbose: whether to dump replacements and other info to console
267
- :param use64bit: whether to perform arithmetic in 64 bit
268
- '''
269
- self.tnorm = tnorm
270
- self.complement = complement
271
- self.comparison = comparison
272
- self.sampling = sampling
273
- self.weight = float(weight)
274
- if debias is None:
275
- debias = set()
276
- self.debias = debias
277
- self.eps = eps
278
- self.verbose = verbose
399
+ @staticmethod
400
+ def _jax_wrapped_calc_if_then_else_soft(c, a, b, params):
401
+ sample = c * a + (1.0 - c) * b
402
+ return sample, params
403
+
404
+ def if_then_else(self, id, init_params):
405
+ return self._jax_wrapped_calc_if_then_else_soft
406
+
407
+ def switch(self, id, init_params):
408
+ id_ = str(id)
409
+ init_params[id_] = self.weight
410
+ def _jax_wrapped_calc_switch_soft(pred, cases, params):
411
+ literals = enumerate_literals(cases.shape, axis=0)
412
+ pred = jnp.broadcast_to(pred[jnp.newaxis, ...], shape=cases.shape)
413
+ proximity = -jnp.square(pred - literals)
414
+ softcase = jax.nn.softmax(params[id_] * proximity, axis=0)
415
+ sample = jnp.sum(cases * softcase, axis=0)
416
+ return sample, params
417
+ return _jax_wrapped_calc_switch_soft
418
+
419
+ def __str__(self) -> str:
420
+ return f'Soft control flow with weight {self.weight}'
421
+
422
+
423
+ # ===========================================================================
424
+ # LOGIC
425
+ # - exact logic
426
+ # - fuzzy logic
427
+ #
428
+ # ===========================================================================
429
+
430
+
431
+ class Logic:
432
+ '''A base class for representing logic computations in JAX.'''
433
+
434
+ def __init__(self, use64bit: bool=False) -> None:
279
435
  self.set_use64bit(use64bit)
280
436
 
437
+ def summarize_hyperparameters(self) -> None:
438
+ print(f'model relaxation:\n'
439
+ f' use_64_bit ={self.use64bit}')
440
+
281
441
  def set_use64bit(self, use64bit: bool) -> None:
442
+ '''Toggles whether or not the JAX system will use 64 bit precision.'''
282
443
  self.use64bit = use64bit
283
444
  if use64bit:
284
445
  self.REAL = jnp.float64
@@ -288,384 +449,507 @@ class FuzzyLogic:
288
449
  self.REAL = jnp.float32
289
450
  self.INT = jnp.int32
290
451
  jax.config.update('jax_enable_x64', False)
291
-
292
- def summarize_hyperparameters(self) -> None:
293
- print(f'model relaxation:\n'
294
- f' tnorm ={type(self.tnorm).__name__}\n'
295
- f' complement ={type(self.complement).__name__}\n'
296
- f' comparison ={type(self.comparison).__name__}\n'
297
- f' sampling ={type(self.sampling).__name__}\n'
298
- f' sigmoid_weight={self.weight}\n'
299
- f' cpfs_to_debias={self.debias}\n'
300
- f' underflow_tol ={self.eps}\n'
301
- f' use_64_bit ={self.use64bit}')
302
-
452
+
303
453
  # ===========================================================================
304
454
  # logical operators
305
455
  # ===========================================================================
306
456
 
307
- def logical_and(self):
308
- if self.verbose:
309
- raise_warning('Using the replacement rule: a ^ b --> tnorm(a, b).')
310
-
311
- _and = self.tnorm.norm
312
-
313
- def _jax_wrapped_calc_and_approx(a, b, param):
314
- return _and(a, b)
315
-
316
- return _jax_wrapped_calc_and_approx, None
457
+ def logical_and(self, id, init_params):
458
+ raise NotImplementedError
317
459
 
318
- def logical_not(self):
319
- if self.verbose:
320
- raise_warning('Using the replacement rule: ~a --> complement(a)')
321
-
322
- _not = self.complement
323
-
324
- def _jax_wrapped_calc_not_approx(x, param):
325
- return _not(x)
326
-
327
- return _jax_wrapped_calc_not_approx, None
328
-
329
- def logical_or(self):
330
- if self.verbose:
331
- raise_warning('Using the replacement rule: a | b --> tconorm(a, b).')
332
-
333
- _not = self.complement
334
- _and = self.tnorm.norm
335
-
336
- def _jax_wrapped_calc_or_approx(a, b, param):
337
- return _not(_and(_not(a), _not(b)))
338
-
339
- return _jax_wrapped_calc_or_approx, None
340
-
341
- def xor(self):
342
- if self.verbose:
343
- raise_warning('Using the replacement rule: '
344
- 'a ~ b --> (a | b) ^ (a ^ b).')
345
-
346
- _not = self.complement
347
- _and = self.tnorm.norm
348
-
349
- def _jax_wrapped_calc_xor_approx(a, b, param):
350
- _or = _not(_and(_not(a), _not(b)))
351
- return _and(_or(a, b), _not(_and(a, b)))
352
-
353
- return _jax_wrapped_calc_xor_approx, None
354
-
355
- def implies(self):
356
- if self.verbose:
357
- raise_warning('Using the replacement rule: a => b --> ~a ^ b')
358
-
359
- _not = self.complement
360
- _and = self.tnorm.norm
361
-
362
- def _jax_wrapped_calc_implies_approx(a, b, param):
363
- return _not(_and(a, _not(b)))
364
-
365
- return _jax_wrapped_calc_implies_approx, None
460
+ def logical_not(self, id, init_params):
461
+ raise NotImplementedError
366
462
 
367
- def equiv(self):
368
- if self.verbose:
369
- raise_warning('Using the replacement rule: '
370
- 'a <=> b --> (a => b) ^ (b => a)')
371
-
372
- _not = self.complement
373
- _and = self.tnorm.norm
374
-
375
- def _jax_wrapped_calc_equiv_approx(a, b, param):
376
- atob = _not(_and(a, _not(b)))
377
- btoa = _not(_and(b, _not(a)))
378
- return _and(atob, btoa)
379
-
380
- return _jax_wrapped_calc_equiv_approx, None
463
+ def logical_or(self, id, init_params):
464
+ raise NotImplementedError
381
465
 
382
- def forall(self):
383
- if self.verbose:
384
- raise_warning('Using the replacement rule: '
385
- 'forall(a) --> a[1] ^ a[2] ^ ...')
386
-
387
- _forall = self.tnorm.norms
388
-
389
- def _jax_wrapped_calc_forall_approx(x, axis, param):
390
- return _forall(x, axis=axis)
391
-
392
- return _jax_wrapped_calc_forall_approx, None
466
+ def xor(self, id, init_params):
467
+ raise NotImplementedError
393
468
 
394
- def exists(self):
395
- _not = self.complement
396
- jax_forall, jax_param = self.forall()
397
-
398
- def _jax_wrapped_calc_exists_approx(x, axis, param):
399
- return _not(jax_forall(_not(x), axis, param))
400
-
401
- return _jax_wrapped_calc_exists_approx, jax_param
469
+ def implies(self, id, init_params):
470
+ raise NotImplementedError
471
+
472
+ def equiv(self, id, init_params):
473
+ raise NotImplementedError
474
+
475
+ def forall(self, id, init_params):
476
+ raise NotImplementedError
477
+
478
+ def exists(self, id, init_params):
479
+ raise NotImplementedError
402
480
 
403
481
  # ===========================================================================
404
482
  # comparison operators
405
483
  # ===========================================================================
484
+
485
+ def greater_equal(self, id, init_params):
486
+ raise NotImplementedError
487
+
488
+ def greater(self, id, init_params):
489
+ raise NotImplementedError
490
+
491
+ def less_equal(self, id, init_params):
492
+ raise NotImplementedError
493
+
494
+ def less(self, id, init_params):
495
+ raise NotImplementedError
496
+
497
+ def equal(self, id, init_params):
498
+ raise NotImplementedError
499
+
500
+ def not_equal(self, id, init_params):
501
+ raise NotImplementedError
502
+
503
+ # ===========================================================================
504
+ # special functions
505
+ # ===========================================================================
406
506
 
407
- def greater_equal(self):
408
- if self.verbose:
409
- raise_warning('Using the replacement rule: '
410
- 'a >= b --> comparison.greater_equal(a, b)')
411
-
412
- greater_equal_op = self.comparison.greater_equal
413
- debias = 'greater_equal' in self.debias
414
-
415
- def _jax_wrapped_calc_geq_approx(a, b, param):
416
- sample = greater_equal_op(a, b, param)
417
- if debias:
418
- hard_sample = jnp.greater_equal(a, b)
419
- sample += jax.lax.stop_gradient(hard_sample - sample)
420
- return sample
421
-
422
- tags = ('weight', 'greater_equal')
423
- new_param = (tags, self.weight)
424
- return _jax_wrapped_calc_geq_approx, new_param
425
-
426
- def greater(self):
427
- if self.verbose:
428
- raise_warning('Using the replacement rule: '
429
- 'a > b --> comparison.greater(a, b)')
430
-
431
- greater_op = self.comparison.greater
432
- debias = 'greater' in self.debias
433
-
434
- def _jax_wrapped_calc_gre_approx(a, b, param):
435
- sample = greater_op(a, b, param)
436
- if debias:
437
- hard_sample = jnp.greater(a, b)
438
- sample += jax.lax.stop_gradient(hard_sample - sample)
439
- return sample
440
-
441
- tags = ('weight', 'greater')
442
- new_param = (tags, self.weight)
443
- return _jax_wrapped_calc_gre_approx, new_param
507
+ def sgn(self, id, init_params):
508
+ raise NotImplementedError
444
509
 
445
- def less_equal(self):
446
- jax_geq, jax_param = self.greater_equal()
447
-
448
- def _jax_wrapped_calc_leq_approx(a, b, param):
449
- return jax_geq(-a, -b, param)
510
+ def floor(self, id, init_params):
511
+ raise NotImplementedError
512
+
513
+ def round(self, id, init_params):
514
+ raise NotImplementedError
515
+
516
+ def ceil(self, id, init_params):
517
+ raise NotImplementedError
518
+
519
+ def div(self, id, init_params):
520
+ raise NotImplementedError
521
+
522
+ def mod(self, id, init_params):
523
+ raise NotImplementedError
524
+
525
+ def sqrt(self, id, init_params):
526
+ raise NotImplementedError
527
+
528
+ # ===========================================================================
529
+ # indexing
530
+ # ===========================================================================
531
+
532
+ def argmax(self, id, init_params):
533
+ raise NotImplementedError
534
+
535
+ def argmin(self, id, init_params):
536
+ raise NotImplementedError
537
+
538
+ # ===========================================================================
539
+ # control flow
540
+ # ===========================================================================
541
+
542
+ def control_if(self, id, init_params):
543
+ raise NotImplementedError
450
544
 
451
- return _jax_wrapped_calc_leq_approx, jax_param
545
+ def control_switch(self, id, init_params):
546
+ raise NotImplementedError
547
+
548
+ # ===========================================================================
549
+ # random variables
550
+ # ===========================================================================
551
+
552
+ def discrete(self, id, init_params):
553
+ raise NotImplementedError
554
+
555
+ def bernoulli(self, id, init_params):
556
+ raise NotImplementedError
452
557
 
453
- def less(self):
454
- jax_gre, jax_param = self.greater()
558
+ def poisson(self, id, init_params):
559
+ raise NotImplementedError
560
+
561
+ def geometric(self, id, init_params):
562
+ raise NotImplementedError
455
563
 
456
- def _jax_wrapped_calc_less_approx(a, b, param):
457
- return jax_gre(-a, -b, param)
458
-
459
- return _jax_wrapped_calc_less_approx, jax_param
460
564
 
461
- def equal(self):
462
- if self.verbose:
463
- raise_warning('Using the replacement rule: '
464
- 'a == b --> comparison.equal(a, b)')
465
-
466
- equal_op = self.comparison.equal
467
- debias = 'equal' in self.debias
468
-
469
- def _jax_wrapped_calc_equal_approx(a, b, param):
470
- sample = equal_op(a, b, param)
471
- if debias:
472
- hard_sample = jnp.equal(a, b)
473
- sample += jax.lax.stop_gradient(hard_sample - sample)
474
- return sample
475
-
476
- tags = ('weight', 'equal')
477
- new_param = (tags, self.weight)
478
- return _jax_wrapped_calc_equal_approx, new_param
565
+ class ExactLogic(Logic):
566
+ '''A class representing exact logic in JAX.'''
479
567
 
480
- def not_equal(self):
481
- _not = self.complement
482
- jax_eq, jax_param = self.equal()
483
-
484
- def _jax_wrapped_calc_neq_approx(a, b, param):
485
- return _not(jax_eq(a, b, param))
568
+ @staticmethod
569
+ def exact_unary_function(op):
570
+ def _jax_wrapped_calc_unary_function_exact(x, params):
571
+ return op(x), params
572
+ return _jax_wrapped_calc_unary_function_exact
573
+
574
+ @staticmethod
575
+ def exact_binary_function(op):
576
+ def _jax_wrapped_calc_binary_function_exact(x, y, params):
577
+ return op(x, y), params
578
+ return _jax_wrapped_calc_binary_function_exact
579
+
580
+ @staticmethod
581
+ def exact_aggregation(op):
582
+ def _jax_wrapped_calc_aggregation_exact(x, axis, params):
583
+ return op(x, axis=axis), params
584
+ return _jax_wrapped_calc_aggregation_exact
486
585
 
487
- return _jax_wrapped_calc_neq_approx, jax_param
488
-
586
+ # ===========================================================================
587
+ # logical operators
588
+ # ===========================================================================
589
+
590
+ def logical_and(self, id, init_params):
591
+ return self.exact_binary_function(jnp.logical_and)
592
+
593
+ def logical_not(self, id, init_params):
594
+ return self.exact_unary_function(jnp.logical_not)
595
+
596
+ def logical_or(self, id, init_params):
597
+ return self.exact_binary_function(jnp.logical_or)
598
+
599
+ def xor(self, id, init_params):
600
+ return self.exact_binary_function(jnp.logical_xor)
601
+
602
+ @staticmethod
603
+ def exact_binary_implies(x, y, params):
604
+ return jnp.logical_or(jnp.logical_not(x), y), params
605
+
606
+ def implies(self, id, init_params):
607
+ return self.exact_binary_implies
608
+
609
+ def equiv(self, id, init_params):
610
+ return self.exact_binary_function(jnp.equal)
611
+
612
+ def forall(self, id, init_params):
613
+ return self.exact_aggregation(jnp.all)
614
+
615
+ def exists(self, id, init_params):
616
+ return self.exact_aggregation(jnp.any)
617
+
618
+ # ===========================================================================
619
+ # comparison operators
620
+ # ===========================================================================
621
+
622
+ def greater_equal(self, id, init_params):
623
+ return self.exact_binary_function(jnp.greater_equal)
624
+
625
+ def greater(self, id, init_params):
626
+ return self.exact_binary_function(jnp.greater)
627
+
628
+ def less_equal(self, id, init_params):
629
+ return self.exact_binary_function(jnp.less_equal)
630
+
631
+ def less(self, id, init_params):
632
+ return self.exact_binary_function(jnp.less)
633
+
634
+ def equal(self, id, init_params):
635
+ return self.exact_binary_function(jnp.equal)
636
+
637
+ def not_equal(self, id, init_params):
638
+ return self.exact_binary_function(jnp.not_equal)
639
+
489
640
  # ===========================================================================
490
641
  # special functions
491
642
  # ===========================================================================
492
643
 
493
- def sgn(self):
494
- if self.verbose:
495
- raise_warning('Using the replacement rule: sgn(x) --> tanh(x)')
496
-
497
- debias = 'sgn' in self.debias
498
-
499
- def _jax_wrapped_calc_sgn_approx(x, param):
500
- sample = jnp.tanh(param * x)
501
- if debias:
502
- hard_sample = jnp.sign(x)
503
- sample += jax.lax.stop_gradient(hard_sample - sample)
504
- return sample
505
-
506
- tags = ('weight', 'sgn')
507
- new_param = (tags, self.weight)
508
- return _jax_wrapped_calc_sgn_approx, new_param
509
-
510
- def floor(self):
511
- if self.verbose:
512
- raise_warning('Using the replacement rule: '
513
- 'floor(x) --> x - atan(-1.0 / tan(pi * x)) / pi - 0.5')
514
-
515
- def _jax_wrapped_calc_floor_approx(x, param):
516
- sawtooth_part = jnp.arctan(-1.0 / jnp.tan(x * jnp.pi)) / jnp.pi + 0.5
517
- sample = x - jax.lax.stop_gradient(sawtooth_part)
518
- return sample
519
-
520
- return _jax_wrapped_calc_floor_approx, None
521
-
522
- def ceil(self):
523
- jax_floor, jax_param = self.floor()
524
-
525
- def _jax_wrapped_calc_ceil_approx(x, param):
526
- return -jax_floor(-x, param)
527
-
528
- return _jax_wrapped_calc_ceil_approx, jax_param
644
+ def sgn(self, id, init_params):
645
+ return self.exact_unary_function(jnp.sign)
529
646
 
530
- def round(self):
531
- if self.verbose:
532
- raise_warning('Using the replacement rule: round(x) --> x')
533
-
534
- debias = 'round' in self.debias
535
-
536
- def _jax_wrapped_calc_round_approx(x, param):
537
- sample = x
538
- if debias:
539
- hard_sample = jnp.round(x)
540
- sample += jax.lax.stop_gradient(hard_sample - sample)
541
- return sample
542
-
543
- return _jax_wrapped_calc_round_approx, None
647
+ def floor(self, id, init_params):
648
+ return self.exact_unary_function(jnp.floor)
544
649
 
545
- def mod(self):
546
- jax_floor, jax_param = self.floor()
547
-
548
- def _jax_wrapped_calc_mod_approx(x, y, param):
549
- return x - y * jax_floor(x / y, param)
550
-
551
- return _jax_wrapped_calc_mod_approx, jax_param
650
+ def round(self, id, init_params):
651
+ return self.exact_unary_function(jnp.round)
552
652
 
553
- def div(self):
554
- jax_floor, jax_param = self.floor()
555
-
556
- def _jax_wrapped_calc_div_approx(x, y, param):
557
- return jax_floor(x / y, param)
558
-
559
- return _jax_wrapped_calc_div_approx, jax_param
653
+ def ceil(self, id, init_params):
654
+ return self.exact_unary_function(jnp.ceil)
560
655
 
561
- def sqrt(self):
562
- if self.verbose:
563
- raise_warning('Using the replacement rule: sqrt(x) --> sqrt(x + eps)')
564
-
565
- def _jax_wrapped_calc_sqrt_approx(x, param):
566
- return jnp.sqrt(x + self.eps)
567
-
568
- return _jax_wrapped_calc_sqrt_approx, None
656
+ def div(self, id, init_params):
657
+ return self.exact_binary_function(jnp.floor_divide)
658
+
659
+ def mod(self, id, init_params):
660
+ return self.exact_binary_function(jnp.mod)
661
+
662
+ def sqrt(self, id, init_params):
663
+ return self.exact_unary_function(jnp.sqrt)
569
664
 
570
665
  # ===========================================================================
571
666
  # indexing
572
667
  # ===========================================================================
573
668
 
669
+ def argmax(self, id, init_params):
670
+ return self.exact_aggregation(jnp.argmax)
671
+
672
+ def argmin(self, id, init_params):
673
+ return self.exact_aggregation(jnp.argmin)
674
+
675
+ # ===========================================================================
676
+ # control flow
677
+ # ===========================================================================
678
+
574
679
  @staticmethod
575
- def enumerate_literals(shape, axis):
576
- literals = jnp.arange(shape[axis])
577
- literals = literals[(...,) + (jnp.newaxis,) * (len(shape) - 1)]
578
- literals = jnp.moveaxis(literals, source=0, destination=axis)
579
- literals = jnp.broadcast_to(literals, shape=shape)
580
- return literals
581
-
582
- def argmax(self):
583
- if self.verbose:
584
- raise_warning('Using the replacement rule: '
585
- 'argmax(x) --> sum(i * softmax(x[i]))')
586
-
587
- debias = 'argmax' in self.debias
680
+ def exact_if_then_else(c, a, b, params):
681
+ return jnp.where(c > 0.5, a, b), params
682
+
683
+ def control_if(self, id, init_params):
684
+ return self.exact_if_then_else
685
+
686
+ @staticmethod
687
+ def exact_switch(pred, cases, params):
688
+ pred = pred[jnp.newaxis, ...]
689
+ sample = jnp.take_along_axis(cases, pred, axis=0)
690
+ assert sample.shape[0] == 1
691
+ return sample[0, ...], params
692
+
693
+ def control_switch(self, id, init_params):
694
+ return self.exact_switch
695
+
696
+ # ===========================================================================
697
+ # random variables
698
+ # ===========================================================================
699
+
700
+ @staticmethod
701
+ def exact_discrete(key, prob, params):
702
+ return random.categorical(key=key, logits=jnp.log(prob), axis=-1), params
703
+
704
+ def discrete(self, id, init_params):
705
+ return self.exact_discrete
706
+
707
+ @staticmethod
708
+ def exact_bernoulli(key, prob, params):
709
+ return random.bernoulli(key, prob), params
710
+
711
+ def bernoulli(self, id, init_params):
712
+ return self.exact_bernoulli
713
+
714
+ @staticmethod
715
+ def exact_poisson(key, rate, params):
716
+ return random.poisson(key=key, lam=rate), params
717
+
718
+ def poisson(self, id, init_params):
719
+ return self.exact_poisson
720
+
721
+ @staticmethod
722
+ def exact_geometric(key, prob, params):
723
+ return random.geometric(key=key, p=prob), params
724
+
725
+ def geometric(self, id, init_params):
726
+ return self.exact_geometric
727
+
728
+
729
+ class FuzzyLogic(Logic):
730
+ '''A class representing fuzzy logic in JAX.'''
731
+
732
+ def __init__(self, tnorm: TNorm=ProductTNorm(),
733
+ complement: Complement=StandardComplement(),
734
+ comparison: Comparison=SigmoidComparison(),
735
+ sampling: RandomSampling=GumbelSoftmax(),
736
+ rounding: Rounding=SoftRounding(),
737
+ control: ControlFlow=SoftControlFlow(),
738
+ eps: float=1e-15,
739
+ use64bit: bool=False) -> None:
740
+ '''Creates a new fuzzy logic in Jax.
741
+
742
+ :param tnorm: fuzzy operator for logical AND
743
+ :param complement: fuzzy operator for logical NOT
744
+ :param comparison: fuzzy operator for comparisons (>, >=, <, ==, ~=, ...)
745
+ :param sampling: random sampling of non-reparameterizable distributions
746
+ :param rounding: rounding floating values to integers
747
+ :param control: if and switch control structures
748
+ :param eps: small positive float to mitigate underflow
749
+ :param use64bit: whether to perform arithmetic in 64 bit
750
+ '''
751
+ super().__init__(use64bit=use64bit)
752
+ self.tnorm = tnorm
753
+ self.complement = complement
754
+ self.comparison = comparison
755
+ self.sampling = sampling
756
+ self.rounding = rounding
757
+ self.control = control
758
+ self.eps = eps
759
+
760
+ def __str__(self) -> str:
761
+ return (f'model relaxation:\n'
762
+ f' tnorm ={str(self.tnorm)}\n'
763
+ f' complement ={str(self.complement)}\n'
764
+ f' comparison ={str(self.comparison)}\n'
765
+ f' sampling ={str(self.sampling)}\n'
766
+ f' rounding ={str(self.rounding)}\n'
767
+ f' control ={str(self.control)}\n'
768
+ f' underflow_tol ={self.eps}\n'
769
+ f' use_64_bit ={self.use64bit}')
770
+
771
+ def summarize_hyperparameters(self) -> None:
772
+ print(self.__str__())
588
773
 
589
- def _jax_wrapped_calc_argmax_approx(x, axis, param):
590
- literals = FuzzyLogic.enumerate_literals(x.shape, axis=axis)
591
- soft_max = jax.nn.softmax(param * x, axis=axis)
592
- sample = jnp.sum(literals * soft_max, axis=axis)
593
- if debias:
594
- hard_sample = jnp.argmax(x, axis=axis)
595
- sample += jax.lax.stop_gradient(hard_sample - sample)
596
- return sample
774
+ # ===========================================================================
775
+ # logical operators
776
+ # ===========================================================================
777
+
778
+ def logical_and(self, id, init_params):
779
+ return self.tnorm.norm(id, init_params)
780
+
781
+ def logical_not(self, id, init_params):
782
+ return self.complement(id, init_params)
783
+
784
+ def logical_or(self, id, init_params):
785
+ _not1 = self.complement(f'{id}_~1', init_params)
786
+ _not2 = self.complement(f'{id}_~2', init_params)
787
+ _and = self.tnorm.norm(f'{id}_^', init_params)
788
+ _not = self.complement(f'{id}_~', init_params)
789
+
790
+ def _jax_wrapped_calc_or_approx(x, y, params):
791
+ not_x, params = _not1(x, params)
792
+ not_y, params = _not2(y, params)
793
+ not_x_and_not_y, params = _and(not_x, not_y, params)
794
+ return _not(not_x_and_not_y, params)
795
+ return _jax_wrapped_calc_or_approx
796
+
797
+ def xor(self, id, init_params):
798
+ _not = self.complement(f'{id}_~', init_params)
799
+ _and1 = self.tnorm.norm(f'{id}_^1', init_params)
800
+ _and2 = self.tnorm.norm(f'{id}_^2', init_params)
801
+ _or = self.logical_or(f'{id}_|', init_params)
802
+
803
+ def _jax_wrapped_calc_xor_approx(x, y, params):
804
+ x_and_y, params = _and1(x, y, params)
805
+ not_x_and_y, params = _not(x_and_y, params)
806
+ x_or_y, params = _or(x, y, params)
807
+ return _and2(x_or_y, not_x_and_y, params)
808
+ return _jax_wrapped_calc_xor_approx
809
+
810
+ def implies(self, id, init_params):
811
+ _not = self.complement(f'{id}_~', init_params)
812
+ _or = self.logical_or(f'{id}_|', init_params)
813
+
814
+ def _jax_wrapped_calc_implies_approx(x, y, params):
815
+ not_x, params = _not(x, params)
816
+ return _or(not_x, y, params)
817
+ return _jax_wrapped_calc_implies_approx
818
+
819
+ def equiv(self, id, init_params):
820
+ _implies1 = self.implies(f'{id}_=>1', init_params)
821
+ _implies2 = self.implies(f'{id}_=>2', init_params)
822
+ _and = self.tnorm.norm(f'{id}_^', init_params)
823
+
824
+ def _jax_wrapped_calc_equiv_approx(x, y, params):
825
+ x_implies_y, params = _implies1(x, y, params)
826
+ y_implies_x, params = _implies2(y, x, params)
827
+ return _and(x_implies_y, y_implies_x, params)
828
+ return _jax_wrapped_calc_equiv_approx
829
+
830
+ def forall(self, id, init_params):
831
+ return self.tnorm.norms(id, init_params)
832
+
833
+ def exists(self, id, init_params):
834
+ _not1 = self.complement(f'{id}_~1', init_params)
835
+ _not2 = self.complement(f'{id}_~2', init_params)
836
+ _forall = self.forall(f'{id}_forall', init_params)
837
+
838
+ def _jax_wrapped_calc_exists_approx(x, axis, params):
839
+ not_x, params = _not1(x, params)
840
+ forall_not_x, params = _forall(not_x, axis, params)
841
+ return _not2(forall_not_x, params)
842
+ return _jax_wrapped_calc_exists_approx
843
+
844
+ # ===========================================================================
845
+ # comparison operators
846
+ # ===========================================================================
847
+
848
+ def greater_equal(self, id, init_params):
849
+ return self.comparison.greater_equal(id, init_params)
850
+
851
+ def greater(self, id, init_params):
852
+ return self.comparison.greater(id, init_params)
853
+
854
+ def less_equal(self, id, init_params):
855
+ _greater_eq = self.greater_equal(id, init_params)
856
+ def _jax_wrapped_calc_leq_approx(x, y, params):
857
+ return _greater_eq(-x, -y, params)
858
+ return _jax_wrapped_calc_leq_approx
859
+
860
+ def less(self, id, init_params):
861
+ _greater = self.greater(id, init_params)
862
+ def _jax_wrapped_calc_less_approx(x, y, params):
863
+ return _greater(-x, -y, params)
864
+ return _jax_wrapped_calc_less_approx
865
+
866
+ def equal(self, id, init_params):
867
+ return self.comparison.equal(id, init_params)
868
+
869
+ def not_equal(self, id, init_params):
870
+ _not = self.complement(f'{id}_~', init_params)
871
+ _equal = self.comparison.equal(f'{id}_==', init_params)
872
+ def _jax_wrapped_calc_neq_approx(x, y, params):
873
+ equal, params = _equal(x, y, params)
874
+ return _not(equal, params)
875
+ return _jax_wrapped_calc_neq_approx
597
876
 
598
- tags = ('weight', 'argmax')
599
- new_param = (tags, self.weight)
600
- return _jax_wrapped_calc_argmax_approx, new_param
877
+ # ===========================================================================
878
+ # special functions
879
+ # ===========================================================================
880
+
881
+ def sgn(self, id, init_params):
882
+ return self.comparison.sgn(id, init_params)
601
883
 
602
- def argmin(self):
603
- jax_argmax, jax_param = self.argmax()
884
+ def floor(self, id, init_params):
885
+ return self.rounding.floor(id, init_params)
604
886
 
887
+ def round(self, id, init_params):
888
+ return self.rounding.round(id, init_params)
889
+
890
+ def ceil(self, id, init_params):
891
+ _floor = self.rounding.floor(id, init_params)
892
+ def _jax_wrapped_calc_ceil_approx(x, params):
893
+ neg_floor, params = _floor(-x, params)
894
+ return -neg_floor, params
895
+ return _jax_wrapped_calc_ceil_approx
896
+
897
+ def div(self, id, init_params):
898
+ _floor = self.rounding.floor(id, init_params)
899
+ def _jax_wrapped_calc_div_approx(x, y, params):
900
+ return _floor(x / y, params)
901
+ return _jax_wrapped_calc_div_approx
902
+
903
+ def mod(self, id, init_params):
904
+ _div = self.div(id, init_params)
905
+ def _jax_wrapped_calc_mod_approx(x, y, params):
906
+ div, params = _div(x, y, params)
907
+ return x - y * div, params
908
+ return _jax_wrapped_calc_mod_approx
909
+
910
+ def sqrt(self, id, init_params):
911
+ def _jax_wrapped_calc_sqrt_approx(x, params):
912
+ return jnp.sqrt(x + self.eps), params
913
+ return _jax_wrapped_calc_sqrt_approx
914
+
915
+ # ===========================================================================
916
+ # indexing
917
+ # ===========================================================================
918
+
919
+ def argmax(self, id, init_params):
920
+ return self.comparison.argmax(id, init_params)
921
+
922
+ def argmin(self, id, init_params):
923
+ _argmax = self.argmax(id, init_params)
605
924
  def _jax_wrapped_calc_argmin_approx(x, axis, param):
606
- return jax_argmax(-x, axis, param)
607
-
608
- return _jax_wrapped_calc_argmin_approx, jax_param
925
+ return _argmax(-x, axis, param)
926
+ return _jax_wrapped_calc_argmin_approx
609
927
 
610
928
  # ===========================================================================
611
929
  # control flow
612
930
  # ===========================================================================
613
931
 
614
- def control_if(self):
615
- if self.verbose:
616
- raise_warning('Using the replacement rule: '
617
- 'if c then a else b --> c * a + (1 - c) * b')
618
-
619
- debias = 'if' in self.debias
620
-
621
- def _jax_wrapped_calc_if_approx(c, a, b, param):
622
- sample = c * a + (1.0 - c) * b
623
- if debias:
624
- hard_sample = jnp.where(c > 0.5, a, b)
625
- sample += jax.lax.stop_gradient(hard_sample - sample)
626
- return sample
627
-
628
- return _jax_wrapped_calc_if_approx, None
629
-
630
- def control_switch(self):
631
- if self.verbose:
632
- raise_warning('Using the replacement rule: '
633
- 'switch(pred) { cases } --> '
634
- 'sum(cases[i] * softmax(-abs(pred - i)))')
635
-
636
- debias = 'switch' in self.debias
637
-
638
- def _jax_wrapped_calc_switch_approx(pred, cases, param):
639
- literals = FuzzyLogic.enumerate_literals(cases.shape, axis=0)
640
- pred = jnp.broadcast_to(pred[jnp.newaxis, ...], shape=cases.shape)
641
- proximity = -jnp.abs(pred - literals)
642
- soft_case = jax.nn.softmax(param * proximity, axis=0)
643
- sample = jnp.sum(cases * soft_case, axis=0)
644
- if debias:
645
- hard_case = jnp.argmax(proximity, axis=0)[jnp.newaxis, ...]
646
- hard_sample = jnp.take_along_axis(cases, hard_case, axis=0)[0, ...]
647
- sample += jax.lax.stop_gradient(hard_sample - sample)
648
- return sample
649
-
650
- tags = ('weight', 'switch')
651
- new_param = (tags, self.weight)
652
- return _jax_wrapped_calc_switch_approx, new_param
932
+ def control_if(self, id, init_params):
933
+ return self.control.if_then_else(id, init_params)
934
+
935
+ def control_switch(self, id, init_params):
936
+ return self.control.switch(id, init_params)
653
937
 
654
938
  # ===========================================================================
655
939
  # random variables
656
940
  # ===========================================================================
657
941
 
658
- def discrete(self):
659
- return self.sampling.discrete(self)
942
+ def discrete(self, id, init_params):
943
+ return self.sampling.discrete(id, init_params, self)
660
944
 
661
- def bernoulli(self):
662
- return self.sampling.bernoulli(self)
945
+ def bernoulli(self, id, init_params):
946
+ return self.sampling.bernoulli(id, init_params, self)
663
947
 
664
- def poisson(self):
665
- return self.sampling.poisson(self)
948
+ def poisson(self, id, init_params):
949
+ return self.sampling.poisson(id, init_params, self)
666
950
 
667
- def geometric(self):
668
- return self.sampling.geometric(self)
951
+ def geometric(self, id, init_params):
952
+ return self.sampling.geometric(id, init_params, self)
669
953
 
670
954
 
671
955
  # ===========================================================================
@@ -673,103 +957,121 @@ class FuzzyLogic:
673
957
  #
674
958
  # ===========================================================================
675
959
 
676
- logic = FuzzyLogic()
677
- w = 100.0
960
+ logic = FuzzyLogic(comparison=SigmoidComparison(10000.0),
961
+ rounding=SoftRounding(10000.0),
962
+ control=SoftControlFlow(10000.0))
678
963
 
679
964
 
680
965
  def _test_logical():
681
966
  print('testing logical')
682
- _and, _ = logic.logical_and()
683
- _not, _ = logic.logical_not()
684
- _gre, _ = logic.greater()
685
- _or, _ = logic.logical_or()
686
- _if, _ = logic.control_if()
687
-
967
+ init_params = {}
968
+ _and = logic.logical_and(0, init_params)
969
+ _not = logic.logical_not(1, init_params)
970
+ _gre = logic.greater(2, init_params)
971
+ _or = logic.logical_or(3, init_params)
972
+ _if = logic.control_if(4, init_params)
973
+ print(init_params)
974
+
688
975
  # https://towardsdatascience.com/emulating-logical-gates-with-a-neural-network-75c229ec4cc9
689
- def test_logic(x1, x2):
690
- q1 = _and(_gre(x1, 0, w), _gre(x2, 0, w), w)
691
- q2 = _and(_not(_gre(x1, 0, w), w), _not(_gre(x2, 0, w), w), w)
692
- cond = _or(q1, q2, w)
693
- pred = _if(cond, +1, -1, w)
976
+ def test_logic(x1, x2, w):
977
+ q1, w = _gre(x1, 0, w)
978
+ q2, w = _gre(x2, 0, w)
979
+ q3, w = _and(q1, q2, w)
980
+ q4, w = _not(q1, w)
981
+ q5, w = _not(q2, w)
982
+ q6, w = _and(q4, q5, w)
983
+ cond, w = _or(q3, q6, w)
984
+ pred, w = _if(cond, +1, -1, w)
694
985
  return pred
695
986
 
696
987
  x1 = jnp.asarray([1, 1, -1, -1, 0.1, 15, -0.5]).astype(float)
697
988
  x2 = jnp.asarray([1, -1, 1, -1, 10, -30, 6]).astype(float)
698
- print(test_logic(x1, x2))
989
+ print(test_logic(x1, x2, init_params))
699
990
 
700
991
 
701
992
  def _test_indexing():
702
993
  print('testing indexing')
703
- _argmax, _ = logic.argmax()
704
- _argmin, _ = logic.argmax()
994
+ init_params = {}
995
+ _argmax = logic.argmax(0, init_params)
996
+ _argmin = logic.argmin(1, init_params)
997
+ print(init_params)
705
998
 
706
- def argmaxmin(x):
707
- amax = _argmax(x, 0, w)
708
- amin = _argmin(x, 0, w)
999
+ def argmaxmin(x, w):
1000
+ amax, w = _argmax(x, 0, w)
1001
+ amin, w = _argmin(x, 0, w)
709
1002
  return amax, amin
710
1003
 
711
1004
  values = jnp.asarray([2., 3., 5., 4.9, 4., 1., -1., -2.])
712
- amax, amin = argmaxmin(values)
1005
+ amax, amin = argmaxmin(values, init_params)
713
1006
  print(amax)
714
1007
  print(amin)
715
1008
 
716
1009
 
717
1010
  def _test_control():
718
1011
  print('testing control')
719
- _switch, _ = logic.control_switch()
1012
+ init_params = {}
1013
+ _switch = logic.control_switch(0, init_params)
1014
+ print(init_params)
720
1015
 
721
1016
  pred = jnp.asarray(jnp.linspace(0, 2, 10))
722
1017
  case1 = jnp.asarray([-10.] * 10)
723
1018
  case2 = jnp.asarray([1.5] * 10)
724
1019
  case3 = jnp.asarray([10.] * 10)
725
1020
  cases = jnp.asarray([case1, case2, case3])
726
- print(_switch(pred, cases, w))
1021
+ switch, _ = _switch(pred, cases, init_params)
1022
+ print(switch)
727
1023
 
728
1024
 
729
1025
  def _test_random():
730
1026
  print('testing random')
731
1027
  key = random.PRNGKey(42)
732
- _bernoulli, _ = logic.bernoulli()
733
- _discrete, _ = logic.discrete()
734
- _geometric, _ = logic.geometric()
1028
+ init_params = {}
1029
+ _bernoulli = logic.bernoulli(0, init_params)
1030
+ _discrete = logic.discrete(1, init_params)
1031
+ _geometric = logic.geometric(2, init_params)
1032
+ print(init_params)
735
1033
 
736
- def bern(n):
1034
+ def bern(n, w):
737
1035
  prob = jnp.asarray([0.3] * n)
738
- sample = _bernoulli(key, prob, w)
1036
+ sample, _ = _bernoulli(key, prob, w)
739
1037
  return sample
740
1038
 
741
- samples = bern(50000)
1039
+ samples = bern(50000, init_params)
742
1040
  print(jnp.mean(samples))
743
1041
 
744
- def disc(n):
1042
+ def disc(n, w):
745
1043
  prob = jnp.asarray([0.1, 0.4, 0.5])
746
1044
  prob = jnp.tile(prob, (n, 1))
747
- sample = _discrete(key, prob, w)
1045
+ sample, _ = _discrete(key, prob, w)
748
1046
  return sample
749
1047
 
750
- samples = disc(50000)
1048
+ samples = disc(50000, init_params)
751
1049
  samples = jnp.round(samples)
752
1050
  print([jnp.mean(samples == i) for i in range(3)])
753
1051
 
754
- def geom(n):
1052
+ def geom(n, w):
755
1053
  prob = jnp.asarray([0.3] * n)
756
- sample = _geometric(key, prob, w)
1054
+ sample, _ = _geometric(key, prob, w)
757
1055
  return sample
758
1056
 
759
- samples = geom(50000)
1057
+ samples = geom(50000, init_params)
760
1058
  print(jnp.mean(samples))
761
1059
 
762
1060
 
763
1061
  def _test_rounding():
764
1062
  print('testing rounding')
765
- _floor, _ = logic.floor()
766
- _ceil, _ = logic.ceil()
767
- _mod, _ = logic.mod()
768
-
769
- x = jnp.asarray([2.1, 0.5001, 1.99, -2.01, -3.2, -0.1, -1.01, 23.01, -101.99, 200.01])
770
- print(_floor(x, w))
771
- print(_ceil(x, w))
772
- print(_mod(x, 2.0, w))
1063
+ init_params = {}
1064
+ _floor = logic.floor(0, init_params)
1065
+ _ceil = logic.ceil(1, init_params)
1066
+ _round = logic.round(2, init_params)
1067
+ _mod = logic.mod(3, init_params)
1068
+ print(init_params)
1069
+
1070
+ x = jnp.asarray([2.1, 0.6, 1.99, -2.01, -3.2, -0.1, -1.01, 23.01, -101.99, 200.01])
1071
+ print(_floor(x, init_params)[0])
1072
+ print(_ceil(x, init_params)[0])
1073
+ print(_round(x, init_params)[0])
1074
+ print(_mod(x, 2.0, init_params)[0])
773
1075
 
774
1076
 
775
1077
  if __name__ == '__main__':