tracepipe 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1024 @@
1
+ # tracepipe/instrumentation/pandas_inst.py
2
+ """
3
+ Pandas DataFrame instrumentation for row-level lineage tracking.
4
+ """
5
+
6
+ import warnings
7
+ from functools import wraps
8
+ from typing import Any
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+
13
+ from ..context import TracePipeContext, get_context
14
+ from ..core import ChangeType, CompletenessLevel
15
+ from ..safety import (
16
+ TracePipeWarning,
17
+ get_caller_info,
18
+ wrap_pandas_filter_method,
19
+ wrap_pandas_method,
20
+ wrap_pandas_method_inplace,
21
+ )
22
+ from ..utils.value_capture import find_changed_indices_vectorized
23
+
24
+ # Store original methods for restore
25
+ _originals: dict[str, Any] = {}
26
+
27
+
28
+ # === FILTER CAPTURE ===
29
+
30
+
31
+ def _capture_filter(
32
+ self: pd.DataFrame, args, kwargs, result, ctx: TracePipeContext, method_name: str
33
+ ):
34
+ """Capture lineage for filter operations (dropna, query, head, etc.)."""
35
+ if not isinstance(result, pd.DataFrame):
36
+ return
37
+
38
+ source_ids = ctx.row_manager.get_ids(self)
39
+ if source_ids is None:
40
+ # Auto-register if not tracked
41
+ ctx.row_manager.register(self)
42
+ source_ids = ctx.row_manager.get_ids(self)
43
+ if source_ids is None:
44
+ return
45
+
46
+ # Propagate IDs to result
47
+ ctx.row_manager.propagate(self, result)
48
+
49
+ # Find dropped rows (returns numpy array for performance)
50
+ dropped_ids = ctx.row_manager.get_dropped_ids(self, result)
51
+
52
+ if len(dropped_ids) > 0:
53
+ code_file, code_line = get_caller_info(skip_frames=2)
54
+ step_id = ctx.store.append_step(
55
+ operation=f"DataFrame.{method_name}",
56
+ stage=ctx.current_stage,
57
+ code_file=code_file,
58
+ code_line=code_line,
59
+ params=_safe_params(kwargs),
60
+ input_shape=self.shape,
61
+ output_shape=result.shape,
62
+ )
63
+
64
+ # Bulk record all drops at once (10-50x faster than loop)
65
+ ctx.store.append_bulk_drops(step_id, dropped_ids)
66
+
67
+
68
+ # === TRANSFORM CAPTURE ===
69
+
70
+
71
+ def _capture_transform(
72
+ self: pd.DataFrame, args, kwargs, result, ctx: TracePipeContext, method_name: str
73
+ ):
74
+ """Capture lineage for transform operations (fillna, replace, astype)."""
75
+ if not isinstance(result, pd.DataFrame):
76
+ return
77
+
78
+ source_ids = ctx.row_manager.get_ids(self)
79
+ if source_ids is None:
80
+ return
81
+
82
+ # Propagate IDs
83
+ ctx.row_manager.propagate(self, result)
84
+
85
+ # Only track watched columns
86
+ if not ctx.watched_columns:
87
+ return
88
+
89
+ # Determine affected columns
90
+ affected = _get_affected_columns(method_name, args, kwargs, self.columns)
91
+ cols_to_track = affected & ctx.watched_columns
92
+
93
+ if not cols_to_track:
94
+ return
95
+
96
+ # Check if mass update
97
+ if not ctx.store.should_track_cell_diffs(len(self)):
98
+ code_file, code_line = get_caller_info(skip_frames=2)
99
+ ctx.store.append_step(
100
+ operation=f"DataFrame.{method_name}",
101
+ stage=ctx.current_stage,
102
+ code_file=code_file,
103
+ code_line=code_line,
104
+ params=_safe_params(kwargs),
105
+ input_shape=self.shape,
106
+ output_shape=result.shape,
107
+ is_mass_update=True,
108
+ rows_affected=len(self),
109
+ )
110
+ return
111
+
112
+ # Find common index for vectorized comparison
113
+ common_index = self.index.intersection(result.index).intersection(source_ids.index)
114
+ if len(common_index) == 0:
115
+ return
116
+
117
+ ids_aligned = source_ids.reindex(common_index)
118
+ ids_arr = ids_aligned.values
119
+
120
+ # Check each column for changes using vectorized comparison
121
+ all_changes = [] # Collect (row_id, col, old_val, new_val)
122
+
123
+ for col in cols_to_track:
124
+ if col not in self.columns or col not in result.columns:
125
+ continue
126
+
127
+ old_aligned = self[col].reindex(common_index)
128
+ new_aligned = result[col].reindex(common_index)
129
+
130
+ # Vectorized change detection (~50-100x faster)
131
+ changed_mask = find_changed_indices_vectorized(old_aligned, new_aligned)
132
+
133
+ if not changed_mask.any():
134
+ continue
135
+
136
+ changed_indices = np.where(changed_mask)[0]
137
+ old_arr = old_aligned.values
138
+ new_arr = new_aligned.values
139
+
140
+ for i in changed_indices:
141
+ all_changes.append((int(ids_arr[i]), col, old_arr[i], new_arr[i]))
142
+
143
+ if not all_changes:
144
+ return # No changes detected, skip step creation
145
+
146
+ # Create step only if there are actual changes
147
+ code_file, code_line = get_caller_info(skip_frames=2)
148
+ step_id = ctx.store.append_step(
149
+ operation=f"DataFrame.{method_name}",
150
+ stage=ctx.current_stage,
151
+ code_file=code_file,
152
+ code_line=code_line,
153
+ params=_safe_params(kwargs),
154
+ input_shape=self.shape,
155
+ output_shape=result.shape,
156
+ )
157
+
158
+ # Record all changes
159
+ for row_id, col, old_val, new_val in all_changes:
160
+ ctx.store.append_diff(
161
+ step_id=step_id,
162
+ row_id=row_id,
163
+ col=col,
164
+ old_val=old_val,
165
+ new_val=new_val,
166
+ change_type=ChangeType.MODIFIED,
167
+ )
168
+
169
+
170
+ # === APPLY/PIPE CAPTURE (PARTIAL) ===
171
+
172
+
173
+ def _capture_apply(
174
+ self: pd.DataFrame, args, kwargs, result, ctx: TracePipeContext, method_name: str
175
+ ):
176
+ """Capture apply/pipe with PARTIAL completeness."""
177
+ func = args[0] if args else kwargs.get("func")
178
+ func_name = getattr(func, "__name__", "<lambda>")
179
+
180
+ code_file, code_line = get_caller_info(skip_frames=2)
181
+ step_id = ctx.store.append_step(
182
+ operation=f"DataFrame.{method_name}({func_name})",
183
+ stage=ctx.current_stage,
184
+ code_file=code_file,
185
+ code_line=code_line,
186
+ params={"func": func_name},
187
+ input_shape=self.shape,
188
+ output_shape=result.shape if hasattr(result, "shape") else None,
189
+ completeness=CompletenessLevel.PARTIAL,
190
+ )
191
+
192
+ # Still propagate and track output changes if result is DataFrame
193
+ if isinstance(result, pd.DataFrame):
194
+ ctx.row_manager.propagate(self, result)
195
+
196
+ # Track changes to watched columns
197
+ if ctx.watched_columns:
198
+ _capture_cell_changes(ctx, step_id, self, result)
199
+
200
+
201
+ # === GROUPBY CAPTURE ===
202
+
203
+
204
+ def _capture_groupby(
205
+ self: pd.DataFrame, args, kwargs, result, ctx: TracePipeContext, method_name: str
206
+ ):
207
+ """Capture groupby without re-calling groupby."""
208
+ row_ids = ctx.row_manager.get_ids(self)
209
+ if row_ids is None:
210
+ return
211
+
212
+ by = args[0] if args else kwargs.get("by")
213
+
214
+ # Extract groups from RESULT (already computed)
215
+ if hasattr(result, "groups"):
216
+ # Clear any stale groupby state from same source (handles new groupby on same df)
217
+ ctx.clear_groupby_for_source(id(self))
218
+
219
+ ctx.push_groupby(
220
+ {
221
+ "source_df": self,
222
+ "source_id": id(self),
223
+ "row_ids": row_ids,
224
+ "by": by,
225
+ "groups": result.groups,
226
+ }
227
+ )
228
+
229
+
230
+ def _capture_agg(self, args, kwargs, result, ctx: TracePipeContext, method_name: str):
231
+ """
232
+ Capture aggregation and record group membership.
233
+
234
+ Note: Uses peek_groupby() instead of pop_groupby() to support
235
+ multiple aggregations on the same GroupBy object:
236
+
237
+ grouped = df.groupby("category")
238
+ means = grouped.mean() # First agg - state preserved
239
+ sums = grouped.sum() # Second agg - still works!
240
+
241
+ State is cleared when a new groupby() is called on the same source.
242
+ """
243
+ pending = ctx.peek_groupby()
244
+ if pending is None:
245
+ return
246
+
247
+ row_ids = pending["row_ids"]
248
+ source_df = pending["source_df"]
249
+ max_membership = ctx.config.max_group_membership_size
250
+
251
+ # Build membership mapping (with threshold for large groups)
252
+ membership = {}
253
+ for group_key, indices in pending["groups"].items():
254
+ group_size = len(indices)
255
+
256
+ if group_size > max_membership:
257
+ # Large group - store count only to save memory
258
+ # Use special marker: negative count indicates "count only"
259
+ membership[str(group_key)] = [-group_size]
260
+ else:
261
+ # Normal group - store full membership
262
+ member_ids = []
263
+ for idx in indices:
264
+ if idx in row_ids.index:
265
+ member_ids.append(int(row_ids.loc[idx]))
266
+ membership[str(group_key)] = member_ids
267
+
268
+ # Determine aggregation functions
269
+ agg_funcs = {}
270
+ if args:
271
+ agg_arg = args[0]
272
+ if isinstance(agg_arg, dict):
273
+ agg_funcs = {k: str(v) for k, v in agg_arg.items()}
274
+ elif isinstance(agg_arg, str):
275
+ agg_funcs = {"_all_": agg_arg}
276
+ elif isinstance(agg_arg, list):
277
+ agg_funcs = {"_all_": str(agg_arg)}
278
+
279
+ code_file, code_line = get_caller_info(skip_frames=2)
280
+ step_id = ctx.store.append_step(
281
+ operation=f"GroupBy.{method_name}",
282
+ stage=ctx.current_stage,
283
+ code_file=code_file,
284
+ code_line=code_line,
285
+ params={"by": str(pending["by"])},
286
+ input_shape=source_df.shape,
287
+ output_shape=result.shape if hasattr(result, "shape") else None,
288
+ )
289
+
290
+ ctx.store.append_aggregation(
291
+ step_id=step_id,
292
+ group_column=str(pending["by"]),
293
+ membership=membership,
294
+ agg_functions=agg_funcs,
295
+ )
296
+
297
+ # Register result with new IDs (aggregation creates new "rows")
298
+ if isinstance(result, pd.DataFrame):
299
+ ctx.row_manager.register(result)
300
+
301
+
302
+ # === MERGE/CONCAT (UNKNOWN - OUT OF SCOPE) ===
303
+
304
+
305
+ def _capture_merge(
306
+ self: pd.DataFrame, args, kwargs, result, ctx: TracePipeContext, method_name: str
307
+ ):
308
+ """Mark merge as UNKNOWN completeness and reset lineage."""
309
+ code_file, code_line = get_caller_info(skip_frames=2)
310
+ ctx.store.append_step(
311
+ operation=f"DataFrame.{method_name}",
312
+ stage=ctx.current_stage,
313
+ code_file=code_file,
314
+ code_line=code_line,
315
+ params={"how": kwargs.get("how", "inner")},
316
+ input_shape=self.shape,
317
+ output_shape=result.shape if hasattr(result, "shape") else None,
318
+ completeness=CompletenessLevel.UNKNOWN,
319
+ )
320
+
321
+ warnings.warn(
322
+ f"TracePipe: {method_name}() resets row lineage. "
323
+ f"Rows in result cannot be traced back to source rows.",
324
+ TracePipeWarning,
325
+ )
326
+
327
+ # Register result with NEW IDs
328
+ if isinstance(result, pd.DataFrame):
329
+ ctx.row_manager.register(result)
330
+
331
+
332
+ def _capture_concat(args, kwargs, result, ctx: TracePipeContext):
333
+ """Capture pd.concat (module-level function)."""
334
+ code_file, code_line = get_caller_info(skip_frames=2)
335
+ ctx.store.append_step(
336
+ operation="pd.concat",
337
+ stage=ctx.current_stage,
338
+ code_file=code_file,
339
+ code_line=code_line,
340
+ params={"axis": kwargs.get("axis", 0)},
341
+ input_shape=None,
342
+ output_shape=result.shape if hasattr(result, "shape") else None,
343
+ completeness=CompletenessLevel.UNKNOWN,
344
+ )
345
+
346
+ warnings.warn("TracePipe: pd.concat() resets row lineage.", TracePipeWarning)
347
+
348
+ if isinstance(result, pd.DataFrame):
349
+ ctx.row_manager.register(result)
350
+
351
+
352
+ # === INDEX OPERATIONS ===
353
+
354
+
355
+ def _capture_reset_index(
356
+ self: pd.DataFrame, args, kwargs, result, ctx: TracePipeContext, method_name: str
357
+ ):
358
+ """Handle reset_index which changes index alignment."""
359
+ if not isinstance(result, pd.DataFrame):
360
+ return
361
+
362
+ drop = kwargs.get("drop", False)
363
+ if args and len(args) > 0:
364
+ # reset_index(drop=True) might pass drop as positional
365
+ pass
366
+
367
+ if drop:
368
+ ctx.row_manager.realign_for_reset_index(self, result)
369
+ else:
370
+ ctx.row_manager.propagate(self, result)
371
+
372
+
373
+ def _capture_set_index(
374
+ self: pd.DataFrame, args, kwargs, result, ctx: TracePipeContext, method_name: str
375
+ ):
376
+ """Handle set_index."""
377
+ if isinstance(result, pd.DataFrame):
378
+ ctx.row_manager.propagate(self, result)
379
+
380
+
381
+ def _capture_sort_values(
382
+ self: pd.DataFrame, args, kwargs, result, ctx: TracePipeContext, method_name: str
383
+ ):
384
+ """Handle sort_values with order tracking."""
385
+ if not isinstance(result, pd.DataFrame):
386
+ return
387
+
388
+ source_ids = ctx.row_manager.get_ids(self)
389
+ if source_ids is None:
390
+ return
391
+
392
+ ctx.row_manager.propagate(self, result)
393
+
394
+ by = args[0] if args else kwargs.get("by")
395
+ ascending = kwargs.get("ascending", True)
396
+
397
+ code_file, code_line = get_caller_info(skip_frames=2)
398
+ step_id = ctx.store.append_step(
399
+ operation="DataFrame.sort_values",
400
+ stage=ctx.current_stage,
401
+ code_file=code_file,
402
+ code_line=code_line,
403
+ params={"by": str(by), "ascending": ascending},
404
+ input_shape=self.shape,
405
+ output_shape=result.shape,
406
+ )
407
+
408
+ # Record reorder for each row
409
+ result_ids = ctx.row_manager.get_ids(result)
410
+ if result_ids is not None:
411
+ for new_pos, (idx, row_id) in enumerate(result_ids.items()):
412
+ # Find old position
413
+ try:
414
+ old_pos = list(source_ids.index).index(idx)
415
+ if old_pos != new_pos:
416
+ ctx.store.append_diff(
417
+ step_id=step_id,
418
+ row_id=int(row_id),
419
+ col="__position__",
420
+ old_val=old_pos,
421
+ new_val=new_pos,
422
+ change_type=ChangeType.REORDERED,
423
+ )
424
+ except (ValueError, KeyError):
425
+ pass
426
+
427
+
428
+ # === COPY CAPTURE ===
429
+
430
+
431
+ def _capture_copy(
432
+ self: pd.DataFrame, args, kwargs, result, ctx: TracePipeContext, method_name: str
433
+ ):
434
+ """
435
+ Capture df.copy() - propagate row IDs to the copy.
436
+
437
+ Without this, copies would lose their row identity and become untracked.
438
+ """
439
+ if not isinstance(result, pd.DataFrame):
440
+ return
441
+
442
+ source_ids = ctx.row_manager.get_ids(self)
443
+ if source_ids is None:
444
+ return
445
+
446
+ # Propagate IDs to the copy (same rows, new DataFrame object)
447
+ ctx.row_manager.propagate(self, result)
448
+
449
+
450
+ # === DROP CAPTURE ===
451
+
452
+
453
+ def _capture_drop(
454
+ self: pd.DataFrame, args, kwargs, result, ctx: TracePipeContext, method_name: str
455
+ ):
456
+ """
457
+ Capture df.drop() - handles both row and column drops.
458
+
459
+ - Row drops (axis=0): Track as filter operation
460
+ - Column drops (axis=1): Track as schema change (step metadata only)
461
+ """
462
+ if not isinstance(result, pd.DataFrame):
463
+ return
464
+
465
+ axis = kwargs.get("axis", 0)
466
+ if args and isinstance(args[0], int):
467
+ axis = args[0]
468
+
469
+ source_ids = ctx.row_manager.get_ids(self)
470
+
471
+ if axis == 0 or axis == "index":
472
+ # Row drop - similar to filter
473
+ if source_ids is None:
474
+ return
475
+
476
+ ctx.row_manager.propagate(self, result)
477
+ dropped_ids = ctx.row_manager.get_dropped_ids(self, result)
478
+
479
+ if len(dropped_ids) > 0:
480
+ labels = kwargs.get("labels") or kwargs.get("index") or (args[0] if args else None)
481
+ code_file, code_line = get_caller_info(skip_frames=2)
482
+ step_id = ctx.store.append_step(
483
+ operation="DataFrame.drop",
484
+ stage=ctx.current_stage,
485
+ code_file=code_file,
486
+ code_line=code_line,
487
+ params={"axis": "index", "labels": str(labels)[:100]},
488
+ input_shape=self.shape,
489
+ output_shape=result.shape,
490
+ )
491
+
492
+ # Bulk record all drops at once
493
+ ctx.store.append_bulk_drops(step_id, dropped_ids)
494
+ else:
495
+ # Column drop - schema change, just propagate IDs
496
+ if source_ids is not None:
497
+ ctx.row_manager.propagate(self, result)
498
+
499
+ columns = kwargs.get("columns") or kwargs.get("labels") or (args[0] if args else None)
500
+ code_file, code_line = get_caller_info(skip_frames=2)
501
+ ctx.store.append_step(
502
+ operation="DataFrame.drop",
503
+ stage=ctx.current_stage,
504
+ code_file=code_file,
505
+ code_line=code_line,
506
+ params={"axis": "columns", "columns": str(columns)[:100]},
507
+ input_shape=self.shape,
508
+ output_shape=result.shape,
509
+ )
510
+
511
+
512
+ # === __getitem__ DISPATCH ===
513
+
514
+
515
+ def _capture_getitem(
516
+ self: pd.DataFrame, args, kwargs, result, ctx: TracePipeContext, method_name: str
517
+ ):
518
+ """
519
+ Dispatch __getitem__ based on key type.
520
+
521
+ - df['col'] -> Series (ignore)
522
+ - df[['a','b']] -> DataFrame column select (propagate)
523
+ - df[mask] -> DataFrame row filter (track drops)
524
+ - df[slice] -> DataFrame row slice (track drops)
525
+ """
526
+ if len(args) != 1:
527
+ return
528
+
529
+ key = args[0]
530
+
531
+ # Series result - column access, not row filter
532
+ if isinstance(result, pd.Series):
533
+ return
534
+
535
+ if not isinstance(result, pd.DataFrame):
536
+ return
537
+
538
+ # Boolean mask - row filter
539
+ if isinstance(key, (pd.Series, np.ndarray)) and getattr(key, "dtype", None) is np.dtype("bool"):
540
+ # Skip if we're inside a named filter op (e.g., drop_duplicates)
541
+ # to avoid double-counting drops
542
+ if ctx._filter_op_depth > 0:
543
+ ctx.row_manager.propagate(self, result)
544
+ return
545
+ _capture_filter(self, args, kwargs, result, ctx, "__getitem__[mask]")
546
+ return
547
+
548
+ # List of columns - column selection
549
+ if isinstance(key, list):
550
+ ctx.row_manager.propagate(self, result)
551
+ return
552
+
553
+ # Slice - row selection
554
+ if isinstance(key, slice):
555
+ # Skip if we're inside a named filter op
556
+ if ctx._filter_op_depth > 0:
557
+ ctx.row_manager.propagate(self, result)
558
+ return
559
+ _capture_filter(self, args, kwargs, result, ctx, "__getitem__[slice]")
560
+ return
561
+
562
+ # Default: propagate
563
+ ctx.row_manager.propagate(self, result)
564
+
565
+
566
+ # === __setitem__ CAPTURE ===
567
+
568
+
569
+ def _capture_setitem_with_before(
570
+ self: pd.DataFrame, key: str, before_values: pd.Series, ctx: TracePipeContext
571
+ ):
572
+ """
573
+ Capture column assignment with before/after values.
574
+
575
+ Called after assignment completes, with before_values captured earlier.
576
+ Uses vectorized comparison for ~50-100x speedup over row-by-row .loc access.
577
+ """
578
+ source_ids = ctx.row_manager.get_ids(self)
579
+ if source_ids is None:
580
+ return
581
+
582
+ after_values = self[key]
583
+
584
+ # Align series to same index for vectorized comparison
585
+ common_index = before_values.index.intersection(after_values.index).intersection(
586
+ source_ids.index
587
+ )
588
+ if len(common_index) == 0:
589
+ return
590
+
591
+ before_aligned = before_values.reindex(common_index)
592
+ after_aligned = after_values.reindex(common_index)
593
+ ids_aligned = source_ids.reindex(common_index)
594
+
595
+ # Vectorized: find which rows changed (~50-100x faster than loop)
596
+ changed_mask = find_changed_indices_vectorized(before_aligned, after_aligned)
597
+
598
+ if not changed_mask.any():
599
+ return # No changes, skip step creation
600
+
601
+ code_file, code_line = get_caller_info(skip_frames=2)
602
+ step_id = ctx.store.append_step(
603
+ operation=f"DataFrame.__setitem__[{key}]",
604
+ stage=ctx.current_stage,
605
+ code_file=code_file,
606
+ code_line=code_line,
607
+ params={"column": str(key)},
608
+ input_shape=self.shape,
609
+ output_shape=self.shape,
610
+ )
611
+
612
+ # Extract only changed values (numpy arrays for fast access)
613
+ changed_indices = np.where(changed_mask)[0]
614
+ old_arr = before_aligned.values
615
+ new_arr = after_aligned.values
616
+ ids_arr = ids_aligned.values
617
+
618
+ # Only loop over changed rows (typically small fraction of total)
619
+ for i in changed_indices:
620
+ ctx.store.append_diff(
621
+ step_id=step_id,
622
+ row_id=int(ids_arr[i]),
623
+ col=key,
624
+ old_val=old_arr[i],
625
+ new_val=new_arr[i],
626
+ change_type=ChangeType.MODIFIED,
627
+ )
628
+
629
+
630
+ def _capture_setitem_new_column(self: pd.DataFrame, key: str, ctx: TracePipeContext):
631
+ """
632
+ Capture assignment to a new column (no before values).
633
+
634
+ Uses vectorized array access for performance.
635
+ """
636
+ source_ids = ctx.row_manager.get_ids(self)
637
+ if source_ids is None:
638
+ return
639
+
640
+ new_values = self[key]
641
+
642
+ # Align to common index
643
+ common_index = new_values.index.intersection(source_ids.index)
644
+ if len(common_index) == 0:
645
+ return
646
+
647
+ code_file, code_line = get_caller_info(skip_frames=2)
648
+ step_id = ctx.store.append_step(
649
+ operation=f"DataFrame.__setitem__[{key}]",
650
+ stage=ctx.current_stage,
651
+ code_file=code_file,
652
+ code_line=code_line,
653
+ params={"column": str(key), "is_new_column": True},
654
+ input_shape=self.shape,
655
+ output_shape=self.shape,
656
+ )
657
+
658
+ # Use numpy arrays for fast access (avoid .loc per row)
659
+ new_aligned = new_values.reindex(common_index)
660
+ ids_aligned = source_ids.reindex(common_index)
661
+
662
+ new_arr = new_aligned.values
663
+ ids_arr = ids_aligned.values
664
+
665
+ for i in range(len(ids_arr)):
666
+ ctx.store.append_diff(
667
+ step_id=step_id,
668
+ row_id=int(ids_arr[i]),
669
+ col=key,
670
+ old_val=None,
671
+ new_val=new_arr[i],
672
+ change_type=ChangeType.ADDED,
673
+ )
674
+
675
+
676
+ def _wrap_setitem(original):
677
+ """
678
+ Wrap __setitem__ to capture column assignments.
679
+
680
+ Captures BEFORE state for existing columns, then executes assignment,
681
+ then records the diff with actual old/new values.
682
+ """
683
+
684
+ @wraps(original)
685
+ def wrapper(self, key, value):
686
+ ctx = get_context()
687
+
688
+ # === CAPTURE BEFORE STATE ===
689
+ before_values = None
690
+ is_new_column = False
691
+ should_track = False
692
+
693
+ if ctx.enabled and isinstance(key, str):
694
+ if key in ctx.watched_columns:
695
+ should_track = True
696
+ if key in self.columns:
697
+ # Existing column - capture before values
698
+ try:
699
+ before_values = self[key].copy()
700
+ except Exception:
701
+ pass
702
+ else:
703
+ # New column
704
+ is_new_column = True
705
+
706
+ # === EXECUTE ORIGINAL ===
707
+ original(self, key, value)
708
+
709
+ # === CAPTURE AFTER STATE ===
710
+ if should_track:
711
+ try:
712
+ if is_new_column:
713
+ _capture_setitem_new_column(self, key, ctx)
714
+ elif before_values is not None:
715
+ _capture_setitem_with_before(self, key, before_values, ctx)
716
+ except Exception as e:
717
+ if ctx.config.strict_mode:
718
+ from ..safety import TracePipeError
719
+
720
+ raise TracePipeError(f"__setitem__ instrumentation failed: {e}") from e
721
+ else:
722
+ warnings.warn(f"TracePipe: __setitem__ failed: {e}", TracePipeWarning)
723
+
724
+ return wrapper
725
+
726
+
727
+ # === AUTO-REGISTRATION ===
728
+
729
+
730
+ def _wrap_dataframe_reader(original, reader_name: str):
731
+ """Wrap pd.read_csv etc. to auto-register."""
732
+
733
+ @wraps(original)
734
+ def wrapper(*args, **kwargs):
735
+ result = original(*args, **kwargs)
736
+
737
+ ctx = get_context()
738
+ if ctx.enabled and isinstance(result, pd.DataFrame):
739
+ ctx.row_manager.register(result)
740
+
741
+ return result
742
+
743
+ return wrapper
744
+
745
+
746
+ def _wrap_dataframe_init(original):
747
+ """Wrap DataFrame.__init__ for auto-registration."""
748
+
749
+ @wraps(original)
750
+ def wrapper(self, *args, **kwargs):
751
+ original(self, *args, **kwargs)
752
+
753
+ ctx = get_context()
754
+ if ctx.enabled:
755
+ if ctx.row_manager.get_ids(self) is None:
756
+ ctx.row_manager.register(self)
757
+
758
+ return wrapper
759
+
760
+
761
+ # === EXPORT HOOKS (Auto-strip hidden column) ===
762
+
763
+
764
+ def _make_export_wrapper(original):
765
+ """Create a wrapper that strips hidden column before export."""
766
+
767
+ @wraps(original)
768
+ def wrapper(self, *args, **kwargs):
769
+ ctx = get_context()
770
+ if ctx.enabled:
771
+ clean_df = ctx.row_manager.strip_hidden_column(self)
772
+ return original(clean_df, *args, **kwargs)
773
+ return original(self, *args, **kwargs)
774
+
775
+ return wrapper
776
+
777
+
778
+ # === HELPER FUNCTIONS ===
779
+
780
+
781
+ def _safe_params(kwargs: dict) -> dict:
782
+ """Extract safe (serializable) params from kwargs."""
783
+ safe = {}
784
+ for k, v in kwargs.items():
785
+ if isinstance(v, (str, int, float, bool, type(None))):
786
+ safe[k] = v
787
+ elif isinstance(v, (list, tuple)) and all(
788
+ isinstance(x, (str, int, float, bool)) for x in v
789
+ ):
790
+ safe[k] = list(v)
791
+ else:
792
+ safe[k] = str(type(v).__name__)
793
+ return safe
794
+
795
+
796
+ def _get_affected_columns(method_name: str, args, kwargs, all_columns: pd.Index) -> set[str]:
797
+ """Determine which columns are affected by an operation."""
798
+ if method_name == "fillna":
799
+ value = args[0] if args else kwargs.get("value")
800
+ if isinstance(value, dict):
801
+ return set(value.keys())
802
+ return set(all_columns)
803
+
804
+ elif method_name == "replace":
805
+ return set(all_columns)
806
+
807
+ elif method_name == "astype":
808
+ dtype = args[0] if args else kwargs.get("dtype")
809
+ if isinstance(dtype, dict):
810
+ return set(dtype.keys())
811
+ return set(all_columns)
812
+
813
+ elif method_name == "__setitem__":
814
+ key = args[0] if args else None
815
+ if isinstance(key, str):
816
+ return {key}
817
+ elif isinstance(key, list):
818
+ return set(key)
819
+
820
+ return set(all_columns)
821
+
822
+
823
+ def _capture_cell_changes(
824
+ ctx: TracePipeContext, step_id: int, before: pd.DataFrame, after: pd.DataFrame
825
+ ):
826
+ """
827
+ Capture cell-level changes between two DataFrames.
828
+
829
+ Uses vectorized comparison for ~50-100x speedup.
830
+ """
831
+ source_ids = ctx.row_manager.get_ids(before)
832
+ if source_ids is None:
833
+ return
834
+
835
+ cols_to_track = ctx.watched_columns & set(before.columns) & set(after.columns)
836
+ if not cols_to_track:
837
+ return
838
+
839
+ # Find common index for vectorized comparison
840
+ common_index = before.index.intersection(after.index).intersection(source_ids.index)
841
+ if len(common_index) == 0:
842
+ return
843
+
844
+ ids_aligned = source_ids.reindex(common_index)
845
+ ids_arr = ids_aligned.values
846
+
847
+ for col in cols_to_track:
848
+ old_aligned = before[col].reindex(common_index)
849
+ new_aligned = after[col].reindex(common_index)
850
+
851
+ # Vectorized change detection
852
+ changed_mask = find_changed_indices_vectorized(old_aligned, new_aligned)
853
+
854
+ if not changed_mask.any():
855
+ continue
856
+
857
+ changed_indices = np.where(changed_mask)[0]
858
+ old_arr = old_aligned.values
859
+ new_arr = new_aligned.values
860
+
861
+ for i in changed_indices:
862
+ ctx.store.append_diff(
863
+ step_id=step_id,
864
+ row_id=int(ids_arr[i]),
865
+ col=col,
866
+ old_val=old_arr[i],
867
+ new_val=new_arr[i],
868
+ change_type=ChangeType.MODIFIED,
869
+ )
870
+
871
+
872
+ # === INSTRUMENTATION SETUP ===
873
+
874
+
875
+ def instrument_pandas():
876
+ """Install all pandas instrumentation."""
877
+ global _originals
878
+
879
+ if _originals:
880
+ # Already instrumented
881
+ return
882
+
883
+ # === DataFrame filter methods ===
884
+ # Use wrap_pandas_filter_method to prevent double-counting when
885
+ # methods like drop_duplicates internally call __getitem__
886
+ filter_methods = ["dropna", "drop_duplicates", "query", "head", "tail", "sample"]
887
+ for method_name in filter_methods:
888
+ if hasattr(pd.DataFrame, method_name):
889
+ original = getattr(pd.DataFrame, method_name)
890
+ _originals[f"DataFrame.{method_name}"] = original
891
+ wrapped = wrap_pandas_filter_method(method_name, original, _capture_filter)
892
+ setattr(pd.DataFrame, method_name, wrapped)
893
+
894
+ # === DataFrame transform methods (with inplace support) ===
895
+ transform_methods = ["fillna", "replace"]
896
+ for method_name in transform_methods:
897
+ if hasattr(pd.DataFrame, method_name):
898
+ original = getattr(pd.DataFrame, method_name)
899
+ _originals[f"DataFrame.{method_name}"] = original
900
+ wrapped = wrap_pandas_method_inplace(method_name, original, _capture_transform)
901
+ setattr(pd.DataFrame, method_name, wrapped)
902
+
903
+ # === astype (no inplace) ===
904
+ _originals["DataFrame.astype"] = pd.DataFrame.astype
905
+ pd.DataFrame.astype = wrap_pandas_method("astype", pd.DataFrame.astype, _capture_transform)
906
+
907
+ # === copy (preserves row identity) ===
908
+ _originals["DataFrame.copy"] = pd.DataFrame.copy
909
+ pd.DataFrame.copy = wrap_pandas_method("copy", pd.DataFrame.copy, _capture_copy)
910
+
911
+ # === drop (row/column removal) ===
912
+ _originals["DataFrame.drop"] = pd.DataFrame.drop
913
+ pd.DataFrame.drop = wrap_pandas_method("drop", pd.DataFrame.drop, _capture_drop)
914
+
915
+ # === apply/pipe ===
916
+ _originals["DataFrame.apply"] = pd.DataFrame.apply
917
+ pd.DataFrame.apply = wrap_pandas_method("apply", pd.DataFrame.apply, _capture_apply)
918
+
919
+ _originals["DataFrame.pipe"] = pd.DataFrame.pipe
920
+ pd.DataFrame.pipe = wrap_pandas_method("pipe", pd.DataFrame.pipe, _capture_apply)
921
+
922
+ # === groupby ===
923
+ _originals["DataFrame.groupby"] = pd.DataFrame.groupby
924
+ pd.DataFrame.groupby = wrap_pandas_method("groupby", pd.DataFrame.groupby, _capture_groupby)
925
+
926
+ # === GroupBy aggregation methods ===
927
+ from pandas.core.groupby import DataFrameGroupBy
928
+
929
+ for agg_method in ["agg", "aggregate", "sum", "mean", "count", "min", "max", "std", "var"]:
930
+ if hasattr(DataFrameGroupBy, agg_method):
931
+ original = getattr(DataFrameGroupBy, agg_method)
932
+ _originals[f"DataFrameGroupBy.{agg_method}"] = original
933
+ wrapped = wrap_pandas_method(agg_method, original, _capture_agg)
934
+ setattr(DataFrameGroupBy, agg_method, wrapped)
935
+
936
+ # === merge ===
937
+ _originals["DataFrame.merge"] = pd.DataFrame.merge
938
+ pd.DataFrame.merge = wrap_pandas_method("merge", pd.DataFrame.merge, _capture_merge)
939
+
940
+ _originals["DataFrame.join"] = pd.DataFrame.join
941
+ pd.DataFrame.join = wrap_pandas_method("join", pd.DataFrame.join, _capture_merge)
942
+
943
+ # === Index operations ===
944
+ _originals["DataFrame.reset_index"] = pd.DataFrame.reset_index
945
+ pd.DataFrame.reset_index = wrap_pandas_method(
946
+ "reset_index", pd.DataFrame.reset_index, _capture_reset_index
947
+ )
948
+
949
+ _originals["DataFrame.set_index"] = pd.DataFrame.set_index
950
+ pd.DataFrame.set_index = wrap_pandas_method(
951
+ "set_index", pd.DataFrame.set_index, _capture_set_index
952
+ )
953
+
954
+ _originals["DataFrame.sort_values"] = pd.DataFrame.sort_values
955
+ pd.DataFrame.sort_values = wrap_pandas_method(
956
+ "sort_values", pd.DataFrame.sort_values, _capture_sort_values
957
+ )
958
+
959
+ # === __getitem__ ===
960
+ _originals["DataFrame.__getitem__"] = pd.DataFrame.__getitem__
961
+ pd.DataFrame.__getitem__ = wrap_pandas_method(
962
+ "__getitem__", pd.DataFrame.__getitem__, _capture_getitem
963
+ )
964
+
965
+ # === __setitem__ (column assignment) ===
966
+ _originals["DataFrame.__setitem__"] = pd.DataFrame.__setitem__
967
+ pd.DataFrame.__setitem__ = _wrap_setitem(pd.DataFrame.__setitem__)
968
+
969
+ # === Readers (auto-registration) ===
970
+ readers = [
971
+ "read_csv",
972
+ "read_excel",
973
+ "read_parquet",
974
+ "read_json",
975
+ "read_sql",
976
+ "read_feather",
977
+ "read_pickle",
978
+ ]
979
+ for reader_name in readers:
980
+ if hasattr(pd, reader_name):
981
+ original = getattr(pd, reader_name)
982
+ _originals[f"pd.{reader_name}"] = original
983
+ setattr(pd, reader_name, _wrap_dataframe_reader(original, reader_name))
984
+
985
+ # === DataFrame.__init__ ===
986
+ _originals["DataFrame.__init__"] = pd.DataFrame.__init__
987
+ pd.DataFrame.__init__ = _wrap_dataframe_init(pd.DataFrame.__init__)
988
+
989
+ # === Export methods (auto-strip hidden column) ===
990
+ _originals["DataFrame.to_csv"] = pd.DataFrame.to_csv
991
+ pd.DataFrame.to_csv = _make_export_wrapper(pd.DataFrame.to_csv)
992
+
993
+ _originals["DataFrame.to_parquet"] = pd.DataFrame.to_parquet
994
+ pd.DataFrame.to_parquet = _make_export_wrapper(pd.DataFrame.to_parquet)
995
+
996
+ # === pd.concat ===
997
+ _originals["pd.concat"] = pd.concat
998
+
999
+ def wrapped_concat(*args, **kwargs):
1000
+ result = _originals["pd.concat"](*args, **kwargs)
1001
+ ctx = get_context()
1002
+ if ctx.enabled:
1003
+ _capture_concat(args, kwargs, result, ctx)
1004
+ return result
1005
+
1006
+ pd.concat = wrapped_concat
1007
+
1008
+
1009
+ def uninstrument_pandas():
1010
+ """Restore original pandas methods."""
1011
+ global _originals
1012
+
1013
+ for key, original in _originals.items():
1014
+ parts = key.split(".")
1015
+ if parts[0] == "pd":
1016
+ setattr(pd, parts[1], original)
1017
+ elif parts[0] == "DataFrame":
1018
+ setattr(pd.DataFrame, parts[1], original)
1019
+ elif parts[0] == "DataFrameGroupBy":
1020
+ from pandas.core.groupby import DataFrameGroupBy
1021
+
1022
+ setattr(DataFrameGroupBy, parts[1], original)
1023
+
1024
+ _originals.clear()