pymc-extras 0.2.6__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2047,6 +2047,69 @@ class PyMCStateSpace:
2047
2047
 
2048
2048
  return scenario
2049
2049
 
2050
+ def _build_forecast_model(
2051
+ self, time_index, t0, forecast_index, scenario, filter_output, mvn_method
2052
+ ):
2053
+ filter_time_dim = TIME_DIM
2054
+ temp_coords = self._fit_coords.copy()
2055
+
2056
+ dims = None
2057
+ if all([dim in temp_coords for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM]]):
2058
+ dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM]
2059
+
2060
+ t0_idx = np.flatnonzero(time_index == t0)[0]
2061
+
2062
+ temp_coords["data_time"] = time_index
2063
+ temp_coords[TIME_DIM] = forecast_index
2064
+
2065
+ mu_dims, cov_dims = None, None
2066
+ if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM]]):
2067
+ mu_dims = ["data_time", ALL_STATE_DIM]
2068
+ cov_dims = ["data_time", ALL_STATE_DIM, ALL_STATE_AUX_DIM]
2069
+
2070
+ with pm.Model(coords=temp_coords) as forecast_model:
2071
+ (_, _, *matrices), grouped_outputs = self._kalman_filter_outputs_from_dummy_graph(
2072
+ data_dims=["data_time", OBS_STATE_DIM],
2073
+ )
2074
+
2075
+ group_idx = FILTER_OUTPUT_TYPES.index(filter_output)
2076
+ mu, cov = grouped_outputs[group_idx]
2077
+
2078
+ sub_dict = {
2079
+ data_var: pt.as_tensor_variable(data_var.get_value(), name="data")
2080
+ for data_var in forecast_model.data_vars
2081
+ }
2082
+
2083
+ missing_data_vars = np.setdiff1d(
2084
+ ar1=[*self.data_names, "data"], ar2=[k.name for k, _ in sub_dict.items()]
2085
+ )
2086
+ if missing_data_vars.size > 0:
2087
+ raise ValueError(f"{missing_data_vars} data used for fitting not found!")
2088
+
2089
+ mu_frozen, cov_frozen = graph_replace([mu, cov], replace=sub_dict, strict=True)
2090
+
2091
+ x0 = pm.Deterministic(
2092
+ "x0_slice", mu_frozen[t0_idx], dims=mu_dims[1:] if mu_dims is not None else None
2093
+ )
2094
+ P0 = pm.Deterministic(
2095
+ "P0_slice", cov_frozen[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None
2096
+ )
2097
+
2098
+ _ = LinearGaussianStateSpace(
2099
+ "forecast",
2100
+ x0,
2101
+ P0,
2102
+ *matrices,
2103
+ steps=len(forecast_index),
2104
+ dims=dims,
2105
+ sequence_names=self.kalman_filter.seq_names,
2106
+ k_endog=self.k_endog,
2107
+ append_x0=False,
2108
+ method=mvn_method,
2109
+ )
2110
+
2111
+ return forecast_model
2112
+
2050
2113
  def forecast(
2051
2114
  self,
2052
2115
  idata: InferenceData,
@@ -2139,8 +2202,6 @@ class PyMCStateSpace:
2139
2202
  the latent state trajectories: `y[t] = Z @ x[t] + nu[t]`, where `nu ~ N(0, H)`.
2140
2203
 
2141
2204
  """
2142
- filter_time_dim = TIME_DIM
2143
-
2144
2205
  _validate_filter_arg(filter_output)
2145
2206
 
2146
2207
  compile_kwargs = kwargs.pop("compile_kwargs", {})
@@ -2185,58 +2246,23 @@ class PyMCStateSpace:
2185
2246
  use_scenario_index=use_scenario_index,
2186
2247
  )
2187
2248
  scenario = self._finalize_scenario_initialization(scenario, forecast_index)
2188
- temp_coords = self._fit_coords.copy()
2189
-
2190
- dims = None
2191
- if all([dim in temp_coords for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM]]):
2192
- dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM]
2193
-
2194
- t0_idx = np.flatnonzero(time_index == t0)[0]
2195
-
2196
- temp_coords["data_time"] = time_index
2197
- temp_coords[TIME_DIM] = forecast_index
2198
-
2199
- mu_dims, cov_dims = None, None
2200
- if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM]]):
2201
- mu_dims = ["data_time", ALL_STATE_DIM]
2202
- cov_dims = ["data_time", ALL_STATE_DIM, ALL_STATE_AUX_DIM]
2203
-
2204
- with pm.Model(coords=temp_coords) as forecast_model:
2205
- (_, _, *matrices), grouped_outputs = self._kalman_filter_outputs_from_dummy_graph(
2206
- scenario=scenario,
2207
- data_dims=["data_time", OBS_STATE_DIM],
2208
- )
2209
-
2210
- for name in self.data_names:
2211
- if name in scenario.keys():
2212
- pm.set_data(
2213
- {"data": np.zeros((len(forecast_index), self.k_endog))},
2214
- coords={"data_time": np.arange(len(forecast_index))},
2215
- )
2216
- break
2217
2249
 
2218
- group_idx = FILTER_OUTPUT_TYPES.index(filter_output)
2219
- mu, cov = grouped_outputs[group_idx]
2220
-
2221
- x0 = pm.Deterministic(
2222
- "x0_slice", mu[t0_idx], dims=mu_dims[1:] if mu_dims is not None else None
2223
- )
2224
- P0 = pm.Deterministic(
2225
- "P0_slice", cov[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None
2226
- )
2250
+ forecast_model = self._build_forecast_model(
2251
+ time_index=time_index,
2252
+ t0=t0,
2253
+ forecast_index=forecast_index,
2254
+ scenario=scenario,
2255
+ filter_output=filter_output,
2256
+ mvn_method=mvn_method,
2257
+ )
2227
2258
 
2228
- _ = LinearGaussianStateSpace(
2229
- "forecast",
2230
- x0,
2231
- P0,
2232
- *matrices,
2233
- steps=len(forecast_index),
2234
- dims=dims,
2235
- sequence_names=self.kalman_filter.seq_names,
2236
- k_endog=self.k_endog,
2237
- append_x0=False,
2238
- method=mvn_method,
2239
- )
2259
+ with forecast_model:
2260
+ if scenario is not None:
2261
+ dummy_obs_data = np.zeros((len(forecast_index), self.k_endog))
2262
+ pm.set_data(
2263
+ scenario | {"data": dummy_obs_data},
2264
+ coords={"data_time": np.arange(len(forecast_index))},
2265
+ )
2240
2266
 
2241
2267
  forecast_model.rvs_to_initial_values = {
2242
2268
  k: None for k in forecast_model.rvs_to_initial_values.keys()
@@ -105,7 +105,7 @@ class KalmanSmoother:
105
105
  a_hat, P_hat = self.predict(a, P, T, R, Q)
106
106
 
107
107
  # Use pinv, otherwise P_hat is singular when there is missing data
108
- smoother_gain = matrix_dot(pt.linalg.pinv(P_hat), T, P).T
108
+ smoother_gain = matrix_dot(pt.linalg.pinv(P_hat, hermitian=True), T, P).T
109
109
  a_smooth_next = a + smoother_gain @ (a_smooth - a_hat)
110
110
 
111
111
  P_smooth_next = P + quad_form_sym(smoother_gain, P_smooth - P_hat)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pymc-extras
3
- Version: 0.2.6
3
+ Version: 0.3.1
4
4
  Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
5
5
  Project-URL: Documentation, https://pymc-extras.readthedocs.io/
6
6
  Project-URL: Repository, https://github.com/pymc-devs/pymc-extras.git
@@ -226,14 +226,14 @@ Classifier: License :: OSI Approved :: Apache Software License
226
226
  Classifier: Operating System :: OS Independent
227
227
  Classifier: Programming Language :: Python
228
228
  Classifier: Programming Language :: Python :: 3
229
- Classifier: Programming Language :: Python :: 3.10
230
229
  Classifier: Programming Language :: Python :: 3.11
231
230
  Classifier: Programming Language :: Python :: 3.12
232
231
  Classifier: Programming Language :: Python :: 3.13
233
232
  Classifier: Topic :: Scientific/Engineering
234
233
  Classifier: Topic :: Scientific/Engineering :: Mathematics
235
- Requires-Python: >=3.10
236
- Requires-Dist: better-optimize>=0.1.2
234
+ Requires-Python: >=3.11
235
+ Requires-Dist: better-optimize>=0.1.4
236
+ Requires-Dist: pydantic>=2.0.0
237
237
  Requires-Dist: pymc>=5.21.1
238
238
  Requires-Dist: scikit-learn
239
239
  Provides-Extra: complete
@@ -245,6 +245,8 @@ Requires-Dist: xhistogram; extra == 'dask-histogram'
245
245
  Provides-Extra: dev
246
246
  Requires-Dist: blackjax; extra == 'dev'
247
247
  Requires-Dist: dask[all]<2025.1.1; extra == 'dev'
248
+ Requires-Dist: preliz>=0.5.0; extra == 'dev'
249
+ Requires-Dist: pytest-mock; extra == 'dev'
248
250
  Requires-Dist: pytest>=6.0; extra == 'dev'
249
251
  Requires-Dist: statsmodels; extra == 'dev'
250
252
  Provides-Extra: docs
@@ -1,7 +1,9 @@
1
1
  pymc_extras/__init__.py,sha256=YsR6OG72aW73y6dGS7w3nGGMV-V-ImHkmUOXKMPfMRA,1230
2
- pymc_extras/linearmodel.py,sha256=6eitl15Ec15mSZu7zoHZ7Wwy4U1DPwqfAgwEt6ILeIc,3920
2
+ pymc_extras/deserialize.py,sha256=dktK5gsR96X3zAUoRF5udrTiconknH3uupiAWqkZi0M,5937
3
+ pymc_extras/linearmodel.py,sha256=KkvZ_DBXOD6myPgVNzu742YV0OzDK449_pDqNC5yae4,3975
3
4
  pymc_extras/model_builder.py,sha256=sAw77fxdiy046BvDPjocuMlbJ0Efj-CDAGtmcwYmoG0,26361
4
5
  pymc_extras/printing.py,sha256=G8mj9dRd6i0PcsbcEWZm56ek6V8mmil78RI4MUhywBs,6506
6
+ pymc_extras/prior.py,sha256=0XbyRRVuS7aKY5gmvJr_iq4fGyHrRDeI_OjWu_O7CTA,39449
5
7
  pymc_extras/distributions/__init__.py,sha256=fDbrBt9mxEVp2CDPwnyCW3oiutzZ0PduB8EUH3fUrjI,1377
6
8
  pymc_extras/distributions/continuous.py,sha256=530wvcO-QcYVdiVN-iQRveImWfyJzzmxiZLMVShP7w4,11251
7
9
  pymc_extras/distributions/discrete.py,sha256=HNi-K0_hnNWTcfyBkWGh26sc71FwBgukQ_EjGAaAOjY,13036
@@ -13,21 +15,24 @@ pymc_extras/distributions/transforms/__init__.py,sha256=FUp2vyRE6_2eUcQ_FVt5Dn0-
13
15
  pymc_extras/distributions/transforms/partial_order.py,sha256=oEZlc9WgnGR46uFEjLzKEUxlhzIo2vrUUbBE3vYrsfQ,8404
14
16
  pymc_extras/gp/__init__.py,sha256=sFHw2y3lEl5tG_FDQHZUonQ_k0DF1JRf0Rp8dpHmge0,745
15
17
  pymc_extras/gp/latent_approx.py,sha256=cDEMM6H1BL2qyKg7BZU-ISrKn2HJe7hDaM4Y8GgQDf4,6682
16
- pymc_extras/inference/__init__.py,sha256=UH6S0bGfQKKyTSuqf7yezdy9PeE2bDU8U1v4eIRv4ZI,887
17
- pymc_extras/inference/find_map.py,sha256=g_qXZbMz6w-De9wCMbBx8yLNkQANdPVWxLN7nJ0O17I,18523
18
- pymc_extras/inference/fit.py,sha256=oe20RAajImZ-VD9Ucbzri8Bof4Y2KHNhNRG19v9O3lI,1336
19
- pymc_extras/inference/laplace.py,sha256=Rq_D6veUYmW93GEyU8UZXiQquvJw-lK1np7NPxKCFqU,22064
18
+ pymc_extras/inference/__init__.py,sha256=YJIBqHoJnjglof7SVESH3u67li_ETmMy24zajld0DNE,917
19
+ pymc_extras/inference/fit.py,sha256=U_jfzuyjk5bV6AvOxtOKzBg-q4z-_BOR06Hn38T0W6E,1328
20
+ pymc_extras/inference/laplace_approx/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
21
+ pymc_extras/inference/laplace_approx/find_map.py,sha256=DihZO3XgiIs7JFKgriluSQDdinMQZG1xjkrQ4OUREno,13859
22
+ pymc_extras/inference/laplace_approx/idata.py,sha256=Ldkzny9qgYg9My2ydIaQfzKJ73y-xcgC7VJiGvbZhWE,13221
23
+ pymc_extras/inference/laplace_approx/laplace.py,sha256=M8s8GfHE5SgNLOiyiHsDuMDvYGEaBMjjqthVnrxY248,18724
24
+ pymc_extras/inference/laplace_approx/scipy_interface.py,sha256=qMxYodmmxaUGsOp1jc7HxBJc6L8NnmFT2Fd4UNNXu2c,8835
20
25
  pymc_extras/inference/pathfinder/__init__.py,sha256=FhAYrCWNx_dCrynEdjg2CZ9tIinvcVLBm67pNx_Y3kA,101
21
26
  pymc_extras/inference/pathfinder/importance_sampling.py,sha256=NwxepXOFit3cA5zEebniKdlnJ1rZWg56aMlH4MEOcG4,6264
22
27
  pymc_extras/inference/pathfinder/lbfgs.py,sha256=GOoJBil5Kft_iFwGNUGKSeqzI5x_shA4KQWDwgGuQtQ,7110
23
- pymc_extras/inference/pathfinder/pathfinder.py,sha256=GW04HQurj_3Nlo1C6_K2tEIeigo8x0buV3FqDLA88PQ,64439
28
+ pymc_extras/inference/pathfinder/pathfinder.py,sha256=yme_wBHnREaT5gSOD6CZ0nb87oScmXjplERiQb0mcAg,64454
24
29
  pymc_extras/inference/smc/__init__.py,sha256=wyaT4NJl1YsSQRLiDy-i0Jq3CbJZ2BQd4nnCk-dIngY,603
25
30
  pymc_extras/inference/smc/sampling.py,sha256=AYwmKqGoV6pBtKnh9SUbBKbN7VcoFgb3MmNWV7SivMA,15365
26
31
  pymc_extras/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
32
  pymc_extras/model/model_api.py,sha256=UHMfQXxWBujeSiUySU0fDUC5Sd_BjT8FoVz3iBxQH_4,2400
28
33
  pymc_extras/model/marginal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
29
34
  pymc_extras/model/marginal/distributions.py,sha256=iM1yT7_BmivgUSloQPKE2QXGPgjvLqDMY_OTBGsdAWg,15563
30
- pymc_extras/model/marginal/graph_analysis.py,sha256=0hWUH_PjfpgneQ3NaT__pWHS1fh50zNbI86kH4Nub0E,15693
35
+ pymc_extras/model/marginal/graph_analysis.py,sha256=l_WSZHivm82297zMIm8i3G_h2F-4Tq397pQlcuEP-0I,15874
31
36
  pymc_extras/model/marginal/marginal_model.py,sha256=oIdikaSnefCkyMxmzAe222qGXNucxZpHYk7548fK6iA,23631
32
37
  pymc_extras/model/transforms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
33
38
  pymc_extras/model/transforms/autoreparam.py,sha256=_NltGWmNqi_X9sHCqAvWcBveLTPxVy11-wENFTcN6kk,12377
@@ -37,11 +42,11 @@ pymc_extras/statespace/__init__.py,sha256=0MtZj7yT6jcyERvITnn-nkhyY8fO6Za4_vV53C
37
42
  pymc_extras/statespace/core/__init__.py,sha256=huHEiXAm8zV2MZyZ8GBHp6q7_fnWqveM7lC6ilpb3iE,309
38
43
  pymc_extras/statespace/core/compile.py,sha256=9FZfE8Bi3VfElxujfOIKRVvmyL9M5R0WfNEqPc5kbVQ,1603
39
44
  pymc_extras/statespace/core/representation.py,sha256=DwNIun6wdeEA20oWBx5M4govyWTf5JI87aGQ_E6Mb4U,18956
40
- pymc_extras/statespace/core/statespace.py,sha256=9jCQ4odmLK3S33tQzKMqck2gsgVoo-C3hOCBX5dc9lA,104674
45
+ pymc_extras/statespace/core/statespace.py,sha256=mr0jDeCEfclCAGfEfEzMsogiuonGO9k2BV67wH_YKio,105627
41
46
  pymc_extras/statespace/filters/__init__.py,sha256=N9Q4D0gAq_ZtT-GtrqiX1HkSg6Orv7o1TbrWUtnbTJE,420
42
47
  pymc_extras/statespace/filters/distributions.py,sha256=-s1c5s2zm6FMc0UqKSrWnJzIF4U5bvJT_3mMNTyV_ak,11927
43
48
  pymc_extras/statespace/filters/kalman_filter.py,sha256=Z6kxsbW8_VQ6ZcPjDMA5d_XPfdUY1-4GfRwKbBNfVZs,31438
44
- pymc_extras/statespace/filters/kalman_smoother.py,sha256=SAjnqtiDdvV79Pp4jp6UzrdIMmH1lqXhCj5WLeHusr8,4167
49
+ pymc_extras/statespace/filters/kalman_smoother.py,sha256=5jlSZAPveJzD5Q8omnpn7Gb1jgElBMgixGR7H9zoH8U,4183
45
50
  pymc_extras/statespace/filters/utilities.py,sha256=iwdaYnO1cO06t_XUjLLRmqb8vwzzVH6Nx1iyZcbJL2k,1584
46
51
  pymc_extras/statespace/models/ETS.py,sha256=08sbiuNvKdxcgKzS7jWj-z4jf-su73WFkYc8sKkGdEs,28538
47
52
  pymc_extras/statespace/models/SARIMAX.py,sha256=aXR6KYuqtSBOk-jvm9NvnOX5vu4QesBgCIL-KR89SXs,22207
@@ -56,10 +61,9 @@ pymc_extras/statespace/utils/data_tools.py,sha256=01sz6XDtLYK9I5xghxYpD-PuDzGXv9
56
61
  pymc_extras/utils/__init__.py,sha256=yxI9cJ7fCtVQS0GFw0y6mDGZIQZiK53vm3UNKqIuGSk,758
57
62
  pymc_extras/utils/linear_cg.py,sha256=KkXhuimFsrKtNd_0By2ApxQQQNm5FdBtmDQJOVbLYkA,10056
58
63
  pymc_extras/utils/model_equivalence.py,sha256=8QIftID2HDxD659i0RXHazQ-l2Q5YegCRLcDqb2p9Pc,2187
59
- pymc_extras/utils/pivoted_cholesky.py,sha256=QtnjP0pAl9b77fLAu-semwT4_9dcoiqx3dz1xKGBjMk,1871
60
64
  pymc_extras/utils/prior.py,sha256=QlWVr7uKIK9VncBw7Fz3YgaASKGDfqpORZHc-vz_9gQ,6841
61
65
  pymc_extras/utils/spline.py,sha256=qGq0gcoMG5dpdazKFzG0RXkkCWP8ADPPXN-653-oFn4,4820
62
- pymc_extras-0.2.6.dist-info/METADATA,sha256=zzdhVkdzXhL7MQH3R0uiCsrcl5i5uh1JLVdRBG6jJyY,18813
63
- pymc_extras-0.2.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
64
- pymc_extras-0.2.6.dist-info/licenses/LICENSE,sha256=WjiLhUKEysJvy5e9jk6WwFv9tmAPtnov1uJ6gcH1kIs,11720
65
- pymc_extras-0.2.6.dist-info/RECORD,,
66
+ pymc_extras-0.3.1.dist-info/METADATA,sha256=QxE1LJTwDeUfZzYZhDjhLA4sy0SX6UcXGDyeVC1dkmM,18881
67
+ pymc_extras-0.3.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
68
+ pymc_extras-0.3.1.dist-info/licenses/LICENSE,sha256=WjiLhUKEysJvy5e9jk6WwFv9tmAPtnov1uJ6gcH1kIs,11720
69
+ pymc_extras-0.3.1.dist-info/RECORD,,