indent 0.1.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. exponent/__init__.py +34 -0
  2. exponent/cli.py +110 -0
  3. exponent/commands/cloud_commands.py +585 -0
  4. exponent/commands/common.py +411 -0
  5. exponent/commands/config_commands.py +334 -0
  6. exponent/commands/run_commands.py +222 -0
  7. exponent/commands/settings.py +56 -0
  8. exponent/commands/types.py +111 -0
  9. exponent/commands/upgrade.py +29 -0
  10. exponent/commands/utils.py +146 -0
  11. exponent/core/config.py +180 -0
  12. exponent/core/graphql/__init__.py +0 -0
  13. exponent/core/graphql/client.py +61 -0
  14. exponent/core/graphql/get_chats_query.py +47 -0
  15. exponent/core/graphql/mutations.py +160 -0
  16. exponent/core/graphql/queries.py +146 -0
  17. exponent/core/graphql/subscriptions.py +16 -0
  18. exponent/core/remote_execution/checkpoints.py +212 -0
  19. exponent/core/remote_execution/cli_rpc_types.py +499 -0
  20. exponent/core/remote_execution/client.py +999 -0
  21. exponent/core/remote_execution/code_execution.py +77 -0
  22. exponent/core/remote_execution/default_env.py +31 -0
  23. exponent/core/remote_execution/error_info.py +45 -0
  24. exponent/core/remote_execution/exceptions.py +10 -0
  25. exponent/core/remote_execution/file_write.py +35 -0
  26. exponent/core/remote_execution/files.py +330 -0
  27. exponent/core/remote_execution/git.py +268 -0
  28. exponent/core/remote_execution/http_fetch.py +94 -0
  29. exponent/core/remote_execution/languages/python_execution.py +239 -0
  30. exponent/core/remote_execution/languages/shell_streaming.py +226 -0
  31. exponent/core/remote_execution/languages/types.py +20 -0
  32. exponent/core/remote_execution/port_utils.py +73 -0
  33. exponent/core/remote_execution/session.py +128 -0
  34. exponent/core/remote_execution/system_context.py +26 -0
  35. exponent/core/remote_execution/terminal_session.py +375 -0
  36. exponent/core/remote_execution/terminal_types.py +29 -0
  37. exponent/core/remote_execution/tool_execution.py +595 -0
  38. exponent/core/remote_execution/tool_type_utils.py +39 -0
  39. exponent/core/remote_execution/truncation.py +296 -0
  40. exponent/core/remote_execution/types.py +635 -0
  41. exponent/core/remote_execution/utils.py +477 -0
  42. exponent/core/types/__init__.py +0 -0
  43. exponent/core/types/command_data.py +206 -0
  44. exponent/core/types/event_types.py +89 -0
  45. exponent/core/types/generated/__init__.py +0 -0
  46. exponent/core/types/generated/strategy_info.py +213 -0
  47. exponent/migration-docs/login.md +112 -0
  48. exponent/py.typed +4 -0
  49. exponent/utils/__init__.py +0 -0
  50. exponent/utils/colors.py +92 -0
  51. exponent/utils/version.py +289 -0
  52. indent-0.1.26.dist-info/METADATA +38 -0
  53. indent-0.1.26.dist-info/RECORD +55 -0
  54. indent-0.1.26.dist-info/WHEEL +4 -0
  55. indent-0.1.26.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,477 @@
1
+ import json
2
+ import logging
3
+ import stat
4
+ from collections.abc import Awaitable, Callable
5
+ from functools import wraps
6
+ from typing import (
7
+ Any,
8
+ NoReturn,
9
+ TypeVar,
10
+ cast,
11
+ overload,
12
+ )
13
+
14
+ import websockets
15
+ import websockets.exceptions
16
+ from anyio import Path as AsyncPath
17
+ from bs4 import UnicodeDammit
18
+ from httpx import Response
19
+ from pydantic import BaseModel
20
+ from sentry_sdk.serializer import serialize
21
+ from sentry_sdk.utils import (
22
+ event_from_exception,
23
+ exc_info_from_error,
24
+ )
25
+
26
+ from exponent.core.remote_execution.cli_rpc_types import FileMetadata
27
+ from exponent.core.remote_execution.types import (
28
+ SUPPORTED_LANGUAGES,
29
+ CLIErrorLog,
30
+ CodeExecutionRequest,
31
+ CodeExecutionResponse,
32
+ CommandRequest,
33
+ CommandResponse,
34
+ CreateCheckpointResponse,
35
+ ErrorResponse,
36
+ FilePath,
37
+ FileWriteRequest,
38
+ FileWriteResponse,
39
+ ListFilesRequest,
40
+ ListFilesResponse,
41
+ RemoteExecutionMessage,
42
+ RemoteExecutionMessageData,
43
+ RemoteExecutionRequest,
44
+ RemoteExecutionRequestType,
45
+ RemoteExecutionResponse,
46
+ RemoteExecutionResponseData,
47
+ RemoteExecutionResponseType,
48
+ RollbackToCheckpointResponse,
49
+ StreamingCodeExecutionResponse,
50
+ StreamingCodeExecutionResponseChunk,
51
+ SupportedLanguage,
52
+ )
53
+ from exponent.core.types.command_data import NaturalEditContent
54
+ from exponent.core.types.event_types import (
55
+ CodeBlockEvent,
56
+ CommandEvent,
57
+ FileWriteEvent,
58
+ LocalEventType,
59
+ )
60
+ from exponent.utils.version import get_installed_version
61
+
62
+ logger = logging.getLogger(__name__)
63
+
64
+ ### Serde
65
+
66
+
67
+ def deserialize_response_data(
68
+ response_data: RemoteExecutionResponseData | str,
69
+ ) -> RemoteExecutionResponseType:
70
+ response: RemoteExecutionResponseType
71
+ if isinstance(response_data, str):
72
+ response_data = RemoteExecutionResponseData.model_validate_json(response_data)
73
+ if response_data.direction != "response":
74
+ raise ValueError(f"Expected response, but got {response_data.direction}")
75
+ if response_data.namespace == "code_execution":
76
+ response = CodeExecutionResponse.model_validate_json(response_data.message_data)
77
+ elif response_data.namespace == "streaming_code_execution":
78
+ response = StreamingCodeExecutionResponse.model_validate_json(
79
+ response_data.message_data
80
+ )
81
+ elif response_data.namespace == "streaming_code_execution_chunk":
82
+ response = StreamingCodeExecutionResponseChunk.model_validate_json(
83
+ response_data.message_data
84
+ )
85
+ elif response_data.namespace == "file_write":
86
+ response = FileWriteResponse.model_validate_json(response_data.message_data)
87
+ elif response_data.namespace == "list_files":
88
+ response = ListFilesResponse.model_validate_json(response_data.message_data)
89
+ elif response_data.namespace == "command":
90
+ response = CommandResponse.model_validate_json(response_data.message_data)
91
+ elif response_data.namespace == "error":
92
+ response = ErrorResponse.model_validate_json(response_data.message_data)
93
+ elif response_data.namespace == "create_checkpoint":
94
+ response = CreateCheckpointResponse.model_validate_json(
95
+ response_data.message_data
96
+ )
97
+ elif response_data.namespace == "rollback_to_checkpoint":
98
+ response = RollbackToCheckpointResponse.model_validate_json(
99
+ response_data.message_data
100
+ )
101
+ else:
102
+ # type checking trick, if you miss a namespace then
103
+ # this won't typecheck due to the input parameter
104
+ # having a potential type other than no-return
105
+ response = assert_unreachable(response_data.namespace)
106
+ return truncate_message(response)
107
+
108
+
109
+ def serialize_message(response: RemoteExecutionMessage) -> str:
110
+ truncated_response = truncate_message(response)
111
+ message = RemoteExecutionMessageData(
112
+ namespace=response.namespace,
113
+ direction=response.direction,
114
+ message_data=truncated_response.model_dump_json(),
115
+ )
116
+ serialized = message.model_dump_json()
117
+ return serialized
118
+
119
+
120
+ ### API Serdes
121
+
122
+
123
+ TModel = TypeVar("TModel", bound=BaseModel)
124
+
125
+
126
+ async def deserialize_api_response(
127
+ response: Response,
128
+ data_model: type[TModel],
129
+ ) -> TModel:
130
+ if response.is_error:
131
+ logging.error(response.text)
132
+ try:
133
+ error_message = response.json()["detail"]
134
+ except Exception:
135
+ error_message = response.text
136
+ raise ValueError(f"{error_message} ({response.status_code})")
137
+
138
+ response_json = response.json()
139
+ return data_model.model_validate(response_json)
140
+
141
+
142
+ def get_file_write_content(event: FileWriteEvent) -> str:
143
+ if isinstance(event.write_content, NaturalEditContent):
144
+ assert event.write_content.new_file is not None
145
+ return event.write_content.new_file
146
+ else:
147
+ return event.write_content.content
148
+
149
+
150
+ @overload
151
+ def convert_event_to_execution_request(
152
+ request: CodeBlockEvent,
153
+ ) -> CodeExecutionRequest: ...
154
+
155
+
156
+ @overload
157
+ def convert_event_to_execution_request(
158
+ request: FileWriteEvent,
159
+ ) -> FileWriteRequest: ...
160
+
161
+
162
+ @overload
163
+ def convert_event_to_execution_request(
164
+ request: CommandEvent,
165
+ ) -> CommandRequest: ...
166
+
167
+
168
+ def convert_event_to_execution_request(
169
+ request: LocalEventType,
170
+ ) -> CodeExecutionRequest | FileWriteRequest | CommandRequest:
171
+ if isinstance(request, CodeBlockEvent):
172
+ language = assert_supported_language(request.language)
173
+
174
+ return CodeExecutionRequest(
175
+ language=language,
176
+ content=request.content,
177
+ timeout=request.timeout,
178
+ correlation_id=request.event_uuid,
179
+ )
180
+ elif isinstance(request, FileWriteEvent):
181
+ return FileWriteRequest(
182
+ file_path=request.file_path,
183
+ language=request.language,
184
+ write_strategy=request.write_strategy,
185
+ content=get_file_write_content(request),
186
+ correlation_id=request.event_uuid,
187
+ )
188
+ elif isinstance(request, CommandEvent):
189
+ return CommandRequest(
190
+ data=request.data,
191
+ correlation_id=request.event_uuid,
192
+ )
193
+ else:
194
+ assert_unreachable(request)
195
+
196
+
197
+ ### Validation
198
+
199
+
200
+ ResponseT = TypeVar("ResponseT", bound=RemoteExecutionResponse)
201
+
202
+
203
+ def assert_valid_response_type(
204
+ response: RemoteExecutionResponseType, request: RemoteExecutionRequest[ResponseT]
205
+ ) -> ResponseT | ErrorResponse:
206
+ if isinstance(response, ErrorResponse):
207
+ return response
208
+ if request.namespace != response.namespace or response.direction != "response":
209
+ raise ValueError(
210
+ f"Expected {request.namespace}.response, but got {response.namespace}.{response.direction}"
211
+ )
212
+ return cast(ResponseT, response)
213
+
214
+
215
+ def assert_unreachable(x: NoReturn) -> NoReturn:
216
+ assert False, f"Unhandled type: {type(x).__name__}"
217
+
218
+
219
+ def assert_supported_language(language: str) -> SupportedLanguage:
220
+ if language not in SUPPORTED_LANGUAGES:
221
+ raise ValueError(f"Unsupported language: {language}")
222
+
223
+ return cast(SupportedLanguage, language)
224
+
225
+
226
+ ### Truncation
227
+
228
+
229
+ OUTPUT_CHARACTER_MAX = 90_000 # A tad over ~8k tokens
230
+ TRUNCATION_MESSAGE_CHARS = (
231
+ "(Output truncated, only showing the first {remaining_chars} characters)"
232
+ )
233
+ TRUNCATION_MESSAGE_LINES = (
234
+ "(Output truncated, only showing the first {remaining_lines} lines)"
235
+ )
236
+ LONGEST_TRUNCATION_MESSAGE_LEN = (
237
+ len(TRUNCATION_MESSAGE_CHARS.format(remaining_chars=OUTPUT_CHARACTER_MAX)) + 1
238
+ )
239
+
240
+ MAX_LINES = 10_000
241
+
242
+
243
+ def truncate_output(
244
+ output: str, character_limit: int = OUTPUT_CHARACTER_MAX
245
+ ) -> tuple[str, bool]:
246
+ output_length = len(output)
247
+ # When under the character limit, return the output as is.
248
+ # Note we're adding the length of the truncation message + 1
249
+ # to the character limit to account for the fact that the
250
+ # truncation message will be added to the output + a newline.
251
+ # In case we want to run truncation logic both client side
252
+ # and server side, we want to account for the truncation
253
+ # message length to avoid weird double truncation overlap.
254
+
255
+ # Attempt to trim whole lines until we're under
256
+ # the character limit.
257
+ lines = output.split("\n")
258
+
259
+ if output_length <= character_limit and len(lines) <= MAX_LINES:
260
+ return output, False
261
+
262
+ while output_length > character_limit:
263
+ last_line = lines.pop()
264
+ # +1 to account for the newline
265
+ output_length -= len(last_line) + 1
266
+
267
+ if not lines:
268
+ # If we truncated all the lines, then we have
269
+ # have some ridiculous long line at the start
270
+ # of the output so we'll just truncate by
271
+ # character count to retain something.
272
+ output = output[:character_limit]
273
+ else:
274
+ # Otherwise, just join the lines back together up to the limit
275
+ lines = lines[:MAX_LINES]
276
+ output = "\n".join(lines)
277
+
278
+ return output, True
279
+
280
+
281
+ @overload
282
+ def truncate_message(response: CodeExecutionRequest) -> CodeExecutionRequest: ...
283
+ @overload
284
+ def truncate_message(response: CodeExecutionResponse) -> CodeExecutionResponse: ...
285
+ @overload
286
+ def truncate_message(
287
+ response: StreamingCodeExecutionResponse,
288
+ ) -> StreamingCodeExecutionResponse: ...
289
+ @overload
290
+ def truncate_message(
291
+ response: StreamingCodeExecutionResponseChunk,
292
+ ) -> StreamingCodeExecutionResponseChunk: ...
293
+ @overload
294
+ def truncate_message(response: FileWriteRequest) -> FileWriteRequest: ...
295
+ @overload
296
+ def truncate_message(response: FileWriteResponse) -> FileWriteResponse: ...
297
+ @overload
298
+ def truncate_message(response: ListFilesRequest) -> ListFilesRequest: ...
299
+ @overload
300
+ def truncate_message(response: ListFilesResponse) -> ListFilesResponse: ...
301
+
302
+
303
+ @overload
304
+ def truncate_message(
305
+ response: RemoteExecutionRequestType,
306
+ ) -> RemoteExecutionRequestType: ...
307
+ @overload
308
+ def truncate_message(
309
+ response: RemoteExecutionResponseType,
310
+ ) -> RemoteExecutionResponseType: ...
311
+ @overload
312
+ def truncate_message(response: RemoteExecutionMessage) -> RemoteExecutionMessage: ...
313
+
314
+
315
+ def truncate_message(
316
+ response: RemoteExecutionMessage,
317
+ ) -> RemoteExecutionMessage:
318
+ if isinstance(
319
+ response,
320
+ CodeExecutionResponse
321
+ | StreamingCodeExecutionResponse
322
+ | StreamingCodeExecutionResponseChunk,
323
+ ):
324
+ content, truncated = truncate_output(response.content)
325
+ response.content = content
326
+ if truncated:
327
+ response.truncated = True
328
+ elif (
329
+ isinstance(response, CommandResponse)
330
+ and response.subcommand != "codebase_context"
331
+ ):
332
+ content, truncated = truncate_output(response.content)
333
+ response.content = content
334
+ if truncated:
335
+ response.truncated = True
336
+ return response
337
+
338
+
339
+ ### Error Handling
340
+
341
+
342
+ def format_attachment_data(
343
+ attachment_lines: list[str] | None = None,
344
+ ) -> str | None:
345
+ if not attachment_lines:
346
+ return None
347
+ log_attachment_str = "\n".join(attachment_lines)
348
+ return log_attachment_str
349
+
350
+
351
+ def format_error_log(
352
+ exc: Exception,
353
+ chat_uuid: str | None = None,
354
+ attachment_lines: list[str] | None = None,
355
+ ) -> CLIErrorLog | None:
356
+ exc_info = exc_info_from_error(exc)
357
+ event, _ = event_from_exception(exc_info)
358
+ attachment_data = format_attachment_data(attachment_lines)
359
+ version = get_installed_version()
360
+
361
+ try:
362
+ event_data = json.dumps(serialize(event)) # type: ignore
363
+ except json.JSONDecodeError:
364
+ return None
365
+
366
+ return CLIErrorLog(
367
+ event_data=event_data,
368
+ attachment_data=attachment_data,
369
+ version=version,
370
+ chat_uuid=chat_uuid,
371
+ )
372
+
373
+
374
+ ### Websockets
375
+
376
+
377
+ ws_logger = logging.getLogger("WebsocketUtils")
378
+
379
+
380
+ def ws_retry(
381
+ connection_name: str,
382
+ max_retries: int = 5,
383
+ ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
384
+ connection_name = connection_name.capitalize()
385
+ reconnect_msg = f"{connection_name} reconnecting."
386
+ disconnect_msg = f"{connection_name} connection closed."
387
+ max_disconnect_msg = (
388
+ f"{connection_name} connection closed {max_retries} times, exiting."
389
+ )
390
+
391
+ def decorator(
392
+ f: Callable[..., Awaitable[None]],
393
+ ) -> Callable[..., Awaitable[None]]:
394
+ @wraps(f)
395
+ async def wrapped(*args: Any, **kwargs: Any) -> None:
396
+ i = 0
397
+
398
+ while True:
399
+ try:
400
+ return await f(*args, **kwargs)
401
+ except (websockets.exceptions.ConnectionClosed, TimeoutError) as e:
402
+ # Warn on disconnect
403
+ ws_logger.warning(disconnect_msg)
404
+
405
+ if i >= max_retries:
406
+ # We've reached the max number of retries,
407
+ # log an error and reraise
408
+ ws_logger.warning(max_disconnect_msg)
409
+ raise e
410
+
411
+ # Increment the retry count
412
+ i += 1
413
+ # Notify the user that we're reconnecting
414
+ ws_logger.warning(reconnect_msg)
415
+ continue
416
+
417
+ return wrapped
418
+
419
+ return decorator
420
+
421
+
422
+ async def safe_read_file(path: FilePath) -> str:
423
+ path = AsyncPath(path)
424
+
425
+ try:
426
+ return await path.read_text(encoding="utf-8")
427
+ except UnicodeDecodeError:
428
+ # Potentially a wacky encoding or mixture of encodings,
429
+ # attempt to correct it.
430
+ fbytes = await path.read_bytes()
431
+ # Handles mixed encodings with utf-8 and cp1252 (windows)
432
+ fbytes = UnicodeDammit.detwingle(fbytes)
433
+
434
+ decode_result = smart_decode(fbytes)
435
+
436
+ if decode_result:
437
+ # First item in tuple is the decoded str
438
+ return decode_result[0]
439
+
440
+ raise
441
+
442
+
443
+ async def safe_get_file_metadata(path: FilePath) -> FileMetadata | None:
444
+ path = AsyncPath(path)
445
+ try:
446
+ stats = await path.stat()
447
+ except Exception as e:
448
+ logger.error(f"Error getting file metadata: {e!s}")
449
+ return None
450
+
451
+ return FileMetadata(
452
+ modified_timestamp=stats.st_mtime,
453
+ file_mode=stat.filemode(stats.st_mode),
454
+ )
455
+
456
+
457
+ async def safe_write_file(path: FilePath, content: str) -> None:
458
+ await AsyncPath(path).write_text(content, encoding="utf-8")
459
+
460
+
461
+ def smart_decode(b: bytes) -> tuple[str, str] | None:
462
+ # This function attempts to decode by detecting the actual source
463
+ # encoding, returning (decoded_str, detected_encoding) if successful.
464
+ # We also attempt to fix cases of mixed encodings of cp1252 + utf-8
465
+ # using the detwingle helper provided by bs4. This can happen on
466
+ # windows, particularly when a user edits a utf-8 file by pasting in
467
+ # the special windows smart quotes.
468
+ b = UnicodeDammit.detwingle(b)
469
+
470
+ encoding = UnicodeDammit(
471
+ b, known_definite_encodings=["utf-8", "cp1252"]
472
+ ).original_encoding
473
+
474
+ if not encoding:
475
+ return None
476
+
477
+ return (b.decode(encoding=encoding), encoding)
File without changes
@@ -0,0 +1,206 @@
1
+ from abc import ABC
2
+ from enum import Enum
3
+ from typing import Annotated, Any, ClassVar, Literal
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+ WRITE_STRATEGY_FULL_FILE_REWRITE: Literal["FULL_FILE_REWRITE"] = "FULL_FILE_REWRITE"
8
+ DEFAULT_CODE_BLOCK_TIMEOUT = 30
9
+ WRITE_STRATEGY_NATURAL_EDIT: Literal["NATURAL_EDIT"] = "NATURAL_EDIT"
10
+ WRITE_STRATEGY_SEARCH_REPLACE: Literal["SEARCH_REPLACE"] = "SEARCH_REPLACE"
11
+ WRITE_STRATEGY_UDIFF: Literal["UDIFF"] = "UDIFF"
12
+
13
+ FileWriteStrategyName = Literal[
14
+ "FULL_FILE_REWRITE", "UDIFF", "SEARCH_REPLACE", "NATURAL_EDIT"
15
+ ]
16
+
17
+
18
+ class CommandType(str, Enum):
19
+ THINKING = "thinking"
20
+ FILE_READ = "file_read"
21
+ SUMMARIZE = "summarize"
22
+ STEP_OUTPUT = "step_output"
23
+ PROTOTYPE = "prototype"
24
+ DB_QUERY = "db_query"
25
+ DB_GET_TABLE_NAMES = "db_get_table_names"
26
+ DB_GET_TABLE_SCHEMA = "db_get_table_schema"
27
+ ANSWER = "answer"
28
+ ASK = "ask"
29
+ SHELL = "shell"
30
+ PYTHON = "python"
31
+ FILE_WRITE = "file_write"
32
+
33
+
34
+ class CommandData(BaseModel):
35
+ executable: ClassVar[bool]
36
+
37
+
38
+ class FileReadCommandData(CommandData):
39
+ executable: ClassVar[bool] = True
40
+ type: Literal[CommandType.FILE_READ] = CommandType.FILE_READ
41
+
42
+ file_path: str
43
+ language: str
44
+ limit: int | None = None
45
+ offset: int | None = None
46
+
47
+
48
+ class ThinkingCommandData(CommandData):
49
+ executable: ClassVar[bool] = False
50
+ type: Literal[CommandType.THINKING] = CommandType.THINKING
51
+
52
+ content: str
53
+ signature: str | None = None
54
+
55
+
56
+ class PrototypeCommandData(CommandData):
57
+ executable: ClassVar[bool] = True
58
+ type: Literal[CommandType.PROTOTYPE] = CommandType.PROTOTYPE
59
+
60
+ command_name: str
61
+ # Structured data extracted from LLM output
62
+ content_json: dict[str, Any]
63
+ # Raw text extracted from LLM output
64
+ content_raw: str
65
+ # Rendered LLM output for frontend display
66
+ content_rendered: str
67
+
68
+ llm_command_name_override: str | None = None
69
+
70
+ @property
71
+ def llm_command_name(self) -> str:
72
+ return self.llm_command_name_override or self.command_name
73
+
74
+
75
+ # deprecated, use StepOutputCommandData instead
76
+ class SummarizeCommandData(CommandData):
77
+ executable: ClassVar[bool] = True
78
+ type: Literal[CommandType.SUMMARIZE] = CommandType.SUMMARIZE
79
+
80
+ summary: str
81
+
82
+
83
+ class StepOutputCommandData(CommandData):
84
+ executable: ClassVar[bool] = True
85
+ type: Literal[CommandType.STEP_OUTPUT] = CommandType.STEP_OUTPUT
86
+
87
+ step_output_raw: str
88
+
89
+
90
+ class DBQueryCommandData(CommandData):
91
+ def __init__(
92
+ self,
93
+ **kwargs: Any,
94
+ ) -> None:
95
+ super().__init__(**kwargs)
96
+
97
+ executable: ClassVar[bool] = True
98
+ type: Literal[CommandType.DB_QUERY] = CommandType.DB_QUERY
99
+
100
+ query: str
101
+ max_gigabytes_billed: float | None = None # BigQuery only
102
+
103
+
104
+ class DBGetTableNamesCommandData(CommandData):
105
+ executable: ClassVar[bool] = True
106
+ type: Literal[CommandType.DB_GET_TABLE_NAMES] = CommandType.DB_GET_TABLE_NAMES
107
+
108
+
109
+ class DBGetTableSchemaCommandData(CommandData):
110
+ executable: ClassVar[bool] = True
111
+ type: Literal[CommandType.DB_GET_TABLE_SCHEMA] = CommandType.DB_GET_TABLE_SCHEMA
112
+
113
+ table_name: str
114
+
115
+
116
+ class AnswerCommandData(CommandData):
117
+ executable: ClassVar[bool] = False
118
+ type: Literal[CommandType.ANSWER] = CommandType.ANSWER
119
+
120
+ answer_raw: str
121
+
122
+
123
+ class AskCommandData(CommandData):
124
+ executable: ClassVar[bool] = False
125
+ type: Literal[CommandType.ASK] = CommandType.ASK
126
+
127
+ ask_raw: str
128
+
129
+
130
+ class ShellCommandData(CommandData):
131
+ exclude_from_schema_gen: ClassVar[bool] = True
132
+
133
+ executable: ClassVar[bool] = True
134
+ type: Literal[CommandType.SHELL] = CommandType.SHELL
135
+
136
+ timeout: int = DEFAULT_CODE_BLOCK_TIMEOUT
137
+ content: str
138
+
139
+
140
+ class PythonCommandData(CommandData):
141
+ exclude_from_schema_gen: ClassVar[bool] = True
142
+
143
+ executable: ClassVar[bool] = True
144
+ type: Literal[CommandType.PYTHON] = CommandType.PYTHON
145
+
146
+ content: str
147
+
148
+
149
+ class EditContent(BaseModel):
150
+ content: str
151
+ original_file: str | None = None
152
+
153
+
154
+ class NaturalEditContent(BaseModel):
155
+ natural_edit: str
156
+ intermediate_edit: str | None
157
+ original_file: str | None
158
+ new_file: str | None
159
+ error_content: str | None
160
+
161
+ @property
162
+ def is_resolved(self) -> bool:
163
+ return self.new_file is not None or self.error_content is not None
164
+
165
+ @property
166
+ def is_noop(self) -> bool:
167
+ return bool(
168
+ self.new_file is not None
169
+ and self.original_file is not None
170
+ and self.new_file == self.original_file
171
+ )
172
+
173
+
174
+ class FileWriteCommandData(CommandData):
175
+ exclude_from_schema_gen: ClassVar[bool] = True
176
+
177
+ executable: ClassVar[bool] = True
178
+ type: Literal[CommandType.FILE_WRITE] = CommandType.FILE_WRITE
179
+
180
+ file_path: str
181
+ language: str
182
+ write_strategy: FileWriteStrategyName
183
+ write_content: NaturalEditContent | EditContent
184
+ content: str
185
+
186
+
187
+ CommandDataType = Annotated[
188
+ FileReadCommandData
189
+ | ThinkingCommandData
190
+ | PrototypeCommandData
191
+ | SummarizeCommandData
192
+ | DBQueryCommandData
193
+ | DBGetTableNamesCommandData
194
+ | DBGetTableSchemaCommandData
195
+ | StepOutputCommandData
196
+ | AnswerCommandData
197
+ | AskCommandData
198
+ | ShellCommandData
199
+ | PythonCommandData
200
+ | FileWriteCommandData,
201
+ Field(discriminator="type"),
202
+ ]
203
+
204
+
205
+ class CommandImpl(ABC):
206
+ command_data_type: ClassVar[type[CommandData]]