pyspark-inspect 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,44 @@
1
+ Metadata-Version: 2.4
2
+ Name: pyspark-inspect
3
+ Version: 0.1.0
4
+ Summary: Inspect pyspark query plans
5
+ License: MIT
6
+ Author: Matthias Ossadnik
7
+ Author-email: ossadnik.matthias@gmail.com
8
+ Requires-Python: >=3.11,<4.0
9
+ Classifier: Development Status :: 3 - Alpha
10
+ Classifier: License :: OSI Approved :: MIT License
11
+ Classifier: Programming Language :: Python :: 3
12
+ Classifier: Programming Language :: Python :: 3.11
13
+ Classifier: Programming Language :: Python :: 3.12
14
+ Classifier: Programming Language :: Python :: 3.13
15
+ Classifier: Programming Language :: Python :: 3.14
16
+ Requires-Dist: pyspark (==3.5.*)
17
+ Project-URL: homepage, https://github.com/mossadnik/pyspark-inspect
18
+ Description-Content-Type: text/markdown
19
+
20
+ # pyspark-inspect
21
+
22
+ pyspark-inspect allows to inspect the query plans of pyspark dataframes from Python.
23
+
24
+ It converts analyzed Catalyst plans into Python data structures that can be queried programmatically.
25
+
26
+ ## Compatibility & Limitiations
27
+
28
+ - Tested against Spark 3.5
29
+ - Spark Connect is not supported
30
+ - Coverage of Spark SQL operations is currently very limited
31
+
32
+ ## Basic Usage
33
+
34
+ ```python
35
+ from pypspark.sql import functions as F
36
+ from pyspark_inspect import inspect_dataframe
37
+
38
+ df = (
39
+ spark.createDataFrame([[1, 2]], ['a', 'b'])
40
+ .withColumn('c', F.col('a') + F.col('b'))
41
+ )
42
+ plan = inspect_dataframe(df)
43
+ ```
44
+
@@ -0,0 +1,24 @@
1
+ # pyspark-inspect
2
+
3
+ pyspark-inspect allows to inspect the query plans of pyspark dataframes from Python.
4
+
5
+ It converts analyzed Catalyst plans into Python data structures that can be queried programmatically.
6
+
7
+ ## Compatibility & Limitiations
8
+
9
+ - Tested against Spark 3.5
10
+ - Spark Connect is not supported
11
+ - Coverage of Spark SQL operations is currently very limited
12
+
13
+ ## Basic Usage
14
+
15
+ ```python
16
+ from pypspark.sql import functions as F
17
+ from pyspark_inspect import inspect_dataframe
18
+
19
+ df = (
20
+ spark.createDataFrame([[1, 2]], ['a', 'b'])
21
+ .withColumn('c', F.col('a') + F.col('b'))
22
+ )
23
+ plan = inspect_dataframe(df)
24
+ ```
@@ -0,0 +1,34 @@
1
+ [tool.poetry]
2
+ name = "pyspark-inspect"
3
+ version = "0.1.0"
4
+ description = "Inspect pyspark query plans"
5
+ authors = [
6
+ "Matthias Ossadnik <ossadnik.matthias@gmail.com>"
7
+ ]
8
+ classifiers = [
9
+ "Development Status :: 3 - Alpha"
10
+ ]
11
+ readme = "README.md"
12
+ license = "MIT"
13
+
14
+ [tool.poetry.urls]
15
+ homepage = "https://github.com/mossadnik/pyspark-inspect"
16
+
17
+
18
+ [tool.poetry.dependencies]
19
+ python = "^3.11"
20
+ pyspark = "==3.5.*"
21
+
22
+ [build-system]
23
+ requires = ["poetry-core"]
24
+ build-backend = "poetry.core.masonry.api"
25
+
26
+ [dependency-groups]
27
+ dev = [
28
+ "jupyter (>=1.1.1,<2.0.0)",
29
+ "pytest (>=9.0.2,<10.0.0)",
30
+ "mkdocs (>=1.6.1,<2.0.0)",
31
+ "mkdocstrings[python] (>=1.0.4,<2.0.0)",
32
+ "pyyaml (>=6.0.3,<7.0.0)",
33
+ "z3-solver (>=4.14.0,<4.15.0)"
34
+ ]
@@ -0,0 +1,6 @@
1
+ from .api import inspect_dataframe
2
+
3
+
4
+ __all__ = [
5
+ 'inspect_dataframe'
6
+ ]
@@ -0,0 +1,9 @@
1
+ """High-level public functions and classes."""
2
+
3
+ from pyspark.sql import DataFrame
4
+ from .plan import Plan
5
+ from .parser.catalyst import get_dataframe_plan, load_catalyst_plan, parse_plan
6
+
7
+
8
+ def inspect_dataframe(df: DataFrame) -> Plan:
9
+ return parse_plan(load_catalyst_plan(get_dataframe_plan(df)))
@@ -0,0 +1,52 @@
1
+ """Translate plans into pyspark DataFrames and Columns."""
2
+
3
+ import typing as tp
4
+ from functools import singledispatch
5
+ from pyspark.sql import SparkSession, DataFrame, Column, functions as F
6
+ from . import plan as P
7
+ from . import expression as E
8
+
9
+
10
+ @singledispatch
11
+ def get_spark_expression(expr: tp.Any) -> Column:
12
+ raise NotImplementedError(f'Cannot translate expression of type {type(expr)}')
13
+
14
+
15
+ @get_spark_expression.register
16
+ def _(expr: E.Coalesce) -> Column:
17
+ return F.coalesce(*map(get_spark_expression, expr.children))
18
+
19
+
20
+ @get_spark_expression.register
21
+ def _(expr: E.Alias) -> Column:
22
+ return get_spark_expression(expr.arg).alias(expr.name)
23
+
24
+ @get_spark_expression.register
25
+ def _(expr: E.AttributeReference) -> Column:
26
+ return F.col(f'`{".".join(expr.qualifier)}`.`{expr.name}`')
27
+
28
+
29
+ @singledispatch
30
+ def get_plan_dataframe(plan: P.Plan, spark: SparkSession) -> DataFrame:
31
+ """convert a plan into a Dataframe."""
32
+ raise NotImplementedError(f'Cannot translate plan of type {type(plan)}')
33
+
34
+
35
+ @get_plan_dataframe.register
36
+ def _(plan: P.Project, spark: SparkSession) -> DataFrame:
37
+ df = get_plan_dataframe(plan.child, spark)
38
+ return df.select(*[
39
+ get_spark_expression(expr)
40
+ for expr in plan.columns
41
+ ])
42
+
43
+
44
+ @get_plan_dataframe.register
45
+ def _(plan: P.Table, spark: SparkSession) -> DataFrame:
46
+ return spark.read.table(plan.qualified_name)
47
+
48
+
49
+ @get_plan_dataframe.register
50
+ def _(plan: P.Alias, spark: SparkSession) -> DataFrame:
51
+ df = get_plan_dataframe(plan.child, spark)
52
+ return df.alias('.'.join([*plan.qualifier, plan.alias]))
@@ -0,0 +1,201 @@
1
+ """Representation of Column expressions."""
2
+
3
+
4
+ import typing as tp
5
+ import operator as op
6
+ from dataclasses import dataclass
7
+ from pyspark.sql import Column, functions as F
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class Expression:
12
+ _func: tp.ClassVar[tp.Optional[tp.Callable[..., Column]]] = None
13
+
14
+ @property
15
+ def children(self) -> list['Expression']:
16
+ raise NotImplementedError()
17
+
18
+ def to_column(self) -> Column:
19
+ raise NotImplementedError()
20
+
21
+ def _get_func(self) -> tp.Callable[..., Column]:
22
+ if self._func is None:
23
+ raise NotImplementedError()
24
+ else:
25
+ return self._func
26
+
27
+
28
+ @dataclass(frozen=True)
29
+ class Leaf(Expression):
30
+ @property
31
+ def children(self):
32
+ return []
33
+
34
+
35
+ @dataclass(frozen=True)
36
+ class UnaryOperator(Expression):
37
+ arg: Expression
38
+
39
+ @property
40
+ def children(self):
41
+ return [self.arg]
42
+
43
+ def to_column(self) -> Column:
44
+ return self._get_func()(self.arg.to_column())
45
+
46
+
47
+ @dataclass(frozen=True)
48
+ class BinaryOperator(Expression):
49
+ left: Expression
50
+ right: Expression
51
+
52
+ @property
53
+ def children(self):
54
+ return [self.left, self.right]
55
+
56
+ def to_column(self) -> Column:
57
+ return self._get_func()(self.left.to_column(), self.right.to_column())
58
+
59
+
60
+ @dataclass(frozen=True)
61
+ class Variadic(Expression):
62
+ args: tuple[Expression, ...]
63
+
64
+ @property
65
+ def children(self):
66
+ return list(self.args)
67
+
68
+ def to_column(self):
69
+ return self._get_func()(*[arg.to_column() for arg in self.args])
70
+
71
+
72
+ @dataclass(frozen=True)
73
+ class UnknownExpression(Variadic):
74
+ class_name: str
75
+
76
+ def to_column(self):
77
+ raise NotImplementedError(f'Cannot translate unknown expression of type {self.class_name}')
78
+
79
+
80
+ @dataclass(frozen=True)
81
+ class Literal(Leaf):
82
+ value: tp.Hashable
83
+ data_type: str | None = None
84
+
85
+ def to_column(self) -> Column:
86
+ return F.lit(self.value)
87
+
88
+
89
+ @dataclass(frozen=True)
90
+ class AttributeReference(Leaf):
91
+ name: str
92
+ qualifier: tuple[str, ...]
93
+ attribute_id: str
94
+
95
+ @property
96
+ def qualified_name(self) -> str:
97
+ name = f'`{self.name}`'
98
+ if self.qualifier:
99
+ qualifier = '.'.join(self.qualifier)
100
+ return f'`{qualifier}`.{name}'
101
+ else:
102
+ return name
103
+
104
+ def to_column(self) -> Column:
105
+ return F.col(self.qualified_name)
106
+
107
+
108
+ @dataclass(frozen=True)
109
+ class Alias(UnaryOperator):
110
+ name: str
111
+ qualifier: tuple[str, ...]
112
+
113
+ @property
114
+ def qualified_name(self) -> str:
115
+ name = f'`{self.name}`'
116
+ if self.qualifier:
117
+ qualifier = '.'.join(self.qualifier)
118
+ return f'`{qualifier}`.{name}'
119
+ else:
120
+ return name
121
+
122
+ def to_column(self) -> Column:
123
+ return self.arg.to_column().alias(self.qualified_name)
124
+
125
+
126
+ @dataclass(frozen=True)
127
+ class Equals(BinaryOperator):
128
+ _func = op.eq
129
+
130
+
131
+ @dataclass(frozen=True)
132
+ class EqNullSafe(BinaryOperator):
133
+ _func = Column.eqNullSafe
134
+
135
+
136
+ @dataclass(frozen=True)
137
+ class GreaterThanOrEqual(BinaryOperator):
138
+ _func = op.ge
139
+
140
+
141
+ @dataclass(frozen=True)
142
+ class GreaterThan(BinaryOperator):
143
+ _func = op.gt
144
+
145
+
146
+ @dataclass(frozen=True)
147
+ class LessThanOrEqual(BinaryOperator):
148
+ _func = op.le
149
+
150
+
151
+ @dataclass(frozen=True)
152
+ class LessThan(BinaryOperator):
153
+ _func = op.lt
154
+
155
+
156
+ @dataclass(frozen=True)
157
+ class And(BinaryOperator):
158
+ _func = op.and_
159
+
160
+
161
+ @dataclass(frozen=True)
162
+ class Or(BinaryOperator):
163
+ _func = op.or_
164
+
165
+
166
+ @dataclass(frozen=True)
167
+ class Not(UnaryOperator):
168
+ _func = op.neg
169
+
170
+
171
+ @dataclass(frozen=True)
172
+ class IsNull(UnaryOperator):
173
+ _func = F.isnull
174
+
175
+
176
+ @dataclass(frozen=True)
177
+ class IsNotNull(UnaryOperator):
178
+ _func = F.isnotnull
179
+
180
+
181
+ @dataclass(frozen=True)
182
+ class Coalesce(Variadic):
183
+ _func = F.coalesce
184
+
185
+
186
+ @dataclass(frozen=True)
187
+ class Concat(Variadic):
188
+ _func = F.concat
189
+
190
+
191
+ @dataclass(frozen=True)
192
+ class When(Variadic):
193
+ def to_column(self):
194
+ args = [a.to_column() for a in self.args]
195
+ when, then = args[:2]
196
+ c = F.when(when, then)
197
+ for when, then in zip(args[2::2], args[3::2]):
198
+ c = c.when(when, then)
199
+ if len(args) % 2 == 1:
200
+ c = c.otherwise(args[-1])
201
+ return c
@@ -0,0 +1 @@
1
+ """Low-level parsing of Spark plans into Python data structures."""
@@ -0,0 +1,297 @@
1
+ """Parse Catalyst JSON plan into Python data structures."""
2
+
3
+ import typing as tp
4
+ import json
5
+ from dataclasses import dataclass
6
+ from pyspark.errors import PySparkAttributeError
7
+ from pyspark.sql import DataFrame, Column, types as T
8
+ from pyspark_inspect import plan as P, expression as E
9
+
10
+
11
+ @dataclass
12
+ class CatalystPlan:
13
+ children: list['CatalystPlan']
14
+ data: dict[str, tp.Any]
15
+
16
+ @property
17
+ def class_name(self) -> str:
18
+ return self.data['class']
19
+
20
+
21
+ def add_node(stack, plan: CatalystPlan, num_children: int):
22
+ if num_children > 0 or not stack:
23
+ stack.append((plan, num_children))
24
+ else:
25
+ top, num_children = stack[-1]
26
+ top.children.append(plan)
27
+ if len(top.children) == num_children and len(stack) > 0:
28
+ stack.pop()
29
+ add_node(stack, top, 0)
30
+
31
+
32
+ def get_dataframe_plan(df: DataFrame) -> str:
33
+ """Get Catalyst json representation of a DataFrame query plan."""
34
+ try:
35
+ jdf = df._jdf
36
+ except PySparkAttributeError:
37
+ raise NotImplementedError('Cannot retrieve query plan from JVM. Are you using Spark Connect?')
38
+ return jdf.logicalPlan().toJSON()
39
+
40
+
41
+ def get_column_expression(c: Column) -> str:
42
+ """Get Catalyst json representation of a column expression."""
43
+ try:
44
+ jc = c._jc
45
+ except PySparkAttributeError:
46
+ raise NotImplementedError('Cannot retrieve query plan from JVM. Are you using Spark Connect?')
47
+ return jc.expr().toJSON()
48
+
49
+
50
+ def load_catalyst_plan(plan: str | list) -> CatalystPlan:
51
+ """Load Catalyst plan string or parsed JSON into raw parse tree."""
52
+ if isinstance(plan, str):
53
+ plan = tp.cast(list, json.loads(plan))
54
+ stack: list[tuple[CatalystPlan, int]] = []
55
+ for item in plan:
56
+ num_children = item['num-children']
57
+ node = CatalystPlan(
58
+ children=[],
59
+ data=item
60
+ )
61
+ add_node(stack, node, num_children)
62
+ assert len(stack) == 1
63
+ return stack[0][0]
64
+
65
+
66
+ def _parse_logical_relation(plan: CatalystPlan, children: list[P.Plan]) -> P.Table:
67
+ tid = plan.data['catalogTable']['identifier']
68
+ qualified_name = '.'.join([tid['catalog'], tid['database'], tid['table']])
69
+ columns = [parse_expression(load_catalyst_plan(c)) for c in plan.data['output']]
70
+ return P.Table(qualified_name=qualified_name, columns=tuple(columns))
71
+
72
+
73
+ def _parse_subquery_alias(plan: CatalystPlan, children: list[P.Plan]) -> P.Alias:
74
+ return P.Alias(
75
+ child=children[0],
76
+ alias=plan.data['identifier']['name'],
77
+ qualifier=parse_array_string(plan.data['identifier'].get('qualifier', ''))
78
+ )
79
+
80
+ def _parse_project(plan: CatalystPlan, children: list[P.Plan]) -> P.Project:
81
+ return P.Project(
82
+ child=children[0],
83
+ columns=tuple([parse_expression(load_catalyst_plan(p)) for p in plan.data['projectList']])
84
+ )
85
+
86
+
87
+ def _parse_window(plan: CatalystPlan, children: list[P.Plan]) -> P.Project:
88
+ child = tp.cast(P.Project, children[0])
89
+ columns = tuple([
90
+ *child.columns,
91
+ *[parse_expression(load_catalyst_plan(p)) for p in plan.data['windowExpressions']],
92
+ ])
93
+ return P.Project(child=child.child, columns=columns)
94
+
95
+
96
+ def _parse_join(plan: CatalystPlan, children: list[P.Plan]) -> P.Join:
97
+ how = JOIN_TYPE[plan.data['joinType']['object']]
98
+ on = parse_expression(load_catalyst_plan(plan.data['condition']))
99
+ return P.Join(left=children[0], right=children[1], on=on, how=how)
100
+
101
+
102
+ def _parse_one_row_relation(plan: CatalystPlan, children: list[P.Plan]) -> P.OneRowRelation:
103
+ return P.OneRowRelation()
104
+
105
+
106
+ def _parse_cte_relation_def(plan: CatalystPlan, children: list[P.Plan]) -> P.Plan:
107
+ return P.CTEDef(
108
+ cte_id=plan.data['id'],
109
+ child=children[0]
110
+ )
111
+
112
+
113
+ def _parse_cte_relation_ref(plan: CatalystPlan, children: list[P.Plan]) -> P.Plan:
114
+ return P.CTERef(cte_id=plan.data['cteId'])
115
+
116
+
117
+ def _parse_with_cte(plan: CatalystPlan, children: list[P.Plan]) -> P.Plan:
118
+ return P.WithCTE(
119
+ ctes=tuple(children[:-1]),
120
+ main=children[-1]
121
+ )
122
+
123
+
124
+ def _parse_union(plan: CatalystPlan, children: list[P.Plan]) -> P.Plan:
125
+ return P.Union(
126
+ left=children[0],
127
+ right=children[1],
128
+ by_name=plan.data['byName'],
129
+ allow_missing_columns=plan.data['allowMissingCol']
130
+ )
131
+
132
+
133
+ def _parse_except(plan: CatalystPlan, children: list[P.Plan]) -> P.Plan:
134
+ return P.Except(
135
+ left=children[0],
136
+ right=children[1],
137
+ is_all=plan.data['isAll'],
138
+ )
139
+
140
+
141
+ def _parse_intersect(plan: CatalystPlan, children: list[P.Plan]) -> P.Plan:
142
+ return P.Intersect(
143
+ left=children[0],
144
+ right=children[1],
145
+ is_all=plan.data['isAll'],
146
+ )
147
+
148
+ def _parse_local_limit(plan: CatalystPlan, children: list[P.Plan]) -> P.Plan:
149
+ limit = int(plan.data['limitExpr'][0]['value'])
150
+ return P.LocalLimit(children[0], limit=limit)
151
+
152
+
153
+ def _parse_global_limit(plan: CatalystPlan, children: list[P.Plan]) -> P.Plan:
154
+ limit = int(plan.data['limitExpr'][0]['value'])
155
+ return P.GlobalLimit(children[0], limit=limit)
156
+
157
+
158
+ def _parse_aggregate(plan: CatalystPlan, children: list[P.Plan]) -> P.Plan:
159
+ return P.Aggregate(
160
+ child=children[0],
161
+ grouping_expressions=tuple([parse_expression(load_catalyst_plan(p)) for p in plan.data['groupingExpressions']]),
162
+ columns=tuple([parse_expression(load_catalyst_plan(p)) for p in plan.data['aggregateExpressions']])
163
+ )
164
+
165
+
166
+ def _parse_distinct(plan: CatalystPlan, children: list[P.Plan]) -> P.Plan:
167
+ return P.Distinct(children[0])
168
+
169
+
170
+ def _skip_unary(plan: CatalystPlan, children: list[P.Plan]) -> P.Plan:
171
+ """Skip a unary plan node, e.g. repartition"""
172
+ return children[0]
173
+
174
+
175
+ LOGICAL_PLAN = 'org.apache.spark.sql.catalyst.plans.logical'
176
+
177
+ JOIN_TYPE = {
178
+ 'org.apache.spark.sql.catalyst.plans.Inner$': 'inner',
179
+ 'org.apache.spark.sql.catalyst.plans.LeftOuter$': 'left',
180
+ 'org.apache.spark.sql.catalyst.plans.RightOuter$': 'right',
181
+ 'org.apache.spark.sql.catalyst.plans.FullOuter$': 'outer',
182
+ 'org.apache.spark.sql.catalyst.plans.LeftSemi$': 'left-semi',
183
+ 'org.apache.spark.sql.catalyst.plans.LeftAnti$': 'left-anti',
184
+ }
185
+
186
+ PLAN_PARSER: dict[str, tp.Callable[[CatalystPlan, list[P.Plan]], P.Plan]] = {
187
+ 'org.apache.spark.sql.execution.datasources.LogicalRelation': _parse_logical_relation,
188
+ f'{LOGICAL_PLAN}.Aggregate': _parse_aggregate,
189
+ f'{LOGICAL_PLAN}.CTERelationDef': _parse_cte_relation_def,
190
+ f'{LOGICAL_PLAN}.CTERelationRef': _parse_cte_relation_ref,
191
+ f'{LOGICAL_PLAN}.Distinct': _parse_distinct,
192
+ f'{LOGICAL_PLAN}.Except': _parse_except,
193
+ f'{LOGICAL_PLAN}.GlobalLimit': _parse_global_limit,
194
+ f'{LOGICAL_PLAN}.Intersect': _parse_intersect,
195
+ f'{LOGICAL_PLAN}.Join': _parse_join,
196
+ f'{LOGICAL_PLAN}.LocalLimit': _parse_local_limit,
197
+ f'{LOGICAL_PLAN}.SubqueryAlias': _parse_subquery_alias,
198
+ f'{LOGICAL_PLAN}.OneRowRelation': _parse_one_row_relation,
199
+ f'{LOGICAL_PLAN}.Project': _parse_project,
200
+ f'{LOGICAL_PLAN}.Union': _parse_union,
201
+ f'{LOGICAL_PLAN}.Window': _parse_window,
202
+ f'{LOGICAL_PLAN}.WithCTE': _parse_with_cte,
203
+ # skipped
204
+ f'{LOGICAL_PLAN}.Repartition': _skip_unary,
205
+ f'{LOGICAL_PLAN}.RepartitionByExpression': _skip_unary,
206
+ f'{LOGICAL_PLAN}.Sort': _skip_unary,
207
+ }
208
+
209
+
210
+ def parse_plan(plan: CatalystPlan):
211
+ children = [parse_plan(c) for c in plan.children]
212
+ parser = PLAN_PARSER.get(plan.class_name)
213
+ if not parser:
214
+ raise ValueError(f'Unsupported plan: {plan.class_name}')
215
+ return parser(plan, children)
216
+
217
+
218
+ def _parse_expression_alias(plan: CatalystPlan, children: list[E.Expression]) -> E.Expression:
219
+ return E.Alias(
220
+ arg=children[0],
221
+ name=plan.data['name'],
222
+ qualifier=parse_array_string(plan.data['qualifier'])
223
+ )
224
+
225
+
226
+ def _parse_attribute_reference(plan: CatalystPlan, children: list[E.Expression]) -> E.Expression:
227
+ # TODO consider to parse data type: T._parse_datatype_json_value(plan.data['dataType'])
228
+ return E.AttributeReference(
229
+ name=plan.data['name'],
230
+ qualifier=parse_array_string(plan.data['qualifier']),
231
+ attribute_id=plan.data['exprId']['id']
232
+ )
233
+
234
+
235
+ def _parse_literal(plan: CatalystPlan, children: list[E.Expression]) -> E.Expression:
236
+ # TODO: Should cast values
237
+ raw_value = plan.data['value']
238
+ raw_data_type = plan.data['dataType']
239
+ return E.Literal(value=raw_value, data_type=raw_data_type)
240
+
241
+
242
+ def _parse_variadic(cls: type) -> tp.Callable:
243
+ """Generic parser vor variadic expressions without additional arguments."""
244
+ def parser(plan: CatalystPlan, children: list[E.Expression]):
245
+ return cls(tuple(children))
246
+ return parser
247
+
248
+
249
+ def _parse_binary(cls: type) -> tp.Callable:
250
+ """Generic parser for binary expressions without additional arguments."""
251
+ def parser(plan: CatalystPlan, children: list[E.Expression]):
252
+ return cls(children[0], children[1])
253
+ return parser
254
+
255
+
256
+ def _parse_unary(cls: type) -> tp.Callable:
257
+ """Generic parser for unary expressions without additional arguments."""
258
+ def parser(plan: CatalystPlan, children: list[E.Expression]):
259
+ return cls(children[0])
260
+ return parser
261
+
262
+
263
+ EXPRESSION_PARSER: dict[str, tp.Callable[[CatalystPlan, list[E.Expression]], E.Expression]] = {
264
+ 'org.apache.spark.sql.catalyst.expressions.Alias': _parse_expression_alias,
265
+ 'org.apache.spark.sql.catalyst.expressions.Coalesce': _parse_variadic(E.Coalesce),
266
+ 'org.apache.spark.sql.catalyst.expressions.AttributeReference': _parse_attribute_reference,
267
+ 'org.apache.spark.sql.catalyst.expressions.And': _parse_binary(E.And),
268
+ 'org.apache.spark.sql.catalyst.expressions.Or': _parse_binary(E.Or),
269
+ 'org.apache.spark.sql.catalyst.expressions.EqualTo': _parse_binary(E.Equals),
270
+ 'org.apache.spark.sql.catalyst.expressions.EqualNullSafe': _parse_binary(E.EqNullSafe),
271
+ 'org.apache.spark.sql.catalyst.expressions.IsNull': _parse_unary(E.IsNull),
272
+ 'org.apache.spark.sql.catalyst.expressions.IsNotNull': _parse_unary(E.IsNotNull),
273
+ 'org.apache.spark.sql.catalyst.expressions.Literal': _parse_literal,
274
+ }
275
+
276
+
277
+ def _parse_unknown_expression(plan: CatalystPlan, children: list[E.Expression]) -> E.Expression:
278
+ """Fallback parser to capture unknown expressions with their children."""
279
+ return E.UnknownExpression(args=tuple(children), class_name=plan.class_name)
280
+
281
+
282
+ def parse_expression(plan: CatalystPlan) -> E.Expression:
283
+ """Parse a column expression."""
284
+ children = [parse_expression(c) for c in plan.children]
285
+ parser = EXPRESSION_PARSER.get(plan.class_name, _parse_unknown_expression)
286
+ return parser(plan, children)
287
+
288
+
289
+ def parse_array_string(s: str) -> tuple[str, ...]:
290
+ """split comma separated array.
291
+
292
+ Note that not all values can be parsed reliably since values are not quoted.
293
+ """
294
+ stripped = s[1:-1]
295
+ if not stripped:
296
+ return ()
297
+ return tuple(stripped.split(', '))
@@ -0,0 +1,137 @@
1
+ """Classes representing steps in a query plan."""
2
+
3
+ import typing as tp
4
+ from dataclasses import dataclass
5
+ from .expression import Expression
6
+ from pyspark.sql import types as T
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class Plan:
11
+ @property
12
+ def children(self) -> list['Plan']:
13
+ return []
14
+
15
+
16
+ @dataclass(frozen=True)
17
+ class UnaryPlan(Plan):
18
+ child: Plan
19
+
20
+ @property
21
+ def children(self) -> list[Plan]:
22
+ return [self.child]
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class BinaryPlan(Plan):
27
+ left: Plan
28
+ right: Plan
29
+
30
+ @property
31
+ def children(self) -> list[Plan]:
32
+ return [self.left, self.right]
33
+
34
+
35
+ @dataclass(frozen=True)
36
+ class Alias(UnaryPlan):
37
+ alias: str
38
+ qualifier: tuple[str, ...]
39
+
40
+
41
+ @dataclass(frozen=True)
42
+ class Project(UnaryPlan):
43
+ """A projection is a select without row generators like explode."""
44
+ columns: tuple[Expression, ...]
45
+
46
+
47
+ @dataclass(frozen=True)
48
+ class Filter(UnaryPlan):
49
+ condition: Expression
50
+
51
+
52
+ @dataclass(frozen=True)
53
+ class Join(BinaryPlan):
54
+ """A left/right/inner/outer/cross join."""
55
+ on: Expression
56
+ how: tp.Literal['left', 'right', 'inner', 'outer', 'left-anti', 'left-semi']
57
+
58
+
59
+ @dataclass(frozen=True)
60
+ class Table(Plan):
61
+ """A catalog table."""
62
+ qualified_name: str
63
+ columns: tuple[Expression, ...]
64
+
65
+
66
+ @dataclass(frozen=True)
67
+ class RDD(Plan):
68
+ """Output of SparkSesssion.createDataFrame.
69
+
70
+ Note that we cannot reconstruct data, would need to feed from
71
+ user.
72
+ """
73
+ schema: T.StructType
74
+
75
+
76
+ @dataclass(frozen=True)
77
+ class OneRowRelation(Plan):
78
+ """Dummy input for SQL select without from clause."""
79
+ pass
80
+
81
+
82
+ @dataclass(frozen=True)
83
+ class WithCTE(Plan):
84
+ """Statement with with-block."""
85
+ ctes: tuple[Plan, ...]
86
+ main: Plan
87
+
88
+ @property
89
+ def children(self):
90
+ return [*self.ctes, self.main]
91
+
92
+
93
+ @dataclass(frozen=True)
94
+ class CTEDef(UnaryPlan):
95
+ cte_id: int
96
+
97
+
98
+ @dataclass(frozen=True)
99
+ class CTERef(Plan):
100
+ cte_id: int
101
+
102
+
103
+ @dataclass(frozen=True)
104
+ class Union(BinaryPlan):
105
+ by_name: bool
106
+ allow_missing_columns: bool
107
+
108
+
109
+ @dataclass(frozen=True)
110
+ class Except(BinaryPlan):
111
+ is_all: bool
112
+
113
+
114
+ @dataclass(frozen=True)
115
+ class Intersect(BinaryPlan):
116
+ is_all: bool
117
+
118
+
119
+ @dataclass(frozen=True)
120
+ class Aggregate(UnaryPlan):
121
+ grouping_expressions: tuple[Expression, ...]
122
+ columns: tuple[Expression, ...]
123
+
124
+
125
+ @dataclass(frozen=True)
126
+ class Distinct(UnaryPlan):
127
+ pass
128
+
129
+
130
+ @dataclass(frozen=True)
131
+ class GlobalLimit(UnaryPlan):
132
+ limit: int
133
+
134
+
135
+ @dataclass(frozen=True)
136
+ class LocalLimit(UnaryPlan):
137
+ limit: int