pyRDDLGym-jax 1.3__py3-none-any.whl → 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +101 -191
- pyRDDLGym_jax/core/logic.py +349 -65
- pyRDDLGym_jax/core/planner.py +554 -208
- pyRDDLGym_jax/core/simulator.py +20 -0
- pyRDDLGym_jax/core/tuning.py +15 -0
- pyRDDLGym_jax/core/visualization.py +55 -8
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg +3 -3
- pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_slp.cfg +4 -4
- pyRDDLGym_jax/examples/configs/Quadcopter_drp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg +4 -3
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_drp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/Reservoir_Continuous_slp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/UAV_Continuous_slp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_drp.cfg +1 -0
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_replan.cfg +1 -0
- pyRDDLGym_jax/examples/configs/Wildfire_MDP_ippc2014_slp.cfg +1 -0
- pyRDDLGym_jax/examples/run_tune.py +10 -6
- {pyRDDLGym_jax-1.3.dist-info → pyrddlgym_jax-2.1.dist-info}/METADATA +22 -12
- {pyRDDLGym_jax-1.3.dist-info → pyrddlgym_jax-2.1.dist-info}/RECORD +24 -24
- {pyRDDLGym_jax-1.3.dist-info → pyrddlgym_jax-2.1.dist-info}/WHEEL +1 -1
- {pyRDDLGym_jax-1.3.dist-info → pyrddlgym_jax-2.1.dist-info}/LICENSE +0 -0
- {pyRDDLGym_jax-1.3.dist-info → pyrddlgym_jax-2.1.dist-info}/entry_points.txt +0 -0
- {pyRDDLGym_jax-1.3.dist-info → pyrddlgym_jax-2.1.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '1
|
|
1
|
+
__version__ = '2.1'
|
pyRDDLGym_jax/core/compiler.py
CHANGED
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# ***********************************************************************
|
|
2
|
+
# JAXPLAN
|
|
3
|
+
#
|
|
4
|
+
# Author: Michael Gimelfarb
|
|
5
|
+
#
|
|
6
|
+
# REFERENCES:
|
|
7
|
+
#
|
|
8
|
+
# [1] Gimelfarb, Michael, Ayal Taitler, and Scott Sanner. "JaxPlan and GurobiPlan:
|
|
9
|
+
# Optimization Baselines for Replanning in Discrete and Mixed Discrete-Continuous
|
|
10
|
+
# Probabilistic Domains." Proceedings of the International Conference on Automated
|
|
11
|
+
# Planning and Scheduling. Vol. 34. 2024.
|
|
12
|
+
#
|
|
13
|
+
# ***********************************************************************
|
|
14
|
+
|
|
15
|
+
|
|
1
16
|
from functools import partial
|
|
2
17
|
import traceback
|
|
3
18
|
from typing import Any, Callable, Dict, List, Optional
|
|
@@ -5,7 +20,6 @@ from typing import Any, Callable, Dict, List, Optional
|
|
|
5
20
|
import jax
|
|
6
21
|
import jax.numpy as jnp
|
|
7
22
|
import jax.random as random
|
|
8
|
-
import jax.scipy as scipy
|
|
9
23
|
|
|
10
24
|
from pyRDDLGym.core.compiler.initializer import RDDLValueInitializer
|
|
11
25
|
from pyRDDLGym.core.compiler.levels import RDDLLevelAnalysis
|
|
@@ -28,8 +42,7 @@ try:
|
|
|
28
42
|
from tensorflow_probability.substrates import jax as tfp
|
|
29
43
|
except Exception:
|
|
30
44
|
raise_warning('Failed to import tensorflow-probability: '
|
|
31
|
-
'compilation of some
|
|
32
|
-
'(Binomial, Negative-Binomial, Multinomial) will fail.', 'red')
|
|
45
|
+
'compilation of some probability distributions will fail.', 'red')
|
|
33
46
|
traceback.print_exc()
|
|
34
47
|
tfp = None
|
|
35
48
|
|
|
@@ -39,102 +52,6 @@ class JaxRDDLCompiler:
|
|
|
39
52
|
All operations are identical to their numpy equivalents.
|
|
40
53
|
'''
|
|
41
54
|
|
|
42
|
-
MODEL_PARAM_TAG_SEPARATOR = '___'
|
|
43
|
-
|
|
44
|
-
# ===========================================================================
|
|
45
|
-
# EXACT RDDL TO JAX COMPILATION RULES BY DEFAULT
|
|
46
|
-
# ===========================================================================
|
|
47
|
-
|
|
48
|
-
@staticmethod
|
|
49
|
-
def wrap_logic(func):
|
|
50
|
-
def exact_func(id, init_params):
|
|
51
|
-
return func
|
|
52
|
-
return exact_func
|
|
53
|
-
|
|
54
|
-
EXACT_RDDL_TO_JAX_NEGATIVE = wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.negative))
|
|
55
|
-
EXACT_RDDL_TO_JAX_ARITHMETIC = {
|
|
56
|
-
'+': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.add)),
|
|
57
|
-
'-': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.subtract)),
|
|
58
|
-
'*': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.multiply)),
|
|
59
|
-
'/': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.divide))
|
|
60
|
-
}
|
|
61
|
-
|
|
62
|
-
EXACT_RDDL_TO_JAX_RELATIONAL = {
|
|
63
|
-
'>=': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.greater_equal)),
|
|
64
|
-
'<=': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.less_equal)),
|
|
65
|
-
'<': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.less)),
|
|
66
|
-
'>': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.greater)),
|
|
67
|
-
'==': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.equal)),
|
|
68
|
-
'~=': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.not_equal))
|
|
69
|
-
}
|
|
70
|
-
|
|
71
|
-
EXACT_RDDL_TO_JAX_LOGICAL_NOT = wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.logical_not))
|
|
72
|
-
EXACT_RDDL_TO_JAX_LOGICAL = {
|
|
73
|
-
'^': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_and)),
|
|
74
|
-
'&': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_and)),
|
|
75
|
-
'|': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_or)),
|
|
76
|
-
'~': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.logical_xor)),
|
|
77
|
-
'=>': wrap_logic.__func__(ExactLogic.exact_binary_implies),
|
|
78
|
-
'<=>': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.equal))
|
|
79
|
-
}
|
|
80
|
-
|
|
81
|
-
EXACT_RDDL_TO_JAX_AGGREGATION = {
|
|
82
|
-
'sum': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.sum)),
|
|
83
|
-
'avg': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.mean)),
|
|
84
|
-
'prod': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.prod)),
|
|
85
|
-
'minimum': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.min)),
|
|
86
|
-
'maximum': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.max)),
|
|
87
|
-
'forall': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.all)),
|
|
88
|
-
'exists': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.any)),
|
|
89
|
-
'argmin': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.argmin)),
|
|
90
|
-
'argmax': wrap_logic.__func__(ExactLogic.exact_aggregation(jnp.argmax))
|
|
91
|
-
}
|
|
92
|
-
|
|
93
|
-
EXACT_RDDL_TO_JAX_UNARY = {
|
|
94
|
-
'abs': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.abs)),
|
|
95
|
-
'sgn': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sign)),
|
|
96
|
-
'round': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.round)),
|
|
97
|
-
'floor': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.floor)),
|
|
98
|
-
'ceil': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.ceil)),
|
|
99
|
-
'cos': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.cos)),
|
|
100
|
-
'sin': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sin)),
|
|
101
|
-
'tan': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.tan)),
|
|
102
|
-
'acos': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.arccos)),
|
|
103
|
-
'asin': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.arcsin)),
|
|
104
|
-
'atan': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.arctan)),
|
|
105
|
-
'cosh': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.cosh)),
|
|
106
|
-
'sinh': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sinh)),
|
|
107
|
-
'tanh': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.tanh)),
|
|
108
|
-
'exp': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.exp)),
|
|
109
|
-
'ln': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.log)),
|
|
110
|
-
'sqrt': wrap_logic.__func__(ExactLogic.exact_unary_function(jnp.sqrt)),
|
|
111
|
-
'lngamma': wrap_logic.__func__(ExactLogic.exact_unary_function(scipy.special.gammaln)),
|
|
112
|
-
'gamma': wrap_logic.__func__(ExactLogic.exact_unary_function(scipy.special.gamma))
|
|
113
|
-
}
|
|
114
|
-
|
|
115
|
-
@staticmethod
|
|
116
|
-
def _jax_wrapped_calc_log_exact(x, y, params):
|
|
117
|
-
return jnp.log(x) / jnp.log(y), params
|
|
118
|
-
|
|
119
|
-
EXACT_RDDL_TO_JAX_BINARY = {
|
|
120
|
-
'div': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.floor_divide)),
|
|
121
|
-
'mod': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.mod)),
|
|
122
|
-
'fmod': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.mod)),
|
|
123
|
-
'min': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.minimum)),
|
|
124
|
-
'max': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.maximum)),
|
|
125
|
-
'pow': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.power)),
|
|
126
|
-
'log': wrap_logic.__func__(_jax_wrapped_calc_log_exact.__func__),
|
|
127
|
-
'hypot': wrap_logic.__func__(ExactLogic.exact_binary_function(jnp.hypot)),
|
|
128
|
-
}
|
|
129
|
-
|
|
130
|
-
EXACT_RDDL_TO_JAX_IF = wrap_logic.__func__(ExactLogic.exact_if_then_else)
|
|
131
|
-
EXACT_RDDL_TO_JAX_SWITCH = wrap_logic.__func__(ExactLogic.exact_switch)
|
|
132
|
-
|
|
133
|
-
EXACT_RDDL_TO_JAX_BERNOULLI = wrap_logic.__func__(ExactLogic.exact_bernoulli)
|
|
134
|
-
EXACT_RDDL_TO_JAX_DISCRETE = wrap_logic.__func__(ExactLogic.exact_discrete)
|
|
135
|
-
EXACT_RDDL_TO_JAX_POISSON = wrap_logic.__func__(ExactLogic.exact_poisson)
|
|
136
|
-
EXACT_RDDL_TO_JAX_GEOMETRIC = wrap_logic.__func__(ExactLogic.exact_geometric)
|
|
137
|
-
|
|
138
55
|
def __init__(self, rddl: RDDLLiftedModel,
|
|
139
56
|
allow_synchronous_state: bool=True,
|
|
140
57
|
logger: Optional[Logger]=None,
|
|
@@ -174,8 +91,7 @@ class JaxRDDLCompiler:
|
|
|
174
91
|
self.init_values = initializer.initialize()
|
|
175
92
|
|
|
176
93
|
# compute dependency graph for CPFs and sort them by evaluation order
|
|
177
|
-
sorter = RDDLLevelAnalysis(
|
|
178
|
-
rddl, allow_synchronous_state=allow_synchronous_state)
|
|
94
|
+
sorter = RDDLLevelAnalysis(rddl, allow_synchronous_state=allow_synchronous_state)
|
|
179
95
|
self.levels = sorter.compute_levels()
|
|
180
96
|
|
|
181
97
|
# trace expressions to cache information to be used later
|
|
@@ -187,28 +103,17 @@ class JaxRDDLCompiler:
|
|
|
187
103
|
rddl=self.rddl,
|
|
188
104
|
init_values=self.init_values,
|
|
189
105
|
levels=self.levels,
|
|
190
|
-
trace_info=self.traced
|
|
106
|
+
trace_info=self.traced
|
|
107
|
+
)
|
|
191
108
|
constraints = RDDLConstraints(simulator, vectorized=True)
|
|
192
109
|
self.constraints = constraints
|
|
193
110
|
|
|
194
111
|
# basic operations - these can be override in subclasses
|
|
195
112
|
self.compile_non_fluent_exact = compile_non_fluent_exact
|
|
196
|
-
self.NEGATIVE = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_NEGATIVE
|
|
197
|
-
self.ARITHMETIC_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_ARITHMETIC.copy()
|
|
198
|
-
self.RELATIONAL_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_RELATIONAL.copy()
|
|
199
|
-
self.LOGICAL_NOT = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_LOGICAL_NOT
|
|
200
|
-
self.LOGICAL_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_LOGICAL.copy()
|
|
201
|
-
self.AGGREGATION_OPS = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_AGGREGATION.copy()
|
|
202
113
|
self.AGGREGATION_BOOL = {'forall', 'exists'}
|
|
203
|
-
self.
|
|
204
|
-
self.
|
|
205
|
-
|
|
206
|
-
self.SWITCH_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_SWITCH
|
|
207
|
-
self.BERNOULLI_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_BERNOULLI
|
|
208
|
-
self.DISCRETE_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_DISCRETE
|
|
209
|
-
self.POISSON_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_POISSON
|
|
210
|
-
self.GEOMETRIC_HELPER = JaxRDDLCompiler.EXACT_RDDL_TO_JAX_GEOMETRIC
|
|
211
|
-
|
|
114
|
+
self.EXACT_OPS = ExactLogic(use64bit=self.use64bit).get_operator_dicts()
|
|
115
|
+
self.OPS = self.EXACT_OPS
|
|
116
|
+
|
|
212
117
|
# ===========================================================================
|
|
213
118
|
# main compilation subroutines
|
|
214
119
|
# ===========================================================================
|
|
@@ -377,7 +282,8 @@ class JaxRDDLCompiler:
|
|
|
377
282
|
|
|
378
283
|
# compile constraint information
|
|
379
284
|
if constraint_func:
|
|
380
|
-
inequality_fns, equality_fns = self._jax_nonlinear_constraints(
|
|
285
|
+
inequality_fns, equality_fns = self._jax_nonlinear_constraints(
|
|
286
|
+
init_params_constr)
|
|
381
287
|
else:
|
|
382
288
|
inequality_fns, equality_fns = None, None
|
|
383
289
|
|
|
@@ -524,7 +430,7 @@ class JaxRDDLCompiler:
|
|
|
524
430
|
_jax_wrapped_single_step_policy,
|
|
525
431
|
in_axes=(0, None, None, None, 0, None)
|
|
526
432
|
)(keys, policy_params, hyperparams, step, subs, model_params)
|
|
527
|
-
model_params = jax.tree_map(
|
|
433
|
+
model_params = jax.tree_map(partial(jnp.mean, axis=0), model_params)
|
|
528
434
|
carry = (key, policy_params, hyperparams, subs, model_params)
|
|
529
435
|
return carry, log
|
|
530
436
|
|
|
@@ -571,7 +477,11 @@ class JaxRDDLCompiler:
|
|
|
571
477
|
for (id, value) in self.model_params.items():
|
|
572
478
|
expr_id = int(str(id).split('_')[0])
|
|
573
479
|
expr = self.traced.lookup(expr_id)
|
|
574
|
-
result[id] = {
|
|
480
|
+
result[id] = {
|
|
481
|
+
'id': expr_id,
|
|
482
|
+
'rddl_op': ' '.join(expr.etype),
|
|
483
|
+
'init_value': value
|
|
484
|
+
}
|
|
575
485
|
return result
|
|
576
486
|
|
|
577
487
|
@staticmethod
|
|
@@ -722,7 +632,7 @@ class JaxRDDLCompiler:
|
|
|
722
632
|
return _jax_wrapped_cast
|
|
723
633
|
|
|
724
634
|
def _fix_dtype(self, value):
|
|
725
|
-
dtype = jnp.
|
|
635
|
+
dtype = jnp.result_type(value)
|
|
726
636
|
if jnp.issubdtype(dtype, jnp.integer):
|
|
727
637
|
return self.INT
|
|
728
638
|
elif jnp.issubdtype(dtype, jnp.floating):
|
|
@@ -870,11 +780,11 @@ class JaxRDDLCompiler:
|
|
|
870
780
|
|
|
871
781
|
# if expression is non-fluent, always use the exact operation
|
|
872
782
|
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
|
|
873
|
-
valid_ops =
|
|
874
|
-
negative_op =
|
|
783
|
+
valid_ops = self.EXACT_OPS['arithmetic']
|
|
784
|
+
negative_op = self.EXACT_OPS['negative']
|
|
875
785
|
else:
|
|
876
|
-
valid_ops = self.
|
|
877
|
-
negative_op = self.
|
|
786
|
+
valid_ops = self.OPS['arithmetic']
|
|
787
|
+
negative_op = self.OPS['negative']
|
|
878
788
|
JaxRDDLCompiler._check_valid_op(expr, valid_ops)
|
|
879
789
|
|
|
880
790
|
# recursively compile arguments
|
|
@@ -901,9 +811,9 @@ class JaxRDDLCompiler:
|
|
|
901
811
|
|
|
902
812
|
# if expression is non-fluent, always use the exact operation
|
|
903
813
|
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
|
|
904
|
-
valid_ops =
|
|
814
|
+
valid_ops = self.EXACT_OPS['relational']
|
|
905
815
|
else:
|
|
906
|
-
valid_ops = self.
|
|
816
|
+
valid_ops = self.OPS['relational']
|
|
907
817
|
JaxRDDLCompiler._check_valid_op(expr, valid_ops)
|
|
908
818
|
|
|
909
819
|
# recursively compile arguments
|
|
@@ -919,11 +829,11 @@ class JaxRDDLCompiler:
|
|
|
919
829
|
|
|
920
830
|
# if expression is non-fluent, always use the exact operation
|
|
921
831
|
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
|
|
922
|
-
valid_ops =
|
|
923
|
-
logical_not_op =
|
|
832
|
+
valid_ops = self.EXACT_OPS['logical']
|
|
833
|
+
logical_not_op = self.EXACT_OPS['logical_not']
|
|
924
834
|
else:
|
|
925
|
-
valid_ops = self.
|
|
926
|
-
logical_not_op = self.
|
|
835
|
+
valid_ops = self.OPS['logical']
|
|
836
|
+
logical_not_op = self.OPS['logical_not']
|
|
927
837
|
JaxRDDLCompiler._check_valid_op(expr, valid_ops)
|
|
928
838
|
|
|
929
839
|
# recursively compile arguments
|
|
@@ -951,9 +861,9 @@ class JaxRDDLCompiler:
|
|
|
951
861
|
|
|
952
862
|
# if expression is non-fluent, always use the exact operation
|
|
953
863
|
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
|
|
954
|
-
valid_ops =
|
|
864
|
+
valid_ops = self.EXACT_OPS['aggregation']
|
|
955
865
|
else:
|
|
956
|
-
valid_ops = self.
|
|
866
|
+
valid_ops = self.OPS['aggregation']
|
|
957
867
|
JaxRDDLCompiler._check_valid_op(expr, valid_ops)
|
|
958
868
|
is_floating = op not in self.AGGREGATION_BOOL
|
|
959
869
|
|
|
@@ -980,11 +890,11 @@ class JaxRDDLCompiler:
|
|
|
980
890
|
|
|
981
891
|
# if expression is non-fluent, always use the exact operation
|
|
982
892
|
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(expr):
|
|
983
|
-
unary_ops =
|
|
984
|
-
binary_ops =
|
|
893
|
+
unary_ops = self.EXACT_OPS['unary']
|
|
894
|
+
binary_ops = self.EXACT_OPS['binary']
|
|
985
895
|
else:
|
|
986
|
-
unary_ops = self.
|
|
987
|
-
binary_ops = self.
|
|
896
|
+
unary_ops = self.OPS['unary']
|
|
897
|
+
binary_ops = self.OPS['binary']
|
|
988
898
|
|
|
989
899
|
# recursively compile arguments
|
|
990
900
|
if op in unary_ops:
|
|
@@ -1026,9 +936,9 @@ class JaxRDDLCompiler:
|
|
|
1026
936
|
|
|
1027
937
|
# if predicate is non-fluent, always use the exact operation
|
|
1028
938
|
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(pred):
|
|
1029
|
-
if_op =
|
|
939
|
+
if_op = self.EXACT_OPS['control']['if']
|
|
1030
940
|
else:
|
|
1031
|
-
if_op = self.
|
|
941
|
+
if_op = self.OPS['control']['if']
|
|
1032
942
|
jax_op = if_op(expr.id, init_params)
|
|
1033
943
|
|
|
1034
944
|
# recursively compile arguments
|
|
@@ -1054,9 +964,9 @@ class JaxRDDLCompiler:
|
|
|
1054
964
|
# if predicate is non-fluent, always use the exact operation
|
|
1055
965
|
# case conditions are currently only literals so they are non-fluent
|
|
1056
966
|
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(pred):
|
|
1057
|
-
switch_op =
|
|
967
|
+
switch_op = self.EXACT_OPS['control']['switch']
|
|
1058
968
|
else:
|
|
1059
|
-
switch_op = self.
|
|
969
|
+
switch_op = self.OPS['control']['switch']
|
|
1060
970
|
jax_op = switch_op(expr.id, init_params)
|
|
1061
971
|
|
|
1062
972
|
# recursively compile predicate
|
|
@@ -1078,8 +988,7 @@ class JaxRDDLCompiler:
|
|
|
1078
988
|
for (i, jax_case) in enumerate(jax_cases):
|
|
1079
989
|
sample_cases[i], key, err_case, params = jax_case(x, params, key)
|
|
1080
990
|
err |= err_case
|
|
1081
|
-
sample_cases = jnp.asarray(
|
|
1082
|
-
sample_cases, dtype=self._fix_dtype(sample_cases))
|
|
991
|
+
sample_cases = jnp.asarray(sample_cases, dtype=self._fix_dtype(sample_cases))
|
|
1083
992
|
|
|
1084
993
|
# predicate (enum) is an integer - use it to extract from case array
|
|
1085
994
|
sample, params = jax_op(sample_pred, sample_cases, params)
|
|
@@ -1098,6 +1007,7 @@ class JaxRDDLCompiler:
|
|
|
1098
1007
|
# Bernoulli: complete (subclass uses Gumbel-softmax)
|
|
1099
1008
|
# Normal: complete
|
|
1100
1009
|
# Exponential: complete
|
|
1010
|
+
# Geometric: complete
|
|
1101
1011
|
# Weibull: complete
|
|
1102
1012
|
# Pareto: complete
|
|
1103
1013
|
# Gumbel: complete
|
|
@@ -1110,14 +1020,18 @@ class JaxRDDLCompiler:
|
|
|
1110
1020
|
# Discrete(p): complete (subclass uses Gumbel-softmax)
|
|
1111
1021
|
# UnnormDiscrete(p): complete (subclass uses Gumbel-softmax)
|
|
1112
1022
|
|
|
1023
|
+
# distributions which seem to support backpropagation (need more testing):
|
|
1024
|
+
# Beta
|
|
1025
|
+
# Student
|
|
1026
|
+
# Gamma
|
|
1027
|
+
# ChiSquare
|
|
1028
|
+
# Dirichlet
|
|
1029
|
+
# Poisson (subclass uses Gumbel-softmax or Poisson process trick)
|
|
1030
|
+
|
|
1113
1031
|
# distributions with incomplete reparameterization support (TODO):
|
|
1114
|
-
# Binomial
|
|
1115
|
-
# NegativeBinomial
|
|
1116
|
-
#
|
|
1117
|
-
# Gamma, ChiSquare: (no shape reparameterization)
|
|
1118
|
-
# Beta: (no reparameterization)
|
|
1119
|
-
# Geometric: (implement safe floor)
|
|
1120
|
-
# Student: (no reparameterization)
|
|
1032
|
+
# Binomial
|
|
1033
|
+
# NegativeBinomial
|
|
1034
|
+
# Multinomial
|
|
1121
1035
|
|
|
1122
1036
|
def _jax_random(self, expr, init_params):
|
|
1123
1037
|
_, name = expr.etype
|
|
@@ -1173,8 +1087,7 @@ class JaxRDDLCompiler:
|
|
|
1173
1087
|
return self._jax_discrete_pvar(expr, init_params, unnorm=True)
|
|
1174
1088
|
else:
|
|
1175
1089
|
raise RDDLNotImplementedError(
|
|
1176
|
-
f'Distribution {name} is not supported.\n' +
|
|
1177
|
-
print_stack_trace(expr))
|
|
1090
|
+
f'Distribution {name} is not supported.\n' + print_stack_trace(expr))
|
|
1178
1091
|
|
|
1179
1092
|
def _jax_kron(self, expr, init_params):
|
|
1180
1093
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_KRON_DELTA']
|
|
@@ -1251,8 +1164,7 @@ class JaxRDDLCompiler:
|
|
|
1251
1164
|
def _jax_wrapped_distribution_exp(x, params, key):
|
|
1252
1165
|
scale, key, err, params = jax_scale(x, params, key)
|
|
1253
1166
|
key, subkey = random.split(key)
|
|
1254
|
-
Exp1 = random.exponential(
|
|
1255
|
-
key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
|
|
1167
|
+
Exp1 = random.exponential(key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
|
|
1256
1168
|
sample = scale * Exp1
|
|
1257
1169
|
out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
|
|
1258
1170
|
err |= (out_of_bounds * ERR)
|
|
@@ -1273,8 +1185,8 @@ class JaxRDDLCompiler:
|
|
|
1273
1185
|
shape, key, err1, params = jax_shape(x, params, key)
|
|
1274
1186
|
scale, key, err2, params = jax_scale(x, params, key)
|
|
1275
1187
|
key, subkey = random.split(key)
|
|
1276
|
-
|
|
1277
|
-
|
|
1188
|
+
sample = random.weibull_min(
|
|
1189
|
+
key=subkey, scale=scale, concentration=shape, dtype=self.REAL)
|
|
1278
1190
|
out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
|
|
1279
1191
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1280
1192
|
return sample, key, err, params
|
|
@@ -1288,9 +1200,9 @@ class JaxRDDLCompiler:
|
|
|
1288
1200
|
|
|
1289
1201
|
# if probability is non-fluent, always use the exact operation
|
|
1290
1202
|
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_prob):
|
|
1291
|
-
bern_op =
|
|
1203
|
+
bern_op = self.EXACT_OPS['sampling']['Bernoulli']
|
|
1292
1204
|
else:
|
|
1293
|
-
bern_op = self.
|
|
1205
|
+
bern_op = self.OPS['sampling']['Bernoulli']
|
|
1294
1206
|
jax_op = bern_op(expr.id, init_params)
|
|
1295
1207
|
|
|
1296
1208
|
# recursively compile arguments
|
|
@@ -1313,9 +1225,9 @@ class JaxRDDLCompiler:
|
|
|
1313
1225
|
|
|
1314
1226
|
# if rate is non-fluent, always use the exact operation
|
|
1315
1227
|
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_rate):
|
|
1316
|
-
poisson_op =
|
|
1228
|
+
poisson_op = self.EXACT_OPS['sampling']['Poisson']
|
|
1317
1229
|
else:
|
|
1318
|
-
poisson_op = self.
|
|
1230
|
+
poisson_op = self.OPS['sampling']['Poisson']
|
|
1319
1231
|
jax_op = poisson_op(expr.id, init_params)
|
|
1320
1232
|
|
|
1321
1233
|
# recursively compile arguments
|
|
@@ -1326,7 +1238,6 @@ class JaxRDDLCompiler:
|
|
|
1326
1238
|
rate, key, err, params = jax_rate(x, params, key)
|
|
1327
1239
|
key, subkey = random.split(key)
|
|
1328
1240
|
sample, params = jax_op(subkey, rate, params)
|
|
1329
|
-
sample = sample.astype(self.INT)
|
|
1330
1241
|
out_of_bounds = jnp.logical_not(jnp.all(rate >= 0))
|
|
1331
1242
|
err |= (out_of_bounds * ERR)
|
|
1332
1243
|
return sample, key, err, params
|
|
@@ -1358,20 +1269,26 @@ class JaxRDDLCompiler:
|
|
|
1358
1269
|
def _jax_binomial(self, expr, init_params):
|
|
1359
1270
|
ERR = JaxRDDLCompiler.ERROR_CODES['INVALID_PARAM_BINOMIAL']
|
|
1360
1271
|
JaxRDDLCompiler._check_num_args(expr, 2)
|
|
1361
|
-
|
|
1362
1272
|
arg_trials, arg_prob = expr.args
|
|
1273
|
+
|
|
1274
|
+
# if prob is non-fluent, always use the exact operation
|
|
1275
|
+
if self.compile_non_fluent_exact \
|
|
1276
|
+
and not self.traced.cached_is_fluent(arg_trials) \
|
|
1277
|
+
and not self.traced.cached_is_fluent(arg_prob):
|
|
1278
|
+
bin_op = self.EXACT_OPS['sampling']['Binomial']
|
|
1279
|
+
else:
|
|
1280
|
+
bin_op = self.OPS['sampling']['Binomial']
|
|
1281
|
+
jax_op = bin_op(expr.id, init_params)
|
|
1282
|
+
|
|
1363
1283
|
jax_trials = self._jax(arg_trials, init_params)
|
|
1364
1284
|
jax_prob = self._jax(arg_prob, init_params)
|
|
1365
|
-
|
|
1366
|
-
# uses
|
|
1285
|
+
|
|
1286
|
+
# uses reduction for constant trials
|
|
1367
1287
|
def _jax_wrapped_distribution_binomial(x, params, key):
|
|
1368
1288
|
trials, key, err2, params = jax_trials(x, params, key)
|
|
1369
1289
|
prob, key, err1, params = jax_prob(x, params, key)
|
|
1370
|
-
trials = jnp.asarray(trials, dtype=self.REAL)
|
|
1371
|
-
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
1372
1290
|
key, subkey = random.split(key)
|
|
1373
|
-
|
|
1374
|
-
sample = dist.sample(seed=subkey).astype(self.INT)
|
|
1291
|
+
sample, params = jax_op(subkey, trials, prob, params)
|
|
1375
1292
|
out_of_bounds = jnp.logical_not(jnp.all(
|
|
1376
1293
|
(prob >= 0) & (prob <= 1) & (trials >= 0)))
|
|
1377
1294
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
@@ -1395,10 +1312,9 @@ class JaxRDDLCompiler:
|
|
|
1395
1312
|
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
1396
1313
|
key, subkey = random.split(key)
|
|
1397
1314
|
dist = tfp.distributions.NegativeBinomial(total_count=trials, probs=prob)
|
|
1398
|
-
sample = dist.sample(seed=subkey)
|
|
1315
|
+
sample = jnp.asarray(dist.sample(seed=subkey), dtype=self.INT)
|
|
1399
1316
|
out_of_bounds = jnp.logical_not(jnp.all(
|
|
1400
|
-
(prob >= 0) & (prob <= 1) & (trials > 0))
|
|
1401
|
-
)
|
|
1317
|
+
(prob >= 0) & (prob <= 1) & (trials > 0)))
|
|
1402
1318
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1403
1319
|
return sample, key, err, params
|
|
1404
1320
|
|
|
@@ -1431,9 +1347,9 @@ class JaxRDDLCompiler:
|
|
|
1431
1347
|
|
|
1432
1348
|
# if prob is non-fluent, always use the exact operation
|
|
1433
1349
|
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg_prob):
|
|
1434
|
-
geom_op =
|
|
1350
|
+
geom_op = self.EXACT_OPS['sampling']['Geometric']
|
|
1435
1351
|
else:
|
|
1436
|
-
geom_op = self.
|
|
1352
|
+
geom_op = self.OPS['sampling']['Geometric']
|
|
1437
1353
|
jax_op = geom_op(expr.id, init_params)
|
|
1438
1354
|
|
|
1439
1355
|
# recursively compile arguments
|
|
@@ -1443,7 +1359,6 @@ class JaxRDDLCompiler:
|
|
|
1443
1359
|
prob, key, err, params = jax_prob(x, params, key)
|
|
1444
1360
|
key, subkey = random.split(key)
|
|
1445
1361
|
sample, params = jax_op(subkey, prob, params)
|
|
1446
|
-
sample = sample.astype(self.INT)
|
|
1447
1362
|
out_of_bounds = jnp.logical_not(jnp.all((prob >= 0) & (prob <= 1)))
|
|
1448
1363
|
err |= (out_of_bounds * ERR)
|
|
1449
1364
|
return sample, key, err, params
|
|
@@ -1482,8 +1397,7 @@ class JaxRDDLCompiler:
|
|
|
1482
1397
|
def _jax_wrapped_distribution_t(x, params, key):
|
|
1483
1398
|
df, key, err, params = jax_df(x, params, key)
|
|
1484
1399
|
key, subkey = random.split(key)
|
|
1485
|
-
sample = random.t(
|
|
1486
|
-
key=subkey, df=df, shape=jnp.shape(df), dtype=self.REAL)
|
|
1400
|
+
sample = random.t(key=subkey, df=df, shape=jnp.shape(df), dtype=self.REAL)
|
|
1487
1401
|
out_of_bounds = jnp.logical_not(jnp.all(df > 0))
|
|
1488
1402
|
err |= (out_of_bounds * ERR)
|
|
1489
1403
|
return sample, key, err, params
|
|
@@ -1503,8 +1417,7 @@ class JaxRDDLCompiler:
|
|
|
1503
1417
|
mean, key, err1, params = jax_mean(x, params, key)
|
|
1504
1418
|
scale, key, err2, params = jax_scale(x, params, key)
|
|
1505
1419
|
key, subkey = random.split(key)
|
|
1506
|
-
Gumbel01 = random.gumbel(
|
|
1507
|
-
key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
|
|
1420
|
+
Gumbel01 = random.gumbel(key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
|
|
1508
1421
|
sample = mean + scale * Gumbel01
|
|
1509
1422
|
out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
|
|
1510
1423
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
@@ -1525,8 +1438,7 @@ class JaxRDDLCompiler:
|
|
|
1525
1438
|
mean, key, err1, params = jax_mean(x, params, key)
|
|
1526
1439
|
scale, key, err2, params = jax_scale(x, params, key)
|
|
1527
1440
|
key, subkey = random.split(key)
|
|
1528
|
-
Laplace01 = random.laplace(
|
|
1529
|
-
key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
|
|
1441
|
+
Laplace01 = random.laplace(key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
|
|
1530
1442
|
sample = mean + scale * Laplace01
|
|
1531
1443
|
out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
|
|
1532
1444
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
@@ -1547,8 +1459,7 @@ class JaxRDDLCompiler:
|
|
|
1547
1459
|
mean, key, err1, params = jax_mean(x, params, key)
|
|
1548
1460
|
scale, key, err2, params = jax_scale(x, params, key)
|
|
1549
1461
|
key, subkey = random.split(key)
|
|
1550
|
-
Cauchy01 = random.cauchy(
|
|
1551
|
-
key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
|
|
1462
|
+
Cauchy01 = random.cauchy(key=subkey, shape=jnp.shape(mean), dtype=self.REAL)
|
|
1552
1463
|
sample = mean + scale * Cauchy01
|
|
1553
1464
|
out_of_bounds = jnp.logical_not(jnp.all(scale > 0))
|
|
1554
1465
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
@@ -1570,7 +1481,7 @@ class JaxRDDLCompiler:
|
|
|
1570
1481
|
scale, key, err2, params = jax_scale(x, params, key)
|
|
1571
1482
|
key, subkey = random.split(key)
|
|
1572
1483
|
U = random.uniform(key=subkey, shape=jnp.shape(scale), dtype=self.REAL)
|
|
1573
|
-
sample = jnp.log(1.0 - jnp.
|
|
1484
|
+
sample = jnp.log(1.0 - jnp.log1p(-U) / shape) / scale
|
|
1574
1485
|
out_of_bounds = jnp.logical_not(jnp.all((shape > 0) & (scale > 0)))
|
|
1575
1486
|
err = err1 | err2 | (out_of_bounds * ERR)
|
|
1576
1487
|
return sample, key, err, params
|
|
@@ -1631,9 +1542,9 @@ class JaxRDDLCompiler:
|
|
|
1631
1542
|
has_fluent_arg = any(self.traced.cached_is_fluent(arg)
|
|
1632
1543
|
for arg in ordered_args)
|
|
1633
1544
|
if self.compile_non_fluent_exact and not has_fluent_arg:
|
|
1634
|
-
discrete_op =
|
|
1545
|
+
discrete_op = self.EXACT_OPS['sampling']['Discrete']
|
|
1635
1546
|
else:
|
|
1636
|
-
discrete_op = self.
|
|
1547
|
+
discrete_op = self.OPS['sampling']['Discrete']
|
|
1637
1548
|
jax_op = discrete_op(expr.id, init_params)
|
|
1638
1549
|
|
|
1639
1550
|
# compile probability expressions
|
|
@@ -1672,9 +1583,9 @@ class JaxRDDLCompiler:
|
|
|
1672
1583
|
|
|
1673
1584
|
# if probabilities are non-fluent, then always sample exact
|
|
1674
1585
|
if self.compile_non_fluent_exact and not self.traced.cached_is_fluent(arg):
|
|
1675
|
-
discrete_op =
|
|
1586
|
+
discrete_op = self.EXACT_OPS['sampling']['Discrete']
|
|
1676
1587
|
else:
|
|
1677
|
-
discrete_op = self.
|
|
1588
|
+
discrete_op = self.OPS['sampling']['Discrete']
|
|
1678
1589
|
jax_op = discrete_op(expr.id, init_params)
|
|
1679
1590
|
|
|
1680
1591
|
# compile probability function
|
|
@@ -1716,8 +1627,7 @@ class JaxRDDLCompiler:
|
|
|
1716
1627
|
return self._jax_multinomial(expr, init_params)
|
|
1717
1628
|
else:
|
|
1718
1629
|
raise RDDLNotImplementedError(
|
|
1719
|
-
f'Distribution {name} is not supported.\n' +
|
|
1720
|
-
print_stack_trace(expr))
|
|
1630
|
+
f'Distribution {name} is not supported.\n' + print_stack_trace(expr))
|
|
1721
1631
|
|
|
1722
1632
|
def _jax_multivariate_normal(self, expr, init_params):
|
|
1723
1633
|
_, args = expr.args
|
|
@@ -1771,7 +1681,7 @@ class JaxRDDLCompiler:
|
|
|
1771
1681
|
|
|
1772
1682
|
# sample StudentT(0, 1, df) -- broadcast df to same shape as cov
|
|
1773
1683
|
sample_df = sample_df[..., jnp.newaxis, jnp.newaxis]
|
|
1774
|
-
sample_df = jnp.broadcast_to(sample_df, shape=
|
|
1684
|
+
sample_df = jnp.broadcast_to(sample_df, shape=jnp.shape(sample_mean) + (1,))
|
|
1775
1685
|
key, subkey = random.split(key)
|
|
1776
1686
|
Z = random.t(
|
|
1777
1687
|
key=subkey,
|
|
@@ -1826,7 +1736,7 @@ class JaxRDDLCompiler:
|
|
|
1826
1736
|
prob = jnp.asarray(prob, dtype=self.REAL)
|
|
1827
1737
|
key, subkey = random.split(key)
|
|
1828
1738
|
dist = tfp.distributions.Multinomial(total_count=trials, probs=prob)
|
|
1829
|
-
sample = dist.sample(seed=subkey)
|
|
1739
|
+
sample = jnp.asarray(dist.sample(seed=subkey), dtype=self.INT)
|
|
1830
1740
|
sample = jnp.moveaxis(sample, source=-1, destination=index)
|
|
1831
1741
|
out_of_bounds = jnp.logical_not(jnp.all(
|
|
1832
1742
|
(prob >= 0)
|