pyRDDLGym-jax 2.0__py3-none-any.whl → 2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pyRDDLGym_jax/__init__.py CHANGED
@@ -1 +1 @@
1
- __version__ = '2.0'
1
+ __version__ = '2.2'
@@ -20,7 +20,6 @@ from typing import Any, Callable, Dict, List, Optional
20
20
  import jax
21
21
  import jax.numpy as jnp
22
22
  import jax.random as random
23
- import jax.scipy as scipy
24
23
 
25
24
  from pyRDDLGym.core.compiler.initializer import RDDLValueInitializer
26
25
  from pyRDDLGym.core.compiler.levels import RDDLLevelAnalysis
@@ -43,8 +42,7 @@ try:
43
42
  from tensorflow_probability.substrates import jax as tfp
44
43
  except Exception:
45
44
  raise_warning('Failed to import tensorflow-probability: '
46
- 'compilation of some complex distributions '
47
- '(Binomial, Negative-Binomial, Multinomial) will fail.', 'red')
45
+ 'compilation of some probability distributions will fail.', 'red')
48
46
  traceback.print_exc()
49
47
  tfp = None
50
48
 
@@ -54,102 +52,6 @@ class JaxRDDLCompiler:
54
52
  All operations are identical to their numpy equivalents.
55
53
  '''
56
54
 
57
- MODEL_PARAM_TAG_SEPARATOR = '___'
58
-
59
- # ===========================================================================
60
- # EXACT RDDL TO JAX COMPILATION RULES BY DEFAULT
61
- # ===========================================================================
62
-
63
- @staticmethod
64
- def wrap_logic(func):
65
- def exact_func(id, init_params):
66
- return func
67
- return exact_func
68
-
69
- EXACT_RDDL_TO_JAX_NEGATIVE = wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.negative))
70
- EXACT_RDDL_TO_JAX_ARITHMETIC = {
71
- '+': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.add)),
72
- '-': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.subtract)),
73
- '*': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.multiply)),
74
- '/': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.divide))
75
- }
76
-
77
- EXACT_RDDL_TO_JAX_RELATIONAL = {
78
- '>=': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.greater_equal)),
79
- '<=': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.less_equal)),
80
- '<': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.less)),
81
- '>': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.greater)),
82
- '==': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.equal)),
83
- '~=': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.not_equal))
84
- }
85
-
86
- EXACT_RDDL_TO_JAX_LOGICAL_NOT = wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.logical_not))
87
- EXACT_RDDL_TO_JAX_LOGICAL = {
88
- '^': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_and)),
89
- '&': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_and)),
90
- '|': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_or)),
91
- '~': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_xor)),
92
- '=>': wrap_logic.__func__(ExactLogic.exact_binary_implies),
93
- '<=>': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.equal))
94
- }
95
-
96
- EXACT_RDDL_TO_JAX_AGGREGATION = {
97
- 'sum': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.sum)),
98
- 'avg': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.mean)),
99
- 'prod': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.prod)),
100
- 'minimum': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.min)),
101
- 'maximum': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.max)),
102
- 'forall': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.all)),
103
- 'exists': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.any)),
104
- 'argmin': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.argmin)),
105
- 'argmax': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.argmax))
106
- }
107
-
108
- EXACT_RDDL_TO_JAX_UNARY = {
109
- 'abs': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.abs)),
110
- 'sgn': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sign)),
111
- 'round': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.round)),
112
- 'floor': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.floor)),
113
- 'ceil': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.ceil)),
114
- 'cos': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.cos)),
115
- 'sin': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sin)),
116
- 'tan': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.tan)),
117
- 'acos': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.arccos)),
118
- 'asin': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.arcsin)),
119
- 'atan': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.arctan)),
120
- 'cosh': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.cosh)),
121
- 'sinh': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sinh)),
122
- 'tanh': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.tanh)),
123
- 'exp': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.exp)),
124
- 'ln': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.log)),
125
- 'sqrt': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sqrt)),
126
- 'lngamma': wrap_logic.__func__(ExactLogic.exact_unary_function(scipy.special.gammaln)),
127
- 'gamma': wrap_logic.__func__(ExactLogic.exact_unary_function(scipy.special.gamma))
128
- }
129
-
130
- @staticmethod
131
- def _jax_wrapped_calc_log_exact(x, y, params):
132
- return jnp.log(x) / jnp.log(y), params
133
-
134
- EXACT_RDDL_TO_JAX_BINARY = {
135
- 'div': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.floor_divide)),
136
- 'mod': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.mod)),
137
- 'fmod': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.mod)),
138
- 'min': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.minimum)),
139
- 'max': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.maximum)),
140
- 'pow': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.power)),
141
- 'log': wrap_logic.__func__(_jax_wrapped_calc_log_exact.__func__),
142
- 'hypot': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.hypot)),
143
- }
144
-
145
- EXACT_RDDL_TO_JAX_IF = wrap_logic.__func__(ExactLogic.exact_if_then_else)
146
- EXACT_RDDL_TO_JAX_SWITCH = wrap_logic.__func__(ExactLogic.exact_switch)
147
-
148
- EXACT_RDDL_TO_JAX_BERNOULLI = wrap_logic.__func__(ExactLogic.exact_bernoulli)
149
- EXACT_RDDL_TO_JAX_DISCRETE = wrap_logic.__func__(ExactLogic.exact_discrete)
150
- EXACT_RDDL_TO_JAX_POISSON = wrap_logic.__func__(ExactLogic.exact_poisson)
151
- EXACT_RDDL_TO_JAX_GEOMETRIC = wrap_logic.__func__(ExactLogic.exact_geometric)
152
-
153
55
  def __init__(self, rddl: RDDLLiftedModel,
154
56
  allow_synchronous_state: bool=True,
155
57
  logger: Optional[Logger]=None,
@@ -189,8 +91,7 @@ class JaxRDDLCompiler:
189
91
  self.init_values = initializer.initialize()
190
92
 
191
93
  # compute dependency graph for CPFs and sort them by evaluation order
192
- sorter = RDDLLevelAnalysis(
193
- rddl, allow_synchronous_state=allow_synchronous_state)
94
+ sorter = RDDLLevelAnalysis(rddl, allow_synchronous_state=allow_synchronous_state)
194
95
  self.levels = sorter.compute_levels()
195
96
 
196
97
  # trace expressions to cache information to be used later
@@ -202,28 +103,17 @@ class JaxRDDLCompiler:
202
103
  rddl=self.rddl,
203
104
  init_values=self.init_values,
204
105
  levels=self.levels,
205
- trace_info=self.traced)
106
+ trace_info=self.traced
107
+ )
206
108
  constraints = RDDLConstraints(simulator, vectorized=True)
207
109
  self.constraints = constraints
208
110
 
209
111
  # basic operations - these can be override in subclasses
210
112
  self.compile_non_fluent_exact = compile_non_fluent_exact
211
- self.NEGATIVE = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_NEGATIVE
212
- self.ARITHMETIC_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_ARITHMETIC.copy()
213
- self.RELATIONAL_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_RELATIONAL.copy()
214
- self.LOGICAL_NOT = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_LOGICAL_NOT
215
- self.LOGICAL_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_LOGICAL.copy()
216
- self.AGGREGATION_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_AGGREGATION.copy()
217
113
  self.AGGREGATION_BOOL = {'forall', 'exists'}
218
- self.KNOWN_UNARY = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_UNARY.copy()
219
- self.KNOWN_BINARY = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BINARY.copy()
220
- self.IF_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_IF
221
- self.SWITCH_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_SWITCH
222
- self.BERNOULLI_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BERNOULLI
223
- self.DISCRETE_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_DISCRETE
224
- self.POISSON_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_POISSON
225
- self.GEOMETRIC_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_GEOMETRIC
226
-
114
+ self.EXACT_OPS = ExactLogic(use64bit=self.use64bit).get_operator_dicts()
115
+ self.OPS = self.EXACT_OPS
116
+
227
117
  # ===========================================================================
228
118
  # main compilation subroutines
229
119
  # ===========================================================================
@@ -392,7 +282,8 @@ class JaxRDDLCompiler:
392
282
 
393
283
  # compile constraint information
394
284
  if constraint_func:
395
- inequality_fns, equality_fns = self._jax_nonlinear_constraints(init_params_constr)
285
+ inequality_fns, equality_fns = self._jax_nonlinear_constraints(
286
+ init_params_constr)
396
287
  else:
397
288
  inequality_fns, equality_fns = None, None
398
289
 
@@ -586,7 +477,11 @@ class JaxRDDLCompiler:
586
477
  for (id, value) in self.model_params.items():
587
478
  expr_id = int(str(id).split('_')[0])
588
479
  expr = self.traced.lookup(expr_id)
589
- result[id] = {'id': expr_id, 'rddl_op': ' '.join(expr.etype), 'init_value': value}
480
+ result[id] = {
481
+ 'id': expr_id,
482
+ 'rddl_op': ' '.join(expr.etype),
483
+ 'init_value': value
484
+ }
590
485
  return result
591
486
 
592
487
  @staticmethod
@@ -737,7 +632,7 @@ class JaxRDDLCompiler:
737
632
  return _jax_wrapped_cast
738
633
 
739
634
  def _fix_dtype(self, value):
740
- dtype = jnp.atleast_1d(value).dtype
635
+ dtype = jnp.result_type(value)
741
636
  if jnp.issubdtype(dtype, jnp.integer):
742
637
  return self.INT
743
638
  elif jnp.issubdtype(dtype, jnp.floating):
@@ -885,11 +780,11 @@ class JaxRDDLCompiler:
885
780
 
886
781
  # if expression is non-fluent, always use the exact operation
887
782
  if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
888
- valid_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_ARITHMETIC
889
- negative_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_NEGATIVE
783
+ valid_ops = self.EXACT_OPS['arithmetic']
784
+ negative_op = self.EXACT_OPS['negative']
890
785
  else:
891
- valid_ops = self.ARITHMETIC_OPS
892
- negative_op = self.NEGATIVE
786
+ valid_ops = self.OPS['arithmetic']
787
+ negative_op = self.OPS['negative']
893
788
  JaxRDDLCompiler._check_valid_op(expr, valid_ops)
894
789
 
895
790
  # recursively compile arguments
@@ -916,9 +811,9 @@ class JaxRDDLCompiler:
916
811
 
917
812
  # if expression is non-fluent, always use the exact operation
918
813
  if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
919
- valid_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_RELATIONAL
814
+ valid_ops = self.EXACT_OPS['relational']
920
815
  else:
921
- valid_ops = self.RELATIONAL_OPS
816
+ valid_ops = self.OPS['relational']
922
817
  JaxRDDLCompiler._check_valid_op(expr, valid_ops)
923
818
 
924
819
  # recursively compile arguments
@@ -934,11 +829,11 @@ class JaxRDDLCompiler:
934
829
 
935
830
  # if expression is non-fluent, always use the exact operation
936
831
  if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
937
- valid_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_LOGICAL
938
- logical_not_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_LOGICAL_NOT
832
+ valid_ops = self.EXACT_OPS['logical']
833
+ logical_not_op = self.EXACT_OPS['logical_not']
939
834
  else:
940
- valid_ops = self.LOGICAL_OPS
941
- logical_not_op = self.LOGICAL_NOT
835
+ valid_ops = self.OPS['logical']
836
+ logical_not_op = self.OPS['logical_not']
942
837
  JaxRDDLCompiler._check_valid_op(expr, valid_ops)
943
838
 
944
839
  # recursively compile arguments
@@ -966,9 +861,9 @@ class JaxRDDLCompiler:
966
861
 
967
862
  # if expression is non-fluent, always use the exact operation
968
863
  if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
969
- valid_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_AGGREGATION
864
+ valid_ops = self.EXACT_OPS['aggregation']
970
865
  else:
971
- valid_ops = self.AGGREGATION_OPS
866
+ valid_ops = self.OPS['aggregation']
972
867
  JaxRDDLCompiler._check_valid_op(expr, valid_ops)
973
868
  is_floating = op not in self.AGGREGATION_BOOL
974
869
 
@@ -995,11 +890,11 @@ class JaxRDDLCompiler:
995
890
 
996
891
  # if expression is non-fluent, always use the exact operation
997
892
  if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
998
- unary_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_UNARY
999
- binary_ops = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BINARY
893
+ unary_ops = self.EXACT_OPS['unary']
894
+ binary_ops = self.EXACT_OPS['binary']
1000
895
  else:
1001
- unary_ops = self.KNOWN_UNARY
1002
- binary_ops = self.KNOWN_BINARY
896
+ unary_ops = self.OPS['unary']
897
+ binary_ops = self.OPS['binary']
1003
898
 
1004
899
  # recursively compile arguments
1005
900
  if op in unary_ops:
@@ -1041,9 +936,9 @@ class JaxRDDLCompiler:
1041
936
 
1042
937
  # if predicate is non-fluent, always use the exact operation
1043
938
  if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(pred):
1044
- if_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_IF
939
+ if_op = self.EXACT_OPS['control']['if']
1045
940
  else:
1046
- if_op = self.IF_HELPER
941
+ if_op = self.OPS['control']['if']
1047
942
  jax_op = if_op(expr.id, init_params)
1048
943
 
1049
944
  # recursively compile arguments
@@ -1069,9 +964,9 @@ class JaxRDDLCompiler:
1069
964
  # if predicate is non-fluent, always use the exact operation
1070
965
  # case conditions are currently only literals so they are non-fluent
1071
966
  if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(pred):
1072
- switch_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_SWITCH
967
+ switch_op = self.EXACT_OPS['control']['switch']
1073
968
  else:
1074
- switch_op = self.SWITCH_HELPER
969
+ switch_op = self.OPS['control']['switch']
1075
970
  jax_op = switch_op(expr.id, init_params)
1076
971
 
1077
972
  # recursively compile predicate
@@ -1093,8 +988,7 @@ class JaxRDDLCompiler:
1093
988
  for (i, jax_case) in enumerate(jax_cases):
1094
989
  sample_cases[i], key, err_case, params = jax_case(x, params, key)
1095
990
  err |= err_case
1096
- sample_cases = jnp.asarray(
1097
- sample_cases, dtype=self._fix_dtype(sample_cases))
991
+ sample_cases = jnp.asarray(sample_cases, dtype=self._fix_dtype(sample_cases))
1098
992
 
1099
993
  # predicate (enum) is an integer - use it to extract from case array
1100
994
  sample, params = jax_op(sample_pred, sample_cases, params)
@@ -1113,6 +1007,7 @@ class JaxRDDLCompiler:
1113
1007
  # Bernoulli: complete (subclass uses Gumbel-softmax)
1114
1008
  # Normal: complete
1115
1009
  # Exponential: complete
1010
+ # Geometric: complete
1116
1011
  # Weibull: complete
1117
1012
  # Pareto: complete
1118
1013
  # Gumbel: complete
@@ -1125,14 +1020,18 @@ class JaxRDDLCompiler:
1125
1020
  # Discrete(p): complete (subclass uses Gumbel-softmax)
1126
1021
  # UnnormDiscrete(p): complete (subclass uses Gumbel-softmax)
1127
1022
 
1023
+ # distributions which seem to support backpropagation (need more testing):
1024
+ # Beta
1025
+ # Student
1026
+ # Gamma
1027
+ # ChiSquare
1028
+ # Dirichlet
1029
+ # Poisson (subclass uses Gumbel-softmax or Poisson process trick)
1030
+
1128
1031
  # distributions with incomplete reparameterization support (TODO):
1129
- # Binomial: (use truncation and Gumbel-softmax)
1130
- # NegativeBinomial: (no reparameterization)
1131
- # Poisson: (use truncation and Gumbel-softmax)
1132
- # Gamma, ChiSquare: (no shape reparameterization)
1133
- # Beta: (no reparameterization)
1134
- # Geometric: (implement safe floor)
1135
- # Student: (no reparameterization)
1032
+ # Binomial
1033
+ # NegativeBinomial
1034
+ # Multinomial
1136
1035
 
1137
1036
  def _jax_random(self, expr, init_params):
1138
1037
  _, name = expr.etype
@@ -1188,8 +1087,7 @@ class JaxRDDLCompiler:
1188
1087
  return self._jax_discrete_pvar(expr, init_params, unnorm=True)
1189
1088
  else:
1190
1089
  raise RDDLNotImplementedError(
1191
- f'Distribution {name} is not supported.\n' +
1192
- print_stack_trace(expr))
1090
+ f'Distribution {name} is not supported.\n' + print_stack_trace(expr))
1193
1091
 
1194
1092
  def _jax_kron(self, expr, init_params):
1195
1093
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_KRON_DELTA']
@@ -1266,8 +1164,7 @@ class JaxRDDLCompiler:
1266
1164
  def _jax_wrapped_distribution_exp(x, params, key):
1267
1165
  scale, key, err, params = jax_scale(x, params, key)
1268
1166
  key, subkey = random.split(key)
1269
- Exp1 = random.exponential(
1270
- key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
1167
+ Exp1 = random.exponential(key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
1271
1168
  sample = scale * Exp1
1272
1169
  out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
1273
1170
  err |= (out_of_bounds * ERR)
@@ -1288,8 +1185,8 @@ class JaxRDDLCompiler:
1288
1185
  shape, key, err1, params = jax_shape(x, params, key)
1289
1186
  scale, key, err2, params = jax_scale(x, params, key)
1290
1187
  key, subkey = random.split(key)
1291
- U = random.uniform(key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
1292
- sample = scale * jnp.power(-jnp.log(U), 1.0 / shape)
1188
+ sample = random.weibull_min(
1189
+ key=subkey, scale=scale, concentration=shape, dtype=self.REAL)
1293
1190
  out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
1294
1191
  err = err1 | err2 | (out_of_bounds * ERR)
1295
1192
  return sample, key, err, params
@@ -1303,9 +1200,9 @@ class JaxRDDLCompiler:
1303
1200
 
1304
1201
  # if probability is non-fluent, always use the exact operation
1305
1202
  if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_prob):
1306
- bern_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BERNOULLI
1203
+ bern_op = self.EXACT_OPS['sampling']['Bernoulli']
1307
1204
  else:
1308
- bern_op = self.BERNOULLI_HELPER
1205
+ bern_op = self.OPS['sampling']['Bernoulli']
1309
1206
  jax_op = bern_op(expr.id, init_params)
1310
1207
 
1311
1208
  # recursively compile arguments
@@ -1328,9 +1225,9 @@ class JaxRDDLCompiler:
1328
1225
 
1329
1226
  # if rate is non-fluent, always use the exact operation
1330
1227
  if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_rate):
1331
- poisson_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_POISSON
1228
+ poisson_op = self.EXACT_OPS['sampling']['Poisson']
1332
1229
  else:
1333
- poisson_op = self.POISSON_HELPER
1230
+ poisson_op = self.OPS['sampling']['Poisson']
1334
1231
  jax_op = poisson_op(expr.id, init_params)
1335
1232
 
1336
1233
  # recursively compile arguments
@@ -1341,7 +1238,6 @@ class JaxRDDLCompiler:
1341
1238
  rate, key, err, params = jax_rate(x, params, key)
1342
1239
  key, subkey = random.split(key)
1343
1240
  sample, params = jax_op(subkey, rate, params)
1344
- sample = sample.astype(self.INT)
1345
1241
  out_of_bounds = jnp.logical_not(jnp.all(rate >= 0))
1346
1242
  err |= (out_of_bounds * ERR)
1347
1243
  return sample, key, err, params
@@ -1373,20 +1269,26 @@ class JaxRDDLCompiler:
1373
1269
  def _jax_binomial(self, expr, init_params):
1374
1270
  ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_BINOMIAL']
1375
1271
  JaxRDDLCompiler._check_num_args(expr, 2)
1376
-
1377
1272
  arg_trials, arg_prob = expr.args
1273
+
1274
+ # if prob is non-fluent, always use the exact operation
1275
+ if self.compile_non_fluent_exact \
1276
+ and not self.traced.cached_is_fluent(arg_trials) \
1277
+ and not self.traced.cached_is_fluent(arg_prob):
1278
+ bin_op = self.EXACT_OPS['sampling']['Binomial']
1279
+ else:
1280
+ bin_op = self.OPS['sampling']['Binomial']
1281
+ jax_op = bin_op(expr.id, init_params)
1282
+
1378
1283
  jax_trials = self._jax(arg_trials, init_params)
1379
1284
  jax_prob = self._jax(arg_prob, init_params)
1380
-
1381
- # uses the JAX substrate of tensorflow-probability
1285
+
1286
+ # uses reduction for constant trials
1382
1287
  def _jax_wrapped_distribution_binomial(x, params, key):
1383
1288
  trials, key, err2, params = jax_trials(x, params, key)
1384
1289
  prob, key, err1, params = jax_prob(x, params, key)
1385
- trials = jnp.asarray(trials, dtype=self.REAL)
1386
- prob = jnp.asarray(prob, dtype=self.REAL)
1387
1290
  key, subkey = random.split(key)
1388
- dist = tfp.distributions.Binomial(total_count=trials, probs=prob)
1389
- sample = dist.sample(seed=subkey).astype(self.INT)
1291
+ sample, params = jax_op(subkey, trials, prob, params)
1390
1292
  out_of_bounds = jnp.logical_not(jnp.all(
1391
1293
  (prob >= 0) & (prob <= 1) & (trials >= 0)))
1392
1294
  err = err1 | err2 | (out_of_bounds * ERR)
@@ -1410,10 +1312,9 @@ class JaxRDDLCompiler:
1410
1312
  prob = jnp.asarray(prob, dtype=self.REAL)
1411
1313
  key, subkey = random.split(key)
1412
1314
  dist = tfp.distributions.NegativeBinomial(total_count=trials, probs=prob)
1413
- sample = dist.sample(seed=subkey).astype(self.INT)
1315
+ sample = jnp.asarray(dist.sample(seed=subkey), dtype=self.INT)
1414
1316
  out_of_bounds = jnp.logical_not(jnp.all(
1415
- (prob >= 0) & (prob <= 1) & (trials > 0))
1416
- )
1317
+ (prob >= 0) & (prob <= 1) & (trials > 0)))
1417
1318
  err = err1 | err2 | (out_of_bounds * ERR)
1418
1319
  return sample, key, err, params
1419
1320
 
@@ -1446,9 +1347,9 @@ class JaxRDDLCompiler:
1446
1347
 
1447
1348
  # if prob is non-fluent, always use the exact operation
1448
1349
  if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_prob):
1449
- geom_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_GEOMETRIC
1350
+ geom_op = self.EXACT_OPS['sampling']['Geometric']
1450
1351
  else:
1451
- geom_op = self.GEOMETRIC_HELPER
1352
+ geom_op = self.OPS['sampling']['Geometric']
1452
1353
  jax_op = geom_op(expr.id, init_params)
1453
1354
 
1454
1355
  # recursively compile arguments
@@ -1458,7 +1359,6 @@ class JaxRDDLCompiler:
1458
1359
  prob, key, err, params = jax_prob(x, params, key)
1459
1360
  key, subkey = random.split(key)
1460
1361
  sample, params = jax_op(subkey, prob, params)
1461
- sample = sample.astype(self.INT)
1462
1362
  out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
1463
1363
  err |= (out_of_bounds * ERR)
1464
1364
  return sample, key, err, params
@@ -1497,8 +1397,7 @@ class JaxRDDLCompiler:
1497
1397
  def _jax_wrapped_distribution_t(x, params, key):
1498
1398
  df, key, err, params = jax_df(x, params, key)
1499
1399
  key, subkey = random.split(key)
1500
- sample = random.t(
1501
- key=subkey, df=df, shape=jnp.shape(df), dtype=self.REAL)
1400
+ sample = random.t(key=subkey, df=df, shape=jnp.shape(df), dtype=self.REAL)
1502
1401
  out_of_bounds = jnp.logical_not(jnp.all(df > 0))
1503
1402
  err |= (out_of_bounds * ERR)
1504
1403
  return sample, key, err, params
@@ -1518,8 +1417,7 @@ class JaxRDDLCompiler:
1518
1417
  mean, key, err1, params = jax_mean(x, params, key)
1519
1418
  scale, key, err2, params = jax_scale(x, params, key)
1520
1419
  key, subkey = random.split(key)
1521
- Gumbel01 = random.gumbel(
1522
- key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
1420
+ Gumbel01 = random.gumbel(key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
1523
1421
  sample = mean + scale * Gumbel01
1524
1422
  out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
1525
1423
  err = err1 | err2 | (out_of_bounds * ERR)
@@ -1540,8 +1438,7 @@ class JaxRDDLCompiler:
1540
1438
  mean, key, err1, params = jax_mean(x, params, key)
1541
1439
  scale, key, err2, params = jax_scale(x, params, key)
1542
1440
  key, subkey = random.split(key)
1543
- Laplace01 = random.laplace(
1544
- key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
1441
+ Laplace01 = random.laplace(key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
1545
1442
  sample = mean + scale * Laplace01
1546
1443
  out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
1547
1444
  err = err1 | err2 | (out_of_bounds * ERR)
@@ -1562,8 +1459,7 @@ class JaxRDDLCompiler:
1562
1459
  mean, key, err1, params = jax_mean(x, params, key)
1563
1460
  scale, key, err2, params = jax_scale(x, params, key)
1564
1461
  key, subkey = random.split(key)
1565
- Cauchy01 = random.cauchy(
1566
- key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
1462
+ Cauchy01 = random.cauchy(key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
1567
1463
  sample = mean + scale * Cauchy01
1568
1464
  out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
1569
1465
  err = err1 | err2 | (out_of_bounds * ERR)
@@ -1585,7 +1481,7 @@ class JaxRDDLCompiler:
1585
1481
  scale, key, err2, params = jax_scale(x, params, key)
1586
1482
  key, subkey = random.split(key)
1587
1483
  U = random.uniform(key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
1588
- sample = jnp.log(1.0 - jnp.log(U) / shape) / scale
1484
+ sample = jnp.log(1.0 - jnp.log1p(-U) / shape) / scale
1589
1485
  out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
1590
1486
  err = err1 | err2 | (out_of_bounds * ERR)
1591
1487
  return sample, key, err, params
@@ -1646,9 +1542,9 @@ class JaxRDDLCompiler:
1646
1542
  has_fluent_arg = any(self.traced.cached_is_fluent(arg)
1647
1543
  for arg in ordered_args)
1648
1544
  if self.compile_non_fluent_exact and not has_fluent_arg:
1649
- discrete_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_DISCRETE
1545
+ discrete_op = self.EXACT_OPS['sampling']['Discrete']
1650
1546
  else:
1651
- discrete_op = self.DISCRETE_HELPER
1547
+ discrete_op = self.OPS['sampling']['Discrete']
1652
1548
  jax_op = discrete_op(expr.id, init_params)
1653
1549
 
1654
1550
  # compile probability expressions
@@ -1687,9 +1583,9 @@ class JaxRDDLCompiler:
1687
1583
 
1688
1584
  # if probabilities are non-fluent, then always sample exact
1689
1585
  if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg):
1690
- discrete_op = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_DISCRETE
1586
+ discrete_op = self.EXACT_OPS['sampling']['Discrete']
1691
1587
  else:
1692
- discrete_op = self.DISCRETE_HELPER
1588
+ discrete_op = self.OPS['sampling']['Discrete']
1693
1589
  jax_op = discrete_op(expr.id, init_params)
1694
1590
 
1695
1591
  # compile probability function
@@ -1731,8 +1627,7 @@ class JaxRDDLCompiler:
1731
1627
  return self._jax_multinomial(expr, init_params)
1732
1628
  else:
1733
1629
  raise RDDLNotImplementedError(
1734
- f'Distribution {name} is not supported.\n' +
1735
- print_stack_trace(expr))
1630
+ f'Distribution {name} is not supported.\n' + print_stack_trace(expr))
1736
1631
 
1737
1632
  def _jax_multivariate_normal(self, expr, init_params):
1738
1633
  _, args = expr.args
@@ -1786,7 +1681,7 @@ class JaxRDDLCompiler:
1786
1681
 
1787
1682
  # sample StudentT(0, 1, df) -- broadcast df to same shape as cov
1788
1683
  sample_df = sample_df[..., jnp.newaxis, jnp.newaxis]
1789
- sample_df = jnp.broadcast_to(sample_df, shape=sample_mean.shape + (1,))
1684
+ sample_df = jnp.broadcast_to(sample_df, shape=jnp.shape(sample_mean) + (1,))
1790
1685
  key, subkey = random.split(key)
1791
1686
  Z = random.t(
1792
1687
  key=subkey,
@@ -1841,7 +1736,7 @@ class JaxRDDLCompiler:
1841
1736
  prob = jnp.asarray(prob, dtype=self.REAL)
1842
1737
  key, subkey = random.split(key)
1843
1738
  dist = tfp.distributions.Multinomial(total_count=trials, probs=prob)
1844
- sample = dist.sample(seed=subkey).astype(self.INT)
1739
+ sample = jnp.asarray(dist.sample(seed=subkey), dtype=self.INT)
1845
1740
  sample = jnp.moveaxis(sample, source=-1, destination=index)
1846
1741
  out_of_bounds = jnp.logical_not(jnp.all(
1847
1742
  (prob >= 0)