copex 0.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- copex/__init__.py +69 -0
- copex/checkpoint.py +445 -0
- copex/cli.py +1106 -0
- copex/client.py +725 -0
- copex/config.py +311 -0
- copex/mcp.py +561 -0
- copex/metrics.py +383 -0
- copex/models.py +50 -0
- copex/persistence.py +324 -0
- copex/plan.py +358 -0
- copex/ralph.py +247 -0
- copex/tools.py +404 -0
- copex/ui.py +971 -0
- copex-0.8.4.dist-info/METADATA +511 -0
- copex-0.8.4.dist-info/RECORD +18 -0
- copex-0.8.4.dist-info/WHEEL +4 -0
- copex-0.8.4.dist-info/entry_points.txt +2 -0
- copex-0.8.4.dist-info/licenses/LICENSE +21 -0
copex/client.py
ADDED
|
@@ -0,0 +1,725 @@
|
|
|
1
|
+
"""Core Copex client with retry logic and stuck detection."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import random
|
|
7
|
+
from collections.abc import AsyncIterator, Callable
|
|
8
|
+
from contextlib import asynccontextmanager
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from copilot import CopilotClient
|
|
13
|
+
from copilot.session import CopilotSession
|
|
14
|
+
|
|
15
|
+
from copex.config import CopexConfig
|
|
16
|
+
from copex.metrics import MetricsCollector, get_collector
|
|
17
|
+
from copex.models import EventType, Model, ReasoningEffort
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class Response:
|
|
22
|
+
"""Response from a Copilot prompt."""
|
|
23
|
+
|
|
24
|
+
content: str
|
|
25
|
+
reasoning: str | None = None
|
|
26
|
+
raw_events: list[dict[str, Any]] = field(default_factory=list)
|
|
27
|
+
retries: int = 0
|
|
28
|
+
auto_continues: int = 0
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class StreamChunk:
|
|
33
|
+
"""A streaming chunk from Copilot."""
|
|
34
|
+
|
|
35
|
+
type: str # "message", "reasoning", "tool_call", "tool_result", "system"
|
|
36
|
+
delta: str = ""
|
|
37
|
+
is_final: bool = False
|
|
38
|
+
content: str | None = None # Full content when is_final=True
|
|
39
|
+
# Tool call info
|
|
40
|
+
tool_name: str | None = None
|
|
41
|
+
tool_args: dict[str, Any] | None = None
|
|
42
|
+
tool_result: str | None = None
|
|
43
|
+
tool_success: bool | None = None
|
|
44
|
+
tool_duration: float | None = None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class _SendState:
|
|
49
|
+
"""State for handling a single send call."""
|
|
50
|
+
|
|
51
|
+
done: asyncio.Event
|
|
52
|
+
error_holder: list[Exception] = field(default_factory=list)
|
|
53
|
+
content_parts: list[str] = field(default_factory=list)
|
|
54
|
+
reasoning_parts: list[str] = field(default_factory=list)
|
|
55
|
+
final_content: str | None = None
|
|
56
|
+
final_reasoning: str | None = None
|
|
57
|
+
raw_events: list[dict[str, Any]] = field(default_factory=list)
|
|
58
|
+
last_activity: float = 0.0
|
|
59
|
+
received_content: bool = False
|
|
60
|
+
pending_tools: int = 0
|
|
61
|
+
awaiting_post_tool_response: bool = False
|
|
62
|
+
tool_execution_seen: bool = False
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class Copex:
|
|
66
|
+
"""Copilot Extended - Resilient wrapper with automatic retry and stuck detection."""
|
|
67
|
+
|
|
68
|
+
def __init__(self, config: CopexConfig | None = None):
|
|
69
|
+
self.config = config or CopexConfig()
|
|
70
|
+
self._client: CopilotClient | None = None
|
|
71
|
+
self._session: Any = None
|
|
72
|
+
self._started = False
|
|
73
|
+
|
|
74
|
+
async def start(self) -> None:
|
|
75
|
+
"""Start the Copilot client."""
|
|
76
|
+
if self._started:
|
|
77
|
+
return
|
|
78
|
+
self._client = CopilotClient(self.config.to_client_options())
|
|
79
|
+
await self._client.start()
|
|
80
|
+
self._started = True
|
|
81
|
+
|
|
82
|
+
async def stop(self) -> None:
|
|
83
|
+
"""Stop the Copilot client."""
|
|
84
|
+
if self._session:
|
|
85
|
+
try:
|
|
86
|
+
await self._session.destroy()
|
|
87
|
+
except Exception:
|
|
88
|
+
pass
|
|
89
|
+
self._session = None
|
|
90
|
+
if self._client:
|
|
91
|
+
await self._client.stop()
|
|
92
|
+
self._client = None
|
|
93
|
+
self._started = False
|
|
94
|
+
|
|
95
|
+
async def __aenter__(self) -> "Copex":
|
|
96
|
+
await self.start()
|
|
97
|
+
return self
|
|
98
|
+
|
|
99
|
+
async def __aexit__(self, *args: Any) -> None:
|
|
100
|
+
await self.stop()
|
|
101
|
+
|
|
102
|
+
def _should_retry(self, error: str | Exception) -> bool:
|
|
103
|
+
"""Check if error should trigger a retry."""
|
|
104
|
+
if self.config.retry.retry_on_any_error:
|
|
105
|
+
return True
|
|
106
|
+
error_str = str(error).lower()
|
|
107
|
+
return any(
|
|
108
|
+
pattern.lower() in error_str for pattern in self.config.retry.retry_on_errors
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
def _is_tool_state_error(self, error: str | Exception) -> bool:
|
|
112
|
+
"""Detect tool-state mismatch errors that require session recovery."""
|
|
113
|
+
error_str = str(error).lower()
|
|
114
|
+
return "tool_use_id" in error_str and "tool_result" in error_str
|
|
115
|
+
|
|
116
|
+
def _calculate_delay(self, attempt: int) -> float:
|
|
117
|
+
"""Calculate delay with exponential backoff and jitter."""
|
|
118
|
+
delay = self.config.retry.base_delay * (self.config.retry.exponential_base ** attempt)
|
|
119
|
+
delay = min(delay, self.config.retry.max_delay)
|
|
120
|
+
# Add jitter (±25%)
|
|
121
|
+
jitter = delay * 0.25 * (2 * random.random() - 1)
|
|
122
|
+
return delay + jitter
|
|
123
|
+
|
|
124
|
+
def _handle_message_delta(
|
|
125
|
+
self,
|
|
126
|
+
event: Any,
|
|
127
|
+
state: _SendState,
|
|
128
|
+
on_chunk: Callable[[StreamChunk], None] | None,
|
|
129
|
+
) -> None:
|
|
130
|
+
delta = getattr(event.data, "delta_content", "") or ""
|
|
131
|
+
if not delta:
|
|
132
|
+
delta = getattr(event.data, "transformed_content", "") or ""
|
|
133
|
+
if delta:
|
|
134
|
+
state.received_content = True
|
|
135
|
+
state.content_parts.append(delta)
|
|
136
|
+
if state.awaiting_post_tool_response and state.tool_execution_seen and state.pending_tools == 0:
|
|
137
|
+
state.awaiting_post_tool_response = False
|
|
138
|
+
if on_chunk:
|
|
139
|
+
on_chunk(StreamChunk(type="message", delta=delta))
|
|
140
|
+
|
|
141
|
+
def _handle_reasoning_delta(
|
|
142
|
+
self,
|
|
143
|
+
event: Any,
|
|
144
|
+
state: _SendState,
|
|
145
|
+
on_chunk: Callable[[StreamChunk], None] | None,
|
|
146
|
+
) -> None:
|
|
147
|
+
delta = getattr(event.data, "delta_content", "") or ""
|
|
148
|
+
state.reasoning_parts.append(delta)
|
|
149
|
+
if on_chunk:
|
|
150
|
+
on_chunk(StreamChunk(type="reasoning", delta=delta))
|
|
151
|
+
|
|
152
|
+
def _handle_message(
|
|
153
|
+
self,
|
|
154
|
+
event: Any,
|
|
155
|
+
state: _SendState,
|
|
156
|
+
on_chunk: Callable[[StreamChunk], None] | None,
|
|
157
|
+
) -> None:
|
|
158
|
+
content = getattr(event.data, "content", "") or ""
|
|
159
|
+
if not content:
|
|
160
|
+
content = getattr(event.data, "transformed_content", "") or ""
|
|
161
|
+
state.final_content = content
|
|
162
|
+
if content:
|
|
163
|
+
state.received_content = True
|
|
164
|
+
if state.awaiting_post_tool_response and state.tool_execution_seen and state.pending_tools == 0:
|
|
165
|
+
state.awaiting_post_tool_response = False
|
|
166
|
+
if on_chunk:
|
|
167
|
+
on_chunk(StreamChunk(
|
|
168
|
+
type="message",
|
|
169
|
+
delta="",
|
|
170
|
+
is_final=True,
|
|
171
|
+
content=state.final_content,
|
|
172
|
+
))
|
|
173
|
+
|
|
174
|
+
def _handle_reasoning(
|
|
175
|
+
self,
|
|
176
|
+
event: Any,
|
|
177
|
+
state: _SendState,
|
|
178
|
+
on_chunk: Callable[[StreamChunk], None] | None,
|
|
179
|
+
) -> None:
|
|
180
|
+
state.final_reasoning = getattr(event.data, "content", "") or ""
|
|
181
|
+
if on_chunk:
|
|
182
|
+
on_chunk(StreamChunk(
|
|
183
|
+
type="reasoning",
|
|
184
|
+
delta="",
|
|
185
|
+
is_final=True,
|
|
186
|
+
content=state.final_reasoning,
|
|
187
|
+
))
|
|
188
|
+
|
|
189
|
+
def _handle_tool_execution_start(
|
|
190
|
+
self,
|
|
191
|
+
event: Any,
|
|
192
|
+
state: _SendState,
|
|
193
|
+
on_chunk: Callable[[StreamChunk], None] | None,
|
|
194
|
+
) -> None:
|
|
195
|
+
tool_name = getattr(event.data, "tool_name", None) or getattr(event.data, "name", None)
|
|
196
|
+
tool_args = getattr(event.data, "arguments", None)
|
|
197
|
+
state.pending_tools += 1
|
|
198
|
+
state.awaiting_post_tool_response = True
|
|
199
|
+
state.tool_execution_seen = True
|
|
200
|
+
if on_chunk:
|
|
201
|
+
on_chunk(StreamChunk(
|
|
202
|
+
type="tool_call",
|
|
203
|
+
tool_name=str(tool_name) if tool_name else "unknown",
|
|
204
|
+
tool_args=tool_args if isinstance(tool_args, dict) else {},
|
|
205
|
+
))
|
|
206
|
+
|
|
207
|
+
def _handle_tool_execution_partial_result(
|
|
208
|
+
self,
|
|
209
|
+
event: Any,
|
|
210
|
+
state: _SendState,
|
|
211
|
+
on_chunk: Callable[[StreamChunk], None] | None,
|
|
212
|
+
) -> None:
|
|
213
|
+
tool_name = getattr(event.data, "tool_name", None) or getattr(event.data, "name", None)
|
|
214
|
+
partial = getattr(event.data, "partial_output", None)
|
|
215
|
+
state.awaiting_post_tool_response = True
|
|
216
|
+
state.tool_execution_seen = True
|
|
217
|
+
if on_chunk and partial:
|
|
218
|
+
on_chunk(StreamChunk(
|
|
219
|
+
type="tool_result",
|
|
220
|
+
tool_name=str(tool_name) if tool_name else "unknown",
|
|
221
|
+
tool_result=str(partial),
|
|
222
|
+
))
|
|
223
|
+
|
|
224
|
+
def _handle_tool_execution_complete(
|
|
225
|
+
self,
|
|
226
|
+
event: Any,
|
|
227
|
+
state: _SendState,
|
|
228
|
+
on_chunk: Callable[[StreamChunk], None] | None,
|
|
229
|
+
) -> None:
|
|
230
|
+
tool_name = getattr(event.data, "tool_name", None) or getattr(event.data, "name", None)
|
|
231
|
+
result_obj = getattr(event.data, "result", None)
|
|
232
|
+
result_text = ""
|
|
233
|
+
if result_obj is not None:
|
|
234
|
+
result_text = getattr(result_obj, "content", "") or str(result_obj)
|
|
235
|
+
success = getattr(event.data, "success", None)
|
|
236
|
+
duration = getattr(event.data, "duration", None)
|
|
237
|
+
state.pending_tools = max(0, state.pending_tools - 1)
|
|
238
|
+
state.awaiting_post_tool_response = True
|
|
239
|
+
state.tool_execution_seen = True
|
|
240
|
+
if on_chunk:
|
|
241
|
+
on_chunk(StreamChunk(
|
|
242
|
+
type="tool_result",
|
|
243
|
+
tool_name=str(tool_name) if tool_name else "unknown",
|
|
244
|
+
tool_result=result_text,
|
|
245
|
+
tool_success=success,
|
|
246
|
+
tool_duration=duration,
|
|
247
|
+
))
|
|
248
|
+
|
|
249
|
+
def _handle_error_event(self, event: Any, state: _SendState) -> None:
|
|
250
|
+
error_msg = str(getattr(event.data, "message", event.data))
|
|
251
|
+
state.error_holder.append(RuntimeError(error_msg))
|
|
252
|
+
state.done.set()
|
|
253
|
+
|
|
254
|
+
def _handle_tool_call(
|
|
255
|
+
self,
|
|
256
|
+
event: Any,
|
|
257
|
+
state: _SendState,
|
|
258
|
+
on_chunk: Callable[[StreamChunk], None] | None,
|
|
259
|
+
) -> None:
|
|
260
|
+
data = event.data
|
|
261
|
+
tool_name = getattr(data, "name", None) or getattr(data, "tool", None) or "unknown"
|
|
262
|
+
tool_args = getattr(data, "arguments", None) or getattr(data, "args", {})
|
|
263
|
+
state.awaiting_post_tool_response = True
|
|
264
|
+
if isinstance(tool_args, str):
|
|
265
|
+
import json
|
|
266
|
+
try:
|
|
267
|
+
tool_args = json.loads(tool_args)
|
|
268
|
+
except Exception:
|
|
269
|
+
tool_args = {"raw": tool_args}
|
|
270
|
+
if on_chunk:
|
|
271
|
+
on_chunk(StreamChunk(
|
|
272
|
+
type="tool_call",
|
|
273
|
+
tool_name=str(tool_name),
|
|
274
|
+
tool_args=tool_args if isinstance(tool_args, dict) else {},
|
|
275
|
+
))
|
|
276
|
+
|
|
277
|
+
def _handle_assistant_turn_end(self, state: _SendState) -> None:
|
|
278
|
+
if not state.awaiting_post_tool_response:
|
|
279
|
+
state.done.set()
|
|
280
|
+
|
|
281
|
+
def _handle_session_idle(self, state: _SendState) -> None:
|
|
282
|
+
state.done.set()
|
|
283
|
+
|
|
284
|
+
async def _ensure_session(self) -> Any:
|
|
285
|
+
"""Ensure a session exists, creating one if needed."""
|
|
286
|
+
if not self._started:
|
|
287
|
+
await self.start()
|
|
288
|
+
if self._session is None:
|
|
289
|
+
self._session = await self._create_session_with_reasoning()
|
|
290
|
+
return self._session
|
|
291
|
+
|
|
292
|
+
async def _create_session_with_reasoning(self) -> CopilotSession:
|
|
293
|
+
"""Create a session with reasoning effort support.
|
|
294
|
+
|
|
295
|
+
The GitHub Copilot SDK's create_session() ignores model_reasoning_effort,
|
|
296
|
+
so we bypass it and call the JSON-RPC directly to inject this parameter.
|
|
297
|
+
|
|
298
|
+
Falls back to SDK's create_session() in test environments where the
|
|
299
|
+
internal JSON-RPC client isn't accessible.
|
|
300
|
+
"""
|
|
301
|
+
opts = self.config.to_session_options()
|
|
302
|
+
|
|
303
|
+
# Check if we can access the internal JSON-RPC client
|
|
304
|
+
# If not (e.g., in tests with mocked clients), fall back to SDK's create_session
|
|
305
|
+
if not hasattr(self._client, '_client') or self._client._client is None:
|
|
306
|
+
return await self._client.create_session(opts)
|
|
307
|
+
|
|
308
|
+
# Build the wire payload with proper camelCase keys
|
|
309
|
+
payload: dict[str, Any] = {}
|
|
310
|
+
|
|
311
|
+
if opts.get("model"):
|
|
312
|
+
payload["model"] = opts["model"]
|
|
313
|
+
if opts.get("streaming") is not None:
|
|
314
|
+
payload["streaming"] = opts["streaming"]
|
|
315
|
+
|
|
316
|
+
# The key fix: inject modelReasoningEffort directly into the wire payload
|
|
317
|
+
# The SDK's create_session() drops this, but the server accepts it!
|
|
318
|
+
reasoning_effort = opts.get("model_reasoning_effort")
|
|
319
|
+
if reasoning_effort and reasoning_effort != "none":
|
|
320
|
+
payload["modelReasoningEffort"] = reasoning_effort
|
|
321
|
+
|
|
322
|
+
# Map other session options
|
|
323
|
+
if opts.get("system_message"):
|
|
324
|
+
payload["systemMessage"] = opts["system_message"]
|
|
325
|
+
if opts.get("available_tools"):
|
|
326
|
+
payload["availableTools"] = opts["available_tools"]
|
|
327
|
+
if opts.get("excluded_tools"):
|
|
328
|
+
payload["excludedTools"] = opts["excluded_tools"]
|
|
329
|
+
if opts.get("working_directory"):
|
|
330
|
+
payload["workingDirectory"] = opts["working_directory"]
|
|
331
|
+
if opts.get("mcp_servers"):
|
|
332
|
+
payload["mcpServers"] = opts["mcp_servers"]
|
|
333
|
+
if opts.get("skill_directories"):
|
|
334
|
+
payload["skillDirectories"] = opts["skill_directories"]
|
|
335
|
+
if opts.get("disabled_skills"):
|
|
336
|
+
payload["disabledSkills"] = opts["disabled_skills"]
|
|
337
|
+
if opts.get("instructions"):
|
|
338
|
+
# Instructions go into system message
|
|
339
|
+
if "systemMessage" not in payload:
|
|
340
|
+
payload["systemMessage"] = {"mode": "append", "content": opts["instructions"]}
|
|
341
|
+
elif isinstance(payload["systemMessage"], dict):
|
|
342
|
+
existing = payload["systemMessage"].get("content", "")
|
|
343
|
+
payload["systemMessage"]["content"] = f"{existing}\n\n{opts['instructions']}" if existing else opts["instructions"]
|
|
344
|
+
|
|
345
|
+
# Call the JSON-RPC directly, bypassing the SDK's create_session
|
|
346
|
+
response = await self._client._client.request("session.create", payload)
|
|
347
|
+
|
|
348
|
+
session_id = response["sessionId"]
|
|
349
|
+
workspace_path = response.get("workspacePath")
|
|
350
|
+
|
|
351
|
+
# Create a CopilotSession using the SDK's class
|
|
352
|
+
session = CopilotSession(session_id, self._client._client, workspace_path)
|
|
353
|
+
|
|
354
|
+
# Register the session with the client for event dispatch
|
|
355
|
+
# Note: we access the internal _sessions dict since we bypassed create_session
|
|
356
|
+
with self._client._sessions_lock:
|
|
357
|
+
self._client._sessions[session_id] = session
|
|
358
|
+
|
|
359
|
+
return session
|
|
360
|
+
|
|
361
|
+
async def _get_session_context(self, session: Any) -> str | None:
|
|
362
|
+
"""Extract conversation context from session for recovery."""
|
|
363
|
+
try:
|
|
364
|
+
messages = await session.get_messages()
|
|
365
|
+
if not messages:
|
|
366
|
+
return None
|
|
367
|
+
|
|
368
|
+
# Build a summary of the conversation
|
|
369
|
+
context_parts = []
|
|
370
|
+
for msg in messages:
|
|
371
|
+
msg_type = getattr(msg, "type", None)
|
|
372
|
+
msg_value = msg_type.value if hasattr(msg_type, "value") else str(msg_type)
|
|
373
|
+
data = getattr(msg, "data", None)
|
|
374
|
+
|
|
375
|
+
if msg_value == EventType.USER_MESSAGE.value:
|
|
376
|
+
content = getattr(data, "content", "") or getattr(data, "prompt", "")
|
|
377
|
+
if content:
|
|
378
|
+
context_parts.append(f"User: {content[:500]}")
|
|
379
|
+
elif msg_value == EventType.ASSISTANT_MESSAGE.value:
|
|
380
|
+
content = getattr(data, "content", "") or ""
|
|
381
|
+
if content:
|
|
382
|
+
# Truncate long responses
|
|
383
|
+
truncated = content[:1000] + "..." if len(content) > 1000 else content
|
|
384
|
+
context_parts.append(f"Assistant: {truncated}")
|
|
385
|
+
|
|
386
|
+
if not context_parts:
|
|
387
|
+
return None
|
|
388
|
+
|
|
389
|
+
return "\n\n".join(context_parts[-10:]) # Last 10 messages max
|
|
390
|
+
except Exception:
|
|
391
|
+
return None
|
|
392
|
+
|
|
393
|
+
async def _recover_session(self, on_chunk: Callable[[StreamChunk], None] | None) -> tuple[Any, str]:
|
|
394
|
+
"""Destroy bad session and create new one, preserving context."""
|
|
395
|
+
context = None
|
|
396
|
+
if self._session:
|
|
397
|
+
context = await self._get_session_context(self._session)
|
|
398
|
+
try:
|
|
399
|
+
await self._session.destroy()
|
|
400
|
+
except Exception:
|
|
401
|
+
pass
|
|
402
|
+
self._session = None
|
|
403
|
+
|
|
404
|
+
# Create fresh session
|
|
405
|
+
session = await self._ensure_session()
|
|
406
|
+
|
|
407
|
+
# Build recovery prompt with context
|
|
408
|
+
if context:
|
|
409
|
+
recovery_prompt = (
|
|
410
|
+
f"[Session recovered. Previous conversation context:]\n\n"
|
|
411
|
+
f"{context}\n\n"
|
|
412
|
+
f"[End of context. {self.config.continue_prompt}]"
|
|
413
|
+
)
|
|
414
|
+
else:
|
|
415
|
+
recovery_prompt = self.config.continue_prompt
|
|
416
|
+
|
|
417
|
+
if on_chunk:
|
|
418
|
+
on_chunk(StreamChunk(
|
|
419
|
+
type="system",
|
|
420
|
+
delta="\n[Session recovered with fresh connection]\n",
|
|
421
|
+
))
|
|
422
|
+
|
|
423
|
+
return session, recovery_prompt
|
|
424
|
+
|
|
425
|
+
async def send(
|
|
426
|
+
self,
|
|
427
|
+
prompt: str,
|
|
428
|
+
*,
|
|
429
|
+
tools: list[Any] | None = None,
|
|
430
|
+
on_chunk: Callable[[StreamChunk], None] | None = None,
|
|
431
|
+
metrics: MetricsCollector | None = None,
|
|
432
|
+
) -> Response:
|
|
433
|
+
"""
|
|
434
|
+
Send a prompt with automatic retry on errors.
|
|
435
|
+
|
|
436
|
+
Args:
|
|
437
|
+
prompt: The prompt to send
|
|
438
|
+
tools: Optional list of tools to make available
|
|
439
|
+
on_chunk: Optional callback for streaming chunks
|
|
440
|
+
|
|
441
|
+
Returns:
|
|
442
|
+
Response object with content and metadata
|
|
443
|
+
"""
|
|
444
|
+
session = await self._ensure_session()
|
|
445
|
+
retries = 0
|
|
446
|
+
auto_continues = 0
|
|
447
|
+
last_error: Exception | None = None
|
|
448
|
+
collector = metrics or get_collector()
|
|
449
|
+
request = collector.start_request(
|
|
450
|
+
model=self.config.model.value,
|
|
451
|
+
reasoning_effort=self.config.reasoning_effort.value,
|
|
452
|
+
prompt=prompt,
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
while True:
|
|
456
|
+
try:
|
|
457
|
+
result = await self._send_once(session, prompt, tools, on_chunk)
|
|
458
|
+
result.retries = retries
|
|
459
|
+
result.auto_continues = auto_continues
|
|
460
|
+
collector.complete_request(
|
|
461
|
+
request.request_id,
|
|
462
|
+
success=True,
|
|
463
|
+
response=result.content,
|
|
464
|
+
retries=retries,
|
|
465
|
+
)
|
|
466
|
+
return result
|
|
467
|
+
|
|
468
|
+
except Exception as e:
|
|
469
|
+
last_error = e
|
|
470
|
+
error_str = str(e)
|
|
471
|
+
|
|
472
|
+
if self._is_tool_state_error(e) and self.config.auto_continue:
|
|
473
|
+
auto_continues += 1
|
|
474
|
+
if auto_continues > self.config.retry.max_auto_continues:
|
|
475
|
+
collector.complete_request(
|
|
476
|
+
request.request_id,
|
|
477
|
+
success=False,
|
|
478
|
+
error=str(last_error),
|
|
479
|
+
retries=retries,
|
|
480
|
+
)
|
|
481
|
+
raise last_error
|
|
482
|
+
retries = 0
|
|
483
|
+
session, prompt = await self._recover_session(on_chunk)
|
|
484
|
+
if on_chunk:
|
|
485
|
+
on_chunk(StreamChunk(
|
|
486
|
+
type="system",
|
|
487
|
+
delta="\n[Tool state mismatch detected; recovered session]\n",
|
|
488
|
+
))
|
|
489
|
+
delay = self._calculate_delay(0)
|
|
490
|
+
await asyncio.sleep(delay)
|
|
491
|
+
continue
|
|
492
|
+
|
|
493
|
+
if not self._should_retry(e):
|
|
494
|
+
collector.complete_request(
|
|
495
|
+
request.request_id,
|
|
496
|
+
success=False,
|
|
497
|
+
error=error_str,
|
|
498
|
+
retries=retries,
|
|
499
|
+
)
|
|
500
|
+
raise
|
|
501
|
+
|
|
502
|
+
retries += 1
|
|
503
|
+
if retries <= self.config.retry.max_retries:
|
|
504
|
+
# Normal retry with exponential backoff (same session)
|
|
505
|
+
delay = self._calculate_delay(retries - 1)
|
|
506
|
+
if on_chunk:
|
|
507
|
+
on_chunk(StreamChunk(
|
|
508
|
+
type="system",
|
|
509
|
+
delta=f"\n[Retry {retries}/{self.config.retry.max_retries} after error: {error_str[:50]}...]\n",
|
|
510
|
+
))
|
|
511
|
+
await asyncio.sleep(delay)
|
|
512
|
+
elif self.config.auto_continue and auto_continues < self.config.retry.max_auto_continues:
|
|
513
|
+
# Retries exhausted - session may be in bad state
|
|
514
|
+
# Recover with fresh session, preserving context
|
|
515
|
+
auto_continues += 1
|
|
516
|
+
retries = 0
|
|
517
|
+
session, prompt = await self._recover_session(on_chunk)
|
|
518
|
+
delay = self._calculate_delay(0)
|
|
519
|
+
if on_chunk:
|
|
520
|
+
on_chunk(StreamChunk(
|
|
521
|
+
type="system",
|
|
522
|
+
delta=f"\n[Auto-continue #{auto_continues}/{self.config.retry.max_auto_continues} with fresh session]\n",
|
|
523
|
+
))
|
|
524
|
+
await asyncio.sleep(delay)
|
|
525
|
+
else:
|
|
526
|
+
collector.complete_request(
|
|
527
|
+
request.request_id,
|
|
528
|
+
success=False,
|
|
529
|
+
error=str(last_error) if last_error else "Max retries exceeded",
|
|
530
|
+
retries=retries,
|
|
531
|
+
)
|
|
532
|
+
raise last_error or RuntimeError("Max retries exceeded")
|
|
533
|
+
|
|
534
|
+
async def _send_once(
|
|
535
|
+
self,
|
|
536
|
+
session: Any,
|
|
537
|
+
prompt: str,
|
|
538
|
+
tools: list[Any] | None,
|
|
539
|
+
on_chunk: Callable[[StreamChunk], None] | None,
|
|
540
|
+
) -> Response:
|
|
541
|
+
"""Send a single prompt and collect the response."""
|
|
542
|
+
state = _SendState(done=asyncio.Event())
|
|
543
|
+
state.last_activity = asyncio.get_running_loop().time()
|
|
544
|
+
|
|
545
|
+
def on_event(event: Any) -> None:
|
|
546
|
+
state.last_activity = asyncio.get_running_loop().time()
|
|
547
|
+
try:
|
|
548
|
+
event_type = event.type.value if hasattr(event.type, "value") else str(event.type)
|
|
549
|
+
state.raw_events.append({"type": event_type, "data": getattr(event, "data", None)})
|
|
550
|
+
|
|
551
|
+
if event_type == EventType.ASSISTANT_MESSAGE_DELTA.value:
|
|
552
|
+
self._handle_message_delta(event, state, on_chunk)
|
|
553
|
+
|
|
554
|
+
elif event_type == EventType.ASSISTANT_REASONING_DELTA.value:
|
|
555
|
+
self._handle_reasoning_delta(event, state, on_chunk)
|
|
556
|
+
|
|
557
|
+
elif event_type == EventType.ASSISTANT_MESSAGE.value:
|
|
558
|
+
self._handle_message(event, state, on_chunk)
|
|
559
|
+
|
|
560
|
+
elif event_type == EventType.ASSISTANT_REASONING.value:
|
|
561
|
+
self._handle_reasoning(event, state, on_chunk)
|
|
562
|
+
|
|
563
|
+
elif event_type == EventType.TOOL_EXECUTION_START.value:
|
|
564
|
+
self._handle_tool_execution_start(event, state, on_chunk)
|
|
565
|
+
|
|
566
|
+
elif event_type == EventType.TOOL_EXECUTION_PARTIAL_RESULT.value:
|
|
567
|
+
self._handle_tool_execution_partial_result(event, state, on_chunk)
|
|
568
|
+
|
|
569
|
+
elif event_type == EventType.TOOL_EXECUTION_COMPLETE.value:
|
|
570
|
+
self._handle_tool_execution_complete(event, state, on_chunk)
|
|
571
|
+
|
|
572
|
+
elif event_type == EventType.ERROR.value:
|
|
573
|
+
self._handle_error_event(event, state)
|
|
574
|
+
|
|
575
|
+
elif event_type == EventType.SESSION_ERROR.value:
|
|
576
|
+
self._handle_error_event(event, state)
|
|
577
|
+
|
|
578
|
+
elif event_type == EventType.TOOL_CALL.value:
|
|
579
|
+
self._handle_tool_call(event, state, on_chunk)
|
|
580
|
+
|
|
581
|
+
elif event_type == EventType.ASSISTANT_TURN_END.value:
|
|
582
|
+
self._handle_assistant_turn_end(state)
|
|
583
|
+
|
|
584
|
+
elif event_type == EventType.SESSION_IDLE.value:
|
|
585
|
+
self._handle_session_idle(state)
|
|
586
|
+
|
|
587
|
+
except Exception as e:
|
|
588
|
+
state.error_holder.append(e)
|
|
589
|
+
state.done.set()
|
|
590
|
+
|
|
591
|
+
unsubscribe = session.on(on_event)
|
|
592
|
+
|
|
593
|
+
try:
|
|
594
|
+
await session.send({"prompt": prompt})
|
|
595
|
+
# Activity-based timeout: only timeout if no events received for timeout period
|
|
596
|
+
while not state.done.is_set():
|
|
597
|
+
try:
|
|
598
|
+
await asyncio.wait_for(state.done.wait(), timeout=self.config.timeout)
|
|
599
|
+
except asyncio.TimeoutError:
|
|
600
|
+
# Check if we've had activity within the timeout window
|
|
601
|
+
idle_time = asyncio.get_running_loop().time() - state.last_activity
|
|
602
|
+
if idle_time >= self.config.timeout:
|
|
603
|
+
raise TimeoutError(
|
|
604
|
+
f"Response timed out after {idle_time:.1f}s of inactivity"
|
|
605
|
+
)
|
|
606
|
+
# Had recent activity, keep waiting
|
|
607
|
+
finally:
|
|
608
|
+
# Remove event handler to avoid duplicates
|
|
609
|
+
try:
|
|
610
|
+
unsubscribe()
|
|
611
|
+
except Exception:
|
|
612
|
+
pass
|
|
613
|
+
|
|
614
|
+
# If we never got explicit content events and NOT streaming, try to extract from history
|
|
615
|
+
# When streaming (on_chunk provided), we trust the streamed chunks and don't use history
|
|
616
|
+
# fallback which could return stale content from previous turns
|
|
617
|
+
if not state.received_content and on_chunk is None:
|
|
618
|
+
try:
|
|
619
|
+
messages = await session.get_messages()
|
|
620
|
+
for message in reversed(messages):
|
|
621
|
+
message_type = getattr(message, "type", None)
|
|
622
|
+
message_value = (
|
|
623
|
+
message_type.value if hasattr(message_type, "value") else str(message_type)
|
|
624
|
+
)
|
|
625
|
+
if message_value == EventType.ASSISTANT_MESSAGE.value:
|
|
626
|
+
state.final_content = getattr(message.data, "content", "") or state.final_content
|
|
627
|
+
if state.final_content:
|
|
628
|
+
break
|
|
629
|
+
except Exception:
|
|
630
|
+
pass
|
|
631
|
+
|
|
632
|
+
if state.error_holder:
|
|
633
|
+
raise state.error_holder[0]
|
|
634
|
+
|
|
635
|
+
return Response(
|
|
636
|
+
content=state.final_content or "".join(state.content_parts),
|
|
637
|
+
reasoning=state.final_reasoning or (
|
|
638
|
+
"".join(state.reasoning_parts) if state.reasoning_parts else None
|
|
639
|
+
),
|
|
640
|
+
raw_events=state.raw_events,
|
|
641
|
+
)
|
|
642
|
+
|
|
643
|
+
async def stream(
|
|
644
|
+
self,
|
|
645
|
+
prompt: str,
|
|
646
|
+
*,
|
|
647
|
+
tools: list[Any] | None = None,
|
|
648
|
+
) -> AsyncIterator[StreamChunk]:
|
|
649
|
+
"""
|
|
650
|
+
Stream a response with automatic retry.
|
|
651
|
+
|
|
652
|
+
Yields StreamChunk objects as they arrive.
|
|
653
|
+
"""
|
|
654
|
+
queue: asyncio.Queue[StreamChunk | None | Exception] = asyncio.Queue()
|
|
655
|
+
|
|
656
|
+
def on_chunk(chunk: StreamChunk) -> None:
|
|
657
|
+
queue.put_nowait(chunk)
|
|
658
|
+
|
|
659
|
+
async def sender() -> None:
|
|
660
|
+
try:
|
|
661
|
+
await self.send(prompt, tools=tools, on_chunk=on_chunk)
|
|
662
|
+
queue.put_nowait(None) # Signal completion
|
|
663
|
+
except Exception as e:
|
|
664
|
+
queue.put_nowait(e)
|
|
665
|
+
|
|
666
|
+
task = asyncio.create_task(sender())
|
|
667
|
+
|
|
668
|
+
try:
|
|
669
|
+
while True:
|
|
670
|
+
item = await queue.get()
|
|
671
|
+
if item is None:
|
|
672
|
+
break
|
|
673
|
+
if isinstance(item, Exception):
|
|
674
|
+
raise item
|
|
675
|
+
yield item
|
|
676
|
+
finally:
|
|
677
|
+
task.cancel()
|
|
678
|
+
try:
|
|
679
|
+
await task
|
|
680
|
+
except asyncio.CancelledError:
|
|
681
|
+
pass
|
|
682
|
+
|
|
683
|
+
async def chat(self, prompt: str) -> str:
|
|
684
|
+
"""Simple interface - send prompt, get response content."""
|
|
685
|
+
response = await self.send(prompt)
|
|
686
|
+
return response.content
|
|
687
|
+
|
|
688
|
+
def new_session(self) -> None:
|
|
689
|
+
"""Start a fresh session (clears conversation history)."""
|
|
690
|
+
if self._session:
|
|
691
|
+
session = self._session
|
|
692
|
+
self._session = None
|
|
693
|
+
try:
|
|
694
|
+
loop = asyncio.get_running_loop()
|
|
695
|
+
except RuntimeError:
|
|
696
|
+
asyncio.run(session.destroy())
|
|
697
|
+
else:
|
|
698
|
+
loop.create_task(session.destroy())
|
|
699
|
+
|
|
700
|
+
|
|
701
|
+
@asynccontextmanager
|
|
702
|
+
async def copex(
|
|
703
|
+
model: Model | str = Model.GPT_5_2_CODEX,
|
|
704
|
+
reasoning: ReasoningEffort | str = ReasoningEffort.XHIGH,
|
|
705
|
+
**kwargs: Any,
|
|
706
|
+
) -> AsyncIterator[Copex]:
|
|
707
|
+
"""
|
|
708
|
+
Context manager for quick Copex access.
|
|
709
|
+
|
|
710
|
+
Example:
|
|
711
|
+
async with copex() as c:
|
|
712
|
+
response = await c.chat("Hello!")
|
|
713
|
+
print(response)
|
|
714
|
+
"""
|
|
715
|
+
config = CopexConfig(
|
|
716
|
+
model=Model(model) if isinstance(model, str) else model,
|
|
717
|
+
reasoning_effort=ReasoningEffort(reasoning) if isinstance(reasoning, str) else reasoning,
|
|
718
|
+
**kwargs,
|
|
719
|
+
)
|
|
720
|
+
client = Copex(config)
|
|
721
|
+
try:
|
|
722
|
+
await client.start()
|
|
723
|
+
yield client
|
|
724
|
+
finally:
|
|
725
|
+
await client.stop()
|