pyRDDLGym-jax 2.0__py3-none-any.whl → 2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,12 +24,17 @@
24
24
  # Gumble-Softmax." In International Conference on Learning Representations (ICLR 2017).
25
25
  # OpenReview. net, 2017.
26
26
  #
27
+ # [6] Vafaii, H., Galor, D., & Yates, J. (2025). Poisson Variational Autoencoder.
28
+ # Advances in Neural Information Processing Systems, 37, 44871-44906.
29
+ #
27
30
  # ***********************************************************************
28
31
 
32
+ from typing import Callable, Dict, Union
29
33
 
30
34
  import jax
31
35
  import jax.numpy as jnp
32
36
  import jax.random as random
37
+ import jax.scipy as scipy
33
38
 
34
39
 
35
40
  def enumerate_literals(shape, axis, dtype=jnp.int32):
@@ -105,7 +110,7 @@ class SigmoidComparison(Comparison):
105
110
  id_ = str(id)
106
111
  init_params[id_] = self.weight
107
112
  def _jax_wrapped_calc_argmax_approx(x, axis, params):
108
- literals = enumerate_literals(x.shape, axis=axis)
113
+ literals = enumerate_literals(jnp.shape(x), axis=axis)
109
114
  softmax = jax.nn.softmax(params[id_] * x, axis=axis)
110
115
  sample = jnp.sum(literals * softmax, axis=axis)
111
116
  return sample, params
@@ -323,62 +328,192 @@ class YagerTNorm(TNorm):
323
328
  # ===========================================================================
324
329
 
325
330
  class RandomSampling:
326
- '''An abstract class that describes how discrete and non-reparameterizable
327
- random variables are sampled.'''
331
+ '''Describes how non-reparameterizable random variables are sampled.'''
328
332
 
329
333
  def discrete(self, id, init_params, logic):
330
334
  raise NotImplementedError
331
335
 
332
- def bernoulli(self, id, init_params, logic):
333
- discrete_approx = self.discrete(id, init_params, logic)
334
- def _jax_wrapped_calc_bernoulli_approx(key, prob, params):
335
- prob = jnp.stack([1.0 - prob, prob], axis=-1)
336
- return discrete_approx(key, prob, params)
337
- return _jax_wrapped_calc_bernoulli_approx
338
-
339
- @staticmethod
340
- def _jax_wrapped_calc_poisson_exact(key, rate, params):
341
- sample = random.poisson(key=key, lam=rate, dtype=logic.INT)
342
- return sample, params
343
-
344
336
  def poisson(self, id, init_params, logic):
345
- return self._jax_wrapped_calc_poisson_exact
337
+ raise NotImplementedError
338
+
339
+ def binomial(self, id, init_params, logic):
340
+ raise NotImplementedError
346
341
 
347
342
  def geometric(self, id, init_params, logic):
348
- approx_floor = logic.floor(id, init_params)
349
- def _jax_wrapped_calc_geometric_approx(key, prob, params):
350
- U = random.uniform(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
351
- floor, params = approx_floor(jnp.log(U) / jnp.log(1.0 - prob), params)
352
- sample = floor + 1
353
- return sample, params
354
- return _jax_wrapped_calc_geometric_approx
343
+ raise NotImplementedError
344
+
345
+ def bernoulli(self, id, init_params, logic):
346
+ raise NotImplementedError
347
+
348
+ def __str__(self) -> str:
349
+ return 'RandomSampling'
355
350
 
356
351
 
357
- class GumbelSoftmax(RandomSampling):
352
+ class SoftRandomSampling(RandomSampling):
358
353
  '''Random sampling of discrete variables using Gumbel-softmax trick.'''
359
354
 
355
+ def __init__(self, poisson_max_bins: int=100,
356
+ poisson_min_cdf: float=0.999,
357
+ poisson_exp_sampling: bool=True,
358
+ binomial_max_bins: int=100,
359
+ bernoulli_gumbel_softmax: bool=False) -> None:
360
+ '''Creates a new instance of soft random sampling.
361
+
362
+ :param poisson_max_bins: maximum bins to use for Poisson distribution relaxation
363
+ :param poisson_min_cdf: minimum cdf value of Poisson within truncated region
364
+ in order to use Poisson relaxation
365
+ :param poisson_exp_sampling: whether to use Poisson process sampling method
366
+ instead of truncated Gumbel-Softmax
367
+ :param binomial_max_bins: maximum bins to use for Binomial distribution relaxation
368
+ :param bernoulli_gumbel_softmax: whether to use Gumbel-Softmax to approximate
369
+ Bernoulli samples, or the standard uniform reparameterization instead
370
+ '''
371
+ self.poisson_bins = poisson_max_bins
372
+ self.poisson_min_cdf = poisson_min_cdf
373
+ self.poisson_exp_method = poisson_exp_sampling
374
+ self.binomial_bins = binomial_max_bins
375
+ self.bernoulli_gumbel_softmax = bernoulli_gumbel_softmax
376
+
360
377
  # https://arxiv.org/pdf/1611.01144
361
378
  def discrete(self, id, init_params, logic):
362
379
  argmax_approx = logic.argmax(id, init_params)
363
380
  def _jax_wrapped_calc_discrete_gumbel_softmax(key, prob, params):
364
- Gumbel01 = random.gumbel(key=key, shape=prob.shape, dtype=logic.REAL)
381
+ Gumbel01 = random.gumbel(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
365
382
  sample = Gumbel01 + jnp.log(prob + logic.eps)
366
383
  return argmax_approx(sample, axis=-1, params=params)
367
384
  return _jax_wrapped_calc_discrete_gumbel_softmax
368
385
 
386
+ def _poisson_gumbel_softmax(self, id, init_params, logic):
387
+ argmax_approx = logic.argmax(id, init_params)
388
+ def _jax_wrapped_calc_poisson_gumbel_softmax(key, rate, params):
389
+ ks = jnp.arange(0, self.poisson_bins)
390
+ ks = ks[(jnp.newaxis,) * jnp.ndim(rate) + (...,)]
391
+ rate = rate[..., jnp.newaxis]
392
+ log_prob = ks * jnp.log(rate + logic.eps) - rate - scipy.special.gammaln(ks + 1)
393
+ Gumbel01 = random.gumbel(key=key, shape=jnp.shape(log_prob), dtype=logic.REAL)
394
+ sample = Gumbel01 + log_prob
395
+ return argmax_approx(sample, axis=-1, params=params)
396
+ return _jax_wrapped_calc_poisson_gumbel_softmax
397
+
398
+ # https://arxiv.org/abs/2405.14473
399
+ def _poisson_exponential(self, id, init_params, logic):
400
+ less_approx = logic.less(id, init_params)
401
+ def _jax_wrapped_calc_poisson_exponential(key, rate, params):
402
+ Exp1 = random.exponential(
403
+ key=key,
404
+ shape=(self.poisson_bins,) + jnp.shape(rate),
405
+ dtype=logic.REAL
406
+ )
407
+ delta_t = Exp1 / rate[jnp.newaxis, ...]
408
+ times = jnp.cumsum(delta_t, axis=0)
409
+ indicator, params = less_approx(times, 1.0, params)
410
+ sample = jnp.sum(indicator, axis=0)
411
+ return sample, params
412
+ return _jax_wrapped_calc_poisson_exponential
413
+
414
+ def poisson(self, id, init_params, logic):
415
+ def _jax_wrapped_calc_poisson_exact(key, rate, params):
416
+ sample = random.poisson(key=key, lam=rate, dtype=logic.INT)
417
+ sample = jnp.asarray(sample, dtype=logic.REAL)
418
+ return sample, params
419
+
420
+ if self.poisson_exp_method:
421
+ _jax_wrapped_calc_poisson_diff = self._poisson_exponential(
422
+ id, init_params, logic)
423
+ else:
424
+ _jax_wrapped_calc_poisson_diff = self._poisson_gumbel_softmax(
425
+ id, init_params, logic)
426
+
427
+ def _jax_wrapped_calc_poisson_approx(key, rate, params):
428
+
429
+ # determine if error of truncation at rate is acceptable
430
+ if self.poisson_bins > 0:
431
+ cuml_prob = scipy.stats.poisson.cdf(self.poisson_bins, rate)
432
+ approx_cond = jax.lax.stop_gradient(
433
+ jnp.min(cuml_prob) > self.poisson_min_cdf)
434
+ else:
435
+ approx_cond = False
436
+
437
+ # for acceptable truncation use the approximation, use exact otherwise
438
+ return jax.lax.cond(
439
+ approx_cond,
440
+ _jax_wrapped_calc_poisson_diff,
441
+ _jax_wrapped_calc_poisson_exact,
442
+ key, rate, params
443
+ )
444
+ return _jax_wrapped_calc_poisson_approx
445
+
446
+ def binomial(self, id, init_params, logic):
447
+ def _jax_wrapped_calc_binomial_exact(key, trials, prob, params):
448
+ trials = jnp.asarray(trials, dtype=logic.REAL)
449
+ prob = jnp.asarray(prob, dtype=logic.REAL)
450
+ sample = random.binomial(key=key, n=trials, p=prob, dtype=logic.REAL)
451
+ return sample, params
452
+
453
+ # Binomial(n, p) = sum_{i = 1 ... n} Bernoulli(p)
454
+ bernoulli_approx = self.bernoulli(id, init_params, logic)
455
+ def _jax_wrapped_calc_binomial_sum(key, trials, prob, params):
456
+ prob_full = jnp.broadcast_to(
457
+ prob[..., jnp.newaxis], shape=jnp.shape(prob) + (self.binomial_bins,))
458
+ sample_bern, params = bernoulli_approx(key, prob_full, params)
459
+ indices = jnp.arange(self.binomial_bins)[
460
+ (jnp.newaxis,) * jnp.ndim(prob) + (...,)]
461
+ mask = indices < trials[..., jnp.newaxis]
462
+ sample = jnp.sum(sample_bern * mask, axis=-1)
463
+ return sample, params
464
+
465
+ # for trials not too large use the Bernoulli relaxation, use exact otherwise
466
+ def _jax_wrapped_calc_binomial_approx(key, trials, prob, params):
467
+ return jax.lax.cond(
468
+ jax.lax.stop_gradient(jnp.max(trials) < self.binomial_bins),
469
+ _jax_wrapped_calc_binomial_sum,
470
+ _jax_wrapped_calc_binomial_exact,
471
+ key, trials, prob, params
472
+ )
473
+ return _jax_wrapped_calc_binomial_approx
474
+
475
+ def geometric(self, id, init_params, logic):
476
+ approx_floor = logic.floor(id, init_params)
477
+ def _jax_wrapped_calc_geometric_approx(key, prob, params):
478
+ U = random.uniform(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
479
+ floor, params = approx_floor(jnp.log1p(-U) / jnp.log1p(-prob), params)
480
+ sample = floor + 1
481
+ return sample, params
482
+ return _jax_wrapped_calc_geometric_approx
483
+
484
+ def _bernoulli_uniform(self, id, init_params, logic):
485
+ less_approx = logic.less(id, init_params)
486
+ def _jax_wrapped_calc_bernoulli_uniform(key, prob, params):
487
+ U = random.uniform(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
488
+ return less_approx(U, prob, params)
489
+ return _jax_wrapped_calc_bernoulli_uniform
490
+
491
+ def _bernoulli_gumbel_softmax(self, id, init_params, logic):
492
+ discrete_approx = self.discrete(id, init_params, logic)
493
+ def _jax_wrapped_calc_bernoulli_gumbel_softmax(key, prob, params):
494
+ prob = jnp.stack([1.0 - prob, prob], axis=-1)
495
+ return discrete_approx(key, prob, params)
496
+ return _jax_wrapped_calc_bernoulli_gumbel_softmax
497
+
498
+ def bernoulli(self, id, init_params, logic):
499
+ if self.bernoulli_gumbel_softmax:
500
+ return self._bernoulli_gumbel_softmax(id, init_params, logic)
501
+ else:
502
+ return self._bernoulli_uniform(id, init_params, logic)
503
+
369
504
  def __str__(self) -> str:
370
- return 'Gumbel-Softmax'
505
+ return 'SoftRandomSampling'
371
506
 
372
507
 
373
508
  class Determinization(RandomSampling):
374
509
  '''Random sampling of variables using their deterministic mean estimate.'''
375
-
510
+
376
511
  @staticmethod
377
512
  def _jax_wrapped_calc_discrete_determinized(key, prob, params):
378
- literals = enumerate_literals(prob.shape, axis=-1)
513
+ literals = enumerate_literals(jnp.shape(prob), axis=-1)
379
514
  sample = jnp.sum(literals * prob, axis=-1)
380
515
  return sample, params
381
-
516
+
382
517
  def discrete(self, id, init_params, logic):
383
518
  return self._jax_wrapped_calc_discrete_determinized
384
519
 
@@ -389,6 +524,14 @@ class Determinization(RandomSampling):
389
524
  def poisson(self, id, init_params, logic):
390
525
  return self._jax_wrapped_calc_poisson_determinized
391
526
 
527
+ @staticmethod
528
+ def _jax_wrapped_calc_binomial_determinized(key, trials, prob, params):
529
+ sample = trials * prob
530
+ return sample, params
531
+
532
+ def binomial(self, id, init_params, logic):
533
+ return self._jax_wrapped_calc_binomial_determinized
534
+
392
535
  @staticmethod
393
536
  def _jax_wrapped_calc_geometric_determinized(key, prob, params):
394
537
  sample = 1.0 / prob
@@ -397,6 +540,14 @@ class Determinization(RandomSampling):
397
540
  def geometric(self, id, init_params, logic):
398
541
  return self._jax_wrapped_calc_geometric_determinized
399
542
 
543
+ @staticmethod
544
+ def _jax_wrapped_calc_bernoulli_determinized(key, prob, params):
545
+ sample = prob
546
+ return sample, params
547
+
548
+ def bernoulli(self, id, init_params, logic):
549
+ return self._jax_wrapped_calc_bernoulli_determinized
550
+
400
551
  def __str__(self) -> str:
401
552
  return 'Deterministic'
402
553
 
@@ -435,8 +586,8 @@ class SoftControlFlow(ControlFlow):
435
586
  id_ = str(id)
436
587
  init_params[id_] = self.weight
437
588
  def _jax_wrapped_calc_switch_soft(pred, cases, params):
438
- literals = enumerate_literals(cases.shape, axis=0)
439
- pred = jnp.broadcast_to(pred[jnp.newaxis, ...], shape=cases.shape)
589
+ literals = enumerate_literals(jnp.shape(cases), axis=0)
590
+ pred = jnp.broadcast_to(pred[jnp.newaxis, ...], shape=jnp.shape(cases))
440
591
  proximity = -jnp.square(pred - literals)
441
592
  softcase = jax.nn.softmax(params[id_] * proximity, axis=0)
442
593
  sample = jnp.sum(cases * softcase, axis=0)
@@ -477,6 +628,94 @@ class Logic:
477
628
  self.INT = jnp.int32
478
629
  jax.config.update('jax_enable_x64', False)
479
630
 
631
+ @staticmethod
632
+ def wrap_logic(func):
633
+ def exact_func(id, init_params):
634
+ return func
635
+ return exact_func
636
+
637
+ def get_operator_dicts(self) -> Dict[str, Union[Callable, Dict[str, Callable]]]:
638
+ '''Returns a dictionary of all operators in the current logic.'''
639
+ return {
640
+ 'negative': self.wrap_logic(ExactLogic.exact_unary_function(jnp.negative)),
641
+ 'arithmetic': {
642
+ '+': self.wrap_logic(ExactLogic.exact_binary_function(jnp.add)),
643
+ '-': self.wrap_logic(ExactLogic.exact_binary_function(jnp.subtract)),
644
+ '*': self.wrap_logic(ExactLogic.exact_binary_function(jnp.multiply)),
645
+ '/': self.wrap_logic(ExactLogic.exact_binary_function(jnp.divide))
646
+ },
647
+ 'relational': {
648
+ '>=': self.greater_equal,
649
+ '<=': self.less_equal,
650
+ '<': self.less,
651
+ '>': self.greater,
652
+ '==': self.equal,
653
+ '~=': self.not_equal
654
+ },
655
+ 'logical_not': self.logical_not,
656
+ 'logical': {
657
+ '^': self.logical_and,
658
+ '&': self.logical_and,
659
+ '|': self.logical_or,
660
+ '~': self.xor,
661
+ '=>': self.implies,
662
+ '<=>': self.equiv
663
+ },
664
+ 'aggregation': {
665
+ 'sum': self.wrap_logic(ExactLogic.exact_aggregation(jnp.sum)),
666
+ 'avg': self.wrap_logic(ExactLogic.exact_aggregation(jnp.mean)),
667
+ 'prod': self.wrap_logic(ExactLogic.exact_aggregation(jnp.prod)),
668
+ 'minimum': self.wrap_logic(ExactLogic.exact_aggregation(jnp.min)),
669
+ 'maximum': self.wrap_logic(ExactLogic.exact_aggregation(jnp.max)),
670
+ 'forall': self.forall,
671
+ 'exists': self.exists,
672
+ 'argmin': self.argmin,
673
+ 'argmax': self.argmax
674
+ },
675
+ 'unary': {
676
+ 'abs': self.wrap_logic(ExactLogic.exact_unary_function(jnp.abs)),
677
+ 'sgn': self.sgn,
678
+ 'round': self.round,
679
+ 'floor': self.floor,
680
+ 'ceil': self.ceil,
681
+ 'cos': self.wrap_logic(ExactLogic.exact_unary_function(jnp.cos)),
682
+ 'sin': self.wrap_logic(ExactLogic.exact_unary_function(jnp.sin)),
683
+ 'tan': self.wrap_logic(ExactLogic.exact_unary_function(jnp.tan)),
684
+ 'acos': self.wrap_logic(ExactLogic.exact_unary_function(jnp.arccos)),
685
+ 'asin': self.wrap_logic(ExactLogic.exact_unary_function(jnp.arcsin)),
686
+ 'atan': self.wrap_logic(ExactLogic.exact_unary_function(jnp.arctan)),
687
+ 'cosh': self.wrap_logic(ExactLogic.exact_unary_function(jnp.cosh)),
688
+ 'sinh': self.wrap_logic(ExactLogic.exact_unary_function(jnp.sinh)),
689
+ 'tanh': self.wrap_logic(ExactLogic.exact_unary_function(jnp.tanh)),
690
+ 'exp': self.wrap_logic(ExactLogic.exact_unary_function(jnp.exp)),
691
+ 'ln': self.wrap_logic(ExactLogic.exact_unary_function(jnp.log)),
692
+ 'sqrt': self.sqrt,
693
+ 'lngamma': self.wrap_logic(ExactLogic.exact_unary_function(scipy.special.gammaln)),
694
+ 'gamma': self.wrap_logic(ExactLogic.exact_unary_function(scipy.special.gamma))
695
+ },
696
+ 'binary': {
697
+ 'div': self.div,
698
+ 'mod': self.mod,
699
+ 'fmod': self.mod,
700
+ 'min': self.wrap_logic(ExactLogic.exact_binary_function(jnp.minimum)),
701
+ 'max': self.wrap_logic(ExactLogic.exact_binary_function(jnp.maximum)),
702
+ 'pow': self.wrap_logic(ExactLogic.exact_binary_function(jnp.power)),
703
+ 'log': self.wrap_logic(ExactLogic.exact_binary_log),
704
+ 'hypot': self.wrap_logic(ExactLogic.exact_binary_function(jnp.hypot)),
705
+ },
706
+ 'control': {
707
+ 'if': self.control_if,
708
+ 'switch': self.control_switch
709
+ },
710
+ 'sampling': {
711
+ 'Bernoulli': self.bernoulli,
712
+ 'Discrete': self.discrete,
713
+ 'Poisson': self.poisson,
714
+ 'Geometric': self.geometric,
715
+ 'Binomial': self.binomial
716
+ }
717
+ }
718
+
480
719
  # ===========================================================================
481
720
  # logical operators
482
721
  # ===========================================================================
@@ -587,6 +826,9 @@ class Logic:
587
826
 
588
827
  def geometric(self, id, init_params):
589
828
  raise NotImplementedError
829
+
830
+ def binomial(self, id, init_params):
831
+ raise NotImplementedError
590
832
 
591
833
 
592
834
  class ExactLogic(Logic):
@@ -627,11 +869,11 @@ class ExactLogic(Logic):
627
869
  return self.exact_binary_function(jnp.logical_xor)
628
870
 
629
871
  @staticmethod
630
- def exact_binary_implies(x, y, params):
872
+ def _jax_wrapped_calc_implies_exact(x, y, params):
631
873
  return jnp.logical_or(jnp.logical_not(x), y), params
632
874
 
633
875
  def implies(self, id, init_params):
634
- return self.exact_binary_implies
876
+ return self._jax_wrapped_calc_implies_exact
635
877
 
636
878
  def equiv(self, id, init_params):
637
879
  return self.exact_binary_function(jnp.equal)
@@ -668,6 +910,10 @@ class ExactLogic(Logic):
668
910
  # special functions
669
911
  # ===========================================================================
670
912
 
913
+ @staticmethod
914
+ def exact_binary_log(x, y, params):
915
+ return jnp.log(x) / jnp.log(y), params
916
+
671
917
  def sgn(self, id, init_params):
672
918
  return self.exact_unary_function(jnp.sign)
673
919
 
@@ -704,54 +950,62 @@ class ExactLogic(Logic):
704
950
  # ===========================================================================
705
951
 
706
952
  @staticmethod
707
- def exact_if_then_else(c, a, b, params):
953
+ def _jax_wrapped_calc_if_then_else_exact(c, a, b, params):
708
954
  return jnp.where(c > 0.5, a, b), params
709
955
 
710
956
  def control_if(self, id, init_params):
711
- return self.exact_if_then_else
957
+ return self._jax_wrapped_calc_if_then_else_exact
712
958
 
713
959
  @staticmethod
714
- def exact_switch(pred, cases, params):
960
+ def _jax_wrapped_calc_switch_exact(pred, cases, params):
715
961
  pred = pred[jnp.newaxis, ...]
716
962
  sample = jnp.take_along_axis(cases, pred, axis=0)
717
963
  assert sample.shape[0] == 1
718
964
  return sample[0, ...], params
719
965
 
720
966
  def control_switch(self, id, init_params):
721
- return self.exact_switch
967
+ return self._jax_wrapped_calc_switch_exact
722
968
 
723
969
  # ===========================================================================
724
970
  # random variables
725
971
  # ===========================================================================
726
972
 
727
973
  @staticmethod
728
- def exact_discrete(key, prob, params):
729
- return random.categorical(key=key, logits=jnp.log(prob), axis=-1), params
974
+ def _jax_wrapped_calc_discrete_exact(key, prob, params):
975
+ sample = random.categorical(key=key, logits=jnp.log(prob), axis=-1)
976
+ return sample, params
730
977
 
731
978
  def discrete(self, id, init_params):
732
- return self.exact_discrete
979
+ return self._jax_wrapped_calc_discrete_exact
733
980
 
734
981
  @staticmethod
735
- def exact_bernoulli(key, prob, params):
982
+ def _jax_wrapped_calc_bernoulli_exact(key, prob, params):
736
983
  return random.bernoulli(key, prob), params
737
984
 
738
985
  def bernoulli(self, id, init_params):
739
- return self.exact_bernoulli
740
-
741
- @staticmethod
742
- def exact_poisson(key, rate, params):
743
- return random.poisson(key=key, lam=rate), params
986
+ return self._jax_wrapped_calc_bernoulli_exact
744
987
 
745
988
  def poisson(self, id, init_params):
746
- return self.exact_poisson
747
-
748
- @staticmethod
749
- def exact_geometric(key, prob, params):
750
- return random.geometric(key=key, p=prob), params
989
+ def _jax_wrapped_calc_poisson_exact(key, rate, params):
990
+ sample = random.poisson(key=key, lam=rate, dtype=self.INT)
991
+ return sample, params
992
+ return _jax_wrapped_calc_poisson_exact
751
993
 
752
994
  def geometric(self, id, init_params):
753
- return self.exact_geometric
754
-
995
+ def _jax_wrapped_calc_geometric_exact(key, prob, params):
996
+ sample = random.geometric(key=key, p=prob, dtype=self.INT)
997
+ return sample, params
998
+ return _jax_wrapped_calc_geometric_exact
999
+
1000
+ def binomial(self, id, init_params):
1001
+ def _jax_wrapped_calc_binomial_exact(key, trials, prob, params):
1002
+ trials = jnp.asarray(trials, dtype=self.REAL)
1003
+ prob = jnp.asarray(prob, dtype=self.REAL)
1004
+ sample = random.binomial(key=key, n=trials, p=prob, dtype=self.REAL)
1005
+ sample = jnp.asarray(sample, dtype=self.INT)
1006
+ return sample, params
1007
+ return _jax_wrapped_calc_binomial_exact
1008
+
755
1009
 
756
1010
  class FuzzyLogic(Logic):
757
1011
  '''A class representing fuzzy logic in JAX.'''
@@ -759,7 +1013,7 @@ class FuzzyLogic(Logic):
759
1013
  def __init__(self, tnorm: TNorm=ProductTNorm(),
760
1014
  complement: Complement=StandardComplement(),
761
1015
  comparison: Comparison=SigmoidComparison(),
762
- sampling: RandomSampling=GumbelSoftmax(),
1016
+ sampling: RandomSampling=SoftRandomSampling(),
763
1017
  rounding: Rounding=SoftRounding(),
764
1018
  control: ControlFlow=SoftControlFlow(),
765
1019
  eps: float=1e-15,
@@ -977,6 +1231,9 @@ class FuzzyLogic(Logic):
977
1231
 
978
1232
  def geometric(self, id, init_params):
979
1233
  return self.sampling.geometric(id, init_params, self)
1234
+
1235
+ def binomial(self, id, init_params):
1236
+ return self.sampling.binomial(id, init_params, self)
980
1237
 
981
1238
 
982
1239
  # ===========================================================================
@@ -1011,8 +1268,8 @@ def _test_logical():
1011
1268
  pred, w = _if(cond, +1, -1, w)
1012
1269
  return pred
1013
1270
 
1014
- x1 = jnp.asarray([1, 1, -1, -1, 0.1, 15, -0.5]).astype(float)
1015
- x2 = jnp.asarray([1, -1, 1, -1, 10, -30, 6]).astype(float)
1271
+ x1 = jnp.asarray([1, 1, -1, -1, 0.1, 15, -0.5], dtype=float)
1272
+ x2 = jnp.asarray([1, -1, 1, -1, 10, -30, 6], dtype=float)
1016
1273
  print(test_logic(x1, x2, init_params))
1017
1274
 
1018
1275