ff-aitoolkit 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aitoolkit/rag/agent.py ADDED
@@ -0,0 +1,165 @@
1
+ """Centralized RAG agent coordinating embeddings, vector store, and retrieval.
2
+
3
+ Provider-agnostic: no Google/Gemini defaults. ``answer_question`` is functional —
4
+ it retrieves context and generates an answer with the toolkit's
5
+ :class:`~aitoolkit.llm.LLMClient`.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Any, Dict, List, Optional
11
+
12
+ from loguru import logger
13
+
14
+ from aitoolkit.config import get_settings
15
+ from aitoolkit.embeddings import EmbeddingsClient, get_embeddings_client
16
+ from aitoolkit.llm import LLMClient, get_llm_client
17
+ from aitoolkit.rag.retriever import FilterFn, RAGRetriever
18
+ from aitoolkit.rag.vector_store import UnifiedVectorStore
19
+ from aitoolkit.types import RetrievedChunk
20
+
21
+ _DEFAULT_ANSWER_SYSTEM = (
22
+ "You are a helpful assistant. Answer the question using ONLY the provided "
23
+ "context. If the context is insufficient, say so clearly."
24
+ )
25
+
26
+
27
+ class RAGAgent:
28
+ """Unified interface for embedding, storing, retrieving and answering."""
29
+
30
+ def __init__(
31
+ self,
32
+ qdrant_url: Optional[str] = None,
33
+ collection_name: Optional[str] = None,
34
+ redis_url: Optional[str] = None,
35
+ embeddings: Optional[EmbeddingsClient] = None,
36
+ llm: Optional[LLMClient] = None,
37
+ enable_caching: bool = True,
38
+ filter_fn: Optional[FilterFn] = None,
39
+ ) -> None:
40
+ settings = get_settings()
41
+ self.collection_name = collection_name or settings.qdrant_collection
42
+ self.embeddings = embeddings or get_embeddings_client()
43
+ self.llm = llm or get_llm_client()
44
+
45
+ self.vector_store = UnifiedVectorStore(
46
+ qdrant_url=qdrant_url,
47
+ collection_name=self.collection_name,
48
+ embeddings=self.embeddings,
49
+ )
50
+ self.retriever = RAGRetriever(
51
+ vector_store=self.vector_store,
52
+ redis_url=(redis_url or settings.redis_url) if enable_caching else None,
53
+ cache_ttl=settings.cache_ttl,
54
+ filter_fn=filter_fn,
55
+ )
56
+ logger.success(f"RAGAgent ready (collection='{self.collection_name}')")
57
+
58
+ async def add_documents(
59
+ self,
60
+ texts: List[str],
61
+ *,
62
+ file_id: str,
63
+ metadatas: Optional[List[Dict[str, Any]]] = None,
64
+ source_type: str = "upload",
65
+ **extra: Any,
66
+ ) -> List[str]:
67
+ """Embed and store document chunks. Returns point IDs."""
68
+ return await self.vector_store.add_texts(
69
+ texts, metadatas, file_id=file_id, source_type=source_type, **extra
70
+ )
71
+
72
+ async def delete_file(self, file_id: str) -> None:
73
+ await self.vector_store.delete_by_file_id(file_id)
74
+
75
+ async def retrieve_context(
76
+ self,
77
+ query: str,
78
+ *,
79
+ file_ids: Optional[List[str]] = None,
80
+ limit: int = 10,
81
+ score_threshold: Optional[float] = None,
82
+ use_cache: bool = True,
83
+ ) -> List[RetrievedChunk]:
84
+ return await self.retriever.retrieve(
85
+ query,
86
+ file_ids=file_ids,
87
+ limit=limit,
88
+ score_threshold=score_threshold,
89
+ use_cache=use_cache,
90
+ )
91
+
92
+ async def get_formatted_context(
93
+ self,
94
+ query: str,
95
+ *,
96
+ file_ids: Optional[List[str]] = None,
97
+ limit: int = 10,
98
+ separator: str = "\n\n---\n\n",
99
+ ) -> str:
100
+ return await self.retriever.get_context_text(
101
+ query, file_ids=file_ids, limit=limit, separator=separator
102
+ )
103
+
104
+ async def answer_question(
105
+ self,
106
+ question: str,
107
+ *,
108
+ file_ids: Optional[List[str]] = None,
109
+ limit: int = 5,
110
+ system: str = _DEFAULT_ANSWER_SYSTEM,
111
+ temperature: Optional[float] = None,
112
+ ) -> Dict[str, Any]:
113
+ """Retrieve context and generate a grounded answer."""
114
+ chunks = await self.retrieve_context(
115
+ question, file_ids=file_ids, limit=limit
116
+ )
117
+ if not chunks:
118
+ return {
119
+ "answer": "I don't have enough context to answer this question.",
120
+ "sources": [],
121
+ }
122
+
123
+ context = "\n\n---\n\n".join(c.text for c in chunks)
124
+ prompt = (
125
+ f"Context:\n{context}\n\nQuestion: {question}\n\nAnswer:"
126
+ )
127
+ answer = await self.llm.chat(
128
+ prompt, system=system, temperature=temperature
129
+ )
130
+ return {
131
+ "answer": answer,
132
+ "sources": [c.model_dump() for c in chunks],
133
+ }
134
+
135
+ async def get_file_ids(self, exclude: Optional[List[str]] = None) -> List[str]:
136
+ return await self.vector_store.get_unique_file_ids(exclude=exclude)
137
+
138
+ async def get_file_count(self, exclude: Optional[List[str]] = None) -> int:
139
+ return len(await self.get_file_ids(exclude=exclude))
140
+
141
+ async def aclose(self) -> None:
142
+ await self.retriever.aclose()
143
+ await self.vector_store.aclose()
144
+
145
+
146
+ _agents: Dict[str, RAGAgent] = {}
147
+
148
+
149
+ def get_rag_agent(
150
+ collection_name: Optional[str] = None,
151
+ **kwargs: Any,
152
+ ) -> RAGAgent:
153
+ """Return a cached RAG agent, one per collection.
154
+
155
+ Caching is keyed on the collection name so an app that serves multiple
156
+ collections gets a distinct agent for each — a single global singleton would
157
+ silently hand back the first collection's agent for every later call.
158
+ ``kwargs`` only take effect the first time a given collection is requested.
159
+ """
160
+ key = collection_name or get_settings().qdrant_collection
161
+ agent = _agents.get(key)
162
+ if agent is None:
163
+ agent = RAGAgent(collection_name=collection_name, **kwargs)
164
+ _agents[key] = agent
165
+ return agent
@@ -0,0 +1,147 @@
1
+ """Query expansion for improved RAG recall.
2
+
3
+ Domain-agnostic: no domain-specific keyword list is hardcoded. Pass ``domain``
4
+ and an optional ``keywords`` list to tailor expansion to any project. With no
5
+ keywords, expansion is purely structural.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from functools import lru_cache
11
+ from typing import List, Optional, Tuple
12
+
13
+ from loguru import logger
14
+
15
+ # Generic English stop words for naive keyword extraction.
16
+ _STOP_WORDS = {
17
+ "the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for",
18
+ "of", "with", "is", "are", "was", "were", "be", "this", "that", "it",
19
+ }
20
+
21
+ # Content-type templates that anticipate document content for better retrieval.
22
+ _TEMPLATES = {
23
+ "quiz": (
24
+ "What are the key concepts, definitions, procedures, and important points "
25
+ "related to {topic}? What questions should learners understand about {topic}?"
26
+ ),
27
+ "lesson": (
28
+ "Explain {topic} in detail. What are the fundamental concepts of {topic}? "
29
+ "What should learners understand about {topic}?"
30
+ ),
31
+ "summary": (
32
+ "Summarize the main points about {topic}. What are the key takeaways "
33
+ "regarding {topic}? Provide an overview of {topic}."
34
+ ),
35
+ "default": (
36
+ "What topics and skills are needed to learn {topic}? "
37
+ "Provide comprehensive material about {topic}."
38
+ ),
39
+ }
40
+
41
+
42
+ class QueryExpander:
43
+ """Expand short topics/questions into retrieval-optimized queries."""
44
+
45
+ def __init__(
46
+ self,
47
+ domain: str = "general",
48
+ keywords: Optional[List[str]] = None,
49
+ domain_hint: Optional[str] = None,
50
+ ) -> None:
51
+ """
52
+ Args:
53
+ domain: Free-form domain label (used in chat expansion context).
54
+ keywords: Optional domain keywords appended to keyword extraction and
55
+ used to decide when to add ``domain_hint``.
56
+ domain_hint: Extra phrase appended when a topic matches a keyword
57
+ (e.g. "Include relevant regulations and safety procedures.").
58
+ """
59
+ self.domain = domain
60
+ self.keywords = [k.lower() for k in (keywords or [])]
61
+ self.domain_hint = domain_hint
62
+
63
+ def expand_for_generation(
64
+ self,
65
+ topic: str,
66
+ content_type: str = "default",
67
+ file_names: Optional[List[str]] = None,
68
+ ) -> str:
69
+ """Expand a topic into a detailed query for content generation."""
70
+ prefix = ""
71
+ if file_names:
72
+ clean = [_strip_ext(n) for n in file_names[:3]]
73
+ if clean:
74
+ prefix = f"Related to documents: {', '.join(clean)}. "
75
+
76
+ template = _TEMPLATES.get(content_type, _TEMPLATES["default"])
77
+ expanded = prefix + template.format(topic=topic)
78
+
79
+ if self.domain_hint and self._matches_domain(topic):
80
+ expanded += f" {self.domain_hint}"
81
+
82
+ logger.debug(f"expanded '{topic}' -> '{expanded[:80]}...'")
83
+ return expanded
84
+
85
+ def expand_for_chat(self, question: str, conversation_history: str = "") -> str:
86
+ """Lightly expand a short chat question for better retrieval."""
87
+ if len(question.split()) >= 5:
88
+ return question
89
+ if conversation_history:
90
+ return f"{question}. Context: {conversation_history[-200:]}"
91
+ return f"{question} in {self.domain} context"
92
+
93
+ def generate_query_variations(
94
+ self, query: str, num_variations: int = 3
95
+ ) -> List[str]:
96
+ """Produce structural query variations for hybrid search."""
97
+ variations = [query]
98
+ lower = query.lower()
99
+ if "what" not in lower and num_variations >= 2:
100
+ variations.append(f"What is {query}? Explain {query}.")
101
+ if "how" not in lower and num_variations >= 3:
102
+ variations.append(f"How to {query}? Steps for {query}.")
103
+ if self.keywords and num_variations >= 4:
104
+ variations.append(f"{query} {' '.join(self.keywords[:5])}")
105
+ return variations[:num_variations]
106
+
107
+ def extract_keywords(self, text: str, top_n: int = 10) -> List[str]:
108
+ """Naive keyword extraction plus any configured domain keywords present."""
109
+ words = [w for w in text.lower().split() if w not in _STOP_WORDS and len(w) > 3]
110
+ domain_kws = [k for k in self.keywords if k in text.lower()]
111
+ seen: List[str] = []
112
+ for w in words + domain_kws:
113
+ if w not in seen:
114
+ seen.append(w)
115
+ return seen[:top_n]
116
+
117
+ def _matches_domain(self, topic: str) -> bool:
118
+ if not self.keywords:
119
+ return False
120
+ low = topic.lower()
121
+ return any(k in low for k in self.keywords)
122
+
123
+
124
+ def _strip_ext(name: str) -> str:
125
+ for ext in (".pdf", ".docx", ".pptx", ".xlsx", ".txt", ".doc"):
126
+ name = name.replace(ext, "")
127
+ return name
128
+
129
+
130
+ @lru_cache(maxsize=16)
131
+ def _cached_expander(
132
+ domain: str, keywords: Tuple[str, ...], domain_hint: Optional[str]
133
+ ) -> QueryExpander:
134
+ return QueryExpander(domain=domain, keywords=list(keywords), domain_hint=domain_hint)
135
+
136
+
137
+ def get_query_expander(
138
+ domain: str = "general",
139
+ keywords: Optional[List[str]] = None,
140
+ domain_hint: Optional[str] = None,
141
+ ) -> QueryExpander:
142
+ """Return a cached query expander, one per (domain, keywords, hint) combo.
143
+
144
+ Keying on the arguments avoids the single-global footgun where the first
145
+ caller's domain/keywords would be returned for every subsequent call.
146
+ """
147
+ return _cached_expander(domain, tuple(keywords or ()), domain_hint)
@@ -0,0 +1,141 @@
1
+ """RAG retriever with optional Redis caching and pluggable filtering.
2
+
3
+ Differences from the old implementation:
4
+
5
+ * No LangChain dependency — operates on :class:`RetrievedChunk`.
6
+ * The hardcoded "drop DOCX files" rule is gone; pass a ``filter_fn`` predicate to
7
+ apply any project-specific filtering instead.
8
+ * Redis is optional (``aitoolkit[cache]``); without it, caching is a no-op.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import hashlib
14
+ import json
15
+ from typing import Callable, List, Optional
16
+
17
+ from loguru import logger
18
+
19
+ from aitoolkit.rag.vector_store import UnifiedVectorStore
20
+ from aitoolkit.types import RetrievedChunk
21
+
22
+ try: # optional dependency
23
+ import redis.asyncio as redis
24
+ except ImportError: # pragma: no cover
25
+ redis = None # type: ignore[assignment]
26
+
27
+ FilterFn = Callable[[RetrievedChunk], bool]
28
+
29
+
30
+ class RAGRetriever:
31
+ """Retrieve relevant chunks, with optional caching, filtering and reranking."""
32
+
33
+ def __init__(
34
+ self,
35
+ vector_store: UnifiedVectorStore,
36
+ redis_url: Optional[str] = None,
37
+ cache_ttl: int = 3600,
38
+ filter_fn: Optional[FilterFn] = None,
39
+ reranker: Optional[Callable[[str, List[RetrievedChunk]], List[RetrievedChunk]]] = None,
40
+ ) -> None:
41
+ self.vector_store = vector_store
42
+ self.cache_ttl = cache_ttl
43
+ self.filter_fn = filter_fn
44
+ self.reranker = reranker
45
+
46
+ self.redis_client = None
47
+ if redis_url and redis is not None:
48
+ try:
49
+ self.redis_client = redis.from_url(redis_url, decode_responses=True)
50
+ logger.info("RAG retriever caching enabled")
51
+ except Exception as exc: # noqa: BLE001
52
+ logger.warning(f"failed to init Redis cache: {exc}")
53
+ elif redis_url and redis is None:
54
+ logger.warning(
55
+ "redis_url provided but redis not installed; install aitoolkit[cache]"
56
+ )
57
+
58
+ def _cache_key(self, query: str, file_ids: Optional[List[str]], limit: int) -> str:
59
+ payload = {
60
+ "q": query,
61
+ "f": sorted(file_ids) if file_ids else None,
62
+ "l": limit,
63
+ }
64
+ digest = hashlib.sha256(json.dumps(payload, sort_keys=True).encode()).hexdigest()
65
+ return f"aitoolkit:rag:{digest}"
66
+
67
+ async def _cache_get(self, key: str) -> Optional[List[RetrievedChunk]]:
68
+ if not self.redis_client:
69
+ return None
70
+ try:
71
+ raw = await self.redis_client.get(key)
72
+ if raw:
73
+ return [RetrievedChunk(**item) for item in json.loads(raw)]
74
+ except Exception as exc: # noqa: BLE001
75
+ logger.warning(f"cache read failed: {exc}")
76
+ return None
77
+
78
+ async def _cache_set(self, key: str, chunks: List[RetrievedChunk]) -> None:
79
+ if not self.redis_client:
80
+ return
81
+ try:
82
+ raw = json.dumps([c.model_dump() for c in chunks])
83
+ await self.redis_client.setex(key, self.cache_ttl, raw)
84
+ except Exception as exc: # noqa: BLE001
85
+ logger.warning(f"cache write failed: {exc}")
86
+
87
+ async def retrieve(
88
+ self,
89
+ query: str,
90
+ *,
91
+ file_ids: Optional[List[str]] = None,
92
+ limit: int = 10,
93
+ score_threshold: Optional[float] = None,
94
+ use_cache: bool = True,
95
+ ) -> List[RetrievedChunk]:
96
+ """Return relevant chunks for ``query``."""
97
+ if use_cache:
98
+ cached = await self._cache_get(self._cache_key(query, file_ids, limit))
99
+ if cached is not None:
100
+ logger.debug(f"cache hit ({len(cached)} chunks)")
101
+ return cached
102
+
103
+ chunks = await self.vector_store.similarity_search(
104
+ query, file_ids=file_ids, limit=limit, score_threshold=score_threshold
105
+ )
106
+
107
+ if self.filter_fn:
108
+ chunks = [c for c in chunks if self.filter_fn(c)]
109
+ if self.reranker and chunks:
110
+ chunks = self.reranker(query, chunks)
111
+
112
+ if use_cache and chunks:
113
+ await self._cache_set(self._cache_key(query, file_ids, limit), chunks)
114
+ return chunks
115
+
116
+ async def get_context_text(
117
+ self,
118
+ query: str,
119
+ *,
120
+ file_ids: Optional[List[str]] = None,
121
+ limit: int = 10,
122
+ separator: str = "\n\n---\n\n",
123
+ include_sources: bool = True,
124
+ ) -> str:
125
+ """Retrieve and format chunks into a single context string."""
126
+ chunks = await self.retrieve(query, file_ids=file_ids, limit=limit)
127
+ if not chunks:
128
+ return ""
129
+ parts = []
130
+ for c in chunks:
131
+ if include_sources:
132
+ parts.append(
133
+ f"[Source: {c.file_id or 'unknown'}, Score: {c.score:.3f}]\n{c.text}"
134
+ )
135
+ else:
136
+ parts.append(c.text)
137
+ return separator.join(parts)
138
+
139
+ async def aclose(self) -> None:
140
+ if self.redis_client:
141
+ await self.redis_client.aclose()
@@ -0,0 +1,245 @@
1
+ """Unified Qdrant vector store with file-based filtering.
2
+
3
+ LangChain-free: it works with plain ``texts`` + ``metadatas`` and returns
4
+ :class:`~aitoolkit.types.RetrievedChunk` objects. Embeddings are produced by an
5
+ injected :class:`~aitoolkit.embeddings.EmbeddingsClient`, so the vector size is
6
+ detected at runtime rather than hardcoded.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from typing import Any, Dict, List, Optional
12
+ from uuid import uuid4
13
+
14
+ from loguru import logger
15
+ from qdrant_client import AsyncQdrantClient
16
+ from qdrant_client.http import models as qmodels
17
+
18
+ from aitoolkit.config import get_settings
19
+ from aitoolkit.embeddings import EmbeddingsClient, get_embeddings_client
20
+ from aitoolkit.exceptions import VectorStoreError
21
+ from aitoolkit.types import RetrievedChunk
22
+
23
+
24
+ class UnifiedVectorStore:
25
+ """Async Qdrant wrapper for storing and retrieving document embeddings."""
26
+
27
+ def __init__(
28
+ self,
29
+ qdrant_url: Optional[str] = None,
30
+ collection_name: Optional[str] = None,
31
+ embeddings: Optional[EmbeddingsClient] = None,
32
+ vector_size: Optional[int] = None,
33
+ ) -> None:
34
+ settings = get_settings()
35
+ self.qdrant_url = qdrant_url or settings.qdrant_url
36
+ self.collection_name = collection_name or settings.qdrant_collection
37
+ self.embeddings = embeddings or get_embeddings_client()
38
+ self._client = AsyncQdrantClient(
39
+ url=self.qdrant_url,
40
+ check_compatibility=settings.qdrant_check_compatibility,
41
+ )
42
+ self._initialized = False
43
+ self._vector_size: Optional[int] = vector_size or settings.qdrant_vector_size
44
+ logger.info(
45
+ f"UnifiedVectorStore ready (collection='{self.collection_name}', "
46
+ f"url={self.qdrant_url})"
47
+ )
48
+
49
+ async def _ensure_collection(self, vector_size: Optional[int] = None) -> None:
50
+ """Create the collection if it does not exist (no destructive recreate)."""
51
+ if self._initialized:
52
+ return
53
+
54
+ try:
55
+ exists = await self._client.collection_exists(self.collection_name)
56
+ except Exception as exc: # noqa: BLE001
57
+ raise VectorStoreError(f"failed to check collection: {exc}") from exc
58
+
59
+ if exists:
60
+ self._initialized = True
61
+ try:
62
+ info = await self._client.get_collection(self.collection_name)
63
+ size = info.config.params.vectors.size # type: ignore[union-attr]
64
+ if isinstance(size, int):
65
+ self._vector_size = size
66
+ except Exception: # noqa: BLE001 - size discovery is best-effort
67
+ pass
68
+ return
69
+
70
+ size = vector_size or self._vector_size
71
+ if size is None:
72
+ # Nothing to create yet — caller will provide a size once embeddings exist.
73
+ return
74
+
75
+ try:
76
+ await self._client.create_collection(
77
+ collection_name=self.collection_name,
78
+ vectors_config=qmodels.VectorParams(
79
+ size=size, distance=qmodels.Distance.COSINE
80
+ ),
81
+ )
82
+ self._initialized = True
83
+ self._vector_size = size
84
+ logger.success(
85
+ f"Created collection '{self.collection_name}' (size={size})"
86
+ )
87
+ except Exception as exc: # noqa: BLE001
88
+ raise VectorStoreError(f"failed to create collection: {exc}") from exc
89
+
90
+ async def add_texts(
91
+ self,
92
+ texts: List[str],
93
+ metadatas: Optional[List[Dict[str, Any]]] = None,
94
+ *,
95
+ file_id: str,
96
+ source_type: str = "upload",
97
+ **extra_metadata: Any,
98
+ ) -> List[str]:
99
+ """Embed and store ``texts`` with associated metadata. Returns point IDs."""
100
+ if not texts:
101
+ return []
102
+
103
+ metadatas = metadatas or [{} for _ in texts]
104
+ if len(metadatas) != len(texts):
105
+ raise VectorStoreError("texts and metadatas length mismatch")
106
+
107
+ embeddings = await self.embeddings.aembed_documents(texts)
108
+ if not embeddings:
109
+ logger.warning("No embeddings produced; nothing stored")
110
+ return []
111
+
112
+ await self._ensure_collection(len(embeddings[0]))
113
+
114
+ point_ids: List[str] = []
115
+ points: List[qmodels.PointStruct] = []
116
+ for text, vector, meta in zip(texts, embeddings, metadatas):
117
+ pid = str(uuid4())
118
+ point_ids.append(pid)
119
+ points.append(
120
+ qmodels.PointStruct(
121
+ id=pid,
122
+ vector=vector,
123
+ payload={
124
+ "text": text,
125
+ "file_id": file_id,
126
+ "source_type": source_type,
127
+ **meta,
128
+ **extra_metadata,
129
+ },
130
+ )
131
+ )
132
+
133
+ try:
134
+ await self._client.upsert(
135
+ collection_name=self.collection_name, points=points
136
+ )
137
+ except Exception as exc: # noqa: BLE001
138
+ raise VectorStoreError(f"failed to upsert points: {exc}") from exc
139
+
140
+ logger.info(f"Stored {len(points)} chunks for file_id={file_id}")
141
+ return point_ids
142
+
143
+ async def delete_by_file_id(self, file_id: str) -> None:
144
+ """Delete all points for a given ``file_id``."""
145
+ await self._ensure_collection()
146
+ if not self._initialized:
147
+ return
148
+ try:
149
+ await self._client.delete(
150
+ collection_name=self.collection_name,
151
+ points_selector=qmodels.FilterSelector(
152
+ filter=qmodels.Filter(
153
+ must=[
154
+ qmodels.FieldCondition(
155
+ key="file_id",
156
+ match=qmodels.MatchValue(value=file_id),
157
+ )
158
+ ]
159
+ )
160
+ ),
161
+ )
162
+ logger.info(f"Deleted points for file_id={file_id}")
163
+ except Exception as exc: # noqa: BLE001
164
+ raise VectorStoreError(f"failed to delete by file_id: {exc}") from exc
165
+
166
+ async def similarity_search(
167
+ self,
168
+ query: str,
169
+ *,
170
+ file_ids: Optional[List[str]] = None,
171
+ limit: int = 10,
172
+ score_threshold: Optional[float] = None,
173
+ ) -> List[RetrievedChunk]:
174
+ """Vector similarity search with optional ``file_ids`` filtering."""
175
+ query_vector = await self.embeddings.aembed_query(query)
176
+ await self._ensure_collection(len(query_vector))
177
+ if not self._initialized:
178
+ return []
179
+
180
+ qfilter = None
181
+ if file_ids:
182
+ qfilter = qmodels.Filter(
183
+ must=[
184
+ qmodels.FieldCondition(
185
+ key="file_id", match=qmodels.MatchAny(any=file_ids)
186
+ )
187
+ ]
188
+ )
189
+
190
+ try:
191
+ response = await self._client.query_points(
192
+ collection_name=self.collection_name,
193
+ query=query_vector,
194
+ limit=limit,
195
+ query_filter=qfilter,
196
+ score_threshold=score_threshold,
197
+ )
198
+ except Exception as exc: # noqa: BLE001
199
+ raise VectorStoreError(f"similarity search failed: {exc}") from exc
200
+
201
+ chunks: List[RetrievedChunk] = []
202
+ for point in response.points:
203
+ payload = dict(point.payload or {})
204
+ text = payload.pop("text", "")
205
+ chunks.append(
206
+ RetrievedChunk(
207
+ text=text,
208
+ score=point.score,
209
+ file_id=payload.get("file_id"),
210
+ metadata=payload,
211
+ )
212
+ )
213
+ logger.debug(f"similarity_search returned {len(chunks)} chunks")
214
+ return chunks
215
+
216
+ async def get_unique_file_ids(self, exclude: Optional[List[str]] = None) -> List[str]:
217
+ """Return all distinct ``file_id`` values stored in the collection."""
218
+ await self._ensure_collection()
219
+ if not self._initialized:
220
+ return []
221
+ exclude_set = set(exclude or [])
222
+ found: set[str] = set()
223
+ offset = None
224
+ try:
225
+ while True:
226
+ records, offset = await self._client.scroll(
227
+ collection_name=self.collection_name,
228
+ limit=1000,
229
+ offset=offset,
230
+ with_payload=True,
231
+ with_vectors=False,
232
+ )
233
+ for rec in records:
234
+ fid = (rec.payload or {}).get("file_id")
235
+ if fid and fid not in exclude_set:
236
+ found.add(fid)
237
+ if offset is None:
238
+ break
239
+ except Exception as exc: # noqa: BLE001
240
+ logger.error(f"failed to scroll file_ids: {exc}")
241
+ return []
242
+ return sorted(found)
243
+
244
+ async def aclose(self) -> None:
245
+ await self._client.close()