axion-code 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- axion/__init__.py +3 -0
- axion/api/__init__.py +0 -0
- axion/api/anthropic.py +460 -0
- axion/api/client.py +259 -0
- axion/api/error.py +161 -0
- axion/api/ollama.py +597 -0
- axion/api/openai_compat.py +805 -0
- axion/api/openai_responses.py +627 -0
- axion/api/prompt_cache.py +31 -0
- axion/api/sse.py +98 -0
- axion/api/types.py +451 -0
- axion/cli/__init__.py +0 -0
- axion/cli/init_cmd.py +50 -0
- axion/cli/input.py +290 -0
- axion/cli/main.py +2953 -0
- axion/cli/render.py +489 -0
- axion/cli/tui.py +766 -0
- axion/commands/__init__.py +0 -0
- axion/commands/handlers/__init__.py +0 -0
- axion/commands/handlers/agents.py +51 -0
- axion/commands/handlers/builtin_commands.py +367 -0
- axion/commands/handlers/mcp.py +59 -0
- axion/commands/handlers/models.py +75 -0
- axion/commands/handlers/plugins.py +55 -0
- axion/commands/handlers/skills.py +61 -0
- axion/commands/parsing.py +317 -0
- axion/commands/registry.py +166 -0
- axion/compat_harness/__init__.py +0 -0
- axion/compat_harness/extractor.py +145 -0
- axion/plugins/__init__.py +0 -0
- axion/plugins/hooks.py +22 -0
- axion/plugins/manager.py +391 -0
- axion/plugins/manifest.py +270 -0
- axion/runtime/__init__.py +0 -0
- axion/runtime/bash.py +388 -0
- axion/runtime/bootstrap.py +39 -0
- axion/runtime/claude_subscription.py +300 -0
- axion/runtime/compact.py +233 -0
- axion/runtime/config.py +397 -0
- axion/runtime/conversation.py +1073 -0
- axion/runtime/file_ops.py +613 -0
- axion/runtime/git.py +213 -0
- axion/runtime/hooks.py +235 -0
- axion/runtime/image.py +212 -0
- axion/runtime/lanes.py +282 -0
- axion/runtime/lsp.py +425 -0
- axion/runtime/mcp/__init__.py +0 -0
- axion/runtime/mcp/client.py +76 -0
- axion/runtime/mcp/lifecycle.py +96 -0
- axion/runtime/mcp/stdio.py +318 -0
- axion/runtime/mcp/tool_bridge.py +79 -0
- axion/runtime/memory.py +196 -0
- axion/runtime/oauth.py +329 -0
- axion/runtime/openai_subscription.py +346 -0
- axion/runtime/permissions.py +247 -0
- axion/runtime/plan_mode.py +96 -0
- axion/runtime/policy_engine.py +259 -0
- axion/runtime/prompt.py +586 -0
- axion/runtime/recovery.py +261 -0
- axion/runtime/remote.py +28 -0
- axion/runtime/sandbox.py +68 -0
- axion/runtime/scheduler.py +231 -0
- axion/runtime/session.py +365 -0
- axion/runtime/sharing.py +159 -0
- axion/runtime/skills.py +124 -0
- axion/runtime/tasks.py +258 -0
- axion/runtime/usage.py +241 -0
- axion/runtime/workers.py +186 -0
- axion/telemetry/__init__.py +0 -0
- axion/telemetry/events.py +67 -0
- axion/telemetry/profile.py +49 -0
- axion/telemetry/sink.py +60 -0
- axion/telemetry/tracer.py +95 -0
- axion/tools/__init__.py +0 -0
- axion/tools/lane_completion.py +33 -0
- axion/tools/registry.py +853 -0
- axion/tools/tool_search.py +226 -0
- axion_code-1.0.0.dist-info/METADATA +709 -0
- axion_code-1.0.0.dist-info/RECORD +82 -0
- axion_code-1.0.0.dist-info/WHEEL +4 -0
- axion_code-1.0.0.dist-info/entry_points.txt +2 -0
- axion_code-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,805 @@
|
|
|
1
|
+
"""OpenAI-compatible API client with streaming support.
|
|
2
|
+
|
|
3
|
+
Maps to: rust/crates/api/src/providers/openai_compat.rs
|
|
4
|
+
|
|
5
|
+
Supports xAI (Grok) and OpenAI providers via the OpenAI chat completions
|
|
6
|
+
API format, translating between Anthropic-style request/response types and
|
|
7
|
+
OpenAI's wire format.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import asyncio
|
|
13
|
+
import json
|
|
14
|
+
import logging
|
|
15
|
+
import os
|
|
16
|
+
from collections import OrderedDict
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
from typing import Any, AsyncIterator
|
|
19
|
+
|
|
20
|
+
import httpx
|
|
21
|
+
|
|
22
|
+
from axion.api.error import (
|
|
23
|
+
ApiError,
|
|
24
|
+
ApiResponseError,
|
|
25
|
+
BackoffOverflowError,
|
|
26
|
+
HttpError,
|
|
27
|
+
InvalidSseFrameError,
|
|
28
|
+
MissingCredentialsError,
|
|
29
|
+
RetriesExhaustedError,
|
|
30
|
+
)
|
|
31
|
+
from axion.api.types import (
|
|
32
|
+
ContentBlockDeltaEvent,
|
|
33
|
+
ContentBlockStartEvent,
|
|
34
|
+
ContentBlockStopEvent,
|
|
35
|
+
InputJsonDelta,
|
|
36
|
+
InputMessage,
|
|
37
|
+
MessageDelta,
|
|
38
|
+
MessageDeltaEvent,
|
|
39
|
+
MessageRequest,
|
|
40
|
+
MessageResponse,
|
|
41
|
+
MessageStartEvent,
|
|
42
|
+
MessageStopEvent,
|
|
43
|
+
OutputContentBlock,
|
|
44
|
+
StreamEvent,
|
|
45
|
+
TextDelta,
|
|
46
|
+
TextInputBlock,
|
|
47
|
+
TextOutputBlock,
|
|
48
|
+
ToolChoice,
|
|
49
|
+
ToolDefinition,
|
|
50
|
+
ToolResultBlock,
|
|
51
|
+
ToolResultJsonContent,
|
|
52
|
+
ToolResultTextContent,
|
|
53
|
+
ToolUseInputBlock,
|
|
54
|
+
ToolUseOutputBlock,
|
|
55
|
+
Usage,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
logger = logging.getLogger(__name__)
|
|
59
|
+
|
|
60
|
+
# ---------------------------------------------------------------------------
|
|
61
|
+
# Constants
|
|
62
|
+
# ---------------------------------------------------------------------------
|
|
63
|
+
|
|
64
|
+
DEFAULT_XAI_BASE_URL = "https://api.x.ai/v1"
|
|
65
|
+
DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"
|
|
66
|
+
REQUEST_ID_HEADER = "request-id"
|
|
67
|
+
ALT_REQUEST_ID_HEADER = "x-request-id"
|
|
68
|
+
DEFAULT_INITIAL_BACKOFF_MS = 200
|
|
69
|
+
DEFAULT_MAX_BACKOFF_MS = 2000
|
|
70
|
+
DEFAULT_MAX_RETRIES = 2
|
|
71
|
+
RETRYABLE_STATUS_CODES = frozenset({408, 409, 429, 500, 502, 503, 504})
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# ---------------------------------------------------------------------------
|
|
75
|
+
# Config presets
|
|
76
|
+
# ---------------------------------------------------------------------------
|
|
77
|
+
|
|
78
|
+
@dataclass(frozen=True)
|
|
79
|
+
class OpenAiCompatConfig:
|
|
80
|
+
"""Provider configuration for an OpenAI-compatible endpoint."""
|
|
81
|
+
|
|
82
|
+
provider_name: str
|
|
83
|
+
api_key_env: str
|
|
84
|
+
base_url_env: str
|
|
85
|
+
default_base_url: str
|
|
86
|
+
|
|
87
|
+
@classmethod
|
|
88
|
+
def xai(cls) -> OpenAiCompatConfig:
|
|
89
|
+
return cls(
|
|
90
|
+
provider_name="xAI",
|
|
91
|
+
api_key_env="XAI_API_KEY",
|
|
92
|
+
base_url_env="XAI_BASE_URL",
|
|
93
|
+
default_base_url=DEFAULT_XAI_BASE_URL,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
@classmethod
|
|
97
|
+
def openai(cls) -> OpenAiCompatConfig:
|
|
98
|
+
return cls(
|
|
99
|
+
provider_name="OpenAI",
|
|
100
|
+
api_key_env="OPENAI_API_KEY",
|
|
101
|
+
base_url_env="OPENAI_BASE_URL",
|
|
102
|
+
default_base_url=DEFAULT_OPENAI_BASE_URL,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def credential_env_vars(self) -> list[str]:
|
|
107
|
+
return [self.api_key_env]
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
# ---------------------------------------------------------------------------
|
|
111
|
+
# Client
|
|
112
|
+
# ---------------------------------------------------------------------------
|
|
113
|
+
|
|
114
|
+
class OpenAiCompatClient:
|
|
115
|
+
"""HTTP client for OpenAI-compatible chat completion APIs.
|
|
116
|
+
|
|
117
|
+
Maps to: rust OpenAiCompatClient
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
def __init__(
|
|
121
|
+
self,
|
|
122
|
+
api_key: str,
|
|
123
|
+
config: OpenAiCompatConfig,
|
|
124
|
+
*,
|
|
125
|
+
base_url: str | None = None,
|
|
126
|
+
max_retries: int = DEFAULT_MAX_RETRIES,
|
|
127
|
+
initial_backoff_ms: int = DEFAULT_INITIAL_BACKOFF_MS,
|
|
128
|
+
max_backoff_ms: int = DEFAULT_MAX_BACKOFF_MS,
|
|
129
|
+
) -> None:
|
|
130
|
+
self._http = httpx.AsyncClient(timeout=httpx.Timeout(120.0))
|
|
131
|
+
self._api_key = api_key
|
|
132
|
+
self._config = config
|
|
133
|
+
self._base_url = base_url or _read_base_url(config)
|
|
134
|
+
self._max_retries = max_retries
|
|
135
|
+
self._initial_backoff_ms = initial_backoff_ms
|
|
136
|
+
self._max_backoff_ms = max_backoff_ms
|
|
137
|
+
|
|
138
|
+
@classmethod
|
|
139
|
+
def from_env(cls, config: OpenAiCompatConfig) -> OpenAiCompatClient:
|
|
140
|
+
"""Create a client reading the API key from env or saved file."""
|
|
141
|
+
api_key = _read_env_non_empty(config.api_key_env)
|
|
142
|
+
|
|
143
|
+
# Check saved key file (from `axion login --provider <name>`)
|
|
144
|
+
if api_key is None:
|
|
145
|
+
from pathlib import Path
|
|
146
|
+
key_path = Path.home() / ".axion" / "credentials" / f"{config.provider_name}.key"
|
|
147
|
+
if key_path.exists():
|
|
148
|
+
saved = key_path.read_text(encoding="utf-8").strip()
|
|
149
|
+
if saved:
|
|
150
|
+
api_key = saved
|
|
151
|
+
os.environ[config.api_key_env] = saved
|
|
152
|
+
|
|
153
|
+
if api_key is None:
|
|
154
|
+
raise MissingCredentialsError(
|
|
155
|
+
provider=config.provider_name,
|
|
156
|
+
env_vars=config.credential_env_vars,
|
|
157
|
+
)
|
|
158
|
+
return cls(api_key=api_key, config=config)
|
|
159
|
+
|
|
160
|
+
@property
|
|
161
|
+
def config(self) -> OpenAiCompatConfig:
|
|
162
|
+
return self._config
|
|
163
|
+
|
|
164
|
+
# -- Public API ----------------------------------------------------------
|
|
165
|
+
|
|
166
|
+
async def send_message(self, request: MessageRequest) -> MessageResponse:
|
|
167
|
+
"""Send a non-streaming chat completion request."""
|
|
168
|
+
req = MessageRequest(
|
|
169
|
+
model=request.model,
|
|
170
|
+
max_tokens=request.max_tokens,
|
|
171
|
+
messages=request.messages,
|
|
172
|
+
system=request.system,
|
|
173
|
+
tools=request.tools,
|
|
174
|
+
tool_choice=request.tool_choice,
|
|
175
|
+
stream=False,
|
|
176
|
+
)
|
|
177
|
+
response = await self._send_with_retry(req)
|
|
178
|
+
request_id = _request_id_from_headers(response.headers)
|
|
179
|
+
payload = response.json()
|
|
180
|
+
normalized = _normalize_response(req.model, payload)
|
|
181
|
+
if normalized.request_id is None:
|
|
182
|
+
normalized.request_id = request_id
|
|
183
|
+
return normalized
|
|
184
|
+
|
|
185
|
+
async def stream_message(
|
|
186
|
+
self, request: MessageRequest
|
|
187
|
+
) -> AsyncIterator[StreamEvent]:
|
|
188
|
+
"""Send a streaming request and yield Anthropic-format StreamEvents."""
|
|
189
|
+
req = MessageRequest(
|
|
190
|
+
model=request.model,
|
|
191
|
+
max_tokens=request.max_tokens,
|
|
192
|
+
messages=request.messages,
|
|
193
|
+
system=request.system,
|
|
194
|
+
tools=request.tools,
|
|
195
|
+
tool_choice=request.tool_choice,
|
|
196
|
+
stream=True,
|
|
197
|
+
)
|
|
198
|
+
response = await self._send_with_retry(req)
|
|
199
|
+
|
|
200
|
+
parser = OpenAiSseParser()
|
|
201
|
+
state = _StreamState(model=req.model)
|
|
202
|
+
|
|
203
|
+
async for raw_chunk in response.aiter_bytes():
|
|
204
|
+
for chunk in parser.push(raw_chunk):
|
|
205
|
+
for event in state.ingest_chunk(chunk):
|
|
206
|
+
yield event
|
|
207
|
+
|
|
208
|
+
# Finalize the stream
|
|
209
|
+
for event in state.finish():
|
|
210
|
+
yield event
|
|
211
|
+
|
|
212
|
+
async def close(self) -> None:
|
|
213
|
+
"""Close the underlying HTTP client."""
|
|
214
|
+
await self._http.aclose()
|
|
215
|
+
|
|
216
|
+
# -- Retry logic ---------------------------------------------------------
|
|
217
|
+
|
|
218
|
+
async def _send_with_retry(self, request: MessageRequest) -> httpx.Response:
|
|
219
|
+
attempts = 0
|
|
220
|
+
last_error: ApiError | None = None
|
|
221
|
+
|
|
222
|
+
while True:
|
|
223
|
+
attempts += 1
|
|
224
|
+
try:
|
|
225
|
+
response = await self._send_raw_request(request)
|
|
226
|
+
_expect_success(response)
|
|
227
|
+
return response
|
|
228
|
+
except ApiError as err:
|
|
229
|
+
if err.is_retryable() and attempts <= self._max_retries:
|
|
230
|
+
last_error = err
|
|
231
|
+
backoff = self._backoff_for_attempt(attempts)
|
|
232
|
+
await asyncio.sleep(backoff)
|
|
233
|
+
continue
|
|
234
|
+
if not err.is_retryable():
|
|
235
|
+
raise
|
|
236
|
+
last_error = err
|
|
237
|
+
break
|
|
238
|
+
|
|
239
|
+
raise RetriesExhaustedError(attempts=attempts, last_error=last_error) # type: ignore[arg-type]
|
|
240
|
+
|
|
241
|
+
async def _send_raw_request(self, request: MessageRequest) -> httpx.Response:
|
|
242
|
+
url = _chat_completions_endpoint(self._base_url)
|
|
243
|
+
body = _build_chat_completion_request(request, self._config)
|
|
244
|
+
try:
|
|
245
|
+
response = await self._http.post(
|
|
246
|
+
url,
|
|
247
|
+
json=body,
|
|
248
|
+
headers={
|
|
249
|
+
"content-type": "application/json",
|
|
250
|
+
"authorization": f"Bearer {self._api_key}",
|
|
251
|
+
},
|
|
252
|
+
)
|
|
253
|
+
return response
|
|
254
|
+
except httpx.HTTPError as exc:
|
|
255
|
+
raise HttpError(str(exc), cause=exc) from exc
|
|
256
|
+
|
|
257
|
+
def _backoff_for_attempt(self, attempt: int) -> float:
|
|
258
|
+
"""Exponential backoff in seconds."""
|
|
259
|
+
try:
|
|
260
|
+
multiplier = 1 << (attempt - 1)
|
|
261
|
+
except (OverflowError, ValueError):
|
|
262
|
+
raise BackoffOverflowError(attempt=attempt, base_delay_ms=self._initial_backoff_ms)
|
|
263
|
+
|
|
264
|
+
delay_ms = self._initial_backoff_ms * multiplier
|
|
265
|
+
delay_ms = min(delay_ms, self._max_backoff_ms)
|
|
266
|
+
return delay_ms / 1000.0
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
# ---------------------------------------------------------------------------
|
|
270
|
+
# SSE parser (OpenAI format)
|
|
271
|
+
# ---------------------------------------------------------------------------
|
|
272
|
+
|
|
273
|
+
class OpenAiSseParser:
|
|
274
|
+
"""Incremental SSE parser for OpenAI's streaming format.
|
|
275
|
+
|
|
276
|
+
Parses ``data: {...}\\n\\n`` frames and ``data: [DONE]`` terminators.
|
|
277
|
+
"""
|
|
278
|
+
|
|
279
|
+
def __init__(self) -> None:
|
|
280
|
+
self._buffer = bytearray()
|
|
281
|
+
|
|
282
|
+
def push(self, chunk: bytes) -> list[dict[str, Any]]:
|
|
283
|
+
"""Push raw bytes and return any fully-parsed ChatCompletionChunk dicts."""
|
|
284
|
+
self._buffer.extend(chunk)
|
|
285
|
+
results: list[dict[str, Any]] = []
|
|
286
|
+
|
|
287
|
+
while True:
|
|
288
|
+
frame = self._next_frame()
|
|
289
|
+
if frame is None:
|
|
290
|
+
break
|
|
291
|
+
parsed = _parse_sse_frame(frame)
|
|
292
|
+
if parsed is not None:
|
|
293
|
+
results.append(parsed)
|
|
294
|
+
|
|
295
|
+
return results
|
|
296
|
+
|
|
297
|
+
def _next_frame(self) -> str | None:
|
|
298
|
+
pos = self._buffer.find(b"\n\n")
|
|
299
|
+
sep_len = 2
|
|
300
|
+
if pos == -1:
|
|
301
|
+
pos = self._buffer.find(b"\r\n\r\n")
|
|
302
|
+
sep_len = 4
|
|
303
|
+
if pos == -1:
|
|
304
|
+
return None
|
|
305
|
+
|
|
306
|
+
frame_bytes = bytes(self._buffer[:pos])
|
|
307
|
+
del self._buffer[:pos + sep_len]
|
|
308
|
+
return frame_bytes.decode("utf-8", errors="replace")
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def _parse_sse_frame(frame: str) -> dict[str, Any] | None:
|
|
312
|
+
"""Parse a single SSE frame into a ChatCompletionChunk dict."""
|
|
313
|
+
trimmed = frame.strip()
|
|
314
|
+
if not trimmed:
|
|
315
|
+
return None
|
|
316
|
+
|
|
317
|
+
data_lines: list[str] = []
|
|
318
|
+
for line in trimmed.splitlines():
|
|
319
|
+
if line.startswith(":"):
|
|
320
|
+
continue
|
|
321
|
+
if line.startswith("data:"):
|
|
322
|
+
data_lines.append(line[len("data:"):].lstrip())
|
|
323
|
+
|
|
324
|
+
if not data_lines:
|
|
325
|
+
return None
|
|
326
|
+
|
|
327
|
+
payload = "\n".join(data_lines)
|
|
328
|
+
if payload == "[DONE]":
|
|
329
|
+
return None
|
|
330
|
+
|
|
331
|
+
try:
|
|
332
|
+
return json.loads(payload)
|
|
333
|
+
except json.JSONDecodeError as exc:
|
|
334
|
+
raise InvalidSseFrameError(f"Invalid JSON in SSE data: {exc}") from exc
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
# ---------------------------------------------------------------------------
|
|
338
|
+
# Stream state machine
|
|
339
|
+
# ---------------------------------------------------------------------------
|
|
340
|
+
|
|
341
|
+
class _ToolCallState:
|
|
342
|
+
"""Accumulates streamed tool call deltas for a single tool call."""
|
|
343
|
+
|
|
344
|
+
__slots__ = ("openai_index", "id", "name", "arguments", "emitted_len", "started", "stopped")
|
|
345
|
+
|
|
346
|
+
def __init__(self, openai_index: int = 0) -> None:
|
|
347
|
+
self.openai_index = openai_index
|
|
348
|
+
self.id: str | None = None
|
|
349
|
+
self.name: str | None = None
|
|
350
|
+
self.arguments: str = ""
|
|
351
|
+
self.emitted_len: int = 0
|
|
352
|
+
self.started: bool = False
|
|
353
|
+
self.stopped: bool = False
|
|
354
|
+
|
|
355
|
+
def apply(self, tool_call: dict[str, Any]) -> None:
|
|
356
|
+
self.openai_index = tool_call.get("index", self.openai_index)
|
|
357
|
+
if "id" in tool_call and tool_call["id"]:
|
|
358
|
+
self.id = tool_call["id"]
|
|
359
|
+
func = tool_call.get("function", {})
|
|
360
|
+
if func.get("name"):
|
|
361
|
+
self.name = func["name"]
|
|
362
|
+
if func.get("arguments"):
|
|
363
|
+
self.arguments += func["arguments"]
|
|
364
|
+
|
|
365
|
+
@property
|
|
366
|
+
def block_index(self) -> int:
|
|
367
|
+
"""Anthropic block index: tool calls start after the text block at 0."""
|
|
368
|
+
return self.openai_index + 1
|
|
369
|
+
|
|
370
|
+
def start_event(self) -> ContentBlockStartEvent | None:
|
|
371
|
+
if self.name is None:
|
|
372
|
+
return None
|
|
373
|
+
tool_id = self.id or f"tool_call_{self.openai_index}"
|
|
374
|
+
return ContentBlockStartEvent(
|
|
375
|
+
index=self.block_index,
|
|
376
|
+
content_block=ToolUseOutputBlock(id=tool_id, name=self.name, input={}),
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
def delta_event(self) -> ContentBlockDeltaEvent | None:
|
|
380
|
+
if self.emitted_len >= len(self.arguments):
|
|
381
|
+
return None
|
|
382
|
+
delta_text = self.arguments[self.emitted_len:]
|
|
383
|
+
self.emitted_len = len(self.arguments)
|
|
384
|
+
return ContentBlockDeltaEvent(
|
|
385
|
+
index=self.block_index,
|
|
386
|
+
delta=InputJsonDelta(partial_json=delta_text),
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
class _StreamState:
|
|
391
|
+
"""Translates a sequence of OpenAI ChatCompletionChunks into Anthropic StreamEvents."""
|
|
392
|
+
|
|
393
|
+
def __init__(self, model: str) -> None:
|
|
394
|
+
self._model = model
|
|
395
|
+
self._message_started = False
|
|
396
|
+
self._text_started = False
|
|
397
|
+
self._text_finished = False
|
|
398
|
+
self._finished = False
|
|
399
|
+
self._stop_reason: str | None = None
|
|
400
|
+
self._usage: Usage | None = None
|
|
401
|
+
self._tool_calls: OrderedDict[int, _ToolCallState] = OrderedDict()
|
|
402
|
+
|
|
403
|
+
def ingest_chunk(self, chunk: dict[str, Any]) -> list[StreamEvent]:
|
|
404
|
+
"""Process one ChatCompletionChunk dict and return resulting StreamEvents."""
|
|
405
|
+
events: list[StreamEvent] = []
|
|
406
|
+
|
|
407
|
+
# Emit MessageStart on first chunk
|
|
408
|
+
if not self._message_started:
|
|
409
|
+
self._message_started = True
|
|
410
|
+
events.append(MessageStartEvent(
|
|
411
|
+
message=MessageResponse(
|
|
412
|
+
id=chunk.get("id", ""),
|
|
413
|
+
type="message",
|
|
414
|
+
role="assistant",
|
|
415
|
+
content=[],
|
|
416
|
+
model=chunk.get("model") or self._model,
|
|
417
|
+
usage=Usage(),
|
|
418
|
+
),
|
|
419
|
+
))
|
|
420
|
+
|
|
421
|
+
# Track usage if present
|
|
422
|
+
if "usage" in chunk and chunk["usage"]:
|
|
423
|
+
u = chunk["usage"]
|
|
424
|
+
self._usage = Usage(
|
|
425
|
+
input_tokens=u.get("prompt_tokens", 0),
|
|
426
|
+
output_tokens=u.get("completion_tokens", 0),
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
# Process choices
|
|
430
|
+
for choice in chunk.get("choices", []):
|
|
431
|
+
delta = choice.get("delta", {})
|
|
432
|
+
|
|
433
|
+
# Text content
|
|
434
|
+
content = delta.get("content")
|
|
435
|
+
if content:
|
|
436
|
+
if not self._text_started:
|
|
437
|
+
self._text_started = True
|
|
438
|
+
events.append(ContentBlockStartEvent(
|
|
439
|
+
index=0,
|
|
440
|
+
content_block=TextOutputBlock(text=""),
|
|
441
|
+
))
|
|
442
|
+
events.append(ContentBlockDeltaEvent(
|
|
443
|
+
index=0,
|
|
444
|
+
delta=TextDelta(text=content),
|
|
445
|
+
))
|
|
446
|
+
|
|
447
|
+
# Tool calls
|
|
448
|
+
for tc in delta.get("tool_calls", []):
|
|
449
|
+
idx = tc.get("index", 0)
|
|
450
|
+
if idx not in self._tool_calls:
|
|
451
|
+
self._tool_calls[idx] = _ToolCallState(openai_index=idx)
|
|
452
|
+
state = self._tool_calls[idx]
|
|
453
|
+
state.apply(tc)
|
|
454
|
+
|
|
455
|
+
if not state.started:
|
|
456
|
+
start_ev = state.start_event()
|
|
457
|
+
if start_ev is not None:
|
|
458
|
+
state.started = True
|
|
459
|
+
events.append(start_ev)
|
|
460
|
+
else:
|
|
461
|
+
continue
|
|
462
|
+
|
|
463
|
+
delta_ev = state.delta_event()
|
|
464
|
+
if delta_ev is not None:
|
|
465
|
+
events.append(delta_ev)
|
|
466
|
+
|
|
467
|
+
if choice.get("finish_reason") == "tool_calls" and not state.stopped:
|
|
468
|
+
state.stopped = True
|
|
469
|
+
events.append(ContentBlockStopEvent(index=state.block_index))
|
|
470
|
+
|
|
471
|
+
# Finish reason
|
|
472
|
+
finish_reason = choice.get("finish_reason")
|
|
473
|
+
if finish_reason:
|
|
474
|
+
self._stop_reason = _normalize_finish_reason(finish_reason)
|
|
475
|
+
if finish_reason == "tool_calls":
|
|
476
|
+
for state in self._tool_calls.values():
|
|
477
|
+
if state.started and not state.stopped:
|
|
478
|
+
state.stopped = True
|
|
479
|
+
events.append(ContentBlockStopEvent(index=state.block_index))
|
|
480
|
+
|
|
481
|
+
return events
|
|
482
|
+
|
|
483
|
+
def finish(self) -> list[StreamEvent]:
|
|
484
|
+
"""Finalize the stream, emitting closing events."""
|
|
485
|
+
if self._finished:
|
|
486
|
+
return []
|
|
487
|
+
self._finished = True
|
|
488
|
+
|
|
489
|
+
events: list[StreamEvent] = []
|
|
490
|
+
|
|
491
|
+
# Close text block if still open
|
|
492
|
+
if self._text_started and not self._text_finished:
|
|
493
|
+
self._text_finished = True
|
|
494
|
+
events.append(ContentBlockStopEvent(index=0))
|
|
495
|
+
|
|
496
|
+
# Flush any un-started or un-stopped tool calls
|
|
497
|
+
for state in self._tool_calls.values():
|
|
498
|
+
if not state.started:
|
|
499
|
+
start_ev = state.start_event()
|
|
500
|
+
if start_ev is not None:
|
|
501
|
+
state.started = True
|
|
502
|
+
events.append(start_ev)
|
|
503
|
+
delta_ev = state.delta_event()
|
|
504
|
+
if delta_ev is not None:
|
|
505
|
+
events.append(delta_ev)
|
|
506
|
+
if state.started and not state.stopped:
|
|
507
|
+
state.stopped = True
|
|
508
|
+
events.append(ContentBlockStopEvent(index=state.block_index))
|
|
509
|
+
|
|
510
|
+
# MessageDelta and MessageStop
|
|
511
|
+
if self._message_started:
|
|
512
|
+
events.append(MessageDeltaEvent(
|
|
513
|
+
delta=MessageDelta(
|
|
514
|
+
stop_reason=self._stop_reason or "end_turn",
|
|
515
|
+
),
|
|
516
|
+
usage=self._usage or Usage(),
|
|
517
|
+
))
|
|
518
|
+
events.append(MessageStopEvent())
|
|
519
|
+
|
|
520
|
+
return events
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
# ---------------------------------------------------------------------------
|
|
524
|
+
# Request translation (Anthropic -> OpenAI)
|
|
525
|
+
# ---------------------------------------------------------------------------
|
|
526
|
+
|
|
527
|
+
def _build_chat_completion_request(
|
|
528
|
+
request: MessageRequest, config: OpenAiCompatConfig
|
|
529
|
+
) -> dict[str, Any]:
|
|
530
|
+
"""Translate an Anthropic-style MessageRequest into an OpenAI chat completion body."""
|
|
531
|
+
messages: list[dict[str, Any]] = []
|
|
532
|
+
|
|
533
|
+
# System message
|
|
534
|
+
if request.system:
|
|
535
|
+
messages.append({"role": "system", "content": request.system})
|
|
536
|
+
|
|
537
|
+
# Conversation messages
|
|
538
|
+
for message in request.messages:
|
|
539
|
+
messages.extend(_translate_message(message))
|
|
540
|
+
|
|
541
|
+
# GPT-5+, o1, o3, o4 models use max_completion_tokens instead of max_tokens
|
|
542
|
+
model_lower = request.model.lower()
|
|
543
|
+
uses_new_param = any(
|
|
544
|
+
model_lower.startswith(p) for p in ("gpt-5", "o1", "o3", "o4")
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
payload: dict[str, Any] = {
|
|
548
|
+
"model": request.model,
|
|
549
|
+
"messages": messages,
|
|
550
|
+
"stream": request.stream,
|
|
551
|
+
}
|
|
552
|
+
if uses_new_param:
|
|
553
|
+
payload["max_completion_tokens"] = request.max_tokens
|
|
554
|
+
else:
|
|
555
|
+
payload["max_tokens"] = request.max_tokens
|
|
556
|
+
|
|
557
|
+
# OpenAI requires stream_options for usage in streaming
|
|
558
|
+
if request.stream and _should_request_stream_usage(config):
|
|
559
|
+
payload["stream_options"] = {"include_usage": True}
|
|
560
|
+
|
|
561
|
+
# Tools
|
|
562
|
+
if request.tools:
|
|
563
|
+
payload["tools"] = [_openai_tool_definition(t) for t in request.tools]
|
|
564
|
+
|
|
565
|
+
# Tool choice
|
|
566
|
+
if request.tool_choice is not None:
|
|
567
|
+
payload["tool_choice"] = _openai_tool_choice(request.tool_choice)
|
|
568
|
+
|
|
569
|
+
return payload
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
def _translate_message(message: InputMessage) -> list[dict[str, Any]]:
|
|
573
|
+
"""Translate a single Anthropic InputMessage into one or more OpenAI messages."""
|
|
574
|
+
if message.role == "assistant":
|
|
575
|
+
text_parts: list[str] = []
|
|
576
|
+
tool_calls: list[dict[str, Any]] = []
|
|
577
|
+
for block in message.content:
|
|
578
|
+
if isinstance(block, TextInputBlock):
|
|
579
|
+
text_parts.append(block.text)
|
|
580
|
+
elif isinstance(block, ToolUseInputBlock):
|
|
581
|
+
tool_calls.append({
|
|
582
|
+
"id": block.id,
|
|
583
|
+
"type": "function",
|
|
584
|
+
"function": {
|
|
585
|
+
"name": block.name,
|
|
586
|
+
"arguments": json.dumps(block.input) if not isinstance(block.input, str) else block.input,
|
|
587
|
+
},
|
|
588
|
+
})
|
|
589
|
+
text = "".join(text_parts)
|
|
590
|
+
if not text and not tool_calls:
|
|
591
|
+
return []
|
|
592
|
+
msg: dict[str, Any] = {"role": "assistant"}
|
|
593
|
+
if text:
|
|
594
|
+
msg["content"] = text
|
|
595
|
+
if tool_calls:
|
|
596
|
+
msg["tool_calls"] = tool_calls
|
|
597
|
+
return [msg]
|
|
598
|
+
|
|
599
|
+
# User or other roles: expand each block into its own message
|
|
600
|
+
# Check if there are images — if so, combine text+images into one message
|
|
601
|
+
from axion.api.types import ImageInputBlock
|
|
602
|
+
|
|
603
|
+
has_images = any(isinstance(b, ImageInputBlock) for b in message.content)
|
|
604
|
+
|
|
605
|
+
if has_images:
|
|
606
|
+
# OpenAI requires multi-part content array for vision
|
|
607
|
+
parts: list[dict[str, Any]] = []
|
|
608
|
+
tool_results: list[dict[str, Any]] = []
|
|
609
|
+
for block in message.content:
|
|
610
|
+
if isinstance(block, TextInputBlock):
|
|
611
|
+
parts.append({"type": "text", "text": block.text})
|
|
612
|
+
elif isinstance(block, ImageInputBlock):
|
|
613
|
+
parts.append(block.to_openai_dict())
|
|
614
|
+
elif isinstance(block, ToolResultBlock):
|
|
615
|
+
tool_results.append({
|
|
616
|
+
"role": "tool",
|
|
617
|
+
"tool_call_id": block.tool_use_id,
|
|
618
|
+
"content": _flatten_tool_result_content(block),
|
|
619
|
+
})
|
|
620
|
+
results: list[dict[str, Any]] = []
|
|
621
|
+
if parts:
|
|
622
|
+
results.append({"role": "user", "content": parts})
|
|
623
|
+
results.extend(tool_results)
|
|
624
|
+
return results
|
|
625
|
+
|
|
626
|
+
results = []
|
|
627
|
+
for block in message.content:
|
|
628
|
+
if isinstance(block, TextInputBlock):
|
|
629
|
+
results.append({"role": "user", "content": block.text})
|
|
630
|
+
elif isinstance(block, ToolResultBlock):
|
|
631
|
+
content = _flatten_tool_result_content(block)
|
|
632
|
+
results.append({
|
|
633
|
+
"role": "tool",
|
|
634
|
+
"tool_call_id": block.tool_use_id,
|
|
635
|
+
"content": content,
|
|
636
|
+
})
|
|
637
|
+
return results
|
|
638
|
+
|
|
639
|
+
|
|
640
|
+
def _flatten_tool_result_content(block: ToolResultBlock) -> str:
|
|
641
|
+
"""Flatten tool result content blocks into a single string."""
|
|
642
|
+
parts: list[str] = []
|
|
643
|
+
for c in block.content:
|
|
644
|
+
if isinstance(c, ToolResultTextContent):
|
|
645
|
+
parts.append(c.text)
|
|
646
|
+
elif isinstance(c, ToolResultJsonContent):
|
|
647
|
+
parts.append(json.dumps(c.value) if not isinstance(c.value, str) else c.value)
|
|
648
|
+
return "\n".join(parts)
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
def _openai_tool_definition(tool: ToolDefinition) -> dict[str, Any]:
|
|
652
|
+
"""Translate an Anthropic ToolDefinition to OpenAI function format."""
|
|
653
|
+
func: dict[str, Any] = {
|
|
654
|
+
"name": tool.name,
|
|
655
|
+
"parameters": tool.input_schema,
|
|
656
|
+
}
|
|
657
|
+
if tool.description is not None:
|
|
658
|
+
func["description"] = tool.description
|
|
659
|
+
return {"type": "function", "function": func}
|
|
660
|
+
|
|
661
|
+
|
|
662
|
+
def _openai_tool_choice(tool_choice: ToolChoice) -> Any:
|
|
663
|
+
"""Translate Anthropic ToolChoice to OpenAI tool_choice format."""
|
|
664
|
+
if tool_choice.type == "auto":
|
|
665
|
+
return "auto"
|
|
666
|
+
if tool_choice.type == "any":
|
|
667
|
+
return "required"
|
|
668
|
+
if tool_choice.type == "tool" and tool_choice.name:
|
|
669
|
+
return {"type": "function", "function": {"name": tool_choice.name}}
|
|
670
|
+
return "auto"
|
|
671
|
+
|
|
672
|
+
|
|
673
|
+
# ---------------------------------------------------------------------------
|
|
674
|
+
# Response translation (OpenAI -> Anthropic)
|
|
675
|
+
# ---------------------------------------------------------------------------
|
|
676
|
+
|
|
677
|
+
def _normalize_response(model: str, data: dict[str, Any]) -> MessageResponse:
|
|
678
|
+
"""Translate an OpenAI ChatCompletion response to Anthropic MessageResponse."""
|
|
679
|
+
choices = data.get("choices", [])
|
|
680
|
+
if not choices:
|
|
681
|
+
raise InvalidSseFrameError("chat completion response missing choices")
|
|
682
|
+
|
|
683
|
+
choice = choices[0]
|
|
684
|
+
message = choice.get("message", {})
|
|
685
|
+
content: list[OutputContentBlock] = []
|
|
686
|
+
|
|
687
|
+
# Text content
|
|
688
|
+
text = message.get("content")
|
|
689
|
+
if text:
|
|
690
|
+
content.append(TextOutputBlock(text=text))
|
|
691
|
+
|
|
692
|
+
# Tool calls
|
|
693
|
+
for tc in message.get("tool_calls", []):
|
|
694
|
+
func = tc.get("function", {})
|
|
695
|
+
arguments = _parse_tool_arguments(func.get("arguments", ""))
|
|
696
|
+
content.append(ToolUseOutputBlock(
|
|
697
|
+
id=tc.get("id", ""),
|
|
698
|
+
name=func.get("name", ""),
|
|
699
|
+
input=arguments,
|
|
700
|
+
))
|
|
701
|
+
|
|
702
|
+
# Usage
|
|
703
|
+
usage_data = data.get("usage", {})
|
|
704
|
+
usage = Usage(
|
|
705
|
+
input_tokens=usage_data.get("prompt_tokens", 0),
|
|
706
|
+
output_tokens=usage_data.get("completion_tokens", 0),
|
|
707
|
+
)
|
|
708
|
+
|
|
709
|
+
# Finish reason
|
|
710
|
+
finish_reason = choice.get("finish_reason")
|
|
711
|
+
stop_reason = _normalize_finish_reason(finish_reason) if finish_reason else None
|
|
712
|
+
|
|
713
|
+
resp_model = data.get("model", "") or model
|
|
714
|
+
|
|
715
|
+
return MessageResponse(
|
|
716
|
+
id=data.get("id", ""),
|
|
717
|
+
type="message",
|
|
718
|
+
role=message.get("role", "assistant"),
|
|
719
|
+
content=content,
|
|
720
|
+
model=resp_model,
|
|
721
|
+
usage=usage,
|
|
722
|
+
stop_reason=stop_reason,
|
|
723
|
+
)
|
|
724
|
+
|
|
725
|
+
|
|
726
|
+
# ---------------------------------------------------------------------------
|
|
727
|
+
# Helpers
|
|
728
|
+
# ---------------------------------------------------------------------------
|
|
729
|
+
|
|
730
|
+
def _parse_tool_arguments(arguments: str) -> Any:
|
|
731
|
+
"""Parse tool call arguments JSON, falling back to a raw wrapper."""
|
|
732
|
+
try:
|
|
733
|
+
return json.loads(arguments)
|
|
734
|
+
except (json.JSONDecodeError, TypeError):
|
|
735
|
+
return {"raw": arguments}
|
|
736
|
+
|
|
737
|
+
|
|
738
|
+
def _normalize_finish_reason(value: str) -> str:
|
|
739
|
+
"""Map OpenAI finish reasons to Anthropic stop reasons."""
|
|
740
|
+
mapping = {"stop": "end_turn", "tool_calls": "tool_use"}
|
|
741
|
+
return mapping.get(value, value)
|
|
742
|
+
|
|
743
|
+
|
|
744
|
+
def _should_request_stream_usage(config: OpenAiCompatConfig) -> bool:
|
|
745
|
+
"""Only OpenAI proper requires the stream_options usage opt-in."""
|
|
746
|
+
return config.provider_name == "OpenAI"
|
|
747
|
+
|
|
748
|
+
|
|
749
|
+
def _chat_completions_endpoint(base_url: str) -> str:
|
|
750
|
+
"""Build the chat/completions URL, handling trailing slashes and full URLs."""
|
|
751
|
+
trimmed = base_url.rstrip("/")
|
|
752
|
+
if trimmed.endswith("/chat/completions"):
|
|
753
|
+
return trimmed
|
|
754
|
+
return f"{trimmed}/chat/completions"
|
|
755
|
+
|
|
756
|
+
|
|
757
|
+
def _read_env_non_empty(key: str) -> str | None:
|
|
758
|
+
"""Read an environment variable, returning None if empty or unset."""
|
|
759
|
+
value = os.environ.get(key, "")
|
|
760
|
+
return value if value else None
|
|
761
|
+
|
|
762
|
+
|
|
763
|
+
def _read_base_url(config: OpenAiCompatConfig) -> str:
|
|
764
|
+
"""Read the base URL from env or fall back to the default."""
|
|
765
|
+
return os.environ.get(config.base_url_env, "") or config.default_base_url
|
|
766
|
+
|
|
767
|
+
|
|
768
|
+
def _request_id_from_headers(headers: httpx.Headers) -> str | None:
|
|
769
|
+
"""Extract request ID from response headers."""
|
|
770
|
+
return headers.get(REQUEST_ID_HEADER) or headers.get(ALT_REQUEST_ID_HEADER)
|
|
771
|
+
|
|
772
|
+
|
|
773
|
+
def _expect_success(response: httpx.Response) -> None:
|
|
774
|
+
"""Raise ApiResponseError for non-2xx responses."""
|
|
775
|
+
if response.is_success:
|
|
776
|
+
return
|
|
777
|
+
|
|
778
|
+
request_id = _request_id_from_headers(response.headers)
|
|
779
|
+
body = response.text
|
|
780
|
+
|
|
781
|
+
error_type: str | None = None
|
|
782
|
+
message: str | None = None
|
|
783
|
+
try:
|
|
784
|
+
envelope = json.loads(body)
|
|
785
|
+
err_obj = envelope.get("error", {})
|
|
786
|
+
error_type = err_obj.get("type")
|
|
787
|
+
message = err_obj.get("message")
|
|
788
|
+
except (json.JSONDecodeError, AttributeError):
|
|
789
|
+
pass
|
|
790
|
+
|
|
791
|
+
retryable = response.status_code in RETRYABLE_STATUS_CODES
|
|
792
|
+
|
|
793
|
+
raise ApiResponseError(
|
|
794
|
+
status=response.status_code,
|
|
795
|
+
error_type=error_type,
|
|
796
|
+
message=message,
|
|
797
|
+
request_id_val=request_id,
|
|
798
|
+
body=body,
|
|
799
|
+
retryable=retryable,
|
|
800
|
+
)
|
|
801
|
+
|
|
802
|
+
|
|
803
|
+
def has_api_key(key: str) -> bool:
|
|
804
|
+
"""Check whether an API key environment variable is set and non-empty."""
|
|
805
|
+
return _read_env_non_empty(key) is not None
|