sdg-hub 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdg_hub/_version.py +2 -2
- sdg_hub/core/blocks/llm/client_manager.py +37 -25
- sdg_hub/core/blocks/llm/llm_chat_block.py +12 -9
- sdg_hub/core/blocks/llm/text_parser_block.py +88 -21
- sdg_hub/core/blocks/transform/__init__.py +2 -0
- sdg_hub/core/blocks/transform/json_structure_block.py +142 -0
- sdg_hub/core/flow/base.py +199 -56
- sdg_hub/core/utils/datautils.py +27 -2
- sdg_hub/core/utils/flow_metrics.py +261 -0
- sdg_hub/core/utils/logger_config.py +50 -9
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/__init__.py +0 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/detailed_summary/__init__.py +0 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/detailed_summary/detailed_summary.yaml +11 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/detailed_summary/flow.yaml +159 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/extractive_summary/__init__.py +0 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/extractive_summary/extractive_summary.yaml +65 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/extractive_summary/flow.yaml +161 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/generate_answers.yaml +15 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/generate_multiple_qa.yaml +21 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/generate_question_list.yaml +44 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/key_facts/__init__.py +0 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/key_facts/flow.yaml +104 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/key_facts/key_facts_summary.yaml +61 -0
- sdg_hub/flows/text_analysis/__init__.py +2 -0
- sdg_hub/flows/text_analysis/structured_insights/__init__.py +6 -0
- sdg_hub/flows/text_analysis/structured_insights/analyze_sentiment.yaml +27 -0
- sdg_hub/flows/text_analysis/structured_insights/extract_entities.yaml +38 -0
- sdg_hub/flows/text_analysis/structured_insights/extract_keywords.yaml +21 -0
- sdg_hub/flows/text_analysis/structured_insights/flow.yaml +153 -0
- sdg_hub/flows/text_analysis/structured_insights/summarize.yaml +21 -0
- {sdg_hub-0.2.2.dist-info → sdg_hub-0.3.0.dist-info}/METADATA +3 -1
- {sdg_hub-0.2.2.dist-info → sdg_hub-0.3.0.dist-info}/RECORD +35 -13
- {sdg_hub-0.2.2.dist-info → sdg_hub-0.3.0.dist-info}/WHEEL +0 -0
- {sdg_hub-0.2.2.dist-info → sdg_hub-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {sdg_hub-0.2.2.dist-info → sdg_hub-0.3.0.dist-info}/top_level.txt +0 -0
sdg_hub/_version.py
CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
28
28
|
commit_id: COMMIT_ID
|
29
29
|
__commit_id__: COMMIT_ID
|
30
30
|
|
31
|
-
__version__ = version = '0.
|
32
|
-
__version_tuple__ = version_tuple = (0,
|
31
|
+
__version__ = version = '0.3.0'
|
32
|
+
__version_tuple__ = version_tuple = (0, 3, 0)
|
33
33
|
|
34
34
|
__commit_id__ = commit_id = None
|
@@ -107,9 +107,18 @@ class LLMClientManager:
|
|
107
107
|
f"Could not validate setup for model '{self.config.model}': {e}"
|
108
108
|
)
|
109
109
|
|
110
|
+
def _message_to_dict(self, message: Any) -> dict[str, Any]:
|
111
|
+
"""Convert a message to a dict."""
|
112
|
+
if hasattr(message, "to_dict"):
|
113
|
+
return message.to_dict()
|
114
|
+
elif hasattr(message, "__dict__"):
|
115
|
+
return message.__dict__
|
116
|
+
else:
|
117
|
+
return dict(message)
|
118
|
+
|
110
119
|
def create_completion(
|
111
120
|
self, messages: list[dict[str, Any]], **overrides: Any
|
112
|
-
) -> Union[
|
121
|
+
) -> Union[dict, list[dict]]:
|
113
122
|
"""Create a completion using LiteLLM.
|
114
123
|
|
115
124
|
Parameters
|
@@ -121,9 +130,9 @@ class LLMClientManager:
|
|
121
130
|
|
122
131
|
Returns
|
123
132
|
-------
|
124
|
-
Union[
|
125
|
-
The completion
|
126
|
-
returns a list of
|
133
|
+
Union[dict, List[dict]]
|
134
|
+
The completion response(s). Returns a single response when n=1 or n is None,
|
135
|
+
returns a list of responses when n>1. Response dicts contain 'content' and may contain 'reasoning_content'.
|
127
136
|
|
128
137
|
Raises
|
129
138
|
------
|
@@ -151,28 +160,30 @@ class LLMClientManager:
|
|
151
160
|
# Make the completion call
|
152
161
|
response = completion_func(kwargs)
|
153
162
|
|
154
|
-
# Extract
|
163
|
+
# Extract message objects from response
|
155
164
|
# Check if n > 1 to determine return type
|
156
165
|
n_value = final_config.n or 1
|
157
166
|
if n_value > 1:
|
158
|
-
return [
|
167
|
+
return [
|
168
|
+
self._message_to_dict(choice.message) for choice in response.choices
|
169
|
+
]
|
159
170
|
else:
|
160
|
-
return response.choices[0].message
|
171
|
+
return self._message_to_dict(response.choices[0].message)
|
161
172
|
|
162
173
|
async def acreate_completion(
|
163
174
|
self,
|
164
175
|
messages: Union[list[dict[str, Any]], list[list[dict[str, Any]]]],
|
165
176
|
max_concurrency: Optional[int] = None,
|
166
177
|
**overrides: Any,
|
167
|
-
) -> Union[
|
178
|
+
) -> Union[dict, list[dict]] | list[Union[dict, list[dict]]]:
|
168
179
|
"""Create async completion(s) using LiteLLM with optional concurrency control.
|
169
180
|
|
170
181
|
Parameters
|
171
182
|
----------
|
172
183
|
messages : Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]
|
173
184
|
Single message list or list of message lists.
|
174
|
-
- For single: List[Dict[str, Any]] - returns Union[
|
175
|
-
- For multiple: List[List[Dict[str, Any]]] - returns List[Union[
|
185
|
+
- For single: List[Dict[str, Any]] - returns Union[Any, List[Any]]
|
186
|
+
- For multiple: List[List[Dict[str, Any]]] - returns List[Union[Any, List[Any]]]
|
176
187
|
max_concurrency : Optional[int], optional
|
177
188
|
Maximum number of concurrent requests when processing multiple messages.
|
178
189
|
If None, all requests run concurrently.
|
@@ -181,9 +192,9 @@ class LLMClientManager:
|
|
181
192
|
|
182
193
|
Returns
|
183
194
|
-------
|
184
|
-
Union[
|
185
|
-
For single message: completion
|
186
|
-
For multiple messages: list of completion
|
195
|
+
Union[dict, List[dict], List[Union[dict, List[dict]]]]
|
196
|
+
For single message: completion response (dict when n=1, List[dict] when n>1)
|
197
|
+
For multiple messages: list of completion responses (each element can be dict or List[dict])
|
187
198
|
|
188
199
|
Raises
|
189
200
|
------
|
@@ -221,7 +232,7 @@ class LLMClientManager:
|
|
221
232
|
|
222
233
|
async def _acreate_single(
|
223
234
|
self, messages: list[dict[str, Any]], **overrides: Any
|
224
|
-
) -> Union[
|
235
|
+
) -> Union[dict, list[dict]]:
|
225
236
|
"""Create a single async completion using LiteLLM.
|
226
237
|
|
227
238
|
Parameters
|
@@ -233,10 +244,9 @@ class LLMClientManager:
|
|
233
244
|
|
234
245
|
Returns
|
235
246
|
-------
|
236
|
-
Union[
|
237
|
-
|
238
|
-
|
239
|
-
|
247
|
+
Union[dict, List[dict]]
|
248
|
+
List of completion message objects. Each element is a dict when n=1 or n is None,
|
249
|
+
or a list of dicts when n>1. Message dicts contain 'content' and may contain 'reasoning_content'.
|
240
250
|
Raises
|
241
251
|
------
|
242
252
|
Exception
|
@@ -263,17 +273,19 @@ class LLMClientManager:
|
|
263
273
|
# Make the async completion call
|
264
274
|
response = await completion_func(kwargs)
|
265
275
|
|
266
|
-
# Extract
|
276
|
+
# Extract message objects from response
|
267
277
|
# Check if n > 1 to determine return type
|
268
278
|
n_value = final_config.n or 1
|
269
279
|
if n_value > 1:
|
270
|
-
return [
|
280
|
+
return [
|
281
|
+
self._message_to_dict(choice.message) for choice in response.choices
|
282
|
+
]
|
271
283
|
else:
|
272
|
-
return response.choices[0].message
|
284
|
+
return self._message_to_dict(response.choices[0].message)
|
273
285
|
|
274
286
|
def create_completions_batch(
|
275
287
|
self, messages_list: list[list[dict[str, Any]]], **overrides: Any
|
276
|
-
) -> list[Union[
|
288
|
+
) -> list[Union[dict, list[dict]]]:
|
277
289
|
"""Create multiple completions in batch.
|
278
290
|
|
279
291
|
Parameters
|
@@ -285,9 +297,9 @@ class LLMClientManager:
|
|
285
297
|
|
286
298
|
Returns
|
287
299
|
-------
|
288
|
-
List[
|
289
|
-
List of completion
|
290
|
-
or a list of
|
300
|
+
List[dict] | List[List[dict]]
|
301
|
+
List of completion responses. Each element is a dict when n=1 or n is None,
|
302
|
+
or a list of dicts when n>1. Response dicts contain 'content' and may contain 'reasoning_content'.
|
291
303
|
"""
|
292
304
|
results = []
|
293
305
|
for messages in messages_list:
|
@@ -42,9 +42,10 @@ class LLMChatBlock(BaseBlock):
|
|
42
42
|
Name of the block.
|
43
43
|
input_cols : Union[str, List[str]]
|
44
44
|
Input column name(s). Should contain the messages list.
|
45
|
-
output_cols : Union[
|
45
|
+
output_cols : Union[dict, List[dict]]
|
46
46
|
Output column name(s) for the response. When n > 1, the column will contain
|
47
|
-
a list of responses instead of a single
|
47
|
+
a list of responses instead of a single response. Responses contain 'content',
|
48
|
+
may contain 'reasoning_content' and other fields if any.
|
48
49
|
model : str
|
49
50
|
Model identifier in LiteLLM format. Examples:
|
50
51
|
- "openai/gpt-4"
|
@@ -131,7 +132,7 @@ class LLMChatBlock(BaseBlock):
|
|
131
132
|
>>> block = LLMChatBlock(
|
132
133
|
... block_name="gpt4_multiple",
|
133
134
|
... input_cols="messages",
|
134
|
-
... output_cols="responses", # Will contain lists of
|
135
|
+
... output_cols="responses", # Will contain lists of responses
|
135
136
|
... model="openai/gpt-4",
|
136
137
|
... n=3, # Generate 3 responses per input
|
137
138
|
... temperature=0.8
|
@@ -406,7 +407,7 @@ class LLMChatBlock(BaseBlock):
|
|
406
407
|
self,
|
407
408
|
messages_list: list[list[dict[str, Any]]],
|
408
409
|
**override_kwargs: dict[str, Any],
|
409
|
-
) -> list[Union[
|
410
|
+
) -> list[Union[dict, list[dict]]]:
|
410
411
|
"""Generate responses synchronously.
|
411
412
|
|
412
413
|
Parameters
|
@@ -418,8 +419,9 @@ class LLMChatBlock(BaseBlock):
|
|
418
419
|
|
419
420
|
Returns
|
420
421
|
-------
|
421
|
-
List[Union[
|
422
|
-
List of
|
422
|
+
List[Union[dict, List[dict]]]
|
423
|
+
List of responses. Each element is a dict when n=1 or n is None,
|
424
|
+
or a list of dicts when n>1. Response dicts contain 'content', may contain 'reasoning_content' and other fields if any.
|
423
425
|
"""
|
424
426
|
responses = []
|
425
427
|
|
@@ -461,7 +463,7 @@ class LLMChatBlock(BaseBlock):
|
|
461
463
|
messages_list: list[list[dict[str, Any]]],
|
462
464
|
flow_max_concurrency: Optional[int] = None,
|
463
465
|
**override_kwargs: dict[str, Any],
|
464
|
-
) -> list[Union[
|
466
|
+
) -> list[Union[dict, list[dict]]]:
|
465
467
|
"""Generate responses asynchronously.
|
466
468
|
|
467
469
|
Parameters
|
@@ -475,8 +477,9 @@ class LLMChatBlock(BaseBlock):
|
|
475
477
|
|
476
478
|
Returns
|
477
479
|
-------
|
478
|
-
List[Union[
|
479
|
-
List of
|
480
|
+
List[Union[dict, List[dict]]]
|
481
|
+
List of responses. Each element is a dict when n=1 or n is None,
|
482
|
+
or a list of dicts when n>1. Response dicts contain 'content', may contain 'reasoning_content' and other fields if any.
|
480
483
|
"""
|
481
484
|
try:
|
482
485
|
# Use unified client manager method with optional concurrency control
|
@@ -51,6 +51,10 @@ class TextParserBlock(BaseBlock):
|
|
51
51
|
expand_lists : bool
|
52
52
|
Whether to expand list inputs into individual rows (True) or preserve lists (False).
|
53
53
|
Default is True for backward compatibility.
|
54
|
+
save_reasoning_content : bool
|
55
|
+
Whether to save the reasoning content to the output.
|
56
|
+
reasoning_content_field : Optional[str]
|
57
|
+
The field name of the reasoning content to save to the output.
|
54
58
|
"""
|
55
59
|
|
56
60
|
start_tags: list[str] = Field(
|
@@ -69,6 +73,14 @@ class TextParserBlock(BaseBlock):
|
|
69
73
|
default=True,
|
70
74
|
description="Whether to expand list inputs into individual rows (True) or preserve lists (False). ",
|
71
75
|
)
|
76
|
+
save_reasoning_content: bool = Field(
|
77
|
+
default=False,
|
78
|
+
description="Whether to save the reasoning content to the output.",
|
79
|
+
)
|
80
|
+
reasoning_content_field: Optional[str] = Field(
|
81
|
+
default="reasoning_content",
|
82
|
+
description="The field name of the reasoning content to save to the output.",
|
83
|
+
)
|
72
84
|
|
73
85
|
@field_validator("start_tags", "end_tags", mode="before")
|
74
86
|
@classmethod
|
@@ -234,6 +246,27 @@ class TextParserBlock(BaseBlock):
|
|
234
246
|
value = value.replace(clean_tag, "")
|
235
247
|
return value
|
236
248
|
|
249
|
+
def _handle_message(self, sample: dict) -> dict[str, list[str]]:
|
250
|
+
if "content" not in sample:
|
251
|
+
logger.warning(f"Content not found in sample: {sample}")
|
252
|
+
return {}
|
253
|
+
parsed_output = self._parse(sample["content"])
|
254
|
+
if self.save_reasoning_content:
|
255
|
+
parsed_output[self.reasoning_content_field] = [
|
256
|
+
self._get_reasoning_content(sample)
|
257
|
+
]
|
258
|
+
return parsed_output
|
259
|
+
|
260
|
+
def _get_reasoning_content(self, sample: dict) -> str:
|
261
|
+
if self.save_reasoning_content:
|
262
|
+
if self.reasoning_content_field in sample:
|
263
|
+
return sample[self.reasoning_content_field]
|
264
|
+
else:
|
265
|
+
logger.warning(
|
266
|
+
f"Reasoning content field '{self.reasoning_content_field}' not found in response"
|
267
|
+
)
|
268
|
+
return ""
|
269
|
+
|
237
270
|
def _generate(self, sample: dict) -> list[dict]:
|
238
271
|
input_column = self.input_cols[0]
|
239
272
|
raw_output = sample[input_column]
|
@@ -250,21 +283,24 @@ class TextParserBlock(BaseBlock):
|
|
250
283
|
all_parsed_outputs = {col: [] for col in self.output_cols}
|
251
284
|
valid_responses = 0
|
252
285
|
|
253
|
-
for i,
|
254
|
-
if not
|
286
|
+
for i, message in enumerate(raw_output):
|
287
|
+
if not message:
|
255
288
|
logger.warning(
|
256
|
-
f"List item {i} in column '{input_column}'
|
257
|
-
f"(empty or non-string): {type(response)}"
|
289
|
+
f"List item {i} in column '{input_column}' is empty"
|
258
290
|
)
|
259
291
|
continue
|
260
292
|
|
261
|
-
parsed_outputs = self.
|
293
|
+
parsed_outputs = self._handle_message(message)
|
294
|
+
if self.save_reasoning_content:
|
295
|
+
reasoning_content = parsed_outputs.pop(
|
296
|
+
self.reasoning_content_field
|
297
|
+
)
|
262
298
|
|
263
299
|
if not parsed_outputs or not any(
|
264
300
|
len(value) > 0 for value in parsed_outputs.values()
|
265
301
|
):
|
266
302
|
logger.warning(
|
267
|
-
f"Failed to parse content from list item {i}. Raw output length: {len(
|
303
|
+
f"Failed to parse content from list item {i}. Raw output length: {len(message)}, "
|
268
304
|
f"parsing method: {'regex' if self.parsing_pattern else 'tags'}"
|
269
305
|
)
|
270
306
|
continue
|
@@ -273,6 +309,17 @@ class TextParserBlock(BaseBlock):
|
|
273
309
|
# Collect all parsed values for each column as lists
|
274
310
|
for col in self.output_cols:
|
275
311
|
all_parsed_outputs[col].extend(parsed_outputs.get(col, []))
|
312
|
+
if self.save_reasoning_content:
|
313
|
+
if (
|
314
|
+
self.block_name + "_" + self.reasoning_content_field
|
315
|
+
not in all_parsed_outputs
|
316
|
+
):
|
317
|
+
all_parsed_outputs[
|
318
|
+
self.block_name + "_" + self.reasoning_content_field
|
319
|
+
] = []
|
320
|
+
all_parsed_outputs[
|
321
|
+
self.block_name + "_" + self.reasoning_content_field
|
322
|
+
].extend(reasoning_content)
|
276
323
|
|
277
324
|
if valid_responses == 0:
|
278
325
|
return []
|
@@ -283,21 +330,24 @@ class TextParserBlock(BaseBlock):
|
|
283
330
|
else:
|
284
331
|
# When expand_lists=True, use existing expanding behavior
|
285
332
|
all_results = []
|
286
|
-
for i,
|
287
|
-
if not
|
333
|
+
for i, message in enumerate(raw_output):
|
334
|
+
if not message:
|
288
335
|
logger.warning(
|
289
|
-
f"List item {i} in column '{input_column}'
|
290
|
-
f"(empty or non-string): {type(response)}"
|
336
|
+
f"List item {i} in column '{input_column}' is empty"
|
291
337
|
)
|
292
338
|
continue
|
293
339
|
|
294
|
-
parsed_outputs = self.
|
340
|
+
parsed_outputs = self._handle_message(message)
|
341
|
+
if self.save_reasoning_content:
|
342
|
+
reasoning_content = parsed_outputs.pop(
|
343
|
+
self.reasoning_content_field
|
344
|
+
)
|
295
345
|
|
296
346
|
if not parsed_outputs or not any(
|
297
347
|
len(value) > 0 for value in parsed_outputs.values()
|
298
348
|
):
|
299
349
|
logger.warning(
|
300
|
-
f"Failed to parse content from list item {i}. Raw output length: {len(
|
350
|
+
f"Failed to parse content from list item {i}. Raw output length: {len(message)}, "
|
301
351
|
f"parsing method: {'regex' if self.parsing_pattern else 'tags'}"
|
302
352
|
)
|
303
353
|
continue
|
@@ -307,19 +357,30 @@ class TextParserBlock(BaseBlock):
|
|
307
357
|
for values in zip(
|
308
358
|
*(lst[:max_length] for lst in parsed_outputs.values())
|
309
359
|
):
|
310
|
-
|
311
|
-
|
312
|
-
|
360
|
+
result_row = {
|
361
|
+
**sample,
|
362
|
+
**dict(zip(parsed_outputs.keys(), values)),
|
363
|
+
}
|
364
|
+
if self.save_reasoning_content:
|
365
|
+
result_row[
|
366
|
+
self.block_name + "_" + self.reasoning_content_field
|
367
|
+
] = reasoning_content[0]
|
368
|
+
all_results.append(result_row)
|
313
369
|
|
314
370
|
return all_results
|
315
371
|
|
316
|
-
# Handle
|
317
|
-
elif isinstance(raw_output, str):
|
372
|
+
# Handle dict inputs (existing logic)
|
373
|
+
elif isinstance(raw_output, dict) or isinstance(raw_output, str):
|
318
374
|
if not raw_output:
|
319
|
-
logger.warning(f"Input column '{input_column}' contains empty
|
375
|
+
logger.warning(f"Input column '{input_column}' contains empty dict")
|
320
376
|
return []
|
321
377
|
|
322
|
-
|
378
|
+
if isinstance(raw_output, str):
|
379
|
+
raw_output = {"content": raw_output}
|
380
|
+
|
381
|
+
parsed_outputs = self._handle_message(raw_output)
|
382
|
+
if self.save_reasoning_content:
|
383
|
+
reasoning_content = parsed_outputs.pop(self.reasoning_content_field)
|
323
384
|
|
324
385
|
if not parsed_outputs or not any(
|
325
386
|
len(value) > 0 for value in parsed_outputs.values()
|
@@ -333,13 +394,19 @@ class TextParserBlock(BaseBlock):
|
|
333
394
|
result = []
|
334
395
|
max_length = max(len(value) for value in parsed_outputs.values())
|
335
396
|
for values in zip(*(lst[:max_length] for lst in parsed_outputs.values())):
|
336
|
-
|
397
|
+
result_row = {**sample, **dict(zip(parsed_outputs.keys(), values))}
|
398
|
+
if self.save_reasoning_content:
|
399
|
+
result_row[self.block_name + "_" + self.reasoning_content_field] = (
|
400
|
+
reasoning_content[0]
|
401
|
+
)
|
402
|
+
result.append(result_row)
|
403
|
+
|
337
404
|
return result
|
338
405
|
|
339
406
|
else:
|
340
407
|
logger.warning(
|
341
408
|
f"Input column '{input_column}' contains invalid data type: {type(raw_output)}. "
|
342
|
-
f"Expected
|
409
|
+
f"Expected dict or List[dict]"
|
343
410
|
)
|
344
411
|
return []
|
345
412
|
|
@@ -8,6 +8,7 @@ wide-to-long transformations, value selection, and majority value assignment.
|
|
8
8
|
# Local
|
9
9
|
from .duplicate_columns import DuplicateColumnsBlock
|
10
10
|
from .index_based_mapper import IndexBasedMapperBlock
|
11
|
+
from .json_structure_block import JSONStructureBlock
|
11
12
|
from .melt_columns import MeltColumnsBlock
|
12
13
|
from .rename_columns import RenameColumnsBlock
|
13
14
|
from .text_concat import TextConcatBlock
|
@@ -16,6 +17,7 @@ from .uniform_col_val_setter import UniformColumnValueSetter
|
|
16
17
|
__all__ = [
|
17
18
|
"TextConcatBlock",
|
18
19
|
"DuplicateColumnsBlock",
|
20
|
+
"JSONStructureBlock",
|
19
21
|
"MeltColumnsBlock",
|
20
22
|
"IndexBasedMapperBlock",
|
21
23
|
"RenameColumnsBlock",
|
@@ -0,0 +1,142 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
"""JSON structure block for combining multiple columns into a structured JSON object.
|
3
|
+
|
4
|
+
This module provides a block for combining multiple columns into a single column
|
5
|
+
containing a structured JSON object with specified field names.
|
6
|
+
"""
|
7
|
+
|
8
|
+
# Standard
|
9
|
+
from typing import Any, Dict
|
10
|
+
import json
|
11
|
+
|
12
|
+
# Third Party
|
13
|
+
from datasets import Dataset
|
14
|
+
from pydantic import Field, field_validator
|
15
|
+
|
16
|
+
# Local
|
17
|
+
from ...utils.logger_config import setup_logger
|
18
|
+
from ..base import BaseBlock
|
19
|
+
from ..registry import BlockRegistry
|
20
|
+
|
21
|
+
logger = setup_logger(__name__)
|
22
|
+
|
23
|
+
|
24
|
+
@BlockRegistry.register(
|
25
|
+
"JSONStructureBlock",
|
26
|
+
"transform",
|
27
|
+
"Combines multiple columns into a single column containing a structured JSON object",
|
28
|
+
)
|
29
|
+
class JSONStructureBlock(BaseBlock):
|
30
|
+
"""Block for combining multiple columns into a structured JSON object.
|
31
|
+
|
32
|
+
This block takes values from multiple input columns and combines them into a single
|
33
|
+
output column containing a JSON object. The JSON field names match the input column names.
|
34
|
+
|
35
|
+
Attributes
|
36
|
+
----------
|
37
|
+
block_name : str
|
38
|
+
Name of the block.
|
39
|
+
input_cols : List[str]
|
40
|
+
List of input column names to include in the JSON object.
|
41
|
+
Column names become the JSON field names.
|
42
|
+
output_cols : List[str]
|
43
|
+
List containing the single output column name.
|
44
|
+
ensure_json_serializable : bool
|
45
|
+
Whether to ensure all values are JSON serializable (default True).
|
46
|
+
pretty_print : bool
|
47
|
+
Whether to format JSON with indentation (default False).
|
48
|
+
"""
|
49
|
+
|
50
|
+
ensure_json_serializable: bool = Field(
|
51
|
+
default=True, description="Whether to ensure all values are JSON serializable"
|
52
|
+
)
|
53
|
+
pretty_print: bool = Field(
|
54
|
+
default=False, description="Whether to format JSON with indentation"
|
55
|
+
)
|
56
|
+
|
57
|
+
@field_validator("output_cols", mode="after")
|
58
|
+
@classmethod
|
59
|
+
def validate_output_cols(cls, v):
|
60
|
+
"""Validate that exactly one output column is specified."""
|
61
|
+
if not v or len(v) != 1:
|
62
|
+
raise ValueError("JSONStructureBlock requires exactly one output column")
|
63
|
+
return v
|
64
|
+
|
65
|
+
def _make_json_serializable(self, value: Any) -> Any:
|
66
|
+
"""Convert value to JSON serializable format."""
|
67
|
+
if value is None:
|
68
|
+
return None
|
69
|
+
|
70
|
+
# Handle basic types that are already JSON serializable
|
71
|
+
if isinstance(value, (str, int, float, bool)):
|
72
|
+
return value
|
73
|
+
|
74
|
+
# Handle lists
|
75
|
+
if isinstance(value, (list, tuple)):
|
76
|
+
return [self._make_json_serializable(item) for item in value]
|
77
|
+
|
78
|
+
# Handle dictionaries
|
79
|
+
if isinstance(value, dict):
|
80
|
+
return {k: self._make_json_serializable(v) for k, v in value.items()}
|
81
|
+
|
82
|
+
# Convert other types to string
|
83
|
+
return str(value)
|
84
|
+
|
85
|
+
def _get_field_mapping(self) -> Dict[str, str]:
|
86
|
+
"""Get the mapping of JSON field names to input column names."""
|
87
|
+
# Use column names as JSON field names (standard SDG Hub pattern)
|
88
|
+
if isinstance(self.input_cols, list):
|
89
|
+
return {col: col for col in self.input_cols}
|
90
|
+
|
91
|
+
raise ValueError("input_cols must be a list of column names")
|
92
|
+
|
93
|
+
def generate(self, samples: Dataset, **kwargs: Any) -> Dataset:
|
94
|
+
"""Generate a dataset with JSON structured output.
|
95
|
+
|
96
|
+
Parameters
|
97
|
+
----------
|
98
|
+
samples : Dataset
|
99
|
+
Input dataset to process.
|
100
|
+
|
101
|
+
Returns
|
102
|
+
-------
|
103
|
+
Dataset
|
104
|
+
Dataset with JSON structured output in the specified column.
|
105
|
+
"""
|
106
|
+
if not self.output_cols:
|
107
|
+
raise ValueError("output_cols must be specified")
|
108
|
+
|
109
|
+
output_col = self.output_cols[0]
|
110
|
+
field_mapping = self._get_field_mapping()
|
111
|
+
|
112
|
+
def _create_json_structure(sample):
|
113
|
+
"""Create JSON structure from input columns."""
|
114
|
+
json_obj = {}
|
115
|
+
|
116
|
+
# Build the JSON object using the field mapping
|
117
|
+
for json_field, col_name in field_mapping.items():
|
118
|
+
if col_name not in sample:
|
119
|
+
logger.warning(f"Input column '{col_name}' not found in sample")
|
120
|
+
json_obj[json_field] = None
|
121
|
+
else:
|
122
|
+
value = sample[col_name]
|
123
|
+
if self.ensure_json_serializable:
|
124
|
+
value = self._make_json_serializable(value)
|
125
|
+
json_obj[json_field] = value
|
126
|
+
|
127
|
+
# Convert to JSON string
|
128
|
+
try:
|
129
|
+
if self.pretty_print:
|
130
|
+
json_string = json.dumps(json_obj, indent=2, ensure_ascii=False)
|
131
|
+
else:
|
132
|
+
json_string = json.dumps(json_obj, ensure_ascii=False)
|
133
|
+
sample[output_col] = json_string
|
134
|
+
except (TypeError, ValueError) as e:
|
135
|
+
logger.error(f"Failed to serialize JSON object: {e}")
|
136
|
+
sample[output_col] = "{}"
|
137
|
+
|
138
|
+
return sample
|
139
|
+
|
140
|
+
# Apply the JSON structuring to all samples
|
141
|
+
result = samples.map(_create_json_structure)
|
142
|
+
return result
|