indent 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of indent might be problematic. Click here for more details.
- exponent/__init__.py +1 -0
- exponent/cli.py +112 -0
- exponent/commands/cloud_commands.py +85 -0
- exponent/commands/common.py +434 -0
- exponent/commands/config_commands.py +581 -0
- exponent/commands/github_app_commands.py +211 -0
- exponent/commands/listen_commands.py +96 -0
- exponent/commands/run_commands.py +208 -0
- exponent/commands/settings.py +56 -0
- exponent/commands/shell_commands.py +2840 -0
- exponent/commands/theme.py +246 -0
- exponent/commands/types.py +111 -0
- exponent/commands/upgrade.py +29 -0
- exponent/commands/utils.py +236 -0
- exponent/core/config.py +180 -0
- exponent/core/graphql/__init__.py +0 -0
- exponent/core/graphql/client.py +59 -0
- exponent/core/graphql/cloud_config_queries.py +77 -0
- exponent/core/graphql/get_chats_query.py +47 -0
- exponent/core/graphql/github_config_queries.py +56 -0
- exponent/core/graphql/mutations.py +75 -0
- exponent/core/graphql/queries.py +110 -0
- exponent/core/graphql/subscriptions.py +452 -0
- exponent/core/remote_execution/checkpoints.py +212 -0
- exponent/core/remote_execution/cli_rpc_types.py +214 -0
- exponent/core/remote_execution/client.py +545 -0
- exponent/core/remote_execution/code_execution.py +58 -0
- exponent/core/remote_execution/command_execution.py +105 -0
- exponent/core/remote_execution/error_info.py +45 -0
- exponent/core/remote_execution/exceptions.py +10 -0
- exponent/core/remote_execution/file_write.py +410 -0
- exponent/core/remote_execution/files.py +415 -0
- exponent/core/remote_execution/git.py +268 -0
- exponent/core/remote_execution/languages/python_execution.py +239 -0
- exponent/core/remote_execution/languages/shell_streaming.py +221 -0
- exponent/core/remote_execution/languages/types.py +20 -0
- exponent/core/remote_execution/session.py +128 -0
- exponent/core/remote_execution/system_context.py +54 -0
- exponent/core/remote_execution/tool_execution.py +289 -0
- exponent/core/remote_execution/truncation.py +284 -0
- exponent/core/remote_execution/types.py +670 -0
- exponent/core/remote_execution/utils.py +600 -0
- exponent/core/types/__init__.py +0 -0
- exponent/core/types/command_data.py +206 -0
- exponent/core/types/event_types.py +89 -0
- exponent/core/types/generated/__init__.py +0 -0
- exponent/core/types/generated/strategy_info.py +225 -0
- exponent/migration-docs/login.md +112 -0
- exponent/py.typed +4 -0
- exponent/utils/__init__.py +0 -0
- exponent/utils/colors.py +92 -0
- exponent/utils/version.py +289 -0
- indent-0.0.8.dist-info/METADATA +36 -0
- indent-0.0.8.dist-info/RECORD +56 -0
- indent-0.0.8.dist-info/WHEEL +4 -0
- indent-0.0.8.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,600 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
from collections.abc import Awaitable, Callable
|
|
4
|
+
from functools import wraps
|
|
5
|
+
from typing import (
|
|
6
|
+
Any,
|
|
7
|
+
NoReturn,
|
|
8
|
+
TypeVar,
|
|
9
|
+
cast,
|
|
10
|
+
overload,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
import websockets
|
|
14
|
+
import websockets.exceptions
|
|
15
|
+
from anyio import Path as AsyncPath
|
|
16
|
+
from bs4 import UnicodeDammit
|
|
17
|
+
from httpx import Response
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
from sentry_sdk.serializer import serialize
|
|
20
|
+
from sentry_sdk.utils import (
|
|
21
|
+
event_from_exception,
|
|
22
|
+
exc_info_from_error,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
from exponent.core.remote_execution.types import (
|
|
26
|
+
SUPPORTED_LANGUAGES,
|
|
27
|
+
CLIErrorLog,
|
|
28
|
+
CodeExecutionRequest,
|
|
29
|
+
CodeExecutionResponse,
|
|
30
|
+
CommandRequest,
|
|
31
|
+
CommandResponse,
|
|
32
|
+
CreateCheckpointRequest,
|
|
33
|
+
CreateCheckpointResponse,
|
|
34
|
+
ErrorResponse,
|
|
35
|
+
FilePath,
|
|
36
|
+
FileWriteRequest,
|
|
37
|
+
FileWriteResponse,
|
|
38
|
+
GetAllTrackedFilesRequest,
|
|
39
|
+
GetAllTrackedFilesResponse,
|
|
40
|
+
GetFileAttachmentRequest,
|
|
41
|
+
GetFileAttachmentResponse,
|
|
42
|
+
GetFileAttachmentsRequest,
|
|
43
|
+
GetFileAttachmentsResponse,
|
|
44
|
+
GetMatchingFilesRequest,
|
|
45
|
+
GetMatchingFilesResponse,
|
|
46
|
+
HaltRequest,
|
|
47
|
+
HaltResponse,
|
|
48
|
+
ListFilesRequest,
|
|
49
|
+
ListFilesResponse,
|
|
50
|
+
RemoteExecutionMessage,
|
|
51
|
+
RemoteExecutionMessageData,
|
|
52
|
+
RemoteExecutionRequest,
|
|
53
|
+
RemoteExecutionRequestData,
|
|
54
|
+
RemoteExecutionRequestType,
|
|
55
|
+
RemoteExecutionResponse,
|
|
56
|
+
RemoteExecutionResponseData,
|
|
57
|
+
RemoteExecutionResponseType,
|
|
58
|
+
RollbackToCheckpointRequest,
|
|
59
|
+
RollbackToCheckpointResponse,
|
|
60
|
+
StreamingCodeExecutionRequest,
|
|
61
|
+
StreamingCodeExecutionResponse,
|
|
62
|
+
StreamingCodeExecutionResponseChunk,
|
|
63
|
+
SupportedLanguage,
|
|
64
|
+
SwitchCLIChatRequest,
|
|
65
|
+
SwitchCLIChatResponse,
|
|
66
|
+
SystemContextRequest,
|
|
67
|
+
SystemContextResponse,
|
|
68
|
+
)
|
|
69
|
+
from exponent.core.types.command_data import NaturalEditContent
|
|
70
|
+
from exponent.core.types.event_types import (
|
|
71
|
+
CodeBlockEvent,
|
|
72
|
+
CommandEvent,
|
|
73
|
+
FileWriteEvent,
|
|
74
|
+
LocalEventType,
|
|
75
|
+
)
|
|
76
|
+
from exponent.utils.version import get_installed_version
|
|
77
|
+
|
|
78
|
+
### Serde
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def deserialize_message_data(
|
|
82
|
+
message_data: RemoteExecutionMessageData | str,
|
|
83
|
+
) -> RemoteExecutionMessage:
|
|
84
|
+
if isinstance(message_data, str):
|
|
85
|
+
message_data = RemoteExecutionMessageData.model_validate_json(message_data)
|
|
86
|
+
if message_data.direction == "request":
|
|
87
|
+
return deserialize_request_data(cast(RemoteExecutionRequestData, message_data))
|
|
88
|
+
elif message_data.direction == "response":
|
|
89
|
+
return deserialize_response_data(
|
|
90
|
+
cast(RemoteExecutionResponseData, message_data)
|
|
91
|
+
)
|
|
92
|
+
else:
|
|
93
|
+
# type checking trick, if you miss a namespace then
|
|
94
|
+
# this won't typecheck due to the input parameter
|
|
95
|
+
# having a potential type other than no-return
|
|
96
|
+
assert_unreachable(message_data.direction)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def deserialize_request_data(
|
|
100
|
+
request_data: RemoteExecutionRequestData | str,
|
|
101
|
+
) -> RemoteExecutionRequestType:
|
|
102
|
+
request: RemoteExecutionRequestType
|
|
103
|
+
if isinstance(request_data, str):
|
|
104
|
+
request_data = RemoteExecutionRequestData.model_validate_json(request_data)
|
|
105
|
+
if request_data.direction != "request":
|
|
106
|
+
raise ValueError(f"Expected request, but got {request_data.direction}")
|
|
107
|
+
if request_data.namespace == "code_execution":
|
|
108
|
+
request = CodeExecutionRequest.model_validate_json(request_data.message_data)
|
|
109
|
+
elif request_data.namespace == "file_write":
|
|
110
|
+
request = FileWriteRequest.model_validate_json(request_data.message_data)
|
|
111
|
+
elif request_data.namespace == "list_files":
|
|
112
|
+
request = ListFilesRequest.model_validate_json(request_data.message_data)
|
|
113
|
+
elif request_data.namespace == "get_file_attachment":
|
|
114
|
+
request = GetFileAttachmentRequest.model_validate_json(
|
|
115
|
+
request_data.message_data
|
|
116
|
+
)
|
|
117
|
+
elif request_data.namespace == "get_file_attachments":
|
|
118
|
+
request = GetFileAttachmentsRequest.model_validate_json(
|
|
119
|
+
request_data.message_data
|
|
120
|
+
)
|
|
121
|
+
elif request_data.namespace == "get_matching_files":
|
|
122
|
+
request = GetMatchingFilesRequest.model_validate_json(request_data.message_data)
|
|
123
|
+
elif request_data.namespace == "system_context":
|
|
124
|
+
request = SystemContextRequest.model_validate_json(request_data.message_data)
|
|
125
|
+
elif request_data.namespace == "get_all_tracked_files":
|
|
126
|
+
request = GetAllTrackedFilesRequest.model_validate_json(
|
|
127
|
+
request_data.message_data
|
|
128
|
+
)
|
|
129
|
+
elif request_data.namespace == "command":
|
|
130
|
+
request = CommandRequest.model_validate_json(request_data.message_data)
|
|
131
|
+
elif request_data.namespace == "halt":
|
|
132
|
+
request = HaltRequest.model_validate_json(request_data.message_data)
|
|
133
|
+
elif request_data.namespace == "streaming_code_execution":
|
|
134
|
+
request = StreamingCodeExecutionRequest.model_validate_json(
|
|
135
|
+
request_data.message_data
|
|
136
|
+
)
|
|
137
|
+
elif request_data.namespace == "switch_cli_chat":
|
|
138
|
+
request = SwitchCLIChatRequest.model_validate_json(request_data.message_data)
|
|
139
|
+
elif request_data.namespace == "streaming_code_execution_chunk":
|
|
140
|
+
assert False, "Streaming code execution chunk is a response, not a request"
|
|
141
|
+
elif request_data.namespace == "error":
|
|
142
|
+
assert False, "Error is a response, not a request"
|
|
143
|
+
elif request_data.namespace == "create_checkpoint":
|
|
144
|
+
request = CreateCheckpointRequest.model_validate_json(request_data.message_data)
|
|
145
|
+
elif request_data.namespace == "rollback_to_checkpoint":
|
|
146
|
+
request = RollbackToCheckpointRequest.model_validate_json(
|
|
147
|
+
request_data.message_data
|
|
148
|
+
)
|
|
149
|
+
else:
|
|
150
|
+
# type checking trick, if you miss a namespace then
|
|
151
|
+
# this won't typecheck due to the input parameter
|
|
152
|
+
# having a potential type other than no-return
|
|
153
|
+
request = assert_unreachable(request_data.namespace)
|
|
154
|
+
return truncate_message(request)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def deserialize_response_data(
|
|
158
|
+
response_data: RemoteExecutionResponseData | str,
|
|
159
|
+
) -> RemoteExecutionResponseType:
|
|
160
|
+
response: RemoteExecutionResponseType
|
|
161
|
+
if isinstance(response_data, str):
|
|
162
|
+
response_data = RemoteExecutionResponseData.model_validate_json(response_data)
|
|
163
|
+
if response_data.direction != "response":
|
|
164
|
+
raise ValueError(f"Expected response, but got {response_data.direction}")
|
|
165
|
+
if response_data.namespace == "code_execution":
|
|
166
|
+
response = CodeExecutionResponse.model_validate_json(response_data.message_data)
|
|
167
|
+
elif response_data.namespace == "streaming_code_execution":
|
|
168
|
+
response = StreamingCodeExecutionResponse.model_validate_json(
|
|
169
|
+
response_data.message_data
|
|
170
|
+
)
|
|
171
|
+
elif response_data.namespace == "streaming_code_execution_chunk":
|
|
172
|
+
response = StreamingCodeExecutionResponseChunk.model_validate_json(
|
|
173
|
+
response_data.message_data
|
|
174
|
+
)
|
|
175
|
+
elif response_data.namespace == "file_write":
|
|
176
|
+
response = FileWriteResponse.model_validate_json(response_data.message_data)
|
|
177
|
+
elif response_data.namespace == "list_files":
|
|
178
|
+
response = ListFilesResponse.model_validate_json(response_data.message_data)
|
|
179
|
+
elif response_data.namespace == "get_matching_files":
|
|
180
|
+
response = GetMatchingFilesResponse.model_validate_json(
|
|
181
|
+
response_data.message_data
|
|
182
|
+
)
|
|
183
|
+
elif response_data.namespace == "get_file_attachment":
|
|
184
|
+
response = GetFileAttachmentResponse.model_validate_json(
|
|
185
|
+
response_data.message_data
|
|
186
|
+
)
|
|
187
|
+
elif response_data.namespace == "get_file_attachments":
|
|
188
|
+
response = GetFileAttachmentsResponse.model_validate_json(
|
|
189
|
+
response_data.message_data
|
|
190
|
+
)
|
|
191
|
+
elif response_data.namespace == "system_context":
|
|
192
|
+
response = SystemContextResponse.model_validate_json(response_data.message_data)
|
|
193
|
+
elif response_data.namespace == "get_all_tracked_files":
|
|
194
|
+
response = GetAllTrackedFilesResponse.model_validate_json(
|
|
195
|
+
response_data.message_data
|
|
196
|
+
)
|
|
197
|
+
elif response_data.namespace == "command":
|
|
198
|
+
response = CommandResponse.model_validate_json(response_data.message_data)
|
|
199
|
+
elif response_data.namespace == "halt":
|
|
200
|
+
response = HaltResponse.model_validate_json(response_data.message_data)
|
|
201
|
+
elif response_data.namespace == "switch_cli_chat":
|
|
202
|
+
response = SwitchCLIChatResponse.model_validate_json(response_data.message_data)
|
|
203
|
+
elif response_data.namespace == "error":
|
|
204
|
+
response = ErrorResponse.model_validate_json(response_data.message_data)
|
|
205
|
+
elif response_data.namespace == "create_checkpoint":
|
|
206
|
+
response = CreateCheckpointResponse.model_validate_json(
|
|
207
|
+
response_data.message_data
|
|
208
|
+
)
|
|
209
|
+
elif response_data.namespace == "rollback_to_checkpoint":
|
|
210
|
+
response = RollbackToCheckpointResponse.model_validate_json(
|
|
211
|
+
response_data.message_data
|
|
212
|
+
)
|
|
213
|
+
else:
|
|
214
|
+
# type checking trick, if you miss a namespace then
|
|
215
|
+
# this won't typecheck due to the input parameter
|
|
216
|
+
# having a potential type other than no-return
|
|
217
|
+
response = assert_unreachable(response_data.namespace)
|
|
218
|
+
return truncate_message(response)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def serialize_message(response: RemoteExecutionMessage) -> str:
|
|
222
|
+
truncated_response = truncate_message(response)
|
|
223
|
+
message = RemoteExecutionMessageData(
|
|
224
|
+
namespace=response.namespace,
|
|
225
|
+
direction=response.direction,
|
|
226
|
+
message_data=truncated_response.model_dump_json(),
|
|
227
|
+
)
|
|
228
|
+
serialized = message.model_dump_json()
|
|
229
|
+
return serialized
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
### API Serde
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
TModel = TypeVar("TModel", bound=BaseModel)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
async def deserialize_api_response(
|
|
239
|
+
response: Response,
|
|
240
|
+
data_model: type[TModel],
|
|
241
|
+
) -> TModel:
|
|
242
|
+
if response.is_error:
|
|
243
|
+
print(response.text)
|
|
244
|
+
try:
|
|
245
|
+
error_message = response.json()["detail"]
|
|
246
|
+
except Exception:
|
|
247
|
+
error_message = response.text
|
|
248
|
+
raise ValueError(f"{error_message} ({response.status_code})")
|
|
249
|
+
|
|
250
|
+
response_json = response.json()
|
|
251
|
+
return data_model.model_validate(response_json)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def get_file_write_content(event: FileWriteEvent) -> str:
|
|
255
|
+
if isinstance(event.write_content, NaturalEditContent):
|
|
256
|
+
assert event.write_content.new_file is not None
|
|
257
|
+
return event.write_content.new_file
|
|
258
|
+
else:
|
|
259
|
+
return event.write_content.content
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
@overload
|
|
263
|
+
def convert_event_to_execution_request(
|
|
264
|
+
request: CodeBlockEvent,
|
|
265
|
+
) -> CodeExecutionRequest: ...
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
@overload
|
|
269
|
+
def convert_event_to_execution_request(
|
|
270
|
+
request: FileWriteEvent,
|
|
271
|
+
) -> FileWriteRequest: ...
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
@overload
|
|
275
|
+
def convert_event_to_execution_request(
|
|
276
|
+
request: CommandEvent,
|
|
277
|
+
) -> CommandRequest: ...
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def convert_event_to_execution_request(
|
|
281
|
+
request: LocalEventType,
|
|
282
|
+
) -> CodeExecutionRequest | FileWriteRequest | CommandRequest:
|
|
283
|
+
if isinstance(request, CodeBlockEvent):
|
|
284
|
+
language = assert_supported_language(request.language)
|
|
285
|
+
|
|
286
|
+
return CodeExecutionRequest(
|
|
287
|
+
language=language,
|
|
288
|
+
content=request.content,
|
|
289
|
+
timeout=request.timeout,
|
|
290
|
+
correlation_id=request.event_uuid,
|
|
291
|
+
)
|
|
292
|
+
elif isinstance(request, FileWriteEvent):
|
|
293
|
+
return FileWriteRequest(
|
|
294
|
+
file_path=request.file_path,
|
|
295
|
+
language=request.language,
|
|
296
|
+
write_strategy=request.write_strategy,
|
|
297
|
+
content=get_file_write_content(request),
|
|
298
|
+
correlation_id=request.event_uuid,
|
|
299
|
+
)
|
|
300
|
+
elif isinstance(request, CommandEvent):
|
|
301
|
+
return CommandRequest(
|
|
302
|
+
data=request.data,
|
|
303
|
+
correlation_id=request.event_uuid,
|
|
304
|
+
)
|
|
305
|
+
else:
|
|
306
|
+
assert_unreachable(request)
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
### Validation
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
ResponseT = TypeVar("ResponseT", bound=RemoteExecutionResponse)
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def assert_valid_response_type(
|
|
316
|
+
response: RemoteExecutionResponseType, request: RemoteExecutionRequest[ResponseT]
|
|
317
|
+
) -> ResponseT | ErrorResponse:
|
|
318
|
+
if isinstance(response, ErrorResponse):
|
|
319
|
+
return response
|
|
320
|
+
if request.namespace != response.namespace or response.direction != "response":
|
|
321
|
+
raise ValueError(
|
|
322
|
+
f"Expected {request.namespace}.response, but got {response.namespace}.{response.direction}"
|
|
323
|
+
)
|
|
324
|
+
return cast(ResponseT, response)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def assert_unreachable(x: NoReturn) -> NoReturn:
|
|
328
|
+
assert False, f"Unhandled type: {type(x).__name__}"
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def assert_supported_language(language: str) -> SupportedLanguage:
|
|
332
|
+
if language not in SUPPORTED_LANGUAGES:
|
|
333
|
+
raise ValueError(f"Unsupported language: {language}")
|
|
334
|
+
|
|
335
|
+
return cast(SupportedLanguage, language)
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
### Truncation
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
OUTPUT_CHARACTER_MAX = 90_000 # A tad over ~8k tokens
|
|
342
|
+
TRUNCATION_MESSAGE_CHARS = (
|
|
343
|
+
"(Output truncated, only showing the first {remaining_chars} characters)"
|
|
344
|
+
)
|
|
345
|
+
TRUNCATION_MESSAGE_LINES = (
|
|
346
|
+
"(Output truncated, only showing the first {remaining_lines} lines)"
|
|
347
|
+
)
|
|
348
|
+
LONGEST_TRUNCATION_MESSAGE_LEN = (
|
|
349
|
+
len(TRUNCATION_MESSAGE_CHARS.format(remaining_chars=OUTPUT_CHARACTER_MAX)) + 1
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
MAX_LINES = 10_000
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def truncate_output(
|
|
356
|
+
output: str, character_limit: int = OUTPUT_CHARACTER_MAX
|
|
357
|
+
) -> tuple[str, bool]:
|
|
358
|
+
output_length = len(output)
|
|
359
|
+
# When under the character limit, return the output as is.
|
|
360
|
+
# Note we're adding the length of the truncation message + 1
|
|
361
|
+
# to the character limit to account for the fact that the
|
|
362
|
+
# truncation message will be added to the output + a newline.
|
|
363
|
+
# In case we want to run truncation logic both client side
|
|
364
|
+
# and server side, we want to account for the truncation
|
|
365
|
+
# message length to avoid weird double truncation overlap.
|
|
366
|
+
|
|
367
|
+
# Attempt to trim whole lines until we're under
|
|
368
|
+
# the character limit.
|
|
369
|
+
lines = output.split("\n")
|
|
370
|
+
|
|
371
|
+
if output_length <= character_limit and len(lines) <= MAX_LINES:
|
|
372
|
+
return output, False
|
|
373
|
+
|
|
374
|
+
while output_length > character_limit:
|
|
375
|
+
last_line = lines.pop()
|
|
376
|
+
# +1 to account for the newline
|
|
377
|
+
output_length -= len(last_line) + 1
|
|
378
|
+
|
|
379
|
+
if not lines:
|
|
380
|
+
# If we truncated all the lines, then we have
|
|
381
|
+
# have some ridiculous long line at the start
|
|
382
|
+
# of the output so we'll just truncate by
|
|
383
|
+
# character count to retain something.
|
|
384
|
+
output = output[:character_limit]
|
|
385
|
+
else:
|
|
386
|
+
# Otherwise, just join the lines back together up to the limit
|
|
387
|
+
lines = lines[:MAX_LINES]
|
|
388
|
+
output = "\n".join(lines)
|
|
389
|
+
|
|
390
|
+
return output, True
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
@overload
|
|
394
|
+
def truncate_message(response: CodeExecutionRequest) -> CodeExecutionRequest: ...
|
|
395
|
+
@overload
|
|
396
|
+
def truncate_message(response: CodeExecutionResponse) -> CodeExecutionResponse: ...
|
|
397
|
+
@overload
|
|
398
|
+
def truncate_message(
|
|
399
|
+
response: StreamingCodeExecutionResponse,
|
|
400
|
+
) -> StreamingCodeExecutionResponse: ...
|
|
401
|
+
@overload
|
|
402
|
+
def truncate_message(
|
|
403
|
+
response: StreamingCodeExecutionResponseChunk,
|
|
404
|
+
) -> StreamingCodeExecutionResponseChunk: ...
|
|
405
|
+
@overload
|
|
406
|
+
def truncate_message(response: FileWriteRequest) -> FileWriteRequest: ...
|
|
407
|
+
@overload
|
|
408
|
+
def truncate_message(response: FileWriteResponse) -> FileWriteResponse: ...
|
|
409
|
+
@overload
|
|
410
|
+
def truncate_message(
|
|
411
|
+
response: GetFileAttachmentRequest,
|
|
412
|
+
) -> GetFileAttachmentRequest: ...
|
|
413
|
+
@overload
|
|
414
|
+
def truncate_message(
|
|
415
|
+
response: GetFileAttachmentResponse,
|
|
416
|
+
) -> GetFileAttachmentResponse: ...
|
|
417
|
+
@overload
|
|
418
|
+
def truncate_message(response: ListFilesRequest) -> ListFilesRequest: ...
|
|
419
|
+
@overload
|
|
420
|
+
def truncate_message(response: ListFilesResponse) -> ListFilesResponse: ...
|
|
421
|
+
@overload
|
|
422
|
+
def truncate_message(response: GetMatchingFilesRequest) -> GetMatchingFilesRequest: ...
|
|
423
|
+
@overload
|
|
424
|
+
def truncate_message(
|
|
425
|
+
response: GetMatchingFilesResponse,
|
|
426
|
+
) -> GetMatchingFilesResponse: ...
|
|
427
|
+
@overload
|
|
428
|
+
def truncate_message(response: SystemContextRequest) -> SystemContextRequest: ...
|
|
429
|
+
@overload
|
|
430
|
+
def truncate_message(response: SystemContextResponse) -> SystemContextResponse: ...
|
|
431
|
+
@overload
|
|
432
|
+
def truncate_message(
|
|
433
|
+
response: RemoteExecutionRequestType,
|
|
434
|
+
) -> RemoteExecutionRequestType: ...
|
|
435
|
+
@overload
|
|
436
|
+
def truncate_message(
|
|
437
|
+
response: RemoteExecutionResponseType,
|
|
438
|
+
) -> RemoteExecutionResponseType: ...
|
|
439
|
+
@overload
|
|
440
|
+
def truncate_message(response: RemoteExecutionMessage) -> RemoteExecutionMessage: ...
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
def truncate_message(
|
|
444
|
+
response: RemoteExecutionMessage,
|
|
445
|
+
) -> RemoteExecutionMessage:
|
|
446
|
+
if isinstance(
|
|
447
|
+
response,
|
|
448
|
+
(
|
|
449
|
+
CodeExecutionResponse,
|
|
450
|
+
GetFileAttachmentResponse,
|
|
451
|
+
StreamingCodeExecutionResponse,
|
|
452
|
+
StreamingCodeExecutionResponseChunk,
|
|
453
|
+
),
|
|
454
|
+
):
|
|
455
|
+
content, truncated = truncate_output(response.content)
|
|
456
|
+
response.content = content
|
|
457
|
+
if truncated:
|
|
458
|
+
response.truncated = True
|
|
459
|
+
elif (
|
|
460
|
+
isinstance(response, CommandResponse)
|
|
461
|
+
and response.subcommand != "codebase_context"
|
|
462
|
+
):
|
|
463
|
+
content, truncated = truncate_output(response.content)
|
|
464
|
+
response.content = content
|
|
465
|
+
if truncated:
|
|
466
|
+
response.truncated = True
|
|
467
|
+
elif isinstance(response, GetFileAttachmentsResponse):
|
|
468
|
+
for file_attachment in response.file_attachments:
|
|
469
|
+
content, truncated = truncate_output(file_attachment.content)
|
|
470
|
+
file_attachment.content = content
|
|
471
|
+
if truncated:
|
|
472
|
+
file_attachment.truncated = True
|
|
473
|
+
return response
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
### Error Handling
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
def format_attachment_data(
|
|
480
|
+
attachment_lines: list[str] | None = None,
|
|
481
|
+
) -> str | None:
|
|
482
|
+
if not attachment_lines:
|
|
483
|
+
return None
|
|
484
|
+
log_attachment_str = "\n".join(attachment_lines)
|
|
485
|
+
return log_attachment_str
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
def format_error_log(
|
|
489
|
+
exc: Exception,
|
|
490
|
+
chat_uuid: str | None = None,
|
|
491
|
+
attachment_lines: list[str] | None = None,
|
|
492
|
+
) -> CLIErrorLog | None:
|
|
493
|
+
exc_info = exc_info_from_error(exc)
|
|
494
|
+
event, _ = event_from_exception(exc_info)
|
|
495
|
+
attachment_data = format_attachment_data(attachment_lines)
|
|
496
|
+
version = get_installed_version()
|
|
497
|
+
|
|
498
|
+
try:
|
|
499
|
+
event_data = json.dumps(serialize(event)) # type: ignore
|
|
500
|
+
except json.JSONDecodeError:
|
|
501
|
+
return None
|
|
502
|
+
|
|
503
|
+
return CLIErrorLog(
|
|
504
|
+
event_data=event_data,
|
|
505
|
+
attachment_data=attachment_data,
|
|
506
|
+
version=version,
|
|
507
|
+
chat_uuid=chat_uuid,
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
### Websockets
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
ws_logger = logging.getLogger("WebsocketUtils")
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
def ws_retry(
|
|
518
|
+
connection_name: str,
|
|
519
|
+
max_retries: int = 5,
|
|
520
|
+
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
|
521
|
+
connection_name = connection_name.capitalize()
|
|
522
|
+
reconnect_msg = f"{connection_name} reconnecting."
|
|
523
|
+
disconnect_msg = f"{connection_name} connection closed."
|
|
524
|
+
max_disconnect_msg = (
|
|
525
|
+
f"{connection_name} connection closed {max_retries} times, exiting."
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
def decorator(
|
|
529
|
+
f: Callable[..., Awaitable[None]],
|
|
530
|
+
) -> Callable[..., Awaitable[None]]:
|
|
531
|
+
@wraps(f)
|
|
532
|
+
async def wrapped(*args: Any, **kwargs: Any) -> None:
|
|
533
|
+
i = 0
|
|
534
|
+
|
|
535
|
+
while True:
|
|
536
|
+
try:
|
|
537
|
+
return await f(*args, **kwargs)
|
|
538
|
+
except (websockets.exceptions.ConnectionClosed, TimeoutError) as e:
|
|
539
|
+
# Warn on disconnect
|
|
540
|
+
ws_logger.warning(disconnect_msg)
|
|
541
|
+
|
|
542
|
+
if i >= max_retries:
|
|
543
|
+
# We've reached the max number of retries,
|
|
544
|
+
# log an error and reraise
|
|
545
|
+
ws_logger.warning(max_disconnect_msg)
|
|
546
|
+
raise e
|
|
547
|
+
|
|
548
|
+
# Increment the retry count
|
|
549
|
+
i += 1
|
|
550
|
+
# Notify the user that we're reconnecting
|
|
551
|
+
ws_logger.warning(reconnect_msg)
|
|
552
|
+
continue
|
|
553
|
+
|
|
554
|
+
return wrapped
|
|
555
|
+
|
|
556
|
+
return decorator
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
async def safe_read_file(path: FilePath) -> str:
|
|
560
|
+
path = AsyncPath(path)
|
|
561
|
+
|
|
562
|
+
try:
|
|
563
|
+
return await path.read_text(encoding="utf-8")
|
|
564
|
+
except UnicodeDecodeError:
|
|
565
|
+
# Potentially a wacky encoding or mixture of encodings,
|
|
566
|
+
# attempt to correct it.
|
|
567
|
+
fbytes = await path.read_bytes()
|
|
568
|
+
# Handles mixed encodings with utf-8 and cp1252 (windows)
|
|
569
|
+
fbytes = UnicodeDammit.detwingle(fbytes)
|
|
570
|
+
|
|
571
|
+
decode_result = smart_decode(fbytes)
|
|
572
|
+
|
|
573
|
+
if decode_result:
|
|
574
|
+
# First item in tuple is the decoded str
|
|
575
|
+
return decode_result[0]
|
|
576
|
+
|
|
577
|
+
raise
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
async def safe_write_file(path: FilePath, content: str) -> None:
|
|
581
|
+
await AsyncPath(path).write_text(content, encoding="utf-8")
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
def smart_decode(b: bytes) -> tuple[str, str] | None:
|
|
585
|
+
# This function attempts to decode by detecting the actual source
|
|
586
|
+
# encoding, returning (decoded_str, detected_encoding) if successful.
|
|
587
|
+
# We also attempt to fix cases of mixed encodings of cp1252 + utf-8
|
|
588
|
+
# using the detwingle helper provided by bs4. This can happen on
|
|
589
|
+
# windows, particularly when a user edits a utf-8 file by pasting in
|
|
590
|
+
# the special windows smart quotes.
|
|
591
|
+
b = UnicodeDammit.detwingle(b)
|
|
592
|
+
|
|
593
|
+
encoding = UnicodeDammit(
|
|
594
|
+
b, known_definite_encodings=["utf-8", "cp1252"]
|
|
595
|
+
).original_encoding
|
|
596
|
+
|
|
597
|
+
if not encoding:
|
|
598
|
+
return None
|
|
599
|
+
|
|
600
|
+
return (b.decode(encoding=encoding), encoding)
|
|
File without changes
|