binscatter 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Matthias Kaeding
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,97 @@
1
+ Metadata-Version: 2.3
2
+ Name: binscatter
3
+ Version: 0.1.0
4
+ Summary: Cross-backend binscatter plots.
5
+ Keywords: binscatter,visualization,econometrics
6
+ Author: Matthias Kaeding
7
+ License: MIT License
8
+
9
+ Copyright (c) 2025 Matthias Kaeding
10
+
11
+ Permission is hereby granted, free of charge, to any person obtaining a copy
12
+ of this software and associated documentation files (the "Software"), to deal
13
+ in the Software without restriction, including without limitation the rights
14
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15
+ copies of the Software, and to permit persons to whom the Software is
16
+ furnished to do so, subject to the following conditions:
17
+
18
+ The above copyright notice and this permission notice shall be included in all
19
+ copies or substantial portions of the Software.
20
+
21
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
27
+ SOFTWARE.
28
+ Classifier: Development Status :: 3 - Alpha
29
+ Classifier: Intended Audience :: Science/Research
30
+ Classifier: License :: OSI Approved :: MIT License
31
+ Classifier: Programming Language :: Python :: 3 :: Only
32
+ Classifier: Programming Language :: Python :: 3.11
33
+ Classifier: Topic :: Scientific/Engineering :: Visualization
34
+ Requires-Dist: narwhals>=2.1.2
35
+ Requires-Dist: numpy>=2.3.2
36
+ Requires-Dist: plotly>=6.3.0
37
+ Requires-Python: >=3.11
38
+ Project-URL: homepage, https://github.com/matthiaskaeding/binscatter
39
+ Project-URL: issues, https://github.com/matthiaskaeding/binscatter/issues
40
+ Project-URL: repository, https://github.com/matthiaskaeding/binscatter
41
+ Description-Content-Type: text/markdown
42
+
43
+ # Dataframe agnostic binscatter plots
44
+
45
+ This package implements binscatter plots following:
46
+
47
+ > Cattaneo, Crump, Farrell and Feng (2024)
48
+ > "On Binscatter"
49
+ > American Economic Review, 114(5), pp. 1488-1514
50
+ > [DOI: 10.1257/aer.20221576](https://doi.org/10.1257/aer.20221576)
51
+
52
+ - Uses `narwhals` as dataframe layer `binscatter`.
53
+ - Currently supports: pandas, Polars, DuckDB, Dask, and PySpark
54
+ - All other Narwhals backends fall back to a generic quantile handler if a native path is unavailable
55
+ - Lightweight - little dependencies
56
+ - Uses `plotly` as graphics backend - because: (1) its great (2) it uses `narwhals` as well, minimizing dependencies
57
+ - Pythonic alternative to the excellent **binsreg** package
58
+
59
+ ---
60
+
61
+ ## Example
62
+
63
+ We made this noisy scatterplot:
64
+
65
+ ![Noisy scatterplot](https://raw.githubusercontent.com/matthiaskaeding/binscatter/images/images/readme/scatter.png)
66
+
67
+ This is how we make a nice binscatter plot, controlling for a set of features:
68
+
69
+ ```python
70
+ from binscatter import binscatter
71
+
72
+ p_binscatter_controls = binscatter(
73
+ df,
74
+ "mtr90_lag3",
75
+ "lnpat",
76
+ [
77
+ "top_corp_lag3",
78
+ "real_gdp_pc",
79
+ "population_density",
80
+ "rd_credit_lag3",
81
+ "statenum",
82
+ "year",
83
+ ],
84
+ num_bins=35,
85
+ )
86
+ ```
87
+
88
+ ![Binscatter with controls (35 bins)](https://raw.githubusercontent.com/matthiaskaeding/binscatter/images/images/readme/binscatter_controls.png)
89
+
90
+ The data originates from:
91
+
92
+ Akcigit, Ufuk; Grigsby, John; Nicholas, Tom; Stantcheva, Stefanie, 2021, "Replication Data for: 'Taxation and Innovation in the 20th Century'", https://doi.org/10.7910/DVN/SR410I, Harvard Dataverse, V1
93
+
94
+ ## Tests
95
+
96
+ - Run the full backend matrix, including PySpark: `just test`
97
+ - Use the faster run without PySpark: `just test-fast`
@@ -0,0 +1,55 @@
1
+ # Dataframe agnostic binscatter plots
2
+
3
+ This package implements binscatter plots following:
4
+
5
+ > Cattaneo, Crump, Farrell and Feng (2024)
6
+ > "On Binscatter"
7
+ > American Economic Review, 114(5), pp. 1488-1514
8
+ > [DOI: 10.1257/aer.20221576](https://doi.org/10.1257/aer.20221576)
9
+
10
+ - Uses `narwhals` as dataframe layer `binscatter`.
11
+ - Currently supports: pandas, Polars, DuckDB, Dask, and PySpark
12
+ - All other Narwhals backends fall back to a generic quantile handler if a native path is unavailable
13
+ - Lightweight - little dependencies
14
+ - Uses `plotly` as graphics backend - because: (1) its great (2) it uses `narwhals` as well, minimizing dependencies
15
+ - Pythonic alternative to the excellent **binsreg** package
16
+
17
+ ---
18
+
19
+ ## Example
20
+
21
+ We made this noisy scatterplot:
22
+
23
+ ![Noisy scatterplot](https://raw.githubusercontent.com/matthiaskaeding/binscatter/images/images/readme/scatter.png)
24
+
25
+ This is how we make a nice binscatter plot, controlling for a set of features:
26
+
27
+ ```python
28
+ from binscatter import binscatter
29
+
30
+ p_binscatter_controls = binscatter(
31
+ df,
32
+ "mtr90_lag3",
33
+ "lnpat",
34
+ [
35
+ "top_corp_lag3",
36
+ "real_gdp_pc",
37
+ "population_density",
38
+ "rd_credit_lag3",
39
+ "statenum",
40
+ "year",
41
+ ],
42
+ num_bins=35,
43
+ )
44
+ ```
45
+
46
+ ![Binscatter with controls (35 bins)](https://raw.githubusercontent.com/matthiaskaeding/binscatter/images/images/readme/binscatter_controls.png)
47
+
48
+ The data originates from:
49
+
50
+ Akcigit, Ufuk; Grigsby, John; Nicholas, Tom; Stantcheva, Stefanie, 2021, "Replication Data for: 'Taxation and Innovation in the 20th Century'", https://doi.org/10.7910/DVN/SR410I, Harvard Dataverse, V1
51
+
52
+ ## Tests
53
+
54
+ - Run the full backend matrix, including PySpark: `just test`
55
+ - Use the faster run without PySpark: `just test-fast`
@@ -0,0 +1,48 @@
1
+ [build-system]
2
+ requires = ["uv_build>=0.9.18,<0.10.0"]
3
+ build-backend = "uv_build"
4
+
5
+ [project]
6
+ name = "binscatter"
7
+ version = "0.1.0"
8
+ description = "Cross-backend binscatter plots."
9
+ readme = "README.md"
10
+ requires-python = ">=3.11"
11
+ dependencies = ["narwhals>=2.1.2", "numpy>=2.3.2", "plotly>=6.3.0"]
12
+ authors = [{ name = "Matthias Kaeding" }]
13
+ license = { file = "LICENSE" }
14
+ keywords = ["binscatter", "visualization", "econometrics"]
15
+ classifiers = [
16
+ "Development Status :: 3 - Alpha",
17
+ "Intended Audience :: Science/Research",
18
+ "License :: OSI Approved :: MIT License",
19
+ "Programming Language :: Python :: 3 :: Only",
20
+ "Programming Language :: Python :: 3.11",
21
+ "Topic :: Scientific/Engineering :: Visualization",
22
+ ]
23
+
24
+ [project.urls]
25
+ homepage = "https://github.com/matthiaskaeding/binscatter"
26
+ repository = "https://github.com/matthiaskaeding/binscatter"
27
+ issues = "https://github.com/matthiaskaeding/binscatter/issues"
28
+
29
+ [dependency-groups]
30
+ dev = [
31
+ "kaleido>=0.2.1",
32
+ "ipykernel>=6.29.5",
33
+ "pytest>=8.3.5",
34
+ "polars>=1.22.0",
35
+ "pandas>=1.4.4",
36
+ "duckdb>=1.3.2",
37
+ "pyarrow>=21.0.0",
38
+ "ibis>=3.3.0",
39
+ "sqlframe>=3.39.2",
40
+ "pyspark>=4.0.0",
41
+ "dask>=2025.7.0",
42
+ "nbformat>=5.10.4",
43
+ "ty>=0.0.4",
44
+ "statsmodels>=0.14.6",
45
+ ]
46
+
47
+ [tool.ty.rules]
48
+ unresolved-import = "ignore"
@@ -0,0 +1,3 @@
1
+ from binscatter.core import binscatter
2
+
3
+ __all__ = ["binscatter"]
@@ -0,0 +1,736 @@
1
+ import logging
2
+ import math
3
+ import operator
4
+ import uuid
5
+ import warnings
6
+ from functools import reduce
7
+ from typing import (
8
+ Any,
9
+ Callable,
10
+ Iterable,
11
+ List,
12
+ Literal,
13
+ NamedTuple,
14
+ Tuple,
15
+ cast,
16
+ overload,
17
+ )
18
+
19
+ import narwhals as nw
20
+ import narwhals.selectors as ncs
21
+ import numpy as np
22
+ import plotly.express as px
23
+ from narwhals import Implementation
24
+ from narwhals.typing import IntoDataFrame
25
+ from plotly import graph_objects as go
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ @overload
31
+ def binscatter(
32
+ df: IntoDataFrame,
33
+ x: str,
34
+ y: str,
35
+ controls: Iterable[str] | str | None = None,
36
+ num_bins=20,
37
+ return_type: Literal["plotly"] = "plotly",
38
+ plot_args=None,
39
+ **kwargs,
40
+ ) -> go.Figure: ...
41
+
42
+
43
+ @overload
44
+ def binscatter(
45
+ df: IntoDataFrame,
46
+ x: str,
47
+ y: str,
48
+ controls: Iterable[str] | str | None = None,
49
+ num_bins=20,
50
+ return_type: Literal["native"] = "native",
51
+ **kwargs,
52
+ ) -> object: ...
53
+
54
+
55
+ def binscatter(
56
+ df: IntoDataFrame,
57
+ x: str,
58
+ y: str,
59
+ controls: Iterable[str] | str | None = None,
60
+ num_bins: int = 20,
61
+ return_type: Literal["plotly", "native"] = "plotly",
62
+ **kwargs_binscatter,
63
+ ) -> object:
64
+ """Creates a binned scatter plot by grouping x values into quantile bins and plotting mean y values.
65
+
66
+ Args:
67
+ df (IntoDataFrame): Input dataframe - must be a type supported by narwhals
68
+ x (str): Name of x column
69
+ y (str): Name y column
70
+ controls (Iterable[str]): Names of control variables (numeric). These are partialled out
71
+ following Cattaneo et al. (2024).
72
+ num_bins (int, optional): Number of bins to use. Defaults to 20
73
+ return_type (str): Return type. Default "plotly" gives a plotly plot.
74
+ kwargs (dict, optional): Additional arguments used in plotly.express.scatter to make the binscatter plot.
75
+ Otherwise "native" returns a dataframe that is natural match to input dataframe.
76
+
77
+
78
+ Returns:
79
+ plotly plot (default) if return_type == "plotly". Otherwise native dataframe, depending on input.
80
+ """
81
+ if return_type not in ("plotly", "native"):
82
+ msg = f"Invalid return_type: {return_type}"
83
+ raise ValueError(msg)
84
+ # Prepare dataframe: sort, remove non numerics and add bins
85
+ df_prepped, profile = prep(df, x, y, controls, num_bins)
86
+
87
+ # Currently there are 2 cases:
88
+ # (1) no controls: the easy one, just compute the means by bin
89
+ # (2) controls: here we need to compute regression coefficients
90
+ # and partial out the effect of the controls
91
+ # (see section 2.2 in Cattaneo, Crump, Farrell and Feng (2024))
92
+ if not controls:
93
+ df_plotting: nw.LazyFrame = (
94
+ df_prepped.group_by(profile.bin_name)
95
+ .agg(profile.x_col.mean(), profile.y_col.mean())
96
+ .with_columns(nw.col(profile.bin_name).cast(nw.Int32))
97
+ ).lazy()
98
+ else:
99
+ df_plotting, _ = partial_out_controls(df_prepped, profile)
100
+
101
+ match return_type:
102
+ case "plotly":
103
+ return make_plot_plotly(
104
+ df_plotting, profile, kwargs_binscatter=kwargs_binscatter
105
+ )
106
+ case "native":
107
+ df_out_nw = df_plotting.rename({profile.bin_name: "bin"}).sort("bin")
108
+ logger.debug(
109
+ "Type of df_out_nw: %s, implementation: %s",
110
+ type(df_out_nw),
111
+ df_out_nw.implementation,
112
+ )
113
+
114
+ if profile.implementation in (
115
+ Implementation.PYSPARK,
116
+ Implementation.DUCKDB,
117
+ Implementation.DASK,
118
+ ):
119
+ return df_out_nw.to_native()
120
+ else:
121
+ return df_out_nw.collect().to_native()
122
+
123
+
124
+ class Profile(NamedTuple):
125
+ """Main profile which holds bunch of data derived from dataframe."""
126
+
127
+ x_name: str
128
+ y_name: str
129
+ controls: Tuple[str, ...]
130
+ num_bins: int
131
+ bin_name: str
132
+ x_bounds: Tuple[float, float]
133
+ distinct_suffix: str
134
+ is_lazy_input: bool
135
+ implementation: Implementation
136
+ numeric_columns: Tuple[str, ...]
137
+ categorical_columns: Tuple[str, ...]
138
+
139
+ @property
140
+ def x_col(self) -> nw.Expr:
141
+ return nw.col(self.x_name)
142
+
143
+ @property
144
+ def y_col(self) -> nw.Expr:
145
+ return nw.col(self.y_name)
146
+
147
+
148
+ def prep(
149
+ df_in: IntoDataFrame,
150
+ x_name: str,
151
+ y_name: str,
152
+ controls: Iterable[str] | str | None = None,
153
+ num_bins: int = 20,
154
+ ) -> Tuple[nw.LazyFrame, Profile]:
155
+ """Prepares the input data and derives profile.
156
+
157
+ Args:
158
+ df: Input dataframe.
159
+ x_name: name of x col
160
+ y_name: name of y col
161
+ controls: Iterable of control vars
162
+ num_bins: Number of bins to use for binscatter. Must be less than number of rows.
163
+
164
+ Returns:
165
+ tuple: (narwhals.LazyFrame, Profile)
166
+ - Sorted input dataframe converted to a narwhals LazyFrame
167
+ - Profile object with metadata about the data
168
+
169
+ Raises:
170
+ AssertionError: If input validation fails
171
+ """
172
+ if num_bins <= 1:
173
+ raise ValueError("num_bins must be greater than 1")
174
+ if not isinstance(x_name, str):
175
+ raise TypeError("x_name must be a string")
176
+ if not isinstance(y_name, str):
177
+ raise TypeError("y_name must be a string")
178
+
179
+ if controls is None:
180
+ controls = ()
181
+ elif isinstance(controls, str):
182
+ controls = (controls,)
183
+ else:
184
+ try:
185
+ controls = tuple(controls)
186
+ except TypeError:
187
+ raise TypeError(
188
+ f"controls must be a string, iterable, or None, got {type(controls)}"
189
+ )
190
+ if not all(isinstance(c, str) for c in controls):
191
+ raise TypeError("controls must contain only strings")
192
+
193
+ dfn: nw.DataFrame | nw.LazyFrame = nw.from_native(df_in)
194
+ logger.debug("Type after calling to native: %s", type(dfn.to_native()))
195
+ if type(dfn) is nw.DataFrame:
196
+ is_lazy_input = False
197
+ elif type(dfn) is nw.LazyFrame:
198
+ is_lazy_input = True
199
+ else:
200
+ msg = f"Unexpected narwhals type {(type(dfn))}"
201
+ raise ValueError(msg)
202
+ dfl: nw.LazyFrame = dfn.lazy()
203
+
204
+ try:
205
+ df = dfl.select(x_name, y_name, *controls)
206
+ except Exception as e:
207
+ cols = dfl.columns
208
+ for c in [x_name, y_name, *controls]:
209
+ if c not in cols:
210
+ msg = f"{c} not in input dataframe"
211
+ raise ValueError(msg)
212
+ raise e
213
+
214
+ assert num_bins > 1
215
+
216
+ # Find name for bins
217
+ distinct_suffix = str(uuid.uuid4()).replace("-", "_")
218
+ bin_name = f"bins____{distinct_suffix}"
219
+
220
+ cols_numeric, cols_cat = get_columns(df)
221
+ union_cols = set(cols_numeric) | set(cols_cat)
222
+ if set(df.columns) - union_cols:
223
+ missing = [c for c in df.columns if c not in union_cols]
224
+ msg = f"Columns with unsupported types: {missing}"
225
+ raise TypeError(msg)
226
+ missing_controls = [c for c in controls if c not in union_cols]
227
+ if missing_controls:
228
+ msg = f"Unknown control columns (neither numeric nor categorical): {missing_controls}"
229
+ raise TypeError(msg)
230
+
231
+ df_filtered = _remove_bad_values(df, cols_numeric, cols_cat)
232
+
233
+ # We need the range of x for plotting
234
+ bounds_df = df_filtered.select(
235
+ nw.col(x_name).min().alias("x_min"),
236
+ nw.col(x_name).max().alias("x_max"),
237
+ ).collect()
238
+ x_bounds = (bounds_df.item(0, "x_min"), bounds_df.item(0, "x_max"))
239
+ for val, fun in zip(x_bounds, ["min", "max"]):
240
+ if not math.isfinite(val):
241
+ msg = f"{fun}({x_name})={val}"
242
+ raise ValueError(msg)
243
+
244
+ profile = Profile(
245
+ num_bins=num_bins,
246
+ x_name=x_name,
247
+ y_name=y_name,
248
+ bin_name=bin_name,
249
+ controls=controls,
250
+ x_bounds=x_bounds,
251
+ distinct_suffix=distinct_suffix,
252
+ is_lazy_input=is_lazy_input,
253
+ implementation=df_filtered.implementation,
254
+ numeric_columns=cols_numeric,
255
+ categorical_columns=cols_cat,
256
+ )
257
+ logger.debug("Profile: %s", profile)
258
+
259
+ quantile_handler = configure_quantile_handler(profile)
260
+ try:
261
+ df_with_bins = quantile_handler(df_filtered)
262
+ except ValueError as err:
263
+ err_text = str(err)
264
+ if (
265
+ "Quantiles are not unique" in err_text
266
+ or "Bin edges must be unique" in err_text
267
+ ):
268
+ raise ValueError(
269
+ "Quantiles are not unique. Decrease number of bins."
270
+ ) from err
271
+ raise
272
+
273
+ return df_with_bins.lazy(), profile
274
+
275
+
276
+ def partial_out_controls(
277
+ df_prepped: nw.LazyFrame, profile: Profile
278
+ ) -> tuple[nw.LazyFrame, dict[str, np.ndarray]]:
279
+ """Compute binscatter means after partialling out controls following Cattaneo et al. (2024)."""
280
+
281
+ controls = profile.controls
282
+ if not controls:
283
+ raise ValueError("Controls must be provided for partial_out_controls")
284
+
285
+ numeric_controls = [c for c in controls if c in profile.numeric_columns]
286
+ categorical_controls = [c for c in controls if c in profile.categorical_columns]
287
+ unknown_controls = [
288
+ c for c in controls if c not in numeric_controls + categorical_controls
289
+ ]
290
+ if unknown_controls:
291
+ msg = f"Controls with unsupported types: {unknown_controls}"
292
+ raise TypeError(msg)
293
+
294
+ control_aliases: list[str] = []
295
+ new_columns = []
296
+
297
+ for c in numeric_controls:
298
+ alias = f"__ctrl_{len(control_aliases)}"
299
+ new_columns.append(nw.col(c).cast(nw.Float64).alias(alias))
300
+ control_aliases.append(alias)
301
+
302
+ dummy_exprs: list[nw.Expr] = []
303
+ dummy_aliases: list[str] = []
304
+ for c in categorical_controls:
305
+ if c not in profile.categorical_columns:
306
+ raise TypeError(f"Control '{c}' is not recognized as categorical")
307
+
308
+ unique_values = df_prepped.select(c).unique().collect().get_column(c)
309
+ values = unique_values.to_list()
310
+ if len(values) <= 1:
311
+ continue
312
+ for value in values[1:]:
313
+ alias = f"__ctrl_{len(control_aliases) + len(dummy_aliases)}"
314
+ expr = (nw.col(c) == value).cast(nw.Float64).alias(alias)
315
+ dummy_exprs.append(expr)
316
+ dummy_aliases.append(alias)
317
+
318
+ df_augmented = (
319
+ df_prepped.with_columns(*new_columns, *dummy_exprs)
320
+ if (new_columns or dummy_exprs)
321
+ else df_prepped
322
+ )
323
+ control_aliases.extend(dummy_aliases)
324
+
325
+ bin_index = profile.bin_name
326
+
327
+ agg_exprs = [
328
+ nw.len().alias("__count"),
329
+ profile.x_col.mean().alias(profile.x_name),
330
+ profile.y_col.sum().alias("__sum_y"),
331
+ ]
332
+ agg_exprs.extend(nw.col(alias).sum().alias(alias) for alias in control_aliases)
333
+
334
+ per_bin = (
335
+ df_augmented.group_by(bin_index).agg(*agg_exprs).sort(bin_index)
336
+ ).collect()
337
+
338
+ counts = per_bin.get_column("__count").to_numpy()
339
+ if counts.size < profile.num_bins:
340
+ msg = "Quantiles are not unique. Decrease number of bins."
341
+ raise ValueError(msg)
342
+ sum_y = per_bin.get_column("__sum_y").to_numpy()
343
+ if control_aliases:
344
+ bin_control_sums = np.column_stack(
345
+ [per_bin.get_column(alias).to_numpy() for alias in control_aliases]
346
+ )
347
+ else:
348
+ bin_control_sums = np.zeros((profile.num_bins, 0))
349
+
350
+ total_exprs = [nw.len().alias("__total_count")]
351
+ total_exprs.extend(
352
+ nw.col(alias).sum().alias(f"__total_ctrl_{idx}")
353
+ for idx, alias in enumerate(control_aliases)
354
+ )
355
+ total_exprs.extend(
356
+ (nw.col(alias) * profile.y_col).sum().alias(f"__wy_{idx}")
357
+ for idx, alias in enumerate(control_aliases)
358
+ )
359
+ for i, alias_i in enumerate(control_aliases):
360
+ for j, alias_j in enumerate(control_aliases[i:], start=i):
361
+ total_exprs.append(
362
+ (nw.col(alias_i) * nw.col(alias_j)).sum().alias(f"__ww_{i}_{j}")
363
+ )
364
+
365
+ totals = df_augmented.select(*total_exprs).collect()
366
+ total_count = totals.item(0, "__total_count")
367
+ if control_aliases:
368
+ total_ctrl_sums = np.array(
369
+ [
370
+ totals.item(0, f"__total_ctrl_{idx}")
371
+ for idx in range(len(control_aliases))
372
+ ]
373
+ )
374
+ wy = np.array(
375
+ [totals.item(0, f"__wy_{idx}") for idx in range(len(control_aliases))]
376
+ )
377
+ ww = np.zeros((len(control_aliases), len(control_aliases)))
378
+ for i in range(len(control_aliases)):
379
+ for j in range(i, len(control_aliases)):
380
+ alias = f"__ww_{i}_{j}"
381
+ value = totals.item(0, alias)
382
+ ww[i, j] = value
383
+ ww[j, i] = value
384
+ else:
385
+ total_ctrl_sums = np.array([])
386
+ wy = np.array([])
387
+ ww = np.zeros((0, 0))
388
+
389
+ # Assemble normal equations
390
+ num_bins = profile.num_bins
391
+ k = len(control_aliases)
392
+ size = num_bins + k
393
+ XTX = np.zeros((size, size))
394
+ XTy = np.zeros(size)
395
+
396
+ XTX[:num_bins, :num_bins] = np.diag(counts)
397
+ if k:
398
+ XTX[:num_bins, num_bins:] = bin_control_sums
399
+ XTX[num_bins:, :num_bins] = bin_control_sums.T
400
+ XTX[num_bins:, num_bins:] = ww
401
+ XTy[num_bins:] = wy
402
+
403
+ XTy[:num_bins] = sum_y
404
+
405
+ try:
406
+ theta = np.linalg.solve(XTX, XTy)
407
+ except np.linalg.LinAlgError:
408
+ theta, *_ = np.linalg.lstsq(XTX, XTy, rcond=None)
409
+
410
+ beta = theta[:num_bins]
411
+ gamma = theta[num_bins:]
412
+ mean_controls = total_ctrl_sums / total_count if k else np.array([])
413
+ fitted = beta + (mean_controls @ gamma if k else 0.0)
414
+
415
+ y_vals = nw.new_series(
416
+ name=profile.y_name, values=fitted, backend=per_bin.implementation
417
+ )
418
+
419
+ df_plotting = per_bin.select(bin_index, profile.x_name).with_columns(y_vals).lazy()
420
+
421
+ return df_plotting, {"beta": beta, "gamma": gamma}
422
+
423
+
424
+ def make_plot_plotly(
425
+ df_prepped: nw.LazyFrame, profile: Profile, kwargs_binscatter: dict[str, Any]
426
+ ) -> go.Figure:
427
+ """Make plot from prepared dataframe.
428
+
429
+ Args:
430
+ df_prepped (nw.LazyFrame): Prepared dataframe. Has three columns: bin, x, y with names in profile"""
431
+ data = df_prepped.select(profile.x_name, profile.y_name).collect()
432
+ if data.shape[0] < profile.num_bins:
433
+ raise ValueError("Quantiles are not unique. Decrease number of bins.")
434
+
435
+ x = data.get_column(profile.x_name).to_list()
436
+ if len(set(x)) < profile.num_bins:
437
+ msg = f"Unique number of bins is {len(set(x))} fewer than {profile.num_bins} as desired. Decrease parameter num_bins."
438
+ raise ValueError(msg)
439
+ y = data.get_column(profile.y_name).to_list()
440
+
441
+ scatter_args = {
442
+ "x": x,
443
+ "y": y,
444
+ "range_x": profile.x_bounds,
445
+ "title": "Binscatter",
446
+ "labels": {
447
+ "x": profile.x_name,
448
+ "y": profile.y_name,
449
+ },
450
+ }
451
+ for k in kwargs_binscatter:
452
+ if k in ("x", "y", "range_x"):
453
+ msg = f"px.scatter will ignore keyword argument '{k}'"
454
+ warnings.warn(msg)
455
+ continue
456
+ scatter_args[k] = kwargs_binscatter[k]
457
+
458
+ return px.scatter(**scatter_args)
459
+
460
+
461
+ def _remove_bad_values(
462
+ df: nw.LazyFrame, cols_numeric: Iterable[str], cols_cat: Iterable[str]
463
+ ) -> nw.LazyFrame:
464
+ """Removes nulls and non-finite values for the provided columns."""
465
+
466
+ bad_conditions = []
467
+
468
+ for c in cols_numeric:
469
+ col = nw.col(c)
470
+ bad_conditions.append(col.is_null() | ~col.is_finite() | col.is_nan())
471
+
472
+ for c in cols_cat:
473
+ bad_conditions.append(nw.col(c).is_null())
474
+
475
+ if not bad_conditions:
476
+ return df
477
+
478
+ final_bad_condition = reduce(operator.or_, bad_conditions)
479
+
480
+ return df.filter(~final_bad_condition)
481
+
482
+
483
+ def get_columns(
484
+ frame: nw.LazyFrame | nw.DataFrame,
485
+ ) -> Tuple[Tuple[str, ...], Tuple[str, ...]]:
486
+ """Return tuples of numeric and categorical column names for a narwhals frame."""
487
+
488
+ def _safe_columns(selection: Any) -> Tuple[str, ...]:
489
+ if selection is None:
490
+ return tuple()
491
+ columns: Tuple[str, ...] = tuple()
492
+ if hasattr(selection, "columns"):
493
+ try:
494
+ columns = tuple(selection.columns) # type: ignore[attr-defined]
495
+ except Exception: # pragma: no cover - backend quirk
496
+ columns = tuple()
497
+ if columns:
498
+ return columns
499
+ if hasattr(selection, "collect_schema"):
500
+ try:
501
+ schema = selection.collect_schema()
502
+ except Exception: # pragma: no cover - backend quirk
503
+ schema = None
504
+ else:
505
+ if schema is not None:
506
+ if hasattr(schema, "names") and callable(schema.names):
507
+ return tuple(schema.names())
508
+ if isinstance(schema, dict):
509
+ return tuple(schema.keys())
510
+ return tuple()
511
+
512
+ numeric_cols = _safe_columns(frame.select(ncs.numeric()))
513
+ frame_columns = tuple(frame.columns)
514
+ categorical_cols = tuple(col for col in frame_columns if col not in numeric_cols)
515
+ return numeric_cols, categorical_cols
516
+
517
+
518
+ # Quantiles
519
+
520
+
521
+ # Defined here for testability
522
+ def _add_fallback(
523
+ df: nw.LazyFrame, profile: Profile, probs: List[float]
524
+ ) -> nw.LazyFrame:
525
+ try:
526
+ qs = df.select(
527
+ [
528
+ profile.x_col.quantile(p, interpolation="linear").alias(f"q{p}")
529
+ for p in probs
530
+ ]
531
+ ).collect()
532
+ except TypeError:
533
+ expr = cast(Any, profile.x_col)
534
+ qs = df.select([expr.quantile(p).alias(f"q{p}") for p in probs]).collect()
535
+ except Exception as e:
536
+ logger.error(
537
+ "Tried making quantiles with and without interpolation method for df of type: %s",
538
+ type(df),
539
+ )
540
+ raise e
541
+ qs_long = (
542
+ qs.unpivot(variable_name="prob", value_name="quantile")
543
+ .sort("quantile")
544
+ .with_row_index(profile.bin_name)
545
+ )
546
+
547
+ quantile_bins = qs_long.select("quantile", profile.bin_name).lazy()
548
+
549
+ # Sorting is not always necessary - but for safety we sort
550
+ return (
551
+ df.sort(profile.x_name)
552
+ .join_asof(
553
+ quantile_bins,
554
+ left_on=profile.x_name,
555
+ right_on="quantile",
556
+ strategy="forward",
557
+ )
558
+ .drop("quantile")
559
+ )
560
+
561
+
562
+ def _make_probs(num_bins) -> List[float]:
563
+ return [i / num_bins for i in range(1, num_bins + 1)]
564
+
565
+
566
+ def configure_quantile_handler(profile: Profile) -> Callable:
567
+ probs = _make_probs(profile.num_bins)
568
+
569
+ def add_fallback(df: nw.LazyFrame):
570
+ return _add_fallback(df, profile, probs)
571
+
572
+ def add_to_dask(df: nw.DataFrame) -> nw.LazyFrame:
573
+ try:
574
+ from pandas import cut
575
+ except ImportError:
576
+ raise ImportError("Dask support requires dask and pandas to be installed.")
577
+
578
+ df_native = df.to_native()
579
+ logger.debug("Type of df_native (should be dask): %s", type(df_native))
580
+ quantiles = df_native[profile.x_name].quantile(probs[:-1]).compute()
581
+ bins = (float("-inf"), *quantiles, float("inf"))
582
+ df_native[profile.bin_name] = df_native[profile.x_name].map_partitions(
583
+ cut,
584
+ bins=bins,
585
+ labels=range(len(probs)),
586
+ include_lowest=False,
587
+ right=False,
588
+ )
589
+
590
+ return nw.from_native(df_native).lazy()
591
+
592
+ def add_to_pandas(df: nw.DataFrame) -> nw.LazyFrame:
593
+ try:
594
+ from pandas import cut
595
+ except ImportError:
596
+ raise ImportError("Pandas support requires pandas to be installed.")
597
+ df_native = df.to_native()
598
+ x = df_native[profile.x_name]
599
+ quantiles = x.quantile(probs[:-1])
600
+
601
+ bins = (float("-Inf"), *quantiles, float("Inf"))
602
+ buckets = cut(
603
+ df_native[profile.x_name],
604
+ bins=bins,
605
+ labels=range(len(probs)),
606
+ include_lowest=False,
607
+ right=False,
608
+ )
609
+ df_native[profile.bin_name] = buckets
610
+
611
+ return nw.from_native(df_native).lazy()
612
+
613
+ def add_to_polars(df: nw.DataFrame) -> nw.LazyFrame:
614
+ try:
615
+ import polars as pl
616
+ except ImportError:
617
+ raise ImportError("Polars support requires Polars to be installed.")
618
+ # Because cut and qcut are not stable we use when-then
619
+ df_native = df.to_native()
620
+ x_col = pl.col(profile.x_name)
621
+
622
+ qs = df_native.select(
623
+ [x_col.quantile(p, interpolation="linear").alias(f"q{p}") for p in probs]
624
+ ).collect()
625
+ expr = pl
626
+ n = qs.width
627
+ for i in range(n):
628
+ thr = qs.item(0, i)
629
+ cond = x_col.le(thr) if i == n - 1 else x_col.lt(thr)
630
+ expr = expr.when(cond).then(pl.lit(i))
631
+ expr = expr.alias(profile.bin_name)
632
+ df_native_with_bin = df_native.with_columns(expr)
633
+
634
+ return nw.from_native(df_native_with_bin).lazy()
635
+
636
+ def add_to_duckdb(df: nw.DataFrame) -> nw.LazyFrame:
637
+ try:
638
+ import duckdb
639
+
640
+ rel = df.to_native()
641
+ assert isinstance(rel, duckdb.DuckDBPyRelation), f"{type(rel)=}"
642
+
643
+ except Exception as e:
644
+ raise RuntimeError(
645
+ "Failed to use df.to_native(); DuckDB may not be installed."
646
+ ) from e
647
+
648
+ order_expr = f"{profile.x_name} ASC"
649
+ rel_with_bins = rel.project(
650
+ f"*, ntile({len(probs)}) OVER (ORDER BY {order_expr}) - 1 AS {profile.bin_name}"
651
+ )
652
+ assert isinstance(rel_with_bins, duckdb.DuckDBPyRelation), (
653
+ f"{type(rel_with_bins)=}"
654
+ )
655
+
656
+ return nw.from_native(rel_with_bins).lazy()
657
+
658
+ def add_to_pyspark(df: nw.DataFrame) -> nw.LazyFrame:
659
+ try:
660
+ from pyspark.ml.feature import Bucketizer
661
+ from pyspark.sql.functions import col
662
+ except ImportError as e:
663
+ raise ImportError(
664
+ f"PySpark support requires pyspark to be installed. Original error: {e}"
665
+ ) from e
666
+ sdf = df.to_native()
667
+ qs = sdf.approxQuantile(profile.x_name, (0.0, *probs), relativeError=0.01)
668
+ if logger.isEnabledFor(logging.DEBUG):
669
+ sample = sdf.sample(False, 0.02, seed=1).select(profile.x_name).toPandas()
670
+ pd_qs = sample[profile.x_name].quantile((0.0, *probs)).to_list()
671
+ logger.debug(
672
+ "Pyspark vs pandas (sample) quantiles: %s", list(zip(qs, pd_qs))
673
+ )
674
+ if len(set(qs)) < len(qs):
675
+ raise ValueError("Quantiles not unique. Decrease number of bins.")
676
+
677
+ bucketizer = Bucketizer(
678
+ splits=qs,
679
+ inputCol=profile.x_name,
680
+ outputCol=profile.bin_name,
681
+ handleInvalid="keep",
682
+ )
683
+
684
+ sdf_binned = bucketizer.transform(sdf).withColumn(
685
+ profile.bin_name, col(profile.bin_name).cast("int")
686
+ )
687
+
688
+ return nw.from_native(sdf_binned).lazy()
689
+
690
+ if profile.implementation == Implementation.PANDAS:
691
+ return add_to_pandas
692
+ elif profile.implementation == Implementation.POLARS:
693
+ return add_to_polars
694
+ elif profile.implementation == Implementation.PYSPARK:
695
+ return add_to_pyspark
696
+ elif profile.implementation == Implementation.DUCKDB:
697
+ return add_to_duckdb
698
+ elif profile.implementation == Implementation.DASK:
699
+ return add_to_dask
700
+ else:
701
+ return add_fallback
702
+
703
+
704
+ def _compute_quantiles(
705
+ df: nw.DataFrame, colname: str, probs: Iterable[float], bin_name: str
706
+ ) -> nw.LazyFrame:
707
+ """Get multiple quantiles in one operation"""
708
+ col = nw.col(colname)
709
+ if df.implementation != nw.Implementation.PYSPARK:
710
+ qs = df.select(
711
+ [col.quantile(p, interpolation="linear").alias(f"q{p}") for p in probs]
712
+ )
713
+ else:
714
+ # Pyspark - ugly hack
715
+ try:
716
+ from pyspark.sql import SparkSession
717
+ except ImportError:
718
+ raise ImportError("PySpark support requires pyspark to be installed.")
719
+ spark = SparkSession.builder.getOrCreate()
720
+
721
+ quantiles: list[float] = (
722
+ df.select(colname).to_native().approxQuantile(colname, probs, 0.03)
723
+ )
724
+ q_data = {}
725
+ for p, q in zip(probs, quantiles):
726
+ k = f"q{p}"
727
+ q_data[k] = [q]
728
+ qs_spark = spark.createDataFrame(q_data)
729
+ qs = nw.from_native(qs_spark)
730
+
731
+ return (
732
+ qs.unpivot(variable_name="prob", value_name="quantile")
733
+ .sort("quantile")
734
+ .with_row_index(bin_name, order_by="quantile")
735
+ .lazy()
736
+ )