tsagentkit 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tsagentkit/__init__.py +126 -0
- tsagentkit/anomaly/__init__.py +130 -0
- tsagentkit/backtest/__init__.py +48 -0
- tsagentkit/backtest/engine.py +788 -0
- tsagentkit/backtest/metrics.py +244 -0
- tsagentkit/backtest/report.py +342 -0
- tsagentkit/calibration/__init__.py +136 -0
- tsagentkit/contracts/__init__.py +133 -0
- tsagentkit/contracts/errors.py +275 -0
- tsagentkit/contracts/results.py +418 -0
- tsagentkit/contracts/schema.py +44 -0
- tsagentkit/contracts/task_spec.py +300 -0
- tsagentkit/covariates/__init__.py +340 -0
- tsagentkit/eval/__init__.py +285 -0
- tsagentkit/features/__init__.py +20 -0
- tsagentkit/features/covariates.py +328 -0
- tsagentkit/features/extra/__init__.py +5 -0
- tsagentkit/features/extra/native.py +179 -0
- tsagentkit/features/factory.py +187 -0
- tsagentkit/features/matrix.py +159 -0
- tsagentkit/features/tsfeatures_adapter.py +115 -0
- tsagentkit/features/versioning.py +203 -0
- tsagentkit/hierarchy/__init__.py +39 -0
- tsagentkit/hierarchy/aggregation.py +62 -0
- tsagentkit/hierarchy/evaluator.py +400 -0
- tsagentkit/hierarchy/reconciliation.py +232 -0
- tsagentkit/hierarchy/structure.py +453 -0
- tsagentkit/models/__init__.py +182 -0
- tsagentkit/models/adapters/__init__.py +83 -0
- tsagentkit/models/adapters/base.py +321 -0
- tsagentkit/models/adapters/chronos.py +387 -0
- tsagentkit/models/adapters/moirai.py +256 -0
- tsagentkit/models/adapters/registry.py +171 -0
- tsagentkit/models/adapters/timesfm.py +440 -0
- tsagentkit/models/baselines.py +207 -0
- tsagentkit/models/sktime.py +307 -0
- tsagentkit/monitoring/__init__.py +51 -0
- tsagentkit/monitoring/alerts.py +302 -0
- tsagentkit/monitoring/coverage.py +203 -0
- tsagentkit/monitoring/drift.py +330 -0
- tsagentkit/monitoring/report.py +214 -0
- tsagentkit/monitoring/stability.py +275 -0
- tsagentkit/monitoring/triggers.py +423 -0
- tsagentkit/qa/__init__.py +347 -0
- tsagentkit/router/__init__.py +37 -0
- tsagentkit/router/bucketing.py +489 -0
- tsagentkit/router/fallback.py +132 -0
- tsagentkit/router/plan.py +23 -0
- tsagentkit/router/router.py +271 -0
- tsagentkit/series/__init__.py +26 -0
- tsagentkit/series/alignment.py +206 -0
- tsagentkit/series/dataset.py +449 -0
- tsagentkit/series/sparsity.py +261 -0
- tsagentkit/series/validation.py +393 -0
- tsagentkit/serving/__init__.py +39 -0
- tsagentkit/serving/orchestration.py +943 -0
- tsagentkit/serving/packaging.py +73 -0
- tsagentkit/serving/provenance.py +317 -0
- tsagentkit/serving/tsfm_cache.py +214 -0
- tsagentkit/skill/README.md +135 -0
- tsagentkit/skill/__init__.py +8 -0
- tsagentkit/skill/recipes.md +429 -0
- tsagentkit/skill/tool_map.md +21 -0
- tsagentkit/time/__init__.py +134 -0
- tsagentkit/utils/__init__.py +20 -0
- tsagentkit/utils/quantiles.py +83 -0
- tsagentkit/utils/signature.py +47 -0
- tsagentkit/utils/temporal.py +41 -0
- tsagentkit-1.0.2.dist-info/METADATA +371 -0
- tsagentkit-1.0.2.dist-info/RECORD +72 -0
- tsagentkit-1.0.2.dist-info/WHEEL +4 -0
- tsagentkit-1.0.2.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
"""Models module for tsagentkit.
|
|
2
|
+
|
|
3
|
+
Provides model fitting, prediction, and TSFM (Time-Series Foundation Model)
|
|
4
|
+
adapters for various forecasting backends.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
from datetime import UTC
|
|
11
|
+
from typing import TYPE_CHECKING, Any
|
|
12
|
+
|
|
13
|
+
from tsagentkit.contracts import ForecastResult, ModelArtifact, Provenance
|
|
14
|
+
|
|
15
|
+
# Import adapters submodules
|
|
16
|
+
from tsagentkit.models import adapters
|
|
17
|
+
from tsagentkit.models.baselines import fit_baseline, is_baseline_model, predict_baseline
|
|
18
|
+
from tsagentkit.models.sktime import SktimeModelBundle, fit_sktime, predict_sktime
|
|
19
|
+
from tsagentkit.utils import normalize_quantile_columns
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from tsagentkit.contracts import TaskSpec
|
|
23
|
+
from tsagentkit.router import PlanSpec
|
|
24
|
+
from tsagentkit.series import TSDataset
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _is_tsfm_model(model_name: str) -> bool:
|
|
28
|
+
return model_name.lower().startswith("tsfm-")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _is_sktime_model(model_name: str) -> bool:
|
|
32
|
+
return model_name.lower().startswith("sktime-")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _build_adapter_config(model_name: str, config: dict[str, Any]) -> adapters.AdapterConfig:
|
|
36
|
+
adapter_name = model_name.split("tsfm-", 1)[-1]
|
|
37
|
+
return adapters.AdapterConfig(
|
|
38
|
+
model_name=adapter_name,
|
|
39
|
+
model_size=config.get("model_size", "base"),
|
|
40
|
+
device=config.get("device"),
|
|
41
|
+
cache_dir=config.get("cache_dir"),
|
|
42
|
+
batch_size=config.get("batch_size", 32),
|
|
43
|
+
prediction_batch_size=config.get("prediction_batch_size", 100),
|
|
44
|
+
quantile_method=config.get("quantile_method", "sample"),
|
|
45
|
+
num_samples=config.get("num_samples", 100),
|
|
46
|
+
max_context_length=config.get("max_context_length"),
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _fit_model_name(
|
|
51
|
+
model_name: str,
|
|
52
|
+
dataset: TSDataset,
|
|
53
|
+
plan: PlanSpec,
|
|
54
|
+
covariates: Any | None = None,
|
|
55
|
+
) -> ModelArtifact:
|
|
56
|
+
"""Fit a model by name with baseline or TSFM dispatch."""
|
|
57
|
+
config: dict[str, Any] = {
|
|
58
|
+
"horizon": dataset.task_spec.horizon,
|
|
59
|
+
"season_length": dataset.task_spec.season_length or 1,
|
|
60
|
+
"quantiles": plan.quantiles,
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
if _is_tsfm_model(model_name):
|
|
64
|
+
adapter_name = model_name.split("tsfm-", 1)[-1]
|
|
65
|
+
adapter_config = _build_adapter_config(model_name, {})
|
|
66
|
+
adapter = adapters.AdapterRegistry.create(adapter_name, adapter_config)
|
|
67
|
+
adapter.fit(
|
|
68
|
+
dataset=dataset,
|
|
69
|
+
prediction_length=dataset.task_spec.horizon,
|
|
70
|
+
quantiles=plan.quantiles,
|
|
71
|
+
)
|
|
72
|
+
return ModelArtifact(
|
|
73
|
+
model=adapter,
|
|
74
|
+
model_name=model_name,
|
|
75
|
+
config=config,
|
|
76
|
+
metadata={"adapter": adapter_name},
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
if is_baseline_model(model_name):
|
|
80
|
+
return fit_baseline(model_name, dataset, config)
|
|
81
|
+
|
|
82
|
+
if _is_sktime_model(model_name):
|
|
83
|
+
return fit_sktime(model_name, dataset, plan, covariates=covariates)
|
|
84
|
+
|
|
85
|
+
raise ValueError(f"Unknown model name: {model_name}")
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def fit(
|
|
89
|
+
dataset: TSDataset,
|
|
90
|
+
plan: PlanSpec,
|
|
91
|
+
on_fallback: Callable[[str, str, Exception], None] | None = None,
|
|
92
|
+
covariates: Any | None = None,
|
|
93
|
+
) -> ModelArtifact:
|
|
94
|
+
"""Fit a model using the plan's fallback ladder."""
|
|
95
|
+
from tsagentkit.router import execute_with_fallback
|
|
96
|
+
|
|
97
|
+
def _fit(model_name: str, ds: TSDataset) -> ModelArtifact:
|
|
98
|
+
return _fit_model_name(model_name, ds, plan, covariates=covariates)
|
|
99
|
+
|
|
100
|
+
artifact, _ = execute_with_fallback(
|
|
101
|
+
fit_func=_fit,
|
|
102
|
+
dataset=dataset,
|
|
103
|
+
plan=plan,
|
|
104
|
+
on_fallback=on_fallback,
|
|
105
|
+
)
|
|
106
|
+
return artifact
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _basic_provenance(
|
|
110
|
+
dataset: TSDataset,
|
|
111
|
+
spec: TaskSpec,
|
|
112
|
+
artifact: ModelArtifact,
|
|
113
|
+
) -> Provenance:
|
|
114
|
+
from datetime import datetime
|
|
115
|
+
|
|
116
|
+
from tsagentkit.utils import compute_data_signature
|
|
117
|
+
|
|
118
|
+
return Provenance(
|
|
119
|
+
run_id=f"model_{datetime.now(UTC).strftime('%Y%m%d_%H%M%S')}",
|
|
120
|
+
timestamp=datetime.now(UTC).isoformat(),
|
|
121
|
+
data_signature=compute_data_signature(dataset.df),
|
|
122
|
+
task_signature=spec.model_hash(),
|
|
123
|
+
plan_signature=artifact.signature,
|
|
124
|
+
model_signature=artifact.signature,
|
|
125
|
+
metadata={"provenance_incomplete": True},
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def predict(
|
|
130
|
+
dataset: TSDataset,
|
|
131
|
+
artifact: ModelArtifact,
|
|
132
|
+
spec: TaskSpec,
|
|
133
|
+
covariates: Any | None = None,
|
|
134
|
+
) -> ForecastResult:
|
|
135
|
+
"""Generate predictions for baseline or TSFM models."""
|
|
136
|
+
if isinstance(artifact.model, adapters.TSFMAdapter):
|
|
137
|
+
result = artifact.model.predict(
|
|
138
|
+
dataset=dataset,
|
|
139
|
+
horizon=spec.horizon,
|
|
140
|
+
quantiles=artifact.config.get("quantiles"),
|
|
141
|
+
)
|
|
142
|
+
df = normalize_quantile_columns(result.df)
|
|
143
|
+
if "model" not in df.columns:
|
|
144
|
+
df = df.copy()
|
|
145
|
+
df["model"] = artifact.model_name
|
|
146
|
+
return ForecastResult(
|
|
147
|
+
df=df,
|
|
148
|
+
provenance=result.provenance,
|
|
149
|
+
model_name=artifact.model_name,
|
|
150
|
+
horizon=spec.horizon,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
if is_baseline_model(artifact.model_name):
|
|
154
|
+
forecast_df = predict_baseline(
|
|
155
|
+
model_artifact=artifact,
|
|
156
|
+
dataset=dataset,
|
|
157
|
+
horizon=spec.horizon,
|
|
158
|
+
quantiles=artifact.config.get("quantiles"),
|
|
159
|
+
)
|
|
160
|
+
if "model" not in forecast_df.columns:
|
|
161
|
+
forecast_df = forecast_df.copy()
|
|
162
|
+
forecast_df["model"] = artifact.model_name
|
|
163
|
+
provenance = _basic_provenance(dataset, spec, artifact)
|
|
164
|
+
return ForecastResult(
|
|
165
|
+
df=normalize_quantile_columns(forecast_df),
|
|
166
|
+
provenance=provenance,
|
|
167
|
+
model_name=artifact.model_name,
|
|
168
|
+
horizon=spec.horizon,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
if isinstance(artifact.model, SktimeModelBundle):
|
|
172
|
+
return predict_sktime(
|
|
173
|
+
dataset=dataset,
|
|
174
|
+
artifact=artifact,
|
|
175
|
+
spec=spec,
|
|
176
|
+
covariates=covariates,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
raise ValueError(f"Unknown model type for prediction: {artifact.model_name}")
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
__all__ = ["fit", "predict", "adapters"]
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""TSFM (Time-Series Foundation Model) adapters for tsagentkit.
|
|
2
|
+
|
|
3
|
+
This module provides unified adapters for popular time-series foundation models:
|
|
4
|
+
- Amazon Chronos
|
|
5
|
+
- Salesforce Moirai
|
|
6
|
+
- Google TimesFM
|
|
7
|
+
|
|
8
|
+
Example:
|
|
9
|
+
>>> from tsagentkit.models.adapters import AdapterConfig, AdapterRegistry
|
|
10
|
+
>>> from tsagentkit.models.adapters import ChronosAdapter
|
|
11
|
+
>>>
|
|
12
|
+
>>> # Register adapters
|
|
13
|
+
>>> AdapterRegistry.register("chronos", ChronosAdapter)
|
|
14
|
+
>>>
|
|
15
|
+
>>> # Create and use adapter
|
|
16
|
+
>>> config = AdapterConfig(model_name="chronos", model_size="base")
|
|
17
|
+
>>> adapter = AdapterRegistry.create("chronos", config)
|
|
18
|
+
>>> adapter.load_model()
|
|
19
|
+
>>> result = adapter.predict(dataset, horizon=30)
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
24
|
+
# Base classes
|
|
25
|
+
from .base import AdapterConfig, TSFMAdapter
|
|
26
|
+
from .registry import AdapterRegistry
|
|
27
|
+
|
|
28
|
+
__all__ = [
|
|
29
|
+
# Base classes
|
|
30
|
+
"AdapterConfig",
|
|
31
|
+
"TSFMAdapter",
|
|
32
|
+
"AdapterRegistry",
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
# Lazily import specific adapters as they become available
|
|
36
|
+
# This prevents import errors when dependencies are not installed
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _try_import_chronos():
|
|
40
|
+
"""Try to import ChronosAdapter if dependencies are available."""
|
|
41
|
+
try:
|
|
42
|
+
from .chronos import ChronosAdapter
|
|
43
|
+
|
|
44
|
+
return ChronosAdapter
|
|
45
|
+
except ImportError:
|
|
46
|
+
return None
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _try_import_moirai():
|
|
50
|
+
"""Try to import MoiraiAdapter if dependencies are available."""
|
|
51
|
+
try:
|
|
52
|
+
from .moirai import MoiraiAdapter
|
|
53
|
+
|
|
54
|
+
return MoiraiAdapter
|
|
55
|
+
except ImportError:
|
|
56
|
+
return None
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _try_import_timesfm():
|
|
60
|
+
"""Try to import TimesFMAdapter if dependencies are available."""
|
|
61
|
+
try:
|
|
62
|
+
from .timesfm import TimesFMAdapter
|
|
63
|
+
|
|
64
|
+
return TimesFMAdapter
|
|
65
|
+
except ImportError:
|
|
66
|
+
return None
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# Auto-register available adapters
|
|
70
|
+
_chronos = _try_import_chronos()
|
|
71
|
+
if _chronos:
|
|
72
|
+
AdapterRegistry.register("chronos", _chronos)
|
|
73
|
+
__all__.append("ChronosAdapter")
|
|
74
|
+
|
|
75
|
+
_moirai = _try_import_moirai()
|
|
76
|
+
if _moirai:
|
|
77
|
+
AdapterRegistry.register("moirai", _moirai)
|
|
78
|
+
__all__.append("MoiraiAdapter")
|
|
79
|
+
|
|
80
|
+
_timesfm = _try_import_timesfm()
|
|
81
|
+
if _timesfm:
|
|
82
|
+
AdapterRegistry.register("timesfm", _timesfm)
|
|
83
|
+
__all__.append("TimesFMAdapter")
|
|
@@ -0,0 +1,321 @@
|
|
|
1
|
+
"""Base adapter class for Time-Series Foundation Models.
|
|
2
|
+
|
|
3
|
+
Provides a unified interface for integrating external TSFMs like Chronos,
|
|
4
|
+
Moirai, and TimesFM with the tsagentkit pipeline.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from abc import ABC, abstractmethod
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from datetime import UTC
|
|
12
|
+
from typing import TYPE_CHECKING, Any, Literal
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
|
|
16
|
+
from tsagentkit.contracts import ForecastResult, ModelArtifact, Provenance
|
|
17
|
+
from tsagentkit.series import TSDataset
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(frozen=True)
|
|
21
|
+
class AdapterConfig:
|
|
22
|
+
"""Configuration for TSFM adapter.
|
|
23
|
+
|
|
24
|
+
Attributes:
|
|
25
|
+
model_name: Name of the model/adapter
|
|
26
|
+
model_size: Model size variant (small, base, large)
|
|
27
|
+
device: Compute device (cuda, mps, cpu, or None for auto)
|
|
28
|
+
cache_dir: Directory for model caching
|
|
29
|
+
batch_size: Batch size for training/fitting
|
|
30
|
+
prediction_batch_size: Batch size for prediction
|
|
31
|
+
quantile_method: Method for quantile prediction (sample, direct)
|
|
32
|
+
num_samples: Number of samples for probabilistic forecasting
|
|
33
|
+
max_context_length: Maximum context length the model accepts
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
model_name: str
|
|
37
|
+
model_size: Literal["tiny", "small", "base", "large"] = "base"
|
|
38
|
+
device: str | None = None # Auto-detect if None
|
|
39
|
+
cache_dir: str | None = None
|
|
40
|
+
batch_size: int = 32
|
|
41
|
+
prediction_batch_size: int = 100
|
|
42
|
+
quantile_method: Literal["sample", "direct"] = "sample"
|
|
43
|
+
num_samples: int = 100
|
|
44
|
+
max_context_length: int | None = None
|
|
45
|
+
|
|
46
|
+
def __post_init__(self) -> None:
|
|
47
|
+
"""Validate configuration."""
|
|
48
|
+
valid_sizes = {"tiny", "small", "base", "large"}
|
|
49
|
+
if self.model_size not in valid_sizes:
|
|
50
|
+
raise ValueError(
|
|
51
|
+
f"Invalid model_size '{self.model_size}'. "
|
|
52
|
+
f"Must be one of: {valid_sizes}"
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
valid_methods = {"sample", "direct"}
|
|
56
|
+
if self.quantile_method not in valid_methods:
|
|
57
|
+
raise ValueError(
|
|
58
|
+
f"Invalid quantile_method '{self.quantile_method}'. "
|
|
59
|
+
f"Must be one of: {valid_methods}"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class TSFMAdapter(ABC):
|
|
64
|
+
"""Abstract base class for Time-Series Foundation Model adapters.
|
|
65
|
+
|
|
66
|
+
Provides unified interface for different TSFMs while handling
|
|
67
|
+
model-specific quirks internally. Subclasses implement specific
|
|
68
|
+
adapters for models like Chronos, Moirai, and TimesFM.
|
|
69
|
+
|
|
70
|
+
Example:
|
|
71
|
+
>>> config = AdapterConfig(model_name="chronos", model_size="base")
|
|
72
|
+
>>> adapter = ChronosAdapter(config)
|
|
73
|
+
>>> adapter.load_model()
|
|
74
|
+
>>> result = adapter.predict(dataset, horizon=30)
|
|
75
|
+
|
|
76
|
+
Attributes:
|
|
77
|
+
config: Adapter configuration
|
|
78
|
+
_model: The underlying model instance (lazy loaded)
|
|
79
|
+
_device: Resolved compute device
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(self, config: AdapterConfig):
|
|
83
|
+
"""Initialize adapter with configuration.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
config: Adapter configuration
|
|
87
|
+
"""
|
|
88
|
+
self.config = config
|
|
89
|
+
self._model: Any | None = None
|
|
90
|
+
self._device = self._resolve_device()
|
|
91
|
+
|
|
92
|
+
def _resolve_device(self) -> str:
|
|
93
|
+
"""Resolve compute device (cuda/mps/cpu).
|
|
94
|
+
|
|
95
|
+
Returns the best available device based on hardware support.
|
|
96
|
+
Priority: cuda > mps > cpu
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
Device string for PyTorch
|
|
100
|
+
"""
|
|
101
|
+
if self.config.device:
|
|
102
|
+
return self.config.device
|
|
103
|
+
|
|
104
|
+
# Try to import torch for device detection
|
|
105
|
+
try:
|
|
106
|
+
import torch
|
|
107
|
+
|
|
108
|
+
if torch.cuda.is_available():
|
|
109
|
+
return "cuda"
|
|
110
|
+
elif torch.backends.mps.is_available():
|
|
111
|
+
return "mps"
|
|
112
|
+
except ImportError:
|
|
113
|
+
pass
|
|
114
|
+
|
|
115
|
+
return "cpu"
|
|
116
|
+
|
|
117
|
+
@abstractmethod
|
|
118
|
+
def load_model(self) -> None:
|
|
119
|
+
"""Load the foundation model with caching.
|
|
120
|
+
|
|
121
|
+
Downloads and caches the model if not already present.
|
|
122
|
+
Should be called before fit() or predict().
|
|
123
|
+
|
|
124
|
+
Raises:
|
|
125
|
+
ImportError: If required dependencies are not installed
|
|
126
|
+
RuntimeError: If model loading fails
|
|
127
|
+
"""
|
|
128
|
+
pass
|
|
129
|
+
|
|
130
|
+
@abstractmethod
|
|
131
|
+
def fit(
|
|
132
|
+
self,
|
|
133
|
+
dataset: TSDataset,
|
|
134
|
+
prediction_length: int,
|
|
135
|
+
quantiles: list[float] | None = None,
|
|
136
|
+
) -> ModelArtifact:
|
|
137
|
+
"""Prepare model for prediction on the dataset.
|
|
138
|
+
|
|
139
|
+
Note: Most TSFMs are zero-shot and don't require traditional fitting.
|
|
140
|
+
This method validates compatibility and may perform preprocessing.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
dataset: Training dataset
|
|
144
|
+
prediction_length: Forecast horizon
|
|
145
|
+
quantiles: Optional quantile levels for probabilistic forecasts
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
ModelArtifact with model reference and configuration
|
|
149
|
+
|
|
150
|
+
Raises:
|
|
151
|
+
ValueError: If dataset is incompatible with model
|
|
152
|
+
RuntimeError: If preparation fails
|
|
153
|
+
"""
|
|
154
|
+
pass
|
|
155
|
+
|
|
156
|
+
@abstractmethod
|
|
157
|
+
def predict(
|
|
158
|
+
self,
|
|
159
|
+
dataset: TSDataset,
|
|
160
|
+
horizon: int,
|
|
161
|
+
quantiles: list[float] | None = None,
|
|
162
|
+
) -> ForecastResult:
|
|
163
|
+
"""Generate forecasts using the TSFM.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
dataset: Historical data for context
|
|
167
|
+
horizon: Number of steps to forecast
|
|
168
|
+
quantiles: Optional quantile levels (e.g., [0.1, 0.5, 0.9])
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
ForecastResult with predictions and provenance
|
|
172
|
+
|
|
173
|
+
Raises:
|
|
174
|
+
RuntimeError: If prediction fails
|
|
175
|
+
ValueError: If horizon exceeds model limits
|
|
176
|
+
"""
|
|
177
|
+
pass
|
|
178
|
+
|
|
179
|
+
@abstractmethod
|
|
180
|
+
def get_model_signature(self) -> str:
|
|
181
|
+
"""Return unique signature for this model configuration.
|
|
182
|
+
|
|
183
|
+
Used for provenance tracking to identify the exact model
|
|
184
|
+
configuration used for a forecast.
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
Unique signature string (e.g., "chronos-base-cuda-v1.0")
|
|
188
|
+
"""
|
|
189
|
+
pass
|
|
190
|
+
|
|
191
|
+
@property
|
|
192
|
+
def is_loaded(self) -> bool:
|
|
193
|
+
"""Check if model is loaded in memory.
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
True if model has been loaded, False otherwise
|
|
197
|
+
"""
|
|
198
|
+
return self._model is not None
|
|
199
|
+
|
|
200
|
+
def unload_model(self) -> None:
|
|
201
|
+
"""Unload model from memory to free resources.
|
|
202
|
+
|
|
203
|
+
Useful for managing memory when using multiple large models.
|
|
204
|
+
"""
|
|
205
|
+
self._model = None
|
|
206
|
+
|
|
207
|
+
def _validate_dataset(self, dataset: TSDataset) -> None:
|
|
208
|
+
"""Validate dataset compatibility.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
dataset: Dataset to validate
|
|
212
|
+
|
|
213
|
+
Raises:
|
|
214
|
+
ValueError: If dataset is incompatible
|
|
215
|
+
"""
|
|
216
|
+
# Check minimum length requirements
|
|
217
|
+
min_length = 1 # Most TSFMs need at least some context
|
|
218
|
+
|
|
219
|
+
for uid in dataset.series_ids:
|
|
220
|
+
series = dataset.get_series(uid)
|
|
221
|
+
if len(series) < min_length:
|
|
222
|
+
raise ValueError(
|
|
223
|
+
f"Series '{uid}' has only {len(series)} observations. "
|
|
224
|
+
f"Minimum required: {min_length}"
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
def _create_provenance(
|
|
228
|
+
self,
|
|
229
|
+
dataset: TSDataset,
|
|
230
|
+
horizon: int,
|
|
231
|
+
quantiles: list[float] | None = None,
|
|
232
|
+
) -> Provenance:
|
|
233
|
+
"""Create provenance for forecast.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
dataset: Input dataset
|
|
237
|
+
horizon: Forecast horizon
|
|
238
|
+
quantiles: Quantile levels used
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
Provenance instance
|
|
242
|
+
"""
|
|
243
|
+
# Create data signature from dataset
|
|
244
|
+
import hashlib
|
|
245
|
+
from datetime import datetime
|
|
246
|
+
|
|
247
|
+
from tsagentkit.contracts import Provenance
|
|
248
|
+
|
|
249
|
+
data_hash = hashlib.sha256(
|
|
250
|
+
str(dataset.n_series).encode() +
|
|
251
|
+
str(dataset.n_observations).encode() +
|
|
252
|
+
dataset.date_range[0].isoformat().encode()
|
|
253
|
+
).hexdigest()[:16]
|
|
254
|
+
|
|
255
|
+
return Provenance(
|
|
256
|
+
run_id=f"tsfm_{datetime.now(UTC).strftime('%Y%m%d_%H%M%S')}",
|
|
257
|
+
timestamp=datetime.now(UTC).isoformat(),
|
|
258
|
+
data_signature=data_hash,
|
|
259
|
+
task_signature=f"horizon={horizon}",
|
|
260
|
+
plan_signature=self.get_model_signature(),
|
|
261
|
+
model_signature=self.get_model_signature(),
|
|
262
|
+
metadata={
|
|
263
|
+
"adapter": self.__class__.__name__,
|
|
264
|
+
"device": self._device,
|
|
265
|
+
"quantiles": quantiles,
|
|
266
|
+
},
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
def _batch_iterator(
|
|
270
|
+
self,
|
|
271
|
+
data: list[Any],
|
|
272
|
+
batch_size: int,
|
|
273
|
+
):
|
|
274
|
+
"""Iterate over data in batches.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
data: List of data items
|
|
278
|
+
batch_size: Batch size
|
|
279
|
+
|
|
280
|
+
Yields:
|
|
281
|
+
Batches of data
|
|
282
|
+
"""
|
|
283
|
+
for i in range(0, len(data), batch_size):
|
|
284
|
+
yield data[i : i + batch_size]
|
|
285
|
+
|
|
286
|
+
def _compute_quantiles_from_samples(
|
|
287
|
+
self,
|
|
288
|
+
samples: Any, # numpy array or tensor
|
|
289
|
+
quantiles: list[float],
|
|
290
|
+
) -> dict[float, float]:
|
|
291
|
+
"""Compute quantiles from sampled predictions.
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
samples: Array of samples with shape (n_samples, horizon)
|
|
295
|
+
quantiles: Quantile levels to compute
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
Dictionary mapping quantile levels to values
|
|
299
|
+
"""
|
|
300
|
+
import numpy as np
|
|
301
|
+
|
|
302
|
+
results = {}
|
|
303
|
+
for q in quantiles:
|
|
304
|
+
results[q] = float(np.quantile(samples, q, axis=0))
|
|
305
|
+
return results
|
|
306
|
+
|
|
307
|
+
@classmethod
|
|
308
|
+
def _check_dependencies(cls) -> None:
|
|
309
|
+
"""Check if required dependencies are installed.
|
|
310
|
+
|
|
311
|
+
Raises:
|
|
312
|
+
ImportError: If dependencies are missing
|
|
313
|
+
"""
|
|
314
|
+
# Base class checks for torch only
|
|
315
|
+
try:
|
|
316
|
+
import torch # noqa: F401
|
|
317
|
+
except ImportError as e:
|
|
318
|
+
raise ImportError(
|
|
319
|
+
"PyTorch is required for TSFM adapters. "
|
|
320
|
+
"Install with: pip install torch"
|
|
321
|
+
) from e
|