pyRDDLGym-jax 1.2__py3-none-any.whl → 1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/planner.py +73 -40
- {pyRDDLGym_jax-1.2.dist-info → pyRDDLGym_jax-1.3.dist-info}/METADATA +1 -1
- {pyRDDLGym_jax-1.2.dist-info → pyRDDLGym_jax-1.3.dist-info}/RECORD +8 -8
- {pyRDDLGym_jax-1.2.dist-info → pyRDDLGym_jax-1.3.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-1.2.dist-info → pyRDDLGym_jax-1.3.dist-info}/WHEEL +0 -0
- {pyRDDLGym_jax-1.2.dist-info → pyRDDLGym_jax-1.3.dist-info}/entry_points.txt +0 -0
- {pyRDDLGym_jax-1.2.dist-info → pyRDDLGym_jax-1.3.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '1.
|
|
1
|
+
__version__ = '1.3'
|
pyRDDLGym_jax/core/planner.py
CHANGED
|
@@ -655,7 +655,10 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
655
655
|
if ranges[var] == 'bool':
|
|
656
656
|
param_flat = jnp.ravel(param)
|
|
657
657
|
if noop[var]:
|
|
658
|
-
|
|
658
|
+
if wrap_sigmoid:
|
|
659
|
+
param_flat = -param_flat
|
|
660
|
+
else:
|
|
661
|
+
param_flat = 1.0 - param_flat
|
|
659
662
|
scores.append(param_flat)
|
|
660
663
|
scores = jnp.concatenate(scores)
|
|
661
664
|
descending = jnp.sort(scores)[::-1]
|
|
@@ -666,7 +669,10 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
666
669
|
new_params = {}
|
|
667
670
|
for (var, param) in params.items():
|
|
668
671
|
if ranges[var] == 'bool':
|
|
669
|
-
|
|
672
|
+
if noop[var]:
|
|
673
|
+
new_param = param + surplus
|
|
674
|
+
else:
|
|
675
|
+
new_param = param - surplus
|
|
670
676
|
new_param = _jax_project_bool_to_box(var, new_param, hyperparams)
|
|
671
677
|
else:
|
|
672
678
|
new_param = param
|
|
@@ -687,57 +693,73 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
687
693
|
elif use_constraint_satisfaction and not self._use_new_projection:
|
|
688
694
|
|
|
689
695
|
# calculate the surplus of actions above max-nondef-actions
|
|
690
|
-
def _jax_wrapped_sogbofa_surplus(
|
|
691
|
-
sum_action,
|
|
692
|
-
for (var,
|
|
696
|
+
def _jax_wrapped_sogbofa_surplus(actions):
|
|
697
|
+
sum_action, k = 0.0, 0
|
|
698
|
+
for (var, action) in actions.items():
|
|
693
699
|
if ranges[var] == 'bool':
|
|
694
|
-
action = _jax_bool_param_to_action(var, param, hyperparams)
|
|
695
700
|
if noop[var]:
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
sum_action += jnp.sum(action)
|
|
700
|
-
count += jnp.sum(action > 0)
|
|
701
|
+
action = 1 - action
|
|
702
|
+
sum_action += jnp.sum(action)
|
|
703
|
+
k += jnp.count_nonzero(action)
|
|
701
704
|
surplus = jnp.maximum(sum_action - allowed_actions, 0.0)
|
|
702
|
-
|
|
703
|
-
return surplus / count
|
|
705
|
+
return surplus, k
|
|
704
706
|
|
|
705
707
|
# return whether the surplus is positive or reached compute limit
|
|
706
708
|
max_constraint_iter = self._max_constraint_iter
|
|
707
709
|
|
|
708
710
|
def _jax_wrapped_sogbofa_continue(values):
|
|
709
|
-
it, _,
|
|
710
|
-
return jnp.logical_and(
|
|
711
|
+
it, _, surplus, k = values
|
|
712
|
+
return jnp.logical_and(
|
|
713
|
+
it < max_constraint_iter, jnp.logical_and(surplus > 0, k > 0))
|
|
711
714
|
|
|
712
715
|
# reduce all bool action values by the surplus clipping at minimum
|
|
713
716
|
# for no-op = True, do the opposite, i.e. increase all
|
|
714
717
|
# bool action values by surplus clipping at maximum
|
|
715
718
|
def _jax_wrapped_sogbofa_subtract_surplus(values):
|
|
716
|
-
it,
|
|
717
|
-
|
|
718
|
-
|
|
719
|
+
it, actions, surplus, k = values
|
|
720
|
+
amount = surplus / k
|
|
721
|
+
new_actions = {}
|
|
722
|
+
for (var, action) in actions.items():
|
|
719
723
|
if ranges[var] == 'bool':
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
+
if noop[var]:
|
|
725
|
+
new_actions[var] = jnp.minimum(action + amount, 1)
|
|
726
|
+
else:
|
|
727
|
+
new_actions[var] = jnp.maximum(action - amount, 0)
|
|
724
728
|
else:
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
new_surplus = _jax_wrapped_sogbofa_surplus(new_params, hyperparams)
|
|
729
|
+
new_actions[var] = action
|
|
730
|
+
new_surplus, new_k = _jax_wrapped_sogbofa_surplus(new_actions)
|
|
728
731
|
new_it = it + 1
|
|
729
|
-
return new_it,
|
|
732
|
+
return new_it, new_actions, new_surplus, new_k
|
|
730
733
|
|
|
731
734
|
# apply the surplus to the actions until it becomes zero
|
|
732
735
|
def _jax_wrapped_sogbofa_project(params, hyperparams):
|
|
733
|
-
|
|
734
|
-
|
|
736
|
+
|
|
737
|
+
# convert parameters to actions
|
|
738
|
+
actions = {}
|
|
739
|
+
for (var, param) in params.items():
|
|
740
|
+
if ranges[var] == 'bool':
|
|
741
|
+
actions[var] = _jax_bool_param_to_action(var, param, hyperparams)
|
|
742
|
+
else:
|
|
743
|
+
actions[var] = param
|
|
744
|
+
|
|
745
|
+
# run SOGBOFA loop on the actions to get adjusted actions
|
|
746
|
+
surplus, k = _jax_wrapped_sogbofa_surplus(actions)
|
|
747
|
+
_, actions, surplus, k = jax.lax.while_loop(
|
|
735
748
|
cond_fun=_jax_wrapped_sogbofa_continue,
|
|
736
749
|
body_fun=_jax_wrapped_sogbofa_subtract_surplus,
|
|
737
|
-
init_val=(0,
|
|
750
|
+
init_val=(0, actions, surplus, k)
|
|
738
751
|
)
|
|
739
752
|
converged = jnp.logical_not(surplus > 0)
|
|
740
|
-
|
|
753
|
+
|
|
754
|
+
# convert the adjusted actions back to parameters
|
|
755
|
+
new_params = {}
|
|
756
|
+
for (var, action) in actions.items():
|
|
757
|
+
if ranges[var] == 'bool':
|
|
758
|
+
action = jnp.clip(action, min_action, max_action)
|
|
759
|
+
new_params[var] = _jax_bool_action_to_param(var, action, hyperparams)
|
|
760
|
+
else:
|
|
761
|
+
new_params[var] = action
|
|
762
|
+
return new_params, converged
|
|
741
763
|
|
|
742
764
|
# clip actions to valid bounds and satisfy constraint on max actions
|
|
743
765
|
def _jax_wrapped_slp_project_to_max_constraint(params, hyperparams):
|
|
@@ -1415,6 +1437,7 @@ r"""
|
|
|
1415
1437
|
|
|
1416
1438
|
# optimization
|
|
1417
1439
|
self.update = self._jax_update(train_loss)
|
|
1440
|
+
self.check_zero_grad = self._jax_check_zero_gradients()
|
|
1418
1441
|
|
|
1419
1442
|
def _jax_return(self, use_symlog):
|
|
1420
1443
|
gamma = self.rddl.discount
|
|
@@ -1497,6 +1520,18 @@ r"""
|
|
|
1497
1520
|
|
|
1498
1521
|
return jax.jit(_jax_wrapped_plan_update)
|
|
1499
1522
|
|
|
1523
|
+
def _jax_check_zero_gradients(self):
|
|
1524
|
+
|
|
1525
|
+
def _jax_wrapped_zero_gradient(grad):
|
|
1526
|
+
return jnp.allclose(grad, 0)
|
|
1527
|
+
|
|
1528
|
+
def _jax_wrapped_zero_gradients(grad):
|
|
1529
|
+
leaves, _ = jax.tree_util.tree_flatten(
|
|
1530
|
+
jax.tree_map(_jax_wrapped_zero_gradient, grad))
|
|
1531
|
+
return jnp.all(jnp.asarray(leaves))
|
|
1532
|
+
|
|
1533
|
+
return jax.jit(_jax_wrapped_zero_gradients)
|
|
1534
|
+
|
|
1500
1535
|
def _batched_init_subs(self, subs):
|
|
1501
1536
|
rddl = self.rddl
|
|
1502
1537
|
n_train, n_test = self.batch_size_train, self.batch_size_test
|
|
@@ -1795,7 +1830,6 @@ r"""
|
|
|
1795
1830
|
rolling_test_loss = RollingMean(test_rolling_window)
|
|
1796
1831
|
log = {}
|
|
1797
1832
|
status = JaxPlannerStatus.NORMAL
|
|
1798
|
-
is_all_zero_fn = lambda x: np.allclose(x, 0)
|
|
1799
1833
|
|
|
1800
1834
|
# initialize stopping criterion
|
|
1801
1835
|
if stopping_rule is not None:
|
|
@@ -1836,9 +1870,7 @@ r"""
|
|
|
1836
1870
|
# ==================================================================
|
|
1837
1871
|
|
|
1838
1872
|
# no progress
|
|
1839
|
-
|
|
1840
|
-
jax.tree_map(is_all_zero_fn, train_log['grad']))
|
|
1841
|
-
if np.all(grad_norm_zero):
|
|
1873
|
+
if self.check_zero_grad(train_log['grad']):
|
|
1842
1874
|
status = JaxPlannerStatus.NO_PROGRESS
|
|
1843
1875
|
|
|
1844
1876
|
# constraint satisfaction problem
|
|
@@ -2035,8 +2067,8 @@ r"""
|
|
|
2035
2067
|
# must be numeric array
|
|
2036
2068
|
# exception is for POMDPs at 1st epoch when observ-fluents are None
|
|
2037
2069
|
dtype = np.atleast_1d(values).dtype
|
|
2038
|
-
if not
|
|
2039
|
-
and not
|
|
2070
|
+
if not np.issubdtype(dtype, np.number) \
|
|
2071
|
+
and not np.issubdtype(dtype, np.bool_):
|
|
2040
2072
|
if step == 0 and var in self.rddl.observ_fluents:
|
|
2041
2073
|
subs[var] = self.test_compiled.init_values[var]
|
|
2042
2074
|
else:
|
|
@@ -2077,10 +2109,11 @@ def mean_variance_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
|
2077
2109
|
|
|
2078
2110
|
@jax.jit
|
|
2079
2111
|
def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
|
|
2080
|
-
|
|
2081
|
-
|
|
2082
|
-
|
|
2083
|
-
|
|
2112
|
+
var = jnp.percentile(returns, q=100 * alpha)
|
|
2113
|
+
mask = returns <= var
|
|
2114
|
+
weights = mask / jnp.maximum(1, jnp.sum(mask))
|
|
2115
|
+
return jnp.sum(returns * weights)
|
|
2116
|
+
|
|
2084
2117
|
|
|
2085
2118
|
# ***********************************************************************
|
|
2086
2119
|
# ALL VERSIONS OF CONTROLLERS
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: pyRDDLGym-jax
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.3
|
|
4
4
|
Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
|
|
5
5
|
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
6
|
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
pyRDDLGym_jax/__init__.py,sha256=
|
|
1
|
+
pyRDDLGym_jax/__init__.py,sha256=p_veRZMP15-djJyMuDHT7Ul1RbCCHpYsZ9LO0GD1URo,19
|
|
2
2
|
pyRDDLGym_jax/entry_point.py,sha256=dxDlO_5gneEEViwkLCg30Z-KVzUgdRXaKuFjoZklkA0,974
|
|
3
3
|
pyRDDLGym_jax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
4
|
pyRDDLGym_jax/core/compiler.py,sha256=qy1TSivdpuZxWecDl5HEM0PXX45JB7DHzV7uAB8kmbE,88696
|
|
5
5
|
pyRDDLGym_jax/core/logic.py,sha256=iYvLgWyQd_mrkwwoeRWao9NzjmhsObQnPq4DphILw1Q,38425
|
|
6
|
-
pyRDDLGym_jax/core/planner.py,sha256=
|
|
6
|
+
pyRDDLGym_jax/core/planner.py,sha256=TFFy91aCzRW600k_eP-7i2Gvp9wpNVjXlXtBnt9x03M,101744
|
|
7
7
|
pyRDDLGym_jax/core/simulator.py,sha256=JpmwfPqYPBfEhmQ04ufBeclZOQ-U1ZiyAtLf1AIwO2M,8462
|
|
8
8
|
pyRDDLGym_jax/core/tuning.py,sha256=LBhoVQZWWhYQj89gpM2B4xVHlYlKDt4psw4Be9cBbSY,23685
|
|
9
9
|
pyRDDLGym_jax/core/visualization.py,sha256=uKhC8z0TeX9BklPNoxSVt0g5pkqhgxrQClQAih78ybY,68292
|
|
@@ -41,9 +41,9 @@ pyRDDLGym_jax/examples/configs/default_slp.cfg,sha256=mJo0woDevhQCSQfJg30ULVy9qG
|
|
|
41
41
|
pyRDDLGym_jax/examples/configs/tuning_drp.cfg,sha256=CQMpSCKTkGioO7U82mHMsYWFRsutULx0V6Wrl3YzV2U,504
|
|
42
42
|
pyRDDLGym_jax/examples/configs/tuning_replan.cfg,sha256=m_0nozFg_GVld0tGv92Xao_KONFJDq_vtiJKt5isqI8,501
|
|
43
43
|
pyRDDLGym_jax/examples/configs/tuning_slp.cfg,sha256=KHu8II6CA-h_HblwvWHylNRjSvvGS3VHxN7JQNR4p_Q,464
|
|
44
|
-
pyRDDLGym_jax-1.
|
|
45
|
-
pyRDDLGym_jax-1.
|
|
46
|
-
pyRDDLGym_jax-1.
|
|
47
|
-
pyRDDLGym_jax-1.
|
|
48
|
-
pyRDDLGym_jax-1.
|
|
49
|
-
pyRDDLGym_jax-1.
|
|
44
|
+
pyRDDLGym_jax-1.3.dist-info/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
|
|
45
|
+
pyRDDLGym_jax-1.3.dist-info/METADATA,sha256=Colu-byYJ4RF5sr1qOVKg9VhCbrLnv32OvHt_A9KtLE,15090
|
|
46
|
+
pyRDDLGym_jax-1.3.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
47
|
+
pyRDDLGym_jax-1.3.dist-info/entry_points.txt,sha256=Q--z9QzqDBz1xjswPZ87PU-pib-WPXx44hUWAFoBGBA,59
|
|
48
|
+
pyRDDLGym_jax-1.3.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
|
|
49
|
+
pyRDDLGym_jax-1.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|