pyRDDLGym-jax 2.2__tar.gz → 2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/PKG-INFO +1 -1
  2. pyrddlgym_jax-2.3/pyRDDLGym_jax/__init__.py +1 -0
  3. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/core/compiler.py +14 -8
  4. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/core/logic.py +118 -55
  5. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax.egg-info/PKG-INFO +1 -1
  6. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/setup.py +1 -1
  7. pyrddlgym_jax-2.2/pyRDDLGym_jax/__init__.py +0 -1
  8. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/LICENSE +0 -0
  9. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/README.md +0 -0
  10. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/core/__init__.py +0 -0
  11. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/core/assets/__init__.py +0 -0
  12. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/core/assets/favicon.ico +0 -0
  13. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/core/planner.py +0 -0
  14. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/core/simulator.py +0 -0
  15. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/core/tuning.py +0 -0
  16. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/core/visualization.py +0 -0
  17. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/entry_point.py +0 -0
  18. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/__init__.py +0 -0
  19. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg +0 -0
  20. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +0 -0
  21. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +0 -0
  22. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_drp.cfg +0 -0
  23. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/HVAC_ippc2023_slp.cfg +0 -0
  24. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/MountainCar_Continuous_gym_slp.cfg +0 -0
  25. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/MountainCar_ippc2023_slp.cfg +0 -0
  26. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_drp.cfg +0 -0
  27. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_replan.cfg +0 -0
  28. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/PowerGen_Continuous_slp.cfg +0 -0
  29. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +0 -0
  30. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +0 -0
  31. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +0 -0
  32. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_replan.cfg +0 -0
  33. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +0 -0
  34. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +0 -0
  35. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +0 -0
  36. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +0 -0
  37. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +0 -0
  38. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/__init__.py +0 -0
  39. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/default_drp.cfg +0 -0
  40. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/default_replan.cfg +0 -0
  41. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/default_slp.cfg +0 -0
  42. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/tuning_drp.cfg +0 -0
  43. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/tuning_replan.cfg +0 -0
  44. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/configs/tuning_slp.cfg +0 -0
  45. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/run_gradient.py +0 -0
  46. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/run_gym.py +0 -0
  47. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/run_plan.py +0 -0
  48. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/run_scipy.py +0 -0
  49. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax/examples/run_tune.py +0 -0
  50. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax.egg-info/SOURCES.txt +0 -0
  51. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax.egg-info/dependency_links.txt +0 -0
  52. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax.egg-info/entry_points.txt +0 -0
  53. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax.egg-info/requires.txt +0 -0
  54. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/pyRDDLGym_jax.egg-info/top_level.txt +0 -0
  55. {pyrddlgym_jax-2.2 → pyrddlgym_jax-2.3}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: pyRDDLGym-jax
3
- Version: 2.2
3
+ Version: 2.3
4
4
  Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
5
  Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
6
  Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
@@ -0,0 +1 @@
1
+ __version__ = '2.3'
@@ -1019,6 +1019,9 @@ class JaxRDDLCompiler:
1019
1019
  # UnnormDiscrete: complete (subclass uses Gumbel-softmax)
1020
1020
  # Discrete(p): complete (subclass uses Gumbel-softmax)
1021
1021
  # UnnormDiscrete(p): complete (subclass uses Gumbel-softmax)
1022
+ # Poisson (subclass uses Gumbel-softmax or Poisson process trick)
1023
+ # Binomial (subclass uses Gumbel-softmax or Normal approximation)
1024
+ # NegativeBinomial (subclass uses Poisson-Gamma mixture)
1022
1025
 
1023
1026
  # distributions which seem to support backpropagation (need more testing):
1024
1027
  # Beta
@@ -1026,11 +1029,8 @@ class JaxRDDLCompiler:
1026
1029
  # Gamma
1027
1030
  # ChiSquare
1028
1031
  # Dirichlet
1029
- # Poisson (subclass uses Gumbel-softmax or Poisson process trick)
1030
1032
 
1031
1033
  # distributions with incomplete reparameterization support (TODO):
1032
- # Binomial
1033
- # NegativeBinomial
1034
1034
  # Multinomial
1035
1035
 
1036
1036
  def _jax_random(self, expr, init_params):
@@ -1299,8 +1299,17 @@ class JaxRDDLCompiler:
1299
1299
  def _jax_negative_binomial(self, expr, init_params):
1300
1300
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_NEGATIVE_BINOMIAL']
1301
1301
  JaxRDDLCompiler._check_num_args(expr, 2)
1302
-
1303
1302
  arg_trials, arg_prob = expr.args
1303
+
1304
+ # if prob is non-fluent, always use the exact operation
1305
+ if self.compile_non_fluent_exact \
1306
+ and not self.traced.cached_is_fluent(arg_trials) \
1307
+ and not self.traced.cached_is_fluent(arg_prob):
1308
+ negbin_op = self.EXACT_OPS['sampling']['NegativeBinomial']
1309
+ else:
1310
+ negbin_op = self.OPS['sampling']['NegativeBinomial']
1311
+ jax_op = negbin_op(expr.id, init_params)
1312
+
1304
1313
  jax_trials = self._jax(arg_trials, init_params)
1305
1314
  jax_prob = self._jax(arg_prob, init_params)
1306
1315
 
@@ -1308,11 +1317,8 @@ class JaxRDDLCompiler:
1308
1317
  def _jax_wrapped_distribution_negative_binomial(x, params, key):
1309
1318
  trials, key, err2, params = jax_trials(x, params, key)
1310
1319
  prob, key, err1, params = jax_prob(x, params, key)
1311
- trials = jnp.asarray(trials, dtype=self.REAL)
1312
- prob = jnp.asarray(prob, dtype=self.REAL)
1313
1320
  key, subkey = random.split(key)
1314
- dist = tfp.distributions.NegativeBinomial(total_count=trials, probs=prob)
1315
- sample = jnp.asarray(dist.sample(seed=subkey), dtype=self.INT)
1321
+ sample, params = jax_op(subkey, trials, prob, params)
1316
1322
  out_of_bounds = jnp.logical_not(jnp.all(
1317
1323
  (prob >= 0) & (prob <= 1) & (trials > 0)))
1318
1324
  err = err1 | err2 | (out_of_bounds * ERR)
@@ -29,15 +29,27 @@
29
29
  #
30
30
  # ***********************************************************************
31
31
 
32
- from typing import Callable, Dict, Union
32
+ import traceback
33
+ from typing import Callable, Dict, Tuple, Union
33
34
 
34
35
  import jax
35
36
  import jax.numpy as jnp
36
37
  import jax.random as random
37
38
  import jax.scipy as scipy
38
39
 
40
+ from pyRDDLGym.core.debug.exception import raise_warning
39
41
 
40
- def enumerate_literals(shape, axis, dtype=jnp.int32):
42
+ # more robust approach - if user does not have this or broken try to continue
43
+ try:
44
+ from tensorflow_probability.substrates import jax as tfp
45
+ except Exception:
46
+ raise_warning('Failed to import tensorflow-probability: '
47
+ 'compilation of some probability distributions will fail.', 'red')
48
+ traceback.print_exc()
49
+ tfp = None
50
+
51
+
52
+ def enumerate_literals(shape: Tuple[int, ...], axis: int, dtype: type=jnp.int32) -> jnp.ndarray:
41
53
  literals = jnp.arange(shape[axis], dtype=dtype)
42
54
  literals = literals[(...,) + (jnp.newaxis,) * (len(shape) - 1)]
43
55
  literals = jnp.moveaxis(literals, source=0, destination=axis)
@@ -74,7 +86,7 @@ class Comparison:
74
86
  class SigmoidComparison(Comparison):
75
87
  '''Comparison operations approximated using sigmoid functions.'''
76
88
 
77
- def __init__(self, weight: float=10.0):
89
+ def __init__(self, weight: float=10.0) -> None:
78
90
  self.weight = weight
79
91
 
80
92
  # https://arxiv.org/abs/2110.05651
@@ -140,7 +152,7 @@ class Rounding:
140
152
  class SoftRounding(Rounding):
141
153
  '''Rounding operations approximated using soft operations.'''
142
154
 
143
- def __init__(self, weight: float=10.0):
155
+ def __init__(self, weight: float=10.0) -> None:
144
156
  self.weight = weight
145
157
 
146
158
  # https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/bijectors/Softfloor
@@ -291,7 +303,7 @@ class YagerTNorm(TNorm):
291
303
  '''Yager t-norm given by the expression
292
304
  (x, y) -> max(1 - ((1 - x)^p + (1 - y)^p)^(1/p)).'''
293
305
 
294
- def __init__(self, p=2.0):
306
+ def __init__(self, p: float=2.0) -> None:
295
307
  self.p = float(p)
296
308
 
297
309
  def norm(self, id, init_params):
@@ -339,6 +351,9 @@ class RandomSampling:
339
351
  def binomial(self, id, init_params, logic):
340
352
  raise NotImplementedError
341
353
 
354
+ def negative_binomial(self, id, init_params, logic):
355
+ raise NotImplementedError
356
+
342
357
  def geometric(self, id, init_params, logic):
343
358
  raise NotImplementedError
344
359
 
@@ -386,8 +401,7 @@ class SoftRandomSampling(RandomSampling):
386
401
  def _poisson_gumbel_softmax(self, id, init_params, logic):
387
402
  argmax_approx = logic.argmax(id, init_params)
388
403
  def _jax_wrapped_calc_poisson_gumbel_softmax(key, rate, params):
389
- ks = jnp.arange(0, self.poisson_bins)
390
- ks = ks[(jnp.newaxis,) * jnp.ndim(rate) + (...,)]
404
+ ks = jnp.arange(self.poisson_bins)[(jnp.newaxis,) * jnp.ndim(rate) + (...,)]
391
405
  rate = rate[..., jnp.newaxis]
392
406
  log_prob = ks * jnp.log(rate + logic.eps) - rate - scipy.special.gammaln(ks + 1)
393
407
  Gumbel01 = random.gumbel(key=key, shape=jnp.shape(log_prob), dtype=logic.REAL)
@@ -400,10 +414,7 @@ class SoftRandomSampling(RandomSampling):
400
414
  less_approx = logic.less(id, init_params)
401
415
  def _jax_wrapped_calc_poisson_exponential(key, rate, params):
402
416
  Exp1 = random.exponential(
403
- key=key,
404
- shape=(self.poisson_bins,) + jnp.shape(rate),
405
- dtype=logic.REAL
406
- )
417
+ key=key, shape=(self.poisson_bins,) + jnp.shape(rate), dtype=logic.REAL)
407
418
  delta_t = Exp1 / rate[jnp.newaxis, ...]
408
419
  times = jnp.cumsum(delta_t, axis=0)
409
420
  indicator, params = less_approx(times, 1.0, params)
@@ -411,72 +422,98 @@ class SoftRandomSampling(RandomSampling):
411
422
  return sample, params
412
423
  return _jax_wrapped_calc_poisson_exponential
413
424
 
425
+ # normal approximation to Poisson: Poisson(rate) -> Normal(rate, rate)
426
+ def _poisson_normal_approx(self, logic):
427
+ def _jax_wrapped_calc_poisson_normal_approx(key, rate, params):
428
+ normal = random.normal(key=key, shape=jnp.shape(rate), dtype=logic.REAL)
429
+ sample = rate + jnp.sqrt(rate) * normal
430
+ return sample, params
431
+ return _jax_wrapped_calc_poisson_normal_approx
432
+
414
433
  def poisson(self, id, init_params, logic):
415
- def _jax_wrapped_calc_poisson_exact(key, rate, params):
416
- sample = random.poisson(key=key, lam=rate, dtype=logic.INT)
417
- sample = jnp.asarray(sample, dtype=logic.REAL)
418
- return sample, params
419
-
420
434
  if self.poisson_exp_method:
421
435
  _jax_wrapped_calc_poisson_diff = self._poisson_exponential(
422
436
  id, init_params, logic)
423
437
  else:
424
438
  _jax_wrapped_calc_poisson_diff = self._poisson_gumbel_softmax(
425
439
  id, init_params, logic)
440
+ _jax_wrapped_calc_poisson_normal = self._poisson_normal_approx(logic)
426
441
 
442
+ # for small rate use the Poisson process or gumbel-softmax reparameterization
443
+ # for large rate use the normal approximation
427
444
  def _jax_wrapped_calc_poisson_approx(key, rate, params):
428
-
429
- # determine if error of truncation at rate is acceptable
430
445
  if self.poisson_bins > 0:
431
446
  cuml_prob = scipy.stats.poisson.cdf(self.poisson_bins, rate)
432
- approx_cond = jax.lax.stop_gradient(
433
- jnp.min(cuml_prob) > self.poisson_min_cdf)
447
+ small_rate = jax.lax.stop_gradient(cuml_prob >= self.poisson_min_cdf)
448
+ small_sample, params = _jax_wrapped_calc_poisson_diff(key, rate, params)
449
+ large_sample, params = _jax_wrapped_calc_poisson_normal(key, rate, params)
450
+ sample = jnp.where(small_rate, small_sample, large_sample)
451
+ return sample, params
434
452
  else:
435
- approx_cond = False
436
-
437
- # for acceptable truncation use the approximation, use exact otherwise
438
- return jax.lax.cond(
439
- approx_cond,
440
- _jax_wrapped_calc_poisson_diff,
441
- _jax_wrapped_calc_poisson_exact,
442
- key, rate, params
443
- )
453
+ return _jax_wrapped_calc_poisson_normal(key, rate, params)
444
454
  return _jax_wrapped_calc_poisson_approx
445
455
 
446
- def binomial(self, id, init_params, logic):
447
- def _jax_wrapped_calc_binomial_exact(key, trials, prob, params):
448
- trials = jnp.asarray(trials, dtype=logic.REAL)
449
- prob = jnp.asarray(prob, dtype=logic.REAL)
450
- sample = random.binomial(key=key, n=trials, p=prob, dtype=logic.REAL)
451
- return sample, params
456
+ # normal approximation to Binomial: Bin(n, p) -> Normal(np, np(1-p))
457
+ def _binomial_normal_approx(self, logic):
458
+ def _jax_wrapped_calc_binomial_normal_approx(key, trials, prob, params):
459
+ normal = random.normal(key=key, shape=jnp.shape(trials), dtype=logic.REAL)
460
+ mean = trials * prob
461
+ std = jnp.sqrt(trials * prob * (1.0 - prob))
462
+ sample = mean + std * normal
463
+ return sample, params
464
+ return _jax_wrapped_calc_binomial_normal_approx
452
465
 
453
- # Binomial(n, p) = sum_{i = 1 ... n} Bernoulli(p)
454
- bernoulli_approx = self.bernoulli(id, init_params, logic)
455
- def _jax_wrapped_calc_binomial_sum(key, trials, prob, params):
456
- prob_full = jnp.broadcast_to(
457
- prob[..., jnp.newaxis], shape=jnp.shape(prob) + (self.binomial_bins,))
458
- sample_bern, params = bernoulli_approx(key, prob_full, params)
459
- indices = jnp.arange(self.binomial_bins)[
460
- (jnp.newaxis,) * jnp.ndim(prob) + (...,)]
461
- mask = indices < trials[..., jnp.newaxis]
462
- sample = jnp.sum(sample_bern * mask, axis=-1)
463
- return sample, params
466
+ def _binomial_gumbel_softmax(self, id, init_params, logic):
467
+ argmax_approx = logic.argmax(id, init_params)
468
+ def _jax_wrapped_calc_binomial_gumbel_softmax(key, trials, prob, params):
469
+ ks = jnp.arange(self.binomial_bins)[(jnp.newaxis,) * jnp.ndim(trials) + (...,)]
470
+ trials = trials[..., jnp.newaxis]
471
+ prob = prob[..., jnp.newaxis]
472
+ in_support = ks <= trials
473
+ ks = jnp.minimum(ks, trials)
474
+ log_prob = ((scipy.special.gammaln(trials + 1) -
475
+ scipy.special.gammaln(ks + 1) -
476
+ scipy.special.gammaln(trials - ks + 1)) +
477
+ ks * jnp.log(prob + logic.eps) +
478
+ (trials - ks) * jnp.log1p(-prob + logic.eps))
479
+ log_prob = jnp.where(in_support, log_prob, jnp.log(logic.eps))
480
+ Gumbel01 = random.gumbel(key=key, shape=jnp.shape(log_prob), dtype=logic.REAL)
481
+ sample = Gumbel01 + log_prob
482
+ return argmax_approx(sample, axis=-1, params=params)
483
+ return _jax_wrapped_calc_binomial_gumbel_softmax
464
484
 
465
- # for trials not too large use the Bernoulli relaxation, use exact otherwise
485
+ def binomial(self, id, init_params, logic):
486
+ _jax_wrapped_calc_binomial_normal = self._binomial_normal_approx(logic)
487
+ _jax_wrapped_calc_binomial_gs = self._binomial_gumbel_softmax(id, init_params, logic)
488
+
489
+ # for small trials use the Bernoulli relaxation
490
+ # for large trials use the normal approximation
466
491
  def _jax_wrapped_calc_binomial_approx(key, trials, prob, params):
467
- return jax.lax.cond(
468
- jax.lax.stop_gradient(jnp.max(trials) < self.binomial_bins),
469
- _jax_wrapped_calc_binomial_sum,
470
- _jax_wrapped_calc_binomial_exact,
471
- key, trials, prob, params
472
- )
492
+ small_trials = jax.lax.stop_gradient(trials < self.binomial_bins)
493
+ small_sample, params = _jax_wrapped_calc_binomial_gs(key, trials, prob, params)
494
+ large_sample, params = _jax_wrapped_calc_binomial_normal(key, trials, prob, params)
495
+ sample = jnp.where(small_trials, small_sample, large_sample)
496
+ return sample, params
473
497
  return _jax_wrapped_calc_binomial_approx
474
498
 
499
+ # https://en.wikipedia.org/wiki/Negative_binomial_distribution#Gamma%E2%80%93Poisson_mixture
500
+ def negative_binomial(self, id, init_params, logic):
501
+ poisson_approx = self.poisson(id, init_params, logic)
502
+ def _jax_wrapped_calc_negative_binomial_approx(key, trials, prob, params):
503
+ key, subkey = random.split(key)
504
+ trials = jnp.asarray(trials, dtype=logic.REAL)
505
+ Gamma = random.gamma(key=key, a=trials, dtype=logic.REAL)
506
+ scale = (1.0 - prob) / prob
507
+ poisson_rate = scale * Gamma
508
+ return poisson_approx(subkey, poisson_rate, params)
509
+ return _jax_wrapped_calc_negative_binomial_approx
510
+
475
511
  def geometric(self, id, init_params, logic):
476
512
  approx_floor = logic.floor(id, init_params)
477
513
  def _jax_wrapped_calc_geometric_approx(key, prob, params):
478
514
  U = random.uniform(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
479
- floor, params = approx_floor(jnp.log1p(-U) / jnp.log1p(-prob), params)
515
+ floor, params = approx_floor(
516
+ jnp.log1p(-U) / jnp.log1p(-prob + logic.eps), params)
480
517
  sample = floor + 1
481
518
  return sample, params
482
519
  return _jax_wrapped_calc_geometric_approx
@@ -532,6 +569,14 @@ class Determinization(RandomSampling):
532
569
  def binomial(self, id, init_params, logic):
533
570
  return self._jax_wrapped_calc_binomial_determinized
534
571
 
572
+ @staticmethod
573
+ def _jax_wrapped_calc_negative_binomial_determinized(key, trials, prob, params):
574
+ sample = trials * ((1.0 / prob) - 1.0)
575
+ return sample, params
576
+
577
+ def negative_binomial(self, id, init_params, logic):
578
+ return self._jax_wrapped_calc_negative_binomial_determinized
579
+
535
580
  @staticmethod
536
581
  def _jax_wrapped_calc_geometric_determinized(key, prob, params):
537
582
  sample = 1.0 / prob
@@ -712,7 +757,8 @@ class Logic:
712
757
  'Discrete': self.discrete,
713
758
  'Poisson': self.poisson,
714
759
  'Geometric': self.geometric,
715
- 'Binomial': self.binomial
760
+ 'Binomial': self.binomial,
761
+ 'NegativeBinomial': self.negative_binomial
716
762
  }
717
763
  }
718
764
 
@@ -830,6 +876,9 @@ class Logic:
830
876
  def binomial(self, id, init_params):
831
877
  raise NotImplementedError
832
878
 
879
+ def negative_binomial(self, id, init_params):
880
+ raise NotImplementedError
881
+
833
882
 
834
883
  class ExactLogic(Logic):
835
884
  '''A class representing exact logic in JAX.'''
@@ -1005,6 +1054,17 @@ class ExactLogic(Logic):
1005
1054
  sample = jnp.asarray(sample, dtype=self.INT)
1006
1055
  return sample, params
1007
1056
  return _jax_wrapped_calc_binomial_exact
1057
+
1058
+ # note: for some reason tfp defines it as number of successes before trials failures
1059
+ # I will define it as the number of failures before trials successes
1060
+ def negative_binomial(self, id, init_params):
1061
+ def _jax_wrapped_calc_negative_binomial_exact(key, trials, prob, params):
1062
+ trials = jnp.asarray(trials, dtype=self.REAL)
1063
+ prob = jnp.asarray(prob, dtype=self.REAL)
1064
+ dist = tfp.distributions.NegativeBinomial(total_count=trials, probs=1.0 - prob)
1065
+ sample = jnp.asarray(dist.sample(seed=key), dtype=self.INT)
1066
+ return sample, params
1067
+ return _jax_wrapped_calc_negative_binomial_exact
1008
1068
 
1009
1069
 
1010
1070
  class FuzzyLogic(Logic):
@@ -1234,6 +1294,9 @@ class FuzzyLogic(Logic):
1234
1294
 
1235
1295
  def binomial(self, id, init_params):
1236
1296
  return self.sampling.binomial(id, init_params, self)
1297
+
1298
+ def negative_binomial(self, id, init_params):
1299
+ return self.sampling.negative_binomial(id, init_params, self)
1237
1300
 
1238
1301
 
1239
1302
  # ===========================================================================
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: pyRDDLGym-jax
3
- Version: 2.2
3
+ Version: 2.3
4
4
  Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
5
  Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
6
  Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
@@ -19,7 +19,7 @@ long_description = (Path(__file__).parent / "README.md").read_text()
19
19
 
20
20
  setup(
21
21
  name='pyRDDLGym-jax',
22
- version='2.2',
22
+ version='2.3',
23
23
  author="Michael Gimelfarb, Ayal Taitler, Scott Sanner",
24
24
  author_email="mike.gimelfarb@mail.utoronto.ca, ataitler@gmail.com, ssanner@mie.utoronto.ca",
25
25
  description="pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.",
@@ -1 +0,0 @@
1
- __version__ = '2.2'
File without changes
File without changes
File without changes