dycw-utilities 0.112.9__py3-none-any.whl → 0.112.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dycw_utilities-0.112.9.dist-info → dycw_utilities-0.112.11.dist-info}/METADATA +12 -12
- {dycw_utilities-0.112.9.dist-info → dycw_utilities-0.112.11.dist-info}/RECORD +7 -7
- utilities/__init__.py +1 -1
- utilities/numpy.py +20 -1
- utilities/polars.py +321 -9
- {dycw_utilities-0.112.9.dist-info → dycw_utilities-0.112.11.dist-info}/WHEEL +0 -0
- {dycw_utilities-0.112.9.dist-info → dycw_utilities-0.112.11.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: dycw-utilities
|
3
|
-
Version: 0.112.
|
3
|
+
Version: 0.112.11
|
4
4
|
Author-email: Derek Wan <d.wan@icloud.com>
|
5
5
|
License-File: LICENSE
|
6
6
|
Requires-Python: >=3.12
|
@@ -24,7 +24,7 @@ Provides-Extra: zzz-test-altair
|
|
24
24
|
Requires-Dist: altair<5.6,>=5.5.0; extra == 'zzz-test-altair'
|
25
25
|
Requires-Dist: atomicwrites<1.5,>=1.4.1; extra == 'zzz-test-altair'
|
26
26
|
Requires-Dist: img2pdf<0.7,>=0.6.0; extra == 'zzz-test-altair'
|
27
|
-
Requires-Dist: polars-lts-cpu<1.
|
27
|
+
Requires-Dist: polars-lts-cpu<1.30,>=1.29.0; extra == 'zzz-test-altair'
|
28
28
|
Requires-Dist: vl-convert-python<1.8,>=1.7.0; extra == 'zzz-test-altair'
|
29
29
|
Requires-Dist: whenever<0.8,>=0.7.3; extra == 'zzz-test-altair'
|
30
30
|
Provides-Extra: zzz-test-astor
|
@@ -48,7 +48,7 @@ Provides-Extra: zzz-test-cvxpy
|
|
48
48
|
Requires-Dist: cvxpy<1.7,>=1.6.5; extra == 'zzz-test-cvxpy'
|
49
49
|
Provides-Extra: zzz-test-dataclasses
|
50
50
|
Requires-Dist: orjson<3.11,>=3.10.15; extra == 'zzz-test-dataclasses'
|
51
|
-
Requires-Dist: polars-lts-cpu<1.
|
51
|
+
Requires-Dist: polars-lts-cpu<1.30,>=1.29.0; extra == 'zzz-test-dataclasses'
|
52
52
|
Requires-Dist: whenever<0.8,>=0.7.3; extra == 'zzz-test-dataclasses'
|
53
53
|
Provides-Extra: zzz-test-datetime
|
54
54
|
Requires-Dist: tzlocal<5.4,>=5.3.1; extra == 'zzz-test-datetime'
|
@@ -70,7 +70,7 @@ Provides-Extra: zzz-test-getpass
|
|
70
70
|
Provides-Extra: zzz-test-git
|
71
71
|
Provides-Extra: zzz-test-hashlib
|
72
72
|
Requires-Dist: orjson<3.11,>=3.10.15; extra == 'zzz-test-hashlib'
|
73
|
-
Requires-Dist: polars-lts-cpu<1.
|
73
|
+
Requires-Dist: polars-lts-cpu<1.30,>=1.29.0; extra == 'zzz-test-hashlib'
|
74
74
|
Requires-Dist: whenever<0.8,>=0.7.3; extra == 'zzz-test-hashlib'
|
75
75
|
Provides-Extra: zzz-test-http
|
76
76
|
Requires-Dist: atomicwrites<1.5,>=1.4.1; extra == 'zzz-test-http'
|
@@ -92,12 +92,12 @@ Requires-Dist: whenever<0.8,>=0.7.3; extra == 'zzz-test-hypothesis'
|
|
92
92
|
Provides-Extra: zzz-test-ipython
|
93
93
|
Requires-Dist: ipython<9.1,>=9.0.1; extra == 'zzz-test-ipython'
|
94
94
|
Provides-Extra: zzz-test-iterables
|
95
|
-
Requires-Dist: polars-lts-cpu<1.
|
95
|
+
Requires-Dist: polars-lts-cpu<1.30,>=1.29.0; extra == 'zzz-test-iterables'
|
96
96
|
Requires-Dist: whenever<0.8,>=0.7.3; extra == 'zzz-test-iterables'
|
97
97
|
Provides-Extra: zzz-test-jupyter
|
98
98
|
Requires-Dist: jupyterlab<4.3,>=4.2.0; extra == 'zzz-test-jupyter'
|
99
99
|
Requires-Dist: pandas<2.3,>=2.2.2; extra == 'zzz-test-jupyter'
|
100
|
-
Requires-Dist: polars-lts-cpu<1.
|
100
|
+
Requires-Dist: polars-lts-cpu<1.30,>=1.29.0; extra == 'zzz-test-jupyter'
|
101
101
|
Provides-Extra: zzz-test-logging
|
102
102
|
Requires-Dist: atomicwrites<1.5,>=1.4.1; extra == 'zzz-test-logging'
|
103
103
|
Requires-Dist: coloredlogs<15.1,>=15.0.1; extra == 'zzz-test-logging'
|
@@ -121,13 +121,13 @@ Requires-Dist: more-itertools<10.8,>=10.7.0; extra == 'zzz-test-more-itertools'
|
|
121
121
|
Provides-Extra: zzz-test-numpy
|
122
122
|
Requires-Dist: numpy<2.3,>=2.2.5; extra == 'zzz-test-numpy'
|
123
123
|
Provides-Extra: zzz-test-operator
|
124
|
-
Requires-Dist: polars-lts-cpu<1.
|
124
|
+
Requires-Dist: polars-lts-cpu<1.30,>=1.29.0; extra == 'zzz-test-operator'
|
125
125
|
Requires-Dist: whenever<0.8,>=0.7.3; extra == 'zzz-test-operator'
|
126
126
|
Provides-Extra: zzz-test-optuna
|
127
127
|
Requires-Dist: optuna<4.4,>=4.3.0; extra == 'zzz-test-optuna'
|
128
128
|
Provides-Extra: zzz-test-orjson
|
129
129
|
Requires-Dist: orjson<3.11,>=3.10.15; extra == 'zzz-test-orjson'
|
130
|
-
Requires-Dist: polars-lts-cpu<1.
|
130
|
+
Requires-Dist: polars-lts-cpu<1.30,>=1.29.0; extra == 'zzz-test-orjson'
|
131
131
|
Requires-Dist: rich<14.1,>=14.0.0; extra == 'zzz-test-orjson'
|
132
132
|
Requires-Dist: tzlocal<5.4,>=5.3.1; extra == 'zzz-test-orjson'
|
133
133
|
Requires-Dist: whenever<0.8,>=0.7.3; extra == 'zzz-test-orjson'
|
@@ -138,7 +138,7 @@ Requires-Dist: atomicwrites<1.5,>=1.4.1; extra == 'zzz-test-pickle'
|
|
138
138
|
Provides-Extra: zzz-test-platform
|
139
139
|
Provides-Extra: zzz-test-polars
|
140
140
|
Requires-Dist: dacite<1.10,>=1.9.2; extra == 'zzz-test-polars'
|
141
|
-
Requires-Dist: polars-lts-cpu<1.
|
141
|
+
Requires-Dist: polars-lts-cpu<1.30,>=1.29.0; extra == 'zzz-test-polars'
|
142
142
|
Requires-Dist: whenever<0.8,>=0.7.3; extra == 'zzz-test-polars'
|
143
143
|
Provides-Extra: zzz-test-pqdm
|
144
144
|
Requires-Dist: pqdm<0.3,>=0.2.0; extra == 'zzz-test-pqdm'
|
@@ -164,7 +164,7 @@ Provides-Extra: zzz-test-random
|
|
164
164
|
Provides-Extra: zzz-test-re
|
165
165
|
Provides-Extra: zzz-test-redis
|
166
166
|
Requires-Dist: orjson<3.11,>=3.10.15; extra == 'zzz-test-redis'
|
167
|
-
Requires-Dist: polars-lts-cpu<1.
|
167
|
+
Requires-Dist: polars-lts-cpu<1.30,>=1.29.0; extra == 'zzz-test-redis'
|
168
168
|
Requires-Dist: redis<5.3,>=5.2.1; extra == 'zzz-test-redis'
|
169
169
|
Requires-Dist: rich<14.1,>=14.0.0; extra == 'zzz-test-redis'
|
170
170
|
Requires-Dist: tenacity<9.0,>=8.5.0; extra == 'zzz-test-redis'
|
@@ -192,7 +192,7 @@ Requires-Dist: aiosqlite<0.22,>=0.21.0; extra == 'zzz-test-sqlalchemy-polars'
|
|
192
192
|
Requires-Dist: asyncpg<0.31,>=0.30.0; extra == 'zzz-test-sqlalchemy-polars'
|
193
193
|
Requires-Dist: greenlet<3.3,>=3.2.0; extra == 'zzz-test-sqlalchemy-polars'
|
194
194
|
Requires-Dist: nest-asyncio<1.7,>=1.6.0; extra == 'zzz-test-sqlalchemy-polars'
|
195
|
-
Requires-Dist: polars-lts-cpu<1.
|
195
|
+
Requires-Dist: polars-lts-cpu<1.30,>=1.29.0; extra == 'zzz-test-sqlalchemy-polars'
|
196
196
|
Requires-Dist: sqlalchemy<2.1,>=2.0.40; extra == 'zzz-test-sqlalchemy-polars'
|
197
197
|
Requires-Dist: tenacity<9.0,>=8.5.0; extra == 'zzz-test-sqlalchemy-polars'
|
198
198
|
Requires-Dist: whenever<0.8,>=0.7.3; extra == 'zzz-test-sqlalchemy-polars'
|
@@ -217,7 +217,7 @@ Requires-Dist: tzlocal<5.4,>=5.3.1; extra == 'zzz-test-traceback'
|
|
217
217
|
Requires-Dist: whenever<0.8,>=0.7.3; extra == 'zzz-test-traceback'
|
218
218
|
Provides-Extra: zzz-test-types
|
219
219
|
Provides-Extra: zzz-test-typing
|
220
|
-
Requires-Dist: polars-lts-cpu<1.
|
220
|
+
Requires-Dist: polars-lts-cpu<1.30,>=1.29.0; extra == 'zzz-test-typing'
|
221
221
|
Requires-Dist: whenever<0.8,>=0.7.3; extra == 'zzz-test-typing'
|
222
222
|
Provides-Extra: zzz-test-tzlocal
|
223
223
|
Requires-Dist: tzlocal<5.4,>=5.3.1; extra == 'zzz-test-tzlocal'
|
@@ -1,4 +1,4 @@
|
|
1
|
-
utilities/__init__.py,sha256=
|
1
|
+
utilities/__init__.py,sha256=UdqUImTa6tBT-IqBR4s6iXfh38IrLnn774Bkpt5kz_g,61
|
2
2
|
utilities/altair.py,sha256=Gpja-flOo-Db0PIPJLJsgzAlXWoKUjPU1qY-DQ829ek,9156
|
3
3
|
utilities/astor.py,sha256=xuDUkjq0-b6fhtwjhbnebzbqQZAjMSHR1IIS5uOodVg,777
|
4
4
|
utilities/asyncio.py,sha256=41oQUurWMvadFK5gFnaG21hMM0Vmfn2WS6OpC0R9mas,14757
|
@@ -36,7 +36,7 @@ utilities/math.py,sha256=TexfvLCI12d9Sw5_W4pKVBZ3nRr3zk2iPkcEU7xdEWU,26771
|
|
36
36
|
utilities/memory_profiler.py,sha256=tf2C51P2lCujPGvRt2Rfc7VEw5LDXmVPCG3z_AvBmbU,962
|
37
37
|
utilities/modules.py,sha256=iuvLluJya-hvl1Q25-Jk3dLgx2Es3ck4SjJiEkAlVTs,3195
|
38
38
|
utilities/more_itertools.py,sha256=CPUxrMAcTwRxbzbhiqPKi3Xx9hxqI0t6gkWjutaibGk,5534
|
39
|
-
utilities/numpy.py,sha256=
|
39
|
+
utilities/numpy.py,sha256=Xn23sA2ZbVNqwUYEgNJD3XBYH6IbCri_WkHSNhg3NkY,26122
|
40
40
|
utilities/operator.py,sha256=0M2yZJ0PODH47ogFEnkGMBe_cfxwZR02T_92LZVZvHo,3715
|
41
41
|
utilities/optuna.py,sha256=loyJGWTzljgdJaoLhP09PT8Jz6o_pwBOwehY33lHkhw,1923
|
42
42
|
utilities/orjson.py,sha256=DBm2zPP04kcHpY3l1etL24ksNynu-R3duFyx3U-RjqQ,36948
|
@@ -46,7 +46,7 @@ utilities/pathlib.py,sha256=31WPMXdLIyXgYOMMl_HOI2wlo66MGSE-cgeelk-Lias,1410
|
|
46
46
|
utilities/period.py,sha256=RWfcNVoNlW07RNdU47g_zuLZMKbtgfK4bE6G-9tVjY8,11024
|
47
47
|
utilities/pickle.py,sha256=Bhvd7cZl-zQKQDFjUerqGuSKlHvnW1K2QXeU5UZibtg,657
|
48
48
|
utilities/platform.py,sha256=NU7ycTvAXAG-fdYmDXaM1m4EOml2cGiaYwaUzfzSqyU,1767
|
49
|
-
utilities/polars.py,sha256=
|
49
|
+
utilities/polars.py,sha256=bo2Rhukk2eXxe1RMfu2uvEjTQTd9SmOi8mGW4BRG82c,67288
|
50
50
|
utilities/polars_ols.py,sha256=efhXf0gjrHUpQrvS6a7g8yJQJWf_ATKtJnqqF2inCOU,5680
|
51
51
|
utilities/pqdm.py,sha256=foRytQybmOQ05pjt5LF7ANyzrIa--4ScDE3T2wd31a4,3118
|
52
52
|
utilities/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -87,7 +87,7 @@ utilities/warnings.py,sha256=un1LvHv70PU-LLv8RxPVmugTzDJkkGXRMZTE2-fTQHw,1771
|
|
87
87
|
utilities/whenever.py,sha256=iLRP_-8CZtBpHKbGZGu-kjSMg1ZubJ-VSmgSy7Eudxw,17787
|
88
88
|
utilities/zipfile.py,sha256=24lQc9ATcJxHXBPc_tBDiJk48pWyRrlxO2fIsFxU0A8,699
|
89
89
|
utilities/zoneinfo.py,sha256=-Xm57PMMwDTYpxJdkiJG13wnbwK--I7XItBh5WVhD-o,1874
|
90
|
-
dycw_utilities-0.112.
|
91
|
-
dycw_utilities-0.112.
|
92
|
-
dycw_utilities-0.112.
|
93
|
-
dycw_utilities-0.112.
|
90
|
+
dycw_utilities-0.112.11.dist-info/METADATA,sha256=ICb-HGQkFLIEQZoD2MayfokOJhntweh8-YFElfTf7aU,13005
|
91
|
+
dycw_utilities-0.112.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
92
|
+
dycw_utilities-0.112.11.dist-info/licenses/LICENSE,sha256=gppZp16M6nSVpBbUBrNL6JuYfvKwZiKgV7XoKKsHzqo,1066
|
93
|
+
dycw_utilities-0.112.11.dist-info/RECORD,,
|
utilities/__init__.py
CHANGED
utilities/numpy.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
from collections.abc import Sequence
|
3
4
|
from dataclasses import dataclass
|
4
5
|
from functools import partial, reduce
|
5
6
|
from itertools import repeat
|
6
|
-
from typing import TYPE_CHECKING, Any, overload, override
|
7
|
+
from typing import TYPE_CHECKING, Any, SupportsIndex, overload, override
|
7
8
|
|
8
9
|
import numpy as np
|
9
10
|
from numpy import (
|
@@ -46,6 +47,9 @@ if TYPE_CHECKING:
|
|
46
47
|
from utilities.types import MaybeIterable
|
47
48
|
|
48
49
|
|
50
|
+
type ShapeLike = SupportsIndex | Sequence[SupportsIndex]
|
51
|
+
|
52
|
+
|
49
53
|
##
|
50
54
|
|
51
55
|
|
@@ -136,6 +140,19 @@ class AsIntError(Exception): ...
|
|
136
140
|
##
|
137
141
|
|
138
142
|
|
143
|
+
def bernoulli(
|
144
|
+
*, true: float = 0.5, seed: int | None = None, size: ShapeLike = ()
|
145
|
+
) -> NDArrayB:
|
146
|
+
"""Return a set of Bernoulli random variates."""
|
147
|
+
from numpy.random import default_rng
|
148
|
+
|
149
|
+
rng = default_rng(seed=seed)
|
150
|
+
return rng.binomial(1, true, size=size).astype(bool)
|
151
|
+
|
152
|
+
|
153
|
+
##
|
154
|
+
|
155
|
+
|
139
156
|
def boxcar(
|
140
157
|
array: NDArrayF,
|
141
158
|
/,
|
@@ -980,11 +997,13 @@ __all__ = [
|
|
980
997
|
"NDArrayF",
|
981
998
|
"NDArrayI",
|
982
999
|
"NDArrayO",
|
1000
|
+
"ShapeLike",
|
983
1001
|
"ShiftError",
|
984
1002
|
"SigmoidError",
|
985
1003
|
"adjust_frequencies",
|
986
1004
|
"array_indexer",
|
987
1005
|
"as_int",
|
1006
|
+
"bernoulli",
|
988
1007
|
"boxcar",
|
989
1008
|
"datetime64D",
|
990
1009
|
"datetime64M",
|
utilities/polars.py
CHANGED
@@ -7,7 +7,7 @@ from collections.abc import Set as AbstractSet
|
|
7
7
|
from contextlib import suppress
|
8
8
|
from dataclasses import asdict, dataclass
|
9
9
|
from functools import partial, reduce
|
10
|
-
from itertools import chain
|
10
|
+
from itertools import chain, product
|
11
11
|
from math import ceil, log
|
12
12
|
from pathlib import Path
|
13
13
|
from typing import (
|
@@ -39,17 +39,21 @@ from polars import (
|
|
39
39
|
Series,
|
40
40
|
String,
|
41
41
|
Struct,
|
42
|
+
UInt32,
|
42
43
|
all_horizontal,
|
44
|
+
any_horizontal,
|
43
45
|
col,
|
44
46
|
concat,
|
45
47
|
int_range,
|
46
48
|
lit,
|
47
49
|
struct,
|
50
|
+
sum_horizontal,
|
48
51
|
when,
|
49
52
|
)
|
50
53
|
from polars.datatypes import DataType, DataTypeClass
|
51
54
|
from polars.exceptions import (
|
52
55
|
ColumnNotFoundError,
|
56
|
+
NoRowsReturnedError,
|
53
57
|
OutOfBoundsError,
|
54
58
|
PolarsInefficientMapWarning,
|
55
59
|
)
|
@@ -337,6 +341,102 @@ def are_frames_equal(
|
|
337
341
|
##
|
338
342
|
|
339
343
|
|
344
|
+
def bernoulli(
|
345
|
+
obj: int | Series | DataFrame,
|
346
|
+
/,
|
347
|
+
*,
|
348
|
+
true: float = 0.5,
|
349
|
+
seed: int | None = None,
|
350
|
+
name: str | None = None,
|
351
|
+
) -> Series:
|
352
|
+
"""Construct a series of Bernoulli-random variables."""
|
353
|
+
match obj:
|
354
|
+
case int() as height:
|
355
|
+
import utilities.numpy
|
356
|
+
|
357
|
+
values = utilities.numpy.bernoulli(true=true, seed=seed, size=height)
|
358
|
+
return Series(name=name, values=values)
|
359
|
+
case Series() as series:
|
360
|
+
return bernoulli(series.len(), true=true, seed=seed, name=name)
|
361
|
+
case DataFrame() as df:
|
362
|
+
return bernoulli(df.height, true=true, seed=seed, name=name)
|
363
|
+
case _ as never:
|
364
|
+
assert_never(never)
|
365
|
+
|
366
|
+
|
367
|
+
##
|
368
|
+
|
369
|
+
|
370
|
+
def boolean_value_counts(
|
371
|
+
obj: Series | DataFrame, /, *exprs: IntoExprColumn, **named_exprs: IntoExprColumn
|
372
|
+
) -> DataFrame:
|
373
|
+
"""Conduct a set of boolean value counts."""
|
374
|
+
match obj:
|
375
|
+
case Series() as series:
|
376
|
+
return boolean_value_counts(series.to_frame(), *exprs, **named_exprs)
|
377
|
+
case DataFrame() as df:
|
378
|
+
all_exprs = ensure_expr_or_series_many(*exprs, **named_exprs)
|
379
|
+
rows = [_boolean_value_counts_one(df, expr) for expr in all_exprs]
|
380
|
+
true, false, null = [col(c) for c in ["true", "false", "null"]]
|
381
|
+
total = sum_horizontal(true, false, null).alias("total")
|
382
|
+
return DataFrame(
|
383
|
+
rows,
|
384
|
+
schema={
|
385
|
+
"name": String,
|
386
|
+
"true": UInt32,
|
387
|
+
"false": UInt32,
|
388
|
+
"null": UInt32,
|
389
|
+
},
|
390
|
+
orient="row",
|
391
|
+
).with_columns(
|
392
|
+
total,
|
393
|
+
(true / total).alias("true (%)"),
|
394
|
+
(false / total).alias("false (%)"),
|
395
|
+
(null / total).alias("null (%)"),
|
396
|
+
)
|
397
|
+
case _ as never:
|
398
|
+
assert_never(never)
|
399
|
+
|
400
|
+
|
401
|
+
def _boolean_value_counts_one(
|
402
|
+
df: DataFrame, expr: IntoExprColumn, /
|
403
|
+
) -> Mapping[str, Any]:
|
404
|
+
name = get_expr_name(df, expr)
|
405
|
+
sr = df.select(expr)[name]
|
406
|
+
if not isinstance(sr.dtype, Boolean):
|
407
|
+
raise BooleanValueCountsError(name=name, dtype=sr.dtype)
|
408
|
+
counts = sr.value_counts()
|
409
|
+
truth = col(name)
|
410
|
+
try:
|
411
|
+
true = counts.row(by_predicate=truth.is_not_null() & truth, named=True)["count"]
|
412
|
+
except NoRowsReturnedError:
|
413
|
+
true = 0
|
414
|
+
try:
|
415
|
+
false = counts.row(by_predicate=(truth.is_not_null() & ~truth), named=True)[
|
416
|
+
"count"
|
417
|
+
]
|
418
|
+
except NoRowsReturnedError:
|
419
|
+
false = 0
|
420
|
+
try:
|
421
|
+
null = counts.row(by_predicate=truth.is_null(), named=True)["count"]
|
422
|
+
except NoRowsReturnedError:
|
423
|
+
null = 0
|
424
|
+
return {"name": name, "true": true, "false": false, "null": null}
|
425
|
+
|
426
|
+
|
427
|
+
@dataclass(kw_only=True, slots=True)
|
428
|
+
class BooleanValueCountsError(Exception):
|
429
|
+
name: str
|
430
|
+
dtype: DataType
|
431
|
+
|
432
|
+
@override
|
433
|
+
def __str__(self) -> str:
|
434
|
+
return f"Column {self.name!r} must be Boolean; got {self.dtype!r}"
|
435
|
+
|
436
|
+
|
437
|
+
##
|
438
|
+
|
439
|
+
|
340
440
|
@overload
|
341
441
|
def ceil_datetime(column: ExprLike, every: ExprLike, /) -> Expr: ...
|
342
442
|
@overload
|
@@ -637,6 +737,54 @@ class _CheckPolarsDataFrameWidthError(CheckPolarsDataFrameError):
|
|
637
737
|
##
|
638
738
|
|
639
739
|
|
740
|
+
def choice(
|
741
|
+
obj: int | Series | DataFrame,
|
742
|
+
elements: Iterable[Any],
|
743
|
+
/,
|
744
|
+
*,
|
745
|
+
replace: bool = True,
|
746
|
+
p: Iterable[float] | None = None,
|
747
|
+
seed: int | None = None,
|
748
|
+
name: str | None = None,
|
749
|
+
dtype: PolarsDataType = Float64,
|
750
|
+
) -> Series:
|
751
|
+
"""Construct a series of random samples."""
|
752
|
+
match obj:
|
753
|
+
case int() as height:
|
754
|
+
from numpy.random import default_rng
|
755
|
+
|
756
|
+
rng = default_rng(seed=seed)
|
757
|
+
elements = list(elements)
|
758
|
+
p = None if p is None else list(p)
|
759
|
+
values = rng.choice(elements, size=height, replace=replace, p=p)
|
760
|
+
return Series(name=name, values=values.tolist(), dtype=dtype)
|
761
|
+
case Series() as series:
|
762
|
+
return choice(
|
763
|
+
series.len(),
|
764
|
+
elements,
|
765
|
+
replace=replace,
|
766
|
+
p=p,
|
767
|
+
seed=seed,
|
768
|
+
name=name,
|
769
|
+
dtype=dtype,
|
770
|
+
)
|
771
|
+
case DataFrame() as df:
|
772
|
+
return choice(
|
773
|
+
df.height,
|
774
|
+
elements,
|
775
|
+
replace=replace,
|
776
|
+
p=p,
|
777
|
+
seed=seed,
|
778
|
+
name=name,
|
779
|
+
dtype=dtype,
|
780
|
+
)
|
781
|
+
case _ as never:
|
782
|
+
assert_never(never)
|
783
|
+
|
784
|
+
|
785
|
+
##
|
786
|
+
|
787
|
+
|
640
788
|
def collect_series(expr: Expr, /) -> Series:
|
641
789
|
"""Collect a column expression into a Series."""
|
642
790
|
data = DataFrame().with_columns(expr)
|
@@ -1053,6 +1201,18 @@ def ensure_expr_or_series(column: IntoExprColumn, /) -> Expr | Series:
|
|
1053
1201
|
##
|
1054
1202
|
|
1055
1203
|
|
1204
|
+
def ensure_expr_or_series_many(
|
1205
|
+
*columns: IntoExprColumn, **named_columns: IntoExprColumn
|
1206
|
+
) -> Sequence[Expr | Series]:
|
1207
|
+
"""Ensure a set of column expressions and/or Series are returned."""
|
1208
|
+
args = map(ensure_expr_or_series, columns)
|
1209
|
+
kwargs = (ensure_expr_or_series(v).alias(k) for k, v in named_columns.items())
|
1210
|
+
return list(chain(args, kwargs))
|
1211
|
+
|
1212
|
+
|
1213
|
+
##
|
1214
|
+
|
1215
|
+
|
1056
1216
|
@overload
|
1057
1217
|
def finite_ewm_mean(
|
1058
1218
|
column: ExprLike,
|
@@ -1106,13 +1266,7 @@ def finite_ewm_mean(
|
|
1106
1266
|
column = ensure_expr_or_series(column)
|
1107
1267
|
mean = column.fill_null(value=0.0).rolling_mean(len(weights), weights=list(weights))
|
1108
1268
|
expr = when(column.is_not_null()).then(mean)
|
1109
|
-
|
1110
|
-
case Expr():
|
1111
|
-
return expr
|
1112
|
-
case Series() as series:
|
1113
|
-
return series.to_frame().with_columns(expr.alias(series.name))[series.name]
|
1114
|
-
case _ as never:
|
1115
|
-
assert_never(never)
|
1269
|
+
return try_reify_expr(expr, column)
|
1116
1270
|
|
1117
1271
|
|
1118
1272
|
@dataclass(kw_only=True)
|
@@ -1227,7 +1381,7 @@ class _GetDataTypeOrSeriesTimeZoneNotZonedError(GetDataTypeOrSeriesTimeZoneError
|
|
1227
1381
|
##
|
1228
1382
|
|
1229
1383
|
|
1230
|
-
def get_expr_name(obj: Series | DataFrame, expr:
|
1384
|
+
def get_expr_name(obj: Series | DataFrame, expr: IntoExprColumn, /) -> str:
|
1231
1385
|
"""Get the name of an expression."""
|
1232
1386
|
match obj:
|
1233
1387
|
case Series() as series:
|
@@ -1440,6 +1594,69 @@ def integers(
|
|
1440
1594
|
##
|
1441
1595
|
|
1442
1596
|
|
1597
|
+
@overload
|
1598
|
+
def is_near_event(
|
1599
|
+
*exprs: ExprLike, before: int = 0, after: int = 0, **named_exprs: ExprLike
|
1600
|
+
) -> Expr: ...
|
1601
|
+
@overload
|
1602
|
+
def is_near_event(
|
1603
|
+
*exprs: Series, before: int = 0, after: int = 0, **named_exprs: Series
|
1604
|
+
) -> Series: ...
|
1605
|
+
@overload
|
1606
|
+
def is_near_event(
|
1607
|
+
*exprs: IntoExprColumn,
|
1608
|
+
before: int = 0,
|
1609
|
+
after: int = 0,
|
1610
|
+
**named_exprs: IntoExprColumn,
|
1611
|
+
) -> Expr | Series: ...
|
1612
|
+
def is_near_event(
|
1613
|
+
*exprs: IntoExprColumn,
|
1614
|
+
before: int = 0,
|
1615
|
+
after: int = 0,
|
1616
|
+
**named_exprs: IntoExprColumn,
|
1617
|
+
) -> Expr | Series:
|
1618
|
+
"""Compute the rows near any event."""
|
1619
|
+
if before <= -1:
|
1620
|
+
raise _IsNearEventBeforeError(before=before)
|
1621
|
+
if after <= -1:
|
1622
|
+
raise _IsNearEventAfterError(after=after)
|
1623
|
+
all_exprs = ensure_expr_or_series_many(*exprs, **named_exprs)
|
1624
|
+
shifts = range(-before, after + 1)
|
1625
|
+
if len(all_exprs) == 0:
|
1626
|
+
near = lit(value=False, dtype=Boolean)
|
1627
|
+
else:
|
1628
|
+
near_exprs = (
|
1629
|
+
e.shift(s).fill_null(value=False) for e, s in product(all_exprs, shifts)
|
1630
|
+
)
|
1631
|
+
near = any_horizontal(*near_exprs)
|
1632
|
+
return try_reify_expr(near, *exprs, **named_exprs)
|
1633
|
+
|
1634
|
+
|
1635
|
+
@dataclass(kw_only=True, slots=True)
|
1636
|
+
class IsNearEventError(Exception): ...
|
1637
|
+
|
1638
|
+
|
1639
|
+
@dataclass(kw_only=True, slots=True)
|
1640
|
+
class _IsNearEventBeforeError(IsNearEventError):
|
1641
|
+
before: int
|
1642
|
+
|
1643
|
+
@override
|
1644
|
+
def __str__(self) -> str:
|
1645
|
+
return f"'Before' must be non-negative; got {self.before}"
|
1646
|
+
|
1647
|
+
|
1648
|
+
@dataclass(kw_only=True, slots=True)
|
1649
|
+
class _IsNearEventAfterError(IsNearEventError):
|
1650
|
+
after: int
|
1651
|
+
|
1652
|
+
@override
|
1653
|
+
def __str__(self) -> str:
|
1654
|
+
return f"'After' must be non-negative; got {self.after}"
|
1655
|
+
|
1656
|
+
|
1657
|
+
##
|
1658
|
+
|
1659
|
+
|
1443
1660
|
def is_not_null_struct_series(series: Series, /) -> Series:
|
1444
1661
|
"""Check if a struct-dtype Series is not null as per the <= 1.1 definition."""
|
1445
1662
|
try:
|
@@ -1632,6 +1849,71 @@ def normal(
|
|
1632
1849
|
##
|
1633
1850
|
|
1634
1851
|
|
1852
|
+
def reify_exprs(
|
1853
|
+
*exprs: IntoExprColumn, **named_exprs: IntoExprColumn
|
1854
|
+
) -> Expr | Series | DataFrame:
|
1855
|
+
"""Reify a set of expressions."""
|
1856
|
+
all_exprs = ensure_expr_or_series_many(*exprs, **named_exprs)
|
1857
|
+
if len(all_exprs) == 0:
|
1858
|
+
raise _ReifyExprsEmptyError from None
|
1859
|
+
series = [s for s in all_exprs if isinstance(s, Series)]
|
1860
|
+
lengths = {s.len() for s in series}
|
1861
|
+
try:
|
1862
|
+
length = one(lengths)
|
1863
|
+
except OneEmptyError:
|
1864
|
+
match len(all_exprs):
|
1865
|
+
case 0:
|
1866
|
+
raise ImpossibleCaseError(
|
1867
|
+
case=[f"{all_exprs=}"]
|
1868
|
+
) from None # pragma: no cover
|
1869
|
+
case 1:
|
1870
|
+
return one(all_exprs)
|
1871
|
+
case _:
|
1872
|
+
return struct(*all_exprs)
|
1873
|
+
except OneNonUniqueError as error:
|
1874
|
+
raise _ReifyExprsSeriesNonUniqueError(
|
1875
|
+
first=error.first, second=error.second
|
1876
|
+
) from None
|
1877
|
+
df = (
|
1878
|
+
int_range(end=length, eager=True)
|
1879
|
+
.alias("_index")
|
1880
|
+
.to_frame()
|
1881
|
+
.with_columns(*all_exprs)
|
1882
|
+
.drop("_index")
|
1883
|
+
)
|
1884
|
+
match len(df.columns):
|
1885
|
+
case 0:
|
1886
|
+
raise ImpossibleCaseError(case=[f"{df.columns=}"]) # pragma: no cover
|
1887
|
+
case 1:
|
1888
|
+
return df[one(df.columns)]
|
1889
|
+
case _:
|
1890
|
+
return df
|
1891
|
+
|
1892
|
+
|
1893
|
+
@dataclass(kw_only=True, slots=True)
|
1894
|
+
class ReifyExprsError(Exception): ...
|
1895
|
+
|
1896
|
+
|
1897
|
+
@dataclass(kw_only=True, slots=True)
|
1898
|
+
class _ReifyExprsEmptyError(ReifyExprsError):
|
1899
|
+
@override
|
1900
|
+
def __str__(self) -> str:
|
1901
|
+
return "At least 1 Expression or Series must be given"
|
1902
|
+
|
1903
|
+
|
1904
|
+
@dataclass
|
1905
|
+
class _ReifyExprsSeriesNonUniqueError(ReifyExprsError):
|
1906
|
+
first: int
|
1907
|
+
second: int
|
1908
|
+
|
1909
|
+
@override
|
1910
|
+
def __str__(self) -> str:
|
1911
|
+
return f"Series must contain exactly one length; got {self.first}, {self.second} and perhaps more"
|
1912
|
+
|
1913
|
+
|
1914
|
+
##
|
1915
|
+
|
1916
|
+
|
1635
1917
|
@overload
|
1636
1918
|
def replace_time_zone(
|
1637
1919
|
obj: Series, /, *, time_zone: TimeZoneLike | None = UTC
|
@@ -1771,6 +2053,28 @@ class _StructFromDataClassTypeError(StructFromDataClassError):
|
|
1771
2053
|
##
|
1772
2054
|
|
1773
2055
|
|
2056
|
+
def try_reify_expr(
|
2057
|
+
expr: IntoExprColumn, /, *exprs: IntoExprColumn, **named_exprs: IntoExprColumn
|
2058
|
+
) -> Expr | Series:
|
2059
|
+
"""Try reify an expression."""
|
2060
|
+
expr = ensure_expr_or_series(expr)
|
2061
|
+
all_exprs = ensure_expr_or_series_many(*exprs, **named_exprs)
|
2062
|
+
all_exprs = [e.alias(f"_{i}") for i, e in enumerate(all_exprs)]
|
2063
|
+
result = reify_exprs(expr, *all_exprs)
|
2064
|
+
match result:
|
2065
|
+
case Expr():
|
2066
|
+
return expr
|
2067
|
+
case Series() as series:
|
2068
|
+
return series
|
2069
|
+
case DataFrame() as df:
|
2070
|
+
return df[get_expr_name(df, expr)]
|
2071
|
+
case _ as never:
|
2072
|
+
assert_never(never)
|
2073
|
+
|
2074
|
+
|
2075
|
+
##
|
2076
|
+
|
2077
|
+
|
1774
2078
|
def uniform(
|
1775
2079
|
obj: int | Series | DataFrame,
|
1776
2080
|
/,
|
@@ -1948,6 +2252,7 @@ def zoned_datetime(
|
|
1948
2252
|
|
1949
2253
|
|
1950
2254
|
__all__ = [
|
2255
|
+
"BooleanValueCountsError",
|
1951
2256
|
"CheckPolarsDataFrameError",
|
1952
2257
|
"ColumnsToDictError",
|
1953
2258
|
"DataClassToDataFrameError",
|
@@ -1963,6 +2268,7 @@ __all__ = [
|
|
1963
2268
|
"InsertAfterError",
|
1964
2269
|
"InsertBeforeError",
|
1965
2270
|
"InsertBetweenError",
|
2271
|
+
"IsNearEventError",
|
1966
2272
|
"IsNullStructSeriesError",
|
1967
2273
|
"SetFirstRowAsColumnsError",
|
1968
2274
|
"StructFromDataClassError",
|
@@ -1971,8 +2277,11 @@ __all__ = [
|
|
1971
2277
|
"adjust_frequencies",
|
1972
2278
|
"append_dataclass",
|
1973
2279
|
"are_frames_equal",
|
2280
|
+
"bernoulli",
|
2281
|
+
"boolean_value_counts",
|
1974
2282
|
"ceil_datetime",
|
1975
2283
|
"check_polars_dataframe",
|
2284
|
+
"choice",
|
1976
2285
|
"collect_series",
|
1977
2286
|
"columns_to_dict",
|
1978
2287
|
"concat_series",
|
@@ -1983,6 +2292,7 @@ __all__ = [
|
|
1983
2292
|
"drop_null_struct_series",
|
1984
2293
|
"ensure_data_type",
|
1985
2294
|
"ensure_expr_or_series",
|
2295
|
+
"ensure_expr_or_series_many",
|
1986
2296
|
"finite_ewm_mean",
|
1987
2297
|
"floor_datetime",
|
1988
2298
|
"get_data_type_or_series_time_zone",
|
@@ -1993,6 +2303,7 @@ __all__ = [
|
|
1993
2303
|
"insert_before",
|
1994
2304
|
"insert_between",
|
1995
2305
|
"integers",
|
2306
|
+
"is_near_event",
|
1996
2307
|
"is_not_null_struct_series",
|
1997
2308
|
"is_null_struct_series",
|
1998
2309
|
"join",
|
@@ -2005,6 +2316,7 @@ __all__ = [
|
|
2005
2316
|
"struct_dtype",
|
2006
2317
|
"struct_from_dataclass",
|
2007
2318
|
"touch",
|
2319
|
+
"try_reify_expr",
|
2008
2320
|
"uniform",
|
2009
2321
|
"unique_element",
|
2010
2322
|
"yield_struct_series_dataclasses",
|
File without changes
|
File without changes
|