schema-search 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. schema_search/__init__.py +26 -0
  2. schema_search/chunkers/__init__.py +6 -0
  3. schema_search/chunkers/base.py +95 -0
  4. schema_search/chunkers/factory.py +31 -0
  5. schema_search/chunkers/llm.py +54 -0
  6. schema_search/chunkers/markdown.py +25 -0
  7. schema_search/embedding_cache/__init__.py +5 -0
  8. schema_search/embedding_cache/base.py +40 -0
  9. schema_search/embedding_cache/bm25.py +63 -0
  10. schema_search/embedding_cache/factory.py +20 -0
  11. schema_search/embedding_cache/inmemory.py +122 -0
  12. schema_search/graph_builder.py +69 -0
  13. schema_search/mcp_server.py +81 -0
  14. schema_search/metrics.py +33 -0
  15. schema_search/rankers/__init__.py +5 -0
  16. schema_search/rankers/base.py +45 -0
  17. schema_search/rankers/cross_encoder.py +40 -0
  18. schema_search/rankers/factory.py +11 -0
  19. schema_search/schema_extractor.py +135 -0
  20. schema_search/schema_search.py +276 -0
  21. schema_search/search/__init__.py +15 -0
  22. schema_search/search/base.py +85 -0
  23. schema_search/search/bm25.py +48 -0
  24. schema_search/search/factory.py +61 -0
  25. schema_search/search/fuzzy.py +56 -0
  26. schema_search/search/hybrid.py +82 -0
  27. schema_search/search/semantic.py +49 -0
  28. schema_search/types.py +57 -0
  29. schema_search/utils/__init__.py +0 -0
  30. schema_search/utils/lazy_import.py +26 -0
  31. schema_search-0.1.10.dist-info/METADATA +308 -0
  32. schema_search-0.1.10.dist-info/RECORD +40 -0
  33. schema_search-0.1.10.dist-info/WHEEL +5 -0
  34. schema_search-0.1.10.dist-info/entry_points.txt +2 -0
  35. schema_search-0.1.10.dist-info/licenses/LICENSE +21 -0
  36. schema_search-0.1.10.dist-info/top_level.txt +2 -0
  37. tests/__init__.py +0 -0
  38. tests/test_integration.py +352 -0
  39. tests/test_llm_sql_generation.py +320 -0
  40. tests/test_spider_eval.py +488 -0
@@ -0,0 +1,488 @@
1
+ """Spider benchmark evaluation.
2
+
3
+ Warning: this test intentionally creates and drops PostgreSQL databases
4
+ matching Spider db_ids. Only run against an isolated server you are
5
+ comfortable wiping.
6
+ """
7
+
8
+ import re
9
+ import time
10
+ import json
11
+ from collections import defaultdict
12
+ from typing import List, Set, Dict, Any
13
+ from pathlib import Path
14
+
15
+ import numpy as np
16
+ import pytest
17
+ from datasets import load_dataset
18
+ from sqlalchemy import create_engine, text
19
+ from tqdm import tqdm
20
+ from schema_search import SchemaSearch
21
+ from schema_search.types import SearchType
22
+
23
+ # WARNING: Do not swap this for an environment variable—this benchmark is meant
24
+ # to run against an explicit throwaway Postgres instance.
25
+ DATABASE_URL = "postgresql://user:pass@localhost/db"
26
+ MIN_TABLES = 10
27
+
28
+
29
+ def parse_spider_schema(schema_str: str) -> Dict[str, List[tuple]]:
30
+ """Parse spider-schema format into table definitions."""
31
+ tables = {}
32
+ table_parts = schema_str.split(" | ")
33
+
34
+ for table_part in table_parts:
35
+ if " : " not in table_part:
36
+ continue
37
+
38
+ table_name, columns_str = table_part.split(" : ", 1)
39
+ table_name = table_name.strip()
40
+
41
+ columns = []
42
+ col_parts = columns_str.split(" , ")
43
+
44
+ for col_part in col_parts:
45
+ match = re.match(r"(.+?)\s*\((\w+)\)", col_part.strip())
46
+ if match:
47
+ col_name = match.group(1).strip()
48
+ col_type = match.group(2).strip()
49
+ sql_type = "INTEGER" if col_type == "number" else "TEXT"
50
+ columns.append((col_name, sql_type))
51
+
52
+ if columns:
53
+ tables[table_name] = columns
54
+
55
+ return tables
56
+
57
+
58
+ def extract_tables_from_sql(sql: str) -> Set[str]:
59
+ """Extract table names from SQL query using regex.
60
+
61
+ Handles: aliases, quoted names, schemas, comma-separated joins.
62
+ """
63
+ sql_lower = sql.lower()
64
+
65
+ def clean_table_name(name: str) -> str:
66
+ """Strip quotes, brackets, schema prefix, and aliases."""
67
+ name = name.strip().strip('`"[]')
68
+ if "." in name:
69
+ name = name.split(".")[-1]
70
+ if " " in name:
71
+ name = name.split()[0]
72
+ return name.strip()
73
+
74
+ tables = set()
75
+
76
+ from_pattern = r"\bfrom\s+([`\"[]?[\w]+[`\"\]]?(?:\.[\w]+)?(?:\s+(?:as\s+)?\w+)?)"
77
+ join_pattern = r"\bjoin\s+([`\"[]?[\w]+[`\"\]]?(?:\.[\w]+)?(?:\s+(?:as\s+)?\w+)?)"
78
+
79
+ for match in re.finditer(from_pattern, sql_lower):
80
+ table = clean_table_name(match.group(1))
81
+ if table:
82
+ tables.add(table)
83
+
84
+ for match in re.finditer(join_pattern, sql_lower):
85
+ table = clean_table_name(match.group(1))
86
+ if table:
87
+ tables.add(table)
88
+
89
+ from_clause_match = re.search(
90
+ r"\bfrom\s+([^;]+?)(?:\bwhere\b|\bjoin\b|\bgroup\b|\border\b|\blimit\b|$)",
91
+ sql_lower,
92
+ )
93
+ if from_clause_match:
94
+ from_clause = from_clause_match.group(1)
95
+ for table_expr in re.split(r"\s*,\s*", from_clause):
96
+ table = clean_table_name(table_expr)
97
+ if table and re.match(r"^[a-zA-Z_][\w]*$", table):
98
+ tables.add(table)
99
+
100
+ return tables
101
+
102
+
103
+ def create_database_from_schema(db_id: str, schema_str: str):
104
+ """Create PostgreSQL database from spider-schema string.
105
+
106
+ Returns:
107
+ tuple: (db_engine, num_tables)
108
+ """
109
+ safe_db_id = db_id.lower().replace("-", "_")
110
+
111
+ assert (
112
+ "localhost" in DATABASE_URL
113
+ ), "DATABASE_URL must be a test database on localhost"
114
+
115
+ admin_engine = create_engine(
116
+ DATABASE_URL, isolation_level="AUTOCOMMIT", pool_pre_ping=True, pool_recycle=60
117
+ )
118
+
119
+ with admin_engine.connect() as conn:
120
+ conn.execute(
121
+ text(
122
+ f"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{safe_db_id}' AND pid != pg_backend_pid()"
123
+ )
124
+ )
125
+ conn.execute(text(f"DROP DATABASE IF EXISTS {safe_db_id}"))
126
+ conn.execute(text(f"CREATE DATABASE {safe_db_id}"))
127
+
128
+ admin_engine.dispose()
129
+
130
+ db_engine = create_engine(
131
+ f"postgresql://user:pass@localhost/{safe_db_id}",
132
+ pool_pre_ping=True,
133
+ pool_recycle=60,
134
+ )
135
+ tables = parse_spider_schema(schema_str)
136
+
137
+ with db_engine.connect() as conn:
138
+ for table_name, columns in tables.items():
139
+ col_defs = ", ".join(
140
+ [f'"{col_name}" {col_type}' for col_name, col_type in columns]
141
+ )
142
+ create_sql = f'CREATE TABLE "{table_name}" ({col_defs})'
143
+ conn.execute(text(create_sql))
144
+ conn.commit()
145
+
146
+ return db_engine, len(tables)
147
+
148
+
149
+ def calculate_recall_at_k(
150
+ predicted: List[str], ground_truth: Set[str], k: int
151
+ ) -> float:
152
+ """Calculate Recall@k = (relevant retrieved) / (total relevant).
153
+
154
+ For k=1, ground_truth={A,B,C}, predicted=[A,...]: Recall@1 = 1/3 = 0.33
155
+ """
156
+ if not ground_truth:
157
+ return 0.0
158
+
159
+ predicted_at_k = set([p.lower() for p in predicted[:k]])
160
+ ground_truth_lower = set([g.lower() for g in ground_truth])
161
+
162
+ matches = len(predicted_at_k & ground_truth_lower)
163
+ return matches / len(ground_truth_lower)
164
+
165
+
166
+ def calculate_precision_at_k(
167
+ predicted: List[str], ground_truth: Set[str], k: int
168
+ ) -> float:
169
+ """Calculate Precision@k = (relevant retrieved) / k.
170
+
171
+ For k=1, ground_truth={A,B,C}, predicted=[A,...]: Precision@1 = 1/1 = 1.0
172
+ """
173
+ if k == 0:
174
+ return 0.0
175
+
176
+ predicted_at_k = set([p.lower() for p in predicted[:k]])
177
+ ground_truth_lower = set([g.lower() for g in ground_truth])
178
+
179
+ matches = len(predicted_at_k & ground_truth_lower)
180
+ return matches / k
181
+
182
+
183
+ def calculate_mrr(predicted: List[str], ground_truth: Set[str]) -> float:
184
+ """Calculate Mean Reciprocal Rank = 1/rank of first correct item."""
185
+ if not ground_truth:
186
+ return 0.0
187
+
188
+ ground_truth_lower = set([g.lower() for g in ground_truth])
189
+
190
+ for rank, pred in enumerate(predicted, 1):
191
+ if pred.lower() in ground_truth_lower:
192
+ return 1.0 / rank
193
+
194
+ return 0.0
195
+
196
+
197
+ def save_benchmark_results(results_by_strategy, index_latencies, strategies):
198
+ """Save benchmark results as JSON."""
199
+ import yaml
200
+
201
+ config_path = Path(__file__).parent.parent / "config.yml"
202
+ with open(config_path) as f:
203
+ config = yaml.safe_load(f)
204
+
205
+ has_reranker = config.get("reranker", {}).get("model") is not None
206
+ reranker_model = config.get("reranker", {}).get("model")
207
+
208
+ img_dir = Path(__file__).parent.parent / "img"
209
+ img_dir.mkdir(exist_ok=True)
210
+
211
+ output = {
212
+ "reranker_enabled": has_reranker,
213
+ "reranker_model": reranker_model,
214
+ "indexing": {},
215
+ "strategies": {},
216
+ }
217
+
218
+ if index_latencies:
219
+ output["indexing"] = {
220
+ "num_databases": len(index_latencies),
221
+ "mean_latency": float(np.mean(index_latencies)),
222
+ "std_latency": float(np.std(index_latencies)),
223
+ }
224
+
225
+ for strategy in strategies:
226
+ stats = results_by_strategy[strategy]
227
+ if stats["num_queries"] == 0:
228
+ continue
229
+
230
+ strategy_results = {"num_queries": stats["num_queries"]}
231
+
232
+ metric_names = [
233
+ "recall_at_1",
234
+ "recall_at_3",
235
+ "recall_at_5",
236
+ "mrr",
237
+ "precision_at_1",
238
+ "precision_at_3",
239
+ "precision_at_5",
240
+ "latency",
241
+ ]
242
+
243
+ for metric_name in metric_names:
244
+ if metric_name in stats and stats[metric_name]:
245
+ values = np.array(stats[metric_name])
246
+ strategy_results[metric_name] = {
247
+ "mean": float(np.mean(values)),
248
+ "std": float(np.std(values)),
249
+ }
250
+
251
+ output["strategies"][strategy] = strategy_results
252
+
253
+ output_filename = (
254
+ "spider_benchmark_with_reranker.json"
255
+ if has_reranker
256
+ else "spider_benchmark_without_reranker.json"
257
+ )
258
+ output_path = img_dir / output_filename
259
+
260
+ with open(output_path, "w") as f:
261
+ json.dump(output, f, indent=2)
262
+
263
+ print(f"\nBenchmark results saved to: {output_path}")
264
+
265
+
266
+ def cleanup_spider_databases():
267
+ """Drop Spider test databases and clear cache before starting tests.
268
+
269
+ Only drops databases matching Spider naming pattern (alphanumeric + underscores).
270
+ """
271
+ import shutil
272
+ from pathlib import Path
273
+
274
+ admin_engine = create_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
275
+
276
+ with admin_engine.connect() as conn:
277
+ result = conn.execute(
278
+ text(
279
+ "SELECT datname FROM pg_database WHERE datname NOT IN ('postgres', 'template0', 'template1', 'db')"
280
+ )
281
+ )
282
+ databases = [row[0] for row in result]
283
+
284
+ for db_name in databases:
285
+ if re.match(r"^[a-z0-9_]+$", db_name):
286
+ conn.execute(
287
+ text(
288
+ f"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{db_name}'"
289
+ )
290
+ )
291
+ conn.execute(text(f"DROP DATABASE IF EXISTS {db_name}"))
292
+
293
+ admin_engine.dispose()
294
+
295
+ cache_path = Path("/tmp/.schema_search_cache")
296
+ if cache_path.exists():
297
+ shutil.rmtree(cache_path)
298
+ print(f"Cleared cache: {cache_path}")
299
+
300
+
301
+ @pytest.fixture(scope="module")
302
+ def spider_data():
303
+ """Load Spider questions/queries and spider-schema definitions."""
304
+ cleanup_spider_databases()
305
+
306
+ spider = load_dataset("spider", split="train")
307
+ spider_schema = load_dataset("richardr1126/spider-schema", split="train")
308
+
309
+ schema_map = {ex["db_id"]: ex for ex in spider_schema}
310
+
311
+ return spider, schema_map
312
+
313
+
314
+ def test_spider_evaluation(spider_data):
315
+ """Evaluate schema search on Spider benchmark.
316
+
317
+ Tests table-level retrieval accuracy across multiple databases.
318
+ Metrics: Recall@1, Recall@3, Recall@5, MRR, Precision@k, Latency
319
+ """
320
+ spider, schema_map = spider_data
321
+
322
+ results_by_strategy = defaultdict(
323
+ lambda: {
324
+ "recall_at_1": [],
325
+ "recall_at_3": [],
326
+ "recall_at_5": [],
327
+ "mrr": [],
328
+ "precision_at_1": [],
329
+ "precision_at_3": [],
330
+ "precision_at_5": [],
331
+ "latency": [],
332
+ "num_queries": 0,
333
+ }
334
+ )
335
+
336
+ strategies: List[SearchType] = ["hybrid", "fuzzy", "bm25", "semantic"]
337
+
338
+ current_db_id = None
339
+ search_engine = None
340
+ db_engine = None
341
+ index_latencies = []
342
+ num_dbs_filtered = 0
343
+ total_queries_evaluated = 0
344
+
345
+ for example in tqdm(spider, desc="Evaluating Spider"):
346
+ db_id = example["db_id"]
347
+ question = example["question"]
348
+ sql = example["query"]
349
+
350
+ if db_id not in schema_map:
351
+ continue
352
+
353
+ if db_id != current_db_id:
354
+ if db_engine is not None:
355
+ db_engine.dispose()
356
+
357
+ schema_str = schema_map[db_id]["Schema (values (type))"]
358
+ db_engine, num_tables = create_database_from_schema(db_id, schema_str)
359
+
360
+ if num_tables < MIN_TABLES:
361
+ num_dbs_filtered += 1
362
+ safe_db_id = db_id.lower().replace("-", "_")
363
+ db_engine.dispose()
364
+
365
+ admin_engine = create_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
366
+ with admin_engine.connect() as conn:
367
+ conn.execute(
368
+ text(
369
+ f"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{safe_db_id}'"
370
+ )
371
+ )
372
+ conn.execute(text(f"DROP DATABASE IF EXISTS {safe_db_id}"))
373
+ admin_engine.dispose()
374
+
375
+ db_engine = None
376
+ search_engine = None
377
+ current_db_id = db_id
378
+ continue
379
+
380
+ search_engine = SchemaSearch(db_engine)
381
+
382
+ index_start = time.time()
383
+ search_engine.index(force=False)
384
+ index_latency = time.time() - index_start
385
+ index_latencies.append(index_latency)
386
+
387
+ current_db_id = db_id
388
+
389
+ if search_engine is None:
390
+ continue
391
+
392
+ ground_truth_tables = extract_tables_from_sql(sql)
393
+
394
+ if not ground_truth_tables:
395
+ continue
396
+
397
+ for strategy in strategies:
398
+ start_time = time.time()
399
+ response = search_engine.search(
400
+ question, search_type=strategy, limit=10, hops=1
401
+ )
402
+ latency = time.time() - start_time
403
+
404
+ predicted_tables = [r["table"] for r in response["results"]]
405
+
406
+ for k in [1, 3, 5]:
407
+ results_by_strategy[strategy][f"recall_at_{k}"].append(
408
+ calculate_recall_at_k(predicted_tables, ground_truth_tables, k)
409
+ )
410
+ results_by_strategy[strategy][f"precision_at_{k}"].append(
411
+ calculate_precision_at_k(predicted_tables, ground_truth_tables, k)
412
+ )
413
+
414
+ results_by_strategy[strategy]["mrr"].append(
415
+ calculate_mrr(predicted_tables, ground_truth_tables)
416
+ )
417
+ results_by_strategy[strategy]["latency"].append(latency)
418
+ results_by_strategy[strategy]["num_queries"] += 1
419
+
420
+ total_queries_evaluated += 1
421
+
422
+ print(f"\nFiltered out {num_dbs_filtered} databases with <{MIN_TABLES} tables")
423
+ print(
424
+ f"Evaluated {total_queries_evaluated} queries on {len(index_latencies)} databases"
425
+ )
426
+
427
+ if db_engine is not None:
428
+ db_engine.dispose()
429
+
430
+ print(f"\n{'='*80}")
431
+ print("FINAL RESULTS")
432
+ print(f"{'='*80}")
433
+
434
+ if index_latencies:
435
+ mean_index = np.mean(index_latencies)
436
+ std_index = np.std(index_latencies)
437
+ print(f"\nINDEXING")
438
+ print(f" Databases indexed: {len(index_latencies)}")
439
+ print(f" Index latency: {mean_index:.3f}s ± {std_index:.3f}s")
440
+
441
+ for strategy in strategies:
442
+ stats = results_by_strategy[strategy]
443
+
444
+ if stats["num_queries"] == 0:
445
+ continue
446
+
447
+ print(f"\n{strategy.upper()}")
448
+ print(f" Queries: {stats['num_queries']}")
449
+
450
+ metric_order = [
451
+ "recall_at_1",
452
+ "recall_at_3",
453
+ "recall_at_5",
454
+ "mrr",
455
+ "precision_at_1",
456
+ "precision_at_3",
457
+ "precision_at_5",
458
+ "latency",
459
+ ]
460
+ metric_labels = {
461
+ "recall_at_1": "Recall@1",
462
+ "recall_at_3": "Recall@3",
463
+ "recall_at_5": "Recall@5",
464
+ "mrr": "MRR",
465
+ "precision_at_1": "Precision@1",
466
+ "precision_at_3": "Precision@3",
467
+ "precision_at_5": "Precision@5",
468
+ "latency": "Latency",
469
+ }
470
+
471
+ for metric_name in metric_order:
472
+ if metric_name in stats and stats[metric_name]:
473
+ values = np.array(stats[metric_name])
474
+ mean = np.mean(values)
475
+ std = np.std(values)
476
+ label = metric_labels[metric_name]
477
+
478
+ if metric_name == "latency":
479
+ print(f" {label}: {mean:.3f}s ± {std:.3f}s")
480
+ else:
481
+ print(f" {label}: {mean:.3f} ± {std:.3f}")
482
+
483
+ mean_recall_1 = np.mean(stats["recall_at_1"]) if stats["recall_at_1"] else 0.0
484
+ mean_mrr = np.mean(stats["mrr"]) if stats["mrr"] else 0.0
485
+ assert mean_recall_1 >= 0.0, f"{strategy}: Invalid recall@1"
486
+ assert mean_mrr >= 0.0, f"{strategy}: Invalid MRR"
487
+
488
+ save_benchmark_results(results_by_strategy, index_latencies, strategies)