schema-search 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. schema_search/__init__.py +26 -0
  2. schema_search/chunkers/__init__.py +6 -0
  3. schema_search/chunkers/base.py +95 -0
  4. schema_search/chunkers/factory.py +31 -0
  5. schema_search/chunkers/llm.py +54 -0
  6. schema_search/chunkers/markdown.py +25 -0
  7. schema_search/embedding_cache/__init__.py +5 -0
  8. schema_search/embedding_cache/base.py +40 -0
  9. schema_search/embedding_cache/bm25.py +63 -0
  10. schema_search/embedding_cache/factory.py +20 -0
  11. schema_search/embedding_cache/inmemory.py +122 -0
  12. schema_search/graph_builder.py +69 -0
  13. schema_search/mcp_server.py +81 -0
  14. schema_search/metrics.py +33 -0
  15. schema_search/rankers/__init__.py +5 -0
  16. schema_search/rankers/base.py +45 -0
  17. schema_search/rankers/cross_encoder.py +40 -0
  18. schema_search/rankers/factory.py +11 -0
  19. schema_search/schema_extractor.py +135 -0
  20. schema_search/schema_search.py +276 -0
  21. schema_search/search/__init__.py +15 -0
  22. schema_search/search/base.py +85 -0
  23. schema_search/search/bm25.py +48 -0
  24. schema_search/search/factory.py +61 -0
  25. schema_search/search/fuzzy.py +56 -0
  26. schema_search/search/hybrid.py +82 -0
  27. schema_search/search/semantic.py +49 -0
  28. schema_search/types.py +57 -0
  29. schema_search/utils/__init__.py +0 -0
  30. schema_search/utils/lazy_import.py +26 -0
  31. schema_search-0.1.10.dist-info/METADATA +308 -0
  32. schema_search-0.1.10.dist-info/RECORD +40 -0
  33. schema_search-0.1.10.dist-info/WHEEL +5 -0
  34. schema_search-0.1.10.dist-info/entry_points.txt +2 -0
  35. schema_search-0.1.10.dist-info/licenses/LICENSE +21 -0
  36. schema_search-0.1.10.dist-info/top_level.txt +2 -0
  37. tests/__init__.py +0 -0
  38. tests/test_integration.py +352 -0
  39. tests/test_llm_sql_generation.py +320 -0
  40. tests/test_spider_eval.py +488 -0
@@ -0,0 +1,320 @@
1
+ import os
2
+ from pathlib import Path
3
+
4
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
5
+
6
+ import anthropic
7
+ import pytest
8
+ from dotenv import load_dotenv
9
+ from sqlalchemy import create_engine
10
+
11
+ from schema_search import SchemaSearch
12
+
13
+
14
+ @pytest.fixture(scope="module")
15
+ def database_url():
16
+ env_path = Path(__file__).parent / ".env"
17
+ load_dotenv(env_path)
18
+
19
+ url = os.getenv("DATABASE_URL")
20
+ if not url:
21
+ pytest.skip("DATABASE_URL not set in tests/.env file")
22
+
23
+ return url
24
+
25
+
26
+ @pytest.fixture(scope="module")
27
+ def llm_config():
28
+ env_path = Path(__file__).parent / ".env"
29
+ load_dotenv(env_path)
30
+
31
+ api_key = os.getenv("LLM_API_KEY")
32
+ base_url = os.getenv("LLM_BASE_URL")
33
+
34
+ if not api_key:
35
+ pytest.skip("LLM_API_KEY not set in tests/.env file")
36
+
37
+ return {"api_key": api_key, "base_url": base_url}
38
+
39
+
40
+ @pytest.fixture(scope="module")
41
+ def search_engine(database_url, llm_config):
42
+ engine = create_engine(database_url)
43
+ search = SchemaSearch(
44
+ engine,
45
+ llm_api_key=llm_config["api_key"],
46
+ llm_base_url=llm_config["base_url"],
47
+ )
48
+ search.index(force=False)
49
+ return search
50
+
51
+
52
+ def test_table_identification_with_schema_search(search_engine, llm_config):
53
+ """
54
+ Compare table identification quality when LLM has:
55
+ 1. Full schema context (all tables and indices)
56
+ 2. Limited context from schema search with graph hops
57
+
58
+ For each natural language question, we:
59
+ - Ask LLM which tables are needed with full schema context (baseline)
60
+ - Ask LLM which tables are needed with schema search context (our approach)
61
+ - Compare both against the objective list of required tables
62
+ """
63
+
64
+ eval_data = [
65
+ {
66
+ "question": "how many unique users do we have?",
67
+ "required_tables": ["user_metadata"],
68
+ "searches": ["user table"],
69
+ "hops": 1,
70
+ },
71
+ {
72
+ "question": "what is the email of the user who deposited the most last month",
73
+ "required_tables": ["user_metadata", "user_deposits"],
74
+ "searches": ["user email deposit"],
75
+ "hops": 1,
76
+ },
77
+ {
78
+ "question": "what is the twitter handle of the agent that posted the most?",
79
+ "required_tables": ["agent_metadata", "agent_content"],
80
+ "searches": ["agent metadata content"],
81
+ "hops": 1,
82
+ },
83
+ {
84
+ "question": "which topic was covered the most in news articles last month?",
85
+ "required_tables": ["news_to_topic_map", "topic_metadata"],
86
+ "searches": ["topic metadata news map"],
87
+ "hops": 1,
88
+ },
89
+ {
90
+ "question": "which coin's price increased the most last month?",
91
+ "required_tables": ["historical_market_data"],
92
+ "searches": ["historical market data"],
93
+ "hops": 1,
94
+ },
95
+ {
96
+ "question": "find the 5 most recent news about the coin that increased the most last month?",
97
+ "required_tables": [
98
+ "historical_market_data",
99
+ "news_to_topic_map",
100
+ "topic_metadata",
101
+ "news_summary",
102
+ ],
103
+ "searches": ["historical market data news topic"],
104
+ "hops": 1,
105
+ },
106
+ {
107
+ "question": "which model did the top user of last month use?",
108
+ "required_tables": ["user_metadata", "model_metadata", "query_metrics"],
109
+ "searches": ["user metadata model query metrics"],
110
+ "hops": 1,
111
+ },
112
+ {
113
+ "question": "which agent gained the most followers last month?",
114
+ "required_tables": ["agent_metadata", "twitter_follow_activity"],
115
+ "searches": ["agent metadata twitter follow activity"],
116
+ "hops": 1,
117
+ },
118
+ {
119
+ "question": "which agent posted the most content last month?",
120
+ "required_tables": ["agent_metadata", "agent_content"],
121
+ "searches": ["agent metadata agent content"],
122
+ "hops": 1,
123
+ },
124
+ {
125
+ "question": "which api key was most used during last month?",
126
+ "required_tables": ["api_token", "query_metrics", "user_metadata"],
127
+ "searches": ["api token query metrics user metadata"],
128
+ "hops": 1,
129
+ },
130
+ ]
131
+
132
+ def get_baseline_context(search_engine):
133
+ """Get minimal context: just table names and indices."""
134
+ context_parts = []
135
+
136
+ for table_name, table_schema in search_engine.schemas.items():
137
+ context_parts.append(f"Table: {table_name}")
138
+
139
+ indices = table_schema.get("indices")
140
+ if indices:
141
+ idx_list = ", ".join([idx["name"] for idx in indices])
142
+ context_parts.append(f"Indices: {idx_list}")
143
+
144
+ return "\n\n".join(context_parts)
145
+
146
+ def get_search_results_context(search_engine, searches, hops):
147
+ """Get detailed schema from search results to add to baseline."""
148
+ context_parts = []
149
+ seen_tables = set()
150
+
151
+ for search_query in searches:
152
+ response = search_engine.search(
153
+ search_query, hops=hops, limit=5, search_type="semantic"
154
+ )
155
+ for result in response["results"]:
156
+ table_name = result["table"]
157
+ if table_name in seen_tables:
158
+ continue
159
+ seen_tables.add(table_name)
160
+
161
+ columns = result["schema"].get("columns")
162
+ if columns:
163
+ col_list = ", ".join(
164
+ [f"{col['name']} ({col['type']})" for col in columns]
165
+ )
166
+ context_parts.append(f"Table: {table_name}\nColumns: {col_list}")
167
+ print("Search results tables: ", list(seen_tables))
168
+
169
+ return "\n\n".join(context_parts)
170
+
171
+ def call_llm_for_tables(question, schema_context, llm_config):
172
+ """Call LLM to identify which tables are needed."""
173
+ client = anthropic.Anthropic(api_key=llm_config["api_key"])
174
+
175
+ prompt = f"""Given the following database schema:
176
+
177
+ {schema_context}
178
+
179
+ Which tables are necessary to answer this question: {question}
180
+
181
+ Return ONLY a comma-separated list of table names, nothing else. No explanations or additional text.
182
+ Example format: table1, table2, table3"""
183
+
184
+ response = client.messages.create(
185
+ model="claude-sonnet-4-5-20250929",
186
+ max_tokens=512,
187
+ system="You are a database expert. Identify only the tables needed to answer the question.",
188
+ messages=[
189
+ {"role": "user", "content": prompt},
190
+ ],
191
+ temperature=0,
192
+ )
193
+
194
+ tables_str = response.content[0].text.strip() # type: ignore
195
+ tables = [t.strip() for t in tables_str.split(",") if t.strip()]
196
+ # Remove schema prefix if present
197
+ tables = [t.split(".")[-1] for t in tables]
198
+ return tables
199
+
200
+ def compare_tables(identified_tables, required_tables):
201
+ """Compare identified tables with required tables."""
202
+ identified_set = set(t.lower() for t in identified_tables)
203
+ required_set = set(t.lower() for t in required_tables)
204
+
205
+ correct = identified_set & required_set
206
+ missing = required_set - identified_set
207
+ extra = identified_set - required_set
208
+
209
+ is_perfect = len(missing) == 0 and len(extra) == 0
210
+
211
+ return {
212
+ "is_perfect": is_perfect,
213
+ "correct": correct,
214
+ "missing": missing,
215
+ "extra": extra,
216
+ "precision": len(correct) / len(identified_set) if identified_set else 0,
217
+ "recall": len(correct) / len(required_set) if required_set else 0,
218
+ }
219
+
220
+ if len(eval_data) == 0:
221
+ pytest.skip("No evaluation data provided")
222
+
223
+ print("\n" + "=" * 100)
224
+ print("EVALUATION: Table Identification - Baseline vs Baseline + Search Results")
225
+ print("=" * 100)
226
+
227
+ baseline_context = get_baseline_context(search_engine)
228
+
229
+ baseline_perfect = 0
230
+ baseline_total_precision = 0
231
+ baseline_total_recall = 0
232
+
233
+ search_perfect = 0
234
+ search_total_precision = 0
235
+ search_total_recall = 0
236
+
237
+ for idx, eval_item in enumerate(eval_data, 1):
238
+ question = eval_item["question"]
239
+ required_tables = eval_item.get("required_tables", [])
240
+ searches = eval_item.get("searches", [question])
241
+ hops = eval_item.get("hops", 1)
242
+
243
+ print(f"\n{'='*100}")
244
+ print(f"Question {idx}: {question}")
245
+ print(f"Required tables: {required_tables}")
246
+ print(f"{'='*100}")
247
+
248
+ # Get search results and combine with baseline
249
+ search_results_context = get_search_results_context(
250
+ search_engine, searches, hops
251
+ )
252
+ enhanced_context = baseline_context + "\n\n" + search_results_context
253
+
254
+ print(f"\n[Baseline only] Context: {len(baseline_context)} chars")
255
+ print(f"[Baseline + Search] Context: {len(enhanced_context)} chars")
256
+ print(f"Additional context from search: {len(search_results_context)} chars")
257
+
258
+ # Identify tables with baseline only
259
+ print("\n--- Identifying tables with BASELINE ONLY ---")
260
+ tables_baseline = call_llm_for_tables(question, baseline_context, llm_config)
261
+ print(f"Identified tables: {tables_baseline}")
262
+
263
+ comparison_baseline = compare_tables(tables_baseline, required_tables)
264
+ print(
265
+ f"Precision: {comparison_baseline['precision']:.2f}, Recall: {comparison_baseline['recall']:.2f}"
266
+ )
267
+ if comparison_baseline["missing"]:
268
+ print(f"Missing: {comparison_baseline['missing']}")
269
+ if comparison_baseline["extra"]:
270
+ print(f"Extra: {comparison_baseline['extra']}")
271
+
272
+ # Identify tables with baseline + search results
273
+ print("\n--- Identifying tables with BASELINE + SEARCH ---")
274
+ tables_search = call_llm_for_tables(question, enhanced_context, llm_config)
275
+ print(f"Identified tables: {tables_search}")
276
+
277
+ comparison_search = compare_tables(tables_search, required_tables)
278
+ print(
279
+ f"Precision: {comparison_search['precision']:.2f}, Recall: {comparison_search['recall']:.2f}"
280
+ )
281
+ if comparison_search["missing"]:
282
+ print(f"Missing: {comparison_search['missing']}")
283
+ if comparison_search["extra"]:
284
+ print(f"Extra: {comparison_search['extra']}")
285
+
286
+ # Track metrics
287
+ if comparison_baseline["is_perfect"]:
288
+ baseline_perfect += 1
289
+ print("\n✓ Baseline: PERFECT")
290
+ else:
291
+ print("\n✗ Baseline: Not perfect")
292
+
293
+ if comparison_search["is_perfect"]:
294
+ search_perfect += 1
295
+ print("✓ Schema Search: PERFECT")
296
+ else:
297
+ print("✗ Schema Search: Not perfect")
298
+
299
+ baseline_total_precision += comparison_baseline["precision"]
300
+ baseline_total_recall += comparison_baseline["recall"]
301
+ search_total_precision += comparison_search["precision"]
302
+ search_total_recall += comparison_search["recall"]
303
+
304
+ print("\n" + "=" * 100)
305
+ print("FINAL RESULTS")
306
+ print("=" * 100)
307
+ total_questions = len(eval_data)
308
+ print(f"Total questions: {total_questions}")
309
+ print(f"\nBaseline Only:")
310
+ print(f" Perfect matches: {baseline_perfect}/{total_questions}")
311
+ print(f" Avg Precision: {baseline_total_precision/total_questions:.2f}")
312
+ print(f" Avg Recall: {baseline_total_recall/total_questions:.2f}")
313
+
314
+ print(f"\nBaseline + Search Results:")
315
+ print(f" Perfect matches: {search_perfect}/{total_questions}")
316
+ print(f" Avg Precision: {search_total_precision/total_questions:.2f}")
317
+ print(f" Avg Recall: {search_total_recall/total_questions:.2f}")
318
+
319
+ print(f"\nImprovement: {search_perfect - baseline_perfect} more perfect matches")
320
+ print("=" * 100)