pymc-extras 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +29 -0
- pymc_extras/distributions/__init__.py +40 -0
- pymc_extras/distributions/continuous.py +351 -0
- pymc_extras/distributions/discrete.py +399 -0
- pymc_extras/distributions/histogram_utils.py +163 -0
- pymc_extras/distributions/multivariate/__init__.py +3 -0
- pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
- pymc_extras/distributions/timeseries.py +356 -0
- pymc_extras/gp/__init__.py +18 -0
- pymc_extras/gp/latent_approx.py +183 -0
- pymc_extras/inference/__init__.py +18 -0
- pymc_extras/inference/find_map.py +431 -0
- pymc_extras/inference/fit.py +44 -0
- pymc_extras/inference/laplace.py +570 -0
- pymc_extras/inference/pathfinder.py +134 -0
- pymc_extras/inference/smc/__init__.py +13 -0
- pymc_extras/inference/smc/sampling.py +451 -0
- pymc_extras/linearmodel.py +130 -0
- pymc_extras/model/__init__.py +0 -0
- pymc_extras/model/marginal/__init__.py +0 -0
- pymc_extras/model/marginal/distributions.py +276 -0
- pymc_extras/model/marginal/graph_analysis.py +372 -0
- pymc_extras/model/marginal/marginal_model.py +595 -0
- pymc_extras/model/model_api.py +56 -0
- pymc_extras/model/transforms/__init__.py +0 -0
- pymc_extras/model/transforms/autoreparam.py +434 -0
- pymc_extras/model_builder.py +759 -0
- pymc_extras/preprocessing/__init__.py +0 -0
- pymc_extras/preprocessing/standard_scaler.py +17 -0
- pymc_extras/printing.py +182 -0
- pymc_extras/statespace/__init__.py +13 -0
- pymc_extras/statespace/core/__init__.py +7 -0
- pymc_extras/statespace/core/compile.py +48 -0
- pymc_extras/statespace/core/representation.py +438 -0
- pymc_extras/statespace/core/statespace.py +2268 -0
- pymc_extras/statespace/filters/__init__.py +15 -0
- pymc_extras/statespace/filters/distributions.py +453 -0
- pymc_extras/statespace/filters/kalman_filter.py +820 -0
- pymc_extras/statespace/filters/kalman_smoother.py +126 -0
- pymc_extras/statespace/filters/utilities.py +59 -0
- pymc_extras/statespace/models/ETS.py +670 -0
- pymc_extras/statespace/models/SARIMAX.py +536 -0
- pymc_extras/statespace/models/VARMAX.py +393 -0
- pymc_extras/statespace/models/__init__.py +6 -0
- pymc_extras/statespace/models/structural.py +1651 -0
- pymc_extras/statespace/models/utilities.py +387 -0
- pymc_extras/statespace/utils/__init__.py +0 -0
- pymc_extras/statespace/utils/constants.py +74 -0
- pymc_extras/statespace/utils/coord_tools.py +0 -0
- pymc_extras/statespace/utils/data_tools.py +182 -0
- pymc_extras/utils/__init__.py +23 -0
- pymc_extras/utils/linear_cg.py +290 -0
- pymc_extras/utils/pivoted_cholesky.py +69 -0
- pymc_extras/utils/prior.py +200 -0
- pymc_extras/utils/spline.py +131 -0
- pymc_extras/version.py +11 -0
- pymc_extras/version.txt +1 -0
- pymc_extras-0.2.0.dist-info/LICENSE +212 -0
- pymc_extras-0.2.0.dist-info/METADATA +99 -0
- pymc_extras-0.2.0.dist-info/RECORD +101 -0
- pymc_extras-0.2.0.dist-info/WHEEL +5 -0
- pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +13 -0
- tests/distributions/__init__.py +19 -0
- tests/distributions/test_continuous.py +185 -0
- tests/distributions/test_discrete.py +210 -0
- tests/distributions/test_discrete_markov_chain.py +258 -0
- tests/distributions/test_multivariate.py +304 -0
- tests/model/__init__.py +0 -0
- tests/model/marginal/__init__.py +0 -0
- tests/model/marginal/test_distributions.py +131 -0
- tests/model/marginal/test_graph_analysis.py +182 -0
- tests/model/marginal/test_marginal_model.py +867 -0
- tests/model/test_model_api.py +29 -0
- tests/statespace/__init__.py +0 -0
- tests/statespace/test_ETS.py +411 -0
- tests/statespace/test_SARIMAX.py +405 -0
- tests/statespace/test_VARMAX.py +184 -0
- tests/statespace/test_coord_assignment.py +116 -0
- tests/statespace/test_distributions.py +270 -0
- tests/statespace/test_kalman_filter.py +326 -0
- tests/statespace/test_representation.py +175 -0
- tests/statespace/test_statespace.py +818 -0
- tests/statespace/test_statespace_JAX.py +156 -0
- tests/statespace/test_structural.py +829 -0
- tests/statespace/utilities/__init__.py +0 -0
- tests/statespace/utilities/shared_fixtures.py +9 -0
- tests/statespace/utilities/statsmodel_local_level.py +42 -0
- tests/statespace/utilities/test_helpers.py +310 -0
- tests/test_blackjax_smc.py +222 -0
- tests/test_find_map.py +98 -0
- tests/test_histogram_approximation.py +109 -0
- tests/test_laplace.py +238 -0
- tests/test_linearmodel.py +208 -0
- tests/test_model_builder.py +306 -0
- tests/test_pathfinder.py +45 -0
- tests/test_pivoted_cholesky.py +24 -0
- tests/test_printing.py +98 -0
- tests/test_prior_from_trace.py +172 -0
- tests/test_splines.py +77 -0
- tests/utils.py +31 -0
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from pymc_extras.statespace.filters.distributions import LinearGaussianStateSpace
|
|
2
|
+
from pymc_extras.statespace.filters.kalman_filter import (
|
|
3
|
+
SquareRootFilter,
|
|
4
|
+
StandardFilter,
|
|
5
|
+
UnivariateFilter,
|
|
6
|
+
)
|
|
7
|
+
from pymc_extras.statespace.filters.kalman_smoother import KalmanSmoother
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"StandardFilter",
|
|
11
|
+
"UnivariateFilter",
|
|
12
|
+
"KalmanSmoother",
|
|
13
|
+
"SquareRootFilter",
|
|
14
|
+
"LinearGaussianStateSpace",
|
|
15
|
+
]
|
|
@@ -0,0 +1,453 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pymc as pm
|
|
3
|
+
import pytensor
|
|
4
|
+
import pytensor.tensor as pt
|
|
5
|
+
|
|
6
|
+
from pymc import intX
|
|
7
|
+
from pymc.distributions.dist_math import check_parameters
|
|
8
|
+
from pymc.distributions.distribution import Continuous, SymbolicRandomVariable
|
|
9
|
+
from pymc.distributions.multivariate import MvNormal
|
|
10
|
+
from pymc.distributions.shape_utils import get_support_shape_1d
|
|
11
|
+
from pymc.logprob.abstract import _logprob
|
|
12
|
+
from pytensor.graph.basic import Node
|
|
13
|
+
from pytensor.tensor.random.basic import MvNormalRV
|
|
14
|
+
|
|
15
|
+
floatX = pytensor.config.floatX
|
|
16
|
+
COV_ZERO_TOL = 0
|
|
17
|
+
|
|
18
|
+
lgss_shape_message = (
|
|
19
|
+
"The LinearGaussianStateSpace distribution needs shape information to be constructed. "
|
|
20
|
+
"Ensure that all input matrices have shape information specified."
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def make_signature(sequence_names):
|
|
25
|
+
states = "s"
|
|
26
|
+
obs = "p"
|
|
27
|
+
exog = "r"
|
|
28
|
+
time = "t"
|
|
29
|
+
state_and_obs = "n"
|
|
30
|
+
|
|
31
|
+
matrix_to_shape = {
|
|
32
|
+
"x0": (states,),
|
|
33
|
+
"P0": (states, states),
|
|
34
|
+
"c": (states,),
|
|
35
|
+
"d": (obs,),
|
|
36
|
+
"T": (states, states),
|
|
37
|
+
"Z": (obs, states),
|
|
38
|
+
"R": (states, exog),
|
|
39
|
+
"H": (obs, obs),
|
|
40
|
+
"Q": (exog, exog),
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
for matrix in sequence_names:
|
|
44
|
+
base_shape = matrix_to_shape[matrix]
|
|
45
|
+
matrix_to_shape[matrix] = (time, *base_shape)
|
|
46
|
+
|
|
47
|
+
signature = ",".join(["(" + ",".join(shapes) + ")" for shapes in matrix_to_shape.values()])
|
|
48
|
+
|
|
49
|
+
return f"{signature},[rng]->[rng],({time},{state_and_obs})"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class MvNormalSVDRV(MvNormalRV):
|
|
53
|
+
name = "multivariate_normal"
|
|
54
|
+
signature = "(n),(n,n)->(n)"
|
|
55
|
+
dtype = "floatX"
|
|
56
|
+
_print_name = ("MultivariateNormal", "\\operatorname{MultivariateNormal}")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class MvNormalSVD(MvNormal):
|
|
60
|
+
"""Dummy distribution intended to be rewritten into a JAX multivariate_normal with method="svd".
|
|
61
|
+
|
|
62
|
+
A JAX MvNormal robust to low-rank covariance matrices
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
rv_op = MvNormalSVDRV()
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
try:
|
|
69
|
+
import jax.random
|
|
70
|
+
|
|
71
|
+
from pytensor.link.jax.dispatch.random import jax_sample_fn
|
|
72
|
+
|
|
73
|
+
@jax_sample_fn.register(MvNormalSVDRV)
|
|
74
|
+
def jax_sample_fn_mvnormal_svd(op, node):
|
|
75
|
+
def sample_fn(rng, size, dtype, *parameters):
|
|
76
|
+
rng_key = rng["jax_state"]
|
|
77
|
+
rng_key, sampling_key = jax.random.split(rng_key, 2)
|
|
78
|
+
sample = jax.random.multivariate_normal(
|
|
79
|
+
sampling_key, *parameters, shape=size, dtype=dtype, method="svd"
|
|
80
|
+
)
|
|
81
|
+
rng["jax_state"] = rng_key
|
|
82
|
+
return (rng, sample)
|
|
83
|
+
|
|
84
|
+
return sample_fn
|
|
85
|
+
|
|
86
|
+
except ImportError:
|
|
87
|
+
pass
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class LinearGaussianStateSpaceRV(SymbolicRandomVariable):
|
|
91
|
+
default_output = 1
|
|
92
|
+
_print_name = ("LinearGuassianStateSpace", "\\operatorname{LinearGuassianStateSpace}")
|
|
93
|
+
|
|
94
|
+
def update(self, node: Node):
|
|
95
|
+
return {node.inputs[-1]: node.outputs[0]}
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class _LinearGaussianStateSpace(Continuous):
|
|
99
|
+
def __new__(
|
|
100
|
+
cls,
|
|
101
|
+
name,
|
|
102
|
+
a0,
|
|
103
|
+
P0,
|
|
104
|
+
c,
|
|
105
|
+
d,
|
|
106
|
+
T,
|
|
107
|
+
Z,
|
|
108
|
+
R,
|
|
109
|
+
H,
|
|
110
|
+
Q,
|
|
111
|
+
steps=None,
|
|
112
|
+
mode=None,
|
|
113
|
+
sequence_names=None,
|
|
114
|
+
append_x0=True,
|
|
115
|
+
**kwargs,
|
|
116
|
+
):
|
|
117
|
+
# Ignore dims in support shape because they are just passed along to the "observed" and "latent" distributions
|
|
118
|
+
# created by LinearGaussianStateSpace. This "combined" distribution shouldn't ever be directly used.
|
|
119
|
+
steps = get_support_shape_1d(
|
|
120
|
+
support_shape=steps,
|
|
121
|
+
shape=None,
|
|
122
|
+
dims=None,
|
|
123
|
+
observed=kwargs.get("observed", None),
|
|
124
|
+
support_shape_offset=0,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
return super().__new__(
|
|
128
|
+
cls,
|
|
129
|
+
name,
|
|
130
|
+
a0,
|
|
131
|
+
P0,
|
|
132
|
+
c,
|
|
133
|
+
d,
|
|
134
|
+
T,
|
|
135
|
+
Z,
|
|
136
|
+
R,
|
|
137
|
+
H,
|
|
138
|
+
Q,
|
|
139
|
+
steps=steps,
|
|
140
|
+
mode=mode,
|
|
141
|
+
sequence_names=sequence_names,
|
|
142
|
+
append_x0=append_x0,
|
|
143
|
+
**kwargs,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
@classmethod
|
|
147
|
+
def dist(
|
|
148
|
+
cls,
|
|
149
|
+
a0,
|
|
150
|
+
P0,
|
|
151
|
+
c,
|
|
152
|
+
d,
|
|
153
|
+
T,
|
|
154
|
+
Z,
|
|
155
|
+
R,
|
|
156
|
+
H,
|
|
157
|
+
Q,
|
|
158
|
+
steps=None,
|
|
159
|
+
mode=None,
|
|
160
|
+
sequence_names=None,
|
|
161
|
+
append_x0=True,
|
|
162
|
+
**kwargs,
|
|
163
|
+
):
|
|
164
|
+
steps = get_support_shape_1d(
|
|
165
|
+
support_shape=steps, shape=kwargs.get("shape", None), support_shape_offset=0
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
if steps is None:
|
|
169
|
+
raise ValueError("Must specify steps or shape parameter")
|
|
170
|
+
|
|
171
|
+
steps = pt.as_tensor_variable(intX(steps), ndim=0)
|
|
172
|
+
|
|
173
|
+
return super().dist(
|
|
174
|
+
[a0, P0, c, d, T, Z, R, H, Q, steps],
|
|
175
|
+
mode=mode,
|
|
176
|
+
sequence_names=sequence_names,
|
|
177
|
+
append_x0=append_x0,
|
|
178
|
+
**kwargs,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
@classmethod
|
|
182
|
+
def rv_op(
|
|
183
|
+
cls,
|
|
184
|
+
a0,
|
|
185
|
+
P0,
|
|
186
|
+
c,
|
|
187
|
+
d,
|
|
188
|
+
T,
|
|
189
|
+
Z,
|
|
190
|
+
R,
|
|
191
|
+
H,
|
|
192
|
+
Q,
|
|
193
|
+
steps,
|
|
194
|
+
size=None,
|
|
195
|
+
mode=None,
|
|
196
|
+
sequence_names=None,
|
|
197
|
+
append_x0=True,
|
|
198
|
+
):
|
|
199
|
+
if sequence_names is None:
|
|
200
|
+
sequence_names = []
|
|
201
|
+
|
|
202
|
+
a0_, P0_, c_, d_, T_, Z_, R_, H_, Q_ = map(
|
|
203
|
+
lambda x: x.type(), [a0, P0, c, d, T, Z, R, H, Q]
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
c_.name = "c"
|
|
207
|
+
d_.name = "d"
|
|
208
|
+
T_.name = "T"
|
|
209
|
+
Z_.name = "Z"
|
|
210
|
+
R_.name = "R"
|
|
211
|
+
H_.name = "H"
|
|
212
|
+
Q_.name = "Q"
|
|
213
|
+
|
|
214
|
+
sequences = [
|
|
215
|
+
x
|
|
216
|
+
for x, name in zip([c_, d_, T_, Z_, R_, H_, Q_], ["c", "d", "T", "Z", "R", "H", "Q"])
|
|
217
|
+
if name in sequence_names
|
|
218
|
+
]
|
|
219
|
+
non_sequences = [x for x in [c_, d_, T_, Z_, R_, H_, Q_] if x not in sequences]
|
|
220
|
+
|
|
221
|
+
rng = pytensor.shared(np.random.default_rng())
|
|
222
|
+
|
|
223
|
+
def sort_args(args):
|
|
224
|
+
sorted_args = []
|
|
225
|
+
|
|
226
|
+
# Inside the scan, outputs_info variables get a time step appended to their name
|
|
227
|
+
# e.g. x -> x[t]. Remove this so we can identify variables by name.
|
|
228
|
+
arg_names = [x.name.replace("[t]", "") for x in args]
|
|
229
|
+
|
|
230
|
+
# c, d ,T, Z, R, H, Q is the "canonical" ordering
|
|
231
|
+
for name in ["c", "d", "T", "Z", "R", "H", "Q"]:
|
|
232
|
+
idx = arg_names.index(name)
|
|
233
|
+
sorted_args.append(args[idx])
|
|
234
|
+
|
|
235
|
+
return sorted_args
|
|
236
|
+
|
|
237
|
+
n_seq = len(sequence_names)
|
|
238
|
+
|
|
239
|
+
def step_fn(*args):
|
|
240
|
+
seqs, state, non_seqs = args[:n_seq], args[n_seq], args[n_seq + 1 :]
|
|
241
|
+
non_seqs, rng = non_seqs[:-1], non_seqs[-1]
|
|
242
|
+
|
|
243
|
+
c, d, T, Z, R, H, Q = sort_args(seqs + non_seqs)
|
|
244
|
+
k = T.shape[0]
|
|
245
|
+
a = state[:k]
|
|
246
|
+
|
|
247
|
+
middle_rng, a_innovation = MvNormalSVD.dist(mu=0, cov=Q, rng=rng).owner.outputs
|
|
248
|
+
next_rng, y_innovation = MvNormalSVD.dist(mu=0, cov=H, rng=middle_rng).owner.outputs
|
|
249
|
+
|
|
250
|
+
a_mu = c + T @ a
|
|
251
|
+
a_next = a_mu + R @ a_innovation
|
|
252
|
+
|
|
253
|
+
y_mu = d + Z @ a_next
|
|
254
|
+
y_next = y_mu + y_innovation
|
|
255
|
+
|
|
256
|
+
next_state = pt.concatenate([a_next, y_next], axis=0)
|
|
257
|
+
|
|
258
|
+
return next_state, {rng: next_rng}
|
|
259
|
+
|
|
260
|
+
Z_init = Z_ if Z_ in non_sequences else Z_[0]
|
|
261
|
+
H_init = H_ if H_ in non_sequences else H_[0]
|
|
262
|
+
|
|
263
|
+
init_x_ = MvNormalSVD.dist(a0_, P0_, rng=rng)
|
|
264
|
+
init_y_ = MvNormalSVD.dist(Z_init @ init_x_, H_init, rng=rng)
|
|
265
|
+
|
|
266
|
+
init_dist_ = pt.concatenate([init_x_, init_y_], axis=0)
|
|
267
|
+
|
|
268
|
+
statespace, updates = pytensor.scan(
|
|
269
|
+
step_fn,
|
|
270
|
+
outputs_info=[init_dist_],
|
|
271
|
+
sequences=None if len(sequences) == 0 else sequences,
|
|
272
|
+
non_sequences=[*non_sequences, rng],
|
|
273
|
+
n_steps=steps,
|
|
274
|
+
mode=mode,
|
|
275
|
+
strict=True,
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
if append_x0:
|
|
279
|
+
statespace_ = pt.concatenate([init_dist_[None], statespace], axis=0)
|
|
280
|
+
statespace_ = pt.specify_shape(statespace_, (steps + 1, None))
|
|
281
|
+
else:
|
|
282
|
+
statespace_ = statespace
|
|
283
|
+
statespace_ = pt.specify_shape(statespace_, (steps, None))
|
|
284
|
+
|
|
285
|
+
(ss_rng,) = tuple(updates.values())
|
|
286
|
+
linear_gaussian_ss_op = LinearGaussianStateSpaceRV(
|
|
287
|
+
inputs=[a0_, P0_, c_, d_, T_, Z_, R_, H_, Q_, steps, rng],
|
|
288
|
+
outputs=[ss_rng, statespace_],
|
|
289
|
+
extended_signature=make_signature(sequence_names),
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
linear_gaussian_ss = linear_gaussian_ss_op(a0, P0, c, d, T, Z, R, H, Q, steps, rng)
|
|
293
|
+
return linear_gaussian_ss
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
class LinearGaussianStateSpace(Continuous):
|
|
297
|
+
"""
|
|
298
|
+
Linear Gaussian Statespace distribution
|
|
299
|
+
|
|
300
|
+
"""
|
|
301
|
+
|
|
302
|
+
def __new__(
|
|
303
|
+
cls,
|
|
304
|
+
name,
|
|
305
|
+
a0,
|
|
306
|
+
P0,
|
|
307
|
+
c,
|
|
308
|
+
d,
|
|
309
|
+
T,
|
|
310
|
+
Z,
|
|
311
|
+
R,
|
|
312
|
+
H,
|
|
313
|
+
Q,
|
|
314
|
+
*,
|
|
315
|
+
steps,
|
|
316
|
+
k_endog=None,
|
|
317
|
+
sequence_names=None,
|
|
318
|
+
mode=None,
|
|
319
|
+
append_x0=True,
|
|
320
|
+
**kwargs,
|
|
321
|
+
):
|
|
322
|
+
dims = kwargs.pop("dims", None)
|
|
323
|
+
latent_dims = None
|
|
324
|
+
obs_dims = None
|
|
325
|
+
if dims is not None:
|
|
326
|
+
if len(dims) != 3:
|
|
327
|
+
ValueError(
|
|
328
|
+
"LinearGaussianStateSpace expects 3 dims: time, all_states, and observed_states"
|
|
329
|
+
)
|
|
330
|
+
time_dim, state_dim, obs_dim = dims
|
|
331
|
+
latent_dims = [time_dim, state_dim]
|
|
332
|
+
obs_dims = [time_dim, obs_dim]
|
|
333
|
+
|
|
334
|
+
latent_obs_combined = _LinearGaussianStateSpace(
|
|
335
|
+
f"{name}_combined",
|
|
336
|
+
a0,
|
|
337
|
+
P0,
|
|
338
|
+
c,
|
|
339
|
+
d,
|
|
340
|
+
T,
|
|
341
|
+
Z,
|
|
342
|
+
R,
|
|
343
|
+
H,
|
|
344
|
+
Q,
|
|
345
|
+
steps=steps,
|
|
346
|
+
mode=mode,
|
|
347
|
+
sequence_names=sequence_names,
|
|
348
|
+
append_x0=append_x0,
|
|
349
|
+
**kwargs,
|
|
350
|
+
)
|
|
351
|
+
latent_obs_combined = pt.specify_shape(latent_obs_combined, (steps + int(append_x0), None))
|
|
352
|
+
if k_endog is None:
|
|
353
|
+
k_endog = cls._get_k_endog(H)
|
|
354
|
+
latent_slice = slice(None, -k_endog)
|
|
355
|
+
obs_slice = slice(-k_endog, None)
|
|
356
|
+
|
|
357
|
+
latent_states = latent_obs_combined[..., latent_slice]
|
|
358
|
+
obs_states = latent_obs_combined[..., obs_slice]
|
|
359
|
+
|
|
360
|
+
latent_states = pm.Deterministic(f"{name}_latent", latent_states, dims=latent_dims)
|
|
361
|
+
obs_states = pm.Deterministic(f"{name}_observed", obs_states, dims=obs_dims)
|
|
362
|
+
|
|
363
|
+
return latent_states, obs_states
|
|
364
|
+
|
|
365
|
+
@classmethod
|
|
366
|
+
def dist(cls, a0, P0, c, d, T, Z, R, H, Q, *, steps=None, **kwargs):
|
|
367
|
+
latent_obs_combined = _LinearGaussianStateSpace.dist(
|
|
368
|
+
a0, P0, c, d, T, Z, R, H, Q, steps=steps, **kwargs
|
|
369
|
+
)
|
|
370
|
+
k_states = T.type.shape[0]
|
|
371
|
+
|
|
372
|
+
latent_states = latent_obs_combined[..., :k_states]
|
|
373
|
+
obs_states = latent_obs_combined[..., k_states:]
|
|
374
|
+
|
|
375
|
+
return latent_states, obs_states
|
|
376
|
+
|
|
377
|
+
@classmethod
|
|
378
|
+
def _get_k_states(cls, T):
|
|
379
|
+
k_states = T.type.shape[0]
|
|
380
|
+
if k_states is None:
|
|
381
|
+
raise ValueError(lgss_shape_message)
|
|
382
|
+
return k_states
|
|
383
|
+
|
|
384
|
+
@classmethod
|
|
385
|
+
def _get_k_endog(cls, H):
|
|
386
|
+
k_endog = H.type.shape[0]
|
|
387
|
+
if k_endog is None:
|
|
388
|
+
raise ValueError(lgss_shape_message)
|
|
389
|
+
|
|
390
|
+
return k_endog
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
class KalmanFilterRV(SymbolicRandomVariable):
|
|
394
|
+
default_output = 1
|
|
395
|
+
_print_name = ("KalmanFilter", "\\operatorname{KalmanFilter}")
|
|
396
|
+
extended_signature = "(t,s),(t,s,s),(t),[rng]->[rng],(t,s)"
|
|
397
|
+
|
|
398
|
+
def update(self, node: Node):
|
|
399
|
+
return {node.inputs[-1]: node.outputs[0]}
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
class SequenceMvNormal(Continuous):
|
|
403
|
+
def __new__(cls, *args, **kwargs):
|
|
404
|
+
return super().__new__(cls, *args, **kwargs)
|
|
405
|
+
|
|
406
|
+
@classmethod
|
|
407
|
+
def dist(cls, mus, covs, logp, **kwargs):
|
|
408
|
+
return super().dist([mus, covs, logp], **kwargs)
|
|
409
|
+
|
|
410
|
+
@classmethod
|
|
411
|
+
def rv_op(cls, mus, covs, logp, size=None):
|
|
412
|
+
# Batch dimensions (if any) will be on the far left, but scan requires time to be there instead
|
|
413
|
+
if mus.ndim > 2:
|
|
414
|
+
mus = pt.moveaxis(mus, -2, 0)
|
|
415
|
+
if covs.ndim > 3:
|
|
416
|
+
covs = pt.moveaxis(covs, -3, 0)
|
|
417
|
+
|
|
418
|
+
mus_, covs_ = mus.type(), covs.type()
|
|
419
|
+
|
|
420
|
+
logp_ = logp.type()
|
|
421
|
+
rng = pytensor.shared(np.random.default_rng())
|
|
422
|
+
|
|
423
|
+
def step(mu, cov, rng):
|
|
424
|
+
new_rng, mvn = MvNormalSVD.dist(mu=mu, cov=cov, rng=rng).owner.outputs
|
|
425
|
+
return mvn, {rng: new_rng}
|
|
426
|
+
|
|
427
|
+
mvn_seq, updates = pytensor.scan(
|
|
428
|
+
step, sequences=[mus_, covs_], non_sequences=[rng], strict=True, n_steps=mus_.shape[0]
|
|
429
|
+
)
|
|
430
|
+
mvn_seq = pt.specify_shape(mvn_seq, mus.type.shape)
|
|
431
|
+
|
|
432
|
+
# Move time axis back to position -2 so batches are on the left
|
|
433
|
+
if mvn_seq.ndim > 2:
|
|
434
|
+
mvn_seq = pt.moveaxis(mvn_seq, 0, -2)
|
|
435
|
+
|
|
436
|
+
(seq_mvn_rng,) = tuple(updates.values())
|
|
437
|
+
|
|
438
|
+
mvn_seq_op = KalmanFilterRV(
|
|
439
|
+
inputs=[mus_, covs_, logp_, rng], outputs=[seq_mvn_rng, mvn_seq], ndim_supp=2
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
mvn_seq = mvn_seq_op(mus, covs, logp, rng)
|
|
443
|
+
return mvn_seq
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
@_logprob.register(KalmanFilterRV)
|
|
447
|
+
def sequence_mvnormal_logp(op, values, mus, covs, logp, rng, **kwargs):
|
|
448
|
+
return check_parameters(
|
|
449
|
+
logp,
|
|
450
|
+
pt.eq(values[0].shape[0], mus.shape[0]),
|
|
451
|
+
pt.eq(covs.shape[0], mus.shape[0]),
|
|
452
|
+
msg="Observed data and parameters must have the same number of timesteps (dimension 0)",
|
|
453
|
+
)
|