numpyro-forecast 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- numpyro_forecast/__init__.py +45 -0
- numpyro_forecast/datasets.py +104 -0
- numpyro_forecast/evaluate.py +279 -0
- numpyro_forecast/forecaster.py +308 -0
- numpyro_forecast/functional.py +521 -0
- numpyro_forecast/metrics.py +65 -0
- numpyro_forecast/py.typed +0 -0
- numpyro_forecast/typing.py +40 -0
- numpyro_forecast/util.py +225 -0
- numpyro_forecast-0.1.0.dist-info/METADATA +266 -0
- numpyro_forecast-0.1.0.dist-info/RECORD +13 -0
- numpyro_forecast-0.1.0.dist-info/WHEEL +4 -0
- numpyro_forecast-0.1.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""numpyro_forecast: a JAX/NumPyro port of Pyro's forecasting module."""
|
|
2
|
+
|
|
3
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
4
|
+
|
|
5
|
+
from jaxtyping import install_import_hook
|
|
6
|
+
|
|
7
|
+
with install_import_hook("numpyro_forecast", "beartype.beartype"):
|
|
8
|
+
from numpyro_forecast import ( # noqa: F401
|
|
9
|
+
datasets,
|
|
10
|
+
evaluate,
|
|
11
|
+
forecaster,
|
|
12
|
+
functional,
|
|
13
|
+
metrics,
|
|
14
|
+
util,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from numpyro_forecast.evaluate import (
|
|
18
|
+
BacktestResult,
|
|
19
|
+
backtest,
|
|
20
|
+
eval_coverage,
|
|
21
|
+
eval_crps,
|
|
22
|
+
eval_mae,
|
|
23
|
+
eval_rmse,
|
|
24
|
+
)
|
|
25
|
+
from numpyro_forecast.forecaster import Forecaster, ForecastingModel, HMCForecaster
|
|
26
|
+
from numpyro_forecast.functional import forecasting_model
|
|
27
|
+
|
|
28
|
+
try:
|
|
29
|
+
__version__ = version("numpyro_forecast")
|
|
30
|
+
except PackageNotFoundError: # pragma: no cover - package not installed
|
|
31
|
+
__version__ = "0.0.0+unknown"
|
|
32
|
+
|
|
33
|
+
__all__ = [
|
|
34
|
+
"BacktestResult",
|
|
35
|
+
"Forecaster",
|
|
36
|
+
"ForecastingModel",
|
|
37
|
+
"HMCForecaster",
|
|
38
|
+
"__version__",
|
|
39
|
+
"backtest",
|
|
40
|
+
"eval_coverage",
|
|
41
|
+
"eval_crps",
|
|
42
|
+
"eval_mae",
|
|
43
|
+
"eval_rmse",
|
|
44
|
+
"forecasting_model",
|
|
45
|
+
]
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""BART ridership dataset helpers.
|
|
2
|
+
|
|
3
|
+
Thin wrappers around :func:`numpyro.examples.datasets.load_bart_od` that return
|
|
4
|
+
arrays in the package convention (time at axis ``-2``) for the two examples.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
from jaxtyping import Float
|
|
9
|
+
|
|
10
|
+
from numpyro_forecast.typing import Array
|
|
11
|
+
|
|
12
|
+
HOURS_PER_WEEK = 24 * 7
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def bart_available() -> bool:
|
|
16
|
+
"""Return whether the BART dataset can be loaded (download succeeds).
|
|
17
|
+
|
|
18
|
+
Returns
|
|
19
|
+
-------
|
|
20
|
+
bool
|
|
21
|
+
``True`` if :func:`load_bart_od` loads without error.
|
|
22
|
+
"""
|
|
23
|
+
try:
|
|
24
|
+
_load_counts()
|
|
25
|
+
except Exception:
|
|
26
|
+
return False
|
|
27
|
+
return True
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _load_counts() -> tuple[Array, list[str]]:
|
|
31
|
+
"""Load raw hourly origin-destination counts ``(time, origin, destin)``."""
|
|
32
|
+
from numpyro.examples.datasets import load_bart_od
|
|
33
|
+
|
|
34
|
+
dataset = load_bart_od()
|
|
35
|
+
counts = jnp.asarray(dataset["counts"])
|
|
36
|
+
stations = [str(name) for name in dataset["stations"]]
|
|
37
|
+
return counts, stations
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def load_bart_weekly() -> Float[Array, " weeks 1"]:
|
|
41
|
+
"""Load total weekly BART ridership (log scale) for the univariate example.
|
|
42
|
+
|
|
43
|
+
Hourly counts are summed over all origin-destination pairs, aggregated into
|
|
44
|
+
non-overlapping weeks, and log-transformed.
|
|
45
|
+
|
|
46
|
+
Returns
|
|
47
|
+
-------
|
|
48
|
+
Float[Array, " weeks 1"]
|
|
49
|
+
Log weekly totals with time at axis ``-2`` and a single observation dim.
|
|
50
|
+
"""
|
|
51
|
+
counts, _ = _load_counts()
|
|
52
|
+
hourly_total = counts.sum(axis=(1, 2))
|
|
53
|
+
num_weeks = hourly_total.shape[0] // HOURS_PER_WEEK
|
|
54
|
+
weekly = hourly_total[: num_weeks * HOURS_PER_WEEK]
|
|
55
|
+
weekly = weekly.reshape(num_weeks, HOURS_PER_WEEK).sum(axis=1)
|
|
56
|
+
return jnp.log(weekly)[:, None]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def load_bart_hierarchical(
|
|
60
|
+
train_days: int = 90,
|
|
61
|
+
test_weeks: int = 2,
|
|
62
|
+
) -> tuple[Float[Array, " origin time destin"], int, list[str]]:
|
|
63
|
+
"""Load the windowed hierarchical BART panel for the hierarchical example.
|
|
64
|
+
|
|
65
|
+
The counts are ``log1p``-transformed and transposed to the
|
|
66
|
+
``(origin, time, destin)`` convention, then restricted to a ``train_days``
|
|
67
|
+
training window followed by a ``test_weeks`` test window.
|
|
68
|
+
|
|
69
|
+
Parameters
|
|
70
|
+
----------
|
|
71
|
+
train_days
|
|
72
|
+
Number of training days (24 hours each).
|
|
73
|
+
test_weeks
|
|
74
|
+
Number of test weeks (``24 * 7`` hours each).
|
|
75
|
+
|
|
76
|
+
Returns
|
|
77
|
+
-------
|
|
78
|
+
y : Float[Array, " origin time destin"]
|
|
79
|
+
Log counts over the train+test window with time at axis ``-2``.
|
|
80
|
+
split : int
|
|
81
|
+
Index along the time axis separating train from test.
|
|
82
|
+
stations : list[str]
|
|
83
|
+
Station names.
|
|
84
|
+
|
|
85
|
+
Raises
|
|
86
|
+
------
|
|
87
|
+
ValueError
|
|
88
|
+
If the requested ``train_days`` + ``test_weeks`` window exceeds the
|
|
89
|
+
available history (which would otherwise wrap a negative slice index).
|
|
90
|
+
"""
|
|
91
|
+
counts, stations = _load_counts()
|
|
92
|
+
log_counts = jnp.log1p(jnp.transpose(counts, (1, 0, 2)))
|
|
93
|
+
t_total = log_counts.shape[1]
|
|
94
|
+
t1 = t_total - test_weeks * HOURS_PER_WEEK
|
|
95
|
+
t0 = t1 - train_days * 24
|
|
96
|
+
if t0 < 0 or t1 <= t0:
|
|
97
|
+
msg = (
|
|
98
|
+
f"requested window (train_days={train_days}, test_weeks={test_weeks}) "
|
|
99
|
+
f"exceeds available history of {t_total} hours"
|
|
100
|
+
)
|
|
101
|
+
raise ValueError(msg)
|
|
102
|
+
y = log_counts[:, t0:t_total, :]
|
|
103
|
+
split = t1 - t0
|
|
104
|
+
return y, split, stations
|
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
"""Backtesting and evaluation metrics.
|
|
2
|
+
|
|
3
|
+
This is the JAX/NumPyro port of ``pyro.contrib.forecast.evaluate``. Unlike Pyro
|
|
4
|
+
there is no global parameter store, so each backtest window is a pure call that
|
|
5
|
+
fits its own forecaster.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from collections.abc import Callable, Mapping
|
|
9
|
+
from dataclasses import asdict, dataclass, field
|
|
10
|
+
from time import perf_counter
|
|
11
|
+
from typing import Any, cast
|
|
12
|
+
|
|
13
|
+
import jax.numpy as jnp
|
|
14
|
+
from jax import random
|
|
15
|
+
|
|
16
|
+
from numpyro_forecast.forecaster import Forecaster
|
|
17
|
+
from numpyro_forecast.metrics import crps_empirical
|
|
18
|
+
from numpyro_forecast.typing import Array, ForecasterFactory, Metric, ModelFactory
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def eval_mae(pred: Array, truth: Array) -> float:
|
|
22
|
+
"""Mean absolute error using the forecast sample median as point estimate.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
pred
|
|
27
|
+
Forecast samples with the sample axis first.
|
|
28
|
+
truth
|
|
29
|
+
Ground-truth values (matching ``pred`` without the sample axis).
|
|
30
|
+
|
|
31
|
+
Returns
|
|
32
|
+
-------
|
|
33
|
+
float
|
|
34
|
+
The mean absolute error.
|
|
35
|
+
"""
|
|
36
|
+
point = jnp.median(pred, axis=0)
|
|
37
|
+
return float(jnp.abs(point - truth).mean())
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def eval_rmse(pred: Array, truth: Array) -> float:
|
|
41
|
+
"""Root mean squared error using the forecast sample mean as point estimate.
|
|
42
|
+
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
pred
|
|
46
|
+
Forecast samples with the sample axis first.
|
|
47
|
+
truth
|
|
48
|
+
Ground-truth values (matching ``pred`` without the sample axis).
|
|
49
|
+
|
|
50
|
+
Returns
|
|
51
|
+
-------
|
|
52
|
+
float
|
|
53
|
+
The root mean squared error.
|
|
54
|
+
"""
|
|
55
|
+
point = pred.mean(axis=0)
|
|
56
|
+
return float(jnp.sqrt(jnp.square(point - truth).mean()))
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def eval_crps(pred: Array, truth: Array) -> float:
|
|
60
|
+
"""Empirical CRPS averaged over all data elements.
|
|
61
|
+
|
|
62
|
+
Parameters
|
|
63
|
+
----------
|
|
64
|
+
pred
|
|
65
|
+
Forecast samples with the sample axis first.
|
|
66
|
+
truth
|
|
67
|
+
Ground-truth values (matching ``pred`` without the sample axis).
|
|
68
|
+
|
|
69
|
+
Returns
|
|
70
|
+
-------
|
|
71
|
+
float
|
|
72
|
+
The mean empirical CRPS.
|
|
73
|
+
"""
|
|
74
|
+
return float(crps_empirical(pred, truth).mean())
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def eval_coverage(pred: Array, truth: Array, *, alpha: float = 0.9) -> float:
|
|
78
|
+
"""Empirical coverage of the central ``alpha`` prediction interval.
|
|
79
|
+
|
|
80
|
+
The central ``alpha`` interval is bounded by the ``(1 - alpha) / 2`` and
|
|
81
|
+
``1 - (1 - alpha) / 2`` quantiles of the forecast samples; the metric is the
|
|
82
|
+
fraction of ground-truth values that fall inside it. A well-calibrated
|
|
83
|
+
forecast has coverage close to ``alpha``.
|
|
84
|
+
|
|
85
|
+
Parameters
|
|
86
|
+
----------
|
|
87
|
+
pred
|
|
88
|
+
Forecast samples with the sample axis first.
|
|
89
|
+
truth
|
|
90
|
+
Ground-truth values (matching ``pred`` without the sample axis).
|
|
91
|
+
alpha
|
|
92
|
+
Nominal interval level in ``(0, 1)`` (defaults to ``0.9``).
|
|
93
|
+
|
|
94
|
+
Returns
|
|
95
|
+
-------
|
|
96
|
+
float
|
|
97
|
+
The fraction of ground-truth values inside the central ``alpha`` interval.
|
|
98
|
+
"""
|
|
99
|
+
tail = (1.0 - alpha) / 2.0
|
|
100
|
+
lo = jnp.quantile(pred, tail, axis=0)
|
|
101
|
+
hi = jnp.quantile(pred, 1.0 - tail, axis=0)
|
|
102
|
+
return float(((truth >= lo) & (truth <= hi)).mean())
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
DEFAULT_METRICS: dict[str, Metric] = {
|
|
106
|
+
"mae": eval_mae,
|
|
107
|
+
"rmse": eval_rmse,
|
|
108
|
+
"crps": eval_crps,
|
|
109
|
+
"coverage": eval_coverage,
|
|
110
|
+
}
|
|
111
|
+
"""Default metrics used by :func:`backtest`."""
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@dataclass(frozen=True)
|
|
115
|
+
class BacktestResult:
|
|
116
|
+
"""Per-window result of a :func:`backtest` run.
|
|
117
|
+
|
|
118
|
+
Attributes
|
|
119
|
+
----------
|
|
120
|
+
t0, t1, t2
|
|
121
|
+
Train-begin, train/test split, and test-end time indices.
|
|
122
|
+
num_samples
|
|
123
|
+
Number of forecast samples drawn.
|
|
124
|
+
train_walltime, test_walltime
|
|
125
|
+
Wall-clock seconds for fitting and forecasting.
|
|
126
|
+
metrics
|
|
127
|
+
Mapping of metric name to value for the window.
|
|
128
|
+
params
|
|
129
|
+
Mapping of scalar parameter name to value (when available).
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
t0: int
|
|
133
|
+
t1: int
|
|
134
|
+
t2: int
|
|
135
|
+
num_samples: int
|
|
136
|
+
train_walltime: float
|
|
137
|
+
test_walltime: float
|
|
138
|
+
metrics: dict[str, float]
|
|
139
|
+
params: dict[str, float] = field(default_factory=dict)
|
|
140
|
+
|
|
141
|
+
def to_dict(self) -> dict[str, Any]:
|
|
142
|
+
"""Return a flat dictionary view (Pyro-style access).
|
|
143
|
+
|
|
144
|
+
Returns
|
|
145
|
+
-------
|
|
146
|
+
dict[str, Any]
|
|
147
|
+
All fields as a plain dictionary.
|
|
148
|
+
"""
|
|
149
|
+
return asdict(self)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def _scalar_params(forecaster: object) -> dict[str, float]:
|
|
153
|
+
"""Extract scalar variational parameters from a fitted forecaster, if any."""
|
|
154
|
+
params = getattr(forecaster, "params", None)
|
|
155
|
+
if not isinstance(params, Mapping):
|
|
156
|
+
return {}
|
|
157
|
+
return {name: float(value) for name, value in params.items() if jnp.size(value) == 1}
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def backtest(
|
|
161
|
+
rng_key: Array,
|
|
162
|
+
data: Array,
|
|
163
|
+
covariates: Array,
|
|
164
|
+
model_fn: ModelFactory,
|
|
165
|
+
*,
|
|
166
|
+
forecaster_fn: ForecasterFactory = Forecaster,
|
|
167
|
+
metrics: Mapping[str, Metric] | None = None,
|
|
168
|
+
transform: Callable[[Array, Array], tuple[Array, Array]] | None = None,
|
|
169
|
+
train_window: int | None = None,
|
|
170
|
+
min_train_window: int = 1,
|
|
171
|
+
test_window: int | None = None,
|
|
172
|
+
min_test_window: int = 1,
|
|
173
|
+
stride: int = 1,
|
|
174
|
+
num_samples: int = 100,
|
|
175
|
+
batch_size: int | None = None,
|
|
176
|
+
forecaster_options: Mapping[str, Any] | Callable[..., Mapping[str, Any]] | None = None,
|
|
177
|
+
) -> list[BacktestResult]:
|
|
178
|
+
"""Backtest a forecasting model on a moving window of ``(train, test)`` data.
|
|
179
|
+
|
|
180
|
+
Parameters
|
|
181
|
+
----------
|
|
182
|
+
rng_key
|
|
183
|
+
Base PRNG key (used for every window, matching Pyro).
|
|
184
|
+
data
|
|
185
|
+
Dataset with time at axis ``-2``.
|
|
186
|
+
covariates
|
|
187
|
+
Covariates with time at axis ``-2`` (same duration as ``data``).
|
|
188
|
+
model_fn
|
|
189
|
+
Factory returning a fresh :class:`ForecastingModel` per window.
|
|
190
|
+
forecaster_fn
|
|
191
|
+
Factory returning a fitted forecaster (defaults to :class:`Forecaster`).
|
|
192
|
+
metrics
|
|
193
|
+
Mapping of metric name to function; defaults to :data:`DEFAULT_METRICS`.
|
|
194
|
+
transform
|
|
195
|
+
Optional ``(pred, truth) -> (pred, truth)`` applied before metrics.
|
|
196
|
+
train_window
|
|
197
|
+
Training window size; if ``None`` the window expands from the start.
|
|
198
|
+
min_train_window
|
|
199
|
+
Minimum training window size when ``train_window`` is ``None``.
|
|
200
|
+
test_window
|
|
201
|
+
Test window size; if ``None`` forecasts to the end of the data.
|
|
202
|
+
min_test_window
|
|
203
|
+
Minimum test window size when ``test_window`` is ``None``.
|
|
204
|
+
stride
|
|
205
|
+
Step between successive train/test splits.
|
|
206
|
+
num_samples
|
|
207
|
+
Number of forecast samples per window.
|
|
208
|
+
batch_size
|
|
209
|
+
Optional forecast-sampling chunk size.
|
|
210
|
+
forecaster_options
|
|
211
|
+
Options dict passed to ``forecaster_fn``, or a callable
|
|
212
|
+
``(t0, t1, t2) -> dict`` returning per-window options.
|
|
213
|
+
|
|
214
|
+
Returns
|
|
215
|
+
-------
|
|
216
|
+
list[BacktestResult]
|
|
217
|
+
One result per backtest window.
|
|
218
|
+
"""
|
|
219
|
+
if data.shape[-2] != covariates.shape[-2]:
|
|
220
|
+
msg = "data and covariates must share the time axis length"
|
|
221
|
+
raise ValueError(msg)
|
|
222
|
+
metrics = DEFAULT_METRICS if metrics is None else metrics
|
|
223
|
+
|
|
224
|
+
def options_for(t0: int, t1: int, t2: int) -> Mapping[str, Any]:
|
|
225
|
+
if forecaster_options is None:
|
|
226
|
+
return {}
|
|
227
|
+
if callable(forecaster_options) and not isinstance(forecaster_options, Mapping):
|
|
228
|
+
return forecaster_options(t0=t0, t1=t1, t2=t2)
|
|
229
|
+
return cast("Mapping[str, Any]", forecaster_options)
|
|
230
|
+
|
|
231
|
+
duration = data.shape[-2]
|
|
232
|
+
stop = duration - (min_test_window if test_window is None else test_window) + 1
|
|
233
|
+
start = min_train_window if train_window is None else train_window
|
|
234
|
+
|
|
235
|
+
results: list[BacktestResult] = []
|
|
236
|
+
for t1 in range(start, stop, stride):
|
|
237
|
+
t0 = 0 if train_window is None else t1 - train_window
|
|
238
|
+
t2 = duration if test_window is None else t1 + test_window
|
|
239
|
+
|
|
240
|
+
train_data = data[..., t0:t1, :]
|
|
241
|
+
train_covariates = covariates[..., t0:t1, :]
|
|
242
|
+
test_covariates = covariates[..., t0:t2, :]
|
|
243
|
+
truth = data[..., t1:t2, :]
|
|
244
|
+
|
|
245
|
+
key_fit, key_forecast = random.split(rng_key)
|
|
246
|
+
options = options_for(t0, t1, t2)
|
|
247
|
+
|
|
248
|
+
fit_start = perf_counter()
|
|
249
|
+
model = model_fn()
|
|
250
|
+
forecaster = forecaster_fn(key_fit, model, train_data, train_covariates, **options)
|
|
251
|
+
train_walltime = perf_counter() - fit_start
|
|
252
|
+
|
|
253
|
+
forecast_start = perf_counter()
|
|
254
|
+
pred = forecaster(
|
|
255
|
+
key_forecast,
|
|
256
|
+
train_data,
|
|
257
|
+
test_covariates,
|
|
258
|
+
num_samples,
|
|
259
|
+
batch_size=batch_size,
|
|
260
|
+
)
|
|
261
|
+
test_walltime = perf_counter() - forecast_start
|
|
262
|
+
|
|
263
|
+
if transform is not None:
|
|
264
|
+
pred, truth = transform(pred, truth)
|
|
265
|
+
|
|
266
|
+
results.append(
|
|
267
|
+
BacktestResult(
|
|
268
|
+
t0=t0,
|
|
269
|
+
t1=t1,
|
|
270
|
+
t2=t2,
|
|
271
|
+
num_samples=num_samples,
|
|
272
|
+
train_walltime=train_walltime,
|
|
273
|
+
test_walltime=test_walltime,
|
|
274
|
+
metrics={name: fn(pred, truth) for name, fn in metrics.items()},
|
|
275
|
+
params=_scalar_params(forecaster),
|
|
276
|
+
)
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
return results
|