odibi 2.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- odibi/__init__.py +32 -0
- odibi/__main__.py +8 -0
- odibi/catalog.py +3011 -0
- odibi/cli/__init__.py +11 -0
- odibi/cli/__main__.py +6 -0
- odibi/cli/catalog.py +553 -0
- odibi/cli/deploy.py +69 -0
- odibi/cli/doctor.py +161 -0
- odibi/cli/export.py +66 -0
- odibi/cli/graph.py +150 -0
- odibi/cli/init_pipeline.py +242 -0
- odibi/cli/lineage.py +259 -0
- odibi/cli/main.py +215 -0
- odibi/cli/run.py +98 -0
- odibi/cli/schema.py +208 -0
- odibi/cli/secrets.py +232 -0
- odibi/cli/story.py +379 -0
- odibi/cli/system.py +132 -0
- odibi/cli/test.py +286 -0
- odibi/cli/ui.py +31 -0
- odibi/cli/validate.py +39 -0
- odibi/config.py +3541 -0
- odibi/connections/__init__.py +9 -0
- odibi/connections/azure_adls.py +499 -0
- odibi/connections/azure_sql.py +709 -0
- odibi/connections/base.py +28 -0
- odibi/connections/factory.py +322 -0
- odibi/connections/http.py +78 -0
- odibi/connections/local.py +119 -0
- odibi/connections/local_dbfs.py +61 -0
- odibi/constants.py +17 -0
- odibi/context.py +528 -0
- odibi/diagnostics/__init__.py +12 -0
- odibi/diagnostics/delta.py +520 -0
- odibi/diagnostics/diff.py +169 -0
- odibi/diagnostics/manager.py +171 -0
- odibi/engine/__init__.py +20 -0
- odibi/engine/base.py +334 -0
- odibi/engine/pandas_engine.py +2178 -0
- odibi/engine/polars_engine.py +1114 -0
- odibi/engine/registry.py +54 -0
- odibi/engine/spark_engine.py +2362 -0
- odibi/enums.py +7 -0
- odibi/exceptions.py +297 -0
- odibi/graph.py +426 -0
- odibi/introspect.py +1214 -0
- odibi/lineage.py +511 -0
- odibi/node.py +3341 -0
- odibi/orchestration/__init__.py +0 -0
- odibi/orchestration/airflow.py +90 -0
- odibi/orchestration/dagster.py +77 -0
- odibi/patterns/__init__.py +24 -0
- odibi/patterns/aggregation.py +599 -0
- odibi/patterns/base.py +94 -0
- odibi/patterns/date_dimension.py +423 -0
- odibi/patterns/dimension.py +696 -0
- odibi/patterns/fact.py +748 -0
- odibi/patterns/merge.py +128 -0
- odibi/patterns/scd2.py +148 -0
- odibi/pipeline.py +2382 -0
- odibi/plugins.py +80 -0
- odibi/project.py +581 -0
- odibi/references.py +151 -0
- odibi/registry.py +246 -0
- odibi/semantics/__init__.py +71 -0
- odibi/semantics/materialize.py +392 -0
- odibi/semantics/metrics.py +361 -0
- odibi/semantics/query.py +743 -0
- odibi/semantics/runner.py +430 -0
- odibi/semantics/story.py +507 -0
- odibi/semantics/views.py +432 -0
- odibi/state/__init__.py +1203 -0
- odibi/story/__init__.py +55 -0
- odibi/story/doc_story.py +554 -0
- odibi/story/generator.py +1431 -0
- odibi/story/lineage.py +1043 -0
- odibi/story/lineage_utils.py +324 -0
- odibi/story/metadata.py +608 -0
- odibi/story/renderers.py +453 -0
- odibi/story/templates/run_story.html +2520 -0
- odibi/story/themes.py +216 -0
- odibi/testing/__init__.py +13 -0
- odibi/testing/assertions.py +75 -0
- odibi/testing/fixtures.py +85 -0
- odibi/testing/source_pool.py +277 -0
- odibi/transformers/__init__.py +122 -0
- odibi/transformers/advanced.py +1472 -0
- odibi/transformers/delete_detection.py +610 -0
- odibi/transformers/manufacturing.py +1029 -0
- odibi/transformers/merge_transformer.py +778 -0
- odibi/transformers/relational.py +675 -0
- odibi/transformers/scd.py +579 -0
- odibi/transformers/sql_core.py +1356 -0
- odibi/transformers/validation.py +165 -0
- odibi/ui/__init__.py +0 -0
- odibi/ui/app.py +195 -0
- odibi/utils/__init__.py +66 -0
- odibi/utils/alerting.py +667 -0
- odibi/utils/config_loader.py +343 -0
- odibi/utils/console.py +231 -0
- odibi/utils/content_hash.py +202 -0
- odibi/utils/duration.py +43 -0
- odibi/utils/encoding.py +102 -0
- odibi/utils/extensions.py +28 -0
- odibi/utils/hashing.py +61 -0
- odibi/utils/logging.py +203 -0
- odibi/utils/logging_context.py +740 -0
- odibi/utils/progress.py +429 -0
- odibi/utils/setup_helpers.py +302 -0
- odibi/utils/telemetry.py +140 -0
- odibi/validation/__init__.py +62 -0
- odibi/validation/engine.py +765 -0
- odibi/validation/explanation_linter.py +155 -0
- odibi/validation/fk.py +547 -0
- odibi/validation/gate.py +252 -0
- odibi/validation/quarantine.py +605 -0
- odibi/writers/__init__.py +15 -0
- odibi/writers/sql_server_writer.py +2081 -0
- odibi-2.5.0.dist-info/METADATA +255 -0
- odibi-2.5.0.dist-info/RECORD +124 -0
- odibi-2.5.0.dist-info/WHEEL +5 -0
- odibi-2.5.0.dist-info/entry_points.txt +2 -0
- odibi-2.5.0.dist-info/licenses/LICENSE +190 -0
- odibi-2.5.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,675 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import Dict, List, Literal, Optional, Union
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, Field, field_validator
|
|
6
|
+
|
|
7
|
+
from odibi.context import EngineContext
|
|
8
|
+
from odibi.enums import EngineType
|
|
9
|
+
from odibi.utils.logging_context import get_logging_context
|
|
10
|
+
|
|
11
|
+
# -------------------------------------------------------------------------
|
|
12
|
+
# 1. Join
|
|
13
|
+
# -------------------------------------------------------------------------
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class JoinParams(BaseModel):
|
|
17
|
+
"""
|
|
18
|
+
Configuration for joining datasets.
|
|
19
|
+
|
|
20
|
+
Scenario 1: Simple Left Join
|
|
21
|
+
```yaml
|
|
22
|
+
join:
|
|
23
|
+
right_dataset: "customers"
|
|
24
|
+
on: "customer_id"
|
|
25
|
+
how: "left"
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
Scenario 2: Join with Prefix (avoid collisions)
|
|
29
|
+
```yaml
|
|
30
|
+
join:
|
|
31
|
+
right_dataset: "orders"
|
|
32
|
+
on: ["user_id"]
|
|
33
|
+
how: "inner"
|
|
34
|
+
prefix: "ord" # Result cols: ord_date, ord_amount...
|
|
35
|
+
```
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
right_dataset: str = Field(..., description="Name of the node/dataset to join with")
|
|
39
|
+
on: Union[str, List[str]] = Field(..., description="Column(s) to join on")
|
|
40
|
+
how: Literal["inner", "left", "right", "full", "cross", "anti", "semi"] = Field(
|
|
41
|
+
"left", description="Join type"
|
|
42
|
+
)
|
|
43
|
+
prefix: Optional[str] = Field(
|
|
44
|
+
None, description="Prefix for columns from right dataset to avoid collisions"
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
@field_validator("on")
|
|
48
|
+
@classmethod
|
|
49
|
+
def coerce_on_to_list(cls, v):
|
|
50
|
+
if isinstance(v, str):
|
|
51
|
+
return [v]
|
|
52
|
+
if not v:
|
|
53
|
+
raise ValueError(
|
|
54
|
+
f"Join 'on' parameter must contain at least one join key column. "
|
|
55
|
+
f"Got: {v!r}. Provide column name(s) that exist in both datasets."
|
|
56
|
+
)
|
|
57
|
+
return v
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def join(context: EngineContext, params: JoinParams) -> EngineContext:
|
|
61
|
+
"""
|
|
62
|
+
Joins the current dataset with another dataset from the context.
|
|
63
|
+
"""
|
|
64
|
+
ctx = get_logging_context()
|
|
65
|
+
start_time = time.time()
|
|
66
|
+
|
|
67
|
+
ctx.debug(
|
|
68
|
+
"Join starting",
|
|
69
|
+
right_dataset=params.right_dataset,
|
|
70
|
+
join_type=params.how,
|
|
71
|
+
keys=params.on,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Get row count before transformation
|
|
75
|
+
rows_before = None
|
|
76
|
+
try:
|
|
77
|
+
rows_before = context.df.shape[0] if hasattr(context.df, "shape") else None
|
|
78
|
+
if rows_before is None and hasattr(context.df, "count"):
|
|
79
|
+
rows_before = context.df.count()
|
|
80
|
+
except Exception as e:
|
|
81
|
+
ctx.debug(f"Could not get row count: {type(e).__name__}")
|
|
82
|
+
|
|
83
|
+
# Get Right DF
|
|
84
|
+
right_df = context.get(params.right_dataset)
|
|
85
|
+
if right_df is None:
|
|
86
|
+
ctx.error(
|
|
87
|
+
"Join failed: right dataset not found",
|
|
88
|
+
right_dataset=params.right_dataset,
|
|
89
|
+
available_datasets=(
|
|
90
|
+
list(context.context._data.keys()) if hasattr(context, "context") else None
|
|
91
|
+
),
|
|
92
|
+
)
|
|
93
|
+
raise ValueError(
|
|
94
|
+
f"Join failed: dataset '{params.right_dataset}' not found in context. "
|
|
95
|
+
f"Available datasets: {list(context.context._data.keys()) if hasattr(context, 'context') and hasattr(context.context, '_data') else 'unknown'}. "
|
|
96
|
+
f"Ensure '{params.right_dataset}' is listed in 'depends_on' for this node."
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Get right df row count
|
|
100
|
+
right_rows = None
|
|
101
|
+
try:
|
|
102
|
+
right_rows = right_df.shape[0] if hasattr(right_df, "shape") else None
|
|
103
|
+
if right_rows is None and hasattr(right_df, "count"):
|
|
104
|
+
right_rows = right_df.count()
|
|
105
|
+
except Exception as e:
|
|
106
|
+
ctx.debug(f"Could not get row count: {type(e).__name__}")
|
|
107
|
+
|
|
108
|
+
ctx.debug(
|
|
109
|
+
"Join datasets loaded",
|
|
110
|
+
left_rows=rows_before,
|
|
111
|
+
right_rows=right_rows,
|
|
112
|
+
right_dataset=params.right_dataset,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# Register Right DF as temp view
|
|
116
|
+
right_view_name = f"join_right_{params.right_dataset}"
|
|
117
|
+
context.register_temp_view(right_view_name, right_df)
|
|
118
|
+
|
|
119
|
+
# Construct Join Condition
|
|
120
|
+
# params.on is guaranteed to be List[str] by validator
|
|
121
|
+
join_cols = params.on
|
|
122
|
+
|
|
123
|
+
join_condition = " AND ".join([f"df.{col} = {right_view_name}.{col}" for col in join_cols])
|
|
124
|
+
|
|
125
|
+
# Handle Column Selection (to apply prefix if needed)
|
|
126
|
+
# Strategy: We explicitly construct the projection to handle collisions safely
|
|
127
|
+
# and avoid ambiguous column references.
|
|
128
|
+
|
|
129
|
+
# 1. Get Columns
|
|
130
|
+
left_cols = context.columns
|
|
131
|
+
right_cols = list(right_df.columns) if hasattr(right_df, "columns") else []
|
|
132
|
+
|
|
133
|
+
# 2. Use Native Pandas optimization if possible
|
|
134
|
+
if context.engine_type == EngineType.PANDAS:
|
|
135
|
+
# Pandas defaults to ('_x', '_y'). We want ('', '_{prefix or right_dataset}')
|
|
136
|
+
suffix = f"_{params.prefix}" if params.prefix else f"_{params.right_dataset}"
|
|
137
|
+
|
|
138
|
+
# Handle anti and semi joins for pandas
|
|
139
|
+
if params.how == "anti":
|
|
140
|
+
# Anti join: rows in left that don't match right
|
|
141
|
+
merged = context.df.merge(right_df[params.on], on=params.on, how="left", indicator=True)
|
|
142
|
+
res = merged[merged["_merge"] == "left_only"].drop(columns=["_merge"])
|
|
143
|
+
elif params.how == "semi":
|
|
144
|
+
# Semi join: rows in left that match right (no columns from right)
|
|
145
|
+
merged = context.df.merge(right_df[params.on], on=params.on, how="inner")
|
|
146
|
+
res = merged.drop_duplicates(subset=params.on)
|
|
147
|
+
else:
|
|
148
|
+
res = context.df.merge(right_df, on=params.on, how=params.how, suffixes=("", suffix))
|
|
149
|
+
|
|
150
|
+
rows_after = res.shape[0] if hasattr(res, "shape") else None
|
|
151
|
+
elapsed_ms = (time.time() - start_time) * 1000
|
|
152
|
+
|
|
153
|
+
ctx.debug(
|
|
154
|
+
"Join completed",
|
|
155
|
+
join_type=params.how,
|
|
156
|
+
rows_before=rows_before,
|
|
157
|
+
rows_after=rows_after,
|
|
158
|
+
row_delta=rows_after - rows_before if rows_before and rows_after else None,
|
|
159
|
+
right_rows=right_rows,
|
|
160
|
+
elapsed_ms=round(elapsed_ms, 2),
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
return context.with_df(res)
|
|
164
|
+
|
|
165
|
+
# 3. For SQL/Spark, build explicit projection
|
|
166
|
+
projection = []
|
|
167
|
+
|
|
168
|
+
# Add Left Columns (with Coalesce for keys in Outer Join)
|
|
169
|
+
for col in left_cols:
|
|
170
|
+
if col in join_cols and params.how in ["right", "full", "outer"]:
|
|
171
|
+
# Coalesce to ensure we get non-null key from either side
|
|
172
|
+
projection.append(f"COALESCE(df.{col}, {right_view_name}.{col}) AS {col}")
|
|
173
|
+
else:
|
|
174
|
+
projection.append(f"df.{col}")
|
|
175
|
+
|
|
176
|
+
# Add Right Columns (skip keys, handle collisions)
|
|
177
|
+
for col in right_cols:
|
|
178
|
+
if col in join_cols:
|
|
179
|
+
continue
|
|
180
|
+
|
|
181
|
+
if col in left_cols:
|
|
182
|
+
# Collision! Apply prefix or default to right_dataset name
|
|
183
|
+
prefix = params.prefix if params.prefix else params.right_dataset
|
|
184
|
+
projection.append(f"{right_view_name}.{col} AS {prefix}_{col}")
|
|
185
|
+
else:
|
|
186
|
+
projection.append(f"{right_view_name}.{col}")
|
|
187
|
+
|
|
188
|
+
select_clause = ", ".join(projection)
|
|
189
|
+
|
|
190
|
+
# Map join types to SQL syntax
|
|
191
|
+
join_type_sql = params.how.upper()
|
|
192
|
+
if params.how == "anti":
|
|
193
|
+
join_type_sql = "LEFT ANTI"
|
|
194
|
+
elif params.how == "semi":
|
|
195
|
+
join_type_sql = "LEFT SEMI"
|
|
196
|
+
|
|
197
|
+
sql_query = f"""
|
|
198
|
+
SELECT {select_clause}
|
|
199
|
+
FROM df
|
|
200
|
+
{join_type_sql} JOIN {right_view_name}
|
|
201
|
+
ON {join_condition}
|
|
202
|
+
"""
|
|
203
|
+
result = context.sql(sql_query)
|
|
204
|
+
|
|
205
|
+
# Log completion
|
|
206
|
+
rows_after = None
|
|
207
|
+
try:
|
|
208
|
+
rows_after = result.df.shape[0] if hasattr(result.df, "shape") else None
|
|
209
|
+
if rows_after is None and hasattr(result.df, "count"):
|
|
210
|
+
rows_after = result.df.count()
|
|
211
|
+
except Exception as e:
|
|
212
|
+
ctx.debug(f"Could not get row count: {type(e).__name__}")
|
|
213
|
+
|
|
214
|
+
elapsed_ms = (time.time() - start_time) * 1000
|
|
215
|
+
ctx.debug(
|
|
216
|
+
"Join completed",
|
|
217
|
+
join_type=params.how,
|
|
218
|
+
rows_before=rows_before,
|
|
219
|
+
rows_after=rows_after,
|
|
220
|
+
row_delta=rows_after - rows_before if rows_before and rows_after else None,
|
|
221
|
+
right_rows=right_rows,
|
|
222
|
+
elapsed_ms=round(elapsed_ms, 2),
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
return result
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
# -------------------------------------------------------------------------
|
|
229
|
+
# 2. Union
|
|
230
|
+
# -------------------------------------------------------------------------
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class UnionParams(BaseModel):
|
|
234
|
+
"""
|
|
235
|
+
Configuration for unioning datasets.
|
|
236
|
+
|
|
237
|
+
Example (By Name - Default):
|
|
238
|
+
```yaml
|
|
239
|
+
union:
|
|
240
|
+
datasets: ["sales_2023", "sales_2024"]
|
|
241
|
+
by_name: true
|
|
242
|
+
```
|
|
243
|
+
|
|
244
|
+
Example (By Position):
|
|
245
|
+
```yaml
|
|
246
|
+
union:
|
|
247
|
+
datasets: ["legacy_data"]
|
|
248
|
+
by_name: false
|
|
249
|
+
```
|
|
250
|
+
"""
|
|
251
|
+
|
|
252
|
+
datasets: List[str] = Field(..., description="List of node names to union with current")
|
|
253
|
+
by_name: bool = Field(True, description="Match columns by name (UNION ALL BY NAME)")
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def union(context: EngineContext, params: UnionParams) -> EngineContext:
|
|
257
|
+
"""
|
|
258
|
+
Unions current dataset with others.
|
|
259
|
+
"""
|
|
260
|
+
ctx = get_logging_context()
|
|
261
|
+
start_time = time.time()
|
|
262
|
+
|
|
263
|
+
ctx.debug(
|
|
264
|
+
"Union starting",
|
|
265
|
+
datasets=params.datasets,
|
|
266
|
+
by_name=params.by_name,
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
# Get row count of current df
|
|
270
|
+
rows_before = None
|
|
271
|
+
try:
|
|
272
|
+
rows_before = context.df.shape[0] if hasattr(context.df, "shape") else None
|
|
273
|
+
if rows_before is None and hasattr(context.df, "count"):
|
|
274
|
+
rows_before = context.df.count()
|
|
275
|
+
except Exception as e:
|
|
276
|
+
ctx.debug(f"Could not get row count: {type(e).__name__}")
|
|
277
|
+
|
|
278
|
+
union_sqls = []
|
|
279
|
+
dataset_row_counts = {"current": rows_before}
|
|
280
|
+
|
|
281
|
+
# Add current
|
|
282
|
+
union_sqls.append("SELECT * FROM df")
|
|
283
|
+
|
|
284
|
+
# Add others
|
|
285
|
+
for ds_name in params.datasets:
|
|
286
|
+
other_df = context.get(ds_name)
|
|
287
|
+
if other_df is None:
|
|
288
|
+
ctx.error(
|
|
289
|
+
"Union failed: dataset not found",
|
|
290
|
+
missing_dataset=ds_name,
|
|
291
|
+
requested_datasets=params.datasets,
|
|
292
|
+
)
|
|
293
|
+
raise ValueError(
|
|
294
|
+
f"Union failed: dataset '{ds_name}' not found in context. "
|
|
295
|
+
f"Requested datasets: {params.datasets}. "
|
|
296
|
+
f"Available datasets: {list(context.context._data.keys()) if hasattr(context, 'context') and hasattr(context.context, '_data') else 'unknown'}. "
|
|
297
|
+
f"Ensure all datasets are listed in 'depends_on'."
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
# Get row count of other df
|
|
301
|
+
try:
|
|
302
|
+
other_rows = other_df.shape[0] if hasattr(other_df, "shape") else None
|
|
303
|
+
if other_rows is None and hasattr(other_df, "count"):
|
|
304
|
+
other_rows = other_df.count()
|
|
305
|
+
dataset_row_counts[ds_name] = other_rows
|
|
306
|
+
except Exception as e:
|
|
307
|
+
ctx.debug(f"Could not get row count: {type(e).__name__}")
|
|
308
|
+
|
|
309
|
+
view_name = f"union_{ds_name}"
|
|
310
|
+
context.register_temp_view(view_name, other_df)
|
|
311
|
+
union_sqls.append(f"SELECT * FROM {view_name}")
|
|
312
|
+
|
|
313
|
+
ctx.debug(
|
|
314
|
+
"Union datasets loaded",
|
|
315
|
+
dataset_row_counts=dataset_row_counts,
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
# Construct Query
|
|
319
|
+
# DuckDB supports "UNION ALL BY NAME", Spark does too in recent versions.
|
|
320
|
+
operator = "UNION ALL BY NAME" if params.by_name else "UNION ALL"
|
|
321
|
+
|
|
322
|
+
# Fallback for engines without BY NAME if needed (omitted for brevity, assuming modern engines)
|
|
323
|
+
# Spark < 3.1 might need logic.
|
|
324
|
+
|
|
325
|
+
sql_query = f" {operator} ".join(union_sqls)
|
|
326
|
+
result = context.sql(sql_query)
|
|
327
|
+
|
|
328
|
+
# Log completion
|
|
329
|
+
rows_after = None
|
|
330
|
+
try:
|
|
331
|
+
rows_after = result.df.shape[0] if hasattr(result.df, "shape") else None
|
|
332
|
+
if rows_after is None and hasattr(result.df, "count"):
|
|
333
|
+
rows_after = result.df.count()
|
|
334
|
+
except Exception as e:
|
|
335
|
+
ctx.debug(f"Could not get row count: {type(e).__name__}")
|
|
336
|
+
|
|
337
|
+
elapsed_ms = (time.time() - start_time) * 1000
|
|
338
|
+
ctx.debug(
|
|
339
|
+
"Union completed",
|
|
340
|
+
datasets_count=len(params.datasets) + 1,
|
|
341
|
+
rows_after=rows_after,
|
|
342
|
+
elapsed_ms=round(elapsed_ms, 2),
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
return result
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
# -------------------------------------------------------------------------
|
|
349
|
+
# 3. Pivot
|
|
350
|
+
# -------------------------------------------------------------------------
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
class PivotParams(BaseModel):
|
|
354
|
+
"""
|
|
355
|
+
Configuration for pivoting data.
|
|
356
|
+
|
|
357
|
+
Example:
|
|
358
|
+
```yaml
|
|
359
|
+
pivot:
|
|
360
|
+
group_by: ["product_id", "region"]
|
|
361
|
+
pivot_col: "month"
|
|
362
|
+
agg_col: "sales"
|
|
363
|
+
agg_func: "sum"
|
|
364
|
+
```
|
|
365
|
+
|
|
366
|
+
Example (Optimized for Spark):
|
|
367
|
+
```yaml
|
|
368
|
+
pivot:
|
|
369
|
+
group_by: ["id"]
|
|
370
|
+
pivot_col: "category"
|
|
371
|
+
values: ["A", "B", "C"] # Explicit values avoid extra pass
|
|
372
|
+
agg_col: "amount"
|
|
373
|
+
```
|
|
374
|
+
"""
|
|
375
|
+
|
|
376
|
+
group_by: List[str]
|
|
377
|
+
pivot_col: str
|
|
378
|
+
agg_col: str
|
|
379
|
+
agg_func: Literal["sum", "count", "avg", "max", "min", "first"] = "sum"
|
|
380
|
+
values: Optional[List[str]] = Field(
|
|
381
|
+
None, description="Specific values to pivot (for Spark optimization)"
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
def pivot(context: EngineContext, params: PivotParams) -> EngineContext:
|
|
386
|
+
"""
|
|
387
|
+
Pivots row values into columns.
|
|
388
|
+
"""
|
|
389
|
+
ctx = get_logging_context()
|
|
390
|
+
start_time = time.time()
|
|
391
|
+
|
|
392
|
+
ctx.debug(
|
|
393
|
+
"Pivot starting",
|
|
394
|
+
group_by=params.group_by,
|
|
395
|
+
pivot_col=params.pivot_col,
|
|
396
|
+
agg_col=params.agg_col,
|
|
397
|
+
agg_func=params.agg_func,
|
|
398
|
+
values=params.values,
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
# Get row count before transformation
|
|
402
|
+
rows_before = None
|
|
403
|
+
try:
|
|
404
|
+
rows_before = context.df.shape[0] if hasattr(context.df, "shape") else None
|
|
405
|
+
if rows_before is None and hasattr(context.df, "count"):
|
|
406
|
+
rows_before = context.df.count()
|
|
407
|
+
except Exception as e:
|
|
408
|
+
ctx.debug(f"Could not get row count: {type(e).__name__}")
|
|
409
|
+
|
|
410
|
+
if context.engine_type == EngineType.SPARK:
|
|
411
|
+
df = context.df.groupBy(*params.group_by)
|
|
412
|
+
|
|
413
|
+
if params.values:
|
|
414
|
+
pivot_op = df.pivot(params.pivot_col, params.values)
|
|
415
|
+
else:
|
|
416
|
+
pivot_op = df.pivot(params.pivot_col)
|
|
417
|
+
|
|
418
|
+
# Construct agg expression dynamically based on string
|
|
419
|
+
import pyspark.sql.functions as F
|
|
420
|
+
|
|
421
|
+
agg_expr = getattr(F, params.agg_func)(params.agg_col)
|
|
422
|
+
|
|
423
|
+
res = pivot_op.agg(agg_expr)
|
|
424
|
+
|
|
425
|
+
rows_after = res.count() if hasattr(res, "count") else None
|
|
426
|
+
elapsed_ms = (time.time() - start_time) * 1000
|
|
427
|
+
ctx.debug(
|
|
428
|
+
"Pivot completed",
|
|
429
|
+
rows_before=rows_before,
|
|
430
|
+
rows_after=rows_after,
|
|
431
|
+
elapsed_ms=round(elapsed_ms, 2),
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
return context.with_df(res)
|
|
435
|
+
|
|
436
|
+
elif context.engine_type == EngineType.PANDAS:
|
|
437
|
+
import pandas as pd
|
|
438
|
+
|
|
439
|
+
# pivot_table is robust
|
|
440
|
+
res = pd.pivot_table(
|
|
441
|
+
context.df,
|
|
442
|
+
index=params.group_by,
|
|
443
|
+
columns=params.pivot_col,
|
|
444
|
+
values=params.agg_col,
|
|
445
|
+
aggfunc=params.agg_func,
|
|
446
|
+
).reset_index()
|
|
447
|
+
|
|
448
|
+
rows_after = res.shape[0] if hasattr(res, "shape") else None
|
|
449
|
+
elapsed_ms = (time.time() - start_time) * 1000
|
|
450
|
+
ctx.debug(
|
|
451
|
+
"Pivot completed",
|
|
452
|
+
rows_before=rows_before,
|
|
453
|
+
rows_after=rows_after,
|
|
454
|
+
columns_after=len(res.columns) if hasattr(res, "columns") else None,
|
|
455
|
+
elapsed_ms=round(elapsed_ms, 2),
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
return context.with_df(res)
|
|
459
|
+
|
|
460
|
+
else:
|
|
461
|
+
ctx.error(
|
|
462
|
+
"Pivot failed: unsupported engine",
|
|
463
|
+
engine_type=str(context.engine_type),
|
|
464
|
+
)
|
|
465
|
+
raise ValueError(
|
|
466
|
+
f"Pivot transformer does not support engine type '{context.engine_type}'. "
|
|
467
|
+
f"Supported engines: SPARK, PANDAS. "
|
|
468
|
+
f"Check your engine configuration."
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
# -------------------------------------------------------------------------
|
|
473
|
+
# 4. Unpivot (Stack)
|
|
474
|
+
# -------------------------------------------------------------------------
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
class UnpivotParams(BaseModel):
|
|
478
|
+
"""
|
|
479
|
+
Configuration for unpivoting (melting) data.
|
|
480
|
+
|
|
481
|
+
Example:
|
|
482
|
+
```yaml
|
|
483
|
+
unpivot:
|
|
484
|
+
id_cols: ["product_id"]
|
|
485
|
+
value_vars: ["jan_sales", "feb_sales", "mar_sales"]
|
|
486
|
+
var_name: "month"
|
|
487
|
+
value_name: "sales"
|
|
488
|
+
```
|
|
489
|
+
"""
|
|
490
|
+
|
|
491
|
+
id_cols: List[str]
|
|
492
|
+
value_vars: List[str]
|
|
493
|
+
var_name: str = "variable"
|
|
494
|
+
value_name: str = "value"
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
def unpivot(context: EngineContext, params: UnpivotParams) -> EngineContext:
|
|
498
|
+
"""
|
|
499
|
+
Unpivots columns into rows (Melt/Stack).
|
|
500
|
+
"""
|
|
501
|
+
ctx = get_logging_context()
|
|
502
|
+
start_time = time.time()
|
|
503
|
+
|
|
504
|
+
ctx.debug(
|
|
505
|
+
"Unpivot starting",
|
|
506
|
+
id_cols=params.id_cols,
|
|
507
|
+
value_vars=params.value_vars,
|
|
508
|
+
var_name=params.var_name,
|
|
509
|
+
value_name=params.value_name,
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
# Get row count before transformation
|
|
513
|
+
rows_before = None
|
|
514
|
+
try:
|
|
515
|
+
rows_before = context.df.shape[0] if hasattr(context.df, "shape") else None
|
|
516
|
+
if rows_before is None and hasattr(context.df, "count"):
|
|
517
|
+
rows_before = context.df.count()
|
|
518
|
+
except Exception as e:
|
|
519
|
+
ctx.debug(f"Could not get row count: {type(e).__name__}")
|
|
520
|
+
|
|
521
|
+
if context.engine_type == EngineType.PANDAS:
|
|
522
|
+
res = context.df.melt(
|
|
523
|
+
id_vars=params.id_cols,
|
|
524
|
+
value_vars=params.value_vars,
|
|
525
|
+
var_name=params.var_name,
|
|
526
|
+
value_name=params.value_name,
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
rows_after = res.shape[0] if hasattr(res, "shape") else None
|
|
530
|
+
elapsed_ms = (time.time() - start_time) * 1000
|
|
531
|
+
ctx.debug(
|
|
532
|
+
"Unpivot completed",
|
|
533
|
+
rows_before=rows_before,
|
|
534
|
+
rows_after=rows_after,
|
|
535
|
+
value_vars_count=len(params.value_vars),
|
|
536
|
+
elapsed_ms=round(elapsed_ms, 2),
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
return context.with_df(res)
|
|
540
|
+
|
|
541
|
+
elif context.engine_type == EngineType.SPARK:
|
|
542
|
+
# Spark Stack Syntax: stack(n, col1, val1, col2, val2, ...)
|
|
543
|
+
import pyspark.sql.functions as F
|
|
544
|
+
|
|
545
|
+
# Construct stack expression string
|
|
546
|
+
# "stack(2, 'A', A, 'B', B) as (variable, value)"
|
|
547
|
+
num_vars = len(params.value_vars)
|
|
548
|
+
stack_args = []
|
|
549
|
+
for col in params.value_vars:
|
|
550
|
+
stack_args.append(f"'{col}'") # The label
|
|
551
|
+
stack_args.append(col) # The value
|
|
552
|
+
|
|
553
|
+
stack_expr = (
|
|
554
|
+
f"stack({num_vars}, {', '.join(stack_args)}) "
|
|
555
|
+
f"as ({params.var_name}, {params.value_name})"
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
res = context.df.select(*params.id_cols, F.expr(stack_expr))
|
|
559
|
+
|
|
560
|
+
rows_after = res.count() if hasattr(res, "count") else None
|
|
561
|
+
elapsed_ms = (time.time() - start_time) * 1000
|
|
562
|
+
ctx.debug(
|
|
563
|
+
"Unpivot completed",
|
|
564
|
+
rows_before=rows_before,
|
|
565
|
+
rows_after=rows_after,
|
|
566
|
+
value_vars_count=len(params.value_vars),
|
|
567
|
+
elapsed_ms=round(elapsed_ms, 2),
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
return context.with_df(res)
|
|
571
|
+
|
|
572
|
+
else:
|
|
573
|
+
ctx.error(
|
|
574
|
+
"Unpivot failed: unsupported engine",
|
|
575
|
+
engine_type=str(context.engine_type),
|
|
576
|
+
)
|
|
577
|
+
raise ValueError(
|
|
578
|
+
f"Unpivot transformer does not support engine type '{context.engine_type}'. "
|
|
579
|
+
f"Supported engines: SPARK, PANDAS. "
|
|
580
|
+
f"Check your engine configuration."
|
|
581
|
+
)
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
# -------------------------------------------------------------------------
|
|
585
|
+
# 5. Aggregate
|
|
586
|
+
# -------------------------------------------------------------------------
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
class AggFunc(str, Enum):
|
|
590
|
+
SUM = "sum"
|
|
591
|
+
AVG = "avg"
|
|
592
|
+
MIN = "min"
|
|
593
|
+
MAX = "max"
|
|
594
|
+
COUNT = "count"
|
|
595
|
+
FIRST = "first"
|
|
596
|
+
|
|
597
|
+
|
|
598
|
+
class AggregateParams(BaseModel):
|
|
599
|
+
"""
|
|
600
|
+
Configuration for aggregation.
|
|
601
|
+
|
|
602
|
+
Example:
|
|
603
|
+
```yaml
|
|
604
|
+
aggregate:
|
|
605
|
+
group_by: ["department", "region"]
|
|
606
|
+
aggregations:
|
|
607
|
+
salary: "sum"
|
|
608
|
+
employee_id: "count"
|
|
609
|
+
age: "avg"
|
|
610
|
+
```
|
|
611
|
+
"""
|
|
612
|
+
|
|
613
|
+
group_by: List[str] = Field(..., description="Columns to group by")
|
|
614
|
+
aggregations: Dict[str, AggFunc] = Field(
|
|
615
|
+
..., description="Map of column to aggregation function (sum, avg, min, max, count)"
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
def aggregate(context: EngineContext, params: AggregateParams) -> EngineContext:
|
|
620
|
+
"""
|
|
621
|
+
Performs grouping and aggregation via SQL.
|
|
622
|
+
"""
|
|
623
|
+
ctx = get_logging_context()
|
|
624
|
+
start_time = time.time()
|
|
625
|
+
|
|
626
|
+
ctx.debug(
|
|
627
|
+
"Aggregate starting",
|
|
628
|
+
group_by=params.group_by,
|
|
629
|
+
aggregations={col: func.value for col, func in params.aggregations.items()},
|
|
630
|
+
)
|
|
631
|
+
|
|
632
|
+
# Get row count before transformation
|
|
633
|
+
rows_before = None
|
|
634
|
+
try:
|
|
635
|
+
rows_before = context.df.shape[0] if hasattr(context.df, "shape") else None
|
|
636
|
+
if rows_before is None and hasattr(context.df, "count"):
|
|
637
|
+
rows_before = context.df.count()
|
|
638
|
+
except Exception as e:
|
|
639
|
+
ctx.debug(f"Could not get row count: {type(e).__name__}")
|
|
640
|
+
|
|
641
|
+
group_cols = ", ".join(params.group_by)
|
|
642
|
+
agg_exprs = []
|
|
643
|
+
|
|
644
|
+
for col, func in params.aggregations.items():
|
|
645
|
+
# Construct agg: SUM(col) AS col
|
|
646
|
+
agg_exprs.append(f"{func.value.upper()}({col}) AS {col}")
|
|
647
|
+
|
|
648
|
+
# Select grouped cols + aggregated cols
|
|
649
|
+
# Note: params.group_by are already columns, so we list them
|
|
650
|
+
select_items = params.group_by + agg_exprs
|
|
651
|
+
select_clause = ", ".join(select_items)
|
|
652
|
+
|
|
653
|
+
sql_query = f"SELECT {select_clause} FROM df GROUP BY {group_cols}"
|
|
654
|
+
result = context.sql(sql_query)
|
|
655
|
+
|
|
656
|
+
# Log completion
|
|
657
|
+
rows_after = None
|
|
658
|
+
try:
|
|
659
|
+
rows_after = result.df.shape[0] if hasattr(result.df, "shape") else None
|
|
660
|
+
if rows_after is None and hasattr(result.df, "count"):
|
|
661
|
+
rows_after = result.df.count()
|
|
662
|
+
except Exception as e:
|
|
663
|
+
ctx.debug(f"Could not get row count: {type(e).__name__}")
|
|
664
|
+
|
|
665
|
+
elapsed_ms = (time.time() - start_time) * 1000
|
|
666
|
+
ctx.debug(
|
|
667
|
+
"Aggregate completed",
|
|
668
|
+
group_by=params.group_by,
|
|
669
|
+
rows_before=rows_before,
|
|
670
|
+
rows_after=rows_after,
|
|
671
|
+
aggregation_count=len(params.aggregations),
|
|
672
|
+
elapsed_ms=round(elapsed_ms, 2),
|
|
673
|
+
)
|
|
674
|
+
|
|
675
|
+
return result
|