sdg-hub 0.2.2__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. sdg_hub/_version.py +2 -2
  2. sdg_hub/core/blocks/llm/client_manager.py +63 -26
  3. sdg_hub/core/blocks/llm/llm_chat_block.py +12 -9
  4. sdg_hub/core/blocks/llm/text_parser_block.py +88 -21
  5. sdg_hub/core/blocks/transform/__init__.py +2 -0
  6. sdg_hub/core/blocks/transform/json_structure_block.py +142 -0
  7. sdg_hub/core/flow/base.py +199 -56
  8. sdg_hub/core/utils/datautils.py +45 -2
  9. sdg_hub/core/utils/flow_metrics.py +261 -0
  10. sdg_hub/core/utils/logger_config.py +50 -9
  11. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/__init__.py +0 -0
  12. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/detailed_summary/__init__.py +0 -0
  13. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/detailed_summary/detailed_summary.yaml +11 -0
  14. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/detailed_summary/flow.yaml +159 -0
  15. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/extractive_summary/__init__.py +0 -0
  16. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/extractive_summary/extractive_summary.yaml +65 -0
  17. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/extractive_summary/flow.yaml +161 -0
  18. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/generate_answers.yaml +15 -0
  19. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/generate_multiple_qa.yaml +21 -0
  20. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/generate_question_list.yaml +44 -0
  21. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/key_facts/__init__.py +0 -0
  22. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/key_facts/flow.yaml +104 -0
  23. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/key_facts/key_facts_summary.yaml +61 -0
  24. sdg_hub/flows/text_analysis/__init__.py +2 -0
  25. sdg_hub/flows/text_analysis/structured_insights/__init__.py +6 -0
  26. sdg_hub/flows/text_analysis/structured_insights/analyze_sentiment.yaml +27 -0
  27. sdg_hub/flows/text_analysis/structured_insights/extract_entities.yaml +38 -0
  28. sdg_hub/flows/text_analysis/structured_insights/extract_keywords.yaml +21 -0
  29. sdg_hub/flows/text_analysis/structured_insights/flow.yaml +153 -0
  30. sdg_hub/flows/text_analysis/structured_insights/summarize.yaml +21 -0
  31. {sdg_hub-0.2.2.dist-info → sdg_hub-0.3.1.dist-info}/METADATA +3 -1
  32. {sdg_hub-0.2.2.dist-info → sdg_hub-0.3.1.dist-info}/RECORD +35 -13
  33. {sdg_hub-0.2.2.dist-info → sdg_hub-0.3.1.dist-info}/WHEEL +0 -0
  34. {sdg_hub-0.2.2.dist-info → sdg_hub-0.3.1.dist-info}/licenses/LICENSE +0 -0
  35. {sdg_hub-0.2.2.dist-info → sdg_hub-0.3.1.dist-info}/top_level.txt +0 -0
sdg_hub/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.2'
32
- __version_tuple__ = version_tuple = (0, 2, 2)
31
+ __version__ = version = '0.3.1'
32
+ __version_tuple__ = version_tuple = (0, 3, 1)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -107,9 +107,18 @@ class LLMClientManager:
107
107
  f"Could not validate setup for model '{self.config.model}': {e}"
108
108
  )
109
109
 
110
+ def _message_to_dict(self, message: Any) -> dict[str, Any]:
111
+ """Convert a message to a dict."""
112
+ if hasattr(message, "to_dict"):
113
+ return message.to_dict()
114
+ elif hasattr(message, "__dict__"):
115
+ return message.__dict__
116
+ else:
117
+ return dict(message)
118
+
110
119
  def create_completion(
111
120
  self, messages: list[dict[str, Any]], **overrides: Any
112
- ) -> Union[str, list[str]]:
121
+ ) -> Union[dict, list[dict]]:
113
122
  """Create a completion using LiteLLM.
114
123
 
115
124
  Parameters
@@ -121,9 +130,9 @@ class LLMClientManager:
121
130
 
122
131
  Returns
123
132
  -------
124
- Union[str, List[str]]
125
- The completion text(s). Returns a single string when n=1 or n is None,
126
- returns a list of strings when n>1.
133
+ Union[dict, List[dict]]
134
+ The completion response(s). Returns a single response when n=1 or n is None,
135
+ returns a list of responses when n>1. Response dicts contain 'content' and may contain 'reasoning_content'.
127
136
 
128
137
  Raises
129
138
  ------
@@ -151,28 +160,30 @@ class LLMClientManager:
151
160
  # Make the completion call
152
161
  response = completion_func(kwargs)
153
162
 
154
- # Extract content from response
163
+ # Extract message objects from response
155
164
  # Check if n > 1 to determine return type
156
165
  n_value = final_config.n or 1
157
166
  if n_value > 1:
158
- return [choice.message.content for choice in response.choices]
167
+ return [
168
+ self._message_to_dict(choice.message) for choice in response.choices
169
+ ]
159
170
  else:
160
- return response.choices[0].message.content
171
+ return self._message_to_dict(response.choices[0].message)
161
172
 
162
173
  async def acreate_completion(
163
174
  self,
164
175
  messages: Union[list[dict[str, Any]], list[list[dict[str, Any]]]],
165
176
  max_concurrency: Optional[int] = None,
166
177
  **overrides: Any,
167
- ) -> Union[str, list[str], list[Union[str, list[str]]]]:
178
+ ) -> Union[dict, list[dict]] | list[Union[dict, list[dict]]]:
168
179
  """Create async completion(s) using LiteLLM with optional concurrency control.
169
180
 
170
181
  Parameters
171
182
  ----------
172
183
  messages : Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]
173
184
  Single message list or list of message lists.
174
- - For single: List[Dict[str, Any]] - returns Union[str, List[str]]
175
- - For multiple: List[List[Dict[str, Any]]] - returns List[Union[str, List[str]]]
185
+ - For single: List[Dict[str, Any]] - returns Union[Any, List[Any]]
186
+ - For multiple: List[List[Dict[str, Any]]] - returns List[Union[Any, List[Any]]]
176
187
  max_concurrency : Optional[int], optional
177
188
  Maximum number of concurrent requests when processing multiple messages.
178
189
  If None, all requests run concurrently.
@@ -181,9 +192,9 @@ class LLMClientManager:
181
192
 
182
193
  Returns
183
194
  -------
184
- Union[str, List[str], List[Union[str, List[str]]]]
185
- For single message: completion text (string when n=1, list when n>1)
186
- For multiple messages: list of completion texts (each element can be str or List[str])
195
+ Union[dict, List[dict], List[Union[dict, List[dict]]]]
196
+ For single message: completion response (dict when n=1, List[dict] when n>1)
197
+ For multiple messages: list of completion responses (each element can be dict or List[dict])
187
198
 
188
199
  Raises
189
200
  ------
@@ -203,8 +214,33 @@ class LLMClientManager:
203
214
  messages_list = messages
204
215
 
205
216
  if max_concurrency is not None:
217
+ if max_concurrency < 1:
218
+ raise ValueError(
219
+ "max_concurrency must be greater than 0, got {max_concurrency}"
220
+ )
221
+ # Adjust concurrency based on n parameter to avoid overwhelming API
222
+ # when n > 1 (multiple completions per request)
223
+ n_value = overrides.get("n") or self.config.n or 1
224
+ if n_value > 1:
225
+ # Warn if max_concurrency is less than n
226
+ if max_concurrency < n_value:
227
+ logger.warning(
228
+ f"max_concurrency ({max_concurrency}) is less than n ({n_value}). "
229
+ f"This may result in very low concurrency. Consider increasing max_concurrency "
230
+ f"or reducing n for better performance."
231
+ )
232
+
233
+ # Reduce concurrency when generating multiple completions per request
234
+ adjusted_concurrency = max(1, max_concurrency // n_value)
235
+ logger.debug(
236
+ f"Adjusted max_concurrency from {max_concurrency} to {adjusted_concurrency} "
237
+ f"for n={n_value} completions per request"
238
+ )
239
+ else:
240
+ adjusted_concurrency = max_concurrency
241
+
206
242
  # Use semaphore for concurrency control
207
- semaphore = asyncio.Semaphore(max_concurrency)
243
+ semaphore = asyncio.Semaphore(adjusted_concurrency)
208
244
 
209
245
  async def _create_with_semaphore(msgs):
210
246
  async with semaphore:
@@ -221,7 +257,7 @@ class LLMClientManager:
221
257
 
222
258
  async def _acreate_single(
223
259
  self, messages: list[dict[str, Any]], **overrides: Any
224
- ) -> Union[str, list[str]]:
260
+ ) -> Union[dict, list[dict]]:
225
261
  """Create a single async completion using LiteLLM.
226
262
 
227
263
  Parameters
@@ -233,10 +269,9 @@ class LLMClientManager:
233
269
 
234
270
  Returns
235
271
  -------
236
- Union[str, List[str]]
237
- The completion text(s). Returns a single string when n=1 or n is None,
238
- returns a list of strings when n>1.
239
-
272
+ Union[dict, List[dict]]
273
+ List of completion message objects. Each element is a dict when n=1 or n is None,
274
+ or a list of dicts when n>1. Message dicts contain 'content' and may contain 'reasoning_content'.
240
275
  Raises
241
276
  ------
242
277
  Exception
@@ -263,17 +298,19 @@ class LLMClientManager:
263
298
  # Make the async completion call
264
299
  response = await completion_func(kwargs)
265
300
 
266
- # Extract content from response
301
+ # Extract message objects from response
267
302
  # Check if n > 1 to determine return type
268
303
  n_value = final_config.n or 1
269
304
  if n_value > 1:
270
- return [choice.message.content for choice in response.choices]
305
+ return [
306
+ self._message_to_dict(choice.message) for choice in response.choices
307
+ ]
271
308
  else:
272
- return response.choices[0].message.content
309
+ return self._message_to_dict(response.choices[0].message)
273
310
 
274
311
  def create_completions_batch(
275
312
  self, messages_list: list[list[dict[str, Any]]], **overrides: Any
276
- ) -> list[Union[str, list[str]]]:
313
+ ) -> list[Union[dict, list[dict]]]:
277
314
  """Create multiple completions in batch.
278
315
 
279
316
  Parameters
@@ -285,9 +322,9 @@ class LLMClientManager:
285
322
 
286
323
  Returns
287
324
  -------
288
- List[Union[str, List[str]]]
289
- List of completion texts. Each element is a single string when n=1 or n is None,
290
- or a list of strings when n>1.
325
+ List[dict] | List[List[dict]]
326
+ List of completion responses. Each element is a dict when n=1 or n is None,
327
+ or a list of dicts when n>1. Response dicts contain 'content' and may contain 'reasoning_content'.
291
328
  """
292
329
  results = []
293
330
  for messages in messages_list:
@@ -42,9 +42,10 @@ class LLMChatBlock(BaseBlock):
42
42
  Name of the block.
43
43
  input_cols : Union[str, List[str]]
44
44
  Input column name(s). Should contain the messages list.
45
- output_cols : Union[str, List[str]]
45
+ output_cols : Union[dict, List[dict]]
46
46
  Output column name(s) for the response. When n > 1, the column will contain
47
- a list of responses instead of a single string.
47
+ a list of responses instead of a single response. Responses contain 'content',
48
+ may contain 'reasoning_content' and other fields if any.
48
49
  model : str
49
50
  Model identifier in LiteLLM format. Examples:
50
51
  - "openai/gpt-4"
@@ -131,7 +132,7 @@ class LLMChatBlock(BaseBlock):
131
132
  >>> block = LLMChatBlock(
132
133
  ... block_name="gpt4_multiple",
133
134
  ... input_cols="messages",
134
- ... output_cols="responses", # Will contain lists of strings
135
+ ... output_cols="responses", # Will contain lists of responses
135
136
  ... model="openai/gpt-4",
136
137
  ... n=3, # Generate 3 responses per input
137
138
  ... temperature=0.8
@@ -406,7 +407,7 @@ class LLMChatBlock(BaseBlock):
406
407
  self,
407
408
  messages_list: list[list[dict[str, Any]]],
408
409
  **override_kwargs: dict[str, Any],
409
- ) -> list[Union[str, list[str]]]:
410
+ ) -> list[Union[dict, list[dict]]]:
410
411
  """Generate responses synchronously.
411
412
 
412
413
  Parameters
@@ -418,8 +419,9 @@ class LLMChatBlock(BaseBlock):
418
419
 
419
420
  Returns
420
421
  -------
421
- List[Union[str, List[str]]]
422
- List of response strings or lists of response strings (when n > 1).
422
+ List[Union[dict, List[dict]]]
423
+ List of responses. Each element is a dict when n=1 or n is None,
424
+ or a list of dicts when n>1. Response dicts contain 'content', may contain 'reasoning_content' and other fields if any.
423
425
  """
424
426
  responses = []
425
427
 
@@ -461,7 +463,7 @@ class LLMChatBlock(BaseBlock):
461
463
  messages_list: list[list[dict[str, Any]]],
462
464
  flow_max_concurrency: Optional[int] = None,
463
465
  **override_kwargs: dict[str, Any],
464
- ) -> list[Union[str, list[str]]]:
466
+ ) -> list[Union[dict, list[dict]]]:
465
467
  """Generate responses asynchronously.
466
468
 
467
469
  Parameters
@@ -475,8 +477,9 @@ class LLMChatBlock(BaseBlock):
475
477
 
476
478
  Returns
477
479
  -------
478
- List[Union[str, List[str]]]
479
- List of response strings or lists of response strings (when n > 1).
480
+ List[Union[dict, List[dict]]]
481
+ List of responses. Each element is a dict when n=1 or n is None,
482
+ or a list of dicts when n>1. Response dicts contain 'content', may contain 'reasoning_content' and other fields if any.
480
483
  """
481
484
  try:
482
485
  # Use unified client manager method with optional concurrency control
@@ -51,6 +51,10 @@ class TextParserBlock(BaseBlock):
51
51
  expand_lists : bool
52
52
  Whether to expand list inputs into individual rows (True) or preserve lists (False).
53
53
  Default is True for backward compatibility.
54
+ save_reasoning_content : bool
55
+ Whether to save the reasoning content to the output.
56
+ reasoning_content_field : Optional[str]
57
+ The field name of the reasoning content to save to the output.
54
58
  """
55
59
 
56
60
  start_tags: list[str] = Field(
@@ -69,6 +73,14 @@ class TextParserBlock(BaseBlock):
69
73
  default=True,
70
74
  description="Whether to expand list inputs into individual rows (True) or preserve lists (False). ",
71
75
  )
76
+ save_reasoning_content: bool = Field(
77
+ default=False,
78
+ description="Whether to save the reasoning content to the output.",
79
+ )
80
+ reasoning_content_field: Optional[str] = Field(
81
+ default="reasoning_content",
82
+ description="The field name of the reasoning content to save to the output.",
83
+ )
72
84
 
73
85
  @field_validator("start_tags", "end_tags", mode="before")
74
86
  @classmethod
@@ -234,6 +246,27 @@ class TextParserBlock(BaseBlock):
234
246
  value = value.replace(clean_tag, "")
235
247
  return value
236
248
 
249
+ def _handle_message(self, sample: dict) -> dict[str, list[str]]:
250
+ if "content" not in sample:
251
+ logger.warning(f"Content not found in sample: {sample}")
252
+ return {}
253
+ parsed_output = self._parse(sample["content"])
254
+ if self.save_reasoning_content:
255
+ parsed_output[self.reasoning_content_field] = [
256
+ self._get_reasoning_content(sample)
257
+ ]
258
+ return parsed_output
259
+
260
+ def _get_reasoning_content(self, sample: dict) -> str:
261
+ if self.save_reasoning_content:
262
+ if self.reasoning_content_field in sample:
263
+ return sample[self.reasoning_content_field]
264
+ else:
265
+ logger.warning(
266
+ f"Reasoning content field '{self.reasoning_content_field}' not found in response"
267
+ )
268
+ return ""
269
+
237
270
  def _generate(self, sample: dict) -> list[dict]:
238
271
  input_column = self.input_cols[0]
239
272
  raw_output = sample[input_column]
@@ -250,21 +283,24 @@ class TextParserBlock(BaseBlock):
250
283
  all_parsed_outputs = {col: [] for col in self.output_cols}
251
284
  valid_responses = 0
252
285
 
253
- for i, response in enumerate(raw_output):
254
- if not response or not isinstance(response, str):
286
+ for i, message in enumerate(raw_output):
287
+ if not message:
255
288
  logger.warning(
256
- f"List item {i} in column '{input_column}' contains invalid data "
257
- f"(empty or non-string): {type(response)}"
289
+ f"List item {i} in column '{input_column}' is empty"
258
290
  )
259
291
  continue
260
292
 
261
- parsed_outputs = self._parse(response)
293
+ parsed_outputs = self._handle_message(message)
294
+ if self.save_reasoning_content:
295
+ reasoning_content = parsed_outputs.pop(
296
+ self.reasoning_content_field
297
+ )
262
298
 
263
299
  if not parsed_outputs or not any(
264
300
  len(value) > 0 for value in parsed_outputs.values()
265
301
  ):
266
302
  logger.warning(
267
- f"Failed to parse content from list item {i}. Raw output length: {len(response)}, "
303
+ f"Failed to parse content from list item {i}. Raw output length: {len(message)}, "
268
304
  f"parsing method: {'regex' if self.parsing_pattern else 'tags'}"
269
305
  )
270
306
  continue
@@ -273,6 +309,17 @@ class TextParserBlock(BaseBlock):
273
309
  # Collect all parsed values for each column as lists
274
310
  for col in self.output_cols:
275
311
  all_parsed_outputs[col].extend(parsed_outputs.get(col, []))
312
+ if self.save_reasoning_content:
313
+ if (
314
+ self.block_name + "_" + self.reasoning_content_field
315
+ not in all_parsed_outputs
316
+ ):
317
+ all_parsed_outputs[
318
+ self.block_name + "_" + self.reasoning_content_field
319
+ ] = []
320
+ all_parsed_outputs[
321
+ self.block_name + "_" + self.reasoning_content_field
322
+ ].extend(reasoning_content)
276
323
 
277
324
  if valid_responses == 0:
278
325
  return []
@@ -283,21 +330,24 @@ class TextParserBlock(BaseBlock):
283
330
  else:
284
331
  # When expand_lists=True, use existing expanding behavior
285
332
  all_results = []
286
- for i, response in enumerate(raw_output):
287
- if not response or not isinstance(response, str):
333
+ for i, message in enumerate(raw_output):
334
+ if not message:
288
335
  logger.warning(
289
- f"List item {i} in column '{input_column}' contains invalid data "
290
- f"(empty or non-string): {type(response)}"
336
+ f"List item {i} in column '{input_column}' is empty"
291
337
  )
292
338
  continue
293
339
 
294
- parsed_outputs = self._parse(response)
340
+ parsed_outputs = self._handle_message(message)
341
+ if self.save_reasoning_content:
342
+ reasoning_content = parsed_outputs.pop(
343
+ self.reasoning_content_field
344
+ )
295
345
 
296
346
  if not parsed_outputs or not any(
297
347
  len(value) > 0 for value in parsed_outputs.values()
298
348
  ):
299
349
  logger.warning(
300
- f"Failed to parse content from list item {i}. Raw output length: {len(response)}, "
350
+ f"Failed to parse content from list item {i}. Raw output length: {len(message)}, "
301
351
  f"parsing method: {'regex' if self.parsing_pattern else 'tags'}"
302
352
  )
303
353
  continue
@@ -307,19 +357,30 @@ class TextParserBlock(BaseBlock):
307
357
  for values in zip(
308
358
  *(lst[:max_length] for lst in parsed_outputs.values())
309
359
  ):
310
- all_results.append(
311
- {**sample, **dict(zip(parsed_outputs.keys(), values))}
312
- )
360
+ result_row = {
361
+ **sample,
362
+ **dict(zip(parsed_outputs.keys(), values)),
363
+ }
364
+ if self.save_reasoning_content:
365
+ result_row[
366
+ self.block_name + "_" + self.reasoning_content_field
367
+ ] = reasoning_content[0]
368
+ all_results.append(result_row)
313
369
 
314
370
  return all_results
315
371
 
316
- # Handle string inputs (existing logic)
317
- elif isinstance(raw_output, str):
372
+ # Handle dict inputs (existing logic)
373
+ elif isinstance(raw_output, dict) or isinstance(raw_output, str):
318
374
  if not raw_output:
319
- logger.warning(f"Input column '{input_column}' contains empty string")
375
+ logger.warning(f"Input column '{input_column}' contains empty dict")
320
376
  return []
321
377
 
322
- parsed_outputs = self._parse(raw_output)
378
+ if isinstance(raw_output, str):
379
+ raw_output = {"content": raw_output}
380
+
381
+ parsed_outputs = self._handle_message(raw_output)
382
+ if self.save_reasoning_content:
383
+ reasoning_content = parsed_outputs.pop(self.reasoning_content_field)
323
384
 
324
385
  if not parsed_outputs or not any(
325
386
  len(value) > 0 for value in parsed_outputs.values()
@@ -333,13 +394,19 @@ class TextParserBlock(BaseBlock):
333
394
  result = []
334
395
  max_length = max(len(value) for value in parsed_outputs.values())
335
396
  for values in zip(*(lst[:max_length] for lst in parsed_outputs.values())):
336
- result.append({**sample, **dict(zip(parsed_outputs.keys(), values))})
397
+ result_row = {**sample, **dict(zip(parsed_outputs.keys(), values))}
398
+ if self.save_reasoning_content:
399
+ result_row[self.block_name + "_" + self.reasoning_content_field] = (
400
+ reasoning_content[0]
401
+ )
402
+ result.append(result_row)
403
+
337
404
  return result
338
405
 
339
406
  else:
340
407
  logger.warning(
341
408
  f"Input column '{input_column}' contains invalid data type: {type(raw_output)}. "
342
- f"Expected str or List[str]"
409
+ f"Expected dict or List[dict]"
343
410
  )
344
411
  return []
345
412
 
@@ -8,6 +8,7 @@ wide-to-long transformations, value selection, and majority value assignment.
8
8
  # Local
9
9
  from .duplicate_columns import DuplicateColumnsBlock
10
10
  from .index_based_mapper import IndexBasedMapperBlock
11
+ from .json_structure_block import JSONStructureBlock
11
12
  from .melt_columns import MeltColumnsBlock
12
13
  from .rename_columns import RenameColumnsBlock
13
14
  from .text_concat import TextConcatBlock
@@ -16,6 +17,7 @@ from .uniform_col_val_setter import UniformColumnValueSetter
16
17
  __all__ = [
17
18
  "TextConcatBlock",
18
19
  "DuplicateColumnsBlock",
20
+ "JSONStructureBlock",
19
21
  "MeltColumnsBlock",
20
22
  "IndexBasedMapperBlock",
21
23
  "RenameColumnsBlock",
@@ -0,0 +1,142 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """JSON structure block for combining multiple columns into a structured JSON object.
3
+
4
+ This module provides a block for combining multiple columns into a single column
5
+ containing a structured JSON object with specified field names.
6
+ """
7
+
8
+ # Standard
9
+ from typing import Any, Dict
10
+ import json
11
+
12
+ # Third Party
13
+ from datasets import Dataset
14
+ from pydantic import Field, field_validator
15
+
16
+ # Local
17
+ from ...utils.logger_config import setup_logger
18
+ from ..base import BaseBlock
19
+ from ..registry import BlockRegistry
20
+
21
+ logger = setup_logger(__name__)
22
+
23
+
24
+ @BlockRegistry.register(
25
+ "JSONStructureBlock",
26
+ "transform",
27
+ "Combines multiple columns into a single column containing a structured JSON object",
28
+ )
29
+ class JSONStructureBlock(BaseBlock):
30
+ """Block for combining multiple columns into a structured JSON object.
31
+
32
+ This block takes values from multiple input columns and combines them into a single
33
+ output column containing a JSON object. The JSON field names match the input column names.
34
+
35
+ Attributes
36
+ ----------
37
+ block_name : str
38
+ Name of the block.
39
+ input_cols : List[str]
40
+ List of input column names to include in the JSON object.
41
+ Column names become the JSON field names.
42
+ output_cols : List[str]
43
+ List containing the single output column name.
44
+ ensure_json_serializable : bool
45
+ Whether to ensure all values are JSON serializable (default True).
46
+ pretty_print : bool
47
+ Whether to format JSON with indentation (default False).
48
+ """
49
+
50
+ ensure_json_serializable: bool = Field(
51
+ default=True, description="Whether to ensure all values are JSON serializable"
52
+ )
53
+ pretty_print: bool = Field(
54
+ default=False, description="Whether to format JSON with indentation"
55
+ )
56
+
57
+ @field_validator("output_cols", mode="after")
58
+ @classmethod
59
+ def validate_output_cols(cls, v):
60
+ """Validate that exactly one output column is specified."""
61
+ if not v or len(v) != 1:
62
+ raise ValueError("JSONStructureBlock requires exactly one output column")
63
+ return v
64
+
65
+ def _make_json_serializable(self, value: Any) -> Any:
66
+ """Convert value to JSON serializable format."""
67
+ if value is None:
68
+ return None
69
+
70
+ # Handle basic types that are already JSON serializable
71
+ if isinstance(value, (str, int, float, bool)):
72
+ return value
73
+
74
+ # Handle lists
75
+ if isinstance(value, (list, tuple)):
76
+ return [self._make_json_serializable(item) for item in value]
77
+
78
+ # Handle dictionaries
79
+ if isinstance(value, dict):
80
+ return {k: self._make_json_serializable(v) for k, v in value.items()}
81
+
82
+ # Convert other types to string
83
+ return str(value)
84
+
85
+ def _get_field_mapping(self) -> Dict[str, str]:
86
+ """Get the mapping of JSON field names to input column names."""
87
+ # Use column names as JSON field names (standard SDG Hub pattern)
88
+ if isinstance(self.input_cols, list):
89
+ return {col: col for col in self.input_cols}
90
+
91
+ raise ValueError("input_cols must be a list of column names")
92
+
93
+ def generate(self, samples: Dataset, **kwargs: Any) -> Dataset:
94
+ """Generate a dataset with JSON structured output.
95
+
96
+ Parameters
97
+ ----------
98
+ samples : Dataset
99
+ Input dataset to process.
100
+
101
+ Returns
102
+ -------
103
+ Dataset
104
+ Dataset with JSON structured output in the specified column.
105
+ """
106
+ if not self.output_cols:
107
+ raise ValueError("output_cols must be specified")
108
+
109
+ output_col = self.output_cols[0]
110
+ field_mapping = self._get_field_mapping()
111
+
112
+ def _create_json_structure(sample):
113
+ """Create JSON structure from input columns."""
114
+ json_obj = {}
115
+
116
+ # Build the JSON object using the field mapping
117
+ for json_field, col_name in field_mapping.items():
118
+ if col_name not in sample:
119
+ logger.warning(f"Input column '{col_name}' not found in sample")
120
+ json_obj[json_field] = None
121
+ else:
122
+ value = sample[col_name]
123
+ if self.ensure_json_serializable:
124
+ value = self._make_json_serializable(value)
125
+ json_obj[json_field] = value
126
+
127
+ # Convert to JSON string
128
+ try:
129
+ if self.pretty_print:
130
+ json_string = json.dumps(json_obj, indent=2, ensure_ascii=False)
131
+ else:
132
+ json_string = json.dumps(json_obj, ensure_ascii=False)
133
+ sample[output_col] = json_string
134
+ except (TypeError, ValueError) as e:
135
+ logger.error(f"Failed to serialize JSON object: {e}")
136
+ sample[output_col] = "{}"
137
+
138
+ return sample
139
+
140
+ # Apply the JSON structuring to all samples
141
+ result = samples.map(_create_json_structure)
142
+ return result