onekit 1.2.0__tar.gz → 1.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,13 +1,13 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: onekit
3
- Version: 1.2.0
3
+ Version: 1.4.0
4
4
  Summary: All-in-One Python Kit.
5
5
  Home-page: https://github.com/estripling/onekit
6
6
  License: BSD 3-Clause
7
7
  Keywords: onekit
8
8
  Author: Eugen Stripling
9
9
  Author-email: estripling042@gmail.com
10
- Requires-Python: >=3.8.1
10
+ Requires-Python: >=3.9
11
11
  Classifier: License :: Other/Proprietary License
12
12
  Classifier: Programming Language :: Python :: 3
13
13
  Classifier: Programming Language :: Python :: 3.9
@@ -15,7 +15,6 @@ Classifier: Programming Language :: Python :: 3.10
15
15
  Classifier: Programming Language :: Python :: 3.11
16
16
  Classifier: Programming Language :: Python :: 3.12
17
17
  Classifier: Programming Language :: Python :: 3 :: Only
18
- Classifier: Programming Language :: Python :: 3.8
19
18
  Requires-Dist: pytz (>=2024.1,<2025.0)
20
19
  Requires-Dist: toolz (>=0.12.0,<0.13.0)
21
20
  Project-URL: Documentation, https://onekit.readthedocs.io/en/stable/
@@ -46,7 +45,7 @@ All-in-One Python Kit:
46
45
 
47
46
  ## Installation
48
47
 
49
- `onekit` is available on [PyPI](https://pypi.org/project/onekit/) for Python 3.8+:
48
+ `onekit` is available on [PyPI](https://pypi.org/project/onekit/) for Python 3.9+:
50
49
 
51
50
  ```console
52
51
  pip install onekit
@@ -22,7 +22,7 @@ All-in-One Python Kit:
22
22
 
23
23
  ## Installation
24
24
 
25
- `onekit` is available on [PyPI](https://pypi.org/project/onekit/) for Python 3.8+:
25
+ `onekit` is available on [PyPI](https://pypi.org/project/onekit/) for Python 3.9+:
26
26
 
27
27
  ```console
28
28
  pip install onekit
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "onekit"
3
- version = "1.2.0"
3
+ version = "1.4.0"
4
4
  description = "All-in-One Python Kit."
5
5
  authors = ["Eugen Stripling <estripling042@gmail.com>"]
6
6
  license = "BSD 3-Clause"
@@ -10,7 +10,6 @@ documentation = "https://onekit.readthedocs.io/en/stable/"
10
10
  keywords = ["onekit"]
11
11
  classifiers = [
12
12
  "Programming Language :: Python :: 3 :: Only",
13
- "Programming Language :: Python :: 3.8",
14
13
  "Programming Language :: Python :: 3.9",
15
14
  "Programming Language :: Python :: 3.10",
16
15
  "Programming Language :: Python :: 3.11",
@@ -18,7 +17,7 @@ classifiers = [
18
17
  ]
19
18
 
20
19
  [tool.poetry.dependencies]
21
- python = ">=3.8.1"
20
+ python = ">=3.9"
22
21
  toolz = "^0.12.0"
23
22
  pytz = "^2024.1"
24
23
 
@@ -50,6 +49,9 @@ python-semantic-release = "^8.3.0"
50
49
  [tool.poetry.group.pandaskit.dependencies]
51
50
  pandas = ">=0.23.2"
52
51
 
52
+ [tool.poetry.group.sklearnkit.dependencies]
53
+ scikit-learn = ">=1.3"
54
+
53
55
  [tool.poetry.group.sparkkit.dependencies]
54
56
  pyspark = "3.1.1"
55
57
 
@@ -82,7 +82,8 @@ def collatz(n: int, /) -> Generator:
82
82
  n = n // 2 if iseven(n) else 3 * n + 1
83
83
 
84
84
 
85
- def digitscale(x: Union[int, float], /) -> float:
85
+ @toolz.curry
86
+ def digitscale(x: Union[int, float], /, *, kind: str = "log") -> Union[int, float]:
86
87
  """Scale :math:`x` such that its mapped integer part is its number of digits.
87
88
 
88
89
  Given a number :math:`x \\in \\mathbb{R}`, the following function
@@ -102,8 +103,24 @@ def digitscale(x: Union[int, float], /) -> float:
102
103
  -----
103
104
  - :math:`\\lfloor \\cdot \\rfloor`: floor function
104
105
  - :math:`\\left[ \\, \\cdot \\, \\right]`: truncation function
105
- - For any positive integer :math:`n`, the number of digits in :math:`n` is
106
- :math:`1 + \\lfloor \\log_{10} n \\rfloor`
106
+ - For any positive integer :math:`k`, the number of digits in :math:`k` is
107
+ :math:`1 + \\lfloor \\log_{10} k \\rfloor`
108
+ - If `kind="int"`, returns :math:`\\lfloor f(x) \\rfloor`
109
+ - If `kind="linear"`, linear interpolation is performed:
110
+
111
+ .. math::
112
+
113
+ f_{linear}(x) =
114
+ \\begin{cases}
115
+ \\frac{y_{0} (x_{1} - x) + y_{1} (x - x_{0})}{x_{1} - x_{0}}
116
+ & \\text{ if } |x| \\ge 0.1 \\\\[6pt]
117
+ 0 & \\text{ otherwise }
118
+ \\end{cases}
119
+
120
+ \\\\[6pt]
121
+
122
+ \\text{ with } n = \\lfloor f(x) \\rfloor, y_{0} = n, y_{1} = n + 1,
123
+ x_{0} = 10^{n - 1}, \\text{ and } x_{1} = 10^{n}
107
124
 
108
125
  See Also
109
126
  --------
@@ -121,8 +138,37 @@ def digitscale(x: Union[int, float], /) -> float:
121
138
 
122
139
  >>> list(map(mk.digitscale, [-0.5, -5, -50, -500]))
123
140
  [0.6989700043360187, 1.6989700043360187, 2.6989700043360187, 3.6989700043360187]
141
+
142
+ >>> # function is curried
143
+ >>> list(map(mk.digitscale(kind="int"), [-0.5, -5, -50, -500]))
144
+ [0, 1, 2, 3]
145
+
146
+ >>> list(map(mk.digitscale(kind="linear"), [0.1, 1, 10, 100, 1_000, 10_000]))
147
+ [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
148
+ >>> list(map(mk.digitscale(kind="linear"), [0.2, 2, 20, 200]))
149
+ [0.11111111111111112, 1.1111111111111112, 2.111111111111111, 3.111111111111111]
150
+ >>> list(map(mk.digitscale(kind="linear"), [-0.5, -5, -50, -500]))
151
+ [0.4444444444444445, 1.4444444444444444, 2.4444444444444446, 3.4444444444444446]
124
152
  """
125
- return 1 + math.log10(abs(x)) if abs(x) >= 0.1 else 0.0
153
+ valid_kind = ["log", "int", "linear"]
154
+
155
+ x = abs(x)
156
+ fx = 1 + math.log10(x) if x >= 0.1 else 0.0
157
+
158
+ if kind == "log":
159
+ return fx
160
+
161
+ elif kind == "int":
162
+ return math.floor(fx)
163
+
164
+ elif kind == "linear":
165
+ n = math.floor(fx)
166
+ y0, y1 = n, n + 1
167
+ x0, x1 = 10 ** (n - 1), 10**n
168
+ return (y0 * (x1 - x) + y1 * (x - x0)) / (x1 - x0) if x >= 0.1 else 0.0
169
+
170
+ else:
171
+ raise ValueError(f"{kind=} - must be a valid value: {valid_kind}")
126
172
 
127
173
 
128
174
  def fibonacci() -> Generator:
@@ -52,7 +52,7 @@ def check_vector(x: ArrayLike, /, *, n_min: int = 1, n_max: int = np.inf) -> Vec
52
52
  return x
53
53
 
54
54
 
55
- def digitscale(x: ArrayLike, /) -> np.ndarray:
55
+ def digitscale(x: ArrayLike, /, *, kind: str = "log") -> np.ndarray:
56
56
  """NumPy version of digitscale.
57
57
 
58
58
  See Also
@@ -63,10 +63,17 @@ def digitscale(x: ArrayLike, /) -> np.ndarray:
63
63
  Examples
64
64
  --------
65
65
  >>> import onekit.numpykit as npk
66
- >>> npk.digitscale([0.1, 1, 10, 100, 1_000, 10_000, 100_000, 1_000_000])
67
- array([0., 1., 2., 3., 4., 5., 6., 7.])
66
+ >>> npk.digitscale([0.1, 1, 10, 100, 1_000, 10_000, 2_000_000])
67
+ array([0. , 1. , 2. , 3. , 4. , 5. , 7.30103])
68
+
69
+ >>> npk.digitscale([0.1, 1, 10, 100, 1_000, 10_000, 100_000, 2_000_000], kind="int")
70
+ array([0, 1, 2, 3, 4, 5, 6, 7])
71
+
72
+ >>> npk.digitscale([0.2, 2, 20], kind="linear")
73
+ array([0.11111111, 1.11111111, 2.11111111])
68
74
  """
69
- return np.vectorize(mk.digitscale, otypes=[float])(x)
75
+ otypes = [int] if kind == "int" else [float]
76
+ return np.vectorize(mk.digitscale(kind=kind), otypes=otypes)(x)
70
77
 
71
78
 
72
79
  def stderr(x: ArrayLike, /) -> float:
@@ -1,6 +1,5 @@
1
1
  import calendar
2
2
  import datetime as dt
3
- import distutils
4
3
  import functools
5
4
  import inspect
6
5
  import itertools
@@ -37,7 +36,6 @@ __all__ = (
37
36
  "coinflip",
38
37
  "concat_strings",
39
38
  "contrast_sets",
40
- "create_path",
41
39
  "date_ago",
42
40
  "date_ahead",
43
41
  "date_count_backward",
@@ -390,21 +388,6 @@ def contrast_sets(x: set, y: set, /, *, n: int = 3) -> dict:
390
388
  return output
391
389
 
392
390
 
393
- def create_path(*strings: str) -> str:
394
- """Create path by concatenating strings.
395
-
396
- Examples
397
- --------
398
- >>> import onekit.pythonkit as pk
399
- >>> pk.create_path("path", "to", "file")
400
- 'path/to/file'
401
-
402
- >>> pk.create_path(["hdfs://", "path", "to", "file"])
403
- 'hdfs://path/to/file'
404
- """
405
- return functools.reduce(os.path.join, flatten(strings))
406
-
407
-
408
391
  @toolz.curry
409
392
  def date_ago(d0: dt.date, /, n: int) -> dt.date:
410
393
  """Compute date that is :math:`n \\in \\mathbb{N}_{0}` days ago.
@@ -683,13 +666,13 @@ def highlight_string_differences(lft_str: str, rgt_str: str, /) -> str:
683
666
  Examples
684
667
  --------
685
668
  >>> import onekit.pythonkit as pk
686
- >>> print(pk.highlight_string_differences("hello", "hall"))
669
+ >>> print(pk.highlight_string_differences("hello", "hall")) # doctest: +SKIP
687
670
  hello
688
671
  | |
689
672
  hall
690
673
 
691
674
  >>> # no differences when there is no '|' character
692
- >>> print(pk.highlight_string_differences("hello", "hello"))
675
+ >>> print(pk.highlight_string_differences("hello", "hello")) # doctest: +SKIP
693
676
  hello
694
677
  <BLANKLINE>
695
678
  hello
@@ -699,7 +682,7 @@ def highlight_string_differences(lft_str: str, rgt_str: str, /) -> str:
699
682
  lft_str,
700
683
  concat_strings(
701
684
  "",
702
- (
685
+ *(
703
686
  " " if x == y else "|"
704
687
  for x, y in itertools.zip_longest(lft_str, rgt_str, fillvalue="")
705
688
  ),
@@ -936,11 +919,31 @@ def prompt_yes_no(question: str, /, *, default: Optional[str] = None) -> bool:
936
919
 
937
920
  answer = input(f"{question} {prompt} ").lower()
938
921
 
922
+ def strtobool(value: str) -> bool:
923
+ """Convert a string representation of truth to true (1) or false (0).
924
+
925
+ True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
926
+ are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
927
+ 'val' is anything else.
928
+
929
+ Notes
930
+ -----
931
+ - Shamelessly copied and modified from: distutils.util.strtobool
932
+ - distutils is not available with Python>=3.12
933
+ """
934
+ value = value.lower()
935
+ if value in ("y", "yes", "t", "true", "on", "1"):
936
+ return True
937
+ elif value in ("n", "no", "f", "false", "off", "0"):
938
+ return False
939
+ else:
940
+ raise ValueError("invalid truth value {!r}".format(value))
941
+
939
942
  while True:
940
943
  try:
941
944
  if answer == "" and default in ["yes", "no"]:
942
- return bool(distutils.util.strtobool(default))
943
- return bool(distutils.util.strtobool(answer))
945
+ return bool(strtobool(default))
946
+ return bool(strtobool(answer))
944
947
 
945
948
  except ValueError:
946
949
  response_text = "{} Please respond with 'yes' [{}] or 'no' [{}] ".format(
@@ -0,0 +1,153 @@
1
+ from typing import (
2
+ Optional,
3
+ Union,
4
+ )
5
+
6
+ import numpy as np
7
+ import numpy.typing as npt
8
+ import pandas as pd
9
+ from pandas import DataFrame as PandasDF
10
+ from sklearn import metrics
11
+ from sklearn.utils import validation
12
+
13
+ __all__ = (
14
+ "precision_given_recall_score",
15
+ "threshold_summary",
16
+ )
17
+
18
+
19
+ ArrayLike = npt.ArrayLike
20
+
21
+
22
+ def precision_given_recall_score(
23
+ y_true: ArrayLike,
24
+ y_score: ArrayLike,
25
+ *,
26
+ min_recall: float,
27
+ pos_label: Optional[Union[int, str]] = None,
28
+ ) -> float:
29
+ """Compute precision given a desired minimum recall level.
30
+
31
+ Examples
32
+ --------
33
+ >>> import onekit.sklearnkit as slk
34
+ >>> y_true = [0, 1, 1, 1, 0, 0, 0, 1]
35
+ >>> y_score = [0.1, 0.4, 0.35, 0.8, 0.5, 0.2, 0.75, 0.5]
36
+ >>> slk.precision_given_recall_score(y_true, y_score, min_recall=0.7)
37
+ 0.6
38
+ """
39
+ if not (0 < min_recall <= 1):
40
+ raise ValueError(f"{min_recall=} - must be a float in (0, 1]")
41
+
42
+ df = (
43
+ threshold_summary(y_true, y_score, pos_label=pos_label)
44
+ .filter(items=["precision", "recall"])
45
+ .query(f"recall >= {min_recall}")
46
+ )
47
+
48
+ min_empirical_recall = df["recall"].min()
49
+
50
+ return float(
51
+ 0
52
+ if df.empty
53
+ else df.query(f"recall == {min_empirical_recall}")["precision"].max()
54
+ )
55
+
56
+
57
+ def threshold_summary(
58
+ y_true: ArrayLike,
59
+ y_score: ArrayLike,
60
+ *,
61
+ pos_label: Optional[Union[int, str]] = None,
62
+ ) -> PandasDF:
63
+ """Threshold summary.
64
+
65
+ Notes
66
+ -----
67
+ - Support for binary classification only
68
+ - Assumpution: classifier returns scores
69
+ - First values correspond to the edge case where everything is predicted positive
70
+ - Last values correspond to the edge case where everything is predicted negative
71
+
72
+ Examples
73
+ --------
74
+ >>> import onekit.sklearnkit as slk
75
+ >>> y_true = [0, 1, 1, 1, 0, 0, 0, 1]
76
+ >>> y_score = [0.1, 0.4, 0.35, 0.8, 0.5, 0.2, 0.75, 0.5]
77
+ >>> with pd.option_context("display.float_format", "{:.2f}".format):
78
+ ... slk.threshold_summary(y_true, y_score).T
79
+ 0 1 2 3 4 5 6 7
80
+ threshold 0.10 0.20 0.35 0.40 0.50 0.75 0.80 inf
81
+ predicted_positive 8.00 7.00 6.00 5.00 4.00 2.00 1.00 0.00
82
+ true_positive 4.00 4.00 4.00 3.00 2.00 1.00 1.00 0.00
83
+ false_positive 4.00 3.00 2.00 2.00 2.00 1.00 0.00 0.00
84
+ false_negative 0.00 0.00 0.00 1.00 2.00 3.00 3.00 4.00
85
+ true_negative 0.00 1.00 2.00 2.00 2.00 3.00 4.00 4.00
86
+ precision 0.50 0.57 0.67 0.60 0.50 0.50 1.00 1.00
87
+ recall 1.00 1.00 1.00 0.75 0.50 0.25 0.25 0.00
88
+ f1 0.67 0.73 0.80 0.67 0.50 0.33 0.40 0.00
89
+ accuracy 0.50 0.62 0.75 0.62 0.50 0.50 0.62 0.50
90
+ balanced_accuracy 0.50 0.62 0.75 0.62 0.50 0.50 0.62 0.50
91
+ matthews_corrcoef NaN 0.38 0.58 0.26 0.00 0.00 0.38 NaN
92
+ """
93
+ y = validation.column_or_1d(y_true)
94
+ s = validation.column_or_1d(y_score)
95
+ validation.check_consistent_length(y, s)
96
+ validation.assert_all_finite(y)
97
+ validation.assert_all_finite(s)
98
+ pos_label = validation._check_pos_label_consistency(pos_label, y)
99
+
100
+ precision, recall, thresholds = metrics.precision_recall_curve(
101
+ y_true=y,
102
+ y_score=s,
103
+ pos_label=pos_label,
104
+ sample_weight=None,
105
+ drop_intermediate=False,
106
+ )
107
+
108
+ is_true_pos = y == pos_label
109
+ is_true_neg = y != pos_label
110
+
111
+ def is_pred_pos(t: float) -> np.ndarray:
112
+ return s >= t
113
+
114
+ def is_pred_neg(t: float) -> np.ndarray:
115
+ return s < t
116
+
117
+ return (
118
+ pd.DataFrame(np.append(thresholds, np.inf), columns=["t"])
119
+ .assign(
120
+ pp=lambda df: df.t.map(lambda t: is_pred_pos(t).sum()),
121
+ tp=lambda df: df.t.map(lambda t: (is_pred_pos(t) & is_true_pos).sum()),
122
+ fp=lambda df: df.t.map(lambda t: (is_pred_pos(t) & is_true_neg).sum()),
123
+ fn=lambda df: df.t.map(lambda t: (is_pred_neg(t) & is_true_pos).sum()),
124
+ tn=lambda df: df.t.map(lambda t: (is_pred_neg(t) & is_true_neg).sum()),
125
+ precision=precision,
126
+ recall=recall,
127
+ f1=2 * (precision * recall) / (precision + recall),
128
+ acc=lambda df: (df.tp + df.tn) / (df.tp + df.tn + df.fp + df.fn),
129
+ bacc=lambda df: 0.5 * (df.tp / (df.tp + df.fn) + df.tn / (df.tn + df.fp)),
130
+ mcc=lambda df: np.true_divide(
131
+ (df.tp * df.tn - df.fp * df.fn),
132
+ np.sqrt(
133
+ (df.tp + df.fp)
134
+ * (df.tp + df.fn)
135
+ * (df.tn + df.fp)
136
+ * (df.tn + df.fn)
137
+ ),
138
+ ),
139
+ )
140
+ .rename(
141
+ columns={
142
+ "t": "threshold",
143
+ "pp": "predicted_positive",
144
+ "tp": "true_positive",
145
+ "fp": "false_positive",
146
+ "fn": "false_negative",
147
+ "tn": "true_negative",
148
+ "acc": "accuracy",
149
+ "bacc": "balanced_accuracy",
150
+ "mcc": "matthews_corrcoef",
151
+ },
152
+ )
153
+ )
@@ -34,6 +34,7 @@ __all__ = (
34
34
  "assert_row_equal",
35
35
  "assert_schema_equal",
36
36
  "bool_to_int",
37
+ "bool_to_str",
37
38
  "check_column_present",
38
39
  "count_nulls",
39
40
  "cvf",
@@ -46,6 +47,7 @@ __all__ = (
46
47
  "is_schema_equal",
47
48
  "join",
48
49
  "peek",
50
+ "select_col_types",
49
51
  "str_to_col",
50
52
  "union",
51
53
  "with_date_diff_ago",
@@ -475,12 +477,53 @@ def bool_to_int(df: SparkDF, /, *, subset=None) -> SparkDF:
475
477
  <BLANKLINE>
476
478
  """
477
479
  cols = subset or df.columns
478
- bool_cols = [c for c in cols if isinstance(df.schema[c].dataType, T.BooleanType)]
480
+ bool_cols = [c for c in select_col_types(df, T.BooleanType) if c in cols]
479
481
  for bool_col in bool_cols:
480
482
  df = df.withColumn(bool_col, F.col(bool_col).cast(T.IntegerType()))
481
483
  return df
482
484
 
483
485
 
486
+ @toolz.curry
487
+ def bool_to_str(df: SparkDF, /, *, subset=None) -> SparkDF:
488
+ """Cast values of Boolean columns to string values.
489
+
490
+ Examples
491
+ --------
492
+ >>> from pyspark.sql import SparkSession
493
+ >>> import onekit.sparkkit as sk
494
+ >>> spark = SparkSession.builder.getOrCreate()
495
+ >>> df = spark.createDataFrame(
496
+ ... [
497
+ ... dict(x=True, y=False, z=None),
498
+ ... dict(x=False, y=None, z=True),
499
+ ... dict(x=True, y=None, z=None),
500
+ ... ]
501
+ ... )
502
+ >>> sk.bool_to_str(df).show()
503
+ +-----+-----+----+
504
+ | x| y| z|
505
+ +-----+-----+----+
506
+ | true|false|null|
507
+ |false| null|true|
508
+ | true| null|null|
509
+ +-----+-----+----+
510
+ <BLANKLINE>
511
+
512
+ >>> # function is curried
513
+ >>> df.transform(sk.bool_to_str(subset=["y", "z"])).printSchema()
514
+ root
515
+ |-- x: boolean (nullable = true)
516
+ |-- y: string (nullable = true)
517
+ |-- z: string (nullable = true)
518
+ <BLANKLINE>
519
+ """
520
+ cols = subset or df.columns
521
+ bool_cols = [c for c in select_col_types(df, T.BooleanType) if c in cols]
522
+ for bool_col in bool_cols:
523
+ df = df.withColumn(bool_col, F.col(bool_col).cast(T.StringType()))
524
+ return df
525
+
526
+
484
527
  def check_column_present(*cols: str) -> SparkDFTransformFunc:
485
528
  """Check if columns are present in dataframe.
486
529
 
@@ -1004,6 +1047,42 @@ def peek(
1004
1047
  return inner
1005
1048
 
1006
1049
 
1050
+ def select_col_types(df: SparkDF, /, *col_types: T.DataType) -> List[str]:
1051
+ """Identify columns of specified data type.
1052
+
1053
+ Examples
1054
+ --------
1055
+ >>> from pyspark.sql import SparkSession
1056
+ >>> from pyspark.sql import types as T
1057
+ >>> import onekit.sparkkit as sk
1058
+ >>> spark = SparkSession.builder.getOrCreate()
1059
+ >>> df = spark.createDataFrame(
1060
+ ... [dict(bool=True, double=1.0, float=2.0, int=3, long=4, str="string")],
1061
+ ... schema=T.StructType(
1062
+ ... [
1063
+ ... T.StructField("bool", T.BooleanType(), nullable=True),
1064
+ ... T.StructField("double", T.DoubleType(), nullable=True),
1065
+ ... T.StructField("float", T.FloatType(), nullable=True),
1066
+ ... T.StructField("int", T.IntegerType(), nullable=True),
1067
+ ... T.StructField("long", T.LongType(), nullable=True),
1068
+ ... T.StructField("str", T.StringType(), nullable=True),
1069
+ ... ]
1070
+ ... ),
1071
+ ... )
1072
+ >>> sk.select_col_types(df, T.BooleanType)
1073
+ ['bool']
1074
+
1075
+ >>> sk.select_col_types(df, T.IntegerType, T.LongType)
1076
+ ['int', 'long']
1077
+ """
1078
+ valid_types = {v.typeName() for k, v in T.__dict__.items() if k.endswith("Type")}
1079
+ col_types = tuple(pk.flatten(col_types))
1080
+ for col_type in col_types:
1081
+ if not hasattr(col_type, "typeName") or col_type.typeName() not in valid_types:
1082
+ raise TypeError(f"{col_type=} - must be a valid data type: {valid_types}")
1083
+ return [c for c in df.columns if isinstance(df.schema[c].dataType, col_types)]
1084
+
1085
+
1007
1086
  def str_to_col(x: str, /) -> SparkCol:
1008
1087
  """Cast string ``x`` to Spark column else return ``x``.
1009
1088
 
@@ -1145,7 +1224,13 @@ def with_date_diff_ahead(
1145
1224
  return inner
1146
1225
 
1147
1226
 
1148
- def with_digitscale(num_col: str, new_col: str) -> SparkDFTransformFunc:
1227
+ def with_digitscale(
1228
+ num_col: str,
1229
+ new_col: str,
1230
+ /,
1231
+ *,
1232
+ kind: str = "log",
1233
+ ) -> SparkDFTransformFunc:
1149
1234
  """PySpark version of digitscale.
1150
1235
 
1151
1236
  See Also
@@ -1168,33 +1253,95 @@ def with_digitscale(num_col: str, new_col: str) -> SparkDFTransformFunc:
1168
1253
  ... dict(x=10_000.0),
1169
1254
  ... dict(x=100_000.0),
1170
1255
  ... dict(x=1_000_000.0),
1256
+ ... dict(x=2_000_000.0),
1171
1257
  ... dict(x=None),
1172
1258
  ... ],
1173
1259
  ... )
1174
1260
  >>> df.transform(sk.with_digitscale("x", "fx")).show()
1261
+ +---------+-----------------+
1262
+ | x| fx|
1263
+ +---------+-----------------+
1264
+ | 0.1| 0.0|
1265
+ | 1.0| 1.0|
1266
+ | 10.0| 2.0|
1267
+ | 100.0| 3.0|
1268
+ | 1000.0| 4.0|
1269
+ | 10000.0| 5.0|
1270
+ | 100000.0| 6.0|
1271
+ |1000000.0| 7.0|
1272
+ |2000000.0|7.301029995663981|
1273
+ | null| null|
1274
+ +---------+-----------------+
1275
+ <BLANKLINE>
1276
+
1277
+ >>> df.transform(sk.with_digitscale("x", "fx", kind="int")).show()
1175
1278
  +---------+----+
1176
1279
  | x| fx|
1177
1280
  +---------+----+
1178
- | 0.1| 0.0|
1179
- | 1.0| 1.0|
1180
- | 10.0| 2.0|
1181
- | 100.0| 3.0|
1182
- | 1000.0| 4.0|
1183
- | 10000.0| 5.0|
1184
- | 100000.0| 6.0|
1185
- |1000000.0| 7.0|
1281
+ | 0.1| 0|
1282
+ | 1.0| 1|
1283
+ | 10.0| 2|
1284
+ | 100.0| 3|
1285
+ | 1000.0| 4|
1286
+ | 10000.0| 5|
1287
+ | 100000.0| 6|
1288
+ |1000000.0| 7|
1289
+ |2000000.0| 7|
1186
1290
  | null|null|
1187
1291
  +---------+----+
1188
1292
  <BLANKLINE>
1293
+
1294
+ >>> df.transform(sk.with_digitscale("x", "fx", kind="linear")).show()
1295
+ +---------+-----------------+
1296
+ | x| fx|
1297
+ +---------+-----------------+
1298
+ | 0.1| 0.0|
1299
+ | 1.0| 1.0|
1300
+ | 10.0| 2.0|
1301
+ | 100.0| 3.0|
1302
+ | 1000.0| 4.0|
1303
+ | 10000.0| 5.0|
1304
+ | 100000.0| 6.0|
1305
+ |1000000.0| 7.0|
1306
+ |2000000.0|7.111111111111111|
1307
+ | null| null|
1308
+ +---------+-----------------+
1309
+ <BLANKLINE>
1189
1310
  """
1311
+ valid_kind = ["log", "int", "linear"]
1312
+ if kind not in valid_kind:
1313
+ raise ValueError(f"{kind=} - must be a valid value: {valid_kind}")
1190
1314
 
1191
1315
  def inner(df: SparkDF, /) -> SparkDF:
1192
1316
  x = F.abs(num_col)
1193
- return df.withColumn(
1317
+ df = df.withColumn(
1194
1318
  new_col,
1195
1319
  F.when(x.isNull(), None).when(x >= 0.1, 1 + F.log10(x)).otherwise(0.0),
1196
1320
  )
1197
1321
 
1322
+ if kind == "int":
1323
+ df = df.withColumn(new_col, F.floor(new_col).cast(T.IntegerType()))
1324
+
1325
+ if kind == "linear":
1326
+ n = "_n_"
1327
+ y0 = F.col(n)
1328
+ y1 = F.col(n) + 1
1329
+ x0 = 10 ** (F.col(n) - 1)
1330
+ x1 = 10 ** F.col(n)
1331
+
1332
+ df = (
1333
+ df.withColumn(n, F.floor(new_col).cast(T.IntegerType()))
1334
+ .withColumn(
1335
+ new_col,
1336
+ F.when(x.isNull(), None)
1337
+ .when(x >= 0.1, (y0 * (x1 - x) + y1 * (x - x0)) / (x1 - x0))
1338
+ .otherwise(0.0),
1339
+ )
1340
+ .drop(n)
1341
+ )
1342
+
1343
+ return df
1344
+
1198
1345
  return inner
1199
1346
 
1200
1347
 
File without changes
File without changes
File without changes
File without changes
File without changes