arbiter-cli 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbiter/__init__.py +3 -0
- arbiter/cli/__init__.py +0 -0
- arbiter/cli/app.py +699 -0
- arbiter/cli/display.py +381 -0
- arbiter/core/__init__.py +0 -0
- arbiter/core/benchmarks.py +804 -0
- arbiter/core/config.py +137 -0
- arbiter/core/discover.py +184 -0
- arbiter/core/judge.py +193 -0
- arbiter/core/leaderboard.py +197 -0
- arbiter/core/metrics.py +367 -0
- arbiter/core/providers/__init__.py +19 -0
- arbiter/core/providers/anthropic_provider.py +133 -0
- arbiter/core/providers/base.py +62 -0
- arbiter/core/providers/factory.py +79 -0
- arbiter/core/providers/google_provider.py +126 -0
- arbiter/core/providers/ollama.py +103 -0
- arbiter/core/providers/openai_provider.py +120 -0
- arbiter/core/runner.py +257 -0
- arbiter/core/swe/__init__.py +1 -0
- arbiter/core/swe/container.py +158 -0
- arbiter/core/swe/runner.py +220 -0
- arbiter/core/swe/sandbox.py +111 -0
- arbiter/core/swe/test_packs.py +548 -0
- arbiter/dashboard/__init__.py +0 -0
- arbiter/dashboard/frontend/dist/assets/index-1tkxJouQ.css +1 -0
- arbiter/dashboard/frontend/dist/assets/index-dHa4zmvw.js +298 -0
- arbiter/dashboard/frontend/dist/index.html +16 -0
- arbiter/dashboard/server.py +426 -0
- arbiter_cli-0.1.0.dist-info/METADATA +299 -0
- arbiter_cli-0.1.0.dist-info/RECORD +35 -0
- arbiter_cli-0.1.0.dist-info/WHEEL +5 -0
- arbiter_cli-0.1.0.dist-info/entry_points.txt +2 -0
- arbiter_cli-0.1.0.dist-info/licenses/LICENSE +21 -0
- arbiter_cli-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
"""Anthropic provider - Claude models via the Anthropic API."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from typing import AsyncIterator, Optional
|
|
7
|
+
|
|
8
|
+
import httpx
|
|
9
|
+
|
|
10
|
+
from arbiter.core.providers.base import LLMProvider, StreamChunk
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AnthropicProvider(LLMProvider):
|
|
14
|
+
"""Provider for Anthropic Claude models."""
|
|
15
|
+
|
|
16
|
+
provider_name = "anthropic"
|
|
17
|
+
|
|
18
|
+
def __init__(self, api_key: str):
|
|
19
|
+
self.api_key = api_key
|
|
20
|
+
self.base_url = "https://api.anthropic.com/v1"
|
|
21
|
+
|
|
22
|
+
async def stream_generate(
|
|
23
|
+
self,
|
|
24
|
+
model: str,
|
|
25
|
+
prompt: str,
|
|
26
|
+
system: Optional[str] = None,
|
|
27
|
+
image_path: Optional[str] = None,
|
|
28
|
+
) -> AsyncIterator[StreamChunk]:
|
|
29
|
+
"""Stream tokens from the Anthropic API."""
|
|
30
|
+
messages = []
|
|
31
|
+
|
|
32
|
+
if image_path:
|
|
33
|
+
import base64
|
|
34
|
+
from pathlib import Path
|
|
35
|
+
|
|
36
|
+
img_data = base64.b64encode(Path(image_path).read_bytes()).decode()
|
|
37
|
+
messages.append(
|
|
38
|
+
{
|
|
39
|
+
"role": "user",
|
|
40
|
+
"content": [
|
|
41
|
+
{
|
|
42
|
+
"type": "image",
|
|
43
|
+
"source": {
|
|
44
|
+
"type": "base64",
|
|
45
|
+
"media_type": "image/jpeg",
|
|
46
|
+
"data": img_data,
|
|
47
|
+
},
|
|
48
|
+
},
|
|
49
|
+
{"type": "text", "text": prompt},
|
|
50
|
+
],
|
|
51
|
+
}
|
|
52
|
+
)
|
|
53
|
+
else:
|
|
54
|
+
messages.append({"role": "user", "content": prompt})
|
|
55
|
+
|
|
56
|
+
payload: dict = {
|
|
57
|
+
"model": model,
|
|
58
|
+
"messages": messages,
|
|
59
|
+
"max_tokens": 4096,
|
|
60
|
+
"stream": True,
|
|
61
|
+
}
|
|
62
|
+
if system:
|
|
63
|
+
payload["system"] = system
|
|
64
|
+
|
|
65
|
+
headers = {
|
|
66
|
+
"x-api-key": self.api_key,
|
|
67
|
+
"anthropic-version": "2023-06-01",
|
|
68
|
+
"Content-Type": "application/json",
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:
|
|
72
|
+
async with client.stream(
|
|
73
|
+
"POST",
|
|
74
|
+
f"{self.base_url}/messages",
|
|
75
|
+
json=payload,
|
|
76
|
+
headers=headers,
|
|
77
|
+
) as response:
|
|
78
|
+
response.raise_for_status()
|
|
79
|
+
async for line in response.aiter_lines():
|
|
80
|
+
line = line.strip()
|
|
81
|
+
if not line or not line.startswith("data: "):
|
|
82
|
+
continue
|
|
83
|
+
data_str = line[6:]
|
|
84
|
+
if not data_str:
|
|
85
|
+
continue
|
|
86
|
+
try:
|
|
87
|
+
data = json.loads(data_str)
|
|
88
|
+
except json.JSONDecodeError:
|
|
89
|
+
continue
|
|
90
|
+
|
|
91
|
+
event_type = data.get("type", "")
|
|
92
|
+
|
|
93
|
+
if event_type == "content_block_delta":
|
|
94
|
+
delta = data.get("delta", {})
|
|
95
|
+
text = delta.get("text", "")
|
|
96
|
+
yield StreamChunk(text=text, done=False)
|
|
97
|
+
|
|
98
|
+
elif event_type == "message_delta":
|
|
99
|
+
usage = data.get("usage", {})
|
|
100
|
+
yield StreamChunk(
|
|
101
|
+
text="",
|
|
102
|
+
done=True,
|
|
103
|
+
meta={
|
|
104
|
+
"output_tokens": usage.get("output_tokens"),
|
|
105
|
+
"stop_reason": data.get("delta", {}).get(
|
|
106
|
+
"stop_reason"
|
|
107
|
+
),
|
|
108
|
+
},
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
elif event_type == "message_start":
|
|
112
|
+
# Capture input token count
|
|
113
|
+
msg = data.get("message", {})
|
|
114
|
+
usage = msg.get("usage", {})
|
|
115
|
+
if usage.get("input_tokens"):
|
|
116
|
+
yield StreamChunk(
|
|
117
|
+
text="",
|
|
118
|
+
done=False,
|
|
119
|
+
meta={"input_tokens": usage["input_tokens"]},
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
async def list_models(self) -> list[dict]:
|
|
123
|
+
"""List available Anthropic models."""
|
|
124
|
+
# Anthropic doesn't have a list models endpoint, return known models
|
|
125
|
+
return [
|
|
126
|
+
{"name": "claude-sonnet-4-20250514", "size": None},
|
|
127
|
+
{"name": "claude-haiku-4-20250414", "size": None},
|
|
128
|
+
{"name": "claude-opus-4-20250514", "size": None},
|
|
129
|
+
]
|
|
130
|
+
|
|
131
|
+
async def check_connection(self) -> bool:
|
|
132
|
+
"""Check if the Anthropic API key is configured."""
|
|
133
|
+
return bool(self.api_key)
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""Base provider interface for all LLM backends."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from typing import AsyncIterator, Optional
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class StreamChunk:
|
|
12
|
+
"""A single chunk from a streaming LLM response."""
|
|
13
|
+
|
|
14
|
+
text: str
|
|
15
|
+
done: bool = False
|
|
16
|
+
# Provider-specific metadata (eval_count, eval_duration, etc.)
|
|
17
|
+
meta: dict = field(default_factory=dict)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class GenerationResult:
|
|
22
|
+
"""Complete result from a model generation."""
|
|
23
|
+
|
|
24
|
+
model: str
|
|
25
|
+
provider: str
|
|
26
|
+
output: str
|
|
27
|
+
total_tokens: int
|
|
28
|
+
eval_duration_ns: Optional[int] = None # nanoseconds
|
|
29
|
+
total_duration_ns: Optional[int] = None
|
|
30
|
+
prompt_tokens: Optional[int] = None
|
|
31
|
+
# Raw provider response metadata
|
|
32
|
+
raw_meta: dict = field(default_factory=dict)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class LLMProvider(ABC):
|
|
36
|
+
"""Abstract base class for LLM providers."""
|
|
37
|
+
|
|
38
|
+
provider_name: str = "base"
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
async def stream_generate(
|
|
42
|
+
self,
|
|
43
|
+
model: str,
|
|
44
|
+
prompt: str,
|
|
45
|
+
system: Optional[str] = None,
|
|
46
|
+
image_path: Optional[str] = None,
|
|
47
|
+
) -> AsyncIterator[StreamChunk]:
|
|
48
|
+
"""Stream tokens from the model. Yields StreamChunk objects."""
|
|
49
|
+
...
|
|
50
|
+
|
|
51
|
+
@abstractmethod
|
|
52
|
+
async def list_models(self) -> list[dict]:
|
|
53
|
+
"""List available models for this provider.
|
|
54
|
+
|
|
55
|
+
Returns a list of dicts with at least {"name": str, "size": int|None}.
|
|
56
|
+
"""
|
|
57
|
+
...
|
|
58
|
+
|
|
59
|
+
@abstractmethod
|
|
60
|
+
async def check_connection(self) -> bool:
|
|
61
|
+
"""Check if the provider is reachable and configured."""
|
|
62
|
+
...
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
"""Factory for creating provider instances from config."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from arbiter.core.config import (
|
|
6
|
+
PROVIDER_ANTHROPIC,
|
|
7
|
+
PROVIDER_GOOGLE,
|
|
8
|
+
PROVIDER_OLLAMA,
|
|
9
|
+
PROVIDER_OPENAI,
|
|
10
|
+
PROVIDER_OPENAI_COMPAT,
|
|
11
|
+
ProviderConfig,
|
|
12
|
+
resolve_model,
|
|
13
|
+
)
|
|
14
|
+
from arbiter.core.providers.anthropic_provider import AnthropicProvider
|
|
15
|
+
from arbiter.core.providers.base import LLMProvider
|
|
16
|
+
from arbiter.core.providers.google_provider import GoogleProvider
|
|
17
|
+
from arbiter.core.providers.ollama import OllamaProvider
|
|
18
|
+
from arbiter.core.providers.openai_provider import OpenAIProvider
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def create_provider(config: ProviderConfig) -> tuple[LLMProvider, str]:
|
|
22
|
+
"""Create a provider instance and return (provider, clean_model_name).
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
config: ProviderConfig from resolve_model()
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
Tuple of (provider_instance, model_name_to_use)
|
|
29
|
+
|
|
30
|
+
Raises:
|
|
31
|
+
ValueError: If API key is required but missing
|
|
32
|
+
"""
|
|
33
|
+
model = config.extra["model"]
|
|
34
|
+
|
|
35
|
+
if config.provider == PROVIDER_OLLAMA:
|
|
36
|
+
provider = OllamaProvider(base_url=config.base_url or "http://localhost:11434")
|
|
37
|
+
return provider, model
|
|
38
|
+
|
|
39
|
+
if config.provider == PROVIDER_OPENAI:
|
|
40
|
+
if not config.api_key:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
"OpenAI API key required. Set OPENAI_API_KEY environment variable."
|
|
43
|
+
)
|
|
44
|
+
provider = OpenAIProvider(api_key=config.api_key)
|
|
45
|
+
return provider, model
|
|
46
|
+
|
|
47
|
+
if config.provider == PROVIDER_OPENAI_COMPAT:
|
|
48
|
+
if not config.api_key:
|
|
49
|
+
raise ValueError(
|
|
50
|
+
"API key required for OpenAI-compatible endpoint. Set OPENAI_API_KEY."
|
|
51
|
+
)
|
|
52
|
+
provider = OpenAIProvider(
|
|
53
|
+
api_key=config.api_key, base_url=config.base_url or ""
|
|
54
|
+
)
|
|
55
|
+
return provider, model
|
|
56
|
+
|
|
57
|
+
if config.provider == PROVIDER_ANTHROPIC:
|
|
58
|
+
if not config.api_key:
|
|
59
|
+
raise ValueError(
|
|
60
|
+
"Anthropic API key required. Set ANTHROPIC_API_KEY environment variable."
|
|
61
|
+
)
|
|
62
|
+
provider = AnthropicProvider(api_key=config.api_key)
|
|
63
|
+
return provider, model
|
|
64
|
+
|
|
65
|
+
if config.provider == PROVIDER_GOOGLE:
|
|
66
|
+
if not config.api_key:
|
|
67
|
+
raise ValueError(
|
|
68
|
+
"Google API key required. Set GOOGLE_API_KEY environment variable."
|
|
69
|
+
)
|
|
70
|
+
provider = GoogleProvider(api_key=config.api_key)
|
|
71
|
+
return provider, model
|
|
72
|
+
|
|
73
|
+
raise ValueError(f"Unknown provider: {config.provider}")
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def provider_from_spec(model_spec: str) -> tuple[LLMProvider, str]:
|
|
77
|
+
"""Convenience: resolve a model spec string directly to (provider, model_name)."""
|
|
78
|
+
config = resolve_model(model_spec)
|
|
79
|
+
return create_provider(config)
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
"""Google Gemini provider via the Generative Language API."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from typing import AsyncIterator, Optional
|
|
7
|
+
|
|
8
|
+
import httpx
|
|
9
|
+
|
|
10
|
+
from arbiter.core.providers.base import LLMProvider, StreamChunk
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class GoogleProvider(LLMProvider):
|
|
14
|
+
"""Provider for Google Gemini models."""
|
|
15
|
+
|
|
16
|
+
provider_name = "google"
|
|
17
|
+
|
|
18
|
+
def __init__(self, api_key: str):
|
|
19
|
+
self.api_key = api_key
|
|
20
|
+
self.base_url = "https://generativelanguage.googleapis.com/v1beta"
|
|
21
|
+
|
|
22
|
+
async def stream_generate(
|
|
23
|
+
self,
|
|
24
|
+
model: str,
|
|
25
|
+
prompt: str,
|
|
26
|
+
system: Optional[str] = None,
|
|
27
|
+
image_path: Optional[str] = None,
|
|
28
|
+
) -> AsyncIterator[StreamChunk]:
|
|
29
|
+
"""Stream tokens from the Google Gemini API."""
|
|
30
|
+
contents = []
|
|
31
|
+
|
|
32
|
+
if image_path:
|
|
33
|
+
import base64
|
|
34
|
+
from pathlib import Path
|
|
35
|
+
|
|
36
|
+
img_data = base64.b64encode(Path(image_path).read_bytes()).decode()
|
|
37
|
+
contents.append(
|
|
38
|
+
{
|
|
39
|
+
"parts": [
|
|
40
|
+
{"inline_data": {"mime_type": "image/jpeg", "data": img_data}},
|
|
41
|
+
{"text": prompt},
|
|
42
|
+
]
|
|
43
|
+
}
|
|
44
|
+
)
|
|
45
|
+
else:
|
|
46
|
+
contents.append({"parts": [{"text": prompt}]})
|
|
47
|
+
|
|
48
|
+
payload: dict = {"contents": contents}
|
|
49
|
+
if system:
|
|
50
|
+
payload["system_instruction"] = {"parts": [{"text": system}]}
|
|
51
|
+
|
|
52
|
+
url = (
|
|
53
|
+
f"{self.base_url}/models/{model}:streamGenerateContent"
|
|
54
|
+
f"?key={self.api_key}&alt=sse"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:
|
|
58
|
+
async with client.stream("POST", url, json=payload) as response:
|
|
59
|
+
response.raise_for_status()
|
|
60
|
+
async for line in response.aiter_lines():
|
|
61
|
+
line = line.strip()
|
|
62
|
+
if not line or not line.startswith("data: "):
|
|
63
|
+
continue
|
|
64
|
+
data_str = line[6:]
|
|
65
|
+
try:
|
|
66
|
+
data = json.loads(data_str)
|
|
67
|
+
except json.JSONDecodeError:
|
|
68
|
+
continue
|
|
69
|
+
|
|
70
|
+
candidates = data.get("candidates", [])
|
|
71
|
+
if not candidates:
|
|
72
|
+
continue
|
|
73
|
+
|
|
74
|
+
candidate = candidates[0]
|
|
75
|
+
content = candidate.get("content", {})
|
|
76
|
+
parts = content.get("parts", [])
|
|
77
|
+
text = "".join(p.get("text", "") for p in parts)
|
|
78
|
+
|
|
79
|
+
finish_reason = candidate.get("finishReason")
|
|
80
|
+
done = finish_reason is not None and finish_reason != "STOP"
|
|
81
|
+
|
|
82
|
+
chunk = StreamChunk(text=text, done=False)
|
|
83
|
+
|
|
84
|
+
if finish_reason == "STOP":
|
|
85
|
+
usage = data.get("usageMetadata", {})
|
|
86
|
+
chunk.done = True
|
|
87
|
+
chunk.meta = {
|
|
88
|
+
"prompt_tokens": usage.get("promptTokenCount"),
|
|
89
|
+
"output_tokens": usage.get("candidatesTokenCount"),
|
|
90
|
+
"total_tokens": usage.get("totalTokenCount"),
|
|
91
|
+
"finish_reason": finish_reason,
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
yield chunk
|
|
95
|
+
|
|
96
|
+
async def list_models(self) -> list[dict]:
|
|
97
|
+
"""List available Gemini models."""
|
|
98
|
+
try:
|
|
99
|
+
url = f"{self.base_url}/models?key={self.api_key}"
|
|
100
|
+
async with httpx.AsyncClient(timeout=httpx.Timeout(10.0)) as client:
|
|
101
|
+
resp = await client.get(url)
|
|
102
|
+
resp.raise_for_status()
|
|
103
|
+
data = resp.json()
|
|
104
|
+
return [
|
|
105
|
+
{
|
|
106
|
+
"name": m["name"].replace("models/", ""),
|
|
107
|
+
"size": None,
|
|
108
|
+
"display_name": m.get("displayName"),
|
|
109
|
+
}
|
|
110
|
+
for m in data.get("models", [])
|
|
111
|
+
if "generateContent" in m.get("supportedGenerationMethods", [])
|
|
112
|
+
]
|
|
113
|
+
except (httpx.HTTPError, KeyError):
|
|
114
|
+
return []
|
|
115
|
+
|
|
116
|
+
async def check_connection(self) -> bool:
|
|
117
|
+
"""Check if the Google API key is configured and valid."""
|
|
118
|
+
if not self.api_key:
|
|
119
|
+
return False
|
|
120
|
+
try:
|
|
121
|
+
url = f"{self.base_url}/models?key={self.api_key}"
|
|
122
|
+
async with httpx.AsyncClient(timeout=httpx.Timeout(10.0)) as client:
|
|
123
|
+
resp = await client.get(url)
|
|
124
|
+
return resp.status_code == 200
|
|
125
|
+
except (httpx.ConnectError, httpx.TimeoutException):
|
|
126
|
+
return False
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
"""Ollama provider - local model execution."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import base64
|
|
6
|
+
import json
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import AsyncIterator, Optional
|
|
9
|
+
|
|
10
|
+
import httpx
|
|
11
|
+
|
|
12
|
+
from arbiter.core.providers.base import GenerationResult, LLMProvider, StreamChunk
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class OllamaProvider(LLMProvider):
|
|
16
|
+
"""Provider for locally running Ollama models."""
|
|
17
|
+
|
|
18
|
+
provider_name = "ollama"
|
|
19
|
+
|
|
20
|
+
def __init__(self, base_url: str = "http://localhost:11434"):
|
|
21
|
+
self.base_url = base_url.rstrip("/")
|
|
22
|
+
|
|
23
|
+
async def stream_generate(
|
|
24
|
+
self,
|
|
25
|
+
model: str,
|
|
26
|
+
prompt: str,
|
|
27
|
+
system: Optional[str] = None,
|
|
28
|
+
image_path: Optional[str] = None,
|
|
29
|
+
) -> AsyncIterator[StreamChunk]:
|
|
30
|
+
"""Stream tokens from an Ollama model."""
|
|
31
|
+
payload: dict = {"model": model, "prompt": prompt, "stream": True}
|
|
32
|
+
|
|
33
|
+
if system:
|
|
34
|
+
payload["system"] = system
|
|
35
|
+
|
|
36
|
+
if image_path:
|
|
37
|
+
img_data = Path(image_path).read_bytes()
|
|
38
|
+
payload["images"] = [base64.b64encode(img_data).decode()]
|
|
39
|
+
|
|
40
|
+
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:
|
|
41
|
+
async with client.stream(
|
|
42
|
+
"POST", f"{self.base_url}/api/generate", json=payload
|
|
43
|
+
) as response:
|
|
44
|
+
response.raise_for_status()
|
|
45
|
+
async for line in response.aiter_lines():
|
|
46
|
+
if not line.strip():
|
|
47
|
+
continue
|
|
48
|
+
data = json.loads(line)
|
|
49
|
+
chunk = StreamChunk(
|
|
50
|
+
text=data.get("response", ""),
|
|
51
|
+
done=data.get("done", False),
|
|
52
|
+
)
|
|
53
|
+
if chunk.done:
|
|
54
|
+
chunk.meta = {
|
|
55
|
+
"total_duration": data.get("total_duration"),
|
|
56
|
+
"eval_count": data.get("eval_count"),
|
|
57
|
+
"eval_duration": data.get("eval_duration"),
|
|
58
|
+
"prompt_eval_count": data.get("prompt_eval_count"),
|
|
59
|
+
"prompt_eval_duration": data.get("prompt_eval_duration"),
|
|
60
|
+
}
|
|
61
|
+
yield chunk
|
|
62
|
+
|
|
63
|
+
async def list_models(self) -> list[dict]:
|
|
64
|
+
"""List installed Ollama models."""
|
|
65
|
+
async with httpx.AsyncClient(timeout=httpx.Timeout(10.0)) as client:
|
|
66
|
+
resp = await client.get(f"{self.base_url}/api/tags")
|
|
67
|
+
resp.raise_for_status()
|
|
68
|
+
data = resp.json()
|
|
69
|
+
|
|
70
|
+
models = []
|
|
71
|
+
for m in data.get("models", []):
|
|
72
|
+
details = m.get("details", {})
|
|
73
|
+
models.append(
|
|
74
|
+
{
|
|
75
|
+
"name": m["name"],
|
|
76
|
+
"size": m.get("size"),
|
|
77
|
+
"parameter_size": details.get("parameter_size"),
|
|
78
|
+
"quantization": details.get("quantization_level"),
|
|
79
|
+
"family": details.get("family"),
|
|
80
|
+
"families": details.get("families"),
|
|
81
|
+
"format": details.get("format"),
|
|
82
|
+
"modified_at": m.get("modified_at"),
|
|
83
|
+
}
|
|
84
|
+
)
|
|
85
|
+
return models
|
|
86
|
+
|
|
87
|
+
async def get_model_info(self, model: str) -> dict:
|
|
88
|
+
"""Get detailed info for a specific model."""
|
|
89
|
+
async with httpx.AsyncClient(timeout=httpx.Timeout(10.0)) as client:
|
|
90
|
+
resp = await client.post(
|
|
91
|
+
f"{self.base_url}/api/show", json={"name": model}
|
|
92
|
+
)
|
|
93
|
+
resp.raise_for_status()
|
|
94
|
+
return resp.json()
|
|
95
|
+
|
|
96
|
+
async def check_connection(self) -> bool:
|
|
97
|
+
"""Check if Ollama is running."""
|
|
98
|
+
try:
|
|
99
|
+
async with httpx.AsyncClient(timeout=httpx.Timeout(5.0)) as client:
|
|
100
|
+
resp = await client.get(f"{self.base_url}/api/tags")
|
|
101
|
+
return resp.status_code == 200
|
|
102
|
+
except (httpx.ConnectError, httpx.TimeoutException):
|
|
103
|
+
return False
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
"""OpenAI-compatible provider - works with OpenAI, Together, Groq, any OpenAI-compatible API."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from typing import AsyncIterator, Optional
|
|
7
|
+
|
|
8
|
+
import httpx
|
|
9
|
+
|
|
10
|
+
from arbiter.core.providers.base import LLMProvider, StreamChunk
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class OpenAIProvider(LLMProvider):
|
|
14
|
+
"""Provider for OpenAI and any OpenAI-compatible API."""
|
|
15
|
+
|
|
16
|
+
provider_name = "openai"
|
|
17
|
+
|
|
18
|
+
def __init__(self, api_key: str, base_url: str = "https://api.openai.com/v1"):
|
|
19
|
+
self.api_key = api_key
|
|
20
|
+
self.base_url = base_url.rstrip("/")
|
|
21
|
+
|
|
22
|
+
async def stream_generate(
|
|
23
|
+
self,
|
|
24
|
+
model: str,
|
|
25
|
+
prompt: str,
|
|
26
|
+
system: Optional[str] = None,
|
|
27
|
+
image_path: Optional[str] = None,
|
|
28
|
+
) -> AsyncIterator[StreamChunk]:
|
|
29
|
+
"""Stream tokens from an OpenAI-compatible API."""
|
|
30
|
+
messages = []
|
|
31
|
+
if system:
|
|
32
|
+
messages.append({"role": "system", "content": system})
|
|
33
|
+
|
|
34
|
+
if image_path:
|
|
35
|
+
import base64
|
|
36
|
+
from pathlib import Path
|
|
37
|
+
|
|
38
|
+
img_data = base64.b64encode(Path(image_path).read_bytes()).decode()
|
|
39
|
+
messages.append(
|
|
40
|
+
{
|
|
41
|
+
"role": "user",
|
|
42
|
+
"content": [
|
|
43
|
+
{"type": "text", "text": prompt},
|
|
44
|
+
{
|
|
45
|
+
"type": "image_url",
|
|
46
|
+
"image_url": {"url": f"data:image/jpeg;base64,{img_data}"},
|
|
47
|
+
},
|
|
48
|
+
],
|
|
49
|
+
}
|
|
50
|
+
)
|
|
51
|
+
else:
|
|
52
|
+
messages.append({"role": "user", "content": prompt})
|
|
53
|
+
|
|
54
|
+
payload = {"model": model, "messages": messages, "stream": True}
|
|
55
|
+
headers = {
|
|
56
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
57
|
+
"Content-Type": "application/json",
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:
|
|
61
|
+
async with client.stream(
|
|
62
|
+
"POST",
|
|
63
|
+
f"{self.base_url}/chat/completions",
|
|
64
|
+
json=payload,
|
|
65
|
+
headers=headers,
|
|
66
|
+
) as response:
|
|
67
|
+
response.raise_for_status()
|
|
68
|
+
async for line in response.aiter_lines():
|
|
69
|
+
line = line.strip()
|
|
70
|
+
if not line or not line.startswith("data: "):
|
|
71
|
+
continue
|
|
72
|
+
data_str = line[6:]
|
|
73
|
+
if data_str == "[DONE]":
|
|
74
|
+
yield StreamChunk(text="", done=True)
|
|
75
|
+
return
|
|
76
|
+
data = json.loads(data_str)
|
|
77
|
+
choices = data.get("choices", [])
|
|
78
|
+
if not choices:
|
|
79
|
+
continue
|
|
80
|
+
delta = choices[0].get("delta", {})
|
|
81
|
+
text = delta.get("content", "")
|
|
82
|
+
finish = choices[0].get("finish_reason")
|
|
83
|
+
chunk = StreamChunk(text=text, done=finish is not None)
|
|
84
|
+
if finish:
|
|
85
|
+
chunk.meta = {
|
|
86
|
+
"usage": data.get("usage", {}),
|
|
87
|
+
"finish_reason": finish,
|
|
88
|
+
}
|
|
89
|
+
yield chunk
|
|
90
|
+
|
|
91
|
+
async def list_models(self) -> list[dict]:
|
|
92
|
+
"""List available models from the API."""
|
|
93
|
+
headers = {"Authorization": f"Bearer {self.api_key}"}
|
|
94
|
+
try:
|
|
95
|
+
async with httpx.AsyncClient(timeout=httpx.Timeout(10.0)) as client:
|
|
96
|
+
resp = await client.get(
|
|
97
|
+
f"{self.base_url}/models", headers=headers
|
|
98
|
+
)
|
|
99
|
+
resp.raise_for_status()
|
|
100
|
+
data = resp.json()
|
|
101
|
+
return [
|
|
102
|
+
{"name": m["id"], "size": None, "owned_by": m.get("owned_by")}
|
|
103
|
+
for m in data.get("data", [])
|
|
104
|
+
]
|
|
105
|
+
except (httpx.HTTPError, KeyError):
|
|
106
|
+
return []
|
|
107
|
+
|
|
108
|
+
async def check_connection(self) -> bool:
|
|
109
|
+
"""Check if the API is reachable and the key is valid."""
|
|
110
|
+
if not self.api_key:
|
|
111
|
+
return False
|
|
112
|
+
headers = {"Authorization": f"Bearer {self.api_key}"}
|
|
113
|
+
try:
|
|
114
|
+
async with httpx.AsyncClient(timeout=httpx.Timeout(10.0)) as client:
|
|
115
|
+
resp = await client.get(
|
|
116
|
+
f"{self.base_url}/models", headers=headers
|
|
117
|
+
)
|
|
118
|
+
return resp.status_code == 200
|
|
119
|
+
except (httpx.ConnectError, httpx.TimeoutException):
|
|
120
|
+
return False
|