pixie-examples 0.1.1.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/__init__.py +0 -0
- examples/langchain/README.md +39 -0
- examples/langchain/__init__.py +1 -0
- examples/langchain/basic_agent.py +100 -0
- examples/langchain/customer_support.py +238 -0
- examples/langchain/personal_assistant.py +163 -0
- examples/langchain/sql_agent.py +176 -0
- examples/langgraph/__init__.py +0 -0
- examples/langgraph/langgraph_rag.py +241 -0
- examples/langgraph/langgraph_sql_agent.py +218 -0
- examples/openai_agents_sdk/README.md +299 -0
- examples/openai_agents_sdk/__init__.py +0 -0
- examples/openai_agents_sdk/customer_service.py +258 -0
- examples/openai_agents_sdk/financial_research_agent.py +328 -0
- examples/openai_agents_sdk/llm_as_a_judge.py +108 -0
- examples/openai_agents_sdk/routing.py +177 -0
- examples/pydantic_ai/.env.example +26 -0
- examples/pydantic_ai/README.md +246 -0
- examples/pydantic_ai/__init__.py +0 -0
- examples/pydantic_ai/bank_support.py +154 -0
- examples/pydantic_ai/flight_booking.py +250 -0
- examples/pydantic_ai/question_graph.py +152 -0
- examples/pydantic_ai/sql_gen.py +182 -0
- examples/pydantic_ai/structured_output.py +64 -0
- examples/quickstart/__init__.py +0 -0
- examples/quickstart/chatbot.py +25 -0
- examples/quickstart/sleepy_poet.py +96 -0
- examples/quickstart/weather_agent.py +110 -0
- examples/sql_utils.py +241 -0
- pixie_examples-0.1.1.dev3.dist-info/METADATA +113 -0
- pixie_examples-0.1.1.dev3.dist-info/RECORD +33 -0
- pixie_examples-0.1.1.dev3.dist-info/WHEEL +4 -0
- pixie_examples-0.1.1.dev3.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SQL Agent (Multi-Turn & Multi-Step)
|
|
3
|
+
|
|
4
|
+
This example demonstrates an agent that can answer questions about a SQL database.
|
|
5
|
+
The agent can:
|
|
6
|
+
1. Fetch available tables and schemas
|
|
7
|
+
2. Decide which tables are relevant
|
|
8
|
+
3. Generate SQL queries
|
|
9
|
+
4. Execute queries and handle errors
|
|
10
|
+
5. Formulate responses based on results
|
|
11
|
+
|
|
12
|
+
Based on: https://docs.langchain.com/oss/python/langchain/sql-agent
|
|
13
|
+
|
|
14
|
+
WARNING: Building Q&A systems of SQL databases requires executing model-generated SQL
|
|
15
|
+
queries. Make sure database connection permissions are scoped as narrowly as possible.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
import pathlib
|
|
19
|
+
import requests
|
|
20
|
+
from langchain.agents import create_agent
|
|
21
|
+
from langchain.chat_models import init_chat_model
|
|
22
|
+
from langgraph.checkpoint.memory import InMemorySaver
|
|
23
|
+
|
|
24
|
+
from langfuse.langchain import CallbackHandler
|
|
25
|
+
import pixie
|
|
26
|
+
from ..sql_utils import SQLDatabase, SQLDatabaseToolkit
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
langfuse_handler = CallbackHandler()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# System prompt for SQL agent
|
|
33
|
+
SQL_AGENT_PROMPT = """
|
|
34
|
+
You are an agent designed to interact with a SQL database.
|
|
35
|
+
Given an input question, create a syntactically correct {dialect} query to run,
|
|
36
|
+
then look at the results of the query and return the answer. Unless the user
|
|
37
|
+
specifies a specific number of examples they wish to obtain, always limit your
|
|
38
|
+
query to at most {top_k} results.
|
|
39
|
+
|
|
40
|
+
You can order the results by a relevant column to return the most interesting
|
|
41
|
+
examples in the database. Never query for all the columns from a specific table,
|
|
42
|
+
only ask for the relevant columns given the question.
|
|
43
|
+
|
|
44
|
+
You MUST double check your query before executing it. If you get an error while
|
|
45
|
+
executing a query, rewrite the query and try again.
|
|
46
|
+
|
|
47
|
+
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the
|
|
48
|
+
database.
|
|
49
|
+
|
|
50
|
+
To start you should ALWAYS look at the tables in the database to see what you
|
|
51
|
+
can query. Do NOT skip this step.
|
|
52
|
+
|
|
53
|
+
Then you should query the schema of the most relevant tables.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def setup_database():
|
|
58
|
+
"""Download and setup the Chinook database if not already present."""
|
|
59
|
+
url = "https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db"
|
|
60
|
+
local_path = pathlib.Path("Chinook.db")
|
|
61
|
+
|
|
62
|
+
if local_path.exists():
|
|
63
|
+
print(f"{local_path} already exists, skipping download.")
|
|
64
|
+
else:
|
|
65
|
+
print("Downloading Chinook database...")
|
|
66
|
+
response = requests.get(url)
|
|
67
|
+
if response.status_code == 200:
|
|
68
|
+
local_path.write_bytes(response.content)
|
|
69
|
+
print(f"File downloaded and saved as {local_path}")
|
|
70
|
+
else:
|
|
71
|
+
raise Exception(
|
|
72
|
+
f"Failed to download the file. Status code: {response.status_code}"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
return SQLDatabase.from_uri("sqlite:///Chinook.db")
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@pixie.app
|
|
79
|
+
async def langchain_sql_query_agent(question: str) -> str:
|
|
80
|
+
"""SQL database query agent that can answer questions about the Chinook database.
|
|
81
|
+
|
|
82
|
+
The Chinook database represents a digital media store with tables for artists,
|
|
83
|
+
albums, tracks, customers, invoices, etc.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
question: Natural language question about the database
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
AI-generated answer based on SQL query results
|
|
90
|
+
"""
|
|
91
|
+
# Setup database
|
|
92
|
+
db = setup_database()
|
|
93
|
+
|
|
94
|
+
# Initialize model
|
|
95
|
+
model = init_chat_model("gpt-4o-mini", temperature=0)
|
|
96
|
+
|
|
97
|
+
# Create SQL toolkit with tools for database interaction
|
|
98
|
+
toolkit = SQLDatabaseToolkit(db=db, llm=model)
|
|
99
|
+
tools = toolkit.get_tools()
|
|
100
|
+
|
|
101
|
+
# Format system prompt with database info
|
|
102
|
+
system_prompt = SQL_AGENT_PROMPT.format(dialect=db.dialect, top_k=5)
|
|
103
|
+
|
|
104
|
+
# Create agent
|
|
105
|
+
agent = create_agent(model, tools, system_prompt=system_prompt)
|
|
106
|
+
|
|
107
|
+
# Run the agent
|
|
108
|
+
result = agent.invoke(
|
|
109
|
+
{"messages": [{"role": "user", "content": question}]},
|
|
110
|
+
config={"callbacks": [langfuse_handler]},
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# Return the final answer
|
|
114
|
+
return result["messages"][-1].content
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@pixie.app
|
|
118
|
+
async def langchain_interactive_sql_agent() -> pixie.PixieGenerator[str, str]:
|
|
119
|
+
"""Interactive SQL database query agent with multi-turn conversation.
|
|
120
|
+
|
|
121
|
+
This agent maintains conversation history and can handle follow-up questions.
|
|
122
|
+
|
|
123
|
+
Yields:
|
|
124
|
+
AI responses to database queries
|
|
125
|
+
"""
|
|
126
|
+
# Setup database
|
|
127
|
+
db = setup_database()
|
|
128
|
+
|
|
129
|
+
# Initialize model
|
|
130
|
+
model = init_chat_model("gpt-4o-mini", temperature=0)
|
|
131
|
+
|
|
132
|
+
# Create SQL toolkit
|
|
133
|
+
toolkit = SQLDatabaseToolkit(db=db, llm=model)
|
|
134
|
+
tools = toolkit.get_tools()
|
|
135
|
+
|
|
136
|
+
# Format system prompt
|
|
137
|
+
system_prompt = SQL_AGENT_PROMPT.format(dialect=db.dialect, top_k=5)
|
|
138
|
+
|
|
139
|
+
# Create agent with checkpointer for conversation memory
|
|
140
|
+
agent = create_agent(
|
|
141
|
+
model, tools, system_prompt=system_prompt, checkpointer=InMemorySaver()
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
# Send welcome message
|
|
145
|
+
yield f"""Welcome to the SQL Query Assistant!
|
|
146
|
+
|
|
147
|
+
I can help you query the Chinook database, which contains information about:
|
|
148
|
+
- Artists and Albums
|
|
149
|
+
- Tracks and Genres
|
|
150
|
+
- Customers and Invoices
|
|
151
|
+
- Employees and more
|
|
152
|
+
|
|
153
|
+
Available tables: {', '.join(db.get_usable_table_names())}
|
|
154
|
+
|
|
155
|
+
Ask me any question about the data!"""
|
|
156
|
+
|
|
157
|
+
# Initialize conversation
|
|
158
|
+
thread_id = "sql_thread"
|
|
159
|
+
config = {"configurable": {"thread_id": thread_id}, "callbacks": [langfuse_handler]}
|
|
160
|
+
|
|
161
|
+
while True:
|
|
162
|
+
# Get user question
|
|
163
|
+
user_question = yield pixie.InputRequired(str)
|
|
164
|
+
|
|
165
|
+
# Check for exit
|
|
166
|
+
if user_question.lower() in {"exit", "quit", "bye"}:
|
|
167
|
+
yield "Goodbye! Feel free to come back if you have more questions about the database."
|
|
168
|
+
break
|
|
169
|
+
|
|
170
|
+
# Process with agent
|
|
171
|
+
result = agent.invoke(
|
|
172
|
+
{"messages": [{"role": "user", "content": user_question}]}, config # type: ignore
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# Yield the agent's response
|
|
176
|
+
yield result["messages"][-1].content
|
|
File without changes
|
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LangGraph RAG Agent (Retrieval Augmented Generation)
|
|
3
|
+
|
|
4
|
+
This example demonstrates building an agentic RAG system using LangGraph that can:
|
|
5
|
+
1. Decide when to use retrieval vs. respond directly
|
|
6
|
+
2. Grade retrieved documents for relevance
|
|
7
|
+
3. Rewrite questions if documents aren't relevant
|
|
8
|
+
4. Generate answers based on retrieved context
|
|
9
|
+
|
|
10
|
+
Based on: https://docs.langchain.com/oss/python/langgraph/agentic-rag
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from pydantic import BaseModel, Field
|
|
14
|
+
from typing import Literal
|
|
15
|
+
from langchain.chat_models import init_chat_model
|
|
16
|
+
from langchain.tools import tool
|
|
17
|
+
from langchain.messages import HumanMessage
|
|
18
|
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
19
|
+
from langchain_core.vectorstores import InMemoryVectorStore
|
|
20
|
+
from langchain_core.documents import Document
|
|
21
|
+
from langchain_openai import OpenAIEmbeddings
|
|
22
|
+
from langgraph.graph import END, START, MessagesState, StateGraph
|
|
23
|
+
from langgraph.prebuilt import ToolNode, tools_condition
|
|
24
|
+
|
|
25
|
+
from langfuse.langchain import CallbackHandler
|
|
26
|
+
import pixie
|
|
27
|
+
import requests
|
|
28
|
+
from bs4 import BeautifulSoup
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
langfuse_handler = CallbackHandler()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def load_web_page(url: str) -> list[Document]:
|
|
35
|
+
"""Simple web page loader using requests and BeautifulSoup.
|
|
36
|
+
|
|
37
|
+
Replaces langchain_community.document_loaders.WebBaseLoader
|
|
38
|
+
to avoid the langchain-community dependency.
|
|
39
|
+
"""
|
|
40
|
+
response = requests.get(url)
|
|
41
|
+
response.raise_for_status()
|
|
42
|
+
soup = BeautifulSoup(response.content, "html.parser")
|
|
43
|
+
|
|
44
|
+
# Extract text from the page
|
|
45
|
+
text = soup.get_text(separator="\n", strip=True)
|
|
46
|
+
|
|
47
|
+
return [Document(page_content=text, metadata={"source": url})]
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def setup_vectorstore():
|
|
51
|
+
"""Setup vectorstore with documents from Lilian Weng's blog."""
|
|
52
|
+
print("Loading documents from web...")
|
|
53
|
+
|
|
54
|
+
urls = [
|
|
55
|
+
"https://lilianweng.github.io/posts/2024-11-28-reward-hacking/",
|
|
56
|
+
"https://lilianweng.github.io/posts/2024-07-07-hallucination/",
|
|
57
|
+
"https://lilianweng.github.io/posts/2024-04-12-diffusion-video/",
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
docs = [load_web_page(url) for url in urls]
|
|
61
|
+
docs_list = [item for sublist in docs for item in sublist]
|
|
62
|
+
|
|
63
|
+
print("Splitting documents...")
|
|
64
|
+
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
|
65
|
+
chunk_size=100, chunk_overlap=50
|
|
66
|
+
)
|
|
67
|
+
doc_splits = text_splitter.split_documents(docs_list)
|
|
68
|
+
|
|
69
|
+
print("Creating vectorstore...")
|
|
70
|
+
vectorstore = InMemoryVectorStore.from_documents(
|
|
71
|
+
documents=doc_splits, embedding=OpenAIEmbeddings()
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
return vectorstore.as_retriever()
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def create_rag_graph(retriever, model):
|
|
78
|
+
"""Create a LangGraph-based RAG agent."""
|
|
79
|
+
|
|
80
|
+
# Create retriever tool
|
|
81
|
+
@tool
|
|
82
|
+
def retrieve_blog_posts(query: str) -> str:
|
|
83
|
+
"""Search and return information about Lilian Weng blog posts."""
|
|
84
|
+
docs = retriever.invoke(query)
|
|
85
|
+
return "\n\n".join([doc.page_content for doc in docs])
|
|
86
|
+
|
|
87
|
+
retriever_tool = retrieve_blog_posts
|
|
88
|
+
|
|
89
|
+
# Node: Generate query or respond
|
|
90
|
+
def generate_query_or_respond(state: MessagesState):
|
|
91
|
+
"""Call the model to generate a response or use retrieval tool."""
|
|
92
|
+
response = model.bind_tools([retriever_tool]).invoke(
|
|
93
|
+
state["messages"], config={"callbacks": [langfuse_handler]}
|
|
94
|
+
)
|
|
95
|
+
return {"messages": [response]}
|
|
96
|
+
|
|
97
|
+
# Grade documents schema
|
|
98
|
+
class GradeDocuments(BaseModel):
|
|
99
|
+
"""Grade documents using a binary score for relevance check."""
|
|
100
|
+
|
|
101
|
+
binary_score: str = Field(
|
|
102
|
+
description="Relevance score: 'yes' if relevant, or 'no' if not relevant"
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
GRADE_PROMPT = (
|
|
106
|
+
"You are a grader assessing relevance of a retrieved document to a user question. \n "
|
|
107
|
+
"Here is the retrieved document: \n\n {context} \n\n"
|
|
108
|
+
"Here is the user question: {question} \n"
|
|
109
|
+
"If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n"
|
|
110
|
+
"Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
grader_model = init_chat_model("gpt-4o", temperature=0)
|
|
114
|
+
|
|
115
|
+
# Conditional edge: Grade documents
|
|
116
|
+
def grade_documents(
|
|
117
|
+
state: MessagesState,
|
|
118
|
+
) -> Literal["generate_answer", "rewrite_question"]:
|
|
119
|
+
"""Determine whether the retrieved documents are relevant to the question."""
|
|
120
|
+
question = state["messages"][0].content
|
|
121
|
+
context = state["messages"][-1].content
|
|
122
|
+
|
|
123
|
+
prompt = GRADE_PROMPT.format(question=question, context=context)
|
|
124
|
+
response = grader_model.with_structured_output(GradeDocuments).invoke(
|
|
125
|
+
[{"role": "user", "content": prompt}],
|
|
126
|
+
config={"callbacks": [langfuse_handler]},
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
score = response.binary_score # type: ignore
|
|
130
|
+
if score == "yes":
|
|
131
|
+
return "generate_answer"
|
|
132
|
+
else:
|
|
133
|
+
return "rewrite_question"
|
|
134
|
+
|
|
135
|
+
# Node: Rewrite question
|
|
136
|
+
REWRITE_PROMPT = (
|
|
137
|
+
"Look at the input and try to reason about the underlying semantic intent / meaning.\n"
|
|
138
|
+
"Here is the initial question:"
|
|
139
|
+
"\n ------- \n"
|
|
140
|
+
"{question}"
|
|
141
|
+
"\n ------- \n"
|
|
142
|
+
"Formulate an improved question:"
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
def rewrite_question(state: MessagesState):
|
|
146
|
+
"""Rewrite the original user question."""
|
|
147
|
+
messages = state["messages"]
|
|
148
|
+
question = messages[0].content
|
|
149
|
+
prompt = REWRITE_PROMPT.format(question=question)
|
|
150
|
+
response = model.invoke(
|
|
151
|
+
[{"role": "user", "content": prompt}],
|
|
152
|
+
config={"callbacks": [langfuse_handler]},
|
|
153
|
+
)
|
|
154
|
+
return {"messages": [HumanMessage(content=response.content)]}
|
|
155
|
+
|
|
156
|
+
# Node: Generate answer
|
|
157
|
+
GENERATE_PROMPT = (
|
|
158
|
+
"You are an assistant for question-answering tasks. "
|
|
159
|
+
"Use the following pieces of retrieved context to answer the question. "
|
|
160
|
+
"If you don't know the answer, just say that you don't know. "
|
|
161
|
+
"Use three sentences maximum and keep the answer concise.\n"
|
|
162
|
+
"Question: {question} \n"
|
|
163
|
+
"Context: {context}"
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
def generate_answer(state: MessagesState):
|
|
167
|
+
"""Generate an answer."""
|
|
168
|
+
question = state["messages"][0].content
|
|
169
|
+
context = state["messages"][-1].content
|
|
170
|
+
prompt = GENERATE_PROMPT.format(question=question, context=context)
|
|
171
|
+
response = model.invoke(
|
|
172
|
+
[{"role": "user", "content": prompt}],
|
|
173
|
+
config={"callbacks": [langfuse_handler]},
|
|
174
|
+
)
|
|
175
|
+
return {"messages": [response]}
|
|
176
|
+
|
|
177
|
+
# Build graph
|
|
178
|
+
workflow = StateGraph(MessagesState)
|
|
179
|
+
|
|
180
|
+
# Define nodes
|
|
181
|
+
workflow.add_node(generate_query_or_respond)
|
|
182
|
+
workflow.add_node("retrieve", ToolNode([retriever_tool]))
|
|
183
|
+
workflow.add_node(rewrite_question)
|
|
184
|
+
workflow.add_node(generate_answer)
|
|
185
|
+
|
|
186
|
+
# Define edges
|
|
187
|
+
workflow.add_edge(START, "generate_query_or_respond")
|
|
188
|
+
|
|
189
|
+
# Decide whether to retrieve
|
|
190
|
+
workflow.add_conditional_edges(
|
|
191
|
+
"generate_query_or_respond",
|
|
192
|
+
tools_condition,
|
|
193
|
+
{
|
|
194
|
+
"tools": "retrieve",
|
|
195
|
+
END: END,
|
|
196
|
+
},
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
# Grade documents after retrieval
|
|
200
|
+
workflow.add_conditional_edges("retrieve", grade_documents)
|
|
201
|
+
workflow.add_edge("generate_answer", END)
|
|
202
|
+
workflow.add_edge("rewrite_question", "generate_query_or_respond")
|
|
203
|
+
|
|
204
|
+
return workflow.compile()
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
@pixie.app
|
|
208
|
+
async def langgraph_rag_agent(question: str) -> str:
|
|
209
|
+
"""Agentic RAG system that can answer questions about Lilian Weng's blog posts.
|
|
210
|
+
|
|
211
|
+
The agent:
|
|
212
|
+
1. Decides whether to retrieve or respond directly
|
|
213
|
+
2. Grades retrieved documents for relevance
|
|
214
|
+
3. Rewrites questions if needed
|
|
215
|
+
4. Generates answers based on context
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
question: Natural language question about the blog content
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
AI-generated answer based on retrieved context
|
|
222
|
+
"""
|
|
223
|
+
# Setup retriever (this will take a moment on first run)
|
|
224
|
+
retriever = setup_vectorstore()
|
|
225
|
+
|
|
226
|
+
# Initialize model
|
|
227
|
+
model = init_chat_model("gpt-4o-mini", temperature=0)
|
|
228
|
+
|
|
229
|
+
# Create graph
|
|
230
|
+
graph = create_rag_graph(retriever, model)
|
|
231
|
+
|
|
232
|
+
print(f"Processing question: {question}")
|
|
233
|
+
|
|
234
|
+
# Run the graph
|
|
235
|
+
result = graph.invoke(
|
|
236
|
+
{"messages": [{"role": "user", "content": question}]}, # type: ignore
|
|
237
|
+
config={"callbacks": [langfuse_handler]},
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# Return the final answer
|
|
241
|
+
return result["messages"][-1].content
|
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LangGraph SQL Agent (Custom Implementation)
|
|
3
|
+
|
|
4
|
+
This example demonstrates building a SQL agent directly using LangGraph primitives
|
|
5
|
+
for deeper customization. This gives more control over the agent's behavior compared
|
|
6
|
+
to the higher-level LangChain agent.
|
|
7
|
+
|
|
8
|
+
Based on: https://docs.langchain.com/oss/python/langgraph/sql-agent
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import pathlib
|
|
12
|
+
import requests
|
|
13
|
+
from typing import Literal
|
|
14
|
+
from langchain.chat_models import init_chat_model
|
|
15
|
+
from langchain.messages import AIMessage
|
|
16
|
+
from langgraph.graph import START, MessagesState, StateGraph
|
|
17
|
+
from ..sql_utils import SQLDatabase, SQLDatabaseToolkit
|
|
18
|
+
from langgraph.prebuilt import ToolNode
|
|
19
|
+
from langgraph.checkpoint.memory import InMemorySaver
|
|
20
|
+
|
|
21
|
+
from langfuse.langchain import CallbackHandler
|
|
22
|
+
import pixie
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
langfuse_handler = CallbackHandler()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def setup_database():
|
|
29
|
+
"""Download and setup the Chinook database if not already present."""
|
|
30
|
+
url = "https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db"
|
|
31
|
+
local_path = pathlib.Path("Chinook.db")
|
|
32
|
+
|
|
33
|
+
if local_path.exists():
|
|
34
|
+
print(f"{local_path} already exists, skipping download.")
|
|
35
|
+
else:
|
|
36
|
+
print("Downloading Chinook database...")
|
|
37
|
+
response = requests.get(url)
|
|
38
|
+
if response.status_code == 200:
|
|
39
|
+
local_path.write_bytes(response.content)
|
|
40
|
+
print(f"File downloaded and saved as {local_path}")
|
|
41
|
+
else:
|
|
42
|
+
raise Exception(
|
|
43
|
+
f"Failed to download the file. Status code: {response.status_code}"
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
return SQLDatabase.from_uri("sqlite:///Chinook.db")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def create_sql_graph(db: SQLDatabase, model):
|
|
50
|
+
"""Create a LangGraph-based SQL agent with custom workflow."""
|
|
51
|
+
|
|
52
|
+
# Get tools from toolkit
|
|
53
|
+
toolkit = SQLDatabaseToolkit(db=db, llm=model)
|
|
54
|
+
tools = toolkit.get_tools()
|
|
55
|
+
|
|
56
|
+
# Extract specific tools
|
|
57
|
+
get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")
|
|
58
|
+
get_schema_node = ToolNode([get_schema_tool], name="get_schema")
|
|
59
|
+
|
|
60
|
+
run_query_tool = next(tool for tool in tools if tool.name == "sql_db_query")
|
|
61
|
+
run_query_node = ToolNode([run_query_tool], name="run_query")
|
|
62
|
+
|
|
63
|
+
list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables")
|
|
64
|
+
|
|
65
|
+
# Node: List tables (forced tool call)
|
|
66
|
+
def list_tables(state: MessagesState):
|
|
67
|
+
tool_call = {
|
|
68
|
+
"name": "sql_db_list_tables",
|
|
69
|
+
"args": {},
|
|
70
|
+
"id": "list_tables_call",
|
|
71
|
+
"type": "tool_call",
|
|
72
|
+
}
|
|
73
|
+
tool_call_message = AIMessage(content="", tool_calls=[tool_call])
|
|
74
|
+
tool_message = list_tables_tool.invoke(tool_call)
|
|
75
|
+
response = AIMessage(f"Available tables: {tool_message.content}")
|
|
76
|
+
return {"messages": [tool_call_message, tool_message, response]}
|
|
77
|
+
|
|
78
|
+
# Node: Force model to call get_schema
|
|
79
|
+
def call_get_schema(state: MessagesState):
|
|
80
|
+
llm_with_tools = model.bind_tools([get_schema_tool], tool_choice="any")
|
|
81
|
+
response = llm_with_tools.invoke(
|
|
82
|
+
state["messages"], config={"callbacks": [langfuse_handler]}
|
|
83
|
+
)
|
|
84
|
+
return {"messages": [response]}
|
|
85
|
+
|
|
86
|
+
# Node: Generate query
|
|
87
|
+
generate_query_prompt = f"""
|
|
88
|
+
You are an agent designed to interact with a SQL database.
|
|
89
|
+
Given an input question, create a syntactically correct {db.dialect} query to run,
|
|
90
|
+
then look at the results of the query and return the answer. Unless the user
|
|
91
|
+
specifies a specific number of examples they wish to obtain, always limit your
|
|
92
|
+
query to at most 5 results.
|
|
93
|
+
|
|
94
|
+
You can order the results by a relevant column to return the most interesting
|
|
95
|
+
examples in the database. Never query for all the columns from a specific table,
|
|
96
|
+
only ask for the relevant columns given the question.
|
|
97
|
+
|
|
98
|
+
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
def generate_query(state: MessagesState):
|
|
102
|
+
system_message = {"role": "system", "content": generate_query_prompt}
|
|
103
|
+
llm_with_tools = model.bind_tools([run_query_tool])
|
|
104
|
+
response = llm_with_tools.invoke(
|
|
105
|
+
[system_message] + state["messages"],
|
|
106
|
+
config={"callbacks": [langfuse_handler]},
|
|
107
|
+
)
|
|
108
|
+
return {"messages": [response]}
|
|
109
|
+
|
|
110
|
+
# Node: Check query
|
|
111
|
+
check_query_prompt = f"""
|
|
112
|
+
You are a SQL expert with a strong attention to detail.
|
|
113
|
+
Double check the {db.dialect} query for common mistakes, including:
|
|
114
|
+
- Using NOT IN with NULL values
|
|
115
|
+
- Using UNION when UNION ALL should have been used
|
|
116
|
+
- Using BETWEEN for exclusive ranges
|
|
117
|
+
- Data type mismatch in predicates
|
|
118
|
+
- Properly quoting identifiers
|
|
119
|
+
- Using the correct number of arguments for functions
|
|
120
|
+
- Casting to the correct data type
|
|
121
|
+
- Using the proper columns for joins
|
|
122
|
+
|
|
123
|
+
If there are any of the above mistakes, rewrite the query. If there are no mistakes,
|
|
124
|
+
just reproduce the original query.
|
|
125
|
+
|
|
126
|
+
You will call the appropriate tool to execute the query after running this check.
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
def check_query(state: MessagesState):
|
|
130
|
+
from langchain.messages import AIMessage as AI
|
|
131
|
+
|
|
132
|
+
system_message = {"role": "system", "content": check_query_prompt}
|
|
133
|
+
last_message = state["messages"][-1]
|
|
134
|
+
# Only AIMessage has tool_calls
|
|
135
|
+
if isinstance(last_message, AI) and last_message.tool_calls:
|
|
136
|
+
tool_call = last_message.tool_calls[0]
|
|
137
|
+
user_message = {"role": "user", "content": tool_call["args"]["query"]}
|
|
138
|
+
else:
|
|
139
|
+
# Fallback if no tool calls
|
|
140
|
+
user_message = {"role": "user", "content": "Please check the query"}
|
|
141
|
+
llm_with_tools = model.bind_tools([run_query_tool], tool_choice="any")
|
|
142
|
+
response = llm_with_tools.invoke(
|
|
143
|
+
[system_message, user_message], config={"callbacks": [langfuse_handler]}
|
|
144
|
+
)
|
|
145
|
+
if isinstance(last_message, AI):
|
|
146
|
+
response.id = last_message.id
|
|
147
|
+
return {"messages": [response]}
|
|
148
|
+
|
|
149
|
+
# Conditional edge: Continue or end
|
|
150
|
+
def should_continue(state: MessagesState) -> Literal["__end__", "check_query"]:
|
|
151
|
+
from langchain.messages import AIMessage as AI
|
|
152
|
+
|
|
153
|
+
messages = state["messages"]
|
|
154
|
+
last_message = messages[-1]
|
|
155
|
+
# Check if last message is AIMessage and has tool calls
|
|
156
|
+
if isinstance(last_message, AI) and last_message.tool_calls:
|
|
157
|
+
return "check_query"
|
|
158
|
+
else:
|
|
159
|
+
return "__end__"
|
|
160
|
+
|
|
161
|
+
# Build graph
|
|
162
|
+
builder = StateGraph(MessagesState)
|
|
163
|
+
builder.add_node("list_tables", list_tables)
|
|
164
|
+
builder.add_node("call_get_schema", call_get_schema)
|
|
165
|
+
builder.add_node("get_schema", get_schema_node)
|
|
166
|
+
builder.add_node("generate_query", generate_query)
|
|
167
|
+
builder.add_node("check_query", check_query)
|
|
168
|
+
builder.add_node("run_query", run_query_node)
|
|
169
|
+
|
|
170
|
+
builder.add_edge(START, "list_tables")
|
|
171
|
+
builder.add_edge("list_tables", "call_get_schema")
|
|
172
|
+
builder.add_edge("call_get_schema", "get_schema")
|
|
173
|
+
builder.add_edge("get_schema", "generate_query")
|
|
174
|
+
builder.add_conditional_edges("generate_query", should_continue)
|
|
175
|
+
builder.add_edge("check_query", "run_query")
|
|
176
|
+
builder.add_edge("run_query", "generate_query")
|
|
177
|
+
|
|
178
|
+
return builder.compile(checkpointer=InMemorySaver())
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
@pixie.app
|
|
182
|
+
async def langgraph_sql_agent(question: str) -> str:
|
|
183
|
+
"""Custom SQL agent built with LangGraph primitives.
|
|
184
|
+
|
|
185
|
+
This agent has explicit control over the workflow:
|
|
186
|
+
1. Lists all tables
|
|
187
|
+
2. Gets schema for relevant tables
|
|
188
|
+
3. Generates SQL query
|
|
189
|
+
4. Checks query for errors
|
|
190
|
+
5. Executes query
|
|
191
|
+
6. Returns natural language answer
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
question: Natural language question about the database
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
AI-generated answer based on SQL query results
|
|
198
|
+
"""
|
|
199
|
+
# Setup database
|
|
200
|
+
db = setup_database()
|
|
201
|
+
|
|
202
|
+
# Initialize model
|
|
203
|
+
model = init_chat_model("gpt-4o-mini", temperature=0)
|
|
204
|
+
|
|
205
|
+
# Create graph
|
|
206
|
+
graph = create_sql_graph(db, model)
|
|
207
|
+
|
|
208
|
+
# Run the graph
|
|
209
|
+
result = graph.invoke(
|
|
210
|
+
{"messages": [{"role": "user", "content": question}]}, # type: ignore
|
|
211
|
+
{
|
|
212
|
+
"configurable": {"thread_id": "langgraph_sql"},
|
|
213
|
+
"callbacks": [langfuse_handler],
|
|
214
|
+
},
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# Return the final message
|
|
218
|
+
return result["messages"][-1].content
|