pointblank 0.13.4__py3-none-any.whl → 0.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. pointblank/__init__.py +4 -0
  2. pointblank/_constants.py +54 -0
  3. pointblank/_constants_translations.py +487 -2
  4. pointblank/_interrogation.py +182 -11
  5. pointblank/_utils.py +3 -3
  6. pointblank/_utils_ai.py +850 -0
  7. pointblank/cli.py +128 -115
  8. pointblank/column.py +1 -1
  9. pointblank/data/api-docs.txt +198 -13
  10. pointblank/data/validations/README.md +108 -0
  11. pointblank/data/validations/complex_preprocessing.json +54 -0
  12. pointblank/data/validations/complex_preprocessing.pkl +0 -0
  13. pointblank/data/validations/generate_test_files.py +127 -0
  14. pointblank/data/validations/multiple_steps.json +83 -0
  15. pointblank/data/validations/multiple_steps.pkl +0 -0
  16. pointblank/data/validations/narwhals_function.json +28 -0
  17. pointblank/data/validations/narwhals_function.pkl +0 -0
  18. pointblank/data/validations/no_preprocessing.json +83 -0
  19. pointblank/data/validations/no_preprocessing.pkl +0 -0
  20. pointblank/data/validations/pandas_compatible.json +28 -0
  21. pointblank/data/validations/pandas_compatible.pkl +0 -0
  22. pointblank/data/validations/preprocessing_functions.py +46 -0
  23. pointblank/data/validations/simple_preprocessing.json +57 -0
  24. pointblank/data/validations/simple_preprocessing.pkl +0 -0
  25. pointblank/datascan.py +4 -4
  26. pointblank/scan_profile.py +6 -6
  27. pointblank/schema.py +8 -82
  28. pointblank/thresholds.py +1 -1
  29. pointblank/validate.py +1233 -12
  30. {pointblank-0.13.4.dist-info → pointblank-0.14.0.dist-info}/METADATA +66 -8
  31. pointblank-0.14.0.dist-info/RECORD +55 -0
  32. pointblank-0.13.4.dist-info/RECORD +0 -39
  33. {pointblank-0.13.4.dist-info → pointblank-0.14.0.dist-info}/WHEEL +0 -0
  34. {pointblank-0.13.4.dist-info → pointblank-0.14.0.dist-info}/entry_points.txt +0 -0
  35. {pointblank-0.13.4.dist-info → pointblank-0.14.0.dist-info}/licenses/LICENSE +0 -0
  36. {pointblank-0.13.4.dist-info → pointblank-0.14.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,850 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import json
5
+ import logging
6
+ from dataclasses import dataclass
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+
9
+ import narwhals as nw
10
+ from narwhals.typing import FrameT
11
+
12
+ from pointblank._constants import MODEL_PROVIDERS
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ # ============================================================================
18
+ # LLM Configuration and Chat Interface
19
+ # ============================================================================
20
+
21
+
22
+ @dataclass
23
+ class _LLMConfig:
24
+ """Configuration for LLM provider.
25
+
26
+ Parameters
27
+ ----------
28
+ provider
29
+ LLM provider name (e.g., 'anthropic', 'openai', 'ollama', 'bedrock').
30
+ model
31
+ Model name (e.g., 'claude-3-sonnet-20240229', 'gpt-4').
32
+ api_key
33
+ API key for the provider. If None, will be read from environment.
34
+ """
35
+
36
+ provider: str
37
+ model: str
38
+ api_key: Optional[str] = None
39
+
40
+
41
+ def _create_chat_instance(provider: str, model_name: str, api_key: Optional[str] = None):
42
+ """
43
+ Create a chatlas chat instance for the specified provider.
44
+
45
+ Parameters
46
+ ----------
47
+ provider
48
+ The provider name (e.g., 'anthropic', 'openai', 'ollama', 'bedrock').
49
+ model_name
50
+ The model name for the provider.
51
+ api_key
52
+ Optional API key. If None, will be read from environment.
53
+
54
+ Returns
55
+ -------
56
+ Chat instance from chatlas.
57
+ """
58
+ # Check if chatlas is installed
59
+ try:
60
+ import chatlas # noqa
61
+ except ImportError: # pragma: no cover
62
+ raise ImportError(
63
+ "The `chatlas` package is required for AI validation. "
64
+ "Please install it using `pip install chatlas`."
65
+ )
66
+
67
+ # Validate provider
68
+ if provider not in MODEL_PROVIDERS:
69
+ raise ValueError(
70
+ f"Provider '{provider}' is not supported. "
71
+ f"Supported providers: {', '.join(MODEL_PROVIDERS)}"
72
+ )
73
+
74
+ # System prompt with role definition and instructions
75
+ system_prompt = """You are a data validation assistant. Your task is to analyze rows of data and determine if they meet the specified validation criteria.
76
+
77
+ INSTRUCTIONS:
78
+ - Analyze each row in the provided data
79
+ - For each row, determine if it meets the validation criteria (True) or not (False)
80
+ - Return ONLY a JSON array with validation results
81
+ - Each result should have: {"index": <row_index>, "result": <true_or_false>}
82
+ - Do not include any explanatory text, only the JSON array
83
+ - The row_index should match the "_pb_row_index" field from the input data
84
+
85
+ EXAMPLE OUTPUT FORMAT:
86
+ [
87
+ {"index": 0, "result": true},
88
+ {"index": 1, "result": false},
89
+ {"index": 2, "result": true}
90
+ ]"""
91
+
92
+ # Create provider-specific chat instance
93
+ if provider == "anthropic": # pragma: no cover
94
+ # Check that the anthropic package is installed
95
+ try:
96
+ import anthropic # noqa
97
+ except ImportError:
98
+ raise ImportError(
99
+ "The `anthropic` package is required to use AI validation with "
100
+ "`anthropic`. Please install it using `pip install anthropic`."
101
+ )
102
+
103
+ from chatlas import ChatAnthropic
104
+
105
+ chat = ChatAnthropic(
106
+ model=model_name,
107
+ api_key=api_key,
108
+ system_prompt=system_prompt,
109
+ )
110
+
111
+ elif provider == "openai": # pragma: no cover
112
+ # Check that the openai package is installed
113
+ try:
114
+ import openai # noqa
115
+ except ImportError:
116
+ raise ImportError(
117
+ "The `openai` package is required to use AI validation with "
118
+ "`openai`. Please install it using `pip install openai`."
119
+ )
120
+
121
+ from chatlas import ChatOpenAI
122
+
123
+ chat = ChatOpenAI(
124
+ model=model_name,
125
+ api_key=api_key,
126
+ system_prompt=system_prompt,
127
+ )
128
+
129
+ elif provider == "ollama": # pragma: no cover
130
+ # Check that the openai package is installed (required for Ollama)
131
+ try:
132
+ import openai # noqa
133
+ except ImportError:
134
+ raise ImportError(
135
+ "The `openai` package is required to use AI validation with "
136
+ "`ollama`. Please install it using `pip install openai`."
137
+ )
138
+
139
+ from chatlas import ChatOllama
140
+
141
+ chat = ChatOllama(
142
+ model=model_name,
143
+ system_prompt=system_prompt,
144
+ )
145
+
146
+ elif provider == "bedrock": # pragma: no cover
147
+ from chatlas import ChatBedrockAnthropic
148
+
149
+ chat = ChatBedrockAnthropic(
150
+ model=model_name,
151
+ system_prompt=system_prompt,
152
+ )
153
+
154
+ else:
155
+ raise ValueError(f"Unsupported provider: {provider}")
156
+
157
+ return chat
158
+
159
+
160
+ # ============================================================================
161
+ # Data Batching and Optimization
162
+ # ============================================================================
163
+
164
+
165
+ @dataclass
166
+ class _BatchConfig:
167
+ """Configuration for AI validation batching.
168
+
169
+ Parameters
170
+ ----------
171
+ size
172
+ Batch size for processing rows.
173
+ max_concurrent
174
+ Maximum number of concurrent LLM requests.
175
+ """
176
+
177
+ size: int = 1000
178
+ max_concurrent: int = 3
179
+
180
+
181
+ class _DataBatcher:
182
+ """Optimized batching of data for AI validation with row signature memoization."""
183
+
184
+ def __init__(
185
+ self,
186
+ data: FrameT,
187
+ columns: Optional[List[str]] = None,
188
+ config: Optional[_BatchConfig] = None,
189
+ ):
190
+ """
191
+ Initialize the optimized data batcher.
192
+
193
+ Parameters
194
+ ----------
195
+ data
196
+ The data frame to batch.
197
+ columns
198
+ Optional list of columns to include in batches. If None, all columns are included.
199
+ config
200
+ Optional batch configuration. If None, default configuration is used.
201
+ """
202
+ self.data = data
203
+ self.columns = columns
204
+ self.config = config or _BatchConfig()
205
+ self._validate_data()
206
+
207
+ # Memoization structures
208
+ self.unique_rows_table = None
209
+ self.signature_to_original_indices = {}
210
+ self.reduction_stats = {}
211
+
212
+ def _validate_data(self) -> None:
213
+ """Validate that the data is supported."""
214
+ if not hasattr(self.data, "shape"):
215
+ raise ValueError("Data must have a 'shape' attribute")
216
+
217
+ # Get data with narwhals for compatibility
218
+ self._nw_data = nw.from_native(self.data)
219
+
220
+ if self.columns:
221
+ # Validate that specified columns exist
222
+ available_columns = self._nw_data.columns
223
+ missing_columns = set(self.columns) - set(available_columns)
224
+ if missing_columns:
225
+ raise ValueError(f"Columns not found in data: {missing_columns}")
226
+
227
+ def _create_row_signature(self, row_dict: Dict[str, Any]) -> str:
228
+ """
229
+ Create a unique signature for a row based on selected columns.
230
+
231
+ Parameters
232
+ ----------
233
+ row_dict
234
+ Dictionary representing a row.
235
+
236
+ Returns
237
+ -------
238
+ str
239
+ Unique signature for the row.
240
+ """
241
+ # Create deterministic signature from sorted column values
242
+ signature_data = {k: v for k, v in row_dict.items() if k != "_pb_row_index"}
243
+ signature_str = json.dumps(signature_data, sort_keys=True, default=str)
244
+ return hashlib.md5(signature_str.encode()).hexdigest()
245
+
246
+ def _build_unique_rows_table(self) -> Tuple[FrameT, Dict[str, List[int]]]:
247
+ """
248
+ Build unique rows table and mapping back to original indices.
249
+
250
+ Returns
251
+ -------
252
+ Tuple[FrameT, Dict[str, List[int]]]
253
+ Unique rows table and signature-to-indices mapping.
254
+ """
255
+ nw_data = self._nw_data
256
+
257
+ # Select columns if specified
258
+ if self.columns:
259
+ nw_data = nw_data.select(self.columns)
260
+
261
+ # Convert to native for easier manipulation
262
+ native_data = nw.to_native(nw_data)
263
+
264
+ # Get all rows as dictionaries
265
+ if hasattr(native_data, "to_dicts"):
266
+ # Polars DataFrame
267
+ all_rows = native_data.to_dicts()
268
+ elif hasattr(native_data, "to_dict"):
269
+ # Pandas DataFrame
270
+ all_rows = native_data.to_dict("records")
271
+ else: # pragma: no cover
272
+ # Fallback: manual conversion
273
+ all_rows = []
274
+ columns = nw_data.columns
275
+ for i in range(len(native_data)):
276
+ row_dict = {}
277
+ for col in columns:
278
+ row_dict[col] = (
279
+ native_data[col].iloc[i]
280
+ if hasattr(native_data[col], "iloc")
281
+ else native_data[col][i]
282
+ )
283
+ all_rows.append(row_dict)
284
+
285
+ # Build signature mapping
286
+ signature_to_indices = {}
287
+ unique_rows = []
288
+ unique_signatures = set()
289
+
290
+ for original_idx, row_dict in enumerate(all_rows):
291
+ signature = self._create_row_signature(row_dict)
292
+
293
+ if signature not in signature_to_indices:
294
+ signature_to_indices[signature] = []
295
+
296
+ signature_to_indices[signature].append(original_idx)
297
+
298
+ # Add to unique rows if not seen before
299
+ if signature not in unique_signatures:
300
+ unique_signatures.add(signature)
301
+ # Add signature tracking for later joining
302
+ row_dict["_pb_signature"] = signature
303
+ unique_rows.append(row_dict)
304
+
305
+ # Convert unique rows back to dataframe
306
+ if unique_rows:
307
+ if hasattr(native_data, "with_columns"): # Polars
308
+ import polars as pl
309
+
310
+ unique_df = pl.DataFrame(unique_rows)
311
+ elif hasattr(native_data, "assign"): # Pandas
312
+ import pandas as pd
313
+
314
+ unique_df = pd.DataFrame(unique_rows)
315
+ else: # pragma: no cover
316
+ # This is tricky for generic case, but let's try
317
+ unique_df = unique_rows # Fallback to list of dicts
318
+ else: # pragma: no cover
319
+ unique_df = native_data.head(0) # Empty dataframe with same structure
320
+
321
+ # Store reduction stats
322
+ original_count = len(all_rows)
323
+ unique_count = len(unique_rows)
324
+ reduction_pct = (1 - unique_count / original_count) * 100 if original_count > 0 else 0
325
+
326
+ self.reduction_stats = {
327
+ "original_rows": original_count,
328
+ "unique_rows": unique_count,
329
+ "reduction_percentage": reduction_pct,
330
+ }
331
+
332
+ logger.info(
333
+ f"Row signature optimization: {original_count} → {unique_count} rows ({reduction_pct:.1f}% reduction)"
334
+ )
335
+
336
+ return unique_df, signature_to_indices
337
+
338
+ def create_batches(self) -> Tuple[List[Dict[str, Any]], Dict[str, List[int]]]:
339
+ """
340
+ Create optimized batches using unique row signatures.
341
+
342
+ Returns
343
+ -------
344
+ Tuple[List[Dict[str, Any]], Dict[str, List[int]]]
345
+ Batches for unique rows and signature-to-indices mapping.
346
+ """
347
+ # Build unique rows table and signature mapping
348
+ unique_rows_table, signature_to_indices = self._build_unique_rows_table()
349
+ self.unique_rows_table = unique_rows_table
350
+ self.signature_to_original_indices = signature_to_indices
351
+
352
+ # Create batches from unique rows table
353
+ if hasattr(unique_rows_table, "shape"):
354
+ total_rows = unique_rows_table.shape[0]
355
+ else: # pragma: no cover
356
+ total_rows = len(unique_rows_table)
357
+
358
+ batches = []
359
+ batch_id = 0
360
+
361
+ # Convert to narwhals if needed
362
+ if not hasattr(unique_rows_table, "columns"): # pragma: no cover
363
+ nw_unique = nw.from_native(unique_rows_table)
364
+ else:
365
+ nw_unique = unique_rows_table
366
+
367
+ for start_row in range(0, total_rows, self.config.size):
368
+ end_row = min(start_row + self.config.size, total_rows)
369
+
370
+ # Get the batch data
371
+ if hasattr(nw_unique, "__getitem__"):
372
+ batch_data = nw_unique[start_row:end_row]
373
+ else: # pragma: no cover
374
+ # Fallback for list of dicts
375
+ batch_data = unique_rows_table[start_row:end_row]
376
+
377
+ # Convert to JSON-serializable format
378
+ batch_json = self._convert_batch_to_json(batch_data, start_row)
379
+
380
+ batches.append(
381
+ {
382
+ "batch_id": batch_id,
383
+ "start_row": start_row,
384
+ "end_row": end_row,
385
+ "data": batch_json,
386
+ }
387
+ )
388
+
389
+ batch_id += 1
390
+
391
+ logger.info(f"Created {len(batches)} batches for {total_rows} unique rows")
392
+ return batches, signature_to_indices
393
+
394
+ def _convert_batch_to_json(self, batch_data, start_row: int) -> Dict[str, Any]:
395
+ """
396
+ Convert a batch of unique data to JSON format for LLM consumption.
397
+ """
398
+ # Handle different input types
399
+ if isinstance(batch_data, list):
400
+ # List of dictionaries
401
+ rows = []
402
+ columns = list(batch_data[0].keys()) if batch_data else []
403
+
404
+ for i, row_dict in enumerate(batch_data):
405
+ # Remove signature column from LLM input but keep for joining
406
+ clean_row = {k: v for k, v in row_dict.items() if k != "_pb_signature"}
407
+ clean_row["_pb_row_index"] = start_row + i
408
+ rows.append(clean_row)
409
+
410
+ # Remove signature from columns list
411
+ columns = [col for col in columns if col != "_pb_signature"]
412
+
413
+ else:
414
+ # DataFrame-like object
415
+ columns = [col for col in batch_data.columns if col != "_pb_signature"]
416
+ rows = []
417
+
418
+ # batch_data is already native format from slicing
419
+ native_batch = batch_data
420
+
421
+ # Handle different data frame types
422
+ if hasattr(native_batch, "to_dicts"):
423
+ # Polars DataFrame
424
+ batch_dicts = native_batch.to_dicts()
425
+ elif hasattr(native_batch, "to_dict"):
426
+ # Pandas DataFrame
427
+ batch_dicts = native_batch.to_dict("records")
428
+ else: # pragma: no cover
429
+ # Fallback: manual conversion
430
+ batch_dicts = []
431
+ for i in range(len(native_batch)):
432
+ row_dict = {}
433
+ for col in columns:
434
+ row_dict[col] = (
435
+ native_batch[col].iloc[i]
436
+ if hasattr(native_batch[col], "iloc")
437
+ else native_batch[col][i]
438
+ )
439
+ batch_dicts.append(row_dict)
440
+
441
+ # Clean up rows and add indices
442
+ for i, row_dict in enumerate(batch_dicts):
443
+ clean_row = {k: v for k, v in row_dict.items() if k != "_pb_signature"}
444
+ clean_row["_pb_row_index"] = start_row + i
445
+ rows.append(clean_row)
446
+
447
+ return {
448
+ "columns": columns,
449
+ "rows": rows,
450
+ "batch_info": {
451
+ "start_row": start_row,
452
+ "num_rows": len(rows),
453
+ "columns_count": len(columns),
454
+ },
455
+ }
456
+
457
+ def get_reduction_stats(self) -> Dict[str, Any]:
458
+ """Get statistics about the row reduction optimization."""
459
+ return self.reduction_stats.copy()
460
+
461
+
462
+ # ============================================================================
463
+ # Prompt Building
464
+ # ============================================================================
465
+
466
+
467
+ class _PromptBuilder:
468
+ """
469
+ Builds user messages for AI validation.
470
+
471
+ Works in conjunction with the system prompt set on the chat instance.
472
+ The system prompt contains role definition and general instructions,
473
+ while this class builds user messages with specific validation criteria and data.
474
+ """
475
+
476
+ USER_MESSAGE_TEMPLATE = """VALIDATION CRITERIA:
477
+ {user_prompt}
478
+
479
+ DATA TO VALIDATE:
480
+ {data_json}"""
481
+
482
+ def __init__(self, user_prompt: str):
483
+ """
484
+ Initialize the prompt builder.
485
+
486
+ Parameters
487
+ ----------
488
+ user_prompt
489
+ The user's validation prompt describing what to check.
490
+ """
491
+ self.user_prompt = user_prompt
492
+
493
+ def build_prompt(self, batch_data: Dict[str, Any]) -> str:
494
+ """
495
+ Build a user message for a data batch.
496
+
497
+ Parameters
498
+ ----------
499
+ batch_data
500
+ The batch data dictionary from DataBatcher.
501
+
502
+ Returns
503
+ -------
504
+ str
505
+ The user message for the LLM.
506
+ """
507
+ data_json = json.dumps(batch_data, indent=2, default=str)
508
+
509
+ return self.USER_MESSAGE_TEMPLATE.format(user_prompt=self.user_prompt, data_json=data_json)
510
+
511
+
512
+ # ============================================================================
513
+ # Response Parsing
514
+ # ============================================================================
515
+
516
+
517
+ class _ValidationResponseParser:
518
+ """Parses AI validation responses."""
519
+
520
+ def __init__(self, total_rows: int):
521
+ """
522
+ Initialize the response parser.
523
+
524
+ Parameters
525
+ ----------
526
+ total_rows
527
+ Total number of rows being validated.
528
+ """
529
+ self.total_rows = total_rows
530
+
531
+ def parse_response(self, response: str, batch_info: Dict[str, Any]) -> List[Dict[str, Any]]:
532
+ """
533
+ Parse an LLM response for a batch.
534
+
535
+ Parameters
536
+ ----------
537
+ response
538
+ The raw response from the LLM.
539
+ batch_info
540
+ Information about the batch being processed.
541
+
542
+ Returns
543
+ -------
544
+ List[Dict[str, Any]]
545
+ List of parsed results with 'index' and 'result' keys.
546
+ """
547
+ try:
548
+ # Try to extract JSON from the response
549
+ json_response = self._extract_json(response)
550
+
551
+ # Validate the structure
552
+ self._validate_response_structure(json_response, batch_info)
553
+
554
+ return json_response
555
+
556
+ except Exception as e:
557
+ logger.error(
558
+ f"Failed to parse response for batch {batch_info.get('batch_id', 'unknown')}: {e}"
559
+ )
560
+ logger.error(f"Raw response: {response}")
561
+
562
+ # Return default results (all False) for this batch
563
+ return self._create_default_results(batch_info)
564
+
565
+ def _extract_json(self, response: str) -> List[Dict[str, Any]]:
566
+ """Extract JSON from LLM response."""
567
+ # Clean up the response
568
+ response = response.strip()
569
+
570
+ # Look for JSON array patterns
571
+ import re
572
+
573
+ json_pattern = r"\[.*?\]"
574
+ matches = re.findall(json_pattern, response, re.DOTALL)
575
+
576
+ if matches:
577
+ # Try to parse the first match
578
+ try:
579
+ return json.loads(matches[0])
580
+ except json.JSONDecodeError: # pragma: no cover
581
+ # If that fails, try the raw response
582
+ return json.loads(response)
583
+ else:
584
+ # Try to parse the raw response
585
+ return json.loads(response)
586
+
587
+ def _validate_response_structure(
588
+ self, json_response: List[Dict[str, Any]], batch_info: Dict[str, Any]
589
+ ) -> None:
590
+ """Validate that the response has the correct structure."""
591
+ if not isinstance(json_response, list):
592
+ raise ValueError("Response must be a JSON array")
593
+
594
+ expected_rows = batch_info["end_row"] - batch_info["start_row"]
595
+ if len(json_response) != expected_rows:
596
+ logger.warning(
597
+ f"Expected {expected_rows} results, got {len(json_response)} for batch {batch_info.get('batch_id')}"
598
+ )
599
+
600
+ for i, result in enumerate(json_response):
601
+ if not isinstance(result, dict):
602
+ raise ValueError(f"Result {i} must be a dictionary")
603
+
604
+ if "index" not in result or "result" not in result:
605
+ raise ValueError(f"Result {i} must have 'index' and 'result' keys")
606
+
607
+ if not isinstance(result["result"], bool):
608
+ raise ValueError(f"Result {i} 'result' must be a boolean")
609
+
610
+ def _create_default_results(self, batch_info: Dict[str, Any]) -> List[Dict[str, Any]]:
611
+ """Create default results (all False) for a batch."""
612
+ results = []
613
+ for i in range(batch_info["start_row"], batch_info["end_row"]):
614
+ results.append({"index": i, "result": False})
615
+ return results
616
+
617
+ def combine_batch_results(
618
+ self,
619
+ batch_results: List[List[Dict[str, Any]]],
620
+ signature_mapping: Optional[Dict[str, List[int]]] = None,
621
+ ) -> Dict[int, bool]:
622
+ """
623
+ Combine results from multiple batches and project to original rows using signature mapping.
624
+
625
+ Parameters
626
+ ----------
627
+ batch_results
628
+ List of batch results from parse_response.
629
+ signature_mapping
630
+ Optional mapping from row signatures to original row indices for memoization.
631
+
632
+ Returns
633
+ -------
634
+ Dict[int, bool]
635
+ Dictionary mapping original row index to validation result.
636
+ """
637
+ logger.debug(f"🔀 Combining results from {len(batch_results)} batches")
638
+
639
+ if signature_mapping:
640
+ logger.debug(
641
+ f"🎯 Using signature mapping optimization for {len(signature_mapping)} unique signatures"
642
+ )
643
+
644
+ # First, collect results from unique rows
645
+ unique_results = {}
646
+ total_processed = 0
647
+
648
+ for batch_idx, batch_result in enumerate(batch_results):
649
+ logger.debug(f" Batch {batch_idx}: {len(batch_result)} results")
650
+
651
+ for result in batch_result:
652
+ index = result["index"]
653
+ validation_result = result["result"]
654
+ unique_results[index] = validation_result
655
+ total_processed += 1
656
+
657
+ # Log first few results
658
+ if len(unique_results) <= 3 or total_processed <= 3:
659
+ logger.debug(f" Unique row {index}: {validation_result}")
660
+
661
+ # If no signature mapping, return unique results as-is (fallback to original behavior)
662
+ if not signature_mapping:
663
+ logger.debug("📊 No signature mapping - returning unique results")
664
+ passed_count = sum(1 for v in unique_results.values() if v)
665
+ failed_count = len(unique_results) - passed_count
666
+ logger.debug(f" - Final count: {passed_count} passed, {failed_count} failed")
667
+ return unique_results
668
+
669
+ # Project unique results back to all original rows using signature mapping
670
+ combined_results = {}
671
+
672
+ # We need to map from unique row indices back to signatures, then to original indices
673
+ # This requires rebuilding the signatures from the unique rows
674
+ # For now, let's assume the unique_results indices correspond to signature order
675
+ signature_list = list(signature_mapping.keys())
676
+
677
+ for unique_idx, validation_result in unique_results.items():
678
+ if unique_idx < len(signature_list):
679
+ signature = signature_list[unique_idx]
680
+ original_indices = signature_mapping[signature]
681
+
682
+ logger.debug(
683
+ f" Projecting result {validation_result} from unique row {unique_idx} to {len(original_indices)} original rows"
684
+ )
685
+
686
+ # Project this result to all original rows with this signature
687
+ for original_idx in original_indices:
688
+ combined_results[original_idx] = validation_result
689
+ else: # pragma: no cover
690
+ logger.warning(f"Unique index {unique_idx} out of range for signature mapping")
691
+
692
+ logger.debug("📊 Projected results summary:")
693
+ logger.debug(f" - Unique rows processed: {len(unique_results)}")
694
+ logger.debug(f" - Original rows mapped: {len(combined_results)}")
695
+ logger.debug(
696
+ f" - Index range: {min(combined_results.keys()) if combined_results else 'N/A'} to {max(combined_results.keys()) if combined_results else 'N/A'}"
697
+ )
698
+
699
+ passed_count = sum(1 for v in combined_results.values() if v)
700
+ failed_count = len(combined_results) - passed_count
701
+ logger.debug(f" - Final count: {passed_count} passed, {failed_count} failed")
702
+
703
+ return combined_results
704
+
705
+
706
+ # ============================================================================
707
+ # AI Validation Engine
708
+ # ============================================================================
709
+
710
+
711
+ class _AIValidationEngine:
712
+ """Main engine for AI-powered validation using chatlas."""
713
+
714
+ def __init__(self, llm_config: _LLMConfig):
715
+ """
716
+ Initialize the AI validation engine.
717
+
718
+ Parameters
719
+ ----------
720
+ llm_config
721
+ Configuration for the LLM provider.
722
+ """
723
+ self.llm_config = llm_config
724
+ self.chat = _create_chat_instance(
725
+ provider=llm_config.provider, model_name=llm_config.model, api_key=llm_config.api_key
726
+ )
727
+
728
+ def validate_batches(
729
+ self, batches: List[Dict[str, Any]], prompt_builder: Any, max_concurrent: int = 3
730
+ ) -> List[List[Dict[str, Any]]]:
731
+ """
732
+ Validate multiple batches.
733
+
734
+ Parameters
735
+ ----------
736
+ batches
737
+ List of batch dictionaries from DataBatcher.
738
+ prompt_builder
739
+ PromptBuilder instance for generating prompts.
740
+ max_concurrent
741
+ Maximum number of concurrent requests (ignored for now with chatlas).
742
+
743
+ Returns
744
+ -------
745
+ List[List[Dict[str, Any]]]
746
+ List of batch results, each containing validation results.
747
+ """
748
+
749
+ def validate_batch(batch: Dict[str, Any]) -> List[Dict[str, Any]]:
750
+ try:
751
+ # Debug: Log batch information
752
+ logger.debug(f"🔍 Processing batch {batch['batch_id']}")
753
+ logger.debug(f" - Rows: {batch['start_row']} to {batch['end_row'] - 1}")
754
+ logger.debug(
755
+ f" - Data shape: {batch['data'].shape if hasattr(batch['data'], 'shape') else 'N/A'}"
756
+ )
757
+
758
+ # Build the prompt for this batch
759
+ prompt = prompt_builder.build_prompt(batch["data"])
760
+
761
+ # Debug: Log the prompt being sent to LLM
762
+ logger.debug(f"📤 LLM Prompt for batch {batch['batch_id']}:")
763
+ logger.debug("--- PROMPT START ---")
764
+ logger.debug(prompt)
765
+ logger.debug("--- PROMPT END ---")
766
+
767
+ # Get response from LLM using chatlas (synchronous)
768
+ response = str(self.chat.chat(prompt, stream=False, echo="none"))
769
+
770
+ # Debug: Log the raw LLM response
771
+ logger.debug(f"📥 LLM Response for batch {batch['batch_id']}:")
772
+ logger.debug("--- RESPONSE START ---")
773
+ logger.debug(response)
774
+ logger.debug("--- RESPONSE END ---")
775
+
776
+ # Parse the response
777
+ parser = _ValidationResponseParser(total_rows=1000) # This will be set properly
778
+ results = parser.parse_response(response, batch)
779
+
780
+ # Debug: Log parsed results # pragma: no cover
781
+ logger.debug(f"📊 Parsed results for batch {batch['batch_id']}:")
782
+ for i, result in enumerate(results[:5]): # Show first 5 results
783
+ logger.debug(f" Row {result['index']}: {result['result']}")
784
+ if len(results) > 5:
785
+ logger.debug(f" ... and {len(results) - 5} more results")
786
+
787
+ passed_count = sum(1 for r in results if r["result"])
788
+ failed_count = len(results) - passed_count
789
+ logger.debug(f" Summary: {passed_count} passed, {failed_count} failed")
790
+
791
+ logger.info(f"Successfully validated batch {batch['batch_id']}")
792
+ return results
793
+
794
+ except Exception as e: # pragma: no cover
795
+ logger.error(
796
+ f"Failed to validate batch {batch['batch_id']}: {e}"
797
+ ) # pragma: no cover
798
+ # Return default results (all False) for failed batches
799
+ default_results = [] # pragma: no cover
800
+ for i in range(batch["start_row"], batch["end_row"]): # pragma: no cover
801
+ default_results.append({"index": i, "result": False}) # pragma: no cover
802
+ return default_results # pragma: no cover
803
+
804
+ # Execute all batch validations sequentially (chatlas is synchronous)
805
+ final_results = []
806
+ for batch in batches:
807
+ result = validate_batch(batch)
808
+ final_results.append(result)
809
+
810
+ return final_results
811
+
812
+ def validate_single_batch(
813
+ self, batch: Dict[str, Any], prompt_builder: Any
814
+ ) -> List[Dict[str, Any]]:
815
+ """
816
+ Validate a single batch.
817
+
818
+ Parameters
819
+ ----------
820
+ batch
821
+ Batch dictionary from DataBatcher.
822
+ prompt_builder
823
+ PromptBuilder instance for generating prompts.
824
+
825
+ Returns
826
+ -------
827
+ List[Dict[str, Any]]
828
+ Validation results for the batch.
829
+ """
830
+ try:
831
+ # Build the prompt for this batch
832
+ prompt = prompt_builder.build_prompt(batch["data"])
833
+
834
+ # Get response from LLM using chatlas (synchronous)
835
+ response = str(self.chat.chat(prompt, stream=False, echo="none"))
836
+
837
+ # Parse the response
838
+ parser = _ValidationResponseParser(total_rows=1000) # This will be set properly
839
+ results = parser.parse_response(response, batch)
840
+
841
+ logger.info(f"Successfully validated batch {batch['batch_id']}")
842
+ return results
843
+
844
+ except Exception as e:
845
+ logger.error(f"Failed to validate batch {batch['batch_id']}: {e}")
846
+ # Return default results (all False) for failed batch
847
+ default_results = []
848
+ for i in range(batch["start_row"], batch["end_row"]):
849
+ default_results.append({"index": i, "result": False})
850
+ return default_results