synapsekit 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synapsekit/__init__.py +158 -0
- synapsekit/_compat.py +29 -0
- synapsekit/agents/__init__.py +33 -0
- synapsekit/agents/base.py +58 -0
- synapsekit/agents/executor.py +83 -0
- synapsekit/agents/function_calling.py +123 -0
- synapsekit/agents/memory.py +47 -0
- synapsekit/agents/react.py +147 -0
- synapsekit/agents/registry.py +42 -0
- synapsekit/agents/tools/__init__.py +13 -0
- synapsekit/agents/tools/calculator.py +72 -0
- synapsekit/agents/tools/file_read.py +42 -0
- synapsekit/agents/tools/python_repl.py +55 -0
- synapsekit/agents/tools/sql_query.py +106 -0
- synapsekit/agents/tools/web_search.py +59 -0
- synapsekit/embeddings/__init__.py +3 -0
- synapsekit/embeddings/backend.py +44 -0
- synapsekit/graph/__init__.py +26 -0
- synapsekit/graph/checkpointers/__init__.py +9 -0
- synapsekit/graph/checkpointers/base.py +23 -0
- synapsekit/graph/checkpointers/memory.py +26 -0
- synapsekit/graph/checkpointers/sqlite.py +41 -0
- synapsekit/graph/compiled.py +164 -0
- synapsekit/graph/edge.py +22 -0
- synapsekit/graph/errors.py +9 -0
- synapsekit/graph/graph.py +125 -0
- synapsekit/graph/mermaid.py +31 -0
- synapsekit/graph/node.py +34 -0
- synapsekit/graph/state.py +9 -0
- synapsekit/llm/__init__.py +34 -0
- synapsekit/llm/_cache.py +52 -0
- synapsekit/llm/_retry.py +44 -0
- synapsekit/llm/anthropic.py +125 -0
- synapsekit/llm/base.py +158 -0
- synapsekit/llm/bedrock.py +97 -0
- synapsekit/llm/cohere.py +45 -0
- synapsekit/llm/gemini.py +123 -0
- synapsekit/llm/mistral.py +74 -0
- synapsekit/llm/ollama.py +46 -0
- synapsekit/llm/openai.py +95 -0
- synapsekit/loaders/__init__.py +34 -0
- synapsekit/loaders/base.py +9 -0
- synapsekit/loaders/csv.py +35 -0
- synapsekit/loaders/directory.py +57 -0
- synapsekit/loaders/html.py +23 -0
- synapsekit/loaders/json_loader.py +38 -0
- synapsekit/loaders/pdf.py +23 -0
- synapsekit/loaders/text.py +31 -0
- synapsekit/loaders/web.py +44 -0
- synapsekit/memory/__init__.py +3 -0
- synapsekit/memory/conversation.py +38 -0
- synapsekit/observability/__init__.py +3 -0
- synapsekit/observability/tracer.py +70 -0
- synapsekit/parsers/__init__.py +5 -0
- synapsekit/parsers/json_parser.py +26 -0
- synapsekit/parsers/list_parser.py +16 -0
- synapsekit/parsers/pydantic_parser.py +23 -0
- synapsekit/prompts/__init__.py +3 -0
- synapsekit/prompts/template.py +45 -0
- synapsekit/py.typed +0 -0
- synapsekit/rag/__init__.py +4 -0
- synapsekit/rag/facade.py +187 -0
- synapsekit/rag/pipeline.py +98 -0
- synapsekit/retrieval/__init__.py +5 -0
- synapsekit/retrieval/base.py +23 -0
- synapsekit/retrieval/chroma.py +68 -0
- synapsekit/retrieval/faiss.py +72 -0
- synapsekit/retrieval/pinecone.py +53 -0
- synapsekit/retrieval/qdrant.py +76 -0
- synapsekit/retrieval/retriever.py +65 -0
- synapsekit/retrieval/vectorstore.py +83 -0
- synapsekit/text_splitters/__init__.py +13 -0
- synapsekit/text_splitters/base.py +12 -0
- synapsekit/text_splitters/character.py +63 -0
- synapsekit/text_splitters/recursive.py +68 -0
- synapsekit/text_splitters/semantic.py +73 -0
- synapsekit/text_splitters/token.py +33 -0
- synapsekit-0.5.0.dist-info/METADATA +268 -0
- synapsekit-0.5.0.dist-info/RECORD +81 -0
- synapsekit-0.5.0.dist-info/WHEEL +4 -0
- synapsekit-0.5.0.dist-info/licenses/LICENSE +21 -0
synapsekit/__init__.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SynapseKit — lightweight, async-first RAG framework.
|
|
3
|
+
|
|
4
|
+
3-line happy path:
|
|
5
|
+
|
|
6
|
+
from synapsekit import RAG
|
|
7
|
+
|
|
8
|
+
rag = RAG(model="gpt-4o-mini", api_key="sk-...")
|
|
9
|
+
rag.add("Your document text here")
|
|
10
|
+
|
|
11
|
+
async for token in rag.stream("What is the main topic?"):
|
|
12
|
+
print(token, end="", flush=True)
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from .agents import (
|
|
18
|
+
AgentConfig,
|
|
19
|
+
AgentExecutor,
|
|
20
|
+
AgentMemory,
|
|
21
|
+
AgentStep,
|
|
22
|
+
BaseTool,
|
|
23
|
+
CalculatorTool,
|
|
24
|
+
FileReadTool,
|
|
25
|
+
FunctionCallingAgent,
|
|
26
|
+
PythonREPLTool,
|
|
27
|
+
ReActAgent,
|
|
28
|
+
SQLQueryTool,
|
|
29
|
+
ToolRegistry,
|
|
30
|
+
ToolResult,
|
|
31
|
+
WebSearchTool,
|
|
32
|
+
)
|
|
33
|
+
from .embeddings.backend import SynapsekitEmbeddings
|
|
34
|
+
from .graph import (
|
|
35
|
+
END,
|
|
36
|
+
BaseCheckpointer,
|
|
37
|
+
CompiledGraph,
|
|
38
|
+
ConditionalEdge,
|
|
39
|
+
ConditionFn,
|
|
40
|
+
Edge,
|
|
41
|
+
GraphConfigError,
|
|
42
|
+
GraphRuntimeError,
|
|
43
|
+
GraphState,
|
|
44
|
+
InMemoryCheckpointer,
|
|
45
|
+
Node,
|
|
46
|
+
NodeFn,
|
|
47
|
+
SQLiteCheckpointer,
|
|
48
|
+
StateGraph,
|
|
49
|
+
agent_node,
|
|
50
|
+
rag_node,
|
|
51
|
+
)
|
|
52
|
+
from .llm.base import BaseLLM, LLMConfig
|
|
53
|
+
from .loaders.base import Document
|
|
54
|
+
from .loaders.csv import CSVLoader
|
|
55
|
+
from .loaders.directory import DirectoryLoader
|
|
56
|
+
from .loaders.html import HTMLLoader
|
|
57
|
+
from .loaders.json_loader import JSONLoader
|
|
58
|
+
from .loaders.pdf import PDFLoader
|
|
59
|
+
from .loaders.text import StringLoader, TextLoader
|
|
60
|
+
from .loaders.web import WebLoader
|
|
61
|
+
from .memory.conversation import ConversationMemory
|
|
62
|
+
from .observability.tracer import TokenTracer
|
|
63
|
+
from .parsers.json_parser import JSONParser
|
|
64
|
+
from .parsers.list_parser import ListParser
|
|
65
|
+
from .parsers.pydantic_parser import PydanticParser
|
|
66
|
+
from .prompts.template import ChatPromptTemplate, FewShotPromptTemplate, PromptTemplate
|
|
67
|
+
from .rag.facade import RAG
|
|
68
|
+
from .rag.pipeline import RAGConfig, RAGPipeline
|
|
69
|
+
from .retrieval.base import VectorStore
|
|
70
|
+
from .retrieval.retriever import Retriever
|
|
71
|
+
from .retrieval.vectorstore import InMemoryVectorStore
|
|
72
|
+
from .text_splitters import (
|
|
73
|
+
BaseSplitter,
|
|
74
|
+
CharacterTextSplitter,
|
|
75
|
+
RecursiveCharacterTextSplitter,
|
|
76
|
+
SemanticSplitter,
|
|
77
|
+
TokenAwareSplitter,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
__version__ = "0.5.0"
|
|
81
|
+
__all__ = [
|
|
82
|
+
# Facade
|
|
83
|
+
"RAG",
|
|
84
|
+
# Pipeline
|
|
85
|
+
"RAGPipeline",
|
|
86
|
+
"RAGConfig",
|
|
87
|
+
# LLM
|
|
88
|
+
"BaseLLM",
|
|
89
|
+
"LLMConfig",
|
|
90
|
+
# Embeddings
|
|
91
|
+
"SynapsekitEmbeddings",
|
|
92
|
+
# Vector stores
|
|
93
|
+
"VectorStore",
|
|
94
|
+
"InMemoryVectorStore",
|
|
95
|
+
# Retrieval
|
|
96
|
+
"Retriever",
|
|
97
|
+
# Memory / observability
|
|
98
|
+
"ConversationMemory",
|
|
99
|
+
"TokenTracer",
|
|
100
|
+
# Loaders
|
|
101
|
+
"Document",
|
|
102
|
+
"TextLoader",
|
|
103
|
+
"StringLoader",
|
|
104
|
+
"PDFLoader",
|
|
105
|
+
"HTMLLoader",
|
|
106
|
+
"CSVLoader",
|
|
107
|
+
"JSONLoader",
|
|
108
|
+
"DirectoryLoader",
|
|
109
|
+
"WebLoader",
|
|
110
|
+
# Parsers
|
|
111
|
+
"JSONParser",
|
|
112
|
+
"PydanticParser",
|
|
113
|
+
"ListParser",
|
|
114
|
+
# Prompts
|
|
115
|
+
"PromptTemplate",
|
|
116
|
+
"ChatPromptTemplate",
|
|
117
|
+
"FewShotPromptTemplate",
|
|
118
|
+
# Agents
|
|
119
|
+
"BaseTool",
|
|
120
|
+
"ToolResult",
|
|
121
|
+
"ToolRegistry",
|
|
122
|
+
"AgentMemory",
|
|
123
|
+
"AgentStep",
|
|
124
|
+
"ReActAgent",
|
|
125
|
+
"FunctionCallingAgent",
|
|
126
|
+
"AgentExecutor",
|
|
127
|
+
"AgentConfig",
|
|
128
|
+
# Built-in tools
|
|
129
|
+
"CalculatorTool",
|
|
130
|
+
"FileReadTool",
|
|
131
|
+
"PythonREPLTool",
|
|
132
|
+
"SQLQueryTool",
|
|
133
|
+
"WebSearchTool",
|
|
134
|
+
# Text splitters
|
|
135
|
+
"BaseSplitter",
|
|
136
|
+
"CharacterTextSplitter",
|
|
137
|
+
"RecursiveCharacterTextSplitter",
|
|
138
|
+
"TokenAwareSplitter",
|
|
139
|
+
"SemanticSplitter",
|
|
140
|
+
# Graph workflows
|
|
141
|
+
"END",
|
|
142
|
+
"GraphState",
|
|
143
|
+
"GraphConfigError",
|
|
144
|
+
"GraphRuntimeError",
|
|
145
|
+
"Node",
|
|
146
|
+
"NodeFn",
|
|
147
|
+
"agent_node",
|
|
148
|
+
"rag_node",
|
|
149
|
+
"Edge",
|
|
150
|
+
"ConditionalEdge",
|
|
151
|
+
"ConditionFn",
|
|
152
|
+
"StateGraph",
|
|
153
|
+
"CompiledGraph",
|
|
154
|
+
# Checkpointers
|
|
155
|
+
"BaseCheckpointer",
|
|
156
|
+
"InMemoryCheckpointer",
|
|
157
|
+
"SQLiteCheckpointer",
|
|
158
|
+
]
|
synapsekit/_compat.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from collections.abc import Coroutine
|
|
5
|
+
from typing import Any, TypeVar
|
|
6
|
+
|
|
7
|
+
T = TypeVar("T")
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def run_sync(coro: Coroutine[Any, Any, T]) -> T:
|
|
11
|
+
"""
|
|
12
|
+
Run an async coroutine synchronously.
|
|
13
|
+
Works both inside and outside a running event loop.
|
|
14
|
+
"""
|
|
15
|
+
try:
|
|
16
|
+
loop = asyncio.get_running_loop()
|
|
17
|
+
except RuntimeError:
|
|
18
|
+
loop = None
|
|
19
|
+
|
|
20
|
+
if loop is not None and loop.is_running():
|
|
21
|
+
# Running inside an existing loop (e.g., Jupyter).
|
|
22
|
+
# Use a new thread with its own loop to avoid deadlock.
|
|
23
|
+
import concurrent.futures
|
|
24
|
+
|
|
25
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
|
26
|
+
future = pool.submit(asyncio.run, coro)
|
|
27
|
+
return future.result()
|
|
28
|
+
else:
|
|
29
|
+
return asyncio.run(coro)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from .base import BaseTool, ToolResult
|
|
2
|
+
from .executor import AgentConfig, AgentExecutor
|
|
3
|
+
from .function_calling import FunctionCallingAgent
|
|
4
|
+
from .memory import AgentMemory, AgentStep
|
|
5
|
+
from .react import ReActAgent
|
|
6
|
+
from .registry import ToolRegistry
|
|
7
|
+
from .tools import (
|
|
8
|
+
CalculatorTool,
|
|
9
|
+
FileReadTool,
|
|
10
|
+
PythonREPLTool,
|
|
11
|
+
SQLQueryTool,
|
|
12
|
+
WebSearchTool,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
# Core
|
|
17
|
+
"BaseTool",
|
|
18
|
+
"ToolResult",
|
|
19
|
+
"ToolRegistry",
|
|
20
|
+
"AgentMemory",
|
|
21
|
+
"AgentStep",
|
|
22
|
+
# Agents
|
|
23
|
+
"ReActAgent",
|
|
24
|
+
"FunctionCallingAgent",
|
|
25
|
+
"AgentExecutor",
|
|
26
|
+
"AgentConfig",
|
|
27
|
+
# Built-in tools
|
|
28
|
+
"CalculatorTool",
|
|
29
|
+
"FileReadTool",
|
|
30
|
+
"PythonREPLTool",
|
|
31
|
+
"SQLQueryTool",
|
|
32
|
+
"WebSearchTool",
|
|
33
|
+
]
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class ToolResult:
|
|
10
|
+
"""Result returned by any tool execution."""
|
|
11
|
+
|
|
12
|
+
output: str
|
|
13
|
+
error: str | None = None
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def is_error(self) -> bool:
|
|
17
|
+
return self.error is not None
|
|
18
|
+
|
|
19
|
+
def __str__(self) -> str:
|
|
20
|
+
return self.error if self.error is not None else self.output
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class BaseTool(ABC):
|
|
24
|
+
"""Abstract base class for all agent tools."""
|
|
25
|
+
|
|
26
|
+
name: str
|
|
27
|
+
description: str
|
|
28
|
+
|
|
29
|
+
# JSON Schema for the tool's input parameters.
|
|
30
|
+
# Subclasses must define this as a class attribute.
|
|
31
|
+
parameters: dict = field(default_factory=dict)
|
|
32
|
+
|
|
33
|
+
@abstractmethod
|
|
34
|
+
async def run(self, **kwargs: Any) -> ToolResult:
|
|
35
|
+
"""Execute the tool. kwargs come from the parsed Action Input."""
|
|
36
|
+
...
|
|
37
|
+
|
|
38
|
+
def schema(self) -> dict:
|
|
39
|
+
"""OpenAI-compatible function-calling schema."""
|
|
40
|
+
return {
|
|
41
|
+
"type": "function",
|
|
42
|
+
"function": {
|
|
43
|
+
"name": self.name,
|
|
44
|
+
"description": self.description,
|
|
45
|
+
"parameters": getattr(self, "parameters", {"type": "object", "properties": {}}),
|
|
46
|
+
},
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
def anthropic_schema(self) -> dict:
|
|
50
|
+
"""Anthropic-compatible tool schema."""
|
|
51
|
+
return {
|
|
52
|
+
"name": self.name,
|
|
53
|
+
"description": self.description,
|
|
54
|
+
"input_schema": getattr(self, "parameters", {"type": "object", "properties": {}}),
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
def __repr__(self) -> str:
|
|
58
|
+
return f"{type(self).__name__}(name={self.name!r})"
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncGenerator
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
from .._compat import run_sync
|
|
8
|
+
from ..llm.base import BaseLLM
|
|
9
|
+
from .base import BaseTool
|
|
10
|
+
from .function_calling import FunctionCallingAgent
|
|
11
|
+
from .memory import AgentMemory
|
|
12
|
+
from .react import ReActAgent
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class AgentConfig:
|
|
17
|
+
llm: BaseLLM
|
|
18
|
+
tools: list[BaseTool]
|
|
19
|
+
agent_type: Literal["react", "function_calling"] = "react"
|
|
20
|
+
max_iterations: int = 10
|
|
21
|
+
system_prompt: str = "You are a helpful AI assistant."
|
|
22
|
+
verbose: bool = False
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class AgentExecutor:
|
|
26
|
+
"""
|
|
27
|
+
High-level agent runner. Picks ReActAgent or FunctionCallingAgent based on config.
|
|
28
|
+
|
|
29
|
+
Usage::
|
|
30
|
+
|
|
31
|
+
executor = AgentExecutor(AgentConfig(
|
|
32
|
+
llm=OpenAILLM(config),
|
|
33
|
+
tools=[CalculatorTool(), WebSearchTool()],
|
|
34
|
+
agent_type="function_calling",
|
|
35
|
+
))
|
|
36
|
+
|
|
37
|
+
answer = await executor.run("What is 2 ** 10?")
|
|
38
|
+
answer = executor.run_sync("What is 2 ** 10?")
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, config: AgentConfig) -> None:
|
|
42
|
+
self.config = config
|
|
43
|
+
self._agent = self._build_agent()
|
|
44
|
+
|
|
45
|
+
def _build_agent(self) -> ReActAgent | FunctionCallingAgent:
|
|
46
|
+
memory = AgentMemory(max_steps=self.config.max_iterations)
|
|
47
|
+
if self.config.agent_type == "react":
|
|
48
|
+
return ReActAgent(
|
|
49
|
+
llm=self.config.llm,
|
|
50
|
+
tools=self.config.tools,
|
|
51
|
+
max_iterations=self.config.max_iterations,
|
|
52
|
+
memory=memory,
|
|
53
|
+
)
|
|
54
|
+
elif self.config.agent_type == "function_calling":
|
|
55
|
+
return FunctionCallingAgent(
|
|
56
|
+
llm=self.config.llm,
|
|
57
|
+
tools=self.config.tools,
|
|
58
|
+
max_iterations=self.config.max_iterations,
|
|
59
|
+
memory=memory,
|
|
60
|
+
system_prompt=self.config.system_prompt,
|
|
61
|
+
)
|
|
62
|
+
else:
|
|
63
|
+
raise ValueError(
|
|
64
|
+
f"Unknown agent_type: {self.config.agent_type!r}. "
|
|
65
|
+
"Use 'react' or 'function_calling'."
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
async def run(self, query: str) -> str:
|
|
69
|
+
"""Async: run agent and return final answer."""
|
|
70
|
+
return await self._agent.run(query)
|
|
71
|
+
|
|
72
|
+
async def stream(self, query: str) -> AsyncGenerator[str]:
|
|
73
|
+
"""Async: stream final answer tokens."""
|
|
74
|
+
async for token in self._agent.stream(query):
|
|
75
|
+
yield token
|
|
76
|
+
|
|
77
|
+
def run_sync(self, query: str) -> str:
|
|
78
|
+
"""Sync: run agent (for scripts / notebooks)."""
|
|
79
|
+
return run_sync(self.run(query))
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def memory(self) -> AgentMemory:
|
|
83
|
+
return self._agent.memory
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from collections.abc import AsyncGenerator
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from ..llm.base import BaseLLM
|
|
8
|
+
from .base import BaseTool
|
|
9
|
+
from .memory import AgentMemory, AgentStep
|
|
10
|
+
from .registry import ToolRegistry
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class FunctionCallingAgent:
|
|
14
|
+
"""
|
|
15
|
+
Agent that uses native LLM function-calling (OpenAI tool_calls / Anthropic tool_use).
|
|
16
|
+
|
|
17
|
+
Falls back gracefully: if the LLM doesn't support call_with_tools(),
|
|
18
|
+
raises RuntimeError with a suggestion to use ReActAgent instead.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
llm: BaseLLM,
|
|
24
|
+
tools: list[BaseTool],
|
|
25
|
+
max_iterations: int = 10,
|
|
26
|
+
memory: AgentMemory | None = None,
|
|
27
|
+
system_prompt: str = "You are a helpful AI assistant.",
|
|
28
|
+
) -> None:
|
|
29
|
+
self._llm = llm
|
|
30
|
+
self._registry = ToolRegistry(tools)
|
|
31
|
+
self._max_iterations = max_iterations
|
|
32
|
+
self._memory = memory or AgentMemory(max_steps=max_iterations)
|
|
33
|
+
self._system_prompt = system_prompt
|
|
34
|
+
|
|
35
|
+
def _check_support(self) -> None:
|
|
36
|
+
# Check if the provider has overridden call_with_tools (not just the base NotImplementedError)
|
|
37
|
+
method = getattr(type(self._llm), "call_with_tools", None)
|
|
38
|
+
if method is getattr(BaseLLM, "call_with_tools", None):
|
|
39
|
+
raise RuntimeError(
|
|
40
|
+
f"{type(self._llm).__name__} does not support native function calling. "
|
|
41
|
+
"Use ReActAgent instead, or switch to OpenAILLM / AnthropicLLM / GeminiLLM / MistralLLM."
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
async def run(self, query: str) -> str:
|
|
45
|
+
"""Run the function-calling loop and return the final answer."""
|
|
46
|
+
self._check_support()
|
|
47
|
+
self._memory.clear()
|
|
48
|
+
|
|
49
|
+
messages: list[dict] = [
|
|
50
|
+
{"role": "system", "content": self._system_prompt},
|
|
51
|
+
{"role": "user", "content": query},
|
|
52
|
+
]
|
|
53
|
+
|
|
54
|
+
tool_schemas = self._registry.schemas()
|
|
55
|
+
|
|
56
|
+
for _ in range(self._max_iterations):
|
|
57
|
+
result: dict[str, Any] = await self._llm.call_with_tools(messages, tool_schemas)
|
|
58
|
+
|
|
59
|
+
tool_calls = result.get("tool_calls")
|
|
60
|
+
content = result.get("content")
|
|
61
|
+
|
|
62
|
+
# No tool calls → final answer
|
|
63
|
+
if not tool_calls:
|
|
64
|
+
return content or ""
|
|
65
|
+
|
|
66
|
+
# Append assistant message with tool_calls
|
|
67
|
+
messages.append(
|
|
68
|
+
{
|
|
69
|
+
"role": "assistant",
|
|
70
|
+
"content": None,
|
|
71
|
+
"tool_calls": [
|
|
72
|
+
{
|
|
73
|
+
"id": tc["id"],
|
|
74
|
+
"type": "function",
|
|
75
|
+
"function": {
|
|
76
|
+
"name": tc["name"],
|
|
77
|
+
"arguments": json.dumps(tc["arguments"]),
|
|
78
|
+
},
|
|
79
|
+
}
|
|
80
|
+
for tc in tool_calls
|
|
81
|
+
],
|
|
82
|
+
}
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# Execute each tool and append observations
|
|
86
|
+
for tc in tool_calls:
|
|
87
|
+
try:
|
|
88
|
+
tool = self._registry.get(tc["name"])
|
|
89
|
+
tool_result = await tool.run(**tc["arguments"])
|
|
90
|
+
observation = str(tool_result)
|
|
91
|
+
except KeyError as e:
|
|
92
|
+
observation = f"Error: {e}"
|
|
93
|
+
except Exception as e:
|
|
94
|
+
observation = f"Tool error: {e}"
|
|
95
|
+
|
|
96
|
+
messages.append(
|
|
97
|
+
{
|
|
98
|
+
"role": "tool",
|
|
99
|
+
"tool_call_id": tc["id"],
|
|
100
|
+
"content": observation,
|
|
101
|
+
}
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
self._memory.add_step(
|
|
105
|
+
AgentStep(
|
|
106
|
+
thought="",
|
|
107
|
+
action=tc["name"],
|
|
108
|
+
action_input=json.dumps(tc["arguments"]),
|
|
109
|
+
observation=observation,
|
|
110
|
+
)
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
return "I was unable to complete the task within the allowed number of steps."
|
|
114
|
+
|
|
115
|
+
async def stream(self, query: str) -> AsyncGenerator[str]:
|
|
116
|
+
"""Stream the final answer (intermediate tool calls run silently)."""
|
|
117
|
+
answer = await self.run(query)
|
|
118
|
+
for word in answer.split(" "):
|
|
119
|
+
yield word + " "
|
|
120
|
+
|
|
121
|
+
@property
|
|
122
|
+
def memory(self) -> AgentMemory:
|
|
123
|
+
return self._memory
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class AgentStep:
|
|
8
|
+
"""One complete Thought → Action → Observation cycle."""
|
|
9
|
+
|
|
10
|
+
thought: str
|
|
11
|
+
action: str
|
|
12
|
+
action_input: str
|
|
13
|
+
observation: str
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class AgentMemory:
|
|
17
|
+
"""Scratchpad that records agent steps for the current run."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, max_steps: int = 20) -> None:
|
|
20
|
+
self._max_steps = max_steps
|
|
21
|
+
self._steps: list[AgentStep] = []
|
|
22
|
+
|
|
23
|
+
def add_step(self, step: AgentStep) -> None:
|
|
24
|
+
self._steps.append(step)
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def steps(self) -> list[AgentStep]:
|
|
28
|
+
return list(self._steps)
|
|
29
|
+
|
|
30
|
+
def format_scratchpad(self) -> str:
|
|
31
|
+
"""Format all steps as a ReAct scratchpad string."""
|
|
32
|
+
parts = []
|
|
33
|
+
for step in self._steps:
|
|
34
|
+
parts.append(f"Thought: {step.thought}")
|
|
35
|
+
parts.append(f"Action: {step.action}")
|
|
36
|
+
parts.append(f"Action Input: {step.action_input}")
|
|
37
|
+
parts.append(f"Observation: {step.observation}")
|
|
38
|
+
return "\n".join(parts)
|
|
39
|
+
|
|
40
|
+
def is_full(self) -> bool:
|
|
41
|
+
return len(self._steps) >= self._max_steps
|
|
42
|
+
|
|
43
|
+
def clear(self) -> None:
|
|
44
|
+
self._steps.clear()
|
|
45
|
+
|
|
46
|
+
def __len__(self) -> int:
|
|
47
|
+
return len(self._steps)
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from collections.abc import AsyncGenerator
|
|
5
|
+
|
|
6
|
+
from ..llm.base import BaseLLM
|
|
7
|
+
from .base import BaseTool
|
|
8
|
+
from .memory import AgentMemory, AgentStep
|
|
9
|
+
from .registry import ToolRegistry
|
|
10
|
+
|
|
11
|
+
_REACT_SYSTEM = """\
|
|
12
|
+
You are a helpful AI assistant with access to tools.
|
|
13
|
+
|
|
14
|
+
Available tools:
|
|
15
|
+
{tools}
|
|
16
|
+
|
|
17
|
+
Use EXACTLY this format for every response until you have a final answer:
|
|
18
|
+
|
|
19
|
+
Thought: (your reasoning about what to do next)
|
|
20
|
+
Action: (the exact tool name from the list above)
|
|
21
|
+
Action Input: (the input to pass to the tool, as a plain string)
|
|
22
|
+
|
|
23
|
+
When you have enough information to answer:
|
|
24
|
+
|
|
25
|
+
Thought: I now know the final answer.
|
|
26
|
+
Final Answer: (your complete answer to the original question)
|
|
27
|
+
|
|
28
|
+
Rules:
|
|
29
|
+
- Only use tools from the list above.
|
|
30
|
+
- Never invent tool results — always call the tool and wait for the Observation.
|
|
31
|
+
- Never skip the Thought step.
|
|
32
|
+
- Provide Final Answer only when you are confident.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
_ACTION_RE = re.compile(r"Action:\s*(.+)", re.IGNORECASE)
|
|
36
|
+
_ACTION_INPUT_RE = re.compile(r"Action Input:\s*(.+)", re.IGNORECASE | re.DOTALL)
|
|
37
|
+
_THOUGHT_RE = re.compile(
|
|
38
|
+
r"Thought:\s*(.+?)(?=\n(?:Action|Final Answer)|$)", re.IGNORECASE | re.DOTALL
|
|
39
|
+
)
|
|
40
|
+
_FINAL_ANSWER_RE = re.compile(r"Final Answer:\s*(.+)", re.IGNORECASE | re.DOTALL)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _parse_thought(text: str) -> str:
|
|
44
|
+
m = _THOUGHT_RE.search(text)
|
|
45
|
+
return m.group(1).strip() if m else ""
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _parse_action(text: str) -> tuple[str, str]:
|
|
49
|
+
action_m = _ACTION_RE.search(text)
|
|
50
|
+
input_m = _ACTION_INPUT_RE.search(text)
|
|
51
|
+
action = action_m.group(1).strip() if action_m else ""
|
|
52
|
+
action_input = input_m.group(1).strip() if input_m else ""
|
|
53
|
+
return action, action_input
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _parse_final_answer(text: str) -> str | None:
|
|
57
|
+
m = _FINAL_ANSWER_RE.search(text)
|
|
58
|
+
return m.group(1).strip() if m else None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class ReActAgent:
|
|
62
|
+
"""
|
|
63
|
+
Reasoning + Acting agent.
|
|
64
|
+
|
|
65
|
+
Loops: Thought → Action → Observation → repeat until Final Answer.
|
|
66
|
+
Works with any BaseLLM — no native function-calling required.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
llm: BaseLLM,
|
|
72
|
+
tools: list[BaseTool],
|
|
73
|
+
max_iterations: int = 10,
|
|
74
|
+
memory: AgentMemory | None = None,
|
|
75
|
+
) -> None:
|
|
76
|
+
self._llm = llm
|
|
77
|
+
self._registry = ToolRegistry(tools)
|
|
78
|
+
self._max_iterations = max_iterations
|
|
79
|
+
self._memory = memory or AgentMemory(max_steps=max_iterations)
|
|
80
|
+
|
|
81
|
+
def _build_system_prompt(self) -> str:
|
|
82
|
+
return _REACT_SYSTEM.format(tools=self._registry.describe())
|
|
83
|
+
|
|
84
|
+
def _build_messages(self, query: str) -> list[dict]:
|
|
85
|
+
scratchpad = self._memory.format_scratchpad()
|
|
86
|
+
user_content = f"Question: {query}"
|
|
87
|
+
if scratchpad:
|
|
88
|
+
user_content += f"\n\n{scratchpad}"
|
|
89
|
+
return [
|
|
90
|
+
{"role": "system", "content": self._build_system_prompt()},
|
|
91
|
+
{"role": "user", "content": user_content},
|
|
92
|
+
]
|
|
93
|
+
|
|
94
|
+
async def run(self, query: str) -> str:
|
|
95
|
+
"""Run the ReAct loop and return the final answer."""
|
|
96
|
+
self._memory.clear()
|
|
97
|
+
|
|
98
|
+
for _ in range(self._max_iterations):
|
|
99
|
+
messages = self._build_messages(query)
|
|
100
|
+
response = await self._llm.generate_with_messages(messages)
|
|
101
|
+
|
|
102
|
+
# Check for final answer first
|
|
103
|
+
final = _parse_final_answer(response)
|
|
104
|
+
if final is not None:
|
|
105
|
+
return final
|
|
106
|
+
|
|
107
|
+
# Parse action
|
|
108
|
+
action_name, action_input = _parse_action(response)
|
|
109
|
+
thought = _parse_thought(response)
|
|
110
|
+
|
|
111
|
+
if not action_name:
|
|
112
|
+
# LLM didn't follow format — treat whole response as final answer
|
|
113
|
+
return response.strip()
|
|
114
|
+
|
|
115
|
+
# Execute tool
|
|
116
|
+
try:
|
|
117
|
+
tool = self._registry.get(action_name)
|
|
118
|
+
result = await tool.run(input=action_input)
|
|
119
|
+
observation = str(result)
|
|
120
|
+
except KeyError as e:
|
|
121
|
+
observation = f"Error: {e}"
|
|
122
|
+
except Exception as e:
|
|
123
|
+
observation = f"Tool error: {e}"
|
|
124
|
+
|
|
125
|
+
self._memory.add_step(
|
|
126
|
+
AgentStep(
|
|
127
|
+
thought=thought,
|
|
128
|
+
action=action_name,
|
|
129
|
+
action_input=action_input,
|
|
130
|
+
observation=observation,
|
|
131
|
+
)
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
return "I was unable to find the answer within the allowed number of steps."
|
|
135
|
+
|
|
136
|
+
async def stream(self, query: str) -> AsyncGenerator[str]:
|
|
137
|
+
"""
|
|
138
|
+
Stream the final answer. Intermediate tool calls run silently.
|
|
139
|
+
Yields the final answer string (may be multi-token on last LLM call).
|
|
140
|
+
"""
|
|
141
|
+
answer = await self.run(query)
|
|
142
|
+
for word in answer.split(" "):
|
|
143
|
+
yield word + " "
|
|
144
|
+
|
|
145
|
+
@property
|
|
146
|
+
def memory(self) -> AgentMemory:
|
|
147
|
+
return self._memory
|