pyRDDLGym-jax 2.1__py3-none-any.whl → 2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +14 -8
- pyRDDLGym_jax/core/logic.py +118 -55
- pyRDDLGym_jax/core/planner.py +159 -76
- {pyrddlgym_jax-2.1.dist-info → pyrddlgym_jax-2.3.dist-info}/METADATA +25 -22
- {pyrddlgym_jax-2.1.dist-info → pyrddlgym_jax-2.3.dist-info}/RECORD +10 -10
- {pyrddlgym_jax-2.1.dist-info → pyrddlgym_jax-2.3.dist-info}/WHEEL +1 -1
- {pyrddlgym_jax-2.1.dist-info → pyrddlgym_jax-2.3.dist-info}/LICENSE +0 -0
- {pyrddlgym_jax-2.1.dist-info → pyrddlgym_jax-2.3.dist-info}/entry_points.txt +0 -0
- {pyrddlgym_jax-2.1.dist-info → pyrddlgym_jax-2.3.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '2.
|
|
1
|
+
__version__ = '2.3'
|
pyRDDLGym_jax/core/compiler.py
CHANGED
|
@@ -1019,6 +1019,9 @@ class JaxRDDLCompiler:
|
|
|
1019
1019
|
# UnnormDiscrete: complete (subclass uses Gumbel-softmax)
|
|
1020
1020
|
# Discrete(p): complete (subclass uses Gumbel-softmax)
|
|
1021
1021
|
# UnnormDiscrete(p): complete (subclass uses Gumbel-softmax)
|
|
1022
|
+
# Poisson (subclass uses Gumbel-softmax or Poisson process trick)
|
|
1023
|
+
# Binomial (subclass uses Gumbel-softmax or Normal approximation)
|
|
1024
|
+
# NegativeBinomial (subclass uses Poisson-Gamma mixture)
|
|
1022
1025
|
|
|
1023
1026
|
# distributions which seem to support backpropagation (need more testing):
|
|
1024
1027
|
# Beta
|
|
@@ -1026,11 +1029,8 @@ class JaxRDDLCompiler:
|
|
|
1026
1029
|
# Gamma
|
|
1027
1030
|
# ChiSquare
|
|
1028
1031
|
# Dirichlet
|
|
1029
|
-
# Poisson (subclass uses Gumbel-softmax or Poisson process trick)
|
|
1030
1032
|
|
|
1031
1033
|
# distributions with incomplete reparameterization support (TODO):
|
|
1032
|
-
# Binomial
|
|
1033
|
-
# NegativeBinomial
|
|
1034
1034
|
# Multinomial
|
|
1035
1035
|
|
|
1036
1036
|
def _jax_random(self, expr, init_params):
|
|
@@ -1299,8 +1299,17 @@ class JaxRDDLCompiler:
|
|
|
1299
1299
|
def _jax_negative_binomial(self, expr, init_params):
|
|
1300
1300
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_NEGATIVE_BINOMIAL']
|
|
1301
1301
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1302
|
-
|
|
1303
1302
|
arg_trials, arg_prob = expr.args
|
|
1303
|
+
|
|
1304
|
+
# if prob is non-fluent, always use the exact operation
|
|
1305
|
+
if self.compile_non_fluent_exact \
|
|
1306
|
+
and not self.traced.cached_is_fluent(arg_trials) \
|
|
1307
|
+
and not self.traced.cached_is_fluent(arg_prob):
|
|
1308
|
+
negbin_op = self.EXACT_OPS['sampling']['NegativeBinomial']
|
|
1309
|
+
else:
|
|
1310
|
+
negbin_op = self.OPS['sampling']['NegativeBinomial']
|
|
1311
|
+
jax_op = negbin_op(expr.id, init_params)
|
|
1312
|
+
|
|
1304
1313
|
jax_trials = self._jax(arg_trials, init_params)
|
|
1305
1314
|
jax_prob = self._jax(arg_prob, init_params)
|
|
1306
1315
|
|
|
@@ -1308,11 +1317,8 @@ class JaxRDDLCompiler:
|
|
|
1308
1317
|
def _jax_wrapped_distribution_negative_binomial(x, params, key):
|
|
1309
1318
|
trials, key, err2, params = jax_trials(x, params, key)
|
|
1310
1319
|
prob, key, err1, params = jax_prob(x, params, key)
|
|
1311
|
-
trials = jnp.asarray(trials, dtype=self.REAL)
|
|
1312
|
-
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
1313
1320
|
key, subkey = random.split(key)
|
|
1314
|
-
|
|
1315
|
-
sample = jnp.asarray(dist.sample(seed=subkey), dtype=self.INT)
|
|
1321
|
+
sample, params = jax_op(subkey, trials, prob, params)
|
|
1316
1322
|
out_of_bounds = jnp.logical_not(jnp.all(
|
|
1317
1323
|
(prob >= 0) & (prob <= 1) & (trials > 0)))
|
|
1318
1324
|
err = err1 | err2 | (out_of_bounds * ERR)
|
pyRDDLGym_jax/core/logic.py
CHANGED
|
@@ -29,15 +29,27 @@
|
|
|
29
29
|
#
|
|
30
30
|
# ***********************************************************************
|
|
31
31
|
|
|
32
|
-
|
|
32
|
+
import traceback
|
|
33
|
+
from typing import Callable, Dict, Tuple, Union
|
|
33
34
|
|
|
34
35
|
import jax
|
|
35
36
|
import jax.numpy as jnp
|
|
36
37
|
import jax.random as random
|
|
37
38
|
import jax.scipy as scipy
|
|
38
39
|
|
|
40
|
+
from pyRDDLGym.core.debug.exception import raise_warning
|
|
39
41
|
|
|
40
|
-
|
|
42
|
+
# more robust approach - if user does not have this or broken try to continue
|
|
43
|
+
try:
|
|
44
|
+
from tensorflow_probability.substrates import jax as tfp
|
|
45
|
+
except Exception:
|
|
46
|
+
raise_warning('Failed to import tensorflow-probability: '
|
|
47
|
+
'compilation of some probability distributions will fail.', 'red')
|
|
48
|
+
traceback.print_exc()
|
|
49
|
+
tfp = None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def enumerate_literals(shape: Tuple[int, ...], axis: int, dtype: type=jnp.int32) -> jnp.ndarray:
|
|
41
53
|
literals = jnp.arange(shape[axis], dtype=dtype)
|
|
42
54
|
literals = literals[(...,) + (jnp.newaxis,) * (len(shape) - 1)]
|
|
43
55
|
literals = jnp.moveaxis(literals, source=0, destination=axis)
|
|
@@ -74,7 +86,7 @@ class Comparison:
|
|
|
74
86
|
class SigmoidComparison(Comparison):
|
|
75
87
|
'''Comparison operations approximated using sigmoid functions.'''
|
|
76
88
|
|
|
77
|
-
def __init__(self, weight: float=10.0):
|
|
89
|
+
def __init__(self, weight: float=10.0) -> None:
|
|
78
90
|
self.weight = weight
|
|
79
91
|
|
|
80
92
|
# https://arxiv.org/abs/2110.05651
|
|
@@ -140,7 +152,7 @@ class Rounding:
|
|
|
140
152
|
class SoftRounding(Rounding):
|
|
141
153
|
'''Rounding operations approximated using soft operations.'''
|
|
142
154
|
|
|
143
|
-
def __init__(self, weight: float=10.0):
|
|
155
|
+
def __init__(self, weight: float=10.0) -> None:
|
|
144
156
|
self.weight = weight
|
|
145
157
|
|
|
146
158
|
# https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/bijectors/Softfloor
|
|
@@ -291,7 +303,7 @@ class YagerTNorm(TNorm):
|
|
|
291
303
|
'''Yager t-norm given by the expression
|
|
292
304
|
(x, y) -> max(1 - ((1 - x)^p + (1 - y)^p)^(1/p)).'''
|
|
293
305
|
|
|
294
|
-
def __init__(self, p=2.0):
|
|
306
|
+
def __init__(self, p: float=2.0) -> None:
|
|
295
307
|
self.p = float(p)
|
|
296
308
|
|
|
297
309
|
def norm(self, id, init_params):
|
|
@@ -339,6 +351,9 @@ class RandomSampling:
|
|
|
339
351
|
def binomial(self, id, init_params, logic):
|
|
340
352
|
raise NotImplementedError
|
|
341
353
|
|
|
354
|
+
def negative_binomial(self, id, init_params, logic):
|
|
355
|
+
raise NotImplementedError
|
|
356
|
+
|
|
342
357
|
def geometric(self, id, init_params, logic):
|
|
343
358
|
raise NotImplementedError
|
|
344
359
|
|
|
@@ -386,8 +401,7 @@ class SoftRandomSampling(RandomSampling):
|
|
|
386
401
|
def _poisson_gumbel_softmax(self, id, init_params, logic):
|
|
387
402
|
argmax_approx = logic.argmax(id, init_params)
|
|
388
403
|
def _jax_wrapped_calc_poisson_gumbel_softmax(key, rate, params):
|
|
389
|
-
ks = jnp.arange(
|
|
390
|
-
ks = ks[(jnp.newaxis,) * jnp.ndim(rate) + (...,)]
|
|
404
|
+
ks = jnp.arange(self.poisson_bins)[(jnp.newaxis,) * jnp.ndim(rate) + (...,)]
|
|
391
405
|
rate = rate[..., jnp.newaxis]
|
|
392
406
|
log_prob = ks * jnp.log(rate + logic.eps) - rate - scipy.special.gammaln(ks + 1)
|
|
393
407
|
Gumbel01 = random.gumbel(key=key, shape=jnp.shape(log_prob), dtype=logic.REAL)
|
|
@@ -400,10 +414,7 @@ class SoftRandomSampling(RandomSampling):
|
|
|
400
414
|
less_approx = logic.less(id, init_params)
|
|
401
415
|
def _jax_wrapped_calc_poisson_exponential(key, rate, params):
|
|
402
416
|
Exp1 = random.exponential(
|
|
403
|
-
key=key,
|
|
404
|
-
shape=(self.poisson_bins,) + jnp.shape(rate),
|
|
405
|
-
dtype=logic.REAL
|
|
406
|
-
)
|
|
417
|
+
key=key, shape=(self.poisson_bins,) + jnp.shape(rate), dtype=logic.REAL)
|
|
407
418
|
delta_t = Exp1 / rate[jnp.newaxis, ...]
|
|
408
419
|
times = jnp.cumsum(delta_t, axis=0)
|
|
409
420
|
indicator, params = less_approx(times, 1.0, params)
|
|
@@ -411,72 +422,98 @@ class SoftRandomSampling(RandomSampling):
|
|
|
411
422
|
return sample, params
|
|
412
423
|
return _jax_wrapped_calc_poisson_exponential
|
|
413
424
|
|
|
425
|
+
# normal approximation to Poisson: Poisson(rate) -> Normal(rate, rate)
|
|
426
|
+
def _poisson_normal_approx(self, logic):
|
|
427
|
+
def _jax_wrapped_calc_poisson_normal_approx(key, rate, params):
|
|
428
|
+
normal = random.normal(key=key, shape=jnp.shape(rate), dtype=logic.REAL)
|
|
429
|
+
sample = rate + jnp.sqrt(rate) * normal
|
|
430
|
+
return sample, params
|
|
431
|
+
return _jax_wrapped_calc_poisson_normal_approx
|
|
432
|
+
|
|
414
433
|
def poisson(self, id, init_params, logic):
|
|
415
|
-
def _jax_wrapped_calc_poisson_exact(key, rate, params):
|
|
416
|
-
sample = random.poisson(key=key, lam=rate, dtype=logic.INT)
|
|
417
|
-
sample = jnp.asarray(sample, dtype=logic.REAL)
|
|
418
|
-
return sample, params
|
|
419
|
-
|
|
420
434
|
if self.poisson_exp_method:
|
|
421
435
|
_jax_wrapped_calc_poisson_diff = self._poisson_exponential(
|
|
422
436
|
id, init_params, logic)
|
|
423
437
|
else:
|
|
424
438
|
_jax_wrapped_calc_poisson_diff = self._poisson_gumbel_softmax(
|
|
425
439
|
id, init_params, logic)
|
|
440
|
+
_jax_wrapped_calc_poisson_normal = self._poisson_normal_approx(logic)
|
|
426
441
|
|
|
442
|
+
# for small rate use the Poisson process or gumbel-softmax reparameterization
|
|
443
|
+
# for large rate use the normal approximation
|
|
427
444
|
def _jax_wrapped_calc_poisson_approx(key, rate, params):
|
|
428
|
-
|
|
429
|
-
# determine if error of truncation at rate is acceptable
|
|
430
445
|
if self.poisson_bins > 0:
|
|
431
446
|
cuml_prob = scipy.stats.poisson.cdf(self.poisson_bins, rate)
|
|
432
|
-
|
|
433
|
-
|
|
447
|
+
small_rate = jax.lax.stop_gradient(cuml_prob >= self.poisson_min_cdf)
|
|
448
|
+
small_sample, params = _jax_wrapped_calc_poisson_diff(key, rate, params)
|
|
449
|
+
large_sample, params = _jax_wrapped_calc_poisson_normal(key, rate, params)
|
|
450
|
+
sample = jnp.where(small_rate, small_sample, large_sample)
|
|
451
|
+
return sample, params
|
|
434
452
|
else:
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
# for acceptable truncation use the approximation, use exact otherwise
|
|
438
|
-
return jax.lax.cond(
|
|
439
|
-
approx_cond,
|
|
440
|
-
_jax_wrapped_calc_poisson_diff,
|
|
441
|
-
_jax_wrapped_calc_poisson_exact,
|
|
442
|
-
key, rate, params
|
|
443
|
-
)
|
|
453
|
+
return _jax_wrapped_calc_poisson_normal(key, rate, params)
|
|
444
454
|
return _jax_wrapped_calc_poisson_approx
|
|
445
455
|
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
456
|
+
# normal approximation to Binomial: Bin(n, p) -> Normal(np, np(1-p))
|
|
457
|
+
def _binomial_normal_approx(self, logic):
|
|
458
|
+
def _jax_wrapped_calc_binomial_normal_approx(key, trials, prob, params):
|
|
459
|
+
normal = random.normal(key=key, shape=jnp.shape(trials), dtype=logic.REAL)
|
|
460
|
+
mean = trials * prob
|
|
461
|
+
std = jnp.sqrt(trials * prob * (1.0 - prob))
|
|
462
|
+
sample = mean + std * normal
|
|
463
|
+
return sample, params
|
|
464
|
+
return _jax_wrapped_calc_binomial_normal_approx
|
|
452
465
|
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
def
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
466
|
+
def _binomial_gumbel_softmax(self, id, init_params, logic):
|
|
467
|
+
argmax_approx = logic.argmax(id, init_params)
|
|
468
|
+
def _jax_wrapped_calc_binomial_gumbel_softmax(key, trials, prob, params):
|
|
469
|
+
ks = jnp.arange(self.binomial_bins)[(jnp.newaxis,) * jnp.ndim(trials) + (...,)]
|
|
470
|
+
trials = trials[..., jnp.newaxis]
|
|
471
|
+
prob = prob[..., jnp.newaxis]
|
|
472
|
+
in_support = ks <= trials
|
|
473
|
+
ks = jnp.minimum(ks, trials)
|
|
474
|
+
log_prob = ((scipy.special.gammaln(trials + 1) -
|
|
475
|
+
scipy.special.gammaln(ks + 1) -
|
|
476
|
+
scipy.special.gammaln(trials - ks + 1)) +
|
|
477
|
+
ks * jnp.log(prob + logic.eps) +
|
|
478
|
+
(trials - ks) * jnp.log1p(-prob + logic.eps))
|
|
479
|
+
log_prob = jnp.where(in_support, log_prob, jnp.log(logic.eps))
|
|
480
|
+
Gumbel01 = random.gumbel(key=key, shape=jnp.shape(log_prob), dtype=logic.REAL)
|
|
481
|
+
sample = Gumbel01 + log_prob
|
|
482
|
+
return argmax_approx(sample, axis=-1, params=params)
|
|
483
|
+
return _jax_wrapped_calc_binomial_gumbel_softmax
|
|
464
484
|
|
|
465
|
-
|
|
485
|
+
def binomial(self, id, init_params, logic):
|
|
486
|
+
_jax_wrapped_calc_binomial_normal = self._binomial_normal_approx(logic)
|
|
487
|
+
_jax_wrapped_calc_binomial_gs = self._binomial_gumbel_softmax(id, init_params, logic)
|
|
488
|
+
|
|
489
|
+
# for small trials use the Bernoulli relaxation
|
|
490
|
+
# for large trials use the normal approximation
|
|
466
491
|
def _jax_wrapped_calc_binomial_approx(key, trials, prob, params):
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
)
|
|
492
|
+
small_trials = jax.lax.stop_gradient(trials < self.binomial_bins)
|
|
493
|
+
small_sample, params = _jax_wrapped_calc_binomial_gs(key, trials, prob, params)
|
|
494
|
+
large_sample, params = _jax_wrapped_calc_binomial_normal(key, trials, prob, params)
|
|
495
|
+
sample = jnp.where(small_trials, small_sample, large_sample)
|
|
496
|
+
return sample, params
|
|
473
497
|
return _jax_wrapped_calc_binomial_approx
|
|
474
498
|
|
|
499
|
+
# https://en.wikipedia.org/wiki/Negative_binomial_distribution#Gamma%E2%80%93Poisson_mixture
|
|
500
|
+
def negative_binomial(self, id, init_params, logic):
|
|
501
|
+
poisson_approx = self.poisson(id, init_params, logic)
|
|
502
|
+
def _jax_wrapped_calc_negative_binomial_approx(key, trials, prob, params):
|
|
503
|
+
key, subkey = random.split(key)
|
|
504
|
+
trials = jnp.asarray(trials, dtype=logic.REAL)
|
|
505
|
+
Gamma = random.gamma(key=key, a=trials, dtype=logic.REAL)
|
|
506
|
+
scale = (1.0 - prob) / prob
|
|
507
|
+
poisson_rate = scale * Gamma
|
|
508
|
+
return poisson_approx(subkey, poisson_rate, params)
|
|
509
|
+
return _jax_wrapped_calc_negative_binomial_approx
|
|
510
|
+
|
|
475
511
|
def geometric(self, id, init_params, logic):
|
|
476
512
|
approx_floor = logic.floor(id, init_params)
|
|
477
513
|
def _jax_wrapped_calc_geometric_approx(key, prob, params):
|
|
478
514
|
U = random.uniform(key=key, shape=jnp.shape(prob), dtype=logic.REAL)
|
|
479
|
-
floor, params = approx_floor(
|
|
515
|
+
floor, params = approx_floor(
|
|
516
|
+
jnp.log1p(-U) / jnp.log1p(-prob + logic.eps), params)
|
|
480
517
|
sample = floor + 1
|
|
481
518
|
return sample, params
|
|
482
519
|
return _jax_wrapped_calc_geometric_approx
|
|
@@ -532,6 +569,14 @@ class Determinization(RandomSampling):
|
|
|
532
569
|
def binomial(self, id, init_params, logic):
|
|
533
570
|
return self._jax_wrapped_calc_binomial_determinized
|
|
534
571
|
|
|
572
|
+
@staticmethod
|
|
573
|
+
def _jax_wrapped_calc_negative_binomial_determinized(key, trials, prob, params):
|
|
574
|
+
sample = trials * ((1.0 / prob) - 1.0)
|
|
575
|
+
return sample, params
|
|
576
|
+
|
|
577
|
+
def negative_binomial(self, id, init_params, logic):
|
|
578
|
+
return self._jax_wrapped_calc_negative_binomial_determinized
|
|
579
|
+
|
|
535
580
|
@staticmethod
|
|
536
581
|
def _jax_wrapped_calc_geometric_determinized(key, prob, params):
|
|
537
582
|
sample = 1.0 / prob
|
|
@@ -712,7 +757,8 @@ class Logic:
|
|
|
712
757
|
'Discrete': self.discrete,
|
|
713
758
|
'Poisson': self.poisson,
|
|
714
759
|
'Geometric': self.geometric,
|
|
715
|
-
'Binomial': self.binomial
|
|
760
|
+
'Binomial': self.binomial,
|
|
761
|
+
'NegativeBinomial': self.negative_binomial
|
|
716
762
|
}
|
|
717
763
|
}
|
|
718
764
|
|
|
@@ -830,6 +876,9 @@ class Logic:
|
|
|
830
876
|
def binomial(self, id, init_params):
|
|
831
877
|
raise NotImplementedError
|
|
832
878
|
|
|
879
|
+
def negative_binomial(self, id, init_params):
|
|
880
|
+
raise NotImplementedError
|
|
881
|
+
|
|
833
882
|
|
|
834
883
|
class ExactLogic(Logic):
|
|
835
884
|
'''A class representing exact logic in JAX.'''
|
|
@@ -1005,6 +1054,17 @@ class ExactLogic(Logic):
|
|
|
1005
1054
|
sample = jnp.asarray(sample, dtype=self.INT)
|
|
1006
1055
|
return sample, params
|
|
1007
1056
|
return _jax_wrapped_calc_binomial_exact
|
|
1057
|
+
|
|
1058
|
+
# note: for some reason tfp defines it as number of successes before trials failures
|
|
1059
|
+
# I will define it as the number of failures before trials successes
|
|
1060
|
+
def negative_binomial(self, id, init_params):
|
|
1061
|
+
def _jax_wrapped_calc_negative_binomial_exact(key, trials, prob, params):
|
|
1062
|
+
trials = jnp.asarray(trials, dtype=self.REAL)
|
|
1063
|
+
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
1064
|
+
dist = tfp.distributions.NegativeBinomial(total_count=trials, probs=1.0 - prob)
|
|
1065
|
+
sample = jnp.asarray(dist.sample(seed=key), dtype=self.INT)
|
|
1066
|
+
return sample, params
|
|
1067
|
+
return _jax_wrapped_calc_negative_binomial_exact
|
|
1008
1068
|
|
|
1009
1069
|
|
|
1010
1070
|
class FuzzyLogic(Logic):
|
|
@@ -1234,6 +1294,9 @@ class FuzzyLogic(Logic):
|
|
|
1234
1294
|
|
|
1235
1295
|
def binomial(self, id, init_params):
|
|
1236
1296
|
return self.sampling.binomial(id, init_params, self)
|
|
1297
|
+
|
|
1298
|
+
def negative_binomial(self, id, init_params):
|
|
1299
|
+
return self.sampling.negative_binomial(id, init_params, self)
|
|
1237
1300
|
|
|
1238
1301
|
|
|
1239
1302
|
# ===========================================================================
|
pyRDDLGym_jax/core/planner.py
CHANGED
|
@@ -47,7 +47,9 @@ import jax.random as random
|
|
|
47
47
|
import numpy as np
|
|
48
48
|
import optax
|
|
49
49
|
import termcolor
|
|
50
|
-
from tqdm import tqdm
|
|
50
|
+
from tqdm import tqdm, TqdmWarning
|
|
51
|
+
import warnings
|
|
52
|
+
warnings.filterwarnings("ignore", category=TqdmWarning)
|
|
51
53
|
|
|
52
54
|
from pyRDDLGym.core.compiler.model import RDDLPlanningModel, RDDLLiftedModel
|
|
53
55
|
from pyRDDLGym.core.debug.logger import Logger
|
|
@@ -1212,17 +1214,22 @@ class GaussianPGPE(PGPE):
|
|
|
1212
1214
|
init_sigma: float=1.0,
|
|
1213
1215
|
sigma_range: Tuple[float, float]=(1e-5, 1e5),
|
|
1214
1216
|
scale_reward: bool=True,
|
|
1217
|
+
min_reward_scale: float=1e-5,
|
|
1215
1218
|
super_symmetric: bool=True,
|
|
1216
1219
|
super_symmetric_accurate: bool=True,
|
|
1217
1220
|
optimizer: Callable[..., optax.GradientTransformation]=optax.adam,
|
|
1218
1221
|
optimizer_kwargs_mu: Optional[Kwargs]=None,
|
|
1219
|
-
optimizer_kwargs_sigma: Optional[Kwargs]=None
|
|
1222
|
+
optimizer_kwargs_sigma: Optional[Kwargs]=None,
|
|
1223
|
+
start_entropy_coeff: float=1e-3,
|
|
1224
|
+
end_entropy_coeff: float=1e-8,
|
|
1225
|
+
max_kl_update: Optional[float]=None) -> None:
|
|
1220
1226
|
'''Creates a new Gaussian PGPE planner.
|
|
1221
1227
|
|
|
1222
1228
|
:param batch_size: how many policy parameters to sample per optimization step
|
|
1223
1229
|
:param init_sigma: initial standard deviation of Gaussian
|
|
1224
1230
|
:param sigma_range: bounds to constrain standard deviation
|
|
1225
1231
|
:param scale_reward: whether to apply reward scaling as in the paper
|
|
1232
|
+
:param min_reward_scale: minimum reward scaling to avoid underflow
|
|
1226
1233
|
:param super_symmetric: whether to use super-symmetric sampling as in the paper
|
|
1227
1234
|
:param super_symmetric_accurate: whether to use the accurate formula for super-
|
|
1228
1235
|
symmetric sampling or the simplified but biased formula
|
|
@@ -1231,6 +1238,9 @@ class GaussianPGPE(PGPE):
|
|
|
1231
1238
|
factory for the mean optimizer
|
|
1232
1239
|
:param optimizer_kwargs_sigma: a dictionary of parameters to pass to the SGD
|
|
1233
1240
|
factory for the standard deviation optimizer
|
|
1241
|
+
:param start_entropy_coeff: starting entropy regularization coeffient for Gaussian
|
|
1242
|
+
:param end_entropy_coeff: ending entropy regularization coeffient for Gaussian
|
|
1243
|
+
:param max_kl_update: bound on kl-divergence between parameter updates
|
|
1234
1244
|
'''
|
|
1235
1245
|
super().__init__()
|
|
1236
1246
|
|
|
@@ -1238,8 +1248,13 @@ class GaussianPGPE(PGPE):
|
|
|
1238
1248
|
self.init_sigma = init_sigma
|
|
1239
1249
|
self.sigma_range = sigma_range
|
|
1240
1250
|
self.scale_reward = scale_reward
|
|
1251
|
+
self.min_reward_scale = min_reward_scale
|
|
1241
1252
|
self.super_symmetric = super_symmetric
|
|
1242
1253
|
self.super_symmetric_accurate = super_symmetric_accurate
|
|
1254
|
+
|
|
1255
|
+
# entropy regularization penalty is decayed exponentially between these values
|
|
1256
|
+
self.start_entropy_coeff = start_entropy_coeff
|
|
1257
|
+
self.end_entropy_coeff = end_entropy_coeff
|
|
1243
1258
|
|
|
1244
1259
|
# set optimizers
|
|
1245
1260
|
if optimizer_kwargs_mu is None:
|
|
@@ -1249,36 +1264,62 @@ class GaussianPGPE(PGPE):
|
|
|
1249
1264
|
optimizer_kwargs_sigma = {'learning_rate': 0.1}
|
|
1250
1265
|
self.optimizer_kwargs_sigma = optimizer_kwargs_sigma
|
|
1251
1266
|
self.optimizer_name = optimizer
|
|
1252
|
-
|
|
1253
|
-
|
|
1267
|
+
try:
|
|
1268
|
+
mu_optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs_mu)
|
|
1269
|
+
sigma_optimizer = optax.inject_hyperparams(optimizer)(**optimizer_kwargs_sigma)
|
|
1270
|
+
except Exception as _:
|
|
1271
|
+
raise_warning(
|
|
1272
|
+
f'Failed to inject hyperparameters into optax optimizer for PGPE, '
|
|
1273
|
+
'rolling back to safer method: please note that kl-divergence '
|
|
1274
|
+
'constraints will be disabled.', 'red')
|
|
1275
|
+
mu_optimizer = optimizer(**optimizer_kwargs_mu)
|
|
1276
|
+
sigma_optimizer = optimizer(**optimizer_kwargs_sigma)
|
|
1277
|
+
max_kl_update = None
|
|
1254
1278
|
self.optimizers = (mu_optimizer, sigma_optimizer)
|
|
1279
|
+
self.max_kl = max_kl_update
|
|
1255
1280
|
|
|
1256
1281
|
def __str__(self) -> str:
|
|
1257
1282
|
return (f'PGPE hyper-parameters:\n'
|
|
1258
|
-
f' method
|
|
1259
|
-
f' batch_size
|
|
1260
|
-
f' init_sigma
|
|
1261
|
-
f' sigma_range
|
|
1262
|
-
f' scale_reward
|
|
1263
|
-
f'
|
|
1264
|
-
f'
|
|
1265
|
-
f'
|
|
1283
|
+
f' method ={self.__class__.__name__}\n'
|
|
1284
|
+
f' batch_size ={self.batch_size}\n'
|
|
1285
|
+
f' init_sigma ={self.init_sigma}\n'
|
|
1286
|
+
f' sigma_range ={self.sigma_range}\n'
|
|
1287
|
+
f' scale_reward ={self.scale_reward}\n'
|
|
1288
|
+
f' min_reward_scale ={self.min_reward_scale}\n'
|
|
1289
|
+
f' super_symmetric ={self.super_symmetric}\n'
|
|
1290
|
+
f' accurate ={self.super_symmetric_accurate}\n'
|
|
1291
|
+
f' optimizer ={self.optimizer_name}\n'
|
|
1266
1292
|
f' optimizer_kwargs:\n'
|
|
1267
1293
|
f' mu ={self.optimizer_kwargs_mu}\n'
|
|
1268
1294
|
f' sigma={self.optimizer_kwargs_sigma}\n'
|
|
1295
|
+
f' start_entropy_coeff={self.start_entropy_coeff}\n'
|
|
1296
|
+
f' end_entropy_coeff ={self.end_entropy_coeff}\n'
|
|
1297
|
+
f' max_kl_update ={self.max_kl}\n'
|
|
1269
1298
|
)
|
|
1270
1299
|
|
|
1271
1300
|
def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type) -> None:
|
|
1272
|
-
MIN_NORM = 1e-5
|
|
1273
1301
|
sigma0 = self.init_sigma
|
|
1274
1302
|
sigma_range = self.sigma_range
|
|
1275
1303
|
scale_reward = self.scale_reward
|
|
1304
|
+
min_reward_scale = self.min_reward_scale
|
|
1276
1305
|
super_symmetric = self.super_symmetric
|
|
1277
1306
|
super_symmetric_accurate = self.super_symmetric_accurate
|
|
1278
1307
|
batch_size = self.batch_size
|
|
1279
1308
|
optimizers = (mu_optimizer, sigma_optimizer) = self.optimizers
|
|
1280
|
-
|
|
1281
|
-
|
|
1309
|
+
max_kl = self.max_kl
|
|
1310
|
+
|
|
1311
|
+
# entropy regularization penalty is decayed exponentially by elapsed budget
|
|
1312
|
+
start_entropy_coeff = self.start_entropy_coeff
|
|
1313
|
+
if start_entropy_coeff == 0:
|
|
1314
|
+
entropy_coeff_decay = 0
|
|
1315
|
+
else:
|
|
1316
|
+
entropy_coeff_decay = (self.end_entropy_coeff / start_entropy_coeff) ** 0.01
|
|
1317
|
+
|
|
1318
|
+
# ***********************************************************************
|
|
1319
|
+
# INITIALIZATION OF POLICY
|
|
1320
|
+
#
|
|
1321
|
+
# ***********************************************************************
|
|
1322
|
+
|
|
1282
1323
|
def _jax_wrapped_pgpe_init(key, policy_params):
|
|
1283
1324
|
mu = policy_params
|
|
1284
1325
|
sigma = jax.tree_map(lambda x: sigma0 * jnp.ones_like(x), mu)
|
|
@@ -1289,7 +1330,11 @@ class GaussianPGPE(PGPE):
|
|
|
1289
1330
|
|
|
1290
1331
|
self._initializer = jax.jit(_jax_wrapped_pgpe_init)
|
|
1291
1332
|
|
|
1292
|
-
#
|
|
1333
|
+
# ***********************************************************************
|
|
1334
|
+
# PARAMETER SAMPLING FUNCTIONS
|
|
1335
|
+
#
|
|
1336
|
+
# ***********************************************************************
|
|
1337
|
+
|
|
1293
1338
|
def _jax_wrapped_mu_noise(key, sigma):
|
|
1294
1339
|
return sigma * random.normal(key, shape=jnp.shape(sigma), dtype=real_dtype)
|
|
1295
1340
|
|
|
@@ -1299,19 +1344,20 @@ class GaussianPGPE(PGPE):
|
|
|
1299
1344
|
a = (sigma - jnp.abs(epsilon)) / sigma
|
|
1300
1345
|
if super_symmetric_accurate:
|
|
1301
1346
|
aa = jnp.abs(a)
|
|
1347
|
+
aa3 = jnp.power(aa, 3)
|
|
1302
1348
|
epsilon_star = jnp.sign(epsilon) * phi * jnp.where(
|
|
1303
1349
|
a <= 0,
|
|
1304
|
-
jnp.exp(c1 *
|
|
1305
|
-
jnp.exp(aa - c3 * aa * jnp.log(1.0 -
|
|
1350
|
+
jnp.exp(c1 * (aa3 - aa) / jnp.log(aa + 1e-10) + c2 * aa),
|
|
1351
|
+
jnp.exp(aa - c3 * aa * jnp.log(1.0 - aa3 + 1e-10))
|
|
1306
1352
|
)
|
|
1307
1353
|
else:
|
|
1308
1354
|
epsilon_star = jnp.sign(epsilon) * phi * jnp.exp(a)
|
|
1309
1355
|
return epsilon_star
|
|
1310
1356
|
|
|
1311
1357
|
def _jax_wrapped_sample_params(key, mu, sigma):
|
|
1312
|
-
|
|
1313
|
-
|
|
1314
|
-
|
|
1358
|
+
treedef = jax.tree_util.tree_structure(sigma)
|
|
1359
|
+
keys = random.split(key, num=treedef.num_leaves)
|
|
1360
|
+
keys_pytree = jax.tree_util.tree_unflatten(treedef=treedef, leaves=keys)
|
|
1315
1361
|
epsilon = jax.tree_map(_jax_wrapped_mu_noise, keys_pytree, sigma)
|
|
1316
1362
|
p1 = jax.tree_map(jnp.add, mu, epsilon)
|
|
1317
1363
|
p2 = jax.tree_map(jnp.subtract, mu, epsilon)
|
|
@@ -1321,14 +1367,18 @@ class GaussianPGPE(PGPE):
|
|
|
1321
1367
|
p4 = jax.tree_map(jnp.subtract, mu, epsilon_star)
|
|
1322
1368
|
else:
|
|
1323
1369
|
epsilon_star, p3, p4 = epsilon, p1, p2
|
|
1324
|
-
return
|
|
1370
|
+
return p1, p2, p3, p4, epsilon, epsilon_star
|
|
1325
1371
|
|
|
1326
|
-
#
|
|
1372
|
+
# ***********************************************************************
|
|
1373
|
+
# POLICY GRADIENT CALCULATION
|
|
1374
|
+
#
|
|
1375
|
+
# ***********************************************************************
|
|
1376
|
+
|
|
1327
1377
|
def _jax_wrapped_mu_grad(epsilon, epsilon_star, r1, r2, r3, r4, m):
|
|
1328
1378
|
if super_symmetric:
|
|
1329
1379
|
if scale_reward:
|
|
1330
|
-
scale1 = jnp.maximum(
|
|
1331
|
-
scale2 = jnp.maximum(
|
|
1380
|
+
scale1 = jnp.maximum(min_reward_scale, m - (r1 + r2) / 2)
|
|
1381
|
+
scale2 = jnp.maximum(min_reward_scale, m - (r3 + r4) / 2)
|
|
1332
1382
|
else:
|
|
1333
1383
|
scale1 = scale2 = 1.0
|
|
1334
1384
|
r_mu1 = (r1 - r2) / (2 * scale1)
|
|
@@ -1336,37 +1386,37 @@ class GaussianPGPE(PGPE):
|
|
|
1336
1386
|
grad = -(r_mu1 * epsilon + r_mu2 * epsilon_star)
|
|
1337
1387
|
else:
|
|
1338
1388
|
if scale_reward:
|
|
1339
|
-
scale = jnp.maximum(
|
|
1389
|
+
scale = jnp.maximum(min_reward_scale, m - (r1 + r2) / 2)
|
|
1340
1390
|
else:
|
|
1341
1391
|
scale = 1.0
|
|
1342
1392
|
r_mu = (r1 - r2) / (2 * scale)
|
|
1343
1393
|
grad = -r_mu * epsilon
|
|
1344
1394
|
return grad
|
|
1345
1395
|
|
|
1346
|
-
def _jax_wrapped_sigma_grad(epsilon, epsilon_star, sigma, r1, r2, r3, r4, m):
|
|
1396
|
+
def _jax_wrapped_sigma_grad(epsilon, epsilon_star, sigma, r1, r2, r3, r4, m, ent):
|
|
1347
1397
|
if super_symmetric:
|
|
1348
1398
|
mask = r1 + r2 >= r3 + r4
|
|
1349
1399
|
epsilon_tau = mask * epsilon + (1 - mask) * epsilon_star
|
|
1350
|
-
s = epsilon_tau
|
|
1400
|
+
s = jnp.square(epsilon_tau) / sigma - sigma
|
|
1351
1401
|
if scale_reward:
|
|
1352
|
-
scale = jnp.maximum(
|
|
1402
|
+
scale = jnp.maximum(min_reward_scale, m - (r1 + r2 + r3 + r4) / 4)
|
|
1353
1403
|
else:
|
|
1354
1404
|
scale = 1.0
|
|
1355
1405
|
r_sigma = ((r1 + r2) - (r3 + r4)) / (4 * scale)
|
|
1356
1406
|
else:
|
|
1357
|
-
s = epsilon
|
|
1407
|
+
s = jnp.square(epsilon) / sigma - sigma
|
|
1358
1408
|
if scale_reward:
|
|
1359
|
-
scale = jnp.maximum(
|
|
1409
|
+
scale = jnp.maximum(min_reward_scale, jnp.abs(m))
|
|
1360
1410
|
else:
|
|
1361
1411
|
scale = 1.0
|
|
1362
1412
|
r_sigma = (r1 + r2) / (2 * scale)
|
|
1363
|
-
grad = -r_sigma * s
|
|
1413
|
+
grad = -(r_sigma * s + ent / sigma)
|
|
1364
1414
|
return grad
|
|
1365
1415
|
|
|
1366
|
-
def _jax_wrapped_pgpe_grad(key, mu, sigma, r_max,
|
|
1416
|
+
def _jax_wrapped_pgpe_grad(key, mu, sigma, r_max, ent,
|
|
1367
1417
|
policy_hyperparams, subs, model_params):
|
|
1368
1418
|
key, subkey = random.split(key)
|
|
1369
|
-
|
|
1419
|
+
p1, p2, p3, p4, epsilon, epsilon_star = _jax_wrapped_sample_params(
|
|
1370
1420
|
key, mu, sigma)
|
|
1371
1421
|
r1 = -loss_fn(subkey, p1, policy_hyperparams, subs, model_params)[0]
|
|
1372
1422
|
r2 = -loss_fn(subkey, p2, policy_hyperparams, subs, model_params)[0]
|
|
@@ -1384,42 +1434,76 @@ class GaussianPGPE(PGPE):
|
|
|
1384
1434
|
epsilon, epsilon_star
|
|
1385
1435
|
)
|
|
1386
1436
|
grad_sigma = jax.tree_map(
|
|
1387
|
-
partial(_jax_wrapped_sigma_grad,
|
|
1437
|
+
partial(_jax_wrapped_sigma_grad,
|
|
1438
|
+
r1=r1, r2=r2, r3=r3, r4=r4, m=r_max, ent=ent),
|
|
1388
1439
|
epsilon, epsilon_star, sigma
|
|
1389
1440
|
)
|
|
1390
1441
|
return grad_mu, grad_sigma, r_max
|
|
1391
1442
|
|
|
1392
|
-
def _jax_wrapped_pgpe_grad_batched(key, pgpe_params, r_max,
|
|
1443
|
+
def _jax_wrapped_pgpe_grad_batched(key, pgpe_params, r_max, ent,
|
|
1393
1444
|
policy_hyperparams, subs, model_params):
|
|
1394
1445
|
mu, sigma = pgpe_params
|
|
1395
1446
|
if batch_size == 1:
|
|
1396
1447
|
mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad(
|
|
1397
|
-
key, mu, sigma, r_max, policy_hyperparams, subs, model_params)
|
|
1448
|
+
key, mu, sigma, r_max, ent, policy_hyperparams, subs, model_params)
|
|
1398
1449
|
else:
|
|
1399
1450
|
keys = random.split(key, num=batch_size)
|
|
1400
1451
|
mu_grads, sigma_grads, r_maxs = jax.vmap(
|
|
1401
1452
|
_jax_wrapped_pgpe_grad,
|
|
1402
|
-
in_axes=(0, None, None, None, None, None, None)
|
|
1403
|
-
)(keys, mu, sigma, r_max, policy_hyperparams, subs, model_params)
|
|
1453
|
+
in_axes=(0, None, None, None, None, None, None, None)
|
|
1454
|
+
)(keys, mu, sigma, r_max, ent, policy_hyperparams, subs, model_params)
|
|
1404
1455
|
mu_grad, sigma_grad = jax.tree_map(
|
|
1405
1456
|
partial(jnp.mean, axis=0), (mu_grads, sigma_grads))
|
|
1406
1457
|
new_r_max = jnp.max(r_maxs)
|
|
1407
1458
|
return mu_grad, sigma_grad, new_r_max
|
|
1459
|
+
|
|
1460
|
+
# ***********************************************************************
|
|
1461
|
+
# PARAMETER UPDATE
|
|
1462
|
+
#
|
|
1463
|
+
# ***********************************************************************
|
|
1408
1464
|
|
|
1409
|
-
def
|
|
1465
|
+
def _jax_wrapped_pgpe_kl_term(mu, sigma, old_mu, old_sigma):
|
|
1466
|
+
return 0.5 * jnp.sum(2 * jnp.log(sigma / old_sigma) +
|
|
1467
|
+
jnp.square(old_sigma / sigma) +
|
|
1468
|
+
jnp.square((mu - old_mu) / sigma) - 1)
|
|
1469
|
+
|
|
1470
|
+
def _jax_wrapped_pgpe_update(key, pgpe_params, r_max, progress,
|
|
1410
1471
|
policy_hyperparams, subs, model_params,
|
|
1411
1472
|
pgpe_opt_state):
|
|
1473
|
+
# regular update
|
|
1412
1474
|
mu, sigma = pgpe_params
|
|
1413
1475
|
mu_state, sigma_state = pgpe_opt_state
|
|
1476
|
+
ent = start_entropy_coeff * jnp.power(entropy_coeff_decay, progress)
|
|
1414
1477
|
mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad_batched(
|
|
1415
|
-
key, pgpe_params, r_max, policy_hyperparams, subs, model_params)
|
|
1478
|
+
key, pgpe_params, r_max, ent, policy_hyperparams, subs, model_params)
|
|
1416
1479
|
mu_updates, new_mu_state = mu_optimizer.update(mu_grad, mu_state, params=mu)
|
|
1417
1480
|
sigma_updates, new_sigma_state = sigma_optimizer.update(
|
|
1418
1481
|
sigma_grad, sigma_state, params=sigma)
|
|
1419
1482
|
new_mu = optax.apply_updates(mu, mu_updates)
|
|
1420
|
-
new_mu, converged = projection(new_mu, policy_hyperparams)
|
|
1421
1483
|
new_sigma = optax.apply_updates(sigma, sigma_updates)
|
|
1422
1484
|
new_sigma = jax.tree_map(lambda x: jnp.clip(x, *sigma_range), new_sigma)
|
|
1485
|
+
|
|
1486
|
+
# respect KL divergence contraint with old parameters
|
|
1487
|
+
if max_kl is not None:
|
|
1488
|
+
old_mu_lr = new_mu_state.hyperparams['learning_rate']
|
|
1489
|
+
old_sigma_lr = new_sigma_state.hyperparams['learning_rate']
|
|
1490
|
+
kl_terms = jax.tree_map(
|
|
1491
|
+
_jax_wrapped_pgpe_kl_term, new_mu, new_sigma, mu, sigma)
|
|
1492
|
+
total_kl = jax.tree_util.tree_reduce(jnp.add, kl_terms)
|
|
1493
|
+
kl_reduction = jnp.minimum(1.0, jnp.sqrt(max_kl / total_kl))
|
|
1494
|
+
mu_state.hyperparams['learning_rate'] = old_mu_lr * kl_reduction
|
|
1495
|
+
sigma_state.hyperparams['learning_rate'] = old_sigma_lr * kl_reduction
|
|
1496
|
+
mu_updates, new_mu_state = mu_optimizer.update(mu_grad, mu_state, params=mu)
|
|
1497
|
+
sigma_updates, new_sigma_state = sigma_optimizer.update(
|
|
1498
|
+
sigma_grad, sigma_state, params=sigma)
|
|
1499
|
+
new_mu = optax.apply_updates(mu, mu_updates)
|
|
1500
|
+
new_sigma = optax.apply_updates(sigma, sigma_updates)
|
|
1501
|
+
new_sigma = jax.tree_map(lambda x: jnp.clip(x, *sigma_range), new_sigma)
|
|
1502
|
+
new_mu_state.hyperparams['learning_rate'] = old_mu_lr
|
|
1503
|
+
new_sigma_state.hyperparams['learning_rate'] = old_sigma_lr
|
|
1504
|
+
|
|
1505
|
+
# apply projection step and finalize results
|
|
1506
|
+
new_mu, converged = projection(new_mu, policy_hyperparams)
|
|
1423
1507
|
new_pgpe_params = (new_mu, new_sigma)
|
|
1424
1508
|
new_pgpe_opt_state = (new_mu_state, new_sigma_state)
|
|
1425
1509
|
policy_params = new_mu
|
|
@@ -1462,14 +1546,14 @@ def mean_deviation_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
|
1462
1546
|
@jax.jit
|
|
1463
1547
|
def mean_semideviation_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
1464
1548
|
mu = jnp.mean(returns)
|
|
1465
|
-
msd = jnp.sqrt(jnp.mean(jnp.minimum(0.0, returns - mu)
|
|
1549
|
+
msd = jnp.sqrt(jnp.mean(jnp.square(jnp.minimum(0.0, returns - mu))))
|
|
1466
1550
|
return mu - 0.5 * beta * msd
|
|
1467
1551
|
|
|
1468
1552
|
|
|
1469
1553
|
@jax.jit
|
|
1470
1554
|
def mean_semivariance_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
1471
1555
|
mu = jnp.mean(returns)
|
|
1472
|
-
msv = jnp.mean(jnp.minimum(0.0, returns - mu)
|
|
1556
|
+
msv = jnp.mean(jnp.square(jnp.minimum(0.0, returns - mu)))
|
|
1473
1557
|
return mu - 0.5 * beta * msv
|
|
1474
1558
|
|
|
1475
1559
|
|
|
@@ -1768,7 +1852,6 @@ r"""
|
|
|
1768
1852
|
|
|
1769
1853
|
# optimization
|
|
1770
1854
|
self.update = self._jax_update(train_loss)
|
|
1771
|
-
self.check_zero_grad = self._jax_check_zero_gradients()
|
|
1772
1855
|
|
|
1773
1856
|
# pgpe option
|
|
1774
1857
|
if self.use_pgpe:
|
|
@@ -1831,6 +1914,12 @@ r"""
|
|
|
1831
1914
|
projection = self.plan.projection
|
|
1832
1915
|
use_ls = self.line_search_kwargs is not None
|
|
1833
1916
|
|
|
1917
|
+
# check if the gradients are all zeros
|
|
1918
|
+
def _jax_wrapped_zero_gradients(grad):
|
|
1919
|
+
leaves, _ = jax.tree_util.tree_flatten(
|
|
1920
|
+
jax.tree_map(lambda g: jnp.allclose(g, 0), grad))
|
|
1921
|
+
return jnp.all(jnp.asarray(leaves))
|
|
1922
|
+
|
|
1834
1923
|
# calculate the plan gradient w.r.t. return loss and update optimizer
|
|
1835
1924
|
# also perform a projection step to satisfy constraints on actions
|
|
1836
1925
|
def _jax_wrapped_loss_swapped(policy_params, key, policy_hyperparams,
|
|
@@ -1855,23 +1944,12 @@ r"""
|
|
|
1855
1944
|
policy_params, converged = projection(policy_params, policy_hyperparams)
|
|
1856
1945
|
log['grad'] = grad
|
|
1857
1946
|
log['updates'] = updates
|
|
1947
|
+
zero_grads = _jax_wrapped_zero_gradients(grad)
|
|
1858
1948
|
return policy_params, converged, opt_state, opt_aux, \
|
|
1859
|
-
loss_val, log, model_params
|
|
1949
|
+
loss_val, log, model_params, zero_grads
|
|
1860
1950
|
|
|
1861
1951
|
return jax.jit(_jax_wrapped_plan_update)
|
|
1862
1952
|
|
|
1863
|
-
def _jax_check_zero_gradients(self):
|
|
1864
|
-
|
|
1865
|
-
def _jax_wrapped_zero_gradient(grad):
|
|
1866
|
-
return jnp.allclose(grad, 0)
|
|
1867
|
-
|
|
1868
|
-
def _jax_wrapped_zero_gradients(grad):
|
|
1869
|
-
leaves, _ = jax.tree_util.tree_flatten(
|
|
1870
|
-
jax.tree_map(_jax_wrapped_zero_gradient, grad))
|
|
1871
|
-
return jnp.all(jnp.asarray(leaves))
|
|
1872
|
-
|
|
1873
|
-
return jax.jit(_jax_wrapped_zero_gradients)
|
|
1874
|
-
|
|
1875
1953
|
def _batched_init_subs(self, subs):
|
|
1876
1954
|
rddl = self.rddl
|
|
1877
1955
|
n_train, n_test = self.batch_size_train, self.batch_size_test
|
|
@@ -2175,11 +2253,12 @@ r"""
|
|
|
2175
2253
|
# ======================================================================
|
|
2176
2254
|
|
|
2177
2255
|
# initialize running statistics
|
|
2178
|
-
best_params, best_loss, best_grad = policy_params, jnp.inf,
|
|
2256
|
+
best_params, best_loss, best_grad = policy_params, jnp.inf, None
|
|
2179
2257
|
last_iter_improve = 0
|
|
2180
2258
|
rolling_test_loss = RollingMean(test_rolling_window)
|
|
2181
2259
|
log = {}
|
|
2182
2260
|
status = JaxPlannerStatus.NORMAL
|
|
2261
|
+
progress_percent = 0
|
|
2183
2262
|
|
|
2184
2263
|
# initialize stopping criterion
|
|
2185
2264
|
if stopping_rule is not None:
|
|
@@ -2191,18 +2270,19 @@ r"""
|
|
|
2191
2270
|
dashboard_id, dashboard.get_planner_info(self),
|
|
2192
2271
|
key=dash_key, viz=self.dashboard_viz)
|
|
2193
2272
|
|
|
2273
|
+
# progress bar
|
|
2274
|
+
if print_progress:
|
|
2275
|
+
progress_bar = tqdm(None, total=100, position=tqdm_position,
|
|
2276
|
+
bar_format='{l_bar}{bar}| {elapsed} {postfix}')
|
|
2277
|
+
else:
|
|
2278
|
+
progress_bar = None
|
|
2279
|
+
position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
|
|
2280
|
+
|
|
2194
2281
|
# ======================================================================
|
|
2195
2282
|
# MAIN TRAINING LOOP BEGINS
|
|
2196
2283
|
# ======================================================================
|
|
2197
2284
|
|
|
2198
|
-
|
|
2199
|
-
if print_progress:
|
|
2200
|
-
iters = tqdm(iters, total=100,
|
|
2201
|
-
bar_format='{l_bar}{bar}| {elapsed} {postfix}',
|
|
2202
|
-
position=tqdm_position)
|
|
2203
|
-
position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
|
|
2204
|
-
|
|
2205
|
-
for it in iters:
|
|
2285
|
+
for it in range(epochs):
|
|
2206
2286
|
|
|
2207
2287
|
# ==================================================================
|
|
2208
2288
|
# NEXT GRADIENT DESCENT STEP
|
|
@@ -2213,8 +2293,9 @@ r"""
|
|
|
2213
2293
|
# update the parameters of the plan
|
|
2214
2294
|
key, subkey = random.split(key)
|
|
2215
2295
|
(policy_params, converged, opt_state, opt_aux, train_loss, train_log,
|
|
2216
|
-
model_params) = self.update(
|
|
2217
|
-
|
|
2296
|
+
model_params, zero_grads) = self.update(
|
|
2297
|
+
subkey, policy_params, policy_hyperparams, train_subs, model_params,
|
|
2298
|
+
opt_state, opt_aux)
|
|
2218
2299
|
test_loss, (test_log, model_params_test) = self.test_loss(
|
|
2219
2300
|
subkey, policy_params, policy_hyperparams, test_subs, model_params_test)
|
|
2220
2301
|
test_loss_smooth = rolling_test_loss.update(test_loss)
|
|
@@ -2224,8 +2305,9 @@ r"""
|
|
|
2224
2305
|
if self.use_pgpe:
|
|
2225
2306
|
key, subkey = random.split(key)
|
|
2226
2307
|
pgpe_params, r_max, pgpe_opt_state, pgpe_param, pgpe_converged = \
|
|
2227
|
-
self.pgpe.update(subkey, pgpe_params, r_max,
|
|
2228
|
-
test_subs,
|
|
2308
|
+
self.pgpe.update(subkey, pgpe_params, r_max, progress_percent,
|
|
2309
|
+
policy_hyperparams, test_subs, model_params_test,
|
|
2310
|
+
pgpe_opt_state)
|
|
2229
2311
|
pgpe_loss, _ = self.test_loss(
|
|
2230
2312
|
subkey, pgpe_param, policy_hyperparams, test_subs, model_params_test)
|
|
2231
2313
|
pgpe_loss_smooth = rolling_pgpe_loss.update(pgpe_loss)
|
|
@@ -2252,7 +2334,7 @@ r"""
|
|
|
2252
2334
|
# ==================================================================
|
|
2253
2335
|
|
|
2254
2336
|
# no progress
|
|
2255
|
-
if (not pgpe_improve) and
|
|
2337
|
+
if (not pgpe_improve) and zero_grads:
|
|
2256
2338
|
status = JaxPlannerStatus.NO_PROGRESS
|
|
2257
2339
|
|
|
2258
2340
|
# constraint satisfaction problem
|
|
@@ -2311,14 +2393,15 @@ r"""
|
|
|
2311
2393
|
|
|
2312
2394
|
# if the progress bar is used
|
|
2313
2395
|
if print_progress:
|
|
2314
|
-
|
|
2315
|
-
iters.set_description(
|
|
2396
|
+
progress_bar.set_description(
|
|
2316
2397
|
f'{position_str} {it:6} it / {-train_loss:14.5f} train / '
|
|
2317
2398
|
f'{-test_loss_smooth:14.5f} test / {-best_loss:14.5f} best / '
|
|
2318
2399
|
f'{status.value} status / {total_pgpe_it:6} pgpe',
|
|
2319
2400
|
refresh=False
|
|
2320
2401
|
)
|
|
2321
|
-
|
|
2402
|
+
progress_bar.set_postfix_str(
|
|
2403
|
+
f"{(it + 1) / (elapsed + 1e-6):.2f}it/s", refresh=False)
|
|
2404
|
+
progress_bar.update(progress_percent - progress_bar.n)
|
|
2322
2405
|
|
|
2323
2406
|
# dash-board
|
|
2324
2407
|
if dashboard is not None:
|
|
@@ -2339,7 +2422,7 @@ r"""
|
|
|
2339
2422
|
|
|
2340
2423
|
# release resources
|
|
2341
2424
|
if print_progress:
|
|
2342
|
-
|
|
2425
|
+
progress_bar.close()
|
|
2343
2426
|
|
|
2344
2427
|
# validate the test return
|
|
2345
2428
|
if log:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: pyRDDLGym-jax
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.3
|
|
4
4
|
Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
|
|
5
5
|
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
6
|
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
@@ -58,8 +58,11 @@ Dynamic: summary
|
|
|
58
58
|
|
|
59
59
|
Purpose:
|
|
60
60
|
|
|
61
|
-
1. automatic translation of
|
|
62
|
-
2.
|
|
61
|
+
1. automatic translation of RDDL description files into differentiable JAX simulators
|
|
62
|
+
2. implementation of (highly configurable) operator relaxations for working in discrete and hybrid domains
|
|
63
|
+
3. flexible policy representations and automated Bayesian hyper-parameter tuning
|
|
64
|
+
4. interactive dashboard for dyanmic visualization and debugging
|
|
65
|
+
5. hybridization with parameter-exploring policy gradients.
|
|
63
66
|
|
|
64
67
|
Some demos of solved problems by JaxPlan:
|
|
65
68
|
|
|
@@ -235,8 +238,23 @@ More documentation about this and other new features will be coming soon.
|
|
|
235
238
|
|
|
236
239
|
## Tuning the Planner
|
|
237
240
|
|
|
238
|
-
|
|
239
|
-
|
|
241
|
+
A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
|
|
242
|
+
|
|
243
|
+
```shell
|
|
244
|
+
jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
|
|
245
|
+
```
|
|
246
|
+
|
|
247
|
+
where:
|
|
248
|
+
- ``domain`` is the domain identifier as specified in rddlrepository
|
|
249
|
+
- ``instance`` is the instance identifier
|
|
250
|
+
- ``method`` is the planning method to use (i.e. drp, slp, replan)
|
|
251
|
+
- ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
|
|
252
|
+
- ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
|
|
253
|
+
- ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
|
|
254
|
+
- ``dashboard`` is whether the optimizations are tracked in the dashboard application.
|
|
255
|
+
|
|
256
|
+
It is easy to tune a custom range of the planner's hyper-parameters efficiently.
|
|
257
|
+
First create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
|
|
240
258
|
|
|
241
259
|
```ini
|
|
242
260
|
[Model]
|
|
@@ -260,7 +278,7 @@ train_on_reset=True
|
|
|
260
278
|
|
|
261
279
|
would allow to tune the sharpness of model relaxations, and the learning rate of the optimizer.
|
|
262
280
|
|
|
263
|
-
Next, you must link the patterns in the config with concrete hyper-parameter ranges the tuner will understand:
|
|
281
|
+
Next, you must link the patterns in the config with concrete hyper-parameter ranges the tuner will understand, and run the optimizer:
|
|
264
282
|
|
|
265
283
|
```python
|
|
266
284
|
import pyRDDLGym
|
|
@@ -292,22 +310,7 @@ tuning = JaxParameterTuning(env=env,
|
|
|
292
310
|
gp_iters=iters)
|
|
293
311
|
tuning.tune(key=42, log_file='path/to/log.csv')
|
|
294
312
|
```
|
|
295
|
-
|
|
296
|
-
A basic run script is provided to run the automatic hyper-parameter tuning for the most sensitive parameters of JaxPlan:
|
|
297
|
-
|
|
298
|
-
```shell
|
|
299
|
-
jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>
|
|
300
|
-
```
|
|
301
|
-
|
|
302
|
-
where:
|
|
303
|
-
- ``domain`` is the domain identifier as specified in rddlrepository
|
|
304
|
-
- ``instance`` is the instance identifier
|
|
305
|
-
- ``method`` is the planning method to use (i.e. drp, slp, replan)
|
|
306
|
-
- ``trials`` is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
|
|
307
|
-
- ``iters`` is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
|
|
308
|
-
- ``workers`` is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = ``iters * workers``
|
|
309
|
-
- ``dashboard`` is whether the optimizations are tracked in the dashboard application.
|
|
310
|
-
|
|
313
|
+
|
|
311
314
|
|
|
312
315
|
## Simulation
|
|
313
316
|
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
pyRDDLGym_jax/__init__.py,sha256=
|
|
1
|
+
pyRDDLGym_jax/__init__.py,sha256=ab_pLSTaKv50-5b6lazl75TqhQi0bNsErQ8JlBepVII,19
|
|
2
2
|
pyRDDLGym_jax/entry_point.py,sha256=dxDlO_5gneEEViwkLCg30Z-KVzUgdRXaKuFjoZklkA0,974
|
|
3
3
|
pyRDDLGym_jax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
|
-
pyRDDLGym_jax/core/compiler.py,sha256=
|
|
5
|
-
pyRDDLGym_jax/core/logic.py,sha256=
|
|
6
|
-
pyRDDLGym_jax/core/planner.py,sha256=
|
|
4
|
+
pyRDDLGym_jax/core/compiler.py,sha256=fLOdJED-Cxtm_IT4LRiZ461Alp9Qjr0vBsOnw1s__EY,82612
|
|
5
|
+
pyRDDLGym_jax/core/logic.py,sha256=0NNm0OaeKv46K0VNY6vL0PHOUFZPNxqQLOvQYkHCswM,56093
|
|
6
|
+
pyRDDLGym_jax/core/planner.py,sha256=0rluBXKGNHRPEPfegOWcx9__cJHr8KjZdDJtG7i1JjI,122793
|
|
7
7
|
pyRDDLGym_jax/core/simulator.py,sha256=DnPL93WVCMZqtqMUoiJdfWcH9pEvNgGfDfO4NV0wIS0,9271
|
|
8
8
|
pyRDDLGym_jax/core/tuning.py,sha256=RKKtDZp7unvfbhZEoaunZtcAn5xtzGYqXBB_Ij_Aapc,24205
|
|
9
9
|
pyRDDLGym_jax/core/visualization.py,sha256=4BghMp8N7qtF0tdyDSqtxAxNfP9HPrQWTiXzAMJmx7o,70365
|
|
@@ -41,9 +41,9 @@ pyRDDLGym_jax/examples/configs/default_slp.cfg,sha256=mJo0woDevhQCSQfJg30ULVy9qG
|
|
|
41
41
|
pyRDDLGym_jax/examples/configs/tuning_drp.cfg,sha256=CQMpSCKTkGioO7U82mHMsYWFRsutULx0V6Wrl3YzV2U,504
|
|
42
42
|
pyRDDLGym_jax/examples/configs/tuning_replan.cfg,sha256=m_0nozFg_GVld0tGv92Xao_KONFJDq_vtiJKt5isqI8,501
|
|
43
43
|
pyRDDLGym_jax/examples/configs/tuning_slp.cfg,sha256=KHu8II6CA-h_HblwvWHylNRjSvvGS3VHxN7JQNR4p_Q,464
|
|
44
|
-
pyrddlgym_jax-2.
|
|
45
|
-
pyrddlgym_jax-2.
|
|
46
|
-
pyrddlgym_jax-2.
|
|
47
|
-
pyrddlgym_jax-2.
|
|
48
|
-
pyrddlgym_jax-2.
|
|
49
|
-
pyrddlgym_jax-2.
|
|
44
|
+
pyrddlgym_jax-2.3.dist-info/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
|
|
45
|
+
pyrddlgym_jax-2.3.dist-info/METADATA,sha256=MS6tckyg-bAQBGZJ112VQPZm5at660EfhntCnfrlUbE,17021
|
|
46
|
+
pyrddlgym_jax-2.3.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
|
|
47
|
+
pyrddlgym_jax-2.3.dist-info/entry_points.txt,sha256=Q--z9QzqDBz1xjswPZ87PU-pib-WPXx44hUWAFoBGBA,59
|
|
48
|
+
pyrddlgym_jax-2.3.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
|
|
49
|
+
pyrddlgym_jax-2.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|