tracepipe 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,434 @@
1
+ # tracepipe/instrumentation/merge_capture.py
2
+ """
3
+ Merge provenance using position column injection.
4
+
5
+ CI Mode: Stats only (fast)
6
+ DEBUG Mode: Full parent RID mapping (position column injection)
7
+ """
8
+
9
+ import uuid
10
+ import warnings
11
+ from functools import wraps
12
+
13
+ import numpy as np
14
+ import pandas as pd
15
+
16
+ from ..context import get_context
17
+ from ..core import CompletenessLevel, MergeMapping, MergeStats
18
+ from ..safety import TracePipeWarning, get_caller_info
19
+
20
+
21
+ def wrap_merge_with_lineage(original_merge):
22
+ """
23
+ Wrap DataFrame.merge with lineage capture.
24
+ """
25
+
26
+ @wraps(original_merge)
27
+ def wrapper(self, right, *args, **kwargs):
28
+ ctx = get_context()
29
+
30
+ if not ctx.enabled:
31
+ return original_merge(self, right, *args, **kwargs)
32
+
33
+ if ctx.config.should_capture_merge_provenance:
34
+ return _merge_with_provenance(original_merge, self, right, args, kwargs, ctx)
35
+ else:
36
+ return _merge_with_stats_only(original_merge, self, right, args, kwargs, ctx)
37
+
38
+ return wrapper
39
+
40
+
41
+ def _merge_with_provenance(original_merge, left, right, args, kwargs, ctx):
42
+ """
43
+ Merge with full provenance (debug mode).
44
+ Uses position column injection.
45
+
46
+ Guards against validate= and other merge errors.
47
+ """
48
+ row_mgr = ctx.row_manager
49
+ store = ctx.store
50
+
51
+ # Get/register source RIDs
52
+ left_rids = row_mgr.get_ids_array(left)
53
+ right_rids = row_mgr.get_ids_array(right)
54
+
55
+ if left_rids is None:
56
+ left_rids = row_mgr.register(left)
57
+ if right_rids is None:
58
+ right_rids = row_mgr.register(right)
59
+
60
+ # Generate unique position column names
61
+ token = uuid.uuid4().hex[:12]
62
+ left_pos_col = f"__tp_lp_{token}__"
63
+ right_pos_col = f"__tp_rp_{token}__"
64
+
65
+ # Inject position columns (int32 for memory efficiency)
66
+ left_tracked = left.assign(**{left_pos_col: np.arange(len(left), dtype=np.int32)})
67
+ right_tracked = right.assign(**{right_pos_col: np.arange(len(right), dtype=np.int32)})
68
+
69
+ # Run merge in try/except to handle validate= errors
70
+ try:
71
+ result_tracked = original_merge(left_tracked, right_tracked, *args, **kwargs)
72
+ except Exception as e:
73
+ # Merge failed (e.g., validate="1:1" violation)
74
+ # Record error step for debuggability, then re-raise
75
+ return _record_merge_error_and_reraise(e, left, right, kwargs, ctx)
76
+
77
+ # Check for collision, don't rerun merge
78
+ if left_pos_col not in result_tracked.columns or right_pos_col not in result_tracked.columns:
79
+ warnings.warn(
80
+ "TracePipe: Position column collision in merge. Provenance marked PARTIAL.",
81
+ UserWarning,
82
+ )
83
+ # Don't rerun - just drop tracking columns and continue
84
+ result = result_tracked.drop(
85
+ columns=[c for c in [left_pos_col, right_pos_col] if c in result_tracked.columns],
86
+ errors="ignore",
87
+ )
88
+ return _finalize_merge_partial(result, left, right, kwargs, ctx)
89
+
90
+ # Extract indexers
91
+ left_indexer = result_tracked[left_pos_col].values
92
+ right_indexer = result_tracked[right_pos_col].values
93
+
94
+ # Drop tracking columns
95
+ result = result_tracked.drop(columns=[left_pos_col, right_pos_col])
96
+
97
+ # Build mapping with vectorized parent lookup
98
+ # Handle NaN for outer joins
99
+ left_valid = ~pd.isna(left_indexer)
100
+ right_valid = ~pd.isna(right_indexer)
101
+
102
+ left_parent_rids = np.full(len(result), -1, dtype=np.int64)
103
+ right_parent_rids = np.full(len(result), -1, dtype=np.int64)
104
+
105
+ left_parent_rids[left_valid] = left_rids[left_indexer[left_valid].astype(np.int64)]
106
+ right_parent_rids[right_valid] = right_rids[right_indexer[right_valid].astype(np.int64)]
107
+
108
+ # Assign new RIDs to result
109
+ result_rids = row_mgr.register(result)
110
+
111
+ # Record step
112
+ code_file, code_line = get_caller_info(skip_frames=3)
113
+
114
+ step_id = store.append_step(
115
+ operation="DataFrame.merge",
116
+ stage=ctx.current_stage,
117
+ code_file=code_file,
118
+ code_line=code_line,
119
+ params=_merge_params(kwargs),
120
+ input_shape=(left.shape, right.shape),
121
+ output_shape=result.shape,
122
+ completeness=CompletenessLevel.FULL,
123
+ )
124
+
125
+ # Sort mapping arrays by out_rids for O(log n) lookup
126
+ sort_idx = np.argsort(result_rids)
127
+ sorted_out_rids = result_rids[sort_idx]
128
+ sorted_left_parents = left_parent_rids[sort_idx]
129
+ sorted_right_parents = right_parent_rids[sort_idx]
130
+
131
+ mapping = MergeMapping(
132
+ step_id=step_id,
133
+ out_rids=sorted_out_rids,
134
+ left_parent_rids=sorted_left_parents,
135
+ right_parent_rids=sorted_right_parents,
136
+ )
137
+ store.merge_mappings.append(mapping)
138
+
139
+ # Compute stats from indexers
140
+ stats = _compute_stats_from_indexers(
141
+ left_indexer[left_valid].astype(np.int64),
142
+ right_indexer[right_valid].astype(np.int64),
143
+ len(left),
144
+ len(right),
145
+ len(result),
146
+ kwargs,
147
+ )
148
+ store.merge_stats.append((step_id, stats))
149
+
150
+ return result
151
+
152
+
153
+ def _finalize_merge_partial(result, left, right, kwargs, ctx):
154
+ """Finalize merge when provenance capture failed."""
155
+ row_mgr = ctx.row_manager
156
+ store = ctx.store
157
+
158
+ row_mgr.register(result)
159
+
160
+ code_file, code_line = get_caller_info(skip_frames=4)
161
+
162
+ step_id = store.append_step(
163
+ operation="DataFrame.merge",
164
+ stage=ctx.current_stage,
165
+ code_file=code_file,
166
+ code_line=code_line,
167
+ params=_merge_params(kwargs),
168
+ input_shape=(left.shape, right.shape),
169
+ output_shape=result.shape,
170
+ completeness=CompletenessLevel.PARTIAL,
171
+ )
172
+
173
+ stats = _compute_stats_approximate(len(left), len(right), len(result), kwargs)
174
+ store.merge_stats.append((step_id, stats))
175
+
176
+ return result
177
+
178
+
179
+ def _record_merge_error_and_reraise(error: Exception, left, right, kwargs, ctx):
180
+ """
181
+ Record merge error step for debuggability, then re-raise.
182
+
183
+ This handles cases like validate="1:1" violations where pandas raises
184
+ but we still want to record what was attempted.
185
+ """
186
+ store = ctx.store
187
+
188
+ code_file, code_line = get_caller_info(skip_frames=4)
189
+
190
+ # Record error step with exception info
191
+ params = _merge_params(kwargs)
192
+ params["_error"] = str(error)[:200] # Truncate long error messages
193
+ params["_error_type"] = type(error).__name__
194
+
195
+ store.append_step(
196
+ operation="DataFrame.merge (error)",
197
+ stage=ctx.current_stage,
198
+ code_file=code_file,
199
+ code_line=code_line,
200
+ params=params,
201
+ input_shape=(left.shape, right.shape),
202
+ output_shape=None, # No result
203
+ completeness=CompletenessLevel.PARTIAL,
204
+ )
205
+
206
+ # Re-raise the original exception so user code behaves normally
207
+ raise error
208
+
209
+
210
+ def _merge_with_stats_only(original_merge, left, right, args, kwargs, ctx):
211
+ """
212
+ Merge with stats only (CI mode).
213
+ Fast: no position injection.
214
+
215
+ Also handles merge errors for debuggability.
216
+ """
217
+ try:
218
+ result = original_merge(left, right, *args, **kwargs)
219
+ except Exception as e:
220
+ return _record_merge_error_and_reraise(e, left, right, kwargs, ctx)
221
+
222
+ row_mgr = ctx.row_manager
223
+ store = ctx.store
224
+
225
+ row_mgr.register(result)
226
+
227
+ code_file, code_line = get_caller_info(skip_frames=3)
228
+
229
+ # CI mode merge = PARTIAL (we know it's a merge, but no parent mapping)
230
+ step_id = store.append_step(
231
+ operation="DataFrame.merge",
232
+ stage=ctx.current_stage,
233
+ code_file=code_file,
234
+ code_line=code_line,
235
+ params=_merge_params(kwargs),
236
+ input_shape=(left.shape, right.shape),
237
+ output_shape=result.shape,
238
+ completeness=CompletenessLevel.PARTIAL,
239
+ )
240
+
241
+ stats = _compute_stats_approximate(len(left), len(right), len(result), kwargs)
242
+ store.merge_stats.append((step_id, stats))
243
+
244
+ return result
245
+
246
+
247
+ def _compute_stats_from_indexers(
248
+ left_indexer: np.ndarray,
249
+ right_indexer: np.ndarray,
250
+ n_left: int,
251
+ n_right: int,
252
+ n_result: int,
253
+ kwargs: dict,
254
+ ) -> MergeStats:
255
+ """Compute accurate merge stats from indexers."""
256
+
257
+ # Match rates
258
+ left_match_rate = len(np.unique(left_indexer)) / n_left if n_left > 0 else 0
259
+ right_match_rate = len(np.unique(right_indexer)) / n_right if n_right > 0 else 0
260
+
261
+ # Dup rates (rows appearing more than once)
262
+ if len(left_indexer) > 0:
263
+ left_counts = np.bincount(left_indexer, minlength=n_left)
264
+ left_dup_rate = (left_counts > 1).sum() / n_left if n_left > 0 else 0
265
+ else:
266
+ left_dup_rate = 0
267
+
268
+ if len(right_indexer) > 0:
269
+ right_counts = np.bincount(right_indexer, minlength=n_right)
270
+ right_dup_rate = (right_counts > 1).sum() / n_right if n_right > 0 else 0
271
+ else:
272
+ right_dup_rate = 0
273
+
274
+ return MergeStats(
275
+ left_rows=n_left,
276
+ right_rows=n_right,
277
+ result_rows=n_result,
278
+ expansion_ratio=n_result / max(n_left, n_right, 1),
279
+ left_match_rate=left_match_rate,
280
+ right_match_rate=right_match_rate,
281
+ left_dup_rate=left_dup_rate,
282
+ right_dup_rate=right_dup_rate,
283
+ how=kwargs.get("how", "inner"),
284
+ )
285
+
286
+
287
+ def _compute_stats_approximate(
288
+ n_left: int, n_right: int, n_result: int, kwargs: dict
289
+ ) -> MergeStats:
290
+ """Approximate stats (CI mode - fast, skip expensive computations)."""
291
+ return MergeStats(
292
+ left_rows=n_left,
293
+ right_rows=n_right,
294
+ result_rows=n_result,
295
+ expansion_ratio=n_result / max(n_left, n_right, 1),
296
+ left_match_rate=-1.0, # Unknown in CI mode
297
+ right_match_rate=-1.0,
298
+ left_dup_rate=-1.0,
299
+ right_dup_rate=-1.0,
300
+ how=kwargs.get("how", "inner"),
301
+ )
302
+
303
+
304
+ def _merge_params(kwargs: dict) -> dict:
305
+ return {
306
+ "how": kwargs.get("how", "inner"),
307
+ "on": str(kwargs.get("on", kwargs.get("left_on", "")))[:50],
308
+ }
309
+
310
+
311
+ # ============ JOIN WRAPPER ============
312
+
313
+
314
+ def wrap_join_with_lineage(original_join):
315
+ """
316
+ Wrap DataFrame.join with lineage capture.
317
+ Similar to merge but uses index-based joining.
318
+ """
319
+
320
+ @wraps(original_join)
321
+ def wrapper(self, other, *args, **kwargs):
322
+ ctx = get_context()
323
+
324
+ if not ctx.enabled:
325
+ return original_join(self, other, *args, **kwargs)
326
+
327
+ # Run join
328
+ try:
329
+ result = original_join(self, other, *args, **kwargs)
330
+ except Exception as e:
331
+ if ctx.config.strict_mode:
332
+ raise
333
+ warnings.warn(f"TracePipe: Join failed: {e}", TracePipeWarning)
334
+ raise
335
+
336
+ row_mgr = ctx.row_manager
337
+ store = ctx.store
338
+
339
+ # Register result
340
+ row_mgr.register(result)
341
+
342
+ code_file, code_line = get_caller_info(skip_frames=2)
343
+
344
+ # Record step
345
+ step_id = store.append_step(
346
+ operation="DataFrame.join",
347
+ stage=ctx.current_stage,
348
+ code_file=code_file,
349
+ code_line=code_line,
350
+ params={"how": kwargs.get("how", "left")},
351
+ input_shape=(self.shape, other.shape if hasattr(other, "shape") else None),
352
+ output_shape=result.shape,
353
+ completeness=CompletenessLevel.PARTIAL, # Join is complex, mark PARTIAL
354
+ )
355
+
356
+ # Compute basic stats
357
+ n_left = len(self)
358
+ n_right = len(other) if hasattr(other, "__len__") else 0
359
+ n_result = len(result)
360
+
361
+ stats = MergeStats(
362
+ left_rows=n_left,
363
+ right_rows=n_right,
364
+ result_rows=n_result,
365
+ expansion_ratio=n_result / max(n_left, 1),
366
+ left_match_rate=-1.0,
367
+ right_match_rate=-1.0,
368
+ left_dup_rate=-1.0,
369
+ right_dup_rate=-1.0,
370
+ how=kwargs.get("how", "left"),
371
+ )
372
+ store.merge_stats.append((step_id, stats))
373
+
374
+ return result
375
+
376
+ return wrapper
377
+
378
+
379
+ # ============ CONCAT WRAPPER ============
380
+
381
+
382
+ def wrap_concat_with_lineage(original_concat):
383
+ """
384
+ Wrap pd.concat with lineage capture.
385
+ """
386
+
387
+ @wraps(original_concat)
388
+ def wrapper(objs, *args, **kwargs):
389
+ ctx = get_context()
390
+
391
+ result = original_concat(objs, *args, **kwargs)
392
+
393
+ if not ctx.enabled:
394
+ return result
395
+
396
+ if not isinstance(result, pd.DataFrame):
397
+ return result
398
+
399
+ try:
400
+ row_mgr = ctx.row_manager
401
+ store = ctx.store
402
+
403
+ # Register result
404
+ row_mgr.register(result)
405
+
406
+ code_file, code_line = get_caller_info(skip_frames=2)
407
+
408
+ # Compute input shapes
409
+ input_shapes = []
410
+ for obj in objs:
411
+ if hasattr(obj, "shape"):
412
+ input_shapes.append(obj.shape)
413
+
414
+ store.append_step(
415
+ operation="pd.concat",
416
+ stage=ctx.current_stage,
417
+ code_file=code_file,
418
+ code_line=code_line,
419
+ params={
420
+ "axis": kwargs.get("axis", 0),
421
+ "n_inputs": len(objs) if hasattr(objs, "__len__") else 1,
422
+ },
423
+ input_shape=tuple(input_shapes) if input_shapes else None,
424
+ output_shape=result.shape,
425
+ completeness=CompletenessLevel.PARTIAL, # Concat resets lineage
426
+ )
427
+ except Exception as e:
428
+ if ctx.config.strict_mode:
429
+ raise
430
+ warnings.warn(f"TracePipe: Concat capture failed: {e}", TracePipeWarning)
431
+
432
+ return result
433
+
434
+ return wrapper