schema-search 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- schema_search/__init__.py +26 -0
- schema_search/chunkers/__init__.py +6 -0
- schema_search/chunkers/base.py +95 -0
- schema_search/chunkers/factory.py +31 -0
- schema_search/chunkers/llm.py +54 -0
- schema_search/chunkers/markdown.py +25 -0
- schema_search/embedding_cache/__init__.py +5 -0
- schema_search/embedding_cache/base.py +40 -0
- schema_search/embedding_cache/bm25.py +63 -0
- schema_search/embedding_cache/factory.py +20 -0
- schema_search/embedding_cache/inmemory.py +122 -0
- schema_search/graph_builder.py +69 -0
- schema_search/mcp_server.py +81 -0
- schema_search/metrics.py +33 -0
- schema_search/rankers/__init__.py +5 -0
- schema_search/rankers/base.py +45 -0
- schema_search/rankers/cross_encoder.py +40 -0
- schema_search/rankers/factory.py +11 -0
- schema_search/schema_extractor.py +135 -0
- schema_search/schema_search.py +276 -0
- schema_search/search/__init__.py +15 -0
- schema_search/search/base.py +85 -0
- schema_search/search/bm25.py +48 -0
- schema_search/search/factory.py +61 -0
- schema_search/search/fuzzy.py +56 -0
- schema_search/search/hybrid.py +82 -0
- schema_search/search/semantic.py +49 -0
- schema_search/types.py +57 -0
- schema_search/utils/__init__.py +0 -0
- schema_search/utils/lazy_import.py +26 -0
- schema_search-0.1.10.dist-info/METADATA +308 -0
- schema_search-0.1.10.dist-info/RECORD +40 -0
- schema_search-0.1.10.dist-info/WHEEL +5 -0
- schema_search-0.1.10.dist-info/entry_points.txt +2 -0
- schema_search-0.1.10.dist-info/licenses/LICENSE +21 -0
- schema_search-0.1.10.dist-info/top_level.txt +2 -0
- tests/__init__.py +0 -0
- tests/test_integration.py +352 -0
- tests/test_llm_sql_generation.py +320 -0
- tests/test_spider_eval.py +488 -0
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from schema_search.schema_search import SchemaSearch
|
|
2
|
+
from schema_search.types import (
|
|
3
|
+
IndexResult,
|
|
4
|
+
SearchResult,
|
|
5
|
+
SearchResultItem,
|
|
6
|
+
SearchType,
|
|
7
|
+
TableSchema,
|
|
8
|
+
ColumnInfo,
|
|
9
|
+
ForeignKeyInfo,
|
|
10
|
+
IndexInfo,
|
|
11
|
+
ConstraintInfo,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
__version__ = "0.1.0"
|
|
15
|
+
__all__ = [
|
|
16
|
+
"SchemaSearch",
|
|
17
|
+
"IndexResult",
|
|
18
|
+
"SearchResult",
|
|
19
|
+
"SearchResultItem",
|
|
20
|
+
"SearchType",
|
|
21
|
+
"TableSchema",
|
|
22
|
+
"ColumnInfo",
|
|
23
|
+
"ForeignKeyInfo",
|
|
24
|
+
"IndexInfo",
|
|
25
|
+
"ConstraintInfo",
|
|
26
|
+
]
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
from schema_search.chunkers.base import Chunk, BaseChunker
|
|
2
|
+
from schema_search.chunkers.markdown import MarkdownChunker
|
|
3
|
+
from schema_search.chunkers.llm import LLMChunker
|
|
4
|
+
from schema_search.chunkers.factory import create_chunker
|
|
5
|
+
|
|
6
|
+
__all__ = ["Chunk", "BaseChunker", "MarkdownChunker", "LLMChunker", "create_chunker"]
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
from typing import Dict, List
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
|
|
5
|
+
from tqdm import tqdm
|
|
6
|
+
|
|
7
|
+
from schema_search.types import TableSchema
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class Chunk:
|
|
12
|
+
table_name: str
|
|
13
|
+
content: str
|
|
14
|
+
chunk_id: int
|
|
15
|
+
token_count: int
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BaseChunker(ABC):
|
|
19
|
+
def __init__(self, max_tokens: int, overlap_tokens: int, show_progress: bool = False):
|
|
20
|
+
self.max_tokens = max_tokens
|
|
21
|
+
self.overlap_tokens = overlap_tokens
|
|
22
|
+
self.show_progress = show_progress
|
|
23
|
+
|
|
24
|
+
def chunk_schemas(self, schemas: Dict[str, TableSchema]) -> List[Chunk]:
|
|
25
|
+
chunks: List[Chunk] = []
|
|
26
|
+
chunk_id = 0
|
|
27
|
+
|
|
28
|
+
iterator = schemas.items()
|
|
29
|
+
if self.show_progress:
|
|
30
|
+
iterator = tqdm(iterator, desc="Chunking tables", unit="table")
|
|
31
|
+
|
|
32
|
+
for table_name, schema in iterator:
|
|
33
|
+
table_chunks = self._chunk_table(table_name, schema, chunk_id)
|
|
34
|
+
chunks.extend(table_chunks)
|
|
35
|
+
chunk_id += len(table_chunks)
|
|
36
|
+
|
|
37
|
+
return chunks
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def _generate_content(self, table_name: str, schema: TableSchema) -> str:
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
def _chunk_table(
|
|
44
|
+
self, table_name: str, schema: TableSchema, start_id: int
|
|
45
|
+
) -> List[Chunk]:
|
|
46
|
+
content = self._generate_content(table_name, schema)
|
|
47
|
+
lines = content.split("\n")
|
|
48
|
+
|
|
49
|
+
header = f"Table: {table_name}"
|
|
50
|
+
header_tokens = self._estimate_tokens(header)
|
|
51
|
+
|
|
52
|
+
chunks: List[Chunk] = []
|
|
53
|
+
current_chunk_lines = [header]
|
|
54
|
+
current_tokens = header_tokens
|
|
55
|
+
chunk_id = start_id
|
|
56
|
+
|
|
57
|
+
for line in lines[1:]:
|
|
58
|
+
line_tokens = self._estimate_tokens(line)
|
|
59
|
+
|
|
60
|
+
if (
|
|
61
|
+
current_tokens + line_tokens > self.max_tokens
|
|
62
|
+
and len(current_chunk_lines) > 1
|
|
63
|
+
):
|
|
64
|
+
chunk_content = "\n".join(current_chunk_lines)
|
|
65
|
+
chunks.append(
|
|
66
|
+
Chunk(
|
|
67
|
+
table_name=table_name,
|
|
68
|
+
content=chunk_content,
|
|
69
|
+
chunk_id=chunk_id,
|
|
70
|
+
token_count=current_tokens,
|
|
71
|
+
)
|
|
72
|
+
)
|
|
73
|
+
chunk_id += 1
|
|
74
|
+
|
|
75
|
+
current_chunk_lines = [header]
|
|
76
|
+
current_tokens = header_tokens
|
|
77
|
+
|
|
78
|
+
current_chunk_lines.append(line)
|
|
79
|
+
current_tokens += line_tokens
|
|
80
|
+
|
|
81
|
+
if len(current_chunk_lines) > 1:
|
|
82
|
+
chunk_content = "\n".join(current_chunk_lines)
|
|
83
|
+
chunks.append(
|
|
84
|
+
Chunk(
|
|
85
|
+
table_name=table_name,
|
|
86
|
+
content=chunk_content,
|
|
87
|
+
chunk_id=chunk_id,
|
|
88
|
+
token_count=current_tokens,
|
|
89
|
+
)
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
return chunks
|
|
93
|
+
|
|
94
|
+
def _estimate_tokens(self, text: str) -> int:
|
|
95
|
+
return len(text.split()) + len(text) // 4
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from typing import Dict, Optional
|
|
2
|
+
|
|
3
|
+
from schema_search.chunkers.base import BaseChunker
|
|
4
|
+
from schema_search.chunkers.markdown import MarkdownChunker
|
|
5
|
+
from schema_search.chunkers.llm import LLMChunker
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def create_chunker(
|
|
9
|
+
config: Dict, llm_api_key: Optional[str], llm_base_url: Optional[str]
|
|
10
|
+
) -> BaseChunker:
|
|
11
|
+
chunking_config = config["chunking"]
|
|
12
|
+
strategy = chunking_config["strategy"]
|
|
13
|
+
show_progress = config["embedding"].get("show_progress", False)
|
|
14
|
+
|
|
15
|
+
if strategy == "llm":
|
|
16
|
+
return LLMChunker(
|
|
17
|
+
max_tokens=chunking_config["max_tokens"],
|
|
18
|
+
overlap_tokens=chunking_config["overlap_tokens"],
|
|
19
|
+
model=chunking_config["model"],
|
|
20
|
+
llm_api_key=llm_api_key,
|
|
21
|
+
llm_base_url=llm_base_url,
|
|
22
|
+
show_progress=show_progress,
|
|
23
|
+
)
|
|
24
|
+
elif strategy == "raw":
|
|
25
|
+
return MarkdownChunker(
|
|
26
|
+
max_tokens=chunking_config["max_tokens"],
|
|
27
|
+
overlap_tokens=chunking_config["overlap_tokens"],
|
|
28
|
+
show_progress=show_progress,
|
|
29
|
+
)
|
|
30
|
+
else:
|
|
31
|
+
raise ValueError(f"Unknown chunking strategy: {strategy}")
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Optional, TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from schema_search.chunkers.base import BaseChunker
|
|
6
|
+
from schema_search.types import TableSchema
|
|
7
|
+
from schema_search.utils.lazy_import import lazy_import_check
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from openai import OpenAI
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class LLMChunker(BaseChunker):
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
max_tokens: int,
|
|
19
|
+
overlap_tokens: int,
|
|
20
|
+
model: str,
|
|
21
|
+
llm_api_key: Optional[str],
|
|
22
|
+
llm_base_url: Optional[str],
|
|
23
|
+
show_progress: bool = False,
|
|
24
|
+
):
|
|
25
|
+
super().__init__(max_tokens, overlap_tokens, show_progress)
|
|
26
|
+
self.model = model
|
|
27
|
+
openai = lazy_import_check("openai", "llm", "LLM-based chunking")
|
|
28
|
+
self.llm_client: "OpenAI" = openai.OpenAI(api_key=llm_api_key, base_url=llm_base_url)
|
|
29
|
+
logger.info(f"Schema Summarizer Model: {self.model}")
|
|
30
|
+
|
|
31
|
+
def _generate_content(self, table_name: str, schema: TableSchema) -> str:
|
|
32
|
+
prompt = f"""Generate a concise 250 tokens or less semantic summary of this database table schema. Focus on:
|
|
33
|
+
1. What entity or concept this table represents
|
|
34
|
+
2. Key data it stores (main columns)
|
|
35
|
+
3. How it relates to other tables
|
|
36
|
+
4. Any important constraints or indices
|
|
37
|
+
|
|
38
|
+
Keep it brief and semantic, optimized for embedding-based search.
|
|
39
|
+
|
|
40
|
+
Schema:
|
|
41
|
+
{json.dumps(schema, indent=2)}
|
|
42
|
+
|
|
43
|
+
Return ONLY the summary text, no preamble."""
|
|
44
|
+
|
|
45
|
+
response = self.llm_client.chat.completions.create(
|
|
46
|
+
model=self.model,
|
|
47
|
+
max_tokens=500,
|
|
48
|
+
messages=[{"role": "user", "content": prompt}],
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
summary = response.choices[0].message.content.strip() # type: ignore
|
|
52
|
+
logger.debug(f"Generated LLM summary for {table_name}: {summary[:100]}...")
|
|
53
|
+
|
|
54
|
+
return f"Table: {table_name}\n{summary}"
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from schema_search.chunkers.base import BaseChunker
|
|
2
|
+
from schema_search.types import TableSchema
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class MarkdownChunker(BaseChunker):
|
|
6
|
+
def _generate_content(self, table_name: str, schema: TableSchema) -> str:
|
|
7
|
+
lines = [f"Table: {table_name}"]
|
|
8
|
+
|
|
9
|
+
if schema["primary_keys"]:
|
|
10
|
+
lines.append(f"Primary keys: {', '.join(schema['primary_keys'])}")
|
|
11
|
+
|
|
12
|
+
if schema["columns"]:
|
|
13
|
+
col_names = [col["name"] for col in schema["columns"]]
|
|
14
|
+
lines.append(f"Columns: {', '.join(col_names)}")
|
|
15
|
+
|
|
16
|
+
if schema["foreign_keys"]:
|
|
17
|
+
related = [fk["referred_table"] for fk in schema["foreign_keys"]]
|
|
18
|
+
lines.append(f"Related to: {', '.join(related)}")
|
|
19
|
+
|
|
20
|
+
if schema["indices"]:
|
|
21
|
+
idx_names = [idx["name"] for idx in schema["indices"] if idx["name"]]
|
|
22
|
+
if idx_names:
|
|
23
|
+
lines.append(f"Indexes: {', '.join(idx_names)}")
|
|
24
|
+
|
|
25
|
+
return "\n".join(lines)
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
from schema_search.embedding_cache.base import BaseEmbeddingCache
|
|
2
|
+
from schema_search.embedding_cache.inmemory import InMemoryEmbeddingCache
|
|
3
|
+
from schema_search.embedding_cache.factory import create_embedding_cache
|
|
4
|
+
|
|
5
|
+
__all__ = ["BaseEmbeddingCache", "InMemoryEmbeddingCache", "create_embedding_cache"]
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Dict, List
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from schema_search.chunkers import Chunk
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BaseEmbeddingCache(ABC):
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
cache_dir: Path,
|
|
14
|
+
model_name: str,
|
|
15
|
+
metric: str,
|
|
16
|
+
batch_size: int,
|
|
17
|
+
show_progress: bool,
|
|
18
|
+
):
|
|
19
|
+
self.cache_dir = cache_dir
|
|
20
|
+
self.cache_dir.mkdir(exist_ok=True)
|
|
21
|
+
self.model_name = model_name
|
|
22
|
+
self.model = None
|
|
23
|
+
self.metric = metric
|
|
24
|
+
self.batch_size = batch_size
|
|
25
|
+
self.show_progress = show_progress
|
|
26
|
+
self.embeddings = None
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def load_or_generate(
|
|
30
|
+
self, chunks: List[Chunk], force: bool, chunking_config: Dict
|
|
31
|
+
) -> None:
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def encode_query(self, query: str) -> np.ndarray:
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def compute_similarities(self, query_embedding: np.ndarray) -> np.ndarray:
|
|
40
|
+
pass
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
import re
|
|
3
|
+
import logging
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
import bm25s
|
|
7
|
+
|
|
8
|
+
from schema_search.chunkers import Chunk
|
|
9
|
+
|
|
10
|
+
logging.getLogger("bm25s").setLevel(logging.WARNING)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def light_stem(token: str) -> str:
|
|
14
|
+
"""Tiny rule-based stemmer for schema tokens."""
|
|
15
|
+
for suf in ("ing", "ers", "ies", "ied", "ed", "es", "s"):
|
|
16
|
+
if token.endswith(suf) and len(token) > len(suf) + 2:
|
|
17
|
+
if suf == "ies":
|
|
18
|
+
return token[:-3] + "y"
|
|
19
|
+
return token[: -len(suf)]
|
|
20
|
+
return token
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _tokenize(text: str) -> List[str]:
|
|
24
|
+
"""Tokenize and normalize database-like text."""
|
|
25
|
+
text = text.lower()
|
|
26
|
+
text = text.replace("\n", " ")
|
|
27
|
+
text = re.sub(r"[_\-]+", " ", text)
|
|
28
|
+
text = re.sub(r"([a-z])([A-Z])", r"\1 \2", text)
|
|
29
|
+
text = re.sub(r"([a-z])([0-9])", r"\1 \2", text)
|
|
30
|
+
text = re.sub(r"([0-9])([a-z])", r"\1 \2", text)
|
|
31
|
+
|
|
32
|
+
tokens = re.findall(r"[a-z0-9]+", text)
|
|
33
|
+
normalized = []
|
|
34
|
+
for t in tokens:
|
|
35
|
+
if t in {"pk", "pkey", "key"}:
|
|
36
|
+
t = "id"
|
|
37
|
+
elif t in {"ts", "time", "timestamp"}:
|
|
38
|
+
t = "timestamp"
|
|
39
|
+
elif t.endswith("id") and len(t) > 2:
|
|
40
|
+
t = "id"
|
|
41
|
+
elif t in {"ix", "index", "idx"}:
|
|
42
|
+
t = "index"
|
|
43
|
+
normalized.append(light_stem(t))
|
|
44
|
+
return normalized
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class BM25Cache:
|
|
48
|
+
def __init__(self):
|
|
49
|
+
self.bm25 = None
|
|
50
|
+
self.tokenized_docs = None
|
|
51
|
+
|
|
52
|
+
def build(self, chunks: List[Chunk]) -> None:
|
|
53
|
+
if self.bm25 is None:
|
|
54
|
+
self.tokenized_docs = [_tokenize(chunk.content) for chunk in chunks]
|
|
55
|
+
self.bm25 = bm25s.BM25()
|
|
56
|
+
self.bm25.index(self.tokenized_docs)
|
|
57
|
+
|
|
58
|
+
def get_scores(self, query: str) -> np.ndarray:
|
|
59
|
+
if self.bm25 is None or self.tokenized_docs is None:
|
|
60
|
+
raise RuntimeError("BM25 cache not built. Call build() first.")
|
|
61
|
+
query_tokens = _tokenize(query)
|
|
62
|
+
scores = self.bm25.get_scores(query_tokens)
|
|
63
|
+
return scores
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Dict
|
|
3
|
+
|
|
4
|
+
from schema_search.embedding_cache.base import BaseEmbeddingCache
|
|
5
|
+
from schema_search.embedding_cache.inmemory import InMemoryEmbeddingCache
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def create_embedding_cache(config: Dict, cache_dir: Path) -> BaseEmbeddingCache:
|
|
9
|
+
location = config["embedding"]["location"]
|
|
10
|
+
|
|
11
|
+
if location == "memory":
|
|
12
|
+
return InMemoryEmbeddingCache(
|
|
13
|
+
cache_dir=cache_dir,
|
|
14
|
+
model_name=config["embedding"]["model"],
|
|
15
|
+
metric=config["embedding"]["metric"],
|
|
16
|
+
batch_size=config["embedding"]["batch_size"],
|
|
17
|
+
show_progress=config["embedding"]["show_progress"],
|
|
18
|
+
)
|
|
19
|
+
else:
|
|
20
|
+
raise ValueError(f"Unsupported embedding location: {location}")
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Dict, List, Optional, TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from schema_search.chunkers import Chunk
|
|
9
|
+
from schema_search.embedding_cache.base import BaseEmbeddingCache
|
|
10
|
+
from schema_search.metrics import get_metric
|
|
11
|
+
from schema_search.utils.lazy_import import lazy_import_check
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from sentence_transformers import SentenceTransformer
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class InMemoryEmbeddingCache(BaseEmbeddingCache):
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
cache_dir: Path,
|
|
23
|
+
model_name: str,
|
|
24
|
+
metric: str,
|
|
25
|
+
batch_size: int,
|
|
26
|
+
show_progress: bool,
|
|
27
|
+
):
|
|
28
|
+
super().__init__(cache_dir, model_name, metric, batch_size, show_progress)
|
|
29
|
+
self.model: Optional["SentenceTransformer"] = None
|
|
30
|
+
|
|
31
|
+
def load_or_generate(
|
|
32
|
+
self, chunks: List[Chunk], force: bool, chunking_config: Dict
|
|
33
|
+
) -> None:
|
|
34
|
+
cache_file = self.cache_dir / "embeddings.npz"
|
|
35
|
+
config_file = self.cache_dir / "cache_config.json"
|
|
36
|
+
|
|
37
|
+
if not force and self._is_cache_valid(cache_file, config_file, chunking_config):
|
|
38
|
+
self._load_from_cache(cache_file)
|
|
39
|
+
else:
|
|
40
|
+
self._generate_and_cache(chunks, cache_file, config_file, chunking_config)
|
|
41
|
+
|
|
42
|
+
def _load_from_cache(self, cache_file: Path) -> None:
|
|
43
|
+
logger.info("Loading embeddings from cache")
|
|
44
|
+
self.embeddings = np.load(cache_file)["embeddings"]
|
|
45
|
+
|
|
46
|
+
def _is_cache_valid(
|
|
47
|
+
self, cache_file: Path, config_file: Path, chunking_config: Dict
|
|
48
|
+
) -> bool:
|
|
49
|
+
if not (cache_file.exists() and config_file.exists()):
|
|
50
|
+
return False
|
|
51
|
+
|
|
52
|
+
with open(config_file) as f:
|
|
53
|
+
cached_config = json.load(f)
|
|
54
|
+
|
|
55
|
+
current_config = {
|
|
56
|
+
"strategy": chunking_config["strategy"],
|
|
57
|
+
"max_tokens": chunking_config["max_tokens"],
|
|
58
|
+
"embedding_model": self.model_name,
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
if cached_config != current_config:
|
|
62
|
+
logger.info("Cache invalidated: chunking config changed")
|
|
63
|
+
return False
|
|
64
|
+
|
|
65
|
+
return True
|
|
66
|
+
|
|
67
|
+
def _generate_and_cache(
|
|
68
|
+
self,
|
|
69
|
+
chunks: List[Chunk],
|
|
70
|
+
cache_file: Path,
|
|
71
|
+
config_file: Path,
|
|
72
|
+
chunking_config: Dict,
|
|
73
|
+
) -> None:
|
|
74
|
+
self._load_model()
|
|
75
|
+
|
|
76
|
+
logger.info(f"Generating embeddings for {len(chunks)} chunks")
|
|
77
|
+
texts = [chunk.content for chunk in chunks]
|
|
78
|
+
|
|
79
|
+
assert self.model is not None
|
|
80
|
+
self.embeddings = self.model.encode(
|
|
81
|
+
texts,
|
|
82
|
+
batch_size=self.batch_size,
|
|
83
|
+
normalize_embeddings=True,
|
|
84
|
+
show_progress_bar=self.show_progress,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
np.savez_compressed(cache_file, embeddings=self.embeddings)
|
|
88
|
+
|
|
89
|
+
cache_config = {
|
|
90
|
+
"strategy": chunking_config["strategy"],
|
|
91
|
+
"max_tokens": chunking_config["max_tokens"],
|
|
92
|
+
"embedding_model": self.model_name,
|
|
93
|
+
}
|
|
94
|
+
with open(config_file, "w") as f:
|
|
95
|
+
json.dump(cache_config, f, indent=2)
|
|
96
|
+
|
|
97
|
+
def _load_model(self) -> None:
|
|
98
|
+
if self.model is None:
|
|
99
|
+
sentence_transformers = lazy_import_check(
|
|
100
|
+
"sentence_transformers",
|
|
101
|
+
"semantic",
|
|
102
|
+
"semantic/hybrid search or reranking",
|
|
103
|
+
)
|
|
104
|
+
logging.getLogger("sentence_transformers").setLevel(logging.WARNING)
|
|
105
|
+
self.model = sentence_transformers.SentenceTransformer(self.model_name)
|
|
106
|
+
logger.info(f"Loaded embedding model: {self.model_name}")
|
|
107
|
+
|
|
108
|
+
def encode_query(self, query: str) -> np.ndarray:
|
|
109
|
+
self._load_model()
|
|
110
|
+
|
|
111
|
+
assert self.model is not None
|
|
112
|
+
query_emb = self.model.encode(
|
|
113
|
+
[query],
|
|
114
|
+
batch_size=self.batch_size,
|
|
115
|
+
normalize_embeddings=True,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
return query_emb
|
|
119
|
+
|
|
120
|
+
def compute_similarities(self, query_embedding: np.ndarray) -> np.ndarray:
|
|
121
|
+
metric_fn = get_metric(self.metric)
|
|
122
|
+
return metric_fn(self.embeddings, query_embedding).flatten()
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import pickle
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Dict, Set
|
|
5
|
+
|
|
6
|
+
import networkx as nx
|
|
7
|
+
|
|
8
|
+
from schema_search.types import TableSchema
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class GraphBuilder:
|
|
14
|
+
def __init__(self, cache_dir: Path):
|
|
15
|
+
self.cache_dir = cache_dir
|
|
16
|
+
self.cache_dir.mkdir(exist_ok=True)
|
|
17
|
+
self.graph: nx.DiGraph
|
|
18
|
+
|
|
19
|
+
def build(self, schemas: Dict[str, TableSchema], force: bool) -> None:
|
|
20
|
+
cache_file = self.cache_dir / "graph.pkl"
|
|
21
|
+
|
|
22
|
+
if not force and cache_file.exists():
|
|
23
|
+
self._load_from_cache(cache_file)
|
|
24
|
+
else:
|
|
25
|
+
self._build_and_cache(schemas, cache_file)
|
|
26
|
+
|
|
27
|
+
def _load_from_cache(self, cache_file: Path) -> None:
|
|
28
|
+
logger.debug(f"Loading graph from cache: {cache_file}")
|
|
29
|
+
with open(cache_file, "rb") as f:
|
|
30
|
+
self.graph = pickle.load(f)
|
|
31
|
+
|
|
32
|
+
def _build_and_cache(
|
|
33
|
+
self, schemas: Dict[str, TableSchema], cache_file: Path
|
|
34
|
+
) -> None:
|
|
35
|
+
logger.info("Building foreign key relationship graph")
|
|
36
|
+
self.graph = nx.DiGraph()
|
|
37
|
+
|
|
38
|
+
for table_name, schema in schemas.items():
|
|
39
|
+
self.graph.add_node(table_name, **schema)
|
|
40
|
+
|
|
41
|
+
for table_name, schema in schemas.items():
|
|
42
|
+
if schema["foreign_keys"]:
|
|
43
|
+
for fk in schema["foreign_keys"]:
|
|
44
|
+
referred_table = fk["referred_table"]
|
|
45
|
+
if referred_table in self.graph:
|
|
46
|
+
self.graph.add_edge(table_name, referred_table, **fk)
|
|
47
|
+
|
|
48
|
+
with open(cache_file, "wb") as f:
|
|
49
|
+
pickle.dump(self.graph, f)
|
|
50
|
+
|
|
51
|
+
def get_neighbors(self, table_name: str, hops: int) -> Set[str]:
|
|
52
|
+
if table_name not in self.graph:
|
|
53
|
+
return set()
|
|
54
|
+
|
|
55
|
+
neighbors: Set[str] = set()
|
|
56
|
+
|
|
57
|
+
forward = nx.single_source_shortest_path_length(
|
|
58
|
+
self.graph, table_name, cutoff=hops
|
|
59
|
+
)
|
|
60
|
+
neighbors.update(forward.keys())
|
|
61
|
+
|
|
62
|
+
backward = nx.single_source_shortest_path_length(
|
|
63
|
+
self.graph.reverse(), table_name, cutoff=hops
|
|
64
|
+
)
|
|
65
|
+
neighbors.update(backward.keys())
|
|
66
|
+
|
|
67
|
+
neighbors.discard(table_name)
|
|
68
|
+
|
|
69
|
+
return neighbors
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
from fastmcp import FastMCP
|
|
6
|
+
from sqlalchemy import create_engine
|
|
7
|
+
|
|
8
|
+
from schema_search import SchemaSearch
|
|
9
|
+
|
|
10
|
+
logging.basicConfig(level=logging.INFO)
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
mcp = FastMCP("schema-search")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@mcp.tool()
|
|
17
|
+
def schema_search(
|
|
18
|
+
query: str,
|
|
19
|
+
limit: int = 5,
|
|
20
|
+
) -> dict:
|
|
21
|
+
"""Search database schema using natural language.
|
|
22
|
+
|
|
23
|
+
Finds relevant database tables and their relationships by searching through schema metadata
|
|
24
|
+
using semantic similarity. Expands results by traversing foreign key relationships.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
query: Natural language question about database schema (e.g., 'tables related to payments')
|
|
28
|
+
limit: Maximum number of table schemas to return in results. Default: 5; Max: 10.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
Dictionary with 'results' (list of table schemas with columns, types, constraints, and relationships) and 'latency_sec' (query execution time)
|
|
32
|
+
"""
|
|
33
|
+
limit = min(limit, 10)
|
|
34
|
+
search_result = mcp.search_engine.search(query, limit=limit) # type: ignore
|
|
35
|
+
return {
|
|
36
|
+
"results": search_result["results"],
|
|
37
|
+
"latency_sec": search_result["latency_sec"],
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def run_server(
|
|
42
|
+
database_url: str,
|
|
43
|
+
config_path: Optional[str] = None,
|
|
44
|
+
llm_api_key: Optional[str] = None,
|
|
45
|
+
llm_base_url: Optional[str] = None,
|
|
46
|
+
):
|
|
47
|
+
engine = create_engine(database_url)
|
|
48
|
+
|
|
49
|
+
mcp.search_engine = SchemaSearch( # type: ignore
|
|
50
|
+
engine,
|
|
51
|
+
config_path=config_path,
|
|
52
|
+
llm_api_key=llm_api_key,
|
|
53
|
+
llm_base_url=llm_base_url,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
logger.info("Indexing database schema...")
|
|
57
|
+
mcp.search_engine.index() # type: ignore
|
|
58
|
+
logger.info("Index ready")
|
|
59
|
+
|
|
60
|
+
mcp.run()
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def main():
|
|
64
|
+
import sys
|
|
65
|
+
|
|
66
|
+
if len(sys.argv) < 2:
|
|
67
|
+
print(
|
|
68
|
+
"Usage: schema-search <database_url> [config_path] [llm_api_key] [llm_base_url]"
|
|
69
|
+
)
|
|
70
|
+
sys.exit(1)
|
|
71
|
+
|
|
72
|
+
database_url = sys.argv[1]
|
|
73
|
+
config_path = sys.argv[2] if len(sys.argv) > 2 else None
|
|
74
|
+
llm_api_key = sys.argv[3] if len(sys.argv) > 3 else None
|
|
75
|
+
llm_base_url = sys.argv[4] if len(sys.argv) > 4 else None
|
|
76
|
+
|
|
77
|
+
run_server(database_url, config_path, llm_api_key, llm_base_url)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
if __name__ == "__main__":
|
|
81
|
+
main()
|
schema_search/metrics.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
|
5
|
+
a_norm = a / (np.linalg.norm(a, axis=-1, keepdims=True) + 1e-8)
|
|
6
|
+
b_norm = b / (np.linalg.norm(b, axis=-1, keepdims=True) + 1e-8)
|
|
7
|
+
return a_norm @ b_norm.T
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def dot_product(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
|
11
|
+
return a @ b.T
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def euclidean_distance(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
|
15
|
+
return -np.linalg.norm(a[:, None] - b[None, :], axis=-1)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def manhattan_distance(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
|
19
|
+
return -np.sum(np.abs(a[:, None] - b[None, :]), axis=-1)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
METRICS = {
|
|
23
|
+
"cosine": cosine_similarity,
|
|
24
|
+
"dot": dot_product,
|
|
25
|
+
"euclidean": euclidean_distance,
|
|
26
|
+
"manhattan": manhattan_distance,
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def get_metric(name: str):
|
|
31
|
+
if name not in METRICS:
|
|
32
|
+
raise ValueError(f"Unknown metric: {name}. Available: {list(METRICS.keys())}")
|
|
33
|
+
return METRICS[name]
|