pyRDDLGym-jax 2.2__py3-none-any.whl → 2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +14 -8
- pyRDDLGym_jax/core/logic.py +118 -55
- {pyrddlgym_jax-2.2.dist-info → pyrddlgym_jax-2.3.dist-info}/METADATA +1 -1
- {pyrddlgym_jax-2.2.dist-info → pyrddlgym_jax-2.3.dist-info}/RECORD +9 -9
- {pyrddlgym_jax-2.2.dist-info → pyrddlgym_jax-2.3.dist-info}/WHEEL +1 -1
- {pyrddlgym_jax-2.2.dist-info → pyrddlgym_jax-2.3.dist-info}/LICENSE +0 -0
- {pyrddlgym_jax-2.2.dist-info → pyrddlgym_jax-2.3.dist-info}/entry_points.txt +0 -0
- {pyrddlgym_jax-2.2.dist-info → pyrddlgym_jax-2.3.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '2.
|
|
1
|
+
__version__ = '2.3'
|
pyRDDLGym_jax/core/compiler.py
CHANGED
|
@@ -1019,6 +1019,9 @@ class JaxRDDLCompiler:
|
|
|
1019
1019
|
# UnnormDiscrete: complete (subclass uses Gumbel-softmax)
|
|
1020
1020
|
# Discrete(p): complete (subclass uses Gumbel-softmax)
|
|
1021
1021
|
# UnnormDiscrete(p): complete (subclass uses Gumbel-softmax)
|
|
1022
|
+
# Poisson (subclass uses Gumbel-softmax or Poisson process trick)
|
|
1023
|
+
# Binomial (subclass uses Gumbel-softmax or Normal approximation)
|
|
1024
|
+
# NegativeBinomial (subclass uses Poisson-Gamma mixture)
|
|
1022
1025
|
|
|
1023
1026
|
# distributions which seem to support backpropagation (need more testing):
|
|
1024
1027
|
# Beta
|
|
@@ -1026,11 +1029,8 @@ class JaxRDDLCompiler:
|
|
|
1026
1029
|
# Gamma
|
|
1027
1030
|
# ChiSquare
|
|
1028
1031
|
# Dirichlet
|
|
1029
|
-
# Poisson (subclass uses Gumbel-softmax or Poisson process trick)
|
|
1030
1032
|
|
|
1031
1033
|
# distributions with incomplete reparameterization support (TODO):
|
|
1032
|
-
# Binomial
|
|
1033
|
-
# NegativeBinomial
|
|
1034
1034
|
# Multinomial
|
|
1035
1035
|
|
|
1036
1036
|
def _jax_random(self, expr, init_params):
|
|
@@ -1299,8 +1299,17 @@ class JaxRDDLCompiler:
|
|
|
1299
1299
|
def _jax_negative_binomial(self, expr, init_params):
|
|
1300
1300
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_NEGATIVE_BINOMIAL']
|
|
1301
1301
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1302
|
-
|
|
1303
1302
|
arg_trials, arg_prob = expr.args
|
|
1303
|
+
|
|
1304
|
+
# if prob is non-fluent, always use the exact operation
|
|
1305
|
+
if self.compile_non_fluent_exact \
|
|
1306
|
+
and not self.traced.cached_is_fluent(arg_trials) \
|
|
1307
|
+
and not self.traced.cached_is_fluent(arg_prob):
|
|
1308
|
+
negbin_op = self.EXACT_OPS['sampling']['NegativeBinomial']
|
|
1309
|
+
else:
|
|
1310
|
+
negbin_op = self.OPS['sampling']['NegativeBinomial']
|
|
1311
|
+
jax_op = negbin_op(expr.id, init_params)
|
|
1312
|
+
|
|
1304
1313
|
jax_trials = self._jax(arg_trials, init_params)
|
|
1305
1314
|
jax_prob = self._jax(arg_prob, init_params)
|
|
1306
1315
|
|
|
@@ -1308,11 +1317,8 @@ class JaxRDDLCompiler:
|
|
|
1308
1317
|
def _jax_wrapped_distribution_negative_binomial(x, params, key):
|
|
1309
1318
|
trials, key, err2, params = jax_trials(x, params, key)
|
|
1310
1319
|
prob, key, err1, params = jax_prob(x, params, key)
|
|
1311
|
-
trials = jnp.asarray(trials, dtype=self.REAL)
|
|
1312
|
-
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
1313
1320
|
key, subkey = random.split(key)
|
|
1314
|
-
|
|
1315
|
-
sample = jnp.asarray(dist.sample(seed=subkey), dtype=self.INT)
|
|
1321
|
+
sample, params = jax_op(subkey, trials, prob, params)
|
|
1316
1322
|
out_of_bounds = jnp.logical_not(jnp.all(
|
|
1317
1323
|
(prob >= 0) & (prob <= 1) & (trials > 0)))
|
|
1318
1324
|
err = err1 | err2 | (out_of_bounds * ERR)
|
pyRDDLGym_jax/core/logic.py
CHANGED
|
@@ -29,15 +29,27 @@
|
|
|
29
29
|
#
|
|
30
30
|
# ***********************************************************************
|
|
31
31
|
|
|
32
|
-
|
|
32
|
+
import traceback
|
|
33
|
+
from typing import Callable, Dict, Tuple, Union
|
|
33
34
|
|
|
34
35
|
import jax
|
|
35
36
|
import jax.numpy as jnp
|
|
36
37
|
import jax.random as random
|
|
37
38
|
import jax.scipy as scipy
|
|
38
39
|
|
|
40
|
+
from pyRDDLGym.core.debug.exception import raise_warning
|
|
39
41
|
|
|
40
|
-
|
|
42
|
+
# more robust approach - if user does not have this or broken try to continue
|
|
43
|
+
try:
|
|
44
|
+
from tensorflow_probability.substrates import jax as tfp
|
|
45
|
+
except Exception:
|
|
46
|
+
raise_warning('Failed to import tensorflow-probability: '
|
|
47
|
+
'compilation of some probability distributions will fail.', 'red')
|
|
48
|
+
traceback.print_exc()
|
|
49
|
+
tfp = None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def enumerate_literals(shape: Tuple[int, ...], axis: int, dtype: type=jnp.int32) -> jnp.ndarray:
|
|
41
53
|
literals = jnp.arange(shape[axis], dtype=dtype)
|
|
42
54
|
literals = literals[(...,) + (jnp.newaxis,) * (len(shape) - 1)]
|
|
43
55
|
literals = jnp.moveaxis(literals, source=0, destination=axis)
|
|
@@ -74,7 +86,7 @@ class Comparison:
|
|
|
74
86
|
class SigmoidComparison(Comparison):
|
|
75
87
|
'''Comparison operations approximated using sigmoid functions.'''
|
|
76
88
|
|
|
77
|
-
def __init__(self, weight: float=10.0):
|
|
89
|
+
def __init__(self, weight: float=10.0) -> None:
|
|
78
90
|
self.weight = weight
|
|
79
91
|
|
|
80
92
|
# https://arxiv.org/abs/2110.05651
|
|
@@ -140,7 +152,7 @@ class Rounding:
|
|
|
140
152
|
class SoftRounding(Rounding):
|
|
141
153
|
'''Rounding operations approximated using soft operations.'''
|
|
142
154
|
|
|
143
|
-
def __init__(self, weight: float=10.0):
|
|
155
|
+
def __init__(self, weight: float=10.0) -> None:
|
|
144
156
|
self.weight = weight
|
|
145
157
|
|
|
146
158
|
# https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/bijectors/Softfloor
|
|
@@ -291,7 +303,7 @@ class YagerTNorm(TNorm):
|
|
|
291
303
|
'''Yager t-norm given by the expression
|
|
292
304
|
(x, y) -> max(1 - ((1 - x)^p + (1 - y)^p)^(1/p)).'''
|
|
293
305
|
|
|
294
|
-
def __init__(self, p=2.0):
|
|
306
|
+
def __init__(self, p: float=2.0) -> None:
|
|
295
307
|
self.p = float(p)
|
|
296
308
|
|
|
297
309
|
def norm(self, id, init_params):
|
|
@@ -339,6 +351,9 @@ class RandomSampling:
|
|
|
339
351
|
def binomial(self, id, init_params, logic):
|
|
340
352
|
raise NotImplementedError
|
|
341
353
|
|
|
354
|
+
def negative_binomial(self, id, init_params, logic):
|
|
355
|
+
raise NotImplementedError
|
|
356
|
+
|
|
342
357
|
def geometric(self, id, init_params, logic):
|
|
343
358
|
raise NotImplementedError
|
|
344
359
|
|
|
@@ -386,8 +401,7 @@ class SoftRandomSampling(RandomSampling):
|
|
|
386
401
|
def _poisson_gumbel_softmax(self, id, init_params, logic):
|
|
387
402
|
argmax_approx = logic.argmax(id, init_params)
|
|
388
403
|
def _jax_wrapped_calc_poisson_gumbel_softmax(key, rate, params):
|
|
389
|
-
ks = jnp.arange(
|
|
390
|
-
ks = ks[(jnp.newaxis,) * jnp.ndim(rate) + (...,)]
|
|
404
|
+
ks = jnp.arange(self.poisson_bins)[(jnp.newaxis,) * jnp.ndim(rate) + (...,)]
|
|
391
405
|
rate = rate[..., jnp.newaxis]
|
|
392
406
|
log_prob = ks * jnp.log(rate + logic.eps) - rate - scipy.special.gammaln(ks + 1)
|
|
393
407
|
Gumbel01 = random.gumbel(key=key, shape=jnp.shape(log_prob), dtype=logic.REAL)
|
|
@@ -400,10 +414,7 @@ class SoftRandomSampling(RandomSampling):
|
|
|
400
414
|
less_approx = logic.less(id, init_params)
|
|
401
415
|
def _jax_wrapped_calc_poisson_exponential(key, rate, params):
|
|
402
416
|
Exp1 = random.exponential(
|
|
403
|
-
key=key,
|
|
404
|
-
shape=(self.poisson_bins,) + jnp.shape(rate),
|
|
405
|
-
dtype=logic.REAL
|
|
406
|
-
)
|
|
417
|
+
key=key, shape=(self.poisson_bins,) + jnp.shape(rate), dtype=logic.REAL)
|
|
407
418
|
delta_t = Exp1 / rate[jnp.newaxis, ...]
|
|
408
419
|
times = jnp.cumsum(delta_t, axis=0)
|
|
409
420
|
indicator, params = less_approx(times, 1.0, params)
|
|
@@ -411,72 +422,98 @@ class SoftRandomSampling(RandomSampling):
|
|
|
411
422
|
return sample, params
|
|
412
423
|
return _jax_wrapped_calc_poisson_exponential
|
|
413
424
|
|
|
425
|
+
# normal approximation to Poisson: Poisson(rate) -> Normal(rate, rate)
|
|
426
|
+
def _poisson_normal_approx(self, logic):
|
|
427
|
+
def _jax_wrapped_calc_poisson_normal_approx(key, rate, params):
|
|
428
|
+
normal = random.normal(key=key, shape=jnp.shape(rate), dtype=logic.REAL)
|
|
429
|
+
sample = rate + jnp.sqrt(rate) * normal
|
|
430
|
+
return sample, params
|
|
431
|
+
return _jax_wrapped_calc_poisson_normal_approx
|
|
432
|
+
|
|
414
433
|
def poisson(self, id, init_params, logic):
|
|
415
|
-
def _jax_wrapped_calc_poisson_exact(key, rate, params):
|
|
416
|
-
sample = random.poisson(key=key, lam=rate, dtype=logic.INT)
|
|
417
|
-
sample = jnp.asarray(sample, dtype=logic.REAL)
|
|
418
|
-
return sample, params
|
|
419
|
-
|
|
420
434
|
if self.poisson_exp_method:
|
|
421
435
|
_jax_wrapped_calc_poisson_diff = self._poisson_exponential(
|
|
422
436
|
id, init_params, logic)
|
|
423
437
|
else:
|
|
424
438
|
_jax_wrapped_calc_poisson_diff = self._poisson_gumbel_softmax(
|
|
425
439
|
id, init_params, logic)
|
|
440
|
+
_jax_wrapped_calc_poisson_normal = self._poisson_normal_approx(logic)
|
|
426
441
|
|
|
442
|
+
# for small rate use the Poisson process or gumbel-softmax reparameterization
|
|
443
|
+
# for large rate use the normal approximation
|
|
427
444
|
def _jax_wrapped_calc_poisson_approx(key, rate, params):
|
|
428
|
-
|
|
429
|
-
# determine if error of truncation at rate is acceptable
|
|
430
445
|
if self.poisson_bins > 0:
|
|
431
446
|
cuml_prob = scipy.stats.poisson.cdf(self.poisson_bins, rate)
|
|
432
|
-
|
|
433
|
-
|
|
447
|
+
small_rate = jax.lax.stop_gradient(cuml_prob >= self.poisson_min_cdf)
|
|
448
|
+
small_sample, params = _jax_wrapped_calc_poisson_diff(key, rate, params)
|
|
449
|
+
large_sample, params = _jax_wrapped_calc_poisson_normal(key, rate, params)
|
|
450
|
+
sample = jnp.where(small_rate, small_sample, large_sample)
|
|
451
|
+
return sample, params
|
|
434
452
|
else:
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
# for acceptable truncation use the approximation, use exact otherwise
|
|
438
|
-
return jax.lax.cond(
|
|
439
|
-
approx_cond,
|
|
440
|
-
_jax_wrapped_calc_poisson_diff,
|
|
441
|
-
_jax_wrapped_calc_poisson_exact,
|
|
442
|
-
key, rate, params
|
|
443
|
-
)
|
|
453
|
+
return _jax_wrapped_calc_poisson_normal(key, rate, params)
|
|
444
454
|
return _jax_wrapped_calc_poisson_approx
|
|
445
455
|
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
456
|
+
# normal approximation to Binomial: Bin(n, p) -> Normal(np, np(1-p))
|
|
457
|
+
def _binomial_normal_approx(self, logic):
|
|
458
|
+
def _jax_wrapped_calc_binomial_normal_approx(key, trials, prob, params):
|
|
459
|
+
normal = random.normal(key=key, shape=jnp.shape(trials), dtype=logic.REAL)
|
|
460
|
+
mean = trials * prob
|
|
461
|
+
std = jnp.sqrt(trials * prob * (1.0 - prob))
|
|
462
|
+
sample = mean + std * normal
|
|
463
|
+
return sample, params
|
|
464
|
+
return _jax_wrapped_calc_binomial_normal_approx
|
|
452
465
|
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
def
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
466
|
+
def _binomial_gumbel_softmax(self, id, init_params, logic):
|
|
467
|
+
argmax_approx = logic.argmax(id, init_params)
|
|
468
|
+
def _jax_wrapped_calc_binomial_gumbel_softmax(key, trials, prob, params):
|
|
469
|
+
ks = jnp.arange(self.binomial_bins)[(jnp.newaxis,) * jnp.ndim(trials) + (...,)]
|
|
470
|
+
trials = trials[..., jnp.newaxis]
|
|
471
|
+
prob = prob[..., jnp.newaxis]
|
|
472
|
+
in_support = ks <= trials
|
|
473
|
+
ks = jnp.minimum(ks, trials)
|
|
474
|
+
log_prob = ((scipy.special.gammaln(trials + 1) -
|
|
475
|
+
scipy.special.gammaln(ks + 1) -
|
|
476
|
+
scipy.special.gammaln(trials - ks + 1)) +
|
|
477
|
+
ks * jnp.log(prob + logic.eps) +
|
|
478
|
+
(trials - ks) * jnp.log1p(-prob + logic.eps))
|
|
479
|
+
log_prob = jnp.where(in_support, log_prob, jnp.log(logic.eps))
|
|
480
|
+
Gumbel01 = random.gumbel(key=key, shape=jnp.shape(log_prob), dtype=logic.REAL)
|
|
481
|
+
sample = Gumbel01 + log_prob
|
|
482
|
+
return argmax_approx(sample, axis=-1, params=params)
|
|
483
|
+
return _jax_wrapped_calc_binomial_gumbel_softmax
|
|
464
484
|
|
|
465
|
-
|
|
485
|
+
def binomial(self, id, init_params, logic):
|
|
486
|
+
_jax_wrapped_calc_binomial_normal = self._binomial_normal_approx(logic)
|
|
487
|
+
_jax_wrapped_calc_binomial_gs = self._binomial_gumbel_softmax(id, init_params, logic)
|
|
488
|
+
|
|
489
|
+
# for small trials use the Bernoulli relaxation
|
|
490
|
+
# for large trials use the normal approximation
|
|
466
491
|
def _jax_wrapped_calc_binomial_approx(key, trials, prob, params):
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
)
|
|
492
|
+
small_trials = jax.lax.stop_gradient(trials < self.binomial_bins)
|
|
493
|
+
small_sample, params = _jax_wrapped_calc_binomial_gs(key, trials, prob, params)
|
|
494
|
+
large_sample, params = _jax_wrapped_calc_binomial_normal(key, trials, prob, params)
|
|
495
|
+
sample = jnp.where(small_trials, small_sample, large_sample)
|
|
496
|
+
return sample, params
|
|
473
497
|
return _jax_wrapped_calc_binomial_approx
|
|
474
498
|
|
|
499
|
+
# https://en.wikipedia.org/wiki/Negative_binomial_distribution#Gamma%E2%80%93Poisson_mixture
|
|
500
|
+
def negative_binomial(self, id, init_params, logic):
|
|
501
|
+
poisson_approx = self.poisson(id, init_params, logic)
|
|
502
|
+
def _jax_wrapped_calc_negative_binomial_approx(key, trials, prob, params):
|
|
503
|
+
key, subkey = random.split(key)
|
|
504
|
+
trials = jnp.asarray(trials, dtype=logic.REAL)
|
|
505
|
+
Gamma = random.gamma(key=key, a=trials, dtype=logic.REAL)
|
|
506
|
+
scale = (1.0 - prob) / prob
|
|
507
|
+
poisson_rate = scale * Gamma
|
|
508
|
+
return poisson_approx(subkey, poisson_rate, params)
|
|
509
|
+
return _jax_wrapped_calc_negative_binomial_approx
|
|
510
|
+
|
|
475
511
|
def geometric(self, id, init_params, logic):
|
|
476
512
|
approx_floor = logic.floor(id, init_params)
|
|
477
513
|
def _jax_wrapped_calc_geometric_approx(key, prob, params):
|
|
478
514
|
U = random.uniform(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
|
|
479
|
-
floor, params = approx_floor(
|
|
515
|
+
floor, params = approx_floor(
|
|
516
|
+
jnp.log1p(-U) / jnp.log1p(-prob + logic.eps), params)
|
|
480
517
|
sample = floor + 1
|
|
481
518
|
return sample, params
|
|
482
519
|
return _jax_wrapped_calc_geometric_approx
|
|
@@ -532,6 +569,14 @@ class Determinization(RandomSampling):
|
|
|
532
569
|
def binomial(self, id, init_params, logic):
|
|
533
570
|
return self._jax_wrapped_calc_binomial_determinized
|
|
534
571
|
|
|
572
|
+
@staticmethod
|
|
573
|
+
def _jax_wrapped_calc_negative_binomial_determinized(key, trials, prob, params):
|
|
574
|
+
sample = trials * ((1.0 / prob) - 1.0)
|
|
575
|
+
return sample, params
|
|
576
|
+
|
|
577
|
+
def negative_binomial(self, id, init_params, logic):
|
|
578
|
+
return self._jax_wrapped_calc_negative_binomial_determinized
|
|
579
|
+
|
|
535
580
|
@staticmethod
|
|
536
581
|
def _jax_wrapped_calc_geometric_determinized(key, prob, params):
|
|
537
582
|
sample = 1.0 / prob
|
|
@@ -712,7 +757,8 @@ class Logic:
|
|
|
712
757
|
'Discrete': self.discrete,
|
|
713
758
|
'Poisson': self.poisson,
|
|
714
759
|
'Geometric': self.geometric,
|
|
715
|
-
'Binomial': self.binomial
|
|
760
|
+
'Binomial': self.binomial,
|
|
761
|
+
'NegativeBinomial': self.negative_binomial
|
|
716
762
|
}
|
|
717
763
|
}
|
|
718
764
|
|
|
@@ -830,6 +876,9 @@ class Logic:
|
|
|
830
876
|
def binomial(self, id, init_params):
|
|
831
877
|
raise NotImplementedError
|
|
832
878
|
|
|
879
|
+
def negative_binomial(self, id, init_params):
|
|
880
|
+
raise NotImplementedError
|
|
881
|
+
|
|
833
882
|
|
|
834
883
|
class ExactLogic(Logic):
|
|
835
884
|
'''A class representing exact logic in JAX.'''
|
|
@@ -1005,6 +1054,17 @@ class ExactLogic(Logic):
|
|
|
1005
1054
|
sample = jnp.asarray(sample, dtype=self.INT)
|
|
1006
1055
|
return sample, params
|
|
1007
1056
|
return _jax_wrapped_calc_binomial_exact
|
|
1057
|
+
|
|
1058
|
+
# note: for some reason tfp defines it as number of successes before trials failures
|
|
1059
|
+
# I will define it as the number of failures before trials successes
|
|
1060
|
+
def negative_binomial(self, id, init_params):
|
|
1061
|
+
def _jax_wrapped_calc_negative_binomial_exact(key, trials, prob, params):
|
|
1062
|
+
trials = jnp.asarray(trials, dtype=self.REAL)
|
|
1063
|
+
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
1064
|
+
dist = tfp.distributions.NegativeBinomial(total_count=trials, probs=1.0 - prob)
|
|
1065
|
+
sample = jnp.asarray(dist.sample(seed=key), dtype=self.INT)
|
|
1066
|
+
return sample, params
|
|
1067
|
+
return _jax_wrapped_calc_negative_binomial_exact
|
|
1008
1068
|
|
|
1009
1069
|
|
|
1010
1070
|
class FuzzyLogic(Logic):
|
|
@@ -1234,6 +1294,9 @@ class FuzzyLogic(Logic):
|
|
|
1234
1294
|
|
|
1235
1295
|
def binomial(self, id, init_params):
|
|
1236
1296
|
return self.sampling.binomial(id, init_params, self)
|
|
1297
|
+
|
|
1298
|
+
def negative_binomial(self, id, init_params):
|
|
1299
|
+
return self.sampling.negative_binomial(id, init_params, self)
|
|
1237
1300
|
|
|
1238
1301
|
|
|
1239
1302
|
# ===========================================================================
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: pyRDDLGym-jax
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.3
|
|
4
4
|
Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
|
|
5
5
|
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
6
|
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
pyRDDLGym_jax/__init__.py,sha256=
|
|
1
|
+
pyRDDLGym_jax/__init__.py,sha256=ab_pLSTaKv50-5b6lazl75TqhQi0bNsErQ8JlBepVII,19
|
|
2
2
|
pyRDDLGym_jax/entry_point.py,sha256=dxDlO_5gneEEViwkLCg30Z-KVzUgdRXaKuFjoZklkA0,974
|
|
3
3
|
pyRDDLGym_jax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
|
-
pyRDDLGym_jax/core/compiler.py,sha256=
|
|
5
|
-
pyRDDLGym_jax/core/logic.py,sha256=
|
|
4
|
+
pyRDDLGym_jax/core/compiler.py,sha256=fLOdJED-Cxtm_IT4LRiZ461Alp9Qjr0vBsOnw1s__EY,82612
|
|
5
|
+
pyRDDLGym_jax/core/logic.py,sha256=0NNm0OaeKv46K0VNY6vL0PHOUFZPNxqQLOvQYkHCswM,56093
|
|
6
6
|
pyRDDLGym_jax/core/planner.py,sha256=0rluBXKGNHRPEPfegOWcx9__cJHr8KjZdDJtG7i1JjI,122793
|
|
7
7
|
pyRDDLGym_jax/core/simulator.py,sha256=DnPL93WVCMZqtqMUoiJdfWcH9pEvNgGfDfO4NV0wIS0,9271
|
|
8
8
|
pyRDDLGym_jax/core/tuning.py,sha256=RKKtDZp7unvfbhZEoaunZtcAn5xtzGYqXBB_Ij_Aapc,24205
|
|
@@ -41,9 +41,9 @@ pyRDDLGym_jax/examples/configs/default_slp.cfg,sha256=mJo0woDevhQCSQfJg30ULVy9qG
|
|
|
41
41
|
pyRDDLGym_jax/examples/configs/tuning_drp.cfg,sha256=CQMpSCKTkGioO7U82mHMsYWFRsutULx0V6Wrl3YzV2U,504
|
|
42
42
|
pyRDDLGym_jax/examples/configs/tuning_replan.cfg,sha256=m_0nozFg_GVld0tGv92Xao_KONFJDq_vtiJKt5isqI8,501
|
|
43
43
|
pyRDDLGym_jax/examples/configs/tuning_slp.cfg,sha256=KHu8II6CA-h_HblwvWHylNRjSvvGS3VHxN7JQNR4p_Q,464
|
|
44
|
-
pyrddlgym_jax-2.
|
|
45
|
-
pyrddlgym_jax-2.
|
|
46
|
-
pyrddlgym_jax-2.
|
|
47
|
-
pyrddlgym_jax-2.
|
|
48
|
-
pyrddlgym_jax-2.
|
|
49
|
-
pyrddlgym_jax-2.
|
|
44
|
+
pyrddlgym_jax-2.3.dist-info/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
|
|
45
|
+
pyrddlgym_jax-2.3.dist-info/METADATA,sha256=MS6tckyg-bAQBGZJ112VQPZm5at660EfhntCnfrlUbE,17021
|
|
46
|
+
pyrddlgym_jax-2.3.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
|
|
47
|
+
pyrddlgym_jax-2.3.dist-info/entry_points.txt,sha256=Q--z9QzqDBz1xjswPZ87PU-pib-WPXx44hUWAFoBGBA,59
|
|
48
|
+
pyrddlgym_jax-2.3.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
|
|
49
|
+
pyrddlgym_jax-2.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|