tsagentkit 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tsagentkit/__init__.py +126 -0
- tsagentkit/anomaly/__init__.py +130 -0
- tsagentkit/backtest/__init__.py +48 -0
- tsagentkit/backtest/engine.py +788 -0
- tsagentkit/backtest/metrics.py +244 -0
- tsagentkit/backtest/report.py +342 -0
- tsagentkit/calibration/__init__.py +136 -0
- tsagentkit/contracts/__init__.py +133 -0
- tsagentkit/contracts/errors.py +275 -0
- tsagentkit/contracts/results.py +418 -0
- tsagentkit/contracts/schema.py +44 -0
- tsagentkit/contracts/task_spec.py +300 -0
- tsagentkit/covariates/__init__.py +340 -0
- tsagentkit/eval/__init__.py +285 -0
- tsagentkit/features/__init__.py +20 -0
- tsagentkit/features/covariates.py +328 -0
- tsagentkit/features/extra/__init__.py +5 -0
- tsagentkit/features/extra/native.py +179 -0
- tsagentkit/features/factory.py +187 -0
- tsagentkit/features/matrix.py +159 -0
- tsagentkit/features/tsfeatures_adapter.py +115 -0
- tsagentkit/features/versioning.py +203 -0
- tsagentkit/hierarchy/__init__.py +39 -0
- tsagentkit/hierarchy/aggregation.py +62 -0
- tsagentkit/hierarchy/evaluator.py +400 -0
- tsagentkit/hierarchy/reconciliation.py +232 -0
- tsagentkit/hierarchy/structure.py +453 -0
- tsagentkit/models/__init__.py +182 -0
- tsagentkit/models/adapters/__init__.py +83 -0
- tsagentkit/models/adapters/base.py +321 -0
- tsagentkit/models/adapters/chronos.py +387 -0
- tsagentkit/models/adapters/moirai.py +256 -0
- tsagentkit/models/adapters/registry.py +171 -0
- tsagentkit/models/adapters/timesfm.py +440 -0
- tsagentkit/models/baselines.py +207 -0
- tsagentkit/models/sktime.py +307 -0
- tsagentkit/monitoring/__init__.py +51 -0
- tsagentkit/monitoring/alerts.py +302 -0
- tsagentkit/monitoring/coverage.py +203 -0
- tsagentkit/monitoring/drift.py +330 -0
- tsagentkit/monitoring/report.py +214 -0
- tsagentkit/monitoring/stability.py +275 -0
- tsagentkit/monitoring/triggers.py +423 -0
- tsagentkit/qa/__init__.py +347 -0
- tsagentkit/router/__init__.py +37 -0
- tsagentkit/router/bucketing.py +489 -0
- tsagentkit/router/fallback.py +132 -0
- tsagentkit/router/plan.py +23 -0
- tsagentkit/router/router.py +271 -0
- tsagentkit/series/__init__.py +26 -0
- tsagentkit/series/alignment.py +206 -0
- tsagentkit/series/dataset.py +449 -0
- tsagentkit/series/sparsity.py +261 -0
- tsagentkit/series/validation.py +393 -0
- tsagentkit/serving/__init__.py +39 -0
- tsagentkit/serving/orchestration.py +943 -0
- tsagentkit/serving/packaging.py +73 -0
- tsagentkit/serving/provenance.py +317 -0
- tsagentkit/serving/tsfm_cache.py +214 -0
- tsagentkit/skill/README.md +135 -0
- tsagentkit/skill/__init__.py +8 -0
- tsagentkit/skill/recipes.md +429 -0
- tsagentkit/skill/tool_map.md +21 -0
- tsagentkit/time/__init__.py +134 -0
- tsagentkit/utils/__init__.py +20 -0
- tsagentkit/utils/quantiles.py +83 -0
- tsagentkit/utils/signature.py +47 -0
- tsagentkit/utils/temporal.py +41 -0
- tsagentkit-1.0.2.dist-info/METADATA +371 -0
- tsagentkit-1.0.2.dist-info/RECORD +72 -0
- tsagentkit-1.0.2.dist-info/WHEEL +4 -0
- tsagentkit-1.0.2.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
"""Adapter registry for TSFM discovery and factory creation.
|
|
2
|
+
|
|
3
|
+
Provides centralized registration and discovery of TSFM adapters
|
|
4
|
+
with availability checking and dependency management.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import TYPE_CHECKING
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from .base import AdapterConfig, TSFMAdapter
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AdapterRegistry:
|
|
16
|
+
"""Registry for TSFM adapters with auto-discovery.
|
|
17
|
+
|
|
18
|
+
Provides centralized access to available adapters with
|
|
19
|
+
fallback handling and version checking.
|
|
20
|
+
|
|
21
|
+
Example:
|
|
22
|
+
>>> # Register an adapter
|
|
23
|
+
>>> AdapterRegistry.register("chronos", ChronosAdapter)
|
|
24
|
+
>>>
|
|
25
|
+
>>> # List available adapters
|
|
26
|
+
>>> AdapterRegistry.list_available()
|
|
27
|
+
['chronos', 'moirai']
|
|
28
|
+
>>>
|
|
29
|
+
>>> # Check if adapter can be used
|
|
30
|
+
>>> AdapterRegistry.check_availability("chronos")
|
|
31
|
+
(True, None)
|
|
32
|
+
>>>
|
|
33
|
+
>>> # Create adapter instance
|
|
34
|
+
>>> adapter = AdapterRegistry.create("chronos", config)
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
_adapters: dict[str, type[TSFMAdapter]] = {}
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def register(
|
|
41
|
+
cls,
|
|
42
|
+
name: str,
|
|
43
|
+
adapter_class: type[TSFMAdapter],
|
|
44
|
+
) -> None:
|
|
45
|
+
"""Register an adapter class.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
name: Unique identifier for the adapter
|
|
49
|
+
adapter_class: The adapter class to register
|
|
50
|
+
|
|
51
|
+
Raises:
|
|
52
|
+
ValueError: If name is already registered with a different class
|
|
53
|
+
"""
|
|
54
|
+
if name in cls._adapters and cls._adapters[name] is not adapter_class:
|
|
55
|
+
raise ValueError(
|
|
56
|
+
f"Adapter '{name}' is already registered with a different class"
|
|
57
|
+
)
|
|
58
|
+
cls._adapters[name] = adapter_class
|
|
59
|
+
|
|
60
|
+
@classmethod
|
|
61
|
+
def unregister(cls, name: str) -> None:
|
|
62
|
+
"""Unregister an adapter.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
name: Name of the adapter to unregister
|
|
66
|
+
|
|
67
|
+
Raises:
|
|
68
|
+
KeyError: If adapter is not registered
|
|
69
|
+
"""
|
|
70
|
+
if name not in cls._adapters:
|
|
71
|
+
raise KeyError(f"Adapter '{name}' is not registered")
|
|
72
|
+
del cls._adapters[name]
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
def get(cls, name: str) -> type[TSFMAdapter]:
|
|
76
|
+
"""Get adapter class by name.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
name: Adapter name
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
The adapter class
|
|
83
|
+
|
|
84
|
+
Raises:
|
|
85
|
+
ValueError: If adapter is not registered
|
|
86
|
+
"""
|
|
87
|
+
if name not in cls._adapters:
|
|
88
|
+
available = ", ".join(sorted(cls._adapters.keys()))
|
|
89
|
+
raise ValueError(
|
|
90
|
+
f"Unknown adapter '{name}'. "
|
|
91
|
+
f"Available adapters: {available or 'none'}"
|
|
92
|
+
)
|
|
93
|
+
return cls._adapters[name]
|
|
94
|
+
|
|
95
|
+
@classmethod
|
|
96
|
+
def list_available(cls) -> list[str]:
|
|
97
|
+
"""List all registered adapter names.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Sorted list of registered adapter names
|
|
101
|
+
"""
|
|
102
|
+
return sorted(cls._adapters.keys())
|
|
103
|
+
|
|
104
|
+
@classmethod
|
|
105
|
+
def create(
|
|
106
|
+
cls,
|
|
107
|
+
name: str,
|
|
108
|
+
config: AdapterConfig | None = None,
|
|
109
|
+
) -> TSFMAdapter:
|
|
110
|
+
"""Factory method to create adapter instance.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
name: Name of the adapter to create
|
|
114
|
+
config: Optional configuration (uses defaults if None)
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
Instantiated adapter
|
|
118
|
+
|
|
119
|
+
Raises:
|
|
120
|
+
ValueError: If adapter is not registered
|
|
121
|
+
"""
|
|
122
|
+
from .base import AdapterConfig
|
|
123
|
+
|
|
124
|
+
adapter_class = cls.get(name)
|
|
125
|
+
return adapter_class(config or AdapterConfig(model_name=name))
|
|
126
|
+
|
|
127
|
+
@classmethod
|
|
128
|
+
def check_availability(cls, name: str) -> tuple[bool, str]:
|
|
129
|
+
"""Check if adapter dependencies are installed.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
name: Adapter name to check
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
Tuple of (is_available, error_message). error_message is empty string
|
|
136
|
+
if the adapter is available.
|
|
137
|
+
"""
|
|
138
|
+
try:
|
|
139
|
+
adapter_class = cls.get(name)
|
|
140
|
+
except ValueError as e:
|
|
141
|
+
return False, str(e)
|
|
142
|
+
|
|
143
|
+
try:
|
|
144
|
+
adapter_class._check_dependencies()
|
|
145
|
+
return True, ""
|
|
146
|
+
except ImportError as e:
|
|
147
|
+
return False, str(e)
|
|
148
|
+
except Exception as e:
|
|
149
|
+
return False, f"Unexpected error checking dependencies: {e}"
|
|
150
|
+
|
|
151
|
+
@classmethod
|
|
152
|
+
def get_available_adapters(cls) -> dict[str, type[TSFMAdapter]]:
|
|
153
|
+
"""Get all adapters that have their dependencies installed.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Dictionary mapping available adapter names to their classes
|
|
157
|
+
"""
|
|
158
|
+
available = {}
|
|
159
|
+
for name in cls._adapters:
|
|
160
|
+
is_avail, _ = cls.check_availability(name)
|
|
161
|
+
if is_avail:
|
|
162
|
+
available[name] = cls._adapters[name]
|
|
163
|
+
return available
|
|
164
|
+
|
|
165
|
+
@classmethod
|
|
166
|
+
def clear(cls) -> None:
|
|
167
|
+
"""Clear all registered adapters.
|
|
168
|
+
|
|
169
|
+
Useful for testing.
|
|
170
|
+
"""
|
|
171
|
+
cls._adapters.clear()
|
|
@@ -0,0 +1,440 @@
|
|
|
1
|
+
"""Google TimesFM TSFM adapter.
|
|
2
|
+
|
|
3
|
+
Adapter for Google's TimesFM (Time Series Foundation Model).
|
|
4
|
+
TimesFM is a pretrained decoder-only model for time series forecasting.
|
|
5
|
+
|
|
6
|
+
Reference: https://github.com/google-research/timesfm
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from typing import TYPE_CHECKING
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import pandas as pd
|
|
15
|
+
|
|
16
|
+
from tsagentkit.time import normalize_pandas_freq
|
|
17
|
+
from tsagentkit.utils import quantile_col_name
|
|
18
|
+
|
|
19
|
+
from .base import TSFMAdapter
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from tsagentkit.contracts import ForecastResult, ModelArtifact
|
|
23
|
+
from tsagentkit.series import TSDataset
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class TimesFMAdapter(TSFMAdapter):
|
|
27
|
+
"""Adapter for Google TimesFM foundation model (2.5).
|
|
28
|
+
|
|
29
|
+
TimesFM is a decoder-only foundation model with 200M parameters
|
|
30
|
+
that achieves strong zero-shot performance on various datasets.
|
|
31
|
+
|
|
32
|
+
Example:
|
|
33
|
+
>>> adapter = TimesFMAdapter(AdapterConfig(model_name="timesfm"))
|
|
34
|
+
>>> adapter.load_model()
|
|
35
|
+
>>> result = adapter.predict(dataset, horizon=30)
|
|
36
|
+
|
|
37
|
+
Reference:
|
|
38
|
+
https://github.com/google-research/timesfm
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
# TimesFM 2.5 model checkpoint (200M parameters)
|
|
42
|
+
MODEL_ID = "google/timesfm-2.5-200m-pytorch"
|
|
43
|
+
|
|
44
|
+
# Model configuration constants
|
|
45
|
+
MAX_CONTEXT = 128 # Maximum supported context length
|
|
46
|
+
MAX_HORIZON = 256 # Maximum supported forecast horizon
|
|
47
|
+
MIN_INPUT_LENGTH = 97 # Minimum input to avoid NaN (MAX_CONTEXT - 32 + 1)
|
|
48
|
+
|
|
49
|
+
# Supported quantiles from the model
|
|
50
|
+
SUPPORTED_QUANTILES = [round(q / 10, 1) for q in range(1, 10)]
|
|
51
|
+
|
|
52
|
+
def load_model(self) -> None:
|
|
53
|
+
"""Load TimesFM model from checkpoint.
|
|
54
|
+
|
|
55
|
+
Downloads and caches the model if not already present.
|
|
56
|
+
|
|
57
|
+
Raises:
|
|
58
|
+
ImportError: If timesfm is not installed
|
|
59
|
+
RuntimeError: If model loading fails
|
|
60
|
+
"""
|
|
61
|
+
try:
|
|
62
|
+
import timesfm
|
|
63
|
+
except ImportError as e:
|
|
64
|
+
raise ImportError(
|
|
65
|
+
"timesfm is required for TimesFMAdapter. "
|
|
66
|
+
"Install with: pip install timesfm"
|
|
67
|
+
) from e
|
|
68
|
+
|
|
69
|
+
self._compiled_max_context = 0
|
|
70
|
+
self._compiled_max_horizon = 0
|
|
71
|
+
|
|
72
|
+
# Load TimesFM 2.5 model using the new API
|
|
73
|
+
self._model = timesfm.TimesFM_2p5_200M_torch.from_pretrained(
|
|
74
|
+
self.MODEL_ID,
|
|
75
|
+
cache_dir=self.config.cache_dir,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Compile with default config
|
|
79
|
+
self._ensure_compiled(self.MAX_CONTEXT, self.MAX_HORIZON)
|
|
80
|
+
|
|
81
|
+
def fit(
|
|
82
|
+
self,
|
|
83
|
+
dataset: TSDataset,
|
|
84
|
+
prediction_length: int,
|
|
85
|
+
quantiles: list[float] | None = None,
|
|
86
|
+
) -> ModelArtifact:
|
|
87
|
+
"""Prepare TimesFM for prediction.
|
|
88
|
+
|
|
89
|
+
TimesFM is a zero-shot model and doesn't require training.
|
|
90
|
+
This method validates compatibility and returns a ModelArtifact.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
dataset: Dataset to validate
|
|
94
|
+
prediction_length: Forecast horizon
|
|
95
|
+
quantiles: Optional quantile levels
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
ModelArtifact with model reference
|
|
99
|
+
|
|
100
|
+
Raises:
|
|
101
|
+
ValueError: If prediction_length exceeds model horizon
|
|
102
|
+
"""
|
|
103
|
+
from tsagentkit.contracts import ModelArtifact
|
|
104
|
+
|
|
105
|
+
if not self.is_loaded:
|
|
106
|
+
self.load_model()
|
|
107
|
+
|
|
108
|
+
# Validate dataset
|
|
109
|
+
self._validate_dataset(dataset)
|
|
110
|
+
|
|
111
|
+
max_context, max_horizon = self._get_compilation_targets(
|
|
112
|
+
dataset, prediction_length
|
|
113
|
+
)
|
|
114
|
+
self._ensure_compiled(max_context, max_horizon)
|
|
115
|
+
|
|
116
|
+
return ModelArtifact(
|
|
117
|
+
model=self._model,
|
|
118
|
+
model_name="timesfm-2.5",
|
|
119
|
+
config={
|
|
120
|
+
"device": self._device,
|
|
121
|
+
"prediction_length": prediction_length,
|
|
122
|
+
"quantiles": quantiles,
|
|
123
|
+
},
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
def predict(
|
|
127
|
+
self,
|
|
128
|
+
dataset: TSDataset,
|
|
129
|
+
horizon: int,
|
|
130
|
+
quantiles: list[float] | None = None,
|
|
131
|
+
) -> ForecastResult:
|
|
132
|
+
"""Generate forecasts using TimesFM.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
dataset: Historical data for context
|
|
136
|
+
horizon: Number of steps to forecast
|
|
137
|
+
quantiles: Quantile levels for probabilistic forecasts
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
ForecastResult with predictions and provenance
|
|
141
|
+
"""
|
|
142
|
+
if not self.is_loaded:
|
|
143
|
+
self.load_model()
|
|
144
|
+
|
|
145
|
+
max_context, max_horizon = self._get_compilation_targets(dataset, horizon)
|
|
146
|
+
self._ensure_compiled(max_context, max_horizon)
|
|
147
|
+
|
|
148
|
+
inputs, _freq = self._to_timesfm_format(dataset)
|
|
149
|
+
|
|
150
|
+
# TimesFM 2.5 forecast API: forecast(horizon, inputs)
|
|
151
|
+
point_forecasts, quantile_forecasts = self._model.forecast(
|
|
152
|
+
horizon=horizon,
|
|
153
|
+
inputs=inputs,
|
|
154
|
+
)
|
|
155
|
+
if point_forecasts.shape[1] > horizon:
|
|
156
|
+
point_forecasts = point_forecasts[:, :horizon]
|
|
157
|
+
quantile_forecasts = quantile_forecasts[:, :horizon, :]
|
|
158
|
+
|
|
159
|
+
# Handle potential NaN in outputs (see: https://github.com/google-research/timesfm/issues/321)
|
|
160
|
+
if np.any(np.isnan(point_forecasts)):
|
|
161
|
+
point_forecasts = self._handle_nan_forecasts(point_forecasts, inputs)
|
|
162
|
+
if quantile_forecasts is not None and np.any(np.isnan(quantile_forecasts)):
|
|
163
|
+
quantile_forecasts = self._handle_nan_quantiles(quantile_forecasts, point_forecasts)
|
|
164
|
+
|
|
165
|
+
# Convert to ForecastResult
|
|
166
|
+
return self._to_forecast_result(
|
|
167
|
+
point_forecasts,
|
|
168
|
+
quantile_forecasts,
|
|
169
|
+
dataset,
|
|
170
|
+
horizon,
|
|
171
|
+
quantiles,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
def _handle_nan_forecasts(
|
|
175
|
+
self,
|
|
176
|
+
forecasts: np.ndarray,
|
|
177
|
+
inputs: list[np.ndarray],
|
|
178
|
+
) -> np.ndarray:
|
|
179
|
+
"""Replace NaN forecasts with last valid values.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
forecasts: Forecast array that may contain NaN
|
|
183
|
+
inputs: Original input values for each series
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
Forecast array with NaN replaced
|
|
187
|
+
"""
|
|
188
|
+
result = forecasts.copy()
|
|
189
|
+
for i in range(result.shape[0]):
|
|
190
|
+
if np.any(np.isnan(result[i])):
|
|
191
|
+
last_value = inputs[i][-1] if len(inputs[i]) > 0 else 0.0
|
|
192
|
+
result[i] = np.nan_to_num(result[i], nan=last_value)
|
|
193
|
+
return result
|
|
194
|
+
|
|
195
|
+
def _handle_nan_quantiles(
|
|
196
|
+
self,
|
|
197
|
+
quantiles: np.ndarray,
|
|
198
|
+
point_forecasts: np.ndarray,
|
|
199
|
+
) -> np.ndarray:
|
|
200
|
+
"""Replace NaN quantile forecasts.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
quantiles: Quantile forecast array that may contain NaN
|
|
204
|
+
point_forecasts: Point forecasts for fallback
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
Quantile array with NaN replaced
|
|
208
|
+
"""
|
|
209
|
+
result = quantiles.copy()
|
|
210
|
+
for i in range(result.shape[0]):
|
|
211
|
+
for h in range(result.shape[1]):
|
|
212
|
+
if np.any(np.isnan(result[i, h])):
|
|
213
|
+
point_val = point_forecasts[i, h]
|
|
214
|
+
result[i, h] = np.nan_to_num(result[i, h], nan=point_val)
|
|
215
|
+
return result
|
|
216
|
+
|
|
217
|
+
def _to_timesfm_format(
|
|
218
|
+
self,
|
|
219
|
+
dataset: TSDataset,
|
|
220
|
+
) -> tuple[list[np.ndarray], None]:
|
|
221
|
+
"""Convert TSDataset to TimesFM format.
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
dataset: Input dataset
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
Tuple of (list of input arrays, None)
|
|
228
|
+
TimesFM 2.5 does not require frequency mapping.
|
|
229
|
+
"""
|
|
230
|
+
inputs = []
|
|
231
|
+
for uid in dataset.series_ids:
|
|
232
|
+
series_df = dataset.get_series(uid)
|
|
233
|
+
values = series_df["y"].values.astype(np.float32)
|
|
234
|
+
|
|
235
|
+
# Handle NaN values
|
|
236
|
+
if np.any(np.isnan(values)):
|
|
237
|
+
values = self._handle_missing_values(values)
|
|
238
|
+
|
|
239
|
+
# Pad short inputs to avoid NaN from attention mask issue
|
|
240
|
+
# See: https://github.com/google-research/timesfm/issues/321
|
|
241
|
+
if len(values) < self.MIN_INPUT_LENGTH:
|
|
242
|
+
values = self._pad_input(values, self.MIN_INPUT_LENGTH)
|
|
243
|
+
|
|
244
|
+
inputs.append(values)
|
|
245
|
+
|
|
246
|
+
# TimesFM 2.5 does not require frequency mapping
|
|
247
|
+
return inputs, None
|
|
248
|
+
|
|
249
|
+
def _pad_input(self, values: np.ndarray, min_length: int) -> np.ndarray:
|
|
250
|
+
"""Pad input values to minimum length.
|
|
251
|
+
|
|
252
|
+
Uses linear extrapolation based on the last values to avoid
|
|
253
|
+
introducing artificial patterns.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
values: Original input values
|
|
257
|
+
min_length: Minimum required length
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
Padded array
|
|
261
|
+
"""
|
|
262
|
+
if len(values) >= min_length:
|
|
263
|
+
return values
|
|
264
|
+
|
|
265
|
+
# Calculate trend from last values for extrapolation
|
|
266
|
+
n_last = min(5, len(values))
|
|
267
|
+
last_values = values[-n_last:]
|
|
268
|
+
trend = np.diff(last_values).mean() if n_last > 1 else 0
|
|
269
|
+
last_value = values[-1]
|
|
270
|
+
|
|
271
|
+
# Generate padded values following the trend
|
|
272
|
+
n_pad = min_length - len(values)
|
|
273
|
+
padded = np.arange(1, n_pad + 1) * trend + last_value
|
|
274
|
+
|
|
275
|
+
return np.concatenate([values, padded]).astype(np.float32)
|
|
276
|
+
|
|
277
|
+
def _to_forecast_result(
|
|
278
|
+
self,
|
|
279
|
+
point_forecasts: np.ndarray,
|
|
280
|
+
quantile_forecasts: np.ndarray | None,
|
|
281
|
+
dataset: TSDataset,
|
|
282
|
+
horizon: int,
|
|
283
|
+
quantiles: list[float] | None,
|
|
284
|
+
) -> ForecastResult:
|
|
285
|
+
"""Convert TimesFM output to ForecastResult.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
point_forecasts: Point predictions (n_series, horizon)
|
|
289
|
+
quantile_forecasts: Optional quantile predictions
|
|
290
|
+
dataset: Original dataset
|
|
291
|
+
horizon: Forecast horizon
|
|
292
|
+
quantiles: Quantile levels
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
ForecastResult with predictions
|
|
296
|
+
"""
|
|
297
|
+
from tsagentkit.contracts import ForecastResult
|
|
298
|
+
|
|
299
|
+
result_rows = []
|
|
300
|
+
freq = normalize_pandas_freq(dataset.freq)
|
|
301
|
+
offset = pd.tseries.frequencies.to_offset(freq)
|
|
302
|
+
|
|
303
|
+
quantile_values: dict[float, np.ndarray] = {}
|
|
304
|
+
if quantiles:
|
|
305
|
+
if quantile_forecasts is None:
|
|
306
|
+
quantile_values = dict.fromkeys(quantiles, point_forecasts)
|
|
307
|
+
else:
|
|
308
|
+
supported = getattr(self, "_model_quantiles", self.SUPPORTED_QUANTILES)
|
|
309
|
+
for q in quantiles:
|
|
310
|
+
nearest = min(supported, key=lambda v: abs(v - q))
|
|
311
|
+
idx = supported.index(nearest) + 1
|
|
312
|
+
quantile_values[q] = quantile_forecasts[:, :, idx]
|
|
313
|
+
|
|
314
|
+
for i, uid in enumerate(dataset.series_ids):
|
|
315
|
+
last_date = dataset.get_series(uid)["ds"].max()
|
|
316
|
+
future_dates = pd.date_range(
|
|
317
|
+
start=last_date + offset,
|
|
318
|
+
periods=horizon,
|
|
319
|
+
freq=freq,
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
for h in range(horizon):
|
|
323
|
+
row = {
|
|
324
|
+
"unique_id": uid,
|
|
325
|
+
"ds": future_dates[h],
|
|
326
|
+
"yhat": float(point_forecasts[i, h]),
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
# Add quantile columns if available
|
|
330
|
+
if quantiles:
|
|
331
|
+
for q in quantiles:
|
|
332
|
+
row[quantile_col_name(q)] = float(
|
|
333
|
+
quantile_values[q][i, h]
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
result_rows.append(row)
|
|
337
|
+
|
|
338
|
+
result_df = pd.DataFrame(result_rows)
|
|
339
|
+
result_df["model"] = "timesfm-2.5"
|
|
340
|
+
provenance = self._create_provenance(dataset, horizon, quantiles)
|
|
341
|
+
|
|
342
|
+
return ForecastResult(
|
|
343
|
+
df=result_df,
|
|
344
|
+
provenance=provenance,
|
|
345
|
+
model_name="timesfm-2.5",
|
|
346
|
+
horizon=horizon,
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
def _handle_missing_values(self, values: np.ndarray) -> np.ndarray:
|
|
350
|
+
"""Handle missing values in series.
|
|
351
|
+
|
|
352
|
+
Args:
|
|
353
|
+
values: Array that may contain NaNs
|
|
354
|
+
|
|
355
|
+
Returns:
|
|
356
|
+
Array with NaNs filled
|
|
357
|
+
"""
|
|
358
|
+
import pandas as pd
|
|
359
|
+
|
|
360
|
+
s = pd.Series(values)
|
|
361
|
+
s = s.interpolate(method="linear", limit_direction="both")
|
|
362
|
+
return s.fillna(s.mean()).values.astype(np.float32)
|
|
363
|
+
|
|
364
|
+
def _get_compilation_targets(
|
|
365
|
+
self,
|
|
366
|
+
dataset: TSDataset,
|
|
367
|
+
horizon: int,
|
|
368
|
+
) -> tuple[int, int]:
|
|
369
|
+
"""Get compilation targets for model.
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
dataset: Input dataset
|
|
373
|
+
horizon: Forecast horizon
|
|
374
|
+
|
|
375
|
+
Returns:
|
|
376
|
+
Tuple of (max_context, max_horizon)
|
|
377
|
+
"""
|
|
378
|
+
max_series_len = int(
|
|
379
|
+
dataset.df.groupby("unique_id").size().max()
|
|
380
|
+
) if not dataset.df.empty else 0
|
|
381
|
+
target_context = self.config.max_context_length or max_series_len or self.MAX_CONTEXT
|
|
382
|
+
max_context = max(self.MAX_CONTEXT, target_context)
|
|
383
|
+
max_context = min(max_context, self.MAX_CONTEXT)
|
|
384
|
+
max_horizon = max(self.MAX_HORIZON, horizon)
|
|
385
|
+
return max_context, max_horizon
|
|
386
|
+
|
|
387
|
+
def _ensure_compiled(self, max_context: int, max_horizon: int) -> None:
|
|
388
|
+
"""Ensure model is compiled with appropriate context/horizon settings.
|
|
389
|
+
|
|
390
|
+
TimesFM 2.5 requires calling compile() with ForecastConfig before forecast().
|
|
391
|
+
|
|
392
|
+
Args:
|
|
393
|
+
max_context: Maximum context length
|
|
394
|
+
max_horizon: Maximum forecast horizon
|
|
395
|
+
"""
|
|
396
|
+
if (
|
|
397
|
+
getattr(self, "_compiled_max_context", 0) >= max_context
|
|
398
|
+
and getattr(self, "_compiled_max_horizon", 0) >= max_horizon
|
|
399
|
+
and self._model is not None
|
|
400
|
+
):
|
|
401
|
+
return
|
|
402
|
+
|
|
403
|
+
import timesfm
|
|
404
|
+
|
|
405
|
+
config = timesfm.ForecastConfig(
|
|
406
|
+
max_context=max_context,
|
|
407
|
+
max_horizon=max_horizon,
|
|
408
|
+
normalize_inputs=True,
|
|
409
|
+
use_continuous_quantile_head=True,
|
|
410
|
+
force_flip_invariance=True,
|
|
411
|
+
infer_is_positive=True,
|
|
412
|
+
fix_quantile_crossing=True,
|
|
413
|
+
)
|
|
414
|
+
self._model.compile(config)
|
|
415
|
+
self._model_quantiles = list(self.SUPPORTED_QUANTILES)
|
|
416
|
+
self._compiled_max_context = max_context
|
|
417
|
+
self._compiled_max_horizon = max_horizon
|
|
418
|
+
|
|
419
|
+
def get_model_signature(self) -> str:
|
|
420
|
+
"""Return model signature for provenance.
|
|
421
|
+
|
|
422
|
+
Returns:
|
|
423
|
+
Unique signature string
|
|
424
|
+
"""
|
|
425
|
+
return f"timesfm-2.5-{self._device}"
|
|
426
|
+
|
|
427
|
+
@classmethod
|
|
428
|
+
def _check_dependencies(cls) -> None:
|
|
429
|
+
"""Check if TimesFM dependencies are installed.
|
|
430
|
+
|
|
431
|
+
Raises:
|
|
432
|
+
ImportError: If timesfm is not installed
|
|
433
|
+
"""
|
|
434
|
+
try:
|
|
435
|
+
import timesfm # noqa: F401
|
|
436
|
+
except ImportError as e:
|
|
437
|
+
raise ImportError(
|
|
438
|
+
"timesfm is required. "
|
|
439
|
+
"Install with: pip install timesfm"
|
|
440
|
+
) from e
|