kssrag 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kssrag/__init__.py +66 -0
- kssrag/cli.py +142 -0
- kssrag/config.py +193 -0
- kssrag/core/__init__.py +0 -0
- kssrag/core/agents.py +68 -0
- kssrag/core/chunkers.py +100 -0
- kssrag/core/retrievers.py +74 -0
- kssrag/core/vectorstores.py +397 -0
- kssrag/kssrag.py +116 -0
- kssrag/models/__init__.py +0 -0
- kssrag/models/local_llms.py +30 -0
- kssrag/models/openrouter.py +85 -0
- kssrag/server.py +116 -0
- kssrag/utils/__init__.py +0 -0
- kssrag/utils/document_loaders.py +40 -0
- kssrag/utils/helpers.py +55 -0
- kssrag/utils/preprocessors.py +30 -0
- kssrag-0.1.0.dist-info/METADATA +407 -0
- kssrag-0.1.0.dist-info/RECORD +26 -0
- kssrag-0.1.0.dist-info/WHEEL +5 -0
- kssrag-0.1.0.dist-info/entry_points.txt +2 -0
- kssrag-0.1.0.dist-info/licenses/LICENSE +0 -0
- kssrag-0.1.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +0 -0
- tests/test_basic.py +43 -0
- tests/test_vectorstores.py +35 -0
|
@@ -0,0 +1,397 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
import pickle
|
|
5
|
+
import numpy as np
|
|
6
|
+
import faiss
|
|
7
|
+
from rank_bm25 import BM25Okapi
|
|
8
|
+
from sentence_transformers import SentenceTransformer
|
|
9
|
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
10
|
+
from sklearn.metrics.pairwise import cosine_similarity
|
|
11
|
+
import scipy.sparse as sp
|
|
12
|
+
from typing import List, Dict, Any, Optional
|
|
13
|
+
from ..utils.helpers import logger
|
|
14
|
+
from ..config import config
|
|
15
|
+
|
|
16
|
+
class BaseVectorStore:
|
|
17
|
+
"""Base class for vector stores"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, persist_path: Optional[str] = None):
|
|
20
|
+
self.persist_path = persist_path
|
|
21
|
+
self.documents: List[Dict[str, Any]] = []
|
|
22
|
+
self.doc_texts: List[str] = []
|
|
23
|
+
|
|
24
|
+
def add_documents(self, documents: List[Dict[str, Any]]):
|
|
25
|
+
"""Add documents to the vector store"""
|
|
26
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
27
|
+
|
|
28
|
+
def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
|
29
|
+
"""Retrieve documents based on query"""
|
|
30
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
31
|
+
|
|
32
|
+
def persist(self):
|
|
33
|
+
"""Persist the vector store to disk"""
|
|
34
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
35
|
+
|
|
36
|
+
def load(self):
|
|
37
|
+
"""Load the vector store from disk"""
|
|
38
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
39
|
+
|
|
40
|
+
class BM25VectorStore(BaseVectorStore):
|
|
41
|
+
"""BM25 vector store implementation"""
|
|
42
|
+
|
|
43
|
+
def __init__(self, persist_path: Optional[str] = "bm25_index.pkl"):
|
|
44
|
+
super().__init__(persist_path)
|
|
45
|
+
self.bm25 = None
|
|
46
|
+
|
|
47
|
+
def add_documents(self, documents: List[Dict[str, Any]]):
|
|
48
|
+
self.documents = documents
|
|
49
|
+
self.doc_texts = [doc["content"] for doc in documents]
|
|
50
|
+
|
|
51
|
+
# Tokenize corpus for BM25
|
|
52
|
+
tokenized_corpus = [self._tokenize(doc) for doc in self.doc_texts]
|
|
53
|
+
self.bm25 = BM25Okapi(tokenized_corpus)
|
|
54
|
+
|
|
55
|
+
logger.info(f"BM25 index created with {len(self.documents)} documents")
|
|
56
|
+
|
|
57
|
+
def _tokenize(self, text: str) -> List[str]:
|
|
58
|
+
"""Tokenize text for BM25"""
|
|
59
|
+
return re.findall(r'\w+', text.lower())
|
|
60
|
+
|
|
61
|
+
def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
|
62
|
+
if not self.bm25:
|
|
63
|
+
raise ValueError("BM25 index not initialized. Call add_documents first.")
|
|
64
|
+
|
|
65
|
+
try:
|
|
66
|
+
tokenized_query = self._tokenize(query)
|
|
67
|
+
doc_scores = self.bm25.get_scores(tokenized_query)
|
|
68
|
+
|
|
69
|
+
# Get top-k indices
|
|
70
|
+
top_indices = np.argsort(doc_scores)[::-1][:top_k]
|
|
71
|
+
|
|
72
|
+
# Filter out invalid indices
|
|
73
|
+
valid_indices = [i for i in top_indices if i < len(self.documents)]
|
|
74
|
+
|
|
75
|
+
if not valid_indices:
|
|
76
|
+
logger.warning(f"No valid indices found for query: {query}")
|
|
77
|
+
return []
|
|
78
|
+
|
|
79
|
+
return [self.documents[i] for i in valid_indices]
|
|
80
|
+
|
|
81
|
+
except Exception as e:
|
|
82
|
+
logger.error(f"Error in BM25 retrieval: {str(e)}")
|
|
83
|
+
return []
|
|
84
|
+
|
|
85
|
+
def persist(self):
|
|
86
|
+
if self.persist_path:
|
|
87
|
+
with open(self.persist_path, 'wb') as f:
|
|
88
|
+
pickle.dump({
|
|
89
|
+
'documents': self.documents,
|
|
90
|
+
'doc_texts': self.doc_texts,
|
|
91
|
+
'bm25': self.bm25
|
|
92
|
+
}, f)
|
|
93
|
+
logger.info(f"BM25 index persisted to {self.persist_path}")
|
|
94
|
+
|
|
95
|
+
def load(self):
|
|
96
|
+
if self.persist_path and os.path.exists(self.persist_path):
|
|
97
|
+
with open(self.persist_path, 'rb') as f:
|
|
98
|
+
data = pickle.load(f)
|
|
99
|
+
self.documents = data['documents']
|
|
100
|
+
self.doc_texts = data['doc_texts']
|
|
101
|
+
self.bm25 = data['bm25']
|
|
102
|
+
logger.info(f"BM25 index loaded from {self.persist_path}")
|
|
103
|
+
|
|
104
|
+
import tempfile
|
|
105
|
+
class FAISSVectorStore(BaseVectorStore):
|
|
106
|
+
def __init__(self, persist_path: Optional[str] = None, model_name: Optional[str] = None):
|
|
107
|
+
super().__init__(persist_path)
|
|
108
|
+
self.model_name = model_name or config.FAISS_MODEL_NAME
|
|
109
|
+
|
|
110
|
+
# Handle cache directory permissions
|
|
111
|
+
try:
|
|
112
|
+
cache_dir = config.CACHE_DIR
|
|
113
|
+
os.makedirs(cache_dir, exist_ok=True)
|
|
114
|
+
# Test if we can write to the cache directory
|
|
115
|
+
test_file = os.path.join(cache_dir, 'write_test.txt')
|
|
116
|
+
with open(test_file, 'w') as f:
|
|
117
|
+
f.write('test')
|
|
118
|
+
os.remove(test_file)
|
|
119
|
+
except PermissionError:
|
|
120
|
+
logger.warning(f"Could not write to cache directory {cache_dir}. Using temp directory.")
|
|
121
|
+
cache_dir = tempfile.gettempdir()
|
|
122
|
+
|
|
123
|
+
self.model = SentenceTransformer(self.model_name, cache_folder=cache_dir)
|
|
124
|
+
self.dimension = self.model.get_sentence_embedding_dimension()
|
|
125
|
+
self.index = faiss.IndexFlatL2(self.dimension)
|
|
126
|
+
self.metadata_path = persist_path + ".meta" if persist_path else None
|
|
127
|
+
|
|
128
|
+
def add_documents(self, documents: List[Dict[str, Any]]):
|
|
129
|
+
self.documents = documents
|
|
130
|
+
self.doc_texts = [doc["content"] for doc in documents]
|
|
131
|
+
|
|
132
|
+
# Generate embeddings in batches
|
|
133
|
+
embeddings = []
|
|
134
|
+
batch_size = config.BATCH_SIZE
|
|
135
|
+
|
|
136
|
+
for i in range(0, len(self.doc_texts), batch_size):
|
|
137
|
+
batch_texts = self.doc_texts[i:i+batch_size]
|
|
138
|
+
batch_embeddings = self.model.encode(batch_texts, show_progress_bar=False)
|
|
139
|
+
embeddings.append(batch_embeddings)
|
|
140
|
+
logger.info(f"Processed batch {i//batch_size + 1}/{(len(self.doc_texts)-1)//batch_size + 1}")
|
|
141
|
+
|
|
142
|
+
embeddings = np.vstack(embeddings).astype('float32')
|
|
143
|
+
self.index.add(embeddings)
|
|
144
|
+
|
|
145
|
+
logger.info(f"FAISS index created with {len(self.documents)} documents")
|
|
146
|
+
|
|
147
|
+
def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
|
148
|
+
if not self.index or self.index.ntotal == 0:
|
|
149
|
+
raise ValueError("FAISS index not initialized. Call add_documents first.")
|
|
150
|
+
|
|
151
|
+
try:
|
|
152
|
+
query_embedding = self.model.encode([query])
|
|
153
|
+
distances, indices = self.index.search(query_embedding.astype('float32'), top_k)
|
|
154
|
+
|
|
155
|
+
# Filter out invalid indices (FAISS might return -1 for no results)
|
|
156
|
+
valid_indices = [i for i in indices[0] if i >= 0 and i < len(self.documents)]
|
|
157
|
+
|
|
158
|
+
if not valid_indices:
|
|
159
|
+
logger.warning(f"No valid indices found for query: {query}")
|
|
160
|
+
return []
|
|
161
|
+
|
|
162
|
+
return [self.documents[i] for i in valid_indices]
|
|
163
|
+
except Exception as e:
|
|
164
|
+
logger.error(f"Error in FAISS retrieval: {str(e)}")
|
|
165
|
+
return []
|
|
166
|
+
|
|
167
|
+
def persist(self):
|
|
168
|
+
if self.persist_path and self.metadata_path:
|
|
169
|
+
faiss.write_index(self.index, self.persist_path)
|
|
170
|
+
|
|
171
|
+
# Save metadata
|
|
172
|
+
with open(self.metadata_path, 'wb') as f:
|
|
173
|
+
pickle.dump({
|
|
174
|
+
'documents': self.documents,
|
|
175
|
+
'doc_texts': self.doc_texts
|
|
176
|
+
}, f)
|
|
177
|
+
logger.info(f"FAISS index persisted to {self.persist_path}")
|
|
178
|
+
|
|
179
|
+
def load(self):
|
|
180
|
+
if (self.persist_path and os.path.exists(self.persist_path) and
|
|
181
|
+
self.metadata_path and os.path.exists(self.metadata_path)):
|
|
182
|
+
|
|
183
|
+
self.index = faiss.read_index(self.persist_path)
|
|
184
|
+
|
|
185
|
+
# Load metadata
|
|
186
|
+
with open(self.metadata_path, 'rb') as f:
|
|
187
|
+
data = pickle.load(f)
|
|
188
|
+
self.documents = data['documents']
|
|
189
|
+
self.doc_texts = data['doc_texts']
|
|
190
|
+
logger.info(f"FAISS index loaded from {self.persist_path}")
|
|
191
|
+
|
|
192
|
+
class TFIDFVectorStore(BaseVectorStore):
|
|
193
|
+
"""TFIDF vector store implementation"""
|
|
194
|
+
|
|
195
|
+
def __init__(self, persist_path: Optional[str] = "tfidf_index.pkl", max_features: int = 10000):
|
|
196
|
+
super().__init__(persist_path)
|
|
197
|
+
self.vectorizer = TfidfVectorizer(max_features=max_features, stop_words='english')
|
|
198
|
+
self.tfidf_matrix = None
|
|
199
|
+
|
|
200
|
+
def add_documents(self, documents: List[Dict[str, Any]]):
|
|
201
|
+
self.documents = documents
|
|
202
|
+
self.doc_texts = [doc["content"] for doc in documents]
|
|
203
|
+
|
|
204
|
+
# Fit and transform the documents
|
|
205
|
+
self.tfidf_matrix = self.vectorizer.fit_transform(self.doc_texts)
|
|
206
|
+
|
|
207
|
+
logger.info(f"TFIDF index created with {len(self.documents)} documents")
|
|
208
|
+
|
|
209
|
+
def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
|
210
|
+
if self.tfidf_matrix is None:
|
|
211
|
+
raise ValueError("TFIDF index not initialized. Call add_documents first.")
|
|
212
|
+
|
|
213
|
+
try:
|
|
214
|
+
# Transform the query
|
|
215
|
+
query_vec = self.vectorizer.transform([query])
|
|
216
|
+
|
|
217
|
+
# Calculate cosine similarities
|
|
218
|
+
similarities = cosine_similarity(query_vec, self.tfidf_matrix).flatten()
|
|
219
|
+
|
|
220
|
+
# Get top-k indices
|
|
221
|
+
top_indices = np.argsort(similarities)[::-1][:top_k]
|
|
222
|
+
|
|
223
|
+
# Filter out invalid indices
|
|
224
|
+
valid_indices = [i for i in top_indices if i < len(self.documents)]
|
|
225
|
+
|
|
226
|
+
if not valid_indices:
|
|
227
|
+
logger.warning(f"No valid indices found for query: {query}")
|
|
228
|
+
return []
|
|
229
|
+
|
|
230
|
+
return [self.documents[i] for i in valid_indices]
|
|
231
|
+
|
|
232
|
+
except Exception as e:
|
|
233
|
+
logger.error(f"Error in TFIDF retrieval: {str(e)}")
|
|
234
|
+
return []
|
|
235
|
+
|
|
236
|
+
def persist(self):
|
|
237
|
+
if self.persist_path:
|
|
238
|
+
with open(self.persist_path, 'wb') as f:
|
|
239
|
+
pickle.dump({
|
|
240
|
+
'documents': self.documents,
|
|
241
|
+
'doc_texts': self.doc_texts,
|
|
242
|
+
'vectorizer': self.vectorizer,
|
|
243
|
+
'tfidf_matrix': self.tfidf_matrix
|
|
244
|
+
}, f)
|
|
245
|
+
logger.info(f"TFIDF index persisted to {self.persist_path}")
|
|
246
|
+
|
|
247
|
+
def load(self):
|
|
248
|
+
if self.persist_path and os.path.exists(self.persist_path):
|
|
249
|
+
with open(self.persist_path, 'rb') as f:
|
|
250
|
+
data = pickle.load(f)
|
|
251
|
+
self.documents = data['documents']
|
|
252
|
+
self.doc_texts = data['doc_texts']
|
|
253
|
+
self.vectorizer = data['vectorizer']
|
|
254
|
+
self.tfidf_matrix = data['tfidf_matrix']
|
|
255
|
+
logger.info(f"TFIDF index loaded from {self.persist_path}")
|
|
256
|
+
|
|
257
|
+
class HybridVectorStore(BaseVectorStore):
|
|
258
|
+
"""Hybrid vector store combining BM25 and FAISS"""
|
|
259
|
+
|
|
260
|
+
def __init__(self, persist_path: Optional[str] = "hybrid_index"):
|
|
261
|
+
super().__init__(persist_path)
|
|
262
|
+
self.bm25_store = BM25VectorStore(persist_path + "_bm25")
|
|
263
|
+
self.faiss_store = FAISSVectorStore(persist_path + "_faiss")
|
|
264
|
+
|
|
265
|
+
def add_documents(self, documents: List[Dict[str, Any]]):
|
|
266
|
+
self.documents = documents
|
|
267
|
+
self.bm25_store.add_documents(documents)
|
|
268
|
+
self.faiss_store.add_documents(documents)
|
|
269
|
+
logger.info(f"Hybrid index created with {len(self.documents)} documents")
|
|
270
|
+
|
|
271
|
+
def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
|
272
|
+
try:
|
|
273
|
+
# Get results from both methods
|
|
274
|
+
bm25_results = self.bm25_store.retrieve(query, top_k * 2)
|
|
275
|
+
faiss_results = self.faiss_store.retrieve(query, top_k * 2)
|
|
276
|
+
|
|
277
|
+
# Combine and deduplicate by content
|
|
278
|
+
combined = {}
|
|
279
|
+
for doc in bm25_results + faiss_results:
|
|
280
|
+
# Use a combination of content and metadata for deduplication
|
|
281
|
+
key = hash(doc["content"] + str(doc["metadata"]))
|
|
282
|
+
if key not in combined:
|
|
283
|
+
combined[key] = doc
|
|
284
|
+
|
|
285
|
+
all_results = list(combined.values())
|
|
286
|
+
|
|
287
|
+
# If no results after deduplication
|
|
288
|
+
if not all_results:
|
|
289
|
+
return []
|
|
290
|
+
|
|
291
|
+
# Rerank by relevance to query using FAISS similarity
|
|
292
|
+
query_embedding = self.faiss_store.model.encode(query)
|
|
293
|
+
scored_results = []
|
|
294
|
+
|
|
295
|
+
for doc in all_results:
|
|
296
|
+
doc_embedding = self.faiss_store.model.encode(doc["content"])
|
|
297
|
+
similarity = np.dot(query_embedding, doc_embedding) / (
|
|
298
|
+
np.linalg.norm(query_embedding) * np.linalg.norm(doc_embedding))
|
|
299
|
+
|
|
300
|
+
scored_results.append((doc, similarity))
|
|
301
|
+
|
|
302
|
+
scored_results.sort(key=lambda x: x[1], reverse=True)
|
|
303
|
+
return [doc for doc, _ in scored_results[:top_k]]
|
|
304
|
+
|
|
305
|
+
except Exception as e:
|
|
306
|
+
logger.error(f"Error in hybrid retrieval: {str(e)}")
|
|
307
|
+
return []
|
|
308
|
+
|
|
309
|
+
def persist(self):
|
|
310
|
+
self.bm25_store.persist()
|
|
311
|
+
self.faiss_store.persist()
|
|
312
|
+
logger.info(f"Hybrid index persisted")
|
|
313
|
+
|
|
314
|
+
def load(self):
|
|
315
|
+
self.bm25_store.load()
|
|
316
|
+
self.faiss_store.load()
|
|
317
|
+
self.documents = self.bm25_store.documents
|
|
318
|
+
logger.info(f"Hybrid index loaded")
|
|
319
|
+
|
|
320
|
+
class HybridOfflineVectorStore(BaseVectorStore):
|
|
321
|
+
"""Hybrid offline vector store combining BM25 and TFIDF"""
|
|
322
|
+
|
|
323
|
+
def __init__(self, persist_path: Optional[str] = "hybrid_offline_index"):
|
|
324
|
+
super().__init__(persist_path)
|
|
325
|
+
self.bm25_store = BM25VectorStore(persist_path + "_bm25")
|
|
326
|
+
self.tfidf_store = TFIDFVectorStore(persist_path + "_tfidf")
|
|
327
|
+
self.alpha = 0.5 # Weight for BM25 vs TFIDF
|
|
328
|
+
|
|
329
|
+
def add_documents(self, documents: List[Dict[str, Any]]):
|
|
330
|
+
self.documents = documents
|
|
331
|
+
self.bm25_store.add_documents(documents)
|
|
332
|
+
self.tfidf_store.add_documents(documents)
|
|
333
|
+
logger.info(f"Hybrid offline index created with {len(self.documents)} documents")
|
|
334
|
+
|
|
335
|
+
def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
|
336
|
+
try:
|
|
337
|
+
# Get results from both methods
|
|
338
|
+
bm25_results = self.bm25_store.retrieve(query, top_k * 2)
|
|
339
|
+
tfidf_results = self.tfidf_store.retrieve(query, top_k * 2)
|
|
340
|
+
|
|
341
|
+
# Combine and deduplicate by content
|
|
342
|
+
combined = {}
|
|
343
|
+
for doc in bm25_results + tfidf_results:
|
|
344
|
+
# Use a combination of content and metadata for deduplication
|
|
345
|
+
key = hash(doc["content"] + str(doc["metadata"]))
|
|
346
|
+
if key not in combined:
|
|
347
|
+
combined[key] = doc
|
|
348
|
+
|
|
349
|
+
all_results = list(combined.values())
|
|
350
|
+
|
|
351
|
+
# If no results after deduplication
|
|
352
|
+
if not all_results:
|
|
353
|
+
return []
|
|
354
|
+
|
|
355
|
+
# Score results based on both methods
|
|
356
|
+
scored_results = []
|
|
357
|
+
|
|
358
|
+
for doc in all_results:
|
|
359
|
+
# Get BM25 score
|
|
360
|
+
bm25_score = 0
|
|
361
|
+
for i, bm25_doc in enumerate(bm25_results):
|
|
362
|
+
if (doc["content"] == bm25_doc["content"] and
|
|
363
|
+
doc["metadata"] == bm25_doc["metadata"]):
|
|
364
|
+
# Normalize score based on position
|
|
365
|
+
bm25_score = (len(bm25_results) - i) / len(bm25_results)
|
|
366
|
+
break
|
|
367
|
+
|
|
368
|
+
# Get TFIDF score
|
|
369
|
+
tfidf_score = 0
|
|
370
|
+
for i, tfidf_doc in enumerate(tfidf_results):
|
|
371
|
+
if (doc["content"] == tfidf_doc["content"] and
|
|
372
|
+
doc["metadata"] == tfidf_doc["metadata"]):
|
|
373
|
+
# Normalize score based on position
|
|
374
|
+
tfidf_score = (len(tfidf_results) - i) / len(tfidf_results)
|
|
375
|
+
break
|
|
376
|
+
|
|
377
|
+
# Combine scores
|
|
378
|
+
combined_score = self.alpha * bm25_score + (1 - self.alpha) * tfidf_score
|
|
379
|
+
scored_results.append((doc, combined_score))
|
|
380
|
+
|
|
381
|
+
scored_results.sort(key=lambda x: x[1], reverse=True)
|
|
382
|
+
return [doc for doc, _ in scored_results[:top_k]]
|
|
383
|
+
|
|
384
|
+
except Exception as e:
|
|
385
|
+
logger.error(f"Error in hybrid offline retrieval: {str(e)}")
|
|
386
|
+
return []
|
|
387
|
+
|
|
388
|
+
def persist(self):
|
|
389
|
+
self.bm25_store.persist()
|
|
390
|
+
self.tfidf_store.persist()
|
|
391
|
+
logger.info(f"Hybrid offline index persisted")
|
|
392
|
+
|
|
393
|
+
def load(self):
|
|
394
|
+
self.bm25_store.load()
|
|
395
|
+
self.tfidf_store.load()
|
|
396
|
+
self.documents = self.bm25_store.documents
|
|
397
|
+
logger.info(f"Hybrid offline index loaded")
|
kssrag/kssrag.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Main KSSRAG class that ties everything together for easy usage.
|
|
3
|
+
"""
|
|
4
|
+
from typing import Optional, List, Dict, Any
|
|
5
|
+
import os
|
|
6
|
+
from .core.chunkers import TextChunker, JSONChunker, PDFChunker
|
|
7
|
+
from .core.vectorstores import BM25VectorStore, FAISSVectorStore, TFIDFVectorStore, HybridVectorStore, HybridOfflineVectorStore
|
|
8
|
+
from .core.retrievers import SimpleRetriever, HybridRetriever
|
|
9
|
+
from .core.agents import RAGAgent
|
|
10
|
+
from .models.openrouter import OpenRouterLLM
|
|
11
|
+
from .utils.document_loaders import load_document, load_json_documents
|
|
12
|
+
from .config import Config, VectorStoreType, ChunkerType, RetrieverType
|
|
13
|
+
from .utils.helpers import logger, validate_config, import_custom_component
|
|
14
|
+
|
|
15
|
+
class KSSRAG:
|
|
16
|
+
"""Main class for KSS RAG functionality"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, config: Optional[Config] = None):
|
|
19
|
+
self.config = config or Config()
|
|
20
|
+
self.vector_store = None
|
|
21
|
+
self.retriever = None
|
|
22
|
+
self.agent = None
|
|
23
|
+
self.documents = []
|
|
24
|
+
|
|
25
|
+
# Validate configuration
|
|
26
|
+
validate_config()
|
|
27
|
+
|
|
28
|
+
def load_document(self, file_path: str, format: Optional[str] = None,
|
|
29
|
+
chunker: Optional[Any] = None, metadata: Optional[Dict[str, Any]] = None):
|
|
30
|
+
"""Load and process a document"""
|
|
31
|
+
if format is None:
|
|
32
|
+
# Auto-detect format
|
|
33
|
+
if file_path.endswith('.txt'):
|
|
34
|
+
format = 'text'
|
|
35
|
+
elif file_path.endswith('.json'):
|
|
36
|
+
format = 'json'
|
|
37
|
+
elif file_path.endswith('.pdf'):
|
|
38
|
+
format = 'pdf'
|
|
39
|
+
else:
|
|
40
|
+
raise ValueError(f"Unsupported file format: {file_path}")
|
|
41
|
+
|
|
42
|
+
# Use custom chunker if provided
|
|
43
|
+
if chunker is None:
|
|
44
|
+
if format == 'text':
|
|
45
|
+
chunker = TextChunker(chunk_size=self.config.CHUNK_SIZE, overlap=self.config.CHUNK_OVERLAP)
|
|
46
|
+
elif format == 'json':
|
|
47
|
+
chunker = JSONChunker()
|
|
48
|
+
elif format == 'pdf':
|
|
49
|
+
chunker = PDFChunker(chunk_size=self.config.CHUNK_SIZE, overlap=self.config.CHUNK_OVERLAP)
|
|
50
|
+
|
|
51
|
+
if metadata is None:
|
|
52
|
+
metadata = {"source": file_path}
|
|
53
|
+
|
|
54
|
+
# Load and chunk document
|
|
55
|
+
if format == 'text':
|
|
56
|
+
content = load_document(file_path)
|
|
57
|
+
self.documents = chunker.chunk(content, metadata)
|
|
58
|
+
elif format == 'json':
|
|
59
|
+
data = load_json_documents(file_path)
|
|
60
|
+
self.documents = chunker.chunk(data)
|
|
61
|
+
elif format == 'pdf':
|
|
62
|
+
self.documents = chunker.chunk_pdf(file_path, metadata)
|
|
63
|
+
|
|
64
|
+
# Create vector store
|
|
65
|
+
if self.config.CUSTOM_VECTOR_STORE:
|
|
66
|
+
vector_store_class = import_custom_component(self.config.CUSTOM_VECTOR_STORE)
|
|
67
|
+
self.vector_store = vector_store_class()
|
|
68
|
+
else:
|
|
69
|
+
if self.config.VECTOR_STORE_TYPE == VectorStoreType.BM25:
|
|
70
|
+
self.vector_store = BM25VectorStore()
|
|
71
|
+
elif self.config.VECTOR_STORE_TYPE == VectorStoreType.FAISS:
|
|
72
|
+
self.vector_store = FAISSVectorStore()
|
|
73
|
+
elif self.config.VECTOR_STORE_TYPE == VectorStoreType.TFIDF:
|
|
74
|
+
self.vector_store = TFIDFVectorStore()
|
|
75
|
+
elif self.config.VECTOR_STORE_TYPE == VectorStoreType.HYBRID_ONLINE:
|
|
76
|
+
self.vector_store = HybridVectorStore()
|
|
77
|
+
elif self.config.VECTOR_STORE_TYPE == VectorStoreType.HYBRID_OFFLINE:
|
|
78
|
+
self.vector_store = HybridOfflineVectorStore()
|
|
79
|
+
|
|
80
|
+
self.vector_store.add_documents(self.documents)
|
|
81
|
+
|
|
82
|
+
# Create retriever
|
|
83
|
+
if self.config.CUSTOM_RETRIEVER:
|
|
84
|
+
retriever_class = import_custom_component(self.config.CUSTOM_RETRIEVER)
|
|
85
|
+
self.retriever = retriever_class(self.vector_store)
|
|
86
|
+
else:
|
|
87
|
+
if self.config.RETRIEVER_TYPE == RetrieverType.SIMPLE:
|
|
88
|
+
self.retriever = SimpleRetriever(self.vector_store)
|
|
89
|
+
elif self.config.RETRIEVER_TYPE == RetrieverType.HYBRID:
|
|
90
|
+
# For hybrid retriever, extract entity names from documents if available
|
|
91
|
+
entity_names = []
|
|
92
|
+
if format == 'json':
|
|
93
|
+
entity_names = [doc['metadata'].get('name', '') for doc in self.documents if doc['metadata'].get('name')]
|
|
94
|
+
self.retriever = HybridRetriever(self.vector_store, entity_names)
|
|
95
|
+
|
|
96
|
+
# Create LLM
|
|
97
|
+
if self.config.CUSTOM_LLM:
|
|
98
|
+
llm_class = import_custom_component(self.config.CUSTOM_LLM)
|
|
99
|
+
llm = llm_class()
|
|
100
|
+
else:
|
|
101
|
+
llm = OpenRouterLLM()
|
|
102
|
+
|
|
103
|
+
# Create agent
|
|
104
|
+
self.agent = RAGAgent(self.retriever, llm)
|
|
105
|
+
|
|
106
|
+
def query(self, question: str, top_k: Optional[int] = None) -> str:
|
|
107
|
+
"""Query the RAG system"""
|
|
108
|
+
if not self.agent:
|
|
109
|
+
raise ValueError("No documents loaded. Call load_document first.")
|
|
110
|
+
|
|
111
|
+
return self.agent.query(question, top_k=top_k or self.config.TOP_K)
|
|
112
|
+
|
|
113
|
+
def create_server(self, server_config=None):
|
|
114
|
+
"""Create a FastAPI server for the RAG system"""
|
|
115
|
+
from .server import create_app
|
|
116
|
+
return create_app(self.agent, server_config)
|
|
File without changes
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Local LLM implementations for offline usage
|
|
3
|
+
"""
|
|
4
|
+
from typing import List, Dict, Any
|
|
5
|
+
from ..utils.helpers import logger
|
|
6
|
+
|
|
7
|
+
class LocalLLM:
|
|
8
|
+
"""Base class for local LLM implementations"""
|
|
9
|
+
|
|
10
|
+
def predict(self, messages: List[Dict[str, str]]) -> str:
|
|
11
|
+
"""Generate a response using a local model"""
|
|
12
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
13
|
+
|
|
14
|
+
class MockLLM(LocalLLM):
|
|
15
|
+
"""Mock LLM for testing without API calls"""
|
|
16
|
+
|
|
17
|
+
def predict(self, messages: List[Dict[str, str]]) -> str:
|
|
18
|
+
"""Generate a mock response for testing"""
|
|
19
|
+
logger.info("Using mock LLM for response generation")
|
|
20
|
+
|
|
21
|
+
# Extract the last user message
|
|
22
|
+
user_message = ""
|
|
23
|
+
for msg in reversed(messages):
|
|
24
|
+
if msg["role"] == "user":
|
|
25
|
+
user_message = msg["content"]
|
|
26
|
+
break
|
|
27
|
+
|
|
28
|
+
return f"This is a mock response to: {user_message}"
|
|
29
|
+
|
|
30
|
+
# TO-DO - Add more local LLM implementations as needed - I hope i do not forget, someone if you are seeing this, the date i added this was 8th of september 2025, how long has it been? DM me on WhatsApp +2349019549473.
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
import requests
|
|
2
|
+
import json
|
|
3
|
+
from typing import List, Dict, Any, Optional
|
|
4
|
+
from ..utils.helpers import logger
|
|
5
|
+
from ..config import config
|
|
6
|
+
|
|
7
|
+
class OpenRouterLLM:
|
|
8
|
+
"""OpenRouter LLM interface with fallback models"""
|
|
9
|
+
|
|
10
|
+
def __init__(self, api_key: Optional[str] = None, model: Optional[str] = None,
|
|
11
|
+
fallback_models: Optional[List[str]] = None):
|
|
12
|
+
self.api_key = api_key or config.OPENROUTER_API_KEY
|
|
13
|
+
self.model = model or config.DEFAULT_MODEL
|
|
14
|
+
self.fallback_models = fallback_models or config.FALLBACK_MODELS
|
|
15
|
+
self.base_url = "https://openrouter.ai/api/v1/chat/completions"
|
|
16
|
+
self.headers = {
|
|
17
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
18
|
+
"Content-Type": "application/json",
|
|
19
|
+
"HTTP-Referer": "https://github.com/Ksschkw/kssrag",
|
|
20
|
+
"X-Title": "KSS RAG Agent"
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
def predict(self, messages: List[Dict[str, str]]) -> str:
|
|
24
|
+
"""Generate a response using OpenRouter's API with fallbacks"""
|
|
25
|
+
logger.info(f"Attempting to generate response with {len(messages)} messages")
|
|
26
|
+
|
|
27
|
+
for model in [self.model] + self.fallback_models:
|
|
28
|
+
payload = {
|
|
29
|
+
"model": model,
|
|
30
|
+
"messages": messages,
|
|
31
|
+
"temperature": 0.7,
|
|
32
|
+
"max_tokens": 1024,
|
|
33
|
+
"top_p": 1,
|
|
34
|
+
"stop": None,
|
|
35
|
+
"stream": False
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
try:
|
|
39
|
+
logger.info(f"Trying model: {model}")
|
|
40
|
+
response = requests.post(
|
|
41
|
+
self.base_url,
|
|
42
|
+
headers=self.headers,
|
|
43
|
+
json=payload,
|
|
44
|
+
timeout=15
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# Check for HTTP errors
|
|
48
|
+
response.raise_for_status()
|
|
49
|
+
|
|
50
|
+
# Parse JSON response
|
|
51
|
+
response_data = response.json()
|
|
52
|
+
|
|
53
|
+
# Validate response structure
|
|
54
|
+
if ("choices" not in response_data or
|
|
55
|
+
len(response_data["choices"]) == 0 or
|
|
56
|
+
"message" not in response_data["choices"][0] or
|
|
57
|
+
"content" not in response_data["choices"][0]["message"]):
|
|
58
|
+
|
|
59
|
+
logger.warning(f"Invalid response format from {model}: {response_data}")
|
|
60
|
+
continue
|
|
61
|
+
|
|
62
|
+
content = response_data["choices"][0]["message"]["content"]
|
|
63
|
+
logger.info(f"Successfully used model: {model}")
|
|
64
|
+
return content
|
|
65
|
+
|
|
66
|
+
except requests.exceptions.Timeout:
|
|
67
|
+
logger.warning(f"Model {model} timed out")
|
|
68
|
+
continue
|
|
69
|
+
except requests.exceptions.RequestException as e:
|
|
70
|
+
logger.warning(f"Request error with model {model}: {str(e)}")
|
|
71
|
+
if hasattr(e, 'response') and e.response is not None:
|
|
72
|
+
try:
|
|
73
|
+
error_data = e.response.json()
|
|
74
|
+
logger.warning(f"Error response: {error_data}")
|
|
75
|
+
except:
|
|
76
|
+
logger.warning(f"Error response text: {e.response.text}")
|
|
77
|
+
continue
|
|
78
|
+
except Exception as e:
|
|
79
|
+
logger.warning(f"Unexpected error with model {model}: {str(e)}")
|
|
80
|
+
continue
|
|
81
|
+
|
|
82
|
+
# If all models fail, return a friendly error message
|
|
83
|
+
error_msg = "I'm having trouble connecting to the knowledge service right now. Please try again in a moment."
|
|
84
|
+
logger.error("All model fallbacks failed to respond")
|
|
85
|
+
return error_msg
|