schema-search 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. schema_search/__init__.py +26 -0
  2. schema_search/chunkers/__init__.py +6 -0
  3. schema_search/chunkers/base.py +95 -0
  4. schema_search/chunkers/factory.py +31 -0
  5. schema_search/chunkers/llm.py +54 -0
  6. schema_search/chunkers/markdown.py +25 -0
  7. schema_search/embedding_cache/__init__.py +5 -0
  8. schema_search/embedding_cache/base.py +40 -0
  9. schema_search/embedding_cache/bm25.py +63 -0
  10. schema_search/embedding_cache/factory.py +20 -0
  11. schema_search/embedding_cache/inmemory.py +122 -0
  12. schema_search/graph_builder.py +69 -0
  13. schema_search/mcp_server.py +81 -0
  14. schema_search/metrics.py +33 -0
  15. schema_search/rankers/__init__.py +5 -0
  16. schema_search/rankers/base.py +45 -0
  17. schema_search/rankers/cross_encoder.py +40 -0
  18. schema_search/rankers/factory.py +11 -0
  19. schema_search/schema_extractor.py +135 -0
  20. schema_search/schema_search.py +276 -0
  21. schema_search/search/__init__.py +15 -0
  22. schema_search/search/base.py +85 -0
  23. schema_search/search/bm25.py +48 -0
  24. schema_search/search/factory.py +61 -0
  25. schema_search/search/fuzzy.py +56 -0
  26. schema_search/search/hybrid.py +82 -0
  27. schema_search/search/semantic.py +49 -0
  28. schema_search/types.py +57 -0
  29. schema_search/utils/__init__.py +0 -0
  30. schema_search/utils/lazy_import.py +26 -0
  31. schema_search-0.1.10.dist-info/METADATA +308 -0
  32. schema_search-0.1.10.dist-info/RECORD +40 -0
  33. schema_search-0.1.10.dist-info/WHEEL +5 -0
  34. schema_search-0.1.10.dist-info/entry_points.txt +2 -0
  35. schema_search-0.1.10.dist-info/licenses/LICENSE +21 -0
  36. schema_search-0.1.10.dist-info/top_level.txt +2 -0
  37. tests/__init__.py +0 -0
  38. tests/test_integration.py +352 -0
  39. tests/test_llm_sql_generation.py +320 -0
  40. tests/test_spider_eval.py +488 -0
@@ -0,0 +1,26 @@
1
+ from schema_search.schema_search import SchemaSearch
2
+ from schema_search.types import (
3
+ IndexResult,
4
+ SearchResult,
5
+ SearchResultItem,
6
+ SearchType,
7
+ TableSchema,
8
+ ColumnInfo,
9
+ ForeignKeyInfo,
10
+ IndexInfo,
11
+ ConstraintInfo,
12
+ )
13
+
14
+ __version__ = "0.1.0"
15
+ __all__ = [
16
+ "SchemaSearch",
17
+ "IndexResult",
18
+ "SearchResult",
19
+ "SearchResultItem",
20
+ "SearchType",
21
+ "TableSchema",
22
+ "ColumnInfo",
23
+ "ForeignKeyInfo",
24
+ "IndexInfo",
25
+ "ConstraintInfo",
26
+ ]
@@ -0,0 +1,6 @@
1
+ from schema_search.chunkers.base import Chunk, BaseChunker
2
+ from schema_search.chunkers.markdown import MarkdownChunker
3
+ from schema_search.chunkers.llm import LLMChunker
4
+ from schema_search.chunkers.factory import create_chunker
5
+
6
+ __all__ = ["Chunk", "BaseChunker", "MarkdownChunker", "LLMChunker", "create_chunker"]
@@ -0,0 +1,95 @@
1
+ from typing import Dict, List
2
+ from dataclasses import dataclass
3
+ from abc import ABC, abstractmethod
4
+
5
+ from tqdm import tqdm
6
+
7
+ from schema_search.types import TableSchema
8
+
9
+
10
+ @dataclass
11
+ class Chunk:
12
+ table_name: str
13
+ content: str
14
+ chunk_id: int
15
+ token_count: int
16
+
17
+
18
+ class BaseChunker(ABC):
19
+ def __init__(self, max_tokens: int, overlap_tokens: int, show_progress: bool = False):
20
+ self.max_tokens = max_tokens
21
+ self.overlap_tokens = overlap_tokens
22
+ self.show_progress = show_progress
23
+
24
+ def chunk_schemas(self, schemas: Dict[str, TableSchema]) -> List[Chunk]:
25
+ chunks: List[Chunk] = []
26
+ chunk_id = 0
27
+
28
+ iterator = schemas.items()
29
+ if self.show_progress:
30
+ iterator = tqdm(iterator, desc="Chunking tables", unit="table")
31
+
32
+ for table_name, schema in iterator:
33
+ table_chunks = self._chunk_table(table_name, schema, chunk_id)
34
+ chunks.extend(table_chunks)
35
+ chunk_id += len(table_chunks)
36
+
37
+ return chunks
38
+
39
+ @abstractmethod
40
+ def _generate_content(self, table_name: str, schema: TableSchema) -> str:
41
+ pass
42
+
43
+ def _chunk_table(
44
+ self, table_name: str, schema: TableSchema, start_id: int
45
+ ) -> List[Chunk]:
46
+ content = self._generate_content(table_name, schema)
47
+ lines = content.split("\n")
48
+
49
+ header = f"Table: {table_name}"
50
+ header_tokens = self._estimate_tokens(header)
51
+
52
+ chunks: List[Chunk] = []
53
+ current_chunk_lines = [header]
54
+ current_tokens = header_tokens
55
+ chunk_id = start_id
56
+
57
+ for line in lines[1:]:
58
+ line_tokens = self._estimate_tokens(line)
59
+
60
+ if (
61
+ current_tokens + line_tokens > self.max_tokens
62
+ and len(current_chunk_lines) > 1
63
+ ):
64
+ chunk_content = "\n".join(current_chunk_lines)
65
+ chunks.append(
66
+ Chunk(
67
+ table_name=table_name,
68
+ content=chunk_content,
69
+ chunk_id=chunk_id,
70
+ token_count=current_tokens,
71
+ )
72
+ )
73
+ chunk_id += 1
74
+
75
+ current_chunk_lines = [header]
76
+ current_tokens = header_tokens
77
+
78
+ current_chunk_lines.append(line)
79
+ current_tokens += line_tokens
80
+
81
+ if len(current_chunk_lines) > 1:
82
+ chunk_content = "\n".join(current_chunk_lines)
83
+ chunks.append(
84
+ Chunk(
85
+ table_name=table_name,
86
+ content=chunk_content,
87
+ chunk_id=chunk_id,
88
+ token_count=current_tokens,
89
+ )
90
+ )
91
+
92
+ return chunks
93
+
94
+ def _estimate_tokens(self, text: str) -> int:
95
+ return len(text.split()) + len(text) // 4
@@ -0,0 +1,31 @@
1
+ from typing import Dict, Optional
2
+
3
+ from schema_search.chunkers.base import BaseChunker
4
+ from schema_search.chunkers.markdown import MarkdownChunker
5
+ from schema_search.chunkers.llm import LLMChunker
6
+
7
+
8
+ def create_chunker(
9
+ config: Dict, llm_api_key: Optional[str], llm_base_url: Optional[str]
10
+ ) -> BaseChunker:
11
+ chunking_config = config["chunking"]
12
+ strategy = chunking_config["strategy"]
13
+ show_progress = config["embedding"].get("show_progress", False)
14
+
15
+ if strategy == "llm":
16
+ return LLMChunker(
17
+ max_tokens=chunking_config["max_tokens"],
18
+ overlap_tokens=chunking_config["overlap_tokens"],
19
+ model=chunking_config["model"],
20
+ llm_api_key=llm_api_key,
21
+ llm_base_url=llm_base_url,
22
+ show_progress=show_progress,
23
+ )
24
+ elif strategy == "raw":
25
+ return MarkdownChunker(
26
+ max_tokens=chunking_config["max_tokens"],
27
+ overlap_tokens=chunking_config["overlap_tokens"],
28
+ show_progress=show_progress,
29
+ )
30
+ else:
31
+ raise ValueError(f"Unknown chunking strategy: {strategy}")
@@ -0,0 +1,54 @@
1
+ import json
2
+ import logging
3
+ from typing import Optional, TYPE_CHECKING
4
+
5
+ from schema_search.chunkers.base import BaseChunker
6
+ from schema_search.types import TableSchema
7
+ from schema_search.utils.lazy_import import lazy_import_check
8
+
9
+ if TYPE_CHECKING:
10
+ from openai import OpenAI
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class LLMChunker(BaseChunker):
16
+ def __init__(
17
+ self,
18
+ max_tokens: int,
19
+ overlap_tokens: int,
20
+ model: str,
21
+ llm_api_key: Optional[str],
22
+ llm_base_url: Optional[str],
23
+ show_progress: bool = False,
24
+ ):
25
+ super().__init__(max_tokens, overlap_tokens, show_progress)
26
+ self.model = model
27
+ openai = lazy_import_check("openai", "llm", "LLM-based chunking")
28
+ self.llm_client: "OpenAI" = openai.OpenAI(api_key=llm_api_key, base_url=llm_base_url)
29
+ logger.info(f"Schema Summarizer Model: {self.model}")
30
+
31
+ def _generate_content(self, table_name: str, schema: TableSchema) -> str:
32
+ prompt = f"""Generate a concise 250 tokens or less semantic summary of this database table schema. Focus on:
33
+ 1. What entity or concept this table represents
34
+ 2. Key data it stores (main columns)
35
+ 3. How it relates to other tables
36
+ 4. Any important constraints or indices
37
+
38
+ Keep it brief and semantic, optimized for embedding-based search.
39
+
40
+ Schema:
41
+ {json.dumps(schema, indent=2)}
42
+
43
+ Return ONLY the summary text, no preamble."""
44
+
45
+ response = self.llm_client.chat.completions.create(
46
+ model=self.model,
47
+ max_tokens=500,
48
+ messages=[{"role": "user", "content": prompt}],
49
+ )
50
+
51
+ summary = response.choices[0].message.content.strip() # type: ignore
52
+ logger.debug(f"Generated LLM summary for {table_name}: {summary[:100]}...")
53
+
54
+ return f"Table: {table_name}\n{summary}"
@@ -0,0 +1,25 @@
1
+ from schema_search.chunkers.base import BaseChunker
2
+ from schema_search.types import TableSchema
3
+
4
+
5
+ class MarkdownChunker(BaseChunker):
6
+ def _generate_content(self, table_name: str, schema: TableSchema) -> str:
7
+ lines = [f"Table: {table_name}"]
8
+
9
+ if schema["primary_keys"]:
10
+ lines.append(f"Primary keys: {', '.join(schema['primary_keys'])}")
11
+
12
+ if schema["columns"]:
13
+ col_names = [col["name"] for col in schema["columns"]]
14
+ lines.append(f"Columns: {', '.join(col_names)}")
15
+
16
+ if schema["foreign_keys"]:
17
+ related = [fk["referred_table"] for fk in schema["foreign_keys"]]
18
+ lines.append(f"Related to: {', '.join(related)}")
19
+
20
+ if schema["indices"]:
21
+ idx_names = [idx["name"] for idx in schema["indices"] if idx["name"]]
22
+ if idx_names:
23
+ lines.append(f"Indexes: {', '.join(idx_names)}")
24
+
25
+ return "\n".join(lines)
@@ -0,0 +1,5 @@
1
+ from schema_search.embedding_cache.base import BaseEmbeddingCache
2
+ from schema_search.embedding_cache.inmemory import InMemoryEmbeddingCache
3
+ from schema_search.embedding_cache.factory import create_embedding_cache
4
+
5
+ __all__ = ["BaseEmbeddingCache", "InMemoryEmbeddingCache", "create_embedding_cache"]
@@ -0,0 +1,40 @@
1
+ from abc import ABC, abstractmethod
2
+ from pathlib import Path
3
+ from typing import Dict, List
4
+
5
+ import numpy as np
6
+
7
+ from schema_search.chunkers import Chunk
8
+
9
+
10
+ class BaseEmbeddingCache(ABC):
11
+ def __init__(
12
+ self,
13
+ cache_dir: Path,
14
+ model_name: str,
15
+ metric: str,
16
+ batch_size: int,
17
+ show_progress: bool,
18
+ ):
19
+ self.cache_dir = cache_dir
20
+ self.cache_dir.mkdir(exist_ok=True)
21
+ self.model_name = model_name
22
+ self.model = None
23
+ self.metric = metric
24
+ self.batch_size = batch_size
25
+ self.show_progress = show_progress
26
+ self.embeddings = None
27
+
28
+ @abstractmethod
29
+ def load_or_generate(
30
+ self, chunks: List[Chunk], force: bool, chunking_config: Dict
31
+ ) -> None:
32
+ pass
33
+
34
+ @abstractmethod
35
+ def encode_query(self, query: str) -> np.ndarray:
36
+ pass
37
+
38
+ @abstractmethod
39
+ def compute_similarities(self, query_embedding: np.ndarray) -> np.ndarray:
40
+ pass
@@ -0,0 +1,63 @@
1
+ from typing import List
2
+ import re
3
+ import logging
4
+ import numpy as np
5
+
6
+ import bm25s
7
+
8
+ from schema_search.chunkers import Chunk
9
+
10
+ logging.getLogger("bm25s").setLevel(logging.WARNING)
11
+
12
+
13
+ def light_stem(token: str) -> str:
14
+ """Tiny rule-based stemmer for schema tokens."""
15
+ for suf in ("ing", "ers", "ies", "ied", "ed", "es", "s"):
16
+ if token.endswith(suf) and len(token) > len(suf) + 2:
17
+ if suf == "ies":
18
+ return token[:-3] + "y"
19
+ return token[: -len(suf)]
20
+ return token
21
+
22
+
23
+ def _tokenize(text: str) -> List[str]:
24
+ """Tokenize and normalize database-like text."""
25
+ text = text.lower()
26
+ text = text.replace("\n", " ")
27
+ text = re.sub(r"[_\-]+", " ", text)
28
+ text = re.sub(r"([a-z])([A-Z])", r"\1 \2", text)
29
+ text = re.sub(r"([a-z])([0-9])", r"\1 \2", text)
30
+ text = re.sub(r"([0-9])([a-z])", r"\1 \2", text)
31
+
32
+ tokens = re.findall(r"[a-z0-9]+", text)
33
+ normalized = []
34
+ for t in tokens:
35
+ if t in {"pk", "pkey", "key"}:
36
+ t = "id"
37
+ elif t in {"ts", "time", "timestamp"}:
38
+ t = "timestamp"
39
+ elif t.endswith("id") and len(t) > 2:
40
+ t = "id"
41
+ elif t in {"ix", "index", "idx"}:
42
+ t = "index"
43
+ normalized.append(light_stem(t))
44
+ return normalized
45
+
46
+
47
+ class BM25Cache:
48
+ def __init__(self):
49
+ self.bm25 = None
50
+ self.tokenized_docs = None
51
+
52
+ def build(self, chunks: List[Chunk]) -> None:
53
+ if self.bm25 is None:
54
+ self.tokenized_docs = [_tokenize(chunk.content) for chunk in chunks]
55
+ self.bm25 = bm25s.BM25()
56
+ self.bm25.index(self.tokenized_docs)
57
+
58
+ def get_scores(self, query: str) -> np.ndarray:
59
+ if self.bm25 is None or self.tokenized_docs is None:
60
+ raise RuntimeError("BM25 cache not built. Call build() first.")
61
+ query_tokens = _tokenize(query)
62
+ scores = self.bm25.get_scores(query_tokens)
63
+ return scores
@@ -0,0 +1,20 @@
1
+ from pathlib import Path
2
+ from typing import Dict
3
+
4
+ from schema_search.embedding_cache.base import BaseEmbeddingCache
5
+ from schema_search.embedding_cache.inmemory import InMemoryEmbeddingCache
6
+
7
+
8
+ def create_embedding_cache(config: Dict, cache_dir: Path) -> BaseEmbeddingCache:
9
+ location = config["embedding"]["location"]
10
+
11
+ if location == "memory":
12
+ return InMemoryEmbeddingCache(
13
+ cache_dir=cache_dir,
14
+ model_name=config["embedding"]["model"],
15
+ metric=config["embedding"]["metric"],
16
+ batch_size=config["embedding"]["batch_size"],
17
+ show_progress=config["embedding"]["show_progress"],
18
+ )
19
+ else:
20
+ raise ValueError(f"Unsupported embedding location: {location}")
@@ -0,0 +1,122 @@
1
+ import json
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import Dict, List, Optional, TYPE_CHECKING
5
+
6
+ import numpy as np
7
+
8
+ from schema_search.chunkers import Chunk
9
+ from schema_search.embedding_cache.base import BaseEmbeddingCache
10
+ from schema_search.metrics import get_metric
11
+ from schema_search.utils.lazy_import import lazy_import_check
12
+
13
+ if TYPE_CHECKING:
14
+ from sentence_transformers import SentenceTransformer
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class InMemoryEmbeddingCache(BaseEmbeddingCache):
20
+ def __init__(
21
+ self,
22
+ cache_dir: Path,
23
+ model_name: str,
24
+ metric: str,
25
+ batch_size: int,
26
+ show_progress: bool,
27
+ ):
28
+ super().__init__(cache_dir, model_name, metric, batch_size, show_progress)
29
+ self.model: Optional["SentenceTransformer"] = None
30
+
31
+ def load_or_generate(
32
+ self, chunks: List[Chunk], force: bool, chunking_config: Dict
33
+ ) -> None:
34
+ cache_file = self.cache_dir / "embeddings.npz"
35
+ config_file = self.cache_dir / "cache_config.json"
36
+
37
+ if not force and self._is_cache_valid(cache_file, config_file, chunking_config):
38
+ self._load_from_cache(cache_file)
39
+ else:
40
+ self._generate_and_cache(chunks, cache_file, config_file, chunking_config)
41
+
42
+ def _load_from_cache(self, cache_file: Path) -> None:
43
+ logger.info("Loading embeddings from cache")
44
+ self.embeddings = np.load(cache_file)["embeddings"]
45
+
46
+ def _is_cache_valid(
47
+ self, cache_file: Path, config_file: Path, chunking_config: Dict
48
+ ) -> bool:
49
+ if not (cache_file.exists() and config_file.exists()):
50
+ return False
51
+
52
+ with open(config_file) as f:
53
+ cached_config = json.load(f)
54
+
55
+ current_config = {
56
+ "strategy": chunking_config["strategy"],
57
+ "max_tokens": chunking_config["max_tokens"],
58
+ "embedding_model": self.model_name,
59
+ }
60
+
61
+ if cached_config != current_config:
62
+ logger.info("Cache invalidated: chunking config changed")
63
+ return False
64
+
65
+ return True
66
+
67
+ def _generate_and_cache(
68
+ self,
69
+ chunks: List[Chunk],
70
+ cache_file: Path,
71
+ config_file: Path,
72
+ chunking_config: Dict,
73
+ ) -> None:
74
+ self._load_model()
75
+
76
+ logger.info(f"Generating embeddings for {len(chunks)} chunks")
77
+ texts = [chunk.content for chunk in chunks]
78
+
79
+ assert self.model is not None
80
+ self.embeddings = self.model.encode(
81
+ texts,
82
+ batch_size=self.batch_size,
83
+ normalize_embeddings=True,
84
+ show_progress_bar=self.show_progress,
85
+ )
86
+
87
+ np.savez_compressed(cache_file, embeddings=self.embeddings)
88
+
89
+ cache_config = {
90
+ "strategy": chunking_config["strategy"],
91
+ "max_tokens": chunking_config["max_tokens"],
92
+ "embedding_model": self.model_name,
93
+ }
94
+ with open(config_file, "w") as f:
95
+ json.dump(cache_config, f, indent=2)
96
+
97
+ def _load_model(self) -> None:
98
+ if self.model is None:
99
+ sentence_transformers = lazy_import_check(
100
+ "sentence_transformers",
101
+ "semantic",
102
+ "semantic/hybrid search or reranking",
103
+ )
104
+ logging.getLogger("sentence_transformers").setLevel(logging.WARNING)
105
+ self.model = sentence_transformers.SentenceTransformer(self.model_name)
106
+ logger.info(f"Loaded embedding model: {self.model_name}")
107
+
108
+ def encode_query(self, query: str) -> np.ndarray:
109
+ self._load_model()
110
+
111
+ assert self.model is not None
112
+ query_emb = self.model.encode(
113
+ [query],
114
+ batch_size=self.batch_size,
115
+ normalize_embeddings=True,
116
+ )
117
+
118
+ return query_emb
119
+
120
+ def compute_similarities(self, query_embedding: np.ndarray) -> np.ndarray:
121
+ metric_fn = get_metric(self.metric)
122
+ return metric_fn(self.embeddings, query_embedding).flatten()
@@ -0,0 +1,69 @@
1
+ import logging
2
+ import pickle
3
+ from pathlib import Path
4
+ from typing import Dict, Set
5
+
6
+ import networkx as nx
7
+
8
+ from schema_search.types import TableSchema
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class GraphBuilder:
14
+ def __init__(self, cache_dir: Path):
15
+ self.cache_dir = cache_dir
16
+ self.cache_dir.mkdir(exist_ok=True)
17
+ self.graph: nx.DiGraph
18
+
19
+ def build(self, schemas: Dict[str, TableSchema], force: bool) -> None:
20
+ cache_file = self.cache_dir / "graph.pkl"
21
+
22
+ if not force and cache_file.exists():
23
+ self._load_from_cache(cache_file)
24
+ else:
25
+ self._build_and_cache(schemas, cache_file)
26
+
27
+ def _load_from_cache(self, cache_file: Path) -> None:
28
+ logger.debug(f"Loading graph from cache: {cache_file}")
29
+ with open(cache_file, "rb") as f:
30
+ self.graph = pickle.load(f)
31
+
32
+ def _build_and_cache(
33
+ self, schemas: Dict[str, TableSchema], cache_file: Path
34
+ ) -> None:
35
+ logger.info("Building foreign key relationship graph")
36
+ self.graph = nx.DiGraph()
37
+
38
+ for table_name, schema in schemas.items():
39
+ self.graph.add_node(table_name, **schema)
40
+
41
+ for table_name, schema in schemas.items():
42
+ if schema["foreign_keys"]:
43
+ for fk in schema["foreign_keys"]:
44
+ referred_table = fk["referred_table"]
45
+ if referred_table in self.graph:
46
+ self.graph.add_edge(table_name, referred_table, **fk)
47
+
48
+ with open(cache_file, "wb") as f:
49
+ pickle.dump(self.graph, f)
50
+
51
+ def get_neighbors(self, table_name: str, hops: int) -> Set[str]:
52
+ if table_name not in self.graph:
53
+ return set()
54
+
55
+ neighbors: Set[str] = set()
56
+
57
+ forward = nx.single_source_shortest_path_length(
58
+ self.graph, table_name, cutoff=hops
59
+ )
60
+ neighbors.update(forward.keys())
61
+
62
+ backward = nx.single_source_shortest_path_length(
63
+ self.graph.reverse(), table_name, cutoff=hops
64
+ )
65
+ neighbors.update(backward.keys())
66
+
67
+ neighbors.discard(table_name)
68
+
69
+ return neighbors
@@ -0,0 +1,81 @@
1
+ #!/usr/bin/env python3
2
+ import logging
3
+ from typing import Optional
4
+
5
+ from fastmcp import FastMCP
6
+ from sqlalchemy import create_engine
7
+
8
+ from schema_search import SchemaSearch
9
+
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ mcp = FastMCP("schema-search")
14
+
15
+
16
+ @mcp.tool()
17
+ def schema_search(
18
+ query: str,
19
+ limit: int = 5,
20
+ ) -> dict:
21
+ """Search database schema using natural language.
22
+
23
+ Finds relevant database tables and their relationships by searching through schema metadata
24
+ using semantic similarity. Expands results by traversing foreign key relationships.
25
+
26
+ Args:
27
+ query: Natural language question about database schema (e.g., 'tables related to payments')
28
+ limit: Maximum number of table schemas to return in results. Default: 5; Max: 10.
29
+
30
+ Returns:
31
+ Dictionary with 'results' (list of table schemas with columns, types, constraints, and relationships) and 'latency_sec' (query execution time)
32
+ """
33
+ limit = min(limit, 10)
34
+ search_result = mcp.search_engine.search(query, limit=limit) # type: ignore
35
+ return {
36
+ "results": search_result["results"],
37
+ "latency_sec": search_result["latency_sec"],
38
+ }
39
+
40
+
41
+ def run_server(
42
+ database_url: str,
43
+ config_path: Optional[str] = None,
44
+ llm_api_key: Optional[str] = None,
45
+ llm_base_url: Optional[str] = None,
46
+ ):
47
+ engine = create_engine(database_url)
48
+
49
+ mcp.search_engine = SchemaSearch( # type: ignore
50
+ engine,
51
+ config_path=config_path,
52
+ llm_api_key=llm_api_key,
53
+ llm_base_url=llm_base_url,
54
+ )
55
+
56
+ logger.info("Indexing database schema...")
57
+ mcp.search_engine.index() # type: ignore
58
+ logger.info("Index ready")
59
+
60
+ mcp.run()
61
+
62
+
63
+ def main():
64
+ import sys
65
+
66
+ if len(sys.argv) < 2:
67
+ print(
68
+ "Usage: schema-search <database_url> [config_path] [llm_api_key] [llm_base_url]"
69
+ )
70
+ sys.exit(1)
71
+
72
+ database_url = sys.argv[1]
73
+ config_path = sys.argv[2] if len(sys.argv) > 2 else None
74
+ llm_api_key = sys.argv[3] if len(sys.argv) > 3 else None
75
+ llm_base_url = sys.argv[4] if len(sys.argv) > 4 else None
76
+
77
+ run_server(database_url, config_path, llm_api_key, llm_base_url)
78
+
79
+
80
+ if __name__ == "__main__":
81
+ main()
@@ -0,0 +1,33 @@
1
+ import numpy as np
2
+
3
+
4
+ def cosine_similarity(a: np.ndarray, b: np.ndarray) -> np.ndarray:
5
+ a_norm = a / (np.linalg.norm(a, axis=-1, keepdims=True) + 1e-8)
6
+ b_norm = b / (np.linalg.norm(b, axis=-1, keepdims=True) + 1e-8)
7
+ return a_norm @ b_norm.T
8
+
9
+
10
+ def dot_product(a: np.ndarray, b: np.ndarray) -> np.ndarray:
11
+ return a @ b.T
12
+
13
+
14
+ def euclidean_distance(a: np.ndarray, b: np.ndarray) -> np.ndarray:
15
+ return -np.linalg.norm(a[:, None] - b[None, :], axis=-1)
16
+
17
+
18
+ def manhattan_distance(a: np.ndarray, b: np.ndarray) -> np.ndarray:
19
+ return -np.sum(np.abs(a[:, None] - b[None, :]), axis=-1)
20
+
21
+
22
+ METRICS = {
23
+ "cosine": cosine_similarity,
24
+ "dot": dot_product,
25
+ "euclidean": euclidean_distance,
26
+ "manhattan": manhattan_distance,
27
+ }
28
+
29
+
30
+ def get_metric(name: str):
31
+ if name not in METRICS:
32
+ raise ValueError(f"Unknown metric: {name}. Available: {list(METRICS.keys())}")
33
+ return METRICS[name]
@@ -0,0 +1,5 @@
1
+ from schema_search.rankers.base import BaseRanker
2
+ from schema_search.rankers.cross_encoder import CrossEncoderRanker
3
+ from schema_search.rankers.factory import create_ranker
4
+
5
+ __all__ = ["BaseRanker", "CrossEncoderRanker", "create_ranker"]