tsagentkit 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. tsagentkit/__init__.py +126 -0
  2. tsagentkit/anomaly/__init__.py +130 -0
  3. tsagentkit/backtest/__init__.py +48 -0
  4. tsagentkit/backtest/engine.py +788 -0
  5. tsagentkit/backtest/metrics.py +244 -0
  6. tsagentkit/backtest/report.py +342 -0
  7. tsagentkit/calibration/__init__.py +136 -0
  8. tsagentkit/contracts/__init__.py +133 -0
  9. tsagentkit/contracts/errors.py +275 -0
  10. tsagentkit/contracts/results.py +418 -0
  11. tsagentkit/contracts/schema.py +44 -0
  12. tsagentkit/contracts/task_spec.py +300 -0
  13. tsagentkit/covariates/__init__.py +340 -0
  14. tsagentkit/eval/__init__.py +285 -0
  15. tsagentkit/features/__init__.py +20 -0
  16. tsagentkit/features/covariates.py +328 -0
  17. tsagentkit/features/extra/__init__.py +5 -0
  18. tsagentkit/features/extra/native.py +179 -0
  19. tsagentkit/features/factory.py +187 -0
  20. tsagentkit/features/matrix.py +159 -0
  21. tsagentkit/features/tsfeatures_adapter.py +115 -0
  22. tsagentkit/features/versioning.py +203 -0
  23. tsagentkit/hierarchy/__init__.py +39 -0
  24. tsagentkit/hierarchy/aggregation.py +62 -0
  25. tsagentkit/hierarchy/evaluator.py +400 -0
  26. tsagentkit/hierarchy/reconciliation.py +232 -0
  27. tsagentkit/hierarchy/structure.py +453 -0
  28. tsagentkit/models/__init__.py +182 -0
  29. tsagentkit/models/adapters/__init__.py +83 -0
  30. tsagentkit/models/adapters/base.py +321 -0
  31. tsagentkit/models/adapters/chronos.py +387 -0
  32. tsagentkit/models/adapters/moirai.py +256 -0
  33. tsagentkit/models/adapters/registry.py +171 -0
  34. tsagentkit/models/adapters/timesfm.py +440 -0
  35. tsagentkit/models/baselines.py +207 -0
  36. tsagentkit/models/sktime.py +307 -0
  37. tsagentkit/monitoring/__init__.py +51 -0
  38. tsagentkit/monitoring/alerts.py +302 -0
  39. tsagentkit/monitoring/coverage.py +203 -0
  40. tsagentkit/monitoring/drift.py +330 -0
  41. tsagentkit/monitoring/report.py +214 -0
  42. tsagentkit/monitoring/stability.py +275 -0
  43. tsagentkit/monitoring/triggers.py +423 -0
  44. tsagentkit/qa/__init__.py +347 -0
  45. tsagentkit/router/__init__.py +37 -0
  46. tsagentkit/router/bucketing.py +489 -0
  47. tsagentkit/router/fallback.py +132 -0
  48. tsagentkit/router/plan.py +23 -0
  49. tsagentkit/router/router.py +271 -0
  50. tsagentkit/series/__init__.py +26 -0
  51. tsagentkit/series/alignment.py +206 -0
  52. tsagentkit/series/dataset.py +449 -0
  53. tsagentkit/series/sparsity.py +261 -0
  54. tsagentkit/series/validation.py +393 -0
  55. tsagentkit/serving/__init__.py +39 -0
  56. tsagentkit/serving/orchestration.py +943 -0
  57. tsagentkit/serving/packaging.py +73 -0
  58. tsagentkit/serving/provenance.py +317 -0
  59. tsagentkit/serving/tsfm_cache.py +214 -0
  60. tsagentkit/skill/README.md +135 -0
  61. tsagentkit/skill/__init__.py +8 -0
  62. tsagentkit/skill/recipes.md +429 -0
  63. tsagentkit/skill/tool_map.md +21 -0
  64. tsagentkit/time/__init__.py +134 -0
  65. tsagentkit/utils/__init__.py +20 -0
  66. tsagentkit/utils/quantiles.py +83 -0
  67. tsagentkit/utils/signature.py +47 -0
  68. tsagentkit/utils/temporal.py +41 -0
  69. tsagentkit-1.0.2.dist-info/METADATA +371 -0
  70. tsagentkit-1.0.2.dist-info/RECORD +72 -0
  71. tsagentkit-1.0.2.dist-info/WHEEL +4 -0
  72. tsagentkit-1.0.2.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,182 @@
1
+ """Models module for tsagentkit.
2
+
3
+ Provides model fitting, prediction, and TSFM (Time-Series Foundation Model)
4
+ adapters for various forecasting backends.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from collections.abc import Callable
10
+ from datetime import UTC
11
+ from typing import TYPE_CHECKING, Any
12
+
13
+ from tsagentkit.contracts import ForecastResult, ModelArtifact, Provenance
14
+
15
+ # Import adapters submodules
16
+ from tsagentkit.models import adapters
17
+ from tsagentkit.models.baselines import fit_baseline, is_baseline_model, predict_baseline
18
+ from tsagentkit.models.sktime import SktimeModelBundle, fit_sktime, predict_sktime
19
+ from tsagentkit.utils import normalize_quantile_columns
20
+
21
+ if TYPE_CHECKING:
22
+ from tsagentkit.contracts import TaskSpec
23
+ from tsagentkit.router import PlanSpec
24
+ from tsagentkit.series import TSDataset
25
+
26
+
27
+ def _is_tsfm_model(model_name: str) -> bool:
28
+ return model_name.lower().startswith("tsfm-")
29
+
30
+
31
+ def _is_sktime_model(model_name: str) -> bool:
32
+ return model_name.lower().startswith("sktime-")
33
+
34
+
35
+ def _build_adapter_config(model_name: str, config: dict[str, Any]) -> adapters.AdapterConfig:
36
+ adapter_name = model_name.split("tsfm-", 1)[-1]
37
+ return adapters.AdapterConfig(
38
+ model_name=adapter_name,
39
+ model_size=config.get("model_size", "base"),
40
+ device=config.get("device"),
41
+ cache_dir=config.get("cache_dir"),
42
+ batch_size=config.get("batch_size", 32),
43
+ prediction_batch_size=config.get("prediction_batch_size", 100),
44
+ quantile_method=config.get("quantile_method", "sample"),
45
+ num_samples=config.get("num_samples", 100),
46
+ max_context_length=config.get("max_context_length"),
47
+ )
48
+
49
+
50
+ def _fit_model_name(
51
+ model_name: str,
52
+ dataset: TSDataset,
53
+ plan: PlanSpec,
54
+ covariates: Any | None = None,
55
+ ) -> ModelArtifact:
56
+ """Fit a model by name with baseline or TSFM dispatch."""
57
+ config: dict[str, Any] = {
58
+ "horizon": dataset.task_spec.horizon,
59
+ "season_length": dataset.task_spec.season_length or 1,
60
+ "quantiles": plan.quantiles,
61
+ }
62
+
63
+ if _is_tsfm_model(model_name):
64
+ adapter_name = model_name.split("tsfm-", 1)[-1]
65
+ adapter_config = _build_adapter_config(model_name, {})
66
+ adapter = adapters.AdapterRegistry.create(adapter_name, adapter_config)
67
+ adapter.fit(
68
+ dataset=dataset,
69
+ prediction_length=dataset.task_spec.horizon,
70
+ quantiles=plan.quantiles,
71
+ )
72
+ return ModelArtifact(
73
+ model=adapter,
74
+ model_name=model_name,
75
+ config=config,
76
+ metadata={"adapter": adapter_name},
77
+ )
78
+
79
+ if is_baseline_model(model_name):
80
+ return fit_baseline(model_name, dataset, config)
81
+
82
+ if _is_sktime_model(model_name):
83
+ return fit_sktime(model_name, dataset, plan, covariates=covariates)
84
+
85
+ raise ValueError(f"Unknown model name: {model_name}")
86
+
87
+
88
+ def fit(
89
+ dataset: TSDataset,
90
+ plan: PlanSpec,
91
+ on_fallback: Callable[[str, str, Exception], None] | None = None,
92
+ covariates: Any | None = None,
93
+ ) -> ModelArtifact:
94
+ """Fit a model using the plan's fallback ladder."""
95
+ from tsagentkit.router import execute_with_fallback
96
+
97
+ def _fit(model_name: str, ds: TSDataset) -> ModelArtifact:
98
+ return _fit_model_name(model_name, ds, plan, covariates=covariates)
99
+
100
+ artifact, _ = execute_with_fallback(
101
+ fit_func=_fit,
102
+ dataset=dataset,
103
+ plan=plan,
104
+ on_fallback=on_fallback,
105
+ )
106
+ return artifact
107
+
108
+
109
+ def _basic_provenance(
110
+ dataset: TSDataset,
111
+ spec: TaskSpec,
112
+ artifact: ModelArtifact,
113
+ ) -> Provenance:
114
+ from datetime import datetime
115
+
116
+ from tsagentkit.utils import compute_data_signature
117
+
118
+ return Provenance(
119
+ run_id=f"model_{datetime.now(UTC).strftime('%Y%m%d_%H%M%S')}",
120
+ timestamp=datetime.now(UTC).isoformat(),
121
+ data_signature=compute_data_signature(dataset.df),
122
+ task_signature=spec.model_hash(),
123
+ plan_signature=artifact.signature,
124
+ model_signature=artifact.signature,
125
+ metadata={"provenance_incomplete": True},
126
+ )
127
+
128
+
129
+ def predict(
130
+ dataset: TSDataset,
131
+ artifact: ModelArtifact,
132
+ spec: TaskSpec,
133
+ covariates: Any | None = None,
134
+ ) -> ForecastResult:
135
+ """Generate predictions for baseline or TSFM models."""
136
+ if isinstance(artifact.model, adapters.TSFMAdapter):
137
+ result = artifact.model.predict(
138
+ dataset=dataset,
139
+ horizon=spec.horizon,
140
+ quantiles=artifact.config.get("quantiles"),
141
+ )
142
+ df = normalize_quantile_columns(result.df)
143
+ if "model" not in df.columns:
144
+ df = df.copy()
145
+ df["model"] = artifact.model_name
146
+ return ForecastResult(
147
+ df=df,
148
+ provenance=result.provenance,
149
+ model_name=artifact.model_name,
150
+ horizon=spec.horizon,
151
+ )
152
+
153
+ if is_baseline_model(artifact.model_name):
154
+ forecast_df = predict_baseline(
155
+ model_artifact=artifact,
156
+ dataset=dataset,
157
+ horizon=spec.horizon,
158
+ quantiles=artifact.config.get("quantiles"),
159
+ )
160
+ if "model" not in forecast_df.columns:
161
+ forecast_df = forecast_df.copy()
162
+ forecast_df["model"] = artifact.model_name
163
+ provenance = _basic_provenance(dataset, spec, artifact)
164
+ return ForecastResult(
165
+ df=normalize_quantile_columns(forecast_df),
166
+ provenance=provenance,
167
+ model_name=artifact.model_name,
168
+ horizon=spec.horizon,
169
+ )
170
+
171
+ if isinstance(artifact.model, SktimeModelBundle):
172
+ return predict_sktime(
173
+ dataset=dataset,
174
+ artifact=artifact,
175
+ spec=spec,
176
+ covariates=covariates,
177
+ )
178
+
179
+ raise ValueError(f"Unknown model type for prediction: {artifact.model_name}")
180
+
181
+
182
+ __all__ = ["fit", "predict", "adapters"]
@@ -0,0 +1,83 @@
1
+ """TSFM (Time-Series Foundation Model) adapters for tsagentkit.
2
+
3
+ This module provides unified adapters for popular time-series foundation models:
4
+ - Amazon Chronos
5
+ - Salesforce Moirai
6
+ - Google TimesFM
7
+
8
+ Example:
9
+ >>> from tsagentkit.models.adapters import AdapterConfig, AdapterRegistry
10
+ >>> from tsagentkit.models.adapters import ChronosAdapter
11
+ >>>
12
+ >>> # Register adapters
13
+ >>> AdapterRegistry.register("chronos", ChronosAdapter)
14
+ >>>
15
+ >>> # Create and use adapter
16
+ >>> config = AdapterConfig(model_name="chronos", model_size="base")
17
+ >>> adapter = AdapterRegistry.create("chronos", config)
18
+ >>> adapter.load_model()
19
+ >>> result = adapter.predict(dataset, horizon=30)
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ # Base classes
25
+ from .base import AdapterConfig, TSFMAdapter
26
+ from .registry import AdapterRegistry
27
+
28
+ __all__ = [
29
+ # Base classes
30
+ "AdapterConfig",
31
+ "TSFMAdapter",
32
+ "AdapterRegistry",
33
+ ]
34
+
35
+ # Lazily import specific adapters as they become available
36
+ # This prevents import errors when dependencies are not installed
37
+
38
+
39
+ def _try_import_chronos():
40
+ """Try to import ChronosAdapter if dependencies are available."""
41
+ try:
42
+ from .chronos import ChronosAdapter
43
+
44
+ return ChronosAdapter
45
+ except ImportError:
46
+ return None
47
+
48
+
49
+ def _try_import_moirai():
50
+ """Try to import MoiraiAdapter if dependencies are available."""
51
+ try:
52
+ from .moirai import MoiraiAdapter
53
+
54
+ return MoiraiAdapter
55
+ except ImportError:
56
+ return None
57
+
58
+
59
+ def _try_import_timesfm():
60
+ """Try to import TimesFMAdapter if dependencies are available."""
61
+ try:
62
+ from .timesfm import TimesFMAdapter
63
+
64
+ return TimesFMAdapter
65
+ except ImportError:
66
+ return None
67
+
68
+
69
+ # Auto-register available adapters
70
+ _chronos = _try_import_chronos()
71
+ if _chronos:
72
+ AdapterRegistry.register("chronos", _chronos)
73
+ __all__.append("ChronosAdapter")
74
+
75
+ _moirai = _try_import_moirai()
76
+ if _moirai:
77
+ AdapterRegistry.register("moirai", _moirai)
78
+ __all__.append("MoiraiAdapter")
79
+
80
+ _timesfm = _try_import_timesfm()
81
+ if _timesfm:
82
+ AdapterRegistry.register("timesfm", _timesfm)
83
+ __all__.append("TimesFMAdapter")
@@ -0,0 +1,321 @@
1
+ """Base adapter class for Time-Series Foundation Models.
2
+
3
+ Provides a unified interface for integrating external TSFMs like Chronos,
4
+ Moirai, and TimesFM with the tsagentkit pipeline.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from abc import ABC, abstractmethod
10
+ from dataclasses import dataclass
11
+ from datetime import UTC
12
+ from typing import TYPE_CHECKING, Any, Literal
13
+
14
+ if TYPE_CHECKING:
15
+
16
+ from tsagentkit.contracts import ForecastResult, ModelArtifact, Provenance
17
+ from tsagentkit.series import TSDataset
18
+
19
+
20
+ @dataclass(frozen=True)
21
+ class AdapterConfig:
22
+ """Configuration for TSFM adapter.
23
+
24
+ Attributes:
25
+ model_name: Name of the model/adapter
26
+ model_size: Model size variant (small, base, large)
27
+ device: Compute device (cuda, mps, cpu, or None for auto)
28
+ cache_dir: Directory for model caching
29
+ batch_size: Batch size for training/fitting
30
+ prediction_batch_size: Batch size for prediction
31
+ quantile_method: Method for quantile prediction (sample, direct)
32
+ num_samples: Number of samples for probabilistic forecasting
33
+ max_context_length: Maximum context length the model accepts
34
+ """
35
+
36
+ model_name: str
37
+ model_size: Literal["tiny", "small", "base", "large"] = "base"
38
+ device: str | None = None # Auto-detect if None
39
+ cache_dir: str | None = None
40
+ batch_size: int = 32
41
+ prediction_batch_size: int = 100
42
+ quantile_method: Literal["sample", "direct"] = "sample"
43
+ num_samples: int = 100
44
+ max_context_length: int | None = None
45
+
46
+ def __post_init__(self) -> None:
47
+ """Validate configuration."""
48
+ valid_sizes = {"tiny", "small", "base", "large"}
49
+ if self.model_size not in valid_sizes:
50
+ raise ValueError(
51
+ f"Invalid model_size '{self.model_size}'. "
52
+ f"Must be one of: {valid_sizes}"
53
+ )
54
+
55
+ valid_methods = {"sample", "direct"}
56
+ if self.quantile_method not in valid_methods:
57
+ raise ValueError(
58
+ f"Invalid quantile_method '{self.quantile_method}'. "
59
+ f"Must be one of: {valid_methods}"
60
+ )
61
+
62
+
63
+ class TSFMAdapter(ABC):
64
+ """Abstract base class for Time-Series Foundation Model adapters.
65
+
66
+ Provides unified interface for different TSFMs while handling
67
+ model-specific quirks internally. Subclasses implement specific
68
+ adapters for models like Chronos, Moirai, and TimesFM.
69
+
70
+ Example:
71
+ >>> config = AdapterConfig(model_name="chronos", model_size="base")
72
+ >>> adapter = ChronosAdapter(config)
73
+ >>> adapter.load_model()
74
+ >>> result = adapter.predict(dataset, horizon=30)
75
+
76
+ Attributes:
77
+ config: Adapter configuration
78
+ _model: The underlying model instance (lazy loaded)
79
+ _device: Resolved compute device
80
+ """
81
+
82
+ def __init__(self, config: AdapterConfig):
83
+ """Initialize adapter with configuration.
84
+
85
+ Args:
86
+ config: Adapter configuration
87
+ """
88
+ self.config = config
89
+ self._model: Any | None = None
90
+ self._device = self._resolve_device()
91
+
92
+ def _resolve_device(self) -> str:
93
+ """Resolve compute device (cuda/mps/cpu).
94
+
95
+ Returns the best available device based on hardware support.
96
+ Priority: cuda > mps > cpu
97
+
98
+ Returns:
99
+ Device string for PyTorch
100
+ """
101
+ if self.config.device:
102
+ return self.config.device
103
+
104
+ # Try to import torch for device detection
105
+ try:
106
+ import torch
107
+
108
+ if torch.cuda.is_available():
109
+ return "cuda"
110
+ elif torch.backends.mps.is_available():
111
+ return "mps"
112
+ except ImportError:
113
+ pass
114
+
115
+ return "cpu"
116
+
117
+ @abstractmethod
118
+ def load_model(self) -> None:
119
+ """Load the foundation model with caching.
120
+
121
+ Downloads and caches the model if not already present.
122
+ Should be called before fit() or predict().
123
+
124
+ Raises:
125
+ ImportError: If required dependencies are not installed
126
+ RuntimeError: If model loading fails
127
+ """
128
+ pass
129
+
130
+ @abstractmethod
131
+ def fit(
132
+ self,
133
+ dataset: TSDataset,
134
+ prediction_length: int,
135
+ quantiles: list[float] | None = None,
136
+ ) -> ModelArtifact:
137
+ """Prepare model for prediction on the dataset.
138
+
139
+ Note: Most TSFMs are zero-shot and don't require traditional fitting.
140
+ This method validates compatibility and may perform preprocessing.
141
+
142
+ Args:
143
+ dataset: Training dataset
144
+ prediction_length: Forecast horizon
145
+ quantiles: Optional quantile levels for probabilistic forecasts
146
+
147
+ Returns:
148
+ ModelArtifact with model reference and configuration
149
+
150
+ Raises:
151
+ ValueError: If dataset is incompatible with model
152
+ RuntimeError: If preparation fails
153
+ """
154
+ pass
155
+
156
+ @abstractmethod
157
+ def predict(
158
+ self,
159
+ dataset: TSDataset,
160
+ horizon: int,
161
+ quantiles: list[float] | None = None,
162
+ ) -> ForecastResult:
163
+ """Generate forecasts using the TSFM.
164
+
165
+ Args:
166
+ dataset: Historical data for context
167
+ horizon: Number of steps to forecast
168
+ quantiles: Optional quantile levels (e.g., [0.1, 0.5, 0.9])
169
+
170
+ Returns:
171
+ ForecastResult with predictions and provenance
172
+
173
+ Raises:
174
+ RuntimeError: If prediction fails
175
+ ValueError: If horizon exceeds model limits
176
+ """
177
+ pass
178
+
179
+ @abstractmethod
180
+ def get_model_signature(self) -> str:
181
+ """Return unique signature for this model configuration.
182
+
183
+ Used for provenance tracking to identify the exact model
184
+ configuration used for a forecast.
185
+
186
+ Returns:
187
+ Unique signature string (e.g., "chronos-base-cuda-v1.0")
188
+ """
189
+ pass
190
+
191
+ @property
192
+ def is_loaded(self) -> bool:
193
+ """Check if model is loaded in memory.
194
+
195
+ Returns:
196
+ True if model has been loaded, False otherwise
197
+ """
198
+ return self._model is not None
199
+
200
+ def unload_model(self) -> None:
201
+ """Unload model from memory to free resources.
202
+
203
+ Useful for managing memory when using multiple large models.
204
+ """
205
+ self._model = None
206
+
207
+ def _validate_dataset(self, dataset: TSDataset) -> None:
208
+ """Validate dataset compatibility.
209
+
210
+ Args:
211
+ dataset: Dataset to validate
212
+
213
+ Raises:
214
+ ValueError: If dataset is incompatible
215
+ """
216
+ # Check minimum length requirements
217
+ min_length = 1 # Most TSFMs need at least some context
218
+
219
+ for uid in dataset.series_ids:
220
+ series = dataset.get_series(uid)
221
+ if len(series) < min_length:
222
+ raise ValueError(
223
+ f"Series '{uid}' has only {len(series)} observations. "
224
+ f"Minimum required: {min_length}"
225
+ )
226
+
227
+ def _create_provenance(
228
+ self,
229
+ dataset: TSDataset,
230
+ horizon: int,
231
+ quantiles: list[float] | None = None,
232
+ ) -> Provenance:
233
+ """Create provenance for forecast.
234
+
235
+ Args:
236
+ dataset: Input dataset
237
+ horizon: Forecast horizon
238
+ quantiles: Quantile levels used
239
+
240
+ Returns:
241
+ Provenance instance
242
+ """
243
+ # Create data signature from dataset
244
+ import hashlib
245
+ from datetime import datetime
246
+
247
+ from tsagentkit.contracts import Provenance
248
+
249
+ data_hash = hashlib.sha256(
250
+ str(dataset.n_series).encode() +
251
+ str(dataset.n_observations).encode() +
252
+ dataset.date_range[0].isoformat().encode()
253
+ ).hexdigest()[:16]
254
+
255
+ return Provenance(
256
+ run_id=f"tsfm_{datetime.now(UTC).strftime('%Y%m%d_%H%M%S')}",
257
+ timestamp=datetime.now(UTC).isoformat(),
258
+ data_signature=data_hash,
259
+ task_signature=f"horizon={horizon}",
260
+ plan_signature=self.get_model_signature(),
261
+ model_signature=self.get_model_signature(),
262
+ metadata={
263
+ "adapter": self.__class__.__name__,
264
+ "device": self._device,
265
+ "quantiles": quantiles,
266
+ },
267
+ )
268
+
269
+ def _batch_iterator(
270
+ self,
271
+ data: list[Any],
272
+ batch_size: int,
273
+ ):
274
+ """Iterate over data in batches.
275
+
276
+ Args:
277
+ data: List of data items
278
+ batch_size: Batch size
279
+
280
+ Yields:
281
+ Batches of data
282
+ """
283
+ for i in range(0, len(data), batch_size):
284
+ yield data[i : i + batch_size]
285
+
286
+ def _compute_quantiles_from_samples(
287
+ self,
288
+ samples: Any, # numpy array or tensor
289
+ quantiles: list[float],
290
+ ) -> dict[float, float]:
291
+ """Compute quantiles from sampled predictions.
292
+
293
+ Args:
294
+ samples: Array of samples with shape (n_samples, horizon)
295
+ quantiles: Quantile levels to compute
296
+
297
+ Returns:
298
+ Dictionary mapping quantile levels to values
299
+ """
300
+ import numpy as np
301
+
302
+ results = {}
303
+ for q in quantiles:
304
+ results[q] = float(np.quantile(samples, q, axis=0))
305
+ return results
306
+
307
+ @classmethod
308
+ def _check_dependencies(cls) -> None:
309
+ """Check if required dependencies are installed.
310
+
311
+ Raises:
312
+ ImportError: If dependencies are missing
313
+ """
314
+ # Base class checks for torch only
315
+ try:
316
+ import torch # noqa: F401
317
+ except ImportError as e:
318
+ raise ImportError(
319
+ "PyTorch is required for TSFM adapters. "
320
+ "Install with: pip install torch"
321
+ ) from e