schema-search 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- schema_search/__init__.py +26 -0
- schema_search/chunkers/__init__.py +6 -0
- schema_search/chunkers/base.py +95 -0
- schema_search/chunkers/factory.py +31 -0
- schema_search/chunkers/llm.py +54 -0
- schema_search/chunkers/markdown.py +25 -0
- schema_search/embedding_cache/__init__.py +5 -0
- schema_search/embedding_cache/base.py +40 -0
- schema_search/embedding_cache/bm25.py +63 -0
- schema_search/embedding_cache/factory.py +20 -0
- schema_search/embedding_cache/inmemory.py +122 -0
- schema_search/graph_builder.py +69 -0
- schema_search/mcp_server.py +81 -0
- schema_search/metrics.py +33 -0
- schema_search/rankers/__init__.py +5 -0
- schema_search/rankers/base.py +45 -0
- schema_search/rankers/cross_encoder.py +40 -0
- schema_search/rankers/factory.py +11 -0
- schema_search/schema_extractor.py +135 -0
- schema_search/schema_search.py +276 -0
- schema_search/search/__init__.py +15 -0
- schema_search/search/base.py +85 -0
- schema_search/search/bm25.py +48 -0
- schema_search/search/factory.py +61 -0
- schema_search/search/fuzzy.py +56 -0
- schema_search/search/hybrid.py +82 -0
- schema_search/search/semantic.py +49 -0
- schema_search/types.py +57 -0
- schema_search/utils/__init__.py +0 -0
- schema_search/utils/lazy_import.py +26 -0
- schema_search-0.1.10.dist-info/METADATA +308 -0
- schema_search-0.1.10.dist-info/RECORD +40 -0
- schema_search-0.1.10.dist-info/WHEEL +5 -0
- schema_search-0.1.10.dist-info/entry_points.txt +2 -0
- schema_search-0.1.10.dist-info/licenses/LICENSE +21 -0
- schema_search-0.1.10.dist-info/top_level.txt +2 -0
- tests/__init__.py +0 -0
- tests/test_integration.py +352 -0
- tests/test_llm_sql_generation.py +320 -0
- tests/test_spider_eval.py +488 -0
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
5
|
+
|
|
6
|
+
import anthropic
|
|
7
|
+
import pytest
|
|
8
|
+
from dotenv import load_dotenv
|
|
9
|
+
from sqlalchemy import create_engine
|
|
10
|
+
|
|
11
|
+
from schema_search import SchemaSearch
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@pytest.fixture(scope="module")
|
|
15
|
+
def database_url():
|
|
16
|
+
env_path = Path(__file__).parent / ".env"
|
|
17
|
+
load_dotenv(env_path)
|
|
18
|
+
|
|
19
|
+
url = os.getenv("DATABASE_URL")
|
|
20
|
+
if not url:
|
|
21
|
+
pytest.skip("DATABASE_URL not set in tests/.env file")
|
|
22
|
+
|
|
23
|
+
return url
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@pytest.fixture(scope="module")
|
|
27
|
+
def llm_config():
|
|
28
|
+
env_path = Path(__file__).parent / ".env"
|
|
29
|
+
load_dotenv(env_path)
|
|
30
|
+
|
|
31
|
+
api_key = os.getenv("LLM_API_KEY")
|
|
32
|
+
base_url = os.getenv("LLM_BASE_URL")
|
|
33
|
+
|
|
34
|
+
if not api_key:
|
|
35
|
+
pytest.skip("LLM_API_KEY not set in tests/.env file")
|
|
36
|
+
|
|
37
|
+
return {"api_key": api_key, "base_url": base_url}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@pytest.fixture(scope="module")
|
|
41
|
+
def search_engine(database_url, llm_config):
|
|
42
|
+
engine = create_engine(database_url)
|
|
43
|
+
search = SchemaSearch(
|
|
44
|
+
engine,
|
|
45
|
+
llm_api_key=llm_config["api_key"],
|
|
46
|
+
llm_base_url=llm_config["base_url"],
|
|
47
|
+
)
|
|
48
|
+
search.index(force=False)
|
|
49
|
+
return search
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def test_table_identification_with_schema_search(search_engine, llm_config):
|
|
53
|
+
"""
|
|
54
|
+
Compare table identification quality when LLM has:
|
|
55
|
+
1. Full schema context (all tables and indices)
|
|
56
|
+
2. Limited context from schema search with graph hops
|
|
57
|
+
|
|
58
|
+
For each natural language question, we:
|
|
59
|
+
- Ask LLM which tables are needed with full schema context (baseline)
|
|
60
|
+
- Ask LLM which tables are needed with schema search context (our approach)
|
|
61
|
+
- Compare both against the objective list of required tables
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
eval_data = [
|
|
65
|
+
{
|
|
66
|
+
"question": "how many unique users do we have?",
|
|
67
|
+
"required_tables": ["user_metadata"],
|
|
68
|
+
"searches": ["user table"],
|
|
69
|
+
"hops": 1,
|
|
70
|
+
},
|
|
71
|
+
{
|
|
72
|
+
"question": "what is the email of the user who deposited the most last month",
|
|
73
|
+
"required_tables": ["user_metadata", "user_deposits"],
|
|
74
|
+
"searches": ["user email deposit"],
|
|
75
|
+
"hops": 1,
|
|
76
|
+
},
|
|
77
|
+
{
|
|
78
|
+
"question": "what is the twitter handle of the agent that posted the most?",
|
|
79
|
+
"required_tables": ["agent_metadata", "agent_content"],
|
|
80
|
+
"searches": ["agent metadata content"],
|
|
81
|
+
"hops": 1,
|
|
82
|
+
},
|
|
83
|
+
{
|
|
84
|
+
"question": "which topic was covered the most in news articles last month?",
|
|
85
|
+
"required_tables": ["news_to_topic_map", "topic_metadata"],
|
|
86
|
+
"searches": ["topic metadata news map"],
|
|
87
|
+
"hops": 1,
|
|
88
|
+
},
|
|
89
|
+
{
|
|
90
|
+
"question": "which coin's price increased the most last month?",
|
|
91
|
+
"required_tables": ["historical_market_data"],
|
|
92
|
+
"searches": ["historical market data"],
|
|
93
|
+
"hops": 1,
|
|
94
|
+
},
|
|
95
|
+
{
|
|
96
|
+
"question": "find the 5 most recent news about the coin that increased the most last month?",
|
|
97
|
+
"required_tables": [
|
|
98
|
+
"historical_market_data",
|
|
99
|
+
"news_to_topic_map",
|
|
100
|
+
"topic_metadata",
|
|
101
|
+
"news_summary",
|
|
102
|
+
],
|
|
103
|
+
"searches": ["historical market data news topic"],
|
|
104
|
+
"hops": 1,
|
|
105
|
+
},
|
|
106
|
+
{
|
|
107
|
+
"question": "which model did the top user of last month use?",
|
|
108
|
+
"required_tables": ["user_metadata", "model_metadata", "query_metrics"],
|
|
109
|
+
"searches": ["user metadata model query metrics"],
|
|
110
|
+
"hops": 1,
|
|
111
|
+
},
|
|
112
|
+
{
|
|
113
|
+
"question": "which agent gained the most followers last month?",
|
|
114
|
+
"required_tables": ["agent_metadata", "twitter_follow_activity"],
|
|
115
|
+
"searches": ["agent metadata twitter follow activity"],
|
|
116
|
+
"hops": 1,
|
|
117
|
+
},
|
|
118
|
+
{
|
|
119
|
+
"question": "which agent posted the most content last month?",
|
|
120
|
+
"required_tables": ["agent_metadata", "agent_content"],
|
|
121
|
+
"searches": ["agent metadata agent content"],
|
|
122
|
+
"hops": 1,
|
|
123
|
+
},
|
|
124
|
+
{
|
|
125
|
+
"question": "which api key was most used during last month?",
|
|
126
|
+
"required_tables": ["api_token", "query_metrics", "user_metadata"],
|
|
127
|
+
"searches": ["api token query metrics user metadata"],
|
|
128
|
+
"hops": 1,
|
|
129
|
+
},
|
|
130
|
+
]
|
|
131
|
+
|
|
132
|
+
def get_baseline_context(search_engine):
|
|
133
|
+
"""Get minimal context: just table names and indices."""
|
|
134
|
+
context_parts = []
|
|
135
|
+
|
|
136
|
+
for table_name, table_schema in search_engine.schemas.items():
|
|
137
|
+
context_parts.append(f"Table: {table_name}")
|
|
138
|
+
|
|
139
|
+
indices = table_schema.get("indices")
|
|
140
|
+
if indices:
|
|
141
|
+
idx_list = ", ".join([idx["name"] for idx in indices])
|
|
142
|
+
context_parts.append(f"Indices: {idx_list}")
|
|
143
|
+
|
|
144
|
+
return "\n\n".join(context_parts)
|
|
145
|
+
|
|
146
|
+
def get_search_results_context(search_engine, searches, hops):
|
|
147
|
+
"""Get detailed schema from search results to add to baseline."""
|
|
148
|
+
context_parts = []
|
|
149
|
+
seen_tables = set()
|
|
150
|
+
|
|
151
|
+
for search_query in searches:
|
|
152
|
+
response = search_engine.search(
|
|
153
|
+
search_query, hops=hops, limit=5, search_type="semantic"
|
|
154
|
+
)
|
|
155
|
+
for result in response["results"]:
|
|
156
|
+
table_name = result["table"]
|
|
157
|
+
if table_name in seen_tables:
|
|
158
|
+
continue
|
|
159
|
+
seen_tables.add(table_name)
|
|
160
|
+
|
|
161
|
+
columns = result["schema"].get("columns")
|
|
162
|
+
if columns:
|
|
163
|
+
col_list = ", ".join(
|
|
164
|
+
[f"{col['name']} ({col['type']})" for col in columns]
|
|
165
|
+
)
|
|
166
|
+
context_parts.append(f"Table: {table_name}\nColumns: {col_list}")
|
|
167
|
+
print("Search results tables: ", list(seen_tables))
|
|
168
|
+
|
|
169
|
+
return "\n\n".join(context_parts)
|
|
170
|
+
|
|
171
|
+
def call_llm_for_tables(question, schema_context, llm_config):
|
|
172
|
+
"""Call LLM to identify which tables are needed."""
|
|
173
|
+
client = anthropic.Anthropic(api_key=llm_config["api_key"])
|
|
174
|
+
|
|
175
|
+
prompt = f"""Given the following database schema:
|
|
176
|
+
|
|
177
|
+
{schema_context}
|
|
178
|
+
|
|
179
|
+
Which tables are necessary to answer this question: {question}
|
|
180
|
+
|
|
181
|
+
Return ONLY a comma-separated list of table names, nothing else. No explanations or additional text.
|
|
182
|
+
Example format: table1, table2, table3"""
|
|
183
|
+
|
|
184
|
+
response = client.messages.create(
|
|
185
|
+
model="claude-sonnet-4-5-20250929",
|
|
186
|
+
max_tokens=512,
|
|
187
|
+
system="You are a database expert. Identify only the tables needed to answer the question.",
|
|
188
|
+
messages=[
|
|
189
|
+
{"role": "user", "content": prompt},
|
|
190
|
+
],
|
|
191
|
+
temperature=0,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
tables_str = response.content[0].text.strip() # type: ignore
|
|
195
|
+
tables = [t.strip() for t in tables_str.split(",") if t.strip()]
|
|
196
|
+
# Remove schema prefix if present
|
|
197
|
+
tables = [t.split(".")[-1] for t in tables]
|
|
198
|
+
return tables
|
|
199
|
+
|
|
200
|
+
def compare_tables(identified_tables, required_tables):
|
|
201
|
+
"""Compare identified tables with required tables."""
|
|
202
|
+
identified_set = set(t.lower() for t in identified_tables)
|
|
203
|
+
required_set = set(t.lower() for t in required_tables)
|
|
204
|
+
|
|
205
|
+
correct = identified_set & required_set
|
|
206
|
+
missing = required_set - identified_set
|
|
207
|
+
extra = identified_set - required_set
|
|
208
|
+
|
|
209
|
+
is_perfect = len(missing) == 0 and len(extra) == 0
|
|
210
|
+
|
|
211
|
+
return {
|
|
212
|
+
"is_perfect": is_perfect,
|
|
213
|
+
"correct": correct,
|
|
214
|
+
"missing": missing,
|
|
215
|
+
"extra": extra,
|
|
216
|
+
"precision": len(correct) / len(identified_set) if identified_set else 0,
|
|
217
|
+
"recall": len(correct) / len(required_set) if required_set else 0,
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
if len(eval_data) == 0:
|
|
221
|
+
pytest.skip("No evaluation data provided")
|
|
222
|
+
|
|
223
|
+
print("\n" + "=" * 100)
|
|
224
|
+
print("EVALUATION: Table Identification - Baseline vs Baseline + Search Results")
|
|
225
|
+
print("=" * 100)
|
|
226
|
+
|
|
227
|
+
baseline_context = get_baseline_context(search_engine)
|
|
228
|
+
|
|
229
|
+
baseline_perfect = 0
|
|
230
|
+
baseline_total_precision = 0
|
|
231
|
+
baseline_total_recall = 0
|
|
232
|
+
|
|
233
|
+
search_perfect = 0
|
|
234
|
+
search_total_precision = 0
|
|
235
|
+
search_total_recall = 0
|
|
236
|
+
|
|
237
|
+
for idx, eval_item in enumerate(eval_data, 1):
|
|
238
|
+
question = eval_item["question"]
|
|
239
|
+
required_tables = eval_item.get("required_tables", [])
|
|
240
|
+
searches = eval_item.get("searches", [question])
|
|
241
|
+
hops = eval_item.get("hops", 1)
|
|
242
|
+
|
|
243
|
+
print(f"\n{'='*100}")
|
|
244
|
+
print(f"Question {idx}: {question}")
|
|
245
|
+
print(f"Required tables: {required_tables}")
|
|
246
|
+
print(f"{'='*100}")
|
|
247
|
+
|
|
248
|
+
# Get search results and combine with baseline
|
|
249
|
+
search_results_context = get_search_results_context(
|
|
250
|
+
search_engine, searches, hops
|
|
251
|
+
)
|
|
252
|
+
enhanced_context = baseline_context + "\n\n" + search_results_context
|
|
253
|
+
|
|
254
|
+
print(f"\n[Baseline only] Context: {len(baseline_context)} chars")
|
|
255
|
+
print(f"[Baseline + Search] Context: {len(enhanced_context)} chars")
|
|
256
|
+
print(f"Additional context from search: {len(search_results_context)} chars")
|
|
257
|
+
|
|
258
|
+
# Identify tables with baseline only
|
|
259
|
+
print("\n--- Identifying tables with BASELINE ONLY ---")
|
|
260
|
+
tables_baseline = call_llm_for_tables(question, baseline_context, llm_config)
|
|
261
|
+
print(f"Identified tables: {tables_baseline}")
|
|
262
|
+
|
|
263
|
+
comparison_baseline = compare_tables(tables_baseline, required_tables)
|
|
264
|
+
print(
|
|
265
|
+
f"Precision: {comparison_baseline['precision']:.2f}, Recall: {comparison_baseline['recall']:.2f}"
|
|
266
|
+
)
|
|
267
|
+
if comparison_baseline["missing"]:
|
|
268
|
+
print(f"Missing: {comparison_baseline['missing']}")
|
|
269
|
+
if comparison_baseline["extra"]:
|
|
270
|
+
print(f"Extra: {comparison_baseline['extra']}")
|
|
271
|
+
|
|
272
|
+
# Identify tables with baseline + search results
|
|
273
|
+
print("\n--- Identifying tables with BASELINE + SEARCH ---")
|
|
274
|
+
tables_search = call_llm_for_tables(question, enhanced_context, llm_config)
|
|
275
|
+
print(f"Identified tables: {tables_search}")
|
|
276
|
+
|
|
277
|
+
comparison_search = compare_tables(tables_search, required_tables)
|
|
278
|
+
print(
|
|
279
|
+
f"Precision: {comparison_search['precision']:.2f}, Recall: {comparison_search['recall']:.2f}"
|
|
280
|
+
)
|
|
281
|
+
if comparison_search["missing"]:
|
|
282
|
+
print(f"Missing: {comparison_search['missing']}")
|
|
283
|
+
if comparison_search["extra"]:
|
|
284
|
+
print(f"Extra: {comparison_search['extra']}")
|
|
285
|
+
|
|
286
|
+
# Track metrics
|
|
287
|
+
if comparison_baseline["is_perfect"]:
|
|
288
|
+
baseline_perfect += 1
|
|
289
|
+
print("\n✓ Baseline: PERFECT")
|
|
290
|
+
else:
|
|
291
|
+
print("\n✗ Baseline: Not perfect")
|
|
292
|
+
|
|
293
|
+
if comparison_search["is_perfect"]:
|
|
294
|
+
search_perfect += 1
|
|
295
|
+
print("✓ Schema Search: PERFECT")
|
|
296
|
+
else:
|
|
297
|
+
print("✗ Schema Search: Not perfect")
|
|
298
|
+
|
|
299
|
+
baseline_total_precision += comparison_baseline["precision"]
|
|
300
|
+
baseline_total_recall += comparison_baseline["recall"]
|
|
301
|
+
search_total_precision += comparison_search["precision"]
|
|
302
|
+
search_total_recall += comparison_search["recall"]
|
|
303
|
+
|
|
304
|
+
print("\n" + "=" * 100)
|
|
305
|
+
print("FINAL RESULTS")
|
|
306
|
+
print("=" * 100)
|
|
307
|
+
total_questions = len(eval_data)
|
|
308
|
+
print(f"Total questions: {total_questions}")
|
|
309
|
+
print(f"\nBaseline Only:")
|
|
310
|
+
print(f" Perfect matches: {baseline_perfect}/{total_questions}")
|
|
311
|
+
print(f" Avg Precision: {baseline_total_precision/total_questions:.2f}")
|
|
312
|
+
print(f" Avg Recall: {baseline_total_recall/total_questions:.2f}")
|
|
313
|
+
|
|
314
|
+
print(f"\nBaseline + Search Results:")
|
|
315
|
+
print(f" Perfect matches: {search_perfect}/{total_questions}")
|
|
316
|
+
print(f" Avg Precision: {search_total_precision/total_questions:.2f}")
|
|
317
|
+
print(f" Avg Recall: {search_total_recall/total_questions:.2f}")
|
|
318
|
+
|
|
319
|
+
print(f"\nImprovement: {search_perfect - baseline_perfect} more perfect matches")
|
|
320
|
+
print("=" * 100)
|