dfguard 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dfguard/__init__.py ADDED
@@ -0,0 +1,8 @@
1
+ """
2
+ dfguard: Type-safe dataframe operations for PySpark, Pandas, and Polars.
3
+
4
+ Backends are imported lazily so installing dfguard does not require
5
+ any dataframe library to be present.
6
+ """
7
+
8
+ __version__ = "0.1.0"
dfguard/py.typed ADDED
File without changes
@@ -0,0 +1,114 @@
1
+ """
2
+ dfguard.pyspark: runtime schema enforcement for PySpark DataFrames.
3
+
4
+ Two-line setup for packages (Kedro, Airflow, any importable module)
5
+ -------------------------------------------------------------------
6
+ ::
7
+
8
+ from dfguard.pyspark import schema_of, arm
9
+
10
+ RawSchema = schema_of(raw_df)
11
+
12
+ def enrich(df: RawSchema): ...
13
+ def clean(df: RawSchema): ...
14
+
15
+ arm() # enforces every annotated function above
16
+
17
+ For scripts and notebooks use ``@enforce`` per function::
18
+
19
+ from dfguard.pyspark import schema_of, enforce
20
+
21
+ RawSchema = schema_of(raw_df)
22
+
23
+ @enforce
24
+ def enrich(df: RawSchema): ...
25
+
26
+ Declaring schemas upfront (no live DataFrame required)
27
+ ------------------------------------------------------
28
+ ::
29
+
30
+ from pyspark.sql import types as T
31
+ from dfguard.pyspark import SparkSchema, Optional, enforce
32
+
33
+ class OrderSchema(SparkSchema):
34
+ order_id: T.LongType()
35
+ amount: T.DoubleType()
36
+ tags: T.ArrayType(T.StringType())
37
+ address: AddressSchema # nested struct
38
+ zip: Optional[T.StringType()] # nullable
39
+
40
+ @enforce
41
+ def process(df: OrderSchema): ... # subset matching: df must have these fields
42
+
43
+ Public API
44
+ ----------
45
+ ``schema_of(df)``
46
+ Capture ``df``'s schema as a type. Exact match required.
47
+
48
+ ``dataset(df)``
49
+ Wrap ``df`` in a tracked instance. Every ``withColumn``, ``drop``,
50
+ ``select``, etc. is recorded in ``schema_history``.
51
+
52
+ ``arm()``
53
+ Apply schema enforcement to every annotated function in the calling
54
+ module. Call after all function definitions.
55
+
56
+ ``disarm()``
57
+ Turn off all enforcement globally. Call ``arm()`` to re-enable.
58
+
59
+ ``enforce``
60
+ Per-function decorator. Only checks schema-annotated args.
61
+
62
+ ``SparkSchema``
63
+ Declare a schema as a class using real PySpark types.
64
+ Subset matching: df must have at least the declared fields.
65
+
66
+ ``check_schema(schema)`` / ``typed_transform(input_schema, output_schema)``
67
+ Function decorators for explicit input/output validation.
68
+ """
69
+
70
+ try:
71
+ import pyspark # noqa: F401
72
+ except ImportError as _e:
73
+ raise ImportError(
74
+ "dfguard's PySpark integration requires PySpark. "
75
+ "Install it with: pip install 'dfguard[pyspark]'"
76
+ ) from _e
77
+
78
+ from dfguard.pyspark._enforcement import arm, disarm, enforce
79
+ from dfguard.pyspark._inference import infer_schema
80
+ from dfguard.pyspark._nullable import Optional
81
+ from dfguard.pyspark.coercion import result_type
82
+ from dfguard.pyspark.dataset import TypedGroupedData, _TypedDatasetBase, schema_of
83
+ from dfguard.pyspark.dataset import _make_dataset as dataset
84
+ from dfguard.pyspark.decorators import check_schema, typed_transform
85
+ from dfguard.pyspark.exceptions import (
86
+ ColumnNotFoundError,
87
+ DfTypesError,
88
+ SchemaValidationError,
89
+ TypeAnnotationError,
90
+ )
91
+ from dfguard.pyspark.history import SchemaChange, SchemaHistory
92
+ from dfguard.pyspark.schema import SparkSchema # noqa: E402
93
+
94
+ __all__ = [
95
+ "schema_of",
96
+ "dataset",
97
+ "enforce",
98
+ "arm",
99
+ "disarm",
100
+ "_TypedDatasetBase",
101
+ "TypedGroupedData",
102
+ "SparkSchema",
103
+ "typed_transform",
104
+ "check_schema",
105
+ "SchemaChange",
106
+ "SchemaHistory",
107
+ "DfTypesError",
108
+ "SchemaValidationError",
109
+ "TypeAnnotationError",
110
+ "ColumnNotFoundError",
111
+ "infer_schema",
112
+ "result_type",
113
+ "Optional",
114
+ ]
@@ -0,0 +1,250 @@
1
+ """Schema enforcement without touching non-schema arguments."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import functools
6
+ import importlib
7
+ import inspect
8
+ import pkgutil
9
+ import types
10
+ import warnings
11
+ from collections.abc import Callable
12
+ from typing import Any, TypeVar, overload
13
+
14
+ F = TypeVar("F", bound=Callable[..., Any])
15
+
16
+ _ENABLED = True # dfg.disarm() / dfg.arm() controls this
17
+ _ARMED = False # tracks whether arm() has been called
18
+ _SUBSET = True # dfg.arm(subset=...) sets the global default; function-level overrides this
19
+
20
+ _UNSET = object() # sentinel: "no function-level override, use global"
21
+
22
+
23
+ def _is_schema_type(annotation: Any) -> bool:
24
+ """
25
+ Return True when *annotation* participates in dfguard enforcement.
26
+
27
+ Any class that exposes a ``_fg_check(value, subset) -> bool`` classmethod
28
+ is treated as a schema type. This is the extension point: new DataFrame
29
+ backends (pandas, polars, …) just need to add ``_fg_check`` to their
30
+ schema class; no changes to enforcement code required.
31
+ """
32
+ return isinstance(annotation, type) and callable(getattr(annotation, "_fg_check", None))
33
+
34
+
35
+ def _schema_matches(value: Any, annotation: type, subset: bool) -> bool:
36
+ """
37
+ Check whether *value* satisfies *annotation*.
38
+
39
+ Delegates to ``annotation._fg_check(value, subset)`` when available.
40
+ The meaning of *subset* is left to each schema type:
41
+
42
+ - ``schema_of`` types (``_TypedDatasetBase``) ignore *subset*: always exact.
43
+ - ``SparkSchema`` types respect *subset*:
44
+ - ``True``: extra columns in *value* are fine.
45
+ - ``False``: *value* must have exactly the declared columns, nothing extra.
46
+ """
47
+ checker = getattr(annotation, "_fg_check", None)
48
+ if callable(checker):
49
+ return bool(checker(value, subset))
50
+ return isinstance(value, annotation)
51
+
52
+
53
+ def _arm_module_dict(module_dict: dict[str, Any], *, subset: Any) -> None:
54
+ """Patch all public functions in a module's __dict__ with enforce()."""
55
+ for name, obj in list(module_dict.items()):
56
+ if name.startswith("_"):
57
+ continue
58
+ if isinstance(obj, types.FunctionType):
59
+ # Pass _UNSET so each wrapped function reads _SUBSET at call-time,
60
+ # unless overridden at decoration time by the caller.
61
+ wrapped = enforce(subset=subset)(obj)
62
+ if wrapped is not obj:
63
+ module_dict[name] = wrapped
64
+
65
+
66
+ def arm(
67
+ module: Any = None,
68
+ *,
69
+ package: str | None = None,
70
+ subset: bool = True,
71
+ ) -> None:
72
+ """
73
+ Arm the entire calling package and set the global subset default.
74
+
75
+ Call once from your entry point, ``__init__.py``, or ``settings.py`` (Kedro)::
76
+
77
+ import dfguard.pyspark as dfg
78
+
79
+ dfg.arm() # subset=True (default): extra columns are fine
80
+ dfg.arm(subset=False) # exact match: no extra columns allowed anywhere
81
+
82
+ The ``subset`` value becomes the global default. Individual functions decorated
83
+ with ``@dfg.enforce(subset=...)`` override it for that function only.
84
+
85
+ If called when already armed, re-enables enforcement (sets ``_ENABLED = True``)
86
+ without re-walking the package.
87
+
88
+ **Specific module object**::
89
+
90
+ dfg.arm(my_module)
91
+
92
+ **Explicit package name**::
93
+
94
+ dfg.arm(package="my_pipeline.nodes")
95
+ """
96
+ global _SUBSET, _ENABLED, _ARMED
97
+ _SUBSET = subset
98
+ _ENABLED = True
99
+
100
+ if _ARMED:
101
+ # Already armed: just re-enable and update subset, no re-walking needed.
102
+ return
103
+
104
+ _ARMED = True
105
+
106
+ if isinstance(module, types.ModuleType):
107
+ _arm_module_dict(vars(module), subset=_UNSET)
108
+ return
109
+
110
+ if package is None:
111
+ frame = inspect.currentframe()
112
+ if frame is None or frame.f_back is None:
113
+ return
114
+ caller_globals = frame.f_back.f_globals
115
+ package = caller_globals.get("__package__") or caller_globals.get("__name__", "")
116
+
117
+ if not package or package == "__main__":
118
+ warnings.warn(
119
+ "dfguard.pyspark.arm() called from __main__. "
120
+ "Use @dfguard.pyspark.enforce on individual functions instead.",
121
+ stacklevel=2,
122
+ )
123
+ return
124
+
125
+ pkg = importlib.import_module(package)
126
+ _arm_module_dict(vars(pkg), subset=_UNSET)
127
+ pkg_path = getattr(pkg, "__path__", None)
128
+ if pkg_path is not None:
129
+ for _, mod_name, _ in pkgutil.walk_packages(pkg_path, prefix=package + "."):
130
+ try:
131
+ mod = importlib.import_module(mod_name)
132
+ _arm_module_dict(vars(mod), subset=_UNSET)
133
+ except Exception as exc:
134
+ warnings.warn(
135
+ f"dfguard: skipped arming module '{mod_name}': {exc}",
136
+ stacklevel=2,
137
+ )
138
+
139
+
140
+ def disarm() -> None:
141
+ """Turn off all enforcement globally. Call arm() to re-enable."""
142
+ global _ENABLED
143
+ _ENABLED = False
144
+
145
+
146
+ @overload
147
+ def enforce(func: F) -> F: ...
148
+
149
+ @overload
150
+ def enforce(func: None = None, *, subset: bool = ...) -> Callable[[F], F]: ...
151
+
152
+
153
+ def enforce(
154
+ func: F | None = None,
155
+ *,
156
+ subset: Any = _UNSET,
157
+ ) -> F | Callable[[F], F]:
158
+ """
159
+ Validate schema annotations on DataFrame arguments.
160
+
161
+ Only intercepts parameters annotated with a ``dfg.schema_of`` type or a
162
+ ``dfg.SparkSchema`` subclass. All other arguments are left completely alone.
163
+
164
+ **Default**: inherits the global ``subset`` set by ``dfg.arm()``:
165
+
166
+ @dfg.enforce
167
+ def process(df: OrderSchema, label: str): ...
168
+
169
+ **subset=True**: extra columns in the DataFrame are fine (overrides global)::
170
+
171
+ @dfg.enforce(subset=True)
172
+ def process(df: OrderSchema): ...
173
+
174
+ **subset=False**: DataFrame must match the schema exactly (overrides global)::
175
+
176
+ @dfg.enforce(subset=False)
177
+ def process(df: OrderSchema): ...
178
+ """
179
+ # Capture the function-level subset at decoration time.
180
+ # If _UNSET, the wrapper reads _SUBSET at call-time (respects dfg.arm changes).
181
+ subset_override = subset
182
+
183
+ def decorator(f: F) -> F:
184
+ params = inspect.signature(f).parameters
185
+ schema_params = [
186
+ (name, param.annotation)
187
+ for name, param in params.items()
188
+ if param.annotation is not inspect.Parameter.empty
189
+ and _is_schema_type(param.annotation)
190
+ ]
191
+
192
+ if not schema_params:
193
+ return f # nothing schema-typed, zero overhead
194
+
195
+ sig = inspect.signature(f)
196
+
197
+ @functools.wraps(f)
198
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
199
+ if not _ENABLED:
200
+ return f(*args, **kwargs)
201
+
202
+ # Function-level subset wins; fall back to global if not set.
203
+ effective_subset = _SUBSET if subset_override is _UNSET else subset_override
204
+
205
+ bound = sig.bind(*args, **kwargs)
206
+ bound.apply_defaults()
207
+
208
+ for param_name, annotation in schema_params:
209
+ if param_name not in bound.arguments:
210
+ continue
211
+ value = bound.arguments[param_name]
212
+ if not _schema_matches(value, annotation, subset=effective_subset):
213
+ _raise_schema_mismatch(f.__name__, param_name, annotation, value)
214
+
215
+ return f(*args, **kwargs)
216
+
217
+ return wrapper # type: ignore[return-value]
218
+
219
+ if func is not None:
220
+ return decorator(func)
221
+ return decorator
222
+
223
+
224
+ def _raise_schema_mismatch(
225
+ func_name: str,
226
+ param_name: str,
227
+ annotation: type,
228
+ value: Any,
229
+ ) -> None:
230
+ actual_schema = getattr(value, "schema", None)
231
+ if actual_schema is not None:
232
+ actual_str = ", ".join(
233
+ f"{f.name}:{f.dataType.simpleString()}" for f in actual_schema.fields
234
+ )
235
+ else:
236
+ actual_str = type(value).__name__
237
+
238
+ expected_schema = getattr(annotation, "_expected_schema", None)
239
+ if expected_schema is not None:
240
+ expected_str = ", ".join(
241
+ f"{f.name}:{f.dataType.simpleString()}" for f in expected_schema.fields
242
+ )
243
+ else:
244
+ expected_str = getattr(annotation, "__name__", repr(annotation))
245
+
246
+ raise TypeError(
247
+ f"Schema mismatch in {func_name}() argument '{param_name}':\n"
248
+ f" expected: {expected_str}\n"
249
+ f" received: {actual_str}"
250
+ )
@@ -0,0 +1,77 @@
1
+ """
2
+ dfguard.pyspark._inference
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~
4
+ ``infer_schema(df)``: inspect a live DataFrame and generate a SparkSchema
5
+ subclass with the correct types, including deeply nested structs.
6
+
7
+ The generated class can immediately be used for validation::
8
+
9
+ schema = infer_schema(df, name="OrderSchema")
10
+ print(schema.to_code()) # copy-paste into your codebase
11
+ ds.validate(schema)
12
+
13
+ Nested StructTypes are emitted as separate named inner classes.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ from typing import Any
19
+
20
+
21
+ def infer_schema(df: Any, name: str = "InferredSchema") -> type:
22
+ """
23
+ Inspect ``df`` (a pyspark.sql.DataFrame or dataset wrapper) and return
24
+ a SparkSchema subclass that exactly matches its current schema.
25
+
26
+ Also prints the Python code so developers can copy it into their codebase.
27
+
28
+ Parameters
29
+ ----------
30
+ df:
31
+ A live ``pyspark.sql.DataFrame`` or ``dataset`` wrapper.
32
+ name:
33
+ Name to give the generated class.
34
+
35
+ Returns
36
+ -------
37
+ type[SparkSchema]
38
+ A fully usable SparkSchema subclass.
39
+ """
40
+ from dfguard.pyspark.schema import SparkSchema
41
+
42
+ struct = df.schema
43
+ schema_class = SparkSchema.from_struct(struct, name=name)
44
+
45
+ print(_render_code(schema_class, name))
46
+ return schema_class
47
+
48
+
49
+ def _render_code(schema_class: Any, name: str) -> str:
50
+ """Recursively render a SparkSchema (and any nested schemas) as Python source."""
51
+ lines: list[str] = []
52
+ nested_lines: list[str] = []
53
+
54
+ # Collect any nested SparkSchema classes stored as class attributes
55
+ from dfguard.pyspark.schema import SparkSchema
56
+ for attr_name, attr_val in vars(schema_class).items():
57
+ if (
58
+ isinstance(attr_val, type)
59
+ and issubclass(attr_val, SparkSchema)
60
+ and attr_val is not SparkSchema
61
+ ):
62
+ nested_lines.append(_render_code(attr_val, attr_val.__name__))
63
+ nested_lines.append("")
64
+
65
+ if nested_lines:
66
+ lines.extend(nested_lines)
67
+
68
+ lines.append(f"class {name}(SparkSchema):")
69
+ if not schema_class._schema_fields:
70
+ lines.append(" pass")
71
+ else:
72
+ from dfguard.pyspark.schema import _annotation_to_str
73
+ for col_name, annotation in schema_class._schema_fields.items():
74
+ ann_str = _annotation_to_str(annotation)
75
+ lines.append(f" {col_name}: {ann_str}")
76
+
77
+ return "\n".join(lines)
@@ -0,0 +1,49 @@
1
+ """
2
+ Nullable field marker for SparkSchema.
3
+
4
+ ``typing.Optional[T.StringType()]`` raises TypeError on Python 3.10 because
5
+ PySpark DataType instances are not callable classes, and Python's typing module
6
+ validates this. Python 3.11+ relaxed the check.
7
+
8
+ This module provides a drop-in replacement that works on Python 3.10+.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from typing import Any
14
+
15
+
16
+ class _NullableAnnotation:
17
+ """Wraps a DataType annotation to mark the field as nullable."""
18
+
19
+ __slots__ = ("inner",)
20
+
21
+ def __init__(self, inner: Any) -> None:
22
+ self.inner = inner
23
+
24
+ def __class_getitem__(cls, item: Any) -> _NullableAnnotation:
25
+ return cls(item)
26
+
27
+ def __repr__(self) -> str:
28
+ return f"Optional[{self.inner!r}]"
29
+
30
+ def __eq__(self, other: object) -> bool:
31
+ return isinstance(other, _NullableAnnotation) and self.inner == other.inner
32
+
33
+ def __hash__(self) -> int:
34
+ try:
35
+ return hash(("_NullableAnnotation", self.inner))
36
+ except TypeError:
37
+ return id(self)
38
+
39
+
40
+ #: Drop-in for ``typing.Optional`` that works on Python 3.10+ with PySpark DataType fields.
41
+ #:
42
+ #: Usage::
43
+ #:
44
+ #: from dfguard.pyspark import Optional
45
+ #:
46
+ #: class OrderSchema(fg.SparkSchema):
47
+ #: order_id: T.LongType()
48
+ #: zip: Optional[T.StringType()] # nullable field
49
+ Optional = _NullableAnnotation
@@ -0,0 +1,203 @@
1
+ """
2
+ Type coercion rules for PySpark, implemented in pure Python.
3
+
4
+ This mirrors Spark's Catalyst TypeCoercion and DecimalPrecision rules so that
5
+ dfguard can resolve derived column types without a running Spark session.
6
+
7
+ Rules source: org.apache.spark.sql.catalyst.analysis.TypeCoercion
8
+ org.apache.spark.sql.catalyst.analysis.DecimalPrecision
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from typing import Any
14
+
15
+ # ---------------------------------------------------------------------------
16
+ # Integer types expressed as Decimal equivalents (precision, scale=0)
17
+ # ---------------------------------------------------------------------------
18
+
19
+ _INT_AS_DECIMAL = {
20
+ "ByteType": (3, 0),
21
+ "ShortType": (5, 0),
22
+ "IntegerType": (10, 0),
23
+ "LongType": (20, 0),
24
+ }
25
+
26
+ # Numeric widening rank (higher = wider). Float/Double are above Long because
27
+ # Spark promotes integer+float → double (lossy but matches Spark behaviour).
28
+ _NUMERIC_RANK: dict[str, int] = {
29
+ "ByteType": 1,
30
+ "ShortType": 2,
31
+ "IntegerType": 3,
32
+ "LongType": 4,
33
+ "FloatType": 5,
34
+ "DoubleType": 6,
35
+ }
36
+
37
+
38
+ def _type_name(dt: Any) -> str:
39
+ return type(dt).__name__
40
+
41
+
42
+ def _is_integral(dt: Any) -> bool:
43
+ return _type_name(dt) in _INT_AS_DECIMAL
44
+
45
+
46
+ def _is_fractional(dt: Any) -> bool:
47
+ return _type_name(dt) in ("FloatType", "DoubleType")
48
+
49
+
50
+ def _is_numeric(dt: Any) -> bool:
51
+ return _is_integral(dt) or _is_fractional(dt) or _type_name(dt) == "DecimalType"
52
+
53
+
54
+ # ---------------------------------------------------------------------------
55
+ # Decimal precision arithmetic (mirrors DecimalPrecision.scala)
56
+ # ---------------------------------------------------------------------------
57
+
58
+ def _decimal_for_integral(dt: Any) -> tuple[int, int]:
59
+ """Return (precision, scale) for an integer type treated as Decimal."""
60
+ return _INT_AS_DECIMAL[_type_name(dt)]
61
+
62
+
63
+ def _decimal_add(p1: int, s1: int, p2: int, s2: int) -> tuple[int, int]:
64
+ """Decimal addition/subtraction result precision and scale."""
65
+ scale = max(s1, s2)
66
+ precision = max(p1 - s1, p2 - s2) + scale + 1
67
+ return precision, scale
68
+
69
+
70
+ def _decimal_mul(p1: int, s1: int, p2: int, s2: int) -> tuple[int, int]:
71
+ """Decimal multiplication result precision and scale."""
72
+ scale = s1 + s2
73
+ precision = p1 + p2 + 1
74
+ return precision, scale
75
+
76
+
77
+ def _decimal_div(p1: int, s1: int, p2: int, s2: int) -> tuple[int, int]:
78
+ """Decimal division result precision and scale."""
79
+ scale = max(6, s1 + p2 + 1)
80
+ precision = p1 - s1 + s2 + scale
81
+ return precision, scale
82
+
83
+
84
+ def _make_decimal(p: int, s: int) -> Any:
85
+ from pyspark.sql import types as T
86
+ # Spark caps precision at 38
87
+ p = min(p, 38)
88
+ return T.DecimalType(p, s)
89
+
90
+
91
+ # ---------------------------------------------------------------------------
92
+ # Public API
93
+ # ---------------------------------------------------------------------------
94
+
95
+ def coerce_add(left: Any, right: Any) -> Any:
96
+ """Return the result type of ``left + right`` (or ``left - right``)."""
97
+ return _coerce_binary(left, right, "add")
98
+
99
+
100
+ def coerce_mul(left: Any, right: Any) -> Any:
101
+ """Return the result type of ``left * right``."""
102
+ return _coerce_binary(left, right, "mul")
103
+
104
+
105
+ def coerce_div(left: Any, right: Any) -> Any:
106
+ """Return the result type of ``left / right``. Always Double in Spark."""
107
+ from pyspark.sql import types as T
108
+ if not (_is_numeric(left) and _is_numeric(right)):
109
+ raise TypeError(f"Cannot divide {_type_name(left)} by {_type_name(right)}")
110
+ return T.DoubleType()
111
+
112
+
113
+ def coerce_mod(left: Any, right: Any) -> Any:
114
+ """Return the result type of ``left % right`` (modulo)."""
115
+ return _coerce_binary(left, right, "add") # same widening rules as add
116
+
117
+
118
+ def _coerce_binary(left: Any, right: Any, op: str) -> Any:
119
+ from pyspark.sql import types as T
120
+
121
+ ln, rn = _type_name(left), _type_name(right)
122
+
123
+ # Both are simple numeric (no Decimal involved)
124
+ if ln in _NUMERIC_RANK and rn in _NUMERIC_RANK:
125
+ # Float/Double always promotes to Double (Spark rule)
126
+ if _is_fractional(left) or _is_fractional(right):
127
+ return T.DoubleType()
128
+ # Pure integer widening
129
+ winner = left if _NUMERIC_RANK[ln] >= _NUMERIC_RANK[rn] else right
130
+ return type(winner)()
131
+
132
+ # At least one side is Decimal
133
+ if ln == "DecimalType" or rn == "DecimalType":
134
+ # Float or Double wins over Decimal
135
+ if _is_fractional(left) or _is_fractional(right):
136
+ return T.DoubleType()
137
+
138
+ # Normalise both sides to (precision, scale)
139
+ if ln == "DecimalType":
140
+ p1, s1 = left.precision, left.scale
141
+ else:
142
+ p1, s1 = _decimal_for_integral(left)
143
+
144
+ if rn == "DecimalType":
145
+ p2, s2 = right.precision, right.scale
146
+ else:
147
+ p2, s2 = _decimal_for_integral(right)
148
+
149
+ if op == "mul":
150
+ p, s = _decimal_mul(p1, s1, p2, s2)
151
+ else:
152
+ p, s = _decimal_add(p1, s1, p2, s2)
153
+
154
+ return _make_decimal(p, s)
155
+
156
+ raise TypeError(f"Cannot coerce {ln} and {rn} for operation '{op}'")
157
+
158
+
159
+ def coerce_comparison(left: Any, right: Any) -> Any:
160
+ """Comparison (==, <, >, <=, >=) always returns BooleanType."""
161
+ from pyspark.sql import types as T
162
+ return T.BooleanType()
163
+
164
+
165
+ def coerce_cast(target: Any) -> Any:
166
+ """Explicit cast: the target type is the result."""
167
+ return target
168
+
169
+
170
+ def result_type(left: Any, right: Any, op: str) -> Any:
171
+ """
172
+ Resolve the output DataType for a binary operation between two typed columns.
173
+
174
+ Parameters
175
+ ----------
176
+ left, right : DataType instances
177
+ op : one of '+', '-', '*', '/', '%', '==', '!=', '<', '<=', '>', '>='
178
+
179
+ Returns
180
+ -------
181
+ DataType instance representing the result type.
182
+
183
+ Examples
184
+ --------
185
+ >>> from pyspark.sql import types as T
186
+ >>> result_type(T.IntegerType(), T.DecimalType(10, 2), '+')
187
+ DecimalType(13,2)
188
+ >>> result_type(T.LongType(), T.DoubleType(), '*')
189
+ DoubleType()
190
+ >>> result_type(T.IntegerType(), T.IntegerType(), '/')
191
+ DoubleType()
192
+ """
193
+ if op in ("+", "-"):
194
+ return coerce_add(left, right)
195
+ if op == "*":
196
+ return coerce_mul(left, right)
197
+ if op == "/":
198
+ return coerce_div(left, right)
199
+ if op == "%":
200
+ return coerce_mod(left, right)
201
+ if op in ("==", "!=", "<", "<=", ">", ">="):
202
+ return coerce_comparison(left, right)
203
+ raise ValueError(f"Unknown operator: {op!r}")