databricks4py 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- databricks4py/__init__.py +56 -0
- databricks4py/catalog.py +65 -0
- databricks4py/config/__init__.py +6 -0
- databricks4py/config/base.py +119 -0
- databricks4py/config/unity.py +72 -0
- databricks4py/filters/__init__.py +17 -0
- databricks4py/filters/base.py +154 -0
- databricks4py/io/__init__.py +40 -0
- databricks4py/io/checkpoint.py +98 -0
- databricks4py/io/dbfs.py +91 -0
- databricks4py/io/delta.py +564 -0
- databricks4py/io/merge.py +176 -0
- databricks4py/io/streaming.py +281 -0
- databricks4py/logging.py +39 -0
- databricks4py/metrics/__init__.py +22 -0
- databricks4py/metrics/base.py +66 -0
- databricks4py/metrics/delta_sink.py +75 -0
- databricks4py/metrics/logging_sink.py +20 -0
- databricks4py/migrations/__init__.py +27 -0
- databricks4py/migrations/alter.py +114 -0
- databricks4py/migrations/runner.py +241 -0
- databricks4py/migrations/schema_diff.py +136 -0
- databricks4py/migrations/validators.py +195 -0
- databricks4py/observability/__init__.py +24 -0
- databricks4py/observability/_utils.py +24 -0
- databricks4py/observability/batch_context.py +134 -0
- databricks4py/observability/health.py +223 -0
- databricks4py/observability/query_listener.py +236 -0
- databricks4py/py.typed +0 -0
- databricks4py/quality/__init__.py +26 -0
- databricks4py/quality/base.py +54 -0
- databricks4py/quality/expectations.py +184 -0
- databricks4py/quality/gate.py +90 -0
- databricks4py/retry.py +102 -0
- databricks4py/secrets.py +69 -0
- databricks4py/spark_session.py +68 -0
- databricks4py/testing/__init__.py +35 -0
- databricks4py/testing/assertions.py +111 -0
- databricks4py/testing/builders.py +127 -0
- databricks4py/testing/fixtures.py +134 -0
- databricks4py/testing/mocks.py +106 -0
- databricks4py/testing/temp_table.py +73 -0
- databricks4py/workflow.py +219 -0
- databricks4py-0.2.0.dist-info/METADATA +589 -0
- databricks4py-0.2.0.dist-info/RECORD +48 -0
- databricks4py-0.2.0.dist-info/WHEEL +5 -0
- databricks4py-0.2.0.dist-info/licenses/LICENSE +21 -0
- databricks4py-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
"""Table structure validation for migrations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from collections.abc import Sequence
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
|
|
9
|
+
from pyspark.sql import SparkSession
|
|
10
|
+
|
|
11
|
+
from databricks4py.io.delta import GeneratedColumn
|
|
12
|
+
from databricks4py.spark_session import active_fallback
|
|
13
|
+
|
|
14
|
+
__all__ = ["MigrationError", "TableValidator", "ValidationResult"]
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class MigrationError(Exception):
|
|
20
|
+
"""Raised when table validation fails during migration.
|
|
21
|
+
|
|
22
|
+
Attributes:
|
|
23
|
+
table_name: The table that failed validation.
|
|
24
|
+
errors: List of validation error messages.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, table_name: str, errors: list[str]) -> None:
|
|
28
|
+
self.table_name = table_name
|
|
29
|
+
self.errors = errors
|
|
30
|
+
message = f"Migration validation failed for '{table_name}':\n" + "\n".join(
|
|
31
|
+
f" - {e}" for e in errors
|
|
32
|
+
)
|
|
33
|
+
super().__init__(message)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class ValidationResult:
|
|
38
|
+
"""Result of a table validation check.
|
|
39
|
+
|
|
40
|
+
Attributes:
|
|
41
|
+
is_valid: Whether all checks passed.
|
|
42
|
+
errors: List of validation errors.
|
|
43
|
+
warnings: List of non-fatal warnings.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
is_valid: bool
|
|
47
|
+
errors: list[str] = field(default_factory=list)
|
|
48
|
+
warnings: list[str] = field(default_factory=list)
|
|
49
|
+
|
|
50
|
+
def raise_if_invalid(self, table_name: str) -> None:
|
|
51
|
+
"""Raise MigrationError if validation failed.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
table_name: Table name for the error message.
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
MigrationError: If ``is_valid`` is False.
|
|
58
|
+
"""
|
|
59
|
+
if not self.is_valid:
|
|
60
|
+
raise MigrationError(table_name, self.errors)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class TableValidator:
|
|
64
|
+
"""Validates Delta table structure against expected configuration.
|
|
65
|
+
|
|
66
|
+
Used in migration workflows to verify that a table matches
|
|
67
|
+
expected schema, partitioning, and structure before and after
|
|
68
|
+
migration steps.
|
|
69
|
+
|
|
70
|
+
Example::
|
|
71
|
+
|
|
72
|
+
validator = TableValidator(
|
|
73
|
+
table_name="catalog.schema.events",
|
|
74
|
+
expected_columns=["id", "name", "event_date"],
|
|
75
|
+
expected_partition_columns=["event_date"],
|
|
76
|
+
)
|
|
77
|
+
result = validator.validate()
|
|
78
|
+
result.raise_if_invalid("catalog.schema.events")
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
table_name: Fully qualified table name.
|
|
82
|
+
expected_columns: Columns that must exist in the table.
|
|
83
|
+
expected_partition_columns: Expected partition column order.
|
|
84
|
+
expected_generated_columns: Expected generated column definitions.
|
|
85
|
+
expected_location_contains: Substring that must appear in table location.
|
|
86
|
+
spark: Optional SparkSession.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
table_name: str,
|
|
92
|
+
*,
|
|
93
|
+
expected_columns: Sequence[str] | None = None,
|
|
94
|
+
expected_partition_columns: Sequence[str] | None = None,
|
|
95
|
+
expected_generated_columns: Sequence[GeneratedColumn] | None = None,
|
|
96
|
+
expected_location_contains: str | None = None,
|
|
97
|
+
spark: SparkSession | None = None,
|
|
98
|
+
) -> None:
|
|
99
|
+
self._spark = active_fallback(spark)
|
|
100
|
+
self._table_name = table_name
|
|
101
|
+
self._expected_columns = list(expected_columns or [])
|
|
102
|
+
self._expected_partition_columns = list(expected_partition_columns or [])
|
|
103
|
+
self._expected_generated_columns = list(expected_generated_columns or [])
|
|
104
|
+
self._expected_location_contains = expected_location_contains
|
|
105
|
+
|
|
106
|
+
def _table_exists(self) -> bool:
|
|
107
|
+
"""Check if the table exists in the catalog."""
|
|
108
|
+
from pyspark.errors import AnalysisException
|
|
109
|
+
|
|
110
|
+
try:
|
|
111
|
+
self._spark.sql(f"DESCRIBE TABLE {self._table_name}")
|
|
112
|
+
return True
|
|
113
|
+
except AnalysisException:
|
|
114
|
+
return False
|
|
115
|
+
|
|
116
|
+
def _get_actual_columns(self) -> set[str]:
|
|
117
|
+
"""Get column names from the table."""
|
|
118
|
+
rows = self._spark.sql(f"DESCRIBE TABLE {self._table_name}").collect()
|
|
119
|
+
columns: set[str] = set()
|
|
120
|
+
for row in rows:
|
|
121
|
+
col_name = row["col_name"]
|
|
122
|
+
if col_name is None or col_name == "" or col_name.startswith("#"):
|
|
123
|
+
break
|
|
124
|
+
columns.add(col_name)
|
|
125
|
+
return columns
|
|
126
|
+
|
|
127
|
+
def _get_actual_partitions(self) -> list[str]:
|
|
128
|
+
"""Get partition columns from Delta DETAIL."""
|
|
129
|
+
from delta.tables import DeltaTable
|
|
130
|
+
|
|
131
|
+
dt = DeltaTable.forName(self._spark, self._table_name)
|
|
132
|
+
row = dt.detail().select("partitionColumns").first()
|
|
133
|
+
return list(row["partitionColumns"]) if row else []
|
|
134
|
+
|
|
135
|
+
def _get_actual_location(self) -> str:
|
|
136
|
+
"""Get the table's physical location."""
|
|
137
|
+
from delta.tables import DeltaTable
|
|
138
|
+
|
|
139
|
+
dt = DeltaTable.forName(self._spark, self._table_name)
|
|
140
|
+
row = dt.detail().select("location").first()
|
|
141
|
+
return row["location"] if row else ""
|
|
142
|
+
|
|
143
|
+
def validate(self) -> ValidationResult:
|
|
144
|
+
"""Run all configured validations.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
ValidationResult with any errors and warnings.
|
|
148
|
+
"""
|
|
149
|
+
errors: list[str] = []
|
|
150
|
+
warnings: list[str] = []
|
|
151
|
+
|
|
152
|
+
if not self._table_exists():
|
|
153
|
+
errors.append(f"Table '{self._table_name}' does not exist")
|
|
154
|
+
return ValidationResult(is_valid=False, errors=errors)
|
|
155
|
+
|
|
156
|
+
logger.info("Validating table %s", self._table_name)
|
|
157
|
+
|
|
158
|
+
if self._expected_columns:
|
|
159
|
+
actual = self._get_actual_columns()
|
|
160
|
+
missing = set(self._expected_columns) - actual
|
|
161
|
+
if missing:
|
|
162
|
+
errors.append(f"Missing required columns: {sorted(missing)}")
|
|
163
|
+
extra = actual - set(self._expected_columns)
|
|
164
|
+
if extra:
|
|
165
|
+
warnings.append(f"Unexpected extra columns: {sorted(extra)}")
|
|
166
|
+
|
|
167
|
+
if self._expected_partition_columns:
|
|
168
|
+
actual_partitions = self._get_actual_partitions()
|
|
169
|
+
if sorted(actual_partitions) != sorted(self._expected_partition_columns):
|
|
170
|
+
errors.append(
|
|
171
|
+
f"Partition mismatch: expected {self._expected_partition_columns}, "
|
|
172
|
+
f"got {actual_partitions}"
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
if self._expected_location_contains:
|
|
176
|
+
actual_location = self._get_actual_location()
|
|
177
|
+
if self._expected_location_contains not in actual_location:
|
|
178
|
+
errors.append(
|
|
179
|
+
f"Location '{actual_location}' does not contain "
|
|
180
|
+
f"'{self._expected_location_contains}'"
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
if self._expected_generated_columns:
|
|
184
|
+
actual_cols = self._get_actual_columns()
|
|
185
|
+
for gc in self._expected_generated_columns:
|
|
186
|
+
if gc.name not in actual_cols:
|
|
187
|
+
errors.append(f"Missing generated column: '{gc.name}'")
|
|
188
|
+
|
|
189
|
+
is_valid = len(errors) == 0
|
|
190
|
+
if is_valid:
|
|
191
|
+
logger.info("Table %s validation passed", self._table_name)
|
|
192
|
+
else:
|
|
193
|
+
logger.warning("Table %s validation failed: %s", self._table_name, errors)
|
|
194
|
+
|
|
195
|
+
return ValidationResult(is_valid=is_valid, errors=errors, warnings=warnings)
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""Observability: structured batch logging, query listeners, and health checks."""
|
|
2
|
+
|
|
3
|
+
from databricks4py.observability.batch_context import BatchContext, BatchLogger
|
|
4
|
+
from databricks4py.observability.health import (
|
|
5
|
+
CheckDetail,
|
|
6
|
+
HealthResult,
|
|
7
|
+
HealthStatus,
|
|
8
|
+
StreamingHealthCheck,
|
|
9
|
+
)
|
|
10
|
+
from databricks4py.observability.query_listener import (
|
|
11
|
+
QueryProgressObserver,
|
|
12
|
+
QueryProgressSnapshot,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"BatchContext",
|
|
17
|
+
"BatchLogger",
|
|
18
|
+
"CheckDetail",
|
|
19
|
+
"HealthResult",
|
|
20
|
+
"HealthStatus",
|
|
21
|
+
"QueryProgressObserver",
|
|
22
|
+
"QueryProgressSnapshot",
|
|
23
|
+
"StreamingHealthCheck",
|
|
24
|
+
]
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""Shared utilities for the observability subpackage."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def parse_duration_ms(val: str | int) -> int:
|
|
7
|
+
"""Parse a Spark duration string to integer milliseconds.
|
|
8
|
+
|
|
9
|
+
Handles ``'250 ms'``, ``'1 s'``, ``'2 m'``, and bare integers.
|
|
10
|
+
Returns 0 if the value cannot be parsed.
|
|
11
|
+
"""
|
|
12
|
+
if isinstance(val, int):
|
|
13
|
+
return val
|
|
14
|
+
try:
|
|
15
|
+
stripped = val.strip()
|
|
16
|
+
if stripped.endswith("ms"):
|
|
17
|
+
return int(stripped.replace("ms", "").strip())
|
|
18
|
+
if stripped.endswith("s"):
|
|
19
|
+
return int(float(stripped.replace("s", "").strip()) * 1000)
|
|
20
|
+
if stripped.endswith("m"):
|
|
21
|
+
return int(float(stripped.replace("m", "").strip()) * 60_000)
|
|
22
|
+
return int(stripped)
|
|
23
|
+
except (ValueError, AttributeError):
|
|
24
|
+
return 0
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
"""Structured per-batch logging with correlation IDs.
|
|
2
|
+
|
|
3
|
+
Produces JSON-structured log records for each batch lifecycle event
|
|
4
|
+
(start, complete, error, skip). Designed for use inside
|
|
5
|
+
:class:`~databricks4py.io.streaming.StreamingTableReader` or any
|
|
6
|
+
``foreachBatch`` processor where you need queryable, machine-parseable logs.
|
|
7
|
+
|
|
8
|
+
Example::
|
|
9
|
+
|
|
10
|
+
logger = BatchLogger()
|
|
11
|
+
|
|
12
|
+
ctx = BatchContext.create(batch_id=42, source_table="catalog.schema.events")
|
|
13
|
+
logger.batch_start(ctx)
|
|
14
|
+
# ... process ...
|
|
15
|
+
logger.batch_complete(ctx, row_count=1000, duration_ms=345.2)
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import json
|
|
21
|
+
import logging
|
|
22
|
+
import uuid
|
|
23
|
+
from dataclasses import dataclass, field
|
|
24
|
+
from datetime import datetime, timezone
|
|
25
|
+
from typing import Any
|
|
26
|
+
|
|
27
|
+
__all__ = ["BatchContext", "BatchLogger"]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass(frozen=True)
|
|
31
|
+
class BatchContext:
|
|
32
|
+
"""Immutable context for a single streaming micro-batch.
|
|
33
|
+
|
|
34
|
+
Carries the batch identifier, source table, a unique correlation ID,
|
|
35
|
+
and the batch start time. Thread-safe (frozen).
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
batch_id: Spark-assigned batch identifier.
|
|
39
|
+
source_table: Fully qualified source table or path.
|
|
40
|
+
correlation_id: Unique ID for correlating logs, metrics, and DLQ
|
|
41
|
+
records across systems. Auto-generated if not provided.
|
|
42
|
+
start_time: UTC timestamp when the batch started processing.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
batch_id: int
|
|
46
|
+
source_table: str
|
|
47
|
+
correlation_id: str = field(default_factory=lambda: uuid.uuid4().hex[:12])
|
|
48
|
+
start_time: datetime = field(default_factory=lambda: datetime.now(tz=timezone.utc))
|
|
49
|
+
|
|
50
|
+
@classmethod
|
|
51
|
+
def create(
|
|
52
|
+
cls,
|
|
53
|
+
batch_id: int,
|
|
54
|
+
source_table: str,
|
|
55
|
+
*,
|
|
56
|
+
correlation_id: str | None = None,
|
|
57
|
+
) -> BatchContext:
|
|
58
|
+
"""Factory with optional explicit correlation ID."""
|
|
59
|
+
if correlation_id is not None:
|
|
60
|
+
return cls(batch_id=batch_id, source_table=source_table, correlation_id=correlation_id)
|
|
61
|
+
return cls(batch_id=batch_id, source_table=source_table)
|
|
62
|
+
|
|
63
|
+
def elapsed_ms(self) -> float:
|
|
64
|
+
"""Milliseconds since ``start_time``."""
|
|
65
|
+
return (datetime.now(tz=timezone.utc) - self.start_time).total_seconds() * 1000
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class BatchLogger:
|
|
69
|
+
"""Structured JSON logger for streaming batch lifecycle events.
|
|
70
|
+
|
|
71
|
+
Each log record is a single-line JSON object with a consistent schema,
|
|
72
|
+
making it easy to query in log aggregation systems (Datadog, Splunk,
|
|
73
|
+
CloudWatch, etc.).
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
logger_name: Python logger name. Defaults to ``"databricks4py.batch"``.
|
|
77
|
+
extra_fields: Static fields added to every log record (e.g. environment,
|
|
78
|
+
pipeline name).
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
logger_name: str = "databricks4py.batch",
|
|
84
|
+
extra_fields: dict[str, Any] | None = None,
|
|
85
|
+
) -> None:
|
|
86
|
+
self._logger = logging.getLogger(logger_name)
|
|
87
|
+
self._extra = dict(extra_fields) if extra_fields else {}
|
|
88
|
+
|
|
89
|
+
def _emit(self, event: str, ctx: BatchContext, level: int, **fields: Any) -> None:
|
|
90
|
+
if not self._logger.isEnabledFor(level):
|
|
91
|
+
return
|
|
92
|
+
record = {
|
|
93
|
+
"event": event,
|
|
94
|
+
"batch_id": ctx.batch_id,
|
|
95
|
+
"source_table": ctx.source_table,
|
|
96
|
+
"correlation_id": ctx.correlation_id,
|
|
97
|
+
"timestamp": datetime.now(tz=timezone.utc).isoformat(),
|
|
98
|
+
**self._extra,
|
|
99
|
+
**fields,
|
|
100
|
+
}
|
|
101
|
+
self._logger.log(level, json.dumps(record, default=str))
|
|
102
|
+
|
|
103
|
+
def batch_start(self, ctx: BatchContext) -> None:
|
|
104
|
+
self._emit("batch_start", ctx, logging.INFO)
|
|
105
|
+
|
|
106
|
+
def batch_complete(
|
|
107
|
+
self,
|
|
108
|
+
ctx: BatchContext,
|
|
109
|
+
row_count: int,
|
|
110
|
+
duration_ms: float,
|
|
111
|
+
) -> None:
|
|
112
|
+
self._emit(
|
|
113
|
+
"batch_complete",
|
|
114
|
+
ctx,
|
|
115
|
+
logging.INFO,
|
|
116
|
+
row_count=row_count,
|
|
117
|
+
duration_ms=round(duration_ms, 2),
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
def batch_error(self, ctx: BatchContext, error: str) -> None:
|
|
121
|
+
self._emit("batch_error", ctx, logging.ERROR, error=error[:2000])
|
|
122
|
+
|
|
123
|
+
def batch_skip(self, ctx: BatchContext, reason: str) -> None:
|
|
124
|
+
self._emit("batch_skip", ctx, logging.DEBUG, reason=reason[:500])
|
|
125
|
+
|
|
126
|
+
def batch_dlq(self, ctx: BatchContext, dlq_table: str, error: str) -> None:
|
|
127
|
+
"""Log that a failed batch was routed to the dead-letter queue."""
|
|
128
|
+
self._emit(
|
|
129
|
+
"batch_dlq",
|
|
130
|
+
ctx,
|
|
131
|
+
logging.WARNING,
|
|
132
|
+
dlq_table=dlq_table,
|
|
133
|
+
error=error[:2000],
|
|
134
|
+
)
|
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
"""Health checks for streaming queries and checkpoints.
|
|
2
|
+
|
|
3
|
+
Polls a ``StreamingQuery`` or ``QueryProgressObserver`` and evaluates
|
|
4
|
+
configurable thresholds to produce a health status. Use in monitoring
|
|
5
|
+
dashboards, alerting hooks, or as a pre-flight check before scaling down.
|
|
6
|
+
|
|
7
|
+
Example::
|
|
8
|
+
|
|
9
|
+
check = StreamingHealthCheck(
|
|
10
|
+
query,
|
|
11
|
+
max_batch_duration_ms=60_000,
|
|
12
|
+
min_processing_rate=100.0,
|
|
13
|
+
stale_timeout_seconds=300,
|
|
14
|
+
)
|
|
15
|
+
result = check.evaluate()
|
|
16
|
+
if result.status == HealthStatus.UNHEALTHY:
|
|
17
|
+
alert(result.summary())
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
import json as _json
|
|
23
|
+
import logging
|
|
24
|
+
import time
|
|
25
|
+
from dataclasses import dataclass, field
|
|
26
|
+
from datetime import datetime, timezone
|
|
27
|
+
from enum import Enum
|
|
28
|
+
from typing import TYPE_CHECKING, Any
|
|
29
|
+
|
|
30
|
+
from databricks4py.observability._utils import parse_duration_ms
|
|
31
|
+
|
|
32
|
+
if TYPE_CHECKING:
|
|
33
|
+
from pyspark.sql.streaming import StreamingQuery
|
|
34
|
+
|
|
35
|
+
__all__ = ["CheckDetail", "HealthResult", "HealthStatus", "StreamingHealthCheck"]
|
|
36
|
+
|
|
37
|
+
logger = logging.getLogger(__name__)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class HealthStatus(Enum):
|
|
41
|
+
"""Overall health of a monitored component."""
|
|
42
|
+
|
|
43
|
+
HEALTHY = "healthy"
|
|
44
|
+
DEGRADED = "degraded"
|
|
45
|
+
UNHEALTHY = "unhealthy"
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass(frozen=True)
|
|
49
|
+
class CheckDetail:
|
|
50
|
+
"""Result of a single health check rule.
|
|
51
|
+
|
|
52
|
+
Attributes:
|
|
53
|
+
name: Short identifier for the check (e.g. ``"stuck_query"``).
|
|
54
|
+
status: Pass/degraded/fail for this individual check.
|
|
55
|
+
message: Human-readable explanation.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
name: str
|
|
59
|
+
status: HealthStatus
|
|
60
|
+
message: str
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@dataclass(frozen=True)
|
|
64
|
+
class HealthResult:
|
|
65
|
+
"""Aggregated health across all check rules.
|
|
66
|
+
|
|
67
|
+
``status`` is the worst status among all ``checks``. If any check is
|
|
68
|
+
UNHEALTHY the result is UNHEALTHY; if any is DEGRADED it's DEGRADED.
|
|
69
|
+
|
|
70
|
+
Attributes:
|
|
71
|
+
status: Worst-case status across all checks.
|
|
72
|
+
checks: Individual check results.
|
|
73
|
+
timestamp: UTC time the evaluation ran.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
status: HealthStatus
|
|
77
|
+
checks: list[CheckDetail] = field(default_factory=list)
|
|
78
|
+
timestamp: datetime = field(default_factory=lambda: datetime.now(tz=timezone.utc))
|
|
79
|
+
|
|
80
|
+
def summary(self) -> str:
|
|
81
|
+
"""One-line-per-check summary string."""
|
|
82
|
+
lines = [f"Overall: {self.status.value}"]
|
|
83
|
+
for c in self.checks:
|
|
84
|
+
lines.append(f" [{c.status.value}] {c.name}: {c.message}")
|
|
85
|
+
return "\n".join(lines)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _worst(statuses: list[HealthStatus]) -> HealthStatus:
|
|
89
|
+
if HealthStatus.UNHEALTHY in statuses:
|
|
90
|
+
return HealthStatus.UNHEALTHY
|
|
91
|
+
if HealthStatus.DEGRADED in statuses:
|
|
92
|
+
return HealthStatus.DEGRADED
|
|
93
|
+
return HealthStatus.HEALTHY
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class StreamingHealthCheck:
|
|
97
|
+
"""Evaluates a streaming query's health by polling its progress.
|
|
98
|
+
|
|
99
|
+
Checks (all optional — configure the thresholds you care about):
|
|
100
|
+
|
|
101
|
+
- **Stuck query**: No progress events for ``stale_timeout_seconds``.
|
|
102
|
+
- **Slow batches**: ``batch_duration_ms`` exceeds ``max_batch_duration_ms``.
|
|
103
|
+
- **Low throughput**: ``processedRowsPerSecond`` below ``min_processing_rate``.
|
|
104
|
+
- **Query inactive**: The query has stopped unexpectedly.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
query: A PySpark ``StreamingQuery`` to monitor.
|
|
108
|
+
max_batch_duration_ms: DEGRADED if last batch took longer than this.
|
|
109
|
+
min_processing_rate: DEGRADED if processed rows/sec drops below this.
|
|
110
|
+
stale_timeout_seconds: UNHEALTHY if no progress for this many seconds.
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
def __init__(
|
|
114
|
+
self,
|
|
115
|
+
query: StreamingQuery,
|
|
116
|
+
*,
|
|
117
|
+
max_batch_duration_ms: int | None = None,
|
|
118
|
+
min_processing_rate: float | None = None,
|
|
119
|
+
stale_timeout_seconds: int = 600,
|
|
120
|
+
) -> None:
|
|
121
|
+
self._query = query
|
|
122
|
+
self._max_batch_duration_ms = max_batch_duration_ms
|
|
123
|
+
self._min_processing_rate = min_processing_rate
|
|
124
|
+
self._stale_timeout_seconds = stale_timeout_seconds
|
|
125
|
+
self._last_progress_time: float = time.monotonic()
|
|
126
|
+
self._last_batch_id: int | None = None
|
|
127
|
+
|
|
128
|
+
def _get_progress(self) -> dict[str, Any] | None:
|
|
129
|
+
progress = self._query.lastProgress
|
|
130
|
+
if progress is None:
|
|
131
|
+
return None
|
|
132
|
+
if isinstance(progress, dict):
|
|
133
|
+
return progress
|
|
134
|
+
return _json.loads(progress.json)
|
|
135
|
+
|
|
136
|
+
def evaluate(self) -> HealthResult:
|
|
137
|
+
"""Run all configured checks and return the aggregated result."""
|
|
138
|
+
checks: list[CheckDetail] = []
|
|
139
|
+
|
|
140
|
+
# Check 1: query still active
|
|
141
|
+
if not self._query.isActive:
|
|
142
|
+
checks.append(
|
|
143
|
+
CheckDetail(
|
|
144
|
+
name="query_active",
|
|
145
|
+
status=HealthStatus.UNHEALTHY,
|
|
146
|
+
message="Query is no longer active",
|
|
147
|
+
)
|
|
148
|
+
)
|
|
149
|
+
return HealthResult(status=_worst([c.status for c in checks]), checks=checks)
|
|
150
|
+
|
|
151
|
+
checks.append(
|
|
152
|
+
CheckDetail(
|
|
153
|
+
name="query_active",
|
|
154
|
+
status=HealthStatus.HEALTHY,
|
|
155
|
+
message="Query is running",
|
|
156
|
+
)
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
progress = self._get_progress()
|
|
160
|
+
if progress is None:
|
|
161
|
+
elapsed = time.monotonic() - self._last_progress_time
|
|
162
|
+
if elapsed > self._stale_timeout_seconds:
|
|
163
|
+
checks.append(
|
|
164
|
+
CheckDetail(
|
|
165
|
+
name="stale_progress",
|
|
166
|
+
status=HealthStatus.UNHEALTHY,
|
|
167
|
+
message=f"No progress events for {elapsed:.0f}s "
|
|
168
|
+
f"(threshold: {self._stale_timeout_seconds}s)",
|
|
169
|
+
)
|
|
170
|
+
)
|
|
171
|
+
else:
|
|
172
|
+
checks.append(
|
|
173
|
+
CheckDetail(
|
|
174
|
+
name="stale_progress",
|
|
175
|
+
status=HealthStatus.HEALTHY,
|
|
176
|
+
message="Waiting for first progress event",
|
|
177
|
+
)
|
|
178
|
+
)
|
|
179
|
+
return HealthResult(status=_worst([c.status for c in checks]), checks=checks)
|
|
180
|
+
|
|
181
|
+
# Track progress advancement
|
|
182
|
+
batch_id = progress.get("batchId", -1)
|
|
183
|
+
if batch_id != self._last_batch_id:
|
|
184
|
+
self._last_progress_time = time.monotonic()
|
|
185
|
+
self._last_batch_id = batch_id
|
|
186
|
+
else:
|
|
187
|
+
elapsed = time.monotonic() - self._last_progress_time
|
|
188
|
+
if elapsed > self._stale_timeout_seconds:
|
|
189
|
+
checks.append(
|
|
190
|
+
CheckDetail(
|
|
191
|
+
name="stale_progress",
|
|
192
|
+
status=HealthStatus.UNHEALTHY,
|
|
193
|
+
message=f"Batch {batch_id} unchanged for {elapsed:.0f}s",
|
|
194
|
+
)
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# Check 2: batch duration
|
|
198
|
+
if self._max_batch_duration_ms is not None:
|
|
199
|
+
duration = parse_duration_ms(progress.get("batchDuration", "0 ms"))
|
|
200
|
+
exceeded = duration > self._max_batch_duration_ms
|
|
201
|
+
checks.append(
|
|
202
|
+
CheckDetail(
|
|
203
|
+
name="batch_duration",
|
|
204
|
+
status=HealthStatus.DEGRADED if exceeded else HealthStatus.HEALTHY,
|
|
205
|
+
message=f"Batch took {duration}ms"
|
|
206
|
+
+ (f" (max: {self._max_batch_duration_ms}ms)" if exceeded else ""),
|
|
207
|
+
)
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
# Check 3: processing rate
|
|
211
|
+
if self._min_processing_rate is not None:
|
|
212
|
+
rate = progress.get("processedRowsPerSecond", 0.0)
|
|
213
|
+
below = rate < self._min_processing_rate
|
|
214
|
+
checks.append(
|
|
215
|
+
CheckDetail(
|
|
216
|
+
name="processing_rate",
|
|
217
|
+
status=HealthStatus.DEGRADED if below else HealthStatus.HEALTHY,
|
|
218
|
+
message=f"Processing {rate:.1f} rows/s"
|
|
219
|
+
+ (f" (min: {self._min_processing_rate:.1f})" if below else ""),
|
|
220
|
+
)
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
return HealthResult(status=_worst([c.status for c in checks]), checks=checks)
|