pyRDDLGym-jax 0.2__py3-none-any.whl → 0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,53 +1,52 @@
1
- __version__ = '0.2'
2
-
3
1
  from ast import literal_eval
4
2
  from collections import deque
5
3
  import configparser
6
4
  from enum import Enum
5
+ import os
6
+ import sys
7
+ import time
8
+ import traceback
9
+ from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Tuple, Union
10
+
7
11
  import haiku as hk
8
12
  import jax
13
+ import jax.nn.initializers as initializers
9
14
  import jax.numpy as jnp
10
15
  import jax.random as random
11
- import jax.nn.initializers as initializers
12
16
  import numpy as np
13
17
  import optax
14
- import os
15
- import sys
16
18
  import termcolor
17
- import time
18
19
  from tqdm import tqdm
19
- from typing import Any, Callable, Dict, Generator, Optional, Set, Sequence, Tuple, Union
20
-
21
- Activation = Callable[[jnp.ndarray], jnp.ndarray]
22
- Bounds = Dict[str, Tuple[np.ndarray, np.ndarray]]
23
- Kwargs = Dict[str, Any]
24
- Pytree = Any
25
20
 
26
- from pyRDDLGym.core.debug.exception import raise_warning
27
-
28
- # try to import matplotlib, if failed then skip plotting
29
- try:
30
- import matplotlib
31
- import matplotlib.pyplot as plt
32
- matplotlib.use('TkAgg')
33
- except Exception:
34
- raise_warning('matplotlib is not installed, '
35
- 'plotting functionality is disabled.', 'red')
36
- plt = None
37
-
38
21
  from pyRDDLGym.core.compiler.model import RDDLPlanningModel, RDDLLiftedModel
39
22
  from pyRDDLGym.core.debug.logger import Logger
40
23
  from pyRDDLGym.core.debug.exception import (
24
+ raise_warning,
41
25
  RDDLNotImplementedError,
42
26
  RDDLUndefinedVariableError,
43
27
  RDDLTypeError
44
28
  )
45
29
  from pyRDDLGym.core.policy import BaseAgent
46
30
 
47
- from pyRDDLGym_jax.core.compiler import JaxRDDLCompiler
31
+ from pyRDDLGym_jax import __version__
48
32
  from pyRDDLGym_jax.core import logic
33
+ from pyRDDLGym_jax.core.compiler import JaxRDDLCompiler
49
34
  from pyRDDLGym_jax.core.logic import FuzzyLogic
50
35
 
36
+ # try to import matplotlib, if failed then skip plotting
37
+ try:
38
+ import matplotlib.pyplot as plt
39
+ except Exception:
40
+ raise_warning('failed to import matplotlib: '
41
+ 'plotting functionality will be disabled.', 'red')
42
+ traceback.print_exc()
43
+ plt = None
44
+
45
+ Activation = Callable[[jnp.ndarray], jnp.ndarray]
46
+ Bounds = Dict[str, Tuple[np.ndarray, np.ndarray]]
47
+ Kwargs = Dict[str, Any]
48
+ Pytree = Any
49
+
51
50
 
52
51
  # ***********************************************************************
53
52
  # CONFIG FILE MANAGEMENT
@@ -102,9 +101,12 @@ def _load_config(config, args):
102
101
  comp_kwargs = model_args.get('complement_kwargs', {})
103
102
  compare_name = model_args.get('comparison', 'SigmoidComparison')
104
103
  compare_kwargs = model_args.get('comparison_kwargs', {})
104
+ sampling_name = model_args.get('sampling', 'GumbelSoftmax')
105
+ sampling_kwargs = model_args.get('sampling_kwargs', {})
105
106
  logic_kwargs['tnorm'] = getattr(logic, tnorm_name)(**tnorm_kwargs)
106
107
  logic_kwargs['complement'] = getattr(logic, comp_name)(**comp_kwargs)
107
108
  logic_kwargs['comparison'] = getattr(logic, compare_name)(**compare_kwargs)
109
+ logic_kwargs['sampling'] = getattr(logic, sampling_name)(**sampling_kwargs)
108
110
 
109
111
  # read the policy settings
110
112
  plan_method = planner_args.pop('method')
@@ -113,7 +115,8 @@ def _load_config(config, args):
113
115
  # policy initialization
114
116
  plan_initializer = plan_kwargs.get('initializer', None)
115
117
  if plan_initializer is not None:
116
- initializer = _getattr_any(packages=[initializers], item=plan_initializer)
118
+ initializer = _getattr_any(
119
+ packages=[initializers, hk.initializers], item=plan_initializer)
117
120
  if initializer is None:
118
121
  raise_warning(
119
122
  f'Ignoring invalid initializer <{plan_initializer}>.', 'red')
@@ -130,7 +133,8 @@ def _load_config(config, args):
130
133
  # policy activation
131
134
  plan_activation = plan_kwargs.get('activation', None)
132
135
  if plan_activation is not None:
133
- activation = _getattr_any(packages=[jax.nn, jax.numpy], item=plan_activation)
136
+ activation = _getattr_any(
137
+ packages=[jax.nn, jax.numpy], item=plan_activation)
134
138
  if activation is None:
135
139
  raise_warning(
136
140
  f'Ignoring invalid activation <{plan_activation}>.', 'red')
@@ -180,18 +184,6 @@ def load_config_from_string(value: str) -> Tuple[Kwargs, ...]:
180
184
  #
181
185
  # ***********************************************************************
182
186
 
183
- def _function_discrete_approx_named(logic):
184
- jax_discrete, jax_param = logic.discrete()
185
-
186
- def _jax_wrapped_discrete_calc_approx(key, prob, params):
187
- sample = jax_discrete(key, prob, params)
188
- out_of_bounds = jnp.logical_not(jnp.logical_and(
189
- jnp.all(prob >= 0),
190
- jnp.allclose(jnp.sum(prob, axis=-1), 1.0)))
191
- return sample, out_of_bounds
192
-
193
- return _jax_wrapped_discrete_calc_approx, jax_param
194
-
195
187
 
196
188
  class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
197
189
  '''Compiles a RDDL AST representation to an equivalent JAX representation.
@@ -217,6 +209,7 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
217
209
  :param *kwargs: keyword arguments to pass to base compiler
218
210
  '''
219
211
  super(JaxRDDLCompilerWithGrad, self).__init__(*args, **kwargs)
212
+
220
213
  self.logic = logic
221
214
  self.logic.set_use64bit(self.use64bit)
222
215
  if cpfs_without_grad is None:
@@ -224,9 +217,14 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
224
217
  self.cpfs_without_grad = cpfs_without_grad
225
218
 
226
219
  # actions and CPFs must be continuous
227
- raise_warning('Initial values of pvariables will be cast to real.')
220
+ pvars_cast = set()
228
221
  for (var, values) in self.init_values.items():
229
222
  self.init_values[var] = np.asarray(values, dtype=self.REAL)
223
+ if not np.issubdtype(np.atleast_1d(values).dtype, np.floating):
224
+ pvars_cast.add(var)
225
+ if pvars_cast:
226
+ raise_warning(f'JAX gradient compiler requires that initial values '
227
+ f'of p-variables {pvars_cast} be cast to float.')
230
228
 
231
229
  # overwrite basic operations with fuzzy ones
232
230
  self.RELATIONAL_OPS = {
@@ -261,7 +259,9 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
261
259
  self.IF_HELPER = logic.control_if()
262
260
  self.SWITCH_HELPER = logic.control_switch()
263
261
  self.BERNOULLI_HELPER = logic.bernoulli()
264
- self.DISCRETE_HELPER = _function_discrete_approx_named(logic)
262
+ self.DISCRETE_HELPER = logic.discrete()
263
+ self.POISSON_HELPER = logic.poisson()
264
+ self.GEOMETRIC_HELPER = logic.geometric()
265
265
 
266
266
  def _jax_stop_grad(self, jax_expr):
267
267
 
@@ -273,20 +273,29 @@ class JaxRDDLCompilerWithGrad(JaxRDDLCompiler):
273
273
  return _jax_wrapped_stop_grad
274
274
 
275
275
  def _compile_cpfs(self, info):
276
- raise_warning('CPFs outputs will be cast to real.')
276
+ cpfs_cast = set()
277
277
  jax_cpfs = {}
278
278
  for (_, cpfs) in self.levels.items():
279
279
  for cpf in cpfs:
280
280
  _, expr = self.rddl.cpfs[cpf]
281
281
  jax_cpfs[cpf] = self._jax(expr, info, dtype=self.REAL)
282
+ if self.rddl.variable_ranges[cpf] != 'real':
283
+ cpfs_cast.add(cpf)
282
284
  if cpf in self.cpfs_without_grad:
283
- raise_warning(f'CPF <{cpf}> stops gradient.')
284
285
  jax_cpfs[cpf] = self._jax_stop_grad(jax_cpfs[cpf])
286
+
287
+ if cpfs_cast:
288
+ raise_warning(f'JAX gradient compiler requires that outputs of CPFs '
289
+ f'{cpfs_cast} be cast to float.')
290
+ if self.cpfs_without_grad:
291
+ raise_warning(f'User requested that gradients not flow '
292
+ f'through CPFs {self.cpfs_without_grad}.')
285
293
  return jax_cpfs
286
294
 
287
295
  def _jax_kron(self, expr, info):
288
296
  if self.logic.verbose:
289
- raise_warning('KronDelta will be ignored.')
297
+ raise_warning('JAX gradient compiler ignores KronDelta '
298
+ 'during compilation.')
290
299
  arg, = expr.args
291
300
  arg = self._jax(arg, info)
292
301
  return arg
@@ -308,7 +317,8 @@ class JaxPlan:
308
317
  self._train_policy = None
309
318
  self._test_policy = None
310
319
  self._projection = None
311
-
320
+ self.bounds = None
321
+
312
322
  def summarize_hyperparameters(self) -> None:
313
323
  pass
314
324
 
@@ -363,7 +373,7 @@ class JaxPlan:
363
373
  # check invalid type
364
374
  if prange not in compiled.JAX_TYPES:
365
375
  raise RDDLTypeError(
366
- f'Invalid range <{prange}. of action-fluent <{name}>, '
376
+ f'Invalid range <{prange}> of action-fluent <{name}>, '
367
377
  f'must be one of {set(compiled.JAX_TYPES.keys())}.')
368
378
 
369
379
  # clip boolean to (0, 1), otherwise use the RDDL action bounds
@@ -385,7 +395,7 @@ class JaxPlan:
385
395
  ~lower_finite & upper_finite,
386
396
  ~lower_finite & ~upper_finite]
387
397
  bounds[name] = (lower, upper)
388
- raise_warning(f'Bounds of action fluent <{name}> set to {bounds[name]}.')
398
+ raise_warning(f'Bounds of action-fluent <{name}> set to {bounds[name]}.')
389
399
  return shapes, bounds, bounds_safe, cond_lists
390
400
 
391
401
  def _count_bool_actions(self, rddl: RDDLLiftedModel):
@@ -427,6 +437,7 @@ class JaxStraightLinePlan(JaxPlan):
427
437
  use_new_projection = True
428
438
  '''
429
439
  super(JaxStraightLinePlan, self).__init__()
440
+
430
441
  self._initializer_base = initializer
431
442
  self._initializer = initializer
432
443
  self._wrap_sigmoid = wrap_sigmoid
@@ -437,15 +448,19 @@ class JaxStraightLinePlan(JaxPlan):
437
448
  self._max_constraint_iter = max_constraint_iter
438
449
 
439
450
  def summarize_hyperparameters(self) -> None:
451
+ bounds = '\n '.join(
452
+ map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
440
453
  print(f'policy hyper-parameters:\n'
441
- f' initializer ={type(self._initializer_base).__name__}\n'
454
+ f' initializer ={self._initializer_base}\n'
442
455
  f'constraint-sat strategy (simple):\n'
456
+ f' parsed_action_bounds =\n {bounds}\n'
443
457
  f' wrap_sigmoid ={self._wrap_sigmoid}\n'
444
458
  f' wrap_sigmoid_min_prob={self._min_action_prob}\n'
445
459
  f' wrap_non_bool ={self._wrap_non_bool}\n'
446
460
  f'constraint-sat strategy (complex):\n'
447
461
  f' wrap_softmax ={self._wrap_softmax}\n'
448
- f' use_new_projection ={self._use_new_projection}')
462
+ f' use_new_projection ={self._use_new_projection}\n'
463
+ f' max_projection_iters ={self._max_constraint_iter}')
449
464
 
450
465
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
451
466
  _bounds: Bounds,
@@ -603,7 +618,7 @@ class JaxStraightLinePlan(JaxPlan):
603
618
  if 1 < allowed_actions < bool_action_count:
604
619
  raise RDDLNotImplementedError(
605
620
  f'Straight-line plans with wrap_softmax currently '
606
- f'do not support max-nondef-actions = {allowed_actions} > 1.')
621
+ f'do not support max-nondef-actions {allowed_actions} > 1.')
607
622
 
608
623
  # potentially apply projection but to non-bool actions only
609
624
  self.projection = _jax_wrapped_slp_project_to_box
@@ -734,14 +749,14 @@ class JaxStraightLinePlan(JaxPlan):
734
749
  for (var, shape) in shapes.items():
735
750
  if ranges[var] != 'bool' or not stack_bool_params:
736
751
  key, subkey = random.split(key)
737
- param = init(subkey, shape, dtype=compiled.REAL)
752
+ param = init(key=subkey, shape=shape, dtype=compiled.REAL)
738
753
  if ranges[var] == 'bool':
739
754
  param += bool_threshold
740
755
  params[var] = param
741
756
  if stack_bool_params:
742
757
  key, subkey = random.split(key)
743
758
  bool_shape = (horizon, bool_action_count)
744
- bool_param = init(subkey, bool_shape, dtype=compiled.REAL)
759
+ bool_param = init(key=subkey, shape=bool_shape, dtype=compiled.REAL)
745
760
  params[bool_key] = bool_param
746
761
  params, _ = _jax_wrapped_slp_project_to_box(params, hyperparams)
747
762
  return params
@@ -765,7 +780,8 @@ class JaxDeepReactivePolicy(JaxPlan):
765
780
  def __init__(self, topology: Optional[Sequence[int]]=None,
766
781
  activation: Activation=jnp.tanh,
767
782
  initializer: hk.initializers.Initializer=hk.initializers.VarianceScaling(scale=2.0),
768
- normalize: bool=True,
783
+ normalize: bool=False,
784
+ normalize_per_layer: bool=False,
769
785
  normalizer_kwargs: Optional[Kwargs]=None,
770
786
  wrap_non_bool: bool=False) -> None:
771
787
  '''Creates a new deep reactive policy in JAX.
@@ -775,12 +791,15 @@ class JaxDeepReactivePolicy(JaxPlan):
775
791
  :param activation: function to apply after each layer of the policy
776
792
  :param initializer: weight initialization
777
793
  :param normalize: whether to apply layer norm to the inputs
794
+ :param normalize_per_layer: whether to apply layer norm to each input
795
+ individually (only active if normalize is True)
778
796
  :param normalizer_kwargs: if normalize is True, apply additional arguments
779
797
  to layer norm
780
798
  :param wrap_non_bool: whether to wrap real or int action fluent parameters
781
799
  with non-linearity (e.g. sigmoid or ELU) to satisfy box constraints
782
800
  '''
783
801
  super(JaxDeepReactivePolicy, self).__init__()
802
+
784
803
  if topology is None:
785
804
  topology = [128, 64]
786
805
  self._topology = topology
@@ -788,22 +807,25 @@ class JaxDeepReactivePolicy(JaxPlan):
788
807
  self._initializer_base = initializer
789
808
  self._initializer = initializer
790
809
  self._normalize = normalize
810
+ self._normalize_per_layer = normalize_per_layer
791
811
  if normalizer_kwargs is None:
792
- normalizer_kwargs = {
793
- 'create_offset': True, 'create_scale': True,
794
- 'name': 'input_norm'
795
- }
812
+ normalizer_kwargs = {'create_offset': True, 'create_scale': True}
796
813
  self._normalizer_kwargs = normalizer_kwargs
797
814
  self._wrap_non_bool = wrap_non_bool
798
815
 
799
816
  def summarize_hyperparameters(self) -> None:
817
+ bounds = '\n '.join(
818
+ map(lambda kv: f'{kv[0]}: {kv[1]}', self.bounds.items()))
800
819
  print(f'policy hyper-parameters:\n'
801
- f' topology ={self._topology}\n'
802
- f' activation_fn ={self._activations[0].__name__}\n'
803
- f' initializer ={type(self._initializer_base).__name__}\n'
804
- f' apply_layer_norm={self._normalize}\n'
805
- f' layer_norm_args ={self._normalizer_kwargs}\n'
806
- f' wrap_non_bool ={self._wrap_non_bool}')
820
+ f' topology ={self._topology}\n'
821
+ f' activation_fn ={self._activations[0].__name__}\n'
822
+ f' initializer ={type(self._initializer_base).__name__}\n'
823
+ f' apply_input_norm ={self._normalize}\n'
824
+ f' input_norm_layerwise={self._normalize_per_layer}\n'
825
+ f' input_norm_args ={self._normalizer_kwargs}\n'
826
+ f'constraint-sat strategy:\n'
827
+ f' parsed_action_bounds=\n {bounds}\n'
828
+ f' wrap_non_bool ={self._wrap_non_bool}')
807
829
 
808
830
  def compile(self, compiled: JaxRDDLCompilerWithGrad,
809
831
  _bounds: Bounds,
@@ -821,7 +843,7 @@ class JaxDeepReactivePolicy(JaxPlan):
821
843
  if 1 < allowed_actions < bool_action_count:
822
844
  raise RDDLNotImplementedError(
823
845
  f'Deep reactive policies currently do not support '
824
- f'max-nondef-actions = {allowed_actions} > 1.')
846
+ f'max-nondef-actions {allowed_actions} > 1.')
825
847
  use_constraint_satisfaction = allowed_actions < bool_action_count
826
848
 
827
849
  noop = {var: (values[0] if isinstance(values, list) else values)
@@ -835,6 +857,7 @@ class JaxDeepReactivePolicy(JaxPlan):
835
857
 
836
858
  ranges = rddl.variable_ranges
837
859
  normalize = self._normalize
860
+ normalize_per_layer = self._normalize_per_layer
838
861
  wrap_non_bool = self._wrap_non_bool
839
862
  init = self._initializer
840
863
  layers = list(enumerate(zip(self._topology, self._activations)))
@@ -842,14 +865,67 @@ class JaxDeepReactivePolicy(JaxPlan):
842
865
  for (var, shape) in shapes.items()}
843
866
  layer_names = {var: f'output_{var}'.replace('-', '_') for var in shapes}
844
867
 
845
- # predict actions from the policy network for current state
846
- def _jax_wrapped_policy_network_predict(state):
868
+ # inputs for the policy network
869
+ if rddl.observ_fluents:
870
+ observed_vars = rddl.observ_fluents
871
+ else:
872
+ observed_vars = rddl.state_fluents
873
+ input_names = {var: f'{var}'.replace('-', '_') for var in observed_vars}
874
+
875
+ # catch if input norm is applied to size 1 tensor
876
+ if normalize:
877
+ non_bool_dims = 0
878
+ for (var, values) in observed_vars.items():
879
+ if ranges[var] != 'bool':
880
+ value_size = np.atleast_1d(values).size
881
+ if normalize_per_layer and value_size == 1:
882
+ raise_warning(
883
+ f'Cannot apply layer norm to state-fluent <{var}> '
884
+ f'of size 1: setting normalize_per_layer = False.',
885
+ 'red')
886
+ normalize_per_layer = False
887
+ non_bool_dims += value_size
888
+ if not normalize_per_layer and non_bool_dims == 1:
889
+ raise_warning(
890
+ 'Cannot apply layer norm to state-fluents of total size 1: '
891
+ 'setting normalize = False.', 'red')
892
+ normalize = False
893
+
894
+ # convert subs dictionary into a state vector to feed to the MLP
895
+ def _jax_wrapped_policy_input(subs):
847
896
 
848
- # apply layer norm
849
- if normalize:
897
+ # concatenate all state variables into a single vector
898
+ # optionally apply layer norm to each input tensor
899
+ states_bool, states_non_bool = [], []
900
+ non_bool_dims = 0
901
+ for (var, value) in subs.items():
902
+ if var in observed_vars:
903
+ state = jnp.ravel(value)
904
+ if ranges[var] == 'bool':
905
+ states_bool.append(state)
906
+ else:
907
+ if normalize and normalize_per_layer:
908
+ normalizer = hk.LayerNorm(
909
+ axis=-1, param_axis=-1,
910
+ name=f'input_norm_{input_names[var]}',
911
+ **self._normalizer_kwargs)
912
+ state = normalizer(state)
913
+ states_non_bool.append(state)
914
+ non_bool_dims += state.size
915
+ state = jnp.concatenate(states_non_bool + states_bool)
916
+
917
+ # optionally perform layer normalization on the non-bool inputs
918
+ if normalize and not normalize_per_layer and non_bool_dims:
850
919
  normalizer = hk.LayerNorm(
851
- axis=-1, param_axis=-1, **self._normalizer_kwargs)
852
- state = normalizer(state)
920
+ axis=-1, param_axis=-1, name='input_norm',
921
+ **self._normalizer_kwargs)
922
+ normalized = normalizer(state[:non_bool_dims])
923
+ state = state.at[:non_bool_dims].set(normalized)
924
+ return state
925
+
926
+ # predict actions from the policy network for current state
927
+ def _jax_wrapped_policy_network_predict(subs):
928
+ state = _jax_wrapped_policy_input(subs)
853
929
 
854
930
  # feed state vector through hidden layers
855
931
  hidden = state
@@ -913,25 +989,9 @@ class JaxDeepReactivePolicy(JaxPlan):
913
989
  start += size
914
990
  return actions
915
991
 
916
- if rddl.observ_fluents:
917
- observed_vars = rddl.observ_fluents
918
- else:
919
- observed_vars = rddl.state_fluents
920
-
921
- # state is concatenated into single tensor
922
- def _jax_wrapped_subs_to_state(subs):
923
- subs = {var: value
924
- for (var, value) in subs.items()
925
- if var in observed_vars}
926
- flat_subs = jax.tree_map(jnp.ravel, subs)
927
- states = list(flat_subs.values())
928
- state = jnp.concatenate(states)
929
- return state
930
-
931
992
  # train action prediction
932
993
  def _jax_wrapped_drp_predict_train(key, params, hyperparams, step, subs):
933
- state = _jax_wrapped_subs_to_state(subs)
934
- actions = predict_fn.apply(params, state)
994
+ actions = predict_fn.apply(params, subs)
935
995
  if not wrap_non_bool:
936
996
  for (var, action) in actions.items():
937
997
  if var != bool_key and ranges[var] != 'bool':
@@ -982,8 +1042,7 @@ class JaxDeepReactivePolicy(JaxPlan):
982
1042
  subs = {var: value[0, ...]
983
1043
  for (var, value) in subs.items()
984
1044
  if var in observed_vars}
985
- state = _jax_wrapped_subs_to_state(subs)
986
- params = predict_fn.init(key, state)
1045
+ params = predict_fn.init(key, subs)
987
1046
  return params
988
1047
 
989
1048
  self.initializer = _jax_wrapped_drp_init
@@ -1021,46 +1080,72 @@ class RollingMean:
1021
1080
  class JaxPlannerPlot:
1022
1081
  '''Supports plotting and visualization of a JAX policy in real time.'''
1023
1082
 
1024
- def __init__(self, rddl: RDDLPlanningModel, horizon: int) -> None:
1025
- self._fig, axes = plt.subplots(1 + len(rddl.action_fluents))
1083
+ def __init__(self, rddl: RDDLPlanningModel, horizon: int,
1084
+ show_violin: bool=True, show_action: bool=True) -> None:
1085
+ '''Creates a new planner visualizer.
1086
+
1087
+ :param rddl: the planning model to optimize
1088
+ :param horizon: the lookahead or planning horizon
1089
+ :param show_violin: whether to show the distribution of batch losses
1090
+ :param show_action: whether to show heatmaps of the action fluents
1091
+ '''
1092
+ num_plots = 1
1093
+ if show_violin:
1094
+ num_plots += 1
1095
+ if show_action:
1096
+ num_plots += len(rddl.action_fluents)
1097
+ self._fig, axes = plt.subplots(num_plots)
1098
+ if num_plots == 1:
1099
+ axes = [axes]
1026
1100
 
1027
1101
  # prepare the loss plot
1028
1102
  self._loss_ax = axes[0]
1029
1103
  self._loss_ax.autoscale(enable=True)
1030
- self._loss_ax.set_xlabel('decision epoch')
1104
+ self._loss_ax.set_xlabel('training time')
1031
1105
  self._loss_ax.set_ylabel('loss value')
1032
1106
  self._loss_plot = self._loss_ax.plot(
1033
1107
  [], [], linestyle=':', marker='o', markersize=2)[0]
1034
1108
  self._loss_back = self._fig.canvas.copy_from_bbox(self._loss_ax.bbox)
1035
1109
 
1110
+ # prepare the violin plot
1111
+ if show_violin:
1112
+ self._hist_ax = axes[1]
1113
+ else:
1114
+ self._hist_ax = None
1115
+
1036
1116
  # prepare the action plots
1037
- self._action_ax = {name: axes[idx + 1]
1038
- for (idx, name) in enumerate(rddl.action_fluents)}
1039
- self._action_plots = {}
1040
- for name in rddl.action_fluents:
1041
- ax = self._action_ax[name]
1042
- if rddl.variable_ranges[name] == 'bool':
1043
- vmin, vmax = 0.0, 1.0
1044
- else:
1045
- vmin, vmax = None, None
1046
- action_dim = 1
1047
- for dim in rddl.object_counts(rddl.variable_params[name]):
1048
- action_dim *= dim
1049
- action_plot = ax.pcolormesh(
1050
- np.zeros((action_dim, horizon)),
1051
- cmap='seismic', vmin=vmin, vmax=vmax)
1052
- ax.set_aspect('auto')
1053
- ax.set_xlabel('decision epoch')
1054
- ax.set_ylabel(name)
1055
- plt.colorbar(action_plot, ax=ax)
1056
- self._action_plots[name] = action_plot
1057
- self._action_back = {name: self._fig.canvas.copy_from_bbox(ax.bbox)
1058
- for (name, ax) in self._action_ax.items()}
1059
-
1117
+ if show_action:
1118
+ self._action_ax = {name: axes[idx + (2 if show_violin else 1)]
1119
+ for (idx, name) in enumerate(rddl.action_fluents)}
1120
+ self._action_plots = {}
1121
+ for name in rddl.action_fluents:
1122
+ ax = self._action_ax[name]
1123
+ if rddl.variable_ranges[name] == 'bool':
1124
+ vmin, vmax = 0.0, 1.0
1125
+ else:
1126
+ vmin, vmax = None, None
1127
+ action_dim = 1
1128
+ for dim in rddl.object_counts(rddl.variable_params[name]):
1129
+ action_dim *= dim
1130
+ action_plot = ax.pcolormesh(
1131
+ np.zeros((action_dim, horizon)),
1132
+ cmap='seismic', vmin=vmin, vmax=vmax)
1133
+ ax.set_aspect('auto')
1134
+ ax.set_xlabel('decision epoch')
1135
+ ax.set_ylabel(name)
1136
+ plt.colorbar(action_plot, ax=ax)
1137
+ self._action_plots[name] = action_plot
1138
+ self._action_back = {name: self._fig.canvas.copy_from_bbox(ax.bbox)
1139
+ for (name, ax) in self._action_ax.items()}
1140
+ else:
1141
+ self._action_ax = None
1142
+ self._action_plots = None
1143
+ self._action_back = None
1144
+
1060
1145
  plt.tight_layout()
1061
1146
  plt.show(block=False)
1062
1147
 
1063
- def redraw(self, xticks, losses, actions) -> None:
1148
+ def redraw(self, xticks, losses, actions, returns) -> None:
1064
1149
 
1065
1150
  # draw the loss curve
1066
1151
  self._fig.canvas.restore_region(self._loss_back)
@@ -1071,21 +1156,30 @@ class JaxPlannerPlot:
1071
1156
  self._loss_ax.draw_artist(self._loss_plot)
1072
1157
  self._fig.canvas.blit(self._loss_ax.bbox)
1073
1158
 
1159
+ # draw the violin plot
1160
+ if self._hist_ax is not None:
1161
+ self._hist_ax.clear()
1162
+ self._hist_ax.set_xlabel('loss value')
1163
+ self._hist_ax.set_ylabel('density')
1164
+ self._hist_ax.violinplot(returns, vert=False, showmeans=True)
1165
+
1074
1166
  # draw the actions
1075
- for (name, values) in actions.items():
1076
- values = np.mean(values, axis=0, dtype=float)
1077
- values = np.reshape(values, newshape=(values.shape[0], -1)).T
1078
- self._fig.canvas.restore_region(self._action_back[name])
1079
- self._action_plots[name].set_array(values)
1080
- self._action_ax[name].draw_artist(self._action_plots[name])
1081
- self._fig.canvas.blit(self._action_ax[name].bbox)
1082
- self._action_plots[name].set_clim([np.min(values), np.max(values)])
1167
+ if self._action_ax is not None:
1168
+ for (name, values) in actions.items():
1169
+ values = np.mean(values, axis=0, dtype=float)
1170
+ values = np.reshape(values, newshape=(values.shape[0], -1)).T
1171
+ self._fig.canvas.restore_region(self._action_back[name])
1172
+ self._action_plots[name].set_array(values)
1173
+ self._action_ax[name].draw_artist(self._action_plots[name])
1174
+ self._fig.canvas.blit(self._action_ax[name].bbox)
1175
+ self._action_plots[name].set_clim([np.min(values), np.max(values)])
1176
+
1083
1177
  self._fig.canvas.draw()
1084
1178
  self._fig.canvas.flush_events()
1085
1179
 
1086
1180
  def close(self) -> None:
1087
1181
  plt.close(self._fig)
1088
- del self._loss_ax, self._action_ax, \
1182
+ del self._loss_ax, self._hist_ax, self._action_ax, \
1089
1183
  self._loss_plot, self._action_plots, self._fig, \
1090
1184
  self._loss_back, self._action_back
1091
1185
 
@@ -1099,9 +1193,9 @@ class JaxPlannerStatus(Enum):
1099
1193
  NORMAL = 0
1100
1194
  NO_PROGRESS = 1
1101
1195
  PRECONDITION_POSSIBLY_UNSATISFIED = 2
1102
- TIME_BUDGET_REACHED = 3
1103
- ITER_BUDGET_REACHED = 4
1104
- INVALID_GRADIENT = 5
1196
+ INVALID_GRADIENT = 3
1197
+ TIME_BUDGET_REACHED = 4
1198
+ ITER_BUDGET_REACHED = 5
1105
1199
 
1106
1200
  def is_failure(self) -> bool:
1107
1201
  return self.value >= 3
@@ -1245,30 +1339,41 @@ class JaxBackpropPlanner:
1245
1339
  map(str, jax._src.xla_bridge.devices())).replace('\n', '')
1246
1340
  except Exception as _:
1247
1341
  devices_short = 'N/A'
1342
+ LOGO = \
1343
+ """
1344
+ __ ______ __ __ ______ __ ______ __ __
1345
+ /\ \ /\ __ \ /\_\_\_\ /\ == \/\ \ /\ __ \ /\ "-.\ \
1346
+ _\_\ \ \ \ __ \ \/_/\_\/_ \ \ _-/\ \ \____ \ \ __ \ \ \ \-. \
1347
+ /\_____\ \ \_\ \_\ /\_\/\_\ \ \_\ \ \_____\ \ \_\ \_\ \ \_\\"\_\
1348
+ \/_____/ \/_/\/_/ \/_/\/_/ \/_/ \/_____/ \/_/\/_/ \/_/ \/_/
1349
+ """
1350
+
1248
1351
  print('\n'
1249
- f'JAX Planner version {__version__}\n'
1352
+ f'{LOGO}\n'
1353
+ f'Version {__version__}\n'
1250
1354
  f'Python {sys.version}\n'
1251
1355
  f'jax {jax.version.__version__}, jaxlib {jaxlib_version}, '
1356
+ f'optax {optax.__version__}, haiku {hk.__version__}, '
1252
1357
  f'numpy {np.__version__}\n'
1253
1358
  f'devices: {devices_short}\n')
1254
1359
 
1255
1360
  def summarize_hyperparameters(self) -> None:
1256
1361
  print(f'objective hyper-parameters:\n'
1257
- f' utility_fn ={self.utility.__name__}\n'
1258
- f' utility args ={self.utility_kwargs}\n'
1259
- f' use_symlog ={self.use_symlog_reward}\n'
1260
- f' lookahead ={self.horizon}\n'
1261
- f' action_bounds ={self._action_bounds}\n'
1262
- f' fuzzy logic type={type(self.logic).__name__}\n'
1263
- f' nonfluents exact={self.compile_non_fluent_exact}\n'
1264
- f' cpfs_no_gradient={self.cpfs_without_grad}\n'
1362
+ f' utility_fn ={self.utility.__name__}\n'
1363
+ f' utility args ={self.utility_kwargs}\n'
1364
+ f' use_symlog ={self.use_symlog_reward}\n'
1365
+ f' lookahead ={self.horizon}\n'
1366
+ f' user_action_bounds={self._action_bounds}\n'
1367
+ f' fuzzy logic type ={type(self.logic).__name__}\n'
1368
+ f' nonfluents exact ={self.compile_non_fluent_exact}\n'
1369
+ f' cpfs_no_gradient ={self.cpfs_without_grad}\n'
1265
1370
  f'optimizer hyper-parameters:\n'
1266
- f' use_64_bit ={self.use64bit}\n'
1267
- f' optimizer ={self._optimizer_name.__name__}\n'
1268
- f' optimizer args ={self._optimizer_kwargs}\n'
1269
- f' clip_gradient ={self.clip_grad}\n'
1270
- f' batch_size_train={self.batch_size_train}\n'
1271
- f' batch_size_test ={self.batch_size_test}')
1371
+ f' use_64_bit ={self.use64bit}\n'
1372
+ f' optimizer ={self._optimizer_name.__name__}\n'
1373
+ f' optimizer args ={self._optimizer_kwargs}\n'
1374
+ f' clip_gradient ={self.clip_grad}\n'
1375
+ f' batch_size_train ={self.batch_size_train}\n'
1376
+ f' batch_size_test ={self.batch_size_test}')
1272
1377
  self.plan.summarize_hyperparameters()
1273
1378
  self.logic.summarize_hyperparameters()
1274
1379
 
@@ -1310,6 +1415,7 @@ class JaxBackpropPlanner:
1310
1415
  policy=self.plan.train_policy,
1311
1416
  n_steps=self.horizon,
1312
1417
  n_batch=self.batch_size_train)
1418
+ self.train_rollouts = train_rollouts
1313
1419
 
1314
1420
  test_rollouts = self.test_compiled.compile_rollouts(
1315
1421
  policy=self.plan.test_policy,
@@ -1417,17 +1523,106 @@ class JaxBackpropPlanner:
1417
1523
 
1418
1524
  return init_train, init_test
1419
1525
 
1526
+ def as_optimization_problem(
1527
+ self, key: Optional[random.PRNGKey]=None,
1528
+ policy_hyperparams: Optional[Pytree]=None,
1529
+ loss_function_updates_key: bool=True,
1530
+ grad_function_updates_key: bool=False) -> Tuple[Callable, Callable, np.ndarray, Callable]:
1531
+ '''Returns a function that computes the loss and a function that
1532
+ computes gradient of the return as a 1D vector given a 1D representation
1533
+ of policy parameters. These functions are designed to be compatible with
1534
+ off-the-shelf optimizers such as scipy.
1535
+
1536
+ Also returns the initial parameter vector to seed an optimizer,
1537
+ as well as a mapping that recovers the parameter pytree from the vector.
1538
+ The PRNG key is updated internally starting from the optional given key.
1539
+
1540
+ Constraints on actions, if they are required, cannot be constructed
1541
+ automatically in the general case. The user should build constraints
1542
+ for each problem in the format required by the downstream optimizer.
1543
+
1544
+ :param key: JAX PRNG key (derived from clock if not provided)
1545
+ :param policy_hyperparameters: hyper-parameters for the policy/plan,
1546
+ such as weights for sigmoid wrapping boolean actions (defaults to 1
1547
+ for all action-fluents if not provided)
1548
+ :param loss_function_updates_key: if True, the loss function
1549
+ updates the PRNG key internally independently of the grad function
1550
+ :param grad_function_updates_key: if True, the gradient function
1551
+ updates the PRNG key internally independently of the loss function.
1552
+ '''
1553
+
1554
+ # if PRNG key is not provided
1555
+ if key is None:
1556
+ key = random.PRNGKey(round(time.time() * 1000))
1557
+
1558
+ # initialize the initial fluents, model parameters, policy hyper-params
1559
+ subs = self.test_compiled.init_values
1560
+ train_subs, _ = self._batched_init_subs(subs)
1561
+ model_params = self.compiled.model_params
1562
+ if policy_hyperparams is None:
1563
+ raise_warning('policy_hyperparams is not set, setting 1.0 for '
1564
+ 'all action-fluents which could be suboptimal.')
1565
+ policy_hyperparams = {action: 1.0
1566
+ for action in self.rddl.action_fluents}
1567
+
1568
+ # initialize the policy parameters
1569
+ params_guess, *_ = self.initialize(key, policy_hyperparams, train_subs)
1570
+ guess_1d, unravel_fn = jax.flatten_util.ravel_pytree(params_guess)
1571
+ guess_1d = np.asarray(guess_1d)
1572
+
1573
+ # computes the training loss function and its 1D gradient
1574
+ loss_fn = self._jax_loss(self.train_rollouts)
1575
+
1576
+ @jax.jit
1577
+ def _loss_with_key(key, params_1d):
1578
+ policy_params = unravel_fn(params_1d)
1579
+ loss_val, _ = loss_fn(key, policy_params, policy_hyperparams,
1580
+ train_subs, model_params)
1581
+ return loss_val
1582
+
1583
+ @jax.jit
1584
+ def _grad_with_key(key, params_1d):
1585
+ policy_params = unravel_fn(params_1d)
1586
+ grad_fn = jax.grad(loss_fn, argnums=1, has_aux=True)
1587
+ grad_val, _ = grad_fn(key, policy_params, policy_hyperparams,
1588
+ train_subs, model_params)
1589
+ grad_1d = jax.flatten_util.ravel_pytree(grad_val)[0]
1590
+ return grad_1d
1591
+
1592
+ def _loss_function(params_1d):
1593
+ nonlocal key
1594
+ if loss_function_updates_key:
1595
+ key, subkey = random.split(key)
1596
+ else:
1597
+ subkey = key
1598
+ loss_val = _loss_with_key(subkey, params_1d)
1599
+ loss_val = float(loss_val)
1600
+ return loss_val
1601
+
1602
+ def _grad_function(params_1d):
1603
+ nonlocal key
1604
+ if grad_function_updates_key:
1605
+ key, subkey = random.split(key)
1606
+ else:
1607
+ subkey = key
1608
+ grad = _grad_with_key(subkey, params_1d)
1609
+ grad = np.asarray(grad)
1610
+ return grad
1611
+
1612
+ return _loss_function, _grad_function, guess_1d, jax.jit(unravel_fn)
1613
+
1420
1614
  # ===========================================================================
1421
1615
  # OPTIMIZE API
1422
1616
  # ===========================================================================
1423
1617
 
1424
1618
  def optimize(self, *args, **kwargs) -> Dict[str, Any]:
1425
- ''' Compute an optimal policy or plan. Return the callback from training.
1619
+ '''Compute an optimal policy or plan. Return the callback from training.
1426
1620
 
1427
1621
  :param key: JAX PRNG key (derived from clock if not provided)
1428
1622
  :param epochs: the maximum number of steps of gradient descent
1429
1623
  :param train_seconds: total time allocated for gradient descent
1430
1624
  :param plot_step: frequency to plot the plan and save result to disk
1625
+ :param plot_kwargs: additional arguments to pass to the plotter
1431
1626
  :param model_params: optional model-parameters to override default
1432
1627
  :param policy_hyperparams: hyper-parameters for the policy/plan, such as
1433
1628
  weights for sigmoid wrapping boolean actions
@@ -1435,7 +1630,9 @@ class JaxBackpropPlanner:
1435
1630
  their values: if None initializes all variables from the RDDL instance
1436
1631
  :param guess: initial policy parameters: if None will use the initializer
1437
1632
  specified in this instance
1438
- :param verbose: not print (0), print summary (1), print progress (2)
1633
+ :param print_summary: whether to print planner header, parameter
1634
+ summary, and diagnosis
1635
+ :param print_progress: whether to print the progress bar during training
1439
1636
  :param test_rolling_window: the test return is averaged on a rolling
1440
1637
  window of the past test_rolling_window returns when updating the best
1441
1638
  parameters found so far
@@ -1461,11 +1658,13 @@ class JaxBackpropPlanner:
1461
1658
  epochs: int=999999,
1462
1659
  train_seconds: float=120.,
1463
1660
  plot_step: Optional[int]=None,
1661
+ plot_kwargs: Optional[Dict[str, Any]]=None,
1464
1662
  model_params: Optional[Dict[str, Any]]=None,
1465
1663
  policy_hyperparams: Optional[Dict[str, Any]]=None,
1466
1664
  subs: Optional[Dict[str, Any]]=None,
1467
1665
  guess: Optional[Pytree]=None,
1468
- verbose: int=2,
1666
+ print_summary: bool=True,
1667
+ print_progress: bool=True,
1469
1668
  test_rolling_window: int=10,
1470
1669
  tqdm_position: Optional[int]=None) -> Generator[Dict[str, Any], None, None]:
1471
1670
  '''Returns a generator for computing an optimal policy or plan.
@@ -1476,20 +1675,22 @@ class JaxBackpropPlanner:
1476
1675
  :param epochs: the maximum number of steps of gradient descent
1477
1676
  :param train_seconds: total time allocated for gradient descent
1478
1677
  :param plot_step: frequency to plot the plan and save result to disk
1678
+ :param plot_kwargs: additional arguments to pass to the plotter
1479
1679
  :param model_params: optional model-parameters to override default
1480
1680
  :param policy_hyperparams: hyper-parameters for the policy/plan, such as
1481
1681
  weights for sigmoid wrapping boolean actions
1482
1682
  :param subs: dictionary mapping initial state and non-fluents to
1483
1683
  their values: if None initializes all variables from the RDDL instance
1484
1684
  :param guess: initial policy parameters: if None will use the initializer
1485
- specified in this instance
1486
- :param verbose: not print (0), print summary (1), print progress (2)
1685
+ specified in this instance
1686
+ :param print_summary: whether to print planner header, parameter
1687
+ summary, and diagnosis
1688
+ :param print_progress: whether to print the progress bar during training
1487
1689
  :param test_rolling_window: the test return is averaged on a rolling
1488
1690
  window of the past test_rolling_window returns when updating the best
1489
1691
  parameters found so far
1490
1692
  :param tqdm_position: position of tqdm progress bar (for multiprocessing)
1491
1693
  '''
1492
- verbose = int(verbose)
1493
1694
  start_time = time.time()
1494
1695
  elapsed_outside_loop = 0
1495
1696
 
@@ -1511,9 +1712,17 @@ class JaxBackpropPlanner:
1511
1712
  hyperparam_value = float(policy_hyperparams)
1512
1713
  policy_hyperparams = {action: hyperparam_value
1513
1714
  for action in self.rddl.action_fluents}
1715
+
1716
+ # fill in missing entries
1717
+ elif isinstance(policy_hyperparams, dict):
1718
+ for action in self.rddl.action_fluents:
1719
+ if action not in policy_hyperparams:
1720
+ raise_warning(f'policy_hyperparams[{action}] is not set, '
1721
+ 'setting 1.0 which could be suboptimal.')
1722
+ policy_hyperparams[action] = 1.0
1514
1723
 
1515
1724
  # print summary of parameters:
1516
- if verbose >= 1:
1725
+ if print_summary:
1517
1726
  self._summarize_system()
1518
1727
  self.summarize_hyperparameters()
1519
1728
  print(f'optimize() call hyper-parameters:\n'
@@ -1526,8 +1735,10 @@ class JaxBackpropPlanner:
1526
1735
  f' provide_param_guess={guess is not None}\n'
1527
1736
  f' test_rolling_window={test_rolling_window}\n'
1528
1737
  f' plot_frequency ={plot_step}\n'
1529
- f' verbose ={verbose}\n')
1530
- if verbose >= 2 and self.compiled.relaxations:
1738
+ f' plot_kwargs ={plot_kwargs}\n'
1739
+ f' print_summary ={print_summary}\n'
1740
+ f' print_progress ={print_progress}\n')
1741
+ if self.compiled.relaxations:
1531
1742
  print('Some RDDL operations are non-differentiable, '
1532
1743
  'replacing them with differentiable relaxations:')
1533
1744
  print(self.compiled.summarize_model_relaxations())
@@ -1549,7 +1760,7 @@ class JaxBackpropPlanner:
1549
1760
  'from the RDDL files.')
1550
1761
  train_subs, test_subs = self._batched_init_subs(subs)
1551
1762
 
1552
- # initialize, model parameters
1763
+ # initialize model parameters
1553
1764
  if model_params is None:
1554
1765
  model_params = self.compiled.model_params
1555
1766
  model_params_test = self.test_compiled.model_params
@@ -1570,27 +1781,40 @@ class JaxBackpropPlanner:
1570
1781
  rolling_test_loss = RollingMean(test_rolling_window)
1571
1782
  log = {}
1572
1783
  status = JaxPlannerStatus.NORMAL
1784
+ is_all_zero_fn = lambda x: np.allclose(x, 0)
1573
1785
 
1574
1786
  # initialize plot area
1575
1787
  if plot_step is None or plot_step <= 0 or plt is None:
1576
1788
  plot = None
1577
1789
  else:
1578
- plot = JaxPlannerPlot(self.rddl, self.horizon)
1790
+ if plot_kwargs is None:
1791
+ plot_kwargs = {}
1792
+ plot = JaxPlannerPlot(self.rddl, self.horizon, **plot_kwargs)
1579
1793
  xticks, loss_values = [], []
1580
1794
 
1581
1795
  # training loop
1582
1796
  iters = range(epochs)
1583
- if verbose >= 2:
1797
+ if print_progress:
1584
1798
  iters = tqdm(iters, total=100, position=tqdm_position)
1799
+ position_str = '' if tqdm_position is None else f'[{tqdm_position}]'
1585
1800
 
1586
1801
  for it in iters:
1587
1802
  status = JaxPlannerStatus.NORMAL
1588
1803
 
1589
1804
  # update the parameters of the plan
1590
1805
  key, subkey = random.split(key)
1591
- policy_params, converged, opt_state, opt_aux, train_loss, train_log = \
1806
+ policy_params, converged, opt_state, opt_aux, \
1807
+ train_loss, train_log = \
1592
1808
  self.update(subkey, policy_params, policy_hyperparams,
1593
1809
  train_subs, model_params, opt_state, opt_aux)
1810
+
1811
+ # no progress
1812
+ grad_norm_zero, _ = jax.tree_util.tree_flatten(
1813
+ jax.tree_map(is_all_zero_fn, train_log['grad']))
1814
+ if np.all(grad_norm_zero):
1815
+ status = JaxPlannerStatus.NO_PROGRESS
1816
+
1817
+ # constraint satisfaction problem
1594
1818
  if not np.all(converged):
1595
1819
  raise_warning(
1596
1820
  'Projected gradient method for satisfying action concurrency '
@@ -1598,13 +1822,18 @@ class JaxBackpropPlanner:
1598
1822
  'invalid for the current instance.', 'red')
1599
1823
  status = JaxPlannerStatus.PRECONDITION_POSSIBLY_UNSATISFIED
1600
1824
 
1601
- # evaluate losses
1825
+ # numerical error
1826
+ if not np.isfinite(train_loss):
1827
+ raise_warning(
1828
+ f'Aborting JAX planner due to invalid train loss {train_loss}.',
1829
+ 'red')
1830
+ status = JaxPlannerStatus.INVALID_GRADIENT
1831
+
1832
+ # evaluate test losses and record best plan so far
1602
1833
  test_loss, log = self.test_loss(
1603
1834
  subkey, policy_params, policy_hyperparams,
1604
1835
  test_subs, model_params_test)
1605
1836
  test_loss = rolling_test_loss.update(test_loss)
1606
-
1607
- # record the best plan so far
1608
1837
  if test_loss < best_loss:
1609
1838
  best_params, best_loss, best_grad = \
1610
1839
  policy_params, test_loss, train_log['grad']
@@ -1617,15 +1846,17 @@ class JaxBackpropPlanner:
1617
1846
  action_values = {name: values
1618
1847
  for (name, values) in log['fluents'].items()
1619
1848
  if name in self.rddl.action_fluents}
1620
- plot.redraw(xticks, loss_values, action_values)
1849
+ returns = -np.sum(np.asarray(log['reward']), axis=1)
1850
+ plot.redraw(xticks, loss_values, action_values, returns)
1621
1851
 
1622
1852
  # if the progress bar is used
1623
1853
  elapsed = time.time() - start_time - elapsed_outside_loop
1624
- if verbose >= 2:
1854
+ if print_progress:
1625
1855
  iters.n = int(100 * min(1, max(elapsed / train_seconds, it / epochs)))
1626
1856
  iters.set_description(
1627
- f'[{tqdm_position}] {it:6} it / {-train_loss:14.6f} train / '
1628
- f'{-test_loss:14.6f} test / {-best_loss:14.6f} best')
1857
+ f'{position_str} {it:6} it / {-train_loss:14.6f} train / '
1858
+ f'{-test_loss:14.6f} test / {-best_loss:14.6f} best / '
1859
+ f'{status.value} status')
1629
1860
 
1630
1861
  # reached computation budget
1631
1862
  if elapsed >= train_seconds:
@@ -1633,19 +1864,6 @@ class JaxBackpropPlanner:
1633
1864
  if it >= epochs - 1:
1634
1865
  status = JaxPlannerStatus.ITER_BUDGET_REACHED
1635
1866
 
1636
- # numerical error
1637
- if not np.isfinite(train_loss):
1638
- raise_warning(
1639
- f'Aborting JAX planner due to invalid train loss {train_loss}.',
1640
- 'red')
1641
- status = JaxPlannerStatus.INVALID_GRADIENT
1642
-
1643
- # no progress
1644
- grad_norm_zero, _ = jax.tree_util.tree_flatten(
1645
- jax.tree_map(lambda x: np.allclose(x, 0), train_log['grad']))
1646
- if np.all(grad_norm_zero):
1647
- status = JaxPlannerStatus.NO_PROGRESS
1648
-
1649
1867
  # return a callback
1650
1868
  start_time_outside = time.time()
1651
1869
  yield {
@@ -1671,7 +1889,7 @@ class JaxBackpropPlanner:
1671
1889
  break
1672
1890
 
1673
1891
  # release resources
1674
- if verbose >= 2:
1892
+ if print_progress:
1675
1893
  iters.close()
1676
1894
  if plot is not None:
1677
1895
  plot.close()
@@ -1688,7 +1906,7 @@ class JaxBackpropPlanner:
1688
1906
  f'during test evaluation:\n{messages}', 'red')
1689
1907
 
1690
1908
  # summarize and test for convergence
1691
- if verbose >= 1:
1909
+ if print_summary:
1692
1910
  grad_norm = jax.tree_map(lambda x: np.linalg.norm(x).item(), best_grad)
1693
1911
  diagnosis = self._perform_diagnosis(
1694
1912
  last_iter_improve, -train_loss, -test_loss, -best_loss, grad_norm)
@@ -1698,7 +1916,7 @@ class JaxBackpropPlanner:
1698
1916
  f' iterations ={it}\n'
1699
1917
  f' best_objective={-best_loss}\n'
1700
1918
  f' best_grad_norm={grad_norm}\n'
1701
- f'diagnosis: {diagnosis}\n')
1919
+ f' diagnosis: {diagnosis}\n')
1702
1920
 
1703
1921
  def _perform_diagnosis(self, last_iter_improve,
1704
1922
  train_return, test_return, best_return, grad_norm):
@@ -1778,17 +1996,19 @@ class JaxBackpropPlanner:
1778
1996
  raise ValueError(f'State dictionary passed to the JAX policy is '
1779
1997
  f'grounded, since it contains the key <{var}>, '
1780
1998
  f'but a vectorized environment is required: '
1781
- f'please make sure vectorized=True in the RDDLEnv.')
1999
+ f'make sure vectorized = True in the RDDLEnv.')
1782
2000
 
1783
2001
  # must be numeric array
1784
2002
  # exception is for POMDPs at 1st epoch when observ-fluents are None
1785
- if not jnp.issubdtype(values.dtype, jnp.number) \
1786
- and not jnp.issubdtype(values.dtype, jnp.bool_):
2003
+ dtype = np.atleast_1d(values).dtype
2004
+ if not jnp.issubdtype(dtype, jnp.number) \
2005
+ and not jnp.issubdtype(dtype, jnp.bool_):
1787
2006
  if step == 0 and var in self.rddl.observ_fluents:
1788
2007
  subs[var] = self.test_compiled.init_values[var]
1789
2008
  else:
1790
- raise ValueError(f'Values assigned to pvariable {var} are '
1791
- f'non-numeric of type {values.dtype}: {values}.')
2009
+ raise ValueError(
2010
+ f'Values {values} assigned to p-variable <{var}> are '
2011
+ f'non-numeric of type {dtype}.')
1792
2012
 
1793
2013
  # cast device arrays to numpy
1794
2014
  actions = self.test_policy(key, params, policy_hyperparams, step, subs)
@@ -1801,8 +2021,6 @@ class JaxLineSearchPlanner(JaxBackpropPlanner):
1801
2021
  linear search gradient descent, with the Armijo condition.'''
1802
2022
 
1803
2023
  def __init__(self, *args,
1804
- optimizer: Callable[..., optax.GradientTransformation]=optax.sgd,
1805
- optimizer_kwargs: Kwargs={'learning_rate': 1.0},
1806
2024
  decay: float=0.8,
1807
2025
  c: float=0.1,
1808
2026
  step_max: float=1.0,
@@ -1825,11 +2043,7 @@ class JaxLineSearchPlanner(JaxBackpropPlanner):
1825
2043
  raise_warning('clip_grad parameter conflicts with '
1826
2044
  'line search planner and will be ignored.', 'red')
1827
2045
  del kwargs['clip_grad']
1828
- super(JaxLineSearchPlanner, self).__init__(
1829
- *args,
1830
- optimizer=optimizer,
1831
- optimizer_kwargs=optimizer_kwargs,
1832
- **kwargs)
2046
+ super(JaxLineSearchPlanner, self).__init__(*args, **kwargs)
1833
2047
 
1834
2048
  def summarize_hyperparameters(self) -> None:
1835
2049
  super(JaxLineSearchPlanner, self).summarize_hyperparameters()
@@ -1878,7 +2092,8 @@ class JaxLineSearchPlanner(JaxBackpropPlanner):
1878
2092
  step = lrmax / decay
1879
2093
  f_step = np.inf
1880
2094
  best_f, best_step, best_params, best_state = np.inf, None, None, None
1881
- while f_step > f - c * step * gnorm2 and step * decay >= lrmin:
2095
+ while (f_step > f - c * step * gnorm2 and step * decay >= lrmin) \
2096
+ or not trials:
1882
2097
  trials += 1
1883
2098
  step *= decay
1884
2099
  f_step, new_params, new_state = _jax_wrapped_line_search_trial(
@@ -1913,12 +2128,12 @@ class JaxLineSearchPlanner(JaxBackpropPlanner):
1913
2128
  @jax.jit
1914
2129
  def entropic_utility(returns: jnp.ndarray, beta: float) -> float:
1915
2130
  return (-1.0 / beta) * jax.scipy.special.logsumexp(
1916
- -beta * returns, b=1.0 / returns.size)
2131
+ -beta * returns, b=1.0 / returns.size)
1917
2132
 
1918
2133
 
1919
2134
  @jax.jit
1920
2135
  def mean_variance_utility(returns: jnp.ndarray, beta: float) -> float:
1921
- return jnp.mean(returns) - (beta / 2.0) * jnp.var(returns)
2136
+ return jnp.mean(returns) - 0.5 * beta * jnp.var(returns)
1922
2137
 
1923
2138
 
1924
2139
  @jax.jit
@@ -1986,7 +2201,8 @@ class JaxOfflineController(BaseAgent):
1986
2201
  def reset(self) -> None:
1987
2202
  self.step = 0
1988
2203
  if self.train_on_reset and not self.params_given:
1989
- self.params = self.planner.optimize(key=self.key, **self.train_kwargs)
2204
+ callback = self.planner.optimize(key=self.key, **self.train_kwargs)
2205
+ self.params = callback['best_params']
1990
2206
 
1991
2207
 
1992
2208
  class JaxOnlineController(BaseAgent):