indent 0.1.11__py3-none-any.whl → 0.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- exponent/__init__.py +2 -2
- exponent/commands/cloud_commands.py +582 -0
- exponent/commands/common.py +4 -9
- exponent/commands/run_commands.py +20 -9
- exponent/commands/workflow_commands.py +1 -1
- exponent/core/graphql/client.py +5 -3
- exponent/core/graphql/mutations.py +114 -0
- exponent/core/graphql/queries.py +23 -0
- exponent/core/remote_execution/cli_rpc_types.py +17 -0
- exponent/core/remote_execution/client.py +108 -19
- exponent/core/remote_execution/file_write.py +0 -375
- exponent/core/remote_execution/types.py +1 -7
- exponent/core/types/generated/strategy_info.py +0 -12
- {indent-0.1.11.dist-info → indent-0.1.13.dist-info}/METADATA +3 -3
- {indent-0.1.11.dist-info → indent-0.1.13.dist-info}/RECORD +17 -17
- {indent-0.1.11.dist-info → indent-0.1.13.dist-info}/WHEEL +0 -0
- {indent-0.1.11.dist-info → indent-0.1.13.dist-info}/entry_points.txt +0 -0
|
@@ -61,12 +61,19 @@ def run_cli() -> None:
|
|
|
61
61
|
hidden=True,
|
|
62
62
|
required=False,
|
|
63
63
|
)
|
|
64
|
+
@click.option(
|
|
65
|
+
"--timeout-seconds",
|
|
66
|
+
type=int,
|
|
67
|
+
help="Number of seconds without receiving a request before shutting down",
|
|
68
|
+
envvar="INDENT_TIMEOUT_SECONDS",
|
|
69
|
+
)
|
|
64
70
|
@use_settings
|
|
65
71
|
def run(
|
|
66
72
|
settings: Settings,
|
|
67
73
|
chat_id: str | None = None,
|
|
68
74
|
prompt: str | None = None,
|
|
69
75
|
workflow_id: str | None = None,
|
|
76
|
+
timeout_seconds: int | None = None,
|
|
70
77
|
) -> None:
|
|
71
78
|
"""[default] Start or reconnect to an Indent session."""
|
|
72
79
|
check_exponent_version_and_upgrade(settings)
|
|
@@ -104,7 +111,9 @@ def run(
|
|
|
104
111
|
launch_exponent_browser(settings.environment, base_url, chat_uuid)
|
|
105
112
|
|
|
106
113
|
while True:
|
|
107
|
-
result = run_chat(
|
|
114
|
+
result = run_chat(
|
|
115
|
+
loop, api_key, chat_uuid, settings, prompt, workflow_id, timeout_seconds
|
|
116
|
+
)
|
|
108
117
|
if result is None or isinstance(result, WSDisconnected):
|
|
109
118
|
# NOTE: None here means that handle_connection_changes exited
|
|
110
119
|
# first. We should likely have a different message for this.
|
|
@@ -112,11 +121,11 @@ def run(
|
|
|
112
121
|
click.secho(f"Error: {result.error_message}", fg="red")
|
|
113
122
|
sys.exit(10)
|
|
114
123
|
else:
|
|
115
|
-
|
|
124
|
+
click.echo("Disconnected upon user request, shutting down...")
|
|
116
125
|
break
|
|
117
126
|
elif isinstance(result, SwitchCLIChat):
|
|
118
127
|
chat_uuid = result.new_chat_uuid
|
|
119
|
-
|
|
128
|
+
click.echo("\nSwitching chats...")
|
|
120
129
|
else:
|
|
121
130
|
assert_unreachable(result)
|
|
122
131
|
|
|
@@ -128,6 +137,7 @@ def run_chat(
|
|
|
128
137
|
settings: Settings,
|
|
129
138
|
prompt: str | None,
|
|
130
139
|
workflow_id: str | None,
|
|
140
|
+
timeout_seconds: int | None,
|
|
131
141
|
) -> REMOTE_EXECUTION_CLIENT_EXIT_INFO | None:
|
|
132
142
|
start_ts = time.time()
|
|
133
143
|
base_url = settings.base_url
|
|
@@ -135,7 +145,7 @@ def run_chat(
|
|
|
135
145
|
base_ws_url = settings.get_base_ws_url()
|
|
136
146
|
|
|
137
147
|
print_exponent_message(base_url, chat_uuid)
|
|
138
|
-
|
|
148
|
+
click.echo()
|
|
139
149
|
|
|
140
150
|
connection_tracker = ConnectionTracker()
|
|
141
151
|
|
|
@@ -149,6 +159,7 @@ def run_chat(
|
|
|
149
159
|
prompt=prompt,
|
|
150
160
|
workflow_id=workflow_id,
|
|
151
161
|
connection_tracker=connection_tracker,
|
|
162
|
+
timeout_seconds=timeout_seconds,
|
|
152
163
|
)
|
|
153
164
|
)
|
|
154
165
|
|
|
@@ -179,27 +190,27 @@ async def handle_connection_changes(
|
|
|
179
190
|
try:
|
|
180
191
|
async with timeout(5):
|
|
181
192
|
assert await connection_tracker.next_change()
|
|
182
|
-
|
|
193
|
+
click.echo(ready_message(start_ts))
|
|
183
194
|
except TimeoutError:
|
|
184
195
|
spinner = Spinner("Connecting...")
|
|
185
196
|
spinner.show()
|
|
186
197
|
assert await connection_tracker.next_change()
|
|
187
198
|
spinner.hide()
|
|
188
|
-
|
|
199
|
+
click.echo(ready_message(start_ts))
|
|
189
200
|
|
|
190
201
|
while True:
|
|
191
202
|
assert not await connection_tracker.next_change()
|
|
192
203
|
|
|
193
|
-
|
|
204
|
+
click.echo("Disconnected...", nl=False)
|
|
194
205
|
await asyncio.sleep(1)
|
|
195
206
|
spinner = Spinner("Reconnecting...")
|
|
196
207
|
spinner.show()
|
|
197
208
|
assert await connection_tracker.next_change()
|
|
198
209
|
spinner.hide()
|
|
199
|
-
|
|
210
|
+
click.echo("\x1b[1;32m✓ Reconnected", nl=False)
|
|
200
211
|
sys.stdout.flush()
|
|
201
212
|
await asyncio.sleep(1)
|
|
202
|
-
|
|
213
|
+
click.echo("\r\x1b[0m\x1b[2K", nl=False)
|
|
203
214
|
sys.stdout.flush()
|
|
204
215
|
|
|
205
216
|
|
|
@@ -47,7 +47,7 @@ def trigger(settings: Settings, workflow_type: str) -> None:
|
|
|
47
47
|
|
|
48
48
|
while True:
|
|
49
49
|
result = run_chat(
|
|
50
|
-
loop, settings.api_key, response.chat_uuid, settings, None, None
|
|
50
|
+
loop, settings.api_key, response.chat_uuid, settings, None, None, None
|
|
51
51
|
)
|
|
52
52
|
if result is None or isinstance(result, WSDisconnected):
|
|
53
53
|
# NOTE: None here means that handle_connection_changes exited
|
exponent/core/graphql/client.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from collections.abc import AsyncGenerator
|
|
2
2
|
from typing import Any
|
|
3
3
|
|
|
4
|
-
from gql import Client, gql
|
|
4
|
+
from gql import Client, GraphQLRequest, gql
|
|
5
5
|
from gql.transport.httpx import HTTPXAsyncTransport
|
|
6
6
|
from gql.transport.websockets import WebsocketsTransport
|
|
7
7
|
|
|
@@ -41,8 +41,10 @@ class GraphQLClient:
|
|
|
41
41
|
execute_timeout=timeout,
|
|
42
42
|
) as session:
|
|
43
43
|
# Execute single query
|
|
44
|
-
query =
|
|
45
|
-
|
|
44
|
+
query = GraphQLRequest(
|
|
45
|
+
query_str, variable_values=vars, operation_name=op_name
|
|
46
|
+
)
|
|
47
|
+
result = await session.execute(query)
|
|
46
48
|
return result
|
|
47
49
|
|
|
48
50
|
async def subscribe(
|
|
@@ -73,3 +73,117 @@ mutation CreateCloudChat($configId: String!) {
|
|
|
73
73
|
}
|
|
74
74
|
}
|
|
75
75
|
"""
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
CREATE_CLOUD_CHAT_FROM_REPOSITORY_MUTATION = """
|
|
79
|
+
mutation CreateCloudChatFromRepository($repositoryId: String!) {
|
|
80
|
+
createCloudChat(repositoryId: $repositoryId) {
|
|
81
|
+
__typename
|
|
82
|
+
...on Chat {
|
|
83
|
+
chatUuid
|
|
84
|
+
}
|
|
85
|
+
...on UnauthenticatedError {
|
|
86
|
+
message
|
|
87
|
+
}
|
|
88
|
+
...on ChatNotFoundError {
|
|
89
|
+
message
|
|
90
|
+
}
|
|
91
|
+
...on CloudConfigNotFoundError {
|
|
92
|
+
message
|
|
93
|
+
}
|
|
94
|
+
...on GithubConfigNotFoundError {
|
|
95
|
+
message
|
|
96
|
+
}
|
|
97
|
+
...on CloudSessionError {
|
|
98
|
+
message
|
|
99
|
+
}
|
|
100
|
+
}
|
|
101
|
+
}
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
ENABLE_CLOUD_REPOSITORY_MUTATION = """
|
|
106
|
+
mutation EnableCloudRepository($orgName: String!, $repoName: String!) {
|
|
107
|
+
enableCloudRepository(orgName: $orgName, repoName: $repoName) {
|
|
108
|
+
__typename
|
|
109
|
+
...on ContainerImage {
|
|
110
|
+
buildRef
|
|
111
|
+
createdAt
|
|
112
|
+
updatedAt
|
|
113
|
+
}
|
|
114
|
+
...on UnauthenticatedError {
|
|
115
|
+
message
|
|
116
|
+
}
|
|
117
|
+
...on CloudConfigNotFoundError {
|
|
118
|
+
message
|
|
119
|
+
}
|
|
120
|
+
...on GithubConfigNotFoundError {
|
|
121
|
+
message
|
|
122
|
+
}
|
|
123
|
+
...on CloudSessionError {
|
|
124
|
+
message
|
|
125
|
+
}
|
|
126
|
+
...on Error {
|
|
127
|
+
message
|
|
128
|
+
}
|
|
129
|
+
}
|
|
130
|
+
}
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
INCREMENTAL_BUILD_CLOUD_REPOSITORY_MUTATION = """
|
|
135
|
+
mutation IncrementalBuildCloudRepository($orgName: String!, $repoName: String!) {
|
|
136
|
+
incrementalBuildCloudRepository(orgName: $orgName, repoName: $repoName) {
|
|
137
|
+
__typename
|
|
138
|
+
...on ContainerImage {
|
|
139
|
+
buildRef
|
|
140
|
+
createdAt
|
|
141
|
+
updatedAt
|
|
142
|
+
}
|
|
143
|
+
...on UnauthenticatedError {
|
|
144
|
+
message
|
|
145
|
+
}
|
|
146
|
+
...on CloudConfigNotFoundError {
|
|
147
|
+
message
|
|
148
|
+
}
|
|
149
|
+
...on GithubConfigNotFoundError {
|
|
150
|
+
message
|
|
151
|
+
}
|
|
152
|
+
...on CloudSessionError {
|
|
153
|
+
message
|
|
154
|
+
}
|
|
155
|
+
...on Error {
|
|
156
|
+
message
|
|
157
|
+
}
|
|
158
|
+
}
|
|
159
|
+
}
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
REBUILD_CLOUD_REPOSITORY_MUTATION = """
|
|
164
|
+
mutation RebuildCloudRepository($orgName: String!, $repoName: String!) {
|
|
165
|
+
rebuildCloudRepository(orgName: $orgName, repoName: $repoName) {
|
|
166
|
+
__typename
|
|
167
|
+
...on ContainerImage {
|
|
168
|
+
buildRef
|
|
169
|
+
createdAt
|
|
170
|
+
updatedAt
|
|
171
|
+
}
|
|
172
|
+
...on UnauthenticatedError {
|
|
173
|
+
message
|
|
174
|
+
}
|
|
175
|
+
...on CloudConfigNotFoundError {
|
|
176
|
+
message
|
|
177
|
+
}
|
|
178
|
+
...on GithubConfigNotFoundError {
|
|
179
|
+
message
|
|
180
|
+
}
|
|
181
|
+
...on CloudSessionError {
|
|
182
|
+
message
|
|
183
|
+
}
|
|
184
|
+
...on Error {
|
|
185
|
+
message
|
|
186
|
+
}
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
"""
|
exponent/core/graphql/queries.py
CHANGED
|
@@ -1,3 +1,26 @@
|
|
|
1
|
+
GITHUB_REPOSITORIES_QUERY: str = """
|
|
2
|
+
query GithubRepositories {
|
|
3
|
+
githubRepositories {
|
|
4
|
+
__typename
|
|
5
|
+
... on GithubRepositories {
|
|
6
|
+
repositories {
|
|
7
|
+
id
|
|
8
|
+
githubOrgName
|
|
9
|
+
githubRepoName
|
|
10
|
+
baseHost
|
|
11
|
+
containerImageId
|
|
12
|
+
createdAt
|
|
13
|
+
updatedAt
|
|
14
|
+
}
|
|
15
|
+
}
|
|
16
|
+
... on Error {
|
|
17
|
+
message
|
|
18
|
+
}
|
|
19
|
+
}
|
|
20
|
+
}
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
|
|
1
24
|
EVENTS_FOR_CHAT_QUERY: str = """query EventsForChat($chatUuid: UUID!) {
|
|
2
25
|
eventsForChat(chatUuid: $chatUuid) {
|
|
3
26
|
... on EventHistory {
|
|
@@ -27,6 +27,11 @@ class ToolInput(msgspec.Struct, tag_field="tool_name", omit_defaults=True):
|
|
|
27
27
|
"""
|
|
28
28
|
return msgspec.to_builtins(self) # type: ignore[no-any-return]
|
|
29
29
|
|
|
30
|
+
def to_text(self) -> str:
|
|
31
|
+
d = msgspec.to_builtins(self)
|
|
32
|
+
del d["tool_name"]
|
|
33
|
+
return yaml.dump(d)
|
|
34
|
+
|
|
30
35
|
|
|
31
36
|
class ToolResult(msgspec.Struct, tag_field="tool_name", omit_defaults=True):
|
|
32
37
|
"""Concrete subclasses return data from a tool execution."""
|
|
@@ -236,6 +241,12 @@ class SwitchCLIChatRequest(msgspec.Struct, tag="switch_cli_chat"):
|
|
|
236
241
|
new_chat_uuid: str
|
|
237
242
|
|
|
238
243
|
|
|
244
|
+
# This message is sent periodically from the client to keep the connection alive while the chat is turning.
|
|
245
|
+
# This synergizes with CLI-side timeouts to avoid disconnecting during long operations.
|
|
246
|
+
class KeepAliveCliChatRequest(msgspec.Struct, tag="keep_alive_cli_chat"):
|
|
247
|
+
pass
|
|
248
|
+
|
|
249
|
+
|
|
239
250
|
class BatchToolExecutionRequest(msgspec.Struct, tag="batch_tool_execution"):
|
|
240
251
|
tool_inputs: list[ToolInputType]
|
|
241
252
|
|
|
@@ -256,6 +267,10 @@ class SwitchCLIChatResponse(msgspec.Struct, tag="switch_cli_chat"):
|
|
|
256
267
|
pass
|
|
257
268
|
|
|
258
269
|
|
|
270
|
+
class KeepAliveCliChatResponse(msgspec.Struct, tag="keep_alive_cli_chat"):
|
|
271
|
+
pass
|
|
272
|
+
|
|
273
|
+
|
|
259
274
|
class CliRpcRequest(msgspec.Struct):
|
|
260
275
|
request_id: str
|
|
261
276
|
request: (
|
|
@@ -265,6 +280,7 @@ class CliRpcRequest(msgspec.Struct):
|
|
|
265
280
|
| HttpRequest
|
|
266
281
|
| BatchToolExecutionRequest
|
|
267
282
|
| SwitchCLIChatRequest
|
|
283
|
+
| KeepAliveCliChatRequest
|
|
268
284
|
)
|
|
269
285
|
|
|
270
286
|
|
|
@@ -286,4 +302,5 @@ class CliRpcResponse(msgspec.Struct):
|
|
|
286
302
|
| BatchToolExecutionResponse
|
|
287
303
|
| HttpResponse
|
|
288
304
|
| SwitchCLIChatResponse
|
|
305
|
+
| KeepAliveCliChatResponse
|
|
289
306
|
)
|
|
@@ -3,19 +3,22 @@ from __future__ import annotations
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import json
|
|
5
5
|
import logging
|
|
6
|
-
|
|
6
|
+
import time
|
|
7
|
+
import uuid
|
|
8
|
+
from collections.abc import AsyncGenerator, Callable, Generator
|
|
7
9
|
from contextlib import asynccontextmanager
|
|
8
10
|
from dataclasses import dataclass
|
|
9
11
|
from typing import Any, TypeVar, cast
|
|
10
12
|
|
|
11
13
|
import msgspec
|
|
12
|
-
import websockets.client
|
|
13
14
|
import websockets.exceptions
|
|
14
15
|
from httpx import (
|
|
15
16
|
AsyncClient,
|
|
16
17
|
codes as http_status,
|
|
17
18
|
)
|
|
18
19
|
from pydantic import BaseModel
|
|
20
|
+
from websockets.asyncio import client as asyncio_websockets_client
|
|
21
|
+
from websockets.asyncio.client import ClientConnection, connect
|
|
19
22
|
|
|
20
23
|
from exponent.commands.utils import ConnectionTracker
|
|
21
24
|
from exponent.core.config import is_editable_install
|
|
@@ -30,6 +33,8 @@ from exponent.core.remote_execution.cli_rpc_types import (
|
|
|
30
33
|
GetAllFilesRequest,
|
|
31
34
|
GetAllFilesResponse,
|
|
32
35
|
HttpRequest,
|
|
36
|
+
KeepAliveCliChatRequest,
|
|
37
|
+
KeepAliveCliChatResponse,
|
|
33
38
|
SwitchCLIChatRequest,
|
|
34
39
|
SwitchCLIChatResponse,
|
|
35
40
|
TerminateRequest,
|
|
@@ -88,6 +93,9 @@ class SwitchCLIChat:
|
|
|
88
93
|
|
|
89
94
|
REMOTE_EXECUTION_CLIENT_EXIT_INFO = WSDisconnected | SwitchCLIChat
|
|
90
95
|
|
|
96
|
+
# UUID for a single run of the CLI
|
|
97
|
+
cli_uuid = uuid.uuid4()
|
|
98
|
+
|
|
91
99
|
|
|
92
100
|
class RemoteExecutionClient:
|
|
93
101
|
def __init__(
|
|
@@ -103,6 +111,9 @@ class RemoteExecutionClient:
|
|
|
103
111
|
self._halt_states: dict[str, bool] = {}
|
|
104
112
|
self._halt_lock = asyncio.Lock()
|
|
105
113
|
|
|
114
|
+
# Track last request time for timeout functionality
|
|
115
|
+
self._last_request_time: float | None = None
|
|
116
|
+
|
|
106
117
|
@property
|
|
107
118
|
def working_directory(self) -> str:
|
|
108
119
|
return self.current_session.working_directory
|
|
@@ -137,10 +148,35 @@ class RemoteExecutionClient:
|
|
|
137
148
|
|
|
138
149
|
return should_halt
|
|
139
150
|
|
|
151
|
+
async def _timeout_monitor(
|
|
152
|
+
self, timeout_seconds: int | None
|
|
153
|
+
) -> WSDisconnected | None:
|
|
154
|
+
"""Monitor for inactivity timeout and return WSDisconnected if timeout occurs.
|
|
155
|
+
|
|
156
|
+
If timeout_seconds is None, keeps looping indefinitely until cancelled.
|
|
157
|
+
"""
|
|
158
|
+
try:
|
|
159
|
+
while True:
|
|
160
|
+
await asyncio.sleep(1)
|
|
161
|
+
if (
|
|
162
|
+
timeout_seconds is not None
|
|
163
|
+
and self._last_request_time is not None
|
|
164
|
+
and time.time() - self._last_request_time > timeout_seconds
|
|
165
|
+
):
|
|
166
|
+
logger.info(
|
|
167
|
+
f"No requests received for {timeout_seconds} seconds. Shutting down..."
|
|
168
|
+
)
|
|
169
|
+
return WSDisconnected(
|
|
170
|
+
error_message=f"Timeout after {timeout_seconds} seconds of inactivity"
|
|
171
|
+
)
|
|
172
|
+
except asyncio.CancelledError:
|
|
173
|
+
# Handle cancellation gracefully
|
|
174
|
+
return None
|
|
175
|
+
|
|
140
176
|
async def _handle_websocket_message(
|
|
141
177
|
self,
|
|
142
178
|
msg: str,
|
|
143
|
-
websocket:
|
|
179
|
+
websocket: ClientConnection,
|
|
144
180
|
requests: asyncio.Queue[CliRpcRequest],
|
|
145
181
|
) -> REMOTE_EXECUTION_CLIENT_EXIT_INFO | None:
|
|
146
182
|
"""Handle an incoming websocket message.
|
|
@@ -240,6 +276,21 @@ class RemoteExecutionClient:
|
|
|
240
276
|
)
|
|
241
277
|
)
|
|
242
278
|
return SwitchCLIChat(new_chat_uuid=request.request.new_chat_uuid)
|
|
279
|
+
elif isinstance(request.request, KeepAliveCliChatRequest):
|
|
280
|
+
await websocket.send(
|
|
281
|
+
json.dumps(
|
|
282
|
+
{
|
|
283
|
+
"type": "result",
|
|
284
|
+
"data": msgspec.to_builtins(
|
|
285
|
+
CliRpcResponse(
|
|
286
|
+
request_id=request.request_id,
|
|
287
|
+
response=KeepAliveCliChatResponse(),
|
|
288
|
+
)
|
|
289
|
+
),
|
|
290
|
+
}
|
|
291
|
+
)
|
|
292
|
+
)
|
|
293
|
+
return None
|
|
243
294
|
else:
|
|
244
295
|
if isinstance(request.request, ToolExecutionRequest) and isinstance(
|
|
245
296
|
request.request.tool_input, BashToolInput
|
|
@@ -326,7 +377,7 @@ class RemoteExecutionClient:
|
|
|
326
377
|
|
|
327
378
|
async def _process_websocket_messages(
|
|
328
379
|
self,
|
|
329
|
-
websocket:
|
|
380
|
+
websocket: ClientConnection,
|
|
330
381
|
beats: asyncio.Queue[HeartbeatInfo],
|
|
331
382
|
requests: asyncio.Queue[CliRpcRequest],
|
|
332
383
|
results: asyncio.Queue[CliRpcResponse],
|
|
@@ -379,7 +430,7 @@ class RemoteExecutionClient:
|
|
|
379
430
|
|
|
380
431
|
async def _handle_websocket_connection(
|
|
381
432
|
self,
|
|
382
|
-
websocket:
|
|
433
|
+
websocket: ClientConnection,
|
|
383
434
|
connection_tracker: ConnectionTracker | None,
|
|
384
435
|
) -> REMOTE_EXECUTION_CLIENT_EXIT_INFO | None:
|
|
385
436
|
"""Handle a single websocket connection.
|
|
@@ -426,17 +477,37 @@ class RemoteExecutionClient:
|
|
|
426
477
|
self,
|
|
427
478
|
chat_uuid: str,
|
|
428
479
|
connection_tracker: ConnectionTracker | None = None,
|
|
480
|
+
timeout_seconds: int | None = None,
|
|
429
481
|
) -> REMOTE_EXECUTION_CLIENT_EXIT_INFO:
|
|
430
|
-
"""Run the websocket connection loop."""
|
|
482
|
+
"""Run the websocket connection loop with optional inactivity timeout."""
|
|
431
483
|
self.current_session.set_chat_uuid(chat_uuid)
|
|
432
484
|
|
|
485
|
+
# Initialize last request time for timeout monitoring
|
|
486
|
+
self._last_request_time = time.time()
|
|
487
|
+
|
|
433
488
|
async for websocket in self.ws_connect(f"/api/ws/chat/{chat_uuid}"):
|
|
434
|
-
|
|
435
|
-
|
|
489
|
+
# Always run connection and timeout monitor concurrently
|
|
490
|
+
# If timeout_seconds is None, timeout monitor will loop indefinitely
|
|
491
|
+
done, pending = await asyncio.wait(
|
|
492
|
+
[
|
|
493
|
+
asyncio.create_task(
|
|
494
|
+
self._handle_websocket_connection(websocket, connection_tracker)
|
|
495
|
+
),
|
|
496
|
+
asyncio.create_task(self._timeout_monitor(timeout_seconds)),
|
|
497
|
+
],
|
|
498
|
+
return_when=asyncio.FIRST_COMPLETED,
|
|
436
499
|
)
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
500
|
+
|
|
501
|
+
# Cancel pending tasks
|
|
502
|
+
for task in pending:
|
|
503
|
+
task.cancel()
|
|
504
|
+
|
|
505
|
+
# Return result from completed task
|
|
506
|
+
for task in done:
|
|
507
|
+
result = await task
|
|
508
|
+
# If we get None, we'll try to reconnect
|
|
509
|
+
if result is not None:
|
|
510
|
+
return result
|
|
440
511
|
|
|
441
512
|
# If we exit the websocket connection loop without returning,
|
|
442
513
|
# it means we couldn't establish a connection
|
|
@@ -489,6 +560,7 @@ class RemoteExecutionClient:
|
|
|
489
560
|
system_info=await system_context.get_system_info(self.working_directory),
|
|
490
561
|
exponent_version=get_installed_version(),
|
|
491
562
|
editable_installation=is_editable_install(),
|
|
563
|
+
cli_uuid=str(cli_uuid),
|
|
492
564
|
)
|
|
493
565
|
|
|
494
566
|
async def send_heartbeat(self, chat_uuid: str) -> CLIConnectedState:
|
|
@@ -508,6 +580,9 @@ class RemoteExecutionClient:
|
|
|
508
580
|
return connected_state
|
|
509
581
|
|
|
510
582
|
async def handle_request(self, request: CliRpcRequest) -> CliRpcResponse:
|
|
583
|
+
# Update last request time for timeout functionality
|
|
584
|
+
self._last_request_time = time.time()
|
|
585
|
+
|
|
511
586
|
try:
|
|
512
587
|
if isinstance(request.request, ToolExecutionRequest):
|
|
513
588
|
if isinstance(request.request.tool_input, BashToolInput):
|
|
@@ -578,6 +653,10 @@ class RemoteExecutionClient:
|
|
|
578
653
|
raise ValueError(
|
|
579
654
|
"SwitchCLIChatRequest should not be handled by handle_request"
|
|
580
655
|
)
|
|
656
|
+
elif isinstance(request.request, KeepAliveCliChatRequest):
|
|
657
|
+
raise ValueError(
|
|
658
|
+
"KeepAliveCliChatRequest should not be handled by handle_request"
|
|
659
|
+
)
|
|
581
660
|
|
|
582
661
|
raise ValueError(f"Unhandled request type: {type(request)}")
|
|
583
662
|
|
|
@@ -611,7 +690,7 @@ class RemoteExecutionClient:
|
|
|
611
690
|
):
|
|
612
691
|
yield output
|
|
613
692
|
|
|
614
|
-
def ws_connect(self, path: str) ->
|
|
693
|
+
def ws_connect(self, path: str) -> connect:
|
|
615
694
|
base_url = (
|
|
616
695
|
str(self.ws_client.base_url)
|
|
617
696
|
.replace("http://", "ws://")
|
|
@@ -621,14 +700,24 @@ class RemoteExecutionClient:
|
|
|
621
700
|
url = f"{base_url}{path}"
|
|
622
701
|
headers = {"api-key": self.api_client.headers["api-key"]}
|
|
623
702
|
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
)
|
|
703
|
+
def custom_backoff() -> Generator[float, None, None]:
|
|
704
|
+
yield 0.1 # short initial delay
|
|
627
705
|
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
706
|
+
delay = 0.5
|
|
707
|
+
while True:
|
|
708
|
+
if delay < 2.0:
|
|
709
|
+
yield delay
|
|
710
|
+
delay *= 1.5
|
|
711
|
+
else:
|
|
712
|
+
yield 2.0
|
|
713
|
+
|
|
714
|
+
# Can remove if this is added to public API
|
|
715
|
+
# https://github.com/python-websockets/websockets/issues/1395#issuecomment-3225670409
|
|
716
|
+
asyncio_websockets_client.backoff = custom_backoff # type: ignore[attr-defined, assignment]
|
|
717
|
+
|
|
718
|
+
conn = connect(
|
|
719
|
+
url, additional_headers=headers, open_timeout=10, ping_timeout=10
|
|
720
|
+
)
|
|
632
721
|
|
|
633
722
|
return conn
|
|
634
723
|
|