tsagentkit 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. tsagentkit/__init__.py +126 -0
  2. tsagentkit/anomaly/__init__.py +130 -0
  3. tsagentkit/backtest/__init__.py +48 -0
  4. tsagentkit/backtest/engine.py +788 -0
  5. tsagentkit/backtest/metrics.py +244 -0
  6. tsagentkit/backtest/report.py +342 -0
  7. tsagentkit/calibration/__init__.py +136 -0
  8. tsagentkit/contracts/__init__.py +133 -0
  9. tsagentkit/contracts/errors.py +275 -0
  10. tsagentkit/contracts/results.py +418 -0
  11. tsagentkit/contracts/schema.py +44 -0
  12. tsagentkit/contracts/task_spec.py +300 -0
  13. tsagentkit/covariates/__init__.py +340 -0
  14. tsagentkit/eval/__init__.py +285 -0
  15. tsagentkit/features/__init__.py +20 -0
  16. tsagentkit/features/covariates.py +328 -0
  17. tsagentkit/features/extra/__init__.py +5 -0
  18. tsagentkit/features/extra/native.py +179 -0
  19. tsagentkit/features/factory.py +187 -0
  20. tsagentkit/features/matrix.py +159 -0
  21. tsagentkit/features/tsfeatures_adapter.py +115 -0
  22. tsagentkit/features/versioning.py +203 -0
  23. tsagentkit/hierarchy/__init__.py +39 -0
  24. tsagentkit/hierarchy/aggregation.py +62 -0
  25. tsagentkit/hierarchy/evaluator.py +400 -0
  26. tsagentkit/hierarchy/reconciliation.py +232 -0
  27. tsagentkit/hierarchy/structure.py +453 -0
  28. tsagentkit/models/__init__.py +182 -0
  29. tsagentkit/models/adapters/__init__.py +83 -0
  30. tsagentkit/models/adapters/base.py +321 -0
  31. tsagentkit/models/adapters/chronos.py +387 -0
  32. tsagentkit/models/adapters/moirai.py +256 -0
  33. tsagentkit/models/adapters/registry.py +171 -0
  34. tsagentkit/models/adapters/timesfm.py +440 -0
  35. tsagentkit/models/baselines.py +207 -0
  36. tsagentkit/models/sktime.py +307 -0
  37. tsagentkit/monitoring/__init__.py +51 -0
  38. tsagentkit/monitoring/alerts.py +302 -0
  39. tsagentkit/monitoring/coverage.py +203 -0
  40. tsagentkit/monitoring/drift.py +330 -0
  41. tsagentkit/monitoring/report.py +214 -0
  42. tsagentkit/monitoring/stability.py +275 -0
  43. tsagentkit/monitoring/triggers.py +423 -0
  44. tsagentkit/qa/__init__.py +347 -0
  45. tsagentkit/router/__init__.py +37 -0
  46. tsagentkit/router/bucketing.py +489 -0
  47. tsagentkit/router/fallback.py +132 -0
  48. tsagentkit/router/plan.py +23 -0
  49. tsagentkit/router/router.py +271 -0
  50. tsagentkit/series/__init__.py +26 -0
  51. tsagentkit/series/alignment.py +206 -0
  52. tsagentkit/series/dataset.py +449 -0
  53. tsagentkit/series/sparsity.py +261 -0
  54. tsagentkit/series/validation.py +393 -0
  55. tsagentkit/serving/__init__.py +39 -0
  56. tsagentkit/serving/orchestration.py +943 -0
  57. tsagentkit/serving/packaging.py +73 -0
  58. tsagentkit/serving/provenance.py +317 -0
  59. tsagentkit/serving/tsfm_cache.py +214 -0
  60. tsagentkit/skill/README.md +135 -0
  61. tsagentkit/skill/__init__.py +8 -0
  62. tsagentkit/skill/recipes.md +429 -0
  63. tsagentkit/skill/tool_map.md +21 -0
  64. tsagentkit/time/__init__.py +134 -0
  65. tsagentkit/utils/__init__.py +20 -0
  66. tsagentkit/utils/quantiles.py +83 -0
  67. tsagentkit/utils/signature.py +47 -0
  68. tsagentkit/utils/temporal.py +41 -0
  69. tsagentkit-1.0.2.dist-info/METADATA +371 -0
  70. tsagentkit-1.0.2.dist-info/RECORD +72 -0
  71. tsagentkit-1.0.2.dist-info/WHEEL +4 -0
  72. tsagentkit-1.0.2.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,489 @@
1
+ """Data bucketing for advanced router strategies.
2
+
3
+ Provides Head vs Tail and Short vs Long history bucketing for
4
+ series-specific model selection.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass, field
10
+ from enum import Enum
11
+ from typing import TYPE_CHECKING
12
+
13
+ import pandas as pd
14
+
15
+ if TYPE_CHECKING:
16
+ from tsagentkit.series import SparsityProfile, TSDataset
17
+
18
+
19
+ class SeriesBucket(Enum):
20
+ """Buckets for series classification.
21
+
22
+ - HEAD: High volume/frequent series (top 20% by default)
23
+ - TAIL: Low volume/infrequent series (bottom 20% by default)
24
+ - SHORT_HISTORY: Few observations (< 30 by default)
25
+ - LONG_HISTORY: Many observations (> 365 by default)
26
+ """
27
+
28
+ HEAD = "head"
29
+ TAIL = "tail"
30
+ SHORT_HISTORY = "short_history"
31
+ LONG_HISTORY = "long_history"
32
+
33
+
34
+ @dataclass(frozen=True)
35
+ class BucketStatistics:
36
+ """Statistics for a bucket.
37
+
38
+ Attributes:
39
+ series_count: Number of series in the bucket
40
+ total_observations: Total observations across all series
41
+ avg_observations: Average observations per series
42
+ avg_value: Average value (e.g., mean sales)
43
+ value_percentile: Percentile of total value volume
44
+ """
45
+
46
+ series_count: int
47
+ total_observations: int
48
+ avg_observations: float
49
+ avg_value: float
50
+ value_percentile: float
51
+
52
+
53
+ @dataclass(frozen=True)
54
+ class BucketConfig:
55
+ """Configuration for data bucketing thresholds.
56
+
57
+ Attributes:
58
+ head_quantile_threshold: Quantile for HEAD bucket (default: 0.8, top 20%)
59
+ tail_quantile_threshold: Quantile for TAIL bucket (default: 0.2, bottom 20%)
60
+ short_history_max_obs: Max observations for SHORT_HISTORY (default: 30)
61
+ long_history_min_obs: Min observations for LONG_HISTORY (default: 365)
62
+ prefer_sparsity: Whether sparsity classification trumps volume (default: True)
63
+ value_col: Column to use for volume calculations (default: "y")
64
+
65
+ Example:
66
+ >>> config = BucketConfig(
67
+ ... head_quantile_threshold=0.8, # Top 20% by volume = HEAD
68
+ ... short_history_max_obs=50, # < 50 obs = short history
69
+ ... )
70
+ """
71
+
72
+ head_quantile_threshold: float = 0.8
73
+ tail_quantile_threshold: float = 0.2
74
+ short_history_max_obs: int = 30
75
+ long_history_min_obs: int = 365
76
+ prefer_sparsity: bool = True
77
+ value_col: str = "y"
78
+
79
+ def __post_init__(self) -> None:
80
+ """Validate configuration."""
81
+ if not 0 < self.head_quantile_threshold <= 1:
82
+ raise ValueError("head_quantile_threshold must be in (0, 1]")
83
+ if not 0 <= self.tail_quantile_threshold < 1:
84
+ raise ValueError("tail_quantile_threshold must be in [0, 1)")
85
+ if self.tail_quantile_threshold >= self.head_quantile_threshold:
86
+ raise ValueError("tail_quantile_threshold must be < head_quantile_threshold")
87
+ if self.short_history_max_obs < 1:
88
+ raise ValueError("short_history_max_obs must be >= 1")
89
+ if self.long_history_min_obs <= self.short_history_max_obs:
90
+ raise ValueError("long_history_min_obs must be > short_history_max_obs")
91
+
92
+
93
+ @dataclass(frozen=True)
94
+ class BucketProfile:
95
+ """Bucketing profile for a dataset.
96
+
97
+ Attributes:
98
+ bucket_assignments: Dict mapping unique_id to set of assigned buckets
99
+ bucket_stats: Dict mapping bucket to BucketStatistics
100
+
101
+ Example:
102
+ >>> profile = BucketProfile(
103
+ ... bucket_assignments={"A": {SeriesBucket.HEAD, SeriesBucket.LONG_HISTORY}},
104
+ ... bucket_stats={SeriesBucket.HEAD: stats},
105
+ ... )
106
+ >>> print(profile.get_bucket_for_series("A"))
107
+ {SeriesBucket.HEAD, SeriesBucket.LONG_HISTORY}
108
+ """
109
+
110
+ bucket_assignments: dict[str, set[SeriesBucket]] = field(default_factory=dict)
111
+ bucket_stats: dict[SeriesBucket, BucketStatistics] = field(default_factory=dict)
112
+
113
+ def get_bucket_for_series(self, unique_id: str) -> set[SeriesBucket]:
114
+ """Get all buckets assigned to a series.
115
+
116
+ Args:
117
+ unique_id: Series identifier
118
+
119
+ Returns:
120
+ Set of buckets for the series (empty if not found)
121
+ """
122
+ return self.bucket_assignments.get(unique_id, set())
123
+
124
+ def get_series_in_bucket(self, bucket: SeriesBucket) -> list[str]:
125
+ """Get all series IDs in a specific bucket.
126
+
127
+ Args:
128
+ bucket: Bucket to query
129
+
130
+ Returns:
131
+ List of unique_id values in the bucket
132
+ """
133
+ return [
134
+ uid for uid, buckets in self.bucket_assignments.items()
135
+ if bucket in buckets
136
+ ]
137
+
138
+ def get_bucket_counts(self) -> dict[SeriesBucket, int]:
139
+ """Get count of series per bucket.
140
+
141
+ Returns:
142
+ Dict mapping bucket to series count
143
+ """
144
+ counts: dict[SeriesBucket, int] = dict.fromkeys(SeriesBucket, 0)
145
+ for buckets in self.bucket_assignments.values():
146
+ for bucket in buckets:
147
+ counts[bucket] = counts.get(bucket, 0) + 1
148
+ return counts
149
+
150
+ def summary(self) -> str:
151
+ """Generate human-readable summary of bucketing."""
152
+ counts = self.get_bucket_counts()
153
+ lines = ["Bucket Profile Summary:"]
154
+ for bucket in SeriesBucket:
155
+ count = counts.get(bucket, 0)
156
+ lines.append(f" {bucket.value}: {count} series")
157
+ return "\n".join(lines)
158
+
159
+
160
+ class DataBucketer:
161
+ """Bucket series by volume and history characteristics.
162
+
163
+ This class classifies series into buckets based on:
164
+ - Volume: Head (high volume) vs Tail (low volume)
165
+ - History: Short history vs Long history
166
+
167
+ Example:
168
+ >>> config = BucketConfig(
169
+ ... head_quantile_threshold=0.8,
170
+ ... short_history_max_obs=50,
171
+ ... )
172
+ >>> bucketer = DataBucketer(config)
173
+ >>> profile = bucketer.create_bucket_profile(dataset)
174
+ >>> print(profile.summary())
175
+ Bucket Profile Summary:
176
+ head: 20 series
177
+ tail: 20 series
178
+ short_history: 15 series
179
+ long_history: 25 series
180
+ """
181
+
182
+ def __init__(self, config: BucketConfig | None = None):
183
+ """Initialize the bucketer with configuration.
184
+
185
+ Args:
186
+ config: Bucketing configuration (uses defaults if None)
187
+ """
188
+ self.config = config or BucketConfig()
189
+
190
+ def bucket_by_volume(
191
+ self,
192
+ df: pd.DataFrame,
193
+ value_col: str | None = None,
194
+ ) -> dict[str, SeriesBucket]:
195
+ """Bucket series into HEAD/TAIL based on total value volume.
196
+
197
+ Args:
198
+ df: DataFrame with [unique_id, value_col] columns
199
+ value_col: Column containing values to sum (uses config default)
200
+
201
+ Returns:
202
+ Dict mapping unique_id to HEAD or TAIL bucket
203
+ """
204
+ value_col = value_col or self.config.value_col
205
+
206
+ if value_col not in df.columns:
207
+ raise ValueError(f"Value column '{value_col}' not found in data")
208
+
209
+ # Compute total volume per series
210
+ volume_by_series = df.groupby("unique_id")[value_col].sum().reset_index()
211
+ volume_by_series.columns = ["unique_id", "total_volume"]
212
+
213
+ # Sort by volume
214
+ volume_by_series = volume_by_series.sort_values("total_volume")
215
+
216
+ n_series = len(volume_by_series)
217
+ if n_series == 0:
218
+ return {}
219
+
220
+ # Compute quantile positions
221
+ volume_by_series["rank"] = range(n_series)
222
+ volume_by_series["quantile"] = volume_by_series["rank"] / (n_series - 1) if n_series > 1 else 0
223
+
224
+ # Assign buckets
225
+ buckets = {}
226
+ for _, row in volume_by_series.iterrows():
227
+ uid = row["unique_id"]
228
+ quantile = row["quantile"]
229
+
230
+ if quantile >= self.config.head_quantile_threshold:
231
+ buckets[uid] = SeriesBucket.HEAD
232
+ elif quantile <= self.config.tail_quantile_threshold:
233
+ buckets[uid] = SeriesBucket.TAIL
234
+
235
+ return buckets
236
+
237
+ def bucket_by_history_length(
238
+ self,
239
+ df: pd.DataFrame,
240
+ ) -> dict[str, SeriesBucket]:
241
+ """Bucket series into SHORT_HISTORY/LONG_HISTORY by observation count.
242
+
243
+ Args:
244
+ df: DataFrame with [unique_id] column
245
+
246
+ Returns:
247
+ Dict mapping unique_id to SHORT_HISTORY or LONG_HISTORY bucket
248
+ """
249
+ # Count observations per series
250
+ obs_counts = df.groupby("unique_id").size().reset_index(name="n_obs")
251
+
252
+ buckets = {}
253
+ for _, row in obs_counts.iterrows():
254
+ uid = row["unique_id"]
255
+ n_obs = row["n_obs"]
256
+
257
+ if n_obs <= self.config.short_history_max_obs:
258
+ buckets[uid] = SeriesBucket.SHORT_HISTORY
259
+ elif n_obs >= self.config.long_history_min_obs:
260
+ buckets[uid] = SeriesBucket.LONG_HISTORY
261
+
262
+ return buckets
263
+
264
+ def create_bucket_profile(
265
+ self,
266
+ dataset: TSDataset,
267
+ sparsity_profile: SparsityProfile | None = None,
268
+ ) -> BucketProfile:
269
+ """Create comprehensive bucket profile combining all bucketing strategies.
270
+
271
+ A series can belong to multiple buckets (e.g., HEAD + LONG_HISTORY).
272
+
273
+ Args:
274
+ dataset: TSDataset to analyze
275
+ sparsity_profile: Optional sparsity profile for sparsity-based overrides
276
+
277
+ Returns:
278
+ BucketProfile with assignments and statistics
279
+ """
280
+ # Extract DataFrame from TSDataset
281
+ if hasattr(dataset, "df"):
282
+ df = dataset.df
283
+ elif hasattr(dataset, "data"):
284
+ df = dataset.data
285
+ else:
286
+ df = dataset
287
+
288
+ # Ensure required columns exist
289
+ if "unique_id" not in df.columns:
290
+ raise ValueError("Dataset must have 'unique_id' column")
291
+
292
+ # Get volume buckets
293
+ volume_buckets = self.bucket_by_volume(df)
294
+
295
+ # Get history buckets
296
+ history_buckets = self.bucket_by_history_length(df)
297
+
298
+ # Combine bucket assignments
299
+ bucket_assignments: dict[str, set[SeriesBucket]] = {}
300
+ all_series = df["unique_id"].unique()
301
+
302
+ for uid in all_series:
303
+ buckets = set()
304
+
305
+ # Add volume-based bucket if assigned
306
+ if uid in volume_buckets:
307
+ buckets.add(volume_buckets[uid])
308
+
309
+ # Add history-based bucket if assigned
310
+ if uid in history_buckets:
311
+ buckets.add(history_buckets[uid])
312
+
313
+ # If no buckets assigned, series is "middle" - not head/tail, not short/long
314
+ # We still include it with empty bucket set
315
+ bucket_assignments[uid] = buckets
316
+
317
+ # Apply sparsity overrides if configured
318
+ if sparsity_profile is not None and self.config.prefer_sparsity:
319
+ bucket_assignments = self._apply_sparsity_overrides(
320
+ bucket_assignments, sparsity_profile
321
+ )
322
+
323
+ # Compute statistics for each bucket
324
+ bucket_stats = self._compute_bucket_stats(df, bucket_assignments)
325
+
326
+ return BucketProfile(
327
+ bucket_assignments=bucket_assignments,
328
+ bucket_stats=bucket_stats,
329
+ )
330
+
331
+ def _apply_sparsity_overrides(
332
+ self,
333
+ bucket_assignments: dict[str, set[SeriesBucket]],
334
+ sparsity_profile: SparsityProfile,
335
+ ) -> dict[str, set[SeriesBucket]]:
336
+ """Apply sparsity-based overrides to bucket assignments.
337
+
338
+ Intermittent and sparse series are treated as TAIL regardless of volume.
339
+ Cold-start series are treated as SHORT_HISTORY.
340
+
341
+ Args:
342
+ bucket_assignments: Current bucket assignments
343
+ sparsity_profile: Sparsity profile from series module
344
+
345
+ Returns:
346
+ Updated bucket assignments
347
+ """
348
+ updated = dict(bucket_assignments)
349
+
350
+ for uid in updated:
351
+ classification = sparsity_profile.get_classification(uid)
352
+
353
+ if classification.value in ("intermittent", "sparse"):
354
+ # Intermittent/sparse series -> TAIL
355
+ updated[uid] = updated[uid] | {SeriesBucket.TAIL}
356
+ # Remove HEAD if present
357
+ updated[uid] = updated[uid] - {SeriesBucket.HEAD}
358
+
359
+ if classification.value == "cold_start":
360
+ # Cold start -> SHORT_HISTORY
361
+ updated[uid] = updated[uid] | {SeriesBucket.SHORT_HISTORY}
362
+ # Remove LONG_HISTORY if present
363
+ updated[uid] = updated[uid] - {SeriesBucket.LONG_HISTORY}
364
+
365
+ return updated
366
+
367
+ def _compute_bucket_stats(
368
+ self,
369
+ df: pd.DataFrame,
370
+ bucket_assignments: dict[str, set[SeriesBucket]],
371
+ ) -> dict[SeriesBucket, BucketStatistics]:
372
+ """Compute statistics for each bucket.
373
+
374
+ Args:
375
+ df: Source DataFrame
376
+ bucket_assignments: Bucket assignments per series
377
+
378
+ Returns:
379
+ Dict mapping bucket to BucketStatistics
380
+ """
381
+ stats = {}
382
+ value_col = self.config.value_col
383
+
384
+ # Compute per-series metrics
385
+ series_metrics = df.groupby("unique_id").agg({
386
+ value_col: ["sum", "mean", "count"],
387
+ }).reset_index()
388
+ series_metrics.columns = ["unique_id", "total_value", "avg_value", "n_obs"]
389
+
390
+ # Total volume across all series
391
+ total_volume = series_metrics["total_value"].sum()
392
+
393
+ for bucket in SeriesBucket:
394
+ series_in_bucket = [
395
+ uid for uid, buckets in bucket_assignments.items()
396
+ if bucket in buckets
397
+ ]
398
+
399
+ if not series_in_bucket:
400
+ stats[bucket] = BucketStatistics(
401
+ series_count=0,
402
+ total_observations=0,
403
+ avg_observations=0.0,
404
+ avg_value=0.0,
405
+ value_percentile=0.0,
406
+ )
407
+ continue
408
+
409
+ bucket_metrics = series_metrics[series_metrics["unique_id"].isin(series_in_bucket)]
410
+
411
+ bucket_volume = bucket_metrics["total_value"].sum()
412
+ value_percentile = bucket_volume / total_volume if total_volume > 0 else 0.0
413
+
414
+ stats[bucket] = BucketStatistics(
415
+ series_count=len(series_in_bucket),
416
+ total_observations=int(bucket_metrics["n_obs"].sum()),
417
+ avg_observations=float(bucket_metrics["n_obs"].mean()),
418
+ avg_value=float(bucket_metrics["avg_value"].mean()),
419
+ value_percentile=float(value_percentile),
420
+ )
421
+
422
+ return stats
423
+
424
+ def get_model_for_bucket(
425
+ self,
426
+ bucket: SeriesBucket,
427
+ sparsity_class: str | None = None,
428
+ ) -> str:
429
+ """Get recommended model for a given bucket.
430
+
431
+ Model recommendations:
432
+ - HEAD + LONG_HISTORY: TSFM or sophisticated model
433
+ - HEAD + SHORT_HISTORY: Robust local model
434
+ - TAIL: Simple baseline (SeasonalNaive, HistoricAverage)
435
+ - INTERMITTENT: Croston or ADIDA (via statsforecast)
436
+
437
+ Args:
438
+ bucket: Bucket to get recommendation for
439
+ sparsity_class: Optional sparsity classification
440
+
441
+ Returns:
442
+ Recommended model name
443
+ """
444
+ if sparsity_class == "intermittent":
445
+ return "Croston" # or "ADIDA"
446
+
447
+ if bucket == SeriesBucket.HEAD:
448
+ return "SeasonalNaive" # Placeholder for TSFM
449
+ elif bucket == SeriesBucket.TAIL or bucket == SeriesBucket.SHORT_HISTORY:
450
+ return "HistoricAverage"
451
+ elif bucket == SeriesBucket.LONG_HISTORY:
452
+ return "SeasonalNaive"
453
+ else:
454
+ return "SeasonalNaive" # Default
455
+
456
+ def get_bucket_specific_plan_config(
457
+ self,
458
+ bucket: SeriesBucket,
459
+ ) -> dict:
460
+ """Get bucket-specific configuration for model training.
461
+
462
+ Args:
463
+ bucket: Bucket to get config for
464
+
465
+ Returns:
466
+ Dict with bucket-specific configuration
467
+ """
468
+ configs = {
469
+ SeriesBucket.HEAD: {
470
+ "model_complexity": "high",
471
+ "hyperparameter_tuning": True,
472
+ "ensemble_size": 5,
473
+ },
474
+ SeriesBucket.TAIL: {
475
+ "model_complexity": "low",
476
+ "hyperparameter_tuning": False,
477
+ "ensemble_size": 1,
478
+ },
479
+ SeriesBucket.SHORT_HISTORY: {
480
+ "min_observations": 10,
481
+ "use_global_model": True,
482
+ },
483
+ SeriesBucket.LONG_HISTORY: {
484
+ "min_observations": 100,
485
+ "use_global_model": False,
486
+ },
487
+ }
488
+
489
+ return configs.get(bucket, {})
@@ -0,0 +1,132 @@
1
+ """Fallback ladder implementation.
2
+
3
+ Provides automatic model degradation when primary models fail.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from collections.abc import Callable
9
+ from typing import TYPE_CHECKING, TypeVar
10
+
11
+ from tsagentkit.contracts.errors import EFallbackExhausted
12
+ from tsagentkit.router.plan import PlanSpec, get_candidate_models
13
+
14
+ if TYPE_CHECKING:
15
+ from tsagentkit.series import TSDataset
16
+
17
+ T = TypeVar("T")
18
+
19
+
20
+ def execute_with_fallback(
21
+ fit_func: Callable[[str, TSDataset], T],
22
+ dataset: TSDataset,
23
+ plan: PlanSpec,
24
+ on_fallback: Callable[[str, str, Exception], None] | None = None,
25
+ ) -> tuple[T, str]:
26
+ """Execute fit function with fallback ladder.
27
+
28
+ Attempts to fit models in order (primary -> fallbacks) until one succeeds.
29
+
30
+ Args:
31
+ fit_func: Function that fits a model given (model_name, dataset)
32
+ dataset: TSDataset to fit on
33
+ plan: Execution plan with fallback chain
34
+ on_fallback: Optional callback when fallback triggered (from_model, to_model, error)
35
+
36
+ Returns:
37
+ Tuple of (result, model_name_that_succeeded)
38
+
39
+ Raises:
40
+ EFallbackExhausted: If all models in the ladder fail
41
+ """
42
+ models = get_candidate_models(plan)
43
+ last_error: Exception | None = None
44
+
45
+ for i, model_name in enumerate(models):
46
+ try:
47
+ result = fit_func(model_name, dataset)
48
+ return result, model_name
49
+ except Exception as e:
50
+ last_error = e
51
+
52
+ # Trigger callback if provided
53
+ if on_fallback and i < len(models) - 1:
54
+ on_fallback(model_name, models[i + 1], e)
55
+
56
+ # Continue to next fallback
57
+ continue
58
+
59
+ # All models failed
60
+ error_msg = f"All models failed. Last error: {last_error}"
61
+ raise EFallbackExhausted(
62
+ error_msg,
63
+ context={
64
+ "models_attempted": models,
65
+ "last_error": str(last_error),
66
+ },
67
+ )
68
+
69
+
70
+ class FallbackLadder:
71
+ """Manages fallback chains for different scenarios.
72
+
73
+ Provides predefined fallback ladders for common use cases.
74
+ """
75
+
76
+ # Standard fallback chains
77
+ STANDARD_LADDER: list[str] = ["SeasonalNaive", "HistoricAverage", "Naive"]
78
+ """Standard fallback: SeasonalNaive -> HistoricAverage -> Naive"""
79
+
80
+ INTERMITTENT_LADDER: list[str] = ["Croston", "Naive"]
81
+ """For intermittent demand: Croston -> Naive"""
82
+
83
+ COLD_START_LADDER: list[str] = ["HistoricAverage", "Naive"]
84
+ """For cold-start series: HistoricAverage -> Naive"""
85
+
86
+ @classmethod
87
+ def get_ladder(
88
+ cls,
89
+ is_intermittent: bool = False,
90
+ is_cold_start: bool = False,
91
+ ) -> list[str]:
92
+ """Get appropriate fallback ladder for scenario.
93
+
94
+ Args:
95
+ is_intermittent: Whether series is intermittent
96
+ is_cold_start: Whether series is cold-start
97
+
98
+ Returns:
99
+ Ordered list of fallback model names
100
+ """
101
+ if is_intermittent:
102
+ return cls.INTERMITTENT_LADDER
103
+ if is_cold_start:
104
+ return cls.COLD_START_LADDER
105
+ return cls.STANDARD_LADDER
106
+
107
+ @classmethod
108
+ def with_primary(
109
+ cls,
110
+ primary: str,
111
+ fallbacks: list[str] | None = None,
112
+ is_intermittent: bool = False,
113
+ is_cold_start: bool = False,
114
+ ) -> list[str]:
115
+ """Create full model chain with primary and fallbacks.
116
+
117
+ Args:
118
+ primary: Primary model name
119
+ fallbacks: Optional explicit fallback list
120
+ is_intermittent: Whether series is intermittent
121
+ is_cold_start: Whether series is cold-start
122
+
123
+ Returns:
124
+ List with primary first, then fallbacks
125
+ """
126
+ if fallbacks is None:
127
+ fallbacks = cls.get_ladder(is_intermittent, is_cold_start)
128
+
129
+ # Ensure primary isn't in fallbacks
130
+ filtered_fallbacks = [f for f in fallbacks if f != primary]
131
+
132
+ return [primary] + filtered_fallbacks
@@ -0,0 +1,23 @@
1
+ """PlanSpec helpers for routing and provenance."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import hashlib
6
+ import json
7
+
8
+ from tsagentkit.contracts import PlanSpec
9
+
10
+
11
+ def compute_plan_signature(plan: PlanSpec) -> str:
12
+ """Compute deterministic signature for a PlanSpec."""
13
+ data = plan.model_dump()
14
+ json_str = json.dumps(data, sort_keys=True, separators=(",", ":"))
15
+ return hashlib.sha256(json_str.encode()).hexdigest()[:16]
16
+
17
+
18
+ def get_candidate_models(plan: PlanSpec) -> list[str]:
19
+ """Return ordered candidate models for a plan."""
20
+ return list(plan.candidate_models)
21
+
22
+
23
+ __all__ = ["PlanSpec", "compute_plan_signature", "get_candidate_models"]