sdg-hub 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdg_hub/_version.py +16 -3
- sdg_hub/core/blocks/deprecated_blocks/selector.py +1 -1
- sdg_hub/core/blocks/evaluation/evaluate_faithfulness_block.py +175 -416
- sdg_hub/core/blocks/evaluation/evaluate_relevancy_block.py +174 -415
- sdg_hub/core/blocks/evaluation/verify_question_block.py +180 -415
- sdg_hub/core/blocks/llm/__init__.py +2 -0
- sdg_hub/core/blocks/llm/client_manager.py +61 -24
- sdg_hub/core/blocks/llm/config.py +1 -0
- sdg_hub/core/blocks/llm/llm_chat_block.py +62 -7
- sdg_hub/core/blocks/llm/llm_chat_with_parsing_retry_block.py +653 -0
- sdg_hub/core/blocks/llm/text_parser_block.py +75 -30
- sdg_hub/core/blocks/registry.py +49 -35
- sdg_hub/core/blocks/transform/index_based_mapper.py +1 -1
- sdg_hub/core/flow/base.py +370 -20
- sdg_hub/core/flow/checkpointer.py +333 -0
- sdg_hub/core/flow/metadata.py +45 -0
- sdg_hub/core/flow/migration.py +12 -1
- sdg_hub/core/flow/registry.py +121 -58
- sdg_hub/core/flow/validation.py +12 -0
- sdg_hub/core/utils/__init__.py +2 -1
- sdg_hub/core/utils/datautils.py +81 -1
- sdg_hub/core/utils/flow_id_words.yaml +231 -0
- sdg_hub/core/utils/flow_identifier.py +94 -0
- sdg_hub/core/utils/yaml_utils.py +59 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/flow.yaml +1 -7
- {sdg_hub-0.2.0.dist-info → sdg_hub-0.2.2.dist-info}/METADATA +59 -31
- {sdg_hub-0.2.0.dist-info → sdg_hub-0.2.2.dist-info}/RECORD +30 -25
- {sdg_hub-0.2.0.dist-info → sdg_hub-0.2.2.dist-info}/WHEEL +0 -0
- {sdg_hub-0.2.0.dist-info → sdg_hub-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {sdg_hub-0.2.0.dist-info → sdg_hub-0.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,653 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
"""Composite block combining LLM chat and text parsing with retry logic.
|
3
|
+
|
4
|
+
This module provides the LLMChatWithParsingRetryBlock that encapsulates the complete
|
5
|
+
LLM generation and parsing workflow with automatic retry on parsing failures.
|
6
|
+
"""
|
7
|
+
|
8
|
+
# Standard
|
9
|
+
from typing import Any, Optional
|
10
|
+
|
11
|
+
# Third Party
|
12
|
+
from datasets import Dataset
|
13
|
+
from pydantic import ConfigDict, Field, field_validator
|
14
|
+
|
15
|
+
# Local
|
16
|
+
from ...utils.error_handling import BlockValidationError
|
17
|
+
from ...utils.logger_config import setup_logger
|
18
|
+
from ..base import BaseBlock
|
19
|
+
from ..registry import BlockRegistry
|
20
|
+
from .llm_chat_block import LLMChatBlock
|
21
|
+
from .text_parser_block import TextParserBlock
|
22
|
+
|
23
|
+
logger = setup_logger(__name__)
|
24
|
+
|
25
|
+
|
26
|
+
class MaxRetriesExceededError(Exception):
|
27
|
+
"""Raised when maximum retry attempts are exceeded without achieving target count."""
|
28
|
+
|
29
|
+
def __init__(self, target_count: int, actual_count: int, max_retries: int):
|
30
|
+
self.target_count = target_count
|
31
|
+
self.actual_count = actual_count
|
32
|
+
self.max_retries = max_retries
|
33
|
+
super().__init__(
|
34
|
+
f"Failed to achieve target count {target_count} after {max_retries} retries. "
|
35
|
+
f"Only got {actual_count} successful parses."
|
36
|
+
)
|
37
|
+
|
38
|
+
|
39
|
+
@BlockRegistry.register(
|
40
|
+
"LLMChatWithParsingRetryBlock",
|
41
|
+
"llm",
|
42
|
+
"Composite block combining LLM chat and text parsing with automatic retry on parsing failures",
|
43
|
+
)
|
44
|
+
class LLMChatWithParsingRetryBlock(BaseBlock):
|
45
|
+
"""Composite block for LLM generation with parsing retry logic.
|
46
|
+
|
47
|
+
This block combines LLMChatBlock and TextParserBlock into a single cohesive block
|
48
|
+
that automatically retries LLM generation when parsing fails, accumulating successful
|
49
|
+
results until the target count is reached or max retries exceeded.
|
50
|
+
|
51
|
+
Parameters
|
52
|
+
----------
|
53
|
+
block_name : str
|
54
|
+
Name of the block.
|
55
|
+
input_cols : Union[str, List[str]]
|
56
|
+
Input column name(s). Should contain the messages list.
|
57
|
+
output_cols : Union[str, List[str]]
|
58
|
+
Output column name(s) for parsed results.
|
59
|
+
model : str
|
60
|
+
Model identifier in LiteLLM format.
|
61
|
+
api_base : Optional[str]
|
62
|
+
Base URL for the API. Required for local models.
|
63
|
+
api_key : Optional[str]
|
64
|
+
API key for the provider. Falls back to environment variables.
|
65
|
+
parsing_max_retries : int, optional
|
66
|
+
Maximum number of retry attempts for parsing failures (default: 3).
|
67
|
+
This is different from max_retries, which handles LLM network/API failures.
|
68
|
+
|
69
|
+
### LLM Generation Parameters ###
|
70
|
+
async_mode : bool, optional
|
71
|
+
Whether to use async processing (default: False).
|
72
|
+
timeout : float, optional
|
73
|
+
Request timeout in seconds (default: 120.0).
|
74
|
+
max_retries : int, optional
|
75
|
+
Maximum number of LLM retry attempts for network failures (default: 6).
|
76
|
+
temperature : Optional[float], optional
|
77
|
+
Sampling temperature (0.0 to 2.0).
|
78
|
+
max_tokens : Optional[int], optional
|
79
|
+
Maximum tokens to generate.
|
80
|
+
top_p : Optional[float], optional
|
81
|
+
Nucleus sampling parameter (0.0 to 1.0).
|
82
|
+
frequency_penalty : Optional[float], optional
|
83
|
+
Frequency penalty (-2.0 to 2.0).
|
84
|
+
presence_penalty : Optional[float], optional
|
85
|
+
Presence penalty (-2.0 to 2.0).
|
86
|
+
stop : Optional[Union[str, List[str]]], optional
|
87
|
+
Stop sequences.
|
88
|
+
seed : Optional[int], optional
|
89
|
+
Random seed for reproducible outputs.
|
90
|
+
response_format : Optional[Dict[str, Any]], optional
|
91
|
+
Response format specification (e.g., JSON mode).
|
92
|
+
stream : Optional[bool], optional
|
93
|
+
Whether to stream responses.
|
94
|
+
n : Optional[int], optional
|
95
|
+
Number of completions to generate per retry attempt.
|
96
|
+
logprobs : Optional[bool], optional
|
97
|
+
Whether to return log probabilities.
|
98
|
+
top_logprobs : Optional[int], optional
|
99
|
+
Number of top log probabilities to return.
|
100
|
+
user : Optional[str], optional
|
101
|
+
End-user identifier.
|
102
|
+
extra_headers : Optional[Dict[str, str]], optional
|
103
|
+
Additional headers to send with requests.
|
104
|
+
extra_body : Optional[Dict[str, Any]], optional
|
105
|
+
Additional parameters for the request body.
|
106
|
+
provider_specific : Optional[Dict[str, Any]], optional
|
107
|
+
Provider-specific parameters.
|
108
|
+
|
109
|
+
### Text Parser Parameters ###
|
110
|
+
start_tags : List[str], optional
|
111
|
+
List of start tags for tag-based parsing.
|
112
|
+
end_tags : List[str], optional
|
113
|
+
List of end tags for tag-based parsing.
|
114
|
+
parsing_pattern : Optional[str], optional
|
115
|
+
Regex pattern for custom parsing.
|
116
|
+
parser_cleanup_tags : Optional[List[str]], optional
|
117
|
+
List of tags to clean from parsed output.
|
118
|
+
|
119
|
+
Examples
|
120
|
+
--------
|
121
|
+
>>> # Basic JSON parsing with retry
|
122
|
+
>>> block = LLMChatWithParsingRetryBlock(
|
123
|
+
... block_name="json_retry_block",
|
124
|
+
... input_cols="messages",
|
125
|
+
... output_cols="parsed_json",
|
126
|
+
... model="openai/gpt-4",
|
127
|
+
... parsing_max_retries=3,
|
128
|
+
... parsing_pattern=r'"result":\s*"([^"]*)"',
|
129
|
+
... n=3
|
130
|
+
... )
|
131
|
+
|
132
|
+
>>> # Tag-based parsing with retry
|
133
|
+
>>> block = LLMChatWithParsingRetryBlock(
|
134
|
+
... block_name="tag_retry_block",
|
135
|
+
... input_cols="messages",
|
136
|
+
... output_cols=["explanation", "answer"],
|
137
|
+
... model="anthropic/claude-3-sonnet-20240229",
|
138
|
+
... parsing_max_retries=5,
|
139
|
+
... start_tags=["<explanation>", "<answer>"],
|
140
|
+
... end_tags=["</explanation>", "</answer>"],
|
141
|
+
... n=2
|
142
|
+
... )
|
143
|
+
"""
|
144
|
+
|
145
|
+
model_config = ConfigDict(
|
146
|
+
extra="allow"
|
147
|
+
) # Allow extra fields for dynamic forwarding
|
148
|
+
|
149
|
+
# --- Composite-specific configuration ---
|
150
|
+
parsing_max_retries: int = Field(
|
151
|
+
3, description="Maximum number of retry attempts for parsing failures"
|
152
|
+
)
|
153
|
+
|
154
|
+
# --- Parser configuration (required for internal TextParserBlock) ---
|
155
|
+
start_tags: Optional[list[str]] = Field(
|
156
|
+
None, description="Start tags for tag-based parsing"
|
157
|
+
)
|
158
|
+
end_tags: Optional[list[str]] = Field(
|
159
|
+
None, description="End tags for tag-based parsing"
|
160
|
+
)
|
161
|
+
parsing_pattern: Optional[str] = Field(
|
162
|
+
None, description="Regex pattern for custom parsing"
|
163
|
+
)
|
164
|
+
parser_cleanup_tags: Optional[list[str]] = Field(
|
165
|
+
None, description="List of tags to clean from parsed output"
|
166
|
+
)
|
167
|
+
|
168
|
+
# Internal blocks - excluded from serialization
|
169
|
+
llm_chat: Optional[LLMChatBlock] = Field(None, exclude=True)
|
170
|
+
text_parser: Optional[TextParserBlock] = Field(None, exclude=True)
|
171
|
+
|
172
|
+
@field_validator("input_cols")
|
173
|
+
@classmethod
|
174
|
+
def validate_single_input_col(cls, v):
|
175
|
+
"""Ensure exactly one input column."""
|
176
|
+
if isinstance(v, str):
|
177
|
+
return [v]
|
178
|
+
if isinstance(v, list) and len(v) == 1:
|
179
|
+
return v
|
180
|
+
if isinstance(v, list) and len(v) != 1:
|
181
|
+
raise ValueError(
|
182
|
+
f"LLMChatWithParsingRetryBlock expects exactly one input column, got {len(v)}: {v}"
|
183
|
+
)
|
184
|
+
raise ValueError(f"Invalid input_cols format: {v}")
|
185
|
+
|
186
|
+
@field_validator("parsing_max_retries")
|
187
|
+
@classmethod
|
188
|
+
def validate_parsing_max_retries(cls, v):
|
189
|
+
"""Ensure parsing_max_retries is positive."""
|
190
|
+
if v < 1:
|
191
|
+
raise ValueError("parsing_max_retries must be at least 1")
|
192
|
+
return v
|
193
|
+
|
194
|
+
def __init__(self, **kwargs):
|
195
|
+
"""Initialize with dynamic parameter routing."""
|
196
|
+
super().__init__(**kwargs)
|
197
|
+
self._create_internal_blocks(**kwargs)
|
198
|
+
|
199
|
+
# Log initialization if model is configured
|
200
|
+
if hasattr(self, "model") and self.model:
|
201
|
+
logger.info(
|
202
|
+
f"Initialized LLMChatWithParsingRetryBlock '{self.block_name}' with model '{self.model}'",
|
203
|
+
extra={
|
204
|
+
"block_name": self.block_name,
|
205
|
+
"model": self.model,
|
206
|
+
"parsing_max_retries": self.parsing_max_retries,
|
207
|
+
},
|
208
|
+
)
|
209
|
+
|
210
|
+
def _extract_params(self, kwargs: dict, block_class) -> dict:
|
211
|
+
"""Extract parameters for specific block class based on its model_fields."""
|
212
|
+
# Exclude parameters that are handled by this wrapper
|
213
|
+
wrapper_params = {
|
214
|
+
"block_name",
|
215
|
+
"input_cols",
|
216
|
+
"output_cols",
|
217
|
+
"parsing_max_retries",
|
218
|
+
}
|
219
|
+
|
220
|
+
# Extract parameters that the target block accepts
|
221
|
+
params = {
|
222
|
+
k: v
|
223
|
+
for k, v in kwargs.items()
|
224
|
+
if k in block_class.model_fields and k not in wrapper_params
|
225
|
+
}
|
226
|
+
|
227
|
+
# Also include declared fields from this composite block that the target block accepts
|
228
|
+
for field_name in self.__class__.model_fields:
|
229
|
+
if (
|
230
|
+
field_name in block_class.model_fields
|
231
|
+
and field_name not in wrapper_params
|
232
|
+
):
|
233
|
+
field_value = getattr(self, field_name)
|
234
|
+
if field_value is not None: # Only forward non-None values
|
235
|
+
params[field_name] = field_value
|
236
|
+
|
237
|
+
return params
|
238
|
+
|
239
|
+
def _create_internal_blocks(self, **kwargs):
|
240
|
+
"""Create internal blocks with parameter routing."""
|
241
|
+
# Route parameters to appropriate blocks
|
242
|
+
llm_params = self._extract_params(kwargs, LLMChatBlock)
|
243
|
+
parser_params = self._extract_params(kwargs, TextParserBlock)
|
244
|
+
|
245
|
+
# 1. LLMChatBlock
|
246
|
+
self.llm_chat = LLMChatBlock(
|
247
|
+
block_name=f"{self.block_name}_llm_chat",
|
248
|
+
input_cols=self.input_cols,
|
249
|
+
output_cols=[f"{self.block_name}_raw_response"],
|
250
|
+
**llm_params,
|
251
|
+
)
|
252
|
+
|
253
|
+
# 2. TextParserBlock
|
254
|
+
self.text_parser = TextParserBlock(
|
255
|
+
block_name=f"{self.block_name}_text_parser",
|
256
|
+
input_cols=[f"{self.block_name}_raw_response"],
|
257
|
+
output_cols=self.output_cols,
|
258
|
+
**parser_params,
|
259
|
+
)
|
260
|
+
|
261
|
+
def __getattr__(self, name: str) -> Any:
|
262
|
+
"""Forward attribute access to appropriate internal block."""
|
263
|
+
# Check each internal block to see which one has this parameter
|
264
|
+
for block_attr, block_class in [
|
265
|
+
("llm_chat", LLMChatBlock),
|
266
|
+
("text_parser", TextParserBlock),
|
267
|
+
]:
|
268
|
+
if hasattr(self, block_attr) and name in block_class.model_fields:
|
269
|
+
internal_block = getattr(self, block_attr)
|
270
|
+
if internal_block is not None:
|
271
|
+
return getattr(internal_block, name)
|
272
|
+
raise AttributeError(
|
273
|
+
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
274
|
+
)
|
275
|
+
|
276
|
+
def __setattr__(self, name: str, value: Any) -> None:
|
277
|
+
"""Handle dynamic parameter updates from flow.set_model_config()."""
|
278
|
+
super().__setattr__(name, value)
|
279
|
+
|
280
|
+
# Forward to appropriate internal blocks
|
281
|
+
for block_attr, block_class in [
|
282
|
+
("llm_chat", LLMChatBlock),
|
283
|
+
("text_parser", TextParserBlock),
|
284
|
+
]:
|
285
|
+
if hasattr(self, block_attr) and name in block_class.model_fields:
|
286
|
+
setattr(getattr(self, block_attr), name, value)
|
287
|
+
|
288
|
+
def _reinitialize_client_manager(self) -> None:
|
289
|
+
"""Reinitialize the internal LLM chat block's client manager.
|
290
|
+
|
291
|
+
This should be called after model configuration changes to ensure
|
292
|
+
the internal LLM chat block uses the updated model configuration.
|
293
|
+
"""
|
294
|
+
if self.llm_chat and hasattr(self.llm_chat, "_reinitialize_client_manager"):
|
295
|
+
# The parameters should already be forwarded via __setattr__ magic method
|
296
|
+
# Just reinitialize the client manager with the current configuration
|
297
|
+
self.llm_chat._reinitialize_client_manager()
|
298
|
+
|
299
|
+
def generate(self, samples: Dataset, **kwargs: Any) -> Dataset:
|
300
|
+
"""Generate responses with parsing retry logic.
|
301
|
+
|
302
|
+
For each input sample, this method:
|
303
|
+
1. Generates LLM responses using the configured n parameter
|
304
|
+
2. Attempts to parse the responses using TextParserBlock
|
305
|
+
3. Counts successful parses and retries if below target
|
306
|
+
4. Accumulates results across retry attempts
|
307
|
+
5. Returns final dataset with all successful parses
|
308
|
+
|
309
|
+
Parameters
|
310
|
+
----------
|
311
|
+
samples : Dataset
|
312
|
+
Input dataset containing the messages column.
|
313
|
+
**kwargs : Any
|
314
|
+
Additional keyword arguments passed to internal blocks.
|
315
|
+
|
316
|
+
Returns
|
317
|
+
-------
|
318
|
+
Dataset
|
319
|
+
Dataset with parsed results from successful generations.
|
320
|
+
|
321
|
+
Raises
|
322
|
+
------
|
323
|
+
BlockValidationError
|
324
|
+
If model is not configured before calling generate().
|
325
|
+
MaxRetriesExceededError
|
326
|
+
If target count not reached after max retries for any sample.
|
327
|
+
"""
|
328
|
+
# Validate that model is configured
|
329
|
+
if not hasattr(self, "model") or not self.model:
|
330
|
+
raise BlockValidationError(
|
331
|
+
f"Model not configured for block '{self.block_name}'. "
|
332
|
+
f"Call flow.set_model_config() before generating."
|
333
|
+
)
|
334
|
+
|
335
|
+
logger.info(
|
336
|
+
f"Starting LLM generation with parsing retry for {len(samples)} samples",
|
337
|
+
extra={
|
338
|
+
"block_name": self.block_name,
|
339
|
+
"model": self.model,
|
340
|
+
"batch_size": len(samples),
|
341
|
+
"parsing_max_retries": self.parsing_max_retries,
|
342
|
+
},
|
343
|
+
)
|
344
|
+
|
345
|
+
all_results = []
|
346
|
+
|
347
|
+
# Process each sample independently with retry logic
|
348
|
+
for sample_idx, sample in enumerate(samples):
|
349
|
+
# Determine target count for this sample (number of completions requested)
|
350
|
+
target = kwargs.get("n", getattr(self, "n", None)) or 1
|
351
|
+
|
352
|
+
logger.debug(
|
353
|
+
f"Processing sample {sample_idx} with target count {target}",
|
354
|
+
extra={
|
355
|
+
"block_name": self.block_name,
|
356
|
+
"sample_idx": sample_idx,
|
357
|
+
"target_count": target,
|
358
|
+
},
|
359
|
+
)
|
360
|
+
|
361
|
+
if self.text_parser.expand_lists:
|
362
|
+
# Current behavior for expand_lists=True: count rows directly
|
363
|
+
sample_results = []
|
364
|
+
total_parsed_count = 0
|
365
|
+
|
366
|
+
# Retry loop for this sample
|
367
|
+
for attempt in range(self.parsing_max_retries):
|
368
|
+
if total_parsed_count >= target:
|
369
|
+
break # Already reached target
|
370
|
+
|
371
|
+
try:
|
372
|
+
# Generate LLM responses for this sample
|
373
|
+
temp_dataset = Dataset.from_list([sample])
|
374
|
+
llm_result = self.llm_chat.generate(temp_dataset, **kwargs)
|
375
|
+
|
376
|
+
# Parse the responses
|
377
|
+
parsed_result = self.text_parser.generate(llm_result, **kwargs)
|
378
|
+
|
379
|
+
# Count successful parses and accumulate results
|
380
|
+
new_parsed_count = len(parsed_result)
|
381
|
+
total_parsed_count += new_parsed_count
|
382
|
+
sample_results.extend(parsed_result)
|
383
|
+
|
384
|
+
logger.debug(
|
385
|
+
f"Attempt {attempt + 1} for sample {sample_idx}: {new_parsed_count} successful parses "
|
386
|
+
f"(total: {total_parsed_count}/{target})",
|
387
|
+
extra={
|
388
|
+
"block_name": self.block_name,
|
389
|
+
"sample_idx": sample_idx,
|
390
|
+
"attempt": attempt + 1,
|
391
|
+
"new_parses": new_parsed_count,
|
392
|
+
"total_parses": total_parsed_count,
|
393
|
+
"target_count": target,
|
394
|
+
},
|
395
|
+
)
|
396
|
+
|
397
|
+
if total_parsed_count >= target:
|
398
|
+
logger.debug(
|
399
|
+
f"Target reached for sample {sample_idx} after {attempt + 1} attempts",
|
400
|
+
extra={
|
401
|
+
"block_name": self.block_name,
|
402
|
+
"sample_idx": sample_idx,
|
403
|
+
"attempts": attempt + 1,
|
404
|
+
"final_count": total_parsed_count,
|
405
|
+
},
|
406
|
+
)
|
407
|
+
break
|
408
|
+
|
409
|
+
except Exception as e:
|
410
|
+
logger.warning(
|
411
|
+
f"Error during attempt {attempt + 1} for sample {sample_idx}: {e}",
|
412
|
+
extra={
|
413
|
+
"block_name": self.block_name,
|
414
|
+
"sample_idx": sample_idx,
|
415
|
+
"attempt": attempt + 1,
|
416
|
+
"error": str(e),
|
417
|
+
},
|
418
|
+
)
|
419
|
+
# Continue to next attempt
|
420
|
+
continue
|
421
|
+
|
422
|
+
else:
|
423
|
+
# New behavior for expand_lists=False: parse individual responses and accumulate
|
424
|
+
accumulated_parsed_items = {col: [] for col in self.output_cols}
|
425
|
+
total_parsed_count = 0
|
426
|
+
|
427
|
+
# Retry loop for this sample
|
428
|
+
for attempt in range(self.parsing_max_retries):
|
429
|
+
if total_parsed_count >= target:
|
430
|
+
break # Already reached target
|
431
|
+
|
432
|
+
try:
|
433
|
+
# Generate LLM responses for this sample
|
434
|
+
temp_dataset = Dataset.from_list([sample])
|
435
|
+
llm_result = self.llm_chat.generate(temp_dataset, **kwargs)
|
436
|
+
|
437
|
+
# Get the raw responses (should be a list when n > 1)
|
438
|
+
raw_response_col = f"{self.block_name}_raw_response"
|
439
|
+
raw_responses = llm_result[0][raw_response_col]
|
440
|
+
if not isinstance(raw_responses, list):
|
441
|
+
raw_responses = [raw_responses]
|
442
|
+
|
443
|
+
# Parse each response individually and accumulate successful ones
|
444
|
+
new_parsed_count = 0
|
445
|
+
for response in raw_responses:
|
446
|
+
if total_parsed_count >= target:
|
447
|
+
break # Stop if we've reached target
|
448
|
+
|
449
|
+
# Create temporary dataset with single response for parsing
|
450
|
+
temp_parse_data = [{**sample, raw_response_col: response}]
|
451
|
+
temp_parse_dataset = Dataset.from_list(temp_parse_data)
|
452
|
+
|
453
|
+
# Force expand_lists=True temporarily to get individual parsed items
|
454
|
+
original_expand_lists = self.text_parser.expand_lists
|
455
|
+
try:
|
456
|
+
self.text_parser.expand_lists = True
|
457
|
+
parsed_result = self.text_parser.generate(
|
458
|
+
temp_parse_dataset, **kwargs
|
459
|
+
)
|
460
|
+
except Exception as parse_e:
|
461
|
+
logger.debug(
|
462
|
+
f"Failed to parse individual response: {parse_e}"
|
463
|
+
)
|
464
|
+
continue
|
465
|
+
finally:
|
466
|
+
self.text_parser.expand_lists = original_expand_lists
|
467
|
+
|
468
|
+
# If parsing was successful, accumulate the results
|
469
|
+
if len(parsed_result) > 0:
|
470
|
+
for parsed_row in parsed_result:
|
471
|
+
if total_parsed_count >= target:
|
472
|
+
break
|
473
|
+
|
474
|
+
# Only count as successful if ALL output columns are present
|
475
|
+
if all(
|
476
|
+
col in parsed_row for col in self.output_cols
|
477
|
+
):
|
478
|
+
for col in self.output_cols:
|
479
|
+
accumulated_parsed_items[col].append(
|
480
|
+
parsed_row[col]
|
481
|
+
)
|
482
|
+
total_parsed_count += 1
|
483
|
+
new_parsed_count += 1
|
484
|
+
# If any column is missing, skip this parsed response entirely
|
485
|
+
|
486
|
+
logger.debug(
|
487
|
+
f"Attempt {attempt + 1} for sample {sample_idx}: {new_parsed_count} successful parses "
|
488
|
+
f"(total: {total_parsed_count}/{target})",
|
489
|
+
extra={
|
490
|
+
"block_name": self.block_name,
|
491
|
+
"sample_idx": sample_idx,
|
492
|
+
"attempt": attempt + 1,
|
493
|
+
"new_parses": new_parsed_count,
|
494
|
+
"total_parses": total_parsed_count,
|
495
|
+
"target_count": target,
|
496
|
+
},
|
497
|
+
)
|
498
|
+
|
499
|
+
if total_parsed_count >= target:
|
500
|
+
logger.debug(
|
501
|
+
f"Target reached for sample {sample_idx} after {attempt + 1} attempts",
|
502
|
+
extra={
|
503
|
+
"block_name": self.block_name,
|
504
|
+
"sample_idx": sample_idx,
|
505
|
+
"attempts": attempt + 1,
|
506
|
+
"final_count": total_parsed_count,
|
507
|
+
},
|
508
|
+
)
|
509
|
+
break
|
510
|
+
|
511
|
+
except Exception as e:
|
512
|
+
logger.warning(
|
513
|
+
f"Error during attempt {attempt + 1} for sample {sample_idx}: {e}",
|
514
|
+
extra={
|
515
|
+
"block_name": self.block_name,
|
516
|
+
"sample_idx": sample_idx,
|
517
|
+
"attempt": attempt + 1,
|
518
|
+
"error": str(e),
|
519
|
+
},
|
520
|
+
)
|
521
|
+
# Continue to next attempt
|
522
|
+
continue
|
523
|
+
|
524
|
+
# Create final result row with accumulated lists
|
525
|
+
if total_parsed_count > 0:
|
526
|
+
# Trim to exact target count if needed
|
527
|
+
for col in self.output_cols:
|
528
|
+
if len(accumulated_parsed_items[col]) > target:
|
529
|
+
accumulated_parsed_items[col] = accumulated_parsed_items[
|
530
|
+
col
|
531
|
+
][:target]
|
532
|
+
|
533
|
+
# Only add the parsed output columns as lists, preserve other columns as-is
|
534
|
+
final_row = {**sample, **accumulated_parsed_items}
|
535
|
+
sample_results = [final_row]
|
536
|
+
else:
|
537
|
+
sample_results = []
|
538
|
+
|
539
|
+
# Check if we reached the target count
|
540
|
+
if total_parsed_count < target:
|
541
|
+
raise MaxRetriesExceededError(
|
542
|
+
target_count=target,
|
543
|
+
actual_count=total_parsed_count,
|
544
|
+
max_retries=self.parsing_max_retries,
|
545
|
+
)
|
546
|
+
|
547
|
+
# For expand_lists=True, trim results to exact target count if we exceeded it
|
548
|
+
if self.text_parser.expand_lists and total_parsed_count > target:
|
549
|
+
sample_results = sample_results[:target]
|
550
|
+
logger.debug(
|
551
|
+
f"Trimmed sample {sample_idx} results from {total_parsed_count} to {target}",
|
552
|
+
extra={
|
553
|
+
"block_name": self.block_name,
|
554
|
+
"sample_idx": sample_idx,
|
555
|
+
"trimmed_from": total_parsed_count,
|
556
|
+
"trimmed_to": target,
|
557
|
+
},
|
558
|
+
)
|
559
|
+
|
560
|
+
# Add this sample's results to final dataset
|
561
|
+
all_results.extend(sample_results)
|
562
|
+
|
563
|
+
logger.info(
|
564
|
+
f"LLM generation with parsing retry completed: {len(samples)} input samples → {len(all_results)} output rows",
|
565
|
+
extra={
|
566
|
+
"block_name": self.block_name,
|
567
|
+
"input_samples": len(samples),
|
568
|
+
"output_rows": len(all_results),
|
569
|
+
"model": self.model,
|
570
|
+
},
|
571
|
+
)
|
572
|
+
|
573
|
+
return Dataset.from_list(all_results)
|
574
|
+
|
575
|
+
def _validate_custom(self, dataset: Dataset) -> None:
|
576
|
+
"""Custom validation for LLMChatWithParsingRetryBlock.
|
577
|
+
|
578
|
+
This method validates the entire chain of internal blocks by simulating
|
579
|
+
the data flow through each block to ensure they can all process the data correctly.
|
580
|
+
"""
|
581
|
+
# Validate that required input column exists
|
582
|
+
if len(self.input_cols) != 1:
|
583
|
+
raise ValueError(
|
584
|
+
f"LLMChatWithParsingRetryBlock expects exactly one input column, got {len(self.input_cols)}"
|
585
|
+
)
|
586
|
+
|
587
|
+
input_col = self.input_cols[0]
|
588
|
+
if input_col not in dataset.column_names:
|
589
|
+
raise ValueError(
|
590
|
+
f"Required input column '{input_col}' not found in dataset. "
|
591
|
+
f"Available columns: {dataset.column_names}"
|
592
|
+
)
|
593
|
+
|
594
|
+
# Validate parsing configuration
|
595
|
+
has_regex = getattr(self, "parsing_pattern", None) is not None
|
596
|
+
has_tags = bool(getattr(self, "start_tags", [])) or bool(
|
597
|
+
getattr(self, "end_tags", [])
|
598
|
+
)
|
599
|
+
|
600
|
+
if not has_regex and not has_tags:
|
601
|
+
raise ValueError(
|
602
|
+
"LLMChatWithParsingRetryBlock requires at least one parsing method: "
|
603
|
+
"either 'parsing_pattern' (regex) or 'start_tags'/'end_tags' (tag-based parsing)"
|
604
|
+
)
|
605
|
+
|
606
|
+
# Validate that internal blocks are initialized
|
607
|
+
if not all([self.llm_chat, self.text_parser]):
|
608
|
+
raise ValueError(
|
609
|
+
"All internal blocks must be initialized before validation"
|
610
|
+
)
|
611
|
+
|
612
|
+
# Validate internal blocks
|
613
|
+
try:
|
614
|
+
logger.debug("Validating internal LLM chat block")
|
615
|
+
self.llm_chat._validate_custom(dataset)
|
616
|
+
|
617
|
+
# Create temporary dataset with expected LLM output for parser validation
|
618
|
+
temp_data = []
|
619
|
+
for sample in dataset:
|
620
|
+
temp_sample = dict(sample)
|
621
|
+
temp_sample[f"{self.block_name}_raw_response"] = "test output"
|
622
|
+
temp_data.append(temp_sample)
|
623
|
+
temp_dataset = Dataset.from_list(temp_data)
|
624
|
+
|
625
|
+
logger.debug("Validating internal text parser block")
|
626
|
+
self.text_parser._validate_custom(temp_dataset)
|
627
|
+
|
628
|
+
logger.debug("All internal blocks validated successfully")
|
629
|
+
|
630
|
+
except Exception as e:
|
631
|
+
logger.error(f"Validation failed in internal blocks: {e}")
|
632
|
+
raise ValueError(f"Internal block validation failed: {e}") from e
|
633
|
+
|
634
|
+
def get_internal_blocks_info(self) -> dict[str, Any]:
|
635
|
+
"""Get information about the internal blocks.
|
636
|
+
|
637
|
+
Returns
|
638
|
+
-------
|
639
|
+
Dict[str, Any]
|
640
|
+
Information about each internal block.
|
641
|
+
"""
|
642
|
+
return {
|
643
|
+
"llm_chat": self.llm_chat.get_info() if self.llm_chat else None,
|
644
|
+
"text_parser": self.text_parser.get_info() if self.text_parser else None,
|
645
|
+
}
|
646
|
+
|
647
|
+
def __repr__(self) -> str:
|
648
|
+
"""String representation of the block."""
|
649
|
+
model = getattr(self, "model", "not_configured")
|
650
|
+
return (
|
651
|
+
f"LLMChatWithParsingRetryBlock(name='{self.block_name}', "
|
652
|
+
f"model='{model}', parsing_max_retries={self.parsing_max_retries})"
|
653
|
+
)
|