schema-search 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- schema_search/__init__.py +26 -0
- schema_search/chunkers/__init__.py +6 -0
- schema_search/chunkers/base.py +95 -0
- schema_search/chunkers/factory.py +31 -0
- schema_search/chunkers/llm.py +54 -0
- schema_search/chunkers/markdown.py +25 -0
- schema_search/embedding_cache/__init__.py +5 -0
- schema_search/embedding_cache/base.py +40 -0
- schema_search/embedding_cache/bm25.py +63 -0
- schema_search/embedding_cache/factory.py +20 -0
- schema_search/embedding_cache/inmemory.py +122 -0
- schema_search/graph_builder.py +69 -0
- schema_search/mcp_server.py +81 -0
- schema_search/metrics.py +33 -0
- schema_search/rankers/__init__.py +5 -0
- schema_search/rankers/base.py +45 -0
- schema_search/rankers/cross_encoder.py +40 -0
- schema_search/rankers/factory.py +11 -0
- schema_search/schema_extractor.py +135 -0
- schema_search/schema_search.py +276 -0
- schema_search/search/__init__.py +15 -0
- schema_search/search/base.py +85 -0
- schema_search/search/bm25.py +48 -0
- schema_search/search/factory.py +61 -0
- schema_search/search/fuzzy.py +56 -0
- schema_search/search/hybrid.py +82 -0
- schema_search/search/semantic.py +49 -0
- schema_search/types.py +57 -0
- schema_search/utils/__init__.py +0 -0
- schema_search/utils/lazy_import.py +26 -0
- schema_search-0.1.10.dist-info/METADATA +308 -0
- schema_search-0.1.10.dist-info/RECORD +40 -0
- schema_search-0.1.10.dist-info/WHEEL +5 -0
- schema_search-0.1.10.dist-info/entry_points.txt +2 -0
- schema_search-0.1.10.dist-info/licenses/LICENSE +21 -0
- schema_search-0.1.10.dist-info/top_level.txt +2 -0
- tests/__init__.py +0 -0
- tests/test_integration.py +352 -0
- tests/test_llm_sql_generation.py +320 -0
- tests/test_spider_eval.py +488 -0
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from typing import Dict, List, Tuple
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
|
|
5
|
+
from schema_search.chunkers import Chunk
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BaseRanker(ABC):
|
|
9
|
+
def __init__(self):
|
|
10
|
+
self.chunks: List[Chunk]
|
|
11
|
+
|
|
12
|
+
@abstractmethod
|
|
13
|
+
def build(self, chunks: List[Chunk]) -> None:
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def rank(self, query: str) -> List[Tuple[int, float]]:
|
|
18
|
+
"""Returns: List of (chunk_idx, score)"""
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
def get_top_tables_from_chunks(
|
|
22
|
+
self, ranked_chunks: List[Tuple[int, float]], top_k: int
|
|
23
|
+
) -> Dict[str, List[int]]:
|
|
24
|
+
table_to_chunk_indices: Dict[str, List[int]] = defaultdict(list)
|
|
25
|
+
chunk_idx_to_score: Dict[int, float] = {}
|
|
26
|
+
|
|
27
|
+
for chunk_idx, score in ranked_chunks:
|
|
28
|
+
chunk = self.chunks[chunk_idx]
|
|
29
|
+
table_to_chunk_indices[chunk.table_name].append(chunk_idx)
|
|
30
|
+
chunk_idx_to_score[chunk_idx] = score
|
|
31
|
+
|
|
32
|
+
table_scores: Dict[str, float] = {}
|
|
33
|
+
for table_name, chunk_indices in table_to_chunk_indices.items():
|
|
34
|
+
max_score = max(chunk_idx_to_score[idx] for idx in chunk_indices)
|
|
35
|
+
table_scores[table_name] = max_score
|
|
36
|
+
|
|
37
|
+
top_tables = sorted(table_scores.items(), key=lambda x: x[1], reverse=True)[
|
|
38
|
+
:top_k
|
|
39
|
+
]
|
|
40
|
+
|
|
41
|
+
result: Dict[str, List[int]] = {}
|
|
42
|
+
for table_name, score in top_tables:
|
|
43
|
+
result[table_name] = table_to_chunk_indices[table_name]
|
|
44
|
+
|
|
45
|
+
return result
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from typing import List, Tuple, Optional, TYPE_CHECKING
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
from schema_search.chunkers import Chunk
|
|
5
|
+
from schema_search.rankers.base import BaseRanker
|
|
6
|
+
from schema_search.utils.lazy_import import lazy_import_check
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from sentence_transformers import CrossEncoder
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class CrossEncoderRanker(BaseRanker):
|
|
15
|
+
def __init__(self, model_name: str):
|
|
16
|
+
super().__init__()
|
|
17
|
+
self.model_name = model_name
|
|
18
|
+
self.model: Optional["CrossEncoder"] = None
|
|
19
|
+
|
|
20
|
+
def _load_model(self) -> "CrossEncoder":
|
|
21
|
+
if self.model is None:
|
|
22
|
+
sentence_transformers = lazy_import_check(
|
|
23
|
+
"sentence_transformers", "semantic", "reranking with CrossEncoder"
|
|
24
|
+
)
|
|
25
|
+
logging.getLogger("sentence_transformers").setLevel(logging.WARNING)
|
|
26
|
+
self.model = sentence_transformers.CrossEncoder(self.model_name)
|
|
27
|
+
assert self.model is not None
|
|
28
|
+
logger.info(f"Loaded CrossEncoder: {self.model_name}")
|
|
29
|
+
return self.model
|
|
30
|
+
|
|
31
|
+
def build(self, chunks: List[Chunk]) -> None:
|
|
32
|
+
self.chunks = chunks
|
|
33
|
+
logger.debug(f"Initialized CrossEncoder reranker with {len(chunks)} chunks")
|
|
34
|
+
|
|
35
|
+
def rank(self, query: str) -> List[Tuple[int, float]]:
|
|
36
|
+
model = self._load_model()
|
|
37
|
+
pairs = [(query, chunk.content) for chunk in self.chunks]
|
|
38
|
+
scores = model.predict(pairs, show_progress_bar=False)
|
|
39
|
+
ranked_indices = scores.argsort()[::-1]
|
|
40
|
+
return [(int(idx), float(scores[idx])) for idx in ranked_indices]
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from typing import Dict, Optional
|
|
2
|
+
|
|
3
|
+
from schema_search.rankers.base import BaseRanker
|
|
4
|
+
from schema_search.rankers.cross_encoder import CrossEncoderRanker
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def create_ranker(config: Dict) -> Optional[BaseRanker]:
|
|
8
|
+
reranker_model = config["reranker"]["model"]
|
|
9
|
+
if reranker_model is None:
|
|
10
|
+
return None
|
|
11
|
+
return CrossEncoderRanker(model_name=reranker_model)
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
from typing import Dict, List, Any
|
|
2
|
+
from sqlalchemy import inspect
|
|
3
|
+
from sqlalchemy.engine import Engine
|
|
4
|
+
|
|
5
|
+
from schema_search.types import (
|
|
6
|
+
TableSchema,
|
|
7
|
+
ColumnInfo,
|
|
8
|
+
ForeignKeyInfo,
|
|
9
|
+
IndexInfo,
|
|
10
|
+
ConstraintInfo,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SchemaExtractor:
|
|
15
|
+
def __init__(self, engine: Engine, config: Dict[str, Any]):
|
|
16
|
+
self.engine = engine
|
|
17
|
+
self.config = config
|
|
18
|
+
|
|
19
|
+
def extract(self) -> Dict[str, TableSchema]:
|
|
20
|
+
inspector = inspect(self.engine)
|
|
21
|
+
schemas: Dict[str, TableSchema] = {}
|
|
22
|
+
|
|
23
|
+
schema_names = inspector.get_schema_names()
|
|
24
|
+
for schema_name in schema_names:
|
|
25
|
+
if self._should_skip_schema(schema_name):
|
|
26
|
+
continue
|
|
27
|
+
|
|
28
|
+
for table_name in inspector.get_table_names(schema=schema_name):
|
|
29
|
+
schemas[table_name] = self._extract_table(
|
|
30
|
+
inspector, table_name, schema_name
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
return schemas
|
|
34
|
+
|
|
35
|
+
def _should_skip_schema(self, schema_name: str) -> bool:
|
|
36
|
+
skip = {
|
|
37
|
+
"information_schema",
|
|
38
|
+
"pg_catalog",
|
|
39
|
+
"pg_toast",
|
|
40
|
+
"performance_schema",
|
|
41
|
+
"mysql",
|
|
42
|
+
"sys",
|
|
43
|
+
}
|
|
44
|
+
return schema_name.lower() in skip
|
|
45
|
+
|
|
46
|
+
def _extract_table(
|
|
47
|
+
self, inspector, table_name: str, schema_name: str
|
|
48
|
+
) -> TableSchema:
|
|
49
|
+
pk_constraint = inspector.get_pk_constraint(table_name, schema=schema_name)
|
|
50
|
+
|
|
51
|
+
schema: TableSchema = {
|
|
52
|
+
"name": table_name,
|
|
53
|
+
"columns": (
|
|
54
|
+
self._extract_columns(
|
|
55
|
+
inspector.get_columns(table_name, schema=schema_name)
|
|
56
|
+
)
|
|
57
|
+
if self.config["schema"]["include_columns"]
|
|
58
|
+
else None
|
|
59
|
+
),
|
|
60
|
+
"primary_keys": pk_constraint["constrained_columns"],
|
|
61
|
+
"foreign_keys": (
|
|
62
|
+
self._extract_foreign_keys(
|
|
63
|
+
inspector.get_foreign_keys(table_name, schema=schema_name)
|
|
64
|
+
)
|
|
65
|
+
if self.config["schema"]["include_foreign_keys"]
|
|
66
|
+
else None
|
|
67
|
+
),
|
|
68
|
+
"indices": (
|
|
69
|
+
self._extract_indices(
|
|
70
|
+
inspector.get_indexes(table_name, schema=schema_name)
|
|
71
|
+
)
|
|
72
|
+
if self.config["schema"]["include_indices"]
|
|
73
|
+
else None
|
|
74
|
+
),
|
|
75
|
+
"unique_constraints": (
|
|
76
|
+
self._extract_constraints(
|
|
77
|
+
inspector.get_unique_constraints(table_name, schema=schema_name)
|
|
78
|
+
)
|
|
79
|
+
if self.config["schema"]["include_constraints"]
|
|
80
|
+
else None
|
|
81
|
+
),
|
|
82
|
+
"check_constraints": (
|
|
83
|
+
self._extract_constraints(
|
|
84
|
+
inspector.get_check_constraints(table_name, schema=schema_name)
|
|
85
|
+
)
|
|
86
|
+
if self.config["schema"]["include_constraints"]
|
|
87
|
+
else None
|
|
88
|
+
),
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
return schema
|
|
92
|
+
|
|
93
|
+
def _extract_columns(self, columns: List[Dict[str, Any]]) -> List[ColumnInfo]:
|
|
94
|
+
return [
|
|
95
|
+
{
|
|
96
|
+
"name": col["name"],
|
|
97
|
+
"type": str(col["type"]),
|
|
98
|
+
"nullable": col["nullable"],
|
|
99
|
+
"default": str(col["default"]) if col["default"] else None,
|
|
100
|
+
}
|
|
101
|
+
for col in columns
|
|
102
|
+
]
|
|
103
|
+
|
|
104
|
+
def _extract_foreign_keys(
|
|
105
|
+
self, foreign_keys: List[Dict[str, Any]]
|
|
106
|
+
) -> List[ForeignKeyInfo]:
|
|
107
|
+
return [
|
|
108
|
+
{
|
|
109
|
+
"constrained_columns": fk["constrained_columns"],
|
|
110
|
+
"referred_table": fk["referred_table"],
|
|
111
|
+
"referred_columns": fk["referred_columns"],
|
|
112
|
+
}
|
|
113
|
+
for fk in foreign_keys
|
|
114
|
+
]
|
|
115
|
+
|
|
116
|
+
def _extract_indices(self, indices: List[Dict[str, Any]]) -> List[IndexInfo]:
|
|
117
|
+
return [
|
|
118
|
+
{
|
|
119
|
+
"name": idx["name"],
|
|
120
|
+
"columns": idx["column_names"],
|
|
121
|
+
"unique": idx["unique"],
|
|
122
|
+
}
|
|
123
|
+
for idx in indices
|
|
124
|
+
]
|
|
125
|
+
|
|
126
|
+
def _extract_constraints(
|
|
127
|
+
self, constraints: List[Dict[str, Any]]
|
|
128
|
+
) -> List[ConstraintInfo]:
|
|
129
|
+
return [
|
|
130
|
+
{
|
|
131
|
+
"name": constraint["name"],
|
|
132
|
+
"columns": constraint["column_names"],
|
|
133
|
+
}
|
|
134
|
+
for constraint in constraints
|
|
135
|
+
]
|
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import time
|
|
4
|
+
from functools import wraps
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Dict, List, Optional
|
|
7
|
+
|
|
8
|
+
import yaml
|
|
9
|
+
from sqlalchemy.engine import Engine
|
|
10
|
+
|
|
11
|
+
from schema_search.schema_extractor import SchemaExtractor
|
|
12
|
+
from schema_search.chunkers import Chunk, create_chunker
|
|
13
|
+
from schema_search.embedding_cache import create_embedding_cache
|
|
14
|
+
from schema_search.embedding_cache.bm25 import BM25Cache
|
|
15
|
+
from schema_search.graph_builder import GraphBuilder
|
|
16
|
+
from schema_search.search import create_search_strategy
|
|
17
|
+
from schema_search.types import IndexResult, SearchResult, SearchType, TableSchema
|
|
18
|
+
from schema_search.rankers import create_ranker
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def time_it(func):
|
|
25
|
+
@wraps(func)
|
|
26
|
+
def wrapper(*args, **kwargs):
|
|
27
|
+
start = time.time()
|
|
28
|
+
result = func(*args, **kwargs)
|
|
29
|
+
elapsed = time.time() - start
|
|
30
|
+
if isinstance(result, dict):
|
|
31
|
+
result["latency_sec"] = round(elapsed, 3)
|
|
32
|
+
return result
|
|
33
|
+
|
|
34
|
+
return wrapper
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class SchemaSearch:
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
engine: Engine,
|
|
41
|
+
config_path: Optional[str] = None,
|
|
42
|
+
llm_api_key: Optional[str] = None,
|
|
43
|
+
llm_base_url: Optional[str] = None,
|
|
44
|
+
):
|
|
45
|
+
self.config = self._load_config(config_path)
|
|
46
|
+
self._setup_logging()
|
|
47
|
+
|
|
48
|
+
base_cache_dir = Path(self.config["embedding"]["cache_dir"])
|
|
49
|
+
db_name = engine.url.database or "default"
|
|
50
|
+
cache_dir = base_cache_dir / db_name
|
|
51
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
52
|
+
|
|
53
|
+
self.schemas: Dict[str, TableSchema] = {}
|
|
54
|
+
self.chunks: List[Chunk] = []
|
|
55
|
+
self.cache_dir = cache_dir
|
|
56
|
+
|
|
57
|
+
self._validate_dependencies()
|
|
58
|
+
|
|
59
|
+
self.schema_extractor = SchemaExtractor(engine, self.config)
|
|
60
|
+
self.chunker = create_chunker(self.config, llm_api_key, llm_base_url)
|
|
61
|
+
self._embedding_cache = None
|
|
62
|
+
self._bm25_cache = None
|
|
63
|
+
self.graph_builder = GraphBuilder(cache_dir)
|
|
64
|
+
self._reranker = None
|
|
65
|
+
self._reranker_config = self.config["reranker"]["model"]
|
|
66
|
+
self._search_strategies = {}
|
|
67
|
+
|
|
68
|
+
def _setup_logging(self) -> None:
|
|
69
|
+
level = getattr(logging, self.config["logging"]["level"])
|
|
70
|
+
logging.basicConfig(
|
|
71
|
+
level=level,
|
|
72
|
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
73
|
+
force=True,
|
|
74
|
+
)
|
|
75
|
+
logger.setLevel(level)
|
|
76
|
+
|
|
77
|
+
def _load_config(self, config_path: Optional[str]) -> Dict:
|
|
78
|
+
if config_path is None:
|
|
79
|
+
config_path = str(Path(__file__).parent.parent / "config.yml")
|
|
80
|
+
|
|
81
|
+
with open(config_path) as f:
|
|
82
|
+
return yaml.safe_load(f)
|
|
83
|
+
|
|
84
|
+
def _validate_dependencies(self) -> None:
|
|
85
|
+
from schema_search.utils.lazy_import import lazy_import_check
|
|
86
|
+
|
|
87
|
+
strategy = self.config["search"]["strategy"]
|
|
88
|
+
reranker_model = self.config["reranker"]["model"]
|
|
89
|
+
chunking_strategy = self.config["chunking"]["strategy"]
|
|
90
|
+
|
|
91
|
+
needs_semantic = strategy in ("semantic", "hybrid") or reranker_model
|
|
92
|
+
if needs_semantic:
|
|
93
|
+
lazy_import_check(
|
|
94
|
+
"sentence_transformers",
|
|
95
|
+
"semantic",
|
|
96
|
+
f"{strategy} search or reranking"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
if chunking_strategy == "llm":
|
|
100
|
+
lazy_import_check("openai", "llm", "LLM-based chunking")
|
|
101
|
+
|
|
102
|
+
@time_it
|
|
103
|
+
def index(self, force: bool = False) -> IndexResult:
|
|
104
|
+
logger.info("Starting schema indexing" + (" (force)" if force else ""))
|
|
105
|
+
|
|
106
|
+
current_schema = self._extract_current_schema()
|
|
107
|
+
|
|
108
|
+
schema_changed = False
|
|
109
|
+
if not force:
|
|
110
|
+
cached_schema = self._load_cached_schema()
|
|
111
|
+
schema_changed = self._schema_has_changed(cached_schema, current_schema)
|
|
112
|
+
if schema_changed:
|
|
113
|
+
logger.info("Schema change detected; forcing reindex")
|
|
114
|
+
|
|
115
|
+
self._cache_schema(current_schema)
|
|
116
|
+
|
|
117
|
+
effective_force = force or schema_changed
|
|
118
|
+
|
|
119
|
+
self.schemas = current_schema
|
|
120
|
+
self.graph_builder.build(self.schemas, effective_force)
|
|
121
|
+
self.chunks = self._load_or_generate_chunks(self.schemas, effective_force)
|
|
122
|
+
self._index_force = effective_force
|
|
123
|
+
|
|
124
|
+
logger.info(
|
|
125
|
+
f"Indexing complete: {len(self.schemas)} tables, {len(self.chunks)} chunks"
|
|
126
|
+
)
|
|
127
|
+
return {
|
|
128
|
+
"tables": len(self.schemas),
|
|
129
|
+
"chunks": len(self.chunks),
|
|
130
|
+
"latency_sec": 0.0,
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
def _extract_current_schema(self) -> Dict[str, TableSchema]:
|
|
134
|
+
logger.info("Extracting schema from database")
|
|
135
|
+
return self.schema_extractor.extract()
|
|
136
|
+
|
|
137
|
+
def _load_cached_schema(self) -> Optional[Dict[str, TableSchema]]:
|
|
138
|
+
schema_cache = self.cache_dir / "metadata.json"
|
|
139
|
+
|
|
140
|
+
if not schema_cache.exists():
|
|
141
|
+
logger.debug("Schema cache missing; treating as schema change")
|
|
142
|
+
return None
|
|
143
|
+
|
|
144
|
+
with open(schema_cache) as f:
|
|
145
|
+
return json.load(f)
|
|
146
|
+
|
|
147
|
+
def _cache_schema(self, schema: Dict[str, TableSchema]) -> None:
|
|
148
|
+
schema_cache = self.cache_dir / "metadata.json"
|
|
149
|
+
with open(schema_cache, "w") as f:
|
|
150
|
+
json.dump(schema, f, indent=2)
|
|
151
|
+
|
|
152
|
+
def _schema_has_changed(
|
|
153
|
+
self,
|
|
154
|
+
cached_schema: Optional[Dict[str, TableSchema]],
|
|
155
|
+
current_schema: Dict[str, TableSchema],
|
|
156
|
+
) -> bool:
|
|
157
|
+
if cached_schema is None:
|
|
158
|
+
return True
|
|
159
|
+
if cached_schema != current_schema:
|
|
160
|
+
logger.debug("Cached schema differs from current schema")
|
|
161
|
+
return True
|
|
162
|
+
logger.debug("Schema matches cached version; reuse existing index")
|
|
163
|
+
return False
|
|
164
|
+
|
|
165
|
+
def _load_or_generate_chunks(
|
|
166
|
+
self, schemas: Dict[str, TableSchema], force: bool
|
|
167
|
+
) -> List[Chunk]:
|
|
168
|
+
chunks_cache = self.cache_dir / "chunk_metadata.json"
|
|
169
|
+
|
|
170
|
+
if not force and chunks_cache.exists():
|
|
171
|
+
logger.info(f"Loading chunks from cache: {chunks_cache}")
|
|
172
|
+
with open(chunks_cache) as f:
|
|
173
|
+
chunk_data = json.load(f)
|
|
174
|
+
return [
|
|
175
|
+
Chunk(
|
|
176
|
+
table_name=c["table_name"],
|
|
177
|
+
content=c["content"],
|
|
178
|
+
chunk_id=c["chunk_id"],
|
|
179
|
+
token_count=c["token_count"],
|
|
180
|
+
)
|
|
181
|
+
for c in chunk_data
|
|
182
|
+
]
|
|
183
|
+
|
|
184
|
+
logger.info("Generating chunks from schemas")
|
|
185
|
+
chunks = self.chunker.chunk_schemas(schemas)
|
|
186
|
+
|
|
187
|
+
with open(chunks_cache, "w") as f:
|
|
188
|
+
chunk_data = [
|
|
189
|
+
{
|
|
190
|
+
"table_name": c.table_name,
|
|
191
|
+
"content": c.content,
|
|
192
|
+
"chunk_id": c.chunk_id,
|
|
193
|
+
"token_count": c.token_count,
|
|
194
|
+
}
|
|
195
|
+
for c in chunks
|
|
196
|
+
]
|
|
197
|
+
json.dump(chunk_data, f, indent=2)
|
|
198
|
+
|
|
199
|
+
return chunks
|
|
200
|
+
|
|
201
|
+
def _get_embedding_cache(self):
|
|
202
|
+
if self._embedding_cache is None:
|
|
203
|
+
self._embedding_cache = create_embedding_cache(self.config, self.cache_dir)
|
|
204
|
+
return self._embedding_cache
|
|
205
|
+
|
|
206
|
+
def _get_reranker(self):
|
|
207
|
+
if self._reranker is None and self._reranker_config:
|
|
208
|
+
self._reranker = create_ranker(self.config)
|
|
209
|
+
return self._reranker
|
|
210
|
+
|
|
211
|
+
@property
|
|
212
|
+
def embedding_cache(self):
|
|
213
|
+
return self._get_embedding_cache()
|
|
214
|
+
|
|
215
|
+
@property
|
|
216
|
+
def reranker(self):
|
|
217
|
+
return self._get_reranker()
|
|
218
|
+
|
|
219
|
+
def _get_bm25_cache(self):
|
|
220
|
+
if self._bm25_cache is None:
|
|
221
|
+
self._bm25_cache = BM25Cache()
|
|
222
|
+
return self._bm25_cache
|
|
223
|
+
|
|
224
|
+
def _ensure_embeddings_loaded(self):
|
|
225
|
+
cache = self._get_embedding_cache()
|
|
226
|
+
if cache.embeddings is None:
|
|
227
|
+
cache.load_or_generate(
|
|
228
|
+
self.chunks, self._index_force, self.config["chunking"]
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
def _ensure_bm25_built(self):
|
|
232
|
+
cache = self._get_bm25_cache()
|
|
233
|
+
if cache.bm25 is None:
|
|
234
|
+
logger.info("Building BM25 index")
|
|
235
|
+
cache.build(self.chunks)
|
|
236
|
+
|
|
237
|
+
def _get_search_strategy(self, search_type: str):
|
|
238
|
+
if search_type not in self._search_strategies:
|
|
239
|
+
self._search_strategies[search_type] = create_search_strategy(
|
|
240
|
+
self.config,
|
|
241
|
+
self._get_embedding_cache,
|
|
242
|
+
self._get_bm25_cache,
|
|
243
|
+
self._get_reranker,
|
|
244
|
+
search_type,
|
|
245
|
+
)
|
|
246
|
+
return self._search_strategies[search_type]
|
|
247
|
+
|
|
248
|
+
@time_it
|
|
249
|
+
def search(
|
|
250
|
+
self,
|
|
251
|
+
query: str,
|
|
252
|
+
hops: Optional[int] = None,
|
|
253
|
+
limit: int = 5,
|
|
254
|
+
search_type: Optional[SearchType] = None,
|
|
255
|
+
) -> SearchResult:
|
|
256
|
+
if hops is None:
|
|
257
|
+
hops = int(self.config["search"]["hops"])
|
|
258
|
+
logger.debug(f"Searching: {query} (hops={hops}, search_type={search_type})")
|
|
259
|
+
|
|
260
|
+
search_type = search_type or self.config["search"]["strategy"]
|
|
261
|
+
|
|
262
|
+
if search_type in ["semantic", "hybrid"]:
|
|
263
|
+
self._ensure_embeddings_loaded()
|
|
264
|
+
|
|
265
|
+
if search_type in ["bm25", "hybrid"]:
|
|
266
|
+
self._ensure_bm25_built()
|
|
267
|
+
|
|
268
|
+
strategy = self._get_search_strategy(search_type)
|
|
269
|
+
|
|
270
|
+
results = strategy.search(
|
|
271
|
+
query, self.schemas, self.chunks, self.graph_builder, hops, limit
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
logger.debug(f"Found {len(results)} results")
|
|
275
|
+
|
|
276
|
+
return {"results": results, "latency_sec": 0.0}
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from schema_search.search.base import BaseSearchStrategy
|
|
2
|
+
from schema_search.search.semantic import SemanticSearchStrategy
|
|
3
|
+
from schema_search.search.fuzzy import FuzzySearchStrategy
|
|
4
|
+
from schema_search.search.bm25 import BM25SearchStrategy
|
|
5
|
+
from schema_search.search.hybrid import HybridSearchStrategy
|
|
6
|
+
from schema_search.search.factory import create_search_strategy
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"BaseSearchStrategy",
|
|
10
|
+
"SemanticSearchStrategy",
|
|
11
|
+
"FuzzySearchStrategy",
|
|
12
|
+
"BM25SearchStrategy",
|
|
13
|
+
"HybridSearchStrategy",
|
|
14
|
+
"create_search_strategy",
|
|
15
|
+
]
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
from typing import Dict, List, Optional
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
|
|
4
|
+
from schema_search.types import TableSchema, SearchResultItem
|
|
5
|
+
from schema_search.chunkers import Chunk
|
|
6
|
+
from schema_search.graph_builder import GraphBuilder
|
|
7
|
+
from schema_search.rankers.base import BaseRanker
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BaseSearchStrategy(ABC):
|
|
11
|
+
def __init__(
|
|
12
|
+
self, reranker: Optional[BaseRanker], initial_top_k: int, rerank_top_k: int
|
|
13
|
+
):
|
|
14
|
+
self.reranker = reranker
|
|
15
|
+
self.initial_top_k = initial_top_k
|
|
16
|
+
self.rerank_top_k = rerank_top_k
|
|
17
|
+
|
|
18
|
+
def search(
|
|
19
|
+
self,
|
|
20
|
+
query: str,
|
|
21
|
+
schemas: Dict[str, TableSchema],
|
|
22
|
+
chunks: List[Chunk],
|
|
23
|
+
graph_builder: GraphBuilder,
|
|
24
|
+
hops: int,
|
|
25
|
+
limit: int,
|
|
26
|
+
) -> List[SearchResultItem]:
|
|
27
|
+
initial_results = self._initial_ranking(
|
|
28
|
+
query, schemas, chunks, graph_builder, hops
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
if self.reranker is None:
|
|
32
|
+
return initial_results[:limit]
|
|
33
|
+
|
|
34
|
+
initial_chunks = []
|
|
35
|
+
for result in initial_results:
|
|
36
|
+
for chunk in chunks:
|
|
37
|
+
if chunk.table_name == result["table"]:
|
|
38
|
+
initial_chunks.append(chunk)
|
|
39
|
+
break
|
|
40
|
+
|
|
41
|
+
self.reranker.build(initial_chunks)
|
|
42
|
+
ranked = self.reranker.rank(query)
|
|
43
|
+
|
|
44
|
+
reranked_results: List[SearchResultItem] = []
|
|
45
|
+
for chunk_idx, score in ranked[: self.rerank_top_k]:
|
|
46
|
+
chunk = initial_chunks[chunk_idx]
|
|
47
|
+
result = self._build_result_item(
|
|
48
|
+
table_name=chunk.table_name,
|
|
49
|
+
score=score,
|
|
50
|
+
schema=schemas[chunk.table_name],
|
|
51
|
+
matched_chunks=[chunk.content],
|
|
52
|
+
graph_builder=graph_builder,
|
|
53
|
+
hops=hops,
|
|
54
|
+
)
|
|
55
|
+
reranked_results.append(result)
|
|
56
|
+
|
|
57
|
+
return reranked_results[:limit]
|
|
58
|
+
|
|
59
|
+
@abstractmethod
|
|
60
|
+
def _initial_ranking(
|
|
61
|
+
self,
|
|
62
|
+
query: str,
|
|
63
|
+
schemas: Dict[str, TableSchema],
|
|
64
|
+
chunks: List[Chunk],
|
|
65
|
+
graph_builder: GraphBuilder,
|
|
66
|
+
hops: int,
|
|
67
|
+
) -> List[SearchResultItem]:
|
|
68
|
+
pass
|
|
69
|
+
|
|
70
|
+
def _build_result_item(
|
|
71
|
+
self,
|
|
72
|
+
table_name: str,
|
|
73
|
+
score: float,
|
|
74
|
+
schema: TableSchema,
|
|
75
|
+
matched_chunks: List[str],
|
|
76
|
+
graph_builder: GraphBuilder,
|
|
77
|
+
hops: int,
|
|
78
|
+
) -> SearchResultItem:
|
|
79
|
+
return {
|
|
80
|
+
"table": table_name,
|
|
81
|
+
"score": score,
|
|
82
|
+
"schema": schema,
|
|
83
|
+
"matched_chunks": matched_chunks,
|
|
84
|
+
"related_tables": list(graph_builder.get_neighbors(table_name, hops)),
|
|
85
|
+
}
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from typing import Dict, List, Optional, TYPE_CHECKING
|
|
2
|
+
|
|
3
|
+
from schema_search.search.base import BaseSearchStrategy
|
|
4
|
+
from schema_search.types import TableSchema, SearchResultItem
|
|
5
|
+
from schema_search.chunkers import Chunk
|
|
6
|
+
from schema_search.graph_builder import GraphBuilder
|
|
7
|
+
from schema_search.rankers.base import BaseRanker
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from schema_search.embedding_cache.bm25 import BM25Cache
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class BM25SearchStrategy(BaseSearchStrategy):
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
bm25_cache: "BM25Cache",
|
|
17
|
+
initial_top_k: int,
|
|
18
|
+
rerank_top_k: int,
|
|
19
|
+
reranker: Optional[BaseRanker],
|
|
20
|
+
):
|
|
21
|
+
super().__init__(reranker, initial_top_k, rerank_top_k)
|
|
22
|
+
self.bm25_cache = bm25_cache
|
|
23
|
+
|
|
24
|
+
def _initial_ranking(
|
|
25
|
+
self,
|
|
26
|
+
query: str,
|
|
27
|
+
schemas: Dict[str, TableSchema],
|
|
28
|
+
chunks: List[Chunk],
|
|
29
|
+
graph_builder: GraphBuilder,
|
|
30
|
+
hops: int,
|
|
31
|
+
) -> List[SearchResultItem]:
|
|
32
|
+
scores = self.bm25_cache.get_scores(query)
|
|
33
|
+
top_indices = scores.argsort()[::-1][: self.initial_top_k]
|
|
34
|
+
|
|
35
|
+
results: List[SearchResultItem] = []
|
|
36
|
+
for idx in top_indices:
|
|
37
|
+
chunk = chunks[idx]
|
|
38
|
+
result = self._build_result_item(
|
|
39
|
+
table_name=chunk.table_name,
|
|
40
|
+
score=float(scores[idx]),
|
|
41
|
+
schema=schemas[chunk.table_name],
|
|
42
|
+
matched_chunks=[chunk.content],
|
|
43
|
+
graph_builder=graph_builder,
|
|
44
|
+
hops=hops,
|
|
45
|
+
)
|
|
46
|
+
results.append(result)
|
|
47
|
+
|
|
48
|
+
return results
|