dycw-utilities 0.135.0__py3-none-any.whl → 0.178.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dycw-utilities might be problematic. Click here for more details.

Files changed (97) hide show
  1. dycw_utilities-0.178.1.dist-info/METADATA +34 -0
  2. dycw_utilities-0.178.1.dist-info/RECORD +105 -0
  3. dycw_utilities-0.178.1.dist-info/WHEEL +4 -0
  4. dycw_utilities-0.178.1.dist-info/entry_points.txt +4 -0
  5. utilities/__init__.py +1 -1
  6. utilities/altair.py +13 -10
  7. utilities/asyncio.py +312 -787
  8. utilities/atomicwrites.py +18 -6
  9. utilities/atools.py +64 -4
  10. utilities/cachetools.py +9 -6
  11. utilities/click.py +195 -77
  12. utilities/concurrent.py +1 -1
  13. utilities/contextlib.py +216 -17
  14. utilities/contextvars.py +20 -1
  15. utilities/cryptography.py +3 -3
  16. utilities/dataclasses.py +15 -28
  17. utilities/docker.py +387 -0
  18. utilities/enum.py +2 -2
  19. utilities/errors.py +17 -3
  20. utilities/fastapi.py +28 -59
  21. utilities/fpdf2.py +2 -2
  22. utilities/functions.py +24 -269
  23. utilities/git.py +9 -30
  24. utilities/grp.py +28 -0
  25. utilities/gzip.py +31 -0
  26. utilities/http.py +3 -2
  27. utilities/hypothesis.py +513 -159
  28. utilities/importlib.py +17 -1
  29. utilities/inflect.py +12 -4
  30. utilities/iterables.py +33 -58
  31. utilities/jinja2.py +148 -0
  32. utilities/json.py +70 -0
  33. utilities/libcst.py +38 -17
  34. utilities/lightweight_charts.py +4 -7
  35. utilities/logging.py +136 -93
  36. utilities/math.py +8 -4
  37. utilities/more_itertools.py +43 -45
  38. utilities/operator.py +27 -27
  39. utilities/orjson.py +189 -36
  40. utilities/os.py +61 -4
  41. utilities/packaging.py +115 -0
  42. utilities/parse.py +8 -5
  43. utilities/pathlib.py +269 -40
  44. utilities/permissions.py +298 -0
  45. utilities/platform.py +7 -6
  46. utilities/polars.py +1205 -413
  47. utilities/polars_ols.py +1 -1
  48. utilities/postgres.py +408 -0
  49. utilities/pottery.py +43 -19
  50. utilities/pqdm.py +3 -3
  51. utilities/psutil.py +5 -57
  52. utilities/pwd.py +28 -0
  53. utilities/pydantic.py +4 -52
  54. utilities/pydantic_settings.py +240 -0
  55. utilities/pydantic_settings_sops.py +76 -0
  56. utilities/pyinstrument.py +7 -7
  57. utilities/pytest.py +104 -143
  58. utilities/pytest_plugins/__init__.py +1 -0
  59. utilities/pytest_plugins/pytest_randomly.py +23 -0
  60. utilities/pytest_plugins/pytest_regressions.py +56 -0
  61. utilities/pytest_regressions.py +26 -46
  62. utilities/random.py +11 -6
  63. utilities/re.py +1 -1
  64. utilities/redis.py +220 -343
  65. utilities/sentinel.py +10 -0
  66. utilities/shelve.py +4 -1
  67. utilities/shutil.py +25 -0
  68. utilities/slack_sdk.py +35 -104
  69. utilities/sqlalchemy.py +496 -471
  70. utilities/sqlalchemy_polars.py +29 -54
  71. utilities/string.py +2 -3
  72. utilities/subprocess.py +1977 -0
  73. utilities/tempfile.py +112 -4
  74. utilities/testbook.py +50 -0
  75. utilities/text.py +174 -42
  76. utilities/throttle.py +158 -0
  77. utilities/timer.py +2 -2
  78. utilities/traceback.py +70 -35
  79. utilities/types.py +102 -30
  80. utilities/typing.py +479 -19
  81. utilities/uuid.py +42 -5
  82. utilities/version.py +27 -26
  83. utilities/whenever.py +1559 -361
  84. utilities/zoneinfo.py +80 -22
  85. dycw_utilities-0.135.0.dist-info/METADATA +0 -39
  86. dycw_utilities-0.135.0.dist-info/RECORD +0 -96
  87. dycw_utilities-0.135.0.dist-info/WHEEL +0 -4
  88. dycw_utilities-0.135.0.dist-info/licenses/LICENSE +0 -21
  89. utilities/aiolimiter.py +0 -25
  90. utilities/arq.py +0 -216
  91. utilities/eventkit.py +0 -388
  92. utilities/luigi.py +0 -183
  93. utilities/period.py +0 -152
  94. utilities/pudb.py +0 -62
  95. utilities/python_dotenv.py +0 -101
  96. utilities/streamlit.py +0 -105
  97. utilities/typed_settings.py +0 -123
utilities/importlib.py CHANGED
@@ -1,7 +1,23 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import importlib.resources
3
4
  from importlib import import_module
4
5
  from importlib.util import find_spec
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING
8
+
9
+ from utilities.errors import ImpossibleCaseError
10
+
11
+ if TYPE_CHECKING:
12
+ from importlib.resources import Anchor
13
+
14
+
15
+ def files(*, anchor: Anchor | None = None) -> Path:
16
+ """Get the path for an anchor."""
17
+ path = importlib.resources.files(anchor)
18
+ if isinstance(path, Path):
19
+ return path
20
+ raise ImpossibleCaseError(case=[f"{path=}"]) # pragma: no cover
5
21
 
6
22
 
7
23
  def is_valid_import(module: str, /, *, name: str | None = None) -> bool:
@@ -15,4 +31,4 @@ def is_valid_import(module: str, /, *, name: str | None = None) -> bool:
15
31
  return hasattr(mod, name)
16
32
 
17
33
 
18
- __all__ = ["is_valid_import"]
34
+ __all__ = ["files", "is_valid_import"]
utilities/inflect.py CHANGED
@@ -1,17 +1,25 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import cast
3
+ from collections.abc import Sized
4
+ from typing import assert_never, cast
4
5
 
5
6
  from inflect import Word, engine
6
7
 
7
8
  _ENGINE = engine()
8
9
 
9
10
 
10
- def counted_noun(num: int, noun: str, /) -> str:
11
+ def counted_noun(obj: int | Sized, noun: str, /) -> str:
11
12
  """Construct a counted noun."""
13
+ match obj:
14
+ case int() as count:
15
+ ...
16
+ case Sized() as sized:
17
+ count = len(sized)
18
+ case never:
19
+ assert_never(never)
12
20
  word = cast("Word", noun)
13
- sin_or_plu = _ENGINE.plural_noun(word, count=num)
14
- return f"{num} {sin_or_plu}"
21
+ sin_or_plu = _ENGINE.plural_noun(word, count=count)
22
+ return f"{count} {sin_or_plu}"
15
23
 
16
24
 
17
25
  __all__ = ["counted_noun"]
utilities/iterables.py CHANGED
@@ -18,7 +18,7 @@ from enum import Enum
18
18
  from functools import cmp_to_key, partial, reduce
19
19
  from itertools import accumulate, chain, groupby, islice, pairwise, product
20
20
  from math import isnan
21
- from operator import add, itemgetter, or_
21
+ from operator import add, or_
22
22
  from typing import (
23
23
  TYPE_CHECKING,
24
24
  Any,
@@ -31,7 +31,6 @@ from typing import (
31
31
  )
32
32
 
33
33
  from utilities.errors import ImpossibleCaseError
34
- from utilities.functions import ensure_hashable, ensure_str
35
34
  from utilities.math import (
36
35
  _CheckIntegerEqualError,
37
36
  _CheckIntegerEqualOrApproxError,
@@ -40,13 +39,13 @@ from utilities.math import (
40
39
  check_integer,
41
40
  )
42
41
  from utilities.reprlib import get_repr
43
- from utilities.sentinel import Sentinel, sentinel
42
+ from utilities.sentinel import Sentinel, is_sentinel, sentinel
44
43
  from utilities.types import SupportsAdd, SupportsLT
45
44
 
46
45
  if TYPE_CHECKING:
47
46
  from types import NoneType
48
47
 
49
- from utilities.types import MaybeIterable, MaybeIterableHashable, Sign, StrMapping
48
+ from utilities.types import MaybeIterable, Sign, StrMapping
50
49
 
51
50
 
52
51
  ##
@@ -66,16 +65,6 @@ def always_iterable[T](obj: MaybeIterable[T], /) -> Iterable[T]:
66
65
  ##
67
66
 
68
67
 
69
- def always_iterable_hashable[T](
70
- obj: MaybeIterable[T] | None, /
71
- ) -> MaybeIterableHashable[T] | None:
72
- """Ensure an object is always hashable."""
73
- return None if obj is None else tuple(always_iterable(obj))
74
-
75
-
76
- ##
77
-
78
-
79
68
  def apply_bijection[T, U](
80
69
  func: Callable[[T], U], iterable: Iterable[T], /
81
70
  ) -> Mapping[T, U]:
@@ -246,8 +235,7 @@ def check_iterables_equal(left: Iterable[Any], right: Iterable[Any], /) -> None:
246
235
  if lv != rv:
247
236
  errors.append((i, lv, rv))
248
237
  except ValueError as error:
249
- msg = ensure_str(one(error.args))
250
- match msg:
238
+ match one(error.args):
251
239
  case "zip() argument 2 is longer than argument 1":
252
240
  state = "right_longer"
253
241
  case "zip() argument 2 is shorter than argument 1":
@@ -295,7 +283,7 @@ class CheckIterablesEqualError[T](Exception):
295
283
  yield "right was longer"
296
284
  case None:
297
285
  pass
298
- case _ as never:
286
+ case never:
299
287
  assert_never(never)
300
288
 
301
289
 
@@ -690,7 +678,7 @@ def cmp_nullable[T: SupportsLT](x: T | None, y: T | None, /) -> Sign:
690
678
  return 1
691
679
  case _, _:
692
680
  return cast("Sign", (x > y) - (x < y))
693
- case _ as never:
681
+ case never:
694
682
  assert_never(never)
695
683
 
696
684
 
@@ -705,18 +693,6 @@ def chunked[T](iterable: Iterable[T], n: int, /) -> Iterator[Sequence[T]]:
705
693
  ##
706
694
 
707
695
 
708
- def ensure_hashables(
709
- *args: Any, **kwargs: Any
710
- ) -> tuple[list[Hashable], dict[str, Hashable]]:
711
- """Ensure a set of positional & keyword arguments are all hashable."""
712
- hash_args = list(map(ensure_hashable, args))
713
- hash_kwargs = {k: ensure_hashable(v) for k, v in kwargs.items()}
714
- return hash_args, hash_kwargs
715
-
716
-
717
- ##
718
-
719
-
720
696
  def ensure_iterable(obj: Any, /) -> Iterable[Any]:
721
697
  """Ensure an object is iterable."""
722
698
  if is_iterable(obj):
@@ -755,6 +731,26 @@ class EnsureIterableNotStrError(Exception):
755
731
  ##
756
732
 
757
733
 
734
+ _EDGE: int = 5
735
+
736
+
737
+ def enumerate_with_edge[T](
738
+ iterable: Iterable[T], /, *, start: int = 0, edge: int = _EDGE
739
+ ) -> Iterator[tuple[int, int, bool, T]]:
740
+ """Enumerate an iterable, with the edge items marked."""
741
+ as_list = list(iterable)
742
+ total = len(as_list)
743
+ indices = set(range(edge)) | set(range(total)[-edge:])
744
+ is_edge = (i in indices for i in range(total))
745
+ for (i, value), is_edge_i in zip(
746
+ enumerate(as_list, start=start), is_edge, strict=True
747
+ ):
748
+ yield i, total, is_edge_i, value
749
+
750
+
751
+ ##
752
+
753
+
758
754
  def expanding_window[T](iterable: Iterable[T], /) -> islice[list[T]]:
759
755
  """Yield an expanding window over an iterable."""
760
756
 
@@ -811,24 +807,6 @@ def filter_include_and_exclude[T, U](
811
807
  ##
812
808
 
813
809
 
814
- def group_consecutive_integers(iterable: Iterable[int], /) -> Iterable[tuple[int, int]]:
815
- """Group consecutive integers."""
816
- integers = sorted(iterable)
817
- for _, group in groupby(enumerate(integers), key=lambda x: x[1] - x[0]):
818
- as_list = list(map(itemgetter(1), group))
819
- yield as_list[0], as_list[-1]
820
-
821
-
822
- def ungroup_consecutive_integers(
823
- iterable: Iterable[tuple[int, int]], /
824
- ) -> Iterable[int]:
825
- """Ungroup consecutive integers."""
826
- return chain.from_iterable(range(start, end + 1) for start, end in iterable)
827
-
828
-
829
- ##
830
-
831
-
832
810
  @overload
833
811
  def groupby_lists[T](
834
812
  iterable: Iterable[T], /, *, key: None = None
@@ -955,7 +933,7 @@ class MergeStrMappingsError(Exception):
955
933
 
956
934
  def one[T](*iterables: Iterable[T]) -> T:
957
935
  """Return the unique value in a set of iterables."""
958
- it = iter(chain(*iterables))
936
+ it = chain(*iterables)
959
937
  try:
960
938
  first = next(it)
961
939
  except StopIteration:
@@ -1048,7 +1026,7 @@ def one_str(
1048
1026
  it = (t for t in as_list if t.startswith(text))
1049
1027
  case True, False:
1050
1028
  it = (t for t in as_list if t.lower().startswith(text.lower()))
1051
- case _ as never:
1029
+ case never:
1052
1030
  assert_never(never)
1053
1031
  try:
1054
1032
  return one(it)
@@ -1089,7 +1067,7 @@ class OneStrEmptyError(OneStrError):
1089
1067
  tail = f"any string starting with {self.text!r}"
1090
1068
  case True, False:
1091
1069
  tail = f"any string starting with {self.text!r} (modulo case)"
1092
- case _ as never:
1070
+ case never:
1093
1071
  assert_never(never)
1094
1072
  return f"{head} {tail}"
1095
1073
 
@@ -1111,7 +1089,7 @@ class OneStrNonUniqueError(OneStrError):
1111
1089
  mid = f"exactly one string starting with {self.text!r}"
1112
1090
  case True, False:
1113
1091
  mid = f"exactly one string starting with {self.text!r} (modulo case)"
1114
- case _ as never:
1092
+ case never:
1115
1093
  assert_never(never)
1116
1094
  return f"{head} {mid}; got {self.first!r}, {self.second!r} and perhaps more"
1117
1095
 
@@ -1245,7 +1223,7 @@ def reduce_mappings[K, V, W](
1245
1223
  ) -> Mapping[K, V | W]:
1246
1224
  """Reduce a function over the values of a set of mappings."""
1247
1225
  chained = chain_mappings(*sequence)
1248
- if isinstance(initial, Sentinel):
1226
+ if is_sentinel(initial):
1249
1227
  func2 = cast("Callable[[V, V], V]", func)
1250
1228
  return {k: reduce(func2, v) for k, v in chained.items()}
1251
1229
  func2 = cast("Callable[[W, V], W]", func)
@@ -1362,7 +1340,7 @@ def _sort_iterable_cmp_floats(x: float, y: float, /) -> Sign:
1362
1340
  return -1
1363
1341
  case False, False:
1364
1342
  return cast("Sign", (x > y) - (x < y))
1365
- case _ as never:
1343
+ case never:
1366
1344
  assert_never(never)
1367
1345
 
1368
1346
 
@@ -1469,7 +1447,6 @@ __all__ = [
1469
1447
  "ResolveIncludeAndExcludeError",
1470
1448
  "SortIterableError",
1471
1449
  "always_iterable",
1472
- "always_iterable_hashable",
1473
1450
  "apply_bijection",
1474
1451
  "apply_to_tuple",
1475
1452
  "apply_to_varargs",
@@ -1489,12 +1466,11 @@ __all__ = [
1489
1466
  "check_unique_modulo_case",
1490
1467
  "chunked",
1491
1468
  "cmp_nullable",
1492
- "ensure_hashables",
1493
1469
  "ensure_iterable",
1494
1470
  "ensure_iterable_not_str",
1471
+ "enumerate_with_edge",
1495
1472
  "expanding_window",
1496
1473
  "filter_include_and_exclude",
1497
- "group_consecutive_integers",
1498
1474
  "groupby_lists",
1499
1475
  "hashable_to_iterable",
1500
1476
  "is_iterable",
@@ -1517,6 +1493,5 @@ __all__ = [
1517
1493
  "sum_mappings",
1518
1494
  "take",
1519
1495
  "transpose",
1520
- "ungroup_consecutive_integers",
1521
1496
  "unique_everseen",
1522
1497
  ]
utilities/jinja2.py ADDED
@@ -0,0 +1,148 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING, Any, Literal, assert_never, override
5
+
6
+ from jinja2 import BaseLoader, BytecodeCache, Environment, FileSystemLoader, Undefined
7
+ from jinja2.defaults import (
8
+ BLOCK_END_STRING,
9
+ BLOCK_START_STRING,
10
+ COMMENT_END_STRING,
11
+ COMMENT_START_STRING,
12
+ KEEP_TRAILING_NEWLINE,
13
+ LINE_COMMENT_PREFIX,
14
+ LINE_STATEMENT_PREFIX,
15
+ LSTRIP_BLOCKS,
16
+ NEWLINE_SEQUENCE,
17
+ TRIM_BLOCKS,
18
+ VARIABLE_END_STRING,
19
+ VARIABLE_START_STRING,
20
+ )
21
+
22
+ from utilities.atomicwrites import writer
23
+ from utilities.text import kebab_case, pascal_case, snake_case
24
+
25
+ if TYPE_CHECKING:
26
+ from collections.abc import Callable, Sequence
27
+ from pathlib import Path
28
+
29
+ from jinja2.ext import Extension
30
+
31
+ from utilities.types import StrMapping
32
+
33
+
34
+ class EnhancedEnvironment(Environment):
35
+ """Environment with enhanced features."""
36
+
37
+ @override
38
+ def __init__(
39
+ self,
40
+ block_start_string: str = BLOCK_START_STRING,
41
+ block_end_string: str = BLOCK_END_STRING,
42
+ variable_start_string: str = VARIABLE_START_STRING,
43
+ variable_end_string: str = VARIABLE_END_STRING,
44
+ comment_start_string: str = COMMENT_START_STRING,
45
+ comment_end_string: str = COMMENT_END_STRING,
46
+ line_statement_prefix: str | None = LINE_STATEMENT_PREFIX,
47
+ line_comment_prefix: str | None = LINE_COMMENT_PREFIX,
48
+ trim_blocks: bool = TRIM_BLOCKS,
49
+ lstrip_blocks: bool = LSTRIP_BLOCKS,
50
+ newline_sequence: Literal["\n", "\r\n", "\r"] = NEWLINE_SEQUENCE,
51
+ keep_trailing_newline: bool = KEEP_TRAILING_NEWLINE,
52
+ extensions: Sequence[str | type[Extension]] = (),
53
+ optimized: bool = True,
54
+ undefined: type[Undefined] = Undefined,
55
+ finalize: Callable[..., Any] | None = None,
56
+ autoescape: bool | Callable[[str | None], bool] = False,
57
+ loader: BaseLoader | None = None,
58
+ cache_size: int = 400,
59
+ auto_reload: bool = True,
60
+ bytecode_cache: BytecodeCache | None = None,
61
+ enable_async: bool = False,
62
+ ) -> None:
63
+ super().__init__(
64
+ block_start_string,
65
+ block_end_string,
66
+ variable_start_string,
67
+ variable_end_string,
68
+ comment_start_string,
69
+ comment_end_string,
70
+ line_statement_prefix,
71
+ line_comment_prefix,
72
+ trim_blocks,
73
+ lstrip_blocks,
74
+ newline_sequence,
75
+ keep_trailing_newline,
76
+ extensions,
77
+ optimized,
78
+ undefined,
79
+ finalize,
80
+ autoescape,
81
+ loader,
82
+ cache_size,
83
+ auto_reload,
84
+ bytecode_cache,
85
+ enable_async,
86
+ )
87
+ self.filters["kebab"] = kebab_case
88
+ self.filters["pascal"] = pascal_case
89
+ self.filters["snake"] = snake_case
90
+
91
+
92
+ @dataclass(order=True, unsafe_hash=True, kw_only=True, slots=True)
93
+ class TemplateJob:
94
+ """A template with an associated rendering job."""
95
+
96
+ template: Path
97
+ kwargs: StrMapping
98
+ target: Path
99
+ mode: Literal["write", "append"] = "write"
100
+
101
+ def __post_init__(self) -> None:
102
+ if not self.template.exists():
103
+ raise _TemplateJobTemplateDoesNotExistError(path=self.template)
104
+ if (self.mode == "append") and not self.target.exists():
105
+ raise _TemplateJobTargetDoesNotExistError(path=self.template)
106
+
107
+ def run(self) -> None:
108
+ """Run the job."""
109
+ match self.mode:
110
+ case "write":
111
+ with writer(self.target, overwrite=True) as temp:
112
+ _ = temp.write_text(self.rendered)
113
+ case "append":
114
+ with self.target.open(mode="a") as fh:
115
+ _ = fh.write(self.rendered)
116
+ case never:
117
+ assert_never(never)
118
+
119
+ @property
120
+ def rendered(self) -> str:
121
+ """The template, rendered."""
122
+ env = EnhancedEnvironment(loader=FileSystemLoader(self.template.parent))
123
+ return env.get_template(self.template.name).render(self.kwargs)
124
+
125
+
126
+ @dataclass(kw_only=True, slots=True)
127
+ class TemplateJobError(Exception): ...
128
+
129
+
130
+ @dataclass(kw_only=True, slots=True)
131
+ class _TemplateJobTemplateDoesNotExistError(TemplateJobError):
132
+ path: Path
133
+
134
+ @override
135
+ def __str__(self) -> str:
136
+ return f"Template {str(self.path)!r} does not exist"
137
+
138
+
139
+ @dataclass(kw_only=True, slots=True)
140
+ class _TemplateJobTargetDoesNotExistError(TemplateJobError):
141
+ path: Path
142
+
143
+ @override
144
+ def __str__(self) -> str:
145
+ return f"Target {str(self.path)!r} does not exist"
146
+
147
+
148
+ __all__ = ["EnhancedEnvironment", "TemplateJob", "TemplateJobError"]
utilities/json.py ADDED
@@ -0,0 +1,70 @@
1
+ from __future__ import annotations
2
+
3
+ from contextlib import suppress
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from subprocess import check_output
7
+ from typing import TYPE_CHECKING, assert_never, overload, override
8
+
9
+ from utilities.atomicwrites import writer
10
+ from utilities.gzip import write_binary
11
+
12
+ if TYPE_CHECKING:
13
+ from utilities.types import PathLike
14
+
15
+
16
+ ##
17
+
18
+
19
+ @overload
20
+ def run_prettier(source: bytes, /) -> bytes: ...
21
+ @overload
22
+ def run_prettier(source: str, /) -> str: ...
23
+ @overload
24
+ def run_prettier(source: Path, /) -> None: ...
25
+ def run_prettier(source: bytes | str | Path, /) -> bytes | str | None:
26
+ """Run `prettier` on a string/path."""
27
+ match source: # skipif-ci
28
+ case bytes() as data:
29
+ return _run_prettier_core(data, text=False)
30
+ case str() as text:
31
+ if (path := Path(text)).is_file():
32
+ return run_prettier(path)
33
+ return _run_prettier_core(text, text=True)
34
+ case Path() as path:
35
+ result = run_prettier(path.read_bytes())
36
+ with writer(path, overwrite=True) as temp:
37
+ _ = temp.write_bytes(result)
38
+ return None
39
+ case never:
40
+ assert_never(never)
41
+
42
+
43
+ def _run_prettier_core(data: bytes | str, /, *, text: bool) -> bytes | str:
44
+ """Run `prettier` on a string/path."""
45
+ try: # skipif-ci
46
+ return check_output(["prettier", "--parser=json"], input=data, text=text)
47
+ except FileNotFoundError: # pragma: no cover
48
+ raise RunPrettierError from None
49
+
50
+
51
+ @dataclass(kw_only=True, slots=True)
52
+ class RunPrettierError(Exception):
53
+ @override
54
+ def __str__(self) -> str:
55
+ return "Unable to find 'prettier'" # pragma: no cover
56
+
57
+
58
+ ##
59
+
60
+
61
+ def write_formatted_json(
62
+ data: bytes, path: PathLike, /, *, compress: bool = False, overwrite: bool = False
63
+ ) -> None:
64
+ """Write a formatted byte string to disk."""
65
+ with suppress(RunPrettierError):
66
+ data = run_prettier(data)
67
+ write_binary(data, path, compress=compress, overwrite=overwrite)
68
+
69
+
70
+ __all__ = ["RunPrettierError", "run_prettier", "write_formatted_json"]
utilities/libcst.py CHANGED
@@ -23,16 +23,6 @@ from libcst import (
23
23
  from utilities.errors import ImpossibleCaseError
24
24
 
25
25
 
26
- def generate_from_import(
27
- module: str, name: str, /, *, asname: str | None = None
28
- ) -> ImportFrom:
29
- """Generate an `ImportFrom` object."""
30
- alias = ImportAlias(
31
- name=Name(name), asname=AsName(Name(asname)) if asname else None
32
- )
33
- return ImportFrom(module=split_dotted_str(module), names=[alias])
34
-
35
-
36
26
  def generate_f_string(var: str, suffix: str, /) -> FormattedString:
37
27
  """Generate an f-string."""
38
28
  return FormattedString([
@@ -49,6 +39,36 @@ def generate_import(module: str, /, *, asname: str | None = None) -> Import:
49
39
  return Import(names=[alias])
50
40
 
51
41
 
42
+ def generate_import_from(
43
+ module: str, name: str, /, *, asname: str | None = None
44
+ ) -> ImportFrom:
45
+ """Generate an `ImportFrom` object."""
46
+ match name, asname:
47
+ case "*", None:
48
+ names = ImportStar()
49
+ case "*", str():
50
+ raise GenerateImportFromError(module=module, asname=asname)
51
+ case _, None:
52
+ alias = ImportAlias(name=Name(name))
53
+ names = [alias]
54
+ case _, str():
55
+ alias = ImportAlias(name=Name(name), asname=AsName(Name(asname)))
56
+ names = [alias]
57
+ case never:
58
+ assert_never(never)
59
+ return ImportFrom(module=split_dotted_str(module), names=names)
60
+
61
+
62
+ @dataclass(kw_only=True, slots=True)
63
+ class GenerateImportFromError(Exception):
64
+ module: str
65
+ asname: str | None = None
66
+
67
+ @override
68
+ def __str__(self) -> str:
69
+ return f"Invalid import: 'from {self.module} import * as {self.asname}'"
70
+
71
+
52
72
  ##
53
73
 
54
74
 
@@ -72,9 +92,9 @@ def parse_import(import_: Import | ImportFrom, /) -> Sequence[_ParseImportOutput
72
92
  return [_parse_import_from_one(module, n) for n in names]
73
93
  case ImportStar():
74
94
  return [_ParseImportOutput(module=module, name="*")]
75
- case _ as never:
95
+ case never:
76
96
  assert_never(never)
77
- case _ as never:
97
+ case never:
78
98
  assert_never(never)
79
99
 
80
100
 
@@ -88,7 +108,7 @@ def _parse_import_from_one(module: str, alias: ImportAlias, /) -> _ParseImportOu
88
108
  return _ParseImportOutput(module=module, name=name)
89
109
  case Attribute() as attr:
90
110
  raise _ParseImportAliasError(module=module, attr=attr)
91
- case _ as never:
111
+ case never:
92
112
  assert_never(never)
93
113
 
94
114
 
@@ -130,7 +150,7 @@ def split_dotted_str(dotted: str, /) -> Name | Attribute:
130
150
 
131
151
  def join_dotted_str(name_or_attr: Name | Attribute, /) -> str:
132
152
  """Join a dotted from from a name/attribute."""
133
- parts: Sequence[str] = []
153
+ parts: list[str] = []
134
154
  curr: BaseExpression | Name | Attribute = name_or_attr
135
155
  while True:
136
156
  match curr:
@@ -142,7 +162,7 @@ def join_dotted_str(name_or_attr: Name | Attribute, /) -> str:
142
162
  curr = value
143
163
  case BaseExpression(): # pragma: no cover
144
164
  raise ImpossibleCaseError(case=[f"{curr=}"])
145
- case _ as never:
165
+ case never:
146
166
  assert_never(never)
147
167
  return ".".join(reversed(parts))
148
168
 
@@ -160,7 +180,7 @@ def render_module(source: str | Module, /) -> str:
160
180
  return text
161
181
  case Module() as module:
162
182
  return render_module(module.code)
163
- case _ as never:
183
+ case never:
164
184
  assert_never(never)
165
185
 
166
186
 
@@ -168,10 +188,11 @@ def render_module(source: str | Module, /) -> str:
168
188
 
169
189
 
170
190
  __all__ = [
191
+ "GenerateImportFromError",
171
192
  "ParseImportError",
172
193
  "generate_f_string",
173
- "generate_from_import",
174
194
  "generate_import",
195
+ "generate_import_from",
175
196
  "join_dotted_str",
176
197
  "parse_import",
177
198
  "render_module",
@@ -1,10 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
- from contextlib import asynccontextmanager
4
3
  from dataclasses import dataclass
5
4
  from typing import TYPE_CHECKING, override
6
5
 
7
6
  from utilities.atomicwrites import writer # pragma: no cover
7
+ from utilities.contextlib import enhanced_async_context_manager
8
8
  from utilities.iterables import OneEmptyError, OneNonUniqueError, one
9
9
  from utilities.reprlib import get_repr
10
10
 
@@ -25,11 +25,8 @@ if TYPE_CHECKING:
25
25
  def save_chart(chart: Chart, path: PathLike, /, *, overwrite: bool = False) -> None:
26
26
  """Atomically save a chart to disk."""
27
27
  chart.show(block=False) # pragma: no cover
28
- with ( # pragma: no cover
29
- writer(path, overwrite=overwrite) as temp,
30
- temp.open(mode="wb") as fh,
31
- ):
32
- _ = fh.write(chart.screenshot())
28
+ with writer(path, overwrite=overwrite) as temp: # pragma: no cover
29
+ _ = temp.write_bytes(chart.screenshot())
33
30
  chart.exit() # pragma: no cover
34
31
 
35
32
 
@@ -81,7 +78,7 @@ class _SetDataFrameNonUniqueError(SetDataFrameError):
81
78
  ##
82
79
 
83
80
 
84
- @asynccontextmanager
81
+ @enhanced_async_context_manager
85
82
  async def yield_chart(chart: Chart, /) -> AsyncIterator[None]:
86
83
  """Yield a chart for visualization in a notebook."""
87
84
  try: # pragma: no cover