aleph-rlm 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aleph/__init__.py +49 -0
- aleph/cache/__init__.py +6 -0
- aleph/cache/base.py +20 -0
- aleph/cache/memory.py +27 -0
- aleph/cli.py +1044 -0
- aleph/config.py +154 -0
- aleph/core.py +874 -0
- aleph/mcp/__init__.py +30 -0
- aleph/mcp/local_server.py +3527 -0
- aleph/mcp/server.py +20 -0
- aleph/prompts/__init__.py +5 -0
- aleph/prompts/system.py +45 -0
- aleph/providers/__init__.py +14 -0
- aleph/providers/anthropic.py +253 -0
- aleph/providers/base.py +59 -0
- aleph/providers/openai.py +224 -0
- aleph/providers/registry.py +22 -0
- aleph/repl/__init__.py +5 -0
- aleph/repl/helpers.py +1068 -0
- aleph/repl/sandbox.py +777 -0
- aleph/sub_query/__init__.py +166 -0
- aleph/sub_query/api_backend.py +166 -0
- aleph/sub_query/cli_backend.py +327 -0
- aleph/types.py +216 -0
- aleph/utils/__init__.py +6 -0
- aleph/utils/logging.py +79 -0
- aleph/utils/tokens.py +43 -0
- aleph_rlm-0.6.0.dist-info/METADATA +358 -0
- aleph_rlm-0.6.0.dist-info/RECORD +32 -0
- aleph_rlm-0.6.0.dist-info/WHEEL +4 -0
- aleph_rlm-0.6.0.dist-info/entry_points.txt +3 -0
- aleph_rlm-0.6.0.dist-info/licenses/LICENSE +21 -0
aleph/mcp/server.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Compatibility entry point for Aleph MCP server.
|
|
2
|
+
|
|
3
|
+
This module now aliases the full-featured MCP server.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from .local_server import AlephMCPServerLocal, main as _main
|
|
9
|
+
|
|
10
|
+
AlephMCPServer = AlephMCPServerLocal
|
|
11
|
+
|
|
12
|
+
__all__ = ["AlephMCPServer", "main"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def main() -> None:
|
|
16
|
+
_main()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
if __name__ == "__main__":
|
|
20
|
+
main()
|
aleph/prompts/system.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""Default system prompt for Aleph.
|
|
2
|
+
|
|
3
|
+
This prompt teaches the model how to interact with the REPL and how to signal a
|
|
4
|
+
final answer.
|
|
5
|
+
|
|
6
|
+
The placeholders are filled by Aleph at runtime.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
DEFAULT_SYSTEM_PROMPT = """You are Aleph, a Recursive Language Model (RLM) assistant.
|
|
12
|
+
|
|
13
|
+
You have access to a sandboxed Python REPL environment where a potentially massive context is stored in the variable `{context_var}`.
|
|
14
|
+
|
|
15
|
+
CONTEXT INFORMATION:
|
|
16
|
+
- Format: {context_format}
|
|
17
|
+
- Size: {context_size_chars:,} characters, {context_size_lines:,} lines, ~{context_size_tokens:,} tokens (estimate)
|
|
18
|
+
- Structure: {structure_hint}
|
|
19
|
+
- Preview (first 500 chars):
|
|
20
|
+
```
|
|
21
|
+
{context_preview}
|
|
22
|
+
```
|
|
23
|
+
|
|
24
|
+
AVAILABLE FUNCTIONS (in the REPL):
|
|
25
|
+
- `peek(start=0, end=None)` - View characters [start:end] of the context
|
|
26
|
+
- `lines(start=0, end=None)` - View lines [start:end] of the context
|
|
27
|
+
- `search(pattern, context_lines=2, flags=0, max_results=20)` - Regex search returning matches with surrounding context
|
|
28
|
+
- `chunk(chunk_size, overlap=0)` - Split the context into character chunks
|
|
29
|
+
- `semantic_search(query, chunk_size=1000, overlap=100, top_k=5)` - Meaning-based search
|
|
30
|
+
- `sub_query(prompt, context_slice=None)` - Ask a sub-question to another LLM (cheaper model)
|
|
31
|
+
- `sub_aleph(query, context=None)` - Run a recursive Aleph call (higher-level recursion)
|
|
32
|
+
|
|
33
|
+
WORKFLOW:
|
|
34
|
+
1. Decide what you need from the context.
|
|
35
|
+
2. Use Python code blocks to explore/process the context.
|
|
36
|
+
3. Keep REPL outputs small; summarize or extract only what you need.
|
|
37
|
+
4. When you have the final answer, respond with exactly one of:
|
|
38
|
+
- `FINAL(your answer)`
|
|
39
|
+
- `FINAL_VAR(variable_name)`
|
|
40
|
+
|
|
41
|
+
IMPORTANT:
|
|
42
|
+
- Write Python code inside a fenced block: ```python ... ```
|
|
43
|
+
- You can iterate: write code, inspect output, then write more code.
|
|
44
|
+
- Avoid dumping huge text. Prefer targeted search/slicing.
|
|
45
|
+
"""
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""LLM provider implementations."""
|
|
2
|
+
|
|
3
|
+
from .base import LLMProvider, ProviderError
|
|
4
|
+
from .anthropic import AnthropicProvider
|
|
5
|
+
from .openai import OpenAIProvider
|
|
6
|
+
from .registry import get_provider
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"LLMProvider",
|
|
10
|
+
"ProviderError",
|
|
11
|
+
"AnthropicProvider",
|
|
12
|
+
"OpenAIProvider",
|
|
13
|
+
"get_provider",
|
|
14
|
+
]
|
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
"""Anthropic provider.
|
|
2
|
+
|
|
3
|
+
Implements Aleph's provider interface against Anthropic's Messages API.
|
|
4
|
+
|
|
5
|
+
This module intentionally uses bare HTTP (httpx) to keep dependencies minimal.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import os
|
|
12
|
+
import asyncio
|
|
13
|
+
|
|
14
|
+
import httpx
|
|
15
|
+
|
|
16
|
+
from .base import LLMProvider, ModelPricing, ProviderError
|
|
17
|
+
from ..utils.tokens import estimate_tokens
|
|
18
|
+
from ..types import Message
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class AnthropicProvider:
|
|
22
|
+
"""Anthropic Claude provider via the Messages API."""
|
|
23
|
+
|
|
24
|
+
# Model -> pricing / limits (rough defaults; override in code if needed)
|
|
25
|
+
MODEL_INFO: dict[str, ModelPricing] = {
|
|
26
|
+
# NOTE: Values are approximate and may change; intended for budgeting/telemetry.
|
|
27
|
+
"claude-sonnet-4-20250514": ModelPricing(200_000, 64_000, 0.003, 0.015),
|
|
28
|
+
"claude-opus-4-20250514": ModelPricing(200_000, 32_000, 0.015, 0.075),
|
|
29
|
+
"claude-haiku-3-5-20241022": ModelPricing(200_000, 8_192, 0.0008, 0.004),
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
api_key: str | None = None,
|
|
35
|
+
base_url: str = "https://api.anthropic.com",
|
|
36
|
+
anthropic_version: str = "2023-06-01",
|
|
37
|
+
http_client: httpx.AsyncClient | None = None,
|
|
38
|
+
max_retries: int = 3,
|
|
39
|
+
backoff_base_seconds: float = 0.8,
|
|
40
|
+
) -> None:
|
|
41
|
+
self._api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
|
|
42
|
+
if not self._api_key:
|
|
43
|
+
# Don't raise immediately; allow creating instance and failing on first call.
|
|
44
|
+
self._api_key = ""
|
|
45
|
+
self._base_url = base_url.rstrip("/")
|
|
46
|
+
self._version = anthropic_version
|
|
47
|
+
self._client = http_client
|
|
48
|
+
self._owned_client = http_client is None
|
|
49
|
+
self._max_retries = max_retries
|
|
50
|
+
self._backoff_base = backoff_base_seconds
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def provider_name(self) -> str:
|
|
54
|
+
return "anthropic"
|
|
55
|
+
|
|
56
|
+
async def _get_client(self) -> httpx.AsyncClient:
|
|
57
|
+
if self._client is None:
|
|
58
|
+
self._client = httpx.AsyncClient(timeout=httpx.Timeout(60.0))
|
|
59
|
+
return self._client
|
|
60
|
+
|
|
61
|
+
async def aclose(self) -> None:
|
|
62
|
+
if self._owned_client and self._client is not None:
|
|
63
|
+
await self._client.aclose()
|
|
64
|
+
self._client = None
|
|
65
|
+
|
|
66
|
+
def count_tokens(self, text: str, model: str) -> int:
|
|
67
|
+
# Keep it dependency-free by default.
|
|
68
|
+
return estimate_tokens(text)
|
|
69
|
+
|
|
70
|
+
def get_context_limit(self, model: str) -> int:
|
|
71
|
+
info = self.MODEL_INFO.get(model)
|
|
72
|
+
return info.context_limit if info else 200_000
|
|
73
|
+
|
|
74
|
+
def get_output_limit(self, model: str) -> int:
|
|
75
|
+
info = self.MODEL_INFO.get(model)
|
|
76
|
+
return info.output_limit if info else 8_192
|
|
77
|
+
|
|
78
|
+
def _estimate_cost(self, model: str, input_tokens: int, output_tokens: int) -> float:
|
|
79
|
+
info = self.MODEL_INFO.get(model)
|
|
80
|
+
if not info:
|
|
81
|
+
return 0.0
|
|
82
|
+
return (input_tokens / 1000.0) * info.input_cost_per_1k + (output_tokens / 1000.0) * info.output_cost_per_1k
|
|
83
|
+
|
|
84
|
+
@staticmethod
|
|
85
|
+
def _split_system(messages: list[Message]) -> tuple[str | None, list[Message]]:
|
|
86
|
+
system_parts: list[str] = []
|
|
87
|
+
out: list[Message] = []
|
|
88
|
+
for m in messages:
|
|
89
|
+
role = m.get("role", "")
|
|
90
|
+
if role == "system":
|
|
91
|
+
system_parts.append(m.get("content", ""))
|
|
92
|
+
else:
|
|
93
|
+
out.append(m)
|
|
94
|
+
system = "\n\n".join([p for p in system_parts if p.strip()]) or None
|
|
95
|
+
return system, out
|
|
96
|
+
|
|
97
|
+
async def complete(
|
|
98
|
+
self,
|
|
99
|
+
messages: list[Message],
|
|
100
|
+
model: str,
|
|
101
|
+
max_tokens: int = 4096,
|
|
102
|
+
temperature: float = 0.0,
|
|
103
|
+
stop_sequences: list[str] | None = None,
|
|
104
|
+
timeout_seconds: float | None = None,
|
|
105
|
+
) -> tuple[str, int, int, float]:
|
|
106
|
+
if not self._api_key:
|
|
107
|
+
raise ProviderError(
|
|
108
|
+
"Anthropic API key not set. Provide api_key=... or set ANTHROPIC_API_KEY."
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
system, filtered = self._split_system(messages)
|
|
112
|
+
|
|
113
|
+
# Anthropic Messages API uses roles: user/assistant only.
|
|
114
|
+
anthropic_messages: list[dict[str, str]] = []
|
|
115
|
+
for m in filtered:
|
|
116
|
+
role = m.get("role")
|
|
117
|
+
if role not in {"user", "assistant"}:
|
|
118
|
+
# Best-effort fallback: treat unknown roles as user content.
|
|
119
|
+
role = "user"
|
|
120
|
+
anthropic_messages.append({"role": role, "content": m.get("content", "")})
|
|
121
|
+
|
|
122
|
+
url = f"{self._base_url}/v1/messages"
|
|
123
|
+
headers = {
|
|
124
|
+
"x-api-key": self._api_key,
|
|
125
|
+
"anthropic-version": self._version,
|
|
126
|
+
"content-type": "application/json",
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
payload: dict[str, object] = {
|
|
130
|
+
"model": model,
|
|
131
|
+
"max_tokens": max_tokens,
|
|
132
|
+
"temperature": temperature,
|
|
133
|
+
"messages": anthropic_messages,
|
|
134
|
+
}
|
|
135
|
+
if system:
|
|
136
|
+
payload["system"] = system
|
|
137
|
+
if stop_sequences:
|
|
138
|
+
payload["stop_sequences"] = stop_sequences
|
|
139
|
+
|
|
140
|
+
client = await self._get_client()
|
|
141
|
+
timeout = httpx.Timeout(timeout_seconds) if timeout_seconds else client.timeout
|
|
142
|
+
|
|
143
|
+
def _parse_retry_after_seconds(resp: httpx.Response) -> float | None:
|
|
144
|
+
ra = resp.headers.get("retry-after")
|
|
145
|
+
if not ra:
|
|
146
|
+
return None
|
|
147
|
+
try:
|
|
148
|
+
return max(0.0, float(ra.strip()))
|
|
149
|
+
except ValueError:
|
|
150
|
+
return None
|
|
151
|
+
|
|
152
|
+
def _format_http_error(resp: httpx.Response) -> str:
|
|
153
|
+
request_id = resp.headers.get("request-id") or resp.headers.get("x-request-id")
|
|
154
|
+
retry_after = _parse_retry_after_seconds(resp)
|
|
155
|
+
|
|
156
|
+
msg = None
|
|
157
|
+
try:
|
|
158
|
+
data = resp.json()
|
|
159
|
+
err = data.get("error") if isinstance(data, dict) else None
|
|
160
|
+
if isinstance(err, dict):
|
|
161
|
+
raw = err.get("message")
|
|
162
|
+
if isinstance(raw, str) and raw.strip():
|
|
163
|
+
msg = raw.strip()
|
|
164
|
+
except Exception:
|
|
165
|
+
msg = None
|
|
166
|
+
|
|
167
|
+
if msg is None:
|
|
168
|
+
body = (resp.text or "").strip()
|
|
169
|
+
msg = body[:500] if body else "(no response body)"
|
|
170
|
+
|
|
171
|
+
parts = [f"Anthropic API error {resp.status_code}: {msg}"]
|
|
172
|
+
if request_id:
|
|
173
|
+
parts.append(f"request_id={request_id}")
|
|
174
|
+
if retry_after is not None:
|
|
175
|
+
parts.append(f"retry_after_seconds={retry_after:.0f}")
|
|
176
|
+
if len(parts) == 1:
|
|
177
|
+
return parts[0]
|
|
178
|
+
return parts[0] + " (" + ", ".join(parts[1:]) + ")"
|
|
179
|
+
|
|
180
|
+
last_err: Exception | None = None
|
|
181
|
+
for attempt in range(1, self._max_retries + 2):
|
|
182
|
+
try:
|
|
183
|
+
resp = await client.post(url, headers=headers, json=payload, timeout=timeout)
|
|
184
|
+
|
|
185
|
+
if resp.status_code >= 400:
|
|
186
|
+
retryable_status = resp.status_code in {408, 409, 429, 500, 502, 503, 504}
|
|
187
|
+
if retryable_status and attempt <= self._max_retries:
|
|
188
|
+
retry_after = _parse_retry_after_seconds(resp)
|
|
189
|
+
delay = retry_after if retry_after is not None else (self._backoff_base * (2 ** (attempt - 1)))
|
|
190
|
+
await asyncio.sleep(delay)
|
|
191
|
+
continue
|
|
192
|
+
raise ProviderError(_format_http_error(resp))
|
|
193
|
+
|
|
194
|
+
try:
|
|
195
|
+
data = resp.json()
|
|
196
|
+
except json.JSONDecodeError as e:
|
|
197
|
+
raise ProviderError(f"Invalid JSON response from Anthropic: {e}")
|
|
198
|
+
|
|
199
|
+
if not isinstance(data, dict):
|
|
200
|
+
raise ProviderError(f"Anthropic API returned invalid JSON type: {type(data)}")
|
|
201
|
+
|
|
202
|
+
# Response content is a list of blocks; typically first is text.
|
|
203
|
+
content_blocks = data.get("content") or []
|
|
204
|
+
text_parts: list[str] = []
|
|
205
|
+
if isinstance(content_blocks, list):
|
|
206
|
+
for block in content_blocks:
|
|
207
|
+
if isinstance(block, dict) and block.get("type") == "text":
|
|
208
|
+
val = block.get("text", "")
|
|
209
|
+
if isinstance(val, str):
|
|
210
|
+
text_parts.append(val)
|
|
211
|
+
text = "".join(text_parts).strip()
|
|
212
|
+
|
|
213
|
+
usage = data.get("usage") or {}
|
|
214
|
+
if not isinstance(usage, dict):
|
|
215
|
+
usage = {}
|
|
216
|
+
input_tokens = int(usage.get("input_tokens") or 0)
|
|
217
|
+
output_tokens = int(usage.get("output_tokens") or 0)
|
|
218
|
+
|
|
219
|
+
# If usage is missing (rare), estimate.
|
|
220
|
+
if input_tokens == 0:
|
|
221
|
+
input_tokens = sum(self.count_tokens(m["content"], model) for m in messages)
|
|
222
|
+
if output_tokens == 0:
|
|
223
|
+
output_tokens = self.count_tokens(text, model)
|
|
224
|
+
|
|
225
|
+
cost = self._estimate_cost(model, input_tokens, output_tokens)
|
|
226
|
+
return text, input_tokens, output_tokens, cost
|
|
227
|
+
|
|
228
|
+
except ProviderError:
|
|
229
|
+
raise
|
|
230
|
+
except httpx.TimeoutException as e:
|
|
231
|
+
last_err = e
|
|
232
|
+
if attempt <= self._max_retries:
|
|
233
|
+
await asyncio.sleep(self._backoff_base * (2 ** (attempt - 1)))
|
|
234
|
+
continue
|
|
235
|
+
break
|
|
236
|
+
except httpx.RequestError as e:
|
|
237
|
+
last_err = e
|
|
238
|
+
if attempt <= self._max_retries:
|
|
239
|
+
await asyncio.sleep(self._backoff_base * (2 ** (attempt - 1)))
|
|
240
|
+
continue
|
|
241
|
+
break
|
|
242
|
+
except Exception as e:
|
|
243
|
+
last_err = e
|
|
244
|
+
if attempt <= self._max_retries:
|
|
245
|
+
await asyncio.sleep(self._backoff_base * (2 ** (attempt - 1)))
|
|
246
|
+
continue
|
|
247
|
+
break
|
|
248
|
+
|
|
249
|
+
if isinstance(last_err, httpx.TimeoutException):
|
|
250
|
+
raise ProviderError(f"Anthropic request timed out: {last_err}")
|
|
251
|
+
if isinstance(last_err, httpx.RequestError):
|
|
252
|
+
raise ProviderError(f"Anthropic request failed: {last_err}")
|
|
253
|
+
raise ProviderError(f"Anthropic provider failed after retries: {last_err}")
|
aleph/providers/base.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""Provider abstraction.
|
|
2
|
+
|
|
3
|
+
Aleph's core logic is provider-agnostic. Any provider can be used as long as it
|
|
4
|
+
implements the LLMProvider protocol.
|
|
5
|
+
|
|
6
|
+
The interface is intentionally small:
|
|
7
|
+
- complete(): async LLM call
|
|
8
|
+
- count_tokens(): token estimate
|
|
9
|
+
- get_context_limit()/get_output_limit(): model metadata
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
from dataclasses import dataclass
|
|
15
|
+
from typing import Protocol
|
|
16
|
+
|
|
17
|
+
from ..types import Message
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ProviderError(RuntimeError):
|
|
21
|
+
"""Raised when a provider call fails."""
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class LLMProvider(Protocol):
|
|
25
|
+
"""Protocol all providers must implement."""
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def provider_name(self) -> str:
|
|
29
|
+
...
|
|
30
|
+
|
|
31
|
+
async def complete(
|
|
32
|
+
self,
|
|
33
|
+
messages: list[Message],
|
|
34
|
+
model: str,
|
|
35
|
+
max_tokens: int = 4096,
|
|
36
|
+
temperature: float = 0.0,
|
|
37
|
+
stop_sequences: list[str] | None = None,
|
|
38
|
+
timeout_seconds: float | None = None,
|
|
39
|
+
) -> tuple[str, int, int, float]:
|
|
40
|
+
"""Return (response_text, input_tokens, output_tokens, cost_usd)."""
|
|
41
|
+
|
|
42
|
+
def count_tokens(self, text: str, model: str) -> int:
|
|
43
|
+
...
|
|
44
|
+
|
|
45
|
+
def get_context_limit(self, model: str) -> int:
|
|
46
|
+
...
|
|
47
|
+
|
|
48
|
+
def get_output_limit(self, model: str) -> int:
|
|
49
|
+
...
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass(slots=True)
|
|
53
|
+
class ModelPricing:
|
|
54
|
+
"""Token and cost metadata for a model (rough)."""
|
|
55
|
+
|
|
56
|
+
context_limit: int
|
|
57
|
+
output_limit: int
|
|
58
|
+
input_cost_per_1k: float
|
|
59
|
+
output_cost_per_1k: float
|
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
"""OpenAI provider.
|
|
2
|
+
|
|
3
|
+
Implements Aleph's provider interface against OpenAI's Chat Completions API.
|
|
4
|
+
|
|
5
|
+
This module uses bare HTTP via httpx for minimal dependencies.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import asyncio
|
|
11
|
+
import json
|
|
12
|
+
import os
|
|
13
|
+
|
|
14
|
+
import httpx
|
|
15
|
+
|
|
16
|
+
from .base import ModelPricing, ProviderError
|
|
17
|
+
from ..utils.tokens import estimate_tokens, try_count_tokens_tiktoken
|
|
18
|
+
from ..types import Message
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class OpenAIProvider:
|
|
22
|
+
"""OpenAI provider via /v1/chat/completions."""
|
|
23
|
+
|
|
24
|
+
MODEL_INFO: dict[str, ModelPricing] = {
|
|
25
|
+
# NOTE: Prices/limits change; these defaults are for budgeting/telemetry.
|
|
26
|
+
"gpt-4o": ModelPricing(128_000, 16_384, 0.0025, 0.01),
|
|
27
|
+
"gpt-4o-mini": ModelPricing(128_000, 16_384, 0.00015, 0.0006),
|
|
28
|
+
"gpt-4-turbo": ModelPricing(128_000, 4_096, 0.01, 0.03),
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
api_key: str | None = None,
|
|
34
|
+
base_url: str = "https://api.openai.com",
|
|
35
|
+
organization: str | None = None,
|
|
36
|
+
http_client: httpx.AsyncClient | None = None,
|
|
37
|
+
max_retries: int = 3,
|
|
38
|
+
backoff_base_seconds: float = 0.8,
|
|
39
|
+
) -> None:
|
|
40
|
+
self._api_key = api_key or os.getenv("OPENAI_API_KEY")
|
|
41
|
+
if not self._api_key:
|
|
42
|
+
self._api_key = ""
|
|
43
|
+
self._base_url = base_url.rstrip("/")
|
|
44
|
+
self._org = organization or os.getenv("OPENAI_ORG_ID")
|
|
45
|
+
self._client = http_client
|
|
46
|
+
self._owned_client = http_client is None
|
|
47
|
+
self._max_retries = max_retries
|
|
48
|
+
self._backoff_base = backoff_base_seconds
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def provider_name(self) -> str:
|
|
52
|
+
return "openai"
|
|
53
|
+
|
|
54
|
+
async def _get_client(self) -> httpx.AsyncClient:
|
|
55
|
+
if self._client is None:
|
|
56
|
+
self._client = httpx.AsyncClient(timeout=httpx.Timeout(60.0))
|
|
57
|
+
return self._client
|
|
58
|
+
|
|
59
|
+
async def aclose(self) -> None:
|
|
60
|
+
if self._owned_client and self._client is not None:
|
|
61
|
+
await self._client.aclose()
|
|
62
|
+
self._client = None
|
|
63
|
+
|
|
64
|
+
def count_tokens(self, text: str, model: str) -> int:
|
|
65
|
+
# Best-effort: use tiktoken if installed.
|
|
66
|
+
n = try_count_tokens_tiktoken(text, model)
|
|
67
|
+
if n is not None:
|
|
68
|
+
return n
|
|
69
|
+
return estimate_tokens(text)
|
|
70
|
+
|
|
71
|
+
def get_context_limit(self, model: str) -> int:
|
|
72
|
+
info = self.MODEL_INFO.get(model)
|
|
73
|
+
return info.context_limit if info else 128_000
|
|
74
|
+
|
|
75
|
+
def get_output_limit(self, model: str) -> int:
|
|
76
|
+
info = self.MODEL_INFO.get(model)
|
|
77
|
+
return info.output_limit if info else 4_096
|
|
78
|
+
|
|
79
|
+
def _estimate_cost(self, model: str, input_tokens: int, output_tokens: int) -> float:
|
|
80
|
+
info = self.MODEL_INFO.get(model)
|
|
81
|
+
if not info:
|
|
82
|
+
return 0.0
|
|
83
|
+
return (input_tokens / 1000.0) * info.input_cost_per_1k + (output_tokens / 1000.0) * info.output_cost_per_1k
|
|
84
|
+
|
|
85
|
+
async def complete(
|
|
86
|
+
self,
|
|
87
|
+
messages: list[Message],
|
|
88
|
+
model: str,
|
|
89
|
+
max_tokens: int = 4096,
|
|
90
|
+
temperature: float = 0.0,
|
|
91
|
+
stop_sequences: list[str] | None = None,
|
|
92
|
+
timeout_seconds: float | None = None,
|
|
93
|
+
) -> tuple[str, int, int, float]:
|
|
94
|
+
if not self._api_key:
|
|
95
|
+
raise ProviderError("OpenAI API key not set. Provide api_key=... or set OPENAI_API_KEY.")
|
|
96
|
+
|
|
97
|
+
url = f"{self._base_url}/v1/chat/completions"
|
|
98
|
+
headers = {
|
|
99
|
+
"authorization": f"Bearer {self._api_key}",
|
|
100
|
+
"content-type": "application/json",
|
|
101
|
+
}
|
|
102
|
+
if self._org:
|
|
103
|
+
headers["openai-organization"] = self._org
|
|
104
|
+
|
|
105
|
+
payload: dict[str, object] = {
|
|
106
|
+
"model": model,
|
|
107
|
+
"messages": messages,
|
|
108
|
+
"max_tokens": max_tokens,
|
|
109
|
+
"temperature": temperature,
|
|
110
|
+
}
|
|
111
|
+
if stop_sequences:
|
|
112
|
+
payload["stop"] = stop_sequences
|
|
113
|
+
|
|
114
|
+
client = await self._get_client()
|
|
115
|
+
timeout = httpx.Timeout(timeout_seconds) if timeout_seconds else client.timeout
|
|
116
|
+
|
|
117
|
+
def _parse_retry_after_seconds(resp: httpx.Response) -> float | None:
|
|
118
|
+
ra = resp.headers.get("retry-after")
|
|
119
|
+
if not ra:
|
|
120
|
+
return None
|
|
121
|
+
try:
|
|
122
|
+
return max(0.0, float(ra.strip()))
|
|
123
|
+
except ValueError:
|
|
124
|
+
return None
|
|
125
|
+
|
|
126
|
+
def _format_http_error(resp: httpx.Response) -> str:
|
|
127
|
+
request_id = resp.headers.get("x-request-id")
|
|
128
|
+
retry_after = _parse_retry_after_seconds(resp)
|
|
129
|
+
|
|
130
|
+
msg = None
|
|
131
|
+
try:
|
|
132
|
+
data = resp.json()
|
|
133
|
+
err = data.get("error") if isinstance(data, dict) else None
|
|
134
|
+
if isinstance(err, dict):
|
|
135
|
+
raw = err.get("message")
|
|
136
|
+
if isinstance(raw, str) and raw.strip():
|
|
137
|
+
msg = raw.strip()
|
|
138
|
+
except Exception:
|
|
139
|
+
msg = None
|
|
140
|
+
|
|
141
|
+
if msg is None:
|
|
142
|
+
body = (resp.text or "").strip()
|
|
143
|
+
msg = body[:500] if body else "(no response body)"
|
|
144
|
+
|
|
145
|
+
parts = [f"OpenAI API error {resp.status_code}: {msg}"]
|
|
146
|
+
if request_id:
|
|
147
|
+
parts.append(f"request_id={request_id}")
|
|
148
|
+
if retry_after is not None:
|
|
149
|
+
parts.append(f"retry_after_seconds={retry_after:.0f}")
|
|
150
|
+
if len(parts) == 1:
|
|
151
|
+
return parts[0]
|
|
152
|
+
return parts[0] + " (" + ", ".join(parts[1:]) + ")"
|
|
153
|
+
|
|
154
|
+
last_err: Exception | None = None
|
|
155
|
+
for attempt in range(1, self._max_retries + 2):
|
|
156
|
+
try:
|
|
157
|
+
resp = await client.post(url, headers=headers, json=payload, timeout=timeout)
|
|
158
|
+
|
|
159
|
+
if resp.status_code >= 400:
|
|
160
|
+
retryable_status = resp.status_code in {408, 409, 429, 500, 502, 503, 504}
|
|
161
|
+
if retryable_status and attempt <= self._max_retries:
|
|
162
|
+
retry_after = _parse_retry_after_seconds(resp)
|
|
163
|
+
delay = retry_after if retry_after is not None else (self._backoff_base * (2 ** (attempt - 1)))
|
|
164
|
+
await asyncio.sleep(delay)
|
|
165
|
+
continue
|
|
166
|
+
raise ProviderError(_format_http_error(resp))
|
|
167
|
+
|
|
168
|
+
try:
|
|
169
|
+
data = resp.json()
|
|
170
|
+
except json.JSONDecodeError as e:
|
|
171
|
+
raise ProviderError(f"Invalid JSON response from OpenAI: {e}")
|
|
172
|
+
|
|
173
|
+
if not isinstance(data, dict):
|
|
174
|
+
raise ProviderError(f"OpenAI API returned invalid JSON type: {type(data)}")
|
|
175
|
+
|
|
176
|
+
choices = data.get("choices") or []
|
|
177
|
+
if not choices:
|
|
178
|
+
raise ProviderError(f"OpenAI API returned no choices")
|
|
179
|
+
|
|
180
|
+
message = choices[0].get("message") if isinstance(choices[0], dict) else None
|
|
181
|
+
if not isinstance(message, dict):
|
|
182
|
+
message = {}
|
|
183
|
+
text = (message.get("content") or "").strip()
|
|
184
|
+
|
|
185
|
+
usage = data.get("usage") or {}
|
|
186
|
+
if not isinstance(usage, dict):
|
|
187
|
+
usage = {}
|
|
188
|
+
input_tokens = int(usage.get("prompt_tokens") or 0)
|
|
189
|
+
output_tokens = int(usage.get("completion_tokens") or 0)
|
|
190
|
+
|
|
191
|
+
if input_tokens == 0:
|
|
192
|
+
input_tokens = sum(self.count_tokens(m.get("content", ""), model) for m in messages)
|
|
193
|
+
if output_tokens == 0:
|
|
194
|
+
output_tokens = self.count_tokens(text, model)
|
|
195
|
+
|
|
196
|
+
cost = self._estimate_cost(model, input_tokens, output_tokens)
|
|
197
|
+
return text, input_tokens, output_tokens, cost
|
|
198
|
+
|
|
199
|
+
except ProviderError:
|
|
200
|
+
raise
|
|
201
|
+
except httpx.TimeoutException as e:
|
|
202
|
+
last_err = e
|
|
203
|
+
if attempt <= self._max_retries:
|
|
204
|
+
await asyncio.sleep(self._backoff_base * (2 ** (attempt - 1)))
|
|
205
|
+
continue
|
|
206
|
+
break
|
|
207
|
+
except httpx.RequestError as e:
|
|
208
|
+
last_err = e
|
|
209
|
+
if attempt <= self._max_retries:
|
|
210
|
+
await asyncio.sleep(self._backoff_base * (2 ** (attempt - 1)))
|
|
211
|
+
continue
|
|
212
|
+
break
|
|
213
|
+
except Exception as e:
|
|
214
|
+
last_err = e
|
|
215
|
+
if attempt <= self._max_retries:
|
|
216
|
+
await asyncio.sleep(self._backoff_base * (2 ** (attempt - 1)))
|
|
217
|
+
continue
|
|
218
|
+
break
|
|
219
|
+
|
|
220
|
+
if isinstance(last_err, httpx.TimeoutException):
|
|
221
|
+
raise ProviderError(f"OpenAI request timed out: {last_err}")
|
|
222
|
+
if isinstance(last_err, httpx.RequestError):
|
|
223
|
+
raise ProviderError(f"OpenAI request failed: {last_err}")
|
|
224
|
+
raise ProviderError(f"OpenAI provider failed after retries: {last_err}")
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""Provider factory."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from .anthropic import AnthropicProvider
|
|
6
|
+
from .openai import OpenAIProvider
|
|
7
|
+
from .base import LLMProvider
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
PROVIDERS: dict[str, type[LLMProvider]] = {
|
|
11
|
+
"anthropic": AnthropicProvider,
|
|
12
|
+
"openai": OpenAIProvider,
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_provider(name: str, **kwargs: object) -> LLMProvider:
|
|
17
|
+
"""Instantiate a provider by name."""
|
|
18
|
+
|
|
19
|
+
key = name.lower().strip()
|
|
20
|
+
if key not in PROVIDERS:
|
|
21
|
+
raise ValueError(f"Unknown provider: {name}. Available: {sorted(PROVIDERS.keys())}")
|
|
22
|
+
return PROVIDERS[key](**kwargs)
|