stravinsky 0.2.67__py3-none-any.whl → 0.4.66__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of stravinsky might be problematic. Click here for more details.
- mcp_bridge/__init__.py +1 -1
- mcp_bridge/auth/__init__.py +16 -6
- mcp_bridge/auth/cli.py +202 -11
- mcp_bridge/auth/oauth.py +1 -2
- mcp_bridge/auth/openai_oauth.py +4 -7
- mcp_bridge/auth/token_store.py +112 -11
- mcp_bridge/cli/__init__.py +1 -1
- mcp_bridge/cli/install_hooks.py +503 -107
- mcp_bridge/cli/session_report.py +0 -3
- mcp_bridge/config/MANIFEST_SCHEMA.md +305 -0
- mcp_bridge/config/README.md +276 -0
- mcp_bridge/config/__init__.py +2 -2
- mcp_bridge/config/hook_config.py +247 -0
- mcp_bridge/config/hooks_manifest.json +138 -0
- mcp_bridge/config/rate_limits.py +317 -0
- mcp_bridge/config/skills_manifest.json +128 -0
- mcp_bridge/hooks/HOOKS_SETTINGS.json +17 -4
- mcp_bridge/hooks/__init__.py +19 -4
- mcp_bridge/hooks/agent_reminder.py +4 -4
- mcp_bridge/hooks/auto_slash_command.py +5 -5
- mcp_bridge/hooks/budget_optimizer.py +2 -2
- mcp_bridge/hooks/claude_limits_hook.py +114 -0
- mcp_bridge/hooks/comment_checker.py +3 -4
- mcp_bridge/hooks/compaction.py +2 -2
- mcp_bridge/hooks/context.py +2 -1
- mcp_bridge/hooks/context_monitor.py +2 -2
- mcp_bridge/hooks/delegation_policy.py +85 -0
- mcp_bridge/hooks/directory_context.py +3 -3
- mcp_bridge/hooks/edit_recovery.py +3 -2
- mcp_bridge/hooks/edit_recovery_policy.py +49 -0
- mcp_bridge/hooks/empty_message_sanitizer.py +2 -2
- mcp_bridge/hooks/events.py +160 -0
- mcp_bridge/hooks/git_noninteractive.py +4 -4
- mcp_bridge/hooks/keyword_detector.py +8 -10
- mcp_bridge/hooks/manager.py +43 -22
- mcp_bridge/hooks/notification_hook.py +13 -6
- mcp_bridge/hooks/parallel_enforcement_policy.py +67 -0
- mcp_bridge/hooks/parallel_enforcer.py +5 -5
- mcp_bridge/hooks/parallel_execution.py +22 -10
- mcp_bridge/hooks/post_tool/parallel_validation.py +103 -0
- mcp_bridge/hooks/pre_compact.py +8 -9
- mcp_bridge/hooks/pre_tool/agent_spawn_validator.py +115 -0
- mcp_bridge/hooks/preemptive_compaction.py +2 -3
- mcp_bridge/hooks/routing_notifications.py +80 -0
- mcp_bridge/hooks/rules_injector.py +11 -19
- mcp_bridge/hooks/session_idle.py +4 -4
- mcp_bridge/hooks/session_notifier.py +4 -4
- mcp_bridge/hooks/session_recovery.py +4 -5
- mcp_bridge/hooks/stravinsky_mode.py +1 -1
- mcp_bridge/hooks/subagent_stop.py +1 -3
- mcp_bridge/hooks/task_validator.py +2 -2
- mcp_bridge/hooks/tmux_manager.py +7 -8
- mcp_bridge/hooks/todo_delegation.py +4 -1
- mcp_bridge/hooks/todo_enforcer.py +180 -10
- mcp_bridge/hooks/tool_messaging.py +113 -10
- mcp_bridge/hooks/truncation_policy.py +37 -0
- mcp_bridge/hooks/truncator.py +1 -2
- mcp_bridge/metrics/cost_tracker.py +115 -0
- mcp_bridge/native_search.py +93 -0
- mcp_bridge/native_watcher.py +118 -0
- mcp_bridge/notifications.py +150 -0
- mcp_bridge/orchestrator/enums.py +11 -0
- mcp_bridge/orchestrator/router.py +165 -0
- mcp_bridge/orchestrator/state.py +32 -0
- mcp_bridge/orchestrator/visualization.py +14 -0
- mcp_bridge/orchestrator/wisdom.py +34 -0
- mcp_bridge/prompts/__init__.py +1 -8
- mcp_bridge/prompts/dewey.py +1 -1
- mcp_bridge/prompts/planner.py +2 -4
- mcp_bridge/prompts/stravinsky.py +53 -31
- mcp_bridge/proxy/__init__.py +0 -0
- mcp_bridge/proxy/client.py +70 -0
- mcp_bridge/proxy/model_server.py +157 -0
- mcp_bridge/routing/__init__.py +43 -0
- mcp_bridge/routing/config.py +250 -0
- mcp_bridge/routing/model_tiers.py +135 -0
- mcp_bridge/routing/provider_state.py +261 -0
- mcp_bridge/routing/task_classifier.py +190 -0
- mcp_bridge/server.py +542 -59
- mcp_bridge/server_tools.py +738 -6
- mcp_bridge/tools/__init__.py +40 -25
- mcp_bridge/tools/agent_manager.py +616 -697
- mcp_bridge/tools/background_tasks.py +13 -17
- mcp_bridge/tools/code_search.py +70 -53
- mcp_bridge/tools/continuous_loop.py +0 -1
- mcp_bridge/tools/dashboard.py +19 -0
- mcp_bridge/tools/find_code.py +296 -0
- mcp_bridge/tools/init.py +1 -0
- mcp_bridge/tools/list_directory.py +42 -0
- mcp_bridge/tools/lsp/__init__.py +12 -5
- mcp_bridge/tools/lsp/manager.py +471 -0
- mcp_bridge/tools/lsp/tools.py +723 -207
- mcp_bridge/tools/model_invoke.py +1195 -273
- mcp_bridge/tools/mux_client.py +75 -0
- mcp_bridge/tools/project_context.py +1 -2
- mcp_bridge/tools/query_classifier.py +406 -0
- mcp_bridge/tools/read_file.py +84 -0
- mcp_bridge/tools/replace.py +45 -0
- mcp_bridge/tools/run_shell_command.py +38 -0
- mcp_bridge/tools/search_enhancements.py +347 -0
- mcp_bridge/tools/semantic_search.py +3627 -0
- mcp_bridge/tools/session_manager.py +0 -2
- mcp_bridge/tools/skill_loader.py +0 -1
- mcp_bridge/tools/task_runner.py +5 -7
- mcp_bridge/tools/templates.py +3 -3
- mcp_bridge/tools/tool_search.py +331 -0
- mcp_bridge/tools/write_file.py +29 -0
- mcp_bridge/update_manager.py +585 -0
- mcp_bridge/update_manager_pypi.py +297 -0
- mcp_bridge/utils/cache.py +82 -0
- mcp_bridge/utils/process.py +71 -0
- mcp_bridge/utils/session_state.py +51 -0
- mcp_bridge/utils/truncation.py +76 -0
- stravinsky-0.4.66.dist-info/METADATA +517 -0
- stravinsky-0.4.66.dist-info/RECORD +198 -0
- {stravinsky-0.2.67.dist-info → stravinsky-0.4.66.dist-info}/entry_points.txt +1 -0
- stravinsky_claude_assets/HOOKS_INTEGRATION.md +316 -0
- stravinsky_claude_assets/agents/HOOKS.md +437 -0
- stravinsky_claude_assets/agents/code-reviewer.md +210 -0
- stravinsky_claude_assets/agents/comment_checker.md +580 -0
- stravinsky_claude_assets/agents/debugger.md +254 -0
- stravinsky_claude_assets/agents/delphi.md +495 -0
- stravinsky_claude_assets/agents/dewey.md +248 -0
- stravinsky_claude_assets/agents/explore.md +1198 -0
- stravinsky_claude_assets/agents/frontend.md +472 -0
- stravinsky_claude_assets/agents/implementation-lead.md +164 -0
- stravinsky_claude_assets/agents/momus.md +464 -0
- stravinsky_claude_assets/agents/research-lead.md +141 -0
- stravinsky_claude_assets/agents/stravinsky.md +730 -0
- stravinsky_claude_assets/commands/delphi.md +9 -0
- stravinsky_claude_assets/commands/dewey.md +54 -0
- stravinsky_claude_assets/commands/git-master.md +112 -0
- stravinsky_claude_assets/commands/index.md +49 -0
- stravinsky_claude_assets/commands/publish.md +86 -0
- stravinsky_claude_assets/commands/review.md +73 -0
- stravinsky_claude_assets/commands/str/agent_cancel.md +70 -0
- stravinsky_claude_assets/commands/str/agent_list.md +56 -0
- stravinsky_claude_assets/commands/str/agent_output.md +92 -0
- stravinsky_claude_assets/commands/str/agent_progress.md +74 -0
- stravinsky_claude_assets/commands/str/agent_retry.md +94 -0
- stravinsky_claude_assets/commands/str/cancel.md +51 -0
- stravinsky_claude_assets/commands/str/clean.md +97 -0
- stravinsky_claude_assets/commands/str/continue.md +38 -0
- stravinsky_claude_assets/commands/str/index.md +199 -0
- stravinsky_claude_assets/commands/str/list_watchers.md +96 -0
- stravinsky_claude_assets/commands/str/search.md +205 -0
- stravinsky_claude_assets/commands/str/start_filewatch.md +136 -0
- stravinsky_claude_assets/commands/str/stats.md +71 -0
- stravinsky_claude_assets/commands/str/stop_filewatch.md +89 -0
- stravinsky_claude_assets/commands/str/unwatch.md +42 -0
- stravinsky_claude_assets/commands/str/watch.md +45 -0
- stravinsky_claude_assets/commands/strav.md +53 -0
- stravinsky_claude_assets/commands/stravinsky.md +292 -0
- stravinsky_claude_assets/commands/verify.md +60 -0
- stravinsky_claude_assets/commands/version.md +5 -0
- stravinsky_claude_assets/hooks/README.md +248 -0
- stravinsky_claude_assets/hooks/comment_checker.py +193 -0
- stravinsky_claude_assets/hooks/context.py +38 -0
- stravinsky_claude_assets/hooks/context_monitor.py +153 -0
- stravinsky_claude_assets/hooks/dependency_tracker.py +73 -0
- stravinsky_claude_assets/hooks/edit_recovery.py +46 -0
- stravinsky_claude_assets/hooks/execution_state_tracker.py +68 -0
- stravinsky_claude_assets/hooks/notification_hook.py +103 -0
- stravinsky_claude_assets/hooks/notification_hook_v2.py +96 -0
- stravinsky_claude_assets/hooks/parallel_execution.py +241 -0
- stravinsky_claude_assets/hooks/parallel_reinforcement.py +106 -0
- stravinsky_claude_assets/hooks/parallel_reinforcement_v2.py +112 -0
- stravinsky_claude_assets/hooks/pre_compact.py +123 -0
- stravinsky_claude_assets/hooks/ralph_loop.py +173 -0
- stravinsky_claude_assets/hooks/session_recovery.py +263 -0
- stravinsky_claude_assets/hooks/stop_hook.py +89 -0
- stravinsky_claude_assets/hooks/stravinsky_metrics.py +164 -0
- stravinsky_claude_assets/hooks/stravinsky_mode.py +146 -0
- stravinsky_claude_assets/hooks/subagent_stop.py +98 -0
- stravinsky_claude_assets/hooks/todo_continuation.py +111 -0
- stravinsky_claude_assets/hooks/todo_delegation.py +96 -0
- stravinsky_claude_assets/hooks/tool_messaging.py +281 -0
- stravinsky_claude_assets/hooks/truncator.py +23 -0
- stravinsky_claude_assets/rules/deployment_safety.md +51 -0
- stravinsky_claude_assets/rules/integration_wiring.md +89 -0
- stravinsky_claude_assets/rules/pypi_deployment.md +220 -0
- stravinsky_claude_assets/rules/stravinsky_orchestrator.md +32 -0
- stravinsky_claude_assets/settings.json +152 -0
- stravinsky_claude_assets/skills/chrome-devtools/SKILL.md +81 -0
- stravinsky_claude_assets/skills/sqlite/SKILL.md +77 -0
- stravinsky_claude_assets/skills/supabase/SKILL.md +74 -0
- stravinsky_claude_assets/task_dependencies.json +34 -0
- stravinsky-0.2.67.dist-info/METADATA +0 -284
- stravinsky-0.2.67.dist-info/RECORD +0 -76
- {stravinsky-0.2.67.dist-info → stravinsky-0.4.66.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
"""Model tier definitions and cross-provider fallback planning.
|
|
2
|
+
|
|
3
|
+
This module centralizes a simple, two-tier model architecture and provides
|
|
4
|
+
a deterministic fallback chain when an OAuth call fails or is unavailable.
|
|
5
|
+
|
|
6
|
+
The fallback chain is ordered to prefer:
|
|
7
|
+
1) Same-tier OAuth models on *other* providers
|
|
8
|
+
2) Lower-tier OAuth models (if available)
|
|
9
|
+
3) Same-tier models via API key auth
|
|
10
|
+
|
|
11
|
+
The boolean in the returned tuples indicates whether OAuth should be used.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
from enum import Enum
|
|
18
|
+
from typing import Final
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ModelTier(str, Enum):
|
|
22
|
+
PREMIUM = "premium"
|
|
23
|
+
STANDARD = "standard"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass(frozen=True)
|
|
27
|
+
class TierModel:
|
|
28
|
+
model: str
|
|
29
|
+
thinking: bool
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
KNOWN_PROVIDERS: Final[tuple[str, ...]] = ("claude", "openai", "gemini")
|
|
33
|
+
|
|
34
|
+
# Provider preference order mirrors existing routing fallback chains.
|
|
35
|
+
PROVIDER_FALLBACK_ORDER: Final[dict[str, list[str]]] = {
|
|
36
|
+
"claude": ["openai", "gemini"],
|
|
37
|
+
"openai": ["gemini", "claude"],
|
|
38
|
+
"gemini": ["openai", "claude"],
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
# Ordered best -> worst.
|
|
42
|
+
TIER_ORDER: Final[tuple[ModelTier, ...]] = (ModelTier.PREMIUM, ModelTier.STANDARD)
|
|
43
|
+
|
|
44
|
+
MODEL_TIERS: Final[dict[ModelTier, dict[str, TierModel]]] = {
|
|
45
|
+
ModelTier.PREMIUM: {
|
|
46
|
+
"claude": TierModel(model="claude-4.5-opus", thinking=True),
|
|
47
|
+
"openai": TierModel(model="gpt-5.2-codex", thinking=False),
|
|
48
|
+
"gemini": TierModel(model="gemini-3-pro", thinking=False),
|
|
49
|
+
},
|
|
50
|
+
ModelTier.STANDARD: {
|
|
51
|
+
"claude": TierModel(model="claude-4.5-sonnet", thinking=False),
|
|
52
|
+
"openai": TierModel(model="gpt-5.2", thinking=False),
|
|
53
|
+
"gemini": TierModel(model="gemini-3-flash-preview", thinking=False),
|
|
54
|
+
},
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _require_known_provider(provider: str) -> None:
|
|
59
|
+
if provider not in KNOWN_PROVIDERS:
|
|
60
|
+
raise ValueError(f"Unknown provider: {provider!r}. Expected one of {KNOWN_PROVIDERS!r}.")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _tier_for(provider: str, model: str) -> ModelTier:
|
|
64
|
+
_require_known_provider(provider)
|
|
65
|
+
|
|
66
|
+
for tier, tier_models in MODEL_TIERS.items():
|
|
67
|
+
spec = tier_models.get(provider)
|
|
68
|
+
if spec and spec.model == model:
|
|
69
|
+
return tier
|
|
70
|
+
|
|
71
|
+
raise ValueError(
|
|
72
|
+
f"Unknown model for provider {provider!r}: {model!r}. "
|
|
73
|
+
"Expected a model present in MODEL_TIERS."
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _providers_other_first(provider: str) -> list[str]:
|
|
78
|
+
_require_known_provider(provider)
|
|
79
|
+
preferred = PROVIDER_FALLBACK_ORDER.get(provider)
|
|
80
|
+
if preferred is not None:
|
|
81
|
+
return [p for p in preferred if p != provider]
|
|
82
|
+
return [p for p in KNOWN_PROVIDERS if p != provider]
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _lower_tiers(tier: ModelTier) -> list[ModelTier]:
|
|
86
|
+
try:
|
|
87
|
+
idx = TIER_ORDER.index(tier)
|
|
88
|
+
except ValueError:
|
|
89
|
+
return []
|
|
90
|
+
return list(TIER_ORDER[idx + 1 :])
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def get_oauth_fallback_chain(provider: str, model: str) -> list[tuple[str, str, bool]]:
|
|
94
|
+
"""Return ordered (provider, model, use_oauth) fallbacks.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
provider: Current provider (e.g. "openai")
|
|
98
|
+
model: Current model identifier within that provider
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
A list of candidate (provider, model, use_oauth) tuples.
|
|
102
|
+
|
|
103
|
+
Ordering rules:
|
|
104
|
+
- Same-tier models on OTHER providers first (OAuth)
|
|
105
|
+
- Then lower-tier models (OAuth)
|
|
106
|
+
- Then same-tier models via API key (non-OAuth)
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
tier = _tier_for(provider, model)
|
|
110
|
+
other_providers = _providers_other_first(provider)
|
|
111
|
+
|
|
112
|
+
chain: list[tuple[str, str, bool]] = []
|
|
113
|
+
seen: set[tuple[str, str, bool]] = set()
|
|
114
|
+
|
|
115
|
+
def add(p: str, m: str, use_oauth: bool) -> None:
|
|
116
|
+
item = (p, m, use_oauth)
|
|
117
|
+
if item in seen:
|
|
118
|
+
return
|
|
119
|
+
seen.add(item)
|
|
120
|
+
chain.append(item)
|
|
121
|
+
|
|
122
|
+
# 1) Same tier, other providers, OAuth first.
|
|
123
|
+
for p in other_providers:
|
|
124
|
+
add(p, MODEL_TIERS[tier][p].model, True)
|
|
125
|
+
|
|
126
|
+
# 2) Lower tiers, OAuth.
|
|
127
|
+
for lower in _lower_tiers(tier):
|
|
128
|
+
for p in [*other_providers, provider]:
|
|
129
|
+
add(p, MODEL_TIERS[lower][p].model, True)
|
|
130
|
+
|
|
131
|
+
# 3) Same tier, API key (non-OAuth).
|
|
132
|
+
for p in [provider, *other_providers]:
|
|
133
|
+
add(p, MODEL_TIERS[tier][p].model, False)
|
|
134
|
+
|
|
135
|
+
return chain
|
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Provider State Tracking for Multi-Provider Routing.
|
|
3
|
+
|
|
4
|
+
Tracks the availability and health of each provider (Claude, OpenAI, Gemini)
|
|
5
|
+
to enable intelligent fallback when providers are rate-limited or unavailable.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
import sys
|
|
12
|
+
import threading
|
|
13
|
+
import time
|
|
14
|
+
from dataclasses import dataclass
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class ProviderState:
|
|
22
|
+
"""Tracks the state of a single provider."""
|
|
23
|
+
|
|
24
|
+
name: str
|
|
25
|
+
is_available: bool = True
|
|
26
|
+
cooldown_until: float | None = None
|
|
27
|
+
error_count: int = 0
|
|
28
|
+
last_success: float | None = None
|
|
29
|
+
last_error: str | None = None
|
|
30
|
+
total_requests: int = 0
|
|
31
|
+
total_failures: int = 0
|
|
32
|
+
|
|
33
|
+
def mark_rate_limited(self, duration: int = 300, reason: str = "429 rate limit") -> None:
|
|
34
|
+
"""
|
|
35
|
+
Mark provider as rate-limited with cooldown.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
duration: Cooldown duration in seconds (default 5 minutes)
|
|
39
|
+
reason: Reason for rate limiting (for logging)
|
|
40
|
+
"""
|
|
41
|
+
self.cooldown_until = time.time() + duration
|
|
42
|
+
self.is_available = False
|
|
43
|
+
self.error_count += 1
|
|
44
|
+
self.total_failures += 1
|
|
45
|
+
self.last_error = reason
|
|
46
|
+
|
|
47
|
+
logger.warning(
|
|
48
|
+
f"[ProviderState] {self.name} rate-limited: {reason}. "
|
|
49
|
+
f"Cooldown until {time.strftime('%H:%M:%S', time.localtime(self.cooldown_until))}"
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
# User-visible notification
|
|
53
|
+
print(
|
|
54
|
+
f"⚠️ {self.name.upper()}: Rate-limited ({reason}). Cooldown for {duration}s.",
|
|
55
|
+
file=sys.stderr,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
def mark_success(self) -> None:
|
|
59
|
+
"""Mark a successful request to this provider."""
|
|
60
|
+
self.last_success = time.time()
|
|
61
|
+
self.error_count = 0 # Reset consecutive error count
|
|
62
|
+
self.total_requests += 1
|
|
63
|
+
|
|
64
|
+
# If we were in cooldown but succeeded, clear it
|
|
65
|
+
if self.cooldown_until is not None:
|
|
66
|
+
logger.info(f"[ProviderState] {self.name} recovered from cooldown")
|
|
67
|
+
self.cooldown_until = None
|
|
68
|
+
self.is_available = True
|
|
69
|
+
|
|
70
|
+
def mark_error(self, error: str) -> None:
|
|
71
|
+
"""Mark a non-rate-limit error."""
|
|
72
|
+
self.error_count += 1
|
|
73
|
+
self.total_failures += 1
|
|
74
|
+
self.last_error = error
|
|
75
|
+
logger.warning(f"[ProviderState] {self.name} error ({self.error_count}): {error}")
|
|
76
|
+
|
|
77
|
+
def check_availability(self) -> bool:
|
|
78
|
+
"""
|
|
79
|
+
Check if provider is available (cooldown expired?).
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
True if provider is available, False if in cooldown
|
|
83
|
+
"""
|
|
84
|
+
if self.cooldown_until is None:
|
|
85
|
+
return True
|
|
86
|
+
|
|
87
|
+
if time.time() > self.cooldown_until:
|
|
88
|
+
# Cooldown expired - reset
|
|
89
|
+
logger.info(f"[ProviderState] {self.name} cooldown expired. Marking available.")
|
|
90
|
+
self.cooldown_until = None
|
|
91
|
+
self.is_available = True
|
|
92
|
+
return True
|
|
93
|
+
|
|
94
|
+
# Still in cooldown
|
|
95
|
+
remaining = self.cooldown_until - time.time()
|
|
96
|
+
logger.debug(f"[ProviderState] {self.name} still in cooldown ({remaining:.0f}s remaining)")
|
|
97
|
+
return False
|
|
98
|
+
|
|
99
|
+
def get_cooldown_remaining(self) -> float | None:
|
|
100
|
+
"""Get remaining cooldown time in seconds, or None if not in cooldown."""
|
|
101
|
+
if self.cooldown_until is None:
|
|
102
|
+
return None
|
|
103
|
+
remaining = self.cooldown_until - time.time()
|
|
104
|
+
return max(0, remaining)
|
|
105
|
+
|
|
106
|
+
def reset(self) -> None:
|
|
107
|
+
"""Reset provider state (clear cooldown and errors)."""
|
|
108
|
+
self.is_available = True
|
|
109
|
+
self.cooldown_until = None
|
|
110
|
+
self.error_count = 0
|
|
111
|
+
self.last_error = None
|
|
112
|
+
logger.info(f"[ProviderState] {self.name} state reset")
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class ProviderStateTracker:
|
|
116
|
+
"""
|
|
117
|
+
Tracks availability of all providers with thread-safe access.
|
|
118
|
+
|
|
119
|
+
Provides fallback chain logic when primary providers are unavailable.
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
# Default fallback chains for each provider
|
|
123
|
+
DEFAULT_FALLBACK_CHAINS: dict[str, list[str]] = {
|
|
124
|
+
"claude": ["openai", "gemini"],
|
|
125
|
+
"openai": ["gemini", "claude"],
|
|
126
|
+
"gemini": ["openai", "claude"],
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
def __init__(self) -> None:
|
|
130
|
+
self._lock = threading.Lock()
|
|
131
|
+
self.providers: dict[str, ProviderState] = {
|
|
132
|
+
"claude": ProviderState("claude"),
|
|
133
|
+
"openai": ProviderState("openai"),
|
|
134
|
+
"gemini": ProviderState("gemini"),
|
|
135
|
+
}
|
|
136
|
+
self._fallback_chains = self.DEFAULT_FALLBACK_CHAINS.copy()
|
|
137
|
+
|
|
138
|
+
def get_provider(self, name: str) -> ProviderState:
|
|
139
|
+
"""Get the state for a specific provider."""
|
|
140
|
+
with self._lock:
|
|
141
|
+
if name not in self.providers:
|
|
142
|
+
# Create new provider state if not exists
|
|
143
|
+
self.providers[name] = ProviderState(name)
|
|
144
|
+
return self.providers[name]
|
|
145
|
+
|
|
146
|
+
def mark_rate_limited(
|
|
147
|
+
self, provider: str, duration: int = 300, reason: str = "429 rate limit"
|
|
148
|
+
) -> None:
|
|
149
|
+
"""Mark a provider as rate-limited."""
|
|
150
|
+
with self._lock:
|
|
151
|
+
self.get_provider(provider).mark_rate_limited(duration, reason)
|
|
152
|
+
|
|
153
|
+
def mark_success(self, provider: str) -> None:
|
|
154
|
+
"""Mark a successful request to a provider."""
|
|
155
|
+
with self._lock:
|
|
156
|
+
self.get_provider(provider).mark_success()
|
|
157
|
+
|
|
158
|
+
def mark_error(self, provider: str, error: str) -> None:
|
|
159
|
+
"""Mark an error for a provider."""
|
|
160
|
+
with self._lock:
|
|
161
|
+
self.get_provider(provider).mark_error(error)
|
|
162
|
+
|
|
163
|
+
def is_available(self, provider: str) -> bool:
|
|
164
|
+
"""Check if a provider is available."""
|
|
165
|
+
with self._lock:
|
|
166
|
+
return self.get_provider(provider).check_availability()
|
|
167
|
+
|
|
168
|
+
def get_fallback_provider(self, preferred: str) -> str:
|
|
169
|
+
"""
|
|
170
|
+
Get best available provider, falling back as needed.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
preferred: The preferred provider to use
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
The best available provider (preferred if available, otherwise fallback)
|
|
177
|
+
"""
|
|
178
|
+
with self._lock:
|
|
179
|
+
# Check preferred first
|
|
180
|
+
if self.get_provider(preferred).check_availability():
|
|
181
|
+
return preferred
|
|
182
|
+
|
|
183
|
+
# Try fallback chain
|
|
184
|
+
fallback_chain = self._fallback_chains.get(preferred, [])
|
|
185
|
+
for fallback in fallback_chain:
|
|
186
|
+
if self.get_provider(fallback).check_availability():
|
|
187
|
+
logger.info(
|
|
188
|
+
f"[ProviderStateTracker] Falling back from {preferred} to {fallback}"
|
|
189
|
+
)
|
|
190
|
+
# Notify user
|
|
191
|
+
print(
|
|
192
|
+
f"⚠️ {preferred.title()} unavailable → Routing to {fallback.title()}",
|
|
193
|
+
file=sys.stderr,
|
|
194
|
+
)
|
|
195
|
+
return fallback
|
|
196
|
+
|
|
197
|
+
# All providers unavailable, return preferred anyway (will likely fail)
|
|
198
|
+
logger.warning(
|
|
199
|
+
f"[ProviderStateTracker] All providers unavailable. Using {preferred} anyway."
|
|
200
|
+
)
|
|
201
|
+
return preferred
|
|
202
|
+
|
|
203
|
+
def get_status(self) -> dict[str, dict[str, Any]]:
|
|
204
|
+
"""Get status of all providers for dashboard/CLI."""
|
|
205
|
+
with self._lock:
|
|
206
|
+
status = {}
|
|
207
|
+
for name, state in self.providers.items():
|
|
208
|
+
state.check_availability() # Update availability
|
|
209
|
+
cooldown_remaining = state.get_cooldown_remaining()
|
|
210
|
+
status[name] = {
|
|
211
|
+
"available": state.is_available,
|
|
212
|
+
"cooldown_remaining": cooldown_remaining,
|
|
213
|
+
"error_count": state.error_count,
|
|
214
|
+
"last_success": state.last_success,
|
|
215
|
+
"last_error": state.last_error,
|
|
216
|
+
"total_requests": state.total_requests,
|
|
217
|
+
"total_failures": state.total_failures,
|
|
218
|
+
}
|
|
219
|
+
return status
|
|
220
|
+
|
|
221
|
+
def reset_all(self) -> None:
|
|
222
|
+
"""Reset all provider states."""
|
|
223
|
+
with self._lock:
|
|
224
|
+
for state in self.providers.values():
|
|
225
|
+
state.reset()
|
|
226
|
+
logger.info("[ProviderStateTracker] All provider states reset")
|
|
227
|
+
|
|
228
|
+
def reset_provider(self, provider: str) -> None:
|
|
229
|
+
"""Reset a specific provider's state."""
|
|
230
|
+
with self._lock:
|
|
231
|
+
if provider in self.providers:
|
|
232
|
+
self.providers[provider].reset()
|
|
233
|
+
|
|
234
|
+
def set_fallback_chain(self, provider: str, chain: list[str]) -> None:
|
|
235
|
+
"""Set custom fallback chain for a provider."""
|
|
236
|
+
with self._lock:
|
|
237
|
+
self._fallback_chains[provider] = chain
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
# Global singleton instance
|
|
241
|
+
_provider_tracker: ProviderStateTracker | None = None
|
|
242
|
+
_tracker_lock = threading.Lock()
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def get_provider_tracker() -> ProviderStateTracker:
|
|
246
|
+
"""Get or create the global ProviderStateTracker instance."""
|
|
247
|
+
global _provider_tracker
|
|
248
|
+
if _provider_tracker is None:
|
|
249
|
+
with _tracker_lock:
|
|
250
|
+
if _provider_tracker is None:
|
|
251
|
+
_provider_tracker = ProviderStateTracker()
|
|
252
|
+
logger.info("[ProviderStateTracker] Created global instance")
|
|
253
|
+
return _provider_tracker
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def reset_provider_tracker() -> None:
|
|
257
|
+
"""Reset the global provider tracker (mainly for testing)."""
|
|
258
|
+
global _provider_tracker
|
|
259
|
+
with _tracker_lock:
|
|
260
|
+
if _provider_tracker is not None:
|
|
261
|
+
_provider_tracker.reset_all()
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Task Classifier for Intelligent Routing.
|
|
3
|
+
|
|
4
|
+
Classifies incoming tasks to determine the optimal provider and model.
|
|
5
|
+
Uses pattern matching and heuristics to categorize tasks.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
import re
|
|
12
|
+
from enum import Enum, auto
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TaskType(Enum):
|
|
19
|
+
"""Types of tasks for routing purposes."""
|
|
20
|
+
|
|
21
|
+
CODE_GENERATION = auto() # Creating new code
|
|
22
|
+
CODE_REFACTORING = auto() # Improving existing code structure
|
|
23
|
+
DEBUGGING = auto() # Fixing bugs and errors
|
|
24
|
+
ARCHITECTURE = auto() # System design and planning
|
|
25
|
+
DOCUMENTATION = auto() # Writing docs, comments, READMEs
|
|
26
|
+
CODE_SEARCH = auto() # Finding code patterns
|
|
27
|
+
SECURITY_REVIEW = auto() # Security analysis
|
|
28
|
+
GENERAL = auto() # Default fallback
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# Patterns for task classification
|
|
32
|
+
TASK_PATTERNS: dict[TaskType, list[str]] = {
|
|
33
|
+
TaskType.CODE_GENERATION: [
|
|
34
|
+
r"\b(generate|create|implement|build|write|add|make|develop)\b.*\b(code|function|class|module|component|api|endpoint|feature)\b",
|
|
35
|
+
r"\b(new|fresh)\b.*\b(implementation|feature|module)\b",
|
|
36
|
+
r"\bimplement\b",
|
|
37
|
+
r"\bcreate a\b",
|
|
38
|
+
],
|
|
39
|
+
TaskType.CODE_REFACTORING: [
|
|
40
|
+
r"\b(refactor|restructure|reorganize|clean\s*up|simplify|optimize|improve)\b",
|
|
41
|
+
r"\b(extract|inline|rename|move)\b.*\b(method|function|class|variable)\b",
|
|
42
|
+
r"\bcode\s*(cleanup|quality)\b",
|
|
43
|
+
r"\breduce\s*(complexity|duplication)\b",
|
|
44
|
+
],
|
|
45
|
+
TaskType.DEBUGGING: [
|
|
46
|
+
r"\b(debug|fix|solve|resolve|troubleshoot|diagnose)\b",
|
|
47
|
+
r"\b(bug|error|issue|problem|failing|broken|crash)\b",
|
|
48
|
+
r"\b(not\s*working|doesn't\s*work|won't\s*work)\b",
|
|
49
|
+
r"\b(exception|traceback|stack\s*trace)\b",
|
|
50
|
+
r"\bwhy\s*(is|does|doesn't)\b.*\b(fail|error|crash)\b",
|
|
51
|
+
],
|
|
52
|
+
TaskType.ARCHITECTURE: [
|
|
53
|
+
r"\b(architect|design|structure|pattern|system)\b",
|
|
54
|
+
r"\b(high\s*level|overall|big\s*picture)\b",
|
|
55
|
+
r"\b(scalability|maintainability|extensibility)\b",
|
|
56
|
+
r"\b(trade\s*off|decision|approach|strategy)\b",
|
|
57
|
+
r"\bhow\s*should\s*(we|i)\s*(design|structure|organize)\b",
|
|
58
|
+
],
|
|
59
|
+
TaskType.DOCUMENTATION: [
|
|
60
|
+
r"\b(document|readme|docstring|comment|explain|describe)\b",
|
|
61
|
+
r"\b(api\s*docs|documentation|jsdoc|pydoc)\b",
|
|
62
|
+
r"\bwrite\s*(up|docs|documentation)\b",
|
|
63
|
+
r"\badd\s*comments?\b",
|
|
64
|
+
],
|
|
65
|
+
TaskType.CODE_SEARCH: [
|
|
66
|
+
r"\b(find|search|locate|where\s*is|look\s*for)\b.*\b(code|function|class|implementation)\b",
|
|
67
|
+
r"\b(grep|ripgrep|search)\b",
|
|
68
|
+
r"\bhow\s*is\b.*\b(implemented|used|called)\b",
|
|
69
|
+
r"\bshow\s*me\b.*\b(code|implementation)\b",
|
|
70
|
+
],
|
|
71
|
+
TaskType.SECURITY_REVIEW: [
|
|
72
|
+
r"\b(security|vulnerability|exploit|attack|injection)\b",
|
|
73
|
+
r"\b(auth|authentication|authorization|permission)\b.*\b(check|review|audit)\b",
|
|
74
|
+
r"\b(secure|harden|protect)\b",
|
|
75
|
+
r"\b(xss|csrf|sql\s*injection|rce)\b",
|
|
76
|
+
],
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
# Default routing for each task type
|
|
80
|
+
DEFAULT_TASK_ROUTING: dict[TaskType, tuple[str, str | None]] = {
|
|
81
|
+
TaskType.CODE_GENERATION: ("openai", "gpt-5-codex"),
|
|
82
|
+
TaskType.CODE_REFACTORING: ("openai", "gpt-5-codex"),
|
|
83
|
+
TaskType.DEBUGGING: ("openai", "gpt-5-codex"),
|
|
84
|
+
TaskType.ARCHITECTURE: ("openai", "gpt-5.2-medium"), # Delphi-style
|
|
85
|
+
TaskType.DOCUMENTATION: ("gemini", "gemini-3-flash"),
|
|
86
|
+
TaskType.CODE_SEARCH: ("gemini", "gemini-3-flash"),
|
|
87
|
+
TaskType.SECURITY_REVIEW: ("claude", None), # Keep in Claude
|
|
88
|
+
TaskType.GENERAL: ("claude", None), # Default to Claude
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def classify_task(prompt: str, context: dict[str, Any] | None = None) -> TaskType:
|
|
93
|
+
"""
|
|
94
|
+
Classify a task based on prompt content and optional context.
|
|
95
|
+
|
|
96
|
+
Uses pattern matching against known task type indicators.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
prompt: The user's prompt or request
|
|
100
|
+
context: Optional context dict with additional signals
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
TaskType enum indicating the classification
|
|
104
|
+
"""
|
|
105
|
+
if not prompt:
|
|
106
|
+
return TaskType.GENERAL
|
|
107
|
+
|
|
108
|
+
prompt_lower = prompt.lower()
|
|
109
|
+
|
|
110
|
+
# Check patterns for each task type
|
|
111
|
+
# Priority order matters - first match wins
|
|
112
|
+
priority_order = [
|
|
113
|
+
TaskType.DEBUGGING, # Most specific - error fixing
|
|
114
|
+
TaskType.SECURITY_REVIEW, # Security concerns
|
|
115
|
+
TaskType.CODE_REFACTORING, # Improvement tasks
|
|
116
|
+
TaskType.ARCHITECTURE, # Design decisions
|
|
117
|
+
TaskType.DOCUMENTATION, # Doc writing
|
|
118
|
+
TaskType.CODE_SEARCH, # Finding code
|
|
119
|
+
TaskType.CODE_GENERATION, # Creating code (broad)
|
|
120
|
+
]
|
|
121
|
+
|
|
122
|
+
for task_type in priority_order:
|
|
123
|
+
patterns = TASK_PATTERNS.get(task_type, [])
|
|
124
|
+
for pattern in patterns:
|
|
125
|
+
if re.search(pattern, prompt_lower, re.IGNORECASE):
|
|
126
|
+
logger.debug(f"[TaskClassifier] Matched {task_type.name} with pattern: {pattern}")
|
|
127
|
+
return task_type
|
|
128
|
+
|
|
129
|
+
# Check context for additional signals
|
|
130
|
+
if context:
|
|
131
|
+
# If there's an error in context, likely debugging
|
|
132
|
+
if context.get("error") or context.get("exception"):
|
|
133
|
+
return TaskType.DEBUGGING
|
|
134
|
+
|
|
135
|
+
# If there's existing code to modify
|
|
136
|
+
if context.get("existing_code") and not context.get("create_new"):
|
|
137
|
+
return TaskType.CODE_REFACTORING
|
|
138
|
+
|
|
139
|
+
return TaskType.GENERAL
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def get_routing_for_task(
|
|
143
|
+
task_type: TaskType,
|
|
144
|
+
config: dict[str, Any] | None = None,
|
|
145
|
+
) -> tuple[str, str | None]:
|
|
146
|
+
"""
|
|
147
|
+
Get the recommended (provider, model) for a task type.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
task_type: The classified task type
|
|
151
|
+
config: Optional routing config override
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
Tuple of (provider, model) where model may be None
|
|
155
|
+
"""
|
|
156
|
+
# Use config if provided
|
|
157
|
+
if config:
|
|
158
|
+
task_name = task_type.name.lower()
|
|
159
|
+
if task_name in config:
|
|
160
|
+
rule = config[task_name]
|
|
161
|
+
return (rule.get("provider", "claude"), rule.get("model"))
|
|
162
|
+
|
|
163
|
+
# Fall back to defaults
|
|
164
|
+
return DEFAULT_TASK_ROUTING.get(task_type, ("claude", None))
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def classify_and_route(
|
|
168
|
+
prompt: str,
|
|
169
|
+
context: dict[str, Any] | None = None,
|
|
170
|
+
config: dict[str, Any] | None = None,
|
|
171
|
+
) -> tuple[TaskType, str, str | None]:
|
|
172
|
+
"""
|
|
173
|
+
Classify a task and get routing recommendation in one call.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
prompt: The user's prompt
|
|
177
|
+
context: Optional context dict
|
|
178
|
+
config: Optional routing config override
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
Tuple of (task_type, provider, model)
|
|
182
|
+
"""
|
|
183
|
+
task_type = classify_task(prompt, context)
|
|
184
|
+
provider, model = get_routing_for_task(task_type, config)
|
|
185
|
+
|
|
186
|
+
logger.info(
|
|
187
|
+
f"[TaskClassifier] Classified as {task_type.name} → {provider}/{model or 'default'}"
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
return task_type, provider, model
|