cursus 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cursus/__init__.py +120 -0
- cursus/__version__.py +9 -0
- cursus/api/__init__.py +30 -0
- cursus/api/dag/__init__.py +29 -0
- cursus/api/dag/base_dag.py +74 -0
- cursus/api/dag/edge_types.py +281 -0
- cursus/api/dag/enhanced_dag.py +372 -0
- cursus/cli/__init__.py +416 -0
- cursus/core/__init__.py +163 -0
- cursus/core/assembler/__init__.py +14 -0
- cursus/core/assembler/pipeline_assembler.py +468 -0
- cursus/core/assembler/pipeline_template_base.py +420 -0
- cursus/core/base/__init__.py +38 -0
- cursus/core/base/builder_base.py +822 -0
- cursus/core/base/config_base.py +450 -0
- cursus/core/base/contract_base.py +303 -0
- cursus/core/base/enums.py +46 -0
- cursus/core/base/hyperparameters_base.py +338 -0
- cursus/core/base/specification_base.py +626 -0
- cursus/core/compiler/__init__.py +58 -0
- cursus/core/compiler/config_resolver.py +725 -0
- cursus/core/compiler/dag_compiler.py +529 -0
- cursus/core/compiler/dynamic_template.py +864 -0
- cursus/core/compiler/exceptions.py +104 -0
- cursus/core/compiler/name_generator.py +112 -0
- cursus/core/compiler/validation.py +339 -0
- cursus/core/config_fields/__init__.py +308 -0
- cursus/core/config_fields/circular_reference_tracker.py +212 -0
- cursus/core/config_fields/config_class_detector.py +209 -0
- cursus/core/config_fields/config_class_store.py +136 -0
- cursus/core/config_fields/config_field_categorizer.py +393 -0
- cursus/core/config_fields/config_merger.py +363 -0
- cursus/core/config_fields/constants.py +89 -0
- cursus/core/config_fields/type_aware_config_serializer.py +661 -0
- cursus/core/deps/__init__.py +52 -0
- cursus/core/deps/base_specifications.py +664 -0
- cursus/core/deps/dependency_resolver.py +362 -0
- cursus/core/deps/factory.py +50 -0
- cursus/core/deps/property_reference.py +244 -0
- cursus/core/deps/registry_manager.py +226 -0
- cursus/core/deps/semantic_matcher.py +268 -0
- cursus/core/deps/specification_registry.py +115 -0
- cursus/steps/__init__.py +36 -0
- cursus/steps/builders/__init__.py +53 -0
- cursus/steps/builders/builder_batch_transform_step.py +295 -0
- cursus/steps/builders/builder_currency_conversion_step.py +367 -0
- cursus/steps/builders/builder_data_load_step_cradle.py +597 -0
- cursus/steps/builders/builder_dummy_training_step.py +508 -0
- cursus/steps/builders/builder_model_calibration_step.py +504 -0
- cursus/steps/builders/builder_model_eval_step_xgboost.py +359 -0
- cursus/steps/builders/builder_model_step_pytorch.py +249 -0
- cursus/steps/builders/builder_model_step_xgboost.py +247 -0
- cursus/steps/builders/builder_package_step.py +394 -0
- cursus/steps/builders/builder_payload_step.py +360 -0
- cursus/steps/builders/builder_registration_step.py +387 -0
- cursus/steps/builders/builder_risk_table_mapping_step.py +506 -0
- cursus/steps/builders/builder_tabular_preprocessing_step.py +367 -0
- cursus/steps/builders/builder_training_step_pytorch.py +474 -0
- cursus/steps/builders/builder_training_step_xgboost.py +610 -0
- cursus/steps/builders/s3_utils.py +205 -0
- cursus/steps/configs/__init__.py +98 -0
- cursus/steps/configs/config_batch_transform_step.py +99 -0
- cursus/steps/configs/config_currency_conversion_step.py +194 -0
- cursus/steps/configs/config_data_load_step_cradle.py +889 -0
- cursus/steps/configs/config_dummy_training_step.py +165 -0
- cursus/steps/configs/config_model_calibration_step.py +347 -0
- cursus/steps/configs/config_model_eval_step_xgboost.py +223 -0
- cursus/steps/configs/config_model_step_pytorch.py +105 -0
- cursus/steps/configs/config_model_step_xgboost.py +164 -0
- cursus/steps/configs/config_package_step.py +113 -0
- cursus/steps/configs/config_payload_step.py +635 -0
- cursus/steps/configs/config_processing_step_base.py +363 -0
- cursus/steps/configs/config_registration_step.py +413 -0
- cursus/steps/configs/config_risk_table_mapping_step.py +110 -0
- cursus/steps/configs/config_tabular_preprocessing_step.py +231 -0
- cursus/steps/configs/config_training_step_pytorch.py +83 -0
- cursus/steps/configs/config_training_step_xgboost.py +213 -0
- cursus/steps/configs/utils.py +466 -0
- cursus/steps/contracts/__init__.py +61 -0
- cursus/steps/contracts/contract_validator.py +262 -0
- cursus/steps/contracts/cradle_data_loading_contract.py +65 -0
- cursus/steps/contracts/currency_conversion_contract.py +77 -0
- cursus/steps/contracts/dummy_training_contract.py +31 -0
- cursus/steps/contracts/mims_package_contract.py +47 -0
- cursus/steps/contracts/mims_payload_contract.py +62 -0
- cursus/steps/contracts/mims_registration_contract.py +63 -0
- cursus/steps/contracts/model_calibration_contract.py +68 -0
- cursus/steps/contracts/model_evaluation_contract.py +67 -0
- cursus/steps/contracts/pytorch_train_contract.py +97 -0
- cursus/steps/contracts/risk_table_mapping_contract.py +75 -0
- cursus/steps/contracts/tabular_preprocess_contract.py +56 -0
- cursus/steps/contracts/training_script_contract.py +153 -0
- cursus/steps/contracts/xgboost_train_contract.py +105 -0
- cursus/steps/hyperparams/__init__.py +17 -0
- cursus/steps/hyperparams/hyperparameters_bsm.py +256 -0
- cursus/steps/hyperparams/hyperparameters_xgboost.py +194 -0
- cursus/steps/registry/__init__.py +42 -0
- cursus/steps/registry/builder_registry.py +610 -0
- cursus/steps/registry/exceptions.py +26 -0
- cursus/steps/registry/hyperparameter_registry.py +59 -0
- cursus/steps/registry/step_names.py +201 -0
- cursus/steps/scripts/__init__.py +38 -0
- cursus/steps/scripts/contract_utils.py +349 -0
- cursus/steps/scripts/currency_conversion.py +278 -0
- cursus/steps/scripts/dummy_training.py +259 -0
- cursus/steps/scripts/mims_package.py +269 -0
- cursus/steps/scripts/mims_payload.py +492 -0
- cursus/steps/scripts/model_calibration.py +939 -0
- cursus/steps/scripts/model_evaluation_xgb.py +382 -0
- cursus/steps/scripts/risk_table_mapping.py +462 -0
- cursus/steps/scripts/tabular_preprocess.py +183 -0
- cursus/steps/specs/__init__.py +109 -0
- cursus/steps/specs/batch_transform_calibration_spec.py +44 -0
- cursus/steps/specs/batch_transform_testing_spec.py +44 -0
- cursus/steps/specs/batch_transform_training_spec.py +44 -0
- cursus/steps/specs/batch_transform_validation_spec.py +44 -0
- cursus/steps/specs/currency_conversion_calibration_spec.py +41 -0
- cursus/steps/specs/currency_conversion_testing_spec.py +41 -0
- cursus/steps/specs/currency_conversion_training_spec.py +41 -0
- cursus/steps/specs/currency_conversion_validation_spec.py +41 -0
- cursus/steps/specs/data_loading_calibration_spec.py +45 -0
- cursus/steps/specs/data_loading_spec.py +52 -0
- cursus/steps/specs/data_loading_testing_spec.py +45 -0
- cursus/steps/specs/data_loading_training_spec.py +45 -0
- cursus/steps/specs/data_loading_validation_spec.py +45 -0
- cursus/steps/specs/dummy_training_spec.py +50 -0
- cursus/steps/specs/model_calibration_spec.py +61 -0
- cursus/steps/specs/model_eval_spec.py +61 -0
- cursus/steps/specs/packaging_spec.py +60 -0
- cursus/steps/specs/payload_spec.py +42 -0
- cursus/steps/specs/preprocessing_calibration_spec.py +36 -0
- cursus/steps/specs/preprocessing_spec.py +35 -0
- cursus/steps/specs/preprocessing_testing_spec.py +35 -0
- cursus/steps/specs/preprocessing_training_spec.py +42 -0
- cursus/steps/specs/preprocessing_validation_spec.py +35 -0
- cursus/steps/specs/pytorch_model_spec.py +36 -0
- cursus/steps/specs/pytorch_training_spec.py +51 -0
- cursus/steps/specs/registration_spec.py +50 -0
- cursus/steps/specs/risk_table_mapping_calibration_spec.py +72 -0
- cursus/steps/specs/risk_table_mapping_testing_spec.py +72 -0
- cursus/steps/specs/risk_table_mapping_training_spec.py +64 -0
- cursus/steps/specs/risk_table_mapping_validation_spec.py +72 -0
- cursus/steps/specs/xgboost_model_spec.py +36 -0
- cursus/steps/specs/xgboost_training_spec.py +59 -0
- cursus/validation/__init__.py +7 -0
- cursus-1.0.1.dist-info/METADATA +302 -0
- cursus-1.0.1.dist-info/RECORD +151 -0
- cursus-1.0.1.dist-info/WHEEL +5 -0
- cursus-1.0.1.dist-info/entry_points.txt +2 -0
- cursus-1.0.1.dist-info/licenses/LICENSE +21 -0
- cursus-1.0.1.dist-info/top_level.txt +1 -0
cursus/__init__.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Cursus: Automatic SageMaker Pipeline Generation
|
|
3
|
+
|
|
4
|
+
Transform pipeline graphs into production-ready SageMaker pipelines automatically.
|
|
5
|
+
An intelligent pipeline generation system that automatically creates complete SageMaker
|
|
6
|
+
pipelines from user-provided pipeline graphs with intelligent dependency resolution
|
|
7
|
+
and configuration management.
|
|
8
|
+
|
|
9
|
+
Key Features:
|
|
10
|
+
- 🎯 Graph-to-Pipeline Automation: Automatically generate complete SageMaker pipelines
|
|
11
|
+
- ⚡ 10x Faster Development: Minutes to working pipeline vs. weeks of manual configuration
|
|
12
|
+
- 🧠 Intelligent Dependency Resolution: Automatic step connections and data flow
|
|
13
|
+
- 🛡️ Production Ready: Built-in quality gates and validation
|
|
14
|
+
- 📈 Proven Results: 60% average code reduction across pipeline components
|
|
15
|
+
|
|
16
|
+
Basic Usage:
|
|
17
|
+
>>> import cursus
|
|
18
|
+
>>> pipeline = cursus.compile_dag(my_dag)
|
|
19
|
+
|
|
20
|
+
>>> from cursus import PipelineDAGCompiler
|
|
21
|
+
>>> compiler = PipelineDAGCompiler()
|
|
22
|
+
>>> pipeline = compiler.compile(my_dag, pipeline_name="fraud-detection")
|
|
23
|
+
|
|
24
|
+
Advanced Usage:
|
|
25
|
+
>>> from cursus.core.dag import PipelineDAG
|
|
26
|
+
>>> from cursus.api import compile_dag_to_pipeline
|
|
27
|
+
>>>
|
|
28
|
+
>>> dag = PipelineDAG()
|
|
29
|
+
>>> # ... build your DAG
|
|
30
|
+
>>> pipeline = compile_dag_to_pipeline(dag, config_path="config.yaml")
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
from .__version__ import __version__, __title__, __description__, __author__
|
|
34
|
+
|
|
35
|
+
# Core API exports - main user interface
|
|
36
|
+
try:
|
|
37
|
+
from .core.compiler import compile_dag_to_pipeline, PipelineDAGCompiler
|
|
38
|
+
from .core.compiler import compile_dag_to_pipeline as compile_dag
|
|
39
|
+
from .core.compiler import DynamicPipelineTemplate
|
|
40
|
+
except ImportError as e:
|
|
41
|
+
# Graceful degradation if dependencies are missing
|
|
42
|
+
import warnings
|
|
43
|
+
warnings.warn(f"Some Cursus features may not be available: {e}")
|
|
44
|
+
|
|
45
|
+
def compile_dag(*args, **kwargs):
|
|
46
|
+
raise ImportError("Core Cursus dependencies not available. Please install with: pip install cursus[all]")
|
|
47
|
+
|
|
48
|
+
def compile_dag_to_pipeline(*args, **kwargs):
|
|
49
|
+
raise ImportError("Core Cursus dependencies not available. Please install with: pip install cursus[all]")
|
|
50
|
+
|
|
51
|
+
class PipelineDAGCompiler:
|
|
52
|
+
def __init__(self, *args, **kwargs):
|
|
53
|
+
raise ImportError("Core Cursus dependencies not available. Please install with: pip install cursus[all]")
|
|
54
|
+
|
|
55
|
+
# Core data structures
|
|
56
|
+
try:
|
|
57
|
+
from .api.dag import PipelineDAG, EnhancedPipelineDAG
|
|
58
|
+
except ImportError:
|
|
59
|
+
class PipelineDAG:
|
|
60
|
+
def __init__(self, *args, **kwargs):
|
|
61
|
+
raise ImportError("DAG functionality not available. Please install with: pip install cursus[all]")
|
|
62
|
+
|
|
63
|
+
class EnhancedPipelineDAG:
|
|
64
|
+
def __init__(self, *args, **kwargs):
|
|
65
|
+
raise ImportError("Enhanced DAG functionality not available. Please install with: pip install cursus[all]")
|
|
66
|
+
|
|
67
|
+
# Convenience function for quick pipeline creation
|
|
68
|
+
def create_pipeline_from_dag(dag, pipeline_name=None, **kwargs):
|
|
69
|
+
"""
|
|
70
|
+
Create a SageMaker pipeline from a DAG specification.
|
|
71
|
+
|
|
72
|
+
This is a convenience function that combines DAG compilation and pipeline creation
|
|
73
|
+
in a single call with sensible defaults.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
dag: PipelineDAG instance or DAG specification
|
|
77
|
+
pipeline_name: Optional name for the pipeline
|
|
78
|
+
**kwargs: Additional arguments passed to the compiler
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
SageMaker Pipeline instance ready for execution
|
|
82
|
+
|
|
83
|
+
Example:
|
|
84
|
+
>>> dag = PipelineDAG()
|
|
85
|
+
>>> # ... configure your DAG
|
|
86
|
+
>>> pipeline = create_pipeline_from_dag(dag, "my-ml-pipeline")
|
|
87
|
+
>>> pipeline.start()
|
|
88
|
+
"""
|
|
89
|
+
return compile_dag_to_pipeline(dag, pipeline_name=pipeline_name, **kwargs)
|
|
90
|
+
|
|
91
|
+
# Public API
|
|
92
|
+
__all__ = [
|
|
93
|
+
# Version info
|
|
94
|
+
"__version__",
|
|
95
|
+
"__title__",
|
|
96
|
+
"__description__",
|
|
97
|
+
"__author__",
|
|
98
|
+
|
|
99
|
+
# Main API functions
|
|
100
|
+
"compile_dag",
|
|
101
|
+
"compile_dag_to_pipeline",
|
|
102
|
+
"create_pipeline_from_dag",
|
|
103
|
+
|
|
104
|
+
# Core classes
|
|
105
|
+
"PipelineDAGCompiler",
|
|
106
|
+
"PipelineDAG",
|
|
107
|
+
"EnhancedPipelineDAG",
|
|
108
|
+
"DynamicPipelineTemplate",
|
|
109
|
+
]
|
|
110
|
+
|
|
111
|
+
# Package metadata for introspection
|
|
112
|
+
__package_info__ = {
|
|
113
|
+
"name": __title__,
|
|
114
|
+
"version": __version__,
|
|
115
|
+
"description": __description__,
|
|
116
|
+
"author": __author__,
|
|
117
|
+
"license": "MIT",
|
|
118
|
+
"python_requires": ">=3.8",
|
|
119
|
+
"keywords": ["sagemaker", "pipeline", "dag", "machine-learning", "aws", "automation"],
|
|
120
|
+
}
|
cursus/__version__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"""Version information for Cursus."""
|
|
2
|
+
|
|
3
|
+
__version__ = "1.0.1"
|
|
4
|
+
__title__ = "cursus"
|
|
5
|
+
__description__ = "Automatic SageMaker Pipeline Generation from DAG Specifications"
|
|
6
|
+
__author__ = "Tianpei Xie"
|
|
7
|
+
__author_email__ = "unidoctor@gmail.com"
|
|
8
|
+
__license__ = "MIT"
|
|
9
|
+
__url__ = "https://github.com/TianpeiLuke/cursus"
|
cursus/api/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""
|
|
2
|
+
AutoPipe API module.
|
|
3
|
+
|
|
4
|
+
This module provides the main API interfaces for AutoPipe functionality,
|
|
5
|
+
including DAG management and pipeline compilation.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
# Import DAG classes for direct access
|
|
9
|
+
from .dag import (
|
|
10
|
+
PipelineDAG,
|
|
11
|
+
EnhancedPipelineDAG,
|
|
12
|
+
EdgeType,
|
|
13
|
+
DependencyEdge,
|
|
14
|
+
ConditionalEdge,
|
|
15
|
+
ParallelEdge,
|
|
16
|
+
EdgeCollection
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
# DAG classes
|
|
21
|
+
"PipelineDAG",
|
|
22
|
+
"EnhancedPipelineDAG",
|
|
23
|
+
|
|
24
|
+
# Edge types and management
|
|
25
|
+
"EdgeType",
|
|
26
|
+
"DependencyEdge",
|
|
27
|
+
"ConditionalEdge",
|
|
28
|
+
"ParallelEdge",
|
|
29
|
+
"EdgeCollection",
|
|
30
|
+
]
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pipeline DAG API module.
|
|
3
|
+
|
|
4
|
+
This module provides the core DAG classes for building and managing
|
|
5
|
+
pipeline topologies with intelligent dependency resolution.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from .base_dag import PipelineDAG
|
|
9
|
+
from .edge_types import (
|
|
10
|
+
EdgeType,
|
|
11
|
+
DependencyEdge,
|
|
12
|
+
ConditionalEdge,
|
|
13
|
+
ParallelEdge,
|
|
14
|
+
EdgeCollection
|
|
15
|
+
)
|
|
16
|
+
from .enhanced_dag import EnhancedPipelineDAG
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
# Core DAG classes
|
|
20
|
+
"PipelineDAG",
|
|
21
|
+
"EnhancedPipelineDAG",
|
|
22
|
+
|
|
23
|
+
# Edge types and management
|
|
24
|
+
"EdgeType",
|
|
25
|
+
"DependencyEdge",
|
|
26
|
+
"ConditionalEdge",
|
|
27
|
+
"ParallelEdge",
|
|
28
|
+
"EdgeCollection",
|
|
29
|
+
]
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
from typing import Dict, List, Any, Optional, Type, Set, Tuple
|
|
2
|
+
from collections import deque
|
|
3
|
+
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
logger = logging.getLogger(__name__)
|
|
8
|
+
|
|
9
|
+
class PipelineDAG:
|
|
10
|
+
"""
|
|
11
|
+
Represents a pipeline topology as a directed acyclic graph (DAG).
|
|
12
|
+
Each node is a step name; edges define dependencies.
|
|
13
|
+
"""
|
|
14
|
+
def __init__(self, nodes: Optional[List[str]] = None, edges: Optional[List[tuple]] = None):
|
|
15
|
+
"""
|
|
16
|
+
nodes: List of step names (str)
|
|
17
|
+
edges: List of (from_step, to_step) tuples
|
|
18
|
+
"""
|
|
19
|
+
self.nodes = nodes or []
|
|
20
|
+
self.edges = edges or []
|
|
21
|
+
self.adj_list = {n: [] for n in self.nodes}
|
|
22
|
+
self.reverse_adj = {n: [] for n in self.nodes}
|
|
23
|
+
|
|
24
|
+
for src, dst in self.edges:
|
|
25
|
+
self.adj_list[src].append(dst)
|
|
26
|
+
self.reverse_adj[dst].append(src)
|
|
27
|
+
|
|
28
|
+
def add_node(self, node: str) -> None:
|
|
29
|
+
"""Add a single node to the DAG."""
|
|
30
|
+
if node not in self.nodes:
|
|
31
|
+
self.nodes.append(node)
|
|
32
|
+
self.adj_list[node] = []
|
|
33
|
+
self.reverse_adj[node] = []
|
|
34
|
+
logger.info(f"Added node: {node}")
|
|
35
|
+
|
|
36
|
+
def add_edge(self, src: str, dst: str) -> None:
|
|
37
|
+
"""Add a directed edge from src to dst."""
|
|
38
|
+
# Ensure both nodes exist
|
|
39
|
+
if src not in self.nodes:
|
|
40
|
+
self.add_node(src)
|
|
41
|
+
if dst not in self.nodes:
|
|
42
|
+
self.add_node(dst)
|
|
43
|
+
|
|
44
|
+
# Add the edge if it doesn't already exist
|
|
45
|
+
edge = (src, dst)
|
|
46
|
+
if edge not in self.edges:
|
|
47
|
+
self.edges.append(edge)
|
|
48
|
+
self.adj_list[src].append(dst)
|
|
49
|
+
self.reverse_adj[dst].append(src)
|
|
50
|
+
logger.info(f"Added edge: {src} -> {dst}")
|
|
51
|
+
|
|
52
|
+
def get_dependencies(self, node: str) -> List[str]:
|
|
53
|
+
"""Return immediate dependencies (parents) of a node."""
|
|
54
|
+
return self.reverse_adj.get(node, [])
|
|
55
|
+
|
|
56
|
+
def topological_sort(self) -> List[str]:
|
|
57
|
+
"""Return nodes in topological order."""
|
|
58
|
+
|
|
59
|
+
in_degree = {n: 0 for n in self.nodes}
|
|
60
|
+
for src, dst in self.edges:
|
|
61
|
+
in_degree[dst] += 1
|
|
62
|
+
|
|
63
|
+
queue = deque([n for n in self.nodes if in_degree[n] == 0])
|
|
64
|
+
order = []
|
|
65
|
+
while queue:
|
|
66
|
+
node = queue.popleft()
|
|
67
|
+
order.append(node)
|
|
68
|
+
for neighbor in self.adj_list[node]:
|
|
69
|
+
in_degree[neighbor] -= 1
|
|
70
|
+
if in_degree[neighbor] == 0:
|
|
71
|
+
queue.append(neighbor)
|
|
72
|
+
if len(order) != len(self.nodes):
|
|
73
|
+
raise ValueError("DAG has cycles or disconnected nodes")
|
|
74
|
+
return order
|
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Edge types for enhanced pipeline DAG with typed dependencies.
|
|
3
|
+
|
|
4
|
+
This module defines the various types of edges that can exist between
|
|
5
|
+
pipeline steps, including typed dependency edges with confidence scoring.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
9
|
+
from typing import Optional, Dict, Any
|
|
10
|
+
from enum import Enum
|
|
11
|
+
import logging
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class EdgeType(Enum):
|
|
17
|
+
"""Types of edges in the pipeline DAG."""
|
|
18
|
+
DEPENDENCY = "dependency" # Standard dependency edge
|
|
19
|
+
CONDITIONAL = "conditional" # Conditional dependency
|
|
20
|
+
PARALLEL = "parallel" # Parallel execution hint
|
|
21
|
+
SEQUENTIAL = "sequential" # Sequential execution requirement
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class DependencyEdge(BaseModel):
|
|
25
|
+
"""Represents a typed dependency edge between step ports."""
|
|
26
|
+
source_step: str = Field(
|
|
27
|
+
description="Name of the source step",
|
|
28
|
+
min_length=1
|
|
29
|
+
)
|
|
30
|
+
target_step: str = Field(
|
|
31
|
+
description="Name of the target step",
|
|
32
|
+
min_length=1
|
|
33
|
+
)
|
|
34
|
+
source_output: str = Field(
|
|
35
|
+
description="Logical name of source output",
|
|
36
|
+
min_length=1
|
|
37
|
+
)
|
|
38
|
+
target_input: str = Field(
|
|
39
|
+
description="Logical name of target input",
|
|
40
|
+
min_length=1
|
|
41
|
+
)
|
|
42
|
+
confidence: float = Field(
|
|
43
|
+
default=1.0,
|
|
44
|
+
description="Confidence score for auto-resolved edges",
|
|
45
|
+
ge=0.0,
|
|
46
|
+
le=1.0
|
|
47
|
+
)
|
|
48
|
+
edge_type: EdgeType = Field(
|
|
49
|
+
default=EdgeType.DEPENDENCY,
|
|
50
|
+
description="Type of edge"
|
|
51
|
+
)
|
|
52
|
+
metadata: Dict[str, Any] = Field(
|
|
53
|
+
default_factory=dict,
|
|
54
|
+
description="Additional metadata"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
model_config = {
|
|
58
|
+
"arbitrary_types_allowed": True,
|
|
59
|
+
"validate_assignment": True,
|
|
60
|
+
"use_enum_values": True
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
def to_property_reference_dict(self) -> Dict[str, Any]:
|
|
64
|
+
"""Convert edge to a property reference dictionary for SageMaker."""
|
|
65
|
+
return {
|
|
66
|
+
"Get": f"Steps.{self.source_step}.{self.source_output}"
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
def is_high_confidence(self, threshold: float = 0.8) -> bool:
|
|
70
|
+
"""Check if this edge has high confidence."""
|
|
71
|
+
return self.confidence >= threshold
|
|
72
|
+
|
|
73
|
+
def is_auto_resolved(self) -> bool:
|
|
74
|
+
"""Check if this edge was automatically resolved."""
|
|
75
|
+
return self.confidence < 1.0
|
|
76
|
+
|
|
77
|
+
def __str__(self):
|
|
78
|
+
return f"{self.source_step}.{self.source_output} -> {self.target_step}.{self.target_input}"
|
|
79
|
+
|
|
80
|
+
def __repr__(self):
|
|
81
|
+
return (f"DependencyEdge(source='{self.source_step}.{self.source_output}', "
|
|
82
|
+
f"target='{self.target_step}.{self.target_input}', "
|
|
83
|
+
f"confidence={self.confidence:.3f})")
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class ConditionalEdge(DependencyEdge):
|
|
87
|
+
"""Represents a conditional dependency edge."""
|
|
88
|
+
condition: str = Field(
|
|
89
|
+
default="",
|
|
90
|
+
description="Condition expression"
|
|
91
|
+
)
|
|
92
|
+
edge_type: EdgeType = Field(
|
|
93
|
+
default=EdgeType.CONDITIONAL,
|
|
94
|
+
description="Type of edge"
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
@model_validator(mode='after')
|
|
98
|
+
def validate_condition(self) -> 'ConditionalEdge':
|
|
99
|
+
"""Validate condition and log warning if empty."""
|
|
100
|
+
if not self.condition:
|
|
101
|
+
logger.warning(f"ConditionalEdge {self} has no condition specified")
|
|
102
|
+
return self
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class ParallelEdge(DependencyEdge):
|
|
106
|
+
"""Represents a parallel execution hint edge."""
|
|
107
|
+
max_parallel: Optional[int] = Field(
|
|
108
|
+
default=None,
|
|
109
|
+
description="Maximum parallel executions",
|
|
110
|
+
ge=1
|
|
111
|
+
)
|
|
112
|
+
edge_type: EdgeType = Field(
|
|
113
|
+
default=EdgeType.PARALLEL,
|
|
114
|
+
description="Type of edge"
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class EdgeCollection:
|
|
119
|
+
"""Collection of edges with utility methods."""
|
|
120
|
+
|
|
121
|
+
def __init__(self):
|
|
122
|
+
self.edges: Dict[str, DependencyEdge] = {}
|
|
123
|
+
self._source_index: Dict[str, list] = {} # source_step -> list of edge_ids
|
|
124
|
+
self._target_index: Dict[str, list] = {} # target_step -> list of edge_ids
|
|
125
|
+
|
|
126
|
+
def add_edge(self, edge: DependencyEdge) -> str:
|
|
127
|
+
"""
|
|
128
|
+
Add an edge to the collection.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
edge: DependencyEdge to add
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
Edge ID for the added edge
|
|
135
|
+
"""
|
|
136
|
+
edge_id = f"{edge.source_step}:{edge.source_output}->{edge.target_step}:{edge.target_input}"
|
|
137
|
+
|
|
138
|
+
# Check for duplicate edges
|
|
139
|
+
if edge_id in self.edges:
|
|
140
|
+
existing = self.edges[edge_id]
|
|
141
|
+
if existing.confidence < edge.confidence:
|
|
142
|
+
# Replace with higher confidence edge
|
|
143
|
+
logger.info(f"Replacing edge {edge_id} with higher confidence "
|
|
144
|
+
f"({existing.confidence:.3f} -> {edge.confidence:.3f})")
|
|
145
|
+
else:
|
|
146
|
+
logger.debug(f"Ignoring duplicate edge {edge_id} with lower confidence")
|
|
147
|
+
return edge_id
|
|
148
|
+
|
|
149
|
+
self.edges[edge_id] = edge
|
|
150
|
+
|
|
151
|
+
# Update indices
|
|
152
|
+
if edge.source_step not in self._source_index:
|
|
153
|
+
self._source_index[edge.source_step] = []
|
|
154
|
+
self._source_index[edge.source_step].append(edge_id)
|
|
155
|
+
|
|
156
|
+
if edge.target_step not in self._target_index:
|
|
157
|
+
self._target_index[edge.target_step] = []
|
|
158
|
+
self._target_index[edge.target_step].append(edge_id)
|
|
159
|
+
|
|
160
|
+
logger.debug(f"Added edge: {edge}")
|
|
161
|
+
return edge_id
|
|
162
|
+
|
|
163
|
+
def remove_edge(self, edge_id: str) -> bool:
|
|
164
|
+
"""Remove an edge from the collection."""
|
|
165
|
+
if edge_id not in self.edges:
|
|
166
|
+
return False
|
|
167
|
+
|
|
168
|
+
edge = self.edges[edge_id]
|
|
169
|
+
del self.edges[edge_id]
|
|
170
|
+
|
|
171
|
+
# Update indices
|
|
172
|
+
if edge.source_step in self._source_index:
|
|
173
|
+
self._source_index[edge.source_step].remove(edge_id)
|
|
174
|
+
if not self._source_index[edge.source_step]:
|
|
175
|
+
del self._source_index[edge.source_step]
|
|
176
|
+
|
|
177
|
+
if edge.target_step in self._target_index:
|
|
178
|
+
self._target_index[edge.target_step].remove(edge_id)
|
|
179
|
+
if not self._target_index[edge.target_step]:
|
|
180
|
+
del self._target_index[edge.target_step]
|
|
181
|
+
|
|
182
|
+
logger.debug(f"Removed edge: {edge}")
|
|
183
|
+
return True
|
|
184
|
+
|
|
185
|
+
def get_edges_from_step(self, step_name: str) -> list:
|
|
186
|
+
"""Get all edges originating from a step."""
|
|
187
|
+
edge_ids = self._source_index.get(step_name, [])
|
|
188
|
+
return [self.edges[edge_id] for edge_id in edge_ids]
|
|
189
|
+
|
|
190
|
+
def get_edges_to_step(self, step_name: str) -> list:
|
|
191
|
+
"""Get all edges targeting a step."""
|
|
192
|
+
edge_ids = self._target_index.get(step_name, [])
|
|
193
|
+
return [self.edges[edge_id] for edge_id in edge_ids]
|
|
194
|
+
|
|
195
|
+
def get_edge(self, source_step: str, source_output: str,
|
|
196
|
+
target_step: str, target_input: str) -> Optional[DependencyEdge]:
|
|
197
|
+
"""Get a specific edge by its components."""
|
|
198
|
+
edge_id = f"{source_step}:{source_output}->{target_step}:{target_input}"
|
|
199
|
+
return self.edges.get(edge_id)
|
|
200
|
+
|
|
201
|
+
def list_all_edges(self) -> list:
|
|
202
|
+
"""Get list of all edges."""
|
|
203
|
+
return list(self.edges.values())
|
|
204
|
+
|
|
205
|
+
def list_auto_resolved_edges(self) -> list:
|
|
206
|
+
"""Get list of automatically resolved edges."""
|
|
207
|
+
return [edge for edge in self.edges.values() if edge.is_auto_resolved()]
|
|
208
|
+
|
|
209
|
+
def list_high_confidence_edges(self, threshold: float = 0.8) -> list:
|
|
210
|
+
"""Get list of high confidence edges."""
|
|
211
|
+
return [edge for edge in self.edges.values() if edge.is_high_confidence(threshold)]
|
|
212
|
+
|
|
213
|
+
def list_low_confidence_edges(self, threshold: float = 0.6) -> list:
|
|
214
|
+
"""Get list of low confidence edges that may need review."""
|
|
215
|
+
return [edge for edge in self.edges.values() if edge.confidence < threshold]
|
|
216
|
+
|
|
217
|
+
def get_step_dependencies(self, step_name: str) -> Dict[str, DependencyEdge]:
|
|
218
|
+
"""Get all dependencies for a step as a dictionary."""
|
|
219
|
+
edges = self.get_edges_to_step(step_name)
|
|
220
|
+
return {edge.target_input: edge for edge in edges}
|
|
221
|
+
|
|
222
|
+
def validate_edges(self) -> list:
|
|
223
|
+
"""Validate all edges and return list of errors."""
|
|
224
|
+
errors = []
|
|
225
|
+
|
|
226
|
+
for edge_id, edge in self.edges.items():
|
|
227
|
+
# Check for self-dependencies
|
|
228
|
+
if edge.source_step == edge.target_step:
|
|
229
|
+
errors.append(f"Self-dependency detected: {edge_id}")
|
|
230
|
+
|
|
231
|
+
# Check confidence bounds
|
|
232
|
+
if not 0.0 <= edge.confidence <= 1.0:
|
|
233
|
+
errors.append(f"Invalid confidence {edge.confidence} for edge {edge_id}")
|
|
234
|
+
|
|
235
|
+
# Check for empty names
|
|
236
|
+
if not all([edge.source_step, edge.target_step, edge.source_output, edge.target_input]):
|
|
237
|
+
errors.append(f"Empty component in edge {edge_id}")
|
|
238
|
+
|
|
239
|
+
return errors
|
|
240
|
+
|
|
241
|
+
def get_statistics(self) -> Dict[str, Any]:
|
|
242
|
+
"""Get statistics about the edge collection."""
|
|
243
|
+
edges = list(self.edges.values())
|
|
244
|
+
|
|
245
|
+
if not edges:
|
|
246
|
+
return {
|
|
247
|
+
'total_edges': 0,
|
|
248
|
+
'auto_resolved_edges': 0,
|
|
249
|
+
'high_confidence_edges': 0,
|
|
250
|
+
'low_confidence_edges': 0,
|
|
251
|
+
'average_confidence': 0.0,
|
|
252
|
+
'edge_types': {}
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
confidences = [edge.confidence for edge in edges]
|
|
256
|
+
edge_types = {}
|
|
257
|
+
for edge in edges:
|
|
258
|
+
edge_type = edge.edge_type.value
|
|
259
|
+
edge_types[edge_type] = edge_types.get(edge_type, 0) + 1
|
|
260
|
+
|
|
261
|
+
return {
|
|
262
|
+
'total_edges': len(edges),
|
|
263
|
+
'auto_resolved_edges': len(self.list_auto_resolved_edges()),
|
|
264
|
+
'high_confidence_edges': len(self.list_high_confidence_edges()),
|
|
265
|
+
'low_confidence_edges': len(self.list_low_confidence_edges()),
|
|
266
|
+
'average_confidence': sum(confidences) / len(confidences),
|
|
267
|
+
'min_confidence': min(confidences),
|
|
268
|
+
'max_confidence': max(confidences),
|
|
269
|
+
'edge_types': edge_types,
|
|
270
|
+
'unique_source_steps': len(self._source_index),
|
|
271
|
+
'unique_target_steps': len(self._target_index)
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
def __len__(self):
|
|
275
|
+
return len(self.edges)
|
|
276
|
+
|
|
277
|
+
def __iter__(self):
|
|
278
|
+
return iter(self.edges.values())
|
|
279
|
+
|
|
280
|
+
def __contains__(self, edge_id: str):
|
|
281
|
+
return edge_id in self.edges
|