indent 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of indent might be problematic. Click here for more details.
- exponent/__init__.py +1 -0
- exponent/cli.py +112 -0
- exponent/commands/cloud_commands.py +85 -0
- exponent/commands/common.py +434 -0
- exponent/commands/config_commands.py +581 -0
- exponent/commands/github_app_commands.py +211 -0
- exponent/commands/listen_commands.py +96 -0
- exponent/commands/run_commands.py +208 -0
- exponent/commands/settings.py +56 -0
- exponent/commands/shell_commands.py +2840 -0
- exponent/commands/theme.py +246 -0
- exponent/commands/types.py +111 -0
- exponent/commands/upgrade.py +29 -0
- exponent/commands/utils.py +236 -0
- exponent/core/config.py +180 -0
- exponent/core/graphql/__init__.py +0 -0
- exponent/core/graphql/client.py +59 -0
- exponent/core/graphql/cloud_config_queries.py +77 -0
- exponent/core/graphql/get_chats_query.py +47 -0
- exponent/core/graphql/github_config_queries.py +56 -0
- exponent/core/graphql/mutations.py +75 -0
- exponent/core/graphql/queries.py +110 -0
- exponent/core/graphql/subscriptions.py +452 -0
- exponent/core/remote_execution/checkpoints.py +212 -0
- exponent/core/remote_execution/cli_rpc_types.py +214 -0
- exponent/core/remote_execution/client.py +545 -0
- exponent/core/remote_execution/code_execution.py +58 -0
- exponent/core/remote_execution/command_execution.py +105 -0
- exponent/core/remote_execution/error_info.py +45 -0
- exponent/core/remote_execution/exceptions.py +10 -0
- exponent/core/remote_execution/file_write.py +410 -0
- exponent/core/remote_execution/files.py +415 -0
- exponent/core/remote_execution/git.py +268 -0
- exponent/core/remote_execution/languages/python_execution.py +239 -0
- exponent/core/remote_execution/languages/shell_streaming.py +221 -0
- exponent/core/remote_execution/languages/types.py +20 -0
- exponent/core/remote_execution/session.py +128 -0
- exponent/core/remote_execution/system_context.py +54 -0
- exponent/core/remote_execution/tool_execution.py +289 -0
- exponent/core/remote_execution/truncation.py +284 -0
- exponent/core/remote_execution/types.py +670 -0
- exponent/core/remote_execution/utils.py +600 -0
- exponent/core/types/__init__.py +0 -0
- exponent/core/types/command_data.py +206 -0
- exponent/core/types/event_types.py +89 -0
- exponent/core/types/generated/__init__.py +0 -0
- exponent/core/types/generated/strategy_info.py +225 -0
- exponent/migration-docs/login.md +112 -0
- exponent/py.typed +4 -0
- exponent/utils/__init__.py +0 -0
- exponent/utils/colors.py +92 -0
- exponent/utils/version.py +289 -0
- indent-0.0.8.dist-info/METADATA +36 -0
- indent-0.0.8.dist-info/RECORD +56 -0
- indent-0.0.8.dist-info/WHEEL +4 -0
- indent-0.0.8.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,545 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
from collections.abc import AsyncGenerator, Callable
|
|
7
|
+
from contextlib import asynccontextmanager
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import Any, TypeVar, Union, cast
|
|
10
|
+
|
|
11
|
+
import msgspec
|
|
12
|
+
import websockets.client
|
|
13
|
+
import websockets.exceptions
|
|
14
|
+
from httpx import (
|
|
15
|
+
AsyncClient,
|
|
16
|
+
codes as http_status,
|
|
17
|
+
)
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
|
|
20
|
+
from exponent.commands.utils import ConnectionTracker
|
|
21
|
+
from exponent.core.config import is_editable_install
|
|
22
|
+
from exponent.core.remote_execution import files, system_context
|
|
23
|
+
from exponent.core.remote_execution.cli_rpc_types import (
|
|
24
|
+
BashToolInput,
|
|
25
|
+
BatchToolExecutionRequest,
|
|
26
|
+
BatchToolExecutionResponse,
|
|
27
|
+
CliRpcRequest,
|
|
28
|
+
CliRpcResponse,
|
|
29
|
+
ErrorResponse,
|
|
30
|
+
GetAllFilesRequest,
|
|
31
|
+
GetAllFilesResponse,
|
|
32
|
+
TerminateRequest,
|
|
33
|
+
TerminateResponse,
|
|
34
|
+
ToolExecutionRequest,
|
|
35
|
+
ToolExecutionResponse,
|
|
36
|
+
ToolResultType,
|
|
37
|
+
)
|
|
38
|
+
from exponent.core.remote_execution.code_execution import (
|
|
39
|
+
execute_code_streaming,
|
|
40
|
+
)
|
|
41
|
+
from exponent.core.remote_execution.files import file_walk
|
|
42
|
+
from exponent.core.remote_execution.session import (
|
|
43
|
+
RemoteExecutionClientSession,
|
|
44
|
+
get_session,
|
|
45
|
+
)
|
|
46
|
+
from exponent.core.remote_execution.tool_execution import (
|
|
47
|
+
execute_bash_tool,
|
|
48
|
+
execute_tool,
|
|
49
|
+
truncate_result,
|
|
50
|
+
)
|
|
51
|
+
from exponent.core.remote_execution.types import (
|
|
52
|
+
ChatSource,
|
|
53
|
+
CLIConnectedState,
|
|
54
|
+
CreateChatResponse,
|
|
55
|
+
GitInfo,
|
|
56
|
+
HeartbeatInfo,
|
|
57
|
+
RemoteExecutionResponseType,
|
|
58
|
+
RunWorkflowRequest,
|
|
59
|
+
StreamingCodeExecutionRequest,
|
|
60
|
+
)
|
|
61
|
+
from exponent.core.remote_execution.utils import (
|
|
62
|
+
deserialize_api_response,
|
|
63
|
+
)
|
|
64
|
+
from exponent.utils.version import get_installed_version
|
|
65
|
+
|
|
66
|
+
logger = logging.getLogger(__name__)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
TModel = TypeVar("TModel", bound=BaseModel)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@dataclass
|
|
73
|
+
class WSDisconnected:
|
|
74
|
+
error_message: str | None = None
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@dataclass
|
|
78
|
+
class SwitchCLIChat:
|
|
79
|
+
new_chat_uuid: str
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
REMOTE_EXECUTION_CLIENT_EXIT_INFO = Union[WSDisconnected, SwitchCLIChat]
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class RemoteExecutionClient:
|
|
86
|
+
def __init__(
|
|
87
|
+
self,
|
|
88
|
+
session: RemoteExecutionClientSession,
|
|
89
|
+
file_cache: files.FileCache | None = None,
|
|
90
|
+
):
|
|
91
|
+
self.current_session = session
|
|
92
|
+
self.file_cache = file_cache or files.FileCache(session.working_directory)
|
|
93
|
+
|
|
94
|
+
# for active code executions, track whether they should be halted
|
|
95
|
+
# correlation_id -> should_halt
|
|
96
|
+
self._halt_states: dict[str, bool] = {}
|
|
97
|
+
self._halt_lock = asyncio.Lock()
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def working_directory(self) -> str:
|
|
101
|
+
return self.current_session.working_directory
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def api_client(self) -> AsyncClient:
|
|
105
|
+
return self.current_session.api_client
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def ws_client(self) -> AsyncClient:
|
|
109
|
+
return self.current_session.ws_client
|
|
110
|
+
|
|
111
|
+
async def add_code_execution_to_halt_states(self, correlation_id: str) -> None:
|
|
112
|
+
async with self._halt_lock:
|
|
113
|
+
self._halt_states[correlation_id] = False
|
|
114
|
+
|
|
115
|
+
async def halt_all_code_executions(self) -> None:
|
|
116
|
+
logger.info(f"Halting all code executions: {self._halt_states}")
|
|
117
|
+
async with self._halt_lock:
|
|
118
|
+
self._halt_states = {
|
|
119
|
+
correlation_id: True for correlation_id in self._halt_states.keys()
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
async def clear_halt_state(self, correlation_id: str) -> None:
|
|
123
|
+
async with self._halt_lock:
|
|
124
|
+
self._halt_states.pop(correlation_id, None)
|
|
125
|
+
|
|
126
|
+
def get_halt_check(self, correlation_id: str) -> Callable[[], bool]:
|
|
127
|
+
def should_halt() -> bool:
|
|
128
|
+
# Don't need to lock here, since just reading from dict
|
|
129
|
+
return self._halt_states.get(correlation_id, False)
|
|
130
|
+
|
|
131
|
+
return should_halt
|
|
132
|
+
|
|
133
|
+
async def _handle_websocket_message(
|
|
134
|
+
self,
|
|
135
|
+
msg: str,
|
|
136
|
+
websocket: websockets.client.WebSocketClientProtocol,
|
|
137
|
+
requests: asyncio.Queue[CliRpcRequest],
|
|
138
|
+
) -> REMOTE_EXECUTION_CLIENT_EXIT_INFO | None:
|
|
139
|
+
"""Handle an incoming websocket message.
|
|
140
|
+
Returns None to continue processing, or a REMOTE_EXECUTION_CLIENT_EXIT_INFO to exit."""
|
|
141
|
+
|
|
142
|
+
msg_data = json.loads(msg)
|
|
143
|
+
if msg_data["type"] != "request":
|
|
144
|
+
return None
|
|
145
|
+
|
|
146
|
+
data = json.dumps(msg_data["data"])
|
|
147
|
+
request = msgspec.json.decode(data, type=CliRpcRequest)
|
|
148
|
+
|
|
149
|
+
if isinstance(request.request, TerminateRequest):
|
|
150
|
+
await self.halt_all_code_executions()
|
|
151
|
+
await websocket.send(
|
|
152
|
+
json.dumps(
|
|
153
|
+
{
|
|
154
|
+
"type": "result",
|
|
155
|
+
"data": msgspec.to_builtins(
|
|
156
|
+
CliRpcResponse(
|
|
157
|
+
request_id=request.request_id,
|
|
158
|
+
response=TerminateResponse(),
|
|
159
|
+
)
|
|
160
|
+
),
|
|
161
|
+
}
|
|
162
|
+
)
|
|
163
|
+
)
|
|
164
|
+
return None
|
|
165
|
+
else:
|
|
166
|
+
if isinstance(request.request, ToolExecutionRequest) and isinstance(
|
|
167
|
+
request.request.tool_input, BashToolInput
|
|
168
|
+
):
|
|
169
|
+
await self.add_code_execution_to_halt_states(request.request_id)
|
|
170
|
+
elif isinstance(request.request, BatchToolExecutionRequest):
|
|
171
|
+
# Add halt state if any of the batch tools are bash commands
|
|
172
|
+
if any(
|
|
173
|
+
isinstance(tool_input, BashToolInput)
|
|
174
|
+
for tool_input in request.request.tool_inputs
|
|
175
|
+
):
|
|
176
|
+
await self.add_code_execution_to_halt_states(request.request_id)
|
|
177
|
+
|
|
178
|
+
await requests.put(request)
|
|
179
|
+
return None
|
|
180
|
+
|
|
181
|
+
async def _setup_tasks(
|
|
182
|
+
self,
|
|
183
|
+
beats: asyncio.Queue[HeartbeatInfo],
|
|
184
|
+
requests: asyncio.Queue[CliRpcRequest],
|
|
185
|
+
results: asyncio.Queue[CliRpcResponse],
|
|
186
|
+
) -> list[asyncio.Task[None]]:
|
|
187
|
+
"""Setup heartbeat and executor tasks."""
|
|
188
|
+
|
|
189
|
+
async def beat() -> None:
|
|
190
|
+
while True:
|
|
191
|
+
info = await self.get_heartbeat_info()
|
|
192
|
+
await beats.put(info)
|
|
193
|
+
await asyncio.sleep(3)
|
|
194
|
+
|
|
195
|
+
# Lock to ensure that only one executor can grab a
|
|
196
|
+
# request at a time.
|
|
197
|
+
requests_lock = asyncio.Lock()
|
|
198
|
+
|
|
199
|
+
# Lock to ensure that only one executor can put a
|
|
200
|
+
# result in the results queue at a time.
|
|
201
|
+
results_lock = asyncio.Lock()
|
|
202
|
+
|
|
203
|
+
async def executor() -> None:
|
|
204
|
+
# We use locks here to protect the request/result
|
|
205
|
+
# queues from being accessed by multiple executors.
|
|
206
|
+
while True:
|
|
207
|
+
async with requests_lock:
|
|
208
|
+
request = await requests.get()
|
|
209
|
+
|
|
210
|
+
try:
|
|
211
|
+
# if isinstance(request, StreamingCodeExecutionRequest):
|
|
212
|
+
# async for streaming_response in self.handle_streaming_request(
|
|
213
|
+
# request
|
|
214
|
+
# ):
|
|
215
|
+
# async with results_lock:
|
|
216
|
+
# await results.put(streaming_response)
|
|
217
|
+
# else:
|
|
218
|
+
# Note that we don't want to hold the lock here
|
|
219
|
+
# because we want other executors to be able to
|
|
220
|
+
# grab requests while we're handling a request.
|
|
221
|
+
logger.info(f"Handling request {request}")
|
|
222
|
+
response = await self.handle_request(request)
|
|
223
|
+
async with results_lock:
|
|
224
|
+
logger.info(f"Putting response {response}")
|
|
225
|
+
await results.put(response)
|
|
226
|
+
except Exception as e: # noqa: BLE001
|
|
227
|
+
logger.info(f"Error handling request {request}:\n\n{e}")
|
|
228
|
+
async with results_lock:
|
|
229
|
+
await results.put(
|
|
230
|
+
CliRpcResponse(
|
|
231
|
+
request_id=request.request_id,
|
|
232
|
+
response=ErrorResponse(
|
|
233
|
+
error_message=str(e),
|
|
234
|
+
),
|
|
235
|
+
)
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
beat_task = asyncio.create_task(beat())
|
|
239
|
+
# Three parallel executors to handle requests
|
|
240
|
+
|
|
241
|
+
executor_tasks = [
|
|
242
|
+
asyncio.create_task(executor()),
|
|
243
|
+
asyncio.create_task(executor()),
|
|
244
|
+
asyncio.create_task(executor()),
|
|
245
|
+
]
|
|
246
|
+
|
|
247
|
+
return [beat_task, *executor_tasks]
|
|
248
|
+
|
|
249
|
+
async def _process_websocket_messages(
|
|
250
|
+
self,
|
|
251
|
+
websocket: websockets.client.WebSocketClientProtocol,
|
|
252
|
+
beats: asyncio.Queue[HeartbeatInfo],
|
|
253
|
+
requests: asyncio.Queue[CliRpcRequest],
|
|
254
|
+
results: asyncio.Queue[CliRpcResponse],
|
|
255
|
+
) -> REMOTE_EXECUTION_CLIENT_EXIT_INFO:
|
|
256
|
+
"""Process messages from the websocket connection."""
|
|
257
|
+
try:
|
|
258
|
+
recv = asyncio.create_task(websocket.recv())
|
|
259
|
+
get_beat = asyncio.create_task(beats.get())
|
|
260
|
+
get_result = asyncio.create_task(results.get())
|
|
261
|
+
pending = {recv, get_beat, get_result}
|
|
262
|
+
|
|
263
|
+
while True:
|
|
264
|
+
done, pending = await asyncio.wait(
|
|
265
|
+
pending, return_when=asyncio.FIRST_COMPLETED
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
if recv in done:
|
|
269
|
+
msg = str(recv.result())
|
|
270
|
+
exit_info = await self._handle_websocket_message(
|
|
271
|
+
msg, websocket, requests
|
|
272
|
+
)
|
|
273
|
+
if exit_info is not None:
|
|
274
|
+
return exit_info
|
|
275
|
+
|
|
276
|
+
recv = asyncio.create_task(websocket.recv())
|
|
277
|
+
pending.add(recv)
|
|
278
|
+
|
|
279
|
+
if get_beat in done:
|
|
280
|
+
info = get_beat.result()
|
|
281
|
+
data = json.loads(info.model_dump_json())
|
|
282
|
+
msg = json.dumps({"type": "heartbeat", "data": data})
|
|
283
|
+
await websocket.send(msg)
|
|
284
|
+
|
|
285
|
+
get_beat = asyncio.create_task(beats.get())
|
|
286
|
+
pending.add(get_beat)
|
|
287
|
+
|
|
288
|
+
if get_result in done:
|
|
289
|
+
response = get_result.result()
|
|
290
|
+
data = msgspec.to_builtins(response)
|
|
291
|
+
msg = json.dumps({"type": "result", "data": data})
|
|
292
|
+
await websocket.send(msg)
|
|
293
|
+
|
|
294
|
+
get_result = asyncio.create_task(results.get())
|
|
295
|
+
pending.add(get_result)
|
|
296
|
+
finally:
|
|
297
|
+
for task in pending:
|
|
298
|
+
task.cancel()
|
|
299
|
+
|
|
300
|
+
await asyncio.gather(*pending, return_exceptions=True)
|
|
301
|
+
|
|
302
|
+
async def _handle_websocket_connection(
|
|
303
|
+
self,
|
|
304
|
+
websocket: websockets.client.WebSocketClientProtocol,
|
|
305
|
+
connection_tracker: ConnectionTracker | None,
|
|
306
|
+
) -> REMOTE_EXECUTION_CLIENT_EXIT_INFO | None:
|
|
307
|
+
"""Handle a single websocket connection.
|
|
308
|
+
Returns None to continue with reconnection attempts, or an exit info to terminate."""
|
|
309
|
+
if connection_tracker is not None:
|
|
310
|
+
await connection_tracker.set_connected(True)
|
|
311
|
+
|
|
312
|
+
beats: asyncio.Queue[HeartbeatInfo] = asyncio.Queue()
|
|
313
|
+
requests: asyncio.Queue[CliRpcRequest] = asyncio.Queue()
|
|
314
|
+
results: asyncio.Queue[CliRpcResponse] = asyncio.Queue()
|
|
315
|
+
|
|
316
|
+
tasks = await self._setup_tasks(beats, requests, results)
|
|
317
|
+
|
|
318
|
+
try:
|
|
319
|
+
return await self._process_websocket_messages(
|
|
320
|
+
websocket, beats, requests, results
|
|
321
|
+
)
|
|
322
|
+
except websockets.exceptions.ConnectionClosed as e:
|
|
323
|
+
if e.rcvd is not None:
|
|
324
|
+
if e.rcvd.code == 1000:
|
|
325
|
+
# Normal closure, exit completely
|
|
326
|
+
return WSDisconnected()
|
|
327
|
+
elif e.rcvd.code == 1008:
|
|
328
|
+
error_message = (
|
|
329
|
+
"Error connecting to websocket"
|
|
330
|
+
if e.rcvd.reason is None
|
|
331
|
+
else e.rcvd.reason
|
|
332
|
+
)
|
|
333
|
+
return WSDisconnected(error_message=error_message)
|
|
334
|
+
# Otherwise, allow reconnection attempt
|
|
335
|
+
return None
|
|
336
|
+
except TimeoutError:
|
|
337
|
+
# Timeout, allow reconnection attempt
|
|
338
|
+
# TODO: investgate if this is needed, possibly scope it down
|
|
339
|
+
return None
|
|
340
|
+
finally:
|
|
341
|
+
for task in tasks:
|
|
342
|
+
task.cancel()
|
|
343
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|
|
344
|
+
if connection_tracker is not None:
|
|
345
|
+
await connection_tracker.set_connected(False)
|
|
346
|
+
|
|
347
|
+
async def run_connection(
|
|
348
|
+
self,
|
|
349
|
+
chat_uuid: str,
|
|
350
|
+
connection_tracker: ConnectionTracker | None = None,
|
|
351
|
+
) -> REMOTE_EXECUTION_CLIENT_EXIT_INFO:
|
|
352
|
+
"""Run the websocket connection loop."""
|
|
353
|
+
self.current_session.set_chat_uuid(chat_uuid)
|
|
354
|
+
|
|
355
|
+
async for websocket in self.ws_connect(f"/api/ws/chat/{chat_uuid}"):
|
|
356
|
+
result = await self._handle_websocket_connection(
|
|
357
|
+
websocket, connection_tracker
|
|
358
|
+
)
|
|
359
|
+
if result is not None:
|
|
360
|
+
return result
|
|
361
|
+
# If we get None, we'll try to reconnect
|
|
362
|
+
|
|
363
|
+
# If we exit the websocket connection loop without returning,
|
|
364
|
+
# it means we couldn't establish a connection
|
|
365
|
+
return WSDisconnected(error_message="Could not establish websocket connection")
|
|
366
|
+
|
|
367
|
+
async def create_chat(self, chat_source: ChatSource) -> CreateChatResponse:
|
|
368
|
+
response = await self.api_client.post(
|
|
369
|
+
"/api/remote_execution/create_chat",
|
|
370
|
+
params={"chat_source": chat_source.value},
|
|
371
|
+
)
|
|
372
|
+
return await deserialize_api_response(response, CreateChatResponse)
|
|
373
|
+
|
|
374
|
+
async def get_gh_installation_token(self, git_info: GitInfo) -> dict[str, Any]:
|
|
375
|
+
response = await self.api_client.post(
|
|
376
|
+
"/github_app/exchange_token",
|
|
377
|
+
json=git_info.model_dump(),
|
|
378
|
+
)
|
|
379
|
+
return cast(dict[str, Any], response.json())
|
|
380
|
+
|
|
381
|
+
async def run_workflow(self, chat_uuid: str, workflow_id: str) -> dict[str, Any]:
|
|
382
|
+
response = await self.api_client.post(
|
|
383
|
+
"/api/remote_execution/run_workflow",
|
|
384
|
+
json=RunWorkflowRequest(
|
|
385
|
+
chat_uuid=chat_uuid,
|
|
386
|
+
workflow_id=workflow_id,
|
|
387
|
+
).model_dump(),
|
|
388
|
+
timeout=60,
|
|
389
|
+
)
|
|
390
|
+
if response.status_code != http_status.OK:
|
|
391
|
+
raise Exception(
|
|
392
|
+
f"Failed to run workflow with status code {response.status_code} and response {response.text}"
|
|
393
|
+
)
|
|
394
|
+
return cast(dict[str, Any], response.json())
|
|
395
|
+
|
|
396
|
+
async def get_heartbeat_info(self) -> HeartbeatInfo:
|
|
397
|
+
return HeartbeatInfo(
|
|
398
|
+
system_info=await system_context.get_system_info(self.working_directory),
|
|
399
|
+
exponent_version=get_installed_version(),
|
|
400
|
+
editable_installation=is_editable_install(),
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
async def send_heartbeat(self, chat_uuid: str) -> CLIConnectedState:
|
|
404
|
+
logger.info(f"Sending heartbeat for chat_uuid {chat_uuid}")
|
|
405
|
+
heartbeat_info = await self.get_heartbeat_info()
|
|
406
|
+
response = await self.api_client.post(
|
|
407
|
+
f"/api/remote_execution/{chat_uuid}/heartbeat",
|
|
408
|
+
content=heartbeat_info.model_dump_json(),
|
|
409
|
+
timeout=60,
|
|
410
|
+
)
|
|
411
|
+
if response.status_code != http_status.OK:
|
|
412
|
+
raise Exception(
|
|
413
|
+
f"Heartbeat failed with status code {response.status_code} and response {response.text}"
|
|
414
|
+
)
|
|
415
|
+
connected_state = await deserialize_api_response(response, CLIConnectedState)
|
|
416
|
+
logger.info(f"Heartbeat response: {connected_state}")
|
|
417
|
+
return connected_state
|
|
418
|
+
|
|
419
|
+
async def handle_request(self, request: CliRpcRequest) -> CliRpcResponse:
|
|
420
|
+
try:
|
|
421
|
+
if isinstance(request.request, ToolExecutionRequest):
|
|
422
|
+
if isinstance(request.request.tool_input, BashToolInput):
|
|
423
|
+
raw_result = await execute_bash_tool(
|
|
424
|
+
request.request.tool_input,
|
|
425
|
+
self.working_directory,
|
|
426
|
+
should_halt=self.get_halt_check(request.request_id),
|
|
427
|
+
)
|
|
428
|
+
else:
|
|
429
|
+
raw_result = await execute_tool( # type: ignore[assignment]
|
|
430
|
+
request.request.tool_input, self.working_directory
|
|
431
|
+
)
|
|
432
|
+
tool_result = truncate_result(raw_result)
|
|
433
|
+
return CliRpcResponse(
|
|
434
|
+
request_id=request.request_id,
|
|
435
|
+
response=ToolExecutionResponse(
|
|
436
|
+
tool_result=tool_result,
|
|
437
|
+
),
|
|
438
|
+
)
|
|
439
|
+
elif isinstance(request.request, GetAllFilesRequest):
|
|
440
|
+
files = await file_walk(self.working_directory)
|
|
441
|
+
return CliRpcResponse(
|
|
442
|
+
request_id=request.request_id,
|
|
443
|
+
response=GetAllFilesResponse(files=files),
|
|
444
|
+
)
|
|
445
|
+
elif isinstance(request.request, BatchToolExecutionRequest):
|
|
446
|
+
results: list[ToolResultType] = []
|
|
447
|
+
for tool_input in request.request.tool_inputs:
|
|
448
|
+
try:
|
|
449
|
+
if isinstance(tool_input, BashToolInput):
|
|
450
|
+
raw_result = await execute_bash_tool(
|
|
451
|
+
tool_input,
|
|
452
|
+
self.working_directory,
|
|
453
|
+
should_halt=self.get_halt_check(request.request_id),
|
|
454
|
+
)
|
|
455
|
+
else:
|
|
456
|
+
raw_result = await execute_tool( # type: ignore[assignment]
|
|
457
|
+
tool_input, self.working_directory
|
|
458
|
+
)
|
|
459
|
+
tool_result = truncate_result(raw_result)
|
|
460
|
+
results.append(tool_result)
|
|
461
|
+
except Exception as e: # noqa: BLE001
|
|
462
|
+
logger.error(f"Error executing tool {tool_input}: {e}")
|
|
463
|
+
from exponent.core.remote_execution.cli_rpc_types import (
|
|
464
|
+
ErrorToolResult,
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
results.append(ErrorToolResult(error_message=str(e)))
|
|
468
|
+
|
|
469
|
+
return CliRpcResponse(
|
|
470
|
+
request_id=request.request_id,
|
|
471
|
+
response=BatchToolExecutionResponse(
|
|
472
|
+
tool_results=results,
|
|
473
|
+
),
|
|
474
|
+
)
|
|
475
|
+
elif isinstance(request.request, TerminateRequest):
|
|
476
|
+
raise ValueError(
|
|
477
|
+
"TerminateRequest should not be handled by handle_request"
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
raise ValueError(f"Unhandled request type: {type(request)}")
|
|
481
|
+
|
|
482
|
+
except Exception as e:
|
|
483
|
+
logger.error(f"Error handling request {request}:\n\n{e}")
|
|
484
|
+
raise e
|
|
485
|
+
finally:
|
|
486
|
+
# Clean up halt state after request is complete
|
|
487
|
+
if isinstance(request.request, ToolExecutionRequest) and isinstance(
|
|
488
|
+
request.request.tool_input, BashToolInput
|
|
489
|
+
):
|
|
490
|
+
await self.clear_halt_state(request.request_id)
|
|
491
|
+
elif isinstance(request.request, BatchToolExecutionRequest):
|
|
492
|
+
# Clear halt state if any of the batch tools were bash commands
|
|
493
|
+
if any(
|
|
494
|
+
isinstance(tool_input, BashToolInput)
|
|
495
|
+
for tool_input in request.request.tool_inputs
|
|
496
|
+
):
|
|
497
|
+
await self.clear_halt_state(request.request_id)
|
|
498
|
+
|
|
499
|
+
async def handle_streaming_request(
|
|
500
|
+
self, request: StreamingCodeExecutionRequest
|
|
501
|
+
) -> AsyncGenerator[RemoteExecutionResponseType, None]:
|
|
502
|
+
if not isinstance(request, StreamingCodeExecutionRequest):
|
|
503
|
+
assert False, f"{type(request)} should be sent to handle_streaming_request"
|
|
504
|
+
async for output in execute_code_streaming(
|
|
505
|
+
request,
|
|
506
|
+
self.current_session,
|
|
507
|
+
working_directory=self.working_directory,
|
|
508
|
+
should_halt=self.get_halt_check(request.correlation_id),
|
|
509
|
+
):
|
|
510
|
+
yield output
|
|
511
|
+
|
|
512
|
+
def ws_connect(self, path: str) -> websockets.client.connect:
|
|
513
|
+
base_url = (
|
|
514
|
+
str(self.ws_client.base_url)
|
|
515
|
+
.replace("http://", "ws://")
|
|
516
|
+
.replace("https://", "wss://")
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
url = f"{base_url}{path}"
|
|
520
|
+
headers = {"api-key": self.api_client.headers["api-key"]}
|
|
521
|
+
|
|
522
|
+
conn = websockets.client.connect(
|
|
523
|
+
url, extra_headers=headers, timeout=10, ping_timeout=10
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
# Stop exponential backoff from blowing up
|
|
527
|
+
# the wait time between connection attempts
|
|
528
|
+
conn.BACKOFF_MAX = 2
|
|
529
|
+
conn.BACKOFF_INITIAL = 1 # pyright: ignore
|
|
530
|
+
|
|
531
|
+
return conn
|
|
532
|
+
|
|
533
|
+
@staticmethod
|
|
534
|
+
@asynccontextmanager
|
|
535
|
+
async def session(
|
|
536
|
+
api_key: str,
|
|
537
|
+
base_url: str,
|
|
538
|
+
base_ws_url: str,
|
|
539
|
+
working_directory: str,
|
|
540
|
+
file_cache: files.FileCache | None = None,
|
|
541
|
+
) -> AsyncGenerator[RemoteExecutionClient, None]:
|
|
542
|
+
async with get_session(
|
|
543
|
+
working_directory, base_url, base_ws_url, api_key
|
|
544
|
+
) as session:
|
|
545
|
+
yield RemoteExecutionClient(session, file_cache)
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from collections.abc import AsyncGenerator, Callable
|
|
2
|
+
|
|
3
|
+
from exponent.core.remote_execution.languages.python_execution import (
|
|
4
|
+
execute_python_streaming,
|
|
5
|
+
)
|
|
6
|
+
from exponent.core.remote_execution.languages.shell_streaming import (
|
|
7
|
+
execute_shell_streaming,
|
|
8
|
+
)
|
|
9
|
+
from exponent.core.remote_execution.languages.types import StreamedOutputPiece
|
|
10
|
+
from exponent.core.remote_execution.session import RemoteExecutionClientSession
|
|
11
|
+
from exponent.core.remote_execution.types import (
|
|
12
|
+
StreamingCodeExecutionRequest,
|
|
13
|
+
StreamingCodeExecutionResponse,
|
|
14
|
+
StreamingCodeExecutionResponseChunk,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
EMPTY_OUTPUT_STRING = "(No output)"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
async def execute_code_streaming(
|
|
21
|
+
request: StreamingCodeExecutionRequest,
|
|
22
|
+
session: RemoteExecutionClientSession,
|
|
23
|
+
working_directory: str,
|
|
24
|
+
should_halt: Callable[[], bool] | None = None,
|
|
25
|
+
) -> AsyncGenerator[
|
|
26
|
+
StreamingCodeExecutionResponseChunk | StreamingCodeExecutionResponse, None
|
|
27
|
+
]:
|
|
28
|
+
if request.language == "python":
|
|
29
|
+
async for output in execute_python_streaming(
|
|
30
|
+
request.content, session.kernel, user_interrupted=should_halt
|
|
31
|
+
):
|
|
32
|
+
if isinstance(output, StreamedOutputPiece):
|
|
33
|
+
yield StreamingCodeExecutionResponseChunk(
|
|
34
|
+
content=output.content, correlation_id=request.correlation_id
|
|
35
|
+
)
|
|
36
|
+
else:
|
|
37
|
+
yield StreamingCodeExecutionResponse(
|
|
38
|
+
correlation_id=request.correlation_id,
|
|
39
|
+
content=output.output or EMPTY_OUTPUT_STRING,
|
|
40
|
+
halted=output.halted,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
elif request.language == "shell":
|
|
44
|
+
async for shell_output in execute_shell_streaming(
|
|
45
|
+
request.content, working_directory, request.timeout, should_halt
|
|
46
|
+
):
|
|
47
|
+
if isinstance(shell_output, StreamedOutputPiece):
|
|
48
|
+
yield StreamingCodeExecutionResponseChunk(
|
|
49
|
+
content=shell_output.content, correlation_id=request.correlation_id
|
|
50
|
+
)
|
|
51
|
+
else:
|
|
52
|
+
yield StreamingCodeExecutionResponse(
|
|
53
|
+
correlation_id=request.correlation_id,
|
|
54
|
+
content=shell_output.output or EMPTY_OUTPUT_STRING,
|
|
55
|
+
halted=shell_output.halted,
|
|
56
|
+
exit_code=shell_output.exit_code,
|
|
57
|
+
cancelled_for_timeout=shell_output.cancelled_for_timeout,
|
|
58
|
+
)
|