pyRDDLGym-jax 0.3__py3-none-any.whl → 0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +90 -67
- pyRDDLGym_jax/core/logic.py +286 -82
- pyRDDLGym_jax/core/planner.py +191 -97
- pyRDDLGym_jax/core/simulator.py +2 -1
- pyRDDLGym_jax/core/tuning.py +58 -63
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +2 -1
- pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +2 -1
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +2 -1
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +4 -3
- pyRDDLGym_jax/examples/configs/default_replan.cfg +2 -1
- pyRDDLGym_jax/examples/run_tune.py +1 -3
- pyRDDLGym_jax-0.5.dist-info/METADATA +278 -0
- {pyRDDLGym_jax-0.3.dist-info → pyRDDLGym_jax-0.5.dist-info}/RECORD +17 -17
- {pyRDDLGym_jax-0.3.dist-info → pyRDDLGym_jax-0.5.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax-0.3.dist-info/METADATA +0 -26
- {pyRDDLGym_jax-0.3.dist-info → pyRDDLGym_jax-0.5.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-0.3.dist-info → pyRDDLGym_jax-0.5.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/logic.py
CHANGED
|
@@ -1,24 +1,18 @@
|
|
|
1
|
+
from typing import Optional, Set
|
|
2
|
+
|
|
1
3
|
import jax
|
|
2
4
|
import jax.numpy as jnp
|
|
3
5
|
import jax.random as random
|
|
4
|
-
from typing import Optional, Set
|
|
5
6
|
|
|
6
7
|
from pyRDDLGym.core.debug.exception import raise_warning
|
|
7
8
|
|
|
8
9
|
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class StandardComplement(Complement):
|
|
17
|
-
'''The standard approximate logical complement given by x -> 1 - x.'''
|
|
18
|
-
|
|
19
|
-
def __call__(self, x):
|
|
20
|
-
return 1.0 - x
|
|
21
|
-
|
|
10
|
+
# ===========================================================================
|
|
11
|
+
# RELATIONAL OPERATIONS
|
|
12
|
+
# - abstract class
|
|
13
|
+
# - sigmoid comparison
|
|
14
|
+
#
|
|
15
|
+
# ===========================================================================
|
|
22
16
|
|
|
23
17
|
class Comparison:
|
|
24
18
|
'''Base class for approximate comparison operations.'''
|
|
@@ -32,10 +26,14 @@ class Comparison:
|
|
|
32
26
|
def equal(self, x, y, param):
|
|
33
27
|
raise NotImplementedError
|
|
34
28
|
|
|
29
|
+
def sgn(self, x, param):
|
|
30
|
+
raise NotImplementedError
|
|
31
|
+
|
|
35
32
|
|
|
36
33
|
class SigmoidComparison(Comparison):
|
|
37
34
|
'''Comparison operations approximated using sigmoid functions.'''
|
|
38
35
|
|
|
36
|
+
# https://arxiv.org/abs/2110.05651
|
|
39
37
|
def greater_equal(self, x, y, param):
|
|
40
38
|
return jax.nn.sigmoid(param * (x - y))
|
|
41
39
|
|
|
@@ -44,7 +42,75 @@ class SigmoidComparison(Comparison):
|
|
|
44
42
|
|
|
45
43
|
def equal(self, x, y, param):
|
|
46
44
|
return 1.0 - jnp.square(jnp.tanh(param * (y - x)))
|
|
47
|
-
|
|
45
|
+
|
|
46
|
+
def sgn(self, x, param):
|
|
47
|
+
return jnp.tanh(param * x)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# ===========================================================================
|
|
51
|
+
# ROUNDING OPERATIONS
|
|
52
|
+
# - abstract class
|
|
53
|
+
# - soft rounding
|
|
54
|
+
#
|
|
55
|
+
# ===========================================================================
|
|
56
|
+
|
|
57
|
+
class Rounding:
|
|
58
|
+
'''Base class for approximate rounding operations.'''
|
|
59
|
+
|
|
60
|
+
def floor(self, x, param):
|
|
61
|
+
raise NotImplementedError
|
|
62
|
+
|
|
63
|
+
def round(self, x, param):
|
|
64
|
+
raise NotImplementedError
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class SoftRounding(Rounding):
|
|
68
|
+
'''Rounding operations approximated using soft operations.'''
|
|
69
|
+
|
|
70
|
+
# https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/bijectors/Softfloor
|
|
71
|
+
def floor(self, x, param):
|
|
72
|
+
denom = jnp.tanh(param / 4.0)
|
|
73
|
+
return (jax.nn.sigmoid(param * (x - jnp.floor(x) - 1.0)) -
|
|
74
|
+
jax.nn.sigmoid(-param / 2.0)) / denom + jnp.floor(x)
|
|
75
|
+
|
|
76
|
+
# https://arxiv.org/abs/2006.09952
|
|
77
|
+
def round(self, x, param):
|
|
78
|
+
m = jnp.floor(x) + 0.5
|
|
79
|
+
return m + 0.5 * jnp.tanh(param * (x - m)) / jnp.tanh(param / 2.0)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
# ===========================================================================
|
|
83
|
+
# LOGICAL COMPLEMENT
|
|
84
|
+
# - abstract class
|
|
85
|
+
# - standard complement
|
|
86
|
+
#
|
|
87
|
+
# ===========================================================================
|
|
88
|
+
|
|
89
|
+
class Complement:
|
|
90
|
+
'''Base class for approximate logical complement operations.'''
|
|
91
|
+
|
|
92
|
+
def __call__(self, x):
|
|
93
|
+
raise NotImplementedError
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class StandardComplement(Complement):
|
|
97
|
+
'''The standard approximate logical complement given by x -> 1 - x.'''
|
|
98
|
+
|
|
99
|
+
# https://www.sciencedirect.com/science/article/abs/pii/016501149190171L
|
|
100
|
+
def __call__(self, x):
|
|
101
|
+
return 1.0 - x
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
# ===========================================================================
|
|
105
|
+
# TNORMS
|
|
106
|
+
# - abstract tnorm
|
|
107
|
+
# - product tnorm
|
|
108
|
+
# - Godel tnorm
|
|
109
|
+
# - Lukasiewicz tnorm
|
|
110
|
+
# - Yager(p) tnorm
|
|
111
|
+
#
|
|
112
|
+
# https://www.sciencedirect.com/science/article/abs/pii/016501149190171L
|
|
113
|
+
# ===========================================================================
|
|
48
114
|
|
|
49
115
|
class TNorm:
|
|
50
116
|
'''Base class for fuzzy differentiable t-norms.'''
|
|
@@ -86,8 +152,134 @@ class LukasiewiczTNorm(TNorm):
|
|
|
86
152
|
|
|
87
153
|
def norms(self, x, axis):
|
|
88
154
|
return jax.nn.relu(jnp.sum(x - 1.0, axis=axis) + 1.0)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class YagerTNorm(TNorm):
|
|
158
|
+
'''Yager t-norm given by the expression
|
|
159
|
+
(x, y) -> max(1 - ((1 - x)^p + (1 - y)^p)^(1/p)).'''
|
|
160
|
+
|
|
161
|
+
def __init__(self, p=2.0):
|
|
162
|
+
self.p = float(p)
|
|
163
|
+
|
|
164
|
+
def norm(self, x, y):
|
|
165
|
+
base = jax.nn.relu(1.0 - jnp.stack([x, y], axis=0))
|
|
166
|
+
arg = jnp.linalg.norm(base, ord=self.p, axis=0)
|
|
167
|
+
return jax.nn.relu(1.0 - arg)
|
|
168
|
+
|
|
169
|
+
def norms(self, x, axis):
|
|
170
|
+
arg = jax.nn.relu(1.0 - x)
|
|
171
|
+
for ax in sorted(axis, reverse=True):
|
|
172
|
+
arg = jnp.linalg.norm(arg, ord=self.p, axis=ax)
|
|
173
|
+
return jax.nn.relu(1.0 - arg)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
# ===========================================================================
|
|
177
|
+
# RANDOM SAMPLING
|
|
178
|
+
# - abstract sampler
|
|
179
|
+
# - Gumbel-softmax sampler
|
|
180
|
+
# - determinization
|
|
181
|
+
#
|
|
182
|
+
# ===========================================================================
|
|
183
|
+
|
|
184
|
+
class RandomSampling:
|
|
185
|
+
'''An abstract class that describes how discrete and non-reparameterizable
|
|
186
|
+
random variables are sampled.'''
|
|
89
187
|
|
|
188
|
+
def discrete(self, logic):
|
|
189
|
+
raise NotImplementedError
|
|
90
190
|
|
|
191
|
+
def bernoulli(self, logic):
|
|
192
|
+
jax_discrete, jax_param = self.discrete(logic)
|
|
193
|
+
|
|
194
|
+
def _jax_wrapped_calc_bernoulli_approx(key, prob, param):
|
|
195
|
+
prob = jnp.stack([1.0 - prob, prob], axis=-1)
|
|
196
|
+
sample = jax_discrete(key, prob, param)
|
|
197
|
+
return sample
|
|
198
|
+
|
|
199
|
+
return _jax_wrapped_calc_bernoulli_approx, jax_param
|
|
200
|
+
|
|
201
|
+
def poisson(self, logic):
|
|
202
|
+
|
|
203
|
+
def _jax_wrapped_calc_poisson_exact(key, rate, param):
|
|
204
|
+
return random.poisson(key=key, lam=rate, dtype=logic.INT)
|
|
205
|
+
|
|
206
|
+
return _jax_wrapped_calc_poisson_exact, None
|
|
207
|
+
|
|
208
|
+
def geometric(self, logic):
|
|
209
|
+
if logic.verbose:
|
|
210
|
+
raise_warning('Using the replacement rule: '
|
|
211
|
+
'Geometric(p) --> floor(log(U) / log(1 - p)) + 1')
|
|
212
|
+
|
|
213
|
+
jax_floor, jax_param = logic.floor()
|
|
214
|
+
|
|
215
|
+
def _jax_wrapped_calc_geometric_approx(key, prob, param):
|
|
216
|
+
U = random.uniform(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
|
|
217
|
+
sample = jax_floor(jnp.log(U) / jnp.log(1.0 - prob), param) + 1
|
|
218
|
+
return sample
|
|
219
|
+
|
|
220
|
+
return _jax_wrapped_calc_geometric_approx, jax_param
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
class GumbelSoftmax(RandomSampling):
|
|
224
|
+
'''Random sampling of discrete variables using Gumbel-softmax trick.'''
|
|
225
|
+
|
|
226
|
+
def discrete(self, logic):
|
|
227
|
+
if logic.verbose:
|
|
228
|
+
raise_warning('Using the replacement rule: '
|
|
229
|
+
'Discrete(p) --> Gumbel-Softmax(p)')
|
|
230
|
+
|
|
231
|
+
jax_argmax, jax_param = logic.argmax()
|
|
232
|
+
|
|
233
|
+
# https://arxiv.org/pdf/1611.01144
|
|
234
|
+
def _jax_wrapped_calc_discrete_gumbel_softmax(key, prob, param):
|
|
235
|
+
Gumbel01 = random.gumbel(key=key, shape=prob.shape, dtype=logic.REAL)
|
|
236
|
+
sample = Gumbel01 + jnp.log(prob + logic.eps)
|
|
237
|
+
sample = jax_argmax(sample, axis=-1, param=param)
|
|
238
|
+
return sample
|
|
239
|
+
|
|
240
|
+
return _jax_wrapped_calc_discrete_gumbel_softmax, jax_param
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
class Determinization(RandomSampling):
|
|
244
|
+
'''Random sampling of variables using their deterministic mean estimate.'''
|
|
245
|
+
|
|
246
|
+
def discrete(self, logic):
|
|
247
|
+
if logic.verbose:
|
|
248
|
+
raise_warning('Using the replacement rule: '
|
|
249
|
+
'Discrete(p) --> sum(i * p[i])')
|
|
250
|
+
|
|
251
|
+
def _jax_wrapped_calc_discrete_determinized(key, prob, param):
|
|
252
|
+
literals = FuzzyLogic.enumerate_literals(prob.shape, axis=-1)
|
|
253
|
+
sample = jnp.sum(literals * prob, axis=-1)
|
|
254
|
+
return sample
|
|
255
|
+
|
|
256
|
+
return _jax_wrapped_calc_discrete_determinized, None
|
|
257
|
+
|
|
258
|
+
def poisson(self, logic):
|
|
259
|
+
if logic.verbose:
|
|
260
|
+
raise_warning('Using the replacement rule: Poisson(rate) --> rate')
|
|
261
|
+
|
|
262
|
+
def _jax_wrapped_calc_poisson_determinized(key, rate, param):
|
|
263
|
+
return rate
|
|
264
|
+
|
|
265
|
+
return _jax_wrapped_calc_poisson_determinized, None
|
|
266
|
+
|
|
267
|
+
def geometric(self, logic):
|
|
268
|
+
if logic.verbose:
|
|
269
|
+
raise_warning('Using the replacement rule: Geometric(p) --> 1 / p')
|
|
270
|
+
|
|
271
|
+
def _jax_wrapped_calc_geometric_determinized(key, prob, param):
|
|
272
|
+
sample = 1.0 / prob
|
|
273
|
+
return sample
|
|
274
|
+
|
|
275
|
+
return _jax_wrapped_calc_geometric_determinized, None
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
# ===========================================================================
|
|
279
|
+
# FUZZY LOGIC
|
|
280
|
+
#
|
|
281
|
+
# ===========================================================================
|
|
282
|
+
|
|
91
283
|
class FuzzyLogic:
|
|
92
284
|
'''A class representing fuzzy logic in JAX.
|
|
93
285
|
|
|
@@ -98,9 +290,11 @@ class FuzzyLogic:
|
|
|
98
290
|
def __init__(self, tnorm: TNorm=ProductTNorm(),
|
|
99
291
|
complement: Complement=StandardComplement(),
|
|
100
292
|
comparison: Comparison=SigmoidComparison(),
|
|
293
|
+
sampling: RandomSampling=GumbelSoftmax(),
|
|
294
|
+
rounding: Rounding=SoftRounding(),
|
|
101
295
|
weight: float=10.0,
|
|
102
296
|
debias: Optional[Set[str]]=None,
|
|
103
|
-
eps: float=1e-
|
|
297
|
+
eps: float=1e-15,
|
|
104
298
|
verbose: bool=False,
|
|
105
299
|
use64bit: bool=False) -> None:
|
|
106
300
|
'''Creates a new fuzzy logic in Jax.
|
|
@@ -108,8 +302,9 @@ class FuzzyLogic:
|
|
|
108
302
|
:param tnorm: fuzzy operator for logical AND
|
|
109
303
|
:param complement: fuzzy operator for logical NOT
|
|
110
304
|
:param comparison: fuzzy operator for comparisons (>, >=, <, ==, ~=, ...)
|
|
305
|
+
:param sampling: random sampling of non-reparameterizable distributions
|
|
306
|
+
:param rounding: rounding floating values to integers
|
|
111
307
|
:param weight: a sharpness parameter for sigmoid and softmax activations
|
|
112
|
-
:param error: an error parameter (e.g. floor) (smaller means better accuracy)
|
|
113
308
|
:param debias: which functions to de-bias approximate on forward pass
|
|
114
309
|
:param eps: small positive float to mitigate underflow
|
|
115
310
|
:param verbose: whether to dump replacements and other info to console
|
|
@@ -118,6 +313,8 @@ class FuzzyLogic:
|
|
|
118
313
|
self.tnorm = tnorm
|
|
119
314
|
self.complement = complement
|
|
120
315
|
self.comparison = comparison
|
|
316
|
+
self.sampling = sampling
|
|
317
|
+
self.rounding = rounding
|
|
121
318
|
self.weight = float(weight)
|
|
122
319
|
if debias is None:
|
|
123
320
|
debias = set()
|
|
@@ -142,10 +339,12 @@ class FuzzyLogic:
|
|
|
142
339
|
f' tnorm ={type(self.tnorm).__name__}\n'
|
|
143
340
|
f' complement ={type(self.complement).__name__}\n'
|
|
144
341
|
f' comparison ={type(self.comparison).__name__}\n'
|
|
342
|
+
f' sampling ={type(self.sampling).__name__}\n'
|
|
343
|
+
f' rounding ={type(self.rounding).__name__}\n'
|
|
145
344
|
f' sigmoid_weight={self.weight}\n'
|
|
146
345
|
f' cpfs_to_debias={self.debias}\n'
|
|
147
346
|
f' underflow_tol ={self.eps}\n'
|
|
148
|
-
f'
|
|
347
|
+
f' use_64_bit ={self.use64bit}')
|
|
149
348
|
|
|
150
349
|
# ===========================================================================
|
|
151
350
|
# logical operators
|
|
@@ -339,12 +538,14 @@ class FuzzyLogic:
|
|
|
339
538
|
|
|
340
539
|
def sgn(self):
|
|
341
540
|
if self.verbose:
|
|
342
|
-
raise_warning('Using the replacement rule:
|
|
343
|
-
|
|
541
|
+
raise_warning('Using the replacement rule: '
|
|
542
|
+
'sgn(x) --> comparison.sgn(x)')
|
|
543
|
+
|
|
544
|
+
sgn_op = self.comparison.sgn
|
|
344
545
|
debias = 'sgn' in self.debias
|
|
345
546
|
|
|
346
547
|
def _jax_wrapped_calc_sgn_approx(x, param):
|
|
347
|
-
sample =
|
|
548
|
+
sample = sgn_op(x, param)
|
|
348
549
|
if debias:
|
|
349
550
|
hard_sample = jnp.sign(x)
|
|
350
551
|
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
@@ -357,37 +558,48 @@ class FuzzyLogic:
|
|
|
357
558
|
def floor(self):
|
|
358
559
|
if self.verbose:
|
|
359
560
|
raise_warning('Using the replacement rule: '
|
|
360
|
-
'floor(x) -->
|
|
561
|
+
'floor(x) --> rounding.floor(x)')
|
|
562
|
+
|
|
563
|
+
floor_op = self.rounding.floor
|
|
564
|
+
debias = 'floor' in self.debias
|
|
361
565
|
|
|
362
566
|
def _jax_wrapped_calc_floor_approx(x, param):
|
|
363
|
-
|
|
364
|
-
|
|
567
|
+
sample = floor_op(x, param)
|
|
568
|
+
if debias:
|
|
569
|
+
hard_sample = jnp.floor(x)
|
|
570
|
+
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
365
571
|
return sample
|
|
366
572
|
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
jax_floor, jax_param = self.floor()
|
|
371
|
-
|
|
372
|
-
def _jax_wrapped_calc_ceil_approx(x, param):
|
|
373
|
-
return -jax_floor(-x, param)
|
|
573
|
+
tags = ('weight', 'floor')
|
|
574
|
+
new_param = (tags, self.weight)
|
|
575
|
+
return _jax_wrapped_calc_floor_approx, new_param
|
|
374
576
|
|
|
375
|
-
return _jax_wrapped_calc_ceil_approx, jax_param
|
|
376
|
-
|
|
377
577
|
def round(self):
|
|
378
578
|
if self.verbose:
|
|
379
|
-
raise_warning('Using the replacement rule:
|
|
579
|
+
raise_warning('Using the replacement rule: '
|
|
580
|
+
'round(x) --> rounding.round(x)')
|
|
380
581
|
|
|
582
|
+
round_op = self.rounding.round
|
|
381
583
|
debias = 'round' in self.debias
|
|
382
584
|
|
|
383
585
|
def _jax_wrapped_calc_round_approx(x, param):
|
|
384
|
-
sample = x
|
|
586
|
+
sample = round_op(x, param)
|
|
385
587
|
if debias:
|
|
386
588
|
hard_sample = jnp.round(x)
|
|
387
589
|
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
388
590
|
return sample
|
|
389
591
|
|
|
390
|
-
|
|
592
|
+
tags = ('weight', 'round')
|
|
593
|
+
new_param = (tags, self.weight)
|
|
594
|
+
return _jax_wrapped_calc_round_approx, new_param
|
|
595
|
+
|
|
596
|
+
def ceil(self):
|
|
597
|
+
jax_floor, jax_param = self.floor()
|
|
598
|
+
|
|
599
|
+
def _jax_wrapped_calc_ceil_approx(x, param):
|
|
600
|
+
return -jax_floor(-x, param)
|
|
601
|
+
|
|
602
|
+
return _jax_wrapped_calc_ceil_approx, jax_param
|
|
391
603
|
|
|
392
604
|
def mod(self):
|
|
393
605
|
jax_floor, jax_param = self.floor()
|
|
@@ -419,7 +631,7 @@ class FuzzyLogic:
|
|
|
419
631
|
# ===========================================================================
|
|
420
632
|
|
|
421
633
|
@staticmethod
|
|
422
|
-
def
|
|
634
|
+
def enumerate_literals(shape, axis):
|
|
423
635
|
literals = jnp.arange(shape[axis])
|
|
424
636
|
literals = literals[(...,) + (jnp.newaxis,) * (len(shape) - 1)]
|
|
425
637
|
literals = jnp.moveaxis(literals, source=0, destination=axis)
|
|
@@ -433,8 +645,9 @@ class FuzzyLogic:
|
|
|
433
645
|
|
|
434
646
|
debias = 'argmax' in self.debias
|
|
435
647
|
|
|
648
|
+
# https://arxiv.org/abs/2110.05651
|
|
436
649
|
def _jax_wrapped_calc_argmax_approx(x, axis, param):
|
|
437
|
-
literals = FuzzyLogic.
|
|
650
|
+
literals = FuzzyLogic.enumerate_literals(x.shape, axis=axis)
|
|
438
651
|
soft_max = jax.nn.softmax(param * x, axis=axis)
|
|
439
652
|
sample = jnp.sum(literals * soft_max, axis=axis)
|
|
440
653
|
if debias:
|
|
@@ -468,7 +681,7 @@ class FuzzyLogic:
|
|
|
468
681
|
def _jax_wrapped_calc_if_approx(c, a, b, param):
|
|
469
682
|
sample = c * a + (1.0 - c) * b
|
|
470
683
|
if debias:
|
|
471
|
-
hard_sample = jnp.
|
|
684
|
+
hard_sample = jnp.where(c > 0.5, a, b)
|
|
472
685
|
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
473
686
|
return sample
|
|
474
687
|
|
|
@@ -478,14 +691,14 @@ class FuzzyLogic:
|
|
|
478
691
|
if self.verbose:
|
|
479
692
|
raise_warning('Using the replacement rule: '
|
|
480
693
|
'switch(pred) { cases } --> '
|
|
481
|
-
'sum(cases[i] * softmax(-
|
|
694
|
+
'sum(cases[i] * softmax(-(pred - i)^2))')
|
|
482
695
|
|
|
483
696
|
debias = 'switch' in self.debias
|
|
484
697
|
|
|
485
698
|
def _jax_wrapped_calc_switch_approx(pred, cases, param):
|
|
486
|
-
literals = FuzzyLogic.
|
|
699
|
+
literals = FuzzyLogic.enumerate_literals(cases.shape, axis=0)
|
|
487
700
|
pred = jnp.broadcast_to(pred[jnp.newaxis, ...], shape=cases.shape)
|
|
488
|
-
proximity = -jnp.
|
|
701
|
+
proximity = -jnp.square(pred - literals)
|
|
489
702
|
soft_case = jax.nn.softmax(param * proximity, axis=0)
|
|
490
703
|
sample = jnp.sum(cases * soft_case, axis=0)
|
|
491
704
|
if debias:
|
|
@@ -502,46 +715,26 @@ class FuzzyLogic:
|
|
|
502
715
|
# random variables
|
|
503
716
|
# ===========================================================================
|
|
504
717
|
|
|
505
|
-
def
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
return sample
|
|
509
|
-
|
|
718
|
+
def discrete(self):
|
|
719
|
+
return self.sampling.discrete(self)
|
|
720
|
+
|
|
510
721
|
def bernoulli(self):
|
|
511
|
-
|
|
512
|
-
raise_warning('Using the replacement rule: '
|
|
513
|
-
'Bernoulli(p) --> Gumbel-softmax(p)')
|
|
514
|
-
|
|
515
|
-
jax_gs = self._gumbel_softmax
|
|
516
|
-
jax_argmax, jax_param = self.argmax()
|
|
517
|
-
|
|
518
|
-
def _jax_wrapped_calc_bernoulli_approx(key, prob, param):
|
|
519
|
-
prob = jnp.stack([1.0 - prob, prob], axis=-1)
|
|
520
|
-
sample = jax_gs(key, prob)
|
|
521
|
-
sample = jax_argmax(sample, axis=-1, param=param)
|
|
522
|
-
return sample
|
|
523
|
-
|
|
524
|
-
return _jax_wrapped_calc_bernoulli_approx, jax_param
|
|
722
|
+
return self.sampling.bernoulli(self)
|
|
525
723
|
|
|
526
|
-
def
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
jax_gs = self._gumbel_softmax
|
|
532
|
-
jax_argmax, jax_param = self.argmax()
|
|
533
|
-
|
|
534
|
-
def _jax_wrapped_calc_discrete_approx(key, prob, param):
|
|
535
|
-
sample = jax_gs(key, prob)
|
|
536
|
-
sample = jax_argmax(sample, axis=-1, param=param)
|
|
537
|
-
return sample
|
|
538
|
-
|
|
539
|
-
return _jax_wrapped_calc_discrete_approx, jax_param
|
|
540
|
-
|
|
724
|
+
def poisson(self):
|
|
725
|
+
return self.sampling.poisson(self)
|
|
726
|
+
|
|
727
|
+
def geometric(self):
|
|
728
|
+
return self.sampling.geometric(self)
|
|
541
729
|
|
|
730
|
+
|
|
731
|
+
# ===========================================================================
|
|
542
732
|
# UNIT TESTS
|
|
733
|
+
#
|
|
734
|
+
# ===========================================================================
|
|
735
|
+
|
|
543
736
|
logic = FuzzyLogic()
|
|
544
|
-
w =
|
|
737
|
+
w = 1000.0
|
|
545
738
|
|
|
546
739
|
|
|
547
740
|
def _test_logical():
|
|
@@ -568,7 +761,7 @@ def _test_logical():
|
|
|
568
761
|
def _test_indexing():
|
|
569
762
|
print('testing indexing')
|
|
570
763
|
_argmax, _ = logic.argmax()
|
|
571
|
-
_argmin, _ = logic.
|
|
764
|
+
_argmin, _ = logic.argmin()
|
|
572
765
|
|
|
573
766
|
def argmaxmin(x):
|
|
574
767
|
amax = _argmax(x, 0, w)
|
|
@@ -598,13 +791,14 @@ def _test_random():
|
|
|
598
791
|
key = random.PRNGKey(42)
|
|
599
792
|
_bernoulli, _ = logic.bernoulli()
|
|
600
793
|
_discrete, _ = logic.discrete()
|
|
794
|
+
_geometric, _ = logic.geometric()
|
|
601
795
|
|
|
602
796
|
def bern(n):
|
|
603
797
|
prob = jnp.asarray([0.3] * n)
|
|
604
798
|
sample = _bernoulli(key, prob, w)
|
|
605
799
|
return sample
|
|
606
800
|
|
|
607
|
-
samples = bern(
|
|
801
|
+
samples = bern(50000)
|
|
608
802
|
print(jnp.mean(samples))
|
|
609
803
|
|
|
610
804
|
def disc(n):
|
|
@@ -613,20 +807,30 @@ def _test_random():
|
|
|
613
807
|
sample = _discrete(key, prob, w)
|
|
614
808
|
return sample
|
|
615
809
|
|
|
616
|
-
samples = disc(
|
|
810
|
+
samples = disc(50000)
|
|
617
811
|
samples = jnp.round(samples)
|
|
618
812
|
print([jnp.mean(samples == i) for i in range(3)])
|
|
619
|
-
|
|
813
|
+
|
|
814
|
+
def geom(n):
|
|
815
|
+
prob = jnp.asarray([0.3] * n)
|
|
816
|
+
sample = _geometric(key, prob, w)
|
|
817
|
+
return sample
|
|
818
|
+
|
|
819
|
+
samples = geom(50000)
|
|
820
|
+
print(jnp.mean(samples))
|
|
821
|
+
|
|
620
822
|
|
|
621
823
|
def _test_rounding():
|
|
622
824
|
print('testing rounding')
|
|
623
825
|
_floor, _ = logic.floor()
|
|
624
826
|
_ceil, _ = logic.ceil()
|
|
827
|
+
_round, _ = logic.round()
|
|
625
828
|
_mod, _ = logic.mod()
|
|
626
829
|
|
|
627
|
-
x = jnp.asarray([2.1, 0.
|
|
830
|
+
x = jnp.asarray([2.1, 0.6, 1.99, -2.01, -3.2, -0.1, -1.01, 23.01, -101.99, 200.01])
|
|
628
831
|
print(_floor(x, w))
|
|
629
832
|
print(_ceil(x, w))
|
|
833
|
+
print(_round(x, w))
|
|
630
834
|
print(_mod(x, 2.0, w))
|
|
631
835
|
|
|
632
836
|
|