sdg-hub 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdg_hub/__init__.py +28 -1
- sdg_hub/_version.py +2 -2
- sdg_hub/core/__init__.py +22 -0
- sdg_hub/core/blocks/__init__.py +58 -0
- sdg_hub/core/blocks/base.py +313 -0
- sdg_hub/core/blocks/deprecated_blocks/__init__.py +29 -0
- sdg_hub/core/blocks/deprecated_blocks/combine_columns.py +93 -0
- sdg_hub/core/blocks/deprecated_blocks/duplicate_columns.py +88 -0
- sdg_hub/core/blocks/deprecated_blocks/filter_by_value.py +103 -0
- sdg_hub/core/blocks/deprecated_blocks/flatten_columns.py +94 -0
- sdg_hub/core/blocks/deprecated_blocks/llmblock.py +479 -0
- sdg_hub/core/blocks/deprecated_blocks/rename_columns.py +88 -0
- sdg_hub/core/blocks/deprecated_blocks/sample_populator.py +58 -0
- sdg_hub/core/blocks/deprecated_blocks/selector.py +97 -0
- sdg_hub/core/blocks/deprecated_blocks/set_to_majority_value.py +88 -0
- sdg_hub/core/blocks/evaluation/__init__.py +9 -0
- sdg_hub/core/blocks/evaluation/evaluate_faithfulness_block.py +564 -0
- sdg_hub/core/blocks/evaluation/evaluate_relevancy_block.py +564 -0
- sdg_hub/core/blocks/evaluation/verify_question_block.py +564 -0
- sdg_hub/core/blocks/filtering/__init__.py +12 -0
- sdg_hub/core/blocks/filtering/column_value_filter.py +188 -0
- sdg_hub/core/blocks/llm/__init__.py +25 -0
- sdg_hub/core/blocks/llm/client_manager.py +398 -0
- sdg_hub/core/blocks/llm/config.py +336 -0
- sdg_hub/core/blocks/llm/error_handler.py +368 -0
- sdg_hub/core/blocks/llm/llm_chat_block.py +542 -0
- sdg_hub/core/blocks/llm/prompt_builder_block.py +368 -0
- sdg_hub/core/blocks/llm/text_parser_block.py +310 -0
- sdg_hub/core/blocks/registry.py +331 -0
- sdg_hub/core/blocks/transform/__init__.py +23 -0
- sdg_hub/core/blocks/transform/duplicate_columns.py +88 -0
- sdg_hub/core/blocks/transform/index_based_mapper.py +225 -0
- sdg_hub/core/blocks/transform/melt_columns.py +126 -0
- sdg_hub/core/blocks/transform/rename_columns.py +69 -0
- sdg_hub/core/blocks/transform/text_concat.py +102 -0
- sdg_hub/core/blocks/transform/uniform_col_val_setter.py +101 -0
- sdg_hub/core/flow/__init__.py +20 -0
- sdg_hub/core/flow/base.py +980 -0
- sdg_hub/core/flow/metadata.py +344 -0
- sdg_hub/core/flow/migration.py +187 -0
- sdg_hub/core/flow/registry.py +330 -0
- sdg_hub/core/flow/validation.py +265 -0
- sdg_hub/{utils → core/utils}/__init__.py +6 -4
- sdg_hub/{utils → core/utils}/datautils.py +1 -3
- sdg_hub/core/utils/error_handling.py +208 -0
- sdg_hub/{utils → core/utils}/path_resolution.py +2 -2
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/atomic_facts.yaml +40 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/detailed_summary.yaml +13 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_faithfulness.yaml +64 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_question.yaml +29 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_relevancy.yaml +81 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/extractive_summary.yaml +13 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/flow.yaml +191 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/generate_questions_responses.yaml +54 -0
- sdg_hub-0.2.0.dist-info/METADATA +218 -0
- sdg_hub-0.2.0.dist-info/RECORD +63 -0
- sdg_hub/blocks/__init__.py +0 -42
- sdg_hub/blocks/block.py +0 -96
- sdg_hub/blocks/llmblock.py +0 -375
- sdg_hub/blocks/openaichatblock.py +0 -556
- sdg_hub/blocks/utilblocks.py +0 -597
- sdg_hub/checkpointer.py +0 -139
- sdg_hub/configs/annotations/cot_reflection.yaml +0 -34
- sdg_hub/configs/annotations/detailed_annotations.yaml +0 -28
- sdg_hub/configs/annotations/detailed_description.yaml +0 -10
- sdg_hub/configs/annotations/detailed_description_icl.yaml +0 -32
- sdg_hub/configs/annotations/simple_annotations.yaml +0 -9
- sdg_hub/configs/knowledge/__init__.py +0 -0
- sdg_hub/configs/knowledge/atomic_facts.yaml +0 -46
- sdg_hub/configs/knowledge/auxilary_instructions.yaml +0 -35
- sdg_hub/configs/knowledge/detailed_summary.yaml +0 -18
- sdg_hub/configs/knowledge/evaluate_faithfulness.yaml +0 -68
- sdg_hub/configs/knowledge/evaluate_question.yaml +0 -38
- sdg_hub/configs/knowledge/evaluate_relevancy.yaml +0 -84
- sdg_hub/configs/knowledge/extractive_summary.yaml +0 -18
- sdg_hub/configs/knowledge/generate_code_questions_responses.yaml +0 -39
- sdg_hub/configs/knowledge/generate_questions.yaml +0 -82
- sdg_hub/configs/knowledge/generate_questions_responses.yaml +0 -56
- sdg_hub/configs/knowledge/generate_responses.yaml +0 -86
- sdg_hub/configs/knowledge/mcq_generation.yaml +0 -83
- sdg_hub/configs/knowledge/router.yaml +0 -12
- sdg_hub/configs/knowledge/simple_generate_qa.yaml +0 -34
- sdg_hub/configs/reasoning/__init__.py +0 -0
- sdg_hub/configs/reasoning/dynamic_cot.yaml +0 -40
- sdg_hub/configs/skills/__init__.py +0 -0
- sdg_hub/configs/skills/analyzer.yaml +0 -48
- sdg_hub/configs/skills/annotation.yaml +0 -36
- sdg_hub/configs/skills/contexts.yaml +0 -28
- sdg_hub/configs/skills/critic.yaml +0 -60
- sdg_hub/configs/skills/evaluate_freeform_pair.yaml +0 -111
- sdg_hub/configs/skills/evaluate_freeform_questions.yaml +0 -78
- sdg_hub/configs/skills/evaluate_grounded_pair.yaml +0 -119
- sdg_hub/configs/skills/evaluate_grounded_questions.yaml +0 -51
- sdg_hub/configs/skills/freeform_questions.yaml +0 -34
- sdg_hub/configs/skills/freeform_responses.yaml +0 -39
- sdg_hub/configs/skills/grounded_questions.yaml +0 -38
- sdg_hub/configs/skills/grounded_responses.yaml +0 -59
- sdg_hub/configs/skills/icl_examples/STEM.yaml +0 -56
- sdg_hub/configs/skills/icl_examples/__init__.py +0 -0
- sdg_hub/configs/skills/icl_examples/coding.yaml +0 -97
- sdg_hub/configs/skills/icl_examples/extraction.yaml +0 -36
- sdg_hub/configs/skills/icl_examples/humanities.yaml +0 -71
- sdg_hub/configs/skills/icl_examples/math.yaml +0 -85
- sdg_hub/configs/skills/icl_examples/reasoning.yaml +0 -30
- sdg_hub/configs/skills/icl_examples/roleplay.yaml +0 -45
- sdg_hub/configs/skills/icl_examples/writing.yaml +0 -80
- sdg_hub/configs/skills/judge.yaml +0 -53
- sdg_hub/configs/skills/planner.yaml +0 -67
- sdg_hub/configs/skills/respond.yaml +0 -8
- sdg_hub/configs/skills/revised_responder.yaml +0 -78
- sdg_hub/configs/skills/router.yaml +0 -59
- sdg_hub/configs/skills/simple_generate_qa_freeform.yaml +0 -27
- sdg_hub/configs/skills/simple_generate_qa_grounded.yaml +0 -31
- sdg_hub/flow.py +0 -477
- sdg_hub/flow_runner.py +0 -450
- sdg_hub/flows/generation/knowledge/mmlu_bench.yaml +0 -13
- sdg_hub/flows/generation/knowledge/simple_knowledge.yaml +0 -12
- sdg_hub/flows/generation/knowledge/synth_knowledge.yaml +0 -89
- sdg_hub/flows/generation/knowledge/synth_knowledge1.5.yaml +0 -136
- sdg_hub/flows/generation/skills/improve_responses.yaml +0 -103
- sdg_hub/flows/generation/skills/simple_freeform_skill.yaml +0 -12
- sdg_hub/flows/generation/skills/simple_grounded_skill.yaml +0 -12
- sdg_hub/flows/generation/skills/synth_grounded_skills.yaml +0 -80
- sdg_hub/flows/generation/skills/synth_skills.yaml +0 -59
- sdg_hub/pipeline.py +0 -121
- sdg_hub/prompts.py +0 -80
- sdg_hub/registry.py +0 -122
- sdg_hub/sdg.py +0 -206
- sdg_hub/utils/config_validation.py +0 -91
- sdg_hub/utils/error_handling.py +0 -94
- sdg_hub/utils/validation_result.py +0 -10
- sdg_hub-0.1.4.dist-info/METADATA +0 -190
- sdg_hub-0.1.4.dist-info/RECORD +0 -89
- sdg_hub/{logger_config.py → core/utils/logger_config.py} +1 -1
- /sdg_hub/{configs/__init__.py → flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/README.md} +0 -0
- /sdg_hub/{configs/annotations → flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab}/__init__.py +0 -0
- {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.0.dist-info}/WHEEL +0 -0
- {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,542 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
"""Unified LLM chat block supporting all providers via LiteLLM."""
|
3
|
+
|
4
|
+
# Standard
|
5
|
+
from typing import Any, Optional, Union
|
6
|
+
import asyncio
|
7
|
+
|
8
|
+
# Third Party
|
9
|
+
from datasets import Dataset
|
10
|
+
from pydantic import Field, field_validator
|
11
|
+
|
12
|
+
# Local
|
13
|
+
from ...utils.error_handling import BlockValidationError
|
14
|
+
from ...utils.logger_config import setup_logger
|
15
|
+
from ..base import BaseBlock
|
16
|
+
from ..registry import BlockRegistry
|
17
|
+
from .client_manager import LLMClientManager
|
18
|
+
from .config import LLMConfig
|
19
|
+
|
20
|
+
logger = setup_logger(__name__)
|
21
|
+
|
22
|
+
|
23
|
+
@BlockRegistry.register(
|
24
|
+
"LLMChatBlock",
|
25
|
+
"llm",
|
26
|
+
"Unified LLM chat block supporting 100+ providers via LiteLLM",
|
27
|
+
)
|
28
|
+
class LLMChatBlock(BaseBlock):
|
29
|
+
"""Unified LLM chat block supporting all providers via LiteLLM.
|
30
|
+
|
31
|
+
This block replaces OpenAIChatBlock and OpenAIAsyncChatBlock with a single
|
32
|
+
implementation that supports 100+ LLM providers through LiteLLM, including:
|
33
|
+
- OpenAI (GPT-3.5, GPT-4, etc.)
|
34
|
+
- Anthropic (Claude models)
|
35
|
+
- Google (Gemini, PaLM)
|
36
|
+
- Local models (vLLM, Ollama, etc.)
|
37
|
+
- And many more...
|
38
|
+
|
39
|
+
Parameters
|
40
|
+
----------
|
41
|
+
block_name : str
|
42
|
+
Name of the block.
|
43
|
+
input_cols : Union[str, List[str]]
|
44
|
+
Input column name(s). Should contain the messages list.
|
45
|
+
output_cols : Union[str, List[str]]
|
46
|
+
Output column name(s) for the response. When n > 1, the column will contain
|
47
|
+
a list of responses instead of a single string.
|
48
|
+
model : str
|
49
|
+
Model identifier in LiteLLM format. Examples:
|
50
|
+
- "openai/gpt-4"
|
51
|
+
- "anthropic/claude-3-sonnet-20240229"
|
52
|
+
- "hosted_vllm/meta-llama/Llama-2-7b-chat-hf"
|
53
|
+
api_key : Optional[str], optional
|
54
|
+
API key for the provider. Falls back to environment variables.
|
55
|
+
api_base : Optional[str], optional
|
56
|
+
Base URL for the API. Required for local models.
|
57
|
+
async_mode : bool, optional
|
58
|
+
Whether to use async processing, by default False.
|
59
|
+
timeout : float, optional
|
60
|
+
Request timeout in seconds, by default 120.0.
|
61
|
+
max_retries : int, optional
|
62
|
+
Maximum number of retry attempts, by default 6.
|
63
|
+
|
64
|
+
### Generation Parameters ###
|
65
|
+
|
66
|
+
temperature : Optional[float], optional
|
67
|
+
Sampling temperature (0.0 to 2.0).
|
68
|
+
max_tokens : Optional[int], optional
|
69
|
+
Maximum tokens to generate.
|
70
|
+
top_p : Optional[float], optional
|
71
|
+
Nucleus sampling parameter (0.0 to 1.0).
|
72
|
+
frequency_penalty : Optional[float], optional
|
73
|
+
Frequency penalty (-2.0 to 2.0).
|
74
|
+
presence_penalty : Optional[float], optional
|
75
|
+
Presence penalty (-2.0 to 2.0).
|
76
|
+
stop : Optional[Union[str, List[str]]], optional
|
77
|
+
Stop sequences.
|
78
|
+
seed : Optional[int], optional
|
79
|
+
Random seed for reproducible outputs.
|
80
|
+
response_format : Optional[Dict[str, Any]], optional
|
81
|
+
Response format specification (e.g., JSON mode).
|
82
|
+
stream : Optional[bool], optional
|
83
|
+
Whether to stream responses.
|
84
|
+
n : Optional[int], optional
|
85
|
+
Number of completions to generate. When n > 1, the output column will contain
|
86
|
+
a list of responses for each input sample.
|
87
|
+
logprobs : Optional[bool], optional
|
88
|
+
Whether to return log probabilities.
|
89
|
+
top_logprobs : Optional[int], optional
|
90
|
+
Number of top log probabilities to return.
|
91
|
+
user : Optional[str], optional
|
92
|
+
End-user identifier.
|
93
|
+
extra_headers : Optional[Dict[str, str]], optional
|
94
|
+
Additional headers to send with requests.
|
95
|
+
extra_body : Optional[Dict[str, Any]], optional
|
96
|
+
Additional parameters for the request body.
|
97
|
+
**kwargs : Any
|
98
|
+
Additional provider-specific parameters.
|
99
|
+
|
100
|
+
Examples
|
101
|
+
--------
|
102
|
+
>>> # OpenAI GPT-4
|
103
|
+
>>> block = LLMChatBlock(
|
104
|
+
... block_name="gpt4_block",
|
105
|
+
... input_cols="messages",
|
106
|
+
... output_cols="response",
|
107
|
+
... model="openai/gpt-4",
|
108
|
+
... temperature=0.7
|
109
|
+
... )
|
110
|
+
|
111
|
+
>>> # Anthropic Claude
|
112
|
+
>>> block = LLMChatBlock(
|
113
|
+
... block_name="claude_block",
|
114
|
+
... input_cols="messages",
|
115
|
+
... output_cols="response",
|
116
|
+
... model="anthropic/claude-3-sonnet-20240229",
|
117
|
+
... temperature=0.7
|
118
|
+
... )
|
119
|
+
|
120
|
+
>>> # Local vLLM model
|
121
|
+
>>> block = LLMChatBlock(
|
122
|
+
... block_name="local_llama",
|
123
|
+
... input_cols="messages",
|
124
|
+
... output_cols="response",
|
125
|
+
... model="hosted_vllm/meta-llama/Llama-2-7b-chat-hf",
|
126
|
+
... api_base="http://localhost:8000/v1",
|
127
|
+
... temperature=0.7
|
128
|
+
... )
|
129
|
+
|
130
|
+
>>> # Multiple completions (n > 1)
|
131
|
+
>>> block = LLMChatBlock(
|
132
|
+
... block_name="gpt4_multiple",
|
133
|
+
... input_cols="messages",
|
134
|
+
... output_cols="responses", # Will contain lists of strings
|
135
|
+
... model="openai/gpt-4",
|
136
|
+
... n=3, # Generate 3 responses per input
|
137
|
+
... temperature=0.8
|
138
|
+
... )
|
139
|
+
"""
|
140
|
+
|
141
|
+
# LLM Configuration
|
142
|
+
model: Optional[str] = Field(None, description="Model identifier in LiteLLM format")
|
143
|
+
api_key: Optional[str] = Field(None, description="API key for the provider")
|
144
|
+
api_base: Optional[str] = Field(None, description="Base URL for the API")
|
145
|
+
async_mode: bool = Field(False, description="Whether to use async processing")
|
146
|
+
timeout: float = Field(120.0, description="Request timeout in seconds")
|
147
|
+
max_retries: int = Field(6, description="Maximum number of retry attempts")
|
148
|
+
|
149
|
+
# Generation parameters
|
150
|
+
temperature: Optional[float] = Field(
|
151
|
+
None, description="Sampling temperature (0.0 to 2.0)"
|
152
|
+
)
|
153
|
+
max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate")
|
154
|
+
top_p: Optional[float] = Field(
|
155
|
+
None, description="Nucleus sampling parameter (0.0 to 1.0)"
|
156
|
+
)
|
157
|
+
frequency_penalty: Optional[float] = Field(
|
158
|
+
None, description="Frequency penalty (-2.0 to 2.0)"
|
159
|
+
)
|
160
|
+
presence_penalty: Optional[float] = Field(
|
161
|
+
None, description="Presence penalty (-2.0 to 2.0)"
|
162
|
+
)
|
163
|
+
stop: Optional[Union[str, list[str]]] = Field(None, description="Stop sequences")
|
164
|
+
seed: Optional[int] = Field(
|
165
|
+
None, description="Random seed for reproducible outputs"
|
166
|
+
)
|
167
|
+
response_format: Optional[dict[str, Any]] = Field(
|
168
|
+
None, description="Response format specification"
|
169
|
+
)
|
170
|
+
stream: Optional[bool] = Field(None, description="Whether to stream responses")
|
171
|
+
n: Optional[int] = Field(None, description="Number of completions to generate")
|
172
|
+
logprobs: Optional[bool] = Field(
|
173
|
+
None, description="Whether to return log probabilities"
|
174
|
+
)
|
175
|
+
top_logprobs: Optional[int] = Field(
|
176
|
+
None, description="Number of top log probabilities to return"
|
177
|
+
)
|
178
|
+
user: Optional[str] = Field(None, description="End-user identifier")
|
179
|
+
extra_headers: Optional[dict[str, str]] = Field(
|
180
|
+
None, description="Additional headers"
|
181
|
+
)
|
182
|
+
extra_body: Optional[dict[str, Any]] = Field(
|
183
|
+
None, description="Additional request body parameters"
|
184
|
+
)
|
185
|
+
provider_specific: Optional[dict[str, Any]] = Field(
|
186
|
+
None, description="Provider-specific parameters"
|
187
|
+
)
|
188
|
+
|
189
|
+
# Exclude from serialization - internal computed field
|
190
|
+
client_manager: Optional[Any] = Field(
|
191
|
+
None, exclude=True, description="Internal client manager"
|
192
|
+
)
|
193
|
+
|
194
|
+
@field_validator("input_cols")
|
195
|
+
@classmethod
|
196
|
+
def validate_single_input_col(cls, v):
|
197
|
+
"""Ensure exactly one input column."""
|
198
|
+
if isinstance(v, str):
|
199
|
+
return [v]
|
200
|
+
if isinstance(v, list) and len(v) == 1:
|
201
|
+
return v
|
202
|
+
if isinstance(v, list) and len(v) != 1:
|
203
|
+
raise ValueError(
|
204
|
+
f"LLMChatBlock expects exactly one input column, got {len(v)}: {v}"
|
205
|
+
)
|
206
|
+
raise ValueError(f"Invalid input_cols format: {v}")
|
207
|
+
|
208
|
+
@field_validator("output_cols")
|
209
|
+
@classmethod
|
210
|
+
def validate_single_output_col(cls, v):
|
211
|
+
"""Ensure exactly one output column."""
|
212
|
+
if isinstance(v, str):
|
213
|
+
return [v]
|
214
|
+
if isinstance(v, list) and len(v) == 1:
|
215
|
+
return v
|
216
|
+
if isinstance(v, list) and len(v) != 1:
|
217
|
+
raise ValueError(
|
218
|
+
f"LLMChatBlock expects exactly one output column, got {len(v)}: {v}"
|
219
|
+
)
|
220
|
+
raise ValueError(f"Invalid output_cols format: {v}")
|
221
|
+
|
222
|
+
def model_post_init(self, __context) -> None:
|
223
|
+
"""Initialize after Pydantic validation."""
|
224
|
+
super().model_post_init(__context)
|
225
|
+
|
226
|
+
# Initialize client manager
|
227
|
+
self._setup_client_manager()
|
228
|
+
|
229
|
+
def _setup_client_manager(self) -> None:
|
230
|
+
"""Set up the LLM client manager with current configuration."""
|
231
|
+
# Create configuration with current values
|
232
|
+
config = LLMConfig(
|
233
|
+
model=self.model,
|
234
|
+
api_key=self.api_key,
|
235
|
+
api_base=self.api_base,
|
236
|
+
timeout=self.timeout,
|
237
|
+
max_retries=self.max_retries,
|
238
|
+
temperature=self.temperature,
|
239
|
+
max_tokens=self.max_tokens,
|
240
|
+
top_p=self.top_p,
|
241
|
+
frequency_penalty=self.frequency_penalty,
|
242
|
+
presence_penalty=self.presence_penalty,
|
243
|
+
stop=self.stop,
|
244
|
+
seed=self.seed,
|
245
|
+
response_format=self.response_format,
|
246
|
+
stream=self.stream,
|
247
|
+
n=self.n,
|
248
|
+
logprobs=self.logprobs,
|
249
|
+
top_logprobs=self.top_logprobs,
|
250
|
+
user=self.user,
|
251
|
+
extra_headers=self.extra_headers,
|
252
|
+
extra_body=self.extra_body,
|
253
|
+
provider_specific=self.provider_specific,
|
254
|
+
)
|
255
|
+
|
256
|
+
# Create client manager
|
257
|
+
self.client_manager = LLMClientManager(config)
|
258
|
+
|
259
|
+
# Load client immediately
|
260
|
+
self.client_manager.load()
|
261
|
+
|
262
|
+
# Log initialization only when model is configured
|
263
|
+
if self.model:
|
264
|
+
logger.info(
|
265
|
+
f"Initialized LLMChatBlock '{self.block_name}' with model '{self.model}'",
|
266
|
+
extra={
|
267
|
+
"block_name": self.block_name,
|
268
|
+
"model": self.model,
|
269
|
+
"provider": self.client_manager.config.get_provider(),
|
270
|
+
"is_local": self.client_manager.config.is_local_model(),
|
271
|
+
"async_mode": self.async_mode,
|
272
|
+
"generation_params": self.client_manager.config.get_generation_kwargs(),
|
273
|
+
},
|
274
|
+
)
|
275
|
+
|
276
|
+
def _reinitialize_client_manager(self) -> None:
|
277
|
+
"""Reinitialize the client manager with updated model configuration.
|
278
|
+
|
279
|
+
This should be called after model configuration changes to ensure
|
280
|
+
the client manager uses the updated model, api_base, api_key, etc.
|
281
|
+
"""
|
282
|
+
self._setup_client_manager()
|
283
|
+
|
284
|
+
def generate(self, samples: Dataset, **override_kwargs: dict[str, Any]) -> Dataset:
|
285
|
+
"""Generate responses from the LLM.
|
286
|
+
|
287
|
+
Parameters set at runtime override those set during initialization.
|
288
|
+
Supports all LiteLLM parameters for the configured provider.
|
289
|
+
|
290
|
+
Parameters
|
291
|
+
----------
|
292
|
+
samples : Dataset
|
293
|
+
Input dataset containing the messages column.
|
294
|
+
**override_kwargs : Dict[str, Any]
|
295
|
+
Runtime parameters that override initialization defaults.
|
296
|
+
Valid parameters depend on the provider but typically include:
|
297
|
+
temperature, max_tokens, top_p, frequency_penalty, presence_penalty,
|
298
|
+
stop, seed, response_format, stream, n, and provider-specific params.
|
299
|
+
|
300
|
+
Returns
|
301
|
+
-------
|
302
|
+
Dataset
|
303
|
+
Dataset with responses added to the output column.
|
304
|
+
|
305
|
+
Raises
|
306
|
+
------
|
307
|
+
BlockValidationError
|
308
|
+
If model is not configured before calling generate().
|
309
|
+
"""
|
310
|
+
# Validate that model is configured
|
311
|
+
if not self.model:
|
312
|
+
raise BlockValidationError(
|
313
|
+
f"Model not configured for block '{self.block_name}'. "
|
314
|
+
f"Call flow.set_model_config() before generating."
|
315
|
+
)
|
316
|
+
|
317
|
+
# Extract messages
|
318
|
+
messages_list = samples[self.input_cols[0]]
|
319
|
+
|
320
|
+
# Log generation start
|
321
|
+
logger.info(
|
322
|
+
f"Starting {'async' if self.async_mode else 'sync'} generation for {len(messages_list)} samples",
|
323
|
+
extra={
|
324
|
+
"block_name": self.block_name,
|
325
|
+
"model": self.model,
|
326
|
+
"provider": self.client_manager.config.get_provider(),
|
327
|
+
"batch_size": len(messages_list),
|
328
|
+
"async_mode": self.async_mode,
|
329
|
+
"override_params": override_kwargs,
|
330
|
+
},
|
331
|
+
)
|
332
|
+
|
333
|
+
# Generate responses
|
334
|
+
if self.async_mode:
|
335
|
+
responses = asyncio.run(
|
336
|
+
self._generate_async(messages_list, **override_kwargs)
|
337
|
+
)
|
338
|
+
else:
|
339
|
+
responses = self._generate_sync(messages_list, **override_kwargs)
|
340
|
+
|
341
|
+
# Log completion
|
342
|
+
logger.info(
|
343
|
+
f"Generation completed successfully for {len(responses)} samples",
|
344
|
+
extra={
|
345
|
+
"block_name": self.block_name,
|
346
|
+
"model": self.model,
|
347
|
+
"provider": self.client_manager.config.get_provider(),
|
348
|
+
"batch_size": len(responses),
|
349
|
+
},
|
350
|
+
)
|
351
|
+
|
352
|
+
# Add responses as new column
|
353
|
+
return samples.add_column(self.output_cols[0], responses)
|
354
|
+
|
355
|
+
def _generate_sync(
|
356
|
+
self,
|
357
|
+
messages_list: list[list[dict[str, Any]]],
|
358
|
+
**override_kwargs: dict[str, Any],
|
359
|
+
) -> list[Union[str, list[str]]]:
|
360
|
+
"""Generate responses synchronously.
|
361
|
+
|
362
|
+
Parameters
|
363
|
+
----------
|
364
|
+
messages_list : List[List[Dict[str, Any]]]
|
365
|
+
List of message lists to process.
|
366
|
+
**override_kwargs : Dict[str, Any]
|
367
|
+
Runtime parameter overrides.
|
368
|
+
|
369
|
+
Returns
|
370
|
+
-------
|
371
|
+
List[Union[str, List[str]]]
|
372
|
+
List of response strings or lists of response strings (when n > 1).
|
373
|
+
"""
|
374
|
+
responses = []
|
375
|
+
|
376
|
+
for i, messages in enumerate(messages_list):
|
377
|
+
try:
|
378
|
+
response = self.client_manager.create_completion(
|
379
|
+
messages, **override_kwargs
|
380
|
+
)
|
381
|
+
responses.append(response)
|
382
|
+
|
383
|
+
# Log progress for large batches
|
384
|
+
if (i + 1) % 10 == 0:
|
385
|
+
logger.debug(
|
386
|
+
f"Generated {i + 1}/{len(messages_list)} responses",
|
387
|
+
extra={
|
388
|
+
"block_name": self.block_name,
|
389
|
+
"progress": f"{i + 1}/{len(messages_list)}",
|
390
|
+
},
|
391
|
+
)
|
392
|
+
|
393
|
+
except Exception as e:
|
394
|
+
error_msg = self.client_manager.error_handler.format_error_message(
|
395
|
+
e, {"model": self.model, "sample_index": i}
|
396
|
+
)
|
397
|
+
logger.error(
|
398
|
+
f"Failed to generate response for sample {i}: {error_msg}",
|
399
|
+
extra={
|
400
|
+
"block_name": self.block_name,
|
401
|
+
"sample_index": i,
|
402
|
+
"error": str(e),
|
403
|
+
},
|
404
|
+
)
|
405
|
+
raise
|
406
|
+
|
407
|
+
return responses
|
408
|
+
|
409
|
+
async def _generate_async(
|
410
|
+
self,
|
411
|
+
messages_list: list[list[dict[str, Any]]],
|
412
|
+
**override_kwargs: dict[str, Any],
|
413
|
+
) -> list[Union[str, list[str]]]:
|
414
|
+
"""Generate responses asynchronously.
|
415
|
+
|
416
|
+
Parameters
|
417
|
+
----------
|
418
|
+
messages_list : List[List[Dict[str, Any]]]
|
419
|
+
List of message lists to process.
|
420
|
+
**override_kwargs : Dict[str, Any]
|
421
|
+
Runtime parameter overrides.
|
422
|
+
|
423
|
+
Returns
|
424
|
+
-------
|
425
|
+
List[Union[str, List[str]]]
|
426
|
+
List of response strings or lists of response strings (when n > 1).
|
427
|
+
"""
|
428
|
+
try:
|
429
|
+
responses = await self.client_manager.acreate_completions_batch(
|
430
|
+
messages_list, **override_kwargs
|
431
|
+
)
|
432
|
+
return responses
|
433
|
+
|
434
|
+
except Exception as e:
|
435
|
+
error_msg = self.client_manager.error_handler.format_error_message(
|
436
|
+
e, {"model": self.model}
|
437
|
+
)
|
438
|
+
logger.error(
|
439
|
+
f"Failed to generate async responses: {error_msg}",
|
440
|
+
extra={
|
441
|
+
"block_name": self.block_name,
|
442
|
+
"batch_size": len(messages_list),
|
443
|
+
"error": str(e),
|
444
|
+
},
|
445
|
+
)
|
446
|
+
raise
|
447
|
+
|
448
|
+
def get_model_info(self) -> dict[str, Any]:
|
449
|
+
"""Get information about the configured model.
|
450
|
+
|
451
|
+
Returns
|
452
|
+
-------
|
453
|
+
Dict[str, Any]
|
454
|
+
Model information including provider, capabilities, etc.
|
455
|
+
"""
|
456
|
+
return {
|
457
|
+
**self.client_manager.get_model_info(),
|
458
|
+
"block_name": self.block_name,
|
459
|
+
"input_column": self.input_cols[0],
|
460
|
+
"output_column": self.output_cols[0],
|
461
|
+
"async_mode": self.async_mode,
|
462
|
+
}
|
463
|
+
|
464
|
+
def _validate_custom(self, dataset: Dataset) -> None:
|
465
|
+
"""Custom validation for LLMChatBlock message format.
|
466
|
+
|
467
|
+
Validates that all samples contain properly formatted messages.
|
468
|
+
|
469
|
+
Parameters
|
470
|
+
----------
|
471
|
+
dataset : Dataset
|
472
|
+
The dataset to validate.
|
473
|
+
|
474
|
+
Raises
|
475
|
+
------
|
476
|
+
BlockValidationError
|
477
|
+
If message format validation fails.
|
478
|
+
"""
|
479
|
+
|
480
|
+
def validate_sample(sample_with_index):
|
481
|
+
"""Validate a single sample's message format."""
|
482
|
+
idx, sample = sample_with_index
|
483
|
+
messages = sample[self.input_cols[0]]
|
484
|
+
|
485
|
+
# Validate messages is a list
|
486
|
+
if not isinstance(messages, list):
|
487
|
+
raise BlockValidationError(
|
488
|
+
f"Messages column '{self.input_cols[0]}' must contain a list, "
|
489
|
+
f"got {type(messages)} in row {idx}",
|
490
|
+
details=f"Block: {self.block_name}, Row: {idx}, Value: {messages}",
|
491
|
+
)
|
492
|
+
|
493
|
+
# Validate messages is not empty
|
494
|
+
if not messages:
|
495
|
+
raise BlockValidationError(
|
496
|
+
f"Messages list is empty in row {idx}",
|
497
|
+
details=f"Block: {self.block_name}, Row: {idx}",
|
498
|
+
)
|
499
|
+
|
500
|
+
# Validate each message format
|
501
|
+
for msg_idx, message in enumerate(messages):
|
502
|
+
if not isinstance(message, dict):
|
503
|
+
raise BlockValidationError(
|
504
|
+
f"Message {msg_idx} in row {idx} must be a dict, got {type(message)}",
|
505
|
+
details=f"Block: {self.block_name}, Row: {idx}, Message: {msg_idx}, Value: {message}",
|
506
|
+
)
|
507
|
+
|
508
|
+
# Validate required fields
|
509
|
+
if "role" not in message or message["role"] is None:
|
510
|
+
raise BlockValidationError(
|
511
|
+
f"Message {msg_idx} in row {idx} missing required 'role' field",
|
512
|
+
details=f"Block: {self.block_name}, Row: {idx}, Message: {msg_idx}, Available fields: {list(message.keys())}",
|
513
|
+
)
|
514
|
+
|
515
|
+
if "content" not in message or message["content"] is None:
|
516
|
+
raise BlockValidationError(
|
517
|
+
f"Message {msg_idx} in row {idx} missing required 'content' field",
|
518
|
+
details=f"Block: {self.block_name}, Row: {idx}, Message: {msg_idx}, Available fields: {list(message.keys())}",
|
519
|
+
)
|
520
|
+
|
521
|
+
return True # Return something for map
|
522
|
+
|
523
|
+
# Use map to validate all samples
|
524
|
+
# Add index to each sample for better error reporting
|
525
|
+
indexed_samples = [(i, sample) for i, sample in enumerate(dataset)]
|
526
|
+
list(map(validate_sample, indexed_samples))
|
527
|
+
|
528
|
+
def __del__(self) -> None:
|
529
|
+
"""Cleanup when block is destroyed."""
|
530
|
+
try:
|
531
|
+
if hasattr(self, "client_manager"):
|
532
|
+
self.client_manager.unload()
|
533
|
+
except Exception:
|
534
|
+
# Ignore errors during cleanup to prevent issues during shutdown
|
535
|
+
pass
|
536
|
+
|
537
|
+
def __repr__(self) -> str:
|
538
|
+
"""String representation of the block."""
|
539
|
+
return (
|
540
|
+
f"LLMChatBlock(name='{self.block_name}', model='{self.model}', "
|
541
|
+
f"provider='{self.client_manager.config.get_provider()}', async_mode={self.async_mode})"
|
542
|
+
)
|