pytest-regtest 2.3.2__tar.gz → 2.3.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. pytest_regtest-2.3.4/MANIFEST.in +2 -0
  2. {pytest_regtest-2.3.2 → pytest_regtest-2.3.4}/PKG-INFO +29 -7
  3. {pytest_regtest-2.3.2 → pytest_regtest-2.3.4}/pyproject.toml +6 -11
  4. pytest_regtest-2.3.4/setup.cfg +4 -0
  5. pytest_regtest-2.3.4/src/pytest_regtest/__init__.py +107 -0
  6. pytest_regtest-2.3.4/src/pytest_regtest/numpy_handler.py +216 -0
  7. pytest_regtest-2.3.4/src/pytest_regtest/pandas_handler.py +143 -0
  8. pytest_regtest-2.3.4/src/pytest_regtest/polars_handler.py +114 -0
  9. pytest_regtest-2.3.4/src/pytest_regtest/pytest_regtest.py +656 -0
  10. pytest_regtest-2.3.4/src/pytest_regtest/register_third_party_handlers.py +43 -0
  11. pytest_regtest-2.3.4/src/pytest_regtest/snapshot_handler.py +188 -0
  12. pytest_regtest-2.3.4/src/pytest_regtest/utils.py +28 -0
  13. pytest_regtest-2.3.4/src/pytest_regtest.egg-info/PKG-INFO +113 -0
  14. pytest_regtest-2.3.4/src/pytest_regtest.egg-info/SOURCES.txt +27 -0
  15. pytest_regtest-2.3.4/src/pytest_regtest.egg-info/dependency_links.txt +1 -0
  16. pytest_regtest-2.3.4/src/pytest_regtest.egg-info/entry_points.txt +2 -0
  17. pytest_regtest-2.3.4/src/pytest_regtest.egg-info/requires.txt +23 -0
  18. pytest_regtest-2.3.4/src/pytest_regtest.egg-info/top_level.txt +1 -0
  19. pytest_regtest-2.3.4/tests/test_cli.py +69 -0
  20. pytest_regtest-2.3.4/tests/test_regtest.py +587 -0
  21. pytest_regtest-2.3.4/tests/test_snapshot.py +37 -0
  22. pytest_regtest-2.3.4/tests/test_snapshot_numpy.py +510 -0
  23. pytest_regtest-2.3.4/tests/test_snapshot_pandas.py +224 -0
  24. pytest_regtest-2.3.4/tests/test_snapshot_polars.py +214 -0
  25. pytest_regtest-2.3.4/tests/test_snapshot_python_types.py +109 -0
  26. pytest_regtest-2.3.4/tests/test_utils.py +21 -0
  27. pytest_regtest-2.3.2/.gitignore +0 -27
  28. {pytest_regtest-2.3.2 → pytest_regtest-2.3.4}/LICENSE.txt +0 -0
  29. {pytest_regtest-2.3.2 → pytest_regtest-2.3.4}/README.md +0 -0
  30. {pytest_regtest-2.3.2 → pytest_regtest-2.3.4}/tests/conftest.py +0 -0
@@ -0,0 +1,2 @@
1
+ include LICENSE.txt
2
+ include tests/*.py
@@ -1,20 +1,42 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: pytest-regtest
3
- Version: 2.3.2
3
+ Version: 2.3.4
4
4
  Summary: pytest plugin for snapshot regression testing
5
- Project-URL: Source, https://gitlab.com/uweschmitt/pytest-regtest
6
- Project-URL: Documentation, https://pytest-regtest.readthedocs.org
7
5
  Author-email: Uwe Schmitt <uwe.schmitt@id.ethz.ch>
8
6
  License: MIT License
9
- License-File: LICENSE.txt
7
+ Project-URL: Source, https://gitlab.com/uweschmitt/pytest-regtest
8
+ Project-URL: Documentation, https://pytest-regtest.readthedocs.org
10
9
  Classifier: Intended Audience :: Developers
11
- Classifier: License :: OSI Approved :: MIT License
12
10
  Classifier: Programming Language :: Python :: 3.9
13
11
  Classifier: Programming Language :: Python :: 3.10
14
12
  Classifier: Programming Language :: Python :: 3.11
15
13
  Classifier: Programming Language :: Python :: 3.12
16
- Requires-Dist: pytest>7.2
14
+ Classifier: License :: OSI Approved :: MIT License
17
15
  Description-Content-Type: text/markdown
16
+ License-File: LICENSE.txt
17
+ Requires-Dist: pytest>7.2
18
+ Provides-Extra: dev
19
+ Requires-Dist: twine; extra == "dev"
20
+ Requires-Dist: build; extra == "dev"
21
+ Requires-Dist: hatchling; extra == "dev"
22
+ Requires-Dist: wheel; extra == "dev"
23
+ Requires-Dist: pre-commit; extra == "dev"
24
+ Requires-Dist: ruff; extra == "dev"
25
+ Requires-Dist: black; extra == "dev"
26
+ Requires-Dist: pytest-cov; extra == "dev"
27
+ Requires-Dist: numpy; extra == "dev"
28
+ Requires-Dist: pandas; extra == "dev"
29
+ Requires-Dist: mkdocs; extra == "dev"
30
+ Requires-Dist: mkdocs-material; extra == "dev"
31
+ Requires-Dist: mistletoe; extra == "dev"
32
+ Requires-Dist: mkdocs-awesome-pages-plugin; extra == "dev"
33
+ Requires-Dist: jinja2-cli; extra == "dev"
34
+ Requires-Dist: mkdocstrings[python]; extra == "dev"
35
+ Requires-Dist: numpy>=2; extra == "dev"
36
+ Requires-Dist: pandas>=2; extra == "dev"
37
+ Requires-Dist: polars>=1.9; extra == "dev"
38
+ Requires-Dist: md-transformer>=0.0.3; extra == "dev"
39
+ Dynamic: license-file
18
40
 
19
41
  ![](https://gitlab.com/uweschmitt/pytest-regtest/badges/main/pipeline.svg)
20
42
  ![](https://gitlab.com/uweschmitt/pytest-regtest/badges/main/coverage.svg?job=coverage)
@@ -1,11 +1,12 @@
1
1
  [build-system]
2
- requires = ["hatchling"]
3
- build-backend = "hatchling.build"
2
+ requires = ["setuptools>=61"]
3
+ build-backend = "setuptools.build_meta"
4
+
4
5
 
5
6
 
6
7
  [project]
7
8
  name = "pytest-regtest"
8
- version = "2.3.2"
9
+ version = "2.3.4"
9
10
  description = "pytest plugin for snapshot regression testing"
10
11
  readme = "README.md"
11
12
  authors = [
@@ -35,12 +36,6 @@ Documentation = "https://pytest-regtest.readthedocs.org"
35
36
  [project.entry-points.pytest11]
36
37
  regtest = "pytest_regtest"
37
38
 
38
- [tool.hatch.build.targets.sdist]
39
- only-include = ["pytest_regtest", "tests/conftest.py", "tests/test_plugin.py"]
40
-
41
- [tool.hatch.build.targets.wheel]
42
- packages = ["src/pytest_regtest"]
43
-
44
39
  [tool.ruff]
45
40
  line-length = 88
46
41
  exclude = ["_regtest_output"]
@@ -48,8 +43,8 @@ exclude = ["_regtest_output"]
48
43
  [tool.ruff.lint]
49
44
  ignore = ["E731", "E203"]
50
45
 
51
- [tool.uv]
52
- dev-dependencies = [
46
+ [project.optional-dependencies]
47
+ dev = [
53
48
  "twine", "build", "hatchling", "wheel", "pre-commit", "ruff", "black",
54
49
  "pytest-cov", "numpy", "pandas", "mkdocs", "mkdocs-material", "mistletoe",
55
50
  "mkdocs-awesome-pages-plugin", "jinja2-cli", "mkdocstrings[python]",
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,107 @@
1
+ from importlib.metadata import version as _version
2
+
3
+ import pytest
4
+
5
+ from .pytest_regtest import clear_converters # noqa: F401
6
+ from .pytest_regtest import patch_terminal_size # noqa: F401
7
+ from .pytest_regtest import register_converter_post # noqa: F401
8
+ from .pytest_regtest import register_converter_pre # noqa: F401
9
+ from .pytest_regtest import (
10
+ PytestRegtestCommonHooks,
11
+ PytestRegtestPlugin,
12
+ RegtestStream,
13
+ Snapshot,
14
+ SnapshotPlugin,
15
+ )
16
+ from .register_third_party_handlers import (
17
+ register_numpy_handler,
18
+ register_pandas_handler,
19
+ register_polars_handler,
20
+ )
21
+ from .snapshot_handler import register_python_object_handler
22
+
23
+ __version__ = _version(__package__)
24
+
25
+
26
+ def pytest_addoption(parser):
27
+ """Add options to control the timeout plugin"""
28
+ group = parser.getgroup("regtest", "regression test plugin")
29
+ group.addoption(
30
+ "--regtest-reset",
31
+ action="store_true",
32
+ help="do not run regtest but record current output",
33
+ )
34
+ group.addoption(
35
+ "--regtest-tee",
36
+ action="store_true",
37
+ default=False,
38
+ help="print recorded results to console too",
39
+ )
40
+ group.addoption(
41
+ "--regtest-consider-line-endings",
42
+ action="store_true",
43
+ default=False,
44
+ help="do not strip whitespaces at end of recorded lines",
45
+ )
46
+ group.addoption(
47
+ "--regtest-nodiff",
48
+ action="store_true",
49
+ default=False,
50
+ help="do not show diff output for failed regresson tests",
51
+ )
52
+ group.addoption(
53
+ "--regtest-disable-stdconv",
54
+ action="store_true",
55
+ default=False,
56
+ help=(
57
+ "do not apply standard output converters to clean up indeterministic output"
58
+ ),
59
+ )
60
+
61
+
62
+ def pytest_configure(config):
63
+ common = PytestRegtestCommonHooks()
64
+ config.pluginmanager.register(common)
65
+ config.pluginmanager.register(PytestRegtestPlugin(common))
66
+ config.pluginmanager.register(SnapshotPlugin(common))
67
+
68
+
69
+ @pytest.fixture
70
+ def regtest(request):
71
+ yield RegtestStream(request)
72
+
73
+
74
+ @pytest.fixture
75
+ def snapshot(request):
76
+ yield Snapshot(request)
77
+
78
+
79
+ @pytest.fixture
80
+ def regtest_all(regtest):
81
+ yield regtest
82
+
83
+
84
+ snapshot_all_output = regtest_all
85
+
86
+ register_python_object_handler()
87
+
88
+ try:
89
+ import pandas # noqa: F401
90
+
91
+ register_pandas_handler()
92
+ except ImportError:
93
+ pass
94
+
95
+ try:
96
+ import numpy # noqa: F401
97
+
98
+ register_numpy_handler()
99
+ except ImportError:
100
+ pass
101
+
102
+ try:
103
+ import polars # noqa: F401
104
+
105
+ register_polars_handler()
106
+ except ImportError:
107
+ pass
@@ -0,0 +1,216 @@
1
+ import difflib
2
+ import io
3
+ import os.path
4
+ import warnings
5
+
6
+ import numpy as np
7
+
8
+ from .snapshot_handler import BaseSnapshotHandler
9
+ from .utils import highlight_mismatches
10
+
11
+
12
+ class NumpyHandler(BaseSnapshotHandler):
13
+ def __init__(self, handler_options, pytest_config, tw):
14
+ self.atol = handler_options.get("atol", 0.0)
15
+ self.rtol = handler_options.get("rtol", 0.0)
16
+ self.equal_nan = handler_options.get("equal_nan", True)
17
+ if handler_options.get("print_options"):
18
+ warnings.warn(
19
+ "please use the numpy.printoptions context manager instead of"
20
+ " the print_options argument.",
21
+ DeprecationWarning,
22
+ )
23
+
24
+ self.print_options = np.get_printoptions() | handler_options.get(
25
+ "print_options", {}
26
+ )
27
+
28
+ def _filename(self, folder):
29
+ return os.path.join(folder, "arrays.npy")
30
+
31
+ def save(self, folder, obj):
32
+ np.save(self._filename(folder), obj)
33
+
34
+ def load(self, folder):
35
+ return np.load(self._filename(folder))
36
+
37
+ def show(self, obj):
38
+ stream = io.StringIO()
39
+ with np.printoptions(**self.print_options):
40
+ print(obj, file=stream)
41
+ return stream.getvalue().splitlines()
42
+
43
+ def compare(self, current_obj, recorded_obj):
44
+ return (
45
+ isinstance(current_obj, np.ndarray)
46
+ and current_obj.shape == recorded_obj.shape
47
+ and current_obj.dtype == recorded_obj.dtype
48
+ and np.allclose(
49
+ recorded_obj,
50
+ current_obj,
51
+ atol=self.atol,
52
+ rtol=self.rtol,
53
+ equal_nan=self.equal_nan,
54
+ )
55
+ )
56
+
57
+ def show_differences(self, current_obj, recorded_obj, has_markup):
58
+ lines = []
59
+
60
+ if recorded_obj.dtype != current_obj.dtype:
61
+ lines.extend(
62
+ [
63
+ f"dtype mismatch: current dtype: {current_obj.dtype}",
64
+ f" recorded dtype: {recorded_obj.dtype}",
65
+ ]
66
+ )
67
+
68
+ recorded_as_text = self.show(recorded_obj)
69
+ current_as_text = self.show(current_obj)
70
+
71
+ if recorded_obj.shape == current_obj.shape:
72
+ if np.allclose(current_obj, recorded_obj, rtol=self.rtol, atol=self.atol):
73
+ return lines or None
74
+
75
+ lines.extend(self.error_diagnostics(recorded_obj, current_obj))
76
+
77
+ else:
78
+ lines.extend(
79
+ [
80
+ f"shape mismatch: current shape: {current_obj.shape}",
81
+ f" recorded shape: {recorded_obj.shape}",
82
+ ]
83
+ )
84
+
85
+ if recorded_obj.ndim > 2:
86
+ return lines
87
+
88
+ if recorded_obj.ndim == 1:
89
+ diff_lines = list(
90
+ difflib.unified_diff(
91
+ current_as_text,
92
+ recorded_as_text,
93
+ "current",
94
+ "expected",
95
+ lineterm="",
96
+ )
97
+ )
98
+ lines.append("")
99
+ lines.extend(diff_lines)
100
+
101
+ else:
102
+ diff_lines = self.error_diagnostics_2d_linewise(
103
+ current_obj,
104
+ current_as_text,
105
+ recorded_obj,
106
+ recorded_as_text,
107
+ has_markup,
108
+ )
109
+ lines.extend(diff_lines)
110
+
111
+ if not diff_lines:
112
+ lines.append("diff is empty, you may want to change the print options")
113
+
114
+ return lines
115
+
116
+ def error_diagnostics(self, recorded_obj, current_obj):
117
+ with warnings.catch_warnings():
118
+ warnings.simplefilter("ignore", RuntimeWarning)
119
+ rel_err = np.abs(current_obj - recorded_obj) / recorded_obj
120
+ rel_err[(recorded_obj == 0) * (current_obj == recorded_obj)] = 0.0
121
+ rel_err_max_1 = np.max(rel_err)
122
+ rel_err_max_2 = np.max(rel_err[recorded_obj != 0])
123
+
124
+ abs_err = np.abs(current_obj - recorded_obj)
125
+ abs_err_max = np.max(abs_err)
126
+
127
+ lines = []
128
+
129
+ if rel_err_max_1 == rel_err_max_2:
130
+ lines.append(f"max relative deviation: {rel_err_max_1:e}")
131
+ else:
132
+ lines.append(f"max relative deviation: {rel_err_max_1:e}")
133
+ lines.append(f"max relative deviation except inf: {rel_err_max_2:e}")
134
+
135
+ lines.append(f"max absolute deviation: {abs_err_max:e}")
136
+
137
+ n_diff = np.sum(
138
+ np.logical_not(
139
+ np.isclose(current_obj, recorded_obj, rtol=self.rtol, atol=self.atol)
140
+ )
141
+ )
142
+
143
+ lines.append(
144
+ f"both arrays differ in {n_diff} out of {np.prod(recorded_obj.shape)}"
145
+ " entries"
146
+ )
147
+ lines.append(
148
+ f"up to given precision settings rtol={self.rtol:e} and"
149
+ f" atol={self.atol:e}"
150
+ )
151
+
152
+ return lines
153
+
154
+ def error_diagnostics_2d_linewise(
155
+ self, current_obj, current_as_text, recorded_obj, recorded_as_text, has_markup
156
+ ):
157
+ sub_diff = []
158
+
159
+ for i, (l1, l2, r1, r2) in enumerate(
160
+ zip(current_as_text, recorded_as_text, current_obj, recorded_obj)
161
+ ):
162
+ if r1.shape == r2.shape and np.allclose(
163
+ r1, r2, rtol=self.rtol, atol=self.atol
164
+ ):
165
+ continue
166
+
167
+ if r1.shape == r2.shape:
168
+ # enforces more uniform formatting of both lines:
169
+ rows_together = np.vstack((r1, r2))
170
+ lines_together = self.show(rows_together)
171
+ line_diff = list(
172
+ difflib.unified_diff(
173
+ [lines_together[0][1:].strip()],
174
+ [lines_together[1][:-1].strip()],
175
+ "current",
176
+ "expected",
177
+ lineterm="",
178
+ )
179
+ )
180
+ else:
181
+ row_1 = self.show(r1)
182
+ row_2 = self.show(r2)
183
+ line_diff = list(
184
+ difflib.unified_diff(
185
+ row_1,
186
+ row_2,
187
+ "current",
188
+ "expected",
189
+ lineterm="",
190
+ )
191
+ )
192
+
193
+ if line_diff:
194
+ if not sub_diff:
195
+ sub_diff = line_diff[:2]
196
+
197
+ l1, l2 = line_diff[-2], line_diff[-1]
198
+ if has_markup:
199
+ l1, l2 = highlight_mismatches(l1, l2)
200
+
201
+ sub_diff.append(f"row {i:3d}: {l1}")
202
+ sub_diff.append(f" {l2}")
203
+
204
+ missing = len(current_as_text) - len(recorded_as_text)
205
+ if missing > 0:
206
+ for i, row in enumerate(current_as_text[-missing:], len(recorded_as_text)):
207
+ # remove duplicate brackets
208
+ row = row.rstrip("]") + "]"
209
+ sub_diff.append(f"row {i:3d}: -{row.lstrip()}")
210
+ if missing < 0:
211
+ for i, row in enumerate(recorded_as_text[missing:], len(current_as_text)):
212
+ # remove duplicate brackets
213
+ row = row.rstrip("]") + "]"
214
+ sub_diff.append(f"row {i:3d}: +{row.lstrip()}")
215
+
216
+ return sub_diff
@@ -0,0 +1,143 @@
1
+ import difflib
2
+ import io
3
+ import os.path
4
+ import warnings
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+
9
+ from .snapshot_handler import BaseSnapshotHandler
10
+
11
+
12
+ class DataFrameHandler(BaseSnapshotHandler):
13
+ def __init__(self, handler_options, pytest_config, tw):
14
+ if handler_options.get("display_options"):
15
+ warnings.warn(
16
+ "please use the 'pandas.option_context' context manager instead of"
17
+ " the display_options argument.",
18
+ DeprecationWarning,
19
+ )
20
+
21
+ # default contains a few nested dicts and we flatten those, e.g.
22
+ # { "html": {"border": 1} } -> { "html.border": 1 }
23
+ default = list(pd.options.display.d.items())
24
+ default_flattened = {}
25
+ for k, v in default:
26
+ if isinstance(v, dict):
27
+ for k0, v0 in v.items():
28
+ default_flattened[f"{k}.{k0}"] = v0
29
+ else:
30
+ default_flattened[k] = v
31
+
32
+ # overwrite with user settings:
33
+ items = (default_flattened | handler_options.get("display_options", {})).items()
34
+
35
+ # flatten items as required by pandas.option_context:
36
+ self.display_options_flat = [
37
+ entry for item in items for entry in (f"display.{item[0]}", item[1])
38
+ ]
39
+ self.atol = handler_options.get("atol", 0.0)
40
+ self.rtol = handler_options.get("rtol", 0.0)
41
+
42
+ def _filename(self, folder):
43
+ return os.path.join(folder, "dataframe.pkl")
44
+
45
+ def save(self, folder, obj):
46
+ obj.to_pickle(self._filename(folder), compression="gzip")
47
+
48
+ def load(self, folder):
49
+ return pd.read_pickle(self._filename(folder), compression="gzip")
50
+
51
+ def show(self, obj):
52
+ stream = io.StringIO()
53
+ with pd.option_context(*self.display_options_flat):
54
+ print(obj, file=stream)
55
+ return stream.getvalue().splitlines()
56
+
57
+ def compare(self, current, recorded):
58
+ missing = set(
59
+ n
60
+ for (n, t) in set(zip(recorded.columns, recorded.dtypes))
61
+ ^ set(zip(current.columns, current.dtypes))
62
+ )
63
+
64
+ if missing:
65
+ return False
66
+
67
+ common = set(
68
+ n
69
+ for (n, t) in set(zip(recorded.columns, recorded.dtypes))
70
+ & set(zip(current.columns, current.dtypes))
71
+ )
72
+ current_reduced = current[[n for n in current.columns if n in common]]
73
+ recorded_reduced = recorded[[n for n in recorded.columns if n in common]]
74
+
75
+ def extract(df, selector):
76
+ return df[[n for (n, t) in zip(df.columns, df.dtypes) if selector(t)]]
77
+
78
+ current_reduced_floats = extract(
79
+ current_reduced, lambda t: t.type is np.float64
80
+ ).to_numpy()
81
+
82
+ current_reduced_other = extract(
83
+ current_reduced, lambda t: t.type is not np.float64
84
+ )
85
+
86
+ recorded_reduced_floats = extract(
87
+ recorded_reduced, lambda t: t.type is np.float64
88
+ ).to_numpy()
89
+
90
+ recorded_reduced_other = extract(
91
+ recorded_reduced, lambda t: t.type is not np.float64
92
+ )
93
+
94
+ return np.allclose(
95
+ current_reduced_floats,
96
+ recorded_reduced_floats,
97
+ atol=self.atol,
98
+ rtol=self.rtol,
99
+ equal_nan=True,
100
+ ) and (current_reduced_other == recorded_reduced_other).all(axis=None)
101
+
102
+ def show_differences(self, current, recorded, has_markup):
103
+ lines = []
104
+
105
+ stream = io.StringIO()
106
+ current.info(buf=stream, verbose=True, memory_usage=False)
107
+ current_info = stream.getvalue().splitlines()[2:][:-1]
108
+
109
+ stream = io.StringIO()
110
+ recorded.info(buf=stream, verbose=True, memory_usage=False)
111
+ recorded_info = stream.getvalue().splitlines()[2:][:-1]
112
+
113
+ info_diff = list(
114
+ difflib.unified_diff(
115
+ current_info,
116
+ recorded_info,
117
+ "current",
118
+ "expected",
119
+ lineterm="",
120
+ )
121
+ )
122
+ lines.extend(info_diff)
123
+
124
+ recorded_as_text = self.show(recorded)
125
+ current_as_text = self.show(current)
126
+
127
+ diffs = list(
128
+ difflib.unified_diff(
129
+ current_as_text,
130
+ recorded_as_text,
131
+ "current",
132
+ "expected",
133
+ lineterm="",
134
+ )
135
+ )
136
+
137
+ lines.append("")
138
+ if diffs:
139
+ lines.extend(diffs)
140
+ else:
141
+ lines.append("diff is empty, you may want to change the print options")
142
+
143
+ return lines
@@ -0,0 +1,114 @@
1
+ import difflib
2
+ import io
3
+ import os
4
+ from typing import Any, Union
5
+
6
+ import polars as pl
7
+ from polars.testing import assert_frame_equal
8
+
9
+ from .snapshot_handler import BaseSnapshotHandler
10
+
11
+
12
+ class PolarsHandler(BaseSnapshotHandler):
13
+ """
14
+ PolarsHandler is a class for handling Polars DataFrame snapshots in pytest-regtest.
15
+ """
16
+
17
+ def __init__(self, handler_options: dict[str, Any], pytest_config, tw):
18
+ self.atol = handler_options.get("atol", 0.0)
19
+ self.rtol = handler_options.get("rtol", 0.0)
20
+ self.display_options = handler_options.get("display_options", None)
21
+
22
+ def _filename(self, folder: Union[str, os.PathLike[Any]]) -> str:
23
+ return os.path.join(folder, "polars.parquet")
24
+
25
+ def save(self, folder: Union[str, os.PathLike[Any]], obj: pl.DataFrame):
26
+ obj.write_parquet(self._filename(folder))
27
+
28
+ def load(self, folder: Union[str, os.PathLike[Any]]) -> pl.DataFrame:
29
+ return pl.read_parquet(self._filename(folder))
30
+
31
+ def show(self, obj: pl.DataFrame) -> list[str]:
32
+ stream = io.StringIO()
33
+ if self.display_options:
34
+ with pl.Config(**self.display_options):
35
+ stream.write(str(obj))
36
+ else:
37
+ stream.write(str(obj))
38
+ return stream.getvalue().splitlines()
39
+
40
+ def compare(self, current_obj: pl.DataFrame, recorded_obj: pl.DataFrame) -> bool:
41
+ try:
42
+ assert_frame_equal(
43
+ current_obj, recorded_obj, atol=self.atol, rtol=self.rtol
44
+ )
45
+ return True
46
+ except AssertionError:
47
+ return False
48
+
49
+ @staticmethod
50
+ def create_schema_info(df: pl.DataFrame) -> list[str]:
51
+ """
52
+ Generate a summary of the schema information for a given Polars DataFrame.
53
+
54
+ Parameters:
55
+ df (pl.DataFrame): The Polars DataFrame for which to generate schema information.
56
+
57
+ Returns:
58
+ list[str]: A list of strings representing the schema information, including
59
+ the total number of columns, column names, non-null counts, and data types.
60
+ """
61
+ schema = df.schema
62
+ schema_string_repr = [
63
+ "Data columns (total {} columns):".format(len(schema)),
64
+ " # Column Non-Null Count Dtype ",
65
+ "--- ------ -------------- ----- ",
66
+ ]
67
+ for i, (column, dtype) in enumerate(schema.items()):
68
+ total_count = df.height
69
+ null_count = df[column].null_count()
70
+ non_null_count = total_count - null_count
71
+ dtype_str = str(dtype)
72
+ schema_string_repr.append(
73
+ f" {i} {column} {non_null_count} non-null {dtype_str}"
74
+ )
75
+ return schema_string_repr
76
+
77
+ def show_differences(
78
+ self, current_obj: pl.DataFrame, recorded_obj: pl.DataFrame, has_markup: bool
79
+ ) -> list[str]:
80
+ lines = []
81
+
82
+ current_schema = self.create_schema_info(current_obj)
83
+ recorded_schema = self.create_schema_info(recorded_obj)
84
+
85
+ info_diff = list(
86
+ difflib.unified_diff(
87
+ current_schema,
88
+ recorded_schema,
89
+ "current",
90
+ "expected",
91
+ lineterm="",
92
+ )
93
+ )
94
+ lines.extend(info_diff)
95
+ recorded_as_text = self.show(recorded_obj)
96
+ current_as_text = self.show(current_obj)
97
+
98
+ diffs = list(
99
+ difflib.unified_diff(
100
+ current_as_text,
101
+ recorded_as_text,
102
+ "current",
103
+ "expected",
104
+ lineterm="",
105
+ )
106
+ )
107
+
108
+ lines.append("")
109
+ if diffs:
110
+ lines.extend(diffs)
111
+ else:
112
+ lines.append("diff is empty, you may want to change the print options")
113
+
114
+ return lines