odibi 2.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. odibi/__init__.py +32 -0
  2. odibi/__main__.py +8 -0
  3. odibi/catalog.py +3011 -0
  4. odibi/cli/__init__.py +11 -0
  5. odibi/cli/__main__.py +6 -0
  6. odibi/cli/catalog.py +553 -0
  7. odibi/cli/deploy.py +69 -0
  8. odibi/cli/doctor.py +161 -0
  9. odibi/cli/export.py +66 -0
  10. odibi/cli/graph.py +150 -0
  11. odibi/cli/init_pipeline.py +242 -0
  12. odibi/cli/lineage.py +259 -0
  13. odibi/cli/main.py +215 -0
  14. odibi/cli/run.py +98 -0
  15. odibi/cli/schema.py +208 -0
  16. odibi/cli/secrets.py +232 -0
  17. odibi/cli/story.py +379 -0
  18. odibi/cli/system.py +132 -0
  19. odibi/cli/test.py +286 -0
  20. odibi/cli/ui.py +31 -0
  21. odibi/cli/validate.py +39 -0
  22. odibi/config.py +3541 -0
  23. odibi/connections/__init__.py +9 -0
  24. odibi/connections/azure_adls.py +499 -0
  25. odibi/connections/azure_sql.py +709 -0
  26. odibi/connections/base.py +28 -0
  27. odibi/connections/factory.py +322 -0
  28. odibi/connections/http.py +78 -0
  29. odibi/connections/local.py +119 -0
  30. odibi/connections/local_dbfs.py +61 -0
  31. odibi/constants.py +17 -0
  32. odibi/context.py +528 -0
  33. odibi/diagnostics/__init__.py +12 -0
  34. odibi/diagnostics/delta.py +520 -0
  35. odibi/diagnostics/diff.py +169 -0
  36. odibi/diagnostics/manager.py +171 -0
  37. odibi/engine/__init__.py +20 -0
  38. odibi/engine/base.py +334 -0
  39. odibi/engine/pandas_engine.py +2178 -0
  40. odibi/engine/polars_engine.py +1114 -0
  41. odibi/engine/registry.py +54 -0
  42. odibi/engine/spark_engine.py +2362 -0
  43. odibi/enums.py +7 -0
  44. odibi/exceptions.py +297 -0
  45. odibi/graph.py +426 -0
  46. odibi/introspect.py +1214 -0
  47. odibi/lineage.py +511 -0
  48. odibi/node.py +3341 -0
  49. odibi/orchestration/__init__.py +0 -0
  50. odibi/orchestration/airflow.py +90 -0
  51. odibi/orchestration/dagster.py +77 -0
  52. odibi/patterns/__init__.py +24 -0
  53. odibi/patterns/aggregation.py +599 -0
  54. odibi/patterns/base.py +94 -0
  55. odibi/patterns/date_dimension.py +423 -0
  56. odibi/patterns/dimension.py +696 -0
  57. odibi/patterns/fact.py +748 -0
  58. odibi/patterns/merge.py +128 -0
  59. odibi/patterns/scd2.py +148 -0
  60. odibi/pipeline.py +2382 -0
  61. odibi/plugins.py +80 -0
  62. odibi/project.py +581 -0
  63. odibi/references.py +151 -0
  64. odibi/registry.py +246 -0
  65. odibi/semantics/__init__.py +71 -0
  66. odibi/semantics/materialize.py +392 -0
  67. odibi/semantics/metrics.py +361 -0
  68. odibi/semantics/query.py +743 -0
  69. odibi/semantics/runner.py +430 -0
  70. odibi/semantics/story.py +507 -0
  71. odibi/semantics/views.py +432 -0
  72. odibi/state/__init__.py +1203 -0
  73. odibi/story/__init__.py +55 -0
  74. odibi/story/doc_story.py +554 -0
  75. odibi/story/generator.py +1431 -0
  76. odibi/story/lineage.py +1043 -0
  77. odibi/story/lineage_utils.py +324 -0
  78. odibi/story/metadata.py +608 -0
  79. odibi/story/renderers.py +453 -0
  80. odibi/story/templates/run_story.html +2520 -0
  81. odibi/story/themes.py +216 -0
  82. odibi/testing/__init__.py +13 -0
  83. odibi/testing/assertions.py +75 -0
  84. odibi/testing/fixtures.py +85 -0
  85. odibi/testing/source_pool.py +277 -0
  86. odibi/transformers/__init__.py +122 -0
  87. odibi/transformers/advanced.py +1472 -0
  88. odibi/transformers/delete_detection.py +610 -0
  89. odibi/transformers/manufacturing.py +1029 -0
  90. odibi/transformers/merge_transformer.py +778 -0
  91. odibi/transformers/relational.py +675 -0
  92. odibi/transformers/scd.py +579 -0
  93. odibi/transformers/sql_core.py +1356 -0
  94. odibi/transformers/validation.py +165 -0
  95. odibi/ui/__init__.py +0 -0
  96. odibi/ui/app.py +195 -0
  97. odibi/utils/__init__.py +66 -0
  98. odibi/utils/alerting.py +667 -0
  99. odibi/utils/config_loader.py +343 -0
  100. odibi/utils/console.py +231 -0
  101. odibi/utils/content_hash.py +202 -0
  102. odibi/utils/duration.py +43 -0
  103. odibi/utils/encoding.py +102 -0
  104. odibi/utils/extensions.py +28 -0
  105. odibi/utils/hashing.py +61 -0
  106. odibi/utils/logging.py +203 -0
  107. odibi/utils/logging_context.py +740 -0
  108. odibi/utils/progress.py +429 -0
  109. odibi/utils/setup_helpers.py +302 -0
  110. odibi/utils/telemetry.py +140 -0
  111. odibi/validation/__init__.py +62 -0
  112. odibi/validation/engine.py +765 -0
  113. odibi/validation/explanation_linter.py +155 -0
  114. odibi/validation/fk.py +547 -0
  115. odibi/validation/gate.py +252 -0
  116. odibi/validation/quarantine.py +605 -0
  117. odibi/writers/__init__.py +15 -0
  118. odibi/writers/sql_server_writer.py +2081 -0
  119. odibi-2.5.0.dist-info/METADATA +255 -0
  120. odibi-2.5.0.dist-info/RECORD +124 -0
  121. odibi-2.5.0.dist-info/WHEEL +5 -0
  122. odibi-2.5.0.dist-info/entry_points.txt +2 -0
  123. odibi-2.5.0.dist-info/licenses/LICENSE +190 -0
  124. odibi-2.5.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,696 @@
1
+ import time
2
+ from datetime import datetime
3
+ from typing import Any, List, Optional
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+ from odibi.context import EngineContext
8
+ from odibi.enums import EngineType
9
+ from odibi.patterns.base import Pattern
10
+ from odibi.transformers.scd import SCD2Params, scd2
11
+ from odibi.utils.logging_context import get_logging_context
12
+
13
+
14
+ class AuditConfig(BaseModel):
15
+ """Configuration for audit columns."""
16
+
17
+ load_timestamp: bool = Field(default=True, description="Add load_timestamp column")
18
+ source_system: Optional[str] = Field(
19
+ default=None, description="Source system name for source_system column"
20
+ )
21
+
22
+
23
+ class DimensionPattern(Pattern):
24
+ """
25
+ Dimension Pattern: Builds complete dimension tables with surrogate keys and SCD support.
26
+
27
+ Features:
28
+ - Auto-generate integer surrogate keys (MAX(existing) + ROW_NUMBER for new rows)
29
+ - SCD Type 0 (static), 1 (overwrite), 2 (history tracking)
30
+ - Optional unknown member row (SK=0) for orphan FK handling
31
+ - Audit columns (load_timestamp, source_system)
32
+
33
+ Configuration Options (via params dict):
34
+ - **natural_key** (str): Natural/business key column name
35
+ - **surrogate_key** (str): Surrogate key column name to generate
36
+ - **scd_type** (int): 0=static, 1=overwrite, 2=history tracking (default: 1)
37
+ - **track_cols** (list): Columns to track for SCD1/2 changes
38
+ - **target** (str): Target table path (required for SCD2 to read existing history)
39
+ - **unknown_member** (bool): If true, insert a row with SK=0 for orphan FK handling
40
+ - **audit** (dict): Audit configuration with load_timestamp and source_system
41
+
42
+ Supported target formats:
43
+ Spark:
44
+ - Catalog tables: catalog.schema.table, warehouse.dim_customer
45
+ - Delta paths: /path/to/delta (no extension)
46
+ - Parquet: /path/to/file.parquet
47
+ - CSV: /path/to/file.csv
48
+ - JSON: /path/to/file.json
49
+ - ORC: /path/to/file.orc
50
+ Pandas:
51
+ - Parquet: path/to/file.parquet (or directory)
52
+ - CSV: path/to/file.csv
53
+ - JSON: path/to/file.json
54
+ - Excel: path/to/file.xlsx, path/to/file.xls
55
+ - Feather/Arrow: path/to/file.feather, path/to/file.arrow
56
+ - Pickle: path/to/file.pickle, path/to/file.pkl
57
+ - Connection-prefixed: warehouse.dim_customer
58
+ """
59
+
60
+ def validate(self) -> None:
61
+ ctx = get_logging_context()
62
+ ctx.debug(
63
+ "DimensionPattern validation starting",
64
+ pattern="DimensionPattern",
65
+ params=self.params,
66
+ )
67
+
68
+ if not self.params.get("natural_key"):
69
+ ctx.error(
70
+ "DimensionPattern validation failed: 'natural_key' is required",
71
+ pattern="DimensionPattern",
72
+ )
73
+ raise ValueError(
74
+ "DimensionPattern: 'natural_key' parameter is required. "
75
+ "The natural_key identifies the business key column(s) that uniquely identify "
76
+ "each dimension record in the source system. "
77
+ "Provide natural_key as a string (single column) or list of strings (composite key)."
78
+ )
79
+
80
+ if not self.params.get("surrogate_key"):
81
+ ctx.error(
82
+ "DimensionPattern validation failed: 'surrogate_key' is required",
83
+ pattern="DimensionPattern",
84
+ )
85
+ raise ValueError(
86
+ "DimensionPattern: 'surrogate_key' parameter is required. "
87
+ "The surrogate_key is the auto-generated primary key column for the dimension table, "
88
+ "used to join with fact tables instead of the natural key. "
89
+ "Provide surrogate_key as a string specifying the column name (e.g., 'customer_sk')."
90
+ )
91
+
92
+ scd_type = self.params.get("scd_type", 1)
93
+ if scd_type not in (0, 1, 2):
94
+ ctx.error(
95
+ f"DimensionPattern validation failed: invalid scd_type {scd_type}",
96
+ pattern="DimensionPattern",
97
+ )
98
+ raise ValueError(
99
+ f"DimensionPattern: 'scd_type' must be 0, 1, or 2. Got: {scd_type}. "
100
+ "SCD Type 0: No changes tracked (static dimension). "
101
+ "SCD Type 1: Overwrite changes (no history). "
102
+ "SCD Type 2: Track full history with valid_from/valid_to dates."
103
+ )
104
+
105
+ if scd_type == 2 and not self.params.get("target"):
106
+ ctx.error(
107
+ "DimensionPattern validation failed: 'target' required for SCD2",
108
+ pattern="DimensionPattern",
109
+ )
110
+ raise ValueError(
111
+ "DimensionPattern: 'target' parameter is required for scd_type=2. "
112
+ "SCD Type 2 compares incoming data against existing records to detect changes, "
113
+ "so a target DataFrame containing current dimension data must be provided. "
114
+ "Pass the existing dimension table as the 'target' parameter."
115
+ )
116
+
117
+ if scd_type in (1, 2) and not self.params.get("track_cols"):
118
+ ctx.error(
119
+ "DimensionPattern validation failed: 'track_cols' required for SCD1/2",
120
+ pattern="DimensionPattern",
121
+ )
122
+ raise ValueError(
123
+ "DimensionPattern: 'track_cols' parameter is required for scd_type 1 or 2. "
124
+ "The track_cols specifies which columns to monitor for changes. "
125
+ "When these columns change, SCD1 overwrites values or SCD2 creates new history records. "
126
+ "Provide track_cols as a list of column names (e.g., ['address', 'phone', 'email'])."
127
+ )
128
+
129
+ ctx.debug(
130
+ "DimensionPattern validation passed",
131
+ pattern="DimensionPattern",
132
+ )
133
+
134
+ def execute(self, context: EngineContext) -> Any:
135
+ ctx = get_logging_context()
136
+ start_time = time.time()
137
+
138
+ natural_key = self.params.get("natural_key")
139
+ surrogate_key = self.params.get("surrogate_key")
140
+ scd_type = self.params.get("scd_type", 1)
141
+ track_cols = self.params.get("track_cols", [])
142
+ target = self.params.get("target")
143
+ unknown_member = self.params.get("unknown_member", False)
144
+ audit_config = self.params.get("audit", {})
145
+
146
+ ctx.debug(
147
+ "DimensionPattern starting",
148
+ pattern="DimensionPattern",
149
+ natural_key=natural_key,
150
+ surrogate_key=surrogate_key,
151
+ scd_type=scd_type,
152
+ track_cols=track_cols,
153
+ target=target,
154
+ unknown_member=unknown_member,
155
+ )
156
+
157
+ source_count = self._get_row_count(context.df, context.engine_type)
158
+ ctx.debug("Dimension source loaded", pattern="DimensionPattern", source_rows=source_count)
159
+
160
+ try:
161
+ if scd_type == 0:
162
+ result_df = self._execute_scd0(context, natural_key, surrogate_key, target)
163
+ elif scd_type == 1:
164
+ result_df = self._execute_scd1(
165
+ context, natural_key, surrogate_key, track_cols, target
166
+ )
167
+ else:
168
+ result_df = self._execute_scd2(
169
+ context, natural_key, surrogate_key, track_cols, target
170
+ )
171
+
172
+ result_df = self._add_audit_columns(context, result_df, audit_config)
173
+
174
+ if unknown_member:
175
+ result_df = self._ensure_unknown_member(
176
+ context, result_df, natural_key, surrogate_key, audit_config
177
+ )
178
+
179
+ result_count = self._get_row_count(result_df, context.engine_type)
180
+ elapsed_ms = (time.time() - start_time) * 1000
181
+
182
+ ctx.info(
183
+ "DimensionPattern completed",
184
+ pattern="DimensionPattern",
185
+ elapsed_ms=round(elapsed_ms, 2),
186
+ source_rows=source_count,
187
+ result_rows=result_count,
188
+ scd_type=scd_type,
189
+ )
190
+
191
+ return result_df
192
+
193
+ except Exception as e:
194
+ elapsed_ms = (time.time() - start_time) * 1000
195
+ ctx.error(
196
+ f"DimensionPattern failed: {e}",
197
+ pattern="DimensionPattern",
198
+ error_type=type(e).__name__,
199
+ elapsed_ms=round(elapsed_ms, 2),
200
+ )
201
+ raise
202
+
203
+ def _get_row_count(self, df, engine_type) -> Optional[int]:
204
+ try:
205
+ if engine_type == EngineType.SPARK:
206
+ return df.count()
207
+ else:
208
+ return len(df)
209
+ except Exception:
210
+ return None
211
+
212
+ def _load_existing_target(self, context: EngineContext, target: str):
213
+ """Load existing target table if it exists."""
214
+ if context.engine_type == EngineType.SPARK:
215
+ return self._load_existing_spark(context, target)
216
+ else:
217
+ return self._load_existing_pandas(context, target)
218
+
219
+ def _load_existing_spark(self, context: EngineContext, target: str):
220
+ """Load existing target table from Spark with multi-format support."""
221
+ ctx = get_logging_context()
222
+ spark = context.spark
223
+
224
+ # Try catalog table first
225
+ try:
226
+ return spark.table(target)
227
+ except Exception:
228
+ pass
229
+
230
+ # Check file extension for format detection
231
+ target_lower = target.lower()
232
+
233
+ try:
234
+ if target_lower.endswith(".parquet"):
235
+ return spark.read.parquet(target)
236
+ elif target_lower.endswith(".csv"):
237
+ return spark.read.option("header", "true").option("inferSchema", "true").csv(target)
238
+ elif target_lower.endswith(".json"):
239
+ return spark.read.json(target)
240
+ elif target_lower.endswith(".orc"):
241
+ return spark.read.orc(target)
242
+ else:
243
+ # Try Delta format as fallback (for paths without extension)
244
+ return spark.read.format("delta").load(target)
245
+ except Exception as e:
246
+ ctx.warning(
247
+ f"Could not load existing target '{target}': {e}. Treating as initial load.",
248
+ pattern="DimensionPattern",
249
+ target=target,
250
+ )
251
+ return None
252
+
253
+ def _load_existing_pandas(self, context: EngineContext, target: str):
254
+ """Load existing target table from Pandas with multi-format support."""
255
+ import os
256
+
257
+ import pandas as pd
258
+
259
+ ctx = get_logging_context()
260
+ path = target
261
+
262
+ # Handle connection-prefixed paths
263
+ if hasattr(context, "engine") and context.engine:
264
+ if "." in path:
265
+ parts = path.split(".", 1)
266
+ conn_name = parts[0]
267
+ rel_path = parts[1]
268
+ if conn_name in context.engine.connections:
269
+ try:
270
+ path = context.engine.connections[conn_name].get_path(rel_path)
271
+ except Exception:
272
+ pass
273
+
274
+ if not os.path.exists(path):
275
+ return None
276
+
277
+ path_lower = str(path).lower()
278
+
279
+ try:
280
+ # Parquet (file or directory)
281
+ if path_lower.endswith(".parquet") or os.path.isdir(path):
282
+ return pd.read_parquet(path)
283
+ # CSV
284
+ elif path_lower.endswith(".csv"):
285
+ return pd.read_csv(path)
286
+ # JSON
287
+ elif path_lower.endswith(".json"):
288
+ return pd.read_json(path)
289
+ # Excel
290
+ elif path_lower.endswith(".xlsx") or path_lower.endswith(".xls"):
291
+ return pd.read_excel(path)
292
+ # Feather / Arrow IPC
293
+ elif path_lower.endswith(".feather") or path_lower.endswith(".arrow"):
294
+ return pd.read_feather(path)
295
+ # Pickle
296
+ elif path_lower.endswith(".pickle") or path_lower.endswith(".pkl"):
297
+ return pd.read_pickle(path)
298
+ else:
299
+ ctx.warning(
300
+ f"Unrecognized file format for target '{target}'. "
301
+ "Supported formats: parquet, csv, json, xlsx, xls, feather, arrow, pickle. "
302
+ "Treating as initial load.",
303
+ pattern="DimensionPattern",
304
+ target=target,
305
+ )
306
+ return None
307
+ except Exception as e:
308
+ ctx.warning(
309
+ f"Could not load existing target '{target}': {e}. Treating as initial load.",
310
+ pattern="DimensionPattern",
311
+ target=target,
312
+ )
313
+ return None
314
+
315
+ def _get_max_sk(self, df, surrogate_key: str, engine_type) -> int:
316
+ """Get the maximum surrogate key value from existing data."""
317
+ if df is None:
318
+ return 0
319
+ try:
320
+ if engine_type == EngineType.SPARK:
321
+ from pyspark.sql import functions as F
322
+
323
+ max_row = df.agg(F.max(surrogate_key)).collect()[0]
324
+ max_val = max_row[0]
325
+ return max_val if max_val is not None else 0
326
+ else:
327
+ if surrogate_key not in df.columns:
328
+ return 0
329
+ max_val = df[surrogate_key].max()
330
+ return int(max_val) if max_val is not None and not (max_val != max_val) else 0
331
+ except Exception:
332
+ return 0
333
+
334
+ def _generate_surrogate_keys(
335
+ self,
336
+ context: EngineContext,
337
+ df,
338
+ natural_key: str,
339
+ surrogate_key: str,
340
+ start_sk: int,
341
+ ):
342
+ """Generate surrogate keys starting from start_sk + 1."""
343
+ if context.engine_type == EngineType.SPARK:
344
+ from pyspark.sql import functions as F
345
+ from pyspark.sql.window import Window
346
+
347
+ window = Window.orderBy(natural_key)
348
+ df = df.withColumn(
349
+ surrogate_key, (F.row_number().over(window) + F.lit(start_sk)).cast("int")
350
+ )
351
+ return df
352
+ else:
353
+ df = df.copy()
354
+ df = df.sort_values(by=natural_key).reset_index(drop=True)
355
+ df[surrogate_key] = range(start_sk + 1, start_sk + 1 + len(df))
356
+ df[surrogate_key] = df[surrogate_key].astype("int64")
357
+ return df
358
+
359
+ def _execute_scd0(
360
+ self,
361
+ context: EngineContext,
362
+ natural_key: str,
363
+ surrogate_key: str,
364
+ target: Optional[str],
365
+ ):
366
+ """
367
+ SCD Type 0: Static dimension - never update existing records.
368
+ Only insert new records that don't exist in target.
369
+ """
370
+ existing_df = self._load_existing_target(context, target) if target else None
371
+ source_df = context.df
372
+
373
+ if existing_df is None:
374
+ return self._generate_surrogate_keys(
375
+ context, source_df, natural_key, surrogate_key, start_sk=0
376
+ )
377
+
378
+ max_sk = self._get_max_sk(existing_df, surrogate_key, context.engine_type)
379
+
380
+ if context.engine_type == EngineType.SPARK:
381
+ existing_keys = existing_df.select(natural_key).distinct()
382
+ new_records = source_df.join(existing_keys, on=natural_key, how="left_anti")
383
+ else:
384
+ existing_keys = set(existing_df[natural_key].unique())
385
+ new_records = source_df[~source_df[natural_key].isin(existing_keys)].copy()
386
+
387
+ if self._get_row_count(new_records, context.engine_type) == 0:
388
+ return existing_df
389
+
390
+ new_with_sk = self._generate_surrogate_keys(
391
+ context, new_records, natural_key, surrogate_key, start_sk=max_sk
392
+ )
393
+
394
+ if context.engine_type == EngineType.SPARK:
395
+ return existing_df.unionByName(new_with_sk, allowMissingColumns=True)
396
+ else:
397
+ import pandas as pd
398
+
399
+ return pd.concat([existing_df, new_with_sk], ignore_index=True)
400
+
401
+ def _execute_scd1(
402
+ self,
403
+ context: EngineContext,
404
+ natural_key: str,
405
+ surrogate_key: str,
406
+ track_cols: List[str],
407
+ target: Optional[str],
408
+ ):
409
+ """
410
+ SCD Type 1: Overwrite changes - no history tracking.
411
+ Update existing records in place, insert new records.
412
+ """
413
+ existing_df = self._load_existing_target(context, target) if target else None
414
+ source_df = context.df
415
+
416
+ if existing_df is None:
417
+ return self._generate_surrogate_keys(
418
+ context, source_df, natural_key, surrogate_key, start_sk=0
419
+ )
420
+
421
+ max_sk = self._get_max_sk(existing_df, surrogate_key, context.engine_type)
422
+
423
+ if context.engine_type == EngineType.SPARK:
424
+ return self._execute_scd1_spark(
425
+ context, source_df, existing_df, natural_key, surrogate_key, track_cols, max_sk
426
+ )
427
+ else:
428
+ return self._execute_scd1_pandas(
429
+ context, source_df, existing_df, natural_key, surrogate_key, track_cols, max_sk
430
+ )
431
+
432
+ def _execute_scd1_spark(
433
+ self,
434
+ context: EngineContext,
435
+ source_df,
436
+ existing_df,
437
+ natural_key: str,
438
+ surrogate_key: str,
439
+ track_cols: List[str],
440
+ max_sk: int,
441
+ ):
442
+ from pyspark.sql import functions as F
443
+
444
+ t_prefix = "__existing_"
445
+ renamed_existing = existing_df
446
+ for c in existing_df.columns:
447
+ renamed_existing = renamed_existing.withColumnRenamed(c, f"{t_prefix}{c}")
448
+
449
+ joined = source_df.join(
450
+ renamed_existing,
451
+ source_df[natural_key] == renamed_existing[f"{t_prefix}{natural_key}"],
452
+ "left",
453
+ )
454
+
455
+ new_records = joined.filter(F.col(f"{t_prefix}{natural_key}").isNull()).select(
456
+ source_df.columns
457
+ )
458
+
459
+ update_records = joined.filter(F.col(f"{t_prefix}{natural_key}").isNotNull())
460
+ update_cols = [F.col(f"{t_prefix}{surrogate_key}").alias(surrogate_key)] + [
461
+ F.col(c) for c in source_df.columns
462
+ ]
463
+ updated_records = update_records.select(update_cols)
464
+
465
+ unchanged_keys = update_records.select(F.col(f"{t_prefix}{natural_key}").alias(natural_key))
466
+ unchanged = existing_df.join(unchanged_keys, on=natural_key, how="left_anti")
467
+
468
+ new_with_sk = self._generate_surrogate_keys(
469
+ context, new_records, natural_key, surrogate_key, start_sk=max_sk
470
+ )
471
+
472
+ result = unchanged.unionByName(updated_records, allowMissingColumns=True).unionByName(
473
+ new_with_sk, allowMissingColumns=True
474
+ )
475
+ return result
476
+
477
+ def _execute_scd1_pandas(
478
+ self,
479
+ context: EngineContext,
480
+ source_df,
481
+ existing_df,
482
+ natural_key: str,
483
+ surrogate_key: str,
484
+ track_cols: List[str],
485
+ max_sk: int,
486
+ ):
487
+ import pandas as pd
488
+
489
+ merged = pd.merge(
490
+ source_df,
491
+ existing_df[[natural_key, surrogate_key]],
492
+ on=natural_key,
493
+ how="left",
494
+ suffixes=("", "_existing"),
495
+ )
496
+
497
+ has_existing_sk = f"{surrogate_key}_existing" in merged.columns
498
+ if has_existing_sk:
499
+ merged[surrogate_key] = merged[f"{surrogate_key}_existing"]
500
+ merged = merged.drop(columns=[f"{surrogate_key}_existing"])
501
+
502
+ new_mask = merged[surrogate_key].isna()
503
+ new_records = merged[new_mask].drop(columns=[surrogate_key])
504
+ existing_records = merged[~new_mask]
505
+
506
+ if len(new_records) > 0:
507
+ new_with_sk = self._generate_surrogate_keys(
508
+ context, new_records, natural_key, surrogate_key, start_sk=max_sk
509
+ )
510
+ else:
511
+ new_with_sk = pd.DataFrame()
512
+
513
+ unchanged = existing_df[~existing_df[natural_key].isin(source_df[natural_key])]
514
+
515
+ result = pd.concat([unchanged, existing_records, new_with_sk], ignore_index=True)
516
+ return result
517
+
518
+ def _execute_scd2(
519
+ self,
520
+ context: EngineContext,
521
+ natural_key: str,
522
+ surrogate_key: str,
523
+ track_cols: List[str],
524
+ target: str,
525
+ ):
526
+ """
527
+ SCD Type 2: History tracking - reuse existing scd2 transformer.
528
+ Surrogate keys are generated for new/changed records.
529
+ """
530
+ existing_df = self._load_existing_target(context, target)
531
+
532
+ valid_from_col = self.params.get("valid_from_col", "valid_from")
533
+ valid_to_col = self.params.get("valid_to_col", "valid_to")
534
+ is_current_col = self.params.get("is_current_col", "is_current")
535
+
536
+ if context.engine_type == EngineType.SPARK:
537
+ from pyspark.sql import functions as F
538
+
539
+ source_with_time = context.df.withColumn(valid_from_col, F.current_timestamp())
540
+ else:
541
+ source_df = context.df.copy()
542
+ source_df[valid_from_col] = datetime.now()
543
+ source_with_time = source_df
544
+
545
+ temp_context = context.with_df(source_with_time)
546
+
547
+ scd_params = SCD2Params(
548
+ target=target,
549
+ keys=[natural_key],
550
+ track_cols=track_cols,
551
+ effective_time_col=valid_from_col,
552
+ end_time_col=valid_to_col,
553
+ current_flag_col=is_current_col,
554
+ )
555
+
556
+ result_context = scd2(temp_context, scd_params)
557
+ result_df = result_context.df
558
+
559
+ max_sk = self._get_max_sk(existing_df, surrogate_key, context.engine_type)
560
+
561
+ if context.engine_type == EngineType.SPARK:
562
+ from pyspark.sql import functions as F
563
+ from pyspark.sql.window import Window
564
+
565
+ if surrogate_key not in result_df.columns:
566
+ window = Window.orderBy(natural_key, valid_from_col)
567
+ result_df = result_df.withColumn(
568
+ surrogate_key, (F.row_number().over(window) + F.lit(max_sk)).cast("int")
569
+ )
570
+ else:
571
+ null_sk_df = result_df.filter(F.col(surrogate_key).isNull())
572
+ has_sk_df = result_df.filter(F.col(surrogate_key).isNotNull())
573
+
574
+ if null_sk_df.count() > 0:
575
+ window = Window.orderBy(natural_key, valid_from_col)
576
+ null_sk_df = null_sk_df.withColumn(
577
+ surrogate_key, (F.row_number().over(window) + F.lit(max_sk)).cast("int")
578
+ )
579
+ result_df = has_sk_df.unionByName(null_sk_df)
580
+ else:
581
+ import pandas as pd
582
+
583
+ if surrogate_key not in result_df.columns:
584
+ result_df = result_df.sort_values([natural_key, valid_from_col]).reset_index(
585
+ drop=True
586
+ )
587
+ result_df[surrogate_key] = range(max_sk + 1, max_sk + 1 + len(result_df))
588
+ else:
589
+ null_mask = result_df[surrogate_key].isna()
590
+ if null_mask.any():
591
+ null_df = result_df[null_mask].copy()
592
+ null_df = null_df.sort_values([natural_key, valid_from_col]).reset_index(
593
+ drop=True
594
+ )
595
+ null_df[surrogate_key] = range(max_sk + 1, max_sk + 1 + len(null_df))
596
+ result_df = pd.concat([result_df[~null_mask], null_df], ignore_index=True)
597
+
598
+ return result_df
599
+
600
+ def _add_audit_columns(self, context: EngineContext, df, audit_config: dict):
601
+ """Add audit columns (load_timestamp, source_system) to the dataframe."""
602
+ load_timestamp = audit_config.get("load_timestamp", True)
603
+ source_system = audit_config.get("source_system")
604
+
605
+ if context.engine_type == EngineType.SPARK:
606
+ from pyspark.sql import functions as F
607
+
608
+ if load_timestamp:
609
+ df = df.withColumn("load_timestamp", F.current_timestamp())
610
+ if source_system:
611
+ df = df.withColumn("source_system", F.lit(source_system))
612
+ else:
613
+ df = df.copy()
614
+ if load_timestamp:
615
+ df["load_timestamp"] = datetime.now()
616
+ if source_system:
617
+ df["source_system"] = source_system
618
+
619
+ return df
620
+
621
+ def _ensure_unknown_member(
622
+ self,
623
+ context: EngineContext,
624
+ df,
625
+ natural_key: str,
626
+ surrogate_key: str,
627
+ audit_config: dict,
628
+ ):
629
+ """Ensure unknown member row exists with SK=0."""
630
+ valid_from_col = self.params.get("valid_from_col", "valid_from")
631
+ valid_to_col = self.params.get("valid_to_col", "valid_to")
632
+ is_current_col = self.params.get("is_current_col", "is_current")
633
+
634
+ if context.engine_type == EngineType.SPARK:
635
+ from pyspark.sql import functions as F
636
+
637
+ existing_unknown = df.filter(F.col(surrogate_key) == 0)
638
+ if existing_unknown.count() > 0:
639
+ return df
640
+
641
+ columns = df.columns
642
+ unknown_values = []
643
+ for col in columns:
644
+ if col == surrogate_key:
645
+ unknown_values.append(0)
646
+ elif col == natural_key:
647
+ unknown_values.append("-1")
648
+ elif col == valid_from_col:
649
+ unknown_values.append(datetime(1900, 1, 1))
650
+ elif col == valid_to_col:
651
+ unknown_values.append(None)
652
+ elif col == is_current_col:
653
+ unknown_values.append(True)
654
+ elif col == "load_timestamp":
655
+ unknown_values.append(datetime.now())
656
+ elif col == "source_system":
657
+ unknown_values.append(audit_config.get("source_system", "Unknown"))
658
+ else:
659
+ unknown_values.append("Unknown")
660
+
661
+ unknown_row = context.spark.createDataFrame([unknown_values], columns)
662
+ return unknown_row.unionByName(df)
663
+ else:
664
+ import pandas as pd
665
+
666
+ if (df[surrogate_key] == 0).any():
667
+ return df
668
+
669
+ unknown_row = {}
670
+ for col in df.columns:
671
+ if col == surrogate_key:
672
+ unknown_row[col] = 0
673
+ elif col == natural_key:
674
+ unknown_row[col] = "-1"
675
+ elif col == valid_from_col:
676
+ unknown_row[col] = datetime(1900, 1, 1)
677
+ elif col == valid_to_col:
678
+ unknown_row[col] = None
679
+ elif col == is_current_col:
680
+ unknown_row[col] = True
681
+ elif col == "load_timestamp":
682
+ unknown_row[col] = datetime.now()
683
+ elif col == "source_system":
684
+ unknown_row[col] = audit_config.get("source_system", "Unknown")
685
+ else:
686
+ dtype = df[col].dtype
687
+ if pd.api.types.is_numeric_dtype(dtype):
688
+ unknown_row[col] = 0
689
+ else:
690
+ unknown_row[col] = "Unknown"
691
+
692
+ unknown_df = pd.DataFrame([unknown_row])
693
+ for col in unknown_df.columns:
694
+ if col in df.columns:
695
+ unknown_df[col] = unknown_df[col].astype(df[col].dtype)
696
+ return pd.concat([unknown_df, df], ignore_index=True)