sdg-hub 0.2.1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdg_hub/_version.py +16 -3
- sdg_hub/core/blocks/deprecated_blocks/selector.py +1 -1
- sdg_hub/core/blocks/evaluation/evaluate_faithfulness_block.py +175 -416
- sdg_hub/core/blocks/evaluation/evaluate_relevancy_block.py +174 -415
- sdg_hub/core/blocks/evaluation/verify_question_block.py +180 -415
- sdg_hub/core/blocks/llm/client_manager.py +92 -43
- sdg_hub/core/blocks/llm/config.py +1 -0
- sdg_hub/core/blocks/llm/llm_chat_block.py +74 -16
- sdg_hub/core/blocks/llm/llm_chat_with_parsing_retry_block.py +277 -115
- sdg_hub/core/blocks/llm/text_parser_block.py +88 -23
- sdg_hub/core/blocks/registry.py +48 -34
- sdg_hub/core/blocks/transform/__init__.py +2 -0
- sdg_hub/core/blocks/transform/index_based_mapper.py +1 -1
- sdg_hub/core/blocks/transform/json_structure_block.py +142 -0
- sdg_hub/core/flow/base.py +326 -62
- sdg_hub/core/utils/datautils.py +54 -0
- sdg_hub/core/utils/flow_metrics.py +261 -0
- sdg_hub/core/utils/logger_config.py +50 -9
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/__init__.py +0 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/detailed_summary/__init__.py +0 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/detailed_summary/detailed_summary.yaml +11 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/detailed_summary/flow.yaml +159 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/extractive_summary/__init__.py +0 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/extractive_summary/extractive_summary.yaml +65 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/extractive_summary/flow.yaml +161 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/generate_answers.yaml +15 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/generate_multiple_qa.yaml +21 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/generate_question_list.yaml +44 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/key_facts/__init__.py +0 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/key_facts/flow.yaml +104 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/key_facts/key_facts_summary.yaml +61 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/flow.yaml +0 -7
- sdg_hub/flows/text_analysis/__init__.py +2 -0
- sdg_hub/flows/text_analysis/structured_insights/__init__.py +6 -0
- sdg_hub/flows/text_analysis/structured_insights/analyze_sentiment.yaml +27 -0
- sdg_hub/flows/text_analysis/structured_insights/extract_entities.yaml +38 -0
- sdg_hub/flows/text_analysis/structured_insights/extract_keywords.yaml +21 -0
- sdg_hub/flows/text_analysis/structured_insights/flow.yaml +153 -0
- sdg_hub/flows/text_analysis/structured_insights/summarize.yaml +21 -0
- {sdg_hub-0.2.1.dist-info → sdg_hub-0.3.0.dist-info}/METADATA +42 -15
- {sdg_hub-0.2.1.dist-info → sdg_hub-0.3.0.dist-info}/RECORD +44 -22
- {sdg_hub-0.2.1.dist-info → sdg_hub-0.3.0.dist-info}/WHEEL +0 -0
- {sdg_hub-0.2.1.dist-info → sdg_hub-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {sdg_hub-0.2.1.dist-info → sdg_hub-0.3.0.dist-info}/top_level.txt +0 -0
@@ -51,6 +51,10 @@ class TextParserBlock(BaseBlock):
|
|
51
51
|
expand_lists : bool
|
52
52
|
Whether to expand list inputs into individual rows (True) or preserve lists (False).
|
53
53
|
Default is True for backward compatibility.
|
54
|
+
save_reasoning_content : bool
|
55
|
+
Whether to save the reasoning content to the output.
|
56
|
+
reasoning_content_field : Optional[str]
|
57
|
+
The field name of the reasoning content to save to the output.
|
54
58
|
"""
|
55
59
|
|
56
60
|
start_tags: list[str] = Field(
|
@@ -69,6 +73,14 @@ class TextParserBlock(BaseBlock):
|
|
69
73
|
default=True,
|
70
74
|
description="Whether to expand list inputs into individual rows (True) or preserve lists (False). ",
|
71
75
|
)
|
76
|
+
save_reasoning_content: bool = Field(
|
77
|
+
default=False,
|
78
|
+
description="Whether to save the reasoning content to the output.",
|
79
|
+
)
|
80
|
+
reasoning_content_field: Optional[str] = Field(
|
81
|
+
default="reasoning_content",
|
82
|
+
description="The field name of the reasoning content to save to the output.",
|
83
|
+
)
|
72
84
|
|
73
85
|
@field_validator("start_tags", "end_tags", mode="before")
|
74
86
|
@classmethod
|
@@ -234,6 +246,27 @@ class TextParserBlock(BaseBlock):
|
|
234
246
|
value = value.replace(clean_tag, "")
|
235
247
|
return value
|
236
248
|
|
249
|
+
def _handle_message(self, sample: dict) -> dict[str, list[str]]:
|
250
|
+
if "content" not in sample:
|
251
|
+
logger.warning(f"Content not found in sample: {sample}")
|
252
|
+
return {}
|
253
|
+
parsed_output = self._parse(sample["content"])
|
254
|
+
if self.save_reasoning_content:
|
255
|
+
parsed_output[self.reasoning_content_field] = [
|
256
|
+
self._get_reasoning_content(sample)
|
257
|
+
]
|
258
|
+
return parsed_output
|
259
|
+
|
260
|
+
def _get_reasoning_content(self, sample: dict) -> str:
|
261
|
+
if self.save_reasoning_content:
|
262
|
+
if self.reasoning_content_field in sample:
|
263
|
+
return sample[self.reasoning_content_field]
|
264
|
+
else:
|
265
|
+
logger.warning(
|
266
|
+
f"Reasoning content field '{self.reasoning_content_field}' not found in response"
|
267
|
+
)
|
268
|
+
return ""
|
269
|
+
|
237
270
|
def _generate(self, sample: dict) -> list[dict]:
|
238
271
|
input_column = self.input_cols[0]
|
239
272
|
raw_output = sample[input_column]
|
@@ -250,21 +283,24 @@ class TextParserBlock(BaseBlock):
|
|
250
283
|
all_parsed_outputs = {col: [] for col in self.output_cols}
|
251
284
|
valid_responses = 0
|
252
285
|
|
253
|
-
for i,
|
254
|
-
if not
|
286
|
+
for i, message in enumerate(raw_output):
|
287
|
+
if not message:
|
255
288
|
logger.warning(
|
256
|
-
f"List item {i} in column '{input_column}'
|
257
|
-
f"(empty or non-string): {type(response)}"
|
289
|
+
f"List item {i} in column '{input_column}' is empty"
|
258
290
|
)
|
259
291
|
continue
|
260
292
|
|
261
|
-
parsed_outputs = self.
|
293
|
+
parsed_outputs = self._handle_message(message)
|
294
|
+
if self.save_reasoning_content:
|
295
|
+
reasoning_content = parsed_outputs.pop(
|
296
|
+
self.reasoning_content_field
|
297
|
+
)
|
262
298
|
|
263
299
|
if not parsed_outputs or not any(
|
264
300
|
len(value) > 0 for value in parsed_outputs.values()
|
265
301
|
):
|
266
302
|
logger.warning(
|
267
|
-
f"Failed to parse content from list item {i}. Raw output length: {len(
|
303
|
+
f"Failed to parse content from list item {i}. Raw output length: {len(message)}, "
|
268
304
|
f"parsing method: {'regex' if self.parsing_pattern else 'tags'}"
|
269
305
|
)
|
270
306
|
continue
|
@@ -273,33 +309,45 @@ class TextParserBlock(BaseBlock):
|
|
273
309
|
# Collect all parsed values for each column as lists
|
274
310
|
for col in self.output_cols:
|
275
311
|
all_parsed_outputs[col].extend(parsed_outputs.get(col, []))
|
312
|
+
if self.save_reasoning_content:
|
313
|
+
if (
|
314
|
+
self.block_name + "_" + self.reasoning_content_field
|
315
|
+
not in all_parsed_outputs
|
316
|
+
):
|
317
|
+
all_parsed_outputs[
|
318
|
+
self.block_name + "_" + self.reasoning_content_field
|
319
|
+
] = []
|
320
|
+
all_parsed_outputs[
|
321
|
+
self.block_name + "_" + self.reasoning_content_field
|
322
|
+
].extend(reasoning_content)
|
276
323
|
|
277
324
|
if valid_responses == 0:
|
278
325
|
return []
|
279
326
|
|
280
327
|
# Return single row with lists as values
|
281
|
-
# TODO: This breaks retry counting in LLMChatWithParsingRetryBlock until LLMChatWithParsingRetryBlock is re-based
|
282
|
-
# which expects one row per successful parse for counting
|
283
328
|
return [{**sample, **all_parsed_outputs}]
|
284
329
|
|
285
330
|
else:
|
286
331
|
# When expand_lists=True, use existing expanding behavior
|
287
332
|
all_results = []
|
288
|
-
for i,
|
289
|
-
if not
|
333
|
+
for i, message in enumerate(raw_output):
|
334
|
+
if not message:
|
290
335
|
logger.warning(
|
291
|
-
f"List item {i} in column '{input_column}'
|
292
|
-
f"(empty or non-string): {type(response)}"
|
336
|
+
f"List item {i} in column '{input_column}' is empty"
|
293
337
|
)
|
294
338
|
continue
|
295
339
|
|
296
|
-
parsed_outputs = self.
|
340
|
+
parsed_outputs = self._handle_message(message)
|
341
|
+
if self.save_reasoning_content:
|
342
|
+
reasoning_content = parsed_outputs.pop(
|
343
|
+
self.reasoning_content_field
|
344
|
+
)
|
297
345
|
|
298
346
|
if not parsed_outputs or not any(
|
299
347
|
len(value) > 0 for value in parsed_outputs.values()
|
300
348
|
):
|
301
349
|
logger.warning(
|
302
|
-
f"Failed to parse content from list item {i}. Raw output length: {len(
|
350
|
+
f"Failed to parse content from list item {i}. Raw output length: {len(message)}, "
|
303
351
|
f"parsing method: {'regex' if self.parsing_pattern else 'tags'}"
|
304
352
|
)
|
305
353
|
continue
|
@@ -309,19 +357,30 @@ class TextParserBlock(BaseBlock):
|
|
309
357
|
for values in zip(
|
310
358
|
*(lst[:max_length] for lst in parsed_outputs.values())
|
311
359
|
):
|
312
|
-
|
313
|
-
|
314
|
-
|
360
|
+
result_row = {
|
361
|
+
**sample,
|
362
|
+
**dict(zip(parsed_outputs.keys(), values)),
|
363
|
+
}
|
364
|
+
if self.save_reasoning_content:
|
365
|
+
result_row[
|
366
|
+
self.block_name + "_" + self.reasoning_content_field
|
367
|
+
] = reasoning_content[0]
|
368
|
+
all_results.append(result_row)
|
315
369
|
|
316
370
|
return all_results
|
317
371
|
|
318
|
-
# Handle
|
319
|
-
elif isinstance(raw_output, str):
|
372
|
+
# Handle dict inputs (existing logic)
|
373
|
+
elif isinstance(raw_output, dict) or isinstance(raw_output, str):
|
320
374
|
if not raw_output:
|
321
|
-
logger.warning(f"Input column '{input_column}' contains empty
|
375
|
+
logger.warning(f"Input column '{input_column}' contains empty dict")
|
322
376
|
return []
|
323
377
|
|
324
|
-
|
378
|
+
if isinstance(raw_output, str):
|
379
|
+
raw_output = {"content": raw_output}
|
380
|
+
|
381
|
+
parsed_outputs = self._handle_message(raw_output)
|
382
|
+
if self.save_reasoning_content:
|
383
|
+
reasoning_content = parsed_outputs.pop(self.reasoning_content_field)
|
325
384
|
|
326
385
|
if not parsed_outputs or not any(
|
327
386
|
len(value) > 0 for value in parsed_outputs.values()
|
@@ -335,13 +394,19 @@ class TextParserBlock(BaseBlock):
|
|
335
394
|
result = []
|
336
395
|
max_length = max(len(value) for value in parsed_outputs.values())
|
337
396
|
for values in zip(*(lst[:max_length] for lst in parsed_outputs.values())):
|
338
|
-
|
397
|
+
result_row = {**sample, **dict(zip(parsed_outputs.keys(), values))}
|
398
|
+
if self.save_reasoning_content:
|
399
|
+
result_row[self.block_name + "_" + self.reasoning_content_field] = (
|
400
|
+
reasoning_content[0]
|
401
|
+
)
|
402
|
+
result.append(result_row)
|
403
|
+
|
339
404
|
return result
|
340
405
|
|
341
406
|
else:
|
342
407
|
logger.warning(
|
343
408
|
f"Input column '{input_column}' contains invalid data type: {type(raw_output)}. "
|
344
|
-
f"Expected
|
409
|
+
f"Expected dict or List[dict]"
|
345
410
|
)
|
346
411
|
return []
|
347
412
|
|
sdg_hub/core/blocks/registry.py
CHANGED
@@ -164,8 +164,10 @@ class BlockRegistry:
|
|
164
164
|
) from exc
|
165
165
|
|
166
166
|
@classmethod
|
167
|
-
def
|
168
|
-
"""
|
167
|
+
def _get(cls, block_name: str) -> type:
|
168
|
+
"""Internal method to get a block class with enhanced error handling.
|
169
|
+
|
170
|
+
This is a private method used by the framework internals (Flow system).
|
169
171
|
|
170
172
|
Parameters
|
171
173
|
----------
|
@@ -216,29 +218,6 @@ class BlockRegistry:
|
|
216
218
|
|
217
219
|
return metadata.block_class
|
218
220
|
|
219
|
-
@classmethod
|
220
|
-
def info(cls, block_name: str) -> BlockMetadata:
|
221
|
-
"""Get metadata for a specific block.
|
222
|
-
|
223
|
-
Parameters
|
224
|
-
----------
|
225
|
-
block_name : str
|
226
|
-
Name of the block.
|
227
|
-
|
228
|
-
Returns
|
229
|
-
-------
|
230
|
-
BlockMetadata
|
231
|
-
The block's metadata.
|
232
|
-
|
233
|
-
Raises
|
234
|
-
------
|
235
|
-
KeyError
|
236
|
-
If the block is not found.
|
237
|
-
"""
|
238
|
-
if block_name not in cls._metadata:
|
239
|
-
raise KeyError(f"Block '{block_name}' not found in registry.")
|
240
|
-
return cls._metadata[block_name]
|
241
|
-
|
242
221
|
@classmethod
|
243
222
|
def categories(cls) -> list[str]:
|
244
223
|
"""Get all available categories.
|
@@ -251,8 +230,8 @@ class BlockRegistry:
|
|
251
230
|
return sorted(cls._categories.keys())
|
252
231
|
|
253
232
|
@classmethod
|
254
|
-
def
|
255
|
-
"""Get all blocks in a specific category.
|
233
|
+
def _get_category_blocks(cls, category: str) -> list[str]:
|
234
|
+
"""Get all blocks in a specific category (private method).
|
256
235
|
|
257
236
|
Parameters
|
258
237
|
----------
|
@@ -278,17 +257,52 @@ class BlockRegistry:
|
|
278
257
|
return sorted(cls._categories[category])
|
279
258
|
|
280
259
|
@classmethod
|
281
|
-
def
|
282
|
-
|
260
|
+
def list_blocks(
|
261
|
+
cls,
|
262
|
+
category: Optional[str] = None,
|
263
|
+
*,
|
264
|
+
grouped: bool = False,
|
265
|
+
include_deprecated: bool = True,
|
266
|
+
) -> list[str] | dict[str, list[str]]:
|
267
|
+
"""
|
268
|
+
List registered blocks, optionally filtered by category.
|
269
|
+
|
270
|
+
Args:
|
271
|
+
category: If provided, return only blocks in this category.
|
272
|
+
grouped: If True (and category is None), return a dict
|
273
|
+
mapping categories to lists of blocks.
|
274
|
+
include_deprecated: If True, return deprecated blocks.
|
283
275
|
|
284
276
|
Returns
|
285
277
|
-------
|
286
|
-
Dict[str, List[str]]
|
287
|
-
|
278
|
+
List[str] | Dict[str, List[str]]
|
279
|
+
If grouped is False, returns a list of block names.
|
280
|
+
If grouped is True, returns a dict mapping categories to lists of block names.
|
288
281
|
"""
|
289
|
-
|
290
|
-
|
291
|
-
|
282
|
+
|
283
|
+
def filter_deprecated(block_names: list[str]) -> list[str]:
|
284
|
+
if include_deprecated:
|
285
|
+
return block_names
|
286
|
+
return [name for name in block_names if not cls._metadata[name].deprecated]
|
287
|
+
|
288
|
+
if category:
|
289
|
+
block_names = cls._get_category_blocks(category)
|
290
|
+
return filter_deprecated(block_names)
|
291
|
+
|
292
|
+
if grouped:
|
293
|
+
result = {}
|
294
|
+
for cat, blocks in cls._categories.items():
|
295
|
+
filtered = filter_deprecated(sorted(blocks))
|
296
|
+
if filtered:
|
297
|
+
result[cat] = filtered
|
298
|
+
return result
|
299
|
+
|
300
|
+
# Flat list of all block names (across all categories)
|
301
|
+
all_block_names = []
|
302
|
+
for blocks in cls._categories.values():
|
303
|
+
all_block_names.extend(blocks)
|
304
|
+
filtered = filter_deprecated(sorted(all_block_names))
|
305
|
+
return filtered
|
292
306
|
|
293
307
|
@classmethod
|
294
308
|
def discover_blocks(cls) -> None:
|
@@ -8,6 +8,7 @@ wide-to-long transformations, value selection, and majority value assignment.
|
|
8
8
|
# Local
|
9
9
|
from .duplicate_columns import DuplicateColumnsBlock
|
10
10
|
from .index_based_mapper import IndexBasedMapperBlock
|
11
|
+
from .json_structure_block import JSONStructureBlock
|
11
12
|
from .melt_columns import MeltColumnsBlock
|
12
13
|
from .rename_columns import RenameColumnsBlock
|
13
14
|
from .text_concat import TextConcatBlock
|
@@ -16,6 +17,7 @@ from .uniform_col_val_setter import UniformColumnValueSetter
|
|
16
17
|
__all__ = [
|
17
18
|
"TextConcatBlock",
|
18
19
|
"DuplicateColumnsBlock",
|
20
|
+
"JSONStructureBlock",
|
19
21
|
"MeltColumnsBlock",
|
20
22
|
"IndexBasedMapperBlock",
|
21
23
|
"RenameColumnsBlock",
|
@@ -174,7 +174,7 @@ class IndexBasedMapperBlock(BaseBlock):
|
|
174
174
|
sample[output_col] = sample[source_col]
|
175
175
|
return sample
|
176
176
|
|
177
|
-
def generate(self, samples: Dataset) -> Dataset:
|
177
|
+
def generate(self, samples: Dataset, **kwargs) -> Dataset:
|
178
178
|
"""Generate a new dataset with selected values.
|
179
179
|
|
180
180
|
Parameters
|
@@ -0,0 +1,142 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
"""JSON structure block for combining multiple columns into a structured JSON object.
|
3
|
+
|
4
|
+
This module provides a block for combining multiple columns into a single column
|
5
|
+
containing a structured JSON object with specified field names.
|
6
|
+
"""
|
7
|
+
|
8
|
+
# Standard
|
9
|
+
from typing import Any, Dict
|
10
|
+
import json
|
11
|
+
|
12
|
+
# Third Party
|
13
|
+
from datasets import Dataset
|
14
|
+
from pydantic import Field, field_validator
|
15
|
+
|
16
|
+
# Local
|
17
|
+
from ...utils.logger_config import setup_logger
|
18
|
+
from ..base import BaseBlock
|
19
|
+
from ..registry import BlockRegistry
|
20
|
+
|
21
|
+
logger = setup_logger(__name__)
|
22
|
+
|
23
|
+
|
24
|
+
@BlockRegistry.register(
|
25
|
+
"JSONStructureBlock",
|
26
|
+
"transform",
|
27
|
+
"Combines multiple columns into a single column containing a structured JSON object",
|
28
|
+
)
|
29
|
+
class JSONStructureBlock(BaseBlock):
|
30
|
+
"""Block for combining multiple columns into a structured JSON object.
|
31
|
+
|
32
|
+
This block takes values from multiple input columns and combines them into a single
|
33
|
+
output column containing a JSON object. The JSON field names match the input column names.
|
34
|
+
|
35
|
+
Attributes
|
36
|
+
----------
|
37
|
+
block_name : str
|
38
|
+
Name of the block.
|
39
|
+
input_cols : List[str]
|
40
|
+
List of input column names to include in the JSON object.
|
41
|
+
Column names become the JSON field names.
|
42
|
+
output_cols : List[str]
|
43
|
+
List containing the single output column name.
|
44
|
+
ensure_json_serializable : bool
|
45
|
+
Whether to ensure all values are JSON serializable (default True).
|
46
|
+
pretty_print : bool
|
47
|
+
Whether to format JSON with indentation (default False).
|
48
|
+
"""
|
49
|
+
|
50
|
+
ensure_json_serializable: bool = Field(
|
51
|
+
default=True, description="Whether to ensure all values are JSON serializable"
|
52
|
+
)
|
53
|
+
pretty_print: bool = Field(
|
54
|
+
default=False, description="Whether to format JSON with indentation"
|
55
|
+
)
|
56
|
+
|
57
|
+
@field_validator("output_cols", mode="after")
|
58
|
+
@classmethod
|
59
|
+
def validate_output_cols(cls, v):
|
60
|
+
"""Validate that exactly one output column is specified."""
|
61
|
+
if not v or len(v) != 1:
|
62
|
+
raise ValueError("JSONStructureBlock requires exactly one output column")
|
63
|
+
return v
|
64
|
+
|
65
|
+
def _make_json_serializable(self, value: Any) -> Any:
|
66
|
+
"""Convert value to JSON serializable format."""
|
67
|
+
if value is None:
|
68
|
+
return None
|
69
|
+
|
70
|
+
# Handle basic types that are already JSON serializable
|
71
|
+
if isinstance(value, (str, int, float, bool)):
|
72
|
+
return value
|
73
|
+
|
74
|
+
# Handle lists
|
75
|
+
if isinstance(value, (list, tuple)):
|
76
|
+
return [self._make_json_serializable(item) for item in value]
|
77
|
+
|
78
|
+
# Handle dictionaries
|
79
|
+
if isinstance(value, dict):
|
80
|
+
return {k: self._make_json_serializable(v) for k, v in value.items()}
|
81
|
+
|
82
|
+
# Convert other types to string
|
83
|
+
return str(value)
|
84
|
+
|
85
|
+
def _get_field_mapping(self) -> Dict[str, str]:
|
86
|
+
"""Get the mapping of JSON field names to input column names."""
|
87
|
+
# Use column names as JSON field names (standard SDG Hub pattern)
|
88
|
+
if isinstance(self.input_cols, list):
|
89
|
+
return {col: col for col in self.input_cols}
|
90
|
+
|
91
|
+
raise ValueError("input_cols must be a list of column names")
|
92
|
+
|
93
|
+
def generate(self, samples: Dataset, **kwargs: Any) -> Dataset:
|
94
|
+
"""Generate a dataset with JSON structured output.
|
95
|
+
|
96
|
+
Parameters
|
97
|
+
----------
|
98
|
+
samples : Dataset
|
99
|
+
Input dataset to process.
|
100
|
+
|
101
|
+
Returns
|
102
|
+
-------
|
103
|
+
Dataset
|
104
|
+
Dataset with JSON structured output in the specified column.
|
105
|
+
"""
|
106
|
+
if not self.output_cols:
|
107
|
+
raise ValueError("output_cols must be specified")
|
108
|
+
|
109
|
+
output_col = self.output_cols[0]
|
110
|
+
field_mapping = self._get_field_mapping()
|
111
|
+
|
112
|
+
def _create_json_structure(sample):
|
113
|
+
"""Create JSON structure from input columns."""
|
114
|
+
json_obj = {}
|
115
|
+
|
116
|
+
# Build the JSON object using the field mapping
|
117
|
+
for json_field, col_name in field_mapping.items():
|
118
|
+
if col_name not in sample:
|
119
|
+
logger.warning(f"Input column '{col_name}' not found in sample")
|
120
|
+
json_obj[json_field] = None
|
121
|
+
else:
|
122
|
+
value = sample[col_name]
|
123
|
+
if self.ensure_json_serializable:
|
124
|
+
value = self._make_json_serializable(value)
|
125
|
+
json_obj[json_field] = value
|
126
|
+
|
127
|
+
# Convert to JSON string
|
128
|
+
try:
|
129
|
+
if self.pretty_print:
|
130
|
+
json_string = json.dumps(json_obj, indent=2, ensure_ascii=False)
|
131
|
+
else:
|
132
|
+
json_string = json.dumps(json_obj, ensure_ascii=False)
|
133
|
+
sample[output_col] = json_string
|
134
|
+
except (TypeError, ValueError) as e:
|
135
|
+
logger.error(f"Failed to serialize JSON object: {e}")
|
136
|
+
sample[output_col] = "{}"
|
137
|
+
|
138
|
+
return sample
|
139
|
+
|
140
|
+
# Apply the JSON structuring to all samples
|
141
|
+
result = samples.map(_create_json_structure)
|
142
|
+
return result
|