pytest-drift 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,146 @@
1
+ Metadata-Version: 2.4
2
+ Name: pytest-drift
3
+ Version: 0.1.0
4
+ Summary: Pytest plugin for regression testing via branch comparison
5
+ Requires-Python: >=3.10
6
+ Description-Content-Type: text/markdown
7
+ Requires-Dist: pytest>=7.0
8
+ Requires-Dist: cloudpickle>=3.0
9
+ Requires-Dist: pandas>=1.5
10
+ Provides-Extra: datacompy
11
+ Requires-Dist: datacompy>=0.9; extra == "datacompy"
12
+ Provides-Extra: parquet
13
+ Requires-Dist: pyarrow>=10.0; extra == "parquet"
14
+ Provides-Extra: dev
15
+ Requires-Dist: pytest; extra == "dev"
16
+ Requires-Dist: pandas; extra == "dev"
17
+ Requires-Dist: numpy; extra == "dev"
18
+ Requires-Dist: datacompy>=0.9; extra == "dev"
19
+ Requires-Dist: pyarrow>=10.0; extra == "dev"
20
+
21
+ # pytest-drift
22
+
23
+ A pytest plugin for regression testing via branch comparison. When a test returns a value, the plugin runs the same test on a base git branch and compares the results — catching regressions before they merge.
24
+
25
+ ## How it works
26
+
27
+ 1. You run `pytest --drift BASE_BRANCH`
28
+ 2. For every test that **returns a non-None value**, the plugin:
29
+ - Records the return value from the current branch (HEAD)
30
+ - Simultaneously runs the same tests on `BASE_BRANCH` in a git worktree
31
+ - Compares the two results at the end of the session
32
+ 3. Tests returning `None` (the default for normal pytest tests) are ignored entirely
33
+
34
+ The base branch runs in parallel with your HEAD tests, so total wall time is approximately `max(HEAD_time, BASE_time)` rather than `HEAD_time + BASE_time`.
35
+
36
+ ## Installation
37
+
38
+ ```bash
39
+ pip install pytest-drift
40
+
41
+ # With smart DataFrame diff reports (recommended):
42
+ pip install "pytest-drift[datacompy]"
43
+ ```
44
+
45
+ ## Usage
46
+
47
+ ### CLI flag
48
+
49
+ ```bash
50
+ pytest --drift main
51
+ pytest --drift origin/main
52
+ ```
53
+
54
+ ### Environment variable
55
+
56
+ ```bash
57
+ export PYTEST_DRIFT_BASE_BRANCH=main
58
+ pytest
59
+ ```
60
+
61
+ ## Writing regression tests
62
+
63
+ Return a value from your test — that's it:
64
+
65
+ ```python
66
+ def test_revenue_calculation():
67
+ df = compute_revenue(load_data())
68
+ return df # compared against the same function on BASE_BRANCH
69
+
70
+ def test_model_accuracy():
71
+ return evaluate_model() # compared as a float
72
+
73
+ def test_pipeline_output():
74
+ return run_pipeline() # compared as a dict, list, DataFrame, etc.
75
+ ```
76
+
77
+ Normal tests (returning `None`) are unaffected and run as usual.
78
+
79
+ ## Comparison logic
80
+
81
+ The plugin dispatches comparison based on the return type:
82
+
83
+ | Type | Comparison method |
84
+ |---|---|
85
+ | `pd.DataFrame` | Auto-detects join columns; uses `datacompy` if installed, else `pd.testing.assert_frame_equal` |
86
+ | `pd.Series` | Converted to DataFrame, same path as above |
87
+ | `float` / `np.floating` | `math.isclose` with `rtol=1e-5, atol=1e-8` |
88
+ | `np.ndarray` | `np.testing.assert_array_almost_equal` (5 decimal places) |
89
+ | `dict` | Recursive key-by-key comparison |
90
+ | `list` / `tuple` | Element-wise comparison |
91
+ | Everything else | `==`, with `repr()` diff on failure |
92
+
93
+ ### Pandas index auto-detection
94
+
95
+ When comparing DataFrames, the plugin automatically finds the best join key:
96
+
97
+ 1. **Named index**: if the DataFrame already has a named (non-RangeIndex) index, it's used directly
98
+ 2. **MultiIndex**: all named index levels are used
99
+ 3. **Column heuristic**: searches combinations of up to 3 non-float columns with full cardinality (every row is unique in that combination)
100
+ 4. **Positional fallback**: if no unique key is found, rows are compared positionally
101
+
102
+ You can also pass `join_columns` explicitly by calling `compare_dataframes` directly from `pandas_utils`.
103
+
104
+ ## Terminal output
105
+
106
+ At the end of the session a regression summary is printed:
107
+
108
+ ```
109
+ ========================================================================
110
+ REGRESSION COMPARISON SUMMARY
111
+ ========================================================================
112
+ PASSED tests/test_revenue.py::test_revenue_calculation
113
+ FAILED tests/test_model.py::test_model_accuracy
114
+ Float mismatch:
115
+ head: 0.923
116
+ base: 0.941
117
+ ------------------------------------------------------------------------
118
+ 1 passed, 1 failed (2 total regression comparisons)
119
+ ```
120
+
121
+ ## How branch switching works
122
+
123
+ The plugin uses `git worktree add` to check out `BASE_BRANCH` into a temporary directory — your working tree is never touched. The worktree is cleaned up automatically after the session.
124
+
125
+ ```
126
+ HEAD tests run ─────────────────────────▶ sessionfinish
127
+
128
+ git worktree add ──▶ BASE tests run in parallel ────┘ compare
129
+ ```
130
+
131
+ ## Requirements
132
+
133
+ | Package | Required | Purpose |
134
+ |---|---|---|
135
+ | `pytest >= 7.0` | Yes | Core |
136
+ | `cloudpickle >= 3.0` | Yes | Serialization of return values |
137
+ | `pandas >= 1.5` | Yes | DataFrame/Series support |
138
+ | `datacompy >= 0.9` | Optional | Rich DataFrame diff reports |
139
+ | `pyarrow >= 10.0` | Optional | Parquet storage for large DataFrames |
140
+
141
+ ## Caveats
142
+
143
+ - The base branch subprocess uses the same Python environment as HEAD — if your project uses `tox` or `nox`, point to the correct environment
144
+ - Session-scoped fixtures with side effects (e.g. starting a server) will run twice — once per session
145
+ - Tests that fail on HEAD are not compared (no base result is fetched for them)
146
+ - Tests that fail on BASE produce a "base branch test failed, cannot compare" warning
@@ -0,0 +1,126 @@
1
+ # pytest-drift
2
+
3
+ A pytest plugin for regression testing via branch comparison. When a test returns a value, the plugin runs the same test on a base git branch and compares the results — catching regressions before they merge.
4
+
5
+ ## How it works
6
+
7
+ 1. You run `pytest --drift BASE_BRANCH`
8
+ 2. For every test that **returns a non-None value**, the plugin:
9
+ - Records the return value from the current branch (HEAD)
10
+ - Simultaneously runs the same tests on `BASE_BRANCH` in a git worktree
11
+ - Compares the two results at the end of the session
12
+ 3. Tests returning `None` (the default for normal pytest tests) are ignored entirely
13
+
14
+ The base branch runs in parallel with your HEAD tests, so total wall time is approximately `max(HEAD_time, BASE_time)` rather than `HEAD_time + BASE_time`.
15
+
16
+ ## Installation
17
+
18
+ ```bash
19
+ pip install pytest-drift
20
+
21
+ # With smart DataFrame diff reports (recommended):
22
+ pip install "pytest-drift[datacompy]"
23
+ ```
24
+
25
+ ## Usage
26
+
27
+ ### CLI flag
28
+
29
+ ```bash
30
+ pytest --drift main
31
+ pytest --drift origin/main
32
+ ```
33
+
34
+ ### Environment variable
35
+
36
+ ```bash
37
+ export PYTEST_DRIFT_BASE_BRANCH=main
38
+ pytest
39
+ ```
40
+
41
+ ## Writing regression tests
42
+
43
+ Return a value from your test — that's it:
44
+
45
+ ```python
46
+ def test_revenue_calculation():
47
+ df = compute_revenue(load_data())
48
+ return df # compared against the same function on BASE_BRANCH
49
+
50
+ def test_model_accuracy():
51
+ return evaluate_model() # compared as a float
52
+
53
+ def test_pipeline_output():
54
+ return run_pipeline() # compared as a dict, list, DataFrame, etc.
55
+ ```
56
+
57
+ Normal tests (returning `None`) are unaffected and run as usual.
58
+
59
+ ## Comparison logic
60
+
61
+ The plugin dispatches comparison based on the return type:
62
+
63
+ | Type | Comparison method |
64
+ |---|---|
65
+ | `pd.DataFrame` | Auto-detects join columns; uses `datacompy` if installed, else `pd.testing.assert_frame_equal` |
66
+ | `pd.Series` | Converted to DataFrame, same path as above |
67
+ | `float` / `np.floating` | `math.isclose` with `rtol=1e-5, atol=1e-8` |
68
+ | `np.ndarray` | `np.testing.assert_array_almost_equal` (5 decimal places) |
69
+ | `dict` | Recursive key-by-key comparison |
70
+ | `list` / `tuple` | Element-wise comparison |
71
+ | Everything else | `==`, with `repr()` diff on failure |
72
+
73
+ ### Pandas index auto-detection
74
+
75
+ When comparing DataFrames, the plugin automatically finds the best join key:
76
+
77
+ 1. **Named index**: if the DataFrame already has a named (non-RangeIndex) index, it's used directly
78
+ 2. **MultiIndex**: all named index levels are used
79
+ 3. **Column heuristic**: searches combinations of up to 3 non-float columns with full cardinality (every row is unique in that combination)
80
+ 4. **Positional fallback**: if no unique key is found, rows are compared positionally
81
+
82
+ You can also pass `join_columns` explicitly by calling `compare_dataframes` directly from `pandas_utils`.
83
+
84
+ ## Terminal output
85
+
86
+ At the end of the session a regression summary is printed:
87
+
88
+ ```
89
+ ========================================================================
90
+ REGRESSION COMPARISON SUMMARY
91
+ ========================================================================
92
+ PASSED tests/test_revenue.py::test_revenue_calculation
93
+ FAILED tests/test_model.py::test_model_accuracy
94
+ Float mismatch:
95
+ head: 0.923
96
+ base: 0.941
97
+ ------------------------------------------------------------------------
98
+ 1 passed, 1 failed (2 total regression comparisons)
99
+ ```
100
+
101
+ ## How branch switching works
102
+
103
+ The plugin uses `git worktree add` to check out `BASE_BRANCH` into a temporary directory — your working tree is never touched. The worktree is cleaned up automatically after the session.
104
+
105
+ ```
106
+ HEAD tests run ─────────────────────────▶ sessionfinish
107
+
108
+ git worktree add ──▶ BASE tests run in parallel ────┘ compare
109
+ ```
110
+
111
+ ## Requirements
112
+
113
+ | Package | Required | Purpose |
114
+ |---|---|---|
115
+ | `pytest >= 7.0` | Yes | Core |
116
+ | `cloudpickle >= 3.0` | Yes | Serialization of return values |
117
+ | `pandas >= 1.5` | Yes | DataFrame/Series support |
118
+ | `datacompy >= 0.9` | Optional | Rich DataFrame diff reports |
119
+ | `pyarrow >= 10.0` | Optional | Parquet storage for large DataFrames |
120
+
121
+ ## Caveats
122
+
123
+ - The base branch subprocess uses the same Python environment as HEAD — if your project uses `tox` or `nox`, point to the correct environment
124
+ - Session-scoped fixtures with side effects (e.g. starting a server) will run twice — once per session
125
+ - Tests that fail on HEAD are not compared (no base result is fetched for them)
126
+ - Tests that fail on BASE produce a "base branch test failed, cannot compare" warning
@@ -0,0 +1,33 @@
1
+ [build-system]
2
+ requires = ["setuptools>=42", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "pytest-drift"
7
+ version = "0.1.0"
8
+ description = "Pytest plugin for regression testing via branch comparison"
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ dependencies = [
12
+ "pytest>=7.0",
13
+ "cloudpickle>=3.0",
14
+ "pandas>=1.5",
15
+ ]
16
+
17
+ [project.optional-dependencies]
18
+ datacompy = ["datacompy>=0.9"]
19
+ parquet = ["pyarrow>=10.0"]
20
+ dev = [
21
+ "pytest",
22
+ "pandas",
23
+ "numpy",
24
+ "datacompy>=0.9",
25
+ "pyarrow>=10.0",
26
+ ]
27
+
28
+ [project.entry-points.pytest11]
29
+ drift = "pytest_drift.plugin"
30
+
31
+ [tool.setuptools.packages.find]
32
+ where = ["."]
33
+ include = ["pytest_drift*"]
File without changes
@@ -0,0 +1,154 @@
1
+ """Type-dispatching comparison logic for test return values."""
2
+ from __future__ import annotations
3
+
4
+ import math
5
+ from typing import Any
6
+
7
+ from .pandas_utils import ComparisonResult, compare_dataframes, compare_series
8
+
9
+
10
+ def compare_values(head: Any, base: Any, node_id: str = "") -> ComparisonResult:
11
+ """
12
+ Compare head (current branch) and base (base branch) return values.
13
+ Dispatches based on type.
14
+ """
15
+ result = _dispatch(head, base)
16
+ result.node_id = node_id
17
+ return result
18
+
19
+
20
+ def _dispatch(head: Any, base: Any) -> ComparisonResult:
21
+ # Check pandas types first (before generic checks)
22
+ try:
23
+ import pandas as pd
24
+
25
+ if isinstance(head, pd.DataFrame) and isinstance(base, pd.DataFrame):
26
+ return compare_dataframes(head, base)
27
+ if isinstance(head, pd.Series) and isinstance(base, pd.Series):
28
+ return compare_series(head, base)
29
+ if isinstance(head, (pd.DataFrame, pd.Series)) or isinstance(
30
+ base, (pd.DataFrame, pd.Series)
31
+ ):
32
+ return ComparisonResult(
33
+ equal=False,
34
+ report=f"Type mismatch: head={type(head).__name__}, base={type(base).__name__}",
35
+ )
36
+ except ImportError:
37
+ pass
38
+
39
+ # numpy arrays
40
+ try:
41
+ import numpy as np
42
+
43
+ if isinstance(head, np.ndarray) and isinstance(base, np.ndarray):
44
+ return _compare_arrays(head, base)
45
+ except ImportError:
46
+ pass
47
+
48
+ # float scalars
49
+ if isinstance(head, float) and isinstance(base, float):
50
+ return _compare_floats(head, base)
51
+
52
+ # numpy scalars that are float-like
53
+ try:
54
+ import numpy as np
55
+
56
+ if isinstance(head, np.floating) and isinstance(base, np.floating):
57
+ return _compare_floats(float(head), float(base))
58
+ except ImportError:
59
+ pass
60
+
61
+ # dict
62
+ if isinstance(head, dict) and isinstance(base, dict):
63
+ return _compare_dicts(head, base)
64
+
65
+ # list / tuple
66
+ if isinstance(head, (list, tuple)) and isinstance(base, (list, tuple)):
67
+ return _compare_sequences(head, base)
68
+
69
+ # generic fallback
70
+ return _compare_generic(head, base)
71
+
72
+
73
+ def _compare_floats(head: float, base: float, rtol: float = 1e-5, atol: float = 1e-8) -> ComparisonResult:
74
+ if math.isnan(head) and math.isnan(base):
75
+ return ComparisonResult(equal=True, report=None)
76
+ equal = math.isclose(head, base, rel_tol=rtol, abs_tol=atol)
77
+ report = None if equal else f"Float mismatch: head={head!r}, base={base!r}"
78
+ return ComparisonResult(equal=equal, report=report)
79
+
80
+
81
+ def _compare_arrays(head, base) -> ComparisonResult:
82
+ import numpy as np
83
+
84
+ if head.shape != base.shape:
85
+ return ComparisonResult(
86
+ equal=False,
87
+ report=f"Shape mismatch: head={head.shape}, base={base.shape}",
88
+ )
89
+ try:
90
+ np.testing.assert_array_almost_equal(head, base, decimal=5)
91
+ return ComparisonResult(equal=True, report=None)
92
+ except AssertionError as e:
93
+ return ComparisonResult(equal=False, report=str(e))
94
+
95
+
96
+ def _compare_dicts(head: dict, base: dict) -> ComparisonResult:
97
+ head_keys = set(head.keys())
98
+ base_keys = set(base.keys())
99
+
100
+ if head_keys != base_keys:
101
+ only_head = head_keys - base_keys
102
+ only_base = base_keys - head_keys
103
+ parts = []
104
+ if only_head:
105
+ parts.append(f"Keys only in head: {sorted(str(k) for k in only_head)}")
106
+ if only_base:
107
+ parts.append(f"Keys only in base: {sorted(str(k) for k in only_base)}")
108
+ return ComparisonResult(equal=False, report="\n".join(parts))
109
+
110
+ mismatches = []
111
+ for key in sorted(head_keys, key=str):
112
+ sub = _dispatch(head[key], base[key])
113
+ if not sub.equal:
114
+ mismatches.append(f" Key {key!r}: {sub.report}")
115
+
116
+ if mismatches:
117
+ return ComparisonResult(
118
+ equal=False, report="Dict mismatches:\n" + "\n".join(mismatches)
119
+ )
120
+ return ComparisonResult(equal=True, report=None)
121
+
122
+
123
+ def _compare_sequences(head, base) -> ComparisonResult:
124
+ if len(head) != len(base):
125
+ return ComparisonResult(
126
+ equal=False,
127
+ report=f"Length mismatch: head={len(head)}, base={len(base)}",
128
+ )
129
+ mismatches = []
130
+ for i, (h, b) in enumerate(zip(head, base)):
131
+ sub = _dispatch(h, b)
132
+ if not sub.equal:
133
+ mismatches.append(f" Index {i}: {sub.report}")
134
+ if mismatches:
135
+ return ComparisonResult(
136
+ equal=False,
137
+ report=f"{type(head).__name__} mismatches:\n" + "\n".join(mismatches),
138
+ )
139
+ return ComparisonResult(equal=True, report=None)
140
+
141
+
142
+ def _compare_generic(head: Any, base: Any) -> ComparisonResult:
143
+ try:
144
+ equal = bool(head == base)
145
+ except Exception:
146
+ equal = False
147
+
148
+ if equal:
149
+ return ComparisonResult(equal=True, report=None)
150
+
151
+ return ComparisonResult(
152
+ equal=False,
153
+ report=f"Value mismatch:\n head: {head!r}\n base: {base!r}",
154
+ )
@@ -0,0 +1,135 @@
1
+ """Pandas DataFrame/Series index auto-detection and comparison."""
2
+ from __future__ import annotations
3
+
4
+ import itertools
5
+ from dataclasses import dataclass, field
6
+ from typing import TYPE_CHECKING
7
+
8
+ if TYPE_CHECKING:
9
+ import pandas as pd
10
+
11
+
12
+ @dataclass
13
+ class ComparisonResult:
14
+ equal: bool
15
+ report: str | None
16
+ node_id: str = ""
17
+ extra: dict = field(default_factory=dict)
18
+
19
+
20
+ def detect_index_columns(df: "pd.DataFrame", max_combo_size: int = 3) -> list[str] | None:
21
+ """
22
+ Auto-detect which columns can serve as join keys for comparison.
23
+
24
+ Priority:
25
+ 1. Named non-RangeIndex (already set as index)
26
+ 2. Heuristic: find smallest combo of non-float cols with full cardinality
27
+ 3. None → fall back to positional comparison
28
+ """
29
+ import pandas as pd
30
+
31
+ # Case A: already has a meaningful named index
32
+ if not isinstance(df.index, pd.RangeIndex):
33
+ if isinstance(df.index, pd.MultiIndex):
34
+ if all(name is not None for name in df.index.names):
35
+ return list(df.index.names)
36
+ elif df.index.name is not None:
37
+ return [df.index.name]
38
+
39
+ n = len(df)
40
+ if n == 0 or len(df.columns) == 0:
41
+ return None
42
+
43
+ # Case B: heuristic column search (exclude float columns — poor join keys)
44
+ candidate_cols = [
45
+ c for c in df.columns if not pd.api.types.is_float_dtype(df[c].dtype)
46
+ ]
47
+ # Sort by cardinality descending (higher = better key candidate)
48
+ candidate_cols = sorted(candidate_cols, key=lambda c: df[c].nunique(), reverse=True)
49
+
50
+ for r in range(1, min(max_combo_size + 1, len(candidate_cols) + 1)):
51
+ for combo in itertools.combinations(candidate_cols, r):
52
+ try:
53
+ if df.groupby(list(combo)).ngroups == n:
54
+ return list(combo)
55
+ except Exception:
56
+ continue
57
+
58
+ return None
59
+
60
+
61
+ def _reset_named_index(df: "pd.DataFrame") -> "pd.DataFrame":
62
+ """If df has a named index, reset it to columns."""
63
+ import pandas as pd
64
+
65
+ if not isinstance(df.index, pd.RangeIndex):
66
+ return df.reset_index()
67
+ return df
68
+
69
+
70
+ def compare_dataframes(
71
+ head_df: "pd.DataFrame",
72
+ base_df: "pd.DataFrame",
73
+ join_columns: list[str] | None = None,
74
+ ) -> ComparisonResult:
75
+ """Compare two DataFrames, auto-detecting join columns if not provided."""
76
+ import pandas as pd
77
+
78
+ head_flat = _reset_named_index(head_df)
79
+ base_flat = _reset_named_index(base_df)
80
+
81
+ if join_columns is None:
82
+ join_columns = detect_index_columns(head_flat)
83
+
84
+ # Try datacompy first
85
+ try:
86
+ import datacompy
87
+
88
+ if not hasattr(datacompy, "Compare"):
89
+ raise ImportError("datacompy.Compare not available")
90
+
91
+ if join_columns is None:
92
+ # No key found; use all columns positionally by adding a row-number key
93
+ head_flat = head_flat.copy()
94
+ base_flat = base_flat.copy()
95
+ head_flat["__row__"] = range(len(head_flat))
96
+ base_flat["__row__"] = range(len(base_flat))
97
+ join_columns = ["__row__"]
98
+
99
+ cmp = datacompy.Compare(
100
+ head_flat,
101
+ base_flat,
102
+ join_columns=join_columns,
103
+ df1_name="head",
104
+ df2_name="base",
105
+ )
106
+ equal = cmp.matches()
107
+ return ComparisonResult(
108
+ equal=equal,
109
+ report=None if equal else cmp.report(),
110
+ )
111
+ except ImportError:
112
+ pass
113
+
114
+ # Fallback: pd.testing.assert_frame_equal
115
+ try:
116
+ if join_columns:
117
+ head_sorted = head_flat.set_index(join_columns).sort_index()
118
+ base_sorted = base_flat.set_index(join_columns).sort_index()
119
+ else:
120
+ head_sorted = head_flat.reset_index(drop=True)
121
+ base_sorted = base_flat.reset_index(drop=True)
122
+
123
+ pd.testing.assert_frame_equal(
124
+ head_sorted, base_sorted, check_like=True, rtol=1e-5
125
+ )
126
+ return ComparisonResult(equal=True, report=None)
127
+ except AssertionError as e:
128
+ return ComparisonResult(equal=False, report=str(e))
129
+
130
+
131
+ def compare_series(head_s: "pd.Series", base_s: "pd.Series") -> ComparisonResult:
132
+ """Compare two Series by converting to DataFrames."""
133
+ head_df = head_s.to_frame(name=head_s.name or "value")
134
+ base_df = base_s.to_frame(name=base_s.name or "value")
135
+ return compare_dataframes(head_df, base_df)