pymc-extras 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. pymc_extras/__init__.py +29 -0
  2. pymc_extras/distributions/__init__.py +40 -0
  3. pymc_extras/distributions/continuous.py +351 -0
  4. pymc_extras/distributions/discrete.py +399 -0
  5. pymc_extras/distributions/histogram_utils.py +163 -0
  6. pymc_extras/distributions/multivariate/__init__.py +3 -0
  7. pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
  8. pymc_extras/distributions/timeseries.py +356 -0
  9. pymc_extras/gp/__init__.py +18 -0
  10. pymc_extras/gp/latent_approx.py +183 -0
  11. pymc_extras/inference/__init__.py +18 -0
  12. pymc_extras/inference/find_map.py +431 -0
  13. pymc_extras/inference/fit.py +44 -0
  14. pymc_extras/inference/laplace.py +570 -0
  15. pymc_extras/inference/pathfinder.py +134 -0
  16. pymc_extras/inference/smc/__init__.py +13 -0
  17. pymc_extras/inference/smc/sampling.py +451 -0
  18. pymc_extras/linearmodel.py +130 -0
  19. pymc_extras/model/__init__.py +0 -0
  20. pymc_extras/model/marginal/__init__.py +0 -0
  21. pymc_extras/model/marginal/distributions.py +276 -0
  22. pymc_extras/model/marginal/graph_analysis.py +372 -0
  23. pymc_extras/model/marginal/marginal_model.py +595 -0
  24. pymc_extras/model/model_api.py +56 -0
  25. pymc_extras/model/transforms/__init__.py +0 -0
  26. pymc_extras/model/transforms/autoreparam.py +434 -0
  27. pymc_extras/model_builder.py +759 -0
  28. pymc_extras/preprocessing/__init__.py +0 -0
  29. pymc_extras/preprocessing/standard_scaler.py +17 -0
  30. pymc_extras/printing.py +182 -0
  31. pymc_extras/statespace/__init__.py +13 -0
  32. pymc_extras/statespace/core/__init__.py +7 -0
  33. pymc_extras/statespace/core/compile.py +48 -0
  34. pymc_extras/statespace/core/representation.py +438 -0
  35. pymc_extras/statespace/core/statespace.py +2268 -0
  36. pymc_extras/statespace/filters/__init__.py +15 -0
  37. pymc_extras/statespace/filters/distributions.py +453 -0
  38. pymc_extras/statespace/filters/kalman_filter.py +820 -0
  39. pymc_extras/statespace/filters/kalman_smoother.py +126 -0
  40. pymc_extras/statespace/filters/utilities.py +59 -0
  41. pymc_extras/statespace/models/ETS.py +670 -0
  42. pymc_extras/statespace/models/SARIMAX.py +536 -0
  43. pymc_extras/statespace/models/VARMAX.py +393 -0
  44. pymc_extras/statespace/models/__init__.py +6 -0
  45. pymc_extras/statespace/models/structural.py +1651 -0
  46. pymc_extras/statespace/models/utilities.py +387 -0
  47. pymc_extras/statespace/utils/__init__.py +0 -0
  48. pymc_extras/statespace/utils/constants.py +74 -0
  49. pymc_extras/statespace/utils/coord_tools.py +0 -0
  50. pymc_extras/statespace/utils/data_tools.py +182 -0
  51. pymc_extras/utils/__init__.py +23 -0
  52. pymc_extras/utils/linear_cg.py +290 -0
  53. pymc_extras/utils/pivoted_cholesky.py +69 -0
  54. pymc_extras/utils/prior.py +200 -0
  55. pymc_extras/utils/spline.py +131 -0
  56. pymc_extras/version.py +11 -0
  57. pymc_extras/version.txt +1 -0
  58. pymc_extras-0.2.0.dist-info/LICENSE +212 -0
  59. pymc_extras-0.2.0.dist-info/METADATA +99 -0
  60. pymc_extras-0.2.0.dist-info/RECORD +101 -0
  61. pymc_extras-0.2.0.dist-info/WHEEL +5 -0
  62. pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
  63. tests/__init__.py +13 -0
  64. tests/distributions/__init__.py +19 -0
  65. tests/distributions/test_continuous.py +185 -0
  66. tests/distributions/test_discrete.py +210 -0
  67. tests/distributions/test_discrete_markov_chain.py +258 -0
  68. tests/distributions/test_multivariate.py +304 -0
  69. tests/model/__init__.py +0 -0
  70. tests/model/marginal/__init__.py +0 -0
  71. tests/model/marginal/test_distributions.py +131 -0
  72. tests/model/marginal/test_graph_analysis.py +182 -0
  73. tests/model/marginal/test_marginal_model.py +867 -0
  74. tests/model/test_model_api.py +29 -0
  75. tests/statespace/__init__.py +0 -0
  76. tests/statespace/test_ETS.py +411 -0
  77. tests/statespace/test_SARIMAX.py +405 -0
  78. tests/statespace/test_VARMAX.py +184 -0
  79. tests/statespace/test_coord_assignment.py +116 -0
  80. tests/statespace/test_distributions.py +270 -0
  81. tests/statespace/test_kalman_filter.py +326 -0
  82. tests/statespace/test_representation.py +175 -0
  83. tests/statespace/test_statespace.py +818 -0
  84. tests/statespace/test_statespace_JAX.py +156 -0
  85. tests/statespace/test_structural.py +829 -0
  86. tests/statespace/utilities/__init__.py +0 -0
  87. tests/statespace/utilities/shared_fixtures.py +9 -0
  88. tests/statespace/utilities/statsmodel_local_level.py +42 -0
  89. tests/statespace/utilities/test_helpers.py +310 -0
  90. tests/test_blackjax_smc.py +222 -0
  91. tests/test_find_map.py +98 -0
  92. tests/test_histogram_approximation.py +109 -0
  93. tests/test_laplace.py +238 -0
  94. tests/test_linearmodel.py +208 -0
  95. tests/test_model_builder.py +306 -0
  96. tests/test_pathfinder.py +45 -0
  97. tests/test_pivoted_cholesky.py +24 -0
  98. tests/test_printing.py +98 -0
  99. tests/test_prior_from_trace.py +172 -0
  100. tests/test_splines.py +77 -0
  101. tests/utils.py +31 -0
@@ -0,0 +1,276 @@
1
+ from collections.abc import Sequence
2
+
3
+ import numpy as np
4
+ import pytensor.tensor as pt
5
+
6
+ from pymc.distributions import Bernoulli, Categorical, DiscreteUniform
7
+ from pymc.logprob.abstract import MeasurableOp, _logprob
8
+ from pymc.logprob.basic import conditional_logp, logp
9
+ from pymc.pytensorf import constant_fold
10
+ from pytensor import Variable
11
+ from pytensor.compile.builders import OpFromGraph
12
+ from pytensor.compile.mode import Mode
13
+ from pytensor.graph import Op, vectorize_graph
14
+ from pytensor.graph.replace import clone_replace, graph_replace
15
+ from pytensor.scan import map as scan_map
16
+ from pytensor.scan import scan
17
+ from pytensor.tensor import TensorVariable
18
+
19
+ from pymc_extras.distributions import DiscreteMarkovChain
20
+
21
+
22
+ class MarginalRV(OpFromGraph, MeasurableOp):
23
+ """Base class for Marginalized RVs"""
24
+
25
+ def __init__(self, *args, dims_connections: tuple[tuple[int | None]], **kwargs) -> None:
26
+ self.dims_connections = dims_connections
27
+ super().__init__(*args, **kwargs)
28
+
29
+ @property
30
+ def support_axes(self) -> tuple[tuple[int]]:
31
+ """Dimensions of dependent RVs that belong to the core (non-batched) marginalized variable."""
32
+ marginalized_ndim_supp = self.inner_outputs[0].owner.op.ndim_supp
33
+ support_axes_vars = []
34
+ for dims_connection in self.dims_connections:
35
+ ndim = len(dims_connection)
36
+ marginalized_supp_axes = ndim - marginalized_ndim_supp
37
+ support_axes_vars.append(
38
+ tuple(
39
+ -i
40
+ for i, dim in enumerate(reversed(dims_connection), start=1)
41
+ if (dim is None or dim > marginalized_supp_axes)
42
+ )
43
+ )
44
+ return tuple(support_axes_vars)
45
+
46
+
47
+ class MarginalFiniteDiscreteRV(MarginalRV):
48
+ """Base class for Marginalized Finite Discrete RVs"""
49
+
50
+
51
+ class MarginalDiscreteMarkovChainRV(MarginalRV):
52
+ """Base class for Marginalized Discrete Markov Chain RVs"""
53
+
54
+
55
+ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
56
+ op = rv.owner.op
57
+ dist_params = rv.owner.op.dist_params(rv.owner)
58
+ if isinstance(op, Bernoulli):
59
+ return (0, 1)
60
+ elif isinstance(op, Categorical):
61
+ [p_param] = dist_params
62
+ [p_param_length] = constant_fold([p_param.shape[-1]])
63
+ return tuple(range(p_param_length))
64
+ elif isinstance(op, DiscreteUniform):
65
+ lower, upper = constant_fold(dist_params)
66
+ return tuple(np.arange(lower, upper + 1))
67
+ elif isinstance(op, DiscreteMarkovChain):
68
+ P, *_ = dist_params
69
+ return tuple(range(pt.get_vector_length(P[-1])))
70
+
71
+ raise NotImplementedError(f"Cannot compute domain for op {op}")
72
+
73
+
74
+ def reduce_batch_dependent_logps(
75
+ dependent_dims_connections: Sequence[tuple[int | None, ...]],
76
+ dependent_ops: Sequence[Op],
77
+ dependent_logps: Sequence[TensorVariable],
78
+ ) -> TensorVariable:
79
+ """Combine the logps of dependent RVs and align them with the marginalized logp.
80
+
81
+ This requires reducing extra batch dims and transposing when they are not aligned.
82
+
83
+ idx = pm.Bernoulli(idx, shape=(3, 2)) # 0, 1
84
+ pm.Normal("dep1", mu=idx.T[..., None] * 2, shape=(3, 2, 5))
85
+ pm.Normal("dep2", mu=idx * 2, shape=(7, 2, 3))
86
+
87
+ marginalize(idx)
88
+
89
+ The marginalized op will have dims_connections = [(1, 0, None), (None, 0, 1)]
90
+ which tells us we need to reduce the last axis of dep1 logp and the first of dep2 logp,
91
+ as well as transpose the remaining axis of dep1 logp before adding the two element-wise.
92
+
93
+ """
94
+ from pymc_extras.model.marginal.graph_analysis import get_support_axes
95
+
96
+ reduced_logps = []
97
+ for dependent_op, dependent_logp, dependent_dims_connection in zip(
98
+ dependent_ops, dependent_logps, dependent_dims_connections
99
+ ):
100
+ if dependent_logp.type.ndim > 0:
101
+ # Find which support axis implied by the MarginalRV need to be reduced
102
+ # Some may have already been reduced by the logp expression of the dependent RV (e.g., multivariate RVs)
103
+ dep_supp_axes = get_support_axes(dependent_op)[0]
104
+
105
+ # Dependent RV support axes are already collapsed in the logp, so we ignore them
106
+ supp_axes = [
107
+ -i
108
+ for i, dim in enumerate(reversed(dependent_dims_connection), start=1)
109
+ if (dim is None and -i not in dep_supp_axes)
110
+ ]
111
+ dependent_logp = dependent_logp.sum(supp_axes)
112
+
113
+ # Finally, we need to align the dependent logp batch dimensions with the marginalized logp
114
+ dims_alignment = [dim for dim in dependent_dims_connection if dim is not None]
115
+ dependent_logp = dependent_logp.transpose(*dims_alignment)
116
+
117
+ reduced_logps.append(dependent_logp)
118
+
119
+ reduced_logp = pt.add(*reduced_logps)
120
+ return reduced_logp
121
+
122
+
123
+ def align_logp_dims(dims: tuple[tuple[int, None]], logp: TensorVariable) -> TensorVariable:
124
+ """Align the logp with the order specified in dims."""
125
+ dims_alignment = [dim for dim in dims if dim is not None]
126
+ return logp.transpose(*dims_alignment)
127
+
128
+
129
+ def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]:
130
+ """Inline the inner graph (outputs) of an OpFromGraph Op.
131
+
132
+ Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
133
+ the inner graph.
134
+ """
135
+ return clone_replace(
136
+ op.inner_outputs,
137
+ replace=tuple(zip(op.inner_inputs, inputs)),
138
+ )
139
+
140
+
141
+ DUMMY_ZERO = pt.constant(0, name="dummy_zero")
142
+
143
+
144
+ @_logprob.register(MarginalFiniteDiscreteRV)
145
+ def finite_discrete_marginal_rv_logp(op: MarginalFiniteDiscreteRV, values, *inputs, **kwargs):
146
+ # Clone the inner RV graph of the Marginalized RV
147
+ marginalized_rv, *inner_rvs = inline_ofg_outputs(op, inputs)
148
+
149
+ # Obtain the joint_logp graph of the inner RV graph
150
+ inner_rv_values = dict(zip(inner_rvs, values))
151
+ marginalized_vv = marginalized_rv.clone()
152
+ rv_values = inner_rv_values | {marginalized_rv: marginalized_vv}
153
+ logps_dict = conditional_logp(rv_values=rv_values, **kwargs)
154
+
155
+ # Reduce logp dimensions corresponding to broadcasted variables
156
+ marginalized_logp = logps_dict.pop(marginalized_vv)
157
+ joint_logp = marginalized_logp + reduce_batch_dependent_logps(
158
+ dependent_dims_connections=op.dims_connections,
159
+ dependent_ops=[inner_rv.owner.op for inner_rv in inner_rvs],
160
+ dependent_logps=[logps_dict[value] for value in values],
161
+ )
162
+
163
+ # Compute the joint_logp for all possible n values of the marginalized RV. We assume
164
+ # each original dimension is independent so that it suffices to evaluate the graph
165
+ # n times, once with each possible value of the marginalized RV replicated across
166
+ # batched dimensions of the marginalized RV
167
+
168
+ # PyMC does not allow RVs in the logp graph, even if we are just using the shape
169
+ marginalized_rv_shape = constant_fold(tuple(marginalized_rv.shape), raise_not_constant=False)
170
+ marginalized_rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv)
171
+ marginalized_rv_domain_tensor = pt.moveaxis(
172
+ pt.full(
173
+ (*marginalized_rv_shape, len(marginalized_rv_domain)),
174
+ marginalized_rv_domain,
175
+ dtype=marginalized_rv.dtype,
176
+ ),
177
+ -1,
178
+ 0,
179
+ )
180
+
181
+ try:
182
+ joint_logps = vectorize_graph(
183
+ joint_logp, replace={marginalized_vv: marginalized_rv_domain_tensor}
184
+ )
185
+ except Exception:
186
+ # Fallback to Scan
187
+ def logp_fn(marginalized_rv_const, *non_sequences):
188
+ return graph_replace(joint_logp, replace={marginalized_vv: marginalized_rv_const})
189
+
190
+ joint_logps, _ = scan_map(
191
+ fn=logp_fn,
192
+ sequences=marginalized_rv_domain_tensor,
193
+ non_sequences=[*values, *inputs],
194
+ mode=Mode().including("local_remove_check_parameter"),
195
+ )
196
+
197
+ joint_logp = pt.logsumexp(joint_logps, axis=0)
198
+
199
+ # Align logp with non-collapsed batch dimensions of first RV
200
+ joint_logp = align_logp_dims(dims=op.dims_connections[0], logp=joint_logp)
201
+
202
+ # We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
203
+ dummy_logps = (DUMMY_ZERO,) * (len(values) - 1)
204
+ return joint_logp, *dummy_logps
205
+
206
+
207
+ @_logprob.register(MarginalDiscreteMarkovChainRV)
208
+ def marginal_hmm_logp(op, values, *inputs, **kwargs):
209
+ chain_rv, *dependent_rvs = inline_ofg_outputs(op, inputs)
210
+
211
+ P, n_steps_, init_dist_, rng = chain_rv.owner.inputs
212
+ domain = pt.arange(P.shape[-1], dtype="int32")
213
+
214
+ # Construct logp in two steps
215
+ # Step 1: Compute the probability of the data ("emissions") under every possible state (vec_logp_emission)
216
+
217
+ # First we need to vectorize the conditional logp graph of the data, in case there are batch dimensions floating
218
+ # around. To do this, we need to break the dependency between chain and the init_dist_ random variable. Otherwise,
219
+ # PyMC will detect a random variable in the logp graph (init_dist_), that isn't relevant at this step.
220
+ chain_value = chain_rv.clone()
221
+ dependent_rvs = clone_replace(dependent_rvs, {chain_rv: chain_value})
222
+ logp_emissions_dict = conditional_logp(dict(zip(dependent_rvs, values)))
223
+
224
+ # Reduce and add the batch dims beyond the chain dimension
225
+ reduced_logp_emissions = reduce_batch_dependent_logps(
226
+ dependent_dims_connections=op.dims_connections,
227
+ dependent_ops=[dependent_rv.owner.op for dependent_rv in dependent_rvs],
228
+ dependent_logps=[logp_emissions_dict[value] for value in values],
229
+ )
230
+
231
+ # Add a batch dimension for the domain of the chain
232
+ chain_shape = constant_fold(tuple(chain_rv.shape))
233
+ batch_chain_value = pt.moveaxis(pt.full((*chain_shape, domain.size), domain), -1, 0)
234
+ batch_logp_emissions = vectorize_graph(reduced_logp_emissions, {chain_value: batch_chain_value})
235
+
236
+ # Step 2: Compute the transition probabilities
237
+ # This is the "forward algorithm", alpha_t = p(y | s_t) * sum_{s_{t-1}}(p(s_t | s_{t-1}) * alpha_{t-1})
238
+ # We do it entirely in logs, though.
239
+
240
+ # To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states)
241
+ # under the initial distribution. This is robust to everything the user can throw at it.
242
+ init_dist_value = init_dist_.type()
243
+ logp_init_dist = logp(init_dist_, init_dist_value)
244
+ # There is a degerate batch dim for lags=1 (the only supported case),
245
+ # that we have to work around, by expanding the batch value and then squeezing it out of the logp
246
+ batch_logp_init_dist = vectorize_graph(
247
+ logp_init_dist, {init_dist_value: batch_chain_value[:, None, ..., 0]}
248
+ ).squeeze(1)
249
+ log_alpha_init = batch_logp_init_dist + batch_logp_emissions[..., 0]
250
+
251
+ def step_alpha(logp_emission, log_alpha, log_P):
252
+ step_log_prob = pt.logsumexp(log_alpha[:, None] + log_P, axis=0)
253
+ return logp_emission + step_log_prob
254
+
255
+ P_bcast_dims = (len(chain_shape) - 1) - (P.type.ndim - 2)
256
+ log_P = pt.shape_padright(pt.log(P), P_bcast_dims)
257
+ log_alpha_seq, _ = scan(
258
+ step_alpha,
259
+ non_sequences=[log_P],
260
+ outputs_info=[log_alpha_init],
261
+ # Scan needs the time dimension first, and we already consumed the 1st logp computing the initial value
262
+ sequences=pt.moveaxis(batch_logp_emissions[..., 1:], -1, 0),
263
+ )
264
+ # Final logp is just the sum of the last scan state
265
+ joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0)
266
+
267
+ # Align logp with non-collapsed batch dimensions of first RV
268
+ remaining_dims_first_emission = list(op.dims_connections[0])
269
+ # The last dim of chain_rv was removed when computing the logp
270
+ remaining_dims_first_emission.remove(chain_rv.type.ndim - 1)
271
+ joint_logp = align_logp_dims(remaining_dims_first_emission, joint_logp)
272
+
273
+ # If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
274
+ # return is the joint probability of everything together, but PyMC still expects one logp for each emission stream.
275
+ dummy_logps = (DUMMY_ZERO,) * (len(values) - 1)
276
+ return joint_logp, *dummy_logps
@@ -0,0 +1,372 @@
1
+ import itertools
2
+
3
+ from collections.abc import Sequence
4
+ from itertools import zip_longest
5
+
6
+ from pymc import SymbolicRandomVariable
7
+ from pytensor.compile import SharedVariable
8
+ from pytensor.graph import Constant, Variable, ancestors
9
+ from pytensor.graph.basic import io_toposort
10
+ from pytensor.tensor import TensorType, TensorVariable
11
+ from pytensor.tensor.blockwise import Blockwise
12
+ from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
13
+ from pytensor.tensor.random.op import RandomVariable
14
+ from pytensor.tensor.rewriting.subtensor import is_full_slice
15
+ from pytensor.tensor.shape import Shape
16
+ from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor, get_idx_list
17
+ from pytensor.tensor.type_other import NoneTypeT
18
+
19
+ from pymc_extras.model.marginal.distributions import MarginalRV
20
+
21
+
22
+ def static_shape_ancestors(vars):
23
+ """Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph)."""
24
+ return [
25
+ var
26
+ for var in ancestors(vars)
27
+ if (
28
+ var.owner
29
+ and isinstance(var.owner.op, Shape)
30
+ # All static dims lengths of Shape input are known
31
+ and None not in var.owner.inputs[0].type.shape
32
+ )
33
+ ]
34
+
35
+
36
+ def find_conditional_input_rvs(output_rvs, all_rvs):
37
+ """Find conditionally indepedent input RVs."""
38
+ blockers = [other_rv for other_rv in all_rvs if other_rv not in output_rvs]
39
+ blockers += static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs))
40
+ return [
41
+ var
42
+ for var in ancestors(output_rvs, blockers=blockers)
43
+ if var in blockers or (var.owner is None and not isinstance(var, Constant | SharedVariable))
44
+ ]
45
+
46
+
47
+ def is_conditional_dependent(
48
+ dependent_rv: TensorVariable, dependable_rv: TensorVariable, all_rvs
49
+ ) -> bool:
50
+ """Check if dependent_rv is conditionall dependent on dependable_rv,
51
+ given all conditionally independent all_rvs"""
52
+
53
+ return dependable_rv in find_conditional_input_rvs((dependent_rv,), all_rvs)
54
+
55
+
56
+ def find_conditional_dependent_rvs(dependable_rv, all_rvs):
57
+ """Find rvs than depend on dependable"""
58
+ return [
59
+ rv
60
+ for rv in all_rvs
61
+ if (rv is not dependable_rv and is_conditional_dependent(rv, dependable_rv, all_rvs))
62
+ ]
63
+
64
+
65
+ def get_support_axes(op) -> tuple[tuple[int, ...], ...]:
66
+ if isinstance(op, MarginalRV):
67
+ return op.support_axes
68
+ else:
69
+ # For vanilla RVs, the support axes are the last ndim_supp
70
+ return (tuple(range(-op.ndim_supp, 0)),)
71
+
72
+
73
+ def _advanced_indexing_axis_and_ndim(idxs) -> tuple[int, int]:
74
+ """Find the output axis and dimensionality of the advanced indexing group (i.e., array indexing).
75
+
76
+ There is a special case: when there are non-consecutive advanced indexing groups, the advanced indexing
77
+ group is always moved to the front.
78
+
79
+ See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
80
+ """
81
+ adv_group_axis = None
82
+ simple_group_after_adv = False
83
+ for axis, idx in enumerate(idxs):
84
+ if isinstance(idx.type, TensorType):
85
+ if simple_group_after_adv:
86
+ # Special non-consecutive case
87
+ adv_group_axis = 0
88
+ break
89
+ elif adv_group_axis is None:
90
+ adv_group_axis = axis
91
+ elif adv_group_axis is not None:
92
+ # Special non-consecutive case
93
+ simple_group_after_adv = True
94
+
95
+ adv_group_ndim = max(idx.type.ndim for idx in idxs if isinstance(idx.type, TensorType))
96
+ return adv_group_axis, adv_group_ndim
97
+
98
+
99
+ DIMS = tuple[int | None, ...]
100
+ VAR_DIMS = dict[Variable, DIMS]
101
+
102
+
103
+ def _broadcast_dims(
104
+ inputs_dims: Sequence[DIMS],
105
+ ) -> DIMS:
106
+ output_ndim = max((len(input_dim) for input_dim in inputs_dims), default=0)
107
+
108
+ # Add missing dims
109
+ inputs_dims = [
110
+ (None,) * (output_ndim - len(input_dim)) + input_dim for input_dim in inputs_dims
111
+ ]
112
+
113
+ # Find which known dims show in the output, while checking no mixing
114
+ output_dims = []
115
+ for inputs_dim in zip(*inputs_dims):
116
+ output_dim = None
117
+ for input_dim in inputs_dim:
118
+ if input_dim is None:
119
+ continue
120
+ if output_dim is not None and output_dim != input_dim:
121
+ raise ValueError("Different known dimensions mixed via broadcasting")
122
+ output_dim = input_dim
123
+ output_dims.append(output_dim)
124
+
125
+ # Check for duplicates
126
+ known_dims = [dim for dim in output_dims if dim is not None]
127
+ if len(known_dims) > len(set(known_dims)):
128
+ raise ValueError("Same known dimension used in different axis after broadcasting")
129
+
130
+ return tuple(output_dims)
131
+
132
+
133
+ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars) -> VAR_DIMS:
134
+ for node in io_toposort(input_vars, output_vars):
135
+ inputs_dims = [
136
+ var_dims.get(inp, ((None,) * inp.type.ndim) if hasattr(inp.type, "ndim") else ())
137
+ for inp in node.inputs
138
+ ]
139
+
140
+ if all(dim is None for input_dims in inputs_dims for dim in input_dims):
141
+ # None of the inputs are related to the batch_axes of the input_vars
142
+ continue
143
+
144
+ elif isinstance(node.op, DimShuffle):
145
+ [input_dims] = inputs_dims
146
+ output_dims = tuple(None if i == "x" else input_dims[i] for i in node.op.new_order)
147
+ var_dims[node.outputs[0]] = output_dims
148
+
149
+ elif isinstance(node.op, MarginalRV) or (
150
+ isinstance(node.op, SymbolicRandomVariable) and node.op.extended_signature is None
151
+ ):
152
+ # MarginalRV and SymbolicRandomVariables without signature are a wild-card,
153
+ # so we need to introspect the inner graph.
154
+ op = node.op
155
+ inner_inputs = op.inner_inputs
156
+ inner_outputs = op.inner_outputs
157
+
158
+ inner_var_dims = _subgraph_batch_dim_connection(
159
+ dict(zip(inner_inputs, inputs_dims)), inner_inputs, inner_outputs
160
+ )
161
+
162
+ support_axes = iter(get_support_axes(op))
163
+ if isinstance(op, MarginalRV):
164
+ # The first output is the marginalized variable for which we don't compute support axes
165
+ support_axes = itertools.chain(((),), support_axes)
166
+ for i, (out, inner_out) in enumerate(zip(node.outputs, inner_outputs)):
167
+ if not isinstance(out.type, TensorType):
168
+ continue
169
+ support_axes_out = next(support_axes)
170
+
171
+ if inner_out in inner_var_dims:
172
+ out_dims = inner_var_dims[inner_out]
173
+ if any(
174
+ dim is not None for dim in (out_dims[axis] for axis in support_axes_out)
175
+ ):
176
+ raise ValueError(f"Known dim corresponds to core dimension of {node.op}")
177
+ var_dims[out] = out_dims
178
+
179
+ elif isinstance(node.op, Elemwise | Blockwise | RandomVariable | SymbolicRandomVariable):
180
+ # NOTE: User-provided CustomDist may not respect core dimensions on the left.
181
+
182
+ if isinstance(node.op, Elemwise):
183
+ op_batch_ndim = node.outputs[0].type.ndim
184
+ else:
185
+ op_batch_ndim = node.op.batch_ndim(node)
186
+
187
+ if isinstance(node.op, SymbolicRandomVariable):
188
+ # SymbolicRandomVariable don't have explicit expand_dims unlike the other Ops considered in this
189
+ [_, _, param_idxs], _ = node.op.get_input_output_type_idxs(
190
+ node.op.extended_signature
191
+ )
192
+ for param_idx, param_core_ndim in zip(param_idxs, node.op.ndims_params):
193
+ param_dims = inputs_dims[param_idx]
194
+ missing_ndim = op_batch_ndim - (len(param_dims) - param_core_ndim)
195
+ inputs_dims[param_idx] = (None,) * missing_ndim + param_dims
196
+
197
+ if any(
198
+ dim is not None for input_dim in inputs_dims for dim in input_dim[op_batch_ndim:]
199
+ ):
200
+ raise ValueError(
201
+ f"Use of known dimensions as core dimensions of op {node.op} not supported."
202
+ )
203
+
204
+ batch_dims = _broadcast_dims(
205
+ tuple(input_dims[:op_batch_ndim] for input_dims in inputs_dims)
206
+ )
207
+ for out in node.outputs:
208
+ if isinstance(out.type, TensorType):
209
+ core_ndim = out.type.ndim - op_batch_ndim
210
+ output_dims = batch_dims + (None,) * core_ndim
211
+ var_dims[out] = output_dims
212
+
213
+ elif isinstance(node.op, CAReduce):
214
+ [input_dims] = inputs_dims
215
+
216
+ axes = node.op.axis
217
+ if isinstance(axes, int):
218
+ axes = (axes,)
219
+ elif axes is None:
220
+ axes = tuple(range(node.inputs[0].type.ndim))
221
+
222
+ if any(input_dims[axis] for axis in axes):
223
+ raise ValueError(
224
+ f"Use of known dimensions as reduced dimensions of op {node.op} not supported."
225
+ )
226
+
227
+ output_dims = [dims for i, dims in enumerate(input_dims) if i not in axes]
228
+ var_dims[node.outputs[0]] = tuple(output_dims)
229
+
230
+ elif isinstance(node.op, Subtensor):
231
+ value_dims, *keys_dims = inputs_dims
232
+ # Dims in basic indexing must belong to the value variable, since indexing keys are always scalar
233
+ assert not any(dim is None for dim in keys_dims)
234
+ keys = get_idx_list(node.inputs, node.op.idx_list)
235
+
236
+ output_dims = []
237
+ for value_dim, idx in zip_longest(value_dims, keys, fillvalue=slice(None)):
238
+ if idx == slice(None):
239
+ # Dim is kept
240
+ output_dims.append(value_dim)
241
+ elif value_dim is not None:
242
+ raise ValueError(
243
+ "Partial slicing or indexing of known dimensions not supported."
244
+ )
245
+ elif isinstance(idx, slice):
246
+ # Unknown dimensions kept by partial slice.
247
+ output_dims.append(None)
248
+
249
+ var_dims[node.outputs[0]] = tuple(output_dims)
250
+
251
+ elif isinstance(node.op, AdvancedSubtensor):
252
+ # AdvancedSubtensor dimensions can show up as both the indexed variable and indexing variables
253
+ value, *keys = node.inputs
254
+ value_dims, *keys_dims = inputs_dims
255
+
256
+ # Just to stay sane, we forbid any boolean indexing...
257
+ if any(isinstance(idx.type, TensorType) and idx.type.dtype == "bool" for idx in keys):
258
+ raise NotImplementedError(
259
+ f"Array indexing with boolean variables in node {node} not supported."
260
+ )
261
+
262
+ if any(dim is not None for dim in value_dims) and any(
263
+ dim is not None for key_dims in keys_dims for dim in key_dims
264
+ ):
265
+ # Both indexed variable and indexing variables have known dimensions
266
+ # I am to lazy to think through these, so we raise for now.
267
+ raise NotImplementedError(
268
+ f"Simultaneous use of known dimensions in indexed and indexing variables in node {node} not supported."
269
+ )
270
+
271
+ adv_group_axis, adv_group_ndim = _advanced_indexing_axis_and_ndim(keys)
272
+
273
+ if any(dim is not None for dim in value_dims):
274
+ # Indexed variable has known dimensions
275
+
276
+ if any(isinstance(idx.type, NoneTypeT) for idx in keys):
277
+ # Corresponds to an expand_dims, for now not supported
278
+ raise NotImplementedError(
279
+ f"Advanced indexing in node {node} which introduces new axis is not supported."
280
+ )
281
+
282
+ non_adv_dims = []
283
+ for value_dim, idx in zip_longest(value_dims, keys, fillvalue=slice(None)):
284
+ if is_full_slice(idx):
285
+ non_adv_dims.append(value_dim)
286
+ elif value_dim is not None:
287
+ # We are trying to partially slice or index a known dimension
288
+ raise ValueError(
289
+ "Partial slicing or advanced integer indexing of known dimensions not supported."
290
+ )
291
+ elif isinstance(idx, slice):
292
+ # Unknown dimensions kept by partial slice.
293
+ non_adv_dims.append(None)
294
+
295
+ # Insert unknown dimensions corresponding to advanced indexing
296
+ output_dims = tuple(
297
+ non_adv_dims[:adv_group_axis]
298
+ + [None] * adv_group_ndim
299
+ + non_adv_dims[adv_group_axis:]
300
+ )
301
+
302
+ else:
303
+ # Indexing keys have known dimensions.
304
+ # Only array indices can have dimensions, the rest are just slices or newaxis
305
+
306
+ # Advanced indexing variables broadcast together, so we apply same rules as in Elemwise
307
+ adv_dims = _broadcast_dims(keys_dims)
308
+
309
+ start_non_adv_dims = (None,) * adv_group_axis
310
+ end_non_adv_dims = (None,) * (
311
+ node.outputs[0].type.ndim - adv_group_axis - adv_group_ndim
312
+ )
313
+ output_dims = start_non_adv_dims + adv_dims + end_non_adv_dims
314
+
315
+ var_dims[node.outputs[0]] = output_dims
316
+
317
+ else:
318
+ raise NotImplementedError(f"Marginalization through operation {node} not supported.")
319
+
320
+ return var_dims
321
+
322
+
323
+ def subgraph_batch_dim_connection(input_var, output_vars) -> list[DIMS]:
324
+ """Identify how the batch dims of input map to the batch dimensions of the output_rvs.
325
+
326
+ Example:
327
+ -------
328
+ In the example below `idx` has two batch dimensions (indexed 0, 1 from left to right).
329
+ The two uncommented dependent variables each have 2 batch dimensions where each entry
330
+ results from a mapping of a single entry from one of these batch dimensions.
331
+
332
+ This mapping is transposed in the case of the first dependent variable, and shows up in
333
+ the same order for the second dependent variable. Each of the variables as a further
334
+ batch dimension encoded as `None`.
335
+
336
+ The commented out third dependent variable combines information from the batch dimensions
337
+ of `idx` via the `sum` operation. A `ValueError` would be raised if we requested the
338
+ connection of batch dims.
339
+
340
+ .. code-block:: python
341
+ import pymc as pm
342
+
343
+ idx = pm.Bernoulli.dist(shape=(3, 2))
344
+ dep1 = pm.Normal.dist(mu=idx.T[..., None] * 2, shape=(3, 2, 5))
345
+ dep2 = pm.Normal.dist(mu=idx * 2, shape=(7, 2, 3))
346
+ # dep3 = pm.Normal.dist(mu=idx.sum()) # Would raise if requested
347
+
348
+ print(subgraph_batch_dim_connection(idx, [], [dep1, dep2]))
349
+ # [(1, 0, None), (None, 0, 1)]
350
+
351
+ Returns:
352
+ -------
353
+ list of tuples
354
+ Each tuple corresponds to the batch dimensions of the output_rv in the order they are found in the output.
355
+ None is used to indicate a batch dimension that is not mapped from the input.
356
+
357
+ Raises:
358
+ ------
359
+ ValueError
360
+ If input batch dimensions are mixed in the graph leading to output_vars.
361
+
362
+ NotImplementedError
363
+ If variable related to marginalized batch_dims is used in an operation that is not yet supported
364
+ """
365
+ var_dims = {input_var: tuple(range(input_var.type.ndim))}
366
+ var_dims = _subgraph_batch_dim_connection(var_dims, [input_var], output_vars)
367
+ ret = []
368
+ for output_var in output_vars:
369
+ output_dims = var_dims.get(output_var, (None,) * output_var.type.ndim)
370
+ assert len(output_dims) == output_var.type.ndim
371
+ ret.append(output_dims)
372
+ return ret