RouteKitAI 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- routekitai/__init__.py +53 -0
- routekitai/cli/__init__.py +18 -0
- routekitai/cli/main.py +40 -0
- routekitai/cli/replay.py +80 -0
- routekitai/cli/run.py +95 -0
- routekitai/cli/serve.py +966 -0
- routekitai/cli/test_agent.py +178 -0
- routekitai/cli/trace.py +209 -0
- routekitai/cli/trace_analyze.py +120 -0
- routekitai/cli/trace_search.py +126 -0
- routekitai/core/__init__.py +58 -0
- routekitai/core/agent.py +325 -0
- routekitai/core/errors.py +49 -0
- routekitai/core/hooks.py +174 -0
- routekitai/core/memory.py +54 -0
- routekitai/core/message.py +132 -0
- routekitai/core/model.py +91 -0
- routekitai/core/policies.py +373 -0
- routekitai/core/policy.py +85 -0
- routekitai/core/policy_adapter.py +133 -0
- routekitai/core/runtime.py +1403 -0
- routekitai/core/tool.py +148 -0
- routekitai/core/tools.py +180 -0
- routekitai/evals/__init__.py +13 -0
- routekitai/evals/dataset.py +75 -0
- routekitai/evals/metrics.py +101 -0
- routekitai/evals/runner.py +184 -0
- routekitai/graphs/__init__.py +12 -0
- routekitai/graphs/executors.py +457 -0
- routekitai/graphs/graph.py +164 -0
- routekitai/memory/__init__.py +13 -0
- routekitai/memory/episodic.py +242 -0
- routekitai/memory/kv.py +34 -0
- routekitai/memory/retrieval.py +192 -0
- routekitai/memory/vector.py +700 -0
- routekitai/memory/working.py +66 -0
- routekitai/message.py +29 -0
- routekitai/model.py +48 -0
- routekitai/observability/__init__.py +21 -0
- routekitai/observability/analyzer.py +314 -0
- routekitai/observability/exporters/__init__.py +10 -0
- routekitai/observability/exporters/base.py +30 -0
- routekitai/observability/exporters/jsonl.py +81 -0
- routekitai/observability/exporters/otel.py +119 -0
- routekitai/observability/spans.py +111 -0
- routekitai/observability/streaming.py +117 -0
- routekitai/observability/trace.py +144 -0
- routekitai/providers/__init__.py +9 -0
- routekitai/providers/anthropic.py +227 -0
- routekitai/providers/azure_openai.py +243 -0
- routekitai/providers/local.py +196 -0
- routekitai/providers/openai.py +321 -0
- routekitai/py.typed +0 -0
- routekitai/sandbox/__init__.py +12 -0
- routekitai/sandbox/filesystem.py +131 -0
- routekitai/sandbox/network.py +142 -0
- routekitai/sandbox/permissions.py +70 -0
- routekitai/tool.py +33 -0
- routekitai-0.1.0.dist-info/METADATA +328 -0
- routekitai-0.1.0.dist-info/RECORD +64 -0
- routekitai-0.1.0.dist-info/WHEEL +5 -0
- routekitai-0.1.0.dist-info/entry_points.txt +2 -0
- routekitai-0.1.0.dist-info/licenses/LICENSE +21 -0
- routekitai-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
"""Message primitive for RouteKit."""
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, Field
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MessageRole(str, Enum):
|
|
10
|
+
"""Message role types."""
|
|
11
|
+
|
|
12
|
+
SYSTEM = "system"
|
|
13
|
+
USER = "user"
|
|
14
|
+
ASSISTANT = "assistant"
|
|
15
|
+
TOOL = "tool"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Message(BaseModel):
|
|
19
|
+
"""Canonical message schema for RouteKit."""
|
|
20
|
+
|
|
21
|
+
role: MessageRole = Field(..., description="Message role")
|
|
22
|
+
content: str = Field(..., description="Message content")
|
|
23
|
+
tool_calls: list[dict[str, Any]] | None = Field(
|
|
24
|
+
default=None, description="Tool calls made by the assistant"
|
|
25
|
+
)
|
|
26
|
+
tool_result: dict[str, Any] | None = Field(
|
|
27
|
+
default=None, description="Result from a tool call (for tool role messages)"
|
|
28
|
+
)
|
|
29
|
+
metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
|
|
30
|
+
|
|
31
|
+
@classmethod
|
|
32
|
+
def system(cls, content: str, **metadata: Any) -> "Message":
|
|
33
|
+
"""Create a system message.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
content: Message content
|
|
37
|
+
**metadata: Additional metadata
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
System message
|
|
41
|
+
"""
|
|
42
|
+
return cls(role=MessageRole.SYSTEM, content=content, metadata=metadata)
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def user(cls, content: str, **metadata: Any) -> "Message":
|
|
46
|
+
"""Create a user message.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
content: Message content
|
|
50
|
+
**metadata: Additional metadata
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
User message
|
|
54
|
+
"""
|
|
55
|
+
return cls(role=MessageRole.USER, content=content, metadata=metadata)
|
|
56
|
+
|
|
57
|
+
@classmethod
|
|
58
|
+
def assistant(
|
|
59
|
+
cls,
|
|
60
|
+
content: str,
|
|
61
|
+
tool_calls: list[dict[str, Any]] | None = None,
|
|
62
|
+
**metadata: Any,
|
|
63
|
+
) -> "Message":
|
|
64
|
+
"""Create an assistant message.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
content: Message content
|
|
68
|
+
tool_calls: Optional tool calls
|
|
69
|
+
**metadata: Additional metadata
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
Assistant message
|
|
73
|
+
"""
|
|
74
|
+
return cls(
|
|
75
|
+
role=MessageRole.ASSISTANT, content=content, tool_calls=tool_calls, metadata=metadata
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
@classmethod
|
|
79
|
+
def tool(cls, content: str, tool_result: dict[str, Any], **metadata: Any) -> "Message":
|
|
80
|
+
"""Create a tool message.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
content: Message content
|
|
84
|
+
tool_result: Tool execution result
|
|
85
|
+
**metadata: Additional metadata
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
Tool message
|
|
89
|
+
"""
|
|
90
|
+
return cls(
|
|
91
|
+
role=MessageRole.TOOL, content=content, tool_result=tool_result, metadata=metadata
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
def to_dict(self) -> dict[str, Any]:
|
|
95
|
+
"""Serialize message to dictionary.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
Dictionary representation
|
|
99
|
+
"""
|
|
100
|
+
return self.model_dump()
|
|
101
|
+
|
|
102
|
+
@classmethod
|
|
103
|
+
def from_dict(cls, data: dict[str, Any]) -> "Message":
|
|
104
|
+
"""Deserialize message from dictionary.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
data: Dictionary representation
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
Message instance
|
|
111
|
+
"""
|
|
112
|
+
return cls(**data)
|
|
113
|
+
|
|
114
|
+
def to_json(self) -> str:
|
|
115
|
+
"""Serialize message to JSON string.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
JSON string representation
|
|
119
|
+
"""
|
|
120
|
+
return self.model_dump_json()
|
|
121
|
+
|
|
122
|
+
@classmethod
|
|
123
|
+
def from_json(cls, json_str: str) -> "Message":
|
|
124
|
+
"""Deserialize message from JSON string.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
json_str: JSON string representation
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
Message instance
|
|
131
|
+
"""
|
|
132
|
+
return cls.model_validate_json(json_str)
|
routekitai/core/model.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
"""Model primitive for RouteKit."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from collections.abc import AsyncIterator
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, Field
|
|
8
|
+
|
|
9
|
+
from routekitai.core.message import Message
|
|
10
|
+
from routekitai.core.tool import Tool
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Usage(BaseModel):
|
|
14
|
+
"""Token usage and cost information."""
|
|
15
|
+
|
|
16
|
+
prompt_tokens: int = Field(default=0, description="Number of prompt tokens")
|
|
17
|
+
completion_tokens: int = Field(default=0, description="Number of completion tokens")
|
|
18
|
+
total_tokens: int = Field(default=0, description="Total number of tokens")
|
|
19
|
+
cost: float | None = Field(default=None, description="Estimated cost in USD")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ToolCall(BaseModel):
|
|
23
|
+
"""Represents a tool call from the model."""
|
|
24
|
+
|
|
25
|
+
id: str = Field(..., description="Tool call ID")
|
|
26
|
+
name: str = Field(..., description="Tool name")
|
|
27
|
+
arguments: dict[str, Any] = Field(..., description="Tool arguments as dictionary")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ModelResponse(BaseModel):
|
|
31
|
+
"""Standard response from a model."""
|
|
32
|
+
|
|
33
|
+
content: str = Field(..., description="Response content")
|
|
34
|
+
tool_calls: list[ToolCall] | None = Field(
|
|
35
|
+
default=None, description="Tool calls requested by the model"
|
|
36
|
+
)
|
|
37
|
+
usage: Usage | None = Field(default=None, description="Token usage information")
|
|
38
|
+
metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class StreamEvent(BaseModel):
|
|
42
|
+
"""Event emitted during streaming."""
|
|
43
|
+
|
|
44
|
+
type: str = Field(..., description="Event type")
|
|
45
|
+
content: str | None = Field(default=None, description="Content chunk")
|
|
46
|
+
tool_calls: list[ToolCall] | None = Field(default=None, description="Tool calls")
|
|
47
|
+
usage: Usage | None = Field(default=None, description="Usage information")
|
|
48
|
+
metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class Model(ABC):
|
|
52
|
+
"""Abstract base class for provider-agnostic model interface."""
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def name(self) -> str:
|
|
56
|
+
"""Return the model name.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Model name, defaults to class name if not set
|
|
60
|
+
"""
|
|
61
|
+
# Check for _name attribute first (set by subclasses like FakeModel)
|
|
62
|
+
if hasattr(self, "_name"):
|
|
63
|
+
name_attr = self._name
|
|
64
|
+
if isinstance(name_attr, str):
|
|
65
|
+
return name_attr
|
|
66
|
+
# Fallback to class name
|
|
67
|
+
return str(self.__class__.__name__)
|
|
68
|
+
|
|
69
|
+
@abstractmethod
|
|
70
|
+
async def chat(
|
|
71
|
+
self,
|
|
72
|
+
messages: list[Message],
|
|
73
|
+
tools: list[Tool] | None = None,
|
|
74
|
+
stream: bool = False,
|
|
75
|
+
**kwargs: Any,
|
|
76
|
+
) -> ModelResponse | AsyncIterator[StreamEvent]:
|
|
77
|
+
"""Chat with the model.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
messages: List of messages in the conversation
|
|
81
|
+
tools: Optional list of tools available to the model
|
|
82
|
+
stream: Whether to stream the response
|
|
83
|
+
**kwargs: Additional model-specific parameters
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
ModelResponse if stream=False, AsyncIterator[StreamEvent] if stream=True
|
|
87
|
+
|
|
88
|
+
Raises:
|
|
89
|
+
ModelError: If the model operation fails
|
|
90
|
+
"""
|
|
91
|
+
raise NotImplementedError("Subclasses must implement chat")
|
|
@@ -0,0 +1,373 @@
|
|
|
1
|
+
"""Concrete policy implementations for RouteKit."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
|
|
7
|
+
from routekitai.core.agent import Agent
|
|
8
|
+
from routekitai.core.message import Message, MessageRole
|
|
9
|
+
from routekitai.core.model import ModelResponse
|
|
10
|
+
from routekitai.core.policy import (
|
|
11
|
+
Action,
|
|
12
|
+
Final,
|
|
13
|
+
ModelAction,
|
|
14
|
+
Parallel,
|
|
15
|
+
Policy,
|
|
16
|
+
ToolAction,
|
|
17
|
+
)
|
|
18
|
+
from routekitai.core.runtime import Runtime
|
|
19
|
+
from routekitai.graphs.graph import Graph
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ReActPolicy(Policy):
|
|
23
|
+
"""ReAct (Reasoning + Acting) policy.
|
|
24
|
+
|
|
25
|
+
Simple loop: model -> decide tool -> tool -> model -> final
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
max_iterations: int = 10
|
|
29
|
+
|
|
30
|
+
async def plan(self, state: dict[str, Any]) -> list[Action]:
|
|
31
|
+
"""Plan next action in ReAct loop.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
state: Current state with agent, messages, etc.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
List of actions (single action per step)
|
|
38
|
+
"""
|
|
39
|
+
messages: list[Message] = state.get("messages", [])
|
|
40
|
+
iteration: int = state.get("iteration", 0)
|
|
41
|
+
|
|
42
|
+
if iteration >= self.max_iterations:
|
|
43
|
+
# Finalize with last message
|
|
44
|
+
if messages and messages[-1].role == MessageRole.ASSISTANT:
|
|
45
|
+
return [Final(output=messages[-1])]
|
|
46
|
+
return [Final(output=Message.assistant("Max iterations reached"))]
|
|
47
|
+
|
|
48
|
+
# If no messages or last message is not assistant, call model
|
|
49
|
+
if not messages or messages[-1].role != MessageRole.ASSISTANT:
|
|
50
|
+
return [ModelAction(messages=messages)]
|
|
51
|
+
|
|
52
|
+
# Check for tool calls in last message
|
|
53
|
+
last_message = messages[-1]
|
|
54
|
+
if last_message.tool_calls:
|
|
55
|
+
# Execute tool calls (can be parallel)
|
|
56
|
+
tool_actions: list[Action] = [
|
|
57
|
+
ToolAction(
|
|
58
|
+
tool_name=tc["name"],
|
|
59
|
+
tool_input=tc.get("arguments", {}),
|
|
60
|
+
)
|
|
61
|
+
for tc in last_message.tool_calls
|
|
62
|
+
]
|
|
63
|
+
if len(tool_actions) > 1:
|
|
64
|
+
return [Parallel(actions=tool_actions)]
|
|
65
|
+
return tool_actions
|
|
66
|
+
|
|
67
|
+
# If assistant message with no tool calls, we're done
|
|
68
|
+
return [Final(output=last_message)]
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class FunctionCallingPolicy(Policy):
|
|
72
|
+
"""Strict function calling policy.
|
|
73
|
+
|
|
74
|
+
Only allows tool calls that match available tool schemas.
|
|
75
|
+
No free-form tool names.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
max_iterations: int = 10
|
|
79
|
+
|
|
80
|
+
async def plan(self, state: dict[str, Any]) -> list[Action]:
|
|
81
|
+
"""Plan next action with strict function calling.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
state: Current state
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
List of actions
|
|
88
|
+
"""
|
|
89
|
+
agent: Agent = state["agent"]
|
|
90
|
+
messages: list[Message] = state.get("messages", [])
|
|
91
|
+
iteration: int = state.get("iteration", 0)
|
|
92
|
+
|
|
93
|
+
if iteration >= self.max_iterations:
|
|
94
|
+
if messages and messages[-1].role == MessageRole.ASSISTANT:
|
|
95
|
+
return [Final(output=messages[-1])]
|
|
96
|
+
return [Final(output=Message.assistant("Max iterations reached"))]
|
|
97
|
+
|
|
98
|
+
# If no messages, start with model call
|
|
99
|
+
if not messages:
|
|
100
|
+
return [
|
|
101
|
+
ModelAction(messages=[], prompt="You are a helpful assistant with access to tools.")
|
|
102
|
+
]
|
|
103
|
+
|
|
104
|
+
# If last message is not assistant, call model
|
|
105
|
+
if messages[-1].role != MessageRole.ASSISTANT:
|
|
106
|
+
return [ModelAction(messages=messages)]
|
|
107
|
+
|
|
108
|
+
# Check for tool calls
|
|
109
|
+
last_message = messages[-1]
|
|
110
|
+
if last_message.tool_calls:
|
|
111
|
+
# Validate tool calls against available tools
|
|
112
|
+
tool_names = {t.name for t in agent.tools}
|
|
113
|
+
valid_actions: list[Action] = []
|
|
114
|
+
for tc in last_message.tool_calls:
|
|
115
|
+
tool_name = tc["name"]
|
|
116
|
+
if tool_name not in tool_names:
|
|
117
|
+
# Invalid tool name - skip
|
|
118
|
+
continue
|
|
119
|
+
valid_actions.append(
|
|
120
|
+
ToolAction(
|
|
121
|
+
tool_name=tool_name,
|
|
122
|
+
tool_input=tc.get("arguments", {}),
|
|
123
|
+
)
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
if not valid_actions:
|
|
127
|
+
# No valid tools, finalize
|
|
128
|
+
return [Final(output=Message.assistant("No valid tools available"))]
|
|
129
|
+
|
|
130
|
+
if len(valid_actions) > 1:
|
|
131
|
+
return [Parallel(actions=valid_actions)]
|
|
132
|
+
return valid_actions
|
|
133
|
+
|
|
134
|
+
# Done
|
|
135
|
+
return [Final(output=last_message)]
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class GraphPolicy(Policy, BaseModel):
|
|
139
|
+
"""Graph-based policy that delegates to graphs module execution."""
|
|
140
|
+
|
|
141
|
+
model_config = {"arbitrary_types_allowed": True}
|
|
142
|
+
|
|
143
|
+
graph: Graph = Field(..., description="Graph to execute")
|
|
144
|
+
runtime: Runtime | None = Field(default=None, description="Runtime for graph execution")
|
|
145
|
+
|
|
146
|
+
async def plan(self, state: dict[str, Any]) -> list[Action]:
|
|
147
|
+
"""Plan using graph execution.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
state: Current state
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
List of actions
|
|
154
|
+
"""
|
|
155
|
+
from routekitai.graphs.executors import GraphExecutor
|
|
156
|
+
|
|
157
|
+
# Get runtime from state or use instance runtime
|
|
158
|
+
runtime = state.get("runtime") or self.runtime
|
|
159
|
+
if not runtime:
|
|
160
|
+
raise ValueError("GraphPolicy requires a Runtime instance")
|
|
161
|
+
|
|
162
|
+
# Create executor
|
|
163
|
+
executor = GraphExecutor(runtime=runtime, graph=self.graph)
|
|
164
|
+
|
|
165
|
+
# Extract input from state (prompt or messages)
|
|
166
|
+
messages = state.get("messages", [])
|
|
167
|
+
if messages:
|
|
168
|
+
# Get prompt from last user message
|
|
169
|
+
prompt = None
|
|
170
|
+
for msg in reversed(messages):
|
|
171
|
+
if hasattr(msg, "role") and msg.role.value == "user":
|
|
172
|
+
prompt = msg.content if hasattr(msg, "content") else str(msg)
|
|
173
|
+
break
|
|
174
|
+
input_data = {"prompt": prompt or ""}
|
|
175
|
+
else:
|
|
176
|
+
input_data = state.get("input_data", {})
|
|
177
|
+
|
|
178
|
+
# Execute graph
|
|
179
|
+
graph_result = await executor.execute(input_data=input_data)
|
|
180
|
+
|
|
181
|
+
# Return final output as result
|
|
182
|
+
final_output = graph_result.get("state", {}).get("output") or graph_result.get(
|
|
183
|
+
"state", {}
|
|
184
|
+
).get("final_output", "Graph execution completed")
|
|
185
|
+
return [Final(output=Message.assistant(str(final_output)))]
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class PlanExecutePolicy(Policy):
|
|
189
|
+
"""Plan-Execute policy: plan steps, then execute them."""
|
|
190
|
+
|
|
191
|
+
max_plan_steps: int = 10
|
|
192
|
+
max_iterations: int = 20
|
|
193
|
+
|
|
194
|
+
async def plan(self, state: dict[str, Any]) -> list[Action]:
|
|
195
|
+
"""Plan execution steps.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
state: Current state
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
List of actions
|
|
202
|
+
"""
|
|
203
|
+
messages: list[Message] = state.get("messages", [])
|
|
204
|
+
phase: str = state.get("phase", "planning")
|
|
205
|
+
iteration: int = state.get("iteration", 0)
|
|
206
|
+
|
|
207
|
+
if iteration >= self.max_iterations:
|
|
208
|
+
return [Final(output=Message.assistant("Max iterations reached"))]
|
|
209
|
+
|
|
210
|
+
if phase == "planning":
|
|
211
|
+
# Planning phase: ask model to create a plan
|
|
212
|
+
planning_prompt = (
|
|
213
|
+
"Create a step-by-step plan to solve this task. List the steps clearly."
|
|
214
|
+
)
|
|
215
|
+
if messages:
|
|
216
|
+
planning_prompt = f"{messages[0].content}\n\n{planning_prompt}"
|
|
217
|
+
|
|
218
|
+
return [ModelAction(messages=[Message.user(planning_prompt)])]
|
|
219
|
+
|
|
220
|
+
elif phase == "executing":
|
|
221
|
+
# Execution phase: execute planned steps
|
|
222
|
+
plan: list[str] = state.get("plan", [])
|
|
223
|
+
current_step: int = state.get("current_step", 0)
|
|
224
|
+
|
|
225
|
+
if current_step >= len(plan):
|
|
226
|
+
# All steps executed
|
|
227
|
+
return [Final(output=Message.assistant("Plan execution completed"))]
|
|
228
|
+
|
|
229
|
+
# Execute current step
|
|
230
|
+
step = plan[current_step]
|
|
231
|
+
return [ModelAction(messages=[Message.user(f"Execute step: {step}")])]
|
|
232
|
+
|
|
233
|
+
# Default: start planning
|
|
234
|
+
return [ModelAction(messages=messages)]
|
|
235
|
+
|
|
236
|
+
async def reflect(self, state: dict[str, Any], observation: dict[str, Any]) -> dict[str, Any]:
|
|
237
|
+
"""Reflect and update state based on observation.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
state: Current state
|
|
241
|
+
observation: Observation from last action
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
Updated state
|
|
245
|
+
"""
|
|
246
|
+
state = state.copy()
|
|
247
|
+
state.setdefault("observations", []).append(observation)
|
|
248
|
+
state["iteration"] = state.get("iteration", 0) + 1
|
|
249
|
+
|
|
250
|
+
# Extract plan from model response if in planning phase
|
|
251
|
+
if state.get("phase") == "planning" and "result" in observation:
|
|
252
|
+
result = observation["result"]
|
|
253
|
+
if isinstance(result, ModelResponse):
|
|
254
|
+
# Simple plan extraction (split by lines)
|
|
255
|
+
plan_lines = [
|
|
256
|
+
line.strip()
|
|
257
|
+
for line in result.content.split("\n")
|
|
258
|
+
if line.strip() and line.strip()[0].isdigit()
|
|
259
|
+
]
|
|
260
|
+
if plan_lines:
|
|
261
|
+
state["plan"] = plan_lines
|
|
262
|
+
state["phase"] = "executing"
|
|
263
|
+
state["current_step"] = 0
|
|
264
|
+
|
|
265
|
+
# Update current step in execution phase
|
|
266
|
+
if state.get("phase") == "executing":
|
|
267
|
+
state["current_step"] = state.get("current_step", 0) + 1
|
|
268
|
+
|
|
269
|
+
return state
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
class SupervisorPolicy(Policy, BaseModel):
|
|
273
|
+
"""Supervisor policy for multi-agent coordination.
|
|
274
|
+
|
|
275
|
+
Supervisor delegates tasks to sub-agents with constrained toolsets and merges results.
|
|
276
|
+
"""
|
|
277
|
+
|
|
278
|
+
model_config = {"arbitrary_types_allowed": True}
|
|
279
|
+
|
|
280
|
+
sub_agents: dict[str, Agent] = Field(
|
|
281
|
+
default_factory=dict, description="Sub-agents available for delegation"
|
|
282
|
+
)
|
|
283
|
+
runtime: Any = Field(default=None, description="Runtime for executing sub-agents")
|
|
284
|
+
max_iterations: int = Field(default=20, description="Maximum iterations")
|
|
285
|
+
delegation_keywords: dict[str, list[str]] = Field(
|
|
286
|
+
default_factory=dict, description="Keywords to identify which agent to delegate to"
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
async def plan(self, state: dict[str, Any]) -> list[Action]:
|
|
290
|
+
"""Plan using supervisor delegation.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
state: Current state (should not be mutated directly)
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
List of actions
|
|
297
|
+
"""
|
|
298
|
+
messages: list[Message] = state.get("messages", [])
|
|
299
|
+
iteration: int = state.get("iteration", 0)
|
|
300
|
+
runtime: Runtime | None = state.get("runtime") or self.runtime
|
|
301
|
+
|
|
302
|
+
if iteration >= self.max_iterations:
|
|
303
|
+
return [Final(output=Message.assistant("Max iterations reached"))]
|
|
304
|
+
|
|
305
|
+
# Check if we're waiting for a sub-agent result
|
|
306
|
+
if state.get("waiting_for_subagent"):
|
|
307
|
+
# Sub-agent has completed, merge result
|
|
308
|
+
subagent_result = state.get("subagent_result")
|
|
309
|
+
if subagent_result:
|
|
310
|
+
# Create message with sub-agent result
|
|
311
|
+
result_message = Message.assistant(
|
|
312
|
+
f"Sub-agent completed: {subagent_result.get('output', 'Task completed') if isinstance(subagent_result, dict) else str(subagent_result)}"
|
|
313
|
+
)
|
|
314
|
+
# Note: State mutations should be done via reflect(), not here
|
|
315
|
+
# But we need to signal completion, so we'll let the runtime handle it
|
|
316
|
+
# Supervisor processes the result
|
|
317
|
+
return [ModelAction(messages=[*messages, result_message])]
|
|
318
|
+
|
|
319
|
+
# If no messages, supervisor decides which agent to use
|
|
320
|
+
if not messages:
|
|
321
|
+
# Supervisor prompt to choose agent
|
|
322
|
+
agent_list = ", ".join(self.sub_agents.keys())
|
|
323
|
+
prompt = (
|
|
324
|
+
f"You are a supervisor coordinating multiple agents. "
|
|
325
|
+
f"Available agents: {agent_list}. "
|
|
326
|
+
f"Analyze the task and delegate to the appropriate agent. "
|
|
327
|
+
f"Respond with the agent name you want to delegate to."
|
|
328
|
+
)
|
|
329
|
+
return [ModelAction(messages=[Message.user(prompt)])]
|
|
330
|
+
|
|
331
|
+
# Check if supervisor has delegated
|
|
332
|
+
last_message = messages[-1]
|
|
333
|
+
if last_message.role == MessageRole.ASSISTANT and not state.get("delegated_agent"):
|
|
334
|
+
# Supervisor responded - check if it's a delegation
|
|
335
|
+
content = last_message.content.lower()
|
|
336
|
+
delegated_agent = None
|
|
337
|
+
|
|
338
|
+
# Check for explicit agent mentions
|
|
339
|
+
for agent_name, keywords in self.delegation_keywords.items():
|
|
340
|
+
if any(keyword.lower() in content for keyword in keywords):
|
|
341
|
+
delegated_agent = agent_name
|
|
342
|
+
break
|
|
343
|
+
|
|
344
|
+
# Fallback: check if agent name appears in content
|
|
345
|
+
if not delegated_agent:
|
|
346
|
+
for agent_name in self.sub_agents.keys():
|
|
347
|
+
if agent_name.lower() in content:
|
|
348
|
+
delegated_agent = agent_name
|
|
349
|
+
break
|
|
350
|
+
|
|
351
|
+
if delegated_agent and delegated_agent in self.sub_agents:
|
|
352
|
+
# Delegate to sub-agent
|
|
353
|
+
state["delegated_agent"] = delegated_agent
|
|
354
|
+
state["waiting_for_subagent"] = True
|
|
355
|
+
|
|
356
|
+
# Extract task from original prompt
|
|
357
|
+
original_prompt = messages[0].content if messages else "Complete the task"
|
|
358
|
+
|
|
359
|
+
# Return a special action that will trigger sub-agent execution
|
|
360
|
+
# This is handled by the runtime adapter
|
|
361
|
+
return [
|
|
362
|
+
ModelAction(
|
|
363
|
+
messages=[Message.user(f"DELEGATE:{delegated_agent}:{original_prompt}")]
|
|
364
|
+
)
|
|
365
|
+
]
|
|
366
|
+
|
|
367
|
+
# Check if we need to execute sub-agent (handled by adapter)
|
|
368
|
+
if state.get("delegated_agent") and state.get("waiting_for_subagent") and runtime:
|
|
369
|
+
# This will be handled by the adapter
|
|
370
|
+
pass
|
|
371
|
+
|
|
372
|
+
# Default: continue with supervisor
|
|
373
|
+
return [ModelAction(messages=messages)]
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""Policy system for routkitai agent execution."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, Field
|
|
7
|
+
|
|
8
|
+
from routekitai.core.message import Message
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Action(BaseModel):
|
|
12
|
+
"""Base class for execution actions."""
|
|
13
|
+
|
|
14
|
+
action_type: str = Field(..., description="Action type")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ModelAction(Action):
|
|
18
|
+
"""Action to call the model."""
|
|
19
|
+
|
|
20
|
+
action_type: str = Field(default="model", description="Action type")
|
|
21
|
+
messages: list[Message] = Field(..., description="Messages to send to model")
|
|
22
|
+
prompt: str | None = Field(default=None, description="Optional prompt string")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ToolAction(Action):
|
|
26
|
+
"""Action to execute a tool."""
|
|
27
|
+
|
|
28
|
+
action_type: str = Field(default="tool", description="Action type")
|
|
29
|
+
tool_name: str = Field(..., description="Tool name to execute")
|
|
30
|
+
tool_input: dict[str, Any] = Field(..., description="Tool input arguments")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Parallel(Action):
|
|
34
|
+
"""Action to execute multiple tool actions in parallel."""
|
|
35
|
+
|
|
36
|
+
action_type: str = Field(default="parallel", description="Action type")
|
|
37
|
+
actions: list[Action] = Field(
|
|
38
|
+
..., description="Actions to execute in parallel (typically ToolActions)"
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class Final(Action):
|
|
43
|
+
"""Action to finalize execution with output."""
|
|
44
|
+
|
|
45
|
+
action_type: str = Field(default="final", description="Action type")
|
|
46
|
+
output: Message = Field(..., description="Final output message")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class Policy(ABC):
|
|
50
|
+
"""Policy interface for agent execution.
|
|
51
|
+
|
|
52
|
+
A policy determines what actions to take based on the current state.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
@abstractmethod
|
|
56
|
+
async def plan(self, state: dict[str, Any]) -> list[Action]:
|
|
57
|
+
"""Plan next actions based on current state.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
state: Current execution state containing:
|
|
61
|
+
- agent: Agent instance
|
|
62
|
+
- messages: List of conversation messages
|
|
63
|
+
- tools: Available tools
|
|
64
|
+
- memory: Memory instance (if available)
|
|
65
|
+
- metadata: Additional state metadata
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
List of actions to execute
|
|
69
|
+
"""
|
|
70
|
+
raise NotImplementedError("Subclasses must implement plan")
|
|
71
|
+
|
|
72
|
+
async def reflect(self, state: dict[str, Any], observation: dict[str, Any]) -> dict[str, Any]:
|
|
73
|
+
"""Reflect on observation and update state.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
state: Current state
|
|
77
|
+
observation: Observation from last action (result, error, etc.)
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
Updated state
|
|
81
|
+
"""
|
|
82
|
+
# Default implementation: just merge observation into state
|
|
83
|
+
state = state.copy()
|
|
84
|
+
state.setdefault("observations", []).append(observation)
|
|
85
|
+
return state
|