indent 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of indent might be problematic. Click here for more details.

Files changed (56) hide show
  1. exponent/__init__.py +1 -0
  2. exponent/cli.py +112 -0
  3. exponent/commands/cloud_commands.py +85 -0
  4. exponent/commands/common.py +434 -0
  5. exponent/commands/config_commands.py +581 -0
  6. exponent/commands/github_app_commands.py +211 -0
  7. exponent/commands/listen_commands.py +96 -0
  8. exponent/commands/run_commands.py +208 -0
  9. exponent/commands/settings.py +56 -0
  10. exponent/commands/shell_commands.py +2840 -0
  11. exponent/commands/theme.py +246 -0
  12. exponent/commands/types.py +111 -0
  13. exponent/commands/upgrade.py +29 -0
  14. exponent/commands/utils.py +236 -0
  15. exponent/core/config.py +180 -0
  16. exponent/core/graphql/__init__.py +0 -0
  17. exponent/core/graphql/client.py +59 -0
  18. exponent/core/graphql/cloud_config_queries.py +77 -0
  19. exponent/core/graphql/get_chats_query.py +47 -0
  20. exponent/core/graphql/github_config_queries.py +56 -0
  21. exponent/core/graphql/mutations.py +75 -0
  22. exponent/core/graphql/queries.py +110 -0
  23. exponent/core/graphql/subscriptions.py +452 -0
  24. exponent/core/remote_execution/checkpoints.py +212 -0
  25. exponent/core/remote_execution/cli_rpc_types.py +214 -0
  26. exponent/core/remote_execution/client.py +545 -0
  27. exponent/core/remote_execution/code_execution.py +58 -0
  28. exponent/core/remote_execution/command_execution.py +105 -0
  29. exponent/core/remote_execution/error_info.py +45 -0
  30. exponent/core/remote_execution/exceptions.py +10 -0
  31. exponent/core/remote_execution/file_write.py +410 -0
  32. exponent/core/remote_execution/files.py +415 -0
  33. exponent/core/remote_execution/git.py +268 -0
  34. exponent/core/remote_execution/languages/python_execution.py +239 -0
  35. exponent/core/remote_execution/languages/shell_streaming.py +221 -0
  36. exponent/core/remote_execution/languages/types.py +20 -0
  37. exponent/core/remote_execution/session.py +128 -0
  38. exponent/core/remote_execution/system_context.py +54 -0
  39. exponent/core/remote_execution/tool_execution.py +289 -0
  40. exponent/core/remote_execution/truncation.py +284 -0
  41. exponent/core/remote_execution/types.py +670 -0
  42. exponent/core/remote_execution/utils.py +600 -0
  43. exponent/core/types/__init__.py +0 -0
  44. exponent/core/types/command_data.py +206 -0
  45. exponent/core/types/event_types.py +89 -0
  46. exponent/core/types/generated/__init__.py +0 -0
  47. exponent/core/types/generated/strategy_info.py +225 -0
  48. exponent/migration-docs/login.md +112 -0
  49. exponent/py.typed +4 -0
  50. exponent/utils/__init__.py +0 -0
  51. exponent/utils/colors.py +92 -0
  52. exponent/utils/version.py +289 -0
  53. indent-0.0.8.dist-info/METADATA +36 -0
  54. indent-0.0.8.dist-info/RECORD +56 -0
  55. indent-0.0.8.dist-info/WHEEL +4 -0
  56. indent-0.0.8.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,545 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import json
5
+ import logging
6
+ from collections.abc import AsyncGenerator, Callable
7
+ from contextlib import asynccontextmanager
8
+ from dataclasses import dataclass
9
+ from typing import Any, TypeVar, Union, cast
10
+
11
+ import msgspec
12
+ import websockets.client
13
+ import websockets.exceptions
14
+ from httpx import (
15
+ AsyncClient,
16
+ codes as http_status,
17
+ )
18
+ from pydantic import BaseModel
19
+
20
+ from exponent.commands.utils import ConnectionTracker
21
+ from exponent.core.config import is_editable_install
22
+ from exponent.core.remote_execution import files, system_context
23
+ from exponent.core.remote_execution.cli_rpc_types import (
24
+ BashToolInput,
25
+ BatchToolExecutionRequest,
26
+ BatchToolExecutionResponse,
27
+ CliRpcRequest,
28
+ CliRpcResponse,
29
+ ErrorResponse,
30
+ GetAllFilesRequest,
31
+ GetAllFilesResponse,
32
+ TerminateRequest,
33
+ TerminateResponse,
34
+ ToolExecutionRequest,
35
+ ToolExecutionResponse,
36
+ ToolResultType,
37
+ )
38
+ from exponent.core.remote_execution.code_execution import (
39
+ execute_code_streaming,
40
+ )
41
+ from exponent.core.remote_execution.files import file_walk
42
+ from exponent.core.remote_execution.session import (
43
+ RemoteExecutionClientSession,
44
+ get_session,
45
+ )
46
+ from exponent.core.remote_execution.tool_execution import (
47
+ execute_bash_tool,
48
+ execute_tool,
49
+ truncate_result,
50
+ )
51
+ from exponent.core.remote_execution.types import (
52
+ ChatSource,
53
+ CLIConnectedState,
54
+ CreateChatResponse,
55
+ GitInfo,
56
+ HeartbeatInfo,
57
+ RemoteExecutionResponseType,
58
+ RunWorkflowRequest,
59
+ StreamingCodeExecutionRequest,
60
+ )
61
+ from exponent.core.remote_execution.utils import (
62
+ deserialize_api_response,
63
+ )
64
+ from exponent.utils.version import get_installed_version
65
+
66
+ logger = logging.getLogger(__name__)
67
+
68
+
69
+ TModel = TypeVar("TModel", bound=BaseModel)
70
+
71
+
72
+ @dataclass
73
+ class WSDisconnected:
74
+ error_message: str | None = None
75
+
76
+
77
+ @dataclass
78
+ class SwitchCLIChat:
79
+ new_chat_uuid: str
80
+
81
+
82
+ REMOTE_EXECUTION_CLIENT_EXIT_INFO = Union[WSDisconnected, SwitchCLIChat]
83
+
84
+
85
+ class RemoteExecutionClient:
86
+ def __init__(
87
+ self,
88
+ session: RemoteExecutionClientSession,
89
+ file_cache: files.FileCache | None = None,
90
+ ):
91
+ self.current_session = session
92
+ self.file_cache = file_cache or files.FileCache(session.working_directory)
93
+
94
+ # for active code executions, track whether they should be halted
95
+ # correlation_id -> should_halt
96
+ self._halt_states: dict[str, bool] = {}
97
+ self._halt_lock = asyncio.Lock()
98
+
99
+ @property
100
+ def working_directory(self) -> str:
101
+ return self.current_session.working_directory
102
+
103
+ @property
104
+ def api_client(self) -> AsyncClient:
105
+ return self.current_session.api_client
106
+
107
+ @property
108
+ def ws_client(self) -> AsyncClient:
109
+ return self.current_session.ws_client
110
+
111
+ async def add_code_execution_to_halt_states(self, correlation_id: str) -> None:
112
+ async with self._halt_lock:
113
+ self._halt_states[correlation_id] = False
114
+
115
+ async def halt_all_code_executions(self) -> None:
116
+ logger.info(f"Halting all code executions: {self._halt_states}")
117
+ async with self._halt_lock:
118
+ self._halt_states = {
119
+ correlation_id: True for correlation_id in self._halt_states.keys()
120
+ }
121
+
122
+ async def clear_halt_state(self, correlation_id: str) -> None:
123
+ async with self._halt_lock:
124
+ self._halt_states.pop(correlation_id, None)
125
+
126
+ def get_halt_check(self, correlation_id: str) -> Callable[[], bool]:
127
+ def should_halt() -> bool:
128
+ # Don't need to lock here, since just reading from dict
129
+ return self._halt_states.get(correlation_id, False)
130
+
131
+ return should_halt
132
+
133
+ async def _handle_websocket_message(
134
+ self,
135
+ msg: str,
136
+ websocket: websockets.client.WebSocketClientProtocol,
137
+ requests: asyncio.Queue[CliRpcRequest],
138
+ ) -> REMOTE_EXECUTION_CLIENT_EXIT_INFO | None:
139
+ """Handle an incoming websocket message.
140
+ Returns None to continue processing, or a REMOTE_EXECUTION_CLIENT_EXIT_INFO to exit."""
141
+
142
+ msg_data = json.loads(msg)
143
+ if msg_data["type"] != "request":
144
+ return None
145
+
146
+ data = json.dumps(msg_data["data"])
147
+ request = msgspec.json.decode(data, type=CliRpcRequest)
148
+
149
+ if isinstance(request.request, TerminateRequest):
150
+ await self.halt_all_code_executions()
151
+ await websocket.send(
152
+ json.dumps(
153
+ {
154
+ "type": "result",
155
+ "data": msgspec.to_builtins(
156
+ CliRpcResponse(
157
+ request_id=request.request_id,
158
+ response=TerminateResponse(),
159
+ )
160
+ ),
161
+ }
162
+ )
163
+ )
164
+ return None
165
+ else:
166
+ if isinstance(request.request, ToolExecutionRequest) and isinstance(
167
+ request.request.tool_input, BashToolInput
168
+ ):
169
+ await self.add_code_execution_to_halt_states(request.request_id)
170
+ elif isinstance(request.request, BatchToolExecutionRequest):
171
+ # Add halt state if any of the batch tools are bash commands
172
+ if any(
173
+ isinstance(tool_input, BashToolInput)
174
+ for tool_input in request.request.tool_inputs
175
+ ):
176
+ await self.add_code_execution_to_halt_states(request.request_id)
177
+
178
+ await requests.put(request)
179
+ return None
180
+
181
+ async def _setup_tasks(
182
+ self,
183
+ beats: asyncio.Queue[HeartbeatInfo],
184
+ requests: asyncio.Queue[CliRpcRequest],
185
+ results: asyncio.Queue[CliRpcResponse],
186
+ ) -> list[asyncio.Task[None]]:
187
+ """Setup heartbeat and executor tasks."""
188
+
189
+ async def beat() -> None:
190
+ while True:
191
+ info = await self.get_heartbeat_info()
192
+ await beats.put(info)
193
+ await asyncio.sleep(3)
194
+
195
+ # Lock to ensure that only one executor can grab a
196
+ # request at a time.
197
+ requests_lock = asyncio.Lock()
198
+
199
+ # Lock to ensure that only one executor can put a
200
+ # result in the results queue at a time.
201
+ results_lock = asyncio.Lock()
202
+
203
+ async def executor() -> None:
204
+ # We use locks here to protect the request/result
205
+ # queues from being accessed by multiple executors.
206
+ while True:
207
+ async with requests_lock:
208
+ request = await requests.get()
209
+
210
+ try:
211
+ # if isinstance(request, StreamingCodeExecutionRequest):
212
+ # async for streaming_response in self.handle_streaming_request(
213
+ # request
214
+ # ):
215
+ # async with results_lock:
216
+ # await results.put(streaming_response)
217
+ # else:
218
+ # Note that we don't want to hold the lock here
219
+ # because we want other executors to be able to
220
+ # grab requests while we're handling a request.
221
+ logger.info(f"Handling request {request}")
222
+ response = await self.handle_request(request)
223
+ async with results_lock:
224
+ logger.info(f"Putting response {response}")
225
+ await results.put(response)
226
+ except Exception as e: # noqa: BLE001
227
+ logger.info(f"Error handling request {request}:\n\n{e}")
228
+ async with results_lock:
229
+ await results.put(
230
+ CliRpcResponse(
231
+ request_id=request.request_id,
232
+ response=ErrorResponse(
233
+ error_message=str(e),
234
+ ),
235
+ )
236
+ )
237
+
238
+ beat_task = asyncio.create_task(beat())
239
+ # Three parallel executors to handle requests
240
+
241
+ executor_tasks = [
242
+ asyncio.create_task(executor()),
243
+ asyncio.create_task(executor()),
244
+ asyncio.create_task(executor()),
245
+ ]
246
+
247
+ return [beat_task, *executor_tasks]
248
+
249
+ async def _process_websocket_messages(
250
+ self,
251
+ websocket: websockets.client.WebSocketClientProtocol,
252
+ beats: asyncio.Queue[HeartbeatInfo],
253
+ requests: asyncio.Queue[CliRpcRequest],
254
+ results: asyncio.Queue[CliRpcResponse],
255
+ ) -> REMOTE_EXECUTION_CLIENT_EXIT_INFO:
256
+ """Process messages from the websocket connection."""
257
+ try:
258
+ recv = asyncio.create_task(websocket.recv())
259
+ get_beat = asyncio.create_task(beats.get())
260
+ get_result = asyncio.create_task(results.get())
261
+ pending = {recv, get_beat, get_result}
262
+
263
+ while True:
264
+ done, pending = await asyncio.wait(
265
+ pending, return_when=asyncio.FIRST_COMPLETED
266
+ )
267
+
268
+ if recv in done:
269
+ msg = str(recv.result())
270
+ exit_info = await self._handle_websocket_message(
271
+ msg, websocket, requests
272
+ )
273
+ if exit_info is not None:
274
+ return exit_info
275
+
276
+ recv = asyncio.create_task(websocket.recv())
277
+ pending.add(recv)
278
+
279
+ if get_beat in done:
280
+ info = get_beat.result()
281
+ data = json.loads(info.model_dump_json())
282
+ msg = json.dumps({"type": "heartbeat", "data": data})
283
+ await websocket.send(msg)
284
+
285
+ get_beat = asyncio.create_task(beats.get())
286
+ pending.add(get_beat)
287
+
288
+ if get_result in done:
289
+ response = get_result.result()
290
+ data = msgspec.to_builtins(response)
291
+ msg = json.dumps({"type": "result", "data": data})
292
+ await websocket.send(msg)
293
+
294
+ get_result = asyncio.create_task(results.get())
295
+ pending.add(get_result)
296
+ finally:
297
+ for task in pending:
298
+ task.cancel()
299
+
300
+ await asyncio.gather(*pending, return_exceptions=True)
301
+
302
+ async def _handle_websocket_connection(
303
+ self,
304
+ websocket: websockets.client.WebSocketClientProtocol,
305
+ connection_tracker: ConnectionTracker | None,
306
+ ) -> REMOTE_EXECUTION_CLIENT_EXIT_INFO | None:
307
+ """Handle a single websocket connection.
308
+ Returns None to continue with reconnection attempts, or an exit info to terminate."""
309
+ if connection_tracker is not None:
310
+ await connection_tracker.set_connected(True)
311
+
312
+ beats: asyncio.Queue[HeartbeatInfo] = asyncio.Queue()
313
+ requests: asyncio.Queue[CliRpcRequest] = asyncio.Queue()
314
+ results: asyncio.Queue[CliRpcResponse] = asyncio.Queue()
315
+
316
+ tasks = await self._setup_tasks(beats, requests, results)
317
+
318
+ try:
319
+ return await self._process_websocket_messages(
320
+ websocket, beats, requests, results
321
+ )
322
+ except websockets.exceptions.ConnectionClosed as e:
323
+ if e.rcvd is not None:
324
+ if e.rcvd.code == 1000:
325
+ # Normal closure, exit completely
326
+ return WSDisconnected()
327
+ elif e.rcvd.code == 1008:
328
+ error_message = (
329
+ "Error connecting to websocket"
330
+ if e.rcvd.reason is None
331
+ else e.rcvd.reason
332
+ )
333
+ return WSDisconnected(error_message=error_message)
334
+ # Otherwise, allow reconnection attempt
335
+ return None
336
+ except TimeoutError:
337
+ # Timeout, allow reconnection attempt
338
+ # TODO: investgate if this is needed, possibly scope it down
339
+ return None
340
+ finally:
341
+ for task in tasks:
342
+ task.cancel()
343
+ await asyncio.gather(*tasks, return_exceptions=True)
344
+ if connection_tracker is not None:
345
+ await connection_tracker.set_connected(False)
346
+
347
+ async def run_connection(
348
+ self,
349
+ chat_uuid: str,
350
+ connection_tracker: ConnectionTracker | None = None,
351
+ ) -> REMOTE_EXECUTION_CLIENT_EXIT_INFO:
352
+ """Run the websocket connection loop."""
353
+ self.current_session.set_chat_uuid(chat_uuid)
354
+
355
+ async for websocket in self.ws_connect(f"/api/ws/chat/{chat_uuid}"):
356
+ result = await self._handle_websocket_connection(
357
+ websocket, connection_tracker
358
+ )
359
+ if result is not None:
360
+ return result
361
+ # If we get None, we'll try to reconnect
362
+
363
+ # If we exit the websocket connection loop without returning,
364
+ # it means we couldn't establish a connection
365
+ return WSDisconnected(error_message="Could not establish websocket connection")
366
+
367
+ async def create_chat(self, chat_source: ChatSource) -> CreateChatResponse:
368
+ response = await self.api_client.post(
369
+ "/api/remote_execution/create_chat",
370
+ params={"chat_source": chat_source.value},
371
+ )
372
+ return await deserialize_api_response(response, CreateChatResponse)
373
+
374
+ async def get_gh_installation_token(self, git_info: GitInfo) -> dict[str, Any]:
375
+ response = await self.api_client.post(
376
+ "/github_app/exchange_token",
377
+ json=git_info.model_dump(),
378
+ )
379
+ return cast(dict[str, Any], response.json())
380
+
381
+ async def run_workflow(self, chat_uuid: str, workflow_id: str) -> dict[str, Any]:
382
+ response = await self.api_client.post(
383
+ "/api/remote_execution/run_workflow",
384
+ json=RunWorkflowRequest(
385
+ chat_uuid=chat_uuid,
386
+ workflow_id=workflow_id,
387
+ ).model_dump(),
388
+ timeout=60,
389
+ )
390
+ if response.status_code != http_status.OK:
391
+ raise Exception(
392
+ f"Failed to run workflow with status code {response.status_code} and response {response.text}"
393
+ )
394
+ return cast(dict[str, Any], response.json())
395
+
396
+ async def get_heartbeat_info(self) -> HeartbeatInfo:
397
+ return HeartbeatInfo(
398
+ system_info=await system_context.get_system_info(self.working_directory),
399
+ exponent_version=get_installed_version(),
400
+ editable_installation=is_editable_install(),
401
+ )
402
+
403
+ async def send_heartbeat(self, chat_uuid: str) -> CLIConnectedState:
404
+ logger.info(f"Sending heartbeat for chat_uuid {chat_uuid}")
405
+ heartbeat_info = await self.get_heartbeat_info()
406
+ response = await self.api_client.post(
407
+ f"/api/remote_execution/{chat_uuid}/heartbeat",
408
+ content=heartbeat_info.model_dump_json(),
409
+ timeout=60,
410
+ )
411
+ if response.status_code != http_status.OK:
412
+ raise Exception(
413
+ f"Heartbeat failed with status code {response.status_code} and response {response.text}"
414
+ )
415
+ connected_state = await deserialize_api_response(response, CLIConnectedState)
416
+ logger.info(f"Heartbeat response: {connected_state}")
417
+ return connected_state
418
+
419
+ async def handle_request(self, request: CliRpcRequest) -> CliRpcResponse:
420
+ try:
421
+ if isinstance(request.request, ToolExecutionRequest):
422
+ if isinstance(request.request.tool_input, BashToolInput):
423
+ raw_result = await execute_bash_tool(
424
+ request.request.tool_input,
425
+ self.working_directory,
426
+ should_halt=self.get_halt_check(request.request_id),
427
+ )
428
+ else:
429
+ raw_result = await execute_tool( # type: ignore[assignment]
430
+ request.request.tool_input, self.working_directory
431
+ )
432
+ tool_result = truncate_result(raw_result)
433
+ return CliRpcResponse(
434
+ request_id=request.request_id,
435
+ response=ToolExecutionResponse(
436
+ tool_result=tool_result,
437
+ ),
438
+ )
439
+ elif isinstance(request.request, GetAllFilesRequest):
440
+ files = await file_walk(self.working_directory)
441
+ return CliRpcResponse(
442
+ request_id=request.request_id,
443
+ response=GetAllFilesResponse(files=files),
444
+ )
445
+ elif isinstance(request.request, BatchToolExecutionRequest):
446
+ results: list[ToolResultType] = []
447
+ for tool_input in request.request.tool_inputs:
448
+ try:
449
+ if isinstance(tool_input, BashToolInput):
450
+ raw_result = await execute_bash_tool(
451
+ tool_input,
452
+ self.working_directory,
453
+ should_halt=self.get_halt_check(request.request_id),
454
+ )
455
+ else:
456
+ raw_result = await execute_tool( # type: ignore[assignment]
457
+ tool_input, self.working_directory
458
+ )
459
+ tool_result = truncate_result(raw_result)
460
+ results.append(tool_result)
461
+ except Exception as e: # noqa: BLE001
462
+ logger.error(f"Error executing tool {tool_input}: {e}")
463
+ from exponent.core.remote_execution.cli_rpc_types import (
464
+ ErrorToolResult,
465
+ )
466
+
467
+ results.append(ErrorToolResult(error_message=str(e)))
468
+
469
+ return CliRpcResponse(
470
+ request_id=request.request_id,
471
+ response=BatchToolExecutionResponse(
472
+ tool_results=results,
473
+ ),
474
+ )
475
+ elif isinstance(request.request, TerminateRequest):
476
+ raise ValueError(
477
+ "TerminateRequest should not be handled by handle_request"
478
+ )
479
+
480
+ raise ValueError(f"Unhandled request type: {type(request)}")
481
+
482
+ except Exception as e:
483
+ logger.error(f"Error handling request {request}:\n\n{e}")
484
+ raise e
485
+ finally:
486
+ # Clean up halt state after request is complete
487
+ if isinstance(request.request, ToolExecutionRequest) and isinstance(
488
+ request.request.tool_input, BashToolInput
489
+ ):
490
+ await self.clear_halt_state(request.request_id)
491
+ elif isinstance(request.request, BatchToolExecutionRequest):
492
+ # Clear halt state if any of the batch tools were bash commands
493
+ if any(
494
+ isinstance(tool_input, BashToolInput)
495
+ for tool_input in request.request.tool_inputs
496
+ ):
497
+ await self.clear_halt_state(request.request_id)
498
+
499
+ async def handle_streaming_request(
500
+ self, request: StreamingCodeExecutionRequest
501
+ ) -> AsyncGenerator[RemoteExecutionResponseType, None]:
502
+ if not isinstance(request, StreamingCodeExecutionRequest):
503
+ assert False, f"{type(request)} should be sent to handle_streaming_request"
504
+ async for output in execute_code_streaming(
505
+ request,
506
+ self.current_session,
507
+ working_directory=self.working_directory,
508
+ should_halt=self.get_halt_check(request.correlation_id),
509
+ ):
510
+ yield output
511
+
512
+ def ws_connect(self, path: str) -> websockets.client.connect:
513
+ base_url = (
514
+ str(self.ws_client.base_url)
515
+ .replace("http://", "ws://")
516
+ .replace("https://", "wss://")
517
+ )
518
+
519
+ url = f"{base_url}{path}"
520
+ headers = {"api-key": self.api_client.headers["api-key"]}
521
+
522
+ conn = websockets.client.connect(
523
+ url, extra_headers=headers, timeout=10, ping_timeout=10
524
+ )
525
+
526
+ # Stop exponential backoff from blowing up
527
+ # the wait time between connection attempts
528
+ conn.BACKOFF_MAX = 2
529
+ conn.BACKOFF_INITIAL = 1 # pyright: ignore
530
+
531
+ return conn
532
+
533
+ @staticmethod
534
+ @asynccontextmanager
535
+ async def session(
536
+ api_key: str,
537
+ base_url: str,
538
+ base_ws_url: str,
539
+ working_directory: str,
540
+ file_cache: files.FileCache | None = None,
541
+ ) -> AsyncGenerator[RemoteExecutionClient, None]:
542
+ async with get_session(
543
+ working_directory, base_url, base_ws_url, api_key
544
+ ) as session:
545
+ yield RemoteExecutionClient(session, file_cache)
@@ -0,0 +1,58 @@
1
+ from collections.abc import AsyncGenerator, Callable
2
+
3
+ from exponent.core.remote_execution.languages.python_execution import (
4
+ execute_python_streaming,
5
+ )
6
+ from exponent.core.remote_execution.languages.shell_streaming import (
7
+ execute_shell_streaming,
8
+ )
9
+ from exponent.core.remote_execution.languages.types import StreamedOutputPiece
10
+ from exponent.core.remote_execution.session import RemoteExecutionClientSession
11
+ from exponent.core.remote_execution.types import (
12
+ StreamingCodeExecutionRequest,
13
+ StreamingCodeExecutionResponse,
14
+ StreamingCodeExecutionResponseChunk,
15
+ )
16
+
17
+ EMPTY_OUTPUT_STRING = "(No output)"
18
+
19
+
20
+ async def execute_code_streaming(
21
+ request: StreamingCodeExecutionRequest,
22
+ session: RemoteExecutionClientSession,
23
+ working_directory: str,
24
+ should_halt: Callable[[], bool] | None = None,
25
+ ) -> AsyncGenerator[
26
+ StreamingCodeExecutionResponseChunk | StreamingCodeExecutionResponse, None
27
+ ]:
28
+ if request.language == "python":
29
+ async for output in execute_python_streaming(
30
+ request.content, session.kernel, user_interrupted=should_halt
31
+ ):
32
+ if isinstance(output, StreamedOutputPiece):
33
+ yield StreamingCodeExecutionResponseChunk(
34
+ content=output.content, correlation_id=request.correlation_id
35
+ )
36
+ else:
37
+ yield StreamingCodeExecutionResponse(
38
+ correlation_id=request.correlation_id,
39
+ content=output.output or EMPTY_OUTPUT_STRING,
40
+ halted=output.halted,
41
+ )
42
+
43
+ elif request.language == "shell":
44
+ async for shell_output in execute_shell_streaming(
45
+ request.content, working_directory, request.timeout, should_halt
46
+ ):
47
+ if isinstance(shell_output, StreamedOutputPiece):
48
+ yield StreamingCodeExecutionResponseChunk(
49
+ content=shell_output.content, correlation_id=request.correlation_id
50
+ )
51
+ else:
52
+ yield StreamingCodeExecutionResponse(
53
+ correlation_id=request.correlation_id,
54
+ content=shell_output.output or EMPTY_OUTPUT_STRING,
55
+ halted=shell_output.halted,
56
+ exit_code=shell_output.exit_code,
57
+ cancelled_for_timeout=shell_output.cancelled_for_timeout,
58
+ )