kssrag 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,397 @@
1
+ import json
2
+ import os
3
+ import re
4
+ import pickle
5
+ import numpy as np
6
+ import faiss
7
+ from rank_bm25 import BM25Okapi
8
+ from sentence_transformers import SentenceTransformer
9
+ from sklearn.feature_extraction.text import TfidfVectorizer
10
+ from sklearn.metrics.pairwise import cosine_similarity
11
+ import scipy.sparse as sp
12
+ from typing import List, Dict, Any, Optional
13
+ from ..utils.helpers import logger
14
+ from ..config import config
15
+
16
+ class BaseVectorStore:
17
+ """Base class for vector stores"""
18
+
19
+ def __init__(self, persist_path: Optional[str] = None):
20
+ self.persist_path = persist_path
21
+ self.documents: List[Dict[str, Any]] = []
22
+ self.doc_texts: List[str] = []
23
+
24
+ def add_documents(self, documents: List[Dict[str, Any]]):
25
+ """Add documents to the vector store"""
26
+ raise NotImplementedError("Subclasses must implement this method")
27
+
28
+ def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
29
+ """Retrieve documents based on query"""
30
+ raise NotImplementedError("Subclasses must implement this method")
31
+
32
+ def persist(self):
33
+ """Persist the vector store to disk"""
34
+ raise NotImplementedError("Subclasses must implement this method")
35
+
36
+ def load(self):
37
+ """Load the vector store from disk"""
38
+ raise NotImplementedError("Subclasses must implement this method")
39
+
40
+ class BM25VectorStore(BaseVectorStore):
41
+ """BM25 vector store implementation"""
42
+
43
+ def __init__(self, persist_path: Optional[str] = "bm25_index.pkl"):
44
+ super().__init__(persist_path)
45
+ self.bm25 = None
46
+
47
+ def add_documents(self, documents: List[Dict[str, Any]]):
48
+ self.documents = documents
49
+ self.doc_texts = [doc["content"] for doc in documents]
50
+
51
+ # Tokenize corpus for BM25
52
+ tokenized_corpus = [self._tokenize(doc) for doc in self.doc_texts]
53
+ self.bm25 = BM25Okapi(tokenized_corpus)
54
+
55
+ logger.info(f"BM25 index created with {len(self.documents)} documents")
56
+
57
+ def _tokenize(self, text: str) -> List[str]:
58
+ """Tokenize text for BM25"""
59
+ return re.findall(r'\w+', text.lower())
60
+
61
+ def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
62
+ if not self.bm25:
63
+ raise ValueError("BM25 index not initialized. Call add_documents first.")
64
+
65
+ try:
66
+ tokenized_query = self._tokenize(query)
67
+ doc_scores = self.bm25.get_scores(tokenized_query)
68
+
69
+ # Get top-k indices
70
+ top_indices = np.argsort(doc_scores)[::-1][:top_k]
71
+
72
+ # Filter out invalid indices
73
+ valid_indices = [i for i in top_indices if i < len(self.documents)]
74
+
75
+ if not valid_indices:
76
+ logger.warning(f"No valid indices found for query: {query}")
77
+ return []
78
+
79
+ return [self.documents[i] for i in valid_indices]
80
+
81
+ except Exception as e:
82
+ logger.error(f"Error in BM25 retrieval: {str(e)}")
83
+ return []
84
+
85
+ def persist(self):
86
+ if self.persist_path:
87
+ with open(self.persist_path, 'wb') as f:
88
+ pickle.dump({
89
+ 'documents': self.documents,
90
+ 'doc_texts': self.doc_texts,
91
+ 'bm25': self.bm25
92
+ }, f)
93
+ logger.info(f"BM25 index persisted to {self.persist_path}")
94
+
95
+ def load(self):
96
+ if self.persist_path and os.path.exists(self.persist_path):
97
+ with open(self.persist_path, 'rb') as f:
98
+ data = pickle.load(f)
99
+ self.documents = data['documents']
100
+ self.doc_texts = data['doc_texts']
101
+ self.bm25 = data['bm25']
102
+ logger.info(f"BM25 index loaded from {self.persist_path}")
103
+
104
+ import tempfile
105
+ class FAISSVectorStore(BaseVectorStore):
106
+ def __init__(self, persist_path: Optional[str] = None, model_name: Optional[str] = None):
107
+ super().__init__(persist_path)
108
+ self.model_name = model_name or config.FAISS_MODEL_NAME
109
+
110
+ # Handle cache directory permissions
111
+ try:
112
+ cache_dir = config.CACHE_DIR
113
+ os.makedirs(cache_dir, exist_ok=True)
114
+ # Test if we can write to the cache directory
115
+ test_file = os.path.join(cache_dir, 'write_test.txt')
116
+ with open(test_file, 'w') as f:
117
+ f.write('test')
118
+ os.remove(test_file)
119
+ except PermissionError:
120
+ logger.warning(f"Could not write to cache directory {cache_dir}. Using temp directory.")
121
+ cache_dir = tempfile.gettempdir()
122
+
123
+ self.model = SentenceTransformer(self.model_name, cache_folder=cache_dir)
124
+ self.dimension = self.model.get_sentence_embedding_dimension()
125
+ self.index = faiss.IndexFlatL2(self.dimension)
126
+ self.metadata_path = persist_path + ".meta" if persist_path else None
127
+
128
+ def add_documents(self, documents: List[Dict[str, Any]]):
129
+ self.documents = documents
130
+ self.doc_texts = [doc["content"] for doc in documents]
131
+
132
+ # Generate embeddings in batches
133
+ embeddings = []
134
+ batch_size = config.BATCH_SIZE
135
+
136
+ for i in range(0, len(self.doc_texts), batch_size):
137
+ batch_texts = self.doc_texts[i:i+batch_size]
138
+ batch_embeddings = self.model.encode(batch_texts, show_progress_bar=False)
139
+ embeddings.append(batch_embeddings)
140
+ logger.info(f"Processed batch {i//batch_size + 1}/{(len(self.doc_texts)-1)//batch_size + 1}")
141
+
142
+ embeddings = np.vstack(embeddings).astype('float32')
143
+ self.index.add(embeddings)
144
+
145
+ logger.info(f"FAISS index created with {len(self.documents)} documents")
146
+
147
+ def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
148
+ if not self.index or self.index.ntotal == 0:
149
+ raise ValueError("FAISS index not initialized. Call add_documents first.")
150
+
151
+ try:
152
+ query_embedding = self.model.encode([query])
153
+ distances, indices = self.index.search(query_embedding.astype('float32'), top_k)
154
+
155
+ # Filter out invalid indices (FAISS might return -1 for no results)
156
+ valid_indices = [i for i in indices[0] if i >= 0 and i < len(self.documents)]
157
+
158
+ if not valid_indices:
159
+ logger.warning(f"No valid indices found for query: {query}")
160
+ return []
161
+
162
+ return [self.documents[i] for i in valid_indices]
163
+ except Exception as e:
164
+ logger.error(f"Error in FAISS retrieval: {str(e)}")
165
+ return []
166
+
167
+ def persist(self):
168
+ if self.persist_path and self.metadata_path:
169
+ faiss.write_index(self.index, self.persist_path)
170
+
171
+ # Save metadata
172
+ with open(self.metadata_path, 'wb') as f:
173
+ pickle.dump({
174
+ 'documents': self.documents,
175
+ 'doc_texts': self.doc_texts
176
+ }, f)
177
+ logger.info(f"FAISS index persisted to {self.persist_path}")
178
+
179
+ def load(self):
180
+ if (self.persist_path and os.path.exists(self.persist_path) and
181
+ self.metadata_path and os.path.exists(self.metadata_path)):
182
+
183
+ self.index = faiss.read_index(self.persist_path)
184
+
185
+ # Load metadata
186
+ with open(self.metadata_path, 'rb') as f:
187
+ data = pickle.load(f)
188
+ self.documents = data['documents']
189
+ self.doc_texts = data['doc_texts']
190
+ logger.info(f"FAISS index loaded from {self.persist_path}")
191
+
192
+ class TFIDFVectorStore(BaseVectorStore):
193
+ """TFIDF vector store implementation"""
194
+
195
+ def __init__(self, persist_path: Optional[str] = "tfidf_index.pkl", max_features: int = 10000):
196
+ super().__init__(persist_path)
197
+ self.vectorizer = TfidfVectorizer(max_features=max_features, stop_words='english')
198
+ self.tfidf_matrix = None
199
+
200
+ def add_documents(self, documents: List[Dict[str, Any]]):
201
+ self.documents = documents
202
+ self.doc_texts = [doc["content"] for doc in documents]
203
+
204
+ # Fit and transform the documents
205
+ self.tfidf_matrix = self.vectorizer.fit_transform(self.doc_texts)
206
+
207
+ logger.info(f"TFIDF index created with {len(self.documents)} documents")
208
+
209
+ def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
210
+ if self.tfidf_matrix is None:
211
+ raise ValueError("TFIDF index not initialized. Call add_documents first.")
212
+
213
+ try:
214
+ # Transform the query
215
+ query_vec = self.vectorizer.transform([query])
216
+
217
+ # Calculate cosine similarities
218
+ similarities = cosine_similarity(query_vec, self.tfidf_matrix).flatten()
219
+
220
+ # Get top-k indices
221
+ top_indices = np.argsort(similarities)[::-1][:top_k]
222
+
223
+ # Filter out invalid indices
224
+ valid_indices = [i for i in top_indices if i < len(self.documents)]
225
+
226
+ if not valid_indices:
227
+ logger.warning(f"No valid indices found for query: {query}")
228
+ return []
229
+
230
+ return [self.documents[i] for i in valid_indices]
231
+
232
+ except Exception as e:
233
+ logger.error(f"Error in TFIDF retrieval: {str(e)}")
234
+ return []
235
+
236
+ def persist(self):
237
+ if self.persist_path:
238
+ with open(self.persist_path, 'wb') as f:
239
+ pickle.dump({
240
+ 'documents': self.documents,
241
+ 'doc_texts': self.doc_texts,
242
+ 'vectorizer': self.vectorizer,
243
+ 'tfidf_matrix': self.tfidf_matrix
244
+ }, f)
245
+ logger.info(f"TFIDF index persisted to {self.persist_path}")
246
+
247
+ def load(self):
248
+ if self.persist_path and os.path.exists(self.persist_path):
249
+ with open(self.persist_path, 'rb') as f:
250
+ data = pickle.load(f)
251
+ self.documents = data['documents']
252
+ self.doc_texts = data['doc_texts']
253
+ self.vectorizer = data['vectorizer']
254
+ self.tfidf_matrix = data['tfidf_matrix']
255
+ logger.info(f"TFIDF index loaded from {self.persist_path}")
256
+
257
+ class HybridVectorStore(BaseVectorStore):
258
+ """Hybrid vector store combining BM25 and FAISS"""
259
+
260
+ def __init__(self, persist_path: Optional[str] = "hybrid_index"):
261
+ super().__init__(persist_path)
262
+ self.bm25_store = BM25VectorStore(persist_path + "_bm25")
263
+ self.faiss_store = FAISSVectorStore(persist_path + "_faiss")
264
+
265
+ def add_documents(self, documents: List[Dict[str, Any]]):
266
+ self.documents = documents
267
+ self.bm25_store.add_documents(documents)
268
+ self.faiss_store.add_documents(documents)
269
+ logger.info(f"Hybrid index created with {len(self.documents)} documents")
270
+
271
+ def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
272
+ try:
273
+ # Get results from both methods
274
+ bm25_results = self.bm25_store.retrieve(query, top_k * 2)
275
+ faiss_results = self.faiss_store.retrieve(query, top_k * 2)
276
+
277
+ # Combine and deduplicate by content
278
+ combined = {}
279
+ for doc in bm25_results + faiss_results:
280
+ # Use a combination of content and metadata for deduplication
281
+ key = hash(doc["content"] + str(doc["metadata"]))
282
+ if key not in combined:
283
+ combined[key] = doc
284
+
285
+ all_results = list(combined.values())
286
+
287
+ # If no results after deduplication
288
+ if not all_results:
289
+ return []
290
+
291
+ # Rerank by relevance to query using FAISS similarity
292
+ query_embedding = self.faiss_store.model.encode(query)
293
+ scored_results = []
294
+
295
+ for doc in all_results:
296
+ doc_embedding = self.faiss_store.model.encode(doc["content"])
297
+ similarity = np.dot(query_embedding, doc_embedding) / (
298
+ np.linalg.norm(query_embedding) * np.linalg.norm(doc_embedding))
299
+
300
+ scored_results.append((doc, similarity))
301
+
302
+ scored_results.sort(key=lambda x: x[1], reverse=True)
303
+ return [doc for doc, _ in scored_results[:top_k]]
304
+
305
+ except Exception as e:
306
+ logger.error(f"Error in hybrid retrieval: {str(e)}")
307
+ return []
308
+
309
+ def persist(self):
310
+ self.bm25_store.persist()
311
+ self.faiss_store.persist()
312
+ logger.info(f"Hybrid index persisted")
313
+
314
+ def load(self):
315
+ self.bm25_store.load()
316
+ self.faiss_store.load()
317
+ self.documents = self.bm25_store.documents
318
+ logger.info(f"Hybrid index loaded")
319
+
320
+ class HybridOfflineVectorStore(BaseVectorStore):
321
+ """Hybrid offline vector store combining BM25 and TFIDF"""
322
+
323
+ def __init__(self, persist_path: Optional[str] = "hybrid_offline_index"):
324
+ super().__init__(persist_path)
325
+ self.bm25_store = BM25VectorStore(persist_path + "_bm25")
326
+ self.tfidf_store = TFIDFVectorStore(persist_path + "_tfidf")
327
+ self.alpha = 0.5 # Weight for BM25 vs TFIDF
328
+
329
+ def add_documents(self, documents: List[Dict[str, Any]]):
330
+ self.documents = documents
331
+ self.bm25_store.add_documents(documents)
332
+ self.tfidf_store.add_documents(documents)
333
+ logger.info(f"Hybrid offline index created with {len(self.documents)} documents")
334
+
335
+ def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
336
+ try:
337
+ # Get results from both methods
338
+ bm25_results = self.bm25_store.retrieve(query, top_k * 2)
339
+ tfidf_results = self.tfidf_store.retrieve(query, top_k * 2)
340
+
341
+ # Combine and deduplicate by content
342
+ combined = {}
343
+ for doc in bm25_results + tfidf_results:
344
+ # Use a combination of content and metadata for deduplication
345
+ key = hash(doc["content"] + str(doc["metadata"]))
346
+ if key not in combined:
347
+ combined[key] = doc
348
+
349
+ all_results = list(combined.values())
350
+
351
+ # If no results after deduplication
352
+ if not all_results:
353
+ return []
354
+
355
+ # Score results based on both methods
356
+ scored_results = []
357
+
358
+ for doc in all_results:
359
+ # Get BM25 score
360
+ bm25_score = 0
361
+ for i, bm25_doc in enumerate(bm25_results):
362
+ if (doc["content"] == bm25_doc["content"] and
363
+ doc["metadata"] == bm25_doc["metadata"]):
364
+ # Normalize score based on position
365
+ bm25_score = (len(bm25_results) - i) / len(bm25_results)
366
+ break
367
+
368
+ # Get TFIDF score
369
+ tfidf_score = 0
370
+ for i, tfidf_doc in enumerate(tfidf_results):
371
+ if (doc["content"] == tfidf_doc["content"] and
372
+ doc["metadata"] == tfidf_doc["metadata"]):
373
+ # Normalize score based on position
374
+ tfidf_score = (len(tfidf_results) - i) / len(tfidf_results)
375
+ break
376
+
377
+ # Combine scores
378
+ combined_score = self.alpha * bm25_score + (1 - self.alpha) * tfidf_score
379
+ scored_results.append((doc, combined_score))
380
+
381
+ scored_results.sort(key=lambda x: x[1], reverse=True)
382
+ return [doc for doc, _ in scored_results[:top_k]]
383
+
384
+ except Exception as e:
385
+ logger.error(f"Error in hybrid offline retrieval: {str(e)}")
386
+ return []
387
+
388
+ def persist(self):
389
+ self.bm25_store.persist()
390
+ self.tfidf_store.persist()
391
+ logger.info(f"Hybrid offline index persisted")
392
+
393
+ def load(self):
394
+ self.bm25_store.load()
395
+ self.tfidf_store.load()
396
+ self.documents = self.bm25_store.documents
397
+ logger.info(f"Hybrid offline index loaded")
kssrag/kssrag.py ADDED
@@ -0,0 +1,116 @@
1
+ """
2
+ Main KSSRAG class that ties everything together for easy usage.
3
+ """
4
+ from typing import Optional, List, Dict, Any
5
+ import os
6
+ from .core.chunkers import TextChunker, JSONChunker, PDFChunker
7
+ from .core.vectorstores import BM25VectorStore, FAISSVectorStore, TFIDFVectorStore, HybridVectorStore, HybridOfflineVectorStore
8
+ from .core.retrievers import SimpleRetriever, HybridRetriever
9
+ from .core.agents import RAGAgent
10
+ from .models.openrouter import OpenRouterLLM
11
+ from .utils.document_loaders import load_document, load_json_documents
12
+ from .config import Config, VectorStoreType, ChunkerType, RetrieverType
13
+ from .utils.helpers import logger, validate_config, import_custom_component
14
+
15
+ class KSSRAG:
16
+ """Main class for KSS RAG functionality"""
17
+
18
+ def __init__(self, config: Optional[Config] = None):
19
+ self.config = config or Config()
20
+ self.vector_store = None
21
+ self.retriever = None
22
+ self.agent = None
23
+ self.documents = []
24
+
25
+ # Validate configuration
26
+ validate_config()
27
+
28
+ def load_document(self, file_path: str, format: Optional[str] = None,
29
+ chunker: Optional[Any] = None, metadata: Optional[Dict[str, Any]] = None):
30
+ """Load and process a document"""
31
+ if format is None:
32
+ # Auto-detect format
33
+ if file_path.endswith('.txt'):
34
+ format = 'text'
35
+ elif file_path.endswith('.json'):
36
+ format = 'json'
37
+ elif file_path.endswith('.pdf'):
38
+ format = 'pdf'
39
+ else:
40
+ raise ValueError(f"Unsupported file format: {file_path}")
41
+
42
+ # Use custom chunker if provided
43
+ if chunker is None:
44
+ if format == 'text':
45
+ chunker = TextChunker(chunk_size=self.config.CHUNK_SIZE, overlap=self.config.CHUNK_OVERLAP)
46
+ elif format == 'json':
47
+ chunker = JSONChunker()
48
+ elif format == 'pdf':
49
+ chunker = PDFChunker(chunk_size=self.config.CHUNK_SIZE, overlap=self.config.CHUNK_OVERLAP)
50
+
51
+ if metadata is None:
52
+ metadata = {"source": file_path}
53
+
54
+ # Load and chunk document
55
+ if format == 'text':
56
+ content = load_document(file_path)
57
+ self.documents = chunker.chunk(content, metadata)
58
+ elif format == 'json':
59
+ data = load_json_documents(file_path)
60
+ self.documents = chunker.chunk(data)
61
+ elif format == 'pdf':
62
+ self.documents = chunker.chunk_pdf(file_path, metadata)
63
+
64
+ # Create vector store
65
+ if self.config.CUSTOM_VECTOR_STORE:
66
+ vector_store_class = import_custom_component(self.config.CUSTOM_VECTOR_STORE)
67
+ self.vector_store = vector_store_class()
68
+ else:
69
+ if self.config.VECTOR_STORE_TYPE == VectorStoreType.BM25:
70
+ self.vector_store = BM25VectorStore()
71
+ elif self.config.VECTOR_STORE_TYPE == VectorStoreType.FAISS:
72
+ self.vector_store = FAISSVectorStore()
73
+ elif self.config.VECTOR_STORE_TYPE == VectorStoreType.TFIDF:
74
+ self.vector_store = TFIDFVectorStore()
75
+ elif self.config.VECTOR_STORE_TYPE == VectorStoreType.HYBRID_ONLINE:
76
+ self.vector_store = HybridVectorStore()
77
+ elif self.config.VECTOR_STORE_TYPE == VectorStoreType.HYBRID_OFFLINE:
78
+ self.vector_store = HybridOfflineVectorStore()
79
+
80
+ self.vector_store.add_documents(self.documents)
81
+
82
+ # Create retriever
83
+ if self.config.CUSTOM_RETRIEVER:
84
+ retriever_class = import_custom_component(self.config.CUSTOM_RETRIEVER)
85
+ self.retriever = retriever_class(self.vector_store)
86
+ else:
87
+ if self.config.RETRIEVER_TYPE == RetrieverType.SIMPLE:
88
+ self.retriever = SimpleRetriever(self.vector_store)
89
+ elif self.config.RETRIEVER_TYPE == RetrieverType.HYBRID:
90
+ # For hybrid retriever, extract entity names from documents if available
91
+ entity_names = []
92
+ if format == 'json':
93
+ entity_names = [doc['metadata'].get('name', '') for doc in self.documents if doc['metadata'].get('name')]
94
+ self.retriever = HybridRetriever(self.vector_store, entity_names)
95
+
96
+ # Create LLM
97
+ if self.config.CUSTOM_LLM:
98
+ llm_class = import_custom_component(self.config.CUSTOM_LLM)
99
+ llm = llm_class()
100
+ else:
101
+ llm = OpenRouterLLM()
102
+
103
+ # Create agent
104
+ self.agent = RAGAgent(self.retriever, llm)
105
+
106
+ def query(self, question: str, top_k: Optional[int] = None) -> str:
107
+ """Query the RAG system"""
108
+ if not self.agent:
109
+ raise ValueError("No documents loaded. Call load_document first.")
110
+
111
+ return self.agent.query(question, top_k=top_k or self.config.TOP_K)
112
+
113
+ def create_server(self, server_config=None):
114
+ """Create a FastAPI server for the RAG system"""
115
+ from .server import create_app
116
+ return create_app(self.agent, server_config)
File without changes
@@ -0,0 +1,30 @@
1
+ """
2
+ Local LLM implementations for offline usage
3
+ """
4
+ from typing import List, Dict, Any
5
+ from ..utils.helpers import logger
6
+
7
+ class LocalLLM:
8
+ """Base class for local LLM implementations"""
9
+
10
+ def predict(self, messages: List[Dict[str, str]]) -> str:
11
+ """Generate a response using a local model"""
12
+ raise NotImplementedError("Subclasses must implement this method")
13
+
14
+ class MockLLM(LocalLLM):
15
+ """Mock LLM for testing without API calls"""
16
+
17
+ def predict(self, messages: List[Dict[str, str]]) -> str:
18
+ """Generate a mock response for testing"""
19
+ logger.info("Using mock LLM for response generation")
20
+
21
+ # Extract the last user message
22
+ user_message = ""
23
+ for msg in reversed(messages):
24
+ if msg["role"] == "user":
25
+ user_message = msg["content"]
26
+ break
27
+
28
+ return f"This is a mock response to: {user_message}"
29
+
30
+ # TO-DO - Add more local LLM implementations as needed - I hope i do not forget, someone if you are seeing this, the date i added this was 8th of september 2025, how long has it been? DM me on WhatsApp +2349019549473.
@@ -0,0 +1,85 @@
1
+ import requests
2
+ import json
3
+ from typing import List, Dict, Any, Optional
4
+ from ..utils.helpers import logger
5
+ from ..config import config
6
+
7
+ class OpenRouterLLM:
8
+ """OpenRouter LLM interface with fallback models"""
9
+
10
+ def __init__(self, api_key: Optional[str] = None, model: Optional[str] = None,
11
+ fallback_models: Optional[List[str]] = None):
12
+ self.api_key = api_key or config.OPENROUTER_API_KEY
13
+ self.model = model or config.DEFAULT_MODEL
14
+ self.fallback_models = fallback_models or config.FALLBACK_MODELS
15
+ self.base_url = "https://openrouter.ai/api/v1/chat/completions"
16
+ self.headers = {
17
+ "Authorization": f"Bearer {self.api_key}",
18
+ "Content-Type": "application/json",
19
+ "HTTP-Referer": "https://github.com/Ksschkw/kssrag",
20
+ "X-Title": "KSS RAG Agent"
21
+ }
22
+
23
+ def predict(self, messages: List[Dict[str, str]]) -> str:
24
+ """Generate a response using OpenRouter's API with fallbacks"""
25
+ logger.info(f"Attempting to generate response with {len(messages)} messages")
26
+
27
+ for model in [self.model] + self.fallback_models:
28
+ payload = {
29
+ "model": model,
30
+ "messages": messages,
31
+ "temperature": 0.7,
32
+ "max_tokens": 1024,
33
+ "top_p": 1,
34
+ "stop": None,
35
+ "stream": False
36
+ }
37
+
38
+ try:
39
+ logger.info(f"Trying model: {model}")
40
+ response = requests.post(
41
+ self.base_url,
42
+ headers=self.headers,
43
+ json=payload,
44
+ timeout=15
45
+ )
46
+
47
+ # Check for HTTP errors
48
+ response.raise_for_status()
49
+
50
+ # Parse JSON response
51
+ response_data = response.json()
52
+
53
+ # Validate response structure
54
+ if ("choices" not in response_data or
55
+ len(response_data["choices"]) == 0 or
56
+ "message" not in response_data["choices"][0] or
57
+ "content" not in response_data["choices"][0]["message"]):
58
+
59
+ logger.warning(f"Invalid response format from {model}: {response_data}")
60
+ continue
61
+
62
+ content = response_data["choices"][0]["message"]["content"]
63
+ logger.info(f"Successfully used model: {model}")
64
+ return content
65
+
66
+ except requests.exceptions.Timeout:
67
+ logger.warning(f"Model {model} timed out")
68
+ continue
69
+ except requests.exceptions.RequestException as e:
70
+ logger.warning(f"Request error with model {model}: {str(e)}")
71
+ if hasattr(e, 'response') and e.response is not None:
72
+ try:
73
+ error_data = e.response.json()
74
+ logger.warning(f"Error response: {error_data}")
75
+ except:
76
+ logger.warning(f"Error response text: {e.response.text}")
77
+ continue
78
+ except Exception as e:
79
+ logger.warning(f"Unexpected error with model {model}: {str(e)}")
80
+ continue
81
+
82
+ # If all models fail, return a friendly error message
83
+ error_msg = "I'm having trouble connecting to the knowledge service right now. Please try again in a moment."
84
+ logger.error("All model fallbacks failed to respond")
85
+ return error_msg