indent 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of indent might be problematic. Click here for more details.

Files changed (56) hide show
  1. exponent/__init__.py +1 -0
  2. exponent/cli.py +112 -0
  3. exponent/commands/cloud_commands.py +85 -0
  4. exponent/commands/common.py +434 -0
  5. exponent/commands/config_commands.py +581 -0
  6. exponent/commands/github_app_commands.py +211 -0
  7. exponent/commands/listen_commands.py +96 -0
  8. exponent/commands/run_commands.py +208 -0
  9. exponent/commands/settings.py +56 -0
  10. exponent/commands/shell_commands.py +2840 -0
  11. exponent/commands/theme.py +246 -0
  12. exponent/commands/types.py +111 -0
  13. exponent/commands/upgrade.py +29 -0
  14. exponent/commands/utils.py +236 -0
  15. exponent/core/config.py +180 -0
  16. exponent/core/graphql/__init__.py +0 -0
  17. exponent/core/graphql/client.py +59 -0
  18. exponent/core/graphql/cloud_config_queries.py +77 -0
  19. exponent/core/graphql/get_chats_query.py +47 -0
  20. exponent/core/graphql/github_config_queries.py +56 -0
  21. exponent/core/graphql/mutations.py +75 -0
  22. exponent/core/graphql/queries.py +110 -0
  23. exponent/core/graphql/subscriptions.py +452 -0
  24. exponent/core/remote_execution/checkpoints.py +212 -0
  25. exponent/core/remote_execution/cli_rpc_types.py +214 -0
  26. exponent/core/remote_execution/client.py +545 -0
  27. exponent/core/remote_execution/code_execution.py +58 -0
  28. exponent/core/remote_execution/command_execution.py +105 -0
  29. exponent/core/remote_execution/error_info.py +45 -0
  30. exponent/core/remote_execution/exceptions.py +10 -0
  31. exponent/core/remote_execution/file_write.py +410 -0
  32. exponent/core/remote_execution/files.py +415 -0
  33. exponent/core/remote_execution/git.py +268 -0
  34. exponent/core/remote_execution/languages/python_execution.py +239 -0
  35. exponent/core/remote_execution/languages/shell_streaming.py +221 -0
  36. exponent/core/remote_execution/languages/types.py +20 -0
  37. exponent/core/remote_execution/session.py +128 -0
  38. exponent/core/remote_execution/system_context.py +54 -0
  39. exponent/core/remote_execution/tool_execution.py +289 -0
  40. exponent/core/remote_execution/truncation.py +284 -0
  41. exponent/core/remote_execution/types.py +670 -0
  42. exponent/core/remote_execution/utils.py +600 -0
  43. exponent/core/types/__init__.py +0 -0
  44. exponent/core/types/command_data.py +206 -0
  45. exponent/core/types/event_types.py +89 -0
  46. exponent/core/types/generated/__init__.py +0 -0
  47. exponent/core/types/generated/strategy_info.py +225 -0
  48. exponent/migration-docs/login.md +112 -0
  49. exponent/py.typed +4 -0
  50. exponent/utils/__init__.py +0 -0
  51. exponent/utils/colors.py +92 -0
  52. exponent/utils/version.py +289 -0
  53. indent-0.0.8.dist-info/METADATA +36 -0
  54. indent-0.0.8.dist-info/RECORD +56 -0
  55. indent-0.0.8.dist-info/WHEEL +4 -0
  56. indent-0.0.8.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,600 @@
1
+ import json
2
+ import logging
3
+ from collections.abc import Awaitable, Callable
4
+ from functools import wraps
5
+ from typing import (
6
+ Any,
7
+ NoReturn,
8
+ TypeVar,
9
+ cast,
10
+ overload,
11
+ )
12
+
13
+ import websockets
14
+ import websockets.exceptions
15
+ from anyio import Path as AsyncPath
16
+ from bs4 import UnicodeDammit
17
+ from httpx import Response
18
+ from pydantic import BaseModel
19
+ from sentry_sdk.serializer import serialize
20
+ from sentry_sdk.utils import (
21
+ event_from_exception,
22
+ exc_info_from_error,
23
+ )
24
+
25
+ from exponent.core.remote_execution.types import (
26
+ SUPPORTED_LANGUAGES,
27
+ CLIErrorLog,
28
+ CodeExecutionRequest,
29
+ CodeExecutionResponse,
30
+ CommandRequest,
31
+ CommandResponse,
32
+ CreateCheckpointRequest,
33
+ CreateCheckpointResponse,
34
+ ErrorResponse,
35
+ FilePath,
36
+ FileWriteRequest,
37
+ FileWriteResponse,
38
+ GetAllTrackedFilesRequest,
39
+ GetAllTrackedFilesResponse,
40
+ GetFileAttachmentRequest,
41
+ GetFileAttachmentResponse,
42
+ GetFileAttachmentsRequest,
43
+ GetFileAttachmentsResponse,
44
+ GetMatchingFilesRequest,
45
+ GetMatchingFilesResponse,
46
+ HaltRequest,
47
+ HaltResponse,
48
+ ListFilesRequest,
49
+ ListFilesResponse,
50
+ RemoteExecutionMessage,
51
+ RemoteExecutionMessageData,
52
+ RemoteExecutionRequest,
53
+ RemoteExecutionRequestData,
54
+ RemoteExecutionRequestType,
55
+ RemoteExecutionResponse,
56
+ RemoteExecutionResponseData,
57
+ RemoteExecutionResponseType,
58
+ RollbackToCheckpointRequest,
59
+ RollbackToCheckpointResponse,
60
+ StreamingCodeExecutionRequest,
61
+ StreamingCodeExecutionResponse,
62
+ StreamingCodeExecutionResponseChunk,
63
+ SupportedLanguage,
64
+ SwitchCLIChatRequest,
65
+ SwitchCLIChatResponse,
66
+ SystemContextRequest,
67
+ SystemContextResponse,
68
+ )
69
+ from exponent.core.types.command_data import NaturalEditContent
70
+ from exponent.core.types.event_types import (
71
+ CodeBlockEvent,
72
+ CommandEvent,
73
+ FileWriteEvent,
74
+ LocalEventType,
75
+ )
76
+ from exponent.utils.version import get_installed_version
77
+
78
+ ### Serde
79
+
80
+
81
+ def deserialize_message_data(
82
+ message_data: RemoteExecutionMessageData | str,
83
+ ) -> RemoteExecutionMessage:
84
+ if isinstance(message_data, str):
85
+ message_data = RemoteExecutionMessageData.model_validate_json(message_data)
86
+ if message_data.direction == "request":
87
+ return deserialize_request_data(cast(RemoteExecutionRequestData, message_data))
88
+ elif message_data.direction == "response":
89
+ return deserialize_response_data(
90
+ cast(RemoteExecutionResponseData, message_data)
91
+ )
92
+ else:
93
+ # type checking trick, if you miss a namespace then
94
+ # this won't typecheck due to the input parameter
95
+ # having a potential type other than no-return
96
+ assert_unreachable(message_data.direction)
97
+
98
+
99
+ def deserialize_request_data(
100
+ request_data: RemoteExecutionRequestData | str,
101
+ ) -> RemoteExecutionRequestType:
102
+ request: RemoteExecutionRequestType
103
+ if isinstance(request_data, str):
104
+ request_data = RemoteExecutionRequestData.model_validate_json(request_data)
105
+ if request_data.direction != "request":
106
+ raise ValueError(f"Expected request, but got {request_data.direction}")
107
+ if request_data.namespace == "code_execution":
108
+ request = CodeExecutionRequest.model_validate_json(request_data.message_data)
109
+ elif request_data.namespace == "file_write":
110
+ request = FileWriteRequest.model_validate_json(request_data.message_data)
111
+ elif request_data.namespace == "list_files":
112
+ request = ListFilesRequest.model_validate_json(request_data.message_data)
113
+ elif request_data.namespace == "get_file_attachment":
114
+ request = GetFileAttachmentRequest.model_validate_json(
115
+ request_data.message_data
116
+ )
117
+ elif request_data.namespace == "get_file_attachments":
118
+ request = GetFileAttachmentsRequest.model_validate_json(
119
+ request_data.message_data
120
+ )
121
+ elif request_data.namespace == "get_matching_files":
122
+ request = GetMatchingFilesRequest.model_validate_json(request_data.message_data)
123
+ elif request_data.namespace == "system_context":
124
+ request = SystemContextRequest.model_validate_json(request_data.message_data)
125
+ elif request_data.namespace == "get_all_tracked_files":
126
+ request = GetAllTrackedFilesRequest.model_validate_json(
127
+ request_data.message_data
128
+ )
129
+ elif request_data.namespace == "command":
130
+ request = CommandRequest.model_validate_json(request_data.message_data)
131
+ elif request_data.namespace == "halt":
132
+ request = HaltRequest.model_validate_json(request_data.message_data)
133
+ elif request_data.namespace == "streaming_code_execution":
134
+ request = StreamingCodeExecutionRequest.model_validate_json(
135
+ request_data.message_data
136
+ )
137
+ elif request_data.namespace == "switch_cli_chat":
138
+ request = SwitchCLIChatRequest.model_validate_json(request_data.message_data)
139
+ elif request_data.namespace == "streaming_code_execution_chunk":
140
+ assert False, "Streaming code execution chunk is a response, not a request"
141
+ elif request_data.namespace == "error":
142
+ assert False, "Error is a response, not a request"
143
+ elif request_data.namespace == "create_checkpoint":
144
+ request = CreateCheckpointRequest.model_validate_json(request_data.message_data)
145
+ elif request_data.namespace == "rollback_to_checkpoint":
146
+ request = RollbackToCheckpointRequest.model_validate_json(
147
+ request_data.message_data
148
+ )
149
+ else:
150
+ # type checking trick, if you miss a namespace then
151
+ # this won't typecheck due to the input parameter
152
+ # having a potential type other than no-return
153
+ request = assert_unreachable(request_data.namespace)
154
+ return truncate_message(request)
155
+
156
+
157
+ def deserialize_response_data(
158
+ response_data: RemoteExecutionResponseData | str,
159
+ ) -> RemoteExecutionResponseType:
160
+ response: RemoteExecutionResponseType
161
+ if isinstance(response_data, str):
162
+ response_data = RemoteExecutionResponseData.model_validate_json(response_data)
163
+ if response_data.direction != "response":
164
+ raise ValueError(f"Expected response, but got {response_data.direction}")
165
+ if response_data.namespace == "code_execution":
166
+ response = CodeExecutionResponse.model_validate_json(response_data.message_data)
167
+ elif response_data.namespace == "streaming_code_execution":
168
+ response = StreamingCodeExecutionResponse.model_validate_json(
169
+ response_data.message_data
170
+ )
171
+ elif response_data.namespace == "streaming_code_execution_chunk":
172
+ response = StreamingCodeExecutionResponseChunk.model_validate_json(
173
+ response_data.message_data
174
+ )
175
+ elif response_data.namespace == "file_write":
176
+ response = FileWriteResponse.model_validate_json(response_data.message_data)
177
+ elif response_data.namespace == "list_files":
178
+ response = ListFilesResponse.model_validate_json(response_data.message_data)
179
+ elif response_data.namespace == "get_matching_files":
180
+ response = GetMatchingFilesResponse.model_validate_json(
181
+ response_data.message_data
182
+ )
183
+ elif response_data.namespace == "get_file_attachment":
184
+ response = GetFileAttachmentResponse.model_validate_json(
185
+ response_data.message_data
186
+ )
187
+ elif response_data.namespace == "get_file_attachments":
188
+ response = GetFileAttachmentsResponse.model_validate_json(
189
+ response_data.message_data
190
+ )
191
+ elif response_data.namespace == "system_context":
192
+ response = SystemContextResponse.model_validate_json(response_data.message_data)
193
+ elif response_data.namespace == "get_all_tracked_files":
194
+ response = GetAllTrackedFilesResponse.model_validate_json(
195
+ response_data.message_data
196
+ )
197
+ elif response_data.namespace == "command":
198
+ response = CommandResponse.model_validate_json(response_data.message_data)
199
+ elif response_data.namespace == "halt":
200
+ response = HaltResponse.model_validate_json(response_data.message_data)
201
+ elif response_data.namespace == "switch_cli_chat":
202
+ response = SwitchCLIChatResponse.model_validate_json(response_data.message_data)
203
+ elif response_data.namespace == "error":
204
+ response = ErrorResponse.model_validate_json(response_data.message_data)
205
+ elif response_data.namespace == "create_checkpoint":
206
+ response = CreateCheckpointResponse.model_validate_json(
207
+ response_data.message_data
208
+ )
209
+ elif response_data.namespace == "rollback_to_checkpoint":
210
+ response = RollbackToCheckpointResponse.model_validate_json(
211
+ response_data.message_data
212
+ )
213
+ else:
214
+ # type checking trick, if you miss a namespace then
215
+ # this won't typecheck due to the input parameter
216
+ # having a potential type other than no-return
217
+ response = assert_unreachable(response_data.namespace)
218
+ return truncate_message(response)
219
+
220
+
221
+ def serialize_message(response: RemoteExecutionMessage) -> str:
222
+ truncated_response = truncate_message(response)
223
+ message = RemoteExecutionMessageData(
224
+ namespace=response.namespace,
225
+ direction=response.direction,
226
+ message_data=truncated_response.model_dump_json(),
227
+ )
228
+ serialized = message.model_dump_json()
229
+ return serialized
230
+
231
+
232
+ ### API Serde
233
+
234
+
235
+ TModel = TypeVar("TModel", bound=BaseModel)
236
+
237
+
238
+ async def deserialize_api_response(
239
+ response: Response,
240
+ data_model: type[TModel],
241
+ ) -> TModel:
242
+ if response.is_error:
243
+ print(response.text)
244
+ try:
245
+ error_message = response.json()["detail"]
246
+ except Exception:
247
+ error_message = response.text
248
+ raise ValueError(f"{error_message} ({response.status_code})")
249
+
250
+ response_json = response.json()
251
+ return data_model.model_validate(response_json)
252
+
253
+
254
+ def get_file_write_content(event: FileWriteEvent) -> str:
255
+ if isinstance(event.write_content, NaturalEditContent):
256
+ assert event.write_content.new_file is not None
257
+ return event.write_content.new_file
258
+ else:
259
+ return event.write_content.content
260
+
261
+
262
+ @overload
263
+ def convert_event_to_execution_request(
264
+ request: CodeBlockEvent,
265
+ ) -> CodeExecutionRequest: ...
266
+
267
+
268
+ @overload
269
+ def convert_event_to_execution_request(
270
+ request: FileWriteEvent,
271
+ ) -> FileWriteRequest: ...
272
+
273
+
274
+ @overload
275
+ def convert_event_to_execution_request(
276
+ request: CommandEvent,
277
+ ) -> CommandRequest: ...
278
+
279
+
280
+ def convert_event_to_execution_request(
281
+ request: LocalEventType,
282
+ ) -> CodeExecutionRequest | FileWriteRequest | CommandRequest:
283
+ if isinstance(request, CodeBlockEvent):
284
+ language = assert_supported_language(request.language)
285
+
286
+ return CodeExecutionRequest(
287
+ language=language,
288
+ content=request.content,
289
+ timeout=request.timeout,
290
+ correlation_id=request.event_uuid,
291
+ )
292
+ elif isinstance(request, FileWriteEvent):
293
+ return FileWriteRequest(
294
+ file_path=request.file_path,
295
+ language=request.language,
296
+ write_strategy=request.write_strategy,
297
+ content=get_file_write_content(request),
298
+ correlation_id=request.event_uuid,
299
+ )
300
+ elif isinstance(request, CommandEvent):
301
+ return CommandRequest(
302
+ data=request.data,
303
+ correlation_id=request.event_uuid,
304
+ )
305
+ else:
306
+ assert_unreachable(request)
307
+
308
+
309
+ ### Validation
310
+
311
+
312
+ ResponseT = TypeVar("ResponseT", bound=RemoteExecutionResponse)
313
+
314
+
315
+ def assert_valid_response_type(
316
+ response: RemoteExecutionResponseType, request: RemoteExecutionRequest[ResponseT]
317
+ ) -> ResponseT | ErrorResponse:
318
+ if isinstance(response, ErrorResponse):
319
+ return response
320
+ if request.namespace != response.namespace or response.direction != "response":
321
+ raise ValueError(
322
+ f"Expected {request.namespace}.response, but got {response.namespace}.{response.direction}"
323
+ )
324
+ return cast(ResponseT, response)
325
+
326
+
327
+ def assert_unreachable(x: NoReturn) -> NoReturn:
328
+ assert False, f"Unhandled type: {type(x).__name__}"
329
+
330
+
331
+ def assert_supported_language(language: str) -> SupportedLanguage:
332
+ if language not in SUPPORTED_LANGUAGES:
333
+ raise ValueError(f"Unsupported language: {language}")
334
+
335
+ return cast(SupportedLanguage, language)
336
+
337
+
338
+ ### Truncation
339
+
340
+
341
+ OUTPUT_CHARACTER_MAX = 90_000 # A tad over ~8k tokens
342
+ TRUNCATION_MESSAGE_CHARS = (
343
+ "(Output truncated, only showing the first {remaining_chars} characters)"
344
+ )
345
+ TRUNCATION_MESSAGE_LINES = (
346
+ "(Output truncated, only showing the first {remaining_lines} lines)"
347
+ )
348
+ LONGEST_TRUNCATION_MESSAGE_LEN = (
349
+ len(TRUNCATION_MESSAGE_CHARS.format(remaining_chars=OUTPUT_CHARACTER_MAX)) + 1
350
+ )
351
+
352
+ MAX_LINES = 10_000
353
+
354
+
355
+ def truncate_output(
356
+ output: str, character_limit: int = OUTPUT_CHARACTER_MAX
357
+ ) -> tuple[str, bool]:
358
+ output_length = len(output)
359
+ # When under the character limit, return the output as is.
360
+ # Note we're adding the length of the truncation message + 1
361
+ # to the character limit to account for the fact that the
362
+ # truncation message will be added to the output + a newline.
363
+ # In case we want to run truncation logic both client side
364
+ # and server side, we want to account for the truncation
365
+ # message length to avoid weird double truncation overlap.
366
+
367
+ # Attempt to trim whole lines until we're under
368
+ # the character limit.
369
+ lines = output.split("\n")
370
+
371
+ if output_length <= character_limit and len(lines) <= MAX_LINES:
372
+ return output, False
373
+
374
+ while output_length > character_limit:
375
+ last_line = lines.pop()
376
+ # +1 to account for the newline
377
+ output_length -= len(last_line) + 1
378
+
379
+ if not lines:
380
+ # If we truncated all the lines, then we have
381
+ # have some ridiculous long line at the start
382
+ # of the output so we'll just truncate by
383
+ # character count to retain something.
384
+ output = output[:character_limit]
385
+ else:
386
+ # Otherwise, just join the lines back together up to the limit
387
+ lines = lines[:MAX_LINES]
388
+ output = "\n".join(lines)
389
+
390
+ return output, True
391
+
392
+
393
+ @overload
394
+ def truncate_message(response: CodeExecutionRequest) -> CodeExecutionRequest: ...
395
+ @overload
396
+ def truncate_message(response: CodeExecutionResponse) -> CodeExecutionResponse: ...
397
+ @overload
398
+ def truncate_message(
399
+ response: StreamingCodeExecutionResponse,
400
+ ) -> StreamingCodeExecutionResponse: ...
401
+ @overload
402
+ def truncate_message(
403
+ response: StreamingCodeExecutionResponseChunk,
404
+ ) -> StreamingCodeExecutionResponseChunk: ...
405
+ @overload
406
+ def truncate_message(response: FileWriteRequest) -> FileWriteRequest: ...
407
+ @overload
408
+ def truncate_message(response: FileWriteResponse) -> FileWriteResponse: ...
409
+ @overload
410
+ def truncate_message(
411
+ response: GetFileAttachmentRequest,
412
+ ) -> GetFileAttachmentRequest: ...
413
+ @overload
414
+ def truncate_message(
415
+ response: GetFileAttachmentResponse,
416
+ ) -> GetFileAttachmentResponse: ...
417
+ @overload
418
+ def truncate_message(response: ListFilesRequest) -> ListFilesRequest: ...
419
+ @overload
420
+ def truncate_message(response: ListFilesResponse) -> ListFilesResponse: ...
421
+ @overload
422
+ def truncate_message(response: GetMatchingFilesRequest) -> GetMatchingFilesRequest: ...
423
+ @overload
424
+ def truncate_message(
425
+ response: GetMatchingFilesResponse,
426
+ ) -> GetMatchingFilesResponse: ...
427
+ @overload
428
+ def truncate_message(response: SystemContextRequest) -> SystemContextRequest: ...
429
+ @overload
430
+ def truncate_message(response: SystemContextResponse) -> SystemContextResponse: ...
431
+ @overload
432
+ def truncate_message(
433
+ response: RemoteExecutionRequestType,
434
+ ) -> RemoteExecutionRequestType: ...
435
+ @overload
436
+ def truncate_message(
437
+ response: RemoteExecutionResponseType,
438
+ ) -> RemoteExecutionResponseType: ...
439
+ @overload
440
+ def truncate_message(response: RemoteExecutionMessage) -> RemoteExecutionMessage: ...
441
+
442
+
443
+ def truncate_message(
444
+ response: RemoteExecutionMessage,
445
+ ) -> RemoteExecutionMessage:
446
+ if isinstance(
447
+ response,
448
+ (
449
+ CodeExecutionResponse,
450
+ GetFileAttachmentResponse,
451
+ StreamingCodeExecutionResponse,
452
+ StreamingCodeExecutionResponseChunk,
453
+ ),
454
+ ):
455
+ content, truncated = truncate_output(response.content)
456
+ response.content = content
457
+ if truncated:
458
+ response.truncated = True
459
+ elif (
460
+ isinstance(response, CommandResponse)
461
+ and response.subcommand != "codebase_context"
462
+ ):
463
+ content, truncated = truncate_output(response.content)
464
+ response.content = content
465
+ if truncated:
466
+ response.truncated = True
467
+ elif isinstance(response, GetFileAttachmentsResponse):
468
+ for file_attachment in response.file_attachments:
469
+ content, truncated = truncate_output(file_attachment.content)
470
+ file_attachment.content = content
471
+ if truncated:
472
+ file_attachment.truncated = True
473
+ return response
474
+
475
+
476
+ ### Error Handling
477
+
478
+
479
+ def format_attachment_data(
480
+ attachment_lines: list[str] | None = None,
481
+ ) -> str | None:
482
+ if not attachment_lines:
483
+ return None
484
+ log_attachment_str = "\n".join(attachment_lines)
485
+ return log_attachment_str
486
+
487
+
488
+ def format_error_log(
489
+ exc: Exception,
490
+ chat_uuid: str | None = None,
491
+ attachment_lines: list[str] | None = None,
492
+ ) -> CLIErrorLog | None:
493
+ exc_info = exc_info_from_error(exc)
494
+ event, _ = event_from_exception(exc_info)
495
+ attachment_data = format_attachment_data(attachment_lines)
496
+ version = get_installed_version()
497
+
498
+ try:
499
+ event_data = json.dumps(serialize(event)) # type: ignore
500
+ except json.JSONDecodeError:
501
+ return None
502
+
503
+ return CLIErrorLog(
504
+ event_data=event_data,
505
+ attachment_data=attachment_data,
506
+ version=version,
507
+ chat_uuid=chat_uuid,
508
+ )
509
+
510
+
511
+ ### Websockets
512
+
513
+
514
+ ws_logger = logging.getLogger("WebsocketUtils")
515
+
516
+
517
+ def ws_retry(
518
+ connection_name: str,
519
+ max_retries: int = 5,
520
+ ) -> Callable[[Callable[..., Awaitable[None]]], Callable[..., Awaitable[None]]]:
521
+ connection_name = connection_name.capitalize()
522
+ reconnect_msg = f"{connection_name} reconnecting."
523
+ disconnect_msg = f"{connection_name} connection closed."
524
+ max_disconnect_msg = (
525
+ f"{connection_name} connection closed {max_retries} times, exiting."
526
+ )
527
+
528
+ def decorator(
529
+ f: Callable[..., Awaitable[None]],
530
+ ) -> Callable[..., Awaitable[None]]:
531
+ @wraps(f)
532
+ async def wrapped(*args: Any, **kwargs: Any) -> None:
533
+ i = 0
534
+
535
+ while True:
536
+ try:
537
+ return await f(*args, **kwargs)
538
+ except (websockets.exceptions.ConnectionClosed, TimeoutError) as e:
539
+ # Warn on disconnect
540
+ ws_logger.warning(disconnect_msg)
541
+
542
+ if i >= max_retries:
543
+ # We've reached the max number of retries,
544
+ # log an error and reraise
545
+ ws_logger.warning(max_disconnect_msg)
546
+ raise e
547
+
548
+ # Increment the retry count
549
+ i += 1
550
+ # Notify the user that we're reconnecting
551
+ ws_logger.warning(reconnect_msg)
552
+ continue
553
+
554
+ return wrapped
555
+
556
+ return decorator
557
+
558
+
559
+ async def safe_read_file(path: FilePath) -> str:
560
+ path = AsyncPath(path)
561
+
562
+ try:
563
+ return await path.read_text(encoding="utf-8")
564
+ except UnicodeDecodeError:
565
+ # Potentially a wacky encoding or mixture of encodings,
566
+ # attempt to correct it.
567
+ fbytes = await path.read_bytes()
568
+ # Handles mixed encodings with utf-8 and cp1252 (windows)
569
+ fbytes = UnicodeDammit.detwingle(fbytes)
570
+
571
+ decode_result = smart_decode(fbytes)
572
+
573
+ if decode_result:
574
+ # First item in tuple is the decoded str
575
+ return decode_result[0]
576
+
577
+ raise
578
+
579
+
580
+ async def safe_write_file(path: FilePath, content: str) -> None:
581
+ await AsyncPath(path).write_text(content, encoding="utf-8")
582
+
583
+
584
+ def smart_decode(b: bytes) -> tuple[str, str] | None:
585
+ # This function attempts to decode by detecting the actual source
586
+ # encoding, returning (decoded_str, detected_encoding) if successful.
587
+ # We also attempt to fix cases of mixed encodings of cp1252 + utf-8
588
+ # using the detwingle helper provided by bs4. This can happen on
589
+ # windows, particularly when a user edits a utf-8 file by pasting in
590
+ # the special windows smart quotes.
591
+ b = UnicodeDammit.detwingle(b)
592
+
593
+ encoding = UnicodeDammit(
594
+ b, known_definite_encodings=["utf-8", "cp1252"]
595
+ ).original_encoding
596
+
597
+ if not encoding:
598
+ return None
599
+
600
+ return (b.decode(encoding=encoding), encoding)
File without changes