pointblank 0.13.4__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pointblank/__init__.py +4 -0
- pointblank/_constants.py +117 -0
- pointblank/_constants_translations.py +487 -2
- pointblank/_interrogation.py +1065 -12
- pointblank/_spec_utils.py +1015 -0
- pointblank/_utils.py +17 -7
- pointblank/_utils_ai.py +875 -0
- pointblank/assistant.py +1 -1
- pointblank/cli.py +128 -115
- pointblank/column.py +1 -1
- pointblank/data/api-docs.txt +1838 -130
- pointblank/data/validations/README.md +108 -0
- pointblank/data/validations/complex_preprocessing.json +54 -0
- pointblank/data/validations/complex_preprocessing.pkl +0 -0
- pointblank/data/validations/generate_test_files.py +127 -0
- pointblank/data/validations/multiple_steps.json +83 -0
- pointblank/data/validations/multiple_steps.pkl +0 -0
- pointblank/data/validations/narwhals_function.json +28 -0
- pointblank/data/validations/narwhals_function.pkl +0 -0
- pointblank/data/validations/no_preprocessing.json +83 -0
- pointblank/data/validations/no_preprocessing.pkl +0 -0
- pointblank/data/validations/pandas_compatible.json +28 -0
- pointblank/data/validations/pandas_compatible.pkl +0 -0
- pointblank/data/validations/preprocessing_functions.py +46 -0
- pointblank/data/validations/simple_preprocessing.json +57 -0
- pointblank/data/validations/simple_preprocessing.pkl +0 -0
- pointblank/datascan.py +4 -4
- pointblank/draft.py +52 -3
- pointblank/scan_profile.py +6 -6
- pointblank/schema.py +8 -82
- pointblank/thresholds.py +1 -1
- pointblank/validate.py +3069 -437
- {pointblank-0.13.4.dist-info → pointblank-0.15.0.dist-info}/METADATA +67 -8
- pointblank-0.15.0.dist-info/RECORD +56 -0
- pointblank-0.13.4.dist-info/RECORD +0 -39
- {pointblank-0.13.4.dist-info → pointblank-0.15.0.dist-info}/WHEEL +0 -0
- {pointblank-0.13.4.dist-info → pointblank-0.15.0.dist-info}/entry_points.txt +0 -0
- {pointblank-0.13.4.dist-info → pointblank-0.15.0.dist-info}/licenses/LICENSE +0 -0
- {pointblank-0.13.4.dist-info → pointblank-0.15.0.dist-info}/top_level.txt +0 -0
pointblank/_utils_ai.py
ADDED
|
@@ -0,0 +1,875 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
8
|
+
|
|
9
|
+
import narwhals as nw
|
|
10
|
+
from narwhals.typing import FrameT
|
|
11
|
+
|
|
12
|
+
from pointblank._constants import MODEL_PROVIDERS
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# ============================================================================
|
|
18
|
+
# LLM Configuration and Chat Interface
|
|
19
|
+
# ============================================================================
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class _LLMConfig:
|
|
24
|
+
"""Configuration for LLM provider.
|
|
25
|
+
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
provider
|
|
29
|
+
LLM provider name (e.g., 'anthropic', 'openai', 'ollama', 'bedrock').
|
|
30
|
+
model
|
|
31
|
+
Model name (e.g., 'claude-sonnet-4-5', 'gpt-4').
|
|
32
|
+
api_key
|
|
33
|
+
API key for the provider. If None, will be read from environment.
|
|
34
|
+
verify_ssl
|
|
35
|
+
Whether to verify SSL certificates when making requests. Defaults to True.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
provider: str
|
|
39
|
+
model: str
|
|
40
|
+
api_key: Optional[str] = None
|
|
41
|
+
verify_ssl: bool = True
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _create_chat_instance(
|
|
45
|
+
provider: str, model_name: str, api_key: Optional[str] = None, verify_ssl: bool = True
|
|
46
|
+
):
|
|
47
|
+
"""
|
|
48
|
+
Create a chatlas chat instance for the specified provider.
|
|
49
|
+
|
|
50
|
+
Parameters
|
|
51
|
+
----------
|
|
52
|
+
provider
|
|
53
|
+
The provider name (e.g., 'anthropic', 'openai', 'ollama', 'bedrock').
|
|
54
|
+
model_name
|
|
55
|
+
The model name for the provider.
|
|
56
|
+
api_key
|
|
57
|
+
Optional API key. If None, will be read from environment.
|
|
58
|
+
verify_ssl
|
|
59
|
+
Whether to verify SSL certificates when making requests. Defaults to True.
|
|
60
|
+
|
|
61
|
+
Returns
|
|
62
|
+
-------
|
|
63
|
+
Chat instance from chatlas.
|
|
64
|
+
"""
|
|
65
|
+
# Check if chatlas is installed
|
|
66
|
+
try:
|
|
67
|
+
import chatlas # noqa
|
|
68
|
+
except ImportError: # pragma: no cover
|
|
69
|
+
raise ImportError(
|
|
70
|
+
"The `chatlas` package is required for AI validation. "
|
|
71
|
+
"Please install it using `pip install chatlas`."
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Validate provider
|
|
75
|
+
if provider not in MODEL_PROVIDERS:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
f"Provider '{provider}' is not supported. "
|
|
78
|
+
f"Supported providers: {', '.join(MODEL_PROVIDERS)}"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
# System prompt with role definition and instructions
|
|
82
|
+
system_prompt = """You are a data validation assistant. Your task is to analyze rows of data and determine if they meet the specified validation criteria.
|
|
83
|
+
|
|
84
|
+
INSTRUCTIONS:
|
|
85
|
+
- Analyze each row in the provided data
|
|
86
|
+
- For each row, determine if it meets the validation criteria (True) or not (False)
|
|
87
|
+
- Return ONLY a JSON array with validation results
|
|
88
|
+
- Each result should have: {"index": <row_index>, "result": <true_or_false>}
|
|
89
|
+
- Do not include any explanatory text, only the JSON array
|
|
90
|
+
- The row_index should match the "_pb_row_index" field from the input data
|
|
91
|
+
|
|
92
|
+
EXAMPLE OUTPUT FORMAT:
|
|
93
|
+
[
|
|
94
|
+
{"index": 0, "result": true},
|
|
95
|
+
{"index": 1, "result": false},
|
|
96
|
+
{"index": 2, "result": true}
|
|
97
|
+
]"""
|
|
98
|
+
|
|
99
|
+
# Create httpx client with SSL verification settings
|
|
100
|
+
try:
|
|
101
|
+
import httpx # noqa
|
|
102
|
+
except ImportError: # pragma: no cover
|
|
103
|
+
raise ImportError( # pragma: no cover
|
|
104
|
+
"The `httpx` package is required for SSL configuration. "
|
|
105
|
+
"Please install it using `pip install httpx`."
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
http_client = httpx.AsyncClient(verify=verify_ssl)
|
|
109
|
+
|
|
110
|
+
# Create provider-specific chat instance
|
|
111
|
+
if provider == "anthropic": # pragma: no cover
|
|
112
|
+
# Check that the anthropic package is installed
|
|
113
|
+
try:
|
|
114
|
+
import anthropic # noqa
|
|
115
|
+
except ImportError:
|
|
116
|
+
raise ImportError(
|
|
117
|
+
"The `anthropic` package is required to use AI validation with "
|
|
118
|
+
"`anthropic`. Please install it using `pip install anthropic`."
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
from chatlas import ChatAnthropic
|
|
122
|
+
|
|
123
|
+
chat = ChatAnthropic(
|
|
124
|
+
model=model_name,
|
|
125
|
+
api_key=api_key,
|
|
126
|
+
system_prompt=system_prompt,
|
|
127
|
+
kwargs={"http_client": http_client},
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
elif provider == "openai": # pragma: no cover
|
|
131
|
+
# Check that the openai package is installed
|
|
132
|
+
try:
|
|
133
|
+
import openai # noqa
|
|
134
|
+
except ImportError:
|
|
135
|
+
raise ImportError(
|
|
136
|
+
"The `openai` package is required to use AI validation with "
|
|
137
|
+
"`openai`. Please install it using `pip install openai`."
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
from chatlas import ChatOpenAI
|
|
141
|
+
|
|
142
|
+
chat = ChatOpenAI(
|
|
143
|
+
model=model_name,
|
|
144
|
+
api_key=api_key,
|
|
145
|
+
system_prompt=system_prompt,
|
|
146
|
+
kwargs={"http_client": http_client},
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
elif provider == "ollama": # pragma: no cover
|
|
150
|
+
# Check that the openai package is installed (required for Ollama)
|
|
151
|
+
try:
|
|
152
|
+
import openai # noqa
|
|
153
|
+
except ImportError:
|
|
154
|
+
raise ImportError(
|
|
155
|
+
"The `openai` package is required to use AI validation with "
|
|
156
|
+
"`ollama`. Please install it using `pip install openai`."
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
from chatlas import ChatOllama
|
|
160
|
+
|
|
161
|
+
chat = ChatOllama(
|
|
162
|
+
model=model_name,
|
|
163
|
+
system_prompt=system_prompt,
|
|
164
|
+
kwargs={"http_client": http_client},
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
elif provider == "bedrock": # pragma: no cover
|
|
168
|
+
from chatlas import ChatBedrockAnthropic
|
|
169
|
+
|
|
170
|
+
chat = ChatBedrockAnthropic(
|
|
171
|
+
model=model_name,
|
|
172
|
+
system_prompt=system_prompt,
|
|
173
|
+
kwargs={"http_client": http_client},
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
else:
|
|
177
|
+
raise ValueError(f"Unsupported provider: {provider}")
|
|
178
|
+
|
|
179
|
+
return chat
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
# ============================================================================
|
|
183
|
+
# Data Batching and Optimization
|
|
184
|
+
# ============================================================================
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
@dataclass
|
|
188
|
+
class _BatchConfig:
|
|
189
|
+
"""Configuration for AI validation batching.
|
|
190
|
+
|
|
191
|
+
Parameters
|
|
192
|
+
----------
|
|
193
|
+
size
|
|
194
|
+
Batch size for processing rows.
|
|
195
|
+
max_concurrent
|
|
196
|
+
Maximum number of concurrent LLM requests.
|
|
197
|
+
"""
|
|
198
|
+
|
|
199
|
+
size: int = 1000
|
|
200
|
+
max_concurrent: int = 3
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class _DataBatcher:
|
|
204
|
+
"""Optimized batching of data for AI validation with row signature memoization."""
|
|
205
|
+
|
|
206
|
+
def __init__(
|
|
207
|
+
self,
|
|
208
|
+
data: FrameT,
|
|
209
|
+
columns: Optional[List[str]] = None,
|
|
210
|
+
config: Optional[_BatchConfig] = None,
|
|
211
|
+
):
|
|
212
|
+
"""
|
|
213
|
+
Initialize the optimized data batcher.
|
|
214
|
+
|
|
215
|
+
Parameters
|
|
216
|
+
----------
|
|
217
|
+
data
|
|
218
|
+
The data frame to batch.
|
|
219
|
+
columns
|
|
220
|
+
Optional list of columns to include in batches. If None, all columns are included.
|
|
221
|
+
config
|
|
222
|
+
Optional batch configuration. If None, default configuration is used.
|
|
223
|
+
"""
|
|
224
|
+
self.data = data
|
|
225
|
+
self.columns = columns
|
|
226
|
+
self.config = config or _BatchConfig()
|
|
227
|
+
self._validate_data()
|
|
228
|
+
|
|
229
|
+
# Memoization structures
|
|
230
|
+
self.unique_rows_table = None
|
|
231
|
+
self.signature_to_original_indices = {}
|
|
232
|
+
self.reduction_stats = {}
|
|
233
|
+
|
|
234
|
+
def _validate_data(self) -> None:
|
|
235
|
+
"""Validate that the data is supported."""
|
|
236
|
+
if not hasattr(self.data, "shape"):
|
|
237
|
+
raise ValueError("Data must have a 'shape' attribute")
|
|
238
|
+
|
|
239
|
+
# Get data with narwhals for compatibility
|
|
240
|
+
self._nw_data = nw.from_native(self.data)
|
|
241
|
+
|
|
242
|
+
if self.columns:
|
|
243
|
+
# Validate that specified columns exist
|
|
244
|
+
available_columns = self._nw_data.columns
|
|
245
|
+
missing_columns = set(self.columns) - set(available_columns)
|
|
246
|
+
if missing_columns:
|
|
247
|
+
raise ValueError(f"Columns not found in data: {missing_columns}")
|
|
248
|
+
|
|
249
|
+
def _create_row_signature(self, row_dict: Dict[str, Any]) -> str:
|
|
250
|
+
"""
|
|
251
|
+
Create a unique signature for a row based on selected columns.
|
|
252
|
+
|
|
253
|
+
Parameters
|
|
254
|
+
----------
|
|
255
|
+
row_dict
|
|
256
|
+
Dictionary representing a row.
|
|
257
|
+
|
|
258
|
+
Returns
|
|
259
|
+
-------
|
|
260
|
+
str
|
|
261
|
+
Unique signature for the row.
|
|
262
|
+
"""
|
|
263
|
+
# Create deterministic signature from sorted column values
|
|
264
|
+
signature_data = {k: v for k, v in row_dict.items() if k != "_pb_row_index"}
|
|
265
|
+
signature_str = json.dumps(signature_data, sort_keys=True, default=str)
|
|
266
|
+
return hashlib.md5(signature_str.encode()).hexdigest()
|
|
267
|
+
|
|
268
|
+
def _build_unique_rows_table(self) -> Tuple[FrameT, Dict[str, List[int]]]:
|
|
269
|
+
"""
|
|
270
|
+
Build unique rows table and mapping back to original indices.
|
|
271
|
+
|
|
272
|
+
Returns
|
|
273
|
+
-------
|
|
274
|
+
Tuple[FrameT, Dict[str, List[int]]]
|
|
275
|
+
Unique rows table and signature-to-indices mapping.
|
|
276
|
+
"""
|
|
277
|
+
nw_data = self._nw_data
|
|
278
|
+
|
|
279
|
+
# Select columns if specified
|
|
280
|
+
if self.columns:
|
|
281
|
+
nw_data = nw_data.select(self.columns)
|
|
282
|
+
|
|
283
|
+
# Convert to native for easier manipulation
|
|
284
|
+
native_data = nw.to_native(nw_data)
|
|
285
|
+
|
|
286
|
+
# Get all rows as dictionaries
|
|
287
|
+
if hasattr(native_data, "to_dicts"):
|
|
288
|
+
# Polars DataFrame
|
|
289
|
+
all_rows = native_data.to_dicts()
|
|
290
|
+
elif hasattr(native_data, "to_dict"):
|
|
291
|
+
# Pandas DataFrame
|
|
292
|
+
all_rows = native_data.to_dict("records")
|
|
293
|
+
else: # pragma: no cover
|
|
294
|
+
# Fallback: manual conversion
|
|
295
|
+
all_rows = []
|
|
296
|
+
columns = nw_data.columns
|
|
297
|
+
for i in range(len(native_data)):
|
|
298
|
+
row_dict = {}
|
|
299
|
+
for col in columns:
|
|
300
|
+
row_dict[col] = (
|
|
301
|
+
native_data[col].iloc[i]
|
|
302
|
+
if hasattr(native_data[col], "iloc")
|
|
303
|
+
else native_data[col][i]
|
|
304
|
+
)
|
|
305
|
+
all_rows.append(row_dict)
|
|
306
|
+
|
|
307
|
+
# Build signature mapping
|
|
308
|
+
signature_to_indices = {}
|
|
309
|
+
unique_rows = []
|
|
310
|
+
unique_signatures = set()
|
|
311
|
+
|
|
312
|
+
for original_idx, row_dict in enumerate(all_rows):
|
|
313
|
+
signature = self._create_row_signature(row_dict)
|
|
314
|
+
|
|
315
|
+
if signature not in signature_to_indices:
|
|
316
|
+
signature_to_indices[signature] = []
|
|
317
|
+
|
|
318
|
+
signature_to_indices[signature].append(original_idx)
|
|
319
|
+
|
|
320
|
+
# Add to unique rows if not seen before
|
|
321
|
+
if signature not in unique_signatures:
|
|
322
|
+
unique_signatures.add(signature)
|
|
323
|
+
# Add signature tracking for later joining
|
|
324
|
+
row_dict["_pb_signature"] = signature
|
|
325
|
+
unique_rows.append(row_dict)
|
|
326
|
+
|
|
327
|
+
# Convert unique rows back to dataframe
|
|
328
|
+
if unique_rows:
|
|
329
|
+
if hasattr(native_data, "with_columns"): # Polars
|
|
330
|
+
import polars as pl
|
|
331
|
+
|
|
332
|
+
unique_df = pl.DataFrame(unique_rows)
|
|
333
|
+
elif hasattr(native_data, "assign"): # Pandas
|
|
334
|
+
import pandas as pd
|
|
335
|
+
|
|
336
|
+
unique_df = pd.DataFrame(unique_rows)
|
|
337
|
+
else: # pragma: no cover
|
|
338
|
+
# This is tricky for generic case, but let's try
|
|
339
|
+
unique_df = unique_rows # Fallback to list of dicts
|
|
340
|
+
else: # pragma: no cover
|
|
341
|
+
unique_df = native_data.head(0) # Empty dataframe with same structure
|
|
342
|
+
|
|
343
|
+
# Store reduction stats
|
|
344
|
+
original_count = len(all_rows)
|
|
345
|
+
unique_count = len(unique_rows)
|
|
346
|
+
reduction_pct = (1 - unique_count / original_count) * 100 if original_count > 0 else 0
|
|
347
|
+
|
|
348
|
+
self.reduction_stats = {
|
|
349
|
+
"original_rows": original_count,
|
|
350
|
+
"unique_rows": unique_count,
|
|
351
|
+
"reduction_percentage": reduction_pct,
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
logger.info(
|
|
355
|
+
f"Row signature optimization: {original_count} → {unique_count} rows ({reduction_pct:.1f}% reduction)"
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
return unique_df, signature_to_indices
|
|
359
|
+
|
|
360
|
+
def create_batches(self) -> Tuple[List[Dict[str, Any]], Dict[str, List[int]]]:
|
|
361
|
+
"""
|
|
362
|
+
Create optimized batches using unique row signatures.
|
|
363
|
+
|
|
364
|
+
Returns
|
|
365
|
+
-------
|
|
366
|
+
Tuple[List[Dict[str, Any]], Dict[str, List[int]]]
|
|
367
|
+
Batches for unique rows and signature-to-indices mapping.
|
|
368
|
+
"""
|
|
369
|
+
# Build unique rows table and signature mapping
|
|
370
|
+
unique_rows_table, signature_to_indices = self._build_unique_rows_table()
|
|
371
|
+
self.unique_rows_table = unique_rows_table
|
|
372
|
+
self.signature_to_original_indices = signature_to_indices
|
|
373
|
+
|
|
374
|
+
# Create batches from unique rows table
|
|
375
|
+
if hasattr(unique_rows_table, "shape"):
|
|
376
|
+
total_rows = unique_rows_table.shape[0]
|
|
377
|
+
else: # pragma: no cover
|
|
378
|
+
total_rows = len(unique_rows_table)
|
|
379
|
+
|
|
380
|
+
batches = []
|
|
381
|
+
batch_id = 0
|
|
382
|
+
|
|
383
|
+
# Convert to narwhals if needed
|
|
384
|
+
if not hasattr(unique_rows_table, "columns"): # pragma: no cover
|
|
385
|
+
nw_unique = nw.from_native(unique_rows_table)
|
|
386
|
+
else:
|
|
387
|
+
nw_unique = unique_rows_table
|
|
388
|
+
|
|
389
|
+
for start_row in range(0, total_rows, self.config.size):
|
|
390
|
+
end_row = min(start_row + self.config.size, total_rows)
|
|
391
|
+
|
|
392
|
+
# Get the batch data
|
|
393
|
+
if hasattr(nw_unique, "__getitem__"):
|
|
394
|
+
batch_data = nw_unique[start_row:end_row]
|
|
395
|
+
else: # pragma: no cover
|
|
396
|
+
# Fallback for list of dicts
|
|
397
|
+
batch_data = unique_rows_table[start_row:end_row]
|
|
398
|
+
|
|
399
|
+
# Convert to JSON-serializable format
|
|
400
|
+
batch_json = self._convert_batch_to_json(batch_data, start_row)
|
|
401
|
+
|
|
402
|
+
batches.append(
|
|
403
|
+
{
|
|
404
|
+
"batch_id": batch_id,
|
|
405
|
+
"start_row": start_row,
|
|
406
|
+
"end_row": end_row,
|
|
407
|
+
"data": batch_json,
|
|
408
|
+
}
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
batch_id += 1
|
|
412
|
+
|
|
413
|
+
logger.info(f"Created {len(batches)} batches for {total_rows} unique rows")
|
|
414
|
+
return batches, signature_to_indices
|
|
415
|
+
|
|
416
|
+
def _convert_batch_to_json(self, batch_data, start_row: int) -> Dict[str, Any]:
|
|
417
|
+
"""
|
|
418
|
+
Convert a batch of unique data to JSON format for LLM consumption.
|
|
419
|
+
"""
|
|
420
|
+
# Handle different input types
|
|
421
|
+
if isinstance(batch_data, list):
|
|
422
|
+
# List of dictionaries
|
|
423
|
+
rows = []
|
|
424
|
+
columns = list(batch_data[0].keys()) if batch_data else []
|
|
425
|
+
|
|
426
|
+
for i, row_dict in enumerate(batch_data):
|
|
427
|
+
# Remove signature column from LLM input but keep for joining
|
|
428
|
+
clean_row = {k: v for k, v in row_dict.items() if k != "_pb_signature"}
|
|
429
|
+
clean_row["_pb_row_index"] = start_row + i
|
|
430
|
+
rows.append(clean_row)
|
|
431
|
+
|
|
432
|
+
# Remove signature from columns list
|
|
433
|
+
columns = [col for col in columns if col != "_pb_signature"]
|
|
434
|
+
|
|
435
|
+
else:
|
|
436
|
+
# DataFrame-like object
|
|
437
|
+
columns = [col for col in batch_data.columns if col != "_pb_signature"]
|
|
438
|
+
rows = []
|
|
439
|
+
|
|
440
|
+
# batch_data is already native format from slicing
|
|
441
|
+
native_batch = batch_data
|
|
442
|
+
|
|
443
|
+
# Handle different data frame types
|
|
444
|
+
if hasattr(native_batch, "to_dicts"):
|
|
445
|
+
# Polars DataFrame
|
|
446
|
+
batch_dicts = native_batch.to_dicts()
|
|
447
|
+
elif hasattr(native_batch, "to_dict"):
|
|
448
|
+
# Pandas DataFrame
|
|
449
|
+
batch_dicts = native_batch.to_dict("records")
|
|
450
|
+
else: # pragma: no cover
|
|
451
|
+
# Fallback: manual conversion
|
|
452
|
+
batch_dicts = []
|
|
453
|
+
for i in range(len(native_batch)):
|
|
454
|
+
row_dict = {}
|
|
455
|
+
for col in columns:
|
|
456
|
+
row_dict[col] = (
|
|
457
|
+
native_batch[col].iloc[i]
|
|
458
|
+
if hasattr(native_batch[col], "iloc")
|
|
459
|
+
else native_batch[col][i]
|
|
460
|
+
)
|
|
461
|
+
batch_dicts.append(row_dict)
|
|
462
|
+
|
|
463
|
+
# Clean up rows and add indices
|
|
464
|
+
for i, row_dict in enumerate(batch_dicts):
|
|
465
|
+
clean_row = {k: v for k, v in row_dict.items() if k != "_pb_signature"}
|
|
466
|
+
clean_row["_pb_row_index"] = start_row + i
|
|
467
|
+
rows.append(clean_row)
|
|
468
|
+
|
|
469
|
+
return {
|
|
470
|
+
"columns": columns,
|
|
471
|
+
"rows": rows,
|
|
472
|
+
"batch_info": {
|
|
473
|
+
"start_row": start_row,
|
|
474
|
+
"num_rows": len(rows),
|
|
475
|
+
"columns_count": len(columns),
|
|
476
|
+
},
|
|
477
|
+
}
|
|
478
|
+
|
|
479
|
+
def get_reduction_stats(self) -> Dict[str, Any]:
|
|
480
|
+
"""Get statistics about the row reduction optimization."""
|
|
481
|
+
return self.reduction_stats.copy()
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
# ============================================================================
|
|
485
|
+
# Prompt Building
|
|
486
|
+
# ============================================================================
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
class _PromptBuilder:
|
|
490
|
+
"""
|
|
491
|
+
Builds user messages for AI validation.
|
|
492
|
+
|
|
493
|
+
Works in conjunction with the system prompt set on the chat instance.
|
|
494
|
+
The system prompt contains role definition and general instructions,
|
|
495
|
+
while this class builds user messages with specific validation criteria and data.
|
|
496
|
+
"""
|
|
497
|
+
|
|
498
|
+
USER_MESSAGE_TEMPLATE = """VALIDATION CRITERIA:
|
|
499
|
+
{user_prompt}
|
|
500
|
+
|
|
501
|
+
DATA TO VALIDATE:
|
|
502
|
+
{data_json}"""
|
|
503
|
+
|
|
504
|
+
def __init__(self, user_prompt: str):
|
|
505
|
+
"""
|
|
506
|
+
Initialize the prompt builder.
|
|
507
|
+
|
|
508
|
+
Parameters
|
|
509
|
+
----------
|
|
510
|
+
user_prompt
|
|
511
|
+
The user's validation prompt describing what to check.
|
|
512
|
+
"""
|
|
513
|
+
self.user_prompt = user_prompt
|
|
514
|
+
|
|
515
|
+
def build_prompt(self, batch_data: Dict[str, Any]) -> str:
|
|
516
|
+
"""
|
|
517
|
+
Build a user message for a data batch.
|
|
518
|
+
|
|
519
|
+
Parameters
|
|
520
|
+
----------
|
|
521
|
+
batch_data
|
|
522
|
+
The batch data dictionary from DataBatcher.
|
|
523
|
+
|
|
524
|
+
Returns
|
|
525
|
+
-------
|
|
526
|
+
str
|
|
527
|
+
The user message for the LLM.
|
|
528
|
+
"""
|
|
529
|
+
data_json = json.dumps(batch_data, indent=2, default=str)
|
|
530
|
+
|
|
531
|
+
return self.USER_MESSAGE_TEMPLATE.format(user_prompt=self.user_prompt, data_json=data_json)
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
# ============================================================================
|
|
535
|
+
# Response Parsing
|
|
536
|
+
# ============================================================================
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
class _ValidationResponseParser:
|
|
540
|
+
"""Parses AI validation responses."""
|
|
541
|
+
|
|
542
|
+
def __init__(self, total_rows: int):
|
|
543
|
+
"""
|
|
544
|
+
Initialize the response parser.
|
|
545
|
+
|
|
546
|
+
Parameters
|
|
547
|
+
----------
|
|
548
|
+
total_rows
|
|
549
|
+
Total number of rows being validated.
|
|
550
|
+
"""
|
|
551
|
+
self.total_rows = total_rows
|
|
552
|
+
|
|
553
|
+
def parse_response(self, response: str, batch_info: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
554
|
+
"""
|
|
555
|
+
Parse an LLM response for a batch.
|
|
556
|
+
|
|
557
|
+
Parameters
|
|
558
|
+
----------
|
|
559
|
+
response
|
|
560
|
+
The raw response from the LLM.
|
|
561
|
+
batch_info
|
|
562
|
+
Information about the batch being processed.
|
|
563
|
+
|
|
564
|
+
Returns
|
|
565
|
+
-------
|
|
566
|
+
List[Dict[str, Any]]
|
|
567
|
+
List of parsed results with 'index' and 'result' keys.
|
|
568
|
+
"""
|
|
569
|
+
try:
|
|
570
|
+
# Try to extract JSON from the response
|
|
571
|
+
json_response = self._extract_json(response)
|
|
572
|
+
|
|
573
|
+
# Validate the structure
|
|
574
|
+
self._validate_response_structure(json_response, batch_info)
|
|
575
|
+
|
|
576
|
+
return json_response
|
|
577
|
+
|
|
578
|
+
except Exception as e:
|
|
579
|
+
logger.error(
|
|
580
|
+
f"Failed to parse response for batch {batch_info.get('batch_id', 'unknown')}: {e}"
|
|
581
|
+
)
|
|
582
|
+
logger.error(f"Raw response: {response}")
|
|
583
|
+
|
|
584
|
+
# Return default results (all False) for this batch
|
|
585
|
+
return self._create_default_results(batch_info)
|
|
586
|
+
|
|
587
|
+
def _extract_json(self, response: str) -> List[Dict[str, Any]]:
|
|
588
|
+
"""Extract JSON from LLM response."""
|
|
589
|
+
# Clean up the response
|
|
590
|
+
response = response.strip()
|
|
591
|
+
|
|
592
|
+
# Look for JSON array patterns
|
|
593
|
+
import re
|
|
594
|
+
|
|
595
|
+
json_pattern = r"\[.*?\]"
|
|
596
|
+
matches = re.findall(json_pattern, response, re.DOTALL)
|
|
597
|
+
|
|
598
|
+
if matches:
|
|
599
|
+
# Try to parse the first match
|
|
600
|
+
try:
|
|
601
|
+
return json.loads(matches[0])
|
|
602
|
+
except json.JSONDecodeError: # pragma: no cover
|
|
603
|
+
# If that fails, try the raw response
|
|
604
|
+
return json.loads(response)
|
|
605
|
+
else:
|
|
606
|
+
# Try to parse the raw response
|
|
607
|
+
return json.loads(response)
|
|
608
|
+
|
|
609
|
+
def _validate_response_structure(
|
|
610
|
+
self, json_response: List[Dict[str, Any]], batch_info: Dict[str, Any]
|
|
611
|
+
) -> None:
|
|
612
|
+
"""Validate that the response has the correct structure."""
|
|
613
|
+
if not isinstance(json_response, list):
|
|
614
|
+
raise ValueError("Response must be a JSON array")
|
|
615
|
+
|
|
616
|
+
expected_rows = batch_info["end_row"] - batch_info["start_row"]
|
|
617
|
+
if len(json_response) != expected_rows:
|
|
618
|
+
logger.warning(
|
|
619
|
+
f"Expected {expected_rows} results, got {len(json_response)} for batch {batch_info.get('batch_id')}"
|
|
620
|
+
)
|
|
621
|
+
|
|
622
|
+
for i, result in enumerate(json_response):
|
|
623
|
+
if not isinstance(result, dict):
|
|
624
|
+
raise ValueError(f"Result {i} must be a dictionary")
|
|
625
|
+
|
|
626
|
+
if "index" not in result or "result" not in result:
|
|
627
|
+
raise ValueError(f"Result {i} must have 'index' and 'result' keys")
|
|
628
|
+
|
|
629
|
+
if not isinstance(result["result"], bool):
|
|
630
|
+
raise ValueError(f"Result {i} 'result' must be a boolean")
|
|
631
|
+
|
|
632
|
+
def _create_default_results(self, batch_info: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
633
|
+
"""Create default results (all False) for a batch."""
|
|
634
|
+
results = []
|
|
635
|
+
for i in range(batch_info["start_row"], batch_info["end_row"]):
|
|
636
|
+
results.append({"index": i, "result": False})
|
|
637
|
+
return results
|
|
638
|
+
|
|
639
|
+
def combine_batch_results(
|
|
640
|
+
self,
|
|
641
|
+
batch_results: List[List[Dict[str, Any]]],
|
|
642
|
+
signature_mapping: Optional[Dict[str, List[int]]] = None,
|
|
643
|
+
) -> Dict[int, bool]:
|
|
644
|
+
"""
|
|
645
|
+
Combine results from multiple batches and project to original rows using signature mapping.
|
|
646
|
+
|
|
647
|
+
Parameters
|
|
648
|
+
----------
|
|
649
|
+
batch_results
|
|
650
|
+
List of batch results from parse_response.
|
|
651
|
+
signature_mapping
|
|
652
|
+
Optional mapping from row signatures to original row indices for memoization.
|
|
653
|
+
|
|
654
|
+
Returns
|
|
655
|
+
-------
|
|
656
|
+
Dict[int, bool]
|
|
657
|
+
Dictionary mapping original row index to validation result.
|
|
658
|
+
"""
|
|
659
|
+
logger.debug(f"🔀 Combining results from {len(batch_results)} batches")
|
|
660
|
+
|
|
661
|
+
if signature_mapping:
|
|
662
|
+
logger.debug(
|
|
663
|
+
f"🎯 Using signature mapping optimization for {len(signature_mapping)} unique signatures"
|
|
664
|
+
)
|
|
665
|
+
|
|
666
|
+
# First, collect results from unique rows
|
|
667
|
+
unique_results = {}
|
|
668
|
+
total_processed = 0
|
|
669
|
+
|
|
670
|
+
for batch_idx, batch_result in enumerate(batch_results):
|
|
671
|
+
logger.debug(f" Batch {batch_idx}: {len(batch_result)} results")
|
|
672
|
+
|
|
673
|
+
for result in batch_result:
|
|
674
|
+
index = result["index"]
|
|
675
|
+
validation_result = result["result"]
|
|
676
|
+
unique_results[index] = validation_result
|
|
677
|
+
total_processed += 1
|
|
678
|
+
|
|
679
|
+
# Log first few results
|
|
680
|
+
if len(unique_results) <= 3 or total_processed <= 3:
|
|
681
|
+
logger.debug(f" Unique row {index}: {validation_result}")
|
|
682
|
+
|
|
683
|
+
# If no signature mapping, return unique results as-is (fallback to original behavior)
|
|
684
|
+
if not signature_mapping:
|
|
685
|
+
logger.debug("📊 No signature mapping - returning unique results")
|
|
686
|
+
passed_count = sum(1 for v in unique_results.values() if v)
|
|
687
|
+
failed_count = len(unique_results) - passed_count
|
|
688
|
+
logger.debug(f" - Final count: {passed_count} passed, {failed_count} failed")
|
|
689
|
+
return unique_results
|
|
690
|
+
|
|
691
|
+
# Project unique results back to all original rows using signature mapping
|
|
692
|
+
combined_results = {}
|
|
693
|
+
|
|
694
|
+
# We need to map from unique row indices back to signatures, then to original indices
|
|
695
|
+
# This requires rebuilding the signatures from the unique rows
|
|
696
|
+
# For now, let's assume the unique_results indices correspond to signature order
|
|
697
|
+
signature_list = list(signature_mapping.keys())
|
|
698
|
+
|
|
699
|
+
for unique_idx, validation_result in unique_results.items():
|
|
700
|
+
if unique_idx < len(signature_list):
|
|
701
|
+
signature = signature_list[unique_idx]
|
|
702
|
+
original_indices = signature_mapping[signature]
|
|
703
|
+
|
|
704
|
+
logger.debug(
|
|
705
|
+
f" Projecting result {validation_result} from unique row {unique_idx} to {len(original_indices)} original rows"
|
|
706
|
+
)
|
|
707
|
+
|
|
708
|
+
# Project this result to all original rows with this signature
|
|
709
|
+
for original_idx in original_indices:
|
|
710
|
+
combined_results[original_idx] = validation_result
|
|
711
|
+
else: # pragma: no cover
|
|
712
|
+
logger.warning(f"Unique index {unique_idx} out of range for signature mapping")
|
|
713
|
+
|
|
714
|
+
logger.debug("📊 Projected results summary:")
|
|
715
|
+
logger.debug(f" - Unique rows processed: {len(unique_results)}")
|
|
716
|
+
logger.debug(f" - Original rows mapped: {len(combined_results)}")
|
|
717
|
+
logger.debug(
|
|
718
|
+
f" - Index range: {min(combined_results.keys()) if combined_results else 'N/A'} to {max(combined_results.keys()) if combined_results else 'N/A'}"
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
passed_count = sum(1 for v in combined_results.values() if v)
|
|
722
|
+
failed_count = len(combined_results) - passed_count
|
|
723
|
+
logger.debug(f" - Final count: {passed_count} passed, {failed_count} failed")
|
|
724
|
+
|
|
725
|
+
return combined_results
|
|
726
|
+
|
|
727
|
+
|
|
728
|
+
# ============================================================================
|
|
729
|
+
# AI Validation Engine
|
|
730
|
+
# ============================================================================
|
|
731
|
+
|
|
732
|
+
|
|
733
|
+
class _AIValidationEngine:
|
|
734
|
+
"""Main engine for AI-powered validation using chatlas."""
|
|
735
|
+
|
|
736
|
+
def __init__(self, llm_config: _LLMConfig):
|
|
737
|
+
"""
|
|
738
|
+
Initialize the AI validation engine.
|
|
739
|
+
|
|
740
|
+
Parameters
|
|
741
|
+
----------
|
|
742
|
+
llm_config
|
|
743
|
+
Configuration for the LLM provider.
|
|
744
|
+
"""
|
|
745
|
+
self.llm_config = llm_config
|
|
746
|
+
self.chat = _create_chat_instance(
|
|
747
|
+
provider=llm_config.provider,
|
|
748
|
+
model_name=llm_config.model,
|
|
749
|
+
api_key=llm_config.api_key,
|
|
750
|
+
verify_ssl=llm_config.verify_ssl,
|
|
751
|
+
)
|
|
752
|
+
|
|
753
|
+
def validate_batches(
|
|
754
|
+
self, batches: List[Dict[str, Any]], prompt_builder: Any, max_concurrent: int = 3
|
|
755
|
+
) -> List[List[Dict[str, Any]]]:
|
|
756
|
+
"""
|
|
757
|
+
Validate multiple batches.
|
|
758
|
+
|
|
759
|
+
Parameters
|
|
760
|
+
----------
|
|
761
|
+
batches
|
|
762
|
+
List of batch dictionaries from DataBatcher.
|
|
763
|
+
prompt_builder
|
|
764
|
+
PromptBuilder instance for generating prompts.
|
|
765
|
+
max_concurrent
|
|
766
|
+
Maximum number of concurrent requests (ignored for now with chatlas).
|
|
767
|
+
|
|
768
|
+
Returns
|
|
769
|
+
-------
|
|
770
|
+
List[List[Dict[str, Any]]]
|
|
771
|
+
List of batch results, each containing validation results.
|
|
772
|
+
"""
|
|
773
|
+
|
|
774
|
+
def validate_batch(batch: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
775
|
+
try:
|
|
776
|
+
# Debug: Log batch information
|
|
777
|
+
logger.debug(f"🔍 Processing batch {batch['batch_id']}")
|
|
778
|
+
logger.debug(f" - Rows: {batch['start_row']} to {batch['end_row'] - 1}")
|
|
779
|
+
logger.debug(
|
|
780
|
+
f" - Data shape: {batch['data'].shape if hasattr(batch['data'], 'shape') else 'N/A'}"
|
|
781
|
+
)
|
|
782
|
+
|
|
783
|
+
# Build the prompt for this batch
|
|
784
|
+
prompt = prompt_builder.build_prompt(batch["data"])
|
|
785
|
+
|
|
786
|
+
# Debug: Log the prompt being sent to LLM
|
|
787
|
+
logger.debug(f"📤 LLM Prompt for batch {batch['batch_id']}:")
|
|
788
|
+
logger.debug("--- PROMPT START ---")
|
|
789
|
+
logger.debug(prompt)
|
|
790
|
+
logger.debug("--- PROMPT END ---")
|
|
791
|
+
|
|
792
|
+
# Get response from LLM using chatlas (synchronous)
|
|
793
|
+
response = str(self.chat.chat(prompt, stream=False, echo="none"))
|
|
794
|
+
|
|
795
|
+
# Debug: Log the raw LLM response
|
|
796
|
+
logger.debug(f"📥 LLM Response for batch {batch['batch_id']}:")
|
|
797
|
+
logger.debug("--- RESPONSE START ---")
|
|
798
|
+
logger.debug(response)
|
|
799
|
+
logger.debug("--- RESPONSE END ---")
|
|
800
|
+
|
|
801
|
+
# Parse the response
|
|
802
|
+
parser = _ValidationResponseParser(total_rows=1000) # This will be set properly
|
|
803
|
+
results = parser.parse_response(response, batch)
|
|
804
|
+
|
|
805
|
+
# Debug: Log parsed results # pragma: no cover
|
|
806
|
+
logger.debug(f"📊 Parsed results for batch {batch['batch_id']}:")
|
|
807
|
+
for i, result in enumerate(results[:5]): # Show first 5 results
|
|
808
|
+
logger.debug(f" Row {result['index']}: {result['result']}")
|
|
809
|
+
if len(results) > 5:
|
|
810
|
+
logger.debug(f" ... and {len(results) - 5} more results")
|
|
811
|
+
|
|
812
|
+
passed_count = sum(1 for r in results if r["result"])
|
|
813
|
+
failed_count = len(results) - passed_count
|
|
814
|
+
logger.debug(f" Summary: {passed_count} passed, {failed_count} failed")
|
|
815
|
+
|
|
816
|
+
logger.info(f"Successfully validated batch {batch['batch_id']}")
|
|
817
|
+
return results
|
|
818
|
+
|
|
819
|
+
except Exception as e: # pragma: no cover
|
|
820
|
+
logger.error(
|
|
821
|
+
f"Failed to validate batch {batch['batch_id']}: {e}"
|
|
822
|
+
) # pragma: no cover
|
|
823
|
+
# Return default results (all False) for failed batches
|
|
824
|
+
default_results = [] # pragma: no cover
|
|
825
|
+
for i in range(batch["start_row"], batch["end_row"]): # pragma: no cover
|
|
826
|
+
default_results.append({"index": i, "result": False}) # pragma: no cover
|
|
827
|
+
return default_results # pragma: no cover
|
|
828
|
+
|
|
829
|
+
# Execute all batch validations sequentially (chatlas is synchronous)
|
|
830
|
+
final_results = []
|
|
831
|
+
for batch in batches:
|
|
832
|
+
result = validate_batch(batch)
|
|
833
|
+
final_results.append(result)
|
|
834
|
+
|
|
835
|
+
return final_results
|
|
836
|
+
|
|
837
|
+
def validate_single_batch(
|
|
838
|
+
self, batch: Dict[str, Any], prompt_builder: Any
|
|
839
|
+
) -> List[Dict[str, Any]]:
|
|
840
|
+
"""
|
|
841
|
+
Validate a single batch.
|
|
842
|
+
|
|
843
|
+
Parameters
|
|
844
|
+
----------
|
|
845
|
+
batch
|
|
846
|
+
Batch dictionary from DataBatcher.
|
|
847
|
+
prompt_builder
|
|
848
|
+
PromptBuilder instance for generating prompts.
|
|
849
|
+
|
|
850
|
+
Returns
|
|
851
|
+
-------
|
|
852
|
+
List[Dict[str, Any]]
|
|
853
|
+
Validation results for the batch.
|
|
854
|
+
"""
|
|
855
|
+
try:
|
|
856
|
+
# Build the prompt for this batch
|
|
857
|
+
prompt = prompt_builder.build_prompt(batch["data"])
|
|
858
|
+
|
|
859
|
+
# Get response from LLM using chatlas (synchronous)
|
|
860
|
+
response = str(self.chat.chat(prompt, stream=False, echo="none"))
|
|
861
|
+
|
|
862
|
+
# Parse the response
|
|
863
|
+
parser = _ValidationResponseParser(total_rows=1000) # This will be set properly
|
|
864
|
+
results = parser.parse_response(response, batch)
|
|
865
|
+
|
|
866
|
+
logger.info(f"Successfully validated batch {batch['batch_id']}")
|
|
867
|
+
return results
|
|
868
|
+
|
|
869
|
+
except Exception as e:
|
|
870
|
+
logger.error(f"Failed to validate batch {batch['batch_id']}: {e}")
|
|
871
|
+
# Return default results (all False) for failed batch
|
|
872
|
+
default_results = []
|
|
873
|
+
for i in range(batch["start_row"], batch["end_row"]):
|
|
874
|
+
default_results.append({"index": i, "result": False})
|
|
875
|
+
return default_results
|