tsagentkit 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tsagentkit/__init__.py +126 -0
- tsagentkit/anomaly/__init__.py +130 -0
- tsagentkit/backtest/__init__.py +48 -0
- tsagentkit/backtest/engine.py +788 -0
- tsagentkit/backtest/metrics.py +244 -0
- tsagentkit/backtest/report.py +342 -0
- tsagentkit/calibration/__init__.py +136 -0
- tsagentkit/contracts/__init__.py +133 -0
- tsagentkit/contracts/errors.py +275 -0
- tsagentkit/contracts/results.py +418 -0
- tsagentkit/contracts/schema.py +44 -0
- tsagentkit/contracts/task_spec.py +300 -0
- tsagentkit/covariates/__init__.py +340 -0
- tsagentkit/eval/__init__.py +285 -0
- tsagentkit/features/__init__.py +20 -0
- tsagentkit/features/covariates.py +328 -0
- tsagentkit/features/extra/__init__.py +5 -0
- tsagentkit/features/extra/native.py +179 -0
- tsagentkit/features/factory.py +187 -0
- tsagentkit/features/matrix.py +159 -0
- tsagentkit/features/tsfeatures_adapter.py +115 -0
- tsagentkit/features/versioning.py +203 -0
- tsagentkit/hierarchy/__init__.py +39 -0
- tsagentkit/hierarchy/aggregation.py +62 -0
- tsagentkit/hierarchy/evaluator.py +400 -0
- tsagentkit/hierarchy/reconciliation.py +232 -0
- tsagentkit/hierarchy/structure.py +453 -0
- tsagentkit/models/__init__.py +182 -0
- tsagentkit/models/adapters/__init__.py +83 -0
- tsagentkit/models/adapters/base.py +321 -0
- tsagentkit/models/adapters/chronos.py +387 -0
- tsagentkit/models/adapters/moirai.py +256 -0
- tsagentkit/models/adapters/registry.py +171 -0
- tsagentkit/models/adapters/timesfm.py +440 -0
- tsagentkit/models/baselines.py +207 -0
- tsagentkit/models/sktime.py +307 -0
- tsagentkit/monitoring/__init__.py +51 -0
- tsagentkit/monitoring/alerts.py +302 -0
- tsagentkit/monitoring/coverage.py +203 -0
- tsagentkit/monitoring/drift.py +330 -0
- tsagentkit/monitoring/report.py +214 -0
- tsagentkit/monitoring/stability.py +275 -0
- tsagentkit/monitoring/triggers.py +423 -0
- tsagentkit/qa/__init__.py +347 -0
- tsagentkit/router/__init__.py +37 -0
- tsagentkit/router/bucketing.py +489 -0
- tsagentkit/router/fallback.py +132 -0
- tsagentkit/router/plan.py +23 -0
- tsagentkit/router/router.py +271 -0
- tsagentkit/series/__init__.py +26 -0
- tsagentkit/series/alignment.py +206 -0
- tsagentkit/series/dataset.py +449 -0
- tsagentkit/series/sparsity.py +261 -0
- tsagentkit/series/validation.py +393 -0
- tsagentkit/serving/__init__.py +39 -0
- tsagentkit/serving/orchestration.py +943 -0
- tsagentkit/serving/packaging.py +73 -0
- tsagentkit/serving/provenance.py +317 -0
- tsagentkit/serving/tsfm_cache.py +214 -0
- tsagentkit/skill/README.md +135 -0
- tsagentkit/skill/__init__.py +8 -0
- tsagentkit/skill/recipes.md +429 -0
- tsagentkit/skill/tool_map.md +21 -0
- tsagentkit/time/__init__.py +134 -0
- tsagentkit/utils/__init__.py +20 -0
- tsagentkit/utils/quantiles.py +83 -0
- tsagentkit/utils/signature.py +47 -0
- tsagentkit/utils/temporal.py +41 -0
- tsagentkit-1.0.2.dist-info/METADATA +371 -0
- tsagentkit-1.0.2.dist-info/RECORD +72 -0
- tsagentkit-1.0.2.dist-info/WHEEL +4 -0
- tsagentkit-1.0.2.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,387 @@
|
|
|
1
|
+
"""Amazon Chronos2 TSFM adapter.
|
|
2
|
+
|
|
3
|
+
Adapter for Amazon's Chronos2 time series forecasting models.
|
|
4
|
+
Chronos2 is a family of pretrained models based on T5 architecture
|
|
5
|
+
that supports zero-shot forecasting.
|
|
6
|
+
|
|
7
|
+
Reference: https://github.com/amazon-science/chronos-forecasting
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from typing import TYPE_CHECKING
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import pandas as pd
|
|
16
|
+
|
|
17
|
+
from tsagentkit.utils import quantile_col_name
|
|
18
|
+
|
|
19
|
+
from .base import TSFMAdapter
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from tsagentkit.contracts import ForecastResult, ModelArtifact
|
|
23
|
+
from tsagentkit.series import TSDataset
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ChronosAdapter(TSFMAdapter):
|
|
27
|
+
"""Adapter for Amazon Chronos2 time series models.
|
|
28
|
+
|
|
29
|
+
Chronos2 is a family of pretrained time series forecasting models
|
|
30
|
+
based on T5 architecture. It supports zero-shot forecasting on
|
|
31
|
+
unseen time series data.
|
|
32
|
+
|
|
33
|
+
Available model sizes:
|
|
34
|
+
- small: AutoGluon Chronos-2-Small (fast, lightweight)
|
|
35
|
+
- base: Amazon Chronos-2 (best accuracy)
|
|
36
|
+
|
|
37
|
+
Example:
|
|
38
|
+
>>> config = AdapterConfig(model_name="chronos", model_size="small")
|
|
39
|
+
>>> adapter = ChronosAdapter(config)
|
|
40
|
+
>>> adapter.load_model()
|
|
41
|
+
>>> result = adapter.predict(dataset, horizon=30)
|
|
42
|
+
|
|
43
|
+
Reference:
|
|
44
|
+
https://github.com/amazon-science/chronos-forecasting
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
# HuggingFace model IDs for Chronos2
|
|
48
|
+
MODEL_SIZES = {
|
|
49
|
+
"small": "autogluon/chronos-2-small",
|
|
50
|
+
"base": "amazon/chronos-2",
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
def load_model(self) -> None:
|
|
54
|
+
"""Load Chronos model from HuggingFace.
|
|
55
|
+
|
|
56
|
+
Downloads and caches the model if not already present.
|
|
57
|
+
|
|
58
|
+
Raises:
|
|
59
|
+
ImportError: If chronos-forecasting is not installed
|
|
60
|
+
RuntimeError: If model loading fails
|
|
61
|
+
"""
|
|
62
|
+
try:
|
|
63
|
+
from chronos import Chronos2Pipeline
|
|
64
|
+
except ImportError as e:
|
|
65
|
+
raise ImportError(
|
|
66
|
+
"chronos-forecasting>=2.0.0 is required for ChronosAdapter. "
|
|
67
|
+
"Install with: pip install 'chronos-forecasting>=2.0.0'"
|
|
68
|
+
) from e
|
|
69
|
+
|
|
70
|
+
model_id = self.MODEL_SIZES.get(
|
|
71
|
+
self.config.model_size,
|
|
72
|
+
self.MODEL_SIZES["small"]
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
try:
|
|
76
|
+
self._model = Chronos2Pipeline.from_pretrained(
|
|
77
|
+
model_id,
|
|
78
|
+
device_map=self._device,
|
|
79
|
+
)
|
|
80
|
+
except Exception as e:
|
|
81
|
+
raise RuntimeError(f"Failed to load Chronos model: {e}") from e
|
|
82
|
+
|
|
83
|
+
def fit(
|
|
84
|
+
self,
|
|
85
|
+
dataset: TSDataset,
|
|
86
|
+
prediction_length: int,
|
|
87
|
+
quantiles: list[float] | None = None,
|
|
88
|
+
) -> ModelArtifact:
|
|
89
|
+
"""Prepare Chronos for prediction.
|
|
90
|
+
|
|
91
|
+
Chronos is a zero-shot model and doesn't require training.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
dataset: Dataset to validate
|
|
95
|
+
prediction_length: Forecast horizon
|
|
96
|
+
quantiles: Optional quantile levels
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
ModelArtifact with model reference
|
|
100
|
+
"""
|
|
101
|
+
from tsagentkit.contracts import ModelArtifact
|
|
102
|
+
|
|
103
|
+
if not self.is_loaded:
|
|
104
|
+
self.load_model()
|
|
105
|
+
|
|
106
|
+
self._validate_dataset(dataset)
|
|
107
|
+
|
|
108
|
+
return ModelArtifact(
|
|
109
|
+
model=self._model,
|
|
110
|
+
model_name=f"chronos-{self.config.model_size}",
|
|
111
|
+
config={
|
|
112
|
+
"model_size": self.config.model_size,
|
|
113
|
+
"device": self._device,
|
|
114
|
+
"prediction_length": prediction_length,
|
|
115
|
+
"quantiles": quantiles,
|
|
116
|
+
},
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
def predict(
|
|
120
|
+
self,
|
|
121
|
+
dataset: TSDataset,
|
|
122
|
+
horizon: int,
|
|
123
|
+
quantiles: list[float] | None = None,
|
|
124
|
+
) -> ForecastResult:
|
|
125
|
+
"""Generate forecasts using Chronos.
|
|
126
|
+
|
|
127
|
+
Supports covariate-informed forecasting when dataset contains
|
|
128
|
+
past_covariates (past_x) and/or future_covariates (future_x).
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
dataset: Historical data for context, optionally with covariates
|
|
132
|
+
horizon: Number of steps to forecast
|
|
133
|
+
quantiles: Quantile levels for probabilistic forecasts
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
ForecastResult with predictions and provenance
|
|
137
|
+
"""
|
|
138
|
+
if not self.is_loaded:
|
|
139
|
+
self.load_model()
|
|
140
|
+
|
|
141
|
+
context_df, future_df = self._to_chronos_df(dataset, horizon)
|
|
142
|
+
|
|
143
|
+
# Use predict_df for pandas-friendly API
|
|
144
|
+
# quantile_levels must not be None for Chronos 2.0
|
|
145
|
+
quantile_levels = quantiles if quantiles is not None else [0.1, 0.5, 0.9]
|
|
146
|
+
pred_df = self._model.predict_df(
|
|
147
|
+
context_df,
|
|
148
|
+
future_df=future_df if future_df is not None else None,
|
|
149
|
+
id_column="item_id",
|
|
150
|
+
timestamp_column="timestamp",
|
|
151
|
+
target="target",
|
|
152
|
+
prediction_length=horizon,
|
|
153
|
+
quantile_levels=quantile_levels,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
return self._to_forecast_result(pred_df, dataset, horizon, quantiles)
|
|
157
|
+
|
|
158
|
+
def _to_chronos_df(
|
|
159
|
+
self, dataset: TSDataset, horizon: int
|
|
160
|
+
) -> tuple[pd.DataFrame, pd.DataFrame | None]:
|
|
161
|
+
"""Convert TSDataset to Chronos DataFrame format with covariates.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
dataset: Input dataset with optional covariates
|
|
165
|
+
horizon: Forecast horizon for generating future timestamps
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
Tuple of (context_df, future_df):
|
|
169
|
+
- context_df: Historical data with target and covariates
|
|
170
|
+
- future_df: Future covariates only (no target), or None if no future covariates
|
|
171
|
+
"""
|
|
172
|
+
# Start with base columns
|
|
173
|
+
df = dataset.df[["unique_id", "ds", "y"]].copy()
|
|
174
|
+
df = df.sort_values(["unique_id", "ds"]).reset_index(drop=True)
|
|
175
|
+
|
|
176
|
+
# Handle missing values in target
|
|
177
|
+
if df["y"].isna().any():
|
|
178
|
+
df["y"] = df.groupby("unique_id")["y"].transform(
|
|
179
|
+
self._handle_missing_values
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Merge past covariates if available
|
|
183
|
+
if dataset.past_x is not None:
|
|
184
|
+
df = self._merge_covariates(df, dataset.past_x, "past")
|
|
185
|
+
|
|
186
|
+
# Limit context length if specified
|
|
187
|
+
if self.config.max_context_length:
|
|
188
|
+
df = df.groupby("unique_id", as_index=False).tail(
|
|
189
|
+
self.config.max_context_length
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
# Rename columns for Chronos format
|
|
193
|
+
context_df = df.rename(
|
|
194
|
+
columns={"unique_id": "item_id", "ds": "timestamp", "y": "target"}
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# Prepare future_df if future covariates are available
|
|
198
|
+
future_df = None
|
|
199
|
+
if dataset.future_x is not None:
|
|
200
|
+
future_df = self._prepare_future_df(dataset, horizon)
|
|
201
|
+
|
|
202
|
+
return context_df, future_df
|
|
203
|
+
|
|
204
|
+
def _merge_covariates(
|
|
205
|
+
self, df: pd.DataFrame, covariates: pd.DataFrame, cov_type: str
|
|
206
|
+
) -> pd.DataFrame:
|
|
207
|
+
"""Merge covariates into main DataFrame.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
df: Main DataFrame with unique_id, ds, y
|
|
211
|
+
covariates: Covariate DataFrame
|
|
212
|
+
cov_type: Type of covariate ("past" or "future")
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
Merged DataFrame
|
|
216
|
+
"""
|
|
217
|
+
# Ensure covariates have proper index columns
|
|
218
|
+
cov_df = covariates.copy()
|
|
219
|
+
|
|
220
|
+
# Reset index if needed to get unique_id and ds as columns
|
|
221
|
+
if isinstance(cov_df.index, pd.MultiIndex):
|
|
222
|
+
cov_df = cov_df.reset_index()
|
|
223
|
+
|
|
224
|
+
# Ensure required columns exist
|
|
225
|
+
if "unique_id" not in cov_df.columns and "id" in cov_df.columns:
|
|
226
|
+
cov_df = cov_df.rename(columns={"id": "unique_id"})
|
|
227
|
+
if "ds" not in cov_df.columns and "timestamp" in cov_df.columns:
|
|
228
|
+
cov_df = cov_df.rename(columns={"timestamp": "ds"})
|
|
229
|
+
|
|
230
|
+
# Merge on unique_id and ds
|
|
231
|
+
merge_cols = ["unique_id", "ds"]
|
|
232
|
+
available_cols = [c for c in merge_cols if c in cov_df.columns]
|
|
233
|
+
|
|
234
|
+
if len(available_cols) == 2:
|
|
235
|
+
# Handle missing values in covariates before merging
|
|
236
|
+
cov_cols = [c for c in cov_df.columns if c not in merge_cols]
|
|
237
|
+
for col in cov_cols:
|
|
238
|
+
if cov_df[col].isna().any():
|
|
239
|
+
cov_df[col] = self._handle_missing_values(cov_df[col])
|
|
240
|
+
|
|
241
|
+
df = df.merge(cov_df, on=merge_cols, how="left")
|
|
242
|
+
|
|
243
|
+
return df
|
|
244
|
+
|
|
245
|
+
def _prepare_future_df(
|
|
246
|
+
self, dataset: TSDataset, horizon: int
|
|
247
|
+
) -> pd.DataFrame | None:
|
|
248
|
+
"""Prepare future covariates DataFrame for Chronos.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
dataset: Input dataset with future covariates
|
|
252
|
+
horizon: Forecast horizon
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
Future covariates DataFrame without target column
|
|
256
|
+
"""
|
|
257
|
+
if dataset.future_x is None:
|
|
258
|
+
return None
|
|
259
|
+
|
|
260
|
+
future_df = dataset.future_x.copy()
|
|
261
|
+
|
|
262
|
+
# Reset index if needed
|
|
263
|
+
if isinstance(future_df.index, pd.MultiIndex):
|
|
264
|
+
future_df = future_df.reset_index()
|
|
265
|
+
|
|
266
|
+
# Rename columns to Chronos format
|
|
267
|
+
if "unique_id" not in future_df.columns and "id" in future_df.columns:
|
|
268
|
+
future_df = future_df.rename(columns={"id": "unique_id"})
|
|
269
|
+
if "ds" not in future_df.columns and "timestamp" in future_df.columns:
|
|
270
|
+
future_df = future_df.rename(columns={"timestamp": "ds"})
|
|
271
|
+
|
|
272
|
+
# Rename to Chronos expected column names
|
|
273
|
+
future_df = future_df.rename(
|
|
274
|
+
columns={"unique_id": "item_id", "ds": "timestamp"}
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# Ensure timestamp column exists - generate if needed
|
|
278
|
+
if "timestamp" not in future_df.columns and dataset.future_index is not None:
|
|
279
|
+
future_df = future_df.reset_index()
|
|
280
|
+
if "ds" in future_df.columns:
|
|
281
|
+
future_df = future_df.rename(columns={"ds": "timestamp"})
|
|
282
|
+
|
|
283
|
+
# Handle missing values in future covariates
|
|
284
|
+
for col in future_df.columns:
|
|
285
|
+
if col not in ["item_id", "timestamp"] and future_df[col].isna().any():
|
|
286
|
+
future_df[col] = self._handle_missing_values(future_df[col])
|
|
287
|
+
|
|
288
|
+
return future_df
|
|
289
|
+
|
|
290
|
+
def _handle_missing_values(
|
|
291
|
+
self, values: pd.Series | np.ndarray
|
|
292
|
+
) -> pd.Series | np.ndarray:
|
|
293
|
+
"""Fill missing values using linear interpolation.
|
|
294
|
+
|
|
295
|
+
Args:
|
|
296
|
+
values: Series or array that may contain NaNs
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
Values with NaNs filled
|
|
300
|
+
"""
|
|
301
|
+
is_array = isinstance(values, np.ndarray)
|
|
302
|
+
s = pd.Series(values).astype(float)
|
|
303
|
+
s = s.interpolate(method="linear", limit_direction="both")
|
|
304
|
+
if s.isna().any():
|
|
305
|
+
fill_val = 0.0 if pd.isna(s.mean()) else s.mean()
|
|
306
|
+
s = s.fillna(fill_val)
|
|
307
|
+
return s.values if is_array else s
|
|
308
|
+
|
|
309
|
+
def _to_forecast_result(
|
|
310
|
+
self,
|
|
311
|
+
pred_df: pd.DataFrame,
|
|
312
|
+
dataset: TSDataset,
|
|
313
|
+
horizon: int,
|
|
314
|
+
quantiles: list[float] | None,
|
|
315
|
+
) -> ForecastResult:
|
|
316
|
+
"""Convert Chronos predictions to ForecastResult.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
pred_df: DataFrame from Chronos predict_df
|
|
320
|
+
dataset: Original dataset
|
|
321
|
+
horizon: Forecast horizon
|
|
322
|
+
quantiles: Quantile levels
|
|
323
|
+
|
|
324
|
+
Returns:
|
|
325
|
+
ForecastResult with predictions
|
|
326
|
+
"""
|
|
327
|
+
from tsagentkit.contracts import ForecastResult
|
|
328
|
+
|
|
329
|
+
# Map column names
|
|
330
|
+
result_df = pred_df.rename(
|
|
331
|
+
columns={
|
|
332
|
+
"item_id": "unique_id",
|
|
333
|
+
"timestamp": "ds",
|
|
334
|
+
"predictions": "yhat",
|
|
335
|
+
}
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
# Handle quantile columns (Chronos returns them as strings like "0.1", "0.5")
|
|
339
|
+
if quantiles:
|
|
340
|
+
quantile_cols = {}
|
|
341
|
+
for col in result_df.columns:
|
|
342
|
+
if col not in ["unique_id", "ds", "yhat"]:
|
|
343
|
+
try:
|
|
344
|
+
q_val = float(col)
|
|
345
|
+
if 0 < q_val < 1:
|
|
346
|
+
quantile_cols[q_val] = col
|
|
347
|
+
except (TypeError, ValueError):
|
|
348
|
+
continue
|
|
349
|
+
|
|
350
|
+
for q in quantiles:
|
|
351
|
+
if quantile_cols:
|
|
352
|
+
nearest = min(quantile_cols, key=lambda v: abs(v - q))
|
|
353
|
+
result_df[quantile_col_name(q)] = result_df[quantile_cols[nearest]]
|
|
354
|
+
|
|
355
|
+
# Select and order columns
|
|
356
|
+
keep_cols = ["unique_id", "ds", "yhat"]
|
|
357
|
+
for q in quantiles or []:
|
|
358
|
+
col = quantile_col_name(q)
|
|
359
|
+
if col in result_df.columns:
|
|
360
|
+
keep_cols.append(col)
|
|
361
|
+
|
|
362
|
+
result_df = result_df[keep_cols].copy()
|
|
363
|
+
result_df["model"] = f"chronos-{self.config.model_size}"
|
|
364
|
+
|
|
365
|
+
provenance = self._create_provenance(dataset, horizon, quantiles)
|
|
366
|
+
|
|
367
|
+
return ForecastResult(
|
|
368
|
+
df=result_df,
|
|
369
|
+
provenance=provenance,
|
|
370
|
+
model_name=f"chronos-{self.config.model_size}",
|
|
371
|
+
horizon=horizon,
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
def get_model_signature(self) -> str:
|
|
375
|
+
"""Return model signature for provenance."""
|
|
376
|
+
return f"chronos-{self.config.model_size}-{self._device}"
|
|
377
|
+
|
|
378
|
+
@classmethod
|
|
379
|
+
def _check_dependencies(cls) -> None:
|
|
380
|
+
"""Check if Chronos dependencies are installed."""
|
|
381
|
+
try:
|
|
382
|
+
import chronos # noqa: F401
|
|
383
|
+
except ImportError as e:
|
|
384
|
+
raise ImportError(
|
|
385
|
+
"chronos-forecasting is required. "
|
|
386
|
+
"Install with: pip install chronos-forecasting"
|
|
387
|
+
) from e
|
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
"""Salesforce Moirai 2.0 TSFM adapter.
|
|
2
|
+
|
|
3
|
+
Adapter for Salesforce's Moirai 2.0 universal time series forecasting model.
|
|
4
|
+
Moirai 2.0 is a transformer-based model with improved architecture.
|
|
5
|
+
|
|
6
|
+
Reference: https://github.com/SalesforceAIResearch/uni2ts
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from typing import TYPE_CHECKING
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import pandas as pd
|
|
15
|
+
|
|
16
|
+
from tsagentkit.time import normalize_pandas_freq
|
|
17
|
+
from tsagentkit.utils import quantile_col_name
|
|
18
|
+
|
|
19
|
+
from .base import TSFMAdapter
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from tsagentkit.contracts import ForecastResult, ModelArtifact
|
|
23
|
+
from tsagentkit.series import TSDataset
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class MoiraiAdapter(TSFMAdapter):
|
|
27
|
+
"""Adapter for Salesforce Moirai 2.0 foundation model.
|
|
28
|
+
|
|
29
|
+
Moirai 2.0 is a universal time series forecasting transformer with
|
|
30
|
+
improved architecture over Moirai 1.x.
|
|
31
|
+
|
|
32
|
+
Available model sizes:
|
|
33
|
+
- small: 384d model, 6 layers (recommended for fast inference)
|
|
34
|
+
|
|
35
|
+
Example:
|
|
36
|
+
>>> config = AdapterConfig(model_name="moirai", model_size="small")
|
|
37
|
+
>>> adapter = MoiraiAdapter(config)
|
|
38
|
+
>>> adapter.load_model()
|
|
39
|
+
>>> result = adapter.predict(dataset, horizon=30)
|
|
40
|
+
|
|
41
|
+
Reference:
|
|
42
|
+
https://github.com/SalesforceAIResearch/uni2ts
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
# HuggingFace model ID for Moirai 2.0 (only small available currently)
|
|
46
|
+
MODEL_ID = "Salesforce/moirai-2.0-R-small"
|
|
47
|
+
|
|
48
|
+
# Default context length for Moirai 2.0
|
|
49
|
+
DEFAULT_CONTEXT_LENGTH = 512
|
|
50
|
+
|
|
51
|
+
def load_model(self) -> None:
|
|
52
|
+
"""Load Moirai 2.0 model from HuggingFace.
|
|
53
|
+
|
|
54
|
+
Downloads and caches the model if not already present.
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
ImportError: If uni2ts is not installed
|
|
58
|
+
RuntimeError: If model loading fails
|
|
59
|
+
"""
|
|
60
|
+
try:
|
|
61
|
+
from uni2ts.model.moirai2 import Moirai2Module
|
|
62
|
+
except ImportError as e:
|
|
63
|
+
raise ImportError(
|
|
64
|
+
"uni2ts>=2.0.0 is required for MoiraiAdapter. "
|
|
65
|
+
"Install with: pip install 'uni2ts @ git+https://github.com/SalesforceAIResearch/uni2ts.git'"
|
|
66
|
+
) from e
|
|
67
|
+
|
|
68
|
+
self._module = Moirai2Module.from_pretrained(self.MODEL_ID)
|
|
69
|
+
|
|
70
|
+
def fit(
|
|
71
|
+
self,
|
|
72
|
+
dataset: TSDataset,
|
|
73
|
+
prediction_length: int,
|
|
74
|
+
quantiles: list[float] | None = None,
|
|
75
|
+
) -> ModelArtifact:
|
|
76
|
+
"""Prepare Moirai for prediction.
|
|
77
|
+
|
|
78
|
+
Moirai is a zero-shot model and doesn't require training.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
dataset: Dataset to validate
|
|
82
|
+
prediction_length: Forecast horizon
|
|
83
|
+
quantiles: Optional quantile levels
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
ModelArtifact with model reference
|
|
87
|
+
"""
|
|
88
|
+
from tsagentkit.contracts import ModelArtifact
|
|
89
|
+
|
|
90
|
+
if not self.is_loaded:
|
|
91
|
+
self.load_model()
|
|
92
|
+
|
|
93
|
+
self._validate_dataset(dataset)
|
|
94
|
+
|
|
95
|
+
return ModelArtifact(
|
|
96
|
+
model=self._module,
|
|
97
|
+
model_name="moirai-2.0",
|
|
98
|
+
config={
|
|
99
|
+
"model_size": "small",
|
|
100
|
+
"device": self._device,
|
|
101
|
+
"prediction_length": prediction_length,
|
|
102
|
+
"quantiles": quantiles,
|
|
103
|
+
},
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
def predict(
|
|
107
|
+
self,
|
|
108
|
+
dataset: TSDataset,
|
|
109
|
+
horizon: int,
|
|
110
|
+
quantiles: list[float] | None = None,
|
|
111
|
+
) -> ForecastResult:
|
|
112
|
+
"""Generate forecasts using Moirai 2.0.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
dataset: Historical data for context
|
|
116
|
+
horizon: Number of steps to forecast
|
|
117
|
+
quantiles: Quantile levels for probabilistic forecasts
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
ForecastResult with predictions and provenance
|
|
121
|
+
"""
|
|
122
|
+
if not self.is_loaded:
|
|
123
|
+
self.load_model()
|
|
124
|
+
|
|
125
|
+
from gluonts.dataset.common import ListDataset
|
|
126
|
+
from uni2ts.model.moirai2 import Moirai2Forecast
|
|
127
|
+
|
|
128
|
+
from tsagentkit.contracts import ForecastResult
|
|
129
|
+
|
|
130
|
+
freq = normalize_pandas_freq(dataset.freq)
|
|
131
|
+
context_length = self._get_context_length(dataset, horizon)
|
|
132
|
+
|
|
133
|
+
model = Moirai2Forecast(
|
|
134
|
+
module=self._module,
|
|
135
|
+
prediction_length=horizon,
|
|
136
|
+
context_length=context_length,
|
|
137
|
+
target_dim=1,
|
|
138
|
+
feat_dynamic_real_dim=0,
|
|
139
|
+
past_feat_dynamic_real_dim=0,
|
|
140
|
+
)
|
|
141
|
+
predictor = model.create_predictor(
|
|
142
|
+
batch_size=self.config.prediction_batch_size or 32
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
entries = []
|
|
146
|
+
meta = []
|
|
147
|
+
for uid in dataset.series_ids:
|
|
148
|
+
series_df = dataset.get_series(uid).sort_values("ds")
|
|
149
|
+
values = series_df["y"].values.astype(np.float32)
|
|
150
|
+
if np.any(np.isnan(values)):
|
|
151
|
+
values = self._handle_missing_values(values)
|
|
152
|
+
entries.append(
|
|
153
|
+
{
|
|
154
|
+
"item_id": uid,
|
|
155
|
+
"start": series_df["ds"].iloc[0],
|
|
156
|
+
"target": values,
|
|
157
|
+
}
|
|
158
|
+
)
|
|
159
|
+
meta.append({"uid": uid, "last_date": series_df["ds"].max()})
|
|
160
|
+
|
|
161
|
+
gluonts_ds = ListDataset(entries, freq=freq)
|
|
162
|
+
forecast_it = predictor.predict(gluonts_ds)
|
|
163
|
+
|
|
164
|
+
offset = pd.tseries.frequencies.to_offset(freq)
|
|
165
|
+
result_rows = []
|
|
166
|
+
for meta_item, forecast in zip(meta, forecast_it, strict=False):
|
|
167
|
+
uid = meta_item["uid"]
|
|
168
|
+
last_date = meta_item["last_date"]
|
|
169
|
+
future_dates = pd.date_range(
|
|
170
|
+
start=last_date + offset,
|
|
171
|
+
periods=horizon,
|
|
172
|
+
freq=freq,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
point_forecast = (
|
|
176
|
+
forecast.quantile(0.5)
|
|
177
|
+
if quantiles and 0.5 in quantiles
|
|
178
|
+
else forecast.mean
|
|
179
|
+
)
|
|
180
|
+
point_forecast = np.asarray(point_forecast).flatten()
|
|
181
|
+
|
|
182
|
+
quantile_arrays: dict[float, np.ndarray] = {}
|
|
183
|
+
if quantiles:
|
|
184
|
+
for q in quantiles:
|
|
185
|
+
try:
|
|
186
|
+
quantile_arrays[q] = np.asarray(forecast.quantile(q)).flatten()
|
|
187
|
+
except Exception:
|
|
188
|
+
quantile_arrays[q] = point_forecast
|
|
189
|
+
|
|
190
|
+
for h in range(horizon):
|
|
191
|
+
row = {
|
|
192
|
+
"unique_id": uid,
|
|
193
|
+
"ds": future_dates[h],
|
|
194
|
+
"yhat": float(point_forecast[h]),
|
|
195
|
+
}
|
|
196
|
+
for q in quantiles or []:
|
|
197
|
+
row[quantile_col_name(q)] = float(quantile_arrays[q][h])
|
|
198
|
+
result_rows.append(row)
|
|
199
|
+
|
|
200
|
+
result_df = pd.DataFrame(result_rows)
|
|
201
|
+
result_df["model"] = "moirai-2.0"
|
|
202
|
+
provenance = self._create_provenance(dataset, horizon, quantiles)
|
|
203
|
+
|
|
204
|
+
return ForecastResult(
|
|
205
|
+
df=result_df,
|
|
206
|
+
provenance=provenance,
|
|
207
|
+
model_name="moirai-2.0",
|
|
208
|
+
horizon=horizon,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
def _get_context_length(self, dataset: TSDataset, horizon: int) -> int:
|
|
212
|
+
"""Get appropriate context length for prediction.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
dataset: Input dataset
|
|
216
|
+
horizon: Forecast horizon
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
Context length (capped at model max)
|
|
220
|
+
"""
|
|
221
|
+
max_series_len = int(
|
|
222
|
+
dataset.df.groupby("unique_id").size().max()
|
|
223
|
+
) if not dataset.df.empty else 0
|
|
224
|
+
|
|
225
|
+
context_length = self.config.max_context_length or max_series_len
|
|
226
|
+
context_length = max(context_length, horizon)
|
|
227
|
+
context_length = min(context_length, self.DEFAULT_CONTEXT_LENGTH)
|
|
228
|
+
return max(1, int(context_length))
|
|
229
|
+
|
|
230
|
+
def _handle_missing_values(self, values: np.ndarray) -> np.ndarray:
|
|
231
|
+
"""Handle missing values in series.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
values: Array that may contain NaNs
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
Array with NaNs filled
|
|
238
|
+
"""
|
|
239
|
+
s = pd.Series(values)
|
|
240
|
+
s = s.interpolate(method="linear", limit_direction="both")
|
|
241
|
+
return s.fillna(s.mean()).values.astype(np.float32)
|
|
242
|
+
|
|
243
|
+
def get_model_signature(self) -> str:
|
|
244
|
+
"""Return model signature for provenance."""
|
|
245
|
+
return f"moirai-2.0-{self._device}"
|
|
246
|
+
|
|
247
|
+
@classmethod
|
|
248
|
+
def _check_dependencies(cls) -> None:
|
|
249
|
+
"""Check if Moirai dependencies are installed."""
|
|
250
|
+
try:
|
|
251
|
+
from uni2ts.model.moirai2 import Moirai2Module # noqa: F401
|
|
252
|
+
except ImportError as e:
|
|
253
|
+
raise ImportError(
|
|
254
|
+
"uni2ts>=2.0.0 is required. "
|
|
255
|
+
"Install with: pip install 'uni2ts @ git+https://github.com/SalesforceAIResearch/uni2ts.git'"
|
|
256
|
+
) from e
|