numpyro-forecast 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- numpyro_forecast/__init__.py +45 -0
- numpyro_forecast/datasets.py +104 -0
- numpyro_forecast/evaluate.py +279 -0
- numpyro_forecast/forecaster.py +308 -0
- numpyro_forecast/functional.py +521 -0
- numpyro_forecast/metrics.py +65 -0
- numpyro_forecast/py.typed +0 -0
- numpyro_forecast/typing.py +40 -0
- numpyro_forecast/util.py +225 -0
- numpyro_forecast-0.1.0.dist-info/METADATA +266 -0
- numpyro_forecast-0.1.0.dist-info/RECORD +13 -0
- numpyro_forecast-0.1.0.dist-info/WHEEL +4 -0
- numpyro_forecast-0.1.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,308 @@
|
|
|
1
|
+
"""Forecasting model base class and SVI/MCMC forecasters.
|
|
2
|
+
|
|
3
|
+
This is the JAX/NumPyro port of Pyro's ``pyro.contrib.forecast.forecaster``.
|
|
4
|
+
The classes here are thin object-oriented shims over the functional core in
|
|
5
|
+
:mod:`numpyro_forecast.functional`: :class:`ForecastingModel` threads the
|
|
6
|
+
train/forecast :class:`~numpyro_forecast.functional.Horizon` for you, and the
|
|
7
|
+
forecaster classes wrap :func:`~numpyro_forecast.functional.fit_svi` /
|
|
8
|
+
:func:`~numpyro_forecast.functional.fit_mcmc` plus
|
|
9
|
+
:func:`~numpyro_forecast.functional.forecast`. The two styles are fully
|
|
10
|
+
interchangeable: both consume a model callable ``(covariates, data=None)`` and a
|
|
11
|
+
posterior dict of latent draws.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import abc
|
|
15
|
+
from collections.abc import Callable
|
|
16
|
+
|
|
17
|
+
import numpyro.distributions as dist
|
|
18
|
+
from jax import random
|
|
19
|
+
from jaxtyping import Float
|
|
20
|
+
from numpyro.infer.autoguide import AutoGuide
|
|
21
|
+
from numpyro.infer.reparam import Reparam
|
|
22
|
+
from numpyro.optim import _NumPyroOptim
|
|
23
|
+
|
|
24
|
+
from numpyro_forecast.functional import (
|
|
25
|
+
Horizon,
|
|
26
|
+
draw_posterior,
|
|
27
|
+
fit_mcmc,
|
|
28
|
+
fit_svi,
|
|
29
|
+
)
|
|
30
|
+
from numpyro_forecast.functional import forecast as _forecast
|
|
31
|
+
from numpyro_forecast.functional import predict as _predict
|
|
32
|
+
from numpyro_forecast.functional import time_series as _time_series
|
|
33
|
+
from numpyro_forecast.typing import Array, ForecastModel
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ForecastingModel(abc.ABC):
|
|
37
|
+
"""Abstract base class for forecasting models.
|
|
38
|
+
|
|
39
|
+
Subclasses implement :meth:`model`, which must call :meth:`predict` exactly
|
|
40
|
+
once. The instance itself is the (pure) NumPyro model function with signature
|
|
41
|
+
``model_instance(covariates, data=None)``: the forecast horizon is inferred
|
|
42
|
+
from the shapes (``future = covariates.shape[-2] - data.shape[-2]``).
|
|
43
|
+
|
|
44
|
+
This is the object-oriented façade over the functional API: :meth:`time_series`
|
|
45
|
+
and :meth:`predict` delegate to the free functions in
|
|
46
|
+
:mod:`numpyro_forecast.functional`, passing the current
|
|
47
|
+
:class:`~numpyro_forecast.functional.Horizon`.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(self) -> None:
|
|
51
|
+
self._horizon: Horizon | None = None
|
|
52
|
+
|
|
53
|
+
@abc.abstractmethod
|
|
54
|
+
def model(self, zero_data: Array | None, covariates: Array) -> None:
|
|
55
|
+
"""Define the generative model and call :meth:`predict` exactly once.
|
|
56
|
+
|
|
57
|
+
Parameters
|
|
58
|
+
----------
|
|
59
|
+
zero_data
|
|
60
|
+
Zeros shaped like the data extended to the covariate duration
|
|
61
|
+
(shape/dtype only; ``None`` during pure prior sampling).
|
|
62
|
+
covariates
|
|
63
|
+
Covariates with time at axis ``-2`` and shape
|
|
64
|
+
``(*batch, duration, cov)``.
|
|
65
|
+
"""
|
|
66
|
+
raise NotImplementedError
|
|
67
|
+
|
|
68
|
+
def _require_horizon(self) -> Horizon:
|
|
69
|
+
if self._horizon is None:
|
|
70
|
+
msg = "model state is only available during a model call"
|
|
71
|
+
raise RuntimeError(msg)
|
|
72
|
+
return self._horizon
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def duration(self) -> int:
|
|
76
|
+
"""Total horizon length ``t + future`` (in time steps)."""
|
|
77
|
+
return self._require_horizon().duration
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def t_obs(self) -> int:
|
|
81
|
+
"""Number of observed (in-sample) time steps ``t``."""
|
|
82
|
+
return self._require_horizon().t_obs
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def future(self) -> int:
|
|
86
|
+
"""Number of forecast time steps ``f`` (``0`` while training)."""
|
|
87
|
+
return self._require_horizon().future
|
|
88
|
+
|
|
89
|
+
def time_series(
|
|
90
|
+
self,
|
|
91
|
+
name: str,
|
|
92
|
+
dist_fn: Callable[[], dist.Distribution],
|
|
93
|
+
*,
|
|
94
|
+
reparam: Reparam | None = None,
|
|
95
|
+
) -> Array:
|
|
96
|
+
"""Sample a time-varying latent over the full horizon.
|
|
97
|
+
|
|
98
|
+
Thin wrapper over :func:`numpyro_forecast.functional.time_series`.
|
|
99
|
+
|
|
100
|
+
Parameters
|
|
101
|
+
----------
|
|
102
|
+
name
|
|
103
|
+
Base sample-site name for the in-sample latent.
|
|
104
|
+
dist_fn
|
|
105
|
+
Zero-argument callable returning the per-step prior distribution.
|
|
106
|
+
reparam
|
|
107
|
+
Optional reparameterization (e.g. ``LocScaleReparam``) applied to
|
|
108
|
+
both the in-sample and forecast sites.
|
|
109
|
+
|
|
110
|
+
Returns
|
|
111
|
+
-------
|
|
112
|
+
Array
|
|
113
|
+
The latent over the full horizon with time at axis ``-2``.
|
|
114
|
+
"""
|
|
115
|
+
return _time_series(self._require_horizon(), name, dist_fn, reparam=reparam)
|
|
116
|
+
|
|
117
|
+
def predict(self, noise_dist: dist.Distribution, prediction: Array) -> None:
|
|
118
|
+
"""Register the observation/forecast sites for the model.
|
|
119
|
+
|
|
120
|
+
Thin wrapper over :func:`numpyro_forecast.functional.predict`.
|
|
121
|
+
|
|
122
|
+
Parameters
|
|
123
|
+
----------
|
|
124
|
+
noise_dist
|
|
125
|
+
Zero-centered observation noise (e.g. ``Normal(0, sigma)``).
|
|
126
|
+
prediction
|
|
127
|
+
Deterministic mean with time at axis ``-2``, shape
|
|
128
|
+
``(*batch, duration, obs)``.
|
|
129
|
+
"""
|
|
130
|
+
_predict(self._require_horizon(), noise_dist, prediction)
|
|
131
|
+
|
|
132
|
+
def __call__(self, covariates: Array, data: Array | None = None) -> None:
|
|
133
|
+
"""Run the model as a NumPyro model function.
|
|
134
|
+
|
|
135
|
+
Parameters
|
|
136
|
+
----------
|
|
137
|
+
covariates
|
|
138
|
+
Covariates with time at axis ``-2`` spanning the full horizon.
|
|
139
|
+
data
|
|
140
|
+
Observed data with time at axis ``-2`` (``None`` for prior sampling).
|
|
141
|
+
"""
|
|
142
|
+
horizon = Horizon.from_data(covariates, data)
|
|
143
|
+
self._horizon = horizon
|
|
144
|
+
try:
|
|
145
|
+
self.model(horizon.zero_data, covariates)
|
|
146
|
+
finally:
|
|
147
|
+
self._horizon = None
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class _BaseForecaster(abc.ABC):
|
|
151
|
+
"""Shared forecasting logic over a fitted posterior."""
|
|
152
|
+
|
|
153
|
+
def __init__(self, model: ForecastModel) -> None:
|
|
154
|
+
self.model = model
|
|
155
|
+
|
|
156
|
+
@abc.abstractmethod
|
|
157
|
+
def _draw_posterior(self, rng_key: Array, num_samples: int) -> dict[str, Array]:
|
|
158
|
+
"""Return ``num_samples`` posterior draws of the latent sites."""
|
|
159
|
+
raise NotImplementedError
|
|
160
|
+
|
|
161
|
+
def __call__(
|
|
162
|
+
self,
|
|
163
|
+
rng_key: Array,
|
|
164
|
+
data: Array,
|
|
165
|
+
covariates: Array,
|
|
166
|
+
num_samples: int,
|
|
167
|
+
*,
|
|
168
|
+
batch_size: int | None = None,
|
|
169
|
+
) -> Float[Array, " sample *batch future obs"]:
|
|
170
|
+
"""Sample forecasts for the steps in ``[t, duration)``.
|
|
171
|
+
|
|
172
|
+
Parameters
|
|
173
|
+
----------
|
|
174
|
+
rng_key
|
|
175
|
+
PRNG key.
|
|
176
|
+
data
|
|
177
|
+
Observed data with time at axis ``-2`` and length ``t``.
|
|
178
|
+
covariates
|
|
179
|
+
Covariates with time at axis ``-2`` and length ``duration > t``.
|
|
180
|
+
num_samples
|
|
181
|
+
Number of forecast samples to draw.
|
|
182
|
+
batch_size
|
|
183
|
+
Optional chunk size for sampling (caps peak memory).
|
|
184
|
+
|
|
185
|
+
Returns
|
|
186
|
+
-------
|
|
187
|
+
Float[Array, " sample *batch future obs"]
|
|
188
|
+
Forecast samples over the ``future = duration - t`` horizon.
|
|
189
|
+
"""
|
|
190
|
+
if data.shape[-2] >= covariates.shape[-2]:
|
|
191
|
+
msg = "covariates must extend beyond data along the time axis"
|
|
192
|
+
raise ValueError(msg)
|
|
193
|
+
if num_samples <= 0:
|
|
194
|
+
msg = "num_samples must be positive"
|
|
195
|
+
raise ValueError(msg)
|
|
196
|
+
key_post, key_pred = random.split(rng_key)
|
|
197
|
+
posterior = self._draw_posterior(key_post, num_samples)
|
|
198
|
+
return _forecast(key_pred, self.model, posterior, data, covariates, batch_size=batch_size)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class Forecaster(_BaseForecaster):
|
|
202
|
+
"""Fit a forecasting model with stochastic variational inference.
|
|
203
|
+
|
|
204
|
+
Parameters
|
|
205
|
+
----------
|
|
206
|
+
rng_key
|
|
207
|
+
PRNG key for inference.
|
|
208
|
+
model
|
|
209
|
+
The forecasting model to fit (OOP instance or functional model).
|
|
210
|
+
data
|
|
211
|
+
In-sample data with time at axis ``-2``.
|
|
212
|
+
covariates
|
|
213
|
+
Covariates with time at axis ``-2`` and the same duration as ``data``.
|
|
214
|
+
guide
|
|
215
|
+
Variational guide; defaults to ``AutoNormal(model)``.
|
|
216
|
+
optim
|
|
217
|
+
NumPyro optimizer; defaults to ``Adam(0.01)``.
|
|
218
|
+
num_steps
|
|
219
|
+
Number of SVI steps.
|
|
220
|
+
num_particles
|
|
221
|
+
Number of ELBO particles.
|
|
222
|
+
progress_bar
|
|
223
|
+
Whether to display the SVI progress bar.
|
|
224
|
+
"""
|
|
225
|
+
|
|
226
|
+
def __init__(
|
|
227
|
+
self,
|
|
228
|
+
rng_key: Array,
|
|
229
|
+
model: ForecastModel,
|
|
230
|
+
data: Array,
|
|
231
|
+
covariates: Array,
|
|
232
|
+
*,
|
|
233
|
+
guide: AutoGuide | None = None,
|
|
234
|
+
optim: _NumPyroOptim | None = None,
|
|
235
|
+
num_steps: int = 1_001,
|
|
236
|
+
num_particles: int = 1,
|
|
237
|
+
progress_bar: bool = False,
|
|
238
|
+
) -> None:
|
|
239
|
+
super().__init__(model)
|
|
240
|
+
self._fit = fit_svi(
|
|
241
|
+
rng_key,
|
|
242
|
+
model,
|
|
243
|
+
data,
|
|
244
|
+
covariates,
|
|
245
|
+
guide=guide,
|
|
246
|
+
optim=optim,
|
|
247
|
+
num_steps=num_steps,
|
|
248
|
+
num_particles=num_particles,
|
|
249
|
+
progress_bar=progress_bar,
|
|
250
|
+
)
|
|
251
|
+
self.guide: AutoGuide = self._fit.guide
|
|
252
|
+
self.params: dict[str, Array] = self._fit.params
|
|
253
|
+
self.losses: Array = self._fit.losses
|
|
254
|
+
|
|
255
|
+
def _draw_posterior(self, rng_key: Array, num_samples: int) -> dict[str, Array]:
|
|
256
|
+
return draw_posterior(rng_key, self._fit, num_samples)
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
class HMCForecaster(_BaseForecaster):
|
|
260
|
+
"""Fit a forecasting model with NUTS (Hamiltonian Monte Carlo).
|
|
261
|
+
|
|
262
|
+
Parameters
|
|
263
|
+
----------
|
|
264
|
+
rng_key
|
|
265
|
+
PRNG key for inference.
|
|
266
|
+
model
|
|
267
|
+
The forecasting model to fit (OOP instance or functional model).
|
|
268
|
+
data
|
|
269
|
+
In-sample data with time at axis ``-2``.
|
|
270
|
+
covariates
|
|
271
|
+
Covariates with time at axis ``-2`` and the same duration as ``data``.
|
|
272
|
+
num_warmup
|
|
273
|
+
Number of warmup steps.
|
|
274
|
+
num_samples
|
|
275
|
+
Number of posterior samples.
|
|
276
|
+
num_chains
|
|
277
|
+
Number of MCMC chains.
|
|
278
|
+
progress_bar
|
|
279
|
+
Whether to display the MCMC progress bar.
|
|
280
|
+
"""
|
|
281
|
+
|
|
282
|
+
def __init__(
|
|
283
|
+
self,
|
|
284
|
+
rng_key: Array,
|
|
285
|
+
model: ForecastModel,
|
|
286
|
+
data: Array,
|
|
287
|
+
covariates: Array,
|
|
288
|
+
*,
|
|
289
|
+
num_warmup: int = 1_000,
|
|
290
|
+
num_samples: int = 1_000,
|
|
291
|
+
num_chains: int = 1,
|
|
292
|
+
progress_bar: bool = False,
|
|
293
|
+
) -> None:
|
|
294
|
+
super().__init__(model)
|
|
295
|
+
self._fit = fit_mcmc(
|
|
296
|
+
rng_key,
|
|
297
|
+
model,
|
|
298
|
+
data,
|
|
299
|
+
covariates,
|
|
300
|
+
num_warmup=num_warmup,
|
|
301
|
+
num_samples=num_samples,
|
|
302
|
+
num_chains=num_chains,
|
|
303
|
+
progress_bar=progress_bar,
|
|
304
|
+
)
|
|
305
|
+
self.posterior_samples: dict[str, Array] = self._fit.samples
|
|
306
|
+
|
|
307
|
+
def _draw_posterior(self, rng_key: Array, num_samples: int) -> dict[str, Array]:
|
|
308
|
+
return draw_posterior(rng_key, self._fit, num_samples)
|