pymc-extras 0.3.1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,445 @@
1
+ import numpy as np
2
+
3
+ from pytensor import tensor as pt
4
+
5
+ from pymc_extras.statespace.models.structural.core import Component
6
+ from pymc_extras.statespace.models.structural.utils import _frequency_transition_block
7
+
8
+
9
+ class TimeSeasonality(Component):
10
+ r"""
11
+ Seasonal component, modeled in the time domain
12
+
13
+ Parameters
14
+ ----------
15
+ season_length: int
16
+ The number of periods in a single seasonal cycle, e.g. 12 for monthly data with annual seasonal pattern, 7 for
17
+ daily data with weekly seasonal pattern, etc.
18
+
19
+ innovations: bool, default True
20
+ Whether to include stochastic innovations in the strength of the seasonal effect
21
+
22
+ name: str, default None
23
+ A name for this seasonal component. Used to label dimensions and coordinates. Useful when multiple seasonal
24
+ components are included in the same model. Default is ``f"Seasonal[s={season_length}]"``
25
+
26
+ state_names: list of str, default None
27
+ List of strings for seasonal effect labels. If provided, it must be of length ``season_length``. An example
28
+ would be ``state_names = ['Mon', 'Tue', 'Wed', 'Thur', 'Fri', 'Sat', 'Sun']`` when data is daily with a weekly
29
+ seasonal pattern (``season_length = 7``).
30
+
31
+ If None, states will be numbered ``[State_0, ..., State_s]``
32
+
33
+ remove_first_state: bool, default True
34
+ If True, the first state will be removed from the model. This is done because there are only n-1 degrees of
35
+ freedom in the seasonal component, and one state is not identified. If False, the first state will be
36
+ included in the model, but it will not be identified -- you will need to handle this in the priors (e.g. with
37
+ ZeroSumNormal).
38
+
39
+ observed_state_names: list[str] | None, default None
40
+ List of strings for observed state labels. If None, defaults to ["data"].
41
+
42
+ Notes
43
+ -----
44
+ A seasonal effect is any pattern that repeats every fixed interval. Although there are many possible ways to
45
+ model seasonal effects, the implementation used here is the one described by [1] as the "canonical" time domain
46
+ representation. The seasonal component can be expressed:
47
+
48
+ .. math::
49
+ \gamma_t = -\sum_{i=1}^{s-1} \gamma_{t-i} + \omega_t, \quad \omega_t \sim N(0, \sigma_\gamma)
50
+
51
+ Where :math:`s` is the ``seasonal_length`` parameter and :math:`\omega_t` is the (optional) stochastic innovation.
52
+ To give interpretation to the :math:`\gamma` terms, it is helpful to work through the algebra for a simple
53
+ example. Let :math:`s=4`, and omit the shock term. Define initial conditions :math:`\gamma_0, \gamma_{-1},
54
+ \gamma_{-2}`. The value of the seasonal component for the first 5 timesteps will be:
55
+
56
+ .. math::
57
+ \begin{align}
58
+ \gamma_1 &= -\gamma_0 - \gamma_{-1} - \gamma_{-2} \\
59
+ \gamma_2 &= -\gamma_1 - \gamma_0 - \gamma_{-1} \\
60
+ &= -(-\gamma_0 - \gamma_{-1} - \gamma_{-2}) - \gamma_0 - \gamma_{-1} \\
61
+ &= (\gamma_0 - \gamma_0 )+ (\gamma_{-1} - \gamma_{-1}) + \gamma_{-2} \\
62
+ &= \gamma_{-2} \\
63
+ \gamma_3 &= -\gamma_2 - \gamma_1 - \gamma_0 \\
64
+ &= -\gamma_{-2} - (-\gamma_0 - \gamma_{-1} - \gamma_{-2}) - \gamma_0 \\
65
+ &= (\gamma_{-2} - \gamma_{-2}) + \gamma_{-1} + (\gamma_0 - \gamma_0) \\
66
+ &= \gamma_{-1} \\
67
+ \gamma_4 &= -\gamma_3 - \gamma_2 - \gamma_1 \\
68
+ &= -\gamma_{-1} - \gamma_{-2} -(-\gamma_0 - \gamma_{-1} - \gamma_{-2}) \\
69
+ &= (\gamma_{-2} - \gamma_{-2}) + (\gamma_{-1} - \gamma_{-1}) + \gamma_0 \\
70
+ &= \gamma_0 \\
71
+ \gamma_5 &= -\gamma_4 - \gamma_3 - \gamma_2 \\
72
+ &= -\gamma_0 - \gamma_{-1} - \gamma_{-2} \\
73
+ &= \gamma_1
74
+ \end{align}
75
+
76
+ This exercise shows that, given a list ``initial_conditions`` of length ``s-1``, the effects of this model will be:
77
+
78
+ - Period 1: ``-sum(initial_conditions)``
79
+ - Period 2: ``initial_conditions[-1]``
80
+ - Period 3: ``initial_conditions[-2]``
81
+ - ...
82
+ - Period s: ``initial_conditions[0]``
83
+ - Period s+1: ``-sum(initial_condition)``
84
+
85
+ And so on. So for interpretation, the ``season_length - 1`` initial states are, when reversed, the coefficients
86
+ associated with ``state_names[1:]``.
87
+
88
+ .. warning::
89
+ Although the ``state_names`` argument expects a list of length ``season_length``, only ``state_names[1:]``
90
+ will be saved as model dimensions, since the 1st coefficient is not identified (it is defined as
91
+ :math:`-\sum_{i=1}^{s} \gamma_{t-i}`).
92
+
93
+ Examples
94
+ --------
95
+ Estimate monthly with a model with a gaussian random walk trend and monthly seasonality:
96
+
97
+ .. code:: python
98
+
99
+ from pymc_extras.statespace import structural as st
100
+ import pymc as pm
101
+ import pytensor.tensor as pt
102
+ import pandas as pd
103
+
104
+ # Get month names
105
+ state_names = pd.date_range('1900-01-01', '1900-12-31', freq='MS').month_name().tolist()
106
+
107
+ # Build the structural model
108
+ grw = st.LevelTrendComponent(order=1, innovations_order=1)
109
+ annual_season = st.TimeSeasonality(
110
+ season_length=12, name="annual", state_names=state_names, innovations=False
111
+ )
112
+ ss_mod = (grw + annual_season).build()
113
+
114
+ with pm.Model(coords=ss_mod.coords) as model:
115
+ P0 = pm.Deterministic('P0', pt.eye(ss_mod.k_states) * 10, dims=ss_mod.param_dims['P0'])
116
+
117
+ initial_level_trend = pm.Deterministic(
118
+ "initial_level_trend", pt.zeros(1), dims=ss_mod.param_dims["initial_level_trend"]
119
+ )
120
+ sigma_level_trend = pm.HalfNormal(
121
+ "sigma_level_trend", sigma=1e-6, dims=ss_mod.param_dims["sigma_level_trend"]
122
+ )
123
+ coefs_annual = pm.Normal("coefs_annual", sigma=1e-2, dims=ss_mod.param_dims["coefs_annual"])
124
+
125
+ ss_mod.build_statespace_graph(data)
126
+ idata = pm.sample(
127
+ nuts_sampler="nutpie", nuts_sampler_kwargs={"backend": "JAX", "gradient_backend": "JAX"}
128
+ )
129
+
130
+ References
131
+ ----------
132
+ .. [1] Durbin, James, and Siem Jan Koopman. 2012.
133
+ Time Series Analysis by State Space Methods: Second Edition.
134
+ Oxford University Press.
135
+ """
136
+
137
+ def __init__(
138
+ self,
139
+ season_length: int,
140
+ innovations: bool = True,
141
+ name: str | None = None,
142
+ state_names: list | None = None,
143
+ remove_first_state: bool = True,
144
+ observed_state_names: list[str] | None = None,
145
+ ):
146
+ if observed_state_names is None:
147
+ observed_state_names = ["data"]
148
+
149
+ if name is None:
150
+ name = f"Seasonal[s={season_length}]"
151
+ if state_names is None:
152
+ state_names = [f"{name}_{i}" for i in range(season_length)]
153
+ else:
154
+ if len(state_names) != season_length:
155
+ raise ValueError(
156
+ f"state_names must be a list of length season_length, got {len(state_names)}"
157
+ )
158
+ state_names = state_names.copy()
159
+
160
+ self.innovations = innovations
161
+ self.remove_first_state = remove_first_state
162
+
163
+ if self.remove_first_state:
164
+ # In traditional models, the first state isn't identified, so we can help out the user by automatically
165
+ # discarding it.
166
+ # TODO: Can this be stashed and reconstructed automatically somehow?
167
+ state_names.pop(0)
168
+
169
+ self.provided_state_names = state_names
170
+
171
+ k_states = season_length - int(self.remove_first_state)
172
+ k_endog = len(observed_state_names)
173
+ k_posdef = int(innovations)
174
+
175
+ super().__init__(
176
+ name=name,
177
+ k_endog=k_endog,
178
+ k_states=k_states * k_endog,
179
+ k_posdef=k_posdef * k_endog,
180
+ observed_state_names=observed_state_names,
181
+ measurement_error=False,
182
+ combine_hidden_states=True,
183
+ obs_state_idxs=np.tile(np.array([1.0] + [0.0] * (k_states - 1)), k_endog),
184
+ )
185
+
186
+ def populate_component_properties(self):
187
+ k_states = self.k_states // self.k_endog
188
+ k_endog = self.k_endog
189
+
190
+ self.state_names = [
191
+ f"{state_name}[{endog_name}]"
192
+ for endog_name in self.observed_state_names
193
+ for state_name in self.provided_state_names
194
+ ]
195
+ self.param_names = [f"coefs_{self.name}"]
196
+
197
+ self.param_info = {
198
+ f"coefs_{self.name}": {
199
+ "shape": (k_states,) if k_endog == 1 else (k_endog, k_states),
200
+ "constraints": None,
201
+ "dims": (f"state_{self.name}",)
202
+ if k_endog == 1
203
+ else (f"endog_{self.name}", f"state_{self.name}"),
204
+ }
205
+ }
206
+
207
+ self.param_dims = {
208
+ f"coefs_{self.name}": (f"state_{self.name}",)
209
+ if k_endog == 1
210
+ else (f"endog_{self.name}", f"state_{self.name}")
211
+ }
212
+
213
+ self.coords = (
214
+ {f"state_{self.name}": self.provided_state_names}
215
+ if k_endog == 1
216
+ else {
217
+ f"endog_{self.name}": self.observed_state_names,
218
+ f"state_{self.name}": self.provided_state_names,
219
+ }
220
+ )
221
+
222
+ if self.innovations:
223
+ self.param_names += [f"sigma_{self.name}"]
224
+ self.param_info[f"sigma_{self.name}"] = {
225
+ "shape": (),
226
+ "constraints": "Positive",
227
+ "dims": None,
228
+ }
229
+ self.shock_names = [f"{self.name}[{name}]" for name in self.observed_state_names]
230
+
231
+ def make_symbolic_graph(self) -> None:
232
+ k_states = self.k_states // self.k_endog
233
+ k_posdef = self.k_posdef // self.k_endog
234
+ k_endog = self.k_endog
235
+
236
+ if self.remove_first_state:
237
+ # In this case, parameters are normalized to sum to zero, so the current state is the negative sum of
238
+ # all previous states.
239
+ T = np.eye(k_states, k=-1)
240
+ T[0, :] = -1
241
+ else:
242
+ # In this case we assume the user to be responsible for ensuring the states sum to zero, so T is just a
243
+ # circulant matrix that cycles between the states.
244
+ T = np.eye(k_states, k=1)
245
+ T[-1, 0] = 1
246
+
247
+ self.ssm["transition", :, :] = pt.linalg.block_diag(*[T for _ in range(k_endog)])
248
+
249
+ Z = pt.zeros((1, k_states))[0, 0].set(1)
250
+ self.ssm["design", :, :] = pt.linalg.block_diag(*[Z for _ in range(k_endog)])
251
+
252
+ initial_states = self.make_and_register_variable(
253
+ f"coefs_{self.name}", shape=(k_states,) if k_endog == 1 else (k_endog, k_states)
254
+ )
255
+ self.ssm["initial_state", :] = initial_states.ravel()
256
+
257
+ if self.innovations:
258
+ R = pt.zeros((k_states, k_posdef))[0, 0].set(1.0)
259
+ self.ssm["selection", :, :] = pt.join(0, *[R for _ in range(k_endog)])
260
+ season_sigma = self.make_and_register_variable(
261
+ f"sigma_{self.name}", shape=() if k_endog == 1 else (k_endog,)
262
+ )
263
+ cov_idx = ("state_cov", *np.diag_indices(k_posdef * k_endog))
264
+ self.ssm[cov_idx] = season_sigma**2
265
+
266
+
267
+ class FrequencySeasonality(Component):
268
+ r"""
269
+ Seasonal component, modeled in the frequency domain
270
+
271
+ Parameters
272
+ ----------
273
+ season_length: float
274
+ The number of periods in a single seasonal cycle, e.g. 12 for monthly data with annual seasonal pattern, 7 for
275
+ daily data with weekly seasonal pattern, etc. Non-integer seasonal_length is also permitted, for example
276
+ 365.2422 days in a (solar) year.
277
+
278
+ n: int
279
+ Number of fourier features to include in the seasonal component. Default is ``season_length // 2``, which
280
+ is the maximum possible. A smaller number can be used for a more wave-like seasonal pattern.
281
+
282
+ name: str, default None
283
+ A name for this seasonal component. Used to label dimensions and coordinates. Useful when multiple seasonal
284
+ components are included in the same model. Default is ``f"Seasonal[s={season_length}, n={n}]"``
285
+
286
+ innovations: bool, default True
287
+ Whether to include stochastic innovations in the strength of the seasonal effect
288
+
289
+ observed_state_names: list[str] | None, default None
290
+ List of strings for observed state labels. If None, defaults to ["data"].
291
+
292
+ Notes
293
+ -----
294
+ A seasonal effect is any pattern that repeats every fixed interval. Although there are many possible ways to
295
+ model seasonal effects, the implementation used here is the one described by [1] as the "canonical" frequency domain
296
+ representation. The seasonal component can be expressed:
297
+
298
+ .. math::
299
+ \begin{align}
300
+ \gamma_t &= \sum_{j=1}^{2n} \gamma_{j,t} \\
301
+ \gamma_{j, t+1} &= \gamma_{j,t} \cos \lambda_j + \gamma_{j,t}^\star \sin \lambda_j + \omega_{j, t} \\
302
+ \gamma_{j, t}^\star &= -\gamma_{j,t} \sin \lambda_j + \gamma_{j,t}^\star \cos \lambda_j + \omega_{j,t}^\star
303
+ \lambda_j &= \frac{2\pi j}{s}
304
+ \end{align}
305
+
306
+ Where :math:`s` is the ``seasonal_length``.
307
+
308
+ Unlike a ``TimeSeasonality`` component, a ``FrequencySeasonality`` component does not require integer season
309
+ length. In addition, for long seasonal periods, it is possible to obtain a more compact state space representation
310
+ by choosing ``n << s // 2``. Using ``TimeSeasonality``, an annual seasonal pattern in daily data requires 364
311
+ states, whereas ``FrequencySeasonality`` always requires ``2 * n`` states, regardless of the ``seasonal_length``.
312
+ The price of this compactness is less representational power. At ``n = 1``, the seasonal pattern will be a pure
313
+ sine wave. At ``n = s // 2``, any arbitrary pattern can be represented.
314
+
315
+ One cost of the added flexibility of ``FrequencySeasonality`` is reduced interpretability. States of this model are
316
+ coefficients :math:`\gamma_1, \gamma^\star_1, \gamma_2, \gamma_2^\star ..., \gamma_n, \gamma^\star_n` associated
317
+ with different frequencies in the fourier representation of the seasonal pattern. As a result, it is not possible
318
+ to isolate and identify a "Monday" effect, for instance.
319
+ """
320
+
321
+ def __init__(
322
+ self,
323
+ season_length,
324
+ n=None,
325
+ name=None,
326
+ innovations=True,
327
+ observed_state_names: list[str] | None = None,
328
+ ):
329
+ if observed_state_names is None:
330
+ observed_state_names = ["data"]
331
+
332
+ k_endog = len(observed_state_names)
333
+
334
+ if n is None:
335
+ n = int(season_length / 2)
336
+ if name is None:
337
+ name = f"Frequency[s={season_length}, n={n}]"
338
+
339
+ k_states = n * 2
340
+ self.n = n
341
+ self.season_length = season_length
342
+ self.innovations = innovations
343
+
344
+ # If the model is completely saturated (n = s // 2), the last state will not be identified, so it shouldn't
345
+ # get a parameter assigned to it and should just be fixed to zero.
346
+ # Test this way (rather than n == s // 2) to catch cases when n is non-integer.
347
+ self.last_state_not_identified = self.season_length / self.n == 2.0
348
+ self.n_coefs = k_states - int(self.last_state_not_identified)
349
+
350
+ obs_state_idx = np.zeros(k_states)
351
+ obs_state_idx[slice(0, k_states, 2)] = 1
352
+ obs_state_idx = np.tile(obs_state_idx, k_endog)
353
+
354
+ super().__init__(
355
+ name=name,
356
+ k_endog=k_endog,
357
+ k_states=k_states * k_endog,
358
+ k_posdef=k_states * int(self.innovations) * k_endog,
359
+ observed_state_names=observed_state_names,
360
+ measurement_error=False,
361
+ combine_hidden_states=True,
362
+ obs_state_idxs=obs_state_idx,
363
+ )
364
+
365
+ def make_symbolic_graph(self) -> None:
366
+ k_endog = self.k_endog
367
+ k_states = self.k_states // k_endog
368
+ k_posdef = self.k_posdef // k_endog
369
+ n_coefs = self.n_coefs
370
+
371
+ Z = pt.zeros((1, k_states))[0, slice(0, k_states, 2)].set(1.0)
372
+
373
+ self.ssm["design", :, :] = pt.linalg.block_diag(*[Z for _ in range(k_endog)])
374
+
375
+ init_state = self.make_and_register_variable(
376
+ f"{self.name}", shape=(n_coefs,) if k_endog == 1 else (k_endog, n_coefs)
377
+ )
378
+
379
+ init_state_idx = np.concatenate(
380
+ [
381
+ np.arange(k_states * i, (i + 1) * k_states, dtype=int)[:n_coefs]
382
+ for i in range(k_endog)
383
+ ],
384
+ axis=0,
385
+ )
386
+
387
+ self.ssm["initial_state", init_state_idx] = init_state.ravel()
388
+
389
+ T_mats = [_frequency_transition_block(self.season_length, j + 1) for j in range(self.n)]
390
+ T = pt.linalg.block_diag(*T_mats)
391
+ self.ssm["transition", :, :] = pt.linalg.block_diag(*[T for _ in range(k_endog)])
392
+
393
+ if self.innovations:
394
+ sigma_season = self.make_and_register_variable(
395
+ f"sigma_{self.name}", shape=() if k_endog == 1 else (k_endog,)
396
+ )
397
+ self.ssm["selection", :, :] = pt.eye(self.k_states)
398
+ self.ssm["state_cov", :, :] = pt.eye(self.k_posdef) * pt.repeat(
399
+ sigma_season**2, k_posdef
400
+ )
401
+
402
+ def populate_component_properties(self):
403
+ k_endog = self.k_endog
404
+ n_coefs = self.n_coefs
405
+ k_states = self.k_states // k_endog
406
+
407
+ self.state_names = [
408
+ f"{f}_{self.name}_{i}[{obs_state_name}]"
409
+ for obs_state_name in self.observed_state_names
410
+ for i in range(self.n)
411
+ for f in ["Cos", "Sin"]
412
+ ]
413
+ self.param_names = [f"{self.name}"]
414
+
415
+ self.param_dims = {self.name: (f"state_{self.name}",)}
416
+ self.param_info = {
417
+ f"{self.name}": {
418
+ "shape": (n_coefs,) if k_endog == 1 else (k_endog, n_coefs),
419
+ "constraints": None,
420
+ "dims": (f"state_{self.name}",)
421
+ if k_endog == 1
422
+ else (f"endog_{self.name}", f"state_{self.name}"),
423
+ }
424
+ }
425
+
426
+ # Regardless of whether the fourier basis are saturated, there will always be one symbolic state per basis.
427
+ # That's why the self.states is just a simple loop over everything. But when saturated, one of those states
428
+ # doesn't have an associated **parameter**, so the coords need to be adjusted to reflect this.
429
+ init_state_idx = np.concatenate(
430
+ [
431
+ np.arange(k_states * i, (i + 1) * k_states, dtype=int)[:n_coefs]
432
+ for i in range(k_endog)
433
+ ],
434
+ axis=0,
435
+ )
436
+ self.coords = {f"state_{self.name}": [self.state_names[i] for i in init_state_idx]}
437
+
438
+ if self.innovations:
439
+ self.shock_names = self.state_names.copy()
440
+ self.param_names += [f"sigma_{self.name}"]
441
+ self.param_info[f"sigma_{self.name}"] = {
442
+ "shape": () if k_endog == 1 else (k_endog, n_coefs),
443
+ "constraints": "Positive",
444
+ "dims": None if k_endog == 1 else (f"endog_{self.name}",),
445
+ }