sdg-hub 0.3.1__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdg_hub/_version.py +2 -2
- sdg_hub/core/blocks/__init__.py +2 -4
- sdg_hub/core/blocks/base.py +61 -6
- sdg_hub/core/blocks/filtering/column_value_filter.py +3 -2
- sdg_hub/core/blocks/llm/__init__.py +2 -4
- sdg_hub/core/blocks/llm/llm_chat_block.py +251 -265
- sdg_hub/core/blocks/llm/llm_chat_with_parsing_retry_block.py +216 -98
- sdg_hub/core/blocks/llm/llm_parser_block.py +320 -0
- sdg_hub/core/blocks/llm/text_parser_block.py +53 -152
- sdg_hub/core/flow/base.py +7 -4
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/detailed_summary/flow.yaml +51 -11
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/doc_direct_qa/__init__.py +0 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/doc_direct_qa/flow.yaml +159 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/extractive_summary/flow.yaml +51 -11
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/key_facts/flow.yaml +14 -2
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/flow.yaml +146 -26
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/multilingual/japanese/README.md +0 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/multilingual/japanese/__init__.py +0 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/multilingual/japanese/atomic_facts_ja.yaml +41 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/multilingual/japanese/detailed_summary_ja.yaml +14 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/multilingual/japanese/extractive_summary_ja.yaml +14 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/multilingual/japanese/flow.yaml +304 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/multilingual/japanese/generate_questions_responses_ja.yaml +55 -0
- sdg_hub/flows/text_analysis/structured_insights/flow.yaml +28 -4
- {sdg_hub-0.3.1.dist-info → sdg_hub-0.4.0.dist-info}/METADATA +1 -1
- {sdg_hub-0.3.1.dist-info → sdg_hub-0.4.0.dist-info}/RECORD +29 -25
- sdg_hub/core/blocks/evaluation/__init__.py +0 -9
- sdg_hub/core/blocks/evaluation/evaluate_faithfulness_block.py +0 -323
- sdg_hub/core/blocks/evaluation/evaluate_relevancy_block.py +0 -323
- sdg_hub/core/blocks/evaluation/verify_question_block.py +0 -329
- sdg_hub/core/blocks/llm/client_manager.py +0 -472
- sdg_hub/core/blocks/llm/config.py +0 -337
- {sdg_hub-0.3.1.dist-info → sdg_hub-0.4.0.dist-info}/WHEEL +0 -0
- {sdg_hub-0.3.1.dist-info → sdg_hub-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {sdg_hub-0.3.1.dist-info → sdg_hub-0.4.0.dist-info}/top_level.txt +0 -0
@@ -2,21 +2,23 @@
|
|
2
2
|
"""Unified LLM chat block supporting all providers via LiteLLM."""
|
3
3
|
|
4
4
|
# Standard
|
5
|
-
from typing import Any, Optional
|
5
|
+
from typing import Any, Optional
|
6
6
|
import asyncio
|
7
7
|
|
8
8
|
# Third Party
|
9
9
|
from datasets import Dataset
|
10
|
-
from
|
10
|
+
from litellm import acompletion, completion
|
11
|
+
from pydantic import ConfigDict, Field, field_validator
|
12
|
+
import litellm
|
11
13
|
|
12
|
-
# Local
|
13
14
|
from ...utils.error_handling import BlockValidationError
|
14
15
|
from ...utils.logger_config import setup_logger
|
16
|
+
|
17
|
+
# Local
|
15
18
|
from ..base import BaseBlock
|
16
19
|
from ..registry import BlockRegistry
|
17
|
-
from .client_manager import LLMClientManager
|
18
|
-
from .config import LLMConfig
|
19
20
|
|
21
|
+
litellm.drop_params = True
|
20
22
|
logger = setup_logger(__name__)
|
21
23
|
|
22
24
|
|
@@ -26,10 +28,12 @@ logger = setup_logger(__name__)
|
|
26
28
|
"Unified LLM chat block supporting 100+ providers via LiteLLM",
|
27
29
|
)
|
28
30
|
class LLMChatBlock(BaseBlock):
|
31
|
+
model_config = ConfigDict(extra="allow")
|
32
|
+
|
29
33
|
"""Unified LLM chat block supporting all providers via LiteLLM.
|
30
34
|
|
31
|
-
This block
|
32
|
-
|
35
|
+
This block provides a minimal wrapper around LiteLLM's completion API,
|
36
|
+
supporting 100+ LLM providers including:
|
33
37
|
- OpenAI (GPT-3.5, GPT-4, etc.)
|
34
38
|
- Anthropic (Claude models)
|
35
39
|
- Google (Gemini, PaLM)
|
@@ -43,14 +47,10 @@ class LLMChatBlock(BaseBlock):
|
|
43
47
|
input_cols : Union[str, List[str]]
|
44
48
|
Input column name(s). Should contain the messages list.
|
45
49
|
output_cols : Union[dict, List[dict]]
|
46
|
-
Output column name(s) for the response.
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
Model identifier in LiteLLM format. Examples:
|
51
|
-
- "openai/gpt-4"
|
52
|
-
- "anthropic/claude-3-sonnet-20240229"
|
53
|
-
- "hosted_vllm/meta-llama/Llama-2-7b-chat-hf"
|
50
|
+
Output column name(s) for the response.
|
51
|
+
model : Optional[str], optional
|
52
|
+
Model identifier in LiteLLM format. Can be set later via flow.set_model_config().
|
53
|
+
Examples: "openai/gpt-4", "anthropic/claude-3-sonnet-20240229"
|
54
54
|
api_key : Optional[str], optional
|
55
55
|
API key for the provider. Falls back to environment variables.
|
56
56
|
api_base : Optional[str], optional
|
@@ -59,138 +59,68 @@ class LLMChatBlock(BaseBlock):
|
|
59
59
|
Whether to use async processing, by default False.
|
60
60
|
timeout : float, optional
|
61
61
|
Request timeout in seconds, by default 120.0.
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
temperature : Optional[float], optional
|
68
|
-
Sampling temperature (0.0 to 2.0).
|
69
|
-
max_tokens : Optional[int], optional
|
70
|
-
Maximum tokens to generate.
|
71
|
-
top_p : Optional[float], optional
|
72
|
-
Nucleus sampling parameter (0.0 to 1.0).
|
73
|
-
frequency_penalty : Optional[float], optional
|
74
|
-
Frequency penalty (-2.0 to 2.0).
|
75
|
-
presence_penalty : Optional[float], optional
|
76
|
-
Presence penalty (-2.0 to 2.0).
|
77
|
-
stop : Optional[Union[str, List[str]]], optional
|
78
|
-
Stop sequences.
|
79
|
-
seed : Optional[int], optional
|
80
|
-
Random seed for reproducible outputs.
|
81
|
-
response_format : Optional[Dict[str, Any]], optional
|
82
|
-
Response format specification (e.g., JSON mode).
|
83
|
-
stream : Optional[bool], optional
|
84
|
-
Whether to stream responses.
|
85
|
-
n : Optional[int], optional
|
86
|
-
Number of completions to generate. When n > 1, the output column will contain
|
87
|
-
a list of responses for each input sample.
|
88
|
-
logprobs : Optional[bool], optional
|
89
|
-
Whether to return log probabilities.
|
90
|
-
top_logprobs : Optional[int], optional
|
91
|
-
Number of top log probabilities to return.
|
92
|
-
user : Optional[str], optional
|
93
|
-
End-user identifier.
|
94
|
-
extra_headers : Optional[Dict[str, str]], optional
|
95
|
-
Additional headers to send with requests.
|
96
|
-
extra_body : Optional[Dict[str, Any]], optional
|
97
|
-
Additional parameters for the request body.
|
62
|
+
num_retries : int, optional
|
63
|
+
Number of retry attempts (uses LiteLLM's built-in retry mechanism), by default 6.
|
64
|
+
Note: For rate limit handling, use LiteLLM's fallbacks parameter instead.
|
65
|
+
drop_params : bool, optional
|
66
|
+
Whether to drop unsupported parameters to prevent API errors, by default True.
|
98
67
|
**kwargs : Any
|
99
|
-
|
68
|
+
Any LiteLLM completion parameters (temperature, max_tokens, top_p, etc.).
|
69
|
+
See https://docs.litellm.ai/docs/completion/input for full list.
|
100
70
|
|
101
71
|
Examples
|
102
72
|
--------
|
103
|
-
>>> # OpenAI GPT-4
|
73
|
+
>>> # OpenAI GPT-4 with generation parameters
|
104
74
|
>>> block = LLMChatBlock(
|
105
75
|
... block_name="gpt4_block",
|
106
76
|
... input_cols="messages",
|
107
77
|
... output_cols="response",
|
108
78
|
... model="openai/gpt-4",
|
109
|
-
... temperature=0.7
|
110
|
-
...
|
111
|
-
|
112
|
-
>>> # Anthropic Claude
|
113
|
-
>>> block = LLMChatBlock(
|
114
|
-
... block_name="claude_block",
|
115
|
-
... input_cols="messages",
|
116
|
-
... output_cols="response",
|
117
|
-
... model="anthropic/claude-3-sonnet-20240229",
|
118
|
-
... temperature=0.7
|
79
|
+
... temperature=0.7,
|
80
|
+
... max_tokens=1000
|
119
81
|
... )
|
120
82
|
|
121
|
-
>>> # Local vLLM model
|
83
|
+
>>> # Local vLLM model with custom parameters
|
122
84
|
>>> block = LLMChatBlock(
|
123
85
|
... block_name="local_llama",
|
124
86
|
... input_cols="messages",
|
125
87
|
... output_cols="response",
|
126
88
|
... model="hosted_vllm/meta-llama/Llama-2-7b-chat-hf",
|
127
89
|
... api_base="http://localhost:8000/v1",
|
128
|
-
... temperature=0.7
|
129
|
-
...
|
130
|
-
|
131
|
-
>>> # Multiple completions (n > 1)
|
132
|
-
>>> block = LLMChatBlock(
|
133
|
-
... block_name="gpt4_multiple",
|
134
|
-
... input_cols="messages",
|
135
|
-
... output_cols="responses", # Will contain lists of responses
|
136
|
-
... model="openai/gpt-4",
|
137
|
-
... n=3, # Generate 3 responses per input
|
138
|
-
... temperature=0.8
|
90
|
+
... temperature=0.7,
|
91
|
+
... response_format={"type": "json_object"}
|
139
92
|
... )
|
140
93
|
"""
|
141
94
|
|
142
|
-
#
|
143
|
-
model: Optional[str] = Field(
|
144
|
-
|
145
|
-
api_base: Optional[str] = Field(None, description="Base URL for the API")
|
146
|
-
async_mode: bool = Field(False, description="Whether to use async processing")
|
147
|
-
timeout: float = Field(120.0, description="Request timeout in seconds")
|
148
|
-
max_retries: int = Field(6, description="Maximum number of retry attempts")
|
149
|
-
|
150
|
-
# Generation parameters
|
151
|
-
temperature: Optional[float] = Field(
|
152
|
-
None, description="Sampling temperature (0.0 to 2.0)"
|
153
|
-
)
|
154
|
-
max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate")
|
155
|
-
top_p: Optional[float] = Field(
|
156
|
-
None, description="Nucleus sampling parameter (0.0 to 1.0)"
|
95
|
+
# Essential operational fields (excluded from YAML serialization)
|
96
|
+
model: Optional[str] = Field(
|
97
|
+
None, exclude=True, description="Model identifier in LiteLLM format"
|
157
98
|
)
|
158
|
-
|
159
|
-
None, description="
|
99
|
+
api_key: Optional[str] = Field(
|
100
|
+
None, exclude=True, description="API key for the provider"
|
160
101
|
)
|
161
|
-
|
162
|
-
None, description="
|
102
|
+
api_base: Optional[str] = Field(
|
103
|
+
None, exclude=True, description="Base URL for the API"
|
163
104
|
)
|
164
|
-
|
165
|
-
|
166
|
-
None, description="Random seed for reproducible outputs"
|
105
|
+
async_mode: bool = Field(
|
106
|
+
False, exclude=True, description="Whether to use async processing"
|
167
107
|
)
|
168
|
-
|
169
|
-
|
108
|
+
timeout: float = Field(
|
109
|
+
120.0, exclude=True, description="Request timeout in seconds"
|
170
110
|
)
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
111
|
+
num_retries: int = Field(
|
112
|
+
6,
|
113
|
+
exclude=True,
|
114
|
+
description="Number of retry attempts (uses LiteLLM's built-in retry mechanism)",
|
175
115
|
)
|
176
|
-
|
177
|
-
|
178
|
-
)
|
179
|
-
user: Optional[str] = Field(None, description="End-user identifier")
|
180
|
-
extra_headers: Optional[dict[str, str]] = Field(
|
181
|
-
None, description="Additional headers"
|
182
|
-
)
|
183
|
-
extra_body: Optional[dict[str, Any]] = Field(
|
184
|
-
None, description="Additional request body parameters"
|
185
|
-
)
|
186
|
-
provider_specific: Optional[dict[str, Any]] = Field(
|
187
|
-
None, description="Provider-specific parameters"
|
116
|
+
drop_params: bool = Field(
|
117
|
+
True, description="Whether to drop unsupported parameters to prevent API errors"
|
188
118
|
)
|
189
119
|
|
190
|
-
#
|
191
|
-
|
192
|
-
|
193
|
-
|
120
|
+
# All LiteLLM completion parameters can be passed via extra="allow"
|
121
|
+
# Common examples: temperature, max_tokens, top_p, frequency_penalty,
|
122
|
+
# presence_penalty, stop, seed, response_format, stream, n, logprobs,
|
123
|
+
# top_logprobs, user, extra_headers, extra_body, etc.
|
194
124
|
|
195
125
|
@field_validator("input_cols")
|
196
126
|
@classmethod
|
@@ -224,83 +154,29 @@ class LLMChatBlock(BaseBlock):
|
|
224
154
|
"""Initialize after Pydantic validation."""
|
225
155
|
super().model_post_init(__context)
|
226
156
|
|
227
|
-
# Initialize client manager
|
228
|
-
self._setup_client_manager()
|
229
|
-
|
230
|
-
def _setup_client_manager(self) -> None:
|
231
|
-
"""Set up the LLM client manager with current configuration."""
|
232
|
-
# Create configuration with current values
|
233
|
-
config = LLMConfig(
|
234
|
-
model=self.model,
|
235
|
-
api_key=self.api_key,
|
236
|
-
api_base=self.api_base,
|
237
|
-
timeout=self.timeout,
|
238
|
-
max_retries=self.max_retries,
|
239
|
-
temperature=self.temperature,
|
240
|
-
max_tokens=self.max_tokens,
|
241
|
-
top_p=self.top_p,
|
242
|
-
frequency_penalty=self.frequency_penalty,
|
243
|
-
presence_penalty=self.presence_penalty,
|
244
|
-
stop=self.stop,
|
245
|
-
seed=self.seed,
|
246
|
-
response_format=self.response_format,
|
247
|
-
stream=self.stream,
|
248
|
-
n=self.n,
|
249
|
-
logprobs=self.logprobs,
|
250
|
-
top_logprobs=self.top_logprobs,
|
251
|
-
user=self.user,
|
252
|
-
extra_headers=self.extra_headers,
|
253
|
-
extra_body=self.extra_body,
|
254
|
-
provider_specific=self.provider_specific,
|
255
|
-
)
|
256
|
-
|
257
|
-
# Create client manager
|
258
|
-
self.client_manager = LLMClientManager(config)
|
259
|
-
|
260
|
-
# Load client immediately
|
261
|
-
self.client_manager.load()
|
262
|
-
|
263
157
|
# Log initialization only when model is configured
|
264
158
|
if self.model:
|
265
159
|
logger.info(
|
266
|
-
|
160
|
+
"Initialized LLMChatBlock '%s' with model '%s'",
|
161
|
+
self.block_name,
|
162
|
+
self.model,
|
267
163
|
extra={
|
268
164
|
"block_name": self.block_name,
|
269
165
|
"model": self.model,
|
270
|
-
"provider": self.client_manager.config.get_provider(),
|
271
|
-
"is_local": self.client_manager.config.is_local_model(),
|
272
166
|
"async_mode": self.async_mode,
|
273
|
-
"generation_params": self.client_manager.config.get_generation_kwargs(),
|
274
167
|
},
|
275
168
|
)
|
276
169
|
|
277
|
-
def
|
278
|
-
"""Reinitialize the client manager with updated model configuration.
|
279
|
-
|
280
|
-
This should be called after model configuration changes to ensure
|
281
|
-
the client manager uses the updated model, api_base, api_key, etc.
|
282
|
-
"""
|
283
|
-
self._setup_client_manager()
|
284
|
-
|
285
|
-
def generate(self, samples: Dataset, **override_kwargs: dict[str, Any]) -> Dataset:
|
170
|
+
def generate(self, samples: Dataset, **kwargs: Any) -> Dataset:
|
286
171
|
"""Generate responses from the LLM.
|
287
172
|
|
288
|
-
Parameters set at runtime override those set during initialization.
|
289
|
-
Supports all LiteLLM parameters for the configured provider.
|
290
|
-
|
291
173
|
Parameters
|
292
174
|
----------
|
293
175
|
samples : Dataset
|
294
176
|
Input dataset containing the messages column.
|
295
|
-
**
|
177
|
+
**kwargs : Any
|
296
178
|
Runtime parameters that override initialization defaults.
|
297
|
-
|
298
|
-
temperature, max_tokens, top_p, frequency_penalty, presence_penalty,
|
299
|
-
stop, seed, response_format, stream, n, and provider-specific params.
|
300
|
-
|
301
|
-
Special flow-level parameters:
|
302
|
-
_flow_max_concurrency : int, optional
|
303
|
-
Maximum concurrency for async requests (passed by Flow).
|
179
|
+
Supports all LiteLLM completion parameters.
|
304
180
|
|
305
181
|
Returns
|
306
182
|
-------
|
@@ -319,16 +195,21 @@ class LLMChatBlock(BaseBlock):
|
|
319
195
|
f"Call flow.set_model_config() before generating."
|
320
196
|
)
|
321
197
|
|
322
|
-
# Extract
|
323
|
-
flow_max_concurrency =
|
198
|
+
# Extract flow-specific parameters (BaseBlock already handled block field overrides)
|
199
|
+
flow_max_concurrency = kwargs.pop("_flow_max_concurrency", None)
|
200
|
+
|
201
|
+
# Build completion kwargs from ALL fields + runtime overrides
|
202
|
+
completion_kwargs = self._build_completion_kwargs(**kwargs)
|
324
203
|
|
325
204
|
# Extract messages
|
326
205
|
messages_list = samples[self.input_cols[0]]
|
327
206
|
|
328
207
|
# Log generation start
|
329
208
|
logger.info(
|
330
|
-
|
331
|
-
|
209
|
+
"Starting %s generation for %d samples%s",
|
210
|
+
"async" if self.async_mode else "sync",
|
211
|
+
len(messages_list),
|
212
|
+
(
|
332
213
|
f" (max_concurrency={flow_max_concurrency})"
|
333
214
|
if flow_max_concurrency
|
334
215
|
else ""
|
@@ -336,21 +217,9 @@ class LLMChatBlock(BaseBlock):
|
|
336
217
|
extra={
|
337
218
|
"block_name": self.block_name,
|
338
219
|
"model": self.model,
|
339
|
-
"provider": self.client_manager.config.get_provider(),
|
340
220
|
"batch_size": len(messages_list),
|
341
221
|
"async_mode": self.async_mode,
|
342
222
|
"flow_max_concurrency": flow_max_concurrency,
|
343
|
-
"override_params": {
|
344
|
-
k: (
|
345
|
-
"***"
|
346
|
-
if any(
|
347
|
-
s in k.lower()
|
348
|
-
for s in ["key", "token", "secret", "authorization"]
|
349
|
-
)
|
350
|
-
else v
|
351
|
-
)
|
352
|
-
for k, v in override_kwargs.items()
|
353
|
-
},
|
354
223
|
},
|
355
224
|
)
|
356
225
|
|
@@ -360,7 +229,6 @@ class LLMChatBlock(BaseBlock):
|
|
360
229
|
# Check if there's already a running event loop
|
361
230
|
loop = asyncio.get_running_loop()
|
362
231
|
# Check if nest_asyncio is applied (allows nested asyncio.run)
|
363
|
-
# Use multiple detection methods for robustness
|
364
232
|
nest_asyncio_applied = (
|
365
233
|
hasattr(loop, "_nest_patched")
|
366
234
|
or getattr(asyncio.run, "__module__", "") == "nest_asyncio"
|
@@ -370,7 +238,7 @@ class LLMChatBlock(BaseBlock):
|
|
370
238
|
# nest_asyncio is applied, safe to use asyncio.run
|
371
239
|
responses = asyncio.run(
|
372
240
|
self._generate_async(
|
373
|
-
messages_list,
|
241
|
+
messages_list, completion_kwargs, flow_max_concurrency
|
374
242
|
)
|
375
243
|
)
|
376
244
|
else:
|
@@ -383,19 +251,19 @@ class LLMChatBlock(BaseBlock):
|
|
383
251
|
# No running loop; safe to create one
|
384
252
|
responses = asyncio.run(
|
385
253
|
self._generate_async(
|
386
|
-
messages_list,
|
254
|
+
messages_list, completion_kwargs, flow_max_concurrency
|
387
255
|
)
|
388
256
|
)
|
389
257
|
else:
|
390
|
-
responses = self._generate_sync(messages_list,
|
258
|
+
responses = self._generate_sync(messages_list, completion_kwargs)
|
391
259
|
|
392
260
|
# Log completion
|
393
261
|
logger.info(
|
394
|
-
|
262
|
+
"Generation completed successfully for %d samples",
|
263
|
+
len(responses),
|
395
264
|
extra={
|
396
265
|
"block_name": self.block_name,
|
397
266
|
"model": self.model,
|
398
|
-
"provider": self.client_manager.config.get_provider(),
|
399
267
|
"batch_size": len(responses),
|
400
268
|
},
|
401
269
|
)
|
@@ -403,39 +271,98 @@ class LLMChatBlock(BaseBlock):
|
|
403
271
|
# Add responses as new column
|
404
272
|
return samples.add_column(self.output_cols[0], responses)
|
405
273
|
|
274
|
+
def _build_completion_kwargs(self, **overrides) -> dict[str, Any]:
|
275
|
+
"""Build kwargs for LiteLLM completion call.
|
276
|
+
|
277
|
+
Returns
|
278
|
+
-------
|
279
|
+
dict[str, Any]
|
280
|
+
Kwargs for litellm.completion() or litellm.acompletion().
|
281
|
+
"""
|
282
|
+
# Start with extra fields (temperature, max_tokens, etc.) from extra="allow"
|
283
|
+
extra_values = self.model_dump(exclude_unset=True)
|
284
|
+
|
285
|
+
# Remove block-operational fields that shouldn't go to LiteLLM
|
286
|
+
block_only_fields = {
|
287
|
+
"block_name",
|
288
|
+
"input_cols",
|
289
|
+
"output_cols",
|
290
|
+
"async_mode",
|
291
|
+
}
|
292
|
+
|
293
|
+
completion_kwargs = {
|
294
|
+
k: v for k, v in extra_values.items() if k not in block_only_fields
|
295
|
+
}
|
296
|
+
|
297
|
+
# Add essential LiteLLM fields (even though they're excluded from serialization)
|
298
|
+
if self.model is not None:
|
299
|
+
completion_kwargs["model"] = self.model
|
300
|
+
if self.api_key is not None:
|
301
|
+
completion_kwargs["api_key"] = self.api_key
|
302
|
+
if self.api_base is not None:
|
303
|
+
completion_kwargs["api_base"] = self.api_base
|
304
|
+
if self.timeout is not None:
|
305
|
+
completion_kwargs["timeout"] = self.timeout
|
306
|
+
if self.num_retries is not None:
|
307
|
+
completion_kwargs["num_retries"] = self.num_retries
|
308
|
+
|
309
|
+
# Apply only non-block-field overrides (flow params + unknown LiteLLM params)
|
310
|
+
# BaseBlock already handles block field overrides by modifying instance attributes
|
311
|
+
non_block_overrides = {
|
312
|
+
k: v for k, v in overrides.items() if k not in self.__class__.model_fields
|
313
|
+
}
|
314
|
+
completion_kwargs.update(non_block_overrides)
|
315
|
+
|
316
|
+
# Ensure drop_params is set to handle unknown parameters gracefully
|
317
|
+
completion_kwargs["drop_params"] = self.drop_params
|
318
|
+
|
319
|
+
return completion_kwargs
|
320
|
+
|
321
|
+
def _message_to_dict(self, message) -> dict[str, Any]:
|
322
|
+
"""Convert LiteLLM message to dict."""
|
323
|
+
return {"content": message.content, **getattr(message, "__dict__", {})}
|
324
|
+
|
406
325
|
def _generate_sync(
|
407
326
|
self,
|
408
327
|
messages_list: list[list[dict[str, Any]]],
|
409
|
-
|
410
|
-
) -> list[
|
328
|
+
completion_kwargs: dict[str, Any],
|
329
|
+
) -> list[list[dict[str, Any]]]:
|
411
330
|
"""Generate responses synchronously.
|
412
331
|
|
413
332
|
Parameters
|
414
333
|
----------
|
415
|
-
messages_list :
|
334
|
+
messages_list : list[list[dict[str, Any]]]
|
416
335
|
List of message lists to process.
|
417
|
-
|
418
|
-
|
336
|
+
completion_kwargs : dict[str, Any]
|
337
|
+
Kwargs for LiteLLM completion.
|
419
338
|
|
420
339
|
Returns
|
421
340
|
-------
|
422
|
-
|
423
|
-
List of
|
424
|
-
or a list of dicts when n>1. Response dicts contain 'content', may contain 'reasoning_content' and other fields if any.
|
341
|
+
list[list[dict[str, Any]]]
|
342
|
+
List of response lists, each containing LiteLLM completion response dictionaries.
|
425
343
|
"""
|
426
344
|
responses = []
|
427
345
|
|
428
346
|
for i, messages in enumerate(messages_list):
|
429
347
|
try:
|
430
|
-
response =
|
431
|
-
|
432
|
-
)
|
433
|
-
|
348
|
+
response = completion(messages=messages, **completion_kwargs)
|
349
|
+
# Extract response based on n parameter
|
350
|
+
n_value = completion_kwargs.get("n", 1)
|
351
|
+
if n_value > 1:
|
352
|
+
response_data = [
|
353
|
+
self._message_to_dict(choice.message)
|
354
|
+
for choice in response.choices
|
355
|
+
]
|
356
|
+
else:
|
357
|
+
response_data = [self._message_to_dict(response.choices[0].message)]
|
358
|
+
responses.append(response_data)
|
434
359
|
|
435
360
|
# Log progress for large batches
|
436
361
|
if (i + 1) % 10 == 0:
|
437
362
|
logger.debug(
|
438
|
-
|
363
|
+
"Generated %d/%d responses",
|
364
|
+
i + 1,
|
365
|
+
len(messages_list),
|
439
366
|
extra={
|
440
367
|
"block_name": self.block_name,
|
441
368
|
"progress": f"{i + 1}/{len(messages_list)}",
|
@@ -443,11 +370,10 @@ class LLMChatBlock(BaseBlock):
|
|
443
370
|
)
|
444
371
|
|
445
372
|
except Exception as e:
|
446
|
-
error_msg = self.client_manager.error_handler.format_error_message(
|
447
|
-
e, {"model": self.model, "sample_index": i}
|
448
|
-
)
|
449
373
|
logger.error(
|
450
|
-
|
374
|
+
"Failed to generate response for sample %d: %s",
|
375
|
+
i,
|
376
|
+
str(e),
|
451
377
|
extra={
|
452
378
|
"block_name": self.block_name,
|
453
379
|
"sample_index": i,
|
@@ -458,43 +384,127 @@ class LLMChatBlock(BaseBlock):
|
|
458
384
|
|
459
385
|
return responses
|
460
386
|
|
387
|
+
async def _make_acompletion(
|
388
|
+
self,
|
389
|
+
messages: list[dict[str, Any]],
|
390
|
+
completion_kwargs: dict[str, Any],
|
391
|
+
semaphore: Optional[asyncio.Semaphore] = None,
|
392
|
+
) -> list[dict[str, Any]]:
|
393
|
+
"""Make a single async completion with optional concurrency control.
|
394
|
+
|
395
|
+
Parameters
|
396
|
+
----------
|
397
|
+
messages : list[dict[str, Any]]
|
398
|
+
Messages for this completion.
|
399
|
+
completion_kwargs : dict[str, Any]
|
400
|
+
Kwargs for LiteLLM acompletion.
|
401
|
+
semaphore : Optional[asyncio.Semaphore], optional
|
402
|
+
Semaphore for concurrency control.
|
403
|
+
|
404
|
+
Returns
|
405
|
+
-------
|
406
|
+
list[dict[str, Any]]
|
407
|
+
List of response dictionaries.
|
408
|
+
"""
|
409
|
+
if semaphore:
|
410
|
+
async with semaphore:
|
411
|
+
response = await acompletion(messages=messages, **completion_kwargs)
|
412
|
+
else:
|
413
|
+
response = await acompletion(messages=messages, **completion_kwargs)
|
414
|
+
|
415
|
+
# Extract response based on n parameter
|
416
|
+
n_value = completion_kwargs.get("n", 1)
|
417
|
+
if n_value > 1:
|
418
|
+
return [
|
419
|
+
self._message_to_dict(choice.message) for choice in response.choices
|
420
|
+
]
|
421
|
+
return [self._message_to_dict(response.choices[0].message)]
|
422
|
+
|
461
423
|
async def _generate_async(
|
462
424
|
self,
|
463
425
|
messages_list: list[list[dict[str, Any]]],
|
426
|
+
completion_kwargs: dict[str, Any],
|
464
427
|
flow_max_concurrency: Optional[int] = None,
|
465
|
-
|
466
|
-
) -> list[Union[dict, list[dict]]]:
|
428
|
+
) -> list[list[dict[str, Any]]]:
|
467
429
|
"""Generate responses asynchronously.
|
468
430
|
|
469
431
|
Parameters
|
470
432
|
----------
|
471
|
-
messages_list :
|
433
|
+
messages_list : list[list[dict[str, Any]]]
|
472
434
|
List of message lists to process.
|
435
|
+
completion_kwargs : dict[str, Any]
|
436
|
+
Kwargs for LiteLLM acompletion.
|
473
437
|
flow_max_concurrency : Optional[int], optional
|
474
438
|
Maximum concurrency for async requests.
|
475
|
-
**override_kwargs : Dict[str, Any]
|
476
|
-
Runtime parameter overrides.
|
477
439
|
|
478
440
|
Returns
|
479
441
|
-------
|
480
|
-
|
481
|
-
List of
|
482
|
-
or a list of dicts when n>1. Response dicts contain 'content', may contain 'reasoning_content' and other fields if any.
|
442
|
+
list[list[dict[str, Any]]]
|
443
|
+
List of response lists, each containing LiteLLM completion response dictionaries.
|
483
444
|
"""
|
445
|
+
|
484
446
|
try:
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
447
|
+
if flow_max_concurrency is not None:
|
448
|
+
# Validate max_concurrency parameter
|
449
|
+
if flow_max_concurrency < 1:
|
450
|
+
raise ValueError(
|
451
|
+
f"max_concurrency must be greater than 0, got {flow_max_concurrency}"
|
452
|
+
)
|
489
453
|
|
454
|
+
# Adjust concurrency based on n parameter (number of completions per request)
|
455
|
+
effective_concurrency = flow_max_concurrency
|
456
|
+
n_value = completion_kwargs.get("n", 1)
|
457
|
+
|
458
|
+
if n_value and n_value > 1:
|
459
|
+
if flow_max_concurrency >= n_value:
|
460
|
+
# Adjust concurrency to account for n completions per request
|
461
|
+
effective_concurrency = flow_max_concurrency // n_value
|
462
|
+
logger.debug(
|
463
|
+
"Adjusted max_concurrency from %d to %d for n=%d completions per request",
|
464
|
+
flow_max_concurrency,
|
465
|
+
effective_concurrency,
|
466
|
+
n_value,
|
467
|
+
extra={
|
468
|
+
"block_name": self.block_name,
|
469
|
+
"original_max_concurrency": flow_max_concurrency,
|
470
|
+
"adjusted_max_concurrency": effective_concurrency,
|
471
|
+
"n_value": n_value,
|
472
|
+
},
|
473
|
+
)
|
474
|
+
else:
|
475
|
+
# Warn when max_concurrency is less than n
|
476
|
+
logger.warning(
|
477
|
+
"max_concurrency (%d) is less than n (%d). Consider increasing max_concurrency for optimal performance.",
|
478
|
+
flow_max_concurrency,
|
479
|
+
n_value,
|
480
|
+
extra={
|
481
|
+
"block_name": self.block_name,
|
482
|
+
"max_concurrency": flow_max_concurrency,
|
483
|
+
"n_value": n_value,
|
484
|
+
},
|
485
|
+
)
|
486
|
+
effective_concurrency = flow_max_concurrency
|
487
|
+
|
488
|
+
# Use semaphore for concurrency control
|
489
|
+
semaphore = asyncio.Semaphore(effective_concurrency)
|
490
|
+
tasks = [
|
491
|
+
self._make_acompletion(messages, completion_kwargs, semaphore)
|
492
|
+
for messages in messages_list
|
493
|
+
]
|
494
|
+
else:
|
495
|
+
# No concurrency limit
|
496
|
+
tasks = [
|
497
|
+
self._make_acompletion(messages, completion_kwargs)
|
498
|
+
for messages in messages_list
|
499
|
+
]
|
500
|
+
|
501
|
+
responses = await asyncio.gather(*tasks)
|
490
502
|
return responses
|
491
503
|
|
492
504
|
except Exception as e:
|
493
|
-
error_msg = self.client_manager.error_handler.format_error_message(
|
494
|
-
e, {"model": self.model}
|
495
|
-
)
|
496
505
|
logger.error(
|
497
|
-
|
506
|
+
"Failed to generate async responses: %s",
|
507
|
+
str(e),
|
498
508
|
extra={
|
499
509
|
"block_name": self.block_name,
|
500
510
|
"batch_size": len(messages_list),
|
@@ -503,27 +513,9 @@ class LLMChatBlock(BaseBlock):
|
|
503
513
|
)
|
504
514
|
raise
|
505
515
|
|
506
|
-
def get_model_info(self) -> dict[str, Any]:
|
507
|
-
"""Get information about the configured model.
|
508
|
-
|
509
|
-
Returns
|
510
|
-
-------
|
511
|
-
Dict[str, Any]
|
512
|
-
Model information including provider, capabilities, etc.
|
513
|
-
"""
|
514
|
-
return {
|
515
|
-
**self.client_manager.get_model_info(),
|
516
|
-
"block_name": self.block_name,
|
517
|
-
"input_column": self.input_cols[0],
|
518
|
-
"output_column": self.output_cols[0],
|
519
|
-
"async_mode": self.async_mode,
|
520
|
-
}
|
521
|
-
|
522
516
|
def _validate_custom(self, dataset: Dataset) -> None:
|
523
517
|
"""Custom validation for LLMChatBlock message format.
|
524
518
|
|
525
|
-
Validates that all samples contain properly formatted messages.
|
526
|
-
|
527
519
|
Parameters
|
528
520
|
----------
|
529
521
|
dataset : Dataset
|
@@ -576,25 +568,19 @@ class LLMChatBlock(BaseBlock):
|
|
576
568
|
details=f"Block: {self.block_name}, Row: {idx}, Message: {msg_idx}, Available fields: {list(message.keys())}",
|
577
569
|
)
|
578
570
|
|
579
|
-
return True
|
571
|
+
return True
|
580
572
|
|
581
|
-
#
|
582
|
-
# Add index to each sample for better error reporting
|
573
|
+
# Validate all samples
|
583
574
|
indexed_samples = [(i, sample) for i, sample in enumerate(dataset)]
|
584
575
|
list(map(validate_sample, indexed_samples))
|
585
576
|
|
586
|
-
def __del__(self) -> None:
|
587
|
-
"""Cleanup when block is destroyed."""
|
588
|
-
try:
|
589
|
-
if hasattr(self, "client_manager"):
|
590
|
-
self.client_manager.unload()
|
591
|
-
except Exception:
|
592
|
-
# Ignore errors during cleanup to prevent issues during shutdown
|
593
|
-
pass
|
594
|
-
|
595
577
|
def __repr__(self) -> str:
|
596
578
|
"""String representation of the block."""
|
579
|
+
provider = None
|
580
|
+
if self.model and "/" in self.model:
|
581
|
+
provider = self.model.split("/")[0]
|
582
|
+
|
597
583
|
return (
|
598
584
|
f"LLMChatBlock(name='{self.block_name}', model='{self.model}', "
|
599
|
-
f"provider='{
|
585
|
+
f"provider='{provider}', async_mode={self.async_mode})"
|
600
586
|
)
|