cursus 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (151) hide show
  1. cursus/__init__.py +120 -0
  2. cursus/__version__.py +9 -0
  3. cursus/api/__init__.py +30 -0
  4. cursus/api/dag/__init__.py +29 -0
  5. cursus/api/dag/base_dag.py +74 -0
  6. cursus/api/dag/edge_types.py +281 -0
  7. cursus/api/dag/enhanced_dag.py +372 -0
  8. cursus/cli/__init__.py +416 -0
  9. cursus/core/__init__.py +163 -0
  10. cursus/core/assembler/__init__.py +14 -0
  11. cursus/core/assembler/pipeline_assembler.py +468 -0
  12. cursus/core/assembler/pipeline_template_base.py +420 -0
  13. cursus/core/base/__init__.py +38 -0
  14. cursus/core/base/builder_base.py +822 -0
  15. cursus/core/base/config_base.py +450 -0
  16. cursus/core/base/contract_base.py +303 -0
  17. cursus/core/base/enums.py +46 -0
  18. cursus/core/base/hyperparameters_base.py +338 -0
  19. cursus/core/base/specification_base.py +626 -0
  20. cursus/core/compiler/__init__.py +58 -0
  21. cursus/core/compiler/config_resolver.py +725 -0
  22. cursus/core/compiler/dag_compiler.py +529 -0
  23. cursus/core/compiler/dynamic_template.py +864 -0
  24. cursus/core/compiler/exceptions.py +104 -0
  25. cursus/core/compiler/name_generator.py +112 -0
  26. cursus/core/compiler/validation.py +339 -0
  27. cursus/core/config_fields/__init__.py +308 -0
  28. cursus/core/config_fields/circular_reference_tracker.py +212 -0
  29. cursus/core/config_fields/config_class_detector.py +209 -0
  30. cursus/core/config_fields/config_class_store.py +136 -0
  31. cursus/core/config_fields/config_field_categorizer.py +393 -0
  32. cursus/core/config_fields/config_merger.py +363 -0
  33. cursus/core/config_fields/constants.py +89 -0
  34. cursus/core/config_fields/type_aware_config_serializer.py +661 -0
  35. cursus/core/deps/__init__.py +52 -0
  36. cursus/core/deps/base_specifications.py +664 -0
  37. cursus/core/deps/dependency_resolver.py +362 -0
  38. cursus/core/deps/factory.py +50 -0
  39. cursus/core/deps/property_reference.py +244 -0
  40. cursus/core/deps/registry_manager.py +226 -0
  41. cursus/core/deps/semantic_matcher.py +268 -0
  42. cursus/core/deps/specification_registry.py +115 -0
  43. cursus/steps/__init__.py +36 -0
  44. cursus/steps/builders/__init__.py +53 -0
  45. cursus/steps/builders/builder_batch_transform_step.py +295 -0
  46. cursus/steps/builders/builder_currency_conversion_step.py +367 -0
  47. cursus/steps/builders/builder_data_load_step_cradle.py +597 -0
  48. cursus/steps/builders/builder_dummy_training_step.py +508 -0
  49. cursus/steps/builders/builder_model_calibration_step.py +504 -0
  50. cursus/steps/builders/builder_model_eval_step_xgboost.py +359 -0
  51. cursus/steps/builders/builder_model_step_pytorch.py +249 -0
  52. cursus/steps/builders/builder_model_step_xgboost.py +247 -0
  53. cursus/steps/builders/builder_package_step.py +394 -0
  54. cursus/steps/builders/builder_payload_step.py +360 -0
  55. cursus/steps/builders/builder_registration_step.py +387 -0
  56. cursus/steps/builders/builder_risk_table_mapping_step.py +506 -0
  57. cursus/steps/builders/builder_tabular_preprocessing_step.py +367 -0
  58. cursus/steps/builders/builder_training_step_pytorch.py +474 -0
  59. cursus/steps/builders/builder_training_step_xgboost.py +610 -0
  60. cursus/steps/builders/s3_utils.py +205 -0
  61. cursus/steps/configs/__init__.py +98 -0
  62. cursus/steps/configs/config_batch_transform_step.py +99 -0
  63. cursus/steps/configs/config_currency_conversion_step.py +194 -0
  64. cursus/steps/configs/config_data_load_step_cradle.py +889 -0
  65. cursus/steps/configs/config_dummy_training_step.py +165 -0
  66. cursus/steps/configs/config_model_calibration_step.py +347 -0
  67. cursus/steps/configs/config_model_eval_step_xgboost.py +223 -0
  68. cursus/steps/configs/config_model_step_pytorch.py +105 -0
  69. cursus/steps/configs/config_model_step_xgboost.py +164 -0
  70. cursus/steps/configs/config_package_step.py +113 -0
  71. cursus/steps/configs/config_payload_step.py +635 -0
  72. cursus/steps/configs/config_processing_step_base.py +363 -0
  73. cursus/steps/configs/config_registration_step.py +413 -0
  74. cursus/steps/configs/config_risk_table_mapping_step.py +110 -0
  75. cursus/steps/configs/config_tabular_preprocessing_step.py +231 -0
  76. cursus/steps/configs/config_training_step_pytorch.py +83 -0
  77. cursus/steps/configs/config_training_step_xgboost.py +213 -0
  78. cursus/steps/configs/utils.py +466 -0
  79. cursus/steps/contracts/__init__.py +61 -0
  80. cursus/steps/contracts/contract_validator.py +262 -0
  81. cursus/steps/contracts/cradle_data_loading_contract.py +65 -0
  82. cursus/steps/contracts/currency_conversion_contract.py +77 -0
  83. cursus/steps/contracts/dummy_training_contract.py +31 -0
  84. cursus/steps/contracts/mims_package_contract.py +47 -0
  85. cursus/steps/contracts/mims_payload_contract.py +62 -0
  86. cursus/steps/contracts/mims_registration_contract.py +63 -0
  87. cursus/steps/contracts/model_calibration_contract.py +68 -0
  88. cursus/steps/contracts/model_evaluation_contract.py +67 -0
  89. cursus/steps/contracts/pytorch_train_contract.py +97 -0
  90. cursus/steps/contracts/risk_table_mapping_contract.py +75 -0
  91. cursus/steps/contracts/tabular_preprocess_contract.py +56 -0
  92. cursus/steps/contracts/training_script_contract.py +153 -0
  93. cursus/steps/contracts/xgboost_train_contract.py +105 -0
  94. cursus/steps/hyperparams/__init__.py +17 -0
  95. cursus/steps/hyperparams/hyperparameters_bsm.py +256 -0
  96. cursus/steps/hyperparams/hyperparameters_xgboost.py +194 -0
  97. cursus/steps/registry/__init__.py +42 -0
  98. cursus/steps/registry/builder_registry.py +610 -0
  99. cursus/steps/registry/exceptions.py +26 -0
  100. cursus/steps/registry/hyperparameter_registry.py +59 -0
  101. cursus/steps/registry/step_names.py +201 -0
  102. cursus/steps/scripts/__init__.py +38 -0
  103. cursus/steps/scripts/contract_utils.py +349 -0
  104. cursus/steps/scripts/currency_conversion.py +278 -0
  105. cursus/steps/scripts/dummy_training.py +259 -0
  106. cursus/steps/scripts/mims_package.py +269 -0
  107. cursus/steps/scripts/mims_payload.py +492 -0
  108. cursus/steps/scripts/model_calibration.py +939 -0
  109. cursus/steps/scripts/model_evaluation_xgb.py +382 -0
  110. cursus/steps/scripts/risk_table_mapping.py +462 -0
  111. cursus/steps/scripts/tabular_preprocess.py +183 -0
  112. cursus/steps/specs/__init__.py +109 -0
  113. cursus/steps/specs/batch_transform_calibration_spec.py +44 -0
  114. cursus/steps/specs/batch_transform_testing_spec.py +44 -0
  115. cursus/steps/specs/batch_transform_training_spec.py +44 -0
  116. cursus/steps/specs/batch_transform_validation_spec.py +44 -0
  117. cursus/steps/specs/currency_conversion_calibration_spec.py +41 -0
  118. cursus/steps/specs/currency_conversion_testing_spec.py +41 -0
  119. cursus/steps/specs/currency_conversion_training_spec.py +41 -0
  120. cursus/steps/specs/currency_conversion_validation_spec.py +41 -0
  121. cursus/steps/specs/data_loading_calibration_spec.py +45 -0
  122. cursus/steps/specs/data_loading_spec.py +52 -0
  123. cursus/steps/specs/data_loading_testing_spec.py +45 -0
  124. cursus/steps/specs/data_loading_training_spec.py +45 -0
  125. cursus/steps/specs/data_loading_validation_spec.py +45 -0
  126. cursus/steps/specs/dummy_training_spec.py +50 -0
  127. cursus/steps/specs/model_calibration_spec.py +61 -0
  128. cursus/steps/specs/model_eval_spec.py +61 -0
  129. cursus/steps/specs/packaging_spec.py +60 -0
  130. cursus/steps/specs/payload_spec.py +42 -0
  131. cursus/steps/specs/preprocessing_calibration_spec.py +36 -0
  132. cursus/steps/specs/preprocessing_spec.py +35 -0
  133. cursus/steps/specs/preprocessing_testing_spec.py +35 -0
  134. cursus/steps/specs/preprocessing_training_spec.py +42 -0
  135. cursus/steps/specs/preprocessing_validation_spec.py +35 -0
  136. cursus/steps/specs/pytorch_model_spec.py +36 -0
  137. cursus/steps/specs/pytorch_training_spec.py +51 -0
  138. cursus/steps/specs/registration_spec.py +50 -0
  139. cursus/steps/specs/risk_table_mapping_calibration_spec.py +72 -0
  140. cursus/steps/specs/risk_table_mapping_testing_spec.py +72 -0
  141. cursus/steps/specs/risk_table_mapping_training_spec.py +64 -0
  142. cursus/steps/specs/risk_table_mapping_validation_spec.py +72 -0
  143. cursus/steps/specs/xgboost_model_spec.py +36 -0
  144. cursus/steps/specs/xgboost_training_spec.py +59 -0
  145. cursus/validation/__init__.py +7 -0
  146. cursus-1.0.1.dist-info/METADATA +302 -0
  147. cursus-1.0.1.dist-info/RECORD +151 -0
  148. cursus-1.0.1.dist-info/WHEEL +5 -0
  149. cursus-1.0.1.dist-info/entry_points.txt +2 -0
  150. cursus-1.0.1.dist-info/licenses/LICENSE +21 -0
  151. cursus-1.0.1.dist-info/top_level.txt +1 -0
cursus/__init__.py ADDED
@@ -0,0 +1,120 @@
1
+ """
2
+ Cursus: Automatic SageMaker Pipeline Generation
3
+
4
+ Transform pipeline graphs into production-ready SageMaker pipelines automatically.
5
+ An intelligent pipeline generation system that automatically creates complete SageMaker
6
+ pipelines from user-provided pipeline graphs with intelligent dependency resolution
7
+ and configuration management.
8
+
9
+ Key Features:
10
+ - 🎯 Graph-to-Pipeline Automation: Automatically generate complete SageMaker pipelines
11
+ - ⚡ 10x Faster Development: Minutes to working pipeline vs. weeks of manual configuration
12
+ - 🧠 Intelligent Dependency Resolution: Automatic step connections and data flow
13
+ - 🛡️ Production Ready: Built-in quality gates and validation
14
+ - 📈 Proven Results: 60% average code reduction across pipeline components
15
+
16
+ Basic Usage:
17
+ >>> import cursus
18
+ >>> pipeline = cursus.compile_dag(my_dag)
19
+
20
+ >>> from cursus import PipelineDAGCompiler
21
+ >>> compiler = PipelineDAGCompiler()
22
+ >>> pipeline = compiler.compile(my_dag, pipeline_name="fraud-detection")
23
+
24
+ Advanced Usage:
25
+ >>> from cursus.core.dag import PipelineDAG
26
+ >>> from cursus.api import compile_dag_to_pipeline
27
+ >>>
28
+ >>> dag = PipelineDAG()
29
+ >>> # ... build your DAG
30
+ >>> pipeline = compile_dag_to_pipeline(dag, config_path="config.yaml")
31
+ """
32
+
33
+ from .__version__ import __version__, __title__, __description__, __author__
34
+
35
+ # Core API exports - main user interface
36
+ try:
37
+ from .core.compiler import compile_dag_to_pipeline, PipelineDAGCompiler
38
+ from .core.compiler import compile_dag_to_pipeline as compile_dag
39
+ from .core.compiler import DynamicPipelineTemplate
40
+ except ImportError as e:
41
+ # Graceful degradation if dependencies are missing
42
+ import warnings
43
+ warnings.warn(f"Some Cursus features may not be available: {e}")
44
+
45
+ def compile_dag(*args, **kwargs):
46
+ raise ImportError("Core Cursus dependencies not available. Please install with: pip install cursus[all]")
47
+
48
+ def compile_dag_to_pipeline(*args, **kwargs):
49
+ raise ImportError("Core Cursus dependencies not available. Please install with: pip install cursus[all]")
50
+
51
+ class PipelineDAGCompiler:
52
+ def __init__(self, *args, **kwargs):
53
+ raise ImportError("Core Cursus dependencies not available. Please install with: pip install cursus[all]")
54
+
55
+ # Core data structures
56
+ try:
57
+ from .api.dag import PipelineDAG, EnhancedPipelineDAG
58
+ except ImportError:
59
+ class PipelineDAG:
60
+ def __init__(self, *args, **kwargs):
61
+ raise ImportError("DAG functionality not available. Please install with: pip install cursus[all]")
62
+
63
+ class EnhancedPipelineDAG:
64
+ def __init__(self, *args, **kwargs):
65
+ raise ImportError("Enhanced DAG functionality not available. Please install with: pip install cursus[all]")
66
+
67
+ # Convenience function for quick pipeline creation
68
+ def create_pipeline_from_dag(dag, pipeline_name=None, **kwargs):
69
+ """
70
+ Create a SageMaker pipeline from a DAG specification.
71
+
72
+ This is a convenience function that combines DAG compilation and pipeline creation
73
+ in a single call with sensible defaults.
74
+
75
+ Args:
76
+ dag: PipelineDAG instance or DAG specification
77
+ pipeline_name: Optional name for the pipeline
78
+ **kwargs: Additional arguments passed to the compiler
79
+
80
+ Returns:
81
+ SageMaker Pipeline instance ready for execution
82
+
83
+ Example:
84
+ >>> dag = PipelineDAG()
85
+ >>> # ... configure your DAG
86
+ >>> pipeline = create_pipeline_from_dag(dag, "my-ml-pipeline")
87
+ >>> pipeline.start()
88
+ """
89
+ return compile_dag_to_pipeline(dag, pipeline_name=pipeline_name, **kwargs)
90
+
91
+ # Public API
92
+ __all__ = [
93
+ # Version info
94
+ "__version__",
95
+ "__title__",
96
+ "__description__",
97
+ "__author__",
98
+
99
+ # Main API functions
100
+ "compile_dag",
101
+ "compile_dag_to_pipeline",
102
+ "create_pipeline_from_dag",
103
+
104
+ # Core classes
105
+ "PipelineDAGCompiler",
106
+ "PipelineDAG",
107
+ "EnhancedPipelineDAG",
108
+ "DynamicPipelineTemplate",
109
+ ]
110
+
111
+ # Package metadata for introspection
112
+ __package_info__ = {
113
+ "name": __title__,
114
+ "version": __version__,
115
+ "description": __description__,
116
+ "author": __author__,
117
+ "license": "MIT",
118
+ "python_requires": ">=3.8",
119
+ "keywords": ["sagemaker", "pipeline", "dag", "machine-learning", "aws", "automation"],
120
+ }
cursus/__version__.py ADDED
@@ -0,0 +1,9 @@
1
+ """Version information for Cursus."""
2
+
3
+ __version__ = "1.0.1"
4
+ __title__ = "cursus"
5
+ __description__ = "Automatic SageMaker Pipeline Generation from DAG Specifications"
6
+ __author__ = "Tianpei Xie"
7
+ __author_email__ = "unidoctor@gmail.com"
8
+ __license__ = "MIT"
9
+ __url__ = "https://github.com/TianpeiLuke/cursus"
cursus/api/__init__.py ADDED
@@ -0,0 +1,30 @@
1
+ """
2
+ AutoPipe API module.
3
+
4
+ This module provides the main API interfaces for AutoPipe functionality,
5
+ including DAG management and pipeline compilation.
6
+ """
7
+
8
+ # Import DAG classes for direct access
9
+ from .dag import (
10
+ PipelineDAG,
11
+ EnhancedPipelineDAG,
12
+ EdgeType,
13
+ DependencyEdge,
14
+ ConditionalEdge,
15
+ ParallelEdge,
16
+ EdgeCollection
17
+ )
18
+
19
+ __all__ = [
20
+ # DAG classes
21
+ "PipelineDAG",
22
+ "EnhancedPipelineDAG",
23
+
24
+ # Edge types and management
25
+ "EdgeType",
26
+ "DependencyEdge",
27
+ "ConditionalEdge",
28
+ "ParallelEdge",
29
+ "EdgeCollection",
30
+ ]
@@ -0,0 +1,29 @@
1
+ """
2
+ Pipeline DAG API module.
3
+
4
+ This module provides the core DAG classes for building and managing
5
+ pipeline topologies with intelligent dependency resolution.
6
+ """
7
+
8
+ from .base_dag import PipelineDAG
9
+ from .edge_types import (
10
+ EdgeType,
11
+ DependencyEdge,
12
+ ConditionalEdge,
13
+ ParallelEdge,
14
+ EdgeCollection
15
+ )
16
+ from .enhanced_dag import EnhancedPipelineDAG
17
+
18
+ __all__ = [
19
+ # Core DAG classes
20
+ "PipelineDAG",
21
+ "EnhancedPipelineDAG",
22
+
23
+ # Edge types and management
24
+ "EdgeType",
25
+ "DependencyEdge",
26
+ "ConditionalEdge",
27
+ "ParallelEdge",
28
+ "EdgeCollection",
29
+ ]
@@ -0,0 +1,74 @@
1
+ from typing import Dict, List, Any, Optional, Type, Set, Tuple
2
+ from collections import deque
3
+
4
+ import logging
5
+
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ class PipelineDAG:
10
+ """
11
+ Represents a pipeline topology as a directed acyclic graph (DAG).
12
+ Each node is a step name; edges define dependencies.
13
+ """
14
+ def __init__(self, nodes: Optional[List[str]] = None, edges: Optional[List[tuple]] = None):
15
+ """
16
+ nodes: List of step names (str)
17
+ edges: List of (from_step, to_step) tuples
18
+ """
19
+ self.nodes = nodes or []
20
+ self.edges = edges or []
21
+ self.adj_list = {n: [] for n in self.nodes}
22
+ self.reverse_adj = {n: [] for n in self.nodes}
23
+
24
+ for src, dst in self.edges:
25
+ self.adj_list[src].append(dst)
26
+ self.reverse_adj[dst].append(src)
27
+
28
+ def add_node(self, node: str) -> None:
29
+ """Add a single node to the DAG."""
30
+ if node not in self.nodes:
31
+ self.nodes.append(node)
32
+ self.adj_list[node] = []
33
+ self.reverse_adj[node] = []
34
+ logger.info(f"Added node: {node}")
35
+
36
+ def add_edge(self, src: str, dst: str) -> None:
37
+ """Add a directed edge from src to dst."""
38
+ # Ensure both nodes exist
39
+ if src not in self.nodes:
40
+ self.add_node(src)
41
+ if dst not in self.nodes:
42
+ self.add_node(dst)
43
+
44
+ # Add the edge if it doesn't already exist
45
+ edge = (src, dst)
46
+ if edge not in self.edges:
47
+ self.edges.append(edge)
48
+ self.adj_list[src].append(dst)
49
+ self.reverse_adj[dst].append(src)
50
+ logger.info(f"Added edge: {src} -> {dst}")
51
+
52
+ def get_dependencies(self, node: str) -> List[str]:
53
+ """Return immediate dependencies (parents) of a node."""
54
+ return self.reverse_adj.get(node, [])
55
+
56
+ def topological_sort(self) -> List[str]:
57
+ """Return nodes in topological order."""
58
+
59
+ in_degree = {n: 0 for n in self.nodes}
60
+ for src, dst in self.edges:
61
+ in_degree[dst] += 1
62
+
63
+ queue = deque([n for n in self.nodes if in_degree[n] == 0])
64
+ order = []
65
+ while queue:
66
+ node = queue.popleft()
67
+ order.append(node)
68
+ for neighbor in self.adj_list[node]:
69
+ in_degree[neighbor] -= 1
70
+ if in_degree[neighbor] == 0:
71
+ queue.append(neighbor)
72
+ if len(order) != len(self.nodes):
73
+ raise ValueError("DAG has cycles or disconnected nodes")
74
+ return order
@@ -0,0 +1,281 @@
1
+ """
2
+ Edge types for enhanced pipeline DAG with typed dependencies.
3
+
4
+ This module defines the various types of edges that can exist between
5
+ pipeline steps, including typed dependency edges with confidence scoring.
6
+ """
7
+
8
+ from pydantic import BaseModel, Field, field_validator, model_validator
9
+ from typing import Optional, Dict, Any
10
+ from enum import Enum
11
+ import logging
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class EdgeType(Enum):
17
+ """Types of edges in the pipeline DAG."""
18
+ DEPENDENCY = "dependency" # Standard dependency edge
19
+ CONDITIONAL = "conditional" # Conditional dependency
20
+ PARALLEL = "parallel" # Parallel execution hint
21
+ SEQUENTIAL = "sequential" # Sequential execution requirement
22
+
23
+
24
+ class DependencyEdge(BaseModel):
25
+ """Represents a typed dependency edge between step ports."""
26
+ source_step: str = Field(
27
+ description="Name of the source step",
28
+ min_length=1
29
+ )
30
+ target_step: str = Field(
31
+ description="Name of the target step",
32
+ min_length=1
33
+ )
34
+ source_output: str = Field(
35
+ description="Logical name of source output",
36
+ min_length=1
37
+ )
38
+ target_input: str = Field(
39
+ description="Logical name of target input",
40
+ min_length=1
41
+ )
42
+ confidence: float = Field(
43
+ default=1.0,
44
+ description="Confidence score for auto-resolved edges",
45
+ ge=0.0,
46
+ le=1.0
47
+ )
48
+ edge_type: EdgeType = Field(
49
+ default=EdgeType.DEPENDENCY,
50
+ description="Type of edge"
51
+ )
52
+ metadata: Dict[str, Any] = Field(
53
+ default_factory=dict,
54
+ description="Additional metadata"
55
+ )
56
+
57
+ model_config = {
58
+ "arbitrary_types_allowed": True,
59
+ "validate_assignment": True,
60
+ "use_enum_values": True
61
+ }
62
+
63
+ def to_property_reference_dict(self) -> Dict[str, Any]:
64
+ """Convert edge to a property reference dictionary for SageMaker."""
65
+ return {
66
+ "Get": f"Steps.{self.source_step}.{self.source_output}"
67
+ }
68
+
69
+ def is_high_confidence(self, threshold: float = 0.8) -> bool:
70
+ """Check if this edge has high confidence."""
71
+ return self.confidence >= threshold
72
+
73
+ def is_auto_resolved(self) -> bool:
74
+ """Check if this edge was automatically resolved."""
75
+ return self.confidence < 1.0
76
+
77
+ def __str__(self):
78
+ return f"{self.source_step}.{self.source_output} -> {self.target_step}.{self.target_input}"
79
+
80
+ def __repr__(self):
81
+ return (f"DependencyEdge(source='{self.source_step}.{self.source_output}', "
82
+ f"target='{self.target_step}.{self.target_input}', "
83
+ f"confidence={self.confidence:.3f})")
84
+
85
+
86
+ class ConditionalEdge(DependencyEdge):
87
+ """Represents a conditional dependency edge."""
88
+ condition: str = Field(
89
+ default="",
90
+ description="Condition expression"
91
+ )
92
+ edge_type: EdgeType = Field(
93
+ default=EdgeType.CONDITIONAL,
94
+ description="Type of edge"
95
+ )
96
+
97
+ @model_validator(mode='after')
98
+ def validate_condition(self) -> 'ConditionalEdge':
99
+ """Validate condition and log warning if empty."""
100
+ if not self.condition:
101
+ logger.warning(f"ConditionalEdge {self} has no condition specified")
102
+ return self
103
+
104
+
105
+ class ParallelEdge(DependencyEdge):
106
+ """Represents a parallel execution hint edge."""
107
+ max_parallel: Optional[int] = Field(
108
+ default=None,
109
+ description="Maximum parallel executions",
110
+ ge=1
111
+ )
112
+ edge_type: EdgeType = Field(
113
+ default=EdgeType.PARALLEL,
114
+ description="Type of edge"
115
+ )
116
+
117
+
118
+ class EdgeCollection:
119
+ """Collection of edges with utility methods."""
120
+
121
+ def __init__(self):
122
+ self.edges: Dict[str, DependencyEdge] = {}
123
+ self._source_index: Dict[str, list] = {} # source_step -> list of edge_ids
124
+ self._target_index: Dict[str, list] = {} # target_step -> list of edge_ids
125
+
126
+ def add_edge(self, edge: DependencyEdge) -> str:
127
+ """
128
+ Add an edge to the collection.
129
+
130
+ Args:
131
+ edge: DependencyEdge to add
132
+
133
+ Returns:
134
+ Edge ID for the added edge
135
+ """
136
+ edge_id = f"{edge.source_step}:{edge.source_output}->{edge.target_step}:{edge.target_input}"
137
+
138
+ # Check for duplicate edges
139
+ if edge_id in self.edges:
140
+ existing = self.edges[edge_id]
141
+ if existing.confidence < edge.confidence:
142
+ # Replace with higher confidence edge
143
+ logger.info(f"Replacing edge {edge_id} with higher confidence "
144
+ f"({existing.confidence:.3f} -> {edge.confidence:.3f})")
145
+ else:
146
+ logger.debug(f"Ignoring duplicate edge {edge_id} with lower confidence")
147
+ return edge_id
148
+
149
+ self.edges[edge_id] = edge
150
+
151
+ # Update indices
152
+ if edge.source_step not in self._source_index:
153
+ self._source_index[edge.source_step] = []
154
+ self._source_index[edge.source_step].append(edge_id)
155
+
156
+ if edge.target_step not in self._target_index:
157
+ self._target_index[edge.target_step] = []
158
+ self._target_index[edge.target_step].append(edge_id)
159
+
160
+ logger.debug(f"Added edge: {edge}")
161
+ return edge_id
162
+
163
+ def remove_edge(self, edge_id: str) -> bool:
164
+ """Remove an edge from the collection."""
165
+ if edge_id not in self.edges:
166
+ return False
167
+
168
+ edge = self.edges[edge_id]
169
+ del self.edges[edge_id]
170
+
171
+ # Update indices
172
+ if edge.source_step in self._source_index:
173
+ self._source_index[edge.source_step].remove(edge_id)
174
+ if not self._source_index[edge.source_step]:
175
+ del self._source_index[edge.source_step]
176
+
177
+ if edge.target_step in self._target_index:
178
+ self._target_index[edge.target_step].remove(edge_id)
179
+ if not self._target_index[edge.target_step]:
180
+ del self._target_index[edge.target_step]
181
+
182
+ logger.debug(f"Removed edge: {edge}")
183
+ return True
184
+
185
+ def get_edges_from_step(self, step_name: str) -> list:
186
+ """Get all edges originating from a step."""
187
+ edge_ids = self._source_index.get(step_name, [])
188
+ return [self.edges[edge_id] for edge_id in edge_ids]
189
+
190
+ def get_edges_to_step(self, step_name: str) -> list:
191
+ """Get all edges targeting a step."""
192
+ edge_ids = self._target_index.get(step_name, [])
193
+ return [self.edges[edge_id] for edge_id in edge_ids]
194
+
195
+ def get_edge(self, source_step: str, source_output: str,
196
+ target_step: str, target_input: str) -> Optional[DependencyEdge]:
197
+ """Get a specific edge by its components."""
198
+ edge_id = f"{source_step}:{source_output}->{target_step}:{target_input}"
199
+ return self.edges.get(edge_id)
200
+
201
+ def list_all_edges(self) -> list:
202
+ """Get list of all edges."""
203
+ return list(self.edges.values())
204
+
205
+ def list_auto_resolved_edges(self) -> list:
206
+ """Get list of automatically resolved edges."""
207
+ return [edge for edge in self.edges.values() if edge.is_auto_resolved()]
208
+
209
+ def list_high_confidence_edges(self, threshold: float = 0.8) -> list:
210
+ """Get list of high confidence edges."""
211
+ return [edge for edge in self.edges.values() if edge.is_high_confidence(threshold)]
212
+
213
+ def list_low_confidence_edges(self, threshold: float = 0.6) -> list:
214
+ """Get list of low confidence edges that may need review."""
215
+ return [edge for edge in self.edges.values() if edge.confidence < threshold]
216
+
217
+ def get_step_dependencies(self, step_name: str) -> Dict[str, DependencyEdge]:
218
+ """Get all dependencies for a step as a dictionary."""
219
+ edges = self.get_edges_to_step(step_name)
220
+ return {edge.target_input: edge for edge in edges}
221
+
222
+ def validate_edges(self) -> list:
223
+ """Validate all edges and return list of errors."""
224
+ errors = []
225
+
226
+ for edge_id, edge in self.edges.items():
227
+ # Check for self-dependencies
228
+ if edge.source_step == edge.target_step:
229
+ errors.append(f"Self-dependency detected: {edge_id}")
230
+
231
+ # Check confidence bounds
232
+ if not 0.0 <= edge.confidence <= 1.0:
233
+ errors.append(f"Invalid confidence {edge.confidence} for edge {edge_id}")
234
+
235
+ # Check for empty names
236
+ if not all([edge.source_step, edge.target_step, edge.source_output, edge.target_input]):
237
+ errors.append(f"Empty component in edge {edge_id}")
238
+
239
+ return errors
240
+
241
+ def get_statistics(self) -> Dict[str, Any]:
242
+ """Get statistics about the edge collection."""
243
+ edges = list(self.edges.values())
244
+
245
+ if not edges:
246
+ return {
247
+ 'total_edges': 0,
248
+ 'auto_resolved_edges': 0,
249
+ 'high_confidence_edges': 0,
250
+ 'low_confidence_edges': 0,
251
+ 'average_confidence': 0.0,
252
+ 'edge_types': {}
253
+ }
254
+
255
+ confidences = [edge.confidence for edge in edges]
256
+ edge_types = {}
257
+ for edge in edges:
258
+ edge_type = edge.edge_type.value
259
+ edge_types[edge_type] = edge_types.get(edge_type, 0) + 1
260
+
261
+ return {
262
+ 'total_edges': len(edges),
263
+ 'auto_resolved_edges': len(self.list_auto_resolved_edges()),
264
+ 'high_confidence_edges': len(self.list_high_confidence_edges()),
265
+ 'low_confidence_edges': len(self.list_low_confidence_edges()),
266
+ 'average_confidence': sum(confidences) / len(confidences),
267
+ 'min_confidence': min(confidences),
268
+ 'max_confidence': max(confidences),
269
+ 'edge_types': edge_types,
270
+ 'unique_source_steps': len(self._source_index),
271
+ 'unique_target_steps': len(self._target_index)
272
+ }
273
+
274
+ def __len__(self):
275
+ return len(self.edges)
276
+
277
+ def __iter__(self):
278
+ return iter(self.edges.values())
279
+
280
+ def __contains__(self, edge_id: str):
281
+ return edge_id in self.edges