indent 0.1.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- exponent/__init__.py +34 -0
- exponent/cli.py +110 -0
- exponent/commands/cloud_commands.py +585 -0
- exponent/commands/common.py +411 -0
- exponent/commands/config_commands.py +334 -0
- exponent/commands/run_commands.py +222 -0
- exponent/commands/settings.py +56 -0
- exponent/commands/types.py +111 -0
- exponent/commands/upgrade.py +29 -0
- exponent/commands/utils.py +146 -0
- exponent/core/config.py +180 -0
- exponent/core/graphql/__init__.py +0 -0
- exponent/core/graphql/client.py +61 -0
- exponent/core/graphql/get_chats_query.py +47 -0
- exponent/core/graphql/mutations.py +160 -0
- exponent/core/graphql/queries.py +146 -0
- exponent/core/graphql/subscriptions.py +16 -0
- exponent/core/remote_execution/checkpoints.py +212 -0
- exponent/core/remote_execution/cli_rpc_types.py +499 -0
- exponent/core/remote_execution/client.py +999 -0
- exponent/core/remote_execution/code_execution.py +77 -0
- exponent/core/remote_execution/default_env.py +31 -0
- exponent/core/remote_execution/error_info.py +45 -0
- exponent/core/remote_execution/exceptions.py +10 -0
- exponent/core/remote_execution/file_write.py +35 -0
- exponent/core/remote_execution/files.py +330 -0
- exponent/core/remote_execution/git.py +268 -0
- exponent/core/remote_execution/http_fetch.py +94 -0
- exponent/core/remote_execution/languages/python_execution.py +239 -0
- exponent/core/remote_execution/languages/shell_streaming.py +226 -0
- exponent/core/remote_execution/languages/types.py +20 -0
- exponent/core/remote_execution/port_utils.py +73 -0
- exponent/core/remote_execution/session.py +128 -0
- exponent/core/remote_execution/system_context.py +26 -0
- exponent/core/remote_execution/terminal_session.py +375 -0
- exponent/core/remote_execution/terminal_types.py +29 -0
- exponent/core/remote_execution/tool_execution.py +595 -0
- exponent/core/remote_execution/tool_type_utils.py +39 -0
- exponent/core/remote_execution/truncation.py +296 -0
- exponent/core/remote_execution/types.py +635 -0
- exponent/core/remote_execution/utils.py +477 -0
- exponent/core/types/__init__.py +0 -0
- exponent/core/types/command_data.py +206 -0
- exponent/core/types/event_types.py +89 -0
- exponent/core/types/generated/__init__.py +0 -0
- exponent/core/types/generated/strategy_info.py +213 -0
- exponent/migration-docs/login.md +112 -0
- exponent/py.typed +4 -0
- exponent/utils/__init__.py +0 -0
- exponent/utils/colors.py +92 -0
- exponent/utils/version.py +289 -0
- indent-0.1.26.dist-info/METADATA +38 -0
- indent-0.1.26.dist-info/RECORD +55 -0
- indent-0.1.26.dist-info/WHEEL +4 -0
- indent-0.1.26.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,999 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import time
|
|
7
|
+
import uuid
|
|
8
|
+
from collections.abc import AsyncGenerator, Callable, Coroutine, Generator
|
|
9
|
+
from contextlib import asynccontextmanager
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from typing import Any, TypeVar, cast
|
|
12
|
+
|
|
13
|
+
import msgspec
|
|
14
|
+
import websockets.exceptions
|
|
15
|
+
from httpx import (
|
|
16
|
+
AsyncClient,
|
|
17
|
+
codes as http_status,
|
|
18
|
+
)
|
|
19
|
+
from pydantic import BaseModel
|
|
20
|
+
from websockets.asyncio import client as asyncio_websockets_client
|
|
21
|
+
from websockets.asyncio.client import ClientConnection, connect
|
|
22
|
+
|
|
23
|
+
from exponent.commands.utils import ConnectionTracker
|
|
24
|
+
from exponent.core.config import is_editable_install
|
|
25
|
+
from exponent.core.remote_execution import files, system_context
|
|
26
|
+
from exponent.core.remote_execution.cli_rpc_types import (
|
|
27
|
+
BashToolInput,
|
|
28
|
+
BatchToolExecutionRequest,
|
|
29
|
+
BatchToolExecutionResponse,
|
|
30
|
+
CliRpcRequest,
|
|
31
|
+
CliRpcResponse,
|
|
32
|
+
ErrorResponse,
|
|
33
|
+
ErrorToolResult,
|
|
34
|
+
GenerateUploadUrlRequest,
|
|
35
|
+
GenerateUploadUrlResponse,
|
|
36
|
+
GetAllFilesRequest,
|
|
37
|
+
GetAllFilesResponse,
|
|
38
|
+
HttpRequest,
|
|
39
|
+
KeepAliveCliChatRequest,
|
|
40
|
+
KeepAliveCliChatResponse,
|
|
41
|
+
StartTerminalRequest,
|
|
42
|
+
StartTerminalResponse,
|
|
43
|
+
StopTerminalRequest,
|
|
44
|
+
StopTerminalResponse,
|
|
45
|
+
SwitchCLIChatRequest,
|
|
46
|
+
SwitchCLIChatResponse,
|
|
47
|
+
TerminalInputRequest,
|
|
48
|
+
TerminalInputResponse,
|
|
49
|
+
TerminalResizeRequest,
|
|
50
|
+
TerminalResizeResponse,
|
|
51
|
+
TerminateRequest,
|
|
52
|
+
TerminateResponse,
|
|
53
|
+
ToolExecutionRequest,
|
|
54
|
+
ToolExecutionResponse,
|
|
55
|
+
ToolResultType,
|
|
56
|
+
)
|
|
57
|
+
from exponent.core.remote_execution.code_execution import (
|
|
58
|
+
execute_code_streaming,
|
|
59
|
+
)
|
|
60
|
+
from exponent.core.remote_execution.files import file_walk
|
|
61
|
+
from exponent.core.remote_execution.http_fetch import fetch_http_content
|
|
62
|
+
from exponent.core.remote_execution.session import (
|
|
63
|
+
RemoteExecutionClientSession,
|
|
64
|
+
get_session,
|
|
65
|
+
send_exception_log,
|
|
66
|
+
)
|
|
67
|
+
from exponent.core.remote_execution.terminal_session import TerminalSessionManager
|
|
68
|
+
from exponent.core.remote_execution.terminal_types import TerminalMessage
|
|
69
|
+
from exponent.core.remote_execution.tool_execution import (
|
|
70
|
+
execute_bash_tool,
|
|
71
|
+
execute_tool,
|
|
72
|
+
truncate_result,
|
|
73
|
+
)
|
|
74
|
+
from exponent.core.remote_execution.types import (
|
|
75
|
+
ChatSource,
|
|
76
|
+
CLIConnectedState,
|
|
77
|
+
CreateChatResponse,
|
|
78
|
+
HeartbeatInfo,
|
|
79
|
+
RunWorkflowRequest,
|
|
80
|
+
WorkflowInput,
|
|
81
|
+
WorkflowTriggerRequest,
|
|
82
|
+
WorkflowTriggerResponse,
|
|
83
|
+
)
|
|
84
|
+
from exponent.core.remote_execution.utils import (
|
|
85
|
+
deserialize_api_response,
|
|
86
|
+
)
|
|
87
|
+
from exponent.utils.version import get_installed_version
|
|
88
|
+
|
|
89
|
+
logger = logging.getLogger(__name__)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
TModel = TypeVar("TModel", bound=BaseModel)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@dataclass
|
|
96
|
+
class WSDisconnected:
|
|
97
|
+
error_message: str | None = None
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@dataclass
|
|
101
|
+
class SwitchCLIChat:
|
|
102
|
+
new_chat_uuid: str
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
REMOTE_EXECUTION_CLIENT_EXIT_INFO = WSDisconnected | SwitchCLIChat
|
|
106
|
+
|
|
107
|
+
# UUID for a single run of the CLI
|
|
108
|
+
cli_uuid = uuid.uuid4()
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class RemoteExecutionClient:
|
|
112
|
+
def __init__(
|
|
113
|
+
self,
|
|
114
|
+
session: RemoteExecutionClientSession,
|
|
115
|
+
file_cache: files.FileCache | None = None,
|
|
116
|
+
):
|
|
117
|
+
self.current_session = session
|
|
118
|
+
self.file_cache = file_cache or files.FileCache(session.working_directory)
|
|
119
|
+
|
|
120
|
+
# for active code executions, track whether they should be halted
|
|
121
|
+
# correlation_id -> should_halt
|
|
122
|
+
self._halt_states: dict[str, bool] = {}
|
|
123
|
+
self._halt_lock = asyncio.Lock()
|
|
124
|
+
|
|
125
|
+
# Track last request time for timeout functionality
|
|
126
|
+
self._last_request_time: float | None = None
|
|
127
|
+
|
|
128
|
+
# Track pending upload URL requests
|
|
129
|
+
self._pending_upload_requests: dict[
|
|
130
|
+
str, asyncio.Future[GenerateUploadUrlResponse]
|
|
131
|
+
] = {}
|
|
132
|
+
self._upload_request_lock = asyncio.Lock()
|
|
133
|
+
self._websocket: ClientConnection | None = None
|
|
134
|
+
|
|
135
|
+
@property
|
|
136
|
+
def working_directory(self) -> str:
|
|
137
|
+
return self.current_session.working_directory
|
|
138
|
+
|
|
139
|
+
@property
|
|
140
|
+
def api_client(self) -> AsyncClient:
|
|
141
|
+
return self.current_session.api_client
|
|
142
|
+
|
|
143
|
+
@property
|
|
144
|
+
def ws_client(self) -> AsyncClient:
|
|
145
|
+
return self.current_session.ws_client
|
|
146
|
+
|
|
147
|
+
async def add_code_execution_to_halt_states(self, correlation_id: str) -> None:
|
|
148
|
+
async with self._halt_lock:
|
|
149
|
+
self._halt_states[correlation_id] = False
|
|
150
|
+
|
|
151
|
+
async def halt_all_code_executions(self) -> None:
|
|
152
|
+
logger.info(f"Halting all code executions: {self._halt_states}")
|
|
153
|
+
async with self._halt_lock:
|
|
154
|
+
self._halt_states = {
|
|
155
|
+
correlation_id: True for correlation_id in self._halt_states.keys()
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
async def clear_halt_state(self, correlation_id: str) -> None:
|
|
159
|
+
async with self._halt_lock:
|
|
160
|
+
self._halt_states.pop(correlation_id, None)
|
|
161
|
+
|
|
162
|
+
def get_halt_check(self, correlation_id: str) -> Callable[[], bool]:
|
|
163
|
+
def should_halt() -> bool:
|
|
164
|
+
# Don't need to lock here, since just reading from dict
|
|
165
|
+
return self._halt_states.get(correlation_id, False)
|
|
166
|
+
|
|
167
|
+
return should_halt
|
|
168
|
+
|
|
169
|
+
async def _timeout_monitor(
|
|
170
|
+
self, timeout_seconds: int | None
|
|
171
|
+
) -> WSDisconnected | None:
|
|
172
|
+
"""Monitor for inactivity timeout and return WSDisconnected if timeout occurs.
|
|
173
|
+
|
|
174
|
+
If timeout_seconds is None, keeps looping indefinitely until cancelled.
|
|
175
|
+
"""
|
|
176
|
+
try:
|
|
177
|
+
while True:
|
|
178
|
+
await asyncio.sleep(1)
|
|
179
|
+
if (
|
|
180
|
+
timeout_seconds is not None
|
|
181
|
+
and self._last_request_time is not None
|
|
182
|
+
and time.time() - self._last_request_time > timeout_seconds
|
|
183
|
+
):
|
|
184
|
+
logger.info(
|
|
185
|
+
f"No requests received for {timeout_seconds} seconds. Shutting down..."
|
|
186
|
+
)
|
|
187
|
+
return WSDisconnected(
|
|
188
|
+
error_message=f"Timeout after {timeout_seconds} seconds of inactivity"
|
|
189
|
+
)
|
|
190
|
+
except asyncio.CancelledError:
|
|
191
|
+
# Handle cancellation gracefully
|
|
192
|
+
return None
|
|
193
|
+
|
|
194
|
+
async def _handle_websocket_message( # noqa: PLR0911, PLR0915
|
|
195
|
+
self,
|
|
196
|
+
msg: str,
|
|
197
|
+
websocket: ClientConnection,
|
|
198
|
+
requests: asyncio.Queue[CliRpcRequest],
|
|
199
|
+
terminal_session_manager: TerminalSessionManager,
|
|
200
|
+
) -> REMOTE_EXECUTION_CLIENT_EXIT_INFO | None:
|
|
201
|
+
"""Handle an incoming websocket message.
|
|
202
|
+
Returns None to continue processing, or a REMOTE_EXECUTION_CLIENT_EXIT_INFO to exit."""
|
|
203
|
+
|
|
204
|
+
self._last_request_time = time.time()
|
|
205
|
+
|
|
206
|
+
msg_data = json.loads(msg)
|
|
207
|
+
if msg_data["type"] == "result":
|
|
208
|
+
data = json.dumps(msg_data["data"])
|
|
209
|
+
try:
|
|
210
|
+
response = msgspec.json.decode(data, type=CliRpcResponse)
|
|
211
|
+
if isinstance(response.response, GenerateUploadUrlResponse):
|
|
212
|
+
async with self._upload_request_lock:
|
|
213
|
+
if response.request_id in self._pending_upload_requests:
|
|
214
|
+
future = self._pending_upload_requests.pop(
|
|
215
|
+
response.request_id
|
|
216
|
+
)
|
|
217
|
+
future.set_result(response.response)
|
|
218
|
+
except Exception as e:
|
|
219
|
+
logger.error(f"Error handling upload URL response: {e}")
|
|
220
|
+
return None
|
|
221
|
+
elif msg_data["type"] != "request":
|
|
222
|
+
return None
|
|
223
|
+
|
|
224
|
+
data = json.dumps(msg_data["data"])
|
|
225
|
+
try:
|
|
226
|
+
request = msgspec.json.decode(data, type=CliRpcRequest)
|
|
227
|
+
except (msgspec.DecodeError, msgspec.ValidationError) as e:
|
|
228
|
+
# Try and decode to get request_id if possible
|
|
229
|
+
request = msgspec.json.decode(data)
|
|
230
|
+
if isinstance(request, dict) and "request_id" in request:
|
|
231
|
+
request_id = request["request_id"]
|
|
232
|
+
if (
|
|
233
|
+
request.get("request", {}).get("type", {}) == "tool_execution"
|
|
234
|
+
) and (
|
|
235
|
+
"tool_input" in request["request"]
|
|
236
|
+
and "tool_name" in request["request"]["tool_input"]
|
|
237
|
+
):
|
|
238
|
+
tool_name = request["request"]["tool_input"]["tool_name"]
|
|
239
|
+
logger.error(
|
|
240
|
+
f"Error tool {tool_name} received in a request."
|
|
241
|
+
"Please ensure you are running the latest version of Indent. If this issue persists, please contact support."
|
|
242
|
+
)
|
|
243
|
+
await websocket.send(
|
|
244
|
+
json.dumps(
|
|
245
|
+
{
|
|
246
|
+
"type": "result",
|
|
247
|
+
"data": msgspec.to_builtins(
|
|
248
|
+
CliRpcResponse(
|
|
249
|
+
request_id=request_id,
|
|
250
|
+
response=ErrorResponse(
|
|
251
|
+
error_message=f"Unknown tool: {tool_name}. If you are running an older version of Indent, please upgrade to the latest version to ensure compatibility."
|
|
252
|
+
),
|
|
253
|
+
)
|
|
254
|
+
),
|
|
255
|
+
}
|
|
256
|
+
)
|
|
257
|
+
)
|
|
258
|
+
else:
|
|
259
|
+
logger.error(
|
|
260
|
+
"Error decoding cli rpc request. Please ensure you are running the latest version of Indent."
|
|
261
|
+
)
|
|
262
|
+
await websocket.send(
|
|
263
|
+
json.dumps(
|
|
264
|
+
{
|
|
265
|
+
"type": "result",
|
|
266
|
+
"data": msgspec.to_builtins(
|
|
267
|
+
CliRpcResponse(
|
|
268
|
+
request_id=request_id,
|
|
269
|
+
response=ErrorResponse(
|
|
270
|
+
error_message=f"Unknown cli rpc request type: {request}",
|
|
271
|
+
),
|
|
272
|
+
)
|
|
273
|
+
),
|
|
274
|
+
}
|
|
275
|
+
)
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
return None
|
|
279
|
+
else:
|
|
280
|
+
# If we couldn't get a request_id, re-raise and fail noisily
|
|
281
|
+
raise e
|
|
282
|
+
|
|
283
|
+
if isinstance(request.request, TerminateRequest):
|
|
284
|
+
await self.halt_all_code_executions()
|
|
285
|
+
await websocket.send(
|
|
286
|
+
json.dumps(
|
|
287
|
+
{
|
|
288
|
+
"type": "result",
|
|
289
|
+
"data": msgspec.to_builtins(
|
|
290
|
+
CliRpcResponse(
|
|
291
|
+
request_id=request.request_id,
|
|
292
|
+
response=TerminateResponse(),
|
|
293
|
+
)
|
|
294
|
+
),
|
|
295
|
+
}
|
|
296
|
+
)
|
|
297
|
+
)
|
|
298
|
+
return None
|
|
299
|
+
elif isinstance(request.request, SwitchCLIChatRequest):
|
|
300
|
+
await websocket.send(
|
|
301
|
+
json.dumps(
|
|
302
|
+
{
|
|
303
|
+
"type": "result",
|
|
304
|
+
"data": msgspec.to_builtins(
|
|
305
|
+
CliRpcResponse(
|
|
306
|
+
request_id=request.request_id,
|
|
307
|
+
response=SwitchCLIChatResponse(),
|
|
308
|
+
)
|
|
309
|
+
),
|
|
310
|
+
}
|
|
311
|
+
)
|
|
312
|
+
)
|
|
313
|
+
return SwitchCLIChat(new_chat_uuid=request.request.new_chat_uuid)
|
|
314
|
+
elif isinstance(request.request, KeepAliveCliChatRequest):
|
|
315
|
+
await websocket.send(
|
|
316
|
+
json.dumps(
|
|
317
|
+
{
|
|
318
|
+
"type": "result",
|
|
319
|
+
"data": msgspec.to_builtins(
|
|
320
|
+
CliRpcResponse(
|
|
321
|
+
request_id=request.request_id,
|
|
322
|
+
response=KeepAliveCliChatResponse(),
|
|
323
|
+
)
|
|
324
|
+
),
|
|
325
|
+
}
|
|
326
|
+
)
|
|
327
|
+
)
|
|
328
|
+
return None
|
|
329
|
+
elif isinstance(request.request, StartTerminalRequest):
|
|
330
|
+
# Start a new terminal session
|
|
331
|
+
session_id = await terminal_session_manager.start_session(
|
|
332
|
+
websocket=websocket,
|
|
333
|
+
session_id=request.request.session_id,
|
|
334
|
+
command=request.request.command,
|
|
335
|
+
cols=request.request.cols,
|
|
336
|
+
rows=request.request.rows,
|
|
337
|
+
env=request.request.env,
|
|
338
|
+
)
|
|
339
|
+
await websocket.send(
|
|
340
|
+
json.dumps(
|
|
341
|
+
{
|
|
342
|
+
"type": "result",
|
|
343
|
+
"data": msgspec.to_builtins(
|
|
344
|
+
CliRpcResponse(
|
|
345
|
+
request_id=request.request_id,
|
|
346
|
+
response=StartTerminalResponse(
|
|
347
|
+
session_id=session_id, success=True
|
|
348
|
+
),
|
|
349
|
+
)
|
|
350
|
+
),
|
|
351
|
+
}
|
|
352
|
+
)
|
|
353
|
+
)
|
|
354
|
+
return None
|
|
355
|
+
elif isinstance(request.request, TerminalInputRequest):
|
|
356
|
+
# Send input to terminal session
|
|
357
|
+
success = await terminal_session_manager.send_input(
|
|
358
|
+
session_id=request.request.session_id,
|
|
359
|
+
data=request.request.data,
|
|
360
|
+
)
|
|
361
|
+
await websocket.send(
|
|
362
|
+
json.dumps(
|
|
363
|
+
{
|
|
364
|
+
"type": "result",
|
|
365
|
+
"data": msgspec.to_builtins(
|
|
366
|
+
CliRpcResponse(
|
|
367
|
+
request_id=request.request_id,
|
|
368
|
+
response=TerminalInputResponse(
|
|
369
|
+
session_id=request.request.session_id,
|
|
370
|
+
success=success,
|
|
371
|
+
),
|
|
372
|
+
)
|
|
373
|
+
),
|
|
374
|
+
}
|
|
375
|
+
)
|
|
376
|
+
)
|
|
377
|
+
return None
|
|
378
|
+
elif isinstance(request.request, TerminalResizeRequest):
|
|
379
|
+
# Resize terminal session
|
|
380
|
+
success = await terminal_session_manager.resize_terminal(
|
|
381
|
+
session_id=request.request.session_id,
|
|
382
|
+
rows=request.request.rows,
|
|
383
|
+
cols=request.request.cols,
|
|
384
|
+
)
|
|
385
|
+
await websocket.send(
|
|
386
|
+
json.dumps(
|
|
387
|
+
{
|
|
388
|
+
"type": "result",
|
|
389
|
+
"data": msgspec.to_builtins(
|
|
390
|
+
CliRpcResponse(
|
|
391
|
+
request_id=request.request_id,
|
|
392
|
+
response=TerminalResizeResponse(
|
|
393
|
+
session_id=request.request.session_id,
|
|
394
|
+
success=success,
|
|
395
|
+
),
|
|
396
|
+
)
|
|
397
|
+
),
|
|
398
|
+
}
|
|
399
|
+
)
|
|
400
|
+
)
|
|
401
|
+
return None
|
|
402
|
+
elif isinstance(request.request, StopTerminalRequest):
|
|
403
|
+
# Stop terminal session
|
|
404
|
+
success = await terminal_session_manager.stop_session(
|
|
405
|
+
session_id=request.request.session_id
|
|
406
|
+
)
|
|
407
|
+
await websocket.send(
|
|
408
|
+
json.dumps(
|
|
409
|
+
{
|
|
410
|
+
"type": "result",
|
|
411
|
+
"data": msgspec.to_builtins(
|
|
412
|
+
CliRpcResponse(
|
|
413
|
+
request_id=request.request_id,
|
|
414
|
+
response=StopTerminalResponse(
|
|
415
|
+
session_id=request.request.session_id,
|
|
416
|
+
success=success,
|
|
417
|
+
),
|
|
418
|
+
)
|
|
419
|
+
),
|
|
420
|
+
}
|
|
421
|
+
)
|
|
422
|
+
)
|
|
423
|
+
return None
|
|
424
|
+
else:
|
|
425
|
+
if isinstance(request.request, ToolExecutionRequest) and isinstance(
|
|
426
|
+
request.request.tool_input, BashToolInput
|
|
427
|
+
):
|
|
428
|
+
await self.add_code_execution_to_halt_states(request.request_id)
|
|
429
|
+
elif isinstance(request.request, BatchToolExecutionRequest):
|
|
430
|
+
# Add halt state if any of the batch tools are bash commands
|
|
431
|
+
if any(
|
|
432
|
+
isinstance(tool_input, BashToolInput)
|
|
433
|
+
for tool_input in request.request.tool_inputs
|
|
434
|
+
):
|
|
435
|
+
await self.add_code_execution_to_halt_states(request.request_id)
|
|
436
|
+
|
|
437
|
+
await requests.put(request)
|
|
438
|
+
return None
|
|
439
|
+
|
|
440
|
+
async def _setup_tasks(
|
|
441
|
+
self,
|
|
442
|
+
beats: asyncio.Queue[HeartbeatInfo],
|
|
443
|
+
requests: asyncio.Queue[CliRpcRequest],
|
|
444
|
+
results: asyncio.Queue[CliRpcResponse],
|
|
445
|
+
) -> list[asyncio.Task[None]]:
|
|
446
|
+
"""Setup heartbeat and executor tasks."""
|
|
447
|
+
|
|
448
|
+
async def beat() -> None:
|
|
449
|
+
while True:
|
|
450
|
+
info = await self.get_heartbeat_info()
|
|
451
|
+
await beats.put(info)
|
|
452
|
+
await asyncio.sleep(3)
|
|
453
|
+
|
|
454
|
+
# Lock to ensure that only one executor can grab a
|
|
455
|
+
# request at a time.
|
|
456
|
+
requests_lock = asyncio.Lock()
|
|
457
|
+
|
|
458
|
+
# Lock to ensure that only one executor can put a
|
|
459
|
+
# result in the results queue at a time.
|
|
460
|
+
results_lock = asyncio.Lock()
|
|
461
|
+
|
|
462
|
+
async def executor() -> None:
|
|
463
|
+
# We use locks here to protect the request/result
|
|
464
|
+
# queues from being accessed by multiple executors.
|
|
465
|
+
while True:
|
|
466
|
+
async with requests_lock:
|
|
467
|
+
request = await requests.get()
|
|
468
|
+
|
|
469
|
+
try:
|
|
470
|
+
# Check if this is a streaming request
|
|
471
|
+
from exponent.core.remote_execution.cli_rpc_types import (
|
|
472
|
+
StreamingCodeExecutionRequest,
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
if isinstance(request.request, StreamingCodeExecutionRequest):
|
|
476
|
+
async for streaming_response in self.handle_streaming_request(
|
|
477
|
+
request.request
|
|
478
|
+
):
|
|
479
|
+
async with results_lock:
|
|
480
|
+
await results.put(
|
|
481
|
+
CliRpcResponse(
|
|
482
|
+
request_id=request.request_id,
|
|
483
|
+
response=streaming_response,
|
|
484
|
+
)
|
|
485
|
+
)
|
|
486
|
+
else:
|
|
487
|
+
# Note that we don't want to hold the lock here
|
|
488
|
+
# because we want other executors to be able to
|
|
489
|
+
# grab requests while we're handling a request.
|
|
490
|
+
logger.info(f"Handling request {request}")
|
|
491
|
+
response = await self.handle_request(request)
|
|
492
|
+
async with results_lock:
|
|
493
|
+
logger.info(f"Putting response {response}")
|
|
494
|
+
await results.put(response)
|
|
495
|
+
except Exception as e:
|
|
496
|
+
logger.info(f"Error handling request {request}:\n\n{e}")
|
|
497
|
+
try:
|
|
498
|
+
await send_exception_log(e, session=self.current_session)
|
|
499
|
+
except Exception:
|
|
500
|
+
pass
|
|
501
|
+
async with results_lock:
|
|
502
|
+
from exponent.core.remote_execution.cli_rpc_types import (
|
|
503
|
+
StreamingCodeExecutionRequest,
|
|
504
|
+
StreamingErrorResponse,
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
if isinstance(request.request, StreamingCodeExecutionRequest):
|
|
508
|
+
# For streaming requests, send a streaming error response
|
|
509
|
+
await results.put(
|
|
510
|
+
CliRpcResponse(
|
|
511
|
+
request_id=request.request_id,
|
|
512
|
+
response=StreamingErrorResponse(
|
|
513
|
+
correlation_id=request.request.correlation_id,
|
|
514
|
+
error_message=str(e),
|
|
515
|
+
),
|
|
516
|
+
)
|
|
517
|
+
)
|
|
518
|
+
else:
|
|
519
|
+
await results.put(
|
|
520
|
+
CliRpcResponse(
|
|
521
|
+
request_id=request.request_id,
|
|
522
|
+
response=ErrorResponse(
|
|
523
|
+
error_message=str(e),
|
|
524
|
+
),
|
|
525
|
+
)
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
beat_task = asyncio.create_task(beat())
|
|
529
|
+
# Three parallel executors to handle requests
|
|
530
|
+
|
|
531
|
+
executor_tasks = [
|
|
532
|
+
asyncio.create_task(executor()),
|
|
533
|
+
asyncio.create_task(executor()),
|
|
534
|
+
asyncio.create_task(executor()),
|
|
535
|
+
]
|
|
536
|
+
|
|
537
|
+
return [beat_task, *executor_tasks]
|
|
538
|
+
|
|
539
|
+
async def _process_websocket_messages(
|
|
540
|
+
self,
|
|
541
|
+
websocket: ClientConnection,
|
|
542
|
+
beats: asyncio.Queue[HeartbeatInfo],
|
|
543
|
+
requests: asyncio.Queue[CliRpcRequest],
|
|
544
|
+
results: asyncio.Queue[CliRpcResponse],
|
|
545
|
+
terminal_output_queue: asyncio.Queue[TerminalMessage],
|
|
546
|
+
terminal_session_manager: TerminalSessionManager,
|
|
547
|
+
) -> REMOTE_EXECUTION_CLIENT_EXIT_INFO:
|
|
548
|
+
"""Process messages from the websocket connection."""
|
|
549
|
+
pending: set[asyncio.Task[object]] = set()
|
|
550
|
+
try:
|
|
551
|
+
recv = asyncio.create_task(websocket.recv())
|
|
552
|
+
get_beat = asyncio.create_task(beats.get())
|
|
553
|
+
get_result = asyncio.create_task(results.get())
|
|
554
|
+
get_terminal_output = asyncio.create_task(terminal_output_queue.get())
|
|
555
|
+
pending = {recv, get_beat, get_result, get_terminal_output}
|
|
556
|
+
|
|
557
|
+
while True:
|
|
558
|
+
done, pending = await asyncio.wait(
|
|
559
|
+
pending, return_when=asyncio.FIRST_COMPLETED
|
|
560
|
+
)
|
|
561
|
+
|
|
562
|
+
if recv in done:
|
|
563
|
+
msg = str(recv.result())
|
|
564
|
+
exit_info = await self._handle_websocket_message(
|
|
565
|
+
msg, websocket, requests, terminal_session_manager
|
|
566
|
+
)
|
|
567
|
+
if exit_info is not None:
|
|
568
|
+
return exit_info
|
|
569
|
+
|
|
570
|
+
recv = asyncio.create_task(websocket.recv())
|
|
571
|
+
pending.add(recv)
|
|
572
|
+
|
|
573
|
+
if get_beat in done:
|
|
574
|
+
info = get_beat.result()
|
|
575
|
+
data = json.loads(info.model_dump_json())
|
|
576
|
+
msg = json.dumps({"type": "heartbeat", "data": data})
|
|
577
|
+
await websocket.send(msg)
|
|
578
|
+
|
|
579
|
+
get_beat = asyncio.create_task(beats.get())
|
|
580
|
+
pending.add(get_beat)
|
|
581
|
+
|
|
582
|
+
if get_result in done:
|
|
583
|
+
response = get_result.result()
|
|
584
|
+
# All responses are now CliRpcResponse with msgspec
|
|
585
|
+
data = msgspec.to_builtins(response)
|
|
586
|
+
msg = json.dumps({"type": "result", "data": data})
|
|
587
|
+
await websocket.send(msg)
|
|
588
|
+
|
|
589
|
+
get_result = asyncio.create_task(results.get())
|
|
590
|
+
pending.add(get_result)
|
|
591
|
+
|
|
592
|
+
if get_terminal_output in done:
|
|
593
|
+
terminal_message = get_terminal_output.result()
|
|
594
|
+
data = msgspec.to_builtins(terminal_message)
|
|
595
|
+
msg = json.dumps({"type": "terminal_message", "data": data})
|
|
596
|
+
await websocket.send(msg)
|
|
597
|
+
|
|
598
|
+
get_terminal_output = asyncio.create_task(
|
|
599
|
+
terminal_output_queue.get()
|
|
600
|
+
)
|
|
601
|
+
pending.add(get_terminal_output)
|
|
602
|
+
finally:
|
|
603
|
+
for task in pending:
|
|
604
|
+
task.cancel()
|
|
605
|
+
|
|
606
|
+
await asyncio.gather(*pending, return_exceptions=True)
|
|
607
|
+
|
|
608
|
+
async def _handle_websocket_connection(
|
|
609
|
+
self,
|
|
610
|
+
websocket: ClientConnection,
|
|
611
|
+
connection_tracker: ConnectionTracker | None,
|
|
612
|
+
beats: asyncio.Queue[HeartbeatInfo],
|
|
613
|
+
requests: asyncio.Queue[CliRpcRequest],
|
|
614
|
+
results: asyncio.Queue[CliRpcResponse],
|
|
615
|
+
terminal_output_queue: asyncio.Queue[TerminalMessage],
|
|
616
|
+
terminal_session_manager: TerminalSessionManager,
|
|
617
|
+
) -> REMOTE_EXECUTION_CLIENT_EXIT_INFO | None:
|
|
618
|
+
"""Handle a single websocket connection.
|
|
619
|
+
Returns None to continue with reconnection attempts, or an exit info to terminate."""
|
|
620
|
+
if connection_tracker is not None:
|
|
621
|
+
await connection_tracker.set_connected(True)
|
|
622
|
+
|
|
623
|
+
self._websocket = websocket
|
|
624
|
+
|
|
625
|
+
try:
|
|
626
|
+
return await self._process_websocket_messages(
|
|
627
|
+
websocket,
|
|
628
|
+
beats,
|
|
629
|
+
requests,
|
|
630
|
+
results,
|
|
631
|
+
terminal_output_queue,
|
|
632
|
+
terminal_session_manager,
|
|
633
|
+
)
|
|
634
|
+
except websockets.exceptions.ConnectionClosed as e:
|
|
635
|
+
if e.rcvd is not None:
|
|
636
|
+
if e.rcvd.code == 1000:
|
|
637
|
+
# Normal closure, exit completely
|
|
638
|
+
return WSDisconnected()
|
|
639
|
+
elif e.rcvd.code == 1008:
|
|
640
|
+
error_message = (
|
|
641
|
+
"Error connecting to websocket"
|
|
642
|
+
if e.rcvd.reason is None
|
|
643
|
+
else e.rcvd.reason
|
|
644
|
+
)
|
|
645
|
+
return WSDisconnected(error_message=error_message)
|
|
646
|
+
# Otherwise, allow reconnection attempt
|
|
647
|
+
logger.debug("Websocket connection closed by remote.")
|
|
648
|
+
return None
|
|
649
|
+
except TimeoutError:
|
|
650
|
+
# Timeout, allow reconnection attempt
|
|
651
|
+
# TODO: investgate if this is needed, possibly scope it down
|
|
652
|
+
return None
|
|
653
|
+
finally:
|
|
654
|
+
if connection_tracker is not None:
|
|
655
|
+
await connection_tracker.set_connected(False)
|
|
656
|
+
|
|
657
|
+
async def run_connection(
|
|
658
|
+
self,
|
|
659
|
+
chat_uuid: str,
|
|
660
|
+
connection_tracker: ConnectionTracker | None = None,
|
|
661
|
+
timeout_seconds: int | None = None,
|
|
662
|
+
) -> REMOTE_EXECUTION_CLIENT_EXIT_INFO:
|
|
663
|
+
"""Run the websocket connection loop with optional inactivity timeout."""
|
|
664
|
+
self.current_session.set_chat_uuid(chat_uuid)
|
|
665
|
+
|
|
666
|
+
# Initialize last request time for timeout monitoring
|
|
667
|
+
self._last_request_time = time.time()
|
|
668
|
+
|
|
669
|
+
# Create queues ONCE - persist across reconnections
|
|
670
|
+
beats: asyncio.Queue[HeartbeatInfo] = asyncio.Queue()
|
|
671
|
+
requests: asyncio.Queue[CliRpcRequest] = asyncio.Queue()
|
|
672
|
+
results: asyncio.Queue[CliRpcResponse] = asyncio.Queue()
|
|
673
|
+
terminal_output_queue: asyncio.Queue[TerminalMessage] = asyncio.Queue()
|
|
674
|
+
|
|
675
|
+
# Create terminal session manager ONCE - persist across reconnections
|
|
676
|
+
terminal_session_manager = TerminalSessionManager(terminal_output_queue)
|
|
677
|
+
|
|
678
|
+
# Create tasks ONCE - persist across reconnections
|
|
679
|
+
executors = await self._setup_tasks(beats, requests, results)
|
|
680
|
+
|
|
681
|
+
try:
|
|
682
|
+
async for websocket in self.ws_connect(f"/api/ws/chat/{chat_uuid}"):
|
|
683
|
+
# Always run connection and timeout monitor concurrently
|
|
684
|
+
# If timeout_seconds is None, timeout monitor will loop indefinitely
|
|
685
|
+
done, pending = await asyncio.wait(
|
|
686
|
+
[
|
|
687
|
+
asyncio.create_task(
|
|
688
|
+
self._handle_websocket_connection(
|
|
689
|
+
websocket,
|
|
690
|
+
connection_tracker,
|
|
691
|
+
beats,
|
|
692
|
+
requests,
|
|
693
|
+
results,
|
|
694
|
+
terminal_output_queue,
|
|
695
|
+
terminal_session_manager,
|
|
696
|
+
)
|
|
697
|
+
),
|
|
698
|
+
asyncio.create_task(self._timeout_monitor(timeout_seconds)),
|
|
699
|
+
],
|
|
700
|
+
return_when=asyncio.FIRST_COMPLETED,
|
|
701
|
+
)
|
|
702
|
+
|
|
703
|
+
# Cancel pending tasks
|
|
704
|
+
for task in pending:
|
|
705
|
+
task.cancel()
|
|
706
|
+
|
|
707
|
+
# Return result from completed task
|
|
708
|
+
for task in done:
|
|
709
|
+
result = await task
|
|
710
|
+
# If we get None, we'll try to reconnect
|
|
711
|
+
if result is not None:
|
|
712
|
+
return result
|
|
713
|
+
|
|
714
|
+
# If we exit the websocket connection loop without returning,
|
|
715
|
+
# it means we couldn't establish a connection
|
|
716
|
+
return WSDisconnected(
|
|
717
|
+
error_message="Could not establish websocket connection"
|
|
718
|
+
)
|
|
719
|
+
finally:
|
|
720
|
+
# Stop all terminal sessions to clean up PTY processes
|
|
721
|
+
await terminal_session_manager.stop_all_sessions()
|
|
722
|
+
|
|
723
|
+
# Cancel all background tasks when exiting
|
|
724
|
+
for task in executors:
|
|
725
|
+
task.cancel()
|
|
726
|
+
await asyncio.gather(*executors, return_exceptions=True)
|
|
727
|
+
|
|
728
|
+
async def create_chat(self, chat_source: ChatSource) -> CreateChatResponse:
|
|
729
|
+
response = await self.api_client.post(
|
|
730
|
+
"/api/remote_execution/create_chat",
|
|
731
|
+
params={"chat_source": chat_source.value},
|
|
732
|
+
)
|
|
733
|
+
return await deserialize_api_response(response, CreateChatResponse)
|
|
734
|
+
|
|
735
|
+
# deprecated
|
|
736
|
+
async def run_workflow(self, chat_uuid: str, workflow_id: str) -> dict[str, Any]:
|
|
737
|
+
response = await self.api_client.post(
|
|
738
|
+
"/api/remote_execution/run_workflow",
|
|
739
|
+
json=RunWorkflowRequest(
|
|
740
|
+
chat_uuid=chat_uuid,
|
|
741
|
+
workflow_id=workflow_id,
|
|
742
|
+
).model_dump(),
|
|
743
|
+
timeout=60,
|
|
744
|
+
)
|
|
745
|
+
if response.status_code != http_status.OK:
|
|
746
|
+
raise Exception(
|
|
747
|
+
f"Failed to run workflow with status code {response.status_code} and response {response.text}"
|
|
748
|
+
)
|
|
749
|
+
return cast(dict[str, Any], response.json())
|
|
750
|
+
|
|
751
|
+
async def trigger_workflow(
|
|
752
|
+
self, workflow_name: str, workflow_input: WorkflowInput
|
|
753
|
+
) -> WorkflowTriggerResponse:
|
|
754
|
+
response = await self.api_client.post(
|
|
755
|
+
"/api/remote_execution/trigger_workflow",
|
|
756
|
+
json=WorkflowTriggerRequest(
|
|
757
|
+
workflow_name=workflow_name,
|
|
758
|
+
workflow_input=workflow_input,
|
|
759
|
+
).model_dump(),
|
|
760
|
+
)
|
|
761
|
+
return await deserialize_api_response(response, WorkflowTriggerResponse)
|
|
762
|
+
|
|
763
|
+
async def get_heartbeat_info(self) -> HeartbeatInfo:
|
|
764
|
+
return HeartbeatInfo(
|
|
765
|
+
system_info=await system_context.get_system_info(self.working_directory),
|
|
766
|
+
exponent_version=get_installed_version(),
|
|
767
|
+
editable_installation=is_editable_install(),
|
|
768
|
+
cli_uuid=str(cli_uuid),
|
|
769
|
+
)
|
|
770
|
+
|
|
771
|
+
async def send_heartbeat(self, chat_uuid: str) -> CLIConnectedState:
|
|
772
|
+
logger.info(f"Sending heartbeat for chat_uuid {chat_uuid}")
|
|
773
|
+
heartbeat_info = await self.get_heartbeat_info()
|
|
774
|
+
response = await self.api_client.post(
|
|
775
|
+
f"/api/remote_execution/{chat_uuid}/heartbeat",
|
|
776
|
+
content=heartbeat_info.model_dump_json(),
|
|
777
|
+
timeout=60,
|
|
778
|
+
)
|
|
779
|
+
if response.status_code != http_status.OK:
|
|
780
|
+
raise Exception(
|
|
781
|
+
f"Heartbeat failed with status code {response.status_code} and response {response.text}"
|
|
782
|
+
)
|
|
783
|
+
connected_state = await deserialize_api_response(response, CLIConnectedState)
|
|
784
|
+
logger.info(f"Heartbeat response: {connected_state}")
|
|
785
|
+
return connected_state
|
|
786
|
+
|
|
787
|
+
async def request_upload_url(
|
|
788
|
+
self, s3_key: str, content_type: str
|
|
789
|
+
) -> GenerateUploadUrlResponse:
|
|
790
|
+
if self._websocket is None:
|
|
791
|
+
raise RuntimeError("No active websocket connection")
|
|
792
|
+
|
|
793
|
+
request_id = str(uuid.uuid4())
|
|
794
|
+
request = CliRpcRequest(
|
|
795
|
+
request_id=request_id,
|
|
796
|
+
request=GenerateUploadUrlRequest(s3_key=s3_key, content_type=content_type),
|
|
797
|
+
)
|
|
798
|
+
|
|
799
|
+
future: asyncio.Future[GenerateUploadUrlResponse] = asyncio.Future()
|
|
800
|
+
async with self._upload_request_lock:
|
|
801
|
+
self._pending_upload_requests[request_id] = future
|
|
802
|
+
|
|
803
|
+
try:
|
|
804
|
+
await self._websocket.send(
|
|
805
|
+
json.dumps({"type": "request", "data": msgspec.to_builtins(request)})
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
response = await asyncio.wait_for(future, timeout=30)
|
|
809
|
+
return response
|
|
810
|
+
except TimeoutError:
|
|
811
|
+
async with self._upload_request_lock:
|
|
812
|
+
self._pending_upload_requests.pop(request_id, None)
|
|
813
|
+
raise RuntimeError("Timeout waiting for upload URL response")
|
|
814
|
+
except Exception as e:
|
|
815
|
+
async with self._upload_request_lock:
|
|
816
|
+
self._pending_upload_requests.pop(request_id, None)
|
|
817
|
+
raise e
|
|
818
|
+
|
|
819
|
+
async def handle_request(self, request: CliRpcRequest) -> CliRpcResponse:
|
|
820
|
+
# Update last request time for timeout functionality
|
|
821
|
+
self._last_request_time = time.time()
|
|
822
|
+
|
|
823
|
+
try:
|
|
824
|
+
if isinstance(request.request, ToolExecutionRequest):
|
|
825
|
+
if isinstance(request.request.tool_input, BashToolInput):
|
|
826
|
+
raw_result = await execute_bash_tool(
|
|
827
|
+
request.request.tool_input,
|
|
828
|
+
self.working_directory,
|
|
829
|
+
should_halt=self.get_halt_check(request.request_id),
|
|
830
|
+
)
|
|
831
|
+
else:
|
|
832
|
+
raw_result = await execute_tool( # type: ignore[assignment]
|
|
833
|
+
request.request.tool_input, self.working_directory, self
|
|
834
|
+
)
|
|
835
|
+
tool_result = truncate_result(raw_result)
|
|
836
|
+
return CliRpcResponse(
|
|
837
|
+
request_id=request.request_id,
|
|
838
|
+
response=ToolExecutionResponse(
|
|
839
|
+
tool_result=tool_result,
|
|
840
|
+
),
|
|
841
|
+
)
|
|
842
|
+
elif isinstance(request.request, GetAllFilesRequest):
|
|
843
|
+
files = await file_walk(self.working_directory)
|
|
844
|
+
return CliRpcResponse(
|
|
845
|
+
request_id=request.request_id,
|
|
846
|
+
response=GetAllFilesResponse(files=files),
|
|
847
|
+
)
|
|
848
|
+
elif isinstance(request.request, BatchToolExecutionRequest):
|
|
849
|
+
coros: list[Coroutine[Any, Any, ToolResultType]] = []
|
|
850
|
+
for tool_input in request.request.tool_inputs:
|
|
851
|
+
if isinstance(tool_input, BashToolInput):
|
|
852
|
+
coros.append(
|
|
853
|
+
execute_bash_tool(
|
|
854
|
+
tool_input,
|
|
855
|
+
self.working_directory,
|
|
856
|
+
should_halt=self.get_halt_check(request.request_id),
|
|
857
|
+
)
|
|
858
|
+
)
|
|
859
|
+
else:
|
|
860
|
+
coros.append(
|
|
861
|
+
execute_tool(tool_input, self.working_directory, self)
|
|
862
|
+
)
|
|
863
|
+
|
|
864
|
+
results: list[ToolResultType | BaseException] = await asyncio.gather(
|
|
865
|
+
*coros, return_exceptions=True
|
|
866
|
+
)
|
|
867
|
+
|
|
868
|
+
processed_results: list[ToolResultType] = []
|
|
869
|
+
for result in results:
|
|
870
|
+
if not isinstance(result, BaseException):
|
|
871
|
+
processed_results.append(truncate_result(result))
|
|
872
|
+
else:
|
|
873
|
+
processed_results.append(
|
|
874
|
+
ErrorToolResult(error_message=str(result))
|
|
875
|
+
)
|
|
876
|
+
|
|
877
|
+
return CliRpcResponse(
|
|
878
|
+
request_id=request.request_id,
|
|
879
|
+
response=BatchToolExecutionResponse(
|
|
880
|
+
tool_results=processed_results,
|
|
881
|
+
),
|
|
882
|
+
)
|
|
883
|
+
elif isinstance(request.request, HttpRequest):
|
|
884
|
+
http_response = await fetch_http_content(request.request)
|
|
885
|
+
return CliRpcResponse(
|
|
886
|
+
request_id=request.request_id,
|
|
887
|
+
response=http_response,
|
|
888
|
+
)
|
|
889
|
+
elif isinstance(request.request, TerminateRequest):
|
|
890
|
+
raise ValueError(
|
|
891
|
+
"TerminateRequest should not be handled by handle_request"
|
|
892
|
+
)
|
|
893
|
+
|
|
894
|
+
elif isinstance(request.request, SwitchCLIChatRequest):
|
|
895
|
+
raise ValueError(
|
|
896
|
+
"SwitchCLIChatRequest should not be handled by handle_request"
|
|
897
|
+
)
|
|
898
|
+
elif isinstance(request.request, KeepAliveCliChatRequest):
|
|
899
|
+
raise ValueError(
|
|
900
|
+
"KeepAliveCliChatRequest should not be handled by handle_request"
|
|
901
|
+
)
|
|
902
|
+
elif isinstance(request.request, StartTerminalRequest):
|
|
903
|
+
raise ValueError(
|
|
904
|
+
"StartTerminalRequest should not be handled by handle_request"
|
|
905
|
+
)
|
|
906
|
+
elif isinstance(request.request, TerminalInputRequest):
|
|
907
|
+
raise ValueError(
|
|
908
|
+
"TerminalInputRequest should not be handled by handle_request"
|
|
909
|
+
)
|
|
910
|
+
elif isinstance(request.request, TerminalResizeRequest):
|
|
911
|
+
raise ValueError(
|
|
912
|
+
"TerminalResizeRequest should not be handled by handle_request"
|
|
913
|
+
)
|
|
914
|
+
elif isinstance(request.request, StopTerminalRequest):
|
|
915
|
+
raise ValueError(
|
|
916
|
+
"StopTerminalRequest should not be handled by handle_request"
|
|
917
|
+
)
|
|
918
|
+
|
|
919
|
+
raise ValueError(f"Unhandled request type: {type(request)}")
|
|
920
|
+
|
|
921
|
+
except Exception as e:
|
|
922
|
+
logger.error(f"Error handling request {request}:\n\n{e}")
|
|
923
|
+
raise e
|
|
924
|
+
finally:
|
|
925
|
+
# Clean up halt state after request is complete
|
|
926
|
+
if isinstance(request.request, ToolExecutionRequest) and isinstance(
|
|
927
|
+
request.request.tool_input, BashToolInput
|
|
928
|
+
):
|
|
929
|
+
await self.clear_halt_state(request.request_id)
|
|
930
|
+
elif isinstance(request.request, BatchToolExecutionRequest):
|
|
931
|
+
# Clear halt state if any of the batch tools were bash commands
|
|
932
|
+
if any(
|
|
933
|
+
isinstance(tool_input, BashToolInput)
|
|
934
|
+
for tool_input in request.request.tool_inputs
|
|
935
|
+
):
|
|
936
|
+
await self.clear_halt_state(request.request_id)
|
|
937
|
+
|
|
938
|
+
async def handle_streaming_request(
|
|
939
|
+
self,
|
|
940
|
+
request: Any,
|
|
941
|
+
) -> AsyncGenerator[Any, None]:
|
|
942
|
+
from exponent.core.remote_execution.cli_rpc_types import (
|
|
943
|
+
StreamingCodeExecutionRequest,
|
|
944
|
+
)
|
|
945
|
+
|
|
946
|
+
if not isinstance(request, StreamingCodeExecutionRequest):
|
|
947
|
+
assert False, f"{type(request)} should be sent to handle_streaming_request"
|
|
948
|
+
async for output in execute_code_streaming(
|
|
949
|
+
request,
|
|
950
|
+
self.current_session,
|
|
951
|
+
working_directory=self.working_directory,
|
|
952
|
+
should_halt=self.get_halt_check(request.correlation_id),
|
|
953
|
+
):
|
|
954
|
+
yield output
|
|
955
|
+
|
|
956
|
+
def ws_connect(self, path: str) -> connect:
|
|
957
|
+
base_url = (
|
|
958
|
+
str(self.ws_client.base_url)
|
|
959
|
+
.replace("http://", "ws://")
|
|
960
|
+
.replace("https://", "wss://")
|
|
961
|
+
)
|
|
962
|
+
|
|
963
|
+
url = f"{base_url}{path}"
|
|
964
|
+
headers = {"api-key": self.api_client.headers["api-key"]}
|
|
965
|
+
|
|
966
|
+
def custom_backoff() -> Generator[float, None, None]:
|
|
967
|
+
yield 0.1 # short initial delay
|
|
968
|
+
|
|
969
|
+
delay = 0.5
|
|
970
|
+
while True:
|
|
971
|
+
if delay < 2.0:
|
|
972
|
+
yield delay
|
|
973
|
+
delay *= 1.5
|
|
974
|
+
else:
|
|
975
|
+
yield 2.0
|
|
976
|
+
|
|
977
|
+
# Can remove if this is added to public API
|
|
978
|
+
# https://github.com/python-websockets/websockets/issues/1395#issuecomment-3225670409
|
|
979
|
+
asyncio_websockets_client.backoff = custom_backoff # type: ignore[attr-defined, assignment]
|
|
980
|
+
|
|
981
|
+
conn = connect(
|
|
982
|
+
url, additional_headers=headers, open_timeout=10, ping_timeout=10
|
|
983
|
+
)
|
|
984
|
+
|
|
985
|
+
return conn
|
|
986
|
+
|
|
987
|
+
@staticmethod
|
|
988
|
+
@asynccontextmanager
|
|
989
|
+
async def session(
|
|
990
|
+
api_key: str,
|
|
991
|
+
base_url: str,
|
|
992
|
+
base_ws_url: str,
|
|
993
|
+
working_directory: str,
|
|
994
|
+
file_cache: files.FileCache | None = None,
|
|
995
|
+
) -> AsyncGenerator[RemoteExecutionClient, None]:
|
|
996
|
+
async with get_session(
|
|
997
|
+
working_directory, base_url, base_ws_url, api_key
|
|
998
|
+
) as session:
|
|
999
|
+
yield RemoteExecutionClient(session, file_cache)
|