kite-agent 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kite/__init__.py +46 -0
- kite/ab_testing.py +384 -0
- kite/agent.py +556 -0
- kite/agents/__init__.py +3 -0
- kite/agents/plan_execute.py +191 -0
- kite/agents/react_agent.py +509 -0
- kite/agents/reflective_agent.py +90 -0
- kite/agents/rewoo.py +119 -0
- kite/agents/tot.py +151 -0
- kite/conversation.py +125 -0
- kite/core.py +974 -0
- kite/data_loaders.py +111 -0
- kite/embedding_providers.py +372 -0
- kite/llm_providers.py +1278 -0
- kite/memory/__init__.py +6 -0
- kite/memory/advanced_rag.py +333 -0
- kite/memory/graph_rag.py +719 -0
- kite/memory/session_memory.py +423 -0
- kite/memory/vector_memory.py +579 -0
- kite/monitoring.py +611 -0
- kite/observers.py +107 -0
- kite/optimization/__init__.py +9 -0
- kite/optimization/resource_router.py +80 -0
- kite/persistence.py +42 -0
- kite/pipeline/__init__.py +5 -0
- kite/pipeline/deterministic_pipeline.py +323 -0
- kite/pipeline/reactive_pipeline.py +171 -0
- kite/pipeline_manager.py +15 -0
- kite/routing/__init__.py +6 -0
- kite/routing/aggregator_router.py +325 -0
- kite/routing/llm_router.py +149 -0
- kite/routing/semantic_router.py +228 -0
- kite/safety/__init__.py +6 -0
- kite/safety/circuit_breaker.py +360 -0
- kite/safety/guardrails.py +82 -0
- kite/safety/idempotency_manager.py +304 -0
- kite/safety/kill_switch.py +75 -0
- kite/tool.py +183 -0
- kite/tool_registry.py +87 -0
- kite/tools/__init__.py +21 -0
- kite/tools/code_execution.py +53 -0
- kite/tools/contrib/__init__.py +19 -0
- kite/tools/contrib/calculator.py +26 -0
- kite/tools/contrib/datetime_utils.py +20 -0
- kite/tools/contrib/linkedin.py +428 -0
- kite/tools/contrib/web_search.py +30 -0
- kite/tools/mcp/__init__.py +31 -0
- kite/tools/mcp/database_mcp.py +267 -0
- kite/tools/mcp/gdrive_mcp_server.py +503 -0
- kite/tools/mcp/gmail_mcp_server.py +601 -0
- kite/tools/mcp/postgres_mcp_server.py +490 -0
- kite/tools/mcp/slack_mcp_server.py +538 -0
- kite/tools/mcp/stripe_mcp_server.py +219 -0
- kite/tools/search.py +90 -0
- kite/tools/system_tools.py +54 -0
- kite/tools_manager.py +27 -0
- kite_agent-0.1.0.dist-info/METADATA +621 -0
- kite_agent-0.1.0.dist-info/RECORD +61 -0
- kite_agent-0.1.0.dist-info/WHEEL +5 -0
- kite_agent-0.1.0.dist-info/licenses/LICENSE +21 -0
- kite_agent-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Reactive Pipeline Pattern
|
|
3
|
+
Level 2 Autonomy: Event-driven, concurrent processing with workers.
|
|
4
|
+
|
|
5
|
+
Flow: Producer -> Stage 1 (N workers) -> Stage 2 (M workers) -> ... -> Result
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
import logging
|
|
10
|
+
import inspect
|
|
11
|
+
from typing import Dict, List, Optional, Any, Callable
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from datetime import datetime
|
|
14
|
+
from .deterministic_pipeline import PipelineStatus, PipelineState
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class ReactiveStage:
|
|
18
|
+
name: str
|
|
19
|
+
func: Callable
|
|
20
|
+
workers: int
|
|
21
|
+
input_queue: asyncio.Queue = field(default_factory=asyncio.Queue)
|
|
22
|
+
output_queue: Optional[asyncio.Queue] = None
|
|
23
|
+
tasks: List[asyncio.Task] = field(default_factory=list)
|
|
24
|
+
|
|
25
|
+
class ReactivePipeline:
|
|
26
|
+
"""
|
|
27
|
+
A reactive, streaming data pipeline with parallel workers.
|
|
28
|
+
|
|
29
|
+
Each stage runs N workers in parallel, reading from an input queue
|
|
30
|
+
and passing results to the next stage's queue.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, name: str = "reactive_pipeline", event_bus = None):
|
|
34
|
+
self.name = name
|
|
35
|
+
self.event_bus = event_bus
|
|
36
|
+
self.stages: List[ReactiveStage] = []
|
|
37
|
+
self.logger = logging.getLogger(f"Pipeline:{name}")
|
|
38
|
+
self.history: List[PipelineState] = []
|
|
39
|
+
self._running_tasks = []
|
|
40
|
+
|
|
41
|
+
if self.event_bus:
|
|
42
|
+
self.event_bus.emit("pipeline:init", {"pipeline": self.name, "type": "reactive"})
|
|
43
|
+
|
|
44
|
+
def add_stage(self, name: str, func: Callable, workers: int = 1):
|
|
45
|
+
"""Add a processing stage with a specific number of workers."""
|
|
46
|
+
stage = ReactiveStage(name=name, func=func, workers=workers)
|
|
47
|
+
|
|
48
|
+
# Connect output of previous stage to input of this stage
|
|
49
|
+
if self.stages:
|
|
50
|
+
self.stages[-1].output_queue = stage.input_queue
|
|
51
|
+
|
|
52
|
+
self.stages.append(stage)
|
|
53
|
+
self.logger.info(f" [OK] Added stage: {name} (Workers: {workers})")
|
|
54
|
+
|
|
55
|
+
async def execute(self, initial_data: Any, task_id: Optional[str] = None):
|
|
56
|
+
"""
|
|
57
|
+
Start the pipeline and feed it initial data.
|
|
58
|
+
Note: initial_data can be a single item or a list of items.
|
|
59
|
+
"""
|
|
60
|
+
t_id = task_id or f"RTASK-{datetime.now().strftime('%H%M%S')}"
|
|
61
|
+
|
|
62
|
+
if self.event_bus:
|
|
63
|
+
self.event_bus.emit("pipeline:start", {
|
|
64
|
+
"pipeline": self.name,
|
|
65
|
+
"task_id": t_id,
|
|
66
|
+
"mode": "reactive",
|
|
67
|
+
"stages": [s.name for s in self.stages]
|
|
68
|
+
})
|
|
69
|
+
# Emit structure for dashboard
|
|
70
|
+
self.event_bus.emit("pipeline:structure", {
|
|
71
|
+
"pipeline": self.name,
|
|
72
|
+
"task_id": t_id,
|
|
73
|
+
"steps": [s.name for s in self.stages]
|
|
74
|
+
})
|
|
75
|
+
|
|
76
|
+
# Start all workers for all stages
|
|
77
|
+
for i, stage in enumerate(self.stages):
|
|
78
|
+
for w_idx in range(stage.workers):
|
|
79
|
+
task = asyncio.create_task(
|
|
80
|
+
self._worker_loop(stage, t_id),
|
|
81
|
+
name=f"Worker-{stage.name}-{w_idx}"
|
|
82
|
+
)
|
|
83
|
+
stage.tasks.append(task)
|
|
84
|
+
self._running_tasks.append(task)
|
|
85
|
+
|
|
86
|
+
# Feed the first stage
|
|
87
|
+
if isinstance(initial_data, list):
|
|
88
|
+
for item in initial_data:
|
|
89
|
+
await self.stages[0].input_queue.put(item)
|
|
90
|
+
else:
|
|
91
|
+
await self.stages[0].input_queue.put(initial_data)
|
|
92
|
+
|
|
93
|
+
self.logger.info(f"Pipeline {self.name} is running...")
|
|
94
|
+
return t_id
|
|
95
|
+
|
|
96
|
+
async def _worker_loop(self, stage: ReactiveStage, task_id: str):
|
|
97
|
+
"""Internal loop for a single worker in a stage."""
|
|
98
|
+
while True:
|
|
99
|
+
item = await stage.input_queue.get()
|
|
100
|
+
if item is None: # Shutdown signal
|
|
101
|
+
stage.input_queue.task_done()
|
|
102
|
+
break
|
|
103
|
+
|
|
104
|
+
try:
|
|
105
|
+
if self.event_bus:
|
|
106
|
+
self.event_bus.emit("pipeline:step_start", {
|
|
107
|
+
"pipeline": self.name,
|
|
108
|
+
"task_id": task_id,
|
|
109
|
+
"step": stage.name,
|
|
110
|
+
"data": str(item)[:100]
|
|
111
|
+
})
|
|
112
|
+
|
|
113
|
+
# Execute function
|
|
114
|
+
if inspect.isasyncgenfunction(stage.func):
|
|
115
|
+
# Handle async generator for streaming
|
|
116
|
+
async for result in stage.func(item):
|
|
117
|
+
if stage.output_queue and result is not None:
|
|
118
|
+
if isinstance(result, list):
|
|
119
|
+
for sub_item in result:
|
|
120
|
+
await stage.output_queue.put(sub_item)
|
|
121
|
+
else:
|
|
122
|
+
await stage.output_queue.put(result)
|
|
123
|
+
elif inspect.iscoroutinefunction(stage.func):
|
|
124
|
+
result = await stage.func(item)
|
|
125
|
+
if stage.output_queue and result is not None:
|
|
126
|
+
if isinstance(result, list):
|
|
127
|
+
for sub_item in result:
|
|
128
|
+
await stage.output_queue.put(sub_item)
|
|
129
|
+
else:
|
|
130
|
+
await stage.output_queue.put(result)
|
|
131
|
+
else:
|
|
132
|
+
result = stage.func(item)
|
|
133
|
+
if stage.output_queue and result is not None:
|
|
134
|
+
if isinstance(result, list):
|
|
135
|
+
for sub_item in result:
|
|
136
|
+
await stage.output_queue.put(sub_item)
|
|
137
|
+
else:
|
|
138
|
+
await stage.output_queue.put(result)
|
|
139
|
+
|
|
140
|
+
except Exception as e:
|
|
141
|
+
self.logger.error(f"Error in stage {stage.name}: {e}")
|
|
142
|
+
if self.event_bus:
|
|
143
|
+
self.event_bus.emit("pipeline:error", {
|
|
144
|
+
"pipeline": self.name,
|
|
145
|
+
"task_id": task_id,
|
|
146
|
+
"step": stage.name,
|
|
147
|
+
"error": str(e)
|
|
148
|
+
})
|
|
149
|
+
finally:
|
|
150
|
+
stage.input_queue.task_done()
|
|
151
|
+
|
|
152
|
+
async def wait_until_complete(self):
|
|
153
|
+
"""Wait for all items to flow through all queues and workers to exit."""
|
|
154
|
+
for stage in self.stages:
|
|
155
|
+
# 1. Wait for all items in the current queue to be PROCESSED
|
|
156
|
+
await stage.input_queue.join()
|
|
157
|
+
|
|
158
|
+
# 2. Tell all workers in this stage to SHUT DOWN
|
|
159
|
+
for _ in range(stage.workers):
|
|
160
|
+
await stage.input_queue.put(None)
|
|
161
|
+
|
|
162
|
+
# 3. Wait for these specific workers to finish
|
|
163
|
+
if stage.tasks:
|
|
164
|
+
await asyncio.gather(*stage.tasks)
|
|
165
|
+
|
|
166
|
+
self.logger.info(f"Pipeline {self.name} completed successfully.")
|
|
167
|
+
|
|
168
|
+
def stop(self):
|
|
169
|
+
"""Force stop all workers."""
|
|
170
|
+
for task in self._running_tasks:
|
|
171
|
+
task.cancel()
|
kite/pipeline_manager.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Pipeline Manager"""
|
|
2
|
+
|
|
3
|
+
class PipelineManager:
|
|
4
|
+
def __init__(self, pipeline_class, logger):
|
|
5
|
+
self.pipeline_class = pipeline_class
|
|
6
|
+
self.logger = logger
|
|
7
|
+
self.pipelines = {}
|
|
8
|
+
|
|
9
|
+
def create(self, name: str, event_bus = None):
|
|
10
|
+
pipeline = self.pipeline_class(name, event_bus=event_bus)
|
|
11
|
+
self.pipelines[name] = pipeline
|
|
12
|
+
return pipeline
|
|
13
|
+
|
|
14
|
+
def get(self, name: str):
|
|
15
|
+
return self.pipelines.get(name)
|
kite/routing/__init__.py
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
1
|
+
from kite.agent import Agent
|
|
2
|
+
from kite.llm_providers import LLMFactory
|
|
3
|
+
import os
|
|
4
|
+
import asyncio
|
|
5
|
+
from typing import Dict, List, Optional, Callable, Any
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
8
|
+
from dotenv import load_dotenv
|
|
9
|
+
import json
|
|
10
|
+
|
|
11
|
+
load_dotenv()
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
# ============================================================================
|
|
15
|
+
# AGENT RESPONSE
|
|
16
|
+
# ============================================================================
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class AgentResponse:
|
|
20
|
+
"""Response from a specialist agent."""
|
|
21
|
+
agent_name: str
|
|
22
|
+
subtask: str
|
|
23
|
+
response: str
|
|
24
|
+
success: bool
|
|
25
|
+
metadata: Dict = field(default_factory=dict)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class SpecialistAgent(Agent):
|
|
29
|
+
"""Base class for specialist agents using LLM logic."""
|
|
30
|
+
|
|
31
|
+
def __init__(self, name: str, specialty: str, instructions: str, llm = None):
|
|
32
|
+
super().__init__(
|
|
33
|
+
name=name,
|
|
34
|
+
system_prompt=f"Specialty: {specialty}\nInstructions: {instructions}",
|
|
35
|
+
llm=llm or LLMFactory.auto_detect(),
|
|
36
|
+
tools=[],
|
|
37
|
+
framework=None
|
|
38
|
+
)
|
|
39
|
+
self.specialty = specialty
|
|
40
|
+
self.instructions = instructions
|
|
41
|
+
|
|
42
|
+
async def handle(self, task: str) -> AgentResponse:
|
|
43
|
+
"""Handle a specific task using LLM."""
|
|
44
|
+
print(f" [{self.name}] Processing: {task}")
|
|
45
|
+
|
|
46
|
+
prompt = f"""You are the {self.name} specialist.
|
|
47
|
+
Your specialty: {self.specialty}
|
|
48
|
+
Instructions: {self.instructions}
|
|
49
|
+
|
|
50
|
+
Task to perform: {task}
|
|
51
|
+
|
|
52
|
+
Provide a detailed and helpful response."""
|
|
53
|
+
|
|
54
|
+
try:
|
|
55
|
+
# Use the Agent's run method
|
|
56
|
+
response = await self.run(prompt)
|
|
57
|
+
return AgentResponse(
|
|
58
|
+
agent_name=self.name,
|
|
59
|
+
subtask=task,
|
|
60
|
+
response=response,
|
|
61
|
+
success=True
|
|
62
|
+
)
|
|
63
|
+
except Exception as e:
|
|
64
|
+
return AgentResponse(
|
|
65
|
+
agent_name=self.name,
|
|
66
|
+
subtask=task,
|
|
67
|
+
response=f"Error processing task: {str(e)}",
|
|
68
|
+
success=False
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
# ============================================================================
|
|
73
|
+
# TASK DECOMPOSITION
|
|
74
|
+
# ============================================================================
|
|
75
|
+
|
|
76
|
+
@dataclass
|
|
77
|
+
class Subtask:
|
|
78
|
+
"""A decomposed subtask."""
|
|
79
|
+
description: str
|
|
80
|
+
assigned_agent: str
|
|
81
|
+
priority: int = 1 # 1 = highest
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class TaskDecomposer:
|
|
85
|
+
"""
|
|
86
|
+
Decomposes complex queries into subtasks.
|
|
87
|
+
|
|
88
|
+
Uses LLM to analyze intent and split into actionable subtasks.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def __init__(self, llm = None):
|
|
92
|
+
self.llm = llm
|
|
93
|
+
self.agents_info = {}
|
|
94
|
+
|
|
95
|
+
def decompose(self, query: str) -> List[Subtask]:
|
|
96
|
+
"""
|
|
97
|
+
Decompose query into subtasks.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
query: User's complex query
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
List of subtasks with assigned agents
|
|
104
|
+
"""
|
|
105
|
+
if not self.agents_info:
|
|
106
|
+
return [Subtask(description=query, assigned_agent="DefaultAgent", priority=1)]
|
|
107
|
+
|
|
108
|
+
print(f"\n Decomposing query: {query}")
|
|
109
|
+
|
|
110
|
+
# Create prompt for LLM
|
|
111
|
+
agents_desc = "\n".join([
|
|
112
|
+
f"- {name}: {desc}"
|
|
113
|
+
for name, desc in self.agents_info.items()
|
|
114
|
+
])
|
|
115
|
+
|
|
116
|
+
prompt = f"""You are a task decomposition expert. Analyze this user query and break it into subtasks.
|
|
117
|
+
|
|
118
|
+
Available agents:
|
|
119
|
+
{agents_desc}
|
|
120
|
+
|
|
121
|
+
User query: "{query}"
|
|
122
|
+
|
|
123
|
+
Decompose this into 1-3 specific subtasks. For each subtask:
|
|
124
|
+
1. Write a clear, actionable description
|
|
125
|
+
2. Assign to the most appropriate agent
|
|
126
|
+
3. Set priority (1=highest, 3=lowest)
|
|
127
|
+
|
|
128
|
+
Respond ONLY with valid JSON array:
|
|
129
|
+
[
|
|
130
|
+
{{"description": "...", "agent": "SelectedAgentName", "priority": 1}},
|
|
131
|
+
{{"description": "...", "agent": "AnotherAgentName", "priority": 2}}
|
|
132
|
+
]"""
|
|
133
|
+
|
|
134
|
+
if self.llm:
|
|
135
|
+
response = self.llm.complete(prompt, temperature=0.3)
|
|
136
|
+
content = response.strip()
|
|
137
|
+
else:
|
|
138
|
+
# Fallback if no LLM provided (should not happen in production)
|
|
139
|
+
return [Subtask(description=query, assigned_agent=list(self.agents_info.keys())[0], priority=1)]
|
|
140
|
+
|
|
141
|
+
# Robust JSON extraction
|
|
142
|
+
try:
|
|
143
|
+
# Find the first '[' and last ']'
|
|
144
|
+
start_idx = content.find('[')
|
|
145
|
+
end_idx = content.rfind(']')
|
|
146
|
+
if start_idx != -1 and end_idx != -1:
|
|
147
|
+
content = content[start_idx:end_idx + 1]
|
|
148
|
+
|
|
149
|
+
tasks_data = json.loads(content)
|
|
150
|
+
subtasks = [
|
|
151
|
+
Subtask(
|
|
152
|
+
description=task["description"],
|
|
153
|
+
assigned_agent=task["agent"],
|
|
154
|
+
priority=task.get("priority", 1)
|
|
155
|
+
)
|
|
156
|
+
for task in tasks_data
|
|
157
|
+
]
|
|
158
|
+
|
|
159
|
+
print(f" [OK] Decomposed into {len(subtasks)} subtasks:")
|
|
160
|
+
for i, task in enumerate(subtasks, 1):
|
|
161
|
+
print(f" {i}. [{task.assigned_agent}] {task.description}")
|
|
162
|
+
|
|
163
|
+
return subtasks
|
|
164
|
+
|
|
165
|
+
except json.JSONDecodeError as e:
|
|
166
|
+
print(f" Failed to parse LLM response: {e}")
|
|
167
|
+
# Fallback
|
|
168
|
+
return [Subtask(
|
|
169
|
+
description=query,
|
|
170
|
+
assigned_agent=list(self.agents_info.keys())[0],
|
|
171
|
+
priority=1
|
|
172
|
+
)]
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
# ============================================================================
|
|
176
|
+
# AGGREGATOR ROUTER
|
|
177
|
+
# ============================================================================
|
|
178
|
+
|
|
179
|
+
class AggregatorRouter:
|
|
180
|
+
"""
|
|
181
|
+
The Router (Aggregator Agent) that orchestrates specialist agents.
|
|
182
|
+
|
|
183
|
+
Responsibilities:
|
|
184
|
+
1. Analyze user intent
|
|
185
|
+
2. Decompose into subtasks
|
|
186
|
+
3. Route to specialist agents
|
|
187
|
+
4. Execute in parallel
|
|
188
|
+
5. Merge results
|
|
189
|
+
6. Present unified response
|
|
190
|
+
"""
|
|
191
|
+
|
|
192
|
+
def __init__(self, llm = None):
|
|
193
|
+
self.llm = llm or LLMFactory.auto_detect()
|
|
194
|
+
self.agents: Dict[str, Any] = {}
|
|
195
|
+
self.decomposer = TaskDecomposer(llm=self.llm)
|
|
196
|
+
self.conversation_history: List[Dict] = []
|
|
197
|
+
|
|
198
|
+
print("[OK] Aggregator Router initialized")
|
|
199
|
+
|
|
200
|
+
def register_agent(self, name: str, agent: Any, description: Optional[str] = None):
|
|
201
|
+
"""Register a new specialist agent."""
|
|
202
|
+
self.agents[name] = agent
|
|
203
|
+
|
|
204
|
+
# Update decomposer's info
|
|
205
|
+
if description:
|
|
206
|
+
self.decomposer.agents_info[name] = description
|
|
207
|
+
elif hasattr(agent, 'specialty'):
|
|
208
|
+
self.decomposer.agents_info[name] = agent.specialty
|
|
209
|
+
elif hasattr(agent, 'metadata'):
|
|
210
|
+
self.decomposer.agents_info[name] = agent.metadata.get('specialty', 'Specialist agent')
|
|
211
|
+
else:
|
|
212
|
+
self.decomposer.agents_info[name] = "Specialist agent"
|
|
213
|
+
|
|
214
|
+
print(f" [OK] Registered agent: {name}")
|
|
215
|
+
|
|
216
|
+
async def _execute_subtask(self, subtask: Subtask) -> AgentResponse:
|
|
217
|
+
"""Execute a single subtask with assigned agent."""
|
|
218
|
+
agent = self.agents.get(subtask.assigned_agent)
|
|
219
|
+
|
|
220
|
+
if not agent:
|
|
221
|
+
return AgentResponse(
|
|
222
|
+
agent_name="Router",
|
|
223
|
+
subtask=subtask.description,
|
|
224
|
+
response=f"Error: Agent {subtask.assigned_agent} not available",
|
|
225
|
+
success=False
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
# Support both SpecialistAgent (has handle method) and generic Agent (has run method)
|
|
229
|
+
if hasattr(agent, 'handle'):
|
|
230
|
+
return await agent.handle(subtask.description)
|
|
231
|
+
elif hasattr(agent, 'run'):
|
|
232
|
+
resp = await agent.run(subtask.description)
|
|
233
|
+
return AgentResponse(
|
|
234
|
+
agent_name=subtask.assigned_agent,
|
|
235
|
+
subtask=subtask.description,
|
|
236
|
+
response=resp,
|
|
237
|
+
success=True
|
|
238
|
+
)
|
|
239
|
+
else:
|
|
240
|
+
return AgentResponse(
|
|
241
|
+
agent_name=subtask.assigned_agent,
|
|
242
|
+
subtask=subtask.description,
|
|
243
|
+
response=f"Error: Agent {subtask.assigned_agent} is not compatible",
|
|
244
|
+
success=False
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
async def _execute_parallel(self, subtasks: List[Subtask]) -> List[AgentResponse]:
|
|
248
|
+
"""Execute multiple subtasks in parallel."""
|
|
249
|
+
print(f"\n Executing {len(subtasks)} subtasks in parallel...")
|
|
250
|
+
|
|
251
|
+
tasks = [self._execute_subtask(task) for task in subtasks]
|
|
252
|
+
responses = await asyncio.gather(*tasks)
|
|
253
|
+
|
|
254
|
+
print(f" [OK] All {len(responses)} agents completed")
|
|
255
|
+
return responses
|
|
256
|
+
|
|
257
|
+
def _merge_responses(self, responses: List[AgentResponse], query: str) -> str:
|
|
258
|
+
"""Merge multiple agent responses into unified answer."""
|
|
259
|
+
print(f"\n Merging {len(responses)} responses using LLM...")
|
|
260
|
+
|
|
261
|
+
successful = [r for r in responses if r.success]
|
|
262
|
+
if not successful:
|
|
263
|
+
return "I apologize, but I encountered errors processing your request."
|
|
264
|
+
|
|
265
|
+
context = "\n\n".join([f"Agent: {r.agent_name}\nResponse: {r.response}" for r in successful])
|
|
266
|
+
|
|
267
|
+
prompt = f"""You are the Multi-Agent Aggregator. Your goal is to combine specialist responses into a single, cohesive answer for the user.
|
|
268
|
+
|
|
269
|
+
User original query: "{query}"
|
|
270
|
+
|
|
271
|
+
Specialist Responses:
|
|
272
|
+
{context}
|
|
273
|
+
|
|
274
|
+
Respond as a single helpful assistant. Maintain the specific details provided by each specialist."""
|
|
275
|
+
|
|
276
|
+
return self.llm.complete(prompt)
|
|
277
|
+
|
|
278
|
+
async def route(self, query: str) -> Dict[str, Any]:
|
|
279
|
+
"""Main routing method."""
|
|
280
|
+
print(f"\n{'='*70}")
|
|
281
|
+
print(f"ROUTING REQUEST")
|
|
282
|
+
print('='*70)
|
|
283
|
+
|
|
284
|
+
# Step 1: Decompose
|
|
285
|
+
subtasks = self.decomposer.decompose(query)
|
|
286
|
+
|
|
287
|
+
# Step 2: Execute in parallel
|
|
288
|
+
responses = await self._execute_parallel(subtasks)
|
|
289
|
+
|
|
290
|
+
# Step 3: Merge
|
|
291
|
+
final_response = self._merge_responses(responses, query)
|
|
292
|
+
|
|
293
|
+
# Add to conversation history
|
|
294
|
+
self.conversation_history.append({
|
|
295
|
+
"query": query,
|
|
296
|
+
"subtasks": [{"desc": t.description, "agent": t.assigned_agent} for t in subtasks],
|
|
297
|
+
"responses": [{"agent": r.agent_name, "success": r.success} for r in responses]
|
|
298
|
+
})
|
|
299
|
+
|
|
300
|
+
return {
|
|
301
|
+
"query": query,
|
|
302
|
+
"route": subtasks[0].assigned_agent if len(subtasks) == 1 else "multi",
|
|
303
|
+
"subtasks_count": len(subtasks),
|
|
304
|
+
"agents_used": list(set(r.agent_name for r in responses)),
|
|
305
|
+
"workers": list(set(r.agent_name for r in responses)),
|
|
306
|
+
"parallel": True,
|
|
307
|
+
"response": final_response,
|
|
308
|
+
"answer": final_response,
|
|
309
|
+
"metadata": {
|
|
310
|
+
"successful_tasks": sum(1 for r in responses if r.success),
|
|
311
|
+
"failed_tasks": sum(1 for r in responses if not r.success),
|
|
312
|
+
"total_tasks": len(responses)
|
|
313
|
+
}
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
def get_stats(self) -> Dict:
|
|
317
|
+
"""Get router statistics."""
|
|
318
|
+
return {
|
|
319
|
+
"total_requests": len(self.conversation_history),
|
|
320
|
+
"registered_agents": len(self.agents),
|
|
321
|
+
"average_subtasks": (
|
|
322
|
+
sum(len(h["subtasks"]) for h in self.conversation_history) / len(self.conversation_history)
|
|
323
|
+
if self.conversation_history else 0
|
|
324
|
+
)
|
|
325
|
+
}
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LLM-based Router Implementation
|
|
3
|
+
Uses LLM/SLM to classify user intent with reasoning.
|
|
4
|
+
More accurate than embeddings but slower and more expensive.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
import asyncio
|
|
9
|
+
from typing import Dict, List, Optional, Callable
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class LLMRoute:
|
|
14
|
+
name: str
|
|
15
|
+
description: str
|
|
16
|
+
handler: Callable
|
|
17
|
+
|
|
18
|
+
class LLMRouter:
|
|
19
|
+
"""
|
|
20
|
+
Routes user queries using LLM classification.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, llm=None):
|
|
24
|
+
self.routes: Dict[str, LLMRoute] = {}
|
|
25
|
+
self.llm = llm
|
|
26
|
+
|
|
27
|
+
def add_route(self, name: str, examples: List[str] | str = None, samples: List[str] | str = None, description: str = "", handler: Callable = None):
|
|
28
|
+
"""Add a new route. Examples serve as context. 'samples' is an alias for 'examples'."""
|
|
29
|
+
final_examples = examples or samples
|
|
30
|
+
self.routes[name] = LLMRoute(
|
|
31
|
+
name=name,
|
|
32
|
+
description=description or f"Handle queries related to {name}",
|
|
33
|
+
handler=handler
|
|
34
|
+
)
|
|
35
|
+
print(f"[OK] Added LLM route: {name}")
|
|
36
|
+
|
|
37
|
+
async def route(self, query: str, context: Optional[Dict] = None) -> Dict:
|
|
38
|
+
"""Route query to appropriate specialist agent using LLM."""
|
|
39
|
+
if not self.routes:
|
|
40
|
+
raise RuntimeError("No routes configured in LLMRouter")
|
|
41
|
+
|
|
42
|
+
# Prepare prompt
|
|
43
|
+
routes_desc = ""
|
|
44
|
+
for route in self.routes.values():
|
|
45
|
+
routes_desc += f"- {route.name}: {route.description}\n"
|
|
46
|
+
|
|
47
|
+
prompt = f"""Classify the user query into one of the following categories.
|
|
48
|
+
Available Categories:
|
|
49
|
+
{routes_desc}
|
|
50
|
+
- none: Use this if the query doesn't fit any of the above.
|
|
51
|
+
|
|
52
|
+
User Query: "{query}"
|
|
53
|
+
|
|
54
|
+
Respond ONLY with a JSON object:
|
|
55
|
+
{{"category": "category_name", "confidence": 0.0-1.0, "reasoning": "why?"}}"""
|
|
56
|
+
|
|
57
|
+
try:
|
|
58
|
+
# Use LLM for classification
|
|
59
|
+
response = await asyncio.to_thread(self.llm.complete, prompt, temperature=0.1)
|
|
60
|
+
|
|
61
|
+
# Clean and parse JSON
|
|
62
|
+
content = response.strip()
|
|
63
|
+
if "```json" in content:
|
|
64
|
+
content = content.split("```json")[-1].split("```")[0].strip()
|
|
65
|
+
elif "```" in content:
|
|
66
|
+
content = content.split("```")[-1].split("```")[0].strip()
|
|
67
|
+
|
|
68
|
+
# Robust JSON parsing with fallback
|
|
69
|
+
try:
|
|
70
|
+
# Try to find JSON in text
|
|
71
|
+
import re
|
|
72
|
+
json_match = re.search(r'\{[^}]+\}', content, re.DOTALL)
|
|
73
|
+
if json_match:
|
|
74
|
+
content = json_match.group(0)
|
|
75
|
+
|
|
76
|
+
data = json.loads(content)
|
|
77
|
+
category = data.get("category", "none")
|
|
78
|
+
confidence = data.get("confidence", 0.0)
|
|
79
|
+
reasoning = data.get("reasoning", "No reasoning")
|
|
80
|
+
|
|
81
|
+
except (json.JSONDecodeError, AttributeError) as e:
|
|
82
|
+
# Fallback: keyword matching
|
|
83
|
+
print(f"[WARN] Router JSON parse failed creating fallback")
|
|
84
|
+
content_lower = response.lower()
|
|
85
|
+
category = "none"
|
|
86
|
+
confidence = 0.5
|
|
87
|
+
reasoning = "Fallback text classification"
|
|
88
|
+
|
|
89
|
+
# Check for route keywords
|
|
90
|
+
for route_name in self.routes.keys():
|
|
91
|
+
if route_name.replace("_", " ") in content_lower:
|
|
92
|
+
category = route_name
|
|
93
|
+
break
|
|
94
|
+
print(f"\n LLM Intent Classification:")
|
|
95
|
+
print(f" Query: {query}")
|
|
96
|
+
print(f" Category: {category} (confidence: {confidence:.0%})")
|
|
97
|
+
print(f" Reasoning: {reasoning}")
|
|
98
|
+
|
|
99
|
+
if category == "none" or category not in self.routes:
|
|
100
|
+
return {
|
|
101
|
+
"route": "none",
|
|
102
|
+
"confidence": confidence,
|
|
103
|
+
"response": "I'm not sure how to help with that. Could you be more specific?",
|
|
104
|
+
"needs_clarification": True
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
# Execute handler
|
|
108
|
+
route = self.routes[category]
|
|
109
|
+
print(f"[OK] Routing to {route.name}")
|
|
110
|
+
|
|
111
|
+
# Use context if provided
|
|
112
|
+
try:
|
|
113
|
+
resp = route.handler(query, context)
|
|
114
|
+
except TypeError:
|
|
115
|
+
# Fallback for handlers that don't accept context
|
|
116
|
+
resp = route.handler(query)
|
|
117
|
+
|
|
118
|
+
if asyncio.iscoroutine(resp):
|
|
119
|
+
resp = await resp
|
|
120
|
+
|
|
121
|
+
# Extract response text
|
|
122
|
+
if isinstance(resp, dict) and 'response' in resp:
|
|
123
|
+
response_text = resp['response']
|
|
124
|
+
else:
|
|
125
|
+
response_text = str(resp)
|
|
126
|
+
|
|
127
|
+
return {
|
|
128
|
+
"route": route.name,
|
|
129
|
+
"confidence": confidence,
|
|
130
|
+
"response": response_text,
|
|
131
|
+
"needs_clarification": False
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
except Exception as e:
|
|
135
|
+
print(f"[ERROR] LLM Routing failed: {e}")
|
|
136
|
+
return {
|
|
137
|
+
"route": "error",
|
|
138
|
+
"confidence": 0,
|
|
139
|
+
"response": f"Routing error: {str(e)}",
|
|
140
|
+
"needs_clarification": False
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
def get_stats(self) -> Dict:
|
|
144
|
+
return {
|
|
145
|
+
"total_routes": len(self.routes),
|
|
146
|
+
"confidence_threshold": 0.0,
|
|
147
|
+
"cache_hit_rate": 0.0,
|
|
148
|
+
"type": "LLM"
|
|
149
|
+
}
|