pyRDDLGym-jax 1.3__py3-none-any.whl → 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +101 -191
- pyRDDLGym_jax/core/logic.py +349 -65
- pyRDDLGym_jax/core/planner.py +554 -208
- pyRDDLGym_jax/core/simulator.py +20 -0
- pyRDDLGym_jax/core/tuning.py +15 -0
- pyRDDLGym_jax/core/visualization.py +55 -8
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -3
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +1 -0
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +1 -0
- pyRDDLGym_jax/examples/run_tune.py +10 -6
- {pyRDDLGym_jax-1.3.dist-info → pyrddlgym_jax-2.1.dist-info}/METADATA +22 -12
- {pyRDDLGym_jax-1.3.dist-info → pyrddlgym_jax-2.1.dist-info}/RECORD +24 -24
- {pyRDDLGym_jax-1.3.dist-info → pyrddlgym_jax-2.1.dist-info}/WHEEL +1 -1
- {pyRDDLGym_jax-1.3.dist-info → pyrddlgym_jax-2.1.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-1.3.dist-info → pyrddlgym_jax-2.1.dist-info}/entry_points.txt +0 -0
- {pyRDDLGym_jax-1.3.dist-info → pyrddlgym_jax-2.1.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/logic.py
CHANGED
|
@@ -1,8 +1,40 @@
|
|
|
1
|
-
|
|
1
|
+
# ***********************************************************************
|
|
2
|
+
# JAXPLAN
|
|
3
|
+
#
|
|
4
|
+
# Author: Michael Gimelfarb
|
|
5
|
+
#
|
|
6
|
+
# REFERENCES:
|
|
7
|
+
#
|
|
8
|
+
# [1] Gimelfarb, Michael, Ayal Taitler, and Scott Sanner. "JaxPlan and GurobiPlan:
|
|
9
|
+
# Optimization Baselines for Replanning in Discrete and Mixed Discrete-Continuous
|
|
10
|
+
# Probabilistic Domains." Proceedings of the International Conference on Automated
|
|
11
|
+
# Planning and Scheduling. Vol. 34. 2024.
|
|
12
|
+
#
|
|
13
|
+
# [2] Petersen, Felix, Christian Borgelt, Hilde Kuehne, and Oliver Deussen. "Learning with
|
|
14
|
+
# algorithmic supervision via continuous relaxations." Advances in Neural Information
|
|
15
|
+
# Processing Systems 34 (2021): 16520-16531.
|
|
16
|
+
#
|
|
17
|
+
# [3] Agustsson, Eirikur, and Lucas Theis. "Universally quantized neural compression."
|
|
18
|
+
# Advances in neural information processing systems 33 (2020): 12367-12376.
|
|
19
|
+
#
|
|
20
|
+
# [4] Gupta, Madan M., and J11043360726 Qi. "Theory of T-norms and fuzzy inference
|
|
21
|
+
# methods." Fuzzy sets and systems 40, no. 3 (1991): 431-450.
|
|
22
|
+
#
|
|
23
|
+
# [5] Jang, Eric, Shixiang Gu, and Ben Poole. "Categorical Reparametrization with
|
|
24
|
+
# Gumble-Softmax." In International Conference on Learning Representations (ICLR 2017).
|
|
25
|
+
# OpenReview. net, 2017.
|
|
26
|
+
#
|
|
27
|
+
# [6] Vafaii, H., Galor, D., & Yates, J. (2025). Poisson Variational Autoencoder.
|
|
28
|
+
# Advances in Neural Information Processing Systems, 37, 44871-44906.
|
|
29
|
+
#
|
|
30
|
+
# ***********************************************************************
|
|
31
|
+
|
|
32
|
+
from typing import Callable, Dict, Union
|
|
2
33
|
|
|
3
34
|
import jax
|
|
4
35
|
import jax.numpy as jnp
|
|
5
36
|
import jax.random as random
|
|
37
|
+
import jax.scipy as scipy
|
|
6
38
|
|
|
7
39
|
|
|
8
40
|
def enumerate_literals(shape, axis, dtype=jnp.int32):
|
|
@@ -78,7 +110,7 @@ class SigmoidComparison(Comparison):
|
|
|
78
110
|
id_ = str(id)
|
|
79
111
|
init_params[id_] = self.weight
|
|
80
112
|
def _jax_wrapped_calc_argmax_approx(x, axis, params):
|
|
81
|
-
literals = enumerate_literals(
|
|
113
|
+
literals = enumerate_literals(jnp.shape(x), axis=axis)
|
|
82
114
|
softmax = jax.nn.softmax(params[id_] * x, axis=axis)
|
|
83
115
|
sample = jnp.sum(literals * softmax, axis=axis)
|
|
84
116
|
return sample, params
|
|
@@ -296,62 +328,192 @@ class YagerTNorm(TNorm):
|
|
|
296
328
|
# ===========================================================================
|
|
297
329
|
|
|
298
330
|
class RandomSampling:
|
|
299
|
-
'''
|
|
300
|
-
random variables are sampled.'''
|
|
331
|
+
'''Describes how non-reparameterizable random variables are sampled.'''
|
|
301
332
|
|
|
302
333
|
def discrete(self, id, init_params, logic):
|
|
303
334
|
raise NotImplementedError
|
|
304
335
|
|
|
305
|
-
def bernoulli(self, id, init_params, logic):
|
|
306
|
-
discrete_approx = self.discrete(id, init_params, logic)
|
|
307
|
-
def _jax_wrapped_calc_bernoulli_approx(key, prob, params):
|
|
308
|
-
prob = jnp.stack([1.0 - prob, prob], axis=-1)
|
|
309
|
-
return discrete_approx(key, prob, params)
|
|
310
|
-
return _jax_wrapped_calc_bernoulli_approx
|
|
311
|
-
|
|
312
|
-
@staticmethod
|
|
313
|
-
def _jax_wrapped_calc_poisson_exact(key, rate, params):
|
|
314
|
-
sample = random.poisson(key=key, lam=rate, dtype=logic.INT)
|
|
315
|
-
return sample, params
|
|
316
|
-
|
|
317
336
|
def poisson(self, id, init_params, logic):
|
|
318
|
-
|
|
337
|
+
raise NotImplementedError
|
|
338
|
+
|
|
339
|
+
def binomial(self, id, init_params, logic):
|
|
340
|
+
raise NotImplementedError
|
|
319
341
|
|
|
320
342
|
def geometric(self, id, init_params, logic):
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
return
|
|
343
|
+
raise NotImplementedError
|
|
344
|
+
|
|
345
|
+
def bernoulli(self, id, init_params, logic):
|
|
346
|
+
raise NotImplementedError
|
|
347
|
+
|
|
348
|
+
def __str__(self) -> str:
|
|
349
|
+
return 'RandomSampling'
|
|
328
350
|
|
|
329
351
|
|
|
330
|
-
class
|
|
352
|
+
class SoftRandomSampling(RandomSampling):
|
|
331
353
|
'''Random sampling of discrete variables using Gumbel-softmax trick.'''
|
|
332
354
|
|
|
355
|
+
def __init__(self, poisson_max_bins: int=100,
|
|
356
|
+
poisson_min_cdf: float=0.999,
|
|
357
|
+
poisson_exp_sampling: bool=True,
|
|
358
|
+
binomial_max_bins: int=100,
|
|
359
|
+
bernoulli_gumbel_softmax: bool=False) -> None:
|
|
360
|
+
'''Creates a new instance of soft random sampling.
|
|
361
|
+
|
|
362
|
+
:param poisson_max_bins: maximum bins to use for Poisson distribution relaxation
|
|
363
|
+
:param poisson_min_cdf: minimum cdf value of Poisson within truncated region
|
|
364
|
+
in order to use Poisson relaxation
|
|
365
|
+
:param poisson_exp_sampling: whether to use Poisson process sampling method
|
|
366
|
+
instead of truncated Gumbel-Softmax
|
|
367
|
+
:param binomial_max_bins: maximum bins to use for Binomial distribution relaxation
|
|
368
|
+
:param bernoulli_gumbel_softmax: whether to use Gumbel-Softmax to approximate
|
|
369
|
+
Bernoulli samples, or the standard uniform reparameterization instead
|
|
370
|
+
'''
|
|
371
|
+
self.poisson_bins = poisson_max_bins
|
|
372
|
+
self.poisson_min_cdf = poisson_min_cdf
|
|
373
|
+
self.poisson_exp_method = poisson_exp_sampling
|
|
374
|
+
self.binomial_bins = binomial_max_bins
|
|
375
|
+
self.bernoulli_gumbel_softmax = bernoulli_gumbel_softmax
|
|
376
|
+
|
|
333
377
|
# https://arxiv.org/pdf/1611.01144
|
|
334
378
|
def discrete(self, id, init_params, logic):
|
|
335
379
|
argmax_approx = logic.argmax(id, init_params)
|
|
336
380
|
def _jax_wrapped_calc_discrete_gumbel_softmax(key, prob, params):
|
|
337
|
-
Gumbel01 = random.gumbel(key=key, shape=
|
|
381
|
+
Gumbel01 = random.gumbel(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
|
|
338
382
|
sample = Gumbel01 + jnp.log(prob + logic.eps)
|
|
339
383
|
return argmax_approx(sample, axis=-1, params=params)
|
|
340
384
|
return _jax_wrapped_calc_discrete_gumbel_softmax
|
|
341
385
|
|
|
386
|
+
def _poisson_gumbel_softmax(self, id, init_params, logic):
|
|
387
|
+
argmax_approx = logic.argmax(id, init_params)
|
|
388
|
+
def _jax_wrapped_calc_poisson_gumbel_softmax(key, rate, params):
|
|
389
|
+
ks = jnp.arange(0, self.poisson_bins)
|
|
390
|
+
ks = ks[(jnp.newaxis,) * jnp.ndim(rate) + (...,)]
|
|
391
|
+
rate = rate[..., jnp.newaxis]
|
|
392
|
+
log_prob = ks * jnp.log(rate + logic.eps) - rate - scipy.special.gammaln(ks + 1)
|
|
393
|
+
Gumbel01 = random.gumbel(key=key, shape=jnp.shape(log_prob), dtype=logic.REAL)
|
|
394
|
+
sample = Gumbel01 + log_prob
|
|
395
|
+
return argmax_approx(sample, axis=-1, params=params)
|
|
396
|
+
return _jax_wrapped_calc_poisson_gumbel_softmax
|
|
397
|
+
|
|
398
|
+
# https://arxiv.org/abs/2405.14473
|
|
399
|
+
def _poisson_exponential(self, id, init_params, logic):
|
|
400
|
+
less_approx = logic.less(id, init_params)
|
|
401
|
+
def _jax_wrapped_calc_poisson_exponential(key, rate, params):
|
|
402
|
+
Exp1 = random.exponential(
|
|
403
|
+
key=key,
|
|
404
|
+
shape=(self.poisson_bins,) + jnp.shape(rate),
|
|
405
|
+
dtype=logic.REAL
|
|
406
|
+
)
|
|
407
|
+
delta_t = Exp1 / rate[jnp.newaxis, ...]
|
|
408
|
+
times = jnp.cumsum(delta_t, axis=0)
|
|
409
|
+
indicator, params = less_approx(times, 1.0, params)
|
|
410
|
+
sample = jnp.sum(indicator, axis=0)
|
|
411
|
+
return sample, params
|
|
412
|
+
return _jax_wrapped_calc_poisson_exponential
|
|
413
|
+
|
|
414
|
+
def poisson(self, id, init_params, logic):
|
|
415
|
+
def _jax_wrapped_calc_poisson_exact(key, rate, params):
|
|
416
|
+
sample = random.poisson(key=key, lam=rate, dtype=logic.INT)
|
|
417
|
+
sample = jnp.asarray(sample, dtype=logic.REAL)
|
|
418
|
+
return sample, params
|
|
419
|
+
|
|
420
|
+
if self.poisson_exp_method:
|
|
421
|
+
_jax_wrapped_calc_poisson_diff = self._poisson_exponential(
|
|
422
|
+
id, init_params, logic)
|
|
423
|
+
else:
|
|
424
|
+
_jax_wrapped_calc_poisson_diff = self._poisson_gumbel_softmax(
|
|
425
|
+
id, init_params, logic)
|
|
426
|
+
|
|
427
|
+
def _jax_wrapped_calc_poisson_approx(key, rate, params):
|
|
428
|
+
|
|
429
|
+
# determine if error of truncation at rate is acceptable
|
|
430
|
+
if self.poisson_bins > 0:
|
|
431
|
+
cuml_prob = scipy.stats.poisson.cdf(self.poisson_bins, rate)
|
|
432
|
+
approx_cond = jax.lax.stop_gradient(
|
|
433
|
+
jnp.min(cuml_prob) > self.poisson_min_cdf)
|
|
434
|
+
else:
|
|
435
|
+
approx_cond = False
|
|
436
|
+
|
|
437
|
+
# for acceptable truncation use the approximation, use exact otherwise
|
|
438
|
+
return jax.lax.cond(
|
|
439
|
+
approx_cond,
|
|
440
|
+
_jax_wrapped_calc_poisson_diff,
|
|
441
|
+
_jax_wrapped_calc_poisson_exact,
|
|
442
|
+
key, rate, params
|
|
443
|
+
)
|
|
444
|
+
return _jax_wrapped_calc_poisson_approx
|
|
445
|
+
|
|
446
|
+
def binomial(self, id, init_params, logic):
|
|
447
|
+
def _jax_wrapped_calc_binomial_exact(key, trials, prob, params):
|
|
448
|
+
trials = jnp.asarray(trials, dtype=logic.REAL)
|
|
449
|
+
prob = jnp.asarray(prob, dtype=logic.REAL)
|
|
450
|
+
sample = random.binomial(key=key, n=trials, p=prob, dtype=logic.REAL)
|
|
451
|
+
return sample, params
|
|
452
|
+
|
|
453
|
+
# Binomial(n, p) = sum_{i = 1 ... n} Bernoulli(p)
|
|
454
|
+
bernoulli_approx = self.bernoulli(id, init_params, logic)
|
|
455
|
+
def _jax_wrapped_calc_binomial_sum(key, trials, prob, params):
|
|
456
|
+
prob_full = jnp.broadcast_to(
|
|
457
|
+
prob[..., jnp.newaxis], shape=jnp.shape(prob) + (self.binomial_bins,))
|
|
458
|
+
sample_bern, params = bernoulli_approx(key, prob_full, params)
|
|
459
|
+
indices = jnp.arange(self.binomial_bins)[
|
|
460
|
+
(jnp.newaxis,) * jnp.ndim(prob) + (...,)]
|
|
461
|
+
mask = indices < trials[..., jnp.newaxis]
|
|
462
|
+
sample = jnp.sum(sample_bern * mask, axis=-1)
|
|
463
|
+
return sample, params
|
|
464
|
+
|
|
465
|
+
# for trials not too large use the Bernoulli relaxation, use exact otherwise
|
|
466
|
+
def _jax_wrapped_calc_binomial_approx(key, trials, prob, params):
|
|
467
|
+
return jax.lax.cond(
|
|
468
|
+
jax.lax.stop_gradient(jnp.max(trials) < self.binomial_bins),
|
|
469
|
+
_jax_wrapped_calc_binomial_sum,
|
|
470
|
+
_jax_wrapped_calc_binomial_exact,
|
|
471
|
+
key, trials, prob, params
|
|
472
|
+
)
|
|
473
|
+
return _jax_wrapped_calc_binomial_approx
|
|
474
|
+
|
|
475
|
+
def geometric(self, id, init_params, logic):
|
|
476
|
+
approx_floor = logic.floor(id, init_params)
|
|
477
|
+
def _jax_wrapped_calc_geometric_approx(key, prob, params):
|
|
478
|
+
U = random.uniform(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
|
|
479
|
+
floor, params = approx_floor(jnp.log1p(-U) / jnp.log1p(-prob), params)
|
|
480
|
+
sample = floor + 1
|
|
481
|
+
return sample, params
|
|
482
|
+
return _jax_wrapped_calc_geometric_approx
|
|
483
|
+
|
|
484
|
+
def _bernoulli_uniform(self, id, init_params, logic):
|
|
485
|
+
less_approx = logic.less(id, init_params)
|
|
486
|
+
def _jax_wrapped_calc_bernoulli_uniform(key, prob, params):
|
|
487
|
+
U = random.uniform(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
|
|
488
|
+
return less_approx(U, prob, params)
|
|
489
|
+
return _jax_wrapped_calc_bernoulli_uniform
|
|
490
|
+
|
|
491
|
+
def _bernoulli_gumbel_softmax(self, id, init_params, logic):
|
|
492
|
+
discrete_approx = self.discrete(id, init_params, logic)
|
|
493
|
+
def _jax_wrapped_calc_bernoulli_gumbel_softmax(key, prob, params):
|
|
494
|
+
prob = jnp.stack([1.0 - prob, prob], axis=-1)
|
|
495
|
+
return discrete_approx(key, prob, params)
|
|
496
|
+
return _jax_wrapped_calc_bernoulli_gumbel_softmax
|
|
497
|
+
|
|
498
|
+
def bernoulli(self, id, init_params, logic):
|
|
499
|
+
if self.bernoulli_gumbel_softmax:
|
|
500
|
+
return self._bernoulli_gumbel_softmax(id, init_params, logic)
|
|
501
|
+
else:
|
|
502
|
+
return self._bernoulli_uniform(id, init_params, logic)
|
|
503
|
+
|
|
342
504
|
def __str__(self) -> str:
|
|
343
|
-
return '
|
|
505
|
+
return 'SoftRandomSampling'
|
|
344
506
|
|
|
345
507
|
|
|
346
508
|
class Determinization(RandomSampling):
|
|
347
509
|
'''Random sampling of variables using their deterministic mean estimate.'''
|
|
348
|
-
|
|
510
|
+
|
|
349
511
|
@staticmethod
|
|
350
512
|
def _jax_wrapped_calc_discrete_determinized(key, prob, params):
|
|
351
|
-
literals = enumerate_literals(
|
|
513
|
+
literals = enumerate_literals(jnp.shape(prob), axis=-1)
|
|
352
514
|
sample = jnp.sum(literals * prob, axis=-1)
|
|
353
515
|
return sample, params
|
|
354
|
-
|
|
516
|
+
|
|
355
517
|
def discrete(self, id, init_params, logic):
|
|
356
518
|
return self._jax_wrapped_calc_discrete_determinized
|
|
357
519
|
|
|
@@ -362,6 +524,14 @@ class Determinization(RandomSampling):
|
|
|
362
524
|
def poisson(self, id, init_params, logic):
|
|
363
525
|
return self._jax_wrapped_calc_poisson_determinized
|
|
364
526
|
|
|
527
|
+
@staticmethod
|
|
528
|
+
def _jax_wrapped_calc_binomial_determinized(key, trials, prob, params):
|
|
529
|
+
sample = trials * prob
|
|
530
|
+
return sample, params
|
|
531
|
+
|
|
532
|
+
def binomial(self, id, init_params, logic):
|
|
533
|
+
return self._jax_wrapped_calc_binomial_determinized
|
|
534
|
+
|
|
365
535
|
@staticmethod
|
|
366
536
|
def _jax_wrapped_calc_geometric_determinized(key, prob, params):
|
|
367
537
|
sample = 1.0 / prob
|
|
@@ -370,6 +540,14 @@ class Determinization(RandomSampling):
|
|
|
370
540
|
def geometric(self, id, init_params, logic):
|
|
371
541
|
return self._jax_wrapped_calc_geometric_determinized
|
|
372
542
|
|
|
543
|
+
@staticmethod
|
|
544
|
+
def _jax_wrapped_calc_bernoulli_determinized(key, prob, params):
|
|
545
|
+
sample = prob
|
|
546
|
+
return sample, params
|
|
547
|
+
|
|
548
|
+
def bernoulli(self, id, init_params, logic):
|
|
549
|
+
return self._jax_wrapped_calc_bernoulli_determinized
|
|
550
|
+
|
|
373
551
|
def __str__(self) -> str:
|
|
374
552
|
return 'Deterministic'
|
|
375
553
|
|
|
@@ -408,8 +586,8 @@ class SoftControlFlow(ControlFlow):
|
|
|
408
586
|
id_ = str(id)
|
|
409
587
|
init_params[id_] = self.weight
|
|
410
588
|
def _jax_wrapped_calc_switch_soft(pred, cases, params):
|
|
411
|
-
literals = enumerate_literals(
|
|
412
|
-
pred = jnp.broadcast_to(pred[jnp.newaxis, ...], shape=
|
|
589
|
+
literals = enumerate_literals(jnp.shape(cases), axis=0)
|
|
590
|
+
pred = jnp.broadcast_to(pred[jnp.newaxis, ...], shape=jnp.shape(cases))
|
|
413
591
|
proximity = -jnp.square(pred - literals)
|
|
414
592
|
softcase = jax.nn.softmax(params[id_] * proximity, axis=0)
|
|
415
593
|
sample = jnp.sum(cases * softcase, axis=0)
|
|
@@ -450,6 +628,94 @@ class Logic:
|
|
|
450
628
|
self.INT = jnp.int32
|
|
451
629
|
jax.config.update('jax_enable_x64', False)
|
|
452
630
|
|
|
631
|
+
@staticmethod
|
|
632
|
+
def wrap_logic(func):
|
|
633
|
+
def exact_func(id, init_params):
|
|
634
|
+
return func
|
|
635
|
+
return exact_func
|
|
636
|
+
|
|
637
|
+
def get_operator_dicts(self) -> Dict[str, Union[Callable, Dict[str, Callable]]]:
|
|
638
|
+
'''Returns a dictionary of all operators in the current logic.'''
|
|
639
|
+
return {
|
|
640
|
+
'negative': self.wrap_logic(ExactLogic.exact_unary_function(jnp.negative)),
|
|
641
|
+
'arithmetic': {
|
|
642
|
+
'+': self.wrap_logic(ExactLogic.exact_binary_function(jnp.add)),
|
|
643
|
+
'-': self.wrap_logic(ExactLogic.exact_binary_function(jnp.subtract)),
|
|
644
|
+
'*': self.wrap_logic(ExactLogic.exact_binary_function(jnp.multiply)),
|
|
645
|
+
'/': self.wrap_logic(ExactLogic.exact_binary_function(jnp.divide))
|
|
646
|
+
},
|
|
647
|
+
'relational': {
|
|
648
|
+
'>=': self.greater_equal,
|
|
649
|
+
'<=': self.less_equal,
|
|
650
|
+
'<': self.less,
|
|
651
|
+
'>': self.greater,
|
|
652
|
+
'==': self.equal,
|
|
653
|
+
'~=': self.not_equal
|
|
654
|
+
},
|
|
655
|
+
'logical_not': self.logical_not,
|
|
656
|
+
'logical': {
|
|
657
|
+
'^': self.logical_and,
|
|
658
|
+
'&': self.logical_and,
|
|
659
|
+
'|': self.logical_or,
|
|
660
|
+
'~': self.xor,
|
|
661
|
+
'=>': self.implies,
|
|
662
|
+
'<=>': self.equiv
|
|
663
|
+
},
|
|
664
|
+
'aggregation': {
|
|
665
|
+
'sum': self.wrap_logic(ExactLogic.exact_aggregation(jnp.sum)),
|
|
666
|
+
'avg': self.wrap_logic(ExactLogic.exact_aggregation(jnp.mean)),
|
|
667
|
+
'prod': self.wrap_logic(ExactLogic.exact_aggregation(jnp.prod)),
|
|
668
|
+
'minimum': self.wrap_logic(ExactLogic.exact_aggregation(jnp.min)),
|
|
669
|
+
'maximum': self.wrap_logic(ExactLogic.exact_aggregation(jnp.max)),
|
|
670
|
+
'forall': self.forall,
|
|
671
|
+
'exists': self.exists,
|
|
672
|
+
'argmin': self.argmin,
|
|
673
|
+
'argmax': self.argmax
|
|
674
|
+
},
|
|
675
|
+
'unary': {
|
|
676
|
+
'abs': self.wrap_logic(ExactLogic.exact_unary_function(jnp.abs)),
|
|
677
|
+
'sgn': self.sgn,
|
|
678
|
+
'round': self.round,
|
|
679
|
+
'floor': self.floor,
|
|
680
|
+
'ceil': self.ceil,
|
|
681
|
+
'cos': self.wrap_logic(ExactLogic.exact_unary_function(jnp.cos)),
|
|
682
|
+
'sin': self.wrap_logic(ExactLogic.exact_unary_function(jnp.sin)),
|
|
683
|
+
'tan': self.wrap_logic(ExactLogic.exact_unary_function(jnp.tan)),
|
|
684
|
+
'acos': self.wrap_logic(ExactLogic.exact_unary_function(jnp.arccos)),
|
|
685
|
+
'asin': self.wrap_logic(ExactLogic.exact_unary_function(jnp.arcsin)),
|
|
686
|
+
'atan': self.wrap_logic(ExactLogic.exact_unary_function(jnp.arctan)),
|
|
687
|
+
'cosh': self.wrap_logic(ExactLogic.exact_unary_function(jnp.cosh)),
|
|
688
|
+
'sinh': self.wrap_logic(ExactLogic.exact_unary_function(jnp.sinh)),
|
|
689
|
+
'tanh': self.wrap_logic(ExactLogic.exact_unary_function(jnp.tanh)),
|
|
690
|
+
'exp': self.wrap_logic(ExactLogic.exact_unary_function(jnp.exp)),
|
|
691
|
+
'ln': self.wrap_logic(ExactLogic.exact_unary_function(jnp.log)),
|
|
692
|
+
'sqrt': self.sqrt,
|
|
693
|
+
'lngamma': self.wrap_logic(ExactLogic.exact_unary_function(scipy.special.gammaln)),
|
|
694
|
+
'gamma': self.wrap_logic(ExactLogic.exact_unary_function(scipy.special.gamma))
|
|
695
|
+
},
|
|
696
|
+
'binary': {
|
|
697
|
+
'div': self.div,
|
|
698
|
+
'mod': self.mod,
|
|
699
|
+
'fmod': self.mod,
|
|
700
|
+
'min': self.wrap_logic(ExactLogic.exact_binary_function(jnp.minimum)),
|
|
701
|
+
'max': self.wrap_logic(ExactLogic.exact_binary_function(jnp.maximum)),
|
|
702
|
+
'pow': self.wrap_logic(ExactLogic.exact_binary_function(jnp.power)),
|
|
703
|
+
'log': self.wrap_logic(ExactLogic.exact_binary_log),
|
|
704
|
+
'hypot': self.wrap_logic(ExactLogic.exact_binary_function(jnp.hypot)),
|
|
705
|
+
},
|
|
706
|
+
'control': {
|
|
707
|
+
'if': self.control_if,
|
|
708
|
+
'switch': self.control_switch
|
|
709
|
+
},
|
|
710
|
+
'sampling': {
|
|
711
|
+
'Bernoulli': self.bernoulli,
|
|
712
|
+
'Discrete': self.discrete,
|
|
713
|
+
'Poisson': self.poisson,
|
|
714
|
+
'Geometric': self.geometric,
|
|
715
|
+
'Binomial': self.binomial
|
|
716
|
+
}
|
|
717
|
+
}
|
|
718
|
+
|
|
453
719
|
# ===========================================================================
|
|
454
720
|
# logical operators
|
|
455
721
|
# ===========================================================================
|
|
@@ -560,6 +826,9 @@ class Logic:
|
|
|
560
826
|
|
|
561
827
|
def geometric(self, id, init_params):
|
|
562
828
|
raise NotImplementedError
|
|
829
|
+
|
|
830
|
+
def binomial(self, id, init_params):
|
|
831
|
+
raise NotImplementedError
|
|
563
832
|
|
|
564
833
|
|
|
565
834
|
class ExactLogic(Logic):
|
|
@@ -600,11 +869,11 @@ class ExactLogic(Logic):
|
|
|
600
869
|
return self.exact_binary_function(jnp.logical_xor)
|
|
601
870
|
|
|
602
871
|
@staticmethod
|
|
603
|
-
def
|
|
872
|
+
def _jax_wrapped_calc_implies_exact(x, y, params):
|
|
604
873
|
return jnp.logical_or(jnp.logical_not(x), y), params
|
|
605
874
|
|
|
606
875
|
def implies(self, id, init_params):
|
|
607
|
-
return self.
|
|
876
|
+
return self._jax_wrapped_calc_implies_exact
|
|
608
877
|
|
|
609
878
|
def equiv(self, id, init_params):
|
|
610
879
|
return self.exact_binary_function(jnp.equal)
|
|
@@ -641,6 +910,10 @@ class ExactLogic(Logic):
|
|
|
641
910
|
# special functions
|
|
642
911
|
# ===========================================================================
|
|
643
912
|
|
|
913
|
+
@staticmethod
|
|
914
|
+
def exact_binary_log(x, y, params):
|
|
915
|
+
return jnp.log(x) / jnp.log(y), params
|
|
916
|
+
|
|
644
917
|
def sgn(self, id, init_params):
|
|
645
918
|
return self.exact_unary_function(jnp.sign)
|
|
646
919
|
|
|
@@ -677,54 +950,62 @@ class ExactLogic(Logic):
|
|
|
677
950
|
# ===========================================================================
|
|
678
951
|
|
|
679
952
|
@staticmethod
|
|
680
|
-
def
|
|
953
|
+
def _jax_wrapped_calc_if_then_else_exact(c, a, b, params):
|
|
681
954
|
return jnp.where(c > 0.5, a, b), params
|
|
682
955
|
|
|
683
956
|
def control_if(self, id, init_params):
|
|
684
|
-
return self.
|
|
957
|
+
return self._jax_wrapped_calc_if_then_else_exact
|
|
685
958
|
|
|
686
959
|
@staticmethod
|
|
687
|
-
def
|
|
960
|
+
def _jax_wrapped_calc_switch_exact(pred, cases, params):
|
|
688
961
|
pred = pred[jnp.newaxis, ...]
|
|
689
962
|
sample = jnp.take_along_axis(cases, pred, axis=0)
|
|
690
963
|
assert sample.shape[0] == 1
|
|
691
964
|
return sample[0, ...], params
|
|
692
965
|
|
|
693
966
|
def control_switch(self, id, init_params):
|
|
694
|
-
return self.
|
|
967
|
+
return self._jax_wrapped_calc_switch_exact
|
|
695
968
|
|
|
696
969
|
# ===========================================================================
|
|
697
970
|
# random variables
|
|
698
971
|
# ===========================================================================
|
|
699
972
|
|
|
700
973
|
@staticmethod
|
|
701
|
-
def
|
|
702
|
-
|
|
974
|
+
def _jax_wrapped_calc_discrete_exact(key, prob, params):
|
|
975
|
+
sample = random.categorical(key=key, logits=jnp.log(prob), axis=-1)
|
|
976
|
+
return sample, params
|
|
703
977
|
|
|
704
978
|
def discrete(self, id, init_params):
|
|
705
|
-
return self.
|
|
979
|
+
return self._jax_wrapped_calc_discrete_exact
|
|
706
980
|
|
|
707
981
|
@staticmethod
|
|
708
|
-
def
|
|
982
|
+
def _jax_wrapped_calc_bernoulli_exact(key, prob, params):
|
|
709
983
|
return random.bernoulli(key, prob), params
|
|
710
984
|
|
|
711
985
|
def bernoulli(self, id, init_params):
|
|
712
|
-
return self.
|
|
713
|
-
|
|
714
|
-
@staticmethod
|
|
715
|
-
def exact_poisson(key, rate, params):
|
|
716
|
-
return random.poisson(key=key, lam=rate), params
|
|
986
|
+
return self._jax_wrapped_calc_bernoulli_exact
|
|
717
987
|
|
|
718
988
|
def poisson(self, id, init_params):
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
return random.geometric(key=key, p=prob), params
|
|
989
|
+
def _jax_wrapped_calc_poisson_exact(key, rate, params):
|
|
990
|
+
sample = random.poisson(key=key, lam=rate, dtype=self.INT)
|
|
991
|
+
return sample, params
|
|
992
|
+
return _jax_wrapped_calc_poisson_exact
|
|
724
993
|
|
|
725
994
|
def geometric(self, id, init_params):
|
|
726
|
-
|
|
727
|
-
|
|
995
|
+
def _jax_wrapped_calc_geometric_exact(key, prob, params):
|
|
996
|
+
sample = random.geometric(key=key, p=prob, dtype=self.INT)
|
|
997
|
+
return sample, params
|
|
998
|
+
return _jax_wrapped_calc_geometric_exact
|
|
999
|
+
|
|
1000
|
+
def binomial(self, id, init_params):
|
|
1001
|
+
def _jax_wrapped_calc_binomial_exact(key, trials, prob, params):
|
|
1002
|
+
trials = jnp.asarray(trials, dtype=self.REAL)
|
|
1003
|
+
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
1004
|
+
sample = random.binomial(key=key, n=trials, p=prob, dtype=self.REAL)
|
|
1005
|
+
sample = jnp.asarray(sample, dtype=self.INT)
|
|
1006
|
+
return sample, params
|
|
1007
|
+
return _jax_wrapped_calc_binomial_exact
|
|
1008
|
+
|
|
728
1009
|
|
|
729
1010
|
class FuzzyLogic(Logic):
|
|
730
1011
|
'''A class representing fuzzy logic in JAX.'''
|
|
@@ -732,7 +1013,7 @@ class FuzzyLogic(Logic):
|
|
|
732
1013
|
def __init__(self, tnorm: TNorm=ProductTNorm(),
|
|
733
1014
|
complement: Complement=StandardComplement(),
|
|
734
1015
|
comparison: Comparison=SigmoidComparison(),
|
|
735
|
-
sampling: RandomSampling=
|
|
1016
|
+
sampling: RandomSampling=SoftRandomSampling(),
|
|
736
1017
|
rounding: Rounding=SoftRounding(),
|
|
737
1018
|
control: ControlFlow=SoftControlFlow(),
|
|
738
1019
|
eps: float=1e-15,
|
|
@@ -759,14 +1040,14 @@ class FuzzyLogic(Logic):
|
|
|
759
1040
|
|
|
760
1041
|
def __str__(self) -> str:
|
|
761
1042
|
return (f'model relaxation:\n'
|
|
762
|
-
f' tnorm
|
|
763
|
-
f' complement
|
|
764
|
-
f' comparison
|
|
765
|
-
f' sampling
|
|
766
|
-
f' rounding
|
|
767
|
-
f' control
|
|
768
|
-
f' underflow_tol
|
|
769
|
-
f' use_64_bit
|
|
1043
|
+
f' tnorm ={str(self.tnorm)}\n'
|
|
1044
|
+
f' complement ={str(self.complement)}\n'
|
|
1045
|
+
f' comparison ={str(self.comparison)}\n'
|
|
1046
|
+
f' sampling ={str(self.sampling)}\n'
|
|
1047
|
+
f' rounding ={str(self.rounding)}\n'
|
|
1048
|
+
f' control ={str(self.control)}\n'
|
|
1049
|
+
f' underflow_tol={self.eps}\n'
|
|
1050
|
+
f' use_64_bit ={self.use64bit}\n')
|
|
770
1051
|
|
|
771
1052
|
def summarize_hyperparameters(self) -> None:
|
|
772
1053
|
print(self.__str__())
|
|
@@ -950,6 +1231,9 @@ class FuzzyLogic(Logic):
|
|
|
950
1231
|
|
|
951
1232
|
def geometric(self, id, init_params):
|
|
952
1233
|
return self.sampling.geometric(id, init_params, self)
|
|
1234
|
+
|
|
1235
|
+
def binomial(self, id, init_params):
|
|
1236
|
+
return self.sampling.binomial(id, init_params, self)
|
|
953
1237
|
|
|
954
1238
|
|
|
955
1239
|
# ===========================================================================
|
|
@@ -984,8 +1268,8 @@ def _test_logical():
|
|
|
984
1268
|
pred, w = _if(cond, +1, -1, w)
|
|
985
1269
|
return pred
|
|
986
1270
|
|
|
987
|
-
x1 = jnp.asarray([1, 1, -1, -1, 0.1, 15, -0.5]
|
|
988
|
-
x2 = jnp.asarray([1, -1, 1, -1, 10, -30, 6]
|
|
1271
|
+
x1 = jnp.asarray([1, 1, -1, -1, 0.1, 15, -0.5], dtype=float)
|
|
1272
|
+
x2 = jnp.asarray([1, -1, 1, -1, 10, -30, 6], dtype=float)
|
|
989
1273
|
print(test_logic(x1, x2, init_params))
|
|
990
1274
|
|
|
991
1275
|
|