databricks4py 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. databricks4py/__init__.py +56 -0
  2. databricks4py/catalog.py +65 -0
  3. databricks4py/config/__init__.py +6 -0
  4. databricks4py/config/base.py +119 -0
  5. databricks4py/config/unity.py +72 -0
  6. databricks4py/filters/__init__.py +17 -0
  7. databricks4py/filters/base.py +154 -0
  8. databricks4py/io/__init__.py +40 -0
  9. databricks4py/io/checkpoint.py +98 -0
  10. databricks4py/io/dbfs.py +91 -0
  11. databricks4py/io/delta.py +564 -0
  12. databricks4py/io/merge.py +176 -0
  13. databricks4py/io/streaming.py +281 -0
  14. databricks4py/logging.py +39 -0
  15. databricks4py/metrics/__init__.py +22 -0
  16. databricks4py/metrics/base.py +66 -0
  17. databricks4py/metrics/delta_sink.py +75 -0
  18. databricks4py/metrics/logging_sink.py +20 -0
  19. databricks4py/migrations/__init__.py +27 -0
  20. databricks4py/migrations/alter.py +114 -0
  21. databricks4py/migrations/runner.py +241 -0
  22. databricks4py/migrations/schema_diff.py +136 -0
  23. databricks4py/migrations/validators.py +195 -0
  24. databricks4py/observability/__init__.py +24 -0
  25. databricks4py/observability/_utils.py +24 -0
  26. databricks4py/observability/batch_context.py +134 -0
  27. databricks4py/observability/health.py +223 -0
  28. databricks4py/observability/query_listener.py +236 -0
  29. databricks4py/py.typed +0 -0
  30. databricks4py/quality/__init__.py +26 -0
  31. databricks4py/quality/base.py +54 -0
  32. databricks4py/quality/expectations.py +184 -0
  33. databricks4py/quality/gate.py +90 -0
  34. databricks4py/retry.py +102 -0
  35. databricks4py/secrets.py +69 -0
  36. databricks4py/spark_session.py +68 -0
  37. databricks4py/testing/__init__.py +35 -0
  38. databricks4py/testing/assertions.py +111 -0
  39. databricks4py/testing/builders.py +127 -0
  40. databricks4py/testing/fixtures.py +134 -0
  41. databricks4py/testing/mocks.py +106 -0
  42. databricks4py/testing/temp_table.py +73 -0
  43. databricks4py/workflow.py +219 -0
  44. databricks4py-0.2.0.dist-info/METADATA +589 -0
  45. databricks4py-0.2.0.dist-info/RECORD +48 -0
  46. databricks4py-0.2.0.dist-info/WHEEL +5 -0
  47. databricks4py-0.2.0.dist-info/licenses/LICENSE +21 -0
  48. databricks4py-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,236 @@
1
+ """Streaming query progress observer via PySpark StreamingQueryListener.
2
+
3
+ Wraps the PySpark 3.4+ ``StreamingQueryListener`` API into a simple observer
4
+ that collects progress snapshots, optionally emits them to a
5
+ :class:`~databricks4py.metrics.base.MetricsSink`, and exposes the latest
6
+ state for health checks.
7
+
8
+ Requires PySpark >= 3.4. On older versions, :meth:`QueryProgressObserver.attach`
9
+ raises ``ImportError`` with a clear message.
10
+
11
+ Example::
12
+
13
+ observer = QueryProgressObserver(spark=spark, metrics_sink=my_sink)
14
+ observer.attach()
15
+
16
+ query = df.writeStream.start()
17
+ # ... wait for progress ...
18
+
19
+ latest = observer.latest_progress()
20
+ if latest:
21
+ print(f"Batch {latest.batch_id}: {latest.num_input_rows} rows")
22
+
23
+ observer.detach()
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import json
29
+ import logging
30
+ from collections import deque
31
+ from dataclasses import dataclass, field
32
+ from datetime import datetime, timezone
33
+ from typing import TYPE_CHECKING, Any
34
+
35
+ from databricks4py.observability._utils import parse_duration_ms
36
+ from databricks4py.spark_session import active_fallback
37
+
38
+ if TYPE_CHECKING:
39
+ from collections.abc import Callable
40
+
41
+ from pyspark.sql import SparkSession
42
+
43
+ from databricks4py.metrics.base import MetricsSink
44
+
45
+ __all__ = ["QueryProgressObserver", "QueryProgressSnapshot"]
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+
50
+ @dataclass(frozen=True)
51
+ class QueryProgressSnapshot:
52
+ """Immutable snapshot of a single streaming query progress event.
53
+
54
+ Attributes:
55
+ query_id: Unique query identifier (UUID string).
56
+ query_name: Optional query name (set via ``.queryName()``).
57
+ batch_id: Monotonically increasing batch counter.
58
+ batch_duration_ms: Wall-clock time for the batch in milliseconds.
59
+ input_rows_per_second: Rate of data arriving from the source.
60
+ processed_rows_per_second: Rate of data being processed.
61
+ num_input_rows: Total rows read in this batch.
62
+ timestamp: UTC time when the progress was recorded.
63
+ sources: Per-source offset and rate info.
64
+ sink: Sink commit progress info.
65
+ raw_json: Full progress JSON for custom parsing.
66
+ """
67
+
68
+ query_id: str
69
+ query_name: str | None
70
+ batch_id: int
71
+ batch_duration_ms: int
72
+ input_rows_per_second: float
73
+ processed_rows_per_second: float
74
+ num_input_rows: int
75
+ timestamp: datetime
76
+ sources: list[dict[str, Any]] = field(default_factory=list)
77
+ sink: dict[str, Any] = field(default_factory=dict)
78
+ raw_json: str = ""
79
+
80
+ @classmethod
81
+ def from_progress(cls, progress: Any) -> QueryProgressSnapshot:
82
+ """Build from a PySpark ``StreamingQueryProgress`` object.
83
+
84
+ Expects an object with a ``.json`` property returning a JSON string
85
+ (standard ``StreamingQueryProgress`` in PySpark 3.4+).
86
+ """
87
+ import json as _json
88
+
89
+ raw = progress.json
90
+ data = _json.loads(raw)
91
+
92
+ return cls(
93
+ query_id=data.get("id", ""),
94
+ query_name=data.get("name"),
95
+ batch_id=data.get("batchId", -1),
96
+ batch_duration_ms=parse_duration_ms(data.get("batchDuration", "0 ms")),
97
+ input_rows_per_second=data.get("inputRowsPerSecond", 0.0),
98
+ processed_rows_per_second=data.get("processedRowsPerSecond", 0.0),
99
+ num_input_rows=data.get("numInputRows", 0),
100
+ timestamp=datetime.now(tz=timezone.utc),
101
+ sources=data.get("sources", []),
102
+ sink=data.get("sink", {}),
103
+ raw_json=raw,
104
+ )
105
+
106
+
107
+ class QueryProgressObserver:
108
+ """Collects streaming query progress events and routes them to a metrics sink.
109
+
110
+ Attaches a ``StreamingQueryListener`` to the SparkSession. Each progress
111
+ event is captured as a :class:`QueryProgressSnapshot`, stored in a bounded
112
+ history, and optionally emitted as a ``MetricEvent``.
113
+
114
+ Args:
115
+ spark: SparkSession to attach the listener to.
116
+ metrics_sink: Optional sink for ``query_progress`` metric events.
117
+ on_progress: Optional callback invoked on each progress event.
118
+ history_size: Maximum number of snapshots to retain. Oldest are evicted.
119
+ query_name_filter: If set, only track queries with this name.
120
+ """
121
+
122
+ def __init__(
123
+ self,
124
+ *,
125
+ spark: SparkSession | None = None,
126
+ metrics_sink: MetricsSink | None = None,
127
+ on_progress: Callable[[QueryProgressSnapshot], None] | None = None,
128
+ history_size: int = 100,
129
+ query_name_filter: str | None = None,
130
+ ) -> None:
131
+ self._spark = active_fallback(spark)
132
+ self._metrics_sink = metrics_sink
133
+ self._on_progress = on_progress
134
+ self._history: deque[QueryProgressSnapshot] = deque(maxlen=history_size)
135
+ self._query_name_filter = query_name_filter
136
+ self._listener: Any = None
137
+ self._attached = False
138
+
139
+ def attach(self) -> None:
140
+ """Register the listener with the SparkSession.
141
+
142
+ Raises:
143
+ ImportError: If PySpark < 3.4 (no Python listener support).
144
+ """
145
+ try:
146
+ from pyspark.sql.streaming.listener import StreamingQueryListener
147
+ except ImportError as exc:
148
+ raise ImportError(
149
+ "StreamingQueryListener requires PySpark >= 3.4. "
150
+ "Upgrade PySpark or use polling via query.lastProgress instead."
151
+ ) from exc
152
+
153
+ observer = self
154
+
155
+ class _Listener(StreamingQueryListener):
156
+ def onQueryStarted(self, event: Any) -> None:
157
+ logger.debug("Query started: id=%s name=%s", event.id, event.name)
158
+
159
+ def onQueryProgress(self, event: Any) -> None:
160
+ observer._handle_progress(event.progress)
161
+
162
+ def onQueryTerminated(self, event: Any) -> None:
163
+ exc_str = str(event.exception)[:500] if event.exception else None
164
+ logger.info("Query terminated: id=%s exception=%s", event.id, exc_str)
165
+
166
+ self._listener = _Listener()
167
+ self._spark.streams.addListener(self._listener)
168
+ self._attached = True
169
+ logger.info("QueryProgressObserver attached")
170
+
171
+ def detach(self) -> None:
172
+ """Remove the listener from the SparkSession."""
173
+ if self._listener is not None and self._attached:
174
+ self._spark.streams.removeListener(self._listener)
175
+ self._attached = False
176
+ logger.info("QueryProgressObserver detached")
177
+
178
+ def _handle_progress(self, progress: Any) -> None:
179
+ try:
180
+ snapshot = QueryProgressSnapshot.from_progress(progress)
181
+ except (json.JSONDecodeError, KeyError, TypeError, ValueError, AttributeError):
182
+ logger.exception("Failed to process query progress event")
183
+ return
184
+
185
+ if self._query_name_filter is not None and snapshot.query_name != self._query_name_filter:
186
+ return
187
+
188
+ self._history.append(snapshot)
189
+
190
+ if self._on_progress is not None:
191
+ self._on_progress(snapshot)
192
+
193
+ if self._metrics_sink is not None:
194
+ from databricks4py.metrics.base import MetricEvent
195
+
196
+ self._metrics_sink.emit(
197
+ MetricEvent(
198
+ job_name=snapshot.query_name or snapshot.query_id,
199
+ event_type="query_progress",
200
+ timestamp=snapshot.timestamp,
201
+ duration_ms=snapshot.batch_duration_ms,
202
+ row_count=snapshot.num_input_rows,
203
+ batch_id=snapshot.batch_id,
204
+ metadata={
205
+ "input_rows_per_second": snapshot.input_rows_per_second,
206
+ "processed_rows_per_second": snapshot.processed_rows_per_second,
207
+ },
208
+ )
209
+ )
210
+
211
+ logger.debug(
212
+ "Progress: query=%s batch=%d rows=%d rate=%.1f rows/s",
213
+ snapshot.query_name or snapshot.query_id,
214
+ snapshot.batch_id,
215
+ snapshot.num_input_rows,
216
+ snapshot.processed_rows_per_second,
217
+ )
218
+
219
+ def latest_progress(self) -> QueryProgressSnapshot | None:
220
+ """Most recent progress snapshot, or None if no events received yet."""
221
+ return self._history[-1] if self._history else None
222
+
223
+ def history(self, limit: int | None = None) -> list[QueryProgressSnapshot]:
224
+ """Recent progress snapshots, newest last.
225
+
226
+ Args:
227
+ limit: Maximum number of snapshots to return. Returns all if None.
228
+ """
229
+ items = list(self._history)
230
+ if limit is not None:
231
+ items = items[-limit:]
232
+ return items
233
+
234
+ @property
235
+ def is_attached(self) -> bool:
236
+ return self._attached
databricks4py/py.typed ADDED
File without changes
@@ -0,0 +1,26 @@
1
+ """Data quality expectations and enforcement."""
2
+
3
+ from databricks4py.quality.base import Expectation, ExpectationResult, QualityReport
4
+ from databricks4py.quality.expectations import (
5
+ ColumnExists,
6
+ InRange,
7
+ MatchesRegex,
8
+ NotNull,
9
+ RowCount,
10
+ Unique,
11
+ )
12
+ from databricks4py.quality.gate import QualityError, QualityGate
13
+
14
+ __all__ = [
15
+ "ColumnExists",
16
+ "Expectation",
17
+ "ExpectationResult",
18
+ "InRange",
19
+ "MatchesRegex",
20
+ "NotNull",
21
+ "QualityError",
22
+ "QualityGate",
23
+ "QualityReport",
24
+ "RowCount",
25
+ "Unique",
26
+ ]
@@ -0,0 +1,54 @@
1
+ """Core quality abstractions: expectations, results, and reports."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from abc import ABC, abstractmethod
6
+ from dataclasses import dataclass
7
+ from typing import TYPE_CHECKING
8
+
9
+ if TYPE_CHECKING:
10
+ from pyspark.sql import Column, DataFrame, Row
11
+
12
+ __all__ = ["Expectation", "ExpectationResult", "QualityReport"]
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class ExpectationResult:
17
+ """Result of a single expectation check."""
18
+
19
+ expectation: str
20
+ passed: bool
21
+ total_rows: int
22
+ failing_rows: int = 0
23
+ sample: list[Row] | None = None
24
+
25
+
26
+ @dataclass(frozen=True)
27
+ class QualityReport:
28
+ """Aggregated results from running multiple expectations."""
29
+
30
+ results: list[ExpectationResult]
31
+ passed: bool
32
+
33
+ def summary(self) -> str:
34
+ lines = []
35
+ for r in self.results:
36
+ status = "PASS" if r.passed else "FAIL"
37
+ lines.append(f"[{status}] {r.expectation} ({r.failing_rows}/{r.total_rows} failing)")
38
+ overall = "PASSED" if self.passed else "FAILED"
39
+ lines.append(f"Overall: {overall}")
40
+ return "\n".join(lines)
41
+
42
+
43
+ class Expectation(ABC):
44
+ """Abstract base for DataFrame quality expectations."""
45
+
46
+ @abstractmethod
47
+ def validate(self, df: DataFrame) -> ExpectationResult: ...
48
+
49
+ def failing_condition(self) -> Column | None:
50
+ """Return a Column expression that is True for failing rows.
51
+
52
+ Returns None for aggregate checks that can't filter individual rows.
53
+ """
54
+ return None
@@ -0,0 +1,184 @@
1
+ """Built-in expectation implementations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from functools import reduce
6
+ from typing import TYPE_CHECKING
7
+
8
+ from databricks4py.quality.base import Expectation, ExpectationResult
9
+
10
+ if TYPE_CHECKING:
11
+ from pyspark.sql import Column, DataFrame
12
+
13
+ __all__ = ["ColumnExists", "InRange", "MatchesRegex", "NotNull", "RowCount", "Unique"]
14
+
15
+
16
+ class NotNull(Expectation):
17
+ """Validates that specified columns contain no null values."""
18
+
19
+ def __init__(self, *columns: str) -> None:
20
+ self._columns = columns
21
+
22
+ def validate(self, df: DataFrame) -> ExpectationResult:
23
+ total = df.count()
24
+ condition = self.failing_condition()
25
+ failing = df.where(condition).count() if condition is not None else 0
26
+ return ExpectationResult(
27
+ expectation=repr(self),
28
+ passed=failing == 0,
29
+ total_rows=total,
30
+ failing_rows=failing,
31
+ )
32
+
33
+ def failing_condition(self) -> Column | None:
34
+ from pyspark.sql import functions as F
35
+
36
+ conditions = [F.col(c).isNull() for c in self._columns]
37
+ return reduce(lambda a, b: a | b, conditions)
38
+
39
+ def __repr__(self) -> str:
40
+ cols = ", ".join(repr(c) for c in self._columns)
41
+ return f"NotNull({cols})"
42
+
43
+
44
+ class InRange(Expectation):
45
+ """Validates that a column's values fall within bounds."""
46
+
47
+ def __init__(
48
+ self, column: str, *, min_val: float | None = None, max_val: float | None = None
49
+ ) -> None:
50
+ self._column = column
51
+ self._min_val = min_val
52
+ self._max_val = max_val
53
+
54
+ def validate(self, df: DataFrame) -> ExpectationResult:
55
+ total = df.count()
56
+ condition = self.failing_condition()
57
+ failing = df.where(condition).count() if condition is not None else 0
58
+ return ExpectationResult(
59
+ expectation=repr(self),
60
+ passed=failing == 0,
61
+ total_rows=total,
62
+ failing_rows=failing,
63
+ )
64
+
65
+ def failing_condition(self) -> Column | None:
66
+ from pyspark.sql import functions as F
67
+
68
+ conditions = []
69
+ col = F.col(self._column)
70
+ if self._min_val is not None:
71
+ conditions.append(col < self._min_val)
72
+ if self._max_val is not None:
73
+ conditions.append(col > self._max_val)
74
+ if not conditions:
75
+ return None
76
+ return reduce(lambda a, b: a | b, conditions)
77
+
78
+ def __repr__(self) -> str:
79
+ return f"InRange({self._column!r}, min_val={self._min_val!r}, max_val={self._max_val!r})"
80
+
81
+
82
+ class Unique(Expectation):
83
+ """Validates no duplicate rows exist for the specified columns."""
84
+
85
+ def __init__(self, *columns: str) -> None:
86
+ self._columns = columns
87
+
88
+ def validate(self, df: DataFrame) -> ExpectationResult:
89
+ total = df.count()
90
+ distinct = df.select(*self._columns).distinct().count()
91
+ failing = total - distinct
92
+ return ExpectationResult(
93
+ expectation=repr(self),
94
+ passed=failing == 0,
95
+ total_rows=total,
96
+ failing_rows=failing,
97
+ )
98
+
99
+ def __repr__(self) -> str:
100
+ cols = ", ".join(repr(c) for c in self._columns)
101
+ return f"Unique({cols})"
102
+
103
+
104
+ class RowCount(Expectation):
105
+ """Validates that the DataFrame row count falls within bounds."""
106
+
107
+ def __init__(self, *, min_count: int | None = None, max_count: int | None = None) -> None:
108
+ self._min_count = min_count
109
+ self._max_count = max_count
110
+
111
+ def validate(self, df: DataFrame) -> ExpectationResult:
112
+ total = df.count()
113
+ passed = True
114
+ if self._min_count is not None and total < self._min_count:
115
+ passed = False
116
+ if self._max_count is not None and total > self._max_count:
117
+ passed = False
118
+ return ExpectationResult(
119
+ expectation=repr(self),
120
+ passed=passed,
121
+ total_rows=total,
122
+ )
123
+
124
+ def __repr__(self) -> str:
125
+ return f"RowCount(min_count={self._min_count!r}, max_count={self._max_count!r})"
126
+
127
+
128
+ class MatchesRegex(Expectation):
129
+ """Validates that a column's values match a regex pattern."""
130
+
131
+ def __init__(self, column: str, pattern: str) -> None:
132
+ self._column = column
133
+ self._pattern = pattern
134
+
135
+ def validate(self, df: DataFrame) -> ExpectationResult:
136
+ total = df.count()
137
+ condition = self.failing_condition()
138
+ failing = df.where(condition).count() if condition is not None else 0
139
+ return ExpectationResult(
140
+ expectation=repr(self),
141
+ passed=failing == 0,
142
+ total_rows=total,
143
+ failing_rows=failing,
144
+ )
145
+
146
+ def failing_condition(self) -> Column | None:
147
+ from pyspark.sql import functions as F
148
+
149
+ return ~F.col(self._column).rlike(self._pattern)
150
+
151
+ def __repr__(self) -> str:
152
+ return f"MatchesRegex({self._column!r}, {self._pattern!r})"
153
+
154
+
155
+ class ColumnExists(Expectation):
156
+ """Validates that specified columns exist in the DataFrame schema."""
157
+
158
+ def __init__(self, *columns: str, dtype: str | None = None) -> None:
159
+ self._columns = columns
160
+ self._dtype = dtype
161
+
162
+ def validate(self, df: DataFrame) -> ExpectationResult:
163
+ total = df.count()
164
+ actual_cols = set(df.columns)
165
+ missing = [c for c in self._columns if c not in actual_cols]
166
+
167
+ if not missing and self._dtype:
168
+ schema_fields = {f.name: f.dataType.simpleString() for f in df.schema.fields}
169
+ for c in self._columns:
170
+ if schema_fields.get(c) != self._dtype:
171
+ missing.append(c)
172
+
173
+ return ExpectationResult(
174
+ expectation=repr(self),
175
+ passed=len(missing) == 0,
176
+ total_rows=total,
177
+ failing_rows=len(missing),
178
+ )
179
+
180
+ def __repr__(self) -> str:
181
+ cols = ", ".join(repr(c) for c in self._columns)
182
+ if self._dtype:
183
+ return f"ColumnExists({cols}, dtype={self._dtype!r})"
184
+ return f"ColumnExists({cols})"
@@ -0,0 +1,90 @@
1
+ """Quality gate: orchestrate expectations and enforce data quality."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from collections.abc import Callable
7
+ from typing import TYPE_CHECKING, Literal
8
+
9
+ from databricks4py.quality.base import Expectation, QualityReport
10
+
11
+ if TYPE_CHECKING:
12
+ from pyspark.sql import DataFrame
13
+
14
+ __all__ = ["QualityError", "QualityGate"]
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class QualityError(Exception):
20
+ """Raised when a quality gate fails in 'raise' mode."""
21
+
22
+ def __init__(self, report: QualityReport) -> None:
23
+ self.report = report
24
+ super().__init__(f"Quality check failed:\n{report.summary()}")
25
+
26
+
27
+ class QualityGate:
28
+ """Runs expectations against a DataFrame and enforces quality policy.
29
+
30
+ Args:
31
+ *expectations: Expectation instances to evaluate.
32
+ on_fail: Action on failure: "raise", "warn", or "quarantine".
33
+ quarantine_handler: Callable receiving the bad-rows DataFrame.
34
+ Required when on_fail="quarantine".
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ *expectations: Expectation,
40
+ on_fail: Literal["raise", "warn", "quarantine"] = "raise",
41
+ quarantine_handler: Callable[[DataFrame], None] | None = None,
42
+ ) -> None:
43
+ if on_fail == "quarantine" and quarantine_handler is None:
44
+ raise ValueError("quarantine_handler is required when on_fail='quarantine'")
45
+ self._expectations = expectations
46
+ self._on_fail = on_fail
47
+ self._quarantine_handler = quarantine_handler
48
+
49
+ def check(self, df: DataFrame) -> QualityReport:
50
+ """Run all expectations and return a report."""
51
+ results = [exp.validate(df) for exp in self._expectations]
52
+ passed = all(r.passed for r in results)
53
+ return QualityReport(results=results, passed=passed)
54
+
55
+ def enforce(self, df: DataFrame) -> DataFrame:
56
+ """Run expectations and enforce the configured failure policy.
57
+
58
+ Returns the clean DataFrame (original if all pass, or filtered
59
+ if quarantine mode splits out bad rows).
60
+ """
61
+ report = self.check(df)
62
+ if report.passed:
63
+ return df
64
+
65
+ if self._on_fail == "raise":
66
+ raise QualityError(report)
67
+
68
+ if self._on_fail == "warn":
69
+ logger.warning("Quality check failed:\n%s", report.summary())
70
+ return df
71
+
72
+ # quarantine: split bad rows using failing_condition()
73
+ from functools import reduce
74
+
75
+ conditions = []
76
+ for exp in self._expectations:
77
+ cond = exp.failing_condition()
78
+ if cond is not None:
79
+ conditions.append(cond)
80
+
81
+ if conditions:
82
+ bad_condition = reduce(lambda a, b: a | b, conditions)
83
+ bad_rows = df.where(bad_condition)
84
+ clean_rows = df.where(~bad_condition)
85
+ self._quarantine_handler(bad_rows) # type: ignore[misc]
86
+ return clean_rows
87
+
88
+ # no row-level conditions available — quarantine nothing
89
+ logger.warning("Quarantine requested but no row-level failing conditions available")
90
+ return df
databricks4py/retry.py ADDED
@@ -0,0 +1,102 @@
1
+ """Retry decorator with exponential backoff for transient failures."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import functools
6
+ import logging
7
+ import time
8
+ from collections.abc import Callable
9
+ from dataclasses import dataclass, field
10
+ from typing import Any
11
+
12
+ __all__ = ["RetryConfig", "retry"]
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def _default_retryable_exceptions() -> tuple[type[BaseException], ...]:
18
+ exceptions: list[type[BaseException]] = [ConnectionError, TimeoutError, OSError]
19
+ try:
20
+ from py4j.protocol import Py4JNetworkError
21
+
22
+ exceptions.append(Py4JNetworkError)
23
+ except ImportError:
24
+ pass
25
+ return tuple(exceptions)
26
+
27
+
28
+ @dataclass(frozen=True)
29
+ class RetryConfig:
30
+ """Configuration for :func:`retry` behaviour.
31
+
32
+ Args:
33
+ max_attempts: Total number of tries (including the first).
34
+ base_delay_seconds: Initial delay before the first retry.
35
+ max_delay_seconds: Upper cap on the exponentially increasing delay.
36
+ backoff_factor: Multiplier applied to the delay after each failure.
37
+ retryable_exceptions: Exception types that trigger a retry.
38
+ Defaults to ``ConnectionError``, ``TimeoutError``, ``OSError``,
39
+ and ``Py4JNetworkError`` (if py4j is installed).
40
+ """
41
+
42
+ max_attempts: int = 3
43
+ base_delay_seconds: float = 1.0
44
+ max_delay_seconds: float = 60.0
45
+ backoff_factor: float = 2.0
46
+ retryable_exceptions: tuple[type[BaseException], ...] = field(default_factory=tuple)
47
+
48
+ def __post_init__(self) -> None:
49
+ if not self.retryable_exceptions:
50
+ object.__setattr__(self, "retryable_exceptions", _default_retryable_exceptions())
51
+
52
+
53
+ def retry(config: RetryConfig | None = None) -> Callable:
54
+ """Decorator factory for retrying functions with exponential backoff.
55
+
56
+ Example::
57
+
58
+ @retry(RetryConfig(max_attempts=5, base_delay_seconds=2.0))
59
+ def fetch_data():
60
+ return requests.get(url).json()
61
+
62
+ Args:
63
+ config: Retry parameters. Uses defaults if ``None``.
64
+ """
65
+ if config is None:
66
+ config = RetryConfig()
67
+
68
+ def decorator(fn: Callable) -> Callable:
69
+ @functools.wraps(fn)
70
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
71
+ last_exception: BaseException | None = None
72
+ for attempt in range(1, config.max_attempts + 1):
73
+ try:
74
+ return fn(*args, **kwargs)
75
+ except config.retryable_exceptions as exc:
76
+ last_exception = exc
77
+ if attempt == config.max_attempts:
78
+ logger.error(
79
+ "%s failed after %d attempts: %s",
80
+ fn.__name__,
81
+ config.max_attempts,
82
+ exc,
83
+ )
84
+ raise
85
+ delay = min(
86
+ config.base_delay_seconds * (config.backoff_factor ** (attempt - 1)),
87
+ config.max_delay_seconds,
88
+ )
89
+ logger.warning(
90
+ "%s attempt %d/%d failed: %s. Retrying in %.1fs",
91
+ fn.__name__,
92
+ attempt,
93
+ config.max_attempts,
94
+ exc,
95
+ delay,
96
+ )
97
+ time.sleep(delay)
98
+ raise last_exception # type: ignore[misc]
99
+
100
+ return wrapper
101
+
102
+ return decorator