pyRDDLGym-jax 2.2__py3-none-any.whl → 2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -29,15 +29,29 @@
29
29
  #
30
30
  # ***********************************************************************
31
31
 
32
- from typing import Callable, Dict, Union
32
+
33
+ from abc import ABCMeta, abstractmethod
34
+ import traceback
35
+ from typing import Callable, Dict, Tuple, Union
33
36
 
34
37
  import jax
35
38
  import jax.numpy as jnp
36
39
  import jax.random as random
37
40
  import jax.scipy as scipy
38
41
 
42
+ from pyRDDLGym.core.debug.exception import raise_warning
43
+
44
+ # more robust approach - if user does not have this or broken try to continue
45
+ try:
46
+ from tensorflow_probability.substrates import jax as tfp
47
+ except Exception:
48
+ raise_warning('Failed to import tensorflow-probability: '
49
+ 'compilation of some probability distributions will fail.', 'red')
50
+ traceback.print_exc()
51
+ tfp = None
52
+
39
53
 
40
- def enumerate_literals(shape, axis, dtype=jnp.int32):
54
+ def enumerate_literals(shape: Tuple[int, ...], axis: int, dtype: type=jnp.int32) -> jnp.ndarray:
41
55
  literals = jnp.arange(shape[axis], dtype=dtype)
42
56
  literals = literals[(...,) + (jnp.newaxis,) * (len(shape) - 1)]
43
57
  literals = jnp.moveaxis(literals, source=0, destination=axis)
@@ -52,30 +66,35 @@ def enumerate_literals(shape, axis, dtype=jnp.int32):
52
66
  #
53
67
  # ===========================================================================
54
68
 
55
- class Comparison:
69
+ class Comparison(metaclass=ABCMeta):
56
70
  '''Base class for approximate comparison operations.'''
57
71
 
72
+ @abstractmethod
58
73
  def greater_equal(self, id, init_params):
59
- raise NotImplementedError
74
+ pass
60
75
 
76
+ @abstractmethod
61
77
  def greater(self, id, init_params):
62
- raise NotImplementedError
78
+ pass
63
79
 
80
+ @abstractmethod
64
81
  def equal(self, id, init_params):
65
- raise NotImplementedError
82
+ pass
66
83
 
84
+ @abstractmethod
67
85
  def sgn(self, id, init_params):
68
- raise NotImplementedError
86
+ pass
69
87
 
88
+ @abstractmethod
70
89
  def argmax(self, id, init_params):
71
- raise NotImplementedError
90
+ pass
72
91
 
73
92
 
74
93
  class SigmoidComparison(Comparison):
75
94
  '''Comparison operations approximated using sigmoid functions.'''
76
95
 
77
- def __init__(self, weight: float=10.0):
78
- self.weight = weight
96
+ def __init__(self, weight: float=10.0) -> None:
97
+ self.weight = float(weight)
79
98
 
80
99
  # https://arxiv.org/abs/2110.05651
81
100
  def greater_equal(self, id, init_params):
@@ -127,21 +146,23 @@ class SigmoidComparison(Comparison):
127
146
  #
128
147
  # ===========================================================================
129
148
 
130
- class Rounding:
149
+ class Rounding(metaclass=ABCMeta):
131
150
  '''Base class for approximate rounding operations.'''
132
151
 
152
+ @abstractmethod
133
153
  def floor(self, id, init_params):
134
- raise NotImplementedError
154
+ pass
135
155
 
156
+ @abstractmethod
136
157
  def round(self, id, init_params):
137
- raise NotImplementedError
158
+ pass
138
159
 
139
160
 
140
161
  class SoftRounding(Rounding):
141
162
  '''Rounding operations approximated using soft operations.'''
142
163
 
143
- def __init__(self, weight: float=10.0):
144
- self.weight = weight
164
+ def __init__(self, weight: float=10.0) -> None:
165
+ self.weight = float(weight)
145
166
 
146
167
  # https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/bijectors/Softfloor
147
168
  def floor(self, id, init_params):
@@ -177,11 +198,12 @@ class SoftRounding(Rounding):
177
198
  #
178
199
  # ===========================================================================
179
200
 
180
- class Complement:
201
+ class Complement(metaclass=ABCMeta):
181
202
  '''Base class for approximate logical complement operations.'''
182
203
 
204
+ @abstractmethod
183
205
  def __call__(self, id, init_params):
184
- raise NotImplementedError
206
+ pass
185
207
 
186
208
 
187
209
  class StandardComplement(Complement):
@@ -210,16 +232,18 @@ class StandardComplement(Complement):
210
232
  # https://www.sciencedirect.com/science/article/abs/pii/016501149190171L
211
233
  # ===========================================================================
212
234
 
213
- class TNorm:
235
+ class TNorm(metaclass=ABCMeta):
214
236
  '''Base class for fuzzy differentiable t-norms.'''
215
237
 
238
+ @abstractmethod
216
239
  def norm(self, id, init_params):
217
240
  '''Elementwise t-norm of x and y.'''
218
- raise NotImplementedError
241
+ pass
219
242
 
243
+ @abstractmethod
220
244
  def norms(self, id, init_params):
221
245
  '''T-norm computed for tensor x along axis.'''
222
- raise NotImplementedError
246
+ pass
223
247
 
224
248
 
225
249
  class ProductTNorm(TNorm):
@@ -291,7 +315,7 @@ class YagerTNorm(TNorm):
291
315
  '''Yager t-norm given by the expression
292
316
  (x, y) -> max(1 - ((1 - x)^p + (1 - y)^p)^(1/p)).'''
293
317
 
294
- def __init__(self, p=2.0):
318
+ def __init__(self, p: float=2.0) -> None:
295
319
  self.p = float(p)
296
320
 
297
321
  def norm(self, id, init_params):
@@ -327,23 +351,32 @@ class YagerTNorm(TNorm):
327
351
  #
328
352
  # ===========================================================================
329
353
 
330
- class RandomSampling:
354
+ class RandomSampling(metaclass=ABCMeta):
331
355
  '''Describes how non-reparameterizable random variables are sampled.'''
332
356
 
357
+ @abstractmethod
333
358
  def discrete(self, id, init_params, logic):
334
- raise NotImplementedError
359
+ pass
335
360
 
361
+ @abstractmethod
336
362
  def poisson(self, id, init_params, logic):
337
- raise NotImplementedError
363
+ pass
338
364
 
365
+ @abstractmethod
339
366
  def binomial(self, id, init_params, logic):
340
- raise NotImplementedError
367
+ pass
341
368
 
369
+ @abstractmethod
370
+ def negative_binomial(self, id, init_params, logic):
371
+ pass
372
+
373
+ @abstractmethod
342
374
  def geometric(self, id, init_params, logic):
343
- raise NotImplementedError
375
+ pass
344
376
 
377
+ @abstractmethod
345
378
  def bernoulli(self, id, init_params, logic):
346
- raise NotImplementedError
379
+ pass
347
380
 
348
381
  def __str__(self) -> str:
349
382
  return 'RandomSampling'
@@ -386,8 +419,7 @@ class SoftRandomSampling(RandomSampling):
386
419
  def _poisson_gumbel_softmax(self, id, init_params, logic):
387
420
  argmax_approx = logic.argmax(id, init_params)
388
421
  def _jax_wrapped_calc_poisson_gumbel_softmax(key, rate, params):
389
- ks = jnp.arange(0, self.poisson_bins)
390
- ks = ks[(jnp.newaxis,) * jnp.ndim(rate) + (...,)]
422
+ ks = jnp.arange(self.poisson_bins)[(jnp.newaxis,) * jnp.ndim(rate) + (...,)]
391
423
  rate = rate[..., jnp.newaxis]
392
424
  log_prob = ks * jnp.log(rate + logic.eps) - rate - scipy.special.gammaln(ks + 1)
393
425
  Gumbel01 = random.gumbel(key=key, shape=jnp.shape(log_prob), dtype=logic.REAL)
@@ -400,10 +432,7 @@ class SoftRandomSampling(RandomSampling):
400
432
  less_approx = logic.less(id, init_params)
401
433
  def _jax_wrapped_calc_poisson_exponential(key, rate, params):
402
434
  Exp1 = random.exponential(
403
- key=key,
404
- shape=(self.poisson_bins,) + jnp.shape(rate),
405
- dtype=logic.REAL
406
- )
435
+ key=key, shape=(self.poisson_bins,) + jnp.shape(rate), dtype=logic.REAL)
407
436
  delta_t = Exp1 / rate[jnp.newaxis, ...]
408
437
  times = jnp.cumsum(delta_t, axis=0)
409
438
  indicator, params = less_approx(times, 1.0, params)
@@ -411,72 +440,98 @@ class SoftRandomSampling(RandomSampling):
411
440
  return sample, params
412
441
  return _jax_wrapped_calc_poisson_exponential
413
442
 
443
+ # normal approximation to Poisson: Poisson(rate) -> Normal(rate, rate)
444
+ def _poisson_normal_approx(self, logic):
445
+ def _jax_wrapped_calc_poisson_normal_approx(key, rate, params):
446
+ normal = random.normal(key=key, shape=jnp.shape(rate), dtype=logic.REAL)
447
+ sample = rate + jnp.sqrt(rate) * normal
448
+ return sample, params
449
+ return _jax_wrapped_calc_poisson_normal_approx
450
+
414
451
  def poisson(self, id, init_params, logic):
415
- def _jax_wrapped_calc_poisson_exact(key, rate, params):
416
- sample = random.poisson(key=key, lam=rate, dtype=logic.INT)
417
- sample = jnp.asarray(sample, dtype=logic.REAL)
418
- return sample, params
419
-
420
452
  if self.poisson_exp_method:
421
453
  _jax_wrapped_calc_poisson_diff = self._poisson_exponential(
422
454
  id, init_params, logic)
423
455
  else:
424
456
  _jax_wrapped_calc_poisson_diff = self._poisson_gumbel_softmax(
425
457
  id, init_params, logic)
458
+ _jax_wrapped_calc_poisson_normal = self._poisson_normal_approx(logic)
426
459
 
460
+ # for small rate use the Poisson process or gumbel-softmax reparameterization
461
+ # for large rate use the normal approximation
427
462
  def _jax_wrapped_calc_poisson_approx(key, rate, params):
428
-
429
- # determine if error of truncation at rate is acceptable
430
463
  if self.poisson_bins > 0:
431
464
  cuml_prob = scipy.stats.poisson.cdf(self.poisson_bins, rate)
432
- approx_cond = jax.lax.stop_gradient(
433
- jnp.min(cuml_prob) > self.poisson_min_cdf)
465
+ small_rate = jax.lax.stop_gradient(cuml_prob >= self.poisson_min_cdf)
466
+ small_sample, params = _jax_wrapped_calc_poisson_diff(key, rate, params)
467
+ large_sample, params = _jax_wrapped_calc_poisson_normal(key, rate, params)
468
+ sample = jnp.where(small_rate, small_sample, large_sample)
469
+ return sample, params
434
470
  else:
435
- approx_cond = False
436
-
437
- # for acceptable truncation use the approximation, use exact otherwise
438
- return jax.lax.cond(
439
- approx_cond,
440
- _jax_wrapped_calc_poisson_diff,
441
- _jax_wrapped_calc_poisson_exact,
442
- key, rate, params
443
- )
471
+ return _jax_wrapped_calc_poisson_normal(key, rate, params)
444
472
  return _jax_wrapped_calc_poisson_approx
445
473
 
446
- def binomial(self, id, init_params, logic):
447
- def _jax_wrapped_calc_binomial_exact(key, trials, prob, params):
448
- trials = jnp.asarray(trials, dtype=logic.REAL)
449
- prob = jnp.asarray(prob, dtype=logic.REAL)
450
- sample = random.binomial(key=key, n=trials, p=prob, dtype=logic.REAL)
451
- return sample, params
474
+ # normal approximation to Binomial: Bin(n, p) -> Normal(np, np(1-p))
475
+ def _binomial_normal_approx(self, logic):
476
+ def _jax_wrapped_calc_binomial_normal_approx(key, trials, prob, params):
477
+ normal = random.normal(key=key, shape=jnp.shape(trials), dtype=logic.REAL)
478
+ mean = trials * prob
479
+ std = jnp.sqrt(trials * prob * (1.0 - prob))
480
+ sample = mean + std * normal
481
+ return sample, params
482
+ return _jax_wrapped_calc_binomial_normal_approx
452
483
 
453
- # Binomial(n, p) = sum_{i = 1 ... n} Bernoulli(p)
454
- bernoulli_approx = self.bernoulli(id, init_params, logic)
455
- def _jax_wrapped_calc_binomial_sum(key, trials, prob, params):
456
- prob_full = jnp.broadcast_to(
457
- prob[..., jnp.newaxis], shape=jnp.shape(prob) + (self.binomial_bins,))
458
- sample_bern, params = bernoulli_approx(key, prob_full, params)
459
- indices = jnp.arange(self.binomial_bins)[
460
- (jnp.newaxis,) * jnp.ndim(prob) + (...,)]
461
- mask = indices < trials[..., jnp.newaxis]
462
- sample = jnp.sum(sample_bern * mask, axis=-1)
463
- return sample, params
484
+ def _binomial_gumbel_softmax(self, id, init_params, logic):
485
+ argmax_approx = logic.argmax(id, init_params)
486
+ def _jax_wrapped_calc_binomial_gumbel_softmax(key, trials, prob, params):
487
+ ks = jnp.arange(self.binomial_bins)[(jnp.newaxis,) * jnp.ndim(trials) + (...,)]
488
+ trials = trials[..., jnp.newaxis]
489
+ prob = prob[..., jnp.newaxis]
490
+ in_support = ks <= trials
491
+ ks = jnp.minimum(ks, trials)
492
+ log_prob = ((scipy.special.gammaln(trials + 1) -
493
+ scipy.special.gammaln(ks + 1) -
494
+ scipy.special.gammaln(trials - ks + 1)) +
495
+ ks * jnp.log(prob + logic.eps) +
496
+ (trials - ks) * jnp.log1p(-prob + logic.eps))
497
+ log_prob = jnp.where(in_support, log_prob, jnp.log(logic.eps))
498
+ Gumbel01 = random.gumbel(key=key, shape=jnp.shape(log_prob), dtype=logic.REAL)
499
+ sample = Gumbel01 + log_prob
500
+ return argmax_approx(sample, axis=-1, params=params)
501
+ return _jax_wrapped_calc_binomial_gumbel_softmax
502
+
503
+ def binomial(self, id, init_params, logic):
504
+ _jax_wrapped_calc_binomial_normal = self._binomial_normal_approx(logic)
505
+ _jax_wrapped_calc_binomial_gs = self._binomial_gumbel_softmax(id, init_params, logic)
464
506
 
465
- # for trials not too large use the Bernoulli relaxation, use exact otherwise
507
+ # for small trials use the Bernoulli relaxation
508
+ # for large trials use the normal approximation
466
509
  def _jax_wrapped_calc_binomial_approx(key, trials, prob, params):
467
- return jax.lax.cond(
468
- jax.lax.stop_gradient(jnp.max(trials) < self.binomial_bins),
469
- _jax_wrapped_calc_binomial_sum,
470
- _jax_wrapped_calc_binomial_exact,
471
- key, trials, prob, params
472
- )
510
+ small_trials = jax.lax.stop_gradient(trials < self.binomial_bins)
511
+ small_sample, params = _jax_wrapped_calc_binomial_gs(key, trials, prob, params)
512
+ large_sample, params = _jax_wrapped_calc_binomial_normal(key, trials, prob, params)
513
+ sample = jnp.where(small_trials, small_sample, large_sample)
514
+ return sample, params
473
515
  return _jax_wrapped_calc_binomial_approx
474
516
 
517
+ # https://en.wikipedia.org/wiki/Negative_binomial_distribution#Gamma%E2%80%93Poisson_mixture
518
+ def negative_binomial(self, id, init_params, logic):
519
+ poisson_approx = self.poisson(id, init_params, logic)
520
+ def _jax_wrapped_calc_negative_binomial_approx(key, trials, prob, params):
521
+ key, subkey = random.split(key)
522
+ trials = jnp.asarray(trials, dtype=logic.REAL)
523
+ Gamma = random.gamma(key=key, a=trials, dtype=logic.REAL)
524
+ scale = (1.0 - prob) / prob
525
+ poisson_rate = scale * Gamma
526
+ return poisson_approx(subkey, poisson_rate, params)
527
+ return _jax_wrapped_calc_negative_binomial_approx
528
+
475
529
  def geometric(self, id, init_params, logic):
476
530
  approx_floor = logic.floor(id, init_params)
477
531
  def _jax_wrapped_calc_geometric_approx(key, prob, params):
478
532
  U = random.uniform(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
479
- floor, params = approx_floor(jnp.log1p(-U) / jnp.log1p(-prob), params)
533
+ floor, params = approx_floor(
534
+ jnp.log1p(-U) / jnp.log1p(-prob + logic.eps), params)
480
535
  sample = floor + 1
481
536
  return sample, params
482
537
  return _jax_wrapped_calc_geometric_approx
@@ -532,6 +587,14 @@ class Determinization(RandomSampling):
532
587
  def binomial(self, id, init_params, logic):
533
588
  return self._jax_wrapped_calc_binomial_determinized
534
589
 
590
+ @staticmethod
591
+ def _jax_wrapped_calc_negative_binomial_determinized(key, trials, prob, params):
592
+ sample = trials * ((1.0 / prob) - 1.0)
593
+ return sample, params
594
+
595
+ def negative_binomial(self, id, init_params, logic):
596
+ return self._jax_wrapped_calc_negative_binomial_determinized
597
+
535
598
  @staticmethod
536
599
  def _jax_wrapped_calc_geometric_determinized(key, prob, params):
537
600
  sample = 1.0 / prob
@@ -558,21 +621,23 @@ class Determinization(RandomSampling):
558
621
  #
559
622
  # ===========================================================================
560
623
 
561
- class ControlFlow:
624
+ class ControlFlow(metaclass=ABCMeta):
562
625
  '''A base class for control flow, including if and switch statements.'''
563
626
 
627
+ @abstractmethod
564
628
  def if_then_else(self, id, init_params):
565
- raise NotImplementedError
629
+ pass
566
630
 
631
+ @abstractmethod
567
632
  def switch(self, id, init_params):
568
- raise NotImplementedError
633
+ pass
569
634
 
570
635
 
571
636
  class SoftControlFlow(ControlFlow):
572
637
  '''Soft control flow using a probabilistic interpretation.'''
573
638
 
574
639
  def __init__(self, weight: float=10.0) -> None:
575
- self.weight = weight
640
+ self.weight = float(weight)
576
641
 
577
642
  @staticmethod
578
643
  def _jax_wrapped_calc_if_then_else_soft(c, a, b, params):
@@ -606,15 +671,15 @@ class SoftControlFlow(ControlFlow):
606
671
  # ===========================================================================
607
672
 
608
673
 
609
- class Logic:
674
+ class Logic(metaclass=ABCMeta):
610
675
  '''A base class for representing logic computations in JAX.'''
611
676
 
612
677
  def __init__(self, use64bit: bool=False) -> None:
613
678
  self.set_use64bit(use64bit)
614
679
 
615
- def summarize_hyperparameters(self) -> None:
616
- print(f'model relaxation:\n'
617
- f' use_64_bit ={self.use64bit}')
680
+ def summarize_hyperparameters(self) -> str:
681
+ return (f'model relaxation:\n'
682
+ f' use_64_bit ={self.use64bit}')
618
683
 
619
684
  def set_use64bit(self, use64bit: bool) -> None:
620
685
  '''Toggles whether or not the JAX system will use 64 bit precision.'''
@@ -712,123 +777,158 @@ class Logic:
712
777
  'Discrete': self.discrete,
713
778
  'Poisson': self.poisson,
714
779
  'Geometric': self.geometric,
715
- 'Binomial': self.binomial
780
+ 'Binomial': self.binomial,
781
+ 'NegativeBinomial': self.negative_binomial
716
782
  }
717
783
  }
718
784
 
719
785
  # ===========================================================================
720
786
  # logical operators
721
787
  # ===========================================================================
722
-
788
+
789
+ @abstractmethod
723
790
  def logical_and(self, id, init_params):
724
- raise NotImplementedError
791
+ pass
725
792
 
793
+ @abstractmethod
726
794
  def logical_not(self, id, init_params):
727
- raise NotImplementedError
795
+ pass
728
796
 
797
+ @abstractmethod
729
798
  def logical_or(self, id, init_params):
730
- raise NotImplementedError
799
+ pass
731
800
 
801
+ @abstractmethod
732
802
  def xor(self, id, init_params):
733
- raise NotImplementedError
803
+ pass
734
804
 
805
+ @abstractmethod
735
806
  def implies(self, id, init_params):
736
- raise NotImplementedError
807
+ pass
737
808
 
809
+ @abstractmethod
738
810
  def equiv(self, id, init_params):
739
- raise NotImplementedError
811
+ pass
740
812
 
813
+ @abstractmethod
741
814
  def forall(self, id, init_params):
742
- raise NotImplementedError
815
+ pass
743
816
 
817
+ @abstractmethod
744
818
  def exists(self, id, init_params):
745
- raise NotImplementedError
819
+ pass
746
820
 
747
821
  # ===========================================================================
748
822
  # comparison operators
749
823
  # ===========================================================================
750
824
 
825
+ @abstractmethod
751
826
  def greater_equal(self, id, init_params):
752
- raise NotImplementedError
827
+ pass
753
828
 
829
+ @abstractmethod
754
830
  def greater(self, id, init_params):
755
- raise NotImplementedError
831
+ pass
756
832
 
833
+ @abstractmethod
757
834
  def less_equal(self, id, init_params):
758
- raise NotImplementedError
835
+ pass
759
836
 
837
+ @abstractmethod
760
838
  def less(self, id, init_params):
761
- raise NotImplementedError
839
+ pass
762
840
 
841
+ @abstractmethod
763
842
  def equal(self, id, init_params):
764
- raise NotImplementedError
843
+ pass
765
844
 
845
+ @abstractmethod
766
846
  def not_equal(self, id, init_params):
767
- raise NotImplementedError
847
+ pass
768
848
 
769
849
  # ===========================================================================
770
850
  # special functions
771
851
  # ===========================================================================
772
852
 
853
+ @abstractmethod
773
854
  def sgn(self, id, init_params):
774
- raise NotImplementedError
855
+ pass
775
856
 
857
+ @abstractmethod
776
858
  def floor(self, id, init_params):
777
- raise NotImplementedError
859
+ pass
778
860
 
861
+ @abstractmethod
779
862
  def round(self, id, init_params):
780
- raise NotImplementedError
863
+ pass
781
864
 
865
+ @abstractmethod
782
866
  def ceil(self, id, init_params):
783
- raise NotImplementedError
867
+ pass
784
868
 
869
+ @abstractmethod
785
870
  def div(self, id, init_params):
786
- raise NotImplementedError
871
+ pass
787
872
 
873
+ @abstractmethod
788
874
  def mod(self, id, init_params):
789
- raise NotImplementedError
875
+ pass
790
876
 
877
+ @abstractmethod
791
878
  def sqrt(self, id, init_params):
792
- raise NotImplementedError
879
+ pass
793
880
 
794
881
  # ===========================================================================
795
882
  # indexing
796
883
  # ===========================================================================
797
-
884
+
885
+ @abstractmethod
798
886
  def argmax(self, id, init_params):
799
- raise NotImplementedError
887
+ pass
800
888
 
889
+ @abstractmethod
801
890
  def argmin(self, id, init_params):
802
- raise NotImplementedError
891
+ pass
803
892
 
804
893
  # ===========================================================================
805
894
  # control flow
806
895
  # ===========================================================================
807
896
 
897
+ @abstractmethod
808
898
  def control_if(self, id, init_params):
809
- raise NotImplementedError
899
+ pass
810
900
 
901
+ @abstractmethod
811
902
  def control_switch(self, id, init_params):
812
- raise NotImplementedError
903
+ pass
813
904
 
814
905
  # ===========================================================================
815
906
  # random variables
816
907
  # ===========================================================================
817
908
 
909
+ @abstractmethod
818
910
  def discrete(self, id, init_params):
819
- raise NotImplementedError
911
+ pass
820
912
 
913
+ @abstractmethod
821
914
  def bernoulli(self, id, init_params):
822
- raise NotImplementedError
915
+ pass
823
916
 
917
+ @abstractmethod
824
918
  def poisson(self, id, init_params):
825
- raise NotImplementedError
919
+ pass
826
920
 
921
+ @abstractmethod
827
922
  def geometric(self, id, init_params):
828
- raise NotImplementedError
923
+ pass
829
924
 
925
+ @abstractmethod
830
926
  def binomial(self, id, init_params):
831
- raise NotImplementedError
927
+ pass
928
+
929
+ @abstractmethod
930
+ def negative_binomial(self, id, init_params):
931
+ pass
832
932
 
833
933
 
834
934
  class ExactLogic(Logic):
@@ -1005,6 +1105,17 @@ class ExactLogic(Logic):
1005
1105
  sample = jnp.asarray(sample, dtype=self.INT)
1006
1106
  return sample, params
1007
1107
  return _jax_wrapped_calc_binomial_exact
1108
+
1109
+ # note: for some reason tfp defines it as number of successes before trials failures
1110
+ # I will define it as the number of failures before trials successes
1111
+ def negative_binomial(self, id, init_params):
1112
+ def _jax_wrapped_calc_negative_binomial_exact(key, trials, prob, params):
1113
+ trials = jnp.asarray(trials, dtype=self.REAL)
1114
+ prob = jnp.asarray(prob, dtype=self.REAL)
1115
+ dist = tfp.distributions.NegativeBinomial(total_count=trials, probs=1.0 - prob)
1116
+ sample = jnp.asarray(dist.sample(seed=key), dtype=self.INT)
1117
+ return sample, params
1118
+ return _jax_wrapped_calc_negative_binomial_exact
1008
1119
 
1009
1120
 
1010
1121
  class FuzzyLogic(Logic):
@@ -1049,8 +1160,8 @@ class FuzzyLogic(Logic):
1049
1160
  f' underflow_tol={self.eps}\n'
1050
1161
  f' use_64_bit ={self.use64bit}\n')
1051
1162
 
1052
- def summarize_hyperparameters(self) -> None:
1053
- print(self.__str__())
1163
+ def summarize_hyperparameters(self) -> str:
1164
+ return self.__str__()
1054
1165
 
1055
1166
  # ===========================================================================
1056
1167
  # logical operators
@@ -1234,6 +1345,9 @@ class FuzzyLogic(Logic):
1234
1345
 
1235
1346
  def binomial(self, id, init_params):
1236
1347
  return self.sampling.binomial(id, init_params, self)
1348
+
1349
+ def negative_binomial(self, id, init_params):
1350
+ return self.sampling.negative_binomial(id, init_params, self)
1237
1351
 
1238
1352
 
1239
1353
  # ===========================================================================