pyRDDLGym-jax 0.4__py3-none-any.whl → 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +463 -592
- pyRDDLGym_jax/core/logic.py +832 -530
- pyRDDLGym_jax/core/planner.py +422 -474
- pyRDDLGym_jax/core/simulator.py +7 -5
- pyRDDLGym_jax/core/tuning.py +390 -584
- pyRDDLGym_jax/core/visualization.py +1463 -0
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +5 -6
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +5 -5
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +5 -6
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +5 -4
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +5 -4
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +5 -5
- pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +7 -6
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/default_drp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/default_replan.cfg +5 -4
- pyRDDLGym_jax/examples/configs/default_slp.cfg +3 -3
- pyRDDLGym_jax/examples/configs/tuning_drp.cfg +19 -0
- pyRDDLGym_jax/examples/configs/tuning_replan.cfg +20 -0
- pyRDDLGym_jax/examples/configs/tuning_slp.cfg +19 -0
- pyRDDLGym_jax/examples/run_plan.py +4 -1
- pyRDDLGym_jax/examples/run_tune.py +40 -29
- {pyRDDLGym_jax-0.4.dist-info → pyRDDLGym_jax-1.0.dist-info}/METADATA +164 -105
- pyRDDLGym_jax-1.0.dist-info/RECORD +45 -0
- {pyRDDLGym_jax-0.4.dist-info → pyRDDLGym_jax-1.0.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_drp.cfg +0 -19
- pyRDDLGym_jax/examples/configs/MarsRover_ippc2023_slp.cfg +0 -20
- pyRDDLGym_jax/examples/configs/Pendulum_gym_slp.cfg +0 -18
- pyRDDLGym_jax-0.4.dist-info/RECORD +0 -44
- {pyRDDLGym_jax-0.4.dist-info → pyRDDLGym_jax-1.0.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-0.4.dist-info → pyRDDLGym_jax-1.0.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/logic.py
CHANGED
|
@@ -4,63 +4,169 @@ import jax
|
|
|
4
4
|
import jax.numpy as jnp
|
|
5
5
|
import jax.random as random
|
|
6
6
|
|
|
7
|
-
|
|
7
|
+
|
|
8
|
+
def enumerate_literals(shape, axis, dtype=jnp.int32):
|
|
9
|
+
literals = jnp.arange(shape[axis], dtype=dtype)
|
|
10
|
+
literals = literals[(...,) + (jnp.newaxis,) * (len(shape) - 1)]
|
|
11
|
+
literals = jnp.moveaxis(literals, source=0, destination=axis)
|
|
12
|
+
literals = jnp.broadcast_to(literals, shape=shape)
|
|
13
|
+
return literals
|
|
8
14
|
|
|
9
15
|
|
|
10
16
|
# ===========================================================================
|
|
11
|
-
#
|
|
17
|
+
# RELATIONAL OPERATIONS
|
|
12
18
|
# - abstract class
|
|
13
|
-
# -
|
|
19
|
+
# - sigmoid comparison
|
|
14
20
|
#
|
|
15
21
|
# ===========================================================================
|
|
16
22
|
|
|
17
|
-
class
|
|
18
|
-
'''Base class for approximate
|
|
23
|
+
class Comparison:
|
|
24
|
+
'''Base class for approximate comparison operations.'''
|
|
19
25
|
|
|
20
|
-
def
|
|
26
|
+
def greater_equal(self, id, init_params):
|
|
21
27
|
raise NotImplementedError
|
|
28
|
+
|
|
29
|
+
def greater(self, id, init_params):
|
|
30
|
+
raise NotImplementedError
|
|
31
|
+
|
|
32
|
+
def equal(self, id, init_params):
|
|
33
|
+
raise NotImplementedError
|
|
34
|
+
|
|
35
|
+
def sgn(self, id, init_params):
|
|
36
|
+
raise NotImplementedError
|
|
37
|
+
|
|
38
|
+
def argmax(self, id, init_params):
|
|
39
|
+
raise NotImplementedError
|
|
40
|
+
|
|
22
41
|
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
'''The standard approximate logical complement given by x -> 1 - x.'''
|
|
42
|
+
class SigmoidComparison(Comparison):
|
|
43
|
+
'''Comparison operations approximated using sigmoid functions.'''
|
|
26
44
|
|
|
27
|
-
def
|
|
28
|
-
|
|
45
|
+
def __init__(self, weight: float=10.0):
|
|
46
|
+
self.weight = weight
|
|
47
|
+
|
|
48
|
+
# https://arxiv.org/abs/2110.05651
|
|
49
|
+
def greater_equal(self, id, init_params):
|
|
50
|
+
id_ = str(id)
|
|
51
|
+
init_params[id_] = self.weight
|
|
52
|
+
def _jax_wrapped_calc_greater_equal_approx(x, y, params):
|
|
53
|
+
gre_eq = jax.nn.sigmoid(params[id_] * (x - y))
|
|
54
|
+
return gre_eq, params
|
|
55
|
+
return _jax_wrapped_calc_greater_equal_approx
|
|
56
|
+
|
|
57
|
+
def greater(self, id, init_params):
|
|
58
|
+
return self.greater_equal(id, init_params)
|
|
59
|
+
|
|
60
|
+
def equal(self, id, init_params):
|
|
61
|
+
id_ = str(id)
|
|
62
|
+
init_params[id_] = self.weight
|
|
63
|
+
def _jax_wrapped_calc_equal_approx(x, y, params):
|
|
64
|
+
equal = 1.0 - jnp.square(jnp.tanh(params[id_] * (y - x)))
|
|
65
|
+
return equal, params
|
|
66
|
+
return _jax_wrapped_calc_equal_approx
|
|
67
|
+
|
|
68
|
+
def sgn(self, id, init_params):
|
|
69
|
+
id_ = str(id)
|
|
70
|
+
init_params[id_] = self.weight
|
|
71
|
+
def _jax_wrapped_calc_sgn_approx(x, params):
|
|
72
|
+
sgn = jnp.tanh(params[id_] * x)
|
|
73
|
+
return sgn, params
|
|
74
|
+
return _jax_wrapped_calc_sgn_approx
|
|
75
|
+
|
|
76
|
+
# https://arxiv.org/abs/2110.05651
|
|
77
|
+
def argmax(self, id, init_params):
|
|
78
|
+
id_ = str(id)
|
|
79
|
+
init_params[id_] = self.weight
|
|
80
|
+
def _jax_wrapped_calc_argmax_approx(x, axis, params):
|
|
81
|
+
literals = enumerate_literals(x.shape, axis=axis)
|
|
82
|
+
softmax = jax.nn.softmax(params[id_] * x, axis=axis)
|
|
83
|
+
sample = jnp.sum(literals * softmax, axis=axis)
|
|
84
|
+
return sample, params
|
|
85
|
+
return _jax_wrapped_calc_argmax_approx
|
|
86
|
+
|
|
87
|
+
def __str__(self) -> str:
|
|
88
|
+
return f'Sigmoid comparison with weight {self.weight}'
|
|
29
89
|
|
|
30
90
|
|
|
31
91
|
# ===========================================================================
|
|
32
|
-
#
|
|
92
|
+
# ROUNDING OPERATIONS
|
|
33
93
|
# - abstract class
|
|
34
|
-
# -
|
|
94
|
+
# - soft rounding
|
|
35
95
|
#
|
|
36
96
|
# ===========================================================================
|
|
37
97
|
|
|
38
|
-
class
|
|
39
|
-
'''Base class for approximate
|
|
98
|
+
class Rounding:
|
|
99
|
+
'''Base class for approximate rounding operations.'''
|
|
40
100
|
|
|
41
|
-
def
|
|
101
|
+
def floor(self, id, init_params):
|
|
42
102
|
raise NotImplementedError
|
|
43
103
|
|
|
44
|
-
def
|
|
104
|
+
def round(self, id, init_params):
|
|
45
105
|
raise NotImplementedError
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class SoftRounding(Rounding):
|
|
109
|
+
'''Rounding operations approximated using soft operations.'''
|
|
110
|
+
|
|
111
|
+
def __init__(self, weight: float=10.0):
|
|
112
|
+
self.weight = weight
|
|
113
|
+
|
|
114
|
+
# https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/bijectors/Softfloor
|
|
115
|
+
def floor(self, id, init_params):
|
|
116
|
+
id_ = str(id)
|
|
117
|
+
init_params[id_] = self.weight
|
|
118
|
+
def _jax_wrapped_calc_floor_approx(x, params):
|
|
119
|
+
param = params[id_]
|
|
120
|
+
denom = jnp.tanh(param / 4.0)
|
|
121
|
+
floor = (jax.nn.sigmoid(param * (x - jnp.floor(x) - 1.0)) -
|
|
122
|
+
jax.nn.sigmoid(-param / 2.0)) / denom + jnp.floor(x)
|
|
123
|
+
return floor, params
|
|
124
|
+
return _jax_wrapped_calc_floor_approx
|
|
125
|
+
|
|
126
|
+
# https://arxiv.org/abs/2006.09952
|
|
127
|
+
def round(self, id, init_params):
|
|
128
|
+
id_ = str(id)
|
|
129
|
+
init_params[id_] = self.weight
|
|
130
|
+
def _jax_wrapped_calc_round_approx(x, params):
|
|
131
|
+
param = params[id_]
|
|
132
|
+
m = jnp.floor(x) + 0.5
|
|
133
|
+
rounded = m + 0.5 * jnp.tanh(param * (x - m)) / jnp.tanh(param / 2.0)
|
|
134
|
+
return rounded, params
|
|
135
|
+
return _jax_wrapped_calc_round_approx
|
|
136
|
+
|
|
137
|
+
def __str__(self) -> str:
|
|
138
|
+
return f'SoftFloor and SoftRound with weight {self.weight}'
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
# ===========================================================================
|
|
142
|
+
# LOGICAL COMPLEMENT
|
|
143
|
+
# - abstract class
|
|
144
|
+
# - standard complement
|
|
145
|
+
#
|
|
146
|
+
# ===========================================================================
|
|
147
|
+
|
|
148
|
+
class Complement:
|
|
149
|
+
'''Base class for approximate logical complement operations.'''
|
|
46
150
|
|
|
47
|
-
def
|
|
151
|
+
def __call__(self, id, init_params):
|
|
48
152
|
raise NotImplementedError
|
|
49
|
-
|
|
50
153
|
|
|
51
|
-
|
|
52
|
-
|
|
154
|
+
|
|
155
|
+
class StandardComplement(Complement):
|
|
156
|
+
'''The standard approximate logical complement given by x -> 1 - x.'''
|
|
53
157
|
|
|
54
|
-
|
|
55
|
-
|
|
158
|
+
# https://www.sciencedirect.com/science/article/abs/pii/016501149190171L
|
|
159
|
+
@staticmethod
|
|
160
|
+
def _jax_wrapped_calc_not_approx(x, params):
|
|
161
|
+
return 1.0 - x, params
|
|
56
162
|
|
|
57
|
-
def
|
|
58
|
-
return
|
|
163
|
+
def __call__(self, id, init_params):
|
|
164
|
+
return self._jax_wrapped_calc_not_approx
|
|
59
165
|
|
|
60
|
-
def
|
|
61
|
-
return
|
|
62
|
-
|
|
63
|
-
|
|
166
|
+
def __str__(self) -> str:
|
|
167
|
+
return 'Standard complement'
|
|
168
|
+
|
|
169
|
+
|
|
64
170
|
# ===========================================================================
|
|
65
171
|
# TNORMS
|
|
66
172
|
# - abstract tnorm
|
|
@@ -69,48 +175,84 @@ class SigmoidComparison(Comparison):
|
|
|
69
175
|
# - Lukasiewicz tnorm
|
|
70
176
|
# - Yager(p) tnorm
|
|
71
177
|
#
|
|
178
|
+
# https://www.sciencedirect.com/science/article/abs/pii/016501149190171L
|
|
72
179
|
# ===========================================================================
|
|
73
180
|
|
|
74
181
|
class TNorm:
|
|
75
182
|
'''Base class for fuzzy differentiable t-norms.'''
|
|
76
183
|
|
|
77
|
-
def norm(self,
|
|
184
|
+
def norm(self, id, init_params):
|
|
78
185
|
'''Elementwise t-norm of x and y.'''
|
|
79
186
|
raise NotImplementedError
|
|
80
187
|
|
|
81
|
-
def norms(self,
|
|
188
|
+
def norms(self, id, init_params):
|
|
82
189
|
'''T-norm computed for tensor x along axis.'''
|
|
83
190
|
raise NotImplementedError
|
|
84
|
-
|
|
191
|
+
|
|
85
192
|
|
|
86
193
|
class ProductTNorm(TNorm):
|
|
87
194
|
'''Product t-norm given by the expression (x, y) -> x * y.'''
|
|
88
195
|
|
|
89
|
-
|
|
90
|
-
|
|
196
|
+
@staticmethod
|
|
197
|
+
def _jax_wrapped_calc_and_approx(x, y, params):
|
|
198
|
+
return x * y, params
|
|
199
|
+
|
|
200
|
+
def norm(self, id, init_params):
|
|
201
|
+
return self._jax_wrapped_calc_and_approx
|
|
91
202
|
|
|
92
|
-
|
|
93
|
-
|
|
203
|
+
@staticmethod
|
|
204
|
+
def _jax_wrapped_calc_forall_approx(x, axis, params):
|
|
205
|
+
return jnp.prod(x, axis=axis), params
|
|
206
|
+
|
|
207
|
+
def norms(self, id, init_params):
|
|
208
|
+
return self._jax_wrapped_calc_forall_approx
|
|
94
209
|
|
|
210
|
+
def __str__(self) -> str:
|
|
211
|
+
return 'Product t-norm'
|
|
212
|
+
|
|
95
213
|
|
|
96
214
|
class GodelTNorm(TNorm):
|
|
97
215
|
'''Godel t-norm given by the expression (x, y) -> min(x, y).'''
|
|
98
216
|
|
|
99
|
-
|
|
100
|
-
|
|
217
|
+
@staticmethod
|
|
218
|
+
def _jax_wrapped_calc_and_approx(x, y, params):
|
|
219
|
+
return jnp.minimum(x, y), params
|
|
220
|
+
|
|
221
|
+
def norm(self, id, init_params):
|
|
222
|
+
return self._jax_wrapped_calc_and_approx
|
|
223
|
+
|
|
224
|
+
@staticmethod
|
|
225
|
+
def _jax_wrapped_calc_forall_approx(x, axis, params):
|
|
226
|
+
return jnp.min(x, axis=axis), params
|
|
227
|
+
|
|
228
|
+
def norms(self, id, init_params):
|
|
229
|
+
return self._jax_wrapped_calc_forall_approx
|
|
230
|
+
|
|
231
|
+
def __str__(self) -> str:
|
|
232
|
+
return 'Godel t-norm'
|
|
101
233
|
|
|
102
|
-
def norms(self, x, axis):
|
|
103
|
-
return jnp.min(x, axis=axis)
|
|
104
|
-
|
|
105
234
|
|
|
106
235
|
class LukasiewiczTNorm(TNorm):
|
|
107
236
|
'''Lukasiewicz t-norm given by the expression (x, y) -> max(x + y - 1, 0).'''
|
|
108
237
|
|
|
109
|
-
|
|
110
|
-
|
|
238
|
+
@staticmethod
|
|
239
|
+
def _jax_wrapped_calc_and_approx(x, y, params):
|
|
240
|
+
land = jax.nn.relu(x + y - 1.0)
|
|
241
|
+
return land, params
|
|
242
|
+
|
|
243
|
+
def norm(self, id, init_params):
|
|
244
|
+
return self._jax_wrapped_calc_and_approx
|
|
111
245
|
|
|
112
|
-
|
|
113
|
-
|
|
246
|
+
@staticmethod
|
|
247
|
+
def _jax_wrapped_calc_forall_approx(x, axis, params):
|
|
248
|
+
forall = jax.nn.relu(jnp.sum(x - 1.0, axis=axis) + 1.0)
|
|
249
|
+
return forall, params
|
|
250
|
+
|
|
251
|
+
def norms(self, id, init_params):
|
|
252
|
+
return self._jax_wrapped_calc_forall_approx
|
|
253
|
+
|
|
254
|
+
def __str__(self) -> str:
|
|
255
|
+
return 'Lukasiewicz t-norm'
|
|
114
256
|
|
|
115
257
|
|
|
116
258
|
class YagerTNorm(TNorm):
|
|
@@ -118,19 +260,32 @@ class YagerTNorm(TNorm):
|
|
|
118
260
|
(x, y) -> max(1 - ((1 - x)^p + (1 - y)^p)^(1/p)).'''
|
|
119
261
|
|
|
120
262
|
def __init__(self, p=2.0):
|
|
121
|
-
self.p = p
|
|
122
|
-
|
|
123
|
-
def norm(self,
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
263
|
+
self.p = float(p)
|
|
264
|
+
|
|
265
|
+
def norm(self, id, init_params):
|
|
266
|
+
id_ = str(id)
|
|
267
|
+
init_params[id_] = self.p
|
|
268
|
+
def _jax_wrapped_calc_and_approx(x, y, params):
|
|
269
|
+
base = jax.nn.relu(1.0 - jnp.stack([x, y], axis=0))
|
|
270
|
+
arg = jnp.linalg.norm(base, ord=params[id_], axis=0)
|
|
271
|
+
land = jax.nn.relu(1.0 - arg)
|
|
272
|
+
return land, params
|
|
273
|
+
return _jax_wrapped_calc_and_approx
|
|
274
|
+
|
|
275
|
+
def norms(self, id, init_params):
|
|
276
|
+
id_ = str(id)
|
|
277
|
+
init_params[id_] = self.p
|
|
278
|
+
def _jax_wrapped_calc_forall_approx(x, axis, params):
|
|
279
|
+
arg = jax.nn.relu(1.0 - x)
|
|
280
|
+
for ax in sorted(axis, reverse=True):
|
|
281
|
+
arg = jnp.linalg.norm(arg, ord=params[id_], axis=ax)
|
|
282
|
+
forall = jax.nn.relu(1.0 - arg)
|
|
283
|
+
return forall, params
|
|
284
|
+
return _jax_wrapped_calc_forall_approx
|
|
285
|
+
|
|
286
|
+
def __str__(self) -> str:
|
|
287
|
+
return f'Yager({self.p}) t-norm'
|
|
288
|
+
|
|
134
289
|
|
|
135
290
|
# ===========================================================================
|
|
136
291
|
# RANDOM SAMPLING
|
|
@@ -144,141 +299,147 @@ class RandomSampling:
|
|
|
144
299
|
'''An abstract class that describes how discrete and non-reparameterizable
|
|
145
300
|
random variables are sampled.'''
|
|
146
301
|
|
|
147
|
-
def discrete(self, logic):
|
|
302
|
+
def discrete(self, id, init_params, logic):
|
|
148
303
|
raise NotImplementedError
|
|
149
304
|
|
|
150
|
-
def bernoulli(self, logic):
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
def _jax_wrapped_calc_bernoulli_approx(key, prob, param):
|
|
305
|
+
def bernoulli(self, id, init_params, logic):
|
|
306
|
+
discrete_approx = self.discrete(id, init_params, logic)
|
|
307
|
+
def _jax_wrapped_calc_bernoulli_approx(key, prob, params):
|
|
154
308
|
prob = jnp.stack([1.0 - prob, prob], axis=-1)
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
return _jax_wrapped_calc_bernoulli_approx, jax_param
|
|
309
|
+
return discrete_approx(key, prob, params)
|
|
310
|
+
return _jax_wrapped_calc_bernoulli_approx
|
|
159
311
|
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
312
|
+
@staticmethod
|
|
313
|
+
def _jax_wrapped_calc_poisson_exact(key, rate, params):
|
|
314
|
+
sample = random.poisson(key=key, lam=rate, dtype=logic.INT)
|
|
315
|
+
return sample, params
|
|
164
316
|
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
jax_floor, jax_param = logic.floor()
|
|
173
|
-
|
|
174
|
-
def _jax_wrapped_calc_geometric_approx(key, prob, param):
|
|
317
|
+
def poisson(self, id, init_params, logic):
|
|
318
|
+
return self._jax_wrapped_calc_poisson_exact
|
|
319
|
+
|
|
320
|
+
def geometric(self, id, init_params, logic):
|
|
321
|
+
approx_floor = logic.floor(id, init_params)
|
|
322
|
+
def _jax_wrapped_calc_geometric_approx(key, prob, params):
|
|
175
323
|
U = random.uniform(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
return _jax_wrapped_calc_geometric_approx
|
|
324
|
+
floor, params = approx_floor(jnp.log(U) / jnp.log(1.0 - prob), params)
|
|
325
|
+
sample = floor + 1
|
|
326
|
+
return sample, params
|
|
327
|
+
return _jax_wrapped_calc_geometric_approx
|
|
180
328
|
|
|
181
329
|
|
|
182
330
|
class GumbelSoftmax(RandomSampling):
|
|
183
331
|
'''Random sampling of discrete variables using Gumbel-softmax trick.'''
|
|
184
332
|
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
jax_argmax, jax_param = logic.argmax()
|
|
191
|
-
|
|
192
|
-
def _jax_wrapped_calc_discrete_gumbel_softmax(key, prob, param):
|
|
333
|
+
# https://arxiv.org/pdf/1611.01144
|
|
334
|
+
def discrete(self, id, init_params, logic):
|
|
335
|
+
argmax_approx = logic.argmax(id, init_params)
|
|
336
|
+
def _jax_wrapped_calc_discrete_gumbel_softmax(key, prob, params):
|
|
193
337
|
Gumbel01 = random.gumbel(key=key, shape=prob.shape, dtype=logic.REAL)
|
|
194
338
|
sample = Gumbel01 + jnp.log(prob + logic.eps)
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
339
|
+
return argmax_approx(sample, axis=-1, params=params)
|
|
340
|
+
return _jax_wrapped_calc_discrete_gumbel_softmax
|
|
341
|
+
|
|
342
|
+
def __str__(self) -> str:
|
|
343
|
+
return 'Gumbel-Softmax'
|
|
199
344
|
|
|
200
345
|
|
|
201
346
|
class Determinization(RandomSampling):
|
|
202
347
|
'''Random sampling of variables using their deterministic mean estimate.'''
|
|
203
348
|
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
def _jax_wrapped_calc_discrete_determinized(key, prob, param):
|
|
210
|
-
literals = FuzzyLogic.enumerate_literals(prob.shape, axis=-1)
|
|
211
|
-
sample = jnp.sum(literals * prob, axis=-1)
|
|
212
|
-
return sample
|
|
213
|
-
|
|
214
|
-
return _jax_wrapped_calc_discrete_determinized, None
|
|
215
|
-
|
|
216
|
-
def poisson(self, logic):
|
|
217
|
-
if logic.verbose:
|
|
218
|
-
raise_warning('Using the replacement rule: Poisson(rate) --> rate')
|
|
219
|
-
|
|
220
|
-
def _jax_wrapped_calc_poisson_determinized(key, rate, param):
|
|
221
|
-
return rate
|
|
222
|
-
|
|
223
|
-
return _jax_wrapped_calc_poisson_determinized, None
|
|
224
|
-
|
|
225
|
-
def geometric(self, logic):
|
|
226
|
-
if logic.verbose:
|
|
227
|
-
raise_warning('Using the replacement rule: Geometric(p) --> 1 / p')
|
|
228
|
-
|
|
229
|
-
def _jax_wrapped_calc_geometric_determinized(key, prob, param):
|
|
230
|
-
sample = 1.0 / prob
|
|
231
|
-
return sample
|
|
349
|
+
@staticmethod
|
|
350
|
+
def _jax_wrapped_calc_discrete_determinized(key, prob, params):
|
|
351
|
+
literals = enumerate_literals(prob.shape, axis=-1)
|
|
352
|
+
sample = jnp.sum(literals * prob, axis=-1)
|
|
353
|
+
return sample, params
|
|
232
354
|
|
|
233
|
-
|
|
355
|
+
def discrete(self, id, init_params, logic):
|
|
356
|
+
return self._jax_wrapped_calc_discrete_determinized
|
|
357
|
+
|
|
358
|
+
@staticmethod
|
|
359
|
+
def _jax_wrapped_calc_poisson_determinized(key, rate, params):
|
|
360
|
+
return rate, params
|
|
234
361
|
|
|
362
|
+
def poisson(self, id, init_params, logic):
|
|
363
|
+
return self._jax_wrapped_calc_poisson_determinized
|
|
364
|
+
|
|
365
|
+
@staticmethod
|
|
366
|
+
def _jax_wrapped_calc_geometric_determinized(key, prob, params):
|
|
367
|
+
sample = 1.0 / prob
|
|
368
|
+
return sample, params
|
|
369
|
+
|
|
370
|
+
def geometric(self, id, init_params, logic):
|
|
371
|
+
return self._jax_wrapped_calc_geometric_determinized
|
|
372
|
+
|
|
373
|
+
def __str__(self) -> str:
|
|
374
|
+
return 'Deterministic'
|
|
375
|
+
|
|
235
376
|
|
|
236
377
|
# ===========================================================================
|
|
237
|
-
#
|
|
378
|
+
# CONTROL FLOW
|
|
379
|
+
# - soft flow
|
|
238
380
|
#
|
|
239
381
|
# ===========================================================================
|
|
240
382
|
|
|
241
|
-
class
|
|
242
|
-
'''A class
|
|
383
|
+
class ControlFlow:
|
|
384
|
+
'''A base class for control flow, including if and switch statements.'''
|
|
243
385
|
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
'''
|
|
386
|
+
def if_then_else(self, id, init_params):
|
|
387
|
+
raise NotImplementedError
|
|
247
388
|
|
|
248
|
-
def
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
'''Creates a new fuzzy logic in Jax.
|
|
389
|
+
def switch(self, id, init_params):
|
|
390
|
+
raise NotImplementedError
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
class SoftControlFlow(ControlFlow):
|
|
394
|
+
'''Soft control flow using a probabilistic interpretation.'''
|
|
395
|
+
|
|
396
|
+
def __init__(self, weight: float=10.0) -> None:
|
|
397
|
+
self.weight = weight
|
|
258
398
|
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
399
|
+
@staticmethod
|
|
400
|
+
def _jax_wrapped_calc_if_then_else_soft(c, a, b, params):
|
|
401
|
+
sample = c * a + (1.0 - c) * b
|
|
402
|
+
return sample, params
|
|
403
|
+
|
|
404
|
+
def if_then_else(self, id, init_params):
|
|
405
|
+
return self._jax_wrapped_calc_if_then_else_soft
|
|
406
|
+
|
|
407
|
+
def switch(self, id, init_params):
|
|
408
|
+
id_ = str(id)
|
|
409
|
+
init_params[id_] = self.weight
|
|
410
|
+
def _jax_wrapped_calc_switch_soft(pred, cases, params):
|
|
411
|
+
literals = enumerate_literals(cases.shape, axis=0)
|
|
412
|
+
pred = jnp.broadcast_to(pred[jnp.newaxis, ...], shape=cases.shape)
|
|
413
|
+
proximity = -jnp.square(pred - literals)
|
|
414
|
+
softcase = jax.nn.softmax(params[id_] * proximity, axis=0)
|
|
415
|
+
sample = jnp.sum(cases * softcase, axis=0)
|
|
416
|
+
return sample, params
|
|
417
|
+
return _jax_wrapped_calc_switch_soft
|
|
418
|
+
|
|
419
|
+
def __str__(self) -> str:
|
|
420
|
+
return f'Soft control flow with weight {self.weight}'
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
# ===========================================================================
|
|
424
|
+
# LOGIC
|
|
425
|
+
# - exact logic
|
|
426
|
+
# - fuzzy logic
|
|
427
|
+
#
|
|
428
|
+
# ===========================================================================
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
class Logic:
|
|
432
|
+
'''A base class for representing logic computations in JAX.'''
|
|
433
|
+
|
|
434
|
+
def __init__(self, use64bit: bool=False) -> None:
|
|
279
435
|
self.set_use64bit(use64bit)
|
|
280
436
|
|
|
437
|
+
def summarize_hyperparameters(self) -> None:
|
|
438
|
+
print(f'model relaxation:\n'
|
|
439
|
+
f' use_64_bit ={self.use64bit}')
|
|
440
|
+
|
|
281
441
|
def set_use64bit(self, use64bit: bool) -> None:
|
|
442
|
+
'''Toggles whether or not the JAX system will use 64 bit precision.'''
|
|
282
443
|
self.use64bit = use64bit
|
|
283
444
|
if use64bit:
|
|
284
445
|
self.REAL = jnp.float64
|
|
@@ -288,384 +449,507 @@ class FuzzyLogic:
|
|
|
288
449
|
self.REAL = jnp.float32
|
|
289
450
|
self.INT = jnp.int32
|
|
290
451
|
jax.config.update('jax_enable_x64', False)
|
|
291
|
-
|
|
292
|
-
def summarize_hyperparameters(self) -> None:
|
|
293
|
-
print(f'model relaxation:\n'
|
|
294
|
-
f' tnorm ={type(self.tnorm).__name__}\n'
|
|
295
|
-
f' complement ={type(self.complement).__name__}\n'
|
|
296
|
-
f' comparison ={type(self.comparison).__name__}\n'
|
|
297
|
-
f' sampling ={type(self.sampling).__name__}\n'
|
|
298
|
-
f' sigmoid_weight={self.weight}\n'
|
|
299
|
-
f' cpfs_to_debias={self.debias}\n'
|
|
300
|
-
f' underflow_tol ={self.eps}\n'
|
|
301
|
-
f' use_64_bit ={self.use64bit}')
|
|
302
|
-
|
|
452
|
+
|
|
303
453
|
# ===========================================================================
|
|
304
454
|
# logical operators
|
|
305
455
|
# ===========================================================================
|
|
306
456
|
|
|
307
|
-
def logical_and(self):
|
|
308
|
-
|
|
309
|
-
raise_warning('Using the replacement rule: a ^ b --> tnorm(a, b).')
|
|
310
|
-
|
|
311
|
-
_and = self.tnorm.norm
|
|
312
|
-
|
|
313
|
-
def _jax_wrapped_calc_and_approx(a, b, param):
|
|
314
|
-
return _and(a, b)
|
|
315
|
-
|
|
316
|
-
return _jax_wrapped_calc_and_approx, None
|
|
457
|
+
def logical_and(self, id, init_params):
|
|
458
|
+
raise NotImplementedError
|
|
317
459
|
|
|
318
|
-
def logical_not(self):
|
|
319
|
-
|
|
320
|
-
raise_warning('Using the replacement rule: ~a --> complement(a)')
|
|
321
|
-
|
|
322
|
-
_not = self.complement
|
|
323
|
-
|
|
324
|
-
def _jax_wrapped_calc_not_approx(x, param):
|
|
325
|
-
return _not(x)
|
|
326
|
-
|
|
327
|
-
return _jax_wrapped_calc_not_approx, None
|
|
328
|
-
|
|
329
|
-
def logical_or(self):
|
|
330
|
-
if self.verbose:
|
|
331
|
-
raise_warning('Using the replacement rule: a | b --> tconorm(a, b).')
|
|
332
|
-
|
|
333
|
-
_not = self.complement
|
|
334
|
-
_and = self.tnorm.norm
|
|
335
|
-
|
|
336
|
-
def _jax_wrapped_calc_or_approx(a, b, param):
|
|
337
|
-
return _not(_and(_not(a), _not(b)))
|
|
338
|
-
|
|
339
|
-
return _jax_wrapped_calc_or_approx, None
|
|
340
|
-
|
|
341
|
-
def xor(self):
|
|
342
|
-
if self.verbose:
|
|
343
|
-
raise_warning('Using the replacement rule: '
|
|
344
|
-
'a ~ b --> (a | b) ^ (a ^ b).')
|
|
345
|
-
|
|
346
|
-
_not = self.complement
|
|
347
|
-
_and = self.tnorm.norm
|
|
348
|
-
|
|
349
|
-
def _jax_wrapped_calc_xor_approx(a, b, param):
|
|
350
|
-
_or = _not(_and(_not(a), _not(b)))
|
|
351
|
-
return _and(_or(a, b), _not(_and(a, b)))
|
|
352
|
-
|
|
353
|
-
return _jax_wrapped_calc_xor_approx, None
|
|
354
|
-
|
|
355
|
-
def implies(self):
|
|
356
|
-
if self.verbose:
|
|
357
|
-
raise_warning('Using the replacement rule: a => b --> ~a ^ b')
|
|
358
|
-
|
|
359
|
-
_not = self.complement
|
|
360
|
-
_and = self.tnorm.norm
|
|
361
|
-
|
|
362
|
-
def _jax_wrapped_calc_implies_approx(a, b, param):
|
|
363
|
-
return _not(_and(a, _not(b)))
|
|
364
|
-
|
|
365
|
-
return _jax_wrapped_calc_implies_approx, None
|
|
460
|
+
def logical_not(self, id, init_params):
|
|
461
|
+
raise NotImplementedError
|
|
366
462
|
|
|
367
|
-
def
|
|
368
|
-
|
|
369
|
-
raise_warning('Using the replacement rule: '
|
|
370
|
-
'a <=> b --> (a => b) ^ (b => a)')
|
|
371
|
-
|
|
372
|
-
_not = self.complement
|
|
373
|
-
_and = self.tnorm.norm
|
|
374
|
-
|
|
375
|
-
def _jax_wrapped_calc_equiv_approx(a, b, param):
|
|
376
|
-
atob = _not(_and(a, _not(b)))
|
|
377
|
-
btoa = _not(_and(b, _not(a)))
|
|
378
|
-
return _and(atob, btoa)
|
|
379
|
-
|
|
380
|
-
return _jax_wrapped_calc_equiv_approx, None
|
|
463
|
+
def logical_or(self, id, init_params):
|
|
464
|
+
raise NotImplementedError
|
|
381
465
|
|
|
382
|
-
def
|
|
383
|
-
|
|
384
|
-
raise_warning('Using the replacement rule: '
|
|
385
|
-
'forall(a) --> a[1] ^ a[2] ^ ...')
|
|
386
|
-
|
|
387
|
-
_forall = self.tnorm.norms
|
|
388
|
-
|
|
389
|
-
def _jax_wrapped_calc_forall_approx(x, axis, param):
|
|
390
|
-
return _forall(x, axis=axis)
|
|
391
|
-
|
|
392
|
-
return _jax_wrapped_calc_forall_approx, None
|
|
466
|
+
def xor(self, id, init_params):
|
|
467
|
+
raise NotImplementedError
|
|
393
468
|
|
|
394
|
-
def
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
469
|
+
def implies(self, id, init_params):
|
|
470
|
+
raise NotImplementedError
|
|
471
|
+
|
|
472
|
+
def equiv(self, id, init_params):
|
|
473
|
+
raise NotImplementedError
|
|
474
|
+
|
|
475
|
+
def forall(self, id, init_params):
|
|
476
|
+
raise NotImplementedError
|
|
477
|
+
|
|
478
|
+
def exists(self, id, init_params):
|
|
479
|
+
raise NotImplementedError
|
|
402
480
|
|
|
403
481
|
# ===========================================================================
|
|
404
482
|
# comparison operators
|
|
405
483
|
# ===========================================================================
|
|
484
|
+
|
|
485
|
+
def greater_equal(self, id, init_params):
|
|
486
|
+
raise NotImplementedError
|
|
487
|
+
|
|
488
|
+
def greater(self, id, init_params):
|
|
489
|
+
raise NotImplementedError
|
|
490
|
+
|
|
491
|
+
def less_equal(self, id, init_params):
|
|
492
|
+
raise NotImplementedError
|
|
493
|
+
|
|
494
|
+
def less(self, id, init_params):
|
|
495
|
+
raise NotImplementedError
|
|
496
|
+
|
|
497
|
+
def equal(self, id, init_params):
|
|
498
|
+
raise NotImplementedError
|
|
499
|
+
|
|
500
|
+
def not_equal(self, id, init_params):
|
|
501
|
+
raise NotImplementedError
|
|
502
|
+
|
|
503
|
+
# ===========================================================================
|
|
504
|
+
# special functions
|
|
505
|
+
# ===========================================================================
|
|
406
506
|
|
|
407
|
-
def
|
|
408
|
-
|
|
409
|
-
raise_warning('Using the replacement rule: '
|
|
410
|
-
'a >= b --> comparison.greater_equal(a, b)')
|
|
411
|
-
|
|
412
|
-
greater_equal_op = self.comparison.greater_equal
|
|
413
|
-
debias = 'greater_equal' in self.debias
|
|
414
|
-
|
|
415
|
-
def _jax_wrapped_calc_geq_approx(a, b, param):
|
|
416
|
-
sample = greater_equal_op(a, b, param)
|
|
417
|
-
if debias:
|
|
418
|
-
hard_sample = jnp.greater_equal(a, b)
|
|
419
|
-
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
420
|
-
return sample
|
|
421
|
-
|
|
422
|
-
tags = ('weight', 'greater_equal')
|
|
423
|
-
new_param = (tags, self.weight)
|
|
424
|
-
return _jax_wrapped_calc_geq_approx, new_param
|
|
425
|
-
|
|
426
|
-
def greater(self):
|
|
427
|
-
if self.verbose:
|
|
428
|
-
raise_warning('Using the replacement rule: '
|
|
429
|
-
'a > b --> comparison.greater(a, b)')
|
|
430
|
-
|
|
431
|
-
greater_op = self.comparison.greater
|
|
432
|
-
debias = 'greater' in self.debias
|
|
433
|
-
|
|
434
|
-
def _jax_wrapped_calc_gre_approx(a, b, param):
|
|
435
|
-
sample = greater_op(a, b, param)
|
|
436
|
-
if debias:
|
|
437
|
-
hard_sample = jnp.greater(a, b)
|
|
438
|
-
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
439
|
-
return sample
|
|
440
|
-
|
|
441
|
-
tags = ('weight', 'greater')
|
|
442
|
-
new_param = (tags, self.weight)
|
|
443
|
-
return _jax_wrapped_calc_gre_approx, new_param
|
|
507
|
+
def sgn(self, id, init_params):
|
|
508
|
+
raise NotImplementedError
|
|
444
509
|
|
|
445
|
-
def
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
510
|
+
def floor(self, id, init_params):
|
|
511
|
+
raise NotImplementedError
|
|
512
|
+
|
|
513
|
+
def round(self, id, init_params):
|
|
514
|
+
raise NotImplementedError
|
|
515
|
+
|
|
516
|
+
def ceil(self, id, init_params):
|
|
517
|
+
raise NotImplementedError
|
|
518
|
+
|
|
519
|
+
def div(self, id, init_params):
|
|
520
|
+
raise NotImplementedError
|
|
521
|
+
|
|
522
|
+
def mod(self, id, init_params):
|
|
523
|
+
raise NotImplementedError
|
|
524
|
+
|
|
525
|
+
def sqrt(self, id, init_params):
|
|
526
|
+
raise NotImplementedError
|
|
527
|
+
|
|
528
|
+
# ===========================================================================
|
|
529
|
+
# indexing
|
|
530
|
+
# ===========================================================================
|
|
531
|
+
|
|
532
|
+
def argmax(self, id, init_params):
|
|
533
|
+
raise NotImplementedError
|
|
534
|
+
|
|
535
|
+
def argmin(self, id, init_params):
|
|
536
|
+
raise NotImplementedError
|
|
537
|
+
|
|
538
|
+
# ===========================================================================
|
|
539
|
+
# control flow
|
|
540
|
+
# ===========================================================================
|
|
541
|
+
|
|
542
|
+
def control_if(self, id, init_params):
|
|
543
|
+
raise NotImplementedError
|
|
450
544
|
|
|
451
|
-
|
|
545
|
+
def control_switch(self, id, init_params):
|
|
546
|
+
raise NotImplementedError
|
|
547
|
+
|
|
548
|
+
# ===========================================================================
|
|
549
|
+
# random variables
|
|
550
|
+
# ===========================================================================
|
|
551
|
+
|
|
552
|
+
def discrete(self, id, init_params):
|
|
553
|
+
raise NotImplementedError
|
|
554
|
+
|
|
555
|
+
def bernoulli(self, id, init_params):
|
|
556
|
+
raise NotImplementedError
|
|
452
557
|
|
|
453
|
-
def
|
|
454
|
-
|
|
558
|
+
def poisson(self, id, init_params):
|
|
559
|
+
raise NotImplementedError
|
|
560
|
+
|
|
561
|
+
def geometric(self, id, init_params):
|
|
562
|
+
raise NotImplementedError
|
|
455
563
|
|
|
456
|
-
def _jax_wrapped_calc_less_approx(a, b, param):
|
|
457
|
-
return jax_gre(-a, -b, param)
|
|
458
|
-
|
|
459
|
-
return _jax_wrapped_calc_less_approx, jax_param
|
|
460
564
|
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
raise_warning('Using the replacement rule: '
|
|
464
|
-
'a == b --> comparison.equal(a, b)')
|
|
465
|
-
|
|
466
|
-
equal_op = self.comparison.equal
|
|
467
|
-
debias = 'equal' in self.debias
|
|
468
|
-
|
|
469
|
-
def _jax_wrapped_calc_equal_approx(a, b, param):
|
|
470
|
-
sample = equal_op(a, b, param)
|
|
471
|
-
if debias:
|
|
472
|
-
hard_sample = jnp.equal(a, b)
|
|
473
|
-
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
474
|
-
return sample
|
|
475
|
-
|
|
476
|
-
tags = ('weight', 'equal')
|
|
477
|
-
new_param = (tags, self.weight)
|
|
478
|
-
return _jax_wrapped_calc_equal_approx, new_param
|
|
565
|
+
class ExactLogic(Logic):
|
|
566
|
+
'''A class representing exact logic in JAX.'''
|
|
479
567
|
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
568
|
+
@staticmethod
|
|
569
|
+
def exact_unary_function(op):
|
|
570
|
+
def _jax_wrapped_calc_unary_function_exact(x, params):
|
|
571
|
+
return op(x), params
|
|
572
|
+
return _jax_wrapped_calc_unary_function_exact
|
|
573
|
+
|
|
574
|
+
@staticmethod
|
|
575
|
+
def exact_binary_function(op):
|
|
576
|
+
def _jax_wrapped_calc_binary_function_exact(x, y, params):
|
|
577
|
+
return op(x, y), params
|
|
578
|
+
return _jax_wrapped_calc_binary_function_exact
|
|
579
|
+
|
|
580
|
+
@staticmethod
|
|
581
|
+
def exact_aggregation(op):
|
|
582
|
+
def _jax_wrapped_calc_aggregation_exact(x, axis, params):
|
|
583
|
+
return op(x, axis=axis), params
|
|
584
|
+
return _jax_wrapped_calc_aggregation_exact
|
|
486
585
|
|
|
487
|
-
|
|
488
|
-
|
|
586
|
+
# ===========================================================================
|
|
587
|
+
# logical operators
|
|
588
|
+
# ===========================================================================
|
|
589
|
+
|
|
590
|
+
def logical_and(self, id, init_params):
|
|
591
|
+
return self.exact_binary_function(jnp.logical_and)
|
|
592
|
+
|
|
593
|
+
def logical_not(self, id, init_params):
|
|
594
|
+
return self.exact_unary_function(jnp.logical_not)
|
|
595
|
+
|
|
596
|
+
def logical_or(self, id, init_params):
|
|
597
|
+
return self.exact_binary_function(jnp.logical_or)
|
|
598
|
+
|
|
599
|
+
def xor(self, id, init_params):
|
|
600
|
+
return self.exact_binary_function(jnp.logical_xor)
|
|
601
|
+
|
|
602
|
+
@staticmethod
|
|
603
|
+
def exact_binary_implies(x, y, params):
|
|
604
|
+
return jnp.logical_or(jnp.logical_not(x), y), params
|
|
605
|
+
|
|
606
|
+
def implies(self, id, init_params):
|
|
607
|
+
return self.exact_binary_implies
|
|
608
|
+
|
|
609
|
+
def equiv(self, id, init_params):
|
|
610
|
+
return self.exact_binary_function(jnp.equal)
|
|
611
|
+
|
|
612
|
+
def forall(self, id, init_params):
|
|
613
|
+
return self.exact_aggregation(jnp.all)
|
|
614
|
+
|
|
615
|
+
def exists(self, id, init_params):
|
|
616
|
+
return self.exact_aggregation(jnp.any)
|
|
617
|
+
|
|
618
|
+
# ===========================================================================
|
|
619
|
+
# comparison operators
|
|
620
|
+
# ===========================================================================
|
|
621
|
+
|
|
622
|
+
def greater_equal(self, id, init_params):
|
|
623
|
+
return self.exact_binary_function(jnp.greater_equal)
|
|
624
|
+
|
|
625
|
+
def greater(self, id, init_params):
|
|
626
|
+
return self.exact_binary_function(jnp.greater)
|
|
627
|
+
|
|
628
|
+
def less_equal(self, id, init_params):
|
|
629
|
+
return self.exact_binary_function(jnp.less_equal)
|
|
630
|
+
|
|
631
|
+
def less(self, id, init_params):
|
|
632
|
+
return self.exact_binary_function(jnp.less)
|
|
633
|
+
|
|
634
|
+
def equal(self, id, init_params):
|
|
635
|
+
return self.exact_binary_function(jnp.equal)
|
|
636
|
+
|
|
637
|
+
def not_equal(self, id, init_params):
|
|
638
|
+
return self.exact_binary_function(jnp.not_equal)
|
|
639
|
+
|
|
489
640
|
# ===========================================================================
|
|
490
641
|
# special functions
|
|
491
642
|
# ===========================================================================
|
|
492
643
|
|
|
493
|
-
def sgn(self):
|
|
494
|
-
|
|
495
|
-
raise_warning('Using the replacement rule: sgn(x) --> tanh(x)')
|
|
496
|
-
|
|
497
|
-
debias = 'sgn' in self.debias
|
|
498
|
-
|
|
499
|
-
def _jax_wrapped_calc_sgn_approx(x, param):
|
|
500
|
-
sample = jnp.tanh(param * x)
|
|
501
|
-
if debias:
|
|
502
|
-
hard_sample = jnp.sign(x)
|
|
503
|
-
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
504
|
-
return sample
|
|
505
|
-
|
|
506
|
-
tags = ('weight', 'sgn')
|
|
507
|
-
new_param = (tags, self.weight)
|
|
508
|
-
return _jax_wrapped_calc_sgn_approx, new_param
|
|
509
|
-
|
|
510
|
-
def floor(self):
|
|
511
|
-
if self.verbose:
|
|
512
|
-
raise_warning('Using the replacement rule: '
|
|
513
|
-
'floor(x) --> x - atan(-1.0 / tan(pi * x)) / pi - 0.5')
|
|
514
|
-
|
|
515
|
-
def _jax_wrapped_calc_floor_approx(x, param):
|
|
516
|
-
sawtooth_part = jnp.arctan(-1.0 / jnp.tan(x * jnp.pi)) / jnp.pi + 0.5
|
|
517
|
-
sample = x - jax.lax.stop_gradient(sawtooth_part)
|
|
518
|
-
return sample
|
|
519
|
-
|
|
520
|
-
return _jax_wrapped_calc_floor_approx, None
|
|
521
|
-
|
|
522
|
-
def ceil(self):
|
|
523
|
-
jax_floor, jax_param = self.floor()
|
|
524
|
-
|
|
525
|
-
def _jax_wrapped_calc_ceil_approx(x, param):
|
|
526
|
-
return -jax_floor(-x, param)
|
|
527
|
-
|
|
528
|
-
return _jax_wrapped_calc_ceil_approx, jax_param
|
|
644
|
+
def sgn(self, id, init_params):
|
|
645
|
+
return self.exact_unary_function(jnp.sign)
|
|
529
646
|
|
|
530
|
-
def
|
|
531
|
-
|
|
532
|
-
raise_warning('Using the replacement rule: round(x) --> x')
|
|
533
|
-
|
|
534
|
-
debias = 'round' in self.debias
|
|
535
|
-
|
|
536
|
-
def _jax_wrapped_calc_round_approx(x, param):
|
|
537
|
-
sample = x
|
|
538
|
-
if debias:
|
|
539
|
-
hard_sample = jnp.round(x)
|
|
540
|
-
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
541
|
-
return sample
|
|
542
|
-
|
|
543
|
-
return _jax_wrapped_calc_round_approx, None
|
|
647
|
+
def floor(self, id, init_params):
|
|
648
|
+
return self.exact_unary_function(jnp.floor)
|
|
544
649
|
|
|
545
|
-
def
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
def _jax_wrapped_calc_mod_approx(x, y, param):
|
|
549
|
-
return x - y * jax_floor(x / y, param)
|
|
550
|
-
|
|
551
|
-
return _jax_wrapped_calc_mod_approx, jax_param
|
|
650
|
+
def round(self, id, init_params):
|
|
651
|
+
return self.exact_unary_function(jnp.round)
|
|
552
652
|
|
|
553
|
-
def
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
def _jax_wrapped_calc_div_approx(x, y, param):
|
|
557
|
-
return jax_floor(x / y, param)
|
|
558
|
-
|
|
559
|
-
return _jax_wrapped_calc_div_approx, jax_param
|
|
653
|
+
def ceil(self, id, init_params):
|
|
654
|
+
return self.exact_unary_function(jnp.ceil)
|
|
560
655
|
|
|
561
|
-
def
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
return
|
|
656
|
+
def div(self, id, init_params):
|
|
657
|
+
return self.exact_binary_function(jnp.floor_divide)
|
|
658
|
+
|
|
659
|
+
def mod(self, id, init_params):
|
|
660
|
+
return self.exact_binary_function(jnp.mod)
|
|
661
|
+
|
|
662
|
+
def sqrt(self, id, init_params):
|
|
663
|
+
return self.exact_unary_function(jnp.sqrt)
|
|
569
664
|
|
|
570
665
|
# ===========================================================================
|
|
571
666
|
# indexing
|
|
572
667
|
# ===========================================================================
|
|
573
668
|
|
|
669
|
+
def argmax(self, id, init_params):
|
|
670
|
+
return self.exact_aggregation(jnp.argmax)
|
|
671
|
+
|
|
672
|
+
def argmin(self, id, init_params):
|
|
673
|
+
return self.exact_aggregation(jnp.argmin)
|
|
674
|
+
|
|
675
|
+
# ===========================================================================
|
|
676
|
+
# control flow
|
|
677
|
+
# ===========================================================================
|
|
678
|
+
|
|
574
679
|
@staticmethod
|
|
575
|
-
def
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
def
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
680
|
+
def exact_if_then_else(c, a, b, params):
|
|
681
|
+
return jnp.where(c > 0.5, a, b), params
|
|
682
|
+
|
|
683
|
+
def control_if(self, id, init_params):
|
|
684
|
+
return self.exact_if_then_else
|
|
685
|
+
|
|
686
|
+
@staticmethod
|
|
687
|
+
def exact_switch(pred, cases, params):
|
|
688
|
+
pred = pred[jnp.newaxis, ...]
|
|
689
|
+
sample = jnp.take_along_axis(cases, pred, axis=0)
|
|
690
|
+
assert sample.shape[0] == 1
|
|
691
|
+
return sample[0, ...], params
|
|
692
|
+
|
|
693
|
+
def control_switch(self, id, init_params):
|
|
694
|
+
return self.exact_switch
|
|
695
|
+
|
|
696
|
+
# ===========================================================================
|
|
697
|
+
# random variables
|
|
698
|
+
# ===========================================================================
|
|
699
|
+
|
|
700
|
+
@staticmethod
|
|
701
|
+
def exact_discrete(key, prob, params):
|
|
702
|
+
return random.categorical(key=key, logits=jnp.log(prob), axis=-1), params
|
|
703
|
+
|
|
704
|
+
def discrete(self, id, init_params):
|
|
705
|
+
return self.exact_discrete
|
|
706
|
+
|
|
707
|
+
@staticmethod
|
|
708
|
+
def exact_bernoulli(key, prob, params):
|
|
709
|
+
return random.bernoulli(key, prob), params
|
|
710
|
+
|
|
711
|
+
def bernoulli(self, id, init_params):
|
|
712
|
+
return self.exact_bernoulli
|
|
713
|
+
|
|
714
|
+
@staticmethod
|
|
715
|
+
def exact_poisson(key, rate, params):
|
|
716
|
+
return random.poisson(key=key, lam=rate), params
|
|
717
|
+
|
|
718
|
+
def poisson(self, id, init_params):
|
|
719
|
+
return self.exact_poisson
|
|
720
|
+
|
|
721
|
+
@staticmethod
|
|
722
|
+
def exact_geometric(key, prob, params):
|
|
723
|
+
return random.geometric(key=key, p=prob), params
|
|
724
|
+
|
|
725
|
+
def geometric(self, id, init_params):
|
|
726
|
+
return self.exact_geometric
|
|
727
|
+
|
|
728
|
+
|
|
729
|
+
class FuzzyLogic(Logic):
|
|
730
|
+
'''A class representing fuzzy logic in JAX.'''
|
|
731
|
+
|
|
732
|
+
def __init__(self, tnorm: TNorm=ProductTNorm(),
|
|
733
|
+
complement: Complement=StandardComplement(),
|
|
734
|
+
comparison: Comparison=SigmoidComparison(),
|
|
735
|
+
sampling: RandomSampling=GumbelSoftmax(),
|
|
736
|
+
rounding: Rounding=SoftRounding(),
|
|
737
|
+
control: ControlFlow=SoftControlFlow(),
|
|
738
|
+
eps: float=1e-15,
|
|
739
|
+
use64bit: bool=False) -> None:
|
|
740
|
+
'''Creates a new fuzzy logic in Jax.
|
|
741
|
+
|
|
742
|
+
:param tnorm: fuzzy operator for logical AND
|
|
743
|
+
:param complement: fuzzy operator for logical NOT
|
|
744
|
+
:param comparison: fuzzy operator for comparisons (>, >=, <, ==, ~=, ...)
|
|
745
|
+
:param sampling: random sampling of non-reparameterizable distributions
|
|
746
|
+
:param rounding: rounding floating values to integers
|
|
747
|
+
:param control: if and switch control structures
|
|
748
|
+
:param eps: small positive float to mitigate underflow
|
|
749
|
+
:param use64bit: whether to perform arithmetic in 64 bit
|
|
750
|
+
'''
|
|
751
|
+
super().__init__(use64bit=use64bit)
|
|
752
|
+
self.tnorm = tnorm
|
|
753
|
+
self.complement = complement
|
|
754
|
+
self.comparison = comparison
|
|
755
|
+
self.sampling = sampling
|
|
756
|
+
self.rounding = rounding
|
|
757
|
+
self.control = control
|
|
758
|
+
self.eps = eps
|
|
759
|
+
|
|
760
|
+
def __str__(self) -> str:
|
|
761
|
+
return (f'model relaxation:\n'
|
|
762
|
+
f' tnorm ={str(self.tnorm)}\n'
|
|
763
|
+
f' complement ={str(self.complement)}\n'
|
|
764
|
+
f' comparison ={str(self.comparison)}\n'
|
|
765
|
+
f' sampling ={str(self.sampling)}\n'
|
|
766
|
+
f' rounding ={str(self.rounding)}\n'
|
|
767
|
+
f' control ={str(self.control)}\n'
|
|
768
|
+
f' underflow_tol ={self.eps}\n'
|
|
769
|
+
f' use_64_bit ={self.use64bit}')
|
|
770
|
+
|
|
771
|
+
def summarize_hyperparameters(self) -> None:
|
|
772
|
+
print(self.__str__())
|
|
588
773
|
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
774
|
+
# ===========================================================================
|
|
775
|
+
# logical operators
|
|
776
|
+
# ===========================================================================
|
|
777
|
+
|
|
778
|
+
def logical_and(self, id, init_params):
|
|
779
|
+
return self.tnorm.norm(id, init_params)
|
|
780
|
+
|
|
781
|
+
def logical_not(self, id, init_params):
|
|
782
|
+
return self.complement(id, init_params)
|
|
783
|
+
|
|
784
|
+
def logical_or(self, id, init_params):
|
|
785
|
+
_not1 = self.complement(f'{id}_~1', init_params)
|
|
786
|
+
_not2 = self.complement(f'{id}_~2', init_params)
|
|
787
|
+
_and = self.tnorm.norm(f'{id}_^', init_params)
|
|
788
|
+
_not = self.complement(f'{id}_~', init_params)
|
|
789
|
+
|
|
790
|
+
def _jax_wrapped_calc_or_approx(x, y, params):
|
|
791
|
+
not_x, params = _not1(x, params)
|
|
792
|
+
not_y, params = _not2(y, params)
|
|
793
|
+
not_x_and_not_y, params = _and(not_x, not_y, params)
|
|
794
|
+
return _not(not_x_and_not_y, params)
|
|
795
|
+
return _jax_wrapped_calc_or_approx
|
|
796
|
+
|
|
797
|
+
def xor(self, id, init_params):
|
|
798
|
+
_not = self.complement(f'{id}_~', init_params)
|
|
799
|
+
_and1 = self.tnorm.norm(f'{id}_^1', init_params)
|
|
800
|
+
_and2 = self.tnorm.norm(f'{id}_^2', init_params)
|
|
801
|
+
_or = self.logical_or(f'{id}_|', init_params)
|
|
802
|
+
|
|
803
|
+
def _jax_wrapped_calc_xor_approx(x, y, params):
|
|
804
|
+
x_and_y, params = _and1(x, y, params)
|
|
805
|
+
not_x_and_y, params = _not(x_and_y, params)
|
|
806
|
+
x_or_y, params = _or(x, y, params)
|
|
807
|
+
return _and2(x_or_y, not_x_and_y, params)
|
|
808
|
+
return _jax_wrapped_calc_xor_approx
|
|
809
|
+
|
|
810
|
+
def implies(self, id, init_params):
|
|
811
|
+
_not = self.complement(f'{id}_~', init_params)
|
|
812
|
+
_or = self.logical_or(f'{id}_|', init_params)
|
|
813
|
+
|
|
814
|
+
def _jax_wrapped_calc_implies_approx(x, y, params):
|
|
815
|
+
not_x, params = _not(x, params)
|
|
816
|
+
return _or(not_x, y, params)
|
|
817
|
+
return _jax_wrapped_calc_implies_approx
|
|
818
|
+
|
|
819
|
+
def equiv(self, id, init_params):
|
|
820
|
+
_implies1 = self.implies(f'{id}_=>1', init_params)
|
|
821
|
+
_implies2 = self.implies(f'{id}_=>2', init_params)
|
|
822
|
+
_and = self.tnorm.norm(f'{id}_^', init_params)
|
|
823
|
+
|
|
824
|
+
def _jax_wrapped_calc_equiv_approx(x, y, params):
|
|
825
|
+
x_implies_y, params = _implies1(x, y, params)
|
|
826
|
+
y_implies_x, params = _implies2(y, x, params)
|
|
827
|
+
return _and(x_implies_y, y_implies_x, params)
|
|
828
|
+
return _jax_wrapped_calc_equiv_approx
|
|
829
|
+
|
|
830
|
+
def forall(self, id, init_params):
|
|
831
|
+
return self.tnorm.norms(id, init_params)
|
|
832
|
+
|
|
833
|
+
def exists(self, id, init_params):
|
|
834
|
+
_not1 = self.complement(f'{id}_~1', init_params)
|
|
835
|
+
_not2 = self.complement(f'{id}_~2', init_params)
|
|
836
|
+
_forall = self.forall(f'{id}_forall', init_params)
|
|
837
|
+
|
|
838
|
+
def _jax_wrapped_calc_exists_approx(x, axis, params):
|
|
839
|
+
not_x, params = _not1(x, params)
|
|
840
|
+
forall_not_x, params = _forall(not_x, axis, params)
|
|
841
|
+
return _not2(forall_not_x, params)
|
|
842
|
+
return _jax_wrapped_calc_exists_approx
|
|
843
|
+
|
|
844
|
+
# ===========================================================================
|
|
845
|
+
# comparison operators
|
|
846
|
+
# ===========================================================================
|
|
847
|
+
|
|
848
|
+
def greater_equal(self, id, init_params):
|
|
849
|
+
return self.comparison.greater_equal(id, init_params)
|
|
850
|
+
|
|
851
|
+
def greater(self, id, init_params):
|
|
852
|
+
return self.comparison.greater(id, init_params)
|
|
853
|
+
|
|
854
|
+
def less_equal(self, id, init_params):
|
|
855
|
+
_greater_eq = self.greater_equal(id, init_params)
|
|
856
|
+
def _jax_wrapped_calc_leq_approx(x, y, params):
|
|
857
|
+
return _greater_eq(-x, -y, params)
|
|
858
|
+
return _jax_wrapped_calc_leq_approx
|
|
859
|
+
|
|
860
|
+
def less(self, id, init_params):
|
|
861
|
+
_greater = self.greater(id, init_params)
|
|
862
|
+
def _jax_wrapped_calc_less_approx(x, y, params):
|
|
863
|
+
return _greater(-x, -y, params)
|
|
864
|
+
return _jax_wrapped_calc_less_approx
|
|
865
|
+
|
|
866
|
+
def equal(self, id, init_params):
|
|
867
|
+
return self.comparison.equal(id, init_params)
|
|
868
|
+
|
|
869
|
+
def not_equal(self, id, init_params):
|
|
870
|
+
_not = self.complement(f'{id}_~', init_params)
|
|
871
|
+
_equal = self.comparison.equal(f'{id}_==', init_params)
|
|
872
|
+
def _jax_wrapped_calc_neq_approx(x, y, params):
|
|
873
|
+
equal, params = _equal(x, y, params)
|
|
874
|
+
return _not(equal, params)
|
|
875
|
+
return _jax_wrapped_calc_neq_approx
|
|
597
876
|
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
877
|
+
# ===========================================================================
|
|
878
|
+
# special functions
|
|
879
|
+
# ===========================================================================
|
|
880
|
+
|
|
881
|
+
def sgn(self, id, init_params):
|
|
882
|
+
return self.comparison.sgn(id, init_params)
|
|
601
883
|
|
|
602
|
-
def
|
|
603
|
-
|
|
884
|
+
def floor(self, id, init_params):
|
|
885
|
+
return self.rounding.floor(id, init_params)
|
|
604
886
|
|
|
887
|
+
def round(self, id, init_params):
|
|
888
|
+
return self.rounding.round(id, init_params)
|
|
889
|
+
|
|
890
|
+
def ceil(self, id, init_params):
|
|
891
|
+
_floor = self.rounding.floor(id, init_params)
|
|
892
|
+
def _jax_wrapped_calc_ceil_approx(x, params):
|
|
893
|
+
neg_floor, params = _floor(-x, params)
|
|
894
|
+
return -neg_floor, params
|
|
895
|
+
return _jax_wrapped_calc_ceil_approx
|
|
896
|
+
|
|
897
|
+
def div(self, id, init_params):
|
|
898
|
+
_floor = self.rounding.floor(id, init_params)
|
|
899
|
+
def _jax_wrapped_calc_div_approx(x, y, params):
|
|
900
|
+
return _floor(x / y, params)
|
|
901
|
+
return _jax_wrapped_calc_div_approx
|
|
902
|
+
|
|
903
|
+
def mod(self, id, init_params):
|
|
904
|
+
_div = self.div(id, init_params)
|
|
905
|
+
def _jax_wrapped_calc_mod_approx(x, y, params):
|
|
906
|
+
div, params = _div(x, y, params)
|
|
907
|
+
return x - y * div, params
|
|
908
|
+
return _jax_wrapped_calc_mod_approx
|
|
909
|
+
|
|
910
|
+
def sqrt(self, id, init_params):
|
|
911
|
+
def _jax_wrapped_calc_sqrt_approx(x, params):
|
|
912
|
+
return jnp.sqrt(x + self.eps), params
|
|
913
|
+
return _jax_wrapped_calc_sqrt_approx
|
|
914
|
+
|
|
915
|
+
# ===========================================================================
|
|
916
|
+
# indexing
|
|
917
|
+
# ===========================================================================
|
|
918
|
+
|
|
919
|
+
def argmax(self, id, init_params):
|
|
920
|
+
return self.comparison.argmax(id, init_params)
|
|
921
|
+
|
|
922
|
+
def argmin(self, id, init_params):
|
|
923
|
+
_argmax = self.argmax(id, init_params)
|
|
605
924
|
def _jax_wrapped_calc_argmin_approx(x, axis, param):
|
|
606
|
-
return
|
|
607
|
-
|
|
608
|
-
return _jax_wrapped_calc_argmin_approx, jax_param
|
|
925
|
+
return _argmax(-x, axis, param)
|
|
926
|
+
return _jax_wrapped_calc_argmin_approx
|
|
609
927
|
|
|
610
928
|
# ===========================================================================
|
|
611
929
|
# control flow
|
|
612
930
|
# ===========================================================================
|
|
613
931
|
|
|
614
|
-
def control_if(self):
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
debias = 'if' in self.debias
|
|
620
|
-
|
|
621
|
-
def _jax_wrapped_calc_if_approx(c, a, b, param):
|
|
622
|
-
sample = c * a + (1.0 - c) * b
|
|
623
|
-
if debias:
|
|
624
|
-
hard_sample = jnp.where(c > 0.5, a, b)
|
|
625
|
-
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
626
|
-
return sample
|
|
627
|
-
|
|
628
|
-
return _jax_wrapped_calc_if_approx, None
|
|
629
|
-
|
|
630
|
-
def control_switch(self):
|
|
631
|
-
if self.verbose:
|
|
632
|
-
raise_warning('Using the replacement rule: '
|
|
633
|
-
'switch(pred) { cases } --> '
|
|
634
|
-
'sum(cases[i] * softmax(-abs(pred - i)))')
|
|
635
|
-
|
|
636
|
-
debias = 'switch' in self.debias
|
|
637
|
-
|
|
638
|
-
def _jax_wrapped_calc_switch_approx(pred, cases, param):
|
|
639
|
-
literals = FuzzyLogic.enumerate_literals(cases.shape, axis=0)
|
|
640
|
-
pred = jnp.broadcast_to(pred[jnp.newaxis, ...], shape=cases.shape)
|
|
641
|
-
proximity = -jnp.abs(pred - literals)
|
|
642
|
-
soft_case = jax.nn.softmax(param * proximity, axis=0)
|
|
643
|
-
sample = jnp.sum(cases * soft_case, axis=0)
|
|
644
|
-
if debias:
|
|
645
|
-
hard_case = jnp.argmax(proximity, axis=0)[jnp.newaxis, ...]
|
|
646
|
-
hard_sample = jnp.take_along_axis(cases, hard_case, axis=0)[0, ...]
|
|
647
|
-
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
648
|
-
return sample
|
|
649
|
-
|
|
650
|
-
tags = ('weight', 'switch')
|
|
651
|
-
new_param = (tags, self.weight)
|
|
652
|
-
return _jax_wrapped_calc_switch_approx, new_param
|
|
932
|
+
def control_if(self, id, init_params):
|
|
933
|
+
return self.control.if_then_else(id, init_params)
|
|
934
|
+
|
|
935
|
+
def control_switch(self, id, init_params):
|
|
936
|
+
return self.control.switch(id, init_params)
|
|
653
937
|
|
|
654
938
|
# ===========================================================================
|
|
655
939
|
# random variables
|
|
656
940
|
# ===========================================================================
|
|
657
941
|
|
|
658
|
-
def discrete(self):
|
|
659
|
-
return self.sampling.discrete(self)
|
|
942
|
+
def discrete(self, id, init_params):
|
|
943
|
+
return self.sampling.discrete(id, init_params, self)
|
|
660
944
|
|
|
661
|
-
def bernoulli(self):
|
|
662
|
-
return self.sampling.bernoulli(self)
|
|
945
|
+
def bernoulli(self, id, init_params):
|
|
946
|
+
return self.sampling.bernoulli(id, init_params, self)
|
|
663
947
|
|
|
664
|
-
def poisson(self):
|
|
665
|
-
return self.sampling.poisson(self)
|
|
948
|
+
def poisson(self, id, init_params):
|
|
949
|
+
return self.sampling.poisson(id, init_params, self)
|
|
666
950
|
|
|
667
|
-
def geometric(self):
|
|
668
|
-
return self.sampling.geometric(self)
|
|
951
|
+
def geometric(self, id, init_params):
|
|
952
|
+
return self.sampling.geometric(id, init_params, self)
|
|
669
953
|
|
|
670
954
|
|
|
671
955
|
# ===========================================================================
|
|
@@ -673,103 +957,121 @@ class FuzzyLogic:
|
|
|
673
957
|
#
|
|
674
958
|
# ===========================================================================
|
|
675
959
|
|
|
676
|
-
logic = FuzzyLogic()
|
|
677
|
-
|
|
960
|
+
logic = FuzzyLogic(comparison=SigmoidComparison(10000.0),
|
|
961
|
+
rounding=SoftRounding(10000.0),
|
|
962
|
+
control=SoftControlFlow(10000.0))
|
|
678
963
|
|
|
679
964
|
|
|
680
965
|
def _test_logical():
|
|
681
966
|
print('testing logical')
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
967
|
+
init_params = {}
|
|
968
|
+
_and = logic.logical_and(0, init_params)
|
|
969
|
+
_not = logic.logical_not(1, init_params)
|
|
970
|
+
_gre = logic.greater(2, init_params)
|
|
971
|
+
_or = logic.logical_or(3, init_params)
|
|
972
|
+
_if = logic.control_if(4, init_params)
|
|
973
|
+
print(init_params)
|
|
974
|
+
|
|
688
975
|
# https://towardsdatascience.com/emulating-logical-gates-with-a-neural-network-75c229ec4cc9
|
|
689
|
-
def test_logic(x1, x2):
|
|
690
|
-
q1 =
|
|
691
|
-
q2
|
|
692
|
-
|
|
693
|
-
|
|
976
|
+
def test_logic(x1, x2, w):
|
|
977
|
+
q1, w = _gre(x1, 0, w)
|
|
978
|
+
q2, w = _gre(x2, 0, w)
|
|
979
|
+
q3, w = _and(q1, q2, w)
|
|
980
|
+
q4, w = _not(q1, w)
|
|
981
|
+
q5, w = _not(q2, w)
|
|
982
|
+
q6, w = _and(q4, q5, w)
|
|
983
|
+
cond, w = _or(q3, q6, w)
|
|
984
|
+
pred, w = _if(cond, +1, -1, w)
|
|
694
985
|
return pred
|
|
695
986
|
|
|
696
987
|
x1 = jnp.asarray([1, 1, -1, -1, 0.1, 15, -0.5]).astype(float)
|
|
697
988
|
x2 = jnp.asarray([1, -1, 1, -1, 10, -30, 6]).astype(float)
|
|
698
|
-
print(test_logic(x1, x2))
|
|
989
|
+
print(test_logic(x1, x2, init_params))
|
|
699
990
|
|
|
700
991
|
|
|
701
992
|
def _test_indexing():
|
|
702
993
|
print('testing indexing')
|
|
703
|
-
|
|
704
|
-
|
|
994
|
+
init_params = {}
|
|
995
|
+
_argmax = logic.argmax(0, init_params)
|
|
996
|
+
_argmin = logic.argmin(1, init_params)
|
|
997
|
+
print(init_params)
|
|
705
998
|
|
|
706
|
-
def argmaxmin(x):
|
|
707
|
-
amax = _argmax(x, 0, w)
|
|
708
|
-
amin = _argmin(x, 0, w)
|
|
999
|
+
def argmaxmin(x, w):
|
|
1000
|
+
amax, w = _argmax(x, 0, w)
|
|
1001
|
+
amin, w = _argmin(x, 0, w)
|
|
709
1002
|
return amax, amin
|
|
710
1003
|
|
|
711
1004
|
values = jnp.asarray([2., 3., 5., 4.9, 4., 1., -1., -2.])
|
|
712
|
-
amax, amin = argmaxmin(values)
|
|
1005
|
+
amax, amin = argmaxmin(values, init_params)
|
|
713
1006
|
print(amax)
|
|
714
1007
|
print(amin)
|
|
715
1008
|
|
|
716
1009
|
|
|
717
1010
|
def _test_control():
|
|
718
1011
|
print('testing control')
|
|
719
|
-
|
|
1012
|
+
init_params = {}
|
|
1013
|
+
_switch = logic.control_switch(0, init_params)
|
|
1014
|
+
print(init_params)
|
|
720
1015
|
|
|
721
1016
|
pred = jnp.asarray(jnp.linspace(0, 2, 10))
|
|
722
1017
|
case1 = jnp.asarray([-10.] * 10)
|
|
723
1018
|
case2 = jnp.asarray([1.5] * 10)
|
|
724
1019
|
case3 = jnp.asarray([10.] * 10)
|
|
725
1020
|
cases = jnp.asarray([case1, case2, case3])
|
|
726
|
-
|
|
1021
|
+
switch, _ = _switch(pred, cases, init_params)
|
|
1022
|
+
print(switch)
|
|
727
1023
|
|
|
728
1024
|
|
|
729
1025
|
def _test_random():
|
|
730
1026
|
print('testing random')
|
|
731
1027
|
key = random.PRNGKey(42)
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
1028
|
+
init_params = {}
|
|
1029
|
+
_bernoulli = logic.bernoulli(0, init_params)
|
|
1030
|
+
_discrete = logic.discrete(1, init_params)
|
|
1031
|
+
_geometric = logic.geometric(2, init_params)
|
|
1032
|
+
print(init_params)
|
|
735
1033
|
|
|
736
|
-
def bern(n):
|
|
1034
|
+
def bern(n, w):
|
|
737
1035
|
prob = jnp.asarray([0.3] * n)
|
|
738
|
-
sample = _bernoulli(key, prob, w)
|
|
1036
|
+
sample, _ = _bernoulli(key, prob, w)
|
|
739
1037
|
return sample
|
|
740
1038
|
|
|
741
|
-
samples = bern(50000)
|
|
1039
|
+
samples = bern(50000, init_params)
|
|
742
1040
|
print(jnp.mean(samples))
|
|
743
1041
|
|
|
744
|
-
def disc(n):
|
|
1042
|
+
def disc(n, w):
|
|
745
1043
|
prob = jnp.asarray([0.1, 0.4, 0.5])
|
|
746
1044
|
prob = jnp.tile(prob, (n, 1))
|
|
747
|
-
sample = _discrete(key, prob, w)
|
|
1045
|
+
sample, _ = _discrete(key, prob, w)
|
|
748
1046
|
return sample
|
|
749
1047
|
|
|
750
|
-
samples = disc(50000)
|
|
1048
|
+
samples = disc(50000, init_params)
|
|
751
1049
|
samples = jnp.round(samples)
|
|
752
1050
|
print([jnp.mean(samples == i) for i in range(3)])
|
|
753
1051
|
|
|
754
|
-
def geom(n):
|
|
1052
|
+
def geom(n, w):
|
|
755
1053
|
prob = jnp.asarray([0.3] * n)
|
|
756
|
-
sample = _geometric(key, prob, w)
|
|
1054
|
+
sample, _ = _geometric(key, prob, w)
|
|
757
1055
|
return sample
|
|
758
1056
|
|
|
759
|
-
samples = geom(50000)
|
|
1057
|
+
samples = geom(50000, init_params)
|
|
760
1058
|
print(jnp.mean(samples))
|
|
761
1059
|
|
|
762
1060
|
|
|
763
1061
|
def _test_rounding():
|
|
764
1062
|
print('testing rounding')
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
print(
|
|
771
|
-
|
|
772
|
-
|
|
1063
|
+
init_params = {}
|
|
1064
|
+
_floor = logic.floor(0, init_params)
|
|
1065
|
+
_ceil = logic.ceil(1, init_params)
|
|
1066
|
+
_round = logic.round(2, init_params)
|
|
1067
|
+
_mod = logic.mod(3, init_params)
|
|
1068
|
+
print(init_params)
|
|
1069
|
+
|
|
1070
|
+
x = jnp.asarray([2.1, 0.6, 1.99, -2.01, -3.2, -0.1, -1.01, 23.01, -101.99, 200.01])
|
|
1071
|
+
print(_floor(x, init_params)[0])
|
|
1072
|
+
print(_ceil(x, init_params)[0])
|
|
1073
|
+
print(_round(x, init_params)[0])
|
|
1074
|
+
print(_mod(x, 2.0, init_params)[0])
|
|
773
1075
|
|
|
774
1076
|
|
|
775
1077
|
if __name__ == '__main__':
|