pyRDDLGym-jax 1.3__py3-none-any.whl → 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pyRDDLGym_jax/__init__.py CHANGED
@@ -1 +1 @@
1
- __version__ = '1.3'
1
+ __version__ = '2.1'
@@ -1,3 +1,18 @@
1
+ # ***********************************************************************
2
+ # JAXPLAN
3
+ #
4
+ # Author: Michael Gimelfarb
5
+ #
6
+ # REFERENCES:
7
+ #
8
+ # [1] Gimelfarb, Michael, Ayal Taitler, and Scott Sanner. "JaxPlan and GurobiPlan:
9
+ # Optimization Baselines for Replanning in Discrete and Mixed Discrete-Continuous
10
+ # Probabilistic Domains." Proceedings of the International Conference on Automated
11
+ # Planning and Scheduling. Vol. 34. 2024.
12
+ #
13
+ # ***********************************************************************
14
+
15
+
1
16
  from functools import partial
2
17
  import traceback
3
18
  from typing import Any, Callable, Dict, List, Optional
@@ -5,7 +20,6 @@ from typing import Any, Callable, Dict, List, Optional
5
20
  import jax
6
21
  import jax.numpy as jnp
7
22
  import jax.random as random
8
- import jax.scipy as scipy
9
23
 
10
24
  from pyRDDLGym.core.compiler.initializer import RDDLValueInitializer
11
25
  from pyRDDLGym.core.compiler.levels import RDDLLevelAnalysis
@@ -28,8 +42,7 @@ try:
28
42
  from tensorflow_probability.substrates import jax as tfp
29
43
  except Exception:
30
44
  raise_warning('Failed to import tensorflow-probability: '
31
- 'compilation of some complex distributions '
32
- '(Binomial, Negative-Binomial, Multinomial) will fail.', 'red')
45
+ 'compilation of some probability distributions will fail.', 'red')
33
46
  traceback.print_exc()
34
47
  tfp = None
35
48
 
@@ -39,102 +52,6 @@ class JaxRDDLCompiler:
39
52
  All operations are identical to their numpy equivalents.
40
53
  '''
41
54
 
42
- MODEL_PARAM_TAG_SEPARATOR = '___'
43
-
44
- # ===========================================================================
45
- # EXACT RDDL TO JAX COMPILATION RULES BY DEFAULT
46
- # ===========================================================================
47
-
48
- @staticmethod
49
- def wrap_logic(func):
50
- def exact_func(id, init_params):
51
- return func
52
- return exact_func
53
-
54
- EXACT_RDDL_TO_JAX_NEGATIVE = wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.negative))
55
- EXACT_RDDL_TO_JAX_ARITHMETIC = {
56
- '+': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.add)),
57
- '-': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.subtract)),
58
- '*': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.multiply)),
59
- '/': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.divide))
60
- }
61
-
62
- EXACT_RDDL_TO_JAX_RELATIONAL = {
63
- '>=': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.greater_equal)),
64
- '<=': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.less_equal)),
65
- '<': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.less)),
66
- '>': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.greater)),
67
- '==': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.equal)),
68
- '~=': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.not_equal))
69
- }
70
-
71
- EXACT_RDDL_TO_JAX_LOGICAL_NOT = wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.logical_not))
72
- EXACT_RDDL_TO_JAX_LOGICAL = {
73
- '^': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_and)),
74
- '&': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_and)),
75
- '|': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_or)),
76
- '~': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_xor)),
77
- '=>': wrap_logic.__func__(ExactLogic.exact_binary_implies),
78
- '<=>': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.equal))
79
- }
80
-
81
- EXACT_RDDL_TO_JAX_AGGREGATION = {
82
- 'sum': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.sum)),
83
- 'avg': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.mean)),
84
- 'prod': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.prod)),
85
- 'minimum': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.min)),
86
- 'maximum': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.max)),
87
- 'forall': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.all)),
88
- 'exists': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.any)),
89
- 'argmin': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.argmin)),
90
- 'argmax': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.argmax))
91
- }
92
-
93
- EXACT_RDDL_TO_JAX_UNARY = {
94
- 'abs': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.abs)),
95
- 'sgn': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sign)),
96
- 'round': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.round)),
97
- 'floor': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.floor)),
98
- 'ceil': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.ceil)),
99
- 'cos': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.cos)),
100
- 'sin': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sin)),
101
- 'tan': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.tan)),
102
- 'acos': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.arccos)),
103
- 'asin': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.arcsin)),
104
- 'atan': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.arctan)),
105
- 'cosh': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.cosh)),
106
- 'sinh': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sinh)),
107
- 'tanh': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.tanh)),
108
- 'exp': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.exp)),
109
- 'ln': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.log)),
110
- 'sqrt': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sqrt)),
111
- 'lngamma': wrap_logic.__func__(ExactLogic.exact_unary_function(scipy.special.gammaln)),
112
- 'gamma': wrap_logic.__func__(ExactLogic.exact_unary_function(scipy.special.gamma))
113
- }
114
-
115
- @staticmethod
116
- def _jax_wrapped_calc_log_exact(x, y, params):
117
- return jnp.log(x) / jnp.log(y), params
118
-
119
- EXACT_RDDL_TO_JAX_BINARY = {
120
- 'div': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.floor_divide)),
121
- 'mod': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.mod)),
122
- 'fmod': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.mod)),
123
- 'min': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.minimum)),
124
- 'max': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.maximum)),
125
- 'pow': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.power)),
126
- 'log': wrap_logic.__func__(_jax_wrapped_calc_log_exact.__func__),
127
- 'hypot': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.hypot)),
128
- }
129
-
130
- EXACT_RDDL_TO_JAX_IF = wrap_logic.__func__(ExactLogic.exact_if_then_else)
131
- EXACT_RDDL_TO_JAX_SWITCH = wrap_logic.__func__(ExactLogic.exact_switch)
132
-
133
- EXACT_RDDL_TO_JAX_BERNOULLI = wrap_logic.__func__(ExactLogic.exact_bernoulli)
134
- EXACT_RDDL_TO_JAX_DISCRETE = wrap_logic.__func__(ExactLogic.exact_discrete)
135
- EXACT_RDDL_TO_JAX_POISSON = wrap_logic.__func__(ExactLogic.exact_poisson)
136
- EXACT_RDDL_TO_JAX_GEOMETRIC = wrap_logic.__func__(ExactLogic.exact_geometric)
137
-
138
55
  def __init__(self, rddl: RDDLLiftedModel,
139
56
  allow_synchronous_state: bool=True,
140
57
  logger: Optional[Logger]=None,
@@ -174,8 +91,7 @@ class JaxRDDLCompiler:
174
91
  self.init_values = initializer.initialize()
175
92
 
176
93
  # compute dependency graph for CPFs and sort them by evaluation order
177
- sorter = RDDLLevelAnalysis(
178
- rddl, allow_synchronous_state=allow_synchronous_state)
94
+ sorter = RDDLLevelAnalysis(rddl, allow_synchronous_state=allow_synchronous_state)
179
95
  self.levels = sorter.compute_levels()
180
96
 
181
97
  # trace expressions to cache information to be used later
@@ -187,28 +103,17 @@ class JaxRDDLCompiler:
187
103
  rddl=self.rddl,
188
104
  init_values=self.init_values,
189
105
  levels=self.levels,
190
- trace_info=self.traced)
106
+ trace_info=self.traced
107
+ )
191
108
  constraints = RDDLConstraints(simulator, vectorized=True)
192
109
  self.constraints = constraints
193
110
 
194
111
  # basic operations - these can be override in subclasses
195
112
  self.compile_non_fluent_exact = compile_non_fluent_exact
196
- self.NEGATIVE = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_NEGATIVE
197
- self.ARITHMETIC_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_ARITHMETIC.copy()
198
- self.RELATIONAL_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_RELATIONAL.copy()
199
- self.LOGICAL_NOT = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_LOGICAL_NOT
200
- self.LOGICAL_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_LOGICAL.copy()
201
- self.AGGREGATION_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_AGGREGATION.copy()
202
113
  self.AGGREGATION_BOOL = {'forall', 'exists'}
203
- self.KNOWN_UNARY = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_UNARY.copy()
204
- self.KNOWN_BINARY = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BINARY.copy()
205
- self.IF_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_IF
206
- self.SWITCH_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_SWITCH
207
- self.BERNOULLI_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BERNOULLI
208
- self.DISCRETE_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_DISCRETE
209
- self.POISSON_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_POISSON
210
- self.GEOMETRIC_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_GEOMETRIC
211
-
114
+ self.EXACT_OPS = ExactLogic(use64bit=self.use64bit).get_operator_dicts()
115
+ self.OPS = self.EXACT_OPS
116
+
212
117
  # ===========================================================================
213
118
  # main compilation subroutines
214
119
  # ===========================================================================
@@ -377,7 +282,8 @@ class JaxRDDLCompiler:
377
282
 
378
283
  # compile constraint information
379
284
  if constraint_func:
380
- inequality_fns, equality_fns = self._jax_nonlinear_constraints(init_params_constr)
285
+ inequality_fns, equality_fns = self._jax_nonlinear_constraints(
286
+ init_params_constr)
381
287
  else:
382
288
  inequality_fns, equality_fns = None, None
383
289
 
@@ -524,7 +430,7 @@ class JaxRDDLCompiler:
524
430
  _jax_wrapped_single_step_policy,
525
431
  in_axes=(0, None, None, None, 0, None)
526
432
  )(keys, policy_params, hyperparams, step, subs, model_params)
527
- model_params = jax.tree_map(lambda x: jnp.mean(x, axis=0), model_params)
433
+ model_params = jax.tree_map(partial(jnp.mean, axis=0), model_params)
528
434
  carry = (key, policy_params, hyperparams, subs, model_params)
529
435
  return carry, log
530
436
 
@@ -571,7 +477,11 @@ class JaxRDDLCompiler:
571
477
  for (id, value) in self.model_params.items():
572
478
  expr_id = int(str(id).split('_')[0])
573
479
  expr = self.traced.lookup(expr_id)
574
- result[id] = {'id': expr_id, 'rddl_op': ' '.join(expr.etype), 'init_value': value}
480
+ result[id] = {
481
+ 'id': expr_id,
482
+ 'rddl_op': ' '.join(expr.etype),
483
+ 'init_value': value
484
+ }
575
485
  return result
576
486
 
577
487
  @staticmethod
@@ -722,7 +632,7 @@ class JaxRDDLCompiler:
722
632
  return _jax_wrapped_cast
723
633
 
724
634
  def _fix_dtype(self, value):
725
- dtype = jnp.atleast_1d(value).dtype
635
+ dtype = jnp.result_type(value)
726
636
  if jnp.issubdtype(dtype, jnp.integer):
727
637
  return self.INT
728
638
  elif jnp.issubdtype(dtype, jnp.floating):
@@ -870,11 +780,11 @@ class JaxRDDLCompiler:
870
780
 
871
781
  # if expression is non-fluent, always use the exact operation
872
782
  if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
873
- valid_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_ARITHMETIC
874
- negative_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_NEGATIVE
783
+ valid_ops = self.EXACT_OPS['arithmetic']
784
+ negative_op = self.EXACT_OPS['negative']
875
785
  else:
876
- valid_ops = self.ARITHMETIC_OPS
877
- negative_op = self.NEGATIVE
786
+ valid_ops = self.OPS['arithmetic']
787
+ negative_op = self.OPS['negative']
878
788
  JaxRDDLCompiler._check_valid_op(expr, valid_ops)
879
789
 
880
790
  # recursively compile arguments
@@ -901,9 +811,9 @@ class JaxRDDLCompiler:
901
811
 
902
812
  # if expression is non-fluent, always use the exact operation
903
813
  if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
904
- valid_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_RELATIONAL
814
+ valid_ops = self.EXACT_OPS['relational']
905
815
  else:
906
- valid_ops = self.RELATIONAL_OPS
816
+ valid_ops = self.OPS['relational']
907
817
  JaxRDDLCompiler._check_valid_op(expr, valid_ops)
908
818
 
909
819
  # recursively compile arguments
@@ -919,11 +829,11 @@ class JaxRDDLCompiler:
919
829
 
920
830
  # if expression is non-fluent, always use the exact operation
921
831
  if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
922
- valid_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_LOGICAL
923
- logical_not_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_LOGICAL_NOT
832
+ valid_ops = self.EXACT_OPS['logical']
833
+ logical_not_op = self.EXACT_OPS['logical_not']
924
834
  else:
925
- valid_ops = self.LOGICAL_OPS
926
- logical_not_op = self.LOGICAL_NOT
835
+ valid_ops = self.OPS['logical']
836
+ logical_not_op = self.OPS['logical_not']
927
837
  JaxRDDLCompiler._check_valid_op(expr, valid_ops)
928
838
 
929
839
  # recursively compile arguments
@@ -951,9 +861,9 @@ class JaxRDDLCompiler:
951
861
 
952
862
  # if expression is non-fluent, always use the exact operation
953
863
  if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
954
- valid_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_AGGREGATION
864
+ valid_ops = self.EXACT_OPS['aggregation']
955
865
  else:
956
- valid_ops = self.AGGREGATION_OPS
866
+ valid_ops = self.OPS['aggregation']
957
867
  JaxRDDLCompiler._check_valid_op(expr, valid_ops)
958
868
  is_floating = op not in self.AGGREGATION_BOOL
959
869
 
@@ -980,11 +890,11 @@ class JaxRDDLCompiler:
980
890
 
981
891
  # if expression is non-fluent, always use the exact operation
982
892
  if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
983
- unary_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_UNARY
984
- binary_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BINARY
893
+ unary_ops = self.EXACT_OPS['unary']
894
+ binary_ops = self.EXACT_OPS['binary']
985
895
  else:
986
- unary_ops = self.KNOWN_UNARY
987
- binary_ops = self.KNOWN_BINARY
896
+ unary_ops = self.OPS['unary']
897
+ binary_ops = self.OPS['binary']
988
898
 
989
899
  # recursively compile arguments
990
900
  if op in unary_ops:
@@ -1026,9 +936,9 @@ class JaxRDDLCompiler:
1026
936
 
1027
937
  # if predicate is non-fluent, always use the exact operation
1028
938
  if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(pred):
1029
- if_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_IF
939
+ if_op = self.EXACT_OPS['control']['if']
1030
940
  else:
1031
- if_op = self.IF_HELPER
941
+ if_op = self.OPS['control']['if']
1032
942
  jax_op = if_op(expr.id, init_params)
1033
943
 
1034
944
  # recursively compile arguments
@@ -1054,9 +964,9 @@ class JaxRDDLCompiler:
1054
964
  # if predicate is non-fluent, always use the exact operation
1055
965
  # case conditions are currently only literals so they are non-fluent
1056
966
  if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(pred):
1057
- switch_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_SWITCH
967
+ switch_op = self.EXACT_OPS['control']['switch']
1058
968
  else:
1059
- switch_op = self.SWITCH_HELPER
969
+ switch_op = self.OPS['control']['switch']
1060
970
  jax_op = switch_op(expr.id, init_params)
1061
971
 
1062
972
  # recursively compile predicate
@@ -1078,8 +988,7 @@ class JaxRDDLCompiler:
1078
988
  for (i, jax_case) in enumerate(jax_cases):
1079
989
  sample_cases[i], key, err_case, params = jax_case(x, params, key)
1080
990
  err |= err_case
1081
- sample_cases = jnp.asarray(
1082
- sample_cases, dtype=self._fix_dtype(sample_cases))
991
+ sample_cases = jnp.asarray(sample_cases, dtype=self._fix_dtype(sample_cases))
1083
992
 
1084
993
  # predicate (enum) is an integer - use it to extract from case array
1085
994
  sample, params = jax_op(sample_pred, sample_cases, params)
@@ -1098,6 +1007,7 @@ class JaxRDDLCompiler:
1098
1007
  # Bernoulli: complete (subclass uses Gumbel-softmax)
1099
1008
  # Normal: complete
1100
1009
  # Exponential: complete
1010
+ # Geometric: complete
1101
1011
  # Weibull: complete
1102
1012
  # Pareto: complete
1103
1013
  # Gumbel: complete
@@ -1110,14 +1020,18 @@ class JaxRDDLCompiler:
1110
1020
  # Discrete(p): complete (subclass uses Gumbel-softmax)
1111
1021
  # UnnormDiscrete(p): complete (subclass uses Gumbel-softmax)
1112
1022
 
1023
+ # distributions which seem to support backpropagation (need more testing):
1024
+ # Beta
1025
+ # Student
1026
+ # Gamma
1027
+ # ChiSquare
1028
+ # Dirichlet
1029
+ # Poisson (subclass uses Gumbel-softmax or Poisson process trick)
1030
+
1113
1031
  # distributions with incomplete reparameterization support (TODO):
1114
- # Binomial: (use truncation and Gumbel-softmax)
1115
- # NegativeBinomial: (no reparameterization)
1116
- # Poisson: (use truncation and Gumbel-softmax)
1117
- # Gamma, ChiSquare: (no shape reparameterization)
1118
- # Beta: (no reparameterization)
1119
- # Geometric: (implement safe floor)
1120
- # Student: (no reparameterization)
1032
+ # Binomial
1033
+ # NegativeBinomial
1034
+ # Multinomial
1121
1035
 
1122
1036
  def _jax_random(self, expr, init_params):
1123
1037
  _, name = expr.etype
@@ -1173,8 +1087,7 @@ class JaxRDDLCompiler:
1173
1087
  return self._jax_discrete_pvar(expr, init_params, unnorm=True)
1174
1088
  else:
1175
1089
  raise RDDLNotImplementedError(
1176
- f'Distribution {name} is not supported.\n' +
1177
- print_stack_trace(expr))
1090
+ f'Distribution {name} is not supported.\n' + print_stack_trace(expr))
1178
1091
 
1179
1092
  def _jax_kron(self, expr, init_params):
1180
1093
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_KRON_DELTA']
@@ -1251,8 +1164,7 @@ class JaxRDDLCompiler:
1251
1164
  def _jax_wrapped_distribution_exp(x, params, key):
1252
1165
  scale, key, err, params = jax_scale(x, params, key)
1253
1166
  key, subkey = random.split(key)
1254
- Exp1 = random.exponential(
1255
- key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
1167
+ Exp1 = random.exponential(key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
1256
1168
  sample = scale * Exp1
1257
1169
  out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
1258
1170
  err |= (out_of_bounds * ERR)
@@ -1273,8 +1185,8 @@ class JaxRDDLCompiler:
1273
1185
  shape, key, err1, params = jax_shape(x, params, key)
1274
1186
  scale, key, err2, params = jax_scale(x, params, key)
1275
1187
  key, subkey = random.split(key)
1276
- U = random.uniform(key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
1277
- sample = scale * jnp.power(-jnp.log(U), 1.0 / shape)
1188
+ sample = random.weibull_min(
1189
+ key=subkey, scale=scale, concentration=shape, dtype=self.REAL)
1278
1190
  out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
1279
1191
  err = err1 | err2 | (out_of_bounds * ERR)
1280
1192
  return sample, key, err, params
@@ -1288,9 +1200,9 @@ class JaxRDDLCompiler:
1288
1200
 
1289
1201
  # if probability is non-fluent, always use the exact operation
1290
1202
  if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_prob):
1291
- bern_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BERNOULLI
1203
+ bern_op = self.EXACT_OPS['sampling']['Bernoulli']
1292
1204
  else:
1293
- bern_op = self.BERNOULLI_HELPER
1205
+ bern_op = self.OPS['sampling']['Bernoulli']
1294
1206
  jax_op = bern_op(expr.id, init_params)
1295
1207
 
1296
1208
  # recursively compile arguments
@@ -1313,9 +1225,9 @@ class JaxRDDLCompiler:
1313
1225
 
1314
1226
  # if rate is non-fluent, always use the exact operation
1315
1227
  if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_rate):
1316
- poisson_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_POISSON
1228
+ poisson_op = self.EXACT_OPS['sampling']['Poisson']
1317
1229
  else:
1318
- poisson_op = self.POISSON_HELPER
1230
+ poisson_op = self.OPS['sampling']['Poisson']
1319
1231
  jax_op = poisson_op(expr.id, init_params)
1320
1232
 
1321
1233
  # recursively compile arguments
@@ -1326,7 +1238,6 @@ class JaxRDDLCompiler:
1326
1238
  rate, key, err, params = jax_rate(x, params, key)
1327
1239
  key, subkey = random.split(key)
1328
1240
  sample, params = jax_op(subkey, rate, params)
1329
- sample = sample.astype(self.INT)
1330
1241
  out_of_bounds = jnp.logical_not(jnp.all(rate >= 0))
1331
1242
  err |= (out_of_bounds * ERR)
1332
1243
  return sample, key, err, params
@@ -1358,20 +1269,26 @@ class JaxRDDLCompiler:
1358
1269
  def _jax_binomial(self, expr, init_params):
1359
1270
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_BINOMIAL']
1360
1271
  JaxRDDLCompiler._check_num_args(expr, 2)
1361
-
1362
1272
  arg_trials, arg_prob = expr.args
1273
+
1274
+ # if prob is non-fluent, always use the exact operation
1275
+ if self.compile_non_fluent_exact \
1276
+ and not self.traced.cached_is_fluent(arg_trials) \
1277
+ and not self.traced.cached_is_fluent(arg_prob):
1278
+ bin_op = self.EXACT_OPS['sampling']['Binomial']
1279
+ else:
1280
+ bin_op = self.OPS['sampling']['Binomial']
1281
+ jax_op = bin_op(expr.id, init_params)
1282
+
1363
1283
  jax_trials = self._jax(arg_trials, init_params)
1364
1284
  jax_prob = self._jax(arg_prob, init_params)
1365
-
1366
- # uses the JAX substrate of tensorflow-probability
1285
+
1286
+ # uses reduction for constant trials
1367
1287
  def _jax_wrapped_distribution_binomial(x, params, key):
1368
1288
  trials, key, err2, params = jax_trials(x, params, key)
1369
1289
  prob, key, err1, params = jax_prob(x, params, key)
1370
- trials = jnp.asarray(trials, dtype=self.REAL)
1371
- prob = jnp.asarray(prob, dtype=self.REAL)
1372
1290
  key, subkey = random.split(key)
1373
- dist = tfp.distributions.Binomial(total_count=trials, probs=prob)
1374
- sample = dist.sample(seed=subkey).astype(self.INT)
1291
+ sample, params = jax_op(subkey, trials, prob, params)
1375
1292
  out_of_bounds = jnp.logical_not(jnp.all(
1376
1293
  (prob >= 0) & (prob <= 1) & (trials >= 0)))
1377
1294
  err = err1 | err2 | (out_of_bounds * ERR)
@@ -1395,10 +1312,9 @@ class JaxRDDLCompiler:
1395
1312
  prob = jnp.asarray(prob, dtype=self.REAL)
1396
1313
  key, subkey = random.split(key)
1397
1314
  dist = tfp.distributions.NegativeBinomial(total_count=trials, probs=prob)
1398
- sample = dist.sample(seed=subkey).astype(self.INT)
1315
+ sample = jnp.asarray(dist.sample(seed=subkey), dtype=self.INT)
1399
1316
  out_of_bounds = jnp.logical_not(jnp.all(
1400
- (prob >= 0) & (prob <= 1) & (trials > 0))
1401
- )
1317
+ (prob >= 0) & (prob <= 1) & (trials > 0)))
1402
1318
  err = err1 | err2 | (out_of_bounds * ERR)
1403
1319
  return sample, key, err, params
1404
1320
 
@@ -1431,9 +1347,9 @@ class JaxRDDLCompiler:
1431
1347
 
1432
1348
  # if prob is non-fluent, always use the exact operation
1433
1349
  if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_prob):
1434
- geom_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_GEOMETRIC
1350
+ geom_op = self.EXACT_OPS['sampling']['Geometric']
1435
1351
  else:
1436
- geom_op = self.GEOMETRIC_HELPER
1352
+ geom_op = self.OPS['sampling']['Geometric']
1437
1353
  jax_op = geom_op(expr.id, init_params)
1438
1354
 
1439
1355
  # recursively compile arguments
@@ -1443,7 +1359,6 @@ class JaxRDDLCompiler:
1443
1359
  prob, key, err, params = jax_prob(x, params, key)
1444
1360
  key, subkey = random.split(key)
1445
1361
  sample, params = jax_op(subkey, prob, params)
1446
- sample = sample.astype(self.INT)
1447
1362
  out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
1448
1363
  err |= (out_of_bounds * ERR)
1449
1364
  return sample, key, err, params
@@ -1482,8 +1397,7 @@ class JaxRDDLCompiler:
1482
1397
  def _jax_wrapped_distribution_t(x, params, key):
1483
1398
  df, key, err, params = jax_df(x, params, key)
1484
1399
  key, subkey = random.split(key)
1485
- sample = random.t(
1486
- key=subkey, df=df, shape=jnp.shape(df), dtype=self.REAL)
1400
+ sample = random.t(key=subkey, df=df, shape=jnp.shape(df), dtype=self.REAL)
1487
1401
  out_of_bounds = jnp.logical_not(jnp.all(df > 0))
1488
1402
  err |= (out_of_bounds * ERR)
1489
1403
  return sample, key, err, params
@@ -1503,8 +1417,7 @@ class JaxRDDLCompiler:
1503
1417
  mean, key, err1, params = jax_mean(x, params, key)
1504
1418
  scale, key, err2, params = jax_scale(x, params, key)
1505
1419
  key, subkey = random.split(key)
1506
- Gumbel01 = random.gumbel(
1507
- key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
1420
+ Gumbel01 = random.gumbel(key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
1508
1421
  sample = mean + scale * Gumbel01
1509
1422
  out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
1510
1423
  err = err1 | err2 | (out_of_bounds * ERR)
@@ -1525,8 +1438,7 @@ class JaxRDDLCompiler:
1525
1438
  mean, key, err1, params = jax_mean(x, params, key)
1526
1439
  scale, key, err2, params = jax_scale(x, params, key)
1527
1440
  key, subkey = random.split(key)
1528
- Laplace01 = random.laplace(
1529
- key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
1441
+ Laplace01 = random.laplace(key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
1530
1442
  sample = mean + scale * Laplace01
1531
1443
  out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
1532
1444
  err = err1 | err2 | (out_of_bounds * ERR)
@@ -1547,8 +1459,7 @@ class JaxRDDLCompiler:
1547
1459
  mean, key, err1, params = jax_mean(x, params, key)
1548
1460
  scale, key, err2, params = jax_scale(x, params, key)
1549
1461
  key, subkey = random.split(key)
1550
- Cauchy01 = random.cauchy(
1551
- key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
1462
+ Cauchy01 = random.cauchy(key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
1552
1463
  sample = mean + scale * Cauchy01
1553
1464
  out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
1554
1465
  err = err1 | err2 | (out_of_bounds * ERR)
@@ -1570,7 +1481,7 @@ class JaxRDDLCompiler:
1570
1481
  scale, key, err2, params = jax_scale(x, params, key)
1571
1482
  key, subkey = random.split(key)
1572
1483
  U = random.uniform(key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
1573
- sample = jnp.log(1.0 - jnp.log(U) / shape) / scale
1484
+ sample = jnp.log(1.0 - jnp.log1p(-U) / shape) / scale
1574
1485
  out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
1575
1486
  err = err1 | err2 | (out_of_bounds * ERR)
1576
1487
  return sample, key, err, params
@@ -1631,9 +1542,9 @@ class JaxRDDLCompiler:
1631
1542
  has_fluent_arg = any(self.traced.cached_is_fluent(arg)
1632
1543
  for arg in ordered_args)
1633
1544
  if self.compile_non_fluent_exact and not has_fluent_arg:
1634
- discrete_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_DISCRETE
1545
+ discrete_op = self.EXACT_OPS['sampling']['Discrete']
1635
1546
  else:
1636
- discrete_op = self.DISCRETE_HELPER
1547
+ discrete_op = self.OPS['sampling']['Discrete']
1637
1548
  jax_op = discrete_op(expr.id, init_params)
1638
1549
 
1639
1550
  # compile probability expressions
@@ -1672,9 +1583,9 @@ class JaxRDDLCompiler:
1672
1583
 
1673
1584
  # if probabilities are non-fluent, then always sample exact
1674
1585
  if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg):
1675
- discrete_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_DISCRETE
1586
+ discrete_op = self.EXACT_OPS['sampling']['Discrete']
1676
1587
  else:
1677
- discrete_op = self.DISCRETE_HELPER
1588
+ discrete_op = self.OPS['sampling']['Discrete']
1678
1589
  jax_op = discrete_op(expr.id, init_params)
1679
1590
 
1680
1591
  # compile probability function
@@ -1716,8 +1627,7 @@ class JaxRDDLCompiler:
1716
1627
  return self._jax_multinomial(expr, init_params)
1717
1628
  else:
1718
1629
  raise RDDLNotImplementedError(
1719
- f'Distribution {name} is not supported.\n' +
1720
- print_stack_trace(expr))
1630
+ f'Distribution {name} is not supported.\n' + print_stack_trace(expr))
1721
1631
 
1722
1632
  def _jax_multivariate_normal(self, expr, init_params):
1723
1633
  _, args = expr.args
@@ -1771,7 +1681,7 @@ class JaxRDDLCompiler:
1771
1681
 
1772
1682
  # sample StudentT(0, 1, df) -- broadcast df to same shape as cov
1773
1683
  sample_df = sample_df[..., jnp.newaxis, jnp.newaxis]
1774
- sample_df = jnp.broadcast_to(sample_df, shape=sample_mean.shape + (1,))
1684
+ sample_df = jnp.broadcast_to(sample_df, shape=jnp.shape(sample_mean) + (1,))
1775
1685
  key, subkey = random.split(key)
1776
1686
  Z = random.t(
1777
1687
  key=subkey,
@@ -1826,7 +1736,7 @@ class JaxRDDLCompiler:
1826
1736
  prob = jnp.asarray(prob, dtype=self.REAL)
1827
1737
  key, subkey = random.split(key)
1828
1738
  dist = tfp.distributions.Multinomial(total_count=trials, probs=prob)
1829
- sample = dist.sample(seed=subkey).astype(self.INT)
1739
+ sample = jnp.asarray(dist.sample(seed=subkey), dtype=self.INT)
1830
1740
  sample = jnp.moveaxis(sample, source=-1, destination=index)
1831
1741
  out_of_bounds = jnp.logical_not(jnp.all(
1832
1742
  (prob >= 0)