router-maestro 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- router_maestro/__init__.py +3 -0
- router_maestro/__main__.py +6 -0
- router_maestro/auth/__init__.py +18 -0
- router_maestro/auth/github_oauth.py +181 -0
- router_maestro/auth/manager.py +136 -0
- router_maestro/auth/storage.py +91 -0
- router_maestro/cli/__init__.py +1 -0
- router_maestro/cli/auth.py +167 -0
- router_maestro/cli/client.py +322 -0
- router_maestro/cli/config.py +132 -0
- router_maestro/cli/context.py +146 -0
- router_maestro/cli/main.py +42 -0
- router_maestro/cli/model.py +288 -0
- router_maestro/cli/server.py +117 -0
- router_maestro/cli/stats.py +76 -0
- router_maestro/config/__init__.py +72 -0
- router_maestro/config/contexts.py +29 -0
- router_maestro/config/paths.py +50 -0
- router_maestro/config/priorities.py +93 -0
- router_maestro/config/providers.py +34 -0
- router_maestro/config/server.py +115 -0
- router_maestro/config/settings.py +76 -0
- router_maestro/providers/__init__.py +31 -0
- router_maestro/providers/anthropic.py +203 -0
- router_maestro/providers/base.py +123 -0
- router_maestro/providers/copilot.py +346 -0
- router_maestro/providers/openai.py +188 -0
- router_maestro/providers/openai_compat.py +175 -0
- router_maestro/routing/__init__.py +5 -0
- router_maestro/routing/router.py +526 -0
- router_maestro/server/__init__.py +5 -0
- router_maestro/server/app.py +87 -0
- router_maestro/server/middleware/__init__.py +11 -0
- router_maestro/server/middleware/auth.py +66 -0
- router_maestro/server/oauth_sessions.py +159 -0
- router_maestro/server/routes/__init__.py +8 -0
- router_maestro/server/routes/admin.py +358 -0
- router_maestro/server/routes/anthropic.py +228 -0
- router_maestro/server/routes/chat.py +142 -0
- router_maestro/server/routes/models.py +34 -0
- router_maestro/server/schemas/__init__.py +57 -0
- router_maestro/server/schemas/admin.py +87 -0
- router_maestro/server/schemas/anthropic.py +246 -0
- router_maestro/server/schemas/openai.py +107 -0
- router_maestro/server/translation.py +636 -0
- router_maestro/stats/__init__.py +14 -0
- router_maestro/stats/heatmap.py +154 -0
- router_maestro/stats/storage.py +228 -0
- router_maestro/stats/tracker.py +73 -0
- router_maestro/utils/__init__.py +16 -0
- router_maestro/utils/logging.py +81 -0
- router_maestro/utils/tokens.py +51 -0
- router_maestro-0.1.2.dist-info/METADATA +383 -0
- router_maestro-0.1.2.dist-info/RECORD +57 -0
- router_maestro-0.1.2.dist-info/WHEEL +4 -0
- router_maestro-0.1.2.dist-info/entry_points.txt +2 -0
- router_maestro-0.1.2.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
"""Model priority configuration."""
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class FallbackStrategy(str, Enum):
|
|
9
|
+
"""Fallback strategy options."""
|
|
10
|
+
|
|
11
|
+
PRIORITY = "priority" # Fallback to next model in priorities list
|
|
12
|
+
SAME_MODEL = "same-model" # Only fallback to providers with the same model
|
|
13
|
+
NONE = "none" # Disable fallback, fail immediately
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class FallbackConfig(BaseModel):
|
|
17
|
+
"""Fallback configuration."""
|
|
18
|
+
|
|
19
|
+
strategy: FallbackStrategy = Field(
|
|
20
|
+
default=FallbackStrategy.PRIORITY,
|
|
21
|
+
description="Fallback strategy",
|
|
22
|
+
)
|
|
23
|
+
maxRetries: int = Field( # noqa: N815
|
|
24
|
+
default=2,
|
|
25
|
+
ge=0,
|
|
26
|
+
le=10,
|
|
27
|
+
description="Maximum number of fallback retries",
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class PrioritiesConfig(BaseModel):
|
|
32
|
+
"""Configuration for model priorities and fallback."""
|
|
33
|
+
|
|
34
|
+
priorities: list[str] = Field(
|
|
35
|
+
default_factory=list,
|
|
36
|
+
description="Model priorities in format 'provider/model', highest to lowest",
|
|
37
|
+
)
|
|
38
|
+
fallback: FallbackConfig = Field(default_factory=FallbackConfig)
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
def get_default(cls) -> "PrioritiesConfig":
|
|
42
|
+
"""Get default empty priorities configuration."""
|
|
43
|
+
return cls(priorities=[])
|
|
44
|
+
|
|
45
|
+
def get_priority(self, provider: str, model: str) -> int:
|
|
46
|
+
"""Get priority for a model.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
provider: Provider name
|
|
50
|
+
model: Model ID
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
Priority index (lower = higher priority), or 999999 if not in list
|
|
54
|
+
"""
|
|
55
|
+
key = f"{provider}/{model}"
|
|
56
|
+
try:
|
|
57
|
+
return self.priorities.index(key)
|
|
58
|
+
except ValueError:
|
|
59
|
+
return 999999
|
|
60
|
+
|
|
61
|
+
def add_priority(self, provider: str, model: str, position: int | None = None) -> None:
|
|
62
|
+
"""Add a model to priorities.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
provider: Provider name
|
|
66
|
+
model: Model ID
|
|
67
|
+
position: Position to insert (None = append to end)
|
|
68
|
+
"""
|
|
69
|
+
key = f"{provider}/{model}"
|
|
70
|
+
# Remove if already exists
|
|
71
|
+
if key in self.priorities:
|
|
72
|
+
self.priorities.remove(key)
|
|
73
|
+
# Insert at position
|
|
74
|
+
if position is None:
|
|
75
|
+
self.priorities.append(key)
|
|
76
|
+
else:
|
|
77
|
+
self.priorities.insert(position, key)
|
|
78
|
+
|
|
79
|
+
def remove_priority(self, provider: str, model: str) -> bool:
|
|
80
|
+
"""Remove a model from priorities.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
provider: Provider name
|
|
84
|
+
model: Model ID
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
True if removed, False if not found
|
|
88
|
+
"""
|
|
89
|
+
key = f"{provider}/{model}"
|
|
90
|
+
if key in self.priorities:
|
|
91
|
+
self.priorities.remove(key)
|
|
92
|
+
return True
|
|
93
|
+
return False
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Provider and model configuration models."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ModelConfig(BaseModel):
|
|
9
|
+
"""Configuration for a single model."""
|
|
10
|
+
|
|
11
|
+
name: str = Field(default="", description="Display name for the model")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class CustomProviderConfig(BaseModel):
|
|
15
|
+
"""Configuration for a custom (OpenAI-compatible) provider."""
|
|
16
|
+
|
|
17
|
+
type: str = Field(default="openai-compatible", description="Provider type")
|
|
18
|
+
baseURL: str = Field(..., description="Base URL for API requests") # noqa: N815
|
|
19
|
+
models: dict[str, ModelConfig] = Field(default_factory=dict, description="Model configurations")
|
|
20
|
+
options: dict[str, Any] = Field(default_factory=dict, description="Additional provider options")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ProvidersConfig(BaseModel):
|
|
24
|
+
"""Root configuration for custom providers only."""
|
|
25
|
+
|
|
26
|
+
providers: dict[str, CustomProviderConfig] = Field(
|
|
27
|
+
default_factory=dict,
|
|
28
|
+
description="Custom provider configurations (not including built-in providers)",
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
@classmethod
|
|
32
|
+
def get_default(cls) -> "ProvidersConfig":
|
|
33
|
+
"""Get default empty configuration."""
|
|
34
|
+
return cls(providers={})
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
"""Server configuration management.
|
|
2
|
+
|
|
3
|
+
API keys are stored in contexts.json under context.
|
|
4
|
+
This module provides utilities to manage API keys.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import secrets
|
|
8
|
+
|
|
9
|
+
from router_maestro.config.contexts import ContextConfig
|
|
10
|
+
from router_maestro.config.settings import load_contexts_config, save_contexts_config
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def generate_api_key() -> str:
|
|
14
|
+
"""Generate a random API key."""
|
|
15
|
+
return f"sk-rm-{secrets.token_urlsafe(32)}"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_local_api_key() -> str | None:
|
|
19
|
+
"""Get API key for local context.
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
The API key if configured, None otherwise.
|
|
23
|
+
"""
|
|
24
|
+
config = load_contexts_config()
|
|
25
|
+
local_ctx = config.contexts.get("local")
|
|
26
|
+
if local_ctx:
|
|
27
|
+
return local_ctx.api_key
|
|
28
|
+
return None
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_current_context_api_key() -> str | None:
|
|
32
|
+
"""Get API key for current context.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
The API key if configured, None otherwise.
|
|
36
|
+
"""
|
|
37
|
+
config = load_contexts_config()
|
|
38
|
+
ctx_name = config.current
|
|
39
|
+
ctx = config.contexts.get(ctx_name)
|
|
40
|
+
if ctx:
|
|
41
|
+
return ctx.api_key
|
|
42
|
+
return None
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def set_local_api_key(api_key: str) -> None:
|
|
46
|
+
"""Set API key for local context.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
api_key: The API key to set.
|
|
50
|
+
"""
|
|
51
|
+
config = load_contexts_config()
|
|
52
|
+
|
|
53
|
+
# Ensure local context exists
|
|
54
|
+
if "local" not in config.contexts:
|
|
55
|
+
config.contexts["local"] = ContextConfig(endpoint="http://localhost:8080")
|
|
56
|
+
|
|
57
|
+
config.contexts["local"].api_key = api_key
|
|
58
|
+
save_contexts_config(config)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def get_or_create_api_key(api_key: str | None = None) -> tuple[str, bool]:
|
|
62
|
+
"""Get or create an API key for local server.
|
|
63
|
+
|
|
64
|
+
Priority order:
|
|
65
|
+
1. Provided api_key argument (from CLI --api-key)
|
|
66
|
+
2. ROUTER_MAESTRO_API_KEY environment variable
|
|
67
|
+
3. Existing key in contexts.json
|
|
68
|
+
4. Generate new key
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
api_key: Optional API key to use.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
Tuple of (api_key, was_generated)
|
|
75
|
+
"""
|
|
76
|
+
import os
|
|
77
|
+
|
|
78
|
+
if api_key:
|
|
79
|
+
# User provided API key via CLI, save it to local context
|
|
80
|
+
set_local_api_key(api_key)
|
|
81
|
+
return api_key, False
|
|
82
|
+
|
|
83
|
+
# Check environment variable
|
|
84
|
+
env_key = os.environ.get("ROUTER_MAESTRO_API_KEY")
|
|
85
|
+
if env_key:
|
|
86
|
+
# Save to local context for persistence
|
|
87
|
+
set_local_api_key(env_key)
|
|
88
|
+
return env_key, False
|
|
89
|
+
|
|
90
|
+
# Try to load from local context
|
|
91
|
+
existing_key = get_local_api_key()
|
|
92
|
+
if existing_key:
|
|
93
|
+
return existing_key, False
|
|
94
|
+
|
|
95
|
+
# Generate new key and save to local context
|
|
96
|
+
new_key = generate_api_key()
|
|
97
|
+
set_local_api_key(new_key)
|
|
98
|
+
return new_key, True
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
# Legacy compatibility - ServerConfig is no longer used but kept for reference
|
|
102
|
+
class ServerConfig:
|
|
103
|
+
"""Deprecated: Server configuration is now stored in contexts.json."""
|
|
104
|
+
|
|
105
|
+
def __init__(self, api_key: str = "") -> None:
|
|
106
|
+
self.api_key = api_key
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def load_server_config() -> ServerConfig:
|
|
110
|
+
"""Load server configuration (for backward compatibility).
|
|
111
|
+
|
|
112
|
+
Now reads from contexts.json local context.
|
|
113
|
+
"""
|
|
114
|
+
api_key = get_local_api_key() or ""
|
|
115
|
+
return ServerConfig(api_key=api_key)
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
"""Global settings and configuration management."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import TypeVar
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
|
|
9
|
+
from router_maestro.config.contexts import ContextsConfig
|
|
10
|
+
from router_maestro.config.paths import CONTEXTS_FILE, PRIORITIES_FILE, PROVIDERS_FILE
|
|
11
|
+
from router_maestro.config.priorities import PrioritiesConfig
|
|
12
|
+
from router_maestro.config.providers import ProvidersConfig
|
|
13
|
+
|
|
14
|
+
T = TypeVar("T", bound=BaseModel)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def load_config(path: Path, model: type[T], default_factory: callable) -> T:
|
|
18
|
+
"""Load configuration from JSON file.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
path: Path to configuration file
|
|
22
|
+
model: Pydantic model class to parse into
|
|
23
|
+
default_factory: Function to create default configuration
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Parsed configuration object
|
|
27
|
+
"""
|
|
28
|
+
if not path.exists():
|
|
29
|
+
config = default_factory()
|
|
30
|
+
save_config(path, config)
|
|
31
|
+
return config
|
|
32
|
+
with open(path, encoding="utf-8") as f:
|
|
33
|
+
data = json.load(f)
|
|
34
|
+
return model.model_validate(data)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def save_config(path: Path, config: BaseModel) -> None:
|
|
38
|
+
"""Save configuration to JSON file.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
path: Path to configuration file
|
|
42
|
+
config: Configuration object to save
|
|
43
|
+
"""
|
|
44
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
45
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
46
|
+
json.dump(config.model_dump(mode="json"), f, indent=2, ensure_ascii=False)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def load_providers_config() -> ProvidersConfig:
|
|
50
|
+
"""Load providers configuration."""
|
|
51
|
+
return load_config(PROVIDERS_FILE, ProvidersConfig, ProvidersConfig.get_default)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def save_providers_config(config: ProvidersConfig) -> None:
|
|
55
|
+
"""Save providers configuration."""
|
|
56
|
+
save_config(PROVIDERS_FILE, config)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def load_priorities_config() -> PrioritiesConfig:
|
|
60
|
+
"""Load priorities configuration."""
|
|
61
|
+
return load_config(PRIORITIES_FILE, PrioritiesConfig, PrioritiesConfig.get_default)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def save_priorities_config(config: PrioritiesConfig) -> None:
|
|
65
|
+
"""Save priorities configuration."""
|
|
66
|
+
save_config(PRIORITIES_FILE, config)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def load_contexts_config() -> ContextsConfig:
|
|
70
|
+
"""Load contexts configuration."""
|
|
71
|
+
return load_config(CONTEXTS_FILE, ContextsConfig, ContextsConfig.get_default)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def save_contexts_config(config: ContextsConfig) -> None:
|
|
75
|
+
"""Save contexts configuration."""
|
|
76
|
+
save_config(CONTEXTS_FILE, config)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""Providers module for router-maestro."""
|
|
2
|
+
|
|
3
|
+
from router_maestro.providers.anthropic import AnthropicProvider
|
|
4
|
+
from router_maestro.providers.base import (
|
|
5
|
+
BaseProvider,
|
|
6
|
+
ChatRequest,
|
|
7
|
+
ChatResponse,
|
|
8
|
+
ChatStreamChunk,
|
|
9
|
+
Message,
|
|
10
|
+
ModelInfo,
|
|
11
|
+
ProviderError,
|
|
12
|
+
)
|
|
13
|
+
from router_maestro.providers.copilot import CopilotProvider
|
|
14
|
+
from router_maestro.providers.openai import OpenAIProvider
|
|
15
|
+
from router_maestro.providers.openai_compat import OpenAICompatibleProvider
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
# Base classes
|
|
19
|
+
"BaseProvider",
|
|
20
|
+
"ProviderError",
|
|
21
|
+
"Message",
|
|
22
|
+
"ChatRequest",
|
|
23
|
+
"ChatResponse",
|
|
24
|
+
"ChatStreamChunk",
|
|
25
|
+
"ModelInfo",
|
|
26
|
+
# Providers
|
|
27
|
+
"CopilotProvider",
|
|
28
|
+
"OpenAIProvider",
|
|
29
|
+
"AnthropicProvider",
|
|
30
|
+
"OpenAICompatibleProvider",
|
|
31
|
+
]
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
"""Anthropic provider implementation."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncIterator
|
|
4
|
+
|
|
5
|
+
import httpx
|
|
6
|
+
|
|
7
|
+
from router_maestro.auth import AuthManager, AuthType
|
|
8
|
+
from router_maestro.providers.base import (
|
|
9
|
+
BaseProvider,
|
|
10
|
+
ChatRequest,
|
|
11
|
+
ChatResponse,
|
|
12
|
+
ChatStreamChunk,
|
|
13
|
+
ModelInfo,
|
|
14
|
+
ProviderError,
|
|
15
|
+
)
|
|
16
|
+
from router_maestro.utils import get_logger
|
|
17
|
+
|
|
18
|
+
logger = get_logger("providers.anthropic")
|
|
19
|
+
|
|
20
|
+
ANTHROPIC_API_URL = "https://api.anthropic.com/v1"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class AnthropicProvider(BaseProvider):
|
|
24
|
+
"""Anthropic Claude provider."""
|
|
25
|
+
|
|
26
|
+
name = "anthropic"
|
|
27
|
+
|
|
28
|
+
def __init__(self, base_url: str = ANTHROPIC_API_URL) -> None:
|
|
29
|
+
self.base_url = base_url.rstrip("/")
|
|
30
|
+
self.auth_manager = AuthManager()
|
|
31
|
+
|
|
32
|
+
def is_authenticated(self) -> bool:
|
|
33
|
+
"""Check if authenticated with Anthropic."""
|
|
34
|
+
cred = self.auth_manager.get_credential("anthropic")
|
|
35
|
+
return cred is not None and cred.type == AuthType.API_KEY
|
|
36
|
+
|
|
37
|
+
def _get_api_key(self) -> str:
|
|
38
|
+
"""Get the API key."""
|
|
39
|
+
cred = self.auth_manager.get_credential("anthropic")
|
|
40
|
+
if not cred or cred.type != AuthType.API_KEY:
|
|
41
|
+
logger.error("Not authenticated with Anthropic")
|
|
42
|
+
raise ProviderError("Not authenticated with Anthropic", status_code=401)
|
|
43
|
+
return cred.key
|
|
44
|
+
|
|
45
|
+
def _get_headers(self) -> dict[str, str]:
|
|
46
|
+
"""Get headers for Anthropic API requests."""
|
|
47
|
+
return {
|
|
48
|
+
"x-api-key": self._get_api_key(),
|
|
49
|
+
"Content-Type": "application/json",
|
|
50
|
+
"anthropic-version": "2023-06-01",
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
def _convert_messages(self, messages: list) -> tuple[str | None, list[dict]]:
|
|
54
|
+
"""Convert OpenAI-style messages to Anthropic format.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
Tuple of (system_prompt, messages)
|
|
58
|
+
"""
|
|
59
|
+
system_prompt = None
|
|
60
|
+
converted = []
|
|
61
|
+
|
|
62
|
+
for msg in messages:
|
|
63
|
+
if msg.role == "system":
|
|
64
|
+
system_prompt = msg.content
|
|
65
|
+
else:
|
|
66
|
+
converted.append({"role": msg.role, "content": msg.content})
|
|
67
|
+
|
|
68
|
+
return system_prompt, converted
|
|
69
|
+
|
|
70
|
+
async def chat_completion(self, request: ChatRequest) -> ChatResponse:
|
|
71
|
+
"""Generate a chat completion via Anthropic."""
|
|
72
|
+
system_prompt, messages = self._convert_messages(request.messages)
|
|
73
|
+
|
|
74
|
+
payload = {
|
|
75
|
+
"model": request.model,
|
|
76
|
+
"messages": messages,
|
|
77
|
+
"max_tokens": request.max_tokens or 4096,
|
|
78
|
+
}
|
|
79
|
+
if system_prompt:
|
|
80
|
+
payload["system"] = system_prompt
|
|
81
|
+
if request.temperature != 1.0:
|
|
82
|
+
payload["temperature"] = request.temperature
|
|
83
|
+
|
|
84
|
+
logger.debug("Anthropic chat completion: model=%s", request.model)
|
|
85
|
+
async with httpx.AsyncClient() as client:
|
|
86
|
+
try:
|
|
87
|
+
response = await client.post(
|
|
88
|
+
f"{self.base_url}/messages",
|
|
89
|
+
json=payload,
|
|
90
|
+
headers=self._get_headers(),
|
|
91
|
+
timeout=120.0,
|
|
92
|
+
)
|
|
93
|
+
response.raise_for_status()
|
|
94
|
+
data = response.json()
|
|
95
|
+
|
|
96
|
+
# Extract content from Anthropic response
|
|
97
|
+
content = ""
|
|
98
|
+
for block in data.get("content", []):
|
|
99
|
+
if block.get("type") == "text":
|
|
100
|
+
content += block.get("text", "")
|
|
101
|
+
|
|
102
|
+
logger.debug("Anthropic chat completion successful")
|
|
103
|
+
return ChatResponse(
|
|
104
|
+
content=content,
|
|
105
|
+
model=data.get("model", request.model),
|
|
106
|
+
finish_reason=data.get("stop_reason", "stop"),
|
|
107
|
+
usage={
|
|
108
|
+
"prompt_tokens": data.get("usage", {}).get("input_tokens", 0),
|
|
109
|
+
"completion_tokens": data.get("usage", {}).get("output_tokens", 0),
|
|
110
|
+
"total_tokens": (
|
|
111
|
+
data.get("usage", {}).get("input_tokens", 0)
|
|
112
|
+
+ data.get("usage", {}).get("output_tokens", 0)
|
|
113
|
+
),
|
|
114
|
+
},
|
|
115
|
+
)
|
|
116
|
+
except httpx.HTTPStatusError as e:
|
|
117
|
+
retryable = e.response.status_code in (429, 500, 502, 503, 504, 529)
|
|
118
|
+
logger.error("Anthropic API error: %d", e.response.status_code)
|
|
119
|
+
raise ProviderError(
|
|
120
|
+
f"Anthropic API error: {e.response.status_code}",
|
|
121
|
+
status_code=e.response.status_code,
|
|
122
|
+
retryable=retryable,
|
|
123
|
+
)
|
|
124
|
+
except httpx.HTTPError as e:
|
|
125
|
+
logger.error("Anthropic HTTP error: %s", e)
|
|
126
|
+
raise ProviderError(f"HTTP error: {e}", retryable=True)
|
|
127
|
+
|
|
128
|
+
async def chat_completion_stream(self, request: ChatRequest) -> AsyncIterator[ChatStreamChunk]:
|
|
129
|
+
"""Generate a streaming chat completion via Anthropic."""
|
|
130
|
+
system_prompt, messages = self._convert_messages(request.messages)
|
|
131
|
+
|
|
132
|
+
payload = {
|
|
133
|
+
"model": request.model,
|
|
134
|
+
"messages": messages,
|
|
135
|
+
"max_tokens": request.max_tokens or 4096,
|
|
136
|
+
"stream": True,
|
|
137
|
+
}
|
|
138
|
+
if system_prompt:
|
|
139
|
+
payload["system"] = system_prompt
|
|
140
|
+
if request.temperature != 1.0:
|
|
141
|
+
payload["temperature"] = request.temperature
|
|
142
|
+
|
|
143
|
+
logger.debug("Anthropic streaming chat: model=%s", request.model)
|
|
144
|
+
async with httpx.AsyncClient() as client:
|
|
145
|
+
try:
|
|
146
|
+
async with client.stream(
|
|
147
|
+
"POST",
|
|
148
|
+
f"{self.base_url}/messages",
|
|
149
|
+
json=payload,
|
|
150
|
+
headers=self._get_headers(),
|
|
151
|
+
timeout=120.0,
|
|
152
|
+
) as response:
|
|
153
|
+
response.raise_for_status()
|
|
154
|
+
|
|
155
|
+
async for line in response.aiter_lines():
|
|
156
|
+
if not line or not line.startswith("data: "):
|
|
157
|
+
continue
|
|
158
|
+
|
|
159
|
+
data_str = line[6:]
|
|
160
|
+
if not data_str:
|
|
161
|
+
continue
|
|
162
|
+
|
|
163
|
+
import json
|
|
164
|
+
|
|
165
|
+
data = json.loads(data_str)
|
|
166
|
+
event_type = data.get("type")
|
|
167
|
+
|
|
168
|
+
if event_type == "content_block_delta":
|
|
169
|
+
delta = data.get("delta", {})
|
|
170
|
+
if delta.get("type") == "text_delta":
|
|
171
|
+
yield ChatStreamChunk(
|
|
172
|
+
content=delta.get("text", ""),
|
|
173
|
+
finish_reason=None,
|
|
174
|
+
)
|
|
175
|
+
elif event_type == "message_stop":
|
|
176
|
+
yield ChatStreamChunk(
|
|
177
|
+
content="",
|
|
178
|
+
finish_reason="stop",
|
|
179
|
+
)
|
|
180
|
+
except httpx.HTTPStatusError as e:
|
|
181
|
+
retryable = e.response.status_code in (429, 500, 502, 503, 504, 529)
|
|
182
|
+
logger.error("Anthropic stream API error: %d", e.response.status_code)
|
|
183
|
+
raise ProviderError(
|
|
184
|
+
f"Anthropic API error: {e.response.status_code}",
|
|
185
|
+
status_code=e.response.status_code,
|
|
186
|
+
retryable=retryable,
|
|
187
|
+
)
|
|
188
|
+
except httpx.HTTPError as e:
|
|
189
|
+
logger.error("Anthropic stream HTTP error: %s", e)
|
|
190
|
+
raise ProviderError(f"HTTP error: {e}", retryable=True)
|
|
191
|
+
|
|
192
|
+
async def list_models(self) -> list[ModelInfo]:
|
|
193
|
+
"""List available Anthropic models."""
|
|
194
|
+
# Anthropic doesn't have a models endpoint, return known models
|
|
195
|
+
logger.debug("Returning known Anthropic models")
|
|
196
|
+
return [
|
|
197
|
+
ModelInfo(id="claude-sonnet-4-20250514", name="Claude Sonnet 4", provider=self.name),
|
|
198
|
+
ModelInfo(
|
|
199
|
+
id="claude-3-5-sonnet-20241022", name="Claude 3.5 Sonnet", provider=self.name
|
|
200
|
+
),
|
|
201
|
+
ModelInfo(id="claude-3-5-haiku-20241022", name="Claude 3.5 Haiku", provider=self.name),
|
|
202
|
+
ModelInfo(id="claude-3-opus-20240229", name="Claude 3 Opus", provider=self.name),
|
|
203
|
+
]
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
"""Base provider interface."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from collections.abc import AsyncIterator
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class Message:
|
|
10
|
+
"""A message in the conversation."""
|
|
11
|
+
|
|
12
|
+
role: str # "system", "user", "assistant", "tool"
|
|
13
|
+
content: str | list # Can be str or list for multimodal content (images)
|
|
14
|
+
tool_call_id: str | None = None # Required for tool role messages
|
|
15
|
+
tool_calls: list[dict] | None = None # For assistant messages with tool calls
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class ChatRequest:
|
|
20
|
+
"""Request for chat completion."""
|
|
21
|
+
|
|
22
|
+
model: str
|
|
23
|
+
messages: list[Message]
|
|
24
|
+
temperature: float = 1.0
|
|
25
|
+
max_tokens: int | None = None
|
|
26
|
+
stream: bool = False
|
|
27
|
+
tools: list[dict] | None = None # OpenAI format tool definitions
|
|
28
|
+
# "auto", "none", "required", or {"type": "function", "function": {"name": "..."}}
|
|
29
|
+
tool_choice: str | dict | None = None
|
|
30
|
+
extra: dict = field(default_factory=dict)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class ChatResponse:
|
|
35
|
+
"""Response from chat completion."""
|
|
36
|
+
|
|
37
|
+
content: str
|
|
38
|
+
model: str
|
|
39
|
+
finish_reason: str = "stop"
|
|
40
|
+
usage: dict | None = None # {"prompt_tokens": X, "completion_tokens": Y, "total_tokens": Z}
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class ChatStreamChunk:
|
|
45
|
+
"""A chunk from streaming chat completion."""
|
|
46
|
+
|
|
47
|
+
content: str
|
|
48
|
+
finish_reason: str | None = None
|
|
49
|
+
usage: dict | None = None # Token usage info (typically in final chunk)
|
|
50
|
+
tool_calls: list[dict] | None = None # Tool call deltas for streaming
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class ModelInfo:
|
|
55
|
+
"""Information about an available model."""
|
|
56
|
+
|
|
57
|
+
id: str
|
|
58
|
+
name: str
|
|
59
|
+
provider: str
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class ProviderError(Exception):
|
|
63
|
+
"""Error from a provider."""
|
|
64
|
+
|
|
65
|
+
def __init__(self, message: str, status_code: int = 500, retryable: bool = False):
|
|
66
|
+
super().__init__(message)
|
|
67
|
+
self.status_code = status_code
|
|
68
|
+
self.retryable = retryable
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class BaseProvider(ABC):
|
|
72
|
+
"""Abstract base class for model providers."""
|
|
73
|
+
|
|
74
|
+
name: str = "base"
|
|
75
|
+
|
|
76
|
+
@abstractmethod
|
|
77
|
+
async def chat_completion(self, request: ChatRequest) -> ChatResponse:
|
|
78
|
+
"""Generate a chat completion.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
request: Chat completion request
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
Chat completion response
|
|
85
|
+
"""
|
|
86
|
+
pass
|
|
87
|
+
|
|
88
|
+
@abstractmethod
|
|
89
|
+
async def chat_completion_stream(self, request: ChatRequest) -> AsyncIterator[ChatStreamChunk]:
|
|
90
|
+
"""Generate a streaming chat completion.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
request: Chat completion request
|
|
94
|
+
|
|
95
|
+
Yields:
|
|
96
|
+
Chat completion chunks
|
|
97
|
+
"""
|
|
98
|
+
pass
|
|
99
|
+
|
|
100
|
+
@abstractmethod
|
|
101
|
+
async def list_models(self) -> list[ModelInfo]:
|
|
102
|
+
"""List available models.
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
List of available models
|
|
106
|
+
"""
|
|
107
|
+
pass
|
|
108
|
+
|
|
109
|
+
@abstractmethod
|
|
110
|
+
def is_authenticated(self) -> bool:
|
|
111
|
+
"""Check if the provider is authenticated.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
True if authenticated
|
|
115
|
+
"""
|
|
116
|
+
pass
|
|
117
|
+
|
|
118
|
+
async def ensure_token(self) -> None:
|
|
119
|
+
"""Ensure the provider has a valid token.
|
|
120
|
+
|
|
121
|
+
Override this for providers that need token refresh.
|
|
122
|
+
"""
|
|
123
|
+
pass
|