pyRDDLGym-jax 2.2__py3-none-any.whl → 2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +16 -11
- pyRDDLGym_jax/core/logic.py +233 -119
- pyRDDLGym_jax/core/planner.py +489 -218
- pyRDDLGym_jax/core/tuning.py +28 -22
- pyRDDLGym_jax/examples/run_plan.py +2 -2
- pyRDDLGym_jax/examples/run_scipy.py +2 -2
- {pyrddlgym_jax-2.2.dist-info → pyrddlgym_jax-2.4.dist-info}/METADATA +1 -1
- {pyrddlgym_jax-2.2.dist-info → pyrddlgym_jax-2.4.dist-info}/RECORD +13 -13
- {pyrddlgym_jax-2.2.dist-info → pyrddlgym_jax-2.4.dist-info}/WHEEL +1 -1
- {pyrddlgym_jax-2.2.dist-info → pyrddlgym_jax-2.4.dist-info}/LICENSE +0 -0
- {pyrddlgym_jax-2.2.dist-info → pyrddlgym_jax-2.4.dist-info}/entry_points.txt +0 -0
- {pyrddlgym_jax-2.2.dist-info → pyrddlgym_jax-2.4.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/logic.py
CHANGED
|
@@ -29,15 +29,29 @@
|
|
|
29
29
|
#
|
|
30
30
|
# ***********************************************************************
|
|
31
31
|
|
|
32
|
-
|
|
32
|
+
|
|
33
|
+
from abc import ABCMeta, abstractmethod
|
|
34
|
+
import traceback
|
|
35
|
+
from typing import Callable, Dict, Tuple, Union
|
|
33
36
|
|
|
34
37
|
import jax
|
|
35
38
|
import jax.numpy as jnp
|
|
36
39
|
import jax.random as random
|
|
37
40
|
import jax.scipy as scipy
|
|
38
41
|
|
|
42
|
+
from pyRDDLGym.core.debug.exception import raise_warning
|
|
43
|
+
|
|
44
|
+
# more robust approach - if user does not have this or broken try to continue
|
|
45
|
+
try:
|
|
46
|
+
from tensorflow_probability.substrates import jax as tfp
|
|
47
|
+
except Exception:
|
|
48
|
+
raise_warning('Failed to import tensorflow-probability: '
|
|
49
|
+
'compilation of some probability distributions will fail.', 'red')
|
|
50
|
+
traceback.print_exc()
|
|
51
|
+
tfp = None
|
|
52
|
+
|
|
39
53
|
|
|
40
|
-
def enumerate_literals(shape, axis, dtype=jnp.int32):
|
|
54
|
+
def enumerate_literals(shape: Tuple[int, ...], axis: int, dtype: type=jnp.int32) -> jnp.ndarray:
|
|
41
55
|
literals = jnp.arange(shape[axis], dtype=dtype)
|
|
42
56
|
literals = literals[(...,) + (jnp.newaxis,) * (len(shape) - 1)]
|
|
43
57
|
literals = jnp.moveaxis(literals, source=0, destination=axis)
|
|
@@ -52,30 +66,35 @@ def enumerate_literals(shape, axis, dtype=jnp.int32):
|
|
|
52
66
|
#
|
|
53
67
|
# ===========================================================================
|
|
54
68
|
|
|
55
|
-
class Comparison:
|
|
69
|
+
class Comparison(metaclass=ABCMeta):
|
|
56
70
|
'''Base class for approximate comparison operations.'''
|
|
57
71
|
|
|
72
|
+
@abstractmethod
|
|
58
73
|
def greater_equal(self, id, init_params):
|
|
59
|
-
|
|
74
|
+
pass
|
|
60
75
|
|
|
76
|
+
@abstractmethod
|
|
61
77
|
def greater(self, id, init_params):
|
|
62
|
-
|
|
78
|
+
pass
|
|
63
79
|
|
|
80
|
+
@abstractmethod
|
|
64
81
|
def equal(self, id, init_params):
|
|
65
|
-
|
|
82
|
+
pass
|
|
66
83
|
|
|
84
|
+
@abstractmethod
|
|
67
85
|
def sgn(self, id, init_params):
|
|
68
|
-
|
|
86
|
+
pass
|
|
69
87
|
|
|
88
|
+
@abstractmethod
|
|
70
89
|
def argmax(self, id, init_params):
|
|
71
|
-
|
|
90
|
+
pass
|
|
72
91
|
|
|
73
92
|
|
|
74
93
|
class SigmoidComparison(Comparison):
|
|
75
94
|
'''Comparison operations approximated using sigmoid functions.'''
|
|
76
95
|
|
|
77
|
-
def __init__(self, weight: float=10.0):
|
|
78
|
-
self.weight = weight
|
|
96
|
+
def __init__(self, weight: float=10.0) -> None:
|
|
97
|
+
self.weight = float(weight)
|
|
79
98
|
|
|
80
99
|
# https://arxiv.org/abs/2110.05651
|
|
81
100
|
def greater_equal(self, id, init_params):
|
|
@@ -127,21 +146,23 @@ class SigmoidComparison(Comparison):
|
|
|
127
146
|
#
|
|
128
147
|
# ===========================================================================
|
|
129
148
|
|
|
130
|
-
class Rounding:
|
|
149
|
+
class Rounding(metaclass=ABCMeta):
|
|
131
150
|
'''Base class for approximate rounding operations.'''
|
|
132
151
|
|
|
152
|
+
@abstractmethod
|
|
133
153
|
def floor(self, id, init_params):
|
|
134
|
-
|
|
154
|
+
pass
|
|
135
155
|
|
|
156
|
+
@abstractmethod
|
|
136
157
|
def round(self, id, init_params):
|
|
137
|
-
|
|
158
|
+
pass
|
|
138
159
|
|
|
139
160
|
|
|
140
161
|
class SoftRounding(Rounding):
|
|
141
162
|
'''Rounding operations approximated using soft operations.'''
|
|
142
163
|
|
|
143
|
-
def __init__(self, weight: float=10.0):
|
|
144
|
-
self.weight = weight
|
|
164
|
+
def __init__(self, weight: float=10.0) -> None:
|
|
165
|
+
self.weight = float(weight)
|
|
145
166
|
|
|
146
167
|
# https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/bijectors/Softfloor
|
|
147
168
|
def floor(self, id, init_params):
|
|
@@ -177,11 +198,12 @@ class SoftRounding(Rounding):
|
|
|
177
198
|
#
|
|
178
199
|
# ===========================================================================
|
|
179
200
|
|
|
180
|
-
class Complement:
|
|
201
|
+
class Complement(metaclass=ABCMeta):
|
|
181
202
|
'''Base class for approximate logical complement operations.'''
|
|
182
203
|
|
|
204
|
+
@abstractmethod
|
|
183
205
|
def __call__(self, id, init_params):
|
|
184
|
-
|
|
206
|
+
pass
|
|
185
207
|
|
|
186
208
|
|
|
187
209
|
class StandardComplement(Complement):
|
|
@@ -210,16 +232,18 @@ class StandardComplement(Complement):
|
|
|
210
232
|
# https://www.sciencedirect.com/science/article/abs/pii/016501149190171L
|
|
211
233
|
# ===========================================================================
|
|
212
234
|
|
|
213
|
-
class TNorm:
|
|
235
|
+
class TNorm(metaclass=ABCMeta):
|
|
214
236
|
'''Base class for fuzzy differentiable t-norms.'''
|
|
215
237
|
|
|
238
|
+
@abstractmethod
|
|
216
239
|
def norm(self, id, init_params):
|
|
217
240
|
'''Elementwise t-norm of x and y.'''
|
|
218
|
-
|
|
241
|
+
pass
|
|
219
242
|
|
|
243
|
+
@abstractmethod
|
|
220
244
|
def norms(self, id, init_params):
|
|
221
245
|
'''T-norm computed for tensor x along axis.'''
|
|
222
|
-
|
|
246
|
+
pass
|
|
223
247
|
|
|
224
248
|
|
|
225
249
|
class ProductTNorm(TNorm):
|
|
@@ -291,7 +315,7 @@ class YagerTNorm(TNorm):
|
|
|
291
315
|
'''Yager t-norm given by the expression
|
|
292
316
|
(x, y) -> max(1 - ((1 - x)^p + (1 - y)^p)^(1/p)).'''
|
|
293
317
|
|
|
294
|
-
def __init__(self, p=2.0):
|
|
318
|
+
def __init__(self, p: float=2.0) -> None:
|
|
295
319
|
self.p = float(p)
|
|
296
320
|
|
|
297
321
|
def norm(self, id, init_params):
|
|
@@ -327,23 +351,32 @@ class YagerTNorm(TNorm):
|
|
|
327
351
|
#
|
|
328
352
|
# ===========================================================================
|
|
329
353
|
|
|
330
|
-
class RandomSampling:
|
|
354
|
+
class RandomSampling(metaclass=ABCMeta):
|
|
331
355
|
'''Describes how non-reparameterizable random variables are sampled.'''
|
|
332
356
|
|
|
357
|
+
@abstractmethod
|
|
333
358
|
def discrete(self, id, init_params, logic):
|
|
334
|
-
|
|
359
|
+
pass
|
|
335
360
|
|
|
361
|
+
@abstractmethod
|
|
336
362
|
def poisson(self, id, init_params, logic):
|
|
337
|
-
|
|
363
|
+
pass
|
|
338
364
|
|
|
365
|
+
@abstractmethod
|
|
339
366
|
def binomial(self, id, init_params, logic):
|
|
340
|
-
|
|
367
|
+
pass
|
|
341
368
|
|
|
369
|
+
@abstractmethod
|
|
370
|
+
def negative_binomial(self, id, init_params, logic):
|
|
371
|
+
pass
|
|
372
|
+
|
|
373
|
+
@abstractmethod
|
|
342
374
|
def geometric(self, id, init_params, logic):
|
|
343
|
-
|
|
375
|
+
pass
|
|
344
376
|
|
|
377
|
+
@abstractmethod
|
|
345
378
|
def bernoulli(self, id, init_params, logic):
|
|
346
|
-
|
|
379
|
+
pass
|
|
347
380
|
|
|
348
381
|
def __str__(self) -> str:
|
|
349
382
|
return 'RandomSampling'
|
|
@@ -386,8 +419,7 @@ class SoftRandomSampling(RandomSampling):
|
|
|
386
419
|
def _poisson_gumbel_softmax(self, id, init_params, logic):
|
|
387
420
|
argmax_approx = logic.argmax(id, init_params)
|
|
388
421
|
def _jax_wrapped_calc_poisson_gumbel_softmax(key, rate, params):
|
|
389
|
-
ks = jnp.arange(
|
|
390
|
-
ks = ks[(jnp.newaxis,) * jnp.ndim(rate) + (...,)]
|
|
422
|
+
ks = jnp.arange(self.poisson_bins)[(jnp.newaxis,) * jnp.ndim(rate) + (...,)]
|
|
391
423
|
rate = rate[..., jnp.newaxis]
|
|
392
424
|
log_prob = ks * jnp.log(rate + logic.eps) - rate - scipy.special.gammaln(ks + 1)
|
|
393
425
|
Gumbel01 = random.gumbel(key=key, shape=jnp.shape(log_prob), dtype=logic.REAL)
|
|
@@ -400,10 +432,7 @@ class SoftRandomSampling(RandomSampling):
|
|
|
400
432
|
less_approx = logic.less(id, init_params)
|
|
401
433
|
def _jax_wrapped_calc_poisson_exponential(key, rate, params):
|
|
402
434
|
Exp1 = random.exponential(
|
|
403
|
-
key=key,
|
|
404
|
-
shape=(self.poisson_bins,) + jnp.shape(rate),
|
|
405
|
-
dtype=logic.REAL
|
|
406
|
-
)
|
|
435
|
+
key=key, shape=(self.poisson_bins,) + jnp.shape(rate), dtype=logic.REAL)
|
|
407
436
|
delta_t = Exp1 / rate[jnp.newaxis, ...]
|
|
408
437
|
times = jnp.cumsum(delta_t, axis=0)
|
|
409
438
|
indicator, params = less_approx(times, 1.0, params)
|
|
@@ -411,72 +440,98 @@ class SoftRandomSampling(RandomSampling):
|
|
|
411
440
|
return sample, params
|
|
412
441
|
return _jax_wrapped_calc_poisson_exponential
|
|
413
442
|
|
|
443
|
+
# normal approximation to Poisson: Poisson(rate) -> Normal(rate, rate)
|
|
444
|
+
def _poisson_normal_approx(self, logic):
|
|
445
|
+
def _jax_wrapped_calc_poisson_normal_approx(key, rate, params):
|
|
446
|
+
normal = random.normal(key=key, shape=jnp.shape(rate), dtype=logic.REAL)
|
|
447
|
+
sample = rate + jnp.sqrt(rate) * normal
|
|
448
|
+
return sample, params
|
|
449
|
+
return _jax_wrapped_calc_poisson_normal_approx
|
|
450
|
+
|
|
414
451
|
def poisson(self, id, init_params, logic):
|
|
415
|
-
def _jax_wrapped_calc_poisson_exact(key, rate, params):
|
|
416
|
-
sample = random.poisson(key=key, lam=rate, dtype=logic.INT)
|
|
417
|
-
sample = jnp.asarray(sample, dtype=logic.REAL)
|
|
418
|
-
return sample, params
|
|
419
|
-
|
|
420
452
|
if self.poisson_exp_method:
|
|
421
453
|
_jax_wrapped_calc_poisson_diff = self._poisson_exponential(
|
|
422
454
|
id, init_params, logic)
|
|
423
455
|
else:
|
|
424
456
|
_jax_wrapped_calc_poisson_diff = self._poisson_gumbel_softmax(
|
|
425
457
|
id, init_params, logic)
|
|
458
|
+
_jax_wrapped_calc_poisson_normal = self._poisson_normal_approx(logic)
|
|
426
459
|
|
|
460
|
+
# for small rate use the Poisson process or gumbel-softmax reparameterization
|
|
461
|
+
# for large rate use the normal approximation
|
|
427
462
|
def _jax_wrapped_calc_poisson_approx(key, rate, params):
|
|
428
|
-
|
|
429
|
-
# determine if error of truncation at rate is acceptable
|
|
430
463
|
if self.poisson_bins > 0:
|
|
431
464
|
cuml_prob = scipy.stats.poisson.cdf(self.poisson_bins, rate)
|
|
432
|
-
|
|
433
|
-
|
|
465
|
+
small_rate = jax.lax.stop_gradient(cuml_prob >= self.poisson_min_cdf)
|
|
466
|
+
small_sample, params = _jax_wrapped_calc_poisson_diff(key, rate, params)
|
|
467
|
+
large_sample, params = _jax_wrapped_calc_poisson_normal(key, rate, params)
|
|
468
|
+
sample = jnp.where(small_rate, small_sample, large_sample)
|
|
469
|
+
return sample, params
|
|
434
470
|
else:
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
# for acceptable truncation use the approximation, use exact otherwise
|
|
438
|
-
return jax.lax.cond(
|
|
439
|
-
approx_cond,
|
|
440
|
-
_jax_wrapped_calc_poisson_diff,
|
|
441
|
-
_jax_wrapped_calc_poisson_exact,
|
|
442
|
-
key, rate, params
|
|
443
|
-
)
|
|
471
|
+
return _jax_wrapped_calc_poisson_normal(key, rate, params)
|
|
444
472
|
return _jax_wrapped_calc_poisson_approx
|
|
445
473
|
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
474
|
+
# normal approximation to Binomial: Bin(n, p) -> Normal(np, np(1-p))
|
|
475
|
+
def _binomial_normal_approx(self, logic):
|
|
476
|
+
def _jax_wrapped_calc_binomial_normal_approx(key, trials, prob, params):
|
|
477
|
+
normal = random.normal(key=key, shape=jnp.shape(trials), dtype=logic.REAL)
|
|
478
|
+
mean = trials * prob
|
|
479
|
+
std = jnp.sqrt(trials * prob * (1.0 - prob))
|
|
480
|
+
sample = mean + std * normal
|
|
481
|
+
return sample, params
|
|
482
|
+
return _jax_wrapped_calc_binomial_normal_approx
|
|
452
483
|
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
def
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
484
|
+
def _binomial_gumbel_softmax(self, id, init_params, logic):
|
|
485
|
+
argmax_approx = logic.argmax(id, init_params)
|
|
486
|
+
def _jax_wrapped_calc_binomial_gumbel_softmax(key, trials, prob, params):
|
|
487
|
+
ks = jnp.arange(self.binomial_bins)[(jnp.newaxis,) * jnp.ndim(trials) + (...,)]
|
|
488
|
+
trials = trials[..., jnp.newaxis]
|
|
489
|
+
prob = prob[..., jnp.newaxis]
|
|
490
|
+
in_support = ks <= trials
|
|
491
|
+
ks = jnp.minimum(ks, trials)
|
|
492
|
+
log_prob = ((scipy.special.gammaln(trials + 1) -
|
|
493
|
+
scipy.special.gammaln(ks + 1) -
|
|
494
|
+
scipy.special.gammaln(trials - ks + 1)) +
|
|
495
|
+
ks * jnp.log(prob + logic.eps) +
|
|
496
|
+
(trials - ks) * jnp.log1p(-prob + logic.eps))
|
|
497
|
+
log_prob = jnp.where(in_support, log_prob, jnp.log(logic.eps))
|
|
498
|
+
Gumbel01 = random.gumbel(key=key, shape=jnp.shape(log_prob), dtype=logic.REAL)
|
|
499
|
+
sample = Gumbel01 + log_prob
|
|
500
|
+
return argmax_approx(sample, axis=-1, params=params)
|
|
501
|
+
return _jax_wrapped_calc_binomial_gumbel_softmax
|
|
502
|
+
|
|
503
|
+
def binomial(self, id, init_params, logic):
|
|
504
|
+
_jax_wrapped_calc_binomial_normal = self._binomial_normal_approx(logic)
|
|
505
|
+
_jax_wrapped_calc_binomial_gs = self._binomial_gumbel_softmax(id, init_params, logic)
|
|
464
506
|
|
|
465
|
-
# for trials
|
|
507
|
+
# for small trials use the Bernoulli relaxation
|
|
508
|
+
# for large trials use the normal approximation
|
|
466
509
|
def _jax_wrapped_calc_binomial_approx(key, trials, prob, params):
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
)
|
|
510
|
+
small_trials = jax.lax.stop_gradient(trials < self.binomial_bins)
|
|
511
|
+
small_sample, params = _jax_wrapped_calc_binomial_gs(key, trials, prob, params)
|
|
512
|
+
large_sample, params = _jax_wrapped_calc_binomial_normal(key, trials, prob, params)
|
|
513
|
+
sample = jnp.where(small_trials, small_sample, large_sample)
|
|
514
|
+
return sample, params
|
|
473
515
|
return _jax_wrapped_calc_binomial_approx
|
|
474
516
|
|
|
517
|
+
# https://en.wikipedia.org/wiki/Negative_binomial_distribution#Gamma%E2%80%93Poisson_mixture
|
|
518
|
+
def negative_binomial(self, id, init_params, logic):
|
|
519
|
+
poisson_approx = self.poisson(id, init_params, logic)
|
|
520
|
+
def _jax_wrapped_calc_negative_binomial_approx(key, trials, prob, params):
|
|
521
|
+
key, subkey = random.split(key)
|
|
522
|
+
trials = jnp.asarray(trials, dtype=logic.REAL)
|
|
523
|
+
Gamma = random.gamma(key=key, a=trials, dtype=logic.REAL)
|
|
524
|
+
scale = (1.0 - prob) / prob
|
|
525
|
+
poisson_rate = scale * Gamma
|
|
526
|
+
return poisson_approx(subkey, poisson_rate, params)
|
|
527
|
+
return _jax_wrapped_calc_negative_binomial_approx
|
|
528
|
+
|
|
475
529
|
def geometric(self, id, init_params, logic):
|
|
476
530
|
approx_floor = logic.floor(id, init_params)
|
|
477
531
|
def _jax_wrapped_calc_geometric_approx(key, prob, params):
|
|
478
532
|
U = random.uniform(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
|
|
479
|
-
floor, params = approx_floor(
|
|
533
|
+
floor, params = approx_floor(
|
|
534
|
+
jnp.log1p(-U) / jnp.log1p(-prob + logic.eps), params)
|
|
480
535
|
sample = floor + 1
|
|
481
536
|
return sample, params
|
|
482
537
|
return _jax_wrapped_calc_geometric_approx
|
|
@@ -532,6 +587,14 @@ class Determinization(RandomSampling):
|
|
|
532
587
|
def binomial(self, id, init_params, logic):
|
|
533
588
|
return self._jax_wrapped_calc_binomial_determinized
|
|
534
589
|
|
|
590
|
+
@staticmethod
|
|
591
|
+
def _jax_wrapped_calc_negative_binomial_determinized(key, trials, prob, params):
|
|
592
|
+
sample = trials * ((1.0 / prob) - 1.0)
|
|
593
|
+
return sample, params
|
|
594
|
+
|
|
595
|
+
def negative_binomial(self, id, init_params, logic):
|
|
596
|
+
return self._jax_wrapped_calc_negative_binomial_determinized
|
|
597
|
+
|
|
535
598
|
@staticmethod
|
|
536
599
|
def _jax_wrapped_calc_geometric_determinized(key, prob, params):
|
|
537
600
|
sample = 1.0 / prob
|
|
@@ -558,21 +621,23 @@ class Determinization(RandomSampling):
|
|
|
558
621
|
#
|
|
559
622
|
# ===========================================================================
|
|
560
623
|
|
|
561
|
-
class ControlFlow:
|
|
624
|
+
class ControlFlow(metaclass=ABCMeta):
|
|
562
625
|
'''A base class for control flow, including if and switch statements.'''
|
|
563
626
|
|
|
627
|
+
@abstractmethod
|
|
564
628
|
def if_then_else(self, id, init_params):
|
|
565
|
-
|
|
629
|
+
pass
|
|
566
630
|
|
|
631
|
+
@abstractmethod
|
|
567
632
|
def switch(self, id, init_params):
|
|
568
|
-
|
|
633
|
+
pass
|
|
569
634
|
|
|
570
635
|
|
|
571
636
|
class SoftControlFlow(ControlFlow):
|
|
572
637
|
'''Soft control flow using a probabilistic interpretation.'''
|
|
573
638
|
|
|
574
639
|
def __init__(self, weight: float=10.0) -> None:
|
|
575
|
-
self.weight = weight
|
|
640
|
+
self.weight = float(weight)
|
|
576
641
|
|
|
577
642
|
@staticmethod
|
|
578
643
|
def _jax_wrapped_calc_if_then_else_soft(c, a, b, params):
|
|
@@ -606,15 +671,15 @@ class SoftControlFlow(ControlFlow):
|
|
|
606
671
|
# ===========================================================================
|
|
607
672
|
|
|
608
673
|
|
|
609
|
-
class Logic:
|
|
674
|
+
class Logic(metaclass=ABCMeta):
|
|
610
675
|
'''A base class for representing logic computations in JAX.'''
|
|
611
676
|
|
|
612
677
|
def __init__(self, use64bit: bool=False) -> None:
|
|
613
678
|
self.set_use64bit(use64bit)
|
|
614
679
|
|
|
615
|
-
def summarize_hyperparameters(self) ->
|
|
616
|
-
|
|
617
|
-
|
|
680
|
+
def summarize_hyperparameters(self) -> str:
|
|
681
|
+
return (f'model relaxation:\n'
|
|
682
|
+
f' use_64_bit ={self.use64bit}')
|
|
618
683
|
|
|
619
684
|
def set_use64bit(self, use64bit: bool) -> None:
|
|
620
685
|
'''Toggles whether or not the JAX system will use 64 bit precision.'''
|
|
@@ -712,123 +777,158 @@ class Logic:
|
|
|
712
777
|
'Discrete': self.discrete,
|
|
713
778
|
'Poisson': self.poisson,
|
|
714
779
|
'Geometric': self.geometric,
|
|
715
|
-
'Binomial': self.binomial
|
|
780
|
+
'Binomial': self.binomial,
|
|
781
|
+
'NegativeBinomial': self.negative_binomial
|
|
716
782
|
}
|
|
717
783
|
}
|
|
718
784
|
|
|
719
785
|
# ===========================================================================
|
|
720
786
|
# logical operators
|
|
721
787
|
# ===========================================================================
|
|
722
|
-
|
|
788
|
+
|
|
789
|
+
@abstractmethod
|
|
723
790
|
def logical_and(self, id, init_params):
|
|
724
|
-
|
|
791
|
+
pass
|
|
725
792
|
|
|
793
|
+
@abstractmethod
|
|
726
794
|
def logical_not(self, id, init_params):
|
|
727
|
-
|
|
795
|
+
pass
|
|
728
796
|
|
|
797
|
+
@abstractmethod
|
|
729
798
|
def logical_or(self, id, init_params):
|
|
730
|
-
|
|
799
|
+
pass
|
|
731
800
|
|
|
801
|
+
@abstractmethod
|
|
732
802
|
def xor(self, id, init_params):
|
|
733
|
-
|
|
803
|
+
pass
|
|
734
804
|
|
|
805
|
+
@abstractmethod
|
|
735
806
|
def implies(self, id, init_params):
|
|
736
|
-
|
|
807
|
+
pass
|
|
737
808
|
|
|
809
|
+
@abstractmethod
|
|
738
810
|
def equiv(self, id, init_params):
|
|
739
|
-
|
|
811
|
+
pass
|
|
740
812
|
|
|
813
|
+
@abstractmethod
|
|
741
814
|
def forall(self, id, init_params):
|
|
742
|
-
|
|
815
|
+
pass
|
|
743
816
|
|
|
817
|
+
@abstractmethod
|
|
744
818
|
def exists(self, id, init_params):
|
|
745
|
-
|
|
819
|
+
pass
|
|
746
820
|
|
|
747
821
|
# ===========================================================================
|
|
748
822
|
# comparison operators
|
|
749
823
|
# ===========================================================================
|
|
750
824
|
|
|
825
|
+
@abstractmethod
|
|
751
826
|
def greater_equal(self, id, init_params):
|
|
752
|
-
|
|
827
|
+
pass
|
|
753
828
|
|
|
829
|
+
@abstractmethod
|
|
754
830
|
def greater(self, id, init_params):
|
|
755
|
-
|
|
831
|
+
pass
|
|
756
832
|
|
|
833
|
+
@abstractmethod
|
|
757
834
|
def less_equal(self, id, init_params):
|
|
758
|
-
|
|
835
|
+
pass
|
|
759
836
|
|
|
837
|
+
@abstractmethod
|
|
760
838
|
def less(self, id, init_params):
|
|
761
|
-
|
|
839
|
+
pass
|
|
762
840
|
|
|
841
|
+
@abstractmethod
|
|
763
842
|
def equal(self, id, init_params):
|
|
764
|
-
|
|
843
|
+
pass
|
|
765
844
|
|
|
845
|
+
@abstractmethod
|
|
766
846
|
def not_equal(self, id, init_params):
|
|
767
|
-
|
|
847
|
+
pass
|
|
768
848
|
|
|
769
849
|
# ===========================================================================
|
|
770
850
|
# special functions
|
|
771
851
|
# ===========================================================================
|
|
772
852
|
|
|
853
|
+
@abstractmethod
|
|
773
854
|
def sgn(self, id, init_params):
|
|
774
|
-
|
|
855
|
+
pass
|
|
775
856
|
|
|
857
|
+
@abstractmethod
|
|
776
858
|
def floor(self, id, init_params):
|
|
777
|
-
|
|
859
|
+
pass
|
|
778
860
|
|
|
861
|
+
@abstractmethod
|
|
779
862
|
def round(self, id, init_params):
|
|
780
|
-
|
|
863
|
+
pass
|
|
781
864
|
|
|
865
|
+
@abstractmethod
|
|
782
866
|
def ceil(self, id, init_params):
|
|
783
|
-
|
|
867
|
+
pass
|
|
784
868
|
|
|
869
|
+
@abstractmethod
|
|
785
870
|
def div(self, id, init_params):
|
|
786
|
-
|
|
871
|
+
pass
|
|
787
872
|
|
|
873
|
+
@abstractmethod
|
|
788
874
|
def mod(self, id, init_params):
|
|
789
|
-
|
|
875
|
+
pass
|
|
790
876
|
|
|
877
|
+
@abstractmethod
|
|
791
878
|
def sqrt(self, id, init_params):
|
|
792
|
-
|
|
879
|
+
pass
|
|
793
880
|
|
|
794
881
|
# ===========================================================================
|
|
795
882
|
# indexing
|
|
796
883
|
# ===========================================================================
|
|
797
|
-
|
|
884
|
+
|
|
885
|
+
@abstractmethod
|
|
798
886
|
def argmax(self, id, init_params):
|
|
799
|
-
|
|
887
|
+
pass
|
|
800
888
|
|
|
889
|
+
@abstractmethod
|
|
801
890
|
def argmin(self, id, init_params):
|
|
802
|
-
|
|
891
|
+
pass
|
|
803
892
|
|
|
804
893
|
# ===========================================================================
|
|
805
894
|
# control flow
|
|
806
895
|
# ===========================================================================
|
|
807
896
|
|
|
897
|
+
@abstractmethod
|
|
808
898
|
def control_if(self, id, init_params):
|
|
809
|
-
|
|
899
|
+
pass
|
|
810
900
|
|
|
901
|
+
@abstractmethod
|
|
811
902
|
def control_switch(self, id, init_params):
|
|
812
|
-
|
|
903
|
+
pass
|
|
813
904
|
|
|
814
905
|
# ===========================================================================
|
|
815
906
|
# random variables
|
|
816
907
|
# ===========================================================================
|
|
817
908
|
|
|
909
|
+
@abstractmethod
|
|
818
910
|
def discrete(self, id, init_params):
|
|
819
|
-
|
|
911
|
+
pass
|
|
820
912
|
|
|
913
|
+
@abstractmethod
|
|
821
914
|
def bernoulli(self, id, init_params):
|
|
822
|
-
|
|
915
|
+
pass
|
|
823
916
|
|
|
917
|
+
@abstractmethod
|
|
824
918
|
def poisson(self, id, init_params):
|
|
825
|
-
|
|
919
|
+
pass
|
|
826
920
|
|
|
921
|
+
@abstractmethod
|
|
827
922
|
def geometric(self, id, init_params):
|
|
828
|
-
|
|
923
|
+
pass
|
|
829
924
|
|
|
925
|
+
@abstractmethod
|
|
830
926
|
def binomial(self, id, init_params):
|
|
831
|
-
|
|
927
|
+
pass
|
|
928
|
+
|
|
929
|
+
@abstractmethod
|
|
930
|
+
def negative_binomial(self, id, init_params):
|
|
931
|
+
pass
|
|
832
932
|
|
|
833
933
|
|
|
834
934
|
class ExactLogic(Logic):
|
|
@@ -1005,6 +1105,17 @@ class ExactLogic(Logic):
|
|
|
1005
1105
|
sample = jnp.asarray(sample, dtype=self.INT)
|
|
1006
1106
|
return sample, params
|
|
1007
1107
|
return _jax_wrapped_calc_binomial_exact
|
|
1108
|
+
|
|
1109
|
+
# note: for some reason tfp defines it as number of successes before trials failures
|
|
1110
|
+
# I will define it as the number of failures before trials successes
|
|
1111
|
+
def negative_binomial(self, id, init_params):
|
|
1112
|
+
def _jax_wrapped_calc_negative_binomial_exact(key, trials, prob, params):
|
|
1113
|
+
trials = jnp.asarray(trials, dtype=self.REAL)
|
|
1114
|
+
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
1115
|
+
dist = tfp.distributions.NegativeBinomial(total_count=trials, probs=1.0 - prob)
|
|
1116
|
+
sample = jnp.asarray(dist.sample(seed=key), dtype=self.INT)
|
|
1117
|
+
return sample, params
|
|
1118
|
+
return _jax_wrapped_calc_negative_binomial_exact
|
|
1008
1119
|
|
|
1009
1120
|
|
|
1010
1121
|
class FuzzyLogic(Logic):
|
|
@@ -1049,8 +1160,8 @@ class FuzzyLogic(Logic):
|
|
|
1049
1160
|
f' underflow_tol={self.eps}\n'
|
|
1050
1161
|
f' use_64_bit ={self.use64bit}\n')
|
|
1051
1162
|
|
|
1052
|
-
def summarize_hyperparameters(self) ->
|
|
1053
|
-
|
|
1163
|
+
def summarize_hyperparameters(self) -> str:
|
|
1164
|
+
return self.__str__()
|
|
1054
1165
|
|
|
1055
1166
|
# ===========================================================================
|
|
1056
1167
|
# logical operators
|
|
@@ -1234,6 +1345,9 @@ class FuzzyLogic(Logic):
|
|
|
1234
1345
|
|
|
1235
1346
|
def binomial(self, id, init_params):
|
|
1236
1347
|
return self.sampling.binomial(id, init_params, self)
|
|
1348
|
+
|
|
1349
|
+
def negative_binomial(self, id, init_params):
|
|
1350
|
+
return self.sampling.negative_binomial(id, init_params, self)
|
|
1237
1351
|
|
|
1238
1352
|
|
|
1239
1353
|
# ===========================================================================
|