llmwire 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llmwire/__init__.py +22 -0
- llmwire/client.py +270 -0
- llmwire/config.py +41 -0
- llmwire/exceptions.py +23 -0
- llmwire/models.py +42 -0
- llmwire/provider.py +36 -0
- llmwire/providers/__init__.py +6 -0
- llmwire/providers/anthropic.py +228 -0
- llmwire/providers/ollama.py +202 -0
- llmwire/providers/openai.py +200 -0
- llmwire/retry.py +46 -0
- llmwire-0.1.0.dist-info/METADATA +172 -0
- llmwire-0.1.0.dist-info/RECORD +15 -0
- llmwire-0.1.0.dist-info/WHEEL +4 -0
- llmwire-0.1.0.dist-info/licenses/LICENSE +21 -0
llmwire/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""LLMWire — Lightweight multi-provider LLM client."""
|
|
2
|
+
|
|
3
|
+
from llmwire.client import LLMClient
|
|
4
|
+
from llmwire.config import LLMConfig, ProviderConfig
|
|
5
|
+
from llmwire.exceptions import AllProvidersFailedError, LLMWireError, ProviderError
|
|
6
|
+
from llmwire.models import ChatResponse, Message, StreamChunk, Usage
|
|
7
|
+
from llmwire.provider import Provider
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"AllProvidersFailedError",
|
|
11
|
+
"ChatResponse",
|
|
12
|
+
"LLMClient",
|
|
13
|
+
"LLMConfig",
|
|
14
|
+
"LLMWireError",
|
|
15
|
+
"Message",
|
|
16
|
+
"Provider",
|
|
17
|
+
"ProviderConfig",
|
|
18
|
+
"ProviderError",
|
|
19
|
+
"StreamChunk",
|
|
20
|
+
"Usage",
|
|
21
|
+
]
|
|
22
|
+
__version__ = "0.1.0"
|
llmwire/client.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
"""LLMClient — the main orchestrator for multi-provider LLM access."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import inspect
|
|
5
|
+
import json
|
|
6
|
+
from typing import TYPE_CHECKING, Any, TypeVar, cast, overload
|
|
7
|
+
|
|
8
|
+
from llmwire.exceptions import AllProvidersFailedError, ProviderError
|
|
9
|
+
from llmwire.models import ChatResponse, Message, StreamChunk
|
|
10
|
+
from llmwire.providers import AnthropicProvider, OllamaProvider, OpenAIProvider
|
|
11
|
+
from llmwire.retry import retry_with_backoff
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from collections.abc import AsyncIterator
|
|
15
|
+
|
|
16
|
+
from pydantic import BaseModel
|
|
17
|
+
|
|
18
|
+
from llmwire.config import LLMConfig, ProviderConfig
|
|
19
|
+
|
|
20
|
+
T = TypeVar("T", bound="BaseModel")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
_schema_cache: dict[type[Any], str] = {}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _cached_schema_json(model_cls: type[Any]) -> str:
|
|
27
|
+
"""Return the JSON schema string for a Pydantic model, cached per class."""
|
|
28
|
+
if model_cls not in _schema_cache:
|
|
29
|
+
_schema_cache[model_cls] = json.dumps(model_cls.model_json_schema(), indent=2)
|
|
30
|
+
return _schema_cache[model_cls]
|
|
31
|
+
|
|
32
|
+
_PROVIDER_MAP: dict[str, type[Any]] = {
|
|
33
|
+
"openai": OpenAIProvider,
|
|
34
|
+
"anthropic": AnthropicProvider,
|
|
35
|
+
"ollama": OllamaProvider,
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _create_provider(
|
|
40
|
+
config: ProviderConfig, timeout: float
|
|
41
|
+
) -> OpenAIProvider | AnthropicProvider | OllamaProvider:
|
|
42
|
+
"""Instantiate a provider from its configuration.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
config: Provider configuration including name, api_key, model, base_url.
|
|
46
|
+
timeout: Request timeout in seconds forwarded to the provider.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
A concrete provider instance.
|
|
50
|
+
|
|
51
|
+
Raises:
|
|
52
|
+
ValueError: If the provider name is not recognised.
|
|
53
|
+
"""
|
|
54
|
+
provider_cls = _PROVIDER_MAP.get(config.name)
|
|
55
|
+
if provider_cls is None:
|
|
56
|
+
known = ", ".join(_PROVIDER_MAP)
|
|
57
|
+
raise ValueError(
|
|
58
|
+
f"Unknown provider '{config.name}'. Known providers: {known}"
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
kwargs: dict[str, Any] = {"model": config.model, "timeout": timeout}
|
|
62
|
+
|
|
63
|
+
sig = inspect.signature(provider_cls.__init__)
|
|
64
|
+
if "api_key" in sig.parameters and config.api_key is not None:
|
|
65
|
+
kwargs["api_key"] = config.api_key
|
|
66
|
+
|
|
67
|
+
if config.base_url is not None:
|
|
68
|
+
kwargs["base_url"] = config.base_url
|
|
69
|
+
|
|
70
|
+
instance = provider_cls(**kwargs)
|
|
71
|
+
return cast("OpenAIProvider | AnthropicProvider | OllamaProvider", instance)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class LLMClient:
|
|
75
|
+
"""Orchestrates multi-provider LLM access with fallback and retry logic.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
config: Full LLMWire configuration including provider list, fallback
|
|
79
|
+
behaviour, retry count, and timeout.
|
|
80
|
+
|
|
81
|
+
Raises:
|
|
82
|
+
ValueError: If ``config.providers`` is empty.
|
|
83
|
+
|
|
84
|
+
Example::
|
|
85
|
+
|
|
86
|
+
config = LLMConfig(providers=[ProviderConfig(name="openai", api_key="...", model="gpt-4o")])
|
|
87
|
+
client = LLMClient(config)
|
|
88
|
+
response = await client.chat("Hello!")
|
|
89
|
+
print(response.content)
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def __init__(self, config: LLMConfig) -> None:
|
|
93
|
+
if not config.providers:
|
|
94
|
+
raise ValueError("At least one provider must be configured")
|
|
95
|
+
self._config = config
|
|
96
|
+
self._providers = [_create_provider(p, config.timeout) for p in config.providers]
|
|
97
|
+
|
|
98
|
+
async def close(self) -> None:
|
|
99
|
+
"""Close all underlying provider HTTP clients."""
|
|
100
|
+
for provider in self._providers:
|
|
101
|
+
await provider.close()
|
|
102
|
+
|
|
103
|
+
async def __aenter__(self) -> LLMClient:
|
|
104
|
+
return self
|
|
105
|
+
|
|
106
|
+
async def __aexit__(self, *exc: object) -> None:
|
|
107
|
+
await self.close()
|
|
108
|
+
|
|
109
|
+
def _normalize_messages(self, prompt: str | list[Message]) -> list[Message]:
|
|
110
|
+
"""Convert a bare string prompt into a single-item user message list.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
prompt: Either a plain string or an already-formed list of Messages.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
A list of Message objects ready to send to a provider.
|
|
117
|
+
"""
|
|
118
|
+
if isinstance(prompt, str):
|
|
119
|
+
return [Message(role="user", content=prompt)]
|
|
120
|
+
return list(prompt)
|
|
121
|
+
|
|
122
|
+
def _build_schema_system_message(self, model_cls: type[Any]) -> Message:
|
|
123
|
+
"""Build a system message instructing the LLM to respond with valid JSON.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
model_cls: A Pydantic ``BaseModel`` subclass whose schema to embed.
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
A system ``Message`` containing the JSON schema.
|
|
130
|
+
"""
|
|
131
|
+
schema = _cached_schema_json(model_cls)
|
|
132
|
+
content = (
|
|
133
|
+
"You must respond with valid JSON that matches the following JSON schema. "
|
|
134
|
+
"Do not include any explanation or markdown — only the raw JSON object.\n\n"
|
|
135
|
+
f"Schema:\n{schema}"
|
|
136
|
+
)
|
|
137
|
+
return Message(role="system", content=content)
|
|
138
|
+
|
|
139
|
+
@overload
|
|
140
|
+
async def chat(
|
|
141
|
+
self,
|
|
142
|
+
prompt: str | list[Message],
|
|
143
|
+
*,
|
|
144
|
+
temperature: float = ...,
|
|
145
|
+
max_tokens: int | None = ...,
|
|
146
|
+
response_model: None = ...,
|
|
147
|
+
) -> ChatResponse: ...
|
|
148
|
+
|
|
149
|
+
@overload
|
|
150
|
+
async def chat(
|
|
151
|
+
self,
|
|
152
|
+
prompt: str | list[Message],
|
|
153
|
+
*,
|
|
154
|
+
temperature: float = ...,
|
|
155
|
+
max_tokens: int | None = ...,
|
|
156
|
+
response_model: type[T],
|
|
157
|
+
) -> T: ...
|
|
158
|
+
|
|
159
|
+
async def chat(
|
|
160
|
+
self,
|
|
161
|
+
prompt: str | list[Message],
|
|
162
|
+
*,
|
|
163
|
+
temperature: float = 0.7,
|
|
164
|
+
max_tokens: int | None = None,
|
|
165
|
+
response_model: type[T] | None = None,
|
|
166
|
+
) -> ChatResponse | T:
|
|
167
|
+
"""Send a chat completion request, with optional fallback across providers.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
prompt: A string or list of ``Message`` objects. A bare string is
|
|
171
|
+
wrapped in a single ``user`` message.
|
|
172
|
+
temperature: Sampling temperature forwarded to the provider.
|
|
173
|
+
max_tokens: Maximum tokens to generate; ``None`` uses provider default.
|
|
174
|
+
response_model: Optional Pydantic model class. When supplied, a system
|
|
175
|
+
message with the JSON schema is prepended and the response content
|
|
176
|
+
is parsed into an instance of this model.
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
A ``ChatResponse`` when ``response_model`` is ``None``, otherwise an
|
|
180
|
+
instance of ``response_model``.
|
|
181
|
+
|
|
182
|
+
Raises:
|
|
183
|
+
AllProvidersFailedError: When all attempted providers raise
|
|
184
|
+
``ProviderError``.
|
|
185
|
+
"""
|
|
186
|
+
messages = self._normalize_messages(prompt)
|
|
187
|
+
|
|
188
|
+
if response_model is not None:
|
|
189
|
+
schema_msg = self._build_schema_system_message(response_model)
|
|
190
|
+
messages = [schema_msg, *messages]
|
|
191
|
+
|
|
192
|
+
providers_to_try = self._providers if self._config.fallback else self._providers[:1]
|
|
193
|
+
|
|
194
|
+
errors: list[ProviderError] = []
|
|
195
|
+
for provider in providers_to_try:
|
|
196
|
+
|
|
197
|
+
async def _call(
|
|
198
|
+
_provider: OpenAIProvider | AnthropicProvider | OllamaProvider = provider,
|
|
199
|
+
) -> ChatResponse:
|
|
200
|
+
return await _provider.chat(
|
|
201
|
+
messages,
|
|
202
|
+
temperature=temperature,
|
|
203
|
+
max_tokens=max_tokens,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
try:
|
|
207
|
+
response = await retry_with_backoff(
|
|
208
|
+
_call,
|
|
209
|
+
max_retries=self._config.max_retries,
|
|
210
|
+
retryable_exceptions=(ProviderError,),
|
|
211
|
+
)
|
|
212
|
+
except ProviderError as exc:
|
|
213
|
+
errors.append(exc)
|
|
214
|
+
continue
|
|
215
|
+
|
|
216
|
+
if response_model is not None:
|
|
217
|
+
return response_model.model_validate_json(response.content)
|
|
218
|
+
return response
|
|
219
|
+
|
|
220
|
+
raise AllProvidersFailedError(errors)
|
|
221
|
+
|
|
222
|
+
async def stream(
|
|
223
|
+
self,
|
|
224
|
+
prompt: str | list[Message],
|
|
225
|
+
*,
|
|
226
|
+
temperature: float = 0.7,
|
|
227
|
+
max_tokens: int | None = None,
|
|
228
|
+
) -> AsyncIterator[StreamChunk]:
|
|
229
|
+
"""Stream a chat completion response, with optional fallback across providers.
|
|
230
|
+
|
|
231
|
+
Yields chunks from the first provider that responds successfully.
|
|
232
|
+
If a provider fails before streaming starts, the next provider is tried
|
|
233
|
+
(when ``fallback=True``).
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
prompt: A string or list of ``Message`` objects.
|
|
237
|
+
temperature: Sampling temperature forwarded to the provider.
|
|
238
|
+
max_tokens: Maximum tokens to generate; ``None`` uses provider default.
|
|
239
|
+
|
|
240
|
+
Yields:
|
|
241
|
+
``StreamChunk`` objects as tokens arrive.
|
|
242
|
+
|
|
243
|
+
Raises:
|
|
244
|
+
AllProvidersFailedError: When all attempted providers raise
|
|
245
|
+
``ProviderError``.
|
|
246
|
+
"""
|
|
247
|
+
messages = self._normalize_messages(prompt)
|
|
248
|
+
providers_to_try = self._providers if self._config.fallback else self._providers[:1]
|
|
249
|
+
|
|
250
|
+
errors: list[ProviderError] = []
|
|
251
|
+
for provider in providers_to_try:
|
|
252
|
+
has_yielded = False
|
|
253
|
+
try:
|
|
254
|
+
async for chunk in provider.stream(
|
|
255
|
+
messages,
|
|
256
|
+
temperature=temperature,
|
|
257
|
+
max_tokens=max_tokens,
|
|
258
|
+
):
|
|
259
|
+
has_yielded = True
|
|
260
|
+
yield chunk
|
|
261
|
+
except ProviderError as exc:
|
|
262
|
+
if has_yielded:
|
|
263
|
+
# Already started yielding — cannot fall back; re-raise.
|
|
264
|
+
raise
|
|
265
|
+
errors.append(exc)
|
|
266
|
+
continue
|
|
267
|
+
else:
|
|
268
|
+
return
|
|
269
|
+
|
|
270
|
+
raise AllProvidersFailedError(errors)
|
llmwire/config.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""LLMWire configuration."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, field_validator
|
|
5
|
+
from pydantic_settings import BaseSettings
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ProviderConfig(BaseModel):
|
|
9
|
+
"""Configuration for a single LLM provider."""
|
|
10
|
+
|
|
11
|
+
name: str
|
|
12
|
+
api_key: str | None = None
|
|
13
|
+
model: str = ""
|
|
14
|
+
base_url: str | None = None
|
|
15
|
+
|
|
16
|
+
@field_validator("name", mode="before")
|
|
17
|
+
@classmethod
|
|
18
|
+
def normalize_name(cls, v: object) -> object:
|
|
19
|
+
"""Normalize provider name to lowercase."""
|
|
20
|
+
if isinstance(v, str):
|
|
21
|
+
return v.strip().lower()
|
|
22
|
+
return v
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class LLMConfig(BaseSettings):
|
|
26
|
+
"""Main LLMWire configuration."""
|
|
27
|
+
|
|
28
|
+
model_config = {"env_prefix": "LLMKIT_", "env_nested_delimiter": "__"}
|
|
29
|
+
|
|
30
|
+
providers: list[ProviderConfig] = []
|
|
31
|
+
fallback: bool = True
|
|
32
|
+
max_retries: int = 3
|
|
33
|
+
timeout: float = 30.0
|
|
34
|
+
|
|
35
|
+
@field_validator("providers", mode="before")
|
|
36
|
+
@classmethod
|
|
37
|
+
def coerce_providers(cls, v: object) -> object:
|
|
38
|
+
"""Convert indexed-dict form (from env vars) to a list."""
|
|
39
|
+
if isinstance(v, dict):
|
|
40
|
+
return [v[k] for k in sorted(v.keys(), key=lambda x: int(x))]
|
|
41
|
+
return v
|
llmwire/exceptions.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""LLMWire exceptions."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class LLMWireError(Exception):
|
|
6
|
+
"""Base exception for LLMWire."""
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ProviderError(LLMWireError):
|
|
10
|
+
"""Raised when a single provider fails."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, provider: str, message: str) -> None:
|
|
13
|
+
self.provider = provider
|
|
14
|
+
super().__init__(f"[{provider}] {message}")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class AllProvidersFailedError(LLMWireError):
|
|
18
|
+
"""Raised when all providers fail during fallback."""
|
|
19
|
+
|
|
20
|
+
def __init__(self, errors: list[ProviderError]) -> None:
|
|
21
|
+
self.errors = errors
|
|
22
|
+
providers = ", ".join(e.provider for e in errors)
|
|
23
|
+
super().__init__(f"All providers failed: {providers}")
|
llmwire/models.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""Data models for LLMWire."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
Role = Literal["system", "user", "assistant"]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Message(BaseModel):
|
|
12
|
+
"""A chat message."""
|
|
13
|
+
|
|
14
|
+
role: Role
|
|
15
|
+
content: str
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Usage(BaseModel):
|
|
19
|
+
"""Token usage statistics."""
|
|
20
|
+
|
|
21
|
+
prompt_tokens: int = 0
|
|
22
|
+
completion_tokens: int = 0
|
|
23
|
+
total_tokens: int = 0
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class _LLMResponseBase(BaseModel):
|
|
27
|
+
"""Shared fields for all LLM response types."""
|
|
28
|
+
|
|
29
|
+
content: str
|
|
30
|
+
provider: str
|
|
31
|
+
model: str
|
|
32
|
+
usage: Usage | None = None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ChatResponse(_LLMResponseBase):
|
|
36
|
+
"""Response from an LLM chat completion."""
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class StreamChunk(_LLMResponseBase):
|
|
40
|
+
"""A single chunk from a streaming response."""
|
|
41
|
+
|
|
42
|
+
done: bool = False
|
llmwire/provider.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
"""Provider protocol — the interface all LLM providers implement."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import TYPE_CHECKING, Protocol
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from collections.abc import AsyncIterator
|
|
8
|
+
|
|
9
|
+
from llmwire.models import ChatResponse, Message, StreamChunk
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Provider(Protocol):
|
|
13
|
+
"""Interface for LLM provider adapters."""
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def name(self) -> str: ...
|
|
17
|
+
|
|
18
|
+
async def close(self) -> None: ...
|
|
19
|
+
|
|
20
|
+
async def chat(
|
|
21
|
+
self,
|
|
22
|
+
messages: list[Message],
|
|
23
|
+
*,
|
|
24
|
+
model: str | None = None,
|
|
25
|
+
temperature: float = 0.7,
|
|
26
|
+
max_tokens: int | None = None,
|
|
27
|
+
) -> ChatResponse: ...
|
|
28
|
+
|
|
29
|
+
async def stream(
|
|
30
|
+
self,
|
|
31
|
+
messages: list[Message],
|
|
32
|
+
*,
|
|
33
|
+
model: str | None = None,
|
|
34
|
+
temperature: float = 0.7,
|
|
35
|
+
max_tokens: int | None = None,
|
|
36
|
+
) -> AsyncIterator[StreamChunk]: ...
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
"""LLM provider implementations."""
|
|
2
|
+
from llmwire.providers.anthropic import AnthropicProvider
|
|
3
|
+
from llmwire.providers.ollama import OllamaProvider
|
|
4
|
+
from llmwire.providers.openai import OpenAIProvider
|
|
5
|
+
|
|
6
|
+
__all__ = ["AnthropicProvider", "OllamaProvider", "OpenAIProvider"]
|
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
"""Anthropic provider implementation."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import json
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
import httpx
|
|
8
|
+
|
|
9
|
+
from llmwire.exceptions import ProviderError
|
|
10
|
+
from llmwire.models import ChatResponse, Message, StreamChunk, Usage
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from collections.abc import AsyncIterator
|
|
14
|
+
|
|
15
|
+
_ANTHROPIC_VERSION = "2023-06-01"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AnthropicProvider:
|
|
19
|
+
"""LLM provider for Anthropic's Messages API.
|
|
20
|
+
|
|
21
|
+
Uses httpx directly. Handles system message extraction into the
|
|
22
|
+
top-level ``system`` field as required by the Anthropic API.
|
|
23
|
+
Supports both chat completions and streaming via Server-Sent Events.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
api_key: Anthropic API key.
|
|
27
|
+
model: Model identifier (e.g. "claude-3-5-sonnet-20241022").
|
|
28
|
+
base_url: API base URL. Defaults to the official Anthropic endpoint.
|
|
29
|
+
timeout: Request timeout in seconds. Defaults to 30.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
api_key: str,
|
|
35
|
+
model: str,
|
|
36
|
+
base_url: str = "https://api.anthropic.com/v1",
|
|
37
|
+
timeout: float = 30.0,
|
|
38
|
+
) -> None:
|
|
39
|
+
self._api_key = api_key
|
|
40
|
+
self._model = model
|
|
41
|
+
self._base_url = base_url.rstrip("/")
|
|
42
|
+
self._timeout = timeout
|
|
43
|
+
self._client = httpx.AsyncClient(timeout=self._timeout)
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def name(self) -> str:
|
|
47
|
+
"""Provider identifier."""
|
|
48
|
+
return "anthropic"
|
|
49
|
+
|
|
50
|
+
async def close(self) -> None:
|
|
51
|
+
"""Close the underlying HTTP client."""
|
|
52
|
+
await self._client.aclose()
|
|
53
|
+
|
|
54
|
+
def _headers(self) -> dict[str, str]:
|
|
55
|
+
return {
|
|
56
|
+
"x-api-key": self._api_key,
|
|
57
|
+
"anthropic-version": _ANTHROPIC_VERSION,
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
def _build_payload(
|
|
61
|
+
self,
|
|
62
|
+
messages: list[Message],
|
|
63
|
+
*,
|
|
64
|
+
model: str | None,
|
|
65
|
+
temperature: float,
|
|
66
|
+
max_tokens: int | None,
|
|
67
|
+
stream: bool = False,
|
|
68
|
+
) -> dict[str, Any]:
|
|
69
|
+
"""Build the request payload, extracting system messages."""
|
|
70
|
+
system_parts: list[str] = []
|
|
71
|
+
user_messages: list[dict[str, str]] = []
|
|
72
|
+
|
|
73
|
+
for msg in messages:
|
|
74
|
+
if msg.role == "system":
|
|
75
|
+
system_parts.append(msg.content)
|
|
76
|
+
else:
|
|
77
|
+
user_messages.append({"role": msg.role, "content": msg.content})
|
|
78
|
+
|
|
79
|
+
payload: dict[str, Any] = {
|
|
80
|
+
"model": model or self._model,
|
|
81
|
+
"messages": user_messages,
|
|
82
|
+
"temperature": temperature,
|
|
83
|
+
"max_tokens": max_tokens if max_tokens is not None else 4096,
|
|
84
|
+
"stream": stream,
|
|
85
|
+
}
|
|
86
|
+
if system_parts:
|
|
87
|
+
payload["system"] = "\n".join(system_parts)
|
|
88
|
+
|
|
89
|
+
return payload
|
|
90
|
+
|
|
91
|
+
async def chat(
|
|
92
|
+
self,
|
|
93
|
+
messages: list[Message],
|
|
94
|
+
*,
|
|
95
|
+
model: str | None = None,
|
|
96
|
+
temperature: float = 0.7,
|
|
97
|
+
max_tokens: int | None = None,
|
|
98
|
+
) -> ChatResponse:
|
|
99
|
+
"""Send a chat completion request.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
messages: Conversation history. System messages are extracted
|
|
103
|
+
and placed in the top-level ``system`` field.
|
|
104
|
+
model: Override the default model.
|
|
105
|
+
temperature: Sampling temperature.
|
|
106
|
+
max_tokens: Maximum tokens to generate. Defaults to 4096.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
ChatResponse with content, provider, model, and usage.
|
|
110
|
+
|
|
111
|
+
Raises:
|
|
112
|
+
ProviderError: On any non-200 HTTP response.
|
|
113
|
+
"""
|
|
114
|
+
url = f"{self._base_url}/messages"
|
|
115
|
+
payload = self._build_payload(
|
|
116
|
+
messages,
|
|
117
|
+
model=model,
|
|
118
|
+
temperature=temperature,
|
|
119
|
+
max_tokens=max_tokens,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
try:
|
|
123
|
+
response = await self._client.post(url, headers=self._headers(), json=payload)
|
|
124
|
+
except httpx.TransportError as exc:
|
|
125
|
+
raise ProviderError(self.name, f"Connection failed: {exc}") from exc
|
|
126
|
+
|
|
127
|
+
if response.status_code != 200:
|
|
128
|
+
raise ProviderError(self.name, f"HTTP {response.status_code}: {response.text}")
|
|
129
|
+
|
|
130
|
+
data = response.json()
|
|
131
|
+
content: str = data["content"][0]["text"]
|
|
132
|
+
resolved_model: str = data.get("model", model or self._model)
|
|
133
|
+
|
|
134
|
+
usage: Usage | None = None
|
|
135
|
+
if raw_usage := data.get("usage"):
|
|
136
|
+
prompt_tokens: int = raw_usage.get("input_tokens", 0)
|
|
137
|
+
completion_tokens: int = raw_usage.get("output_tokens", 0)
|
|
138
|
+
usage = Usage(
|
|
139
|
+
prompt_tokens=prompt_tokens,
|
|
140
|
+
completion_tokens=completion_tokens,
|
|
141
|
+
total_tokens=prompt_tokens + completion_tokens,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
return ChatResponse(
|
|
145
|
+
content=content,
|
|
146
|
+
provider=self.name,
|
|
147
|
+
model=resolved_model,
|
|
148
|
+
usage=usage,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
async def stream(
|
|
152
|
+
self,
|
|
153
|
+
messages: list[Message],
|
|
154
|
+
*,
|
|
155
|
+
model: str | None = None,
|
|
156
|
+
temperature: float = 0.7,
|
|
157
|
+
max_tokens: int | None = None,
|
|
158
|
+
) -> AsyncIterator[StreamChunk]:
|
|
159
|
+
"""Stream a chat completion response via SSE.
|
|
160
|
+
|
|
161
|
+
Anthropic uses typed events. Only ``content_block_delta`` events
|
|
162
|
+
with ``text_delta`` deltas produce text. ``message_stop`` signals
|
|
163
|
+
the end of the stream.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
messages: Conversation history.
|
|
167
|
+
model: Override the default model.
|
|
168
|
+
temperature: Sampling temperature.
|
|
169
|
+
max_tokens: Maximum tokens to generate. Defaults to 4096.
|
|
170
|
+
|
|
171
|
+
Yields:
|
|
172
|
+
StreamChunk for each text delta received.
|
|
173
|
+
|
|
174
|
+
Raises:
|
|
175
|
+
ProviderError: On any non-200 HTTP response.
|
|
176
|
+
"""
|
|
177
|
+
url = f"{self._base_url}/messages"
|
|
178
|
+
payload = self._build_payload(
|
|
179
|
+
messages,
|
|
180
|
+
model=model,
|
|
181
|
+
temperature=temperature,
|
|
182
|
+
max_tokens=max_tokens,
|
|
183
|
+
stream=True,
|
|
184
|
+
)
|
|
185
|
+
resolved_model = model or self._model
|
|
186
|
+
|
|
187
|
+
try:
|
|
188
|
+
async with self._client.stream(
|
|
189
|
+
"POST", url, headers=self._headers(), json=payload
|
|
190
|
+
) as response:
|
|
191
|
+
if response.status_code != 200:
|
|
192
|
+
body = await response.aread()
|
|
193
|
+
raise ProviderError(
|
|
194
|
+
self.name, f"HTTP {response.status_code}: {body.decode()}"
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
async for line in response.aiter_lines():
|
|
198
|
+
if not line.startswith("data: "):
|
|
199
|
+
continue
|
|
200
|
+
raw = line[len("data: "):]
|
|
201
|
+
try:
|
|
202
|
+
event = json.loads(raw)
|
|
203
|
+
except json.JSONDecodeError:
|
|
204
|
+
continue
|
|
205
|
+
|
|
206
|
+
event_type: str = event.get("type", "")
|
|
207
|
+
|
|
208
|
+
if event_type == "message_stop":
|
|
209
|
+
yield StreamChunk(
|
|
210
|
+
content="",
|
|
211
|
+
provider=self.name,
|
|
212
|
+
model=resolved_model,
|
|
213
|
+
done=True,
|
|
214
|
+
)
|
|
215
|
+
return
|
|
216
|
+
|
|
217
|
+
if event_type == "content_block_delta":
|
|
218
|
+
delta = event.get("delta", {})
|
|
219
|
+
if delta.get("type") == "text_delta":
|
|
220
|
+
text: str = delta.get("text", "")
|
|
221
|
+
if text:
|
|
222
|
+
yield StreamChunk(
|
|
223
|
+
content=text,
|
|
224
|
+
provider=self.name,
|
|
225
|
+
model=resolved_model,
|
|
226
|
+
)
|
|
227
|
+
except httpx.TransportError as exc:
|
|
228
|
+
raise ProviderError(self.name, f"Connection failed: {exc}") from exc
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
"""Ollama provider implementation for local LLM inference."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import json
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
import httpx
|
|
8
|
+
|
|
9
|
+
from llmwire.exceptions import ProviderError
|
|
10
|
+
from llmwire.models import ChatResponse, Message, StreamChunk, Usage
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from collections.abc import AsyncIterator
|
|
14
|
+
|
|
15
|
+
_DEFAULT_BASE_URL = "http://localhost:11434"
|
|
16
|
+
_DEFAULT_TIMEOUT = 120.0 # local models are slower
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class OllamaProvider:
|
|
20
|
+
"""LLM provider for Ollama local inference server.
|
|
21
|
+
|
|
22
|
+
Connects to a locally running Ollama instance. No API key is needed.
|
|
23
|
+
Uses NDJSON for streaming (one JSON object per line).
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
model: Model identifier (e.g. "llama3.2").
|
|
27
|
+
base_url: Ollama server URL. Defaults to ``http://localhost:11434``.
|
|
28
|
+
timeout: Request timeout in seconds. Defaults to 120.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
model: str,
|
|
34
|
+
base_url: str = _DEFAULT_BASE_URL,
|
|
35
|
+
timeout: float = _DEFAULT_TIMEOUT,
|
|
36
|
+
) -> None:
|
|
37
|
+
self._model = model
|
|
38
|
+
self._base_url = base_url.rstrip("/")
|
|
39
|
+
self._timeout = timeout
|
|
40
|
+
self._client = httpx.AsyncClient(timeout=self._timeout)
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def name(self) -> str:
|
|
44
|
+
"""Provider identifier."""
|
|
45
|
+
return "ollama"
|
|
46
|
+
|
|
47
|
+
async def close(self) -> None:
|
|
48
|
+
"""Close the underlying HTTP client."""
|
|
49
|
+
await self._client.aclose()
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def base_url(self) -> str:
|
|
53
|
+
"""Ollama server base URL."""
|
|
54
|
+
return self._base_url
|
|
55
|
+
|
|
56
|
+
def _headers(self) -> dict[str, str]:
|
|
57
|
+
return {}
|
|
58
|
+
|
|
59
|
+
def _build_payload(
|
|
60
|
+
self,
|
|
61
|
+
messages: list[Message],
|
|
62
|
+
*,
|
|
63
|
+
model: str | None,
|
|
64
|
+
temperature: float,
|
|
65
|
+
max_tokens: int | None,
|
|
66
|
+
stream: bool = False,
|
|
67
|
+
) -> dict[str, Any]:
|
|
68
|
+
options: dict[str, Any] = {"temperature": temperature}
|
|
69
|
+
if max_tokens is not None:
|
|
70
|
+
options["num_predict"] = max_tokens
|
|
71
|
+
|
|
72
|
+
return {
|
|
73
|
+
"model": model or self._model,
|
|
74
|
+
"messages": [{"role": m.role, "content": m.content} for m in messages],
|
|
75
|
+
"stream": stream,
|
|
76
|
+
"options": options,
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
async def chat(
|
|
80
|
+
self,
|
|
81
|
+
messages: list[Message],
|
|
82
|
+
*,
|
|
83
|
+
model: str | None = None,
|
|
84
|
+
temperature: float = 0.7,
|
|
85
|
+
max_tokens: int | None = None,
|
|
86
|
+
) -> ChatResponse:
|
|
87
|
+
"""Send a chat completion request to the Ollama server.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
messages: Conversation history.
|
|
91
|
+
model: Override the default model.
|
|
92
|
+
temperature: Sampling temperature.
|
|
93
|
+
max_tokens: Maximum tokens to generate (``num_predict``).
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
ChatResponse with content, provider, model, and usage.
|
|
97
|
+
|
|
98
|
+
Raises:
|
|
99
|
+
ProviderError: On connection errors or non-200 HTTP responses.
|
|
100
|
+
"""
|
|
101
|
+
url = f"{self._base_url}/api/chat"
|
|
102
|
+
payload = self._build_payload(
|
|
103
|
+
messages,
|
|
104
|
+
model=model,
|
|
105
|
+
temperature=temperature,
|
|
106
|
+
max_tokens=max_tokens,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
try:
|
|
110
|
+
response = await self._client.post(url, headers=self._headers(), json=payload)
|
|
111
|
+
except httpx.TransportError as exc:
|
|
112
|
+
raise ProviderError(self.name, f"Connection failed: {exc}") from exc
|
|
113
|
+
|
|
114
|
+
if response.status_code != 200:
|
|
115
|
+
raise ProviderError(self.name, f"HTTP {response.status_code}: {response.text}")
|
|
116
|
+
|
|
117
|
+
data = response.json()
|
|
118
|
+
content: str = data["message"]["content"]
|
|
119
|
+
resolved_model: str = data.get("model", model or self._model)
|
|
120
|
+
|
|
121
|
+
prompt_tokens: int = data.get("prompt_eval_count", 0)
|
|
122
|
+
completion_tokens: int = data.get("eval_count", 0)
|
|
123
|
+
usage = Usage(
|
|
124
|
+
prompt_tokens=prompt_tokens,
|
|
125
|
+
completion_tokens=completion_tokens,
|
|
126
|
+
total_tokens=prompt_tokens + completion_tokens,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
return ChatResponse(
|
|
130
|
+
content=content,
|
|
131
|
+
provider=self.name,
|
|
132
|
+
model=resolved_model,
|
|
133
|
+
usage=usage,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
async def stream(
|
|
137
|
+
self,
|
|
138
|
+
messages: list[Message],
|
|
139
|
+
*,
|
|
140
|
+
model: str | None = None,
|
|
141
|
+
temperature: float = 0.7,
|
|
142
|
+
max_tokens: int | None = None,
|
|
143
|
+
) -> AsyncIterator[StreamChunk]:
|
|
144
|
+
"""Stream a chat completion response via NDJSON.
|
|
145
|
+
|
|
146
|
+
Ollama streams one JSON object per line. The ``done`` field signals
|
|
147
|
+
the end of the stream.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
messages: Conversation history.
|
|
151
|
+
model: Override the default model.
|
|
152
|
+
temperature: Sampling temperature.
|
|
153
|
+
max_tokens: Maximum tokens to generate (``num_predict``).
|
|
154
|
+
|
|
155
|
+
Yields:
|
|
156
|
+
StreamChunk for each non-empty content token.
|
|
157
|
+
|
|
158
|
+
Raises:
|
|
159
|
+
ProviderError: On connection errors or non-200 HTTP responses.
|
|
160
|
+
"""
|
|
161
|
+
url = f"{self._base_url}/api/chat"
|
|
162
|
+
payload = self._build_payload(
|
|
163
|
+
messages,
|
|
164
|
+
model=model,
|
|
165
|
+
temperature=temperature,
|
|
166
|
+
max_tokens=max_tokens,
|
|
167
|
+
stream=True,
|
|
168
|
+
)
|
|
169
|
+
resolved_model = model or self._model
|
|
170
|
+
|
|
171
|
+
try:
|
|
172
|
+
async with self._client.stream(
|
|
173
|
+
"POST", url, headers=self._headers(), json=payload
|
|
174
|
+
) as response:
|
|
175
|
+
if response.status_code != 200:
|
|
176
|
+
body = await response.aread()
|
|
177
|
+
raise ProviderError(self.name, f"HTTP {response.status_code}: {body.decode()}")
|
|
178
|
+
|
|
179
|
+
async for line in response.aiter_lines():
|
|
180
|
+
if not line.strip():
|
|
181
|
+
continue
|
|
182
|
+
try:
|
|
183
|
+
event = json.loads(line)
|
|
184
|
+
except json.JSONDecodeError:
|
|
185
|
+
continue
|
|
186
|
+
|
|
187
|
+
chunk_content: str = event.get("message", {}).get("content", "")
|
|
188
|
+
event_model: str = event.get("model", resolved_model)
|
|
189
|
+
is_done: bool = event.get("done", False)
|
|
190
|
+
|
|
191
|
+
if chunk_content:
|
|
192
|
+
yield StreamChunk(
|
|
193
|
+
content=chunk_content,
|
|
194
|
+
provider=self.name,
|
|
195
|
+
model=event_model,
|
|
196
|
+
done=is_done,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
if is_done:
|
|
200
|
+
break
|
|
201
|
+
except httpx.TransportError as exc:
|
|
202
|
+
raise ProviderError(self.name, f"Connection failed: {exc}") from exc
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
"""OpenAI provider implementation."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import json
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
import httpx
|
|
8
|
+
|
|
9
|
+
from llmwire.exceptions import ProviderError
|
|
10
|
+
from llmwire.models import ChatResponse, Message, StreamChunk, Usage
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from collections.abc import AsyncIterator
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class OpenAIProvider:
|
|
17
|
+
"""LLM provider for OpenAI-compatible APIs.
|
|
18
|
+
|
|
19
|
+
Uses httpx directly (no SDK dependency). Supports both chat completions
|
|
20
|
+
and streaming via Server-Sent Events.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
api_key: OpenAI API key.
|
|
24
|
+
model: Model identifier (e.g. "gpt-4o").
|
|
25
|
+
base_url: API base URL. Defaults to the official OpenAI endpoint.
|
|
26
|
+
timeout: Request timeout in seconds. Defaults to 30.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
api_key: str,
|
|
32
|
+
model: str,
|
|
33
|
+
base_url: str = "https://api.openai.com/v1",
|
|
34
|
+
timeout: float = 30.0,
|
|
35
|
+
) -> None:
|
|
36
|
+
self._api_key = api_key
|
|
37
|
+
self._model = model
|
|
38
|
+
self._base_url = base_url.rstrip("/")
|
|
39
|
+
self._timeout = timeout
|
|
40
|
+
self._client = httpx.AsyncClient(timeout=self._timeout)
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def name(self) -> str:
|
|
44
|
+
"""Provider identifier."""
|
|
45
|
+
return "openai"
|
|
46
|
+
|
|
47
|
+
async def close(self) -> None:
|
|
48
|
+
"""Close the underlying HTTP client."""
|
|
49
|
+
await self._client.aclose()
|
|
50
|
+
|
|
51
|
+
def _headers(self) -> dict[str, str]:
|
|
52
|
+
return {
|
|
53
|
+
"Authorization": f"Bearer {self._api_key}",
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
def _build_payload(
|
|
57
|
+
self,
|
|
58
|
+
messages: list[Message],
|
|
59
|
+
*,
|
|
60
|
+
model: str | None,
|
|
61
|
+
temperature: float,
|
|
62
|
+
max_tokens: int | None,
|
|
63
|
+
stream: bool = False,
|
|
64
|
+
) -> dict[str, Any]:
|
|
65
|
+
payload: dict[str, Any] = {
|
|
66
|
+
"model": model or self._model,
|
|
67
|
+
"messages": [{"role": m.role, "content": m.content} for m in messages],
|
|
68
|
+
"temperature": temperature,
|
|
69
|
+
"stream": stream,
|
|
70
|
+
}
|
|
71
|
+
if max_tokens is not None:
|
|
72
|
+
payload["max_tokens"] = max_tokens
|
|
73
|
+
return payload
|
|
74
|
+
|
|
75
|
+
async def chat(
|
|
76
|
+
self,
|
|
77
|
+
messages: list[Message],
|
|
78
|
+
*,
|
|
79
|
+
model: str | None = None,
|
|
80
|
+
temperature: float = 0.7,
|
|
81
|
+
max_tokens: int | None = None,
|
|
82
|
+
) -> ChatResponse:
|
|
83
|
+
"""Send a chat completion request.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
messages: Conversation history.
|
|
87
|
+
model: Override the default model.
|
|
88
|
+
temperature: Sampling temperature.
|
|
89
|
+
max_tokens: Maximum tokens to generate.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
ChatResponse with content, provider, model, and usage.
|
|
93
|
+
|
|
94
|
+
Raises:
|
|
95
|
+
ProviderError: On any non-200 HTTP response.
|
|
96
|
+
"""
|
|
97
|
+
url = f"{self._base_url}/chat/completions"
|
|
98
|
+
payload = self._build_payload(
|
|
99
|
+
messages,
|
|
100
|
+
model=model,
|
|
101
|
+
temperature=temperature,
|
|
102
|
+
max_tokens=max_tokens,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
try:
|
|
106
|
+
response = await self._client.post(url, headers=self._headers(), json=payload)
|
|
107
|
+
except httpx.TransportError as exc:
|
|
108
|
+
raise ProviderError(self.name, f"Connection failed: {exc}") from exc
|
|
109
|
+
|
|
110
|
+
if response.status_code != 200:
|
|
111
|
+
raise ProviderError(self.name, f"HTTP {response.status_code}: {response.text}")
|
|
112
|
+
|
|
113
|
+
data = response.json()
|
|
114
|
+
content: str = data["choices"][0]["message"]["content"]
|
|
115
|
+
resolved_model: str = data.get("model", model or self._model)
|
|
116
|
+
|
|
117
|
+
usage: Usage | None = None
|
|
118
|
+
if raw_usage := data.get("usage"):
|
|
119
|
+
usage = Usage(
|
|
120
|
+
prompt_tokens=raw_usage.get("prompt_tokens", 0),
|
|
121
|
+
completion_tokens=raw_usage.get("completion_tokens", 0),
|
|
122
|
+
total_tokens=raw_usage.get("total_tokens", 0),
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
return ChatResponse(
|
|
126
|
+
content=content,
|
|
127
|
+
provider=self.name,
|
|
128
|
+
model=resolved_model,
|
|
129
|
+
usage=usage,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
async def stream(
|
|
133
|
+
self,
|
|
134
|
+
messages: list[Message],
|
|
135
|
+
*,
|
|
136
|
+
model: str | None = None,
|
|
137
|
+
temperature: float = 0.7,
|
|
138
|
+
max_tokens: int | None = None,
|
|
139
|
+
) -> AsyncIterator[StreamChunk]:
|
|
140
|
+
"""Stream a chat completion response via SSE.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
messages: Conversation history.
|
|
144
|
+
model: Override the default model.
|
|
145
|
+
temperature: Sampling temperature.
|
|
146
|
+
max_tokens: Maximum tokens to generate.
|
|
147
|
+
|
|
148
|
+
Yields:
|
|
149
|
+
StreamChunk for each token delta received.
|
|
150
|
+
|
|
151
|
+
Raises:
|
|
152
|
+
ProviderError: On any non-200 HTTP response.
|
|
153
|
+
"""
|
|
154
|
+
url = f"{self._base_url}/chat/completions"
|
|
155
|
+
payload = self._build_payload(
|
|
156
|
+
messages,
|
|
157
|
+
model=model,
|
|
158
|
+
temperature=temperature,
|
|
159
|
+
max_tokens=max_tokens,
|
|
160
|
+
stream=True,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
try:
|
|
164
|
+
async with self._client.stream(
|
|
165
|
+
"POST", url, headers=self._headers(), json=payload
|
|
166
|
+
) as response:
|
|
167
|
+
if response.status_code != 200:
|
|
168
|
+
body = await response.aread()
|
|
169
|
+
raise ProviderError(
|
|
170
|
+
self.name, f"HTTP {response.status_code}: {body.decode()}"
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
async for line in response.aiter_lines():
|
|
174
|
+
if not line.startswith("data: "):
|
|
175
|
+
continue
|
|
176
|
+
raw = line[len("data: "):]
|
|
177
|
+
if raw == "[DONE]":
|
|
178
|
+
yield StreamChunk(
|
|
179
|
+
content="",
|
|
180
|
+
provider=self.name,
|
|
181
|
+
model=model or self._model,
|
|
182
|
+
done=True,
|
|
183
|
+
)
|
|
184
|
+
return
|
|
185
|
+
try:
|
|
186
|
+
event = json.loads(raw)
|
|
187
|
+
except json.JSONDecodeError:
|
|
188
|
+
continue
|
|
189
|
+
|
|
190
|
+
delta_content: str = event["choices"][0]["delta"].get("content", "")
|
|
191
|
+
event_model: str = event.get("model", model or self._model)
|
|
192
|
+
|
|
193
|
+
if delta_content:
|
|
194
|
+
yield StreamChunk(
|
|
195
|
+
content=delta_content,
|
|
196
|
+
provider=self.name,
|
|
197
|
+
model=event_model,
|
|
198
|
+
)
|
|
199
|
+
except httpx.TransportError as exc:
|
|
200
|
+
raise ProviderError(self.name, f"Connection failed: {exc}") from exc
|
llmwire/retry.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""Retry logic with exponential backoff."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import asyncio
|
|
5
|
+
import random
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from collections.abc import Awaitable, Callable
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
async def retry_with_backoff[T](
|
|
13
|
+
fn: Callable[[], Awaitable[T]],
|
|
14
|
+
*,
|
|
15
|
+
max_retries: int = 3,
|
|
16
|
+
base_delay: float = 1.0,
|
|
17
|
+
retryable_exceptions: tuple[type[Exception], ...] = (ConnectionError, TimeoutError, OSError),
|
|
18
|
+
) -> T:
|
|
19
|
+
"""Execute an async function with exponential backoff retry.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
fn: The async callable to execute.
|
|
23
|
+
max_retries: Maximum number of attempts before raising.
|
|
24
|
+
base_delay: Base delay in seconds for the first backoff interval.
|
|
25
|
+
retryable_exceptions: Exception types that trigger a retry.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
The return value of ``fn`` on success.
|
|
29
|
+
|
|
30
|
+
Raises:
|
|
31
|
+
Exception: The last retryable exception after all attempts are exhausted,
|
|
32
|
+
or any non-retryable exception immediately.
|
|
33
|
+
"""
|
|
34
|
+
last_exception: Exception | None = None
|
|
35
|
+
for attempt in range(max_retries):
|
|
36
|
+
try:
|
|
37
|
+
return await fn()
|
|
38
|
+
except retryable_exceptions as exc:
|
|
39
|
+
last_exception = exc
|
|
40
|
+
if attempt < max_retries - 1:
|
|
41
|
+
delay = base_delay * (2**attempt) + random.uniform(0, base_delay * 0.1)
|
|
42
|
+
await asyncio.sleep(delay)
|
|
43
|
+
except Exception:
|
|
44
|
+
raise
|
|
45
|
+
assert last_exception is not None
|
|
46
|
+
raise last_exception
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: llmwire
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Lightweight multi-provider LLM client for Python
|
|
5
|
+
Project-URL: Homepage, https://github.com/alexmar07/llmwire
|
|
6
|
+
Project-URL: Documentation, https://alexmar07.github.io/llmwire
|
|
7
|
+
Project-URL: Repository, https://github.com/alexmar07/llmwire
|
|
8
|
+
Author-email: Alessandro Marotta <alessand.marotta@gmail.com>
|
|
9
|
+
License-Expression: MIT
|
|
10
|
+
License-File: LICENSE
|
|
11
|
+
Keywords: ai,anthropic,async,llm,ollama,openai
|
|
12
|
+
Classifier: Development Status :: 3 - Alpha
|
|
13
|
+
Classifier: Intended Audience :: Developers
|
|
14
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
17
|
+
Classifier: Topic :: Software Development :: Libraries
|
|
18
|
+
Classifier: Typing :: Typed
|
|
19
|
+
Requires-Python: >=3.12
|
|
20
|
+
Requires-Dist: httpx>=0.27
|
|
21
|
+
Requires-Dist: pydantic-settings>=2.0
|
|
22
|
+
Requires-Dist: pydantic>=2.0
|
|
23
|
+
Requires-Dist: pyyaml>=6.0
|
|
24
|
+
Provides-Extra: dev
|
|
25
|
+
Requires-Dist: mypy>=1.10; extra == 'dev'
|
|
26
|
+
Requires-Dist: pytest-asyncio>=0.23; extra == 'dev'
|
|
27
|
+
Requires-Dist: pytest-cov>=5.0; extra == 'dev'
|
|
28
|
+
Requires-Dist: pytest>=8.0; extra == 'dev'
|
|
29
|
+
Requires-Dist: respx>=0.21; extra == 'dev'
|
|
30
|
+
Requires-Dist: ruff>=0.4; extra == 'dev'
|
|
31
|
+
Provides-Extra: docs
|
|
32
|
+
Requires-Dist: mkdocs-material>=9.5; extra == 'docs'
|
|
33
|
+
Requires-Dist: mkdocs>=1.6; extra == 'docs'
|
|
34
|
+
Requires-Dist: mkdocstrings[python]>=0.25; extra == 'docs'
|
|
35
|
+
Description-Content-Type: text/markdown
|
|
36
|
+
|
|
37
|
+
# LLMWire
|
|
38
|
+
|
|
39
|
+
[](https://github.com/alexmar07/llmwire/actions/workflows/ci.yml)
|
|
40
|
+
[](https://pypi.org/project/llmwire/)
|
|
41
|
+
[](https://pypi.org/project/llmwire/)
|
|
42
|
+
[](LICENSE)
|
|
43
|
+
|
|
44
|
+
Lightweight multi-provider LLM client for Python. A single async interface to
|
|
45
|
+
OpenAI, Anthropic, and Ollama — with automatic fallback, exponential-backoff retry,
|
|
46
|
+
streaming, and structured Pydantic output. No provider SDK dependencies; all requests
|
|
47
|
+
go over plain `httpx`.
|
|
48
|
+
|
|
49
|
+
## Features
|
|
50
|
+
|
|
51
|
+
- **Unified API** — one `LLMClient` for all supported providers
|
|
52
|
+
- **Async-first** — built entirely on `asyncio` and `httpx`
|
|
53
|
+
- **Automatic fallback** — on provider failure, tries the next provider in the list
|
|
54
|
+
- **Exponential backoff** — configurable retry with full jitter
|
|
55
|
+
- **Streaming** — token-by-token via `client.stream()`, async generator interface
|
|
56
|
+
- **Structured output** — pass any Pydantic `BaseModel` as `response_model`
|
|
57
|
+
- **No provider SDKs** — runtime deps are only `httpx`, `pydantic`, `pydantic-settings`, `pyyaml`
|
|
58
|
+
- **Environment variable config** — all settings readable from `LLMKIT_*` env vars
|
|
59
|
+
|
|
60
|
+
## Quick Start
|
|
61
|
+
|
|
62
|
+
```bash
|
|
63
|
+
pip install llmwire
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
### Chat
|
|
67
|
+
|
|
68
|
+
```python
|
|
69
|
+
import asyncio
|
|
70
|
+
from llmwire import LLMClient, LLMConfig, ProviderConfig
|
|
71
|
+
|
|
72
|
+
config = LLMConfig(
|
|
73
|
+
providers=[
|
|
74
|
+
ProviderConfig(name="openai", api_key="sk-...", model="gpt-4o"),
|
|
75
|
+
ProviderConfig(name="anthropic", api_key="sk-ant-...", model="claude-3-5-sonnet-20241022"),
|
|
76
|
+
]
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
async def main():
|
|
80
|
+
async with LLMClient(config) as client:
|
|
81
|
+
response = await client.chat("What is the capital of France?")
|
|
82
|
+
print(response.content)
|
|
83
|
+
# Provider: openai | Model: gpt-4o | Tokens: 42
|
|
84
|
+
|
|
85
|
+
asyncio.run(main())
|
|
86
|
+
```
|
|
87
|
+
|
|
88
|
+
### Streaming
|
|
89
|
+
|
|
90
|
+
```python
|
|
91
|
+
async def main():
|
|
92
|
+
async with LLMClient(config) as client:
|
|
93
|
+
async for chunk in client.stream("Write a haiku about async programming."):
|
|
94
|
+
print(chunk.content, end="", flush=True)
|
|
95
|
+
print()
|
|
96
|
+
```
|
|
97
|
+
|
|
98
|
+
### Structured Output
|
|
99
|
+
|
|
100
|
+
```python
|
|
101
|
+
from pydantic import BaseModel
|
|
102
|
+
|
|
103
|
+
class Sentiment(BaseModel):
|
|
104
|
+
label: str # "positive", "negative", or "neutral"
|
|
105
|
+
confidence: float
|
|
106
|
+
|
|
107
|
+
async def main():
|
|
108
|
+
async with LLMClient(config) as client:
|
|
109
|
+
result: Sentiment = await client.chat(
|
|
110
|
+
"Classify: 'I love this library!'",
|
|
111
|
+
response_model=Sentiment,
|
|
112
|
+
)
|
|
113
|
+
print(result.label, result.confidence) # positive 0.97
|
|
114
|
+
```
|
|
115
|
+
|
|
116
|
+
## Configuration
|
|
117
|
+
|
|
118
|
+
### Direct
|
|
119
|
+
|
|
120
|
+
```python
|
|
121
|
+
from llmwire import LLMConfig, ProviderConfig
|
|
122
|
+
|
|
123
|
+
config = LLMConfig(
|
|
124
|
+
providers=[
|
|
125
|
+
ProviderConfig(name="openai", api_key="sk-...", model="gpt-4o"),
|
|
126
|
+
ProviderConfig(name="ollama", model="llama3.2"), # no key needed
|
|
127
|
+
],
|
|
128
|
+
fallback=True, # try next provider on failure (default: True)
|
|
129
|
+
max_retries=3, # per-provider retry attempts (default: 3)
|
|
130
|
+
timeout=30.0, # request timeout in seconds (default: 30.0)
|
|
131
|
+
)
|
|
132
|
+
```
|
|
133
|
+
|
|
134
|
+
### Environment Variables
|
|
135
|
+
|
|
136
|
+
```bash
|
|
137
|
+
export LLMKIT_PROVIDERS__0__NAME=openai
|
|
138
|
+
export LLMKIT_PROVIDERS__0__API_KEY=sk-...
|
|
139
|
+
export LLMKIT_PROVIDERS__0__MODEL=gpt-4o
|
|
140
|
+
|
|
141
|
+
export LLMKIT_PROVIDERS__1__NAME=anthropic
|
|
142
|
+
export LLMKIT_PROVIDERS__1__API_KEY=sk-ant-...
|
|
143
|
+
export LLMKIT_PROVIDERS__1__MODEL=claude-3-5-sonnet-20241022
|
|
144
|
+
|
|
145
|
+
export LLMKIT_FALLBACK=true
|
|
146
|
+
export LLMKIT_MAX_RETRIES=3
|
|
147
|
+
```
|
|
148
|
+
|
|
149
|
+
```python
|
|
150
|
+
config = LLMConfig() # reads from environment
|
|
151
|
+
```
|
|
152
|
+
|
|
153
|
+
## Provider Support
|
|
154
|
+
|
|
155
|
+
| Provider | Chat | Streaming | Auth | Default endpoint |
|
|
156
|
+
|----------|------|-----------|------|-----------------|
|
|
157
|
+
| OpenAI | yes | yes | API key | `https://api.openai.com/v1` |
|
|
158
|
+
| Anthropic | yes | yes | API key | `https://api.anthropic.com/v1` |
|
|
159
|
+
| Ollama | yes | yes | none | `http://localhost:11434` |
|
|
160
|
+
|
|
161
|
+
The `base_url` field on `ProviderConfig` lets you point any provider at a compatible
|
|
162
|
+
endpoint (e.g. Azure OpenAI, local OpenAI-compatible servers).
|
|
163
|
+
|
|
164
|
+
## Further Reading
|
|
165
|
+
|
|
166
|
+
- [ARCHITECTURE.md](ARCHITECTURE.md) — design decisions, component overview, and provider protocol
|
|
167
|
+
- [CONTRIBUTING.md](CONTRIBUTING.md) — dev setup, code style, and how to add a new provider
|
|
168
|
+
- [Documentation](https://alexmar07.github.io/llmwire) — full API reference and guides
|
|
169
|
+
|
|
170
|
+
## License
|
|
171
|
+
|
|
172
|
+
MIT. See [LICENSE](LICENSE).
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
llmwire/__init__.py,sha256=iuzC_QMFDqlQBQ-y1NiAvXwsS-zcNawnWQCYvCo_5uU,587
|
|
2
|
+
llmwire/client.py,sha256=gL_PTXuexyg2IlGijc_Hq5gaWfp1Hi2du9IYNTmoBxU,9176
|
|
3
|
+
llmwire/config.py,sha256=SuMh8XmZ1afJRul-mt9UMxlTHsguxd9MXFbzqa9B_Rc,1164
|
|
4
|
+
llmwire/exceptions.py,sha256=EyAfD6I3XW3BVNCEUMrUHP8ySUsdKvwxNA9yrfQki_E,680
|
|
5
|
+
llmwire/models.py,sha256=hsEELSMRCipmAZoowac-anQ0Zs5lfgcnjnerJV4-qUY,779
|
|
6
|
+
llmwire/provider.py,sha256=f8cYkSW0XU4JEATVXIYAR45rtJ67nALeadeFzTLWJSo,885
|
|
7
|
+
llmwire/retry.py,sha256=copsUVKgkSsr_0yxACFJMV7-2tkdYQYcFLNVCMWaS1s,1480
|
|
8
|
+
llmwire/providers/__init__.py,sha256=E2OsYM2Pr6X0880XFOtru3B9peW-SnJ61MRnPyumTVw,267
|
|
9
|
+
llmwire/providers/anthropic.py,sha256=RBmsm-AKD6floj67gFRGclac0CR87x6JXABDWBy8-80,7591
|
|
10
|
+
llmwire/providers/ollama.py,sha256=bYqdMkzX4BLdwQi1eq9sanOkn14QAevsykQoE62XlGU,6462
|
|
11
|
+
llmwire/providers/openai.py,sha256=CWKIN72wprosC3YhwysljCpl0acVSO5ok18rRUJcW7Q,6400
|
|
12
|
+
llmwire-0.1.0.dist-info/METADATA,sha256=M19uBPz8r-RTf73rWoDGbdOuiZzLcyrZk6fqW811tb4,5806
|
|
13
|
+
llmwire-0.1.0.dist-info/WHEEL,sha256=QccIxa26bgl1E6uMy58deGWi-0aeIkkangHcxk2kWfw,87
|
|
14
|
+
llmwire-0.1.0.dist-info/licenses/LICENSE,sha256=7C2NtwSu9KtyzsvrXhVkHID3RFuva1Dyj8UhvcTAWjE,1075
|
|
15
|
+
llmwire-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Alessandro Marotta
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|