axion-code 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- axion/__init__.py +3 -0
- axion/api/__init__.py +0 -0
- axion/api/anthropic.py +460 -0
- axion/api/client.py +259 -0
- axion/api/error.py +161 -0
- axion/api/ollama.py +597 -0
- axion/api/openai_compat.py +805 -0
- axion/api/openai_responses.py +627 -0
- axion/api/prompt_cache.py +31 -0
- axion/api/sse.py +98 -0
- axion/api/types.py +451 -0
- axion/cli/__init__.py +0 -0
- axion/cli/init_cmd.py +50 -0
- axion/cli/input.py +290 -0
- axion/cli/main.py +2953 -0
- axion/cli/render.py +489 -0
- axion/cli/tui.py +766 -0
- axion/commands/__init__.py +0 -0
- axion/commands/handlers/__init__.py +0 -0
- axion/commands/handlers/agents.py +51 -0
- axion/commands/handlers/builtin_commands.py +367 -0
- axion/commands/handlers/mcp.py +59 -0
- axion/commands/handlers/models.py +75 -0
- axion/commands/handlers/plugins.py +55 -0
- axion/commands/handlers/skills.py +61 -0
- axion/commands/parsing.py +317 -0
- axion/commands/registry.py +166 -0
- axion/compat_harness/__init__.py +0 -0
- axion/compat_harness/extractor.py +145 -0
- axion/plugins/__init__.py +0 -0
- axion/plugins/hooks.py +22 -0
- axion/plugins/manager.py +391 -0
- axion/plugins/manifest.py +270 -0
- axion/runtime/__init__.py +0 -0
- axion/runtime/bash.py +388 -0
- axion/runtime/bootstrap.py +39 -0
- axion/runtime/claude_subscription.py +300 -0
- axion/runtime/compact.py +233 -0
- axion/runtime/config.py +397 -0
- axion/runtime/conversation.py +1073 -0
- axion/runtime/file_ops.py +613 -0
- axion/runtime/git.py +213 -0
- axion/runtime/hooks.py +235 -0
- axion/runtime/image.py +212 -0
- axion/runtime/lanes.py +282 -0
- axion/runtime/lsp.py +425 -0
- axion/runtime/mcp/__init__.py +0 -0
- axion/runtime/mcp/client.py +76 -0
- axion/runtime/mcp/lifecycle.py +96 -0
- axion/runtime/mcp/stdio.py +318 -0
- axion/runtime/mcp/tool_bridge.py +79 -0
- axion/runtime/memory.py +196 -0
- axion/runtime/oauth.py +329 -0
- axion/runtime/openai_subscription.py +346 -0
- axion/runtime/permissions.py +247 -0
- axion/runtime/plan_mode.py +96 -0
- axion/runtime/policy_engine.py +259 -0
- axion/runtime/prompt.py +586 -0
- axion/runtime/recovery.py +261 -0
- axion/runtime/remote.py +28 -0
- axion/runtime/sandbox.py +68 -0
- axion/runtime/scheduler.py +231 -0
- axion/runtime/session.py +365 -0
- axion/runtime/sharing.py +159 -0
- axion/runtime/skills.py +124 -0
- axion/runtime/tasks.py +258 -0
- axion/runtime/usage.py +241 -0
- axion/runtime/workers.py +186 -0
- axion/telemetry/__init__.py +0 -0
- axion/telemetry/events.py +67 -0
- axion/telemetry/profile.py +49 -0
- axion/telemetry/sink.py +60 -0
- axion/telemetry/tracer.py +95 -0
- axion/tools/__init__.py +0 -0
- axion/tools/lane_completion.py +33 -0
- axion/tools/registry.py +853 -0
- axion/tools/tool_search.py +226 -0
- axion_code-1.0.0.dist-info/METADATA +709 -0
- axion_code-1.0.0.dist-info/RECORD +82 -0
- axion_code-1.0.0.dist-info/WHEEL +4 -0
- axion_code-1.0.0.dist-info/entry_points.txt +2 -0
- axion_code-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,627 @@
|
|
|
1
|
+
"""OpenAI Responses API client (the /v1/responses endpoint, used by Codex).
|
|
2
|
+
|
|
3
|
+
Maps to: rust/crates/api/src/providers/openai_responses.rs (no equivalent yet)
|
|
4
|
+
|
|
5
|
+
The Responses API differs from Chat Completions:
|
|
6
|
+
- Single `input` array instead of `messages` (each item has type+content blocks)
|
|
7
|
+
- Stateful by default (`previous_response_id`); we use stateless mode (`store: false`)
|
|
8
|
+
- Reasoning settings via `reasoning: {effort: "low|medium|high"}`
|
|
9
|
+
- Different streaming events (`response.created`, `response.output_text.delta`,
|
|
10
|
+
`response.output_item.added`, `response.completed`, etc.)
|
|
11
|
+
- Tools format is similar to Chat Completions but slightly different field names
|
|
12
|
+
|
|
13
|
+
This client translates between Anthropic-style MessageRequest/StreamEvent
|
|
14
|
+
(what the rest of axion uses) and the Responses API wire format.
|
|
15
|
+
|
|
16
|
+
Used by Codex models (gpt-5-codex, gpt-5-codex-mini, codex, codex-mini).
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from __future__ import annotations
|
|
20
|
+
|
|
21
|
+
import asyncio
|
|
22
|
+
import json
|
|
23
|
+
import logging
|
|
24
|
+
import os
|
|
25
|
+
from dataclasses import dataclass, field
|
|
26
|
+
from typing import Any, AsyncIterator
|
|
27
|
+
|
|
28
|
+
import httpx
|
|
29
|
+
|
|
30
|
+
from axion.api.error import (
|
|
31
|
+
ApiError,
|
|
32
|
+
ApiResponseError,
|
|
33
|
+
BackoffOverflowError,
|
|
34
|
+
HttpError,
|
|
35
|
+
MissingCredentialsError,
|
|
36
|
+
RetriesExhaustedError,
|
|
37
|
+
)
|
|
38
|
+
from axion.api.openai_compat import OpenAiSseParser
|
|
39
|
+
from axion.api.types import (
|
|
40
|
+
ContentBlockDeltaEvent,
|
|
41
|
+
ContentBlockStartEvent,
|
|
42
|
+
ContentBlockStopEvent,
|
|
43
|
+
ImageInputBlock,
|
|
44
|
+
InputJsonDelta,
|
|
45
|
+
InputMessage,
|
|
46
|
+
MessageDelta,
|
|
47
|
+
MessageDeltaEvent,
|
|
48
|
+
MessageRequest,
|
|
49
|
+
MessageResponse,
|
|
50
|
+
MessageStartEvent,
|
|
51
|
+
MessageStopEvent,
|
|
52
|
+
OutputContentBlock,
|
|
53
|
+
StreamEvent,
|
|
54
|
+
TextDelta,
|
|
55
|
+
TextInputBlock,
|
|
56
|
+
TextOutputBlock,
|
|
57
|
+
ToolDefinition,
|
|
58
|
+
ToolResultBlock,
|
|
59
|
+
ToolUseInputBlock,
|
|
60
|
+
ToolUseOutputBlock,
|
|
61
|
+
Usage,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
logger = logging.getLogger(__name__)
|
|
65
|
+
|
|
66
|
+
DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"
|
|
67
|
+
DEFAULT_MAX_RETRIES = 3
|
|
68
|
+
DEFAULT_INITIAL_BACKOFF_MS = 1000
|
|
69
|
+
DEFAULT_MAX_BACKOFF_MS = 30000
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
# ---------------------------------------------------------------------------
|
|
73
|
+
# Client
|
|
74
|
+
# ---------------------------------------------------------------------------
|
|
75
|
+
|
|
76
|
+
class OpenAiResponsesClient:
|
|
77
|
+
"""HTTP client for OpenAI's /v1/responses endpoint (Codex models)."""
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
api_key: str,
|
|
82
|
+
*,
|
|
83
|
+
base_url: str = DEFAULT_OPENAI_BASE_URL,
|
|
84
|
+
max_retries: int = DEFAULT_MAX_RETRIES,
|
|
85
|
+
initial_backoff_ms: int = DEFAULT_INITIAL_BACKOFF_MS,
|
|
86
|
+
max_backoff_ms: int = DEFAULT_MAX_BACKOFF_MS,
|
|
87
|
+
bearer_override: str | None = None,
|
|
88
|
+
) -> None:
|
|
89
|
+
self._http = httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=30.0))
|
|
90
|
+
self._api_key = api_key
|
|
91
|
+
self._base_url = base_url.rstrip("/")
|
|
92
|
+
self._max_retries = max_retries
|
|
93
|
+
self._initial_backoff_ms = initial_backoff_ms
|
|
94
|
+
self._max_backoff_ms = max_backoff_ms
|
|
95
|
+
# Optional Bearer token override (for ChatGPT subscription OAuth, future)
|
|
96
|
+
self._bearer_override = bearer_override
|
|
97
|
+
|
|
98
|
+
@classmethod
|
|
99
|
+
def from_env(cls) -> OpenAiResponsesClient:
|
|
100
|
+
"""Create a client using ChatGPT subscription if present, else API key.
|
|
101
|
+
|
|
102
|
+
Resolution order (unless AXION_AUTH_MODE=api forces API):
|
|
103
|
+
1. ChatGPT subscription OAuth token (~/.axion/credentials/openai-oauth.json)
|
|
104
|
+
2. OPENAI_API_KEY env var
|
|
105
|
+
3. ~/.axion/credentials/openai.key file
|
|
106
|
+
"""
|
|
107
|
+
base = os.environ.get("OPENAI_BASE_URL", DEFAULT_OPENAI_BASE_URL)
|
|
108
|
+
auth_mode = os.environ.get("AXION_AUTH_MODE", "").lower()
|
|
109
|
+
|
|
110
|
+
# 1. Try ChatGPT subscription unless explicitly forced to API mode
|
|
111
|
+
if auth_mode != "api":
|
|
112
|
+
try:
|
|
113
|
+
from axion.runtime.openai_subscription import (
|
|
114
|
+
SUBSCRIPTION_PROVIDER,
|
|
115
|
+
has_openai_subscription_credentials,
|
|
116
|
+
load_oauth_credentials,
|
|
117
|
+
)
|
|
118
|
+
if has_openai_subscription_credentials():
|
|
119
|
+
creds = load_oauth_credentials(SUBSCRIPTION_PROVIDER)
|
|
120
|
+
if creds and creds.access_token:
|
|
121
|
+
# api_key still required by ctor; pass empty since
|
|
122
|
+
# bearer_override takes precedence.
|
|
123
|
+
return cls(
|
|
124
|
+
api_key="",
|
|
125
|
+
base_url=base,
|
|
126
|
+
bearer_override=creds.access_token,
|
|
127
|
+
)
|
|
128
|
+
except Exception as exc:
|
|
129
|
+
logger.debug("Subscription auth lookup failed: %s", exc)
|
|
130
|
+
|
|
131
|
+
# 2. Fall back to API key
|
|
132
|
+
api_key = os.environ.get("OPENAI_API_KEY")
|
|
133
|
+
if not api_key:
|
|
134
|
+
from pathlib import Path
|
|
135
|
+
key_path = Path.home() / ".axion" / "credentials" / "openai.key"
|
|
136
|
+
if key_path.exists():
|
|
137
|
+
saved = key_path.read_text(encoding="utf-8").strip()
|
|
138
|
+
if saved:
|
|
139
|
+
api_key = saved
|
|
140
|
+
os.environ["OPENAI_API_KEY"] = saved
|
|
141
|
+
|
|
142
|
+
if not api_key:
|
|
143
|
+
raise MissingCredentialsError(provider="OpenAI", env_vars=["OPENAI_API_KEY"])
|
|
144
|
+
|
|
145
|
+
return cls(api_key=api_key, base_url=base)
|
|
146
|
+
|
|
147
|
+
# -- Public API ----------------------------------------------------------
|
|
148
|
+
|
|
149
|
+
async def send_message(self, request: MessageRequest) -> MessageResponse:
|
|
150
|
+
"""Send a non-streaming Responses API request."""
|
|
151
|
+
req = _clone_request(request, stream=False)
|
|
152
|
+
response = await self._send_with_retry(req)
|
|
153
|
+
payload = response.json()
|
|
154
|
+
return _normalize_response(req.model, payload)
|
|
155
|
+
|
|
156
|
+
async def stream_message(
|
|
157
|
+
self, request: MessageRequest
|
|
158
|
+
) -> AsyncIterator[StreamEvent]:
|
|
159
|
+
"""Stream a Responses API request, converting events to the standard format."""
|
|
160
|
+
req = _clone_request(request, stream=True)
|
|
161
|
+
response = await self._send_with_retry(req)
|
|
162
|
+
|
|
163
|
+
parser = OpenAiSseParser()
|
|
164
|
+
state = _ResponsesStreamState(model=req.model)
|
|
165
|
+
|
|
166
|
+
async for raw_chunk in response.aiter_bytes():
|
|
167
|
+
for chunk in parser.push(raw_chunk):
|
|
168
|
+
for event in state.ingest_event(chunk):
|
|
169
|
+
yield event
|
|
170
|
+
|
|
171
|
+
for event in state.finish():
|
|
172
|
+
yield event
|
|
173
|
+
|
|
174
|
+
async def close(self) -> None:
|
|
175
|
+
await self._http.aclose()
|
|
176
|
+
|
|
177
|
+
# -- Retry / send -------------------------------------------------------
|
|
178
|
+
|
|
179
|
+
async def _send_with_retry(self, request: MessageRequest) -> httpx.Response:
|
|
180
|
+
attempts = 0
|
|
181
|
+
last_error: ApiError | None = None
|
|
182
|
+
while True:
|
|
183
|
+
attempts += 1
|
|
184
|
+
try:
|
|
185
|
+
response = await self._send_raw_request(request)
|
|
186
|
+
_expect_success(response)
|
|
187
|
+
return response
|
|
188
|
+
except ApiError as err:
|
|
189
|
+
if err.is_retryable() and attempts <= self._max_retries:
|
|
190
|
+
last_error = err
|
|
191
|
+
await asyncio.sleep(self._backoff_for_attempt(attempts))
|
|
192
|
+
continue
|
|
193
|
+
if not err.is_retryable():
|
|
194
|
+
raise
|
|
195
|
+
last_error = err
|
|
196
|
+
break
|
|
197
|
+
raise RetriesExhaustedError(attempts=attempts, last_error=last_error) # type: ignore[arg-type]
|
|
198
|
+
|
|
199
|
+
async def _refresh_subscription_token_if_needed(self) -> None:
|
|
200
|
+
"""If using a subscription bearer, refresh it when near-expired."""
|
|
201
|
+
if not self._bearer_override:
|
|
202
|
+
return
|
|
203
|
+
try:
|
|
204
|
+
from axion.runtime.openai_subscription import (
|
|
205
|
+
get_valid_openai_subscription_token,
|
|
206
|
+
)
|
|
207
|
+
new_token = await get_valid_openai_subscription_token()
|
|
208
|
+
if new_token and new_token != self._bearer_override:
|
|
209
|
+
self._bearer_override = new_token
|
|
210
|
+
logger.info("Refreshed ChatGPT subscription token")
|
|
211
|
+
except Exception as exc:
|
|
212
|
+
logger.debug("Subscription token refresh check failed: %s", exc)
|
|
213
|
+
|
|
214
|
+
async def _send_raw_request(self, request: MessageRequest) -> httpx.Response:
|
|
215
|
+
await self._refresh_subscription_token_if_needed()
|
|
216
|
+
url = f"{self._base_url}/responses"
|
|
217
|
+
body = _build_responses_request(request)
|
|
218
|
+
token = self._bearer_override or self._api_key
|
|
219
|
+
headers = {
|
|
220
|
+
"content-type": "application/json",
|
|
221
|
+
"authorization": f"Bearer {token}",
|
|
222
|
+
}
|
|
223
|
+
# Codex CLI sends an originator header so OpenAI knows which client
|
|
224
|
+
# the request came from. Required for subscription-billed requests.
|
|
225
|
+
if self._bearer_override:
|
|
226
|
+
headers["openai-beta"] = "responses-2024-09-30"
|
|
227
|
+
headers["originator"] = "codex_cli_rs"
|
|
228
|
+
try:
|
|
229
|
+
response = await self._http.post(
|
|
230
|
+
url,
|
|
231
|
+
json=body,
|
|
232
|
+
headers=headers,
|
|
233
|
+
)
|
|
234
|
+
return response
|
|
235
|
+
except httpx.HTTPError as exc:
|
|
236
|
+
raise HttpError(str(exc), cause=exc) from exc
|
|
237
|
+
|
|
238
|
+
def _backoff_for_attempt(self, attempt: int) -> float:
|
|
239
|
+
try:
|
|
240
|
+
multiplier = 1 << (attempt - 1)
|
|
241
|
+
except (OverflowError, ValueError):
|
|
242
|
+
raise BackoffOverflowError(attempt=attempt, base_delay_ms=self._initial_backoff_ms)
|
|
243
|
+
delay_ms = min(self._initial_backoff_ms * multiplier, self._max_backoff_ms)
|
|
244
|
+
return delay_ms / 1000.0
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
# ---------------------------------------------------------------------------
|
|
248
|
+
# Request builder: MessageRequest -> /v1/responses body
|
|
249
|
+
# ---------------------------------------------------------------------------
|
|
250
|
+
|
|
251
|
+
def _build_responses_request(request: MessageRequest) -> dict[str, Any]:
|
|
252
|
+
"""Translate an Anthropic-style MessageRequest to a Responses API body."""
|
|
253
|
+
input_items: list[dict[str, Any]] = []
|
|
254
|
+
|
|
255
|
+
# System prompt becomes the Responses API "instructions" field.
|
|
256
|
+
# MessageRequest.system is typed as str|None, but at runtime it can also
|
|
257
|
+
# be a list of text blocks (Anthropic prompt-caching format), so check both.
|
|
258
|
+
instructions: str | None = None
|
|
259
|
+
sys_value: Any = request.system
|
|
260
|
+
if sys_value:
|
|
261
|
+
if isinstance(sys_value, str):
|
|
262
|
+
instructions = sys_value
|
|
263
|
+
elif isinstance(sys_value, list):
|
|
264
|
+
parts: list[str] = []
|
|
265
|
+
for block in sys_value:
|
|
266
|
+
if isinstance(block, dict):
|
|
267
|
+
parts.append(str(block.get("text", "")))
|
|
268
|
+
else:
|
|
269
|
+
parts.append(str(block))
|
|
270
|
+
instructions = "\n\n".join(p for p in parts if p)
|
|
271
|
+
|
|
272
|
+
# Conversation messages
|
|
273
|
+
for msg in request.messages:
|
|
274
|
+
for item in _translate_message_to_input_items(msg):
|
|
275
|
+
input_items.append(item)
|
|
276
|
+
|
|
277
|
+
body: dict[str, Any] = {
|
|
278
|
+
"model": request.model,
|
|
279
|
+
"input": input_items,
|
|
280
|
+
"stream": request.stream,
|
|
281
|
+
"store": False, # stateless — we manage history ourselves
|
|
282
|
+
}
|
|
283
|
+
if instructions:
|
|
284
|
+
body["instructions"] = instructions
|
|
285
|
+
if request.max_tokens:
|
|
286
|
+
body["max_output_tokens"] = request.max_tokens
|
|
287
|
+
|
|
288
|
+
# Tools
|
|
289
|
+
if request.tools:
|
|
290
|
+
body["tools"] = [_translate_tool_definition(t) for t in request.tools]
|
|
291
|
+
|
|
292
|
+
# Codex models support reasoning effort
|
|
293
|
+
if "codex" in request.model.lower() or request.model.startswith(("o1", "o3", "o4", "gpt-5")):
|
|
294
|
+
body["reasoning"] = {"effort": "medium"}
|
|
295
|
+
|
|
296
|
+
# Tool choice (ToolChoice.type is "auto" | "any" | "tool")
|
|
297
|
+
if request.tool_choice is not None:
|
|
298
|
+
tc = request.tool_choice
|
|
299
|
+
if tc.type == "auto":
|
|
300
|
+
body["tool_choice"] = "auto"
|
|
301
|
+
elif tc.type == "any":
|
|
302
|
+
body["tool_choice"] = "required"
|
|
303
|
+
elif tc.type == "tool" and tc.name:
|
|
304
|
+
body["tool_choice"] = {"type": "function", "name": tc.name}
|
|
305
|
+
|
|
306
|
+
return body
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def _translate_message_to_input_items(msg: InputMessage) -> list[dict[str, Any]]:
|
|
310
|
+
"""Translate one InputMessage to one or more Responses API input items."""
|
|
311
|
+
if msg.role == "assistant":
|
|
312
|
+
# Assistant messages may contain text + tool_use blocks
|
|
313
|
+
parts: list[dict[str, Any]] = []
|
|
314
|
+
for block in msg.content:
|
|
315
|
+
if isinstance(block, TextInputBlock):
|
|
316
|
+
parts.append({"type": "output_text", "text": block.text})
|
|
317
|
+
elif isinstance(block, ToolUseInputBlock):
|
|
318
|
+
# Responses API uses function_call for assistant tool requests
|
|
319
|
+
parts.append({
|
|
320
|
+
"type": "function_call",
|
|
321
|
+
"call_id": block.id,
|
|
322
|
+
"name": block.name,
|
|
323
|
+
"arguments": (
|
|
324
|
+
json.dumps(block.input)
|
|
325
|
+
if not isinstance(block.input, str)
|
|
326
|
+
else block.input
|
|
327
|
+
),
|
|
328
|
+
})
|
|
329
|
+
if not parts:
|
|
330
|
+
return []
|
|
331
|
+
return [{"type": "message", "role": "assistant", "content": parts}]
|
|
332
|
+
|
|
333
|
+
# User / tool-result messages
|
|
334
|
+
results: list[dict[str, Any]] = []
|
|
335
|
+
user_content: list[dict[str, Any]] = []
|
|
336
|
+
|
|
337
|
+
for block in msg.content:
|
|
338
|
+
if isinstance(block, TextInputBlock):
|
|
339
|
+
user_content.append({"type": "input_text", "text": block.text})
|
|
340
|
+
elif isinstance(block, ImageInputBlock):
|
|
341
|
+
user_content.append({
|
|
342
|
+
"type": "input_image",
|
|
343
|
+
"image_url": f"data:{block.media_type};base64,{block.data}",
|
|
344
|
+
})
|
|
345
|
+
elif isinstance(block, ToolResultBlock):
|
|
346
|
+
# Tool results become function_call_output items
|
|
347
|
+
results.append({
|
|
348
|
+
"type": "function_call_output",
|
|
349
|
+
"call_id": block.tool_use_id,
|
|
350
|
+
"output": _flatten_tool_result_content(block),
|
|
351
|
+
})
|
|
352
|
+
|
|
353
|
+
if user_content:
|
|
354
|
+
results.append({
|
|
355
|
+
"type": "message",
|
|
356
|
+
"role": "user",
|
|
357
|
+
"content": user_content,
|
|
358
|
+
})
|
|
359
|
+
|
|
360
|
+
return results
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def _flatten_tool_result_content(block: ToolResultBlock) -> str:
|
|
364
|
+
"""Concatenate tool result content blocks into a single string."""
|
|
365
|
+
parts: list[str] = []
|
|
366
|
+
for c in block.content:
|
|
367
|
+
text = getattr(c, "text", None)
|
|
368
|
+
if text is not None:
|
|
369
|
+
parts.append(text)
|
|
370
|
+
else:
|
|
371
|
+
value = getattr(c, "value", None)
|
|
372
|
+
if value is not None:
|
|
373
|
+
parts.append(json.dumps(value) if not isinstance(value, str) else value)
|
|
374
|
+
return "\n".join(parts) if parts else ""
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
def _translate_tool_definition(tool: ToolDefinition) -> dict[str, Any]:
|
|
378
|
+
"""Translate an Anthropic ToolDefinition to Responses API tool format."""
|
|
379
|
+
return {
|
|
380
|
+
"type": "function",
|
|
381
|
+
"name": tool.name,
|
|
382
|
+
"description": tool.description,
|
|
383
|
+
"parameters": tool.input_schema,
|
|
384
|
+
}
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
# ---------------------------------------------------------------------------
|
|
388
|
+
# Response normalization (non-streaming)
|
|
389
|
+
# ---------------------------------------------------------------------------
|
|
390
|
+
|
|
391
|
+
def _normalize_response(model: str, data: dict[str, Any]) -> MessageResponse:
|
|
392
|
+
"""Convert a /v1/responses payload to an Anthropic-style MessageResponse."""
|
|
393
|
+
output = data.get("output", []) or []
|
|
394
|
+
content_blocks: list[OutputContentBlock] = []
|
|
395
|
+
|
|
396
|
+
for item in output:
|
|
397
|
+
item_type = item.get("type", "")
|
|
398
|
+
if item_type == "message":
|
|
399
|
+
for part in item.get("content", []) or []:
|
|
400
|
+
if part.get("type") == "output_text":
|
|
401
|
+
content_blocks.append(TextOutputBlock(text=part.get("text", "")))
|
|
402
|
+
elif item_type == "function_call":
|
|
403
|
+
args_str = item.get("arguments", "")
|
|
404
|
+
try:
|
|
405
|
+
args = json.loads(args_str) if args_str else {}
|
|
406
|
+
except json.JSONDecodeError:
|
|
407
|
+
args = {"raw": args_str}
|
|
408
|
+
content_blocks.append(
|
|
409
|
+
ToolUseOutputBlock(
|
|
410
|
+
id=item.get("call_id") or item.get("id", ""),
|
|
411
|
+
name=item.get("name", ""),
|
|
412
|
+
input=args,
|
|
413
|
+
)
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
usage_data = data.get("usage", {}) or {}
|
|
417
|
+
usage = Usage(
|
|
418
|
+
input_tokens=usage_data.get("input_tokens", 0),
|
|
419
|
+
output_tokens=usage_data.get("output_tokens", 0),
|
|
420
|
+
cache_creation_input_tokens=0,
|
|
421
|
+
cache_read_input_tokens=(usage_data.get("input_tokens_details", {}) or {}).get("cached_tokens", 0),
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
stop_reason = _map_stop_reason(data.get("status"))
|
|
425
|
+
|
|
426
|
+
return MessageResponse(
|
|
427
|
+
id=data.get("id", ""),
|
|
428
|
+
type="message",
|
|
429
|
+
role="assistant",
|
|
430
|
+
content=content_blocks,
|
|
431
|
+
model=data.get("model", model),
|
|
432
|
+
usage=usage,
|
|
433
|
+
stop_reason=stop_reason,
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
def _map_stop_reason(status: str | None) -> str | None:
|
|
438
|
+
if status == "completed":
|
|
439
|
+
return "end_turn"
|
|
440
|
+
if status == "incomplete":
|
|
441
|
+
return "max_tokens"
|
|
442
|
+
if status == "failed":
|
|
443
|
+
return "error"
|
|
444
|
+
return status
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
# ---------------------------------------------------------------------------
|
|
448
|
+
# Streaming: convert /v1/responses SSE events to the standard StreamEvent format
|
|
449
|
+
# ---------------------------------------------------------------------------
|
|
450
|
+
|
|
451
|
+
@dataclass
|
|
452
|
+
class _ResponsesStreamState:
|
|
453
|
+
"""Tracks the streaming state for a single response."""
|
|
454
|
+
model: str
|
|
455
|
+
response_id: str = ""
|
|
456
|
+
started: bool = False
|
|
457
|
+
# Maps output_index -> (block_index, kind, name) for both text & function_call
|
|
458
|
+
output_index_to_block: dict[int, tuple[int, str, str | None]] = field(default_factory=dict)
|
|
459
|
+
next_block_index: int = 0
|
|
460
|
+
accumulated_input_tokens: int = 0
|
|
461
|
+
accumulated_output_tokens: int = 0
|
|
462
|
+
cache_read_tokens: int = 0
|
|
463
|
+
final_stop_reason: str | None = None
|
|
464
|
+
|
|
465
|
+
def ingest_event(self, event: dict[str, Any]) -> list[StreamEvent]:
|
|
466
|
+
"""Process one parsed Responses API SSE event, emit StreamEvents."""
|
|
467
|
+
evt_type = event.get("type", "")
|
|
468
|
+
out: list[StreamEvent] = []
|
|
469
|
+
|
|
470
|
+
if evt_type == "response.created" or evt_type == "response.in_progress":
|
|
471
|
+
if not self.started:
|
|
472
|
+
self.started = True
|
|
473
|
+
self.response_id = (event.get("response") or {}).get("id", "")
|
|
474
|
+
out.append(MessageStartEvent(
|
|
475
|
+
message=MessageResponse(
|
|
476
|
+
id=self.response_id,
|
|
477
|
+
type="message",
|
|
478
|
+
role="assistant",
|
|
479
|
+
content=[],
|
|
480
|
+
model=self.model,
|
|
481
|
+
usage=Usage(),
|
|
482
|
+
),
|
|
483
|
+
))
|
|
484
|
+
return out
|
|
485
|
+
|
|
486
|
+
# A new output item starts (text message OR function call)
|
|
487
|
+
if evt_type == "response.output_item.added":
|
|
488
|
+
item = event.get("item") or {}
|
|
489
|
+
output_index = event.get("output_index", -1)
|
|
490
|
+
item_type = item.get("type", "")
|
|
491
|
+
|
|
492
|
+
if item_type == "function_call":
|
|
493
|
+
block_index = self.next_block_index
|
|
494
|
+
self.next_block_index += 1
|
|
495
|
+
self.output_index_to_block[output_index] = (block_index, "tool_use", item.get("name"))
|
|
496
|
+
out.append(ContentBlockStartEvent(
|
|
497
|
+
index=block_index,
|
|
498
|
+
content_block=ToolUseOutputBlock(
|
|
499
|
+
id=item.get("call_id") or item.get("id", ""),
|
|
500
|
+
name=item.get("name", ""),
|
|
501
|
+
input={},
|
|
502
|
+
),
|
|
503
|
+
))
|
|
504
|
+
elif item_type == "message":
|
|
505
|
+
# Will get text deltas next; allocate a block now
|
|
506
|
+
block_index = self.next_block_index
|
|
507
|
+
self.next_block_index += 1
|
|
508
|
+
self.output_index_to_block[output_index] = (block_index, "text", None)
|
|
509
|
+
out.append(ContentBlockStartEvent(
|
|
510
|
+
index=block_index,
|
|
511
|
+
content_block=TextOutputBlock(text=""),
|
|
512
|
+
))
|
|
513
|
+
return out
|
|
514
|
+
|
|
515
|
+
# Text delta inside a message
|
|
516
|
+
if evt_type == "response.output_text.delta":
|
|
517
|
+
output_index = event.get("output_index", -1)
|
|
518
|
+
mapping = self.output_index_to_block.get(output_index)
|
|
519
|
+
if mapping is None:
|
|
520
|
+
return out
|
|
521
|
+
block_index, _kind, _name = mapping
|
|
522
|
+
delta_text = event.get("delta", "")
|
|
523
|
+
if delta_text:
|
|
524
|
+
out.append(ContentBlockDeltaEvent(
|
|
525
|
+
index=block_index,
|
|
526
|
+
delta=TextDelta(text=delta_text),
|
|
527
|
+
))
|
|
528
|
+
return out
|
|
529
|
+
|
|
530
|
+
# Function call argument delta
|
|
531
|
+
if evt_type == "response.function_call_arguments.delta":
|
|
532
|
+
output_index = event.get("output_index", -1)
|
|
533
|
+
mapping = self.output_index_to_block.get(output_index)
|
|
534
|
+
if mapping is None:
|
|
535
|
+
return out
|
|
536
|
+
block_index, _kind, _name = mapping
|
|
537
|
+
delta_str = event.get("delta", "")
|
|
538
|
+
if delta_str:
|
|
539
|
+
out.append(ContentBlockDeltaEvent(
|
|
540
|
+
index=block_index,
|
|
541
|
+
delta=InputJsonDelta(partial_json=delta_str),
|
|
542
|
+
))
|
|
543
|
+
return out
|
|
544
|
+
|
|
545
|
+
# Output item complete
|
|
546
|
+
if evt_type == "response.output_item.done":
|
|
547
|
+
output_index = event.get("output_index", -1)
|
|
548
|
+
mapping = self.output_index_to_block.get(output_index)
|
|
549
|
+
if mapping is None:
|
|
550
|
+
return out
|
|
551
|
+
block_index, _kind, _name = mapping
|
|
552
|
+
out.append(ContentBlockStopEvent(index=block_index))
|
|
553
|
+
return out
|
|
554
|
+
|
|
555
|
+
# Final completion
|
|
556
|
+
if evt_type == "response.completed":
|
|
557
|
+
response = event.get("response") or {}
|
|
558
|
+
usage = response.get("usage") or {}
|
|
559
|
+
self.accumulated_input_tokens = usage.get("input_tokens", 0)
|
|
560
|
+
self.accumulated_output_tokens = usage.get("output_tokens", 0)
|
|
561
|
+
self.cache_read_tokens = (usage.get("input_tokens_details") or {}).get("cached_tokens", 0)
|
|
562
|
+
self.final_stop_reason = _map_stop_reason(response.get("status"))
|
|
563
|
+
return out
|
|
564
|
+
|
|
565
|
+
# Other events we don't care about: response.reasoning_summary.*, etc.
|
|
566
|
+
return out
|
|
567
|
+
|
|
568
|
+
def finish(self) -> list[StreamEvent]:
|
|
569
|
+
"""Emit final MessageDelta + MessageStop events."""
|
|
570
|
+
out: list[StreamEvent] = []
|
|
571
|
+
out.append(MessageDeltaEvent(
|
|
572
|
+
delta=MessageDelta(
|
|
573
|
+
stop_reason=self.final_stop_reason or "end_turn",
|
|
574
|
+
stop_sequence=None,
|
|
575
|
+
),
|
|
576
|
+
usage=Usage(
|
|
577
|
+
input_tokens=self.accumulated_input_tokens,
|
|
578
|
+
output_tokens=self.accumulated_output_tokens,
|
|
579
|
+
cache_creation_input_tokens=0,
|
|
580
|
+
cache_read_input_tokens=self.cache_read_tokens,
|
|
581
|
+
),
|
|
582
|
+
))
|
|
583
|
+
out.append(MessageStopEvent())
|
|
584
|
+
return out
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
# ---------------------------------------------------------------------------
|
|
588
|
+
# HTTP helpers
|
|
589
|
+
# ---------------------------------------------------------------------------
|
|
590
|
+
|
|
591
|
+
def _expect_success(response: httpx.Response) -> None:
|
|
592
|
+
"""Raise an ApiResponseError on non-2xx responses."""
|
|
593
|
+
if 200 <= response.status_code < 300:
|
|
594
|
+
return
|
|
595
|
+
body = response.text
|
|
596
|
+
error_type = None
|
|
597
|
+
message = None
|
|
598
|
+
try:
|
|
599
|
+
data = json.loads(body)
|
|
600
|
+
if "error" in data:
|
|
601
|
+
err_obj = data["error"]
|
|
602
|
+
error_type = err_obj.get("type")
|
|
603
|
+
message = err_obj.get("message")
|
|
604
|
+
except (json.JSONDecodeError, KeyError):
|
|
605
|
+
pass
|
|
606
|
+
|
|
607
|
+
retryable = response.status_code in (429, 500, 502, 503, 529)
|
|
608
|
+
raise ApiResponseError(
|
|
609
|
+
status=response.status_code,
|
|
610
|
+
error_type=error_type,
|
|
611
|
+
message=message,
|
|
612
|
+
request_id_val=response.headers.get("x-request-id"),
|
|
613
|
+
body=body,
|
|
614
|
+
retryable=retryable,
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
|
|
618
|
+
def _clone_request(req: MessageRequest, *, stream: bool) -> MessageRequest:
|
|
619
|
+
return MessageRequest(
|
|
620
|
+
model=req.model,
|
|
621
|
+
max_tokens=req.max_tokens,
|
|
622
|
+
messages=req.messages,
|
|
623
|
+
system=req.system,
|
|
624
|
+
tools=req.tools,
|
|
625
|
+
tool_choice=req.tool_choice,
|
|
626
|
+
stream=stream,
|
|
627
|
+
)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""Prompt caching support.
|
|
2
|
+
|
|
3
|
+
Maps to: rust/crates/api/src/prompt_cache.rs
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class PromptCacheConfig:
|
|
13
|
+
"""Configuration for prompt caching behavior."""
|
|
14
|
+
|
|
15
|
+
enabled: bool = True
|
|
16
|
+
scope: str = "session"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class PromptCache:
|
|
21
|
+
"""Tracks prompt cache state across requests."""
|
|
22
|
+
|
|
23
|
+
config: PromptCacheConfig = field(default_factory=PromptCacheConfig)
|
|
24
|
+
cache_hits: int = 0
|
|
25
|
+
cache_misses: int = 0
|
|
26
|
+
|
|
27
|
+
def record_hit(self, tokens: int) -> None:
|
|
28
|
+
self.cache_hits += 1
|
|
29
|
+
|
|
30
|
+
def record_miss(self) -> None:
|
|
31
|
+
self.cache_misses += 1
|