indent 0.1.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- exponent/__init__.py +34 -0
- exponent/cli.py +110 -0
- exponent/commands/cloud_commands.py +585 -0
- exponent/commands/common.py +411 -0
- exponent/commands/config_commands.py +334 -0
- exponent/commands/run_commands.py +222 -0
- exponent/commands/settings.py +56 -0
- exponent/commands/types.py +111 -0
- exponent/commands/upgrade.py +29 -0
- exponent/commands/utils.py +146 -0
- exponent/core/config.py +180 -0
- exponent/core/graphql/__init__.py +0 -0
- exponent/core/graphql/client.py +61 -0
- exponent/core/graphql/get_chats_query.py +47 -0
- exponent/core/graphql/mutations.py +160 -0
- exponent/core/graphql/queries.py +146 -0
- exponent/core/graphql/subscriptions.py +16 -0
- exponent/core/remote_execution/checkpoints.py +212 -0
- exponent/core/remote_execution/cli_rpc_types.py +499 -0
- exponent/core/remote_execution/client.py +999 -0
- exponent/core/remote_execution/code_execution.py +77 -0
- exponent/core/remote_execution/default_env.py +31 -0
- exponent/core/remote_execution/error_info.py +45 -0
- exponent/core/remote_execution/exceptions.py +10 -0
- exponent/core/remote_execution/file_write.py +35 -0
- exponent/core/remote_execution/files.py +330 -0
- exponent/core/remote_execution/git.py +268 -0
- exponent/core/remote_execution/http_fetch.py +94 -0
- exponent/core/remote_execution/languages/python_execution.py +239 -0
- exponent/core/remote_execution/languages/shell_streaming.py +226 -0
- exponent/core/remote_execution/languages/types.py +20 -0
- exponent/core/remote_execution/port_utils.py +73 -0
- exponent/core/remote_execution/session.py +128 -0
- exponent/core/remote_execution/system_context.py +26 -0
- exponent/core/remote_execution/terminal_session.py +375 -0
- exponent/core/remote_execution/terminal_types.py +29 -0
- exponent/core/remote_execution/tool_execution.py +595 -0
- exponent/core/remote_execution/tool_type_utils.py +39 -0
- exponent/core/remote_execution/truncation.py +296 -0
- exponent/core/remote_execution/types.py +635 -0
- exponent/core/remote_execution/utils.py +477 -0
- exponent/core/types/__init__.py +0 -0
- exponent/core/types/command_data.py +206 -0
- exponent/core/types/event_types.py +89 -0
- exponent/core/types/generated/__init__.py +0 -0
- exponent/core/types/generated/strategy_info.py +213 -0
- exponent/migration-docs/login.md +112 -0
- exponent/py.typed +4 -0
- exponent/utils/__init__.py +0 -0
- exponent/utils/colors.py +92 -0
- exponent/utils/version.py +289 -0
- indent-0.1.26.dist-info/METADATA +38 -0
- indent-0.1.26.dist-info/RECORD +55 -0
- indent-0.1.26.dist-info/WHEEL +4 -0
- indent-0.1.26.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,477 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import stat
|
|
4
|
+
from collections.abc import Awaitable, Callable
|
|
5
|
+
from functools import wraps
|
|
6
|
+
from typing import (
|
|
7
|
+
Any,
|
|
8
|
+
NoReturn,
|
|
9
|
+
TypeVar,
|
|
10
|
+
cast,
|
|
11
|
+
overload,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
import websockets
|
|
15
|
+
import websockets.exceptions
|
|
16
|
+
from anyio import Path as AsyncPath
|
|
17
|
+
from bs4 import UnicodeDammit
|
|
18
|
+
from httpx import Response
|
|
19
|
+
from pydantic import BaseModel
|
|
20
|
+
from sentry_sdk.serializer import serialize
|
|
21
|
+
from sentry_sdk.utils import (
|
|
22
|
+
event_from_exception,
|
|
23
|
+
exc_info_from_error,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
from exponent.core.remote_execution.cli_rpc_types import FileMetadata
|
|
27
|
+
from exponent.core.remote_execution.types import (
|
|
28
|
+
SUPPORTED_LANGUAGES,
|
|
29
|
+
CLIErrorLog,
|
|
30
|
+
CodeExecutionRequest,
|
|
31
|
+
CodeExecutionResponse,
|
|
32
|
+
CommandRequest,
|
|
33
|
+
CommandResponse,
|
|
34
|
+
CreateCheckpointResponse,
|
|
35
|
+
ErrorResponse,
|
|
36
|
+
FilePath,
|
|
37
|
+
FileWriteRequest,
|
|
38
|
+
FileWriteResponse,
|
|
39
|
+
ListFilesRequest,
|
|
40
|
+
ListFilesResponse,
|
|
41
|
+
RemoteExecutionMessage,
|
|
42
|
+
RemoteExecutionMessageData,
|
|
43
|
+
RemoteExecutionRequest,
|
|
44
|
+
RemoteExecutionRequestType,
|
|
45
|
+
RemoteExecutionResponse,
|
|
46
|
+
RemoteExecutionResponseData,
|
|
47
|
+
RemoteExecutionResponseType,
|
|
48
|
+
RollbackToCheckpointResponse,
|
|
49
|
+
StreamingCodeExecutionResponse,
|
|
50
|
+
StreamingCodeExecutionResponseChunk,
|
|
51
|
+
SupportedLanguage,
|
|
52
|
+
)
|
|
53
|
+
from exponent.core.types.command_data import NaturalEditContent
|
|
54
|
+
from exponent.core.types.event_types import (
|
|
55
|
+
CodeBlockEvent,
|
|
56
|
+
CommandEvent,
|
|
57
|
+
FileWriteEvent,
|
|
58
|
+
LocalEventType,
|
|
59
|
+
)
|
|
60
|
+
from exponent.utils.version import get_installed_version
|
|
61
|
+
|
|
62
|
+
logger = logging.getLogger(__name__)
|
|
63
|
+
|
|
64
|
+
### Serde
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def deserialize_response_data(
|
|
68
|
+
response_data: RemoteExecutionResponseData | str,
|
|
69
|
+
) -> RemoteExecutionResponseType:
|
|
70
|
+
response: RemoteExecutionResponseType
|
|
71
|
+
if isinstance(response_data, str):
|
|
72
|
+
response_data = RemoteExecutionResponseData.model_validate_json(response_data)
|
|
73
|
+
if response_data.direction != "response":
|
|
74
|
+
raise ValueError(f"Expected response, but got {response_data.direction}")
|
|
75
|
+
if response_data.namespace == "code_execution":
|
|
76
|
+
response = CodeExecutionResponse.model_validate_json(response_data.message_data)
|
|
77
|
+
elif response_data.namespace == "streaming_code_execution":
|
|
78
|
+
response = StreamingCodeExecutionResponse.model_validate_json(
|
|
79
|
+
response_data.message_data
|
|
80
|
+
)
|
|
81
|
+
elif response_data.namespace == "streaming_code_execution_chunk":
|
|
82
|
+
response = StreamingCodeExecutionResponseChunk.model_validate_json(
|
|
83
|
+
response_data.message_data
|
|
84
|
+
)
|
|
85
|
+
elif response_data.namespace == "file_write":
|
|
86
|
+
response = FileWriteResponse.model_validate_json(response_data.message_data)
|
|
87
|
+
elif response_data.namespace == "list_files":
|
|
88
|
+
response = ListFilesResponse.model_validate_json(response_data.message_data)
|
|
89
|
+
elif response_data.namespace == "command":
|
|
90
|
+
response = CommandResponse.model_validate_json(response_data.message_data)
|
|
91
|
+
elif response_data.namespace == "error":
|
|
92
|
+
response = ErrorResponse.model_validate_json(response_data.message_data)
|
|
93
|
+
elif response_data.namespace == "create_checkpoint":
|
|
94
|
+
response = CreateCheckpointResponse.model_validate_json(
|
|
95
|
+
response_data.message_data
|
|
96
|
+
)
|
|
97
|
+
elif response_data.namespace == "rollback_to_checkpoint":
|
|
98
|
+
response = RollbackToCheckpointResponse.model_validate_json(
|
|
99
|
+
response_data.message_data
|
|
100
|
+
)
|
|
101
|
+
else:
|
|
102
|
+
# type checking trick, if you miss a namespace then
|
|
103
|
+
# this won't typecheck due to the input parameter
|
|
104
|
+
# having a potential type other than no-return
|
|
105
|
+
response = assert_unreachable(response_data.namespace)
|
|
106
|
+
return truncate_message(response)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def serialize_message(response: RemoteExecutionMessage) -> str:
|
|
110
|
+
truncated_response = truncate_message(response)
|
|
111
|
+
message = RemoteExecutionMessageData(
|
|
112
|
+
namespace=response.namespace,
|
|
113
|
+
direction=response.direction,
|
|
114
|
+
message_data=truncated_response.model_dump_json(),
|
|
115
|
+
)
|
|
116
|
+
serialized = message.model_dump_json()
|
|
117
|
+
return serialized
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
### API Serdes
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
TModel = TypeVar("TModel", bound=BaseModel)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
async def deserialize_api_response(
|
|
127
|
+
response: Response,
|
|
128
|
+
data_model: type[TModel],
|
|
129
|
+
) -> TModel:
|
|
130
|
+
if response.is_error:
|
|
131
|
+
logging.error(response.text)
|
|
132
|
+
try:
|
|
133
|
+
error_message = response.json()["detail"]
|
|
134
|
+
except Exception:
|
|
135
|
+
error_message = response.text
|
|
136
|
+
raise ValueError(f"{error_message} ({response.status_code})")
|
|
137
|
+
|
|
138
|
+
response_json = response.json()
|
|
139
|
+
return data_model.model_validate(response_json)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def get_file_write_content(event: FileWriteEvent) -> str:
|
|
143
|
+
if isinstance(event.write_content, NaturalEditContent):
|
|
144
|
+
assert event.write_content.new_file is not None
|
|
145
|
+
return event.write_content.new_file
|
|
146
|
+
else:
|
|
147
|
+
return event.write_content.content
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
@overload
|
|
151
|
+
def convert_event_to_execution_request(
|
|
152
|
+
request: CodeBlockEvent,
|
|
153
|
+
) -> CodeExecutionRequest: ...
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@overload
|
|
157
|
+
def convert_event_to_execution_request(
|
|
158
|
+
request: FileWriteEvent,
|
|
159
|
+
) -> FileWriteRequest: ...
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
@overload
|
|
163
|
+
def convert_event_to_execution_request(
|
|
164
|
+
request: CommandEvent,
|
|
165
|
+
) -> CommandRequest: ...
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def convert_event_to_execution_request(
|
|
169
|
+
request: LocalEventType,
|
|
170
|
+
) -> CodeExecutionRequest | FileWriteRequest | CommandRequest:
|
|
171
|
+
if isinstance(request, CodeBlockEvent):
|
|
172
|
+
language = assert_supported_language(request.language)
|
|
173
|
+
|
|
174
|
+
return CodeExecutionRequest(
|
|
175
|
+
language=language,
|
|
176
|
+
content=request.content,
|
|
177
|
+
timeout=request.timeout,
|
|
178
|
+
correlation_id=request.event_uuid,
|
|
179
|
+
)
|
|
180
|
+
elif isinstance(request, FileWriteEvent):
|
|
181
|
+
return FileWriteRequest(
|
|
182
|
+
file_path=request.file_path,
|
|
183
|
+
language=request.language,
|
|
184
|
+
write_strategy=request.write_strategy,
|
|
185
|
+
content=get_file_write_content(request),
|
|
186
|
+
correlation_id=request.event_uuid,
|
|
187
|
+
)
|
|
188
|
+
elif isinstance(request, CommandEvent):
|
|
189
|
+
return CommandRequest(
|
|
190
|
+
data=request.data,
|
|
191
|
+
correlation_id=request.event_uuid,
|
|
192
|
+
)
|
|
193
|
+
else:
|
|
194
|
+
assert_unreachable(request)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
### Validation
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
ResponseT = TypeVar("ResponseT", bound=RemoteExecutionResponse)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def assert_valid_response_type(
|
|
204
|
+
response: RemoteExecutionResponseType, request: RemoteExecutionRequest[ResponseT]
|
|
205
|
+
) -> ResponseT | ErrorResponse:
|
|
206
|
+
if isinstance(response, ErrorResponse):
|
|
207
|
+
return response
|
|
208
|
+
if request.namespace != response.namespace or response.direction != "response":
|
|
209
|
+
raise ValueError(
|
|
210
|
+
f"Expected {request.namespace}.response, but got {response.namespace}.{response.direction}"
|
|
211
|
+
)
|
|
212
|
+
return cast(ResponseT, response)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def assert_unreachable(x: NoReturn) -> NoReturn:
|
|
216
|
+
assert False, f"Unhandled type: {type(x).__name__}"
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def assert_supported_language(language: str) -> SupportedLanguage:
|
|
220
|
+
if language not in SUPPORTED_LANGUAGES:
|
|
221
|
+
raise ValueError(f"Unsupported language: {language}")
|
|
222
|
+
|
|
223
|
+
return cast(SupportedLanguage, language)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
### Truncation
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
OUTPUT_CHARACTER_MAX = 90_000 # A tad over ~8k tokens
|
|
230
|
+
TRUNCATION_MESSAGE_CHARS = (
|
|
231
|
+
"(Output truncated, only showing the first {remaining_chars} characters)"
|
|
232
|
+
)
|
|
233
|
+
TRUNCATION_MESSAGE_LINES = (
|
|
234
|
+
"(Output truncated, only showing the first {remaining_lines} lines)"
|
|
235
|
+
)
|
|
236
|
+
LONGEST_TRUNCATION_MESSAGE_LEN = (
|
|
237
|
+
len(TRUNCATION_MESSAGE_CHARS.format(remaining_chars=OUTPUT_CHARACTER_MAX)) + 1
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
MAX_LINES = 10_000
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def truncate_output(
|
|
244
|
+
output: str, character_limit: int = OUTPUT_CHARACTER_MAX
|
|
245
|
+
) -> tuple[str, bool]:
|
|
246
|
+
output_length = len(output)
|
|
247
|
+
# When under the character limit, return the output as is.
|
|
248
|
+
# Note we're adding the length of the truncation message + 1
|
|
249
|
+
# to the character limit to account for the fact that the
|
|
250
|
+
# truncation message will be added to the output + a newline.
|
|
251
|
+
# In case we want to run truncation logic both client side
|
|
252
|
+
# and server side, we want to account for the truncation
|
|
253
|
+
# message length to avoid weird double truncation overlap.
|
|
254
|
+
|
|
255
|
+
# Attempt to trim whole lines until we're under
|
|
256
|
+
# the character limit.
|
|
257
|
+
lines = output.split("\n")
|
|
258
|
+
|
|
259
|
+
if output_length <= character_limit and len(lines) <= MAX_LINES:
|
|
260
|
+
return output, False
|
|
261
|
+
|
|
262
|
+
while output_length > character_limit:
|
|
263
|
+
last_line = lines.pop()
|
|
264
|
+
# +1 to account for the newline
|
|
265
|
+
output_length -= len(last_line) + 1
|
|
266
|
+
|
|
267
|
+
if not lines:
|
|
268
|
+
# If we truncated all the lines, then we have
|
|
269
|
+
# have some ridiculous long line at the start
|
|
270
|
+
# of the output so we'll just truncate by
|
|
271
|
+
# character count to retain something.
|
|
272
|
+
output = output[:character_limit]
|
|
273
|
+
else:
|
|
274
|
+
# Otherwise, just join the lines back together up to the limit
|
|
275
|
+
lines = lines[:MAX_LINES]
|
|
276
|
+
output = "\n".join(lines)
|
|
277
|
+
|
|
278
|
+
return output, True
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
@overload
|
|
282
|
+
def truncate_message(response: CodeExecutionRequest) -> CodeExecutionRequest: ...
|
|
283
|
+
@overload
|
|
284
|
+
def truncate_message(response: CodeExecutionResponse) -> CodeExecutionResponse: ...
|
|
285
|
+
@overload
|
|
286
|
+
def truncate_message(
|
|
287
|
+
response: StreamingCodeExecutionResponse,
|
|
288
|
+
) -> StreamingCodeExecutionResponse: ...
|
|
289
|
+
@overload
|
|
290
|
+
def truncate_message(
|
|
291
|
+
response: StreamingCodeExecutionResponseChunk,
|
|
292
|
+
) -> StreamingCodeExecutionResponseChunk: ...
|
|
293
|
+
@overload
|
|
294
|
+
def truncate_message(response: FileWriteRequest) -> FileWriteRequest: ...
|
|
295
|
+
@overload
|
|
296
|
+
def truncate_message(response: FileWriteResponse) -> FileWriteResponse: ...
|
|
297
|
+
@overload
|
|
298
|
+
def truncate_message(response: ListFilesRequest) -> ListFilesRequest: ...
|
|
299
|
+
@overload
|
|
300
|
+
def truncate_message(response: ListFilesResponse) -> ListFilesResponse: ...
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
@overload
|
|
304
|
+
def truncate_message(
|
|
305
|
+
response: RemoteExecutionRequestType,
|
|
306
|
+
) -> RemoteExecutionRequestType: ...
|
|
307
|
+
@overload
|
|
308
|
+
def truncate_message(
|
|
309
|
+
response: RemoteExecutionResponseType,
|
|
310
|
+
) -> RemoteExecutionResponseType: ...
|
|
311
|
+
@overload
|
|
312
|
+
def truncate_message(response: RemoteExecutionMessage) -> RemoteExecutionMessage: ...
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def truncate_message(
|
|
316
|
+
response: RemoteExecutionMessage,
|
|
317
|
+
) -> RemoteExecutionMessage:
|
|
318
|
+
if isinstance(
|
|
319
|
+
response,
|
|
320
|
+
CodeExecutionResponse
|
|
321
|
+
| StreamingCodeExecutionResponse
|
|
322
|
+
| StreamingCodeExecutionResponseChunk,
|
|
323
|
+
):
|
|
324
|
+
content, truncated = truncate_output(response.content)
|
|
325
|
+
response.content = content
|
|
326
|
+
if truncated:
|
|
327
|
+
response.truncated = True
|
|
328
|
+
elif (
|
|
329
|
+
isinstance(response, CommandResponse)
|
|
330
|
+
and response.subcommand != "codebase_context"
|
|
331
|
+
):
|
|
332
|
+
content, truncated = truncate_output(response.content)
|
|
333
|
+
response.content = content
|
|
334
|
+
if truncated:
|
|
335
|
+
response.truncated = True
|
|
336
|
+
return response
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
### Error Handling
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def format_attachment_data(
|
|
343
|
+
attachment_lines: list[str] | None = None,
|
|
344
|
+
) -> str | None:
|
|
345
|
+
if not attachment_lines:
|
|
346
|
+
return None
|
|
347
|
+
log_attachment_str = "\n".join(attachment_lines)
|
|
348
|
+
return log_attachment_str
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
def format_error_log(
|
|
352
|
+
exc: Exception,
|
|
353
|
+
chat_uuid: str | None = None,
|
|
354
|
+
attachment_lines: list[str] | None = None,
|
|
355
|
+
) -> CLIErrorLog | None:
|
|
356
|
+
exc_info = exc_info_from_error(exc)
|
|
357
|
+
event, _ = event_from_exception(exc_info)
|
|
358
|
+
attachment_data = format_attachment_data(attachment_lines)
|
|
359
|
+
version = get_installed_version()
|
|
360
|
+
|
|
361
|
+
try:
|
|
362
|
+
event_data = json.dumps(serialize(event)) # type: ignore
|
|
363
|
+
except json.JSONDecodeError:
|
|
364
|
+
return None
|
|
365
|
+
|
|
366
|
+
return CLIErrorLog(
|
|
367
|
+
event_data=event_data,
|
|
368
|
+
attachment_data=attachment_data,
|
|
369
|
+
version=version,
|
|
370
|
+
chat_uuid=chat_uuid,
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
### Websockets
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
ws_logger = logging.getLogger("WebsocketUtils")
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def ws_retry(
|
|
381
|
+
connection_name: str,
|
|
382
|
+
max_retries: int = 5,
|
|
383
|
+
) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
|
|
384
|
+
connection_name = connection_name.capitalize()
|
|
385
|
+
reconnect_msg = f"{connection_name} reconnecting."
|
|
386
|
+
disconnect_msg = f"{connection_name} connection closed."
|
|
387
|
+
max_disconnect_msg = (
|
|
388
|
+
f"{connection_name} connection closed {max_retries} times, exiting."
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
def decorator(
|
|
392
|
+
f: Callable[..., Awaitable[None]],
|
|
393
|
+
) -> Callable[..., Awaitable[None]]:
|
|
394
|
+
@wraps(f)
|
|
395
|
+
async def wrapped(*args: Any, **kwargs: Any) -> None:
|
|
396
|
+
i = 0
|
|
397
|
+
|
|
398
|
+
while True:
|
|
399
|
+
try:
|
|
400
|
+
return await f(*args, **kwargs)
|
|
401
|
+
except (websockets.exceptions.ConnectionClosed, TimeoutError) as e:
|
|
402
|
+
# Warn on disconnect
|
|
403
|
+
ws_logger.warning(disconnect_msg)
|
|
404
|
+
|
|
405
|
+
if i >= max_retries:
|
|
406
|
+
# We've reached the max number of retries,
|
|
407
|
+
# log an error and reraise
|
|
408
|
+
ws_logger.warning(max_disconnect_msg)
|
|
409
|
+
raise e
|
|
410
|
+
|
|
411
|
+
# Increment the retry count
|
|
412
|
+
i += 1
|
|
413
|
+
# Notify the user that we're reconnecting
|
|
414
|
+
ws_logger.warning(reconnect_msg)
|
|
415
|
+
continue
|
|
416
|
+
|
|
417
|
+
return wrapped
|
|
418
|
+
|
|
419
|
+
return decorator
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
async def safe_read_file(path: FilePath) -> str:
|
|
423
|
+
path = AsyncPath(path)
|
|
424
|
+
|
|
425
|
+
try:
|
|
426
|
+
return await path.read_text(encoding="utf-8")
|
|
427
|
+
except UnicodeDecodeError:
|
|
428
|
+
# Potentially a wacky encoding or mixture of encodings,
|
|
429
|
+
# attempt to correct it.
|
|
430
|
+
fbytes = await path.read_bytes()
|
|
431
|
+
# Handles mixed encodings with utf-8 and cp1252 (windows)
|
|
432
|
+
fbytes = UnicodeDammit.detwingle(fbytes)
|
|
433
|
+
|
|
434
|
+
decode_result = smart_decode(fbytes)
|
|
435
|
+
|
|
436
|
+
if decode_result:
|
|
437
|
+
# First item in tuple is the decoded str
|
|
438
|
+
return decode_result[0]
|
|
439
|
+
|
|
440
|
+
raise
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
async def safe_get_file_metadata(path: FilePath) -> FileMetadata | None:
|
|
444
|
+
path = AsyncPath(path)
|
|
445
|
+
try:
|
|
446
|
+
stats = await path.stat()
|
|
447
|
+
except Exception as e:
|
|
448
|
+
logger.error(f"Error getting file metadata: {e!s}")
|
|
449
|
+
return None
|
|
450
|
+
|
|
451
|
+
return FileMetadata(
|
|
452
|
+
modified_timestamp=stats.st_mtime,
|
|
453
|
+
file_mode=stat.filemode(stats.st_mode),
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
async def safe_write_file(path: FilePath, content: str) -> None:
|
|
458
|
+
await AsyncPath(path).write_text(content, encoding="utf-8")
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
def smart_decode(b: bytes) -> tuple[str, str] | None:
|
|
462
|
+
# This function attempts to decode by detecting the actual source
|
|
463
|
+
# encoding, returning (decoded_str, detected_encoding) if successful.
|
|
464
|
+
# We also attempt to fix cases of mixed encodings of cp1252 + utf-8
|
|
465
|
+
# using the detwingle helper provided by bs4. This can happen on
|
|
466
|
+
# windows, particularly when a user edits a utf-8 file by pasting in
|
|
467
|
+
# the special windows smart quotes.
|
|
468
|
+
b = UnicodeDammit.detwingle(b)
|
|
469
|
+
|
|
470
|
+
encoding = UnicodeDammit(
|
|
471
|
+
b, known_definite_encodings=["utf-8", "cp1252"]
|
|
472
|
+
).original_encoding
|
|
473
|
+
|
|
474
|
+
if not encoding:
|
|
475
|
+
return None
|
|
476
|
+
|
|
477
|
+
return (b.decode(encoding=encoding), encoding)
|
|
File without changes
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import Annotated, Any, ClassVar, Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
|
|
7
|
+
WRITE_STRATEGY_FULL_FILE_REWRITE: Literal["FULL_FILE_REWRITE"] = "FULL_FILE_REWRITE"
|
|
8
|
+
DEFAULT_CODE_BLOCK_TIMEOUT = 30
|
|
9
|
+
WRITE_STRATEGY_NATURAL_EDIT: Literal["NATURAL_EDIT"] = "NATURAL_EDIT"
|
|
10
|
+
WRITE_STRATEGY_SEARCH_REPLACE: Literal["SEARCH_REPLACE"] = "SEARCH_REPLACE"
|
|
11
|
+
WRITE_STRATEGY_UDIFF: Literal["UDIFF"] = "UDIFF"
|
|
12
|
+
|
|
13
|
+
FileWriteStrategyName = Literal[
|
|
14
|
+
"FULL_FILE_REWRITE", "UDIFF", "SEARCH_REPLACE", "NATURAL_EDIT"
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class CommandType(str, Enum):
|
|
19
|
+
THINKING = "thinking"
|
|
20
|
+
FILE_READ = "file_read"
|
|
21
|
+
SUMMARIZE = "summarize"
|
|
22
|
+
STEP_OUTPUT = "step_output"
|
|
23
|
+
PROTOTYPE = "prototype"
|
|
24
|
+
DB_QUERY = "db_query"
|
|
25
|
+
DB_GET_TABLE_NAMES = "db_get_table_names"
|
|
26
|
+
DB_GET_TABLE_SCHEMA = "db_get_table_schema"
|
|
27
|
+
ANSWER = "answer"
|
|
28
|
+
ASK = "ask"
|
|
29
|
+
SHELL = "shell"
|
|
30
|
+
PYTHON = "python"
|
|
31
|
+
FILE_WRITE = "file_write"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class CommandData(BaseModel):
|
|
35
|
+
executable: ClassVar[bool]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class FileReadCommandData(CommandData):
|
|
39
|
+
executable: ClassVar[bool] = True
|
|
40
|
+
type: Literal[CommandType.FILE_READ] = CommandType.FILE_READ
|
|
41
|
+
|
|
42
|
+
file_path: str
|
|
43
|
+
language: str
|
|
44
|
+
limit: int | None = None
|
|
45
|
+
offset: int | None = None
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class ThinkingCommandData(CommandData):
|
|
49
|
+
executable: ClassVar[bool] = False
|
|
50
|
+
type: Literal[CommandType.THINKING] = CommandType.THINKING
|
|
51
|
+
|
|
52
|
+
content: str
|
|
53
|
+
signature: str | None = None
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class PrototypeCommandData(CommandData):
|
|
57
|
+
executable: ClassVar[bool] = True
|
|
58
|
+
type: Literal[CommandType.PROTOTYPE] = CommandType.PROTOTYPE
|
|
59
|
+
|
|
60
|
+
command_name: str
|
|
61
|
+
# Structured data extracted from LLM output
|
|
62
|
+
content_json: dict[str, Any]
|
|
63
|
+
# Raw text extracted from LLM output
|
|
64
|
+
content_raw: str
|
|
65
|
+
# Rendered LLM output for frontend display
|
|
66
|
+
content_rendered: str
|
|
67
|
+
|
|
68
|
+
llm_command_name_override: str | None = None
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def llm_command_name(self) -> str:
|
|
72
|
+
return self.llm_command_name_override or self.command_name
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
# deprecated, use StepOutputCommandData instead
|
|
76
|
+
class SummarizeCommandData(CommandData):
|
|
77
|
+
executable: ClassVar[bool] = True
|
|
78
|
+
type: Literal[CommandType.SUMMARIZE] = CommandType.SUMMARIZE
|
|
79
|
+
|
|
80
|
+
summary: str
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class StepOutputCommandData(CommandData):
|
|
84
|
+
executable: ClassVar[bool] = True
|
|
85
|
+
type: Literal[CommandType.STEP_OUTPUT] = CommandType.STEP_OUTPUT
|
|
86
|
+
|
|
87
|
+
step_output_raw: str
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class DBQueryCommandData(CommandData):
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
**kwargs: Any,
|
|
94
|
+
) -> None:
|
|
95
|
+
super().__init__(**kwargs)
|
|
96
|
+
|
|
97
|
+
executable: ClassVar[bool] = True
|
|
98
|
+
type: Literal[CommandType.DB_QUERY] = CommandType.DB_QUERY
|
|
99
|
+
|
|
100
|
+
query: str
|
|
101
|
+
max_gigabytes_billed: float | None = None # BigQuery only
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class DBGetTableNamesCommandData(CommandData):
|
|
105
|
+
executable: ClassVar[bool] = True
|
|
106
|
+
type: Literal[CommandType.DB_GET_TABLE_NAMES] = CommandType.DB_GET_TABLE_NAMES
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class DBGetTableSchemaCommandData(CommandData):
|
|
110
|
+
executable: ClassVar[bool] = True
|
|
111
|
+
type: Literal[CommandType.DB_GET_TABLE_SCHEMA] = CommandType.DB_GET_TABLE_SCHEMA
|
|
112
|
+
|
|
113
|
+
table_name: str
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class AnswerCommandData(CommandData):
|
|
117
|
+
executable: ClassVar[bool] = False
|
|
118
|
+
type: Literal[CommandType.ANSWER] = CommandType.ANSWER
|
|
119
|
+
|
|
120
|
+
answer_raw: str
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class AskCommandData(CommandData):
|
|
124
|
+
executable: ClassVar[bool] = False
|
|
125
|
+
type: Literal[CommandType.ASK] = CommandType.ASK
|
|
126
|
+
|
|
127
|
+
ask_raw: str
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class ShellCommandData(CommandData):
|
|
131
|
+
exclude_from_schema_gen: ClassVar[bool] = True
|
|
132
|
+
|
|
133
|
+
executable: ClassVar[bool] = True
|
|
134
|
+
type: Literal[CommandType.SHELL] = CommandType.SHELL
|
|
135
|
+
|
|
136
|
+
timeout: int = DEFAULT_CODE_BLOCK_TIMEOUT
|
|
137
|
+
content: str
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class PythonCommandData(CommandData):
|
|
141
|
+
exclude_from_schema_gen: ClassVar[bool] = True
|
|
142
|
+
|
|
143
|
+
executable: ClassVar[bool] = True
|
|
144
|
+
type: Literal[CommandType.PYTHON] = CommandType.PYTHON
|
|
145
|
+
|
|
146
|
+
content: str
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class EditContent(BaseModel):
|
|
150
|
+
content: str
|
|
151
|
+
original_file: str | None = None
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class NaturalEditContent(BaseModel):
|
|
155
|
+
natural_edit: str
|
|
156
|
+
intermediate_edit: str | None
|
|
157
|
+
original_file: str | None
|
|
158
|
+
new_file: str | None
|
|
159
|
+
error_content: str | None
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def is_resolved(self) -> bool:
|
|
163
|
+
return self.new_file is not None or self.error_content is not None
|
|
164
|
+
|
|
165
|
+
@property
|
|
166
|
+
def is_noop(self) -> bool:
|
|
167
|
+
return bool(
|
|
168
|
+
self.new_file is not None
|
|
169
|
+
and self.original_file is not None
|
|
170
|
+
and self.new_file == self.original_file
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class FileWriteCommandData(CommandData):
|
|
175
|
+
exclude_from_schema_gen: ClassVar[bool] = True
|
|
176
|
+
|
|
177
|
+
executable: ClassVar[bool] = True
|
|
178
|
+
type: Literal[CommandType.FILE_WRITE] = CommandType.FILE_WRITE
|
|
179
|
+
|
|
180
|
+
file_path: str
|
|
181
|
+
language: str
|
|
182
|
+
write_strategy: FileWriteStrategyName
|
|
183
|
+
write_content: NaturalEditContent | EditContent
|
|
184
|
+
content: str
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
CommandDataType = Annotated[
|
|
188
|
+
FileReadCommandData
|
|
189
|
+
| ThinkingCommandData
|
|
190
|
+
| PrototypeCommandData
|
|
191
|
+
| SummarizeCommandData
|
|
192
|
+
| DBQueryCommandData
|
|
193
|
+
| DBGetTableNamesCommandData
|
|
194
|
+
| DBGetTableSchemaCommandData
|
|
195
|
+
| StepOutputCommandData
|
|
196
|
+
| AnswerCommandData
|
|
197
|
+
| AskCommandData
|
|
198
|
+
| ShellCommandData
|
|
199
|
+
| PythonCommandData
|
|
200
|
+
| FileWriteCommandData,
|
|
201
|
+
Field(discriminator="type"),
|
|
202
|
+
]
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class CommandImpl(ABC):
|
|
206
|
+
command_data_type: ClassVar[type[CommandData]]
|