numpyro-forecast 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,45 @@
1
+ """numpyro_forecast: a JAX/NumPyro port of Pyro's forecasting module."""
2
+
3
+ from importlib.metadata import PackageNotFoundError, version
4
+
5
+ from jaxtyping import install_import_hook
6
+
7
+ with install_import_hook("numpyro_forecast", "beartype.beartype"):
8
+ from numpyro_forecast import ( # noqa: F401
9
+ datasets,
10
+ evaluate,
11
+ forecaster,
12
+ functional,
13
+ metrics,
14
+ util,
15
+ )
16
+
17
+ from numpyro_forecast.evaluate import (
18
+ BacktestResult,
19
+ backtest,
20
+ eval_coverage,
21
+ eval_crps,
22
+ eval_mae,
23
+ eval_rmse,
24
+ )
25
+ from numpyro_forecast.forecaster import Forecaster, ForecastingModel, HMCForecaster
26
+ from numpyro_forecast.functional import forecasting_model
27
+
28
+ try:
29
+ __version__ = version("numpyro_forecast")
30
+ except PackageNotFoundError: # pragma: no cover - package not installed
31
+ __version__ = "0.0.0+unknown"
32
+
33
+ __all__ = [
34
+ "BacktestResult",
35
+ "Forecaster",
36
+ "ForecastingModel",
37
+ "HMCForecaster",
38
+ "__version__",
39
+ "backtest",
40
+ "eval_coverage",
41
+ "eval_crps",
42
+ "eval_mae",
43
+ "eval_rmse",
44
+ "forecasting_model",
45
+ ]
@@ -0,0 +1,104 @@
1
+ """BART ridership dataset helpers.
2
+
3
+ Thin wrappers around :func:`numpyro.examples.datasets.load_bart_od` that return
4
+ arrays in the package convention (time at axis ``-2``) for the two examples.
5
+ """
6
+
7
+ import jax.numpy as jnp
8
+ from jaxtyping import Float
9
+
10
+ from numpyro_forecast.typing import Array
11
+
12
+ HOURS_PER_WEEK = 24 * 7
13
+
14
+
15
+ def bart_available() -> bool:
16
+ """Return whether the BART dataset can be loaded (download succeeds).
17
+
18
+ Returns
19
+ -------
20
+ bool
21
+ ``True`` if :func:`load_bart_od` loads without error.
22
+ """
23
+ try:
24
+ _load_counts()
25
+ except Exception:
26
+ return False
27
+ return True
28
+
29
+
30
+ def _load_counts() -> tuple[Array, list[str]]:
31
+ """Load raw hourly origin-destination counts ``(time, origin, destin)``."""
32
+ from numpyro.examples.datasets import load_bart_od
33
+
34
+ dataset = load_bart_od()
35
+ counts = jnp.asarray(dataset["counts"])
36
+ stations = [str(name) for name in dataset["stations"]]
37
+ return counts, stations
38
+
39
+
40
+ def load_bart_weekly() -> Float[Array, " weeks 1"]:
41
+ """Load total weekly BART ridership (log scale) for the univariate example.
42
+
43
+ Hourly counts are summed over all origin-destination pairs, aggregated into
44
+ non-overlapping weeks, and log-transformed.
45
+
46
+ Returns
47
+ -------
48
+ Float[Array, " weeks 1"]
49
+ Log weekly totals with time at axis ``-2`` and a single observation dim.
50
+ """
51
+ counts, _ = _load_counts()
52
+ hourly_total = counts.sum(axis=(1, 2))
53
+ num_weeks = hourly_total.shape[0] // HOURS_PER_WEEK
54
+ weekly = hourly_total[: num_weeks * HOURS_PER_WEEK]
55
+ weekly = weekly.reshape(num_weeks, HOURS_PER_WEEK).sum(axis=1)
56
+ return jnp.log(weekly)[:, None]
57
+
58
+
59
+ def load_bart_hierarchical(
60
+ train_days: int = 90,
61
+ test_weeks: int = 2,
62
+ ) -> tuple[Float[Array, " origin time destin"], int, list[str]]:
63
+ """Load the windowed hierarchical BART panel for the hierarchical example.
64
+
65
+ The counts are ``log1p``-transformed and transposed to the
66
+ ``(origin, time, destin)`` convention, then restricted to a ``train_days``
67
+ training window followed by a ``test_weeks`` test window.
68
+
69
+ Parameters
70
+ ----------
71
+ train_days
72
+ Number of training days (24 hours each).
73
+ test_weeks
74
+ Number of test weeks (``24 * 7`` hours each).
75
+
76
+ Returns
77
+ -------
78
+ y : Float[Array, " origin time destin"]
79
+ Log counts over the train+test window with time at axis ``-2``.
80
+ split : int
81
+ Index along the time axis separating train from test.
82
+ stations : list[str]
83
+ Station names.
84
+
85
+ Raises
86
+ ------
87
+ ValueError
88
+ If the requested ``train_days`` + ``test_weeks`` window exceeds the
89
+ available history (which would otherwise wrap a negative slice index).
90
+ """
91
+ counts, stations = _load_counts()
92
+ log_counts = jnp.log1p(jnp.transpose(counts, (1, 0, 2)))
93
+ t_total = log_counts.shape[1]
94
+ t1 = t_total - test_weeks * HOURS_PER_WEEK
95
+ t0 = t1 - train_days * 24
96
+ if t0 < 0 or t1 <= t0:
97
+ msg = (
98
+ f"requested window (train_days={train_days}, test_weeks={test_weeks}) "
99
+ f"exceeds available history of {t_total} hours"
100
+ )
101
+ raise ValueError(msg)
102
+ y = log_counts[:, t0:t_total, :]
103
+ split = t1 - t0
104
+ return y, split, stations
@@ -0,0 +1,279 @@
1
+ """Backtesting and evaluation metrics.
2
+
3
+ This is the JAX/NumPyro port of ``pyro.contrib.forecast.evaluate``. Unlike Pyro
4
+ there is no global parameter store, so each backtest window is a pure call that
5
+ fits its own forecaster.
6
+ """
7
+
8
+ from collections.abc import Callable, Mapping
9
+ from dataclasses import asdict, dataclass, field
10
+ from time import perf_counter
11
+ from typing import Any, cast
12
+
13
+ import jax.numpy as jnp
14
+ from jax import random
15
+
16
+ from numpyro_forecast.forecaster import Forecaster
17
+ from numpyro_forecast.metrics import crps_empirical
18
+ from numpyro_forecast.typing import Array, ForecasterFactory, Metric, ModelFactory
19
+
20
+
21
+ def eval_mae(pred: Array, truth: Array) -> float:
22
+ """Mean absolute error using the forecast sample median as point estimate.
23
+
24
+ Parameters
25
+ ----------
26
+ pred
27
+ Forecast samples with the sample axis first.
28
+ truth
29
+ Ground-truth values (matching ``pred`` without the sample axis).
30
+
31
+ Returns
32
+ -------
33
+ float
34
+ The mean absolute error.
35
+ """
36
+ point = jnp.median(pred, axis=0)
37
+ return float(jnp.abs(point - truth).mean())
38
+
39
+
40
+ def eval_rmse(pred: Array, truth: Array) -> float:
41
+ """Root mean squared error using the forecast sample mean as point estimate.
42
+
43
+ Parameters
44
+ ----------
45
+ pred
46
+ Forecast samples with the sample axis first.
47
+ truth
48
+ Ground-truth values (matching ``pred`` without the sample axis).
49
+
50
+ Returns
51
+ -------
52
+ float
53
+ The root mean squared error.
54
+ """
55
+ point = pred.mean(axis=0)
56
+ return float(jnp.sqrt(jnp.square(point - truth).mean()))
57
+
58
+
59
+ def eval_crps(pred: Array, truth: Array) -> float:
60
+ """Empirical CRPS averaged over all data elements.
61
+
62
+ Parameters
63
+ ----------
64
+ pred
65
+ Forecast samples with the sample axis first.
66
+ truth
67
+ Ground-truth values (matching ``pred`` without the sample axis).
68
+
69
+ Returns
70
+ -------
71
+ float
72
+ The mean empirical CRPS.
73
+ """
74
+ return float(crps_empirical(pred, truth).mean())
75
+
76
+
77
+ def eval_coverage(pred: Array, truth: Array, *, alpha: float = 0.9) -> float:
78
+ """Empirical coverage of the central ``alpha`` prediction interval.
79
+
80
+ The central ``alpha`` interval is bounded by the ``(1 - alpha) / 2`` and
81
+ ``1 - (1 - alpha) / 2`` quantiles of the forecast samples; the metric is the
82
+ fraction of ground-truth values that fall inside it. A well-calibrated
83
+ forecast has coverage close to ``alpha``.
84
+
85
+ Parameters
86
+ ----------
87
+ pred
88
+ Forecast samples with the sample axis first.
89
+ truth
90
+ Ground-truth values (matching ``pred`` without the sample axis).
91
+ alpha
92
+ Nominal interval level in ``(0, 1)`` (defaults to ``0.9``).
93
+
94
+ Returns
95
+ -------
96
+ float
97
+ The fraction of ground-truth values inside the central ``alpha`` interval.
98
+ """
99
+ tail = (1.0 - alpha) / 2.0
100
+ lo = jnp.quantile(pred, tail, axis=0)
101
+ hi = jnp.quantile(pred, 1.0 - tail, axis=0)
102
+ return float(((truth >= lo) & (truth <= hi)).mean())
103
+
104
+
105
+ DEFAULT_METRICS: dict[str, Metric] = {
106
+ "mae": eval_mae,
107
+ "rmse": eval_rmse,
108
+ "crps": eval_crps,
109
+ "coverage": eval_coverage,
110
+ }
111
+ """Default metrics used by :func:`backtest`."""
112
+
113
+
114
+ @dataclass(frozen=True)
115
+ class BacktestResult:
116
+ """Per-window result of a :func:`backtest` run.
117
+
118
+ Attributes
119
+ ----------
120
+ t0, t1, t2
121
+ Train-begin, train/test split, and test-end time indices.
122
+ num_samples
123
+ Number of forecast samples drawn.
124
+ train_walltime, test_walltime
125
+ Wall-clock seconds for fitting and forecasting.
126
+ metrics
127
+ Mapping of metric name to value for the window.
128
+ params
129
+ Mapping of scalar parameter name to value (when available).
130
+ """
131
+
132
+ t0: int
133
+ t1: int
134
+ t2: int
135
+ num_samples: int
136
+ train_walltime: float
137
+ test_walltime: float
138
+ metrics: dict[str, float]
139
+ params: dict[str, float] = field(default_factory=dict)
140
+
141
+ def to_dict(self) -> dict[str, Any]:
142
+ """Return a flat dictionary view (Pyro-style access).
143
+
144
+ Returns
145
+ -------
146
+ dict[str, Any]
147
+ All fields as a plain dictionary.
148
+ """
149
+ return asdict(self)
150
+
151
+
152
+ def _scalar_params(forecaster: object) -> dict[str, float]:
153
+ """Extract scalar variational parameters from a fitted forecaster, if any."""
154
+ params = getattr(forecaster, "params", None)
155
+ if not isinstance(params, Mapping):
156
+ return {}
157
+ return {name: float(value) for name, value in params.items() if jnp.size(value) == 1}
158
+
159
+
160
+ def backtest(
161
+ rng_key: Array,
162
+ data: Array,
163
+ covariates: Array,
164
+ model_fn: ModelFactory,
165
+ *,
166
+ forecaster_fn: ForecasterFactory = Forecaster,
167
+ metrics: Mapping[str, Metric] | None = None,
168
+ transform: Callable[[Array, Array], tuple[Array, Array]] | None = None,
169
+ train_window: int | None = None,
170
+ min_train_window: int = 1,
171
+ test_window: int | None = None,
172
+ min_test_window: int = 1,
173
+ stride: int = 1,
174
+ num_samples: int = 100,
175
+ batch_size: int | None = None,
176
+ forecaster_options: Mapping[str, Any] | Callable[..., Mapping[str, Any]] | None = None,
177
+ ) -> list[BacktestResult]:
178
+ """Backtest a forecasting model on a moving window of ``(train, test)`` data.
179
+
180
+ Parameters
181
+ ----------
182
+ rng_key
183
+ Base PRNG key (used for every window, matching Pyro).
184
+ data
185
+ Dataset with time at axis ``-2``.
186
+ covariates
187
+ Covariates with time at axis ``-2`` (same duration as ``data``).
188
+ model_fn
189
+ Factory returning a fresh :class:`ForecastingModel` per window.
190
+ forecaster_fn
191
+ Factory returning a fitted forecaster (defaults to :class:`Forecaster`).
192
+ metrics
193
+ Mapping of metric name to function; defaults to :data:`DEFAULT_METRICS`.
194
+ transform
195
+ Optional ``(pred, truth) -> (pred, truth)`` applied before metrics.
196
+ train_window
197
+ Training window size; if ``None`` the window expands from the start.
198
+ min_train_window
199
+ Minimum training window size when ``train_window`` is ``None``.
200
+ test_window
201
+ Test window size; if ``None`` forecasts to the end of the data.
202
+ min_test_window
203
+ Minimum test window size when ``test_window`` is ``None``.
204
+ stride
205
+ Step between successive train/test splits.
206
+ num_samples
207
+ Number of forecast samples per window.
208
+ batch_size
209
+ Optional forecast-sampling chunk size.
210
+ forecaster_options
211
+ Options dict passed to ``forecaster_fn``, or a callable
212
+ ``(t0, t1, t2) -> dict`` returning per-window options.
213
+
214
+ Returns
215
+ -------
216
+ list[BacktestResult]
217
+ One result per backtest window.
218
+ """
219
+ if data.shape[-2] != covariates.shape[-2]:
220
+ msg = "data and covariates must share the time axis length"
221
+ raise ValueError(msg)
222
+ metrics = DEFAULT_METRICS if metrics is None else metrics
223
+
224
+ def options_for(t0: int, t1: int, t2: int) -> Mapping[str, Any]:
225
+ if forecaster_options is None:
226
+ return {}
227
+ if callable(forecaster_options) and not isinstance(forecaster_options, Mapping):
228
+ return forecaster_options(t0=t0, t1=t1, t2=t2)
229
+ return cast("Mapping[str, Any]", forecaster_options)
230
+
231
+ duration = data.shape[-2]
232
+ stop = duration - (min_test_window if test_window is None else test_window) + 1
233
+ start = min_train_window if train_window is None else train_window
234
+
235
+ results: list[BacktestResult] = []
236
+ for t1 in range(start, stop, stride):
237
+ t0 = 0 if train_window is None else t1 - train_window
238
+ t2 = duration if test_window is None else t1 + test_window
239
+
240
+ train_data = data[..., t0:t1, :]
241
+ train_covariates = covariates[..., t0:t1, :]
242
+ test_covariates = covariates[..., t0:t2, :]
243
+ truth = data[..., t1:t2, :]
244
+
245
+ key_fit, key_forecast = random.split(rng_key)
246
+ options = options_for(t0, t1, t2)
247
+
248
+ fit_start = perf_counter()
249
+ model = model_fn()
250
+ forecaster = forecaster_fn(key_fit, model, train_data, train_covariates, **options)
251
+ train_walltime = perf_counter() - fit_start
252
+
253
+ forecast_start = perf_counter()
254
+ pred = forecaster(
255
+ key_forecast,
256
+ train_data,
257
+ test_covariates,
258
+ num_samples,
259
+ batch_size=batch_size,
260
+ )
261
+ test_walltime = perf_counter() - forecast_start
262
+
263
+ if transform is not None:
264
+ pred, truth = transform(pred, truth)
265
+
266
+ results.append(
267
+ BacktestResult(
268
+ t0=t0,
269
+ t1=t1,
270
+ t2=t2,
271
+ num_samples=num_samples,
272
+ train_walltime=train_walltime,
273
+ test_walltime=test_walltime,
274
+ metrics={name: fn(pred, truth) for name, fn in metrics.items()},
275
+ params=_scalar_params(forecaster),
276
+ )
277
+ )
278
+
279
+ return results