arbiter-cli 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,133 @@
1
+ """Anthropic provider - Claude models via the Anthropic API."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from typing import AsyncIterator, Optional
7
+
8
+ import httpx
9
+
10
+ from arbiter.core.providers.base import LLMProvider, StreamChunk
11
+
12
+
13
+ class AnthropicProvider(LLMProvider):
14
+ """Provider for Anthropic Claude models."""
15
+
16
+ provider_name = "anthropic"
17
+
18
+ def __init__(self, api_key: str):
19
+ self.api_key = api_key
20
+ self.base_url = "https://api.anthropic.com/v1"
21
+
22
+ async def stream_generate(
23
+ self,
24
+ model: str,
25
+ prompt: str,
26
+ system: Optional[str] = None,
27
+ image_path: Optional[str] = None,
28
+ ) -> AsyncIterator[StreamChunk]:
29
+ """Stream tokens from the Anthropic API."""
30
+ messages = []
31
+
32
+ if image_path:
33
+ import base64
34
+ from pathlib import Path
35
+
36
+ img_data = base64.b64encode(Path(image_path).read_bytes()).decode()
37
+ messages.append(
38
+ {
39
+ "role": "user",
40
+ "content": [
41
+ {
42
+ "type": "image",
43
+ "source": {
44
+ "type": "base64",
45
+ "media_type": "image/jpeg",
46
+ "data": img_data,
47
+ },
48
+ },
49
+ {"type": "text", "text": prompt},
50
+ ],
51
+ }
52
+ )
53
+ else:
54
+ messages.append({"role": "user", "content": prompt})
55
+
56
+ payload: dict = {
57
+ "model": model,
58
+ "messages": messages,
59
+ "max_tokens": 4096,
60
+ "stream": True,
61
+ }
62
+ if system:
63
+ payload["system"] = system
64
+
65
+ headers = {
66
+ "x-api-key": self.api_key,
67
+ "anthropic-version": "2023-06-01",
68
+ "Content-Type": "application/json",
69
+ }
70
+
71
+ async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:
72
+ async with client.stream(
73
+ "POST",
74
+ f"{self.base_url}/messages",
75
+ json=payload,
76
+ headers=headers,
77
+ ) as response:
78
+ response.raise_for_status()
79
+ async for line in response.aiter_lines():
80
+ line = line.strip()
81
+ if not line or not line.startswith("data: "):
82
+ continue
83
+ data_str = line[6:]
84
+ if not data_str:
85
+ continue
86
+ try:
87
+ data = json.loads(data_str)
88
+ except json.JSONDecodeError:
89
+ continue
90
+
91
+ event_type = data.get("type", "")
92
+
93
+ if event_type == "content_block_delta":
94
+ delta = data.get("delta", {})
95
+ text = delta.get("text", "")
96
+ yield StreamChunk(text=text, done=False)
97
+
98
+ elif event_type == "message_delta":
99
+ usage = data.get("usage", {})
100
+ yield StreamChunk(
101
+ text="",
102
+ done=True,
103
+ meta={
104
+ "output_tokens": usage.get("output_tokens"),
105
+ "stop_reason": data.get("delta", {}).get(
106
+ "stop_reason"
107
+ ),
108
+ },
109
+ )
110
+
111
+ elif event_type == "message_start":
112
+ # Capture input token count
113
+ msg = data.get("message", {})
114
+ usage = msg.get("usage", {})
115
+ if usage.get("input_tokens"):
116
+ yield StreamChunk(
117
+ text="",
118
+ done=False,
119
+ meta={"input_tokens": usage["input_tokens"]},
120
+ )
121
+
122
+ async def list_models(self) -> list[dict]:
123
+ """List available Anthropic models."""
124
+ # Anthropic doesn't have a list models endpoint, return known models
125
+ return [
126
+ {"name": "claude-sonnet-4-20250514", "size": None},
127
+ {"name": "claude-haiku-4-20250414", "size": None},
128
+ {"name": "claude-opus-4-20250514", "size": None},
129
+ ]
130
+
131
+ async def check_connection(self) -> bool:
132
+ """Check if the Anthropic API key is configured."""
133
+ return bool(self.api_key)
@@ -0,0 +1,62 @@
1
+ """Base provider interface for all LLM backends."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from abc import ABC, abstractmethod
6
+ from dataclasses import dataclass, field
7
+ from typing import AsyncIterator, Optional
8
+
9
+
10
+ @dataclass
11
+ class StreamChunk:
12
+ """A single chunk from a streaming LLM response."""
13
+
14
+ text: str
15
+ done: bool = False
16
+ # Provider-specific metadata (eval_count, eval_duration, etc.)
17
+ meta: dict = field(default_factory=dict)
18
+
19
+
20
+ @dataclass
21
+ class GenerationResult:
22
+ """Complete result from a model generation."""
23
+
24
+ model: str
25
+ provider: str
26
+ output: str
27
+ total_tokens: int
28
+ eval_duration_ns: Optional[int] = None # nanoseconds
29
+ total_duration_ns: Optional[int] = None
30
+ prompt_tokens: Optional[int] = None
31
+ # Raw provider response metadata
32
+ raw_meta: dict = field(default_factory=dict)
33
+
34
+
35
+ class LLMProvider(ABC):
36
+ """Abstract base class for LLM providers."""
37
+
38
+ provider_name: str = "base"
39
+
40
+ @abstractmethod
41
+ async def stream_generate(
42
+ self,
43
+ model: str,
44
+ prompt: str,
45
+ system: Optional[str] = None,
46
+ image_path: Optional[str] = None,
47
+ ) -> AsyncIterator[StreamChunk]:
48
+ """Stream tokens from the model. Yields StreamChunk objects."""
49
+ ...
50
+
51
+ @abstractmethod
52
+ async def list_models(self) -> list[dict]:
53
+ """List available models for this provider.
54
+
55
+ Returns a list of dicts with at least {"name": str, "size": int|None}.
56
+ """
57
+ ...
58
+
59
+ @abstractmethod
60
+ async def check_connection(self) -> bool:
61
+ """Check if the provider is reachable and configured."""
62
+ ...
@@ -0,0 +1,79 @@
1
+ """Factory for creating provider instances from config."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from arbiter.core.config import (
6
+ PROVIDER_ANTHROPIC,
7
+ PROVIDER_GOOGLE,
8
+ PROVIDER_OLLAMA,
9
+ PROVIDER_OPENAI,
10
+ PROVIDER_OPENAI_COMPAT,
11
+ ProviderConfig,
12
+ resolve_model,
13
+ )
14
+ from arbiter.core.providers.anthropic_provider import AnthropicProvider
15
+ from arbiter.core.providers.base import LLMProvider
16
+ from arbiter.core.providers.google_provider import GoogleProvider
17
+ from arbiter.core.providers.ollama import OllamaProvider
18
+ from arbiter.core.providers.openai_provider import OpenAIProvider
19
+
20
+
21
+ def create_provider(config: ProviderConfig) -> tuple[LLMProvider, str]:
22
+ """Create a provider instance and return (provider, clean_model_name).
23
+
24
+ Args:
25
+ config: ProviderConfig from resolve_model()
26
+
27
+ Returns:
28
+ Tuple of (provider_instance, model_name_to_use)
29
+
30
+ Raises:
31
+ ValueError: If API key is required but missing
32
+ """
33
+ model = config.extra["model"]
34
+
35
+ if config.provider == PROVIDER_OLLAMA:
36
+ provider = OllamaProvider(base_url=config.base_url or "http://localhost:11434")
37
+ return provider, model
38
+
39
+ if config.provider == PROVIDER_OPENAI:
40
+ if not config.api_key:
41
+ raise ValueError(
42
+ "OpenAI API key required. Set OPENAI_API_KEY environment variable."
43
+ )
44
+ provider = OpenAIProvider(api_key=config.api_key)
45
+ return provider, model
46
+
47
+ if config.provider == PROVIDER_OPENAI_COMPAT:
48
+ if not config.api_key:
49
+ raise ValueError(
50
+ "API key required for OpenAI-compatible endpoint. Set OPENAI_API_KEY."
51
+ )
52
+ provider = OpenAIProvider(
53
+ api_key=config.api_key, base_url=config.base_url or ""
54
+ )
55
+ return provider, model
56
+
57
+ if config.provider == PROVIDER_ANTHROPIC:
58
+ if not config.api_key:
59
+ raise ValueError(
60
+ "Anthropic API key required. Set ANTHROPIC_API_KEY environment variable."
61
+ )
62
+ provider = AnthropicProvider(api_key=config.api_key)
63
+ return provider, model
64
+
65
+ if config.provider == PROVIDER_GOOGLE:
66
+ if not config.api_key:
67
+ raise ValueError(
68
+ "Google API key required. Set GOOGLE_API_KEY environment variable."
69
+ )
70
+ provider = GoogleProvider(api_key=config.api_key)
71
+ return provider, model
72
+
73
+ raise ValueError(f"Unknown provider: {config.provider}")
74
+
75
+
76
+ def provider_from_spec(model_spec: str) -> tuple[LLMProvider, str]:
77
+ """Convenience: resolve a model spec string directly to (provider, model_name)."""
78
+ config = resolve_model(model_spec)
79
+ return create_provider(config)
@@ -0,0 +1,126 @@
1
+ """Google Gemini provider via the Generative Language API."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from typing import AsyncIterator, Optional
7
+
8
+ import httpx
9
+
10
+ from arbiter.core.providers.base import LLMProvider, StreamChunk
11
+
12
+
13
+ class GoogleProvider(LLMProvider):
14
+ """Provider for Google Gemini models."""
15
+
16
+ provider_name = "google"
17
+
18
+ def __init__(self, api_key: str):
19
+ self.api_key = api_key
20
+ self.base_url = "https://generativelanguage.googleapis.com/v1beta"
21
+
22
+ async def stream_generate(
23
+ self,
24
+ model: str,
25
+ prompt: str,
26
+ system: Optional[str] = None,
27
+ image_path: Optional[str] = None,
28
+ ) -> AsyncIterator[StreamChunk]:
29
+ """Stream tokens from the Google Gemini API."""
30
+ contents = []
31
+
32
+ if image_path:
33
+ import base64
34
+ from pathlib import Path
35
+
36
+ img_data = base64.b64encode(Path(image_path).read_bytes()).decode()
37
+ contents.append(
38
+ {
39
+ "parts": [
40
+ {"inline_data": {"mime_type": "image/jpeg", "data": img_data}},
41
+ {"text": prompt},
42
+ ]
43
+ }
44
+ )
45
+ else:
46
+ contents.append({"parts": [{"text": prompt}]})
47
+
48
+ payload: dict = {"contents": contents}
49
+ if system:
50
+ payload["system_instruction"] = {"parts": [{"text": system}]}
51
+
52
+ url = (
53
+ f"{self.base_url}/models/{model}:streamGenerateContent"
54
+ f"?key={self.api_key}&alt=sse"
55
+ )
56
+
57
+ async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:
58
+ async with client.stream("POST", url, json=payload) as response:
59
+ response.raise_for_status()
60
+ async for line in response.aiter_lines():
61
+ line = line.strip()
62
+ if not line or not line.startswith("data: "):
63
+ continue
64
+ data_str = line[6:]
65
+ try:
66
+ data = json.loads(data_str)
67
+ except json.JSONDecodeError:
68
+ continue
69
+
70
+ candidates = data.get("candidates", [])
71
+ if not candidates:
72
+ continue
73
+
74
+ candidate = candidates[0]
75
+ content = candidate.get("content", {})
76
+ parts = content.get("parts", [])
77
+ text = "".join(p.get("text", "") for p in parts)
78
+
79
+ finish_reason = candidate.get("finishReason")
80
+ done = finish_reason is not None and finish_reason != "STOP"
81
+
82
+ chunk = StreamChunk(text=text, done=False)
83
+
84
+ if finish_reason == "STOP":
85
+ usage = data.get("usageMetadata", {})
86
+ chunk.done = True
87
+ chunk.meta = {
88
+ "prompt_tokens": usage.get("promptTokenCount"),
89
+ "output_tokens": usage.get("candidatesTokenCount"),
90
+ "total_tokens": usage.get("totalTokenCount"),
91
+ "finish_reason": finish_reason,
92
+ }
93
+
94
+ yield chunk
95
+
96
+ async def list_models(self) -> list[dict]:
97
+ """List available Gemini models."""
98
+ try:
99
+ url = f"{self.base_url}/models?key={self.api_key}"
100
+ async with httpx.AsyncClient(timeout=httpx.Timeout(10.0)) as client:
101
+ resp = await client.get(url)
102
+ resp.raise_for_status()
103
+ data = resp.json()
104
+ return [
105
+ {
106
+ "name": m["name"].replace("models/", ""),
107
+ "size": None,
108
+ "display_name": m.get("displayName"),
109
+ }
110
+ for m in data.get("models", [])
111
+ if "generateContent" in m.get("supportedGenerationMethods", [])
112
+ ]
113
+ except (httpx.HTTPError, KeyError):
114
+ return []
115
+
116
+ async def check_connection(self) -> bool:
117
+ """Check if the Google API key is configured and valid."""
118
+ if not self.api_key:
119
+ return False
120
+ try:
121
+ url = f"{self.base_url}/models?key={self.api_key}"
122
+ async with httpx.AsyncClient(timeout=httpx.Timeout(10.0)) as client:
123
+ resp = await client.get(url)
124
+ return resp.status_code == 200
125
+ except (httpx.ConnectError, httpx.TimeoutException):
126
+ return False
@@ -0,0 +1,103 @@
1
+ """Ollama provider - local model execution."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import base64
6
+ import json
7
+ from pathlib import Path
8
+ from typing import AsyncIterator, Optional
9
+
10
+ import httpx
11
+
12
+ from arbiter.core.providers.base import GenerationResult, LLMProvider, StreamChunk
13
+
14
+
15
+ class OllamaProvider(LLMProvider):
16
+ """Provider for locally running Ollama models."""
17
+
18
+ provider_name = "ollama"
19
+
20
+ def __init__(self, base_url: str = "http://localhost:11434"):
21
+ self.base_url = base_url.rstrip("/")
22
+
23
+ async def stream_generate(
24
+ self,
25
+ model: str,
26
+ prompt: str,
27
+ system: Optional[str] = None,
28
+ image_path: Optional[str] = None,
29
+ ) -> AsyncIterator[StreamChunk]:
30
+ """Stream tokens from an Ollama model."""
31
+ payload: dict = {"model": model, "prompt": prompt, "stream": True}
32
+
33
+ if system:
34
+ payload["system"] = system
35
+
36
+ if image_path:
37
+ img_data = Path(image_path).read_bytes()
38
+ payload["images"] = [base64.b64encode(img_data).decode()]
39
+
40
+ async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:
41
+ async with client.stream(
42
+ "POST", f"{self.base_url}/api/generate", json=payload
43
+ ) as response:
44
+ response.raise_for_status()
45
+ async for line in response.aiter_lines():
46
+ if not line.strip():
47
+ continue
48
+ data = json.loads(line)
49
+ chunk = StreamChunk(
50
+ text=data.get("response", ""),
51
+ done=data.get("done", False),
52
+ )
53
+ if chunk.done:
54
+ chunk.meta = {
55
+ "total_duration": data.get("total_duration"),
56
+ "eval_count": data.get("eval_count"),
57
+ "eval_duration": data.get("eval_duration"),
58
+ "prompt_eval_count": data.get("prompt_eval_count"),
59
+ "prompt_eval_duration": data.get("prompt_eval_duration"),
60
+ }
61
+ yield chunk
62
+
63
+ async def list_models(self) -> list[dict]:
64
+ """List installed Ollama models."""
65
+ async with httpx.AsyncClient(timeout=httpx.Timeout(10.0)) as client:
66
+ resp = await client.get(f"{self.base_url}/api/tags")
67
+ resp.raise_for_status()
68
+ data = resp.json()
69
+
70
+ models = []
71
+ for m in data.get("models", []):
72
+ details = m.get("details", {})
73
+ models.append(
74
+ {
75
+ "name": m["name"],
76
+ "size": m.get("size"),
77
+ "parameter_size": details.get("parameter_size"),
78
+ "quantization": details.get("quantization_level"),
79
+ "family": details.get("family"),
80
+ "families": details.get("families"),
81
+ "format": details.get("format"),
82
+ "modified_at": m.get("modified_at"),
83
+ }
84
+ )
85
+ return models
86
+
87
+ async def get_model_info(self, model: str) -> dict:
88
+ """Get detailed info for a specific model."""
89
+ async with httpx.AsyncClient(timeout=httpx.Timeout(10.0)) as client:
90
+ resp = await client.post(
91
+ f"{self.base_url}/api/show", json={"name": model}
92
+ )
93
+ resp.raise_for_status()
94
+ return resp.json()
95
+
96
+ async def check_connection(self) -> bool:
97
+ """Check if Ollama is running."""
98
+ try:
99
+ async with httpx.AsyncClient(timeout=httpx.Timeout(5.0)) as client:
100
+ resp = await client.get(f"{self.base_url}/api/tags")
101
+ return resp.status_code == 200
102
+ except (httpx.ConnectError, httpx.TimeoutException):
103
+ return False
@@ -0,0 +1,120 @@
1
+ """OpenAI-compatible provider - works with OpenAI, Together, Groq, any OpenAI-compatible API."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from typing import AsyncIterator, Optional
7
+
8
+ import httpx
9
+
10
+ from arbiter.core.providers.base import LLMProvider, StreamChunk
11
+
12
+
13
+ class OpenAIProvider(LLMProvider):
14
+ """Provider for OpenAI and any OpenAI-compatible API."""
15
+
16
+ provider_name = "openai"
17
+
18
+ def __init__(self, api_key: str, base_url: str = "https://api.openai.com/v1"):
19
+ self.api_key = api_key
20
+ self.base_url = base_url.rstrip("/")
21
+
22
+ async def stream_generate(
23
+ self,
24
+ model: str,
25
+ prompt: str,
26
+ system: Optional[str] = None,
27
+ image_path: Optional[str] = None,
28
+ ) -> AsyncIterator[StreamChunk]:
29
+ """Stream tokens from an OpenAI-compatible API."""
30
+ messages = []
31
+ if system:
32
+ messages.append({"role": "system", "content": system})
33
+
34
+ if image_path:
35
+ import base64
36
+ from pathlib import Path
37
+
38
+ img_data = base64.b64encode(Path(image_path).read_bytes()).decode()
39
+ messages.append(
40
+ {
41
+ "role": "user",
42
+ "content": [
43
+ {"type": "text", "text": prompt},
44
+ {
45
+ "type": "image_url",
46
+ "image_url": {"url": f"data:image/jpeg;base64,{img_data}"},
47
+ },
48
+ ],
49
+ }
50
+ )
51
+ else:
52
+ messages.append({"role": "user", "content": prompt})
53
+
54
+ payload = {"model": model, "messages": messages, "stream": True}
55
+ headers = {
56
+ "Authorization": f"Bearer {self.api_key}",
57
+ "Content-Type": "application/json",
58
+ }
59
+
60
+ async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:
61
+ async with client.stream(
62
+ "POST",
63
+ f"{self.base_url}/chat/completions",
64
+ json=payload,
65
+ headers=headers,
66
+ ) as response:
67
+ response.raise_for_status()
68
+ async for line in response.aiter_lines():
69
+ line = line.strip()
70
+ if not line or not line.startswith("data: "):
71
+ continue
72
+ data_str = line[6:]
73
+ if data_str == "[DONE]":
74
+ yield StreamChunk(text="", done=True)
75
+ return
76
+ data = json.loads(data_str)
77
+ choices = data.get("choices", [])
78
+ if not choices:
79
+ continue
80
+ delta = choices[0].get("delta", {})
81
+ text = delta.get("content", "")
82
+ finish = choices[0].get("finish_reason")
83
+ chunk = StreamChunk(text=text, done=finish is not None)
84
+ if finish:
85
+ chunk.meta = {
86
+ "usage": data.get("usage", {}),
87
+ "finish_reason": finish,
88
+ }
89
+ yield chunk
90
+
91
+ async def list_models(self) -> list[dict]:
92
+ """List available models from the API."""
93
+ headers = {"Authorization": f"Bearer {self.api_key}"}
94
+ try:
95
+ async with httpx.AsyncClient(timeout=httpx.Timeout(10.0)) as client:
96
+ resp = await client.get(
97
+ f"{self.base_url}/models", headers=headers
98
+ )
99
+ resp.raise_for_status()
100
+ data = resp.json()
101
+ return [
102
+ {"name": m["id"], "size": None, "owned_by": m.get("owned_by")}
103
+ for m in data.get("data", [])
104
+ ]
105
+ except (httpx.HTTPError, KeyError):
106
+ return []
107
+
108
+ async def check_connection(self) -> bool:
109
+ """Check if the API is reachable and the key is valid."""
110
+ if not self.api_key:
111
+ return False
112
+ headers = {"Authorization": f"Bearer {self.api_key}"}
113
+ try:
114
+ async with httpx.AsyncClient(timeout=httpx.Timeout(10.0)) as client:
115
+ resp = await client.get(
116
+ f"{self.base_url}/models", headers=headers
117
+ )
118
+ return resp.status_code == 200
119
+ except (httpx.ConnectError, httpx.TimeoutException):
120
+ return False