schema-search 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of schema-search might be problematic. Click here for more details.

Files changed (38) hide show
  1. schema_search/__init__.py +26 -0
  2. schema_search/chunkers/__init__.py +6 -0
  3. schema_search/chunkers/base.py +95 -0
  4. schema_search/chunkers/factory.py +31 -0
  5. schema_search/chunkers/llm.py +51 -0
  6. schema_search/chunkers/markdown.py +25 -0
  7. schema_search/embedding_cache/__init__.py +5 -0
  8. schema_search/embedding_cache/base.py +40 -0
  9. schema_search/embedding_cache/bm25.py +63 -0
  10. schema_search/embedding_cache/factory.py +20 -0
  11. schema_search/embedding_cache/inmemory.py +112 -0
  12. schema_search/graph_builder.py +69 -0
  13. schema_search/mcp_server.py +82 -0
  14. schema_search/metrics.py +33 -0
  15. schema_search/rankers/__init__.py +5 -0
  16. schema_search/rankers/base.py +45 -0
  17. schema_search/rankers/cross_encoder.py +34 -0
  18. schema_search/rankers/factory.py +11 -0
  19. schema_search/schema_extractor.py +135 -0
  20. schema_search/schema_search.py +263 -0
  21. schema_search/search/__init__.py +15 -0
  22. schema_search/search/base.py +85 -0
  23. schema_search/search/bm25.py +48 -0
  24. schema_search/search/factory.py +61 -0
  25. schema_search/search/fuzzy.py +56 -0
  26. schema_search/search/hybrid.py +82 -0
  27. schema_search/search/semantic.py +49 -0
  28. schema_search/types.py +57 -0
  29. schema_search-0.1.2.dist-info/METADATA +275 -0
  30. schema_search-0.1.2.dist-info/RECORD +38 -0
  31. schema_search-0.1.2.dist-info/WHEEL +5 -0
  32. schema_search-0.1.2.dist-info/entry_points.txt +2 -0
  33. schema_search-0.1.2.dist-info/licenses/LICENSE +21 -0
  34. schema_search-0.1.2.dist-info/top_level.txt +2 -0
  35. tests/__init__.py +0 -0
  36. tests/test_integration.py +352 -0
  37. tests/test_llm_sql_generation.py +320 -0
  38. tests/test_spider_eval.py +484 -0
@@ -0,0 +1,26 @@
1
+ from schema_search.schema_search import SchemaSearch
2
+ from schema_search.types import (
3
+ IndexResult,
4
+ SearchResult,
5
+ SearchResultItem,
6
+ SearchType,
7
+ TableSchema,
8
+ ColumnInfo,
9
+ ForeignKeyInfo,
10
+ IndexInfo,
11
+ ConstraintInfo,
12
+ )
13
+
14
+ __version__ = "0.1.0"
15
+ __all__ = [
16
+ "SchemaSearch",
17
+ "IndexResult",
18
+ "SearchResult",
19
+ "SearchResultItem",
20
+ "SearchType",
21
+ "TableSchema",
22
+ "ColumnInfo",
23
+ "ForeignKeyInfo",
24
+ "IndexInfo",
25
+ "ConstraintInfo",
26
+ ]
@@ -0,0 +1,6 @@
1
+ from schema_search.chunkers.base import Chunk, BaseChunker
2
+ from schema_search.chunkers.markdown import MarkdownChunker
3
+ from schema_search.chunkers.llm import LLMChunker
4
+ from schema_search.chunkers.factory import create_chunker
5
+
6
+ __all__ = ["Chunk", "BaseChunker", "MarkdownChunker", "LLMChunker", "create_chunker"]
@@ -0,0 +1,95 @@
1
+ from typing import Dict, List
2
+ from dataclasses import dataclass
3
+ from abc import ABC, abstractmethod
4
+
5
+ from tqdm import tqdm
6
+
7
+ from schema_search.types import TableSchema
8
+
9
+
10
+ @dataclass
11
+ class Chunk:
12
+ table_name: str
13
+ content: str
14
+ chunk_id: int
15
+ token_count: int
16
+
17
+
18
+ class BaseChunker(ABC):
19
+ def __init__(self, max_tokens: int, overlap_tokens: int, show_progress: bool = False):
20
+ self.max_tokens = max_tokens
21
+ self.overlap_tokens = overlap_tokens
22
+ self.show_progress = show_progress
23
+
24
+ def chunk_schemas(self, schemas: Dict[str, TableSchema]) -> List[Chunk]:
25
+ chunks: List[Chunk] = []
26
+ chunk_id = 0
27
+
28
+ iterator = schemas.items()
29
+ if self.show_progress:
30
+ iterator = tqdm(iterator, desc="Chunking tables", unit="table")
31
+
32
+ for table_name, schema in iterator:
33
+ table_chunks = self._chunk_table(table_name, schema, chunk_id)
34
+ chunks.extend(table_chunks)
35
+ chunk_id += len(table_chunks)
36
+
37
+ return chunks
38
+
39
+ @abstractmethod
40
+ def _generate_content(self, table_name: str, schema: TableSchema) -> str:
41
+ pass
42
+
43
+ def _chunk_table(
44
+ self, table_name: str, schema: TableSchema, start_id: int
45
+ ) -> List[Chunk]:
46
+ content = self._generate_content(table_name, schema)
47
+ lines = content.split("\n")
48
+
49
+ header = f"Table: {table_name}"
50
+ header_tokens = self._estimate_tokens(header)
51
+
52
+ chunks: List[Chunk] = []
53
+ current_chunk_lines = [header]
54
+ current_tokens = header_tokens
55
+ chunk_id = start_id
56
+
57
+ for line in lines[1:]:
58
+ line_tokens = self._estimate_tokens(line)
59
+
60
+ if (
61
+ current_tokens + line_tokens > self.max_tokens
62
+ and len(current_chunk_lines) > 1
63
+ ):
64
+ chunk_content = "\n".join(current_chunk_lines)
65
+ chunks.append(
66
+ Chunk(
67
+ table_name=table_name,
68
+ content=chunk_content,
69
+ chunk_id=chunk_id,
70
+ token_count=current_tokens,
71
+ )
72
+ )
73
+ chunk_id += 1
74
+
75
+ current_chunk_lines = [header]
76
+ current_tokens = header_tokens
77
+
78
+ current_chunk_lines.append(line)
79
+ current_tokens += line_tokens
80
+
81
+ if len(current_chunk_lines) > 1:
82
+ chunk_content = "\n".join(current_chunk_lines)
83
+ chunks.append(
84
+ Chunk(
85
+ table_name=table_name,
86
+ content=chunk_content,
87
+ chunk_id=chunk_id,
88
+ token_count=current_tokens,
89
+ )
90
+ )
91
+
92
+ return chunks
93
+
94
+ def _estimate_tokens(self, text: str) -> int:
95
+ return len(text.split()) + len(text) // 4
@@ -0,0 +1,31 @@
1
+ from typing import Dict, Optional
2
+
3
+ from schema_search.chunkers.base import BaseChunker
4
+ from schema_search.chunkers.markdown import MarkdownChunker
5
+ from schema_search.chunkers.llm import LLMChunker
6
+
7
+
8
+ def create_chunker(
9
+ config: Dict, llm_api_key: Optional[str], llm_base_url: Optional[str]
10
+ ) -> BaseChunker:
11
+ chunking_config = config["chunking"]
12
+ strategy = chunking_config["strategy"]
13
+ show_progress = config["embedding"].get("show_progress", False)
14
+
15
+ if strategy == "llm":
16
+ return LLMChunker(
17
+ max_tokens=chunking_config["max_tokens"],
18
+ overlap_tokens=chunking_config["overlap_tokens"],
19
+ model=chunking_config["model"],
20
+ llm_api_key=llm_api_key,
21
+ llm_base_url=llm_base_url,
22
+ show_progress=show_progress,
23
+ )
24
+ elif strategy == "raw":
25
+ return MarkdownChunker(
26
+ max_tokens=chunking_config["max_tokens"],
27
+ overlap_tokens=chunking_config["overlap_tokens"],
28
+ show_progress=show_progress,
29
+ )
30
+ else:
31
+ raise ValueError(f"Unknown chunking strategy: {strategy}")
@@ -0,0 +1,51 @@
1
+ import json
2
+ import logging
3
+ from typing import Optional
4
+
5
+ from openai import OpenAI
6
+
7
+ from schema_search.chunkers.base import BaseChunker
8
+ from schema_search.types import TableSchema
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class LLMChunker(BaseChunker):
14
+ def __init__(
15
+ self,
16
+ max_tokens: int,
17
+ overlap_tokens: int,
18
+ model: str,
19
+ llm_api_key: Optional[str],
20
+ llm_base_url: Optional[str],
21
+ show_progress: bool = False,
22
+ ):
23
+ super().__init__(max_tokens, overlap_tokens, show_progress)
24
+ self.model = model
25
+ self.llm_client = OpenAI(api_key=llm_api_key, base_url=llm_base_url)
26
+ logger.info(f"Schema Summarizer Model: {self.model}")
27
+
28
+ def _generate_content(self, table_name: str, schema: TableSchema) -> str:
29
+ prompt = f"""Generate a concise 250 tokens or less semantic summary of this database table schema. Focus on:
30
+ 1. What entity or concept this table represents
31
+ 2. Key data it stores (main columns)
32
+ 3. How it relates to other tables
33
+ 4. Any important constraints or indices
34
+
35
+ Keep it brief and semantic, optimized for embedding-based search.
36
+
37
+ Schema:
38
+ {json.dumps(schema, indent=2)}
39
+
40
+ Return ONLY the summary text, no preamble."""
41
+
42
+ response = self.llm_client.chat.completions.create(
43
+ model=self.model,
44
+ max_tokens=500,
45
+ messages=[{"role": "user", "content": prompt}],
46
+ )
47
+
48
+ summary = response.choices[0].message.content.strip() # type: ignore
49
+ logger.debug(f"Generated LLM summary for {table_name}: {summary[:100]}...")
50
+
51
+ return f"Table: {table_name}\n{summary}"
@@ -0,0 +1,25 @@
1
+ from schema_search.chunkers.base import BaseChunker
2
+ from schema_search.types import TableSchema
3
+
4
+
5
+ class MarkdownChunker(BaseChunker):
6
+ def _generate_content(self, table_name: str, schema: TableSchema) -> str:
7
+ lines = [f"Table: {table_name}"]
8
+
9
+ if schema["primary_keys"]:
10
+ lines.append(f"Primary keys: {', '.join(schema['primary_keys'])}")
11
+
12
+ if schema["columns"]:
13
+ col_names = [col["name"] for col in schema["columns"]]
14
+ lines.append(f"Columns: {', '.join(col_names)}")
15
+
16
+ if schema["foreign_keys"]:
17
+ related = [fk["referred_table"] for fk in schema["foreign_keys"]]
18
+ lines.append(f"Related to: {', '.join(related)}")
19
+
20
+ if schema["indices"]:
21
+ idx_names = [idx["name"] for idx in schema["indices"] if idx["name"]]
22
+ if idx_names:
23
+ lines.append(f"Indexes: {', '.join(idx_names)}")
24
+
25
+ return "\n".join(lines)
@@ -0,0 +1,5 @@
1
+ from schema_search.embedding_cache.base import BaseEmbeddingCache
2
+ from schema_search.embedding_cache.inmemory import InMemoryEmbeddingCache
3
+ from schema_search.embedding_cache.factory import create_embedding_cache
4
+
5
+ __all__ = ["BaseEmbeddingCache", "InMemoryEmbeddingCache", "create_embedding_cache"]
@@ -0,0 +1,40 @@
1
+ from abc import ABC, abstractmethod
2
+ from pathlib import Path
3
+ from typing import Dict, List
4
+
5
+ import numpy as np
6
+
7
+ from schema_search.chunkers import Chunk
8
+
9
+
10
+ class BaseEmbeddingCache(ABC):
11
+ def __init__(
12
+ self,
13
+ cache_dir: Path,
14
+ model_name: str,
15
+ metric: str,
16
+ batch_size: int,
17
+ show_progress: bool,
18
+ ):
19
+ self.cache_dir = cache_dir
20
+ self.cache_dir.mkdir(exist_ok=True)
21
+ self.model_name = model_name
22
+ self.model = None
23
+ self.metric = metric
24
+ self.batch_size = batch_size
25
+ self.show_progress = show_progress
26
+ self.embeddings = None
27
+
28
+ @abstractmethod
29
+ def load_or_generate(
30
+ self, chunks: List[Chunk], force: bool, chunking_config: Dict
31
+ ) -> None:
32
+ pass
33
+
34
+ @abstractmethod
35
+ def encode_query(self, query: str) -> np.ndarray:
36
+ pass
37
+
38
+ @abstractmethod
39
+ def compute_similarities(self, query_embedding: np.ndarray) -> np.ndarray:
40
+ pass
@@ -0,0 +1,63 @@
1
+ from typing import List
2
+ import re
3
+ import logging
4
+ import numpy as np
5
+
6
+ import bm25s
7
+
8
+ from schema_search.chunkers import Chunk
9
+
10
+ logging.getLogger("bm25s").setLevel(logging.WARNING)
11
+
12
+
13
+ def light_stem(token: str) -> str:
14
+ """Tiny rule-based stemmer for schema tokens."""
15
+ for suf in ("ing", "ers", "ies", "ied", "ed", "es", "s"):
16
+ if token.endswith(suf) and len(token) > len(suf) + 2:
17
+ if suf == "ies":
18
+ return token[:-3] + "y"
19
+ return token[: -len(suf)]
20
+ return token
21
+
22
+
23
+ def _tokenize(text: str) -> List[str]:
24
+ """Tokenize and normalize database-like text."""
25
+ text = text.lower()
26
+ text = text.replace("\n", " ")
27
+ text = re.sub(r"[_\-]+", " ", text)
28
+ text = re.sub(r"([a-z])([A-Z])", r"\1 \2", text)
29
+ text = re.sub(r"([a-z])([0-9])", r"\1 \2", text)
30
+ text = re.sub(r"([0-9])([a-z])", r"\1 \2", text)
31
+
32
+ tokens = re.findall(r"[a-z0-9]+", text)
33
+ normalized = []
34
+ for t in tokens:
35
+ if t in {"pk", "pkey", "key"}:
36
+ t = "id"
37
+ elif t in {"ts", "time", "timestamp"}:
38
+ t = "timestamp"
39
+ elif t.endswith("id") and len(t) > 2:
40
+ t = "id"
41
+ elif t in {"ix", "index", "idx"}:
42
+ t = "index"
43
+ normalized.append(light_stem(t))
44
+ return normalized
45
+
46
+
47
+ class BM25Cache:
48
+ def __init__(self):
49
+ self.bm25 = None
50
+ self.tokenized_docs = None
51
+
52
+ def build(self, chunks: List[Chunk]) -> None:
53
+ if self.bm25 is None:
54
+ self.tokenized_docs = [_tokenize(chunk.content) for chunk in chunks]
55
+ self.bm25 = bm25s.BM25()
56
+ self.bm25.index(self.tokenized_docs)
57
+
58
+ def get_scores(self, query: str) -> np.ndarray:
59
+ if self.bm25 is None or self.tokenized_docs is None:
60
+ raise RuntimeError("BM25 cache not built. Call build() first.")
61
+ query_tokens = _tokenize(query)
62
+ scores = self.bm25.get_scores(query_tokens)
63
+ return scores
@@ -0,0 +1,20 @@
1
+ from pathlib import Path
2
+ from typing import Dict
3
+
4
+ from schema_search.embedding_cache.base import BaseEmbeddingCache
5
+ from schema_search.embedding_cache.inmemory import InMemoryEmbeddingCache
6
+
7
+
8
+ def create_embedding_cache(config: Dict, cache_dir: Path) -> BaseEmbeddingCache:
9
+ location = config["embedding"]["location"]
10
+
11
+ if location == "memory":
12
+ return InMemoryEmbeddingCache(
13
+ cache_dir=cache_dir,
14
+ model_name=config["embedding"]["model"],
15
+ metric=config["embedding"]["metric"],
16
+ batch_size=config["embedding"]["batch_size"],
17
+ show_progress=config["embedding"]["show_progress"],
18
+ )
19
+ else:
20
+ raise ValueError(f"Unsupported embedding location: {location}")
@@ -0,0 +1,112 @@
1
+ import json
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import Dict, List
5
+
6
+ import numpy as np
7
+ from sentence_transformers import SentenceTransformer
8
+
9
+ from schema_search.chunkers import Chunk
10
+ from schema_search.embedding_cache.base import BaseEmbeddingCache
11
+ from schema_search.metrics import get_metric
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class InMemoryEmbeddingCache(BaseEmbeddingCache):
17
+ def __init__(
18
+ self,
19
+ cache_dir: Path,
20
+ model_name: str,
21
+ metric: str,
22
+ batch_size: int,
23
+ show_progress: bool,
24
+ ):
25
+ super().__init__(cache_dir, model_name, metric, batch_size, show_progress)
26
+ self.model: SentenceTransformer
27
+
28
+ def load_or_generate(
29
+ self, chunks: List[Chunk], force: bool, chunking_config: Dict
30
+ ) -> None:
31
+ cache_file = self.cache_dir / "embeddings.npz"
32
+ config_file = self.cache_dir / "cache_config.json"
33
+
34
+ if not force and self._is_cache_valid(cache_file, config_file, chunking_config):
35
+ self._load_from_cache(cache_file)
36
+ else:
37
+ self._generate_and_cache(chunks, cache_file, config_file, chunking_config)
38
+
39
+ def _load_from_cache(self, cache_file: Path) -> None:
40
+ logger.info("Loading embeddings from cache")
41
+ self.embeddings = np.load(cache_file)["embeddings"]
42
+
43
+ def _is_cache_valid(
44
+ self, cache_file: Path, config_file: Path, chunking_config: Dict
45
+ ) -> bool:
46
+ if not (cache_file.exists() and config_file.exists()):
47
+ return False
48
+
49
+ with open(config_file) as f:
50
+ cached_config = json.load(f)
51
+
52
+ current_config = {
53
+ "strategy": chunking_config["strategy"],
54
+ "max_tokens": chunking_config["max_tokens"],
55
+ "embedding_model": self.model_name,
56
+ }
57
+
58
+ if cached_config != current_config:
59
+ logger.info("Cache invalidated: chunking config changed")
60
+ return False
61
+
62
+ return True
63
+
64
+ def _generate_and_cache(
65
+ self,
66
+ chunks: List[Chunk],
67
+ cache_file: Path,
68
+ config_file: Path,
69
+ chunking_config: Dict,
70
+ ) -> None:
71
+ self._load_model()
72
+
73
+ logger.info(f"Generating embeddings for {len(chunks)} chunks")
74
+ texts = [chunk.content for chunk in chunks]
75
+
76
+ self.embeddings = self.model.encode(
77
+ texts,
78
+ batch_size=self.batch_size,
79
+ normalize_embeddings=True,
80
+ show_progress_bar=self.show_progress,
81
+ )
82
+
83
+ np.savez_compressed(cache_file, embeddings=self.embeddings)
84
+
85
+ cache_config = {
86
+ "strategy": chunking_config["strategy"],
87
+ "max_tokens": chunking_config["max_tokens"],
88
+ "embedding_model": self.model_name,
89
+ }
90
+ with open(config_file, "w") as f:
91
+ json.dump(cache_config, f, indent=2)
92
+
93
+ def _load_model(self) -> None:
94
+ if self.model is None:
95
+ logging.getLogger("sentence_transformers").setLevel(logging.WARNING)
96
+ self.model = SentenceTransformer(self.model_name)
97
+ logger.info(f"Loaded embedding model: {self.model_name}")
98
+
99
+ def encode_query(self, query: str) -> np.ndarray:
100
+ self._load_model()
101
+
102
+ query_emb = self.model.encode(
103
+ [query],
104
+ batch_size=self.batch_size,
105
+ normalize_embeddings=True,
106
+ )
107
+
108
+ return query_emb
109
+
110
+ def compute_similarities(self, query_embedding: np.ndarray) -> np.ndarray:
111
+ metric_fn = get_metric(self.metric)
112
+ return metric_fn(self.embeddings, query_embedding).flatten()
@@ -0,0 +1,69 @@
1
+ import logging
2
+ import pickle
3
+ from pathlib import Path
4
+ from typing import Dict, Set
5
+
6
+ import networkx as nx
7
+
8
+ from schema_search.types import TableSchema
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class GraphBuilder:
14
+ def __init__(self, cache_dir: Path):
15
+ self.cache_dir = cache_dir
16
+ self.cache_dir.mkdir(exist_ok=True)
17
+ self.graph: nx.DiGraph
18
+
19
+ def build(self, schemas: Dict[str, TableSchema], force: bool) -> None:
20
+ cache_file = self.cache_dir / "graph.pkl"
21
+
22
+ if not force and cache_file.exists():
23
+ self._load_from_cache(cache_file)
24
+ else:
25
+ self._build_and_cache(schemas, cache_file)
26
+
27
+ def _load_from_cache(self, cache_file: Path) -> None:
28
+ logger.debug(f"Loading graph from cache: {cache_file}")
29
+ with open(cache_file, "rb") as f:
30
+ self.graph = pickle.load(f)
31
+
32
+ def _build_and_cache(
33
+ self, schemas: Dict[str, TableSchema], cache_file: Path
34
+ ) -> None:
35
+ logger.info("Building foreign key relationship graph")
36
+ self.graph = nx.DiGraph()
37
+
38
+ for table_name, schema in schemas.items():
39
+ self.graph.add_node(table_name, **schema)
40
+
41
+ for table_name, schema in schemas.items():
42
+ if schema["foreign_keys"]:
43
+ for fk in schema["foreign_keys"]:
44
+ referred_table = fk["referred_table"]
45
+ if referred_table in self.graph:
46
+ self.graph.add_edge(table_name, referred_table, **fk)
47
+
48
+ with open(cache_file, "wb") as f:
49
+ pickle.dump(self.graph, f)
50
+
51
+ def get_neighbors(self, table_name: str, hops: int) -> Set[str]:
52
+ if table_name not in self.graph:
53
+ return set()
54
+
55
+ neighbors: Set[str] = set()
56
+
57
+ forward = nx.single_source_shortest_path_length(
58
+ self.graph, table_name, cutoff=hops
59
+ )
60
+ neighbors.update(forward.keys())
61
+
62
+ backward = nx.single_source_shortest_path_length(
63
+ self.graph.reverse(), table_name, cutoff=hops
64
+ )
65
+ neighbors.update(backward.keys())
66
+
67
+ neighbors.discard(table_name)
68
+
69
+ return neighbors
@@ -0,0 +1,82 @@
1
+ #!/usr/bin/env python3
2
+ import logging
3
+ from typing import Optional
4
+
5
+ from fastmcp import FastMCP
6
+ from sqlalchemy import create_engine
7
+
8
+ from schema_search import SchemaSearch
9
+
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ mcp = FastMCP("schema-search")
14
+
15
+
16
+ @mcp.tool()
17
+ def schema_search(
18
+ query: str,
19
+ hops: Optional[int] = None,
20
+ limit: int = 5,
21
+ ) -> dict:
22
+ """Search database schema using natural language.
23
+
24
+ Finds relevant database tables and their relationships by searching through schema metadata
25
+ using semantic similarity. Expands results by traversing foreign key relationships.
26
+
27
+ Args:
28
+ query: Natural language question about database schema (e.g., 'where are user refunds stored?', 'tables related to payments')
29
+ hops: Number of foreign key relationship hops for graph expansion. Use 0 for exact matches only, 1-2 to include related tables. If not specified, uses value from config.yml (default: 1)
30
+ limit: Maximum number of table schemas to return in results. Default: 5
31
+
32
+ Returns:
33
+ Dictionary with 'results' (list of table schemas with columns, types, constraints, and relationships) and 'latency_sec' (query execution time)
34
+ """
35
+ search_result = mcp.search_engine.search(query, hops=hops, limit=limit) # type: ignore
36
+ return {
37
+ "results": search_result["results"],
38
+ "latency_sec": search_result["latency_sec"],
39
+ }
40
+
41
+
42
+ def run_server(
43
+ database_url: str,
44
+ config_path: Optional[str] = None,
45
+ llm_api_key: Optional[str] = None,
46
+ llm_base_url: Optional[str] = None,
47
+ ):
48
+ engine = create_engine(database_url)
49
+
50
+ mcp.search_engine = SchemaSearch( # type: ignore
51
+ engine,
52
+ config_path=config_path,
53
+ llm_api_key=llm_api_key,
54
+ llm_base_url=llm_base_url,
55
+ )
56
+
57
+ logger.info("Indexing database schema...")
58
+ mcp.search_engine.index() # type: ignore
59
+ logger.info("Index ready")
60
+
61
+ mcp.run()
62
+
63
+
64
+ def main():
65
+ import sys
66
+
67
+ if len(sys.argv) < 2:
68
+ print(
69
+ "Usage: schema-search-mcp <database_url> [config_path] [llm_api_key] [llm_base_url]"
70
+ )
71
+ sys.exit(1)
72
+
73
+ database_url = sys.argv[1]
74
+ config_path = sys.argv[2] if len(sys.argv) > 2 else None
75
+ llm_api_key = sys.argv[3] if len(sys.argv) > 3 else None
76
+ llm_base_url = sys.argv[4] if len(sys.argv) > 4 else None
77
+
78
+ run_server(database_url, config_path, llm_api_key, llm_base_url)
79
+
80
+
81
+ if __name__ == "__main__":
82
+ main()
@@ -0,0 +1,33 @@
1
+ import numpy as np
2
+
3
+
4
+ def cosine_similarity(a: np.ndarray, b: np.ndarray) -> np.ndarray:
5
+ a_norm = a / (np.linalg.norm(a, axis=-1, keepdims=True) + 1e-8)
6
+ b_norm = b / (np.linalg.norm(b, axis=-1, keepdims=True) + 1e-8)
7
+ return a_norm @ b_norm.T
8
+
9
+
10
+ def dot_product(a: np.ndarray, b: np.ndarray) -> np.ndarray:
11
+ return a @ b.T
12
+
13
+
14
+ def euclidean_distance(a: np.ndarray, b: np.ndarray) -> np.ndarray:
15
+ return -np.linalg.norm(a[:, None] - b[None, :], axis=-1)
16
+
17
+
18
+ def manhattan_distance(a: np.ndarray, b: np.ndarray) -> np.ndarray:
19
+ return -np.sum(np.abs(a[:, None] - b[None, :]), axis=-1)
20
+
21
+
22
+ METRICS = {
23
+ "cosine": cosine_similarity,
24
+ "dot": dot_product,
25
+ "euclidean": euclidean_distance,
26
+ "manhattan": manhattan_distance,
27
+ }
28
+
29
+
30
+ def get_metric(name: str):
31
+ if name not in METRICS:
32
+ raise ValueError(f"Unknown metric: {name}. Available: {list(METRICS.keys())}")
33
+ return METRICS[name]
@@ -0,0 +1,5 @@
1
+ from schema_search.rankers.base import BaseRanker
2
+ from schema_search.rankers.cross_encoder import CrossEncoderRanker
3
+ from schema_search.rankers.factory import create_ranker
4
+
5
+ __all__ = ["BaseRanker", "CrossEncoderRanker", "create_ranker"]