pymc-extras 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +29 -0
- pymc_extras/distributions/__init__.py +40 -0
- pymc_extras/distributions/continuous.py +351 -0
- pymc_extras/distributions/discrete.py +399 -0
- pymc_extras/distributions/histogram_utils.py +163 -0
- pymc_extras/distributions/multivariate/__init__.py +3 -0
- pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
- pymc_extras/distributions/timeseries.py +356 -0
- pymc_extras/gp/__init__.py +18 -0
- pymc_extras/gp/latent_approx.py +183 -0
- pymc_extras/inference/__init__.py +18 -0
- pymc_extras/inference/find_map.py +431 -0
- pymc_extras/inference/fit.py +44 -0
- pymc_extras/inference/laplace.py +570 -0
- pymc_extras/inference/pathfinder.py +134 -0
- pymc_extras/inference/smc/__init__.py +13 -0
- pymc_extras/inference/smc/sampling.py +451 -0
- pymc_extras/linearmodel.py +130 -0
- pymc_extras/model/__init__.py +0 -0
- pymc_extras/model/marginal/__init__.py +0 -0
- pymc_extras/model/marginal/distributions.py +276 -0
- pymc_extras/model/marginal/graph_analysis.py +372 -0
- pymc_extras/model/marginal/marginal_model.py +595 -0
- pymc_extras/model/model_api.py +56 -0
- pymc_extras/model/transforms/__init__.py +0 -0
- pymc_extras/model/transforms/autoreparam.py +434 -0
- pymc_extras/model_builder.py +759 -0
- pymc_extras/preprocessing/__init__.py +0 -0
- pymc_extras/preprocessing/standard_scaler.py +17 -0
- pymc_extras/printing.py +182 -0
- pymc_extras/statespace/__init__.py +13 -0
- pymc_extras/statespace/core/__init__.py +7 -0
- pymc_extras/statespace/core/compile.py +48 -0
- pymc_extras/statespace/core/representation.py +438 -0
- pymc_extras/statespace/core/statespace.py +2268 -0
- pymc_extras/statespace/filters/__init__.py +15 -0
- pymc_extras/statespace/filters/distributions.py +453 -0
- pymc_extras/statespace/filters/kalman_filter.py +820 -0
- pymc_extras/statespace/filters/kalman_smoother.py +126 -0
- pymc_extras/statespace/filters/utilities.py +59 -0
- pymc_extras/statespace/models/ETS.py +670 -0
- pymc_extras/statespace/models/SARIMAX.py +536 -0
- pymc_extras/statespace/models/VARMAX.py +393 -0
- pymc_extras/statespace/models/__init__.py +6 -0
- pymc_extras/statespace/models/structural.py +1651 -0
- pymc_extras/statespace/models/utilities.py +387 -0
- pymc_extras/statespace/utils/__init__.py +0 -0
- pymc_extras/statespace/utils/constants.py +74 -0
- pymc_extras/statespace/utils/coord_tools.py +0 -0
- pymc_extras/statespace/utils/data_tools.py +182 -0
- pymc_extras/utils/__init__.py +23 -0
- pymc_extras/utils/linear_cg.py +290 -0
- pymc_extras/utils/pivoted_cholesky.py +69 -0
- pymc_extras/utils/prior.py +200 -0
- pymc_extras/utils/spline.py +131 -0
- pymc_extras/version.py +11 -0
- pymc_extras/version.txt +1 -0
- pymc_extras-0.2.0.dist-info/LICENSE +212 -0
- pymc_extras-0.2.0.dist-info/METADATA +99 -0
- pymc_extras-0.2.0.dist-info/RECORD +101 -0
- pymc_extras-0.2.0.dist-info/WHEEL +5 -0
- pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +13 -0
- tests/distributions/__init__.py +19 -0
- tests/distributions/test_continuous.py +185 -0
- tests/distributions/test_discrete.py +210 -0
- tests/distributions/test_discrete_markov_chain.py +258 -0
- tests/distributions/test_multivariate.py +304 -0
- tests/model/__init__.py +0 -0
- tests/model/marginal/__init__.py +0 -0
- tests/model/marginal/test_distributions.py +131 -0
- tests/model/marginal/test_graph_analysis.py +182 -0
- tests/model/marginal/test_marginal_model.py +867 -0
- tests/model/test_model_api.py +29 -0
- tests/statespace/__init__.py +0 -0
- tests/statespace/test_ETS.py +411 -0
- tests/statespace/test_SARIMAX.py +405 -0
- tests/statespace/test_VARMAX.py +184 -0
- tests/statespace/test_coord_assignment.py +116 -0
- tests/statespace/test_distributions.py +270 -0
- tests/statespace/test_kalman_filter.py +326 -0
- tests/statespace/test_representation.py +175 -0
- tests/statespace/test_statespace.py +818 -0
- tests/statespace/test_statespace_JAX.py +156 -0
- tests/statespace/test_structural.py +829 -0
- tests/statespace/utilities/__init__.py +0 -0
- tests/statespace/utilities/shared_fixtures.py +9 -0
- tests/statespace/utilities/statsmodel_local_level.py +42 -0
- tests/statespace/utilities/test_helpers.py +310 -0
- tests/test_blackjax_smc.py +222 -0
- tests/test_find_map.py +98 -0
- tests/test_histogram_approximation.py +109 -0
- tests/test_laplace.py +238 -0
- tests/test_linearmodel.py +208 -0
- tests/test_model_builder.py +306 -0
- tests/test_pathfinder.py +45 -0
- tests/test_pivoted_cholesky.py +24 -0
- tests/test_printing.py +98 -0
- tests/test_prior_from_trace.py +172 -0
- tests/test_splines.py +77 -0
- tests/utils.py +31 -0
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pytensor.tensor as pt
|
|
5
|
+
|
|
6
|
+
from pymc.distributions import Bernoulli, Categorical, DiscreteUniform
|
|
7
|
+
from pymc.logprob.abstract import MeasurableOp, _logprob
|
|
8
|
+
from pymc.logprob.basic import conditional_logp, logp
|
|
9
|
+
from pymc.pytensorf import constant_fold
|
|
10
|
+
from pytensor import Variable
|
|
11
|
+
from pytensor.compile.builders import OpFromGraph
|
|
12
|
+
from pytensor.compile.mode import Mode
|
|
13
|
+
from pytensor.graph import Op, vectorize_graph
|
|
14
|
+
from pytensor.graph.replace import clone_replace, graph_replace
|
|
15
|
+
from pytensor.scan import map as scan_map
|
|
16
|
+
from pytensor.scan import scan
|
|
17
|
+
from pytensor.tensor import TensorVariable
|
|
18
|
+
|
|
19
|
+
from pymc_extras.distributions import DiscreteMarkovChain
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class MarginalRV(OpFromGraph, MeasurableOp):
|
|
23
|
+
"""Base class for Marginalized RVs"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, *args, dims_connections: tuple[tuple[int | None]], **kwargs) -> None:
|
|
26
|
+
self.dims_connections = dims_connections
|
|
27
|
+
super().__init__(*args, **kwargs)
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
def support_axes(self) -> tuple[tuple[int]]:
|
|
31
|
+
"""Dimensions of dependent RVs that belong to the core (non-batched) marginalized variable."""
|
|
32
|
+
marginalized_ndim_supp = self.inner_outputs[0].owner.op.ndim_supp
|
|
33
|
+
support_axes_vars = []
|
|
34
|
+
for dims_connection in self.dims_connections:
|
|
35
|
+
ndim = len(dims_connection)
|
|
36
|
+
marginalized_supp_axes = ndim - marginalized_ndim_supp
|
|
37
|
+
support_axes_vars.append(
|
|
38
|
+
tuple(
|
|
39
|
+
-i
|
|
40
|
+
for i, dim in enumerate(reversed(dims_connection), start=1)
|
|
41
|
+
if (dim is None or dim > marginalized_supp_axes)
|
|
42
|
+
)
|
|
43
|
+
)
|
|
44
|
+
return tuple(support_axes_vars)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class MarginalFiniteDiscreteRV(MarginalRV):
|
|
48
|
+
"""Base class for Marginalized Finite Discrete RVs"""
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class MarginalDiscreteMarkovChainRV(MarginalRV):
|
|
52
|
+
"""Base class for Marginalized Discrete Markov Chain RVs"""
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
|
|
56
|
+
op = rv.owner.op
|
|
57
|
+
dist_params = rv.owner.op.dist_params(rv.owner)
|
|
58
|
+
if isinstance(op, Bernoulli):
|
|
59
|
+
return (0, 1)
|
|
60
|
+
elif isinstance(op, Categorical):
|
|
61
|
+
[p_param] = dist_params
|
|
62
|
+
[p_param_length] = constant_fold([p_param.shape[-1]])
|
|
63
|
+
return tuple(range(p_param_length))
|
|
64
|
+
elif isinstance(op, DiscreteUniform):
|
|
65
|
+
lower, upper = constant_fold(dist_params)
|
|
66
|
+
return tuple(np.arange(lower, upper + 1))
|
|
67
|
+
elif isinstance(op, DiscreteMarkovChain):
|
|
68
|
+
P, *_ = dist_params
|
|
69
|
+
return tuple(range(pt.get_vector_length(P[-1])))
|
|
70
|
+
|
|
71
|
+
raise NotImplementedError(f"Cannot compute domain for op {op}")
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def reduce_batch_dependent_logps(
|
|
75
|
+
dependent_dims_connections: Sequence[tuple[int | None, ...]],
|
|
76
|
+
dependent_ops: Sequence[Op],
|
|
77
|
+
dependent_logps: Sequence[TensorVariable],
|
|
78
|
+
) -> TensorVariable:
|
|
79
|
+
"""Combine the logps of dependent RVs and align them with the marginalized logp.
|
|
80
|
+
|
|
81
|
+
This requires reducing extra batch dims and transposing when they are not aligned.
|
|
82
|
+
|
|
83
|
+
idx = pm.Bernoulli(idx, shape=(3, 2)) # 0, 1
|
|
84
|
+
pm.Normal("dep1", mu=idx.T[..., None] * 2, shape=(3, 2, 5))
|
|
85
|
+
pm.Normal("dep2", mu=idx * 2, shape=(7, 2, 3))
|
|
86
|
+
|
|
87
|
+
marginalize(idx)
|
|
88
|
+
|
|
89
|
+
The marginalized op will have dims_connections = [(1, 0, None), (None, 0, 1)]
|
|
90
|
+
which tells us we need to reduce the last axis of dep1 logp and the first of dep2 logp,
|
|
91
|
+
as well as transpose the remaining axis of dep1 logp before adding the two element-wise.
|
|
92
|
+
|
|
93
|
+
"""
|
|
94
|
+
from pymc_extras.model.marginal.graph_analysis import get_support_axes
|
|
95
|
+
|
|
96
|
+
reduced_logps = []
|
|
97
|
+
for dependent_op, dependent_logp, dependent_dims_connection in zip(
|
|
98
|
+
dependent_ops, dependent_logps, dependent_dims_connections
|
|
99
|
+
):
|
|
100
|
+
if dependent_logp.type.ndim > 0:
|
|
101
|
+
# Find which support axis implied by the MarginalRV need to be reduced
|
|
102
|
+
# Some may have already been reduced by the logp expression of the dependent RV (e.g., multivariate RVs)
|
|
103
|
+
dep_supp_axes = get_support_axes(dependent_op)[0]
|
|
104
|
+
|
|
105
|
+
# Dependent RV support axes are already collapsed in the logp, so we ignore them
|
|
106
|
+
supp_axes = [
|
|
107
|
+
-i
|
|
108
|
+
for i, dim in enumerate(reversed(dependent_dims_connection), start=1)
|
|
109
|
+
if (dim is None and -i not in dep_supp_axes)
|
|
110
|
+
]
|
|
111
|
+
dependent_logp = dependent_logp.sum(supp_axes)
|
|
112
|
+
|
|
113
|
+
# Finally, we need to align the dependent logp batch dimensions with the marginalized logp
|
|
114
|
+
dims_alignment = [dim for dim in dependent_dims_connection if dim is not None]
|
|
115
|
+
dependent_logp = dependent_logp.transpose(*dims_alignment)
|
|
116
|
+
|
|
117
|
+
reduced_logps.append(dependent_logp)
|
|
118
|
+
|
|
119
|
+
reduced_logp = pt.add(*reduced_logps)
|
|
120
|
+
return reduced_logp
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def align_logp_dims(dims: tuple[tuple[int, None]], logp: TensorVariable) -> TensorVariable:
|
|
124
|
+
"""Align the logp with the order specified in dims."""
|
|
125
|
+
dims_alignment = [dim for dim in dims if dim is not None]
|
|
126
|
+
return logp.transpose(*dims_alignment)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]:
|
|
130
|
+
"""Inline the inner graph (outputs) of an OpFromGraph Op.
|
|
131
|
+
|
|
132
|
+
Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
|
|
133
|
+
the inner graph.
|
|
134
|
+
"""
|
|
135
|
+
return clone_replace(
|
|
136
|
+
op.inner_outputs,
|
|
137
|
+
replace=tuple(zip(op.inner_inputs, inputs)),
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
DUMMY_ZERO = pt.constant(0, name="dummy_zero")
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
@_logprob.register(MarginalFiniteDiscreteRV)
|
|
145
|
+
def finite_discrete_marginal_rv_logp(op: MarginalFiniteDiscreteRV, values, *inputs, **kwargs):
|
|
146
|
+
# Clone the inner RV graph of the Marginalized RV
|
|
147
|
+
marginalized_rv, *inner_rvs = inline_ofg_outputs(op, inputs)
|
|
148
|
+
|
|
149
|
+
# Obtain the joint_logp graph of the inner RV graph
|
|
150
|
+
inner_rv_values = dict(zip(inner_rvs, values))
|
|
151
|
+
marginalized_vv = marginalized_rv.clone()
|
|
152
|
+
rv_values = inner_rv_values | {marginalized_rv: marginalized_vv}
|
|
153
|
+
logps_dict = conditional_logp(rv_values=rv_values, **kwargs)
|
|
154
|
+
|
|
155
|
+
# Reduce logp dimensions corresponding to broadcasted variables
|
|
156
|
+
marginalized_logp = logps_dict.pop(marginalized_vv)
|
|
157
|
+
joint_logp = marginalized_logp + reduce_batch_dependent_logps(
|
|
158
|
+
dependent_dims_connections=op.dims_connections,
|
|
159
|
+
dependent_ops=[inner_rv.owner.op for inner_rv in inner_rvs],
|
|
160
|
+
dependent_logps=[logps_dict[value] for value in values],
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# Compute the joint_logp for all possible n values of the marginalized RV. We assume
|
|
164
|
+
# each original dimension is independent so that it suffices to evaluate the graph
|
|
165
|
+
# n times, once with each possible value of the marginalized RV replicated across
|
|
166
|
+
# batched dimensions of the marginalized RV
|
|
167
|
+
|
|
168
|
+
# PyMC does not allow RVs in the logp graph, even if we are just using the shape
|
|
169
|
+
marginalized_rv_shape = constant_fold(tuple(marginalized_rv.shape), raise_not_constant=False)
|
|
170
|
+
marginalized_rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv)
|
|
171
|
+
marginalized_rv_domain_tensor = pt.moveaxis(
|
|
172
|
+
pt.full(
|
|
173
|
+
(*marginalized_rv_shape, len(marginalized_rv_domain)),
|
|
174
|
+
marginalized_rv_domain,
|
|
175
|
+
dtype=marginalized_rv.dtype,
|
|
176
|
+
),
|
|
177
|
+
-1,
|
|
178
|
+
0,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
try:
|
|
182
|
+
joint_logps = vectorize_graph(
|
|
183
|
+
joint_logp, replace={marginalized_vv: marginalized_rv_domain_tensor}
|
|
184
|
+
)
|
|
185
|
+
except Exception:
|
|
186
|
+
# Fallback to Scan
|
|
187
|
+
def logp_fn(marginalized_rv_const, *non_sequences):
|
|
188
|
+
return graph_replace(joint_logp, replace={marginalized_vv: marginalized_rv_const})
|
|
189
|
+
|
|
190
|
+
joint_logps, _ = scan_map(
|
|
191
|
+
fn=logp_fn,
|
|
192
|
+
sequences=marginalized_rv_domain_tensor,
|
|
193
|
+
non_sequences=[*values, *inputs],
|
|
194
|
+
mode=Mode().including("local_remove_check_parameter"),
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
joint_logp = pt.logsumexp(joint_logps, axis=0)
|
|
198
|
+
|
|
199
|
+
# Align logp with non-collapsed batch dimensions of first RV
|
|
200
|
+
joint_logp = align_logp_dims(dims=op.dims_connections[0], logp=joint_logp)
|
|
201
|
+
|
|
202
|
+
# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
|
|
203
|
+
dummy_logps = (DUMMY_ZERO,) * (len(values) - 1)
|
|
204
|
+
return joint_logp, *dummy_logps
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
@_logprob.register(MarginalDiscreteMarkovChainRV)
|
|
208
|
+
def marginal_hmm_logp(op, values, *inputs, **kwargs):
|
|
209
|
+
chain_rv, *dependent_rvs = inline_ofg_outputs(op, inputs)
|
|
210
|
+
|
|
211
|
+
P, n_steps_, init_dist_, rng = chain_rv.owner.inputs
|
|
212
|
+
domain = pt.arange(P.shape[-1], dtype="int32")
|
|
213
|
+
|
|
214
|
+
# Construct logp in two steps
|
|
215
|
+
# Step 1: Compute the probability of the data ("emissions") under every possible state (vec_logp_emission)
|
|
216
|
+
|
|
217
|
+
# First we need to vectorize the conditional logp graph of the data, in case there are batch dimensions floating
|
|
218
|
+
# around. To do this, we need to break the dependency between chain and the init_dist_ random variable. Otherwise,
|
|
219
|
+
# PyMC will detect a random variable in the logp graph (init_dist_), that isn't relevant at this step.
|
|
220
|
+
chain_value = chain_rv.clone()
|
|
221
|
+
dependent_rvs = clone_replace(dependent_rvs, {chain_rv: chain_value})
|
|
222
|
+
logp_emissions_dict = conditional_logp(dict(zip(dependent_rvs, values)))
|
|
223
|
+
|
|
224
|
+
# Reduce and add the batch dims beyond the chain dimension
|
|
225
|
+
reduced_logp_emissions = reduce_batch_dependent_logps(
|
|
226
|
+
dependent_dims_connections=op.dims_connections,
|
|
227
|
+
dependent_ops=[dependent_rv.owner.op for dependent_rv in dependent_rvs],
|
|
228
|
+
dependent_logps=[logp_emissions_dict[value] for value in values],
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
# Add a batch dimension for the domain of the chain
|
|
232
|
+
chain_shape = constant_fold(tuple(chain_rv.shape))
|
|
233
|
+
batch_chain_value = pt.moveaxis(pt.full((*chain_shape, domain.size), domain), -1, 0)
|
|
234
|
+
batch_logp_emissions = vectorize_graph(reduced_logp_emissions, {chain_value: batch_chain_value})
|
|
235
|
+
|
|
236
|
+
# Step 2: Compute the transition probabilities
|
|
237
|
+
# This is the "forward algorithm", alpha_t = p(y | s_t) * sum_{s_{t-1}}(p(s_t | s_{t-1}) * alpha_{t-1})
|
|
238
|
+
# We do it entirely in logs, though.
|
|
239
|
+
|
|
240
|
+
# To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states)
|
|
241
|
+
# under the initial distribution. This is robust to everything the user can throw at it.
|
|
242
|
+
init_dist_value = init_dist_.type()
|
|
243
|
+
logp_init_dist = logp(init_dist_, init_dist_value)
|
|
244
|
+
# There is a degerate batch dim for lags=1 (the only supported case),
|
|
245
|
+
# that we have to work around, by expanding the batch value and then squeezing it out of the logp
|
|
246
|
+
batch_logp_init_dist = vectorize_graph(
|
|
247
|
+
logp_init_dist, {init_dist_value: batch_chain_value[:, None, ..., 0]}
|
|
248
|
+
).squeeze(1)
|
|
249
|
+
log_alpha_init = batch_logp_init_dist + batch_logp_emissions[..., 0]
|
|
250
|
+
|
|
251
|
+
def step_alpha(logp_emission, log_alpha, log_P):
|
|
252
|
+
step_log_prob = pt.logsumexp(log_alpha[:, None] + log_P, axis=0)
|
|
253
|
+
return logp_emission + step_log_prob
|
|
254
|
+
|
|
255
|
+
P_bcast_dims = (len(chain_shape) - 1) - (P.type.ndim - 2)
|
|
256
|
+
log_P = pt.shape_padright(pt.log(P), P_bcast_dims)
|
|
257
|
+
log_alpha_seq, _ = scan(
|
|
258
|
+
step_alpha,
|
|
259
|
+
non_sequences=[log_P],
|
|
260
|
+
outputs_info=[log_alpha_init],
|
|
261
|
+
# Scan needs the time dimension first, and we already consumed the 1st logp computing the initial value
|
|
262
|
+
sequences=pt.moveaxis(batch_logp_emissions[..., 1:], -1, 0),
|
|
263
|
+
)
|
|
264
|
+
# Final logp is just the sum of the last scan state
|
|
265
|
+
joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0)
|
|
266
|
+
|
|
267
|
+
# Align logp with non-collapsed batch dimensions of first RV
|
|
268
|
+
remaining_dims_first_emission = list(op.dims_connections[0])
|
|
269
|
+
# The last dim of chain_rv was removed when computing the logp
|
|
270
|
+
remaining_dims_first_emission.remove(chain_rv.type.ndim - 1)
|
|
271
|
+
joint_logp = align_logp_dims(remaining_dims_first_emission, joint_logp)
|
|
272
|
+
|
|
273
|
+
# If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
|
|
274
|
+
# return is the joint probability of everything together, but PyMC still expects one logp for each emission stream.
|
|
275
|
+
dummy_logps = (DUMMY_ZERO,) * (len(values) - 1)
|
|
276
|
+
return joint_logp, *dummy_logps
|
|
@@ -0,0 +1,372 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from itertools import zip_longest
|
|
5
|
+
|
|
6
|
+
from pymc import SymbolicRandomVariable
|
|
7
|
+
from pytensor.compile import SharedVariable
|
|
8
|
+
from pytensor.graph import Constant, Variable, ancestors
|
|
9
|
+
from pytensor.graph.basic import io_toposort
|
|
10
|
+
from pytensor.tensor import TensorType, TensorVariable
|
|
11
|
+
from pytensor.tensor.blockwise import Blockwise
|
|
12
|
+
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
|
|
13
|
+
from pytensor.tensor.random.op import RandomVariable
|
|
14
|
+
from pytensor.tensor.rewriting.subtensor import is_full_slice
|
|
15
|
+
from pytensor.tensor.shape import Shape
|
|
16
|
+
from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor, get_idx_list
|
|
17
|
+
from pytensor.tensor.type_other import NoneTypeT
|
|
18
|
+
|
|
19
|
+
from pymc_extras.model.marginal.distributions import MarginalRV
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def static_shape_ancestors(vars):
|
|
23
|
+
"""Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph)."""
|
|
24
|
+
return [
|
|
25
|
+
var
|
|
26
|
+
for var in ancestors(vars)
|
|
27
|
+
if (
|
|
28
|
+
var.owner
|
|
29
|
+
and isinstance(var.owner.op, Shape)
|
|
30
|
+
# All static dims lengths of Shape input are known
|
|
31
|
+
and None not in var.owner.inputs[0].type.shape
|
|
32
|
+
)
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def find_conditional_input_rvs(output_rvs, all_rvs):
|
|
37
|
+
"""Find conditionally indepedent input RVs."""
|
|
38
|
+
blockers = [other_rv for other_rv in all_rvs if other_rv not in output_rvs]
|
|
39
|
+
blockers += static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs))
|
|
40
|
+
return [
|
|
41
|
+
var
|
|
42
|
+
for var in ancestors(output_rvs, blockers=blockers)
|
|
43
|
+
if var in blockers or (var.owner is None and not isinstance(var, Constant | SharedVariable))
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def is_conditional_dependent(
|
|
48
|
+
dependent_rv: TensorVariable, dependable_rv: TensorVariable, all_rvs
|
|
49
|
+
) -> bool:
|
|
50
|
+
"""Check if dependent_rv is conditionall dependent on dependable_rv,
|
|
51
|
+
given all conditionally independent all_rvs"""
|
|
52
|
+
|
|
53
|
+
return dependable_rv in find_conditional_input_rvs((dependent_rv,), all_rvs)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def find_conditional_dependent_rvs(dependable_rv, all_rvs):
|
|
57
|
+
"""Find rvs than depend on dependable"""
|
|
58
|
+
return [
|
|
59
|
+
rv
|
|
60
|
+
for rv in all_rvs
|
|
61
|
+
if (rv is not dependable_rv and is_conditional_dependent(rv, dependable_rv, all_rvs))
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def get_support_axes(op) -> tuple[tuple[int, ...], ...]:
|
|
66
|
+
if isinstance(op, MarginalRV):
|
|
67
|
+
return op.support_axes
|
|
68
|
+
else:
|
|
69
|
+
# For vanilla RVs, the support axes are the last ndim_supp
|
|
70
|
+
return (tuple(range(-op.ndim_supp, 0)),)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _advanced_indexing_axis_and_ndim(idxs) -> tuple[int, int]:
|
|
74
|
+
"""Find the output axis and dimensionality of the advanced indexing group (i.e., array indexing).
|
|
75
|
+
|
|
76
|
+
There is a special case: when there are non-consecutive advanced indexing groups, the advanced indexing
|
|
77
|
+
group is always moved to the front.
|
|
78
|
+
|
|
79
|
+
See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
|
|
80
|
+
"""
|
|
81
|
+
adv_group_axis = None
|
|
82
|
+
simple_group_after_adv = False
|
|
83
|
+
for axis, idx in enumerate(idxs):
|
|
84
|
+
if isinstance(idx.type, TensorType):
|
|
85
|
+
if simple_group_after_adv:
|
|
86
|
+
# Special non-consecutive case
|
|
87
|
+
adv_group_axis = 0
|
|
88
|
+
break
|
|
89
|
+
elif adv_group_axis is None:
|
|
90
|
+
adv_group_axis = axis
|
|
91
|
+
elif adv_group_axis is not None:
|
|
92
|
+
# Special non-consecutive case
|
|
93
|
+
simple_group_after_adv = True
|
|
94
|
+
|
|
95
|
+
adv_group_ndim = max(idx.type.ndim for idx in idxs if isinstance(idx.type, TensorType))
|
|
96
|
+
return adv_group_axis, adv_group_ndim
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
DIMS = tuple[int | None, ...]
|
|
100
|
+
VAR_DIMS = dict[Variable, DIMS]
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _broadcast_dims(
|
|
104
|
+
inputs_dims: Sequence[DIMS],
|
|
105
|
+
) -> DIMS:
|
|
106
|
+
output_ndim = max((len(input_dim) for input_dim in inputs_dims), default=0)
|
|
107
|
+
|
|
108
|
+
# Add missing dims
|
|
109
|
+
inputs_dims = [
|
|
110
|
+
(None,) * (output_ndim - len(input_dim)) + input_dim for input_dim in inputs_dims
|
|
111
|
+
]
|
|
112
|
+
|
|
113
|
+
# Find which known dims show in the output, while checking no mixing
|
|
114
|
+
output_dims = []
|
|
115
|
+
for inputs_dim in zip(*inputs_dims):
|
|
116
|
+
output_dim = None
|
|
117
|
+
for input_dim in inputs_dim:
|
|
118
|
+
if input_dim is None:
|
|
119
|
+
continue
|
|
120
|
+
if output_dim is not None and output_dim != input_dim:
|
|
121
|
+
raise ValueError("Different known dimensions mixed via broadcasting")
|
|
122
|
+
output_dim = input_dim
|
|
123
|
+
output_dims.append(output_dim)
|
|
124
|
+
|
|
125
|
+
# Check for duplicates
|
|
126
|
+
known_dims = [dim for dim in output_dims if dim is not None]
|
|
127
|
+
if len(known_dims) > len(set(known_dims)):
|
|
128
|
+
raise ValueError("Same known dimension used in different axis after broadcasting")
|
|
129
|
+
|
|
130
|
+
return tuple(output_dims)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars) -> VAR_DIMS:
|
|
134
|
+
for node in io_toposort(input_vars, output_vars):
|
|
135
|
+
inputs_dims = [
|
|
136
|
+
var_dims.get(inp, ((None,) * inp.type.ndim) if hasattr(inp.type, "ndim") else ())
|
|
137
|
+
for inp in node.inputs
|
|
138
|
+
]
|
|
139
|
+
|
|
140
|
+
if all(dim is None for input_dims in inputs_dims for dim in input_dims):
|
|
141
|
+
# None of the inputs are related to the batch_axes of the input_vars
|
|
142
|
+
continue
|
|
143
|
+
|
|
144
|
+
elif isinstance(node.op, DimShuffle):
|
|
145
|
+
[input_dims] = inputs_dims
|
|
146
|
+
output_dims = tuple(None if i == "x" else input_dims[i] for i in node.op.new_order)
|
|
147
|
+
var_dims[node.outputs[0]] = output_dims
|
|
148
|
+
|
|
149
|
+
elif isinstance(node.op, MarginalRV) or (
|
|
150
|
+
isinstance(node.op, SymbolicRandomVariable) and node.op.extended_signature is None
|
|
151
|
+
):
|
|
152
|
+
# MarginalRV and SymbolicRandomVariables without signature are a wild-card,
|
|
153
|
+
# so we need to introspect the inner graph.
|
|
154
|
+
op = node.op
|
|
155
|
+
inner_inputs = op.inner_inputs
|
|
156
|
+
inner_outputs = op.inner_outputs
|
|
157
|
+
|
|
158
|
+
inner_var_dims = _subgraph_batch_dim_connection(
|
|
159
|
+
dict(zip(inner_inputs, inputs_dims)), inner_inputs, inner_outputs
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
support_axes = iter(get_support_axes(op))
|
|
163
|
+
if isinstance(op, MarginalRV):
|
|
164
|
+
# The first output is the marginalized variable for which we don't compute support axes
|
|
165
|
+
support_axes = itertools.chain(((),), support_axes)
|
|
166
|
+
for i, (out, inner_out) in enumerate(zip(node.outputs, inner_outputs)):
|
|
167
|
+
if not isinstance(out.type, TensorType):
|
|
168
|
+
continue
|
|
169
|
+
support_axes_out = next(support_axes)
|
|
170
|
+
|
|
171
|
+
if inner_out in inner_var_dims:
|
|
172
|
+
out_dims = inner_var_dims[inner_out]
|
|
173
|
+
if any(
|
|
174
|
+
dim is not None for dim in (out_dims[axis] for axis in support_axes_out)
|
|
175
|
+
):
|
|
176
|
+
raise ValueError(f"Known dim corresponds to core dimension of {node.op}")
|
|
177
|
+
var_dims[out] = out_dims
|
|
178
|
+
|
|
179
|
+
elif isinstance(node.op, Elemwise | Blockwise | RandomVariable | SymbolicRandomVariable):
|
|
180
|
+
# NOTE: User-provided CustomDist may not respect core dimensions on the left.
|
|
181
|
+
|
|
182
|
+
if isinstance(node.op, Elemwise):
|
|
183
|
+
op_batch_ndim = node.outputs[0].type.ndim
|
|
184
|
+
else:
|
|
185
|
+
op_batch_ndim = node.op.batch_ndim(node)
|
|
186
|
+
|
|
187
|
+
if isinstance(node.op, SymbolicRandomVariable):
|
|
188
|
+
# SymbolicRandomVariable don't have explicit expand_dims unlike the other Ops considered in this
|
|
189
|
+
[_, _, param_idxs], _ = node.op.get_input_output_type_idxs(
|
|
190
|
+
node.op.extended_signature
|
|
191
|
+
)
|
|
192
|
+
for param_idx, param_core_ndim in zip(param_idxs, node.op.ndims_params):
|
|
193
|
+
param_dims = inputs_dims[param_idx]
|
|
194
|
+
missing_ndim = op_batch_ndim - (len(param_dims) - param_core_ndim)
|
|
195
|
+
inputs_dims[param_idx] = (None,) * missing_ndim + param_dims
|
|
196
|
+
|
|
197
|
+
if any(
|
|
198
|
+
dim is not None for input_dim in inputs_dims for dim in input_dim[op_batch_ndim:]
|
|
199
|
+
):
|
|
200
|
+
raise ValueError(
|
|
201
|
+
f"Use of known dimensions as core dimensions of op {node.op} not supported."
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
batch_dims = _broadcast_dims(
|
|
205
|
+
tuple(input_dims[:op_batch_ndim] for input_dims in inputs_dims)
|
|
206
|
+
)
|
|
207
|
+
for out in node.outputs:
|
|
208
|
+
if isinstance(out.type, TensorType):
|
|
209
|
+
core_ndim = out.type.ndim - op_batch_ndim
|
|
210
|
+
output_dims = batch_dims + (None,) * core_ndim
|
|
211
|
+
var_dims[out] = output_dims
|
|
212
|
+
|
|
213
|
+
elif isinstance(node.op, CAReduce):
|
|
214
|
+
[input_dims] = inputs_dims
|
|
215
|
+
|
|
216
|
+
axes = node.op.axis
|
|
217
|
+
if isinstance(axes, int):
|
|
218
|
+
axes = (axes,)
|
|
219
|
+
elif axes is None:
|
|
220
|
+
axes = tuple(range(node.inputs[0].type.ndim))
|
|
221
|
+
|
|
222
|
+
if any(input_dims[axis] for axis in axes):
|
|
223
|
+
raise ValueError(
|
|
224
|
+
f"Use of known dimensions as reduced dimensions of op {node.op} not supported."
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
output_dims = [dims for i, dims in enumerate(input_dims) if i not in axes]
|
|
228
|
+
var_dims[node.outputs[0]] = tuple(output_dims)
|
|
229
|
+
|
|
230
|
+
elif isinstance(node.op, Subtensor):
|
|
231
|
+
value_dims, *keys_dims = inputs_dims
|
|
232
|
+
# Dims in basic indexing must belong to the value variable, since indexing keys are always scalar
|
|
233
|
+
assert not any(dim is None for dim in keys_dims)
|
|
234
|
+
keys = get_idx_list(node.inputs, node.op.idx_list)
|
|
235
|
+
|
|
236
|
+
output_dims = []
|
|
237
|
+
for value_dim, idx in zip_longest(value_dims, keys, fillvalue=slice(None)):
|
|
238
|
+
if idx == slice(None):
|
|
239
|
+
# Dim is kept
|
|
240
|
+
output_dims.append(value_dim)
|
|
241
|
+
elif value_dim is not None:
|
|
242
|
+
raise ValueError(
|
|
243
|
+
"Partial slicing or indexing of known dimensions not supported."
|
|
244
|
+
)
|
|
245
|
+
elif isinstance(idx, slice):
|
|
246
|
+
# Unknown dimensions kept by partial slice.
|
|
247
|
+
output_dims.append(None)
|
|
248
|
+
|
|
249
|
+
var_dims[node.outputs[0]] = tuple(output_dims)
|
|
250
|
+
|
|
251
|
+
elif isinstance(node.op, AdvancedSubtensor):
|
|
252
|
+
# AdvancedSubtensor dimensions can show up as both the indexed variable and indexing variables
|
|
253
|
+
value, *keys = node.inputs
|
|
254
|
+
value_dims, *keys_dims = inputs_dims
|
|
255
|
+
|
|
256
|
+
# Just to stay sane, we forbid any boolean indexing...
|
|
257
|
+
if any(isinstance(idx.type, TensorType) and idx.type.dtype == "bool" for idx in keys):
|
|
258
|
+
raise NotImplementedError(
|
|
259
|
+
f"Array indexing with boolean variables in node {node} not supported."
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
if any(dim is not None for dim in value_dims) and any(
|
|
263
|
+
dim is not None for key_dims in keys_dims for dim in key_dims
|
|
264
|
+
):
|
|
265
|
+
# Both indexed variable and indexing variables have known dimensions
|
|
266
|
+
# I am to lazy to think through these, so we raise for now.
|
|
267
|
+
raise NotImplementedError(
|
|
268
|
+
f"Simultaneous use of known dimensions in indexed and indexing variables in node {node} not supported."
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
adv_group_axis, adv_group_ndim = _advanced_indexing_axis_and_ndim(keys)
|
|
272
|
+
|
|
273
|
+
if any(dim is not None for dim in value_dims):
|
|
274
|
+
# Indexed variable has known dimensions
|
|
275
|
+
|
|
276
|
+
if any(isinstance(idx.type, NoneTypeT) for idx in keys):
|
|
277
|
+
# Corresponds to an expand_dims, for now not supported
|
|
278
|
+
raise NotImplementedError(
|
|
279
|
+
f"Advanced indexing in node {node} which introduces new axis is not supported."
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
non_adv_dims = []
|
|
283
|
+
for value_dim, idx in zip_longest(value_dims, keys, fillvalue=slice(None)):
|
|
284
|
+
if is_full_slice(idx):
|
|
285
|
+
non_adv_dims.append(value_dim)
|
|
286
|
+
elif value_dim is not None:
|
|
287
|
+
# We are trying to partially slice or index a known dimension
|
|
288
|
+
raise ValueError(
|
|
289
|
+
"Partial slicing or advanced integer indexing of known dimensions not supported."
|
|
290
|
+
)
|
|
291
|
+
elif isinstance(idx, slice):
|
|
292
|
+
# Unknown dimensions kept by partial slice.
|
|
293
|
+
non_adv_dims.append(None)
|
|
294
|
+
|
|
295
|
+
# Insert unknown dimensions corresponding to advanced indexing
|
|
296
|
+
output_dims = tuple(
|
|
297
|
+
non_adv_dims[:adv_group_axis]
|
|
298
|
+
+ [None] * adv_group_ndim
|
|
299
|
+
+ non_adv_dims[adv_group_axis:]
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
else:
|
|
303
|
+
# Indexing keys have known dimensions.
|
|
304
|
+
# Only array indices can have dimensions, the rest are just slices or newaxis
|
|
305
|
+
|
|
306
|
+
# Advanced indexing variables broadcast together, so we apply same rules as in Elemwise
|
|
307
|
+
adv_dims = _broadcast_dims(keys_dims)
|
|
308
|
+
|
|
309
|
+
start_non_adv_dims = (None,) * adv_group_axis
|
|
310
|
+
end_non_adv_dims = (None,) * (
|
|
311
|
+
node.outputs[0].type.ndim - adv_group_axis - adv_group_ndim
|
|
312
|
+
)
|
|
313
|
+
output_dims = start_non_adv_dims + adv_dims + end_non_adv_dims
|
|
314
|
+
|
|
315
|
+
var_dims[node.outputs[0]] = output_dims
|
|
316
|
+
|
|
317
|
+
else:
|
|
318
|
+
raise NotImplementedError(f"Marginalization through operation {node} not supported.")
|
|
319
|
+
|
|
320
|
+
return var_dims
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def subgraph_batch_dim_connection(input_var, output_vars) -> list[DIMS]:
|
|
324
|
+
"""Identify how the batch dims of input map to the batch dimensions of the output_rvs.
|
|
325
|
+
|
|
326
|
+
Example:
|
|
327
|
+
-------
|
|
328
|
+
In the example below `idx` has two batch dimensions (indexed 0, 1 from left to right).
|
|
329
|
+
The two uncommented dependent variables each have 2 batch dimensions where each entry
|
|
330
|
+
results from a mapping of a single entry from one of these batch dimensions.
|
|
331
|
+
|
|
332
|
+
This mapping is transposed in the case of the first dependent variable, and shows up in
|
|
333
|
+
the same order for the second dependent variable. Each of the variables as a further
|
|
334
|
+
batch dimension encoded as `None`.
|
|
335
|
+
|
|
336
|
+
The commented out third dependent variable combines information from the batch dimensions
|
|
337
|
+
of `idx` via the `sum` operation. A `ValueError` would be raised if we requested the
|
|
338
|
+
connection of batch dims.
|
|
339
|
+
|
|
340
|
+
.. code-block:: python
|
|
341
|
+
import pymc as pm
|
|
342
|
+
|
|
343
|
+
idx = pm.Bernoulli.dist(shape=(3, 2))
|
|
344
|
+
dep1 = pm.Normal.dist(mu=idx.T[..., None] * 2, shape=(3, 2, 5))
|
|
345
|
+
dep2 = pm.Normal.dist(mu=idx * 2, shape=(7, 2, 3))
|
|
346
|
+
# dep3 = pm.Normal.dist(mu=idx.sum()) # Would raise if requested
|
|
347
|
+
|
|
348
|
+
print(subgraph_batch_dim_connection(idx, [], [dep1, dep2]))
|
|
349
|
+
# [(1, 0, None), (None, 0, 1)]
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
-------
|
|
353
|
+
list of tuples
|
|
354
|
+
Each tuple corresponds to the batch dimensions of the output_rv in the order they are found in the output.
|
|
355
|
+
None is used to indicate a batch dimension that is not mapped from the input.
|
|
356
|
+
|
|
357
|
+
Raises:
|
|
358
|
+
------
|
|
359
|
+
ValueError
|
|
360
|
+
If input batch dimensions are mixed in the graph leading to output_vars.
|
|
361
|
+
|
|
362
|
+
NotImplementedError
|
|
363
|
+
If variable related to marginalized batch_dims is used in an operation that is not yet supported
|
|
364
|
+
"""
|
|
365
|
+
var_dims = {input_var: tuple(range(input_var.type.ndim))}
|
|
366
|
+
var_dims = _subgraph_batch_dim_connection(var_dims, [input_var], output_vars)
|
|
367
|
+
ret = []
|
|
368
|
+
for output_var in output_vars:
|
|
369
|
+
output_dims = var_dims.get(output_var, (None,) * output_var.type.ndim)
|
|
370
|
+
assert len(output_dims) == output_var.type.ndim
|
|
371
|
+
ret.append(output_dims)
|
|
372
|
+
return ret
|