ff-aitoolkit 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aitoolkit/__init__.py +66 -0
- aitoolkit/config.py +107 -0
- aitoolkit/embeddings/__init__.py +5 -0
- aitoolkit/embeddings/client.py +133 -0
- aitoolkit/exceptions.py +35 -0
- aitoolkit/integrations/__init__.py +1 -0
- aitoolkit/integrations/langchain.py +69 -0
- aitoolkit/llm/__init__.py +5 -0
- aitoolkit/llm/client.py +230 -0
- aitoolkit/py.typed +0 -0
- aitoolkit/rag/__init__.py +25 -0
- aitoolkit/rag/agent.py +165 -0
- aitoolkit/rag/query_expansion.py +147 -0
- aitoolkit/rag/retriever.py +141 -0
- aitoolkit/rag/vector_store.py +245 -0
- aitoolkit/retry.py +51 -0
- aitoolkit/stt/__init__.py +5 -0
- aitoolkit/stt/client.py +147 -0
- aitoolkit/tts/__init__.py +10 -0
- aitoolkit/tts/audio.py +68 -0
- aitoolkit/tts/client.py +219 -0
- aitoolkit/types.py +66 -0
- ff_aitoolkit-0.2.0.dist-info/METADATA +159 -0
- ff_aitoolkit-0.2.0.dist-info/RECORD +26 -0
- ff_aitoolkit-0.2.0.dist-info/WHEEL +4 -0
- ff_aitoolkit-0.2.0.dist-info/licenses/LICENSE +21 -0
aitoolkit/rag/agent.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
"""Centralized RAG agent coordinating embeddings, vector store, and retrieval.
|
|
2
|
+
|
|
3
|
+
Provider-agnostic: no Google/Gemini defaults. ``answer_question`` is functional —
|
|
4
|
+
it retrieves context and generates an answer with the toolkit's
|
|
5
|
+
:class:`~aitoolkit.llm.LLMClient`.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from typing import Any, Dict, List, Optional
|
|
11
|
+
|
|
12
|
+
from loguru import logger
|
|
13
|
+
|
|
14
|
+
from aitoolkit.config import get_settings
|
|
15
|
+
from aitoolkit.embeddings import EmbeddingsClient, get_embeddings_client
|
|
16
|
+
from aitoolkit.llm import LLMClient, get_llm_client
|
|
17
|
+
from aitoolkit.rag.retriever import FilterFn, RAGRetriever
|
|
18
|
+
from aitoolkit.rag.vector_store import UnifiedVectorStore
|
|
19
|
+
from aitoolkit.types import RetrievedChunk
|
|
20
|
+
|
|
21
|
+
_DEFAULT_ANSWER_SYSTEM = (
|
|
22
|
+
"You are a helpful assistant. Answer the question using ONLY the provided "
|
|
23
|
+
"context. If the context is insufficient, say so clearly."
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class RAGAgent:
|
|
28
|
+
"""Unified interface for embedding, storing, retrieving and answering."""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
qdrant_url: Optional[str] = None,
|
|
33
|
+
collection_name: Optional[str] = None,
|
|
34
|
+
redis_url: Optional[str] = None,
|
|
35
|
+
embeddings: Optional[EmbeddingsClient] = None,
|
|
36
|
+
llm: Optional[LLMClient] = None,
|
|
37
|
+
enable_caching: bool = True,
|
|
38
|
+
filter_fn: Optional[FilterFn] = None,
|
|
39
|
+
) -> None:
|
|
40
|
+
settings = get_settings()
|
|
41
|
+
self.collection_name = collection_name or settings.qdrant_collection
|
|
42
|
+
self.embeddings = embeddings or get_embeddings_client()
|
|
43
|
+
self.llm = llm or get_llm_client()
|
|
44
|
+
|
|
45
|
+
self.vector_store = UnifiedVectorStore(
|
|
46
|
+
qdrant_url=qdrant_url,
|
|
47
|
+
collection_name=self.collection_name,
|
|
48
|
+
embeddings=self.embeddings,
|
|
49
|
+
)
|
|
50
|
+
self.retriever = RAGRetriever(
|
|
51
|
+
vector_store=self.vector_store,
|
|
52
|
+
redis_url=(redis_url or settings.redis_url) if enable_caching else None,
|
|
53
|
+
cache_ttl=settings.cache_ttl,
|
|
54
|
+
filter_fn=filter_fn,
|
|
55
|
+
)
|
|
56
|
+
logger.success(f"RAGAgent ready (collection='{self.collection_name}')")
|
|
57
|
+
|
|
58
|
+
async def add_documents(
|
|
59
|
+
self,
|
|
60
|
+
texts: List[str],
|
|
61
|
+
*,
|
|
62
|
+
file_id: str,
|
|
63
|
+
metadatas: Optional[List[Dict[str, Any]]] = None,
|
|
64
|
+
source_type: str = "upload",
|
|
65
|
+
**extra: Any,
|
|
66
|
+
) -> List[str]:
|
|
67
|
+
"""Embed and store document chunks. Returns point IDs."""
|
|
68
|
+
return await self.vector_store.add_texts(
|
|
69
|
+
texts, metadatas, file_id=file_id, source_type=source_type, **extra
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
async def delete_file(self, file_id: str) -> None:
|
|
73
|
+
await self.vector_store.delete_by_file_id(file_id)
|
|
74
|
+
|
|
75
|
+
async def retrieve_context(
|
|
76
|
+
self,
|
|
77
|
+
query: str,
|
|
78
|
+
*,
|
|
79
|
+
file_ids: Optional[List[str]] = None,
|
|
80
|
+
limit: int = 10,
|
|
81
|
+
score_threshold: Optional[float] = None,
|
|
82
|
+
use_cache: bool = True,
|
|
83
|
+
) -> List[RetrievedChunk]:
|
|
84
|
+
return await self.retriever.retrieve(
|
|
85
|
+
query,
|
|
86
|
+
file_ids=file_ids,
|
|
87
|
+
limit=limit,
|
|
88
|
+
score_threshold=score_threshold,
|
|
89
|
+
use_cache=use_cache,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
async def get_formatted_context(
|
|
93
|
+
self,
|
|
94
|
+
query: str,
|
|
95
|
+
*,
|
|
96
|
+
file_ids: Optional[List[str]] = None,
|
|
97
|
+
limit: int = 10,
|
|
98
|
+
separator: str = "\n\n---\n\n",
|
|
99
|
+
) -> str:
|
|
100
|
+
return await self.retriever.get_context_text(
|
|
101
|
+
query, file_ids=file_ids, limit=limit, separator=separator
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
async def answer_question(
|
|
105
|
+
self,
|
|
106
|
+
question: str,
|
|
107
|
+
*,
|
|
108
|
+
file_ids: Optional[List[str]] = None,
|
|
109
|
+
limit: int = 5,
|
|
110
|
+
system: str = _DEFAULT_ANSWER_SYSTEM,
|
|
111
|
+
temperature: Optional[float] = None,
|
|
112
|
+
) -> Dict[str, Any]:
|
|
113
|
+
"""Retrieve context and generate a grounded answer."""
|
|
114
|
+
chunks = await self.retrieve_context(
|
|
115
|
+
question, file_ids=file_ids, limit=limit
|
|
116
|
+
)
|
|
117
|
+
if not chunks:
|
|
118
|
+
return {
|
|
119
|
+
"answer": "I don't have enough context to answer this question.",
|
|
120
|
+
"sources": [],
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
context = "\n\n---\n\n".join(c.text for c in chunks)
|
|
124
|
+
prompt = (
|
|
125
|
+
f"Context:\n{context}\n\nQuestion: {question}\n\nAnswer:"
|
|
126
|
+
)
|
|
127
|
+
answer = await self.llm.chat(
|
|
128
|
+
prompt, system=system, temperature=temperature
|
|
129
|
+
)
|
|
130
|
+
return {
|
|
131
|
+
"answer": answer,
|
|
132
|
+
"sources": [c.model_dump() for c in chunks],
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
async def get_file_ids(self, exclude: Optional[List[str]] = None) -> List[str]:
|
|
136
|
+
return await self.vector_store.get_unique_file_ids(exclude=exclude)
|
|
137
|
+
|
|
138
|
+
async def get_file_count(self, exclude: Optional[List[str]] = None) -> int:
|
|
139
|
+
return len(await self.get_file_ids(exclude=exclude))
|
|
140
|
+
|
|
141
|
+
async def aclose(self) -> None:
|
|
142
|
+
await self.retriever.aclose()
|
|
143
|
+
await self.vector_store.aclose()
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
_agents: Dict[str, RAGAgent] = {}
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def get_rag_agent(
|
|
150
|
+
collection_name: Optional[str] = None,
|
|
151
|
+
**kwargs: Any,
|
|
152
|
+
) -> RAGAgent:
|
|
153
|
+
"""Return a cached RAG agent, one per collection.
|
|
154
|
+
|
|
155
|
+
Caching is keyed on the collection name so an app that serves multiple
|
|
156
|
+
collections gets a distinct agent for each — a single global singleton would
|
|
157
|
+
silently hand back the first collection's agent for every later call.
|
|
158
|
+
``kwargs`` only take effect the first time a given collection is requested.
|
|
159
|
+
"""
|
|
160
|
+
key = collection_name or get_settings().qdrant_collection
|
|
161
|
+
agent = _agents.get(key)
|
|
162
|
+
if agent is None:
|
|
163
|
+
agent = RAGAgent(collection_name=collection_name, **kwargs)
|
|
164
|
+
_agents[key] = agent
|
|
165
|
+
return agent
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
"""Query expansion for improved RAG recall.
|
|
2
|
+
|
|
3
|
+
Domain-agnostic: no domain-specific keyword list is hardcoded. Pass ``domain``
|
|
4
|
+
and an optional ``keywords`` list to tailor expansion to any project. With no
|
|
5
|
+
keywords, expansion is purely structural.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from functools import lru_cache
|
|
11
|
+
from typing import List, Optional, Tuple
|
|
12
|
+
|
|
13
|
+
from loguru import logger
|
|
14
|
+
|
|
15
|
+
# Generic English stop words for naive keyword extraction.
|
|
16
|
+
_STOP_WORDS = {
|
|
17
|
+
"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for",
|
|
18
|
+
"of", "with", "is", "are", "was", "were", "be", "this", "that", "it",
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
# Content-type templates that anticipate document content for better retrieval.
|
|
22
|
+
_TEMPLATES = {
|
|
23
|
+
"quiz": (
|
|
24
|
+
"What are the key concepts, definitions, procedures, and important points "
|
|
25
|
+
"related to {topic}? What questions should learners understand about {topic}?"
|
|
26
|
+
),
|
|
27
|
+
"lesson": (
|
|
28
|
+
"Explain {topic} in detail. What are the fundamental concepts of {topic}? "
|
|
29
|
+
"What should learners understand about {topic}?"
|
|
30
|
+
),
|
|
31
|
+
"summary": (
|
|
32
|
+
"Summarize the main points about {topic}. What are the key takeaways "
|
|
33
|
+
"regarding {topic}? Provide an overview of {topic}."
|
|
34
|
+
),
|
|
35
|
+
"default": (
|
|
36
|
+
"What topics and skills are needed to learn {topic}? "
|
|
37
|
+
"Provide comprehensive material about {topic}."
|
|
38
|
+
),
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class QueryExpander:
|
|
43
|
+
"""Expand short topics/questions into retrieval-optimized queries."""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
domain: str = "general",
|
|
48
|
+
keywords: Optional[List[str]] = None,
|
|
49
|
+
domain_hint: Optional[str] = None,
|
|
50
|
+
) -> None:
|
|
51
|
+
"""
|
|
52
|
+
Args:
|
|
53
|
+
domain: Free-form domain label (used in chat expansion context).
|
|
54
|
+
keywords: Optional domain keywords appended to keyword extraction and
|
|
55
|
+
used to decide when to add ``domain_hint``.
|
|
56
|
+
domain_hint: Extra phrase appended when a topic matches a keyword
|
|
57
|
+
(e.g. "Include relevant regulations and safety procedures.").
|
|
58
|
+
"""
|
|
59
|
+
self.domain = domain
|
|
60
|
+
self.keywords = [k.lower() for k in (keywords or [])]
|
|
61
|
+
self.domain_hint = domain_hint
|
|
62
|
+
|
|
63
|
+
def expand_for_generation(
|
|
64
|
+
self,
|
|
65
|
+
topic: str,
|
|
66
|
+
content_type: str = "default",
|
|
67
|
+
file_names: Optional[List[str]] = None,
|
|
68
|
+
) -> str:
|
|
69
|
+
"""Expand a topic into a detailed query for content generation."""
|
|
70
|
+
prefix = ""
|
|
71
|
+
if file_names:
|
|
72
|
+
clean = [_strip_ext(n) for n in file_names[:3]]
|
|
73
|
+
if clean:
|
|
74
|
+
prefix = f"Related to documents: {', '.join(clean)}. "
|
|
75
|
+
|
|
76
|
+
template = _TEMPLATES.get(content_type, _TEMPLATES["default"])
|
|
77
|
+
expanded = prefix + template.format(topic=topic)
|
|
78
|
+
|
|
79
|
+
if self.domain_hint and self._matches_domain(topic):
|
|
80
|
+
expanded += f" {self.domain_hint}"
|
|
81
|
+
|
|
82
|
+
logger.debug(f"expanded '{topic}' -> '{expanded[:80]}...'")
|
|
83
|
+
return expanded
|
|
84
|
+
|
|
85
|
+
def expand_for_chat(self, question: str, conversation_history: str = "") -> str:
|
|
86
|
+
"""Lightly expand a short chat question for better retrieval."""
|
|
87
|
+
if len(question.split()) >= 5:
|
|
88
|
+
return question
|
|
89
|
+
if conversation_history:
|
|
90
|
+
return f"{question}. Context: {conversation_history[-200:]}"
|
|
91
|
+
return f"{question} in {self.domain} context"
|
|
92
|
+
|
|
93
|
+
def generate_query_variations(
|
|
94
|
+
self, query: str, num_variations: int = 3
|
|
95
|
+
) -> List[str]:
|
|
96
|
+
"""Produce structural query variations for hybrid search."""
|
|
97
|
+
variations = [query]
|
|
98
|
+
lower = query.lower()
|
|
99
|
+
if "what" not in lower and num_variations >= 2:
|
|
100
|
+
variations.append(f"What is {query}? Explain {query}.")
|
|
101
|
+
if "how" not in lower and num_variations >= 3:
|
|
102
|
+
variations.append(f"How to {query}? Steps for {query}.")
|
|
103
|
+
if self.keywords and num_variations >= 4:
|
|
104
|
+
variations.append(f"{query} {' '.join(self.keywords[:5])}")
|
|
105
|
+
return variations[:num_variations]
|
|
106
|
+
|
|
107
|
+
def extract_keywords(self, text: str, top_n: int = 10) -> List[str]:
|
|
108
|
+
"""Naive keyword extraction plus any configured domain keywords present."""
|
|
109
|
+
words = [w for w in text.lower().split() if w not in _STOP_WORDS and len(w) > 3]
|
|
110
|
+
domain_kws = [k for k in self.keywords if k in text.lower()]
|
|
111
|
+
seen: List[str] = []
|
|
112
|
+
for w in words + domain_kws:
|
|
113
|
+
if w not in seen:
|
|
114
|
+
seen.append(w)
|
|
115
|
+
return seen[:top_n]
|
|
116
|
+
|
|
117
|
+
def _matches_domain(self, topic: str) -> bool:
|
|
118
|
+
if not self.keywords:
|
|
119
|
+
return False
|
|
120
|
+
low = topic.lower()
|
|
121
|
+
return any(k in low for k in self.keywords)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _strip_ext(name: str) -> str:
|
|
125
|
+
for ext in (".pdf", ".docx", ".pptx", ".xlsx", ".txt", ".doc"):
|
|
126
|
+
name = name.replace(ext, "")
|
|
127
|
+
return name
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@lru_cache(maxsize=16)
|
|
131
|
+
def _cached_expander(
|
|
132
|
+
domain: str, keywords: Tuple[str, ...], domain_hint: Optional[str]
|
|
133
|
+
) -> QueryExpander:
|
|
134
|
+
return QueryExpander(domain=domain, keywords=list(keywords), domain_hint=domain_hint)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def get_query_expander(
|
|
138
|
+
domain: str = "general",
|
|
139
|
+
keywords: Optional[List[str]] = None,
|
|
140
|
+
domain_hint: Optional[str] = None,
|
|
141
|
+
) -> QueryExpander:
|
|
142
|
+
"""Return a cached query expander, one per (domain, keywords, hint) combo.
|
|
143
|
+
|
|
144
|
+
Keying on the arguments avoids the single-global footgun where the first
|
|
145
|
+
caller's domain/keywords would be returned for every subsequent call.
|
|
146
|
+
"""
|
|
147
|
+
return _cached_expander(domain, tuple(keywords or ()), domain_hint)
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""RAG retriever with optional Redis caching and pluggable filtering.
|
|
2
|
+
|
|
3
|
+
Differences from the old implementation:
|
|
4
|
+
|
|
5
|
+
* No LangChain dependency — operates on :class:`RetrievedChunk`.
|
|
6
|
+
* The hardcoded "drop DOCX files" rule is gone; pass a ``filter_fn`` predicate to
|
|
7
|
+
apply any project-specific filtering instead.
|
|
8
|
+
* Redis is optional (``aitoolkit[cache]``); without it, caching is a no-op.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import hashlib
|
|
14
|
+
import json
|
|
15
|
+
from typing import Callable, List, Optional
|
|
16
|
+
|
|
17
|
+
from loguru import logger
|
|
18
|
+
|
|
19
|
+
from aitoolkit.rag.vector_store import UnifiedVectorStore
|
|
20
|
+
from aitoolkit.types import RetrievedChunk
|
|
21
|
+
|
|
22
|
+
try: # optional dependency
|
|
23
|
+
import redis.asyncio as redis
|
|
24
|
+
except ImportError: # pragma: no cover
|
|
25
|
+
redis = None # type: ignore[assignment]
|
|
26
|
+
|
|
27
|
+
FilterFn = Callable[[RetrievedChunk], bool]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class RAGRetriever:
|
|
31
|
+
"""Retrieve relevant chunks, with optional caching, filtering and reranking."""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
vector_store: UnifiedVectorStore,
|
|
36
|
+
redis_url: Optional[str] = None,
|
|
37
|
+
cache_ttl: int = 3600,
|
|
38
|
+
filter_fn: Optional[FilterFn] = None,
|
|
39
|
+
reranker: Optional[Callable[[str, List[RetrievedChunk]], List[RetrievedChunk]]] = None,
|
|
40
|
+
) -> None:
|
|
41
|
+
self.vector_store = vector_store
|
|
42
|
+
self.cache_ttl = cache_ttl
|
|
43
|
+
self.filter_fn = filter_fn
|
|
44
|
+
self.reranker = reranker
|
|
45
|
+
|
|
46
|
+
self.redis_client = None
|
|
47
|
+
if redis_url and redis is not None:
|
|
48
|
+
try:
|
|
49
|
+
self.redis_client = redis.from_url(redis_url, decode_responses=True)
|
|
50
|
+
logger.info("RAG retriever caching enabled")
|
|
51
|
+
except Exception as exc: # noqa: BLE001
|
|
52
|
+
logger.warning(f"failed to init Redis cache: {exc}")
|
|
53
|
+
elif redis_url and redis is None:
|
|
54
|
+
logger.warning(
|
|
55
|
+
"redis_url provided but redis not installed; install aitoolkit[cache]"
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
def _cache_key(self, query: str, file_ids: Optional[List[str]], limit: int) -> str:
|
|
59
|
+
payload = {
|
|
60
|
+
"q": query,
|
|
61
|
+
"f": sorted(file_ids) if file_ids else None,
|
|
62
|
+
"l": limit,
|
|
63
|
+
}
|
|
64
|
+
digest = hashlib.sha256(json.dumps(payload, sort_keys=True).encode()).hexdigest()
|
|
65
|
+
return f"aitoolkit:rag:{digest}"
|
|
66
|
+
|
|
67
|
+
async def _cache_get(self, key: str) -> Optional[List[RetrievedChunk]]:
|
|
68
|
+
if not self.redis_client:
|
|
69
|
+
return None
|
|
70
|
+
try:
|
|
71
|
+
raw = await self.redis_client.get(key)
|
|
72
|
+
if raw:
|
|
73
|
+
return [RetrievedChunk(**item) for item in json.loads(raw)]
|
|
74
|
+
except Exception as exc: # noqa: BLE001
|
|
75
|
+
logger.warning(f"cache read failed: {exc}")
|
|
76
|
+
return None
|
|
77
|
+
|
|
78
|
+
async def _cache_set(self, key: str, chunks: List[RetrievedChunk]) -> None:
|
|
79
|
+
if not self.redis_client:
|
|
80
|
+
return
|
|
81
|
+
try:
|
|
82
|
+
raw = json.dumps([c.model_dump() for c in chunks])
|
|
83
|
+
await self.redis_client.setex(key, self.cache_ttl, raw)
|
|
84
|
+
except Exception as exc: # noqa: BLE001
|
|
85
|
+
logger.warning(f"cache write failed: {exc}")
|
|
86
|
+
|
|
87
|
+
async def retrieve(
|
|
88
|
+
self,
|
|
89
|
+
query: str,
|
|
90
|
+
*,
|
|
91
|
+
file_ids: Optional[List[str]] = None,
|
|
92
|
+
limit: int = 10,
|
|
93
|
+
score_threshold: Optional[float] = None,
|
|
94
|
+
use_cache: bool = True,
|
|
95
|
+
) -> List[RetrievedChunk]:
|
|
96
|
+
"""Return relevant chunks for ``query``."""
|
|
97
|
+
if use_cache:
|
|
98
|
+
cached = await self._cache_get(self._cache_key(query, file_ids, limit))
|
|
99
|
+
if cached is not None:
|
|
100
|
+
logger.debug(f"cache hit ({len(cached)} chunks)")
|
|
101
|
+
return cached
|
|
102
|
+
|
|
103
|
+
chunks = await self.vector_store.similarity_search(
|
|
104
|
+
query, file_ids=file_ids, limit=limit, score_threshold=score_threshold
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
if self.filter_fn:
|
|
108
|
+
chunks = [c for c in chunks if self.filter_fn(c)]
|
|
109
|
+
if self.reranker and chunks:
|
|
110
|
+
chunks = self.reranker(query, chunks)
|
|
111
|
+
|
|
112
|
+
if use_cache and chunks:
|
|
113
|
+
await self._cache_set(self._cache_key(query, file_ids, limit), chunks)
|
|
114
|
+
return chunks
|
|
115
|
+
|
|
116
|
+
async def get_context_text(
|
|
117
|
+
self,
|
|
118
|
+
query: str,
|
|
119
|
+
*,
|
|
120
|
+
file_ids: Optional[List[str]] = None,
|
|
121
|
+
limit: int = 10,
|
|
122
|
+
separator: str = "\n\n---\n\n",
|
|
123
|
+
include_sources: bool = True,
|
|
124
|
+
) -> str:
|
|
125
|
+
"""Retrieve and format chunks into a single context string."""
|
|
126
|
+
chunks = await self.retrieve(query, file_ids=file_ids, limit=limit)
|
|
127
|
+
if not chunks:
|
|
128
|
+
return ""
|
|
129
|
+
parts = []
|
|
130
|
+
for c in chunks:
|
|
131
|
+
if include_sources:
|
|
132
|
+
parts.append(
|
|
133
|
+
f"[Source: {c.file_id or 'unknown'}, Score: {c.score:.3f}]\n{c.text}"
|
|
134
|
+
)
|
|
135
|
+
else:
|
|
136
|
+
parts.append(c.text)
|
|
137
|
+
return separator.join(parts)
|
|
138
|
+
|
|
139
|
+
async def aclose(self) -> None:
|
|
140
|
+
if self.redis_client:
|
|
141
|
+
await self.redis_client.aclose()
|
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
"""Unified Qdrant vector store with file-based filtering.
|
|
2
|
+
|
|
3
|
+
LangChain-free: it works with plain ``texts`` + ``metadatas`` and returns
|
|
4
|
+
:class:`~aitoolkit.types.RetrievedChunk` objects. Embeddings are produced by an
|
|
5
|
+
injected :class:`~aitoolkit.embeddings.EmbeddingsClient`, so the vector size is
|
|
6
|
+
detected at runtime rather than hardcoded.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from typing import Any, Dict, List, Optional
|
|
12
|
+
from uuid import uuid4
|
|
13
|
+
|
|
14
|
+
from loguru import logger
|
|
15
|
+
from qdrant_client import AsyncQdrantClient
|
|
16
|
+
from qdrant_client.http import models as qmodels
|
|
17
|
+
|
|
18
|
+
from aitoolkit.config import get_settings
|
|
19
|
+
from aitoolkit.embeddings import EmbeddingsClient, get_embeddings_client
|
|
20
|
+
from aitoolkit.exceptions import VectorStoreError
|
|
21
|
+
from aitoolkit.types import RetrievedChunk
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class UnifiedVectorStore:
|
|
25
|
+
"""Async Qdrant wrapper for storing and retrieving document embeddings."""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
qdrant_url: Optional[str] = None,
|
|
30
|
+
collection_name: Optional[str] = None,
|
|
31
|
+
embeddings: Optional[EmbeddingsClient] = None,
|
|
32
|
+
vector_size: Optional[int] = None,
|
|
33
|
+
) -> None:
|
|
34
|
+
settings = get_settings()
|
|
35
|
+
self.qdrant_url = qdrant_url or settings.qdrant_url
|
|
36
|
+
self.collection_name = collection_name or settings.qdrant_collection
|
|
37
|
+
self.embeddings = embeddings or get_embeddings_client()
|
|
38
|
+
self._client = AsyncQdrantClient(
|
|
39
|
+
url=self.qdrant_url,
|
|
40
|
+
check_compatibility=settings.qdrant_check_compatibility,
|
|
41
|
+
)
|
|
42
|
+
self._initialized = False
|
|
43
|
+
self._vector_size: Optional[int] = vector_size or settings.qdrant_vector_size
|
|
44
|
+
logger.info(
|
|
45
|
+
f"UnifiedVectorStore ready (collection='{self.collection_name}', "
|
|
46
|
+
f"url={self.qdrant_url})"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
async def _ensure_collection(self, vector_size: Optional[int] = None) -> None:
|
|
50
|
+
"""Create the collection if it does not exist (no destructive recreate)."""
|
|
51
|
+
if self._initialized:
|
|
52
|
+
return
|
|
53
|
+
|
|
54
|
+
try:
|
|
55
|
+
exists = await self._client.collection_exists(self.collection_name)
|
|
56
|
+
except Exception as exc: # noqa: BLE001
|
|
57
|
+
raise VectorStoreError(f"failed to check collection: {exc}") from exc
|
|
58
|
+
|
|
59
|
+
if exists:
|
|
60
|
+
self._initialized = True
|
|
61
|
+
try:
|
|
62
|
+
info = await self._client.get_collection(self.collection_name)
|
|
63
|
+
size = info.config.params.vectors.size # type: ignore[union-attr]
|
|
64
|
+
if isinstance(size, int):
|
|
65
|
+
self._vector_size = size
|
|
66
|
+
except Exception: # noqa: BLE001 - size discovery is best-effort
|
|
67
|
+
pass
|
|
68
|
+
return
|
|
69
|
+
|
|
70
|
+
size = vector_size or self._vector_size
|
|
71
|
+
if size is None:
|
|
72
|
+
# Nothing to create yet — caller will provide a size once embeddings exist.
|
|
73
|
+
return
|
|
74
|
+
|
|
75
|
+
try:
|
|
76
|
+
await self._client.create_collection(
|
|
77
|
+
collection_name=self.collection_name,
|
|
78
|
+
vectors_config=qmodels.VectorParams(
|
|
79
|
+
size=size, distance=qmodels.Distance.COSINE
|
|
80
|
+
),
|
|
81
|
+
)
|
|
82
|
+
self._initialized = True
|
|
83
|
+
self._vector_size = size
|
|
84
|
+
logger.success(
|
|
85
|
+
f"Created collection '{self.collection_name}' (size={size})"
|
|
86
|
+
)
|
|
87
|
+
except Exception as exc: # noqa: BLE001
|
|
88
|
+
raise VectorStoreError(f"failed to create collection: {exc}") from exc
|
|
89
|
+
|
|
90
|
+
async def add_texts(
|
|
91
|
+
self,
|
|
92
|
+
texts: List[str],
|
|
93
|
+
metadatas: Optional[List[Dict[str, Any]]] = None,
|
|
94
|
+
*,
|
|
95
|
+
file_id: str,
|
|
96
|
+
source_type: str = "upload",
|
|
97
|
+
**extra_metadata: Any,
|
|
98
|
+
) -> List[str]:
|
|
99
|
+
"""Embed and store ``texts`` with associated metadata. Returns point IDs."""
|
|
100
|
+
if not texts:
|
|
101
|
+
return []
|
|
102
|
+
|
|
103
|
+
metadatas = metadatas or [{} for _ in texts]
|
|
104
|
+
if len(metadatas) != len(texts):
|
|
105
|
+
raise VectorStoreError("texts and metadatas length mismatch")
|
|
106
|
+
|
|
107
|
+
embeddings = await self.embeddings.aembed_documents(texts)
|
|
108
|
+
if not embeddings:
|
|
109
|
+
logger.warning("No embeddings produced; nothing stored")
|
|
110
|
+
return []
|
|
111
|
+
|
|
112
|
+
await self._ensure_collection(len(embeddings[0]))
|
|
113
|
+
|
|
114
|
+
point_ids: List[str] = []
|
|
115
|
+
points: List[qmodels.PointStruct] = []
|
|
116
|
+
for text, vector, meta in zip(texts, embeddings, metadatas):
|
|
117
|
+
pid = str(uuid4())
|
|
118
|
+
point_ids.append(pid)
|
|
119
|
+
points.append(
|
|
120
|
+
qmodels.PointStruct(
|
|
121
|
+
id=pid,
|
|
122
|
+
vector=vector,
|
|
123
|
+
payload={
|
|
124
|
+
"text": text,
|
|
125
|
+
"file_id": file_id,
|
|
126
|
+
"source_type": source_type,
|
|
127
|
+
**meta,
|
|
128
|
+
**extra_metadata,
|
|
129
|
+
},
|
|
130
|
+
)
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
try:
|
|
134
|
+
await self._client.upsert(
|
|
135
|
+
collection_name=self.collection_name, points=points
|
|
136
|
+
)
|
|
137
|
+
except Exception as exc: # noqa: BLE001
|
|
138
|
+
raise VectorStoreError(f"failed to upsert points: {exc}") from exc
|
|
139
|
+
|
|
140
|
+
logger.info(f"Stored {len(points)} chunks for file_id={file_id}")
|
|
141
|
+
return point_ids
|
|
142
|
+
|
|
143
|
+
async def delete_by_file_id(self, file_id: str) -> None:
|
|
144
|
+
"""Delete all points for a given ``file_id``."""
|
|
145
|
+
await self._ensure_collection()
|
|
146
|
+
if not self._initialized:
|
|
147
|
+
return
|
|
148
|
+
try:
|
|
149
|
+
await self._client.delete(
|
|
150
|
+
collection_name=self.collection_name,
|
|
151
|
+
points_selector=qmodels.FilterSelector(
|
|
152
|
+
filter=qmodels.Filter(
|
|
153
|
+
must=[
|
|
154
|
+
qmodels.FieldCondition(
|
|
155
|
+
key="file_id",
|
|
156
|
+
match=qmodels.MatchValue(value=file_id),
|
|
157
|
+
)
|
|
158
|
+
]
|
|
159
|
+
)
|
|
160
|
+
),
|
|
161
|
+
)
|
|
162
|
+
logger.info(f"Deleted points for file_id={file_id}")
|
|
163
|
+
except Exception as exc: # noqa: BLE001
|
|
164
|
+
raise VectorStoreError(f"failed to delete by file_id: {exc}") from exc
|
|
165
|
+
|
|
166
|
+
async def similarity_search(
|
|
167
|
+
self,
|
|
168
|
+
query: str,
|
|
169
|
+
*,
|
|
170
|
+
file_ids: Optional[List[str]] = None,
|
|
171
|
+
limit: int = 10,
|
|
172
|
+
score_threshold: Optional[float] = None,
|
|
173
|
+
) -> List[RetrievedChunk]:
|
|
174
|
+
"""Vector similarity search with optional ``file_ids`` filtering."""
|
|
175
|
+
query_vector = await self.embeddings.aembed_query(query)
|
|
176
|
+
await self._ensure_collection(len(query_vector))
|
|
177
|
+
if not self._initialized:
|
|
178
|
+
return []
|
|
179
|
+
|
|
180
|
+
qfilter = None
|
|
181
|
+
if file_ids:
|
|
182
|
+
qfilter = qmodels.Filter(
|
|
183
|
+
must=[
|
|
184
|
+
qmodels.FieldCondition(
|
|
185
|
+
key="file_id", match=qmodels.MatchAny(any=file_ids)
|
|
186
|
+
)
|
|
187
|
+
]
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
try:
|
|
191
|
+
response = await self._client.query_points(
|
|
192
|
+
collection_name=self.collection_name,
|
|
193
|
+
query=query_vector,
|
|
194
|
+
limit=limit,
|
|
195
|
+
query_filter=qfilter,
|
|
196
|
+
score_threshold=score_threshold,
|
|
197
|
+
)
|
|
198
|
+
except Exception as exc: # noqa: BLE001
|
|
199
|
+
raise VectorStoreError(f"similarity search failed: {exc}") from exc
|
|
200
|
+
|
|
201
|
+
chunks: List[RetrievedChunk] = []
|
|
202
|
+
for point in response.points:
|
|
203
|
+
payload = dict(point.payload or {})
|
|
204
|
+
text = payload.pop("text", "")
|
|
205
|
+
chunks.append(
|
|
206
|
+
RetrievedChunk(
|
|
207
|
+
text=text,
|
|
208
|
+
score=point.score,
|
|
209
|
+
file_id=payload.get("file_id"),
|
|
210
|
+
metadata=payload,
|
|
211
|
+
)
|
|
212
|
+
)
|
|
213
|
+
logger.debug(f"similarity_search returned {len(chunks)} chunks")
|
|
214
|
+
return chunks
|
|
215
|
+
|
|
216
|
+
async def get_unique_file_ids(self, exclude: Optional[List[str]] = None) -> List[str]:
|
|
217
|
+
"""Return all distinct ``file_id`` values stored in the collection."""
|
|
218
|
+
await self._ensure_collection()
|
|
219
|
+
if not self._initialized:
|
|
220
|
+
return []
|
|
221
|
+
exclude_set = set(exclude or [])
|
|
222
|
+
found: set[str] = set()
|
|
223
|
+
offset = None
|
|
224
|
+
try:
|
|
225
|
+
while True:
|
|
226
|
+
records, offset = await self._client.scroll(
|
|
227
|
+
collection_name=self.collection_name,
|
|
228
|
+
limit=1000,
|
|
229
|
+
offset=offset,
|
|
230
|
+
with_payload=True,
|
|
231
|
+
with_vectors=False,
|
|
232
|
+
)
|
|
233
|
+
for rec in records:
|
|
234
|
+
fid = (rec.payload or {}).get("file_id")
|
|
235
|
+
if fid and fid not in exclude_set:
|
|
236
|
+
found.add(fid)
|
|
237
|
+
if offset is None:
|
|
238
|
+
break
|
|
239
|
+
except Exception as exc: # noqa: BLE001
|
|
240
|
+
logger.error(f"failed to scroll file_ids: {exc}")
|
|
241
|
+
return []
|
|
242
|
+
return sorted(found)
|
|
243
|
+
|
|
244
|
+
async def aclose(self) -> None:
|
|
245
|
+
await self._client.close()
|