pyRDDLGym-jax 1.2__py3-none-any.whl → 2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +16 -1
- pyRDDLGym_jax/core/logic.py +36 -9
- pyRDDLGym_jax/core/planner.py +517 -129
- pyRDDLGym_jax/core/simulator.py +20 -0
- pyRDDLGym_jax/core/tuning.py +15 -0
- pyRDDLGym_jax/core/visualization.py +48 -0
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -3
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +1 -0
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +1 -0
- {pyRDDLGym_jax-1.2.dist-info → pyRDDLGym_jax-2.0.dist-info}/METADATA +1 -1
- {pyRDDLGym_jax-1.2.dist-info → pyRDDLGym_jax-2.0.dist-info}/RECORD +23 -23
- {pyRDDLGym_jax-1.2.dist-info → pyRDDLGym_jax-2.0.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-1.2.dist-info → pyRDDLGym_jax-2.0.dist-info}/WHEEL +0 -0
- {pyRDDLGym_jax-1.2.dist-info → pyRDDLGym_jax-2.0.dist-info}/entry_points.txt +0 -0
- {pyRDDLGym_jax-1.2.dist-info → pyRDDLGym_jax-2.0.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/planner.py
CHANGED
|
@@ -1,12 +1,43 @@
|
|
|
1
|
+
# ***********************************************************************
|
|
2
|
+
# JAXPLAN
|
|
3
|
+
#
|
|
4
|
+
# Author: Michael Gimelfarb
|
|
5
|
+
#
|
|
6
|
+
# RELEVANT SOURCES:
|
|
7
|
+
#
|
|
8
|
+
# [1] Gimelfarb, Michael, Ayal Taitler, and Scott Sanner. "JaxPlan and GurobiPlan:
|
|
9
|
+
# Optimization Baselines for Replanning in Discrete and Mixed Discrete-Continuous
|
|
10
|
+
# Probabilistic Domains." Proceedings of the International Conference on Automated
|
|
11
|
+
# Planning and Scheduling. Vol. 34. 2024.
|
|
12
|
+
#
|
|
13
|
+
# [2] Patton, Noah, Jihwan Jeong, Mike Gimelfarb, and Scott Sanner. "A Distributional
|
|
14
|
+
# Framework for Risk-Sensitive End-to-End Planning in Continuous MDPs." In Proceedings of
|
|
15
|
+
# the AAAI Conference on Artificial Intelligence, vol. 36, no. 9, pp. 9894-9901. 2022.
|
|
16
|
+
#
|
|
17
|
+
# [3] Bueno, Thiago P., Leliane N. de Barros, Denis D. Mauá, and Scott Sanner. "Deep
|
|
18
|
+
# reactive policies for planning in stochastic nonlinear domains." In Proceedings of the
|
|
19
|
+
# AAAI Conference on Artificial Intelligence, vol. 33, no. 01, pp. 7530-7537. 2019.
|
|
20
|
+
#
|
|
21
|
+
# [4] Wu, Ga, Buser Say, and Scott Sanner. "Scalable planning with tensorflow for hybrid
|
|
22
|
+
# nonlinear domains." Advances in Neural Information Processing Systems 30 (2017).
|
|
23
|
+
#
|
|
24
|
+
# [5] Sehnke, Frank, and Tingting Zhao. "Baseline-free sampling in parameter exploring
|
|
25
|
+
# policy gradients: Super symmetric pgpe." Artificial Neural Networks: Methods and
|
|
26
|
+
# Applications in Bio-/Neuroinformatics. Springer International Publishing, 2015.
|
|
27
|
+
#
|
|
28
|
+
# ***********************************************************************
|
|
29
|
+
|
|
30
|
+
|
|
1
31
|
from ast import literal_eval
|
|
2
32
|
from collections import deque
|
|
3
33
|
import configparser
|
|
4
34
|
from enum import Enum
|
|
35
|
+
from functools import partial
|
|
5
36
|
import os
|
|
6
37
|
import sys
|
|
7
38
|
import time
|
|
8
39
|
import traceback
|
|
9
|
-
from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Tuple, Union
|
|
40
|
+
from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Type, Tuple, Union
|
|
10
41
|
|
|
11
42
|
import haiku as hk
|
|
12
43
|
import jax
|
|
@@ -163,7 +194,20 @@ def _load_config(config, args):
|
|
|
163
194
|
del planner_args['optimizer']
|
|
164
195
|
else:
|
|
165
196
|
planner_args['optimizer'] = optimizer
|
|
166
|
-
|
|
197
|
+
|
|
198
|
+
# pgpe optimizer
|
|
199
|
+
pgpe_method = planner_args.get('pgpe', 'GaussianPGPE')
|
|
200
|
+
pgpe_kwargs = planner_args.pop('pgpe_kwargs', {})
|
|
201
|
+
if pgpe_method is not None:
|
|
202
|
+
if 'optimizer' in pgpe_kwargs:
|
|
203
|
+
pgpe_optimizer = _getattr_any(packages=[optax], item=pgpe_kwargs['optimizer'])
|
|
204
|
+
if pgpe_optimizer is None:
|
|
205
|
+
raise_warning(f'Ignoring invalid optimizer <{pgpe_optimizer}>.', 'red')
|
|
206
|
+
del pgpe_kwargs['optimizer']
|
|
207
|
+
else:
|
|
208
|
+
pgpe_kwargs['optimizer'] = pgpe_optimizer
|
|
209
|
+
planner_args['pgpe'] = getattr(sys.modules[__name__], pgpe_method)(**pgpe_kwargs)
|
|
210
|
+
|
|
167
211
|
# optimize call RNG key
|
|
168
212
|
planner_key = train_args.get('key', None)
|
|
169
213
|
if planner_key is not None:
|
|
@@ -469,16 +513,16 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
469
513
|
bounds = '\n '.join(
|
|
470
514
|
map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
|
|
471
515
|
return (f'policy hyper-parameters:\n'
|
|
472
|
-
f' initializer
|
|
473
|
-
f'constraint-sat strategy (simple):\n'
|
|
474
|
-
f'
|
|
475
|
-
f'
|
|
476
|
-
f'
|
|
477
|
-
f'
|
|
478
|
-
f'constraint-sat strategy (complex):\n'
|
|
479
|
-
f'
|
|
480
|
-
f'
|
|
481
|
-
f'
|
|
516
|
+
f' initializer={self._initializer_base}\n'
|
|
517
|
+
f' constraint-sat strategy (simple):\n'
|
|
518
|
+
f' parsed_action_bounds =\n {bounds}\n'
|
|
519
|
+
f' wrap_sigmoid ={self._wrap_sigmoid}\n'
|
|
520
|
+
f' wrap_sigmoid_min_prob={self._min_action_prob}\n'
|
|
521
|
+
f' wrap_non_bool ={self._wrap_non_bool}\n'
|
|
522
|
+
f' constraint-sat strategy (complex):\n'
|
|
523
|
+
f' wrap_softmax ={self._wrap_softmax}\n'
|
|
524
|
+
f' use_new_projection ={self._use_new_projection}\n'
|
|
525
|
+
f' max_projection_iters={self._max_constraint_iter}\n')
|
|
482
526
|
|
|
483
527
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
484
528
|
_bounds: Bounds,
|
|
@@ -655,7 +699,10 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
655
699
|
if ranges[var] == 'bool':
|
|
656
700
|
param_flat = jnp.ravel(param)
|
|
657
701
|
if noop[var]:
|
|
658
|
-
|
|
702
|
+
if wrap_sigmoid:
|
|
703
|
+
param_flat = -param_flat
|
|
704
|
+
else:
|
|
705
|
+
param_flat = 1.0 - param_flat
|
|
659
706
|
scores.append(param_flat)
|
|
660
707
|
scores = jnp.concatenate(scores)
|
|
661
708
|
descending = jnp.sort(scores)[::-1]
|
|
@@ -666,7 +713,10 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
666
713
|
new_params = {}
|
|
667
714
|
for (var, param) in params.items():
|
|
668
715
|
if ranges[var] == 'bool':
|
|
669
|
-
|
|
716
|
+
if noop[var]:
|
|
717
|
+
new_param = param + surplus
|
|
718
|
+
else:
|
|
719
|
+
new_param = param - surplus
|
|
670
720
|
new_param = _jax_project_bool_to_box(var, new_param, hyperparams)
|
|
671
721
|
else:
|
|
672
722
|
new_param = param
|
|
@@ -687,57 +737,73 @@ class JaxStraightLinePlan(JaxPlan):
|
|
|
687
737
|
elif use_constraint_satisfaction and not self._use_new_projection:
|
|
688
738
|
|
|
689
739
|
# calculate the surplus of actions above max-nondef-actions
|
|
690
|
-
def _jax_wrapped_sogbofa_surplus(
|
|
691
|
-
sum_action,
|
|
692
|
-
for (var,
|
|
740
|
+
def _jax_wrapped_sogbofa_surplus(actions):
|
|
741
|
+
sum_action, k = 0.0, 0
|
|
742
|
+
for (var, action) in actions.items():
|
|
693
743
|
if ranges[var] == 'bool':
|
|
694
|
-
action = _jax_bool_param_to_action(var, param, hyperparams)
|
|
695
744
|
if noop[var]:
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
sum_action += jnp.sum(action)
|
|
700
|
-
count += jnp.sum(action > 0)
|
|
745
|
+
action = 1 - action
|
|
746
|
+
sum_action += jnp.sum(action)
|
|
747
|
+
k += jnp.count_nonzero(action)
|
|
701
748
|
surplus = jnp.maximum(sum_action - allowed_actions, 0.0)
|
|
702
|
-
|
|
703
|
-
return surplus / count
|
|
749
|
+
return surplus, k
|
|
704
750
|
|
|
705
751
|
# return whether the surplus is positive or reached compute limit
|
|
706
752
|
max_constraint_iter = self._max_constraint_iter
|
|
707
753
|
|
|
708
754
|
def _jax_wrapped_sogbofa_continue(values):
|
|
709
|
-
it, _,
|
|
710
|
-
return jnp.logical_and(
|
|
755
|
+
it, _, surplus, k = values
|
|
756
|
+
return jnp.logical_and(
|
|
757
|
+
it < max_constraint_iter, jnp.logical_and(surplus > 0, k > 0))
|
|
711
758
|
|
|
712
759
|
# reduce all bool action values by the surplus clipping at minimum
|
|
713
760
|
# for no-op = True, do the opposite, i.e. increase all
|
|
714
761
|
# bool action values by surplus clipping at maximum
|
|
715
762
|
def _jax_wrapped_sogbofa_subtract_surplus(values):
|
|
716
|
-
it,
|
|
717
|
-
|
|
718
|
-
|
|
763
|
+
it, actions, surplus, k = values
|
|
764
|
+
amount = surplus / k
|
|
765
|
+
new_actions = {}
|
|
766
|
+
for (var, action) in actions.items():
|
|
719
767
|
if ranges[var] == 'bool':
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
768
|
+
if noop[var]:
|
|
769
|
+
new_actions[var] = jnp.minimum(action + amount, 1)
|
|
770
|
+
else:
|
|
771
|
+
new_actions[var] = jnp.maximum(action - amount, 0)
|
|
724
772
|
else:
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
new_surplus = _jax_wrapped_sogbofa_surplus(new_params, hyperparams)
|
|
773
|
+
new_actions[var] = action
|
|
774
|
+
new_surplus, new_k = _jax_wrapped_sogbofa_surplus(new_actions)
|
|
728
775
|
new_it = it + 1
|
|
729
|
-
return new_it,
|
|
776
|
+
return new_it, new_actions, new_surplus, new_k
|
|
730
777
|
|
|
731
778
|
# apply the surplus to the actions until it becomes zero
|
|
732
779
|
def _jax_wrapped_sogbofa_project(params, hyperparams):
|
|
733
|
-
|
|
734
|
-
|
|
780
|
+
|
|
781
|
+
# convert parameters to actions
|
|
782
|
+
actions = {}
|
|
783
|
+
for (var, param) in params.items():
|
|
784
|
+
if ranges[var] == 'bool':
|
|
785
|
+
actions[var] = _jax_bool_param_to_action(var, param, hyperparams)
|
|
786
|
+
else:
|
|
787
|
+
actions[var] = param
|
|
788
|
+
|
|
789
|
+
# run SOGBOFA loop on the actions to get adjusted actions
|
|
790
|
+
surplus, k = _jax_wrapped_sogbofa_surplus(actions)
|
|
791
|
+
_, actions, surplus, k = jax.lax.while_loop(
|
|
735
792
|
cond_fun=_jax_wrapped_sogbofa_continue,
|
|
736
793
|
body_fun=_jax_wrapped_sogbofa_subtract_surplus,
|
|
737
|
-
init_val=(0,
|
|
794
|
+
init_val=(0, actions, surplus, k)
|
|
738
795
|
)
|
|
739
796
|
converged = jnp.logical_not(surplus > 0)
|
|
740
|
-
|
|
797
|
+
|
|
798
|
+
# convert the adjusted actions back to parameters
|
|
799
|
+
new_params = {}
|
|
800
|
+
for (var, action) in actions.items():
|
|
801
|
+
if ranges[var] == 'bool':
|
|
802
|
+
action = jnp.clip(action, min_action, max_action)
|
|
803
|
+
new_params[var] = _jax_bool_action_to_param(var, action, hyperparams)
|
|
804
|
+
else:
|
|
805
|
+
new_params[var] = action
|
|
806
|
+
return new_params, converged
|
|
741
807
|
|
|
742
808
|
# clip actions to valid bounds and satisfy constraint on max actions
|
|
743
809
|
def _jax_wrapped_slp_project_to_max_constraint(params, hyperparams):
|
|
@@ -834,15 +900,16 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
834
900
|
bounds = '\n '.join(
|
|
835
901
|
map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
|
|
836
902
|
return (f'policy hyper-parameters:\n'
|
|
837
|
-
f' topology
|
|
838
|
-
f' activation_fn
|
|
839
|
-
f' initializer
|
|
840
|
-
f'
|
|
841
|
-
f'
|
|
842
|
-
f'
|
|
843
|
-
f'
|
|
844
|
-
f'
|
|
845
|
-
f'
|
|
903
|
+
f' topology ={self._topology}\n'
|
|
904
|
+
f' activation_fn={self._activations[0].__name__}\n'
|
|
905
|
+
f' initializer ={type(self._initializer_base).__name__}\n'
|
|
906
|
+
f' input norm:\n'
|
|
907
|
+
f' apply_input_norm ={self._normalize}\n'
|
|
908
|
+
f' input_norm_layerwise={self._normalize_per_layer}\n'
|
|
909
|
+
f' input_norm_args ={self._normalizer_kwargs}\n'
|
|
910
|
+
f' constraint-sat strategy:\n'
|
|
911
|
+
f' parsed_action_bounds=\n {bounds}\n'
|
|
912
|
+
f' wrap_non_bool ={self._wrap_non_bool}\n')
|
|
846
913
|
|
|
847
914
|
def compile(self, compiled: JaxRDDLCompilerWithGrad,
|
|
848
915
|
_bounds: Bounds,
|
|
@@ -1068,10 +1135,11 @@ class JaxDeepReactivePolicy(JaxPlan):
|
|
|
1068
1135
|
|
|
1069
1136
|
|
|
1070
1137
|
# ***********************************************************************
|
|
1071
|
-
#
|
|
1138
|
+
# SUPPORTING FUNCTIONS
|
|
1072
1139
|
#
|
|
1073
|
-
# -
|
|
1074
|
-
# -
|
|
1140
|
+
# - smoothed mean calculation
|
|
1141
|
+
# - planner status
|
|
1142
|
+
# - stopping criteria
|
|
1075
1143
|
#
|
|
1076
1144
|
# ***********************************************************************
|
|
1077
1145
|
|
|
@@ -1145,6 +1213,264 @@ class NoImprovementStoppingRule(JaxPlannerStoppingRule):
|
|
|
1145
1213
|
return f'No improvement for {self.patience} iterations'
|
|
1146
1214
|
|
|
1147
1215
|
|
|
1216
|
+
# ***********************************************************************
|
|
1217
|
+
# PARAMETER EXPLORING POLICY GRADIENTS (PGPE)
|
|
1218
|
+
#
|
|
1219
|
+
# - simple Gaussian PGPE
|
|
1220
|
+
#
|
|
1221
|
+
# ***********************************************************************
|
|
1222
|
+
|
|
1223
|
+
|
|
1224
|
+
class PGPE:
|
|
1225
|
+
"""Base class for all PGPE strategies."""
|
|
1226
|
+
|
|
1227
|
+
def __init__(self) -> None:
|
|
1228
|
+
self._initializer = None
|
|
1229
|
+
self._update = None
|
|
1230
|
+
|
|
1231
|
+
@property
|
|
1232
|
+
def initialize(self):
|
|
1233
|
+
return self._initializer
|
|
1234
|
+
|
|
1235
|
+
@property
|
|
1236
|
+
def update(self):
|
|
1237
|
+
return self._update
|
|
1238
|
+
|
|
1239
|
+
def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type) -> None:
|
|
1240
|
+
raise NotImplementedError
|
|
1241
|
+
|
|
1242
|
+
|
|
1243
|
+
class GaussianPGPE(PGPE):
|
|
1244
|
+
'''PGPE with a Gaussian parameter distribution.'''
|
|
1245
|
+
|
|
1246
|
+
def __init__(self, batch_size: int=1,
|
|
1247
|
+
init_sigma: float=1.0,
|
|
1248
|
+
sigma_range: Tuple[float, float]=(1e-5, 1e5),
|
|
1249
|
+
scale_reward: bool=True,
|
|
1250
|
+
super_symmetric: bool=True,
|
|
1251
|
+
super_symmetric_accurate: bool=True,
|
|
1252
|
+
optimizer: Callable[..., optax.GradientTransformation]=optax.adam,
|
|
1253
|
+
optimizer_kwargs_mu: Optional[Kwargs]=None,
|
|
1254
|
+
optimizer_kwargs_sigma: Optional[Kwargs]=None) -> None:
|
|
1255
|
+
'''Creates a new Gaussian PGPE planner.
|
|
1256
|
+
|
|
1257
|
+
:param batch_size: how many policy parameters to sample per optimization step
|
|
1258
|
+
:param init_sigma: initial standard deviation of Gaussian
|
|
1259
|
+
:param sigma_range: bounds to constrain standard deviation
|
|
1260
|
+
:param scale_reward: whether to apply reward scaling as in the paper
|
|
1261
|
+
:param super_symmetric: whether to use super-symmetric sampling as in the paper
|
|
1262
|
+
:param super_symmetric_accurate: whether to use the accurate formula for super-
|
|
1263
|
+
symmetric sampling or the simplified but biased formula
|
|
1264
|
+
:param optimizer: a factory for an optax SGD algorithm
|
|
1265
|
+
:param optimizer_kwargs_mu: a dictionary of parameters to pass to the SGD
|
|
1266
|
+
factory for the mean optimizer
|
|
1267
|
+
:param optimizer_kwargs_sigma: a dictionary of parameters to pass to the SGD
|
|
1268
|
+
factory for the standard deviation optimizer
|
|
1269
|
+
'''
|
|
1270
|
+
super().__init__()
|
|
1271
|
+
|
|
1272
|
+
self.batch_size = batch_size
|
|
1273
|
+
self.init_sigma = init_sigma
|
|
1274
|
+
self.sigma_range = sigma_range
|
|
1275
|
+
self.scale_reward = scale_reward
|
|
1276
|
+
self.super_symmetric = super_symmetric
|
|
1277
|
+
self.super_symmetric_accurate = super_symmetric_accurate
|
|
1278
|
+
|
|
1279
|
+
# set optimizers
|
|
1280
|
+
if optimizer_kwargs_mu is None:
|
|
1281
|
+
optimizer_kwargs_mu = {'learning_rate': 0.1}
|
|
1282
|
+
self.optimizer_kwargs_mu = optimizer_kwargs_mu
|
|
1283
|
+
if optimizer_kwargs_sigma is None:
|
|
1284
|
+
optimizer_kwargs_sigma = {'learning_rate': 0.1}
|
|
1285
|
+
self.optimizer_kwargs_sigma = optimizer_kwargs_sigma
|
|
1286
|
+
self.optimizer_name = optimizer
|
|
1287
|
+
mu_optimizer = optimizer(**optimizer_kwargs_mu)
|
|
1288
|
+
sigma_optimizer = optimizer(**optimizer_kwargs_sigma)
|
|
1289
|
+
self.optimizers = (mu_optimizer, sigma_optimizer)
|
|
1290
|
+
|
|
1291
|
+
def __str__(self) -> str:
|
|
1292
|
+
return (f'PGPE hyper-parameters:\n'
|
|
1293
|
+
f' method ={self.__class__.__name__}\n'
|
|
1294
|
+
f' batch_size ={self.batch_size}\n'
|
|
1295
|
+
f' init_sigma ={self.init_sigma}\n'
|
|
1296
|
+
f' sigma_range ={self.sigma_range}\n'
|
|
1297
|
+
f' scale_reward ={self.scale_reward}\n'
|
|
1298
|
+
f' super_symmetric={self.super_symmetric}\n'
|
|
1299
|
+
f' accurate ={self.super_symmetric_accurate}\n'
|
|
1300
|
+
f' optimizer ={self.optimizer_name}\n'
|
|
1301
|
+
f' optimizer_kwargs:\n'
|
|
1302
|
+
f' mu ={self.optimizer_kwargs_mu}\n'
|
|
1303
|
+
f' sigma={self.optimizer_kwargs_sigma}\n'
|
|
1304
|
+
)
|
|
1305
|
+
|
|
1306
|
+
def compile(self, loss_fn: Callable, projection: Callable, real_dtype: Type) -> None:
|
|
1307
|
+
MIN_NORM = 1e-5
|
|
1308
|
+
sigma0 = self.init_sigma
|
|
1309
|
+
sigma_range = self.sigma_range
|
|
1310
|
+
scale_reward = self.scale_reward
|
|
1311
|
+
super_symmetric = self.super_symmetric
|
|
1312
|
+
super_symmetric_accurate = self.super_symmetric_accurate
|
|
1313
|
+
batch_size = self.batch_size
|
|
1314
|
+
optimizers = (mu_optimizer, sigma_optimizer) = self.optimizers
|
|
1315
|
+
|
|
1316
|
+
# initializer
|
|
1317
|
+
def _jax_wrapped_pgpe_init(key, policy_params):
|
|
1318
|
+
mu = policy_params
|
|
1319
|
+
sigma = jax.tree_map(lambda x: sigma0 * jnp.ones_like(x), mu)
|
|
1320
|
+
pgpe_params = (mu, sigma)
|
|
1321
|
+
pgpe_opt_state = tuple(opt.init(param)
|
|
1322
|
+
for (opt, param) in zip(optimizers, pgpe_params))
|
|
1323
|
+
return pgpe_params, pgpe_opt_state
|
|
1324
|
+
|
|
1325
|
+
self._initializer = jax.jit(_jax_wrapped_pgpe_init)
|
|
1326
|
+
|
|
1327
|
+
# parameter sampling functions
|
|
1328
|
+
def _jax_wrapped_mu_noise(key, sigma):
|
|
1329
|
+
return sigma * random.normal(key, shape=jnp.shape(sigma), dtype=real_dtype)
|
|
1330
|
+
|
|
1331
|
+
def _jax_wrapped_epsilon_star(sigma, epsilon):
|
|
1332
|
+
c1, c2, c3 = -0.06655, -0.9706, 0.124
|
|
1333
|
+
phi = 0.67449 * sigma
|
|
1334
|
+
a = (sigma - jnp.abs(epsilon)) / sigma
|
|
1335
|
+
if super_symmetric_accurate:
|
|
1336
|
+
aa = jnp.abs(a)
|
|
1337
|
+
epsilon_star = jnp.sign(epsilon) * phi * jnp.where(
|
|
1338
|
+
a <= 0,
|
|
1339
|
+
jnp.exp(c1 * aa * (aa * aa - 1) / jnp.log(aa + 1e-10) + c2 * aa),
|
|
1340
|
+
jnp.exp(aa - c3 * aa * jnp.log(1.0 - jnp.power(aa, 3) + 1e-10))
|
|
1341
|
+
)
|
|
1342
|
+
else:
|
|
1343
|
+
epsilon_star = jnp.sign(epsilon) * phi * jnp.exp(a)
|
|
1344
|
+
return epsilon_star
|
|
1345
|
+
|
|
1346
|
+
def _jax_wrapped_sample_params(key, mu, sigma):
|
|
1347
|
+
keys = random.split(key, num=len(jax.tree_util.tree_leaves(mu)))
|
|
1348
|
+
keys_pytree = jax.tree_util.tree_unflatten(
|
|
1349
|
+
treedef=jax.tree_util.tree_structure(mu), leaves=keys)
|
|
1350
|
+
epsilon = jax.tree_map(_jax_wrapped_mu_noise, keys_pytree, sigma)
|
|
1351
|
+
p1 = jax.tree_map(jnp.add, mu, epsilon)
|
|
1352
|
+
p2 = jax.tree_map(jnp.subtract, mu, epsilon)
|
|
1353
|
+
if super_symmetric:
|
|
1354
|
+
epsilon_star = jax.tree_map(_jax_wrapped_epsilon_star, sigma, epsilon)
|
|
1355
|
+
p3 = jax.tree_map(jnp.add, mu, epsilon_star)
|
|
1356
|
+
p4 = jax.tree_map(jnp.subtract, mu, epsilon_star)
|
|
1357
|
+
else:
|
|
1358
|
+
epsilon_star, p3, p4 = epsilon, p1, p2
|
|
1359
|
+
return (p1, p2, p3, p4), (epsilon, epsilon_star)
|
|
1360
|
+
|
|
1361
|
+
# policy gradient update functions
|
|
1362
|
+
def _jax_wrapped_mu_grad(epsilon, epsilon_star, r1, r2, r3, r4, m):
|
|
1363
|
+
if super_symmetric:
|
|
1364
|
+
if scale_reward:
|
|
1365
|
+
scale1 = jnp.maximum(MIN_NORM, m - (r1 + r2) / 2)
|
|
1366
|
+
scale2 = jnp.maximum(MIN_NORM, m - (r3 + r4) / 2)
|
|
1367
|
+
else:
|
|
1368
|
+
scale1 = scale2 = 1.0
|
|
1369
|
+
r_mu1 = (r1 - r2) / (2 * scale1)
|
|
1370
|
+
r_mu2 = (r3 - r4) / (2 * scale2)
|
|
1371
|
+
grad = -(r_mu1 * epsilon + r_mu2 * epsilon_star)
|
|
1372
|
+
else:
|
|
1373
|
+
if scale_reward:
|
|
1374
|
+
scale = jnp.maximum(MIN_NORM, m - (r1 + r2) / 2)
|
|
1375
|
+
else:
|
|
1376
|
+
scale = 1.0
|
|
1377
|
+
r_mu = (r1 - r2) / (2 * scale)
|
|
1378
|
+
grad = -r_mu * epsilon
|
|
1379
|
+
return grad
|
|
1380
|
+
|
|
1381
|
+
def _jax_wrapped_sigma_grad(epsilon, epsilon_star, sigma, r1, r2, r3, r4, m):
|
|
1382
|
+
if super_symmetric:
|
|
1383
|
+
mask = r1 + r2 >= r3 + r4
|
|
1384
|
+
epsilon_tau = mask * epsilon + (1 - mask) * epsilon_star
|
|
1385
|
+
s = epsilon_tau * epsilon_tau / sigma - sigma
|
|
1386
|
+
if scale_reward:
|
|
1387
|
+
scale = jnp.maximum(MIN_NORM, m - (r1 + r2 + r3 + r4) / 4)
|
|
1388
|
+
else:
|
|
1389
|
+
scale = 1.0
|
|
1390
|
+
r_sigma = ((r1 + r2) - (r3 + r4)) / (4 * scale)
|
|
1391
|
+
else:
|
|
1392
|
+
s = epsilon * epsilon / sigma - sigma
|
|
1393
|
+
if scale_reward:
|
|
1394
|
+
scale = jnp.maximum(MIN_NORM, jnp.abs(m))
|
|
1395
|
+
else:
|
|
1396
|
+
scale = 1.0
|
|
1397
|
+
r_sigma = (r1 + r2) / (2 * scale)
|
|
1398
|
+
grad = -r_sigma * s
|
|
1399
|
+
return grad
|
|
1400
|
+
|
|
1401
|
+
def _jax_wrapped_pgpe_grad(key, mu, sigma, r_max,
|
|
1402
|
+
policy_hyperparams, subs, model_params):
|
|
1403
|
+
key, subkey = random.split(key)
|
|
1404
|
+
(p1, p2, p3, p4), (epsilon, epsilon_star) = _jax_wrapped_sample_params(
|
|
1405
|
+
key, mu, sigma)
|
|
1406
|
+
r1 = -loss_fn(subkey, p1, policy_hyperparams, subs, model_params)[0]
|
|
1407
|
+
r2 = -loss_fn(subkey, p2, policy_hyperparams, subs, model_params)[0]
|
|
1408
|
+
r_max = jnp.maximum(r_max, r1)
|
|
1409
|
+
r_max = jnp.maximum(r_max, r2)
|
|
1410
|
+
if super_symmetric:
|
|
1411
|
+
r3 = -loss_fn(subkey, p3, policy_hyperparams, subs, model_params)[0]
|
|
1412
|
+
r4 = -loss_fn(subkey, p4, policy_hyperparams, subs, model_params)[0]
|
|
1413
|
+
r_max = jnp.maximum(r_max, r3)
|
|
1414
|
+
r_max = jnp.maximum(r_max, r4)
|
|
1415
|
+
else:
|
|
1416
|
+
r3, r4 = r1, r2
|
|
1417
|
+
grad_mu = jax.tree_map(
|
|
1418
|
+
partial(_jax_wrapped_mu_grad, r1=r1, r2=r2, r3=r3, r4=r4, m=r_max),
|
|
1419
|
+
epsilon, epsilon_star
|
|
1420
|
+
)
|
|
1421
|
+
grad_sigma = jax.tree_map(
|
|
1422
|
+
partial(_jax_wrapped_sigma_grad, r1=r1, r2=r2, r3=r3, r4=r4, m=r_max),
|
|
1423
|
+
epsilon, epsilon_star, sigma
|
|
1424
|
+
)
|
|
1425
|
+
return grad_mu, grad_sigma, r_max
|
|
1426
|
+
|
|
1427
|
+
def _jax_wrapped_pgpe_grad_batched(key, pgpe_params, r_max,
|
|
1428
|
+
policy_hyperparams, subs, model_params):
|
|
1429
|
+
mu, sigma = pgpe_params
|
|
1430
|
+
if batch_size == 1:
|
|
1431
|
+
mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad(
|
|
1432
|
+
key, mu, sigma, r_max, policy_hyperparams, subs, model_params)
|
|
1433
|
+
else:
|
|
1434
|
+
keys = random.split(key, num=batch_size)
|
|
1435
|
+
mu_grads, sigma_grads, r_maxs = jax.vmap(
|
|
1436
|
+
_jax_wrapped_pgpe_grad,
|
|
1437
|
+
in_axes=(0, None, None, None, None, None, None)
|
|
1438
|
+
)(keys, mu, sigma, r_max, policy_hyperparams, subs, model_params)
|
|
1439
|
+
mu_grad = jax.tree_map(partial(jnp.mean, axis=0), mu_grads)
|
|
1440
|
+
sigma_grad = jax.tree_map(partial(jnp.mean, axis=0), sigma_grads)
|
|
1441
|
+
new_r_max = jnp.max(r_maxs)
|
|
1442
|
+
return mu_grad, sigma_grad, new_r_max
|
|
1443
|
+
|
|
1444
|
+
def _jax_wrapped_pgpe_update(key, pgpe_params, r_max,
|
|
1445
|
+
policy_hyperparams, subs, model_params,
|
|
1446
|
+
pgpe_opt_state):
|
|
1447
|
+
mu, sigma = pgpe_params
|
|
1448
|
+
mu_state, sigma_state = pgpe_opt_state
|
|
1449
|
+
mu_grad, sigma_grad, new_r_max = _jax_wrapped_pgpe_grad_batched(
|
|
1450
|
+
key, pgpe_params, r_max, policy_hyperparams, subs, model_params)
|
|
1451
|
+
mu_updates, new_mu_state = mu_optimizer.update(mu_grad, mu_state, params=mu)
|
|
1452
|
+
sigma_updates, new_sigma_state = sigma_optimizer.update(
|
|
1453
|
+
sigma_grad, sigma_state, params=sigma)
|
|
1454
|
+
new_mu = optax.apply_updates(mu, mu_updates)
|
|
1455
|
+
new_mu, converged = projection(new_mu, policy_hyperparams)
|
|
1456
|
+
new_sigma = optax.apply_updates(sigma, sigma_updates)
|
|
1457
|
+
new_sigma = jax.tree_map(lambda x: jnp.clip(x, *sigma_range), new_sigma)
|
|
1458
|
+
new_pgpe_params = (new_mu, new_sigma)
|
|
1459
|
+
new_pgpe_opt_state = (new_mu_state, new_sigma_state)
|
|
1460
|
+
policy_params = new_mu
|
|
1461
|
+
return new_pgpe_params, new_r_max, new_pgpe_opt_state, policy_params, converged
|
|
1462
|
+
|
|
1463
|
+
self._update = jax.jit(_jax_wrapped_pgpe_update)
|
|
1464
|
+
|
|
1465
|
+
|
|
1466
|
+
# ***********************************************************************
|
|
1467
|
+
# ALL VERSIONS OF JAX PLANNER
|
|
1468
|
+
#
|
|
1469
|
+
# - simple gradient descent based planner
|
|
1470
|
+
#
|
|
1471
|
+
# ***********************************************************************
|
|
1472
|
+
|
|
1473
|
+
|
|
1148
1474
|
class JaxBackpropPlanner:
|
|
1149
1475
|
'''A class for optimizing an action sequence in the given RDDL MDP using
|
|
1150
1476
|
gradient descent.'''
|
|
@@ -1161,6 +1487,7 @@ class JaxBackpropPlanner:
|
|
|
1161
1487
|
clip_grad: Optional[float]=None,
|
|
1162
1488
|
line_search_kwargs: Optional[Kwargs]=None,
|
|
1163
1489
|
noise_kwargs: Optional[Kwargs]=None,
|
|
1490
|
+
pgpe: Optional[PGPE]=GaussianPGPE(),
|
|
1164
1491
|
logic: Logic=FuzzyLogic(),
|
|
1165
1492
|
use_symlog_reward: bool=False,
|
|
1166
1493
|
utility: Union[Callable[[jnp.ndarray], float], str]='mean',
|
|
@@ -1191,6 +1518,7 @@ class JaxBackpropPlanner:
|
|
|
1191
1518
|
:param line_search_kwargs: parameters to pass to optional line search
|
|
1192
1519
|
method to scale learning rate
|
|
1193
1520
|
:param noise_kwargs: parameters of optional gradient noise
|
|
1521
|
+
:param pgpe: optional policy gradient to run alongside the planner
|
|
1194
1522
|
:param logic: a subclass of Logic for mapping exact mathematical
|
|
1195
1523
|
operations to their differentiable counterparts
|
|
1196
1524
|
:param use_symlog_reward: whether to use the symlog transform on the
|
|
@@ -1229,6 +1557,8 @@ class JaxBackpropPlanner:
|
|
|
1229
1557
|
self.clip_grad = clip_grad
|
|
1230
1558
|
self.line_search_kwargs = line_search_kwargs
|
|
1231
1559
|
self.noise_kwargs = noise_kwargs
|
|
1560
|
+
self.pgpe = pgpe
|
|
1561
|
+
self.use_pgpe = pgpe is not None
|
|
1232
1562
|
|
|
1233
1563
|
# set optimizer
|
|
1234
1564
|
try:
|
|
@@ -1333,24 +1663,25 @@ r"""
|
|
|
1333
1663
|
f' line_search_kwargs={self.line_search_kwargs}\n'
|
|
1334
1664
|
f' noise_kwargs ={self.noise_kwargs}\n'
|
|
1335
1665
|
f' batch_size_train ={self.batch_size_train}\n'
|
|
1336
|
-
f' batch_size_test ={self.batch_size_test}')
|
|
1337
|
-
result +=
|
|
1338
|
-
|
|
1666
|
+
f' batch_size_test ={self.batch_size_test}\n')
|
|
1667
|
+
result += str(self.plan)
|
|
1668
|
+
if self.use_pgpe:
|
|
1669
|
+
result += str(self.pgpe)
|
|
1670
|
+
result += str(self.logic)
|
|
1339
1671
|
|
|
1340
1672
|
# print model relaxation information
|
|
1341
|
-
if
|
|
1342
|
-
|
|
1343
|
-
|
|
1344
|
-
|
|
1345
|
-
|
|
1346
|
-
|
|
1347
|
-
|
|
1348
|
-
|
|
1349
|
-
|
|
1350
|
-
|
|
1351
|
-
|
|
1352
|
-
|
|
1353
|
-
f' init_values={values_by_rddl_op[rddl_op]}\n')
|
|
1673
|
+
if self.compiled.model_params:
|
|
1674
|
+
result += ('Some RDDL operations are non-differentiable '
|
|
1675
|
+
'and will be approximated as follows:' + '\n')
|
|
1676
|
+
exprs_by_rddl_op, values_by_rddl_op = {}, {}
|
|
1677
|
+
for info in self.compiled.model_parameter_info().values():
|
|
1678
|
+
rddl_op = info['rddl_op']
|
|
1679
|
+
exprs_by_rddl_op.setdefault(rddl_op, []).append(info['id'])
|
|
1680
|
+
values_by_rddl_op.setdefault(rddl_op, []).append(info['init_value'])
|
|
1681
|
+
for rddl_op in sorted(exprs_by_rddl_op.keys()):
|
|
1682
|
+
result += (f' {rddl_op}:\n'
|
|
1683
|
+
f' addresses ={exprs_by_rddl_op[rddl_op]}\n'
|
|
1684
|
+
f' init_values={values_by_rddl_op[rddl_op]}\n')
|
|
1354
1685
|
return result
|
|
1355
1686
|
|
|
1356
1687
|
def summarize_hyperparameters(self) -> None:
|
|
@@ -1415,6 +1746,16 @@ r"""
|
|
|
1415
1746
|
|
|
1416
1747
|
# optimization
|
|
1417
1748
|
self.update = self._jax_update(train_loss)
|
|
1749
|
+
self.check_zero_grad = self._jax_check_zero_gradients()
|
|
1750
|
+
|
|
1751
|
+
# pgpe option
|
|
1752
|
+
if self.use_pgpe:
|
|
1753
|
+
loss_fn = self._jax_loss(rollouts=test_rollouts)
|
|
1754
|
+
self.pgpe.compile(
|
|
1755
|
+
loss_fn=loss_fn,
|
|
1756
|
+
projection=self.plan.projection,
|
|
1757
|
+
real_dtype=self.test_compiled.REAL
|
|
1758
|
+
)
|
|
1418
1759
|
|
|
1419
1760
|
def _jax_return(self, use_symlog):
|
|
1420
1761
|
gamma = self.rddl.discount
|
|
@@ -1497,6 +1838,18 @@ r"""
|
|
|
1497
1838
|
|
|
1498
1839
|
return jax.jit(_jax_wrapped_plan_update)
|
|
1499
1840
|
|
|
1841
|
+
def _jax_check_zero_gradients(self):
|
|
1842
|
+
|
|
1843
|
+
def _jax_wrapped_zero_gradient(grad):
|
|
1844
|
+
return jnp.allclose(grad, 0)
|
|
1845
|
+
|
|
1846
|
+
def _jax_wrapped_zero_gradients(grad):
|
|
1847
|
+
leaves, _ = jax.tree_util.tree_flatten(
|
|
1848
|
+
jax.tree_map(_jax_wrapped_zero_gradient, grad))
|
|
1849
|
+
return jnp.all(jnp.asarray(leaves))
|
|
1850
|
+
|
|
1851
|
+
return jax.jit(_jax_wrapped_zero_gradients)
|
|
1852
|
+
|
|
1500
1853
|
def _batched_init_subs(self, subs):
|
|
1501
1854
|
rddl = self.rddl
|
|
1502
1855
|
n_train, n_test = self.batch_size_train, self.batch_size_test
|
|
@@ -1611,7 +1964,7 @@ r"""
|
|
|
1611
1964
|
return grad
|
|
1612
1965
|
|
|
1613
1966
|
return _loss_function, _grad_function, guess_1d, jax.jit(unravel_fn)
|
|
1614
|
-
|
|
1967
|
+
|
|
1615
1968
|
# ===========================================================================
|
|
1616
1969
|
# OPTIMIZE API
|
|
1617
1970
|
# ===========================================================================
|
|
@@ -1784,7 +2137,17 @@ r"""
|
|
|
1784
2137
|
policy_params = guess
|
|
1785
2138
|
opt_state = self.optimizer.init(policy_params)
|
|
1786
2139
|
opt_aux = {}
|
|
1787
|
-
|
|
2140
|
+
|
|
2141
|
+
# initialize pgpe parameters
|
|
2142
|
+
if self.use_pgpe:
|
|
2143
|
+
pgpe_params, pgpe_opt_state = self.pgpe.initialize(key, policy_params)
|
|
2144
|
+
rolling_pgpe_loss = RollingMean(test_rolling_window)
|
|
2145
|
+
else:
|
|
2146
|
+
pgpe_params, pgpe_opt_state = None, None
|
|
2147
|
+
rolling_pgpe_loss = None
|
|
2148
|
+
total_pgpe_it = 0
|
|
2149
|
+
r_max = -jnp.inf
|
|
2150
|
+
|
|
1788
2151
|
# ======================================================================
|
|
1789
2152
|
# INITIALIZATION OF RUNNING STATISTICS
|
|
1790
2153
|
# ======================================================================
|
|
@@ -1795,7 +2158,6 @@ r"""
|
|
|
1795
2158
|
rolling_test_loss = RollingMean(test_rolling_window)
|
|
1796
2159
|
log = {}
|
|
1797
2160
|
status = JaxPlannerStatus.NORMAL
|
|
1798
|
-
is_all_zero_fn = lambda x: np.allclose(x, 0)
|
|
1799
2161
|
|
|
1800
2162
|
# initialize stopping criterion
|
|
1801
2163
|
if stopping_rule is not None:
|
|
@@ -1826,19 +2188,47 @@ r"""
|
|
|
1826
2188
|
|
|
1827
2189
|
# update the parameters of the plan
|
|
1828
2190
|
key, subkey = random.split(key)
|
|
1829
|
-
(policy_params, converged, opt_state, opt_aux,
|
|
1830
|
-
|
|
1831
|
-
|
|
1832
|
-
|
|
1833
|
-
|
|
2191
|
+
(policy_params, converged, opt_state, opt_aux, train_loss, train_log,
|
|
2192
|
+
model_params) = self.update(subkey, policy_params, policy_hyperparams,
|
|
2193
|
+
train_subs, model_params, opt_state, opt_aux)
|
|
2194
|
+
test_loss, (test_log, model_params_test) = self.test_loss(
|
|
2195
|
+
subkey, policy_params, policy_hyperparams, test_subs, model_params_test)
|
|
2196
|
+
test_loss_smooth = rolling_test_loss.update(test_loss)
|
|
2197
|
+
|
|
2198
|
+
# pgpe update of the plan
|
|
2199
|
+
pgpe_improve = False
|
|
2200
|
+
if self.use_pgpe:
|
|
2201
|
+
key, subkey = random.split(key)
|
|
2202
|
+
pgpe_params, r_max, pgpe_opt_state, pgpe_param, pgpe_converged = \
|
|
2203
|
+
self.pgpe.update(subkey, pgpe_params, r_max, policy_hyperparams,
|
|
2204
|
+
test_subs, model_params, pgpe_opt_state)
|
|
2205
|
+
pgpe_loss, _ = self.test_loss(
|
|
2206
|
+
subkey, pgpe_param, policy_hyperparams, test_subs, model_params_test)
|
|
2207
|
+
pgpe_loss_smooth = rolling_pgpe_loss.update(pgpe_loss)
|
|
2208
|
+
pgpe_return = -pgpe_loss_smooth
|
|
2209
|
+
|
|
2210
|
+
# replace with PGPE if it reaches a new minimum or train loss invalid
|
|
2211
|
+
if pgpe_loss_smooth < best_loss or not np.isfinite(train_loss):
|
|
2212
|
+
policy_params = pgpe_param
|
|
2213
|
+
test_loss, test_loss_smooth = pgpe_loss, pgpe_loss_smooth
|
|
2214
|
+
converged = pgpe_converged
|
|
2215
|
+
pgpe_improve = True
|
|
2216
|
+
total_pgpe_it += 1
|
|
2217
|
+
else:
|
|
2218
|
+
pgpe_loss, pgpe_loss_smooth, pgpe_return = None, None, None
|
|
2219
|
+
|
|
2220
|
+
# evaluate test losses and record best plan so far
|
|
2221
|
+
if test_loss_smooth < best_loss:
|
|
2222
|
+
best_params, best_loss, best_grad = \
|
|
2223
|
+
policy_params, test_loss_smooth, train_log['grad']
|
|
2224
|
+
last_iter_improve = it
|
|
2225
|
+
|
|
1834
2226
|
# ==================================================================
|
|
1835
2227
|
# STATUS CHECKS AND LOGGING
|
|
1836
2228
|
# ==================================================================
|
|
1837
2229
|
|
|
1838
2230
|
# no progress
|
|
1839
|
-
|
|
1840
|
-
jax.tree_map(is_all_zero_fn, train_log['grad']))
|
|
1841
|
-
if np.all(grad_norm_zero):
|
|
2231
|
+
if (not pgpe_improve) and self.check_zero_grad(train_log['grad']):
|
|
1842
2232
|
status = JaxPlannerStatus.NO_PROGRESS
|
|
1843
2233
|
|
|
1844
2234
|
# constraint satisfaction problem
|
|
@@ -1850,21 +2240,14 @@ r"""
|
|
|
1850
2240
|
status = JaxPlannerStatus.PRECONDITION_POSSIBLY_UNSATISFIED
|
|
1851
2241
|
|
|
1852
2242
|
# numerical error
|
|
1853
|
-
if
|
|
1854
|
-
|
|
1855
|
-
|
|
2243
|
+
if self.use_pgpe:
|
|
2244
|
+
invalid_loss = not (np.isfinite(train_loss) or np.isfinite(pgpe_loss))
|
|
2245
|
+
else:
|
|
2246
|
+
invalid_loss = not np.isfinite(train_loss)
|
|
2247
|
+
if invalid_loss:
|
|
2248
|
+
raise_warning(f'Planner aborted due to invalid loss {train_loss}.', 'red')
|
|
1856
2249
|
status = JaxPlannerStatus.INVALID_GRADIENT
|
|
1857
2250
|
|
|
1858
|
-
# evaluate test losses and record best plan so far
|
|
1859
|
-
test_loss, (log, model_params_test) = self.test_loss(
|
|
1860
|
-
subkey, policy_params, policy_hyperparams,
|
|
1861
|
-
test_subs, model_params_test)
|
|
1862
|
-
test_loss = rolling_test_loss.update(test_loss)
|
|
1863
|
-
if test_loss < best_loss:
|
|
1864
|
-
best_params, best_loss, best_grad = \
|
|
1865
|
-
policy_params, test_loss, train_log['grad']
|
|
1866
|
-
last_iter_improve = it
|
|
1867
|
-
|
|
1868
2251
|
# reached computation budget
|
|
1869
2252
|
elapsed = time.time() - start_time - elapsed_outside_loop
|
|
1870
2253
|
if elapsed >= train_seconds:
|
|
@@ -1878,11 +2261,14 @@ r"""
|
|
|
1878
2261
|
'status': status,
|
|
1879
2262
|
'iteration': it,
|
|
1880
2263
|
'train_return':-train_loss,
|
|
1881
|
-
'test_return':-
|
|
2264
|
+
'test_return':-test_loss_smooth,
|
|
1882
2265
|
'best_return':-best_loss,
|
|
2266
|
+
'pgpe_return': pgpe_return,
|
|
1883
2267
|
'params': policy_params,
|
|
1884
2268
|
'best_params': best_params,
|
|
2269
|
+
'pgpe_params': pgpe_params,
|
|
1885
2270
|
'last_iteration_improved': last_iter_improve,
|
|
2271
|
+
'pgpe_improved': pgpe_improve,
|
|
1886
2272
|
'grad': train_log['grad'],
|
|
1887
2273
|
'best_grad': best_grad,
|
|
1888
2274
|
'updates': train_log['updates'],
|
|
@@ -1891,7 +2277,7 @@ r"""
|
|
|
1891
2277
|
'model_params': model_params,
|
|
1892
2278
|
'progress': progress_percent,
|
|
1893
2279
|
'train_log': train_log,
|
|
1894
|
-
**
|
|
2280
|
+
**test_log
|
|
1895
2281
|
}
|
|
1896
2282
|
|
|
1897
2283
|
# stopping condition reached
|
|
@@ -1902,9 +2288,9 @@ r"""
|
|
|
1902
2288
|
if print_progress:
|
|
1903
2289
|
iters.n = progress_percent
|
|
1904
2290
|
iters.set_description(
|
|
1905
|
-
f'{position_str} {it:6} it / {-train_loss:14.
|
|
1906
|
-
f'{-
|
|
1907
|
-
f'{status.value} status'
|
|
2291
|
+
f'{position_str} {it:6} it / {-train_loss:14.5f} train / '
|
|
2292
|
+
f'{-test_loss_smooth:14.5f} test / {-best_loss:14.5f} best / '
|
|
2293
|
+
f'{status.value} status / {total_pgpe_it:6} pgpe'
|
|
1908
2294
|
)
|
|
1909
2295
|
|
|
1910
2296
|
# dash-board
|
|
@@ -1923,7 +2309,7 @@ r"""
|
|
|
1923
2309
|
# ======================================================================
|
|
1924
2310
|
# POST-PROCESSING AND CLEANUP
|
|
1925
2311
|
# ======================================================================
|
|
1926
|
-
|
|
2312
|
+
|
|
1927
2313
|
# release resources
|
|
1928
2314
|
if print_progress:
|
|
1929
2315
|
iters.close()
|
|
@@ -1935,7 +2321,7 @@ r"""
|
|
|
1935
2321
|
messages.update(JaxRDDLCompiler.get_error_messages(error_code))
|
|
1936
2322
|
if messages:
|
|
1937
2323
|
messages = '\n'.join(messages)
|
|
1938
|
-
raise_warning('
|
|
2324
|
+
raise_warning('JAX compiler encountered the following '
|
|
1939
2325
|
'error(s) in the original RDDL formulation '
|
|
1940
2326
|
f'during test evaluation:\n{messages}', 'red')
|
|
1941
2327
|
|
|
@@ -1943,14 +2329,14 @@ r"""
|
|
|
1943
2329
|
if print_summary:
|
|
1944
2330
|
grad_norm = jax.tree_map(lambda x: np.linalg.norm(x).item(), best_grad)
|
|
1945
2331
|
diagnosis = self._perform_diagnosis(
|
|
1946
|
-
last_iter_improve, -train_loss, -
|
|
2332
|
+
last_iter_improve, -train_loss, -test_loss_smooth, -best_loss, grad_norm)
|
|
1947
2333
|
print(f'summary of optimization:\n'
|
|
1948
|
-
f'
|
|
1949
|
-
f'
|
|
2334
|
+
f' status ={status}\n'
|
|
2335
|
+
f' time ={elapsed:.6f} sec.\n'
|
|
1950
2336
|
f' iterations ={it}\n'
|
|
1951
|
-
f'
|
|
1952
|
-
f'
|
|
1953
|
-
f'
|
|
2337
|
+
f' best objective={-best_loss:.6f}\n'
|
|
2338
|
+
f' best grad norm={grad_norm}\n'
|
|
2339
|
+
f'diagnosis: {diagnosis}\n')
|
|
1954
2340
|
|
|
1955
2341
|
def _perform_diagnosis(self, last_iter_improve,
|
|
1956
2342
|
train_return, test_return, best_return, grad_norm):
|
|
@@ -1970,23 +2356,24 @@ r"""
|
|
|
1970
2356
|
if last_iter_improve <= 1:
|
|
1971
2357
|
if grad_is_zero:
|
|
1972
2358
|
return termcolor.colored(
|
|
1973
|
-
'[FAILURE] no progress was made
|
|
1974
|
-
f'and max grad norm {max_grad_norm:.6f}
|
|
1975
|
-
'solver likely stuck in a plateau.', 'red')
|
|
2359
|
+
'[FAILURE] no progress was made '
|
|
2360
|
+
f'and max grad norm {max_grad_norm:.6f} was zero: '
|
|
2361
|
+
'the solver was likely stuck in a plateau.', 'red')
|
|
1976
2362
|
else:
|
|
1977
2363
|
return termcolor.colored(
|
|
1978
|
-
'[FAILURE] no progress was made
|
|
1979
|
-
f'but max grad norm {max_grad_norm:.6f}
|
|
1980
|
-
'
|
|
2364
|
+
'[FAILURE] no progress was made '
|
|
2365
|
+
f'but max grad norm {max_grad_norm:.6f} was non-zero: '
|
|
2366
|
+
'the learning rate or other hyper-parameters were likely suboptimal.',
|
|
2367
|
+
'red')
|
|
1981
2368
|
|
|
1982
2369
|
# model is likely poor IF:
|
|
1983
2370
|
# 1. the train and test return disagree
|
|
1984
2371
|
if not (validation_error < 20):
|
|
1985
2372
|
return termcolor.colored(
|
|
1986
|
-
'[WARNING] progress was made
|
|
1987
|
-
f'but relative train-test error {validation_error:.6f}
|
|
1988
|
-
'
|
|
1989
|
-
'or the batch size
|
|
2373
|
+
'[WARNING] progress was made '
|
|
2374
|
+
f'but relative train-test error {validation_error:.6f} was high: '
|
|
2375
|
+
'model relaxation around the solution was poor '
|
|
2376
|
+
'or the batch size was too small.', 'yellow')
|
|
1990
2377
|
|
|
1991
2378
|
# model likely did not converge IF:
|
|
1992
2379
|
# 1. the max grad relative to the return is high
|
|
@@ -1994,15 +2381,15 @@ r"""
|
|
|
1994
2381
|
return_to_grad_norm = abs(best_return) / max_grad_norm
|
|
1995
2382
|
if not (return_to_grad_norm > 1):
|
|
1996
2383
|
return termcolor.colored(
|
|
1997
|
-
'[WARNING] progress was made
|
|
1998
|
-
f'but max grad norm {max_grad_norm:.6f}
|
|
1999
|
-
'
|
|
2000
|
-
'or the relaxed model
|
|
2001
|
-
'or the batch size
|
|
2384
|
+
'[WARNING] progress was made '
|
|
2385
|
+
f'but max grad norm {max_grad_norm:.6f} was high: '
|
|
2386
|
+
'the solution was likely locally suboptimal, '
|
|
2387
|
+
'or the relaxed model was not smooth around the solution, '
|
|
2388
|
+
'or the batch size was too small.', 'yellow')
|
|
2002
2389
|
|
|
2003
2390
|
# likely successful
|
|
2004
2391
|
return termcolor.colored(
|
|
2005
|
-
'[SUCCESS]
|
|
2392
|
+
'[SUCCESS] solver converged successfully '
|
|
2006
2393
|
'(note: not all potential problems can be ruled out).', 'green')
|
|
2007
2394
|
|
|
2008
2395
|
def get_action(self, key: random.PRNGKey,
|
|
@@ -2035,8 +2422,8 @@ r"""
|
|
|
2035
2422
|
# must be numeric array
|
|
2036
2423
|
# exception is for POMDPs at 1st epoch when observ-fluents are None
|
|
2037
2424
|
dtype = np.atleast_1d(values).dtype
|
|
2038
|
-
if not
|
|
2039
|
-
and not
|
|
2425
|
+
if not np.issubdtype(dtype, np.number) \
|
|
2426
|
+
and not np.issubdtype(dtype, np.bool_):
|
|
2040
2427
|
if step == 0 and var in self.rddl.observ_fluents:
|
|
2041
2428
|
subs[var] = self.test_compiled.init_values[var]
|
|
2042
2429
|
else:
|
|
@@ -2077,10 +2464,11 @@ def mean_variance_utility(returns: jnp.ndarray, beta: float) -> float:
|
|
|
2077
2464
|
|
|
2078
2465
|
@jax.jit
|
|
2079
2466
|
def cvar_utility(returns: jnp.ndarray, alpha: float) -> float:
|
|
2080
|
-
|
|
2081
|
-
|
|
2082
|
-
|
|
2083
|
-
|
|
2467
|
+
var = jnp.percentile(returns, q=100 * alpha)
|
|
2468
|
+
mask = returns <= var
|
|
2469
|
+
weights = mask / jnp.maximum(1, jnp.sum(mask))
|
|
2470
|
+
return jnp.sum(returns * weights)
|
|
2471
|
+
|
|
2084
2472
|
|
|
2085
2473
|
# ***********************************************************************
|
|
2086
2474
|
# ALL VERSIONS OF CONTROLLERS
|