indent 0.1.13__py3-none-any.whl → 0.1.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- exponent/__init__.py +2 -2
- exponent/cli.py +0 -2
- exponent/commands/cloud_commands.py +2 -87
- exponent/commands/common.py +25 -40
- exponent/commands/config_commands.py +0 -87
- exponent/commands/run_commands.py +5 -2
- exponent/core/config.py +1 -1
- exponent/core/container_build/__init__.py +0 -0
- exponent/core/container_build/types.py +25 -0
- exponent/core/graphql/mutations.py +2 -31
- exponent/core/graphql/queries.py +0 -3
- exponent/core/remote_execution/cli_rpc_types.py +201 -5
- exponent/core/remote_execution/client.py +355 -92
- exponent/core/remote_execution/code_execution.py +26 -7
- exponent/core/remote_execution/default_env.py +31 -0
- exponent/core/remote_execution/languages/shell_streaming.py +11 -6
- exponent/core/remote_execution/port_utils.py +73 -0
- exponent/core/remote_execution/system_context.py +2 -0
- exponent/core/remote_execution/terminal_session.py +517 -0
- exponent/core/remote_execution/terminal_types.py +29 -0
- exponent/core/remote_execution/tool_execution.py +228 -18
- exponent/core/remote_execution/tool_type_utils.py +39 -0
- exponent/core/remote_execution/truncation.py +9 -1
- exponent/core/remote_execution/types.py +71 -19
- exponent/utils/version.py +8 -7
- {indent-0.1.13.dist-info → indent-0.1.28.dist-info}/METADATA +5 -2
- {indent-0.1.13.dist-info → indent-0.1.28.dist-info}/RECORD +29 -24
- exponent/commands/workflow_commands.py +0 -111
- exponent/core/graphql/github_config_queries.py +0 -56
- {indent-0.1.13.dist-info → indent-0.1.28.dist-info}/WHEEL +0 -0
- {indent-0.1.13.dist-info → indent-0.1.28.dist-info}/entry_points.txt +0 -0
|
@@ -5,7 +5,7 @@ import json
|
|
|
5
5
|
import logging
|
|
6
6
|
import time
|
|
7
7
|
import uuid
|
|
8
|
-
from collections.abc import AsyncGenerator, Callable, Generator
|
|
8
|
+
from collections.abc import AsyncGenerator, Callable, Coroutine, Generator
|
|
9
9
|
from contextlib import asynccontextmanager
|
|
10
10
|
from dataclasses import dataclass
|
|
11
11
|
from typing import Any, TypeVar, cast
|
|
@@ -30,13 +30,24 @@ from exponent.core.remote_execution.cli_rpc_types import (
|
|
|
30
30
|
CliRpcRequest,
|
|
31
31
|
CliRpcResponse,
|
|
32
32
|
ErrorResponse,
|
|
33
|
+
ErrorToolResult,
|
|
34
|
+
GenerateUploadUrlRequest,
|
|
35
|
+
GenerateUploadUrlResponse,
|
|
33
36
|
GetAllFilesRequest,
|
|
34
37
|
GetAllFilesResponse,
|
|
35
38
|
HttpRequest,
|
|
36
39
|
KeepAliveCliChatRequest,
|
|
37
40
|
KeepAliveCliChatResponse,
|
|
41
|
+
StartTerminalRequest,
|
|
42
|
+
StartTerminalResponse,
|
|
43
|
+
StopTerminalRequest,
|
|
44
|
+
StopTerminalResponse,
|
|
38
45
|
SwitchCLIChatRequest,
|
|
39
46
|
SwitchCLIChatResponse,
|
|
47
|
+
TerminalInputRequest,
|
|
48
|
+
TerminalInputResponse,
|
|
49
|
+
TerminalResizeRequest,
|
|
50
|
+
TerminalResizeResponse,
|
|
40
51
|
TerminateRequest,
|
|
41
52
|
TerminateResponse,
|
|
42
53
|
ToolExecutionRequest,
|
|
@@ -51,7 +62,10 @@ from exponent.core.remote_execution.http_fetch import fetch_http_content
|
|
|
51
62
|
from exponent.core.remote_execution.session import (
|
|
52
63
|
RemoteExecutionClientSession,
|
|
53
64
|
get_session,
|
|
65
|
+
send_exception_log,
|
|
54
66
|
)
|
|
67
|
+
from exponent.core.remote_execution.terminal_session import TerminalSessionManager
|
|
68
|
+
from exponent.core.remote_execution.terminal_types import TerminalMessage
|
|
55
69
|
from exponent.core.remote_execution.tool_execution import (
|
|
56
70
|
execute_bash_tool,
|
|
57
71
|
execute_tool,
|
|
@@ -61,12 +75,9 @@ from exponent.core.remote_execution.types import (
|
|
|
61
75
|
ChatSource,
|
|
62
76
|
CLIConnectedState,
|
|
63
77
|
CreateChatResponse,
|
|
64
|
-
GitInfo,
|
|
65
78
|
HeartbeatInfo,
|
|
66
|
-
PrReviewWorkflowInput,
|
|
67
|
-
RemoteExecutionResponseType,
|
|
68
79
|
RunWorkflowRequest,
|
|
69
|
-
|
|
80
|
+
WorkflowInput,
|
|
70
81
|
WorkflowTriggerRequest,
|
|
71
82
|
WorkflowTriggerResponse,
|
|
72
83
|
)
|
|
@@ -114,6 +125,13 @@ class RemoteExecutionClient:
|
|
|
114
125
|
# Track last request time for timeout functionality
|
|
115
126
|
self._last_request_time: float | None = None
|
|
116
127
|
|
|
128
|
+
# Track pending upload URL requests
|
|
129
|
+
self._pending_upload_requests: dict[
|
|
130
|
+
str, asyncio.Future[GenerateUploadUrlResponse]
|
|
131
|
+
] = {}
|
|
132
|
+
self._upload_request_lock = asyncio.Lock()
|
|
133
|
+
self._websocket: ClientConnection | None = None
|
|
134
|
+
|
|
117
135
|
@property
|
|
118
136
|
def working_directory(self) -> str:
|
|
119
137
|
return self.current_session.working_directory
|
|
@@ -173,23 +191,40 @@ class RemoteExecutionClient:
|
|
|
173
191
|
# Handle cancellation gracefully
|
|
174
192
|
return None
|
|
175
193
|
|
|
176
|
-
async def _handle_websocket_message(
|
|
194
|
+
async def _handle_websocket_message( # noqa: PLR0911, PLR0915
|
|
177
195
|
self,
|
|
178
196
|
msg: str,
|
|
179
197
|
websocket: ClientConnection,
|
|
180
198
|
requests: asyncio.Queue[CliRpcRequest],
|
|
199
|
+
terminal_session_manager: TerminalSessionManager,
|
|
181
200
|
) -> REMOTE_EXECUTION_CLIENT_EXIT_INFO | None:
|
|
182
201
|
"""Handle an incoming websocket message.
|
|
183
202
|
Returns None to continue processing, or a REMOTE_EXECUTION_CLIENT_EXIT_INFO to exit."""
|
|
184
203
|
|
|
204
|
+
self._last_request_time = time.time()
|
|
205
|
+
|
|
185
206
|
msg_data = json.loads(msg)
|
|
186
|
-
if msg_data["type"]
|
|
207
|
+
if msg_data["type"] == "result":
|
|
208
|
+
data = json.dumps(msg_data["data"])
|
|
209
|
+
try:
|
|
210
|
+
response = msgspec.json.decode(data, type=CliRpcResponse)
|
|
211
|
+
if isinstance(response.response, GenerateUploadUrlResponse):
|
|
212
|
+
async with self._upload_request_lock:
|
|
213
|
+
if response.request_id in self._pending_upload_requests:
|
|
214
|
+
future = self._pending_upload_requests.pop(
|
|
215
|
+
response.request_id
|
|
216
|
+
)
|
|
217
|
+
future.set_result(response.response)
|
|
218
|
+
except Exception as e:
|
|
219
|
+
logger.error(f"Error handling upload URL response: {e}")
|
|
220
|
+
return None
|
|
221
|
+
elif msg_data["type"] != "request":
|
|
187
222
|
return None
|
|
188
223
|
|
|
189
224
|
data = json.dumps(msg_data["data"])
|
|
190
225
|
try:
|
|
191
226
|
request = msgspec.json.decode(data, type=CliRpcRequest)
|
|
192
|
-
except msgspec.DecodeError as e:
|
|
227
|
+
except (msgspec.DecodeError, msgspec.ValidationError) as e:
|
|
193
228
|
# Try and decode to get request_id if possible
|
|
194
229
|
request = msgspec.json.decode(data)
|
|
195
230
|
if isinstance(request, dict) and "request_id" in request:
|
|
@@ -291,6 +326,101 @@ class RemoteExecutionClient:
|
|
|
291
326
|
)
|
|
292
327
|
)
|
|
293
328
|
return None
|
|
329
|
+
elif isinstance(request.request, StartTerminalRequest):
|
|
330
|
+
# Start a new terminal session
|
|
331
|
+
session_id = await terminal_session_manager.start_session(
|
|
332
|
+
websocket=websocket,
|
|
333
|
+
session_id=request.request.session_id,
|
|
334
|
+
command=request.request.command,
|
|
335
|
+
cols=request.request.cols,
|
|
336
|
+
rows=request.request.rows,
|
|
337
|
+
env=request.request.env,
|
|
338
|
+
)
|
|
339
|
+
await websocket.send(
|
|
340
|
+
json.dumps(
|
|
341
|
+
{
|
|
342
|
+
"type": "result",
|
|
343
|
+
"data": msgspec.to_builtins(
|
|
344
|
+
CliRpcResponse(
|
|
345
|
+
request_id=request.request_id,
|
|
346
|
+
response=StartTerminalResponse(
|
|
347
|
+
session_id=session_id, success=True
|
|
348
|
+
),
|
|
349
|
+
)
|
|
350
|
+
),
|
|
351
|
+
}
|
|
352
|
+
)
|
|
353
|
+
)
|
|
354
|
+
return None
|
|
355
|
+
elif isinstance(request.request, TerminalInputRequest):
|
|
356
|
+
# Send input to terminal session
|
|
357
|
+
success = await terminal_session_manager.send_input(
|
|
358
|
+
session_id=request.request.session_id,
|
|
359
|
+
data=request.request.data,
|
|
360
|
+
)
|
|
361
|
+
await websocket.send(
|
|
362
|
+
json.dumps(
|
|
363
|
+
{
|
|
364
|
+
"type": "result",
|
|
365
|
+
"data": msgspec.to_builtins(
|
|
366
|
+
CliRpcResponse(
|
|
367
|
+
request_id=request.request_id,
|
|
368
|
+
response=TerminalInputResponse(
|
|
369
|
+
session_id=request.request.session_id,
|
|
370
|
+
success=success,
|
|
371
|
+
),
|
|
372
|
+
)
|
|
373
|
+
),
|
|
374
|
+
}
|
|
375
|
+
)
|
|
376
|
+
)
|
|
377
|
+
return None
|
|
378
|
+
elif isinstance(request.request, TerminalResizeRequest):
|
|
379
|
+
# Resize terminal session
|
|
380
|
+
success = await terminal_session_manager.resize_terminal(
|
|
381
|
+
session_id=request.request.session_id,
|
|
382
|
+
rows=request.request.rows,
|
|
383
|
+
cols=request.request.cols,
|
|
384
|
+
)
|
|
385
|
+
await websocket.send(
|
|
386
|
+
json.dumps(
|
|
387
|
+
{
|
|
388
|
+
"type": "result",
|
|
389
|
+
"data": msgspec.to_builtins(
|
|
390
|
+
CliRpcResponse(
|
|
391
|
+
request_id=request.request_id,
|
|
392
|
+
response=TerminalResizeResponse(
|
|
393
|
+
session_id=request.request.session_id,
|
|
394
|
+
success=success,
|
|
395
|
+
),
|
|
396
|
+
)
|
|
397
|
+
),
|
|
398
|
+
}
|
|
399
|
+
)
|
|
400
|
+
)
|
|
401
|
+
return None
|
|
402
|
+
elif isinstance(request.request, StopTerminalRequest):
|
|
403
|
+
# Stop terminal session
|
|
404
|
+
success = await terminal_session_manager.stop_session(
|
|
405
|
+
session_id=request.request.session_id
|
|
406
|
+
)
|
|
407
|
+
await websocket.send(
|
|
408
|
+
json.dumps(
|
|
409
|
+
{
|
|
410
|
+
"type": "result",
|
|
411
|
+
"data": msgspec.to_builtins(
|
|
412
|
+
CliRpcResponse(
|
|
413
|
+
request_id=request.request_id,
|
|
414
|
+
response=StopTerminalResponse(
|
|
415
|
+
session_id=request.request.session_id,
|
|
416
|
+
success=success,
|
|
417
|
+
),
|
|
418
|
+
)
|
|
419
|
+
),
|
|
420
|
+
}
|
|
421
|
+
)
|
|
422
|
+
)
|
|
423
|
+
return None
|
|
294
424
|
else:
|
|
295
425
|
if isinstance(request.request, ToolExecutionRequest) and isinstance(
|
|
296
426
|
request.request.tool_input, BashToolInput
|
|
@@ -337,33 +467,64 @@ class RemoteExecutionClient:
|
|
|
337
467
|
request = await requests.get()
|
|
338
468
|
|
|
339
469
|
try:
|
|
340
|
-
# if
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
470
|
+
# Check if this is a streaming request
|
|
471
|
+
from exponent.core.remote_execution.cli_rpc_types import (
|
|
472
|
+
StreamingCodeExecutionRequest,
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
if isinstance(request.request, StreamingCodeExecutionRequest):
|
|
476
|
+
async for streaming_response in self.handle_streaming_request(
|
|
477
|
+
request.request
|
|
478
|
+
):
|
|
479
|
+
async with results_lock:
|
|
480
|
+
await results.put(
|
|
481
|
+
CliRpcResponse(
|
|
482
|
+
request_id=request.request_id,
|
|
483
|
+
response=streaming_response,
|
|
484
|
+
)
|
|
485
|
+
)
|
|
486
|
+
else:
|
|
487
|
+
# Note that we don't want to hold the lock here
|
|
488
|
+
# because we want other executors to be able to
|
|
489
|
+
# grab requests while we're handling a request.
|
|
490
|
+
logger.info(f"Handling request {request}")
|
|
491
|
+
response = await self.handle_request(request)
|
|
492
|
+
async with results_lock:
|
|
493
|
+
logger.info(f"Putting response {response}")
|
|
494
|
+
await results.put(response)
|
|
355
495
|
except Exception as e:
|
|
356
496
|
logger.info(f"Error handling request {request}:\n\n{e}")
|
|
497
|
+
try:
|
|
498
|
+
await send_exception_log(e, session=self.current_session)
|
|
499
|
+
except Exception:
|
|
500
|
+
pass
|
|
357
501
|
async with results_lock:
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
response=ErrorResponse(
|
|
362
|
-
error_message=str(e),
|
|
363
|
-
),
|
|
364
|
-
)
|
|
502
|
+
from exponent.core.remote_execution.cli_rpc_types import (
|
|
503
|
+
StreamingCodeExecutionRequest,
|
|
504
|
+
StreamingErrorResponse,
|
|
365
505
|
)
|
|
366
506
|
|
|
507
|
+
if isinstance(request.request, StreamingCodeExecutionRequest):
|
|
508
|
+
# For streaming requests, send a streaming error response
|
|
509
|
+
await results.put(
|
|
510
|
+
CliRpcResponse(
|
|
511
|
+
request_id=request.request_id,
|
|
512
|
+
response=StreamingErrorResponse(
|
|
513
|
+
correlation_id=request.request.correlation_id,
|
|
514
|
+
error_message=str(e),
|
|
515
|
+
),
|
|
516
|
+
)
|
|
517
|
+
)
|
|
518
|
+
else:
|
|
519
|
+
await results.put(
|
|
520
|
+
CliRpcResponse(
|
|
521
|
+
request_id=request.request_id,
|
|
522
|
+
response=ErrorResponse(
|
|
523
|
+
error_message=str(e),
|
|
524
|
+
),
|
|
525
|
+
)
|
|
526
|
+
)
|
|
527
|
+
|
|
367
528
|
beat_task = asyncio.create_task(beat())
|
|
368
529
|
# Three parallel executors to handle requests
|
|
369
530
|
|
|
@@ -381,13 +542,17 @@ class RemoteExecutionClient:
|
|
|
381
542
|
beats: asyncio.Queue[HeartbeatInfo],
|
|
382
543
|
requests: asyncio.Queue[CliRpcRequest],
|
|
383
544
|
results: asyncio.Queue[CliRpcResponse],
|
|
545
|
+
terminal_output_queue: asyncio.Queue[TerminalMessage],
|
|
546
|
+
terminal_session_manager: TerminalSessionManager,
|
|
384
547
|
) -> REMOTE_EXECUTION_CLIENT_EXIT_INFO:
|
|
385
548
|
"""Process messages from the websocket connection."""
|
|
549
|
+
pending: set[asyncio.Task[object]] = set()
|
|
386
550
|
try:
|
|
387
551
|
recv = asyncio.create_task(websocket.recv())
|
|
388
552
|
get_beat = asyncio.create_task(beats.get())
|
|
389
553
|
get_result = asyncio.create_task(results.get())
|
|
390
|
-
|
|
554
|
+
get_terminal_output = asyncio.create_task(terminal_output_queue.get())
|
|
555
|
+
pending = {recv, get_beat, get_result, get_terminal_output}
|
|
391
556
|
|
|
392
557
|
while True:
|
|
393
558
|
done, pending = await asyncio.wait(
|
|
@@ -397,7 +562,7 @@ class RemoteExecutionClient:
|
|
|
397
562
|
if recv in done:
|
|
398
563
|
msg = str(recv.result())
|
|
399
564
|
exit_info = await self._handle_websocket_message(
|
|
400
|
-
msg, websocket, requests
|
|
565
|
+
msg, websocket, requests, terminal_session_manager
|
|
401
566
|
)
|
|
402
567
|
if exit_info is not None:
|
|
403
568
|
return exit_info
|
|
@@ -416,12 +581,24 @@ class RemoteExecutionClient:
|
|
|
416
581
|
|
|
417
582
|
if get_result in done:
|
|
418
583
|
response = get_result.result()
|
|
584
|
+
# All responses are now CliRpcResponse with msgspec
|
|
419
585
|
data = msgspec.to_builtins(response)
|
|
420
586
|
msg = json.dumps({"type": "result", "data": data})
|
|
421
587
|
await websocket.send(msg)
|
|
422
588
|
|
|
423
589
|
get_result = asyncio.create_task(results.get())
|
|
424
590
|
pending.add(get_result)
|
|
591
|
+
|
|
592
|
+
if get_terminal_output in done:
|
|
593
|
+
terminal_message = get_terminal_output.result()
|
|
594
|
+
data = msgspec.to_builtins(terminal_message)
|
|
595
|
+
msg = json.dumps({"type": "terminal_message", "data": data})
|
|
596
|
+
await websocket.send(msg)
|
|
597
|
+
|
|
598
|
+
get_terminal_output = asyncio.create_task(
|
|
599
|
+
terminal_output_queue.get()
|
|
600
|
+
)
|
|
601
|
+
pending.add(get_terminal_output)
|
|
425
602
|
finally:
|
|
426
603
|
for task in pending:
|
|
427
604
|
task.cancel()
|
|
@@ -432,21 +609,27 @@ class RemoteExecutionClient:
|
|
|
432
609
|
self,
|
|
433
610
|
websocket: ClientConnection,
|
|
434
611
|
connection_tracker: ConnectionTracker | None,
|
|
612
|
+
beats: asyncio.Queue[HeartbeatInfo],
|
|
613
|
+
requests: asyncio.Queue[CliRpcRequest],
|
|
614
|
+
results: asyncio.Queue[CliRpcResponse],
|
|
615
|
+
terminal_output_queue: asyncio.Queue[TerminalMessage],
|
|
616
|
+
terminal_session_manager: TerminalSessionManager,
|
|
435
617
|
) -> REMOTE_EXECUTION_CLIENT_EXIT_INFO | None:
|
|
436
618
|
"""Handle a single websocket connection.
|
|
437
619
|
Returns None to continue with reconnection attempts, or an exit info to terminate."""
|
|
438
620
|
if connection_tracker is not None:
|
|
439
621
|
await connection_tracker.set_connected(True)
|
|
440
622
|
|
|
441
|
-
|
|
442
|
-
requests: asyncio.Queue[CliRpcRequest] = asyncio.Queue()
|
|
443
|
-
results: asyncio.Queue[CliRpcResponse] = asyncio.Queue()
|
|
444
|
-
|
|
445
|
-
tasks = await self._setup_tasks(beats, requests, results)
|
|
623
|
+
self._websocket = websocket
|
|
446
624
|
|
|
447
625
|
try:
|
|
448
626
|
return await self._process_websocket_messages(
|
|
449
|
-
websocket,
|
|
627
|
+
websocket,
|
|
628
|
+
beats,
|
|
629
|
+
requests,
|
|
630
|
+
results,
|
|
631
|
+
terminal_output_queue,
|
|
632
|
+
terminal_session_manager,
|
|
450
633
|
)
|
|
451
634
|
except websockets.exceptions.ConnectionClosed as e:
|
|
452
635
|
if e.rcvd is not None:
|
|
@@ -461,15 +644,13 @@ class RemoteExecutionClient:
|
|
|
461
644
|
)
|
|
462
645
|
return WSDisconnected(error_message=error_message)
|
|
463
646
|
# Otherwise, allow reconnection attempt
|
|
647
|
+
logger.debug("Websocket connection closed by remote.")
|
|
464
648
|
return None
|
|
465
649
|
except TimeoutError:
|
|
466
650
|
# Timeout, allow reconnection attempt
|
|
467
651
|
# TODO: investgate if this is needed, possibly scope it down
|
|
468
652
|
return None
|
|
469
653
|
finally:
|
|
470
|
-
for task in tasks:
|
|
471
|
-
task.cancel()
|
|
472
|
-
await asyncio.gather(*tasks, return_exceptions=True)
|
|
473
654
|
if connection_tracker is not None:
|
|
474
655
|
await connection_tracker.set_connected(False)
|
|
475
656
|
|
|
@@ -485,33 +666,64 @@ class RemoteExecutionClient:
|
|
|
485
666
|
# Initialize last request time for timeout monitoring
|
|
486
667
|
self._last_request_time = time.time()
|
|
487
668
|
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
asyncio.create_task(
|
|
494
|
-
self._handle_websocket_connection(websocket, connection_tracker)
|
|
495
|
-
),
|
|
496
|
-
asyncio.create_task(self._timeout_monitor(timeout_seconds)),
|
|
497
|
-
],
|
|
498
|
-
return_when=asyncio.FIRST_COMPLETED,
|
|
499
|
-
)
|
|
669
|
+
# Create queues ONCE - persist across reconnections
|
|
670
|
+
beats: asyncio.Queue[HeartbeatInfo] = asyncio.Queue()
|
|
671
|
+
requests: asyncio.Queue[CliRpcRequest] = asyncio.Queue()
|
|
672
|
+
results: asyncio.Queue[CliRpcResponse] = asyncio.Queue()
|
|
673
|
+
terminal_output_queue: asyncio.Queue[TerminalMessage] = asyncio.Queue()
|
|
500
674
|
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
task.cancel()
|
|
675
|
+
# Create terminal session manager ONCE - persist across reconnections
|
|
676
|
+
terminal_session_manager = TerminalSessionManager(terminal_output_queue)
|
|
504
677
|
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
result = await task
|
|
508
|
-
# If we get None, we'll try to reconnect
|
|
509
|
-
if result is not None:
|
|
510
|
-
return result
|
|
678
|
+
# Create tasks ONCE - persist across reconnections
|
|
679
|
+
executors = await self._setup_tasks(beats, requests, results)
|
|
511
680
|
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
681
|
+
try:
|
|
682
|
+
async for websocket in self.ws_connect(f"/api/ws/chat/{chat_uuid}"):
|
|
683
|
+
# Always run connection and timeout monitor concurrently
|
|
684
|
+
# If timeout_seconds is None, timeout monitor will loop indefinitely
|
|
685
|
+
done, pending = await asyncio.wait(
|
|
686
|
+
[
|
|
687
|
+
asyncio.create_task(
|
|
688
|
+
self._handle_websocket_connection(
|
|
689
|
+
websocket,
|
|
690
|
+
connection_tracker,
|
|
691
|
+
beats,
|
|
692
|
+
requests,
|
|
693
|
+
results,
|
|
694
|
+
terminal_output_queue,
|
|
695
|
+
terminal_session_manager,
|
|
696
|
+
)
|
|
697
|
+
),
|
|
698
|
+
asyncio.create_task(self._timeout_monitor(timeout_seconds)),
|
|
699
|
+
],
|
|
700
|
+
return_when=asyncio.FIRST_COMPLETED,
|
|
701
|
+
)
|
|
702
|
+
|
|
703
|
+
# Cancel pending tasks
|
|
704
|
+
for task in pending:
|
|
705
|
+
task.cancel()
|
|
706
|
+
|
|
707
|
+
# Return result from completed task
|
|
708
|
+
for task in done:
|
|
709
|
+
result = await task
|
|
710
|
+
# If we get None, we'll try to reconnect
|
|
711
|
+
if result is not None:
|
|
712
|
+
return result
|
|
713
|
+
|
|
714
|
+
# If we exit the websocket connection loop without returning,
|
|
715
|
+
# it means we couldn't establish a connection
|
|
716
|
+
return WSDisconnected(
|
|
717
|
+
error_message="Could not establish websocket connection"
|
|
718
|
+
)
|
|
719
|
+
finally:
|
|
720
|
+
# Stop all terminal sessions to clean up PTY processes
|
|
721
|
+
await terminal_session_manager.stop_all_sessions()
|
|
722
|
+
|
|
723
|
+
# Cancel all background tasks when exiting
|
|
724
|
+
for task in executors:
|
|
725
|
+
task.cancel()
|
|
726
|
+
await asyncio.gather(*executors, return_exceptions=True)
|
|
515
727
|
|
|
516
728
|
async def create_chat(self, chat_source: ChatSource) -> CreateChatResponse:
|
|
517
729
|
response = await self.api_client.post(
|
|
@@ -520,13 +732,6 @@ class RemoteExecutionClient:
|
|
|
520
732
|
)
|
|
521
733
|
return await deserialize_api_response(response, CreateChatResponse)
|
|
522
734
|
|
|
523
|
-
async def get_gh_installation_token(self, git_info: GitInfo) -> dict[str, Any]:
|
|
524
|
-
response = await self.api_client.post(
|
|
525
|
-
"/github_app/exchange_token",
|
|
526
|
-
json=git_info.model_dump(),
|
|
527
|
-
)
|
|
528
|
-
return cast(dict[str, Any], response.json())
|
|
529
|
-
|
|
530
735
|
# deprecated
|
|
531
736
|
async def run_workflow(self, chat_uuid: str, workflow_id: str) -> dict[str, Any]:
|
|
532
737
|
response = await self.api_client.post(
|
|
@@ -544,7 +749,7 @@ class RemoteExecutionClient:
|
|
|
544
749
|
return cast(dict[str, Any], response.json())
|
|
545
750
|
|
|
546
751
|
async def trigger_workflow(
|
|
547
|
-
self, workflow_name: str, workflow_input:
|
|
752
|
+
self, workflow_name: str, workflow_input: WorkflowInput
|
|
548
753
|
) -> WorkflowTriggerResponse:
|
|
549
754
|
response = await self.api_client.post(
|
|
550
755
|
"/api/remote_execution/trigger_workflow",
|
|
@@ -579,6 +784,38 @@ class RemoteExecutionClient:
|
|
|
579
784
|
logger.info(f"Heartbeat response: {connected_state}")
|
|
580
785
|
return connected_state
|
|
581
786
|
|
|
787
|
+
async def request_upload_url(
|
|
788
|
+
self, s3_key: str, content_type: str
|
|
789
|
+
) -> GenerateUploadUrlResponse:
|
|
790
|
+
if self._websocket is None:
|
|
791
|
+
raise RuntimeError("No active websocket connection")
|
|
792
|
+
|
|
793
|
+
request_id = str(uuid.uuid4())
|
|
794
|
+
request = CliRpcRequest(
|
|
795
|
+
request_id=request_id,
|
|
796
|
+
request=GenerateUploadUrlRequest(s3_key=s3_key, content_type=content_type),
|
|
797
|
+
)
|
|
798
|
+
|
|
799
|
+
future: asyncio.Future[GenerateUploadUrlResponse] = asyncio.Future()
|
|
800
|
+
async with self._upload_request_lock:
|
|
801
|
+
self._pending_upload_requests[request_id] = future
|
|
802
|
+
|
|
803
|
+
try:
|
|
804
|
+
await self._websocket.send(
|
|
805
|
+
json.dumps({"type": "request", "data": msgspec.to_builtins(request)})
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
response = await asyncio.wait_for(future, timeout=30)
|
|
809
|
+
return response
|
|
810
|
+
except TimeoutError:
|
|
811
|
+
async with self._upload_request_lock:
|
|
812
|
+
self._pending_upload_requests.pop(request_id, None)
|
|
813
|
+
raise RuntimeError("Timeout waiting for upload URL response")
|
|
814
|
+
except Exception as e:
|
|
815
|
+
async with self._upload_request_lock:
|
|
816
|
+
self._pending_upload_requests.pop(request_id, None)
|
|
817
|
+
raise e
|
|
818
|
+
|
|
582
819
|
async def handle_request(self, request: CliRpcRequest) -> CliRpcResponse:
|
|
583
820
|
# Update last request time for timeout functionality
|
|
584
821
|
self._last_request_time = time.time()
|
|
@@ -593,7 +830,7 @@ class RemoteExecutionClient:
|
|
|
593
830
|
)
|
|
594
831
|
else:
|
|
595
832
|
raw_result = await execute_tool( # type: ignore[assignment]
|
|
596
|
-
request.request.tool_input, self.working_directory
|
|
833
|
+
request.request.tool_input, self.working_directory, self
|
|
597
834
|
)
|
|
598
835
|
tool_result = truncate_result(raw_result)
|
|
599
836
|
return CliRpcResponse(
|
|
@@ -609,33 +846,38 @@ class RemoteExecutionClient:
|
|
|
609
846
|
response=GetAllFilesResponse(files=files),
|
|
610
847
|
)
|
|
611
848
|
elif isinstance(request.request, BatchToolExecutionRequest):
|
|
612
|
-
|
|
849
|
+
coros: list[Coroutine[Any, Any, ToolResultType]] = []
|
|
613
850
|
for tool_input in request.request.tool_inputs:
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
851
|
+
if isinstance(tool_input, BashToolInput):
|
|
852
|
+
coros.append(
|
|
853
|
+
execute_bash_tool(
|
|
617
854
|
tool_input,
|
|
618
855
|
self.working_directory,
|
|
619
856
|
should_halt=self.get_halt_check(request.request_id),
|
|
620
857
|
)
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
)
|
|
625
|
-
tool_result = truncate_result(raw_result)
|
|
626
|
-
results.append(tool_result)
|
|
627
|
-
except Exception as e:
|
|
628
|
-
logger.error(f"Error executing tool {tool_input}: {e}")
|
|
629
|
-
from exponent.core.remote_execution.cli_rpc_types import (
|
|
630
|
-
ErrorToolResult,
|
|
858
|
+
)
|
|
859
|
+
else:
|
|
860
|
+
coros.append(
|
|
861
|
+
execute_tool(tool_input, self.working_directory, self)
|
|
631
862
|
)
|
|
632
863
|
|
|
633
|
-
|
|
864
|
+
results: list[ToolResultType | BaseException] = await asyncio.gather(
|
|
865
|
+
*coros, return_exceptions=True
|
|
866
|
+
)
|
|
867
|
+
|
|
868
|
+
processed_results: list[ToolResultType] = []
|
|
869
|
+
for result in results:
|
|
870
|
+
if not isinstance(result, BaseException):
|
|
871
|
+
processed_results.append(truncate_result(result))
|
|
872
|
+
else:
|
|
873
|
+
processed_results.append(
|
|
874
|
+
ErrorToolResult(error_message=str(result))
|
|
875
|
+
)
|
|
634
876
|
|
|
635
877
|
return CliRpcResponse(
|
|
636
878
|
request_id=request.request_id,
|
|
637
879
|
response=BatchToolExecutionResponse(
|
|
638
|
-
tool_results=
|
|
880
|
+
tool_results=processed_results,
|
|
639
881
|
),
|
|
640
882
|
)
|
|
641
883
|
elif isinstance(request.request, HttpRequest):
|
|
@@ -657,6 +899,22 @@ class RemoteExecutionClient:
|
|
|
657
899
|
raise ValueError(
|
|
658
900
|
"KeepAliveCliChatRequest should not be handled by handle_request"
|
|
659
901
|
)
|
|
902
|
+
elif isinstance(request.request, StartTerminalRequest):
|
|
903
|
+
raise ValueError(
|
|
904
|
+
"StartTerminalRequest should not be handled by handle_request"
|
|
905
|
+
)
|
|
906
|
+
elif isinstance(request.request, TerminalInputRequest):
|
|
907
|
+
raise ValueError(
|
|
908
|
+
"TerminalInputRequest should not be handled by handle_request"
|
|
909
|
+
)
|
|
910
|
+
elif isinstance(request.request, TerminalResizeRequest):
|
|
911
|
+
raise ValueError(
|
|
912
|
+
"TerminalResizeRequest should not be handled by handle_request"
|
|
913
|
+
)
|
|
914
|
+
elif isinstance(request.request, StopTerminalRequest):
|
|
915
|
+
raise ValueError(
|
|
916
|
+
"StopTerminalRequest should not be handled by handle_request"
|
|
917
|
+
)
|
|
660
918
|
|
|
661
919
|
raise ValueError(f"Unhandled request type: {type(request)}")
|
|
662
920
|
|
|
@@ -678,8 +936,13 @@ class RemoteExecutionClient:
|
|
|
678
936
|
await self.clear_halt_state(request.request_id)
|
|
679
937
|
|
|
680
938
|
async def handle_streaming_request(
|
|
681
|
-
self,
|
|
682
|
-
|
|
939
|
+
self,
|
|
940
|
+
request: Any,
|
|
941
|
+
) -> AsyncGenerator[Any, None]:
|
|
942
|
+
from exponent.core.remote_execution.cli_rpc_types import (
|
|
943
|
+
StreamingCodeExecutionRequest,
|
|
944
|
+
)
|
|
945
|
+
|
|
683
946
|
if not isinstance(request, StreamingCodeExecutionRequest):
|
|
684
947
|
assert False, f"{type(request)} should be sent to handle_streaming_request"
|
|
685
948
|
async for output in execute_code_streaming(
|