tsagentkit 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tsagentkit/__init__.py +126 -0
- tsagentkit/anomaly/__init__.py +130 -0
- tsagentkit/backtest/__init__.py +48 -0
- tsagentkit/backtest/engine.py +788 -0
- tsagentkit/backtest/metrics.py +244 -0
- tsagentkit/backtest/report.py +342 -0
- tsagentkit/calibration/__init__.py +136 -0
- tsagentkit/contracts/__init__.py +133 -0
- tsagentkit/contracts/errors.py +275 -0
- tsagentkit/contracts/results.py +418 -0
- tsagentkit/contracts/schema.py +44 -0
- tsagentkit/contracts/task_spec.py +300 -0
- tsagentkit/covariates/__init__.py +340 -0
- tsagentkit/eval/__init__.py +285 -0
- tsagentkit/features/__init__.py +20 -0
- tsagentkit/features/covariates.py +328 -0
- tsagentkit/features/extra/__init__.py +5 -0
- tsagentkit/features/extra/native.py +179 -0
- tsagentkit/features/factory.py +187 -0
- tsagentkit/features/matrix.py +159 -0
- tsagentkit/features/tsfeatures_adapter.py +115 -0
- tsagentkit/features/versioning.py +203 -0
- tsagentkit/hierarchy/__init__.py +39 -0
- tsagentkit/hierarchy/aggregation.py +62 -0
- tsagentkit/hierarchy/evaluator.py +400 -0
- tsagentkit/hierarchy/reconciliation.py +232 -0
- tsagentkit/hierarchy/structure.py +453 -0
- tsagentkit/models/__init__.py +182 -0
- tsagentkit/models/adapters/__init__.py +83 -0
- tsagentkit/models/adapters/base.py +321 -0
- tsagentkit/models/adapters/chronos.py +387 -0
- tsagentkit/models/adapters/moirai.py +256 -0
- tsagentkit/models/adapters/registry.py +171 -0
- tsagentkit/models/adapters/timesfm.py +440 -0
- tsagentkit/models/baselines.py +207 -0
- tsagentkit/models/sktime.py +307 -0
- tsagentkit/monitoring/__init__.py +51 -0
- tsagentkit/monitoring/alerts.py +302 -0
- tsagentkit/monitoring/coverage.py +203 -0
- tsagentkit/monitoring/drift.py +330 -0
- tsagentkit/monitoring/report.py +214 -0
- tsagentkit/monitoring/stability.py +275 -0
- tsagentkit/monitoring/triggers.py +423 -0
- tsagentkit/qa/__init__.py +347 -0
- tsagentkit/router/__init__.py +37 -0
- tsagentkit/router/bucketing.py +489 -0
- tsagentkit/router/fallback.py +132 -0
- tsagentkit/router/plan.py +23 -0
- tsagentkit/router/router.py +271 -0
- tsagentkit/series/__init__.py +26 -0
- tsagentkit/series/alignment.py +206 -0
- tsagentkit/series/dataset.py +449 -0
- tsagentkit/series/sparsity.py +261 -0
- tsagentkit/series/validation.py +393 -0
- tsagentkit/serving/__init__.py +39 -0
- tsagentkit/serving/orchestration.py +943 -0
- tsagentkit/serving/packaging.py +73 -0
- tsagentkit/serving/provenance.py +317 -0
- tsagentkit/serving/tsfm_cache.py +214 -0
- tsagentkit/skill/README.md +135 -0
- tsagentkit/skill/__init__.py +8 -0
- tsagentkit/skill/recipes.md +429 -0
- tsagentkit/skill/tool_map.md +21 -0
- tsagentkit/time/__init__.py +134 -0
- tsagentkit/utils/__init__.py +20 -0
- tsagentkit/utils/quantiles.py +83 -0
- tsagentkit/utils/signature.py +47 -0
- tsagentkit/utils/temporal.py +41 -0
- tsagentkit-1.0.2.dist-info/METADATA +371 -0
- tsagentkit-1.0.2.dist-info/RECORD +72 -0
- tsagentkit-1.0.2.dist-info/WHEEL +4 -0
- tsagentkit-1.0.2.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""Run artifact packaging.
|
|
2
|
+
|
|
3
|
+
Bundles all outputs from a forecasting run into a comprehensive artifact.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from typing import TYPE_CHECKING, Any
|
|
9
|
+
|
|
10
|
+
from tsagentkit.contracts import RunArtifact
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from tsagentkit.backtest import BacktestReport
|
|
14
|
+
from tsagentkit.contracts import ForecastResult, ModelArtifact, Provenance
|
|
15
|
+
from tsagentkit.qa import QAReport
|
|
16
|
+
from tsagentkit.router import PlanSpec
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def package_run(
|
|
20
|
+
forecast: ForecastResult,
|
|
21
|
+
plan: PlanSpec | dict[str, Any],
|
|
22
|
+
task_spec: Any | None = None,
|
|
23
|
+
plan_spec: dict[str, Any] | None = None,
|
|
24
|
+
validation_report: dict[str, Any] | None = None,
|
|
25
|
+
backtest_report: BacktestReport | None = None,
|
|
26
|
+
qa_report: QAReport | None = None,
|
|
27
|
+
model_artifact: ModelArtifact | None = None,
|
|
28
|
+
provenance: Provenance | None = None,
|
|
29
|
+
calibration_artifact: dict[str, Any] | None = None,
|
|
30
|
+
anomaly_report: dict[str, Any] | None = None,
|
|
31
|
+
metadata: dict[str, Any] | None = None,
|
|
32
|
+
) -> RunArtifact:
|
|
33
|
+
"""Package all run outputs into a comprehensive artifact.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
forecast: ForecastResult with predictions + provenance
|
|
37
|
+
plan: Execution plan (PlanSpec or dict)
|
|
38
|
+
backtest_report: Optional backtest results
|
|
39
|
+
qa_report: Optional QA report
|
|
40
|
+
model_artifact: Optional fitted model
|
|
41
|
+
provenance: Optional provenance information (overrides forecast provenance)
|
|
42
|
+
calibration_artifact: Optional calibration artifact
|
|
43
|
+
anomaly_report: Optional anomaly report
|
|
44
|
+
metadata: Optional metadata
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
RunArtifact containing all run outputs
|
|
48
|
+
"""
|
|
49
|
+
if hasattr(plan, "to_dict"):
|
|
50
|
+
plan_dict = plan.to_dict()
|
|
51
|
+
elif hasattr(plan, "model_dump"):
|
|
52
|
+
plan_dict = plan.model_dump()
|
|
53
|
+
else:
|
|
54
|
+
plan_dict = plan
|
|
55
|
+
if plan_spec is None:
|
|
56
|
+
plan_spec = plan_dict
|
|
57
|
+
backtest_dict = backtest_report.to_dict() if backtest_report else None
|
|
58
|
+
qa_dict = qa_report.to_dict() if qa_report and hasattr(qa_report, "to_dict") else None
|
|
59
|
+
|
|
60
|
+
return RunArtifact(
|
|
61
|
+
forecast=forecast,
|
|
62
|
+
plan=plan_dict,
|
|
63
|
+
task_spec=task_spec,
|
|
64
|
+
plan_spec=plan_spec,
|
|
65
|
+
validation_report=validation_report,
|
|
66
|
+
backtest_report=backtest_dict,
|
|
67
|
+
qa_report=qa_dict,
|
|
68
|
+
model_artifact=model_artifact,
|
|
69
|
+
provenance=provenance or forecast.provenance,
|
|
70
|
+
calibration_artifact=calibration_artifact,
|
|
71
|
+
anomaly_report=anomaly_report,
|
|
72
|
+
metadata=metadata or {},
|
|
73
|
+
)
|
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
"""Provenance tracking for forecasting runs.
|
|
2
|
+
|
|
3
|
+
Provides utilities for tracking data lineage and reproducibility.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
from typing import TYPE_CHECKING, Any
|
|
10
|
+
from uuid import uuid4
|
|
11
|
+
|
|
12
|
+
import pandas as pd
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
from tsagentkit.contracts import Provenance, TaskSpec
|
|
18
|
+
from tsagentkit.router import PlanSpec
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
from datetime import UTC
|
|
22
|
+
|
|
23
|
+
from tsagentkit.utils import compute_config_signature, compute_data_signature
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def create_provenance(
|
|
27
|
+
data: pd.DataFrame,
|
|
28
|
+
task_spec: TaskSpec,
|
|
29
|
+
plan: PlanSpec,
|
|
30
|
+
model_config: dict[str, Any] | None = None,
|
|
31
|
+
qa_repairs: list[Any] | None = None,
|
|
32
|
+
fallbacks_triggered: list[dict[str, Any]] | None = None,
|
|
33
|
+
feature_matrix: Any | None = None,
|
|
34
|
+
drift_report: Any | None = None,
|
|
35
|
+
column_map: dict[str, str] | None = None,
|
|
36
|
+
original_panel_contract: dict[str, Any] | None = None,
|
|
37
|
+
route_decision: Any | None = None,
|
|
38
|
+
) -> Provenance:
|
|
39
|
+
"""Create a provenance record for a forecasting run.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
data: Input data
|
|
43
|
+
task_spec: Task specification
|
|
44
|
+
plan: Execution plan
|
|
45
|
+
model_config: Model configuration
|
|
46
|
+
qa_repairs: List of QA repairs applied (RepairReport objects)
|
|
47
|
+
fallbacks_triggered: List of fallback events
|
|
48
|
+
feature_matrix: Optional FeatureMatrix for feature signature (v0.2)
|
|
49
|
+
drift_report: Optional DriftReport for drift info (v0.2)
|
|
50
|
+
route_decision: Optional RouteDecision for routing audit trail (v1.0)
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
Provenance object with signatures and metadata
|
|
54
|
+
"""
|
|
55
|
+
from datetime import datetime
|
|
56
|
+
|
|
57
|
+
from tsagentkit.contracts import Provenance
|
|
58
|
+
|
|
59
|
+
metadata: dict[str, Any] = {}
|
|
60
|
+
|
|
61
|
+
# v0.2: Add feature signature if available
|
|
62
|
+
if feature_matrix is not None:
|
|
63
|
+
metadata["feature_signature"] = feature_matrix.signature
|
|
64
|
+
metadata["feature_config_hash"] = feature_matrix.config_hash
|
|
65
|
+
metadata["n_features"] = len(feature_matrix.feature_cols)
|
|
66
|
+
|
|
67
|
+
# v0.2: Add drift info if available
|
|
68
|
+
if drift_report is not None:
|
|
69
|
+
metadata["drift_detected"] = drift_report.drift_detected
|
|
70
|
+
metadata["drift_score"] = drift_report.overall_drift_score
|
|
71
|
+
metadata["drift_threshold"] = drift_report.threshold_used
|
|
72
|
+
if drift_report.drift_detected:
|
|
73
|
+
metadata["drifting_features"] = drift_report.get_drifting_features()
|
|
74
|
+
|
|
75
|
+
if column_map:
|
|
76
|
+
metadata["column_map"] = column_map
|
|
77
|
+
if original_panel_contract:
|
|
78
|
+
metadata["original_panel_contract"] = original_panel_contract
|
|
79
|
+
|
|
80
|
+
# v1.0: Add route decision for audit trail
|
|
81
|
+
if route_decision is not None:
|
|
82
|
+
metadata["route_decision"] = {
|
|
83
|
+
"buckets": route_decision.buckets,
|
|
84
|
+
"reasons": route_decision.reasons,
|
|
85
|
+
"stats": route_decision.stats,
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
from tsagentkit.router import compute_plan_signature
|
|
89
|
+
|
|
90
|
+
# Convert RepairReport objects to dicts for serialization
|
|
91
|
+
repairs_serialized: list[dict[str, Any]] = []
|
|
92
|
+
for repair in (qa_repairs or []):
|
|
93
|
+
if hasattr(repair, "to_dict"):
|
|
94
|
+
repairs_serialized.append(repair.to_dict())
|
|
95
|
+
else:
|
|
96
|
+
repairs_serialized.append(repair)
|
|
97
|
+
|
|
98
|
+
return Provenance(
|
|
99
|
+
run_id=str(uuid4()),
|
|
100
|
+
timestamp=datetime.now(UTC).isoformat(),
|
|
101
|
+
data_signature=compute_data_signature(data),
|
|
102
|
+
task_signature=task_spec.model_hash(),
|
|
103
|
+
plan_signature=compute_plan_signature(plan),
|
|
104
|
+
model_signature=compute_config_signature(model_config or {}),
|
|
105
|
+
qa_repairs=repairs_serialized,
|
|
106
|
+
fallbacks_triggered=fallbacks_triggered or [],
|
|
107
|
+
metadata=metadata,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def log_event(
|
|
112
|
+
step_name: str,
|
|
113
|
+
status: str,
|
|
114
|
+
duration_ms: float,
|
|
115
|
+
error_code: str | None = None,
|
|
116
|
+
artifacts_generated: list[str] | None = None,
|
|
117
|
+
context: dict[str, Any] | None = None,
|
|
118
|
+
) -> dict[str, Any]:
|
|
119
|
+
"""Log a structured event.
|
|
120
|
+
|
|
121
|
+
Creates an event dictionary with all required fields per PRD section 6.2:
|
|
122
|
+
- step_name: Pipeline step name
|
|
123
|
+
- status: Execution status
|
|
124
|
+
- duration_ms: Execution duration
|
|
125
|
+
- error_code: Error code if applicable
|
|
126
|
+
- artifacts_generated: List of generated artifacts
|
|
127
|
+
- timestamp: ISO 8601 timestamp
|
|
128
|
+
- context: Additional context
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
step_name: Name of the pipeline step
|
|
132
|
+
status: Status (e.g., "success", "failed")
|
|
133
|
+
duration_ms: Duration in milliseconds
|
|
134
|
+
error_code: Error code if failed
|
|
135
|
+
artifacts_generated: List of artifact names generated
|
|
136
|
+
context: Additional context dictionary
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
Event dictionary with all structured logging fields
|
|
140
|
+
"""
|
|
141
|
+
from datetime import datetime
|
|
142
|
+
|
|
143
|
+
event = {
|
|
144
|
+
"step_name": step_name,
|
|
145
|
+
"status": status,
|
|
146
|
+
"duration_ms": round(duration_ms, 3),
|
|
147
|
+
"timestamp": datetime.now(UTC).isoformat(),
|
|
148
|
+
"error_code": error_code,
|
|
149
|
+
"artifacts_generated": artifacts_generated or [],
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
if context:
|
|
153
|
+
event["context"] = context
|
|
154
|
+
|
|
155
|
+
return event
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def format_event_json(event: dict[str, Any]) -> str:
|
|
159
|
+
"""Format an event as JSON string for structured logging.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
event: Event dictionary from log_event()
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
JSON string representation
|
|
166
|
+
"""
|
|
167
|
+
return json.dumps(event, sort_keys=True, separators=(",", ":"))
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class StructuredLogger:
|
|
171
|
+
"""Structured logger for tsagentkit pipeline events.
|
|
172
|
+
|
|
173
|
+
Provides consistent JSON-formatted logging with all required fields
|
|
174
|
+
per PRD section 6.2 (Observability & Error Codes).
|
|
175
|
+
|
|
176
|
+
Example:
|
|
177
|
+
>>> logger = StructuredLogger()
|
|
178
|
+
>>> logger.start_step("fit")
|
|
179
|
+
>>> # ... do work ...
|
|
180
|
+
>>> event = logger.end_step("fit", status="success")
|
|
181
|
+
>>> print(logger.to_json())
|
|
182
|
+
|
|
183
|
+
Attributes:
|
|
184
|
+
events: List of logged events
|
|
185
|
+
"""
|
|
186
|
+
|
|
187
|
+
def __init__(self) -> None:
|
|
188
|
+
"""Initialize the structured logger."""
|
|
189
|
+
self.events: list[dict[str, Any]] = []
|
|
190
|
+
self._start_times: dict[str, float] = {}
|
|
191
|
+
|
|
192
|
+
def start_step(self, step_name: str) -> None:
|
|
193
|
+
"""Record the start time for a step.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
step_name: Name of the step
|
|
197
|
+
"""
|
|
198
|
+
import time
|
|
199
|
+
|
|
200
|
+
self._start_times[step_name] = time.time()
|
|
201
|
+
|
|
202
|
+
def end_step(
|
|
203
|
+
self,
|
|
204
|
+
step_name: str,
|
|
205
|
+
status: str = "success",
|
|
206
|
+
error_code: str | None = None,
|
|
207
|
+
artifacts_generated: list[str] | None = None,
|
|
208
|
+
context: dict[str, Any] | None = None,
|
|
209
|
+
) -> dict[str, Any]:
|
|
210
|
+
"""End a step and log the event.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
step_name: Name of the step
|
|
214
|
+
status: Execution status
|
|
215
|
+
error_code: Error code if failed
|
|
216
|
+
artifacts_generated: List of artifacts generated
|
|
217
|
+
context: Additional context
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
The logged event
|
|
221
|
+
"""
|
|
222
|
+
import time
|
|
223
|
+
|
|
224
|
+
start_time = self._start_times.get(step_name)
|
|
225
|
+
if start_time is not None:
|
|
226
|
+
duration_ms = (time.time() - start_time) * 1000
|
|
227
|
+
del self._start_times[step_name]
|
|
228
|
+
else:
|
|
229
|
+
duration_ms = 0.0
|
|
230
|
+
|
|
231
|
+
event = log_event(
|
|
232
|
+
step_name=step_name,
|
|
233
|
+
status=status,
|
|
234
|
+
duration_ms=duration_ms,
|
|
235
|
+
error_code=error_code,
|
|
236
|
+
artifacts_generated=artifacts_generated,
|
|
237
|
+
context=context,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
self.events.append(event)
|
|
241
|
+
return event
|
|
242
|
+
|
|
243
|
+
def log(
|
|
244
|
+
self,
|
|
245
|
+
step_name: str,
|
|
246
|
+
status: str,
|
|
247
|
+
duration_ms: float,
|
|
248
|
+
error_code: str | None = None,
|
|
249
|
+
artifacts_generated: list[str] | None = None,
|
|
250
|
+
context: dict[str, Any] | None = None,
|
|
251
|
+
) -> dict[str, Any]:
|
|
252
|
+
"""Log an event directly.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
step_name: Name of the step
|
|
256
|
+
status: Execution status
|
|
257
|
+
duration_ms: Duration in milliseconds
|
|
258
|
+
error_code: Error code if failed
|
|
259
|
+
artifacts_generated: List of artifacts generated
|
|
260
|
+
context: Additional context
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
The logged event
|
|
264
|
+
"""
|
|
265
|
+
event = log_event(
|
|
266
|
+
step_name=step_name,
|
|
267
|
+
status=status,
|
|
268
|
+
duration_ms=duration_ms,
|
|
269
|
+
error_code=error_code,
|
|
270
|
+
artifacts_generated=artifacts_generated,
|
|
271
|
+
context=context,
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
self.events.append(event)
|
|
275
|
+
return event
|
|
276
|
+
|
|
277
|
+
def to_json(self) -> str:
|
|
278
|
+
"""Export all events as JSON.
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
JSON array string
|
|
282
|
+
"""
|
|
283
|
+
return json.dumps(self.events, sort_keys=True, separators=(",", ":"))
|
|
284
|
+
|
|
285
|
+
def get_events(self) -> list[dict[str, Any]]:
|
|
286
|
+
"""Get all logged events.
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
List of event dictionaries (copy)
|
|
290
|
+
"""
|
|
291
|
+
return self.events.copy()
|
|
292
|
+
|
|
293
|
+
def to_dict(self) -> list[dict[str, Any]]:
|
|
294
|
+
"""Export all events as list of dictionaries.
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
List of event dictionaries
|
|
298
|
+
"""
|
|
299
|
+
return self.events.copy()
|
|
300
|
+
|
|
301
|
+
def get_summary(self) -> dict[str, Any]:
|
|
302
|
+
"""Get summary statistics of logged events.
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
Summary dictionary with counts and timing
|
|
306
|
+
"""
|
|
307
|
+
total_duration = sum(e.get("duration_ms", 0) for e in self.events)
|
|
308
|
+
success_count = sum(1 for e in self.events if e.get("status") == "success")
|
|
309
|
+
failed_count = sum(1 for e in self.events if e.get("status") == "failed")
|
|
310
|
+
|
|
311
|
+
return {
|
|
312
|
+
"total_events": len(self.events),
|
|
313
|
+
"success_count": success_count,
|
|
314
|
+
"failed_count": failed_count,
|
|
315
|
+
"total_duration_ms": round(total_duration, 3),
|
|
316
|
+
"steps": [e.get("step_name") for e in self.events],
|
|
317
|
+
}
|
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
"""TSFM (Time-Series Foundation Model) caching for serving.
|
|
2
|
+
|
|
3
|
+
Provides model caching and lazy loading for TSFM adapters to enable
|
|
4
|
+
efficient inference in production serving environments.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import threading
|
|
10
|
+
from typing import TYPE_CHECKING, Any
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from tsagentkit.models.adapters import TSFMAdapter
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TSFMModelCache:
|
|
17
|
+
"""Thread-safe cache for TSFM model instances.
|
|
18
|
+
|
|
19
|
+
Implements a singleton cache pattern for TSFM adapters to avoid
|
|
20
|
+
reloading large foundation models on each request. Uses weak
|
|
21
|
+
references to allow garbage collection when memory is constrained.
|
|
22
|
+
|
|
23
|
+
Attributes:
|
|
24
|
+
_cache: Dictionary mapping model names to cached instances
|
|
25
|
+
_lock: Threading lock for concurrent access
|
|
26
|
+
_metadata: Cache metadata (load time, access count, etc.)
|
|
27
|
+
|
|
28
|
+
Example:
|
|
29
|
+
>>> cache = TSFMModelCache()
|
|
30
|
+
>>> model = cache.get_model("chronos", pipeline="large")
|
|
31
|
+
>>> # Model is loaded once and cached for subsequent calls
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
_instance: TSFMModelCache | None = None
|
|
35
|
+
_instance_lock = threading.Lock()
|
|
36
|
+
|
|
37
|
+
def __new__(cls) -> TSFMModelCache:
|
|
38
|
+
"""Ensure singleton pattern for global cache."""
|
|
39
|
+
if cls._instance is None:
|
|
40
|
+
with cls._instance_lock:
|
|
41
|
+
if cls._instance is None:
|
|
42
|
+
cls._instance = super().__new__(cls)
|
|
43
|
+
cls._instance._initialized = False
|
|
44
|
+
return cls._instance
|
|
45
|
+
|
|
46
|
+
def __init__(self) -> None:
|
|
47
|
+
"""Initialize the cache (only runs once due to singleton)."""
|
|
48
|
+
if getattr(self, "_initialized", False):
|
|
49
|
+
return
|
|
50
|
+
|
|
51
|
+
self._cache: dict[str, Any] = {}
|
|
52
|
+
self._lock = threading.RLock()
|
|
53
|
+
self._metadata: dict[str, dict[str, Any]] = {}
|
|
54
|
+
self._initialized = True
|
|
55
|
+
|
|
56
|
+
def get_model(
|
|
57
|
+
self,
|
|
58
|
+
model_name: str,
|
|
59
|
+
**model_kwargs,
|
|
60
|
+
) -> TSFMAdapter:
|
|
61
|
+
"""Get a TSFM model from cache or load it.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
model_name: Name of the TSFM (e.g., "chronos", "moirai", "timesfm")
|
|
65
|
+
**model_kwargs: Arguments passed to model initialization
|
|
66
|
+
(e.g., pipeline="large", device="cuda")
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Cached or newly loaded TSFMAdapter instance
|
|
70
|
+
|
|
71
|
+
Raises:
|
|
72
|
+
EModelLoadFailed: If model loading fails
|
|
73
|
+
EAdapterNotAvailable: If adapter is not installed
|
|
74
|
+
"""
|
|
75
|
+
cache_key = self._make_cache_key(model_name, **model_kwargs)
|
|
76
|
+
|
|
77
|
+
with self._lock:
|
|
78
|
+
# Check cache
|
|
79
|
+
if cache_key in self._cache:
|
|
80
|
+
model = self._cache[cache_key]
|
|
81
|
+
self._metadata[cache_key]["access_count"] += 1
|
|
82
|
+
return model
|
|
83
|
+
|
|
84
|
+
# Load model
|
|
85
|
+
model = self._load_model(model_name, **model_kwargs)
|
|
86
|
+
|
|
87
|
+
# Cache model
|
|
88
|
+
self._cache[cache_key] = model
|
|
89
|
+
self._metadata[cache_key] = {
|
|
90
|
+
"model_name": model_name,
|
|
91
|
+
"load_time": self._get_timestamp(),
|
|
92
|
+
"access_count": 1,
|
|
93
|
+
"kwargs": model_kwargs,
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
return model
|
|
97
|
+
|
|
98
|
+
def clear_cache(self, model_name: str | None = None) -> None:
|
|
99
|
+
"""Clear cached models.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
model_name: If specified, only clear this model. Otherwise clear all.
|
|
103
|
+
"""
|
|
104
|
+
with self._lock:
|
|
105
|
+
if model_name is None:
|
|
106
|
+
self._cache.clear()
|
|
107
|
+
self._metadata.clear()
|
|
108
|
+
else:
|
|
109
|
+
keys_to_remove = [
|
|
110
|
+
k for k, v in self._metadata.items()
|
|
111
|
+
if v.get("model_name") == model_name
|
|
112
|
+
]
|
|
113
|
+
for key in keys_to_remove:
|
|
114
|
+
del self._cache[key]
|
|
115
|
+
del self._metadata[key]
|
|
116
|
+
|
|
117
|
+
def get_cache_stats(self) -> dict[str, Any]:
|
|
118
|
+
"""Get cache statistics.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
Dictionary with cache statistics:
|
|
122
|
+
- num_models: Number of models in cache
|
|
123
|
+
- models: List of cached model names
|
|
124
|
+
- total_accesses: Total access count across all models
|
|
125
|
+
"""
|
|
126
|
+
with self._lock:
|
|
127
|
+
return {
|
|
128
|
+
"num_models": len(self._cache),
|
|
129
|
+
"models": [
|
|
130
|
+
v["model_name"] for v in self._metadata.values()
|
|
131
|
+
],
|
|
132
|
+
"total_accesses": sum(
|
|
133
|
+
v["access_count"] for v in self._metadata.values()
|
|
134
|
+
),
|
|
135
|
+
"details": {
|
|
136
|
+
k: {
|
|
137
|
+
"model_name": v["model_name"],
|
|
138
|
+
"access_count": v["access_count"],
|
|
139
|
+
"load_time": v["load_time"],
|
|
140
|
+
}
|
|
141
|
+
for k, v in self._metadata.items()
|
|
142
|
+
},
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
def _make_cache_key(self, model_name: str, **kwargs) -> str:
|
|
146
|
+
"""Create a unique cache key from model name and kwargs."""
|
|
147
|
+
# Sort kwargs for consistent keys
|
|
148
|
+
kv_pairs = sorted(kwargs.items())
|
|
149
|
+
kv_str = ",".join(f"{k}={v}" for k, v in kv_pairs)
|
|
150
|
+
return f"{model_name}:{kv_str}"
|
|
151
|
+
|
|
152
|
+
def _load_model(self, model_name: str, **kwargs) -> TSFMAdapter:
|
|
153
|
+
"""Load a TSFM model via the adapter registry."""
|
|
154
|
+
from tsagentkit.models.adapters import AdapterConfig, AdapterRegistry
|
|
155
|
+
|
|
156
|
+
try:
|
|
157
|
+
adapter_class = AdapterRegistry.get(model_name)
|
|
158
|
+
except ValueError as exc:
|
|
159
|
+
from tsagentkit.contracts import EAdapterNotAvailable
|
|
160
|
+
|
|
161
|
+
raise EAdapterNotAvailable(
|
|
162
|
+
f"TSFM adapter '{model_name}' not found. "
|
|
163
|
+
f"Ensure the required package is installed.",
|
|
164
|
+
context={"adapter_name": model_name, "error": str(exc)},
|
|
165
|
+
) from exc
|
|
166
|
+
|
|
167
|
+
try:
|
|
168
|
+
config = AdapterConfig(model_name=model_name, **kwargs)
|
|
169
|
+
return adapter_class(config)
|
|
170
|
+
except Exception as exc:
|
|
171
|
+
from tsagentkit.contracts import EModelLoadFailed
|
|
172
|
+
|
|
173
|
+
raise EModelLoadFailed(
|
|
174
|
+
f"Failed to load TSFM model '{model_name}': {exc}",
|
|
175
|
+
context={"adapter_name": model_name, "error": str(exc)},
|
|
176
|
+
) from exc
|
|
177
|
+
|
|
178
|
+
def _get_timestamp(self) -> float:
|
|
179
|
+
"""Get current timestamp."""
|
|
180
|
+
import time
|
|
181
|
+
|
|
182
|
+
return time.time()
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def get_tsfm_model(model_name: str, **kwargs) -> TSFMAdapter:
|
|
186
|
+
"""Convenience function to get a cached TSFM model.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
model_name: Name of the TSFM (e.g., "chronos", "moirai", "timesfm")
|
|
190
|
+
**kwargs: Model initialization arguments
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
TSFMAdapter instance (cached or newly loaded)
|
|
194
|
+
|
|
195
|
+
Example:
|
|
196
|
+
>>> model = get_tsfm_model("chronos", pipeline="large")
|
|
197
|
+
>>> forecast = model.predict(series, horizon=7)
|
|
198
|
+
"""
|
|
199
|
+
cache = TSFMModelCache()
|
|
200
|
+
return cache.get_model(model_name, **kwargs)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def clear_tsfm_cache(model_name: str | None = None) -> None:
|
|
204
|
+
"""Clear TSFM model cache.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
model_name: If specified, only clear this model. Otherwise clear all.
|
|
208
|
+
|
|
209
|
+
Example:
|
|
210
|
+
>>> clear_tsfm_cache("chronos") # Clear only Chronos
|
|
211
|
+
>>> clear_tsfm_cache() # Clear all cached models
|
|
212
|
+
"""
|
|
213
|
+
cache = TSFMModelCache()
|
|
214
|
+
cache.clear_cache(model_name)
|