kssrag 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kssrag/cli.py CHANGED
@@ -2,8 +2,8 @@ import argparse
2
2
  import sys
3
3
  import os # Add this import if not already present
4
4
  from .utils.document_loaders import load_document, load_json_documents
5
- from .core.chunkers import TextChunker, JSONChunker, PDFChunker
6
- from .core.vectorstores import BM25VectorStore, FAISSVectorStore, TFIDFVectorStore, HybridVectorStore, HybridOfflineVectorStore
5
+ from .core.chunkers import ImageChunker, OfficeChunker, TextChunker, JSONChunker, PDFChunker
6
+ from .core.vectorstores import BM25SVectorStore, BM25VectorStore, FAISSVectorStore, TFIDFVectorStore, HybridVectorStore, HybridOfflineVectorStore
7
7
  from .core.retrievers import SimpleRetriever, HybridRetriever
8
8
  from .core.agents import RAGAgent
9
9
  from .models.openrouter import OpenRouterLLM
@@ -19,27 +19,36 @@ def main():
19
19
  query_parser = subparsers.add_parser("query", help="Query the RAG system")
20
20
  query_parser.add_argument("--file", type=str, required=True, help="Path to document file")
21
21
  query_parser.add_argument("--query", type=str, required=True, help="Query to ask")
22
- query_parser.add_argument("--format", type=str, default="text", choices=["text", "json", "pdf"],
23
- help="Document format")
22
+ query_parser.add_argument("--format", type=str, default="text",
23
+ choices=["text", "json", "pdf", "image", "docx", "excel", "pptx"],
24
+ help="Document format")
24
25
  query_parser.add_argument("--vector-store", type=str, default=config.VECTOR_STORE_TYPE,
25
- choices=["bm25", "faiss", "tfidf", "hybrid_online", "hybrid_offline"],
26
- help="Vector store type")
26
+ choices=["bm25", "bm25s", "faiss", "tfidf", "hybrid_online", "hybrid_offline"],
27
+ help="Vector store type")
28
+ query_parser.add_argument("--stream", action="store_true",
29
+ help="Enable streaming response")
27
30
  query_parser.add_argument("--top-k", type=int, default=config.TOP_K, help="Number of results to retrieve")
28
31
  query_parser.add_argument("--system-prompt", type=str, help="Path to a file containing the system prompt or the prompt text itself")
32
+ query_parser.add_argument("--ocr-mode", type=str, choices=["typed", "handwritten"],
33
+ default=config.OCR_DEFAULT_MODE,
34
+ help="OCR mode for image processing")
29
35
 
30
36
  # Server command
31
37
  server_parser = subparsers.add_parser("server", help="Start the RAG API server")
32
38
  server_parser.add_argument("--file", type=str, required=True, help="Path to document file")
33
- server_parser.add_argument("--format", type=str, default="text", choices=["text", "json", "pdf"],
34
- help="Document format")
39
+ server_parser.add_argument("--format", type=str, default="text",
40
+ choices=["text", "json", "pdf", "image", "docx", "excel", "pptx"],
41
+ help="Document format")
42
+ # I Updated the server parser vector store choices
35
43
  server_parser.add_argument("--vector-store", type=str, default=config.VECTOR_STORE_TYPE,
36
- choices=["bm25", "faiss", "tfidf", "hybrid_online", "hybrid_offline"],
37
- help="Vector store type")
44
+ choices=["bm25", "bm25s", "faiss", "tfidf", "hybrid_online", "hybrid_offline"], # Add bm25s
45
+ help="Vector store type")
38
46
  server_parser.add_argument("--port", type=int, default=config.SERVER_PORT, help="Port to run server on")
39
47
  server_parser.add_argument("--host", type=str, default=config.SERVER_HOST, help="Host to run server on")
40
48
  server_parser.add_argument("--system-prompt", type=str, help="Path to a file containing the system prompt or the prompt text itself")
41
49
 
42
50
  args = parser.parse_args()
51
+ vector_store_type = args.vector_store if hasattr(args, 'vector_store') else config.VECTOR_STORE_TYPE
43
52
 
44
53
  # Validate config
45
54
  validate_config()
@@ -52,6 +61,7 @@ def main():
52
61
  with open(prompt_arg, 'r', encoding='utf-8') as f:
53
62
  return f.read()
54
63
  return prompt_arg
64
+
55
65
 
56
66
  if args.command == "query":
57
67
  # Load and process document
@@ -66,6 +76,17 @@ def main():
66
76
  elif args.format == "pdf":
67
77
  chunker = PDFChunker(chunk_size=config.CHUNK_SIZE, overlap=config.CHUNK_OVERLAP)
68
78
  documents = chunker.chunk_pdf(args.file, {"source": args.file})
79
+ elif args.format == "image":
80
+ chunker = ImageChunker(
81
+ chunk_size=config.CHUNK_SIZE,
82
+ overlap=config.CHUNK_OVERLAP,
83
+ ocr_mode=getattr(args, 'ocr_mode', config.OCR_DEFAULT_MODE)
84
+ )
85
+ documents = chunker.chunk(args.file, {"source": args.file})
86
+ elif args.format in ["docx", "excel", "pptx"]:
87
+ # Use OfficeChunker for office documents
88
+ chunker = OfficeChunker(chunk_size=config.CHUNK_SIZE, overlap=config.CHUNK_OVERLAP)
89
+ documents = chunker.chunk_office(args.file, {"source": args.file})
69
90
  else:
70
91
  logger.error(f"Unsupported format: {args.format}")
71
92
  return 1
@@ -81,6 +102,8 @@ def main():
81
102
  vector_store = HybridVectorStore()
82
103
  elif args.vector_store == "hybrid_offline":
83
104
  vector_store = HybridOfflineVectorStore()
105
+ elif args.vector_store == "bm25s":
106
+ vector_store = BM25SVectorStore()
84
107
  else:
85
108
  logger.error(f"Unsupported vector store: {args.vector_store}")
86
109
  return 1
@@ -94,9 +117,30 @@ def main():
94
117
  agent = RAGAgent(retriever, llm, system_prompt=system_prompt)
95
118
 
96
119
  # Query and print response
97
- response = agent.query(args.query, top_k=args.top_k)
98
- print(f"Query: {args.query}")
99
- print(f"Response: {response}")
120
+ # response = agent.query(args.query, top_k=args.top_k)
121
+ # print(f"Query: {args.query}")
122
+ # print(f"Response: {response}")
123
+
124
+ # In the query section, after creating the agent:
125
+ if args.stream:
126
+ print(f"Query: {args.query}")
127
+ print("Response: ", end="", flush=True)
128
+
129
+ try:
130
+ # Collect all chunks and print them as they come
131
+ full_response = ""
132
+ for chunk in agent.query_stream(args.query, top_k=args.top_k):
133
+ print(chunk, end="", flush=True)
134
+ full_response += chunk
135
+ print() # New line at the end
136
+
137
+ # The response is already added to conversation in query_stream
138
+ except Exception as e:
139
+ print(f"\nError during streaming: {str(e)}")
140
+ else:
141
+ response = agent.query(args.query, top_k=args.top_k)
142
+ print(f"Query: {args.query}")
143
+ print(f"Response: {response}")
100
144
 
101
145
  elif args.command == "server":
102
146
  # Load and process document
@@ -126,6 +170,8 @@ def main():
126
170
  vector_store = HybridVectorStore()
127
171
  elif args.vector_store == "hybrid_offline":
128
172
  vector_store = HybridOfflineVectorStore()
173
+ elif args.vector_store == "bm25s":
174
+ vector_store = BM25SVectorStore()
129
175
  else:
130
176
  logger.error(f"Unsupported vector store: {args.vector_store}")
131
177
  return 1
kssrag/config.py CHANGED
@@ -9,6 +9,7 @@ load_dotenv()
9
9
 
10
10
  class VectorStoreType(str, Enum):
11
11
  BM25 = "bm25"
12
+ BM25S = "bm25s"
12
13
  FAISS = "faiss"
13
14
  TFIDF = "tfidf"
14
15
  HYBRID_ONLINE = "hybrid_online"
@@ -19,6 +20,7 @@ class ChunkerType(str, Enum):
19
20
  TEXT = "text"
20
21
  JSON = "json"
21
22
  PDF = "pdf"
23
+ IMAGE = "image"
22
24
  CUSTOM = "custom"
23
25
 
24
26
  class RetrieverType(str, Enum):
@@ -36,7 +38,7 @@ class Config(BaseSettings):
36
38
  )
37
39
 
38
40
  DEFAULT_MODEL: str = Field(
39
- default=os.getenv("DEFAULT_MODEL", "deepseek/deepseek-chat-v3.1:free"),
41
+ default=os.getenv("DEFAULT_MODEL", "deepseek/deepseek-chat"),
40
42
  description="Default model to use for LLM responses"
41
43
  )
42
44
 
@@ -183,6 +185,18 @@ class Config(BaseSettings):
183
185
  env_file = ".env"
184
186
  case_sensitive = False
185
187
  use_enum_values = True
188
+
189
+ # OCR settings
190
+ OCR_DEFAULT_MODE: str = Field(
191
+ default=os.getenv("OCR_DEFAULT_MODE", "typed"),
192
+ description="Default OCR mode: typed or handwritten"
193
+ )
194
+
195
+ # Streaming settings
196
+ ENABLE_STREAMING: bool = Field(
197
+ default=os.getenv("ENABLE_STREAMING", "False").lower() == "true",
198
+ description="Whether to enable streaming responses"
199
+ )
186
200
 
187
201
  @validator('FALLBACK_MODELS', 'CORS_ORIGINS', 'CORS_ALLOW_METHODS', 'CORS_ALLOW_HEADERS', pre=True)
188
202
  def split_string(cls, v):
kssrag/core/agents.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import List, Dict, Any, Optional
1
+ from typing import Generator, List, Dict, Any, Optional
2
2
  from ..utils.helpers import logger
3
3
 
4
4
  class RAGAgent:
@@ -29,6 +29,32 @@ class RAGAgent:
29
29
  # Keep the most recent messages
30
30
  self.conversation = [system_msg] + other_msgs[-9:] if system_msg else other_msgs[-10:]
31
31
 
32
+ def _build_context(self, context_docs: List[Dict[str, Any]]) -> str:
33
+ """Build context string from documents"""
34
+ if not context_docs:
35
+ return ""
36
+
37
+ context = "Relevant information:\n"
38
+ for i, doc in enumerate(context_docs, 1):
39
+ context += f"\n--- Document {i} ---\n{doc['content']}\n"
40
+ return context
41
+
42
+ def _build_messages(self, question: str, context: str = "") -> List[Dict[str, str]]:
43
+ """Build messages for LLM including context"""
44
+ # Start with conversation history
45
+ messages = self.conversation.copy()
46
+
47
+ # Add user query with context
48
+ user_message = f"{context}\n\nQuestion: {question}" if context else question
49
+
50
+ # Replace the last user message if it exists, otherwise add new one
51
+ if messages and messages[-1]["role"] == "user":
52
+ messages[-1]["content"] = user_message
53
+ else:
54
+ messages.append({"role": "user", "content": user_message})
55
+
56
+ return messages
57
+
32
58
  def query(self, question: str, top_k: int = 5, include_context: bool = True) -> str:
33
59
  """Process a query and return a response"""
34
60
  try:
@@ -40,18 +66,13 @@ class RAGAgent:
40
66
  return "I couldn't find relevant information to answer your question."
41
67
 
42
68
  # Format context
43
- context = ""
44
- if include_context and context_docs:
45
- context = "Relevant information:\n"
46
- for i, doc in enumerate(context_docs, 1):
47
- context += f"\n--- Document {i} ---\n{doc['content']}\n"
69
+ context = self._build_context(context_docs) if include_context and context_docs else ""
48
70
 
49
- # Add user query with context
50
- user_message = f"{context}\n\nQuestion: {question}" if context else question
51
- self.add_message("user", user_message)
71
+ # Build messages
72
+ messages = self._build_messages(question, context)
52
73
 
53
74
  # Generate response
54
- response = self.llm.predict(self.conversation)
75
+ response = self.llm.predict(messages)
55
76
 
56
77
  # Add assistant response to conversation
57
78
  self.add_message("assistant", response)
@@ -62,6 +83,36 @@ class RAGAgent:
62
83
  logger.error(f"Error processing query: {str(e)}")
63
84
  return "I encountered an issue processing your query. Please try again."
64
85
 
86
+ def query_stream(self, question: str, top_k: int = 5) -> Generator[str, None, None]:
87
+ """Query the RAG system with streaming response"""
88
+ try:
89
+ # Retrieve relevant documents
90
+ relevant_docs = self.retriever.retrieve(question, top_k=top_k)
91
+
92
+ # Build context from documents
93
+ context = self._build_context(relevant_docs)
94
+
95
+ # Build messages
96
+ messages = self._build_messages(question, context)
97
+
98
+ # Stream response from LLM
99
+ if hasattr(self.llm, 'predict_stream'):
100
+ for chunk in self.llm.predict_stream(messages):
101
+ yield chunk
102
+
103
+ # Add the complete response to conversation history
104
+ full_response = "".join([chunk for chunk in self.llm.predict_stream(messages)])
105
+ self.add_message("assistant", full_response)
106
+ else:
107
+ # Fallback to non-streaming
108
+ response = self.llm.predict(messages)
109
+ self.add_message("assistant", response)
110
+ yield response
111
+
112
+ except Exception as e:
113
+ logger.error(f"Error in streaming query: {str(e)}")
114
+ yield f"Error: {str(e)}"
115
+
65
116
  def clear_conversation(self):
66
117
  """Clear conversation history except system message"""
67
118
  system_msg = next((msg for msg in self.conversation if msg["role"] == "system"), None)
kssrag/core/chunkers.py CHANGED
@@ -1,8 +1,16 @@
1
1
  import json
2
2
  import re
3
+ import os
3
4
  from typing import List, Dict, Any, Optional
4
5
  import pypdf
5
6
  from ..utils.helpers import logger
7
+ import os
8
+ try:
9
+ from ..utils.ocr_loader import OCRLoader
10
+ OCR_AVAILABLE = True
11
+ except ImportError:
12
+ OCR_AVAILABLE = False
13
+ OCRLoader = None
6
14
 
7
15
  class BaseChunker:
8
16
  """Base class for document chunkers"""
@@ -46,7 +54,7 @@ class TextChunker(BaseChunker):
46
54
  return chunks
47
55
 
48
56
  class JSONChunker(BaseChunker):
49
- """Chunker for JSON documents (like your drug data)"""
57
+ """Chunker for JSON documents"""
50
58
 
51
59
  def chunk(self, data: List[Dict[str, Any]], metadata_field: str = "name") -> List[Dict[str, Any]]:
52
60
  """Create chunks from JSON data"""
@@ -97,4 +105,90 @@ class PDFChunker(TextChunker):
97
105
  def chunk_pdf(self, pdf_path: str, metadata: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
98
106
  """Extract text from PDF and chunk it"""
99
107
  text = self.extract_text(pdf_path)
108
+ return self.chunk(text, metadata)
109
+
110
+ class ImageChunker(BaseChunker):
111
+ """Chunker for image documents using OCR"""
112
+
113
+ def __init__(self, chunk_size: int = 500, overlap: int = 50, ocr_mode: str = "typed"):
114
+ super().__init__(chunk_size, overlap)
115
+ self.ocr_mode = ocr_mode # typed or handwritten
116
+ self.ocr_loader = None
117
+
118
+ # Initialize OCR loader
119
+ try:
120
+ from ..utils.ocr_loader import OCRLoader
121
+ self.ocr_loader = OCRLoader()
122
+ logger.info(f"OCR loader initialized with mode: {ocr_mode}")
123
+ except ImportError as e:
124
+ logger.error(f"OCR dependencies not available: {str(e)}")
125
+ raise ImportError(
126
+ "OCR functionality requires extra dependencies. "
127
+ "Install with: pip install kssrag[ocr]"
128
+ ) from e
129
+
130
+ def extract_text_from_image(self, image_path: str) -> str:
131
+ """Extract text from image using specified OCR engine"""
132
+ if not self.ocr_loader:
133
+ raise RuntimeError("OCR loader not initialized")
134
+
135
+ if self.ocr_mode not in ["typed", "handwritten"]:
136
+ raise ValueError(f"Invalid OCR mode: {self.ocr_mode}. Must be 'typed' or 'handwritten'")
137
+
138
+ logger.info(f"Extracting text from {image_path} using {self.ocr_mode} OCR")
139
+
140
+ try:
141
+ text = self.ocr_loader.extract_text(image_path, self.ocr_mode)
142
+
143
+ if not text.strip():
144
+ logger.warning(f"No text extracted from image: {image_path}")
145
+ return ""
146
+
147
+ logger.info(f"Successfully extracted {len(text)} characters from {image_path}")
148
+ return text
149
+
150
+ except Exception as e:
151
+ logger.error(f"OCR extraction failed for {image_path}: {str(e)}")
152
+ raise RuntimeError(f"Failed to extract text from image {image_path}: {str(e)}")
153
+
154
+ def chunk(self, image_path: str, metadata: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
155
+ """Extract text from image and chunk it"""
156
+ if metadata is None:
157
+ metadata = {}
158
+
159
+ # Validate image file
160
+ if not os.path.exists(image_path):
161
+ raise FileNotFoundError(f"Image file not found: {image_path}")
162
+
163
+ # Extract text from image
164
+ text = self.extract_text_from_image(image_path)
165
+
166
+ if not text.strip():
167
+ return []
168
+
169
+ # Use text chunking on extracted text
170
+ text_chunker = TextChunker(chunk_size=self.chunk_size, overlap=self.overlap)
171
+ chunks = text_chunker.chunk(text, metadata)
172
+
173
+ # Add OCR-specific metadata
174
+ for chunk in chunks:
175
+ chunk["metadata"]["ocr_extracted"] = True
176
+ chunk["metadata"]["image_source"] = image_path
177
+ chunk["metadata"]["ocr_mode"] = self.ocr_mode
178
+
179
+ logger.info(f"Created {len(chunks)} chunks from image {image_path}")
180
+ return chunks
181
+
182
+ class OfficeChunker(TextChunker):
183
+ """Chunker for Office documents (DOCX, Excel, PowerPoint)"""
184
+
185
+ def chunk_office(self, file_path: str, metadata: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
186
+ """Chunk office documents by extracting text first"""
187
+ if metadata is None:
188
+ metadata = {}
189
+
190
+ # Extract text based on file type
191
+ from ..utils.document_loaders import load_document
192
+ text = load_document(file_path)
193
+
100
194
  return self.chunk(text, metadata)
@@ -13,6 +13,13 @@ from typing import List, Dict, Any, Optional
13
13
  from ..utils.helpers import logger
14
14
  from ..config import config
15
15
 
16
+ FAISS_AVAILABLE = False
17
+ try:
18
+ import faiss
19
+ FAISS_AVAILABLE = True
20
+ except ImportError:
21
+ pass
22
+
16
23
  class BaseVectorStore:
17
24
  """Base class for vector stores"""
18
25
 
@@ -102,11 +109,23 @@ class BM25VectorStore(BaseVectorStore):
102
109
  logger.info(f"BM25 index loaded from {self.persist_path}")
103
110
 
104
111
  import tempfile
112
+ # class FAISSVectorStore(BaseVectorStore):
113
+ # def __init__(self, persist_path: Optional[str] = None, model_name: Optional[str] = None):
114
+ # if not FAISS_AVAILABLE:
115
+ # raise ImportError("FAISS is not available. Please install it with 'pip install faiss-cpu' or use a different vector store.")
116
+ # super().__init__(persist_path)
117
+ # self.model_name = model_name or config.FAISS_MODEL_NAME
105
118
  class FAISSVectorStore(BaseVectorStore):
106
119
  def __init__(self, persist_path: Optional[str] = None, model_name: Optional[str] = None):
120
+ # Only setup FAISS when this vector store is actually used
121
+ from ..utils.helpers import setup_faiss
122
+ faiss_available, _ = setup_faiss("faiss") # Explicitly request FAISS
123
+
124
+ if not faiss_available:
125
+ raise ImportError("FAISS is not available. Please install it with 'pip install faiss-cpu' or use a different vector store.")
126
+
107
127
  super().__init__(persist_path)
108
128
  self.model_name = model_name or config.FAISS_MODEL_NAME
109
-
110
129
  # Handle cache directory permissions
111
130
  try:
112
131
  cache_dir = config.CACHE_DIR
@@ -394,4 +413,86 @@ class HybridOfflineVectorStore(BaseVectorStore):
394
413
  self.bm25_store.load()
395
414
  self.tfidf_store.load()
396
415
  self.documents = self.bm25_store.documents
397
- logger.info(f"Hybrid offline index loaded")
416
+ logger.info(f"Hybrid offline index loaded")
417
+
418
+ import bm25s
419
+ from Stemmer import Stemmer
420
+
421
+ class BM25SVectorStore(BaseVectorStore):
422
+ """BM25S vector store using the bm25s library for ultra-fast retrieval"""
423
+
424
+ def __init__(self, persist_path: Optional[str] = "bm25s_index.pkl"):
425
+ super().__init__(persist_path)
426
+ self.bm25_retriever = None
427
+ self.stemmer = Stemmer("english")
428
+ self.corpus_tokens = None
429
+
430
+ def add_documents(self, documents: List[Dict[str, Any]]):
431
+ self.documents = documents
432
+ self.doc_texts = [doc["content"] for doc in documents]
433
+
434
+ try:
435
+ # Tokenize corpus with BM25S
436
+ self.corpus_tokens = bm25s.tokenize(
437
+ self.doc_texts,
438
+ stopwords="en",
439
+ stemmer=self.stemmer
440
+ )
441
+
442
+ # Create and index with BM25S
443
+ self.bm25_retriever = bm25s.BM25()
444
+ self.bm25_retriever.index(self.corpus_tokens)
445
+
446
+ logger.info(f"BM25S index created with {len(self.documents)} documents")
447
+
448
+ except Exception as e:
449
+ logger.error(f"BM25S initialization failed: {str(e)}")
450
+ raise
451
+
452
+ def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
453
+ if not self.bm25_retriever:
454
+ raise ValueError("BM25S index not initialized. Call add_documents first.")
455
+
456
+ try:
457
+ # Tokenize query with BM25S
458
+ query_tokens = bm25s.tokenize([query], stemmer=self.stemmer)
459
+
460
+ # Retrieve with BM25S
461
+ results, scores = self.bm25_retriever.retrieve(query_tokens, k=top_k)
462
+
463
+ # Format results
464
+ retrieved_docs = []
465
+ for i in range(results.shape[1]):
466
+ doc_idx = results[0, i]
467
+ score = scores[0, i]
468
+
469
+ if doc_idx < len(self.documents):
470
+ retrieved_docs.append(self.documents[doc_idx])
471
+
472
+ logger.info(f"BM25S retrieved {len(retrieved_docs)} documents for query: {query}")
473
+ return retrieved_docs
474
+
475
+ except Exception as e:
476
+ logger.error(f"BM25S retrieval failed for query '{query}': {str(e)}")
477
+ return []
478
+
479
+ def persist(self):
480
+ if self.persist_path:
481
+ with open(self.persist_path, 'wb') as f:
482
+ pickle.dump({
483
+ 'documents': self.documents,
484
+ 'doc_texts': self.doc_texts,
485
+ 'corpus_tokens': self.corpus_tokens,
486
+ 'bm25_retriever': self.bm25_retriever
487
+ }, f)
488
+ logger.info(f"BM25S index persisted to {self.persist_path}")
489
+
490
+ def load(self):
491
+ if self.persist_path and os.path.exists(self.persist_path):
492
+ with open(self.persist_path, 'rb') as f:
493
+ data = pickle.load(f)
494
+ self.documents = data['documents']
495
+ self.doc_texts = data['doc_texts']
496
+ self.corpus_tokens = data['corpus_tokens']
497
+ self.bm25_retriever = data['bm25_retriever']
498
+ logger.info(f"BM25S index loaded from {self.persist_path}")
@@ -1,17 +1,18 @@
1
1
  import requests
2
2
  import json
3
- from typing import List, Dict, Any, Optional
3
+ from typing import List, Dict, Any, Optional, Generator
4
4
  from ..utils.helpers import logger
5
5
  from ..config import config
6
6
 
7
7
  class OpenRouterLLM:
8
- """OpenRouter LLM interface with fallback models"""
8
+ """OpenRouter LLM interface with streaming support"""
9
9
 
10
10
  def __init__(self, api_key: Optional[str] = None, model: Optional[str] = None,
11
- fallback_models: Optional[List[str]] = None):
11
+ fallback_models: Optional[List[str]] = None, stream: bool = False):
12
12
  self.api_key = api_key or config.OPENROUTER_API_KEY
13
13
  self.model = model or config.DEFAULT_MODEL
14
14
  self.fallback_models = fallback_models or config.FALLBACK_MODELS
15
+ self.stream = stream
15
16
  self.base_url = "https://openrouter.ai/api/v1/chat/completions"
16
17
  self.headers = {
17
18
  "Authorization": f"Bearer {self.api_key}",
@@ -21,8 +22,14 @@ class OpenRouterLLM:
21
22
  }
22
23
 
23
24
  def predict(self, messages: List[Dict[str, str]]) -> str:
24
- """Generate a response using OpenRouter's API with fallbacks"""
25
- logger.info(f"Attempting to generate response with {len(messages)} messages")
25
+ """Generate response with fallback models"""
26
+ if self.stream:
27
+ full_response = ""
28
+ for chunk in self.predict_stream(messages):
29
+ full_response += chunk
30
+ return full_response
31
+
32
+ logger.info(f"Generating response with {len(messages)} messages")
26
33
 
27
34
  for model in [self.model] + self.fallback_models:
28
35
  payload = {
@@ -36,21 +43,17 @@ class OpenRouterLLM:
36
43
  }
37
44
 
38
45
  try:
39
- logger.info(f"Trying model: {model}")
46
+ logger.info(f"Using model: {model}")
40
47
  response = requests.post(
41
48
  self.base_url,
42
49
  headers=self.headers,
43
50
  json=payload,
44
- timeout=15
51
+ timeout=30
45
52
  )
46
53
 
47
- # Check for HTTP errors
48
54
  response.raise_for_status()
49
-
50
- # Parse JSON response
51
55
  response_data = response.json()
52
56
 
53
- # Validate response structure
54
57
  if ("choices" not in response_data or
55
58
  len(response_data["choices"]) == 0 or
56
59
  "message" not in response_data["choices"][0] or
@@ -60,7 +63,7 @@ class OpenRouterLLM:
60
63
  continue
61
64
 
62
65
  content = response_data["choices"][0]["message"]["content"]
63
- logger.info(f"Successfully used model: {model}")
66
+ logger.info(f"Successfully generated response with model: {model}")
64
67
  return content
65
68
 
66
69
  except requests.exceptions.Timeout:
@@ -79,7 +82,66 @@ class OpenRouterLLM:
79
82
  logger.warning(f"Unexpected error with model {model}: {str(e)}")
80
83
  continue
81
84
 
82
- # If all models fail, return a friendly error message
83
- error_msg = "I'm having trouble connecting to the knowledge service right now. Please try again in a moment."
85
+ error_msg = "Unable to generate response from available models. Please try again."
84
86
  logger.error("All model fallbacks failed to respond")
85
- return error_msg
87
+ return error_msg
88
+
89
+ def predict_stream(self, messages: List[Dict[str, str]]) -> Generator[str, None, None]:
90
+ """Stream response from OpenRouter API"""
91
+ logger.info(f"Streaming response with {len(messages)} messages")
92
+
93
+ for model in [self.model] + self.fallback_models:
94
+ payload = {
95
+ "model": model,
96
+ "messages": messages,
97
+ "temperature": 0.7,
98
+ "max_tokens": 1024,
99
+ "top_p": 1,
100
+ "stop": None,
101
+ "stream": True
102
+ }
103
+
104
+ try:
105
+ logger.info(f"Streaming with model: {model}")
106
+ response = requests.post(
107
+ self.base_url,
108
+ headers=self.headers,
109
+ json=payload,
110
+ timeout=60,
111
+ stream=True
112
+ )
113
+
114
+ response.raise_for_status()
115
+
116
+ for line in response.iter_lines():
117
+ if line:
118
+ line = line.decode('utf-8')
119
+ if line.startswith('data: '):
120
+ data = line[6:]
121
+ if data.strip() == '[DONE]':
122
+ logger.info("Stream completed successfully")
123
+ return
124
+ try:
125
+ chunk_data = json.loads(data)
126
+ if ('choices' in chunk_data and
127
+ len(chunk_data['choices']) > 0 and
128
+ 'delta' in chunk_data['choices'][0] and
129
+ 'content' in chunk_data['choices'][0]['delta']):
130
+
131
+ content = chunk_data['choices'][0]['delta']['content']
132
+ if content:
133
+ yield content
134
+ except json.JSONDecodeError as e:
135
+ logger.warning(f"Failed to parse stream chunk: {str(e)}")
136
+ continue
137
+
138
+ logger.info(f"Successfully streamed from model: {model}")
139
+ return
140
+
141
+ except Exception as e:
142
+ logger.warning(f"Streaming failed with model {model}: {str(e)}")
143
+ continue
144
+
145
+ error_msg = "Unable to stream response from available models. Please try again."
146
+ logger.error("All model fallbacks failed for streaming")
147
+ yield error_msg