tsagentkit 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. tsagentkit/__init__.py +126 -0
  2. tsagentkit/anomaly/__init__.py +130 -0
  3. tsagentkit/backtest/__init__.py +48 -0
  4. tsagentkit/backtest/engine.py +788 -0
  5. tsagentkit/backtest/metrics.py +244 -0
  6. tsagentkit/backtest/report.py +342 -0
  7. tsagentkit/calibration/__init__.py +136 -0
  8. tsagentkit/contracts/__init__.py +133 -0
  9. tsagentkit/contracts/errors.py +275 -0
  10. tsagentkit/contracts/results.py +418 -0
  11. tsagentkit/contracts/schema.py +44 -0
  12. tsagentkit/contracts/task_spec.py +300 -0
  13. tsagentkit/covariates/__init__.py +340 -0
  14. tsagentkit/eval/__init__.py +285 -0
  15. tsagentkit/features/__init__.py +20 -0
  16. tsagentkit/features/covariates.py +328 -0
  17. tsagentkit/features/extra/__init__.py +5 -0
  18. tsagentkit/features/extra/native.py +179 -0
  19. tsagentkit/features/factory.py +187 -0
  20. tsagentkit/features/matrix.py +159 -0
  21. tsagentkit/features/tsfeatures_adapter.py +115 -0
  22. tsagentkit/features/versioning.py +203 -0
  23. tsagentkit/hierarchy/__init__.py +39 -0
  24. tsagentkit/hierarchy/aggregation.py +62 -0
  25. tsagentkit/hierarchy/evaluator.py +400 -0
  26. tsagentkit/hierarchy/reconciliation.py +232 -0
  27. tsagentkit/hierarchy/structure.py +453 -0
  28. tsagentkit/models/__init__.py +182 -0
  29. tsagentkit/models/adapters/__init__.py +83 -0
  30. tsagentkit/models/adapters/base.py +321 -0
  31. tsagentkit/models/adapters/chronos.py +387 -0
  32. tsagentkit/models/adapters/moirai.py +256 -0
  33. tsagentkit/models/adapters/registry.py +171 -0
  34. tsagentkit/models/adapters/timesfm.py +440 -0
  35. tsagentkit/models/baselines.py +207 -0
  36. tsagentkit/models/sktime.py +307 -0
  37. tsagentkit/monitoring/__init__.py +51 -0
  38. tsagentkit/monitoring/alerts.py +302 -0
  39. tsagentkit/monitoring/coverage.py +203 -0
  40. tsagentkit/monitoring/drift.py +330 -0
  41. tsagentkit/monitoring/report.py +214 -0
  42. tsagentkit/monitoring/stability.py +275 -0
  43. tsagentkit/monitoring/triggers.py +423 -0
  44. tsagentkit/qa/__init__.py +347 -0
  45. tsagentkit/router/__init__.py +37 -0
  46. tsagentkit/router/bucketing.py +489 -0
  47. tsagentkit/router/fallback.py +132 -0
  48. tsagentkit/router/plan.py +23 -0
  49. tsagentkit/router/router.py +271 -0
  50. tsagentkit/series/__init__.py +26 -0
  51. tsagentkit/series/alignment.py +206 -0
  52. tsagentkit/series/dataset.py +449 -0
  53. tsagentkit/series/sparsity.py +261 -0
  54. tsagentkit/series/validation.py +393 -0
  55. tsagentkit/serving/__init__.py +39 -0
  56. tsagentkit/serving/orchestration.py +943 -0
  57. tsagentkit/serving/packaging.py +73 -0
  58. tsagentkit/serving/provenance.py +317 -0
  59. tsagentkit/serving/tsfm_cache.py +214 -0
  60. tsagentkit/skill/README.md +135 -0
  61. tsagentkit/skill/__init__.py +8 -0
  62. tsagentkit/skill/recipes.md +429 -0
  63. tsagentkit/skill/tool_map.md +21 -0
  64. tsagentkit/time/__init__.py +134 -0
  65. tsagentkit/utils/__init__.py +20 -0
  66. tsagentkit/utils/quantiles.py +83 -0
  67. tsagentkit/utils/signature.py +47 -0
  68. tsagentkit/utils/temporal.py +41 -0
  69. tsagentkit-1.0.2.dist-info/METADATA +371 -0
  70. tsagentkit-1.0.2.dist-info/RECORD +72 -0
  71. tsagentkit-1.0.2.dist-info/WHEEL +4 -0
  72. tsagentkit-1.0.2.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,73 @@
1
+ """Run artifact packaging.
2
+
3
+ Bundles all outputs from a forecasting run into a comprehensive artifact.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Any
9
+
10
+ from tsagentkit.contracts import RunArtifact
11
+
12
+ if TYPE_CHECKING:
13
+ from tsagentkit.backtest import BacktestReport
14
+ from tsagentkit.contracts import ForecastResult, ModelArtifact, Provenance
15
+ from tsagentkit.qa import QAReport
16
+ from tsagentkit.router import PlanSpec
17
+
18
+
19
+ def package_run(
20
+ forecast: ForecastResult,
21
+ plan: PlanSpec | dict[str, Any],
22
+ task_spec: Any | None = None,
23
+ plan_spec: dict[str, Any] | None = None,
24
+ validation_report: dict[str, Any] | None = None,
25
+ backtest_report: BacktestReport | None = None,
26
+ qa_report: QAReport | None = None,
27
+ model_artifact: ModelArtifact | None = None,
28
+ provenance: Provenance | None = None,
29
+ calibration_artifact: dict[str, Any] | None = None,
30
+ anomaly_report: dict[str, Any] | None = None,
31
+ metadata: dict[str, Any] | None = None,
32
+ ) -> RunArtifact:
33
+ """Package all run outputs into a comprehensive artifact.
34
+
35
+ Args:
36
+ forecast: ForecastResult with predictions + provenance
37
+ plan: Execution plan (PlanSpec or dict)
38
+ backtest_report: Optional backtest results
39
+ qa_report: Optional QA report
40
+ model_artifact: Optional fitted model
41
+ provenance: Optional provenance information (overrides forecast provenance)
42
+ calibration_artifact: Optional calibration artifact
43
+ anomaly_report: Optional anomaly report
44
+ metadata: Optional metadata
45
+
46
+ Returns:
47
+ RunArtifact containing all run outputs
48
+ """
49
+ if hasattr(plan, "to_dict"):
50
+ plan_dict = plan.to_dict()
51
+ elif hasattr(plan, "model_dump"):
52
+ plan_dict = plan.model_dump()
53
+ else:
54
+ plan_dict = plan
55
+ if plan_spec is None:
56
+ plan_spec = plan_dict
57
+ backtest_dict = backtest_report.to_dict() if backtest_report else None
58
+ qa_dict = qa_report.to_dict() if qa_report and hasattr(qa_report, "to_dict") else None
59
+
60
+ return RunArtifact(
61
+ forecast=forecast,
62
+ plan=plan_dict,
63
+ task_spec=task_spec,
64
+ plan_spec=plan_spec,
65
+ validation_report=validation_report,
66
+ backtest_report=backtest_dict,
67
+ qa_report=qa_dict,
68
+ model_artifact=model_artifact,
69
+ provenance=provenance or forecast.provenance,
70
+ calibration_artifact=calibration_artifact,
71
+ anomaly_report=anomaly_report,
72
+ metadata=metadata or {},
73
+ )
@@ -0,0 +1,317 @@
1
+ """Provenance tracking for forecasting runs.
2
+
3
+ Provides utilities for tracking data lineage and reproducibility.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import json
9
+ from typing import TYPE_CHECKING, Any
10
+ from uuid import uuid4
11
+
12
+ import pandas as pd
13
+
14
+ if TYPE_CHECKING:
15
+ pass
16
+
17
+ from tsagentkit.contracts import Provenance, TaskSpec
18
+ from tsagentkit.router import PlanSpec
19
+
20
+
21
+ from datetime import UTC
22
+
23
+ from tsagentkit.utils import compute_config_signature, compute_data_signature
24
+
25
+
26
+ def create_provenance(
27
+ data: pd.DataFrame,
28
+ task_spec: TaskSpec,
29
+ plan: PlanSpec,
30
+ model_config: dict[str, Any] | None = None,
31
+ qa_repairs: list[Any] | None = None,
32
+ fallbacks_triggered: list[dict[str, Any]] | None = None,
33
+ feature_matrix: Any | None = None,
34
+ drift_report: Any | None = None,
35
+ column_map: dict[str, str] | None = None,
36
+ original_panel_contract: dict[str, Any] | None = None,
37
+ route_decision: Any | None = None,
38
+ ) -> Provenance:
39
+ """Create a provenance record for a forecasting run.
40
+
41
+ Args:
42
+ data: Input data
43
+ task_spec: Task specification
44
+ plan: Execution plan
45
+ model_config: Model configuration
46
+ qa_repairs: List of QA repairs applied (RepairReport objects)
47
+ fallbacks_triggered: List of fallback events
48
+ feature_matrix: Optional FeatureMatrix for feature signature (v0.2)
49
+ drift_report: Optional DriftReport for drift info (v0.2)
50
+ route_decision: Optional RouteDecision for routing audit trail (v1.0)
51
+
52
+ Returns:
53
+ Provenance object with signatures and metadata
54
+ """
55
+ from datetime import datetime
56
+
57
+ from tsagentkit.contracts import Provenance
58
+
59
+ metadata: dict[str, Any] = {}
60
+
61
+ # v0.2: Add feature signature if available
62
+ if feature_matrix is not None:
63
+ metadata["feature_signature"] = feature_matrix.signature
64
+ metadata["feature_config_hash"] = feature_matrix.config_hash
65
+ metadata["n_features"] = len(feature_matrix.feature_cols)
66
+
67
+ # v0.2: Add drift info if available
68
+ if drift_report is not None:
69
+ metadata["drift_detected"] = drift_report.drift_detected
70
+ metadata["drift_score"] = drift_report.overall_drift_score
71
+ metadata["drift_threshold"] = drift_report.threshold_used
72
+ if drift_report.drift_detected:
73
+ metadata["drifting_features"] = drift_report.get_drifting_features()
74
+
75
+ if column_map:
76
+ metadata["column_map"] = column_map
77
+ if original_panel_contract:
78
+ metadata["original_panel_contract"] = original_panel_contract
79
+
80
+ # v1.0: Add route decision for audit trail
81
+ if route_decision is not None:
82
+ metadata["route_decision"] = {
83
+ "buckets": route_decision.buckets,
84
+ "reasons": route_decision.reasons,
85
+ "stats": route_decision.stats,
86
+ }
87
+
88
+ from tsagentkit.router import compute_plan_signature
89
+
90
+ # Convert RepairReport objects to dicts for serialization
91
+ repairs_serialized: list[dict[str, Any]] = []
92
+ for repair in (qa_repairs or []):
93
+ if hasattr(repair, "to_dict"):
94
+ repairs_serialized.append(repair.to_dict())
95
+ else:
96
+ repairs_serialized.append(repair)
97
+
98
+ return Provenance(
99
+ run_id=str(uuid4()),
100
+ timestamp=datetime.now(UTC).isoformat(),
101
+ data_signature=compute_data_signature(data),
102
+ task_signature=task_spec.model_hash(),
103
+ plan_signature=compute_plan_signature(plan),
104
+ model_signature=compute_config_signature(model_config or {}),
105
+ qa_repairs=repairs_serialized,
106
+ fallbacks_triggered=fallbacks_triggered or [],
107
+ metadata=metadata,
108
+ )
109
+
110
+
111
+ def log_event(
112
+ step_name: str,
113
+ status: str,
114
+ duration_ms: float,
115
+ error_code: str | None = None,
116
+ artifacts_generated: list[str] | None = None,
117
+ context: dict[str, Any] | None = None,
118
+ ) -> dict[str, Any]:
119
+ """Log a structured event.
120
+
121
+ Creates an event dictionary with all required fields per PRD section 6.2:
122
+ - step_name: Pipeline step name
123
+ - status: Execution status
124
+ - duration_ms: Execution duration
125
+ - error_code: Error code if applicable
126
+ - artifacts_generated: List of generated artifacts
127
+ - timestamp: ISO 8601 timestamp
128
+ - context: Additional context
129
+
130
+ Args:
131
+ step_name: Name of the pipeline step
132
+ status: Status (e.g., "success", "failed")
133
+ duration_ms: Duration in milliseconds
134
+ error_code: Error code if failed
135
+ artifacts_generated: List of artifact names generated
136
+ context: Additional context dictionary
137
+
138
+ Returns:
139
+ Event dictionary with all structured logging fields
140
+ """
141
+ from datetime import datetime
142
+
143
+ event = {
144
+ "step_name": step_name,
145
+ "status": status,
146
+ "duration_ms": round(duration_ms, 3),
147
+ "timestamp": datetime.now(UTC).isoformat(),
148
+ "error_code": error_code,
149
+ "artifacts_generated": artifacts_generated or [],
150
+ }
151
+
152
+ if context:
153
+ event["context"] = context
154
+
155
+ return event
156
+
157
+
158
+ def format_event_json(event: dict[str, Any]) -> str:
159
+ """Format an event as JSON string for structured logging.
160
+
161
+ Args:
162
+ event: Event dictionary from log_event()
163
+
164
+ Returns:
165
+ JSON string representation
166
+ """
167
+ return json.dumps(event, sort_keys=True, separators=(",", ":"))
168
+
169
+
170
+ class StructuredLogger:
171
+ """Structured logger for tsagentkit pipeline events.
172
+
173
+ Provides consistent JSON-formatted logging with all required fields
174
+ per PRD section 6.2 (Observability & Error Codes).
175
+
176
+ Example:
177
+ >>> logger = StructuredLogger()
178
+ >>> logger.start_step("fit")
179
+ >>> # ... do work ...
180
+ >>> event = logger.end_step("fit", status="success")
181
+ >>> print(logger.to_json())
182
+
183
+ Attributes:
184
+ events: List of logged events
185
+ """
186
+
187
+ def __init__(self) -> None:
188
+ """Initialize the structured logger."""
189
+ self.events: list[dict[str, Any]] = []
190
+ self._start_times: dict[str, float] = {}
191
+
192
+ def start_step(self, step_name: str) -> None:
193
+ """Record the start time for a step.
194
+
195
+ Args:
196
+ step_name: Name of the step
197
+ """
198
+ import time
199
+
200
+ self._start_times[step_name] = time.time()
201
+
202
+ def end_step(
203
+ self,
204
+ step_name: str,
205
+ status: str = "success",
206
+ error_code: str | None = None,
207
+ artifacts_generated: list[str] | None = None,
208
+ context: dict[str, Any] | None = None,
209
+ ) -> dict[str, Any]:
210
+ """End a step and log the event.
211
+
212
+ Args:
213
+ step_name: Name of the step
214
+ status: Execution status
215
+ error_code: Error code if failed
216
+ artifacts_generated: List of artifacts generated
217
+ context: Additional context
218
+
219
+ Returns:
220
+ The logged event
221
+ """
222
+ import time
223
+
224
+ start_time = self._start_times.get(step_name)
225
+ if start_time is not None:
226
+ duration_ms = (time.time() - start_time) * 1000
227
+ del self._start_times[step_name]
228
+ else:
229
+ duration_ms = 0.0
230
+
231
+ event = log_event(
232
+ step_name=step_name,
233
+ status=status,
234
+ duration_ms=duration_ms,
235
+ error_code=error_code,
236
+ artifacts_generated=artifacts_generated,
237
+ context=context,
238
+ )
239
+
240
+ self.events.append(event)
241
+ return event
242
+
243
+ def log(
244
+ self,
245
+ step_name: str,
246
+ status: str,
247
+ duration_ms: float,
248
+ error_code: str | None = None,
249
+ artifacts_generated: list[str] | None = None,
250
+ context: dict[str, Any] | None = None,
251
+ ) -> dict[str, Any]:
252
+ """Log an event directly.
253
+
254
+ Args:
255
+ step_name: Name of the step
256
+ status: Execution status
257
+ duration_ms: Duration in milliseconds
258
+ error_code: Error code if failed
259
+ artifacts_generated: List of artifacts generated
260
+ context: Additional context
261
+
262
+ Returns:
263
+ The logged event
264
+ """
265
+ event = log_event(
266
+ step_name=step_name,
267
+ status=status,
268
+ duration_ms=duration_ms,
269
+ error_code=error_code,
270
+ artifacts_generated=artifacts_generated,
271
+ context=context,
272
+ )
273
+
274
+ self.events.append(event)
275
+ return event
276
+
277
+ def to_json(self) -> str:
278
+ """Export all events as JSON.
279
+
280
+ Returns:
281
+ JSON array string
282
+ """
283
+ return json.dumps(self.events, sort_keys=True, separators=(",", ":"))
284
+
285
+ def get_events(self) -> list[dict[str, Any]]:
286
+ """Get all logged events.
287
+
288
+ Returns:
289
+ List of event dictionaries (copy)
290
+ """
291
+ return self.events.copy()
292
+
293
+ def to_dict(self) -> list[dict[str, Any]]:
294
+ """Export all events as list of dictionaries.
295
+
296
+ Returns:
297
+ List of event dictionaries
298
+ """
299
+ return self.events.copy()
300
+
301
+ def get_summary(self) -> dict[str, Any]:
302
+ """Get summary statistics of logged events.
303
+
304
+ Returns:
305
+ Summary dictionary with counts and timing
306
+ """
307
+ total_duration = sum(e.get("duration_ms", 0) for e in self.events)
308
+ success_count = sum(1 for e in self.events if e.get("status") == "success")
309
+ failed_count = sum(1 for e in self.events if e.get("status") == "failed")
310
+
311
+ return {
312
+ "total_events": len(self.events),
313
+ "success_count": success_count,
314
+ "failed_count": failed_count,
315
+ "total_duration_ms": round(total_duration, 3),
316
+ "steps": [e.get("step_name") for e in self.events],
317
+ }
@@ -0,0 +1,214 @@
1
+ """TSFM (Time-Series Foundation Model) caching for serving.
2
+
3
+ Provides model caching and lazy loading for TSFM adapters to enable
4
+ efficient inference in production serving environments.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import threading
10
+ from typing import TYPE_CHECKING, Any
11
+
12
+ if TYPE_CHECKING:
13
+ from tsagentkit.models.adapters import TSFMAdapter
14
+
15
+
16
+ class TSFMModelCache:
17
+ """Thread-safe cache for TSFM model instances.
18
+
19
+ Implements a singleton cache pattern for TSFM adapters to avoid
20
+ reloading large foundation models on each request. Uses weak
21
+ references to allow garbage collection when memory is constrained.
22
+
23
+ Attributes:
24
+ _cache: Dictionary mapping model names to cached instances
25
+ _lock: Threading lock for concurrent access
26
+ _metadata: Cache metadata (load time, access count, etc.)
27
+
28
+ Example:
29
+ >>> cache = TSFMModelCache()
30
+ >>> model = cache.get_model("chronos", pipeline="large")
31
+ >>> # Model is loaded once and cached for subsequent calls
32
+ """
33
+
34
+ _instance: TSFMModelCache | None = None
35
+ _instance_lock = threading.Lock()
36
+
37
+ def __new__(cls) -> TSFMModelCache:
38
+ """Ensure singleton pattern for global cache."""
39
+ if cls._instance is None:
40
+ with cls._instance_lock:
41
+ if cls._instance is None:
42
+ cls._instance = super().__new__(cls)
43
+ cls._instance._initialized = False
44
+ return cls._instance
45
+
46
+ def __init__(self) -> None:
47
+ """Initialize the cache (only runs once due to singleton)."""
48
+ if getattr(self, "_initialized", False):
49
+ return
50
+
51
+ self._cache: dict[str, Any] = {}
52
+ self._lock = threading.RLock()
53
+ self._metadata: dict[str, dict[str, Any]] = {}
54
+ self._initialized = True
55
+
56
+ def get_model(
57
+ self,
58
+ model_name: str,
59
+ **model_kwargs,
60
+ ) -> TSFMAdapter:
61
+ """Get a TSFM model from cache or load it.
62
+
63
+ Args:
64
+ model_name: Name of the TSFM (e.g., "chronos", "moirai", "timesfm")
65
+ **model_kwargs: Arguments passed to model initialization
66
+ (e.g., pipeline="large", device="cuda")
67
+
68
+ Returns:
69
+ Cached or newly loaded TSFMAdapter instance
70
+
71
+ Raises:
72
+ EModelLoadFailed: If model loading fails
73
+ EAdapterNotAvailable: If adapter is not installed
74
+ """
75
+ cache_key = self._make_cache_key(model_name, **model_kwargs)
76
+
77
+ with self._lock:
78
+ # Check cache
79
+ if cache_key in self._cache:
80
+ model = self._cache[cache_key]
81
+ self._metadata[cache_key]["access_count"] += 1
82
+ return model
83
+
84
+ # Load model
85
+ model = self._load_model(model_name, **model_kwargs)
86
+
87
+ # Cache model
88
+ self._cache[cache_key] = model
89
+ self._metadata[cache_key] = {
90
+ "model_name": model_name,
91
+ "load_time": self._get_timestamp(),
92
+ "access_count": 1,
93
+ "kwargs": model_kwargs,
94
+ }
95
+
96
+ return model
97
+
98
+ def clear_cache(self, model_name: str | None = None) -> None:
99
+ """Clear cached models.
100
+
101
+ Args:
102
+ model_name: If specified, only clear this model. Otherwise clear all.
103
+ """
104
+ with self._lock:
105
+ if model_name is None:
106
+ self._cache.clear()
107
+ self._metadata.clear()
108
+ else:
109
+ keys_to_remove = [
110
+ k for k, v in self._metadata.items()
111
+ if v.get("model_name") == model_name
112
+ ]
113
+ for key in keys_to_remove:
114
+ del self._cache[key]
115
+ del self._metadata[key]
116
+
117
+ def get_cache_stats(self) -> dict[str, Any]:
118
+ """Get cache statistics.
119
+
120
+ Returns:
121
+ Dictionary with cache statistics:
122
+ - num_models: Number of models in cache
123
+ - models: List of cached model names
124
+ - total_accesses: Total access count across all models
125
+ """
126
+ with self._lock:
127
+ return {
128
+ "num_models": len(self._cache),
129
+ "models": [
130
+ v["model_name"] for v in self._metadata.values()
131
+ ],
132
+ "total_accesses": sum(
133
+ v["access_count"] for v in self._metadata.values()
134
+ ),
135
+ "details": {
136
+ k: {
137
+ "model_name": v["model_name"],
138
+ "access_count": v["access_count"],
139
+ "load_time": v["load_time"],
140
+ }
141
+ for k, v in self._metadata.items()
142
+ },
143
+ }
144
+
145
+ def _make_cache_key(self, model_name: str, **kwargs) -> str:
146
+ """Create a unique cache key from model name and kwargs."""
147
+ # Sort kwargs for consistent keys
148
+ kv_pairs = sorted(kwargs.items())
149
+ kv_str = ",".join(f"{k}={v}" for k, v in kv_pairs)
150
+ return f"{model_name}:{kv_str}"
151
+
152
+ def _load_model(self, model_name: str, **kwargs) -> TSFMAdapter:
153
+ """Load a TSFM model via the adapter registry."""
154
+ from tsagentkit.models.adapters import AdapterConfig, AdapterRegistry
155
+
156
+ try:
157
+ adapter_class = AdapterRegistry.get(model_name)
158
+ except ValueError as exc:
159
+ from tsagentkit.contracts import EAdapterNotAvailable
160
+
161
+ raise EAdapterNotAvailable(
162
+ f"TSFM adapter '{model_name}' not found. "
163
+ f"Ensure the required package is installed.",
164
+ context={"adapter_name": model_name, "error": str(exc)},
165
+ ) from exc
166
+
167
+ try:
168
+ config = AdapterConfig(model_name=model_name, **kwargs)
169
+ return adapter_class(config)
170
+ except Exception as exc:
171
+ from tsagentkit.contracts import EModelLoadFailed
172
+
173
+ raise EModelLoadFailed(
174
+ f"Failed to load TSFM model '{model_name}': {exc}",
175
+ context={"adapter_name": model_name, "error": str(exc)},
176
+ ) from exc
177
+
178
+ def _get_timestamp(self) -> float:
179
+ """Get current timestamp."""
180
+ import time
181
+
182
+ return time.time()
183
+
184
+
185
+ def get_tsfm_model(model_name: str, **kwargs) -> TSFMAdapter:
186
+ """Convenience function to get a cached TSFM model.
187
+
188
+ Args:
189
+ model_name: Name of the TSFM (e.g., "chronos", "moirai", "timesfm")
190
+ **kwargs: Model initialization arguments
191
+
192
+ Returns:
193
+ TSFMAdapter instance (cached or newly loaded)
194
+
195
+ Example:
196
+ >>> model = get_tsfm_model("chronos", pipeline="large")
197
+ >>> forecast = model.predict(series, horizon=7)
198
+ """
199
+ cache = TSFMModelCache()
200
+ return cache.get_model(model_name, **kwargs)
201
+
202
+
203
+ def clear_tsfm_cache(model_name: str | None = None) -> None:
204
+ """Clear TSFM model cache.
205
+
206
+ Args:
207
+ model_name: If specified, only clear this model. Otherwise clear all.
208
+
209
+ Example:
210
+ >>> clear_tsfm_cache("chronos") # Clear only Chronos
211
+ >>> clear_tsfm_cache() # Clear all cached models
212
+ """
213
+ cache = TSFMModelCache()
214
+ cache.clear_cache(model_name)