pyRDDLGym-jax 2.0__py3-none-any.whl → 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +85 -190
- pyRDDLGym_jax/core/logic.py +313 -56
- pyRDDLGym_jax/core/planner.py +121 -130
- pyRDDLGym_jax/core/visualization.py +7 -8
- pyRDDLGym_jax/examples/run_tune.py +10 -6
- {pyRDDLGym_jax-2.0.dist-info → pyrddlgym_jax-2.1.dist-info}/METADATA +22 -12
- {pyRDDLGym_jax-2.0.dist-info → pyrddlgym_jax-2.1.dist-info}/RECORD +12 -12
- {pyRDDLGym_jax-2.0.dist-info → pyrddlgym_jax-2.1.dist-info}/WHEEL +1 -1
- {pyRDDLGym_jax-2.0.dist-info → pyrddlgym_jax-2.1.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-2.0.dist-info → pyrddlgym_jax-2.1.dist-info}/entry_points.txt +0 -0
- {pyRDDLGym_jax-2.0.dist-info → pyrddlgym_jax-2.1.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/logic.py
CHANGED
|
@@ -24,12 +24,17 @@
|
|
|
24
24
|
# Gumble-Softmax." In International Conference on Learning Representations (ICLR 2017).
|
|
25
25
|
# OpenReview. net, 2017.
|
|
26
26
|
#
|
|
27
|
+
# [6] Vafaii, H., Galor, D., & Yates, J. (2025). Poisson Variational Autoencoder.
|
|
28
|
+
# Advances in Neural Information Processing Systems, 37, 44871-44906.
|
|
29
|
+
#
|
|
27
30
|
# ***********************************************************************
|
|
28
31
|
|
|
32
|
+
from typing import Callable, Dict, Union
|
|
29
33
|
|
|
30
34
|
import jax
|
|
31
35
|
import jax.numpy as jnp
|
|
32
36
|
import jax.random as random
|
|
37
|
+
import jax.scipy as scipy
|
|
33
38
|
|
|
34
39
|
|
|
35
40
|
def enumerate_literals(shape, axis, dtype=jnp.int32):
|
|
@@ -105,7 +110,7 @@ class SigmoidComparison(Comparison):
|
|
|
105
110
|
id_ = str(id)
|
|
106
111
|
init_params[id_] = self.weight
|
|
107
112
|
def _jax_wrapped_calc_argmax_approx(x, axis, params):
|
|
108
|
-
literals = enumerate_literals(
|
|
113
|
+
literals = enumerate_literals(jnp.shape(x), axis=axis)
|
|
109
114
|
softmax = jax.nn.softmax(params[id_] * x, axis=axis)
|
|
110
115
|
sample = jnp.sum(literals * softmax, axis=axis)
|
|
111
116
|
return sample, params
|
|
@@ -323,62 +328,192 @@ class YagerTNorm(TNorm):
|
|
|
323
328
|
# ===========================================================================
|
|
324
329
|
|
|
325
330
|
class RandomSampling:
|
|
326
|
-
'''
|
|
327
|
-
random variables are sampled.'''
|
|
331
|
+
'''Describes how non-reparameterizable random variables are sampled.'''
|
|
328
332
|
|
|
329
333
|
def discrete(self, id, init_params, logic):
|
|
330
334
|
raise NotImplementedError
|
|
331
335
|
|
|
332
|
-
def bernoulli(self, id, init_params, logic):
|
|
333
|
-
discrete_approx = self.discrete(id, init_params, logic)
|
|
334
|
-
def _jax_wrapped_calc_bernoulli_approx(key, prob, params):
|
|
335
|
-
prob = jnp.stack([1.0 - prob, prob], axis=-1)
|
|
336
|
-
return discrete_approx(key, prob, params)
|
|
337
|
-
return _jax_wrapped_calc_bernoulli_approx
|
|
338
|
-
|
|
339
|
-
@staticmethod
|
|
340
|
-
def _jax_wrapped_calc_poisson_exact(key, rate, params):
|
|
341
|
-
sample = random.poisson(key=key, lam=rate, dtype=logic.INT)
|
|
342
|
-
return sample, params
|
|
343
|
-
|
|
344
336
|
def poisson(self, id, init_params, logic):
|
|
345
|
-
|
|
337
|
+
raise NotImplementedError
|
|
338
|
+
|
|
339
|
+
def binomial(self, id, init_params, logic):
|
|
340
|
+
raise NotImplementedError
|
|
346
341
|
|
|
347
342
|
def geometric(self, id, init_params, logic):
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
return
|
|
343
|
+
raise NotImplementedError
|
|
344
|
+
|
|
345
|
+
def bernoulli(self, id, init_params, logic):
|
|
346
|
+
raise NotImplementedError
|
|
347
|
+
|
|
348
|
+
def __str__(self) -> str:
|
|
349
|
+
return 'RandomSampling'
|
|
355
350
|
|
|
356
351
|
|
|
357
|
-
class
|
|
352
|
+
class SoftRandomSampling(RandomSampling):
|
|
358
353
|
'''Random sampling of discrete variables using Gumbel-softmax trick.'''
|
|
359
354
|
|
|
355
|
+
def __init__(self, poisson_max_bins: int=100,
|
|
356
|
+
poisson_min_cdf: float=0.999,
|
|
357
|
+
poisson_exp_sampling: bool=True,
|
|
358
|
+
binomial_max_bins: int=100,
|
|
359
|
+
bernoulli_gumbel_softmax: bool=False) -> None:
|
|
360
|
+
'''Creates a new instance of soft random sampling.
|
|
361
|
+
|
|
362
|
+
:param poisson_max_bins: maximum bins to use for Poisson distribution relaxation
|
|
363
|
+
:param poisson_min_cdf: minimum cdf value of Poisson within truncated region
|
|
364
|
+
in order to use Poisson relaxation
|
|
365
|
+
:param poisson_exp_sampling: whether to use Poisson process sampling method
|
|
366
|
+
instead of truncated Gumbel-Softmax
|
|
367
|
+
:param binomial_max_bins: maximum bins to use for Binomial distribution relaxation
|
|
368
|
+
:param bernoulli_gumbel_softmax: whether to use Gumbel-Softmax to approximate
|
|
369
|
+
Bernoulli samples, or the standard uniform reparameterization instead
|
|
370
|
+
'''
|
|
371
|
+
self.poisson_bins = poisson_max_bins
|
|
372
|
+
self.poisson_min_cdf = poisson_min_cdf
|
|
373
|
+
self.poisson_exp_method = poisson_exp_sampling
|
|
374
|
+
self.binomial_bins = binomial_max_bins
|
|
375
|
+
self.bernoulli_gumbel_softmax = bernoulli_gumbel_softmax
|
|
376
|
+
|
|
360
377
|
# https://arxiv.org/pdf/1611.01144
|
|
361
378
|
def discrete(self, id, init_params, logic):
|
|
362
379
|
argmax_approx = logic.argmax(id, init_params)
|
|
363
380
|
def _jax_wrapped_calc_discrete_gumbel_softmax(key, prob, params):
|
|
364
|
-
Gumbel01 = random.gumbel(key=key, shape=
|
|
381
|
+
Gumbel01 = random.gumbel(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
|
|
365
382
|
sample = Gumbel01 + jnp.log(prob + logic.eps)
|
|
366
383
|
return argmax_approx(sample, axis=-1, params=params)
|
|
367
384
|
return _jax_wrapped_calc_discrete_gumbel_softmax
|
|
368
385
|
|
|
386
|
+
def _poisson_gumbel_softmax(self, id, init_params, logic):
|
|
387
|
+
argmax_approx = logic.argmax(id, init_params)
|
|
388
|
+
def _jax_wrapped_calc_poisson_gumbel_softmax(key, rate, params):
|
|
389
|
+
ks = jnp.arange(0, self.poisson_bins)
|
|
390
|
+
ks = ks[(jnp.newaxis,) * jnp.ndim(rate) + (...,)]
|
|
391
|
+
rate = rate[..., jnp.newaxis]
|
|
392
|
+
log_prob = ks * jnp.log(rate + logic.eps) - rate - scipy.special.gammaln(ks + 1)
|
|
393
|
+
Gumbel01 = random.gumbel(key=key, shape=jnp.shape(log_prob), dtype=logic.REAL)
|
|
394
|
+
sample = Gumbel01 + log_prob
|
|
395
|
+
return argmax_approx(sample, axis=-1, params=params)
|
|
396
|
+
return _jax_wrapped_calc_poisson_gumbel_softmax
|
|
397
|
+
|
|
398
|
+
# https://arxiv.org/abs/2405.14473
|
|
399
|
+
def _poisson_exponential(self, id, init_params, logic):
|
|
400
|
+
less_approx = logic.less(id, init_params)
|
|
401
|
+
def _jax_wrapped_calc_poisson_exponential(key, rate, params):
|
|
402
|
+
Exp1 = random.exponential(
|
|
403
|
+
key=key,
|
|
404
|
+
shape=(self.poisson_bins,) + jnp.shape(rate),
|
|
405
|
+
dtype=logic.REAL
|
|
406
|
+
)
|
|
407
|
+
delta_t = Exp1 / rate[jnp.newaxis, ...]
|
|
408
|
+
times = jnp.cumsum(delta_t, axis=0)
|
|
409
|
+
indicator, params = less_approx(times, 1.0, params)
|
|
410
|
+
sample = jnp.sum(indicator, axis=0)
|
|
411
|
+
return sample, params
|
|
412
|
+
return _jax_wrapped_calc_poisson_exponential
|
|
413
|
+
|
|
414
|
+
def poisson(self, id, init_params, logic):
|
|
415
|
+
def _jax_wrapped_calc_poisson_exact(key, rate, params):
|
|
416
|
+
sample = random.poisson(key=key, lam=rate, dtype=logic.INT)
|
|
417
|
+
sample = jnp.asarray(sample, dtype=logic.REAL)
|
|
418
|
+
return sample, params
|
|
419
|
+
|
|
420
|
+
if self.poisson_exp_method:
|
|
421
|
+
_jax_wrapped_calc_poisson_diff = self._poisson_exponential(
|
|
422
|
+
id, init_params, logic)
|
|
423
|
+
else:
|
|
424
|
+
_jax_wrapped_calc_poisson_diff = self._poisson_gumbel_softmax(
|
|
425
|
+
id, init_params, logic)
|
|
426
|
+
|
|
427
|
+
def _jax_wrapped_calc_poisson_approx(key, rate, params):
|
|
428
|
+
|
|
429
|
+
# determine if error of truncation at rate is acceptable
|
|
430
|
+
if self.poisson_bins > 0:
|
|
431
|
+
cuml_prob = scipy.stats.poisson.cdf(self.poisson_bins, rate)
|
|
432
|
+
approx_cond = jax.lax.stop_gradient(
|
|
433
|
+
jnp.min(cuml_prob) > self.poisson_min_cdf)
|
|
434
|
+
else:
|
|
435
|
+
approx_cond = False
|
|
436
|
+
|
|
437
|
+
# for acceptable truncation use the approximation, use exact otherwise
|
|
438
|
+
return jax.lax.cond(
|
|
439
|
+
approx_cond,
|
|
440
|
+
_jax_wrapped_calc_poisson_diff,
|
|
441
|
+
_jax_wrapped_calc_poisson_exact,
|
|
442
|
+
key, rate, params
|
|
443
|
+
)
|
|
444
|
+
return _jax_wrapped_calc_poisson_approx
|
|
445
|
+
|
|
446
|
+
def binomial(self, id, init_params, logic):
|
|
447
|
+
def _jax_wrapped_calc_binomial_exact(key, trials, prob, params):
|
|
448
|
+
trials = jnp.asarray(trials, dtype=logic.REAL)
|
|
449
|
+
prob = jnp.asarray(prob, dtype=logic.REAL)
|
|
450
|
+
sample = random.binomial(key=key, n=trials, p=prob, dtype=logic.REAL)
|
|
451
|
+
return sample, params
|
|
452
|
+
|
|
453
|
+
# Binomial(n, p) = sum_{i = 1 ... n} Bernoulli(p)
|
|
454
|
+
bernoulli_approx = self.bernoulli(id, init_params, logic)
|
|
455
|
+
def _jax_wrapped_calc_binomial_sum(key, trials, prob, params):
|
|
456
|
+
prob_full = jnp.broadcast_to(
|
|
457
|
+
prob[..., jnp.newaxis], shape=jnp.shape(prob) + (self.binomial_bins,))
|
|
458
|
+
sample_bern, params = bernoulli_approx(key, prob_full, params)
|
|
459
|
+
indices = jnp.arange(self.binomial_bins)[
|
|
460
|
+
(jnp.newaxis,) * jnp.ndim(prob) + (...,)]
|
|
461
|
+
mask = indices < trials[..., jnp.newaxis]
|
|
462
|
+
sample = jnp.sum(sample_bern * mask, axis=-1)
|
|
463
|
+
return sample, params
|
|
464
|
+
|
|
465
|
+
# for trials not too large use the Bernoulli relaxation, use exact otherwise
|
|
466
|
+
def _jax_wrapped_calc_binomial_approx(key, trials, prob, params):
|
|
467
|
+
return jax.lax.cond(
|
|
468
|
+
jax.lax.stop_gradient(jnp.max(trials) < self.binomial_bins),
|
|
469
|
+
_jax_wrapped_calc_binomial_sum,
|
|
470
|
+
_jax_wrapped_calc_binomial_exact,
|
|
471
|
+
key, trials, prob, params
|
|
472
|
+
)
|
|
473
|
+
return _jax_wrapped_calc_binomial_approx
|
|
474
|
+
|
|
475
|
+
def geometric(self, id, init_params, logic):
|
|
476
|
+
approx_floor = logic.floor(id, init_params)
|
|
477
|
+
def _jax_wrapped_calc_geometric_approx(key, prob, params):
|
|
478
|
+
U = random.uniform(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
|
|
479
|
+
floor, params = approx_floor(jnp.log1p(-U) / jnp.log1p(-prob), params)
|
|
480
|
+
sample = floor + 1
|
|
481
|
+
return sample, params
|
|
482
|
+
return _jax_wrapped_calc_geometric_approx
|
|
483
|
+
|
|
484
|
+
def _bernoulli_uniform(self, id, init_params, logic):
|
|
485
|
+
less_approx = logic.less(id, init_params)
|
|
486
|
+
def _jax_wrapped_calc_bernoulli_uniform(key, prob, params):
|
|
487
|
+
U = random.uniform(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
|
|
488
|
+
return less_approx(U, prob, params)
|
|
489
|
+
return _jax_wrapped_calc_bernoulli_uniform
|
|
490
|
+
|
|
491
|
+
def _bernoulli_gumbel_softmax(self, id, init_params, logic):
|
|
492
|
+
discrete_approx = self.discrete(id, init_params, logic)
|
|
493
|
+
def _jax_wrapped_calc_bernoulli_gumbel_softmax(key, prob, params):
|
|
494
|
+
prob = jnp.stack([1.0 - prob, prob], axis=-1)
|
|
495
|
+
return discrete_approx(key, prob, params)
|
|
496
|
+
return _jax_wrapped_calc_bernoulli_gumbel_softmax
|
|
497
|
+
|
|
498
|
+
def bernoulli(self, id, init_params, logic):
|
|
499
|
+
if self.bernoulli_gumbel_softmax:
|
|
500
|
+
return self._bernoulli_gumbel_softmax(id, init_params, logic)
|
|
501
|
+
else:
|
|
502
|
+
return self._bernoulli_uniform(id, init_params, logic)
|
|
503
|
+
|
|
369
504
|
def __str__(self) -> str:
|
|
370
|
-
return '
|
|
505
|
+
return 'SoftRandomSampling'
|
|
371
506
|
|
|
372
507
|
|
|
373
508
|
class Determinization(RandomSampling):
|
|
374
509
|
'''Random sampling of variables using their deterministic mean estimate.'''
|
|
375
|
-
|
|
510
|
+
|
|
376
511
|
@staticmethod
|
|
377
512
|
def _jax_wrapped_calc_discrete_determinized(key, prob, params):
|
|
378
|
-
literals = enumerate_literals(
|
|
513
|
+
literals = enumerate_literals(jnp.shape(prob), axis=-1)
|
|
379
514
|
sample = jnp.sum(literals * prob, axis=-1)
|
|
380
515
|
return sample, params
|
|
381
|
-
|
|
516
|
+
|
|
382
517
|
def discrete(self, id, init_params, logic):
|
|
383
518
|
return self._jax_wrapped_calc_discrete_determinized
|
|
384
519
|
|
|
@@ -389,6 +524,14 @@ class Determinization(RandomSampling):
|
|
|
389
524
|
def poisson(self, id, init_params, logic):
|
|
390
525
|
return self._jax_wrapped_calc_poisson_determinized
|
|
391
526
|
|
|
527
|
+
@staticmethod
|
|
528
|
+
def _jax_wrapped_calc_binomial_determinized(key, trials, prob, params):
|
|
529
|
+
sample = trials * prob
|
|
530
|
+
return sample, params
|
|
531
|
+
|
|
532
|
+
def binomial(self, id, init_params, logic):
|
|
533
|
+
return self._jax_wrapped_calc_binomial_determinized
|
|
534
|
+
|
|
392
535
|
@staticmethod
|
|
393
536
|
def _jax_wrapped_calc_geometric_determinized(key, prob, params):
|
|
394
537
|
sample = 1.0 / prob
|
|
@@ -397,6 +540,14 @@ class Determinization(RandomSampling):
|
|
|
397
540
|
def geometric(self, id, init_params, logic):
|
|
398
541
|
return self._jax_wrapped_calc_geometric_determinized
|
|
399
542
|
|
|
543
|
+
@staticmethod
|
|
544
|
+
def _jax_wrapped_calc_bernoulli_determinized(key, prob, params):
|
|
545
|
+
sample = prob
|
|
546
|
+
return sample, params
|
|
547
|
+
|
|
548
|
+
def bernoulli(self, id, init_params, logic):
|
|
549
|
+
return self._jax_wrapped_calc_bernoulli_determinized
|
|
550
|
+
|
|
400
551
|
def __str__(self) -> str:
|
|
401
552
|
return 'Deterministic'
|
|
402
553
|
|
|
@@ -435,8 +586,8 @@ class SoftControlFlow(ControlFlow):
|
|
|
435
586
|
id_ = str(id)
|
|
436
587
|
init_params[id_] = self.weight
|
|
437
588
|
def _jax_wrapped_calc_switch_soft(pred, cases, params):
|
|
438
|
-
literals = enumerate_literals(
|
|
439
|
-
pred = jnp.broadcast_to(pred[jnp.newaxis, ...], shape=
|
|
589
|
+
literals = enumerate_literals(jnp.shape(cases), axis=0)
|
|
590
|
+
pred = jnp.broadcast_to(pred[jnp.newaxis, ...], shape=jnp.shape(cases))
|
|
440
591
|
proximity = -jnp.square(pred - literals)
|
|
441
592
|
softcase = jax.nn.softmax(params[id_] * proximity, axis=0)
|
|
442
593
|
sample = jnp.sum(cases * softcase, axis=0)
|
|
@@ -477,6 +628,94 @@ class Logic:
|
|
|
477
628
|
self.INT = jnp.int32
|
|
478
629
|
jax.config.update('jax_enable_x64', False)
|
|
479
630
|
|
|
631
|
+
@staticmethod
|
|
632
|
+
def wrap_logic(func):
|
|
633
|
+
def exact_func(id, init_params):
|
|
634
|
+
return func
|
|
635
|
+
return exact_func
|
|
636
|
+
|
|
637
|
+
def get_operator_dicts(self) -> Dict[str, Union[Callable, Dict[str, Callable]]]:
|
|
638
|
+
'''Returns a dictionary of all operators in the current logic.'''
|
|
639
|
+
return {
|
|
640
|
+
'negative': self.wrap_logic(ExactLogic.exact_unary_function(jnp.negative)),
|
|
641
|
+
'arithmetic': {
|
|
642
|
+
'+': self.wrap_logic(ExactLogic.exact_binary_function(jnp.add)),
|
|
643
|
+
'-': self.wrap_logic(ExactLogic.exact_binary_function(jnp.subtract)),
|
|
644
|
+
'*': self.wrap_logic(ExactLogic.exact_binary_function(jnp.multiply)),
|
|
645
|
+
'/': self.wrap_logic(ExactLogic.exact_binary_function(jnp.divide))
|
|
646
|
+
},
|
|
647
|
+
'relational': {
|
|
648
|
+
'>=': self.greater_equal,
|
|
649
|
+
'<=': self.less_equal,
|
|
650
|
+
'<': self.less,
|
|
651
|
+
'>': self.greater,
|
|
652
|
+
'==': self.equal,
|
|
653
|
+
'~=': self.not_equal
|
|
654
|
+
},
|
|
655
|
+
'logical_not': self.logical_not,
|
|
656
|
+
'logical': {
|
|
657
|
+
'^': self.logical_and,
|
|
658
|
+
'&': self.logical_and,
|
|
659
|
+
'|': self.logical_or,
|
|
660
|
+
'~': self.xor,
|
|
661
|
+
'=>': self.implies,
|
|
662
|
+
'<=>': self.equiv
|
|
663
|
+
},
|
|
664
|
+
'aggregation': {
|
|
665
|
+
'sum': self.wrap_logic(ExactLogic.exact_aggregation(jnp.sum)),
|
|
666
|
+
'avg': self.wrap_logic(ExactLogic.exact_aggregation(jnp.mean)),
|
|
667
|
+
'prod': self.wrap_logic(ExactLogic.exact_aggregation(jnp.prod)),
|
|
668
|
+
'minimum': self.wrap_logic(ExactLogic.exact_aggregation(jnp.min)),
|
|
669
|
+
'maximum': self.wrap_logic(ExactLogic.exact_aggregation(jnp.max)),
|
|
670
|
+
'forall': self.forall,
|
|
671
|
+
'exists': self.exists,
|
|
672
|
+
'argmin': self.argmin,
|
|
673
|
+
'argmax': self.argmax
|
|
674
|
+
},
|
|
675
|
+
'unary': {
|
|
676
|
+
'abs': self.wrap_logic(ExactLogic.exact_unary_function(jnp.abs)),
|
|
677
|
+
'sgn': self.sgn,
|
|
678
|
+
'round': self.round,
|
|
679
|
+
'floor': self.floor,
|
|
680
|
+
'ceil': self.ceil,
|
|
681
|
+
'cos': self.wrap_logic(ExactLogic.exact_unary_function(jnp.cos)),
|
|
682
|
+
'sin': self.wrap_logic(ExactLogic.exact_unary_function(jnp.sin)),
|
|
683
|
+
'tan': self.wrap_logic(ExactLogic.exact_unary_function(jnp.tan)),
|
|
684
|
+
'acos': self.wrap_logic(ExactLogic.exact_unary_function(jnp.arccos)),
|
|
685
|
+
'asin': self.wrap_logic(ExactLogic.exact_unary_function(jnp.arcsin)),
|
|
686
|
+
'atan': self.wrap_logic(ExactLogic.exact_unary_function(jnp.arctan)),
|
|
687
|
+
'cosh': self.wrap_logic(ExactLogic.exact_unary_function(jnp.cosh)),
|
|
688
|
+
'sinh': self.wrap_logic(ExactLogic.exact_unary_function(jnp.sinh)),
|
|
689
|
+
'tanh': self.wrap_logic(ExactLogic.exact_unary_function(jnp.tanh)),
|
|
690
|
+
'exp': self.wrap_logic(ExactLogic.exact_unary_function(jnp.exp)),
|
|
691
|
+
'ln': self.wrap_logic(ExactLogic.exact_unary_function(jnp.log)),
|
|
692
|
+
'sqrt': self.sqrt,
|
|
693
|
+
'lngamma': self.wrap_logic(ExactLogic.exact_unary_function(scipy.special.gammaln)),
|
|
694
|
+
'gamma': self.wrap_logic(ExactLogic.exact_unary_function(scipy.special.gamma))
|
|
695
|
+
},
|
|
696
|
+
'binary': {
|
|
697
|
+
'div': self.div,
|
|
698
|
+
'mod': self.mod,
|
|
699
|
+
'fmod': self.mod,
|
|
700
|
+
'min': self.wrap_logic(ExactLogic.exact_binary_function(jnp.minimum)),
|
|
701
|
+
'max': self.wrap_logic(ExactLogic.exact_binary_function(jnp.maximum)),
|
|
702
|
+
'pow': self.wrap_logic(ExactLogic.exact_binary_function(jnp.power)),
|
|
703
|
+
'log': self.wrap_logic(ExactLogic.exact_binary_log),
|
|
704
|
+
'hypot': self.wrap_logic(ExactLogic.exact_binary_function(jnp.hypot)),
|
|
705
|
+
},
|
|
706
|
+
'control': {
|
|
707
|
+
'if': self.control_if,
|
|
708
|
+
'switch': self.control_switch
|
|
709
|
+
},
|
|
710
|
+
'sampling': {
|
|
711
|
+
'Bernoulli': self.bernoulli,
|
|
712
|
+
'Discrete': self.discrete,
|
|
713
|
+
'Poisson': self.poisson,
|
|
714
|
+
'Geometric': self.geometric,
|
|
715
|
+
'Binomial': self.binomial
|
|
716
|
+
}
|
|
717
|
+
}
|
|
718
|
+
|
|
480
719
|
# ===========================================================================
|
|
481
720
|
# logical operators
|
|
482
721
|
# ===========================================================================
|
|
@@ -587,6 +826,9 @@ class Logic:
|
|
|
587
826
|
|
|
588
827
|
def geometric(self, id, init_params):
|
|
589
828
|
raise NotImplementedError
|
|
829
|
+
|
|
830
|
+
def binomial(self, id, init_params):
|
|
831
|
+
raise NotImplementedError
|
|
590
832
|
|
|
591
833
|
|
|
592
834
|
class ExactLogic(Logic):
|
|
@@ -627,11 +869,11 @@ class ExactLogic(Logic):
|
|
|
627
869
|
return self.exact_binary_function(jnp.logical_xor)
|
|
628
870
|
|
|
629
871
|
@staticmethod
|
|
630
|
-
def
|
|
872
|
+
def _jax_wrapped_calc_implies_exact(x, y, params):
|
|
631
873
|
return jnp.logical_or(jnp.logical_not(x), y), params
|
|
632
874
|
|
|
633
875
|
def implies(self, id, init_params):
|
|
634
|
-
return self.
|
|
876
|
+
return self._jax_wrapped_calc_implies_exact
|
|
635
877
|
|
|
636
878
|
def equiv(self, id, init_params):
|
|
637
879
|
return self.exact_binary_function(jnp.equal)
|
|
@@ -668,6 +910,10 @@ class ExactLogic(Logic):
|
|
|
668
910
|
# special functions
|
|
669
911
|
# ===========================================================================
|
|
670
912
|
|
|
913
|
+
@staticmethod
|
|
914
|
+
def exact_binary_log(x, y, params):
|
|
915
|
+
return jnp.log(x) / jnp.log(y), params
|
|
916
|
+
|
|
671
917
|
def sgn(self, id, init_params):
|
|
672
918
|
return self.exact_unary_function(jnp.sign)
|
|
673
919
|
|
|
@@ -704,54 +950,62 @@ class ExactLogic(Logic):
|
|
|
704
950
|
# ===========================================================================
|
|
705
951
|
|
|
706
952
|
@staticmethod
|
|
707
|
-
def
|
|
953
|
+
def _jax_wrapped_calc_if_then_else_exact(c, a, b, params):
|
|
708
954
|
return jnp.where(c > 0.5, a, b), params
|
|
709
955
|
|
|
710
956
|
def control_if(self, id, init_params):
|
|
711
|
-
return self.
|
|
957
|
+
return self._jax_wrapped_calc_if_then_else_exact
|
|
712
958
|
|
|
713
959
|
@staticmethod
|
|
714
|
-
def
|
|
960
|
+
def _jax_wrapped_calc_switch_exact(pred, cases, params):
|
|
715
961
|
pred = pred[jnp.newaxis, ...]
|
|
716
962
|
sample = jnp.take_along_axis(cases, pred, axis=0)
|
|
717
963
|
assert sample.shape[0] == 1
|
|
718
964
|
return sample[0, ...], params
|
|
719
965
|
|
|
720
966
|
def control_switch(self, id, init_params):
|
|
721
|
-
return self.
|
|
967
|
+
return self._jax_wrapped_calc_switch_exact
|
|
722
968
|
|
|
723
969
|
# ===========================================================================
|
|
724
970
|
# random variables
|
|
725
971
|
# ===========================================================================
|
|
726
972
|
|
|
727
973
|
@staticmethod
|
|
728
|
-
def
|
|
729
|
-
|
|
974
|
+
def _jax_wrapped_calc_discrete_exact(key, prob, params):
|
|
975
|
+
sample = random.categorical(key=key, logits=jnp.log(prob), axis=-1)
|
|
976
|
+
return sample, params
|
|
730
977
|
|
|
731
978
|
def discrete(self, id, init_params):
|
|
732
|
-
return self.
|
|
979
|
+
return self._jax_wrapped_calc_discrete_exact
|
|
733
980
|
|
|
734
981
|
@staticmethod
|
|
735
|
-
def
|
|
982
|
+
def _jax_wrapped_calc_bernoulli_exact(key, prob, params):
|
|
736
983
|
return random.bernoulli(key, prob), params
|
|
737
984
|
|
|
738
985
|
def bernoulli(self, id, init_params):
|
|
739
|
-
return self.
|
|
740
|
-
|
|
741
|
-
@staticmethod
|
|
742
|
-
def exact_poisson(key, rate, params):
|
|
743
|
-
return random.poisson(key=key, lam=rate), params
|
|
986
|
+
return self._jax_wrapped_calc_bernoulli_exact
|
|
744
987
|
|
|
745
988
|
def poisson(self, id, init_params):
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
return random.geometric(key=key, p=prob), params
|
|
989
|
+
def _jax_wrapped_calc_poisson_exact(key, rate, params):
|
|
990
|
+
sample = random.poisson(key=key, lam=rate, dtype=self.INT)
|
|
991
|
+
return sample, params
|
|
992
|
+
return _jax_wrapped_calc_poisson_exact
|
|
751
993
|
|
|
752
994
|
def geometric(self, id, init_params):
|
|
753
|
-
|
|
754
|
-
|
|
995
|
+
def _jax_wrapped_calc_geometric_exact(key, prob, params):
|
|
996
|
+
sample = random.geometric(key=key, p=prob, dtype=self.INT)
|
|
997
|
+
return sample, params
|
|
998
|
+
return _jax_wrapped_calc_geometric_exact
|
|
999
|
+
|
|
1000
|
+
def binomial(self, id, init_params):
|
|
1001
|
+
def _jax_wrapped_calc_binomial_exact(key, trials, prob, params):
|
|
1002
|
+
trials = jnp.asarray(trials, dtype=self.REAL)
|
|
1003
|
+
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
1004
|
+
sample = random.binomial(key=key, n=trials, p=prob, dtype=self.REAL)
|
|
1005
|
+
sample = jnp.asarray(sample, dtype=self.INT)
|
|
1006
|
+
return sample, params
|
|
1007
|
+
return _jax_wrapped_calc_binomial_exact
|
|
1008
|
+
|
|
755
1009
|
|
|
756
1010
|
class FuzzyLogic(Logic):
|
|
757
1011
|
'''A class representing fuzzy logic in JAX.'''
|
|
@@ -759,7 +1013,7 @@ class FuzzyLogic(Logic):
|
|
|
759
1013
|
def __init__(self, tnorm: TNorm=ProductTNorm(),
|
|
760
1014
|
complement: Complement=StandardComplement(),
|
|
761
1015
|
comparison: Comparison=SigmoidComparison(),
|
|
762
|
-
sampling: RandomSampling=
|
|
1016
|
+
sampling: RandomSampling=SoftRandomSampling(),
|
|
763
1017
|
rounding: Rounding=SoftRounding(),
|
|
764
1018
|
control: ControlFlow=SoftControlFlow(),
|
|
765
1019
|
eps: float=1e-15,
|
|
@@ -977,6 +1231,9 @@ class FuzzyLogic(Logic):
|
|
|
977
1231
|
|
|
978
1232
|
def geometric(self, id, init_params):
|
|
979
1233
|
return self.sampling.geometric(id, init_params, self)
|
|
1234
|
+
|
|
1235
|
+
def binomial(self, id, init_params):
|
|
1236
|
+
return self.sampling.binomial(id, init_params, self)
|
|
980
1237
|
|
|
981
1238
|
|
|
982
1239
|
# ===========================================================================
|
|
@@ -1011,8 +1268,8 @@ def _test_logical():
|
|
|
1011
1268
|
pred, w = _if(cond, +1, -1, w)
|
|
1012
1269
|
return pred
|
|
1013
1270
|
|
|
1014
|
-
x1 = jnp.asarray([1, 1, -1, -1, 0.1, 15, -0.5]
|
|
1015
|
-
x2 = jnp.asarray([1, -1, 1, -1, 10, -30, 6]
|
|
1271
|
+
x1 = jnp.asarray([1, 1, -1, -1, 0.1, 15, -0.5], dtype=float)
|
|
1272
|
+
x2 = jnp.asarray([1, -1, 1, -1, 10, -30, 6], dtype=float)
|
|
1016
1273
|
print(test_logic(x1, x2, init_params))
|
|
1017
1274
|
|
|
1018
1275
|
|