stravinsky 0.4.18__py3-none-any.whl → 0.4.66__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of stravinsky might be problematic. Click here for more details.
- mcp_bridge/__init__.py +1 -1
- mcp_bridge/auth/__init__.py +16 -6
- mcp_bridge/auth/cli.py +202 -11
- mcp_bridge/auth/oauth.py +1 -2
- mcp_bridge/auth/openai_oauth.py +4 -7
- mcp_bridge/auth/token_store.py +0 -1
- mcp_bridge/cli/__init__.py +1 -1
- mcp_bridge/cli/install_hooks.py +503 -107
- mcp_bridge/cli/session_report.py +0 -3
- mcp_bridge/config/__init__.py +2 -2
- mcp_bridge/config/hook_config.py +3 -5
- mcp_bridge/config/rate_limits.py +108 -13
- mcp_bridge/hooks/HOOKS_SETTINGS.json +17 -4
- mcp_bridge/hooks/__init__.py +14 -4
- mcp_bridge/hooks/agent_reminder.py +4 -4
- mcp_bridge/hooks/auto_slash_command.py +5 -5
- mcp_bridge/hooks/budget_optimizer.py +2 -2
- mcp_bridge/hooks/claude_limits_hook.py +114 -0
- mcp_bridge/hooks/comment_checker.py +3 -4
- mcp_bridge/hooks/compaction.py +2 -2
- mcp_bridge/hooks/context.py +2 -1
- mcp_bridge/hooks/context_monitor.py +2 -2
- mcp_bridge/hooks/delegation_policy.py +85 -0
- mcp_bridge/hooks/directory_context.py +3 -3
- mcp_bridge/hooks/edit_recovery.py +3 -2
- mcp_bridge/hooks/edit_recovery_policy.py +49 -0
- mcp_bridge/hooks/empty_message_sanitizer.py +2 -2
- mcp_bridge/hooks/events.py +160 -0
- mcp_bridge/hooks/git_noninteractive.py +4 -4
- mcp_bridge/hooks/keyword_detector.py +8 -10
- mcp_bridge/hooks/manager.py +35 -22
- mcp_bridge/hooks/notification_hook.py +13 -6
- mcp_bridge/hooks/parallel_enforcement_policy.py +67 -0
- mcp_bridge/hooks/parallel_enforcer.py +5 -5
- mcp_bridge/hooks/parallel_execution.py +22 -10
- mcp_bridge/hooks/post_tool/parallel_validation.py +103 -0
- mcp_bridge/hooks/pre_compact.py +8 -9
- mcp_bridge/hooks/pre_tool/agent_spawn_validator.py +115 -0
- mcp_bridge/hooks/preemptive_compaction.py +2 -3
- mcp_bridge/hooks/routing_notifications.py +80 -0
- mcp_bridge/hooks/rules_injector.py +11 -19
- mcp_bridge/hooks/session_idle.py +4 -4
- mcp_bridge/hooks/session_notifier.py +4 -4
- mcp_bridge/hooks/session_recovery.py +4 -5
- mcp_bridge/hooks/stravinsky_mode.py +1 -1
- mcp_bridge/hooks/subagent_stop.py +1 -3
- mcp_bridge/hooks/task_validator.py +2 -2
- mcp_bridge/hooks/tmux_manager.py +7 -8
- mcp_bridge/hooks/todo_delegation.py +4 -1
- mcp_bridge/hooks/todo_enforcer.py +180 -10
- mcp_bridge/hooks/truncation_policy.py +37 -0
- mcp_bridge/hooks/truncator.py +1 -2
- mcp_bridge/metrics/cost_tracker.py +115 -0
- mcp_bridge/native_search.py +93 -0
- mcp_bridge/native_watcher.py +118 -0
- mcp_bridge/notifications.py +3 -4
- mcp_bridge/orchestrator/enums.py +11 -0
- mcp_bridge/orchestrator/router.py +165 -0
- mcp_bridge/orchestrator/state.py +32 -0
- mcp_bridge/orchestrator/visualization.py +14 -0
- mcp_bridge/orchestrator/wisdom.py +34 -0
- mcp_bridge/prompts/__init__.py +1 -8
- mcp_bridge/prompts/dewey.py +1 -1
- mcp_bridge/prompts/planner.py +2 -4
- mcp_bridge/prompts/stravinsky.py +53 -31
- mcp_bridge/proxy/__init__.py +0 -0
- mcp_bridge/proxy/client.py +70 -0
- mcp_bridge/proxy/model_server.py +157 -0
- mcp_bridge/routing/__init__.py +43 -0
- mcp_bridge/routing/config.py +250 -0
- mcp_bridge/routing/model_tiers.py +135 -0
- mcp_bridge/routing/provider_state.py +261 -0
- mcp_bridge/routing/task_classifier.py +190 -0
- mcp_bridge/server.py +363 -34
- mcp_bridge/server_tools.py +298 -6
- mcp_bridge/tools/__init__.py +19 -8
- mcp_bridge/tools/agent_manager.py +549 -799
- mcp_bridge/tools/background_tasks.py +13 -17
- mcp_bridge/tools/code_search.py +54 -51
- mcp_bridge/tools/continuous_loop.py +0 -1
- mcp_bridge/tools/dashboard.py +19 -0
- mcp_bridge/tools/find_code.py +296 -0
- mcp_bridge/tools/init.py +1 -0
- mcp_bridge/tools/list_directory.py +42 -0
- mcp_bridge/tools/lsp/__init__.py +8 -8
- mcp_bridge/tools/lsp/manager.py +51 -28
- mcp_bridge/tools/lsp/tools.py +98 -65
- mcp_bridge/tools/model_invoke.py +1047 -152
- mcp_bridge/tools/mux_client.py +75 -0
- mcp_bridge/tools/project_context.py +1 -2
- mcp_bridge/tools/query_classifier.py +132 -49
- mcp_bridge/tools/read_file.py +84 -0
- mcp_bridge/tools/replace.py +45 -0
- mcp_bridge/tools/run_shell_command.py +38 -0
- mcp_bridge/tools/search_enhancements.py +347 -0
- mcp_bridge/tools/semantic_search.py +677 -92
- mcp_bridge/tools/session_manager.py +0 -2
- mcp_bridge/tools/skill_loader.py +0 -1
- mcp_bridge/tools/task_runner.py +5 -7
- mcp_bridge/tools/templates.py +3 -3
- mcp_bridge/tools/tool_search.py +331 -0
- mcp_bridge/tools/write_file.py +29 -0
- mcp_bridge/update_manager.py +33 -37
- mcp_bridge/update_manager_pypi.py +6 -8
- mcp_bridge/utils/cache.py +82 -0
- mcp_bridge/utils/process.py +71 -0
- mcp_bridge/utils/session_state.py +51 -0
- mcp_bridge/utils/truncation.py +76 -0
- {stravinsky-0.4.18.dist-info → stravinsky-0.4.66.dist-info}/METADATA +84 -35
- stravinsky-0.4.66.dist-info/RECORD +198 -0
- {stravinsky-0.4.18.dist-info → stravinsky-0.4.66.dist-info}/entry_points.txt +1 -0
- stravinsky_claude_assets/HOOKS_INTEGRATION.md +316 -0
- stravinsky_claude_assets/agents/HOOKS.md +437 -0
- stravinsky_claude_assets/agents/code-reviewer.md +210 -0
- stravinsky_claude_assets/agents/comment_checker.md +580 -0
- stravinsky_claude_assets/agents/debugger.md +254 -0
- stravinsky_claude_assets/agents/delphi.md +495 -0
- stravinsky_claude_assets/agents/dewey.md +248 -0
- stravinsky_claude_assets/agents/explore.md +1198 -0
- stravinsky_claude_assets/agents/frontend.md +472 -0
- stravinsky_claude_assets/agents/implementation-lead.md +164 -0
- stravinsky_claude_assets/agents/momus.md +464 -0
- stravinsky_claude_assets/agents/research-lead.md +141 -0
- stravinsky_claude_assets/agents/stravinsky.md +730 -0
- stravinsky_claude_assets/commands/delphi.md +9 -0
- stravinsky_claude_assets/commands/dewey.md +54 -0
- stravinsky_claude_assets/commands/git-master.md +112 -0
- stravinsky_claude_assets/commands/index.md +49 -0
- stravinsky_claude_assets/commands/publish.md +86 -0
- stravinsky_claude_assets/commands/review.md +73 -0
- stravinsky_claude_assets/commands/str/agent_cancel.md +70 -0
- stravinsky_claude_assets/commands/str/agent_list.md +56 -0
- stravinsky_claude_assets/commands/str/agent_output.md +92 -0
- stravinsky_claude_assets/commands/str/agent_progress.md +74 -0
- stravinsky_claude_assets/commands/str/agent_retry.md +94 -0
- stravinsky_claude_assets/commands/str/cancel.md +51 -0
- stravinsky_claude_assets/commands/str/clean.md +97 -0
- stravinsky_claude_assets/commands/str/continue.md +38 -0
- stravinsky_claude_assets/commands/str/index.md +199 -0
- stravinsky_claude_assets/commands/str/list_watchers.md +96 -0
- stravinsky_claude_assets/commands/str/search.md +205 -0
- stravinsky_claude_assets/commands/str/start_filewatch.md +136 -0
- stravinsky_claude_assets/commands/str/stats.md +71 -0
- stravinsky_claude_assets/commands/str/stop_filewatch.md +89 -0
- stravinsky_claude_assets/commands/str/unwatch.md +42 -0
- stravinsky_claude_assets/commands/str/watch.md +45 -0
- stravinsky_claude_assets/commands/strav.md +53 -0
- stravinsky_claude_assets/commands/stravinsky.md +292 -0
- stravinsky_claude_assets/commands/verify.md +60 -0
- stravinsky_claude_assets/commands/version.md +5 -0
- stravinsky_claude_assets/hooks/README.md +248 -0
- stravinsky_claude_assets/hooks/comment_checker.py +193 -0
- stravinsky_claude_assets/hooks/context.py +38 -0
- stravinsky_claude_assets/hooks/context_monitor.py +153 -0
- stravinsky_claude_assets/hooks/dependency_tracker.py +73 -0
- stravinsky_claude_assets/hooks/edit_recovery.py +46 -0
- stravinsky_claude_assets/hooks/execution_state_tracker.py +68 -0
- stravinsky_claude_assets/hooks/notification_hook.py +103 -0
- stravinsky_claude_assets/hooks/notification_hook_v2.py +96 -0
- stravinsky_claude_assets/hooks/parallel_execution.py +241 -0
- stravinsky_claude_assets/hooks/parallel_reinforcement.py +106 -0
- stravinsky_claude_assets/hooks/parallel_reinforcement_v2.py +112 -0
- stravinsky_claude_assets/hooks/pre_compact.py +123 -0
- stravinsky_claude_assets/hooks/ralph_loop.py +173 -0
- stravinsky_claude_assets/hooks/session_recovery.py +263 -0
- stravinsky_claude_assets/hooks/stop_hook.py +89 -0
- stravinsky_claude_assets/hooks/stravinsky_metrics.py +164 -0
- stravinsky_claude_assets/hooks/stravinsky_mode.py +146 -0
- stravinsky_claude_assets/hooks/subagent_stop.py +98 -0
- stravinsky_claude_assets/hooks/todo_continuation.py +111 -0
- stravinsky_claude_assets/hooks/todo_delegation.py +96 -0
- stravinsky_claude_assets/hooks/tool_messaging.py +281 -0
- stravinsky_claude_assets/hooks/truncator.py +23 -0
- stravinsky_claude_assets/rules/deployment_safety.md +51 -0
- stravinsky_claude_assets/rules/integration_wiring.md +89 -0
- stravinsky_claude_assets/rules/pypi_deployment.md +220 -0
- stravinsky_claude_assets/rules/stravinsky_orchestrator.md +32 -0
- stravinsky_claude_assets/settings.json +152 -0
- stravinsky_claude_assets/skills/chrome-devtools/SKILL.md +81 -0
- stravinsky_claude_assets/skills/sqlite/SKILL.md +77 -0
- stravinsky_claude_assets/skills/supabase/SKILL.md +74 -0
- stravinsky_claude_assets/task_dependencies.json +34 -0
- stravinsky-0.4.18.dist-info/RECORD +0 -88
- {stravinsky-0.4.18.dist-info → stravinsky-0.4.66.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Provider State Tracking for Multi-Provider Routing.
|
|
3
|
+
|
|
4
|
+
Tracks the availability and health of each provider (Claude, OpenAI, Gemini)
|
|
5
|
+
to enable intelligent fallback when providers are rate-limited or unavailable.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
import sys
|
|
12
|
+
import threading
|
|
13
|
+
import time
|
|
14
|
+
from dataclasses import dataclass
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class ProviderState:
|
|
22
|
+
"""Tracks the state of a single provider."""
|
|
23
|
+
|
|
24
|
+
name: str
|
|
25
|
+
is_available: bool = True
|
|
26
|
+
cooldown_until: float | None = None
|
|
27
|
+
error_count: int = 0
|
|
28
|
+
last_success: float | None = None
|
|
29
|
+
last_error: str | None = None
|
|
30
|
+
total_requests: int = 0
|
|
31
|
+
total_failures: int = 0
|
|
32
|
+
|
|
33
|
+
def mark_rate_limited(self, duration: int = 300, reason: str = "429 rate limit") -> None:
|
|
34
|
+
"""
|
|
35
|
+
Mark provider as rate-limited with cooldown.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
duration: Cooldown duration in seconds (default 5 minutes)
|
|
39
|
+
reason: Reason for rate limiting (for logging)
|
|
40
|
+
"""
|
|
41
|
+
self.cooldown_until = time.time() + duration
|
|
42
|
+
self.is_available = False
|
|
43
|
+
self.error_count += 1
|
|
44
|
+
self.total_failures += 1
|
|
45
|
+
self.last_error = reason
|
|
46
|
+
|
|
47
|
+
logger.warning(
|
|
48
|
+
f"[ProviderState] {self.name} rate-limited: {reason}. "
|
|
49
|
+
f"Cooldown until {time.strftime('%H:%M:%S', time.localtime(self.cooldown_until))}"
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
# User-visible notification
|
|
53
|
+
print(
|
|
54
|
+
f"⚠️ {self.name.upper()}: Rate-limited ({reason}). Cooldown for {duration}s.",
|
|
55
|
+
file=sys.stderr,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
def mark_success(self) -> None:
|
|
59
|
+
"""Mark a successful request to this provider."""
|
|
60
|
+
self.last_success = time.time()
|
|
61
|
+
self.error_count = 0 # Reset consecutive error count
|
|
62
|
+
self.total_requests += 1
|
|
63
|
+
|
|
64
|
+
# If we were in cooldown but succeeded, clear it
|
|
65
|
+
if self.cooldown_until is not None:
|
|
66
|
+
logger.info(f"[ProviderState] {self.name} recovered from cooldown")
|
|
67
|
+
self.cooldown_until = None
|
|
68
|
+
self.is_available = True
|
|
69
|
+
|
|
70
|
+
def mark_error(self, error: str) -> None:
|
|
71
|
+
"""Mark a non-rate-limit error."""
|
|
72
|
+
self.error_count += 1
|
|
73
|
+
self.total_failures += 1
|
|
74
|
+
self.last_error = error
|
|
75
|
+
logger.warning(f"[ProviderState] {self.name} error ({self.error_count}): {error}")
|
|
76
|
+
|
|
77
|
+
def check_availability(self) -> bool:
|
|
78
|
+
"""
|
|
79
|
+
Check if provider is available (cooldown expired?).
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
True if provider is available, False if in cooldown
|
|
83
|
+
"""
|
|
84
|
+
if self.cooldown_until is None:
|
|
85
|
+
return True
|
|
86
|
+
|
|
87
|
+
if time.time() > self.cooldown_until:
|
|
88
|
+
# Cooldown expired - reset
|
|
89
|
+
logger.info(f"[ProviderState] {self.name} cooldown expired. Marking available.")
|
|
90
|
+
self.cooldown_until = None
|
|
91
|
+
self.is_available = True
|
|
92
|
+
return True
|
|
93
|
+
|
|
94
|
+
# Still in cooldown
|
|
95
|
+
remaining = self.cooldown_until - time.time()
|
|
96
|
+
logger.debug(f"[ProviderState] {self.name} still in cooldown ({remaining:.0f}s remaining)")
|
|
97
|
+
return False
|
|
98
|
+
|
|
99
|
+
def get_cooldown_remaining(self) -> float | None:
|
|
100
|
+
"""Get remaining cooldown time in seconds, or None if not in cooldown."""
|
|
101
|
+
if self.cooldown_until is None:
|
|
102
|
+
return None
|
|
103
|
+
remaining = self.cooldown_until - time.time()
|
|
104
|
+
return max(0, remaining)
|
|
105
|
+
|
|
106
|
+
def reset(self) -> None:
|
|
107
|
+
"""Reset provider state (clear cooldown and errors)."""
|
|
108
|
+
self.is_available = True
|
|
109
|
+
self.cooldown_until = None
|
|
110
|
+
self.error_count = 0
|
|
111
|
+
self.last_error = None
|
|
112
|
+
logger.info(f"[ProviderState] {self.name} state reset")
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class ProviderStateTracker:
|
|
116
|
+
"""
|
|
117
|
+
Tracks availability of all providers with thread-safe access.
|
|
118
|
+
|
|
119
|
+
Provides fallback chain logic when primary providers are unavailable.
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
# Default fallback chains for each provider
|
|
123
|
+
DEFAULT_FALLBACK_CHAINS: dict[str, list[str]] = {
|
|
124
|
+
"claude": ["openai", "gemini"],
|
|
125
|
+
"openai": ["gemini", "claude"],
|
|
126
|
+
"gemini": ["openai", "claude"],
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
def __init__(self) -> None:
|
|
130
|
+
self._lock = threading.Lock()
|
|
131
|
+
self.providers: dict[str, ProviderState] = {
|
|
132
|
+
"claude": ProviderState("claude"),
|
|
133
|
+
"openai": ProviderState("openai"),
|
|
134
|
+
"gemini": ProviderState("gemini"),
|
|
135
|
+
}
|
|
136
|
+
self._fallback_chains = self.DEFAULT_FALLBACK_CHAINS.copy()
|
|
137
|
+
|
|
138
|
+
def get_provider(self, name: str) -> ProviderState:
|
|
139
|
+
"""Get the state for a specific provider."""
|
|
140
|
+
with self._lock:
|
|
141
|
+
if name not in self.providers:
|
|
142
|
+
# Create new provider state if not exists
|
|
143
|
+
self.providers[name] = ProviderState(name)
|
|
144
|
+
return self.providers[name]
|
|
145
|
+
|
|
146
|
+
def mark_rate_limited(
|
|
147
|
+
self, provider: str, duration: int = 300, reason: str = "429 rate limit"
|
|
148
|
+
) -> None:
|
|
149
|
+
"""Mark a provider as rate-limited."""
|
|
150
|
+
with self._lock:
|
|
151
|
+
self.get_provider(provider).mark_rate_limited(duration, reason)
|
|
152
|
+
|
|
153
|
+
def mark_success(self, provider: str) -> None:
|
|
154
|
+
"""Mark a successful request to a provider."""
|
|
155
|
+
with self._lock:
|
|
156
|
+
self.get_provider(provider).mark_success()
|
|
157
|
+
|
|
158
|
+
def mark_error(self, provider: str, error: str) -> None:
|
|
159
|
+
"""Mark an error for a provider."""
|
|
160
|
+
with self._lock:
|
|
161
|
+
self.get_provider(provider).mark_error(error)
|
|
162
|
+
|
|
163
|
+
def is_available(self, provider: str) -> bool:
|
|
164
|
+
"""Check if a provider is available."""
|
|
165
|
+
with self._lock:
|
|
166
|
+
return self.get_provider(provider).check_availability()
|
|
167
|
+
|
|
168
|
+
def get_fallback_provider(self, preferred: str) -> str:
|
|
169
|
+
"""
|
|
170
|
+
Get best available provider, falling back as needed.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
preferred: The preferred provider to use
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
The best available provider (preferred if available, otherwise fallback)
|
|
177
|
+
"""
|
|
178
|
+
with self._lock:
|
|
179
|
+
# Check preferred first
|
|
180
|
+
if self.get_provider(preferred).check_availability():
|
|
181
|
+
return preferred
|
|
182
|
+
|
|
183
|
+
# Try fallback chain
|
|
184
|
+
fallback_chain = self._fallback_chains.get(preferred, [])
|
|
185
|
+
for fallback in fallback_chain:
|
|
186
|
+
if self.get_provider(fallback).check_availability():
|
|
187
|
+
logger.info(
|
|
188
|
+
f"[ProviderStateTracker] Falling back from {preferred} to {fallback}"
|
|
189
|
+
)
|
|
190
|
+
# Notify user
|
|
191
|
+
print(
|
|
192
|
+
f"⚠️ {preferred.title()} unavailable → Routing to {fallback.title()}",
|
|
193
|
+
file=sys.stderr,
|
|
194
|
+
)
|
|
195
|
+
return fallback
|
|
196
|
+
|
|
197
|
+
# All providers unavailable, return preferred anyway (will likely fail)
|
|
198
|
+
logger.warning(
|
|
199
|
+
f"[ProviderStateTracker] All providers unavailable. Using {preferred} anyway."
|
|
200
|
+
)
|
|
201
|
+
return preferred
|
|
202
|
+
|
|
203
|
+
def get_status(self) -> dict[str, dict[str, Any]]:
|
|
204
|
+
"""Get status of all providers for dashboard/CLI."""
|
|
205
|
+
with self._lock:
|
|
206
|
+
status = {}
|
|
207
|
+
for name, state in self.providers.items():
|
|
208
|
+
state.check_availability() # Update availability
|
|
209
|
+
cooldown_remaining = state.get_cooldown_remaining()
|
|
210
|
+
status[name] = {
|
|
211
|
+
"available": state.is_available,
|
|
212
|
+
"cooldown_remaining": cooldown_remaining,
|
|
213
|
+
"error_count": state.error_count,
|
|
214
|
+
"last_success": state.last_success,
|
|
215
|
+
"last_error": state.last_error,
|
|
216
|
+
"total_requests": state.total_requests,
|
|
217
|
+
"total_failures": state.total_failures,
|
|
218
|
+
}
|
|
219
|
+
return status
|
|
220
|
+
|
|
221
|
+
def reset_all(self) -> None:
|
|
222
|
+
"""Reset all provider states."""
|
|
223
|
+
with self._lock:
|
|
224
|
+
for state in self.providers.values():
|
|
225
|
+
state.reset()
|
|
226
|
+
logger.info("[ProviderStateTracker] All provider states reset")
|
|
227
|
+
|
|
228
|
+
def reset_provider(self, provider: str) -> None:
|
|
229
|
+
"""Reset a specific provider's state."""
|
|
230
|
+
with self._lock:
|
|
231
|
+
if provider in self.providers:
|
|
232
|
+
self.providers[provider].reset()
|
|
233
|
+
|
|
234
|
+
def set_fallback_chain(self, provider: str, chain: list[str]) -> None:
|
|
235
|
+
"""Set custom fallback chain for a provider."""
|
|
236
|
+
with self._lock:
|
|
237
|
+
self._fallback_chains[provider] = chain
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
# Global singleton instance
|
|
241
|
+
_provider_tracker: ProviderStateTracker | None = None
|
|
242
|
+
_tracker_lock = threading.Lock()
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def get_provider_tracker() -> ProviderStateTracker:
|
|
246
|
+
"""Get or create the global ProviderStateTracker instance."""
|
|
247
|
+
global _provider_tracker
|
|
248
|
+
if _provider_tracker is None:
|
|
249
|
+
with _tracker_lock:
|
|
250
|
+
if _provider_tracker is None:
|
|
251
|
+
_provider_tracker = ProviderStateTracker()
|
|
252
|
+
logger.info("[ProviderStateTracker] Created global instance")
|
|
253
|
+
return _provider_tracker
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def reset_provider_tracker() -> None:
|
|
257
|
+
"""Reset the global provider tracker (mainly for testing)."""
|
|
258
|
+
global _provider_tracker
|
|
259
|
+
with _tracker_lock:
|
|
260
|
+
if _provider_tracker is not None:
|
|
261
|
+
_provider_tracker.reset_all()
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Task Classifier for Intelligent Routing.
|
|
3
|
+
|
|
4
|
+
Classifies incoming tasks to determine the optimal provider and model.
|
|
5
|
+
Uses pattern matching and heuristics to categorize tasks.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
import re
|
|
12
|
+
from enum import Enum, auto
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TaskType(Enum):
|
|
19
|
+
"""Types of tasks for routing purposes."""
|
|
20
|
+
|
|
21
|
+
CODE_GENERATION = auto() # Creating new code
|
|
22
|
+
CODE_REFACTORING = auto() # Improving existing code structure
|
|
23
|
+
DEBUGGING = auto() # Fixing bugs and errors
|
|
24
|
+
ARCHITECTURE = auto() # System design and planning
|
|
25
|
+
DOCUMENTATION = auto() # Writing docs, comments, READMEs
|
|
26
|
+
CODE_SEARCH = auto() # Finding code patterns
|
|
27
|
+
SECURITY_REVIEW = auto() # Security analysis
|
|
28
|
+
GENERAL = auto() # Default fallback
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# Patterns for task classification
|
|
32
|
+
TASK_PATTERNS: dict[TaskType, list[str]] = {
|
|
33
|
+
TaskType.CODE_GENERATION: [
|
|
34
|
+
r"\b(generate|create|implement|build|write|add|make|develop)\b.*\b(code|function|class|module|component|api|endpoint|feature)\b",
|
|
35
|
+
r"\b(new|fresh)\b.*\b(implementation|feature|module)\b",
|
|
36
|
+
r"\bimplement\b",
|
|
37
|
+
r"\bcreate a\b",
|
|
38
|
+
],
|
|
39
|
+
TaskType.CODE_REFACTORING: [
|
|
40
|
+
r"\b(refactor|restructure|reorganize|clean\s*up|simplify|optimize|improve)\b",
|
|
41
|
+
r"\b(extract|inline|rename|move)\b.*\b(method|function|class|variable)\b",
|
|
42
|
+
r"\bcode\s*(cleanup|quality)\b",
|
|
43
|
+
r"\breduce\s*(complexity|duplication)\b",
|
|
44
|
+
],
|
|
45
|
+
TaskType.DEBUGGING: [
|
|
46
|
+
r"\b(debug|fix|solve|resolve|troubleshoot|diagnose)\b",
|
|
47
|
+
r"\b(bug|error|issue|problem|failing|broken|crash)\b",
|
|
48
|
+
r"\b(not\s*working|doesn't\s*work|won't\s*work)\b",
|
|
49
|
+
r"\b(exception|traceback|stack\s*trace)\b",
|
|
50
|
+
r"\bwhy\s*(is|does|doesn't)\b.*\b(fail|error|crash)\b",
|
|
51
|
+
],
|
|
52
|
+
TaskType.ARCHITECTURE: [
|
|
53
|
+
r"\b(architect|design|structure|pattern|system)\b",
|
|
54
|
+
r"\b(high\s*level|overall|big\s*picture)\b",
|
|
55
|
+
r"\b(scalability|maintainability|extensibility)\b",
|
|
56
|
+
r"\b(trade\s*off|decision|approach|strategy)\b",
|
|
57
|
+
r"\bhow\s*should\s*(we|i)\s*(design|structure|organize)\b",
|
|
58
|
+
],
|
|
59
|
+
TaskType.DOCUMENTATION: [
|
|
60
|
+
r"\b(document|readme|docstring|comment|explain|describe)\b",
|
|
61
|
+
r"\b(api\s*docs|documentation|jsdoc|pydoc)\b",
|
|
62
|
+
r"\bwrite\s*(up|docs|documentation)\b",
|
|
63
|
+
r"\badd\s*comments?\b",
|
|
64
|
+
],
|
|
65
|
+
TaskType.CODE_SEARCH: [
|
|
66
|
+
r"\b(find|search|locate|where\s*is|look\s*for)\b.*\b(code|function|class|implementation)\b",
|
|
67
|
+
r"\b(grep|ripgrep|search)\b",
|
|
68
|
+
r"\bhow\s*is\b.*\b(implemented|used|called)\b",
|
|
69
|
+
r"\bshow\s*me\b.*\b(code|implementation)\b",
|
|
70
|
+
],
|
|
71
|
+
TaskType.SECURITY_REVIEW: [
|
|
72
|
+
r"\b(security|vulnerability|exploit|attack|injection)\b",
|
|
73
|
+
r"\b(auth|authentication|authorization|permission)\b.*\b(check|review|audit)\b",
|
|
74
|
+
r"\b(secure|harden|protect)\b",
|
|
75
|
+
r"\b(xss|csrf|sql\s*injection|rce)\b",
|
|
76
|
+
],
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
# Default routing for each task type
|
|
80
|
+
DEFAULT_TASK_ROUTING: dict[TaskType, tuple[str, str | None]] = {
|
|
81
|
+
TaskType.CODE_GENERATION: ("openai", "gpt-5-codex"),
|
|
82
|
+
TaskType.CODE_REFACTORING: ("openai", "gpt-5-codex"),
|
|
83
|
+
TaskType.DEBUGGING: ("openai", "gpt-5-codex"),
|
|
84
|
+
TaskType.ARCHITECTURE: ("openai", "gpt-5.2-medium"), # Delphi-style
|
|
85
|
+
TaskType.DOCUMENTATION: ("gemini", "gemini-3-flash"),
|
|
86
|
+
TaskType.CODE_SEARCH: ("gemini", "gemini-3-flash"),
|
|
87
|
+
TaskType.SECURITY_REVIEW: ("claude", None), # Keep in Claude
|
|
88
|
+
TaskType.GENERAL: ("claude", None), # Default to Claude
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def classify_task(prompt: str, context: dict[str, Any] | None = None) -> TaskType:
|
|
93
|
+
"""
|
|
94
|
+
Classify a task based on prompt content and optional context.
|
|
95
|
+
|
|
96
|
+
Uses pattern matching against known task type indicators.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
prompt: The user's prompt or request
|
|
100
|
+
context: Optional context dict with additional signals
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
TaskType enum indicating the classification
|
|
104
|
+
"""
|
|
105
|
+
if not prompt:
|
|
106
|
+
return TaskType.GENERAL
|
|
107
|
+
|
|
108
|
+
prompt_lower = prompt.lower()
|
|
109
|
+
|
|
110
|
+
# Check patterns for each task type
|
|
111
|
+
# Priority order matters - first match wins
|
|
112
|
+
priority_order = [
|
|
113
|
+
TaskType.DEBUGGING, # Most specific - error fixing
|
|
114
|
+
TaskType.SECURITY_REVIEW, # Security concerns
|
|
115
|
+
TaskType.CODE_REFACTORING, # Improvement tasks
|
|
116
|
+
TaskType.ARCHITECTURE, # Design decisions
|
|
117
|
+
TaskType.DOCUMENTATION, # Doc writing
|
|
118
|
+
TaskType.CODE_SEARCH, # Finding code
|
|
119
|
+
TaskType.CODE_GENERATION, # Creating code (broad)
|
|
120
|
+
]
|
|
121
|
+
|
|
122
|
+
for task_type in priority_order:
|
|
123
|
+
patterns = TASK_PATTERNS.get(task_type, [])
|
|
124
|
+
for pattern in patterns:
|
|
125
|
+
if re.search(pattern, prompt_lower, re.IGNORECASE):
|
|
126
|
+
logger.debug(f"[TaskClassifier] Matched {task_type.name} with pattern: {pattern}")
|
|
127
|
+
return task_type
|
|
128
|
+
|
|
129
|
+
# Check context for additional signals
|
|
130
|
+
if context:
|
|
131
|
+
# If there's an error in context, likely debugging
|
|
132
|
+
if context.get("error") or context.get("exception"):
|
|
133
|
+
return TaskType.DEBUGGING
|
|
134
|
+
|
|
135
|
+
# If there's existing code to modify
|
|
136
|
+
if context.get("existing_code") and not context.get("create_new"):
|
|
137
|
+
return TaskType.CODE_REFACTORING
|
|
138
|
+
|
|
139
|
+
return TaskType.GENERAL
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def get_routing_for_task(
|
|
143
|
+
task_type: TaskType,
|
|
144
|
+
config: dict[str, Any] | None = None,
|
|
145
|
+
) -> tuple[str, str | None]:
|
|
146
|
+
"""
|
|
147
|
+
Get the recommended (provider, model) for a task type.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
task_type: The classified task type
|
|
151
|
+
config: Optional routing config override
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
Tuple of (provider, model) where model may be None
|
|
155
|
+
"""
|
|
156
|
+
# Use config if provided
|
|
157
|
+
if config:
|
|
158
|
+
task_name = task_type.name.lower()
|
|
159
|
+
if task_name in config:
|
|
160
|
+
rule = config[task_name]
|
|
161
|
+
return (rule.get("provider", "claude"), rule.get("model"))
|
|
162
|
+
|
|
163
|
+
# Fall back to defaults
|
|
164
|
+
return DEFAULT_TASK_ROUTING.get(task_type, ("claude", None))
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def classify_and_route(
|
|
168
|
+
prompt: str,
|
|
169
|
+
context: dict[str, Any] | None = None,
|
|
170
|
+
config: dict[str, Any] | None = None,
|
|
171
|
+
) -> tuple[TaskType, str, str | None]:
|
|
172
|
+
"""
|
|
173
|
+
Classify a task and get routing recommendation in one call.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
prompt: The user's prompt
|
|
177
|
+
context: Optional context dict
|
|
178
|
+
config: Optional routing config override
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
Tuple of (task_type, provider, model)
|
|
182
|
+
"""
|
|
183
|
+
task_type = classify_task(prompt, context)
|
|
184
|
+
provider, model = get_routing_for_task(task_type, config)
|
|
185
|
+
|
|
186
|
+
logger.info(
|
|
187
|
+
f"[TaskClassifier] Classified as {task_type.name} → {provider}/{model or 'default'}"
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
return task_type, provider, model
|