coderouter-cli 2.0.0__py3-none-any.whl → 2.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -48,6 +48,13 @@ Event inventory (dispatch table in :meth:`MetricsCollector._dispatch`)
48
48
  + per-profile counter + latest_ratio gauge
49
49
  ``context-budget-trimmed`` (v2.0-F)→ ``context_budget_trims_total``
50
50
  + per-profile counter
51
+ ``drift-detected`` (v2.0-G) → ``drift_detected_total`` + per-provider
52
+ ``drift-promoted`` (v2.0-G) → ``drift_promoted_total``
53
+ ``drift-reload-attempted`` → ``drift_reload_total`` / success
54
+ ``partial-stitch-surfaced`` → ``partial_stitch_surfaced_total`` (v2.0-H)
55
+ ``probe-completed`` (v2.0-I) → ``probe_total`` / ``probe_success`` / ``probe_failure``
56
+ + per-provider latency gauge
57
+ ``probe-round-completed`` → ``probe_rounds_total`` (v2.0-I)
51
58
  ``coderouter-startup`` → ``startup_info`` (stored for the UI header)
52
59
 
53
60
  Unrecognized events are ignored (forward-compat: adding a new log
@@ -194,6 +201,29 @@ class MetricsCollector(logging.Handler):
194
201
  self._context_budget_trims_by_profile: Counter[str] = Counter()
195
202
  self._context_budget_latest_ratio: dict[str, float] = {}
196
203
 
204
+ # v2.0-G (L4): drift detection counters. Per-provider counts of
205
+ # drift events, promotions (rank demotions), and reload attempts.
206
+ self._drift_detected_total: int = 0
207
+ self._drift_detected_by_provider: Counter[str] = Counter()
208
+ self._drift_promoted_total: int = 0
209
+ self._drift_reload_total: int = 0
210
+ self._drift_reload_success_total: int = 0
211
+
212
+ # v2.0-H (L6): partial stitch surfaced counter. Tracks how often
213
+ # the mid-stream failure recovery delivered partial content to the
214
+ # client instead of a generic error event.
215
+ self._partial_stitch_surfaced_total: int = 0
216
+
217
+ # v2.0-I: continuous probe counters. Per-provider probe attempts
218
+ # and outcomes, plus a round counter for the dashboard's
219
+ # "probes/min" panel.
220
+ self._probe_total: Counter[str] = Counter() # per-provider total probes
221
+ self._probe_success: Counter[str] = Counter() # per-provider successes
222
+ self._probe_failure: Counter[str] = Counter() # per-provider failures
223
+ self._probe_rounds_total: int = 0
224
+ self._probe_latency_ms: dict[str, float] = {} # per-provider latest
225
+ self._probe_drift_detected: Counter[str] = Counter() # per-provider drift
226
+
197
227
  # Last-error snapshot per provider (overwrites previous). Enables the
198
228
  # dashboard's "last error" column without scanning the ring.
199
229
  self._last_error: dict[str, dict[str, Any]] = {}
@@ -372,6 +402,43 @@ class MetricsCollector(logging.Handler):
372
402
  profile = _str(extras.get("profile"))
373
403
  self._context_budget_trims_total += 1
374
404
  self._context_budget_trims_by_profile[profile] += 1
405
+ elif event == "drift-detected":
406
+ # v2.0-G (L4): drift detection fired.
407
+ provider = _str(extras.get("provider"))
408
+ self._drift_detected_total += 1
409
+ self._drift_detected_by_provider[provider] += 1
410
+ self._push_recent(event, extras, record)
411
+ elif event == "drift-promoted":
412
+ # v2.0-G (L4): drifted provider was demoted.
413
+ self._drift_promoted_total += 1
414
+ self._push_recent(event, extras, record)
415
+ elif event == "drift-reload-attempted":
416
+ # v2.0-G (L4): Ollama KV cache flush attempted.
417
+ self._drift_reload_total += 1
418
+ if extras.get("success"):
419
+ self._drift_reload_success_total += 1
420
+ elif event == "partial-stitch-surfaced":
421
+ # v2.0-H (L6): mid-stream failure gracefully surfaced.
422
+ self._partial_stitch_surfaced_total += 1
423
+ self._push_recent(event, extras, record)
424
+ elif event == "probe-completed":
425
+ # v2.0-I: per-provider probe outcome.
426
+ provider = _str(extras.get("provider"))
427
+ self._probe_total[provider] += 1
428
+ if extras.get("success"):
429
+ self._probe_success[provider] += 1
430
+ else:
431
+ self._probe_failure[provider] += 1
432
+ latency_raw = extras.get("latency_ms")
433
+ if isinstance(latency_raw, int | float):
434
+ self._probe_latency_ms[provider] = float(latency_raw)
435
+ elif event == "probe-round-completed":
436
+ # v2.0-I: round counter for the dashboard.
437
+ self._probe_rounds_total += 1
438
+ elif event == "probe-capabilities-drift":
439
+ # v2.0-I: model mismatch detected by probe.
440
+ provider = _str(extras.get("provider"))
441
+ self._probe_drift_detected[provider] += 1
375
442
  elif event == "coderouter-startup":
376
443
  # Snapshot a subset — startup payload contains lists that are
377
444
  # safe to surface to /metrics.json. Version / providers /
@@ -534,11 +601,103 @@ class MetricsCollector(logging.Handler):
534
601
  "context_budget_latest_ratio": dict(
535
602
  self._context_budget_latest_ratio
536
603
  ),
604
+ # v2.0-G (L4): drift detection aggregate counters.
605
+ "drift_detected_total": self._drift_detected_total,
606
+ "drift_detected_by_provider": dict(
607
+ self._drift_detected_by_provider
608
+ ),
609
+ "drift_promoted_total": self._drift_promoted_total,
610
+ "drift_reload_total": self._drift_reload_total,
611
+ "drift_reload_success_total": self._drift_reload_success_total,
612
+ # v2.0-H (L6): partial stitch surfaced.
613
+ "partial_stitch_surfaced_total": self._partial_stitch_surfaced_total,
614
+ # v2.0-I: continuous probe counters.
615
+ "probe_total": dict(self._probe_total),
616
+ "probe_success": dict(self._probe_success),
617
+ "probe_failure": dict(self._probe_failure),
618
+ "probe_rounds_total": self._probe_rounds_total,
619
+ "probe_latency_ms": dict(self._probe_latency_ms),
620
+ "probe_drift_detected": dict(self._probe_drift_detected),
537
621
  },
538
622
  "providers": provider_rows,
539
623
  "recent": list(self._recent),
540
624
  }
541
625
 
626
+ # ------------------------------------------------------------------
627
+ # v2.0-K: Persistence
628
+ # ------------------------------------------------------------------
629
+
630
+ def save_state(self) -> dict[str, object]:
631
+ """Export key counters for cross-restart persistence.
632
+
633
+ Returns a JSON-safe dict of the most operationally-important
634
+ counters. The ``recent`` ring and per-provider ``last_error``
635
+ are excluded (ephemeral by nature).
636
+ """
637
+ with self._lock:
638
+ return {
639
+ "requests_total": self._requests_total,
640
+ "provider_attempts": dict(self._provider_attempts),
641
+ "provider_outcomes": {
642
+ k: dict(v) for k, v in self._provider_outcomes.items()
643
+ },
644
+ "cost_total_usd": dict(self._cost_total_usd),
645
+ "cost_savings_usd": dict(self._cost_savings_usd),
646
+ "cost_total_usd_aggregate": self._cost_total_usd_aggregate,
647
+ "cost_savings_usd_aggregate": self._cost_savings_usd_aggregate,
648
+ "chain_paid_gate_blocked_total": self._chain_paid_gate_blocked_total,
649
+ "chain_budget_exceeded_total": self._chain_budget_exceeded_total,
650
+ "chain_memory_pressure_blocked_total": self._chain_memory_pressure_blocked_total,
651
+ "chain_uniform_auth_failure_total": self._chain_uniform_auth_failure_total,
652
+ "probe_rounds_total": self._probe_rounds_total,
653
+ }
654
+
655
+ def load_state(self, state: dict[str, object]) -> None:
656
+ """Restore counters from a previously saved dict.
657
+
658
+ Additive: values from ``state`` are *added* to the current
659
+ (zeroed) counters, so calling ``load_state`` on a fresh
660
+ collector restores the prior session's totals.
661
+ """
662
+ if not isinstance(state, dict):
663
+ return
664
+ with self._lock:
665
+ self._requests_total += int(state.get("requests_total", 0))
666
+ for k, v in (state.get("provider_attempts") or {}).items():
667
+ self._provider_attempts[k] += int(v)
668
+ for prov, outcomes in (state.get("provider_outcomes") or {}).items():
669
+ if not isinstance(outcomes, dict):
670
+ continue
671
+ if prov not in self._provider_outcomes:
672
+ self._provider_outcomes[prov] = Counter()
673
+ for k, v in outcomes.items():
674
+ self._provider_outcomes[prov][k] += int(v)
675
+ for k, v in (state.get("cost_total_usd") or {}).items():
676
+ self._cost_total_usd[k] = self._cost_total_usd.get(k, 0.0) + float(v)
677
+ for k, v in (state.get("cost_savings_usd") or {}).items():
678
+ self._cost_savings_usd[k] = self._cost_savings_usd.get(k, 0.0) + float(v)
679
+ self._cost_total_usd_aggregate += float(
680
+ state.get("cost_total_usd_aggregate", 0.0)
681
+ )
682
+ self._cost_savings_usd_aggregate += float(
683
+ state.get("cost_savings_usd_aggregate", 0.0)
684
+ )
685
+ self._chain_paid_gate_blocked_total += int(
686
+ state.get("chain_paid_gate_blocked_total", 0)
687
+ )
688
+ self._chain_budget_exceeded_total += int(
689
+ state.get("chain_budget_exceeded_total", 0)
690
+ )
691
+ self._chain_memory_pressure_blocked_total += int(
692
+ state.get("chain_memory_pressure_blocked_total", 0)
693
+ )
694
+ self._chain_uniform_auth_failure_total += int(
695
+ state.get("chain_uniform_auth_failure_total", 0)
696
+ )
697
+ self._probe_rounds_total += int(
698
+ state.get("probe_rounds_total", 0)
699
+ )
700
+
542
701
  # ------------------------------------------------------------------
543
702
  # Test hook
544
703
  # ------------------------------------------------------------------
@@ -578,6 +737,15 @@ class MetricsCollector(logging.Handler):
578
737
  self._cost_savings_usd.clear()
579
738
  self._cost_total_usd_aggregate = 0.0
580
739
  self._cost_savings_usd_aggregate = 0.0
740
+ # v2.0-H (L6)
741
+ self._partial_stitch_surfaced_total = 0
742
+ # v2.0-I
743
+ self._probe_total.clear()
744
+ self._probe_success.clear()
745
+ self._probe_failure.clear()
746
+ self._probe_rounds_total = 0
747
+ self._probe_latency_ms.clear()
748
+ self._probe_drift_detected.clear()
581
749
  # v2.0-F (L1)
582
750
  self._context_budget_warnings_total = 0
583
751
  self._context_budget_trims_total = 0
@@ -381,6 +381,147 @@ def format_prometheus(snapshot: dict[str, Any]) -> str:
381
381
  samples=ratio_samples,
382
382
  )
383
383
  )
384
+
385
+ # ---- v2.0-G (L4): drift detection metrics ------------------------------
386
+ lines.extend(
387
+ _counter(
388
+ name="drift_detected_total",
389
+ help_text=(
390
+ "Drift detection events (quality degradation detected), "
391
+ "by provider."
392
+ ),
393
+ samples=[
394
+ ((("provider", p),), v)
395
+ for p, v in sorted(
396
+ counters.get("drift_detected_by_provider", {}).items()
397
+ )
398
+ ],
399
+ )
400
+ )
401
+ drift_promoted = counters.get("drift_promoted_total", 0)
402
+ if drift_promoted:
403
+ lines.extend(
404
+ _counter(
405
+ name="drift_promoted_total",
406
+ help_text=(
407
+ "Number of times a drifted provider was demoted "
408
+ "(promote/reload action fired)."
409
+ ),
410
+ samples=[(((),), drift_promoted)],
411
+ )
412
+ )
413
+ drift_reload = counters.get("drift_reload_total", 0)
414
+ if drift_reload:
415
+ lines.extend(
416
+ _counter(
417
+ name="drift_reload_total",
418
+ help_text="Ollama KV cache flush attempts (reload action).",
419
+ samples=[(((),), drift_reload)],
420
+ )
421
+ )
422
+ drift_reload_ok = counters.get("drift_reload_success_total", 0)
423
+ if drift_reload_ok:
424
+ lines.extend(
425
+ _counter(
426
+ name="drift_reload_success_total",
427
+ help_text="Successful Ollama KV cache flush attempts.",
428
+ samples=[(((),), drift_reload_ok)],
429
+ )
430
+ )
431
+
432
+ # ---- v2.0-H (L6): partial stitch surfaced metric -----------------------
433
+ partial_stitch = counters.get("partial_stitch_surfaced_total", 0)
434
+ if partial_stitch:
435
+ lines.extend(
436
+ _counter(
437
+ name="partial_stitch_surfaced_total",
438
+ help_text=(
439
+ "Mid-stream failures where partial content was delivered "
440
+ "to the client (partial_stitch_action=surface)."
441
+ ),
442
+ samples=[((), partial_stitch)],
443
+ )
444
+ )
445
+
446
+ # ---- v2.0-I: continuous probe metrics ------------------------------------
447
+ probe_total_samples: list[tuple[tuple[tuple[str, str], ...], int]] = []
448
+ for provider, count in sorted(counters.get("probe_total", {}).items()):
449
+ probe_total_samples.append(
450
+ ((("provider", provider),), count)
451
+ )
452
+ if probe_total_samples:
453
+ lines.extend(
454
+ _counter(
455
+ name="probe_total",
456
+ help_text=(
457
+ "Continuous health probe attempts, by provider. "
458
+ "Each probe sends a 1-token completion to verify "
459
+ "the full model pipeline."
460
+ ),
461
+ samples=probe_total_samples,
462
+ )
463
+ )
464
+ probe_outcome_samples: list[tuple[tuple[tuple[str, str], ...], int]] = []
465
+ for provider, count in sorted(counters.get("probe_success", {}).items()):
466
+ probe_outcome_samples.append(
467
+ ((("provider", provider), ("outcome", "success")), count)
468
+ )
469
+ for provider, count in sorted(counters.get("probe_failure", {}).items()):
470
+ probe_outcome_samples.append(
471
+ ((("provider", provider), ("outcome", "failure")), count)
472
+ )
473
+ if probe_outcome_samples:
474
+ lines.extend(
475
+ _counter(
476
+ name="probe_outcomes_total",
477
+ help_text=(
478
+ "Continuous probe outcomes by provider and result "
479
+ "(success | failure)."
480
+ ),
481
+ samples=probe_outcome_samples,
482
+ )
483
+ )
484
+ probe_rounds = counters.get("probe_rounds_total", 0)
485
+ if probe_rounds:
486
+ lines.extend(
487
+ _counter(
488
+ name="probe_rounds_total",
489
+ help_text="Completed probe sweep rounds (one round probes all eligible providers).",
490
+ samples=[((), probe_rounds)],
491
+ )
492
+ )
493
+ probe_drift_samples: list[tuple[tuple[tuple[str, str], ...], int]] = [
494
+ ((("provider", p),), v)
495
+ for p, v in sorted(counters.get("probe_drift_detected", {}).items())
496
+ ]
497
+ if probe_drift_samples:
498
+ lines.extend(
499
+ _counter(
500
+ name="probe_drift_detected_total",
501
+ help_text=(
502
+ "Model-name mismatches detected by continuous probing "
503
+ "(configured model != response model), by provider."
504
+ ),
505
+ samples=probe_drift_samples,
506
+ )
507
+ )
508
+ # Gauge: latest probe latency per provider (ms).
509
+ latency_samples: list[tuple[tuple[tuple[str, str], ...], float]] = [
510
+ ((("provider", p),), round(v, 1))
511
+ for p, v in sorted(counters.get("probe_latency_ms", {}).items())
512
+ ]
513
+ if latency_samples:
514
+ lines.extend(
515
+ _gauge_float(
516
+ name="probe_latency_ms",
517
+ help_text=(
518
+ "Latest probe round-trip latency in milliseconds, by "
519
+ "provider. Gauge (most recent value, not cumulative)."
520
+ ),
521
+ samples=latency_samples,
522
+ )
523
+ )
524
+
384
525
  return "\n".join(lines) + "\n"
385
526
 
386
527
 
@@ -52,6 +52,7 @@ __all__ = [
52
52
  "OutputFilterChain",
53
53
  "StripStopMarkersFilter",
54
54
  "StripThinkingFilter",
55
+ "StripToolCallXmlFilter",
55
56
  "apply_output_filters",
56
57
  "validate_output_filters",
57
58
  ]
@@ -63,20 +64,28 @@ __all__ = [
63
64
 
64
65
 
65
66
  DEFAULT_STOP_MARKERS: tuple[str, ...] = (
67
+ # v1.0-A originals
66
68
  "<|turn|>",
67
69
  "<|end|>",
68
70
  "<|python_tag|>",
69
71
  "<|im_end|>",
70
72
  "<|eot_id|>",
71
73
  "<|channel>thought",
74
+ # v2.2: tool-call XML tags leaked by Qwen / Hermes / Llama tool-call
75
+ # formats. These appear when the model writes tool calls as XML
76
+ # instead of structured JSON, or when the tokenizer's special-token
77
+ # handling leaks through.
78
+ "<|tool▁call|>",
79
+ "<|tool▁sep|>",
72
80
  )
73
81
  """Default stop/harness markers stripped by ``strip_stop_markers``.
74
82
 
75
83
  Covers Llama 3.x (``<|python_tag|>``, ``<|eot_id|>``), ChatML / Qwen
76
- (``<|im_end|>``, ``<|end|>``), Gemma-ish (``<|turn|>``) and OpenAI-
77
- harmony (``<|channel>thought``). Extending this tuple is an ABI change
78
- users who need a bespoke set can add a dedicated filter entry in
79
- a later minor; for v1.0-A the fixed list covers observed leaks.
84
+ (``<|im_end|>``, ``<|end|>``), Gemma-ish (``<|turn|>``), OpenAI-
85
+ harmony (``<|channel>thought``), and Qwen / Hermes tool-call markers
86
+ (``<|tool▁call|>``, ``<|tool▁sep|>``). Extending this tuple is an ABI
87
+ change users who need a bespoke set can add a dedicated filter entry
88
+ in a later minor.
80
89
  """
81
90
 
82
91
 
@@ -292,6 +301,87 @@ class StripStopMarkersFilter:
292
301
  return "".join(out_parts)
293
302
 
294
303
 
304
+ # ---------------------------------------------------------------------------
305
+ # strip_tool_call_xml (v2.2)
306
+ # ---------------------------------------------------------------------------
307
+
308
+
309
+ _TOOL_CALL_OPEN = "<tool_call>"
310
+ _TOOL_CALL_CLOSE = "</tool_call>"
311
+
312
+
313
+ class StripToolCallXmlFilter:
314
+ """Remove ``<tool_call>...</tool_call>`` XML blocks from assistant content.
315
+
316
+ Qwen / Hermes / Llama tool-call formats sometimes emit tool calls
317
+ as ``<tool_call>{"name": "Bash", ...}</tool_call>`` XML in the
318
+ content stream. When ``tool_repair`` has already extracted the
319
+ structured JSON from these blocks, the XML wrapper tags are
320
+ leftover debris that confuse downstream clients.
321
+
322
+ Architecture note: this filter should run AFTER ``tool_repair``
323
+ has had a chance to extract the JSON. The filter chain is applied
324
+ at the adapter boundary (post-repair), so ordering is naturally
325
+ correct.
326
+
327
+ Implementation mirrors ``StripThinkingFilter`` — the same
328
+ stateful open/close tag scanning, same chunk-boundary safety.
329
+ """
330
+
331
+ name = "strip_tool_call_xml"
332
+
333
+ def __init__(self) -> None:
334
+ """Initialize the per-request buffer + in-block state to empty."""
335
+ self.modified: bool = False
336
+ self._in_block: bool = False
337
+ self._buffer: str = ""
338
+
339
+ def feed(self, text: str, *, eof: bool = False) -> str:
340
+ """Consume ``text`` and return the portion safe to emit now.
341
+
342
+ Mirrors the ``StripThinkingFilter`` algorithm: greedy tag
343
+ matching with partial-prefix holdback across chunk boundaries.
344
+ """
345
+ self._buffer += text
346
+ out_parts: list[str] = []
347
+
348
+ while True:
349
+ if not self._in_block:
350
+ idx = self._buffer.find(_TOOL_CALL_OPEN)
351
+ if idx != -1:
352
+ out_parts.append(self._buffer[:idx])
353
+ self._buffer = self._buffer[idx + len(_TOOL_CALL_OPEN) :]
354
+ self._in_block = True
355
+ self.modified = True
356
+ continue
357
+ # No open tag — emit all but a potential partial prefix.
358
+ overlap = _max_suffix_overlap(self._buffer, _TOOL_CALL_OPEN)
359
+ if overlap:
360
+ out_parts.append(self._buffer[:-overlap])
361
+ self._buffer = self._buffer[-overlap:]
362
+ else:
363
+ out_parts.append(self._buffer)
364
+ self._buffer = ""
365
+ break
366
+ # in_block: suppress until we find the close tag.
367
+ idx = self._buffer.find(_TOOL_CALL_CLOSE)
368
+ if idx != -1:
369
+ self._buffer = self._buffer[idx + len(_TOOL_CALL_CLOSE) :]
370
+ self._in_block = False
371
+ continue
372
+ # No close tag — retain potential partial suffix, drop the rest.
373
+ overlap = _max_suffix_overlap(self._buffer, _TOOL_CALL_CLOSE)
374
+ self._buffer = self._buffer[-overlap:] if overlap else ""
375
+ break
376
+
377
+ if eof:
378
+ if not self._in_block:
379
+ out_parts.append(self._buffer)
380
+ # If still in block at eof, silently drop the partial block.
381
+ self._buffer = ""
382
+ return "".join(out_parts)
383
+
384
+
295
385
  # ---------------------------------------------------------------------------
296
386
  # Registry + chain
297
387
  # ---------------------------------------------------------------------------
@@ -300,6 +390,7 @@ class StripStopMarkersFilter:
300
390
  KNOWN_FILTERS: dict[str, type[OutputFilter]] = {
301
391
  StripThinkingFilter.name: StripThinkingFilter,
302
392
  StripStopMarkersFilter.name: StripStopMarkersFilter,
393
+ StripToolCallXmlFilter.name: StripToolCallXmlFilter,
303
394
  }
304
395
  """Registry of string-name → filter class.
305
396
 
@@ -234,6 +234,29 @@ class AdaptiveAdjuster:
234
234
  while entry.observations and entry.observations[0].ts_monotonic < cutoff:
235
235
  entry.observations.popleft()
236
236
 
237
+ def demote(self, provider: str, *, steps: int = 2) -> None:
238
+ """Force-demote a provider by injecting synthetic failure observations.
239
+
240
+ Used by v2.0-G drift detection to push a provider's error rate above
241
+ the demotion threshold (``ERROR_RATE_DEMOTE_THRESHOLD``). Each step
242
+ injects one synthetic failure observation, so ``steps=2`` guarantees
243
+ the provider will be ranked lower on the next ``compute_effective_order``.
244
+
245
+ The injected observations carry no latency signal (``latency_ms=None``)
246
+ and expire naturally after ``ROLLING_WINDOW_S`` seconds.
247
+ """
248
+ ts = time.monotonic()
249
+ with self._lock:
250
+ entry = self._state.setdefault(provider, _AdjusterState())
251
+ for _ in range(steps):
252
+ entry.observations.append(
253
+ _ProviderObservation(
254
+ ts_monotonic=ts,
255
+ latency_ms=None,
256
+ success=False,
257
+ )
258
+ )
259
+
237
260
  # ------------------------------------------------------------------
238
261
  # Stats
239
262
  # ------------------------------------------------------------------
@@ -187,5 +187,40 @@ class BudgetTracker:
187
187
  self._totals.clear()
188
188
  self._month = current
189
189
 
190
+ # ------------------------------------------------------------------
191
+ # v2.0-K: Persistence
192
+ # ------------------------------------------------------------------
193
+
194
+ def save_state(self) -> dict[str, object]:
195
+ """Export the current state as a JSON-safe dict.
196
+
197
+ Called by the engine to persist budget totals across restarts.
198
+ """
199
+ with self._lock:
200
+ return {
201
+ "month": self._month,
202
+ "totals": dict(self._totals),
203
+ }
204
+
205
+ def load_state(self, state: dict[str, object]) -> None:
206
+ """Restore state from a previously saved dict.
207
+
208
+ Only restores if the saved month matches the current month
209
+ (no point restoring last month's totals into a new month).
210
+ """
211
+ if not isinstance(state, dict):
212
+ return
213
+ saved_month = state.get("month", "")
214
+ with self._lock:
215
+ current = _utc_month_key()
216
+ if saved_month != current:
217
+ return # stale month — skip
218
+ totals = state.get("totals", {})
219
+ if isinstance(totals, dict):
220
+ self._totals = {
221
+ k: float(v) for k, v in totals.items() if isinstance(v, (int, float))
222
+ }
223
+ self._month = current
224
+
190
225
 
191
226
  __all__ = ["BudgetTracker"]