pytest-pyspark-utils 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,4 @@
1
+ from pytest_pyspark_utils.plugin import TableConfig, DeltaTablesResult
2
+ from pytest_pyspark_utils.delta_caching import DeltaCaching
3
+
4
+ __all__ = ["TableConfig", "DeltaTablesResult", "DeltaCaching"]
@@ -0,0 +1,230 @@
1
+ """
2
+ Delta Lake caching layer for pytest-pyspark-utils.
3
+
4
+ Converts CSV/JSONL source files to Delta format and caches them on disk.
5
+ Cache validity is determined by an MD5 hash of the source file content
6
+ and schema. A cache hit skips re-conversion; a miss cleans the old cache
7
+ and re-generates it.
8
+ """
9
+
10
+ import hashlib
11
+ import logging
12
+ import shutil
13
+ from pathlib import Path
14
+
15
+ from pyspark.sql import DataFrame, SparkSession
16
+ from pyspark.sql.types import StructType
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class DeltaCaching:
22
+ """Manages on-disk Delta Lake caching for a single CSV or JSONL source file.
23
+
24
+ On first use (or when the source file changes) the data is converted to
25
+ Delta format and stored under ``cache_base_dir / dataset_name``. Subsequent
26
+ calls with the same source file and schema reuse the cached Delta data.
27
+
28
+ Args:
29
+ source_path: Absolute or relative path to the CSV or JSONL source file.
30
+ cache_base_dir: Directory under which per-dataset Delta caches are stored.
31
+ spark: Active SparkSession.
32
+ schema: Optional Spark schema. If omitted, schema is inferred from the file.
33
+ partition_by: Column names used for Delta partitioning or liquid clustering.
34
+ liquid_clustering: Write using Delta liquid clustering instead of Hive
35
+ partitioning. Requires *schema* to be provided.
36
+ debug: Emit extra debug log messages.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ source_path: str,
42
+ cache_base_dir: Path,
43
+ spark: SparkSession,
44
+ schema: StructType | None = None,
45
+ partition_by: list[str] | None = None,
46
+ liquid_clustering: bool = False,
47
+ debug: bool = False,
48
+ ) -> None:
49
+ self.source_path = Path(source_path)
50
+ self.spark = spark
51
+ self.schema = schema
52
+ self.partition_by = partition_by
53
+ self.liquid_clustering = liquid_clustering
54
+ self.debug = debug
55
+ self.dataset = self.source_path.stem
56
+ self.cached_path = cache_base_dir / self.dataset
57
+
58
+ if self.debug:
59
+ logger.debug("cache_base_dir=%s", cache_base_dir)
60
+ logger.debug("dataset=%s", self.dataset)
61
+ logger.debug("cached_path=%s", self.cached_path)
62
+
63
+ @property
64
+ def hash_source(self) -> str:
65
+ """MD5 hex digest of the source file content combined with the schema JSON.
66
+
67
+ Returns ``"-2"`` when the source file does not exist so that a missing
68
+ file never falsely matches a populated cache (which uses ``"-1"`` as its
69
+ sentinel for a missing hash file).
70
+ """
71
+ if not self.source_path.exists():
72
+ return "-2"
73
+ content = self.source_path.read_text().encode("UTF-16")
74
+ schema_content = self.schema.json().encode("UTF-16") if self.schema else b""
75
+ return hashlib.md5(content + schema_content).hexdigest()
76
+
77
+ @property
78
+ def hash_cache(self) -> str:
79
+ """MD5 hex digest stored alongside the cached Delta table.
80
+
81
+ Returns ``"-1"`` when no hash file exists (cache is absent or corrupted).
82
+ """
83
+ hash_file = self.cached_path / "_source_data_hash"
84
+ if hash_file.exists():
85
+ return hash_file.read_text()
86
+ else:
87
+ return "-1"
88
+
89
+ def probe_cache(self) -> bool:
90
+ """Return ``True`` if the cached Delta table is up-to-date with the source."""
91
+ return self.hash_source == self.hash_cache
92
+
93
+ def cache(self) -> DataFrame:
94
+ """Ensure the Delta cache is valid and return a DataFrame over the source.
95
+
96
+ If the source hash matches the stored hash the existing cache is reused.
97
+ Otherwise the cache directory is cleaned, the source is converted to Delta,
98
+ and the hash file is written.
99
+
100
+ Returns:
101
+ DataFrame read from the source file (not from the Delta cache).
102
+ """
103
+ if self.probe_cache():
104
+ if self.debug:
105
+ logger.debug("%s: skipping, cached data is up to date", self.dataset)
106
+ return self.read()
107
+ else:
108
+ if self.debug:
109
+ logger.debug("%s: refreshing the cache", self.dataset)
110
+
111
+ self.clean_cache()
112
+ df = self.write_delta()
113
+ self.write_cache_hash()
114
+ self.remove_crc_files()
115
+
116
+ return df
117
+
118
+ def remove_crc_files(self) -> None:
119
+ """Delete all Hadoop CRC sidecar files from the cache directory."""
120
+ for crc_file in Path(self.cached_path).glob("**/*.crc"):
121
+ crc_file.unlink()
122
+
123
+ def write_cache_hash(self) -> None:
124
+ """Write the current source hash into the cache directory."""
125
+ self.cached_path.joinpath("_source_data_hash").write_text(self.hash_source)
126
+
127
+ def read(self) -> DataFrame:
128
+ """Read the source file into a Spark DataFrame.
129
+
130
+ Dispatches to :meth:`read_csv` or :meth:`read_jsonl` based on the file
131
+ extension.
132
+
133
+ Raises:
134
+ ValueError: For file extensions other than ``.csv`` or ``.jsonl``.
135
+ """
136
+ if self.source_path.suffix == ".csv":
137
+ return self.read_csv()
138
+ elif self.source_path.suffix == ".jsonl":
139
+ return self.read_jsonl()
140
+ else:
141
+ raise ValueError(f"Unsupported file format: {self.source_path.suffix}")
142
+
143
+ def read_jsonl(self) -> DataFrame:
144
+ """Read a JSONL file into a Spark DataFrame.
145
+
146
+ Uses the provided schema when available; otherwise infers the schema.
147
+ """
148
+ jsonl_path = self.source_path.as_posix()
149
+
150
+ if self.schema:
151
+ return self.spark.read.schema(self.schema).json(jsonl_path)
152
+ else:
153
+ return self.spark.read.option("inferSchema", "true").json(jsonl_path)
154
+
155
+ def read_csv(self) -> DataFrame:
156
+ """Read a CSV file into a Spark DataFrame.
157
+
158
+ Expects a header row. Uses the provided schema when available;
159
+ otherwise infers the schema. Empty strings are treated as ``null``.
160
+ """
161
+ csv_path = self.source_path.as_posix()
162
+
163
+ if self.schema:
164
+ return self.spark.read.options(header=True).option("nullValue", "null").schema(self.schema).csv(csv_path)
165
+ else:
166
+ return self.spark.read.options(header=True, inferSchema=True).option("nullValue", "null").csv(csv_path)
167
+
168
+ def write_delta(self) -> DataFrame:
169
+ """Convert the source file to Delta and write it to the cache directory.
170
+
171
+ Three write modes are supported (in priority order):
172
+
173
+ 1. **Liquid clustering** — when ``liquid_clustering=True``, creates the table
174
+ via DDL and saves with ``saveAsTable``. Requires *schema*.
175
+ 2. **Partitioned** — when ``partition_by`` is set, writes a Hive-partitioned
176
+ Delta table.
177
+ 3. **Plain** — unpartitioned Delta table saved directly to ``cached_path``.
178
+
179
+ Returns:
180
+ The source DataFrame (same object returned by :meth:`read`).
181
+
182
+ Raises:
183
+ ValueError: When ``liquid_clustering=True`` but no schema is provided.
184
+ """
185
+ if self.liquid_clustering and self.schema is None:
186
+ raise ValueError("liquid_clustering=True requires an explicit schema to be provided")
187
+
188
+ reader = self.read()
189
+ delta_location = self.cached_path.as_posix()
190
+ df_writer = reader.repartition(1).write.format("delta").mode("overwrite")
191
+
192
+ if self.liquid_clustering:
193
+ ddl = self.construct_table_ddl()
194
+ self.spark.sql(ddl)
195
+ df_writer.saveAsTable(self.dataset)
196
+ elif self.partition_by:
197
+ df_writer.partitionBy(self.partition_by).save(delta_location)
198
+ else:
199
+ df_writer.save(delta_location)
200
+
201
+ return reader
202
+
203
+ def clean_cache(self) -> None:
204
+ """Remove the cached Delta directory if it exists."""
205
+ if self.cached_path.exists():
206
+ shutil.rmtree(self.cached_path)
207
+
208
+ def construct_table_ddl(self) -> str:
209
+ """Build a ``CREATE TABLE … USING DELTA`` DDL statement for liquid clustering.
210
+
211
+ Uses ``self.schema`` to enumerate columns and ``self.partition_by`` for the
212
+ ``CLUSTER BY`` clause.
213
+
214
+ Returns:
215
+ A DDL string ready to pass to ``spark.sql()``.
216
+ """
217
+ delta_location = self.cached_path.as_posix()
218
+ columns = [f"{field.name} {field.dataType.simpleString()}" for field in self.schema.fields]
219
+ columns_str = ",\n ".join(columns)
220
+
221
+ if self.partition_by:
222
+ cluster_str = f"""CLUSTER BY ({", ".join(self.partition_by)})"""
223
+ else:
224
+ cluster_str = ""
225
+
226
+ return f"""
227
+ CREATE TABLE {self.dataset}({columns_str})
228
+ USING DELTA LOCATION '{delta_location}'
229
+ {cluster_str}
230
+ """
@@ -0,0 +1,320 @@
1
+ """pytest plugin providing PySpark fixtures with Delta Lake table caching.
2
+
3
+ Fixtures:
4
+ spark: Session-scoped SparkSession with optional Delta Lake support.
5
+ delta_tables: Function-scoped ``DeltaTablesResult`` with per-test isolation.
6
+ set_utc_timezone: Sets TZ=UTC for the test session.
7
+ drop_hive_objects: Drops all Hive tables (utility, not auto-used).
8
+
9
+ Configuration (pytest.ini / pyproject.toml / CLI):
10
+ delta_jar: Maven coordinates for Delta Lake JAR.
11
+ spark_app_name: Spark application name (default: pytest-pyspark).
12
+ delta_cache_dir: Cache directory name (default: _delta_cache).
13
+
14
+ Usage:
15
+ Define a module-scoped ``delta_tables_config`` fixture returning
16
+ ``dict[str, TableConfig]``, then use ``delta_tables`` in your tests.
17
+ """
18
+
19
+ import logging
20
+ import os
21
+ import random
22
+ import shutil
23
+ import string
24
+ import sys
25
+ import time
26
+ from dataclasses import dataclass, field
27
+ from datetime import datetime
28
+ from pathlib import Path
29
+
30
+ import pytest
31
+
32
+ from pytest_pyspark_utils.delta_caching import DeltaCaching
33
+ from pyspark.sql import DataFrame
34
+ from pyspark.sql.types import StructType
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ @dataclass
40
+ class TableConfig:
41
+ """Configuration for a single test table loaded from CSV or JSONL.
42
+
43
+ Args:
44
+ source: Subdirectory under the test folder where the source file lives
45
+ (``"input"`` or ``"expected"``). Defaults to ``"input"``.
46
+ schema: Explicit Spark schema. If ``None``, schema is inferred from the file.
47
+ table_name: SQL table name to register. Defaults to the source filename stem.
48
+ partition_by: Column names to partition the Delta table by.
49
+ Mutually exclusive with ``liquid_clustering``.
50
+ liquid_clustering: Enable Delta Lake liquid clustering (requires ``schema``).
51
+ Mutually exclusive with ``partition_by``.
52
+ """
53
+
54
+ source: str = "input"
55
+ schema: StructType | None = None
56
+ table_name: str | None = None
57
+ partition_by: list[str] | None = None
58
+ liquid_clustering: bool = False
59
+
60
+
61
+ @dataclass
62
+ class DeltaTablesResult:
63
+ """Result returned by the :func:`delta_tables` fixture.
64
+
65
+ Attributes:
66
+ tables: Mapping from config key (filename stem) to the corresponding DataFrame.
67
+ path: Filesystem path to the isolated Delta table copies for this test.
68
+ """
69
+
70
+ tables: dict[str, DataFrame]
71
+ path: str
72
+
73
+
74
+ @dataclass
75
+ class _CachedTables:
76
+ """Internal module-level cache returned by ``_prepare_tables_for_test``."""
77
+
78
+ entries: dict[str, tuple[str, DataFrame]] = field(default_factory=dict)
79
+ path: str = ""
80
+
81
+
82
+ def determine_file_path(base_path: str, filename: str) -> str:
83
+ """Find the unique CSV or JSONL file matching *filename* in *base_path*.
84
+
85
+ Args:
86
+ base_path: Directory to search in.
87
+ filename: Stem name (no extension) to match.
88
+
89
+ Returns:
90
+ Absolute path string to the matched file.
91
+
92
+ Raises:
93
+ FileNotFoundError: If no matching file exists.
94
+ FileExistsError: If more than one matching file exists.
95
+ """
96
+ file_matches = [file for file in Path(base_path).glob(f"{filename}.*") if file.suffix in [".jsonl", ".csv"]]
97
+
98
+ if not file_matches:
99
+ raise FileNotFoundError(f"No file found for {filename} in {base_path}")
100
+ elif len(file_matches) > 1:
101
+ raise FileExistsError(
102
+ f"Multiple files found for {filename} in {base_path}: {[file.name for file in file_matches]}. "
103
+ f"Please ensure there is only one file for {filename} in the directory."
104
+ )
105
+ else:
106
+ return f"{base_path}/{file_matches[0].name}"
107
+
108
+
109
+ def pytest_addoption(parser):
110
+ """Register CLI flags and INI options for the pyspark-delta-caching plugin."""
111
+ group = parser.getgroup("pyspark-delta-caching")
112
+ group.addoption(
113
+ "--delta-jar",
114
+ action="store",
115
+ dest="delta_jar",
116
+ default=None,
117
+ help=("Delta Lake Maven coordinates for spark.jars.packages, " "e.g. io.delta:delta-spark_2.13:4.0.1"),
118
+ )
119
+ parser.addini(
120
+ "delta_jar",
121
+ "Delta Lake Maven coordinates for spark.jars.packages",
122
+ default=None,
123
+ )
124
+ parser.addini(
125
+ "spark_app_name",
126
+ "Spark application name used in tests",
127
+ default="pytest-pyspark",
128
+ )
129
+ parser.addini(
130
+ "delta_cache_dir",
131
+ "Directory for persistent delta table cache (relative to rootdir)",
132
+ default="_delta_cache",
133
+ )
134
+
135
+
136
+ # --- Internal fixtures ---
137
+
138
+
139
+ @pytest.fixture(scope="session")
140
+ def _pyspark_tmp_dir(tmp_path_factory):
141
+ base = tmp_path_factory.mktemp("delta")
142
+ yield base
143
+ shutil.rmtree(base, ignore_errors=True)
144
+
145
+
146
+ @pytest.fixture(scope="module")
147
+ def _pyspark_module_delta_path(_pyspark_tmp_dir, request):
148
+ return (_pyspark_tmp_dir / Path(request.node.name).stem).as_posix()
149
+
150
+
151
+ @pytest.fixture(scope="module")
152
+ def _prepare_tables_for_test(spark, _pyspark_module_delta_path, request):
153
+ def _prepare_tables_for_test(files: dict) -> _CachedTables:
154
+ start = datetime.now()
155
+ test_dir = request.path.parent
156
+ cache_base_dir = test_dir / request.config.getini("delta_cache_dir")
157
+ temp_delta = Path(_pyspark_module_delta_path)
158
+ entries: dict[str, tuple[str, DataFrame]] = {}
159
+
160
+ for filename, config in files.items():
161
+ table_name = config.table_name or filename
162
+
163
+ location = (test_dir / config.source).as_posix()
164
+
165
+ file_path = determine_file_path(base_path=location, filename=filename)
166
+
167
+ delta_caching = DeltaCaching(
168
+ source_path=file_path,
169
+ cache_base_dir=cache_base_dir,
170
+ spark=spark,
171
+ schema=config.schema,
172
+ partition_by=config.partition_by,
173
+ liquid_clustering=config.liquid_clustering,
174
+ )
175
+ _df = delta_caching.cache()
176
+
177
+ delta_target_path = temp_delta / table_name
178
+ shutil.copytree(delta_caching.cached_path, delta_target_path)
179
+
180
+ spark.sql(f"DROP TABLE IF EXISTS {table_name}")
181
+ spark.sql(f"CREATE TABLE {table_name} USING DELTA LOCATION '{delta_target_path.as_posix()}'")
182
+
183
+ entries[filename] = (table_name, _df)
184
+ print(f"successfully created delta table for {filename}")
185
+
186
+ duration = round((datetime.now() - start).total_seconds(), 1)
187
+ print(f"done with creating tables ({duration}s).")
188
+
189
+ return _CachedTables(entries=entries, path=_pyspark_module_delta_path)
190
+
191
+ return _prepare_tables_for_test
192
+
193
+
194
+ @pytest.fixture(scope="module")
195
+ def _delta_tables_cached(_prepare_tables_for_test, delta_tables_config) -> _CachedTables:
196
+ return _prepare_tables_for_test(delta_tables_config)
197
+
198
+
199
+ # --- Public fixtures ---
200
+
201
+
202
+ @pytest.fixture(scope="session")
203
+ def set_utc_timezone():
204
+ """Set the process timezone to UTC for the duration of the test session."""
205
+ os.environ["TZ"] = "UTC"
206
+ time.tzset()
207
+
208
+
209
+ @pytest.fixture(scope="session")
210
+ def spark(set_utc_timezone, request):
211
+ """Create a session-scoped SparkSession configured for local testing.
212
+
213
+ Enables Delta Lake support when ``delta_jar`` is configured. The session
214
+ uses a randomly-named database so parallel test runs remain isolated in the
215
+ Hive metastore.
216
+
217
+ Yields:
218
+ SparkSession ready for use in tests.
219
+ """
220
+ from pyspark.sql import SparkSession
221
+
222
+ delta_jar = request.config.getoption("--delta-jar") or request.config.getini("delta_jar") or None
223
+ app_name = request.config.getini("spark_app_name")
224
+ database_name = "pytest_" + "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(4))
225
+
226
+ os.environ["PYSPARK_PYTHON"] = sys.executable
227
+ os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable
228
+
229
+ builder = (
230
+ SparkSession.builder.master("local[*]")
231
+ .appName(app_name)
232
+ .config("spark.sql.shuffle.partitions", "1")
233
+ .config("spark.databricks.delta.snapshotPartitions", "2")
234
+ .config("spark.ui.showConsoleProgress", "false")
235
+ .config("spark.ui.enabled", "false")
236
+ .config("spark.ui.dagGraph.retainedRootRDDs", "1")
237
+ .config("spark.ui.retainedJobs", "1")
238
+ .config("spark.ui.retainedStages", "1")
239
+ .config("spark.ui.retainedTasks", "1")
240
+ .config("spark.sql.ui.retainedExecutions", "1")
241
+ .config("spark.worker.ui.retainedExecutors", "1")
242
+ .config("spark.worker.ui.retainedDrivers", "1")
243
+ .config("spark.driver.memory", "4g")
244
+ .config("spark.sql.autoBroadcastJoinThreshold", "-1")
245
+ .config(
246
+ "spark.driver.extraJavaOptions",
247
+ "-Duser.timezone=UTC -XX:+UseCompressedOops",
248
+ )
249
+ .config("spark.executor.extraJavaOptions", "-Duser.timezone=UTC")
250
+ .config("spark.sql.session.timeZone", "UTC")
251
+ )
252
+
253
+ if delta_jar:
254
+ builder = (
255
+ builder.config("spark.jars.packages", delta_jar)
256
+ .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
257
+ .config(
258
+ "spark.sql.catalog.spark_catalog",
259
+ "org.apache.spark.sql.delta.catalog.DeltaCatalog",
260
+ )
261
+ )
262
+
263
+ spark_session = builder.getOrCreate()
264
+ spark_session.sparkContext.setLogLevel("ERROR")
265
+ spark_session.sql(f"create database if not exists {database_name}")
266
+ spark_session.sql(f"use {database_name}")
267
+
268
+ try:
269
+ yield spark_session
270
+ finally:
271
+ spark_session.stop()
272
+
273
+
274
+ @pytest.fixture(scope="function")
275
+ def delta_tables(spark, _delta_tables_cached: _CachedTables, _pyspark_tmp_dir, tmp_path) -> DeltaTablesResult:
276
+ """Provide per-test isolated Delta tables as a :class:`DeltaTablesResult`.
277
+
278
+ Copies the module-level cached tables to a function-specific directory,
279
+ drops all existing Hive tables, and re-registers fresh copies. Mutations
280
+ made during a test do not affect sibling tests.
281
+
282
+ Args:
283
+ spark: The session-scoped SparkSession.
284
+ _delta_tables_cached: Module-level cached table entries.
285
+ _pyspark_tmp_dir: Session-scoped temp directory.
286
+ tmp_path: pytest-provided per-test temp directory (used as a unique suffix).
287
+
288
+ Returns:
289
+ A :class:`DeltaTablesResult` with ``tables`` (filename → DataFrame) and
290
+ ``path`` (directory holding the isolated Delta copies).
291
+ """
292
+ source = _delta_tables_cached.path
293
+ dest = Path(str(_pyspark_tmp_dir)) / "isolated_tables" / tmp_path.name
294
+ shutil.copytree(Path(source), dest, dirs_exist_ok=True)
295
+
296
+ tables = spark.sql("SHOW TABLES").collect()
297
+ for table in tables:
298
+ fqn = f"{table.namespace}.{table.tableName}" if table.namespace else table.tableName
299
+ spark.sql(f"DROP TABLE IF EXISTS {fqn}")
300
+
301
+ result_tables: dict[str, DataFrame] = {}
302
+ for filename, (table_name, df) in _delta_tables_cached.entries.items():
303
+ table_path = dest / table_name
304
+ spark.sql(f"CREATE TABLE {table_name} USING DELTA LOCATION '{table_path.as_posix()}'")
305
+ result_tables[filename] = df
306
+
307
+ return DeltaTablesResult(tables=result_tables, path=dest.as_posix())
308
+
309
+
310
+ @pytest.fixture(scope="function")
311
+ def drop_hive_objects(spark):
312
+ """Drop all Hive tables in the current database.
313
+
314
+ Useful as an explicit teardown step in tests that create their own tables
315
+ outside of the ``delta_tables`` fixture.
316
+ """
317
+ tables = spark.sql("SHOW TABLES").collect()
318
+ for table in tables:
319
+ fqn = f"{table.namespace}.{table.tableName}" if table.namespace else table.tableName
320
+ spark.sql(f"DROP TABLE IF EXISTS {fqn}")