dycw-utilities 0.133.7__py3-none-any.whl → 0.134.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dycw_utilities-0.133.7.dist-info → dycw_utilities-0.134.0.dist-info}/METADATA +1 -1
- {dycw_utilities-0.133.7.dist-info → dycw_utilities-0.134.0.dist-info}/RECORD +39 -39
- utilities/__init__.py +1 -1
- utilities/arq.py +13 -20
- utilities/asyncio.py +59 -74
- utilities/atools.py +10 -13
- utilities/cachetools.py +11 -17
- utilities/click.py +31 -51
- utilities/concurrent.py +7 -10
- utilities/dataclasses.py +69 -91
- utilities/enum.py +24 -21
- utilities/eventkit.py +34 -48
- utilities/functions.py +133 -168
- utilities/functools.py +14 -18
- utilities/hypothesis.py +34 -44
- utilities/iterables.py +165 -179
- utilities/luigi.py +3 -15
- utilities/memory_profiler.py +11 -15
- utilities/more_itertools.py +85 -94
- utilities/operator.py +5 -7
- utilities/optuna.py +6 -6
- utilities/pathlib.py +1 -0
- utilities/period.py +7 -9
- utilities/polars.py +5 -16
- utilities/pqdm.py +7 -8
- utilities/pydantic.py +2 -4
- utilities/pytest.py +14 -23
- utilities/python_dotenv.py +5 -9
- utilities/random.py +2 -3
- utilities/redis.py +163 -181
- utilities/slack_sdk.py +2 -2
- utilities/sqlalchemy.py +4 -14
- utilities/timer.py +6 -0
- utilities/typed_settings.py +7 -10
- utilities/types.py +10 -94
- utilities/typing.py +32 -43
- utilities/uuid.py +1 -0
- {dycw_utilities-0.133.7.dist-info → dycw_utilities-0.134.0.dist-info}/WHEEL +0 -0
- {dycw_utilities-0.133.7.dist-info → dycw_utilities-0.134.0.dist-info}/licenses/LICENSE +0 -0
utilities/luigi.py
CHANGED
@@ -2,16 +2,7 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
4
|
from pathlib import Path
|
5
|
-
from typing import
|
6
|
-
TYPE_CHECKING,
|
7
|
-
Any,
|
8
|
-
Literal,
|
9
|
-
TypeVar,
|
10
|
-
assert_never,
|
11
|
-
cast,
|
12
|
-
overload,
|
13
|
-
override,
|
14
|
-
)
|
5
|
+
from typing import TYPE_CHECKING, Any, Literal, assert_never, cast, overload, override
|
15
6
|
|
16
7
|
import luigi
|
17
8
|
from luigi import Parameter, PathParameter, Target, Task
|
@@ -27,9 +18,6 @@ if TYPE_CHECKING:
|
|
27
18
|
from utilities.types import DateTimeRoundUnit, LogLevel, PathLike, ZonedDateTimeLike
|
28
19
|
|
29
20
|
|
30
|
-
_T = TypeVar("_T")
|
31
|
-
|
32
|
-
|
33
21
|
# parameters
|
34
22
|
|
35
23
|
|
@@ -40,7 +28,7 @@ class ZonedDateTimeParameter(Parameter):
|
|
40
28
|
_increment: int
|
41
29
|
|
42
30
|
@override
|
43
|
-
def __init__(
|
31
|
+
def __init__[T](
|
44
32
|
self,
|
45
33
|
default: Any = _no_value,
|
46
34
|
is_global: bool = False,
|
@@ -49,7 +37,7 @@ class ZonedDateTimeParameter(Parameter):
|
|
49
37
|
config_path: None = None,
|
50
38
|
positional: bool = True,
|
51
39
|
always_in_help: bool = False,
|
52
|
-
batch_method: Callable[[Iterable[
|
40
|
+
batch_method: Callable[[Iterable[T]], T] | None = None,
|
53
41
|
visibility: ParameterVisibility = ParameterVisibility.PUBLIC,
|
54
42
|
*,
|
55
43
|
unit: DateTimeRoundUnit = "second",
|
utilities/memory_profiler.py
CHANGED
@@ -2,31 +2,19 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from dataclasses import dataclass
|
4
4
|
from functools import wraps
|
5
|
-
from typing import TYPE_CHECKING, Any,
|
5
|
+
from typing import TYPE_CHECKING, Any, cast
|
6
6
|
|
7
7
|
from memory_profiler import memory_usage
|
8
|
-
from typing_extensions import ParamSpec
|
9
8
|
|
10
9
|
if TYPE_CHECKING:
|
11
10
|
from collections.abc import Callable
|
12
11
|
|
13
|
-
_P = ParamSpec("_P")
|
14
|
-
_T = TypeVar("_T")
|
15
12
|
|
16
|
-
|
17
|
-
@dataclass(kw_only=True, slots=True)
|
18
|
-
class Output(Generic[_T]):
|
19
|
-
"""A function output, and its memory usage."""
|
20
|
-
|
21
|
-
value: _T
|
22
|
-
memory: float
|
23
|
-
|
24
|
-
|
25
|
-
def memory_profiled(func: Callable[_P, _T], /) -> Callable[_P, Output[_T]]:
|
13
|
+
def memory_profiled[**P, T](func: Callable[P, T], /) -> Callable[P, Output[T]]:
|
26
14
|
"""Call a function, but also profile its maximum memory usage."""
|
27
15
|
|
28
16
|
@wraps(func)
|
29
|
-
def wrapped(*args:
|
17
|
+
def wrapped(*args: P.args, **kwargs: P.kwargs) -> Output[T]:
|
30
18
|
memory, value = memory_usage(
|
31
19
|
cast("Any", (func, args, kwargs)), max_usage=True, retval=True
|
32
20
|
)
|
@@ -35,4 +23,12 @@ def memory_profiled(func: Callable[_P, _T], /) -> Callable[_P, Output[_T]]:
|
|
35
23
|
return wrapped
|
36
24
|
|
37
25
|
|
26
|
+
@dataclass(kw_only=True, slots=True)
|
27
|
+
class Output[T]:
|
28
|
+
"""A function output, and its memory usage."""
|
29
|
+
|
30
|
+
value: T
|
31
|
+
memory: float
|
32
|
+
|
33
|
+
|
38
34
|
__all__ = ["Output", "memory_profiled"]
|
utilities/more_itertools.py
CHANGED
@@ -1,16 +1,15 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import builtins
|
4
|
+
from collections.abc import Hashable
|
4
5
|
from dataclasses import dataclass
|
5
6
|
from itertools import islice
|
6
7
|
from textwrap import indent
|
7
8
|
from typing import (
|
8
9
|
TYPE_CHECKING,
|
9
10
|
Any,
|
10
|
-
Generic,
|
11
11
|
Literal,
|
12
12
|
TypeGuard,
|
13
|
-
TypeVar,
|
14
13
|
assert_never,
|
15
14
|
cast,
|
16
15
|
overload,
|
@@ -24,107 +23,99 @@ from utilities.functions import get_class_name
|
|
24
23
|
from utilities.iterables import OneNonUniqueError, one
|
25
24
|
from utilities.reprlib import get_repr
|
26
25
|
from utilities.sentinel import Sentinel, sentinel
|
27
|
-
from utilities.types import THashable
|
28
26
|
|
29
27
|
if TYPE_CHECKING:
|
30
28
|
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
|
31
29
|
|
32
30
|
|
33
|
-
_T = TypeVar("_T")
|
34
|
-
_U = TypeVar("_U")
|
35
|
-
|
36
|
-
|
37
|
-
##
|
38
|
-
|
39
|
-
|
40
31
|
@overload
|
41
|
-
def bucket_mapping(
|
42
|
-
iterable: Iterable[
|
43
|
-
func: Callable[[
|
32
|
+
def bucket_mapping[T, U, UH: Hashable](
|
33
|
+
iterable: Iterable[T],
|
34
|
+
func: Callable[[T], UH],
|
44
35
|
/,
|
45
36
|
*,
|
46
|
-
transform: Callable[[
|
37
|
+
transform: Callable[[T], U],
|
47
38
|
list: bool = False,
|
48
39
|
unique: Literal[True],
|
49
|
-
) -> Mapping[
|
40
|
+
) -> Mapping[UH, U]: ...
|
50
41
|
@overload
|
51
|
-
def bucket_mapping(
|
52
|
-
iterable: Iterable[
|
53
|
-
func: Callable[[
|
42
|
+
def bucket_mapping[T, U, UH: Hashable](
|
43
|
+
iterable: Iterable[T],
|
44
|
+
func: Callable[[T], UH],
|
54
45
|
/,
|
55
46
|
*,
|
56
|
-
transform: Callable[[
|
47
|
+
transform: Callable[[T], U] | None = None,
|
57
48
|
list: bool = False,
|
58
49
|
unique: Literal[True],
|
59
|
-
) -> Mapping[
|
50
|
+
) -> Mapping[UH, T]: ...
|
60
51
|
@overload
|
61
|
-
def bucket_mapping(
|
62
|
-
iterable: Iterable[
|
63
|
-
func: Callable[[
|
52
|
+
def bucket_mapping[T, U, UH: Hashable](
|
53
|
+
iterable: Iterable[T],
|
54
|
+
func: Callable[[T], UH],
|
64
55
|
/,
|
65
56
|
*,
|
66
|
-
transform: Callable[[
|
57
|
+
transform: Callable[[T], U],
|
67
58
|
list: Literal[True],
|
68
|
-
) -> Mapping[
|
59
|
+
) -> Mapping[UH, Sequence[U]]: ...
|
69
60
|
@overload
|
70
|
-
def bucket_mapping(
|
71
|
-
iterable: Iterable[
|
72
|
-
func: Callable[[
|
61
|
+
def bucket_mapping[T, U, UH: Hashable](
|
62
|
+
iterable: Iterable[T],
|
63
|
+
func: Callable[[T], UH],
|
73
64
|
/,
|
74
65
|
*,
|
75
|
-
transform: Callable[[
|
66
|
+
transform: Callable[[T], U],
|
76
67
|
list: bool = False,
|
77
|
-
) -> Mapping[
|
68
|
+
) -> Mapping[UH, Iterator[U]]: ...
|
78
69
|
@overload
|
79
|
-
def bucket_mapping(
|
80
|
-
iterable: Iterable[
|
81
|
-
func: Callable[[
|
70
|
+
def bucket_mapping[T, U, UH: Hashable](
|
71
|
+
iterable: Iterable[T],
|
72
|
+
func: Callable[[T], UH],
|
82
73
|
/,
|
83
74
|
*,
|
84
|
-
transform: Callable[[
|
75
|
+
transform: Callable[[T], U] | None = None,
|
85
76
|
list: Literal[True],
|
86
|
-
) -> Mapping[
|
77
|
+
) -> Mapping[UH, Sequence[T]]: ...
|
87
78
|
@overload
|
88
|
-
def bucket_mapping(
|
89
|
-
iterable: Iterable[
|
90
|
-
func: Callable[[
|
79
|
+
def bucket_mapping[T, U, UH: Hashable](
|
80
|
+
iterable: Iterable[T],
|
81
|
+
func: Callable[[T], UH],
|
91
82
|
/,
|
92
83
|
*,
|
93
|
-
transform: Callable[[
|
84
|
+
transform: Callable[[T], U] | None = None,
|
94
85
|
list: bool = False,
|
95
|
-
) -> Mapping[
|
86
|
+
) -> Mapping[UH, Iterator[T]]: ...
|
96
87
|
@overload
|
97
|
-
def bucket_mapping(
|
98
|
-
iterable: Iterable[
|
99
|
-
func: Callable[[
|
88
|
+
def bucket_mapping[T, U, UH: Hashable](
|
89
|
+
iterable: Iterable[T],
|
90
|
+
func: Callable[[T], UH],
|
100
91
|
/,
|
101
92
|
*,
|
102
|
-
transform: Callable[[
|
93
|
+
transform: Callable[[T], U] | None = None,
|
103
94
|
list: bool = False,
|
104
95
|
unique: bool = False,
|
105
96
|
) -> (
|
106
|
-
Mapping[
|
107
|
-
| Mapping[
|
108
|
-
| Mapping[
|
109
|
-
| Mapping[
|
110
|
-
| Mapping[
|
111
|
-
| Mapping[
|
97
|
+
Mapping[UH, Iterator[T]]
|
98
|
+
| Mapping[UH, Iterator[U]]
|
99
|
+
| Mapping[UH, Sequence[T]]
|
100
|
+
| Mapping[UH, Sequence[U]]
|
101
|
+
| Mapping[UH, T]
|
102
|
+
| Mapping[UH, U]
|
112
103
|
): ...
|
113
|
-
def bucket_mapping(
|
114
|
-
iterable: Iterable[
|
115
|
-
func: Callable[[
|
104
|
+
def bucket_mapping[T, U, UH: Hashable](
|
105
|
+
iterable: Iterable[T],
|
106
|
+
func: Callable[[T], UH],
|
116
107
|
/,
|
117
108
|
*,
|
118
|
-
transform: Callable[[
|
109
|
+
transform: Callable[[T], U] | None = None,
|
119
110
|
list: bool = False, # noqa: A002
|
120
111
|
unique: bool = False,
|
121
112
|
) -> (
|
122
|
-
Mapping[
|
123
|
-
| Mapping[
|
124
|
-
| Mapping[
|
125
|
-
| Mapping[
|
126
|
-
| Mapping[
|
127
|
-
| Mapping[
|
113
|
+
Mapping[UH, Iterator[T]]
|
114
|
+
| Mapping[UH, Iterator[U]]
|
115
|
+
| Mapping[UH, Sequence[T]]
|
116
|
+
| Mapping[UH, Sequence[U]]
|
117
|
+
| Mapping[UH, T]
|
118
|
+
| Mapping[UH, U]
|
128
119
|
):
|
129
120
|
"""Bucket the values of iterable into a mapping."""
|
130
121
|
b = bucket(iterable, func)
|
@@ -143,7 +134,7 @@ def bucket_mapping(
|
|
143
134
|
if not unique:
|
144
135
|
return mapping
|
145
136
|
results = {}
|
146
|
-
error_no_transform: dict[
|
137
|
+
error_no_transform: dict[UH, tuple[T, T]] = {}
|
147
138
|
for key, value in mapping.items():
|
148
139
|
try:
|
149
140
|
results[key] = one(value)
|
@@ -155,8 +146,8 @@ def bucket_mapping(
|
|
155
146
|
|
156
147
|
|
157
148
|
@dataclass(kw_only=True, slots=True)
|
158
|
-
class BucketMappingError
|
159
|
-
errors: Mapping[
|
149
|
+
class BucketMappingError[K: Hashable, V](Exception):
|
150
|
+
errors: Mapping[K, tuple[V, V]]
|
160
151
|
|
161
152
|
@override
|
162
153
|
def __str__(self) -> str:
|
@@ -171,9 +162,9 @@ class BucketMappingError(Exception, Generic[THashable, _U]):
|
|
171
162
|
##
|
172
163
|
|
173
164
|
|
174
|
-
def partition_list(
|
175
|
-
pred: Callable[[
|
176
|
-
) -> tuple[list[
|
165
|
+
def partition_list[T](
|
166
|
+
pred: Callable[[T], bool], iterable: Iterable[T], /
|
167
|
+
) -> tuple[list[T], list[T]]:
|
177
168
|
"""Partition with lists."""
|
178
169
|
false, true = partition(pred, iterable)
|
179
170
|
return list(false), list(true)
|
@@ -182,48 +173,48 @@ def partition_list(
|
|
182
173
|
##
|
183
174
|
|
184
175
|
|
185
|
-
def partition_typeguard(
|
186
|
-
pred: Callable[[
|
187
|
-
) -> tuple[Iterator[
|
176
|
+
def partition_typeguard[T, U](
|
177
|
+
pred: Callable[[T], TypeGuard[U]], iterable: Iterable[T], /
|
178
|
+
) -> tuple[Iterator[T], Iterator[U]]:
|
188
179
|
"""Partition with a typeguarded function."""
|
189
180
|
false, true = partition(pred, iterable)
|
190
|
-
true = cast("Iterator[
|
181
|
+
true = cast("Iterator[U]", true)
|
191
182
|
return false, true
|
192
183
|
|
193
184
|
|
194
185
|
##
|
195
186
|
|
196
187
|
|
197
|
-
class peekable(_peekable
|
188
|
+
class peekable[T](_peekable): # noqa: N801
|
198
189
|
"""Peekable which supports dropwhile/takewhile methods."""
|
199
190
|
|
200
|
-
def __init__(self, iterable: Iterable[
|
191
|
+
def __init__(self, iterable: Iterable[T], /) -> None:
|
201
192
|
super().__init__(iterable)
|
202
193
|
|
203
194
|
@override
|
204
|
-
def __iter__(self) -> Iterator[
|
195
|
+
def __iter__(self) -> Iterator[T]: # pyright: ignore[reportIncompatibleMethodOverride]
|
205
196
|
while bool(self):
|
206
197
|
yield next(self)
|
207
198
|
|
208
199
|
@override
|
209
|
-
def __next__(self) ->
|
200
|
+
def __next__(self) -> T:
|
210
201
|
return super().__next__()
|
211
202
|
|
212
|
-
def dropwhile(self, predicate: Callable[[
|
203
|
+
def dropwhile(self, predicate: Callable[[T], bool], /) -> None:
|
213
204
|
while bool(self) and predicate(self.peek()):
|
214
205
|
_ = next(self)
|
215
206
|
|
216
207
|
@overload
|
217
|
-
def peek(self, *, default: Sentinel = sentinel) ->
|
208
|
+
def peek(self, *, default: Sentinel = sentinel) -> T: ...
|
218
209
|
@overload
|
219
|
-
def peek(self, *, default:
|
210
|
+
def peek[U](self, *, default: U) -> T | U: ...
|
220
211
|
@override
|
221
212
|
def peek(self, *, default: Any = sentinel) -> Any: # pyright: ignore[reportIncompatibleMethodOverride]
|
222
213
|
if isinstance(default, Sentinel):
|
223
214
|
return super().peek()
|
224
215
|
return super().peek(default=default)
|
225
216
|
|
226
|
-
def takewhile(self, predicate: Callable[[
|
217
|
+
def takewhile(self, predicate: Callable[[T], bool], /) -> Iterator[T]:
|
227
218
|
while bool(self) and predicate(self.peek()):
|
228
219
|
yield next(self)
|
229
220
|
|
@@ -232,11 +223,11 @@ class peekable(_peekable, Generic[_T]): # noqa: N801
|
|
232
223
|
|
233
224
|
|
234
225
|
@dataclass(kw_only=True, slots=True)
|
235
|
-
class Split
|
226
|
+
class Split[T]:
|
236
227
|
"""An iterable split into head/tail."""
|
237
228
|
|
238
|
-
head:
|
239
|
-
tail:
|
229
|
+
head: T
|
230
|
+
tail: T
|
240
231
|
|
241
232
|
@override
|
242
233
|
def __repr__(self) -> str:
|
@@ -250,15 +241,15 @@ class Split(Generic[_T]):
|
|
250
241
|
return f"{cls}(\n{joined}\n)"
|
251
242
|
|
252
243
|
|
253
|
-
def yield_splits(
|
254
|
-
iterable: Iterable[
|
244
|
+
def yield_splits[T](
|
245
|
+
iterable: Iterable[T],
|
255
246
|
head: int,
|
256
247
|
tail: int,
|
257
248
|
/,
|
258
249
|
*,
|
259
250
|
min_frac: float | None = None,
|
260
251
|
freq: int | None = None,
|
261
|
-
) -> Iterator[Split[Sequence[
|
252
|
+
) -> Iterator[Split[Sequence[T]]]:
|
262
253
|
"""Yield the splits of an iterable."""
|
263
254
|
it1 = _yield_splits1(iterable, head + tail)
|
264
255
|
it2 = _yield_splits2(it1, head, tail, min_frac=min_frac)
|
@@ -267,9 +258,9 @@ def yield_splits(
|
|
267
258
|
return islice(it3, 0, None, freq_use)
|
268
259
|
|
269
260
|
|
270
|
-
def _yield_splits1(
|
271
|
-
iterable: Iterable[
|
272
|
-
) -> Iterator[tuple[Literal["head", "body"], Sequence[
|
261
|
+
def _yield_splits1[T](
|
262
|
+
iterable: Iterable[T], total: int, /
|
263
|
+
) -> Iterator[tuple[Literal["head", "body"], Sequence[T]]]:
|
273
264
|
peek = peekable(iterable)
|
274
265
|
for i in range(1, total + 1):
|
275
266
|
if len(result := peek[:i]) < i:
|
@@ -283,14 +274,14 @@ def _yield_splits1(
|
|
283
274
|
break
|
284
275
|
|
285
276
|
|
286
|
-
def _yield_splits2(
|
287
|
-
iterable: Iterable[tuple[Literal["head", "body"], Sequence[
|
277
|
+
def _yield_splits2[T](
|
278
|
+
iterable: Iterable[tuple[Literal["head", "body"], Sequence[T]],],
|
288
279
|
head: int,
|
289
280
|
tail: int,
|
290
281
|
/,
|
291
282
|
*,
|
292
283
|
min_frac: float | None = None,
|
293
|
-
) -> Iterator[tuple[Iterable[
|
284
|
+
) -> Iterator[tuple[Iterable[T], int, int]]:
|
294
285
|
min_length = head if min_frac is None else min_frac * head
|
295
286
|
for kind, window in iterable:
|
296
287
|
len_win = len(window)
|
@@ -307,13 +298,13 @@ def _yield_splits2(
|
|
307
298
|
assert_never(never)
|
308
299
|
|
309
300
|
|
310
|
-
def _yield_splits3(
|
311
|
-
iterable: Iterable[tuple[Iterable[
|
312
|
-
) -> Iterator[Split[Sequence[
|
301
|
+
def _yield_splits3[T](
|
302
|
+
iterable: Iterable[tuple[Iterable[T], int, int]], /
|
303
|
+
) -> Iterator[Split[Sequence[T]]]:
|
313
304
|
for window, len_head, len_tail in iterable:
|
314
305
|
head_win, tail_win = split_into(window, [len_head, len_tail])
|
315
306
|
yield cast(
|
316
|
-
"Split[Sequence[
|
307
|
+
"Split[Sequence[T]]", Split(head=list(head_win), tail=list(tail_win))
|
317
308
|
)
|
318
309
|
|
319
310
|
|
utilities/operator.py
CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
3
3
|
from collections.abc import Callable, Mapping, Sequence
|
4
4
|
from collections.abc import Set as AbstractSet
|
5
5
|
from dataclasses import asdict, dataclass
|
6
|
-
from typing import TYPE_CHECKING, Any,
|
6
|
+
from typing import TYPE_CHECKING, Any, cast, override
|
7
7
|
|
8
8
|
import utilities.math
|
9
9
|
from utilities.functions import is_dataclass_instance
|
@@ -13,17 +13,15 @@ from utilities.reprlib import get_repr
|
|
13
13
|
if TYPE_CHECKING:
|
14
14
|
from utilities.types import Dataclass, Number
|
15
15
|
|
16
|
-
_T = TypeVar("_T")
|
17
16
|
|
18
|
-
|
19
|
-
def is_equal(
|
17
|
+
def is_equal[T](
|
20
18
|
x: Any,
|
21
19
|
y: Any,
|
22
20
|
/,
|
23
21
|
*,
|
24
22
|
rel_tol: float | None = None,
|
25
23
|
abs_tol: float | None = None,
|
26
|
-
extra: Mapping[type[
|
24
|
+
extra: Mapping[type[T], Callable[[T, T], bool]] | None = None,
|
27
25
|
) -> bool:
|
28
26
|
"""Check if two objects are equal."""
|
29
27
|
if type(x) is type(y):
|
@@ -34,8 +32,8 @@ def is_equal(
|
|
34
32
|
except StopIteration:
|
35
33
|
pass
|
36
34
|
else:
|
37
|
-
x = cast("
|
38
|
-
y = cast("
|
35
|
+
x = cast("T", x)
|
36
|
+
y = cast("T", y)
|
39
37
|
return cmp(x, y)
|
40
38
|
|
41
39
|
# singletons
|
utilities/optuna.py
CHANGED
@@ -6,6 +6,8 @@ from typing import TYPE_CHECKING
|
|
6
6
|
from optuna import create_study
|
7
7
|
from sqlalchemy import URL
|
8
8
|
|
9
|
+
from utilities.types import Dataclass
|
10
|
+
|
9
11
|
if TYPE_CHECKING:
|
10
12
|
from collections.abc import Callable, Sequence
|
11
13
|
|
@@ -14,7 +16,7 @@ if TYPE_CHECKING:
|
|
14
16
|
from optuna.samplers import BaseSampler
|
15
17
|
from optuna.study import StudyDirection
|
16
18
|
|
17
|
-
from utilities.types import PathLike
|
19
|
+
from utilities.types import PathLike
|
18
20
|
|
19
21
|
|
20
22
|
def create_sqlite_study(
|
@@ -45,7 +47,7 @@ def create_sqlite_study(
|
|
45
47
|
##
|
46
48
|
|
47
49
|
|
48
|
-
def get_best_params(study: Study, cls: type[
|
50
|
+
def get_best_params[T: Dataclass](study: Study, cls: type[T], /) -> T:
|
49
51
|
"""Get the best params as a dataclass."""
|
50
52
|
return cls(**study.best_params)
|
51
53
|
|
@@ -53,10 +55,8 @@ def get_best_params(study: Study, cls: type[TDataclass], /) -> TDataclass:
|
|
53
55
|
##
|
54
56
|
|
55
57
|
|
56
|
-
def make_objective(
|
57
|
-
suggest_params: Callable[[Trial],
|
58
|
-
objective: Callable[[TDataclass], float],
|
59
|
-
/,
|
58
|
+
def make_objective[T: Dataclass](
|
59
|
+
suggest_params: Callable[[Trial], T], objective: Callable[[T], float], /
|
60
60
|
) -> Callable[[Trial], float]:
|
61
61
|
"""Make an objective given separate trialling & evaluating functions."""
|
62
62
|
|
utilities/pathlib.py
CHANGED
utilities/period.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
from dataclasses import dataclass
|
4
|
-
from typing import TYPE_CHECKING,
|
4
|
+
from typing import TYPE_CHECKING, Self, TypedDict, override
|
5
5
|
from zoneinfo import ZoneInfo
|
6
6
|
|
7
7
|
from whenever import Date, DateDelta, TimeDelta, ZonedDateTime
|
@@ -14,12 +14,10 @@ from utilities.zoneinfo import get_time_zone_name
|
|
14
14
|
if TYPE_CHECKING:
|
15
15
|
from utilities.types import TimeZoneLike
|
16
16
|
|
17
|
-
_TPeriod = TypeVar("_TPeriod", Date, ZonedDateTime)
|
18
17
|
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
end: _TPeriod
|
18
|
+
class _PeriodAsDict[T: (Date, ZonedDateTime)](TypedDict):
|
19
|
+
start: T
|
20
|
+
end: T
|
23
21
|
|
24
22
|
|
25
23
|
@dataclass(repr=False, order=True, unsafe_hash=True, kw_only=False)
|
@@ -132,9 +130,9 @@ class PeriodError(Exception): ...
|
|
132
130
|
|
133
131
|
|
134
132
|
@dataclass(kw_only=True, slots=True)
|
135
|
-
class _PeriodInvalidError(
|
136
|
-
start:
|
137
|
-
end:
|
133
|
+
class _PeriodInvalidError[T: (Date, ZonedDateTime)](PeriodError):
|
134
|
+
start: T
|
135
|
+
end: T
|
138
136
|
|
139
137
|
@override
|
140
138
|
def __str__(self) -> str:
|
utilities/polars.py
CHANGED
@@ -10,17 +10,7 @@ from functools import partial, reduce
|
|
10
10
|
from itertools import chain, product
|
11
11
|
from math import ceil, log
|
12
12
|
from pathlib import Path
|
13
|
-
from typing import
|
14
|
-
TYPE_CHECKING,
|
15
|
-
Any,
|
16
|
-
Generic,
|
17
|
-
Literal,
|
18
|
-
TypeVar,
|
19
|
-
assert_never,
|
20
|
-
cast,
|
21
|
-
overload,
|
22
|
-
override,
|
23
|
-
)
|
13
|
+
from typing import TYPE_CHECKING, Any, Literal, assert_never, cast, overload, override
|
24
14
|
from uuid import UUID
|
25
15
|
from zoneinfo import ZoneInfo
|
26
16
|
|
@@ -127,7 +117,6 @@ if TYPE_CHECKING:
|
|
127
117
|
from utilities.types import Dataclass, MaybeIterable, StrMapping, TimeZoneLike
|
128
118
|
|
129
119
|
|
130
|
-
_T = TypeVar("_T")
|
131
120
|
type ExprLike = MaybeStr[Expr]
|
132
121
|
DatetimeHongKong = Datetime(time_zone="Asia/Hong_Kong")
|
133
122
|
DatetimeTokyo = Datetime(time_zone="Asia/Tokyo")
|
@@ -286,10 +275,10 @@ def append_dataclass(df: DataFrame, obj: Dataclass, /) -> DataFrame:
|
|
286
275
|
|
287
276
|
|
288
277
|
@dataclass(kw_only=True, slots=True)
|
289
|
-
class AppendDataClassError(Exception
|
290
|
-
left: AbstractSet[
|
291
|
-
right: AbstractSet[
|
292
|
-
extra: AbstractSet[
|
278
|
+
class AppendDataClassError[T](Exception):
|
279
|
+
left: AbstractSet[T]
|
280
|
+
right: AbstractSet[T]
|
281
|
+
extra: AbstractSet[T]
|
293
282
|
|
294
283
|
@override
|
295
284
|
def __str__(self) -> str:
|
utilities/pqdm.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
from functools import partial
|
4
|
-
from typing import TYPE_CHECKING, Any, Literal,
|
4
|
+
from typing import TYPE_CHECKING, Any, Literal, assert_never
|
5
5
|
|
6
6
|
from pqdm import processes, threads
|
7
7
|
from tqdm.auto import tqdm as tqdm_auto
|
@@ -20,12 +20,11 @@ if TYPE_CHECKING:
|
|
20
20
|
from utilities.types import Parallelism
|
21
21
|
|
22
22
|
|
23
|
-
_T = TypeVar("_T")
|
24
23
|
type _ExceptionBehaviour = Literal["ignore", "immediate", "deferred"]
|
25
24
|
|
26
25
|
|
27
|
-
def pqdm_map(
|
28
|
-
func: Callable[...,
|
26
|
+
def pqdm_map[T](
|
27
|
+
func: Callable[..., T],
|
29
28
|
/,
|
30
29
|
*iterables: Iterable[Any],
|
31
30
|
parallelism: Parallelism = "processes",
|
@@ -35,7 +34,7 @@ def pqdm_map(
|
|
35
34
|
tqdm_class: tqdm_type = tqdm_auto, # pyright: ignore[reportArgumentType]
|
36
35
|
desc: str | None | Sentinel = sentinel,
|
37
36
|
**kwargs: Any,
|
38
|
-
) -> list[
|
37
|
+
) -> list[T]:
|
39
38
|
"""Parallel map, powered by `pqdm`."""
|
40
39
|
return pqdm_starmap(
|
41
40
|
func,
|
@@ -50,8 +49,8 @@ def pqdm_map(
|
|
50
49
|
)
|
51
50
|
|
52
51
|
|
53
|
-
def pqdm_starmap(
|
54
|
-
func: Callable[...,
|
52
|
+
def pqdm_starmap[T](
|
53
|
+
func: Callable[..., T],
|
55
54
|
iterable: Iterable[tuple[Any, ...]],
|
56
55
|
/,
|
57
56
|
*,
|
@@ -62,7 +61,7 @@ def pqdm_starmap(
|
|
62
61
|
tqdm_class: tqdm_type = tqdm_auto, # pyright: ignore[reportArgumentType]
|
63
62
|
desc: str | None | Sentinel = sentinel,
|
64
63
|
**kwargs: Any,
|
65
|
-
) -> list[
|
64
|
+
) -> list[T]:
|
66
65
|
"""Parallel starmap, powered by `pqdm`."""
|
67
66
|
apply = partial(apply_to_varargs, func)
|
68
67
|
n_jobs_use = get_cpu_use(n=n_jobs)
|