sdg-hub 0.7.2__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdg_hub/_version.py +2 -2
- sdg_hub/core/__init__.py +13 -1
- sdg_hub/core/blocks/__init__.py +11 -2
- sdg_hub/core/blocks/agent/__init__.py +6 -0
- sdg_hub/core/blocks/agent/agent_block.py +397 -0
- sdg_hub/core/blocks/base.py +4 -1
- sdg_hub/core/blocks/filtering/column_value_filter.py +2 -0
- sdg_hub/core/blocks/llm/__init__.py +3 -2
- sdg_hub/core/blocks/llm/llm_chat_block.py +2 -0
- sdg_hub/core/blocks/llm/{llm_parser_block.py → llm_response_extractor_block.py} +32 -9
- sdg_hub/core/blocks/llm/prompt_builder_block.py +2 -0
- sdg_hub/core/blocks/llm/text_parser_block.py +2 -0
- sdg_hub/core/blocks/transform/duplicate_columns.py +2 -0
- sdg_hub/core/blocks/transform/index_based_mapper.py +2 -0
- sdg_hub/core/blocks/transform/json_structure_block.py +2 -0
- sdg_hub/core/blocks/transform/melt_columns.py +2 -0
- sdg_hub/core/blocks/transform/rename_columns.py +12 -0
- sdg_hub/core/blocks/transform/text_concat.py +2 -0
- sdg_hub/core/blocks/transform/uniform_col_val_setter.py +2 -0
- sdg_hub/core/connectors/__init__.py +46 -0
- sdg_hub/core/connectors/agent/__init__.py +10 -0
- sdg_hub/core/connectors/agent/base.py +233 -0
- sdg_hub/core/connectors/agent/langflow.py +151 -0
- sdg_hub/core/connectors/base.py +99 -0
- sdg_hub/core/connectors/exceptions.py +41 -0
- sdg_hub/core/connectors/http/__init__.py +6 -0
- sdg_hub/core/connectors/http/client.py +150 -0
- sdg_hub/core/connectors/registry.py +112 -0
- sdg_hub/core/flow/base.py +7 -31
- sdg_hub/core/utils/flow_metrics.py +3 -3
- sdg_hub/flows/evaluation/rag/flow.yaml +6 -6
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/detailed_summary/flow.yaml +4 -4
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/doc_direct_qa/flow.yaml +3 -3
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/extractive_summary/flow.yaml +4 -4
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/key_facts/flow.yaml +2 -2
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/flow.yaml +7 -7
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/multilingual/japanese/flow.yaml +7 -7
- sdg_hub/flows/text_analysis/structured_insights/flow.yaml +4 -4
- {sdg_hub-0.7.2.dist-info → sdg_hub-0.8.0.dist-info}/METADATA +2 -2
- {sdg_hub-0.7.2.dist-info → sdg_hub-0.8.0.dist-info}/RECORD +43 -32
- {sdg_hub-0.7.2.dist-info → sdg_hub-0.8.0.dist-info}/WHEEL +1 -1
- {sdg_hub-0.7.2.dist-info → sdg_hub-0.8.0.dist-info}/licenses/LICENSE +0 -0
- {sdg_hub-0.7.2.dist-info → sdg_hub-0.8.0.dist-info}/top_level.txt +0 -0
sdg_hub/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.
|
|
32
|
-
__version_tuple__ = version_tuple = (0,
|
|
31
|
+
__version__ = version = '0.8.0'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 8, 0)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
sdg_hub/core/__init__.py
CHANGED
|
@@ -2,14 +2,26 @@
|
|
|
2
2
|
"""Core SDG Hub components."""
|
|
3
3
|
|
|
4
4
|
# Local
|
|
5
|
-
from .blocks import BaseBlock, BlockRegistry
|
|
5
|
+
from .blocks import AgentBlock, BaseBlock, BlockRegistry
|
|
6
|
+
from .connectors import (
|
|
7
|
+
BaseConnector,
|
|
8
|
+
ConnectorConfig,
|
|
9
|
+
ConnectorError,
|
|
10
|
+
ConnectorRegistry,
|
|
11
|
+
)
|
|
6
12
|
from .flow import Flow, FlowMetadata, FlowRegistry, FlowValidator
|
|
7
13
|
from .utils import GenerateError, resolve_path
|
|
8
14
|
|
|
9
15
|
__all__ = [
|
|
10
16
|
# Block components
|
|
17
|
+
"AgentBlock",
|
|
11
18
|
"BaseBlock",
|
|
12
19
|
"BlockRegistry",
|
|
20
|
+
# Connector components
|
|
21
|
+
"BaseConnector",
|
|
22
|
+
"ConnectorConfig",
|
|
23
|
+
"ConnectorError",
|
|
24
|
+
"ConnectorRegistry",
|
|
13
25
|
# Flow components
|
|
14
26
|
"Flow",
|
|
15
27
|
"FlowRegistry",
|
sdg_hub/core/blocks/__init__.py
CHANGED
|
@@ -4,9 +4,16 @@ This package provides various block implementations for data generation, process
|
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
6
|
# Local
|
|
7
|
+
from .agent import AgentBlock
|
|
7
8
|
from .base import BaseBlock
|
|
8
9
|
from .filtering import ColumnValueFilterBlock
|
|
9
|
-
from .llm import
|
|
10
|
+
from .llm import (
|
|
11
|
+
LLMChatBlock,
|
|
12
|
+
LLMParserBlock,
|
|
13
|
+
LLMResponseExtractorBlock,
|
|
14
|
+
PromptBuilderBlock,
|
|
15
|
+
TextParserBlock,
|
|
16
|
+
)
|
|
10
17
|
from .registry import BlockRegistry
|
|
11
18
|
from .transform import (
|
|
12
19
|
DuplicateColumnsBlock,
|
|
@@ -18,6 +25,7 @@ from .transform import (
|
|
|
18
25
|
)
|
|
19
26
|
|
|
20
27
|
__all__ = [
|
|
28
|
+
"AgentBlock",
|
|
21
29
|
"BaseBlock",
|
|
22
30
|
"BlockRegistry",
|
|
23
31
|
"ColumnValueFilterBlock",
|
|
@@ -28,7 +36,8 @@ __all__ = [
|
|
|
28
36
|
"TextConcatBlock",
|
|
29
37
|
"UniformColumnValueSetter",
|
|
30
38
|
"LLMChatBlock",
|
|
31
|
-
"LLMParserBlock",
|
|
39
|
+
"LLMParserBlock", # Deprecated alias for LLMResponseExtractorBlock
|
|
40
|
+
"LLMResponseExtractorBlock",
|
|
32
41
|
"TextParserBlock",
|
|
33
42
|
"PromptBuilderBlock",
|
|
34
43
|
]
|
|
@@ -0,0 +1,397 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Agent block for integrating external agent frameworks."""
|
|
3
|
+
|
|
4
|
+
from typing import Any, Optional
|
|
5
|
+
import asyncio
|
|
6
|
+
import uuid
|
|
7
|
+
|
|
8
|
+
from pydantic import Field, PrivateAttr
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
import pandas as pd
|
|
11
|
+
|
|
12
|
+
from ...connectors.agent.base import BaseAgentConnector
|
|
13
|
+
from ...connectors.base import ConnectorConfig
|
|
14
|
+
from ...connectors.exceptions import ConnectorError
|
|
15
|
+
from ...connectors.registry import ConnectorRegistry
|
|
16
|
+
from ...utils.logger_config import setup_logger
|
|
17
|
+
from ..base import BaseBlock
|
|
18
|
+
from ..registry import BlockRegistry
|
|
19
|
+
|
|
20
|
+
logger = setup_logger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@BlockRegistry.register(
|
|
24
|
+
"AgentBlock",
|
|
25
|
+
category="agent",
|
|
26
|
+
description="Execute agent frameworks (Langflow, etc.) on DataFrame rows",
|
|
27
|
+
)
|
|
28
|
+
class AgentBlock(BaseBlock):
|
|
29
|
+
"""Block for executing external agent frameworks on DataFrame rows.
|
|
30
|
+
|
|
31
|
+
This block integrates with various agent frameworks through the connector
|
|
32
|
+
system. Each row in the DataFrame is processed by sending messages to the
|
|
33
|
+
agent and storing the response.
|
|
34
|
+
|
|
35
|
+
The block supports both sync and async execution modes for optimal
|
|
36
|
+
performance with large datasets.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
agent_framework : str
|
|
41
|
+
Name of the connector to use (e.g., 'langflow').
|
|
42
|
+
agent_url : str
|
|
43
|
+
API endpoint URL for the agent.
|
|
44
|
+
agent_api_key : str, optional
|
|
45
|
+
API key for authentication.
|
|
46
|
+
timeout : float
|
|
47
|
+
Request timeout in seconds. Default 120.0.
|
|
48
|
+
max_retries : int
|
|
49
|
+
Maximum retry attempts. Default 3.
|
|
50
|
+
session_id_col : str, optional
|
|
51
|
+
Column containing session IDs. If not provided, generates UUIDs.
|
|
52
|
+
async_mode : bool
|
|
53
|
+
Whether to use async execution. Default False.
|
|
54
|
+
max_concurrency : int
|
|
55
|
+
Maximum concurrent requests in async mode. Default 10.
|
|
56
|
+
|
|
57
|
+
Example YAML Configuration
|
|
58
|
+
--------------------------
|
|
59
|
+
```yaml
|
|
60
|
+
- block_type: AgentBlock
|
|
61
|
+
block_config:
|
|
62
|
+
block_name: my_agent
|
|
63
|
+
agent_framework: langflow
|
|
64
|
+
agent_url: http://localhost:7860/api/v1/run/my-flow
|
|
65
|
+
agent_api_key: ${LANGFLOW_API_KEY}
|
|
66
|
+
input_cols:
|
|
67
|
+
messages: messages_col
|
|
68
|
+
output_cols:
|
|
69
|
+
- agent_response
|
|
70
|
+
```
|
|
71
|
+
|
|
72
|
+
Example
|
|
73
|
+
-------
|
|
74
|
+
>>> block = AgentBlock(
|
|
75
|
+
... block_name="qa_agent",
|
|
76
|
+
... agent_framework="langflow",
|
|
77
|
+
... agent_url="http://localhost:7860/api/v1/run/qa-flow",
|
|
78
|
+
... input_cols={"messages": "question"},
|
|
79
|
+
... output_cols=["response"],
|
|
80
|
+
... )
|
|
81
|
+
>>> result_df = block(df)
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
# Required configuration
|
|
85
|
+
agent_framework: str = Field(
|
|
86
|
+
...,
|
|
87
|
+
description="Connector name (e.g., 'langflow')",
|
|
88
|
+
)
|
|
89
|
+
agent_url: str = Field(
|
|
90
|
+
...,
|
|
91
|
+
description="Agent API endpoint URL",
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# Optional configuration
|
|
95
|
+
agent_api_key: Optional[str] = Field(
|
|
96
|
+
None,
|
|
97
|
+
description="API key for authentication",
|
|
98
|
+
)
|
|
99
|
+
timeout: float = Field(
|
|
100
|
+
120.0,
|
|
101
|
+
description="Request timeout in seconds",
|
|
102
|
+
gt=0,
|
|
103
|
+
)
|
|
104
|
+
max_retries: int = Field(
|
|
105
|
+
3,
|
|
106
|
+
description="Maximum retry attempts",
|
|
107
|
+
ge=0,
|
|
108
|
+
)
|
|
109
|
+
session_id_col: Optional[str] = Field(
|
|
110
|
+
None,
|
|
111
|
+
description="Column containing session IDs",
|
|
112
|
+
)
|
|
113
|
+
async_mode: bool = Field(
|
|
114
|
+
False,
|
|
115
|
+
description="Use async execution for better throughput",
|
|
116
|
+
)
|
|
117
|
+
max_concurrency: int = Field(
|
|
118
|
+
10,
|
|
119
|
+
description="Maximum concurrent requests in async mode",
|
|
120
|
+
gt=0,
|
|
121
|
+
)
|
|
122
|
+
extract_response: bool = Field(
|
|
123
|
+
False,
|
|
124
|
+
description="Extract just the text content from agent response",
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Private attributes
|
|
128
|
+
_connector: Optional[BaseAgentConnector] = PrivateAttr(default=None)
|
|
129
|
+
_connector_config_key: Optional[tuple] = PrivateAttr(default=None)
|
|
130
|
+
|
|
131
|
+
def _get_connector(self) -> BaseAgentConnector:
|
|
132
|
+
"""Get or create the connector instance.
|
|
133
|
+
|
|
134
|
+
Invalidates the cached connector if the config has changed (e.g., due
|
|
135
|
+
to runtime overrides).
|
|
136
|
+
|
|
137
|
+
Returns
|
|
138
|
+
-------
|
|
139
|
+
BaseAgentConnector
|
|
140
|
+
The configured connector instance.
|
|
141
|
+
"""
|
|
142
|
+
config_key = (
|
|
143
|
+
self.agent_framework,
|
|
144
|
+
self.agent_url,
|
|
145
|
+
self.agent_api_key,
|
|
146
|
+
self.timeout,
|
|
147
|
+
self.max_retries,
|
|
148
|
+
self.extract_response,
|
|
149
|
+
)
|
|
150
|
+
if self._connector is None or self._connector_config_key != config_key:
|
|
151
|
+
connector_class = ConnectorRegistry.get(self.agent_framework)
|
|
152
|
+
config = ConnectorConfig(
|
|
153
|
+
url=self.agent_url,
|
|
154
|
+
api_key=self.agent_api_key,
|
|
155
|
+
timeout=self.timeout,
|
|
156
|
+
max_retries=self.max_retries,
|
|
157
|
+
extract_text=self.extract_response,
|
|
158
|
+
)
|
|
159
|
+
self._connector = connector_class(config=config)
|
|
160
|
+
self._connector_config_key = config_key
|
|
161
|
+
return self._connector
|
|
162
|
+
|
|
163
|
+
def _get_messages_col(self) -> str:
|
|
164
|
+
"""Get the input column name for messages.
|
|
165
|
+
|
|
166
|
+
Returns
|
|
167
|
+
-------
|
|
168
|
+
str
|
|
169
|
+
Column name containing messages.
|
|
170
|
+
"""
|
|
171
|
+
if isinstance(self.input_cols, dict):
|
|
172
|
+
if "messages" in self.input_cols:
|
|
173
|
+
return self.input_cols["messages"]
|
|
174
|
+
elif self.input_cols:
|
|
175
|
+
return list(self.input_cols.keys())[0]
|
|
176
|
+
else:
|
|
177
|
+
raise ConnectorError("input_cols must specify the messages column")
|
|
178
|
+
elif isinstance(self.input_cols, list) and len(self.input_cols) > 0:
|
|
179
|
+
return self.input_cols[0]
|
|
180
|
+
else:
|
|
181
|
+
raise ConnectorError("input_cols must specify the messages column")
|
|
182
|
+
|
|
183
|
+
def _get_output_col(self) -> str:
|
|
184
|
+
"""Get the output column name for responses.
|
|
185
|
+
|
|
186
|
+
Returns
|
|
187
|
+
-------
|
|
188
|
+
str
|
|
189
|
+
Column name for storing responses.
|
|
190
|
+
"""
|
|
191
|
+
if isinstance(self.output_cols, dict):
|
|
192
|
+
return list(self.output_cols.keys())[0]
|
|
193
|
+
elif isinstance(self.output_cols, list) and len(self.output_cols) > 0:
|
|
194
|
+
return self.output_cols[0]
|
|
195
|
+
else:
|
|
196
|
+
return "agent_response"
|
|
197
|
+
|
|
198
|
+
def _build_messages(self, content: Any) -> list[dict[str, Any]]:
|
|
199
|
+
"""Build message list from row content.
|
|
200
|
+
|
|
201
|
+
Parameters
|
|
202
|
+
----------
|
|
203
|
+
content : Any
|
|
204
|
+
Content from the DataFrame cell.
|
|
205
|
+
|
|
206
|
+
Returns
|
|
207
|
+
-------
|
|
208
|
+
list[dict]
|
|
209
|
+
List of messages in standard format.
|
|
210
|
+
"""
|
|
211
|
+
if isinstance(content, list):
|
|
212
|
+
# Already a message list
|
|
213
|
+
return content
|
|
214
|
+
elif isinstance(content, dict):
|
|
215
|
+
# Single message dict
|
|
216
|
+
return [content]
|
|
217
|
+
else:
|
|
218
|
+
# Plain text - wrap as user message
|
|
219
|
+
return [{"role": "user", "content": str(content)}]
|
|
220
|
+
|
|
221
|
+
def _get_session_id(self, row: pd.Series, idx: int) -> str:
|
|
222
|
+
"""Get session ID for a row.
|
|
223
|
+
|
|
224
|
+
Parameters
|
|
225
|
+
----------
|
|
226
|
+
row : pd.Series
|
|
227
|
+
DataFrame row.
|
|
228
|
+
idx : int
|
|
229
|
+
Row index.
|
|
230
|
+
|
|
231
|
+
Returns
|
|
232
|
+
-------
|
|
233
|
+
str
|
|
234
|
+
Session ID.
|
|
235
|
+
"""
|
|
236
|
+
if self.session_id_col and self.session_id_col in row:
|
|
237
|
+
return str(row[self.session_id_col])
|
|
238
|
+
return str(uuid.uuid4())
|
|
239
|
+
|
|
240
|
+
def _process_row_sync(
|
|
241
|
+
self,
|
|
242
|
+
row: pd.Series,
|
|
243
|
+
idx: int,
|
|
244
|
+
connector: BaseAgentConnector,
|
|
245
|
+
messages_col: str,
|
|
246
|
+
) -> dict[str, Any]:
|
|
247
|
+
"""Process a single row synchronously.
|
|
248
|
+
|
|
249
|
+
Parameters
|
|
250
|
+
----------
|
|
251
|
+
row : pd.Series
|
|
252
|
+
DataFrame row.
|
|
253
|
+
idx : int
|
|
254
|
+
Row index.
|
|
255
|
+
connector : BaseAgentConnector
|
|
256
|
+
Connector instance.
|
|
257
|
+
messages_col : str
|
|
258
|
+
Column containing messages.
|
|
259
|
+
|
|
260
|
+
Returns
|
|
261
|
+
-------
|
|
262
|
+
dict
|
|
263
|
+
Response from the agent.
|
|
264
|
+
"""
|
|
265
|
+
messages = self._build_messages(row[messages_col])
|
|
266
|
+
session_id = self._get_session_id(row, idx)
|
|
267
|
+
return connector.send(messages, session_id)
|
|
268
|
+
|
|
269
|
+
async def _process_row_async(
|
|
270
|
+
self,
|
|
271
|
+
row: pd.Series,
|
|
272
|
+
idx: int,
|
|
273
|
+
connector: BaseAgentConnector,
|
|
274
|
+
messages_col: str,
|
|
275
|
+
semaphore: asyncio.Semaphore,
|
|
276
|
+
) -> tuple[int, dict[str, Any]]:
|
|
277
|
+
"""Process a single row asynchronously.
|
|
278
|
+
|
|
279
|
+
Parameters
|
|
280
|
+
----------
|
|
281
|
+
row : pd.Series
|
|
282
|
+
DataFrame row.
|
|
283
|
+
idx : int
|
|
284
|
+
Row index.
|
|
285
|
+
connector : BaseAgentConnector
|
|
286
|
+
Connector instance.
|
|
287
|
+
messages_col : str
|
|
288
|
+
Column containing messages.
|
|
289
|
+
semaphore : asyncio.Semaphore
|
|
290
|
+
Semaphore for concurrency control.
|
|
291
|
+
|
|
292
|
+
Returns
|
|
293
|
+
-------
|
|
294
|
+
tuple[int, dict]
|
|
295
|
+
Row index and response.
|
|
296
|
+
"""
|
|
297
|
+
async with semaphore:
|
|
298
|
+
messages = self._build_messages(row[messages_col])
|
|
299
|
+
session_id = self._get_session_id(row, idx)
|
|
300
|
+
response = await connector.asend(messages, session_id)
|
|
301
|
+
return idx, response
|
|
302
|
+
|
|
303
|
+
async def _process_batch_async(
|
|
304
|
+
self,
|
|
305
|
+
df: pd.DataFrame,
|
|
306
|
+
connector: BaseAgentConnector,
|
|
307
|
+
messages_col: str,
|
|
308
|
+
) -> dict[int, dict[str, Any]]:
|
|
309
|
+
"""Process all rows asynchronously.
|
|
310
|
+
|
|
311
|
+
Parameters
|
|
312
|
+
----------
|
|
313
|
+
df : pd.DataFrame
|
|
314
|
+
Input DataFrame.
|
|
315
|
+
connector : BaseAgentConnector
|
|
316
|
+
Connector instance.
|
|
317
|
+
messages_col : str
|
|
318
|
+
Column containing messages.
|
|
319
|
+
|
|
320
|
+
Returns
|
|
321
|
+
-------
|
|
322
|
+
dict[int, dict]
|
|
323
|
+
Mapping from row index to response.
|
|
324
|
+
"""
|
|
325
|
+
semaphore = asyncio.Semaphore(self.max_concurrency)
|
|
326
|
+
tasks = [
|
|
327
|
+
self._process_row_async(row, idx, connector, messages_col, semaphore)
|
|
328
|
+
for idx, row in df.iterrows()
|
|
329
|
+
]
|
|
330
|
+
|
|
331
|
+
results = {}
|
|
332
|
+
for coro in tqdm(
|
|
333
|
+
asyncio.as_completed(tasks),
|
|
334
|
+
total=len(tasks),
|
|
335
|
+
desc=f"{self.block_name} (async)",
|
|
336
|
+
):
|
|
337
|
+
idx, response = await coro
|
|
338
|
+
results[idx] = response
|
|
339
|
+
|
|
340
|
+
return results
|
|
341
|
+
|
|
342
|
+
def generate(self, samples: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
|
|
343
|
+
"""Process DataFrame rows through the agent.
|
|
344
|
+
|
|
345
|
+
Parameters
|
|
346
|
+
----------
|
|
347
|
+
samples : pd.DataFrame
|
|
348
|
+
Input DataFrame with messages column.
|
|
349
|
+
**kwargs : Any
|
|
350
|
+
Runtime overrides.
|
|
351
|
+
|
|
352
|
+
Returns
|
|
353
|
+
-------
|
|
354
|
+
pd.DataFrame
|
|
355
|
+
DataFrame with agent responses added.
|
|
356
|
+
"""
|
|
357
|
+
df = samples.copy()
|
|
358
|
+
connector = self._get_connector()
|
|
359
|
+
messages_col = self._get_messages_col()
|
|
360
|
+
output_col = self._get_output_col()
|
|
361
|
+
|
|
362
|
+
if self.async_mode:
|
|
363
|
+
# Async execution
|
|
364
|
+
try:
|
|
365
|
+
asyncio.get_running_loop()
|
|
366
|
+
# Already in async context - use thread executor
|
|
367
|
+
import concurrent.futures
|
|
368
|
+
|
|
369
|
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
370
|
+
future = executor.submit(
|
|
371
|
+
asyncio.run,
|
|
372
|
+
self._process_batch_async(df, connector, messages_col),
|
|
373
|
+
)
|
|
374
|
+
results = future.result()
|
|
375
|
+
except RuntimeError:
|
|
376
|
+
# No event loop - create one
|
|
377
|
+
results = asyncio.run(
|
|
378
|
+
self._process_batch_async(df, connector, messages_col)
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
# Apply results
|
|
382
|
+
df[output_col] = df.index.map(results)
|
|
383
|
+
else:
|
|
384
|
+
# Sync execution with progress bar
|
|
385
|
+
responses = []
|
|
386
|
+
for idx, row in tqdm(
|
|
387
|
+
df.iterrows(),
|
|
388
|
+
total=len(df),
|
|
389
|
+
desc=self.block_name,
|
|
390
|
+
):
|
|
391
|
+
response = self._process_row_sync(row, idx, connector, messages_col)
|
|
392
|
+
responses.append(response)
|
|
393
|
+
|
|
394
|
+
df[output_col] = responses
|
|
395
|
+
|
|
396
|
+
logger.info(f"Processed {len(df)} rows with {self.agent_framework} agent")
|
|
397
|
+
return df
|
sdg_hub/core/blocks/base.py
CHANGED
|
@@ -49,6 +49,9 @@ class BaseBlock(BaseModel, ABC):
|
|
|
49
49
|
block_name: str = Field(
|
|
50
50
|
..., description="Unique identifier for this block instance"
|
|
51
51
|
)
|
|
52
|
+
block_type: Optional[str] = Field(
|
|
53
|
+
None, description="Block type (e.g., 'llm', 'transform', 'parser', 'filtering')"
|
|
54
|
+
)
|
|
52
55
|
input_cols: Union[str, list[str], dict[str, Any], None] = Field(
|
|
53
56
|
None, description="Input columns: str, list, or dict"
|
|
54
57
|
)
|
|
@@ -366,5 +369,5 @@ class BaseBlock(BaseModel, ABC):
|
|
|
366
369
|
Dict[str, Any]
|
|
367
370
|
"""
|
|
368
371
|
config = self.get_config()
|
|
369
|
-
config["
|
|
372
|
+
config["block_class"] = self.__class__.__name__
|
|
370
373
|
return config
|
|
@@ -46,6 +46,8 @@ DTYPE_MAP = {
|
|
|
46
46
|
"Filters datasets based on column values using various comparison operations",
|
|
47
47
|
)
|
|
48
48
|
class ColumnValueFilterBlock(BaseBlock):
|
|
49
|
+
block_type: str = "filtering"
|
|
50
|
+
|
|
49
51
|
"""A block for filtering datasets based on column values.
|
|
50
52
|
|
|
51
53
|
This block allows filtering of datasets using various operations (e.g., equals, contains)
|
|
@@ -9,7 +9,7 @@ local models (vLLM, Ollama), and more.
|
|
|
9
9
|
# Local
|
|
10
10
|
from .error_handler import ErrorCategory, LLMErrorHandler
|
|
11
11
|
from .llm_chat_block import LLMChatBlock
|
|
12
|
-
from .
|
|
12
|
+
from .llm_response_extractor_block import LLMParserBlock, LLMResponseExtractorBlock
|
|
13
13
|
from .prompt_builder_block import PromptBuilderBlock
|
|
14
14
|
from .text_parser_block import TextParserBlock
|
|
15
15
|
|
|
@@ -17,7 +17,8 @@ __all__ = [
|
|
|
17
17
|
"LLMErrorHandler",
|
|
18
18
|
"ErrorCategory",
|
|
19
19
|
"LLMChatBlock",
|
|
20
|
-
"LLMParserBlock",
|
|
20
|
+
"LLMParserBlock", # Deprecated alias for LLMResponseExtractorBlock
|
|
21
|
+
"LLMResponseExtractorBlock",
|
|
21
22
|
"PromptBuilderBlock",
|
|
22
23
|
"TextParserBlock",
|
|
23
24
|
]
|
|
@@ -32,6 +32,8 @@ logger = setup_logger(__name__)
|
|
|
32
32
|
class LLMChatBlock(BaseBlock):
|
|
33
33
|
model_config = ConfigDict(extra="allow")
|
|
34
34
|
|
|
35
|
+
block_type: str = "llm"
|
|
36
|
+
|
|
35
37
|
"""Unified LLM chat block supporting all providers via LiteLLM.
|
|
36
38
|
|
|
37
39
|
This block provides a minimal wrapper around LiteLLM's completion API,
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
-
"""LLM
|
|
2
|
+
"""LLM response extractor block for extracting fields from LLM response objects.
|
|
3
3
|
|
|
4
|
-
This module provides the
|
|
4
|
+
This module provides the LLMResponseExtractorBlock for extracting specific fields
|
|
5
5
|
(content, reasoning_content, tool_calls) from chat completion response objects.
|
|
6
6
|
"""
|
|
7
7
|
|
|
@@ -22,13 +22,15 @@ logger = setup_logger(__name__)
|
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
@BlockRegistry.register(
|
|
25
|
-
"
|
|
25
|
+
"LLMResponseExtractorBlock",
|
|
26
26
|
"llm",
|
|
27
27
|
"Extracts specified fields from LLM response objects",
|
|
28
28
|
)
|
|
29
|
-
class
|
|
29
|
+
class LLMResponseExtractorBlock(BaseBlock):
|
|
30
30
|
_flow_requires_jsonl_tmp: bool = True
|
|
31
31
|
|
|
32
|
+
block_type: str = "llm_util"
|
|
33
|
+
|
|
32
34
|
"""Block for extracting fields from LLM response objects.
|
|
33
35
|
|
|
34
36
|
This block extracts specified fields from chat completion response objects.
|
|
@@ -88,7 +90,7 @@ class LLMParserBlock(BaseBlock):
|
|
|
88
90
|
]
|
|
89
91
|
):
|
|
90
92
|
raise ValueError(
|
|
91
|
-
"
|
|
93
|
+
"LLMResponseExtractorBlock requires at least one extraction field to be enabled: "
|
|
92
94
|
"extract_content, extract_reasoning_content, or extract_tool_calls"
|
|
93
95
|
)
|
|
94
96
|
|
|
@@ -106,7 +108,7 @@ class LLMParserBlock(BaseBlock):
|
|
|
106
108
|
return self
|
|
107
109
|
|
|
108
110
|
def _validate_custom(self, dataset: pd.DataFrame) -> None:
|
|
109
|
-
"""Validate
|
|
111
|
+
"""Validate LLMResponseExtractorBlock specific requirements.
|
|
110
112
|
|
|
111
113
|
Parameters
|
|
112
114
|
----------
|
|
@@ -116,14 +118,16 @@ class LLMParserBlock(BaseBlock):
|
|
|
116
118
|
Raises
|
|
117
119
|
------
|
|
118
120
|
ValueError
|
|
119
|
-
If
|
|
121
|
+
If LLMResponseExtractorBlock requirements are not met.
|
|
120
122
|
"""
|
|
121
123
|
# Validate that we have exactly one input column
|
|
122
124
|
if len(self.input_cols) == 0:
|
|
123
|
-
raise ValueError(
|
|
125
|
+
raise ValueError(
|
|
126
|
+
"LLMResponseExtractorBlock expects at least one input column"
|
|
127
|
+
)
|
|
124
128
|
if len(self.input_cols) > 1:
|
|
125
129
|
logger.warning(
|
|
126
|
-
f"
|
|
130
|
+
f"LLMResponseExtractorBlock expects exactly one input column, but got {len(self.input_cols)}. "
|
|
127
131
|
f"Using the first column: {self.input_cols[0]}"
|
|
128
132
|
)
|
|
129
133
|
|
|
@@ -324,3 +328,22 @@ class LLMParserBlock(BaseBlock):
|
|
|
324
328
|
new_data.extend(self._generate(sample))
|
|
325
329
|
|
|
326
330
|
return pd.DataFrame(new_data)
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
# Backwards compatibility alias (deprecated)
|
|
334
|
+
# Register deprecated alias in BlockRegistry so old YAML flows still work
|
|
335
|
+
@BlockRegistry.register(
|
|
336
|
+
"LLMParserBlock",
|
|
337
|
+
"llm",
|
|
338
|
+
"Deprecated: Use LLMResponseExtractorBlock instead",
|
|
339
|
+
deprecated=True,
|
|
340
|
+
replacement="LLMResponseExtractorBlock",
|
|
341
|
+
)
|
|
342
|
+
class LLMParserBlock(LLMResponseExtractorBlock):
|
|
343
|
+
"""Deprecated alias for LLMResponseExtractorBlock.
|
|
344
|
+
|
|
345
|
+
This class exists for backwards compatibility with existing code and YAML flows.
|
|
346
|
+
Use LLMResponseExtractorBlock instead.
|
|
347
|
+
"""
|
|
348
|
+
|
|
349
|
+
pass
|
|
@@ -222,6 +222,8 @@ class PromptRenderer:
|
|
|
222
222
|
"Formats prompts into structured chat messages or plain text using Jinja templates",
|
|
223
223
|
)
|
|
224
224
|
class PromptBuilderBlock(BaseBlock):
|
|
225
|
+
block_type: str = "llm_util"
|
|
226
|
+
|
|
225
227
|
"""Block for formatting prompts into structured chat messages or plain text.
|
|
226
228
|
|
|
227
229
|
This block takes input from dataset columns, applies Jinja templates from a YAML config
|
|
@@ -30,6 +30,8 @@ logger = setup_logger(__name__)
|
|
|
30
30
|
class TextParserBlock(BaseBlock):
|
|
31
31
|
_flow_requires_jsonl_tmp: bool = True
|
|
32
32
|
|
|
33
|
+
block_type: str = "parser"
|
|
34
|
+
|
|
33
35
|
"""Block for parsing and post-processing text content.
|
|
34
36
|
|
|
35
37
|
This block handles text parsing using start/end tags, custom regex patterns,
|
|
@@ -27,6 +27,8 @@ logger = setup_logger(__name__)
|
|
|
27
27
|
"Duplicates existing columns with new names according to a mapping specification",
|
|
28
28
|
)
|
|
29
29
|
class DuplicateColumnsBlock(BaseBlock):
|
|
30
|
+
block_type: str = "transform"
|
|
31
|
+
|
|
30
32
|
"""Block for duplicating existing columns with new names.
|
|
31
33
|
|
|
32
34
|
This block creates copies of existing columns with new names according to a mapping specification.
|
|
@@ -28,6 +28,8 @@ logger = setup_logger(__name__)
|
|
|
28
28
|
"Maps values from source columns to output columns based on choice columns using shared mapping",
|
|
29
29
|
)
|
|
30
30
|
class IndexBasedMapperBlock(BaseBlock):
|
|
31
|
+
block_type: str = "transform"
|
|
32
|
+
|
|
31
33
|
"""Block for mapping values from source columns to output columns based on choice columns.
|
|
32
34
|
|
|
33
35
|
This block uses a shared mapping dictionary to select values from source columns and
|
|
@@ -28,6 +28,8 @@ logger = setup_logger(__name__)
|
|
|
28
28
|
"Combines multiple columns into a single column containing a structured JSON object",
|
|
29
29
|
)
|
|
30
30
|
class JSONStructureBlock(BaseBlock):
|
|
31
|
+
block_type: str = "transform"
|
|
32
|
+
|
|
31
33
|
"""Block for combining multiple columns into a structured JSON object.
|
|
32
34
|
|
|
33
35
|
This block takes values from multiple input columns and combines them into a single
|
|
@@ -28,6 +28,8 @@ logger = setup_logger(__name__)
|
|
|
28
28
|
"Transforms wide dataset format into long format by melting columns into rows",
|
|
29
29
|
)
|
|
30
30
|
class MeltColumnsBlock(BaseBlock):
|
|
31
|
+
block_type: str = "transform"
|
|
32
|
+
|
|
31
33
|
"""Block for flattening multiple columns into a long format.
|
|
32
34
|
|
|
33
35
|
This block transforms a wide dataset format into a long format by melting
|