pytest-drift 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytest_drift-0.1.0/PKG-INFO +146 -0
- pytest_drift-0.1.0/README.md +126 -0
- pytest_drift-0.1.0/pyproject.toml +33 -0
- pytest_drift-0.1.0/pytest_drift/__init__.py +0 -0
- pytest_drift-0.1.0/pytest_drift/compare.py +154 -0
- pytest_drift-0.1.0/pytest_drift/pandas_utils.py +135 -0
- pytest_drift-0.1.0/pytest_drift/plugin.py +292 -0
- pytest_drift-0.1.0/pytest_drift/report.py +68 -0
- pytest_drift-0.1.0/pytest_drift/runner.py +100 -0
- pytest_drift-0.1.0/pytest_drift/storage.py +92 -0
- pytest_drift-0.1.0/pytest_drift.egg-info/PKG-INFO +146 -0
- pytest_drift-0.1.0/pytest_drift.egg-info/SOURCES.txt +19 -0
- pytest_drift-0.1.0/pytest_drift.egg-info/dependency_links.txt +1 -0
- pytest_drift-0.1.0/pytest_drift.egg-info/entry_points.txt +2 -0
- pytest_drift-0.1.0/pytest_drift.egg-info/requires.txt +16 -0
- pytest_drift-0.1.0/pytest_drift.egg-info/top_level.txt +1 -0
- pytest_drift-0.1.0/setup.cfg +4 -0
- pytest_drift-0.1.0/tests/test_compare.py +129 -0
- pytest_drift-0.1.0/tests/test_pandas_utils.py +109 -0
- pytest_drift-0.1.0/tests/test_plugin.py +161 -0
- pytest_drift-0.1.0/tests/test_storage.py +79 -0
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: pytest-drift
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Pytest plugin for regression testing via branch comparison
|
|
5
|
+
Requires-Python: >=3.10
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
Requires-Dist: pytest>=7.0
|
|
8
|
+
Requires-Dist: cloudpickle>=3.0
|
|
9
|
+
Requires-Dist: pandas>=1.5
|
|
10
|
+
Provides-Extra: datacompy
|
|
11
|
+
Requires-Dist: datacompy>=0.9; extra == "datacompy"
|
|
12
|
+
Provides-Extra: parquet
|
|
13
|
+
Requires-Dist: pyarrow>=10.0; extra == "parquet"
|
|
14
|
+
Provides-Extra: dev
|
|
15
|
+
Requires-Dist: pytest; extra == "dev"
|
|
16
|
+
Requires-Dist: pandas; extra == "dev"
|
|
17
|
+
Requires-Dist: numpy; extra == "dev"
|
|
18
|
+
Requires-Dist: datacompy>=0.9; extra == "dev"
|
|
19
|
+
Requires-Dist: pyarrow>=10.0; extra == "dev"
|
|
20
|
+
|
|
21
|
+
# pytest-drift
|
|
22
|
+
|
|
23
|
+
A pytest plugin for regression testing via branch comparison. When a test returns a value, the plugin runs the same test on a base git branch and compares the results — catching regressions before they merge.
|
|
24
|
+
|
|
25
|
+
## How it works
|
|
26
|
+
|
|
27
|
+
1. You run `pytest --drift BASE_BRANCH`
|
|
28
|
+
2. For every test that **returns a non-None value**, the plugin:
|
|
29
|
+
- Records the return value from the current branch (HEAD)
|
|
30
|
+
- Simultaneously runs the same tests on `BASE_BRANCH` in a git worktree
|
|
31
|
+
- Compares the two results at the end of the session
|
|
32
|
+
3. Tests returning `None` (the default for normal pytest tests) are ignored entirely
|
|
33
|
+
|
|
34
|
+
The base branch runs in parallel with your HEAD tests, so total wall time is approximately `max(HEAD_time, BASE_time)` rather than `HEAD_time + BASE_time`.
|
|
35
|
+
|
|
36
|
+
## Installation
|
|
37
|
+
|
|
38
|
+
```bash
|
|
39
|
+
pip install pytest-drift
|
|
40
|
+
|
|
41
|
+
# With smart DataFrame diff reports (recommended):
|
|
42
|
+
pip install "pytest-drift[datacompy]"
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
## Usage
|
|
46
|
+
|
|
47
|
+
### CLI flag
|
|
48
|
+
|
|
49
|
+
```bash
|
|
50
|
+
pytest --drift main
|
|
51
|
+
pytest --drift origin/main
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
### Environment variable
|
|
55
|
+
|
|
56
|
+
```bash
|
|
57
|
+
export PYTEST_DRIFT_BASE_BRANCH=main
|
|
58
|
+
pytest
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
## Writing regression tests
|
|
62
|
+
|
|
63
|
+
Return a value from your test — that's it:
|
|
64
|
+
|
|
65
|
+
```python
|
|
66
|
+
def test_revenue_calculation():
|
|
67
|
+
df = compute_revenue(load_data())
|
|
68
|
+
return df # compared against the same function on BASE_BRANCH
|
|
69
|
+
|
|
70
|
+
def test_model_accuracy():
|
|
71
|
+
return evaluate_model() # compared as a float
|
|
72
|
+
|
|
73
|
+
def test_pipeline_output():
|
|
74
|
+
return run_pipeline() # compared as a dict, list, DataFrame, etc.
|
|
75
|
+
```
|
|
76
|
+
|
|
77
|
+
Normal tests (returning `None`) are unaffected and run as usual.
|
|
78
|
+
|
|
79
|
+
## Comparison logic
|
|
80
|
+
|
|
81
|
+
The plugin dispatches comparison based on the return type:
|
|
82
|
+
|
|
83
|
+
| Type | Comparison method |
|
|
84
|
+
|---|---|
|
|
85
|
+
| `pd.DataFrame` | Auto-detects join columns; uses `datacompy` if installed, else `pd.testing.assert_frame_equal` |
|
|
86
|
+
| `pd.Series` | Converted to DataFrame, same path as above |
|
|
87
|
+
| `float` / `np.floating` | `math.isclose` with `rtol=1e-5, atol=1e-8` |
|
|
88
|
+
| `np.ndarray` | `np.testing.assert_array_almost_equal` (5 decimal places) |
|
|
89
|
+
| `dict` | Recursive key-by-key comparison |
|
|
90
|
+
| `list` / `tuple` | Element-wise comparison |
|
|
91
|
+
| Everything else | `==`, with `repr()` diff on failure |
|
|
92
|
+
|
|
93
|
+
### Pandas index auto-detection
|
|
94
|
+
|
|
95
|
+
When comparing DataFrames, the plugin automatically finds the best join key:
|
|
96
|
+
|
|
97
|
+
1. **Named index**: if the DataFrame already has a named (non-RangeIndex) index, it's used directly
|
|
98
|
+
2. **MultiIndex**: all named index levels are used
|
|
99
|
+
3. **Column heuristic**: searches combinations of up to 3 non-float columns with full cardinality (every row is unique in that combination)
|
|
100
|
+
4. **Positional fallback**: if no unique key is found, rows are compared positionally
|
|
101
|
+
|
|
102
|
+
You can also pass `join_columns` explicitly by calling `compare_dataframes` directly from `pandas_utils`.
|
|
103
|
+
|
|
104
|
+
## Terminal output
|
|
105
|
+
|
|
106
|
+
At the end of the session a regression summary is printed:
|
|
107
|
+
|
|
108
|
+
```
|
|
109
|
+
========================================================================
|
|
110
|
+
REGRESSION COMPARISON SUMMARY
|
|
111
|
+
========================================================================
|
|
112
|
+
PASSED tests/test_revenue.py::test_revenue_calculation
|
|
113
|
+
FAILED tests/test_model.py::test_model_accuracy
|
|
114
|
+
Float mismatch:
|
|
115
|
+
head: 0.923
|
|
116
|
+
base: 0.941
|
|
117
|
+
------------------------------------------------------------------------
|
|
118
|
+
1 passed, 1 failed (2 total regression comparisons)
|
|
119
|
+
```
|
|
120
|
+
|
|
121
|
+
## How branch switching works
|
|
122
|
+
|
|
123
|
+
The plugin uses `git worktree add` to check out `BASE_BRANCH` into a temporary directory — your working tree is never touched. The worktree is cleaned up automatically after the session.
|
|
124
|
+
|
|
125
|
+
```
|
|
126
|
+
HEAD tests run ─────────────────────────▶ sessionfinish
|
|
127
|
+
│
|
|
128
|
+
git worktree add ──▶ BASE tests run in parallel ────┘ compare
|
|
129
|
+
```
|
|
130
|
+
|
|
131
|
+
## Requirements
|
|
132
|
+
|
|
133
|
+
| Package | Required | Purpose |
|
|
134
|
+
|---|---|---|
|
|
135
|
+
| `pytest >= 7.0` | Yes | Core |
|
|
136
|
+
| `cloudpickle >= 3.0` | Yes | Serialization of return values |
|
|
137
|
+
| `pandas >= 1.5` | Yes | DataFrame/Series support |
|
|
138
|
+
| `datacompy >= 0.9` | Optional | Rich DataFrame diff reports |
|
|
139
|
+
| `pyarrow >= 10.0` | Optional | Parquet storage for large DataFrames |
|
|
140
|
+
|
|
141
|
+
## Caveats
|
|
142
|
+
|
|
143
|
+
- The base branch subprocess uses the same Python environment as HEAD — if your project uses `tox` or `nox`, point to the correct environment
|
|
144
|
+
- Session-scoped fixtures with side effects (e.g. starting a server) will run twice — once per session
|
|
145
|
+
- Tests that fail on HEAD are not compared (no base result is fetched for them)
|
|
146
|
+
- Tests that fail on BASE produce a "base branch test failed, cannot compare" warning
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
# pytest-drift
|
|
2
|
+
|
|
3
|
+
A pytest plugin for regression testing via branch comparison. When a test returns a value, the plugin runs the same test on a base git branch and compares the results — catching regressions before they merge.
|
|
4
|
+
|
|
5
|
+
## How it works
|
|
6
|
+
|
|
7
|
+
1. You run `pytest --drift BASE_BRANCH`
|
|
8
|
+
2. For every test that **returns a non-None value**, the plugin:
|
|
9
|
+
- Records the return value from the current branch (HEAD)
|
|
10
|
+
- Simultaneously runs the same tests on `BASE_BRANCH` in a git worktree
|
|
11
|
+
- Compares the two results at the end of the session
|
|
12
|
+
3. Tests returning `None` (the default for normal pytest tests) are ignored entirely
|
|
13
|
+
|
|
14
|
+
The base branch runs in parallel with your HEAD tests, so total wall time is approximately `max(HEAD_time, BASE_time)` rather than `HEAD_time + BASE_time`.
|
|
15
|
+
|
|
16
|
+
## Installation
|
|
17
|
+
|
|
18
|
+
```bash
|
|
19
|
+
pip install pytest-drift
|
|
20
|
+
|
|
21
|
+
# With smart DataFrame diff reports (recommended):
|
|
22
|
+
pip install "pytest-drift[datacompy]"
|
|
23
|
+
```
|
|
24
|
+
|
|
25
|
+
## Usage
|
|
26
|
+
|
|
27
|
+
### CLI flag
|
|
28
|
+
|
|
29
|
+
```bash
|
|
30
|
+
pytest --drift main
|
|
31
|
+
pytest --drift origin/main
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
### Environment variable
|
|
35
|
+
|
|
36
|
+
```bash
|
|
37
|
+
export PYTEST_DRIFT_BASE_BRANCH=main
|
|
38
|
+
pytest
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
## Writing regression tests
|
|
42
|
+
|
|
43
|
+
Return a value from your test — that's it:
|
|
44
|
+
|
|
45
|
+
```python
|
|
46
|
+
def test_revenue_calculation():
|
|
47
|
+
df = compute_revenue(load_data())
|
|
48
|
+
return df # compared against the same function on BASE_BRANCH
|
|
49
|
+
|
|
50
|
+
def test_model_accuracy():
|
|
51
|
+
return evaluate_model() # compared as a float
|
|
52
|
+
|
|
53
|
+
def test_pipeline_output():
|
|
54
|
+
return run_pipeline() # compared as a dict, list, DataFrame, etc.
|
|
55
|
+
```
|
|
56
|
+
|
|
57
|
+
Normal tests (returning `None`) are unaffected and run as usual.
|
|
58
|
+
|
|
59
|
+
## Comparison logic
|
|
60
|
+
|
|
61
|
+
The plugin dispatches comparison based on the return type:
|
|
62
|
+
|
|
63
|
+
| Type | Comparison method |
|
|
64
|
+
|---|---|
|
|
65
|
+
| `pd.DataFrame` | Auto-detects join columns; uses `datacompy` if installed, else `pd.testing.assert_frame_equal` |
|
|
66
|
+
| `pd.Series` | Converted to DataFrame, same path as above |
|
|
67
|
+
| `float` / `np.floating` | `math.isclose` with `rtol=1e-5, atol=1e-8` |
|
|
68
|
+
| `np.ndarray` | `np.testing.assert_array_almost_equal` (5 decimal places) |
|
|
69
|
+
| `dict` | Recursive key-by-key comparison |
|
|
70
|
+
| `list` / `tuple` | Element-wise comparison |
|
|
71
|
+
| Everything else | `==`, with `repr()` diff on failure |
|
|
72
|
+
|
|
73
|
+
### Pandas index auto-detection
|
|
74
|
+
|
|
75
|
+
When comparing DataFrames, the plugin automatically finds the best join key:
|
|
76
|
+
|
|
77
|
+
1. **Named index**: if the DataFrame already has a named (non-RangeIndex) index, it's used directly
|
|
78
|
+
2. **MultiIndex**: all named index levels are used
|
|
79
|
+
3. **Column heuristic**: searches combinations of up to 3 non-float columns with full cardinality (every row is unique in that combination)
|
|
80
|
+
4. **Positional fallback**: if no unique key is found, rows are compared positionally
|
|
81
|
+
|
|
82
|
+
You can also pass `join_columns` explicitly by calling `compare_dataframes` directly from `pandas_utils`.
|
|
83
|
+
|
|
84
|
+
## Terminal output
|
|
85
|
+
|
|
86
|
+
At the end of the session a regression summary is printed:
|
|
87
|
+
|
|
88
|
+
```
|
|
89
|
+
========================================================================
|
|
90
|
+
REGRESSION COMPARISON SUMMARY
|
|
91
|
+
========================================================================
|
|
92
|
+
PASSED tests/test_revenue.py::test_revenue_calculation
|
|
93
|
+
FAILED tests/test_model.py::test_model_accuracy
|
|
94
|
+
Float mismatch:
|
|
95
|
+
head: 0.923
|
|
96
|
+
base: 0.941
|
|
97
|
+
------------------------------------------------------------------------
|
|
98
|
+
1 passed, 1 failed (2 total regression comparisons)
|
|
99
|
+
```
|
|
100
|
+
|
|
101
|
+
## How branch switching works
|
|
102
|
+
|
|
103
|
+
The plugin uses `git worktree add` to check out `BASE_BRANCH` into a temporary directory — your working tree is never touched. The worktree is cleaned up automatically after the session.
|
|
104
|
+
|
|
105
|
+
```
|
|
106
|
+
HEAD tests run ─────────────────────────▶ sessionfinish
|
|
107
|
+
│
|
|
108
|
+
git worktree add ──▶ BASE tests run in parallel ────┘ compare
|
|
109
|
+
```
|
|
110
|
+
|
|
111
|
+
## Requirements
|
|
112
|
+
|
|
113
|
+
| Package | Required | Purpose |
|
|
114
|
+
|---|---|---|
|
|
115
|
+
| `pytest >= 7.0` | Yes | Core |
|
|
116
|
+
| `cloudpickle >= 3.0` | Yes | Serialization of return values |
|
|
117
|
+
| `pandas >= 1.5` | Yes | DataFrame/Series support |
|
|
118
|
+
| `datacompy >= 0.9` | Optional | Rich DataFrame diff reports |
|
|
119
|
+
| `pyarrow >= 10.0` | Optional | Parquet storage for large DataFrames |
|
|
120
|
+
|
|
121
|
+
## Caveats
|
|
122
|
+
|
|
123
|
+
- The base branch subprocess uses the same Python environment as HEAD — if your project uses `tox` or `nox`, point to the correct environment
|
|
124
|
+
- Session-scoped fixtures with side effects (e.g. starting a server) will run twice — once per session
|
|
125
|
+
- Tests that fail on HEAD are not compared (no base result is fetched for them)
|
|
126
|
+
- Tests that fail on BASE produce a "base branch test failed, cannot compare" warning
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=42", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "pytest-drift"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Pytest plugin for regression testing via branch comparison"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.10"
|
|
11
|
+
dependencies = [
|
|
12
|
+
"pytest>=7.0",
|
|
13
|
+
"cloudpickle>=3.0",
|
|
14
|
+
"pandas>=1.5",
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
[project.optional-dependencies]
|
|
18
|
+
datacompy = ["datacompy>=0.9"]
|
|
19
|
+
parquet = ["pyarrow>=10.0"]
|
|
20
|
+
dev = [
|
|
21
|
+
"pytest",
|
|
22
|
+
"pandas",
|
|
23
|
+
"numpy",
|
|
24
|
+
"datacompy>=0.9",
|
|
25
|
+
"pyarrow>=10.0",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
[project.entry-points.pytest11]
|
|
29
|
+
drift = "pytest_drift.plugin"
|
|
30
|
+
|
|
31
|
+
[tool.setuptools.packages.find]
|
|
32
|
+
where = ["."]
|
|
33
|
+
include = ["pytest_drift*"]
|
|
File without changes
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
"""Type-dispatching comparison logic for test return values."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import math
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from .pandas_utils import ComparisonResult, compare_dataframes, compare_series
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def compare_values(head: Any, base: Any, node_id: str = "") -> ComparisonResult:
|
|
11
|
+
"""
|
|
12
|
+
Compare head (current branch) and base (base branch) return values.
|
|
13
|
+
Dispatches based on type.
|
|
14
|
+
"""
|
|
15
|
+
result = _dispatch(head, base)
|
|
16
|
+
result.node_id = node_id
|
|
17
|
+
return result
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _dispatch(head: Any, base: Any) -> ComparisonResult:
|
|
21
|
+
# Check pandas types first (before generic checks)
|
|
22
|
+
try:
|
|
23
|
+
import pandas as pd
|
|
24
|
+
|
|
25
|
+
if isinstance(head, pd.DataFrame) and isinstance(base, pd.DataFrame):
|
|
26
|
+
return compare_dataframes(head, base)
|
|
27
|
+
if isinstance(head, pd.Series) and isinstance(base, pd.Series):
|
|
28
|
+
return compare_series(head, base)
|
|
29
|
+
if isinstance(head, (pd.DataFrame, pd.Series)) or isinstance(
|
|
30
|
+
base, (pd.DataFrame, pd.Series)
|
|
31
|
+
):
|
|
32
|
+
return ComparisonResult(
|
|
33
|
+
equal=False,
|
|
34
|
+
report=f"Type mismatch: head={type(head).__name__}, base={type(base).__name__}",
|
|
35
|
+
)
|
|
36
|
+
except ImportError:
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
# numpy arrays
|
|
40
|
+
try:
|
|
41
|
+
import numpy as np
|
|
42
|
+
|
|
43
|
+
if isinstance(head, np.ndarray) and isinstance(base, np.ndarray):
|
|
44
|
+
return _compare_arrays(head, base)
|
|
45
|
+
except ImportError:
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
# float scalars
|
|
49
|
+
if isinstance(head, float) and isinstance(base, float):
|
|
50
|
+
return _compare_floats(head, base)
|
|
51
|
+
|
|
52
|
+
# numpy scalars that are float-like
|
|
53
|
+
try:
|
|
54
|
+
import numpy as np
|
|
55
|
+
|
|
56
|
+
if isinstance(head, np.floating) and isinstance(base, np.floating):
|
|
57
|
+
return _compare_floats(float(head), float(base))
|
|
58
|
+
except ImportError:
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
# dict
|
|
62
|
+
if isinstance(head, dict) and isinstance(base, dict):
|
|
63
|
+
return _compare_dicts(head, base)
|
|
64
|
+
|
|
65
|
+
# list / tuple
|
|
66
|
+
if isinstance(head, (list, tuple)) and isinstance(base, (list, tuple)):
|
|
67
|
+
return _compare_sequences(head, base)
|
|
68
|
+
|
|
69
|
+
# generic fallback
|
|
70
|
+
return _compare_generic(head, base)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _compare_floats(head: float, base: float, rtol: float = 1e-5, atol: float = 1e-8) -> ComparisonResult:
|
|
74
|
+
if math.isnan(head) and math.isnan(base):
|
|
75
|
+
return ComparisonResult(equal=True, report=None)
|
|
76
|
+
equal = math.isclose(head, base, rel_tol=rtol, abs_tol=atol)
|
|
77
|
+
report = None if equal else f"Float mismatch: head={head!r}, base={base!r}"
|
|
78
|
+
return ComparisonResult(equal=equal, report=report)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _compare_arrays(head, base) -> ComparisonResult:
|
|
82
|
+
import numpy as np
|
|
83
|
+
|
|
84
|
+
if head.shape != base.shape:
|
|
85
|
+
return ComparisonResult(
|
|
86
|
+
equal=False,
|
|
87
|
+
report=f"Shape mismatch: head={head.shape}, base={base.shape}",
|
|
88
|
+
)
|
|
89
|
+
try:
|
|
90
|
+
np.testing.assert_array_almost_equal(head, base, decimal=5)
|
|
91
|
+
return ComparisonResult(equal=True, report=None)
|
|
92
|
+
except AssertionError as e:
|
|
93
|
+
return ComparisonResult(equal=False, report=str(e))
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _compare_dicts(head: dict, base: dict) -> ComparisonResult:
|
|
97
|
+
head_keys = set(head.keys())
|
|
98
|
+
base_keys = set(base.keys())
|
|
99
|
+
|
|
100
|
+
if head_keys != base_keys:
|
|
101
|
+
only_head = head_keys - base_keys
|
|
102
|
+
only_base = base_keys - head_keys
|
|
103
|
+
parts = []
|
|
104
|
+
if only_head:
|
|
105
|
+
parts.append(f"Keys only in head: {sorted(str(k) for k in only_head)}")
|
|
106
|
+
if only_base:
|
|
107
|
+
parts.append(f"Keys only in base: {sorted(str(k) for k in only_base)}")
|
|
108
|
+
return ComparisonResult(equal=False, report="\n".join(parts))
|
|
109
|
+
|
|
110
|
+
mismatches = []
|
|
111
|
+
for key in sorted(head_keys, key=str):
|
|
112
|
+
sub = _dispatch(head[key], base[key])
|
|
113
|
+
if not sub.equal:
|
|
114
|
+
mismatches.append(f" Key {key!r}: {sub.report}")
|
|
115
|
+
|
|
116
|
+
if mismatches:
|
|
117
|
+
return ComparisonResult(
|
|
118
|
+
equal=False, report="Dict mismatches:\n" + "\n".join(mismatches)
|
|
119
|
+
)
|
|
120
|
+
return ComparisonResult(equal=True, report=None)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _compare_sequences(head, base) -> ComparisonResult:
|
|
124
|
+
if len(head) != len(base):
|
|
125
|
+
return ComparisonResult(
|
|
126
|
+
equal=False,
|
|
127
|
+
report=f"Length mismatch: head={len(head)}, base={len(base)}",
|
|
128
|
+
)
|
|
129
|
+
mismatches = []
|
|
130
|
+
for i, (h, b) in enumerate(zip(head, base)):
|
|
131
|
+
sub = _dispatch(h, b)
|
|
132
|
+
if not sub.equal:
|
|
133
|
+
mismatches.append(f" Index {i}: {sub.report}")
|
|
134
|
+
if mismatches:
|
|
135
|
+
return ComparisonResult(
|
|
136
|
+
equal=False,
|
|
137
|
+
report=f"{type(head).__name__} mismatches:\n" + "\n".join(mismatches),
|
|
138
|
+
)
|
|
139
|
+
return ComparisonResult(equal=True, report=None)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _compare_generic(head: Any, base: Any) -> ComparisonResult:
|
|
143
|
+
try:
|
|
144
|
+
equal = bool(head == base)
|
|
145
|
+
except Exception:
|
|
146
|
+
equal = False
|
|
147
|
+
|
|
148
|
+
if equal:
|
|
149
|
+
return ComparisonResult(equal=True, report=None)
|
|
150
|
+
|
|
151
|
+
return ComparisonResult(
|
|
152
|
+
equal=False,
|
|
153
|
+
report=f"Value mismatch:\n head: {head!r}\n base: {base!r}",
|
|
154
|
+
)
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
"""Pandas DataFrame/Series index auto-detection and comparison."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import itertools
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
import pandas as pd
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class ComparisonResult:
|
|
14
|
+
equal: bool
|
|
15
|
+
report: str | None
|
|
16
|
+
node_id: str = ""
|
|
17
|
+
extra: dict = field(default_factory=dict)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def detect_index_columns(df: "pd.DataFrame", max_combo_size: int = 3) -> list[str] | None:
|
|
21
|
+
"""
|
|
22
|
+
Auto-detect which columns can serve as join keys for comparison.
|
|
23
|
+
|
|
24
|
+
Priority:
|
|
25
|
+
1. Named non-RangeIndex (already set as index)
|
|
26
|
+
2. Heuristic: find smallest combo of non-float cols with full cardinality
|
|
27
|
+
3. None → fall back to positional comparison
|
|
28
|
+
"""
|
|
29
|
+
import pandas as pd
|
|
30
|
+
|
|
31
|
+
# Case A: already has a meaningful named index
|
|
32
|
+
if not isinstance(df.index, pd.RangeIndex):
|
|
33
|
+
if isinstance(df.index, pd.MultiIndex):
|
|
34
|
+
if all(name is not None for name in df.index.names):
|
|
35
|
+
return list(df.index.names)
|
|
36
|
+
elif df.index.name is not None:
|
|
37
|
+
return [df.index.name]
|
|
38
|
+
|
|
39
|
+
n = len(df)
|
|
40
|
+
if n == 0 or len(df.columns) == 0:
|
|
41
|
+
return None
|
|
42
|
+
|
|
43
|
+
# Case B: heuristic column search (exclude float columns — poor join keys)
|
|
44
|
+
candidate_cols = [
|
|
45
|
+
c for c in df.columns if not pd.api.types.is_float_dtype(df[c].dtype)
|
|
46
|
+
]
|
|
47
|
+
# Sort by cardinality descending (higher = better key candidate)
|
|
48
|
+
candidate_cols = sorted(candidate_cols, key=lambda c: df[c].nunique(), reverse=True)
|
|
49
|
+
|
|
50
|
+
for r in range(1, min(max_combo_size + 1, len(candidate_cols) + 1)):
|
|
51
|
+
for combo in itertools.combinations(candidate_cols, r):
|
|
52
|
+
try:
|
|
53
|
+
if df.groupby(list(combo)).ngroups == n:
|
|
54
|
+
return list(combo)
|
|
55
|
+
except Exception:
|
|
56
|
+
continue
|
|
57
|
+
|
|
58
|
+
return None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _reset_named_index(df: "pd.DataFrame") -> "pd.DataFrame":
|
|
62
|
+
"""If df has a named index, reset it to columns."""
|
|
63
|
+
import pandas as pd
|
|
64
|
+
|
|
65
|
+
if not isinstance(df.index, pd.RangeIndex):
|
|
66
|
+
return df.reset_index()
|
|
67
|
+
return df
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def compare_dataframes(
|
|
71
|
+
head_df: "pd.DataFrame",
|
|
72
|
+
base_df: "pd.DataFrame",
|
|
73
|
+
join_columns: list[str] | None = None,
|
|
74
|
+
) -> ComparisonResult:
|
|
75
|
+
"""Compare two DataFrames, auto-detecting join columns if not provided."""
|
|
76
|
+
import pandas as pd
|
|
77
|
+
|
|
78
|
+
head_flat = _reset_named_index(head_df)
|
|
79
|
+
base_flat = _reset_named_index(base_df)
|
|
80
|
+
|
|
81
|
+
if join_columns is None:
|
|
82
|
+
join_columns = detect_index_columns(head_flat)
|
|
83
|
+
|
|
84
|
+
# Try datacompy first
|
|
85
|
+
try:
|
|
86
|
+
import datacompy
|
|
87
|
+
|
|
88
|
+
if not hasattr(datacompy, "Compare"):
|
|
89
|
+
raise ImportError("datacompy.Compare not available")
|
|
90
|
+
|
|
91
|
+
if join_columns is None:
|
|
92
|
+
# No key found; use all columns positionally by adding a row-number key
|
|
93
|
+
head_flat = head_flat.copy()
|
|
94
|
+
base_flat = base_flat.copy()
|
|
95
|
+
head_flat["__row__"] = range(len(head_flat))
|
|
96
|
+
base_flat["__row__"] = range(len(base_flat))
|
|
97
|
+
join_columns = ["__row__"]
|
|
98
|
+
|
|
99
|
+
cmp = datacompy.Compare(
|
|
100
|
+
head_flat,
|
|
101
|
+
base_flat,
|
|
102
|
+
join_columns=join_columns,
|
|
103
|
+
df1_name="head",
|
|
104
|
+
df2_name="base",
|
|
105
|
+
)
|
|
106
|
+
equal = cmp.matches()
|
|
107
|
+
return ComparisonResult(
|
|
108
|
+
equal=equal,
|
|
109
|
+
report=None if equal else cmp.report(),
|
|
110
|
+
)
|
|
111
|
+
except ImportError:
|
|
112
|
+
pass
|
|
113
|
+
|
|
114
|
+
# Fallback: pd.testing.assert_frame_equal
|
|
115
|
+
try:
|
|
116
|
+
if join_columns:
|
|
117
|
+
head_sorted = head_flat.set_index(join_columns).sort_index()
|
|
118
|
+
base_sorted = base_flat.set_index(join_columns).sort_index()
|
|
119
|
+
else:
|
|
120
|
+
head_sorted = head_flat.reset_index(drop=True)
|
|
121
|
+
base_sorted = base_flat.reset_index(drop=True)
|
|
122
|
+
|
|
123
|
+
pd.testing.assert_frame_equal(
|
|
124
|
+
head_sorted, base_sorted, check_like=True, rtol=1e-5
|
|
125
|
+
)
|
|
126
|
+
return ComparisonResult(equal=True, report=None)
|
|
127
|
+
except AssertionError as e:
|
|
128
|
+
return ComparisonResult(equal=False, report=str(e))
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def compare_series(head_s: "pd.Series", base_s: "pd.Series") -> ComparisonResult:
|
|
132
|
+
"""Compare two Series by converting to DataFrames."""
|
|
133
|
+
head_df = head_s.to_frame(name=head_s.name or "value")
|
|
134
|
+
base_df = base_s.to_frame(name=base_s.name or "value")
|
|
135
|
+
return compare_dataframes(head_df, base_df)
|