msaas-rag 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- msaas_rag-1.0.0.dist-info/METADATA +20 -0
- msaas_rag-1.0.0.dist-info/RECORD +11 -0
- msaas_rag-1.0.0.dist-info/WHEEL +4 -0
- rag/__init__.py +54 -0
- rag/chunking.py +202 -0
- rag/config.py +46 -0
- rag/embeddings.py +136 -0
- rag/models.py +73 -0
- rag/pipeline.py +110 -0
- rag/reranker.py +113 -0
- rag/vector_store.py +293 -0
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: msaas-rag
|
|
3
|
+
Version: 1.0.0
|
|
4
|
+
Summary: RAG pipeline library — chunking, embeddings, vector search, and retrieval for the Willian SaaS platform
|
|
5
|
+
License: MIT
|
|
6
|
+
Requires-Python: >=3.12
|
|
7
|
+
Requires-Dist: asyncpg>=0.30.0
|
|
8
|
+
Requires-Dist: pydantic>=2.0
|
|
9
|
+
Provides-Extra: all
|
|
10
|
+
Requires-Dist: numpy>=2.0; extra == 'all'
|
|
11
|
+
Requires-Dist: openai>=1.50.0; extra == 'all'
|
|
12
|
+
Provides-Extra: dev
|
|
13
|
+
Requires-Dist: numpy>=2.0; extra == 'dev'
|
|
14
|
+
Requires-Dist: pytest-asyncio>=0.24.0; extra == 'dev'
|
|
15
|
+
Requires-Dist: pytest>=8.0; extra == 'dev'
|
|
16
|
+
Requires-Dist: ruff>=0.8; extra == 'dev'
|
|
17
|
+
Provides-Extra: numpy
|
|
18
|
+
Requires-Dist: numpy>=2.0; extra == 'numpy'
|
|
19
|
+
Provides-Extra: openai
|
|
20
|
+
Requires-Dist: openai>=1.50.0; extra == 'openai'
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
rag/__init__.py,sha256=ca84grhMOHLh_cfFsJtbAU-jp2DkqEPjAjJwO65tdGo,1235
|
|
2
|
+
rag/chunking.py,sha256=XpuIU44VMyHlp_YuSKbIZpEvQct1VsryCs-qHDuufgA,7538
|
|
3
|
+
rag/config.py,sha256=ITny9Edku1M82KdzfeZ-SmZA9wotR5P7O8cUO6JDKgQ,1240
|
|
4
|
+
rag/embeddings.py,sha256=UwK63qcJEBvkXPXkYCbljx_pSWRG-DNZRmAkrln840o,4870
|
|
5
|
+
rag/models.py,sha256=NdtpR6X9I5lSnMISWOQ9c8iY4ranBg6qsME4PxvPUg8,1777
|
|
6
|
+
rag/pipeline.py,sha256=DAmgmzwa_BnnJ27hgf9l0-e0rPoAfVoNN3br-1NATAc,3631
|
|
7
|
+
rag/reranker.py,sha256=2a7veHSoIA5SgOGxEMzVLAdzKizY0O9jbkz8EBN7sl8,3605
|
|
8
|
+
rag/vector_store.py,sha256=8IuMhv9fkexbkd2BSpyKoQQWqjzzS37gw-nJDJZJM8w,10659
|
|
9
|
+
msaas_rag-1.0.0.dist-info/METADATA,sha256=79VoK8PKNFWy-xg35JoHjVbJEiSFErWcYhdjrH_YsH8,711
|
|
10
|
+
msaas_rag-1.0.0.dist-info/WHEEL,sha256=mffPy8wBnZQn2VnJUU5jE99KsxaSfiyMHV9Yt0aLVxs,87
|
|
11
|
+
msaas_rag-1.0.0.dist-info/RECORD,,
|
rag/__init__.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Willian RAG — retrieval-augmented generation pipeline library."""
|
|
2
|
+
|
|
3
|
+
from rag.chunking import MarkdownChunker, TextChunker
|
|
4
|
+
from rag.config import get_config, get_pipeline, init_rag
|
|
5
|
+
from rag.embeddings import EmbeddingProvider, LocalEmbeddings, OpenAIEmbeddings
|
|
6
|
+
from rag.models import (
|
|
7
|
+
Chunk,
|
|
8
|
+
ChunkingConfig,
|
|
9
|
+
Document,
|
|
10
|
+
EmbeddingConfig,
|
|
11
|
+
RAGConfig,
|
|
12
|
+
SearchResult,
|
|
13
|
+
VectorStoreConfig,
|
|
14
|
+
)
|
|
15
|
+
from rag.pipeline import RAGPipeline
|
|
16
|
+
from rag.reranker import (
|
|
17
|
+
CrossEncoderReranker,
|
|
18
|
+
LLMReranker,
|
|
19
|
+
Reranker,
|
|
20
|
+
reciprocal_rank_fusion,
|
|
21
|
+
)
|
|
22
|
+
from rag.vector_store import InMemoryVectorStore, PgVectorStore, VectorStore
|
|
23
|
+
|
|
24
|
+
__all__ = [
|
|
25
|
+
# Pipeline
|
|
26
|
+
"RAGPipeline",
|
|
27
|
+
"init_rag",
|
|
28
|
+
"get_config",
|
|
29
|
+
"get_pipeline",
|
|
30
|
+
# Models
|
|
31
|
+
"Document",
|
|
32
|
+
"Chunk",
|
|
33
|
+
"SearchResult",
|
|
34
|
+
"RAGConfig",
|
|
35
|
+
"ChunkingConfig",
|
|
36
|
+
"EmbeddingConfig",
|
|
37
|
+
"VectorStoreConfig",
|
|
38
|
+
# Chunking
|
|
39
|
+
"TextChunker",
|
|
40
|
+
"MarkdownChunker",
|
|
41
|
+
# Embeddings
|
|
42
|
+
"EmbeddingProvider",
|
|
43
|
+
"OpenAIEmbeddings",
|
|
44
|
+
"LocalEmbeddings",
|
|
45
|
+
# Vector stores
|
|
46
|
+
"VectorStore",
|
|
47
|
+
"PgVectorStore",
|
|
48
|
+
"InMemoryVectorStore",
|
|
49
|
+
# Reranking
|
|
50
|
+
"Reranker",
|
|
51
|
+
"CrossEncoderReranker",
|
|
52
|
+
"LLMReranker",
|
|
53
|
+
"reciprocal_rank_fusion",
|
|
54
|
+
]
|
rag/chunking.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
"""Text chunking strategies for document processing."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
import uuid
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from rag.models import Chunk
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TextChunker:
|
|
13
|
+
"""Split plain text into chunks using configurable strategies.
|
|
14
|
+
|
|
15
|
+
Supported strategies:
|
|
16
|
+
- ``fixed_size``: split at exact character boundaries
|
|
17
|
+
- ``recursive``: split by paragraph > sentence > character, preferring natural breaks
|
|
18
|
+
- ``semantic``: sentence-based splitting that groups sentences up to chunk_size
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
SEPARATORS = ["\n\n", "\n", ". ", " "]
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
strategy: str = "recursive",
|
|
26
|
+
chunk_size: int = 512,
|
|
27
|
+
chunk_overlap: int = 64,
|
|
28
|
+
) -> None:
|
|
29
|
+
if strategy not in ("fixed_size", "recursive", "semantic"):
|
|
30
|
+
msg = f"Unknown strategy: {strategy!r}. Use fixed_size, recursive, or semantic."
|
|
31
|
+
raise ValueError(msg)
|
|
32
|
+
self.strategy = strategy
|
|
33
|
+
self.chunk_size = chunk_size
|
|
34
|
+
self.chunk_overlap = chunk_overlap
|
|
35
|
+
|
|
36
|
+
def chunk(
|
|
37
|
+
self,
|
|
38
|
+
text: str,
|
|
39
|
+
document_id: str,
|
|
40
|
+
metadata: dict[str, Any] | None = None,
|
|
41
|
+
) -> list[Chunk]:
|
|
42
|
+
"""Split *text* into Chunk objects."""
|
|
43
|
+
meta = metadata or {}
|
|
44
|
+
if self.strategy == "fixed_size":
|
|
45
|
+
return self._fixed_size(text, document_id, meta)
|
|
46
|
+
if self.strategy == "semantic":
|
|
47
|
+
return self._semantic(text, document_id, meta)
|
|
48
|
+
return self._recursive(text, document_id, meta)
|
|
49
|
+
|
|
50
|
+
# -- strategies ----------------------------------------------------------
|
|
51
|
+
|
|
52
|
+
def _fixed_size(self, text: str, document_id: str, meta: dict[str, Any]) -> list[Chunk]:
|
|
53
|
+
chunks: list[Chunk] = []
|
|
54
|
+
step = max(1, self.chunk_size - self.chunk_overlap)
|
|
55
|
+
for start in range(0, len(text), step):
|
|
56
|
+
end = min(start + self.chunk_size, len(text))
|
|
57
|
+
chunks.append(self._make(text[start:end], document_id, start, end, meta))
|
|
58
|
+
if end == len(text):
|
|
59
|
+
break
|
|
60
|
+
return chunks
|
|
61
|
+
|
|
62
|
+
def _recursive(self, text: str, document_id: str, meta: dict[str, Any]) -> list[Chunk]:
|
|
63
|
+
return self._recursive_split(text, document_id, meta, self.SEPARATORS)
|
|
64
|
+
|
|
65
|
+
def _recursive_split(
|
|
66
|
+
self,
|
|
67
|
+
text: str,
|
|
68
|
+
document_id: str,
|
|
69
|
+
meta: dict[str, Any],
|
|
70
|
+
separators: list[str],
|
|
71
|
+
) -> list[Chunk]:
|
|
72
|
+
if len(text) <= self.chunk_size:
|
|
73
|
+
return [self._make(text, document_id, 0, len(text), meta)]
|
|
74
|
+
|
|
75
|
+
sep = separators[0] if separators else ""
|
|
76
|
+
parts = text.split(sep) if sep else list(text)
|
|
77
|
+
remaining_seps = separators[1:] if separators else []
|
|
78
|
+
|
|
79
|
+
chunks: list[Chunk] = []
|
|
80
|
+
current = ""
|
|
81
|
+
offset = 0
|
|
82
|
+
|
|
83
|
+
for i, part in enumerate(parts):
|
|
84
|
+
candidate = current + (sep if current else "") + part
|
|
85
|
+
if len(candidate) > self.chunk_size and current:
|
|
86
|
+
start = offset
|
|
87
|
+
end = offset + len(current)
|
|
88
|
+
chunks.append(self._make(current, document_id, start, end, meta))
|
|
89
|
+
# Compute overlap start
|
|
90
|
+
overlap_start = max(0, len(current) - self.chunk_overlap)
|
|
91
|
+
offset = offset + len(current) - (len(current) - overlap_start) + len(sep)
|
|
92
|
+
current = (
|
|
93
|
+
current[overlap_start:] + sep + part if overlap_start < len(current) else part
|
|
94
|
+
)
|
|
95
|
+
elif len(candidate) > self.chunk_size and remaining_seps:
|
|
96
|
+
sub = self._recursive_split(part, document_id, meta, remaining_seps)
|
|
97
|
+
chunks.extend(sub)
|
|
98
|
+
current = ""
|
|
99
|
+
offset += len(part) + len(sep)
|
|
100
|
+
else:
|
|
101
|
+
current = candidate
|
|
102
|
+
if not current.startswith(part) and i == 0:
|
|
103
|
+
pass # offset stays
|
|
104
|
+
if current.strip():
|
|
105
|
+
chunks.append(self._make(current, document_id, offset, offset + len(current), meta))
|
|
106
|
+
return chunks
|
|
107
|
+
|
|
108
|
+
def _semantic(self, text: str, document_id: str, meta: dict[str, Any]) -> list[Chunk]:
|
|
109
|
+
"""Sentence-based chunking: group sentences up to chunk_size."""
|
|
110
|
+
sentences = re.split(r"(?<=[.!?])\s+", text)
|
|
111
|
+
chunks: list[Chunk] = []
|
|
112
|
+
current = ""
|
|
113
|
+
offset = 0
|
|
114
|
+
|
|
115
|
+
for sentence in sentences:
|
|
116
|
+
candidate = (current + " " + sentence).strip() if current else sentence
|
|
117
|
+
if len(candidate) > self.chunk_size and current:
|
|
118
|
+
end = offset + len(current)
|
|
119
|
+
chunks.append(self._make(current, document_id, offset, end, meta))
|
|
120
|
+
offset = end + 1
|
|
121
|
+
current = sentence
|
|
122
|
+
else:
|
|
123
|
+
current = candidate
|
|
124
|
+
if current.strip():
|
|
125
|
+
chunks.append(self._make(current, document_id, offset, offset + len(current), meta))
|
|
126
|
+
return chunks
|
|
127
|
+
|
|
128
|
+
@staticmethod
|
|
129
|
+
def _make(text: str, document_id: str, start: int, end: int, meta: dict[str, Any]) -> Chunk:
|
|
130
|
+
return Chunk(
|
|
131
|
+
id=uuid.uuid4().hex,
|
|
132
|
+
text=text,
|
|
133
|
+
document_id=document_id,
|
|
134
|
+
start_idx=start,
|
|
135
|
+
end_idx=end,
|
|
136
|
+
metadata=meta,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class MarkdownChunker:
|
|
141
|
+
"""Split Markdown documents by headers, preserving structure.
|
|
142
|
+
|
|
143
|
+
Each header section becomes a chunk. Sections exceeding ``chunk_size`` are
|
|
144
|
+
further split by the given ``fallback`` TextChunker strategy.
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
HEADER_RE = re.compile(r"^(#{1,6})\s+(.*)", re.MULTILINE)
|
|
148
|
+
|
|
149
|
+
def __init__(self, chunk_size: int = 1024, chunk_overlap: int = 64) -> None:
|
|
150
|
+
self.chunk_size = chunk_size
|
|
151
|
+
self.chunk_overlap = chunk_overlap
|
|
152
|
+
self._fallback = TextChunker(
|
|
153
|
+
strategy="recursive", chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
def chunk(
|
|
157
|
+
self,
|
|
158
|
+
text: str,
|
|
159
|
+
document_id: str,
|
|
160
|
+
metadata: dict[str, Any] | None = None,
|
|
161
|
+
) -> list[Chunk]:
|
|
162
|
+
meta = metadata or {}
|
|
163
|
+
sections = self._split_by_headers(text)
|
|
164
|
+
chunks: list[Chunk] = []
|
|
165
|
+
|
|
166
|
+
for header, body, start in sections:
|
|
167
|
+
section_meta = {**meta}
|
|
168
|
+
if header:
|
|
169
|
+
section_meta["header"] = header
|
|
170
|
+
full_text = f"{header}\n{body}".strip() if header else body.strip()
|
|
171
|
+
if len(full_text) <= self.chunk_size:
|
|
172
|
+
chunks.append(
|
|
173
|
+
TextChunker._make(
|
|
174
|
+
full_text, document_id, start, start + len(full_text), section_meta
|
|
175
|
+
)
|
|
176
|
+
)
|
|
177
|
+
else:
|
|
178
|
+
sub = self._fallback.chunk(full_text, document_id, section_meta)
|
|
179
|
+
chunks.extend(sub)
|
|
180
|
+
return chunks
|
|
181
|
+
|
|
182
|
+
def _split_by_headers(self, text: str) -> list[tuple[str, str, int]]:
|
|
183
|
+
"""Return list of (header_line, body_text, start_index)."""
|
|
184
|
+
matches = list(self.HEADER_RE.finditer(text))
|
|
185
|
+
if not matches:
|
|
186
|
+
return [("", text, 0)]
|
|
187
|
+
|
|
188
|
+
sections: list[tuple[str, str, int]] = []
|
|
189
|
+
# Text before first header
|
|
190
|
+
if matches[0].start() > 0:
|
|
191
|
+
preamble = text[: matches[0].start()]
|
|
192
|
+
if preamble.strip():
|
|
193
|
+
sections.append(("", preamble.strip(), 0))
|
|
194
|
+
|
|
195
|
+
for i, match in enumerate(matches):
|
|
196
|
+
header_line = match.group(0)
|
|
197
|
+
body_start = match.end()
|
|
198
|
+
body_end = matches[i + 1].start() if i + 1 < len(matches) else len(text)
|
|
199
|
+
body = text[body_start:body_end].strip()
|
|
200
|
+
sections.append((header_line, body, match.start()))
|
|
201
|
+
|
|
202
|
+
return sections
|
rag/config.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""Global configuration and pipeline singleton management."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
from rag.models import RAGConfig
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from rag.pipeline import RAGPipeline
|
|
11
|
+
|
|
12
|
+
_pipeline: RAGPipeline | None = None
|
|
13
|
+
_config: RAGConfig | None = None
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def init_rag(config: RAGConfig | None = None) -> RAGConfig:
|
|
17
|
+
"""Initialize global RAG configuration.
|
|
18
|
+
|
|
19
|
+
If no config is provided, a default configuration is created.
|
|
20
|
+
Returns the active config for further customization.
|
|
21
|
+
"""
|
|
22
|
+
global _config, _pipeline
|
|
23
|
+
_config = config or RAGConfig()
|
|
24
|
+
_pipeline = None # Reset pipeline so it picks up new config
|
|
25
|
+
return _config
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_config() -> RAGConfig:
|
|
29
|
+
"""Return the current global RAG config, initializing defaults if needed."""
|
|
30
|
+
global _config
|
|
31
|
+
if _config is None:
|
|
32
|
+
_config = RAGConfig()
|
|
33
|
+
return _config
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def get_pipeline() -> RAGPipeline:
|
|
37
|
+
"""Return (or create) the global RAGPipeline singleton.
|
|
38
|
+
|
|
39
|
+
Uses the current global config. Call ``init_rag()`` first to customize.
|
|
40
|
+
"""
|
|
41
|
+
global _pipeline
|
|
42
|
+
if _pipeline is None:
|
|
43
|
+
from rag.pipeline import RAGPipeline
|
|
44
|
+
|
|
45
|
+
_pipeline = RAGPipeline(config=get_config())
|
|
46
|
+
return _pipeline
|
rag/embeddings.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
"""Embedding providers for converting text to vector representations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import hashlib
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class EmbeddingProvider(ABC):
|
|
12
|
+
"""Abstract base for embedding providers."""
|
|
13
|
+
|
|
14
|
+
@abstractmethod
|
|
15
|
+
async def embed(self, texts: list[str]) -> list[list[float]]:
|
|
16
|
+
"""Return embedding vectors for each text in *texts*."""
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def dimensions(self) -> int:
|
|
20
|
+
"""Return the dimensionality of the embedding vectors."""
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class OpenAIEmbeddings(EmbeddingProvider):
|
|
24
|
+
"""OpenAI text-embedding-3-small/large provider with caching and batching.
|
|
25
|
+
|
|
26
|
+
Requires the ``openai`` extra: ``pip install willian-rag[openai]``
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
MODEL_DIMENSIONS = {
|
|
30
|
+
"text-embedding-3-small": 1536,
|
|
31
|
+
"text-embedding-3-large": 3072,
|
|
32
|
+
"text-embedding-ada-002": 1536,
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
model: str = "text-embedding-3-small",
|
|
38
|
+
api_key: str | None = None,
|
|
39
|
+
batch_size: int = 100,
|
|
40
|
+
max_retries: int = 3,
|
|
41
|
+
retry_delay: float = 1.0,
|
|
42
|
+
) -> None:
|
|
43
|
+
try:
|
|
44
|
+
import openai # noqa: F811
|
|
45
|
+
except ImportError as exc:
|
|
46
|
+
msg = "Install the openai extra: pip install willian-rag[openai]"
|
|
47
|
+
raise ImportError(msg) from exc
|
|
48
|
+
|
|
49
|
+
self.model = model
|
|
50
|
+
self.batch_size = batch_size
|
|
51
|
+
self.max_retries = max_retries
|
|
52
|
+
self.retry_delay = retry_delay
|
|
53
|
+
self._client = openai.AsyncOpenAI(api_key=api_key)
|
|
54
|
+
self._cache: dict[str, list[float]] = {}
|
|
55
|
+
|
|
56
|
+
def dimensions(self) -> int:
|
|
57
|
+
return self.MODEL_DIMENSIONS.get(self.model, 1536)
|
|
58
|
+
|
|
59
|
+
async def embed(self, texts: list[str]) -> list[list[float]]:
|
|
60
|
+
"""Embed texts with batching, rate limiting, and caching."""
|
|
61
|
+
results: dict[int, list[float]] = {}
|
|
62
|
+
uncached: list[tuple[int, str]] = []
|
|
63
|
+
|
|
64
|
+
for i, text in enumerate(texts):
|
|
65
|
+
key = self._cache_key(text)
|
|
66
|
+
if key in self._cache:
|
|
67
|
+
results[i] = self._cache[key]
|
|
68
|
+
else:
|
|
69
|
+
uncached.append((i, text))
|
|
70
|
+
|
|
71
|
+
# Process uncached in batches
|
|
72
|
+
for batch_start in range(0, len(uncached), self.batch_size):
|
|
73
|
+
batch = uncached[batch_start : batch_start + self.batch_size]
|
|
74
|
+
batch_texts = [t for _, t in batch]
|
|
75
|
+
embeddings = await self._embed_with_retry(batch_texts)
|
|
76
|
+
for (idx, text), emb in zip(batch, embeddings):
|
|
77
|
+
self._cache[self._cache_key(text)] = emb
|
|
78
|
+
results[idx] = emb
|
|
79
|
+
|
|
80
|
+
return [results[i] for i in range(len(texts))]
|
|
81
|
+
|
|
82
|
+
async def _embed_with_retry(self, texts: list[str]) -> list[list[float]]:
|
|
83
|
+
last_error: Exception | None = None
|
|
84
|
+
for attempt in range(self.max_retries):
|
|
85
|
+
try:
|
|
86
|
+
response = await self._client.embeddings.create(model=self.model, input=texts)
|
|
87
|
+
return [item.embedding for item in response.data]
|
|
88
|
+
except Exception as exc:
|
|
89
|
+
last_error = exc
|
|
90
|
+
if attempt < self.max_retries - 1:
|
|
91
|
+
await asyncio.sleep(self.retry_delay * (attempt + 1))
|
|
92
|
+
msg = f"Embedding failed after {self.max_retries} retries"
|
|
93
|
+
raise RuntimeError(msg) from last_error
|
|
94
|
+
|
|
95
|
+
@staticmethod
|
|
96
|
+
def _cache_key(text: str) -> str:
|
|
97
|
+
return hashlib.sha256(text.encode()).hexdigest()
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class LocalEmbeddings(EmbeddingProvider):
|
|
101
|
+
"""Placeholder for sentence-transformers local embedding.
|
|
102
|
+
|
|
103
|
+
Returns deterministic pseudo-embeddings based on text hash for testing.
|
|
104
|
+
Replace the ``embed`` method with a real model for production use.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
def __init__(self, dims: int = 384) -> None:
|
|
108
|
+
self._dims = dims
|
|
109
|
+
|
|
110
|
+
def dimensions(self) -> int:
|
|
111
|
+
return self._dims
|
|
112
|
+
|
|
113
|
+
async def embed(self, texts: list[str]) -> list[list[float]]:
|
|
114
|
+
return [self._pseudo_embedding(t) for t in texts]
|
|
115
|
+
|
|
116
|
+
def _pseudo_embedding(self, text: str) -> list[float]:
|
|
117
|
+
"""Generate a deterministic pseudo-embedding from text hash."""
|
|
118
|
+
h = hashlib.sha256(text.encode()).digest()
|
|
119
|
+
raw = [b / 255.0 for b in h]
|
|
120
|
+
# Extend or truncate to match dimensions
|
|
121
|
+
while len(raw) < self._dims:
|
|
122
|
+
raw = raw + raw
|
|
123
|
+
raw = raw[: self._dims]
|
|
124
|
+
# Normalize to unit vector
|
|
125
|
+
norm = sum(x * x for x in raw) ** 0.5
|
|
126
|
+
return [x / norm if norm > 0 else 0.0 for x in raw]
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _build_provider(config: Any) -> EmbeddingProvider:
|
|
130
|
+
"""Factory: create an EmbeddingProvider from an EmbeddingConfig."""
|
|
131
|
+
if config.provider == "openai":
|
|
132
|
+
return OpenAIEmbeddings(model=config.model, batch_size=config.batch_size)
|
|
133
|
+
if config.provider == "local":
|
|
134
|
+
return LocalEmbeddings(dims=config.dimensions)
|
|
135
|
+
msg = f"Unknown embedding provider: {config.provider!r}"
|
|
136
|
+
raise ValueError(msg)
|
rag/models.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""Core domain models for the RAG pipeline."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import uuid
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, Field
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _default_id() -> str:
|
|
12
|
+
return uuid.uuid4().hex
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Document(BaseModel):
|
|
16
|
+
"""A source document to be ingested into the RAG pipeline."""
|
|
17
|
+
|
|
18
|
+
id: str = Field(default_factory=_default_id)
|
|
19
|
+
text: str
|
|
20
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Chunk(BaseModel):
|
|
24
|
+
"""A text chunk derived from a document."""
|
|
25
|
+
|
|
26
|
+
id: str = Field(default_factory=_default_id)
|
|
27
|
+
text: str
|
|
28
|
+
document_id: str
|
|
29
|
+
start_idx: int
|
|
30
|
+
end_idx: int
|
|
31
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
32
|
+
embedding: list[float] | None = None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class SearchResult(BaseModel):
|
|
36
|
+
"""A single result returned from vector search."""
|
|
37
|
+
|
|
38
|
+
chunk: Chunk
|
|
39
|
+
score: float
|
|
40
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class ChunkingConfig(BaseModel):
|
|
44
|
+
"""Configuration for text chunking."""
|
|
45
|
+
|
|
46
|
+
strategy: str = "recursive"
|
|
47
|
+
chunk_size: int = 512
|
|
48
|
+
chunk_overlap: int = 64
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class EmbeddingConfig(BaseModel):
|
|
52
|
+
"""Configuration for the embedding provider."""
|
|
53
|
+
|
|
54
|
+
provider: str = "openai"
|
|
55
|
+
model: str = "text-embedding-3-small"
|
|
56
|
+
dimensions: int = 1536
|
|
57
|
+
batch_size: int = 100
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class VectorStoreConfig(BaseModel):
|
|
61
|
+
"""Configuration for the vector store backend."""
|
|
62
|
+
|
|
63
|
+
backend: str = "pgvector"
|
|
64
|
+
table_name: str = "rag_chunks"
|
|
65
|
+
dsn: str = ""
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class RAGConfig(BaseModel):
|
|
69
|
+
"""Top-level configuration for the entire RAG pipeline."""
|
|
70
|
+
|
|
71
|
+
chunking: ChunkingConfig = Field(default_factory=ChunkingConfig)
|
|
72
|
+
embedding: EmbeddingConfig = Field(default_factory=EmbeddingConfig)
|
|
73
|
+
vector_store: VectorStoreConfig = Field(default_factory=VectorStoreConfig)
|
rag/pipeline.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""RAG pipeline orchestrating chunking, embedding, storage, and retrieval."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
from rag.chunking import MarkdownChunker, TextChunker
|
|
8
|
+
from rag.embeddings import LocalEmbeddings, _build_provider
|
|
9
|
+
from rag.models import Chunk, Document, RAGConfig, SearchResult
|
|
10
|
+
from rag.reranker import Reranker
|
|
11
|
+
from rag.vector_store import InMemoryVectorStore, VectorStore
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from rag.embeddings import EmbeddingProvider
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class RAGPipeline:
|
|
18
|
+
"""Orchestrates the full RAG flow: chunk -> embed -> store -> retrieve.
|
|
19
|
+
|
|
20
|
+
Can be configured via RAGConfig or by injecting components directly.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
config: RAGConfig | None = None,
|
|
26
|
+
chunker: TextChunker | MarkdownChunker | None = None,
|
|
27
|
+
embedding_provider: EmbeddingProvider | None = None,
|
|
28
|
+
vector_store: VectorStore | None = None,
|
|
29
|
+
) -> None:
|
|
30
|
+
cfg = config or RAGConfig()
|
|
31
|
+
|
|
32
|
+
self.chunker = chunker or TextChunker(
|
|
33
|
+
strategy=cfg.chunking.strategy,
|
|
34
|
+
chunk_size=cfg.chunking.chunk_size,
|
|
35
|
+
chunk_overlap=cfg.chunking.chunk_overlap,
|
|
36
|
+
)
|
|
37
|
+
self.embedding_provider: EmbeddingProvider = embedding_provider or self._safe_provider(cfg)
|
|
38
|
+
self.vector_store = vector_store or InMemoryVectorStore()
|
|
39
|
+
|
|
40
|
+
@staticmethod
|
|
41
|
+
def _safe_provider(cfg: RAGConfig) -> EmbeddingProvider:
|
|
42
|
+
"""Try to build the configured provider, fall back to local."""
|
|
43
|
+
try:
|
|
44
|
+
return _build_provider(cfg.embedding)
|
|
45
|
+
except Exception:
|
|
46
|
+
return LocalEmbeddings(dims=cfg.embedding.dimensions)
|
|
47
|
+
|
|
48
|
+
async def ingest(self, documents: list[Document]) -> int:
|
|
49
|
+
"""Process documents through the full pipeline: chunk -> embed -> store.
|
|
50
|
+
|
|
51
|
+
Returns the total number of chunks stored.
|
|
52
|
+
"""
|
|
53
|
+
all_chunks: list[Chunk] = []
|
|
54
|
+
|
|
55
|
+
for doc in documents:
|
|
56
|
+
chunks = self.chunker.chunk(
|
|
57
|
+
text=doc.text,
|
|
58
|
+
document_id=doc.id,
|
|
59
|
+
metadata=doc.metadata,
|
|
60
|
+
)
|
|
61
|
+
all_chunks.extend(chunks)
|
|
62
|
+
|
|
63
|
+
if not all_chunks:
|
|
64
|
+
return 0
|
|
65
|
+
|
|
66
|
+
# Embed all chunks in one batch
|
|
67
|
+
texts = [c.text for c in all_chunks]
|
|
68
|
+
embeddings = await self.embedding_provider.embed(texts)
|
|
69
|
+
|
|
70
|
+
for chunk, embedding in zip(all_chunks, embeddings):
|
|
71
|
+
chunk.embedding = embedding
|
|
72
|
+
|
|
73
|
+
# Store
|
|
74
|
+
count = await self.vector_store.upsert(all_chunks)
|
|
75
|
+
return count
|
|
76
|
+
|
|
77
|
+
async def retrieve(
|
|
78
|
+
self,
|
|
79
|
+
query: str,
|
|
80
|
+
top_k: int = 10,
|
|
81
|
+
filters: dict | None = None,
|
|
82
|
+
) -> list[SearchResult]:
|
|
83
|
+
"""Embed the query and search the vector store."""
|
|
84
|
+
query_embeddings = await self.embedding_provider.embed([query])
|
|
85
|
+
query_embedding = query_embeddings[0]
|
|
86
|
+
return await self.vector_store.search(
|
|
87
|
+
query_embedding=query_embedding,
|
|
88
|
+
top_k=top_k,
|
|
89
|
+
filters=filters,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
async def retrieve_with_rerank(
|
|
93
|
+
self,
|
|
94
|
+
query: str,
|
|
95
|
+
top_k: int = 10,
|
|
96
|
+
reranker: Reranker | None = None,
|
|
97
|
+
initial_k: int | None = None,
|
|
98
|
+
filters: dict | None = None,
|
|
99
|
+
) -> list[SearchResult]:
|
|
100
|
+
"""Retrieve a larger set, then rerank down to top_k.
|
|
101
|
+
|
|
102
|
+
If no reranker is provided, behaves like plain retrieve.
|
|
103
|
+
"""
|
|
104
|
+
fetch_k = initial_k or top_k * 3
|
|
105
|
+
results = await self.retrieve(query, top_k=fetch_k, filters=filters)
|
|
106
|
+
|
|
107
|
+
if reranker is None:
|
|
108
|
+
return results[:top_k]
|
|
109
|
+
|
|
110
|
+
return await reranker.rerank(query, results, top_k=top_k)
|
rag/reranker.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""Reranking strategies for improving retrieval quality."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
|
|
7
|
+
from rag.models import SearchResult
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Reranker(ABC):
|
|
11
|
+
"""Abstract base for reranking retrieved results."""
|
|
12
|
+
|
|
13
|
+
@abstractmethod
|
|
14
|
+
async def rerank(
|
|
15
|
+
self, query: str, results: list[SearchResult], top_k: int = 10
|
|
16
|
+
) -> list[SearchResult]:
|
|
17
|
+
"""Rerank results by relevance to *query*. Returns top_k results."""
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class CrossEncoderReranker(Reranker):
|
|
21
|
+
"""Placeholder for cross-encoder model reranking.
|
|
22
|
+
|
|
23
|
+
In production, load a cross-encoder model (e.g., ms-marco-MiniLM) and score
|
|
24
|
+
each (query, chunk.text) pair. This placeholder passes results through with
|
|
25
|
+
a simple length-based heuristic for testing.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
async def rerank(
|
|
29
|
+
self, query: str, results: list[SearchResult], top_k: int = 10
|
|
30
|
+
) -> list[SearchResult]:
|
|
31
|
+
# Placeholder: score by keyword overlap ratio
|
|
32
|
+
query_terms = set(query.lower().split())
|
|
33
|
+
scored: list[tuple[float, SearchResult]] = []
|
|
34
|
+
for result in results:
|
|
35
|
+
chunk_terms = set(result.chunk.text.lower().split())
|
|
36
|
+
overlap = len(query_terms & chunk_terms)
|
|
37
|
+
total = len(query_terms) if query_terms else 1
|
|
38
|
+
relevance = overlap / total
|
|
39
|
+
scored.append((relevance, result))
|
|
40
|
+
|
|
41
|
+
scored.sort(key=lambda x: x[0], reverse=True)
|
|
42
|
+
return [
|
|
43
|
+
SearchResult(
|
|
44
|
+
chunk=r.chunk,
|
|
45
|
+
score=score,
|
|
46
|
+
metadata=r.metadata,
|
|
47
|
+
)
|
|
48
|
+
for score, r in scored[:top_k]
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class LLMReranker(Reranker):
|
|
53
|
+
"""Reranker that uses an LLM to score relevance.
|
|
54
|
+
|
|
55
|
+
Accepts a callable ``score_fn`` that takes (query, text) and returns a
|
|
56
|
+
float relevance score. This allows integration with willian-ai or any
|
|
57
|
+
LLM client.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
score_fn: callable | None = None,
|
|
63
|
+
) -> None:
|
|
64
|
+
self._score_fn = score_fn
|
|
65
|
+
|
|
66
|
+
async def rerank(
|
|
67
|
+
self, query: str, results: list[SearchResult], top_k: int = 10
|
|
68
|
+
) -> list[SearchResult]:
|
|
69
|
+
if self._score_fn is None:
|
|
70
|
+
# No scoring function — return as-is
|
|
71
|
+
return results[:top_k]
|
|
72
|
+
|
|
73
|
+
scored: list[tuple[float, SearchResult]] = []
|
|
74
|
+
for result in results:
|
|
75
|
+
score = await self._score_fn(query, result.chunk.text)
|
|
76
|
+
scored.append((float(score), result))
|
|
77
|
+
|
|
78
|
+
scored.sort(key=lambda x: x[0], reverse=True)
|
|
79
|
+
return [
|
|
80
|
+
SearchResult(chunk=r.chunk, score=score, metadata=r.metadata)
|
|
81
|
+
for score, r in scored[:top_k]
|
|
82
|
+
]
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def reciprocal_rank_fusion(
|
|
86
|
+
*result_lists: list[SearchResult],
|
|
87
|
+
k: int = 60,
|
|
88
|
+
top_n: int = 10,
|
|
89
|
+
) -> list[SearchResult]:
|
|
90
|
+
"""Merge multiple ranked result lists using Reciprocal Rank Fusion (RRF).
|
|
91
|
+
|
|
92
|
+
RRF score for document d = sum over lists of 1 / (k + rank_in_list).
|
|
93
|
+
Default k=60 following the original RRF paper.
|
|
94
|
+
"""
|
|
95
|
+
scores: dict[str, float] = {}
|
|
96
|
+
chunk_map: dict[str, SearchResult] = {}
|
|
97
|
+
|
|
98
|
+
for results in result_lists:
|
|
99
|
+
for rank, result in enumerate(results):
|
|
100
|
+
cid = result.chunk.id
|
|
101
|
+
scores[cid] = scores.get(cid, 0.0) + 1.0 / (k + rank + 1)
|
|
102
|
+
if cid not in chunk_map:
|
|
103
|
+
chunk_map[cid] = result
|
|
104
|
+
|
|
105
|
+
ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)
|
|
106
|
+
return [
|
|
107
|
+
SearchResult(
|
|
108
|
+
chunk=chunk_map[cid].chunk,
|
|
109
|
+
score=score,
|
|
110
|
+
metadata=chunk_map[cid].metadata,
|
|
111
|
+
)
|
|
112
|
+
for cid, score in ranked[:top_n]
|
|
113
|
+
]
|
rag/vector_store.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
"""Vector store backends for chunk storage and similarity search."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from rag.models import Chunk, SearchResult
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class VectorStore(ABC):
|
|
13
|
+
"""Abstract base for vector store backends."""
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
async def upsert(self, chunks: list[Chunk]) -> int:
|
|
17
|
+
"""Insert or update chunks. Returns the number of affected rows."""
|
|
18
|
+
|
|
19
|
+
@abstractmethod
|
|
20
|
+
async def search(
|
|
21
|
+
self,
|
|
22
|
+
query_embedding: list[float],
|
|
23
|
+
top_k: int = 10,
|
|
24
|
+
filters: dict[str, Any] | None = None,
|
|
25
|
+
) -> list[SearchResult]:
|
|
26
|
+
"""Search by cosine similarity. Returns results ordered by score descending."""
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
async def delete(self, chunk_ids: list[str]) -> int:
|
|
30
|
+
"""Delete chunks by ID. Returns the number of deleted rows."""
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class PgVectorStore(VectorStore):
|
|
34
|
+
"""PostgreSQL + pgvector backend using asyncpg.
|
|
35
|
+
|
|
36
|
+
Supports both pure vector search and hybrid (vector + full-text) search
|
|
37
|
+
via Reciprocal Rank Fusion.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self, dsn: str, table_name: str = "rag_chunks", dimensions: int = 1536) -> None:
|
|
41
|
+
self.dsn = dsn
|
|
42
|
+
self.table_name = table_name
|
|
43
|
+
self.dimensions = dimensions
|
|
44
|
+
self._pool: Any = None
|
|
45
|
+
|
|
46
|
+
async def _get_pool(self) -> Any:
|
|
47
|
+
if self._pool is None:
|
|
48
|
+
import asyncpg
|
|
49
|
+
|
|
50
|
+
self._pool = await asyncpg.create_pool(self.dsn, min_size=2, max_size=10)
|
|
51
|
+
return self._pool
|
|
52
|
+
|
|
53
|
+
async def create_table(self) -> None:
|
|
54
|
+
"""Create the chunks table with vector and tsvector columns."""
|
|
55
|
+
pool = await self._get_pool()
|
|
56
|
+
async with pool.acquire() as conn:
|
|
57
|
+
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
|
58
|
+
await conn.execute(f"""
|
|
59
|
+
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
|
60
|
+
id TEXT PRIMARY KEY,
|
|
61
|
+
text TEXT NOT NULL,
|
|
62
|
+
document_id TEXT NOT NULL,
|
|
63
|
+
start_idx INTEGER NOT NULL,
|
|
64
|
+
end_idx INTEGER NOT NULL,
|
|
65
|
+
metadata JSONB DEFAULT '{{}}',
|
|
66
|
+
embedding vector({self.dimensions}),
|
|
67
|
+
tsv tsvector GENERATED ALWAYS AS (to_tsvector('english', text)) STORED
|
|
68
|
+
)
|
|
69
|
+
""")
|
|
70
|
+
await conn.execute(
|
|
71
|
+
f"CREATE INDEX IF NOT EXISTS idx_{self.table_name}_embedding "
|
|
72
|
+
f"ON {self.table_name} USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100)"
|
|
73
|
+
)
|
|
74
|
+
await conn.execute(
|
|
75
|
+
f"CREATE INDEX IF NOT EXISTS idx_{self.table_name}_tsv "
|
|
76
|
+
f"ON {self.table_name} USING gin (tsv)"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
async def upsert(self, chunks: list[Chunk]) -> int:
|
|
80
|
+
"""Bulk upsert chunks with ON CONFLICT."""
|
|
81
|
+
if not chunks:
|
|
82
|
+
return 0
|
|
83
|
+
pool = await self._get_pool()
|
|
84
|
+
async with pool.acquire() as conn:
|
|
85
|
+
values = [
|
|
86
|
+
(
|
|
87
|
+
c.id,
|
|
88
|
+
c.text,
|
|
89
|
+
c.document_id,
|
|
90
|
+
c.start_idx,
|
|
91
|
+
c.end_idx,
|
|
92
|
+
json.dumps(c.metadata),
|
|
93
|
+
str(c.embedding) if c.embedding else None,
|
|
94
|
+
)
|
|
95
|
+
for c in chunks
|
|
96
|
+
]
|
|
97
|
+
result = await conn.executemany(
|
|
98
|
+
f"""
|
|
99
|
+
INSERT INTO {self.table_name} (id, text, document_id, start_idx, end_idx, metadata, embedding)
|
|
100
|
+
VALUES ($1, $2, $3, $4, $5, $6::jsonb, $7::vector)
|
|
101
|
+
ON CONFLICT (id) DO UPDATE SET
|
|
102
|
+
text = EXCLUDED.text,
|
|
103
|
+
document_id = EXCLUDED.document_id,
|
|
104
|
+
metadata = EXCLUDED.metadata,
|
|
105
|
+
embedding = EXCLUDED.embedding
|
|
106
|
+
""",
|
|
107
|
+
values,
|
|
108
|
+
)
|
|
109
|
+
return len(chunks) if result is None else len(chunks)
|
|
110
|
+
|
|
111
|
+
async def search(
|
|
112
|
+
self,
|
|
113
|
+
query_embedding: list[float],
|
|
114
|
+
top_k: int = 10,
|
|
115
|
+
filters: dict[str, Any] | None = None,
|
|
116
|
+
) -> list[SearchResult]:
|
|
117
|
+
"""Cosine similarity search."""
|
|
118
|
+
pool = await self._get_pool()
|
|
119
|
+
filter_clause = self._build_filter_clause(filters)
|
|
120
|
+
embedding_str = str(query_embedding)
|
|
121
|
+
|
|
122
|
+
async with pool.acquire() as conn:
|
|
123
|
+
rows = await conn.fetch(
|
|
124
|
+
f"""
|
|
125
|
+
SELECT id, text, document_id, start_idx, end_idx, metadata,
|
|
126
|
+
1 - (embedding <=> $1::vector) AS score
|
|
127
|
+
FROM {self.table_name}
|
|
128
|
+
{filter_clause}
|
|
129
|
+
ORDER BY embedding <=> $1::vector
|
|
130
|
+
LIMIT $2
|
|
131
|
+
""",
|
|
132
|
+
embedding_str,
|
|
133
|
+
top_k,
|
|
134
|
+
)
|
|
135
|
+
return [self._row_to_result(row) for row in rows]
|
|
136
|
+
|
|
137
|
+
async def hybrid_search(
|
|
138
|
+
self,
|
|
139
|
+
query_text: str,
|
|
140
|
+
query_embedding: list[float],
|
|
141
|
+
top_k: int = 10,
|
|
142
|
+
vector_weight: float = 0.7,
|
|
143
|
+
filters: dict[str, Any] | None = None,
|
|
144
|
+
) -> list[SearchResult]:
|
|
145
|
+
"""Hybrid search combining vector similarity and full-text search via RRF."""
|
|
146
|
+
pool = await self._get_pool()
|
|
147
|
+
filter_clause = self._build_filter_clause(filters)
|
|
148
|
+
embedding_str = str(query_embedding)
|
|
149
|
+
|
|
150
|
+
async with pool.acquire() as conn:
|
|
151
|
+
rows = await conn.fetch(
|
|
152
|
+
f"""
|
|
153
|
+
WITH vector_results AS (
|
|
154
|
+
SELECT id, 1 - (embedding <=> $1::vector) AS vec_score,
|
|
155
|
+
ROW_NUMBER() OVER (ORDER BY embedding <=> $1::vector) AS vec_rank
|
|
156
|
+
FROM {self.table_name} {filter_clause}
|
|
157
|
+
LIMIT $3
|
|
158
|
+
),
|
|
159
|
+
text_results AS (
|
|
160
|
+
SELECT id, ts_rank_cd(tsv, plainto_tsquery('english', $2)) AS text_score,
|
|
161
|
+
ROW_NUMBER() OVER (ORDER BY ts_rank_cd(tsv, plainto_tsquery('english', $2)) DESC) AS text_rank
|
|
162
|
+
FROM {self.table_name}
|
|
163
|
+
WHERE tsv @@ plainto_tsquery('english', $2) {filter_clause.replace("WHERE", "AND") if filter_clause else ""}
|
|
164
|
+
LIMIT $3
|
|
165
|
+
),
|
|
166
|
+
combined AS (
|
|
167
|
+
SELECT COALESCE(v.id, t.id) AS id,
|
|
168
|
+
$4 * COALESCE(1.0 / (60 + v.vec_rank), 0) +
|
|
169
|
+
(1 - $4) * COALESCE(1.0 / (60 + t.text_rank), 0) AS rrf_score
|
|
170
|
+
FROM vector_results v
|
|
171
|
+
FULL OUTER JOIN text_results t ON v.id = t.id
|
|
172
|
+
ORDER BY rrf_score DESC
|
|
173
|
+
LIMIT $3
|
|
174
|
+
)
|
|
175
|
+
SELECT c.id, c.text, c.document_id, c.start_idx, c.end_idx,
|
|
176
|
+
c.metadata, combined.rrf_score AS score
|
|
177
|
+
FROM combined
|
|
178
|
+
JOIN {self.table_name} c ON c.id = combined.id
|
|
179
|
+
ORDER BY combined.rrf_score DESC
|
|
180
|
+
""",
|
|
181
|
+
embedding_str,
|
|
182
|
+
query_text,
|
|
183
|
+
top_k,
|
|
184
|
+
vector_weight,
|
|
185
|
+
)
|
|
186
|
+
return [self._row_to_result(row) for row in rows]
|
|
187
|
+
|
|
188
|
+
async def delete(self, chunk_ids: list[str]) -> int:
|
|
189
|
+
if not chunk_ids:
|
|
190
|
+
return 0
|
|
191
|
+
pool = await self._get_pool()
|
|
192
|
+
async with pool.acquire() as conn:
|
|
193
|
+
result = await conn.execute(
|
|
194
|
+
f"DELETE FROM {self.table_name} WHERE id = ANY($1)", chunk_ids
|
|
195
|
+
)
|
|
196
|
+
# asyncpg returns "DELETE N"
|
|
197
|
+
return int(result.split()[-1]) if result else 0
|
|
198
|
+
|
|
199
|
+
async def close(self) -> None:
|
|
200
|
+
"""Close the connection pool."""
|
|
201
|
+
if self._pool:
|
|
202
|
+
await self._pool.close()
|
|
203
|
+
self._pool = None
|
|
204
|
+
|
|
205
|
+
@staticmethod
|
|
206
|
+
def _build_filter_clause(filters: dict[str, Any] | None) -> str:
|
|
207
|
+
if not filters:
|
|
208
|
+
return ""
|
|
209
|
+
conditions = []
|
|
210
|
+
for key, value in filters.items():
|
|
211
|
+
escaped = json.dumps(value)
|
|
212
|
+
conditions.append(f"metadata->>'{key}' = {escaped}")
|
|
213
|
+
return "WHERE " + " AND ".join(conditions)
|
|
214
|
+
|
|
215
|
+
@staticmethod
|
|
216
|
+
def _row_to_result(row: Any) -> SearchResult:
|
|
217
|
+
metadata = (
|
|
218
|
+
row["metadata"] if isinstance(row["metadata"], dict) else json.loads(row["metadata"])
|
|
219
|
+
)
|
|
220
|
+
chunk = Chunk(
|
|
221
|
+
id=row["id"],
|
|
222
|
+
text=row["text"],
|
|
223
|
+
document_id=row["document_id"],
|
|
224
|
+
start_idx=row["start_idx"],
|
|
225
|
+
end_idx=row["end_idx"],
|
|
226
|
+
metadata=metadata,
|
|
227
|
+
)
|
|
228
|
+
return SearchResult(chunk=chunk, score=float(row["score"]), metadata=metadata)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class InMemoryVectorStore(VectorStore):
|
|
232
|
+
"""In-memory vector store using numpy for cosine similarity. For testing."""
|
|
233
|
+
|
|
234
|
+
def __init__(self) -> None:
|
|
235
|
+
self._chunks: dict[str, Chunk] = {}
|
|
236
|
+
|
|
237
|
+
async def upsert(self, chunks: list[Chunk]) -> int:
|
|
238
|
+
count = 0
|
|
239
|
+
for chunk in chunks:
|
|
240
|
+
self._chunks[chunk.id] = chunk
|
|
241
|
+
count += 1
|
|
242
|
+
return count
|
|
243
|
+
|
|
244
|
+
async def search(
|
|
245
|
+
self,
|
|
246
|
+
query_embedding: list[float],
|
|
247
|
+
top_k: int = 10,
|
|
248
|
+
filters: dict[str, Any] | None = None,
|
|
249
|
+
) -> list[SearchResult]:
|
|
250
|
+
import numpy as np
|
|
251
|
+
|
|
252
|
+
candidates = list(self._chunks.values())
|
|
253
|
+
if filters:
|
|
254
|
+
candidates = [
|
|
255
|
+
c for c in candidates if all(c.metadata.get(k) == v for k, v in filters.items())
|
|
256
|
+
]
|
|
257
|
+
|
|
258
|
+
candidates_with_emb = [c for c in candidates if c.embedding is not None]
|
|
259
|
+
if not candidates_with_emb:
|
|
260
|
+
return []
|
|
261
|
+
|
|
262
|
+
query_vec = np.array(query_embedding, dtype=np.float64)
|
|
263
|
+
query_norm = np.linalg.norm(query_vec)
|
|
264
|
+
if query_norm == 0:
|
|
265
|
+
return []
|
|
266
|
+
query_vec = query_vec / query_norm
|
|
267
|
+
|
|
268
|
+
scored: list[tuple[float, Chunk]] = []
|
|
269
|
+
for chunk in candidates_with_emb:
|
|
270
|
+
chunk_vec = np.array(chunk.embedding, dtype=np.float64)
|
|
271
|
+
chunk_norm = np.linalg.norm(chunk_vec)
|
|
272
|
+
if chunk_norm == 0:
|
|
273
|
+
continue
|
|
274
|
+
score = float(np.dot(query_vec, chunk_vec / chunk_norm))
|
|
275
|
+
scored.append((score, chunk))
|
|
276
|
+
|
|
277
|
+
scored.sort(key=lambda x: x[0], reverse=True)
|
|
278
|
+
return [
|
|
279
|
+
SearchResult(chunk=chunk, score=score, metadata=chunk.metadata)
|
|
280
|
+
for score, chunk in scored[:top_k]
|
|
281
|
+
]
|
|
282
|
+
|
|
283
|
+
async def delete(self, chunk_ids: list[str]) -> int:
|
|
284
|
+
count = 0
|
|
285
|
+
for cid in chunk_ids:
|
|
286
|
+
if cid in self._chunks:
|
|
287
|
+
del self._chunks[cid]
|
|
288
|
+
count += 1
|
|
289
|
+
return count
|
|
290
|
+
|
|
291
|
+
@property
|
|
292
|
+
def size(self) -> int:
|
|
293
|
+
return len(self._chunks)
|