pyRDDLGym-jax 1.3__py3-none-any.whl → 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,8 +1,40 @@
1
- from typing import Optional, Set
1
+ # ***********************************************************************
2
+ # JAXPLAN
3
+ #
4
+ # Author: Michael Gimelfarb
5
+ #
6
+ # REFERENCES:
7
+ #
8
+ # [1] Gimelfarb, Michael, Ayal Taitler, and Scott Sanner. "JaxPlan and GurobiPlan:
9
+ # Optimization Baselines for Replanning in Discrete and Mixed Discrete-Continuous
10
+ # Probabilistic Domains." Proceedings of the International Conference on Automated
11
+ # Planning and Scheduling. Vol. 34. 2024.
12
+ #
13
+ # [2] Petersen, Felix, Christian Borgelt, Hilde Kuehne, and Oliver Deussen. "Learning with
14
+ # algorithmic supervision via continuous relaxations." Advances in Neural Information
15
+ # Processing Systems 34 (2021): 16520-16531.
16
+ #
17
+ # [3] Agustsson, Eirikur, and Lucas Theis. "Universally quantized neural compression."
18
+ # Advances in neural information processing systems 33 (2020): 12367-12376.
19
+ #
20
+ # [4] Gupta, Madan M., and J11043360726 Qi. "Theory of T-norms and fuzzy inference
21
+ # methods." Fuzzy sets and systems 40, no. 3 (1991): 431-450.
22
+ #
23
+ # [5] Jang, Eric, Shixiang Gu, and Ben Poole. "Categorical Reparametrization with
24
+ # Gumble-Softmax." In International Conference on Learning Representations (ICLR 2017).
25
+ # OpenReview. net, 2017.
26
+ #
27
+ # [6] Vafaii, H., Galor, D., & Yates, J. (2025). Poisson Variational Autoencoder.
28
+ # Advances in Neural Information Processing Systems, 37, 44871-44906.
29
+ #
30
+ # ***********************************************************************
31
+
32
+ from typing import Callable, Dict, Union
2
33
 
3
34
  import jax
4
35
  import jax.numpy as jnp
5
36
  import jax.random as random
37
+ import jax.scipy as scipy
6
38
 
7
39
 
8
40
  def enumerate_literals(shape, axis, dtype=jnp.int32):
@@ -78,7 +110,7 @@ class SigmoidComparison(Comparison):
78
110
  id_ = str(id)
79
111
  init_params[id_] = self.weight
80
112
  def _jax_wrapped_calc_argmax_approx(x, axis, params):
81
- literals = enumerate_literals(x.shape, axis=axis)
113
+ literals = enumerate_literals(jnp.shape(x), axis=axis)
82
114
  softmax = jax.nn.softmax(params[id_] * x, axis=axis)
83
115
  sample = jnp.sum(literals * softmax, axis=axis)
84
116
  return sample, params
@@ -296,62 +328,192 @@ class YagerTNorm(TNorm):
296
328
  # ===========================================================================
297
329
 
298
330
  class RandomSampling:
299
- '''An abstract class that describes how discrete and non-reparameterizable
300
- random variables are sampled.'''
331
+ '''Describes how non-reparameterizable random variables are sampled.'''
301
332
 
302
333
  def discrete(self, id, init_params, logic):
303
334
  raise NotImplementedError
304
335
 
305
- def bernoulli(self, id, init_params, logic):
306
- discrete_approx = self.discrete(id, init_params, logic)
307
- def _jax_wrapped_calc_bernoulli_approx(key, prob, params):
308
- prob = jnp.stack([1.0 - prob, prob], axis=-1)
309
- return discrete_approx(key, prob, params)
310
- return _jax_wrapped_calc_bernoulli_approx
311
-
312
- @staticmethod
313
- def _jax_wrapped_calc_poisson_exact(key, rate, params):
314
- sample = random.poisson(key=key, lam=rate, dtype=logic.INT)
315
- return sample, params
316
-
317
336
  def poisson(self, id, init_params, logic):
318
- return self._jax_wrapped_calc_poisson_exact
337
+ raise NotImplementedError
338
+
339
+ def binomial(self, id, init_params, logic):
340
+ raise NotImplementedError
319
341
 
320
342
  def geometric(self, id, init_params, logic):
321
- approx_floor = logic.floor(id, init_params)
322
- def _jax_wrapped_calc_geometric_approx(key, prob, params):
323
- U = random.uniform(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
324
- floor, params = approx_floor(jnp.log(U) / jnp.log(1.0 - prob), params)
325
- sample = floor + 1
326
- return sample, params
327
- return _jax_wrapped_calc_geometric_approx
343
+ raise NotImplementedError
344
+
345
+ def bernoulli(self, id, init_params, logic):
346
+ raise NotImplementedError
347
+
348
+ def __str__(self) -> str:
349
+ return 'RandomSampling'
328
350
 
329
351
 
330
- class GumbelSoftmax(RandomSampling):
352
+ class SoftRandomSampling(RandomSampling):
331
353
  '''Random sampling of discrete variables using Gumbel-softmax trick.'''
332
354
 
355
+ def __init__(self, poisson_max_bins: int=100,
356
+ poisson_min_cdf: float=0.999,
357
+ poisson_exp_sampling: bool=True,
358
+ binomial_max_bins: int=100,
359
+ bernoulli_gumbel_softmax: bool=False) -> None:
360
+ '''Creates a new instance of soft random sampling.
361
+
362
+ :param poisson_max_bins: maximum bins to use for Poisson distribution relaxation
363
+ :param poisson_min_cdf: minimum cdf value of Poisson within truncated region
364
+ in order to use Poisson relaxation
365
+ :param poisson_exp_sampling: whether to use Poisson process sampling method
366
+ instead of truncated Gumbel-Softmax
367
+ :param binomial_max_bins: maximum bins to use for Binomial distribution relaxation
368
+ :param bernoulli_gumbel_softmax: whether to use Gumbel-Softmax to approximate
369
+ Bernoulli samples, or the standard uniform reparameterization instead
370
+ '''
371
+ self.poisson_bins = poisson_max_bins
372
+ self.poisson_min_cdf = poisson_min_cdf
373
+ self.poisson_exp_method = poisson_exp_sampling
374
+ self.binomial_bins = binomial_max_bins
375
+ self.bernoulli_gumbel_softmax = bernoulli_gumbel_softmax
376
+
333
377
  # https://arxiv.org/pdf/1611.01144
334
378
  def discrete(self, id, init_params, logic):
335
379
  argmax_approx = logic.argmax(id, init_params)
336
380
  def _jax_wrapped_calc_discrete_gumbel_softmax(key, prob, params):
337
- Gumbel01 = random.gumbel(key=key, shape=prob.shape, dtype=logic.REAL)
381
+ Gumbel01 = random.gumbel(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
338
382
  sample = Gumbel01 + jnp.log(prob + logic.eps)
339
383
  return argmax_approx(sample, axis=-1, params=params)
340
384
  return _jax_wrapped_calc_discrete_gumbel_softmax
341
385
 
386
+ def _poisson_gumbel_softmax(self, id, init_params, logic):
387
+ argmax_approx = logic.argmax(id, init_params)
388
+ def _jax_wrapped_calc_poisson_gumbel_softmax(key, rate, params):
389
+ ks = jnp.arange(0, self.poisson_bins)
390
+ ks = ks[(jnp.newaxis,) * jnp.ndim(rate) + (...,)]
391
+ rate = rate[..., jnp.newaxis]
392
+ log_prob = ks * jnp.log(rate + logic.eps) - rate - scipy.special.gammaln(ks + 1)
393
+ Gumbel01 = random.gumbel(key=key, shape=jnp.shape(log_prob), dtype=logic.REAL)
394
+ sample = Gumbel01 + log_prob
395
+ return argmax_approx(sample, axis=-1, params=params)
396
+ return _jax_wrapped_calc_poisson_gumbel_softmax
397
+
398
+ # https://arxiv.org/abs/2405.14473
399
+ def _poisson_exponential(self, id, init_params, logic):
400
+ less_approx = logic.less(id, init_params)
401
+ def _jax_wrapped_calc_poisson_exponential(key, rate, params):
402
+ Exp1 = random.exponential(
403
+ key=key,
404
+ shape=(self.poisson_bins,) + jnp.shape(rate),
405
+ dtype=logic.REAL
406
+ )
407
+ delta_t = Exp1 / rate[jnp.newaxis, ...]
408
+ times = jnp.cumsum(delta_t, axis=0)
409
+ indicator, params = less_approx(times, 1.0, params)
410
+ sample = jnp.sum(indicator, axis=0)
411
+ return sample, params
412
+ return _jax_wrapped_calc_poisson_exponential
413
+
414
+ def poisson(self, id, init_params, logic):
415
+ def _jax_wrapped_calc_poisson_exact(key, rate, params):
416
+ sample = random.poisson(key=key, lam=rate, dtype=logic.INT)
417
+ sample = jnp.asarray(sample, dtype=logic.REAL)
418
+ return sample, params
419
+
420
+ if self.poisson_exp_method:
421
+ _jax_wrapped_calc_poisson_diff = self._poisson_exponential(
422
+ id, init_params, logic)
423
+ else:
424
+ _jax_wrapped_calc_poisson_diff = self._poisson_gumbel_softmax(
425
+ id, init_params, logic)
426
+
427
+ def _jax_wrapped_calc_poisson_approx(key, rate, params):
428
+
429
+ # determine if error of truncation at rate is acceptable
430
+ if self.poisson_bins > 0:
431
+ cuml_prob = scipy.stats.poisson.cdf(self.poisson_bins, rate)
432
+ approx_cond = jax.lax.stop_gradient(
433
+ jnp.min(cuml_prob) > self.poisson_min_cdf)
434
+ else:
435
+ approx_cond = False
436
+
437
+ # for acceptable truncation use the approximation, use exact otherwise
438
+ return jax.lax.cond(
439
+ approx_cond,
440
+ _jax_wrapped_calc_poisson_diff,
441
+ _jax_wrapped_calc_poisson_exact,
442
+ key, rate, params
443
+ )
444
+ return _jax_wrapped_calc_poisson_approx
445
+
446
+ def binomial(self, id, init_params, logic):
447
+ def _jax_wrapped_calc_binomial_exact(key, trials, prob, params):
448
+ trials = jnp.asarray(trials, dtype=logic.REAL)
449
+ prob = jnp.asarray(prob, dtype=logic.REAL)
450
+ sample = random.binomial(key=key, n=trials, p=prob, dtype=logic.REAL)
451
+ return sample, params
452
+
453
+ # Binomial(n, p) = sum_{i = 1 ... n} Bernoulli(p)
454
+ bernoulli_approx = self.bernoulli(id, init_params, logic)
455
+ def _jax_wrapped_calc_binomial_sum(key, trials, prob, params):
456
+ prob_full = jnp.broadcast_to(
457
+ prob[..., jnp.newaxis], shape=jnp.shape(prob) + (self.binomial_bins,))
458
+ sample_bern, params = bernoulli_approx(key, prob_full, params)
459
+ indices = jnp.arange(self.binomial_bins)[
460
+ (jnp.newaxis,) * jnp.ndim(prob) + (...,)]
461
+ mask = indices < trials[..., jnp.newaxis]
462
+ sample = jnp.sum(sample_bern * mask, axis=-1)
463
+ return sample, params
464
+
465
+ # for trials not too large use the Bernoulli relaxation, use exact otherwise
466
+ def _jax_wrapped_calc_binomial_approx(key, trials, prob, params):
467
+ return jax.lax.cond(
468
+ jax.lax.stop_gradient(jnp.max(trials) < self.binomial_bins),
469
+ _jax_wrapped_calc_binomial_sum,
470
+ _jax_wrapped_calc_binomial_exact,
471
+ key, trials, prob, params
472
+ )
473
+ return _jax_wrapped_calc_binomial_approx
474
+
475
+ def geometric(self, id, init_params, logic):
476
+ approx_floor = logic.floor(id, init_params)
477
+ def _jax_wrapped_calc_geometric_approx(key, prob, params):
478
+ U = random.uniform(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
479
+ floor, params = approx_floor(jnp.log1p(-U) / jnp.log1p(-prob), params)
480
+ sample = floor + 1
481
+ return sample, params
482
+ return _jax_wrapped_calc_geometric_approx
483
+
484
+ def _bernoulli_uniform(self, id, init_params, logic):
485
+ less_approx = logic.less(id, init_params)
486
+ def _jax_wrapped_calc_bernoulli_uniform(key, prob, params):
487
+ U = random.uniform(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
488
+ return less_approx(U, prob, params)
489
+ return _jax_wrapped_calc_bernoulli_uniform
490
+
491
+ def _bernoulli_gumbel_softmax(self, id, init_params, logic):
492
+ discrete_approx = self.discrete(id, init_params, logic)
493
+ def _jax_wrapped_calc_bernoulli_gumbel_softmax(key, prob, params):
494
+ prob = jnp.stack([1.0 - prob, prob], axis=-1)
495
+ return discrete_approx(key, prob, params)
496
+ return _jax_wrapped_calc_bernoulli_gumbel_softmax
497
+
498
+ def bernoulli(self, id, init_params, logic):
499
+ if self.bernoulli_gumbel_softmax:
500
+ return self._bernoulli_gumbel_softmax(id, init_params, logic)
501
+ else:
502
+ return self._bernoulli_uniform(id, init_params, logic)
503
+
342
504
  def __str__(self) -> str:
343
- return 'Gumbel-Softmax'
505
+ return 'SoftRandomSampling'
344
506
 
345
507
 
346
508
  class Determinization(RandomSampling):
347
509
  '''Random sampling of variables using their deterministic mean estimate.'''
348
-
510
+
349
511
  @staticmethod
350
512
  def _jax_wrapped_calc_discrete_determinized(key, prob, params):
351
- literals = enumerate_literals(prob.shape, axis=-1)
513
+ literals = enumerate_literals(jnp.shape(prob), axis=-1)
352
514
  sample = jnp.sum(literals * prob, axis=-1)
353
515
  return sample, params
354
-
516
+
355
517
  def discrete(self, id, init_params, logic):
356
518
  return self._jax_wrapped_calc_discrete_determinized
357
519
 
@@ -362,6 +524,14 @@ class Determinization(RandomSampling):
362
524
  def poisson(self, id, init_params, logic):
363
525
  return self._jax_wrapped_calc_poisson_determinized
364
526
 
527
+ @staticmethod
528
+ def _jax_wrapped_calc_binomial_determinized(key, trials, prob, params):
529
+ sample = trials * prob
530
+ return sample, params
531
+
532
+ def binomial(self, id, init_params, logic):
533
+ return self._jax_wrapped_calc_binomial_determinized
534
+
365
535
  @staticmethod
366
536
  def _jax_wrapped_calc_geometric_determinized(key, prob, params):
367
537
  sample = 1.0 / prob
@@ -370,6 +540,14 @@ class Determinization(RandomSampling):
370
540
  def geometric(self, id, init_params, logic):
371
541
  return self._jax_wrapped_calc_geometric_determinized
372
542
 
543
+ @staticmethod
544
+ def _jax_wrapped_calc_bernoulli_determinized(key, prob, params):
545
+ sample = prob
546
+ return sample, params
547
+
548
+ def bernoulli(self, id, init_params, logic):
549
+ return self._jax_wrapped_calc_bernoulli_determinized
550
+
373
551
  def __str__(self) -> str:
374
552
  return 'Deterministic'
375
553
 
@@ -408,8 +586,8 @@ class SoftControlFlow(ControlFlow):
408
586
  id_ = str(id)
409
587
  init_params[id_] = self.weight
410
588
  def _jax_wrapped_calc_switch_soft(pred, cases, params):
411
- literals = enumerate_literals(cases.shape, axis=0)
412
- pred = jnp.broadcast_to(pred[jnp.newaxis, ...], shape=cases.shape)
589
+ literals = enumerate_literals(jnp.shape(cases), axis=0)
590
+ pred = jnp.broadcast_to(pred[jnp.newaxis, ...], shape=jnp.shape(cases))
413
591
  proximity = -jnp.square(pred - literals)
414
592
  softcase = jax.nn.softmax(params[id_] * proximity, axis=0)
415
593
  sample = jnp.sum(cases * softcase, axis=0)
@@ -450,6 +628,94 @@ class Logic:
450
628
  self.INT = jnp.int32
451
629
  jax.config.update('jax_enable_x64', False)
452
630
 
631
+ @staticmethod
632
+ def wrap_logic(func):
633
+ def exact_func(id, init_params):
634
+ return func
635
+ return exact_func
636
+
637
+ def get_operator_dicts(self) -> Dict[str, Union[Callable, Dict[str, Callable]]]:
638
+ '''Returns a dictionary of all operators in the current logic.'''
639
+ return {
640
+ 'negative': self.wrap_logic(ExactLogic.exact_unary_function(jnp.negative)),
641
+ 'arithmetic': {
642
+ '+': self.wrap_logic(ExactLogic.exact_binary_function(jnp.add)),
643
+ '-': self.wrap_logic(ExactLogic.exact_binary_function(jnp.subtract)),
644
+ '*': self.wrap_logic(ExactLogic.exact_binary_function(jnp.multiply)),
645
+ '/': self.wrap_logic(ExactLogic.exact_binary_function(jnp.divide))
646
+ },
647
+ 'relational': {
648
+ '>=': self.greater_equal,
649
+ '<=': self.less_equal,
650
+ '<': self.less,
651
+ '>': self.greater,
652
+ '==': self.equal,
653
+ '~=': self.not_equal
654
+ },
655
+ 'logical_not': self.logical_not,
656
+ 'logical': {
657
+ '^': self.logical_and,
658
+ '&': self.logical_and,
659
+ '|': self.logical_or,
660
+ '~': self.xor,
661
+ '=>': self.implies,
662
+ '<=>': self.equiv
663
+ },
664
+ 'aggregation': {
665
+ 'sum': self.wrap_logic(ExactLogic.exact_aggregation(jnp.sum)),
666
+ 'avg': self.wrap_logic(ExactLogic.exact_aggregation(jnp.mean)),
667
+ 'prod': self.wrap_logic(ExactLogic.exact_aggregation(jnp.prod)),
668
+ 'minimum': self.wrap_logic(ExactLogic.exact_aggregation(jnp.min)),
669
+ 'maximum': self.wrap_logic(ExactLogic.exact_aggregation(jnp.max)),
670
+ 'forall': self.forall,
671
+ 'exists': self.exists,
672
+ 'argmin': self.argmin,
673
+ 'argmax': self.argmax
674
+ },
675
+ 'unary': {
676
+ 'abs': self.wrap_logic(ExactLogic.exact_unary_function(jnp.abs)),
677
+ 'sgn': self.sgn,
678
+ 'round': self.round,
679
+ 'floor': self.floor,
680
+ 'ceil': self.ceil,
681
+ 'cos': self.wrap_logic(ExactLogic.exact_unary_function(jnp.cos)),
682
+ 'sin': self.wrap_logic(ExactLogic.exact_unary_function(jnp.sin)),
683
+ 'tan': self.wrap_logic(ExactLogic.exact_unary_function(jnp.tan)),
684
+ 'acos': self.wrap_logic(ExactLogic.exact_unary_function(jnp.arccos)),
685
+ 'asin': self.wrap_logic(ExactLogic.exact_unary_function(jnp.arcsin)),
686
+ 'atan': self.wrap_logic(ExactLogic.exact_unary_function(jnp.arctan)),
687
+ 'cosh': self.wrap_logic(ExactLogic.exact_unary_function(jnp.cosh)),
688
+ 'sinh': self.wrap_logic(ExactLogic.exact_unary_function(jnp.sinh)),
689
+ 'tanh': self.wrap_logic(ExactLogic.exact_unary_function(jnp.tanh)),
690
+ 'exp': self.wrap_logic(ExactLogic.exact_unary_function(jnp.exp)),
691
+ 'ln': self.wrap_logic(ExactLogic.exact_unary_function(jnp.log)),
692
+ 'sqrt': self.sqrt,
693
+ 'lngamma': self.wrap_logic(ExactLogic.exact_unary_function(scipy.special.gammaln)),
694
+ 'gamma': self.wrap_logic(ExactLogic.exact_unary_function(scipy.special.gamma))
695
+ },
696
+ 'binary': {
697
+ 'div': self.div,
698
+ 'mod': self.mod,
699
+ 'fmod': self.mod,
700
+ 'min': self.wrap_logic(ExactLogic.exact_binary_function(jnp.minimum)),
701
+ 'max': self.wrap_logic(ExactLogic.exact_binary_function(jnp.maximum)),
702
+ 'pow': self.wrap_logic(ExactLogic.exact_binary_function(jnp.power)),
703
+ 'log': self.wrap_logic(ExactLogic.exact_binary_log),
704
+ 'hypot': self.wrap_logic(ExactLogic.exact_binary_function(jnp.hypot)),
705
+ },
706
+ 'control': {
707
+ 'if': self.control_if,
708
+ 'switch': self.control_switch
709
+ },
710
+ 'sampling': {
711
+ 'Bernoulli': self.bernoulli,
712
+ 'Discrete': self.discrete,
713
+ 'Poisson': self.poisson,
714
+ 'Geometric': self.geometric,
715
+ 'Binomial': self.binomial
716
+ }
717
+ }
718
+
453
719
  # ===========================================================================
454
720
  # logical operators
455
721
  # ===========================================================================
@@ -560,6 +826,9 @@ class Logic:
560
826
 
561
827
  def geometric(self, id, init_params):
562
828
  raise NotImplementedError
829
+
830
+ def binomial(self, id, init_params):
831
+ raise NotImplementedError
563
832
 
564
833
 
565
834
  class ExactLogic(Logic):
@@ -600,11 +869,11 @@ class ExactLogic(Logic):
600
869
  return self.exact_binary_function(jnp.logical_xor)
601
870
 
602
871
  @staticmethod
603
- def exact_binary_implies(x, y, params):
872
+ def _jax_wrapped_calc_implies_exact(x, y, params):
604
873
  return jnp.logical_or(jnp.logical_not(x), y), params
605
874
 
606
875
  def implies(self, id, init_params):
607
- return self.exact_binary_implies
876
+ return self._jax_wrapped_calc_implies_exact
608
877
 
609
878
  def equiv(self, id, init_params):
610
879
  return self.exact_binary_function(jnp.equal)
@@ -641,6 +910,10 @@ class ExactLogic(Logic):
641
910
  # special functions
642
911
  # ===========================================================================
643
912
 
913
+ @staticmethod
914
+ def exact_binary_log(x, y, params):
915
+ return jnp.log(x) / jnp.log(y), params
916
+
644
917
  def sgn(self, id, init_params):
645
918
  return self.exact_unary_function(jnp.sign)
646
919
 
@@ -677,54 +950,62 @@ class ExactLogic(Logic):
677
950
  # ===========================================================================
678
951
 
679
952
  @staticmethod
680
- def exact_if_then_else(c, a, b, params):
953
+ def _jax_wrapped_calc_if_then_else_exact(c, a, b, params):
681
954
  return jnp.where(c > 0.5, a, b), params
682
955
 
683
956
  def control_if(self, id, init_params):
684
- return self.exact_if_then_else
957
+ return self._jax_wrapped_calc_if_then_else_exact
685
958
 
686
959
  @staticmethod
687
- def exact_switch(pred, cases, params):
960
+ def _jax_wrapped_calc_switch_exact(pred, cases, params):
688
961
  pred = pred[jnp.newaxis, ...]
689
962
  sample = jnp.take_along_axis(cases, pred, axis=0)
690
963
  assert sample.shape[0] == 1
691
964
  return sample[0, ...], params
692
965
 
693
966
  def control_switch(self, id, init_params):
694
- return self.exact_switch
967
+ return self._jax_wrapped_calc_switch_exact
695
968
 
696
969
  # ===========================================================================
697
970
  # random variables
698
971
  # ===========================================================================
699
972
 
700
973
  @staticmethod
701
- def exact_discrete(key, prob, params):
702
- return random.categorical(key=key, logits=jnp.log(prob), axis=-1), params
974
+ def _jax_wrapped_calc_discrete_exact(key, prob, params):
975
+ sample = random.categorical(key=key, logits=jnp.log(prob), axis=-1)
976
+ return sample, params
703
977
 
704
978
  def discrete(self, id, init_params):
705
- return self.exact_discrete
979
+ return self._jax_wrapped_calc_discrete_exact
706
980
 
707
981
  @staticmethod
708
- def exact_bernoulli(key, prob, params):
982
+ def _jax_wrapped_calc_bernoulli_exact(key, prob, params):
709
983
  return random.bernoulli(key, prob), params
710
984
 
711
985
  def bernoulli(self, id, init_params):
712
- return self.exact_bernoulli
713
-
714
- @staticmethod
715
- def exact_poisson(key, rate, params):
716
- return random.poisson(key=key, lam=rate), params
986
+ return self._jax_wrapped_calc_bernoulli_exact
717
987
 
718
988
  def poisson(self, id, init_params):
719
- return self.exact_poisson
720
-
721
- @staticmethod
722
- def exact_geometric(key, prob, params):
723
- return random.geometric(key=key, p=prob), params
989
+ def _jax_wrapped_calc_poisson_exact(key, rate, params):
990
+ sample = random.poisson(key=key, lam=rate, dtype=self.INT)
991
+ return sample, params
992
+ return _jax_wrapped_calc_poisson_exact
724
993
 
725
994
  def geometric(self, id, init_params):
726
- return self.exact_geometric
727
-
995
+ def _jax_wrapped_calc_geometric_exact(key, prob, params):
996
+ sample = random.geometric(key=key, p=prob, dtype=self.INT)
997
+ return sample, params
998
+ return _jax_wrapped_calc_geometric_exact
999
+
1000
+ def binomial(self, id, init_params):
1001
+ def _jax_wrapped_calc_binomial_exact(key, trials, prob, params):
1002
+ trials = jnp.asarray(trials, dtype=self.REAL)
1003
+ prob = jnp.asarray(prob, dtype=self.REAL)
1004
+ sample = random.binomial(key=key, n=trials, p=prob, dtype=self.REAL)
1005
+ sample = jnp.asarray(sample, dtype=self.INT)
1006
+ return sample, params
1007
+ return _jax_wrapped_calc_binomial_exact
1008
+
728
1009
 
729
1010
  class FuzzyLogic(Logic):
730
1011
  '''A class representing fuzzy logic in JAX.'''
@@ -732,7 +1013,7 @@ class FuzzyLogic(Logic):
732
1013
  def __init__(self, tnorm: TNorm=ProductTNorm(),
733
1014
  complement: Complement=StandardComplement(),
734
1015
  comparison: Comparison=SigmoidComparison(),
735
- sampling: RandomSampling=GumbelSoftmax(),
1016
+ sampling: RandomSampling=SoftRandomSampling(),
736
1017
  rounding: Rounding=SoftRounding(),
737
1018
  control: ControlFlow=SoftControlFlow(),
738
1019
  eps: float=1e-15,
@@ -759,14 +1040,14 @@ class FuzzyLogic(Logic):
759
1040
 
760
1041
  def __str__(self) -> str:
761
1042
  return (f'model relaxation:\n'
762
- f' tnorm ={str(self.tnorm)}\n'
763
- f' complement ={str(self.complement)}\n'
764
- f' comparison ={str(self.comparison)}\n'
765
- f' sampling ={str(self.sampling)}\n'
766
- f' rounding ={str(self.rounding)}\n'
767
- f' control ={str(self.control)}\n'
768
- f' underflow_tol ={self.eps}\n'
769
- f' use_64_bit ={self.use64bit}')
1043
+ f' tnorm ={str(self.tnorm)}\n'
1044
+ f' complement ={str(self.complement)}\n'
1045
+ f' comparison ={str(self.comparison)}\n'
1046
+ f' sampling ={str(self.sampling)}\n'
1047
+ f' rounding ={str(self.rounding)}\n'
1048
+ f' control ={str(self.control)}\n'
1049
+ f' underflow_tol={self.eps}\n'
1050
+ f' use_64_bit ={self.use64bit}\n')
770
1051
 
771
1052
  def summarize_hyperparameters(self) -> None:
772
1053
  print(self.__str__())
@@ -950,6 +1231,9 @@ class FuzzyLogic(Logic):
950
1231
 
951
1232
  def geometric(self, id, init_params):
952
1233
  return self.sampling.geometric(id, init_params, self)
1234
+
1235
+ def binomial(self, id, init_params):
1236
+ return self.sampling.binomial(id, init_params, self)
953
1237
 
954
1238
 
955
1239
  # ===========================================================================
@@ -984,8 +1268,8 @@ def _test_logical():
984
1268
  pred, w = _if(cond, +1, -1, w)
985
1269
  return pred
986
1270
 
987
- x1 = jnp.asarray([1, 1, -1, -1, 0.1, 15, -0.5]).astype(float)
988
- x2 = jnp.asarray([1, -1, 1, -1, 10, -30, 6]).astype(float)
1271
+ x1 = jnp.asarray([1, 1, -1, -1, 0.1, 15, -0.5], dtype=float)
1272
+ x2 = jnp.asarray([1, -1, 1, -1, 10, -30, 6], dtype=float)
989
1273
  print(test_logic(x1, x2, init_params))
990
1274
 
991
1275