dycw-utilities 0.148.5__py3-none-any.whl → 0.174.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dycw-utilities might be problematic. Click here for more details.

Files changed (83) hide show
  1. dycw_utilities-0.174.12.dist-info/METADATA +41 -0
  2. dycw_utilities-0.174.12.dist-info/RECORD +104 -0
  3. dycw_utilities-0.174.12.dist-info/WHEEL +4 -0
  4. {dycw_utilities-0.148.5.dist-info → dycw_utilities-0.174.12.dist-info}/entry_points.txt +3 -0
  5. utilities/__init__.py +1 -1
  6. utilities/{eventkit.py → aeventkit.py} +12 -11
  7. utilities/altair.py +7 -6
  8. utilities/asyncio.py +113 -64
  9. utilities/atomicwrites.py +1 -1
  10. utilities/atools.py +64 -4
  11. utilities/cachetools.py +9 -6
  12. utilities/click.py +145 -49
  13. utilities/concurrent.py +1 -1
  14. utilities/contextlib.py +4 -2
  15. utilities/contextvars.py +20 -1
  16. utilities/cryptography.py +3 -3
  17. utilities/dataclasses.py +15 -28
  18. utilities/docker.py +292 -0
  19. utilities/enum.py +2 -2
  20. utilities/errors.py +1 -1
  21. utilities/fastapi.py +8 -3
  22. utilities/fpdf2.py +2 -2
  23. utilities/functions.py +20 -297
  24. utilities/git.py +19 -0
  25. utilities/grp.py +28 -0
  26. utilities/hypothesis.py +360 -78
  27. utilities/inflect.py +1 -1
  28. utilities/iterables.py +12 -58
  29. utilities/jinja2.py +148 -0
  30. utilities/json.py +1 -1
  31. utilities/libcst.py +7 -7
  32. utilities/logging.py +74 -85
  33. utilities/math.py +8 -4
  34. utilities/more_itertools.py +4 -6
  35. utilities/operator.py +1 -1
  36. utilities/orjson.py +86 -34
  37. utilities/os.py +49 -2
  38. utilities/parse.py +2 -2
  39. utilities/pathlib.py +66 -34
  40. utilities/permissions.py +297 -0
  41. utilities/platform.py +5 -5
  42. utilities/polars.py +932 -420
  43. utilities/polars_ols.py +1 -1
  44. utilities/postgres.py +296 -174
  45. utilities/pottery.py +8 -73
  46. utilities/pqdm.py +3 -3
  47. utilities/pwd.py +28 -0
  48. utilities/pydantic.py +11 -0
  49. utilities/pydantic_settings.py +240 -0
  50. utilities/pydantic_settings_sops.py +76 -0
  51. utilities/pyinstrument.py +5 -5
  52. utilities/pytest.py +155 -46
  53. utilities/pytest_plugins/pytest_randomly.py +1 -1
  54. utilities/pytest_plugins/pytest_regressions.py +7 -3
  55. utilities/pytest_regressions.py +2 -3
  56. utilities/random.py +11 -6
  57. utilities/re.py +1 -1
  58. utilities/redis.py +101 -64
  59. utilities/sentinel.py +10 -0
  60. utilities/shelve.py +4 -1
  61. utilities/shutil.py +25 -0
  62. utilities/slack_sdk.py +8 -3
  63. utilities/sqlalchemy.py +422 -352
  64. utilities/sqlalchemy_polars.py +28 -52
  65. utilities/string.py +1 -1
  66. utilities/subprocess.py +864 -0
  67. utilities/tempfile.py +62 -4
  68. utilities/testbook.py +50 -0
  69. utilities/text.py +165 -42
  70. utilities/timer.py +2 -2
  71. utilities/traceback.py +46 -36
  72. utilities/types.py +62 -23
  73. utilities/typing.py +479 -19
  74. utilities/uuid.py +42 -5
  75. utilities/version.py +27 -26
  76. utilities/whenever.py +661 -151
  77. utilities/zoneinfo.py +80 -22
  78. dycw_utilities-0.148.5.dist-info/METADATA +0 -41
  79. dycw_utilities-0.148.5.dist-info/RECORD +0 -95
  80. dycw_utilities-0.148.5.dist-info/WHEEL +0 -4
  81. dycw_utilities-0.148.5.dist-info/licenses/LICENSE +0 -21
  82. utilities/period.py +0 -237
  83. utilities/typed_settings.py +0 -144
utilities/polars.py CHANGED
@@ -1,25 +1,24 @@
1
1
  from __future__ import annotations
2
2
 
3
- import datetime as dt
4
3
  import enum
5
4
  from collections.abc import Callable, Iterator, Sequence
6
5
  from collections.abc import Set as AbstractSet
7
- from contextlib import suppress
8
6
  from dataclasses import asdict, dataclass
9
7
  from functools import partial, reduce
10
- from itertools import chain, product
11
- from math import ceil, log
8
+ from itertools import chain, pairwise, product
9
+ from math import ceil, log, pi, sqrt
12
10
  from pathlib import Path
13
11
  from typing import TYPE_CHECKING, Any, Literal, assert_never, cast, overload, override
14
12
  from uuid import UUID
15
13
  from zoneinfo import ZoneInfo
16
14
 
17
15
  import polars as pl
16
+ import whenever
18
17
  from polars import (
19
18
  Boolean,
20
19
  DataFrame,
21
- Date,
22
20
  Datetime,
21
+ Duration,
23
22
  Expr,
24
23
  Float64,
25
24
  Int64,
@@ -33,8 +32,11 @@ from polars import (
33
32
  any_horizontal,
34
33
  col,
35
34
  concat,
35
+ concat_list,
36
+ datetime_range,
36
37
  int_range,
37
38
  lit,
39
+ max_horizontal,
38
40
  struct,
39
41
  sum_horizontal,
40
42
  when,
@@ -49,59 +51,60 @@ from polars.exceptions import (
49
51
  )
50
52
  from polars.schema import Schema
51
53
  from polars.testing import assert_frame_equal, assert_series_equal
54
+ from whenever import DateDelta, DateTimeDelta, PlainDateTime, TimeDelta, ZonedDateTime
52
55
 
53
- from utilities.dataclasses import _YieldFieldsInstance, yield_fields
56
+ import utilities.math
57
+ from utilities.dataclasses import yield_fields
54
58
  from utilities.errors import ImpossibleCaseError
55
- from utilities.functions import (
56
- EnsureIntError,
57
- ensure_int,
58
- is_dataclass_class,
59
- is_dataclass_instance,
60
- is_iterable_of,
61
- make_isinstance,
62
- )
59
+ from utilities.functions import get_class_name
63
60
  from utilities.gzip import read_binary
64
61
  from utilities.iterables import (
65
62
  CheckIterablesEqualError,
66
63
  CheckMappingsEqualError,
67
- CheckSubSetError,
68
64
  CheckSuperMappingError,
69
65
  OneEmptyError,
70
66
  OneNonUniqueError,
71
67
  always_iterable,
72
68
  check_iterables_equal,
73
69
  check_mappings_equal,
74
- check_subset,
75
70
  check_supermapping,
76
71
  is_iterable_not_str,
77
72
  one,
73
+ resolve_include_and_exclude,
78
74
  )
79
75
  from utilities.json import write_formatted_json
80
76
  from utilities.math import (
77
+ MAX_DECIMALS,
81
78
  CheckIntegerError,
82
79
  check_integer,
83
80
  ewm_parameters,
84
81
  is_less_than,
85
82
  is_non_negative,
86
- number_of_decimals,
87
83
  )
88
84
  from utilities.reprlib import get_repr
89
85
  from utilities.types import MaybeStr, Number, PathLike, WeekDay
90
86
  from utilities.typing import (
91
87
  get_args,
92
- get_type_hints,
88
+ is_dataclass_class,
89
+ is_dataclass_instance,
93
90
  is_frozenset_type,
94
- is_instance_gen,
95
91
  is_list_type,
96
92
  is_literal_type,
97
93
  is_optional_type,
98
94
  is_set_type,
99
- is_union_type,
95
+ make_isinstance,
100
96
  )
101
97
  from utilities.warnings import suppress_warnings
102
- from utilities.zoneinfo import UTC, ensure_time_zone, get_time_zone_name
98
+ from utilities.whenever import (
99
+ DatePeriod,
100
+ TimePeriod,
101
+ ZonedDateTimePeriod,
102
+ to_py_time_delta,
103
+ )
104
+ from utilities.zoneinfo import UTC, to_time_zone_name
103
105
 
104
106
  if TYPE_CHECKING:
107
+ import datetime as dt
105
108
  from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
106
109
  from collections.abc import Set as AbstractSet
107
110
 
@@ -112,6 +115,7 @@ if TYPE_CHECKING:
112
115
  JoinValidation,
113
116
  PolarsDataType,
114
117
  QuantileMethod,
118
+ RoundMode,
115
119
  SchemaDict,
116
120
  TimeUnit,
117
121
  )
@@ -122,13 +126,19 @@ if TYPE_CHECKING:
122
126
 
123
127
 
124
128
  type ExprLike = MaybeStr[Expr]
129
+ type ExprOrSeries = Expr | Series
125
130
  DatetimeHongKong = Datetime(time_zone="Asia/Hong_Kong")
126
131
  DatetimeTokyo = Datetime(time_zone="Asia/Tokyo")
127
132
  DatetimeUSCentral = Datetime(time_zone="US/Central")
128
133
  DatetimeUSEastern = Datetime(time_zone="US/Eastern")
129
134
  DatetimeUTC = Datetime(time_zone="UTC")
135
+ DatePeriodDType = Struct({"start": pl.Date, "end": pl.Date})
136
+ TimePeriodDType = Struct({"start": pl.Time, "end": pl.Time})
137
+
138
+
130
139
  _FINITE_EWM_MIN_WEIGHT = 0.9999
131
140
 
141
+
132
142
  ##
133
143
 
134
144
 
@@ -208,7 +218,7 @@ def acf(
208
218
  df_confints = _acf_process_confints(confints)
209
219
  df_qstats_pvalues = _acf_process_qstats_pvalues(qstats, pvalues)
210
220
  return join(df_acfs, df_confints, df_qstats_pvalues, on=["lag"], how="left")
211
- case _ as never:
221
+ case never:
212
222
  assert_never(never)
213
223
 
214
224
 
@@ -238,11 +248,6 @@ def _acf_process_qstats_pvalues(qstats: NDArrayF, pvalues: NDArrayF, /) -> DataF
238
248
  ##
239
249
 
240
250
 
241
- # def acf_halflife(series: Series,/)
242
-
243
- ##
244
-
245
-
246
251
  def adjust_frequencies(
247
252
  series: Series,
248
253
  /,
@@ -264,29 +269,108 @@ def adjust_frequencies(
264
269
  ##
265
270
 
266
271
 
267
- def append_dataclass(df: DataFrame, obj: Dataclass, /) -> DataFrame:
268
- """Append a dataclass object to a DataFrame."""
269
- non_null_fields = {k: v for k, v in asdict(obj).items() if v is not None}
270
- try:
271
- check_subset(non_null_fields, df.columns)
272
- except CheckSubSetError as error:
273
- raise AppendDataClassError(
274
- left=error.left, right=error.right, extra=error.extra
275
- ) from None
276
- row_cols = set(df.columns) & set(non_null_fields)
277
- row = dataclass_to_dataframe(obj).select(*row_cols)
278
- return concat([df, row], how="diagonal")
272
+ def all_dataframe_columns(
273
+ df: DataFrame, expr: IntoExprColumn, /, *exprs: IntoExprColumn
274
+ ) -> Series:
275
+ """Return a DataFrame column with `AND` applied to additional exprs/series."""
276
+ name = get_expr_name(df, expr)
277
+ return df.select(all_horizontal(expr, *exprs).alias(name))[name]
278
+
279
+
280
+ def any_dataframe_columns(
281
+ df: DataFrame, expr: IntoExprColumn, /, *exprs: IntoExprColumn
282
+ ) -> Series:
283
+ """Return a DataFrame column with `OR` applied to additional exprs/series."""
284
+ name = get_expr_name(df, expr)
285
+ return df.select(any_horizontal(expr, *exprs).alias(name))[name]
286
+
287
+
288
+ def all_series(series: Series, /, *columns: ExprOrSeries) -> Series:
289
+ """Return a Series with `AND` applied to additional exprs/series."""
290
+ return all_dataframe_columns(series.to_frame(), series.name, *columns)
291
+
292
+
293
+ def any_series(series: Series, /, *columns: ExprOrSeries) -> Series:
294
+ """Return a Series with `OR` applied to additional exprs/series."""
295
+ df = series.to_frame()
296
+ name = series.name
297
+ return df.select(any_horizontal(name, *columns).alias(name))[name]
298
+
299
+
300
+ ##
301
+
302
+
303
+ def append_row(
304
+ df: DataFrame,
305
+ row: StrMapping,
306
+ /,
307
+ *,
308
+ predicate: Callable[[StrMapping], bool] | None = None,
309
+ disallow_extra: bool = False,
310
+ disallow_missing: bool | MaybeIterable[str] = False,
311
+ disallow_null: bool | MaybeIterable[str] = False,
312
+ in_place: bool = False,
313
+ ) -> DataFrame:
314
+ """Append a row to a DataFrame."""
315
+ if (predicate is not None) and not predicate(row):
316
+ raise _AppendRowPredicateError(df=df, row=row)
317
+ if disallow_extra and (len(extra := set(row) - set(df.columns)) >= 1):
318
+ raise _AppendRowExtraKeysError(df=df, row=row, extra=extra)
319
+ if disallow_missing is not False:
320
+ missing = set(df.columns) - set(row)
321
+ if disallow_missing is not True:
322
+ missing &= set(always_iterable(disallow_missing))
323
+ if len(missing) >= 1:
324
+ raise _AppendRowMissingKeysError(df=df, row=row, missing=missing)
325
+ other = DataFrame(data=[row], schema=df.schema)
326
+ if disallow_null:
327
+ other_null = other.select(col(c).is_null().any() for c in other.columns)
328
+ null = {k for k, v in other_null.row(0, named=True).items() if v}
329
+ if disallow_null is not True:
330
+ null &= set(always_iterable(disallow_null))
331
+ if len(null) >= 1:
332
+ raise _AppendRowNullColumnsError(df=df, row=row, columns=null)
333
+ return df.extend(other) if in_place else df.vstack(other)
279
334
 
280
335
 
281
336
  @dataclass(kw_only=True, slots=True)
282
- class AppendDataClassError[T](Exception):
283
- left: AbstractSet[T]
284
- right: AbstractSet[T]
285
- extra: AbstractSet[T]
337
+ class AppendRowError(Exception):
338
+ df: DataFrame
339
+ row: StrMapping
340
+
286
341
 
342
+ @dataclass(kw_only=True, slots=True)
343
+ class _AppendRowPredicateError(AppendRowError):
287
344
  @override
288
345
  def __str__(self) -> str:
289
- return f"Dataclass fields {get_repr(self.left)} must be a subset of DataFrame columns {get_repr(self.right)}; dataclass had extra items {get_repr(self.extra)}"
346
+ return f"Predicate failed; got {get_repr(self.row)}"
347
+
348
+
349
+ @dataclass(kw_only=True, slots=True)
350
+ class _AppendRowExtraKeysError(AppendRowError):
351
+ extra: AbstractSet[str]
352
+
353
+ @override
354
+ def __str__(self) -> str:
355
+ return f"Extra key(s) found; got {get_repr(self.extra)}"
356
+
357
+
358
+ @dataclass(kw_only=True, slots=True)
359
+ class _AppendRowMissingKeysError(AppendRowError):
360
+ missing: AbstractSet[str]
361
+
362
+ @override
363
+ def __str__(self) -> str:
364
+ return f"Missing key(s) found; got {get_repr(self.missing)}"
365
+
366
+
367
+ @dataclass(kw_only=True, slots=True)
368
+ class _AppendRowNullColumnsError(AppendRowError):
369
+ columns: AbstractSet[str]
370
+
371
+ @override
372
+ def __str__(self) -> str:
373
+ return f"Null column(s) found; got {get_repr(self.columns)}"
290
374
 
291
375
 
292
376
  ##
@@ -301,8 +385,8 @@ def are_frames_equal(
301
385
  check_column_order: bool = True,
302
386
  check_dtypes: bool = True,
303
387
  check_exact: bool = False,
304
- rtol: float = 1e-5,
305
- atol: float = 1e-8,
388
+ rel_tol: float = 1e-5,
389
+ abs_tol: float = 1e-8,
306
390
  categorical_as_str: bool = False,
307
391
  ) -> bool:
308
392
  """Check if two DataFrames are equal."""
@@ -314,8 +398,8 @@ def are_frames_equal(
314
398
  check_column_order=check_column_order,
315
399
  check_dtypes=check_dtypes,
316
400
  check_exact=check_exact,
317
- rtol=rtol,
318
- atol=atol,
401
+ rel_tol=rel_tol,
402
+ abs_tol=abs_tol,
319
403
  categorical_as_str=categorical_as_str,
320
404
  )
321
405
  except AssertionError:
@@ -345,7 +429,7 @@ def bernoulli(
345
429
  return bernoulli(series.len(), true=true, seed=seed, name=name)
346
430
  case DataFrame() as df:
347
431
  return bernoulli(df.height, true=true, seed=seed, name=name)
348
- case _ as never:
432
+ case never:
349
433
  assert_never(never)
350
434
 
351
435
 
@@ -379,7 +463,7 @@ def boolean_value_counts(
379
463
  (false / total).alias("false (%)"),
380
464
  (null / total).alias("null (%)"),
381
465
  )
382
- case _ as never:
466
+ case never:
383
467
  assert_never(never)
384
468
 
385
469
 
@@ -422,29 +506,6 @@ class BooleanValueCountsError(Exception):
422
506
  ##
423
507
 
424
508
 
425
- @overload
426
- def ceil_datetime(column: ExprLike, every: ExprLike, /) -> Expr: ...
427
- @overload
428
- def ceil_datetime(column: Series, every: ExprLike, /) -> Series: ...
429
- @overload
430
- def ceil_datetime(column: IntoExprColumn, every: ExprLike, /) -> Expr | Series: ...
431
- def ceil_datetime(column: IntoExprColumn, every: ExprLike, /) -> Expr | Series:
432
- """Compute the `ceil` of a datetime column."""
433
- column = ensure_expr_or_series(column)
434
- rounded = column.dt.round(every)
435
- ceil = (
436
- when(column <= rounded)
437
- .then(rounded)
438
- .otherwise(column.dt.offset_by(every).dt.round(every))
439
- )
440
- if isinstance(column, Expr):
441
- return ceil
442
- return DataFrame().with_columns(ceil.alias(column.name))[column.name]
443
-
444
-
445
- ##
446
-
447
-
448
509
  def check_polars_dataframe(
449
510
  df: DataFrame,
450
511
  /,
@@ -504,7 +565,7 @@ def _check_polars_dataframe_columns(df: DataFrame, columns: Iterable[str], /) ->
504
565
 
505
566
  @dataclass(kw_only=True, slots=True)
506
567
  class _CheckPolarsDataFrameColumnsError(CheckPolarsDataFrameError):
507
- columns: Sequence[str]
568
+ columns: list[str]
508
569
 
509
570
  @override
510
571
  def __str__(self) -> str:
@@ -763,29 +824,22 @@ def choice(
763
824
  name=name,
764
825
  dtype=dtype,
765
826
  )
766
- case _ as never:
827
+ case never:
767
828
  assert_never(never)
768
829
 
769
830
 
770
831
  ##
771
832
 
772
833
 
773
- def collect_series(expr: Expr, /) -> Series:
774
- """Collect a column expression into a Series."""
775
- data = DataFrame().with_columns(expr)
776
- return data[one(data.columns)]
777
-
778
-
779
- ##
780
-
781
-
782
- def columns_to_dict(df: DataFrame, key: str, value: str, /) -> dict[Any, Any]:
834
+ def columns_to_dict(
835
+ df: DataFrame, key: IntoExprColumn, value: IntoExprColumn, /
836
+ ) -> dict[Any, Any]:
783
837
  """Map a pair of columns into a dictionary. Must be unique on `key`."""
784
- col_key = df[key]
785
- if col_key.is_duplicated().any():
786
- raise ColumnsToDictError(df=df, key=key)
787
- col_value = df[value]
788
- return dict(zip(col_key, col_value, strict=True))
838
+ df = df.select(key, value)
839
+ key_col, value_col = [df[get_expr_name(df, expr)] for expr in [key, value]]
840
+ if key_col.is_duplicated().any():
841
+ raise ColumnsToDictError(df=df, key=key_col.name)
842
+ return dict(zip(key_col, value_col, strict=True))
789
843
 
790
844
 
791
845
  @dataclass(kw_only=True, slots=True)
@@ -828,7 +882,7 @@ def convert_time_zone(
828
882
 
829
883
  def _convert_time_zone_one(sr: Series, /, *, time_zone: TimeZoneLike = UTC) -> Series:
830
884
  if isinstance(sr.dtype, Datetime):
831
- return sr.dt.convert_time_zone(get_time_zone_name(time_zone))
885
+ return sr.dt.convert_time_zone(to_time_zone_name(time_zone))
832
886
  return sr
833
887
 
834
888
 
@@ -849,13 +903,13 @@ def cross(
849
903
  up_or_down: Literal["up", "down"],
850
904
  other: Number | IntoExprColumn,
851
905
  /,
852
- ) -> Expr | Series: ...
906
+ ) -> ExprOrSeries: ...
853
907
  def cross(
854
908
  expr: IntoExprColumn,
855
909
  up_or_down: Literal["up", "down"],
856
910
  other: Number | IntoExprColumn,
857
911
  /,
858
- ) -> Expr | Series:
912
+ ) -> ExprOrSeries:
859
913
  """Compute when a cross occurs."""
860
914
  return _cross_or_touch(expr, "cross", up_or_down, other)
861
915
 
@@ -874,13 +928,13 @@ def touch(
874
928
  up_or_down: Literal["up", "down"],
875
929
  other: Number | IntoExprColumn,
876
930
  /,
877
- ) -> Expr | Series: ...
931
+ ) -> ExprOrSeries: ...
878
932
  def touch(
879
933
  expr: IntoExprColumn,
880
934
  up_or_down: Literal["up", "down"],
881
935
  other: Number | IntoExprColumn,
882
936
  /,
883
- ) -> Expr | Series:
937
+ ) -> ExprOrSeries:
884
938
  """Compute when a touch occurs."""
885
939
  return _cross_or_touch(expr, "touch", up_or_down, other)
886
940
 
@@ -891,7 +945,7 @@ def _cross_or_touch(
891
945
  up_or_down: Literal["up", "down"],
892
946
  other: Number | IntoExprColumn,
893
947
  /,
894
- ) -> Expr | Series:
948
+ ) -> ExprOrSeries:
895
949
  """Compute when a column crosses/touches a threshold."""
896
950
  expr = ensure_expr_or_series(expr)
897
951
  match other:
@@ -899,7 +953,7 @@ def _cross_or_touch(
899
953
  ...
900
954
  case str() | Expr() | Series():
901
955
  other = ensure_expr_or_series(other)
902
- case _ as never:
956
+ case never:
903
957
  assert_never(never)
904
958
  enough = int_range(end=pl.len()) >= 1
905
959
  match cross_or_touch, up_or_down:
@@ -911,7 +965,7 @@ def _cross_or_touch(
911
965
  current = expr >= other
912
966
  case "touch", "down":
913
967
  current = expr <= other
914
- case _ as never:
968
+ case never:
915
969
  assert_never(never)
916
970
  prev = current.shift()
917
971
  result = when(enough & expr.is_finite()).then(current & ~prev)
@@ -963,7 +1017,7 @@ def cross_rolling_quantile(
963
1017
  weights: list[float] | None = None,
964
1018
  min_samples: int | None = None,
965
1019
  center: bool = False,
966
- ) -> Expr | Series: ...
1020
+ ) -> ExprOrSeries: ...
967
1021
  def cross_rolling_quantile(
968
1022
  expr: IntoExprColumn,
969
1023
  up_or_down: Literal["up", "down"],
@@ -975,7 +1029,7 @@ def cross_rolling_quantile(
975
1029
  weights: list[float] | None = None,
976
1030
  min_samples: int | None = None,
977
1031
  center: bool = False,
978
- ) -> Expr | Series:
1032
+ ) -> ExprOrSeries:
979
1033
  """Compute when a column crosses its rolling quantile."""
980
1034
  expr = ensure_expr_or_series(expr)
981
1035
  rolling = expr.rolling_quantile(
@@ -1020,16 +1074,43 @@ def dataclass_to_dataframe(
1020
1074
 
1021
1075
 
1022
1076
  def _dataclass_to_dataframe_cast(series: Series, /) -> Series:
1023
- if series.dtype == Object:
1024
- is_path = series.map_elements(make_isinstance(Path), return_dtype=Boolean).all()
1025
- is_uuid = series.map_elements(make_isinstance(UUID), return_dtype=Boolean).all()
1026
- if is_path or is_uuid:
1027
- with suppress_warnings(category=PolarsInefficientMapWarning):
1028
- return series.map_elements(str, return_dtype=String)
1029
- else: # pragma: no cover
1030
- msg = f"{is_path=}, f{is_uuid=}"
1031
- raise NotImplementedError(msg)
1032
- return series
1077
+ if series.dtype != Object:
1078
+ return series
1079
+ if series.map_elements(make_isinstance(whenever.Date), return_dtype=Boolean).all():
1080
+ return series.map_elements(lambda x: x.py_date(), return_dtype=pl.Date)
1081
+ if series.map_elements(make_isinstance(DateDelta), return_dtype=Boolean).all():
1082
+ return series.map_elements(to_py_time_delta, return_dtype=Duration)
1083
+ if series.map_elements(make_isinstance(DateTimeDelta), return_dtype=Boolean).all():
1084
+ return series.map_elements(to_py_time_delta, return_dtype=Duration)
1085
+ is_path = series.map_elements(make_isinstance(Path), return_dtype=Boolean).all()
1086
+ is_uuid = series.map_elements(make_isinstance(UUID), return_dtype=Boolean).all()
1087
+ if is_path or is_uuid:
1088
+ with suppress_warnings(
1089
+ category=cast("type[Warning]", PolarsInefficientMapWarning)
1090
+ ):
1091
+ return series.map_elements(str, return_dtype=String)
1092
+ if series.map_elements(make_isinstance(whenever.Time), return_dtype=Boolean).all():
1093
+ return series.map_elements(lambda x: x.py_time(), return_dtype=pl.Time)
1094
+ if series.map_elements(make_isinstance(TimeDelta), return_dtype=Boolean).all():
1095
+ return series.map_elements(to_py_time_delta, return_dtype=Duration)
1096
+ if series.map_elements(make_isinstance(ZonedDateTime), return_dtype=Boolean).all():
1097
+ return_dtype = zoned_date_time_dtype(time_zone=one({dt.tz for dt in series}))
1098
+ return series.map_elements(lambda x: x.py_datetime(), return_dtype=return_dtype)
1099
+ if series.map_elements(
1100
+ lambda x: isinstance(x, dict) and (set(x) == {"start", "end"}),
1101
+ return_dtype=Boolean,
1102
+ ).all():
1103
+ start = _dataclass_to_dataframe_cast(
1104
+ series.map_elements(lambda x: x["start"], return_dtype=Object)
1105
+ ).alias("start")
1106
+ end = _dataclass_to_dataframe_cast(
1107
+ series.map_elements(lambda x: x["end"], return_dtype=Object)
1108
+ ).alias("end")
1109
+ name = series.name
1110
+ return concat_series(start, end).select(
1111
+ struct(start=start, end=end).alias(name)
1112
+ )[name]
1113
+ raise NotImplementedError(series) # pragma: no cover
1033
1114
 
1034
1115
 
1035
1116
  @dataclass(kw_only=True, slots=True)
@@ -1070,20 +1151,14 @@ def dataclass_to_schema(
1070
1151
  for field in yield_fields(
1071
1152
  obj, globalns=globalns, localns=localns, warn_name_errors=warn_name_errors
1072
1153
  ):
1073
- if is_dataclass_instance(field.value):
1154
+ if is_dataclass_instance(field.value) and not (
1155
+ isinstance(field.type_, type)
1156
+ and issubclass(field.type_, (DatePeriod, TimePeriod, ZonedDateTimePeriod))
1157
+ ):
1074
1158
  dtypes = dataclass_to_schema(
1075
1159
  field.value, globalns=globalns, localns=localns
1076
1160
  )
1077
1161
  dtype = struct_dtype(**dtypes)
1078
- elif field.type_ is dt.datetime:
1079
- dtype = _dataclass_to_schema_datetime(field)
1080
- elif is_union_type(field.type_) and set(
1081
- get_args(field.type_, optional_drop_none=True)
1082
- ) == {dt.date, dt.datetime}:
1083
- if is_instance_gen(field.value, dt.date):
1084
- dtype = Date
1085
- else:
1086
- dtype = _dataclass_to_schema_datetime(field)
1087
1162
  else:
1088
1163
  dtype = _dataclass_to_schema_one(
1089
1164
  field.type_, globalns=globalns, localns=localns
@@ -1092,14 +1167,6 @@ def dataclass_to_schema(
1092
1167
  return out
1093
1168
 
1094
1169
 
1095
- def _dataclass_to_schema_datetime(
1096
- field: _YieldFieldsInstance[dt.datetime], /
1097
- ) -> PolarsDataType:
1098
- if field.value.tzinfo is None:
1099
- return Datetime
1100
- return zoned_datetime(time_zone=ensure_time_zone(field.value.tzinfo))
1101
-
1102
-
1103
1170
  def _dataclass_to_schema_one(
1104
1171
  obj: Any,
1105
1172
  /,
@@ -1107,20 +1174,35 @@ def _dataclass_to_schema_one(
1107
1174
  globalns: StrMapping | None = None,
1108
1175
  localns: StrMapping | None = None,
1109
1176
  ) -> PolarsDataType:
1110
- if obj is bool:
1111
- return Boolean
1112
- if obj is int:
1113
- return Int64
1114
- if obj is float:
1115
- return Float64
1116
- if obj is str:
1117
- return String
1118
- if obj is dt.date:
1119
- return Date
1120
- if obj in {Path, UUID}:
1121
- return Object
1122
- if isinstance(obj, type) and issubclass(obj, enum.Enum):
1123
- return pl.Enum([e.name for e in obj])
1177
+ if isinstance(obj, type):
1178
+ if issubclass(obj, bool):
1179
+ return Boolean
1180
+ if issubclass(obj, int):
1181
+ return Int64
1182
+ if issubclass(obj, float):
1183
+ return Float64
1184
+ if issubclass(obj, str):
1185
+ return String
1186
+ if issubclass(
1187
+ obj,
1188
+ (
1189
+ DateDelta,
1190
+ DatePeriod,
1191
+ DateTimeDelta,
1192
+ Path,
1193
+ PlainDateTime,
1194
+ TimeDelta,
1195
+ TimePeriod,
1196
+ UUID,
1197
+ ZonedDateTime,
1198
+ ZonedDateTimePeriod,
1199
+ whenever.Date,
1200
+ whenever.Time,
1201
+ ),
1202
+ ):
1203
+ return Object
1204
+ if issubclass(obj, enum.Enum):
1205
+ return pl.Enum([e.name for e in obj])
1124
1206
  if is_dataclass_class(obj):
1125
1207
  out: dict[str, Any] = {}
1126
1208
  for field in yield_fields(obj, globalns=globalns, localns=localns):
@@ -1146,27 +1228,6 @@ def _dataclass_to_schema_one(
1146
1228
  ##
1147
1229
 
1148
1230
 
1149
- def drop_null_struct_series(series: Series, /) -> Series:
1150
- """Drop nulls in a struct-dtype Series as per the <= 1.1 definition."""
1151
- try:
1152
- is_not_null = is_not_null_struct_series(series)
1153
- except IsNotNullStructSeriesError as error:
1154
- raise DropNullStructSeriesError(series=error.series) from None
1155
- return series.filter(is_not_null)
1156
-
1157
-
1158
- @dataclass(kw_only=True, slots=True)
1159
- class DropNullStructSeriesError(Exception):
1160
- series: Series
1161
-
1162
- @override
1163
- def __str__(self) -> str:
1164
- return f"Series must have Struct-dtype; got {self.series.dtype}"
1165
-
1166
-
1167
- ##
1168
-
1169
-
1170
1231
  def ensure_data_type(dtype: PolarsDataType, /) -> DataType:
1171
1232
  """Ensure a data type is returned."""
1172
1233
  return dtype if isinstance(dtype, DataType) else dtype()
@@ -1180,8 +1241,8 @@ def ensure_expr_or_series(column: ExprLike, /) -> Expr: ...
1180
1241
  @overload
1181
1242
  def ensure_expr_or_series(column: Series, /) -> Series: ...
1182
1243
  @overload
1183
- def ensure_expr_or_series(column: IntoExprColumn, /) -> Expr | Series: ...
1184
- def ensure_expr_or_series(column: IntoExprColumn, /) -> Expr | Series:
1244
+ def ensure_expr_or_series(column: IntoExprColumn, /) -> ExprOrSeries: ...
1245
+ def ensure_expr_or_series(column: IntoExprColumn, /) -> ExprOrSeries:
1185
1246
  """Ensure a column expression or Series is returned."""
1186
1247
  return col(column) if isinstance(column, str) else column
1187
1248
 
@@ -1191,7 +1252,7 @@ def ensure_expr_or_series(column: IntoExprColumn, /) -> Expr | Series:
1191
1252
 
1192
1253
  def ensure_expr_or_series_many(
1193
1254
  *columns: IntoExprColumn, **named_columns: IntoExprColumn
1194
- ) -> Sequence[Expr | Series]:
1255
+ ) -> Sequence[ExprOrSeries]:
1195
1256
  """Ensure a set of column expressions and/or Series are returned."""
1196
1257
  args = map(ensure_expr_or_series, columns)
1197
1258
  kwargs = (ensure_expr_or_series(v).alias(k) for k, v in named_columns.items())
@@ -1201,6 +1262,119 @@ def ensure_expr_or_series_many(
1201
1262
  ##
1202
1263
 
1203
1264
 
1265
+ def expr_to_series(expr: Expr, /) -> Series:
1266
+ """Collect a column expression into a Series."""
1267
+ return one_column(DataFrame().with_columns(expr))
1268
+
1269
+
1270
+ ##
1271
+
1272
+
1273
+ @overload
1274
+ def filter_date(
1275
+ column: ExprLike = "datetime",
1276
+ /,
1277
+ *,
1278
+ time_zone: ZoneInfo | None = None,
1279
+ include: MaybeIterable[whenever.Date] | None = None,
1280
+ exclude: MaybeIterable[whenever.Date] | None = None,
1281
+ ) -> Expr: ...
1282
+ @overload
1283
+ def filter_date(
1284
+ column: Series,
1285
+ /,
1286
+ *,
1287
+ time_zone: ZoneInfo | None = None,
1288
+ include: MaybeIterable[whenever.Date] | None = None,
1289
+ exclude: MaybeIterable[whenever.Date] | None = None,
1290
+ ) -> Series: ...
1291
+ @overload
1292
+ def filter_date(
1293
+ column: IntoExprColumn = "datetime",
1294
+ /,
1295
+ *,
1296
+ time_zone: ZoneInfo | None = None,
1297
+ include: MaybeIterable[whenever.Date] | None = None,
1298
+ exclude: MaybeIterable[whenever.Date] | None = None,
1299
+ ) -> ExprOrSeries: ...
1300
+ def filter_date(
1301
+ column: IntoExprColumn = "datetime",
1302
+ /,
1303
+ *,
1304
+ time_zone: ZoneInfo | None = None,
1305
+ include: MaybeIterable[whenever.Date] | None = None,
1306
+ exclude: MaybeIterable[whenever.Date] | None = None,
1307
+ ) -> ExprOrSeries:
1308
+ """Compute the filter based on a set of dates."""
1309
+ column = ensure_expr_or_series(column)
1310
+ if time_zone is not None:
1311
+ column = column.dt.convert_time_zone(time_zone.key)
1312
+ keep = true_like(column)
1313
+ date = column.dt.date()
1314
+ include, exclude = resolve_include_and_exclude(include=include, exclude=exclude)
1315
+ if include is not None:
1316
+ keep &= date.is_in([d.py_date() for d in include])
1317
+ if exclude is not None:
1318
+ keep &= ~date.is_in([d.py_date() for d in exclude])
1319
+ return try_reify_expr(keep, column)
1320
+
1321
+
1322
+ @overload
1323
+ def filter_time(
1324
+ column: ExprLike = "datetime",
1325
+ /,
1326
+ *,
1327
+ time_zone: ZoneInfo | None = None,
1328
+ include: MaybeIterable[tuple[whenever.Time, whenever.Time]] | None = None,
1329
+ exclude: MaybeIterable[tuple[whenever.Time, whenever.Time]] | None = None,
1330
+ ) -> Expr: ...
1331
+ @overload
1332
+ def filter_time(
1333
+ column: Series,
1334
+ /,
1335
+ *,
1336
+ time_zone: ZoneInfo | None = None,
1337
+ include: MaybeIterable[tuple[whenever.Time, whenever.Time]] | None = None,
1338
+ exclude: MaybeIterable[tuple[whenever.Time, whenever.Time]] | None = None,
1339
+ ) -> Series: ...
1340
+ @overload
1341
+ def filter_time(
1342
+ column: IntoExprColumn = "datetime",
1343
+ /,
1344
+ *,
1345
+ time_zone: ZoneInfo | None = None,
1346
+ include: MaybeIterable[tuple[whenever.Time, whenever.Time]] | None = None,
1347
+ exclude: MaybeIterable[tuple[whenever.Time, whenever.Time]] | None = None,
1348
+ ) -> ExprOrSeries: ...
1349
+ def filter_time(
1350
+ column: IntoExprColumn = "datetime",
1351
+ /,
1352
+ *,
1353
+ time_zone: ZoneInfo | None = None,
1354
+ include: MaybeIterable[tuple[whenever.Time, whenever.Time]] | None = None,
1355
+ exclude: MaybeIterable[tuple[whenever.Time, whenever.Time]] | None = None,
1356
+ ) -> ExprOrSeries:
1357
+ """Compute the filter based on a set of times."""
1358
+ column = ensure_expr_or_series(column)
1359
+ if time_zone is not None:
1360
+ column = column.dt.convert_time_zone(time_zone.key)
1361
+ keep = true_like(column)
1362
+ time = column.dt.time()
1363
+ include, exclude = resolve_include_and_exclude(include=include, exclude=exclude)
1364
+ if include is not None:
1365
+ keep &= any_horizontal(
1366
+ time.is_between(s.py_time(), e.py_time()) for s, e in include
1367
+ )
1368
+ if exclude is not None:
1369
+ keep &= ~any_horizontal(
1370
+ time.is_between(s.py_time(), e.py_time()) for s, e in exclude
1371
+ )
1372
+ return try_reify_expr(keep, column)
1373
+
1374
+
1375
+ ##
1376
+
1377
+
1204
1378
  @overload
1205
1379
  def finite_ewm_mean(
1206
1380
  column: ExprLike,
@@ -1233,7 +1407,7 @@ def finite_ewm_mean(
1233
1407
  half_life: float | None = None,
1234
1408
  alpha: float | None = None,
1235
1409
  min_weight: float = _FINITE_EWM_MIN_WEIGHT,
1236
- ) -> Expr | Series: ...
1410
+ ) -> ExprOrSeries: ...
1237
1411
  def finite_ewm_mean(
1238
1412
  column: IntoExprColumn,
1239
1413
  /,
@@ -1243,7 +1417,7 @@ def finite_ewm_mean(
1243
1417
  half_life: float | None = None,
1244
1418
  alpha: float | None = None,
1245
1419
  min_weight: float = _FINITE_EWM_MIN_WEIGHT,
1246
- ) -> Expr | Series:
1420
+ ) -> ExprOrSeries:
1247
1421
  """Compute a finite EWMA."""
1248
1422
  try:
1249
1423
  weights = _finite_ewm_weights(
@@ -1305,23 +1479,14 @@ class _FiniteEWMWeightsError(Exception):
1305
1479
 
1306
1480
 
1307
1481
  @overload
1308
- def floor_datetime(column: ExprLike, every: ExprLike, /) -> Expr: ...
1309
- @overload
1310
- def floor_datetime(column: Series, every: ExprLike, /) -> Series: ...
1482
+ def first_true_horizontal(*columns: Series) -> Series: ...
1311
1483
  @overload
1312
- def floor_datetime(column: IntoExprColumn, every: ExprLike, /) -> Expr | Series: ...
1313
- def floor_datetime(column: IntoExprColumn, every: ExprLike, /) -> Expr | Series:
1314
- """Compute the `floor` of a datetime column."""
1315
- column = ensure_expr_or_series(column)
1316
- rounded = column.dt.round(every)
1317
- floor = (
1318
- when(column >= rounded)
1319
- .then(rounded)
1320
- .otherwise(column.dt.offset_by("-" + every).dt.round(every))
1321
- )
1322
- if isinstance(column, Expr):
1323
- return floor
1324
- return DataFrame().with_columns(floor.alias(column.name))[column.name]
1484
+ def first_true_horizontal(*columns: IntoExprColumn) -> ExprOrSeries: ...
1485
+ def first_true_horizontal(*columns: IntoExprColumn) -> ExprOrSeries:
1486
+ """Get the index of the first true in each row."""
1487
+ columns2 = ensure_expr_or_series_many(*columns)
1488
+ expr = when(any_horizontal(*columns2)).then(concat_list(*columns2).list.arg_max())
1489
+ return try_reify_expr(expr, *columns2)
1325
1490
 
1326
1491
 
1327
1492
  ##
@@ -1338,13 +1503,24 @@ def get_data_type_or_series_time_zone(
1338
1503
  dtype = dtype_cls()
1339
1504
  case Series() as series:
1340
1505
  dtype = series.dtype
1341
- case _ as never:
1506
+ case never:
1342
1507
  assert_never(never)
1343
- if not isinstance(dtype, Datetime):
1344
- raise _GetDataTypeOrSeriesTimeZoneNotDateTimeError(dtype=dtype)
1345
- if dtype.time_zone is None:
1346
- raise _GetDataTypeOrSeriesTimeZoneNotZonedError(dtype=dtype)
1347
- return ZoneInfo(dtype.time_zone)
1508
+ match dtype:
1509
+ case Datetime() as datetime:
1510
+ if datetime.time_zone is None:
1511
+ raise _GetDataTypeOrSeriesTimeZoneNotZonedError(dtype=datetime)
1512
+ return ZoneInfo(datetime.time_zone)
1513
+ case Struct() as struct:
1514
+ try:
1515
+ return one({
1516
+ get_data_type_or_series_time_zone(f.dtype) for f in struct.fields
1517
+ })
1518
+ except OneNonUniqueError as error:
1519
+ raise _GetDataTypeOrSeriesTimeZoneStructNonUniqueError(
1520
+ dtype=struct, first=error.first, second=error.second
1521
+ ) from None
1522
+ case _:
1523
+ raise _GetDataTypeOrSeriesTimeZoneNotDateTimeError(dtype=dtype)
1348
1524
 
1349
1525
 
1350
1526
  @dataclass(kw_only=True, slots=True)
@@ -1366,6 +1542,18 @@ class _GetDataTypeOrSeriesTimeZoneNotZonedError(GetDataTypeOrSeriesTimeZoneError
1366
1542
  return f"Data type must be zoned; got {self.dtype}"
1367
1543
 
1368
1544
 
1545
+ @dataclass(kw_only=True, slots=True)
1546
+ class _GetDataTypeOrSeriesTimeZoneStructNonUniqueError(
1547
+ GetDataTypeOrSeriesTimeZoneError
1548
+ ):
1549
+ first: ZoneInfo
1550
+ second: ZoneInfo
1551
+
1552
+ @override
1553
+ def __str__(self) -> str:
1554
+ return f"Struct data type must contain exactly one time zone; got {self.first}, {self.second} and perhaps more"
1555
+
1556
+
1369
1557
  ##
1370
1558
 
1371
1559
 
@@ -1375,9 +1563,8 @@ def get_expr_name(obj: Series | DataFrame, expr: IntoExprColumn, /) -> str:
1375
1563
  case Series() as series:
1376
1564
  return get_expr_name(series.to_frame(), expr)
1377
1565
  case DataFrame() as df:
1378
- selected = df.select(expr)
1379
- return one(selected.columns)
1380
- case _ as never:
1566
+ return one_column(df.select(expr)).name
1567
+ case never:
1381
1568
  assert_never(never)
1382
1569
 
1383
1570
 
@@ -1399,50 +1586,31 @@ def get_frequency_spectrum(series: Series, /, *, d: int = 1) -> DataFrame:
1399
1586
 
1400
1587
 
1401
1588
  @overload
1402
- def get_series_number_of_decimals(
1403
- series: Series, /, *, nullable: Literal[True]
1404
- ) -> int | None: ...
1589
+ def increasing_horizontal(*columns: ExprLike) -> Expr: ...
1405
1590
  @overload
1406
- def get_series_number_of_decimals(
1407
- series: Series, /, *, nullable: Literal[False] = False
1408
- ) -> int: ...
1591
+ def increasing_horizontal(*columns: Series) -> Series: ...
1409
1592
  @overload
1410
- def get_series_number_of_decimals(
1411
- series: Series, /, *, nullable: bool = False
1412
- ) -> int | None: ...
1413
- def get_series_number_of_decimals(
1414
- series: Series, /, *, nullable: bool = False
1415
- ) -> int | None:
1416
- """Get the number of decimals of a series."""
1417
- if not isinstance(dtype := series.dtype, Float64):
1418
- raise _GetSeriesNumberOfDecimalsNotFloatError(dtype=dtype)
1419
- decimals = series.map_elements(number_of_decimals, return_dtype=Int64).max()
1420
- try:
1421
- return ensure_int(decimals, nullable=nullable)
1422
- except EnsureIntError:
1423
- raise _GetSeriesNumberOfDecimalsAllNullError(series=series) from None
1424
-
1425
-
1426
- @dataclass(kw_only=True, slots=True)
1427
- class GetSeriesNumberOfDecimalsError(Exception): ...
1593
+ def increasing_horizontal(*columns: IntoExprColumn) -> ExprOrSeries: ...
1594
+ def increasing_horizontal(*columns: IntoExprColumn) -> ExprOrSeries:
1595
+ """Check if a set of columns are increasing."""
1596
+ columns2 = ensure_expr_or_series_many(*columns)
1597
+ if len(columns2) == 0:
1598
+ return lit(value=True, dtype=Boolean)
1599
+ return all_horizontal(prev < curr for prev, curr in pairwise(columns2))
1428
1600
 
1429
1601
 
1430
- @dataclass(kw_only=True, slots=True)
1431
- class _GetSeriesNumberOfDecimalsNotFloatError(GetSeriesNumberOfDecimalsError):
1432
- dtype: DataType
1433
-
1434
- @override
1435
- def __str__(self) -> str:
1436
- return f"Data type must be Float64; got {self.dtype}"
1437
-
1438
-
1439
- @dataclass(kw_only=True, slots=True)
1440
- class _GetSeriesNumberOfDecimalsAllNullError(GetSeriesNumberOfDecimalsError):
1441
- series: Series
1442
-
1443
- @override
1444
- def __str__(self) -> str:
1445
- return f"Series must not be all-null; got {self.series}"
1602
+ @overload
1603
+ def decreasing_horizontal(*columns: ExprLike) -> Expr: ...
1604
+ @overload
1605
+ def decreasing_horizontal(*columns: Series) -> Series: ...
1606
+ @overload
1607
+ def decreasing_horizontal(*columns: IntoExprColumn) -> ExprOrSeries: ...
1608
+ def decreasing_horizontal(*columns: IntoExprColumn) -> ExprOrSeries:
1609
+ """Check if a set of columns are decreasing."""
1610
+ columns2 = ensure_expr_or_series_many(*columns)
1611
+ if len(columns2) == 0:
1612
+ return lit(value=True, dtype=Boolean)
1613
+ return all_horizontal(prev > curr for prev, curr in pairwise(columns2))
1446
1614
 
1447
1615
 
1448
1616
  ##
@@ -1575,13 +1743,49 @@ def integers(
1575
1743
  name=name,
1576
1744
  dtype=dtype,
1577
1745
  )
1578
- case _ as never:
1746
+ case never:
1579
1747
  assert_never(never)
1580
1748
 
1581
1749
 
1582
1750
  ##
1583
1751
 
1584
1752
 
1753
+ @overload
1754
+ def is_close(
1755
+ x: ExprLike, y: ExprLike, /, *, rel_tol: float = 1e-9, abs_tol: float = 0
1756
+ ) -> Expr: ...
1757
+ @overload
1758
+ def is_close(
1759
+ x: Series, y: Series, /, *, rel_tol: float = 1e-9, abs_tol: float = 0
1760
+ ) -> Series: ...
1761
+ @overload
1762
+ def is_close(
1763
+ x: IntoExprColumn,
1764
+ y: IntoExprColumn,
1765
+ /,
1766
+ *,
1767
+ rel_tol: float = 1e-9,
1768
+ abs_tol: float = 0,
1769
+ ) -> ExprOrSeries: ...
1770
+ def is_close(
1771
+ x: IntoExprColumn,
1772
+ y: IntoExprColumn,
1773
+ /,
1774
+ *,
1775
+ rel_tol: float = 1e-9,
1776
+ abs_tol: float = 0,
1777
+ ) -> ExprOrSeries:
1778
+ """Check if two columns are close."""
1779
+ x, y = map(ensure_expr_or_series, [x, y])
1780
+ result = (x - y).abs() <= max_horizontal(
1781
+ rel_tol * max_horizontal(x.abs(), y.abs()), abs_tol
1782
+ )
1783
+ return try_reify_expr(result, x, y)
1784
+
1785
+
1786
+ ##
1787
+
1788
+
1585
1789
  @overload
1586
1790
  def is_near_event(
1587
1791
  *exprs: ExprLike, before: int = 0, after: int = 0, **named_exprs: ExprLike
@@ -1596,13 +1800,13 @@ def is_near_event(
1596
1800
  before: int = 0,
1597
1801
  after: int = 0,
1598
1802
  **named_exprs: IntoExprColumn,
1599
- ) -> Expr | Series: ...
1803
+ ) -> ExprOrSeries: ...
1600
1804
  def is_near_event(
1601
1805
  *exprs: IntoExprColumn,
1602
1806
  before: int = 0,
1603
1807
  after: int = 0,
1604
1808
  **named_exprs: IntoExprColumn,
1605
- ) -> Expr | Series:
1809
+ ) -> ExprOrSeries:
1606
1810
  """Compute the rows near any event."""
1607
1811
  if before <= -1:
1608
1812
  raise _IsNearEventBeforeError(before=before)
@@ -1645,68 +1849,28 @@ class _IsNearEventAfterError(IsNearEventError):
1645
1849
  ##
1646
1850
 
1647
1851
 
1648
- def is_not_null_struct_series(series: Series, /) -> Series:
1649
- """Check if a struct-dtype Series is not null as per the <= 1.1 definition."""
1650
- try:
1651
- return ~is_null_struct_series(series)
1652
- except IsNullStructSeriesError as error:
1653
- raise IsNotNullStructSeriesError(series=error.series) from None
1654
-
1655
-
1656
- @dataclass(kw_only=True, slots=True)
1657
- class IsNotNullStructSeriesError(Exception):
1658
- series: Series
1659
-
1660
- @override
1661
- def __str__(self) -> str:
1662
- return f"Series must have Struct-dtype; got {self.series.dtype}"
1663
-
1664
-
1665
- ##
1666
-
1667
-
1668
- def is_null_struct_series(series: Series, /) -> Series:
1669
- """Check if a struct-dtype Series is null as per the <= 1.1 definition."""
1670
- if not isinstance(series.dtype, Struct):
1671
- raise IsNullStructSeriesError(series=series)
1672
- paths = _is_null_struct_series_one(series.dtype)
1673
- paths = list(paths)
1674
- exprs = map(_is_null_struct_to_expr, paths)
1675
- expr = all_horizontal(*exprs)
1676
- return (
1677
- series.struct.unnest().with_columns(_result=expr)["_result"].rename(series.name)
1678
- )
1679
-
1680
-
1681
- def _is_null_struct_series_one(
1682
- dtype: Struct, /, *, root: Iterable[str] = ()
1683
- ) -> Iterator[Sequence[str]]:
1684
- for field in dtype.fields:
1685
- name = field.name
1686
- inner = field.dtype
1687
- path = list(chain(root, [name]))
1688
- if isinstance(inner, Struct):
1689
- yield from _is_null_struct_series_one(inner, root=path)
1690
- else:
1691
- yield path
1692
-
1693
-
1694
- def _is_null_struct_to_expr(path: Iterable[str], /) -> Expr:
1695
- head, *tail = path
1696
- return reduce(_is_null_struct_to_expr_reducer, tail, col(head)).is_null()
1697
-
1698
-
1699
- def _is_null_struct_to_expr_reducer(expr: Expr, path: str, /) -> Expr:
1700
- return expr.struct[path]
1701
-
1852
+ @overload
1853
+ def is_true(column: ExprLike, /) -> Expr: ...
1854
+ @overload
1855
+ def is_true(column: Series, /) -> Series: ...
1856
+ @overload
1857
+ def is_true(column: IntoExprColumn, /) -> ExprOrSeries: ...
1858
+ def is_true(column: IntoExprColumn, /) -> ExprOrSeries:
1859
+ """Compute when a boolean series is True."""
1860
+ column = ensure_expr_or_series(column)
1861
+ return (column.is_not_null()) & column
1702
1862
 
1703
- @dataclass(kw_only=True, slots=True)
1704
- class IsNullStructSeriesError(Exception):
1705
- series: Series
1706
1863
 
1707
- @override
1708
- def __str__(self) -> str:
1709
- return f"Series must have Struct-dtype; got {self.series.dtype}"
1864
+ @overload
1865
+ def is_false(column: ExprLike, /) -> Expr: ...
1866
+ @overload
1867
+ def is_false(column: Series, /) -> Series: ...
1868
+ @overload
1869
+ def is_false(column: IntoExprColumn, /) -> ExprOrSeries: ...
1870
+ def is_false(column: IntoExprColumn, /) -> ExprOrSeries:
1871
+ """Compute when a boolean series is False."""
1872
+ column = ensure_expr_or_series(column)
1873
+ return (column.is_not_null()) & (~column)
1710
1874
 
1711
1875
 
1712
1876
  ##
@@ -1880,7 +2044,7 @@ def map_over_columns(
1880
2044
  return _map_over_series_one(func, series)
1881
2045
  case DataFrame() as df:
1882
2046
  return df.select(*(_map_over_series_one(func, df[c]) for c in df.columns))
1883
- case _ as never:
2047
+ case never:
1884
2048
  assert_never(never)
1885
2049
 
1886
2050
 
@@ -1895,46 +2059,74 @@ def _map_over_series_one(func: Callable[[Series], Series], series: Series, /) ->
1895
2059
  ##
1896
2060
 
1897
2061
 
1898
- def nan_sum_agg(column: str | Expr, /, *, dtype: PolarsDataType | None = None) -> Expr:
2062
+ def nan_sum_agg(column: str | Expr, /) -> Expr:
1899
2063
  """Nan sum aggregation."""
1900
2064
  col_use = col(column) if isinstance(column, str) else column
1901
- return (
1902
- when(col_use.is_not_null().any())
1903
- .then(col_use.sum())
1904
- .otherwise(lit(None, dtype=dtype))
1905
- )
2065
+ return when(col_use.is_not_null().any()).then(col_use.sum())
1906
2066
 
1907
2067
 
1908
2068
  ##
1909
2069
 
1910
2070
 
1911
- def nan_sum_cols(
1912
- column: str | Expr, *columns: str | Expr, dtype: PolarsDataType | None = None
1913
- ) -> Expr:
2071
+ @overload
2072
+ def nan_sum_horizontal(*columns: Series) -> Series: ...
2073
+ @overload
2074
+ def nan_sum_horizontal(*columns: IntoExprColumn) -> ExprOrSeries: ...
2075
+ def nan_sum_horizontal(*columns: IntoExprColumn) -> ExprOrSeries:
1914
2076
  """Nan sum across columns."""
1915
- all_columns = chain([column], columns)
1916
- all_exprs = (
1917
- col(column) if isinstance(column, str) else column for column in all_columns
2077
+ columns2 = ensure_expr_or_series_many(*columns)
2078
+ expr = when(any_horizontal(*(c.is_not_null() for c in columns2))).then(
2079
+ sum_horizontal(*columns2)
1918
2080
  )
2081
+ return try_reify_expr(expr, *columns2)
1919
2082
 
1920
- def func(x: Expr, y: Expr, /) -> Expr:
1921
- return (
1922
- when(x.is_not_null() & y.is_not_null())
1923
- .then(x + y)
1924
- .when(x.is_not_null() & y.is_null())
1925
- .then(x)
1926
- .when(x.is_null() & y.is_not_null())
1927
- .then(y)
1928
- .otherwise(lit(None, dtype=dtype))
1929
- )
1930
2083
 
1931
- return reduce(func, all_exprs)
2084
+ ##
2085
+
2086
+
2087
+ @overload
2088
+ def normal_pdf(
2089
+ x: ExprLike,
2090
+ /,
2091
+ *,
2092
+ loc: float | IntoExprColumn = 0.0,
2093
+ scale: float | IntoExprColumn = 1.0,
2094
+ ) -> Expr: ...
2095
+ @overload
2096
+ def normal_pdf(
2097
+ x: Series,
2098
+ /,
2099
+ *,
2100
+ loc: float | IntoExprColumn = 0.0,
2101
+ scale: float | IntoExprColumn = 1.0,
2102
+ ) -> Series: ...
2103
+ @overload
2104
+ def normal_pdf(
2105
+ x: IntoExprColumn,
2106
+ /,
2107
+ *,
2108
+ loc: float | IntoExprColumn = 0.0,
2109
+ scale: float | IntoExprColumn = 1.0,
2110
+ ) -> ExprOrSeries: ...
2111
+ def normal_pdf(
2112
+ x: IntoExprColumn,
2113
+ /,
2114
+ *,
2115
+ loc: float | IntoExprColumn = 0.0,
2116
+ scale: float | IntoExprColumn = 1.0,
2117
+ ) -> ExprOrSeries:
2118
+ """Compute the PDF of a normal distribution."""
2119
+ x = ensure_expr_or_series(x)
2120
+ loc = loc if isinstance(loc, int | float) else ensure_expr_or_series(loc)
2121
+ scale = scale if isinstance(scale, int | float) else ensure_expr_or_series(scale)
2122
+ expr = (1 / (scale * sqrt(2 * pi))) * (-(1 / 2) * ((x - loc) / scale) ** 2).exp()
2123
+ return try_reify_expr(expr, x)
1932
2124
 
1933
2125
 
1934
2126
  ##
1935
2127
 
1936
2128
 
1937
- def normal(
2129
+ def normal_rv(
1938
2130
  obj: int | Series | DataFrame,
1939
2131
  /,
1940
2132
  *,
@@ -1953,20 +2145,102 @@ def normal(
1953
2145
  values = rng.normal(loc=loc, scale=scale, size=height)
1954
2146
  return Series(name=name, values=values, dtype=dtype)
1955
2147
  case Series() as series:
1956
- return normal(
2148
+ return normal_rv(
1957
2149
  series.len(), loc=loc, scale=scale, seed=seed, name=name, dtype=dtype
1958
2150
  )
1959
2151
  case DataFrame() as df:
1960
- return normal(
2152
+ return normal_rv(
1961
2153
  df.height, loc=loc, scale=scale, seed=seed, name=name, dtype=dtype
1962
2154
  )
1963
- case _ as never:
2155
+ case never:
1964
2156
  assert_never(never)
1965
2157
 
1966
2158
 
1967
2159
  ##
1968
2160
 
1969
2161
 
2162
+ @overload
2163
+ def number_of_decimals(
2164
+ column: ExprLike, /, *, max_decimals: int = MAX_DECIMALS
2165
+ ) -> Expr: ...
2166
+ @overload
2167
+ def number_of_decimals(
2168
+ column: Series, /, *, max_decimals: int = MAX_DECIMALS
2169
+ ) -> Series: ...
2170
+ @overload
2171
+ def number_of_decimals(
2172
+ column: IntoExprColumn, /, *, max_decimals: int = MAX_DECIMALS
2173
+ ) -> ExprOrSeries: ...
2174
+ def number_of_decimals(
2175
+ column: IntoExprColumn, /, *, max_decimals: int = MAX_DECIMALS
2176
+ ) -> ExprOrSeries:
2177
+ """Get the number of decimals."""
2178
+ column = ensure_expr_or_series(column)
2179
+ frac = column - column.floor()
2180
+ results = (
2181
+ _number_of_decimals_check_scale(frac, s) for s in range(max_decimals + 1)
2182
+ )
2183
+ return first_true_horizontal(*results)
2184
+
2185
+
2186
+ def _number_of_decimals_check_scale(frac: ExprOrSeries, scale: int, /) -> ExprOrSeries:
2187
+ scaled = 10**scale * frac
2188
+ return is_close(scaled, scaled.round()).alias(str(scale))
2189
+
2190
+
2191
+ ##
2192
+
2193
+
2194
+ def offset_datetime(
2195
+ datetime: ZonedDateTime, offset: str, /, *, n: int = 1
2196
+ ) -> ZonedDateTime:
2197
+ """Offset a datetime as `polars` would."""
2198
+ sr = Series(values=[datetime.py_datetime()])
2199
+ for _ in range(n):
2200
+ sr = sr.dt.offset_by(offset)
2201
+ return ZonedDateTime.from_py_datetime(sr.item())
2202
+
2203
+
2204
+ ##
2205
+
2206
+
2207
+ def one_column(df: DataFrame, /) -> Series:
2208
+ """Return the unique column in a DataFrame."""
2209
+ try:
2210
+ return df[one(df.columns)]
2211
+ except OneEmptyError:
2212
+ raise OneColumnEmptyError(df=df) from None
2213
+ except OneNonUniqueError as error:
2214
+ raise OneColumnNonUniqueError(
2215
+ df=df, first=error.first, second=error.second
2216
+ ) from None
2217
+
2218
+
2219
+ @dataclass(kw_only=True, slots=True)
2220
+ class OneColumnError(Exception):
2221
+ df: DataFrame
2222
+
2223
+
2224
+ @dataclass(kw_only=True, slots=True)
2225
+ class OneColumnEmptyError(OneColumnError):
2226
+ @override
2227
+ def __str__(self) -> str:
2228
+ return "DataFrame must not be empty"
2229
+
2230
+
2231
+ @dataclass(kw_only=True, slots=True)
2232
+ class OneColumnNonUniqueError(OneColumnError):
2233
+ first: str
2234
+ second: str
2235
+
2236
+ @override
2237
+ def __str__(self) -> str:
2238
+ return f"DataFrame must contain exactly one column; got {self.first!r}, {self.second!r} and perhaps more"
2239
+
2240
+
2241
+ ##
2242
+
2243
+
1970
2244
  @overload
1971
2245
  def order_of_magnitude(column: ExprLike, /, *, round_: bool = False) -> Expr: ...
1972
2246
  @overload
@@ -1974,10 +2248,10 @@ def order_of_magnitude(column: Series, /, *, round_: bool = False) -> Series: ..
1974
2248
  @overload
1975
2249
  def order_of_magnitude(
1976
2250
  column: IntoExprColumn, /, *, round_: bool = False
1977
- ) -> Expr | Series: ...
2251
+ ) -> ExprOrSeries: ...
1978
2252
  def order_of_magnitude(
1979
2253
  column: IntoExprColumn, /, *, round_: bool = False
1980
- ) -> Expr | Series:
2254
+ ) -> ExprOrSeries:
1981
2255
  """Compute the order of magnitude of a column."""
1982
2256
  column = ensure_expr_or_series(column)
1983
2257
  result = column.abs().log10()
@@ -1987,6 +2261,75 @@ def order_of_magnitude(
1987
2261
  ##
1988
2262
 
1989
2263
 
2264
+ @overload
2265
+ def period_range(
2266
+ start: ZonedDateTime,
2267
+ end_or_length: ZonedDateTime | int,
2268
+ /,
2269
+ *,
2270
+ interval: str = "1d",
2271
+ time_unit: TimeUnit | None = None,
2272
+ time_zone: TimeZoneLike | None = None,
2273
+ eager: Literal[True],
2274
+ ) -> Series: ...
2275
+ @overload
2276
+ def period_range(
2277
+ start: ZonedDateTime,
2278
+ end_or_length: ZonedDateTime | int,
2279
+ /,
2280
+ *,
2281
+ interval: str = "1d",
2282
+ time_unit: TimeUnit | None = None,
2283
+ time_zone: TimeZoneLike | None = None,
2284
+ eager: Literal[False] = False,
2285
+ ) -> Expr: ...
2286
+ @overload
2287
+ def period_range(
2288
+ start: ZonedDateTime,
2289
+ end_or_length: ZonedDateTime | int,
2290
+ /,
2291
+ *,
2292
+ interval: str = "1d",
2293
+ time_unit: TimeUnit | None = None,
2294
+ time_zone: TimeZoneLike | None = None,
2295
+ eager: bool = False,
2296
+ ) -> Series | Expr: ...
2297
+ def period_range(
2298
+ start: ZonedDateTime,
2299
+ end_or_length: ZonedDateTime | int,
2300
+ /,
2301
+ *,
2302
+ interval: str = "1d",
2303
+ time_unit: TimeUnit | None = None,
2304
+ time_zone: TimeZoneLike | None = None,
2305
+ eager: bool = False,
2306
+ ) -> Series | Expr:
2307
+ """Construct a period range."""
2308
+ time_zone_use = None if time_zone is None else to_time_zone_name(time_zone)
2309
+ match end_or_length:
2310
+ case ZonedDateTime() as end:
2311
+ ...
2312
+ case int() as length:
2313
+ end = offset_datetime(start, interval, n=length)
2314
+ case never:
2315
+ assert_never(never)
2316
+ starts = datetime_range(
2317
+ start.py_datetime(),
2318
+ end.py_datetime(),
2319
+ interval,
2320
+ closed="left",
2321
+ time_unit=time_unit,
2322
+ time_zone=time_zone_use,
2323
+ eager=eager,
2324
+ ).alias("start")
2325
+ ends = starts.dt.offset_by(interval).alias("end")
2326
+ period = struct(starts, ends)
2327
+ return try_reify_expr(period, starts, ends)
2328
+
2329
+
2330
+ ##
2331
+
2332
+
1990
2333
  def reify_exprs(
1991
2334
  *exprs: IntoExprColumn, **named_exprs: IntoExprColumn
1992
2335
  ) -> Expr | Series | DataFrame:
@@ -2019,13 +2362,10 @@ def reify_exprs(
2019
2362
  .with_columns(*all_exprs)
2020
2363
  .drop("_index")
2021
2364
  )
2022
- match len(df.columns):
2023
- case 0:
2024
- raise ImpossibleCaseError(case=[f"{df.columns=}"]) # pragma: no cover
2025
- case 1:
2026
- return df[one(df.columns)]
2027
- case _:
2028
- return df
2365
+ try:
2366
+ return one_column(df)
2367
+ except OneColumnNonUniqueError:
2368
+ return df
2029
2369
 
2030
2370
 
2031
2371
  @dataclass(kw_only=True, slots=True)
@@ -2075,7 +2415,7 @@ def _replace_time_zone_one(
2075
2415
  sr: Series, /, *, time_zone: TimeZoneLike | None = UTC
2076
2416
  ) -> Series:
2077
2417
  if isinstance(sr.dtype, Datetime):
2078
- time_zone_use = None if time_zone is None else get_time_zone_name(time_zone)
2418
+ time_zone_use = None if time_zone is None else to_time_zone_name(time_zone)
2079
2419
  return sr.dt.replace_time_zone(time_zone_use)
2080
2420
  return sr
2081
2421
 
@@ -2192,8 +2532,138 @@ def _reconstruct_dtype(obj: _DeconDType, /) -> PolarsDataType:
2192
2532
  return List(_reconstruct_dtype(inner))
2193
2533
  case "Struct", inner:
2194
2534
  return Struct(_reconstruct_schema(inner))
2195
- case _ as never:
2535
+ case never:
2536
+ assert_never(never)
2537
+
2538
+
2539
+ ##
2540
+
2541
+
2542
+ @overload
2543
+ def round_to_float(
2544
+ x: ExprLike, y: float, /, *, mode: RoundMode = "half_to_even"
2545
+ ) -> Expr: ...
2546
+ @overload
2547
+ def round_to_float(
2548
+ x: Series, y: float | ExprOrSeries, /, *, mode: RoundMode = "half_to_even"
2549
+ ) -> Series: ...
2550
+ @overload
2551
+ def round_to_float(
2552
+ x: ExprLike, y: Series, /, *, mode: RoundMode = "half_to_even"
2553
+ ) -> Series: ...
2554
+ @overload
2555
+ def round_to_float(
2556
+ x: ExprLike, y: Expr, /, *, mode: RoundMode = "half_to_even"
2557
+ ) -> Expr: ...
2558
+ @overload
2559
+ def round_to_float(
2560
+ x: IntoExprColumn, y: float | Series, /, *, mode: RoundMode = "half_to_even"
2561
+ ) -> ExprOrSeries: ...
2562
+ def round_to_float(
2563
+ x: IntoExprColumn, y: float | IntoExprColumn, /, *, mode: RoundMode = "half_to_even"
2564
+ ) -> ExprOrSeries:
2565
+ """Round a column to the nearest multiple of another float."""
2566
+ x = ensure_expr_or_series(x)
2567
+ y = y if isinstance(y, int | float) else ensure_expr_or_series(y)
2568
+ match x, y:
2569
+ case Expr() | Series(), int() | float():
2570
+ z = (x / y).round(mode=mode) * y
2571
+ return z.round(decimals=utilities.math.number_of_decimals(y) + 1)
2572
+ case Series(), Expr() | Series():
2573
+ df = (
2574
+ x
2575
+ .to_frame()
2576
+ .with_columns(y)
2577
+ .with_columns(number_of_decimals(y).alias("_decimals"))
2578
+ .with_row_index(name="_index")
2579
+ .group_by("_decimals")
2580
+ .map_groups(_round_to_float_one)
2581
+ .sort("_index")
2582
+ )
2583
+ return df[df.columns[1]]
2584
+ case Expr(), Series():
2585
+ df = y.to_frame().with_columns(x)
2586
+ return round_to_float(df[df.columns[1]], df[df.columns[0]], mode=mode)
2587
+ case Expr(), Expr() | str():
2588
+ raise RoundToFloatError(x=x, y=y)
2589
+ case never:
2590
+ assert_never(never)
2591
+
2592
+
2593
+ def _round_to_float_one(df: DataFrame, /) -> DataFrame:
2594
+ decimals: int | None = df["_decimals"].unique().item()
2595
+ name = df.columns[1]
2596
+ match decimals:
2597
+ case int():
2598
+ expr = col(name).round(decimals=decimals)
2599
+ case None:
2600
+ expr = lit(None, dtype=Float64).alias(name)
2601
+ case never:
2196
2602
  assert_never(never)
2603
+ return df.with_columns(expr)
2604
+
2605
+
2606
+ @dataclass(kw_only=True, slots=True)
2607
+ class RoundToFloatError(Exception):
2608
+ x: IntoExprColumn
2609
+ y: IntoExprColumn
2610
+
2611
+ @override
2612
+ def __str__(self) -> str:
2613
+ return f"At least 1 of the dividend and/or divisor must be a Series; got {get_class_name(self.x)!r} and {get_class_name(self.y)!r}"
2614
+
2615
+
2616
+ ##
2617
+
2618
+
2619
+ def search_period(
2620
+ series: Series,
2621
+ date_time: ZonedDateTime,
2622
+ /,
2623
+ *,
2624
+ start_or_end: Literal["start", "end"] = "end",
2625
+ ) -> int | None:
2626
+ """Search a series of periods for the one containing a given date-time."""
2627
+ end = series.struct["end"]
2628
+ py_date_time = date_time.py_datetime()
2629
+ match start_or_end:
2630
+ case "start":
2631
+ index = end.search_sorted(py_date_time, side="right")
2632
+ if index >= len(series):
2633
+ return None
2634
+ item: dt.datetime = series[index]["start"]
2635
+ return index if py_date_time >= item else None
2636
+ case "end":
2637
+ index = end.search_sorted(py_date_time, side="left")
2638
+ if index >= len(series):
2639
+ return None
2640
+ item: dt.datetime = series[index]["start"]
2641
+ return index if py_date_time > item else None
2642
+
2643
+
2644
+ ##
2645
+
2646
+
2647
+ def select_exact(
2648
+ df: DataFrame, /, *columns: IntoExprColumn, drop: MaybeIterable[str] | None = None
2649
+ ) -> DataFrame:
2650
+ """Select an exact set of columns from a DataFrame."""
2651
+ names = [get_expr_name(df, c) for c in columns]
2652
+ drop = set() if drop is None else set(always_iterable(drop))
2653
+ union = set(names) | drop
2654
+ extra = [c for c in df.columns if c not in union]
2655
+ if len(extra) >= 1:
2656
+ raise SelectExactError(columns=extra)
2657
+ return df.select(*columns)
2658
+
2659
+
2660
+ @dataclass(kw_only=True, slots=True)
2661
+ class SelectExactError(Exception):
2662
+ columns: list[str]
2663
+
2664
+ @override
2665
+ def __str__(self) -> str:
2666
+ return f"All columns must be selected; got {get_repr(self.columns)} remaining"
2197
2667
 
2198
2668
 
2199
2669
  ##
@@ -2229,79 +2699,79 @@ def struct_dtype(**kwargs: PolarsDataType) -> Struct:
2229
2699
  ##
2230
2700
 
2231
2701
 
2232
- def struct_from_dataclass(
2233
- cls: type[Dataclass],
2234
- /,
2235
- *,
2236
- globalns: StrMapping | None = None,
2237
- localns: StrMapping | None = None,
2238
- warn_name_errors: bool = False,
2239
- time_zone: TimeZoneLike | None = None,
2240
- ) -> Struct:
2241
- """Construct the Struct data type for a dataclass."""
2242
- if not is_dataclass_class(cls):
2243
- raise _StructFromDataClassNotADataclassError(cls=cls)
2244
- anns = get_type_hints(
2245
- cls, globalns=globalns, localns=localns, warn_name_errors=warn_name_errors
2246
- )
2247
- data_types = {
2248
- k: _struct_from_dataclass_one(v, time_zone=time_zone) for k, v in anns.items()
2249
- }
2250
- return Struct(data_types)
2702
+ @overload
2703
+ def to_true(column: ExprLike, /) -> Expr: ...
2704
+ @overload
2705
+ def to_true(column: Series, /) -> Series: ...
2706
+ @overload
2707
+ def to_true(column: IntoExprColumn, /) -> ExprOrSeries: ...
2708
+ def to_true(column: IntoExprColumn, /) -> ExprOrSeries:
2709
+ """Compute when a boolean series turns True."""
2710
+ t = is_true(column)
2711
+ return ((~t).shift() & t).fill_null(value=False)
2251
2712
 
2252
2713
 
2253
- def _struct_from_dataclass_one(
2254
- ann: Any, /, *, time_zone: TimeZoneLike | None = None
2255
- ) -> PolarsDataType:
2256
- mapping = {bool: Boolean, dt.date: Date, float: Float64, int: Int64, str: String}
2257
- with suppress(KeyError):
2258
- return mapping[ann]
2259
- if ann is dt.datetime:
2260
- if time_zone is None:
2261
- raise _StructFromDataClassTimeZoneMissingError
2262
- return zoned_datetime(time_zone=time_zone)
2263
- if is_dataclass_class(ann):
2264
- return struct_from_dataclass(ann, time_zone=time_zone)
2265
- if (isinstance(ann, type) and issubclass(ann, enum.Enum)) or (
2266
- is_literal_type(ann) and is_iterable_of(get_args(ann), str)
2267
- ):
2268
- return String
2269
- if is_optional_type(ann):
2270
- return _struct_from_dataclass_one(
2271
- one(get_args(ann, optional_drop_none=True)), time_zone=time_zone
2272
- )
2273
- if is_frozenset_type(ann) or is_list_type(ann) or is_set_type(ann):
2274
- return List(_struct_from_dataclass_one(one(get_args(ann)), time_zone=time_zone))
2275
- raise _StructFromDataClassTypeError(ann=ann)
2714
+ @overload
2715
+ def to_not_true(column: ExprLike, /) -> Expr: ...
2716
+ @overload
2717
+ def to_not_true(column: Series, /) -> Series: ...
2718
+ @overload
2719
+ def to_not_true(column: IntoExprColumn, /) -> ExprOrSeries: ...
2720
+ def to_not_true(column: IntoExprColumn, /) -> ExprOrSeries:
2721
+ """Compute when a boolean series turns non-True."""
2722
+ t = is_true(column)
2723
+ return (t.shift() & (~t)).fill_null(value=False)
2276
2724
 
2277
2725
 
2278
- @dataclass(kw_only=True, slots=True)
2279
- class StructFromDataClassError(Exception): ...
2726
+ @overload
2727
+ def to_false(column: ExprLike, /) -> Expr: ...
2728
+ @overload
2729
+ def to_false(column: Series, /) -> Series: ...
2730
+ @overload
2731
+ def to_false(column: IntoExprColumn, /) -> ExprOrSeries: ...
2732
+ def to_false(column: IntoExprColumn, /) -> ExprOrSeries:
2733
+ """Compute when a boolean series turns False."""
2734
+ f = is_false(column)
2735
+ return ((~f).shift() & f).fill_null(value=False)
2280
2736
 
2281
2737
 
2282
- @dataclass(kw_only=True, slots=True)
2283
- class _StructFromDataClassNotADataclassError(StructFromDataClassError):
2284
- cls: type[Dataclass]
2738
+ @overload
2739
+ def to_not_false(column: ExprLike, /) -> Expr: ...
2740
+ @overload
2741
+ def to_not_false(column: Series, /) -> Series: ...
2742
+ @overload
2743
+ def to_not_false(column: IntoExprColumn, /) -> ExprOrSeries: ...
2744
+ def to_not_false(column: IntoExprColumn, /) -> ExprOrSeries:
2745
+ """Compute when a boolean series turns non-False."""
2746
+ f = is_false(column)
2747
+ return (f.shift() & (~f)).fill_null(value=False)
2285
2748
 
2286
- @override
2287
- def __str__(self) -> str:
2288
- return f"Object must be a dataclass; got {self.cls}"
2289
2749
 
2750
+ ##
2290
2751
 
2291
- @dataclass(kw_only=True, slots=True)
2292
- class _StructFromDataClassTimeZoneMissingError(StructFromDataClassError):
2293
- @override
2294
- def __str__(self) -> str:
2295
- return "Time-zone must be given"
2296
2752
 
2753
+ @overload
2754
+ def true_like(column: ExprLike, /) -> Expr: ...
2755
+ @overload
2756
+ def true_like(column: Series, /) -> Series: ...
2757
+ @overload
2758
+ def true_like(column: IntoExprColumn, /) -> ExprOrSeries: ...
2759
+ def true_like(column: IntoExprColumn, /) -> ExprOrSeries:
2760
+ """Compute a column of `True` values."""
2761
+ column = ensure_expr_or_series(column)
2762
+ return column.is_null() | column.is_not_null()
2297
2763
 
2298
- @dataclass(kw_only=True, slots=True)
2299
- class _StructFromDataClassTypeError(StructFromDataClassError):
2300
- ann: Any
2301
2764
 
2302
- @override
2303
- def __str__(self) -> str:
2304
- return f"Unsupported type: {self.ann}"
2765
+ @overload
2766
+ def false_like(column: ExprLike, /) -> Expr: ...
2767
+ @overload
2768
+ def false_like(column: Series, /) -> Series: ...
2769
+ @overload
2770
+ def false_like(column: IntoExprColumn, /) -> ExprOrSeries: ...
2771
+ def false_like(column: IntoExprColumn, /) -> ExprOrSeries:
2772
+ """Compute a column of `False` values."""
2773
+ column = ensure_expr_or_series(column)
2774
+ return column.is_null() & column.is_not_null()
2305
2775
 
2306
2776
 
2307
2777
  ##
@@ -2309,7 +2779,7 @@ class _StructFromDataClassTypeError(StructFromDataClassError):
2309
2779
 
2310
2780
  def try_reify_expr(
2311
2781
  expr: IntoExprColumn, /, *exprs: IntoExprColumn, **named_exprs: IntoExprColumn
2312
- ) -> Expr | Series:
2782
+ ) -> ExprOrSeries:
2313
2783
  """Try reify an expression."""
2314
2784
  expr = ensure_expr_or_series(expr)
2315
2785
  all_exprs = ensure_expr_or_series_many(*exprs, **named_exprs)
@@ -2322,7 +2792,7 @@ def try_reify_expr(
2322
2792
  return series
2323
2793
  case DataFrame() as df:
2324
2794
  return df[get_expr_name(df, expr)]
2325
- case _ as never:
2795
+ case never:
2326
2796
  assert_never(never)
2327
2797
 
2328
2798
 
@@ -2355,7 +2825,7 @@ def uniform(
2355
2825
  return uniform(
2356
2826
  df.height, low=low, high=high, seed=seed, name=name, dtype=dtype
2357
2827
  )
2358
- case _ as never:
2828
+ case never:
2359
2829
  assert_never(never)
2360
2830
 
2361
2831
 
@@ -2376,8 +2846,8 @@ def week_num(column: ExprLike, /, *, start: WeekDay = "mon") -> Expr: ...
2376
2846
  @overload
2377
2847
  def week_num(column: Series, /, *, start: WeekDay = "mon") -> Series: ...
2378
2848
  @overload
2379
- def week_num(column: IntoExprColumn, /, *, start: WeekDay = "mon") -> Expr | Series: ...
2380
- def week_num(column: IntoExprColumn, /, *, start: WeekDay = "mon") -> Expr | Series:
2849
+ def week_num(column: IntoExprColumn, /, *, start: WeekDay = "mon") -> ExprOrSeries: ...
2850
+ def week_num(column: IntoExprColumn, /, *, start: WeekDay = "mon") -> ExprOrSeries:
2381
2851
  """Compute the week number of a date column."""
2382
2852
  column = ensure_expr_or_series(column)
2383
2853
  epoch = column.dt.epoch(time_unit="d").alias("epoch")
@@ -2388,87 +2858,129 @@ def week_num(column: IntoExprColumn, /, *, start: WeekDay = "mon") -> Expr | Ser
2388
2858
  ##
2389
2859
 
2390
2860
 
2391
- def zoned_datetime(
2861
+ def zoned_date_time_dtype(
2392
2862
  *, time_unit: TimeUnit = "us", time_zone: TimeZoneLike = UTC
2393
2863
  ) -> Datetime:
2394
- """Create a zoned datetime data type."""
2395
- return Datetime(time_unit=time_unit, time_zone=get_time_zone_name(time_zone))
2864
+ """Create a zoned date-time data type."""
2865
+ return Datetime(time_unit=time_unit, time_zone=to_time_zone_name(time_zone))
2866
+
2867
+
2868
+ def zoned_date_time_period_dtype(
2869
+ *,
2870
+ time_unit: TimeUnit = "us",
2871
+ time_zone: TimeZoneLike | tuple[TimeZoneLike, TimeZoneLike] = UTC,
2872
+ ) -> Struct:
2873
+ """Create a zoned date-time period data type."""
2874
+ match time_zone:
2875
+ case start, end:
2876
+ return struct_dtype(
2877
+ start=zoned_date_time_dtype(time_unit=time_unit, time_zone=start),
2878
+ end=zoned_date_time_dtype(time_unit=time_unit, time_zone=end),
2879
+ )
2880
+ case _:
2881
+ dtype = zoned_date_time_dtype(time_unit=time_unit, time_zone=time_zone)
2882
+ return struct_dtype(start=dtype, end=dtype)
2396
2883
 
2397
2884
 
2398
2885
  __all__ = [
2886
+ "AppendRowError",
2399
2887
  "BooleanValueCountsError",
2400
2888
  "CheckPolarsDataFrameError",
2401
2889
  "ColumnsToDictError",
2402
2890
  "DataClassToDataFrameError",
2891
+ "DatePeriodDType",
2403
2892
  "DatetimeHongKong",
2404
2893
  "DatetimeTokyo",
2405
2894
  "DatetimeUSCentral",
2406
2895
  "DatetimeUSEastern",
2407
2896
  "DatetimeUTC",
2408
- "DropNullStructSeriesError",
2897
+ "ExprOrSeries",
2409
2898
  "FiniteEWMMeanError",
2410
2899
  "GetDataTypeOrSeriesTimeZoneError",
2411
- "GetSeriesNumberOfDecimalsError",
2412
2900
  "InsertAfterError",
2413
2901
  "InsertBeforeError",
2414
2902
  "InsertBetweenError",
2415
2903
  "IsNearEventError",
2416
- "IsNullStructSeriesError",
2904
+ "OneColumnEmptyError",
2905
+ "OneColumnError",
2906
+ "OneColumnNonUniqueError",
2907
+ "RoundToFloatError",
2908
+ "SelectExactError",
2417
2909
  "SetFirstRowAsColumnsError",
2418
- "StructFromDataClassError",
2910
+ "TimePeriodDType",
2419
2911
  "acf",
2420
2912
  "adjust_frequencies",
2421
- "append_dataclass",
2913
+ "all_dataframe_columns",
2914
+ "all_series",
2915
+ "any_dataframe_columns",
2916
+ "any_series",
2917
+ "append_row",
2422
2918
  "are_frames_equal",
2423
2919
  "bernoulli",
2424
2920
  "boolean_value_counts",
2425
- "ceil_datetime",
2426
2921
  "check_polars_dataframe",
2427
2922
  "choice",
2428
- "collect_series",
2429
2923
  "columns_to_dict",
2430
2924
  "concat_series",
2431
2925
  "convert_time_zone",
2432
2926
  "cross",
2433
2927
  "dataclass_to_dataframe",
2434
2928
  "dataclass_to_schema",
2929
+ "decreasing_horizontal",
2435
2930
  "deserialize_dataframe",
2436
- "drop_null_struct_series",
2437
2931
  "ensure_data_type",
2438
2932
  "ensure_expr_or_series",
2439
2933
  "ensure_expr_or_series_many",
2934
+ "expr_to_series",
2935
+ "false_like",
2936
+ "filter_date",
2937
+ "filter_time",
2440
2938
  "finite_ewm_mean",
2441
- "floor_datetime",
2939
+ "first_true_horizontal",
2442
2940
  "get_data_type_or_series_time_zone",
2443
2941
  "get_expr_name",
2444
2942
  "get_frequency_spectrum",
2445
- "get_series_number_of_decimals",
2943
+ "increasing_horizontal",
2446
2944
  "insert_after",
2447
2945
  "insert_before",
2448
2946
  "insert_between",
2449
2947
  "integers",
2948
+ "is_close",
2949
+ "is_false",
2450
2950
  "is_near_event",
2451
- "is_not_null_struct_series",
2452
- "is_null_struct_series",
2951
+ "is_true",
2453
2952
  "join",
2454
2953
  "join_into_periods",
2455
2954
  "map_over_columns",
2456
2955
  "nan_sum_agg",
2457
- "nan_sum_cols",
2458
- "normal",
2956
+ "nan_sum_horizontal",
2957
+ "normal_pdf",
2958
+ "normal_rv",
2959
+ "number_of_decimals",
2960
+ "offset_datetime",
2961
+ "one_column",
2459
2962
  "order_of_magnitude",
2963
+ "period_range",
2460
2964
  "read_dataframe",
2461
2965
  "read_series",
2462
2966
  "replace_time_zone",
2967
+ "round_to_float",
2968
+ "search_period",
2969
+ "select_exact",
2463
2970
  "serialize_dataframe",
2464
2971
  "set_first_row_as_columns",
2465
2972
  "struct_dtype",
2466
- "struct_from_dataclass",
2973
+ "to_false",
2974
+ "to_not_false",
2975
+ "to_not_true",
2976
+ "to_true",
2467
2977
  "touch",
2978
+ "true_like",
2468
2979
  "try_reify_expr",
2469
2980
  "uniform",
2470
2981
  "unique_element",
2471
2982
  "write_dataframe",
2472
2983
  "write_series",
2473
- "zoned_datetime",
2984
+ "zoned_date_time_dtype",
2985
+ "zoned_date_time_period_dtype",
2474
2986
  ]