dycw-utilities 0.112.8__py3-none-any.whl → 0.112.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dycw-utilities
3
- Version: 0.112.8
3
+ Version: 0.112.10
4
4
  Author-email: Derek Wan <d.wan@icloud.com>
5
5
  License-File: LICENSE
6
6
  Requires-Python: >=3.12
@@ -24,7 +24,7 @@ Provides-Extra: zzz-test-altair
24
24
  Requires-Dist: altair<5.6,>=5.5.0; extra == 'zzz-test-altair'
25
25
  Requires-Dist: atomicwrites<1.5,>=1.4.1; extra == 'zzz-test-altair'
26
26
  Requires-Dist: img2pdf<0.7,>=0.6.0; extra == 'zzz-test-altair'
27
- Requires-Dist: polars-lts-cpu<1.29,>=1.28.1; extra == 'zzz-test-altair'
27
+ Requires-Dist: polars-lts-cpu<1.30,>=1.29.0; extra == 'zzz-test-altair'
28
28
  Requires-Dist: vl-convert-python<1.8,>=1.7.0; extra == 'zzz-test-altair'
29
29
  Requires-Dist: whenever<0.8,>=0.7.3; extra == 'zzz-test-altair'
30
30
  Provides-Extra: zzz-test-astor
@@ -48,7 +48,7 @@ Provides-Extra: zzz-test-cvxpy
48
48
  Requires-Dist: cvxpy<1.7,>=1.6.5; extra == 'zzz-test-cvxpy'
49
49
  Provides-Extra: zzz-test-dataclasses
50
50
  Requires-Dist: orjson<3.11,>=3.10.15; extra == 'zzz-test-dataclasses'
51
- Requires-Dist: polars-lts-cpu<1.29,>=1.28.1; extra == 'zzz-test-dataclasses'
51
+ Requires-Dist: polars-lts-cpu<1.30,>=1.29.0; extra == 'zzz-test-dataclasses'
52
52
  Requires-Dist: whenever<0.8,>=0.7.3; extra == 'zzz-test-dataclasses'
53
53
  Provides-Extra: zzz-test-datetime
54
54
  Requires-Dist: tzlocal<5.4,>=5.3.1; extra == 'zzz-test-datetime'
@@ -70,11 +70,11 @@ Provides-Extra: zzz-test-getpass
70
70
  Provides-Extra: zzz-test-git
71
71
  Provides-Extra: zzz-test-hashlib
72
72
  Requires-Dist: orjson<3.11,>=3.10.15; extra == 'zzz-test-hashlib'
73
- Requires-Dist: polars-lts-cpu<1.29,>=1.28.1; extra == 'zzz-test-hashlib'
73
+ Requires-Dist: polars-lts-cpu<1.30,>=1.29.0; extra == 'zzz-test-hashlib'
74
74
  Requires-Dist: whenever<0.8,>=0.7.3; extra == 'zzz-test-hashlib'
75
75
  Provides-Extra: zzz-test-http
76
76
  Requires-Dist: atomicwrites<1.5,>=1.4.1; extra == 'zzz-test-http'
77
- Requires-Dist: orjson<3.11,>=3.10.16; extra == 'zzz-test-http'
77
+ Requires-Dist: orjson<3.11,>=3.10.18; extra == 'zzz-test-http'
78
78
  Requires-Dist: whenever<0.8,>=0.7.3; extra == 'zzz-test-http'
79
79
  Provides-Extra: zzz-test-hypothesis
80
80
  Requires-Dist: aiosqlite<0.22,>=0.21.0; extra == 'zzz-test-hypothesis'
@@ -92,12 +92,12 @@ Requires-Dist: whenever<0.8,>=0.7.3; extra == 'zzz-test-hypothesis'
92
92
  Provides-Extra: zzz-test-ipython
93
93
  Requires-Dist: ipython<9.1,>=9.0.1; extra == 'zzz-test-ipython'
94
94
  Provides-Extra: zzz-test-iterables
95
- Requires-Dist: polars-lts-cpu<1.29,>=1.28.1; extra == 'zzz-test-iterables'
95
+ Requires-Dist: polars-lts-cpu<1.30,>=1.29.0; extra == 'zzz-test-iterables'
96
96
  Requires-Dist: whenever<0.8,>=0.7.3; extra == 'zzz-test-iterables'
97
97
  Provides-Extra: zzz-test-jupyter
98
98
  Requires-Dist: jupyterlab<4.3,>=4.2.0; extra == 'zzz-test-jupyter'
99
99
  Requires-Dist: pandas<2.3,>=2.2.2; extra == 'zzz-test-jupyter'
100
- Requires-Dist: polars-lts-cpu<1.29,>=1.28.1; extra == 'zzz-test-jupyter'
100
+ Requires-Dist: polars-lts-cpu<1.30,>=1.29.0; extra == 'zzz-test-jupyter'
101
101
  Provides-Extra: zzz-test-logging
102
102
  Requires-Dist: atomicwrites<1.5,>=1.4.1; extra == 'zzz-test-logging'
103
103
  Requires-Dist: coloredlogs<15.1,>=15.0.1; extra == 'zzz-test-logging'
@@ -121,13 +121,13 @@ Requires-Dist: more-itertools<10.8,>=10.7.0; extra == 'zzz-test-more-itertools'
121
121
  Provides-Extra: zzz-test-numpy
122
122
  Requires-Dist: numpy<2.3,>=2.2.5; extra == 'zzz-test-numpy'
123
123
  Provides-Extra: zzz-test-operator
124
- Requires-Dist: polars-lts-cpu<1.29,>=1.28.1; extra == 'zzz-test-operator'
124
+ Requires-Dist: polars-lts-cpu<1.30,>=1.29.0; extra == 'zzz-test-operator'
125
125
  Requires-Dist: whenever<0.8,>=0.7.3; extra == 'zzz-test-operator'
126
126
  Provides-Extra: zzz-test-optuna
127
127
  Requires-Dist: optuna<4.4,>=4.3.0; extra == 'zzz-test-optuna'
128
128
  Provides-Extra: zzz-test-orjson
129
129
  Requires-Dist: orjson<3.11,>=3.10.15; extra == 'zzz-test-orjson'
130
- Requires-Dist: polars-lts-cpu<1.29,>=1.28.1; extra == 'zzz-test-orjson'
130
+ Requires-Dist: polars-lts-cpu<1.30,>=1.29.0; extra == 'zzz-test-orjson'
131
131
  Requires-Dist: rich<14.1,>=14.0.0; extra == 'zzz-test-orjson'
132
132
  Requires-Dist: tzlocal<5.4,>=5.3.1; extra == 'zzz-test-orjson'
133
133
  Requires-Dist: whenever<0.8,>=0.7.3; extra == 'zzz-test-orjson'
@@ -138,13 +138,13 @@ Requires-Dist: atomicwrites<1.5,>=1.4.1; extra == 'zzz-test-pickle'
138
138
  Provides-Extra: zzz-test-platform
139
139
  Provides-Extra: zzz-test-polars
140
140
  Requires-Dist: dacite<1.10,>=1.9.2; extra == 'zzz-test-polars'
141
- Requires-Dist: polars-lts-cpu<1.29,>=1.28.1; extra == 'zzz-test-polars'
141
+ Requires-Dist: polars-lts-cpu<1.30,>=1.29.0; extra == 'zzz-test-polars'
142
142
  Requires-Dist: whenever<0.8,>=0.7.3; extra == 'zzz-test-polars'
143
143
  Provides-Extra: zzz-test-pqdm
144
144
  Requires-Dist: pqdm<0.3,>=0.2.0; extra == 'zzz-test-pqdm'
145
145
  Provides-Extra: zzz-test-pydantic
146
146
  Requires-Dist: atomicwrites<1.5,>=1.4.1; extra == 'zzz-test-pydantic'
147
- Requires-Dist: pydantic<2.12,>=2.11.2; extra == 'zzz-test-pydantic'
147
+ Requires-Dist: pydantic<2.12,>=2.11.4; extra == 'zzz-test-pydantic'
148
148
  Provides-Extra: zzz-test-pyinstrument
149
149
  Requires-Dist: atomicwrites<1.5,>=1.4.1; extra == 'zzz-test-pyinstrument'
150
150
  Requires-Dist: pyinstrument<5.1,>=5.0.0; extra == 'zzz-test-pyinstrument'
@@ -153,7 +153,7 @@ Provides-Extra: zzz-test-pyrsistent
153
153
  Requires-Dist: pyrsistent<0.21,>=0.20.0; extra == 'zzz-test-pyrsistent'
154
154
  Provides-Extra: zzz-test-pytest
155
155
  Requires-Dist: atomicwrites<1.5,>=1.4.1; extra == 'zzz-test-pytest'
156
- Requires-Dist: orjson<3.11,>=3.10.16; extra == 'zzz-test-pytest'
156
+ Requires-Dist: orjson<3.11,>=3.10.18; extra == 'zzz-test-pytest'
157
157
  Requires-Dist: whenever<0.8,>=0.7.3; extra == 'zzz-test-pytest'
158
158
  Provides-Extra: zzz-test-pytest-regressions
159
159
  Requires-Dist: pytest-regressions<2.8,>=2.7.0; extra == 'zzz-test-pytest-regressions'
@@ -164,7 +164,7 @@ Provides-Extra: zzz-test-random
164
164
  Provides-Extra: zzz-test-re
165
165
  Provides-Extra: zzz-test-redis
166
166
  Requires-Dist: orjson<3.11,>=3.10.15; extra == 'zzz-test-redis'
167
- Requires-Dist: polars-lts-cpu<1.29,>=1.28.1; extra == 'zzz-test-redis'
167
+ Requires-Dist: polars-lts-cpu<1.30,>=1.29.0; extra == 'zzz-test-redis'
168
168
  Requires-Dist: redis<5.3,>=5.2.1; extra == 'zzz-test-redis'
169
169
  Requires-Dist: rich<14.1,>=14.0.0; extra == 'zzz-test-redis'
170
170
  Requires-Dist: tenacity<9.0,>=8.5.0; extra == 'zzz-test-redis'
@@ -192,12 +192,12 @@ Requires-Dist: aiosqlite<0.22,>=0.21.0; extra == 'zzz-test-sqlalchemy-polars'
192
192
  Requires-Dist: asyncpg<0.31,>=0.30.0; extra == 'zzz-test-sqlalchemy-polars'
193
193
  Requires-Dist: greenlet<3.3,>=3.2.0; extra == 'zzz-test-sqlalchemy-polars'
194
194
  Requires-Dist: nest-asyncio<1.7,>=1.6.0; extra == 'zzz-test-sqlalchemy-polars'
195
- Requires-Dist: polars-lts-cpu<1.29,>=1.28.1; extra == 'zzz-test-sqlalchemy-polars'
195
+ Requires-Dist: polars-lts-cpu<1.30,>=1.29.0; extra == 'zzz-test-sqlalchemy-polars'
196
196
  Requires-Dist: sqlalchemy<2.1,>=2.0.40; extra == 'zzz-test-sqlalchemy-polars'
197
197
  Requires-Dist: tenacity<9.0,>=8.5.0; extra == 'zzz-test-sqlalchemy-polars'
198
198
  Requires-Dist: whenever<0.8,>=0.7.3; extra == 'zzz-test-sqlalchemy-polars'
199
199
  Provides-Extra: zzz-test-streamlit
200
- Requires-Dist: streamlit<1.45,>=1.44.1; extra == 'zzz-test-streamlit'
200
+ Requires-Dist: streamlit<1.46,>=1.45.0; extra == 'zzz-test-streamlit'
201
201
  Provides-Extra: zzz-test-sys
202
202
  Requires-Dist: atomicwrites<1.5,>=1.4.1; extra == 'zzz-test-sys'
203
203
  Requires-Dist: rich<14.1,>=14.0.0; extra == 'zzz-test-sys'
@@ -217,7 +217,7 @@ Requires-Dist: tzlocal<5.4,>=5.3.1; extra == 'zzz-test-traceback'
217
217
  Requires-Dist: whenever<0.8,>=0.7.3; extra == 'zzz-test-traceback'
218
218
  Provides-Extra: zzz-test-types
219
219
  Provides-Extra: zzz-test-typing
220
- Requires-Dist: polars-lts-cpu<1.29,>=1.28.1; extra == 'zzz-test-typing'
220
+ Requires-Dist: polars-lts-cpu<1.30,>=1.29.0; extra == 'zzz-test-typing'
221
221
  Requires-Dist: whenever<0.8,>=0.7.3; extra == 'zzz-test-typing'
222
222
  Provides-Extra: zzz-test-tzlocal
223
223
  Requires-Dist: tzlocal<5.4,>=5.3.1; extra == 'zzz-test-tzlocal'
@@ -1,4 +1,4 @@
1
- utilities/__init__.py,sha256=XeC2tMugkErM_YQbJ1Iwis-CVp-V14bjMaBeziiWF1s,60
1
+ utilities/__init__.py,sha256=mALrkpjRtE0nW8tAmwclFaEZyShTCa75IjKVwftuFOc,61
2
2
  utilities/altair.py,sha256=Gpja-flOo-Db0PIPJLJsgzAlXWoKUjPU1qY-DQ829ek,9156
3
3
  utilities/astor.py,sha256=xuDUkjq0-b6fhtwjhbnebzbqQZAjMSHR1IIS5uOodVg,777
4
4
  utilities/asyncio.py,sha256=41oQUurWMvadFK5gFnaG21hMM0Vmfn2WS6OpC0R9mas,14757
@@ -36,7 +36,7 @@ utilities/math.py,sha256=TexfvLCI12d9Sw5_W4pKVBZ3nRr3zk2iPkcEU7xdEWU,26771
36
36
  utilities/memory_profiler.py,sha256=tf2C51P2lCujPGvRt2Rfc7VEw5LDXmVPCG3z_AvBmbU,962
37
37
  utilities/modules.py,sha256=iuvLluJya-hvl1Q25-Jk3dLgx2Es3ck4SjJiEkAlVTs,3195
38
38
  utilities/more_itertools.py,sha256=CPUxrMAcTwRxbzbhiqPKi3Xx9hxqI0t6gkWjutaibGk,5534
39
- utilities/numpy.py,sha256=cBgCBet8YfZP_rb4nkCJHZx9_03qPEinVENMk1dGVYQ,25683
39
+ utilities/numpy.py,sha256=Xn23sA2ZbVNqwUYEgNJD3XBYH6IbCri_WkHSNhg3NkY,26122
40
40
  utilities/operator.py,sha256=0M2yZJ0PODH47ogFEnkGMBe_cfxwZR02T_92LZVZvHo,3715
41
41
  utilities/optuna.py,sha256=loyJGWTzljgdJaoLhP09PT8Jz6o_pwBOwehY33lHkhw,1923
42
42
  utilities/orjson.py,sha256=DBm2zPP04kcHpY3l1etL24ksNynu-R3duFyx3U-RjqQ,36948
@@ -46,7 +46,7 @@ utilities/pathlib.py,sha256=31WPMXdLIyXgYOMMl_HOI2wlo66MGSE-cgeelk-Lias,1410
46
46
  utilities/period.py,sha256=RWfcNVoNlW07RNdU47g_zuLZMKbtgfK4bE6G-9tVjY8,11024
47
47
  utilities/pickle.py,sha256=Bhvd7cZl-zQKQDFjUerqGuSKlHvnW1K2QXeU5UZibtg,657
48
48
  utilities/platform.py,sha256=NU7ycTvAXAG-fdYmDXaM1m4EOml2cGiaYwaUzfzSqyU,1767
49
- utilities/polars.py,sha256=q8a2hTX9tyPxtxUarIm2BxlnckIwF1RCpcJz03OoIoY,58458
49
+ utilities/polars.py,sha256=ZdhGu9qMlgbwfdYYDEV3Pk3kfJr1V5FSolfCWHI3Nxc,63283
50
50
  utilities/polars_ols.py,sha256=efhXf0gjrHUpQrvS6a7g8yJQJWf_ATKtJnqqF2inCOU,5680
51
51
  utilities/pqdm.py,sha256=foRytQybmOQ05pjt5LF7ANyzrIa--4ScDE3T2wd31a4,3118
52
52
  utilities/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -58,7 +58,7 @@ utilities/pytest_regressions.py,sha256=-SVT9647Dg6-JcdsiaDKXe3NdOmmrvGevLKWwGjxq
58
58
  utilities/python_dotenv.py,sha256=iWcnpXbH7S6RoXHiLlGgyuH6udCupAcPd_gQ0eAenQ0,3190
59
59
  utilities/random.py,sha256=lYdjgxB7GCfU_fwFVl5U-BIM_HV3q6_urL9byjrwDM8,4157
60
60
  utilities/re.py,sha256=5J4d8VwIPFVrX2Eb8zfoxImDv7IwiN_U7mJ07wR2Wvs,3958
61
- utilities/redis.py,sha256=fAUbfOlCmxcxhh47PXQX63w0CU5iOFKfdUJ7jDn9ntM,22096
61
+ utilities/redis.py,sha256=XYHo1Qne4j5_BSIUorF8n5EWCRwByTdbAL3NLC6e2_4,25261
62
62
  utilities/reprlib.py,sha256=Re9bk3n-kC__9DxQmRlevqFA86pE6TtVfWjUgpbVOv0,1849
63
63
  utilities/rich.py,sha256=t50MwwVBsoOLxzmeVFSVpjno4OW6Ufum32skXbV8-Bs,1911
64
64
  utilities/scipy.py,sha256=X6ROnHwiUhAmPhM0jkfEh0-Fd9iRvwiqtCQMOLmOQF8,945
@@ -87,7 +87,7 @@ utilities/warnings.py,sha256=un1LvHv70PU-LLv8RxPVmugTzDJkkGXRMZTE2-fTQHw,1771
87
87
  utilities/whenever.py,sha256=iLRP_-8CZtBpHKbGZGu-kjSMg1ZubJ-VSmgSy7Eudxw,17787
88
88
  utilities/zipfile.py,sha256=24lQc9ATcJxHXBPc_tBDiJk48pWyRrlxO2fIsFxU0A8,699
89
89
  utilities/zoneinfo.py,sha256=-Xm57PMMwDTYpxJdkiJG13wnbwK--I7XItBh5WVhD-o,1874
90
- dycw_utilities-0.112.8.dist-info/METADATA,sha256=i2cGlKTXS9uPmvL_dC8xIdqS9wlxq5NbsOgmOn6J23o,13004
91
- dycw_utilities-0.112.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
92
- dycw_utilities-0.112.8.dist-info/licenses/LICENSE,sha256=gppZp16M6nSVpBbUBrNL6JuYfvKwZiKgV7XoKKsHzqo,1066
93
- dycw_utilities-0.112.8.dist-info/RECORD,,
90
+ dycw_utilities-0.112.10.dist-info/METADATA,sha256=HXZIn_uHvb-xc5-RuCEn_T8pPkJ58gnElO7ootr182o,13005
91
+ dycw_utilities-0.112.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
92
+ dycw_utilities-0.112.10.dist-info/licenses/LICENSE,sha256=gppZp16M6nSVpBbUBrNL6JuYfvKwZiKgV7XoKKsHzqo,1066
93
+ dycw_utilities-0.112.10.dist-info/RECORD,,
utilities/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
1
  from __future__ import annotations
2
2
 
3
- __version__ = "0.112.8"
3
+ __version__ = "0.112.10"
utilities/numpy.py CHANGED
@@ -1,9 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from collections.abc import Sequence
3
4
  from dataclasses import dataclass
4
5
  from functools import partial, reduce
5
6
  from itertools import repeat
6
- from typing import TYPE_CHECKING, Any, overload, override
7
+ from typing import TYPE_CHECKING, Any, SupportsIndex, overload, override
7
8
 
8
9
  import numpy as np
9
10
  from numpy import (
@@ -46,6 +47,9 @@ if TYPE_CHECKING:
46
47
  from utilities.types import MaybeIterable
47
48
 
48
49
 
50
+ type ShapeLike = SupportsIndex | Sequence[SupportsIndex]
51
+
52
+
49
53
  ##
50
54
 
51
55
 
@@ -136,6 +140,19 @@ class AsIntError(Exception): ...
136
140
  ##
137
141
 
138
142
 
143
+ def bernoulli(
144
+ *, true: float = 0.5, seed: int | None = None, size: ShapeLike = ()
145
+ ) -> NDArrayB:
146
+ """Return a set of Bernoulli random variates."""
147
+ from numpy.random import default_rng
148
+
149
+ rng = default_rng(seed=seed)
150
+ return rng.binomial(1, true, size=size).astype(bool)
151
+
152
+
153
+ ##
154
+
155
+
139
156
  def boxcar(
140
157
  array: NDArrayF,
141
158
  /,
@@ -980,11 +997,13 @@ __all__ = [
980
997
  "NDArrayF",
981
998
  "NDArrayI",
982
999
  "NDArrayO",
1000
+ "ShapeLike",
983
1001
  "ShiftError",
984
1002
  "SigmoidError",
985
1003
  "adjust_frequencies",
986
1004
  "array_indexer",
987
1005
  "as_int",
1006
+ "bernoulli",
988
1007
  "boxcar",
989
1008
  "datetime64D",
990
1009
  "datetime64M",
utilities/polars.py CHANGED
@@ -39,17 +39,20 @@ from polars import (
39
39
  Series,
40
40
  String,
41
41
  Struct,
42
+ UInt32,
42
43
  all_horizontal,
43
44
  col,
44
45
  concat,
45
46
  int_range,
46
47
  lit,
47
48
  struct,
49
+ sum_horizontal,
48
50
  when,
49
51
  )
50
52
  from polars.datatypes import DataType, DataTypeClass
51
53
  from polars.exceptions import (
52
54
  ColumnNotFoundError,
55
+ NoRowsReturnedError,
53
56
  OutOfBoundsError,
54
57
  PolarsInefficientMapWarning,
55
58
  )
@@ -337,6 +340,102 @@ def are_frames_equal(
337
340
  ##
338
341
 
339
342
 
343
+ def bernoulli(
344
+ obj: int | Series | DataFrame,
345
+ /,
346
+ *,
347
+ true: float = 0.5,
348
+ seed: int | None = None,
349
+ name: str | None = None,
350
+ ) -> Series:
351
+ """Construct a series of Bernoulli-random variables."""
352
+ match obj:
353
+ case int() as height:
354
+ import utilities.numpy
355
+
356
+ values = utilities.numpy.bernoulli(true=true, seed=seed, size=height)
357
+ return Series(name=name, values=values)
358
+ case Series() as series:
359
+ return bernoulli(series.len(), true=true, seed=seed, name=name)
360
+ case DataFrame() as df:
361
+ return bernoulli(df.height, true=true, seed=seed, name=name)
362
+ case _ as never:
363
+ assert_never(never)
364
+
365
+
366
+ ##
367
+
368
+
369
+ def boolean_value_counts(
370
+ obj: Series | DataFrame, /, *exprs: IntoExprColumn, **named_exprs: IntoExprColumn
371
+ ) -> DataFrame:
372
+ """Conduct a set of boolean value counts."""
373
+ match obj:
374
+ case Series() as series:
375
+ return boolean_value_counts(series.to_frame(), *exprs, **named_exprs)
376
+ case DataFrame() as df:
377
+ all_exprs = ensure_expr_or_series_many(*exprs, **named_exprs)
378
+ rows = [_boolean_value_counts_one(df, expr) for expr in all_exprs]
379
+ true, false, null = [col(c) for c in ["true", "false", "null"]]
380
+ total = sum_horizontal(true, false, null).alias("total")
381
+ return DataFrame(
382
+ rows,
383
+ schema={
384
+ "name": String,
385
+ "true": UInt32,
386
+ "false": UInt32,
387
+ "null": UInt32,
388
+ },
389
+ orient="row",
390
+ ).with_columns(
391
+ total,
392
+ (true / total).alias("true (%)"),
393
+ (false / total).alias("false (%)"),
394
+ (null / total).alias("null (%)"),
395
+ )
396
+ case _ as never:
397
+ assert_never(never)
398
+
399
+
400
+ def _boolean_value_counts_one(
401
+ df: DataFrame, expr: IntoExprColumn, /
402
+ ) -> Mapping[str, Any]:
403
+ name = get_expr_name(df, expr)
404
+ sr = df.select(expr)[name]
405
+ if not isinstance(sr.dtype, Boolean):
406
+ raise BooleanValueCountsError(name=name, dtype=sr.dtype)
407
+ counts = sr.value_counts()
408
+ truth = col(name)
409
+ try:
410
+ true = counts.row(by_predicate=truth.is_not_null() & truth, named=True)["count"]
411
+ except NoRowsReturnedError:
412
+ true = 0
413
+ try:
414
+ false = counts.row(by_predicate=(truth.is_not_null() & ~truth), named=True)[
415
+ "count"
416
+ ]
417
+ except NoRowsReturnedError:
418
+ false = 0
419
+ try:
420
+ null = counts.row(by_predicate=truth.is_null(), named=True)["count"]
421
+ except NoRowsReturnedError:
422
+ null = 0
423
+ return {"name": name, "true": true, "false": false, "null": null}
424
+
425
+
426
+ @dataclass(kw_only=True, slots=True)
427
+ class BooleanValueCountsError(Exception):
428
+ name: str
429
+ dtype: DataType
430
+
431
+ @override
432
+ def __str__(self) -> str:
433
+ return f"Column {self.name!r} must be Boolean; got {self.dtype!r}"
434
+
435
+
436
+ ##
437
+
438
+
340
439
  @overload
341
440
  def ceil_datetime(column: ExprLike, every: ExprLike, /) -> Expr: ...
342
441
  @overload
@@ -637,6 +736,54 @@ class _CheckPolarsDataFrameWidthError(CheckPolarsDataFrameError):
637
736
  ##
638
737
 
639
738
 
739
+ def choice(
740
+ obj: int | Series | DataFrame,
741
+ elements: Iterable[Any],
742
+ /,
743
+ *,
744
+ replace: bool = True,
745
+ p: Iterable[float] | None = None,
746
+ seed: int | None = None,
747
+ name: str | None = None,
748
+ dtype: PolarsDataType = Float64,
749
+ ) -> Series:
750
+ """Construct a series of random samples."""
751
+ match obj:
752
+ case int() as height:
753
+ from numpy.random import default_rng
754
+
755
+ rng = default_rng(seed=seed)
756
+ elements = list(elements)
757
+ p = None if p is None else list(p)
758
+ values = rng.choice(elements, size=height, replace=replace, p=p)
759
+ return Series(name=name, values=values.tolist(), dtype=dtype)
760
+ case Series() as series:
761
+ return choice(
762
+ series.len(),
763
+ elements,
764
+ replace=replace,
765
+ p=p,
766
+ seed=seed,
767
+ name=name,
768
+ dtype=dtype,
769
+ )
770
+ case DataFrame() as df:
771
+ return choice(
772
+ df.height,
773
+ elements,
774
+ replace=replace,
775
+ p=p,
776
+ seed=seed,
777
+ name=name,
778
+ dtype=dtype,
779
+ )
780
+ case _ as never:
781
+ assert_never(never)
782
+
783
+
784
+ ##
785
+
786
+
640
787
  def collect_series(expr: Expr, /) -> Series:
641
788
  """Collect a column expression into a Series."""
642
789
  data = DataFrame().with_columns(expr)
@@ -1053,6 +1200,18 @@ def ensure_expr_or_series(column: IntoExprColumn, /) -> Expr | Series:
1053
1200
  ##
1054
1201
 
1055
1202
 
1203
+ def ensure_expr_or_series_many(
1204
+ *columns: IntoExprColumn, **named_columns: IntoExprColumn
1205
+ ) -> Sequence[Expr | Series]:
1206
+ """Ensure a set of column expressions and/or Series are returned."""
1207
+ args = map(ensure_expr_or_series, columns)
1208
+ kwargs = (ensure_expr_or_series(v).alias(k) for k, v in named_columns.items())
1209
+ return list(chain(args, kwargs))
1210
+
1211
+
1212
+ ##
1213
+
1214
+
1056
1215
  @overload
1057
1216
  def finite_ewm_mean(
1058
1217
  column: ExprLike,
@@ -1227,7 +1386,7 @@ class _GetDataTypeOrSeriesTimeZoneNotZonedError(GetDataTypeOrSeriesTimeZoneError
1227
1386
  ##
1228
1387
 
1229
1388
 
1230
- def get_expr_name(obj: Series | DataFrame, expr: Expr, /) -> str:
1389
+ def get_expr_name(obj: Series | DataFrame, expr: IntoExprColumn, /) -> str:
1231
1390
  """Get the name of an expression."""
1232
1391
  match obj:
1233
1392
  case Series() as series:
@@ -1948,6 +2107,7 @@ def zoned_datetime(
1948
2107
 
1949
2108
 
1950
2109
  __all__ = [
2110
+ "BooleanValueCountsError",
1951
2111
  "CheckPolarsDataFrameError",
1952
2112
  "ColumnsToDictError",
1953
2113
  "DataClassToDataFrameError",
@@ -1971,8 +2131,11 @@ __all__ = [
1971
2131
  "adjust_frequencies",
1972
2132
  "append_dataclass",
1973
2133
  "are_frames_equal",
2134
+ "bernoulli",
2135
+ "boolean_value_counts",
1974
2136
  "ceil_datetime",
1975
2137
  "check_polars_dataframe",
2138
+ "choice",
1976
2139
  "collect_series",
1977
2140
  "columns_to_dict",
1978
2141
  "concat_series",
@@ -1983,6 +2146,7 @@ __all__ = [
1983
2146
  "drop_null_struct_series",
1984
2147
  "ensure_data_type",
1985
2148
  "ensure_expr_or_series",
2149
+ "ensure_expr_or_series_many",
1986
2150
  "finite_ewm_mean",
1987
2151
  "floor_datetime",
1988
2152
  "get_data_type_or_series_time_zone",
utilities/redis.py CHANGED
@@ -31,11 +31,18 @@ from utilities.datetime import (
31
31
  )
32
32
  from utilities.errors import ImpossibleCaseError
33
33
  from utilities.functions import ensure_int
34
- from utilities.iterables import always_iterable
34
+ from utilities.iterables import always_iterable, one
35
35
 
36
36
  if TYPE_CHECKING:
37
37
  import datetime as dt
38
- from collections.abc import AsyncIterator, Awaitable, Callable, Mapping
38
+ from collections.abc import (
39
+ AsyncIterator,
40
+ Awaitable,
41
+ Callable,
42
+ Iterable,
43
+ Mapping,
44
+ Sequence,
45
+ )
39
46
 
40
47
  from redis.asyncio import ConnectionPool
41
48
  from redis.asyncio.client import PubSub
@@ -71,6 +78,7 @@ class RedisHashMapKey(Generic[_K, _V]):
71
78
  name: str
72
79
  key: TypeLike[_K]
73
80
  key_serializer: Callable[[_K], bytes] | None = None
81
+ key_deserializer: Callable[[bytes], _K] | None = None
74
82
  value: TypeLike[_V]
75
83
  value_serializer: Callable[[_V], bytes] | None = None
76
84
  value_deserializer: Callable[[bytes], _V] | None = None
@@ -95,46 +103,102 @@ class RedisHashMapKey(Generic[_K, _V]):
95
103
  "Awaitable[bool]", redis.hexists(self.name, cast("str", key))
96
104
  )
97
105
 
98
- async def get(self, redis: Redis, key: _K, /) -> _V | None:
106
+ async def get(self, redis: Redis, key: _K, /) -> _V:
107
+ """Get a value from a hashmap in `redis`."""
108
+ result = one(await self.get_many(redis, [key])) # skipif-ci-and-not-linux
109
+ if result is None: # skipif-ci-and-not-linux
110
+ raise KeyError(self.name, key)
111
+ return result # skipif-ci-and-not-linux
112
+
113
+ async def get_all(self, redis: Redis, /) -> Mapping[_K, _V]:
99
114
  """Get a value from a hashmap in `redis`."""
100
- ser_key = _serialize( # skipif-ci-and-not-linux
101
- key, serializer=self.key_serializer
102
- )
103
115
  async with timeout_dur( # skipif-ci-and-not-linux
104
116
  duration=self.timeout, error=self.error
105
117
  ):
106
118
  result = await cast( # skipif-ci-and-not-linux
107
- "Awaitable[bytes | None]", redis.hget(self.name, cast("Any", ser_key))
119
+ "Awaitable[Mapping[bytes, bytes]]", redis.hgetall(self.name)
108
120
  )
109
- match result: # skipif-ci-and-not-linux
110
- case None:
111
- return None
112
- case bytes() as data:
113
- return _deserialize(data, deserializer=self.value_deserializer)
114
- case _ as never:
115
- assert_never(never)
121
+ return { # skipif-ci-and-not-linux
122
+ _deserialize(key, deserializer=self.key_deserializer): _deserialize(
123
+ value, deserializer=self.value_deserializer
124
+ )
125
+ for key, value in result.items()
126
+ }
127
+
128
+ async def get_many(
129
+ self, redis: Redis, keys: Iterable[_K], /
130
+ ) -> Sequence[_V | None]:
131
+ """Get multiple values from a hashmap in `redis`."""
132
+ keys = list(keys) # skipif-ci-and-not-linux
133
+ if len(keys) == 0: # skipif-ci-and-not-linux
134
+ return []
135
+ ser = [ # skipif-ci-and-not-linux
136
+ _serialize(key, serializer=self.key_serializer) for key in keys
137
+ ]
138
+ async with timeout_dur( # skipif-ci-and-not-linux
139
+ duration=self.timeout, error=self.error
140
+ ):
141
+ result = await cast( # skipif-ci-and-not-linux
142
+ "Awaitable[Sequence[bytes | None]]", redis.hmget(self.name, ser)
143
+ )
144
+ return [ # skipif-ci-and-not-linux
145
+ None
146
+ if data is None
147
+ else _deserialize(data, deserializer=self.value_deserializer)
148
+ for data in result
149
+ ]
150
+
151
+ async def keys(self, redis: Redis, /) -> Sequence[_K]:
152
+ """Get the keys of a hashmap in `redis`."""
153
+ async with timeout_dur( # skipif-ci-and-not-linux
154
+ duration=self.timeout, error=self.error
155
+ ):
156
+ result = await cast("Awaitable[Sequence[bytes]]", redis.hkeys(self.name))
157
+ return [ # skipif-ci-and-not-linux
158
+ _deserialize(data, deserializer=self.key_deserializer) for data in result
159
+ ]
160
+
161
+ async def length(self, redis: Redis, /) -> int:
162
+ """Get the length of a hashmap in `redis`."""
163
+ async with timeout_dur( # skipif-ci-and-not-linux
164
+ duration=self.timeout, error=self.error
165
+ ):
166
+ return await cast("Awaitable[int]", redis.hlen(self.name))
116
167
 
117
168
  async def set(self, redis: Redis, key: _K, value: _V, /) -> int:
118
169
  """Set a value in a hashmap in `redis`."""
119
- ser_key = _serialize( # skipif-ci-and-not-linux
120
- key, serializer=self.key_serializer
121
- )
122
- ser_value = _serialize( # skipif-ci-and-not-linux
123
- value, serializer=self.value_serializer
124
- )
170
+ return await self.set_many(redis, {key: value}) # skipif-ci-and-not-linux
171
+
172
+ async def set_many(self, redis: Redis, mapping: Mapping[_K, _V], /) -> int:
173
+ """Set multiple value(s) in a hashmap in `redis`."""
174
+ if len(mapping) == 0: # skipif-ci-and-not-linux
175
+ return 0
176
+ ser = { # skipif-ci-and-not-linux
177
+ _serialize(key, serializer=self.key_serializer): _serialize(
178
+ value, serializer=self.value_serializer
179
+ )
180
+ for key, value in mapping.items()
181
+ }
125
182
  async with timeout_dur( # skipif-ci-and-not-linux
126
183
  duration=self.timeout, error=self.error
127
184
  ):
128
185
  result = await cast(
129
- "Awaitable[int]",
130
- redis.hset(
131
- self.name, key=cast("Any", ser_key), value=cast("Any", ser_value)
132
- ),
186
+ "Awaitable[int]", redis.hset(self.name, mapping=cast("Any", ser))
133
187
  )
134
188
  if self.ttl is not None:
135
189
  await redis.pexpire(self.name, datetime_duration_to_timedelta(self.ttl))
136
190
  return result # skipif-ci-and-not-linux
137
191
 
192
+ async def values(self, redis: Redis, /) -> Sequence[_V]:
193
+ """Get the values of a hashmap in `redis`."""
194
+ async with timeout_dur( # skipif-ci-and-not-linux
195
+ duration=self.timeout, error=self.error
196
+ ):
197
+ result = await cast("Awaitable[Sequence[bytes]]", redis.hvals(self.name))
198
+ return [ # skipif-ci-and-not-linux
199
+ _deserialize(data, deserializer=self.value_deserializer) for data in result
200
+ ]
201
+
138
202
 
139
203
  @overload
140
204
  def redis_hash_map_key(
@@ -144,6 +208,7 @@ def redis_hash_map_key(
144
208
  /,
145
209
  *,
146
210
  key_serializer: Callable[[_K], bytes] | None = None,
211
+ key_deserializer: Callable[[bytes], Any] | None = None,
147
212
  value_serializer: Callable[[_V], bytes] | None = None,
148
213
  value_deserializer: Callable[[bytes], _V] | None = None,
149
214
  timeout: Duration | None = None,
@@ -158,6 +223,7 @@ def redis_hash_map_key(
158
223
  /,
159
224
  *,
160
225
  key_serializer: Callable[[_K], bytes] | None = None,
226
+ key_deserializer: Callable[[bytes], Any] | None = None,
161
227
  value_serializer: Callable[[_V1 | _V2], bytes] | None = None,
162
228
  value_deserializer: Callable[[bytes], _V1 | _V2] | None = None,
163
229
  timeout: Duration | None = None,
@@ -172,6 +238,7 @@ def redis_hash_map_key(
172
238
  /,
173
239
  *,
174
240
  key_serializer: Callable[[_K], bytes] | None = None,
241
+ key_deserializer: Callable[[bytes], Any] | None = None,
175
242
  value_serializer: Callable[[_V1 | _V2 | _V3], bytes] | None = None,
176
243
  value_deserializer: Callable[[bytes], _V1 | _V2 | _V3] | None = None,
177
244
  timeout: Duration | None = None,
@@ -186,6 +253,7 @@ def redis_hash_map_key(
186
253
  /,
187
254
  *,
188
255
  key_serializer: Callable[[_K1 | _K2], bytes] | None = None,
256
+ key_deserializer: Callable[[bytes], Any] | None = None,
189
257
  value_serializer: Callable[[_V], bytes] | None = None,
190
258
  value_deserializer: Callable[[bytes], _V] | None = None,
191
259
  timeout: Duration | None = None,
@@ -200,6 +268,7 @@ def redis_hash_map_key(
200
268
  /,
201
269
  *,
202
270
  key_serializer: Callable[[_K1 | _K2], bytes] | None = None,
271
+ key_deserializer: Callable[[bytes], Any] | None = None,
203
272
  value_serializer: Callable[[_V1 | _V2], bytes] | None = None,
204
273
  value_deserializer: Callable[[bytes], _V1 | _V2] | None = None,
205
274
  timeout: Duration | None = None,
@@ -214,6 +283,7 @@ def redis_hash_map_key(
214
283
  /,
215
284
  *,
216
285
  key_serializer: Callable[[_K1 | _K2], bytes] | None = None,
286
+ key_deserializer: Callable[[bytes], Any] | None = None,
217
287
  value_serializer: Callable[[_V1 | _V2 | _V3], bytes] | None = None,
218
288
  value_deserializer: Callable[[bytes], _V1 | _V2 | _V3] | None = None,
219
289
  timeout: Duration | None = None,
@@ -228,6 +298,7 @@ def redis_hash_map_key(
228
298
  /,
229
299
  *,
230
300
  key_serializer: Callable[[_K1 | _K2 | _K3], bytes] | None = None,
301
+ key_deserializer: Callable[[bytes], Any] | None = None,
231
302
  value_serializer: Callable[[_V], bytes] | None = None,
232
303
  value_deserializer: Callable[[bytes], _V] | None = None,
233
304
  timeout: Duration | None = None,
@@ -242,6 +313,7 @@ def redis_hash_map_key(
242
313
  /,
243
314
  *,
244
315
  key_serializer: Callable[[_K1 | _K2 | _K3], bytes] | None = None,
316
+ key_deserializer: Callable[[bytes], Any] | None = None,
245
317
  value_serializer: Callable[[_V1 | _V2], bytes] | None = None,
246
318
  value_deserializer: Callable[[bytes], _V1 | _V2] | None = None,
247
319
  timeout: Duration | None = None,
@@ -256,6 +328,7 @@ def redis_hash_map_key(
256
328
  /,
257
329
  *,
258
330
  key_serializer: Callable[[_K1 | _K2 | _K3], bytes] | None = None,
331
+ key_deserializer: Callable[[bytes], Any] | None = None,
259
332
  value_serializer: Callable[[_V1 | _V2 | _V3], bytes] | None = None,
260
333
  value_deserializer: Callable[[bytes], _V1 | _V2 | _V3] | None = None,
261
334
  timeout: Duration | None = None,
@@ -270,6 +343,7 @@ def redis_hash_map_key(
270
343
  /,
271
344
  *,
272
345
  key_serializer: Callable[[_K1 | _K2 | _K3], bytes] | None = None,
346
+ key_deserializer: Callable[[bytes], Any] | None = None,
273
347
  value_serializer: Callable[[_V1 | _V2 | _V3], bytes] | None = None,
274
348
  value_deserializer: Callable[[bytes], _V1 | _V2 | _V3] | None = None,
275
349
  timeout: Duration | None = None,
@@ -283,6 +357,7 @@ def redis_hash_map_key(
283
357
  /,
284
358
  *,
285
359
  key_serializer: Callable[[Any], bytes] | None = None,
360
+ key_deserializer: Callable[[bytes], Any] | None = None,
286
361
  value_serializer: Callable[[Any], bytes] | None = None,
287
362
  value_deserializer: Callable[[bytes], Any] | None = None,
288
363
  timeout: Duration | None = None,
@@ -294,6 +369,7 @@ def redis_hash_map_key(
294
369
  name=name,
295
370
  key=key,
296
371
  key_serializer=key_serializer,
372
+ key_deserializer=key_deserializer,
297
373
  value=value,
298
374
  value_serializer=value_serializer,
299
375
  value_deserializer=value_deserializer,
@@ -343,23 +419,15 @@ class RedisKey(Generic[_T]):
343
419
  duration=self.timeout, error=self.error
344
420
  ):
345
421
  result = cast("bytes | None", await redis.get(self.name))
346
- match result: # skipif-ci-and-not-linux
347
- case None:
348
- return None
349
- case bytes() as data:
350
- if self.deserializer is None:
351
- from utilities.orjson import deserialize
352
-
353
- return deserialize(data)
354
- return self.deserializer(data)
355
- case _ as never:
356
- assert_never(never)
422
+ if result is None: # skipif-ci-and-not-linux
423
+ raise KeyError(self.name)
424
+ return _deserialize( # skipif-ci-and-not-linux
425
+ result, deserializer=self.deserializer
426
+ )
357
427
 
358
428
  async def set(self, redis: Redis, value: _T, /) -> int:
359
429
  """Set a value in `redis`."""
360
- ser_value = _serialize( # skipif-ci-and-not-linux
361
- value, serializer=self.serializer
362
- )
430
+ ser = _serialize(value, serializer=self.serializer) # skipif-ci-and-not-linux
363
431
  ttl = ( # skipif-ci-and-not-linux
364
432
  None
365
433
  if self.ttl is None
@@ -369,7 +437,7 @@ class RedisKey(Generic[_T]):
369
437
  duration=self.timeout, error=self.error
370
438
  ):
371
439
  result = await redis.set( # skipif-ci-and-not-linux
372
- self.name, ser_value, px=ttl
440
+ self.name, ser, px=ttl
373
441
  )
374
442
  return ensure_int(result) # skipif-ci-and-not-linux
375
443