arbiterx-gate 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbiterx/__init__.py +3 -0
- arbiterx/adapters/__init__.py +47 -0
- arbiterx/adapters/anthropic.py +145 -0
- arbiterx/adapters/base.py +124 -0
- arbiterx/adapters/google.py +133 -0
- arbiterx/adapters/ollama.py +129 -0
- arbiterx/adapters/openai.py +126 -0
- arbiterx/adapters/openrouter.py +143 -0
- arbiterx/cli.py +324 -0
- arbiterx/context/__init__.py +11 -0
- arbiterx/context/assembler.py +144 -0
- arbiterx/context/cache.py +193 -0
- arbiterx/context/compressor.py +199 -0
- arbiterx/gate/__init__.py +9 -0
- arbiterx/gate/efficiency.py +255 -0
- arbiterx/gate/robustness.py +255 -0
- arbiterx/gate/security.py +217 -0
- arbiterx/gate/validator.py +345 -0
- arbiterx/ladder/__init__.py +8 -0
- arbiterx/ladder/interrogator.py +975 -0
- arbiterx/mapper/__init__.py +23 -0
- arbiterx/mapper/graph.py +170 -0
- arbiterx/mapper/hasher.py +98 -0
- arbiterx/mapper/indexer.py +119 -0
- arbiterx/mapper/languages.py +72 -0
- arbiterx/mapper/parser.py +509 -0
- arbiterx/mapper/store.py +172 -0
- arbiterx/plugins/__init__.py +73 -0
- arbiterx/plugins/loader.py +307 -0
- arbiterx/principles.py +203 -0
- arbiterx/router/__init__.py +24 -0
- arbiterx/router/classifier.py +364 -0
- arbiterx/router/handoff.py +215 -0
- arbiterx/router/table.py +217 -0
- arbiterx_gate-0.1.0.dist-info/METADATA +780 -0
- arbiterx_gate-0.1.0.dist-info/RECORD +39 -0
- arbiterx_gate-0.1.0.dist-info/WHEEL +4 -0
- arbiterx_gate-0.1.0.dist-info/entry_points.txt +2 -0
- arbiterx_gate-0.1.0.dist-info/licenses/LICENSE +200 -0
arbiterx/__init__.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""Model adapters for LLM providers."""
|
|
2
|
+
|
|
3
|
+
from arbiterx.adapters.anthropic import AnthropicAdapter
|
|
4
|
+
from arbiterx.adapters.base import ModelAdapter
|
|
5
|
+
from arbiterx.adapters.google import GoogleAdapter
|
|
6
|
+
from arbiterx.adapters.ollama import OllamaAdapter
|
|
7
|
+
from arbiterx.adapters.openai import OpenAIAdapter
|
|
8
|
+
from arbiterx.adapters.openrouter import OpenRouterAdapter
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"ModelAdapter",
|
|
12
|
+
"AnthropicAdapter",
|
|
13
|
+
"OpenAIAdapter",
|
|
14
|
+
"GoogleAdapter",
|
|
15
|
+
"OllamaAdapter",
|
|
16
|
+
"OpenRouterAdapter",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
# Registry for adapter lookup by provider name
|
|
20
|
+
ADAPTER_REGISTRY: dict[str, type[ModelAdapter]] = {
|
|
21
|
+
"anthropic": AnthropicAdapter,
|
|
22
|
+
"openai": OpenAIAdapter,
|
|
23
|
+
"google": GoogleAdapter,
|
|
24
|
+
"ollama": OllamaAdapter,
|
|
25
|
+
"openrouter": OpenRouterAdapter,
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def get_adapter(provider: str, model_name: str, **kwargs) -> ModelAdapter:
|
|
30
|
+
"""Factory function to get an adapter by provider name.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
provider: Provider identifier (anthropic, openai, google, ollama, openrouter).
|
|
34
|
+
model_name: Model name to use.
|
|
35
|
+
**kwargs: Additional configuration for the adapter.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
An initialized ModelAdapter instance.
|
|
39
|
+
|
|
40
|
+
Raises:
|
|
41
|
+
ValueError: If the provider is not registered.
|
|
42
|
+
"""
|
|
43
|
+
adapter_cls = ADAPTER_REGISTRY.get(provider)
|
|
44
|
+
if adapter_cls is None:
|
|
45
|
+
available = ", ".join(sorted(ADAPTER_REGISTRY.keys()))
|
|
46
|
+
raise ValueError(f"Unknown provider '{provider}'. Available: {available}")
|
|
47
|
+
return adapter_cls(model_name, **kwargs)
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
"""Anthropic (Claude) adapter implementation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
from collections.abc import AsyncIterator
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import httpx
|
|
11
|
+
|
|
12
|
+
from arbiterx.adapters.base import ModelAdapter
|
|
13
|
+
from arbiterx.router.handoff import ConversationState
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class AnthropicAdapter(ModelAdapter):
|
|
17
|
+
"""Adapter for Anthropic's Claude API.
|
|
18
|
+
|
|
19
|
+
Handles the Anthropic-specific message format where system prompts
|
|
20
|
+
are passed as a separate top-level parameter rather than as a message.
|
|
21
|
+
|
|
22
|
+
Example:
|
|
23
|
+
>>> adapter = AnthropicAdapter("claude-sonnet-4-20250514", api_key="sk-...")
|
|
24
|
+
>>> response = await adapter.complete([{"role": "user", "content": "Hello"}])
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
DEFAULT_BASE_URL = "https://api.anthropic.com/v1"
|
|
28
|
+
|
|
29
|
+
def __init__(self, model_name: str = "claude-sonnet-4-20250514", **kwargs: Any) -> None:
|
|
30
|
+
api_key = kwargs.pop("api_key", "") or os.environ.get("ANTHROPIC_API_KEY", "")
|
|
31
|
+
super().__init__(model_name, api_key=api_key, **kwargs)
|
|
32
|
+
if not self.base_url:
|
|
33
|
+
self.base_url = self.DEFAULT_BASE_URL
|
|
34
|
+
self.api_version: str = kwargs.get("api_version", "2023-06-01")
|
|
35
|
+
|
|
36
|
+
def _headers(self) -> dict[str, str]:
|
|
37
|
+
return {
|
|
38
|
+
"x-api-key": self.api_key,
|
|
39
|
+
"anthropic-version": self.api_version,
|
|
40
|
+
"content-type": "application/json",
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
def _build_request(self, messages: list[dict[str, str]], **kwargs: Any) -> dict[str, Any]:
|
|
44
|
+
"""Build the Anthropic API request body."""
|
|
45
|
+
# Separate system messages from conversation
|
|
46
|
+
system_parts: list[str] = []
|
|
47
|
+
api_messages: list[dict[str, str]] = []
|
|
48
|
+
|
|
49
|
+
for msg in messages:
|
|
50
|
+
if msg["role"] == "system":
|
|
51
|
+
system_parts.append(msg["content"])
|
|
52
|
+
else:
|
|
53
|
+
api_messages.append(msg)
|
|
54
|
+
|
|
55
|
+
body: dict[str, Any] = {
|
|
56
|
+
"model": self.model_name,
|
|
57
|
+
"messages": api_messages,
|
|
58
|
+
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
if system_parts:
|
|
62
|
+
body["system"] = "\n\n".join(system_parts)
|
|
63
|
+
|
|
64
|
+
temperature = kwargs.get("temperature", self.temperature)
|
|
65
|
+
if temperature is not None:
|
|
66
|
+
body["temperature"] = temperature
|
|
67
|
+
|
|
68
|
+
return body
|
|
69
|
+
|
|
70
|
+
async def complete(self, messages: list[dict[str, str]], **kwargs: Any) -> str:
|
|
71
|
+
"""Send messages to Claude and return the complete response."""
|
|
72
|
+
body = self._build_request(messages, **kwargs)
|
|
73
|
+
|
|
74
|
+
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
75
|
+
response = await client.post(
|
|
76
|
+
f"{self.base_url}/messages",
|
|
77
|
+
headers=self._headers(),
|
|
78
|
+
json=body,
|
|
79
|
+
)
|
|
80
|
+
response.raise_for_status()
|
|
81
|
+
data = response.json()
|
|
82
|
+
|
|
83
|
+
# Extract text from content blocks
|
|
84
|
+
content_blocks = data.get("content", [])
|
|
85
|
+
text_parts = [block["text"] for block in content_blocks if block.get("type") == "text"]
|
|
86
|
+
return "\n".join(text_parts)
|
|
87
|
+
|
|
88
|
+
async def stream(self, messages: list[dict[str, str]], **kwargs: Any) -> AsyncIterator[str]:
|
|
89
|
+
"""Stream response tokens from Claude using SSE."""
|
|
90
|
+
body = self._build_request(messages, **kwargs)
|
|
91
|
+
body["stream"] = True
|
|
92
|
+
|
|
93
|
+
async with (
|
|
94
|
+
httpx.AsyncClient(timeout=120.0) as client,
|
|
95
|
+
client.stream(
|
|
96
|
+
"POST",
|
|
97
|
+
f"{self.base_url}/messages",
|
|
98
|
+
headers=self._headers(),
|
|
99
|
+
json=body,
|
|
100
|
+
) as response,
|
|
101
|
+
):
|
|
102
|
+
response.raise_for_status()
|
|
103
|
+
async for line in response.aiter_lines():
|
|
104
|
+
if not line.startswith("data: "):
|
|
105
|
+
continue
|
|
106
|
+
payload = line[6:]
|
|
107
|
+
if payload == "[DONE]":
|
|
108
|
+
break
|
|
109
|
+
try:
|
|
110
|
+
event = json.loads(payload)
|
|
111
|
+
if event.get("type") == "content_block_delta":
|
|
112
|
+
delta = event.get("delta", {})
|
|
113
|
+
if delta.get("type") == "text_delta":
|
|
114
|
+
yield delta.get("text", "")
|
|
115
|
+
except json.JSONDecodeError:
|
|
116
|
+
continue
|
|
117
|
+
|
|
118
|
+
def format_messages(self, state: ConversationState) -> list[dict[str, str]]:
|
|
119
|
+
"""Format state for Anthropic's API.
|
|
120
|
+
|
|
121
|
+
Anthropic expects system prompt as a separate parameter, so it's
|
|
122
|
+
stored in a system message that _build_request will extract.
|
|
123
|
+
"""
|
|
124
|
+
messages: list[dict[str, str]] = []
|
|
125
|
+
|
|
126
|
+
# System prompt
|
|
127
|
+
system_parts: list[str] = []
|
|
128
|
+
if state.system_prompt:
|
|
129
|
+
system_parts.append(state.system_prompt)
|
|
130
|
+
if state.context_snippets:
|
|
131
|
+
system_parts.append(
|
|
132
|
+
"<context>\n" + "\n---\n".join(state.context_snippets) + "\n</context>"
|
|
133
|
+
)
|
|
134
|
+
if system_parts:
|
|
135
|
+
messages.append({"role": "system", "content": "\n\n".join(system_parts)})
|
|
136
|
+
|
|
137
|
+
# Conversation messages
|
|
138
|
+
for msg in state.messages:
|
|
139
|
+
messages.append({"role": msg.role, "content": msg.content})
|
|
140
|
+
|
|
141
|
+
return messages
|
|
142
|
+
|
|
143
|
+
def count_tokens(self, text: str) -> int:
|
|
144
|
+
"""Estimate tokens using Claude's ~3.5 chars per token average."""
|
|
145
|
+
return max(1, len(text) // 4 + len(text) % 4)
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
"""Abstract base class for all model adapters."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from collections.abc import AsyncIterator
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from arbiterx.router.handoff import ConversationState
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ModelAdapter(ABC):
|
|
13
|
+
"""Abstract interface for LLM provider adapters.
|
|
14
|
+
|
|
15
|
+
All adapters must implement complete(), stream(), format_messages(),
|
|
16
|
+
and count_tokens(). Subclasses handle provider-specific API formats,
|
|
17
|
+
authentication, and error handling.
|
|
18
|
+
|
|
19
|
+
Attributes:
|
|
20
|
+
model_name: The identifier of the model being used.
|
|
21
|
+
max_tokens: Maximum output tokens supported.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, model_name: str, api_key: str = "", **kwargs: Any) -> None:
|
|
25
|
+
"""Initialize the adapter.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
model_name: Model identifier (e.g., "claude-sonnet-4-20250514").
|
|
29
|
+
api_key: API key for authentication.
|
|
30
|
+
**kwargs: Additional provider-specific configuration.
|
|
31
|
+
"""
|
|
32
|
+
self.model_name = model_name
|
|
33
|
+
self.api_key = api_key
|
|
34
|
+
self.max_tokens: int = kwargs.get("max_tokens", 4096)
|
|
35
|
+
self.temperature: float = kwargs.get("temperature", 0.7)
|
|
36
|
+
self.base_url: str = kwargs.get("base_url", "")
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
async def complete(self, messages: list[dict[str, str]], **kwargs: Any) -> str:
|
|
40
|
+
"""Send messages and return a complete response.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
messages: List of message dicts with 'role' and 'content' keys.
|
|
44
|
+
**kwargs: Additional parameters (temperature, max_tokens, etc.).
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
The complete response text from the model.
|
|
48
|
+
|
|
49
|
+
Raises:
|
|
50
|
+
NotImplementedError: Subclass must implement.
|
|
51
|
+
"""
|
|
52
|
+
...
|
|
53
|
+
|
|
54
|
+
@abstractmethod
|
|
55
|
+
async def stream(self, messages: list[dict[str, str]], **kwargs: Any) -> AsyncIterator[str]:
|
|
56
|
+
"""Stream response tokens as they are generated.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
messages: List of message dicts with 'role' and 'content' keys.
|
|
60
|
+
**kwargs: Additional parameters.
|
|
61
|
+
|
|
62
|
+
Yields:
|
|
63
|
+
Response text chunks as they arrive.
|
|
64
|
+
|
|
65
|
+
Raises:
|
|
66
|
+
NotImplementedError: Subclass must implement.
|
|
67
|
+
"""
|
|
68
|
+
...
|
|
69
|
+
|
|
70
|
+
@abstractmethod
|
|
71
|
+
def format_messages(self, state: ConversationState) -> list[dict[str, str]]:
|
|
72
|
+
"""Convert a ConversationState into the provider's message format.
|
|
73
|
+
|
|
74
|
+
Different providers have different conventions for system messages,
|
|
75
|
+
multi-turn formatting, and metadata. This method handles those
|
|
76
|
+
translations.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
state: The conversation state to format.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
List of message dicts formatted for this provider's API.
|
|
83
|
+
"""
|
|
84
|
+
...
|
|
85
|
+
|
|
86
|
+
@abstractmethod
|
|
87
|
+
def count_tokens(self, text: str) -> int:
|
|
88
|
+
"""Estimate the number of tokens in a text string.
|
|
89
|
+
|
|
90
|
+
Each provider may use different tokenizers. This provides a
|
|
91
|
+
provider-appropriate estimate.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
text: The text to tokenize.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
Estimated token count.
|
|
98
|
+
"""
|
|
99
|
+
...
|
|
100
|
+
|
|
101
|
+
def _build_messages(self, state: ConversationState) -> list[dict[str, str]]:
|
|
102
|
+
"""Default message builder — can be overridden by subclasses.
|
|
103
|
+
|
|
104
|
+
Combines system prompt, context snippets, and conversation messages
|
|
105
|
+
into a flat list.
|
|
106
|
+
"""
|
|
107
|
+
messages: list[dict[str, str]] = []
|
|
108
|
+
|
|
109
|
+
# System prompt with context
|
|
110
|
+
system_parts = [state.system_prompt] if state.system_prompt else []
|
|
111
|
+
if state.context_snippets:
|
|
112
|
+
system_parts.append("\n--- Context ---\n" + "\n---\n".join(state.context_snippets))
|
|
113
|
+
|
|
114
|
+
if system_parts:
|
|
115
|
+
messages.append({"role": "system", "content": "\n\n".join(system_parts)})
|
|
116
|
+
|
|
117
|
+
# Conversation messages
|
|
118
|
+
for msg in state.messages:
|
|
119
|
+
messages.append({"role": msg.role, "content": msg.content})
|
|
120
|
+
|
|
121
|
+
return messages
|
|
122
|
+
|
|
123
|
+
def __repr__(self) -> str:
|
|
124
|
+
return f"{self.__class__.__name__}(model={self.model_name!r})"
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
"""Google Gemini adapter implementation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
from collections.abc import AsyncIterator
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import httpx
|
|
11
|
+
|
|
12
|
+
from arbiterx.adapters.base import ModelAdapter
|
|
13
|
+
from arbiterx.router.handoff import ConversationState
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class GoogleAdapter(ModelAdapter):
|
|
17
|
+
"""Adapter for Google's Gemini API.
|
|
18
|
+
|
|
19
|
+
Supports Gemini 1.5 Pro, 2.0 Pro, and Flash models.
|
|
20
|
+
|
|
21
|
+
Example:
|
|
22
|
+
>>> adapter = GoogleAdapter("gemini-2.0-pro", api_key="...")
|
|
23
|
+
>>> response = await adapter.complete([{"role": "user", "content": "Hello"}])
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
DEFAULT_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"
|
|
27
|
+
|
|
28
|
+
def __init__(self, model_name: str = "gemini-2.0-flash", **kwargs: Any) -> None:
|
|
29
|
+
api_key = kwargs.pop("api_key", "") or os.environ.get("GOOGLE_AI_API_KEY", "")
|
|
30
|
+
super().__init__(model_name, api_key=api_key, **kwargs)
|
|
31
|
+
if not self.base_url:
|
|
32
|
+
self.base_url = self.DEFAULT_BASE_URL
|
|
33
|
+
|
|
34
|
+
def _build_request(self, messages: list[dict[str, str]], **kwargs: Any) -> dict[str, Any]:
|
|
35
|
+
"""Build the Gemini API request body.
|
|
36
|
+
|
|
37
|
+
Gemini uses 'user' and 'model' roles with 'parts' content structure.
|
|
38
|
+
"""
|
|
39
|
+
contents: list[dict[str, Any]] = []
|
|
40
|
+
system_instruction: str = ""
|
|
41
|
+
|
|
42
|
+
for msg in messages:
|
|
43
|
+
if msg["role"] == "system":
|
|
44
|
+
system_instruction = msg["content"]
|
|
45
|
+
else:
|
|
46
|
+
role = "model" if msg["role"] == "assistant" else "user"
|
|
47
|
+
contents.append(
|
|
48
|
+
{
|
|
49
|
+
"role": role,
|
|
50
|
+
"parts": [{"text": msg["content"]}],
|
|
51
|
+
}
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
body: dict[str, Any] = {
|
|
55
|
+
"contents": contents,
|
|
56
|
+
"generationConfig": {
|
|
57
|
+
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
|
|
58
|
+
"temperature": kwargs.get("temperature", self.temperature),
|
|
59
|
+
},
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
if system_instruction:
|
|
63
|
+
body["systemInstruction"] = {"parts": [{"text": system_instruction}]}
|
|
64
|
+
|
|
65
|
+
return body
|
|
66
|
+
|
|
67
|
+
async def complete(self, messages: list[dict[str, str]], **kwargs: Any) -> str:
|
|
68
|
+
"""Send messages to Gemini and return the complete response."""
|
|
69
|
+
body = self._build_request(messages, **kwargs)
|
|
70
|
+
url = f"{self.base_url}/models/{self.model_name}:generateContent?key={self.api_key}"
|
|
71
|
+
|
|
72
|
+
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
73
|
+
response = await client.post(url, json=body)
|
|
74
|
+
response.raise_for_status()
|
|
75
|
+
data = response.json()
|
|
76
|
+
|
|
77
|
+
candidates = data.get("candidates", [])
|
|
78
|
+
if candidates:
|
|
79
|
+
parts = candidates[0].get("content", {}).get("parts", [])
|
|
80
|
+
return "".join(p.get("text", "") for p in parts)
|
|
81
|
+
return ""
|
|
82
|
+
|
|
83
|
+
async def stream(self, messages: list[dict[str, str]], **kwargs: Any) -> AsyncIterator[str]:
|
|
84
|
+
"""Stream response tokens from Gemini."""
|
|
85
|
+
body = self._build_request(messages, **kwargs)
|
|
86
|
+
url = (
|
|
87
|
+
f"{self.base_url}/models/{self.model_name}:streamGenerateContent"
|
|
88
|
+
f"?key={self.api_key}&alt=sse"
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
92
|
+
async with client.stream("POST", url, json=body) as response:
|
|
93
|
+
response.raise_for_status()
|
|
94
|
+
async for line in response.aiter_lines():
|
|
95
|
+
if not line.startswith("data: "):
|
|
96
|
+
continue
|
|
97
|
+
payload = line[6:]
|
|
98
|
+
try:
|
|
99
|
+
event = json.loads(payload)
|
|
100
|
+
candidates = event.get("candidates", [])
|
|
101
|
+
if candidates:
|
|
102
|
+
parts = candidates[0].get("content", {}).get("parts", [])
|
|
103
|
+
for part in parts:
|
|
104
|
+
text = part.get("text", "")
|
|
105
|
+
if text:
|
|
106
|
+
yield text
|
|
107
|
+
except json.JSONDecodeError:
|
|
108
|
+
continue
|
|
109
|
+
|
|
110
|
+
def format_messages(self, state: ConversationState) -> list[dict[str, str]]:
|
|
111
|
+
"""Format state for Gemini's API.
|
|
112
|
+
|
|
113
|
+
Gemini uses 'user'/'model' roles. System prompt goes into
|
|
114
|
+
systemInstruction (handled by _build_request).
|
|
115
|
+
"""
|
|
116
|
+
messages: list[dict[str, str]] = []
|
|
117
|
+
|
|
118
|
+
system_parts: list[str] = []
|
|
119
|
+
if state.system_prompt:
|
|
120
|
+
system_parts.append(state.system_prompt)
|
|
121
|
+
if state.context_snippets:
|
|
122
|
+
system_parts.append("## Context\n" + "\n---\n".join(state.context_snippets))
|
|
123
|
+
if system_parts:
|
|
124
|
+
messages.append({"role": "system", "content": "\n\n".join(system_parts)})
|
|
125
|
+
|
|
126
|
+
for msg in state.messages:
|
|
127
|
+
messages.append({"role": msg.role, "content": msg.content})
|
|
128
|
+
|
|
129
|
+
return messages
|
|
130
|
+
|
|
131
|
+
def count_tokens(self, text: str) -> int:
|
|
132
|
+
"""Estimate tokens — Gemini uses ~4 chars per token on average."""
|
|
133
|
+
return max(1, len(text) // 4)
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
"""Ollama adapter for local model inference."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from collections.abc import AsyncIterator
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import httpx
|
|
10
|
+
|
|
11
|
+
from arbiterx.adapters.base import ModelAdapter
|
|
12
|
+
from arbiterx.router.handoff import ConversationState
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class OllamaAdapter(ModelAdapter):
|
|
16
|
+
"""Adapter for Ollama local inference server.
|
|
17
|
+
|
|
18
|
+
Supports any model available via Ollama's OpenAI-compatible API.
|
|
19
|
+
No API key needed — runs entirely local.
|
|
20
|
+
|
|
21
|
+
Example:
|
|
22
|
+
>>> adapter = OllamaAdapter("qwen2.5-coder:7b")
|
|
23
|
+
>>> response = await adapter.complete([{"role": "user", "content": "Hello"}])
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
DEFAULT_BASE_URL = "http://localhost:11434"
|
|
27
|
+
|
|
28
|
+
def __init__(self, model_name: str = "qwen2.5-coder:7b", **kwargs: Any) -> None:
|
|
29
|
+
super().__init__(model_name, api_key="", **kwargs)
|
|
30
|
+
if not self.base_url:
|
|
31
|
+
self.base_url = self.DEFAULT_BASE_URL
|
|
32
|
+
|
|
33
|
+
def _build_request(self, messages: list[dict[str, str]], **kwargs: Any) -> dict[str, Any]:
|
|
34
|
+
"""Build the Ollama API request body (OpenAI-compatible format)."""
|
|
35
|
+
body: dict[str, Any] = {
|
|
36
|
+
"model": self.model_name,
|
|
37
|
+
"messages": messages,
|
|
38
|
+
"options": {
|
|
39
|
+
"temperature": kwargs.get("temperature", self.temperature),
|
|
40
|
+
"num_predict": kwargs.get("max_tokens", self.max_tokens),
|
|
41
|
+
},
|
|
42
|
+
}
|
|
43
|
+
return body
|
|
44
|
+
|
|
45
|
+
async def complete(self, messages: list[dict[str, str]], **kwargs: Any) -> str:
|
|
46
|
+
"""Send messages to Ollama and return the complete response."""
|
|
47
|
+
body = self._build_request(messages, **kwargs)
|
|
48
|
+
body["stream"] = False
|
|
49
|
+
|
|
50
|
+
async with httpx.AsyncClient(timeout=300.0) as client:
|
|
51
|
+
response = await client.post(
|
|
52
|
+
f"{self.base_url}/api/chat",
|
|
53
|
+
json=body,
|
|
54
|
+
)
|
|
55
|
+
response.raise_for_status()
|
|
56
|
+
data = response.json()
|
|
57
|
+
|
|
58
|
+
return data.get("message", {}).get("content", "")
|
|
59
|
+
|
|
60
|
+
async def stream(self, messages: list[dict[str, str]], **kwargs: Any) -> AsyncIterator[str]:
|
|
61
|
+
"""Stream response tokens from Ollama."""
|
|
62
|
+
body = self._build_request(messages, **kwargs)
|
|
63
|
+
body["stream"] = True
|
|
64
|
+
|
|
65
|
+
async with (
|
|
66
|
+
httpx.AsyncClient(timeout=300.0) as client,
|
|
67
|
+
client.stream(
|
|
68
|
+
"POST",
|
|
69
|
+
f"{self.base_url}/api/chat",
|
|
70
|
+
json=body,
|
|
71
|
+
) as response,
|
|
72
|
+
):
|
|
73
|
+
response.raise_for_status()
|
|
74
|
+
async for line in response.aiter_lines():
|
|
75
|
+
if not line.strip():
|
|
76
|
+
continue
|
|
77
|
+
try:
|
|
78
|
+
event = json.loads(line)
|
|
79
|
+
content = event.get("message", {}).get("content", "")
|
|
80
|
+
if content:
|
|
81
|
+
yield content
|
|
82
|
+
if event.get("done", False):
|
|
83
|
+
break
|
|
84
|
+
except json.JSONDecodeError:
|
|
85
|
+
continue
|
|
86
|
+
|
|
87
|
+
def format_messages(self, state: ConversationState) -> list[dict[str, str]]:
|
|
88
|
+
"""Format state for Ollama's API.
|
|
89
|
+
|
|
90
|
+
Ollama supports standard system/user/assistant roles.
|
|
91
|
+
"""
|
|
92
|
+
messages: list[dict[str, str]] = []
|
|
93
|
+
|
|
94
|
+
system_parts: list[str] = []
|
|
95
|
+
if state.system_prompt:
|
|
96
|
+
system_parts.append(state.system_prompt)
|
|
97
|
+
if state.context_snippets:
|
|
98
|
+
system_parts.append("Context:\n" + "\n---\n".join(state.context_snippets))
|
|
99
|
+
if system_parts:
|
|
100
|
+
messages.append({"role": "system", "content": "\n\n".join(system_parts)})
|
|
101
|
+
|
|
102
|
+
for msg in state.messages:
|
|
103
|
+
messages.append({"role": msg.role, "content": msg.content})
|
|
104
|
+
|
|
105
|
+
return messages
|
|
106
|
+
|
|
107
|
+
def count_tokens(self, text: str) -> int:
|
|
108
|
+
"""Estimate tokens — most local models use ~4 chars per token."""
|
|
109
|
+
return max(1, len(text) // 4)
|
|
110
|
+
|
|
111
|
+
async def is_available(self) -> bool:
|
|
112
|
+
"""Check if Ollama server is running and accessible."""
|
|
113
|
+
try:
|
|
114
|
+
async with httpx.AsyncClient(timeout=5.0) as client:
|
|
115
|
+
response = await client.get(f"{self.base_url}/api/tags")
|
|
116
|
+
return response.status_code == 200
|
|
117
|
+
except (httpx.ConnectError, httpx.TimeoutException):
|
|
118
|
+
return False
|
|
119
|
+
|
|
120
|
+
async def list_models(self) -> list[str]:
|
|
121
|
+
"""List available models on the Ollama server."""
|
|
122
|
+
try:
|
|
123
|
+
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
124
|
+
response = await client.get(f"{self.base_url}/api/tags")
|
|
125
|
+
response.raise_for_status()
|
|
126
|
+
data = response.json()
|
|
127
|
+
return [m["name"] for m in data.get("models", [])]
|
|
128
|
+
except (httpx.ConnectError, httpx.TimeoutException, httpx.HTTPStatusError):
|
|
129
|
+
return []
|