agentic-builder 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentic_builder-0.1.0/PKG-INFO +25 -0
- agentic_builder-0.1.0/README.md +0 -0
- agentic_builder-0.1.0/agentic_builder/__init__.py +32 -0
- agentic_builder-0.1.0/agentic_builder/agent/__init__.py +161 -0
- agentic_builder-0.1.0/agentic_builder/agent/base.py +18 -0
- agentic_builder-0.1.0/agentic_builder/constants.py +35 -0
- agentic_builder-0.1.0/agentic_builder/langgraph/__init__.py +96 -0
- agentic_builder-0.1.0/agentic_builder/langgraph/nodes.py +140 -0
- agentic_builder-0.1.0/agentic_builder/langgraph/prompts.py +469 -0
- agentic_builder-0.1.0/agentic_builder/langgraph/states.py +17 -0
- agentic_builder-0.1.0/agentic_builder/llm/__init__.py +21 -0
- agentic_builder-0.1.0/agentic_builder/llm/factory.py +42 -0
- agentic_builder-0.1.0/agentic_builder/llm/llm.py +45 -0
- agentic_builder-0.1.0/agentic_builder/mixins.py +77 -0
- agentic_builder-0.1.0/agentic_builder/server.py +184 -0
- agentic_builder-0.1.0/agentic_builder/settings.py +109 -0
- agentic_builder-0.1.0/agentic_builder/utils.py +89 -0
- agentic_builder-0.1.0/pyproject.toml +35 -0
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: agentic-builder
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary:
|
|
5
|
+
Author: jalal
|
|
6
|
+
Author-email: jalalkhaldi3@gmail.com
|
|
7
|
+
Requires-Python: >=3.12,<3.14
|
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
|
9
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
11
|
+
Requires-Dist: arize-phoenix-client (>=1.28.0,<2.0.0)
|
|
12
|
+
Requires-Dist: arize-phoenix-otel (==0.13.1)
|
|
13
|
+
Requires-Dist: dotenv (>=0.9.9,<0.10.0)
|
|
14
|
+
Requires-Dist: fastapi (>=0.128.0,<0.129.0)
|
|
15
|
+
Requires-Dist: langchain (>=1.1.0,<2.0.0)
|
|
16
|
+
Requires-Dist: langchain-mcp-adapters (>=0.2.1,<0.3.0)
|
|
17
|
+
Requires-Dist: langchain-ollama (>=1.0.0,<2.0.0)
|
|
18
|
+
Requires-Dist: langchain-openai (>=1.1.5,<2.0.0)
|
|
19
|
+
Requires-Dist: langchain-openrouter (>=0.0.2,<0.0.3)
|
|
20
|
+
Requires-Dist: openinference-instrumentation-langchain (>=0.1.56,<0.2.0)
|
|
21
|
+
Requires-Dist: pydantic (>=2.12.5,<3.0.0)
|
|
22
|
+
Requires-Dist: pydantic-settings (>=2.12.0,<3.0.0)
|
|
23
|
+
Description-Content-Type: text/markdown
|
|
24
|
+
|
|
25
|
+
|
|
File without changes
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any, Generic
|
|
3
|
+
|
|
4
|
+
from agentic_builder.agent import Agent
|
|
5
|
+
from agentic_builder.constants import TCAPIRunner
|
|
6
|
+
from agentic_builder.mixins import FromConfigMixin
|
|
7
|
+
from agentic_builder.settings import BaseAPIRunnerSettings
|
|
8
|
+
from agentic_builder.utils import load_class
|
|
9
|
+
|
|
10
|
+
_logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class APIRunner(FromConfigMixin[TCAPIRunner], Generic[TCAPIRunner]):
|
|
14
|
+
|
|
15
|
+
def __init__(self, config: TCAPIRunner):
|
|
16
|
+
self.config = config
|
|
17
|
+
self.agent: Agent[Any]
|
|
18
|
+
|
|
19
|
+
@classmethod
|
|
20
|
+
async def create(cls, config: TCAPIRunner) -> "APIRunner[TCAPIRunner]":
|
|
21
|
+
self = cls(config)
|
|
22
|
+
self.agent = await load_class(self.config.agent.module_path).create(self.config.agent)
|
|
23
|
+
_logger.info("APIRunner initialized")
|
|
24
|
+
return self
|
|
25
|
+
|
|
26
|
+
async def init_agent(self) -> None:
|
|
27
|
+
self.agent = await load_class(self.config.agent.module_path).create(self.config.agent)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class BaseAPIRunner(APIRunner[BaseAPIRunnerSettings]):
|
|
31
|
+
def __init__(self, config: BaseAPIRunnerSettings):
|
|
32
|
+
super().__init__(config)
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, AsyncGenerator, Generic, List, Optional
|
|
6
|
+
|
|
7
|
+
from agentic_builder.constants import GRAPHS_DIR, TCAgentSettings, TCAPIRunner
|
|
8
|
+
from agentic_builder.llm.factory import LLMFactory
|
|
9
|
+
from agentic_builder.mixins import FromConfigMixin
|
|
10
|
+
from agentic_builder.settings import BaseAPIRunnerSettings
|
|
11
|
+
from agentic_builder.utils import _cached_mcp_tools, load_class
|
|
12
|
+
from langchain.messages import AIMessageChunk
|
|
13
|
+
from langchain.tools import BaseTool
|
|
14
|
+
from langgraph.checkpoint.memory import InMemorySaver
|
|
15
|
+
from langgraph.graph.state import CompiledStateGraph
|
|
16
|
+
from openinference.instrumentation import TracerProvider
|
|
17
|
+
from phoenix.otel import register
|
|
18
|
+
|
|
19
|
+
_logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Agent(FromConfigMixin[TCAgentSettings], Generic[TCAgentSettings]):
|
|
23
|
+
"""
|
|
24
|
+
Initializes and runs a LangChain agent backed by VLLM and Qdrant embeddings.
|
|
25
|
+
Minimal version: no exception catching, f-string logs only.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
_tracer_provider: Optional[TracerProvider] = None
|
|
29
|
+
_compiled_graph: Optional[CompiledStateGraph] = None
|
|
30
|
+
|
|
31
|
+
def __init__(self, config: TCAgentSettings) -> None:
|
|
32
|
+
self.config = config
|
|
33
|
+
self.checkpointer = InMemorySaver()
|
|
34
|
+
self.configurable = {"configurable": {"thread_id": "1"}}
|
|
35
|
+
_logger.info("Agent initialized")
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
async def create(cls, config: TCAgentSettings) -> "Agent[TCAgentSettings]":
|
|
39
|
+
self = cls(config)
|
|
40
|
+
await self._init_components()
|
|
41
|
+
_logger.info("Agent initialized")
|
|
42
|
+
self.push()
|
|
43
|
+
_logger.info("Graph saved locally")
|
|
44
|
+
return self
|
|
45
|
+
|
|
46
|
+
async def _load_tools(self) -> List[BaseTool]:
|
|
47
|
+
fingerprint = ",".join(f"{mcp.name}:{mcp.url}" for mcp in self.config.mcps)
|
|
48
|
+
tools = await _cached_mcp_tools(
|
|
49
|
+
mcps_fingerprint=fingerprint,
|
|
50
|
+
mcps_config=self.config.mcps,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
_logger.info("Loaded %d MCP tools (cached)", len(tools))
|
|
54
|
+
return tools
|
|
55
|
+
|
|
56
|
+
def _load_tracer_provider(self) -> TracerProvider:
|
|
57
|
+
if self.__class__._tracer_provider is None:
|
|
58
|
+
_logger.info("Initializing OpenInference tracer provider")
|
|
59
|
+
self.__class__._tracer_provider = register(
|
|
60
|
+
endpoint=self.config.phoenix.collector_endpoint,
|
|
61
|
+
project_name=self.config.phoenix.project_name,
|
|
62
|
+
auto_instrument=True,
|
|
63
|
+
)
|
|
64
|
+
_logger.info("LangChain OpenInference instrumentation enabled")
|
|
65
|
+
return self.__class__._tracer_provider
|
|
66
|
+
|
|
67
|
+
async def _init_components(self) -> None:
|
|
68
|
+
self.tracer_provider = self._load_tracer_provider()
|
|
69
|
+
self.tools = await self._load_tools()
|
|
70
|
+
self.llm_factory: LLMFactory = load_class(self.config.llm_factory.module_path)(self.config.llm_factory)
|
|
71
|
+
if self.__class__._compiled_graph is None:
|
|
72
|
+
self.__class__._compiled_graph = self.build_agent()
|
|
73
|
+
self.agent: CompiledStateGraph = self.__class__._compiled_graph
|
|
74
|
+
_logger.info(f"Agent created")
|
|
75
|
+
|
|
76
|
+
@abstractmethod
|
|
77
|
+
def build_agent(self) -> CompiledStateGraph:
|
|
78
|
+
raise NotImplementedError()
|
|
79
|
+
|
|
80
|
+
async def arun(self, query: str) -> str:
|
|
81
|
+
chunks: List[str] = []
|
|
82
|
+
async for event in self.agent.astream(
|
|
83
|
+
{"messages": [{"role": "user", "content": query}]},
|
|
84
|
+
stream_mode="messages",
|
|
85
|
+
):
|
|
86
|
+
msg = event.get("messages", [None])[-1]
|
|
87
|
+
if isinstance(msg, AIMessageChunk) and msg.text:
|
|
88
|
+
chunks.append(msg.text)
|
|
89
|
+
return "".join(chunks)
|
|
90
|
+
|
|
91
|
+
async def astream_llm_tokens(self, query: str) -> AsyncGenerator[str, None]:
|
|
92
|
+
"""
|
|
93
|
+
Stream ONLY LLM tokens (no tool messages, no updates).
|
|
94
|
+
"""
|
|
95
|
+
_logger.info(f"Streaming agent (LLM-only): {query}")
|
|
96
|
+
|
|
97
|
+
async for token, metadata in self.agent.astream(
|
|
98
|
+
{"messages": [{"role": "user", "content": query}]},
|
|
99
|
+
config=self.configurable, # type: ignore[arg-type]
|
|
100
|
+
stream_mode="messages",
|
|
101
|
+
):
|
|
102
|
+
if isinstance(token, AIMessageChunk) and token.text:
|
|
103
|
+
yield token.text
|
|
104
|
+
|
|
105
|
+
async def render_steps(self, query: str) -> None:
|
|
106
|
+
async for chunk in self.agent.astream(
|
|
107
|
+
{
|
|
108
|
+
"messages": [
|
|
109
|
+
{
|
|
110
|
+
"role": "user",
|
|
111
|
+
"content": query,
|
|
112
|
+
}
|
|
113
|
+
]
|
|
114
|
+
},
|
|
115
|
+
config=self.configurable, # type: ignore[arg-type]
|
|
116
|
+
):
|
|
117
|
+
for node, update in chunk.items():
|
|
118
|
+
print("Update from node", node)
|
|
119
|
+
print(update)
|
|
120
|
+
if "messages" in update and update["messages"]:
|
|
121
|
+
update["messages"][-1].pretty_print()
|
|
122
|
+
|
|
123
|
+
def push(self) -> None:
|
|
124
|
+
# Timestamp: YYYYMMDD_HHMMSS
|
|
125
|
+
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
|
126
|
+
|
|
127
|
+
# Optional: include agent name / class
|
|
128
|
+
agent_name = self.__class__.__name__
|
|
129
|
+
|
|
130
|
+
# Build filename
|
|
131
|
+
filename = f"{agent_name}_graph_{timestamp}.png"
|
|
132
|
+
|
|
133
|
+
# Full path
|
|
134
|
+
path: Path = GRAPHS_DIR / filename
|
|
135
|
+
|
|
136
|
+
# Draw + save
|
|
137
|
+
png_bytes = self.agent.get_graph().draw_mermaid_png()
|
|
138
|
+
path.write_bytes(png_bytes)
|
|
139
|
+
_logger.info(f"Graph saved locally at {path}")
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class APIRunner(FromConfigMixin[TCAPIRunner], Generic[TCAPIRunner]):
|
|
143
|
+
|
|
144
|
+
def __init__(self, config: TCAPIRunner):
|
|
145
|
+
self.config = config
|
|
146
|
+
self.agent: Agent[Any]
|
|
147
|
+
|
|
148
|
+
@classmethod
|
|
149
|
+
async def create(cls, config: TCAPIRunner) -> "APIRunner[TCAPIRunner]":
|
|
150
|
+
self = cls(config)
|
|
151
|
+
self.agent = await load_class(self.config.agent.module_path).create(self.config.agent)
|
|
152
|
+
_logger.info("APIRunner initialized")
|
|
153
|
+
return self
|
|
154
|
+
|
|
155
|
+
async def init_agent(self) -> None:
|
|
156
|
+
self.agent = await load_class(self.config.agent.module_path).create(self.config.agent)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class BaseAPIRunner(APIRunner[BaseAPIRunnerSettings]):
|
|
160
|
+
def __init__(self, config: BaseAPIRunnerSettings):
|
|
161
|
+
super().__init__(config)
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from agentic_builder import Agent
|
|
4
|
+
from agentic_builder.settings import BaseAgentSettings
|
|
5
|
+
from langchain.agents import create_agent
|
|
6
|
+
from langgraph.graph.state import CompiledStateGraph
|
|
7
|
+
|
|
8
|
+
_logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BaseAgent(Agent[BaseAgentSettings]):
|
|
12
|
+
|
|
13
|
+
def __init__(self, config: BaseAgentSettings) -> None:
|
|
14
|
+
super().__init__(config)
|
|
15
|
+
|
|
16
|
+
def build_agent(self) -> CompiledStateGraph:
|
|
17
|
+
model = self.llm_factory.get("fast_latency").to_langchain()
|
|
18
|
+
return create_agent(model=model, tools=self.tools, checkpointer=self.checkpointer)
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, TypeVar, Union
|
|
3
|
+
|
|
4
|
+
from langchain_core.tools.base import BaseTool
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
|
|
8
|
+
from agentic_builder.settings import (
|
|
9
|
+
AgentSettings,
|
|
10
|
+
APIRunnerSettings,
|
|
11
|
+
DeletionStrategySettings,
|
|
12
|
+
FromConfigMixinSettings,
|
|
13
|
+
LLMSettings,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
CONFIG_PATH = "/config/config.yaml"
|
|
17
|
+
# CONFIG_PATH = "/home/jalal/projects/arxiv-agent/configs/config_agentic_builder.yaml"
|
|
18
|
+
|
|
19
|
+
agentic_builder_ROOT = Path(__file__).resolve().parent
|
|
20
|
+
GRAPHS_DIR = agentic_builder_ROOT.parent.parent / "graphs"
|
|
21
|
+
GRAPHS_DIR.mkdir(parents=True, exist_ok=True)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
TCAgentSettings = TypeVar("TCAgentSettings", bound="AgentSettings")
|
|
25
|
+
TCFromConfigMixin = TypeVar("TCFromConfigMixin", bound="FromConfigMixinSettings")
|
|
26
|
+
TCLLM = TypeVar("TCLLM", bound="LLMSettings")
|
|
27
|
+
TCAgent = TypeVar("TCAgent", bound="AgentSettings")
|
|
28
|
+
TCAPIRunner = TypeVar("TCAPIRunner", bound="APIRunnerSettings")
|
|
29
|
+
TCDeletionStrategy = TypeVar("TCDeletionStrategy", bound="DeletionStrategySettings")
|
|
30
|
+
|
|
31
|
+
TNode = TypeVar("TNode")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
ToolLike = Union[Dict[str, Any], Type, Callable, BaseTool]
|
|
35
|
+
StructuredOutput = Optional[Union[Dict[str, Any], Type[Any]]]
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import Any, Generic, List, Optional, Type, Union
|
|
4
|
+
|
|
5
|
+
from agentic_builder.constants import TCDeletionStrategy
|
|
6
|
+
from agentic_builder.mixins import FromConfigMixin
|
|
7
|
+
from langchain_core.language_models.chat_models import BaseChatModel
|
|
8
|
+
from langchain_core.messages import AIMessage, BaseMessage, SystemMessage
|
|
9
|
+
from langchain_core.messages.utils import AnyMessage
|
|
10
|
+
from langchain_core.runnables import Runnable
|
|
11
|
+
from langgraph.graph._node import _Node
|
|
12
|
+
from langgraph.typing import NodeInputT_contra
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Prompt(ABC, Generic[NodeInputT_contra]):
|
|
18
|
+
|
|
19
|
+
@classmethod
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def _render(cls, state: NodeInputT_contra) -> str:
|
|
22
|
+
raise NotImplementedError()
|
|
23
|
+
|
|
24
|
+
@classmethod
|
|
25
|
+
def render(cls, state: NodeInputT_contra) -> List[BaseMessage]:
|
|
26
|
+
content = cls._render(state)
|
|
27
|
+
return [
|
|
28
|
+
*state["messages"], # type: ignore[index]
|
|
29
|
+
SystemMessage(content=content),
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class SchemaBasedPrompt(Prompt[NodeInputT_contra], Generic[NodeInputT_contra]):
|
|
34
|
+
|
|
35
|
+
@classmethod
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def get_schema(cls) -> Any:
|
|
38
|
+
raise NotImplementedError()
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class DeletionStrategy(FromConfigMixin[TCDeletionStrategy], Generic[TCDeletionStrategy]):
|
|
42
|
+
|
|
43
|
+
def __init__(self, config: TCDeletionStrategy) -> None:
|
|
44
|
+
self.config = config
|
|
45
|
+
|
|
46
|
+
def _delete(self, messages: List[AnyMessage]) -> List[AnyMessage]:
|
|
47
|
+
raise NotImplementedError()
|
|
48
|
+
|
|
49
|
+
def delete(self, messages: List[AnyMessage]) -> List[AnyMessage]:
|
|
50
|
+
before_count = len(messages)
|
|
51
|
+
logger.debug(
|
|
52
|
+
f"DeletionStrategy.start | " f"strategy={self.__class__.__name__} | " f"messages_before={before_count}"
|
|
53
|
+
)
|
|
54
|
+
new_messages = self._delete(messages)
|
|
55
|
+
after_count = len(new_messages)
|
|
56
|
+
removed = before_count - after_count
|
|
57
|
+
logger.info(
|
|
58
|
+
f"DeletionStrategy.done | "
|
|
59
|
+
f"strategy={self.__class__.__name__} | "
|
|
60
|
+
f"removed={removed} | "
|
|
61
|
+
f"remaining={after_count}"
|
|
62
|
+
)
|
|
63
|
+
return new_messages
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class Node(_Node[NodeInputT_contra], Generic[NodeInputT_contra]):
|
|
67
|
+
def __init__(self, name: str) -> None:
|
|
68
|
+
self.name = name
|
|
69
|
+
|
|
70
|
+
async def __call__(self, state: NodeInputT_contra) -> Any: ...
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class RouterNode(Node[NodeInputT_contra], Generic[NodeInputT_contra]):
|
|
74
|
+
def __init__(self, name: str) -> None:
|
|
75
|
+
super().__init__(name)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class LLMNode(Node[NodeInputT_contra], Generic[NodeInputT_contra]):
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
name: str,
|
|
82
|
+
model: BaseChatModel,
|
|
83
|
+
prompt: Type[Prompt[NodeInputT_contra]],
|
|
84
|
+
structured_output: Optional[Union[dict, type]],
|
|
85
|
+
):
|
|
86
|
+
super().__init__(name)
|
|
87
|
+
runnable: Runnable[Any, AIMessage] = model
|
|
88
|
+
|
|
89
|
+
if structured_output:
|
|
90
|
+
runnable = runnable.with_structured_output(structured_output) # type: ignore[attr-defined]
|
|
91
|
+
self.runnable = runnable
|
|
92
|
+
self.prompt = prompt
|
|
93
|
+
|
|
94
|
+
async def _ainvoke(self, state: NodeInputT_contra) -> AIMessage:
|
|
95
|
+
messages = self.prompt.render(state)
|
|
96
|
+
return await self.runnable.ainvoke(messages)
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from agentic_builder.langgraph import LLMNode, Node, RouterNode
|
|
5
|
+
from agentic_builder.langgraph.prompts import (
|
|
6
|
+
AnswerFailurePrompt,
|
|
7
|
+
GenerateAnswerPrompt,
|
|
8
|
+
GenerateQuerySchemaBasedPrompt,
|
|
9
|
+
GraderSchemaBasedPrompt,
|
|
10
|
+
IrrelevantQueryPrompt,
|
|
11
|
+
)
|
|
12
|
+
from agentic_builder.langgraph.states import AgentState
|
|
13
|
+
from langchain_core.language_models.chat_models import BaseChatModel
|
|
14
|
+
from langchain_core.tools.base import BaseTool
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class GenerateQueryLLMNode(LLMNode[AgentState]):
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
name: str,
|
|
21
|
+
model: BaseChatModel,
|
|
22
|
+
):
|
|
23
|
+
super().__init__(
|
|
24
|
+
name=name,
|
|
25
|
+
model=model,
|
|
26
|
+
prompt=GenerateQuerySchemaBasedPrompt,
|
|
27
|
+
structured_output=GenerateQuerySchemaBasedPrompt.get_schema(),
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
async def __call__(self, state: AgentState) -> Any:
|
|
31
|
+
query = await self._ainvoke(state)
|
|
32
|
+
return {"query": query}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class SearchDocumentsNode(Node[AgentState]):
|
|
36
|
+
def __init__(self, name: str, retriever_tool: BaseTool):
|
|
37
|
+
super().__init__(name=name)
|
|
38
|
+
self.retriever_tool = retriever_tool
|
|
39
|
+
|
|
40
|
+
async def __call__(self, state: AgentState) -> Any:
|
|
41
|
+
if state["query"] is None:
|
|
42
|
+
raise ValueError("key 'query' is None in the state dict")
|
|
43
|
+
# Call the retriever tool asynchronously
|
|
44
|
+
tool_result = (await self.retriever_tool.arun({"query": state["query"]["content"]}))[0]
|
|
45
|
+
# Parse tool output
|
|
46
|
+
payload = json.loads(tool_result["text"])
|
|
47
|
+
evidence = payload["retrieval"]["evidence"]
|
|
48
|
+
# Return state update
|
|
49
|
+
return {"documents": evidence}
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class GraderLLMNode(LLMNode[AgentState]):
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
name: str,
|
|
56
|
+
model: BaseChatModel,
|
|
57
|
+
):
|
|
58
|
+
super().__init__(
|
|
59
|
+
name=name,
|
|
60
|
+
model=model,
|
|
61
|
+
prompt=GraderSchemaBasedPrompt,
|
|
62
|
+
structured_output=GraderSchemaBasedPrompt.get_schema(),
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
async def __call__(self, state: AgentState) -> Any:
|
|
66
|
+
return {"is_answerable": (await self._ainvoke(state))}
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class GenerateAnswerLLMNode(LLMNode[AgentState]):
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
name: str,
|
|
73
|
+
model: BaseChatModel,
|
|
74
|
+
):
|
|
75
|
+
super().__init__(
|
|
76
|
+
name=name,
|
|
77
|
+
model=model,
|
|
78
|
+
prompt=GenerateAnswerPrompt,
|
|
79
|
+
structured_output=None,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
async def __call__(self, state: AgentState) -> Any:
|
|
83
|
+
response = await self._ainvoke(state)
|
|
84
|
+
return {"messages": [response]}
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class AnswerFailureLLMNode(LLMNode[AgentState]):
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
name: str,
|
|
91
|
+
model: BaseChatModel,
|
|
92
|
+
):
|
|
93
|
+
super().__init__(
|
|
94
|
+
name=name,
|
|
95
|
+
model=model,
|
|
96
|
+
prompt=AnswerFailurePrompt,
|
|
97
|
+
structured_output=None,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
async def __call__(self, state: AgentState) -> Any:
|
|
101
|
+
response = await self._ainvoke(state)
|
|
102
|
+
return {"messages": [response]}
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class IrrelevantQueryLLMNode(LLMNode[AgentState]):
|
|
106
|
+
def __init__(
|
|
107
|
+
self,
|
|
108
|
+
name: str,
|
|
109
|
+
model: BaseChatModel,
|
|
110
|
+
):
|
|
111
|
+
super().__init__(
|
|
112
|
+
name=name,
|
|
113
|
+
model=model,
|
|
114
|
+
prompt=IrrelevantQueryPrompt,
|
|
115
|
+
structured_output=None,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
async def __call__(self, state: AgentState) -> Any:
|
|
119
|
+
response = await self._ainvoke(state)
|
|
120
|
+
return {"messages": [response]}
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class IsQueryRelevantRouterNode(RouterNode[AgentState]):
|
|
124
|
+
def __init__(self, name: str) -> None:
|
|
125
|
+
super().__init__(name=name)
|
|
126
|
+
|
|
127
|
+
async def __call__(self, state: AgentState) -> Any:
|
|
128
|
+
if state["query"]["content"] != "[NAN]": # type: ignore[index]
|
|
129
|
+
return "yes"
|
|
130
|
+
return "no"
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class IsDocumentRelevantRouterNode(RouterNode[AgentState]):
|
|
134
|
+
def __init__(self, name: str) -> None:
|
|
135
|
+
super().__init__(name=name)
|
|
136
|
+
|
|
137
|
+
async def __call__(self, state: AgentState) -> Any:
|
|
138
|
+
if state["is_answerable"]:
|
|
139
|
+
return state["is_answerable"]["response"]
|
|
140
|
+
raise ValueError("key 'is_answerable' is None in the state dict")
|