odibi 2.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- odibi/__init__.py +32 -0
- odibi/__main__.py +8 -0
- odibi/catalog.py +3011 -0
- odibi/cli/__init__.py +11 -0
- odibi/cli/__main__.py +6 -0
- odibi/cli/catalog.py +553 -0
- odibi/cli/deploy.py +69 -0
- odibi/cli/doctor.py +161 -0
- odibi/cli/export.py +66 -0
- odibi/cli/graph.py +150 -0
- odibi/cli/init_pipeline.py +242 -0
- odibi/cli/lineage.py +259 -0
- odibi/cli/main.py +215 -0
- odibi/cli/run.py +98 -0
- odibi/cli/schema.py +208 -0
- odibi/cli/secrets.py +232 -0
- odibi/cli/story.py +379 -0
- odibi/cli/system.py +132 -0
- odibi/cli/test.py +286 -0
- odibi/cli/ui.py +31 -0
- odibi/cli/validate.py +39 -0
- odibi/config.py +3541 -0
- odibi/connections/__init__.py +9 -0
- odibi/connections/azure_adls.py +499 -0
- odibi/connections/azure_sql.py +709 -0
- odibi/connections/base.py +28 -0
- odibi/connections/factory.py +322 -0
- odibi/connections/http.py +78 -0
- odibi/connections/local.py +119 -0
- odibi/connections/local_dbfs.py +61 -0
- odibi/constants.py +17 -0
- odibi/context.py +528 -0
- odibi/diagnostics/__init__.py +12 -0
- odibi/diagnostics/delta.py +520 -0
- odibi/diagnostics/diff.py +169 -0
- odibi/diagnostics/manager.py +171 -0
- odibi/engine/__init__.py +20 -0
- odibi/engine/base.py +334 -0
- odibi/engine/pandas_engine.py +2178 -0
- odibi/engine/polars_engine.py +1114 -0
- odibi/engine/registry.py +54 -0
- odibi/engine/spark_engine.py +2362 -0
- odibi/enums.py +7 -0
- odibi/exceptions.py +297 -0
- odibi/graph.py +426 -0
- odibi/introspect.py +1214 -0
- odibi/lineage.py +511 -0
- odibi/node.py +3341 -0
- odibi/orchestration/__init__.py +0 -0
- odibi/orchestration/airflow.py +90 -0
- odibi/orchestration/dagster.py +77 -0
- odibi/patterns/__init__.py +24 -0
- odibi/patterns/aggregation.py +599 -0
- odibi/patterns/base.py +94 -0
- odibi/patterns/date_dimension.py +423 -0
- odibi/patterns/dimension.py +696 -0
- odibi/patterns/fact.py +748 -0
- odibi/patterns/merge.py +128 -0
- odibi/patterns/scd2.py +148 -0
- odibi/pipeline.py +2382 -0
- odibi/plugins.py +80 -0
- odibi/project.py +581 -0
- odibi/references.py +151 -0
- odibi/registry.py +246 -0
- odibi/semantics/__init__.py +71 -0
- odibi/semantics/materialize.py +392 -0
- odibi/semantics/metrics.py +361 -0
- odibi/semantics/query.py +743 -0
- odibi/semantics/runner.py +430 -0
- odibi/semantics/story.py +507 -0
- odibi/semantics/views.py +432 -0
- odibi/state/__init__.py +1203 -0
- odibi/story/__init__.py +55 -0
- odibi/story/doc_story.py +554 -0
- odibi/story/generator.py +1431 -0
- odibi/story/lineage.py +1043 -0
- odibi/story/lineage_utils.py +324 -0
- odibi/story/metadata.py +608 -0
- odibi/story/renderers.py +453 -0
- odibi/story/templates/run_story.html +2520 -0
- odibi/story/themes.py +216 -0
- odibi/testing/__init__.py +13 -0
- odibi/testing/assertions.py +75 -0
- odibi/testing/fixtures.py +85 -0
- odibi/testing/source_pool.py +277 -0
- odibi/transformers/__init__.py +122 -0
- odibi/transformers/advanced.py +1472 -0
- odibi/transformers/delete_detection.py +610 -0
- odibi/transformers/manufacturing.py +1029 -0
- odibi/transformers/merge_transformer.py +778 -0
- odibi/transformers/relational.py +675 -0
- odibi/transformers/scd.py +579 -0
- odibi/transformers/sql_core.py +1356 -0
- odibi/transformers/validation.py +165 -0
- odibi/ui/__init__.py +0 -0
- odibi/ui/app.py +195 -0
- odibi/utils/__init__.py +66 -0
- odibi/utils/alerting.py +667 -0
- odibi/utils/config_loader.py +343 -0
- odibi/utils/console.py +231 -0
- odibi/utils/content_hash.py +202 -0
- odibi/utils/duration.py +43 -0
- odibi/utils/encoding.py +102 -0
- odibi/utils/extensions.py +28 -0
- odibi/utils/hashing.py +61 -0
- odibi/utils/logging.py +203 -0
- odibi/utils/logging_context.py +740 -0
- odibi/utils/progress.py +429 -0
- odibi/utils/setup_helpers.py +302 -0
- odibi/utils/telemetry.py +140 -0
- odibi/validation/__init__.py +62 -0
- odibi/validation/engine.py +765 -0
- odibi/validation/explanation_linter.py +155 -0
- odibi/validation/fk.py +547 -0
- odibi/validation/gate.py +252 -0
- odibi/validation/quarantine.py +605 -0
- odibi/writers/__init__.py +15 -0
- odibi/writers/sql_server_writer.py +2081 -0
- odibi-2.5.0.dist-info/METADATA +255 -0
- odibi-2.5.0.dist-info/RECORD +124 -0
- odibi-2.5.0.dist-info/WHEEL +5 -0
- odibi-2.5.0.dist-info/entry_points.txt +2 -0
- odibi-2.5.0.dist-info/licenses/LICENSE +190 -0
- odibi-2.5.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1472 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import Dict, List, Optional, Union
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, Field, field_validator
|
|
6
|
+
|
|
7
|
+
from odibi.context import EngineContext
|
|
8
|
+
from odibi.enums import EngineType
|
|
9
|
+
from odibi.utils.logging_context import get_logging_context
|
|
10
|
+
|
|
11
|
+
# -------------------------------------------------------------------------
|
|
12
|
+
# 1. Deduplicate (Window)
|
|
13
|
+
# -------------------------------------------------------------------------
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DeduplicateParams(BaseModel):
|
|
17
|
+
"""
|
|
18
|
+
Configuration for deduplication.
|
|
19
|
+
|
|
20
|
+
Scenario: Keep latest record
|
|
21
|
+
```yaml
|
|
22
|
+
deduplicate:
|
|
23
|
+
keys: ["id"]
|
|
24
|
+
order_by: "updated_at DESC"
|
|
25
|
+
```
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
keys: List[str] = Field(
|
|
29
|
+
..., description="List of columns to partition by (columns that define uniqueness)"
|
|
30
|
+
)
|
|
31
|
+
order_by: Optional[str] = Field(
|
|
32
|
+
None,
|
|
33
|
+
description="SQL Order by clause (e.g. 'updated_at DESC') to determine which record to keep (first one is kept)",
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def deduplicate(context: EngineContext, params: DeduplicateParams) -> EngineContext:
|
|
38
|
+
"""
|
|
39
|
+
Deduplicates data using Window functions.
|
|
40
|
+
"""
|
|
41
|
+
ctx = get_logging_context()
|
|
42
|
+
start_time = time.time()
|
|
43
|
+
|
|
44
|
+
ctx.debug(
|
|
45
|
+
"Deduplicate starting",
|
|
46
|
+
keys=params.keys,
|
|
47
|
+
order_by=params.order_by,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
# Get row count before transformation (optional, for logging only)
|
|
51
|
+
rows_before = None
|
|
52
|
+
try:
|
|
53
|
+
rows_before = context.df.shape[0] if hasattr(context.df, "shape") else None
|
|
54
|
+
if rows_before is None and hasattr(context.df, "count"):
|
|
55
|
+
rows_before = context.df.count()
|
|
56
|
+
except Exception as e:
|
|
57
|
+
ctx.debug(f"Could not get row count before transform: {type(e).__name__}")
|
|
58
|
+
|
|
59
|
+
partition_clause = ", ".join(params.keys)
|
|
60
|
+
order_clause = params.order_by if params.order_by else "(SELECT NULL)"
|
|
61
|
+
|
|
62
|
+
# Dialect handling for EXCEPT/EXCLUDE
|
|
63
|
+
except_clause = "EXCEPT"
|
|
64
|
+
if context.engine_type == EngineType.PANDAS:
|
|
65
|
+
# DuckDB uses EXCLUDE
|
|
66
|
+
except_clause = "EXCLUDE"
|
|
67
|
+
|
|
68
|
+
sql_query = f"""
|
|
69
|
+
SELECT * {except_clause}(_rn) FROM (
|
|
70
|
+
SELECT *,
|
|
71
|
+
ROW_NUMBER() OVER (PARTITION BY {partition_clause} ORDER BY {order_clause}) as _rn
|
|
72
|
+
FROM df
|
|
73
|
+
) WHERE _rn = 1
|
|
74
|
+
"""
|
|
75
|
+
result = context.sql(sql_query)
|
|
76
|
+
|
|
77
|
+
# Get row count after transformation (optional, for logging only)
|
|
78
|
+
rows_after = None
|
|
79
|
+
try:
|
|
80
|
+
rows_after = result.df.shape[0] if hasattr(result.df, "shape") else None
|
|
81
|
+
if rows_after is None and hasattr(result.df, "count"):
|
|
82
|
+
rows_after = result.df.count()
|
|
83
|
+
except Exception as e:
|
|
84
|
+
ctx.debug(f"Could not get row count after transform: {type(e).__name__}")
|
|
85
|
+
|
|
86
|
+
elapsed_ms = (time.time() - start_time) * 1000
|
|
87
|
+
duplicates_removed = rows_before - rows_after if rows_before and rows_after else None
|
|
88
|
+
ctx.debug(
|
|
89
|
+
"Deduplicate completed",
|
|
90
|
+
keys=params.keys,
|
|
91
|
+
rows_before=rows_before,
|
|
92
|
+
rows_after=rows_after,
|
|
93
|
+
duplicates_removed=duplicates_removed,
|
|
94
|
+
elapsed_ms=round(elapsed_ms, 2),
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
return result
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
# -------------------------------------------------------------------------
|
|
101
|
+
# 2. Explode List
|
|
102
|
+
# -------------------------------------------------------------------------
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class ExplodeParams(BaseModel):
|
|
106
|
+
"""
|
|
107
|
+
Configuration for exploding lists.
|
|
108
|
+
|
|
109
|
+
Scenario: Flatten list of items per order
|
|
110
|
+
```yaml
|
|
111
|
+
explode_list_column:
|
|
112
|
+
column: "items"
|
|
113
|
+
outer: true # Keep orders with empty items list
|
|
114
|
+
```
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
column: str = Field(..., description="Column containing the list/array to explode")
|
|
118
|
+
outer: bool = Field(
|
|
119
|
+
False,
|
|
120
|
+
description="If True, keep rows with empty lists (explode_outer behavior). If False, drops them.",
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def explode_list_column(context: EngineContext, params: ExplodeParams) -> EngineContext:
|
|
125
|
+
ctx = get_logging_context()
|
|
126
|
+
start_time = time.time()
|
|
127
|
+
|
|
128
|
+
ctx.debug(
|
|
129
|
+
"Explode starting",
|
|
130
|
+
column=params.column,
|
|
131
|
+
outer=params.outer,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
rows_before = None
|
|
135
|
+
try:
|
|
136
|
+
rows_before = context.df.shape[0] if hasattr(context.df, "shape") else None
|
|
137
|
+
if rows_before is None and hasattr(context.df, "count"):
|
|
138
|
+
rows_before = context.df.count()
|
|
139
|
+
except Exception:
|
|
140
|
+
pass
|
|
141
|
+
|
|
142
|
+
if context.engine_type == EngineType.SPARK:
|
|
143
|
+
import pyspark.sql.functions as F
|
|
144
|
+
|
|
145
|
+
func = F.explode_outer if params.outer else F.explode
|
|
146
|
+
df = context.df.withColumn(params.column, func(F.col(params.column)))
|
|
147
|
+
|
|
148
|
+
rows_after = df.count() if hasattr(df, "count") else None
|
|
149
|
+
elapsed_ms = (time.time() - start_time) * 1000
|
|
150
|
+
ctx.debug(
|
|
151
|
+
"Explode completed",
|
|
152
|
+
column=params.column,
|
|
153
|
+
rows_before=rows_before,
|
|
154
|
+
rows_after=rows_after,
|
|
155
|
+
elapsed_ms=round(elapsed_ms, 2),
|
|
156
|
+
)
|
|
157
|
+
return context.with_df(df)
|
|
158
|
+
|
|
159
|
+
elif context.engine_type == EngineType.PANDAS:
|
|
160
|
+
df = context.df.explode(params.column)
|
|
161
|
+
if not params.outer:
|
|
162
|
+
df = df.dropna(subset=[params.column])
|
|
163
|
+
|
|
164
|
+
rows_after = df.shape[0] if hasattr(df, "shape") else None
|
|
165
|
+
elapsed_ms = (time.time() - start_time) * 1000
|
|
166
|
+
ctx.debug(
|
|
167
|
+
"Explode completed",
|
|
168
|
+
column=params.column,
|
|
169
|
+
rows_before=rows_before,
|
|
170
|
+
rows_after=rows_after,
|
|
171
|
+
elapsed_ms=round(elapsed_ms, 2),
|
|
172
|
+
)
|
|
173
|
+
return context.with_df(df)
|
|
174
|
+
|
|
175
|
+
else:
|
|
176
|
+
ctx.error("Explode failed: unsupported engine", engine_type=str(context.engine_type))
|
|
177
|
+
raise ValueError(
|
|
178
|
+
f"Explode transformer does not support engine type '{context.engine_type}'. "
|
|
179
|
+
f"Supported engines: SPARK, PANDAS. "
|
|
180
|
+
f"Check your engine configuration."
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
# -------------------------------------------------------------------------
|
|
185
|
+
# 3. Dict Mapping
|
|
186
|
+
# -------------------------------------------------------------------------
|
|
187
|
+
|
|
188
|
+
JsonScalar = Union[str, int, float, bool, None]
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class DictMappingParams(BaseModel):
|
|
192
|
+
"""
|
|
193
|
+
Configuration for dictionary mapping.
|
|
194
|
+
|
|
195
|
+
Scenario: Map status codes to labels
|
|
196
|
+
```yaml
|
|
197
|
+
dict_based_mapping:
|
|
198
|
+
column: "status_code"
|
|
199
|
+
mapping:
|
|
200
|
+
"1": "Active"
|
|
201
|
+
"0": "Inactive"
|
|
202
|
+
default: "Unknown"
|
|
203
|
+
output_column: "status_desc"
|
|
204
|
+
```
|
|
205
|
+
"""
|
|
206
|
+
|
|
207
|
+
column: str = Field(..., description="Column to map values from")
|
|
208
|
+
mapping: Dict[str, JsonScalar] = Field(
|
|
209
|
+
..., description="Dictionary of source value -> target value"
|
|
210
|
+
)
|
|
211
|
+
default: Optional[JsonScalar] = Field(
|
|
212
|
+
None, description="Default value if source value is not found in mapping"
|
|
213
|
+
)
|
|
214
|
+
output_column: Optional[str] = Field(
|
|
215
|
+
None, description="Name of output column. If not provided, overwrites source column."
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def dict_based_mapping(context: EngineContext, params: DictMappingParams) -> EngineContext:
|
|
220
|
+
target_col = params.output_column or params.column
|
|
221
|
+
|
|
222
|
+
if context.engine_type == EngineType.SPARK:
|
|
223
|
+
from itertools import chain
|
|
224
|
+
|
|
225
|
+
import pyspark.sql.functions as F
|
|
226
|
+
|
|
227
|
+
# Create map expression
|
|
228
|
+
mapping_expr = F.create_map([F.lit(x) for x in chain(*params.mapping.items())])
|
|
229
|
+
|
|
230
|
+
df = context.df.withColumn(target_col, mapping_expr[F.col(params.column)])
|
|
231
|
+
if params.default is not None:
|
|
232
|
+
df = df.withColumn(target_col, F.coalesce(F.col(target_col), F.lit(params.default)))
|
|
233
|
+
return context.with_df(df)
|
|
234
|
+
|
|
235
|
+
elif context.engine_type == EngineType.PANDAS:
|
|
236
|
+
df = context.df.copy()
|
|
237
|
+
# Pandas map is fast
|
|
238
|
+
df[target_col] = df[params.column].map(params.mapping)
|
|
239
|
+
if params.default is not None:
|
|
240
|
+
df[target_col] = df[target_col].fillna(params.default).infer_objects(copy=False)
|
|
241
|
+
return context.with_df(df)
|
|
242
|
+
|
|
243
|
+
else:
|
|
244
|
+
raise ValueError(
|
|
245
|
+
f"Dict-based mapping does not support engine type '{context.engine_type}'. "
|
|
246
|
+
f"Supported engines: SPARK, PANDAS. "
|
|
247
|
+
f"Check your engine configuration."
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
# -------------------------------------------------------------------------
|
|
252
|
+
# 4. Regex Replace
|
|
253
|
+
# -------------------------------------------------------------------------
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
class RegexReplaceParams(BaseModel):
|
|
257
|
+
"""
|
|
258
|
+
Configuration for regex replacement.
|
|
259
|
+
|
|
260
|
+
Example:
|
|
261
|
+
```yaml
|
|
262
|
+
regex_replace:
|
|
263
|
+
column: "phone"
|
|
264
|
+
pattern: "[^0-9]"
|
|
265
|
+
replacement: ""
|
|
266
|
+
```
|
|
267
|
+
"""
|
|
268
|
+
|
|
269
|
+
column: str = Field(..., description="Column to apply regex replacement on")
|
|
270
|
+
pattern: str = Field(..., description="Regex pattern to match")
|
|
271
|
+
replacement: str = Field(..., description="String to replace matches with")
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def regex_replace(context: EngineContext, params: RegexReplaceParams) -> EngineContext:
|
|
275
|
+
"""
|
|
276
|
+
SQL-based Regex replacement.
|
|
277
|
+
"""
|
|
278
|
+
# Spark and DuckDB both support REGEXP_REPLACE(col, pattern, replacement)
|
|
279
|
+
sql_query = f"SELECT *, REGEXP_REPLACE({params.column}, '{params.pattern}', '{params.replacement}') AS {params.column} FROM df"
|
|
280
|
+
return context.sql(sql_query)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
# -------------------------------------------------------------------------
|
|
284
|
+
# 5. Unpack Struct (Flatten)
|
|
285
|
+
# -------------------------------------------------------------------------
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
class UnpackStructParams(BaseModel):
|
|
289
|
+
"""
|
|
290
|
+
Configuration for unpacking structs.
|
|
291
|
+
|
|
292
|
+
Example:
|
|
293
|
+
```yaml
|
|
294
|
+
unpack_struct:
|
|
295
|
+
column: "user_info"
|
|
296
|
+
```
|
|
297
|
+
"""
|
|
298
|
+
|
|
299
|
+
column: str = Field(
|
|
300
|
+
..., description="Struct/Dictionary column to unpack/flatten into individual columns"
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def unpack_struct(context: EngineContext, params: UnpackStructParams) -> EngineContext:
|
|
305
|
+
"""
|
|
306
|
+
Flattens a struct/dict column into top-level columns.
|
|
307
|
+
"""
|
|
308
|
+
if context.engine_type == EngineType.SPARK:
|
|
309
|
+
# Spark: "select col.* from df"
|
|
310
|
+
sql_query = f"SELECT *, {params.column}.* FROM df"
|
|
311
|
+
# Usually we want to drop the original struct?
|
|
312
|
+
# For safety, we keep original but append fields.
|
|
313
|
+
# Actually "SELECT *" includes the struct.
|
|
314
|
+
# Let's assume users drop it later or we just select expanded.
|
|
315
|
+
return context.sql(sql_query)
|
|
316
|
+
|
|
317
|
+
elif context.engine_type == EngineType.PANDAS:
|
|
318
|
+
import pandas as pd
|
|
319
|
+
|
|
320
|
+
# Pandas: json_normalize or Apply(pd.Series)
|
|
321
|
+
# Optimization: df[col].tolist() is much faster than apply(pd.Series)
|
|
322
|
+
# assuming the column contains dictionaries/structs.
|
|
323
|
+
try:
|
|
324
|
+
expanded = pd.DataFrame(context.df[params.column].tolist(), index=context.df.index)
|
|
325
|
+
except Exception as e:
|
|
326
|
+
import logging
|
|
327
|
+
|
|
328
|
+
logger = logging.getLogger(__name__)
|
|
329
|
+
logger.debug(f"Optimized struct unpack failed (falling back to slow apply): {e}")
|
|
330
|
+
# Fallback if tolist() fails (e.g. mixed types)
|
|
331
|
+
expanded = context.df[params.column].apply(pd.Series)
|
|
332
|
+
|
|
333
|
+
# Rename to avoid collisions? Default behavior is to use keys.
|
|
334
|
+
# Join back
|
|
335
|
+
res = pd.concat([context.df, expanded], axis=1)
|
|
336
|
+
return context.with_df(res)
|
|
337
|
+
|
|
338
|
+
else:
|
|
339
|
+
raise ValueError(
|
|
340
|
+
f"Unpack struct does not support engine type '{context.engine_type}'. "
|
|
341
|
+
f"Supported engines: SPARK, PANDAS. "
|
|
342
|
+
f"Check your engine configuration."
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
# -------------------------------------------------------------------------
|
|
347
|
+
# 6. Hash Columns
|
|
348
|
+
# -------------------------------------------------------------------------
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
class HashAlgorithm(str, Enum):
|
|
352
|
+
SHA256 = "sha256"
|
|
353
|
+
MD5 = "md5"
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
class HashParams(BaseModel):
|
|
357
|
+
"""
|
|
358
|
+
Configuration for column hashing.
|
|
359
|
+
|
|
360
|
+
Example:
|
|
361
|
+
```yaml
|
|
362
|
+
hash_columns:
|
|
363
|
+
columns: ["email", "ssn"]
|
|
364
|
+
algorithm: "sha256"
|
|
365
|
+
```
|
|
366
|
+
"""
|
|
367
|
+
|
|
368
|
+
columns: List[str] = Field(..., description="List of columns to hash")
|
|
369
|
+
algorithm: HashAlgorithm = Field(
|
|
370
|
+
HashAlgorithm.SHA256, description="Hashing algorithm. Options: 'sha256', 'md5'"
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
def hash_columns(context: EngineContext, params: HashParams) -> EngineContext:
|
|
375
|
+
"""
|
|
376
|
+
Hashes columns for PII/Anonymization.
|
|
377
|
+
"""
|
|
378
|
+
# Removed unused 'expressions' variable
|
|
379
|
+
|
|
380
|
+
# Since SQL syntax differs, use Dual Engine
|
|
381
|
+
if context.engine_type == EngineType.SPARK:
|
|
382
|
+
import pyspark.sql.functions as F
|
|
383
|
+
|
|
384
|
+
df = context.df
|
|
385
|
+
for col in params.columns:
|
|
386
|
+
if params.algorithm == HashAlgorithm.SHA256:
|
|
387
|
+
df = df.withColumn(col, F.sha2(F.col(col), 256))
|
|
388
|
+
elif params.algorithm == HashAlgorithm.MD5:
|
|
389
|
+
df = df.withColumn(col, F.md5(F.col(col)))
|
|
390
|
+
return context.with_df(df)
|
|
391
|
+
|
|
392
|
+
elif context.engine_type == EngineType.PANDAS:
|
|
393
|
+
df = context.df.copy()
|
|
394
|
+
|
|
395
|
+
# Optimization: Try PyArrow compute for vectorized hashing if available
|
|
396
|
+
# For now, the below logic is a placeholder for future vectorized hashing.
|
|
397
|
+
# The import is unused in the current implementation fallback, triggering linter errors.
|
|
398
|
+
# We will stick to the stable hashlib fallback for now.
|
|
399
|
+
pass
|
|
400
|
+
|
|
401
|
+
import hashlib
|
|
402
|
+
|
|
403
|
+
def hash_val(val, alg):
|
|
404
|
+
if val is None:
|
|
405
|
+
return None
|
|
406
|
+
encoded = str(val).encode("utf-8")
|
|
407
|
+
if alg == HashAlgorithm.SHA256:
|
|
408
|
+
return hashlib.sha256(encoded).hexdigest()
|
|
409
|
+
return hashlib.md5(encoded).hexdigest()
|
|
410
|
+
|
|
411
|
+
# Vectorize? difficult with standard lib hashlib.
|
|
412
|
+
# Apply is acceptable for this security feature vs complexity of numpy deps
|
|
413
|
+
for col in params.columns:
|
|
414
|
+
# Optimization: Ensure string type once
|
|
415
|
+
s_col = df[col].astype(str)
|
|
416
|
+
df[col] = s_col.apply(lambda x: hash_val(x, params.algorithm))
|
|
417
|
+
|
|
418
|
+
return context.with_df(df)
|
|
419
|
+
|
|
420
|
+
else:
|
|
421
|
+
raise ValueError(f"Unsupported engine: {context.engine_type}")
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
# -------------------------------------------------------------------------
|
|
425
|
+
# 7. Generate Surrogate Key
|
|
426
|
+
# -------------------------------------------------------------------------
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
class SurrogateKeyParams(BaseModel):
|
|
430
|
+
"""
|
|
431
|
+
Configuration for surrogate key generation.
|
|
432
|
+
|
|
433
|
+
Example:
|
|
434
|
+
```yaml
|
|
435
|
+
generate_surrogate_key:
|
|
436
|
+
columns: ["region", "product_id"]
|
|
437
|
+
separator: "-"
|
|
438
|
+
output_col: "unique_id"
|
|
439
|
+
```
|
|
440
|
+
"""
|
|
441
|
+
|
|
442
|
+
columns: List[str] = Field(..., description="Columns to combine for the key")
|
|
443
|
+
separator: str = Field("-", description="Separator between values")
|
|
444
|
+
output_col: str = Field("surrogate_key", description="Name of the output column")
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
def generate_surrogate_key(context: EngineContext, params: SurrogateKeyParams) -> EngineContext:
|
|
448
|
+
"""
|
|
449
|
+
Generates a deterministic surrogate key (MD5) from a combination of columns.
|
|
450
|
+
Handles NULLs by treating them as empty strings to ensure consistency.
|
|
451
|
+
"""
|
|
452
|
+
# Logic: MD5( CONCAT_WS( separator, COALESCE(col1, ''), COALESCE(col2, '') ... ) )
|
|
453
|
+
|
|
454
|
+
from odibi.enums import EngineType
|
|
455
|
+
|
|
456
|
+
# 1. Build the concatenation expression
|
|
457
|
+
# We must cast to string and coalesce nulls
|
|
458
|
+
|
|
459
|
+
def safe_col(col, quote_char):
|
|
460
|
+
# Spark/DuckDB cast syntax slightly different but standard SQL CAST(x AS STRING) usually works
|
|
461
|
+
# Spark: cast(col as string) with backticks for quoting
|
|
462
|
+
# DuckDB: cast(col as varchar) with double quotes for quoting
|
|
463
|
+
return f"COALESCE(CAST({quote_char}{col}{quote_char} AS STRING), '')"
|
|
464
|
+
|
|
465
|
+
if context.engine_type == EngineType.SPARK:
|
|
466
|
+
# Spark CONCAT_WS skips nulls, but we coerced them to empty string above anyway for safety.
|
|
467
|
+
# Actually, if we want strict "dbt style" surrogate keys, we often treat NULL as a specific token.
|
|
468
|
+
# But empty string is standard for "simple" SKs.
|
|
469
|
+
quote_char = "`"
|
|
470
|
+
cols_expr = ", ".join([safe_col(c, quote_char) for c in params.columns])
|
|
471
|
+
concat_expr = f"concat_ws('{params.separator}', {cols_expr})"
|
|
472
|
+
final_expr = f"md5({concat_expr})"
|
|
473
|
+
output_col = f"`{params.output_col}`"
|
|
474
|
+
|
|
475
|
+
else:
|
|
476
|
+
# DuckDB / Pandas
|
|
477
|
+
# DuckDB also supports concat_ws and md5.
|
|
478
|
+
# Note: DuckDB CAST AS STRING is valid.
|
|
479
|
+
quote_char = '"'
|
|
480
|
+
cols_expr = ", ".join([safe_col(c, quote_char) for c in params.columns])
|
|
481
|
+
concat_expr = f"concat_ws('{params.separator}', {cols_expr})"
|
|
482
|
+
final_expr = f"md5({concat_expr})"
|
|
483
|
+
output_col = f'"{params.output_col}"'
|
|
484
|
+
|
|
485
|
+
sql_query = f"SELECT *, {final_expr} AS {output_col} FROM df"
|
|
486
|
+
return context.sql(sql_query)
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
# -------------------------------------------------------------------------
|
|
490
|
+
# 7b. Generate Numeric Key (BIGINT surrogate key)
|
|
491
|
+
# -------------------------------------------------------------------------
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
class NumericKeyParams(BaseModel):
|
|
495
|
+
"""
|
|
496
|
+
Configuration for numeric surrogate key generation.
|
|
497
|
+
|
|
498
|
+
Generates a deterministic BIGINT key from a hash of specified columns.
|
|
499
|
+
Useful when unioning data from multiple sources where some have IDs
|
|
500
|
+
and others don't.
|
|
501
|
+
|
|
502
|
+
Example:
|
|
503
|
+
```yaml
|
|
504
|
+
- function: generate_numeric_key
|
|
505
|
+
params:
|
|
506
|
+
columns: [DateID, store_id, reason_id, duration_min, notes]
|
|
507
|
+
output_col: ID
|
|
508
|
+
coalesce_with: ID # Keep existing ID if not null
|
|
509
|
+
```
|
|
510
|
+
|
|
511
|
+
The generated key is:
|
|
512
|
+
- Deterministic: same input data = same ID every time
|
|
513
|
+
- BIGINT: large numeric space to avoid collisions
|
|
514
|
+
- Stable: safe for gold layer / incremental loads
|
|
515
|
+
"""
|
|
516
|
+
|
|
517
|
+
columns: List[str] = Field(..., description="Columns to combine for the key")
|
|
518
|
+
separator: str = Field("|", description="Separator between values")
|
|
519
|
+
output_col: str = Field("numeric_key", description="Name of the output column")
|
|
520
|
+
coalesce_with: Optional[str] = Field(
|
|
521
|
+
None,
|
|
522
|
+
description="Existing column to coalesce with (keep existing value if not null)",
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
def generate_numeric_key(context: EngineContext, params: NumericKeyParams) -> EngineContext:
|
|
527
|
+
"""
|
|
528
|
+
Generates a deterministic BIGINT surrogate key from a hash of columns.
|
|
529
|
+
|
|
530
|
+
This is useful when:
|
|
531
|
+
- Unioning data from multiple sources
|
|
532
|
+
- Some sources have IDs, some don't
|
|
533
|
+
- You need stable numeric IDs for gold layer
|
|
534
|
+
|
|
535
|
+
The key is generated by:
|
|
536
|
+
1. Concatenating columns with separator
|
|
537
|
+
2. Computing MD5 hash
|
|
538
|
+
3. Converting first 15 hex chars to BIGINT
|
|
539
|
+
|
|
540
|
+
If coalesce_with is specified, keeps the existing value when not null.
|
|
541
|
+
If output_col == coalesce_with, the original column is replaced.
|
|
542
|
+
"""
|
|
543
|
+
from odibi.enums import EngineType
|
|
544
|
+
|
|
545
|
+
def safe_col(col, quote_char):
|
|
546
|
+
# Normalize: TRIM whitespace, then treat empty string and NULL as equivalent
|
|
547
|
+
return f"COALESCE(NULLIF(TRIM(CAST({quote_char}{col}{quote_char} AS STRING)), ''), '')"
|
|
548
|
+
|
|
549
|
+
# Check if we need to replace the original column
|
|
550
|
+
# Replace if: coalesce_with == output_col, OR output_col already exists in dataframe
|
|
551
|
+
col_names = list(context.df.columns)
|
|
552
|
+
output_exists = params.output_col in col_names
|
|
553
|
+
replace_column = (
|
|
554
|
+
params.coalesce_with and params.output_col == params.coalesce_with
|
|
555
|
+
) or output_exists
|
|
556
|
+
|
|
557
|
+
if context.engine_type == EngineType.SPARK:
|
|
558
|
+
quote_char = "`"
|
|
559
|
+
cols_expr = ", ".join([safe_col(c, quote_char) for c in params.columns])
|
|
560
|
+
concat_expr = f"concat_ws('{params.separator}', {cols_expr})"
|
|
561
|
+
hash_expr = f"CAST(CONV(SUBSTRING(md5({concat_expr}), 1, 15), 16, 10) AS BIGINT)"
|
|
562
|
+
|
|
563
|
+
if params.coalesce_with:
|
|
564
|
+
final_expr = f"COALESCE(CAST(`{params.coalesce_with}` AS BIGINT), {hash_expr})"
|
|
565
|
+
else:
|
|
566
|
+
final_expr = hash_expr
|
|
567
|
+
|
|
568
|
+
output_col = f"`{params.output_col}`"
|
|
569
|
+
|
|
570
|
+
else:
|
|
571
|
+
# DuckDB/Pandas - use ABS(HASH()) instead of CONV (DuckDB doesn't have CONV)
|
|
572
|
+
quote_char = '"'
|
|
573
|
+
cols_expr = ", ".join([safe_col(c, quote_char) for c in params.columns])
|
|
574
|
+
concat_expr = f"concat_ws('{params.separator}', {cols_expr})"
|
|
575
|
+
# hash() in DuckDB returns BIGINT directly
|
|
576
|
+
hash_expr = f"ABS(hash({concat_expr}))"
|
|
577
|
+
|
|
578
|
+
if params.coalesce_with:
|
|
579
|
+
final_expr = f'COALESCE(CAST("{params.coalesce_with}" AS BIGINT), {hash_expr})'
|
|
580
|
+
else:
|
|
581
|
+
final_expr = hash_expr
|
|
582
|
+
|
|
583
|
+
output_col = f'"{params.output_col}"'
|
|
584
|
+
|
|
585
|
+
if replace_column:
|
|
586
|
+
# Replace the original column by selecting all columns except the original,
|
|
587
|
+
# then adding the new computed column
|
|
588
|
+
col_to_exclude = params.coalesce_with if params.coalesce_with else params.output_col
|
|
589
|
+
if context.engine_type == EngineType.SPARK:
|
|
590
|
+
all_cols = [f"`{c}`" for c in col_names if c != col_to_exclude]
|
|
591
|
+
else:
|
|
592
|
+
# Pandas/DuckDB
|
|
593
|
+
all_cols = [f'"{c}"' for c in col_names if c != col_to_exclude]
|
|
594
|
+
cols_select = ", ".join(all_cols)
|
|
595
|
+
sql_query = f"SELECT {cols_select}, {final_expr} AS {output_col} FROM df"
|
|
596
|
+
else:
|
|
597
|
+
sql_query = f"SELECT *, {final_expr} AS {output_col} FROM df"
|
|
598
|
+
|
|
599
|
+
return context.sql(sql_query)
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
# -------------------------------------------------------------------------
|
|
603
|
+
# 8. Parse JSON
|
|
604
|
+
# -------------------------------------------------------------------------
|
|
605
|
+
|
|
606
|
+
|
|
607
|
+
class ParseJsonParams(BaseModel):
|
|
608
|
+
"""
|
|
609
|
+
Configuration for JSON parsing.
|
|
610
|
+
|
|
611
|
+
Example:
|
|
612
|
+
```yaml
|
|
613
|
+
parse_json:
|
|
614
|
+
column: "raw_json"
|
|
615
|
+
json_schema: "id INT, name STRING"
|
|
616
|
+
output_col: "parsed_struct"
|
|
617
|
+
```
|
|
618
|
+
"""
|
|
619
|
+
|
|
620
|
+
column: str = Field(..., description="String column containing JSON")
|
|
621
|
+
json_schema: str = Field(
|
|
622
|
+
..., description="DDL schema string (e.g. 'a INT, b STRING') or Spark StructType DDL"
|
|
623
|
+
)
|
|
624
|
+
output_col: Optional[str] = None
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
def parse_json(context: EngineContext, params: ParseJsonParams) -> EngineContext:
|
|
628
|
+
"""
|
|
629
|
+
Parses a JSON string column into a Struct/Map column.
|
|
630
|
+
"""
|
|
631
|
+
from odibi.enums import EngineType
|
|
632
|
+
|
|
633
|
+
target = params.output_col or f"{params.column}_parsed"
|
|
634
|
+
|
|
635
|
+
if context.engine_type == EngineType.SPARK:
|
|
636
|
+
# Spark: from_json(col, schema)
|
|
637
|
+
expr = f"from_json({params.column}, '{params.json_schema}')"
|
|
638
|
+
|
|
639
|
+
else:
|
|
640
|
+
# DuckDB / Pandas
|
|
641
|
+
# DuckDB: json_transform(col, 'schema') is experimental.
|
|
642
|
+
# Standard: from_json(col, 'schema') works in recent DuckDB versions.
|
|
643
|
+
# But reliable way is usually casting or json extraction if we know the structure?
|
|
644
|
+
# Actually, DuckDB allows: cast(json_parse(col) as STRUCT(a INT, b VARCHAR...))
|
|
645
|
+
|
|
646
|
+
# We need to convert the generic DDL schema string to DuckDB STRUCT syntax?
|
|
647
|
+
# That is complex.
|
|
648
|
+
# SIMPLIFICATION: For DuckDB, we might rely on automatic inference if we use `json_parse`?
|
|
649
|
+
# Or just `json_parse(col)` which returns a JSON type (which is distinct).
|
|
650
|
+
# Then user can unpack it.
|
|
651
|
+
|
|
652
|
+
# Let's try `json_parse(col)`.
|
|
653
|
+
# Note: If user provided specific schema to enforce types, that's harder in DuckDB SQL string without parsing the DDL.
|
|
654
|
+
# Spark's schema string "a INT, b STRING" is not valid DuckDB STRUCT(a INT, b VARCHAR).
|
|
655
|
+
|
|
656
|
+
# For V1 of this function, we will focus on Spark (where it's critical).
|
|
657
|
+
# For DuckDB, we will use `CAST(col AS JSON)` which is the standard way to parse JSON string to JSON type.
|
|
658
|
+
# `json_parse` is an alias in some versions but CAST is more stable.
|
|
659
|
+
|
|
660
|
+
expr = f"CAST({params.column} AS JSON)"
|
|
661
|
+
|
|
662
|
+
sql_query = f"SELECT *, {expr} AS {target} FROM df"
|
|
663
|
+
return context.sql(sql_query)
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
# -------------------------------------------------------------------------
|
|
667
|
+
# 9. Validate and Flag
|
|
668
|
+
# -------------------------------------------------------------------------
|
|
669
|
+
|
|
670
|
+
|
|
671
|
+
class ValidateAndFlagParams(BaseModel):
|
|
672
|
+
"""
|
|
673
|
+
Configuration for validation flagging.
|
|
674
|
+
|
|
675
|
+
Example:
|
|
676
|
+
```yaml
|
|
677
|
+
validate_and_flag:
|
|
678
|
+
flag_col: "data_issues"
|
|
679
|
+
rules:
|
|
680
|
+
age_check: "age >= 0"
|
|
681
|
+
email_format: "email LIKE '%@%'"
|
|
682
|
+
```
|
|
683
|
+
"""
|
|
684
|
+
|
|
685
|
+
# key: rule name, value: sql condition (must be true for valid)
|
|
686
|
+
rules: Dict[str, str] = Field(
|
|
687
|
+
..., description="Map of rule name to SQL condition (must be TRUE)"
|
|
688
|
+
)
|
|
689
|
+
flag_col: str = Field("_issues", description="Name of the column to store failed rules")
|
|
690
|
+
|
|
691
|
+
@field_validator("rules")
|
|
692
|
+
@classmethod
|
|
693
|
+
def require_non_empty_rules(cls, v):
|
|
694
|
+
if not v:
|
|
695
|
+
raise ValueError("ValidateAndFlag: 'rules' must not be empty")
|
|
696
|
+
return v
|
|
697
|
+
|
|
698
|
+
|
|
699
|
+
def validate_and_flag(context: EngineContext, params: ValidateAndFlagParams) -> EngineContext:
|
|
700
|
+
"""
|
|
701
|
+
Validates rules and appends a column with a list/string of failed rule names.
|
|
702
|
+
"""
|
|
703
|
+
ctx = get_logging_context()
|
|
704
|
+
start_time = time.time()
|
|
705
|
+
|
|
706
|
+
ctx.debug(
|
|
707
|
+
"ValidateAndFlag starting",
|
|
708
|
+
rules=list(params.rules.keys()),
|
|
709
|
+
flag_col=params.flag_col,
|
|
710
|
+
)
|
|
711
|
+
|
|
712
|
+
rule_exprs = []
|
|
713
|
+
|
|
714
|
+
for name, condition in params.rules.items():
|
|
715
|
+
expr = f"CASE WHEN NOT ({condition}) THEN '{name}' ELSE NULL END"
|
|
716
|
+
rule_exprs.append(expr)
|
|
717
|
+
|
|
718
|
+
if not rule_exprs:
|
|
719
|
+
return context.sql(f"SELECT *, NULL AS {params.flag_col} FROM df")
|
|
720
|
+
|
|
721
|
+
concatted = f"concat_ws(', ', {', '.join(rule_exprs)})"
|
|
722
|
+
final_expr = f"NULLIF({concatted}, '')"
|
|
723
|
+
|
|
724
|
+
sql_query = f"SELECT *, {final_expr} AS {params.flag_col} FROM df"
|
|
725
|
+
result = context.sql(sql_query)
|
|
726
|
+
|
|
727
|
+
elapsed_ms = (time.time() - start_time) * 1000
|
|
728
|
+
ctx.debug(
|
|
729
|
+
"ValidateAndFlag completed",
|
|
730
|
+
rules_count=len(params.rules),
|
|
731
|
+
elapsed_ms=round(elapsed_ms, 2),
|
|
732
|
+
)
|
|
733
|
+
|
|
734
|
+
return result
|
|
735
|
+
|
|
736
|
+
|
|
737
|
+
# -------------------------------------------------------------------------
|
|
738
|
+
# 10. Window Calculation
|
|
739
|
+
# -------------------------------------------------------------------------
|
|
740
|
+
|
|
741
|
+
|
|
742
|
+
class WindowCalculationParams(BaseModel):
|
|
743
|
+
"""
|
|
744
|
+
Configuration for window functions.
|
|
745
|
+
|
|
746
|
+
Example:
|
|
747
|
+
```yaml
|
|
748
|
+
window_calculation:
|
|
749
|
+
target_col: "cumulative_sales"
|
|
750
|
+
function: "sum(sales)"
|
|
751
|
+
partition_by: ["region"]
|
|
752
|
+
order_by: "date ASC"
|
|
753
|
+
```
|
|
754
|
+
"""
|
|
755
|
+
|
|
756
|
+
target_col: str
|
|
757
|
+
function: str = Field(..., description="Window function e.g. 'sum(amount)', 'rank()'")
|
|
758
|
+
partition_by: List[str] = Field(default_factory=list)
|
|
759
|
+
order_by: Optional[str] = None
|
|
760
|
+
|
|
761
|
+
|
|
762
|
+
def window_calculation(context: EngineContext, params: WindowCalculationParams) -> EngineContext:
|
|
763
|
+
"""
|
|
764
|
+
Generic wrapper for Window functions.
|
|
765
|
+
"""
|
|
766
|
+
partition_clause = ""
|
|
767
|
+
if params.partition_by:
|
|
768
|
+
partition_clause = f"PARTITION BY {', '.join(params.partition_by)}"
|
|
769
|
+
|
|
770
|
+
order_clause = ""
|
|
771
|
+
if params.order_by:
|
|
772
|
+
order_clause = f"ORDER BY {params.order_by}"
|
|
773
|
+
|
|
774
|
+
over_clause = f"OVER ({partition_clause} {order_clause})".strip()
|
|
775
|
+
|
|
776
|
+
expr = f"{params.function} {over_clause}"
|
|
777
|
+
|
|
778
|
+
sql_query = f"SELECT *, {expr} AS {params.target_col} FROM df"
|
|
779
|
+
return context.sql(sql_query)
|
|
780
|
+
|
|
781
|
+
|
|
782
|
+
# -------------------------------------------------------------------------
|
|
783
|
+
# 11. Normalize JSON
|
|
784
|
+
# -------------------------------------------------------------------------
|
|
785
|
+
|
|
786
|
+
|
|
787
|
+
class NormalizeJsonParams(BaseModel):
|
|
788
|
+
"""
|
|
789
|
+
Configuration for JSON normalization.
|
|
790
|
+
|
|
791
|
+
Example:
|
|
792
|
+
```yaml
|
|
793
|
+
normalize_json:
|
|
794
|
+
column: "json_data"
|
|
795
|
+
sep: "_"
|
|
796
|
+
```
|
|
797
|
+
"""
|
|
798
|
+
|
|
799
|
+
column: str = Field(..., description="Column containing nested JSON/Struct")
|
|
800
|
+
sep: str = Field("_", description="Separator for nested fields (e.g., 'parent_child')")
|
|
801
|
+
|
|
802
|
+
|
|
803
|
+
def normalize_json(context: EngineContext, params: NormalizeJsonParams) -> EngineContext:
|
|
804
|
+
"""
|
|
805
|
+
Flattens a nested JSON/Struct column.
|
|
806
|
+
"""
|
|
807
|
+
if context.engine_type == EngineType.SPARK:
|
|
808
|
+
# Spark: Top-level flatten using "col.*"
|
|
809
|
+
sql_query = f"SELECT *, {params.column}.* FROM df"
|
|
810
|
+
return context.sql(sql_query)
|
|
811
|
+
|
|
812
|
+
elif context.engine_type == EngineType.PANDAS:
|
|
813
|
+
import json
|
|
814
|
+
|
|
815
|
+
import pandas as pd
|
|
816
|
+
|
|
817
|
+
df = context.df.copy()
|
|
818
|
+
|
|
819
|
+
# Ensure we have dicts
|
|
820
|
+
s = df[params.column]
|
|
821
|
+
if len(s) > 0:
|
|
822
|
+
first_val = s.iloc[0]
|
|
823
|
+
if isinstance(first_val, str):
|
|
824
|
+
# Try to parse if string
|
|
825
|
+
try:
|
|
826
|
+
s = s.apply(json.loads)
|
|
827
|
+
except Exception as e:
|
|
828
|
+
import logging
|
|
829
|
+
|
|
830
|
+
logger = logging.getLogger(__name__)
|
|
831
|
+
logger.warning(f"Failed to parse JSON strings in column '{params.column}': {e}")
|
|
832
|
+
# We proceed, but json_normalize will likely fail if data is not dicts.
|
|
833
|
+
|
|
834
|
+
# json_normalize
|
|
835
|
+
# Handle empty case
|
|
836
|
+
if s.empty:
|
|
837
|
+
return context.with_df(df)
|
|
838
|
+
|
|
839
|
+
normalized = pd.json_normalize(s, sep=params.sep)
|
|
840
|
+
# Align index
|
|
841
|
+
normalized.index = df.index
|
|
842
|
+
|
|
843
|
+
# Join back (avoid collision if possible, or use suffixes)
|
|
844
|
+
# We use rsuffix just in case
|
|
845
|
+
df = df.join(normalized, rsuffix="_json")
|
|
846
|
+
return context.with_df(df)
|
|
847
|
+
|
|
848
|
+
else:
|
|
849
|
+
raise ValueError(f"Unsupported engine: {context.engine_type}")
|
|
850
|
+
|
|
851
|
+
|
|
852
|
+
# -------------------------------------------------------------------------
|
|
853
|
+
# 12. Sessionize
|
|
854
|
+
# -------------------------------------------------------------------------
|
|
855
|
+
|
|
856
|
+
|
|
857
|
+
class SessionizeParams(BaseModel):
|
|
858
|
+
"""
|
|
859
|
+
Configuration for sessionization.
|
|
860
|
+
|
|
861
|
+
Example:
|
|
862
|
+
```yaml
|
|
863
|
+
sessionize:
|
|
864
|
+
timestamp_col: "event_time"
|
|
865
|
+
user_col: "user_id"
|
|
866
|
+
threshold_seconds: 1800
|
|
867
|
+
```
|
|
868
|
+
"""
|
|
869
|
+
|
|
870
|
+
timestamp_col: str = Field(
|
|
871
|
+
..., description="Timestamp column to calculate session duration from"
|
|
872
|
+
)
|
|
873
|
+
user_col: str = Field(..., description="User identifier to partition sessions by")
|
|
874
|
+
threshold_seconds: int = Field(
|
|
875
|
+
1800,
|
|
876
|
+
description="Inactivity threshold in seconds (default: 30 minutes). If gap > threshold, new session starts.",
|
|
877
|
+
)
|
|
878
|
+
session_col: str = Field(
|
|
879
|
+
"session_id", description="Output column name for the generated session ID"
|
|
880
|
+
)
|
|
881
|
+
|
|
882
|
+
|
|
883
|
+
def sessionize(context: EngineContext, params: SessionizeParams) -> EngineContext:
|
|
884
|
+
"""
|
|
885
|
+
Assigns session IDs based on inactivity threshold.
|
|
886
|
+
"""
|
|
887
|
+
if context.engine_type == EngineType.SPARK:
|
|
888
|
+
# Spark SQL
|
|
889
|
+
# 1. Lag timestamp to get prev_timestamp
|
|
890
|
+
# 2. Calculate diff: ts - prev_ts
|
|
891
|
+
# 3. Flag new session: if diff > threshold OR prev_ts is null -> 1 else 0
|
|
892
|
+
# 4. Sum(flags) over (partition by user order by ts) -> session_id
|
|
893
|
+
|
|
894
|
+
threshold = params.threshold_seconds
|
|
895
|
+
|
|
896
|
+
# We use nested queries for clarity and safety against multiple aggregations
|
|
897
|
+
sql = f"""
|
|
898
|
+
WITH lagged AS (
|
|
899
|
+
SELECT *,
|
|
900
|
+
LAG({params.timestamp_col}) OVER (PARTITION BY {params.user_col} ORDER BY {params.timestamp_col}) as _prev_ts
|
|
901
|
+
FROM df
|
|
902
|
+
),
|
|
903
|
+
flagged AS (
|
|
904
|
+
SELECT *,
|
|
905
|
+
CASE
|
|
906
|
+
WHEN _prev_ts IS NULL THEN 1
|
|
907
|
+
WHEN (unix_timestamp({params.timestamp_col}) - unix_timestamp(_prev_ts)) > {threshold} THEN 1
|
|
908
|
+
ELSE 0
|
|
909
|
+
END as _is_new_session
|
|
910
|
+
FROM lagged
|
|
911
|
+
)
|
|
912
|
+
SELECT *,
|
|
913
|
+
concat({params.user_col}, '-', sum(_is_new_session) OVER (PARTITION BY {params.user_col} ORDER BY {params.timestamp_col})) as {params.session_col}
|
|
914
|
+
FROM flagged
|
|
915
|
+
"""
|
|
916
|
+
# Note: This returns intermediate columns (_prev_ts, _is_new_session) as well.
|
|
917
|
+
# Ideally we select * EXCEPT ... but Spark < 3.1 doesn't support EXCEPT in SELECT list easily without listing all cols.
|
|
918
|
+
# We leave them for now, or user can drop them.
|
|
919
|
+
return context.sql(sql)
|
|
920
|
+
|
|
921
|
+
elif context.engine_type == EngineType.PANDAS:
|
|
922
|
+
import pandas as pd
|
|
923
|
+
|
|
924
|
+
df = context.df.copy()
|
|
925
|
+
|
|
926
|
+
# Ensure datetime
|
|
927
|
+
if not pd.api.types.is_datetime64_any_dtype(df[params.timestamp_col]):
|
|
928
|
+
df[params.timestamp_col] = pd.to_datetime(df[params.timestamp_col])
|
|
929
|
+
|
|
930
|
+
# Sort
|
|
931
|
+
df = df.sort_values([params.user_col, params.timestamp_col])
|
|
932
|
+
|
|
933
|
+
user = df[params.user_col]
|
|
934
|
+
|
|
935
|
+
# Calculate time diff (in seconds)
|
|
936
|
+
# We groupby user to ensure shift doesn't cross user boundaries for diff
|
|
937
|
+
# But diff() doesn't support groupby well directly on Series without apply?
|
|
938
|
+
# Actually `groupby().diff()` works.
|
|
939
|
+
time_diff = df.groupby(params.user_col)[params.timestamp_col].diff().dt.total_seconds()
|
|
940
|
+
|
|
941
|
+
# Flag new session
|
|
942
|
+
# New if: time_diff > threshold OR time_diff is NaT (start of group)
|
|
943
|
+
is_new = (time_diff > params.threshold_seconds) | (time_diff.isna())
|
|
944
|
+
|
|
945
|
+
# Cumulative sum per user
|
|
946
|
+
session_ids = is_new.groupby(user).cumsum()
|
|
947
|
+
|
|
948
|
+
df[params.session_col] = user.astype(str) + "-" + session_ids.astype(int).astype(str)
|
|
949
|
+
|
|
950
|
+
return context.with_df(df)
|
|
951
|
+
|
|
952
|
+
else:
|
|
953
|
+
raise ValueError(f"Unsupported engine: {context.engine_type}")
|
|
954
|
+
|
|
955
|
+
|
|
956
|
+
# -------------------------------------------------------------------------
|
|
957
|
+
# 13. Geocode (Stub)
|
|
958
|
+
# -------------------------------------------------------------------------
|
|
959
|
+
|
|
960
|
+
|
|
961
|
+
class GeocodeParams(BaseModel):
|
|
962
|
+
"""
|
|
963
|
+
Configuration for geocoding.
|
|
964
|
+
|
|
965
|
+
Example:
|
|
966
|
+
```yaml
|
|
967
|
+
geocode:
|
|
968
|
+
address_col: "full_address"
|
|
969
|
+
output_col: "lat_long"
|
|
970
|
+
```
|
|
971
|
+
"""
|
|
972
|
+
|
|
973
|
+
address_col: str = Field(..., description="Column containing the address to geocode")
|
|
974
|
+
output_col: str = Field("lat_long", description="Name of the output column for coordinates")
|
|
975
|
+
|
|
976
|
+
|
|
977
|
+
def geocode(context: EngineContext, params: GeocodeParams) -> EngineContext:
|
|
978
|
+
"""
|
|
979
|
+
Geocoding Stub.
|
|
980
|
+
"""
|
|
981
|
+
import logging
|
|
982
|
+
|
|
983
|
+
logger = logging.getLogger(__name__)
|
|
984
|
+
logger.warning("Geocode transformer is a stub. No actual geocoding performed.")
|
|
985
|
+
|
|
986
|
+
# Pass-through
|
|
987
|
+
return context.with_df(context.df)
|
|
988
|
+
|
|
989
|
+
|
|
990
|
+
# -------------------------------------------------------------------------
|
|
991
|
+
# 14. Split Events by Period
|
|
992
|
+
# -------------------------------------------------------------------------
|
|
993
|
+
|
|
994
|
+
|
|
995
|
+
class ShiftDefinition(BaseModel):
|
|
996
|
+
"""Definition of a single shift."""
|
|
997
|
+
|
|
998
|
+
name: str = Field(..., description="Name of the shift (e.g., 'Day', 'Night')")
|
|
999
|
+
start: str = Field(..., description="Start time in HH:MM format (e.g., '06:00')")
|
|
1000
|
+
end: str = Field(..., description="End time in HH:MM format (e.g., '14:00')")
|
|
1001
|
+
|
|
1002
|
+
|
|
1003
|
+
class SplitEventsByPeriodParams(BaseModel):
|
|
1004
|
+
"""
|
|
1005
|
+
Configuration for splitting events that span multiple time periods.
|
|
1006
|
+
|
|
1007
|
+
Splits events that span multiple days, hours, or shifts into individual
|
|
1008
|
+
segments per period. Useful for OEE/downtime analysis, billing, and
|
|
1009
|
+
time-based aggregations.
|
|
1010
|
+
|
|
1011
|
+
Example - Split by day:
|
|
1012
|
+
```yaml
|
|
1013
|
+
split_events_by_period:
|
|
1014
|
+
start_col: "Shutdown_Start_Time"
|
|
1015
|
+
end_col: "Shutdown_End_Time"
|
|
1016
|
+
period: "day"
|
|
1017
|
+
duration_col: "Shutdown_Duration_Min"
|
|
1018
|
+
```
|
|
1019
|
+
|
|
1020
|
+
Example - Split by shift:
|
|
1021
|
+
```yaml
|
|
1022
|
+
split_events_by_period:
|
|
1023
|
+
start_col: "event_start"
|
|
1024
|
+
end_col: "event_end"
|
|
1025
|
+
period: "shift"
|
|
1026
|
+
duration_col: "duration_minutes"
|
|
1027
|
+
shifts:
|
|
1028
|
+
- name: "Day"
|
|
1029
|
+
start: "06:00"
|
|
1030
|
+
end: "14:00"
|
|
1031
|
+
- name: "Swing"
|
|
1032
|
+
start: "14:00"
|
|
1033
|
+
end: "22:00"
|
|
1034
|
+
- name: "Night"
|
|
1035
|
+
start: "22:00"
|
|
1036
|
+
end: "06:00"
|
|
1037
|
+
```
|
|
1038
|
+
"""
|
|
1039
|
+
|
|
1040
|
+
start_col: str = Field(..., description="Column containing the event start timestamp")
|
|
1041
|
+
end_col: str = Field(..., description="Column containing the event end timestamp")
|
|
1042
|
+
period: str = Field(
|
|
1043
|
+
"day",
|
|
1044
|
+
description="Period type to split by: 'day', 'hour', or 'shift'",
|
|
1045
|
+
)
|
|
1046
|
+
duration_col: Optional[str] = Field(
|
|
1047
|
+
None,
|
|
1048
|
+
description="Output column name for duration in minutes. If not set, no duration column is added.",
|
|
1049
|
+
)
|
|
1050
|
+
shifts: Optional[List[ShiftDefinition]] = Field(
|
|
1051
|
+
None,
|
|
1052
|
+
description="List of shift definitions (required when period='shift')",
|
|
1053
|
+
)
|
|
1054
|
+
shift_col: Optional[str] = Field(
|
|
1055
|
+
"shift_name",
|
|
1056
|
+
description="Output column name for shift name (only used when period='shift')",
|
|
1057
|
+
)
|
|
1058
|
+
|
|
1059
|
+
def model_post_init(self, __context):
|
|
1060
|
+
if self.period == "shift" and not self.shifts:
|
|
1061
|
+
raise ValueError("shifts must be provided when period='shift'")
|
|
1062
|
+
if self.period not in ("day", "hour", "shift"):
|
|
1063
|
+
raise ValueError(f"Invalid period: {self.period}. Must be 'day', 'hour', or 'shift'")
|
|
1064
|
+
|
|
1065
|
+
|
|
1066
|
+
def split_events_by_period(
|
|
1067
|
+
context: EngineContext, params: SplitEventsByPeriodParams
|
|
1068
|
+
) -> EngineContext:
|
|
1069
|
+
"""
|
|
1070
|
+
Splits events that span multiple time periods into individual segments.
|
|
1071
|
+
|
|
1072
|
+
For events spanning multiple days/hours/shifts, this creates separate rows
|
|
1073
|
+
for each period with adjusted start/end times and recalculated durations.
|
|
1074
|
+
"""
|
|
1075
|
+
if params.period == "day":
|
|
1076
|
+
return _split_by_day(context, params)
|
|
1077
|
+
elif params.period == "hour":
|
|
1078
|
+
return _split_by_hour(context, params)
|
|
1079
|
+
elif params.period == "shift":
|
|
1080
|
+
return _split_by_shift(context, params)
|
|
1081
|
+
else:
|
|
1082
|
+
raise ValueError(f"Unsupported period: {params.period}")
|
|
1083
|
+
|
|
1084
|
+
|
|
1085
|
+
def _split_by_day(context: EngineContext, params: SplitEventsByPeriodParams) -> EngineContext:
|
|
1086
|
+
"""Split events by day boundaries."""
|
|
1087
|
+
start_col = params.start_col
|
|
1088
|
+
end_col = params.end_col
|
|
1089
|
+
|
|
1090
|
+
if context.engine_type == EngineType.SPARK:
|
|
1091
|
+
duration_col_exists = params.duration_col and params.duration_col.lower() in [
|
|
1092
|
+
c.lower() for c in context.df.columns
|
|
1093
|
+
]
|
|
1094
|
+
|
|
1095
|
+
ts_start = f"to_timestamp({start_col})"
|
|
1096
|
+
ts_end = f"to_timestamp({end_col})"
|
|
1097
|
+
|
|
1098
|
+
duration_expr = ""
|
|
1099
|
+
if params.duration_col:
|
|
1100
|
+
duration_expr = f", (unix_timestamp(adj_{end_col}) - unix_timestamp(adj_{start_col})) / 60.0 AS {params.duration_col}"
|
|
1101
|
+
|
|
1102
|
+
single_day_except = "_event_days"
|
|
1103
|
+
if duration_col_exists:
|
|
1104
|
+
single_day_except = f"_event_days, {params.duration_col}"
|
|
1105
|
+
|
|
1106
|
+
single_day_sql = f"""
|
|
1107
|
+
WITH events_with_days AS (
|
|
1108
|
+
SELECT *,
|
|
1109
|
+
{ts_start} AS _ts_start,
|
|
1110
|
+
{ts_end} AS _ts_end,
|
|
1111
|
+
datediff(to_date({ts_end}), to_date({ts_start})) + 1 AS _event_days
|
|
1112
|
+
FROM df
|
|
1113
|
+
)
|
|
1114
|
+
SELECT * EXCEPT({single_day_except}, _ts_start, _ts_end){f", (unix_timestamp(_ts_end) - unix_timestamp(_ts_start)) / 60.0 AS {params.duration_col}" if params.duration_col else ""}
|
|
1115
|
+
FROM events_with_days
|
|
1116
|
+
WHERE _event_days = 1
|
|
1117
|
+
"""
|
|
1118
|
+
|
|
1119
|
+
multi_day_except_adjusted = "_exploded_day, _event_days, _ts_start, _ts_end"
|
|
1120
|
+
if duration_col_exists:
|
|
1121
|
+
multi_day_except_adjusted = (
|
|
1122
|
+
f"_exploded_day, _event_days, _ts_start, _ts_end, {params.duration_col}"
|
|
1123
|
+
)
|
|
1124
|
+
|
|
1125
|
+
multi_day_sql = f"""
|
|
1126
|
+
WITH events_with_days AS (
|
|
1127
|
+
SELECT *,
|
|
1128
|
+
{ts_start} AS _ts_start,
|
|
1129
|
+
{ts_end} AS _ts_end,
|
|
1130
|
+
datediff(to_date({ts_end}), to_date({ts_start})) + 1 AS _event_days
|
|
1131
|
+
FROM df
|
|
1132
|
+
),
|
|
1133
|
+
multi_day AS (
|
|
1134
|
+
SELECT *,
|
|
1135
|
+
explode(sequence(to_date(_ts_start), to_date(_ts_end), interval 1 day)) AS _exploded_day
|
|
1136
|
+
FROM events_with_days
|
|
1137
|
+
WHERE _event_days > 1
|
|
1138
|
+
),
|
|
1139
|
+
multi_day_adjusted AS (
|
|
1140
|
+
SELECT * EXCEPT({multi_day_except_adjusted}),
|
|
1141
|
+
CASE
|
|
1142
|
+
WHEN to_date(_exploded_day) = to_date(_ts_start) THEN _ts_start
|
|
1143
|
+
ELSE to_timestamp(concat(cast(_exploded_day as string), ' 00:00:00'))
|
|
1144
|
+
END AS adj_{start_col},
|
|
1145
|
+
CASE
|
|
1146
|
+
WHEN to_date(_exploded_day) = to_date(_ts_end) THEN _ts_end
|
|
1147
|
+
ELSE to_timestamp(concat(cast(date_add(_exploded_day, 1) as string), ' 00:00:00'))
|
|
1148
|
+
END AS adj_{end_col}
|
|
1149
|
+
FROM multi_day
|
|
1150
|
+
)
|
|
1151
|
+
SELECT * EXCEPT({start_col}, {end_col}, adj_{start_col}, adj_{end_col}),
|
|
1152
|
+
adj_{start_col} AS {start_col},
|
|
1153
|
+
adj_{end_col} AS {end_col}
|
|
1154
|
+
{duration_expr}
|
|
1155
|
+
FROM multi_day_adjusted
|
|
1156
|
+
"""
|
|
1157
|
+
|
|
1158
|
+
single_day_df = context.sql(single_day_sql).df
|
|
1159
|
+
multi_day_df = context.sql(multi_day_sql).df
|
|
1160
|
+
result_df = single_day_df.unionByName(multi_day_df, allowMissingColumns=True)
|
|
1161
|
+
return context.with_df(result_df)
|
|
1162
|
+
|
|
1163
|
+
elif context.engine_type == EngineType.PANDAS:
|
|
1164
|
+
import pandas as pd
|
|
1165
|
+
|
|
1166
|
+
df = context.df.copy()
|
|
1167
|
+
|
|
1168
|
+
df[start_col] = pd.to_datetime(df[start_col])
|
|
1169
|
+
df[end_col] = pd.to_datetime(df[end_col])
|
|
1170
|
+
|
|
1171
|
+
df["_event_days"] = (df[end_col].dt.normalize() - df[start_col].dt.normalize()).dt.days + 1
|
|
1172
|
+
|
|
1173
|
+
single_day = df[df["_event_days"] == 1].copy()
|
|
1174
|
+
multi_day = df[df["_event_days"] > 1].copy()
|
|
1175
|
+
|
|
1176
|
+
if params.duration_col and not single_day.empty:
|
|
1177
|
+
single_day[params.duration_col] = (
|
|
1178
|
+
single_day[end_col] - single_day[start_col]
|
|
1179
|
+
).dt.total_seconds() / 60.0
|
|
1180
|
+
|
|
1181
|
+
if not multi_day.empty:
|
|
1182
|
+
rows = []
|
|
1183
|
+
for _, row in multi_day.iterrows():
|
|
1184
|
+
start = row[start_col]
|
|
1185
|
+
end = row[end_col]
|
|
1186
|
+
current_day = start.normalize()
|
|
1187
|
+
|
|
1188
|
+
while current_day <= end.normalize():
|
|
1189
|
+
new_row = row.copy()
|
|
1190
|
+
|
|
1191
|
+
if current_day == start.normalize():
|
|
1192
|
+
new_start = start
|
|
1193
|
+
else:
|
|
1194
|
+
new_start = current_day
|
|
1195
|
+
|
|
1196
|
+
next_day = current_day + pd.Timedelta(days=1)
|
|
1197
|
+
if next_day > end:
|
|
1198
|
+
new_end = end
|
|
1199
|
+
else:
|
|
1200
|
+
new_end = next_day
|
|
1201
|
+
|
|
1202
|
+
new_row[start_col] = new_start
|
|
1203
|
+
new_row[end_col] = new_end
|
|
1204
|
+
|
|
1205
|
+
if params.duration_col:
|
|
1206
|
+
new_row[params.duration_col] = (new_end - new_start).total_seconds() / 60.0
|
|
1207
|
+
|
|
1208
|
+
rows.append(new_row)
|
|
1209
|
+
current_day = next_day
|
|
1210
|
+
|
|
1211
|
+
multi_day_exploded = pd.DataFrame(rows)
|
|
1212
|
+
result = pd.concat([single_day, multi_day_exploded], ignore_index=True)
|
|
1213
|
+
else:
|
|
1214
|
+
result = single_day
|
|
1215
|
+
|
|
1216
|
+
result = result.drop(columns=["_event_days"], errors="ignore")
|
|
1217
|
+
return context.with_df(result)
|
|
1218
|
+
|
|
1219
|
+
else:
|
|
1220
|
+
raise ValueError(
|
|
1221
|
+
f"Split events by day does not support engine type '{context.engine_type}'. "
|
|
1222
|
+
f"Supported engines: SPARK, PANDAS. "
|
|
1223
|
+
f"Check your engine configuration."
|
|
1224
|
+
)
|
|
1225
|
+
|
|
1226
|
+
|
|
1227
|
+
def _split_by_hour(context: EngineContext, params: SplitEventsByPeriodParams) -> EngineContext:
|
|
1228
|
+
"""Split events by hour boundaries."""
|
|
1229
|
+
start_col = params.start_col
|
|
1230
|
+
end_col = params.end_col
|
|
1231
|
+
|
|
1232
|
+
if context.engine_type == EngineType.SPARK:
|
|
1233
|
+
ts_start = f"to_timestamp({start_col})"
|
|
1234
|
+
ts_end = f"to_timestamp({end_col})"
|
|
1235
|
+
|
|
1236
|
+
duration_expr = ""
|
|
1237
|
+
if params.duration_col:
|
|
1238
|
+
duration_expr = f", (unix_timestamp({end_col}) - unix_timestamp({start_col})) / 60.0 AS {params.duration_col}"
|
|
1239
|
+
|
|
1240
|
+
sql = f"""
|
|
1241
|
+
WITH events_with_hours AS (
|
|
1242
|
+
SELECT *,
|
|
1243
|
+
{ts_start} AS _ts_start,
|
|
1244
|
+
{ts_end} AS _ts_end,
|
|
1245
|
+
CAST((unix_timestamp({ts_end}) - unix_timestamp({ts_start})) / 3600 AS INT) + 1 AS _event_hours
|
|
1246
|
+
FROM df
|
|
1247
|
+
),
|
|
1248
|
+
multi_hour AS (
|
|
1249
|
+
SELECT *,
|
|
1250
|
+
explode(sequence(
|
|
1251
|
+
date_trunc('hour', _ts_start),
|
|
1252
|
+
date_trunc('hour', _ts_end),
|
|
1253
|
+
interval 1 hour
|
|
1254
|
+
)) AS _exploded_hour
|
|
1255
|
+
FROM events_with_hours
|
|
1256
|
+
WHERE _event_hours > 1
|
|
1257
|
+
),
|
|
1258
|
+
multi_hour_adjusted AS (
|
|
1259
|
+
SELECT * EXCEPT(_exploded_hour, _event_hours, _ts_start, _ts_end),
|
|
1260
|
+
CASE
|
|
1261
|
+
WHEN date_trunc('hour', _ts_start) = _exploded_hour THEN _ts_start
|
|
1262
|
+
ELSE _exploded_hour
|
|
1263
|
+
END AS {start_col},
|
|
1264
|
+
CASE
|
|
1265
|
+
WHEN date_trunc('hour', _ts_end) = _exploded_hour THEN _ts_end
|
|
1266
|
+
ELSE _exploded_hour + interval 1 hour
|
|
1267
|
+
END AS {end_col}
|
|
1268
|
+
FROM multi_hour
|
|
1269
|
+
),
|
|
1270
|
+
single_hour AS (
|
|
1271
|
+
SELECT * EXCEPT(_event_hours, _ts_start, _ts_end){duration_expr}
|
|
1272
|
+
FROM events_with_hours
|
|
1273
|
+
WHERE _event_hours = 1
|
|
1274
|
+
)
|
|
1275
|
+
SELECT *{duration_expr} FROM multi_hour_adjusted
|
|
1276
|
+
UNION ALL
|
|
1277
|
+
SELECT * FROM single_hour
|
|
1278
|
+
"""
|
|
1279
|
+
return context.sql(sql)
|
|
1280
|
+
|
|
1281
|
+
elif context.engine_type == EngineType.PANDAS:
|
|
1282
|
+
import pandas as pd
|
|
1283
|
+
|
|
1284
|
+
df = context.df.copy()
|
|
1285
|
+
|
|
1286
|
+
df[start_col] = pd.to_datetime(df[start_col])
|
|
1287
|
+
df[end_col] = pd.to_datetime(df[end_col])
|
|
1288
|
+
|
|
1289
|
+
df["_event_hours"] = ((df[end_col] - df[start_col]).dt.total_seconds() / 3600).astype(
|
|
1290
|
+
int
|
|
1291
|
+
) + 1
|
|
1292
|
+
|
|
1293
|
+
single_hour = df[df["_event_hours"] == 1].copy()
|
|
1294
|
+
multi_hour = df[df["_event_hours"] > 1].copy()
|
|
1295
|
+
|
|
1296
|
+
if params.duration_col and not single_hour.empty:
|
|
1297
|
+
single_hour[params.duration_col] = (
|
|
1298
|
+
single_hour[end_col] - single_hour[start_col]
|
|
1299
|
+
).dt.total_seconds() / 60.0
|
|
1300
|
+
|
|
1301
|
+
if not multi_hour.empty:
|
|
1302
|
+
rows = []
|
|
1303
|
+
for _, row in multi_hour.iterrows():
|
|
1304
|
+
start = row[start_col]
|
|
1305
|
+
end = row[end_col]
|
|
1306
|
+
current_hour = start.floor("h")
|
|
1307
|
+
|
|
1308
|
+
while current_hour <= end.floor("h"):
|
|
1309
|
+
new_row = row.copy()
|
|
1310
|
+
|
|
1311
|
+
if current_hour == start.floor("h"):
|
|
1312
|
+
new_start = start
|
|
1313
|
+
else:
|
|
1314
|
+
new_start = current_hour
|
|
1315
|
+
|
|
1316
|
+
next_hour = current_hour + pd.Timedelta(hours=1)
|
|
1317
|
+
if next_hour > end:
|
|
1318
|
+
new_end = end
|
|
1319
|
+
else:
|
|
1320
|
+
new_end = next_hour
|
|
1321
|
+
|
|
1322
|
+
new_row[start_col] = new_start
|
|
1323
|
+
new_row[end_col] = new_end
|
|
1324
|
+
|
|
1325
|
+
if params.duration_col:
|
|
1326
|
+
new_row[params.duration_col] = (new_end - new_start).total_seconds() / 60.0
|
|
1327
|
+
|
|
1328
|
+
rows.append(new_row)
|
|
1329
|
+
current_hour = next_hour
|
|
1330
|
+
|
|
1331
|
+
multi_hour_exploded = pd.DataFrame(rows)
|
|
1332
|
+
result = pd.concat([single_hour, multi_hour_exploded], ignore_index=True)
|
|
1333
|
+
else:
|
|
1334
|
+
result = single_hour
|
|
1335
|
+
|
|
1336
|
+
result = result.drop(columns=["_event_hours"], errors="ignore")
|
|
1337
|
+
return context.with_df(result)
|
|
1338
|
+
|
|
1339
|
+
else:
|
|
1340
|
+
raise ValueError(
|
|
1341
|
+
f"Split events by hour does not support engine type '{context.engine_type}'. "
|
|
1342
|
+
f"Supported engines: SPARK, PANDAS. "
|
|
1343
|
+
f"Check your engine configuration."
|
|
1344
|
+
)
|
|
1345
|
+
|
|
1346
|
+
|
|
1347
|
+
def _split_by_shift(context: EngineContext, params: SplitEventsByPeriodParams) -> EngineContext:
|
|
1348
|
+
"""Split events by shift boundaries."""
|
|
1349
|
+
start_col = params.start_col
|
|
1350
|
+
end_col = params.end_col
|
|
1351
|
+
shift_col = params.shift_col or "shift_name"
|
|
1352
|
+
shifts = params.shifts
|
|
1353
|
+
|
|
1354
|
+
if context.engine_type == EngineType.SPARK:
|
|
1355
|
+
ts_start = f"to_timestamp({start_col})"
|
|
1356
|
+
ts_end = f"to_timestamp({end_col})"
|
|
1357
|
+
|
|
1358
|
+
duration_expr = ""
|
|
1359
|
+
if params.duration_col:
|
|
1360
|
+
duration_expr = f", (unix_timestamp({end_col}) - unix_timestamp({start_col})) / 60.0 AS {params.duration_col}"
|
|
1361
|
+
|
|
1362
|
+
sql = f"""
|
|
1363
|
+
WITH base AS (
|
|
1364
|
+
SELECT *,
|
|
1365
|
+
{ts_start} AS _ts_start,
|
|
1366
|
+
{ts_end} AS _ts_end,
|
|
1367
|
+
datediff(to_date({ts_end}), to_date({ts_start})) + 1 AS _span_days
|
|
1368
|
+
FROM df
|
|
1369
|
+
),
|
|
1370
|
+
day_exploded AS (
|
|
1371
|
+
SELECT *,
|
|
1372
|
+
explode(sequence(to_date(_ts_start), to_date(_ts_end), interval 1 day)) AS _day
|
|
1373
|
+
FROM base
|
|
1374
|
+
),
|
|
1375
|
+
with_shift_times AS (
|
|
1376
|
+
SELECT *,
|
|
1377
|
+
{
|
|
1378
|
+
", ".join(
|
|
1379
|
+
[
|
|
1380
|
+
f"to_timestamp(concat(cast(_day as string), ' {s.start}:00')) AS _shift_{i}_start, "
|
|
1381
|
+
+ f"to_timestamp(concat(cast(date_add(_day, {1 if int(s.end.split(':')[0]) < int(s.start.split(':')[0]) else 0}) as string), ' {s.end}:00')) AS _shift_{i}_end"
|
|
1382
|
+
for i, s in enumerate(shifts)
|
|
1383
|
+
]
|
|
1384
|
+
)
|
|
1385
|
+
}
|
|
1386
|
+
FROM day_exploded
|
|
1387
|
+
),
|
|
1388
|
+
shift_segments AS (
|
|
1389
|
+
{
|
|
1390
|
+
" UNION ALL ".join(
|
|
1391
|
+
[
|
|
1392
|
+
f"SELECT * EXCEPT({', '.join([f'_shift_{j}_start, _shift_{j}_end' for j in range(len(shifts))])}, _ts_start, _ts_end), "
|
|
1393
|
+
f"GREATEST(_ts_start, _shift_{i}_start) AS {start_col}, "
|
|
1394
|
+
f"LEAST(_ts_end, _shift_{i}_end) AS {end_col}, "
|
|
1395
|
+
f"'{s.name}' AS {shift_col} "
|
|
1396
|
+
f"FROM with_shift_times "
|
|
1397
|
+
f"WHERE _ts_start < _shift_{i}_end AND _ts_end > _shift_{i}_start"
|
|
1398
|
+
for i, s in enumerate(shifts)
|
|
1399
|
+
]
|
|
1400
|
+
)
|
|
1401
|
+
}
|
|
1402
|
+
)
|
|
1403
|
+
SELECT * EXCEPT(_span_days, _day){duration_expr}
|
|
1404
|
+
FROM shift_segments
|
|
1405
|
+
WHERE {start_col} < {end_col}
|
|
1406
|
+
"""
|
|
1407
|
+
return context.sql(sql)
|
|
1408
|
+
|
|
1409
|
+
elif context.engine_type == EngineType.PANDAS:
|
|
1410
|
+
import pandas as pd
|
|
1411
|
+
from datetime import timedelta
|
|
1412
|
+
|
|
1413
|
+
df = context.df.copy()
|
|
1414
|
+
df[start_col] = pd.to_datetime(df[start_col])
|
|
1415
|
+
df[end_col] = pd.to_datetime(df[end_col])
|
|
1416
|
+
|
|
1417
|
+
def parse_time(t_str):
|
|
1418
|
+
h, m = map(int, t_str.split(":"))
|
|
1419
|
+
return timedelta(hours=h, minutes=m)
|
|
1420
|
+
|
|
1421
|
+
rows = []
|
|
1422
|
+
for _, row in df.iterrows():
|
|
1423
|
+
event_start = row[start_col]
|
|
1424
|
+
event_end = row[end_col]
|
|
1425
|
+
|
|
1426
|
+
current_day = event_start.normalize()
|
|
1427
|
+
while current_day <= event_end.normalize():
|
|
1428
|
+
for shift in shifts:
|
|
1429
|
+
shift_start_delta = parse_time(shift.start)
|
|
1430
|
+
shift_end_delta = parse_time(shift.end)
|
|
1431
|
+
|
|
1432
|
+
if shift_end_delta <= shift_start_delta:
|
|
1433
|
+
shift_start_dt = current_day + shift_start_delta
|
|
1434
|
+
shift_end_dt = current_day + timedelta(days=1) + shift_end_delta
|
|
1435
|
+
else:
|
|
1436
|
+
shift_start_dt = current_day + shift_start_delta
|
|
1437
|
+
shift_end_dt = current_day + shift_end_delta
|
|
1438
|
+
|
|
1439
|
+
seg_start = max(event_start, shift_start_dt)
|
|
1440
|
+
seg_end = min(event_end, shift_end_dt)
|
|
1441
|
+
|
|
1442
|
+
if seg_start < seg_end:
|
|
1443
|
+
new_row = row.copy()
|
|
1444
|
+
new_row[start_col] = seg_start
|
|
1445
|
+
new_row[end_col] = seg_end
|
|
1446
|
+
new_row[shift_col] = shift.name
|
|
1447
|
+
|
|
1448
|
+
if params.duration_col:
|
|
1449
|
+
new_row[params.duration_col] = (
|
|
1450
|
+
seg_end - seg_start
|
|
1451
|
+
).total_seconds() / 60.0
|
|
1452
|
+
|
|
1453
|
+
rows.append(new_row)
|
|
1454
|
+
|
|
1455
|
+
current_day += timedelta(days=1)
|
|
1456
|
+
|
|
1457
|
+
if rows:
|
|
1458
|
+
result = pd.DataFrame(rows)
|
|
1459
|
+
else:
|
|
1460
|
+
result = df.copy()
|
|
1461
|
+
result[shift_col] = None
|
|
1462
|
+
if params.duration_col:
|
|
1463
|
+
result[params.duration_col] = 0.0
|
|
1464
|
+
|
|
1465
|
+
return context.with_df(result)
|
|
1466
|
+
|
|
1467
|
+
else:
|
|
1468
|
+
raise ValueError(
|
|
1469
|
+
f"Split events by shift does not support engine type '{context.engine_type}'. "
|
|
1470
|
+
f"Supported engines: SPARK, PANDAS. "
|
|
1471
|
+
f"Check your engine configuration."
|
|
1472
|
+
)
|