kite-agent 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kite/__init__.py +46 -0
- kite/ab_testing.py +384 -0
- kite/agent.py +556 -0
- kite/agents/__init__.py +3 -0
- kite/agents/plan_execute.py +191 -0
- kite/agents/react_agent.py +509 -0
- kite/agents/reflective_agent.py +90 -0
- kite/agents/rewoo.py +119 -0
- kite/agents/tot.py +151 -0
- kite/conversation.py +125 -0
- kite/core.py +974 -0
- kite/data_loaders.py +111 -0
- kite/embedding_providers.py +372 -0
- kite/llm_providers.py +1278 -0
- kite/memory/__init__.py +6 -0
- kite/memory/advanced_rag.py +333 -0
- kite/memory/graph_rag.py +719 -0
- kite/memory/session_memory.py +423 -0
- kite/memory/vector_memory.py +579 -0
- kite/monitoring.py +611 -0
- kite/observers.py +107 -0
- kite/optimization/__init__.py +9 -0
- kite/optimization/resource_router.py +80 -0
- kite/persistence.py +42 -0
- kite/pipeline/__init__.py +5 -0
- kite/pipeline/deterministic_pipeline.py +323 -0
- kite/pipeline/reactive_pipeline.py +171 -0
- kite/pipeline_manager.py +15 -0
- kite/routing/__init__.py +6 -0
- kite/routing/aggregator_router.py +325 -0
- kite/routing/llm_router.py +149 -0
- kite/routing/semantic_router.py +228 -0
- kite/safety/__init__.py +6 -0
- kite/safety/circuit_breaker.py +360 -0
- kite/safety/guardrails.py +82 -0
- kite/safety/idempotency_manager.py +304 -0
- kite/safety/kill_switch.py +75 -0
- kite/tool.py +183 -0
- kite/tool_registry.py +87 -0
- kite/tools/__init__.py +21 -0
- kite/tools/code_execution.py +53 -0
- kite/tools/contrib/__init__.py +19 -0
- kite/tools/contrib/calculator.py +26 -0
- kite/tools/contrib/datetime_utils.py +20 -0
- kite/tools/contrib/linkedin.py +428 -0
- kite/tools/contrib/web_search.py +30 -0
- kite/tools/mcp/__init__.py +31 -0
- kite/tools/mcp/database_mcp.py +267 -0
- kite/tools/mcp/gdrive_mcp_server.py +503 -0
- kite/tools/mcp/gmail_mcp_server.py +601 -0
- kite/tools/mcp/postgres_mcp_server.py +490 -0
- kite/tools/mcp/slack_mcp_server.py +538 -0
- kite/tools/mcp/stripe_mcp_server.py +219 -0
- kite/tools/search.py +90 -0
- kite/tools/system_tools.py +54 -0
- kite/tools_manager.py +27 -0
- kite_agent-0.1.0.dist-info/METADATA +621 -0
- kite_agent-0.1.0.dist-info/RECORD +61 -0
- kite_agent-0.1.0.dist-info/WHEEL +5 -0
- kite_agent-0.1.0.dist-info/licenses/LICENSE +21 -0
- kite_agent-0.1.0.dist-info/top_level.txt +1 -0
kite/observers.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from typing import Any, Dict, List, Optional
|
|
5
|
+
|
|
6
|
+
class EventFileLogger:
|
|
7
|
+
"""Standard JSON file logger for EventBus events."""
|
|
8
|
+
def __init__(self, trace_file: str):
|
|
9
|
+
if not os.path.isabs(trace_file):
|
|
10
|
+
trace_file = os.path.join(os.getcwd(), trace_file)
|
|
11
|
+
self.trace_file = trace_file
|
|
12
|
+
# Initialize file as an empty list
|
|
13
|
+
with open(self.trace_file, "w") as f:
|
|
14
|
+
f.write("[\n]")
|
|
15
|
+
|
|
16
|
+
def on_event(self, event: str, data: Any):
|
|
17
|
+
"""Append event to the JSON list file."""
|
|
18
|
+
try:
|
|
19
|
+
# Atomic-ish append to JSON array
|
|
20
|
+
with open(self.trace_file, "rb+") as f:
|
|
21
|
+
f.seek(-2, os.SEEK_END)
|
|
22
|
+
pos = f.tell()
|
|
23
|
+
f.truncate()
|
|
24
|
+
if pos > 2: f.write(b",\n")
|
|
25
|
+
|
|
26
|
+
log_entry = {
|
|
27
|
+
"timestamp": datetime.now().isoformat(),
|
|
28
|
+
"event": event,
|
|
29
|
+
"data": self._sanitize(data)
|
|
30
|
+
}
|
|
31
|
+
json_str = json.dumps(log_entry, indent=4)
|
|
32
|
+
f.write(json_str.encode('utf-8'))
|
|
33
|
+
f.write(b"\n]")
|
|
34
|
+
except Exception:
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
def _sanitize(self, data):
|
|
38
|
+
if isinstance(data, dict):
|
|
39
|
+
return {k: self._sanitize(v) for k, v in data.items()}
|
|
40
|
+
elif isinstance(data, list):
|
|
41
|
+
return [self._sanitize(v) for v in data]
|
|
42
|
+
elif isinstance(data, (str, int, float, bool, type(None))):
|
|
43
|
+
return data
|
|
44
|
+
else:
|
|
45
|
+
return str(data)
|
|
46
|
+
|
|
47
|
+
class StateTracker:
|
|
48
|
+
"""Standard run state tracker that persists to JSON."""
|
|
49
|
+
def __init__(self, session_file: str, event_map: Dict[str, str] = None):
|
|
50
|
+
if not os.path.isabs(session_file):
|
|
51
|
+
session_file = os.path.join(os.getcwd(), session_file)
|
|
52
|
+
self.session_file = session_file
|
|
53
|
+
# Default map for common events
|
|
54
|
+
self.event_map = event_map or {
|
|
55
|
+
"pipeline:lead_result": "leads",
|
|
56
|
+
"agent:complete": "results"
|
|
57
|
+
}
|
|
58
|
+
self.data = {
|
|
59
|
+
"start_time": datetime.now().isoformat(),
|
|
60
|
+
"status": "running"
|
|
61
|
+
}
|
|
62
|
+
# Initialize collections based on event_map
|
|
63
|
+
for key in self.event_map.values():
|
|
64
|
+
if key not in self.data:
|
|
65
|
+
self.data[key] = []
|
|
66
|
+
|
|
67
|
+
def on_event(self, event: str, data: Any):
|
|
68
|
+
if event in self.event_map:
|
|
69
|
+
collection_name = self.event_map[event]
|
|
70
|
+
if collection_name in self.data:
|
|
71
|
+
# Deduplicate if data has 'name' or 'id'
|
|
72
|
+
if isinstance(data, dict) and "name" in data:
|
|
73
|
+
if any(item.get("name") == data["name"] for item in self.data[collection_name]):
|
|
74
|
+
return
|
|
75
|
+
|
|
76
|
+
self.data[collection_name].append(data)
|
|
77
|
+
self.save()
|
|
78
|
+
elif event == "pipeline:complete":
|
|
79
|
+
self.data["status"] = "completed"
|
|
80
|
+
self.data["end_time"] = datetime.now().isoformat()
|
|
81
|
+
self.save()
|
|
82
|
+
|
|
83
|
+
def save(self):
|
|
84
|
+
try:
|
|
85
|
+
with open(self.session_file, "w") as f:
|
|
86
|
+
json.dump(self.data, f, indent=4)
|
|
87
|
+
except Exception:
|
|
88
|
+
pass
|
|
89
|
+
|
|
90
|
+
class MarkdownReporter:
|
|
91
|
+
"""Standard markdown report generator for real-time updates."""
|
|
92
|
+
def __init__(self, output_file: str, title: str = "Execution Report"):
|
|
93
|
+
if not os.path.isabs(output_file):
|
|
94
|
+
output_file = os.path.join(os.getcwd(), output_file)
|
|
95
|
+
self.output_file = output_file
|
|
96
|
+
with open(self.output_file, "w") as f:
|
|
97
|
+
f.write(f"# {title}\n")
|
|
98
|
+
f.write(f"*Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}*\n\n")
|
|
99
|
+
f.write("Standardized framework report.\n\n---\n\n")
|
|
100
|
+
|
|
101
|
+
def append(self, content: str):
|
|
102
|
+
"""Append a section to the report."""
|
|
103
|
+
try:
|
|
104
|
+
with open(self.output_file, "a") as f:
|
|
105
|
+
f.write(content + "\n\n")
|
|
106
|
+
except Exception:
|
|
107
|
+
pass
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Resource-Aware Optimization (Chapter 16)
|
|
3
|
+
Dynamically selects the optimal model (resource) based on task complexity.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Dict, Any, Optional
|
|
7
|
+
import os
|
|
8
|
+
|
|
9
|
+
class ResourceAwareRouter:
|
|
10
|
+
"""
|
|
11
|
+
Routes queries to the most cost-effective model.
|
|
12
|
+
"""
|
|
13
|
+
def __init__(self, config: Dict[str, Any] = None):
|
|
14
|
+
config = config or {}
|
|
15
|
+
|
|
16
|
+
# Priority: Config > Env Var > Default Placeholder
|
|
17
|
+
# Priority: Config > Env Var (Split) > Env Var (Legacy)
|
|
18
|
+
|
|
19
|
+
# 1. Fast Model Resolution
|
|
20
|
+
fast_provider = config.get("fast_llm_provider") or os.getenv("FAST_LLM_PROVIDER")
|
|
21
|
+
fast_model_name = config.get("fast_llm_model") or os.getenv("FAST_LLM_MODEL")
|
|
22
|
+
|
|
23
|
+
if fast_provider and fast_model_name:
|
|
24
|
+
self.fast_model = f"{fast_provider}/{fast_model_name}"
|
|
25
|
+
else:
|
|
26
|
+
# Legacy fallback: check for full string in FAST_LLM_MODEL (e.g. "groq/llama...")
|
|
27
|
+
self.fast_model = config.get("fast_model") or os.getenv("FAST_LLM_MODEL")
|
|
28
|
+
self.smart_model = config.get("smart_model") or os.getenv("SMART_LLM_MODEL")
|
|
29
|
+
|
|
30
|
+
# Fallback to defaults only if absolutely necessary, but log warning
|
|
31
|
+
if not self.fast_model:
|
|
32
|
+
raise ValueError("Configuration Error: 'fast_model' not found. Set FAST_LLM_MODEL env var or pass in config.")
|
|
33
|
+
|
|
34
|
+
if not self.smart_model:
|
|
35
|
+
# Fallback to main LLM from env if strictly necessary (User request)
|
|
36
|
+
main_provider = os.getenv("LLM_PROVIDER")
|
|
37
|
+
main_model = os.getenv("LLM_MODEL")
|
|
38
|
+
if main_provider and main_model:
|
|
39
|
+
self.smart_model = f"{main_provider}/{main_model}"
|
|
40
|
+
|
|
41
|
+
if not self.smart_model:
|
|
42
|
+
raise ValueError("Configuration Error: 'smart_model' not found. Set SMART_LLM_MODEL or LLM_PROVIDER/LLM_MODEL.")
|
|
43
|
+
|
|
44
|
+
# Simple heuristic threshold (word count)
|
|
45
|
+
self.complexity_threshold = config.get("complexity_threshold", 20)
|
|
46
|
+
|
|
47
|
+
def select_model(self, query: str) -> str:
|
|
48
|
+
"""
|
|
49
|
+
Selects a model based on query complexity.
|
|
50
|
+
This is a simple implementation of 'Dynamic Model Switching'.
|
|
51
|
+
"""
|
|
52
|
+
# 1. Check length
|
|
53
|
+
word_count = len(query.split())
|
|
54
|
+
|
|
55
|
+
if word_count < self.complexity_threshold:
|
|
56
|
+
print(f" [Optimization] Routing to FAST model ({self.fast_model}) for simple query.")
|
|
57
|
+
return self.fast_model
|
|
58
|
+
|
|
59
|
+
# 2. Check for complexity keywords
|
|
60
|
+
complex_terms = ["analyze", "reason", "plan", "code", "compare", "evaluate"]
|
|
61
|
+
if any(term in query.lower() for term in complex_terms):
|
|
62
|
+
print(f" [Optimization] Routing to SMART model ({self.smart_model}) for reasoning task.")
|
|
63
|
+
return self.smart_model
|
|
64
|
+
|
|
65
|
+
# Default to fast for everything else
|
|
66
|
+
print(f" [Optimization] Defaulting to FAST model.")
|
|
67
|
+
return self.fast_model
|
|
68
|
+
|
|
69
|
+
async def route(self, query: str, framework) -> Dict[str, Any]:
|
|
70
|
+
"""
|
|
71
|
+
Executes the query using the selected model.
|
|
72
|
+
"""
|
|
73
|
+
selected_model = self.select_model(query)
|
|
74
|
+
|
|
75
|
+
# In a real system, we would instantiate a temporary agent or use the LLM directly.
|
|
76
|
+
# Here we simulate the selection impacting the framework's execution.
|
|
77
|
+
|
|
78
|
+
# TODO: This method signature might need adjustment to integrate deeply with Kite.
|
|
79
|
+
# For this demo, it returns the model name for the agent to use.
|
|
80
|
+
return {"model": selected_model}
|
kite/persistence.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Persistence utilities for saving and loading application state.
|
|
3
|
+
Supports the "Pause & Resume" (Checkpointing) pattern.
|
|
4
|
+
"""
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
from typing import Any, Dict, Optional
|
|
8
|
+
|
|
9
|
+
class JSONCheckpointer:
|
|
10
|
+
"""
|
|
11
|
+
Simple file-based persistence for arbitrary state dictionaries.
|
|
12
|
+
"""
|
|
13
|
+
def __init__(self, filepath: str):
|
|
14
|
+
self.filepath = filepath
|
|
15
|
+
|
|
16
|
+
def save(self, state: Dict[str, Any]) -> None:
|
|
17
|
+
"""Save the state dictionary to JSON file."""
|
|
18
|
+
try:
|
|
19
|
+
with open(self.filepath, "w") as f:
|
|
20
|
+
json.dump(state, f, indent=2)
|
|
21
|
+
# print(f" [Checkpoint] Saved to {self.filepath}")
|
|
22
|
+
except Exception as e:
|
|
23
|
+
print(f" [Checkpoint Error] Failed to save: {e}")
|
|
24
|
+
|
|
25
|
+
def load(self) -> Optional[Dict[str, Any]]:
|
|
26
|
+
"""Load the state dictionary from JSON file. Returns None if not found."""
|
|
27
|
+
if not os.path.exists(self.filepath):
|
|
28
|
+
return None
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
with open(self.filepath, "r") as f:
|
|
32
|
+
data = json.load(f)
|
|
33
|
+
print(f" [Checkpoint] Resuming from {self.filepath}")
|
|
34
|
+
return data
|
|
35
|
+
except Exception as e:
|
|
36
|
+
print(f" [Checkpoint Error] Corrupt file {self.filepath}: {e}")
|
|
37
|
+
return None
|
|
38
|
+
|
|
39
|
+
def clear(self) -> None:
|
|
40
|
+
"""Remove the checkpoint file."""
|
|
41
|
+
if os.path.exists(self.filepath):
|
|
42
|
+
os.remove(self.filepath)
|
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Deterministic Pipeline Pattern
|
|
3
|
+
Level 1 Autonomy: The assembly line pattern with ZERO risk.
|
|
4
|
+
|
|
5
|
+
Flow: Input -> Step 1 -> Step 2 -> ... -> Action
|
|
6
|
+
- No loops
|
|
7
|
+
- No choices
|
|
8
|
+
- Precise, predictable execution
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import os
|
|
12
|
+
import json
|
|
13
|
+
from typing import Dict, List, Optional, Any, Callable
|
|
14
|
+
from dataclasses import dataclass, field
|
|
15
|
+
from datetime import datetime
|
|
16
|
+
from enum import Enum
|
|
17
|
+
|
|
18
|
+
class PipelineStatus(Enum):
|
|
19
|
+
"""Generic pipeline processing status."""
|
|
20
|
+
PENDING = "pending"
|
|
21
|
+
PROCESSING = "processing"
|
|
22
|
+
AWAITING_APPROVAL = "awaiting_approval"
|
|
23
|
+
SUSPENDED = "suspended"
|
|
24
|
+
COMPLETED = "completed"
|
|
25
|
+
FAILED = "failed"
|
|
26
|
+
ERROR = "error"
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class PipelineState:
|
|
30
|
+
"""Current state of data in pipeline."""
|
|
31
|
+
task_id: str
|
|
32
|
+
status: PipelineStatus = PipelineStatus.PENDING
|
|
33
|
+
data: Any = None
|
|
34
|
+
results: Dict[str, Any] = field(default_factory=dict)
|
|
35
|
+
errors: List[str] = field(default_factory=list)
|
|
36
|
+
current_step_index: int = 0
|
|
37
|
+
feedback: Optional[str] = None
|
|
38
|
+
created_at: datetime = field(default_factory=datetime.now)
|
|
39
|
+
updated_at: datetime = field(default_factory=datetime.now)
|
|
40
|
+
|
|
41
|
+
def __getitem__(self, key):
|
|
42
|
+
"""Allow dict-like access for backward compatibility."""
|
|
43
|
+
return getattr(self, key)
|
|
44
|
+
|
|
45
|
+
class DeterministicPipeline:
|
|
46
|
+
"""
|
|
47
|
+
A generic deterministic processing pipeline.
|
|
48
|
+
|
|
49
|
+
Level 1 Autonomy:
|
|
50
|
+
- Fixed sequence
|
|
51
|
+
- No loops
|
|
52
|
+
- Predictable
|
|
53
|
+
|
|
54
|
+
Example:
|
|
55
|
+
pipeline = DeterministicPipeline("data_processor")
|
|
56
|
+
pipeline.add_step("load", load_func)
|
|
57
|
+
pipeline.add_step("process", process_func)
|
|
58
|
+
|
|
59
|
+
result = pipeline.execute(raw_data)
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def __init__(self, name: str = "pipeline", event_bus = None):
|
|
63
|
+
self.name = name
|
|
64
|
+
self.event_bus = event_bus
|
|
65
|
+
self.steps: List[tuple[str, Callable]] = []
|
|
66
|
+
self.checkpoints: Dict[str, bool] = {} # step_name -> approval_required
|
|
67
|
+
self.intervention_points: Dict[str, Callable] = {} # step_name -> callback
|
|
68
|
+
self.history: List[PipelineState] = []
|
|
69
|
+
if self.event_bus:
|
|
70
|
+
self.event_bus.emit("pipeline:init", {"pipeline": self.name})
|
|
71
|
+
|
|
72
|
+
def add_step(self, name: str, func: Callable):
|
|
73
|
+
"""Add a step to the pipeline."""
|
|
74
|
+
self.steps.append((name, func))
|
|
75
|
+
print(f" [OK] Added step: {name}")
|
|
76
|
+
|
|
77
|
+
def add_checkpoint(self, step_name: str, approval_required: bool = True):
|
|
78
|
+
"""Pause execution for approval after this step."""
|
|
79
|
+
self.checkpoints[step_name] = approval_required
|
|
80
|
+
print(f" [OK] Added checkpoint after: {step_name}")
|
|
81
|
+
|
|
82
|
+
def add_intervention_point(self, step_name: str, callback: Callable):
|
|
83
|
+
"""Call a callback for user intervention before this step."""
|
|
84
|
+
self.intervention_points[step_name] = callback
|
|
85
|
+
print(f" [OK] Added intervention point before: {step_name}")
|
|
86
|
+
|
|
87
|
+
def execute(self, data: Any, task_id: Optional[str] = None) -> PipelineState:
|
|
88
|
+
"""Execute all steps in the pipeline sequentially."""
|
|
89
|
+
t_id = task_id or f"TASK-{len(self.history)+1:04d}"
|
|
90
|
+
if self.event_bus:
|
|
91
|
+
self.event_bus.emit("pipeline:start", {"pipeline": self.name, "task_id": t_id, "data": str(data)[:100]})
|
|
92
|
+
self.event_bus.emit("pipeline:structure", {
|
|
93
|
+
"pipeline": self.name,
|
|
94
|
+
"task_id": t_id,
|
|
95
|
+
"steps": [name for name, _ in self.steps]
|
|
96
|
+
})
|
|
97
|
+
|
|
98
|
+
state = PipelineState(task_id=t_id, data=data)
|
|
99
|
+
state.status = PipelineStatus.PROCESSING
|
|
100
|
+
self.history.append(state)
|
|
101
|
+
|
|
102
|
+
return self._run_sync(state)
|
|
103
|
+
|
|
104
|
+
def resume(self, task_id: str, feedback: Optional[str] = None) -> PipelineState:
|
|
105
|
+
"""Resume a suspended or awaiting_approval task (sync)."""
|
|
106
|
+
state = next((s for s in self.history if s.task_id == task_id), None)
|
|
107
|
+
if not state:
|
|
108
|
+
raise ValueError(f"Task ID {task_id} not found in history")
|
|
109
|
+
|
|
110
|
+
if state.status not in [PipelineStatus.SUSPENDED, PipelineStatus.AWAITING_APPROVAL]:
|
|
111
|
+
print(f"[WARNING] Task {task_id} is in status {state.status}, not suspended.")
|
|
112
|
+
return state
|
|
113
|
+
|
|
114
|
+
print(f"\n[RESUME] Resuming pipeline: {self.name} (Task: {task_id})")
|
|
115
|
+
state.status = PipelineStatus.PROCESSING
|
|
116
|
+
state.feedback = feedback
|
|
117
|
+
|
|
118
|
+
return self._run_sync(state)
|
|
119
|
+
|
|
120
|
+
def _run_sync(self, state: PipelineState) -> PipelineState:
|
|
121
|
+
"""Internal runner for sync execution."""
|
|
122
|
+
try:
|
|
123
|
+
while state.current_step_index < len(self.steps):
|
|
124
|
+
step_idx = state.current_step_index
|
|
125
|
+
step_name, func = self.steps[step_idx]
|
|
126
|
+
|
|
127
|
+
# 1. Intervention Point
|
|
128
|
+
if step_name in self.intervention_points:
|
|
129
|
+
print(f" [INTERVENTION] Triggering before: {step_name}...")
|
|
130
|
+
callback = self.intervention_points[step_name]
|
|
131
|
+
import inspect
|
|
132
|
+
if inspect.iscoroutinefunction(callback):
|
|
133
|
+
raise RuntimeError(f"Intervention callback for '{step_name}' is async. Use execute_async().")
|
|
134
|
+
callback(state)
|
|
135
|
+
|
|
136
|
+
current_data = state.results[self.steps[step_idx-1][0]] if step_idx > 0 else state.data
|
|
137
|
+
|
|
138
|
+
if self.event_bus:
|
|
139
|
+
self.event_bus.emit("pipeline:step_start", {
|
|
140
|
+
"pipeline": self.name,
|
|
141
|
+
"task_id": state.task_id,
|
|
142
|
+
"step": step_name,
|
|
143
|
+
"index": step_idx
|
|
144
|
+
})
|
|
145
|
+
raise RuntimeError(f"Step '{step_name}' is async. Use execute_async().")
|
|
146
|
+
|
|
147
|
+
result = func(current_data)
|
|
148
|
+
state.results[step_name] = result
|
|
149
|
+
|
|
150
|
+
if self.event_bus:
|
|
151
|
+
self.event_bus.emit("pipeline:step", {
|
|
152
|
+
"pipeline": self.name,
|
|
153
|
+
"task_id": state.task_id,
|
|
154
|
+
"step": step_name,
|
|
155
|
+
"result": str(result)[:200]
|
|
156
|
+
})
|
|
157
|
+
state.current_step_index += 1
|
|
158
|
+
state.updated_at = datetime.now()
|
|
159
|
+
|
|
160
|
+
# 3. Checkpoint
|
|
161
|
+
if step_name in self.checkpoints:
|
|
162
|
+
approval_req = self.checkpoints[step_name]
|
|
163
|
+
if approval_req:
|
|
164
|
+
if self.event_bus:
|
|
165
|
+
self.event_bus.emit("pipeline:checkpoint", {"pipeline": self.name, "task_id": state.task_id, "step": step_name})
|
|
166
|
+
state.status = PipelineStatus.AWAITING_APPROVAL
|
|
167
|
+
return state
|
|
168
|
+
else:
|
|
169
|
+
state.status = PipelineStatus.SUSPENDED
|
|
170
|
+
return state
|
|
171
|
+
|
|
172
|
+
state.status = PipelineStatus.COMPLETED
|
|
173
|
+
if self.event_bus:
|
|
174
|
+
self.event_bus.emit("pipeline:complete", {"pipeline": self.name, "task_id": state.task_id})
|
|
175
|
+
|
|
176
|
+
except Exception as e:
|
|
177
|
+
state.status = PipelineStatus.ERROR
|
|
178
|
+
state.errors.append(str(e))
|
|
179
|
+
state.updated_at = datetime.now()
|
|
180
|
+
print(f"[ERROR] Pipeline '{self.name}' failed: {e}")
|
|
181
|
+
|
|
182
|
+
return state
|
|
183
|
+
|
|
184
|
+
async def execute_async(self, data: Any, task_id: Optional[str] = None) -> PipelineState:
|
|
185
|
+
"""Execute all steps in the pipeline asynchronously."""
|
|
186
|
+
t_id = task_id or f"TASK-{len(self.history)+1:04d}"
|
|
187
|
+
if self.event_bus:
|
|
188
|
+
self.event_bus.emit("pipeline:start", {"pipeline": self.name, "task_id": t_id, "data": str(data)[:100], "mode": "async"})
|
|
189
|
+
self.event_bus.emit("pipeline:structure", {
|
|
190
|
+
"pipeline": self.name,
|
|
191
|
+
"task_id": t_id,
|
|
192
|
+
"steps": [name for name, _ in self.steps]
|
|
193
|
+
})
|
|
194
|
+
|
|
195
|
+
state = PipelineState(task_id=t_id, data=data)
|
|
196
|
+
state.status = PipelineStatus.PROCESSING
|
|
197
|
+
self.history.append(state)
|
|
198
|
+
|
|
199
|
+
return await self._run_async(state)
|
|
200
|
+
|
|
201
|
+
async def resume_async(self, task_id: str, feedback: Optional[str] = None) -> PipelineState:
|
|
202
|
+
"""Resume a suspended or awaiting_approval task."""
|
|
203
|
+
state = next((s for s in self.history if s.task_id == task_id), None)
|
|
204
|
+
if not state:
|
|
205
|
+
raise ValueError(f"Task ID {task_id} not found in history")
|
|
206
|
+
|
|
207
|
+
if state.status not in [PipelineStatus.SUSPENDED, PipelineStatus.AWAITING_APPROVAL]:
|
|
208
|
+
print(f"[WARNING] Task {task_id} is in status {state.status}, not suspended.")
|
|
209
|
+
return state
|
|
210
|
+
|
|
211
|
+
print(f"\n[RESUME] Resuming pipeline async: {self.name} (Task: {task_id})")
|
|
212
|
+
state.status = PipelineStatus.PROCESSING
|
|
213
|
+
state.feedback = feedback
|
|
214
|
+
|
|
215
|
+
return await self._run_async(state)
|
|
216
|
+
|
|
217
|
+
async def _run_async(self, state: PipelineState) -> PipelineState:
|
|
218
|
+
"""Internal runner for async execution."""
|
|
219
|
+
try:
|
|
220
|
+
while state.current_step_index < len(self.steps):
|
|
221
|
+
step_idx = state.current_step_index
|
|
222
|
+
step_name, func = self.steps[step_idx]
|
|
223
|
+
|
|
224
|
+
# 1. Intervention Point (Before Step)
|
|
225
|
+
if step_name in self.intervention_points:
|
|
226
|
+
print(f" [INTERVENTION] Triggering before: {step_name}...")
|
|
227
|
+
callback = self.intervention_points[step_name]
|
|
228
|
+
# We pass the state and results for human to tweak
|
|
229
|
+
await self._invoke_callback(callback, state)
|
|
230
|
+
|
|
231
|
+
# 2. Execute Step
|
|
232
|
+
current_data = state.results[self.steps[step_idx-1][0]] if step_idx > 0 else state.data
|
|
233
|
+
|
|
234
|
+
# Create context for the step (useful for agents)
|
|
235
|
+
step_context = {
|
|
236
|
+
"task_id": state.task_id,
|
|
237
|
+
"pipeline": self.name,
|
|
238
|
+
"step": step_name,
|
|
239
|
+
"index": step_idx
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
if self.event_bus:
|
|
243
|
+
self.event_bus.emit("pipeline:step_start", {
|
|
244
|
+
"pipeline": self.name,
|
|
245
|
+
"task_id": state.task_id,
|
|
246
|
+
"step": step_name,
|
|
247
|
+
"index": step_idx
|
|
248
|
+
})
|
|
249
|
+
|
|
250
|
+
import inspect
|
|
251
|
+
# We need to decide how to pass context. Many steps in the scraper take only one arg.
|
|
252
|
+
# If it's an agent.run, we can pass context as a second arg.
|
|
253
|
+
# For now, let's just emit and assume the agent might pick it up if we change how we call it in the scraper,
|
|
254
|
+
# OR we can try to inject it if the function signature allows.
|
|
255
|
+
|
|
256
|
+
if inspect.iscoroutinefunction(func):
|
|
257
|
+
# Try to pass context if it's an agent-like run
|
|
258
|
+
try:
|
|
259
|
+
result = await func(current_data, context=step_context)
|
|
260
|
+
except TypeError:
|
|
261
|
+
result = await func(current_data)
|
|
262
|
+
else:
|
|
263
|
+
try:
|
|
264
|
+
result = func(current_data, context=step_context)
|
|
265
|
+
except TypeError:
|
|
266
|
+
result = func(current_data)
|
|
267
|
+
|
|
268
|
+
state.results[step_name] = result
|
|
269
|
+
|
|
270
|
+
if self.event_bus:
|
|
271
|
+
self.event_bus.emit("pipeline:step", {
|
|
272
|
+
"pipeline": self.name,
|
|
273
|
+
"task_id": state.task_id,
|
|
274
|
+
"step": step_name,
|
|
275
|
+
"result": str(result)[:200]
|
|
276
|
+
})
|
|
277
|
+
state.current_step_index += 1
|
|
278
|
+
state.updated_at = datetime.now()
|
|
279
|
+
|
|
280
|
+
# 3. Checkpoint (After Step)
|
|
281
|
+
if step_name in self.checkpoints:
|
|
282
|
+
approval_req = self.checkpoints[step_name]
|
|
283
|
+
if approval_req:
|
|
284
|
+
if self.event_bus:
|
|
285
|
+
self.event_bus.emit("pipeline:checkpoint", {"pipeline": self.name, "task_id": state.task_id, "step": step_name})
|
|
286
|
+
state.status = PipelineStatus.AWAITING_APPROVAL
|
|
287
|
+
return state
|
|
288
|
+
else:
|
|
289
|
+
state.status = PipelineStatus.SUSPENDED
|
|
290
|
+
return state
|
|
291
|
+
|
|
292
|
+
state.status = PipelineStatus.COMPLETED
|
|
293
|
+
if self.event_bus:
|
|
294
|
+
self.event_bus.emit("pipeline:complete", {"pipeline": self.name, "task_id": state.task_id})
|
|
295
|
+
|
|
296
|
+
except Exception as e:
|
|
297
|
+
state.status = PipelineStatus.ERROR
|
|
298
|
+
state.errors.append(str(e))
|
|
299
|
+
state.updated_at = datetime.now()
|
|
300
|
+
print(f"[ERROR] Pipeline '{self.name}' failed: {e}")
|
|
301
|
+
|
|
302
|
+
return state
|
|
303
|
+
|
|
304
|
+
async def _invoke_callback(self, callback: Callable, state: PipelineState):
|
|
305
|
+
"""Invoke intervention callback."""
|
|
306
|
+
import inspect
|
|
307
|
+
if inspect.iscoroutinefunction(callback):
|
|
308
|
+
await callback(state)
|
|
309
|
+
else:
|
|
310
|
+
callback(state)
|
|
311
|
+
|
|
312
|
+
def get_stats(self) -> Dict:
|
|
313
|
+
"""Get pipeline statistics."""
|
|
314
|
+
if not self.history:
|
|
315
|
+
return {"total_processed": 0}
|
|
316
|
+
|
|
317
|
+
statuses = [state.status for state in self.history]
|
|
318
|
+
return {
|
|
319
|
+
"total_processed": len(self.history),
|
|
320
|
+
"completed": statuses.count(PipelineStatus.COMPLETED),
|
|
321
|
+
"errors": statuses.count(PipelineStatus.ERROR),
|
|
322
|
+
"success_rate": (statuses.count(PipelineStatus.COMPLETED) / len(statuses) * 100) if statuses else 0
|
|
323
|
+
}
|