pyRDDLGym-jax 2.0__py3-none-any.whl → 2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +85 -190
- pyRDDLGym_jax/core/logic.py +313 -56
- pyRDDLGym_jax/core/planner.py +274 -200
- pyRDDLGym_jax/core/visualization.py +7 -8
- pyRDDLGym_jax/examples/run_tune.py +10 -6
- {pyRDDLGym_jax-2.0.dist-info → pyrddlgym_jax-2.2.dist-info}/METADATA +43 -30
- {pyRDDLGym_jax-2.0.dist-info → pyrddlgym_jax-2.2.dist-info}/RECORD +12 -12
- {pyRDDLGym_jax-2.0.dist-info → pyrddlgym_jax-2.2.dist-info}/WHEEL +1 -1
- {pyRDDLGym_jax-2.0.dist-info → pyrddlgym_jax-2.2.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-2.0.dist-info → pyrddlgym_jax-2.2.dist-info}/entry_points.txt +0 -0
- {pyRDDLGym_jax-2.0.dist-info → pyrddlgym_jax-2.2.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '2.
|
|
1
|
+
__version__ = '2.2'
|
pyRDDLGym_jax/core/compiler.py
CHANGED
|
@@ -20,7 +20,6 @@ from typing import Any, Callable, Dict, List, Optional
|
|
|
20
20
|
import jax
|
|
21
21
|
import jax.numpy as jnp
|
|
22
22
|
import jax.random as random
|
|
23
|
-
import jax.scipy as scipy
|
|
24
23
|
|
|
25
24
|
from pyRDDLGym.core.compiler.initializer import RDDLValueInitializer
|
|
26
25
|
from pyRDDLGym.core.compiler.levels import RDDLLevelAnalysis
|
|
@@ -43,8 +42,7 @@ try:
|
|
|
43
42
|
from tensorflow_probability.substrates import jax as tfp
|
|
44
43
|
except Exception:
|
|
45
44
|
raise_warning('Failed to import tensorflow-probability: '
|
|
46
|
-
'compilation of some
|
|
47
|
-
'(Binomial, Negative-Binomial, Multinomial) will fail.', 'red')
|
|
45
|
+
'compilation of some probability distributions will fail.', 'red')
|
|
48
46
|
traceback.print_exc()
|
|
49
47
|
tfp = None
|
|
50
48
|
|
|
@@ -54,102 +52,6 @@ class JaxRDDLCompiler:
|
|
|
54
52
|
All operations are identical to their numpy equivalents.
|
|
55
53
|
'''
|
|
56
54
|
|
|
57
|
-
MODEL_PARAM_TAG_SEPARATOR = '___'
|
|
58
|
-
|
|
59
|
-
# ===========================================================================
|
|
60
|
-
# EXACT RDDL TO JAX COMPILATION RULES BY DEFAULT
|
|
61
|
-
# ===========================================================================
|
|
62
|
-
|
|
63
|
-
@staticmethod
|
|
64
|
-
def wrap_logic(func):
|
|
65
|
-
def exact_func(id, init_params):
|
|
66
|
-
return func
|
|
67
|
-
return exact_func
|
|
68
|
-
|
|
69
|
-
EXACT_RDDL_TO_JAX_NEGATIVE = wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.negative))
|
|
70
|
-
EXACT_RDDL_TO_JAX_ARITHMETIC = {
|
|
71
|
-
'+': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.add)),
|
|
72
|
-
'-': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.subtract)),
|
|
73
|
-
'*': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.multiply)),
|
|
74
|
-
'/': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.divide))
|
|
75
|
-
}
|
|
76
|
-
|
|
77
|
-
EXACT_RDDL_TO_JAX_RELATIONAL = {
|
|
78
|
-
'>=': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.greater_equal)),
|
|
79
|
-
'<=': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.less_equal)),
|
|
80
|
-
'<': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.less)),
|
|
81
|
-
'>': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.greater)),
|
|
82
|
-
'==': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.equal)),
|
|
83
|
-
'~=': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.not_equal))
|
|
84
|
-
}
|
|
85
|
-
|
|
86
|
-
EXACT_RDDL_TO_JAX_LOGICAL_NOT = wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.logical_not))
|
|
87
|
-
EXACT_RDDL_TO_JAX_LOGICAL = {
|
|
88
|
-
'^': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_and)),
|
|
89
|
-
'&': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_and)),
|
|
90
|
-
'|': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_or)),
|
|
91
|
-
'~': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_xor)),
|
|
92
|
-
'=>': wrap_logic.__func__(ExactLogic.exact_binary_implies),
|
|
93
|
-
'<=>': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.equal))
|
|
94
|
-
}
|
|
95
|
-
|
|
96
|
-
EXACT_RDDL_TO_JAX_AGGREGATION = {
|
|
97
|
-
'sum': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.sum)),
|
|
98
|
-
'avg': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.mean)),
|
|
99
|
-
'prod': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.prod)),
|
|
100
|
-
'minimum': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.min)),
|
|
101
|
-
'maximum': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.max)),
|
|
102
|
-
'forall': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.all)),
|
|
103
|
-
'exists': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.any)),
|
|
104
|
-
'argmin': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.argmin)),
|
|
105
|
-
'argmax': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.argmax))
|
|
106
|
-
}
|
|
107
|
-
|
|
108
|
-
EXACT_RDDL_TO_JAX_UNARY = {
|
|
109
|
-
'abs': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.abs)),
|
|
110
|
-
'sgn': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sign)),
|
|
111
|
-
'round': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.round)),
|
|
112
|
-
'floor': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.floor)),
|
|
113
|
-
'ceil': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.ceil)),
|
|
114
|
-
'cos': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.cos)),
|
|
115
|
-
'sin': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sin)),
|
|
116
|
-
'tan': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.tan)),
|
|
117
|
-
'acos': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.arccos)),
|
|
118
|
-
'asin': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.arcsin)),
|
|
119
|
-
'atan': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.arctan)),
|
|
120
|
-
'cosh': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.cosh)),
|
|
121
|
-
'sinh': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sinh)),
|
|
122
|
-
'tanh': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.tanh)),
|
|
123
|
-
'exp': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.exp)),
|
|
124
|
-
'ln': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.log)),
|
|
125
|
-
'sqrt': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sqrt)),
|
|
126
|
-
'lngamma': wrap_logic.__func__(ExactLogic.exact_unary_function(scipy.special.gammaln)),
|
|
127
|
-
'gamma': wrap_logic.__func__(ExactLogic.exact_unary_function(scipy.special.gamma))
|
|
128
|
-
}
|
|
129
|
-
|
|
130
|
-
@staticmethod
|
|
131
|
-
def _jax_wrapped_calc_log_exact(x, y, params):
|
|
132
|
-
return jnp.log(x) / jnp.log(y), params
|
|
133
|
-
|
|
134
|
-
EXACT_RDDL_TO_JAX_BINARY = {
|
|
135
|
-
'div': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.floor_divide)),
|
|
136
|
-
'mod': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.mod)),
|
|
137
|
-
'fmod': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.mod)),
|
|
138
|
-
'min': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.minimum)),
|
|
139
|
-
'max': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.maximum)),
|
|
140
|
-
'pow': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.power)),
|
|
141
|
-
'log': wrap_logic.__func__(_jax_wrapped_calc_log_exact.__func__),
|
|
142
|
-
'hypot': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.hypot)),
|
|
143
|
-
}
|
|
144
|
-
|
|
145
|
-
EXACT_RDDL_TO_JAX_IF = wrap_logic.__func__(ExactLogic.exact_if_then_else)
|
|
146
|
-
EXACT_RDDL_TO_JAX_SWITCH = wrap_logic.__func__(ExactLogic.exact_switch)
|
|
147
|
-
|
|
148
|
-
EXACT_RDDL_TO_JAX_BERNOULLI = wrap_logic.__func__(ExactLogic.exact_bernoulli)
|
|
149
|
-
EXACT_RDDL_TO_JAX_DISCRETE = wrap_logic.__func__(ExactLogic.exact_discrete)
|
|
150
|
-
EXACT_RDDL_TO_JAX_POISSON = wrap_logic.__func__(ExactLogic.exact_poisson)
|
|
151
|
-
EXACT_RDDL_TO_JAX_GEOMETRIC = wrap_logic.__func__(ExactLogic.exact_geometric)
|
|
152
|
-
|
|
153
55
|
def __init__(self, rddl: RDDLLiftedModel,
|
|
154
56
|
allow_synchronous_state: bool=True,
|
|
155
57
|
logger: Optional[Logger]=None,
|
|
@@ -189,8 +91,7 @@ class JaxRDDLCompiler:
|
|
|
189
91
|
self.init_values = initializer.initialize()
|
|
190
92
|
|
|
191
93
|
# compute dependency graph for CPFs and sort them by evaluation order
|
|
192
|
-
sorter = RDDLLevelAnalysis(
|
|
193
|
-
rddl, allow_synchronous_state=allow_synchronous_state)
|
|
94
|
+
sorter = RDDLLevelAnalysis(rddl, allow_synchronous_state=allow_synchronous_state)
|
|
194
95
|
self.levels = sorter.compute_levels()
|
|
195
96
|
|
|
196
97
|
# trace expressions to cache information to be used later
|
|
@@ -202,28 +103,17 @@ class JaxRDDLCompiler:
|
|
|
202
103
|
rddl=self.rddl,
|
|
203
104
|
init_values=self.init_values,
|
|
204
105
|
levels=self.levels,
|
|
205
|
-
trace_info=self.traced
|
|
106
|
+
trace_info=self.traced
|
|
107
|
+
)
|
|
206
108
|
constraints = RDDLConstraints(simulator, vectorized=True)
|
|
207
109
|
self.constraints = constraints
|
|
208
110
|
|
|
209
111
|
# basic operations - these can be override in subclasses
|
|
210
112
|
self.compile_non_fluent_exact = compile_non_fluent_exact
|
|
211
|
-
self.NEGATIVE = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_NEGATIVE
|
|
212
|
-
self.ARITHMETIC_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_ARITHMETIC.copy()
|
|
213
|
-
self.RELATIONAL_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_RELATIONAL.copy()
|
|
214
|
-
self.LOGICAL_NOT = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_LOGICAL_NOT
|
|
215
|
-
self.LOGICAL_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_LOGICAL.copy()
|
|
216
|
-
self.AGGREGATION_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_AGGREGATION.copy()
|
|
217
113
|
self.AGGREGATION_BOOL = {'forall', 'exists'}
|
|
218
|
-
self.
|
|
219
|
-
self.
|
|
220
|
-
|
|
221
|
-
self.SWITCH_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_SWITCH
|
|
222
|
-
self.BERNOULLI_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BERNOULLI
|
|
223
|
-
self.DISCRETE_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_DISCRETE
|
|
224
|
-
self.POISSON_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_POISSON
|
|
225
|
-
self.GEOMETRIC_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_GEOMETRIC
|
|
226
|
-
|
|
114
|
+
self.EXACT_OPS = ExactLogic(use64bit=self.use64bit).get_operator_dicts()
|
|
115
|
+
self.OPS = self.EXACT_OPS
|
|
116
|
+
|
|
227
117
|
# ===========================================================================
|
|
228
118
|
# main compilation subroutines
|
|
229
119
|
# ===========================================================================
|
|
@@ -392,7 +282,8 @@ class JaxRDDLCompiler:
|
|
|
392
282
|
|
|
393
283
|
# compile constraint information
|
|
394
284
|
if constraint_func:
|
|
395
|
-
inequality_fns, equality_fns = self._jax_nonlinear_constraints(
|
|
285
|
+
inequality_fns, equality_fns = self._jax_nonlinear_constraints(
|
|
286
|
+
init_params_constr)
|
|
396
287
|
else:
|
|
397
288
|
inequality_fns, equality_fns = None, None
|
|
398
289
|
|
|
@@ -586,7 +477,11 @@ class JaxRDDLCompiler:
|
|
|
586
477
|
for (id, value) in self.model_params.items():
|
|
587
478
|
expr_id = int(str(id).split('_')[0])
|
|
588
479
|
expr = self.traced.lookup(expr_id)
|
|
589
|
-
result[id] = {
|
|
480
|
+
result[id] = {
|
|
481
|
+
'id': expr_id,
|
|
482
|
+
'rddl_op': ' '.join(expr.etype),
|
|
483
|
+
'init_value': value
|
|
484
|
+
}
|
|
590
485
|
return result
|
|
591
486
|
|
|
592
487
|
@staticmethod
|
|
@@ -737,7 +632,7 @@ class JaxRDDLCompiler:
|
|
|
737
632
|
return _jax_wrapped_cast
|
|
738
633
|
|
|
739
634
|
def _fix_dtype(self, value):
|
|
740
|
-
dtype = jnp.
|
|
635
|
+
dtype = jnp.result_type(value)
|
|
741
636
|
if jnp.issubdtype(dtype, jnp.integer):
|
|
742
637
|
return self.INT
|
|
743
638
|
elif jnp.issubdtype(dtype, jnp.floating):
|
|
@@ -885,11 +780,11 @@ class JaxRDDLCompiler:
|
|
|
885
780
|
|
|
886
781
|
# if expression is non-fluent, always use the exact operation
|
|
887
782
|
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
|
|
888
|
-
valid_ops =
|
|
889
|
-
negative_op =
|
|
783
|
+
valid_ops = self.EXACT_OPS['arithmetic']
|
|
784
|
+
negative_op = self.EXACT_OPS['negative']
|
|
890
785
|
else:
|
|
891
|
-
valid_ops = self.
|
|
892
|
-
negative_op = self.
|
|
786
|
+
valid_ops = self.OPS['arithmetic']
|
|
787
|
+
negative_op = self.OPS['negative']
|
|
893
788
|
JaxRDDLCompiler._check_valid_op(expr, valid_ops)
|
|
894
789
|
|
|
895
790
|
# recursively compile arguments
|
|
@@ -916,9 +811,9 @@ class JaxRDDLCompiler:
|
|
|
916
811
|
|
|
917
812
|
# if expression is non-fluent, always use the exact operation
|
|
918
813
|
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
|
|
919
|
-
valid_ops =
|
|
814
|
+
valid_ops = self.EXACT_OPS['relational']
|
|
920
815
|
else:
|
|
921
|
-
valid_ops = self.
|
|
816
|
+
valid_ops = self.OPS['relational']
|
|
922
817
|
JaxRDDLCompiler._check_valid_op(expr, valid_ops)
|
|
923
818
|
|
|
924
819
|
# recursively compile arguments
|
|
@@ -934,11 +829,11 @@ class JaxRDDLCompiler:
|
|
|
934
829
|
|
|
935
830
|
# if expression is non-fluent, always use the exact operation
|
|
936
831
|
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
|
|
937
|
-
valid_ops =
|
|
938
|
-
logical_not_op =
|
|
832
|
+
valid_ops = self.EXACT_OPS['logical']
|
|
833
|
+
logical_not_op = self.EXACT_OPS['logical_not']
|
|
939
834
|
else:
|
|
940
|
-
valid_ops = self.
|
|
941
|
-
logical_not_op = self.
|
|
835
|
+
valid_ops = self.OPS['logical']
|
|
836
|
+
logical_not_op = self.OPS['logical_not']
|
|
942
837
|
JaxRDDLCompiler._check_valid_op(expr, valid_ops)
|
|
943
838
|
|
|
944
839
|
# recursively compile arguments
|
|
@@ -966,9 +861,9 @@ class JaxRDDLCompiler:
|
|
|
966
861
|
|
|
967
862
|
# if expression is non-fluent, always use the exact operation
|
|
968
863
|
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
|
|
969
|
-
valid_ops =
|
|
864
|
+
valid_ops = self.EXACT_OPS['aggregation']
|
|
970
865
|
else:
|
|
971
|
-
valid_ops = self.
|
|
866
|
+
valid_ops = self.OPS['aggregation']
|
|
972
867
|
JaxRDDLCompiler._check_valid_op(expr, valid_ops)
|
|
973
868
|
is_floating = op not in self.AGGREGATION_BOOL
|
|
974
869
|
|
|
@@ -995,11 +890,11 @@ class JaxRDDLCompiler:
|
|
|
995
890
|
|
|
996
891
|
# if expression is non-fluent, always use the exact operation
|
|
997
892
|
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
|
|
998
|
-
unary_ops =
|
|
999
|
-
binary_ops =
|
|
893
|
+
unary_ops = self.EXACT_OPS['unary']
|
|
894
|
+
binary_ops = self.EXACT_OPS['binary']
|
|
1000
895
|
else:
|
|
1001
|
-
unary_ops = self.
|
|
1002
|
-
binary_ops = self.
|
|
896
|
+
unary_ops = self.OPS['unary']
|
|
897
|
+
binary_ops = self.OPS['binary']
|
|
1003
898
|
|
|
1004
899
|
# recursively compile arguments
|
|
1005
900
|
if op in unary_ops:
|
|
@@ -1041,9 +936,9 @@ class JaxRDDLCompiler:
|
|
|
1041
936
|
|
|
1042
937
|
# if predicate is non-fluent, always use the exact operation
|
|
1043
938
|
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(pred):
|
|
1044
|
-
if_op =
|
|
939
|
+
if_op = self.EXACT_OPS['control']['if']
|
|
1045
940
|
else:
|
|
1046
|
-
if_op = self.
|
|
941
|
+
if_op = self.OPS['control']['if']
|
|
1047
942
|
jax_op = if_op(expr.id, init_params)
|
|
1048
943
|
|
|
1049
944
|
# recursively compile arguments
|
|
@@ -1069,9 +964,9 @@ class JaxRDDLCompiler:
|
|
|
1069
964
|
# if predicate is non-fluent, always use the exact operation
|
|
1070
965
|
# case conditions are currently only literals so they are non-fluent
|
|
1071
966
|
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(pred):
|
|
1072
|
-
switch_op =
|
|
967
|
+
switch_op = self.EXACT_OPS['control']['switch']
|
|
1073
968
|
else:
|
|
1074
|
-
switch_op = self.
|
|
969
|
+
switch_op = self.OPS['control']['switch']
|
|
1075
970
|
jax_op = switch_op(expr.id, init_params)
|
|
1076
971
|
|
|
1077
972
|
# recursively compile predicate
|
|
@@ -1093,8 +988,7 @@ class JaxRDDLCompiler:
|
|
|
1093
988
|
for (i, jax_case) in enumerate(jax_cases):
|
|
1094
989
|
sample_cases[i], key, err_case, params = jax_case(x, params, key)
|
|
1095
990
|
err |= err_case
|
|
1096
|
-
sample_cases = jnp.asarray(
|
|
1097
|
-
sample_cases, dtype=self._fix_dtype(sample_cases))
|
|
991
|
+
sample_cases = jnp.asarray(sample_cases, dtype=self._fix_dtype(sample_cases))
|
|
1098
992
|
|
|
1099
993
|
# predicate (enum) is an integer - use it to extract from case array
|
|
1100
994
|
sample, params = jax_op(sample_pred, sample_cases, params)
|
|
@@ -1113,6 +1007,7 @@ class JaxRDDLCompiler:
|
|
|
1113
1007
|
# Bernoulli: complete (subclass uses Gumbel-softmax)
|
|
1114
1008
|
# Normal: complete
|
|
1115
1009
|
# Exponential: complete
|
|
1010
|
+
# Geometric: complete
|
|
1116
1011
|
# Weibull: complete
|
|
1117
1012
|
# Pareto: complete
|
|
1118
1013
|
# Gumbel: complete
|
|
@@ -1125,14 +1020,18 @@ class JaxRDDLCompiler:
|
|
|
1125
1020
|
# Discrete(p): complete (subclass uses Gumbel-softmax)
|
|
1126
1021
|
# UnnormDiscrete(p): complete (subclass uses Gumbel-softmax)
|
|
1127
1022
|
|
|
1023
|
+
# distributions which seem to support backpropagation (need more testing):
|
|
1024
|
+
# Beta
|
|
1025
|
+
# Student
|
|
1026
|
+
# Gamma
|
|
1027
|
+
# ChiSquare
|
|
1028
|
+
# Dirichlet
|
|
1029
|
+
# Poisson (subclass uses Gumbel-softmax or Poisson process trick)
|
|
1030
|
+
|
|
1128
1031
|
# distributions with incomplete reparameterization support (TODO):
|
|
1129
|
-
# Binomial
|
|
1130
|
-
# NegativeBinomial
|
|
1131
|
-
#
|
|
1132
|
-
# Gamma, ChiSquare: (no shape reparameterization)
|
|
1133
|
-
# Beta: (no reparameterization)
|
|
1134
|
-
# Geometric: (implement safe floor)
|
|
1135
|
-
# Student: (no reparameterization)
|
|
1032
|
+
# Binomial
|
|
1033
|
+
# NegativeBinomial
|
|
1034
|
+
# Multinomial
|
|
1136
1035
|
|
|
1137
1036
|
def _jax_random(self, expr, init_params):
|
|
1138
1037
|
_, name = expr.etype
|
|
@@ -1188,8 +1087,7 @@ class JaxRDDLCompiler:
|
|
|
1188
1087
|
return self._jax_discrete_pvar(expr, init_params, unnorm=True)
|
|
1189
1088
|
else:
|
|
1190
1089
|
raise RDDLNotImplementedError(
|
|
1191
|
-
f'Distribution {name} is not supported.\n' +
|
|
1192
|
-
print_stack_trace(expr))
|
|
1090
|
+
f'Distribution {name} is not supported.\n' + print_stack_trace(expr))
|
|
1193
1091
|
|
|
1194
1092
|
def _jax_kron(self, expr, init_params):
|
|
1195
1093
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_KRON_DELTA']
|
|
@@ -1266,8 +1164,7 @@ class JaxRDDLCompiler:
|
|
|
1266
1164
|
def _jax_wrapped_distribution_exp(x, params, key):
|
|
1267
1165
|
scale, key, err, params = jax_scale(x, params, key)
|
|
1268
1166
|
key, subkey = random.split(key)
|
|
1269
|
-
Exp1 = random.exponential(
|
|
1270
|
-
key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
|
|
1167
|
+
Exp1 = random.exponential(key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
|
|
1271
1168
|
sample = scale * Exp1
|
|
1272
1169
|
out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
|
|
1273
1170
|
err |= (out_of_bounds * ERR)
|
|
@@ -1288,8 +1185,8 @@ class JaxRDDLCompiler:
|
|
|
1288
1185
|
shape, key, err1, params = jax_shape(x, params, key)
|
|
1289
1186
|
scale, key, err2, params = jax_scale(x, params, key)
|
|
1290
1187
|
key, subkey = random.split(key)
|
|
1291
|
-
|
|
1292
|
-
|
|
1188
|
+
sample = random.weibull_min(
|
|
1189
|
+
key=subkey, scale=scale, concentration=shape, dtype=self.REAL)
|
|
1293
1190
|
out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
|
|
1294
1191
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1295
1192
|
return sample, key, err, params
|
|
@@ -1303,9 +1200,9 @@ class JaxRDDLCompiler:
|
|
|
1303
1200
|
|
|
1304
1201
|
# if probability is non-fluent, always use the exact operation
|
|
1305
1202
|
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_prob):
|
|
1306
|
-
bern_op =
|
|
1203
|
+
bern_op = self.EXACT_OPS['sampling']['Bernoulli']
|
|
1307
1204
|
else:
|
|
1308
|
-
bern_op = self.
|
|
1205
|
+
bern_op = self.OPS['sampling']['Bernoulli']
|
|
1309
1206
|
jax_op = bern_op(expr.id, init_params)
|
|
1310
1207
|
|
|
1311
1208
|
# recursively compile arguments
|
|
@@ -1328,9 +1225,9 @@ class JaxRDDLCompiler:
|
|
|
1328
1225
|
|
|
1329
1226
|
# if rate is non-fluent, always use the exact operation
|
|
1330
1227
|
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_rate):
|
|
1331
|
-
poisson_op =
|
|
1228
|
+
poisson_op = self.EXACT_OPS['sampling']['Poisson']
|
|
1332
1229
|
else:
|
|
1333
|
-
poisson_op = self.
|
|
1230
|
+
poisson_op = self.OPS['sampling']['Poisson']
|
|
1334
1231
|
jax_op = poisson_op(expr.id, init_params)
|
|
1335
1232
|
|
|
1336
1233
|
# recursively compile arguments
|
|
@@ -1341,7 +1238,6 @@ class JaxRDDLCompiler:
|
|
|
1341
1238
|
rate, key, err, params = jax_rate(x, params, key)
|
|
1342
1239
|
key, subkey = random.split(key)
|
|
1343
1240
|
sample, params = jax_op(subkey, rate, params)
|
|
1344
|
-
sample = sample.astype(self.INT)
|
|
1345
1241
|
out_of_bounds = jnp.logical_not(jnp.all(rate >= 0))
|
|
1346
1242
|
err |= (out_of_bounds * ERR)
|
|
1347
1243
|
return sample, key, err, params
|
|
@@ -1373,20 +1269,26 @@ class JaxRDDLCompiler:
|
|
|
1373
1269
|
def _jax_binomial(self, expr, init_params):
|
|
1374
1270
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_BINOMIAL']
|
|
1375
1271
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1376
|
-
|
|
1377
1272
|
arg_trials, arg_prob = expr.args
|
|
1273
|
+
|
|
1274
|
+
# if prob is non-fluent, always use the exact operation
|
|
1275
|
+
if self.compile_non_fluent_exact \
|
|
1276
|
+
and not self.traced.cached_is_fluent(arg_trials) \
|
|
1277
|
+
and not self.traced.cached_is_fluent(arg_prob):
|
|
1278
|
+
bin_op = self.EXACT_OPS['sampling']['Binomial']
|
|
1279
|
+
else:
|
|
1280
|
+
bin_op = self.OPS['sampling']['Binomial']
|
|
1281
|
+
jax_op = bin_op(expr.id, init_params)
|
|
1282
|
+
|
|
1378
1283
|
jax_trials = self._jax(arg_trials, init_params)
|
|
1379
1284
|
jax_prob = self._jax(arg_prob, init_params)
|
|
1380
|
-
|
|
1381
|
-
# uses
|
|
1285
|
+
|
|
1286
|
+
# uses reduction for constant trials
|
|
1382
1287
|
def _jax_wrapped_distribution_binomial(x, params, key):
|
|
1383
1288
|
trials, key, err2, params = jax_trials(x, params, key)
|
|
1384
1289
|
prob, key, err1, params = jax_prob(x, params, key)
|
|
1385
|
-
trials = jnp.asarray(trials, dtype=self.REAL)
|
|
1386
|
-
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
1387
1290
|
key, subkey = random.split(key)
|
|
1388
|
-
|
|
1389
|
-
sample = dist.sample(seed=subkey).astype(self.INT)
|
|
1291
|
+
sample, params = jax_op(subkey, trials, prob, params)
|
|
1390
1292
|
out_of_bounds = jnp.logical_not(jnp.all(
|
|
1391
1293
|
(prob >= 0) & (prob <= 1) & (trials >= 0)))
|
|
1392
1294
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
@@ -1410,10 +1312,9 @@ class JaxRDDLCompiler:
|
|
|
1410
1312
|
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
1411
1313
|
key, subkey = random.split(key)
|
|
1412
1314
|
dist = tfp.distributions.NegativeBinomial(total_count=trials, probs=prob)
|
|
1413
|
-
sample = dist.sample(seed=subkey)
|
|
1315
|
+
sample = jnp.asarray(dist.sample(seed=subkey), dtype=self.INT)
|
|
1414
1316
|
out_of_bounds = jnp.logical_not(jnp.all(
|
|
1415
|
-
(prob >= 0) & (prob <= 1) & (trials > 0))
|
|
1416
|
-
)
|
|
1317
|
+
(prob >= 0) & (prob <= 1) & (trials > 0)))
|
|
1417
1318
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1418
1319
|
return sample, key, err, params
|
|
1419
1320
|
|
|
@@ -1446,9 +1347,9 @@ class JaxRDDLCompiler:
|
|
|
1446
1347
|
|
|
1447
1348
|
# if prob is non-fluent, always use the exact operation
|
|
1448
1349
|
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_prob):
|
|
1449
|
-
geom_op =
|
|
1350
|
+
geom_op = self.EXACT_OPS['sampling']['Geometric']
|
|
1450
1351
|
else:
|
|
1451
|
-
geom_op = self.
|
|
1352
|
+
geom_op = self.OPS['sampling']['Geometric']
|
|
1452
1353
|
jax_op = geom_op(expr.id, init_params)
|
|
1453
1354
|
|
|
1454
1355
|
# recursively compile arguments
|
|
@@ -1458,7 +1359,6 @@ class JaxRDDLCompiler:
|
|
|
1458
1359
|
prob, key, err, params = jax_prob(x, params, key)
|
|
1459
1360
|
key, subkey = random.split(key)
|
|
1460
1361
|
sample, params = jax_op(subkey, prob, params)
|
|
1461
|
-
sample = sample.astype(self.INT)
|
|
1462
1362
|
out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
|
|
1463
1363
|
err |= (out_of_bounds * ERR)
|
|
1464
1364
|
return sample, key, err, params
|
|
@@ -1497,8 +1397,7 @@ class JaxRDDLCompiler:
|
|
|
1497
1397
|
def _jax_wrapped_distribution_t(x, params, key):
|
|
1498
1398
|
df, key, err, params = jax_df(x, params, key)
|
|
1499
1399
|
key, subkey = random.split(key)
|
|
1500
|
-
sample = random.t(
|
|
1501
|
-
key=subkey, df=df, shape=jnp.shape(df), dtype=self.REAL)
|
|
1400
|
+
sample = random.t(key=subkey, df=df, shape=jnp.shape(df), dtype=self.REAL)
|
|
1502
1401
|
out_of_bounds = jnp.logical_not(jnp.all(df > 0))
|
|
1503
1402
|
err |= (out_of_bounds * ERR)
|
|
1504
1403
|
return sample, key, err, params
|
|
@@ -1518,8 +1417,7 @@ class JaxRDDLCompiler:
|
|
|
1518
1417
|
mean, key, err1, params = jax_mean(x, params, key)
|
|
1519
1418
|
scale, key, err2, params = jax_scale(x, params, key)
|
|
1520
1419
|
key, subkey = random.split(key)
|
|
1521
|
-
Gumbel01 = random.gumbel(
|
|
1522
|
-
key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
|
|
1420
|
+
Gumbel01 = random.gumbel(key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
|
|
1523
1421
|
sample = mean + scale * Gumbel01
|
|
1524
1422
|
out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
|
|
1525
1423
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
@@ -1540,8 +1438,7 @@ class JaxRDDLCompiler:
|
|
|
1540
1438
|
mean, key, err1, params = jax_mean(x, params, key)
|
|
1541
1439
|
scale, key, err2, params = jax_scale(x, params, key)
|
|
1542
1440
|
key, subkey = random.split(key)
|
|
1543
|
-
Laplace01 = random.laplace(
|
|
1544
|
-
key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
|
|
1441
|
+
Laplace01 = random.laplace(key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
|
|
1545
1442
|
sample = mean + scale * Laplace01
|
|
1546
1443
|
out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
|
|
1547
1444
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
@@ -1562,8 +1459,7 @@ class JaxRDDLCompiler:
|
|
|
1562
1459
|
mean, key, err1, params = jax_mean(x, params, key)
|
|
1563
1460
|
scale, key, err2, params = jax_scale(x, params, key)
|
|
1564
1461
|
key, subkey = random.split(key)
|
|
1565
|
-
Cauchy01 = random.cauchy(
|
|
1566
|
-
key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
|
|
1462
|
+
Cauchy01 = random.cauchy(key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
|
|
1567
1463
|
sample = mean + scale * Cauchy01
|
|
1568
1464
|
out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
|
|
1569
1465
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
@@ -1585,7 +1481,7 @@ class JaxRDDLCompiler:
|
|
|
1585
1481
|
scale, key, err2, params = jax_scale(x, params, key)
|
|
1586
1482
|
key, subkey = random.split(key)
|
|
1587
1483
|
U = random.uniform(key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
|
|
1588
|
-
sample = jnp.log(1.0 - jnp.
|
|
1484
|
+
sample = jnp.log(1.0 - jnp.log1p(-U) / shape) / scale
|
|
1589
1485
|
out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
|
|
1590
1486
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1591
1487
|
return sample, key, err, params
|
|
@@ -1646,9 +1542,9 @@ class JaxRDDLCompiler:
|
|
|
1646
1542
|
has_fluent_arg = any(self.traced.cached_is_fluent(arg)
|
|
1647
1543
|
for arg in ordered_args)
|
|
1648
1544
|
if self.compile_non_fluent_exact and not has_fluent_arg:
|
|
1649
|
-
discrete_op =
|
|
1545
|
+
discrete_op = self.EXACT_OPS['sampling']['Discrete']
|
|
1650
1546
|
else:
|
|
1651
|
-
discrete_op = self.
|
|
1547
|
+
discrete_op = self.OPS['sampling']['Discrete']
|
|
1652
1548
|
jax_op = discrete_op(expr.id, init_params)
|
|
1653
1549
|
|
|
1654
1550
|
# compile probability expressions
|
|
@@ -1687,9 +1583,9 @@ class JaxRDDLCompiler:
|
|
|
1687
1583
|
|
|
1688
1584
|
# if probabilities are non-fluent, then always sample exact
|
|
1689
1585
|
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg):
|
|
1690
|
-
discrete_op =
|
|
1586
|
+
discrete_op = self.EXACT_OPS['sampling']['Discrete']
|
|
1691
1587
|
else:
|
|
1692
|
-
discrete_op = self.
|
|
1588
|
+
discrete_op = self.OPS['sampling']['Discrete']
|
|
1693
1589
|
jax_op = discrete_op(expr.id, init_params)
|
|
1694
1590
|
|
|
1695
1591
|
# compile probability function
|
|
@@ -1731,8 +1627,7 @@ class JaxRDDLCompiler:
|
|
|
1731
1627
|
return self._jax_multinomial(expr, init_params)
|
|
1732
1628
|
else:
|
|
1733
1629
|
raise RDDLNotImplementedError(
|
|
1734
|
-
f'Distribution {name} is not supported.\n' +
|
|
1735
|
-
print_stack_trace(expr))
|
|
1630
|
+
f'Distribution {name} is not supported.\n' + print_stack_trace(expr))
|
|
1736
1631
|
|
|
1737
1632
|
def _jax_multivariate_normal(self, expr, init_params):
|
|
1738
1633
|
_, args = expr.args
|
|
@@ -1786,7 +1681,7 @@ class JaxRDDLCompiler:
|
|
|
1786
1681
|
|
|
1787
1682
|
# sample StudentT(0, 1, df) -- broadcast df to same shape as cov
|
|
1788
1683
|
sample_df = sample_df[..., jnp.newaxis, jnp.newaxis]
|
|
1789
|
-
sample_df = jnp.broadcast_to(sample_df, shape=
|
|
1684
|
+
sample_df = jnp.broadcast_to(sample_df, shape=jnp.shape(sample_mean) + (1,))
|
|
1790
1685
|
key, subkey = random.split(key)
|
|
1791
1686
|
Z = random.t(
|
|
1792
1687
|
key=subkey,
|
|
@@ -1841,7 +1736,7 @@ class JaxRDDLCompiler:
|
|
|
1841
1736
|
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
1842
1737
|
key, subkey = random.split(key)
|
|
1843
1738
|
dist = tfp.distributions.Multinomial(total_count=trials, probs=prob)
|
|
1844
|
-
sample = dist.sample(seed=subkey)
|
|
1739
|
+
sample = jnp.asarray(dist.sample(seed=subkey), dtype=self.INT)
|
|
1845
1740
|
sample = jnp.moveaxis(sample, source=-1, destination=index)
|
|
1846
1741
|
out_of_bounds = jnp.logical_not(jnp.all(
|
|
1847
1742
|
(prob >= 0)
|