sdg-hub 0.2.1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. sdg_hub/_version.py +16 -3
  2. sdg_hub/core/blocks/deprecated_blocks/selector.py +1 -1
  3. sdg_hub/core/blocks/evaluation/evaluate_faithfulness_block.py +175 -416
  4. sdg_hub/core/blocks/evaluation/evaluate_relevancy_block.py +174 -415
  5. sdg_hub/core/blocks/evaluation/verify_question_block.py +180 -415
  6. sdg_hub/core/blocks/llm/client_manager.py +92 -43
  7. sdg_hub/core/blocks/llm/config.py +1 -0
  8. sdg_hub/core/blocks/llm/llm_chat_block.py +74 -16
  9. sdg_hub/core/blocks/llm/llm_chat_with_parsing_retry_block.py +277 -115
  10. sdg_hub/core/blocks/llm/text_parser_block.py +88 -23
  11. sdg_hub/core/blocks/registry.py +48 -34
  12. sdg_hub/core/blocks/transform/__init__.py +2 -0
  13. sdg_hub/core/blocks/transform/index_based_mapper.py +1 -1
  14. sdg_hub/core/blocks/transform/json_structure_block.py +142 -0
  15. sdg_hub/core/flow/base.py +326 -62
  16. sdg_hub/core/utils/datautils.py +54 -0
  17. sdg_hub/core/utils/flow_metrics.py +261 -0
  18. sdg_hub/core/utils/logger_config.py +50 -9
  19. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/__init__.py +0 -0
  20. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/detailed_summary/__init__.py +0 -0
  21. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/detailed_summary/detailed_summary.yaml +11 -0
  22. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/detailed_summary/flow.yaml +159 -0
  23. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/extractive_summary/__init__.py +0 -0
  24. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/extractive_summary/extractive_summary.yaml +65 -0
  25. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/extractive_summary/flow.yaml +161 -0
  26. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/generate_answers.yaml +15 -0
  27. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/generate_multiple_qa.yaml +21 -0
  28. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/generate_question_list.yaml +44 -0
  29. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/key_facts/__init__.py +0 -0
  30. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/key_facts/flow.yaml +104 -0
  31. sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/key_facts/key_facts_summary.yaml +61 -0
  32. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/flow.yaml +0 -7
  33. sdg_hub/flows/text_analysis/__init__.py +2 -0
  34. sdg_hub/flows/text_analysis/structured_insights/__init__.py +6 -0
  35. sdg_hub/flows/text_analysis/structured_insights/analyze_sentiment.yaml +27 -0
  36. sdg_hub/flows/text_analysis/structured_insights/extract_entities.yaml +38 -0
  37. sdg_hub/flows/text_analysis/structured_insights/extract_keywords.yaml +21 -0
  38. sdg_hub/flows/text_analysis/structured_insights/flow.yaml +153 -0
  39. sdg_hub/flows/text_analysis/structured_insights/summarize.yaml +21 -0
  40. {sdg_hub-0.2.1.dist-info → sdg_hub-0.3.0.dist-info}/METADATA +42 -15
  41. {sdg_hub-0.2.1.dist-info → sdg_hub-0.3.0.dist-info}/RECORD +44 -22
  42. {sdg_hub-0.2.1.dist-info → sdg_hub-0.3.0.dist-info}/WHEEL +0 -0
  43. {sdg_hub-0.2.1.dist-info → sdg_hub-0.3.0.dist-info}/licenses/LICENSE +0 -0
  44. {sdg_hub-0.2.1.dist-info → sdg_hub-0.3.0.dist-info}/top_level.txt +0 -0
sdg_hub/core/flow/base.py CHANGED
@@ -2,29 +2,40 @@
2
2
  """Pydantic-based Flow class for managing data generation pipelines."""
3
3
 
4
4
  # Standard
5
+ from datetime import datetime
5
6
  from pathlib import Path
6
7
  from typing import Any, Optional, Union
7
8
  import time
9
+ import uuid
8
10
 
9
11
  # Third Party
10
12
  from datasets import Dataset
11
- from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
13
+ from pydantic import (
14
+ BaseModel,
15
+ ConfigDict,
16
+ Field,
17
+ PrivateAttr,
18
+ field_validator,
19
+ model_validator,
20
+ )
12
21
  from rich.console import Console
13
22
  from rich.panel import Panel
14
23
  from rich.table import Table
15
24
  from rich.tree import Tree
25
+ import datasets
16
26
  import yaml
17
27
 
18
28
  # Local
19
29
  from ..blocks.base import BaseBlock
20
30
  from ..blocks.registry import BlockRegistry
21
- from ..utils.datautils import safe_concatenate_with_validation
31
+ from ..utils.datautils import safe_concatenate_with_validation, validate_no_duplicates
22
32
  from ..utils.error_handling import EmptyDatasetError, FlowValidationError
33
+ from ..utils.flow_metrics import display_metrics_summary, save_metrics_to_json
23
34
  from ..utils.logger_config import setup_logger
24
35
  from ..utils.path_resolution import resolve_path
25
36
  from ..utils.yaml_utils import save_flow_yaml
26
37
  from .checkpointer import FlowCheckpointer
27
- from .metadata import FlowMetadata, FlowParameter
38
+ from .metadata import DatasetRequirements, FlowMetadata, FlowParameter
28
39
  from .migration import FlowMigration
29
40
  from .validation import FlowValidator
30
41
 
@@ -66,6 +77,9 @@ class Flow(BaseModel):
66
77
  _migrated_runtime_params: dict[str, dict[str, Any]] = {}
67
78
  _llm_client: Any = None # Only used for backward compatibility with old YAMLs
68
79
  _model_config_set: bool = False # Track if model configuration has been set
80
+ _block_metrics: list[dict[str, Any]] = PrivateAttr(
81
+ default_factory=list
82
+ ) # Track block execution metrics
69
83
 
70
84
  @field_validator("blocks")
71
85
  @classmethod
@@ -306,13 +320,11 @@ class Flow(BaseModel):
306
320
 
307
321
  # Get block class from registry
308
322
  try:
309
- block_class = BlockRegistry.get(block_type_name)
323
+ block_class = BlockRegistry._get(block_type_name)
310
324
  except KeyError as exc:
311
325
  # Get all available blocks from all categories
312
- all_blocks = BlockRegistry.all()
313
- available_blocks = ", ".join(
314
- [block for blocks in all_blocks.values() for block in blocks]
315
- )
326
+ all_blocks = BlockRegistry.list_blocks()
327
+ available_blocks = ", ".join(all_blocks)
316
328
  raise FlowValidationError(
317
329
  f"Block type '{block_type_name}' not found in registry. "
318
330
  f"Available blocks: {available_blocks}"
@@ -357,6 +369,8 @@ class Flow(BaseModel):
357
369
  runtime_params: Optional[dict[str, dict[str, Any]]] = None,
358
370
  checkpoint_dir: Optional[str] = None,
359
371
  save_freq: Optional[int] = None,
372
+ log_dir: Optional[str] = None,
373
+ max_concurrency: Optional[int] = None,
360
374
  ) -> Dataset:
361
375
  """Execute the flow blocks in sequence to generate data.
362
376
 
@@ -378,6 +392,13 @@ class Flow(BaseModel):
378
392
  save_freq : Optional[int], optional
379
393
  Number of completed samples after which to save a checkpoint.
380
394
  If None, only saves final results when checkpointing is enabled.
395
+ log_dir : Optional[str], optional
396
+ Directory to save execution logs. If provided, logs will be written to both
397
+ console and a log file in this directory. Maintains backward compatibility
398
+ when None.
399
+ max_concurrency : Optional[int], optional
400
+ Maximum number of concurrent requests across all blocks.
401
+ Controls async request concurrency to prevent overwhelming servers.
381
402
 
382
403
  Returns
383
404
  -------
@@ -397,6 +418,37 @@ class Flow(BaseModel):
397
418
  f"save_freq must be greater than 0, got {save_freq}"
398
419
  )
399
420
 
421
+ # Set up file logging if log_dir is provided
422
+ flow_logger = logger # Use global logger by default
423
+ if log_dir is not None:
424
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
425
+ flow_name = self.metadata.name.replace(" ", "_").lower()
426
+ log_filename = f"{flow_name}_{timestamp}.log"
427
+
428
+ # Create a flow-specific logger for this execution
429
+ unique_id = str(uuid.uuid4())[:8] # Short unique ID
430
+ flow_logger_name = f"{__name__}.flow_{flow_name}_{timestamp}_{unique_id}"
431
+ flow_logger = setup_logger(
432
+ flow_logger_name, log_dir=log_dir, log_filename=log_filename
433
+ )
434
+ flow_logger.propagate = False
435
+ flow_logger.info(
436
+ f"Flow logging enabled - logs will be saved to: {log_dir}/{log_filename}"
437
+ )
438
+ # Validate max_concurrency parameter
439
+ if max_concurrency is not None:
440
+ # Explicitly reject boolean values (bool is a subclass of int in Python)
441
+ if isinstance(max_concurrency, bool) or not isinstance(
442
+ max_concurrency, int
443
+ ):
444
+ raise FlowValidationError(
445
+ f"max_concurrency must be an int, got {type(max_concurrency).__name__}"
446
+ )
447
+ if max_concurrency <= 0:
448
+ raise FlowValidationError(
449
+ f"max_concurrency must be greater than 0, got {max_concurrency}"
450
+ )
451
+
400
452
  # Validate preconditions
401
453
  if not self.blocks:
402
454
  raise FlowValidationError("Cannot generate with empty flow")
@@ -404,6 +456,8 @@ class Flow(BaseModel):
404
456
  if len(dataset) == 0:
405
457
  raise EmptyDatasetError("Input dataset is empty")
406
458
 
459
+ validate_no_duplicates(dataset)
460
+
407
461
  # Check if model configuration has been set for flows with LLM blocks
408
462
  llm_blocks = self._detect_llm_blocks()
409
463
  if llm_blocks and not self._model_config_set:
@@ -420,6 +474,10 @@ class Flow(BaseModel):
420
474
  "Dataset validation failed:\n" + "\n".join(dataset_errors)
421
475
  )
422
476
 
477
+ # Log concurrency control if specified
478
+ if max_concurrency is not None:
479
+ logger.info(f"Using max_concurrency={max_concurrency} for LLM requests")
480
+
423
481
  # Initialize checkpointer if enabled
424
482
  checkpointer = None
425
483
  completed_dataset = None
@@ -436,86 +494,154 @@ class Flow(BaseModel):
436
494
  )
437
495
 
438
496
  if len(remaining_dataset) == 0:
439
- logger.info("All samples already completed, returning existing results")
497
+ flow_logger.info(
498
+ "All samples already completed, returning existing results"
499
+ )
500
+ if log_dir is not None and flow_logger is not logger:
501
+ for h in list(getattr(flow_logger, "handlers", [])):
502
+ try:
503
+ h.flush()
504
+ h.close()
505
+ except Exception:
506
+ pass
507
+ finally:
508
+ flow_logger.removeHandler(h)
509
+
440
510
  return completed_dataset
441
511
 
442
512
  dataset = remaining_dataset
443
- logger.info(f"Resuming with {len(dataset)} remaining samples")
513
+ flow_logger.info(f"Resuming with {len(dataset)} remaining samples")
444
514
 
445
- logger.info(
515
+ flow_logger.info(
446
516
  f"Starting flow '{self.metadata.name}' v{self.metadata.version} "
447
517
  f"with {len(dataset)} samples across {len(self.blocks)} blocks"
518
+ + (f" (max_concurrency={max_concurrency})" if max_concurrency else "")
448
519
  )
449
520
 
521
+ # Reset metrics for this execution
522
+ self._block_metrics = []
523
+ run_start = time.perf_counter()
524
+
450
525
  # Merge migrated runtime params with provided ones (provided ones take precedence)
451
526
  merged_runtime_params = self._migrated_runtime_params.copy()
452
527
  if runtime_params:
453
528
  merged_runtime_params.update(runtime_params)
454
529
  runtime_params = merged_runtime_params
455
530
 
456
- # Process dataset in chunks if checkpointing with save_freq
457
- if checkpointer and save_freq:
458
- all_processed = []
531
+ # Execute flow with metrics capture, ensuring metrics are always displayed/saved
532
+ final_dataset = None
533
+ execution_successful = False
459
534
 
460
- # Process in chunks of save_freq
461
- for i in range(0, len(dataset), save_freq):
462
- chunk_end = min(i + save_freq, len(dataset))
463
- chunk_dataset = dataset.select(range(i, chunk_end))
535
+ try:
536
+ # Process dataset in chunks if checkpointing with save_freq
537
+ if checkpointer and save_freq:
538
+ all_processed = []
464
539
 
465
- logger.info(
466
- f"Processing chunk {i // save_freq + 1}: samples {i} to {chunk_end - 1}"
467
- )
540
+ # Process in chunks of save_freq
541
+ for i in range(0, len(dataset), save_freq):
542
+ chunk_end = min(i + save_freq, len(dataset))
543
+ chunk_dataset = dataset.select(range(i, chunk_end))
468
544
 
469
- # Execute all blocks on this chunk
470
- processed_chunk = self._execute_blocks_on_dataset(
471
- chunk_dataset, runtime_params
472
- )
473
- all_processed.append(processed_chunk)
545
+ flow_logger.info(
546
+ f"Processing chunk {i // save_freq + 1}: samples {i} to {chunk_end - 1}"
547
+ )
474
548
 
475
- # Save checkpoint after chunk completion
476
- checkpointer.add_completed_samples(processed_chunk)
549
+ # Execute all blocks on this chunk
550
+ processed_chunk = self._execute_blocks_on_dataset(
551
+ chunk_dataset, runtime_params, flow_logger, max_concurrency
552
+ )
553
+ all_processed.append(processed_chunk)
477
554
 
478
- # Save final checkpoint for any remaining samples
479
- checkpointer.save_final_checkpoint()
555
+ # Save checkpoint after chunk completion
556
+ checkpointer.add_completed_samples(processed_chunk)
480
557
 
481
- # Combine all processed chunks
482
- final_dataset = safe_concatenate_with_validation(
483
- all_processed, "processed chunks from flow execution"
484
- )
558
+ # Save final checkpoint for any remaining samples
559
+ checkpointer.save_final_checkpoint()
485
560
 
486
- # Combine with previously completed samples if any
487
- if checkpointer and completed_dataset:
561
+ # Combine all processed chunks
488
562
  final_dataset = safe_concatenate_with_validation(
489
- [completed_dataset, final_dataset],
490
- "completed checkpoint data with newly processed data",
563
+ all_processed, "processed chunks from flow execution"
491
564
  )
492
565
 
493
- else:
494
- # Process entire dataset at once
495
- final_dataset = self._execute_blocks_on_dataset(dataset, runtime_params)
496
-
497
- # Save final checkpoint if checkpointing enabled
498
- if checkpointer:
499
- checkpointer.add_completed_samples(final_dataset)
500
- checkpointer.save_final_checkpoint()
501
-
502
566
  # Combine with previously completed samples if any
503
- if completed_dataset:
567
+ if checkpointer and completed_dataset:
504
568
  final_dataset = safe_concatenate_with_validation(
505
569
  [completed_dataset, final_dataset],
506
570
  "completed checkpoint data with newly processed data",
507
571
  )
508
572
 
509
- logger.info(
510
- f"Flow '{self.metadata.name}' completed successfully: "
511
- f"{len(final_dataset)} final samples, "
512
- f"{len(final_dataset.column_names)} final columns"
513
- )
573
+ else:
574
+ # Process entire dataset at once
575
+ final_dataset = self._execute_blocks_on_dataset(
576
+ dataset, runtime_params, flow_logger, max_concurrency
577
+ )
578
+
579
+ # Save final checkpoint if checkpointing enabled
580
+ if checkpointer:
581
+ checkpointer.add_completed_samples(final_dataset)
582
+ checkpointer.save_final_checkpoint()
583
+
584
+ # Combine with previously completed samples if any
585
+ if completed_dataset:
586
+ final_dataset = safe_concatenate_with_validation(
587
+ [completed_dataset, final_dataset],
588
+ "completed checkpoint data with newly processed data",
589
+ )
590
+
591
+ execution_successful = True
592
+
593
+ finally:
594
+ # Always display metrics and save JSON, even if execution failed
595
+ display_metrics_summary(
596
+ self._block_metrics, self.metadata.name, final_dataset
597
+ )
598
+
599
+ # Save metrics to JSON if log_dir is provided
600
+ if log_dir is not None:
601
+ # Ensure necessary variables exist
602
+ if "timestamp" not in locals() or "flow_name" not in locals():
603
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
604
+ flow_name = self.metadata.name.replace(" ", "_").lower()
605
+
606
+ save_metrics_to_json(
607
+ self._block_metrics,
608
+ self.metadata.name,
609
+ self.metadata.version,
610
+ execution_successful,
611
+ run_start,
612
+ log_dir,
613
+ timestamp,
614
+ flow_name,
615
+ flow_logger,
616
+ )
617
+
618
+ # Keep a basic log entry for file logs (only if execution was successful)
619
+ if execution_successful and final_dataset is not None:
620
+ flow_logger.info(
621
+ f"Flow '{self.metadata.name}' completed successfully: "
622
+ f"{len(final_dataset)} final samples, "
623
+ f"{len(final_dataset.column_names)} final columns"
624
+ )
625
+
626
+ # Close file handlers if we opened a flow-specific logger
627
+ if log_dir is not None and flow_logger is not logger:
628
+ for h in list(getattr(flow_logger, "handlers", [])):
629
+ try:
630
+ h.flush()
631
+ h.close()
632
+ except Exception:
633
+ pass
634
+ finally:
635
+ flow_logger.removeHandler(h)
514
636
 
515
637
  return final_dataset
516
638
 
517
639
  def _execute_blocks_on_dataset(
518
- self, dataset: Dataset, runtime_params: dict[str, dict[str, Any]]
640
+ self,
641
+ dataset: Dataset,
642
+ runtime_params: dict[str, dict[str, Any]],
643
+ flow_logger=None,
644
+ max_concurrency: Optional[int] = None,
519
645
  ) -> Dataset:
520
646
  """Execute all blocks in sequence on the given dataset.
521
647
 
@@ -525,17 +651,23 @@ class Flow(BaseModel):
525
651
  Dataset to process through all blocks.
526
652
  runtime_params : Dict[str, Dict[str, Any]]
527
653
  Runtime parameters for block execution.
654
+ flow_logger : logging.Logger, optional
655
+ Logger to use for this execution. Falls back to global logger if None.
656
+ max_concurrency : Optional[int], optional
657
+ Maximum concurrency for LLM requests across blocks.
528
658
 
529
659
  Returns
530
660
  -------
531
661
  Dataset
532
662
  Dataset after processing through all blocks.
533
663
  """
664
+ # Use provided logger or fall back to global logger
665
+ exec_logger = flow_logger if flow_logger is not None else logger
534
666
  current_dataset = dataset
535
667
 
536
668
  # Execute blocks in sequence
537
669
  for i, block in enumerate(self.blocks):
538
- logger.info(
670
+ exec_logger.info(
539
671
  f"Executing block {i + 1}/{len(self.blocks)}: "
540
672
  f"{block.block_name} ({block.__class__.__name__})"
541
673
  )
@@ -543,6 +675,15 @@ class Flow(BaseModel):
543
675
  # Prepare block execution parameters
544
676
  block_kwargs = self._prepare_block_kwargs(block, runtime_params)
545
677
 
678
+ # Add max_concurrency to block kwargs if provided
679
+ if max_concurrency is not None:
680
+ block_kwargs["_flow_max_concurrency"] = max_concurrency
681
+
682
+ # Capture metrics before execution
683
+ start_time = time.perf_counter()
684
+ input_rows = len(current_dataset)
685
+ input_cols = set(current_dataset.column_names)
686
+
546
687
  try:
547
688
  # Check if this is a deprecated block and skip validations
548
689
  is_deprecated_block = (
@@ -552,7 +693,7 @@ class Flow(BaseModel):
552
693
  )
553
694
 
554
695
  if is_deprecated_block:
555
- logger.debug(
696
+ exec_logger.debug(
556
697
  f"Skipping validations for deprecated block: {block.block_name}"
557
698
  )
558
699
  # Call generate() directly to skip validations, but keep the runtime params
@@ -567,14 +708,51 @@ class Flow(BaseModel):
567
708
  f"Block '{block.block_name}' produced empty dataset"
568
709
  )
569
710
 
570
- logger.info(
711
+ # Capture metrics after successful execution
712
+ execution_time = time.perf_counter() - start_time
713
+ output_rows = len(current_dataset)
714
+ output_cols = set(current_dataset.column_names)
715
+ added_cols = output_cols - input_cols
716
+ removed_cols = input_cols - output_cols
717
+
718
+ # Store block metrics
719
+ self._block_metrics.append(
720
+ {
721
+ "block_name": block.block_name,
722
+ "block_type": block.__class__.__name__,
723
+ "execution_time": execution_time,
724
+ "input_rows": input_rows,
725
+ "output_rows": output_rows,
726
+ "added_cols": list(added_cols),
727
+ "removed_cols": list(removed_cols),
728
+ "status": "success",
729
+ }
730
+ )
731
+
732
+ exec_logger.info(
571
733
  f"Block '{block.block_name}' completed successfully: "
572
734
  f"{len(current_dataset)} samples, "
573
735
  f"{len(current_dataset.column_names)} columns"
574
736
  )
575
737
 
576
738
  except Exception as exc:
577
- logger.error(
739
+ # Capture metrics for failed execution
740
+ execution_time = time.perf_counter() - start_time
741
+ self._block_metrics.append(
742
+ {
743
+ "block_name": block.block_name,
744
+ "block_type": block.__class__.__name__,
745
+ "execution_time": execution_time,
746
+ "input_rows": input_rows,
747
+ "output_rows": 0,
748
+ "added_cols": [],
749
+ "removed_cols": [],
750
+ "status": "failed",
751
+ "error": str(exc),
752
+ }
753
+ )
754
+
755
+ exec_logger.error(
578
756
  f"Block '{block.block_name}' failed during execution: {exc}"
579
757
  )
580
758
  raise FlowValidationError(
@@ -899,6 +1077,8 @@ class Flow(BaseModel):
899
1077
  if len(dataset) == 0:
900
1078
  raise EmptyDatasetError("Input dataset is empty")
901
1079
 
1080
+ validate_no_duplicates(dataset)
1081
+
902
1082
  # Use smaller sample size if dataset is smaller
903
1083
  actual_sample_size = min(sample_size, len(dataset))
904
1084
 
@@ -923,7 +1103,7 @@ class Flow(BaseModel):
923
1103
  "execution_time_seconds": 0,
924
1104
  }
925
1105
 
926
- start_time = time.time()
1106
+ start_time = time.perf_counter()
927
1107
 
928
1108
  try:
929
1109
  # Execute the flow with sample data
@@ -931,7 +1111,7 @@ class Flow(BaseModel):
931
1111
  runtime_params = runtime_params or {}
932
1112
 
933
1113
  for i, block in enumerate(self.blocks):
934
- block_start_time = time.time()
1114
+ block_start_time = time.perf_counter()
935
1115
  input_rows = len(current_dataset)
936
1116
 
937
1117
  logger.info(
@@ -990,7 +1170,7 @@ class Flow(BaseModel):
990
1170
  else {},
991
1171
  }
992
1172
 
993
- execution_time = time.time() - start_time
1173
+ execution_time = time.perf_counter() - start_time
994
1174
  dry_run_results["execution_time_seconds"] = execution_time
995
1175
 
996
1176
  logger.info(
@@ -1001,7 +1181,7 @@ class Flow(BaseModel):
1001
1181
  return dry_run_results
1002
1182
 
1003
1183
  except Exception as exc:
1004
- execution_time = time.time() - start_time
1184
+ execution_time = time.perf_counter() - start_time
1005
1185
  dry_run_results["execution_successful"] = False
1006
1186
  dry_run_results["execution_time_seconds"] = execution_time
1007
1187
  dry_run_results["error"] = str(exc)
@@ -1066,6 +1246,90 @@ class Flow(BaseModel):
1066
1246
  "block_names": [block.block_name for block in self.blocks],
1067
1247
  }
1068
1248
 
1249
+ def get_dataset_requirements(self) -> Optional[DatasetRequirements]:
1250
+ """Get the dataset requirements for this flow.
1251
+
1252
+ Returns
1253
+ -------
1254
+ Optional[DatasetRequirements]
1255
+ Dataset requirements object or None if not defined.
1256
+
1257
+ Examples
1258
+ --------
1259
+ >>> flow = Flow.from_yaml("path/to/flow.yaml")
1260
+ >>> requirements = flow.get_dataset_requirements()
1261
+ >>> if requirements:
1262
+ ... print(f"Required columns: {requirements.required_columns}")
1263
+ """
1264
+ return self.metadata.dataset_requirements
1265
+
1266
+ def get_dataset_schema(self) -> Dataset:
1267
+ """Get an empty dataset with the correct schema for this flow.
1268
+
1269
+ Returns
1270
+ -------
1271
+ Dataset
1272
+ Empty HuggingFace Dataset with the correct schema/features for this flow.
1273
+ Users can add data to this dataset or use it to validate their own dataset schema.
1274
+
1275
+ Examples
1276
+ --------
1277
+ >>> flow = Flow.from_yaml("path/to/flow.yaml")
1278
+ >>> schema_dataset = flow.get_dataset_schema()
1279
+ >>>
1280
+ >>> # Add your data
1281
+ >>> schema_dataset = schema_dataset.add_item({
1282
+ ... "document": "Your document text",
1283
+ ... "domain": "Computer Science",
1284
+ ... "icl_document": "Example document"
1285
+ ... })
1286
+ >>>
1287
+ >>> # Or validate your existing dataset schema
1288
+ >>> my_dataset = Dataset.from_dict(my_data)
1289
+ >>> if my_dataset.features == schema_dataset.features:
1290
+ ... print("Schema matches!")
1291
+ """
1292
+
1293
+ requirements = self.get_dataset_requirements()
1294
+
1295
+ if requirements is None:
1296
+ # Return empty dataset with no schema requirements
1297
+ return Dataset.from_dict({})
1298
+
1299
+ # Build schema features
1300
+ schema_features = {}
1301
+
1302
+ # Process required columns
1303
+ for col_name in requirements.required_columns:
1304
+ col_type = requirements.column_types.get(col_name, "string")
1305
+ schema_features[col_name] = self._map_column_type_to_feature(col_type)
1306
+
1307
+ # Process optional columns
1308
+ for col_name in requirements.optional_columns:
1309
+ col_type = requirements.column_types.get(col_name, "string")
1310
+ schema_features[col_name] = self._map_column_type_to_feature(col_type)
1311
+
1312
+ # Create empty dataset with the correct features
1313
+ features = datasets.Features(schema_features)
1314
+ empty_data = {col_name: [] for col_name in schema_features.keys()}
1315
+
1316
+ return Dataset.from_dict(empty_data, features=features)
1317
+
1318
+ def _map_column_type_to_feature(self, col_type: str):
1319
+ """Map column type string to HuggingFace feature type."""
1320
+ # Map common type names to HuggingFace types
1321
+ if col_type in ["str", "string", "text"]:
1322
+ return datasets.Value("string")
1323
+ elif col_type in ["int", "integer"]:
1324
+ return datasets.Value("int64")
1325
+ elif col_type in ["float", "number"]:
1326
+ return datasets.Value("float64")
1327
+ elif col_type in ["bool", "boolean"]:
1328
+ return datasets.Value("bool")
1329
+ else:
1330
+ # Default to string for unknown types
1331
+ return datasets.Value("string")
1332
+
1069
1333
  def print_info(self) -> None:
1070
1334
  """
1071
1335
  Print an interactive summary of the Flow in the console.
@@ -15,6 +15,60 @@ def safe_concatenate_datasets(datasets: list):
15
15
  return concatenate_datasets(filtered_datasets)
16
16
 
17
17
 
18
+ def validate_no_duplicates(dataset: Dataset) -> None:
19
+ """
20
+ Validate that the input dataset contains only unique rows.
21
+
22
+ Uses pandas `.duplicated()` for efficient duplicate detection, with preprocessing
23
+ to handle numpy arrays that cause TypeError in pandas duplicate detection.
24
+ Raises FlowValidationError if duplicates are found, including a count
25
+ of the duplicate rows detected.
26
+
27
+ Parameters
28
+ ----------
29
+ dataset : Dataset
30
+ Input dataset to validate.
31
+
32
+ Raises
33
+ ------
34
+ FlowValidationError
35
+ If duplicate rows are detected in the dataset.
36
+ """
37
+ if len(dataset) == 0:
38
+ return
39
+
40
+ df = dataset.to_pandas()
41
+
42
+ # Try pandas duplicated() first - only convert types if we hit unhashable error
43
+ try:
44
+ duplicate_count = int(df.duplicated(keep="first").sum())
45
+ except TypeError as e:
46
+ if "unhashable type" in str(e):
47
+ # Convert unhashable types to tuples so pandas can hash them
48
+ for col in df.columns:
49
+ if df[col].dtype == "object": # Only check object columns
50
+ df[col] = df[col].apply(
51
+ lambda x: (
52
+ tuple(sorted(x.items()))
53
+ if isinstance(x, dict)
54
+ else tuple(x)
55
+ if hasattr(x, "__iter__")
56
+ and not isinstance(x, (str, bytes))
57
+ else x
58
+ )
59
+ )
60
+ duplicate_count = int(df.duplicated(keep="first").sum())
61
+ else:
62
+ raise # Re-raise if it's a different TypeError
63
+
64
+ if duplicate_count > 0:
65
+ raise FlowValidationError(
66
+ f"Input dataset contains {duplicate_count} duplicate rows. "
67
+ f"SDG Hub operations require unique input rows. "
68
+ f"Please deduplicate your dataset before processing."
69
+ )
70
+
71
+
18
72
  def safe_concatenate_with_validation(
19
73
  datasets: list, context: str = "datasets"
20
74
  ) -> Dataset: