sdg-hub 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (139) hide show
  1. sdg_hub/__init__.py +28 -1
  2. sdg_hub/_version.py +2 -2
  3. sdg_hub/core/__init__.py +22 -0
  4. sdg_hub/core/blocks/__init__.py +58 -0
  5. sdg_hub/core/blocks/base.py +313 -0
  6. sdg_hub/core/blocks/deprecated_blocks/__init__.py +29 -0
  7. sdg_hub/core/blocks/deprecated_blocks/combine_columns.py +93 -0
  8. sdg_hub/core/blocks/deprecated_blocks/duplicate_columns.py +88 -0
  9. sdg_hub/core/blocks/deprecated_blocks/filter_by_value.py +103 -0
  10. sdg_hub/core/blocks/deprecated_blocks/flatten_columns.py +94 -0
  11. sdg_hub/core/blocks/deprecated_blocks/llmblock.py +479 -0
  12. sdg_hub/core/blocks/deprecated_blocks/rename_columns.py +88 -0
  13. sdg_hub/core/blocks/deprecated_blocks/sample_populator.py +58 -0
  14. sdg_hub/core/blocks/deprecated_blocks/selector.py +97 -0
  15. sdg_hub/core/blocks/deprecated_blocks/set_to_majority_value.py +88 -0
  16. sdg_hub/core/blocks/evaluation/__init__.py +9 -0
  17. sdg_hub/core/blocks/evaluation/evaluate_faithfulness_block.py +564 -0
  18. sdg_hub/core/blocks/evaluation/evaluate_relevancy_block.py +564 -0
  19. sdg_hub/core/blocks/evaluation/verify_question_block.py +564 -0
  20. sdg_hub/core/blocks/filtering/__init__.py +12 -0
  21. sdg_hub/core/blocks/filtering/column_value_filter.py +188 -0
  22. sdg_hub/core/blocks/llm/__init__.py +25 -0
  23. sdg_hub/core/blocks/llm/client_manager.py +398 -0
  24. sdg_hub/core/blocks/llm/config.py +336 -0
  25. sdg_hub/core/blocks/llm/error_handler.py +368 -0
  26. sdg_hub/core/blocks/llm/llm_chat_block.py +542 -0
  27. sdg_hub/core/blocks/llm/prompt_builder_block.py +368 -0
  28. sdg_hub/core/blocks/llm/text_parser_block.py +310 -0
  29. sdg_hub/core/blocks/registry.py +331 -0
  30. sdg_hub/core/blocks/transform/__init__.py +23 -0
  31. sdg_hub/core/blocks/transform/duplicate_columns.py +88 -0
  32. sdg_hub/core/blocks/transform/index_based_mapper.py +225 -0
  33. sdg_hub/core/blocks/transform/melt_columns.py +126 -0
  34. sdg_hub/core/blocks/transform/rename_columns.py +69 -0
  35. sdg_hub/core/blocks/transform/text_concat.py +102 -0
  36. sdg_hub/core/blocks/transform/uniform_col_val_setter.py +101 -0
  37. sdg_hub/core/flow/__init__.py +20 -0
  38. sdg_hub/core/flow/base.py +980 -0
  39. sdg_hub/core/flow/metadata.py +344 -0
  40. sdg_hub/core/flow/migration.py +187 -0
  41. sdg_hub/core/flow/registry.py +330 -0
  42. sdg_hub/core/flow/validation.py +265 -0
  43. sdg_hub/{utils → core/utils}/__init__.py +6 -4
  44. sdg_hub/{utils → core/utils}/datautils.py +1 -3
  45. sdg_hub/core/utils/error_handling.py +208 -0
  46. sdg_hub/{utils → core/utils}/path_resolution.py +2 -2
  47. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/atomic_facts.yaml +40 -0
  48. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/detailed_summary.yaml +13 -0
  49. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_faithfulness.yaml +64 -0
  50. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_question.yaml +29 -0
  51. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_relevancy.yaml +81 -0
  52. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/extractive_summary.yaml +13 -0
  53. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/flow.yaml +191 -0
  54. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/generate_questions_responses.yaml +54 -0
  55. sdg_hub-0.2.0.dist-info/METADATA +218 -0
  56. sdg_hub-0.2.0.dist-info/RECORD +63 -0
  57. sdg_hub/blocks/__init__.py +0 -42
  58. sdg_hub/blocks/block.py +0 -96
  59. sdg_hub/blocks/llmblock.py +0 -375
  60. sdg_hub/blocks/openaichatblock.py +0 -556
  61. sdg_hub/blocks/utilblocks.py +0 -597
  62. sdg_hub/checkpointer.py +0 -139
  63. sdg_hub/configs/annotations/cot_reflection.yaml +0 -34
  64. sdg_hub/configs/annotations/detailed_annotations.yaml +0 -28
  65. sdg_hub/configs/annotations/detailed_description.yaml +0 -10
  66. sdg_hub/configs/annotations/detailed_description_icl.yaml +0 -32
  67. sdg_hub/configs/annotations/simple_annotations.yaml +0 -9
  68. sdg_hub/configs/knowledge/__init__.py +0 -0
  69. sdg_hub/configs/knowledge/atomic_facts.yaml +0 -46
  70. sdg_hub/configs/knowledge/auxilary_instructions.yaml +0 -35
  71. sdg_hub/configs/knowledge/detailed_summary.yaml +0 -18
  72. sdg_hub/configs/knowledge/evaluate_faithfulness.yaml +0 -68
  73. sdg_hub/configs/knowledge/evaluate_question.yaml +0 -38
  74. sdg_hub/configs/knowledge/evaluate_relevancy.yaml +0 -84
  75. sdg_hub/configs/knowledge/extractive_summary.yaml +0 -18
  76. sdg_hub/configs/knowledge/generate_code_questions_responses.yaml +0 -39
  77. sdg_hub/configs/knowledge/generate_questions.yaml +0 -82
  78. sdg_hub/configs/knowledge/generate_questions_responses.yaml +0 -56
  79. sdg_hub/configs/knowledge/generate_responses.yaml +0 -86
  80. sdg_hub/configs/knowledge/mcq_generation.yaml +0 -83
  81. sdg_hub/configs/knowledge/router.yaml +0 -12
  82. sdg_hub/configs/knowledge/simple_generate_qa.yaml +0 -34
  83. sdg_hub/configs/reasoning/__init__.py +0 -0
  84. sdg_hub/configs/reasoning/dynamic_cot.yaml +0 -40
  85. sdg_hub/configs/skills/__init__.py +0 -0
  86. sdg_hub/configs/skills/analyzer.yaml +0 -48
  87. sdg_hub/configs/skills/annotation.yaml +0 -36
  88. sdg_hub/configs/skills/contexts.yaml +0 -28
  89. sdg_hub/configs/skills/critic.yaml +0 -60
  90. sdg_hub/configs/skills/evaluate_freeform_pair.yaml +0 -111
  91. sdg_hub/configs/skills/evaluate_freeform_questions.yaml +0 -78
  92. sdg_hub/configs/skills/evaluate_grounded_pair.yaml +0 -119
  93. sdg_hub/configs/skills/evaluate_grounded_questions.yaml +0 -51
  94. sdg_hub/configs/skills/freeform_questions.yaml +0 -34
  95. sdg_hub/configs/skills/freeform_responses.yaml +0 -39
  96. sdg_hub/configs/skills/grounded_questions.yaml +0 -38
  97. sdg_hub/configs/skills/grounded_responses.yaml +0 -59
  98. sdg_hub/configs/skills/icl_examples/STEM.yaml +0 -56
  99. sdg_hub/configs/skills/icl_examples/__init__.py +0 -0
  100. sdg_hub/configs/skills/icl_examples/coding.yaml +0 -97
  101. sdg_hub/configs/skills/icl_examples/extraction.yaml +0 -36
  102. sdg_hub/configs/skills/icl_examples/humanities.yaml +0 -71
  103. sdg_hub/configs/skills/icl_examples/math.yaml +0 -85
  104. sdg_hub/configs/skills/icl_examples/reasoning.yaml +0 -30
  105. sdg_hub/configs/skills/icl_examples/roleplay.yaml +0 -45
  106. sdg_hub/configs/skills/icl_examples/writing.yaml +0 -80
  107. sdg_hub/configs/skills/judge.yaml +0 -53
  108. sdg_hub/configs/skills/planner.yaml +0 -67
  109. sdg_hub/configs/skills/respond.yaml +0 -8
  110. sdg_hub/configs/skills/revised_responder.yaml +0 -78
  111. sdg_hub/configs/skills/router.yaml +0 -59
  112. sdg_hub/configs/skills/simple_generate_qa_freeform.yaml +0 -27
  113. sdg_hub/configs/skills/simple_generate_qa_grounded.yaml +0 -31
  114. sdg_hub/flow.py +0 -477
  115. sdg_hub/flow_runner.py +0 -450
  116. sdg_hub/flows/generation/knowledge/mmlu_bench.yaml +0 -13
  117. sdg_hub/flows/generation/knowledge/simple_knowledge.yaml +0 -12
  118. sdg_hub/flows/generation/knowledge/synth_knowledge.yaml +0 -89
  119. sdg_hub/flows/generation/knowledge/synth_knowledge1.5.yaml +0 -148
  120. sdg_hub/flows/generation/skills/improve_responses.yaml +0 -103
  121. sdg_hub/flows/generation/skills/simple_freeform_skill.yaml +0 -12
  122. sdg_hub/flows/generation/skills/simple_grounded_skill.yaml +0 -12
  123. sdg_hub/flows/generation/skills/synth_grounded_skills.yaml +0 -80
  124. sdg_hub/flows/generation/skills/synth_skills.yaml +0 -59
  125. sdg_hub/pipeline.py +0 -121
  126. sdg_hub/prompts.py +0 -74
  127. sdg_hub/registry.py +0 -122
  128. sdg_hub/sdg.py +0 -206
  129. sdg_hub/utils/config_validation.py +0 -91
  130. sdg_hub/utils/error_handling.py +0 -94
  131. sdg_hub/utils/validation_result.py +0 -10
  132. sdg_hub-0.1.3.dist-info/METADATA +0 -190
  133. sdg_hub-0.1.3.dist-info/RECORD +0 -89
  134. sdg_hub/{logger_config.py → core/utils/logger_config.py} +1 -1
  135. /sdg_hub/{configs/__init__.py → flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/README.md} +0 -0
  136. /sdg_hub/{configs/annotations → flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab}/__init__.py +0 -0
  137. {sdg_hub-0.1.3.dist-info → sdg_hub-0.2.0.dist-info}/WHEEL +0 -0
  138. {sdg_hub-0.1.3.dist-info → sdg_hub-0.2.0.dist-info}/licenses/LICENSE +0 -0
  139. {sdg_hub-0.1.3.dist-info → sdg_hub-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,542 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Unified LLM chat block supporting all providers via LiteLLM."""
3
+
4
+ # Standard
5
+ from typing import Any, Optional, Union
6
+ import asyncio
7
+
8
+ # Third Party
9
+ from datasets import Dataset
10
+ from pydantic import Field, field_validator
11
+
12
+ # Local
13
+ from ...utils.error_handling import BlockValidationError
14
+ from ...utils.logger_config import setup_logger
15
+ from ..base import BaseBlock
16
+ from ..registry import BlockRegistry
17
+ from .client_manager import LLMClientManager
18
+ from .config import LLMConfig
19
+
20
+ logger = setup_logger(__name__)
21
+
22
+
23
+ @BlockRegistry.register(
24
+ "LLMChatBlock",
25
+ "llm",
26
+ "Unified LLM chat block supporting 100+ providers via LiteLLM",
27
+ )
28
+ class LLMChatBlock(BaseBlock):
29
+ """Unified LLM chat block supporting all providers via LiteLLM.
30
+
31
+ This block replaces OpenAIChatBlock and OpenAIAsyncChatBlock with a single
32
+ implementation that supports 100+ LLM providers through LiteLLM, including:
33
+ - OpenAI (GPT-3.5, GPT-4, etc.)
34
+ - Anthropic (Claude models)
35
+ - Google (Gemini, PaLM)
36
+ - Local models (vLLM, Ollama, etc.)
37
+ - And many more...
38
+
39
+ Parameters
40
+ ----------
41
+ block_name : str
42
+ Name of the block.
43
+ input_cols : Union[str, List[str]]
44
+ Input column name(s). Should contain the messages list.
45
+ output_cols : Union[str, List[str]]
46
+ Output column name(s) for the response. When n > 1, the column will contain
47
+ a list of responses instead of a single string.
48
+ model : str
49
+ Model identifier in LiteLLM format. Examples:
50
+ - "openai/gpt-4"
51
+ - "anthropic/claude-3-sonnet-20240229"
52
+ - "hosted_vllm/meta-llama/Llama-2-7b-chat-hf"
53
+ api_key : Optional[str], optional
54
+ API key for the provider. Falls back to environment variables.
55
+ api_base : Optional[str], optional
56
+ Base URL for the API. Required for local models.
57
+ async_mode : bool, optional
58
+ Whether to use async processing, by default False.
59
+ timeout : float, optional
60
+ Request timeout in seconds, by default 120.0.
61
+ max_retries : int, optional
62
+ Maximum number of retry attempts, by default 6.
63
+
64
+ ### Generation Parameters ###
65
+
66
+ temperature : Optional[float], optional
67
+ Sampling temperature (0.0 to 2.0).
68
+ max_tokens : Optional[int], optional
69
+ Maximum tokens to generate.
70
+ top_p : Optional[float], optional
71
+ Nucleus sampling parameter (0.0 to 1.0).
72
+ frequency_penalty : Optional[float], optional
73
+ Frequency penalty (-2.0 to 2.0).
74
+ presence_penalty : Optional[float], optional
75
+ Presence penalty (-2.0 to 2.0).
76
+ stop : Optional[Union[str, List[str]]], optional
77
+ Stop sequences.
78
+ seed : Optional[int], optional
79
+ Random seed for reproducible outputs.
80
+ response_format : Optional[Dict[str, Any]], optional
81
+ Response format specification (e.g., JSON mode).
82
+ stream : Optional[bool], optional
83
+ Whether to stream responses.
84
+ n : Optional[int], optional
85
+ Number of completions to generate. When n > 1, the output column will contain
86
+ a list of responses for each input sample.
87
+ logprobs : Optional[bool], optional
88
+ Whether to return log probabilities.
89
+ top_logprobs : Optional[int], optional
90
+ Number of top log probabilities to return.
91
+ user : Optional[str], optional
92
+ End-user identifier.
93
+ extra_headers : Optional[Dict[str, str]], optional
94
+ Additional headers to send with requests.
95
+ extra_body : Optional[Dict[str, Any]], optional
96
+ Additional parameters for the request body.
97
+ **kwargs : Any
98
+ Additional provider-specific parameters.
99
+
100
+ Examples
101
+ --------
102
+ >>> # OpenAI GPT-4
103
+ >>> block = LLMChatBlock(
104
+ ... block_name="gpt4_block",
105
+ ... input_cols="messages",
106
+ ... output_cols="response",
107
+ ... model="openai/gpt-4",
108
+ ... temperature=0.7
109
+ ... )
110
+
111
+ >>> # Anthropic Claude
112
+ >>> block = LLMChatBlock(
113
+ ... block_name="claude_block",
114
+ ... input_cols="messages",
115
+ ... output_cols="response",
116
+ ... model="anthropic/claude-3-sonnet-20240229",
117
+ ... temperature=0.7
118
+ ... )
119
+
120
+ >>> # Local vLLM model
121
+ >>> block = LLMChatBlock(
122
+ ... block_name="local_llama",
123
+ ... input_cols="messages",
124
+ ... output_cols="response",
125
+ ... model="hosted_vllm/meta-llama/Llama-2-7b-chat-hf",
126
+ ... api_base="http://localhost:8000/v1",
127
+ ... temperature=0.7
128
+ ... )
129
+
130
+ >>> # Multiple completions (n > 1)
131
+ >>> block = LLMChatBlock(
132
+ ... block_name="gpt4_multiple",
133
+ ... input_cols="messages",
134
+ ... output_cols="responses", # Will contain lists of strings
135
+ ... model="openai/gpt-4",
136
+ ... n=3, # Generate 3 responses per input
137
+ ... temperature=0.8
138
+ ... )
139
+ """
140
+
141
+ # LLM Configuration
142
+ model: Optional[str] = Field(None, description="Model identifier in LiteLLM format")
143
+ api_key: Optional[str] = Field(None, description="API key for the provider")
144
+ api_base: Optional[str] = Field(None, description="Base URL for the API")
145
+ async_mode: bool = Field(False, description="Whether to use async processing")
146
+ timeout: float = Field(120.0, description="Request timeout in seconds")
147
+ max_retries: int = Field(6, description="Maximum number of retry attempts")
148
+
149
+ # Generation parameters
150
+ temperature: Optional[float] = Field(
151
+ None, description="Sampling temperature (0.0 to 2.0)"
152
+ )
153
+ max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate")
154
+ top_p: Optional[float] = Field(
155
+ None, description="Nucleus sampling parameter (0.0 to 1.0)"
156
+ )
157
+ frequency_penalty: Optional[float] = Field(
158
+ None, description="Frequency penalty (-2.0 to 2.0)"
159
+ )
160
+ presence_penalty: Optional[float] = Field(
161
+ None, description="Presence penalty (-2.0 to 2.0)"
162
+ )
163
+ stop: Optional[Union[str, list[str]]] = Field(None, description="Stop sequences")
164
+ seed: Optional[int] = Field(
165
+ None, description="Random seed for reproducible outputs"
166
+ )
167
+ response_format: Optional[dict[str, Any]] = Field(
168
+ None, description="Response format specification"
169
+ )
170
+ stream: Optional[bool] = Field(None, description="Whether to stream responses")
171
+ n: Optional[int] = Field(None, description="Number of completions to generate")
172
+ logprobs: Optional[bool] = Field(
173
+ None, description="Whether to return log probabilities"
174
+ )
175
+ top_logprobs: Optional[int] = Field(
176
+ None, description="Number of top log probabilities to return"
177
+ )
178
+ user: Optional[str] = Field(None, description="End-user identifier")
179
+ extra_headers: Optional[dict[str, str]] = Field(
180
+ None, description="Additional headers"
181
+ )
182
+ extra_body: Optional[dict[str, Any]] = Field(
183
+ None, description="Additional request body parameters"
184
+ )
185
+ provider_specific: Optional[dict[str, Any]] = Field(
186
+ None, description="Provider-specific parameters"
187
+ )
188
+
189
+ # Exclude from serialization - internal computed field
190
+ client_manager: Optional[Any] = Field(
191
+ None, exclude=True, description="Internal client manager"
192
+ )
193
+
194
+ @field_validator("input_cols")
195
+ @classmethod
196
+ def validate_single_input_col(cls, v):
197
+ """Ensure exactly one input column."""
198
+ if isinstance(v, str):
199
+ return [v]
200
+ if isinstance(v, list) and len(v) == 1:
201
+ return v
202
+ if isinstance(v, list) and len(v) != 1:
203
+ raise ValueError(
204
+ f"LLMChatBlock expects exactly one input column, got {len(v)}: {v}"
205
+ )
206
+ raise ValueError(f"Invalid input_cols format: {v}")
207
+
208
+ @field_validator("output_cols")
209
+ @classmethod
210
+ def validate_single_output_col(cls, v):
211
+ """Ensure exactly one output column."""
212
+ if isinstance(v, str):
213
+ return [v]
214
+ if isinstance(v, list) and len(v) == 1:
215
+ return v
216
+ if isinstance(v, list) and len(v) != 1:
217
+ raise ValueError(
218
+ f"LLMChatBlock expects exactly one output column, got {len(v)}: {v}"
219
+ )
220
+ raise ValueError(f"Invalid output_cols format: {v}")
221
+
222
+ def model_post_init(self, __context) -> None:
223
+ """Initialize after Pydantic validation."""
224
+ super().model_post_init(__context)
225
+
226
+ # Initialize client manager
227
+ self._setup_client_manager()
228
+
229
+ def _setup_client_manager(self) -> None:
230
+ """Set up the LLM client manager with current configuration."""
231
+ # Create configuration with current values
232
+ config = LLMConfig(
233
+ model=self.model,
234
+ api_key=self.api_key,
235
+ api_base=self.api_base,
236
+ timeout=self.timeout,
237
+ max_retries=self.max_retries,
238
+ temperature=self.temperature,
239
+ max_tokens=self.max_tokens,
240
+ top_p=self.top_p,
241
+ frequency_penalty=self.frequency_penalty,
242
+ presence_penalty=self.presence_penalty,
243
+ stop=self.stop,
244
+ seed=self.seed,
245
+ response_format=self.response_format,
246
+ stream=self.stream,
247
+ n=self.n,
248
+ logprobs=self.logprobs,
249
+ top_logprobs=self.top_logprobs,
250
+ user=self.user,
251
+ extra_headers=self.extra_headers,
252
+ extra_body=self.extra_body,
253
+ provider_specific=self.provider_specific,
254
+ )
255
+
256
+ # Create client manager
257
+ self.client_manager = LLMClientManager(config)
258
+
259
+ # Load client immediately
260
+ self.client_manager.load()
261
+
262
+ # Log initialization only when model is configured
263
+ if self.model:
264
+ logger.info(
265
+ f"Initialized LLMChatBlock '{self.block_name}' with model '{self.model}'",
266
+ extra={
267
+ "block_name": self.block_name,
268
+ "model": self.model,
269
+ "provider": self.client_manager.config.get_provider(),
270
+ "is_local": self.client_manager.config.is_local_model(),
271
+ "async_mode": self.async_mode,
272
+ "generation_params": self.client_manager.config.get_generation_kwargs(),
273
+ },
274
+ )
275
+
276
+ def _reinitialize_client_manager(self) -> None:
277
+ """Reinitialize the client manager with updated model configuration.
278
+
279
+ This should be called after model configuration changes to ensure
280
+ the client manager uses the updated model, api_base, api_key, etc.
281
+ """
282
+ self._setup_client_manager()
283
+
284
+ def generate(self, samples: Dataset, **override_kwargs: dict[str, Any]) -> Dataset:
285
+ """Generate responses from the LLM.
286
+
287
+ Parameters set at runtime override those set during initialization.
288
+ Supports all LiteLLM parameters for the configured provider.
289
+
290
+ Parameters
291
+ ----------
292
+ samples : Dataset
293
+ Input dataset containing the messages column.
294
+ **override_kwargs : Dict[str, Any]
295
+ Runtime parameters that override initialization defaults.
296
+ Valid parameters depend on the provider but typically include:
297
+ temperature, max_tokens, top_p, frequency_penalty, presence_penalty,
298
+ stop, seed, response_format, stream, n, and provider-specific params.
299
+
300
+ Returns
301
+ -------
302
+ Dataset
303
+ Dataset with responses added to the output column.
304
+
305
+ Raises
306
+ ------
307
+ BlockValidationError
308
+ If model is not configured before calling generate().
309
+ """
310
+ # Validate that model is configured
311
+ if not self.model:
312
+ raise BlockValidationError(
313
+ f"Model not configured for block '{self.block_name}'. "
314
+ f"Call flow.set_model_config() before generating."
315
+ )
316
+
317
+ # Extract messages
318
+ messages_list = samples[self.input_cols[0]]
319
+
320
+ # Log generation start
321
+ logger.info(
322
+ f"Starting {'async' if self.async_mode else 'sync'} generation for {len(messages_list)} samples",
323
+ extra={
324
+ "block_name": self.block_name,
325
+ "model": self.model,
326
+ "provider": self.client_manager.config.get_provider(),
327
+ "batch_size": len(messages_list),
328
+ "async_mode": self.async_mode,
329
+ "override_params": override_kwargs,
330
+ },
331
+ )
332
+
333
+ # Generate responses
334
+ if self.async_mode:
335
+ responses = asyncio.run(
336
+ self._generate_async(messages_list, **override_kwargs)
337
+ )
338
+ else:
339
+ responses = self._generate_sync(messages_list, **override_kwargs)
340
+
341
+ # Log completion
342
+ logger.info(
343
+ f"Generation completed successfully for {len(responses)} samples",
344
+ extra={
345
+ "block_name": self.block_name,
346
+ "model": self.model,
347
+ "provider": self.client_manager.config.get_provider(),
348
+ "batch_size": len(responses),
349
+ },
350
+ )
351
+
352
+ # Add responses as new column
353
+ return samples.add_column(self.output_cols[0], responses)
354
+
355
+ def _generate_sync(
356
+ self,
357
+ messages_list: list[list[dict[str, Any]]],
358
+ **override_kwargs: dict[str, Any],
359
+ ) -> list[Union[str, list[str]]]:
360
+ """Generate responses synchronously.
361
+
362
+ Parameters
363
+ ----------
364
+ messages_list : List[List[Dict[str, Any]]]
365
+ List of message lists to process.
366
+ **override_kwargs : Dict[str, Any]
367
+ Runtime parameter overrides.
368
+
369
+ Returns
370
+ -------
371
+ List[Union[str, List[str]]]
372
+ List of response strings or lists of response strings (when n > 1).
373
+ """
374
+ responses = []
375
+
376
+ for i, messages in enumerate(messages_list):
377
+ try:
378
+ response = self.client_manager.create_completion(
379
+ messages, **override_kwargs
380
+ )
381
+ responses.append(response)
382
+
383
+ # Log progress for large batches
384
+ if (i + 1) % 10 == 0:
385
+ logger.debug(
386
+ f"Generated {i + 1}/{len(messages_list)} responses",
387
+ extra={
388
+ "block_name": self.block_name,
389
+ "progress": f"{i + 1}/{len(messages_list)}",
390
+ },
391
+ )
392
+
393
+ except Exception as e:
394
+ error_msg = self.client_manager.error_handler.format_error_message(
395
+ e, {"model": self.model, "sample_index": i}
396
+ )
397
+ logger.error(
398
+ f"Failed to generate response for sample {i}: {error_msg}",
399
+ extra={
400
+ "block_name": self.block_name,
401
+ "sample_index": i,
402
+ "error": str(e),
403
+ },
404
+ )
405
+ raise
406
+
407
+ return responses
408
+
409
+ async def _generate_async(
410
+ self,
411
+ messages_list: list[list[dict[str, Any]]],
412
+ **override_kwargs: dict[str, Any],
413
+ ) -> list[Union[str, list[str]]]:
414
+ """Generate responses asynchronously.
415
+
416
+ Parameters
417
+ ----------
418
+ messages_list : List[List[Dict[str, Any]]]
419
+ List of message lists to process.
420
+ **override_kwargs : Dict[str, Any]
421
+ Runtime parameter overrides.
422
+
423
+ Returns
424
+ -------
425
+ List[Union[str, List[str]]]
426
+ List of response strings or lists of response strings (when n > 1).
427
+ """
428
+ try:
429
+ responses = await self.client_manager.acreate_completions_batch(
430
+ messages_list, **override_kwargs
431
+ )
432
+ return responses
433
+
434
+ except Exception as e:
435
+ error_msg = self.client_manager.error_handler.format_error_message(
436
+ e, {"model": self.model}
437
+ )
438
+ logger.error(
439
+ f"Failed to generate async responses: {error_msg}",
440
+ extra={
441
+ "block_name": self.block_name,
442
+ "batch_size": len(messages_list),
443
+ "error": str(e),
444
+ },
445
+ )
446
+ raise
447
+
448
+ def get_model_info(self) -> dict[str, Any]:
449
+ """Get information about the configured model.
450
+
451
+ Returns
452
+ -------
453
+ Dict[str, Any]
454
+ Model information including provider, capabilities, etc.
455
+ """
456
+ return {
457
+ **self.client_manager.get_model_info(),
458
+ "block_name": self.block_name,
459
+ "input_column": self.input_cols[0],
460
+ "output_column": self.output_cols[0],
461
+ "async_mode": self.async_mode,
462
+ }
463
+
464
+ def _validate_custom(self, dataset: Dataset) -> None:
465
+ """Custom validation for LLMChatBlock message format.
466
+
467
+ Validates that all samples contain properly formatted messages.
468
+
469
+ Parameters
470
+ ----------
471
+ dataset : Dataset
472
+ The dataset to validate.
473
+
474
+ Raises
475
+ ------
476
+ BlockValidationError
477
+ If message format validation fails.
478
+ """
479
+
480
+ def validate_sample(sample_with_index):
481
+ """Validate a single sample's message format."""
482
+ idx, sample = sample_with_index
483
+ messages = sample[self.input_cols[0]]
484
+
485
+ # Validate messages is a list
486
+ if not isinstance(messages, list):
487
+ raise BlockValidationError(
488
+ f"Messages column '{self.input_cols[0]}' must contain a list, "
489
+ f"got {type(messages)} in row {idx}",
490
+ details=f"Block: {self.block_name}, Row: {idx}, Value: {messages}",
491
+ )
492
+
493
+ # Validate messages is not empty
494
+ if not messages:
495
+ raise BlockValidationError(
496
+ f"Messages list is empty in row {idx}",
497
+ details=f"Block: {self.block_name}, Row: {idx}",
498
+ )
499
+
500
+ # Validate each message format
501
+ for msg_idx, message in enumerate(messages):
502
+ if not isinstance(message, dict):
503
+ raise BlockValidationError(
504
+ f"Message {msg_idx} in row {idx} must be a dict, got {type(message)}",
505
+ details=f"Block: {self.block_name}, Row: {idx}, Message: {msg_idx}, Value: {message}",
506
+ )
507
+
508
+ # Validate required fields
509
+ if "role" not in message or message["role"] is None:
510
+ raise BlockValidationError(
511
+ f"Message {msg_idx} in row {idx} missing required 'role' field",
512
+ details=f"Block: {self.block_name}, Row: {idx}, Message: {msg_idx}, Available fields: {list(message.keys())}",
513
+ )
514
+
515
+ if "content" not in message or message["content"] is None:
516
+ raise BlockValidationError(
517
+ f"Message {msg_idx} in row {idx} missing required 'content' field",
518
+ details=f"Block: {self.block_name}, Row: {idx}, Message: {msg_idx}, Available fields: {list(message.keys())}",
519
+ )
520
+
521
+ return True # Return something for map
522
+
523
+ # Use map to validate all samples
524
+ # Add index to each sample for better error reporting
525
+ indexed_samples = [(i, sample) for i, sample in enumerate(dataset)]
526
+ list(map(validate_sample, indexed_samples))
527
+
528
+ def __del__(self) -> None:
529
+ """Cleanup when block is destroyed."""
530
+ try:
531
+ if hasattr(self, "client_manager"):
532
+ self.client_manager.unload()
533
+ except Exception:
534
+ # Ignore errors during cleanup to prevent issues during shutdown
535
+ pass
536
+
537
+ def __repr__(self) -> str:
538
+ """String representation of the block."""
539
+ return (
540
+ f"LLMChatBlock(name='{self.block_name}', model='{self.model}', "
541
+ f"provider='{self.client_manager.config.get_provider()}', async_mode={self.async_mode})"
542
+ )