sdg-hub 0.3.1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. sdg_hub/_version.py +2 -2
  2. sdg_hub/core/blocks/__init__.py +2 -4
  3. sdg_hub/core/blocks/base.py +61 -6
  4. sdg_hub/core/blocks/filtering/column_value_filter.py +3 -2
  5. sdg_hub/core/blocks/llm/__init__.py +2 -4
  6. sdg_hub/core/blocks/llm/llm_chat_block.py +251 -265
  7. sdg_hub/core/blocks/llm/llm_chat_with_parsing_retry_block.py +216 -98
  8. sdg_hub/core/blocks/llm/llm_parser_block.py +320 -0
  9. sdg_hub/core/blocks/llm/text_parser_block.py +53 -152
  10. sdg_hub/core/flow/base.py +7 -4
  11. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/detailed_summary/flow.yaml +51 -11
  12. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/doc_direct_qa/__init__.py +0 -0
  13. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/doc_direct_qa/flow.yaml +159 -0
  14. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/extractive_summary/flow.yaml +51 -11
  15. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/key_facts/flow.yaml +14 -2
  16. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/flow.yaml +146 -26
  17. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/multilingual/japanese/README.md +0 -0
  18. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/multilingual/japanese/__init__.py +0 -0
  19. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/multilingual/japanese/atomic_facts_ja.yaml +41 -0
  20. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/multilingual/japanese/detailed_summary_ja.yaml +14 -0
  21. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/multilingual/japanese/extractive_summary_ja.yaml +14 -0
  22. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/multilingual/japanese/flow.yaml +304 -0
  23. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/multilingual/japanese/generate_questions_responses_ja.yaml +55 -0
  24. sdg_hub/flows/text_analysis/structured_insights/flow.yaml +28 -4
  25. {sdg_hub-0.3.1.dist-info → sdg_hub-0.4.0.dist-info}/METADATA +1 -1
  26. {sdg_hub-0.3.1.dist-info → sdg_hub-0.4.0.dist-info}/RECORD +29 -25
  27. sdg_hub/core/blocks/evaluation/__init__.py +0 -9
  28. sdg_hub/core/blocks/evaluation/evaluate_faithfulness_block.py +0 -323
  29. sdg_hub/core/blocks/evaluation/evaluate_relevancy_block.py +0 -323
  30. sdg_hub/core/blocks/evaluation/verify_question_block.py +0 -329
  31. sdg_hub/core/blocks/llm/client_manager.py +0 -472
  32. sdg_hub/core/blocks/llm/config.py +0 -337
  33. {sdg_hub-0.3.1.dist-info → sdg_hub-0.4.0.dist-info}/WHEEL +0 -0
  34. {sdg_hub-0.3.1.dist-info → sdg_hub-0.4.0.dist-info}/licenses/LICENSE +0 -0
  35. {sdg_hub-0.3.1.dist-info → sdg_hub-0.4.0.dist-info}/top_level.txt +0 -0
@@ -2,21 +2,23 @@
2
2
  """Unified LLM chat block supporting all providers via LiteLLM."""
3
3
 
4
4
  # Standard
5
- from typing import Any, Optional, Union
5
+ from typing import Any, Optional
6
6
  import asyncio
7
7
 
8
8
  # Third Party
9
9
  from datasets import Dataset
10
- from pydantic import Field, field_validator
10
+ from litellm import acompletion, completion
11
+ from pydantic import ConfigDict, Field, field_validator
12
+ import litellm
11
13
 
12
- # Local
13
14
  from ...utils.error_handling import BlockValidationError
14
15
  from ...utils.logger_config import setup_logger
16
+
17
+ # Local
15
18
  from ..base import BaseBlock
16
19
  from ..registry import BlockRegistry
17
- from .client_manager import LLMClientManager
18
- from .config import LLMConfig
19
20
 
21
+ litellm.drop_params = True
20
22
  logger = setup_logger(__name__)
21
23
 
22
24
 
@@ -26,10 +28,12 @@ logger = setup_logger(__name__)
26
28
  "Unified LLM chat block supporting 100+ providers via LiteLLM",
27
29
  )
28
30
  class LLMChatBlock(BaseBlock):
31
+ model_config = ConfigDict(extra="allow")
32
+
29
33
  """Unified LLM chat block supporting all providers via LiteLLM.
30
34
 
31
- This block replaces OpenAIChatBlock and OpenAIAsyncChatBlock with a single
32
- implementation that supports 100+ LLM providers through LiteLLM, including:
35
+ This block provides a minimal wrapper around LiteLLM's completion API,
36
+ supporting 100+ LLM providers including:
33
37
  - OpenAI (GPT-3.5, GPT-4, etc.)
34
38
  - Anthropic (Claude models)
35
39
  - Google (Gemini, PaLM)
@@ -43,14 +47,10 @@ class LLMChatBlock(BaseBlock):
43
47
  input_cols : Union[str, List[str]]
44
48
  Input column name(s). Should contain the messages list.
45
49
  output_cols : Union[dict, List[dict]]
46
- Output column name(s) for the response. When n > 1, the column will contain
47
- a list of responses instead of a single response. Responses contain 'content',
48
- may contain 'reasoning_content' and other fields if any.
49
- model : str
50
- Model identifier in LiteLLM format. Examples:
51
- - "openai/gpt-4"
52
- - "anthropic/claude-3-sonnet-20240229"
53
- - "hosted_vllm/meta-llama/Llama-2-7b-chat-hf"
50
+ Output column name(s) for the response.
51
+ model : Optional[str], optional
52
+ Model identifier in LiteLLM format. Can be set later via flow.set_model_config().
53
+ Examples: "openai/gpt-4", "anthropic/claude-3-sonnet-20240229"
54
54
  api_key : Optional[str], optional
55
55
  API key for the provider. Falls back to environment variables.
56
56
  api_base : Optional[str], optional
@@ -59,138 +59,68 @@ class LLMChatBlock(BaseBlock):
59
59
  Whether to use async processing, by default False.
60
60
  timeout : float, optional
61
61
  Request timeout in seconds, by default 120.0.
62
- max_retries : int, optional
63
- Maximum number of retry attempts, by default 6.
64
-
65
- ### Generation Parameters ###
66
-
67
- temperature : Optional[float], optional
68
- Sampling temperature (0.0 to 2.0).
69
- max_tokens : Optional[int], optional
70
- Maximum tokens to generate.
71
- top_p : Optional[float], optional
72
- Nucleus sampling parameter (0.0 to 1.0).
73
- frequency_penalty : Optional[float], optional
74
- Frequency penalty (-2.0 to 2.0).
75
- presence_penalty : Optional[float], optional
76
- Presence penalty (-2.0 to 2.0).
77
- stop : Optional[Union[str, List[str]]], optional
78
- Stop sequences.
79
- seed : Optional[int], optional
80
- Random seed for reproducible outputs.
81
- response_format : Optional[Dict[str, Any]], optional
82
- Response format specification (e.g., JSON mode).
83
- stream : Optional[bool], optional
84
- Whether to stream responses.
85
- n : Optional[int], optional
86
- Number of completions to generate. When n > 1, the output column will contain
87
- a list of responses for each input sample.
88
- logprobs : Optional[bool], optional
89
- Whether to return log probabilities.
90
- top_logprobs : Optional[int], optional
91
- Number of top log probabilities to return.
92
- user : Optional[str], optional
93
- End-user identifier.
94
- extra_headers : Optional[Dict[str, str]], optional
95
- Additional headers to send with requests.
96
- extra_body : Optional[Dict[str, Any]], optional
97
- Additional parameters for the request body.
62
+ num_retries : int, optional
63
+ Number of retry attempts (uses LiteLLM's built-in retry mechanism), by default 6.
64
+ Note: For rate limit handling, use LiteLLM's fallbacks parameter instead.
65
+ drop_params : bool, optional
66
+ Whether to drop unsupported parameters to prevent API errors, by default True.
98
67
  **kwargs : Any
99
- Additional provider-specific parameters.
68
+ Any LiteLLM completion parameters (temperature, max_tokens, top_p, etc.).
69
+ See https://docs.litellm.ai/docs/completion/input for full list.
100
70
 
101
71
  Examples
102
72
  --------
103
- >>> # OpenAI GPT-4
73
+ >>> # OpenAI GPT-4 with generation parameters
104
74
  >>> block = LLMChatBlock(
105
75
  ... block_name="gpt4_block",
106
76
  ... input_cols="messages",
107
77
  ... output_cols="response",
108
78
  ... model="openai/gpt-4",
109
- ... temperature=0.7
110
- ... )
111
-
112
- >>> # Anthropic Claude
113
- >>> block = LLMChatBlock(
114
- ... block_name="claude_block",
115
- ... input_cols="messages",
116
- ... output_cols="response",
117
- ... model="anthropic/claude-3-sonnet-20240229",
118
- ... temperature=0.7
79
+ ... temperature=0.7,
80
+ ... max_tokens=1000
119
81
  ... )
120
82
 
121
- >>> # Local vLLM model
83
+ >>> # Local vLLM model with custom parameters
122
84
  >>> block = LLMChatBlock(
123
85
  ... block_name="local_llama",
124
86
  ... input_cols="messages",
125
87
  ... output_cols="response",
126
88
  ... model="hosted_vllm/meta-llama/Llama-2-7b-chat-hf",
127
89
  ... api_base="http://localhost:8000/v1",
128
- ... temperature=0.7
129
- ... )
130
-
131
- >>> # Multiple completions (n > 1)
132
- >>> block = LLMChatBlock(
133
- ... block_name="gpt4_multiple",
134
- ... input_cols="messages",
135
- ... output_cols="responses", # Will contain lists of responses
136
- ... model="openai/gpt-4",
137
- ... n=3, # Generate 3 responses per input
138
- ... temperature=0.8
90
+ ... temperature=0.7,
91
+ ... response_format={"type": "json_object"}
139
92
  ... )
140
93
  """
141
94
 
142
- # LLM Configuration
143
- model: Optional[str] = Field(None, description="Model identifier in LiteLLM format")
144
- api_key: Optional[str] = Field(None, description="API key for the provider")
145
- api_base: Optional[str] = Field(None, description="Base URL for the API")
146
- async_mode: bool = Field(False, description="Whether to use async processing")
147
- timeout: float = Field(120.0, description="Request timeout in seconds")
148
- max_retries: int = Field(6, description="Maximum number of retry attempts")
149
-
150
- # Generation parameters
151
- temperature: Optional[float] = Field(
152
- None, description="Sampling temperature (0.0 to 2.0)"
153
- )
154
- max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate")
155
- top_p: Optional[float] = Field(
156
- None, description="Nucleus sampling parameter (0.0 to 1.0)"
95
+ # Essential operational fields (excluded from YAML serialization)
96
+ model: Optional[str] = Field(
97
+ None, exclude=True, description="Model identifier in LiteLLM format"
157
98
  )
158
- frequency_penalty: Optional[float] = Field(
159
- None, description="Frequency penalty (-2.0 to 2.0)"
99
+ api_key: Optional[str] = Field(
100
+ None, exclude=True, description="API key for the provider"
160
101
  )
161
- presence_penalty: Optional[float] = Field(
162
- None, description="Presence penalty (-2.0 to 2.0)"
102
+ api_base: Optional[str] = Field(
103
+ None, exclude=True, description="Base URL for the API"
163
104
  )
164
- stop: Optional[Union[str, list[str]]] = Field(None, description="Stop sequences")
165
- seed: Optional[int] = Field(
166
- None, description="Random seed for reproducible outputs"
105
+ async_mode: bool = Field(
106
+ False, exclude=True, description="Whether to use async processing"
167
107
  )
168
- response_format: Optional[dict[str, Any]] = Field(
169
- None, description="Response format specification"
108
+ timeout: float = Field(
109
+ 120.0, exclude=True, description="Request timeout in seconds"
170
110
  )
171
- stream: Optional[bool] = Field(None, description="Whether to stream responses")
172
- n: Optional[int] = Field(None, description="Number of completions to generate")
173
- logprobs: Optional[bool] = Field(
174
- None, description="Whether to return log probabilities"
111
+ num_retries: int = Field(
112
+ 6,
113
+ exclude=True,
114
+ description="Number of retry attempts (uses LiteLLM's built-in retry mechanism)",
175
115
  )
176
- top_logprobs: Optional[int] = Field(
177
- None, description="Number of top log probabilities to return"
178
- )
179
- user: Optional[str] = Field(None, description="End-user identifier")
180
- extra_headers: Optional[dict[str, str]] = Field(
181
- None, description="Additional headers"
182
- )
183
- extra_body: Optional[dict[str, Any]] = Field(
184
- None, description="Additional request body parameters"
185
- )
186
- provider_specific: Optional[dict[str, Any]] = Field(
187
- None, description="Provider-specific parameters"
116
+ drop_params: bool = Field(
117
+ True, description="Whether to drop unsupported parameters to prevent API errors"
188
118
  )
189
119
 
190
- # Exclude from serialization - internal computed field
191
- client_manager: Optional[Any] = Field(
192
- None, exclude=True, description="Internal client manager"
193
- )
120
+ # All LiteLLM completion parameters can be passed via extra="allow"
121
+ # Common examples: temperature, max_tokens, top_p, frequency_penalty,
122
+ # presence_penalty, stop, seed, response_format, stream, n, logprobs,
123
+ # top_logprobs, user, extra_headers, extra_body, etc.
194
124
 
195
125
  @field_validator("input_cols")
196
126
  @classmethod
@@ -224,83 +154,29 @@ class LLMChatBlock(BaseBlock):
224
154
  """Initialize after Pydantic validation."""
225
155
  super().model_post_init(__context)
226
156
 
227
- # Initialize client manager
228
- self._setup_client_manager()
229
-
230
- def _setup_client_manager(self) -> None:
231
- """Set up the LLM client manager with current configuration."""
232
- # Create configuration with current values
233
- config = LLMConfig(
234
- model=self.model,
235
- api_key=self.api_key,
236
- api_base=self.api_base,
237
- timeout=self.timeout,
238
- max_retries=self.max_retries,
239
- temperature=self.temperature,
240
- max_tokens=self.max_tokens,
241
- top_p=self.top_p,
242
- frequency_penalty=self.frequency_penalty,
243
- presence_penalty=self.presence_penalty,
244
- stop=self.stop,
245
- seed=self.seed,
246
- response_format=self.response_format,
247
- stream=self.stream,
248
- n=self.n,
249
- logprobs=self.logprobs,
250
- top_logprobs=self.top_logprobs,
251
- user=self.user,
252
- extra_headers=self.extra_headers,
253
- extra_body=self.extra_body,
254
- provider_specific=self.provider_specific,
255
- )
256
-
257
- # Create client manager
258
- self.client_manager = LLMClientManager(config)
259
-
260
- # Load client immediately
261
- self.client_manager.load()
262
-
263
157
  # Log initialization only when model is configured
264
158
  if self.model:
265
159
  logger.info(
266
- f"Initialized LLMChatBlock '{self.block_name}' with model '{self.model}'",
160
+ "Initialized LLMChatBlock '%s' with model '%s'",
161
+ self.block_name,
162
+ self.model,
267
163
  extra={
268
164
  "block_name": self.block_name,
269
165
  "model": self.model,
270
- "provider": self.client_manager.config.get_provider(),
271
- "is_local": self.client_manager.config.is_local_model(),
272
166
  "async_mode": self.async_mode,
273
- "generation_params": self.client_manager.config.get_generation_kwargs(),
274
167
  },
275
168
  )
276
169
 
277
- def _reinitialize_client_manager(self) -> None:
278
- """Reinitialize the client manager with updated model configuration.
279
-
280
- This should be called after model configuration changes to ensure
281
- the client manager uses the updated model, api_base, api_key, etc.
282
- """
283
- self._setup_client_manager()
284
-
285
- def generate(self, samples: Dataset, **override_kwargs: dict[str, Any]) -> Dataset:
170
+ def generate(self, samples: Dataset, **kwargs: Any) -> Dataset:
286
171
  """Generate responses from the LLM.
287
172
 
288
- Parameters set at runtime override those set during initialization.
289
- Supports all LiteLLM parameters for the configured provider.
290
-
291
173
  Parameters
292
174
  ----------
293
175
  samples : Dataset
294
176
  Input dataset containing the messages column.
295
- **override_kwargs : Dict[str, Any]
177
+ **kwargs : Any
296
178
  Runtime parameters that override initialization defaults.
297
- Valid parameters depend on the provider but typically include:
298
- temperature, max_tokens, top_p, frequency_penalty, presence_penalty,
299
- stop, seed, response_format, stream, n, and provider-specific params.
300
-
301
- Special flow-level parameters:
302
- _flow_max_concurrency : int, optional
303
- Maximum concurrency for async requests (passed by Flow).
179
+ Supports all LiteLLM completion parameters.
304
180
 
305
181
  Returns
306
182
  -------
@@ -319,16 +195,21 @@ class LLMChatBlock(BaseBlock):
319
195
  f"Call flow.set_model_config() before generating."
320
196
  )
321
197
 
322
- # Extract max_concurrency if provided by flow
323
- flow_max_concurrency = override_kwargs.pop("_flow_max_concurrency", None)
198
+ # Extract flow-specific parameters (BaseBlock already handled block field overrides)
199
+ flow_max_concurrency = kwargs.pop("_flow_max_concurrency", None)
200
+
201
+ # Build completion kwargs from ALL fields + runtime overrides
202
+ completion_kwargs = self._build_completion_kwargs(**kwargs)
324
203
 
325
204
  # Extract messages
326
205
  messages_list = samples[self.input_cols[0]]
327
206
 
328
207
  # Log generation start
329
208
  logger.info(
330
- f"Starting {'async' if self.async_mode else 'sync'} generation for {len(messages_list)} samples"
331
- + (
209
+ "Starting %s generation for %d samples%s",
210
+ "async" if self.async_mode else "sync",
211
+ len(messages_list),
212
+ (
332
213
  f" (max_concurrency={flow_max_concurrency})"
333
214
  if flow_max_concurrency
334
215
  else ""
@@ -336,21 +217,9 @@ class LLMChatBlock(BaseBlock):
336
217
  extra={
337
218
  "block_name": self.block_name,
338
219
  "model": self.model,
339
- "provider": self.client_manager.config.get_provider(),
340
220
  "batch_size": len(messages_list),
341
221
  "async_mode": self.async_mode,
342
222
  "flow_max_concurrency": flow_max_concurrency,
343
- "override_params": {
344
- k: (
345
- "***"
346
- if any(
347
- s in k.lower()
348
- for s in ["key", "token", "secret", "authorization"]
349
- )
350
- else v
351
- )
352
- for k, v in override_kwargs.items()
353
- },
354
223
  },
355
224
  )
356
225
 
@@ -360,7 +229,6 @@ class LLMChatBlock(BaseBlock):
360
229
  # Check if there's already a running event loop
361
230
  loop = asyncio.get_running_loop()
362
231
  # Check if nest_asyncio is applied (allows nested asyncio.run)
363
- # Use multiple detection methods for robustness
364
232
  nest_asyncio_applied = (
365
233
  hasattr(loop, "_nest_patched")
366
234
  or getattr(asyncio.run, "__module__", "") == "nest_asyncio"
@@ -370,7 +238,7 @@ class LLMChatBlock(BaseBlock):
370
238
  # nest_asyncio is applied, safe to use asyncio.run
371
239
  responses = asyncio.run(
372
240
  self._generate_async(
373
- messages_list, flow_max_concurrency, **override_kwargs
241
+ messages_list, completion_kwargs, flow_max_concurrency
374
242
  )
375
243
  )
376
244
  else:
@@ -383,19 +251,19 @@ class LLMChatBlock(BaseBlock):
383
251
  # No running loop; safe to create one
384
252
  responses = asyncio.run(
385
253
  self._generate_async(
386
- messages_list, flow_max_concurrency, **override_kwargs
254
+ messages_list, completion_kwargs, flow_max_concurrency
387
255
  )
388
256
  )
389
257
  else:
390
- responses = self._generate_sync(messages_list, **override_kwargs)
258
+ responses = self._generate_sync(messages_list, completion_kwargs)
391
259
 
392
260
  # Log completion
393
261
  logger.info(
394
- f"Generation completed successfully for {len(responses)} samples",
262
+ "Generation completed successfully for %d samples",
263
+ len(responses),
395
264
  extra={
396
265
  "block_name": self.block_name,
397
266
  "model": self.model,
398
- "provider": self.client_manager.config.get_provider(),
399
267
  "batch_size": len(responses),
400
268
  },
401
269
  )
@@ -403,39 +271,98 @@ class LLMChatBlock(BaseBlock):
403
271
  # Add responses as new column
404
272
  return samples.add_column(self.output_cols[0], responses)
405
273
 
274
+ def _build_completion_kwargs(self, **overrides) -> dict[str, Any]:
275
+ """Build kwargs for LiteLLM completion call.
276
+
277
+ Returns
278
+ -------
279
+ dict[str, Any]
280
+ Kwargs for litellm.completion() or litellm.acompletion().
281
+ """
282
+ # Start with extra fields (temperature, max_tokens, etc.) from extra="allow"
283
+ extra_values = self.model_dump(exclude_unset=True)
284
+
285
+ # Remove block-operational fields that shouldn't go to LiteLLM
286
+ block_only_fields = {
287
+ "block_name",
288
+ "input_cols",
289
+ "output_cols",
290
+ "async_mode",
291
+ }
292
+
293
+ completion_kwargs = {
294
+ k: v for k, v in extra_values.items() if k not in block_only_fields
295
+ }
296
+
297
+ # Add essential LiteLLM fields (even though they're excluded from serialization)
298
+ if self.model is not None:
299
+ completion_kwargs["model"] = self.model
300
+ if self.api_key is not None:
301
+ completion_kwargs["api_key"] = self.api_key
302
+ if self.api_base is not None:
303
+ completion_kwargs["api_base"] = self.api_base
304
+ if self.timeout is not None:
305
+ completion_kwargs["timeout"] = self.timeout
306
+ if self.num_retries is not None:
307
+ completion_kwargs["num_retries"] = self.num_retries
308
+
309
+ # Apply only non-block-field overrides (flow params + unknown LiteLLM params)
310
+ # BaseBlock already handles block field overrides by modifying instance attributes
311
+ non_block_overrides = {
312
+ k: v for k, v in overrides.items() if k not in self.__class__.model_fields
313
+ }
314
+ completion_kwargs.update(non_block_overrides)
315
+
316
+ # Ensure drop_params is set to handle unknown parameters gracefully
317
+ completion_kwargs["drop_params"] = self.drop_params
318
+
319
+ return completion_kwargs
320
+
321
+ def _message_to_dict(self, message) -> dict[str, Any]:
322
+ """Convert LiteLLM message to dict."""
323
+ return {"content": message.content, **getattr(message, "__dict__", {})}
324
+
406
325
  def _generate_sync(
407
326
  self,
408
327
  messages_list: list[list[dict[str, Any]]],
409
- **override_kwargs: dict[str, Any],
410
- ) -> list[Union[dict, list[dict]]]:
328
+ completion_kwargs: dict[str, Any],
329
+ ) -> list[list[dict[str, Any]]]:
411
330
  """Generate responses synchronously.
412
331
 
413
332
  Parameters
414
333
  ----------
415
- messages_list : List[List[Dict[str, Any]]]
334
+ messages_list : list[list[dict[str, Any]]]
416
335
  List of message lists to process.
417
- **override_kwargs : Dict[str, Any]
418
- Runtime parameter overrides.
336
+ completion_kwargs : dict[str, Any]
337
+ Kwargs for LiteLLM completion.
419
338
 
420
339
  Returns
421
340
  -------
422
- List[Union[dict, List[dict]]]
423
- List of responses. Each element is a dict when n=1 or n is None,
424
- or a list of dicts when n>1. Response dicts contain 'content', may contain 'reasoning_content' and other fields if any.
341
+ list[list[dict[str, Any]]]
342
+ List of response lists, each containing LiteLLM completion response dictionaries.
425
343
  """
426
344
  responses = []
427
345
 
428
346
  for i, messages in enumerate(messages_list):
429
347
  try:
430
- response = self.client_manager.create_completion(
431
- messages, **override_kwargs
432
- )
433
- responses.append(response)
348
+ response = completion(messages=messages, **completion_kwargs)
349
+ # Extract response based on n parameter
350
+ n_value = completion_kwargs.get("n", 1)
351
+ if n_value > 1:
352
+ response_data = [
353
+ self._message_to_dict(choice.message)
354
+ for choice in response.choices
355
+ ]
356
+ else:
357
+ response_data = [self._message_to_dict(response.choices[0].message)]
358
+ responses.append(response_data)
434
359
 
435
360
  # Log progress for large batches
436
361
  if (i + 1) % 10 == 0:
437
362
  logger.debug(
438
- f"Generated {i + 1}/{len(messages_list)} responses",
363
+ "Generated %d/%d responses",
364
+ i + 1,
365
+ len(messages_list),
439
366
  extra={
440
367
  "block_name": self.block_name,
441
368
  "progress": f"{i + 1}/{len(messages_list)}",
@@ -443,11 +370,10 @@ class LLMChatBlock(BaseBlock):
443
370
  )
444
371
 
445
372
  except Exception as e:
446
- error_msg = self.client_manager.error_handler.format_error_message(
447
- e, {"model": self.model, "sample_index": i}
448
- )
449
373
  logger.error(
450
- f"Failed to generate response for sample {i}: {error_msg}",
374
+ "Failed to generate response for sample %d: %s",
375
+ i,
376
+ str(e),
451
377
  extra={
452
378
  "block_name": self.block_name,
453
379
  "sample_index": i,
@@ -458,43 +384,127 @@ class LLMChatBlock(BaseBlock):
458
384
 
459
385
  return responses
460
386
 
387
+ async def _make_acompletion(
388
+ self,
389
+ messages: list[dict[str, Any]],
390
+ completion_kwargs: dict[str, Any],
391
+ semaphore: Optional[asyncio.Semaphore] = None,
392
+ ) -> list[dict[str, Any]]:
393
+ """Make a single async completion with optional concurrency control.
394
+
395
+ Parameters
396
+ ----------
397
+ messages : list[dict[str, Any]]
398
+ Messages for this completion.
399
+ completion_kwargs : dict[str, Any]
400
+ Kwargs for LiteLLM acompletion.
401
+ semaphore : Optional[asyncio.Semaphore], optional
402
+ Semaphore for concurrency control.
403
+
404
+ Returns
405
+ -------
406
+ list[dict[str, Any]]
407
+ List of response dictionaries.
408
+ """
409
+ if semaphore:
410
+ async with semaphore:
411
+ response = await acompletion(messages=messages, **completion_kwargs)
412
+ else:
413
+ response = await acompletion(messages=messages, **completion_kwargs)
414
+
415
+ # Extract response based on n parameter
416
+ n_value = completion_kwargs.get("n", 1)
417
+ if n_value > 1:
418
+ return [
419
+ self._message_to_dict(choice.message) for choice in response.choices
420
+ ]
421
+ return [self._message_to_dict(response.choices[0].message)]
422
+
461
423
  async def _generate_async(
462
424
  self,
463
425
  messages_list: list[list[dict[str, Any]]],
426
+ completion_kwargs: dict[str, Any],
464
427
  flow_max_concurrency: Optional[int] = None,
465
- **override_kwargs: dict[str, Any],
466
- ) -> list[Union[dict, list[dict]]]:
428
+ ) -> list[list[dict[str, Any]]]:
467
429
  """Generate responses asynchronously.
468
430
 
469
431
  Parameters
470
432
  ----------
471
- messages_list : List[List[Dict[str, Any]]]
433
+ messages_list : list[list[dict[str, Any]]]
472
434
  List of message lists to process.
435
+ completion_kwargs : dict[str, Any]
436
+ Kwargs for LiteLLM acompletion.
473
437
  flow_max_concurrency : Optional[int], optional
474
438
  Maximum concurrency for async requests.
475
- **override_kwargs : Dict[str, Any]
476
- Runtime parameter overrides.
477
439
 
478
440
  Returns
479
441
  -------
480
- List[Union[dict, List[dict]]]
481
- List of responses. Each element is a dict when n=1 or n is None,
482
- or a list of dicts when n>1. Response dicts contain 'content', may contain 'reasoning_content' and other fields if any.
442
+ list[list[dict[str, Any]]]
443
+ List of response lists, each containing LiteLLM completion response dictionaries.
483
444
  """
445
+
484
446
  try:
485
- # Use unified client manager method with optional concurrency control
486
- responses = await self.client_manager.acreate_completion(
487
- messages_list, max_concurrency=flow_max_concurrency, **override_kwargs
488
- )
447
+ if flow_max_concurrency is not None:
448
+ # Validate max_concurrency parameter
449
+ if flow_max_concurrency < 1:
450
+ raise ValueError(
451
+ f"max_concurrency must be greater than 0, got {flow_max_concurrency}"
452
+ )
489
453
 
454
+ # Adjust concurrency based on n parameter (number of completions per request)
455
+ effective_concurrency = flow_max_concurrency
456
+ n_value = completion_kwargs.get("n", 1)
457
+
458
+ if n_value and n_value > 1:
459
+ if flow_max_concurrency >= n_value:
460
+ # Adjust concurrency to account for n completions per request
461
+ effective_concurrency = flow_max_concurrency // n_value
462
+ logger.debug(
463
+ "Adjusted max_concurrency from %d to %d for n=%d completions per request",
464
+ flow_max_concurrency,
465
+ effective_concurrency,
466
+ n_value,
467
+ extra={
468
+ "block_name": self.block_name,
469
+ "original_max_concurrency": flow_max_concurrency,
470
+ "adjusted_max_concurrency": effective_concurrency,
471
+ "n_value": n_value,
472
+ },
473
+ )
474
+ else:
475
+ # Warn when max_concurrency is less than n
476
+ logger.warning(
477
+ "max_concurrency (%d) is less than n (%d). Consider increasing max_concurrency for optimal performance.",
478
+ flow_max_concurrency,
479
+ n_value,
480
+ extra={
481
+ "block_name": self.block_name,
482
+ "max_concurrency": flow_max_concurrency,
483
+ "n_value": n_value,
484
+ },
485
+ )
486
+ effective_concurrency = flow_max_concurrency
487
+
488
+ # Use semaphore for concurrency control
489
+ semaphore = asyncio.Semaphore(effective_concurrency)
490
+ tasks = [
491
+ self._make_acompletion(messages, completion_kwargs, semaphore)
492
+ for messages in messages_list
493
+ ]
494
+ else:
495
+ # No concurrency limit
496
+ tasks = [
497
+ self._make_acompletion(messages, completion_kwargs)
498
+ for messages in messages_list
499
+ ]
500
+
501
+ responses = await asyncio.gather(*tasks)
490
502
  return responses
491
503
 
492
504
  except Exception as e:
493
- error_msg = self.client_manager.error_handler.format_error_message(
494
- e, {"model": self.model}
495
- )
496
505
  logger.error(
497
- f"Failed to generate async responses: {error_msg}",
506
+ "Failed to generate async responses: %s",
507
+ str(e),
498
508
  extra={
499
509
  "block_name": self.block_name,
500
510
  "batch_size": len(messages_list),
@@ -503,27 +513,9 @@ class LLMChatBlock(BaseBlock):
503
513
  )
504
514
  raise
505
515
 
506
- def get_model_info(self) -> dict[str, Any]:
507
- """Get information about the configured model.
508
-
509
- Returns
510
- -------
511
- Dict[str, Any]
512
- Model information including provider, capabilities, etc.
513
- """
514
- return {
515
- **self.client_manager.get_model_info(),
516
- "block_name": self.block_name,
517
- "input_column": self.input_cols[0],
518
- "output_column": self.output_cols[0],
519
- "async_mode": self.async_mode,
520
- }
521
-
522
516
  def _validate_custom(self, dataset: Dataset) -> None:
523
517
  """Custom validation for LLMChatBlock message format.
524
518
 
525
- Validates that all samples contain properly formatted messages.
526
-
527
519
  Parameters
528
520
  ----------
529
521
  dataset : Dataset
@@ -576,25 +568,19 @@ class LLMChatBlock(BaseBlock):
576
568
  details=f"Block: {self.block_name}, Row: {idx}, Message: {msg_idx}, Available fields: {list(message.keys())}",
577
569
  )
578
570
 
579
- return True # Return something for map
571
+ return True
580
572
 
581
- # Use map to validate all samples
582
- # Add index to each sample for better error reporting
573
+ # Validate all samples
583
574
  indexed_samples = [(i, sample) for i, sample in enumerate(dataset)]
584
575
  list(map(validate_sample, indexed_samples))
585
576
 
586
- def __del__(self) -> None:
587
- """Cleanup when block is destroyed."""
588
- try:
589
- if hasattr(self, "client_manager"):
590
- self.client_manager.unload()
591
- except Exception:
592
- # Ignore errors during cleanup to prevent issues during shutdown
593
- pass
594
-
595
577
  def __repr__(self) -> str:
596
578
  """String representation of the block."""
579
+ provider = None
580
+ if self.model and "/" in self.model:
581
+ provider = self.model.split("/")[0]
582
+
597
583
  return (
598
584
  f"LLMChatBlock(name='{self.block_name}', model='{self.model}', "
599
- f"provider='{self.client_manager.config.get_provider()}', async_mode={self.async_mode})"
585
+ f"provider='{provider}', async_mode={self.async_mode})"
600
586
  )