pyRDDLGym-jax 0.1__py3-none-any.whl → 0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/core/compiler.py +445 -221
- pyRDDLGym_jax/core/logic.py +129 -62
- pyRDDLGym_jax/core/planner.py +699 -332
- pyRDDLGym_jax/core/simulator.py +5 -7
- pyRDDLGym_jax/core/tuning.py +23 -12
- pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_drp.cfg → Cartpole_Continuous_gym_drp.cfg} +2 -3
- pyRDDLGym_jax/examples/configs/{HVAC_drp.cfg → HVAC_ippc2023_drp.cfg} +2 -2
- pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +19 -0
- pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +18 -0
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +18 -0
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +1 -1
- pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +1 -1
- pyRDDLGym_jax/examples/configs/default_drp.cfg +19 -0
- pyRDDLGym_jax/examples/configs/default_replan.cfg +20 -0
- pyRDDLGym_jax/examples/configs/default_slp.cfg +19 -0
- pyRDDLGym_jax/examples/run_gradient.py +1 -1
- pyRDDLGym_jax/examples/run_gym.py +1 -2
- pyRDDLGym_jax/examples/run_plan.py +7 -0
- pyRDDLGym_jax/examples/run_tune.py +6 -0
- {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.2.dist-info}/METADATA +1 -1
- pyRDDLGym_jax-0.2.dist-info/RECORD +46 -0
- {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.2.dist-info}/WHEEL +1 -1
- pyRDDLGym_jax-0.1.dist-info/RECORD +0 -40
- /pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_replan.cfg → Cartpole_Continuous_gym_replan.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{Cartpole_Continuous_slp.cfg → Cartpole_Continuous_gym_slp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{HVAC_slp.cfg → HVAC_ippc2023_slp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{MarsRover_drp.cfg → MarsRover_ippc2023_drp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{MarsRover_slp.cfg → MarsRover_ippc2023_slp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{MountainCar_slp.cfg → MountainCar_Continuous_gym_slp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{Pendulum_slp.cfg → Pendulum_gym_slp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{PowerGen_drp.cfg → PowerGen_Continuous_drp.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{PowerGen_replan.cfg → PowerGen_Continuous_replan.cfg} +0 -0
- /pyRDDLGym_jax/examples/configs/{PowerGen_slp.cfg → PowerGen_Continuous_slp.cfg} +0 -0
- {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.2.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-0.1.dist-info → pyRDDLGym_jax-0.2.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/logic.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import jax
|
|
2
2
|
import jax.numpy as jnp
|
|
3
3
|
import jax.random as random
|
|
4
|
-
from typing import Set
|
|
4
|
+
from typing import Optional, Set
|
|
5
5
|
|
|
6
6
|
from pyRDDLGym.core.debug.exception import raise_warning
|
|
7
7
|
|
|
@@ -20,6 +20,32 @@ class StandardComplement(Complement):
|
|
|
20
20
|
return 1.0 - x
|
|
21
21
|
|
|
22
22
|
|
|
23
|
+
class Comparison:
|
|
24
|
+
'''Base class for approximate comparison operations.'''
|
|
25
|
+
|
|
26
|
+
def greater_equal(self, x, y, param):
|
|
27
|
+
raise NotImplementedError
|
|
28
|
+
|
|
29
|
+
def greater(self, x, y, param):
|
|
30
|
+
raise NotImplementedError
|
|
31
|
+
|
|
32
|
+
def equal(self, x, y, param):
|
|
33
|
+
raise NotImplementedError
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class SigmoidComparison(Comparison):
|
|
37
|
+
'''Comparison operations approximated using sigmoid functions.'''
|
|
38
|
+
|
|
39
|
+
def greater_equal(self, x, y, param):
|
|
40
|
+
return jax.nn.sigmoid(param * (x - y))
|
|
41
|
+
|
|
42
|
+
def greater(self, x, y, param):
|
|
43
|
+
return jax.nn.sigmoid(param * (x - y))
|
|
44
|
+
|
|
45
|
+
def equal(self, x, y, param):
|
|
46
|
+
return 1.0 - jnp.square(jnp.tanh(param * (y - x)))
|
|
47
|
+
|
|
48
|
+
|
|
23
49
|
class TNorm:
|
|
24
50
|
'''Base class for fuzzy differentiable t-norms.'''
|
|
25
51
|
|
|
@@ -71,40 +97,61 @@ class FuzzyLogic:
|
|
|
71
97
|
|
|
72
98
|
def __init__(self, tnorm: TNorm=ProductTNorm(),
|
|
73
99
|
complement: Complement=StandardComplement(),
|
|
100
|
+
comparison: Comparison=SigmoidComparison(),
|
|
74
101
|
weight: float=10.0,
|
|
75
|
-
debias: Set[str]=
|
|
102
|
+
debias: Optional[Set[str]]=None,
|
|
76
103
|
eps: float=1e-10,
|
|
77
|
-
verbose: bool=False
|
|
104
|
+
verbose: bool=False,
|
|
105
|
+
use64bit: bool=False) -> None:
|
|
78
106
|
'''Creates a new fuzzy logic in Jax.
|
|
79
107
|
|
|
80
108
|
:param tnorm: fuzzy operator for logical AND
|
|
81
109
|
:param complement: fuzzy operator for logical NOT
|
|
82
|
-
:param
|
|
110
|
+
:param comparison: fuzzy operator for comparisons (>, >=, <, ==, ~=, ...)
|
|
111
|
+
:param weight: a sharpness parameter for sigmoid and softmax activations
|
|
83
112
|
:param error: an error parameter (e.g. floor) (smaller means better accuracy)
|
|
84
113
|
:param debias: which functions to de-bias approximate on forward pass
|
|
85
114
|
:param eps: small positive float to mitigate underflow
|
|
86
115
|
:param verbose: whether to dump replacements and other info to console
|
|
116
|
+
:param use64bit: whether to perform arithmetic in 64 bit
|
|
87
117
|
'''
|
|
88
118
|
self.tnorm = tnorm
|
|
89
119
|
self.complement = complement
|
|
120
|
+
self.comparison = comparison
|
|
90
121
|
self.weight = float(weight)
|
|
122
|
+
if debias is None:
|
|
123
|
+
debias = set()
|
|
91
124
|
self.debias = debias
|
|
92
125
|
self.eps = eps
|
|
93
126
|
self.verbose = verbose
|
|
94
|
-
|
|
95
|
-
|
|
127
|
+
self.set_use64bit(use64bit)
|
|
128
|
+
|
|
129
|
+
def set_use64bit(self, use64bit: bool) -> None:
|
|
130
|
+
self.use64bit = use64bit
|
|
131
|
+
if use64bit:
|
|
132
|
+
self.REAL = jnp.float64
|
|
133
|
+
self.INT = jnp.int64
|
|
134
|
+
jax.config.update('jax_enable_x64', True)
|
|
135
|
+
else:
|
|
136
|
+
self.REAL = jnp.float32
|
|
137
|
+
self.INT = jnp.int32
|
|
138
|
+
jax.config.update('jax_enable_x64', False)
|
|
139
|
+
|
|
140
|
+
def summarize_hyperparameters(self) -> None:
|
|
96
141
|
print(f'model relaxation:\n'
|
|
97
142
|
f' tnorm ={type(self.tnorm).__name__}\n'
|
|
98
143
|
f' complement ={type(self.complement).__name__}\n'
|
|
144
|
+
f' comparison ={type(self.comparison).__name__}\n'
|
|
99
145
|
f' sigmoid_weight={self.weight}\n'
|
|
100
146
|
f' cpfs_to_debias={self.debias}\n'
|
|
101
|
-
f' underflow_tol ={self.eps}'
|
|
147
|
+
f' underflow_tol ={self.eps}\n'
|
|
148
|
+
f' use64bit ={self.use64bit}')
|
|
102
149
|
|
|
103
150
|
# ===========================================================================
|
|
104
151
|
# logical operators
|
|
105
152
|
# ===========================================================================
|
|
106
153
|
|
|
107
|
-
def
|
|
154
|
+
def logical_and(self):
|
|
108
155
|
if self.verbose:
|
|
109
156
|
raise_warning('Using the replacement rule: a ^ b --> tnorm(a, b).')
|
|
110
157
|
|
|
@@ -115,9 +162,9 @@ class FuzzyLogic:
|
|
|
115
162
|
|
|
116
163
|
return _jax_wrapped_calc_and_approx, None
|
|
117
164
|
|
|
118
|
-
def
|
|
165
|
+
def logical_not(self):
|
|
119
166
|
if self.verbose:
|
|
120
|
-
raise_warning('Using the replacement rule: ~a -->
|
|
167
|
+
raise_warning('Using the replacement rule: ~a --> complement(a)')
|
|
121
168
|
|
|
122
169
|
_not = self.complement
|
|
123
170
|
|
|
@@ -126,9 +173,9 @@ class FuzzyLogic:
|
|
|
126
173
|
|
|
127
174
|
return _jax_wrapped_calc_not_approx, None
|
|
128
175
|
|
|
129
|
-
def
|
|
176
|
+
def logical_or(self):
|
|
130
177
|
if self.verbose:
|
|
131
|
-
raise_warning('Using the replacement rule: a
|
|
178
|
+
raise_warning('Using the replacement rule: a | b --> tconorm(a, b).')
|
|
132
179
|
|
|
133
180
|
_not = self.complement
|
|
134
181
|
_and = self.tnorm.norm
|
|
@@ -141,7 +188,7 @@ class FuzzyLogic:
|
|
|
141
188
|
def xor(self):
|
|
142
189
|
if self.verbose:
|
|
143
190
|
raise_warning('Using the replacement rule: '
|
|
144
|
-
'a
|
|
191
|
+
'a ~ b --> (a | b) ^ (a ^ b).')
|
|
145
192
|
|
|
146
193
|
_not = self.complement
|
|
147
194
|
_and = self.tnorm.norm
|
|
@@ -182,7 +229,7 @@ class FuzzyLogic:
|
|
|
182
229
|
def forall(self):
|
|
183
230
|
if self.verbose:
|
|
184
231
|
raise_warning('Using the replacement rule: '
|
|
185
|
-
'forall(a) -->
|
|
232
|
+
'forall(a) --> a[1] ^ a[2] ^ ...')
|
|
186
233
|
|
|
187
234
|
_forall = self.tnorm.norms
|
|
188
235
|
|
|
@@ -204,31 +251,35 @@ class FuzzyLogic:
|
|
|
204
251
|
# comparison operators
|
|
205
252
|
# ===========================================================================
|
|
206
253
|
|
|
207
|
-
def
|
|
254
|
+
def greater_equal(self):
|
|
208
255
|
if self.verbose:
|
|
209
|
-
raise_warning('Using the replacement rule:
|
|
210
|
-
|
|
211
|
-
|
|
256
|
+
raise_warning('Using the replacement rule: '
|
|
257
|
+
'a >= b --> comparison.greater_equal(a, b)')
|
|
258
|
+
|
|
259
|
+
greater_equal_op = self.comparison.greater_equal
|
|
260
|
+
debias = 'greater_equal' in self.debias
|
|
212
261
|
|
|
213
262
|
def _jax_wrapped_calc_geq_approx(a, b, param):
|
|
214
|
-
sample =
|
|
263
|
+
sample = greater_equal_op(a, b, param)
|
|
215
264
|
if debias:
|
|
216
265
|
hard_sample = jnp.greater_equal(a, b)
|
|
217
266
|
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
218
267
|
return sample
|
|
219
268
|
|
|
220
|
-
tags = ('weight', '
|
|
269
|
+
tags = ('weight', 'greater_equal')
|
|
221
270
|
new_param = (tags, self.weight)
|
|
222
271
|
return _jax_wrapped_calc_geq_approx, new_param
|
|
223
272
|
|
|
224
273
|
def greater(self):
|
|
225
274
|
if self.verbose:
|
|
226
|
-
raise_warning('Using the replacement rule:
|
|
227
|
-
|
|
275
|
+
raise_warning('Using the replacement rule: '
|
|
276
|
+
'a > b --> comparison.greater(a, b)')
|
|
277
|
+
|
|
278
|
+
greater_op = self.comparison.greater
|
|
228
279
|
debias = 'greater' in self.debias
|
|
229
280
|
|
|
230
281
|
def _jax_wrapped_calc_gre_approx(a, b, param):
|
|
231
|
-
sample =
|
|
282
|
+
sample = greater_op(a, b, param)
|
|
232
283
|
if debias:
|
|
233
284
|
hard_sample = jnp.greater(a, b)
|
|
234
285
|
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
@@ -238,8 +289,8 @@ class FuzzyLogic:
|
|
|
238
289
|
new_param = (tags, self.weight)
|
|
239
290
|
return _jax_wrapped_calc_gre_approx, new_param
|
|
240
291
|
|
|
241
|
-
def
|
|
242
|
-
jax_geq, jax_param = self.
|
|
292
|
+
def less_equal(self):
|
|
293
|
+
jax_geq, jax_param = self.greater_equal()
|
|
243
294
|
|
|
244
295
|
def _jax_wrapped_calc_leq_approx(a, b, param):
|
|
245
296
|
return jax_geq(-a, -b, param)
|
|
@@ -256,12 +307,14 @@ class FuzzyLogic:
|
|
|
256
307
|
|
|
257
308
|
def equal(self):
|
|
258
309
|
if self.verbose:
|
|
259
|
-
raise_warning('Using the replacement rule:
|
|
260
|
-
|
|
310
|
+
raise_warning('Using the replacement rule: '
|
|
311
|
+
'a == b --> comparison.equal(a, b)')
|
|
312
|
+
|
|
313
|
+
equal_op = self.comparison.equal
|
|
261
314
|
debias = 'equal' in self.debias
|
|
262
315
|
|
|
263
316
|
def _jax_wrapped_calc_equal_approx(a, b, param):
|
|
264
|
-
sample =
|
|
317
|
+
sample = equal_op(a, b, param)
|
|
265
318
|
if debias:
|
|
266
319
|
hard_sample = jnp.equal(a, b)
|
|
267
320
|
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
@@ -271,7 +324,7 @@ class FuzzyLogic:
|
|
|
271
324
|
new_param = (tags, self.weight)
|
|
272
325
|
return _jax_wrapped_calc_equal_approx, new_param
|
|
273
326
|
|
|
274
|
-
def
|
|
327
|
+
def not_equal(self):
|
|
275
328
|
_not = self.complement
|
|
276
329
|
jax_eq, jax_param = self.equal()
|
|
277
330
|
|
|
@@ -284,31 +337,32 @@ class FuzzyLogic:
|
|
|
284
337
|
# special functions
|
|
285
338
|
# ===========================================================================
|
|
286
339
|
|
|
287
|
-
def
|
|
340
|
+
def sgn(self):
|
|
288
341
|
if self.verbose:
|
|
289
|
-
raise_warning('Using the replacement rule:
|
|
342
|
+
raise_warning('Using the replacement rule: sgn(x) --> tanh(x)')
|
|
290
343
|
|
|
291
|
-
debias = '
|
|
344
|
+
debias = 'sgn' in self.debias
|
|
292
345
|
|
|
293
|
-
def
|
|
346
|
+
def _jax_wrapped_calc_sgn_approx(x, param):
|
|
294
347
|
sample = jnp.tanh(param * x)
|
|
295
348
|
if debias:
|
|
296
349
|
hard_sample = jnp.sign(x)
|
|
297
350
|
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
298
351
|
return sample
|
|
299
352
|
|
|
300
|
-
tags = ('weight', '
|
|
353
|
+
tags = ('weight', 'sgn')
|
|
301
354
|
new_param = (tags, self.weight)
|
|
302
|
-
return
|
|
355
|
+
return _jax_wrapped_calc_sgn_approx, new_param
|
|
303
356
|
|
|
304
357
|
def floor(self):
|
|
305
358
|
if self.verbose:
|
|
306
359
|
raise_warning('Using the replacement rule: '
|
|
307
360
|
'floor(x) --> x - atan(-1.0 / tan(pi * x)) / pi - 0.5')
|
|
308
|
-
|
|
361
|
+
|
|
309
362
|
def _jax_wrapped_calc_floor_approx(x, param):
|
|
310
363
|
sawtooth_part = jnp.arctan(-1.0 / jnp.tan(x * jnp.pi)) / jnp.pi + 0.5
|
|
311
|
-
|
|
364
|
+
sample = x - jax.lax.stop_gradient(sawtooth_part)
|
|
365
|
+
return sample
|
|
312
366
|
|
|
313
367
|
return _jax_wrapped_calc_floor_approx, None
|
|
314
368
|
|
|
@@ -324,8 +378,14 @@ class FuzzyLogic:
|
|
|
324
378
|
if self.verbose:
|
|
325
379
|
raise_warning('Using the replacement rule: round(x) --> x')
|
|
326
380
|
|
|
381
|
+
debias = 'round' in self.debias
|
|
382
|
+
|
|
327
383
|
def _jax_wrapped_calc_round_approx(x, param):
|
|
328
|
-
|
|
384
|
+
sample = x
|
|
385
|
+
if debias:
|
|
386
|
+
hard_sample = jnp.round(x)
|
|
387
|
+
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
388
|
+
return sample
|
|
329
389
|
|
|
330
390
|
return _jax_wrapped_calc_round_approx, None
|
|
331
391
|
|
|
@@ -337,13 +397,13 @@ class FuzzyLogic:
|
|
|
337
397
|
|
|
338
398
|
return _jax_wrapped_calc_mod_approx, jax_param
|
|
339
399
|
|
|
340
|
-
def
|
|
400
|
+
def div(self):
|
|
341
401
|
jax_floor, jax_param = self.floor()
|
|
342
402
|
|
|
343
|
-
def
|
|
403
|
+
def _jax_wrapped_calc_div_approx(x, y, param):
|
|
344
404
|
return jax_floor(x / y, param)
|
|
345
405
|
|
|
346
|
-
return
|
|
406
|
+
return _jax_wrapped_calc_div_approx, jax_param
|
|
347
407
|
|
|
348
408
|
def sqrt(self):
|
|
349
409
|
if self.verbose:
|
|
@@ -369,14 +429,14 @@ class FuzzyLogic:
|
|
|
369
429
|
def argmax(self):
|
|
370
430
|
if self.verbose:
|
|
371
431
|
raise_warning('Using the replacement rule: '
|
|
372
|
-
|
|
432
|
+
'argmax(x) --> sum(i * softmax(x[i]))')
|
|
373
433
|
|
|
374
434
|
debias = 'argmax' in self.debias
|
|
375
435
|
|
|
376
436
|
def _jax_wrapped_calc_argmax_approx(x, axis, param):
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
sample = jnp.sum(literals *
|
|
437
|
+
literals = FuzzyLogic._literals(x.shape, axis=axis)
|
|
438
|
+
soft_max = jax.nn.softmax(param * x, axis=axis)
|
|
439
|
+
sample = jnp.sum(literals * soft_max, axis=axis)
|
|
380
440
|
if debias:
|
|
381
441
|
hard_sample = jnp.argmax(x, axis=axis)
|
|
382
442
|
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
@@ -398,26 +458,33 @@ class FuzzyLogic:
|
|
|
398
458
|
# control flow
|
|
399
459
|
# ===========================================================================
|
|
400
460
|
|
|
401
|
-
def
|
|
461
|
+
def control_if(self):
|
|
402
462
|
if self.verbose:
|
|
403
463
|
raise_warning('Using the replacement rule: '
|
|
404
464
|
'if c then a else b --> c * a + (1 - c) * b')
|
|
405
465
|
|
|
466
|
+
debias = 'if' in self.debias
|
|
467
|
+
|
|
406
468
|
def _jax_wrapped_calc_if_approx(c, a, b, param):
|
|
407
|
-
|
|
469
|
+
sample = c * a + (1.0 - c) * b
|
|
470
|
+
if debias:
|
|
471
|
+
hard_sample = jnp.select([c, ~c], [a, b])
|
|
472
|
+
sample += jax.lax.stop_gradient(hard_sample - sample)
|
|
473
|
+
return sample
|
|
408
474
|
|
|
409
475
|
return _jax_wrapped_calc_if_approx, None
|
|
410
476
|
|
|
411
|
-
def
|
|
477
|
+
def control_switch(self):
|
|
412
478
|
if self.verbose:
|
|
413
479
|
raise_warning('Using the replacement rule: '
|
|
414
|
-
'switch(pred) { cases } -->
|
|
480
|
+
'switch(pred) { cases } --> '
|
|
481
|
+
'sum(cases[i] * softmax(-abs(pred - i)))')
|
|
415
482
|
|
|
416
|
-
debias = '
|
|
483
|
+
debias = 'switch' in self.debias
|
|
417
484
|
|
|
418
485
|
def _jax_wrapped_calc_switch_approx(pred, cases, param):
|
|
419
|
-
pred = jnp.broadcast_to(pred[jnp.newaxis, ...], shape=cases.shape)
|
|
420
486
|
literals = FuzzyLogic._literals(cases.shape, axis=0)
|
|
487
|
+
pred = jnp.broadcast_to(pred[jnp.newaxis, ...], shape=cases.shape)
|
|
421
488
|
proximity = -jnp.abs(pred - literals)
|
|
422
489
|
soft_case = jax.nn.softmax(param * proximity, axis=0)
|
|
423
490
|
sample = jnp.sum(cases * soft_case, axis=0)
|
|
@@ -436,8 +503,8 @@ class FuzzyLogic:
|
|
|
436
503
|
# ===========================================================================
|
|
437
504
|
|
|
438
505
|
def _gumbel_softmax(self, key, prob):
|
|
439
|
-
Gumbel01 = random.gumbel(key=key, shape=prob.shape)
|
|
440
|
-
sample = Gumbel01 + jnp.
|
|
506
|
+
Gumbel01 = random.gumbel(key=key, shape=prob.shape, dtype=self.REAL)
|
|
507
|
+
sample = Gumbel01 + jnp.log(prob + self.eps)
|
|
441
508
|
return sample
|
|
442
509
|
|
|
443
510
|
def bernoulli(self):
|
|
@@ -448,13 +515,13 @@ class FuzzyLogic:
|
|
|
448
515
|
jax_gs = self._gumbel_softmax
|
|
449
516
|
jax_argmax, jax_param = self.argmax()
|
|
450
517
|
|
|
451
|
-
def
|
|
518
|
+
def _jax_wrapped_calc_bernoulli_approx(key, prob, param):
|
|
452
519
|
prob = jnp.stack([1.0 - prob, prob], axis=-1)
|
|
453
520
|
sample = jax_gs(key, prob)
|
|
454
|
-
sample = jax_argmax(sample,
|
|
521
|
+
sample = jax_argmax(sample, axis=-1, param=param)
|
|
455
522
|
return sample
|
|
456
523
|
|
|
457
|
-
return
|
|
524
|
+
return _jax_wrapped_calc_bernoulli_approx, jax_param
|
|
458
525
|
|
|
459
526
|
def discrete(self):
|
|
460
527
|
if self.verbose:
|
|
@@ -466,7 +533,7 @@ class FuzzyLogic:
|
|
|
466
533
|
|
|
467
534
|
def _jax_wrapped_calc_discrete_approx(key, prob, param):
|
|
468
535
|
sample = jax_gs(key, prob)
|
|
469
|
-
sample = jax_argmax(sample,
|
|
536
|
+
sample = jax_argmax(sample, axis=-1, param=param)
|
|
470
537
|
return sample
|
|
471
538
|
|
|
472
539
|
return _jax_wrapped_calc_discrete_approx, jax_param
|
|
@@ -479,11 +546,11 @@ w = 100.0
|
|
|
479
546
|
|
|
480
547
|
def _test_logical():
|
|
481
548
|
print('testing logical')
|
|
482
|
-
_and, _ = logic.
|
|
483
|
-
_not, _ = logic.
|
|
549
|
+
_and, _ = logic.logical_and()
|
|
550
|
+
_not, _ = logic.logical_not()
|
|
484
551
|
_gre, _ = logic.greater()
|
|
485
|
-
_or, _ = logic.
|
|
486
|
-
_if, _ = logic.
|
|
552
|
+
_or, _ = logic.logical_or()
|
|
553
|
+
_if, _ = logic.control_if()
|
|
487
554
|
|
|
488
555
|
# https://towardsdatascience.com/emulating-logical-gates-with-a-neural-network-75c229ec4cc9
|
|
489
556
|
def test_logic(x1, x2):
|
|
@@ -516,7 +583,7 @@ def _test_indexing():
|
|
|
516
583
|
|
|
517
584
|
def _test_control():
|
|
518
585
|
print('testing control')
|
|
519
|
-
_switch, _ = logic.
|
|
586
|
+
_switch, _ = logic.control_switch()
|
|
520
587
|
|
|
521
588
|
pred = jnp.asarray(jnp.linspace(0, 2, 10))
|
|
522
589
|
case1 = jnp.asarray([-10.] * 10)
|