odibi 2.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. odibi/__init__.py +32 -0
  2. odibi/__main__.py +8 -0
  3. odibi/catalog.py +3011 -0
  4. odibi/cli/__init__.py +11 -0
  5. odibi/cli/__main__.py +6 -0
  6. odibi/cli/catalog.py +553 -0
  7. odibi/cli/deploy.py +69 -0
  8. odibi/cli/doctor.py +161 -0
  9. odibi/cli/export.py +66 -0
  10. odibi/cli/graph.py +150 -0
  11. odibi/cli/init_pipeline.py +242 -0
  12. odibi/cli/lineage.py +259 -0
  13. odibi/cli/main.py +215 -0
  14. odibi/cli/run.py +98 -0
  15. odibi/cli/schema.py +208 -0
  16. odibi/cli/secrets.py +232 -0
  17. odibi/cli/story.py +379 -0
  18. odibi/cli/system.py +132 -0
  19. odibi/cli/test.py +286 -0
  20. odibi/cli/ui.py +31 -0
  21. odibi/cli/validate.py +39 -0
  22. odibi/config.py +3541 -0
  23. odibi/connections/__init__.py +9 -0
  24. odibi/connections/azure_adls.py +499 -0
  25. odibi/connections/azure_sql.py +709 -0
  26. odibi/connections/base.py +28 -0
  27. odibi/connections/factory.py +322 -0
  28. odibi/connections/http.py +78 -0
  29. odibi/connections/local.py +119 -0
  30. odibi/connections/local_dbfs.py +61 -0
  31. odibi/constants.py +17 -0
  32. odibi/context.py +528 -0
  33. odibi/diagnostics/__init__.py +12 -0
  34. odibi/diagnostics/delta.py +520 -0
  35. odibi/diagnostics/diff.py +169 -0
  36. odibi/diagnostics/manager.py +171 -0
  37. odibi/engine/__init__.py +20 -0
  38. odibi/engine/base.py +334 -0
  39. odibi/engine/pandas_engine.py +2178 -0
  40. odibi/engine/polars_engine.py +1114 -0
  41. odibi/engine/registry.py +54 -0
  42. odibi/engine/spark_engine.py +2362 -0
  43. odibi/enums.py +7 -0
  44. odibi/exceptions.py +297 -0
  45. odibi/graph.py +426 -0
  46. odibi/introspect.py +1214 -0
  47. odibi/lineage.py +511 -0
  48. odibi/node.py +3341 -0
  49. odibi/orchestration/__init__.py +0 -0
  50. odibi/orchestration/airflow.py +90 -0
  51. odibi/orchestration/dagster.py +77 -0
  52. odibi/patterns/__init__.py +24 -0
  53. odibi/patterns/aggregation.py +599 -0
  54. odibi/patterns/base.py +94 -0
  55. odibi/patterns/date_dimension.py +423 -0
  56. odibi/patterns/dimension.py +696 -0
  57. odibi/patterns/fact.py +748 -0
  58. odibi/patterns/merge.py +128 -0
  59. odibi/patterns/scd2.py +148 -0
  60. odibi/pipeline.py +2382 -0
  61. odibi/plugins.py +80 -0
  62. odibi/project.py +581 -0
  63. odibi/references.py +151 -0
  64. odibi/registry.py +246 -0
  65. odibi/semantics/__init__.py +71 -0
  66. odibi/semantics/materialize.py +392 -0
  67. odibi/semantics/metrics.py +361 -0
  68. odibi/semantics/query.py +743 -0
  69. odibi/semantics/runner.py +430 -0
  70. odibi/semantics/story.py +507 -0
  71. odibi/semantics/views.py +432 -0
  72. odibi/state/__init__.py +1203 -0
  73. odibi/story/__init__.py +55 -0
  74. odibi/story/doc_story.py +554 -0
  75. odibi/story/generator.py +1431 -0
  76. odibi/story/lineage.py +1043 -0
  77. odibi/story/lineage_utils.py +324 -0
  78. odibi/story/metadata.py +608 -0
  79. odibi/story/renderers.py +453 -0
  80. odibi/story/templates/run_story.html +2520 -0
  81. odibi/story/themes.py +216 -0
  82. odibi/testing/__init__.py +13 -0
  83. odibi/testing/assertions.py +75 -0
  84. odibi/testing/fixtures.py +85 -0
  85. odibi/testing/source_pool.py +277 -0
  86. odibi/transformers/__init__.py +122 -0
  87. odibi/transformers/advanced.py +1472 -0
  88. odibi/transformers/delete_detection.py +610 -0
  89. odibi/transformers/manufacturing.py +1029 -0
  90. odibi/transformers/merge_transformer.py +778 -0
  91. odibi/transformers/relational.py +675 -0
  92. odibi/transformers/scd.py +579 -0
  93. odibi/transformers/sql_core.py +1356 -0
  94. odibi/transformers/validation.py +165 -0
  95. odibi/ui/__init__.py +0 -0
  96. odibi/ui/app.py +195 -0
  97. odibi/utils/__init__.py +66 -0
  98. odibi/utils/alerting.py +667 -0
  99. odibi/utils/config_loader.py +343 -0
  100. odibi/utils/console.py +231 -0
  101. odibi/utils/content_hash.py +202 -0
  102. odibi/utils/duration.py +43 -0
  103. odibi/utils/encoding.py +102 -0
  104. odibi/utils/extensions.py +28 -0
  105. odibi/utils/hashing.py +61 -0
  106. odibi/utils/logging.py +203 -0
  107. odibi/utils/logging_context.py +740 -0
  108. odibi/utils/progress.py +429 -0
  109. odibi/utils/setup_helpers.py +302 -0
  110. odibi/utils/telemetry.py +140 -0
  111. odibi/validation/__init__.py +62 -0
  112. odibi/validation/engine.py +765 -0
  113. odibi/validation/explanation_linter.py +155 -0
  114. odibi/validation/fk.py +547 -0
  115. odibi/validation/gate.py +252 -0
  116. odibi/validation/quarantine.py +605 -0
  117. odibi/writers/__init__.py +15 -0
  118. odibi/writers/sql_server_writer.py +2081 -0
  119. odibi-2.5.0.dist-info/METADATA +255 -0
  120. odibi-2.5.0.dist-info/RECORD +124 -0
  121. odibi-2.5.0.dist-info/WHEEL +5 -0
  122. odibi-2.5.0.dist-info/entry_points.txt +2 -0
  123. odibi-2.5.0.dist-info/licenses/LICENSE +190 -0
  124. odibi-2.5.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1472 @@
1
+ import time
2
+ from enum import Enum
3
+ from typing import Dict, List, Optional, Union
4
+
5
+ from pydantic import BaseModel, Field, field_validator
6
+
7
+ from odibi.context import EngineContext
8
+ from odibi.enums import EngineType
9
+ from odibi.utils.logging_context import get_logging_context
10
+
11
+ # -------------------------------------------------------------------------
12
+ # 1. Deduplicate (Window)
13
+ # -------------------------------------------------------------------------
14
+
15
+
16
+ class DeduplicateParams(BaseModel):
17
+ """
18
+ Configuration for deduplication.
19
+
20
+ Scenario: Keep latest record
21
+ ```yaml
22
+ deduplicate:
23
+ keys: ["id"]
24
+ order_by: "updated_at DESC"
25
+ ```
26
+ """
27
+
28
+ keys: List[str] = Field(
29
+ ..., description="List of columns to partition by (columns that define uniqueness)"
30
+ )
31
+ order_by: Optional[str] = Field(
32
+ None,
33
+ description="SQL Order by clause (e.g. 'updated_at DESC') to determine which record to keep (first one is kept)",
34
+ )
35
+
36
+
37
+ def deduplicate(context: EngineContext, params: DeduplicateParams) -> EngineContext:
38
+ """
39
+ Deduplicates data using Window functions.
40
+ """
41
+ ctx = get_logging_context()
42
+ start_time = time.time()
43
+
44
+ ctx.debug(
45
+ "Deduplicate starting",
46
+ keys=params.keys,
47
+ order_by=params.order_by,
48
+ )
49
+
50
+ # Get row count before transformation (optional, for logging only)
51
+ rows_before = None
52
+ try:
53
+ rows_before = context.df.shape[0] if hasattr(context.df, "shape") else None
54
+ if rows_before is None and hasattr(context.df, "count"):
55
+ rows_before = context.df.count()
56
+ except Exception as e:
57
+ ctx.debug(f"Could not get row count before transform: {type(e).__name__}")
58
+
59
+ partition_clause = ", ".join(params.keys)
60
+ order_clause = params.order_by if params.order_by else "(SELECT NULL)"
61
+
62
+ # Dialect handling for EXCEPT/EXCLUDE
63
+ except_clause = "EXCEPT"
64
+ if context.engine_type == EngineType.PANDAS:
65
+ # DuckDB uses EXCLUDE
66
+ except_clause = "EXCLUDE"
67
+
68
+ sql_query = f"""
69
+ SELECT * {except_clause}(_rn) FROM (
70
+ SELECT *,
71
+ ROW_NUMBER() OVER (PARTITION BY {partition_clause} ORDER BY {order_clause}) as _rn
72
+ FROM df
73
+ ) WHERE _rn = 1
74
+ """
75
+ result = context.sql(sql_query)
76
+
77
+ # Get row count after transformation (optional, for logging only)
78
+ rows_after = None
79
+ try:
80
+ rows_after = result.df.shape[0] if hasattr(result.df, "shape") else None
81
+ if rows_after is None and hasattr(result.df, "count"):
82
+ rows_after = result.df.count()
83
+ except Exception as e:
84
+ ctx.debug(f"Could not get row count after transform: {type(e).__name__}")
85
+
86
+ elapsed_ms = (time.time() - start_time) * 1000
87
+ duplicates_removed = rows_before - rows_after if rows_before and rows_after else None
88
+ ctx.debug(
89
+ "Deduplicate completed",
90
+ keys=params.keys,
91
+ rows_before=rows_before,
92
+ rows_after=rows_after,
93
+ duplicates_removed=duplicates_removed,
94
+ elapsed_ms=round(elapsed_ms, 2),
95
+ )
96
+
97
+ return result
98
+
99
+
100
+ # -------------------------------------------------------------------------
101
+ # 2. Explode List
102
+ # -------------------------------------------------------------------------
103
+
104
+
105
+ class ExplodeParams(BaseModel):
106
+ """
107
+ Configuration for exploding lists.
108
+
109
+ Scenario: Flatten list of items per order
110
+ ```yaml
111
+ explode_list_column:
112
+ column: "items"
113
+ outer: true # Keep orders with empty items list
114
+ ```
115
+ """
116
+
117
+ column: str = Field(..., description="Column containing the list/array to explode")
118
+ outer: bool = Field(
119
+ False,
120
+ description="If True, keep rows with empty lists (explode_outer behavior). If False, drops them.",
121
+ )
122
+
123
+
124
+ def explode_list_column(context: EngineContext, params: ExplodeParams) -> EngineContext:
125
+ ctx = get_logging_context()
126
+ start_time = time.time()
127
+
128
+ ctx.debug(
129
+ "Explode starting",
130
+ column=params.column,
131
+ outer=params.outer,
132
+ )
133
+
134
+ rows_before = None
135
+ try:
136
+ rows_before = context.df.shape[0] if hasattr(context.df, "shape") else None
137
+ if rows_before is None and hasattr(context.df, "count"):
138
+ rows_before = context.df.count()
139
+ except Exception:
140
+ pass
141
+
142
+ if context.engine_type == EngineType.SPARK:
143
+ import pyspark.sql.functions as F
144
+
145
+ func = F.explode_outer if params.outer else F.explode
146
+ df = context.df.withColumn(params.column, func(F.col(params.column)))
147
+
148
+ rows_after = df.count() if hasattr(df, "count") else None
149
+ elapsed_ms = (time.time() - start_time) * 1000
150
+ ctx.debug(
151
+ "Explode completed",
152
+ column=params.column,
153
+ rows_before=rows_before,
154
+ rows_after=rows_after,
155
+ elapsed_ms=round(elapsed_ms, 2),
156
+ )
157
+ return context.with_df(df)
158
+
159
+ elif context.engine_type == EngineType.PANDAS:
160
+ df = context.df.explode(params.column)
161
+ if not params.outer:
162
+ df = df.dropna(subset=[params.column])
163
+
164
+ rows_after = df.shape[0] if hasattr(df, "shape") else None
165
+ elapsed_ms = (time.time() - start_time) * 1000
166
+ ctx.debug(
167
+ "Explode completed",
168
+ column=params.column,
169
+ rows_before=rows_before,
170
+ rows_after=rows_after,
171
+ elapsed_ms=round(elapsed_ms, 2),
172
+ )
173
+ return context.with_df(df)
174
+
175
+ else:
176
+ ctx.error("Explode failed: unsupported engine", engine_type=str(context.engine_type))
177
+ raise ValueError(
178
+ f"Explode transformer does not support engine type '{context.engine_type}'. "
179
+ f"Supported engines: SPARK, PANDAS. "
180
+ f"Check your engine configuration."
181
+ )
182
+
183
+
184
+ # -------------------------------------------------------------------------
185
+ # 3. Dict Mapping
186
+ # -------------------------------------------------------------------------
187
+
188
+ JsonScalar = Union[str, int, float, bool, None]
189
+
190
+
191
+ class DictMappingParams(BaseModel):
192
+ """
193
+ Configuration for dictionary mapping.
194
+
195
+ Scenario: Map status codes to labels
196
+ ```yaml
197
+ dict_based_mapping:
198
+ column: "status_code"
199
+ mapping:
200
+ "1": "Active"
201
+ "0": "Inactive"
202
+ default: "Unknown"
203
+ output_column: "status_desc"
204
+ ```
205
+ """
206
+
207
+ column: str = Field(..., description="Column to map values from")
208
+ mapping: Dict[str, JsonScalar] = Field(
209
+ ..., description="Dictionary of source value -> target value"
210
+ )
211
+ default: Optional[JsonScalar] = Field(
212
+ None, description="Default value if source value is not found in mapping"
213
+ )
214
+ output_column: Optional[str] = Field(
215
+ None, description="Name of output column. If not provided, overwrites source column."
216
+ )
217
+
218
+
219
+ def dict_based_mapping(context: EngineContext, params: DictMappingParams) -> EngineContext:
220
+ target_col = params.output_column or params.column
221
+
222
+ if context.engine_type == EngineType.SPARK:
223
+ from itertools import chain
224
+
225
+ import pyspark.sql.functions as F
226
+
227
+ # Create map expression
228
+ mapping_expr = F.create_map([F.lit(x) for x in chain(*params.mapping.items())])
229
+
230
+ df = context.df.withColumn(target_col, mapping_expr[F.col(params.column)])
231
+ if params.default is not None:
232
+ df = df.withColumn(target_col, F.coalesce(F.col(target_col), F.lit(params.default)))
233
+ return context.with_df(df)
234
+
235
+ elif context.engine_type == EngineType.PANDAS:
236
+ df = context.df.copy()
237
+ # Pandas map is fast
238
+ df[target_col] = df[params.column].map(params.mapping)
239
+ if params.default is not None:
240
+ df[target_col] = df[target_col].fillna(params.default).infer_objects(copy=False)
241
+ return context.with_df(df)
242
+
243
+ else:
244
+ raise ValueError(
245
+ f"Dict-based mapping does not support engine type '{context.engine_type}'. "
246
+ f"Supported engines: SPARK, PANDAS. "
247
+ f"Check your engine configuration."
248
+ )
249
+
250
+
251
+ # -------------------------------------------------------------------------
252
+ # 4. Regex Replace
253
+ # -------------------------------------------------------------------------
254
+
255
+
256
+ class RegexReplaceParams(BaseModel):
257
+ """
258
+ Configuration for regex replacement.
259
+
260
+ Example:
261
+ ```yaml
262
+ regex_replace:
263
+ column: "phone"
264
+ pattern: "[^0-9]"
265
+ replacement: ""
266
+ ```
267
+ """
268
+
269
+ column: str = Field(..., description="Column to apply regex replacement on")
270
+ pattern: str = Field(..., description="Regex pattern to match")
271
+ replacement: str = Field(..., description="String to replace matches with")
272
+
273
+
274
+ def regex_replace(context: EngineContext, params: RegexReplaceParams) -> EngineContext:
275
+ """
276
+ SQL-based Regex replacement.
277
+ """
278
+ # Spark and DuckDB both support REGEXP_REPLACE(col, pattern, replacement)
279
+ sql_query = f"SELECT *, REGEXP_REPLACE({params.column}, '{params.pattern}', '{params.replacement}') AS {params.column} FROM df"
280
+ return context.sql(sql_query)
281
+
282
+
283
+ # -------------------------------------------------------------------------
284
+ # 5. Unpack Struct (Flatten)
285
+ # -------------------------------------------------------------------------
286
+
287
+
288
+ class UnpackStructParams(BaseModel):
289
+ """
290
+ Configuration for unpacking structs.
291
+
292
+ Example:
293
+ ```yaml
294
+ unpack_struct:
295
+ column: "user_info"
296
+ ```
297
+ """
298
+
299
+ column: str = Field(
300
+ ..., description="Struct/Dictionary column to unpack/flatten into individual columns"
301
+ )
302
+
303
+
304
+ def unpack_struct(context: EngineContext, params: UnpackStructParams) -> EngineContext:
305
+ """
306
+ Flattens a struct/dict column into top-level columns.
307
+ """
308
+ if context.engine_type == EngineType.SPARK:
309
+ # Spark: "select col.* from df"
310
+ sql_query = f"SELECT *, {params.column}.* FROM df"
311
+ # Usually we want to drop the original struct?
312
+ # For safety, we keep original but append fields.
313
+ # Actually "SELECT *" includes the struct.
314
+ # Let's assume users drop it later or we just select expanded.
315
+ return context.sql(sql_query)
316
+
317
+ elif context.engine_type == EngineType.PANDAS:
318
+ import pandas as pd
319
+
320
+ # Pandas: json_normalize or Apply(pd.Series)
321
+ # Optimization: df[col].tolist() is much faster than apply(pd.Series)
322
+ # assuming the column contains dictionaries/structs.
323
+ try:
324
+ expanded = pd.DataFrame(context.df[params.column].tolist(), index=context.df.index)
325
+ except Exception as e:
326
+ import logging
327
+
328
+ logger = logging.getLogger(__name__)
329
+ logger.debug(f"Optimized struct unpack failed (falling back to slow apply): {e}")
330
+ # Fallback if tolist() fails (e.g. mixed types)
331
+ expanded = context.df[params.column].apply(pd.Series)
332
+
333
+ # Rename to avoid collisions? Default behavior is to use keys.
334
+ # Join back
335
+ res = pd.concat([context.df, expanded], axis=1)
336
+ return context.with_df(res)
337
+
338
+ else:
339
+ raise ValueError(
340
+ f"Unpack struct does not support engine type '{context.engine_type}'. "
341
+ f"Supported engines: SPARK, PANDAS. "
342
+ f"Check your engine configuration."
343
+ )
344
+
345
+
346
+ # -------------------------------------------------------------------------
347
+ # 6. Hash Columns
348
+ # -------------------------------------------------------------------------
349
+
350
+
351
+ class HashAlgorithm(str, Enum):
352
+ SHA256 = "sha256"
353
+ MD5 = "md5"
354
+
355
+
356
+ class HashParams(BaseModel):
357
+ """
358
+ Configuration for column hashing.
359
+
360
+ Example:
361
+ ```yaml
362
+ hash_columns:
363
+ columns: ["email", "ssn"]
364
+ algorithm: "sha256"
365
+ ```
366
+ """
367
+
368
+ columns: List[str] = Field(..., description="List of columns to hash")
369
+ algorithm: HashAlgorithm = Field(
370
+ HashAlgorithm.SHA256, description="Hashing algorithm. Options: 'sha256', 'md5'"
371
+ )
372
+
373
+
374
+ def hash_columns(context: EngineContext, params: HashParams) -> EngineContext:
375
+ """
376
+ Hashes columns for PII/Anonymization.
377
+ """
378
+ # Removed unused 'expressions' variable
379
+
380
+ # Since SQL syntax differs, use Dual Engine
381
+ if context.engine_type == EngineType.SPARK:
382
+ import pyspark.sql.functions as F
383
+
384
+ df = context.df
385
+ for col in params.columns:
386
+ if params.algorithm == HashAlgorithm.SHA256:
387
+ df = df.withColumn(col, F.sha2(F.col(col), 256))
388
+ elif params.algorithm == HashAlgorithm.MD5:
389
+ df = df.withColumn(col, F.md5(F.col(col)))
390
+ return context.with_df(df)
391
+
392
+ elif context.engine_type == EngineType.PANDAS:
393
+ df = context.df.copy()
394
+
395
+ # Optimization: Try PyArrow compute for vectorized hashing if available
396
+ # For now, the below logic is a placeholder for future vectorized hashing.
397
+ # The import is unused in the current implementation fallback, triggering linter errors.
398
+ # We will stick to the stable hashlib fallback for now.
399
+ pass
400
+
401
+ import hashlib
402
+
403
+ def hash_val(val, alg):
404
+ if val is None:
405
+ return None
406
+ encoded = str(val).encode("utf-8")
407
+ if alg == HashAlgorithm.SHA256:
408
+ return hashlib.sha256(encoded).hexdigest()
409
+ return hashlib.md5(encoded).hexdigest()
410
+
411
+ # Vectorize? difficult with standard lib hashlib.
412
+ # Apply is acceptable for this security feature vs complexity of numpy deps
413
+ for col in params.columns:
414
+ # Optimization: Ensure string type once
415
+ s_col = df[col].astype(str)
416
+ df[col] = s_col.apply(lambda x: hash_val(x, params.algorithm))
417
+
418
+ return context.with_df(df)
419
+
420
+ else:
421
+ raise ValueError(f"Unsupported engine: {context.engine_type}")
422
+
423
+
424
+ # -------------------------------------------------------------------------
425
+ # 7. Generate Surrogate Key
426
+ # -------------------------------------------------------------------------
427
+
428
+
429
+ class SurrogateKeyParams(BaseModel):
430
+ """
431
+ Configuration for surrogate key generation.
432
+
433
+ Example:
434
+ ```yaml
435
+ generate_surrogate_key:
436
+ columns: ["region", "product_id"]
437
+ separator: "-"
438
+ output_col: "unique_id"
439
+ ```
440
+ """
441
+
442
+ columns: List[str] = Field(..., description="Columns to combine for the key")
443
+ separator: str = Field("-", description="Separator between values")
444
+ output_col: str = Field("surrogate_key", description="Name of the output column")
445
+
446
+
447
+ def generate_surrogate_key(context: EngineContext, params: SurrogateKeyParams) -> EngineContext:
448
+ """
449
+ Generates a deterministic surrogate key (MD5) from a combination of columns.
450
+ Handles NULLs by treating them as empty strings to ensure consistency.
451
+ """
452
+ # Logic: MD5( CONCAT_WS( separator, COALESCE(col1, ''), COALESCE(col2, '') ... ) )
453
+
454
+ from odibi.enums import EngineType
455
+
456
+ # 1. Build the concatenation expression
457
+ # We must cast to string and coalesce nulls
458
+
459
+ def safe_col(col, quote_char):
460
+ # Spark/DuckDB cast syntax slightly different but standard SQL CAST(x AS STRING) usually works
461
+ # Spark: cast(col as string) with backticks for quoting
462
+ # DuckDB: cast(col as varchar) with double quotes for quoting
463
+ return f"COALESCE(CAST({quote_char}{col}{quote_char} AS STRING), '')"
464
+
465
+ if context.engine_type == EngineType.SPARK:
466
+ # Spark CONCAT_WS skips nulls, but we coerced them to empty string above anyway for safety.
467
+ # Actually, if we want strict "dbt style" surrogate keys, we often treat NULL as a specific token.
468
+ # But empty string is standard for "simple" SKs.
469
+ quote_char = "`"
470
+ cols_expr = ", ".join([safe_col(c, quote_char) for c in params.columns])
471
+ concat_expr = f"concat_ws('{params.separator}', {cols_expr})"
472
+ final_expr = f"md5({concat_expr})"
473
+ output_col = f"`{params.output_col}`"
474
+
475
+ else:
476
+ # DuckDB / Pandas
477
+ # DuckDB also supports concat_ws and md5.
478
+ # Note: DuckDB CAST AS STRING is valid.
479
+ quote_char = '"'
480
+ cols_expr = ", ".join([safe_col(c, quote_char) for c in params.columns])
481
+ concat_expr = f"concat_ws('{params.separator}', {cols_expr})"
482
+ final_expr = f"md5({concat_expr})"
483
+ output_col = f'"{params.output_col}"'
484
+
485
+ sql_query = f"SELECT *, {final_expr} AS {output_col} FROM df"
486
+ return context.sql(sql_query)
487
+
488
+
489
+ # -------------------------------------------------------------------------
490
+ # 7b. Generate Numeric Key (BIGINT surrogate key)
491
+ # -------------------------------------------------------------------------
492
+
493
+
494
+ class NumericKeyParams(BaseModel):
495
+ """
496
+ Configuration for numeric surrogate key generation.
497
+
498
+ Generates a deterministic BIGINT key from a hash of specified columns.
499
+ Useful when unioning data from multiple sources where some have IDs
500
+ and others don't.
501
+
502
+ Example:
503
+ ```yaml
504
+ - function: generate_numeric_key
505
+ params:
506
+ columns: [DateID, store_id, reason_id, duration_min, notes]
507
+ output_col: ID
508
+ coalesce_with: ID # Keep existing ID if not null
509
+ ```
510
+
511
+ The generated key is:
512
+ - Deterministic: same input data = same ID every time
513
+ - BIGINT: large numeric space to avoid collisions
514
+ - Stable: safe for gold layer / incremental loads
515
+ """
516
+
517
+ columns: List[str] = Field(..., description="Columns to combine for the key")
518
+ separator: str = Field("|", description="Separator between values")
519
+ output_col: str = Field("numeric_key", description="Name of the output column")
520
+ coalesce_with: Optional[str] = Field(
521
+ None,
522
+ description="Existing column to coalesce with (keep existing value if not null)",
523
+ )
524
+
525
+
526
+ def generate_numeric_key(context: EngineContext, params: NumericKeyParams) -> EngineContext:
527
+ """
528
+ Generates a deterministic BIGINT surrogate key from a hash of columns.
529
+
530
+ This is useful when:
531
+ - Unioning data from multiple sources
532
+ - Some sources have IDs, some don't
533
+ - You need stable numeric IDs for gold layer
534
+
535
+ The key is generated by:
536
+ 1. Concatenating columns with separator
537
+ 2. Computing MD5 hash
538
+ 3. Converting first 15 hex chars to BIGINT
539
+
540
+ If coalesce_with is specified, keeps the existing value when not null.
541
+ If output_col == coalesce_with, the original column is replaced.
542
+ """
543
+ from odibi.enums import EngineType
544
+
545
+ def safe_col(col, quote_char):
546
+ # Normalize: TRIM whitespace, then treat empty string and NULL as equivalent
547
+ return f"COALESCE(NULLIF(TRIM(CAST({quote_char}{col}{quote_char} AS STRING)), ''), '')"
548
+
549
+ # Check if we need to replace the original column
550
+ # Replace if: coalesce_with == output_col, OR output_col already exists in dataframe
551
+ col_names = list(context.df.columns)
552
+ output_exists = params.output_col in col_names
553
+ replace_column = (
554
+ params.coalesce_with and params.output_col == params.coalesce_with
555
+ ) or output_exists
556
+
557
+ if context.engine_type == EngineType.SPARK:
558
+ quote_char = "`"
559
+ cols_expr = ", ".join([safe_col(c, quote_char) for c in params.columns])
560
+ concat_expr = f"concat_ws('{params.separator}', {cols_expr})"
561
+ hash_expr = f"CAST(CONV(SUBSTRING(md5({concat_expr}), 1, 15), 16, 10) AS BIGINT)"
562
+
563
+ if params.coalesce_with:
564
+ final_expr = f"COALESCE(CAST(`{params.coalesce_with}` AS BIGINT), {hash_expr})"
565
+ else:
566
+ final_expr = hash_expr
567
+
568
+ output_col = f"`{params.output_col}`"
569
+
570
+ else:
571
+ # DuckDB/Pandas - use ABS(HASH()) instead of CONV (DuckDB doesn't have CONV)
572
+ quote_char = '"'
573
+ cols_expr = ", ".join([safe_col(c, quote_char) for c in params.columns])
574
+ concat_expr = f"concat_ws('{params.separator}', {cols_expr})"
575
+ # hash() in DuckDB returns BIGINT directly
576
+ hash_expr = f"ABS(hash({concat_expr}))"
577
+
578
+ if params.coalesce_with:
579
+ final_expr = f'COALESCE(CAST("{params.coalesce_with}" AS BIGINT), {hash_expr})'
580
+ else:
581
+ final_expr = hash_expr
582
+
583
+ output_col = f'"{params.output_col}"'
584
+
585
+ if replace_column:
586
+ # Replace the original column by selecting all columns except the original,
587
+ # then adding the new computed column
588
+ col_to_exclude = params.coalesce_with if params.coalesce_with else params.output_col
589
+ if context.engine_type == EngineType.SPARK:
590
+ all_cols = [f"`{c}`" for c in col_names if c != col_to_exclude]
591
+ else:
592
+ # Pandas/DuckDB
593
+ all_cols = [f'"{c}"' for c in col_names if c != col_to_exclude]
594
+ cols_select = ", ".join(all_cols)
595
+ sql_query = f"SELECT {cols_select}, {final_expr} AS {output_col} FROM df"
596
+ else:
597
+ sql_query = f"SELECT *, {final_expr} AS {output_col} FROM df"
598
+
599
+ return context.sql(sql_query)
600
+
601
+
602
+ # -------------------------------------------------------------------------
603
+ # 8. Parse JSON
604
+ # -------------------------------------------------------------------------
605
+
606
+
607
+ class ParseJsonParams(BaseModel):
608
+ """
609
+ Configuration for JSON parsing.
610
+
611
+ Example:
612
+ ```yaml
613
+ parse_json:
614
+ column: "raw_json"
615
+ json_schema: "id INT, name STRING"
616
+ output_col: "parsed_struct"
617
+ ```
618
+ """
619
+
620
+ column: str = Field(..., description="String column containing JSON")
621
+ json_schema: str = Field(
622
+ ..., description="DDL schema string (e.g. 'a INT, b STRING') or Spark StructType DDL"
623
+ )
624
+ output_col: Optional[str] = None
625
+
626
+
627
+ def parse_json(context: EngineContext, params: ParseJsonParams) -> EngineContext:
628
+ """
629
+ Parses a JSON string column into a Struct/Map column.
630
+ """
631
+ from odibi.enums import EngineType
632
+
633
+ target = params.output_col or f"{params.column}_parsed"
634
+
635
+ if context.engine_type == EngineType.SPARK:
636
+ # Spark: from_json(col, schema)
637
+ expr = f"from_json({params.column}, '{params.json_schema}')"
638
+
639
+ else:
640
+ # DuckDB / Pandas
641
+ # DuckDB: json_transform(col, 'schema') is experimental.
642
+ # Standard: from_json(col, 'schema') works in recent DuckDB versions.
643
+ # But reliable way is usually casting or json extraction if we know the structure?
644
+ # Actually, DuckDB allows: cast(json_parse(col) as STRUCT(a INT, b VARCHAR...))
645
+
646
+ # We need to convert the generic DDL schema string to DuckDB STRUCT syntax?
647
+ # That is complex.
648
+ # SIMPLIFICATION: For DuckDB, we might rely on automatic inference if we use `json_parse`?
649
+ # Or just `json_parse(col)` which returns a JSON type (which is distinct).
650
+ # Then user can unpack it.
651
+
652
+ # Let's try `json_parse(col)`.
653
+ # Note: If user provided specific schema to enforce types, that's harder in DuckDB SQL string without parsing the DDL.
654
+ # Spark's schema string "a INT, b STRING" is not valid DuckDB STRUCT(a INT, b VARCHAR).
655
+
656
+ # For V1 of this function, we will focus on Spark (where it's critical).
657
+ # For DuckDB, we will use `CAST(col AS JSON)` which is the standard way to parse JSON string to JSON type.
658
+ # `json_parse` is an alias in some versions but CAST is more stable.
659
+
660
+ expr = f"CAST({params.column} AS JSON)"
661
+
662
+ sql_query = f"SELECT *, {expr} AS {target} FROM df"
663
+ return context.sql(sql_query)
664
+
665
+
666
+ # -------------------------------------------------------------------------
667
+ # 9. Validate and Flag
668
+ # -------------------------------------------------------------------------
669
+
670
+
671
+ class ValidateAndFlagParams(BaseModel):
672
+ """
673
+ Configuration for validation flagging.
674
+
675
+ Example:
676
+ ```yaml
677
+ validate_and_flag:
678
+ flag_col: "data_issues"
679
+ rules:
680
+ age_check: "age >= 0"
681
+ email_format: "email LIKE '%@%'"
682
+ ```
683
+ """
684
+
685
+ # key: rule name, value: sql condition (must be true for valid)
686
+ rules: Dict[str, str] = Field(
687
+ ..., description="Map of rule name to SQL condition (must be TRUE)"
688
+ )
689
+ flag_col: str = Field("_issues", description="Name of the column to store failed rules")
690
+
691
+ @field_validator("rules")
692
+ @classmethod
693
+ def require_non_empty_rules(cls, v):
694
+ if not v:
695
+ raise ValueError("ValidateAndFlag: 'rules' must not be empty")
696
+ return v
697
+
698
+
699
+ def validate_and_flag(context: EngineContext, params: ValidateAndFlagParams) -> EngineContext:
700
+ """
701
+ Validates rules and appends a column with a list/string of failed rule names.
702
+ """
703
+ ctx = get_logging_context()
704
+ start_time = time.time()
705
+
706
+ ctx.debug(
707
+ "ValidateAndFlag starting",
708
+ rules=list(params.rules.keys()),
709
+ flag_col=params.flag_col,
710
+ )
711
+
712
+ rule_exprs = []
713
+
714
+ for name, condition in params.rules.items():
715
+ expr = f"CASE WHEN NOT ({condition}) THEN '{name}' ELSE NULL END"
716
+ rule_exprs.append(expr)
717
+
718
+ if not rule_exprs:
719
+ return context.sql(f"SELECT *, NULL AS {params.flag_col} FROM df")
720
+
721
+ concatted = f"concat_ws(', ', {', '.join(rule_exprs)})"
722
+ final_expr = f"NULLIF({concatted}, '')"
723
+
724
+ sql_query = f"SELECT *, {final_expr} AS {params.flag_col} FROM df"
725
+ result = context.sql(sql_query)
726
+
727
+ elapsed_ms = (time.time() - start_time) * 1000
728
+ ctx.debug(
729
+ "ValidateAndFlag completed",
730
+ rules_count=len(params.rules),
731
+ elapsed_ms=round(elapsed_ms, 2),
732
+ )
733
+
734
+ return result
735
+
736
+
737
+ # -------------------------------------------------------------------------
738
+ # 10. Window Calculation
739
+ # -------------------------------------------------------------------------
740
+
741
+
742
+ class WindowCalculationParams(BaseModel):
743
+ """
744
+ Configuration for window functions.
745
+
746
+ Example:
747
+ ```yaml
748
+ window_calculation:
749
+ target_col: "cumulative_sales"
750
+ function: "sum(sales)"
751
+ partition_by: ["region"]
752
+ order_by: "date ASC"
753
+ ```
754
+ """
755
+
756
+ target_col: str
757
+ function: str = Field(..., description="Window function e.g. 'sum(amount)', 'rank()'")
758
+ partition_by: List[str] = Field(default_factory=list)
759
+ order_by: Optional[str] = None
760
+
761
+
762
+ def window_calculation(context: EngineContext, params: WindowCalculationParams) -> EngineContext:
763
+ """
764
+ Generic wrapper for Window functions.
765
+ """
766
+ partition_clause = ""
767
+ if params.partition_by:
768
+ partition_clause = f"PARTITION BY {', '.join(params.partition_by)}"
769
+
770
+ order_clause = ""
771
+ if params.order_by:
772
+ order_clause = f"ORDER BY {params.order_by}"
773
+
774
+ over_clause = f"OVER ({partition_clause} {order_clause})".strip()
775
+
776
+ expr = f"{params.function} {over_clause}"
777
+
778
+ sql_query = f"SELECT *, {expr} AS {params.target_col} FROM df"
779
+ return context.sql(sql_query)
780
+
781
+
782
+ # -------------------------------------------------------------------------
783
+ # 11. Normalize JSON
784
+ # -------------------------------------------------------------------------
785
+
786
+
787
+ class NormalizeJsonParams(BaseModel):
788
+ """
789
+ Configuration for JSON normalization.
790
+
791
+ Example:
792
+ ```yaml
793
+ normalize_json:
794
+ column: "json_data"
795
+ sep: "_"
796
+ ```
797
+ """
798
+
799
+ column: str = Field(..., description="Column containing nested JSON/Struct")
800
+ sep: str = Field("_", description="Separator for nested fields (e.g., 'parent_child')")
801
+
802
+
803
+ def normalize_json(context: EngineContext, params: NormalizeJsonParams) -> EngineContext:
804
+ """
805
+ Flattens a nested JSON/Struct column.
806
+ """
807
+ if context.engine_type == EngineType.SPARK:
808
+ # Spark: Top-level flatten using "col.*"
809
+ sql_query = f"SELECT *, {params.column}.* FROM df"
810
+ return context.sql(sql_query)
811
+
812
+ elif context.engine_type == EngineType.PANDAS:
813
+ import json
814
+
815
+ import pandas as pd
816
+
817
+ df = context.df.copy()
818
+
819
+ # Ensure we have dicts
820
+ s = df[params.column]
821
+ if len(s) > 0:
822
+ first_val = s.iloc[0]
823
+ if isinstance(first_val, str):
824
+ # Try to parse if string
825
+ try:
826
+ s = s.apply(json.loads)
827
+ except Exception as e:
828
+ import logging
829
+
830
+ logger = logging.getLogger(__name__)
831
+ logger.warning(f"Failed to parse JSON strings in column '{params.column}': {e}")
832
+ # We proceed, but json_normalize will likely fail if data is not dicts.
833
+
834
+ # json_normalize
835
+ # Handle empty case
836
+ if s.empty:
837
+ return context.with_df(df)
838
+
839
+ normalized = pd.json_normalize(s, sep=params.sep)
840
+ # Align index
841
+ normalized.index = df.index
842
+
843
+ # Join back (avoid collision if possible, or use suffixes)
844
+ # We use rsuffix just in case
845
+ df = df.join(normalized, rsuffix="_json")
846
+ return context.with_df(df)
847
+
848
+ else:
849
+ raise ValueError(f"Unsupported engine: {context.engine_type}")
850
+
851
+
852
+ # -------------------------------------------------------------------------
853
+ # 12. Sessionize
854
+ # -------------------------------------------------------------------------
855
+
856
+
857
+ class SessionizeParams(BaseModel):
858
+ """
859
+ Configuration for sessionization.
860
+
861
+ Example:
862
+ ```yaml
863
+ sessionize:
864
+ timestamp_col: "event_time"
865
+ user_col: "user_id"
866
+ threshold_seconds: 1800
867
+ ```
868
+ """
869
+
870
+ timestamp_col: str = Field(
871
+ ..., description="Timestamp column to calculate session duration from"
872
+ )
873
+ user_col: str = Field(..., description="User identifier to partition sessions by")
874
+ threshold_seconds: int = Field(
875
+ 1800,
876
+ description="Inactivity threshold in seconds (default: 30 minutes). If gap > threshold, new session starts.",
877
+ )
878
+ session_col: str = Field(
879
+ "session_id", description="Output column name for the generated session ID"
880
+ )
881
+
882
+
883
+ def sessionize(context: EngineContext, params: SessionizeParams) -> EngineContext:
884
+ """
885
+ Assigns session IDs based on inactivity threshold.
886
+ """
887
+ if context.engine_type == EngineType.SPARK:
888
+ # Spark SQL
889
+ # 1. Lag timestamp to get prev_timestamp
890
+ # 2. Calculate diff: ts - prev_ts
891
+ # 3. Flag new session: if diff > threshold OR prev_ts is null -> 1 else 0
892
+ # 4. Sum(flags) over (partition by user order by ts) -> session_id
893
+
894
+ threshold = params.threshold_seconds
895
+
896
+ # We use nested queries for clarity and safety against multiple aggregations
897
+ sql = f"""
898
+ WITH lagged AS (
899
+ SELECT *,
900
+ LAG({params.timestamp_col}) OVER (PARTITION BY {params.user_col} ORDER BY {params.timestamp_col}) as _prev_ts
901
+ FROM df
902
+ ),
903
+ flagged AS (
904
+ SELECT *,
905
+ CASE
906
+ WHEN _prev_ts IS NULL THEN 1
907
+ WHEN (unix_timestamp({params.timestamp_col}) - unix_timestamp(_prev_ts)) > {threshold} THEN 1
908
+ ELSE 0
909
+ END as _is_new_session
910
+ FROM lagged
911
+ )
912
+ SELECT *,
913
+ concat({params.user_col}, '-', sum(_is_new_session) OVER (PARTITION BY {params.user_col} ORDER BY {params.timestamp_col})) as {params.session_col}
914
+ FROM flagged
915
+ """
916
+ # Note: This returns intermediate columns (_prev_ts, _is_new_session) as well.
917
+ # Ideally we select * EXCEPT ... but Spark < 3.1 doesn't support EXCEPT in SELECT list easily without listing all cols.
918
+ # We leave them for now, or user can drop them.
919
+ return context.sql(sql)
920
+
921
+ elif context.engine_type == EngineType.PANDAS:
922
+ import pandas as pd
923
+
924
+ df = context.df.copy()
925
+
926
+ # Ensure datetime
927
+ if not pd.api.types.is_datetime64_any_dtype(df[params.timestamp_col]):
928
+ df[params.timestamp_col] = pd.to_datetime(df[params.timestamp_col])
929
+
930
+ # Sort
931
+ df = df.sort_values([params.user_col, params.timestamp_col])
932
+
933
+ user = df[params.user_col]
934
+
935
+ # Calculate time diff (in seconds)
936
+ # We groupby user to ensure shift doesn't cross user boundaries for diff
937
+ # But diff() doesn't support groupby well directly on Series without apply?
938
+ # Actually `groupby().diff()` works.
939
+ time_diff = df.groupby(params.user_col)[params.timestamp_col].diff().dt.total_seconds()
940
+
941
+ # Flag new session
942
+ # New if: time_diff > threshold OR time_diff is NaT (start of group)
943
+ is_new = (time_diff > params.threshold_seconds) | (time_diff.isna())
944
+
945
+ # Cumulative sum per user
946
+ session_ids = is_new.groupby(user).cumsum()
947
+
948
+ df[params.session_col] = user.astype(str) + "-" + session_ids.astype(int).astype(str)
949
+
950
+ return context.with_df(df)
951
+
952
+ else:
953
+ raise ValueError(f"Unsupported engine: {context.engine_type}")
954
+
955
+
956
+ # -------------------------------------------------------------------------
957
+ # 13. Geocode (Stub)
958
+ # -------------------------------------------------------------------------
959
+
960
+
961
+ class GeocodeParams(BaseModel):
962
+ """
963
+ Configuration for geocoding.
964
+
965
+ Example:
966
+ ```yaml
967
+ geocode:
968
+ address_col: "full_address"
969
+ output_col: "lat_long"
970
+ ```
971
+ """
972
+
973
+ address_col: str = Field(..., description="Column containing the address to geocode")
974
+ output_col: str = Field("lat_long", description="Name of the output column for coordinates")
975
+
976
+
977
+ def geocode(context: EngineContext, params: GeocodeParams) -> EngineContext:
978
+ """
979
+ Geocoding Stub.
980
+ """
981
+ import logging
982
+
983
+ logger = logging.getLogger(__name__)
984
+ logger.warning("Geocode transformer is a stub. No actual geocoding performed.")
985
+
986
+ # Pass-through
987
+ return context.with_df(context.df)
988
+
989
+
990
+ # -------------------------------------------------------------------------
991
+ # 14. Split Events by Period
992
+ # -------------------------------------------------------------------------
993
+
994
+
995
+ class ShiftDefinition(BaseModel):
996
+ """Definition of a single shift."""
997
+
998
+ name: str = Field(..., description="Name of the shift (e.g., 'Day', 'Night')")
999
+ start: str = Field(..., description="Start time in HH:MM format (e.g., '06:00')")
1000
+ end: str = Field(..., description="End time in HH:MM format (e.g., '14:00')")
1001
+
1002
+
1003
+ class SplitEventsByPeriodParams(BaseModel):
1004
+ """
1005
+ Configuration for splitting events that span multiple time periods.
1006
+
1007
+ Splits events that span multiple days, hours, or shifts into individual
1008
+ segments per period. Useful for OEE/downtime analysis, billing, and
1009
+ time-based aggregations.
1010
+
1011
+ Example - Split by day:
1012
+ ```yaml
1013
+ split_events_by_period:
1014
+ start_col: "Shutdown_Start_Time"
1015
+ end_col: "Shutdown_End_Time"
1016
+ period: "day"
1017
+ duration_col: "Shutdown_Duration_Min"
1018
+ ```
1019
+
1020
+ Example - Split by shift:
1021
+ ```yaml
1022
+ split_events_by_period:
1023
+ start_col: "event_start"
1024
+ end_col: "event_end"
1025
+ period: "shift"
1026
+ duration_col: "duration_minutes"
1027
+ shifts:
1028
+ - name: "Day"
1029
+ start: "06:00"
1030
+ end: "14:00"
1031
+ - name: "Swing"
1032
+ start: "14:00"
1033
+ end: "22:00"
1034
+ - name: "Night"
1035
+ start: "22:00"
1036
+ end: "06:00"
1037
+ ```
1038
+ """
1039
+
1040
+ start_col: str = Field(..., description="Column containing the event start timestamp")
1041
+ end_col: str = Field(..., description="Column containing the event end timestamp")
1042
+ period: str = Field(
1043
+ "day",
1044
+ description="Period type to split by: 'day', 'hour', or 'shift'",
1045
+ )
1046
+ duration_col: Optional[str] = Field(
1047
+ None,
1048
+ description="Output column name for duration in minutes. If not set, no duration column is added.",
1049
+ )
1050
+ shifts: Optional[List[ShiftDefinition]] = Field(
1051
+ None,
1052
+ description="List of shift definitions (required when period='shift')",
1053
+ )
1054
+ shift_col: Optional[str] = Field(
1055
+ "shift_name",
1056
+ description="Output column name for shift name (only used when period='shift')",
1057
+ )
1058
+
1059
+ def model_post_init(self, __context):
1060
+ if self.period == "shift" and not self.shifts:
1061
+ raise ValueError("shifts must be provided when period='shift'")
1062
+ if self.period not in ("day", "hour", "shift"):
1063
+ raise ValueError(f"Invalid period: {self.period}. Must be 'day', 'hour', or 'shift'")
1064
+
1065
+
1066
+ def split_events_by_period(
1067
+ context: EngineContext, params: SplitEventsByPeriodParams
1068
+ ) -> EngineContext:
1069
+ """
1070
+ Splits events that span multiple time periods into individual segments.
1071
+
1072
+ For events spanning multiple days/hours/shifts, this creates separate rows
1073
+ for each period with adjusted start/end times and recalculated durations.
1074
+ """
1075
+ if params.period == "day":
1076
+ return _split_by_day(context, params)
1077
+ elif params.period == "hour":
1078
+ return _split_by_hour(context, params)
1079
+ elif params.period == "shift":
1080
+ return _split_by_shift(context, params)
1081
+ else:
1082
+ raise ValueError(f"Unsupported period: {params.period}")
1083
+
1084
+
1085
+ def _split_by_day(context: EngineContext, params: SplitEventsByPeriodParams) -> EngineContext:
1086
+ """Split events by day boundaries."""
1087
+ start_col = params.start_col
1088
+ end_col = params.end_col
1089
+
1090
+ if context.engine_type == EngineType.SPARK:
1091
+ duration_col_exists = params.duration_col and params.duration_col.lower() in [
1092
+ c.lower() for c in context.df.columns
1093
+ ]
1094
+
1095
+ ts_start = f"to_timestamp({start_col})"
1096
+ ts_end = f"to_timestamp({end_col})"
1097
+
1098
+ duration_expr = ""
1099
+ if params.duration_col:
1100
+ duration_expr = f", (unix_timestamp(adj_{end_col}) - unix_timestamp(adj_{start_col})) / 60.0 AS {params.duration_col}"
1101
+
1102
+ single_day_except = "_event_days"
1103
+ if duration_col_exists:
1104
+ single_day_except = f"_event_days, {params.duration_col}"
1105
+
1106
+ single_day_sql = f"""
1107
+ WITH events_with_days AS (
1108
+ SELECT *,
1109
+ {ts_start} AS _ts_start,
1110
+ {ts_end} AS _ts_end,
1111
+ datediff(to_date({ts_end}), to_date({ts_start})) + 1 AS _event_days
1112
+ FROM df
1113
+ )
1114
+ SELECT * EXCEPT({single_day_except}, _ts_start, _ts_end){f", (unix_timestamp(_ts_end) - unix_timestamp(_ts_start)) / 60.0 AS {params.duration_col}" if params.duration_col else ""}
1115
+ FROM events_with_days
1116
+ WHERE _event_days = 1
1117
+ """
1118
+
1119
+ multi_day_except_adjusted = "_exploded_day, _event_days, _ts_start, _ts_end"
1120
+ if duration_col_exists:
1121
+ multi_day_except_adjusted = (
1122
+ f"_exploded_day, _event_days, _ts_start, _ts_end, {params.duration_col}"
1123
+ )
1124
+
1125
+ multi_day_sql = f"""
1126
+ WITH events_with_days AS (
1127
+ SELECT *,
1128
+ {ts_start} AS _ts_start,
1129
+ {ts_end} AS _ts_end,
1130
+ datediff(to_date({ts_end}), to_date({ts_start})) + 1 AS _event_days
1131
+ FROM df
1132
+ ),
1133
+ multi_day AS (
1134
+ SELECT *,
1135
+ explode(sequence(to_date(_ts_start), to_date(_ts_end), interval 1 day)) AS _exploded_day
1136
+ FROM events_with_days
1137
+ WHERE _event_days > 1
1138
+ ),
1139
+ multi_day_adjusted AS (
1140
+ SELECT * EXCEPT({multi_day_except_adjusted}),
1141
+ CASE
1142
+ WHEN to_date(_exploded_day) = to_date(_ts_start) THEN _ts_start
1143
+ ELSE to_timestamp(concat(cast(_exploded_day as string), ' 00:00:00'))
1144
+ END AS adj_{start_col},
1145
+ CASE
1146
+ WHEN to_date(_exploded_day) = to_date(_ts_end) THEN _ts_end
1147
+ ELSE to_timestamp(concat(cast(date_add(_exploded_day, 1) as string), ' 00:00:00'))
1148
+ END AS adj_{end_col}
1149
+ FROM multi_day
1150
+ )
1151
+ SELECT * EXCEPT({start_col}, {end_col}, adj_{start_col}, adj_{end_col}),
1152
+ adj_{start_col} AS {start_col},
1153
+ adj_{end_col} AS {end_col}
1154
+ {duration_expr}
1155
+ FROM multi_day_adjusted
1156
+ """
1157
+
1158
+ single_day_df = context.sql(single_day_sql).df
1159
+ multi_day_df = context.sql(multi_day_sql).df
1160
+ result_df = single_day_df.unionByName(multi_day_df, allowMissingColumns=True)
1161
+ return context.with_df(result_df)
1162
+
1163
+ elif context.engine_type == EngineType.PANDAS:
1164
+ import pandas as pd
1165
+
1166
+ df = context.df.copy()
1167
+
1168
+ df[start_col] = pd.to_datetime(df[start_col])
1169
+ df[end_col] = pd.to_datetime(df[end_col])
1170
+
1171
+ df["_event_days"] = (df[end_col].dt.normalize() - df[start_col].dt.normalize()).dt.days + 1
1172
+
1173
+ single_day = df[df["_event_days"] == 1].copy()
1174
+ multi_day = df[df["_event_days"] > 1].copy()
1175
+
1176
+ if params.duration_col and not single_day.empty:
1177
+ single_day[params.duration_col] = (
1178
+ single_day[end_col] - single_day[start_col]
1179
+ ).dt.total_seconds() / 60.0
1180
+
1181
+ if not multi_day.empty:
1182
+ rows = []
1183
+ for _, row in multi_day.iterrows():
1184
+ start = row[start_col]
1185
+ end = row[end_col]
1186
+ current_day = start.normalize()
1187
+
1188
+ while current_day <= end.normalize():
1189
+ new_row = row.copy()
1190
+
1191
+ if current_day == start.normalize():
1192
+ new_start = start
1193
+ else:
1194
+ new_start = current_day
1195
+
1196
+ next_day = current_day + pd.Timedelta(days=1)
1197
+ if next_day > end:
1198
+ new_end = end
1199
+ else:
1200
+ new_end = next_day
1201
+
1202
+ new_row[start_col] = new_start
1203
+ new_row[end_col] = new_end
1204
+
1205
+ if params.duration_col:
1206
+ new_row[params.duration_col] = (new_end - new_start).total_seconds() / 60.0
1207
+
1208
+ rows.append(new_row)
1209
+ current_day = next_day
1210
+
1211
+ multi_day_exploded = pd.DataFrame(rows)
1212
+ result = pd.concat([single_day, multi_day_exploded], ignore_index=True)
1213
+ else:
1214
+ result = single_day
1215
+
1216
+ result = result.drop(columns=["_event_days"], errors="ignore")
1217
+ return context.with_df(result)
1218
+
1219
+ else:
1220
+ raise ValueError(
1221
+ f"Split events by day does not support engine type '{context.engine_type}'. "
1222
+ f"Supported engines: SPARK, PANDAS. "
1223
+ f"Check your engine configuration."
1224
+ )
1225
+
1226
+
1227
+ def _split_by_hour(context: EngineContext, params: SplitEventsByPeriodParams) -> EngineContext:
1228
+ """Split events by hour boundaries."""
1229
+ start_col = params.start_col
1230
+ end_col = params.end_col
1231
+
1232
+ if context.engine_type == EngineType.SPARK:
1233
+ ts_start = f"to_timestamp({start_col})"
1234
+ ts_end = f"to_timestamp({end_col})"
1235
+
1236
+ duration_expr = ""
1237
+ if params.duration_col:
1238
+ duration_expr = f", (unix_timestamp({end_col}) - unix_timestamp({start_col})) / 60.0 AS {params.duration_col}"
1239
+
1240
+ sql = f"""
1241
+ WITH events_with_hours AS (
1242
+ SELECT *,
1243
+ {ts_start} AS _ts_start,
1244
+ {ts_end} AS _ts_end,
1245
+ CAST((unix_timestamp({ts_end}) - unix_timestamp({ts_start})) / 3600 AS INT) + 1 AS _event_hours
1246
+ FROM df
1247
+ ),
1248
+ multi_hour AS (
1249
+ SELECT *,
1250
+ explode(sequence(
1251
+ date_trunc('hour', _ts_start),
1252
+ date_trunc('hour', _ts_end),
1253
+ interval 1 hour
1254
+ )) AS _exploded_hour
1255
+ FROM events_with_hours
1256
+ WHERE _event_hours > 1
1257
+ ),
1258
+ multi_hour_adjusted AS (
1259
+ SELECT * EXCEPT(_exploded_hour, _event_hours, _ts_start, _ts_end),
1260
+ CASE
1261
+ WHEN date_trunc('hour', _ts_start) = _exploded_hour THEN _ts_start
1262
+ ELSE _exploded_hour
1263
+ END AS {start_col},
1264
+ CASE
1265
+ WHEN date_trunc('hour', _ts_end) = _exploded_hour THEN _ts_end
1266
+ ELSE _exploded_hour + interval 1 hour
1267
+ END AS {end_col}
1268
+ FROM multi_hour
1269
+ ),
1270
+ single_hour AS (
1271
+ SELECT * EXCEPT(_event_hours, _ts_start, _ts_end){duration_expr}
1272
+ FROM events_with_hours
1273
+ WHERE _event_hours = 1
1274
+ )
1275
+ SELECT *{duration_expr} FROM multi_hour_adjusted
1276
+ UNION ALL
1277
+ SELECT * FROM single_hour
1278
+ """
1279
+ return context.sql(sql)
1280
+
1281
+ elif context.engine_type == EngineType.PANDAS:
1282
+ import pandas as pd
1283
+
1284
+ df = context.df.copy()
1285
+
1286
+ df[start_col] = pd.to_datetime(df[start_col])
1287
+ df[end_col] = pd.to_datetime(df[end_col])
1288
+
1289
+ df["_event_hours"] = ((df[end_col] - df[start_col]).dt.total_seconds() / 3600).astype(
1290
+ int
1291
+ ) + 1
1292
+
1293
+ single_hour = df[df["_event_hours"] == 1].copy()
1294
+ multi_hour = df[df["_event_hours"] > 1].copy()
1295
+
1296
+ if params.duration_col and not single_hour.empty:
1297
+ single_hour[params.duration_col] = (
1298
+ single_hour[end_col] - single_hour[start_col]
1299
+ ).dt.total_seconds() / 60.0
1300
+
1301
+ if not multi_hour.empty:
1302
+ rows = []
1303
+ for _, row in multi_hour.iterrows():
1304
+ start = row[start_col]
1305
+ end = row[end_col]
1306
+ current_hour = start.floor("h")
1307
+
1308
+ while current_hour <= end.floor("h"):
1309
+ new_row = row.copy()
1310
+
1311
+ if current_hour == start.floor("h"):
1312
+ new_start = start
1313
+ else:
1314
+ new_start = current_hour
1315
+
1316
+ next_hour = current_hour + pd.Timedelta(hours=1)
1317
+ if next_hour > end:
1318
+ new_end = end
1319
+ else:
1320
+ new_end = next_hour
1321
+
1322
+ new_row[start_col] = new_start
1323
+ new_row[end_col] = new_end
1324
+
1325
+ if params.duration_col:
1326
+ new_row[params.duration_col] = (new_end - new_start).total_seconds() / 60.0
1327
+
1328
+ rows.append(new_row)
1329
+ current_hour = next_hour
1330
+
1331
+ multi_hour_exploded = pd.DataFrame(rows)
1332
+ result = pd.concat([single_hour, multi_hour_exploded], ignore_index=True)
1333
+ else:
1334
+ result = single_hour
1335
+
1336
+ result = result.drop(columns=["_event_hours"], errors="ignore")
1337
+ return context.with_df(result)
1338
+
1339
+ else:
1340
+ raise ValueError(
1341
+ f"Split events by hour does not support engine type '{context.engine_type}'. "
1342
+ f"Supported engines: SPARK, PANDAS. "
1343
+ f"Check your engine configuration."
1344
+ )
1345
+
1346
+
1347
+ def _split_by_shift(context: EngineContext, params: SplitEventsByPeriodParams) -> EngineContext:
1348
+ """Split events by shift boundaries."""
1349
+ start_col = params.start_col
1350
+ end_col = params.end_col
1351
+ shift_col = params.shift_col or "shift_name"
1352
+ shifts = params.shifts
1353
+
1354
+ if context.engine_type == EngineType.SPARK:
1355
+ ts_start = f"to_timestamp({start_col})"
1356
+ ts_end = f"to_timestamp({end_col})"
1357
+
1358
+ duration_expr = ""
1359
+ if params.duration_col:
1360
+ duration_expr = f", (unix_timestamp({end_col}) - unix_timestamp({start_col})) / 60.0 AS {params.duration_col}"
1361
+
1362
+ sql = f"""
1363
+ WITH base AS (
1364
+ SELECT *,
1365
+ {ts_start} AS _ts_start,
1366
+ {ts_end} AS _ts_end,
1367
+ datediff(to_date({ts_end}), to_date({ts_start})) + 1 AS _span_days
1368
+ FROM df
1369
+ ),
1370
+ day_exploded AS (
1371
+ SELECT *,
1372
+ explode(sequence(to_date(_ts_start), to_date(_ts_end), interval 1 day)) AS _day
1373
+ FROM base
1374
+ ),
1375
+ with_shift_times AS (
1376
+ SELECT *,
1377
+ {
1378
+ ", ".join(
1379
+ [
1380
+ f"to_timestamp(concat(cast(_day as string), ' {s.start}:00')) AS _shift_{i}_start, "
1381
+ + f"to_timestamp(concat(cast(date_add(_day, {1 if int(s.end.split(':')[0]) < int(s.start.split(':')[0]) else 0}) as string), ' {s.end}:00')) AS _shift_{i}_end"
1382
+ for i, s in enumerate(shifts)
1383
+ ]
1384
+ )
1385
+ }
1386
+ FROM day_exploded
1387
+ ),
1388
+ shift_segments AS (
1389
+ {
1390
+ " UNION ALL ".join(
1391
+ [
1392
+ f"SELECT * EXCEPT({', '.join([f'_shift_{j}_start, _shift_{j}_end' for j in range(len(shifts))])}, _ts_start, _ts_end), "
1393
+ f"GREATEST(_ts_start, _shift_{i}_start) AS {start_col}, "
1394
+ f"LEAST(_ts_end, _shift_{i}_end) AS {end_col}, "
1395
+ f"'{s.name}' AS {shift_col} "
1396
+ f"FROM with_shift_times "
1397
+ f"WHERE _ts_start < _shift_{i}_end AND _ts_end > _shift_{i}_start"
1398
+ for i, s in enumerate(shifts)
1399
+ ]
1400
+ )
1401
+ }
1402
+ )
1403
+ SELECT * EXCEPT(_span_days, _day){duration_expr}
1404
+ FROM shift_segments
1405
+ WHERE {start_col} < {end_col}
1406
+ """
1407
+ return context.sql(sql)
1408
+
1409
+ elif context.engine_type == EngineType.PANDAS:
1410
+ import pandas as pd
1411
+ from datetime import timedelta
1412
+
1413
+ df = context.df.copy()
1414
+ df[start_col] = pd.to_datetime(df[start_col])
1415
+ df[end_col] = pd.to_datetime(df[end_col])
1416
+
1417
+ def parse_time(t_str):
1418
+ h, m = map(int, t_str.split(":"))
1419
+ return timedelta(hours=h, minutes=m)
1420
+
1421
+ rows = []
1422
+ for _, row in df.iterrows():
1423
+ event_start = row[start_col]
1424
+ event_end = row[end_col]
1425
+
1426
+ current_day = event_start.normalize()
1427
+ while current_day <= event_end.normalize():
1428
+ for shift in shifts:
1429
+ shift_start_delta = parse_time(shift.start)
1430
+ shift_end_delta = parse_time(shift.end)
1431
+
1432
+ if shift_end_delta <= shift_start_delta:
1433
+ shift_start_dt = current_day + shift_start_delta
1434
+ shift_end_dt = current_day + timedelta(days=1) + shift_end_delta
1435
+ else:
1436
+ shift_start_dt = current_day + shift_start_delta
1437
+ shift_end_dt = current_day + shift_end_delta
1438
+
1439
+ seg_start = max(event_start, shift_start_dt)
1440
+ seg_end = min(event_end, shift_end_dt)
1441
+
1442
+ if seg_start < seg_end:
1443
+ new_row = row.copy()
1444
+ new_row[start_col] = seg_start
1445
+ new_row[end_col] = seg_end
1446
+ new_row[shift_col] = shift.name
1447
+
1448
+ if params.duration_col:
1449
+ new_row[params.duration_col] = (
1450
+ seg_end - seg_start
1451
+ ).total_seconds() / 60.0
1452
+
1453
+ rows.append(new_row)
1454
+
1455
+ current_day += timedelta(days=1)
1456
+
1457
+ if rows:
1458
+ result = pd.DataFrame(rows)
1459
+ else:
1460
+ result = df.copy()
1461
+ result[shift_col] = None
1462
+ if params.duration_col:
1463
+ result[params.duration_col] = 0.0
1464
+
1465
+ return context.with_df(result)
1466
+
1467
+ else:
1468
+ raise ValueError(
1469
+ f"Split events by shift does not support engine type '{context.engine_type}'. "
1470
+ f"Supported engines: SPARK, PANDAS. "
1471
+ f"Check your engine configuration."
1472
+ )