binscatter 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
binscatter-0.1.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Matthias Kaeding
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: binscatter
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Cross-backend binscatter plots.
|
|
5
|
+
Keywords: binscatter,visualization,econometrics
|
|
6
|
+
Author: Matthias Kaeding
|
|
7
|
+
License: MIT License
|
|
8
|
+
|
|
9
|
+
Copyright (c) 2025 Matthias Kaeding
|
|
10
|
+
|
|
11
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
12
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
13
|
+
in the Software without restriction, including without limitation the rights
|
|
14
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
15
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
16
|
+
furnished to do so, subject to the following conditions:
|
|
17
|
+
|
|
18
|
+
The above copyright notice and this permission notice shall be included in all
|
|
19
|
+
copies or substantial portions of the Software.
|
|
20
|
+
|
|
21
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
22
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
23
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
24
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
25
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
26
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
27
|
+
SOFTWARE.
|
|
28
|
+
Classifier: Development Status :: 3 - Alpha
|
|
29
|
+
Classifier: Intended Audience :: Science/Research
|
|
30
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
31
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
|
32
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
33
|
+
Classifier: Topic :: Scientific/Engineering :: Visualization
|
|
34
|
+
Requires-Dist: narwhals>=2.1.2
|
|
35
|
+
Requires-Dist: numpy>=2.3.2
|
|
36
|
+
Requires-Dist: plotly>=6.3.0
|
|
37
|
+
Requires-Python: >=3.11
|
|
38
|
+
Project-URL: homepage, https://github.com/matthiaskaeding/binscatter
|
|
39
|
+
Project-URL: issues, https://github.com/matthiaskaeding/binscatter/issues
|
|
40
|
+
Project-URL: repository, https://github.com/matthiaskaeding/binscatter
|
|
41
|
+
Description-Content-Type: text/markdown
|
|
42
|
+
|
|
43
|
+
# Dataframe agnostic binscatter plots
|
|
44
|
+
|
|
45
|
+
This package implements binscatter plots following:
|
|
46
|
+
|
|
47
|
+
> Cattaneo, Crump, Farrell and Feng (2024)
|
|
48
|
+
> "On Binscatter"
|
|
49
|
+
> American Economic Review, 114(5), pp. 1488-1514
|
|
50
|
+
> [DOI: 10.1257/aer.20221576](https://doi.org/10.1257/aer.20221576)
|
|
51
|
+
|
|
52
|
+
- Uses `narwhals` as dataframe layer `binscatter`.
|
|
53
|
+
- Currently supports: pandas, Polars, DuckDB, Dask, and PySpark
|
|
54
|
+
- All other Narwhals backends fall back to a generic quantile handler if a native path is unavailable
|
|
55
|
+
- Lightweight - little dependencies
|
|
56
|
+
- Uses `plotly` as graphics backend - because: (1) its great (2) it uses `narwhals` as well, minimizing dependencies
|
|
57
|
+
- Pythonic alternative to the excellent **binsreg** package
|
|
58
|
+
|
|
59
|
+
---
|
|
60
|
+
|
|
61
|
+
## Example
|
|
62
|
+
|
|
63
|
+
We made this noisy scatterplot:
|
|
64
|
+
|
|
65
|
+

|
|
66
|
+
|
|
67
|
+
This is how we make a nice binscatter plot, controlling for a set of features:
|
|
68
|
+
|
|
69
|
+
```python
|
|
70
|
+
from binscatter import binscatter
|
|
71
|
+
|
|
72
|
+
p_binscatter_controls = binscatter(
|
|
73
|
+
df,
|
|
74
|
+
"mtr90_lag3",
|
|
75
|
+
"lnpat",
|
|
76
|
+
[
|
|
77
|
+
"top_corp_lag3",
|
|
78
|
+
"real_gdp_pc",
|
|
79
|
+
"population_density",
|
|
80
|
+
"rd_credit_lag3",
|
|
81
|
+
"statenum",
|
|
82
|
+
"year",
|
|
83
|
+
],
|
|
84
|
+
num_bins=35,
|
|
85
|
+
)
|
|
86
|
+
```
|
|
87
|
+
|
|
88
|
+

|
|
89
|
+
|
|
90
|
+
The data originates from:
|
|
91
|
+
|
|
92
|
+
Akcigit, Ufuk; Grigsby, John; Nicholas, Tom; Stantcheva, Stefanie, 2021, "Replication Data for: 'Taxation and Innovation in the 20th Century'", https://doi.org/10.7910/DVN/SR410I, Harvard Dataverse, V1
|
|
93
|
+
|
|
94
|
+
## Tests
|
|
95
|
+
|
|
96
|
+
- Run the full backend matrix, including PySpark: `just test`
|
|
97
|
+
- Use the faster run without PySpark: `just test-fast`
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
# Dataframe agnostic binscatter plots
|
|
2
|
+
|
|
3
|
+
This package implements binscatter plots following:
|
|
4
|
+
|
|
5
|
+
> Cattaneo, Crump, Farrell and Feng (2024)
|
|
6
|
+
> "On Binscatter"
|
|
7
|
+
> American Economic Review, 114(5), pp. 1488-1514
|
|
8
|
+
> [DOI: 10.1257/aer.20221576](https://doi.org/10.1257/aer.20221576)
|
|
9
|
+
|
|
10
|
+
- Uses `narwhals` as dataframe layer `binscatter`.
|
|
11
|
+
- Currently supports: pandas, Polars, DuckDB, Dask, and PySpark
|
|
12
|
+
- All other Narwhals backends fall back to a generic quantile handler if a native path is unavailable
|
|
13
|
+
- Lightweight - little dependencies
|
|
14
|
+
- Uses `plotly` as graphics backend - because: (1) its great (2) it uses `narwhals` as well, minimizing dependencies
|
|
15
|
+
- Pythonic alternative to the excellent **binsreg** package
|
|
16
|
+
|
|
17
|
+
---
|
|
18
|
+
|
|
19
|
+
## Example
|
|
20
|
+
|
|
21
|
+
We made this noisy scatterplot:
|
|
22
|
+
|
|
23
|
+

|
|
24
|
+
|
|
25
|
+
This is how we make a nice binscatter plot, controlling for a set of features:
|
|
26
|
+
|
|
27
|
+
```python
|
|
28
|
+
from binscatter import binscatter
|
|
29
|
+
|
|
30
|
+
p_binscatter_controls = binscatter(
|
|
31
|
+
df,
|
|
32
|
+
"mtr90_lag3",
|
|
33
|
+
"lnpat",
|
|
34
|
+
[
|
|
35
|
+
"top_corp_lag3",
|
|
36
|
+
"real_gdp_pc",
|
|
37
|
+
"population_density",
|
|
38
|
+
"rd_credit_lag3",
|
|
39
|
+
"statenum",
|
|
40
|
+
"year",
|
|
41
|
+
],
|
|
42
|
+
num_bins=35,
|
|
43
|
+
)
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+

|
|
47
|
+
|
|
48
|
+
The data originates from:
|
|
49
|
+
|
|
50
|
+
Akcigit, Ufuk; Grigsby, John; Nicholas, Tom; Stantcheva, Stefanie, 2021, "Replication Data for: 'Taxation and Innovation in the 20th Century'", https://doi.org/10.7910/DVN/SR410I, Harvard Dataverse, V1
|
|
51
|
+
|
|
52
|
+
## Tests
|
|
53
|
+
|
|
54
|
+
- Run the full backend matrix, including PySpark: `just test`
|
|
55
|
+
- Use the faster run without PySpark: `just test-fast`
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["uv_build>=0.9.18,<0.10.0"]
|
|
3
|
+
build-backend = "uv_build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "binscatter"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Cross-backend binscatter plots."
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.11"
|
|
11
|
+
dependencies = ["narwhals>=2.1.2", "numpy>=2.3.2", "plotly>=6.3.0"]
|
|
12
|
+
authors = [{ name = "Matthias Kaeding" }]
|
|
13
|
+
license = { file = "LICENSE" }
|
|
14
|
+
keywords = ["binscatter", "visualization", "econometrics"]
|
|
15
|
+
classifiers = [
|
|
16
|
+
"Development Status :: 3 - Alpha",
|
|
17
|
+
"Intended Audience :: Science/Research",
|
|
18
|
+
"License :: OSI Approved :: MIT License",
|
|
19
|
+
"Programming Language :: Python :: 3 :: Only",
|
|
20
|
+
"Programming Language :: Python :: 3.11",
|
|
21
|
+
"Topic :: Scientific/Engineering :: Visualization",
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
[project.urls]
|
|
25
|
+
homepage = "https://github.com/matthiaskaeding/binscatter"
|
|
26
|
+
repository = "https://github.com/matthiaskaeding/binscatter"
|
|
27
|
+
issues = "https://github.com/matthiaskaeding/binscatter/issues"
|
|
28
|
+
|
|
29
|
+
[dependency-groups]
|
|
30
|
+
dev = [
|
|
31
|
+
"kaleido>=0.2.1",
|
|
32
|
+
"ipykernel>=6.29.5",
|
|
33
|
+
"pytest>=8.3.5",
|
|
34
|
+
"polars>=1.22.0",
|
|
35
|
+
"pandas>=1.4.4",
|
|
36
|
+
"duckdb>=1.3.2",
|
|
37
|
+
"pyarrow>=21.0.0",
|
|
38
|
+
"ibis>=3.3.0",
|
|
39
|
+
"sqlframe>=3.39.2",
|
|
40
|
+
"pyspark>=4.0.0",
|
|
41
|
+
"dask>=2025.7.0",
|
|
42
|
+
"nbformat>=5.10.4",
|
|
43
|
+
"ty>=0.0.4",
|
|
44
|
+
"statsmodels>=0.14.6",
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
[tool.ty.rules]
|
|
48
|
+
unresolved-import = "ignore"
|
|
Binary file
|
|
@@ -0,0 +1,736 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import math
|
|
3
|
+
import operator
|
|
4
|
+
import uuid
|
|
5
|
+
import warnings
|
|
6
|
+
from functools import reduce
|
|
7
|
+
from typing import (
|
|
8
|
+
Any,
|
|
9
|
+
Callable,
|
|
10
|
+
Iterable,
|
|
11
|
+
List,
|
|
12
|
+
Literal,
|
|
13
|
+
NamedTuple,
|
|
14
|
+
Tuple,
|
|
15
|
+
cast,
|
|
16
|
+
overload,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
import narwhals as nw
|
|
20
|
+
import narwhals.selectors as ncs
|
|
21
|
+
import numpy as np
|
|
22
|
+
import plotly.express as px
|
|
23
|
+
from narwhals import Implementation
|
|
24
|
+
from narwhals.typing import IntoDataFrame
|
|
25
|
+
from plotly import graph_objects as go
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@overload
|
|
31
|
+
def binscatter(
|
|
32
|
+
df: IntoDataFrame,
|
|
33
|
+
x: str,
|
|
34
|
+
y: str,
|
|
35
|
+
controls: Iterable[str] | str | None = None,
|
|
36
|
+
num_bins=20,
|
|
37
|
+
return_type: Literal["plotly"] = "plotly",
|
|
38
|
+
plot_args=None,
|
|
39
|
+
**kwargs,
|
|
40
|
+
) -> go.Figure: ...
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@overload
|
|
44
|
+
def binscatter(
|
|
45
|
+
df: IntoDataFrame,
|
|
46
|
+
x: str,
|
|
47
|
+
y: str,
|
|
48
|
+
controls: Iterable[str] | str | None = None,
|
|
49
|
+
num_bins=20,
|
|
50
|
+
return_type: Literal["native"] = "native",
|
|
51
|
+
**kwargs,
|
|
52
|
+
) -> object: ...
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def binscatter(
|
|
56
|
+
df: IntoDataFrame,
|
|
57
|
+
x: str,
|
|
58
|
+
y: str,
|
|
59
|
+
controls: Iterable[str] | str | None = None,
|
|
60
|
+
num_bins: int = 20,
|
|
61
|
+
return_type: Literal["plotly", "native"] = "plotly",
|
|
62
|
+
**kwargs_binscatter,
|
|
63
|
+
) -> object:
|
|
64
|
+
"""Creates a binned scatter plot by grouping x values into quantile bins and plotting mean y values.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
df (IntoDataFrame): Input dataframe - must be a type supported by narwhals
|
|
68
|
+
x (str): Name of x column
|
|
69
|
+
y (str): Name y column
|
|
70
|
+
controls (Iterable[str]): Names of control variables (numeric). These are partialled out
|
|
71
|
+
following Cattaneo et al. (2024).
|
|
72
|
+
num_bins (int, optional): Number of bins to use. Defaults to 20
|
|
73
|
+
return_type (str): Return type. Default "plotly" gives a plotly plot.
|
|
74
|
+
kwargs (dict, optional): Additional arguments used in plotly.express.scatter to make the binscatter plot.
|
|
75
|
+
Otherwise "native" returns a dataframe that is natural match to input dataframe.
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
plotly plot (default) if return_type == "plotly". Otherwise native dataframe, depending on input.
|
|
80
|
+
"""
|
|
81
|
+
if return_type not in ("plotly", "native"):
|
|
82
|
+
msg = f"Invalid return_type: {return_type}"
|
|
83
|
+
raise ValueError(msg)
|
|
84
|
+
# Prepare dataframe: sort, remove non numerics and add bins
|
|
85
|
+
df_prepped, profile = prep(df, x, y, controls, num_bins)
|
|
86
|
+
|
|
87
|
+
# Currently there are 2 cases:
|
|
88
|
+
# (1) no controls: the easy one, just compute the means by bin
|
|
89
|
+
# (2) controls: here we need to compute regression coefficients
|
|
90
|
+
# and partial out the effect of the controls
|
|
91
|
+
# (see section 2.2 in Cattaneo, Crump, Farrell and Feng (2024))
|
|
92
|
+
if not controls:
|
|
93
|
+
df_plotting: nw.LazyFrame = (
|
|
94
|
+
df_prepped.group_by(profile.bin_name)
|
|
95
|
+
.agg(profile.x_col.mean(), profile.y_col.mean())
|
|
96
|
+
.with_columns(nw.col(profile.bin_name).cast(nw.Int32))
|
|
97
|
+
).lazy()
|
|
98
|
+
else:
|
|
99
|
+
df_plotting, _ = partial_out_controls(df_prepped, profile)
|
|
100
|
+
|
|
101
|
+
match return_type:
|
|
102
|
+
case "plotly":
|
|
103
|
+
return make_plot_plotly(
|
|
104
|
+
df_plotting, profile, kwargs_binscatter=kwargs_binscatter
|
|
105
|
+
)
|
|
106
|
+
case "native":
|
|
107
|
+
df_out_nw = df_plotting.rename({profile.bin_name: "bin"}).sort("bin")
|
|
108
|
+
logger.debug(
|
|
109
|
+
"Type of df_out_nw: %s, implementation: %s",
|
|
110
|
+
type(df_out_nw),
|
|
111
|
+
df_out_nw.implementation,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
if profile.implementation in (
|
|
115
|
+
Implementation.PYSPARK,
|
|
116
|
+
Implementation.DUCKDB,
|
|
117
|
+
Implementation.DASK,
|
|
118
|
+
):
|
|
119
|
+
return df_out_nw.to_native()
|
|
120
|
+
else:
|
|
121
|
+
return df_out_nw.collect().to_native()
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class Profile(NamedTuple):
|
|
125
|
+
"""Main profile which holds bunch of data derived from dataframe."""
|
|
126
|
+
|
|
127
|
+
x_name: str
|
|
128
|
+
y_name: str
|
|
129
|
+
controls: Tuple[str, ...]
|
|
130
|
+
num_bins: int
|
|
131
|
+
bin_name: str
|
|
132
|
+
x_bounds: Tuple[float, float]
|
|
133
|
+
distinct_suffix: str
|
|
134
|
+
is_lazy_input: bool
|
|
135
|
+
implementation: Implementation
|
|
136
|
+
numeric_columns: Tuple[str, ...]
|
|
137
|
+
categorical_columns: Tuple[str, ...]
|
|
138
|
+
|
|
139
|
+
@property
|
|
140
|
+
def x_col(self) -> nw.Expr:
|
|
141
|
+
return nw.col(self.x_name)
|
|
142
|
+
|
|
143
|
+
@property
|
|
144
|
+
def y_col(self) -> nw.Expr:
|
|
145
|
+
return nw.col(self.y_name)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def prep(
|
|
149
|
+
df_in: IntoDataFrame,
|
|
150
|
+
x_name: str,
|
|
151
|
+
y_name: str,
|
|
152
|
+
controls: Iterable[str] | str | None = None,
|
|
153
|
+
num_bins: int = 20,
|
|
154
|
+
) -> Tuple[nw.LazyFrame, Profile]:
|
|
155
|
+
"""Prepares the input data and derives profile.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
df: Input dataframe.
|
|
159
|
+
x_name: name of x col
|
|
160
|
+
y_name: name of y col
|
|
161
|
+
controls: Iterable of control vars
|
|
162
|
+
num_bins: Number of bins to use for binscatter. Must be less than number of rows.
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
tuple: (narwhals.LazyFrame, Profile)
|
|
166
|
+
- Sorted input dataframe converted to a narwhals LazyFrame
|
|
167
|
+
- Profile object with metadata about the data
|
|
168
|
+
|
|
169
|
+
Raises:
|
|
170
|
+
AssertionError: If input validation fails
|
|
171
|
+
"""
|
|
172
|
+
if num_bins <= 1:
|
|
173
|
+
raise ValueError("num_bins must be greater than 1")
|
|
174
|
+
if not isinstance(x_name, str):
|
|
175
|
+
raise TypeError("x_name must be a string")
|
|
176
|
+
if not isinstance(y_name, str):
|
|
177
|
+
raise TypeError("y_name must be a string")
|
|
178
|
+
|
|
179
|
+
if controls is None:
|
|
180
|
+
controls = ()
|
|
181
|
+
elif isinstance(controls, str):
|
|
182
|
+
controls = (controls,)
|
|
183
|
+
else:
|
|
184
|
+
try:
|
|
185
|
+
controls = tuple(controls)
|
|
186
|
+
except TypeError:
|
|
187
|
+
raise TypeError(
|
|
188
|
+
f"controls must be a string, iterable, or None, got {type(controls)}"
|
|
189
|
+
)
|
|
190
|
+
if not all(isinstance(c, str) for c in controls):
|
|
191
|
+
raise TypeError("controls must contain only strings")
|
|
192
|
+
|
|
193
|
+
dfn: nw.DataFrame | nw.LazyFrame = nw.from_native(df_in)
|
|
194
|
+
logger.debug("Type after calling to native: %s", type(dfn.to_native()))
|
|
195
|
+
if type(dfn) is nw.DataFrame:
|
|
196
|
+
is_lazy_input = False
|
|
197
|
+
elif type(dfn) is nw.LazyFrame:
|
|
198
|
+
is_lazy_input = True
|
|
199
|
+
else:
|
|
200
|
+
msg = f"Unexpected narwhals type {(type(dfn))}"
|
|
201
|
+
raise ValueError(msg)
|
|
202
|
+
dfl: nw.LazyFrame = dfn.lazy()
|
|
203
|
+
|
|
204
|
+
try:
|
|
205
|
+
df = dfl.select(x_name, y_name, *controls)
|
|
206
|
+
except Exception as e:
|
|
207
|
+
cols = dfl.columns
|
|
208
|
+
for c in [x_name, y_name, *controls]:
|
|
209
|
+
if c not in cols:
|
|
210
|
+
msg = f"{c} not in input dataframe"
|
|
211
|
+
raise ValueError(msg)
|
|
212
|
+
raise e
|
|
213
|
+
|
|
214
|
+
assert num_bins > 1
|
|
215
|
+
|
|
216
|
+
# Find name for bins
|
|
217
|
+
distinct_suffix = str(uuid.uuid4()).replace("-", "_")
|
|
218
|
+
bin_name = f"bins____{distinct_suffix}"
|
|
219
|
+
|
|
220
|
+
cols_numeric, cols_cat = get_columns(df)
|
|
221
|
+
union_cols = set(cols_numeric) | set(cols_cat)
|
|
222
|
+
if set(df.columns) - union_cols:
|
|
223
|
+
missing = [c for c in df.columns if c not in union_cols]
|
|
224
|
+
msg = f"Columns with unsupported types: {missing}"
|
|
225
|
+
raise TypeError(msg)
|
|
226
|
+
missing_controls = [c for c in controls if c not in union_cols]
|
|
227
|
+
if missing_controls:
|
|
228
|
+
msg = f"Unknown control columns (neither numeric nor categorical): {missing_controls}"
|
|
229
|
+
raise TypeError(msg)
|
|
230
|
+
|
|
231
|
+
df_filtered = _remove_bad_values(df, cols_numeric, cols_cat)
|
|
232
|
+
|
|
233
|
+
# We need the range of x for plotting
|
|
234
|
+
bounds_df = df_filtered.select(
|
|
235
|
+
nw.col(x_name).min().alias("x_min"),
|
|
236
|
+
nw.col(x_name).max().alias("x_max"),
|
|
237
|
+
).collect()
|
|
238
|
+
x_bounds = (bounds_df.item(0, "x_min"), bounds_df.item(0, "x_max"))
|
|
239
|
+
for val, fun in zip(x_bounds, ["min", "max"]):
|
|
240
|
+
if not math.isfinite(val):
|
|
241
|
+
msg = f"{fun}({x_name})={val}"
|
|
242
|
+
raise ValueError(msg)
|
|
243
|
+
|
|
244
|
+
profile = Profile(
|
|
245
|
+
num_bins=num_bins,
|
|
246
|
+
x_name=x_name,
|
|
247
|
+
y_name=y_name,
|
|
248
|
+
bin_name=bin_name,
|
|
249
|
+
controls=controls,
|
|
250
|
+
x_bounds=x_bounds,
|
|
251
|
+
distinct_suffix=distinct_suffix,
|
|
252
|
+
is_lazy_input=is_lazy_input,
|
|
253
|
+
implementation=df_filtered.implementation,
|
|
254
|
+
numeric_columns=cols_numeric,
|
|
255
|
+
categorical_columns=cols_cat,
|
|
256
|
+
)
|
|
257
|
+
logger.debug("Profile: %s", profile)
|
|
258
|
+
|
|
259
|
+
quantile_handler = configure_quantile_handler(profile)
|
|
260
|
+
try:
|
|
261
|
+
df_with_bins = quantile_handler(df_filtered)
|
|
262
|
+
except ValueError as err:
|
|
263
|
+
err_text = str(err)
|
|
264
|
+
if (
|
|
265
|
+
"Quantiles are not unique" in err_text
|
|
266
|
+
or "Bin edges must be unique" in err_text
|
|
267
|
+
):
|
|
268
|
+
raise ValueError(
|
|
269
|
+
"Quantiles are not unique. Decrease number of bins."
|
|
270
|
+
) from err
|
|
271
|
+
raise
|
|
272
|
+
|
|
273
|
+
return df_with_bins.lazy(), profile
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def partial_out_controls(
|
|
277
|
+
df_prepped: nw.LazyFrame, profile: Profile
|
|
278
|
+
) -> tuple[nw.LazyFrame, dict[str, np.ndarray]]:
|
|
279
|
+
"""Compute binscatter means after partialling out controls following Cattaneo et al. (2024)."""
|
|
280
|
+
|
|
281
|
+
controls = profile.controls
|
|
282
|
+
if not controls:
|
|
283
|
+
raise ValueError("Controls must be provided for partial_out_controls")
|
|
284
|
+
|
|
285
|
+
numeric_controls = [c for c in controls if c in profile.numeric_columns]
|
|
286
|
+
categorical_controls = [c for c in controls if c in profile.categorical_columns]
|
|
287
|
+
unknown_controls = [
|
|
288
|
+
c for c in controls if c not in numeric_controls + categorical_controls
|
|
289
|
+
]
|
|
290
|
+
if unknown_controls:
|
|
291
|
+
msg = f"Controls with unsupported types: {unknown_controls}"
|
|
292
|
+
raise TypeError(msg)
|
|
293
|
+
|
|
294
|
+
control_aliases: list[str] = []
|
|
295
|
+
new_columns = []
|
|
296
|
+
|
|
297
|
+
for c in numeric_controls:
|
|
298
|
+
alias = f"__ctrl_{len(control_aliases)}"
|
|
299
|
+
new_columns.append(nw.col(c).cast(nw.Float64).alias(alias))
|
|
300
|
+
control_aliases.append(alias)
|
|
301
|
+
|
|
302
|
+
dummy_exprs: list[nw.Expr] = []
|
|
303
|
+
dummy_aliases: list[str] = []
|
|
304
|
+
for c in categorical_controls:
|
|
305
|
+
if c not in profile.categorical_columns:
|
|
306
|
+
raise TypeError(f"Control '{c}' is not recognized as categorical")
|
|
307
|
+
|
|
308
|
+
unique_values = df_prepped.select(c).unique().collect().get_column(c)
|
|
309
|
+
values = unique_values.to_list()
|
|
310
|
+
if len(values) <= 1:
|
|
311
|
+
continue
|
|
312
|
+
for value in values[1:]:
|
|
313
|
+
alias = f"__ctrl_{len(control_aliases) + len(dummy_aliases)}"
|
|
314
|
+
expr = (nw.col(c) == value).cast(nw.Float64).alias(alias)
|
|
315
|
+
dummy_exprs.append(expr)
|
|
316
|
+
dummy_aliases.append(alias)
|
|
317
|
+
|
|
318
|
+
df_augmented = (
|
|
319
|
+
df_prepped.with_columns(*new_columns, *dummy_exprs)
|
|
320
|
+
if (new_columns or dummy_exprs)
|
|
321
|
+
else df_prepped
|
|
322
|
+
)
|
|
323
|
+
control_aliases.extend(dummy_aliases)
|
|
324
|
+
|
|
325
|
+
bin_index = profile.bin_name
|
|
326
|
+
|
|
327
|
+
agg_exprs = [
|
|
328
|
+
nw.len().alias("__count"),
|
|
329
|
+
profile.x_col.mean().alias(profile.x_name),
|
|
330
|
+
profile.y_col.sum().alias("__sum_y"),
|
|
331
|
+
]
|
|
332
|
+
agg_exprs.extend(nw.col(alias).sum().alias(alias) for alias in control_aliases)
|
|
333
|
+
|
|
334
|
+
per_bin = (
|
|
335
|
+
df_augmented.group_by(bin_index).agg(*agg_exprs).sort(bin_index)
|
|
336
|
+
).collect()
|
|
337
|
+
|
|
338
|
+
counts = per_bin.get_column("__count").to_numpy()
|
|
339
|
+
if counts.size < profile.num_bins:
|
|
340
|
+
msg = "Quantiles are not unique. Decrease number of bins."
|
|
341
|
+
raise ValueError(msg)
|
|
342
|
+
sum_y = per_bin.get_column("__sum_y").to_numpy()
|
|
343
|
+
if control_aliases:
|
|
344
|
+
bin_control_sums = np.column_stack(
|
|
345
|
+
[per_bin.get_column(alias).to_numpy() for alias in control_aliases]
|
|
346
|
+
)
|
|
347
|
+
else:
|
|
348
|
+
bin_control_sums = np.zeros((profile.num_bins, 0))
|
|
349
|
+
|
|
350
|
+
total_exprs = [nw.len().alias("__total_count")]
|
|
351
|
+
total_exprs.extend(
|
|
352
|
+
nw.col(alias).sum().alias(f"__total_ctrl_{idx}")
|
|
353
|
+
for idx, alias in enumerate(control_aliases)
|
|
354
|
+
)
|
|
355
|
+
total_exprs.extend(
|
|
356
|
+
(nw.col(alias) * profile.y_col).sum().alias(f"__wy_{idx}")
|
|
357
|
+
for idx, alias in enumerate(control_aliases)
|
|
358
|
+
)
|
|
359
|
+
for i, alias_i in enumerate(control_aliases):
|
|
360
|
+
for j, alias_j in enumerate(control_aliases[i:], start=i):
|
|
361
|
+
total_exprs.append(
|
|
362
|
+
(nw.col(alias_i) * nw.col(alias_j)).sum().alias(f"__ww_{i}_{j}")
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
totals = df_augmented.select(*total_exprs).collect()
|
|
366
|
+
total_count = totals.item(0, "__total_count")
|
|
367
|
+
if control_aliases:
|
|
368
|
+
total_ctrl_sums = np.array(
|
|
369
|
+
[
|
|
370
|
+
totals.item(0, f"__total_ctrl_{idx}")
|
|
371
|
+
for idx in range(len(control_aliases))
|
|
372
|
+
]
|
|
373
|
+
)
|
|
374
|
+
wy = np.array(
|
|
375
|
+
[totals.item(0, f"__wy_{idx}") for idx in range(len(control_aliases))]
|
|
376
|
+
)
|
|
377
|
+
ww = np.zeros((len(control_aliases), len(control_aliases)))
|
|
378
|
+
for i in range(len(control_aliases)):
|
|
379
|
+
for j in range(i, len(control_aliases)):
|
|
380
|
+
alias = f"__ww_{i}_{j}"
|
|
381
|
+
value = totals.item(0, alias)
|
|
382
|
+
ww[i, j] = value
|
|
383
|
+
ww[j, i] = value
|
|
384
|
+
else:
|
|
385
|
+
total_ctrl_sums = np.array([])
|
|
386
|
+
wy = np.array([])
|
|
387
|
+
ww = np.zeros((0, 0))
|
|
388
|
+
|
|
389
|
+
# Assemble normal equations
|
|
390
|
+
num_bins = profile.num_bins
|
|
391
|
+
k = len(control_aliases)
|
|
392
|
+
size = num_bins + k
|
|
393
|
+
XTX = np.zeros((size, size))
|
|
394
|
+
XTy = np.zeros(size)
|
|
395
|
+
|
|
396
|
+
XTX[:num_bins, :num_bins] = np.diag(counts)
|
|
397
|
+
if k:
|
|
398
|
+
XTX[:num_bins, num_bins:] = bin_control_sums
|
|
399
|
+
XTX[num_bins:, :num_bins] = bin_control_sums.T
|
|
400
|
+
XTX[num_bins:, num_bins:] = ww
|
|
401
|
+
XTy[num_bins:] = wy
|
|
402
|
+
|
|
403
|
+
XTy[:num_bins] = sum_y
|
|
404
|
+
|
|
405
|
+
try:
|
|
406
|
+
theta = np.linalg.solve(XTX, XTy)
|
|
407
|
+
except np.linalg.LinAlgError:
|
|
408
|
+
theta, *_ = np.linalg.lstsq(XTX, XTy, rcond=None)
|
|
409
|
+
|
|
410
|
+
beta = theta[:num_bins]
|
|
411
|
+
gamma = theta[num_bins:]
|
|
412
|
+
mean_controls = total_ctrl_sums / total_count if k else np.array([])
|
|
413
|
+
fitted = beta + (mean_controls @ gamma if k else 0.0)
|
|
414
|
+
|
|
415
|
+
y_vals = nw.new_series(
|
|
416
|
+
name=profile.y_name, values=fitted, backend=per_bin.implementation
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
df_plotting = per_bin.select(bin_index, profile.x_name).with_columns(y_vals).lazy()
|
|
420
|
+
|
|
421
|
+
return df_plotting, {"beta": beta, "gamma": gamma}
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
def make_plot_plotly(
|
|
425
|
+
df_prepped: nw.LazyFrame, profile: Profile, kwargs_binscatter: dict[str, Any]
|
|
426
|
+
) -> go.Figure:
|
|
427
|
+
"""Make plot from prepared dataframe.
|
|
428
|
+
|
|
429
|
+
Args:
|
|
430
|
+
df_prepped (nw.LazyFrame): Prepared dataframe. Has three columns: bin, x, y with names in profile"""
|
|
431
|
+
data = df_prepped.select(profile.x_name, profile.y_name).collect()
|
|
432
|
+
if data.shape[0] < profile.num_bins:
|
|
433
|
+
raise ValueError("Quantiles are not unique. Decrease number of bins.")
|
|
434
|
+
|
|
435
|
+
x = data.get_column(profile.x_name).to_list()
|
|
436
|
+
if len(set(x)) < profile.num_bins:
|
|
437
|
+
msg = f"Unique number of bins is {len(set(x))} fewer than {profile.num_bins} as desired. Decrease parameter num_bins."
|
|
438
|
+
raise ValueError(msg)
|
|
439
|
+
y = data.get_column(profile.y_name).to_list()
|
|
440
|
+
|
|
441
|
+
scatter_args = {
|
|
442
|
+
"x": x,
|
|
443
|
+
"y": y,
|
|
444
|
+
"range_x": profile.x_bounds,
|
|
445
|
+
"title": "Binscatter",
|
|
446
|
+
"labels": {
|
|
447
|
+
"x": profile.x_name,
|
|
448
|
+
"y": profile.y_name,
|
|
449
|
+
},
|
|
450
|
+
}
|
|
451
|
+
for k in kwargs_binscatter:
|
|
452
|
+
if k in ("x", "y", "range_x"):
|
|
453
|
+
msg = f"px.scatter will ignore keyword argument '{k}'"
|
|
454
|
+
warnings.warn(msg)
|
|
455
|
+
continue
|
|
456
|
+
scatter_args[k] = kwargs_binscatter[k]
|
|
457
|
+
|
|
458
|
+
return px.scatter(**scatter_args)
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
def _remove_bad_values(
|
|
462
|
+
df: nw.LazyFrame, cols_numeric: Iterable[str], cols_cat: Iterable[str]
|
|
463
|
+
) -> nw.LazyFrame:
|
|
464
|
+
"""Removes nulls and non-finite values for the provided columns."""
|
|
465
|
+
|
|
466
|
+
bad_conditions = []
|
|
467
|
+
|
|
468
|
+
for c in cols_numeric:
|
|
469
|
+
col = nw.col(c)
|
|
470
|
+
bad_conditions.append(col.is_null() | ~col.is_finite() | col.is_nan())
|
|
471
|
+
|
|
472
|
+
for c in cols_cat:
|
|
473
|
+
bad_conditions.append(nw.col(c).is_null())
|
|
474
|
+
|
|
475
|
+
if not bad_conditions:
|
|
476
|
+
return df
|
|
477
|
+
|
|
478
|
+
final_bad_condition = reduce(operator.or_, bad_conditions)
|
|
479
|
+
|
|
480
|
+
return df.filter(~final_bad_condition)
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def get_columns(
|
|
484
|
+
frame: nw.LazyFrame | nw.DataFrame,
|
|
485
|
+
) -> Tuple[Tuple[str, ...], Tuple[str, ...]]:
|
|
486
|
+
"""Return tuples of numeric and categorical column names for a narwhals frame."""
|
|
487
|
+
|
|
488
|
+
def _safe_columns(selection: Any) -> Tuple[str, ...]:
|
|
489
|
+
if selection is None:
|
|
490
|
+
return tuple()
|
|
491
|
+
columns: Tuple[str, ...] = tuple()
|
|
492
|
+
if hasattr(selection, "columns"):
|
|
493
|
+
try:
|
|
494
|
+
columns = tuple(selection.columns) # type: ignore[attr-defined]
|
|
495
|
+
except Exception: # pragma: no cover - backend quirk
|
|
496
|
+
columns = tuple()
|
|
497
|
+
if columns:
|
|
498
|
+
return columns
|
|
499
|
+
if hasattr(selection, "collect_schema"):
|
|
500
|
+
try:
|
|
501
|
+
schema = selection.collect_schema()
|
|
502
|
+
except Exception: # pragma: no cover - backend quirk
|
|
503
|
+
schema = None
|
|
504
|
+
else:
|
|
505
|
+
if schema is not None:
|
|
506
|
+
if hasattr(schema, "names") and callable(schema.names):
|
|
507
|
+
return tuple(schema.names())
|
|
508
|
+
if isinstance(schema, dict):
|
|
509
|
+
return tuple(schema.keys())
|
|
510
|
+
return tuple()
|
|
511
|
+
|
|
512
|
+
numeric_cols = _safe_columns(frame.select(ncs.numeric()))
|
|
513
|
+
frame_columns = tuple(frame.columns)
|
|
514
|
+
categorical_cols = tuple(col for col in frame_columns if col not in numeric_cols)
|
|
515
|
+
return numeric_cols, categorical_cols
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
# Quantiles
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
# Defined here for testability
|
|
522
|
+
def _add_fallback(
|
|
523
|
+
df: nw.LazyFrame, profile: Profile, probs: List[float]
|
|
524
|
+
) -> nw.LazyFrame:
|
|
525
|
+
try:
|
|
526
|
+
qs = df.select(
|
|
527
|
+
[
|
|
528
|
+
profile.x_col.quantile(p, interpolation="linear").alias(f"q{p}")
|
|
529
|
+
for p in probs
|
|
530
|
+
]
|
|
531
|
+
).collect()
|
|
532
|
+
except TypeError:
|
|
533
|
+
expr = cast(Any, profile.x_col)
|
|
534
|
+
qs = df.select([expr.quantile(p).alias(f"q{p}") for p in probs]).collect()
|
|
535
|
+
except Exception as e:
|
|
536
|
+
logger.error(
|
|
537
|
+
"Tried making quantiles with and without interpolation method for df of type: %s",
|
|
538
|
+
type(df),
|
|
539
|
+
)
|
|
540
|
+
raise e
|
|
541
|
+
qs_long = (
|
|
542
|
+
qs.unpivot(variable_name="prob", value_name="quantile")
|
|
543
|
+
.sort("quantile")
|
|
544
|
+
.with_row_index(profile.bin_name)
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
quantile_bins = qs_long.select("quantile", profile.bin_name).lazy()
|
|
548
|
+
|
|
549
|
+
# Sorting is not always necessary - but for safety we sort
|
|
550
|
+
return (
|
|
551
|
+
df.sort(profile.x_name)
|
|
552
|
+
.join_asof(
|
|
553
|
+
quantile_bins,
|
|
554
|
+
left_on=profile.x_name,
|
|
555
|
+
right_on="quantile",
|
|
556
|
+
strategy="forward",
|
|
557
|
+
)
|
|
558
|
+
.drop("quantile")
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
def _make_probs(num_bins) -> List[float]:
|
|
563
|
+
return [i / num_bins for i in range(1, num_bins + 1)]
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
def configure_quantile_handler(profile: Profile) -> Callable:
|
|
567
|
+
probs = _make_probs(profile.num_bins)
|
|
568
|
+
|
|
569
|
+
def add_fallback(df: nw.LazyFrame):
|
|
570
|
+
return _add_fallback(df, profile, probs)
|
|
571
|
+
|
|
572
|
+
def add_to_dask(df: nw.DataFrame) -> nw.LazyFrame:
|
|
573
|
+
try:
|
|
574
|
+
from pandas import cut
|
|
575
|
+
except ImportError:
|
|
576
|
+
raise ImportError("Dask support requires dask and pandas to be installed.")
|
|
577
|
+
|
|
578
|
+
df_native = df.to_native()
|
|
579
|
+
logger.debug("Type of df_native (should be dask): %s", type(df_native))
|
|
580
|
+
quantiles = df_native[profile.x_name].quantile(probs[:-1]).compute()
|
|
581
|
+
bins = (float("-inf"), *quantiles, float("inf"))
|
|
582
|
+
df_native[profile.bin_name] = df_native[profile.x_name].map_partitions(
|
|
583
|
+
cut,
|
|
584
|
+
bins=bins,
|
|
585
|
+
labels=range(len(probs)),
|
|
586
|
+
include_lowest=False,
|
|
587
|
+
right=False,
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
return nw.from_native(df_native).lazy()
|
|
591
|
+
|
|
592
|
+
def add_to_pandas(df: nw.DataFrame) -> nw.LazyFrame:
|
|
593
|
+
try:
|
|
594
|
+
from pandas import cut
|
|
595
|
+
except ImportError:
|
|
596
|
+
raise ImportError("Pandas support requires pandas to be installed.")
|
|
597
|
+
df_native = df.to_native()
|
|
598
|
+
x = df_native[profile.x_name]
|
|
599
|
+
quantiles = x.quantile(probs[:-1])
|
|
600
|
+
|
|
601
|
+
bins = (float("-Inf"), *quantiles, float("Inf"))
|
|
602
|
+
buckets = cut(
|
|
603
|
+
df_native[profile.x_name],
|
|
604
|
+
bins=bins,
|
|
605
|
+
labels=range(len(probs)),
|
|
606
|
+
include_lowest=False,
|
|
607
|
+
right=False,
|
|
608
|
+
)
|
|
609
|
+
df_native[profile.bin_name] = buckets
|
|
610
|
+
|
|
611
|
+
return nw.from_native(df_native).lazy()
|
|
612
|
+
|
|
613
|
+
def add_to_polars(df: nw.DataFrame) -> nw.LazyFrame:
|
|
614
|
+
try:
|
|
615
|
+
import polars as pl
|
|
616
|
+
except ImportError:
|
|
617
|
+
raise ImportError("Polars support requires Polars to be installed.")
|
|
618
|
+
# Because cut and qcut are not stable we use when-then
|
|
619
|
+
df_native = df.to_native()
|
|
620
|
+
x_col = pl.col(profile.x_name)
|
|
621
|
+
|
|
622
|
+
qs = df_native.select(
|
|
623
|
+
[x_col.quantile(p, interpolation="linear").alias(f"q{p}") for p in probs]
|
|
624
|
+
).collect()
|
|
625
|
+
expr = pl
|
|
626
|
+
n = qs.width
|
|
627
|
+
for i in range(n):
|
|
628
|
+
thr = qs.item(0, i)
|
|
629
|
+
cond = x_col.le(thr) if i == n - 1 else x_col.lt(thr)
|
|
630
|
+
expr = expr.when(cond).then(pl.lit(i))
|
|
631
|
+
expr = expr.alias(profile.bin_name)
|
|
632
|
+
df_native_with_bin = df_native.with_columns(expr)
|
|
633
|
+
|
|
634
|
+
return nw.from_native(df_native_with_bin).lazy()
|
|
635
|
+
|
|
636
|
+
def add_to_duckdb(df: nw.DataFrame) -> nw.LazyFrame:
|
|
637
|
+
try:
|
|
638
|
+
import duckdb
|
|
639
|
+
|
|
640
|
+
rel = df.to_native()
|
|
641
|
+
assert isinstance(rel, duckdb.DuckDBPyRelation), f"{type(rel)=}"
|
|
642
|
+
|
|
643
|
+
except Exception as e:
|
|
644
|
+
raise RuntimeError(
|
|
645
|
+
"Failed to use df.to_native(); DuckDB may not be installed."
|
|
646
|
+
) from e
|
|
647
|
+
|
|
648
|
+
order_expr = f"{profile.x_name} ASC"
|
|
649
|
+
rel_with_bins = rel.project(
|
|
650
|
+
f"*, ntile({len(probs)}) OVER (ORDER BY {order_expr}) - 1 AS {profile.bin_name}"
|
|
651
|
+
)
|
|
652
|
+
assert isinstance(rel_with_bins, duckdb.DuckDBPyRelation), (
|
|
653
|
+
f"{type(rel_with_bins)=}"
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
return nw.from_native(rel_with_bins).lazy()
|
|
657
|
+
|
|
658
|
+
def add_to_pyspark(df: nw.DataFrame) -> nw.LazyFrame:
|
|
659
|
+
try:
|
|
660
|
+
from pyspark.ml.feature import Bucketizer
|
|
661
|
+
from pyspark.sql.functions import col
|
|
662
|
+
except ImportError as e:
|
|
663
|
+
raise ImportError(
|
|
664
|
+
f"PySpark support requires pyspark to be installed. Original error: {e}"
|
|
665
|
+
) from e
|
|
666
|
+
sdf = df.to_native()
|
|
667
|
+
qs = sdf.approxQuantile(profile.x_name, (0.0, *probs), relativeError=0.01)
|
|
668
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
669
|
+
sample = sdf.sample(False, 0.02, seed=1).select(profile.x_name).toPandas()
|
|
670
|
+
pd_qs = sample[profile.x_name].quantile((0.0, *probs)).to_list()
|
|
671
|
+
logger.debug(
|
|
672
|
+
"Pyspark vs pandas (sample) quantiles: %s", list(zip(qs, pd_qs))
|
|
673
|
+
)
|
|
674
|
+
if len(set(qs)) < len(qs):
|
|
675
|
+
raise ValueError("Quantiles not unique. Decrease number of bins.")
|
|
676
|
+
|
|
677
|
+
bucketizer = Bucketizer(
|
|
678
|
+
splits=qs,
|
|
679
|
+
inputCol=profile.x_name,
|
|
680
|
+
outputCol=profile.bin_name,
|
|
681
|
+
handleInvalid="keep",
|
|
682
|
+
)
|
|
683
|
+
|
|
684
|
+
sdf_binned = bucketizer.transform(sdf).withColumn(
|
|
685
|
+
profile.bin_name, col(profile.bin_name).cast("int")
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
return nw.from_native(sdf_binned).lazy()
|
|
689
|
+
|
|
690
|
+
if profile.implementation == Implementation.PANDAS:
|
|
691
|
+
return add_to_pandas
|
|
692
|
+
elif profile.implementation == Implementation.POLARS:
|
|
693
|
+
return add_to_polars
|
|
694
|
+
elif profile.implementation == Implementation.PYSPARK:
|
|
695
|
+
return add_to_pyspark
|
|
696
|
+
elif profile.implementation == Implementation.DUCKDB:
|
|
697
|
+
return add_to_duckdb
|
|
698
|
+
elif profile.implementation == Implementation.DASK:
|
|
699
|
+
return add_to_dask
|
|
700
|
+
else:
|
|
701
|
+
return add_fallback
|
|
702
|
+
|
|
703
|
+
|
|
704
|
+
def _compute_quantiles(
|
|
705
|
+
df: nw.DataFrame, colname: str, probs: Iterable[float], bin_name: str
|
|
706
|
+
) -> nw.LazyFrame:
|
|
707
|
+
"""Get multiple quantiles in one operation"""
|
|
708
|
+
col = nw.col(colname)
|
|
709
|
+
if df.implementation != nw.Implementation.PYSPARK:
|
|
710
|
+
qs = df.select(
|
|
711
|
+
[col.quantile(p, interpolation="linear").alias(f"q{p}") for p in probs]
|
|
712
|
+
)
|
|
713
|
+
else:
|
|
714
|
+
# Pyspark - ugly hack
|
|
715
|
+
try:
|
|
716
|
+
from pyspark.sql import SparkSession
|
|
717
|
+
except ImportError:
|
|
718
|
+
raise ImportError("PySpark support requires pyspark to be installed.")
|
|
719
|
+
spark = SparkSession.builder.getOrCreate()
|
|
720
|
+
|
|
721
|
+
quantiles: list[float] = (
|
|
722
|
+
df.select(colname).to_native().approxQuantile(colname, probs, 0.03)
|
|
723
|
+
)
|
|
724
|
+
q_data = {}
|
|
725
|
+
for p, q in zip(probs, quantiles):
|
|
726
|
+
k = f"q{p}"
|
|
727
|
+
q_data[k] = [q]
|
|
728
|
+
qs_spark = spark.createDataFrame(q_data)
|
|
729
|
+
qs = nw.from_native(qs_spark)
|
|
730
|
+
|
|
731
|
+
return (
|
|
732
|
+
qs.unpivot(variable_name="prob", value_name="quantile")
|
|
733
|
+
.sort("quantile")
|
|
734
|
+
.with_row_index(bin_name, order_by="quantile")
|
|
735
|
+
.lazy()
|
|
736
|
+
)
|