dycw-utilities 0.135.0__py3-none-any.whl → 0.178.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dycw-utilities might be problematic. Click here for more details.
- dycw_utilities-0.178.1.dist-info/METADATA +34 -0
- dycw_utilities-0.178.1.dist-info/RECORD +105 -0
- dycw_utilities-0.178.1.dist-info/WHEEL +4 -0
- dycw_utilities-0.178.1.dist-info/entry_points.txt +4 -0
- utilities/__init__.py +1 -1
- utilities/altair.py +13 -10
- utilities/asyncio.py +312 -787
- utilities/atomicwrites.py +18 -6
- utilities/atools.py +64 -4
- utilities/cachetools.py +9 -6
- utilities/click.py +195 -77
- utilities/concurrent.py +1 -1
- utilities/contextlib.py +216 -17
- utilities/contextvars.py +20 -1
- utilities/cryptography.py +3 -3
- utilities/dataclasses.py +15 -28
- utilities/docker.py +387 -0
- utilities/enum.py +2 -2
- utilities/errors.py +17 -3
- utilities/fastapi.py +28 -59
- utilities/fpdf2.py +2 -2
- utilities/functions.py +24 -269
- utilities/git.py +9 -30
- utilities/grp.py +28 -0
- utilities/gzip.py +31 -0
- utilities/http.py +3 -2
- utilities/hypothesis.py +513 -159
- utilities/importlib.py +17 -1
- utilities/inflect.py +12 -4
- utilities/iterables.py +33 -58
- utilities/jinja2.py +148 -0
- utilities/json.py +70 -0
- utilities/libcst.py +38 -17
- utilities/lightweight_charts.py +4 -7
- utilities/logging.py +136 -93
- utilities/math.py +8 -4
- utilities/more_itertools.py +43 -45
- utilities/operator.py +27 -27
- utilities/orjson.py +189 -36
- utilities/os.py +61 -4
- utilities/packaging.py +115 -0
- utilities/parse.py +8 -5
- utilities/pathlib.py +269 -40
- utilities/permissions.py +298 -0
- utilities/platform.py +7 -6
- utilities/polars.py +1205 -413
- utilities/polars_ols.py +1 -1
- utilities/postgres.py +408 -0
- utilities/pottery.py +43 -19
- utilities/pqdm.py +3 -3
- utilities/psutil.py +5 -57
- utilities/pwd.py +28 -0
- utilities/pydantic.py +4 -52
- utilities/pydantic_settings.py +240 -0
- utilities/pydantic_settings_sops.py +76 -0
- utilities/pyinstrument.py +7 -7
- utilities/pytest.py +104 -143
- utilities/pytest_plugins/__init__.py +1 -0
- utilities/pytest_plugins/pytest_randomly.py +23 -0
- utilities/pytest_plugins/pytest_regressions.py +56 -0
- utilities/pytest_regressions.py +26 -46
- utilities/random.py +11 -6
- utilities/re.py +1 -1
- utilities/redis.py +220 -343
- utilities/sentinel.py +10 -0
- utilities/shelve.py +4 -1
- utilities/shutil.py +25 -0
- utilities/slack_sdk.py +35 -104
- utilities/sqlalchemy.py +496 -471
- utilities/sqlalchemy_polars.py +29 -54
- utilities/string.py +2 -3
- utilities/subprocess.py +1977 -0
- utilities/tempfile.py +112 -4
- utilities/testbook.py +50 -0
- utilities/text.py +174 -42
- utilities/throttle.py +158 -0
- utilities/timer.py +2 -2
- utilities/traceback.py +70 -35
- utilities/types.py +102 -30
- utilities/typing.py +479 -19
- utilities/uuid.py +42 -5
- utilities/version.py +27 -26
- utilities/whenever.py +1559 -361
- utilities/zoneinfo.py +80 -22
- dycw_utilities-0.135.0.dist-info/METADATA +0 -39
- dycw_utilities-0.135.0.dist-info/RECORD +0 -96
- dycw_utilities-0.135.0.dist-info/WHEEL +0 -4
- dycw_utilities-0.135.0.dist-info/licenses/LICENSE +0 -21
- utilities/aiolimiter.py +0 -25
- utilities/arq.py +0 -216
- utilities/eventkit.py +0 -388
- utilities/luigi.py +0 -183
- utilities/period.py +0 -152
- utilities/pudb.py +0 -62
- utilities/python_dotenv.py +0 -101
- utilities/streamlit.py +0 -105
- utilities/typed_settings.py +0 -123
utilities/polars.py
CHANGED
|
@@ -1,25 +1,24 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import datetime as dt
|
|
4
3
|
import enum
|
|
5
4
|
from collections.abc import Callable, Iterator, Sequence
|
|
6
5
|
from collections.abc import Set as AbstractSet
|
|
7
|
-
from contextlib import suppress
|
|
8
6
|
from dataclasses import asdict, dataclass
|
|
9
7
|
from functools import partial, reduce
|
|
10
|
-
from itertools import chain, product
|
|
11
|
-
from math import ceil, log
|
|
8
|
+
from itertools import chain, pairwise, product
|
|
9
|
+
from math import ceil, log, pi, sqrt
|
|
12
10
|
from pathlib import Path
|
|
13
11
|
from typing import TYPE_CHECKING, Any, Literal, assert_never, cast, overload, override
|
|
14
12
|
from uuid import UUID
|
|
15
13
|
from zoneinfo import ZoneInfo
|
|
16
14
|
|
|
17
15
|
import polars as pl
|
|
16
|
+
import whenever
|
|
18
17
|
from polars import (
|
|
19
18
|
Boolean,
|
|
20
19
|
DataFrame,
|
|
21
|
-
Date,
|
|
22
20
|
Datetime,
|
|
21
|
+
Duration,
|
|
23
22
|
Expr,
|
|
24
23
|
Float64,
|
|
25
24
|
Int64,
|
|
@@ -33,12 +32,16 @@ from polars import (
|
|
|
33
32
|
any_horizontal,
|
|
34
33
|
col,
|
|
35
34
|
concat,
|
|
35
|
+
concat_list,
|
|
36
|
+
datetime_range,
|
|
36
37
|
int_range,
|
|
37
38
|
lit,
|
|
39
|
+
max_horizontal,
|
|
38
40
|
struct,
|
|
39
41
|
sum_horizontal,
|
|
40
42
|
when,
|
|
41
43
|
)
|
|
44
|
+
from polars._typing import PolarsDataType
|
|
42
45
|
from polars.datatypes import DataType, DataTypeClass
|
|
43
46
|
from polars.exceptions import (
|
|
44
47
|
ColumnNotFoundError,
|
|
@@ -46,58 +49,62 @@ from polars.exceptions import (
|
|
|
46
49
|
OutOfBoundsError,
|
|
47
50
|
PolarsInefficientMapWarning,
|
|
48
51
|
)
|
|
49
|
-
from polars.
|
|
52
|
+
from polars.schema import Schema
|
|
53
|
+
from polars.testing import assert_frame_equal, assert_series_equal
|
|
54
|
+
from whenever import DateDelta, DateTimeDelta, PlainDateTime, TimeDelta, ZonedDateTime
|
|
50
55
|
|
|
51
|
-
|
|
56
|
+
import utilities.math
|
|
57
|
+
from utilities.dataclasses import yield_fields
|
|
52
58
|
from utilities.errors import ImpossibleCaseError
|
|
53
|
-
from utilities.functions import
|
|
54
|
-
|
|
55
|
-
ensure_int,
|
|
56
|
-
is_dataclass_class,
|
|
57
|
-
is_dataclass_instance,
|
|
58
|
-
is_iterable_of,
|
|
59
|
-
make_isinstance,
|
|
60
|
-
)
|
|
59
|
+
from utilities.functions import get_class_name
|
|
60
|
+
from utilities.gzip import read_binary
|
|
61
61
|
from utilities.iterables import (
|
|
62
62
|
CheckIterablesEqualError,
|
|
63
63
|
CheckMappingsEqualError,
|
|
64
|
-
CheckSubSetError,
|
|
65
64
|
CheckSuperMappingError,
|
|
66
65
|
OneEmptyError,
|
|
67
66
|
OneNonUniqueError,
|
|
68
67
|
always_iterable,
|
|
69
68
|
check_iterables_equal,
|
|
70
69
|
check_mappings_equal,
|
|
71
|
-
check_subset,
|
|
72
70
|
check_supermapping,
|
|
73
71
|
is_iterable_not_str,
|
|
74
72
|
one,
|
|
73
|
+
resolve_include_and_exclude,
|
|
75
74
|
)
|
|
75
|
+
from utilities.json import write_formatted_json
|
|
76
76
|
from utilities.math import (
|
|
77
|
+
MAX_DECIMALS,
|
|
77
78
|
CheckIntegerError,
|
|
78
79
|
check_integer,
|
|
79
80
|
ewm_parameters,
|
|
80
81
|
is_less_than,
|
|
81
82
|
is_non_negative,
|
|
82
|
-
number_of_decimals,
|
|
83
83
|
)
|
|
84
84
|
from utilities.reprlib import get_repr
|
|
85
|
-
from utilities.types import MaybeStr, Number, WeekDay
|
|
85
|
+
from utilities.types import MaybeStr, Number, PathLike, WeekDay
|
|
86
86
|
from utilities.typing import (
|
|
87
87
|
get_args,
|
|
88
|
-
|
|
88
|
+
is_dataclass_class,
|
|
89
|
+
is_dataclass_instance,
|
|
89
90
|
is_frozenset_type,
|
|
90
|
-
is_instance_gen,
|
|
91
91
|
is_list_type,
|
|
92
92
|
is_literal_type,
|
|
93
93
|
is_optional_type,
|
|
94
94
|
is_set_type,
|
|
95
|
-
|
|
95
|
+
make_isinstance,
|
|
96
96
|
)
|
|
97
97
|
from utilities.warnings import suppress_warnings
|
|
98
|
-
from utilities.
|
|
98
|
+
from utilities.whenever import (
|
|
99
|
+
DatePeriod,
|
|
100
|
+
TimePeriod,
|
|
101
|
+
ZonedDateTimePeriod,
|
|
102
|
+
to_py_time_delta,
|
|
103
|
+
)
|
|
104
|
+
from utilities.zoneinfo import UTC, to_time_zone_name
|
|
99
105
|
|
|
100
106
|
if TYPE_CHECKING:
|
|
107
|
+
import datetime as dt
|
|
101
108
|
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
|
|
102
109
|
from collections.abc import Set as AbstractSet
|
|
103
110
|
|
|
@@ -108,6 +115,7 @@ if TYPE_CHECKING:
|
|
|
108
115
|
JoinValidation,
|
|
109
116
|
PolarsDataType,
|
|
110
117
|
QuantileMethod,
|
|
118
|
+
RoundMode,
|
|
111
119
|
SchemaDict,
|
|
112
120
|
TimeUnit,
|
|
113
121
|
)
|
|
@@ -118,13 +126,19 @@ if TYPE_CHECKING:
|
|
|
118
126
|
|
|
119
127
|
|
|
120
128
|
type ExprLike = MaybeStr[Expr]
|
|
129
|
+
type ExprOrSeries = Expr | Series
|
|
121
130
|
DatetimeHongKong = Datetime(time_zone="Asia/Hong_Kong")
|
|
122
131
|
DatetimeTokyo = Datetime(time_zone="Asia/Tokyo")
|
|
123
132
|
DatetimeUSCentral = Datetime(time_zone="US/Central")
|
|
124
133
|
DatetimeUSEastern = Datetime(time_zone="US/Eastern")
|
|
125
134
|
DatetimeUTC = Datetime(time_zone="UTC")
|
|
135
|
+
DatePeriodDType = Struct({"start": pl.Date, "end": pl.Date})
|
|
136
|
+
TimePeriodDType = Struct({"start": pl.Time, "end": pl.Time})
|
|
137
|
+
|
|
138
|
+
|
|
126
139
|
_FINITE_EWM_MIN_WEIGHT = 0.9999
|
|
127
140
|
|
|
141
|
+
|
|
128
142
|
##
|
|
129
143
|
|
|
130
144
|
|
|
@@ -204,7 +218,7 @@ def acf(
|
|
|
204
218
|
df_confints = _acf_process_confints(confints)
|
|
205
219
|
df_qstats_pvalues = _acf_process_qstats_pvalues(qstats, pvalues)
|
|
206
220
|
return join(df_acfs, df_confints, df_qstats_pvalues, on=["lag"], how="left")
|
|
207
|
-
case
|
|
221
|
+
case never:
|
|
208
222
|
assert_never(never)
|
|
209
223
|
|
|
210
224
|
|
|
@@ -234,11 +248,6 @@ def _acf_process_qstats_pvalues(qstats: NDArrayF, pvalues: NDArrayF, /) -> DataF
|
|
|
234
248
|
##
|
|
235
249
|
|
|
236
250
|
|
|
237
|
-
# def acf_halflife(series: Series,/)
|
|
238
|
-
|
|
239
|
-
##
|
|
240
|
-
|
|
241
|
-
|
|
242
251
|
def adjust_frequencies(
|
|
243
252
|
series: Series,
|
|
244
253
|
/,
|
|
@@ -260,29 +269,108 @@ def adjust_frequencies(
|
|
|
260
269
|
##
|
|
261
270
|
|
|
262
271
|
|
|
263
|
-
def
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
272
|
+
def all_dataframe_columns(
|
|
273
|
+
df: DataFrame, expr: IntoExprColumn, /, *exprs: IntoExprColumn
|
|
274
|
+
) -> Series:
|
|
275
|
+
"""Return a DataFrame column with `AND` applied to additional exprs/series."""
|
|
276
|
+
name = get_expr_name(df, expr)
|
|
277
|
+
return df.select(all_horizontal(expr, *exprs).alias(name))[name]
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def any_dataframe_columns(
|
|
281
|
+
df: DataFrame, expr: IntoExprColumn, /, *exprs: IntoExprColumn
|
|
282
|
+
) -> Series:
|
|
283
|
+
"""Return a DataFrame column with `OR` applied to additional exprs/series."""
|
|
284
|
+
name = get_expr_name(df, expr)
|
|
285
|
+
return df.select(any_horizontal(expr, *exprs).alias(name))[name]
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def all_series(series: Series, /, *columns: ExprOrSeries) -> Series:
|
|
289
|
+
"""Return a Series with `AND` applied to additional exprs/series."""
|
|
290
|
+
return all_dataframe_columns(series.to_frame(), series.name, *columns)
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def any_series(series: Series, /, *columns: ExprOrSeries) -> Series:
|
|
294
|
+
"""Return a Series with `OR` applied to additional exprs/series."""
|
|
295
|
+
df = series.to_frame()
|
|
296
|
+
name = series.name
|
|
297
|
+
return df.select(any_horizontal(name, *columns).alias(name))[name]
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
##
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def append_row(
|
|
304
|
+
df: DataFrame,
|
|
305
|
+
row: StrMapping,
|
|
306
|
+
/,
|
|
307
|
+
*,
|
|
308
|
+
predicate: Callable[[StrMapping], bool] | None = None,
|
|
309
|
+
disallow_extra: bool = False,
|
|
310
|
+
disallow_missing: bool | MaybeIterable[str] = False,
|
|
311
|
+
disallow_null: bool | MaybeIterable[str] = False,
|
|
312
|
+
in_place: bool = False,
|
|
313
|
+
) -> DataFrame:
|
|
314
|
+
"""Append a row to a DataFrame."""
|
|
315
|
+
if (predicate is not None) and not predicate(row):
|
|
316
|
+
raise _AppendRowPredicateError(df=df, row=row)
|
|
317
|
+
if disallow_extra and (len(extra := set(row) - set(df.columns)) >= 1):
|
|
318
|
+
raise _AppendRowExtraKeysError(df=df, row=row, extra=extra)
|
|
319
|
+
if disallow_missing is not False:
|
|
320
|
+
missing = set(df.columns) - set(row)
|
|
321
|
+
if disallow_missing is not True:
|
|
322
|
+
missing &= set(always_iterable(disallow_missing))
|
|
323
|
+
if len(missing) >= 1:
|
|
324
|
+
raise _AppendRowMissingKeysError(df=df, row=row, missing=missing)
|
|
325
|
+
other = DataFrame(data=[row], schema=df.schema)
|
|
326
|
+
if disallow_null:
|
|
327
|
+
other_null = other.select(col(c).is_null().any() for c in other.columns)
|
|
328
|
+
null = {k for k, v in other_null.row(0, named=True).items() if v}
|
|
329
|
+
if disallow_null is not True:
|
|
330
|
+
null &= set(always_iterable(disallow_null))
|
|
331
|
+
if len(null) >= 1:
|
|
332
|
+
raise _AppendRowNullColumnsError(df=df, row=row, columns=null)
|
|
333
|
+
return df.extend(other) if in_place else df.vstack(other)
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
@dataclass(kw_only=True, slots=True)
|
|
337
|
+
class AppendRowError(Exception):
|
|
338
|
+
df: DataFrame
|
|
339
|
+
row: StrMapping
|
|
275
340
|
|
|
276
341
|
|
|
277
342
|
@dataclass(kw_only=True, slots=True)
|
|
278
|
-
class
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
343
|
+
class _AppendRowPredicateError(AppendRowError):
|
|
344
|
+
@override
|
|
345
|
+
def __str__(self) -> str:
|
|
346
|
+
return f"Predicate failed; got {get_repr(self.row)}"
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
@dataclass(kw_only=True, slots=True)
|
|
350
|
+
class _AppendRowExtraKeysError(AppendRowError):
|
|
351
|
+
extra: AbstractSet[str]
|
|
282
352
|
|
|
283
353
|
@override
|
|
284
354
|
def __str__(self) -> str:
|
|
285
|
-
return f"
|
|
355
|
+
return f"Extra key(s) found; got {get_repr(self.extra)}"
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
@dataclass(kw_only=True, slots=True)
|
|
359
|
+
class _AppendRowMissingKeysError(AppendRowError):
|
|
360
|
+
missing: AbstractSet[str]
|
|
361
|
+
|
|
362
|
+
@override
|
|
363
|
+
def __str__(self) -> str:
|
|
364
|
+
return f"Missing key(s) found; got {get_repr(self.missing)}"
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
@dataclass(kw_only=True, slots=True)
|
|
368
|
+
class _AppendRowNullColumnsError(AppendRowError):
|
|
369
|
+
columns: AbstractSet[str]
|
|
370
|
+
|
|
371
|
+
@override
|
|
372
|
+
def __str__(self) -> str:
|
|
373
|
+
return f"Null column(s) found; got {get_repr(self.columns)}"
|
|
286
374
|
|
|
287
375
|
|
|
288
376
|
##
|
|
@@ -297,8 +385,8 @@ def are_frames_equal(
|
|
|
297
385
|
check_column_order: bool = True,
|
|
298
386
|
check_dtypes: bool = True,
|
|
299
387
|
check_exact: bool = False,
|
|
300
|
-
|
|
301
|
-
|
|
388
|
+
rel_tol: float = 1e-5,
|
|
389
|
+
abs_tol: float = 1e-8,
|
|
302
390
|
categorical_as_str: bool = False,
|
|
303
391
|
) -> bool:
|
|
304
392
|
"""Check if two DataFrames are equal."""
|
|
@@ -310,8 +398,8 @@ def are_frames_equal(
|
|
|
310
398
|
check_column_order=check_column_order,
|
|
311
399
|
check_dtypes=check_dtypes,
|
|
312
400
|
check_exact=check_exact,
|
|
313
|
-
|
|
314
|
-
|
|
401
|
+
rel_tol=rel_tol,
|
|
402
|
+
abs_tol=abs_tol,
|
|
315
403
|
categorical_as_str=categorical_as_str,
|
|
316
404
|
)
|
|
317
405
|
except AssertionError:
|
|
@@ -341,7 +429,7 @@ def bernoulli(
|
|
|
341
429
|
return bernoulli(series.len(), true=true, seed=seed, name=name)
|
|
342
430
|
case DataFrame() as df:
|
|
343
431
|
return bernoulli(df.height, true=true, seed=seed, name=name)
|
|
344
|
-
case
|
|
432
|
+
case never:
|
|
345
433
|
assert_never(never)
|
|
346
434
|
|
|
347
435
|
|
|
@@ -375,7 +463,7 @@ def boolean_value_counts(
|
|
|
375
463
|
(false / total).alias("false (%)"),
|
|
376
464
|
(null / total).alias("null (%)"),
|
|
377
465
|
)
|
|
378
|
-
case
|
|
466
|
+
case never:
|
|
379
467
|
assert_never(never)
|
|
380
468
|
|
|
381
469
|
|
|
@@ -418,29 +506,6 @@ class BooleanValueCountsError(Exception):
|
|
|
418
506
|
##
|
|
419
507
|
|
|
420
508
|
|
|
421
|
-
@overload
|
|
422
|
-
def ceil_datetime(column: ExprLike, every: ExprLike, /) -> Expr: ...
|
|
423
|
-
@overload
|
|
424
|
-
def ceil_datetime(column: Series, every: ExprLike, /) -> Series: ...
|
|
425
|
-
@overload
|
|
426
|
-
def ceil_datetime(column: IntoExprColumn, every: ExprLike, /) -> Expr | Series: ...
|
|
427
|
-
def ceil_datetime(column: IntoExprColumn, every: ExprLike, /) -> Expr | Series:
|
|
428
|
-
"""Compute the `ceil` of a datetime column."""
|
|
429
|
-
column = ensure_expr_or_series(column)
|
|
430
|
-
rounded = column.dt.round(every)
|
|
431
|
-
ceil = (
|
|
432
|
-
when(column <= rounded)
|
|
433
|
-
.then(rounded)
|
|
434
|
-
.otherwise(column.dt.offset_by(every).dt.round(every))
|
|
435
|
-
)
|
|
436
|
-
if isinstance(column, Expr):
|
|
437
|
-
return ceil
|
|
438
|
-
return DataFrame().with_columns(ceil.alias(column.name))[column.name]
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
##
|
|
442
|
-
|
|
443
|
-
|
|
444
509
|
def check_polars_dataframe(
|
|
445
510
|
df: DataFrame,
|
|
446
511
|
/,
|
|
@@ -500,7 +565,7 @@ def _check_polars_dataframe_columns(df: DataFrame, columns: Iterable[str], /) ->
|
|
|
500
565
|
|
|
501
566
|
@dataclass(kw_only=True, slots=True)
|
|
502
567
|
class _CheckPolarsDataFrameColumnsError(CheckPolarsDataFrameError):
|
|
503
|
-
columns:
|
|
568
|
+
columns: list[str]
|
|
504
569
|
|
|
505
570
|
@override
|
|
506
571
|
def __str__(self) -> str:
|
|
@@ -759,29 +824,22 @@ def choice(
|
|
|
759
824
|
name=name,
|
|
760
825
|
dtype=dtype,
|
|
761
826
|
)
|
|
762
|
-
case
|
|
827
|
+
case never:
|
|
763
828
|
assert_never(never)
|
|
764
829
|
|
|
765
830
|
|
|
766
831
|
##
|
|
767
832
|
|
|
768
833
|
|
|
769
|
-
def
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
return data[one(data.columns)]
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
##
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
def columns_to_dict(df: DataFrame, key: str, value: str, /) -> dict[Any, Any]:
|
|
834
|
+
def columns_to_dict(
|
|
835
|
+
df: DataFrame, key: IntoExprColumn, value: IntoExprColumn, /
|
|
836
|
+
) -> dict[Any, Any]:
|
|
779
837
|
"""Map a pair of columns into a dictionary. Must be unique on `key`."""
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
return dict(zip(
|
|
838
|
+
df = df.select(key, value)
|
|
839
|
+
key_col, value_col = [df[get_expr_name(df, expr)] for expr in [key, value]]
|
|
840
|
+
if key_col.is_duplicated().any():
|
|
841
|
+
raise ColumnsToDictError(df=df, key=key_col.name)
|
|
842
|
+
return dict(zip(key_col, value_col, strict=True))
|
|
785
843
|
|
|
786
844
|
|
|
787
845
|
@dataclass(kw_only=True, slots=True)
|
|
@@ -824,7 +882,7 @@ def convert_time_zone(
|
|
|
824
882
|
|
|
825
883
|
def _convert_time_zone_one(sr: Series, /, *, time_zone: TimeZoneLike = UTC) -> Series:
|
|
826
884
|
if isinstance(sr.dtype, Datetime):
|
|
827
|
-
return sr.dt.convert_time_zone(
|
|
885
|
+
return sr.dt.convert_time_zone(to_time_zone_name(time_zone))
|
|
828
886
|
return sr
|
|
829
887
|
|
|
830
888
|
|
|
@@ -845,13 +903,13 @@ def cross(
|
|
|
845
903
|
up_or_down: Literal["up", "down"],
|
|
846
904
|
other: Number | IntoExprColumn,
|
|
847
905
|
/,
|
|
848
|
-
) ->
|
|
906
|
+
) -> ExprOrSeries: ...
|
|
849
907
|
def cross(
|
|
850
908
|
expr: IntoExprColumn,
|
|
851
909
|
up_or_down: Literal["up", "down"],
|
|
852
910
|
other: Number | IntoExprColumn,
|
|
853
911
|
/,
|
|
854
|
-
) ->
|
|
912
|
+
) -> ExprOrSeries:
|
|
855
913
|
"""Compute when a cross occurs."""
|
|
856
914
|
return _cross_or_touch(expr, "cross", up_or_down, other)
|
|
857
915
|
|
|
@@ -870,13 +928,13 @@ def touch(
|
|
|
870
928
|
up_or_down: Literal["up", "down"],
|
|
871
929
|
other: Number | IntoExprColumn,
|
|
872
930
|
/,
|
|
873
|
-
) ->
|
|
931
|
+
) -> ExprOrSeries: ...
|
|
874
932
|
def touch(
|
|
875
933
|
expr: IntoExprColumn,
|
|
876
934
|
up_or_down: Literal["up", "down"],
|
|
877
935
|
other: Number | IntoExprColumn,
|
|
878
936
|
/,
|
|
879
|
-
) ->
|
|
937
|
+
) -> ExprOrSeries:
|
|
880
938
|
"""Compute when a touch occurs."""
|
|
881
939
|
return _cross_or_touch(expr, "touch", up_or_down, other)
|
|
882
940
|
|
|
@@ -887,7 +945,7 @@ def _cross_or_touch(
|
|
|
887
945
|
up_or_down: Literal["up", "down"],
|
|
888
946
|
other: Number | IntoExprColumn,
|
|
889
947
|
/,
|
|
890
|
-
) ->
|
|
948
|
+
) -> ExprOrSeries:
|
|
891
949
|
"""Compute when a column crosses/touches a threshold."""
|
|
892
950
|
expr = ensure_expr_or_series(expr)
|
|
893
951
|
match other:
|
|
@@ -895,7 +953,7 @@ def _cross_or_touch(
|
|
|
895
953
|
...
|
|
896
954
|
case str() | Expr() | Series():
|
|
897
955
|
other = ensure_expr_or_series(other)
|
|
898
|
-
case
|
|
956
|
+
case never:
|
|
899
957
|
assert_never(never)
|
|
900
958
|
enough = int_range(end=pl.len()) >= 1
|
|
901
959
|
match cross_or_touch, up_or_down:
|
|
@@ -907,7 +965,7 @@ def _cross_or_touch(
|
|
|
907
965
|
current = expr >= other
|
|
908
966
|
case "touch", "down":
|
|
909
967
|
current = expr <= other
|
|
910
|
-
case
|
|
968
|
+
case never:
|
|
911
969
|
assert_never(never)
|
|
912
970
|
prev = current.shift()
|
|
913
971
|
result = when(enough & expr.is_finite()).then(current & ~prev)
|
|
@@ -959,7 +1017,7 @@ def cross_rolling_quantile(
|
|
|
959
1017
|
weights: list[float] | None = None,
|
|
960
1018
|
min_samples: int | None = None,
|
|
961
1019
|
center: bool = False,
|
|
962
|
-
) ->
|
|
1020
|
+
) -> ExprOrSeries: ...
|
|
963
1021
|
def cross_rolling_quantile(
|
|
964
1022
|
expr: IntoExprColumn,
|
|
965
1023
|
up_or_down: Literal["up", "down"],
|
|
@@ -971,7 +1029,7 @@ def cross_rolling_quantile(
|
|
|
971
1029
|
weights: list[float] | None = None,
|
|
972
1030
|
min_samples: int | None = None,
|
|
973
1031
|
center: bool = False,
|
|
974
|
-
) ->
|
|
1032
|
+
) -> ExprOrSeries:
|
|
975
1033
|
"""Compute when a column crosses its rolling quantile."""
|
|
976
1034
|
expr = ensure_expr_or_series(expr)
|
|
977
1035
|
rolling = expr.rolling_quantile(
|
|
@@ -1016,16 +1074,43 @@ def dataclass_to_dataframe(
|
|
|
1016
1074
|
|
|
1017
1075
|
|
|
1018
1076
|
def _dataclass_to_dataframe_cast(series: Series, /) -> Series:
|
|
1019
|
-
if series.dtype
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
|
|
1077
|
+
if series.dtype != Object:
|
|
1078
|
+
return series
|
|
1079
|
+
if series.map_elements(make_isinstance(whenever.Date), return_dtype=Boolean).all():
|
|
1080
|
+
return series.map_elements(lambda x: x.py_date(), return_dtype=pl.Date)
|
|
1081
|
+
if series.map_elements(make_isinstance(DateDelta), return_dtype=Boolean).all():
|
|
1082
|
+
return series.map_elements(to_py_time_delta, return_dtype=Duration)
|
|
1083
|
+
if series.map_elements(make_isinstance(DateTimeDelta), return_dtype=Boolean).all():
|
|
1084
|
+
return series.map_elements(to_py_time_delta, return_dtype=Duration)
|
|
1085
|
+
is_path = series.map_elements(make_isinstance(Path), return_dtype=Boolean).all()
|
|
1086
|
+
is_uuid = series.map_elements(make_isinstance(UUID), return_dtype=Boolean).all()
|
|
1087
|
+
if is_path or is_uuid:
|
|
1088
|
+
with suppress_warnings(
|
|
1089
|
+
category=cast("type[Warning]", PolarsInefficientMapWarning)
|
|
1090
|
+
):
|
|
1091
|
+
return series.map_elements(str, return_dtype=String)
|
|
1092
|
+
if series.map_elements(make_isinstance(whenever.Time), return_dtype=Boolean).all():
|
|
1093
|
+
return series.map_elements(lambda x: x.py_time(), return_dtype=pl.Time)
|
|
1094
|
+
if series.map_elements(make_isinstance(TimeDelta), return_dtype=Boolean).all():
|
|
1095
|
+
return series.map_elements(to_py_time_delta, return_dtype=Duration)
|
|
1096
|
+
if series.map_elements(make_isinstance(ZonedDateTime), return_dtype=Boolean).all():
|
|
1097
|
+
return_dtype = zoned_date_time_dtype(time_zone=one({dt.tz for dt in series}))
|
|
1098
|
+
return series.map_elements(lambda x: x.py_datetime(), return_dtype=return_dtype)
|
|
1099
|
+
if series.map_elements(
|
|
1100
|
+
lambda x: isinstance(x, dict) and (set(x) == {"start", "end"}),
|
|
1101
|
+
return_dtype=Boolean,
|
|
1102
|
+
).all():
|
|
1103
|
+
start = _dataclass_to_dataframe_cast(
|
|
1104
|
+
series.map_elements(lambda x: x["start"], return_dtype=Object)
|
|
1105
|
+
).alias("start")
|
|
1106
|
+
end = _dataclass_to_dataframe_cast(
|
|
1107
|
+
series.map_elements(lambda x: x["end"], return_dtype=Object)
|
|
1108
|
+
).alias("end")
|
|
1109
|
+
name = series.name
|
|
1110
|
+
return concat_series(start, end).select(
|
|
1111
|
+
struct(start=start, end=end).alias(name)
|
|
1112
|
+
)[name]
|
|
1113
|
+
raise NotImplementedError(series) # pragma: no cover
|
|
1029
1114
|
|
|
1030
1115
|
|
|
1031
1116
|
@dataclass(kw_only=True, slots=True)
|
|
@@ -1066,20 +1151,14 @@ def dataclass_to_schema(
|
|
|
1066
1151
|
for field in yield_fields(
|
|
1067
1152
|
obj, globalns=globalns, localns=localns, warn_name_errors=warn_name_errors
|
|
1068
1153
|
):
|
|
1069
|
-
if is_dataclass_instance(field.value)
|
|
1154
|
+
if is_dataclass_instance(field.value) and not (
|
|
1155
|
+
isinstance(field.type_, type)
|
|
1156
|
+
and issubclass(field.type_, (DatePeriod, TimePeriod, ZonedDateTimePeriod))
|
|
1157
|
+
):
|
|
1070
1158
|
dtypes = dataclass_to_schema(
|
|
1071
1159
|
field.value, globalns=globalns, localns=localns
|
|
1072
1160
|
)
|
|
1073
1161
|
dtype = struct_dtype(**dtypes)
|
|
1074
|
-
elif field.type_ is dt.datetime:
|
|
1075
|
-
dtype = _dataclass_to_schema_datetime(field)
|
|
1076
|
-
elif is_union_type(field.type_) and set(
|
|
1077
|
-
get_args(field.type_, optional_drop_none=True)
|
|
1078
|
-
) == {dt.date, dt.datetime}:
|
|
1079
|
-
if is_instance_gen(field.value, dt.date):
|
|
1080
|
-
dtype = Date
|
|
1081
|
-
else:
|
|
1082
|
-
dtype = _dataclass_to_schema_datetime(field)
|
|
1083
1162
|
else:
|
|
1084
1163
|
dtype = _dataclass_to_schema_one(
|
|
1085
1164
|
field.type_, globalns=globalns, localns=localns
|
|
@@ -1088,14 +1167,6 @@ def dataclass_to_schema(
|
|
|
1088
1167
|
return out
|
|
1089
1168
|
|
|
1090
1169
|
|
|
1091
|
-
def _dataclass_to_schema_datetime(
|
|
1092
|
-
field: _YieldFieldsInstance[dt.datetime], /
|
|
1093
|
-
) -> PolarsDataType:
|
|
1094
|
-
if field.value.tzinfo is None:
|
|
1095
|
-
return Datetime
|
|
1096
|
-
return zoned_datetime(time_zone=ensure_time_zone(field.value.tzinfo))
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
1170
|
def _dataclass_to_schema_one(
|
|
1100
1171
|
obj: Any,
|
|
1101
1172
|
/,
|
|
@@ -1103,20 +1174,35 @@ def _dataclass_to_schema_one(
|
|
|
1103
1174
|
globalns: StrMapping | None = None,
|
|
1104
1175
|
localns: StrMapping | None = None,
|
|
1105
1176
|
) -> PolarsDataType:
|
|
1106
|
-
if obj
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
1113
|
-
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1177
|
+
if isinstance(obj, type):
|
|
1178
|
+
if issubclass(obj, bool):
|
|
1179
|
+
return Boolean
|
|
1180
|
+
if issubclass(obj, int):
|
|
1181
|
+
return Int64
|
|
1182
|
+
if issubclass(obj, float):
|
|
1183
|
+
return Float64
|
|
1184
|
+
if issubclass(obj, str):
|
|
1185
|
+
return String
|
|
1186
|
+
if issubclass(
|
|
1187
|
+
obj,
|
|
1188
|
+
(
|
|
1189
|
+
DateDelta,
|
|
1190
|
+
DatePeriod,
|
|
1191
|
+
DateTimeDelta,
|
|
1192
|
+
Path,
|
|
1193
|
+
PlainDateTime,
|
|
1194
|
+
TimeDelta,
|
|
1195
|
+
TimePeriod,
|
|
1196
|
+
UUID,
|
|
1197
|
+
ZonedDateTime,
|
|
1198
|
+
ZonedDateTimePeriod,
|
|
1199
|
+
whenever.Date,
|
|
1200
|
+
whenever.Time,
|
|
1201
|
+
),
|
|
1202
|
+
):
|
|
1203
|
+
return Object
|
|
1204
|
+
if issubclass(obj, enum.Enum):
|
|
1205
|
+
return pl.Enum([e.name for e in obj])
|
|
1120
1206
|
if is_dataclass_class(obj):
|
|
1121
1207
|
out: dict[str, Any] = {}
|
|
1122
1208
|
for field in yield_fields(obj, globalns=globalns, localns=localns):
|
|
@@ -1142,27 +1228,6 @@ def _dataclass_to_schema_one(
|
|
|
1142
1228
|
##
|
|
1143
1229
|
|
|
1144
1230
|
|
|
1145
|
-
def drop_null_struct_series(series: Series, /) -> Series:
|
|
1146
|
-
"""Drop nulls in a struct-dtype Series as per the <= 1.1 definition."""
|
|
1147
|
-
try:
|
|
1148
|
-
is_not_null = is_not_null_struct_series(series)
|
|
1149
|
-
except IsNotNullStructSeriesError as error:
|
|
1150
|
-
raise DropNullStructSeriesError(series=error.series) from None
|
|
1151
|
-
return series.filter(is_not_null)
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
@dataclass(kw_only=True, slots=True)
|
|
1155
|
-
class DropNullStructSeriesError(Exception):
|
|
1156
|
-
series: Series
|
|
1157
|
-
|
|
1158
|
-
@override
|
|
1159
|
-
def __str__(self) -> str:
|
|
1160
|
-
return f"Series must have Struct-dtype; got {self.series.dtype}"
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
##
|
|
1164
|
-
|
|
1165
|
-
|
|
1166
1231
|
def ensure_data_type(dtype: PolarsDataType, /) -> DataType:
|
|
1167
1232
|
"""Ensure a data type is returned."""
|
|
1168
1233
|
return dtype if isinstance(dtype, DataType) else dtype()
|
|
@@ -1176,8 +1241,8 @@ def ensure_expr_or_series(column: ExprLike, /) -> Expr: ...
|
|
|
1176
1241
|
@overload
|
|
1177
1242
|
def ensure_expr_or_series(column: Series, /) -> Series: ...
|
|
1178
1243
|
@overload
|
|
1179
|
-
def ensure_expr_or_series(column: IntoExprColumn, /) ->
|
|
1180
|
-
def ensure_expr_or_series(column: IntoExprColumn, /) ->
|
|
1244
|
+
def ensure_expr_or_series(column: IntoExprColumn, /) -> ExprOrSeries: ...
|
|
1245
|
+
def ensure_expr_or_series(column: IntoExprColumn, /) -> ExprOrSeries:
|
|
1181
1246
|
"""Ensure a column expression or Series is returned."""
|
|
1182
1247
|
return col(column) if isinstance(column, str) else column
|
|
1183
1248
|
|
|
@@ -1187,7 +1252,7 @@ def ensure_expr_or_series(column: IntoExprColumn, /) -> Expr | Series:
|
|
|
1187
1252
|
|
|
1188
1253
|
def ensure_expr_or_series_many(
|
|
1189
1254
|
*columns: IntoExprColumn, **named_columns: IntoExprColumn
|
|
1190
|
-
) -> Sequence[
|
|
1255
|
+
) -> Sequence[ExprOrSeries]:
|
|
1191
1256
|
"""Ensure a set of column expressions and/or Series are returned."""
|
|
1192
1257
|
args = map(ensure_expr_or_series, columns)
|
|
1193
1258
|
kwargs = (ensure_expr_or_series(v).alias(k) for k, v in named_columns.items())
|
|
@@ -1197,6 +1262,119 @@ def ensure_expr_or_series_many(
|
|
|
1197
1262
|
##
|
|
1198
1263
|
|
|
1199
1264
|
|
|
1265
|
+
def expr_to_series(expr: Expr, /) -> Series:
|
|
1266
|
+
"""Collect a column expression into a Series."""
|
|
1267
|
+
return one_column(DataFrame().with_columns(expr))
|
|
1268
|
+
|
|
1269
|
+
|
|
1270
|
+
##
|
|
1271
|
+
|
|
1272
|
+
|
|
1273
|
+
@overload
|
|
1274
|
+
def filter_date(
|
|
1275
|
+
column: ExprLike = "datetime",
|
|
1276
|
+
/,
|
|
1277
|
+
*,
|
|
1278
|
+
time_zone: ZoneInfo | None = None,
|
|
1279
|
+
include: MaybeIterable[whenever.Date] | None = None,
|
|
1280
|
+
exclude: MaybeIterable[whenever.Date] | None = None,
|
|
1281
|
+
) -> Expr: ...
|
|
1282
|
+
@overload
|
|
1283
|
+
def filter_date(
|
|
1284
|
+
column: Series,
|
|
1285
|
+
/,
|
|
1286
|
+
*,
|
|
1287
|
+
time_zone: ZoneInfo | None = None,
|
|
1288
|
+
include: MaybeIterable[whenever.Date] | None = None,
|
|
1289
|
+
exclude: MaybeIterable[whenever.Date] | None = None,
|
|
1290
|
+
) -> Series: ...
|
|
1291
|
+
@overload
|
|
1292
|
+
def filter_date(
|
|
1293
|
+
column: IntoExprColumn = "datetime",
|
|
1294
|
+
/,
|
|
1295
|
+
*,
|
|
1296
|
+
time_zone: ZoneInfo | None = None,
|
|
1297
|
+
include: MaybeIterable[whenever.Date] | None = None,
|
|
1298
|
+
exclude: MaybeIterable[whenever.Date] | None = None,
|
|
1299
|
+
) -> ExprOrSeries: ...
|
|
1300
|
+
def filter_date(
|
|
1301
|
+
column: IntoExprColumn = "datetime",
|
|
1302
|
+
/,
|
|
1303
|
+
*,
|
|
1304
|
+
time_zone: ZoneInfo | None = None,
|
|
1305
|
+
include: MaybeIterable[whenever.Date] | None = None,
|
|
1306
|
+
exclude: MaybeIterable[whenever.Date] | None = None,
|
|
1307
|
+
) -> ExprOrSeries:
|
|
1308
|
+
"""Compute the filter based on a set of dates."""
|
|
1309
|
+
column = ensure_expr_or_series(column)
|
|
1310
|
+
if time_zone is not None:
|
|
1311
|
+
column = column.dt.convert_time_zone(time_zone.key)
|
|
1312
|
+
keep = true_like(column)
|
|
1313
|
+
date = column.dt.date()
|
|
1314
|
+
include, exclude = resolve_include_and_exclude(include=include, exclude=exclude)
|
|
1315
|
+
if include is not None:
|
|
1316
|
+
keep &= date.is_in([d.py_date() for d in include])
|
|
1317
|
+
if exclude is not None:
|
|
1318
|
+
keep &= ~date.is_in([d.py_date() for d in exclude])
|
|
1319
|
+
return try_reify_expr(keep, column)
|
|
1320
|
+
|
|
1321
|
+
|
|
1322
|
+
@overload
|
|
1323
|
+
def filter_time(
|
|
1324
|
+
column: ExprLike = "datetime",
|
|
1325
|
+
/,
|
|
1326
|
+
*,
|
|
1327
|
+
time_zone: ZoneInfo | None = None,
|
|
1328
|
+
include: MaybeIterable[tuple[whenever.Time, whenever.Time]] | None = None,
|
|
1329
|
+
exclude: MaybeIterable[tuple[whenever.Time, whenever.Time]] | None = None,
|
|
1330
|
+
) -> Expr: ...
|
|
1331
|
+
@overload
|
|
1332
|
+
def filter_time(
|
|
1333
|
+
column: Series,
|
|
1334
|
+
/,
|
|
1335
|
+
*,
|
|
1336
|
+
time_zone: ZoneInfo | None = None,
|
|
1337
|
+
include: MaybeIterable[tuple[whenever.Time, whenever.Time]] | None = None,
|
|
1338
|
+
exclude: MaybeIterable[tuple[whenever.Time, whenever.Time]] | None = None,
|
|
1339
|
+
) -> Series: ...
|
|
1340
|
+
@overload
|
|
1341
|
+
def filter_time(
|
|
1342
|
+
column: IntoExprColumn = "datetime",
|
|
1343
|
+
/,
|
|
1344
|
+
*,
|
|
1345
|
+
time_zone: ZoneInfo | None = None,
|
|
1346
|
+
include: MaybeIterable[tuple[whenever.Time, whenever.Time]] | None = None,
|
|
1347
|
+
exclude: MaybeIterable[tuple[whenever.Time, whenever.Time]] | None = None,
|
|
1348
|
+
) -> ExprOrSeries: ...
|
|
1349
|
+
def filter_time(
|
|
1350
|
+
column: IntoExprColumn = "datetime",
|
|
1351
|
+
/,
|
|
1352
|
+
*,
|
|
1353
|
+
time_zone: ZoneInfo | None = None,
|
|
1354
|
+
include: MaybeIterable[tuple[whenever.Time, whenever.Time]] | None = None,
|
|
1355
|
+
exclude: MaybeIterable[tuple[whenever.Time, whenever.Time]] | None = None,
|
|
1356
|
+
) -> ExprOrSeries:
|
|
1357
|
+
"""Compute the filter based on a set of times."""
|
|
1358
|
+
column = ensure_expr_or_series(column)
|
|
1359
|
+
if time_zone is not None:
|
|
1360
|
+
column = column.dt.convert_time_zone(time_zone.key)
|
|
1361
|
+
keep = true_like(column)
|
|
1362
|
+
time = column.dt.time()
|
|
1363
|
+
include, exclude = resolve_include_and_exclude(include=include, exclude=exclude)
|
|
1364
|
+
if include is not None:
|
|
1365
|
+
keep &= any_horizontal(
|
|
1366
|
+
time.is_between(s.py_time(), e.py_time()) for s, e in include
|
|
1367
|
+
)
|
|
1368
|
+
if exclude is not None:
|
|
1369
|
+
keep &= ~any_horizontal(
|
|
1370
|
+
time.is_between(s.py_time(), e.py_time()) for s, e in exclude
|
|
1371
|
+
)
|
|
1372
|
+
return try_reify_expr(keep, column)
|
|
1373
|
+
|
|
1374
|
+
|
|
1375
|
+
##
|
|
1376
|
+
|
|
1377
|
+
|
|
1200
1378
|
@overload
|
|
1201
1379
|
def finite_ewm_mean(
|
|
1202
1380
|
column: ExprLike,
|
|
@@ -1229,7 +1407,7 @@ def finite_ewm_mean(
|
|
|
1229
1407
|
half_life: float | None = None,
|
|
1230
1408
|
alpha: float | None = None,
|
|
1231
1409
|
min_weight: float = _FINITE_EWM_MIN_WEIGHT,
|
|
1232
|
-
) ->
|
|
1410
|
+
) -> ExprOrSeries: ...
|
|
1233
1411
|
def finite_ewm_mean(
|
|
1234
1412
|
column: IntoExprColumn,
|
|
1235
1413
|
/,
|
|
@@ -1239,7 +1417,7 @@ def finite_ewm_mean(
|
|
|
1239
1417
|
half_life: float | None = None,
|
|
1240
1418
|
alpha: float | None = None,
|
|
1241
1419
|
min_weight: float = _FINITE_EWM_MIN_WEIGHT,
|
|
1242
|
-
) ->
|
|
1420
|
+
) -> ExprOrSeries:
|
|
1243
1421
|
"""Compute a finite EWMA."""
|
|
1244
1422
|
try:
|
|
1245
1423
|
weights = _finite_ewm_weights(
|
|
@@ -1301,23 +1479,14 @@ class _FiniteEWMWeightsError(Exception):
|
|
|
1301
1479
|
|
|
1302
1480
|
|
|
1303
1481
|
@overload
|
|
1304
|
-
def
|
|
1482
|
+
def first_true_horizontal(*columns: Series) -> Series: ...
|
|
1305
1483
|
@overload
|
|
1306
|
-
def
|
|
1307
|
-
|
|
1308
|
-
|
|
1309
|
-
|
|
1310
|
-
|
|
1311
|
-
|
|
1312
|
-
rounded = column.dt.round(every)
|
|
1313
|
-
floor = (
|
|
1314
|
-
when(column >= rounded)
|
|
1315
|
-
.then(rounded)
|
|
1316
|
-
.otherwise(column.dt.offset_by("-" + every).dt.round(every))
|
|
1317
|
-
)
|
|
1318
|
-
if isinstance(column, Expr):
|
|
1319
|
-
return floor
|
|
1320
|
-
return DataFrame().with_columns(floor.alias(column.name))[column.name]
|
|
1484
|
+
def first_true_horizontal(*columns: IntoExprColumn) -> ExprOrSeries: ...
|
|
1485
|
+
def first_true_horizontal(*columns: IntoExprColumn) -> ExprOrSeries:
|
|
1486
|
+
"""Get the index of the first true in each row."""
|
|
1487
|
+
columns2 = ensure_expr_or_series_many(*columns)
|
|
1488
|
+
expr = when(any_horizontal(*columns2)).then(concat_list(*columns2).list.arg_max())
|
|
1489
|
+
return try_reify_expr(expr, *columns2)
|
|
1321
1490
|
|
|
1322
1491
|
|
|
1323
1492
|
##
|
|
@@ -1334,13 +1503,24 @@ def get_data_type_or_series_time_zone(
|
|
|
1334
1503
|
dtype = dtype_cls()
|
|
1335
1504
|
case Series() as series:
|
|
1336
1505
|
dtype = series.dtype
|
|
1337
|
-
case
|
|
1506
|
+
case never:
|
|
1338
1507
|
assert_never(never)
|
|
1339
|
-
|
|
1340
|
-
|
|
1341
|
-
|
|
1342
|
-
|
|
1343
|
-
|
|
1508
|
+
match dtype:
|
|
1509
|
+
case Datetime() as datetime:
|
|
1510
|
+
if datetime.time_zone is None:
|
|
1511
|
+
raise _GetDataTypeOrSeriesTimeZoneNotZonedError(dtype=datetime)
|
|
1512
|
+
return ZoneInfo(datetime.time_zone)
|
|
1513
|
+
case Struct() as struct:
|
|
1514
|
+
try:
|
|
1515
|
+
return one({
|
|
1516
|
+
get_data_type_or_series_time_zone(f.dtype) for f in struct.fields
|
|
1517
|
+
})
|
|
1518
|
+
except OneNonUniqueError as error:
|
|
1519
|
+
raise _GetDataTypeOrSeriesTimeZoneStructNonUniqueError(
|
|
1520
|
+
dtype=struct, first=error.first, second=error.second
|
|
1521
|
+
) from None
|
|
1522
|
+
case _:
|
|
1523
|
+
raise _GetDataTypeOrSeriesTimeZoneNotDateTimeError(dtype=dtype)
|
|
1344
1524
|
|
|
1345
1525
|
|
|
1346
1526
|
@dataclass(kw_only=True, slots=True)
|
|
@@ -1362,6 +1542,18 @@ class _GetDataTypeOrSeriesTimeZoneNotZonedError(GetDataTypeOrSeriesTimeZoneError
|
|
|
1362
1542
|
return f"Data type must be zoned; got {self.dtype}"
|
|
1363
1543
|
|
|
1364
1544
|
|
|
1545
|
+
@dataclass(kw_only=True, slots=True)
|
|
1546
|
+
class _GetDataTypeOrSeriesTimeZoneStructNonUniqueError(
|
|
1547
|
+
GetDataTypeOrSeriesTimeZoneError
|
|
1548
|
+
):
|
|
1549
|
+
first: ZoneInfo
|
|
1550
|
+
second: ZoneInfo
|
|
1551
|
+
|
|
1552
|
+
@override
|
|
1553
|
+
def __str__(self) -> str:
|
|
1554
|
+
return f"Struct data type must contain exactly one time zone; got {self.first}, {self.second} and perhaps more"
|
|
1555
|
+
|
|
1556
|
+
|
|
1365
1557
|
##
|
|
1366
1558
|
|
|
1367
1559
|
|
|
@@ -1371,9 +1563,8 @@ def get_expr_name(obj: Series | DataFrame, expr: IntoExprColumn, /) -> str:
|
|
|
1371
1563
|
case Series() as series:
|
|
1372
1564
|
return get_expr_name(series.to_frame(), expr)
|
|
1373
1565
|
case DataFrame() as df:
|
|
1374
|
-
|
|
1375
|
-
|
|
1376
|
-
case _ as never:
|
|
1566
|
+
return one_column(df.select(expr)).name
|
|
1567
|
+
case never:
|
|
1377
1568
|
assert_never(never)
|
|
1378
1569
|
|
|
1379
1570
|
|
|
@@ -1395,50 +1586,31 @@ def get_frequency_spectrum(series: Series, /, *, d: int = 1) -> DataFrame:
|
|
|
1395
1586
|
|
|
1396
1587
|
|
|
1397
1588
|
@overload
|
|
1398
|
-
def
|
|
1399
|
-
series: Series, /, *, nullable: Literal[True]
|
|
1400
|
-
) -> int | None: ...
|
|
1589
|
+
def increasing_horizontal(*columns: ExprLike) -> Expr: ...
|
|
1401
1590
|
@overload
|
|
1402
|
-
def
|
|
1403
|
-
series: Series, /, *, nullable: Literal[False] = False
|
|
1404
|
-
) -> int: ...
|
|
1591
|
+
def increasing_horizontal(*columns: Series) -> Series: ...
|
|
1405
1592
|
@overload
|
|
1406
|
-
def
|
|
1407
|
-
|
|
1408
|
-
|
|
1409
|
-
|
|
1410
|
-
|
|
1411
|
-
|
|
1412
|
-
|
|
1413
|
-
if not isinstance(dtype := series.dtype, Float64):
|
|
1414
|
-
raise _GetSeriesNumberOfDecimalsNotFloatError(dtype=dtype)
|
|
1415
|
-
decimals = series.map_elements(number_of_decimals, return_dtype=Int64).max()
|
|
1416
|
-
try:
|
|
1417
|
-
return ensure_int(decimals, nullable=nullable)
|
|
1418
|
-
except EnsureIntError:
|
|
1419
|
-
raise _GetSeriesNumberOfDecimalsAllNullError(series=series) from None
|
|
1593
|
+
def increasing_horizontal(*columns: IntoExprColumn) -> ExprOrSeries: ...
|
|
1594
|
+
def increasing_horizontal(*columns: IntoExprColumn) -> ExprOrSeries:
|
|
1595
|
+
"""Check if a set of columns are increasing."""
|
|
1596
|
+
columns2 = ensure_expr_or_series_many(*columns)
|
|
1597
|
+
if len(columns2) == 0:
|
|
1598
|
+
return lit(value=True, dtype=Boolean)
|
|
1599
|
+
return all_horizontal(prev < curr for prev, curr in pairwise(columns2))
|
|
1420
1600
|
|
|
1421
1601
|
|
|
1422
|
-
@
|
|
1423
|
-
|
|
1424
|
-
|
|
1425
|
-
|
|
1426
|
-
@
|
|
1427
|
-
|
|
1428
|
-
|
|
1429
|
-
|
|
1430
|
-
|
|
1431
|
-
|
|
1432
|
-
return
|
|
1433
|
-
|
|
1434
|
-
|
|
1435
|
-
@dataclass(kw_only=True, slots=True)
|
|
1436
|
-
class _GetSeriesNumberOfDecimalsAllNullError(GetSeriesNumberOfDecimalsError):
|
|
1437
|
-
series: Series
|
|
1438
|
-
|
|
1439
|
-
@override
|
|
1440
|
-
def __str__(self) -> str:
|
|
1441
|
-
return f"Series must not be all-null; got {self.series}"
|
|
1602
|
+
@overload
|
|
1603
|
+
def decreasing_horizontal(*columns: ExprLike) -> Expr: ...
|
|
1604
|
+
@overload
|
|
1605
|
+
def decreasing_horizontal(*columns: Series) -> Series: ...
|
|
1606
|
+
@overload
|
|
1607
|
+
def decreasing_horizontal(*columns: IntoExprColumn) -> ExprOrSeries: ...
|
|
1608
|
+
def decreasing_horizontal(*columns: IntoExprColumn) -> ExprOrSeries:
|
|
1609
|
+
"""Check if a set of columns are decreasing."""
|
|
1610
|
+
columns2 = ensure_expr_or_series_many(*columns)
|
|
1611
|
+
if len(columns2) == 0:
|
|
1612
|
+
return lit(value=True, dtype=Boolean)
|
|
1613
|
+
return all_horizontal(prev > curr for prev, curr in pairwise(columns2))
|
|
1442
1614
|
|
|
1443
1615
|
|
|
1444
1616
|
##
|
|
@@ -1571,13 +1743,49 @@ def integers(
|
|
|
1571
1743
|
name=name,
|
|
1572
1744
|
dtype=dtype,
|
|
1573
1745
|
)
|
|
1574
|
-
case
|
|
1746
|
+
case never:
|
|
1575
1747
|
assert_never(never)
|
|
1576
1748
|
|
|
1577
1749
|
|
|
1578
1750
|
##
|
|
1579
1751
|
|
|
1580
1752
|
|
|
1753
|
+
@overload
|
|
1754
|
+
def is_close(
|
|
1755
|
+
x: ExprLike, y: ExprLike, /, *, rel_tol: float = 1e-9, abs_tol: float = 0
|
|
1756
|
+
) -> Expr: ...
|
|
1757
|
+
@overload
|
|
1758
|
+
def is_close(
|
|
1759
|
+
x: Series, y: Series, /, *, rel_tol: float = 1e-9, abs_tol: float = 0
|
|
1760
|
+
) -> Series: ...
|
|
1761
|
+
@overload
|
|
1762
|
+
def is_close(
|
|
1763
|
+
x: IntoExprColumn,
|
|
1764
|
+
y: IntoExprColumn,
|
|
1765
|
+
/,
|
|
1766
|
+
*,
|
|
1767
|
+
rel_tol: float = 1e-9,
|
|
1768
|
+
abs_tol: float = 0,
|
|
1769
|
+
) -> ExprOrSeries: ...
|
|
1770
|
+
def is_close(
|
|
1771
|
+
x: IntoExprColumn,
|
|
1772
|
+
y: IntoExprColumn,
|
|
1773
|
+
/,
|
|
1774
|
+
*,
|
|
1775
|
+
rel_tol: float = 1e-9,
|
|
1776
|
+
abs_tol: float = 0,
|
|
1777
|
+
) -> ExprOrSeries:
|
|
1778
|
+
"""Check if two columns are close."""
|
|
1779
|
+
x, y = map(ensure_expr_or_series, [x, y])
|
|
1780
|
+
result = (x - y).abs() <= max_horizontal(
|
|
1781
|
+
rel_tol * max_horizontal(x.abs(), y.abs()), abs_tol
|
|
1782
|
+
)
|
|
1783
|
+
return try_reify_expr(result, x, y)
|
|
1784
|
+
|
|
1785
|
+
|
|
1786
|
+
##
|
|
1787
|
+
|
|
1788
|
+
|
|
1581
1789
|
@overload
|
|
1582
1790
|
def is_near_event(
|
|
1583
1791
|
*exprs: ExprLike, before: int = 0, after: int = 0, **named_exprs: ExprLike
|
|
@@ -1592,13 +1800,13 @@ def is_near_event(
|
|
|
1592
1800
|
before: int = 0,
|
|
1593
1801
|
after: int = 0,
|
|
1594
1802
|
**named_exprs: IntoExprColumn,
|
|
1595
|
-
) ->
|
|
1803
|
+
) -> ExprOrSeries: ...
|
|
1596
1804
|
def is_near_event(
|
|
1597
1805
|
*exprs: IntoExprColumn,
|
|
1598
1806
|
before: int = 0,
|
|
1599
1807
|
after: int = 0,
|
|
1600
1808
|
**named_exprs: IntoExprColumn,
|
|
1601
|
-
) ->
|
|
1809
|
+
) -> ExprOrSeries:
|
|
1602
1810
|
"""Compute the rows near any event."""
|
|
1603
1811
|
if before <= -1:
|
|
1604
1812
|
raise _IsNearEventBeforeError(before=before)
|
|
@@ -1641,87 +1849,177 @@ class _IsNearEventAfterError(IsNearEventError):
|
|
|
1641
1849
|
##
|
|
1642
1850
|
|
|
1643
1851
|
|
|
1644
|
-
|
|
1645
|
-
|
|
1646
|
-
|
|
1647
|
-
|
|
1648
|
-
|
|
1649
|
-
|
|
1852
|
+
@overload
|
|
1853
|
+
def is_true(column: ExprLike, /) -> Expr: ...
|
|
1854
|
+
@overload
|
|
1855
|
+
def is_true(column: Series, /) -> Series: ...
|
|
1856
|
+
@overload
|
|
1857
|
+
def is_true(column: IntoExprColumn, /) -> ExprOrSeries: ...
|
|
1858
|
+
def is_true(column: IntoExprColumn, /) -> ExprOrSeries:
|
|
1859
|
+
"""Compute when a boolean series is True."""
|
|
1860
|
+
column = ensure_expr_or_series(column)
|
|
1861
|
+
return (column.is_not_null()) & column
|
|
1650
1862
|
|
|
1651
1863
|
|
|
1652
|
-
@
|
|
1653
|
-
|
|
1654
|
-
|
|
1864
|
+
@overload
|
|
1865
|
+
def is_false(column: ExprLike, /) -> Expr: ...
|
|
1866
|
+
@overload
|
|
1867
|
+
def is_false(column: Series, /) -> Series: ...
|
|
1868
|
+
@overload
|
|
1869
|
+
def is_false(column: IntoExprColumn, /) -> ExprOrSeries: ...
|
|
1870
|
+
def is_false(column: IntoExprColumn, /) -> ExprOrSeries:
|
|
1871
|
+
"""Compute when a boolean series is False."""
|
|
1872
|
+
column = ensure_expr_or_series(column)
|
|
1873
|
+
return (column.is_not_null()) & (~column)
|
|
1655
1874
|
|
|
1656
|
-
|
|
1657
|
-
|
|
1658
|
-
|
|
1875
|
+
|
|
1876
|
+
##
|
|
1877
|
+
|
|
1878
|
+
|
|
1879
|
+
def join(
|
|
1880
|
+
df: DataFrame,
|
|
1881
|
+
*dfs: DataFrame,
|
|
1882
|
+
on: MaybeIterable[str | Expr],
|
|
1883
|
+
how: JoinStrategy = "inner",
|
|
1884
|
+
validate: JoinValidation = "m:m",
|
|
1885
|
+
) -> DataFrame:
|
|
1886
|
+
"""Join a set of DataFrames."""
|
|
1887
|
+
on_use = on if isinstance(on, str | Expr) else list(on)
|
|
1888
|
+
|
|
1889
|
+
def inner(left: DataFrame, right: DataFrame, /) -> DataFrame:
|
|
1890
|
+
return left.join(right, on=on_use, how=how, validate=validate)
|
|
1891
|
+
|
|
1892
|
+
return reduce(inner, chain([df], dfs))
|
|
1659
1893
|
|
|
1660
1894
|
|
|
1661
1895
|
##
|
|
1662
1896
|
|
|
1663
1897
|
|
|
1664
|
-
def
|
|
1665
|
-
|
|
1666
|
-
|
|
1667
|
-
|
|
1668
|
-
|
|
1669
|
-
|
|
1670
|
-
|
|
1671
|
-
|
|
1672
|
-
|
|
1673
|
-
|
|
1898
|
+
def join_into_periods(
|
|
1899
|
+
left: DataFrame,
|
|
1900
|
+
right: DataFrame,
|
|
1901
|
+
/,
|
|
1902
|
+
*,
|
|
1903
|
+
on: str | None = None,
|
|
1904
|
+
left_on: str | None = None,
|
|
1905
|
+
right_on: str | None = None,
|
|
1906
|
+
suffix: str = "_right",
|
|
1907
|
+
) -> DataFrame:
|
|
1908
|
+
"""Join a pair of DataFrames on their periods; left in right."""
|
|
1909
|
+
match on, left_on, right_on:
|
|
1910
|
+
case None, None, None:
|
|
1911
|
+
return _join_into_periods_core(
|
|
1912
|
+
left, right, "datetime", "datetime", suffix=suffix
|
|
1913
|
+
)
|
|
1914
|
+
case str(), None, None:
|
|
1915
|
+
return _join_into_periods_core(left, right, on, on, suffix=suffix)
|
|
1916
|
+
case None, str(), str():
|
|
1917
|
+
return _join_into_periods_core(
|
|
1918
|
+
left, right, left_on, right_on, suffix=suffix
|
|
1919
|
+
)
|
|
1920
|
+
case _:
|
|
1921
|
+
raise _JoinIntoPeriodsArgumentsError(
|
|
1922
|
+
on=on, left_on=left_on, right_on=right_on
|
|
1923
|
+
)
|
|
1924
|
+
|
|
1925
|
+
|
|
1926
|
+
def _join_into_periods_core(
|
|
1927
|
+
left: DataFrame,
|
|
1928
|
+
right: DataFrame,
|
|
1929
|
+
left_on: str,
|
|
1930
|
+
right_on: str,
|
|
1931
|
+
/,
|
|
1932
|
+
*,
|
|
1933
|
+
suffix: str = "_right",
|
|
1934
|
+
) -> DataFrame:
|
|
1935
|
+
"""Join a pair of DataFrames on their periods; left in right."""
|
|
1936
|
+
_join_into_periods_check(left, left_on, "left")
|
|
1937
|
+
_join_into_periods_check(right, right_on, "right")
|
|
1938
|
+
joined = left.join_asof(
|
|
1939
|
+
right,
|
|
1940
|
+
left_on=col(left_on).struct["start"],
|
|
1941
|
+
right_on=col(right_on).struct["start"],
|
|
1942
|
+
strategy="backward",
|
|
1943
|
+
suffix=suffix,
|
|
1944
|
+
coalesce=False,
|
|
1674
1945
|
)
|
|
1946
|
+
new = f"{left_on}{suffix}" if left_on == right_on else right_on
|
|
1947
|
+
new_col = col(new)
|
|
1948
|
+
is_correct = (new_col.struct["start"] <= col(left_on).struct["start"]) & (
|
|
1949
|
+
col(left_on).struct["end"] <= new_col.struct["end"]
|
|
1950
|
+
)
|
|
1951
|
+
return joined.with_columns(when(is_correct).then(new_col))
|
|
1675
1952
|
|
|
1676
1953
|
|
|
1677
|
-
def
|
|
1678
|
-
|
|
1679
|
-
) ->
|
|
1680
|
-
|
|
1681
|
-
|
|
1682
|
-
|
|
1683
|
-
|
|
1684
|
-
|
|
1685
|
-
|
|
1686
|
-
|
|
1687
|
-
|
|
1954
|
+
def _join_into_periods_check(
|
|
1955
|
+
df: DataFrame, column: str, left_or_right: Literal["left", "right"], /
|
|
1956
|
+
) -> None:
|
|
1957
|
+
start = df[column].struct["start"]
|
|
1958
|
+
end = df[column].struct["end"]
|
|
1959
|
+
if not (start <= end).all():
|
|
1960
|
+
raise _JoinIntoPeriodsPeriodError(left_or_right=left_or_right, column=column)
|
|
1961
|
+
try:
|
|
1962
|
+
assert_series_equal(start, start.sort())
|
|
1963
|
+
except AssertionError:
|
|
1964
|
+
raise _JoinIntoPeriodsSortedError(
|
|
1965
|
+
left_or_right=left_or_right, column=column, start_or_end="start"
|
|
1966
|
+
) from None
|
|
1967
|
+
try:
|
|
1968
|
+
assert_series_equal(end, end.sort())
|
|
1969
|
+
except AssertionError:
|
|
1970
|
+
raise _JoinIntoPeriodsSortedError(
|
|
1971
|
+
left_or_right=left_or_right, column=column, start_or_end="end"
|
|
1972
|
+
) from None
|
|
1973
|
+
if (df.height >= 2) and (end[:-1] > start[1:]).any():
|
|
1974
|
+
raise _JoinIntoPeriodsOverlappingError(
|
|
1975
|
+
left_or_right=left_or_right, column=column
|
|
1976
|
+
)
|
|
1688
1977
|
|
|
1689
1978
|
|
|
1690
|
-
|
|
1691
|
-
|
|
1692
|
-
|
|
1979
|
+
@dataclass(kw_only=True, slots=True)
|
|
1980
|
+
class JoinIntoPeriodsError(Exception): ...
|
|
1981
|
+
|
|
1693
1982
|
|
|
1983
|
+
@dataclass(kw_only=True, slots=True)
|
|
1984
|
+
class _JoinIntoPeriodsArgumentsError(JoinIntoPeriodsError):
|
|
1985
|
+
on: str | None
|
|
1986
|
+
left_on: str | None
|
|
1987
|
+
right_on: str | None
|
|
1694
1988
|
|
|
1695
|
-
|
|
1696
|
-
|
|
1989
|
+
@override
|
|
1990
|
+
def __str__(self) -> str:
|
|
1991
|
+
return f"Either 'on' must be given or 'left_on' and 'right_on' must be given; got {self.on!r}, {self.left_on!r} and {self.right_on!r}"
|
|
1697
1992
|
|
|
1698
1993
|
|
|
1699
1994
|
@dataclass(kw_only=True, slots=True)
|
|
1700
|
-
class
|
|
1701
|
-
|
|
1995
|
+
class _JoinIntoPeriodsPeriodError(JoinIntoPeriodsError):
|
|
1996
|
+
left_or_right: Literal["left", "right"]
|
|
1997
|
+
column: str
|
|
1702
1998
|
|
|
1703
1999
|
@override
|
|
1704
2000
|
def __str__(self) -> str:
|
|
1705
|
-
return f"
|
|
2001
|
+
return f"{self.left_or_right.title()} DataFrame column {self.column!r} must contain valid periods"
|
|
1706
2002
|
|
|
1707
2003
|
|
|
1708
|
-
|
|
2004
|
+
@dataclass(kw_only=True, slots=True)
|
|
2005
|
+
class _JoinIntoPeriodsSortedError(JoinIntoPeriodsError):
|
|
2006
|
+
left_or_right: Literal["left", "right"]
|
|
2007
|
+
column: str
|
|
2008
|
+
start_or_end: Literal["start", "end"]
|
|
1709
2009
|
|
|
2010
|
+
@override
|
|
2011
|
+
def __str__(self) -> str:
|
|
2012
|
+
return f"{self.left_or_right.title()} DataFrame column '{self.column}/{self.start_or_end}' must be sorted"
|
|
1710
2013
|
|
|
1711
|
-
def join(
|
|
1712
|
-
df: DataFrame,
|
|
1713
|
-
*dfs: DataFrame,
|
|
1714
|
-
on: MaybeIterable[str | Expr],
|
|
1715
|
-
how: JoinStrategy = "inner",
|
|
1716
|
-
validate: JoinValidation = "m:m",
|
|
1717
|
-
) -> DataFrame:
|
|
1718
|
-
"""Join a set of DataFrames."""
|
|
1719
|
-
on_use = on if isinstance(on, str | Expr) else list(on)
|
|
1720
2014
|
|
|
1721
|
-
|
|
1722
|
-
|
|
2015
|
+
@dataclass(kw_only=True, slots=True)
|
|
2016
|
+
class _JoinIntoPeriodsOverlappingError(JoinIntoPeriodsError):
|
|
2017
|
+
left_or_right: Literal["left", "right"]
|
|
2018
|
+
column: str
|
|
1723
2019
|
|
|
1724
|
-
|
|
2020
|
+
@override
|
|
2021
|
+
def __str__(self) -> str:
|
|
2022
|
+
return f"{self.left_or_right.title()} DataFrame column {self.column!r} must not contain overlaps"
|
|
1725
2023
|
|
|
1726
2024
|
|
|
1727
2025
|
##
|
|
@@ -1746,7 +2044,7 @@ def map_over_columns(
|
|
|
1746
2044
|
return _map_over_series_one(func, series)
|
|
1747
2045
|
case DataFrame() as df:
|
|
1748
2046
|
return df.select(*(_map_over_series_one(func, df[c]) for c in df.columns))
|
|
1749
|
-
case
|
|
2047
|
+
case never:
|
|
1750
2048
|
assert_never(never)
|
|
1751
2049
|
|
|
1752
2050
|
|
|
@@ -1761,46 +2059,74 @@ def _map_over_series_one(func: Callable[[Series], Series], series: Series, /) ->
|
|
|
1761
2059
|
##
|
|
1762
2060
|
|
|
1763
2061
|
|
|
1764
|
-
def nan_sum_agg(column: str | Expr,
|
|
2062
|
+
def nan_sum_agg(column: str | Expr, /) -> Expr:
|
|
1765
2063
|
"""Nan sum aggregation."""
|
|
1766
2064
|
col_use = col(column) if isinstance(column, str) else column
|
|
1767
|
-
return (
|
|
1768
|
-
when(col_use.is_not_null().any())
|
|
1769
|
-
.then(col_use.sum())
|
|
1770
|
-
.otherwise(lit(None, dtype=dtype))
|
|
1771
|
-
)
|
|
2065
|
+
return when(col_use.is_not_null().any()).then(col_use.sum())
|
|
1772
2066
|
|
|
1773
2067
|
|
|
1774
2068
|
##
|
|
1775
2069
|
|
|
1776
2070
|
|
|
1777
|
-
|
|
1778
|
-
|
|
1779
|
-
|
|
2071
|
+
@overload
|
|
2072
|
+
def nan_sum_horizontal(*columns: Series) -> Series: ...
|
|
2073
|
+
@overload
|
|
2074
|
+
def nan_sum_horizontal(*columns: IntoExprColumn) -> ExprOrSeries: ...
|
|
2075
|
+
def nan_sum_horizontal(*columns: IntoExprColumn) -> ExprOrSeries:
|
|
1780
2076
|
"""Nan sum across columns."""
|
|
1781
|
-
|
|
1782
|
-
|
|
1783
|
-
|
|
2077
|
+
columns2 = ensure_expr_or_series_many(*columns)
|
|
2078
|
+
expr = when(any_horizontal(*(c.is_not_null() for c in columns2))).then(
|
|
2079
|
+
sum_horizontal(*columns2)
|
|
1784
2080
|
)
|
|
2081
|
+
return try_reify_expr(expr, *columns2)
|
|
1785
2082
|
|
|
1786
|
-
def func(x: Expr, y: Expr, /) -> Expr:
|
|
1787
|
-
return (
|
|
1788
|
-
when(x.is_not_null() & y.is_not_null())
|
|
1789
|
-
.then(x + y)
|
|
1790
|
-
.when(x.is_not_null() & y.is_null())
|
|
1791
|
-
.then(x)
|
|
1792
|
-
.when(x.is_null() & y.is_not_null())
|
|
1793
|
-
.then(y)
|
|
1794
|
-
.otherwise(lit(None, dtype=dtype))
|
|
1795
|
-
)
|
|
1796
2083
|
|
|
1797
|
-
|
|
2084
|
+
##
|
|
2085
|
+
|
|
2086
|
+
|
|
2087
|
+
@overload
|
|
2088
|
+
def normal_pdf(
|
|
2089
|
+
x: ExprLike,
|
|
2090
|
+
/,
|
|
2091
|
+
*,
|
|
2092
|
+
loc: float | IntoExprColumn = 0.0,
|
|
2093
|
+
scale: float | IntoExprColumn = 1.0,
|
|
2094
|
+
) -> Expr: ...
|
|
2095
|
+
@overload
|
|
2096
|
+
def normal_pdf(
|
|
2097
|
+
x: Series,
|
|
2098
|
+
/,
|
|
2099
|
+
*,
|
|
2100
|
+
loc: float | IntoExprColumn = 0.0,
|
|
2101
|
+
scale: float | IntoExprColumn = 1.0,
|
|
2102
|
+
) -> Series: ...
|
|
2103
|
+
@overload
|
|
2104
|
+
def normal_pdf(
|
|
2105
|
+
x: IntoExprColumn,
|
|
2106
|
+
/,
|
|
2107
|
+
*,
|
|
2108
|
+
loc: float | IntoExprColumn = 0.0,
|
|
2109
|
+
scale: float | IntoExprColumn = 1.0,
|
|
2110
|
+
) -> ExprOrSeries: ...
|
|
2111
|
+
def normal_pdf(
|
|
2112
|
+
x: IntoExprColumn,
|
|
2113
|
+
/,
|
|
2114
|
+
*,
|
|
2115
|
+
loc: float | IntoExprColumn = 0.0,
|
|
2116
|
+
scale: float | IntoExprColumn = 1.0,
|
|
2117
|
+
) -> ExprOrSeries:
|
|
2118
|
+
"""Compute the PDF of a normal distribution."""
|
|
2119
|
+
x = ensure_expr_or_series(x)
|
|
2120
|
+
loc = loc if isinstance(loc, int | float) else ensure_expr_or_series(loc)
|
|
2121
|
+
scale = scale if isinstance(scale, int | float) else ensure_expr_or_series(scale)
|
|
2122
|
+
expr = (1 / (scale * sqrt(2 * pi))) * (-(1 / 2) * ((x - loc) / scale) ** 2).exp()
|
|
2123
|
+
return try_reify_expr(expr, x)
|
|
1798
2124
|
|
|
1799
2125
|
|
|
1800
2126
|
##
|
|
1801
2127
|
|
|
1802
2128
|
|
|
1803
|
-
def
|
|
2129
|
+
def normal_rv(
|
|
1804
2130
|
obj: int | Series | DataFrame,
|
|
1805
2131
|
/,
|
|
1806
2132
|
*,
|
|
@@ -1819,15 +2145,186 @@ def normal(
|
|
|
1819
2145
|
values = rng.normal(loc=loc, scale=scale, size=height)
|
|
1820
2146
|
return Series(name=name, values=values, dtype=dtype)
|
|
1821
2147
|
case Series() as series:
|
|
1822
|
-
return
|
|
2148
|
+
return normal_rv(
|
|
1823
2149
|
series.len(), loc=loc, scale=scale, seed=seed, name=name, dtype=dtype
|
|
1824
2150
|
)
|
|
1825
2151
|
case DataFrame() as df:
|
|
1826
|
-
return
|
|
2152
|
+
return normal_rv(
|
|
1827
2153
|
df.height, loc=loc, scale=scale, seed=seed, name=name, dtype=dtype
|
|
1828
2154
|
)
|
|
1829
|
-
case
|
|
2155
|
+
case never:
|
|
2156
|
+
assert_never(never)
|
|
2157
|
+
|
|
2158
|
+
|
|
2159
|
+
##
|
|
2160
|
+
|
|
2161
|
+
|
|
2162
|
+
@overload
|
|
2163
|
+
def number_of_decimals(
|
|
2164
|
+
column: ExprLike, /, *, max_decimals: int = MAX_DECIMALS
|
|
2165
|
+
) -> Expr: ...
|
|
2166
|
+
@overload
|
|
2167
|
+
def number_of_decimals(
|
|
2168
|
+
column: Series, /, *, max_decimals: int = MAX_DECIMALS
|
|
2169
|
+
) -> Series: ...
|
|
2170
|
+
@overload
|
|
2171
|
+
def number_of_decimals(
|
|
2172
|
+
column: IntoExprColumn, /, *, max_decimals: int = MAX_DECIMALS
|
|
2173
|
+
) -> ExprOrSeries: ...
|
|
2174
|
+
def number_of_decimals(
|
|
2175
|
+
column: IntoExprColumn, /, *, max_decimals: int = MAX_DECIMALS
|
|
2176
|
+
) -> ExprOrSeries:
|
|
2177
|
+
"""Get the number of decimals."""
|
|
2178
|
+
column = ensure_expr_or_series(column)
|
|
2179
|
+
frac = column - column.floor()
|
|
2180
|
+
results = (
|
|
2181
|
+
_number_of_decimals_check_scale(frac, s) for s in range(max_decimals + 1)
|
|
2182
|
+
)
|
|
2183
|
+
return first_true_horizontal(*results)
|
|
2184
|
+
|
|
2185
|
+
|
|
2186
|
+
def _number_of_decimals_check_scale(frac: ExprOrSeries, scale: int, /) -> ExprOrSeries:
|
|
2187
|
+
scaled = 10**scale * frac
|
|
2188
|
+
return is_close(scaled, scaled.round()).alias(str(scale))
|
|
2189
|
+
|
|
2190
|
+
|
|
2191
|
+
##
|
|
2192
|
+
|
|
2193
|
+
|
|
2194
|
+
def offset_datetime(
|
|
2195
|
+
datetime: ZonedDateTime, offset: str, /, *, n: int = 1
|
|
2196
|
+
) -> ZonedDateTime:
|
|
2197
|
+
"""Offset a datetime as `polars` would."""
|
|
2198
|
+
sr = Series(values=[datetime.py_datetime()])
|
|
2199
|
+
for _ in range(n):
|
|
2200
|
+
sr = sr.dt.offset_by(offset)
|
|
2201
|
+
return ZonedDateTime.from_py_datetime(sr.item())
|
|
2202
|
+
|
|
2203
|
+
|
|
2204
|
+
##
|
|
2205
|
+
|
|
2206
|
+
|
|
2207
|
+
def one_column(df: DataFrame, /) -> Series:
|
|
2208
|
+
"""Return the unique column in a DataFrame."""
|
|
2209
|
+
try:
|
|
2210
|
+
return df[one(df.columns)]
|
|
2211
|
+
except OneEmptyError:
|
|
2212
|
+
raise OneColumnEmptyError(df=df) from None
|
|
2213
|
+
except OneNonUniqueError as error:
|
|
2214
|
+
raise OneColumnNonUniqueError(
|
|
2215
|
+
df=df, first=error.first, second=error.second
|
|
2216
|
+
) from None
|
|
2217
|
+
|
|
2218
|
+
|
|
2219
|
+
@dataclass(kw_only=True, slots=True)
|
|
2220
|
+
class OneColumnError(Exception):
|
|
2221
|
+
df: DataFrame
|
|
2222
|
+
|
|
2223
|
+
|
|
2224
|
+
@dataclass(kw_only=True, slots=True)
|
|
2225
|
+
class OneColumnEmptyError(OneColumnError):
|
|
2226
|
+
@override
|
|
2227
|
+
def __str__(self) -> str:
|
|
2228
|
+
return "DataFrame must not be empty"
|
|
2229
|
+
|
|
2230
|
+
|
|
2231
|
+
@dataclass(kw_only=True, slots=True)
|
|
2232
|
+
class OneColumnNonUniqueError(OneColumnError):
|
|
2233
|
+
first: str
|
|
2234
|
+
second: str
|
|
2235
|
+
|
|
2236
|
+
@override
|
|
2237
|
+
def __str__(self) -> str:
|
|
2238
|
+
return f"DataFrame must contain exactly one column; got {self.first!r}, {self.second!r} and perhaps more"
|
|
2239
|
+
|
|
2240
|
+
|
|
2241
|
+
##
|
|
2242
|
+
|
|
2243
|
+
|
|
2244
|
+
@overload
|
|
2245
|
+
def order_of_magnitude(column: ExprLike, /, *, round_: bool = False) -> Expr: ...
|
|
2246
|
+
@overload
|
|
2247
|
+
def order_of_magnitude(column: Series, /, *, round_: bool = False) -> Series: ...
|
|
2248
|
+
@overload
|
|
2249
|
+
def order_of_magnitude(
|
|
2250
|
+
column: IntoExprColumn, /, *, round_: bool = False
|
|
2251
|
+
) -> ExprOrSeries: ...
|
|
2252
|
+
def order_of_magnitude(
|
|
2253
|
+
column: IntoExprColumn, /, *, round_: bool = False
|
|
2254
|
+
) -> ExprOrSeries:
|
|
2255
|
+
"""Compute the order of magnitude of a column."""
|
|
2256
|
+
column = ensure_expr_or_series(column)
|
|
2257
|
+
result = column.abs().log10()
|
|
2258
|
+
return result.round().cast(Int64) if round_ else result
|
|
2259
|
+
|
|
2260
|
+
|
|
2261
|
+
##
|
|
2262
|
+
|
|
2263
|
+
|
|
2264
|
+
@overload
|
|
2265
|
+
def period_range(
|
|
2266
|
+
start: ZonedDateTime,
|
|
2267
|
+
end_or_length: ZonedDateTime | int,
|
|
2268
|
+
/,
|
|
2269
|
+
*,
|
|
2270
|
+
interval: str = "1d",
|
|
2271
|
+
time_unit: TimeUnit | None = None,
|
|
2272
|
+
time_zone: TimeZoneLike | None = None,
|
|
2273
|
+
eager: Literal[True],
|
|
2274
|
+
) -> Series: ...
|
|
2275
|
+
@overload
|
|
2276
|
+
def period_range(
|
|
2277
|
+
start: ZonedDateTime,
|
|
2278
|
+
end_or_length: ZonedDateTime | int,
|
|
2279
|
+
/,
|
|
2280
|
+
*,
|
|
2281
|
+
interval: str = "1d",
|
|
2282
|
+
time_unit: TimeUnit | None = None,
|
|
2283
|
+
time_zone: TimeZoneLike | None = None,
|
|
2284
|
+
eager: Literal[False] = False,
|
|
2285
|
+
) -> Expr: ...
|
|
2286
|
+
@overload
|
|
2287
|
+
def period_range(
|
|
2288
|
+
start: ZonedDateTime,
|
|
2289
|
+
end_or_length: ZonedDateTime | int,
|
|
2290
|
+
/,
|
|
2291
|
+
*,
|
|
2292
|
+
interval: str = "1d",
|
|
2293
|
+
time_unit: TimeUnit | None = None,
|
|
2294
|
+
time_zone: TimeZoneLike | None = None,
|
|
2295
|
+
eager: bool = False,
|
|
2296
|
+
) -> Series | Expr: ...
|
|
2297
|
+
def period_range(
|
|
2298
|
+
start: ZonedDateTime,
|
|
2299
|
+
end_or_length: ZonedDateTime | int,
|
|
2300
|
+
/,
|
|
2301
|
+
*,
|
|
2302
|
+
interval: str = "1d",
|
|
2303
|
+
time_unit: TimeUnit | None = None,
|
|
2304
|
+
time_zone: TimeZoneLike | None = None,
|
|
2305
|
+
eager: bool = False,
|
|
2306
|
+
) -> Series | Expr:
|
|
2307
|
+
"""Construct a period range."""
|
|
2308
|
+
time_zone_use = None if time_zone is None else to_time_zone_name(time_zone)
|
|
2309
|
+
match end_or_length:
|
|
2310
|
+
case ZonedDateTime() as end:
|
|
2311
|
+
...
|
|
2312
|
+
case int() as length:
|
|
2313
|
+
end = offset_datetime(start, interval, n=length)
|
|
2314
|
+
case never:
|
|
1830
2315
|
assert_never(never)
|
|
2316
|
+
starts = datetime_range(
|
|
2317
|
+
start.py_datetime(),
|
|
2318
|
+
end.py_datetime(),
|
|
2319
|
+
interval,
|
|
2320
|
+
closed="left",
|
|
2321
|
+
time_unit=time_unit,
|
|
2322
|
+
time_zone=time_zone_use,
|
|
2323
|
+
eager=eager,
|
|
2324
|
+
).alias("start")
|
|
2325
|
+
ends = starts.dt.offset_by(interval).alias("end")
|
|
2326
|
+
period = struct(starts, ends)
|
|
2327
|
+
return try_reify_expr(period, starts, ends)
|
|
1831
2328
|
|
|
1832
2329
|
|
|
1833
2330
|
##
|
|
@@ -1865,13 +2362,10 @@ def reify_exprs(
|
|
|
1865
2362
|
.with_columns(*all_exprs)
|
|
1866
2363
|
.drop("_index")
|
|
1867
2364
|
)
|
|
1868
|
-
|
|
1869
|
-
|
|
1870
|
-
|
|
1871
|
-
|
|
1872
|
-
return df[one(df.columns)]
|
|
1873
|
-
case _:
|
|
1874
|
-
return df
|
|
2365
|
+
try:
|
|
2366
|
+
return one_column(df)
|
|
2367
|
+
except OneColumnNonUniqueError:
|
|
2368
|
+
return df
|
|
1875
2369
|
|
|
1876
2370
|
|
|
1877
2371
|
@dataclass(kw_only=True, slots=True)
|
|
@@ -1921,7 +2415,7 @@ def _replace_time_zone_one(
|
|
|
1921
2415
|
sr: Series, /, *, time_zone: TimeZoneLike | None = UTC
|
|
1922
2416
|
) -> Series:
|
|
1923
2417
|
if isinstance(sr.dtype, Datetime):
|
|
1924
|
-
time_zone_use = None if time_zone is None else
|
|
2418
|
+
time_zone_use = None if time_zone is None else to_time_zone_name(time_zone)
|
|
1925
2419
|
return sr.dt.replace_time_zone(time_zone_use)
|
|
1926
2420
|
return sr
|
|
1927
2421
|
|
|
@@ -1929,6 +2423,254 @@ def _replace_time_zone_one(
|
|
|
1929
2423
|
##
|
|
1930
2424
|
|
|
1931
2425
|
|
|
2426
|
+
def read_series(path: PathLike, /, *, decompress: bool = False) -> Series:
|
|
2427
|
+
"""Read a Series from disk."""
|
|
2428
|
+
data = read_binary(path, decompress=decompress)
|
|
2429
|
+
return deserialize_series(data)
|
|
2430
|
+
|
|
2431
|
+
|
|
2432
|
+
def write_series(
|
|
2433
|
+
series: Series,
|
|
2434
|
+
path: PathLike,
|
|
2435
|
+
/,
|
|
2436
|
+
*,
|
|
2437
|
+
compress: bool = False,
|
|
2438
|
+
overwrite: bool = False,
|
|
2439
|
+
) -> None:
|
|
2440
|
+
"""Write a Series to disk."""
|
|
2441
|
+
data = serialize_series(series)
|
|
2442
|
+
write_formatted_json(data, path, compress=compress, overwrite=overwrite)
|
|
2443
|
+
|
|
2444
|
+
|
|
2445
|
+
def read_dataframe(path: PathLike, /, *, decompress: bool = False) -> DataFrame:
|
|
2446
|
+
"""Read a DataFrame from disk."""
|
|
2447
|
+
data = read_binary(path, decompress=decompress)
|
|
2448
|
+
return deserialize_dataframe(data)
|
|
2449
|
+
|
|
2450
|
+
|
|
2451
|
+
def write_dataframe(
|
|
2452
|
+
df: DataFrame, path: PathLike, /, *, compress: bool = False, overwrite: bool = False
|
|
2453
|
+
) -> None:
|
|
2454
|
+
"""Write a DataFrame to disk."""
|
|
2455
|
+
data = serialize_dataframe(df)
|
|
2456
|
+
write_formatted_json(data, path, compress=compress, overwrite=overwrite)
|
|
2457
|
+
|
|
2458
|
+
|
|
2459
|
+
def serialize_series(series: Series, /) -> bytes:
|
|
2460
|
+
"""Serialize a Series."""
|
|
2461
|
+
from utilities.orjson import serialize
|
|
2462
|
+
|
|
2463
|
+
values = series.to_list()
|
|
2464
|
+
decon = _deconstruct_dtype(series.dtype)
|
|
2465
|
+
return serialize((series.name, values, decon))
|
|
2466
|
+
|
|
2467
|
+
|
|
2468
|
+
def deserialize_series(data: bytes, /) -> Series:
|
|
2469
|
+
"""Serialize a Series."""
|
|
2470
|
+
from utilities.orjson import deserialize
|
|
2471
|
+
|
|
2472
|
+
name, values, decon = deserialize(data)
|
|
2473
|
+
dtype = _reconstruct_dtype(decon)
|
|
2474
|
+
return Series(name=name, values=values, dtype=dtype)
|
|
2475
|
+
|
|
2476
|
+
|
|
2477
|
+
def serialize_dataframe(df: DataFrame, /) -> bytes:
|
|
2478
|
+
"""Serialize a DataFrame."""
|
|
2479
|
+
from utilities.orjson import serialize
|
|
2480
|
+
|
|
2481
|
+
rows = df.rows()
|
|
2482
|
+
decon = _deconstruct_schema(df.schema)
|
|
2483
|
+
return serialize((rows, decon))
|
|
2484
|
+
|
|
2485
|
+
|
|
2486
|
+
def deserialize_dataframe(data: bytes, /) -> DataFrame:
|
|
2487
|
+
"""Serialize a DataFrame."""
|
|
2488
|
+
from utilities.orjson import deserialize
|
|
2489
|
+
|
|
2490
|
+
rows, decon = deserialize(data)
|
|
2491
|
+
schema = _reconstruct_schema(decon)
|
|
2492
|
+
return DataFrame(data=rows, schema=schema, orient="row")
|
|
2493
|
+
|
|
2494
|
+
|
|
2495
|
+
type _DeconSchema = Sequence[tuple[str, _DeconDType]]
|
|
2496
|
+
type _DeconDType = (
|
|
2497
|
+
str
|
|
2498
|
+
| tuple[Literal["Datetime"], str, str | None]
|
|
2499
|
+
| tuple[Literal["List"], _DeconDType]
|
|
2500
|
+
| tuple[Literal["Struct"], _DeconSchema]
|
|
2501
|
+
)
|
|
2502
|
+
|
|
2503
|
+
|
|
2504
|
+
def _deconstruct_schema(schema: Schema, /) -> _DeconSchema:
|
|
2505
|
+
return [(k, _deconstruct_dtype(v)) for k, v in schema.items()]
|
|
2506
|
+
|
|
2507
|
+
|
|
2508
|
+
def _deconstruct_dtype(dtype: PolarsDataType, /) -> _DeconDType:
|
|
2509
|
+
match dtype:
|
|
2510
|
+
case List() as list_:
|
|
2511
|
+
return "List", _deconstruct_dtype(list_.inner)
|
|
2512
|
+
case Struct() as struct:
|
|
2513
|
+
inner = Schema({f.name: f.dtype for f in struct.fields})
|
|
2514
|
+
return "Struct", _deconstruct_schema(inner)
|
|
2515
|
+
case Datetime() as datetime:
|
|
2516
|
+
return "Datetime", datetime.time_unit, datetime.time_zone
|
|
2517
|
+
case _:
|
|
2518
|
+
return repr(dtype)
|
|
2519
|
+
|
|
2520
|
+
|
|
2521
|
+
def _reconstruct_schema(schema: _DeconSchema, /) -> Schema:
|
|
2522
|
+
return Schema({k: _reconstruct_dtype(v) for k, v in schema})
|
|
2523
|
+
|
|
2524
|
+
|
|
2525
|
+
def _reconstruct_dtype(obj: _DeconDType, /) -> PolarsDataType:
|
|
2526
|
+
match obj:
|
|
2527
|
+
case str() as name:
|
|
2528
|
+
return getattr(pl, name)
|
|
2529
|
+
case "Datetime", str() as time_unit, str() | None as time_zone:
|
|
2530
|
+
return Datetime(time_unit=cast("TimeUnit", time_unit), time_zone=time_zone)
|
|
2531
|
+
case "List", inner:
|
|
2532
|
+
return List(_reconstruct_dtype(inner))
|
|
2533
|
+
case "Struct", inner:
|
|
2534
|
+
return Struct(_reconstruct_schema(inner))
|
|
2535
|
+
case never:
|
|
2536
|
+
assert_never(never)
|
|
2537
|
+
|
|
2538
|
+
|
|
2539
|
+
##
|
|
2540
|
+
|
|
2541
|
+
|
|
2542
|
+
@overload
|
|
2543
|
+
def round_to_float(
|
|
2544
|
+
x: ExprLike, y: float, /, *, mode: RoundMode = "half_to_even"
|
|
2545
|
+
) -> Expr: ...
|
|
2546
|
+
@overload
|
|
2547
|
+
def round_to_float(
|
|
2548
|
+
x: Series, y: float | ExprOrSeries, /, *, mode: RoundMode = "half_to_even"
|
|
2549
|
+
) -> Series: ...
|
|
2550
|
+
@overload
|
|
2551
|
+
def round_to_float(
|
|
2552
|
+
x: ExprLike, y: Series, /, *, mode: RoundMode = "half_to_even"
|
|
2553
|
+
) -> Series: ...
|
|
2554
|
+
@overload
|
|
2555
|
+
def round_to_float(
|
|
2556
|
+
x: ExprLike, y: Expr, /, *, mode: RoundMode = "half_to_even"
|
|
2557
|
+
) -> Expr: ...
|
|
2558
|
+
@overload
|
|
2559
|
+
def round_to_float(
|
|
2560
|
+
x: IntoExprColumn, y: float | Series, /, *, mode: RoundMode = "half_to_even"
|
|
2561
|
+
) -> ExprOrSeries: ...
|
|
2562
|
+
def round_to_float(
|
|
2563
|
+
x: IntoExprColumn, y: float | IntoExprColumn, /, *, mode: RoundMode = "half_to_even"
|
|
2564
|
+
) -> ExprOrSeries:
|
|
2565
|
+
"""Round a column to the nearest multiple of another float."""
|
|
2566
|
+
x = ensure_expr_or_series(x)
|
|
2567
|
+
y = y if isinstance(y, int | float) else ensure_expr_or_series(y)
|
|
2568
|
+
match x, y:
|
|
2569
|
+
case Expr() | Series(), int() | float():
|
|
2570
|
+
z = (x / y).round(mode=mode) * y
|
|
2571
|
+
return z.round(decimals=utilities.math.number_of_decimals(y) + 1)
|
|
2572
|
+
case Series(), Expr() | Series():
|
|
2573
|
+
df = (
|
|
2574
|
+
x
|
|
2575
|
+
.to_frame()
|
|
2576
|
+
.with_columns(y)
|
|
2577
|
+
.with_columns(number_of_decimals(y).alias("_decimals"))
|
|
2578
|
+
.with_row_index(name="_index")
|
|
2579
|
+
.group_by("_decimals")
|
|
2580
|
+
.map_groups(_round_to_float_one)
|
|
2581
|
+
.sort("_index")
|
|
2582
|
+
)
|
|
2583
|
+
return df[df.columns[1]]
|
|
2584
|
+
case Expr(), Series():
|
|
2585
|
+
df = y.to_frame().with_columns(x)
|
|
2586
|
+
return round_to_float(df[df.columns[1]], df[df.columns[0]], mode=mode)
|
|
2587
|
+
case Expr(), Expr() | str():
|
|
2588
|
+
raise RoundToFloatError(x=x, y=y)
|
|
2589
|
+
case never:
|
|
2590
|
+
assert_never(never)
|
|
2591
|
+
|
|
2592
|
+
|
|
2593
|
+
def _round_to_float_one(df: DataFrame, /) -> DataFrame:
|
|
2594
|
+
decimals: int | None = df["_decimals"].unique().item()
|
|
2595
|
+
name = df.columns[1]
|
|
2596
|
+
match decimals:
|
|
2597
|
+
case int():
|
|
2598
|
+
expr = col(name).round(decimals=decimals)
|
|
2599
|
+
case None:
|
|
2600
|
+
expr = lit(None, dtype=Float64).alias(name)
|
|
2601
|
+
case never:
|
|
2602
|
+
assert_never(never)
|
|
2603
|
+
return df.with_columns(expr)
|
|
2604
|
+
|
|
2605
|
+
|
|
2606
|
+
@dataclass(kw_only=True, slots=True)
|
|
2607
|
+
class RoundToFloatError(Exception):
|
|
2608
|
+
x: IntoExprColumn
|
|
2609
|
+
y: IntoExprColumn
|
|
2610
|
+
|
|
2611
|
+
@override
|
|
2612
|
+
def __str__(self) -> str:
|
|
2613
|
+
return f"At least 1 of the dividend and/or divisor must be a Series; got {get_class_name(self.x)!r} and {get_class_name(self.y)!r}"
|
|
2614
|
+
|
|
2615
|
+
|
|
2616
|
+
##
|
|
2617
|
+
|
|
2618
|
+
|
|
2619
|
+
def search_period(
|
|
2620
|
+
series: Series,
|
|
2621
|
+
date_time: ZonedDateTime,
|
|
2622
|
+
/,
|
|
2623
|
+
*,
|
|
2624
|
+
start_or_end: Literal["start", "end"] = "end",
|
|
2625
|
+
) -> int | None:
|
|
2626
|
+
"""Search a series of periods for the one containing a given date-time."""
|
|
2627
|
+
end = series.struct["end"]
|
|
2628
|
+
py_date_time = date_time.py_datetime()
|
|
2629
|
+
match start_or_end:
|
|
2630
|
+
case "start":
|
|
2631
|
+
index = end.search_sorted(py_date_time, side="right")
|
|
2632
|
+
if index >= len(series):
|
|
2633
|
+
return None
|
|
2634
|
+
item: dt.datetime = series[index]["start"]
|
|
2635
|
+
return index if py_date_time >= item else None
|
|
2636
|
+
case "end":
|
|
2637
|
+
index = end.search_sorted(py_date_time, side="left")
|
|
2638
|
+
if index >= len(series):
|
|
2639
|
+
return None
|
|
2640
|
+
item: dt.datetime = series[index]["start"]
|
|
2641
|
+
return index if py_date_time > item else None
|
|
2642
|
+
case never:
|
|
2643
|
+
assert_never(never)
|
|
2644
|
+
|
|
2645
|
+
|
|
2646
|
+
##
|
|
2647
|
+
|
|
2648
|
+
|
|
2649
|
+
def select_exact(
|
|
2650
|
+
df: DataFrame, /, *columns: IntoExprColumn, drop: MaybeIterable[str] | None = None
|
|
2651
|
+
) -> DataFrame:
|
|
2652
|
+
"""Select an exact set of columns from a DataFrame."""
|
|
2653
|
+
names = [get_expr_name(df, c) for c in columns]
|
|
2654
|
+
drop = set() if drop is None else set(always_iterable(drop))
|
|
2655
|
+
union = set(names) | drop
|
|
2656
|
+
extra = [c for c in df.columns if c not in union]
|
|
2657
|
+
if len(extra) >= 1:
|
|
2658
|
+
raise SelectExactError(columns=extra)
|
|
2659
|
+
return df.select(*columns)
|
|
2660
|
+
|
|
2661
|
+
|
|
2662
|
+
@dataclass(kw_only=True, slots=True)
|
|
2663
|
+
class SelectExactError(Exception):
|
|
2664
|
+
columns: list[str]
|
|
2665
|
+
|
|
2666
|
+
@override
|
|
2667
|
+
def __str__(self) -> str:
|
|
2668
|
+
return f"All columns must be selected; got {get_repr(self.columns)} remaining"
|
|
2669
|
+
|
|
2670
|
+
|
|
2671
|
+
##
|
|
2672
|
+
|
|
2673
|
+
|
|
1932
2674
|
def set_first_row_as_columns(df: DataFrame, /) -> DataFrame:
|
|
1933
2675
|
"""Set the first row of a DataFrame as its columns."""
|
|
1934
2676
|
try:
|
|
@@ -1959,79 +2701,79 @@ def struct_dtype(**kwargs: PolarsDataType) -> Struct:
|
|
|
1959
2701
|
##
|
|
1960
2702
|
|
|
1961
2703
|
|
|
1962
|
-
|
|
1963
|
-
|
|
1964
|
-
|
|
1965
|
-
|
|
1966
|
-
|
|
1967
|
-
|
|
1968
|
-
|
|
1969
|
-
|
|
1970
|
-
|
|
1971
|
-
|
|
1972
|
-
if not is_dataclass_class(cls):
|
|
1973
|
-
raise _StructFromDataClassNotADataclassError(cls=cls)
|
|
1974
|
-
anns = get_type_hints(
|
|
1975
|
-
cls, globalns=globalns, localns=localns, warn_name_errors=warn_name_errors
|
|
1976
|
-
)
|
|
1977
|
-
data_types = {
|
|
1978
|
-
k: _struct_from_dataclass_one(v, time_zone=time_zone) for k, v in anns.items()
|
|
1979
|
-
}
|
|
1980
|
-
return Struct(data_types)
|
|
2704
|
+
@overload
|
|
2705
|
+
def to_true(column: ExprLike, /) -> Expr: ...
|
|
2706
|
+
@overload
|
|
2707
|
+
def to_true(column: Series, /) -> Series: ...
|
|
2708
|
+
@overload
|
|
2709
|
+
def to_true(column: IntoExprColumn, /) -> ExprOrSeries: ...
|
|
2710
|
+
def to_true(column: IntoExprColumn, /) -> ExprOrSeries:
|
|
2711
|
+
"""Compute when a boolean series turns True."""
|
|
2712
|
+
t = is_true(column)
|
|
2713
|
+
return ((~t).shift() & t).fill_null(value=False)
|
|
1981
2714
|
|
|
1982
2715
|
|
|
1983
|
-
|
|
1984
|
-
|
|
1985
|
-
|
|
1986
|
-
|
|
1987
|
-
|
|
1988
|
-
|
|
1989
|
-
|
|
1990
|
-
|
|
1991
|
-
|
|
1992
|
-
|
|
1993
|
-
if is_dataclass_class(ann):
|
|
1994
|
-
return struct_from_dataclass(ann, time_zone=time_zone)
|
|
1995
|
-
if (isinstance(ann, type) and issubclass(ann, enum.Enum)) or (
|
|
1996
|
-
is_literal_type(ann) and is_iterable_of(get_args(ann), str)
|
|
1997
|
-
):
|
|
1998
|
-
return String
|
|
1999
|
-
if is_optional_type(ann):
|
|
2000
|
-
return _struct_from_dataclass_one(
|
|
2001
|
-
one(get_args(ann, optional_drop_none=True)), time_zone=time_zone
|
|
2002
|
-
)
|
|
2003
|
-
if is_frozenset_type(ann) or is_list_type(ann) or is_set_type(ann):
|
|
2004
|
-
return List(_struct_from_dataclass_one(one(get_args(ann)), time_zone=time_zone))
|
|
2005
|
-
raise _StructFromDataClassTypeError(ann=ann)
|
|
2716
|
+
@overload
|
|
2717
|
+
def to_not_true(column: ExprLike, /) -> Expr: ...
|
|
2718
|
+
@overload
|
|
2719
|
+
def to_not_true(column: Series, /) -> Series: ...
|
|
2720
|
+
@overload
|
|
2721
|
+
def to_not_true(column: IntoExprColumn, /) -> ExprOrSeries: ...
|
|
2722
|
+
def to_not_true(column: IntoExprColumn, /) -> ExprOrSeries:
|
|
2723
|
+
"""Compute when a boolean series turns non-True."""
|
|
2724
|
+
t = is_true(column)
|
|
2725
|
+
return (t.shift() & (~t)).fill_null(value=False)
|
|
2006
2726
|
|
|
2007
2727
|
|
|
2008
|
-
@
|
|
2009
|
-
|
|
2728
|
+
@overload
|
|
2729
|
+
def to_false(column: ExprLike, /) -> Expr: ...
|
|
2730
|
+
@overload
|
|
2731
|
+
def to_false(column: Series, /) -> Series: ...
|
|
2732
|
+
@overload
|
|
2733
|
+
def to_false(column: IntoExprColumn, /) -> ExprOrSeries: ...
|
|
2734
|
+
def to_false(column: IntoExprColumn, /) -> ExprOrSeries:
|
|
2735
|
+
"""Compute when a boolean series turns False."""
|
|
2736
|
+
f = is_false(column)
|
|
2737
|
+
return ((~f).shift() & f).fill_null(value=False)
|
|
2010
2738
|
|
|
2011
2739
|
|
|
2012
|
-
@
|
|
2013
|
-
|
|
2014
|
-
|
|
2740
|
+
@overload
|
|
2741
|
+
def to_not_false(column: ExprLike, /) -> Expr: ...
|
|
2742
|
+
@overload
|
|
2743
|
+
def to_not_false(column: Series, /) -> Series: ...
|
|
2744
|
+
@overload
|
|
2745
|
+
def to_not_false(column: IntoExprColumn, /) -> ExprOrSeries: ...
|
|
2746
|
+
def to_not_false(column: IntoExprColumn, /) -> ExprOrSeries:
|
|
2747
|
+
"""Compute when a boolean series turns non-False."""
|
|
2748
|
+
f = is_false(column)
|
|
2749
|
+
return (f.shift() & (~f)).fill_null(value=False)
|
|
2015
2750
|
|
|
2016
|
-
@override
|
|
2017
|
-
def __str__(self) -> str:
|
|
2018
|
-
return f"Object must be a dataclass; got {self.cls}"
|
|
2019
2751
|
|
|
2752
|
+
##
|
|
2020
2753
|
|
|
2021
|
-
@dataclass(kw_only=True, slots=True)
|
|
2022
|
-
class _StructFromDataClassTimeZoneMissingError(StructFromDataClassError):
|
|
2023
|
-
@override
|
|
2024
|
-
def __str__(self) -> str:
|
|
2025
|
-
return "Time-zone must be given"
|
|
2026
2754
|
|
|
2755
|
+
@overload
|
|
2756
|
+
def true_like(column: ExprLike, /) -> Expr: ...
|
|
2757
|
+
@overload
|
|
2758
|
+
def true_like(column: Series, /) -> Series: ...
|
|
2759
|
+
@overload
|
|
2760
|
+
def true_like(column: IntoExprColumn, /) -> ExprOrSeries: ...
|
|
2761
|
+
def true_like(column: IntoExprColumn, /) -> ExprOrSeries:
|
|
2762
|
+
"""Compute a column of `True` values."""
|
|
2763
|
+
column = ensure_expr_or_series(column)
|
|
2764
|
+
return column.is_null() | column.is_not_null()
|
|
2027
2765
|
|
|
2028
|
-
@dataclass(kw_only=True, slots=True)
|
|
2029
|
-
class _StructFromDataClassTypeError(StructFromDataClassError):
|
|
2030
|
-
ann: Any
|
|
2031
2766
|
|
|
2032
|
-
|
|
2033
|
-
|
|
2034
|
-
|
|
2767
|
+
@overload
|
|
2768
|
+
def false_like(column: ExprLike, /) -> Expr: ...
|
|
2769
|
+
@overload
|
|
2770
|
+
def false_like(column: Series, /) -> Series: ...
|
|
2771
|
+
@overload
|
|
2772
|
+
def false_like(column: IntoExprColumn, /) -> ExprOrSeries: ...
|
|
2773
|
+
def false_like(column: IntoExprColumn, /) -> ExprOrSeries:
|
|
2774
|
+
"""Compute a column of `False` values."""
|
|
2775
|
+
column = ensure_expr_or_series(column)
|
|
2776
|
+
return column.is_null() & column.is_not_null()
|
|
2035
2777
|
|
|
2036
2778
|
|
|
2037
2779
|
##
|
|
@@ -2039,7 +2781,7 @@ class _StructFromDataClassTypeError(StructFromDataClassError):
|
|
|
2039
2781
|
|
|
2040
2782
|
def try_reify_expr(
|
|
2041
2783
|
expr: IntoExprColumn, /, *exprs: IntoExprColumn, **named_exprs: IntoExprColumn
|
|
2042
|
-
) ->
|
|
2784
|
+
) -> ExprOrSeries:
|
|
2043
2785
|
"""Try reify an expression."""
|
|
2044
2786
|
expr = ensure_expr_or_series(expr)
|
|
2045
2787
|
all_exprs = ensure_expr_or_series_many(*exprs, **named_exprs)
|
|
@@ -2052,7 +2794,7 @@ def try_reify_expr(
|
|
|
2052
2794
|
return series
|
|
2053
2795
|
case DataFrame() as df:
|
|
2054
2796
|
return df[get_expr_name(df, expr)]
|
|
2055
|
-
case
|
|
2797
|
+
case never:
|
|
2056
2798
|
assert_never(never)
|
|
2057
2799
|
|
|
2058
2800
|
|
|
@@ -2085,7 +2827,7 @@ def uniform(
|
|
|
2085
2827
|
return uniform(
|
|
2086
2828
|
df.height, low=low, high=high, seed=seed, name=name, dtype=dtype
|
|
2087
2829
|
)
|
|
2088
|
-
case
|
|
2830
|
+
case never:
|
|
2089
2831
|
assert_never(never)
|
|
2090
2832
|
|
|
2091
2833
|
|
|
@@ -2106,8 +2848,8 @@ def week_num(column: ExprLike, /, *, start: WeekDay = "mon") -> Expr: ...
|
|
|
2106
2848
|
@overload
|
|
2107
2849
|
def week_num(column: Series, /, *, start: WeekDay = "mon") -> Series: ...
|
|
2108
2850
|
@overload
|
|
2109
|
-
def week_num(column: IntoExprColumn, /, *, start: WeekDay = "mon") ->
|
|
2110
|
-
def week_num(column: IntoExprColumn, /, *, start: WeekDay = "mon") ->
|
|
2851
|
+
def week_num(column: IntoExprColumn, /, *, start: WeekDay = "mon") -> ExprOrSeries: ...
|
|
2852
|
+
def week_num(column: IntoExprColumn, /, *, start: WeekDay = "mon") -> ExprOrSeries:
|
|
2111
2853
|
"""Compute the week number of a date column."""
|
|
2112
2854
|
column = ensure_expr_or_series(column)
|
|
2113
2855
|
epoch = column.dt.epoch(time_unit="d").alias("epoch")
|
|
@@ -2118,79 +2860,129 @@ def week_num(column: IntoExprColumn, /, *, start: WeekDay = "mon") -> Expr | Ser
|
|
|
2118
2860
|
##
|
|
2119
2861
|
|
|
2120
2862
|
|
|
2121
|
-
def
|
|
2863
|
+
def zoned_date_time_dtype(
|
|
2122
2864
|
*, time_unit: TimeUnit = "us", time_zone: TimeZoneLike = UTC
|
|
2123
2865
|
) -> Datetime:
|
|
2124
|
-
"""Create a zoned
|
|
2125
|
-
return Datetime(time_unit=time_unit, time_zone=
|
|
2866
|
+
"""Create a zoned date-time data type."""
|
|
2867
|
+
return Datetime(time_unit=time_unit, time_zone=to_time_zone_name(time_zone))
|
|
2868
|
+
|
|
2869
|
+
|
|
2870
|
+
def zoned_date_time_period_dtype(
|
|
2871
|
+
*,
|
|
2872
|
+
time_unit: TimeUnit = "us",
|
|
2873
|
+
time_zone: TimeZoneLike | tuple[TimeZoneLike, TimeZoneLike] = UTC,
|
|
2874
|
+
) -> Struct:
|
|
2875
|
+
"""Create a zoned date-time period data type."""
|
|
2876
|
+
match time_zone:
|
|
2877
|
+
case start, end:
|
|
2878
|
+
return struct_dtype(
|
|
2879
|
+
start=zoned_date_time_dtype(time_unit=time_unit, time_zone=start),
|
|
2880
|
+
end=zoned_date_time_dtype(time_unit=time_unit, time_zone=end),
|
|
2881
|
+
)
|
|
2882
|
+
case _:
|
|
2883
|
+
dtype = zoned_date_time_dtype(time_unit=time_unit, time_zone=time_zone)
|
|
2884
|
+
return struct_dtype(start=dtype, end=dtype)
|
|
2126
2885
|
|
|
2127
2886
|
|
|
2128
2887
|
__all__ = [
|
|
2888
|
+
"AppendRowError",
|
|
2129
2889
|
"BooleanValueCountsError",
|
|
2130
2890
|
"CheckPolarsDataFrameError",
|
|
2131
2891
|
"ColumnsToDictError",
|
|
2132
2892
|
"DataClassToDataFrameError",
|
|
2893
|
+
"DatePeriodDType",
|
|
2133
2894
|
"DatetimeHongKong",
|
|
2134
2895
|
"DatetimeTokyo",
|
|
2135
2896
|
"DatetimeUSCentral",
|
|
2136
2897
|
"DatetimeUSEastern",
|
|
2137
2898
|
"DatetimeUTC",
|
|
2138
|
-
"
|
|
2899
|
+
"ExprOrSeries",
|
|
2139
2900
|
"FiniteEWMMeanError",
|
|
2140
2901
|
"GetDataTypeOrSeriesTimeZoneError",
|
|
2141
|
-
"GetSeriesNumberOfDecimalsError",
|
|
2142
2902
|
"InsertAfterError",
|
|
2143
2903
|
"InsertBeforeError",
|
|
2144
2904
|
"InsertBetweenError",
|
|
2145
2905
|
"IsNearEventError",
|
|
2146
|
-
"
|
|
2906
|
+
"OneColumnEmptyError",
|
|
2907
|
+
"OneColumnError",
|
|
2908
|
+
"OneColumnNonUniqueError",
|
|
2909
|
+
"RoundToFloatError",
|
|
2910
|
+
"SelectExactError",
|
|
2147
2911
|
"SetFirstRowAsColumnsError",
|
|
2148
|
-
"
|
|
2912
|
+
"TimePeriodDType",
|
|
2149
2913
|
"acf",
|
|
2150
2914
|
"adjust_frequencies",
|
|
2151
|
-
"
|
|
2915
|
+
"all_dataframe_columns",
|
|
2916
|
+
"all_series",
|
|
2917
|
+
"any_dataframe_columns",
|
|
2918
|
+
"any_series",
|
|
2919
|
+
"append_row",
|
|
2152
2920
|
"are_frames_equal",
|
|
2153
2921
|
"bernoulli",
|
|
2154
2922
|
"boolean_value_counts",
|
|
2155
|
-
"ceil_datetime",
|
|
2156
2923
|
"check_polars_dataframe",
|
|
2157
2924
|
"choice",
|
|
2158
|
-
"collect_series",
|
|
2159
2925
|
"columns_to_dict",
|
|
2160
2926
|
"concat_series",
|
|
2161
2927
|
"convert_time_zone",
|
|
2162
2928
|
"cross",
|
|
2163
2929
|
"dataclass_to_dataframe",
|
|
2164
2930
|
"dataclass_to_schema",
|
|
2165
|
-
"
|
|
2931
|
+
"decreasing_horizontal",
|
|
2932
|
+
"deserialize_dataframe",
|
|
2166
2933
|
"ensure_data_type",
|
|
2167
2934
|
"ensure_expr_or_series",
|
|
2168
2935
|
"ensure_expr_or_series_many",
|
|
2936
|
+
"expr_to_series",
|
|
2937
|
+
"false_like",
|
|
2938
|
+
"filter_date",
|
|
2939
|
+
"filter_time",
|
|
2169
2940
|
"finite_ewm_mean",
|
|
2170
|
-
"
|
|
2941
|
+
"first_true_horizontal",
|
|
2171
2942
|
"get_data_type_or_series_time_zone",
|
|
2172
2943
|
"get_expr_name",
|
|
2173
2944
|
"get_frequency_spectrum",
|
|
2174
|
-
"
|
|
2945
|
+
"increasing_horizontal",
|
|
2175
2946
|
"insert_after",
|
|
2176
2947
|
"insert_before",
|
|
2177
2948
|
"insert_between",
|
|
2178
2949
|
"integers",
|
|
2950
|
+
"is_close",
|
|
2951
|
+
"is_false",
|
|
2179
2952
|
"is_near_event",
|
|
2180
|
-
"
|
|
2181
|
-
"is_null_struct_series",
|
|
2953
|
+
"is_true",
|
|
2182
2954
|
"join",
|
|
2955
|
+
"join_into_periods",
|
|
2183
2956
|
"map_over_columns",
|
|
2184
2957
|
"nan_sum_agg",
|
|
2185
|
-
"
|
|
2186
|
-
"
|
|
2958
|
+
"nan_sum_horizontal",
|
|
2959
|
+
"normal_pdf",
|
|
2960
|
+
"normal_rv",
|
|
2961
|
+
"number_of_decimals",
|
|
2962
|
+
"offset_datetime",
|
|
2963
|
+
"one_column",
|
|
2964
|
+
"order_of_magnitude",
|
|
2965
|
+
"period_range",
|
|
2966
|
+
"read_dataframe",
|
|
2967
|
+
"read_series",
|
|
2187
2968
|
"replace_time_zone",
|
|
2969
|
+
"round_to_float",
|
|
2970
|
+
"search_period",
|
|
2971
|
+
"select_exact",
|
|
2972
|
+
"serialize_dataframe",
|
|
2188
2973
|
"set_first_row_as_columns",
|
|
2189
2974
|
"struct_dtype",
|
|
2190
|
-
"
|
|
2975
|
+
"to_false",
|
|
2976
|
+
"to_not_false",
|
|
2977
|
+
"to_not_true",
|
|
2978
|
+
"to_true",
|
|
2191
2979
|
"touch",
|
|
2980
|
+
"true_like",
|
|
2192
2981
|
"try_reify_expr",
|
|
2193
2982
|
"uniform",
|
|
2194
2983
|
"unique_element",
|
|
2195
|
-
"
|
|
2984
|
+
"write_dataframe",
|
|
2985
|
+
"write_series",
|
|
2986
|
+
"zoned_date_time_dtype",
|
|
2987
|
+
"zoned_date_time_period_dtype",
|
|
2196
2988
|
]
|