databricks4py 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- databricks4py/__init__.py +56 -0
- databricks4py/catalog.py +65 -0
- databricks4py/config/__init__.py +6 -0
- databricks4py/config/base.py +119 -0
- databricks4py/config/unity.py +72 -0
- databricks4py/filters/__init__.py +17 -0
- databricks4py/filters/base.py +154 -0
- databricks4py/io/__init__.py +40 -0
- databricks4py/io/checkpoint.py +98 -0
- databricks4py/io/dbfs.py +91 -0
- databricks4py/io/delta.py +564 -0
- databricks4py/io/merge.py +176 -0
- databricks4py/io/streaming.py +281 -0
- databricks4py/logging.py +39 -0
- databricks4py/metrics/__init__.py +22 -0
- databricks4py/metrics/base.py +66 -0
- databricks4py/metrics/delta_sink.py +75 -0
- databricks4py/metrics/logging_sink.py +20 -0
- databricks4py/migrations/__init__.py +27 -0
- databricks4py/migrations/alter.py +114 -0
- databricks4py/migrations/runner.py +241 -0
- databricks4py/migrations/schema_diff.py +136 -0
- databricks4py/migrations/validators.py +195 -0
- databricks4py/observability/__init__.py +24 -0
- databricks4py/observability/_utils.py +24 -0
- databricks4py/observability/batch_context.py +134 -0
- databricks4py/observability/health.py +223 -0
- databricks4py/observability/query_listener.py +236 -0
- databricks4py/py.typed +0 -0
- databricks4py/quality/__init__.py +26 -0
- databricks4py/quality/base.py +54 -0
- databricks4py/quality/expectations.py +184 -0
- databricks4py/quality/gate.py +90 -0
- databricks4py/retry.py +102 -0
- databricks4py/secrets.py +69 -0
- databricks4py/spark_session.py +68 -0
- databricks4py/testing/__init__.py +35 -0
- databricks4py/testing/assertions.py +111 -0
- databricks4py/testing/builders.py +127 -0
- databricks4py/testing/fixtures.py +134 -0
- databricks4py/testing/mocks.py +106 -0
- databricks4py/testing/temp_table.py +73 -0
- databricks4py/workflow.py +219 -0
- databricks4py-0.2.0.dist-info/METADATA +589 -0
- databricks4py-0.2.0.dist-info/RECORD +48 -0
- databricks4py-0.2.0.dist-info/WHEEL +5 -0
- databricks4py-0.2.0.dist-info/licenses/LICENSE +21 -0
- databricks4py-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
"""Streaming query progress observer via PySpark StreamingQueryListener.
|
|
2
|
+
|
|
3
|
+
Wraps the PySpark 3.4+ ``StreamingQueryListener`` API into a simple observer
|
|
4
|
+
that collects progress snapshots, optionally emits them to a
|
|
5
|
+
:class:`~databricks4py.metrics.base.MetricsSink`, and exposes the latest
|
|
6
|
+
state for health checks.
|
|
7
|
+
|
|
8
|
+
Requires PySpark >= 3.4. On older versions, :meth:`QueryProgressObserver.attach`
|
|
9
|
+
raises ``ImportError`` with a clear message.
|
|
10
|
+
|
|
11
|
+
Example::
|
|
12
|
+
|
|
13
|
+
observer = QueryProgressObserver(spark=spark, metrics_sink=my_sink)
|
|
14
|
+
observer.attach()
|
|
15
|
+
|
|
16
|
+
query = df.writeStream.start()
|
|
17
|
+
# ... wait for progress ...
|
|
18
|
+
|
|
19
|
+
latest = observer.latest_progress()
|
|
20
|
+
if latest:
|
|
21
|
+
print(f"Batch {latest.batch_id}: {latest.num_input_rows} rows")
|
|
22
|
+
|
|
23
|
+
observer.detach()
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
from __future__ import annotations
|
|
27
|
+
|
|
28
|
+
import json
|
|
29
|
+
import logging
|
|
30
|
+
from collections import deque
|
|
31
|
+
from dataclasses import dataclass, field
|
|
32
|
+
from datetime import datetime, timezone
|
|
33
|
+
from typing import TYPE_CHECKING, Any
|
|
34
|
+
|
|
35
|
+
from databricks4py.observability._utils import parse_duration_ms
|
|
36
|
+
from databricks4py.spark_session import active_fallback
|
|
37
|
+
|
|
38
|
+
if TYPE_CHECKING:
|
|
39
|
+
from collections.abc import Callable
|
|
40
|
+
|
|
41
|
+
from pyspark.sql import SparkSession
|
|
42
|
+
|
|
43
|
+
from databricks4py.metrics.base import MetricsSink
|
|
44
|
+
|
|
45
|
+
__all__ = ["QueryProgressObserver", "QueryProgressSnapshot"]
|
|
46
|
+
|
|
47
|
+
logger = logging.getLogger(__name__)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass(frozen=True)
|
|
51
|
+
class QueryProgressSnapshot:
|
|
52
|
+
"""Immutable snapshot of a single streaming query progress event.
|
|
53
|
+
|
|
54
|
+
Attributes:
|
|
55
|
+
query_id: Unique query identifier (UUID string).
|
|
56
|
+
query_name: Optional query name (set via ``.queryName()``).
|
|
57
|
+
batch_id: Monotonically increasing batch counter.
|
|
58
|
+
batch_duration_ms: Wall-clock time for the batch in milliseconds.
|
|
59
|
+
input_rows_per_second: Rate of data arriving from the source.
|
|
60
|
+
processed_rows_per_second: Rate of data being processed.
|
|
61
|
+
num_input_rows: Total rows read in this batch.
|
|
62
|
+
timestamp: UTC time when the progress was recorded.
|
|
63
|
+
sources: Per-source offset and rate info.
|
|
64
|
+
sink: Sink commit progress info.
|
|
65
|
+
raw_json: Full progress JSON for custom parsing.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
query_id: str
|
|
69
|
+
query_name: str | None
|
|
70
|
+
batch_id: int
|
|
71
|
+
batch_duration_ms: int
|
|
72
|
+
input_rows_per_second: float
|
|
73
|
+
processed_rows_per_second: float
|
|
74
|
+
num_input_rows: int
|
|
75
|
+
timestamp: datetime
|
|
76
|
+
sources: list[dict[str, Any]] = field(default_factory=list)
|
|
77
|
+
sink: dict[str, Any] = field(default_factory=dict)
|
|
78
|
+
raw_json: str = ""
|
|
79
|
+
|
|
80
|
+
@classmethod
|
|
81
|
+
def from_progress(cls, progress: Any) -> QueryProgressSnapshot:
|
|
82
|
+
"""Build from a PySpark ``StreamingQueryProgress`` object.
|
|
83
|
+
|
|
84
|
+
Expects an object with a ``.json`` property returning a JSON string
|
|
85
|
+
(standard ``StreamingQueryProgress`` in PySpark 3.4+).
|
|
86
|
+
"""
|
|
87
|
+
import json as _json
|
|
88
|
+
|
|
89
|
+
raw = progress.json
|
|
90
|
+
data = _json.loads(raw)
|
|
91
|
+
|
|
92
|
+
return cls(
|
|
93
|
+
query_id=data.get("id", ""),
|
|
94
|
+
query_name=data.get("name"),
|
|
95
|
+
batch_id=data.get("batchId", -1),
|
|
96
|
+
batch_duration_ms=parse_duration_ms(data.get("batchDuration", "0 ms")),
|
|
97
|
+
input_rows_per_second=data.get("inputRowsPerSecond", 0.0),
|
|
98
|
+
processed_rows_per_second=data.get("processedRowsPerSecond", 0.0),
|
|
99
|
+
num_input_rows=data.get("numInputRows", 0),
|
|
100
|
+
timestamp=datetime.now(tz=timezone.utc),
|
|
101
|
+
sources=data.get("sources", []),
|
|
102
|
+
sink=data.get("sink", {}),
|
|
103
|
+
raw_json=raw,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class QueryProgressObserver:
|
|
108
|
+
"""Collects streaming query progress events and routes them to a metrics sink.
|
|
109
|
+
|
|
110
|
+
Attaches a ``StreamingQueryListener`` to the SparkSession. Each progress
|
|
111
|
+
event is captured as a :class:`QueryProgressSnapshot`, stored in a bounded
|
|
112
|
+
history, and optionally emitted as a ``MetricEvent``.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
spark: SparkSession to attach the listener to.
|
|
116
|
+
metrics_sink: Optional sink for ``query_progress`` metric events.
|
|
117
|
+
on_progress: Optional callback invoked on each progress event.
|
|
118
|
+
history_size: Maximum number of snapshots to retain. Oldest are evicted.
|
|
119
|
+
query_name_filter: If set, only track queries with this name.
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
def __init__(
|
|
123
|
+
self,
|
|
124
|
+
*,
|
|
125
|
+
spark: SparkSession | None = None,
|
|
126
|
+
metrics_sink: MetricsSink | None = None,
|
|
127
|
+
on_progress: Callable[[QueryProgressSnapshot], None] | None = None,
|
|
128
|
+
history_size: int = 100,
|
|
129
|
+
query_name_filter: str | None = None,
|
|
130
|
+
) -> None:
|
|
131
|
+
self._spark = active_fallback(spark)
|
|
132
|
+
self._metrics_sink = metrics_sink
|
|
133
|
+
self._on_progress = on_progress
|
|
134
|
+
self._history: deque[QueryProgressSnapshot] = deque(maxlen=history_size)
|
|
135
|
+
self._query_name_filter = query_name_filter
|
|
136
|
+
self._listener: Any = None
|
|
137
|
+
self._attached = False
|
|
138
|
+
|
|
139
|
+
def attach(self) -> None:
|
|
140
|
+
"""Register the listener with the SparkSession.
|
|
141
|
+
|
|
142
|
+
Raises:
|
|
143
|
+
ImportError: If PySpark < 3.4 (no Python listener support).
|
|
144
|
+
"""
|
|
145
|
+
try:
|
|
146
|
+
from pyspark.sql.streaming.listener import StreamingQueryListener
|
|
147
|
+
except ImportError as exc:
|
|
148
|
+
raise ImportError(
|
|
149
|
+
"StreamingQueryListener requires PySpark >= 3.4. "
|
|
150
|
+
"Upgrade PySpark or use polling via query.lastProgress instead."
|
|
151
|
+
) from exc
|
|
152
|
+
|
|
153
|
+
observer = self
|
|
154
|
+
|
|
155
|
+
class _Listener(StreamingQueryListener):
|
|
156
|
+
def onQueryStarted(self, event: Any) -> None:
|
|
157
|
+
logger.debug("Query started: id=%s name=%s", event.id, event.name)
|
|
158
|
+
|
|
159
|
+
def onQueryProgress(self, event: Any) -> None:
|
|
160
|
+
observer._handle_progress(event.progress)
|
|
161
|
+
|
|
162
|
+
def onQueryTerminated(self, event: Any) -> None:
|
|
163
|
+
exc_str = str(event.exception)[:500] if event.exception else None
|
|
164
|
+
logger.info("Query terminated: id=%s exception=%s", event.id, exc_str)
|
|
165
|
+
|
|
166
|
+
self._listener = _Listener()
|
|
167
|
+
self._spark.streams.addListener(self._listener)
|
|
168
|
+
self._attached = True
|
|
169
|
+
logger.info("QueryProgressObserver attached")
|
|
170
|
+
|
|
171
|
+
def detach(self) -> None:
|
|
172
|
+
"""Remove the listener from the SparkSession."""
|
|
173
|
+
if self._listener is not None and self._attached:
|
|
174
|
+
self._spark.streams.removeListener(self._listener)
|
|
175
|
+
self._attached = False
|
|
176
|
+
logger.info("QueryProgressObserver detached")
|
|
177
|
+
|
|
178
|
+
def _handle_progress(self, progress: Any) -> None:
|
|
179
|
+
try:
|
|
180
|
+
snapshot = QueryProgressSnapshot.from_progress(progress)
|
|
181
|
+
except (json.JSONDecodeError, KeyError, TypeError, ValueError, AttributeError):
|
|
182
|
+
logger.exception("Failed to process query progress event")
|
|
183
|
+
return
|
|
184
|
+
|
|
185
|
+
if self._query_name_filter is not None and snapshot.query_name != self._query_name_filter:
|
|
186
|
+
return
|
|
187
|
+
|
|
188
|
+
self._history.append(snapshot)
|
|
189
|
+
|
|
190
|
+
if self._on_progress is not None:
|
|
191
|
+
self._on_progress(snapshot)
|
|
192
|
+
|
|
193
|
+
if self._metrics_sink is not None:
|
|
194
|
+
from databricks4py.metrics.base import MetricEvent
|
|
195
|
+
|
|
196
|
+
self._metrics_sink.emit(
|
|
197
|
+
MetricEvent(
|
|
198
|
+
job_name=snapshot.query_name or snapshot.query_id,
|
|
199
|
+
event_type="query_progress",
|
|
200
|
+
timestamp=snapshot.timestamp,
|
|
201
|
+
duration_ms=snapshot.batch_duration_ms,
|
|
202
|
+
row_count=snapshot.num_input_rows,
|
|
203
|
+
batch_id=snapshot.batch_id,
|
|
204
|
+
metadata={
|
|
205
|
+
"input_rows_per_second": snapshot.input_rows_per_second,
|
|
206
|
+
"processed_rows_per_second": snapshot.processed_rows_per_second,
|
|
207
|
+
},
|
|
208
|
+
)
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
logger.debug(
|
|
212
|
+
"Progress: query=%s batch=%d rows=%d rate=%.1f rows/s",
|
|
213
|
+
snapshot.query_name or snapshot.query_id,
|
|
214
|
+
snapshot.batch_id,
|
|
215
|
+
snapshot.num_input_rows,
|
|
216
|
+
snapshot.processed_rows_per_second,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
def latest_progress(self) -> QueryProgressSnapshot | None:
|
|
220
|
+
"""Most recent progress snapshot, or None if no events received yet."""
|
|
221
|
+
return self._history[-1] if self._history else None
|
|
222
|
+
|
|
223
|
+
def history(self, limit: int | None = None) -> list[QueryProgressSnapshot]:
|
|
224
|
+
"""Recent progress snapshots, newest last.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
limit: Maximum number of snapshots to return. Returns all if None.
|
|
228
|
+
"""
|
|
229
|
+
items = list(self._history)
|
|
230
|
+
if limit is not None:
|
|
231
|
+
items = items[-limit:]
|
|
232
|
+
return items
|
|
233
|
+
|
|
234
|
+
@property
|
|
235
|
+
def is_attached(self) -> bool:
|
|
236
|
+
return self._attached
|
databricks4py/py.typed
ADDED
|
File without changes
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""Data quality expectations and enforcement."""
|
|
2
|
+
|
|
3
|
+
from databricks4py.quality.base import Expectation, ExpectationResult, QualityReport
|
|
4
|
+
from databricks4py.quality.expectations import (
|
|
5
|
+
ColumnExists,
|
|
6
|
+
InRange,
|
|
7
|
+
MatchesRegex,
|
|
8
|
+
NotNull,
|
|
9
|
+
RowCount,
|
|
10
|
+
Unique,
|
|
11
|
+
)
|
|
12
|
+
from databricks4py.quality.gate import QualityError, QualityGate
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"ColumnExists",
|
|
16
|
+
"Expectation",
|
|
17
|
+
"ExpectationResult",
|
|
18
|
+
"InRange",
|
|
19
|
+
"MatchesRegex",
|
|
20
|
+
"NotNull",
|
|
21
|
+
"QualityError",
|
|
22
|
+
"QualityGate",
|
|
23
|
+
"QualityReport",
|
|
24
|
+
"RowCount",
|
|
25
|
+
"Unique",
|
|
26
|
+
]
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Core quality abstractions: expectations, results, and reports."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from pyspark.sql import Column, DataFrame, Row
|
|
11
|
+
|
|
12
|
+
__all__ = ["Expectation", "ExpectationResult", "QualityReport"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(frozen=True)
|
|
16
|
+
class ExpectationResult:
|
|
17
|
+
"""Result of a single expectation check."""
|
|
18
|
+
|
|
19
|
+
expectation: str
|
|
20
|
+
passed: bool
|
|
21
|
+
total_rows: int
|
|
22
|
+
failing_rows: int = 0
|
|
23
|
+
sample: list[Row] | None = None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass(frozen=True)
|
|
27
|
+
class QualityReport:
|
|
28
|
+
"""Aggregated results from running multiple expectations."""
|
|
29
|
+
|
|
30
|
+
results: list[ExpectationResult]
|
|
31
|
+
passed: bool
|
|
32
|
+
|
|
33
|
+
def summary(self) -> str:
|
|
34
|
+
lines = []
|
|
35
|
+
for r in self.results:
|
|
36
|
+
status = "PASS" if r.passed else "FAIL"
|
|
37
|
+
lines.append(f"[{status}] {r.expectation} ({r.failing_rows}/{r.total_rows} failing)")
|
|
38
|
+
overall = "PASSED" if self.passed else "FAILED"
|
|
39
|
+
lines.append(f"Overall: {overall}")
|
|
40
|
+
return "\n".join(lines)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class Expectation(ABC):
|
|
44
|
+
"""Abstract base for DataFrame quality expectations."""
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
def validate(self, df: DataFrame) -> ExpectationResult: ...
|
|
48
|
+
|
|
49
|
+
def failing_condition(self) -> Column | None:
|
|
50
|
+
"""Return a Column expression that is True for failing rows.
|
|
51
|
+
|
|
52
|
+
Returns None for aggregate checks that can't filter individual rows.
|
|
53
|
+
"""
|
|
54
|
+
return None
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
"""Built-in expectation implementations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from functools import reduce
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
from databricks4py.quality.base import Expectation, ExpectationResult
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from pyspark.sql import Column, DataFrame
|
|
12
|
+
|
|
13
|
+
__all__ = ["ColumnExists", "InRange", "MatchesRegex", "NotNull", "RowCount", "Unique"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class NotNull(Expectation):
|
|
17
|
+
"""Validates that specified columns contain no null values."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, *columns: str) -> None:
|
|
20
|
+
self._columns = columns
|
|
21
|
+
|
|
22
|
+
def validate(self, df: DataFrame) -> ExpectationResult:
|
|
23
|
+
total = df.count()
|
|
24
|
+
condition = self.failing_condition()
|
|
25
|
+
failing = df.where(condition).count() if condition is not None else 0
|
|
26
|
+
return ExpectationResult(
|
|
27
|
+
expectation=repr(self),
|
|
28
|
+
passed=failing == 0,
|
|
29
|
+
total_rows=total,
|
|
30
|
+
failing_rows=failing,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
def failing_condition(self) -> Column | None:
|
|
34
|
+
from pyspark.sql import functions as F
|
|
35
|
+
|
|
36
|
+
conditions = [F.col(c).isNull() for c in self._columns]
|
|
37
|
+
return reduce(lambda a, b: a | b, conditions)
|
|
38
|
+
|
|
39
|
+
def __repr__(self) -> str:
|
|
40
|
+
cols = ", ".join(repr(c) for c in self._columns)
|
|
41
|
+
return f"NotNull({cols})"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class InRange(Expectation):
|
|
45
|
+
"""Validates that a column's values fall within bounds."""
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self, column: str, *, min_val: float | None = None, max_val: float | None = None
|
|
49
|
+
) -> None:
|
|
50
|
+
self._column = column
|
|
51
|
+
self._min_val = min_val
|
|
52
|
+
self._max_val = max_val
|
|
53
|
+
|
|
54
|
+
def validate(self, df: DataFrame) -> ExpectationResult:
|
|
55
|
+
total = df.count()
|
|
56
|
+
condition = self.failing_condition()
|
|
57
|
+
failing = df.where(condition).count() if condition is not None else 0
|
|
58
|
+
return ExpectationResult(
|
|
59
|
+
expectation=repr(self),
|
|
60
|
+
passed=failing == 0,
|
|
61
|
+
total_rows=total,
|
|
62
|
+
failing_rows=failing,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
def failing_condition(self) -> Column | None:
|
|
66
|
+
from pyspark.sql import functions as F
|
|
67
|
+
|
|
68
|
+
conditions = []
|
|
69
|
+
col = F.col(self._column)
|
|
70
|
+
if self._min_val is not None:
|
|
71
|
+
conditions.append(col < self._min_val)
|
|
72
|
+
if self._max_val is not None:
|
|
73
|
+
conditions.append(col > self._max_val)
|
|
74
|
+
if not conditions:
|
|
75
|
+
return None
|
|
76
|
+
return reduce(lambda a, b: a | b, conditions)
|
|
77
|
+
|
|
78
|
+
def __repr__(self) -> str:
|
|
79
|
+
return f"InRange({self._column!r}, min_val={self._min_val!r}, max_val={self._max_val!r})"
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class Unique(Expectation):
|
|
83
|
+
"""Validates no duplicate rows exist for the specified columns."""
|
|
84
|
+
|
|
85
|
+
def __init__(self, *columns: str) -> None:
|
|
86
|
+
self._columns = columns
|
|
87
|
+
|
|
88
|
+
def validate(self, df: DataFrame) -> ExpectationResult:
|
|
89
|
+
total = df.count()
|
|
90
|
+
distinct = df.select(*self._columns).distinct().count()
|
|
91
|
+
failing = total - distinct
|
|
92
|
+
return ExpectationResult(
|
|
93
|
+
expectation=repr(self),
|
|
94
|
+
passed=failing == 0,
|
|
95
|
+
total_rows=total,
|
|
96
|
+
failing_rows=failing,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def __repr__(self) -> str:
|
|
100
|
+
cols = ", ".join(repr(c) for c in self._columns)
|
|
101
|
+
return f"Unique({cols})"
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class RowCount(Expectation):
|
|
105
|
+
"""Validates that the DataFrame row count falls within bounds."""
|
|
106
|
+
|
|
107
|
+
def __init__(self, *, min_count: int | None = None, max_count: int | None = None) -> None:
|
|
108
|
+
self._min_count = min_count
|
|
109
|
+
self._max_count = max_count
|
|
110
|
+
|
|
111
|
+
def validate(self, df: DataFrame) -> ExpectationResult:
|
|
112
|
+
total = df.count()
|
|
113
|
+
passed = True
|
|
114
|
+
if self._min_count is not None and total < self._min_count:
|
|
115
|
+
passed = False
|
|
116
|
+
if self._max_count is not None and total > self._max_count:
|
|
117
|
+
passed = False
|
|
118
|
+
return ExpectationResult(
|
|
119
|
+
expectation=repr(self),
|
|
120
|
+
passed=passed,
|
|
121
|
+
total_rows=total,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
def __repr__(self) -> str:
|
|
125
|
+
return f"RowCount(min_count={self._min_count!r}, max_count={self._max_count!r})"
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class MatchesRegex(Expectation):
|
|
129
|
+
"""Validates that a column's values match a regex pattern."""
|
|
130
|
+
|
|
131
|
+
def __init__(self, column: str, pattern: str) -> None:
|
|
132
|
+
self._column = column
|
|
133
|
+
self._pattern = pattern
|
|
134
|
+
|
|
135
|
+
def validate(self, df: DataFrame) -> ExpectationResult:
|
|
136
|
+
total = df.count()
|
|
137
|
+
condition = self.failing_condition()
|
|
138
|
+
failing = df.where(condition).count() if condition is not None else 0
|
|
139
|
+
return ExpectationResult(
|
|
140
|
+
expectation=repr(self),
|
|
141
|
+
passed=failing == 0,
|
|
142
|
+
total_rows=total,
|
|
143
|
+
failing_rows=failing,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
def failing_condition(self) -> Column | None:
|
|
147
|
+
from pyspark.sql import functions as F
|
|
148
|
+
|
|
149
|
+
return ~F.col(self._column).rlike(self._pattern)
|
|
150
|
+
|
|
151
|
+
def __repr__(self) -> str:
|
|
152
|
+
return f"MatchesRegex({self._column!r}, {self._pattern!r})"
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class ColumnExists(Expectation):
|
|
156
|
+
"""Validates that specified columns exist in the DataFrame schema."""
|
|
157
|
+
|
|
158
|
+
def __init__(self, *columns: str, dtype: str | None = None) -> None:
|
|
159
|
+
self._columns = columns
|
|
160
|
+
self._dtype = dtype
|
|
161
|
+
|
|
162
|
+
def validate(self, df: DataFrame) -> ExpectationResult:
|
|
163
|
+
total = df.count()
|
|
164
|
+
actual_cols = set(df.columns)
|
|
165
|
+
missing = [c for c in self._columns if c not in actual_cols]
|
|
166
|
+
|
|
167
|
+
if not missing and self._dtype:
|
|
168
|
+
schema_fields = {f.name: f.dataType.simpleString() for f in df.schema.fields}
|
|
169
|
+
for c in self._columns:
|
|
170
|
+
if schema_fields.get(c) != self._dtype:
|
|
171
|
+
missing.append(c)
|
|
172
|
+
|
|
173
|
+
return ExpectationResult(
|
|
174
|
+
expectation=repr(self),
|
|
175
|
+
passed=len(missing) == 0,
|
|
176
|
+
total_rows=total,
|
|
177
|
+
failing_rows=len(missing),
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
def __repr__(self) -> str:
|
|
181
|
+
cols = ", ".join(repr(c) for c in self._columns)
|
|
182
|
+
if self._dtype:
|
|
183
|
+
return f"ColumnExists({cols}, dtype={self._dtype!r})"
|
|
184
|
+
return f"ColumnExists({cols})"
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""Quality gate: orchestrate expectations and enforce data quality."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from typing import TYPE_CHECKING, Literal
|
|
8
|
+
|
|
9
|
+
from databricks4py.quality.base import Expectation, QualityReport
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from pyspark.sql import DataFrame
|
|
13
|
+
|
|
14
|
+
__all__ = ["QualityError", "QualityGate"]
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class QualityError(Exception):
|
|
20
|
+
"""Raised when a quality gate fails in 'raise' mode."""
|
|
21
|
+
|
|
22
|
+
def __init__(self, report: QualityReport) -> None:
|
|
23
|
+
self.report = report
|
|
24
|
+
super().__init__(f"Quality check failed:\n{report.summary()}")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class QualityGate:
|
|
28
|
+
"""Runs expectations against a DataFrame and enforces quality policy.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
*expectations: Expectation instances to evaluate.
|
|
32
|
+
on_fail: Action on failure: "raise", "warn", or "quarantine".
|
|
33
|
+
quarantine_handler: Callable receiving the bad-rows DataFrame.
|
|
34
|
+
Required when on_fail="quarantine".
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
*expectations: Expectation,
|
|
40
|
+
on_fail: Literal["raise", "warn", "quarantine"] = "raise",
|
|
41
|
+
quarantine_handler: Callable[[DataFrame], None] | None = None,
|
|
42
|
+
) -> None:
|
|
43
|
+
if on_fail == "quarantine" and quarantine_handler is None:
|
|
44
|
+
raise ValueError("quarantine_handler is required when on_fail='quarantine'")
|
|
45
|
+
self._expectations = expectations
|
|
46
|
+
self._on_fail = on_fail
|
|
47
|
+
self._quarantine_handler = quarantine_handler
|
|
48
|
+
|
|
49
|
+
def check(self, df: DataFrame) -> QualityReport:
|
|
50
|
+
"""Run all expectations and return a report."""
|
|
51
|
+
results = [exp.validate(df) for exp in self._expectations]
|
|
52
|
+
passed = all(r.passed for r in results)
|
|
53
|
+
return QualityReport(results=results, passed=passed)
|
|
54
|
+
|
|
55
|
+
def enforce(self, df: DataFrame) -> DataFrame:
|
|
56
|
+
"""Run expectations and enforce the configured failure policy.
|
|
57
|
+
|
|
58
|
+
Returns the clean DataFrame (original if all pass, or filtered
|
|
59
|
+
if quarantine mode splits out bad rows).
|
|
60
|
+
"""
|
|
61
|
+
report = self.check(df)
|
|
62
|
+
if report.passed:
|
|
63
|
+
return df
|
|
64
|
+
|
|
65
|
+
if self._on_fail == "raise":
|
|
66
|
+
raise QualityError(report)
|
|
67
|
+
|
|
68
|
+
if self._on_fail == "warn":
|
|
69
|
+
logger.warning("Quality check failed:\n%s", report.summary())
|
|
70
|
+
return df
|
|
71
|
+
|
|
72
|
+
# quarantine: split bad rows using failing_condition()
|
|
73
|
+
from functools import reduce
|
|
74
|
+
|
|
75
|
+
conditions = []
|
|
76
|
+
for exp in self._expectations:
|
|
77
|
+
cond = exp.failing_condition()
|
|
78
|
+
if cond is not None:
|
|
79
|
+
conditions.append(cond)
|
|
80
|
+
|
|
81
|
+
if conditions:
|
|
82
|
+
bad_condition = reduce(lambda a, b: a | b, conditions)
|
|
83
|
+
bad_rows = df.where(bad_condition)
|
|
84
|
+
clean_rows = df.where(~bad_condition)
|
|
85
|
+
self._quarantine_handler(bad_rows) # type: ignore[misc]
|
|
86
|
+
return clean_rows
|
|
87
|
+
|
|
88
|
+
# no row-level conditions available — quarantine nothing
|
|
89
|
+
logger.warning("Quarantine requested but no row-level failing conditions available")
|
|
90
|
+
return df
|
databricks4py/retry.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
"""Retry decorator with exponential backoff for transient failures."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import functools
|
|
6
|
+
import logging
|
|
7
|
+
import time
|
|
8
|
+
from collections.abc import Callable
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
__all__ = ["RetryConfig", "retry"]
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _default_retryable_exceptions() -> tuple[type[BaseException], ...]:
|
|
18
|
+
exceptions: list[type[BaseException]] = [ConnectionError, TimeoutError, OSError]
|
|
19
|
+
try:
|
|
20
|
+
from py4j.protocol import Py4JNetworkError
|
|
21
|
+
|
|
22
|
+
exceptions.append(Py4JNetworkError)
|
|
23
|
+
except ImportError:
|
|
24
|
+
pass
|
|
25
|
+
return tuple(exceptions)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass(frozen=True)
|
|
29
|
+
class RetryConfig:
|
|
30
|
+
"""Configuration for :func:`retry` behaviour.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
max_attempts: Total number of tries (including the first).
|
|
34
|
+
base_delay_seconds: Initial delay before the first retry.
|
|
35
|
+
max_delay_seconds: Upper cap on the exponentially increasing delay.
|
|
36
|
+
backoff_factor: Multiplier applied to the delay after each failure.
|
|
37
|
+
retryable_exceptions: Exception types that trigger a retry.
|
|
38
|
+
Defaults to ``ConnectionError``, ``TimeoutError``, ``OSError``,
|
|
39
|
+
and ``Py4JNetworkError`` (if py4j is installed).
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
max_attempts: int = 3
|
|
43
|
+
base_delay_seconds: float = 1.0
|
|
44
|
+
max_delay_seconds: float = 60.0
|
|
45
|
+
backoff_factor: float = 2.0
|
|
46
|
+
retryable_exceptions: tuple[type[BaseException], ...] = field(default_factory=tuple)
|
|
47
|
+
|
|
48
|
+
def __post_init__(self) -> None:
|
|
49
|
+
if not self.retryable_exceptions:
|
|
50
|
+
object.__setattr__(self, "retryable_exceptions", _default_retryable_exceptions())
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def retry(config: RetryConfig | None = None) -> Callable:
|
|
54
|
+
"""Decorator factory for retrying functions with exponential backoff.
|
|
55
|
+
|
|
56
|
+
Example::
|
|
57
|
+
|
|
58
|
+
@retry(RetryConfig(max_attempts=5, base_delay_seconds=2.0))
|
|
59
|
+
def fetch_data():
|
|
60
|
+
return requests.get(url).json()
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
config: Retry parameters. Uses defaults if ``None``.
|
|
64
|
+
"""
|
|
65
|
+
if config is None:
|
|
66
|
+
config = RetryConfig()
|
|
67
|
+
|
|
68
|
+
def decorator(fn: Callable) -> Callable:
|
|
69
|
+
@functools.wraps(fn)
|
|
70
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
71
|
+
last_exception: BaseException | None = None
|
|
72
|
+
for attempt in range(1, config.max_attempts + 1):
|
|
73
|
+
try:
|
|
74
|
+
return fn(*args, **kwargs)
|
|
75
|
+
except config.retryable_exceptions as exc:
|
|
76
|
+
last_exception = exc
|
|
77
|
+
if attempt == config.max_attempts:
|
|
78
|
+
logger.error(
|
|
79
|
+
"%s failed after %d attempts: %s",
|
|
80
|
+
fn.__name__,
|
|
81
|
+
config.max_attempts,
|
|
82
|
+
exc,
|
|
83
|
+
)
|
|
84
|
+
raise
|
|
85
|
+
delay = min(
|
|
86
|
+
config.base_delay_seconds * (config.backoff_factor ** (attempt - 1)),
|
|
87
|
+
config.max_delay_seconds,
|
|
88
|
+
)
|
|
89
|
+
logger.warning(
|
|
90
|
+
"%s attempt %d/%d failed: %s. Retrying in %.1fs",
|
|
91
|
+
fn.__name__,
|
|
92
|
+
attempt,
|
|
93
|
+
config.max_attempts,
|
|
94
|
+
exc,
|
|
95
|
+
delay,
|
|
96
|
+
)
|
|
97
|
+
time.sleep(delay)
|
|
98
|
+
raise last_exception # type: ignore[misc]
|
|
99
|
+
|
|
100
|
+
return wrapper
|
|
101
|
+
|
|
102
|
+
return decorator
|