dfguard 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dfguard/__init__.py +8 -0
- dfguard/py.typed +0 -0
- dfguard/pyspark/__init__.py +114 -0
- dfguard/pyspark/_enforcement.py +250 -0
- dfguard/pyspark/_inference.py +77 -0
- dfguard/pyspark/_nullable.py +49 -0
- dfguard/pyspark/coercion.py +203 -0
- dfguard/pyspark/dataset.py +696 -0
- dfguard/pyspark/decorators.py +86 -0
- dfguard/pyspark/exceptions.py +55 -0
- dfguard/pyspark/history.py +139 -0
- dfguard/pyspark/schema.py +418 -0
- dfguard/pyspark/types.py +107 -0
- dfguard-0.1.0.dist-info/METADATA +415 -0
- dfguard-0.1.0.dist-info/RECORD +17 -0
- dfguard-0.1.0.dist-info/WHEEL +4 -0
- dfguard-0.1.0.dist-info/licenses/LICENSE +147 -0
dfguard/__init__.py
ADDED
dfguard/py.typed
ADDED
|
File without changes
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
"""
|
|
2
|
+
dfguard.pyspark: runtime schema enforcement for PySpark DataFrames.
|
|
3
|
+
|
|
4
|
+
Two-line setup for packages (Kedro, Airflow, any importable module)
|
|
5
|
+
-------------------------------------------------------------------
|
|
6
|
+
::
|
|
7
|
+
|
|
8
|
+
from dfguard.pyspark import schema_of, arm
|
|
9
|
+
|
|
10
|
+
RawSchema = schema_of(raw_df)
|
|
11
|
+
|
|
12
|
+
def enrich(df: RawSchema): ...
|
|
13
|
+
def clean(df: RawSchema): ...
|
|
14
|
+
|
|
15
|
+
arm() # enforces every annotated function above
|
|
16
|
+
|
|
17
|
+
For scripts and notebooks use ``@enforce`` per function::
|
|
18
|
+
|
|
19
|
+
from dfguard.pyspark import schema_of, enforce
|
|
20
|
+
|
|
21
|
+
RawSchema = schema_of(raw_df)
|
|
22
|
+
|
|
23
|
+
@enforce
|
|
24
|
+
def enrich(df: RawSchema): ...
|
|
25
|
+
|
|
26
|
+
Declaring schemas upfront (no live DataFrame required)
|
|
27
|
+
------------------------------------------------------
|
|
28
|
+
::
|
|
29
|
+
|
|
30
|
+
from pyspark.sql import types as T
|
|
31
|
+
from dfguard.pyspark import SparkSchema, Optional, enforce
|
|
32
|
+
|
|
33
|
+
class OrderSchema(SparkSchema):
|
|
34
|
+
order_id: T.LongType()
|
|
35
|
+
amount: T.DoubleType()
|
|
36
|
+
tags: T.ArrayType(T.StringType())
|
|
37
|
+
address: AddressSchema # nested struct
|
|
38
|
+
zip: Optional[T.StringType()] # nullable
|
|
39
|
+
|
|
40
|
+
@enforce
|
|
41
|
+
def process(df: OrderSchema): ... # subset matching: df must have these fields
|
|
42
|
+
|
|
43
|
+
Public API
|
|
44
|
+
----------
|
|
45
|
+
``schema_of(df)``
|
|
46
|
+
Capture ``df``'s schema as a type. Exact match required.
|
|
47
|
+
|
|
48
|
+
``dataset(df)``
|
|
49
|
+
Wrap ``df`` in a tracked instance. Every ``withColumn``, ``drop``,
|
|
50
|
+
``select``, etc. is recorded in ``schema_history``.
|
|
51
|
+
|
|
52
|
+
``arm()``
|
|
53
|
+
Apply schema enforcement to every annotated function in the calling
|
|
54
|
+
module. Call after all function definitions.
|
|
55
|
+
|
|
56
|
+
``disarm()``
|
|
57
|
+
Turn off all enforcement globally. Call ``arm()`` to re-enable.
|
|
58
|
+
|
|
59
|
+
``enforce``
|
|
60
|
+
Per-function decorator. Only checks schema-annotated args.
|
|
61
|
+
|
|
62
|
+
``SparkSchema``
|
|
63
|
+
Declare a schema as a class using real PySpark types.
|
|
64
|
+
Subset matching: df must have at least the declared fields.
|
|
65
|
+
|
|
66
|
+
``check_schema(schema)`` / ``typed_transform(input_schema, output_schema)``
|
|
67
|
+
Function decorators for explicit input/output validation.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
try:
|
|
71
|
+
import pyspark # noqa: F401
|
|
72
|
+
except ImportError as _e:
|
|
73
|
+
raise ImportError(
|
|
74
|
+
"dfguard's PySpark integration requires PySpark. "
|
|
75
|
+
"Install it with: pip install 'dfguard[pyspark]'"
|
|
76
|
+
) from _e
|
|
77
|
+
|
|
78
|
+
from dfguard.pyspark._enforcement import arm, disarm, enforce
|
|
79
|
+
from dfguard.pyspark._inference import infer_schema
|
|
80
|
+
from dfguard.pyspark._nullable import Optional
|
|
81
|
+
from dfguard.pyspark.coercion import result_type
|
|
82
|
+
from dfguard.pyspark.dataset import TypedGroupedData, _TypedDatasetBase, schema_of
|
|
83
|
+
from dfguard.pyspark.dataset import _make_dataset as dataset
|
|
84
|
+
from dfguard.pyspark.decorators import check_schema, typed_transform
|
|
85
|
+
from dfguard.pyspark.exceptions import (
|
|
86
|
+
ColumnNotFoundError,
|
|
87
|
+
DfTypesError,
|
|
88
|
+
SchemaValidationError,
|
|
89
|
+
TypeAnnotationError,
|
|
90
|
+
)
|
|
91
|
+
from dfguard.pyspark.history import SchemaChange, SchemaHistory
|
|
92
|
+
from dfguard.pyspark.schema import SparkSchema # noqa: E402
|
|
93
|
+
|
|
94
|
+
__all__ = [
|
|
95
|
+
"schema_of",
|
|
96
|
+
"dataset",
|
|
97
|
+
"enforce",
|
|
98
|
+
"arm",
|
|
99
|
+
"disarm",
|
|
100
|
+
"_TypedDatasetBase",
|
|
101
|
+
"TypedGroupedData",
|
|
102
|
+
"SparkSchema",
|
|
103
|
+
"typed_transform",
|
|
104
|
+
"check_schema",
|
|
105
|
+
"SchemaChange",
|
|
106
|
+
"SchemaHistory",
|
|
107
|
+
"DfTypesError",
|
|
108
|
+
"SchemaValidationError",
|
|
109
|
+
"TypeAnnotationError",
|
|
110
|
+
"ColumnNotFoundError",
|
|
111
|
+
"infer_schema",
|
|
112
|
+
"result_type",
|
|
113
|
+
"Optional",
|
|
114
|
+
]
|
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
"""Schema enforcement without touching non-schema arguments."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import functools
|
|
6
|
+
import importlib
|
|
7
|
+
import inspect
|
|
8
|
+
import pkgutil
|
|
9
|
+
import types
|
|
10
|
+
import warnings
|
|
11
|
+
from collections.abc import Callable
|
|
12
|
+
from typing import Any, TypeVar, overload
|
|
13
|
+
|
|
14
|
+
F = TypeVar("F", bound=Callable[..., Any])
|
|
15
|
+
|
|
16
|
+
_ENABLED = True # dfg.disarm() / dfg.arm() controls this
|
|
17
|
+
_ARMED = False # tracks whether arm() has been called
|
|
18
|
+
_SUBSET = True # dfg.arm(subset=...) sets the global default; function-level overrides this
|
|
19
|
+
|
|
20
|
+
_UNSET = object() # sentinel: "no function-level override, use global"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _is_schema_type(annotation: Any) -> bool:
|
|
24
|
+
"""
|
|
25
|
+
Return True when *annotation* participates in dfguard enforcement.
|
|
26
|
+
|
|
27
|
+
Any class that exposes a ``_fg_check(value, subset) -> bool`` classmethod
|
|
28
|
+
is treated as a schema type. This is the extension point: new DataFrame
|
|
29
|
+
backends (pandas, polars, …) just need to add ``_fg_check`` to their
|
|
30
|
+
schema class; no changes to enforcement code required.
|
|
31
|
+
"""
|
|
32
|
+
return isinstance(annotation, type) and callable(getattr(annotation, "_fg_check", None))
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _schema_matches(value: Any, annotation: type, subset: bool) -> bool:
|
|
36
|
+
"""
|
|
37
|
+
Check whether *value* satisfies *annotation*.
|
|
38
|
+
|
|
39
|
+
Delegates to ``annotation._fg_check(value, subset)`` when available.
|
|
40
|
+
The meaning of *subset* is left to each schema type:
|
|
41
|
+
|
|
42
|
+
- ``schema_of`` types (``_TypedDatasetBase``) ignore *subset*: always exact.
|
|
43
|
+
- ``SparkSchema`` types respect *subset*:
|
|
44
|
+
- ``True``: extra columns in *value* are fine.
|
|
45
|
+
- ``False``: *value* must have exactly the declared columns, nothing extra.
|
|
46
|
+
"""
|
|
47
|
+
checker = getattr(annotation, "_fg_check", None)
|
|
48
|
+
if callable(checker):
|
|
49
|
+
return bool(checker(value, subset))
|
|
50
|
+
return isinstance(value, annotation)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _arm_module_dict(module_dict: dict[str, Any], *, subset: Any) -> None:
|
|
54
|
+
"""Patch all public functions in a module's __dict__ with enforce()."""
|
|
55
|
+
for name, obj in list(module_dict.items()):
|
|
56
|
+
if name.startswith("_"):
|
|
57
|
+
continue
|
|
58
|
+
if isinstance(obj, types.FunctionType):
|
|
59
|
+
# Pass _UNSET so each wrapped function reads _SUBSET at call-time,
|
|
60
|
+
# unless overridden at decoration time by the caller.
|
|
61
|
+
wrapped = enforce(subset=subset)(obj)
|
|
62
|
+
if wrapped is not obj:
|
|
63
|
+
module_dict[name] = wrapped
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def arm(
|
|
67
|
+
module: Any = None,
|
|
68
|
+
*,
|
|
69
|
+
package: str | None = None,
|
|
70
|
+
subset: bool = True,
|
|
71
|
+
) -> None:
|
|
72
|
+
"""
|
|
73
|
+
Arm the entire calling package and set the global subset default.
|
|
74
|
+
|
|
75
|
+
Call once from your entry point, ``__init__.py``, or ``settings.py`` (Kedro)::
|
|
76
|
+
|
|
77
|
+
import dfguard.pyspark as dfg
|
|
78
|
+
|
|
79
|
+
dfg.arm() # subset=True (default): extra columns are fine
|
|
80
|
+
dfg.arm(subset=False) # exact match: no extra columns allowed anywhere
|
|
81
|
+
|
|
82
|
+
The ``subset`` value becomes the global default. Individual functions decorated
|
|
83
|
+
with ``@dfg.enforce(subset=...)`` override it for that function only.
|
|
84
|
+
|
|
85
|
+
If called when already armed, re-enables enforcement (sets ``_ENABLED = True``)
|
|
86
|
+
without re-walking the package.
|
|
87
|
+
|
|
88
|
+
**Specific module object**::
|
|
89
|
+
|
|
90
|
+
dfg.arm(my_module)
|
|
91
|
+
|
|
92
|
+
**Explicit package name**::
|
|
93
|
+
|
|
94
|
+
dfg.arm(package="my_pipeline.nodes")
|
|
95
|
+
"""
|
|
96
|
+
global _SUBSET, _ENABLED, _ARMED
|
|
97
|
+
_SUBSET = subset
|
|
98
|
+
_ENABLED = True
|
|
99
|
+
|
|
100
|
+
if _ARMED:
|
|
101
|
+
# Already armed: just re-enable and update subset, no re-walking needed.
|
|
102
|
+
return
|
|
103
|
+
|
|
104
|
+
_ARMED = True
|
|
105
|
+
|
|
106
|
+
if isinstance(module, types.ModuleType):
|
|
107
|
+
_arm_module_dict(vars(module), subset=_UNSET)
|
|
108
|
+
return
|
|
109
|
+
|
|
110
|
+
if package is None:
|
|
111
|
+
frame = inspect.currentframe()
|
|
112
|
+
if frame is None or frame.f_back is None:
|
|
113
|
+
return
|
|
114
|
+
caller_globals = frame.f_back.f_globals
|
|
115
|
+
package = caller_globals.get("__package__") or caller_globals.get("__name__", "")
|
|
116
|
+
|
|
117
|
+
if not package or package == "__main__":
|
|
118
|
+
warnings.warn(
|
|
119
|
+
"dfguard.pyspark.arm() called from __main__. "
|
|
120
|
+
"Use @dfguard.pyspark.enforce on individual functions instead.",
|
|
121
|
+
stacklevel=2,
|
|
122
|
+
)
|
|
123
|
+
return
|
|
124
|
+
|
|
125
|
+
pkg = importlib.import_module(package)
|
|
126
|
+
_arm_module_dict(vars(pkg), subset=_UNSET)
|
|
127
|
+
pkg_path = getattr(pkg, "__path__", None)
|
|
128
|
+
if pkg_path is not None:
|
|
129
|
+
for _, mod_name, _ in pkgutil.walk_packages(pkg_path, prefix=package + "."):
|
|
130
|
+
try:
|
|
131
|
+
mod = importlib.import_module(mod_name)
|
|
132
|
+
_arm_module_dict(vars(mod), subset=_UNSET)
|
|
133
|
+
except Exception as exc:
|
|
134
|
+
warnings.warn(
|
|
135
|
+
f"dfguard: skipped arming module '{mod_name}': {exc}",
|
|
136
|
+
stacklevel=2,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def disarm() -> None:
|
|
141
|
+
"""Turn off all enforcement globally. Call arm() to re-enable."""
|
|
142
|
+
global _ENABLED
|
|
143
|
+
_ENABLED = False
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
@overload
|
|
147
|
+
def enforce(func: F) -> F: ...
|
|
148
|
+
|
|
149
|
+
@overload
|
|
150
|
+
def enforce(func: None = None, *, subset: bool = ...) -> Callable[[F], F]: ...
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def enforce(
|
|
154
|
+
func: F | None = None,
|
|
155
|
+
*,
|
|
156
|
+
subset: Any = _UNSET,
|
|
157
|
+
) -> F | Callable[[F], F]:
|
|
158
|
+
"""
|
|
159
|
+
Validate schema annotations on DataFrame arguments.
|
|
160
|
+
|
|
161
|
+
Only intercepts parameters annotated with a ``dfg.schema_of`` type or a
|
|
162
|
+
``dfg.SparkSchema`` subclass. All other arguments are left completely alone.
|
|
163
|
+
|
|
164
|
+
**Default**: inherits the global ``subset`` set by ``dfg.arm()``:
|
|
165
|
+
|
|
166
|
+
@dfg.enforce
|
|
167
|
+
def process(df: OrderSchema, label: str): ...
|
|
168
|
+
|
|
169
|
+
**subset=True**: extra columns in the DataFrame are fine (overrides global)::
|
|
170
|
+
|
|
171
|
+
@dfg.enforce(subset=True)
|
|
172
|
+
def process(df: OrderSchema): ...
|
|
173
|
+
|
|
174
|
+
**subset=False**: DataFrame must match the schema exactly (overrides global)::
|
|
175
|
+
|
|
176
|
+
@dfg.enforce(subset=False)
|
|
177
|
+
def process(df: OrderSchema): ...
|
|
178
|
+
"""
|
|
179
|
+
# Capture the function-level subset at decoration time.
|
|
180
|
+
# If _UNSET, the wrapper reads _SUBSET at call-time (respects dfg.arm changes).
|
|
181
|
+
subset_override = subset
|
|
182
|
+
|
|
183
|
+
def decorator(f: F) -> F:
|
|
184
|
+
params = inspect.signature(f).parameters
|
|
185
|
+
schema_params = [
|
|
186
|
+
(name, param.annotation)
|
|
187
|
+
for name, param in params.items()
|
|
188
|
+
if param.annotation is not inspect.Parameter.empty
|
|
189
|
+
and _is_schema_type(param.annotation)
|
|
190
|
+
]
|
|
191
|
+
|
|
192
|
+
if not schema_params:
|
|
193
|
+
return f # nothing schema-typed, zero overhead
|
|
194
|
+
|
|
195
|
+
sig = inspect.signature(f)
|
|
196
|
+
|
|
197
|
+
@functools.wraps(f)
|
|
198
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
199
|
+
if not _ENABLED:
|
|
200
|
+
return f(*args, **kwargs)
|
|
201
|
+
|
|
202
|
+
# Function-level subset wins; fall back to global if not set.
|
|
203
|
+
effective_subset = _SUBSET if subset_override is _UNSET else subset_override
|
|
204
|
+
|
|
205
|
+
bound = sig.bind(*args, **kwargs)
|
|
206
|
+
bound.apply_defaults()
|
|
207
|
+
|
|
208
|
+
for param_name, annotation in schema_params:
|
|
209
|
+
if param_name not in bound.arguments:
|
|
210
|
+
continue
|
|
211
|
+
value = bound.arguments[param_name]
|
|
212
|
+
if not _schema_matches(value, annotation, subset=effective_subset):
|
|
213
|
+
_raise_schema_mismatch(f.__name__, param_name, annotation, value)
|
|
214
|
+
|
|
215
|
+
return f(*args, **kwargs)
|
|
216
|
+
|
|
217
|
+
return wrapper # type: ignore[return-value]
|
|
218
|
+
|
|
219
|
+
if func is not None:
|
|
220
|
+
return decorator(func)
|
|
221
|
+
return decorator
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def _raise_schema_mismatch(
|
|
225
|
+
func_name: str,
|
|
226
|
+
param_name: str,
|
|
227
|
+
annotation: type,
|
|
228
|
+
value: Any,
|
|
229
|
+
) -> None:
|
|
230
|
+
actual_schema = getattr(value, "schema", None)
|
|
231
|
+
if actual_schema is not None:
|
|
232
|
+
actual_str = ", ".join(
|
|
233
|
+
f"{f.name}:{f.dataType.simpleString()}" for f in actual_schema.fields
|
|
234
|
+
)
|
|
235
|
+
else:
|
|
236
|
+
actual_str = type(value).__name__
|
|
237
|
+
|
|
238
|
+
expected_schema = getattr(annotation, "_expected_schema", None)
|
|
239
|
+
if expected_schema is not None:
|
|
240
|
+
expected_str = ", ".join(
|
|
241
|
+
f"{f.name}:{f.dataType.simpleString()}" for f in expected_schema.fields
|
|
242
|
+
)
|
|
243
|
+
else:
|
|
244
|
+
expected_str = getattr(annotation, "__name__", repr(annotation))
|
|
245
|
+
|
|
246
|
+
raise TypeError(
|
|
247
|
+
f"Schema mismatch in {func_name}() argument '{param_name}':\n"
|
|
248
|
+
f" expected: {expected_str}\n"
|
|
249
|
+
f" received: {actual_str}"
|
|
250
|
+
)
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""
|
|
2
|
+
dfguard.pyspark._inference
|
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
4
|
+
``infer_schema(df)``: inspect a live DataFrame and generate a SparkSchema
|
|
5
|
+
subclass with the correct types, including deeply nested structs.
|
|
6
|
+
|
|
7
|
+
The generated class can immediately be used for validation::
|
|
8
|
+
|
|
9
|
+
schema = infer_schema(df, name="OrderSchema")
|
|
10
|
+
print(schema.to_code()) # copy-paste into your codebase
|
|
11
|
+
ds.validate(schema)
|
|
12
|
+
|
|
13
|
+
Nested StructTypes are emitted as separate named inner classes.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
from typing import Any
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def infer_schema(df: Any, name: str = "InferredSchema") -> type:
|
|
22
|
+
"""
|
|
23
|
+
Inspect ``df`` (a pyspark.sql.DataFrame or dataset wrapper) and return
|
|
24
|
+
a SparkSchema subclass that exactly matches its current schema.
|
|
25
|
+
|
|
26
|
+
Also prints the Python code so developers can copy it into their codebase.
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
df:
|
|
31
|
+
A live ``pyspark.sql.DataFrame`` or ``dataset`` wrapper.
|
|
32
|
+
name:
|
|
33
|
+
Name to give the generated class.
|
|
34
|
+
|
|
35
|
+
Returns
|
|
36
|
+
-------
|
|
37
|
+
type[SparkSchema]
|
|
38
|
+
A fully usable SparkSchema subclass.
|
|
39
|
+
"""
|
|
40
|
+
from dfguard.pyspark.schema import SparkSchema
|
|
41
|
+
|
|
42
|
+
struct = df.schema
|
|
43
|
+
schema_class = SparkSchema.from_struct(struct, name=name)
|
|
44
|
+
|
|
45
|
+
print(_render_code(schema_class, name))
|
|
46
|
+
return schema_class
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _render_code(schema_class: Any, name: str) -> str:
|
|
50
|
+
"""Recursively render a SparkSchema (and any nested schemas) as Python source."""
|
|
51
|
+
lines: list[str] = []
|
|
52
|
+
nested_lines: list[str] = []
|
|
53
|
+
|
|
54
|
+
# Collect any nested SparkSchema classes stored as class attributes
|
|
55
|
+
from dfguard.pyspark.schema import SparkSchema
|
|
56
|
+
for attr_name, attr_val in vars(schema_class).items():
|
|
57
|
+
if (
|
|
58
|
+
isinstance(attr_val, type)
|
|
59
|
+
and issubclass(attr_val, SparkSchema)
|
|
60
|
+
and attr_val is not SparkSchema
|
|
61
|
+
):
|
|
62
|
+
nested_lines.append(_render_code(attr_val, attr_val.__name__))
|
|
63
|
+
nested_lines.append("")
|
|
64
|
+
|
|
65
|
+
if nested_lines:
|
|
66
|
+
lines.extend(nested_lines)
|
|
67
|
+
|
|
68
|
+
lines.append(f"class {name}(SparkSchema):")
|
|
69
|
+
if not schema_class._schema_fields:
|
|
70
|
+
lines.append(" pass")
|
|
71
|
+
else:
|
|
72
|
+
from dfguard.pyspark.schema import _annotation_to_str
|
|
73
|
+
for col_name, annotation in schema_class._schema_fields.items():
|
|
74
|
+
ann_str = _annotation_to_str(annotation)
|
|
75
|
+
lines.append(f" {col_name}: {ann_str}")
|
|
76
|
+
|
|
77
|
+
return "\n".join(lines)
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Nullable field marker for SparkSchema.
|
|
3
|
+
|
|
4
|
+
``typing.Optional[T.StringType()]`` raises TypeError on Python 3.10 because
|
|
5
|
+
PySpark DataType instances are not callable classes, and Python's typing module
|
|
6
|
+
validates this. Python 3.11+ relaxed the check.
|
|
7
|
+
|
|
8
|
+
This module provides a drop-in replacement that works on Python 3.10+.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class _NullableAnnotation:
|
|
17
|
+
"""Wraps a DataType annotation to mark the field as nullable."""
|
|
18
|
+
|
|
19
|
+
__slots__ = ("inner",)
|
|
20
|
+
|
|
21
|
+
def __init__(self, inner: Any) -> None:
|
|
22
|
+
self.inner = inner
|
|
23
|
+
|
|
24
|
+
def __class_getitem__(cls, item: Any) -> _NullableAnnotation:
|
|
25
|
+
return cls(item)
|
|
26
|
+
|
|
27
|
+
def __repr__(self) -> str:
|
|
28
|
+
return f"Optional[{self.inner!r}]"
|
|
29
|
+
|
|
30
|
+
def __eq__(self, other: object) -> bool:
|
|
31
|
+
return isinstance(other, _NullableAnnotation) and self.inner == other.inner
|
|
32
|
+
|
|
33
|
+
def __hash__(self) -> int:
|
|
34
|
+
try:
|
|
35
|
+
return hash(("_NullableAnnotation", self.inner))
|
|
36
|
+
except TypeError:
|
|
37
|
+
return id(self)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
#: Drop-in for ``typing.Optional`` that works on Python 3.10+ with PySpark DataType fields.
|
|
41
|
+
#:
|
|
42
|
+
#: Usage::
|
|
43
|
+
#:
|
|
44
|
+
#: from dfguard.pyspark import Optional
|
|
45
|
+
#:
|
|
46
|
+
#: class OrderSchema(fg.SparkSchema):
|
|
47
|
+
#: order_id: T.LongType()
|
|
48
|
+
#: zip: Optional[T.StringType()] # nullable field
|
|
49
|
+
Optional = _NullableAnnotation
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Type coercion rules for PySpark, implemented in pure Python.
|
|
3
|
+
|
|
4
|
+
This mirrors Spark's Catalyst TypeCoercion and DecimalPrecision rules so that
|
|
5
|
+
dfguard can resolve derived column types without a running Spark session.
|
|
6
|
+
|
|
7
|
+
Rules source: org.apache.spark.sql.catalyst.analysis.TypeCoercion
|
|
8
|
+
org.apache.spark.sql.catalyst.analysis.DecimalPrecision
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
# ---------------------------------------------------------------------------
|
|
16
|
+
# Integer types expressed as Decimal equivalents (precision, scale=0)
|
|
17
|
+
# ---------------------------------------------------------------------------
|
|
18
|
+
|
|
19
|
+
_INT_AS_DECIMAL = {
|
|
20
|
+
"ByteType": (3, 0),
|
|
21
|
+
"ShortType": (5, 0),
|
|
22
|
+
"IntegerType": (10, 0),
|
|
23
|
+
"LongType": (20, 0),
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
# Numeric widening rank (higher = wider). Float/Double are above Long because
|
|
27
|
+
# Spark promotes integer+float → double (lossy but matches Spark behaviour).
|
|
28
|
+
_NUMERIC_RANK: dict[str, int] = {
|
|
29
|
+
"ByteType": 1,
|
|
30
|
+
"ShortType": 2,
|
|
31
|
+
"IntegerType": 3,
|
|
32
|
+
"LongType": 4,
|
|
33
|
+
"FloatType": 5,
|
|
34
|
+
"DoubleType": 6,
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _type_name(dt: Any) -> str:
|
|
39
|
+
return type(dt).__name__
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _is_integral(dt: Any) -> bool:
|
|
43
|
+
return _type_name(dt) in _INT_AS_DECIMAL
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _is_fractional(dt: Any) -> bool:
|
|
47
|
+
return _type_name(dt) in ("FloatType", "DoubleType")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _is_numeric(dt: Any) -> bool:
|
|
51
|
+
return _is_integral(dt) or _is_fractional(dt) or _type_name(dt) == "DecimalType"
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# ---------------------------------------------------------------------------
|
|
55
|
+
# Decimal precision arithmetic (mirrors DecimalPrecision.scala)
|
|
56
|
+
# ---------------------------------------------------------------------------
|
|
57
|
+
|
|
58
|
+
def _decimal_for_integral(dt: Any) -> tuple[int, int]:
|
|
59
|
+
"""Return (precision, scale) for an integer type treated as Decimal."""
|
|
60
|
+
return _INT_AS_DECIMAL[_type_name(dt)]
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _decimal_add(p1: int, s1: int, p2: int, s2: int) -> tuple[int, int]:
|
|
64
|
+
"""Decimal addition/subtraction result precision and scale."""
|
|
65
|
+
scale = max(s1, s2)
|
|
66
|
+
precision = max(p1 - s1, p2 - s2) + scale + 1
|
|
67
|
+
return precision, scale
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _decimal_mul(p1: int, s1: int, p2: int, s2: int) -> tuple[int, int]:
|
|
71
|
+
"""Decimal multiplication result precision and scale."""
|
|
72
|
+
scale = s1 + s2
|
|
73
|
+
precision = p1 + p2 + 1
|
|
74
|
+
return precision, scale
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _decimal_div(p1: int, s1: int, p2: int, s2: int) -> tuple[int, int]:
|
|
78
|
+
"""Decimal division result precision and scale."""
|
|
79
|
+
scale = max(6, s1 + p2 + 1)
|
|
80
|
+
precision = p1 - s1 + s2 + scale
|
|
81
|
+
return precision, scale
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _make_decimal(p: int, s: int) -> Any:
|
|
85
|
+
from pyspark.sql import types as T
|
|
86
|
+
# Spark caps precision at 38
|
|
87
|
+
p = min(p, 38)
|
|
88
|
+
return T.DecimalType(p, s)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
# ---------------------------------------------------------------------------
|
|
92
|
+
# Public API
|
|
93
|
+
# ---------------------------------------------------------------------------
|
|
94
|
+
|
|
95
|
+
def coerce_add(left: Any, right: Any) -> Any:
|
|
96
|
+
"""Return the result type of ``left + right`` (or ``left - right``)."""
|
|
97
|
+
return _coerce_binary(left, right, "add")
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def coerce_mul(left: Any, right: Any) -> Any:
|
|
101
|
+
"""Return the result type of ``left * right``."""
|
|
102
|
+
return _coerce_binary(left, right, "mul")
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def coerce_div(left: Any, right: Any) -> Any:
|
|
106
|
+
"""Return the result type of ``left / right``. Always Double in Spark."""
|
|
107
|
+
from pyspark.sql import types as T
|
|
108
|
+
if not (_is_numeric(left) and _is_numeric(right)):
|
|
109
|
+
raise TypeError(f"Cannot divide {_type_name(left)} by {_type_name(right)}")
|
|
110
|
+
return T.DoubleType()
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def coerce_mod(left: Any, right: Any) -> Any:
|
|
114
|
+
"""Return the result type of ``left % right`` (modulo)."""
|
|
115
|
+
return _coerce_binary(left, right, "add") # same widening rules as add
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _coerce_binary(left: Any, right: Any, op: str) -> Any:
|
|
119
|
+
from pyspark.sql import types as T
|
|
120
|
+
|
|
121
|
+
ln, rn = _type_name(left), _type_name(right)
|
|
122
|
+
|
|
123
|
+
# Both are simple numeric (no Decimal involved)
|
|
124
|
+
if ln in _NUMERIC_RANK and rn in _NUMERIC_RANK:
|
|
125
|
+
# Float/Double always promotes to Double (Spark rule)
|
|
126
|
+
if _is_fractional(left) or _is_fractional(right):
|
|
127
|
+
return T.DoubleType()
|
|
128
|
+
# Pure integer widening
|
|
129
|
+
winner = left if _NUMERIC_RANK[ln] >= _NUMERIC_RANK[rn] else right
|
|
130
|
+
return type(winner)()
|
|
131
|
+
|
|
132
|
+
# At least one side is Decimal
|
|
133
|
+
if ln == "DecimalType" or rn == "DecimalType":
|
|
134
|
+
# Float or Double wins over Decimal
|
|
135
|
+
if _is_fractional(left) or _is_fractional(right):
|
|
136
|
+
return T.DoubleType()
|
|
137
|
+
|
|
138
|
+
# Normalise both sides to (precision, scale)
|
|
139
|
+
if ln == "DecimalType":
|
|
140
|
+
p1, s1 = left.precision, left.scale
|
|
141
|
+
else:
|
|
142
|
+
p1, s1 = _decimal_for_integral(left)
|
|
143
|
+
|
|
144
|
+
if rn == "DecimalType":
|
|
145
|
+
p2, s2 = right.precision, right.scale
|
|
146
|
+
else:
|
|
147
|
+
p2, s2 = _decimal_for_integral(right)
|
|
148
|
+
|
|
149
|
+
if op == "mul":
|
|
150
|
+
p, s = _decimal_mul(p1, s1, p2, s2)
|
|
151
|
+
else:
|
|
152
|
+
p, s = _decimal_add(p1, s1, p2, s2)
|
|
153
|
+
|
|
154
|
+
return _make_decimal(p, s)
|
|
155
|
+
|
|
156
|
+
raise TypeError(f"Cannot coerce {ln} and {rn} for operation '{op}'")
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def coerce_comparison(left: Any, right: Any) -> Any:
|
|
160
|
+
"""Comparison (==, <, >, <=, >=) always returns BooleanType."""
|
|
161
|
+
from pyspark.sql import types as T
|
|
162
|
+
return T.BooleanType()
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def coerce_cast(target: Any) -> Any:
|
|
166
|
+
"""Explicit cast: the target type is the result."""
|
|
167
|
+
return target
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def result_type(left: Any, right: Any, op: str) -> Any:
|
|
171
|
+
"""
|
|
172
|
+
Resolve the output DataType for a binary operation between two typed columns.
|
|
173
|
+
|
|
174
|
+
Parameters
|
|
175
|
+
----------
|
|
176
|
+
left, right : DataType instances
|
|
177
|
+
op : one of '+', '-', '*', '/', '%', '==', '!=', '<', '<=', '>', '>='
|
|
178
|
+
|
|
179
|
+
Returns
|
|
180
|
+
-------
|
|
181
|
+
DataType instance representing the result type.
|
|
182
|
+
|
|
183
|
+
Examples
|
|
184
|
+
--------
|
|
185
|
+
>>> from pyspark.sql import types as T
|
|
186
|
+
>>> result_type(T.IntegerType(), T.DecimalType(10, 2), '+')
|
|
187
|
+
DecimalType(13,2)
|
|
188
|
+
>>> result_type(T.LongType(), T.DoubleType(), '*')
|
|
189
|
+
DoubleType()
|
|
190
|
+
>>> result_type(T.IntegerType(), T.IntegerType(), '/')
|
|
191
|
+
DoubleType()
|
|
192
|
+
"""
|
|
193
|
+
if op in ("+", "-"):
|
|
194
|
+
return coerce_add(left, right)
|
|
195
|
+
if op == "*":
|
|
196
|
+
return coerce_mul(left, right)
|
|
197
|
+
if op == "/":
|
|
198
|
+
return coerce_div(left, right)
|
|
199
|
+
if op == "%":
|
|
200
|
+
return coerce_mod(left, right)
|
|
201
|
+
if op in ("==", "!=", "<", "<=", ">", ">="):
|
|
202
|
+
return coerce_comparison(left, right)
|
|
203
|
+
raise ValueError(f"Unknown operator: {op!r}")
|