sdg-hub 0.2.2__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdg_hub/_version.py +2 -2
- sdg_hub/core/blocks/llm/client_manager.py +63 -26
- sdg_hub/core/blocks/llm/llm_chat_block.py +12 -9
- sdg_hub/core/blocks/llm/text_parser_block.py +88 -21
- sdg_hub/core/blocks/transform/__init__.py +2 -0
- sdg_hub/core/blocks/transform/json_structure_block.py +142 -0
- sdg_hub/core/flow/base.py +199 -56
- sdg_hub/core/utils/datautils.py +45 -2
- sdg_hub/core/utils/flow_metrics.py +261 -0
- sdg_hub/core/utils/logger_config.py +50 -9
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/__init__.py +0 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/detailed_summary/__init__.py +0 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/detailed_summary/detailed_summary.yaml +11 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/detailed_summary/flow.yaml +159 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/extractive_summary/__init__.py +0 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/extractive_summary/extractive_summary.yaml +65 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/extractive_summary/flow.yaml +161 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/generate_answers.yaml +15 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/generate_multiple_qa.yaml +21 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/generate_question_list.yaml +44 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/key_facts/__init__.py +0 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/key_facts/flow.yaml +104 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/enhanced_multi_summary_qa/key_facts/key_facts_summary.yaml +61 -0
- sdg_hub/flows/text_analysis/__init__.py +2 -0
- sdg_hub/flows/text_analysis/structured_insights/__init__.py +6 -0
- sdg_hub/flows/text_analysis/structured_insights/analyze_sentiment.yaml +27 -0
- sdg_hub/flows/text_analysis/structured_insights/extract_entities.yaml +38 -0
- sdg_hub/flows/text_analysis/structured_insights/extract_keywords.yaml +21 -0
- sdg_hub/flows/text_analysis/structured_insights/flow.yaml +153 -0
- sdg_hub/flows/text_analysis/structured_insights/summarize.yaml +21 -0
- {sdg_hub-0.2.2.dist-info → sdg_hub-0.3.1.dist-info}/METADATA +3 -1
- {sdg_hub-0.2.2.dist-info → sdg_hub-0.3.1.dist-info}/RECORD +35 -13
- {sdg_hub-0.2.2.dist-info → sdg_hub-0.3.1.dist-info}/WHEEL +0 -0
- {sdg_hub-0.2.2.dist-info → sdg_hub-0.3.1.dist-info}/licenses/LICENSE +0 -0
- {sdg_hub-0.2.2.dist-info → sdg_hub-0.3.1.dist-info}/top_level.txt +0 -0
sdg_hub/core/flow/base.py
CHANGED
@@ -2,13 +2,22 @@
|
|
2
2
|
"""Pydantic-based Flow class for managing data generation pipelines."""
|
3
3
|
|
4
4
|
# Standard
|
5
|
+
from datetime import datetime
|
5
6
|
from pathlib import Path
|
6
7
|
from typing import Any, Optional, Union
|
7
8
|
import time
|
9
|
+
import uuid
|
8
10
|
|
9
11
|
# Third Party
|
10
12
|
from datasets import Dataset
|
11
|
-
from pydantic import
|
13
|
+
from pydantic import (
|
14
|
+
BaseModel,
|
15
|
+
ConfigDict,
|
16
|
+
Field,
|
17
|
+
PrivateAttr,
|
18
|
+
field_validator,
|
19
|
+
model_validator,
|
20
|
+
)
|
12
21
|
from rich.console import Console
|
13
22
|
from rich.panel import Panel
|
14
23
|
from rich.table import Table
|
@@ -21,6 +30,7 @@ from ..blocks.base import BaseBlock
|
|
21
30
|
from ..blocks.registry import BlockRegistry
|
22
31
|
from ..utils.datautils import safe_concatenate_with_validation, validate_no_duplicates
|
23
32
|
from ..utils.error_handling import EmptyDatasetError, FlowValidationError
|
33
|
+
from ..utils.flow_metrics import display_metrics_summary, save_metrics_to_json
|
24
34
|
from ..utils.logger_config import setup_logger
|
25
35
|
from ..utils.path_resolution import resolve_path
|
26
36
|
from ..utils.yaml_utils import save_flow_yaml
|
@@ -67,6 +77,9 @@ class Flow(BaseModel):
|
|
67
77
|
_migrated_runtime_params: dict[str, dict[str, Any]] = {}
|
68
78
|
_llm_client: Any = None # Only used for backward compatibility with old YAMLs
|
69
79
|
_model_config_set: bool = False # Track if model configuration has been set
|
80
|
+
_block_metrics: list[dict[str, Any]] = PrivateAttr(
|
81
|
+
default_factory=list
|
82
|
+
) # Track block execution metrics
|
70
83
|
|
71
84
|
@field_validator("blocks")
|
72
85
|
@classmethod
|
@@ -356,6 +369,7 @@ class Flow(BaseModel):
|
|
356
369
|
runtime_params: Optional[dict[str, dict[str, Any]]] = None,
|
357
370
|
checkpoint_dir: Optional[str] = None,
|
358
371
|
save_freq: Optional[int] = None,
|
372
|
+
log_dir: Optional[str] = None,
|
359
373
|
max_concurrency: Optional[int] = None,
|
360
374
|
) -> Dataset:
|
361
375
|
"""Execute the flow blocks in sequence to generate data.
|
@@ -378,6 +392,10 @@ class Flow(BaseModel):
|
|
378
392
|
save_freq : Optional[int], optional
|
379
393
|
Number of completed samples after which to save a checkpoint.
|
380
394
|
If None, only saves final results when checkpointing is enabled.
|
395
|
+
log_dir : Optional[str], optional
|
396
|
+
Directory to save execution logs. If provided, logs will be written to both
|
397
|
+
console and a log file in this directory. Maintains backward compatibility
|
398
|
+
when None.
|
381
399
|
max_concurrency : Optional[int], optional
|
382
400
|
Maximum number of concurrent requests across all blocks.
|
383
401
|
Controls async request concurrency to prevent overwhelming servers.
|
@@ -400,6 +418,23 @@ class Flow(BaseModel):
|
|
400
418
|
f"save_freq must be greater than 0, got {save_freq}"
|
401
419
|
)
|
402
420
|
|
421
|
+
# Set up file logging if log_dir is provided
|
422
|
+
flow_logger = logger # Use global logger by default
|
423
|
+
if log_dir is not None:
|
424
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
425
|
+
flow_name = self.metadata.name.replace(" ", "_").lower()
|
426
|
+
log_filename = f"{flow_name}_{timestamp}.log"
|
427
|
+
|
428
|
+
# Create a flow-specific logger for this execution
|
429
|
+
unique_id = str(uuid.uuid4())[:8] # Short unique ID
|
430
|
+
flow_logger_name = f"{__name__}.flow_{flow_name}_{timestamp}_{unique_id}"
|
431
|
+
flow_logger = setup_logger(
|
432
|
+
flow_logger_name, log_dir=log_dir, log_filename=log_filename
|
433
|
+
)
|
434
|
+
flow_logger.propagate = False
|
435
|
+
flow_logger.info(
|
436
|
+
f"Flow logging enabled - logs will be saved to: {log_dir}/{log_filename}"
|
437
|
+
)
|
403
438
|
# Validate max_concurrency parameter
|
404
439
|
if max_concurrency is not None:
|
405
440
|
# Explicitly reject boolean values (bool is a subclass of int in Python)
|
@@ -459,84 +494,145 @@ class Flow(BaseModel):
|
|
459
494
|
)
|
460
495
|
|
461
496
|
if len(remaining_dataset) == 0:
|
462
|
-
|
497
|
+
flow_logger.info(
|
498
|
+
"All samples already completed, returning existing results"
|
499
|
+
)
|
500
|
+
if log_dir is not None and flow_logger is not logger:
|
501
|
+
for h in list(getattr(flow_logger, "handlers", [])):
|
502
|
+
try:
|
503
|
+
h.flush()
|
504
|
+
h.close()
|
505
|
+
except Exception:
|
506
|
+
pass
|
507
|
+
finally:
|
508
|
+
flow_logger.removeHandler(h)
|
509
|
+
|
463
510
|
return completed_dataset
|
464
511
|
|
465
512
|
dataset = remaining_dataset
|
466
|
-
|
513
|
+
flow_logger.info(f"Resuming with {len(dataset)} remaining samples")
|
467
514
|
|
468
|
-
|
515
|
+
flow_logger.info(
|
469
516
|
f"Starting flow '{self.metadata.name}' v{self.metadata.version} "
|
470
517
|
f"with {len(dataset)} samples across {len(self.blocks)} blocks"
|
471
518
|
+ (f" (max_concurrency={max_concurrency})" if max_concurrency else "")
|
472
519
|
)
|
473
520
|
|
521
|
+
# Reset metrics for this execution
|
522
|
+
self._block_metrics = []
|
523
|
+
run_start = time.perf_counter()
|
524
|
+
|
474
525
|
# Merge migrated runtime params with provided ones (provided ones take precedence)
|
475
526
|
merged_runtime_params = self._migrated_runtime_params.copy()
|
476
527
|
if runtime_params:
|
477
528
|
merged_runtime_params.update(runtime_params)
|
478
529
|
runtime_params = merged_runtime_params
|
479
530
|
|
480
|
-
#
|
481
|
-
|
482
|
-
|
531
|
+
# Execute flow with metrics capture, ensuring metrics are always displayed/saved
|
532
|
+
final_dataset = None
|
533
|
+
execution_successful = False
|
483
534
|
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
535
|
+
try:
|
536
|
+
# Process dataset in chunks if checkpointing with save_freq
|
537
|
+
if checkpointer and save_freq:
|
538
|
+
all_processed = []
|
488
539
|
|
489
|
-
|
490
|
-
|
491
|
-
|
540
|
+
# Process in chunks of save_freq
|
541
|
+
for i in range(0, len(dataset), save_freq):
|
542
|
+
chunk_end = min(i + save_freq, len(dataset))
|
543
|
+
chunk_dataset = dataset.select(range(i, chunk_end))
|
492
544
|
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
)
|
497
|
-
all_processed.append(processed_chunk)
|
545
|
+
flow_logger.info(
|
546
|
+
f"Processing chunk {i // save_freq + 1}: samples {i} to {chunk_end - 1}"
|
547
|
+
)
|
498
548
|
|
499
|
-
|
500
|
-
|
549
|
+
# Execute all blocks on this chunk
|
550
|
+
processed_chunk = self._execute_blocks_on_dataset(
|
551
|
+
chunk_dataset, runtime_params, flow_logger, max_concurrency
|
552
|
+
)
|
553
|
+
all_processed.append(processed_chunk)
|
501
554
|
|
502
|
-
|
503
|
-
|
555
|
+
# Save checkpoint after chunk completion
|
556
|
+
checkpointer.add_completed_samples(processed_chunk)
|
504
557
|
|
505
|
-
|
506
|
-
|
507
|
-
all_processed, "processed chunks from flow execution"
|
508
|
-
)
|
558
|
+
# Save final checkpoint for any remaining samples
|
559
|
+
checkpointer.save_final_checkpoint()
|
509
560
|
|
510
|
-
|
511
|
-
if checkpointer and completed_dataset:
|
561
|
+
# Combine all processed chunks
|
512
562
|
final_dataset = safe_concatenate_with_validation(
|
513
|
-
|
514
|
-
"completed checkpoint data with newly processed data",
|
563
|
+
all_processed, "processed chunks from flow execution"
|
515
564
|
)
|
516
565
|
|
517
|
-
else:
|
518
|
-
# Process entire dataset at once
|
519
|
-
final_dataset = self._execute_blocks_on_dataset(
|
520
|
-
dataset, runtime_params, max_concurrency
|
521
|
-
)
|
522
|
-
|
523
|
-
# Save final checkpoint if checkpointing enabled
|
524
|
-
if checkpointer:
|
525
|
-
checkpointer.add_completed_samples(final_dataset)
|
526
|
-
checkpointer.save_final_checkpoint()
|
527
|
-
|
528
566
|
# Combine with previously completed samples if any
|
529
|
-
if completed_dataset:
|
567
|
+
if checkpointer and completed_dataset:
|
530
568
|
final_dataset = safe_concatenate_with_validation(
|
531
569
|
[completed_dataset, final_dataset],
|
532
570
|
"completed checkpoint data with newly processed data",
|
533
571
|
)
|
534
572
|
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
573
|
+
else:
|
574
|
+
# Process entire dataset at once
|
575
|
+
final_dataset = self._execute_blocks_on_dataset(
|
576
|
+
dataset, runtime_params, flow_logger, max_concurrency
|
577
|
+
)
|
578
|
+
|
579
|
+
# Save final checkpoint if checkpointing enabled
|
580
|
+
if checkpointer:
|
581
|
+
checkpointer.add_completed_samples(final_dataset)
|
582
|
+
checkpointer.save_final_checkpoint()
|
583
|
+
|
584
|
+
# Combine with previously completed samples if any
|
585
|
+
if completed_dataset:
|
586
|
+
final_dataset = safe_concatenate_with_validation(
|
587
|
+
[completed_dataset, final_dataset],
|
588
|
+
"completed checkpoint data with newly processed data",
|
589
|
+
)
|
590
|
+
|
591
|
+
execution_successful = True
|
592
|
+
|
593
|
+
finally:
|
594
|
+
# Always display metrics and save JSON, even if execution failed
|
595
|
+
display_metrics_summary(
|
596
|
+
self._block_metrics, self.metadata.name, final_dataset
|
597
|
+
)
|
598
|
+
|
599
|
+
# Save metrics to JSON if log_dir is provided
|
600
|
+
if log_dir is not None:
|
601
|
+
# Ensure necessary variables exist
|
602
|
+
if "timestamp" not in locals() or "flow_name" not in locals():
|
603
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
604
|
+
flow_name = self.metadata.name.replace(" ", "_").lower()
|
605
|
+
|
606
|
+
save_metrics_to_json(
|
607
|
+
self._block_metrics,
|
608
|
+
self.metadata.name,
|
609
|
+
self.metadata.version,
|
610
|
+
execution_successful,
|
611
|
+
run_start,
|
612
|
+
log_dir,
|
613
|
+
timestamp,
|
614
|
+
flow_name,
|
615
|
+
flow_logger,
|
616
|
+
)
|
617
|
+
|
618
|
+
# Keep a basic log entry for file logs (only if execution was successful)
|
619
|
+
if execution_successful and final_dataset is not None:
|
620
|
+
flow_logger.info(
|
621
|
+
f"Flow '{self.metadata.name}' completed successfully: "
|
622
|
+
f"{len(final_dataset)} final samples, "
|
623
|
+
f"{len(final_dataset.column_names)} final columns"
|
624
|
+
)
|
625
|
+
|
626
|
+
# Close file handlers if we opened a flow-specific logger
|
627
|
+
if log_dir is not None and flow_logger is not logger:
|
628
|
+
for h in list(getattr(flow_logger, "handlers", [])):
|
629
|
+
try:
|
630
|
+
h.flush()
|
631
|
+
h.close()
|
632
|
+
except Exception:
|
633
|
+
pass
|
634
|
+
finally:
|
635
|
+
flow_logger.removeHandler(h)
|
540
636
|
|
541
637
|
return final_dataset
|
542
638
|
|
@@ -544,6 +640,7 @@ class Flow(BaseModel):
|
|
544
640
|
self,
|
545
641
|
dataset: Dataset,
|
546
642
|
runtime_params: dict[str, dict[str, Any]],
|
643
|
+
flow_logger=None,
|
547
644
|
max_concurrency: Optional[int] = None,
|
548
645
|
) -> Dataset:
|
549
646
|
"""Execute all blocks in sequence on the given dataset.
|
@@ -554,6 +651,8 @@ class Flow(BaseModel):
|
|
554
651
|
Dataset to process through all blocks.
|
555
652
|
runtime_params : Dict[str, Dict[str, Any]]
|
556
653
|
Runtime parameters for block execution.
|
654
|
+
flow_logger : logging.Logger, optional
|
655
|
+
Logger to use for this execution. Falls back to global logger if None.
|
557
656
|
max_concurrency : Optional[int], optional
|
558
657
|
Maximum concurrency for LLM requests across blocks.
|
559
658
|
|
@@ -562,11 +661,13 @@ class Flow(BaseModel):
|
|
562
661
|
Dataset
|
563
662
|
Dataset after processing through all blocks.
|
564
663
|
"""
|
664
|
+
# Use provided logger or fall back to global logger
|
665
|
+
exec_logger = flow_logger if flow_logger is not None else logger
|
565
666
|
current_dataset = dataset
|
566
667
|
|
567
668
|
# Execute blocks in sequence
|
568
669
|
for i, block in enumerate(self.blocks):
|
569
|
-
|
670
|
+
exec_logger.info(
|
570
671
|
f"Executing block {i + 1}/{len(self.blocks)}: "
|
571
672
|
f"{block.block_name} ({block.__class__.__name__})"
|
572
673
|
)
|
@@ -578,6 +679,11 @@ class Flow(BaseModel):
|
|
578
679
|
if max_concurrency is not None:
|
579
680
|
block_kwargs["_flow_max_concurrency"] = max_concurrency
|
580
681
|
|
682
|
+
# Capture metrics before execution
|
683
|
+
start_time = time.perf_counter()
|
684
|
+
input_rows = len(current_dataset)
|
685
|
+
input_cols = set(current_dataset.column_names)
|
686
|
+
|
581
687
|
try:
|
582
688
|
# Check if this is a deprecated block and skip validations
|
583
689
|
is_deprecated_block = (
|
@@ -587,7 +693,7 @@ class Flow(BaseModel):
|
|
587
693
|
)
|
588
694
|
|
589
695
|
if is_deprecated_block:
|
590
|
-
|
696
|
+
exec_logger.debug(
|
591
697
|
f"Skipping validations for deprecated block: {block.block_name}"
|
592
698
|
)
|
593
699
|
# Call generate() directly to skip validations, but keep the runtime params
|
@@ -602,14 +708,51 @@ class Flow(BaseModel):
|
|
602
708
|
f"Block '{block.block_name}' produced empty dataset"
|
603
709
|
)
|
604
710
|
|
605
|
-
|
711
|
+
# Capture metrics after successful execution
|
712
|
+
execution_time = time.perf_counter() - start_time
|
713
|
+
output_rows = len(current_dataset)
|
714
|
+
output_cols = set(current_dataset.column_names)
|
715
|
+
added_cols = output_cols - input_cols
|
716
|
+
removed_cols = input_cols - output_cols
|
717
|
+
|
718
|
+
# Store block metrics
|
719
|
+
self._block_metrics.append(
|
720
|
+
{
|
721
|
+
"block_name": block.block_name,
|
722
|
+
"block_type": block.__class__.__name__,
|
723
|
+
"execution_time": execution_time,
|
724
|
+
"input_rows": input_rows,
|
725
|
+
"output_rows": output_rows,
|
726
|
+
"added_cols": list(added_cols),
|
727
|
+
"removed_cols": list(removed_cols),
|
728
|
+
"status": "success",
|
729
|
+
}
|
730
|
+
)
|
731
|
+
|
732
|
+
exec_logger.info(
|
606
733
|
f"Block '{block.block_name}' completed successfully: "
|
607
734
|
f"{len(current_dataset)} samples, "
|
608
735
|
f"{len(current_dataset.column_names)} columns"
|
609
736
|
)
|
610
737
|
|
611
738
|
except Exception as exc:
|
612
|
-
|
739
|
+
# Capture metrics for failed execution
|
740
|
+
execution_time = time.perf_counter() - start_time
|
741
|
+
self._block_metrics.append(
|
742
|
+
{
|
743
|
+
"block_name": block.block_name,
|
744
|
+
"block_type": block.__class__.__name__,
|
745
|
+
"execution_time": execution_time,
|
746
|
+
"input_rows": input_rows,
|
747
|
+
"output_rows": 0,
|
748
|
+
"added_cols": [],
|
749
|
+
"removed_cols": [],
|
750
|
+
"status": "failed",
|
751
|
+
"error": str(exc),
|
752
|
+
}
|
753
|
+
)
|
754
|
+
|
755
|
+
exec_logger.error(
|
613
756
|
f"Block '{block.block_name}' failed during execution: {exc}"
|
614
757
|
)
|
615
758
|
raise FlowValidationError(
|
@@ -960,7 +1103,7 @@ class Flow(BaseModel):
|
|
960
1103
|
"execution_time_seconds": 0,
|
961
1104
|
}
|
962
1105
|
|
963
|
-
start_time = time.
|
1106
|
+
start_time = time.perf_counter()
|
964
1107
|
|
965
1108
|
try:
|
966
1109
|
# Execute the flow with sample data
|
@@ -968,7 +1111,7 @@ class Flow(BaseModel):
|
|
968
1111
|
runtime_params = runtime_params or {}
|
969
1112
|
|
970
1113
|
for i, block in enumerate(self.blocks):
|
971
|
-
block_start_time = time.
|
1114
|
+
block_start_time = time.perf_counter()
|
972
1115
|
input_rows = len(current_dataset)
|
973
1116
|
|
974
1117
|
logger.info(
|
@@ -1027,7 +1170,7 @@ class Flow(BaseModel):
|
|
1027
1170
|
else {},
|
1028
1171
|
}
|
1029
1172
|
|
1030
|
-
execution_time = time.
|
1173
|
+
execution_time = time.perf_counter() - start_time
|
1031
1174
|
dry_run_results["execution_time_seconds"] = execution_time
|
1032
1175
|
|
1033
1176
|
logger.info(
|
@@ -1038,7 +1181,7 @@ class Flow(BaseModel):
|
|
1038
1181
|
return dry_run_results
|
1039
1182
|
|
1040
1183
|
except Exception as exc:
|
1041
|
-
execution_time = time.
|
1184
|
+
execution_time = time.perf_counter() - start_time
|
1042
1185
|
dry_run_results["execution_successful"] = False
|
1043
1186
|
dry_run_results["execution_time_seconds"] = execution_time
|
1044
1187
|
dry_run_results["error"] = str(exc)
|
sdg_hub/core/utils/datautils.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
# Third Party
|
2
2
|
from datasets import Dataset, concatenate_datasets
|
3
|
+
import numpy as np
|
3
4
|
|
4
5
|
# Local
|
5
6
|
from .error_handling import FlowValidationError
|
@@ -19,7 +20,8 @@ def validate_no_duplicates(dataset: Dataset) -> None:
|
|
19
20
|
"""
|
20
21
|
Validate that the input dataset contains only unique rows.
|
21
22
|
|
22
|
-
Uses pandas `.duplicated()` for efficient duplicate detection
|
23
|
+
Uses pandas `.duplicated()` for efficient duplicate detection, with preprocessing
|
24
|
+
to handle numpy arrays that cause TypeError in pandas duplicate detection.
|
23
25
|
Raises FlowValidationError if duplicates are found, including a count
|
24
26
|
of the duplicate rows detected.
|
25
27
|
|
@@ -33,9 +35,50 @@ def validate_no_duplicates(dataset: Dataset) -> None:
|
|
33
35
|
FlowValidationError
|
34
36
|
If duplicate rows are detected in the dataset.
|
35
37
|
"""
|
38
|
+
if len(dataset) == 0:
|
39
|
+
return
|
40
|
+
|
36
41
|
df = dataset.to_pandas()
|
37
|
-
duplicate_count = int(df.duplicated(keep="first").sum())
|
38
42
|
|
43
|
+
def is_hashable(x):
|
44
|
+
try:
|
45
|
+
hash(x)
|
46
|
+
return True
|
47
|
+
except TypeError:
|
48
|
+
return False
|
49
|
+
|
50
|
+
def make_hashable(x):
|
51
|
+
if is_hashable(x):
|
52
|
+
# int, float, str, bytes, None etc. are already hashable
|
53
|
+
return x
|
54
|
+
if isinstance(x, np.ndarray):
|
55
|
+
if x.ndim == 0:
|
56
|
+
return make_hashable(x.item())
|
57
|
+
return tuple(make_hashable(i) for i in x)
|
58
|
+
if isinstance(x, dict):
|
59
|
+
# sort robustly even with heterogeneous key types
|
60
|
+
return tuple(
|
61
|
+
sorted(
|
62
|
+
((k, make_hashable(v)) for k, v in x.items()),
|
63
|
+
key=lambda kv: repr(kv[0]),
|
64
|
+
)
|
65
|
+
)
|
66
|
+
if isinstance(x, (set, frozenset)):
|
67
|
+
# order‑insensitive
|
68
|
+
return frozenset(make_hashable(i) for i in x)
|
69
|
+
if hasattr(x, "__iter__"):
|
70
|
+
# lists, tuples, custom iterables
|
71
|
+
return tuple(make_hashable(i) for i in x)
|
72
|
+
# last‑resort fallback to a stable representation
|
73
|
+
return repr(x)
|
74
|
+
|
75
|
+
# Apply to the whole dataframe to ensure every cell is hashable
|
76
|
+
if hasattr(df, "map"):
|
77
|
+
df = df.map(make_hashable)
|
78
|
+
else:
|
79
|
+
df = df.applymap(make_hashable)
|
80
|
+
|
81
|
+
duplicate_count = int(df.duplicated(keep="first").sum())
|
39
82
|
if duplicate_count > 0:
|
40
83
|
raise FlowValidationError(
|
41
84
|
f"Input dataset contains {duplicate_count} duplicate rows. "
|