pyRDDLGym-jax 1.2__py3-none-any.whl → 1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pyRDDLGym_jax/__init__.py CHANGED
@@ -1 +1 @@
1
- __version__ = '1.2'
1
+ __version__ = '1.3'
@@ -655,7 +655,10 @@ class JaxStraightLinePlan(JaxPlan):
655
655
  if ranges[var] == 'bool':
656
656
  param_flat = jnp.ravel(param)
657
657
  if noop[var]:
658
- param_flat = (-param_flat) if wrap_sigmoid else 1.0 - param_flat
658
+ if wrap_sigmoid:
659
+ param_flat = -param_flat
660
+ else:
661
+ param_flat = 1.0 - param_flat
659
662
  scores.append(param_flat)
660
663
  scores = jnp.concatenate(scores)
661
664
  descending = jnp.sort(scores)[::-1]
@@ -666,7 +669,10 @@ class JaxStraightLinePlan(JaxPlan):
666
669
  new_params = {}
667
670
  for (var, param) in params.items():
668
671
  if ranges[var] == 'bool':
669
- new_param = param + (surplus if noop[var] else -surplus)
672
+ if noop[var]:
673
+ new_param = param + surplus
674
+ else:
675
+ new_param = param - surplus
670
676
  new_param = _jax_project_bool_to_box(var, new_param, hyperparams)
671
677
  else:
672
678
  new_param = param
@@ -687,57 +693,73 @@ class JaxStraightLinePlan(JaxPlan):
687
693
  elif use_constraint_satisfaction and not self._use_new_projection:
688
694
 
689
695
  # calculate the surplus of actions above max-nondef-actions
690
- def _jax_wrapped_sogbofa_surplus(params, hyperparams):
691
- sum_action, count = 0.0, 0
692
- for (var, param) in params.items():
696
+ def _jax_wrapped_sogbofa_surplus(actions):
697
+ sum_action, k = 0.0, 0
698
+ for (var, action) in actions.items():
693
699
  if ranges[var] == 'bool':
694
- action = _jax_bool_param_to_action(var, param, hyperparams)
695
700
  if noop[var]:
696
- sum_action += jnp.size(action) - jnp.sum(action)
697
- count += jnp.sum(action < 1)
698
- else:
699
- sum_action += jnp.sum(action)
700
- count += jnp.sum(action > 0)
701
+ action = 1 - action
702
+ sum_action += jnp.sum(action)
703
+ k += jnp.count_nonzero(action)
701
704
  surplus = jnp.maximum(sum_action - allowed_actions, 0.0)
702
- count = jnp.maximum(count, 1)
703
- return surplus / count
705
+ return surplus, k
704
706
 
705
707
  # return whether the surplus is positive or reached compute limit
706
708
  max_constraint_iter = self._max_constraint_iter
707
709
 
708
710
  def _jax_wrapped_sogbofa_continue(values):
709
- it, _, _, surplus = values
710
- return jnp.logical_and(it < max_constraint_iter, surplus > 0)
711
+ it, _, surplus, k = values
712
+ return jnp.logical_and(
713
+ it < max_constraint_iter, jnp.logical_and(surplus > 0, k > 0))
711
714
 
712
715
  # reduce all bool action values by the surplus clipping at minimum
713
716
  # for no-op = True, do the opposite, i.e. increase all
714
717
  # bool action values by surplus clipping at maximum
715
718
  def _jax_wrapped_sogbofa_subtract_surplus(values):
716
- it, params, hyperparams, surplus = values
717
- new_params = {}
718
- for (var, param) in params.items():
719
+ it, actions, surplus, k = values
720
+ amount = surplus / k
721
+ new_actions = {}
722
+ for (var, action) in actions.items():
719
723
  if ranges[var] == 'bool':
720
- action = _jax_bool_param_to_action(var, param, hyperparams)
721
- new_action = action + (surplus if noop[var] else -surplus)
722
- new_action = jnp.clip(new_action, min_action, max_action)
723
- new_param = _jax_bool_action_to_param(var, new_action, hyperparams)
724
+ if noop[var]:
725
+ new_actions[var] = jnp.minimum(action + amount, 1)
726
+ else:
727
+ new_actions[var] = jnp.maximum(action - amount, 0)
724
728
  else:
725
- new_param = param
726
- new_params[var] = new_param
727
- new_surplus = _jax_wrapped_sogbofa_surplus(new_params, hyperparams)
729
+ new_actions[var] = action
730
+ new_surplus, new_k = _jax_wrapped_sogbofa_surplus(new_actions)
728
731
  new_it = it + 1
729
- return new_it, new_params, hyperparams, new_surplus
732
+ return new_it, new_actions, new_surplus, new_k
730
733
 
731
734
  # apply the surplus to the actions until it becomes zero
732
735
  def _jax_wrapped_sogbofa_project(params, hyperparams):
733
- surplus = _jax_wrapped_sogbofa_surplus(params, hyperparams)
734
- _, params, _, surplus = jax.lax.while_loop(
736
+
737
+ # convert parameters to actions
738
+ actions = {}
739
+ for (var, param) in params.items():
740
+ if ranges[var] == 'bool':
741
+ actions[var] = _jax_bool_param_to_action(var, param, hyperparams)
742
+ else:
743
+ actions[var] = param
744
+
745
+ # run SOGBOFA loop on the actions to get adjusted actions
746
+ surplus, k = _jax_wrapped_sogbofa_surplus(actions)
747
+ _, actions, surplus, k = jax.lax.while_loop(
735
748
  cond_fun=_jax_wrapped_sogbofa_continue,
736
749
  body_fun=_jax_wrapped_sogbofa_subtract_surplus,
737
- init_val=(0, params, hyperparams, surplus)
750
+ init_val=(0, actions, surplus, k)
738
751
  )
739
752
  converged = jnp.logical_not(surplus > 0)
740
- return params, converged
753
+
754
+ # convert the adjusted actions back to parameters
755
+ new_params = {}
756
+ for (var, action) in actions.items():
757
+ if ranges[var] == 'bool':
758
+ action = jnp.clip(action, min_action, max_action)
759
+ new_params[var] = _jax_bool_action_to_param(var, action, hyperparams)
760
+ else:
761
+ new_params[var] = action
762
+ return new_params, converged
741
763
 
742
764
  # clip actions to valid bounds and satisfy constraint on max actions
743
765
  def _jax_wrapped_slp_project_to_max_constraint(params, hyperparams):
@@ -1415,6 +1437,7 @@ r"""
1415
1437
 
1416
1438
  # optimization
1417
1439
  self.update = self._jax_update(train_loss)
1440
+ self.check_zero_grad = self._jax_check_zero_gradients()
1418
1441
 
1419
1442
  def _jax_return(self, use_symlog):
1420
1443
  gamma = self.rddl.discount
@@ -1497,6 +1520,18 @@ r"""
1497
1520
 
1498
1521
  return jax.jit(_jax_wrapped_plan_update)
1499
1522
 
1523
+ def _jax_check_zero_gradients(self):
1524
+
1525
+ def _jax_wrapped_zero_gradient(grad):
1526
+ return jnp.allclose(grad, 0)
1527
+
1528
+ def _jax_wrapped_zero_gradients(grad):
1529
+ leaves, _ = jax.tree_util.tree_flatten(
1530
+ jax.tree_map(_jax_wrapped_zero_gradient, grad))
1531
+ return jnp.all(jnp.asarray(leaves))
1532
+
1533
+ return jax.jit(_jax_wrapped_zero_gradients)
1534
+
1500
1535
  def _batched_init_subs(self, subs):
1501
1536
  rddl = self.rddl
1502
1537
  n_train, n_test = self.batch_size_train, self.batch_size_test
@@ -1795,7 +1830,6 @@ r"""
1795
1830
  rolling_test_loss = RollingMean(test_rolling_window)
1796
1831
  log = {}
1797
1832
  status = JaxPlannerStatus.NORMAL
1798
- is_all_zero_fn = lambda x: np.allclose(x, 0)
1799
1833
 
1800
1834
  # initialize stopping criterion
1801
1835
  if stopping_rule is not None:
@@ -1836,9 +1870,7 @@ r"""
1836
1870
  # ==================================================================
1837
1871
 
1838
1872
  # no progress
1839
- grad_norm_zero, _ = jax.tree_util.tree_flatten(
1840
- jax.tree_map(is_all_zero_fn, train_log['grad']))
1841
- if np.all(grad_norm_zero):
1873
+ if self.check_zero_grad(train_log['grad']):
1842
1874
  status = JaxPlannerStatus.NO_PROGRESS
1843
1875
 
1844
1876
  # constraint satisfaction problem
@@ -2035,8 +2067,8 @@ r"""
2035
2067
  # must be numeric array
2036
2068
  # exception is for POMDPs at 1st epoch when observ-fluents are None
2037
2069
  dtype = np.atleast_1d(values).dtype
2038
- if not jnp.issubdtype(dtype, jnp.number) \
2039
- and not jnp.issubdtype(dtype, jnp.bool_):
2070
+ if not np.issubdtype(dtype, np.number) \
2071
+ and not np.issubdtype(dtype, np.bool_):
2040
2072
  if step == 0 and var in self.rddl.observ_fluents:
2041
2073
  subs[var] = self.test_compiled.init_values[var]
2042
2074
  else:
@@ -2077,10 +2109,11 @@ def mean_variance_utility(returns: jnp.ndarray, beta: float) -> float:
2077
2109
 
2078
2110
  @jax.jit
2079
2111
  def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
2080
- alpha_mask = jax.lax.stop_gradient(
2081
- returns <= jnp.percentile(returns, q=100 * alpha))
2082
- return jnp.sum(returns * alpha_mask) / jnp.sum(alpha_mask)
2083
-
2112
+ var = jnp.percentile(returns, q=100 * alpha)
2113
+ mask = returns <= var
2114
+ weights = mask / jnp.maximum(1, jnp.sum(mask))
2115
+ return jnp.sum(returns * weights)
2116
+
2084
2117
 
2085
2118
  # ***********************************************************************
2086
2119
  # ALL VERSIONS OF CONTROLLERS
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: pyRDDLGym-jax
3
- Version: 1.2
3
+ Version: 1.3
4
4
  Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
5
  Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
6
  Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
@@ -1,9 +1,9 @@
1
- pyRDDLGym_jax/__init__.py,sha256=LTT-ZpL6vrKdC5t0O71pJnk3zMhDf1eXkNmoLoIRupo,19
1
+ pyRDDLGym_jax/__init__.py,sha256=p_veRZMP15-djJyMuDHT7Ul1RbCCHpYsZ9LO0GD1URo,19
2
2
  pyRDDLGym_jax/entry_point.py,sha256=dxDlO_5gneEEViwkLCg30Z-KVzUgdRXaKuFjoZklkA0,974
3
3
  pyRDDLGym_jax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  pyRDDLGym_jax/core/compiler.py,sha256=qy1TSivdpuZxWecDl5HEM0PXX45JB7DHzV7uAB8kmbE,88696
5
5
  pyRDDLGym_jax/core/logic.py,sha256=iYvLgWyQd_mrkwwoeRWao9NzjmhsObQnPq4DphILw1Q,38425
6
- pyRDDLGym_jax/core/planner.py,sha256=oKs9js7xyIc9-bxQFZSQNBw9s1nWQlz4DjENwEgSojY,100672
6
+ pyRDDLGym_jax/core/planner.py,sha256=TFFy91aCzRW600k_eP-7i2Gvp9wpNVjXlXtBnt9x03M,101744
7
7
  pyRDDLGym_jax/core/simulator.py,sha256=JpmwfPqYPBfEhmQ04ufBeclZOQ-U1ZiyAtLf1AIwO2M,8462
8
8
  pyRDDLGym_jax/core/tuning.py,sha256=LBhoVQZWWhYQj89gpM2B4xVHlYlKDt4psw4Be9cBbSY,23685
9
9
  pyRDDLGym_jax/core/visualization.py,sha256=uKhC8z0TeX9BklPNoxSVt0g5pkqhgxrQClQAih78ybY,68292
@@ -41,9 +41,9 @@ pyRDDLGym_jax/examples/configs/default_slp.cfg,sha256=mJo0woDevhQCSQfJg30ULVy9qG
41
41
  pyRDDLGym_jax/examples/configs/tuning_drp.cfg,sha256=CQMpSCKTkGioO7U82mHMsYWFRsutULx0V6Wrl3YzV2U,504
42
42
  pyRDDLGym_jax/examples/configs/tuning_replan.cfg,sha256=m_0nozFg_GVld0tGv92Xao_KONFJDq_vtiJKt5isqI8,501
43
43
  pyRDDLGym_jax/examples/configs/tuning_slp.cfg,sha256=KHu8II6CA-h_HblwvWHylNRjSvvGS3VHxN7JQNR4p_Q,464
44
- pyRDDLGym_jax-1.2.dist-info/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
45
- pyRDDLGym_jax-1.2.dist-info/METADATA,sha256=oWVOtC5AvAm2Xvdd507gXr3b6_aZLaH7LnOj6hADdgQ,15090
46
- pyRDDLGym_jax-1.2.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
47
- pyRDDLGym_jax-1.2.dist-info/entry_points.txt,sha256=Q--z9QzqDBz1xjswPZ87PU-pib-WPXx44hUWAFoBGBA,59
48
- pyRDDLGym_jax-1.2.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
49
- pyRDDLGym_jax-1.2.dist-info/RECORD,,
44
+ pyRDDLGym_jax-1.3.dist-info/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
45
+ pyRDDLGym_jax-1.3.dist-info/METADATA,sha256=Colu-byYJ4RF5sr1qOVKg9VhCbrLnv32OvHt_A9KtLE,15090
46
+ pyRDDLGym_jax-1.3.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
47
+ pyRDDLGym_jax-1.3.dist-info/entry_points.txt,sha256=Q--z9QzqDBz1xjswPZ87PU-pib-WPXx44hUWAFoBGBA,59
48
+ pyRDDLGym_jax-1.3.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
49
+ pyRDDLGym_jax-1.3.dist-info/RECORD,,