pyRDDLGym-jax 0.5__py3-none-any.whl → 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +463 -592
- pyRDDLGym_jax/core/logic.py +784 -544
- pyRDDLGym_jax/core/planner.py +329 -463
- pyRDDLGym_jax/core/simulator.py +7 -5
- pyRDDLGym_jax/core/tuning.py +379 -568
- pyRDDLGym_jax/core/visualization.py +1463 -0
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +5 -6
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +4 -5
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +5 -6
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +5 -5
- pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/default_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/default_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/default_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/tuning_drp.cfg +19 -0
- pyRDDLGym_jax/examples/configs/tuning_replan.cfg +20 -0
- pyRDDLGym_jax/examples/configs/tuning_slp.cfg +19 -0
- pyRDDLGym_jax/examples/run_plan.py +4 -1
- pyRDDLGym_jax/examples/run_tune.py +40 -27
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/METADATA +161 -104
- pyRDDLGym_jax-1.0.dist-info/RECORD +45 -0
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -19
- pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -20
- pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -18
- pyRDDLGym_jax-0.5.dist-info/RECORD +0 -44
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-0.5.dist-info → pyRDDLGym_jax-1.0.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/logic.py
CHANGED
|
@@ -4,7 +4,13 @@ import jax
|
|
|
4
4
|
import jax.numpy as jnp
|
|
5
5
|
import jax.random as random
|
|
6
6
|
|
|
7
|
-
|
|
7
|
+
|
|
8
|
+
def enumerate_literals(shape, axis, dtype=jnp.int32):
|
|
9
|
+
literals = jnp.arange(shape[axis], dtype=dtype)
|
|
10
|
+
literals = literals[(...,) + (jnp.newaxis,) * (len(shape) - 1)]
|
|
11
|
+
literals = jnp.moveaxis(literals, source=0, destination=axis)
|
|
12
|
+
literals = jnp.broadcast_to(literals, shape=shape)
|
|
13
|
+
return literals
|
|
8
14
|
|
|
9
15
|
|
|
10
16
|
# ===========================================================================
|
|
@@ -17,36 +23,71 @@ from pyRDDLGym.core.debug.exception import raise_warning
|
|
|
17
23
|
class Comparison:
|
|
18
24
|
'''Base class for approximate comparison operations.'''
|
|
19
25
|
|
|
20
|
-
def greater_equal(self,
|
|
26
|
+
def greater_equal(self, id, init_params):
|
|
27
|
+
raise NotImplementedError
|
|
28
|
+
|
|
29
|
+
def greater(self, id, init_params):
|
|
21
30
|
raise NotImplementedError
|
|
22
31
|
|
|
23
|
-
def
|
|
32
|
+
def equal(self, id, init_params):
|
|
24
33
|
raise NotImplementedError
|
|
25
34
|
|
|
26
|
-
def
|
|
35
|
+
def sgn(self, id, init_params):
|
|
27
36
|
raise NotImplementedError
|
|
28
37
|
|
|
29
|
-
def
|
|
38
|
+
def argmax(self, id, init_params):
|
|
30
39
|
raise NotImplementedError
|
|
31
40
|
|
|
32
41
|
|
|
33
42
|
class SigmoidComparison(Comparison):
|
|
34
43
|
'''Comparison operations approximated using sigmoid functions.'''
|
|
35
44
|
|
|
45
|
+
def __init__(self, weight: float=10.0):
|
|
46
|
+
self.weight = weight
|
|
47
|
+
|
|
36
48
|
# https://arxiv.org/abs/2110.05651
|
|
37
|
-
def greater_equal(self,
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
49
|
+
def greater_equal(self, id, init_params):
|
|
50
|
+
id_ = str(id)
|
|
51
|
+
init_params[id_] = self.weight
|
|
52
|
+
def _jax_wrapped_calc_greater_equal_approx(x, y, params):
|
|
53
|
+
gre_eq = jax.nn.sigmoid(params[id_] * (x - y))
|
|
54
|
+
return gre_eq, params
|
|
55
|
+
return _jax_wrapped_calc_greater_equal_approx
|
|
56
|
+
|
|
57
|
+
def greater(self, id, init_params):
|
|
58
|
+
return self.greater_equal(id, init_params)
|
|
59
|
+
|
|
60
|
+
def equal(self, id, init_params):
|
|
61
|
+
id_ = str(id)
|
|
62
|
+
init_params[id_] = self.weight
|
|
63
|
+
def _jax_wrapped_calc_equal_approx(x, y, params):
|
|
64
|
+
equal = 1.0 - jnp.square(jnp.tanh(params[id_] * (y - x)))
|
|
65
|
+
return equal, params
|
|
66
|
+
return _jax_wrapped_calc_equal_approx
|
|
67
|
+
|
|
68
|
+
def sgn(self, id, init_params):
|
|
69
|
+
id_ = str(id)
|
|
70
|
+
init_params[id_] = self.weight
|
|
71
|
+
def _jax_wrapped_calc_sgn_approx(x, params):
|
|
72
|
+
sgn = jnp.tanh(params[id_] * x)
|
|
73
|
+
return sgn, params
|
|
74
|
+
return _jax_wrapped_calc_sgn_approx
|
|
45
75
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
76
|
+
# https://arxiv.org/abs/2110.05651
|
|
77
|
+
def argmax(self, id, init_params):
|
|
78
|
+
id_ = str(id)
|
|
79
|
+
init_params[id_] = self.weight
|
|
80
|
+
def _jax_wrapped_calc_argmax_approx(x, axis, params):
|
|
81
|
+
literals = enumerate_literals(x.shape, axis=axis)
|
|
82
|
+
softmax = jax.nn.softmax(params[id_] * x, axis=axis)
|
|
83
|
+
sample = jnp.sum(literals * softmax, axis=axis)
|
|
84
|
+
return sample, params
|
|
85
|
+
return _jax_wrapped_calc_argmax_approx
|
|
86
|
+
|
|
87
|
+
def __str__(self) -> str:
|
|
88
|
+
return f'Sigmoid comparison with weight {self.weight}'
|
|
89
|
+
|
|
90
|
+
|
|
50
91
|
# ===========================================================================
|
|
51
92
|
# ROUNDING OPERATIONS
|
|
52
93
|
# - abstract class
|
|
@@ -57,26 +98,44 @@ class SigmoidComparison(Comparison):
|
|
|
57
98
|
class Rounding:
|
|
58
99
|
'''Base class for approximate rounding operations.'''
|
|
59
100
|
|
|
60
|
-
def floor(self,
|
|
101
|
+
def floor(self, id, init_params):
|
|
61
102
|
raise NotImplementedError
|
|
62
103
|
|
|
63
|
-
def round(self,
|
|
104
|
+
def round(self, id, init_params):
|
|
64
105
|
raise NotImplementedError
|
|
65
106
|
|
|
66
107
|
|
|
67
108
|
class SoftRounding(Rounding):
|
|
68
109
|
'''Rounding operations approximated using soft operations.'''
|
|
69
110
|
|
|
111
|
+
def __init__(self, weight: float=10.0):
|
|
112
|
+
self.weight = weight
|
|
113
|
+
|
|
70
114
|
# https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/bijectors/Softfloor
|
|
71
|
-
def floor(self,
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
115
|
+
def floor(self, id, init_params):
|
|
116
|
+
id_ = str(id)
|
|
117
|
+
init_params[id_] = self.weight
|
|
118
|
+
def _jax_wrapped_calc_floor_approx(x, params):
|
|
119
|
+
param = params[id_]
|
|
120
|
+
denom = jnp.tanh(param / 4.0)
|
|
121
|
+
floor = (jax.nn.sigmoid(param * (x - jnp.floor(x) - 1.0)) -
|
|
122
|
+
jax.nn.sigmoid(-param / 2.0)) / denom + jnp.floor(x)
|
|
123
|
+
return floor, params
|
|
124
|
+
return _jax_wrapped_calc_floor_approx
|
|
75
125
|
|
|
76
126
|
# https://arxiv.org/abs/2006.09952
|
|
77
|
-
def round(self,
|
|
78
|
-
|
|
79
|
-
|
|
127
|
+
def round(self, id, init_params):
|
|
128
|
+
id_ = str(id)
|
|
129
|
+
init_params[id_] = self.weight
|
|
130
|
+
def _jax_wrapped_calc_round_approx(x, params):
|
|
131
|
+
param = params[id_]
|
|
132
|
+
m = jnp.floor(x) + 0.5
|
|
133
|
+
rounded = m + 0.5 * jnp.tanh(param * (x - m)) / jnp.tanh(param / 2.0)
|
|
134
|
+
return rounded, params
|
|
135
|
+
return _jax_wrapped_calc_round_approx
|
|
136
|
+
|
|
137
|
+
def __str__(self) -> str:
|
|
138
|
+
return f'SoftFloor and SoftRound with weight {self.weight}'
|
|
80
139
|
|
|
81
140
|
|
|
82
141
|
# ===========================================================================
|
|
@@ -89,7 +148,7 @@ class SoftRounding(Rounding):
|
|
|
89
148
|
class Complement:
|
|
90
149
|
'''Base class for approximate logical complement operations.'''
|
|
91
150
|
|
|
92
|
-
def __call__(self,
|
|
151
|
+
def __call__(self, id, init_params):
|
|
93
152
|
raise NotImplementedError
|
|
94
153
|
|
|
95
154
|
|
|
@@ -97,9 +156,16 @@ class StandardComplement(Complement):
|
|
|
97
156
|
'''The standard approximate logical complement given by x -> 1 - x.'''
|
|
98
157
|
|
|
99
158
|
# https://www.sciencedirect.com/science/article/abs/pii/016501149190171L
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
159
|
+
@staticmethod
|
|
160
|
+
def _jax_wrapped_calc_not_approx(x, params):
|
|
161
|
+
return 1.0 - x, params
|
|
162
|
+
|
|
163
|
+
def __call__(self, id, init_params):
|
|
164
|
+
return self._jax_wrapped_calc_not_approx
|
|
165
|
+
|
|
166
|
+
def __str__(self) -> str:
|
|
167
|
+
return 'Standard complement'
|
|
168
|
+
|
|
103
169
|
|
|
104
170
|
# ===========================================================================
|
|
105
171
|
# TNORMS
|
|
@@ -115,43 +181,78 @@ class StandardComplement(Complement):
|
|
|
115
181
|
class TNorm:
|
|
116
182
|
'''Base class for fuzzy differentiable t-norms.'''
|
|
117
183
|
|
|
118
|
-
def norm(self,
|
|
184
|
+
def norm(self, id, init_params):
|
|
119
185
|
'''Elementwise t-norm of x and y.'''
|
|
120
186
|
raise NotImplementedError
|
|
121
187
|
|
|
122
|
-
def norms(self,
|
|
188
|
+
def norms(self, id, init_params):
|
|
123
189
|
'''T-norm computed for tensor x along axis.'''
|
|
124
190
|
raise NotImplementedError
|
|
125
|
-
|
|
191
|
+
|
|
126
192
|
|
|
127
193
|
class ProductTNorm(TNorm):
|
|
128
194
|
'''Product t-norm given by the expression (x, y) -> x * y.'''
|
|
129
195
|
|
|
130
|
-
|
|
131
|
-
|
|
196
|
+
@staticmethod
|
|
197
|
+
def _jax_wrapped_calc_and_approx(x, y, params):
|
|
198
|
+
return x * y, params
|
|
199
|
+
|
|
200
|
+
def norm(self, id, init_params):
|
|
201
|
+
return self._jax_wrapped_calc_and_approx
|
|
132
202
|
|
|
133
|
-
|
|
134
|
-
|
|
203
|
+
@staticmethod
|
|
204
|
+
def _jax_wrapped_calc_forall_approx(x, axis, params):
|
|
205
|
+
return jnp.prod(x, axis=axis), params
|
|
206
|
+
|
|
207
|
+
def norms(self, id, init_params):
|
|
208
|
+
return self._jax_wrapped_calc_forall_approx
|
|
135
209
|
|
|
210
|
+
def __str__(self) -> str:
|
|
211
|
+
return 'Product t-norm'
|
|
212
|
+
|
|
136
213
|
|
|
137
214
|
class GodelTNorm(TNorm):
|
|
138
215
|
'''Godel t-norm given by the expression (x, y) -> min(x, y).'''
|
|
139
216
|
|
|
140
|
-
|
|
141
|
-
|
|
217
|
+
@staticmethod
|
|
218
|
+
def _jax_wrapped_calc_and_approx(x, y, params):
|
|
219
|
+
return jnp.minimum(x, y), params
|
|
220
|
+
|
|
221
|
+
def norm(self, id, init_params):
|
|
222
|
+
return self._jax_wrapped_calc_and_approx
|
|
223
|
+
|
|
224
|
+
@staticmethod
|
|
225
|
+
def _jax_wrapped_calc_forall_approx(x, axis, params):
|
|
226
|
+
return jnp.min(x, axis=axis), params
|
|
227
|
+
|
|
228
|
+
def norms(self, id, init_params):
|
|
229
|
+
return self._jax_wrapped_calc_forall_approx
|
|
230
|
+
|
|
231
|
+
def __str__(self) -> str:
|
|
232
|
+
return 'Godel t-norm'
|
|
142
233
|
|
|
143
|
-
def norms(self, x, axis):
|
|
144
|
-
return jnp.min(x, axis=axis)
|
|
145
|
-
|
|
146
234
|
|
|
147
235
|
class LukasiewiczTNorm(TNorm):
|
|
148
236
|
'''Lukasiewicz t-norm given by the expression (x, y) -> max(x + y - 1, 0).'''
|
|
149
237
|
|
|
150
|
-
|
|
151
|
-
|
|
238
|
+
@staticmethod
|
|
239
|
+
def _jax_wrapped_calc_and_approx(x, y, params):
|
|
240
|
+
land = jax.nn.relu(x + y - 1.0)
|
|
241
|
+
return land, params
|
|
242
|
+
|
|
243
|
+
def norm(self, id, init_params):
|
|
244
|
+
return self._jax_wrapped_calc_and_approx
|
|
152
245
|
|
|
153
|
-
|
|
154
|
-
|
|
246
|
+
@staticmethod
|
|
247
|
+
def _jax_wrapped_calc_forall_approx(x, axis, params):
|
|
248
|
+
forall = jax.nn.relu(jnp.sum(x - 1.0, axis=axis) + 1.0)
|
|
249
|
+
return forall, params
|
|
250
|
+
|
|
251
|
+
def norms(self, id, init_params):
|
|
252
|
+
return self._jax_wrapped_calc_forall_approx
|
|
253
|
+
|
|
254
|
+
def __str__(self) -> str:
|
|
255
|
+
return 'Lukasiewicz t-norm'
|
|
155
256
|
|
|
156
257
|
|
|
157
258
|
class YagerTNorm(TNorm):
|
|
@@ -161,17 +262,30 @@ class YagerTNorm(TNorm):
|
|
|
161
262
|
def __init__(self, p=2.0):
|
|
162
263
|
self.p = float(p)
|
|
163
264
|
|
|
164
|
-
def norm(self,
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
265
|
+
def norm(self, id, init_params):
|
|
266
|
+
id_ = str(id)
|
|
267
|
+
init_params[id_] = self.p
|
|
268
|
+
def _jax_wrapped_calc_and_approx(x, y, params):
|
|
269
|
+
base = jax.nn.relu(1.0 - jnp.stack([x, y], axis=0))
|
|
270
|
+
arg = jnp.linalg.norm(base, ord=params[id_], axis=0)
|
|
271
|
+
land = jax.nn.relu(1.0 - arg)
|
|
272
|
+
return land, params
|
|
273
|
+
return _jax_wrapped_calc_and_approx
|
|
274
|
+
|
|
275
|
+
def norms(self, id, init_params):
|
|
276
|
+
id_ = str(id)
|
|
277
|
+
init_params[id_] = self.p
|
|
278
|
+
def _jax_wrapped_calc_forall_approx(x, axis, params):
|
|
279
|
+
arg = jax.nn.relu(1.0 - x)
|
|
280
|
+
for ax in sorted(axis, reverse=True):
|
|
281
|
+
arg = jnp.linalg.norm(arg, ord=params[id_], axis=ax)
|
|
282
|
+
forall = jax.nn.relu(1.0 - arg)
|
|
283
|
+
return forall, params
|
|
284
|
+
return _jax_wrapped_calc_forall_approx
|
|
285
|
+
|
|
286
|
+
def __str__(self) -> str:
|
|
287
|
+
return f'Yager({self.p}) t-norm'
|
|
168
288
|
|
|
169
|
-
def norms(self, x, axis):
|
|
170
|
-
arg = jax.nn.relu(1.0 - x)
|
|
171
|
-
for ax in sorted(axis, reverse=True):
|
|
172
|
-
arg = jnp.linalg.norm(arg, ord=self.p, axis=ax)
|
|
173
|
-
return jax.nn.relu(1.0 - arg)
|
|
174
|
-
|
|
175
289
|
|
|
176
290
|
# ===========================================================================
|
|
177
291
|
# RANDOM SAMPLING
|
|
@@ -185,117 +299,443 @@ class RandomSampling:
|
|
|
185
299
|
'''An abstract class that describes how discrete and non-reparameterizable
|
|
186
300
|
random variables are sampled.'''
|
|
187
301
|
|
|
188
|
-
def discrete(self, logic):
|
|
302
|
+
def discrete(self, id, init_params, logic):
|
|
189
303
|
raise NotImplementedError
|
|
190
304
|
|
|
191
|
-
def bernoulli(self, logic):
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
def _jax_wrapped_calc_bernoulli_approx(key, prob, param):
|
|
305
|
+
def bernoulli(self, id, init_params, logic):
|
|
306
|
+
discrete_approx = self.discrete(id, init_params, logic)
|
|
307
|
+
def _jax_wrapped_calc_bernoulli_approx(key, prob, params):
|
|
195
308
|
prob = jnp.stack([1.0 - prob, prob], axis=-1)
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
return _jax_wrapped_calc_bernoulli_approx, jax_param
|
|
309
|
+
return discrete_approx(key, prob, params)
|
|
310
|
+
return _jax_wrapped_calc_bernoulli_approx
|
|
200
311
|
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
312
|
+
@staticmethod
|
|
313
|
+
def _jax_wrapped_calc_poisson_exact(key, rate, params):
|
|
314
|
+
sample = random.poisson(key=key, lam=rate, dtype=logic.INT)
|
|
315
|
+
return sample, params
|
|
205
316
|
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
jax_floor, jax_param = logic.floor()
|
|
214
|
-
|
|
215
|
-
def _jax_wrapped_calc_geometric_approx(key, prob, param):
|
|
317
|
+
def poisson(self, id, init_params, logic):
|
|
318
|
+
return self._jax_wrapped_calc_poisson_exact
|
|
319
|
+
|
|
320
|
+
def geometric(self, id, init_params, logic):
|
|
321
|
+
approx_floor = logic.floor(id, init_params)
|
|
322
|
+
def _jax_wrapped_calc_geometric_approx(key, prob, params):
|
|
216
323
|
U = random.uniform(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
return _jax_wrapped_calc_geometric_approx
|
|
324
|
+
floor, params = approx_floor(jnp.log(U) / jnp.log(1.0 - prob), params)
|
|
325
|
+
sample = floor + 1
|
|
326
|
+
return sample, params
|
|
327
|
+
return _jax_wrapped_calc_geometric_approx
|
|
221
328
|
|
|
222
329
|
|
|
223
330
|
class GumbelSoftmax(RandomSampling):
|
|
224
331
|
'''Random sampling of discrete variables using Gumbel-softmax trick.'''
|
|
225
332
|
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
jax_argmax, jax_param = logic.argmax()
|
|
232
|
-
|
|
233
|
-
# https://arxiv.org/pdf/1611.01144
|
|
234
|
-
def _jax_wrapped_calc_discrete_gumbel_softmax(key, prob, param):
|
|
333
|
+
# https://arxiv.org/pdf/1611.01144
|
|
334
|
+
def discrete(self, id, init_params, logic):
|
|
335
|
+
argmax_approx = logic.argmax(id, init_params)
|
|
336
|
+
def _jax_wrapped_calc_discrete_gumbel_softmax(key, prob, params):
|
|
235
337
|
Gumbel01 = random.gumbel(key=key, shape=prob.shape, dtype=logic.REAL)
|
|
236
338
|
sample = Gumbel01 + jnp.log(prob + logic.eps)
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
339
|
+
return argmax_approx(sample, axis=-1, params=params)
|
|
340
|
+
return _jax_wrapped_calc_discrete_gumbel_softmax
|
|
341
|
+
|
|
342
|
+
def __str__(self) -> str:
|
|
343
|
+
return 'Gumbel-Softmax'
|
|
241
344
|
|
|
242
345
|
|
|
243
346
|
class Determinization(RandomSampling):
|
|
244
347
|
'''Random sampling of variables using their deterministic mean estimate.'''
|
|
245
348
|
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
def _jax_wrapped_calc_discrete_determinized(key, prob, param):
|
|
252
|
-
literals = FuzzyLogic.enumerate_literals(prob.shape, axis=-1)
|
|
253
|
-
sample = jnp.sum(literals * prob, axis=-1)
|
|
254
|
-
return sample
|
|
255
|
-
|
|
256
|
-
return _jax_wrapped_calc_discrete_determinized, None
|
|
257
|
-
|
|
258
|
-
def poisson(self, logic):
|
|
259
|
-
if logic.verbose:
|
|
260
|
-
raise_warning('Using the replacement rule: Poisson(rate) --> rate')
|
|
261
|
-
|
|
262
|
-
def _jax_wrapped_calc_poisson_determinized(key, rate, param):
|
|
263
|
-
return rate
|
|
264
|
-
|
|
265
|
-
return _jax_wrapped_calc_poisson_determinized, None
|
|
266
|
-
|
|
267
|
-
def geometric(self, logic):
|
|
268
|
-
if logic.verbose:
|
|
269
|
-
raise_warning('Using the replacement rule: Geometric(p) --> 1 / p')
|
|
270
|
-
|
|
271
|
-
def _jax_wrapped_calc_geometric_determinized(key, prob, param):
|
|
272
|
-
sample = 1.0 / prob
|
|
273
|
-
return sample
|
|
349
|
+
@staticmethod
|
|
350
|
+
def _jax_wrapped_calc_discrete_determinized(key, prob, params):
|
|
351
|
+
literals = enumerate_literals(prob.shape, axis=-1)
|
|
352
|
+
sample = jnp.sum(literals * prob, axis=-1)
|
|
353
|
+
return sample, params
|
|
274
354
|
|
|
275
|
-
|
|
355
|
+
def discrete(self, id, init_params, logic):
|
|
356
|
+
return self._jax_wrapped_calc_discrete_determinized
|
|
357
|
+
|
|
358
|
+
@staticmethod
|
|
359
|
+
def _jax_wrapped_calc_poisson_determinized(key, rate, params):
|
|
360
|
+
return rate, params
|
|
276
361
|
|
|
362
|
+
def poisson(self, id, init_params, logic):
|
|
363
|
+
return self._jax_wrapped_calc_poisson_determinized
|
|
364
|
+
|
|
365
|
+
@staticmethod
|
|
366
|
+
def _jax_wrapped_calc_geometric_determinized(key, prob, params):
|
|
367
|
+
sample = 1.0 / prob
|
|
368
|
+
return sample, params
|
|
369
|
+
|
|
370
|
+
def geometric(self, id, init_params, logic):
|
|
371
|
+
return self._jax_wrapped_calc_geometric_determinized
|
|
372
|
+
|
|
373
|
+
def __str__(self) -> str:
|
|
374
|
+
return 'Deterministic'
|
|
375
|
+
|
|
277
376
|
|
|
278
377
|
# ===========================================================================
|
|
279
|
-
#
|
|
378
|
+
# CONTROL FLOW
|
|
379
|
+
# - soft flow
|
|
280
380
|
#
|
|
281
381
|
# ===========================================================================
|
|
282
382
|
|
|
283
|
-
class
|
|
284
|
-
'''A class
|
|
383
|
+
class ControlFlow:
|
|
384
|
+
'''A base class for control flow, including if and switch statements.'''
|
|
385
|
+
|
|
386
|
+
def if_then_else(self, id, init_params):
|
|
387
|
+
raise NotImplementedError
|
|
285
388
|
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
389
|
+
def switch(self, id, init_params):
|
|
390
|
+
raise NotImplementedError
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
class SoftControlFlow(ControlFlow):
|
|
394
|
+
'''Soft control flow using a probabilistic interpretation.'''
|
|
395
|
+
|
|
396
|
+
def __init__(self, weight: float=10.0) -> None:
|
|
397
|
+
self.weight = weight
|
|
398
|
+
|
|
399
|
+
@staticmethod
|
|
400
|
+
def _jax_wrapped_calc_if_then_else_soft(c, a, b, params):
|
|
401
|
+
sample = c * a + (1.0 - c) * b
|
|
402
|
+
return sample, params
|
|
403
|
+
|
|
404
|
+
def if_then_else(self, id, init_params):
|
|
405
|
+
return self._jax_wrapped_calc_if_then_else_soft
|
|
406
|
+
|
|
407
|
+
def switch(self, id, init_params):
|
|
408
|
+
id_ = str(id)
|
|
409
|
+
init_params[id_] = self.weight
|
|
410
|
+
def _jax_wrapped_calc_switch_soft(pred, cases, params):
|
|
411
|
+
literals = enumerate_literals(cases.shape, axis=0)
|
|
412
|
+
pred = jnp.broadcast_to(pred[jnp.newaxis, ...], shape=cases.shape)
|
|
413
|
+
proximity = -jnp.square(pred - literals)
|
|
414
|
+
softcase = jax.nn.softmax(params[id_] * proximity, axis=0)
|
|
415
|
+
sample = jnp.sum(cases * softcase, axis=0)
|
|
416
|
+
return sample, params
|
|
417
|
+
return _jax_wrapped_calc_switch_soft
|
|
418
|
+
|
|
419
|
+
def __str__(self) -> str:
|
|
420
|
+
return f'Soft control flow with weight {self.weight}'
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
# ===========================================================================
|
|
424
|
+
# LOGIC
|
|
425
|
+
# - exact logic
|
|
426
|
+
# - fuzzy logic
|
|
427
|
+
#
|
|
428
|
+
# ===========================================================================
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
class Logic:
|
|
432
|
+
'''A base class for representing logic computations in JAX.'''
|
|
433
|
+
|
|
434
|
+
def __init__(self, use64bit: bool=False) -> None:
|
|
435
|
+
self.set_use64bit(use64bit)
|
|
436
|
+
|
|
437
|
+
def summarize_hyperparameters(self) -> None:
|
|
438
|
+
print(f'model relaxation:\n'
|
|
439
|
+
f' use_64_bit ={self.use64bit}')
|
|
440
|
+
|
|
441
|
+
def set_use64bit(self, use64bit: bool) -> None:
|
|
442
|
+
'''Toggles whether or not the JAX system will use 64 bit precision.'''
|
|
443
|
+
self.use64bit = use64bit
|
|
444
|
+
if use64bit:
|
|
445
|
+
self.REAL = jnp.float64
|
|
446
|
+
self.INT = jnp.int64
|
|
447
|
+
jax.config.update('jax_enable_x64', True)
|
|
448
|
+
else:
|
|
449
|
+
self.REAL = jnp.float32
|
|
450
|
+
self.INT = jnp.int32
|
|
451
|
+
jax.config.update('jax_enable_x64', False)
|
|
452
|
+
|
|
453
|
+
# ===========================================================================
|
|
454
|
+
# logical operators
|
|
455
|
+
# ===========================================================================
|
|
456
|
+
|
|
457
|
+
def logical_and(self, id, init_params):
|
|
458
|
+
raise NotImplementedError
|
|
459
|
+
|
|
460
|
+
def logical_not(self, id, init_params):
|
|
461
|
+
raise NotImplementedError
|
|
462
|
+
|
|
463
|
+
def logical_or(self, id, init_params):
|
|
464
|
+
raise NotImplementedError
|
|
465
|
+
|
|
466
|
+
def xor(self, id, init_params):
|
|
467
|
+
raise NotImplementedError
|
|
468
|
+
|
|
469
|
+
def implies(self, id, init_params):
|
|
470
|
+
raise NotImplementedError
|
|
471
|
+
|
|
472
|
+
def equiv(self, id, init_params):
|
|
473
|
+
raise NotImplementedError
|
|
474
|
+
|
|
475
|
+
def forall(self, id, init_params):
|
|
476
|
+
raise NotImplementedError
|
|
477
|
+
|
|
478
|
+
def exists(self, id, init_params):
|
|
479
|
+
raise NotImplementedError
|
|
480
|
+
|
|
481
|
+
# ===========================================================================
|
|
482
|
+
# comparison operators
|
|
483
|
+
# ===========================================================================
|
|
484
|
+
|
|
485
|
+
def greater_equal(self, id, init_params):
|
|
486
|
+
raise NotImplementedError
|
|
487
|
+
|
|
488
|
+
def greater(self, id, init_params):
|
|
489
|
+
raise NotImplementedError
|
|
490
|
+
|
|
491
|
+
def less_equal(self, id, init_params):
|
|
492
|
+
raise NotImplementedError
|
|
493
|
+
|
|
494
|
+
def less(self, id, init_params):
|
|
495
|
+
raise NotImplementedError
|
|
496
|
+
|
|
497
|
+
def equal(self, id, init_params):
|
|
498
|
+
raise NotImplementedError
|
|
499
|
+
|
|
500
|
+
def not_equal(self, id, init_params):
|
|
501
|
+
raise NotImplementedError
|
|
502
|
+
|
|
503
|
+
# ===========================================================================
|
|
504
|
+
# special functions
|
|
505
|
+
# ===========================================================================
|
|
506
|
+
|
|
507
|
+
def sgn(self, id, init_params):
|
|
508
|
+
raise NotImplementedError
|
|
509
|
+
|
|
510
|
+
def floor(self, id, init_params):
|
|
511
|
+
raise NotImplementedError
|
|
512
|
+
|
|
513
|
+
def round(self, id, init_params):
|
|
514
|
+
raise NotImplementedError
|
|
515
|
+
|
|
516
|
+
def ceil(self, id, init_params):
|
|
517
|
+
raise NotImplementedError
|
|
518
|
+
|
|
519
|
+
def div(self, id, init_params):
|
|
520
|
+
raise NotImplementedError
|
|
521
|
+
|
|
522
|
+
def mod(self, id, init_params):
|
|
523
|
+
raise NotImplementedError
|
|
524
|
+
|
|
525
|
+
def sqrt(self, id, init_params):
|
|
526
|
+
raise NotImplementedError
|
|
527
|
+
|
|
528
|
+
# ===========================================================================
|
|
529
|
+
# indexing
|
|
530
|
+
# ===========================================================================
|
|
531
|
+
|
|
532
|
+
def argmax(self, id, init_params):
|
|
533
|
+
raise NotImplementedError
|
|
534
|
+
|
|
535
|
+
def argmin(self, id, init_params):
|
|
536
|
+
raise NotImplementedError
|
|
537
|
+
|
|
538
|
+
# ===========================================================================
|
|
539
|
+
# control flow
|
|
540
|
+
# ===========================================================================
|
|
541
|
+
|
|
542
|
+
def control_if(self, id, init_params):
|
|
543
|
+
raise NotImplementedError
|
|
544
|
+
|
|
545
|
+
def control_switch(self, id, init_params):
|
|
546
|
+
raise NotImplementedError
|
|
547
|
+
|
|
548
|
+
# ===========================================================================
|
|
549
|
+
# random variables
|
|
550
|
+
# ===========================================================================
|
|
551
|
+
|
|
552
|
+
def discrete(self, id, init_params):
|
|
553
|
+
raise NotImplementedError
|
|
554
|
+
|
|
555
|
+
def bernoulli(self, id, init_params):
|
|
556
|
+
raise NotImplementedError
|
|
557
|
+
|
|
558
|
+
def poisson(self, id, init_params):
|
|
559
|
+
raise NotImplementedError
|
|
560
|
+
|
|
561
|
+
def geometric(self, id, init_params):
|
|
562
|
+
raise NotImplementedError
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
class ExactLogic(Logic):
|
|
566
|
+
'''A class representing exact logic in JAX.'''
|
|
567
|
+
|
|
568
|
+
@staticmethod
|
|
569
|
+
def exact_unary_function(op):
|
|
570
|
+
def _jax_wrapped_calc_unary_function_exact(x, params):
|
|
571
|
+
return op(x), params
|
|
572
|
+
return _jax_wrapped_calc_unary_function_exact
|
|
573
|
+
|
|
574
|
+
@staticmethod
|
|
575
|
+
def exact_binary_function(op):
|
|
576
|
+
def _jax_wrapped_calc_binary_function_exact(x, y, params):
|
|
577
|
+
return op(x, y), params
|
|
578
|
+
return _jax_wrapped_calc_binary_function_exact
|
|
579
|
+
|
|
580
|
+
@staticmethod
|
|
581
|
+
def exact_aggregation(op):
|
|
582
|
+
def _jax_wrapped_calc_aggregation_exact(x, axis, params):
|
|
583
|
+
return op(x, axis=axis), params
|
|
584
|
+
return _jax_wrapped_calc_aggregation_exact
|
|
585
|
+
|
|
586
|
+
# ===========================================================================
|
|
587
|
+
# logical operators
|
|
588
|
+
# ===========================================================================
|
|
589
|
+
|
|
590
|
+
def logical_and(self, id, init_params):
|
|
591
|
+
return self.exact_binary_function(jnp.logical_and)
|
|
592
|
+
|
|
593
|
+
def logical_not(self, id, init_params):
|
|
594
|
+
return self.exact_unary_function(jnp.logical_not)
|
|
595
|
+
|
|
596
|
+
def logical_or(self, id, init_params):
|
|
597
|
+
return self.exact_binary_function(jnp.logical_or)
|
|
598
|
+
|
|
599
|
+
def xor(self, id, init_params):
|
|
600
|
+
return self.exact_binary_function(jnp.logical_xor)
|
|
601
|
+
|
|
602
|
+
@staticmethod
|
|
603
|
+
def exact_binary_implies(x, y, params):
|
|
604
|
+
return jnp.logical_or(jnp.logical_not(x), y), params
|
|
605
|
+
|
|
606
|
+
def implies(self, id, init_params):
|
|
607
|
+
return self.exact_binary_implies
|
|
608
|
+
|
|
609
|
+
def equiv(self, id, init_params):
|
|
610
|
+
return self.exact_binary_function(jnp.equal)
|
|
611
|
+
|
|
612
|
+
def forall(self, id, init_params):
|
|
613
|
+
return self.exact_aggregation(jnp.all)
|
|
614
|
+
|
|
615
|
+
def exists(self, id, init_params):
|
|
616
|
+
return self.exact_aggregation(jnp.any)
|
|
617
|
+
|
|
618
|
+
# ===========================================================================
|
|
619
|
+
# comparison operators
|
|
620
|
+
# ===========================================================================
|
|
621
|
+
|
|
622
|
+
def greater_equal(self, id, init_params):
|
|
623
|
+
return self.exact_binary_function(jnp.greater_equal)
|
|
624
|
+
|
|
625
|
+
def greater(self, id, init_params):
|
|
626
|
+
return self.exact_binary_function(jnp.greater)
|
|
627
|
+
|
|
628
|
+
def less_equal(self, id, init_params):
|
|
629
|
+
return self.exact_binary_function(jnp.less_equal)
|
|
630
|
+
|
|
631
|
+
def less(self, id, init_params):
|
|
632
|
+
return self.exact_binary_function(jnp.less)
|
|
633
|
+
|
|
634
|
+
def equal(self, id, init_params):
|
|
635
|
+
return self.exact_binary_function(jnp.equal)
|
|
636
|
+
|
|
637
|
+
def not_equal(self, id, init_params):
|
|
638
|
+
return self.exact_binary_function(jnp.not_equal)
|
|
639
|
+
|
|
640
|
+
# ===========================================================================
|
|
641
|
+
# special functions
|
|
642
|
+
# ===========================================================================
|
|
643
|
+
|
|
644
|
+
def sgn(self, id, init_params):
|
|
645
|
+
return self.exact_unary_function(jnp.sign)
|
|
646
|
+
|
|
647
|
+
def floor(self, id, init_params):
|
|
648
|
+
return self.exact_unary_function(jnp.floor)
|
|
649
|
+
|
|
650
|
+
def round(self, id, init_params):
|
|
651
|
+
return self.exact_unary_function(jnp.round)
|
|
652
|
+
|
|
653
|
+
def ceil(self, id, init_params):
|
|
654
|
+
return self.exact_unary_function(jnp.ceil)
|
|
655
|
+
|
|
656
|
+
def div(self, id, init_params):
|
|
657
|
+
return self.exact_binary_function(jnp.floor_divide)
|
|
658
|
+
|
|
659
|
+
def mod(self, id, init_params):
|
|
660
|
+
return self.exact_binary_function(jnp.mod)
|
|
661
|
+
|
|
662
|
+
def sqrt(self, id, init_params):
|
|
663
|
+
return self.exact_unary_function(jnp.sqrt)
|
|
664
|
+
|
|
665
|
+
# ===========================================================================
|
|
666
|
+
# indexing
|
|
667
|
+
# ===========================================================================
|
|
668
|
+
|
|
669
|
+
def argmax(self, id, init_params):
|
|
670
|
+
return self.exact_aggregation(jnp.argmax)
|
|
671
|
+
|
|
672
|
+
def argmin(self, id, init_params):
|
|
673
|
+
return self.exact_aggregation(jnp.argmin)
|
|
674
|
+
|
|
675
|
+
# ===========================================================================
|
|
676
|
+
# control flow
|
|
677
|
+
# ===========================================================================
|
|
678
|
+
|
|
679
|
+
@staticmethod
|
|
680
|
+
def exact_if_then_else(c, a, b, params):
|
|
681
|
+
return jnp.where(c > 0.5, a, b), params
|
|
682
|
+
|
|
683
|
+
def control_if(self, id, init_params):
|
|
684
|
+
return self.exact_if_then_else
|
|
685
|
+
|
|
686
|
+
@staticmethod
|
|
687
|
+
def exact_switch(pred, cases, params):
|
|
688
|
+
pred = pred[jnp.newaxis, ...]
|
|
689
|
+
sample = jnp.take_along_axis(cases, pred, axis=0)
|
|
690
|
+
assert sample.shape[0] == 1
|
|
691
|
+
return sample[0, ...], params
|
|
692
|
+
|
|
693
|
+
def control_switch(self, id, init_params):
|
|
694
|
+
return self.exact_switch
|
|
695
|
+
|
|
696
|
+
# ===========================================================================
|
|
697
|
+
# random variables
|
|
698
|
+
# ===========================================================================
|
|
699
|
+
|
|
700
|
+
@staticmethod
|
|
701
|
+
def exact_discrete(key, prob, params):
|
|
702
|
+
return random.categorical(key=key, logits=jnp.log(prob), axis=-1), params
|
|
703
|
+
|
|
704
|
+
def discrete(self, id, init_params):
|
|
705
|
+
return self.exact_discrete
|
|
706
|
+
|
|
707
|
+
@staticmethod
|
|
708
|
+
def exact_bernoulli(key, prob, params):
|
|
709
|
+
return random.bernoulli(key, prob), params
|
|
710
|
+
|
|
711
|
+
def bernoulli(self, id, init_params):
|
|
712
|
+
return self.exact_bernoulli
|
|
713
|
+
|
|
714
|
+
@staticmethod
|
|
715
|
+
def exact_poisson(key, rate, params):
|
|
716
|
+
return random.poisson(key=key, lam=rate), params
|
|
717
|
+
|
|
718
|
+
def poisson(self, id, init_params):
|
|
719
|
+
return self.exact_poisson
|
|
720
|
+
|
|
721
|
+
@staticmethod
|
|
722
|
+
def exact_geometric(key, prob, params):
|
|
723
|
+
return random.geometric(key=key, p=prob), params
|
|
724
|
+
|
|
725
|
+
def geometric(self, id, init_params):
|
|
726
|
+
return self.exact_geometric
|
|
727
|
+
|
|
728
|
+
|
|
729
|
+
class FuzzyLogic(Logic):
|
|
730
|
+
'''A class representing fuzzy logic in JAX.'''
|
|
289
731
|
|
|
290
732
|
def __init__(self, tnorm: TNorm=ProductTNorm(),
|
|
291
733
|
complement: Complement=StandardComplement(),
|
|
292
734
|
comparison: Comparison=SigmoidComparison(),
|
|
293
735
|
sampling: RandomSampling=GumbelSoftmax(),
|
|
294
736
|
rounding: Rounding=SoftRounding(),
|
|
295
|
-
|
|
296
|
-
debias: Optional[Set[str]]=None,
|
|
737
|
+
control: ControlFlow=SoftControlFlow(),
|
|
297
738
|
eps: float=1e-15,
|
|
298
|
-
verbose: bool=False,
|
|
299
739
|
use64bit: bool=False) -> None:
|
|
300
740
|
'''Creates a new fuzzy logic in Jax.
|
|
301
741
|
|
|
@@ -304,428 +744,212 @@ class FuzzyLogic:
|
|
|
304
744
|
:param comparison: fuzzy operator for comparisons (>, >=, <, ==, ~=, ...)
|
|
305
745
|
:param sampling: random sampling of non-reparameterizable distributions
|
|
306
746
|
:param rounding: rounding floating values to integers
|
|
307
|
-
:param
|
|
308
|
-
:param debias: which functions to de-bias approximate on forward pass
|
|
747
|
+
:param control: if and switch control structures
|
|
309
748
|
:param eps: small positive float to mitigate underflow
|
|
310
|
-
:param verbose: whether to dump replacements and other info to console
|
|
311
749
|
:param use64bit: whether to perform arithmetic in 64 bit
|
|
312
750
|
'''
|
|
751
|
+
super().__init__(use64bit=use64bit)
|
|
313
752
|
self.tnorm = tnorm
|
|
314
753
|
self.complement = complement
|
|
315
754
|
self.comparison = comparison
|
|
316
755
|
self.sampling = sampling
|
|
317
756
|
self.rounding = rounding
|
|
318
|
-
self.
|
|
319
|
-
if debias is None:
|
|
320
|
-
debias = set()
|
|
321
|
-
self.debias = debias
|
|
757
|
+
self.control = control
|
|
322
758
|
self.eps = eps
|
|
323
|
-
self.verbose = verbose
|
|
324
|
-
self.set_use64bit(use64bit)
|
|
325
759
|
|
|
326
|
-
def
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
760
|
+
def __str__(self) -> str:
|
|
761
|
+
return (f'model relaxation:\n'
|
|
762
|
+
f' tnorm ={str(self.tnorm)}\n'
|
|
763
|
+
f' complement ={str(self.complement)}\n'
|
|
764
|
+
f' comparison ={str(self.comparison)}\n'
|
|
765
|
+
f' sampling ={str(self.sampling)}\n'
|
|
766
|
+
f' rounding ={str(self.rounding)}\n'
|
|
767
|
+
f' control ={str(self.control)}\n'
|
|
768
|
+
f' underflow_tol ={self.eps}\n'
|
|
769
|
+
f' use_64_bit ={self.use64bit}')
|
|
770
|
+
|
|
337
771
|
def summarize_hyperparameters(self) -> None:
|
|
338
|
-
print(
|
|
339
|
-
f' tnorm ={type(self.tnorm).__name__}\n'
|
|
340
|
-
f' complement ={type(self.complement).__name__}\n'
|
|
341
|
-
f' comparison ={type(self.comparison).__name__}\n'
|
|
342
|
-
f' sampling ={type(self.sampling).__name__}\n'
|
|
343
|
-
f' rounding ={type(self.rounding).__name__}\n'
|
|
344
|
-
f' sigmoid_weight={self.weight}\n'
|
|
345
|
-
f' cpfs_to_debias={self.debias}\n'
|
|
346
|
-
f' underflow_tol ={self.eps}\n'
|
|
347
|
-
f' use_64_bit ={self.use64bit}')
|
|
772
|
+
print(self.__str__())
|
|
348
773
|
|
|
349
774
|
# ===========================================================================
|
|
350
775
|
# logical operators
|
|
351
776
|
# ===========================================================================
|
|
352
777
|
|
|
353
|
-
def logical_and(self):
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
return _not(x)
|
|
372
|
-
|
|
373
|
-
return _jax_wrapped_calc_not_approx, None
|
|
374
|
-
|
|
375
|
-
def logical_or(self):
|
|
376
|
-
if self.verbose:
|
|
377
|
-
raise_warning('Using the replacement rule: a | b --> tconorm(a, b).')
|
|
378
|
-
|
|
379
|
-
_not = self.complement
|
|
380
|
-
_and = self.tnorm.norm
|
|
381
|
-
|
|
382
|
-
def _jax_wrapped_calc_or_approx(a, b, param):
|
|
383
|
-
return _not(_and(_not(a), _not(b)))
|
|
384
|
-
|
|
385
|
-
return _jax_wrapped_calc_or_approx, None
|
|
778
|
+
def logical_and(self, id, init_params):
|
|
779
|
+
return self.tnorm.norm(id, init_params)
|
|
780
|
+
|
|
781
|
+
def logical_not(self, id, init_params):
|
|
782
|
+
return self.complement(id, init_params)
|
|
783
|
+
|
|
784
|
+
def logical_or(self, id, init_params):
|
|
785
|
+
_not1 = self.complement(f'{id}_~1', init_params)
|
|
786
|
+
_not2 = self.complement(f'{id}_~2', init_params)
|
|
787
|
+
_and = self.tnorm.norm(f'{id}_^', init_params)
|
|
788
|
+
_not = self.complement(f'{id}_~', init_params)
|
|
789
|
+
|
|
790
|
+
def _jax_wrapped_calc_or_approx(x, y, params):
|
|
791
|
+
not_x, params = _not1(x, params)
|
|
792
|
+
not_y, params = _not2(y, params)
|
|
793
|
+
not_x_and_not_y, params = _and(not_x, not_y, params)
|
|
794
|
+
return _not(not_x_and_not_y, params)
|
|
795
|
+
return _jax_wrapped_calc_or_approx
|
|
386
796
|
|
|
387
|
-
def xor(self):
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
return
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
_forall = self.tnorm.norms
|
|
434
|
-
|
|
435
|
-
def _jax_wrapped_calc_forall_approx(x, axis, param):
|
|
436
|
-
return _forall(x, axis=axis)
|
|
437
|
-
|
|
438
|
-
return _jax_wrapped_calc_forall_approx, None
|
|
439
|
-
|
|
440
|
-
def exists(self):
|
|
441
|
-
_not = self.complement
|
|
442
|
-
jax_forall, jax_param = self.forall()
|
|
443
|
-
|
|
444
|
-
def _jax_wrapped_calc_exists_approx(x, axis, param):
|
|
445
|
-
return _not(jax_forall(_not(x), axis, param))
|
|
446
|
-
|
|
447
|
-
return _jax_wrapped_calc_exists_approx, jax_param
|
|
797
|
+
def xor(self, id, init_params):
|
|
798
|
+
_not = self.complement(f'{id}_~', init_params)
|
|
799
|
+
_and1 = self.tnorm.norm(f'{id}_^1', init_params)
|
|
800
|
+
_and2 = self.tnorm.norm(f'{id}_^2', init_params)
|
|
801
|
+
_or = self.logical_or(f'{id}_|', init_params)
|
|
802
|
+
|
|
803
|
+
def _jax_wrapped_calc_xor_approx(x, y, params):
|
|
804
|
+
x_and_y, params = _and1(x, y, params)
|
|
805
|
+
not_x_and_y, params = _not(x_and_y, params)
|
|
806
|
+
x_or_y, params = _or(x, y, params)
|
|
807
|
+
return _and2(x_or_y, not_x_and_y, params)
|
|
808
|
+
return _jax_wrapped_calc_xor_approx
|
|
809
|
+
|
|
810
|
+
def implies(self, id, init_params):
|
|
811
|
+
_not = self.complement(f'{id}_~', init_params)
|
|
812
|
+
_or = self.logical_or(f'{id}_|', init_params)
|
|
813
|
+
|
|
814
|
+
def _jax_wrapped_calc_implies_approx(x, y, params):
|
|
815
|
+
not_x, params = _not(x, params)
|
|
816
|
+
return _or(not_x, y, params)
|
|
817
|
+
return _jax_wrapped_calc_implies_approx
|
|
818
|
+
|
|
819
|
+
def equiv(self, id, init_params):
|
|
820
|
+
_implies1 = self.implies(f'{id}_=>1', init_params)
|
|
821
|
+
_implies2 = self.implies(f'{id}_=>2', init_params)
|
|
822
|
+
_and = self.tnorm.norm(f'{id}_^', init_params)
|
|
823
|
+
|
|
824
|
+
def _jax_wrapped_calc_equiv_approx(x, y, params):
|
|
825
|
+
x_implies_y, params = _implies1(x, y, params)
|
|
826
|
+
y_implies_x, params = _implies2(y, x, params)
|
|
827
|
+
return _and(x_implies_y, y_implies_x, params)
|
|
828
|
+
return _jax_wrapped_calc_equiv_approx
|
|
829
|
+
|
|
830
|
+
def forall(self, id, init_params):
|
|
831
|
+
return self.tnorm.norms(id, init_params)
|
|
832
|
+
|
|
833
|
+
def exists(self, id, init_params):
|
|
834
|
+
_not1 = self.complement(f'{id}_~1', init_params)
|
|
835
|
+
_not2 = self.complement(f'{id}_~2', init_params)
|
|
836
|
+
_forall = self.forall(f'{id}_forall', init_params)
|
|
837
|
+
|
|
838
|
+
def _jax_wrapped_calc_exists_approx(x, axis, params):
|
|
839
|
+
not_x, params = _not1(x, params)
|
|
840
|
+
forall_not_x, params = _forall(not_x, axis, params)
|
|
841
|
+
return _not2(forall_not_x, params)
|
|
842
|
+
return _jax_wrapped_calc_exists_approx
|
|
448
843
|
|
|
449
844
|
# ===========================================================================
|
|
450
845
|
# comparison operators
|
|
451
846
|
# ===========================================================================
|
|
452
|
-
|
|
453
|
-
def greater_equal(self):
|
|
454
|
-
if self.verbose:
|
|
455
|
-
raise_warning('Using the replacement rule: '
|
|
456
|
-
'a >= b --> comparison.greater_equal(a, b)')
|
|
457
|
-
|
|
458
|
-
greater_equal_op = self.comparison.greater_equal
|
|
459
|
-
debias = 'greater_equal' in self.debias
|
|
460
|
-
|
|
461
|
-
def _jax_wrapped_calc_geq_approx(a, b, param):
|
|
462
|
-
sample = greater_equal_op(a, b, param)
|
|
463
|
-
if debias:
|
|
464
|
-
hard_sample = jnp.greater_equal(a, b)
|
|
465
|
-
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
466
|
-
return sample
|
|
467
|
-
|
|
468
|
-
tags = ('weight', 'greater_equal')
|
|
469
|
-
new_param = (tags, self.weight)
|
|
470
|
-
return _jax_wrapped_calc_geq_approx, new_param
|
|
471
|
-
|
|
472
|
-
def greater(self):
|
|
473
|
-
if self.verbose:
|
|
474
|
-
raise_warning('Using the replacement rule: '
|
|
475
|
-
'a > b --> comparison.greater(a, b)')
|
|
476
|
-
|
|
477
|
-
greater_op = self.comparison.greater
|
|
478
|
-
debias = 'greater' in self.debias
|
|
479
|
-
|
|
480
|
-
def _jax_wrapped_calc_gre_approx(a, b, param):
|
|
481
|
-
sample = greater_op(a, b, param)
|
|
482
|
-
if debias:
|
|
483
|
-
hard_sample = jnp.greater(a, b)
|
|
484
|
-
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
485
|
-
return sample
|
|
486
|
-
|
|
487
|
-
tags = ('weight', 'greater')
|
|
488
|
-
new_param = (tags, self.weight)
|
|
489
|
-
return _jax_wrapped_calc_gre_approx, new_param
|
|
490
847
|
|
|
491
|
-
def
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
def _jax_wrapped_calc_leq_approx(a, b, param):
|
|
495
|
-
return jax_geq(-a, -b, param)
|
|
496
|
-
|
|
497
|
-
return _jax_wrapped_calc_leq_approx, jax_param
|
|
848
|
+
def greater_equal(self, id, init_params):
|
|
849
|
+
return self.comparison.greater_equal(id, init_params)
|
|
498
850
|
|
|
499
|
-
def
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
def _jax_wrapped_calc_less_approx(a, b, param):
|
|
503
|
-
return jax_gre(-a, -b, param)
|
|
504
|
-
|
|
505
|
-
return _jax_wrapped_calc_less_approx, jax_param
|
|
506
|
-
|
|
507
|
-
def equal(self):
|
|
508
|
-
if self.verbose:
|
|
509
|
-
raise_warning('Using the replacement rule: '
|
|
510
|
-
'a == b --> comparison.equal(a, b)')
|
|
511
|
-
|
|
512
|
-
equal_op = self.comparison.equal
|
|
513
|
-
debias = 'equal' in self.debias
|
|
514
|
-
|
|
515
|
-
def _jax_wrapped_calc_equal_approx(a, b, param):
|
|
516
|
-
sample = equal_op(a, b, param)
|
|
517
|
-
if debias:
|
|
518
|
-
hard_sample = jnp.equal(a, b)
|
|
519
|
-
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
520
|
-
return sample
|
|
521
|
-
|
|
522
|
-
tags = ('weight', 'equal')
|
|
523
|
-
new_param = (tags, self.weight)
|
|
524
|
-
return _jax_wrapped_calc_equal_approx, new_param
|
|
851
|
+
def greater(self, id, init_params):
|
|
852
|
+
return self.comparison.greater(id, init_params)
|
|
525
853
|
|
|
526
|
-
def
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
854
|
+
def less_equal(self, id, init_params):
|
|
855
|
+
_greater_eq = self.greater_equal(id, init_params)
|
|
856
|
+
def _jax_wrapped_calc_leq_approx(x, y, params):
|
|
857
|
+
return _greater_eq(-x, -y, params)
|
|
858
|
+
return _jax_wrapped_calc_leq_approx
|
|
859
|
+
|
|
860
|
+
def less(self, id, init_params):
|
|
861
|
+
_greater = self.greater(id, init_params)
|
|
862
|
+
def _jax_wrapped_calc_less_approx(x, y, params):
|
|
863
|
+
return _greater(-x, -y, params)
|
|
864
|
+
return _jax_wrapped_calc_less_approx
|
|
532
865
|
|
|
533
|
-
|
|
866
|
+
def equal(self, id, init_params):
|
|
867
|
+
return self.comparison.equal(id, init_params)
|
|
868
|
+
|
|
869
|
+
def not_equal(self, id, init_params):
|
|
870
|
+
_not = self.complement(f'{id}_~', init_params)
|
|
871
|
+
_equal = self.comparison.equal(f'{id}_==', init_params)
|
|
872
|
+
def _jax_wrapped_calc_neq_approx(x, y, params):
|
|
873
|
+
equal, params = _equal(x, y, params)
|
|
874
|
+
return _not(equal, params)
|
|
875
|
+
return _jax_wrapped_calc_neq_approx
|
|
534
876
|
|
|
535
877
|
# ===========================================================================
|
|
536
878
|
# special functions
|
|
537
879
|
# ===========================================================================
|
|
538
|
-
|
|
539
|
-
def sgn(self):
|
|
540
|
-
if self.verbose:
|
|
541
|
-
raise_warning('Using the replacement rule: '
|
|
542
|
-
'sgn(x) --> comparison.sgn(x)')
|
|
543
|
-
|
|
544
|
-
sgn_op = self.comparison.sgn
|
|
545
|
-
debias = 'sgn' in self.debias
|
|
546
|
-
|
|
547
|
-
def _jax_wrapped_calc_sgn_approx(x, param):
|
|
548
|
-
sample = sgn_op(x, param)
|
|
549
|
-
if debias:
|
|
550
|
-
hard_sample = jnp.sign(x)
|
|
551
|
-
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
552
|
-
return sample
|
|
553
|
-
|
|
554
|
-
tags = ('weight', 'sgn')
|
|
555
|
-
new_param = (tags, self.weight)
|
|
556
|
-
return _jax_wrapped_calc_sgn_approx, new_param
|
|
557
|
-
|
|
558
|
-
def floor(self):
|
|
559
|
-
if self.verbose:
|
|
560
|
-
raise_warning('Using the replacement rule: '
|
|
561
|
-
'floor(x) --> rounding.floor(x)')
|
|
562
|
-
|
|
563
|
-
floor_op = self.rounding.floor
|
|
564
|
-
debias = 'floor' in self.debias
|
|
565
|
-
|
|
566
|
-
def _jax_wrapped_calc_floor_approx(x, param):
|
|
567
|
-
sample = floor_op(x, param)
|
|
568
|
-
if debias:
|
|
569
|
-
hard_sample = jnp.floor(x)
|
|
570
|
-
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
571
|
-
return sample
|
|
572
|
-
|
|
573
|
-
tags = ('weight', 'floor')
|
|
574
|
-
new_param = (tags, self.weight)
|
|
575
|
-
return _jax_wrapped_calc_floor_approx, new_param
|
|
576
|
-
|
|
577
|
-
def round(self):
|
|
578
|
-
if self.verbose:
|
|
579
|
-
raise_warning('Using the replacement rule: '
|
|
580
|
-
'round(x) --> rounding.round(x)')
|
|
581
|
-
|
|
582
|
-
round_op = self.rounding.round
|
|
583
|
-
debias = 'round' in self.debias
|
|
584
|
-
|
|
585
|
-
def _jax_wrapped_calc_round_approx(x, param):
|
|
586
|
-
sample = round_op(x, param)
|
|
587
|
-
if debias:
|
|
588
|
-
hard_sample = jnp.round(x)
|
|
589
|
-
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
590
|
-
return sample
|
|
591
|
-
|
|
592
|
-
tags = ('weight', 'round')
|
|
593
|
-
new_param = (tags, self.weight)
|
|
594
|
-
return _jax_wrapped_calc_round_approx, new_param
|
|
595
880
|
|
|
596
|
-
def
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
def _jax_wrapped_calc_ceil_approx(x, param):
|
|
600
|
-
return -jax_floor(-x, param)
|
|
601
|
-
|
|
602
|
-
return _jax_wrapped_calc_ceil_approx, jax_param
|
|
881
|
+
def sgn(self, id, init_params):
|
|
882
|
+
return self.comparison.sgn(id, init_params)
|
|
603
883
|
|
|
604
|
-
def
|
|
605
|
-
|
|
884
|
+
def floor(self, id, init_params):
|
|
885
|
+
return self.rounding.floor(id, init_params)
|
|
606
886
|
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
return _jax_wrapped_calc_mod_approx, jax_param
|
|
887
|
+
def round(self, id, init_params):
|
|
888
|
+
return self.rounding.round(id, init_params)
|
|
611
889
|
|
|
612
|
-
def
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
return
|
|
617
|
-
|
|
618
|
-
return _jax_wrapped_calc_div_approx, jax_param
|
|
890
|
+
def ceil(self, id, init_params):
|
|
891
|
+
_floor = self.rounding.floor(id, init_params)
|
|
892
|
+
def _jax_wrapped_calc_ceil_approx(x, params):
|
|
893
|
+
neg_floor, params = _floor(-x, params)
|
|
894
|
+
return -neg_floor, params
|
|
895
|
+
return _jax_wrapped_calc_ceil_approx
|
|
619
896
|
|
|
620
|
-
def
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
897
|
+
def div(self, id, init_params):
|
|
898
|
+
_floor = self.rounding.floor(id, init_params)
|
|
899
|
+
def _jax_wrapped_calc_div_approx(x, y, params):
|
|
900
|
+
return _floor(x / y, params)
|
|
901
|
+
return _jax_wrapped_calc_div_approx
|
|
902
|
+
|
|
903
|
+
def mod(self, id, init_params):
|
|
904
|
+
_div = self.div(id, init_params)
|
|
905
|
+
def _jax_wrapped_calc_mod_approx(x, y, params):
|
|
906
|
+
div, params = _div(x, y, params)
|
|
907
|
+
return x - y * div, params
|
|
908
|
+
return _jax_wrapped_calc_mod_approx
|
|
909
|
+
|
|
910
|
+
def sqrt(self, id, init_params):
|
|
911
|
+
def _jax_wrapped_calc_sqrt_approx(x, params):
|
|
912
|
+
return jnp.sqrt(x + self.eps), params
|
|
913
|
+
return _jax_wrapped_calc_sqrt_approx
|
|
628
914
|
|
|
629
915
|
# ===========================================================================
|
|
630
916
|
# indexing
|
|
631
917
|
# ===========================================================================
|
|
632
918
|
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
literals = jnp.arange(shape[axis])
|
|
636
|
-
literals = literals[(...,) + (jnp.newaxis,) * (len(shape) - 1)]
|
|
637
|
-
literals = jnp.moveaxis(literals, source=0, destination=axis)
|
|
638
|
-
literals = jnp.broadcast_to(literals, shape=shape)
|
|
639
|
-
return literals
|
|
640
|
-
|
|
641
|
-
def argmax(self):
|
|
642
|
-
if self.verbose:
|
|
643
|
-
raise_warning('Using the replacement rule: '
|
|
644
|
-
'argmax(x) --> sum(i * softmax(x[i]))')
|
|
645
|
-
|
|
646
|
-
debias = 'argmax' in self.debias
|
|
647
|
-
|
|
648
|
-
# https://arxiv.org/abs/2110.05651
|
|
649
|
-
def _jax_wrapped_calc_argmax_approx(x, axis, param):
|
|
650
|
-
literals = FuzzyLogic.enumerate_literals(x.shape, axis=axis)
|
|
651
|
-
soft_max = jax.nn.softmax(param * x, axis=axis)
|
|
652
|
-
sample = jnp.sum(literals * soft_max, axis=axis)
|
|
653
|
-
if debias:
|
|
654
|
-
hard_sample = jnp.argmax(x, axis=axis)
|
|
655
|
-
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
656
|
-
return sample
|
|
657
|
-
|
|
658
|
-
tags = ('weight', 'argmax')
|
|
659
|
-
new_param = (tags, self.weight)
|
|
660
|
-
return _jax_wrapped_calc_argmax_approx, new_param
|
|
919
|
+
def argmax(self, id, init_params):
|
|
920
|
+
return self.comparison.argmax(id, init_params)
|
|
661
921
|
|
|
662
|
-
def argmin(self):
|
|
663
|
-
|
|
664
|
-
|
|
922
|
+
def argmin(self, id, init_params):
|
|
923
|
+
_argmax = self.argmax(id, init_params)
|
|
665
924
|
def _jax_wrapped_calc_argmin_approx(x, axis, param):
|
|
666
|
-
return
|
|
667
|
-
|
|
668
|
-
return _jax_wrapped_calc_argmin_approx, jax_param
|
|
925
|
+
return _argmax(-x, axis, param)
|
|
926
|
+
return _jax_wrapped_calc_argmin_approx
|
|
669
927
|
|
|
670
928
|
# ===========================================================================
|
|
671
929
|
# control flow
|
|
672
930
|
# ===========================================================================
|
|
673
931
|
|
|
674
|
-
def control_if(self):
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
debias = 'if' in self.debias
|
|
680
|
-
|
|
681
|
-
def _jax_wrapped_calc_if_approx(c, a, b, param):
|
|
682
|
-
sample = c * a + (1.0 - c) * b
|
|
683
|
-
if debias:
|
|
684
|
-
hard_sample = jnp.where(c > 0.5, a, b)
|
|
685
|
-
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
686
|
-
return sample
|
|
687
|
-
|
|
688
|
-
return _jax_wrapped_calc_if_approx, None
|
|
689
|
-
|
|
690
|
-
def control_switch(self):
|
|
691
|
-
if self.verbose:
|
|
692
|
-
raise_warning('Using the replacement rule: '
|
|
693
|
-
'switch(pred) { cases } --> '
|
|
694
|
-
'sum(cases[i] * softmax(-(pred - i)^2))')
|
|
695
|
-
|
|
696
|
-
debias = 'switch' in self.debias
|
|
697
|
-
|
|
698
|
-
def _jax_wrapped_calc_switch_approx(pred, cases, param):
|
|
699
|
-
literals = FuzzyLogic.enumerate_literals(cases.shape, axis=0)
|
|
700
|
-
pred = jnp.broadcast_to(pred[jnp.newaxis, ...], shape=cases.shape)
|
|
701
|
-
proximity = -jnp.square(pred - literals)
|
|
702
|
-
soft_case = jax.nn.softmax(param * proximity, axis=0)
|
|
703
|
-
sample = jnp.sum(cases * soft_case, axis=0)
|
|
704
|
-
if debias:
|
|
705
|
-
hard_case = jnp.argmax(proximity, axis=0)[jnp.newaxis, ...]
|
|
706
|
-
hard_sample = jnp.take_along_axis(cases, hard_case, axis=0)[0, ...]
|
|
707
|
-
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
708
|
-
return sample
|
|
709
|
-
|
|
710
|
-
tags = ('weight', 'switch')
|
|
711
|
-
new_param = (tags, self.weight)
|
|
712
|
-
return _jax_wrapped_calc_switch_approx, new_param
|
|
932
|
+
def control_if(self, id, init_params):
|
|
933
|
+
return self.control.if_then_else(id, init_params)
|
|
934
|
+
|
|
935
|
+
def control_switch(self, id, init_params):
|
|
936
|
+
return self.control.switch(id, init_params)
|
|
713
937
|
|
|
714
938
|
# ===========================================================================
|
|
715
939
|
# random variables
|
|
716
940
|
# ===========================================================================
|
|
717
941
|
|
|
718
|
-
def discrete(self):
|
|
719
|
-
return self.sampling.discrete(self)
|
|
942
|
+
def discrete(self, id, init_params):
|
|
943
|
+
return self.sampling.discrete(id, init_params, self)
|
|
720
944
|
|
|
721
|
-
def bernoulli(self):
|
|
722
|
-
return self.sampling.bernoulli(self)
|
|
945
|
+
def bernoulli(self, id, init_params):
|
|
946
|
+
return self.sampling.bernoulli(id, init_params, self)
|
|
723
947
|
|
|
724
|
-
def poisson(self):
|
|
725
|
-
return self.sampling.poisson(self)
|
|
948
|
+
def poisson(self, id, init_params):
|
|
949
|
+
return self.sampling.poisson(id, init_params, self)
|
|
726
950
|
|
|
727
|
-
def geometric(self):
|
|
728
|
-
return self.sampling.geometric(self)
|
|
951
|
+
def geometric(self, id, init_params):
|
|
952
|
+
return self.sampling.geometric(id, init_params, self)
|
|
729
953
|
|
|
730
954
|
|
|
731
955
|
# ===========================================================================
|
|
@@ -733,105 +957,121 @@ class FuzzyLogic:
|
|
|
733
957
|
#
|
|
734
958
|
# ===========================================================================
|
|
735
959
|
|
|
736
|
-
logic = FuzzyLogic()
|
|
737
|
-
|
|
960
|
+
logic = FuzzyLogic(comparison=SigmoidComparison(10000.0),
|
|
961
|
+
rounding=SoftRounding(10000.0),
|
|
962
|
+
control=SoftControlFlow(10000.0))
|
|
738
963
|
|
|
739
964
|
|
|
740
965
|
def _test_logical():
|
|
741
966
|
print('testing logical')
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
967
|
+
init_params = {}
|
|
968
|
+
_and = logic.logical_and(0, init_params)
|
|
969
|
+
_not = logic.logical_not(1, init_params)
|
|
970
|
+
_gre = logic.greater(2, init_params)
|
|
971
|
+
_or = logic.logical_or(3, init_params)
|
|
972
|
+
_if = logic.control_if(4, init_params)
|
|
973
|
+
print(init_params)
|
|
974
|
+
|
|
748
975
|
# https://towardsdatascience.com/emulating-logical-gates-with-a-neural-network-75c229ec4cc9
|
|
749
|
-
def test_logic(x1, x2):
|
|
750
|
-
q1 =
|
|
751
|
-
q2
|
|
752
|
-
|
|
753
|
-
|
|
976
|
+
def test_logic(x1, x2, w):
|
|
977
|
+
q1, w = _gre(x1, 0, w)
|
|
978
|
+
q2, w = _gre(x2, 0, w)
|
|
979
|
+
q3, w = _and(q1, q2, w)
|
|
980
|
+
q4, w = _not(q1, w)
|
|
981
|
+
q5, w = _not(q2, w)
|
|
982
|
+
q6, w = _and(q4, q5, w)
|
|
983
|
+
cond, w = _or(q3, q6, w)
|
|
984
|
+
pred, w = _if(cond, +1, -1, w)
|
|
754
985
|
return pred
|
|
755
986
|
|
|
756
987
|
x1 = jnp.asarray([1, 1, -1, -1, 0.1, 15, -0.5]).astype(float)
|
|
757
988
|
x2 = jnp.asarray([1, -1, 1, -1, 10, -30, 6]).astype(float)
|
|
758
|
-
print(test_logic(x1, x2))
|
|
989
|
+
print(test_logic(x1, x2, init_params))
|
|
759
990
|
|
|
760
991
|
|
|
761
992
|
def _test_indexing():
|
|
762
993
|
print('testing indexing')
|
|
763
|
-
|
|
764
|
-
|
|
994
|
+
init_params = {}
|
|
995
|
+
_argmax = logic.argmax(0, init_params)
|
|
996
|
+
_argmin = logic.argmin(1, init_params)
|
|
997
|
+
print(init_params)
|
|
765
998
|
|
|
766
|
-
def argmaxmin(x):
|
|
767
|
-
amax = _argmax(x, 0, w)
|
|
768
|
-
amin = _argmin(x, 0, w)
|
|
999
|
+
def argmaxmin(x, w):
|
|
1000
|
+
amax, w = _argmax(x, 0, w)
|
|
1001
|
+
amin, w = _argmin(x, 0, w)
|
|
769
1002
|
return amax, amin
|
|
770
1003
|
|
|
771
1004
|
values = jnp.asarray([2., 3., 5., 4.9, 4., 1., -1., -2.])
|
|
772
|
-
amax, amin = argmaxmin(values)
|
|
1005
|
+
amax, amin = argmaxmin(values, init_params)
|
|
773
1006
|
print(amax)
|
|
774
1007
|
print(amin)
|
|
775
1008
|
|
|
776
1009
|
|
|
777
1010
|
def _test_control():
|
|
778
1011
|
print('testing control')
|
|
779
|
-
|
|
1012
|
+
init_params = {}
|
|
1013
|
+
_switch = logic.control_switch(0, init_params)
|
|
1014
|
+
print(init_params)
|
|
780
1015
|
|
|
781
1016
|
pred = jnp.asarray(jnp.linspace(0, 2, 10))
|
|
782
1017
|
case1 = jnp.asarray([-10.] * 10)
|
|
783
1018
|
case2 = jnp.asarray([1.5] * 10)
|
|
784
1019
|
case3 = jnp.asarray([10.] * 10)
|
|
785
1020
|
cases = jnp.asarray([case1, case2, case3])
|
|
786
|
-
|
|
1021
|
+
switch, _ = _switch(pred, cases, init_params)
|
|
1022
|
+
print(switch)
|
|
787
1023
|
|
|
788
1024
|
|
|
789
1025
|
def _test_random():
|
|
790
1026
|
print('testing random')
|
|
791
1027
|
key = random.PRNGKey(42)
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
1028
|
+
init_params = {}
|
|
1029
|
+
_bernoulli = logic.bernoulli(0, init_params)
|
|
1030
|
+
_discrete = logic.discrete(1, init_params)
|
|
1031
|
+
_geometric = logic.geometric(2, init_params)
|
|
1032
|
+
print(init_params)
|
|
795
1033
|
|
|
796
|
-
def bern(n):
|
|
1034
|
+
def bern(n, w):
|
|
797
1035
|
prob = jnp.asarray([0.3] * n)
|
|
798
|
-
sample = _bernoulli(key, prob, w)
|
|
1036
|
+
sample, _ = _bernoulli(key, prob, w)
|
|
799
1037
|
return sample
|
|
800
1038
|
|
|
801
|
-
samples = bern(50000)
|
|
1039
|
+
samples = bern(50000, init_params)
|
|
802
1040
|
print(jnp.mean(samples))
|
|
803
1041
|
|
|
804
|
-
def disc(n):
|
|
1042
|
+
def disc(n, w):
|
|
805
1043
|
prob = jnp.asarray([0.1, 0.4, 0.5])
|
|
806
1044
|
prob = jnp.tile(prob, (n, 1))
|
|
807
|
-
sample = _discrete(key, prob, w)
|
|
1045
|
+
sample, _ = _discrete(key, prob, w)
|
|
808
1046
|
return sample
|
|
809
1047
|
|
|
810
|
-
samples = disc(50000)
|
|
1048
|
+
samples = disc(50000, init_params)
|
|
811
1049
|
samples = jnp.round(samples)
|
|
812
1050
|
print([jnp.mean(samples == i) for i in range(3)])
|
|
813
1051
|
|
|
814
|
-
def geom(n):
|
|
1052
|
+
def geom(n, w):
|
|
815
1053
|
prob = jnp.asarray([0.3] * n)
|
|
816
|
-
sample = _geometric(key, prob, w)
|
|
1054
|
+
sample, _ = _geometric(key, prob, w)
|
|
817
1055
|
return sample
|
|
818
1056
|
|
|
819
|
-
samples = geom(50000)
|
|
1057
|
+
samples = geom(50000, init_params)
|
|
820
1058
|
print(jnp.mean(samples))
|
|
821
1059
|
|
|
822
1060
|
|
|
823
1061
|
def _test_rounding():
|
|
824
1062
|
print('testing rounding')
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
1063
|
+
init_params = {}
|
|
1064
|
+
_floor = logic.floor(0, init_params)
|
|
1065
|
+
_ceil = logic.ceil(1, init_params)
|
|
1066
|
+
_round = logic.round(2, init_params)
|
|
1067
|
+
_mod = logic.mod(3, init_params)
|
|
1068
|
+
print(init_params)
|
|
829
1069
|
|
|
830
1070
|
x = jnp.asarray([2.1, 0.6, 1.99, -2.01, -3.2, -0.1, -1.01, 23.01, -101.99, 200.01])
|
|
831
|
-
print(_floor(x,
|
|
832
|
-
print(_ceil(x,
|
|
833
|
-
print(_round(x,
|
|
834
|
-
print(_mod(x, 2.0,
|
|
1071
|
+
print(_floor(x, init_params)[0])
|
|
1072
|
+
print(_ceil(x, init_params)[0])
|
|
1073
|
+
print(_round(x, init_params)[0])
|
|
1074
|
+
print(_mod(x, 2.0, init_params)[0])
|
|
835
1075
|
|
|
836
1076
|
|
|
837
1077
|
if __name__ == '__main__':
|