indent 0.1.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. exponent/__init__.py +34 -0
  2. exponent/cli.py +110 -0
  3. exponent/commands/cloud_commands.py +585 -0
  4. exponent/commands/common.py +411 -0
  5. exponent/commands/config_commands.py +334 -0
  6. exponent/commands/run_commands.py +222 -0
  7. exponent/commands/settings.py +56 -0
  8. exponent/commands/types.py +111 -0
  9. exponent/commands/upgrade.py +29 -0
  10. exponent/commands/utils.py +146 -0
  11. exponent/core/config.py +180 -0
  12. exponent/core/graphql/__init__.py +0 -0
  13. exponent/core/graphql/client.py +61 -0
  14. exponent/core/graphql/get_chats_query.py +47 -0
  15. exponent/core/graphql/mutations.py +160 -0
  16. exponent/core/graphql/queries.py +146 -0
  17. exponent/core/graphql/subscriptions.py +16 -0
  18. exponent/core/remote_execution/checkpoints.py +212 -0
  19. exponent/core/remote_execution/cli_rpc_types.py +499 -0
  20. exponent/core/remote_execution/client.py +999 -0
  21. exponent/core/remote_execution/code_execution.py +77 -0
  22. exponent/core/remote_execution/default_env.py +31 -0
  23. exponent/core/remote_execution/error_info.py +45 -0
  24. exponent/core/remote_execution/exceptions.py +10 -0
  25. exponent/core/remote_execution/file_write.py +35 -0
  26. exponent/core/remote_execution/files.py +330 -0
  27. exponent/core/remote_execution/git.py +268 -0
  28. exponent/core/remote_execution/http_fetch.py +94 -0
  29. exponent/core/remote_execution/languages/python_execution.py +239 -0
  30. exponent/core/remote_execution/languages/shell_streaming.py +226 -0
  31. exponent/core/remote_execution/languages/types.py +20 -0
  32. exponent/core/remote_execution/port_utils.py +73 -0
  33. exponent/core/remote_execution/session.py +128 -0
  34. exponent/core/remote_execution/system_context.py +26 -0
  35. exponent/core/remote_execution/terminal_session.py +375 -0
  36. exponent/core/remote_execution/terminal_types.py +29 -0
  37. exponent/core/remote_execution/tool_execution.py +595 -0
  38. exponent/core/remote_execution/tool_type_utils.py +39 -0
  39. exponent/core/remote_execution/truncation.py +296 -0
  40. exponent/core/remote_execution/types.py +635 -0
  41. exponent/core/remote_execution/utils.py +477 -0
  42. exponent/core/types/__init__.py +0 -0
  43. exponent/core/types/command_data.py +206 -0
  44. exponent/core/types/event_types.py +89 -0
  45. exponent/core/types/generated/__init__.py +0 -0
  46. exponent/core/types/generated/strategy_info.py +213 -0
  47. exponent/migration-docs/login.md +112 -0
  48. exponent/py.typed +4 -0
  49. exponent/utils/__init__.py +0 -0
  50. exponent/utils/colors.py +92 -0
  51. exponent/utils/version.py +289 -0
  52. indent-0.1.26.dist-info/METADATA +38 -0
  53. indent-0.1.26.dist-info/RECORD +55 -0
  54. indent-0.1.26.dist-info/WHEEL +4 -0
  55. indent-0.1.26.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,999 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import json
5
+ import logging
6
+ import time
7
+ import uuid
8
+ from collections.abc import AsyncGenerator, Callable, Coroutine, Generator
9
+ from contextlib import asynccontextmanager
10
+ from dataclasses import dataclass
11
+ from typing import Any, TypeVar, cast
12
+
13
+ import msgspec
14
+ import websockets.exceptions
15
+ from httpx import (
16
+ AsyncClient,
17
+ codes as http_status,
18
+ )
19
+ from pydantic import BaseModel
20
+ from websockets.asyncio import client as asyncio_websockets_client
21
+ from websockets.asyncio.client import ClientConnection, connect
22
+
23
+ from exponent.commands.utils import ConnectionTracker
24
+ from exponent.core.config import is_editable_install
25
+ from exponent.core.remote_execution import files, system_context
26
+ from exponent.core.remote_execution.cli_rpc_types import (
27
+ BashToolInput,
28
+ BatchToolExecutionRequest,
29
+ BatchToolExecutionResponse,
30
+ CliRpcRequest,
31
+ CliRpcResponse,
32
+ ErrorResponse,
33
+ ErrorToolResult,
34
+ GenerateUploadUrlRequest,
35
+ GenerateUploadUrlResponse,
36
+ GetAllFilesRequest,
37
+ GetAllFilesResponse,
38
+ HttpRequest,
39
+ KeepAliveCliChatRequest,
40
+ KeepAliveCliChatResponse,
41
+ StartTerminalRequest,
42
+ StartTerminalResponse,
43
+ StopTerminalRequest,
44
+ StopTerminalResponse,
45
+ SwitchCLIChatRequest,
46
+ SwitchCLIChatResponse,
47
+ TerminalInputRequest,
48
+ TerminalInputResponse,
49
+ TerminalResizeRequest,
50
+ TerminalResizeResponse,
51
+ TerminateRequest,
52
+ TerminateResponse,
53
+ ToolExecutionRequest,
54
+ ToolExecutionResponse,
55
+ ToolResultType,
56
+ )
57
+ from exponent.core.remote_execution.code_execution import (
58
+ execute_code_streaming,
59
+ )
60
+ from exponent.core.remote_execution.files import file_walk
61
+ from exponent.core.remote_execution.http_fetch import fetch_http_content
62
+ from exponent.core.remote_execution.session import (
63
+ RemoteExecutionClientSession,
64
+ get_session,
65
+ send_exception_log,
66
+ )
67
+ from exponent.core.remote_execution.terminal_session import TerminalSessionManager
68
+ from exponent.core.remote_execution.terminal_types import TerminalMessage
69
+ from exponent.core.remote_execution.tool_execution import (
70
+ execute_bash_tool,
71
+ execute_tool,
72
+ truncate_result,
73
+ )
74
+ from exponent.core.remote_execution.types import (
75
+ ChatSource,
76
+ CLIConnectedState,
77
+ CreateChatResponse,
78
+ HeartbeatInfo,
79
+ RunWorkflowRequest,
80
+ WorkflowInput,
81
+ WorkflowTriggerRequest,
82
+ WorkflowTriggerResponse,
83
+ )
84
+ from exponent.core.remote_execution.utils import (
85
+ deserialize_api_response,
86
+ )
87
+ from exponent.utils.version import get_installed_version
88
+
89
+ logger = logging.getLogger(__name__)
90
+
91
+
92
+ TModel = TypeVar("TModel", bound=BaseModel)
93
+
94
+
95
+ @dataclass
96
+ class WSDisconnected:
97
+ error_message: str | None = None
98
+
99
+
100
+ @dataclass
101
+ class SwitchCLIChat:
102
+ new_chat_uuid: str
103
+
104
+
105
+ REMOTE_EXECUTION_CLIENT_EXIT_INFO = WSDisconnected | SwitchCLIChat
106
+
107
+ # UUID for a single run of the CLI
108
+ cli_uuid = uuid.uuid4()
109
+
110
+
111
+ class RemoteExecutionClient:
112
+ def __init__(
113
+ self,
114
+ session: RemoteExecutionClientSession,
115
+ file_cache: files.FileCache | None = None,
116
+ ):
117
+ self.current_session = session
118
+ self.file_cache = file_cache or files.FileCache(session.working_directory)
119
+
120
+ # for active code executions, track whether they should be halted
121
+ # correlation_id -> should_halt
122
+ self._halt_states: dict[str, bool] = {}
123
+ self._halt_lock = asyncio.Lock()
124
+
125
+ # Track last request time for timeout functionality
126
+ self._last_request_time: float | None = None
127
+
128
+ # Track pending upload URL requests
129
+ self._pending_upload_requests: dict[
130
+ str, asyncio.Future[GenerateUploadUrlResponse]
131
+ ] = {}
132
+ self._upload_request_lock = asyncio.Lock()
133
+ self._websocket: ClientConnection | None = None
134
+
135
+ @property
136
+ def working_directory(self) -> str:
137
+ return self.current_session.working_directory
138
+
139
+ @property
140
+ def api_client(self) -> AsyncClient:
141
+ return self.current_session.api_client
142
+
143
+ @property
144
+ def ws_client(self) -> AsyncClient:
145
+ return self.current_session.ws_client
146
+
147
+ async def add_code_execution_to_halt_states(self, correlation_id: str) -> None:
148
+ async with self._halt_lock:
149
+ self._halt_states[correlation_id] = False
150
+
151
+ async def halt_all_code_executions(self) -> None:
152
+ logger.info(f"Halting all code executions: {self._halt_states}")
153
+ async with self._halt_lock:
154
+ self._halt_states = {
155
+ correlation_id: True for correlation_id in self._halt_states.keys()
156
+ }
157
+
158
+ async def clear_halt_state(self, correlation_id: str) -> None:
159
+ async with self._halt_lock:
160
+ self._halt_states.pop(correlation_id, None)
161
+
162
+ def get_halt_check(self, correlation_id: str) -> Callable[[], bool]:
163
+ def should_halt() -> bool:
164
+ # Don't need to lock here, since just reading from dict
165
+ return self._halt_states.get(correlation_id, False)
166
+
167
+ return should_halt
168
+
169
+ async def _timeout_monitor(
170
+ self, timeout_seconds: int | None
171
+ ) -> WSDisconnected | None:
172
+ """Monitor for inactivity timeout and return WSDisconnected if timeout occurs.
173
+
174
+ If timeout_seconds is None, keeps looping indefinitely until cancelled.
175
+ """
176
+ try:
177
+ while True:
178
+ await asyncio.sleep(1)
179
+ if (
180
+ timeout_seconds is not None
181
+ and self._last_request_time is not None
182
+ and time.time() - self._last_request_time > timeout_seconds
183
+ ):
184
+ logger.info(
185
+ f"No requests received for {timeout_seconds} seconds. Shutting down..."
186
+ )
187
+ return WSDisconnected(
188
+ error_message=f"Timeout after {timeout_seconds} seconds of inactivity"
189
+ )
190
+ except asyncio.CancelledError:
191
+ # Handle cancellation gracefully
192
+ return None
193
+
194
+ async def _handle_websocket_message( # noqa: PLR0911, PLR0915
195
+ self,
196
+ msg: str,
197
+ websocket: ClientConnection,
198
+ requests: asyncio.Queue[CliRpcRequest],
199
+ terminal_session_manager: TerminalSessionManager,
200
+ ) -> REMOTE_EXECUTION_CLIENT_EXIT_INFO | None:
201
+ """Handle an incoming websocket message.
202
+ Returns None to continue processing, or a REMOTE_EXECUTION_CLIENT_EXIT_INFO to exit."""
203
+
204
+ self._last_request_time = time.time()
205
+
206
+ msg_data = json.loads(msg)
207
+ if msg_data["type"] == "result":
208
+ data = json.dumps(msg_data["data"])
209
+ try:
210
+ response = msgspec.json.decode(data, type=CliRpcResponse)
211
+ if isinstance(response.response, GenerateUploadUrlResponse):
212
+ async with self._upload_request_lock:
213
+ if response.request_id in self._pending_upload_requests:
214
+ future = self._pending_upload_requests.pop(
215
+ response.request_id
216
+ )
217
+ future.set_result(response.response)
218
+ except Exception as e:
219
+ logger.error(f"Error handling upload URL response: {e}")
220
+ return None
221
+ elif msg_data["type"] != "request":
222
+ return None
223
+
224
+ data = json.dumps(msg_data["data"])
225
+ try:
226
+ request = msgspec.json.decode(data, type=CliRpcRequest)
227
+ except (msgspec.DecodeError, msgspec.ValidationError) as e:
228
+ # Try and decode to get request_id if possible
229
+ request = msgspec.json.decode(data)
230
+ if isinstance(request, dict) and "request_id" in request:
231
+ request_id = request["request_id"]
232
+ if (
233
+ request.get("request", {}).get("type", {}) == "tool_execution"
234
+ ) and (
235
+ "tool_input" in request["request"]
236
+ and "tool_name" in request["request"]["tool_input"]
237
+ ):
238
+ tool_name = request["request"]["tool_input"]["tool_name"]
239
+ logger.error(
240
+ f"Error tool {tool_name} received in a request."
241
+ "Please ensure you are running the latest version of Indent. If this issue persists, please contact support."
242
+ )
243
+ await websocket.send(
244
+ json.dumps(
245
+ {
246
+ "type": "result",
247
+ "data": msgspec.to_builtins(
248
+ CliRpcResponse(
249
+ request_id=request_id,
250
+ response=ErrorResponse(
251
+ error_message=f"Unknown tool: {tool_name}. If you are running an older version of Indent, please upgrade to the latest version to ensure compatibility."
252
+ ),
253
+ )
254
+ ),
255
+ }
256
+ )
257
+ )
258
+ else:
259
+ logger.error(
260
+ "Error decoding cli rpc request. Please ensure you are running the latest version of Indent."
261
+ )
262
+ await websocket.send(
263
+ json.dumps(
264
+ {
265
+ "type": "result",
266
+ "data": msgspec.to_builtins(
267
+ CliRpcResponse(
268
+ request_id=request_id,
269
+ response=ErrorResponse(
270
+ error_message=f"Unknown cli rpc request type: {request}",
271
+ ),
272
+ )
273
+ ),
274
+ }
275
+ )
276
+ )
277
+
278
+ return None
279
+ else:
280
+ # If we couldn't get a request_id, re-raise and fail noisily
281
+ raise e
282
+
283
+ if isinstance(request.request, TerminateRequest):
284
+ await self.halt_all_code_executions()
285
+ await websocket.send(
286
+ json.dumps(
287
+ {
288
+ "type": "result",
289
+ "data": msgspec.to_builtins(
290
+ CliRpcResponse(
291
+ request_id=request.request_id,
292
+ response=TerminateResponse(),
293
+ )
294
+ ),
295
+ }
296
+ )
297
+ )
298
+ return None
299
+ elif isinstance(request.request, SwitchCLIChatRequest):
300
+ await websocket.send(
301
+ json.dumps(
302
+ {
303
+ "type": "result",
304
+ "data": msgspec.to_builtins(
305
+ CliRpcResponse(
306
+ request_id=request.request_id,
307
+ response=SwitchCLIChatResponse(),
308
+ )
309
+ ),
310
+ }
311
+ )
312
+ )
313
+ return SwitchCLIChat(new_chat_uuid=request.request.new_chat_uuid)
314
+ elif isinstance(request.request, KeepAliveCliChatRequest):
315
+ await websocket.send(
316
+ json.dumps(
317
+ {
318
+ "type": "result",
319
+ "data": msgspec.to_builtins(
320
+ CliRpcResponse(
321
+ request_id=request.request_id,
322
+ response=KeepAliveCliChatResponse(),
323
+ )
324
+ ),
325
+ }
326
+ )
327
+ )
328
+ return None
329
+ elif isinstance(request.request, StartTerminalRequest):
330
+ # Start a new terminal session
331
+ session_id = await terminal_session_manager.start_session(
332
+ websocket=websocket,
333
+ session_id=request.request.session_id,
334
+ command=request.request.command,
335
+ cols=request.request.cols,
336
+ rows=request.request.rows,
337
+ env=request.request.env,
338
+ )
339
+ await websocket.send(
340
+ json.dumps(
341
+ {
342
+ "type": "result",
343
+ "data": msgspec.to_builtins(
344
+ CliRpcResponse(
345
+ request_id=request.request_id,
346
+ response=StartTerminalResponse(
347
+ session_id=session_id, success=True
348
+ ),
349
+ )
350
+ ),
351
+ }
352
+ )
353
+ )
354
+ return None
355
+ elif isinstance(request.request, TerminalInputRequest):
356
+ # Send input to terminal session
357
+ success = await terminal_session_manager.send_input(
358
+ session_id=request.request.session_id,
359
+ data=request.request.data,
360
+ )
361
+ await websocket.send(
362
+ json.dumps(
363
+ {
364
+ "type": "result",
365
+ "data": msgspec.to_builtins(
366
+ CliRpcResponse(
367
+ request_id=request.request_id,
368
+ response=TerminalInputResponse(
369
+ session_id=request.request.session_id,
370
+ success=success,
371
+ ),
372
+ )
373
+ ),
374
+ }
375
+ )
376
+ )
377
+ return None
378
+ elif isinstance(request.request, TerminalResizeRequest):
379
+ # Resize terminal session
380
+ success = await terminal_session_manager.resize_terminal(
381
+ session_id=request.request.session_id,
382
+ rows=request.request.rows,
383
+ cols=request.request.cols,
384
+ )
385
+ await websocket.send(
386
+ json.dumps(
387
+ {
388
+ "type": "result",
389
+ "data": msgspec.to_builtins(
390
+ CliRpcResponse(
391
+ request_id=request.request_id,
392
+ response=TerminalResizeResponse(
393
+ session_id=request.request.session_id,
394
+ success=success,
395
+ ),
396
+ )
397
+ ),
398
+ }
399
+ )
400
+ )
401
+ return None
402
+ elif isinstance(request.request, StopTerminalRequest):
403
+ # Stop terminal session
404
+ success = await terminal_session_manager.stop_session(
405
+ session_id=request.request.session_id
406
+ )
407
+ await websocket.send(
408
+ json.dumps(
409
+ {
410
+ "type": "result",
411
+ "data": msgspec.to_builtins(
412
+ CliRpcResponse(
413
+ request_id=request.request_id,
414
+ response=StopTerminalResponse(
415
+ session_id=request.request.session_id,
416
+ success=success,
417
+ ),
418
+ )
419
+ ),
420
+ }
421
+ )
422
+ )
423
+ return None
424
+ else:
425
+ if isinstance(request.request, ToolExecutionRequest) and isinstance(
426
+ request.request.tool_input, BashToolInput
427
+ ):
428
+ await self.add_code_execution_to_halt_states(request.request_id)
429
+ elif isinstance(request.request, BatchToolExecutionRequest):
430
+ # Add halt state if any of the batch tools are bash commands
431
+ if any(
432
+ isinstance(tool_input, BashToolInput)
433
+ for tool_input in request.request.tool_inputs
434
+ ):
435
+ await self.add_code_execution_to_halt_states(request.request_id)
436
+
437
+ await requests.put(request)
438
+ return None
439
+
440
+ async def _setup_tasks(
441
+ self,
442
+ beats: asyncio.Queue[HeartbeatInfo],
443
+ requests: asyncio.Queue[CliRpcRequest],
444
+ results: asyncio.Queue[CliRpcResponse],
445
+ ) -> list[asyncio.Task[None]]:
446
+ """Setup heartbeat and executor tasks."""
447
+
448
+ async def beat() -> None:
449
+ while True:
450
+ info = await self.get_heartbeat_info()
451
+ await beats.put(info)
452
+ await asyncio.sleep(3)
453
+
454
+ # Lock to ensure that only one executor can grab a
455
+ # request at a time.
456
+ requests_lock = asyncio.Lock()
457
+
458
+ # Lock to ensure that only one executor can put a
459
+ # result in the results queue at a time.
460
+ results_lock = asyncio.Lock()
461
+
462
+ async def executor() -> None:
463
+ # We use locks here to protect the request/result
464
+ # queues from being accessed by multiple executors.
465
+ while True:
466
+ async with requests_lock:
467
+ request = await requests.get()
468
+
469
+ try:
470
+ # Check if this is a streaming request
471
+ from exponent.core.remote_execution.cli_rpc_types import (
472
+ StreamingCodeExecutionRequest,
473
+ )
474
+
475
+ if isinstance(request.request, StreamingCodeExecutionRequest):
476
+ async for streaming_response in self.handle_streaming_request(
477
+ request.request
478
+ ):
479
+ async with results_lock:
480
+ await results.put(
481
+ CliRpcResponse(
482
+ request_id=request.request_id,
483
+ response=streaming_response,
484
+ )
485
+ )
486
+ else:
487
+ # Note that we don't want to hold the lock here
488
+ # because we want other executors to be able to
489
+ # grab requests while we're handling a request.
490
+ logger.info(f"Handling request {request}")
491
+ response = await self.handle_request(request)
492
+ async with results_lock:
493
+ logger.info(f"Putting response {response}")
494
+ await results.put(response)
495
+ except Exception as e:
496
+ logger.info(f"Error handling request {request}:\n\n{e}")
497
+ try:
498
+ await send_exception_log(e, session=self.current_session)
499
+ except Exception:
500
+ pass
501
+ async with results_lock:
502
+ from exponent.core.remote_execution.cli_rpc_types import (
503
+ StreamingCodeExecutionRequest,
504
+ StreamingErrorResponse,
505
+ )
506
+
507
+ if isinstance(request.request, StreamingCodeExecutionRequest):
508
+ # For streaming requests, send a streaming error response
509
+ await results.put(
510
+ CliRpcResponse(
511
+ request_id=request.request_id,
512
+ response=StreamingErrorResponse(
513
+ correlation_id=request.request.correlation_id,
514
+ error_message=str(e),
515
+ ),
516
+ )
517
+ )
518
+ else:
519
+ await results.put(
520
+ CliRpcResponse(
521
+ request_id=request.request_id,
522
+ response=ErrorResponse(
523
+ error_message=str(e),
524
+ ),
525
+ )
526
+ )
527
+
528
+ beat_task = asyncio.create_task(beat())
529
+ # Three parallel executors to handle requests
530
+
531
+ executor_tasks = [
532
+ asyncio.create_task(executor()),
533
+ asyncio.create_task(executor()),
534
+ asyncio.create_task(executor()),
535
+ ]
536
+
537
+ return [beat_task, *executor_tasks]
538
+
539
+ async def _process_websocket_messages(
540
+ self,
541
+ websocket: ClientConnection,
542
+ beats: asyncio.Queue[HeartbeatInfo],
543
+ requests: asyncio.Queue[CliRpcRequest],
544
+ results: asyncio.Queue[CliRpcResponse],
545
+ terminal_output_queue: asyncio.Queue[TerminalMessage],
546
+ terminal_session_manager: TerminalSessionManager,
547
+ ) -> REMOTE_EXECUTION_CLIENT_EXIT_INFO:
548
+ """Process messages from the websocket connection."""
549
+ pending: set[asyncio.Task[object]] = set()
550
+ try:
551
+ recv = asyncio.create_task(websocket.recv())
552
+ get_beat = asyncio.create_task(beats.get())
553
+ get_result = asyncio.create_task(results.get())
554
+ get_terminal_output = asyncio.create_task(terminal_output_queue.get())
555
+ pending = {recv, get_beat, get_result, get_terminal_output}
556
+
557
+ while True:
558
+ done, pending = await asyncio.wait(
559
+ pending, return_when=asyncio.FIRST_COMPLETED
560
+ )
561
+
562
+ if recv in done:
563
+ msg = str(recv.result())
564
+ exit_info = await self._handle_websocket_message(
565
+ msg, websocket, requests, terminal_session_manager
566
+ )
567
+ if exit_info is not None:
568
+ return exit_info
569
+
570
+ recv = asyncio.create_task(websocket.recv())
571
+ pending.add(recv)
572
+
573
+ if get_beat in done:
574
+ info = get_beat.result()
575
+ data = json.loads(info.model_dump_json())
576
+ msg = json.dumps({"type": "heartbeat", "data": data})
577
+ await websocket.send(msg)
578
+
579
+ get_beat = asyncio.create_task(beats.get())
580
+ pending.add(get_beat)
581
+
582
+ if get_result in done:
583
+ response = get_result.result()
584
+ # All responses are now CliRpcResponse with msgspec
585
+ data = msgspec.to_builtins(response)
586
+ msg = json.dumps({"type": "result", "data": data})
587
+ await websocket.send(msg)
588
+
589
+ get_result = asyncio.create_task(results.get())
590
+ pending.add(get_result)
591
+
592
+ if get_terminal_output in done:
593
+ terminal_message = get_terminal_output.result()
594
+ data = msgspec.to_builtins(terminal_message)
595
+ msg = json.dumps({"type": "terminal_message", "data": data})
596
+ await websocket.send(msg)
597
+
598
+ get_terminal_output = asyncio.create_task(
599
+ terminal_output_queue.get()
600
+ )
601
+ pending.add(get_terminal_output)
602
+ finally:
603
+ for task in pending:
604
+ task.cancel()
605
+
606
+ await asyncio.gather(*pending, return_exceptions=True)
607
+
608
+ async def _handle_websocket_connection(
609
+ self,
610
+ websocket: ClientConnection,
611
+ connection_tracker: ConnectionTracker | None,
612
+ beats: asyncio.Queue[HeartbeatInfo],
613
+ requests: asyncio.Queue[CliRpcRequest],
614
+ results: asyncio.Queue[CliRpcResponse],
615
+ terminal_output_queue: asyncio.Queue[TerminalMessage],
616
+ terminal_session_manager: TerminalSessionManager,
617
+ ) -> REMOTE_EXECUTION_CLIENT_EXIT_INFO | None:
618
+ """Handle a single websocket connection.
619
+ Returns None to continue with reconnection attempts, or an exit info to terminate."""
620
+ if connection_tracker is not None:
621
+ await connection_tracker.set_connected(True)
622
+
623
+ self._websocket = websocket
624
+
625
+ try:
626
+ return await self._process_websocket_messages(
627
+ websocket,
628
+ beats,
629
+ requests,
630
+ results,
631
+ terminal_output_queue,
632
+ terminal_session_manager,
633
+ )
634
+ except websockets.exceptions.ConnectionClosed as e:
635
+ if e.rcvd is not None:
636
+ if e.rcvd.code == 1000:
637
+ # Normal closure, exit completely
638
+ return WSDisconnected()
639
+ elif e.rcvd.code == 1008:
640
+ error_message = (
641
+ "Error connecting to websocket"
642
+ if e.rcvd.reason is None
643
+ else e.rcvd.reason
644
+ )
645
+ return WSDisconnected(error_message=error_message)
646
+ # Otherwise, allow reconnection attempt
647
+ logger.debug("Websocket connection closed by remote.")
648
+ return None
649
+ except TimeoutError:
650
+ # Timeout, allow reconnection attempt
651
+ # TODO: investgate if this is needed, possibly scope it down
652
+ return None
653
+ finally:
654
+ if connection_tracker is not None:
655
+ await connection_tracker.set_connected(False)
656
+
657
+ async def run_connection(
658
+ self,
659
+ chat_uuid: str,
660
+ connection_tracker: ConnectionTracker | None = None,
661
+ timeout_seconds: int | None = None,
662
+ ) -> REMOTE_EXECUTION_CLIENT_EXIT_INFO:
663
+ """Run the websocket connection loop with optional inactivity timeout."""
664
+ self.current_session.set_chat_uuid(chat_uuid)
665
+
666
+ # Initialize last request time for timeout monitoring
667
+ self._last_request_time = time.time()
668
+
669
+ # Create queues ONCE - persist across reconnections
670
+ beats: asyncio.Queue[HeartbeatInfo] = asyncio.Queue()
671
+ requests: asyncio.Queue[CliRpcRequest] = asyncio.Queue()
672
+ results: asyncio.Queue[CliRpcResponse] = asyncio.Queue()
673
+ terminal_output_queue: asyncio.Queue[TerminalMessage] = asyncio.Queue()
674
+
675
+ # Create terminal session manager ONCE - persist across reconnections
676
+ terminal_session_manager = TerminalSessionManager(terminal_output_queue)
677
+
678
+ # Create tasks ONCE - persist across reconnections
679
+ executors = await self._setup_tasks(beats, requests, results)
680
+
681
+ try:
682
+ async for websocket in self.ws_connect(f"/api/ws/chat/{chat_uuid}"):
683
+ # Always run connection and timeout monitor concurrently
684
+ # If timeout_seconds is None, timeout monitor will loop indefinitely
685
+ done, pending = await asyncio.wait(
686
+ [
687
+ asyncio.create_task(
688
+ self._handle_websocket_connection(
689
+ websocket,
690
+ connection_tracker,
691
+ beats,
692
+ requests,
693
+ results,
694
+ terminal_output_queue,
695
+ terminal_session_manager,
696
+ )
697
+ ),
698
+ asyncio.create_task(self._timeout_monitor(timeout_seconds)),
699
+ ],
700
+ return_when=asyncio.FIRST_COMPLETED,
701
+ )
702
+
703
+ # Cancel pending tasks
704
+ for task in pending:
705
+ task.cancel()
706
+
707
+ # Return result from completed task
708
+ for task in done:
709
+ result = await task
710
+ # If we get None, we'll try to reconnect
711
+ if result is not None:
712
+ return result
713
+
714
+ # If we exit the websocket connection loop without returning,
715
+ # it means we couldn't establish a connection
716
+ return WSDisconnected(
717
+ error_message="Could not establish websocket connection"
718
+ )
719
+ finally:
720
+ # Stop all terminal sessions to clean up PTY processes
721
+ await terminal_session_manager.stop_all_sessions()
722
+
723
+ # Cancel all background tasks when exiting
724
+ for task in executors:
725
+ task.cancel()
726
+ await asyncio.gather(*executors, return_exceptions=True)
727
+
728
+ async def create_chat(self, chat_source: ChatSource) -> CreateChatResponse:
729
+ response = await self.api_client.post(
730
+ "/api/remote_execution/create_chat",
731
+ params={"chat_source": chat_source.value},
732
+ )
733
+ return await deserialize_api_response(response, CreateChatResponse)
734
+
735
+ # deprecated
736
+ async def run_workflow(self, chat_uuid: str, workflow_id: str) -> dict[str, Any]:
737
+ response = await self.api_client.post(
738
+ "/api/remote_execution/run_workflow",
739
+ json=RunWorkflowRequest(
740
+ chat_uuid=chat_uuid,
741
+ workflow_id=workflow_id,
742
+ ).model_dump(),
743
+ timeout=60,
744
+ )
745
+ if response.status_code != http_status.OK:
746
+ raise Exception(
747
+ f"Failed to run workflow with status code {response.status_code} and response {response.text}"
748
+ )
749
+ return cast(dict[str, Any], response.json())
750
+
751
+ async def trigger_workflow(
752
+ self, workflow_name: str, workflow_input: WorkflowInput
753
+ ) -> WorkflowTriggerResponse:
754
+ response = await self.api_client.post(
755
+ "/api/remote_execution/trigger_workflow",
756
+ json=WorkflowTriggerRequest(
757
+ workflow_name=workflow_name,
758
+ workflow_input=workflow_input,
759
+ ).model_dump(),
760
+ )
761
+ return await deserialize_api_response(response, WorkflowTriggerResponse)
762
+
763
+ async def get_heartbeat_info(self) -> HeartbeatInfo:
764
+ return HeartbeatInfo(
765
+ system_info=await system_context.get_system_info(self.working_directory),
766
+ exponent_version=get_installed_version(),
767
+ editable_installation=is_editable_install(),
768
+ cli_uuid=str(cli_uuid),
769
+ )
770
+
771
+ async def send_heartbeat(self, chat_uuid: str) -> CLIConnectedState:
772
+ logger.info(f"Sending heartbeat for chat_uuid {chat_uuid}")
773
+ heartbeat_info = await self.get_heartbeat_info()
774
+ response = await self.api_client.post(
775
+ f"/api/remote_execution/{chat_uuid}/heartbeat",
776
+ content=heartbeat_info.model_dump_json(),
777
+ timeout=60,
778
+ )
779
+ if response.status_code != http_status.OK:
780
+ raise Exception(
781
+ f"Heartbeat failed with status code {response.status_code} and response {response.text}"
782
+ )
783
+ connected_state = await deserialize_api_response(response, CLIConnectedState)
784
+ logger.info(f"Heartbeat response: {connected_state}")
785
+ return connected_state
786
+
787
+ async def request_upload_url(
788
+ self, s3_key: str, content_type: str
789
+ ) -> GenerateUploadUrlResponse:
790
+ if self._websocket is None:
791
+ raise RuntimeError("No active websocket connection")
792
+
793
+ request_id = str(uuid.uuid4())
794
+ request = CliRpcRequest(
795
+ request_id=request_id,
796
+ request=GenerateUploadUrlRequest(s3_key=s3_key, content_type=content_type),
797
+ )
798
+
799
+ future: asyncio.Future[GenerateUploadUrlResponse] = asyncio.Future()
800
+ async with self._upload_request_lock:
801
+ self._pending_upload_requests[request_id] = future
802
+
803
+ try:
804
+ await self._websocket.send(
805
+ json.dumps({"type": "request", "data": msgspec.to_builtins(request)})
806
+ )
807
+
808
+ response = await asyncio.wait_for(future, timeout=30)
809
+ return response
810
+ except TimeoutError:
811
+ async with self._upload_request_lock:
812
+ self._pending_upload_requests.pop(request_id, None)
813
+ raise RuntimeError("Timeout waiting for upload URL response")
814
+ except Exception as e:
815
+ async with self._upload_request_lock:
816
+ self._pending_upload_requests.pop(request_id, None)
817
+ raise e
818
+
819
+ async def handle_request(self, request: CliRpcRequest) -> CliRpcResponse:
820
+ # Update last request time for timeout functionality
821
+ self._last_request_time = time.time()
822
+
823
+ try:
824
+ if isinstance(request.request, ToolExecutionRequest):
825
+ if isinstance(request.request.tool_input, BashToolInput):
826
+ raw_result = await execute_bash_tool(
827
+ request.request.tool_input,
828
+ self.working_directory,
829
+ should_halt=self.get_halt_check(request.request_id),
830
+ )
831
+ else:
832
+ raw_result = await execute_tool( # type: ignore[assignment]
833
+ request.request.tool_input, self.working_directory, self
834
+ )
835
+ tool_result = truncate_result(raw_result)
836
+ return CliRpcResponse(
837
+ request_id=request.request_id,
838
+ response=ToolExecutionResponse(
839
+ tool_result=tool_result,
840
+ ),
841
+ )
842
+ elif isinstance(request.request, GetAllFilesRequest):
843
+ files = await file_walk(self.working_directory)
844
+ return CliRpcResponse(
845
+ request_id=request.request_id,
846
+ response=GetAllFilesResponse(files=files),
847
+ )
848
+ elif isinstance(request.request, BatchToolExecutionRequest):
849
+ coros: list[Coroutine[Any, Any, ToolResultType]] = []
850
+ for tool_input in request.request.tool_inputs:
851
+ if isinstance(tool_input, BashToolInput):
852
+ coros.append(
853
+ execute_bash_tool(
854
+ tool_input,
855
+ self.working_directory,
856
+ should_halt=self.get_halt_check(request.request_id),
857
+ )
858
+ )
859
+ else:
860
+ coros.append(
861
+ execute_tool(tool_input, self.working_directory, self)
862
+ )
863
+
864
+ results: list[ToolResultType | BaseException] = await asyncio.gather(
865
+ *coros, return_exceptions=True
866
+ )
867
+
868
+ processed_results: list[ToolResultType] = []
869
+ for result in results:
870
+ if not isinstance(result, BaseException):
871
+ processed_results.append(truncate_result(result))
872
+ else:
873
+ processed_results.append(
874
+ ErrorToolResult(error_message=str(result))
875
+ )
876
+
877
+ return CliRpcResponse(
878
+ request_id=request.request_id,
879
+ response=BatchToolExecutionResponse(
880
+ tool_results=processed_results,
881
+ ),
882
+ )
883
+ elif isinstance(request.request, HttpRequest):
884
+ http_response = await fetch_http_content(request.request)
885
+ return CliRpcResponse(
886
+ request_id=request.request_id,
887
+ response=http_response,
888
+ )
889
+ elif isinstance(request.request, TerminateRequest):
890
+ raise ValueError(
891
+ "TerminateRequest should not be handled by handle_request"
892
+ )
893
+
894
+ elif isinstance(request.request, SwitchCLIChatRequest):
895
+ raise ValueError(
896
+ "SwitchCLIChatRequest should not be handled by handle_request"
897
+ )
898
+ elif isinstance(request.request, KeepAliveCliChatRequest):
899
+ raise ValueError(
900
+ "KeepAliveCliChatRequest should not be handled by handle_request"
901
+ )
902
+ elif isinstance(request.request, StartTerminalRequest):
903
+ raise ValueError(
904
+ "StartTerminalRequest should not be handled by handle_request"
905
+ )
906
+ elif isinstance(request.request, TerminalInputRequest):
907
+ raise ValueError(
908
+ "TerminalInputRequest should not be handled by handle_request"
909
+ )
910
+ elif isinstance(request.request, TerminalResizeRequest):
911
+ raise ValueError(
912
+ "TerminalResizeRequest should not be handled by handle_request"
913
+ )
914
+ elif isinstance(request.request, StopTerminalRequest):
915
+ raise ValueError(
916
+ "StopTerminalRequest should not be handled by handle_request"
917
+ )
918
+
919
+ raise ValueError(f"Unhandled request type: {type(request)}")
920
+
921
+ except Exception as e:
922
+ logger.error(f"Error handling request {request}:\n\n{e}")
923
+ raise e
924
+ finally:
925
+ # Clean up halt state after request is complete
926
+ if isinstance(request.request, ToolExecutionRequest) and isinstance(
927
+ request.request.tool_input, BashToolInput
928
+ ):
929
+ await self.clear_halt_state(request.request_id)
930
+ elif isinstance(request.request, BatchToolExecutionRequest):
931
+ # Clear halt state if any of the batch tools were bash commands
932
+ if any(
933
+ isinstance(tool_input, BashToolInput)
934
+ for tool_input in request.request.tool_inputs
935
+ ):
936
+ await self.clear_halt_state(request.request_id)
937
+
938
+ async def handle_streaming_request(
939
+ self,
940
+ request: Any,
941
+ ) -> AsyncGenerator[Any, None]:
942
+ from exponent.core.remote_execution.cli_rpc_types import (
943
+ StreamingCodeExecutionRequest,
944
+ )
945
+
946
+ if not isinstance(request, StreamingCodeExecutionRequest):
947
+ assert False, f"{type(request)} should be sent to handle_streaming_request"
948
+ async for output in execute_code_streaming(
949
+ request,
950
+ self.current_session,
951
+ working_directory=self.working_directory,
952
+ should_halt=self.get_halt_check(request.correlation_id),
953
+ ):
954
+ yield output
955
+
956
+ def ws_connect(self, path: str) -> connect:
957
+ base_url = (
958
+ str(self.ws_client.base_url)
959
+ .replace("http://", "ws://")
960
+ .replace("https://", "wss://")
961
+ )
962
+
963
+ url = f"{base_url}{path}"
964
+ headers = {"api-key": self.api_client.headers["api-key"]}
965
+
966
+ def custom_backoff() -> Generator[float, None, None]:
967
+ yield 0.1 # short initial delay
968
+
969
+ delay = 0.5
970
+ while True:
971
+ if delay < 2.0:
972
+ yield delay
973
+ delay *= 1.5
974
+ else:
975
+ yield 2.0
976
+
977
+ # Can remove if this is added to public API
978
+ # https://github.com/python-websockets/websockets/issues/1395#issuecomment-3225670409
979
+ asyncio_websockets_client.backoff = custom_backoff # type: ignore[attr-defined, assignment]
980
+
981
+ conn = connect(
982
+ url, additional_headers=headers, open_timeout=10, ping_timeout=10
983
+ )
984
+
985
+ return conn
986
+
987
+ @staticmethod
988
+ @asynccontextmanager
989
+ async def session(
990
+ api_key: str,
991
+ base_url: str,
992
+ base_ws_url: str,
993
+ working_directory: str,
994
+ file_cache: files.FileCache | None = None,
995
+ ) -> AsyncGenerator[RemoteExecutionClient, None]:
996
+ async with get_session(
997
+ working_directory, base_url, base_ws_url, api_key
998
+ ) as session:
999
+ yield RemoteExecutionClient(session, file_cache)