msaas-rag 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,20 @@
1
+ Metadata-Version: 2.4
2
+ Name: msaas-rag
3
+ Version: 1.0.0
4
+ Summary: RAG pipeline library — chunking, embeddings, vector search, and retrieval for the Willian SaaS platform
5
+ License: MIT
6
+ Requires-Python: >=3.12
7
+ Requires-Dist: asyncpg>=0.30.0
8
+ Requires-Dist: pydantic>=2.0
9
+ Provides-Extra: all
10
+ Requires-Dist: numpy>=2.0; extra == 'all'
11
+ Requires-Dist: openai>=1.50.0; extra == 'all'
12
+ Provides-Extra: dev
13
+ Requires-Dist: numpy>=2.0; extra == 'dev'
14
+ Requires-Dist: pytest-asyncio>=0.24.0; extra == 'dev'
15
+ Requires-Dist: pytest>=8.0; extra == 'dev'
16
+ Requires-Dist: ruff>=0.8; extra == 'dev'
17
+ Provides-Extra: numpy
18
+ Requires-Dist: numpy>=2.0; extra == 'numpy'
19
+ Provides-Extra: openai
20
+ Requires-Dist: openai>=1.50.0; extra == 'openai'
@@ -0,0 +1,11 @@
1
+ rag/__init__.py,sha256=ca84grhMOHLh_cfFsJtbAU-jp2DkqEPjAjJwO65tdGo,1235
2
+ rag/chunking.py,sha256=XpuIU44VMyHlp_YuSKbIZpEvQct1VsryCs-qHDuufgA,7538
3
+ rag/config.py,sha256=ITny9Edku1M82KdzfeZ-SmZA9wotR5P7O8cUO6JDKgQ,1240
4
+ rag/embeddings.py,sha256=UwK63qcJEBvkXPXkYCbljx_pSWRG-DNZRmAkrln840o,4870
5
+ rag/models.py,sha256=NdtpR6X9I5lSnMISWOQ9c8iY4ranBg6qsME4PxvPUg8,1777
6
+ rag/pipeline.py,sha256=DAmgmzwa_BnnJ27hgf9l0-e0rPoAfVoNN3br-1NATAc,3631
7
+ rag/reranker.py,sha256=2a7veHSoIA5SgOGxEMzVLAdzKizY0O9jbkz8EBN7sl8,3605
8
+ rag/vector_store.py,sha256=8IuMhv9fkexbkd2BSpyKoQQWqjzzS37gw-nJDJZJM8w,10659
9
+ msaas_rag-1.0.0.dist-info/METADATA,sha256=79VoK8PKNFWy-xg35JoHjVbJEiSFErWcYhdjrH_YsH8,711
10
+ msaas_rag-1.0.0.dist-info/WHEEL,sha256=mffPy8wBnZQn2VnJUU5jE99KsxaSfiyMHV9Yt0aLVxs,87
11
+ msaas_rag-1.0.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.30.1
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
rag/__init__.py ADDED
@@ -0,0 +1,54 @@
1
+ """Willian RAG — retrieval-augmented generation pipeline library."""
2
+
3
+ from rag.chunking import MarkdownChunker, TextChunker
4
+ from rag.config import get_config, get_pipeline, init_rag
5
+ from rag.embeddings import EmbeddingProvider, LocalEmbeddings, OpenAIEmbeddings
6
+ from rag.models import (
7
+ Chunk,
8
+ ChunkingConfig,
9
+ Document,
10
+ EmbeddingConfig,
11
+ RAGConfig,
12
+ SearchResult,
13
+ VectorStoreConfig,
14
+ )
15
+ from rag.pipeline import RAGPipeline
16
+ from rag.reranker import (
17
+ CrossEncoderReranker,
18
+ LLMReranker,
19
+ Reranker,
20
+ reciprocal_rank_fusion,
21
+ )
22
+ from rag.vector_store import InMemoryVectorStore, PgVectorStore, VectorStore
23
+
24
+ __all__ = [
25
+ # Pipeline
26
+ "RAGPipeline",
27
+ "init_rag",
28
+ "get_config",
29
+ "get_pipeline",
30
+ # Models
31
+ "Document",
32
+ "Chunk",
33
+ "SearchResult",
34
+ "RAGConfig",
35
+ "ChunkingConfig",
36
+ "EmbeddingConfig",
37
+ "VectorStoreConfig",
38
+ # Chunking
39
+ "TextChunker",
40
+ "MarkdownChunker",
41
+ # Embeddings
42
+ "EmbeddingProvider",
43
+ "OpenAIEmbeddings",
44
+ "LocalEmbeddings",
45
+ # Vector stores
46
+ "VectorStore",
47
+ "PgVectorStore",
48
+ "InMemoryVectorStore",
49
+ # Reranking
50
+ "Reranker",
51
+ "CrossEncoderReranker",
52
+ "LLMReranker",
53
+ "reciprocal_rank_fusion",
54
+ ]
rag/chunking.py ADDED
@@ -0,0 +1,202 @@
1
+ """Text chunking strategies for document processing."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ import uuid
7
+ from typing import Any
8
+
9
+ from rag.models import Chunk
10
+
11
+
12
+ class TextChunker:
13
+ """Split plain text into chunks using configurable strategies.
14
+
15
+ Supported strategies:
16
+ - ``fixed_size``: split at exact character boundaries
17
+ - ``recursive``: split by paragraph > sentence > character, preferring natural breaks
18
+ - ``semantic``: sentence-based splitting that groups sentences up to chunk_size
19
+ """
20
+
21
+ SEPARATORS = ["\n\n", "\n", ". ", " "]
22
+
23
+ def __init__(
24
+ self,
25
+ strategy: str = "recursive",
26
+ chunk_size: int = 512,
27
+ chunk_overlap: int = 64,
28
+ ) -> None:
29
+ if strategy not in ("fixed_size", "recursive", "semantic"):
30
+ msg = f"Unknown strategy: {strategy!r}. Use fixed_size, recursive, or semantic."
31
+ raise ValueError(msg)
32
+ self.strategy = strategy
33
+ self.chunk_size = chunk_size
34
+ self.chunk_overlap = chunk_overlap
35
+
36
+ def chunk(
37
+ self,
38
+ text: str,
39
+ document_id: str,
40
+ metadata: dict[str, Any] | None = None,
41
+ ) -> list[Chunk]:
42
+ """Split *text* into Chunk objects."""
43
+ meta = metadata or {}
44
+ if self.strategy == "fixed_size":
45
+ return self._fixed_size(text, document_id, meta)
46
+ if self.strategy == "semantic":
47
+ return self._semantic(text, document_id, meta)
48
+ return self._recursive(text, document_id, meta)
49
+
50
+ # -- strategies ----------------------------------------------------------
51
+
52
+ def _fixed_size(self, text: str, document_id: str, meta: dict[str, Any]) -> list[Chunk]:
53
+ chunks: list[Chunk] = []
54
+ step = max(1, self.chunk_size - self.chunk_overlap)
55
+ for start in range(0, len(text), step):
56
+ end = min(start + self.chunk_size, len(text))
57
+ chunks.append(self._make(text[start:end], document_id, start, end, meta))
58
+ if end == len(text):
59
+ break
60
+ return chunks
61
+
62
+ def _recursive(self, text: str, document_id: str, meta: dict[str, Any]) -> list[Chunk]:
63
+ return self._recursive_split(text, document_id, meta, self.SEPARATORS)
64
+
65
+ def _recursive_split(
66
+ self,
67
+ text: str,
68
+ document_id: str,
69
+ meta: dict[str, Any],
70
+ separators: list[str],
71
+ ) -> list[Chunk]:
72
+ if len(text) <= self.chunk_size:
73
+ return [self._make(text, document_id, 0, len(text), meta)]
74
+
75
+ sep = separators[0] if separators else ""
76
+ parts = text.split(sep) if sep else list(text)
77
+ remaining_seps = separators[1:] if separators else []
78
+
79
+ chunks: list[Chunk] = []
80
+ current = ""
81
+ offset = 0
82
+
83
+ for i, part in enumerate(parts):
84
+ candidate = current + (sep if current else "") + part
85
+ if len(candidate) > self.chunk_size and current:
86
+ start = offset
87
+ end = offset + len(current)
88
+ chunks.append(self._make(current, document_id, start, end, meta))
89
+ # Compute overlap start
90
+ overlap_start = max(0, len(current) - self.chunk_overlap)
91
+ offset = offset + len(current) - (len(current) - overlap_start) + len(sep)
92
+ current = (
93
+ current[overlap_start:] + sep + part if overlap_start < len(current) else part
94
+ )
95
+ elif len(candidate) > self.chunk_size and remaining_seps:
96
+ sub = self._recursive_split(part, document_id, meta, remaining_seps)
97
+ chunks.extend(sub)
98
+ current = ""
99
+ offset += len(part) + len(sep)
100
+ else:
101
+ current = candidate
102
+ if not current.startswith(part) and i == 0:
103
+ pass # offset stays
104
+ if current.strip():
105
+ chunks.append(self._make(current, document_id, offset, offset + len(current), meta))
106
+ return chunks
107
+
108
+ def _semantic(self, text: str, document_id: str, meta: dict[str, Any]) -> list[Chunk]:
109
+ """Sentence-based chunking: group sentences up to chunk_size."""
110
+ sentences = re.split(r"(?<=[.!?])\s+", text)
111
+ chunks: list[Chunk] = []
112
+ current = ""
113
+ offset = 0
114
+
115
+ for sentence in sentences:
116
+ candidate = (current + " " + sentence).strip() if current else sentence
117
+ if len(candidate) > self.chunk_size and current:
118
+ end = offset + len(current)
119
+ chunks.append(self._make(current, document_id, offset, end, meta))
120
+ offset = end + 1
121
+ current = sentence
122
+ else:
123
+ current = candidate
124
+ if current.strip():
125
+ chunks.append(self._make(current, document_id, offset, offset + len(current), meta))
126
+ return chunks
127
+
128
+ @staticmethod
129
+ def _make(text: str, document_id: str, start: int, end: int, meta: dict[str, Any]) -> Chunk:
130
+ return Chunk(
131
+ id=uuid.uuid4().hex,
132
+ text=text,
133
+ document_id=document_id,
134
+ start_idx=start,
135
+ end_idx=end,
136
+ metadata=meta,
137
+ )
138
+
139
+
140
+ class MarkdownChunker:
141
+ """Split Markdown documents by headers, preserving structure.
142
+
143
+ Each header section becomes a chunk. Sections exceeding ``chunk_size`` are
144
+ further split by the given ``fallback`` TextChunker strategy.
145
+ """
146
+
147
+ HEADER_RE = re.compile(r"^(#{1,6})\s+(.*)", re.MULTILINE)
148
+
149
+ def __init__(self, chunk_size: int = 1024, chunk_overlap: int = 64) -> None:
150
+ self.chunk_size = chunk_size
151
+ self.chunk_overlap = chunk_overlap
152
+ self._fallback = TextChunker(
153
+ strategy="recursive", chunk_size=chunk_size, chunk_overlap=chunk_overlap
154
+ )
155
+
156
+ def chunk(
157
+ self,
158
+ text: str,
159
+ document_id: str,
160
+ metadata: dict[str, Any] | None = None,
161
+ ) -> list[Chunk]:
162
+ meta = metadata or {}
163
+ sections = self._split_by_headers(text)
164
+ chunks: list[Chunk] = []
165
+
166
+ for header, body, start in sections:
167
+ section_meta = {**meta}
168
+ if header:
169
+ section_meta["header"] = header
170
+ full_text = f"{header}\n{body}".strip() if header else body.strip()
171
+ if len(full_text) <= self.chunk_size:
172
+ chunks.append(
173
+ TextChunker._make(
174
+ full_text, document_id, start, start + len(full_text), section_meta
175
+ )
176
+ )
177
+ else:
178
+ sub = self._fallback.chunk(full_text, document_id, section_meta)
179
+ chunks.extend(sub)
180
+ return chunks
181
+
182
+ def _split_by_headers(self, text: str) -> list[tuple[str, str, int]]:
183
+ """Return list of (header_line, body_text, start_index)."""
184
+ matches = list(self.HEADER_RE.finditer(text))
185
+ if not matches:
186
+ return [("", text, 0)]
187
+
188
+ sections: list[tuple[str, str, int]] = []
189
+ # Text before first header
190
+ if matches[0].start() > 0:
191
+ preamble = text[: matches[0].start()]
192
+ if preamble.strip():
193
+ sections.append(("", preamble.strip(), 0))
194
+
195
+ for i, match in enumerate(matches):
196
+ header_line = match.group(0)
197
+ body_start = match.end()
198
+ body_end = matches[i + 1].start() if i + 1 < len(matches) else len(text)
199
+ body = text[body_start:body_end].strip()
200
+ sections.append((header_line, body, match.start()))
201
+
202
+ return sections
rag/config.py ADDED
@@ -0,0 +1,46 @@
1
+ """Global configuration and pipeline singleton management."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ from rag.models import RAGConfig
8
+
9
+ if TYPE_CHECKING:
10
+ from rag.pipeline import RAGPipeline
11
+
12
+ _pipeline: RAGPipeline | None = None
13
+ _config: RAGConfig | None = None
14
+
15
+
16
+ def init_rag(config: RAGConfig | None = None) -> RAGConfig:
17
+ """Initialize global RAG configuration.
18
+
19
+ If no config is provided, a default configuration is created.
20
+ Returns the active config for further customization.
21
+ """
22
+ global _config, _pipeline
23
+ _config = config or RAGConfig()
24
+ _pipeline = None # Reset pipeline so it picks up new config
25
+ return _config
26
+
27
+
28
+ def get_config() -> RAGConfig:
29
+ """Return the current global RAG config, initializing defaults if needed."""
30
+ global _config
31
+ if _config is None:
32
+ _config = RAGConfig()
33
+ return _config
34
+
35
+
36
+ def get_pipeline() -> RAGPipeline:
37
+ """Return (or create) the global RAGPipeline singleton.
38
+
39
+ Uses the current global config. Call ``init_rag()`` first to customize.
40
+ """
41
+ global _pipeline
42
+ if _pipeline is None:
43
+ from rag.pipeline import RAGPipeline
44
+
45
+ _pipeline = RAGPipeline(config=get_config())
46
+ return _pipeline
rag/embeddings.py ADDED
@@ -0,0 +1,136 @@
1
+ """Embedding providers for converting text to vector representations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import hashlib
7
+ from abc import ABC, abstractmethod
8
+ from typing import Any
9
+
10
+
11
+ class EmbeddingProvider(ABC):
12
+ """Abstract base for embedding providers."""
13
+
14
+ @abstractmethod
15
+ async def embed(self, texts: list[str]) -> list[list[float]]:
16
+ """Return embedding vectors for each text in *texts*."""
17
+
18
+ @abstractmethod
19
+ def dimensions(self) -> int:
20
+ """Return the dimensionality of the embedding vectors."""
21
+
22
+
23
+ class OpenAIEmbeddings(EmbeddingProvider):
24
+ """OpenAI text-embedding-3-small/large provider with caching and batching.
25
+
26
+ Requires the ``openai`` extra: ``pip install willian-rag[openai]``
27
+ """
28
+
29
+ MODEL_DIMENSIONS = {
30
+ "text-embedding-3-small": 1536,
31
+ "text-embedding-3-large": 3072,
32
+ "text-embedding-ada-002": 1536,
33
+ }
34
+
35
+ def __init__(
36
+ self,
37
+ model: str = "text-embedding-3-small",
38
+ api_key: str | None = None,
39
+ batch_size: int = 100,
40
+ max_retries: int = 3,
41
+ retry_delay: float = 1.0,
42
+ ) -> None:
43
+ try:
44
+ import openai # noqa: F811
45
+ except ImportError as exc:
46
+ msg = "Install the openai extra: pip install willian-rag[openai]"
47
+ raise ImportError(msg) from exc
48
+
49
+ self.model = model
50
+ self.batch_size = batch_size
51
+ self.max_retries = max_retries
52
+ self.retry_delay = retry_delay
53
+ self._client = openai.AsyncOpenAI(api_key=api_key)
54
+ self._cache: dict[str, list[float]] = {}
55
+
56
+ def dimensions(self) -> int:
57
+ return self.MODEL_DIMENSIONS.get(self.model, 1536)
58
+
59
+ async def embed(self, texts: list[str]) -> list[list[float]]:
60
+ """Embed texts with batching, rate limiting, and caching."""
61
+ results: dict[int, list[float]] = {}
62
+ uncached: list[tuple[int, str]] = []
63
+
64
+ for i, text in enumerate(texts):
65
+ key = self._cache_key(text)
66
+ if key in self._cache:
67
+ results[i] = self._cache[key]
68
+ else:
69
+ uncached.append((i, text))
70
+
71
+ # Process uncached in batches
72
+ for batch_start in range(0, len(uncached), self.batch_size):
73
+ batch = uncached[batch_start : batch_start + self.batch_size]
74
+ batch_texts = [t for _, t in batch]
75
+ embeddings = await self._embed_with_retry(batch_texts)
76
+ for (idx, text), emb in zip(batch, embeddings):
77
+ self._cache[self._cache_key(text)] = emb
78
+ results[idx] = emb
79
+
80
+ return [results[i] for i in range(len(texts))]
81
+
82
+ async def _embed_with_retry(self, texts: list[str]) -> list[list[float]]:
83
+ last_error: Exception | None = None
84
+ for attempt in range(self.max_retries):
85
+ try:
86
+ response = await self._client.embeddings.create(model=self.model, input=texts)
87
+ return [item.embedding for item in response.data]
88
+ except Exception as exc:
89
+ last_error = exc
90
+ if attempt < self.max_retries - 1:
91
+ await asyncio.sleep(self.retry_delay * (attempt + 1))
92
+ msg = f"Embedding failed after {self.max_retries} retries"
93
+ raise RuntimeError(msg) from last_error
94
+
95
+ @staticmethod
96
+ def _cache_key(text: str) -> str:
97
+ return hashlib.sha256(text.encode()).hexdigest()
98
+
99
+
100
+ class LocalEmbeddings(EmbeddingProvider):
101
+ """Placeholder for sentence-transformers local embedding.
102
+
103
+ Returns deterministic pseudo-embeddings based on text hash for testing.
104
+ Replace the ``embed`` method with a real model for production use.
105
+ """
106
+
107
+ def __init__(self, dims: int = 384) -> None:
108
+ self._dims = dims
109
+
110
+ def dimensions(self) -> int:
111
+ return self._dims
112
+
113
+ async def embed(self, texts: list[str]) -> list[list[float]]:
114
+ return [self._pseudo_embedding(t) for t in texts]
115
+
116
+ def _pseudo_embedding(self, text: str) -> list[float]:
117
+ """Generate a deterministic pseudo-embedding from text hash."""
118
+ h = hashlib.sha256(text.encode()).digest()
119
+ raw = [b / 255.0 for b in h]
120
+ # Extend or truncate to match dimensions
121
+ while len(raw) < self._dims:
122
+ raw = raw + raw
123
+ raw = raw[: self._dims]
124
+ # Normalize to unit vector
125
+ norm = sum(x * x for x in raw) ** 0.5
126
+ return [x / norm if norm > 0 else 0.0 for x in raw]
127
+
128
+
129
+ def _build_provider(config: Any) -> EmbeddingProvider:
130
+ """Factory: create an EmbeddingProvider from an EmbeddingConfig."""
131
+ if config.provider == "openai":
132
+ return OpenAIEmbeddings(model=config.model, batch_size=config.batch_size)
133
+ if config.provider == "local":
134
+ return LocalEmbeddings(dims=config.dimensions)
135
+ msg = f"Unknown embedding provider: {config.provider!r}"
136
+ raise ValueError(msg)
rag/models.py ADDED
@@ -0,0 +1,73 @@
1
+ """Core domain models for the RAG pipeline."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import uuid
6
+ from typing import Any
7
+
8
+ from pydantic import BaseModel, Field
9
+
10
+
11
+ def _default_id() -> str:
12
+ return uuid.uuid4().hex
13
+
14
+
15
+ class Document(BaseModel):
16
+ """A source document to be ingested into the RAG pipeline."""
17
+
18
+ id: str = Field(default_factory=_default_id)
19
+ text: str
20
+ metadata: dict[str, Any] = Field(default_factory=dict)
21
+
22
+
23
+ class Chunk(BaseModel):
24
+ """A text chunk derived from a document."""
25
+
26
+ id: str = Field(default_factory=_default_id)
27
+ text: str
28
+ document_id: str
29
+ start_idx: int
30
+ end_idx: int
31
+ metadata: dict[str, Any] = Field(default_factory=dict)
32
+ embedding: list[float] | None = None
33
+
34
+
35
+ class SearchResult(BaseModel):
36
+ """A single result returned from vector search."""
37
+
38
+ chunk: Chunk
39
+ score: float
40
+ metadata: dict[str, Any] = Field(default_factory=dict)
41
+
42
+
43
+ class ChunkingConfig(BaseModel):
44
+ """Configuration for text chunking."""
45
+
46
+ strategy: str = "recursive"
47
+ chunk_size: int = 512
48
+ chunk_overlap: int = 64
49
+
50
+
51
+ class EmbeddingConfig(BaseModel):
52
+ """Configuration for the embedding provider."""
53
+
54
+ provider: str = "openai"
55
+ model: str = "text-embedding-3-small"
56
+ dimensions: int = 1536
57
+ batch_size: int = 100
58
+
59
+
60
+ class VectorStoreConfig(BaseModel):
61
+ """Configuration for the vector store backend."""
62
+
63
+ backend: str = "pgvector"
64
+ table_name: str = "rag_chunks"
65
+ dsn: str = ""
66
+
67
+
68
+ class RAGConfig(BaseModel):
69
+ """Top-level configuration for the entire RAG pipeline."""
70
+
71
+ chunking: ChunkingConfig = Field(default_factory=ChunkingConfig)
72
+ embedding: EmbeddingConfig = Field(default_factory=EmbeddingConfig)
73
+ vector_store: VectorStoreConfig = Field(default_factory=VectorStoreConfig)
rag/pipeline.py ADDED
@@ -0,0 +1,110 @@
1
+ """RAG pipeline orchestrating chunking, embedding, storage, and retrieval."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ from rag.chunking import MarkdownChunker, TextChunker
8
+ from rag.embeddings import LocalEmbeddings, _build_provider
9
+ from rag.models import Chunk, Document, RAGConfig, SearchResult
10
+ from rag.reranker import Reranker
11
+ from rag.vector_store import InMemoryVectorStore, VectorStore
12
+
13
+ if TYPE_CHECKING:
14
+ from rag.embeddings import EmbeddingProvider
15
+
16
+
17
+ class RAGPipeline:
18
+ """Orchestrates the full RAG flow: chunk -> embed -> store -> retrieve.
19
+
20
+ Can be configured via RAGConfig or by injecting components directly.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ config: RAGConfig | None = None,
26
+ chunker: TextChunker | MarkdownChunker | None = None,
27
+ embedding_provider: EmbeddingProvider | None = None,
28
+ vector_store: VectorStore | None = None,
29
+ ) -> None:
30
+ cfg = config or RAGConfig()
31
+
32
+ self.chunker = chunker or TextChunker(
33
+ strategy=cfg.chunking.strategy,
34
+ chunk_size=cfg.chunking.chunk_size,
35
+ chunk_overlap=cfg.chunking.chunk_overlap,
36
+ )
37
+ self.embedding_provider: EmbeddingProvider = embedding_provider or self._safe_provider(cfg)
38
+ self.vector_store = vector_store or InMemoryVectorStore()
39
+
40
+ @staticmethod
41
+ def _safe_provider(cfg: RAGConfig) -> EmbeddingProvider:
42
+ """Try to build the configured provider, fall back to local."""
43
+ try:
44
+ return _build_provider(cfg.embedding)
45
+ except Exception:
46
+ return LocalEmbeddings(dims=cfg.embedding.dimensions)
47
+
48
+ async def ingest(self, documents: list[Document]) -> int:
49
+ """Process documents through the full pipeline: chunk -> embed -> store.
50
+
51
+ Returns the total number of chunks stored.
52
+ """
53
+ all_chunks: list[Chunk] = []
54
+
55
+ for doc in documents:
56
+ chunks = self.chunker.chunk(
57
+ text=doc.text,
58
+ document_id=doc.id,
59
+ metadata=doc.metadata,
60
+ )
61
+ all_chunks.extend(chunks)
62
+
63
+ if not all_chunks:
64
+ return 0
65
+
66
+ # Embed all chunks in one batch
67
+ texts = [c.text for c in all_chunks]
68
+ embeddings = await self.embedding_provider.embed(texts)
69
+
70
+ for chunk, embedding in zip(all_chunks, embeddings):
71
+ chunk.embedding = embedding
72
+
73
+ # Store
74
+ count = await self.vector_store.upsert(all_chunks)
75
+ return count
76
+
77
+ async def retrieve(
78
+ self,
79
+ query: str,
80
+ top_k: int = 10,
81
+ filters: dict | None = None,
82
+ ) -> list[SearchResult]:
83
+ """Embed the query and search the vector store."""
84
+ query_embeddings = await self.embedding_provider.embed([query])
85
+ query_embedding = query_embeddings[0]
86
+ return await self.vector_store.search(
87
+ query_embedding=query_embedding,
88
+ top_k=top_k,
89
+ filters=filters,
90
+ )
91
+
92
+ async def retrieve_with_rerank(
93
+ self,
94
+ query: str,
95
+ top_k: int = 10,
96
+ reranker: Reranker | None = None,
97
+ initial_k: int | None = None,
98
+ filters: dict | None = None,
99
+ ) -> list[SearchResult]:
100
+ """Retrieve a larger set, then rerank down to top_k.
101
+
102
+ If no reranker is provided, behaves like plain retrieve.
103
+ """
104
+ fetch_k = initial_k or top_k * 3
105
+ results = await self.retrieve(query, top_k=fetch_k, filters=filters)
106
+
107
+ if reranker is None:
108
+ return results[:top_k]
109
+
110
+ return await reranker.rerank(query, results, top_k=top_k)
rag/reranker.py ADDED
@@ -0,0 +1,113 @@
1
+ """Reranking strategies for improving retrieval quality."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from abc import ABC, abstractmethod
6
+
7
+ from rag.models import SearchResult
8
+
9
+
10
+ class Reranker(ABC):
11
+ """Abstract base for reranking retrieved results."""
12
+
13
+ @abstractmethod
14
+ async def rerank(
15
+ self, query: str, results: list[SearchResult], top_k: int = 10
16
+ ) -> list[SearchResult]:
17
+ """Rerank results by relevance to *query*. Returns top_k results."""
18
+
19
+
20
+ class CrossEncoderReranker(Reranker):
21
+ """Placeholder for cross-encoder model reranking.
22
+
23
+ In production, load a cross-encoder model (e.g., ms-marco-MiniLM) and score
24
+ each (query, chunk.text) pair. This placeholder passes results through with
25
+ a simple length-based heuristic for testing.
26
+ """
27
+
28
+ async def rerank(
29
+ self, query: str, results: list[SearchResult], top_k: int = 10
30
+ ) -> list[SearchResult]:
31
+ # Placeholder: score by keyword overlap ratio
32
+ query_terms = set(query.lower().split())
33
+ scored: list[tuple[float, SearchResult]] = []
34
+ for result in results:
35
+ chunk_terms = set(result.chunk.text.lower().split())
36
+ overlap = len(query_terms & chunk_terms)
37
+ total = len(query_terms) if query_terms else 1
38
+ relevance = overlap / total
39
+ scored.append((relevance, result))
40
+
41
+ scored.sort(key=lambda x: x[0], reverse=True)
42
+ return [
43
+ SearchResult(
44
+ chunk=r.chunk,
45
+ score=score,
46
+ metadata=r.metadata,
47
+ )
48
+ for score, r in scored[:top_k]
49
+ ]
50
+
51
+
52
+ class LLMReranker(Reranker):
53
+ """Reranker that uses an LLM to score relevance.
54
+
55
+ Accepts a callable ``score_fn`` that takes (query, text) and returns a
56
+ float relevance score. This allows integration with willian-ai or any
57
+ LLM client.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ score_fn: callable | None = None,
63
+ ) -> None:
64
+ self._score_fn = score_fn
65
+
66
+ async def rerank(
67
+ self, query: str, results: list[SearchResult], top_k: int = 10
68
+ ) -> list[SearchResult]:
69
+ if self._score_fn is None:
70
+ # No scoring function — return as-is
71
+ return results[:top_k]
72
+
73
+ scored: list[tuple[float, SearchResult]] = []
74
+ for result in results:
75
+ score = await self._score_fn(query, result.chunk.text)
76
+ scored.append((float(score), result))
77
+
78
+ scored.sort(key=lambda x: x[0], reverse=True)
79
+ return [
80
+ SearchResult(chunk=r.chunk, score=score, metadata=r.metadata)
81
+ for score, r in scored[:top_k]
82
+ ]
83
+
84
+
85
+ def reciprocal_rank_fusion(
86
+ *result_lists: list[SearchResult],
87
+ k: int = 60,
88
+ top_n: int = 10,
89
+ ) -> list[SearchResult]:
90
+ """Merge multiple ranked result lists using Reciprocal Rank Fusion (RRF).
91
+
92
+ RRF score for document d = sum over lists of 1 / (k + rank_in_list).
93
+ Default k=60 following the original RRF paper.
94
+ """
95
+ scores: dict[str, float] = {}
96
+ chunk_map: dict[str, SearchResult] = {}
97
+
98
+ for results in result_lists:
99
+ for rank, result in enumerate(results):
100
+ cid = result.chunk.id
101
+ scores[cid] = scores.get(cid, 0.0) + 1.0 / (k + rank + 1)
102
+ if cid not in chunk_map:
103
+ chunk_map[cid] = result
104
+
105
+ ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)
106
+ return [
107
+ SearchResult(
108
+ chunk=chunk_map[cid].chunk,
109
+ score=score,
110
+ metadata=chunk_map[cid].metadata,
111
+ )
112
+ for cid, score in ranked[:top_n]
113
+ ]
rag/vector_store.py ADDED
@@ -0,0 +1,293 @@
1
+ """Vector store backends for chunk storage and similarity search."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from abc import ABC, abstractmethod
7
+ from typing import Any
8
+
9
+ from rag.models import Chunk, SearchResult
10
+
11
+
12
+ class VectorStore(ABC):
13
+ """Abstract base for vector store backends."""
14
+
15
+ @abstractmethod
16
+ async def upsert(self, chunks: list[Chunk]) -> int:
17
+ """Insert or update chunks. Returns the number of affected rows."""
18
+
19
+ @abstractmethod
20
+ async def search(
21
+ self,
22
+ query_embedding: list[float],
23
+ top_k: int = 10,
24
+ filters: dict[str, Any] | None = None,
25
+ ) -> list[SearchResult]:
26
+ """Search by cosine similarity. Returns results ordered by score descending."""
27
+
28
+ @abstractmethod
29
+ async def delete(self, chunk_ids: list[str]) -> int:
30
+ """Delete chunks by ID. Returns the number of deleted rows."""
31
+
32
+
33
+ class PgVectorStore(VectorStore):
34
+ """PostgreSQL + pgvector backend using asyncpg.
35
+
36
+ Supports both pure vector search and hybrid (vector + full-text) search
37
+ via Reciprocal Rank Fusion.
38
+ """
39
+
40
+ def __init__(self, dsn: str, table_name: str = "rag_chunks", dimensions: int = 1536) -> None:
41
+ self.dsn = dsn
42
+ self.table_name = table_name
43
+ self.dimensions = dimensions
44
+ self._pool: Any = None
45
+
46
+ async def _get_pool(self) -> Any:
47
+ if self._pool is None:
48
+ import asyncpg
49
+
50
+ self._pool = await asyncpg.create_pool(self.dsn, min_size=2, max_size=10)
51
+ return self._pool
52
+
53
+ async def create_table(self) -> None:
54
+ """Create the chunks table with vector and tsvector columns."""
55
+ pool = await self._get_pool()
56
+ async with pool.acquire() as conn:
57
+ await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
58
+ await conn.execute(f"""
59
+ CREATE TABLE IF NOT EXISTS {self.table_name} (
60
+ id TEXT PRIMARY KEY,
61
+ text TEXT NOT NULL,
62
+ document_id TEXT NOT NULL,
63
+ start_idx INTEGER NOT NULL,
64
+ end_idx INTEGER NOT NULL,
65
+ metadata JSONB DEFAULT '{{}}',
66
+ embedding vector({self.dimensions}),
67
+ tsv tsvector GENERATED ALWAYS AS (to_tsvector('english', text)) STORED
68
+ )
69
+ """)
70
+ await conn.execute(
71
+ f"CREATE INDEX IF NOT EXISTS idx_{self.table_name}_embedding "
72
+ f"ON {self.table_name} USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100)"
73
+ )
74
+ await conn.execute(
75
+ f"CREATE INDEX IF NOT EXISTS idx_{self.table_name}_tsv "
76
+ f"ON {self.table_name} USING gin (tsv)"
77
+ )
78
+
79
+ async def upsert(self, chunks: list[Chunk]) -> int:
80
+ """Bulk upsert chunks with ON CONFLICT."""
81
+ if not chunks:
82
+ return 0
83
+ pool = await self._get_pool()
84
+ async with pool.acquire() as conn:
85
+ values = [
86
+ (
87
+ c.id,
88
+ c.text,
89
+ c.document_id,
90
+ c.start_idx,
91
+ c.end_idx,
92
+ json.dumps(c.metadata),
93
+ str(c.embedding) if c.embedding else None,
94
+ )
95
+ for c in chunks
96
+ ]
97
+ result = await conn.executemany(
98
+ f"""
99
+ INSERT INTO {self.table_name} (id, text, document_id, start_idx, end_idx, metadata, embedding)
100
+ VALUES ($1, $2, $3, $4, $5, $6::jsonb, $7::vector)
101
+ ON CONFLICT (id) DO UPDATE SET
102
+ text = EXCLUDED.text,
103
+ document_id = EXCLUDED.document_id,
104
+ metadata = EXCLUDED.metadata,
105
+ embedding = EXCLUDED.embedding
106
+ """,
107
+ values,
108
+ )
109
+ return len(chunks) if result is None else len(chunks)
110
+
111
+ async def search(
112
+ self,
113
+ query_embedding: list[float],
114
+ top_k: int = 10,
115
+ filters: dict[str, Any] | None = None,
116
+ ) -> list[SearchResult]:
117
+ """Cosine similarity search."""
118
+ pool = await self._get_pool()
119
+ filter_clause = self._build_filter_clause(filters)
120
+ embedding_str = str(query_embedding)
121
+
122
+ async with pool.acquire() as conn:
123
+ rows = await conn.fetch(
124
+ f"""
125
+ SELECT id, text, document_id, start_idx, end_idx, metadata,
126
+ 1 - (embedding <=> $1::vector) AS score
127
+ FROM {self.table_name}
128
+ {filter_clause}
129
+ ORDER BY embedding <=> $1::vector
130
+ LIMIT $2
131
+ """,
132
+ embedding_str,
133
+ top_k,
134
+ )
135
+ return [self._row_to_result(row) for row in rows]
136
+
137
+ async def hybrid_search(
138
+ self,
139
+ query_text: str,
140
+ query_embedding: list[float],
141
+ top_k: int = 10,
142
+ vector_weight: float = 0.7,
143
+ filters: dict[str, Any] | None = None,
144
+ ) -> list[SearchResult]:
145
+ """Hybrid search combining vector similarity and full-text search via RRF."""
146
+ pool = await self._get_pool()
147
+ filter_clause = self._build_filter_clause(filters)
148
+ embedding_str = str(query_embedding)
149
+
150
+ async with pool.acquire() as conn:
151
+ rows = await conn.fetch(
152
+ f"""
153
+ WITH vector_results AS (
154
+ SELECT id, 1 - (embedding <=> $1::vector) AS vec_score,
155
+ ROW_NUMBER() OVER (ORDER BY embedding <=> $1::vector) AS vec_rank
156
+ FROM {self.table_name} {filter_clause}
157
+ LIMIT $3
158
+ ),
159
+ text_results AS (
160
+ SELECT id, ts_rank_cd(tsv, plainto_tsquery('english', $2)) AS text_score,
161
+ ROW_NUMBER() OVER (ORDER BY ts_rank_cd(tsv, plainto_tsquery('english', $2)) DESC) AS text_rank
162
+ FROM {self.table_name}
163
+ WHERE tsv @@ plainto_tsquery('english', $2) {filter_clause.replace("WHERE", "AND") if filter_clause else ""}
164
+ LIMIT $3
165
+ ),
166
+ combined AS (
167
+ SELECT COALESCE(v.id, t.id) AS id,
168
+ $4 * COALESCE(1.0 / (60 + v.vec_rank), 0) +
169
+ (1 - $4) * COALESCE(1.0 / (60 + t.text_rank), 0) AS rrf_score
170
+ FROM vector_results v
171
+ FULL OUTER JOIN text_results t ON v.id = t.id
172
+ ORDER BY rrf_score DESC
173
+ LIMIT $3
174
+ )
175
+ SELECT c.id, c.text, c.document_id, c.start_idx, c.end_idx,
176
+ c.metadata, combined.rrf_score AS score
177
+ FROM combined
178
+ JOIN {self.table_name} c ON c.id = combined.id
179
+ ORDER BY combined.rrf_score DESC
180
+ """,
181
+ embedding_str,
182
+ query_text,
183
+ top_k,
184
+ vector_weight,
185
+ )
186
+ return [self._row_to_result(row) for row in rows]
187
+
188
+ async def delete(self, chunk_ids: list[str]) -> int:
189
+ if not chunk_ids:
190
+ return 0
191
+ pool = await self._get_pool()
192
+ async with pool.acquire() as conn:
193
+ result = await conn.execute(
194
+ f"DELETE FROM {self.table_name} WHERE id = ANY($1)", chunk_ids
195
+ )
196
+ # asyncpg returns "DELETE N"
197
+ return int(result.split()[-1]) if result else 0
198
+
199
+ async def close(self) -> None:
200
+ """Close the connection pool."""
201
+ if self._pool:
202
+ await self._pool.close()
203
+ self._pool = None
204
+
205
+ @staticmethod
206
+ def _build_filter_clause(filters: dict[str, Any] | None) -> str:
207
+ if not filters:
208
+ return ""
209
+ conditions = []
210
+ for key, value in filters.items():
211
+ escaped = json.dumps(value)
212
+ conditions.append(f"metadata->>'{key}' = {escaped}")
213
+ return "WHERE " + " AND ".join(conditions)
214
+
215
+ @staticmethod
216
+ def _row_to_result(row: Any) -> SearchResult:
217
+ metadata = (
218
+ row["metadata"] if isinstance(row["metadata"], dict) else json.loads(row["metadata"])
219
+ )
220
+ chunk = Chunk(
221
+ id=row["id"],
222
+ text=row["text"],
223
+ document_id=row["document_id"],
224
+ start_idx=row["start_idx"],
225
+ end_idx=row["end_idx"],
226
+ metadata=metadata,
227
+ )
228
+ return SearchResult(chunk=chunk, score=float(row["score"]), metadata=metadata)
229
+
230
+
231
+ class InMemoryVectorStore(VectorStore):
232
+ """In-memory vector store using numpy for cosine similarity. For testing."""
233
+
234
+ def __init__(self) -> None:
235
+ self._chunks: dict[str, Chunk] = {}
236
+
237
+ async def upsert(self, chunks: list[Chunk]) -> int:
238
+ count = 0
239
+ for chunk in chunks:
240
+ self._chunks[chunk.id] = chunk
241
+ count += 1
242
+ return count
243
+
244
+ async def search(
245
+ self,
246
+ query_embedding: list[float],
247
+ top_k: int = 10,
248
+ filters: dict[str, Any] | None = None,
249
+ ) -> list[SearchResult]:
250
+ import numpy as np
251
+
252
+ candidates = list(self._chunks.values())
253
+ if filters:
254
+ candidates = [
255
+ c for c in candidates if all(c.metadata.get(k) == v for k, v in filters.items())
256
+ ]
257
+
258
+ candidates_with_emb = [c for c in candidates if c.embedding is not None]
259
+ if not candidates_with_emb:
260
+ return []
261
+
262
+ query_vec = np.array(query_embedding, dtype=np.float64)
263
+ query_norm = np.linalg.norm(query_vec)
264
+ if query_norm == 0:
265
+ return []
266
+ query_vec = query_vec / query_norm
267
+
268
+ scored: list[tuple[float, Chunk]] = []
269
+ for chunk in candidates_with_emb:
270
+ chunk_vec = np.array(chunk.embedding, dtype=np.float64)
271
+ chunk_norm = np.linalg.norm(chunk_vec)
272
+ if chunk_norm == 0:
273
+ continue
274
+ score = float(np.dot(query_vec, chunk_vec / chunk_norm))
275
+ scored.append((score, chunk))
276
+
277
+ scored.sort(key=lambda x: x[0], reverse=True)
278
+ return [
279
+ SearchResult(chunk=chunk, score=score, metadata=chunk.metadata)
280
+ for score, chunk in scored[:top_k]
281
+ ]
282
+
283
+ async def delete(self, chunk_ids: list[str]) -> int:
284
+ count = 0
285
+ for cid in chunk_ids:
286
+ if cid in self._chunks:
287
+ del self._chunks[cid]
288
+ count += 1
289
+ return count
290
+
291
+ @property
292
+ def size(self) -> int:
293
+ return len(self._chunks)