schema-search 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. schema_search/__init__.py +26 -0
  2. schema_search/chunkers/__init__.py +6 -0
  3. schema_search/chunkers/base.py +95 -0
  4. schema_search/chunkers/factory.py +31 -0
  5. schema_search/chunkers/llm.py +54 -0
  6. schema_search/chunkers/markdown.py +25 -0
  7. schema_search/embedding_cache/__init__.py +5 -0
  8. schema_search/embedding_cache/base.py +40 -0
  9. schema_search/embedding_cache/bm25.py +63 -0
  10. schema_search/embedding_cache/factory.py +20 -0
  11. schema_search/embedding_cache/inmemory.py +122 -0
  12. schema_search/graph_builder.py +69 -0
  13. schema_search/mcp_server.py +81 -0
  14. schema_search/metrics.py +33 -0
  15. schema_search/rankers/__init__.py +5 -0
  16. schema_search/rankers/base.py +45 -0
  17. schema_search/rankers/cross_encoder.py +40 -0
  18. schema_search/rankers/factory.py +11 -0
  19. schema_search/schema_extractor.py +135 -0
  20. schema_search/schema_search.py +276 -0
  21. schema_search/search/__init__.py +15 -0
  22. schema_search/search/base.py +85 -0
  23. schema_search/search/bm25.py +48 -0
  24. schema_search/search/factory.py +61 -0
  25. schema_search/search/fuzzy.py +56 -0
  26. schema_search/search/hybrid.py +82 -0
  27. schema_search/search/semantic.py +49 -0
  28. schema_search/types.py +57 -0
  29. schema_search/utils/__init__.py +0 -0
  30. schema_search/utils/lazy_import.py +26 -0
  31. schema_search-0.1.10.dist-info/METADATA +308 -0
  32. schema_search-0.1.10.dist-info/RECORD +40 -0
  33. schema_search-0.1.10.dist-info/WHEEL +5 -0
  34. schema_search-0.1.10.dist-info/entry_points.txt +2 -0
  35. schema_search-0.1.10.dist-info/licenses/LICENSE +21 -0
  36. schema_search-0.1.10.dist-info/top_level.txt +2 -0
  37. tests/__init__.py +0 -0
  38. tests/test_integration.py +352 -0
  39. tests/test_llm_sql_generation.py +320 -0
  40. tests/test_spider_eval.py +488 -0
@@ -0,0 +1,45 @@
1
+ from typing import Dict, List, Tuple
2
+ from collections import defaultdict
3
+ from abc import ABC, abstractmethod
4
+
5
+ from schema_search.chunkers import Chunk
6
+
7
+
8
+ class BaseRanker(ABC):
9
+ def __init__(self):
10
+ self.chunks: List[Chunk]
11
+
12
+ @abstractmethod
13
+ def build(self, chunks: List[Chunk]) -> None:
14
+ pass
15
+
16
+ @abstractmethod
17
+ def rank(self, query: str) -> List[Tuple[int, float]]:
18
+ """Returns: List of (chunk_idx, score)"""
19
+ pass
20
+
21
+ def get_top_tables_from_chunks(
22
+ self, ranked_chunks: List[Tuple[int, float]], top_k: int
23
+ ) -> Dict[str, List[int]]:
24
+ table_to_chunk_indices: Dict[str, List[int]] = defaultdict(list)
25
+ chunk_idx_to_score: Dict[int, float] = {}
26
+
27
+ for chunk_idx, score in ranked_chunks:
28
+ chunk = self.chunks[chunk_idx]
29
+ table_to_chunk_indices[chunk.table_name].append(chunk_idx)
30
+ chunk_idx_to_score[chunk_idx] = score
31
+
32
+ table_scores: Dict[str, float] = {}
33
+ for table_name, chunk_indices in table_to_chunk_indices.items():
34
+ max_score = max(chunk_idx_to_score[idx] for idx in chunk_indices)
35
+ table_scores[table_name] = max_score
36
+
37
+ top_tables = sorted(table_scores.items(), key=lambda x: x[1], reverse=True)[
38
+ :top_k
39
+ ]
40
+
41
+ result: Dict[str, List[int]] = {}
42
+ for table_name, score in top_tables:
43
+ result[table_name] = table_to_chunk_indices[table_name]
44
+
45
+ return result
@@ -0,0 +1,40 @@
1
+ from typing import List, Tuple, Optional, TYPE_CHECKING
2
+ import logging
3
+
4
+ from schema_search.chunkers import Chunk
5
+ from schema_search.rankers.base import BaseRanker
6
+ from schema_search.utils.lazy_import import lazy_import_check
7
+
8
+ if TYPE_CHECKING:
9
+ from sentence_transformers import CrossEncoder
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class CrossEncoderRanker(BaseRanker):
15
+ def __init__(self, model_name: str):
16
+ super().__init__()
17
+ self.model_name = model_name
18
+ self.model: Optional["CrossEncoder"] = None
19
+
20
+ def _load_model(self) -> "CrossEncoder":
21
+ if self.model is None:
22
+ sentence_transformers = lazy_import_check(
23
+ "sentence_transformers", "semantic", "reranking with CrossEncoder"
24
+ )
25
+ logging.getLogger("sentence_transformers").setLevel(logging.WARNING)
26
+ self.model = sentence_transformers.CrossEncoder(self.model_name)
27
+ assert self.model is not None
28
+ logger.info(f"Loaded CrossEncoder: {self.model_name}")
29
+ return self.model
30
+
31
+ def build(self, chunks: List[Chunk]) -> None:
32
+ self.chunks = chunks
33
+ logger.debug(f"Initialized CrossEncoder reranker with {len(chunks)} chunks")
34
+
35
+ def rank(self, query: str) -> List[Tuple[int, float]]:
36
+ model = self._load_model()
37
+ pairs = [(query, chunk.content) for chunk in self.chunks]
38
+ scores = model.predict(pairs, show_progress_bar=False)
39
+ ranked_indices = scores.argsort()[::-1]
40
+ return [(int(idx), float(scores[idx])) for idx in ranked_indices]
@@ -0,0 +1,11 @@
1
+ from typing import Dict, Optional
2
+
3
+ from schema_search.rankers.base import BaseRanker
4
+ from schema_search.rankers.cross_encoder import CrossEncoderRanker
5
+
6
+
7
+ def create_ranker(config: Dict) -> Optional[BaseRanker]:
8
+ reranker_model = config["reranker"]["model"]
9
+ if reranker_model is None:
10
+ return None
11
+ return CrossEncoderRanker(model_name=reranker_model)
@@ -0,0 +1,135 @@
1
+ from typing import Dict, List, Any
2
+ from sqlalchemy import inspect
3
+ from sqlalchemy.engine import Engine
4
+
5
+ from schema_search.types import (
6
+ TableSchema,
7
+ ColumnInfo,
8
+ ForeignKeyInfo,
9
+ IndexInfo,
10
+ ConstraintInfo,
11
+ )
12
+
13
+
14
+ class SchemaExtractor:
15
+ def __init__(self, engine: Engine, config: Dict[str, Any]):
16
+ self.engine = engine
17
+ self.config = config
18
+
19
+ def extract(self) -> Dict[str, TableSchema]:
20
+ inspector = inspect(self.engine)
21
+ schemas: Dict[str, TableSchema] = {}
22
+
23
+ schema_names = inspector.get_schema_names()
24
+ for schema_name in schema_names:
25
+ if self._should_skip_schema(schema_name):
26
+ continue
27
+
28
+ for table_name in inspector.get_table_names(schema=schema_name):
29
+ schemas[table_name] = self._extract_table(
30
+ inspector, table_name, schema_name
31
+ )
32
+
33
+ return schemas
34
+
35
+ def _should_skip_schema(self, schema_name: str) -> bool:
36
+ skip = {
37
+ "information_schema",
38
+ "pg_catalog",
39
+ "pg_toast",
40
+ "performance_schema",
41
+ "mysql",
42
+ "sys",
43
+ }
44
+ return schema_name.lower() in skip
45
+
46
+ def _extract_table(
47
+ self, inspector, table_name: str, schema_name: str
48
+ ) -> TableSchema:
49
+ pk_constraint = inspector.get_pk_constraint(table_name, schema=schema_name)
50
+
51
+ schema: TableSchema = {
52
+ "name": table_name,
53
+ "columns": (
54
+ self._extract_columns(
55
+ inspector.get_columns(table_name, schema=schema_name)
56
+ )
57
+ if self.config["schema"]["include_columns"]
58
+ else None
59
+ ),
60
+ "primary_keys": pk_constraint["constrained_columns"],
61
+ "foreign_keys": (
62
+ self._extract_foreign_keys(
63
+ inspector.get_foreign_keys(table_name, schema=schema_name)
64
+ )
65
+ if self.config["schema"]["include_foreign_keys"]
66
+ else None
67
+ ),
68
+ "indices": (
69
+ self._extract_indices(
70
+ inspector.get_indexes(table_name, schema=schema_name)
71
+ )
72
+ if self.config["schema"]["include_indices"]
73
+ else None
74
+ ),
75
+ "unique_constraints": (
76
+ self._extract_constraints(
77
+ inspector.get_unique_constraints(table_name, schema=schema_name)
78
+ )
79
+ if self.config["schema"]["include_constraints"]
80
+ else None
81
+ ),
82
+ "check_constraints": (
83
+ self._extract_constraints(
84
+ inspector.get_check_constraints(table_name, schema=schema_name)
85
+ )
86
+ if self.config["schema"]["include_constraints"]
87
+ else None
88
+ ),
89
+ }
90
+
91
+ return schema
92
+
93
+ def _extract_columns(self, columns: List[Dict[str, Any]]) -> List[ColumnInfo]:
94
+ return [
95
+ {
96
+ "name": col["name"],
97
+ "type": str(col["type"]),
98
+ "nullable": col["nullable"],
99
+ "default": str(col["default"]) if col["default"] else None,
100
+ }
101
+ for col in columns
102
+ ]
103
+
104
+ def _extract_foreign_keys(
105
+ self, foreign_keys: List[Dict[str, Any]]
106
+ ) -> List[ForeignKeyInfo]:
107
+ return [
108
+ {
109
+ "constrained_columns": fk["constrained_columns"],
110
+ "referred_table": fk["referred_table"],
111
+ "referred_columns": fk["referred_columns"],
112
+ }
113
+ for fk in foreign_keys
114
+ ]
115
+
116
+ def _extract_indices(self, indices: List[Dict[str, Any]]) -> List[IndexInfo]:
117
+ return [
118
+ {
119
+ "name": idx["name"],
120
+ "columns": idx["column_names"],
121
+ "unique": idx["unique"],
122
+ }
123
+ for idx in indices
124
+ ]
125
+
126
+ def _extract_constraints(
127
+ self, constraints: List[Dict[str, Any]]
128
+ ) -> List[ConstraintInfo]:
129
+ return [
130
+ {
131
+ "name": constraint["name"],
132
+ "columns": constraint["column_names"],
133
+ }
134
+ for constraint in constraints
135
+ ]
@@ -0,0 +1,276 @@
1
+ import json
2
+ import logging
3
+ import time
4
+ from functools import wraps
5
+ from pathlib import Path
6
+ from typing import Dict, List, Optional
7
+
8
+ import yaml
9
+ from sqlalchemy.engine import Engine
10
+
11
+ from schema_search.schema_extractor import SchemaExtractor
12
+ from schema_search.chunkers import Chunk, create_chunker
13
+ from schema_search.embedding_cache import create_embedding_cache
14
+ from schema_search.embedding_cache.bm25 import BM25Cache
15
+ from schema_search.graph_builder import GraphBuilder
16
+ from schema_search.search import create_search_strategy
17
+ from schema_search.types import IndexResult, SearchResult, SearchType, TableSchema
18
+ from schema_search.rankers import create_ranker
19
+
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def time_it(func):
25
+ @wraps(func)
26
+ def wrapper(*args, **kwargs):
27
+ start = time.time()
28
+ result = func(*args, **kwargs)
29
+ elapsed = time.time() - start
30
+ if isinstance(result, dict):
31
+ result["latency_sec"] = round(elapsed, 3)
32
+ return result
33
+
34
+ return wrapper
35
+
36
+
37
+ class SchemaSearch:
38
+ def __init__(
39
+ self,
40
+ engine: Engine,
41
+ config_path: Optional[str] = None,
42
+ llm_api_key: Optional[str] = None,
43
+ llm_base_url: Optional[str] = None,
44
+ ):
45
+ self.config = self._load_config(config_path)
46
+ self._setup_logging()
47
+
48
+ base_cache_dir = Path(self.config["embedding"]["cache_dir"])
49
+ db_name = engine.url.database or "default"
50
+ cache_dir = base_cache_dir / db_name
51
+ cache_dir.mkdir(parents=True, exist_ok=True)
52
+
53
+ self.schemas: Dict[str, TableSchema] = {}
54
+ self.chunks: List[Chunk] = []
55
+ self.cache_dir = cache_dir
56
+
57
+ self._validate_dependencies()
58
+
59
+ self.schema_extractor = SchemaExtractor(engine, self.config)
60
+ self.chunker = create_chunker(self.config, llm_api_key, llm_base_url)
61
+ self._embedding_cache = None
62
+ self._bm25_cache = None
63
+ self.graph_builder = GraphBuilder(cache_dir)
64
+ self._reranker = None
65
+ self._reranker_config = self.config["reranker"]["model"]
66
+ self._search_strategies = {}
67
+
68
+ def _setup_logging(self) -> None:
69
+ level = getattr(logging, self.config["logging"]["level"])
70
+ logging.basicConfig(
71
+ level=level,
72
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
73
+ force=True,
74
+ )
75
+ logger.setLevel(level)
76
+
77
+ def _load_config(self, config_path: Optional[str]) -> Dict:
78
+ if config_path is None:
79
+ config_path = str(Path(__file__).parent.parent / "config.yml")
80
+
81
+ with open(config_path) as f:
82
+ return yaml.safe_load(f)
83
+
84
+ def _validate_dependencies(self) -> None:
85
+ from schema_search.utils.lazy_import import lazy_import_check
86
+
87
+ strategy = self.config["search"]["strategy"]
88
+ reranker_model = self.config["reranker"]["model"]
89
+ chunking_strategy = self.config["chunking"]["strategy"]
90
+
91
+ needs_semantic = strategy in ("semantic", "hybrid") or reranker_model
92
+ if needs_semantic:
93
+ lazy_import_check(
94
+ "sentence_transformers",
95
+ "semantic",
96
+ f"{strategy} search or reranking"
97
+ )
98
+
99
+ if chunking_strategy == "llm":
100
+ lazy_import_check("openai", "llm", "LLM-based chunking")
101
+
102
+ @time_it
103
+ def index(self, force: bool = False) -> IndexResult:
104
+ logger.info("Starting schema indexing" + (" (force)" if force else ""))
105
+
106
+ current_schema = self._extract_current_schema()
107
+
108
+ schema_changed = False
109
+ if not force:
110
+ cached_schema = self._load_cached_schema()
111
+ schema_changed = self._schema_has_changed(cached_schema, current_schema)
112
+ if schema_changed:
113
+ logger.info("Schema change detected; forcing reindex")
114
+
115
+ self._cache_schema(current_schema)
116
+
117
+ effective_force = force or schema_changed
118
+
119
+ self.schemas = current_schema
120
+ self.graph_builder.build(self.schemas, effective_force)
121
+ self.chunks = self._load_or_generate_chunks(self.schemas, effective_force)
122
+ self._index_force = effective_force
123
+
124
+ logger.info(
125
+ f"Indexing complete: {len(self.schemas)} tables, {len(self.chunks)} chunks"
126
+ )
127
+ return {
128
+ "tables": len(self.schemas),
129
+ "chunks": len(self.chunks),
130
+ "latency_sec": 0.0,
131
+ }
132
+
133
+ def _extract_current_schema(self) -> Dict[str, TableSchema]:
134
+ logger.info("Extracting schema from database")
135
+ return self.schema_extractor.extract()
136
+
137
+ def _load_cached_schema(self) -> Optional[Dict[str, TableSchema]]:
138
+ schema_cache = self.cache_dir / "metadata.json"
139
+
140
+ if not schema_cache.exists():
141
+ logger.debug("Schema cache missing; treating as schema change")
142
+ return None
143
+
144
+ with open(schema_cache) as f:
145
+ return json.load(f)
146
+
147
+ def _cache_schema(self, schema: Dict[str, TableSchema]) -> None:
148
+ schema_cache = self.cache_dir / "metadata.json"
149
+ with open(schema_cache, "w") as f:
150
+ json.dump(schema, f, indent=2)
151
+
152
+ def _schema_has_changed(
153
+ self,
154
+ cached_schema: Optional[Dict[str, TableSchema]],
155
+ current_schema: Dict[str, TableSchema],
156
+ ) -> bool:
157
+ if cached_schema is None:
158
+ return True
159
+ if cached_schema != current_schema:
160
+ logger.debug("Cached schema differs from current schema")
161
+ return True
162
+ logger.debug("Schema matches cached version; reuse existing index")
163
+ return False
164
+
165
+ def _load_or_generate_chunks(
166
+ self, schemas: Dict[str, TableSchema], force: bool
167
+ ) -> List[Chunk]:
168
+ chunks_cache = self.cache_dir / "chunk_metadata.json"
169
+
170
+ if not force and chunks_cache.exists():
171
+ logger.info(f"Loading chunks from cache: {chunks_cache}")
172
+ with open(chunks_cache) as f:
173
+ chunk_data = json.load(f)
174
+ return [
175
+ Chunk(
176
+ table_name=c["table_name"],
177
+ content=c["content"],
178
+ chunk_id=c["chunk_id"],
179
+ token_count=c["token_count"],
180
+ )
181
+ for c in chunk_data
182
+ ]
183
+
184
+ logger.info("Generating chunks from schemas")
185
+ chunks = self.chunker.chunk_schemas(schemas)
186
+
187
+ with open(chunks_cache, "w") as f:
188
+ chunk_data = [
189
+ {
190
+ "table_name": c.table_name,
191
+ "content": c.content,
192
+ "chunk_id": c.chunk_id,
193
+ "token_count": c.token_count,
194
+ }
195
+ for c in chunks
196
+ ]
197
+ json.dump(chunk_data, f, indent=2)
198
+
199
+ return chunks
200
+
201
+ def _get_embedding_cache(self):
202
+ if self._embedding_cache is None:
203
+ self._embedding_cache = create_embedding_cache(self.config, self.cache_dir)
204
+ return self._embedding_cache
205
+
206
+ def _get_reranker(self):
207
+ if self._reranker is None and self._reranker_config:
208
+ self._reranker = create_ranker(self.config)
209
+ return self._reranker
210
+
211
+ @property
212
+ def embedding_cache(self):
213
+ return self._get_embedding_cache()
214
+
215
+ @property
216
+ def reranker(self):
217
+ return self._get_reranker()
218
+
219
+ def _get_bm25_cache(self):
220
+ if self._bm25_cache is None:
221
+ self._bm25_cache = BM25Cache()
222
+ return self._bm25_cache
223
+
224
+ def _ensure_embeddings_loaded(self):
225
+ cache = self._get_embedding_cache()
226
+ if cache.embeddings is None:
227
+ cache.load_or_generate(
228
+ self.chunks, self._index_force, self.config["chunking"]
229
+ )
230
+
231
+ def _ensure_bm25_built(self):
232
+ cache = self._get_bm25_cache()
233
+ if cache.bm25 is None:
234
+ logger.info("Building BM25 index")
235
+ cache.build(self.chunks)
236
+
237
+ def _get_search_strategy(self, search_type: str):
238
+ if search_type not in self._search_strategies:
239
+ self._search_strategies[search_type] = create_search_strategy(
240
+ self.config,
241
+ self._get_embedding_cache,
242
+ self._get_bm25_cache,
243
+ self._get_reranker,
244
+ search_type,
245
+ )
246
+ return self._search_strategies[search_type]
247
+
248
+ @time_it
249
+ def search(
250
+ self,
251
+ query: str,
252
+ hops: Optional[int] = None,
253
+ limit: int = 5,
254
+ search_type: Optional[SearchType] = None,
255
+ ) -> SearchResult:
256
+ if hops is None:
257
+ hops = int(self.config["search"]["hops"])
258
+ logger.debug(f"Searching: {query} (hops={hops}, search_type={search_type})")
259
+
260
+ search_type = search_type or self.config["search"]["strategy"]
261
+
262
+ if search_type in ["semantic", "hybrid"]:
263
+ self._ensure_embeddings_loaded()
264
+
265
+ if search_type in ["bm25", "hybrid"]:
266
+ self._ensure_bm25_built()
267
+
268
+ strategy = self._get_search_strategy(search_type)
269
+
270
+ results = strategy.search(
271
+ query, self.schemas, self.chunks, self.graph_builder, hops, limit
272
+ )
273
+
274
+ logger.debug(f"Found {len(results)} results")
275
+
276
+ return {"results": results, "latency_sec": 0.0}
@@ -0,0 +1,15 @@
1
+ from schema_search.search.base import BaseSearchStrategy
2
+ from schema_search.search.semantic import SemanticSearchStrategy
3
+ from schema_search.search.fuzzy import FuzzySearchStrategy
4
+ from schema_search.search.bm25 import BM25SearchStrategy
5
+ from schema_search.search.hybrid import HybridSearchStrategy
6
+ from schema_search.search.factory import create_search_strategy
7
+
8
+ __all__ = [
9
+ "BaseSearchStrategy",
10
+ "SemanticSearchStrategy",
11
+ "FuzzySearchStrategy",
12
+ "BM25SearchStrategy",
13
+ "HybridSearchStrategy",
14
+ "create_search_strategy",
15
+ ]
@@ -0,0 +1,85 @@
1
+ from typing import Dict, List, Optional
2
+ from abc import ABC, abstractmethod
3
+
4
+ from schema_search.types import TableSchema, SearchResultItem
5
+ from schema_search.chunkers import Chunk
6
+ from schema_search.graph_builder import GraphBuilder
7
+ from schema_search.rankers.base import BaseRanker
8
+
9
+
10
+ class BaseSearchStrategy(ABC):
11
+ def __init__(
12
+ self, reranker: Optional[BaseRanker], initial_top_k: int, rerank_top_k: int
13
+ ):
14
+ self.reranker = reranker
15
+ self.initial_top_k = initial_top_k
16
+ self.rerank_top_k = rerank_top_k
17
+
18
+ def search(
19
+ self,
20
+ query: str,
21
+ schemas: Dict[str, TableSchema],
22
+ chunks: List[Chunk],
23
+ graph_builder: GraphBuilder,
24
+ hops: int,
25
+ limit: int,
26
+ ) -> List[SearchResultItem]:
27
+ initial_results = self._initial_ranking(
28
+ query, schemas, chunks, graph_builder, hops
29
+ )
30
+
31
+ if self.reranker is None:
32
+ return initial_results[:limit]
33
+
34
+ initial_chunks = []
35
+ for result in initial_results:
36
+ for chunk in chunks:
37
+ if chunk.table_name == result["table"]:
38
+ initial_chunks.append(chunk)
39
+ break
40
+
41
+ self.reranker.build(initial_chunks)
42
+ ranked = self.reranker.rank(query)
43
+
44
+ reranked_results: List[SearchResultItem] = []
45
+ for chunk_idx, score in ranked[: self.rerank_top_k]:
46
+ chunk = initial_chunks[chunk_idx]
47
+ result = self._build_result_item(
48
+ table_name=chunk.table_name,
49
+ score=score,
50
+ schema=schemas[chunk.table_name],
51
+ matched_chunks=[chunk.content],
52
+ graph_builder=graph_builder,
53
+ hops=hops,
54
+ )
55
+ reranked_results.append(result)
56
+
57
+ return reranked_results[:limit]
58
+
59
+ @abstractmethod
60
+ def _initial_ranking(
61
+ self,
62
+ query: str,
63
+ schemas: Dict[str, TableSchema],
64
+ chunks: List[Chunk],
65
+ graph_builder: GraphBuilder,
66
+ hops: int,
67
+ ) -> List[SearchResultItem]:
68
+ pass
69
+
70
+ def _build_result_item(
71
+ self,
72
+ table_name: str,
73
+ score: float,
74
+ schema: TableSchema,
75
+ matched_chunks: List[str],
76
+ graph_builder: GraphBuilder,
77
+ hops: int,
78
+ ) -> SearchResultItem:
79
+ return {
80
+ "table": table_name,
81
+ "score": score,
82
+ "schema": schema,
83
+ "matched_chunks": matched_chunks,
84
+ "related_tables": list(graph_builder.get_neighbors(table_name, hops)),
85
+ }
@@ -0,0 +1,48 @@
1
+ from typing import Dict, List, Optional, TYPE_CHECKING
2
+
3
+ from schema_search.search.base import BaseSearchStrategy
4
+ from schema_search.types import TableSchema, SearchResultItem
5
+ from schema_search.chunkers import Chunk
6
+ from schema_search.graph_builder import GraphBuilder
7
+ from schema_search.rankers.base import BaseRanker
8
+
9
+ if TYPE_CHECKING:
10
+ from schema_search.embedding_cache.bm25 import BM25Cache
11
+
12
+
13
+ class BM25SearchStrategy(BaseSearchStrategy):
14
+ def __init__(
15
+ self,
16
+ bm25_cache: "BM25Cache",
17
+ initial_top_k: int,
18
+ rerank_top_k: int,
19
+ reranker: Optional[BaseRanker],
20
+ ):
21
+ super().__init__(reranker, initial_top_k, rerank_top_k)
22
+ self.bm25_cache = bm25_cache
23
+
24
+ def _initial_ranking(
25
+ self,
26
+ query: str,
27
+ schemas: Dict[str, TableSchema],
28
+ chunks: List[Chunk],
29
+ graph_builder: GraphBuilder,
30
+ hops: int,
31
+ ) -> List[SearchResultItem]:
32
+ scores = self.bm25_cache.get_scores(query)
33
+ top_indices = scores.argsort()[::-1][: self.initial_top_k]
34
+
35
+ results: List[SearchResultItem] = []
36
+ for idx in top_indices:
37
+ chunk = chunks[idx]
38
+ result = self._build_result_item(
39
+ table_name=chunk.table_name,
40
+ score=float(scores[idx]),
41
+ schema=schemas[chunk.table_name],
42
+ matched_chunks=[chunk.content],
43
+ graph_builder=graph_builder,
44
+ hops=hops,
45
+ )
46
+ results.append(result)
47
+
48
+ return results