kssrag 0.1.2__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kssrag/cli.py +58 -13
- kssrag/config.py +15 -1
- kssrag/core/agents.py +62 -10
- kssrag/core/chunkers.py +95 -1
- kssrag/core/vectorstores.py +95 -3
- kssrag/models/openrouter.py +78 -16
- kssrag/server.py +66 -4
- kssrag/utils/document_loaders.py +80 -2
- kssrag/utils/helpers.py +38 -25
- kssrag/utils/ocr.py +48 -0
- kssrag/utils/ocr_loader.py +151 -0
- kssrag-0.2.1.dist-info/METADATA +840 -0
- kssrag-0.2.1.dist-info/RECORD +33 -0
- tests/test_bm25s.py +74 -0
- tests/test_config.py +42 -0
- tests/test_image_chunker.py +17 -0
- tests/test_integration.py +35 -0
- tests/test_ocr.py +142 -0
- tests/test_streaming.py +41 -0
- kssrag-0.1.2.dist-info/METADATA +0 -407
- kssrag-0.1.2.dist-info/RECORD +0 -25
- {kssrag-0.1.2.dist-info → kssrag-0.2.1.dist-info}/WHEEL +0 -0
- {kssrag-0.1.2.dist-info → kssrag-0.2.1.dist-info}/entry_points.txt +0 -0
- {kssrag-0.1.2.dist-info → kssrag-0.2.1.dist-info}/top_level.txt +0 -0
kssrag/cli.py
CHANGED
|
@@ -2,8 +2,8 @@ import argparse
|
|
|
2
2
|
import sys
|
|
3
3
|
import os # Add this import if not already present
|
|
4
4
|
from .utils.document_loaders import load_document, load_json_documents
|
|
5
|
-
from .core.chunkers import TextChunker, JSONChunker, PDFChunker
|
|
6
|
-
from .core.vectorstores import BM25VectorStore, FAISSVectorStore, TFIDFVectorStore, HybridVectorStore, HybridOfflineVectorStore
|
|
5
|
+
from .core.chunkers import ImageChunker, OfficeChunker, TextChunker, JSONChunker, PDFChunker
|
|
6
|
+
from .core.vectorstores import BM25SVectorStore, BM25VectorStore, FAISSVectorStore, TFIDFVectorStore, HybridVectorStore, HybridOfflineVectorStore
|
|
7
7
|
from .core.retrievers import SimpleRetriever, HybridRetriever
|
|
8
8
|
from .core.agents import RAGAgent
|
|
9
9
|
from .models.openrouter import OpenRouterLLM
|
|
@@ -19,22 +19,30 @@ def main():
|
|
|
19
19
|
query_parser = subparsers.add_parser("query", help="Query the RAG system")
|
|
20
20
|
query_parser.add_argument("--file", type=str, required=True, help="Path to document file")
|
|
21
21
|
query_parser.add_argument("--query", type=str, required=True, help="Query to ask")
|
|
22
|
-
query_parser.add_argument("--format", type=str, default="text",
|
|
23
|
-
|
|
22
|
+
query_parser.add_argument("--format", type=str, default="text",
|
|
23
|
+
choices=["text", "json", "pdf", "image", "docx", "excel", "pptx"],
|
|
24
|
+
help="Document format")
|
|
24
25
|
query_parser.add_argument("--vector-store", type=str, default=config.VECTOR_STORE_TYPE,
|
|
25
|
-
|
|
26
|
-
|
|
26
|
+
choices=["bm25", "bm25s", "faiss", "tfidf", "hybrid_online", "hybrid_offline"],
|
|
27
|
+
help="Vector store type")
|
|
28
|
+
query_parser.add_argument("--stream", action="store_true",
|
|
29
|
+
help="Enable streaming response")
|
|
27
30
|
query_parser.add_argument("--top-k", type=int, default=config.TOP_K, help="Number of results to retrieve")
|
|
28
31
|
query_parser.add_argument("--system-prompt", type=str, help="Path to a file containing the system prompt or the prompt text itself")
|
|
32
|
+
query_parser.add_argument("--ocr-mode", type=str, choices=["typed", "handwritten"],
|
|
33
|
+
default=config.OCR_DEFAULT_MODE,
|
|
34
|
+
help="OCR mode for image processing")
|
|
29
35
|
|
|
30
36
|
# Server command
|
|
31
37
|
server_parser = subparsers.add_parser("server", help="Start the RAG API server")
|
|
32
38
|
server_parser.add_argument("--file", type=str, required=True, help="Path to document file")
|
|
33
|
-
server_parser.add_argument("--format", type=str, default="text",
|
|
34
|
-
|
|
39
|
+
server_parser.add_argument("--format", type=str, default="text",
|
|
40
|
+
choices=["text", "json", "pdf", "image", "docx", "excel", "pptx"],
|
|
41
|
+
help="Document format")
|
|
42
|
+
# I Updated the server parser vector store choices
|
|
35
43
|
server_parser.add_argument("--vector-store", type=str, default=config.VECTOR_STORE_TYPE,
|
|
36
|
-
|
|
37
|
-
|
|
44
|
+
choices=["bm25", "bm25s", "faiss", "tfidf", "hybrid_online", "hybrid_offline"], # Add bm25s
|
|
45
|
+
help="Vector store type")
|
|
38
46
|
server_parser.add_argument("--port", type=int, default=config.SERVER_PORT, help="Port to run server on")
|
|
39
47
|
server_parser.add_argument("--host", type=str, default=config.SERVER_HOST, help="Host to run server on")
|
|
40
48
|
server_parser.add_argument("--system-prompt", type=str, help="Path to a file containing the system prompt or the prompt text itself")
|
|
@@ -53,6 +61,7 @@ def main():
|
|
|
53
61
|
with open(prompt_arg, 'r', encoding='utf-8') as f:
|
|
54
62
|
return f.read()
|
|
55
63
|
return prompt_arg
|
|
64
|
+
|
|
56
65
|
|
|
57
66
|
if args.command == "query":
|
|
58
67
|
# Load and process document
|
|
@@ -67,6 +76,17 @@ def main():
|
|
|
67
76
|
elif args.format == "pdf":
|
|
68
77
|
chunker = PDFChunker(chunk_size=config.CHUNK_SIZE, overlap=config.CHUNK_OVERLAP)
|
|
69
78
|
documents = chunker.chunk_pdf(args.file, {"source": args.file})
|
|
79
|
+
elif args.format == "image":
|
|
80
|
+
chunker = ImageChunker(
|
|
81
|
+
chunk_size=config.CHUNK_SIZE,
|
|
82
|
+
overlap=config.CHUNK_OVERLAP,
|
|
83
|
+
ocr_mode=getattr(args, 'ocr_mode', config.OCR_DEFAULT_MODE)
|
|
84
|
+
)
|
|
85
|
+
documents = chunker.chunk(args.file, {"source": args.file})
|
|
86
|
+
elif args.format in ["docx", "excel", "pptx"]:
|
|
87
|
+
# Use OfficeChunker for office documents
|
|
88
|
+
chunker = OfficeChunker(chunk_size=config.CHUNK_SIZE, overlap=config.CHUNK_OVERLAP)
|
|
89
|
+
documents = chunker.chunk_office(args.file, {"source": args.file})
|
|
70
90
|
else:
|
|
71
91
|
logger.error(f"Unsupported format: {args.format}")
|
|
72
92
|
return 1
|
|
@@ -82,6 +102,8 @@ def main():
|
|
|
82
102
|
vector_store = HybridVectorStore()
|
|
83
103
|
elif args.vector_store == "hybrid_offline":
|
|
84
104
|
vector_store = HybridOfflineVectorStore()
|
|
105
|
+
elif args.vector_store == "bm25s":
|
|
106
|
+
vector_store = BM25SVectorStore()
|
|
85
107
|
else:
|
|
86
108
|
logger.error(f"Unsupported vector store: {args.vector_store}")
|
|
87
109
|
return 1
|
|
@@ -95,9 +117,30 @@ def main():
|
|
|
95
117
|
agent = RAGAgent(retriever, llm, system_prompt=system_prompt)
|
|
96
118
|
|
|
97
119
|
# Query and print response
|
|
98
|
-
response = agent.query(args.query, top_k=args.top_k)
|
|
99
|
-
print(f"Query: {args.query}")
|
|
100
|
-
print(f"Response: {response}")
|
|
120
|
+
# response = agent.query(args.query, top_k=args.top_k)
|
|
121
|
+
# print(f"Query: {args.query}")
|
|
122
|
+
# print(f"Response: {response}")
|
|
123
|
+
|
|
124
|
+
# In the query section, after creating the agent:
|
|
125
|
+
if args.stream:
|
|
126
|
+
print(f"Query: {args.query}")
|
|
127
|
+
print("Response: ", end="", flush=True)
|
|
128
|
+
|
|
129
|
+
try:
|
|
130
|
+
# Collect all chunks and print them as they come
|
|
131
|
+
full_response = ""
|
|
132
|
+
for chunk in agent.query_stream(args.query, top_k=args.top_k):
|
|
133
|
+
print(chunk, end="", flush=True)
|
|
134
|
+
full_response += chunk
|
|
135
|
+
print() # New line at the end
|
|
136
|
+
|
|
137
|
+
# The response is already added to conversation in query_stream
|
|
138
|
+
except Exception as e:
|
|
139
|
+
print(f"\nError during streaming: {str(e)}")
|
|
140
|
+
else:
|
|
141
|
+
response = agent.query(args.query, top_k=args.top_k)
|
|
142
|
+
print(f"Query: {args.query}")
|
|
143
|
+
print(f"Response: {response}")
|
|
101
144
|
|
|
102
145
|
elif args.command == "server":
|
|
103
146
|
# Load and process document
|
|
@@ -127,6 +170,8 @@ def main():
|
|
|
127
170
|
vector_store = HybridVectorStore()
|
|
128
171
|
elif args.vector_store == "hybrid_offline":
|
|
129
172
|
vector_store = HybridOfflineVectorStore()
|
|
173
|
+
elif args.vector_store == "bm25s":
|
|
174
|
+
vector_store = BM25SVectorStore()
|
|
130
175
|
else:
|
|
131
176
|
logger.error(f"Unsupported vector store: {args.vector_store}")
|
|
132
177
|
return 1
|
kssrag/config.py
CHANGED
|
@@ -9,6 +9,7 @@ load_dotenv()
|
|
|
9
9
|
|
|
10
10
|
class VectorStoreType(str, Enum):
|
|
11
11
|
BM25 = "bm25"
|
|
12
|
+
BM25S = "bm25s"
|
|
12
13
|
FAISS = "faiss"
|
|
13
14
|
TFIDF = "tfidf"
|
|
14
15
|
HYBRID_ONLINE = "hybrid_online"
|
|
@@ -19,6 +20,7 @@ class ChunkerType(str, Enum):
|
|
|
19
20
|
TEXT = "text"
|
|
20
21
|
JSON = "json"
|
|
21
22
|
PDF = "pdf"
|
|
23
|
+
IMAGE = "image"
|
|
22
24
|
CUSTOM = "custom"
|
|
23
25
|
|
|
24
26
|
class RetrieverType(str, Enum):
|
|
@@ -36,7 +38,7 @@ class Config(BaseSettings):
|
|
|
36
38
|
)
|
|
37
39
|
|
|
38
40
|
DEFAULT_MODEL: str = Field(
|
|
39
|
-
default=os.getenv("DEFAULT_MODEL", "deepseek/deepseek-chat
|
|
41
|
+
default=os.getenv("DEFAULT_MODEL", "deepseek/deepseek-chat"),
|
|
40
42
|
description="Default model to use for LLM responses"
|
|
41
43
|
)
|
|
42
44
|
|
|
@@ -183,6 +185,18 @@ class Config(BaseSettings):
|
|
|
183
185
|
env_file = ".env"
|
|
184
186
|
case_sensitive = False
|
|
185
187
|
use_enum_values = True
|
|
188
|
+
|
|
189
|
+
# OCR settings
|
|
190
|
+
OCR_DEFAULT_MODE: str = Field(
|
|
191
|
+
default=os.getenv("OCR_DEFAULT_MODE", "typed"),
|
|
192
|
+
description="Default OCR mode: typed or handwritten"
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# Streaming settings
|
|
196
|
+
ENABLE_STREAMING: bool = Field(
|
|
197
|
+
default=os.getenv("ENABLE_STREAMING", "False").lower() == "true",
|
|
198
|
+
description="Whether to enable streaming responses"
|
|
199
|
+
)
|
|
186
200
|
|
|
187
201
|
@validator('FALLBACK_MODELS', 'CORS_ORIGINS', 'CORS_ALLOW_METHODS', 'CORS_ALLOW_HEADERS', pre=True)
|
|
188
202
|
def split_string(cls, v):
|
kssrag/core/agents.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import List, Dict, Any, Optional
|
|
1
|
+
from typing import Generator, List, Dict, Any, Optional
|
|
2
2
|
from ..utils.helpers import logger
|
|
3
3
|
|
|
4
4
|
class RAGAgent:
|
|
@@ -29,6 +29,32 @@ class RAGAgent:
|
|
|
29
29
|
# Keep the most recent messages
|
|
30
30
|
self.conversation = [system_msg] + other_msgs[-9:] if system_msg else other_msgs[-10:]
|
|
31
31
|
|
|
32
|
+
def _build_context(self, context_docs: List[Dict[str, Any]]) -> str:
|
|
33
|
+
"""Build context string from documents"""
|
|
34
|
+
if not context_docs:
|
|
35
|
+
return ""
|
|
36
|
+
|
|
37
|
+
context = "Relevant information:\n"
|
|
38
|
+
for i, doc in enumerate(context_docs, 1):
|
|
39
|
+
context += f"\n--- Document {i} ---\n{doc['content']}\n"
|
|
40
|
+
return context
|
|
41
|
+
|
|
42
|
+
def _build_messages(self, question: str, context: str = "") -> List[Dict[str, str]]:
|
|
43
|
+
"""Build messages for LLM including context"""
|
|
44
|
+
# Start with conversation history
|
|
45
|
+
messages = self.conversation.copy()
|
|
46
|
+
|
|
47
|
+
# Add user query with context
|
|
48
|
+
user_message = f"{context}\n\nQuestion: {question}" if context else question
|
|
49
|
+
|
|
50
|
+
# Replace the last user message if it exists, otherwise add new one
|
|
51
|
+
if messages and messages[-1]["role"] == "user":
|
|
52
|
+
messages[-1]["content"] = user_message
|
|
53
|
+
else:
|
|
54
|
+
messages.append({"role": "user", "content": user_message})
|
|
55
|
+
|
|
56
|
+
return messages
|
|
57
|
+
|
|
32
58
|
def query(self, question: str, top_k: int = 5, include_context: bool = True) -> str:
|
|
33
59
|
"""Process a query and return a response"""
|
|
34
60
|
try:
|
|
@@ -40,18 +66,13 @@ class RAGAgent:
|
|
|
40
66
|
return "I couldn't find relevant information to answer your question."
|
|
41
67
|
|
|
42
68
|
# Format context
|
|
43
|
-
context = ""
|
|
44
|
-
if include_context and context_docs:
|
|
45
|
-
context = "Relevant information:\n"
|
|
46
|
-
for i, doc in enumerate(context_docs, 1):
|
|
47
|
-
context += f"\n--- Document {i} ---\n{doc['content']}\n"
|
|
69
|
+
context = self._build_context(context_docs) if include_context and context_docs else ""
|
|
48
70
|
|
|
49
|
-
#
|
|
50
|
-
|
|
51
|
-
self.add_message("user", user_message)
|
|
71
|
+
# Build messages
|
|
72
|
+
messages = self._build_messages(question, context)
|
|
52
73
|
|
|
53
74
|
# Generate response
|
|
54
|
-
response = self.llm.predict(
|
|
75
|
+
response = self.llm.predict(messages)
|
|
55
76
|
|
|
56
77
|
# Add assistant response to conversation
|
|
57
78
|
self.add_message("assistant", response)
|
|
@@ -62,6 +83,37 @@ class RAGAgent:
|
|
|
62
83
|
logger.error(f"Error processing query: {str(e)}")
|
|
63
84
|
return "I encountered an issue processing your query. Please try again."
|
|
64
85
|
|
|
86
|
+
def query_stream(self, question: str, top_k: int = 5) -> Generator[str, None, None]:
|
|
87
|
+
"""Query the RAG system with streaming response"""
|
|
88
|
+
try:
|
|
89
|
+
# Retrieve relevant documents
|
|
90
|
+
relevant_docs = self.retriever.retrieve(question, top_k=top_k)
|
|
91
|
+
|
|
92
|
+
# Build context from documents
|
|
93
|
+
context = self._build_context(relevant_docs)
|
|
94
|
+
|
|
95
|
+
# Build messages
|
|
96
|
+
messages = self._build_messages(question, context)
|
|
97
|
+
|
|
98
|
+
# Stream response from LLM
|
|
99
|
+
if hasattr(self.llm, 'predict_stream'):
|
|
100
|
+
full_response = ""
|
|
101
|
+
for chunk in self.llm.predict_stream(messages):
|
|
102
|
+
full_response += chunk
|
|
103
|
+
yield chunk
|
|
104
|
+
|
|
105
|
+
# Add the complete response to conversation history
|
|
106
|
+
self.add_message("assistant", full_response)
|
|
107
|
+
else:
|
|
108
|
+
# Fallback to non-streaming
|
|
109
|
+
response = self.llm.predict(messages)
|
|
110
|
+
self.add_message("assistant", response)
|
|
111
|
+
yield response
|
|
112
|
+
|
|
113
|
+
except Exception as e:
|
|
114
|
+
logger.error(f"Error in streaming query: {str(e)}")
|
|
115
|
+
yield f"Error: {str(e)}"
|
|
116
|
+
|
|
65
117
|
def clear_conversation(self):
|
|
66
118
|
"""Clear conversation history except system message"""
|
|
67
119
|
system_msg = next((msg for msg in self.conversation if msg["role"] == "system"), None)
|
kssrag/core/chunkers.py
CHANGED
|
@@ -1,8 +1,16 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import re
|
|
3
|
+
import os
|
|
3
4
|
from typing import List, Dict, Any, Optional
|
|
4
5
|
import pypdf
|
|
5
6
|
from ..utils.helpers import logger
|
|
7
|
+
import os
|
|
8
|
+
try:
|
|
9
|
+
from ..utils.ocr_loader import OCRLoader
|
|
10
|
+
OCR_AVAILABLE = True
|
|
11
|
+
except ImportError:
|
|
12
|
+
OCR_AVAILABLE = False
|
|
13
|
+
OCRLoader = None
|
|
6
14
|
|
|
7
15
|
class BaseChunker:
|
|
8
16
|
"""Base class for document chunkers"""
|
|
@@ -46,7 +54,7 @@ class TextChunker(BaseChunker):
|
|
|
46
54
|
return chunks
|
|
47
55
|
|
|
48
56
|
class JSONChunker(BaseChunker):
|
|
49
|
-
"""Chunker for JSON documents
|
|
57
|
+
"""Chunker for JSON documents"""
|
|
50
58
|
|
|
51
59
|
def chunk(self, data: List[Dict[str, Any]], metadata_field: str = "name") -> List[Dict[str, Any]]:
|
|
52
60
|
"""Create chunks from JSON data"""
|
|
@@ -97,4 +105,90 @@ class PDFChunker(TextChunker):
|
|
|
97
105
|
def chunk_pdf(self, pdf_path: str, metadata: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
|
|
98
106
|
"""Extract text from PDF and chunk it"""
|
|
99
107
|
text = self.extract_text(pdf_path)
|
|
108
|
+
return self.chunk(text, metadata)
|
|
109
|
+
|
|
110
|
+
class ImageChunker(BaseChunker):
|
|
111
|
+
"""Chunker for image documents using OCR"""
|
|
112
|
+
|
|
113
|
+
def __init__(self, chunk_size: int = 500, overlap: int = 50, ocr_mode: str = "typed"):
|
|
114
|
+
super().__init__(chunk_size, overlap)
|
|
115
|
+
self.ocr_mode = ocr_mode # typed or handwritten
|
|
116
|
+
self.ocr_loader = None
|
|
117
|
+
|
|
118
|
+
# Initialize OCR loader
|
|
119
|
+
try:
|
|
120
|
+
from ..utils.ocr_loader import OCRLoader
|
|
121
|
+
self.ocr_loader = OCRLoader()
|
|
122
|
+
logger.info(f"OCR loader initialized with mode: {ocr_mode}")
|
|
123
|
+
except ImportError as e:
|
|
124
|
+
logger.error(f"OCR dependencies not available: {str(e)}")
|
|
125
|
+
raise ImportError(
|
|
126
|
+
"OCR functionality requires extra dependencies. "
|
|
127
|
+
"Install with: pip install kssrag[ocr]"
|
|
128
|
+
) from e
|
|
129
|
+
|
|
130
|
+
def extract_text_from_image(self, image_path: str) -> str:
|
|
131
|
+
"""Extract text from image using specified OCR engine"""
|
|
132
|
+
if not self.ocr_loader:
|
|
133
|
+
raise RuntimeError("OCR loader not initialized")
|
|
134
|
+
|
|
135
|
+
if self.ocr_mode not in ["typed", "handwritten"]:
|
|
136
|
+
raise ValueError(f"Invalid OCR mode: {self.ocr_mode}. Must be 'typed' or 'handwritten'")
|
|
137
|
+
|
|
138
|
+
logger.info(f"Extracting text from {image_path} using {self.ocr_mode} OCR")
|
|
139
|
+
|
|
140
|
+
try:
|
|
141
|
+
text = self.ocr_loader.extract_text(image_path, self.ocr_mode)
|
|
142
|
+
|
|
143
|
+
if not text.strip():
|
|
144
|
+
logger.warning(f"No text extracted from image: {image_path}")
|
|
145
|
+
return ""
|
|
146
|
+
|
|
147
|
+
logger.info(f"Successfully extracted {len(text)} characters from {image_path}")
|
|
148
|
+
return text
|
|
149
|
+
|
|
150
|
+
except Exception as e:
|
|
151
|
+
logger.error(f"OCR extraction failed for {image_path}: {str(e)}")
|
|
152
|
+
raise RuntimeError(f"Failed to extract text from image {image_path}: {str(e)}")
|
|
153
|
+
|
|
154
|
+
def chunk(self, image_path: str, metadata: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
|
|
155
|
+
"""Extract text from image and chunk it"""
|
|
156
|
+
if metadata is None:
|
|
157
|
+
metadata = {}
|
|
158
|
+
|
|
159
|
+
# Validate image file
|
|
160
|
+
if not os.path.exists(image_path):
|
|
161
|
+
raise FileNotFoundError(f"Image file not found: {image_path}")
|
|
162
|
+
|
|
163
|
+
# Extract text from image
|
|
164
|
+
text = self.extract_text_from_image(image_path)
|
|
165
|
+
|
|
166
|
+
if not text.strip():
|
|
167
|
+
return []
|
|
168
|
+
|
|
169
|
+
# Use text chunking on extracted text
|
|
170
|
+
text_chunker = TextChunker(chunk_size=self.chunk_size, overlap=self.overlap)
|
|
171
|
+
chunks = text_chunker.chunk(text, metadata)
|
|
172
|
+
|
|
173
|
+
# Add OCR-specific metadata
|
|
174
|
+
for chunk in chunks:
|
|
175
|
+
chunk["metadata"]["ocr_extracted"] = True
|
|
176
|
+
chunk["metadata"]["image_source"] = image_path
|
|
177
|
+
chunk["metadata"]["ocr_mode"] = self.ocr_mode
|
|
178
|
+
|
|
179
|
+
logger.info(f"Created {len(chunks)} chunks from image {image_path}")
|
|
180
|
+
return chunks
|
|
181
|
+
|
|
182
|
+
class OfficeChunker(TextChunker):
|
|
183
|
+
"""Chunker for Office documents (DOCX, Excel, PowerPoint)"""
|
|
184
|
+
|
|
185
|
+
def chunk_office(self, file_path: str, metadata: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
|
|
186
|
+
"""Chunk office documents by extracting text first"""
|
|
187
|
+
if metadata is None:
|
|
188
|
+
metadata = {}
|
|
189
|
+
|
|
190
|
+
# Extract text based on file type
|
|
191
|
+
from ..utils.document_loaders import load_document
|
|
192
|
+
text = load_document(file_path)
|
|
193
|
+
|
|
100
194
|
return self.chunk(text, metadata)
|
kssrag/core/vectorstores.py
CHANGED
|
@@ -109,13 +109,23 @@ class BM25VectorStore(BaseVectorStore):
|
|
|
109
109
|
logger.info(f"BM25 index loaded from {self.persist_path}")
|
|
110
110
|
|
|
111
111
|
import tempfile
|
|
112
|
+
# class FAISSVectorStore(BaseVectorStore):
|
|
113
|
+
# def __init__(self, persist_path: Optional[str] = None, model_name: Optional[str] = None):
|
|
114
|
+
# if not FAISS_AVAILABLE:
|
|
115
|
+
# raise ImportError("FAISS is not available. Please install it with 'pip install faiss-cpu' or use a different vector store.")
|
|
116
|
+
# super().__init__(persist_path)
|
|
117
|
+
# self.model_name = model_name or config.FAISS_MODEL_NAME
|
|
112
118
|
class FAISSVectorStore(BaseVectorStore):
|
|
113
119
|
def __init__(self, persist_path: Optional[str] = None, model_name: Optional[str] = None):
|
|
114
|
-
|
|
120
|
+
# Only setup FAISS when this vector store is actually used
|
|
121
|
+
from ..utils.helpers import setup_faiss
|
|
122
|
+
faiss_available, _ = setup_faiss("faiss") # Explicitly request FAISS
|
|
123
|
+
|
|
124
|
+
if not faiss_available:
|
|
115
125
|
raise ImportError("FAISS is not available. Please install it with 'pip install faiss-cpu' or use a different vector store.")
|
|
126
|
+
|
|
116
127
|
super().__init__(persist_path)
|
|
117
128
|
self.model_name = model_name or config.FAISS_MODEL_NAME
|
|
118
|
-
|
|
119
129
|
# Handle cache directory permissions
|
|
120
130
|
try:
|
|
121
131
|
cache_dir = config.CACHE_DIR
|
|
@@ -403,4 +413,86 @@ class HybridOfflineVectorStore(BaseVectorStore):
|
|
|
403
413
|
self.bm25_store.load()
|
|
404
414
|
self.tfidf_store.load()
|
|
405
415
|
self.documents = self.bm25_store.documents
|
|
406
|
-
logger.info(f"Hybrid offline index loaded")
|
|
416
|
+
logger.info(f"Hybrid offline index loaded")
|
|
417
|
+
|
|
418
|
+
import bm25s
|
|
419
|
+
from Stemmer import Stemmer
|
|
420
|
+
|
|
421
|
+
class BM25SVectorStore(BaseVectorStore):
|
|
422
|
+
"""BM25S vector store using the bm25s library for ultra-fast retrieval"""
|
|
423
|
+
|
|
424
|
+
def __init__(self, persist_path: Optional[str] = "bm25s_index.pkl"):
|
|
425
|
+
super().__init__(persist_path)
|
|
426
|
+
self.bm25_retriever = None
|
|
427
|
+
self.stemmer = Stemmer("english")
|
|
428
|
+
self.corpus_tokens = None
|
|
429
|
+
|
|
430
|
+
def add_documents(self, documents: List[Dict[str, Any]]):
|
|
431
|
+
self.documents = documents
|
|
432
|
+
self.doc_texts = [doc["content"] for doc in documents]
|
|
433
|
+
|
|
434
|
+
try:
|
|
435
|
+
# Tokenize corpus with BM25S
|
|
436
|
+
self.corpus_tokens = bm25s.tokenize(
|
|
437
|
+
self.doc_texts,
|
|
438
|
+
stopwords="en",
|
|
439
|
+
stemmer=self.stemmer
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
# Create and index with BM25S
|
|
443
|
+
self.bm25_retriever = bm25s.BM25()
|
|
444
|
+
self.bm25_retriever.index(self.corpus_tokens)
|
|
445
|
+
|
|
446
|
+
logger.info(f"BM25S index created with {len(self.documents)} documents")
|
|
447
|
+
|
|
448
|
+
except Exception as e:
|
|
449
|
+
logger.error(f"BM25S initialization failed: {str(e)}")
|
|
450
|
+
raise
|
|
451
|
+
|
|
452
|
+
def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
|
453
|
+
if not self.bm25_retriever:
|
|
454
|
+
raise ValueError("BM25S index not initialized. Call add_documents first.")
|
|
455
|
+
|
|
456
|
+
try:
|
|
457
|
+
# Tokenize query with BM25S
|
|
458
|
+
query_tokens = bm25s.tokenize([query], stemmer=self.stemmer)
|
|
459
|
+
|
|
460
|
+
# Retrieve with BM25S
|
|
461
|
+
results, scores = self.bm25_retriever.retrieve(query_tokens, k=top_k)
|
|
462
|
+
|
|
463
|
+
# Format results
|
|
464
|
+
retrieved_docs = []
|
|
465
|
+
for i in range(results.shape[1]):
|
|
466
|
+
doc_idx = results[0, i]
|
|
467
|
+
score = scores[0, i]
|
|
468
|
+
|
|
469
|
+
if doc_idx < len(self.documents):
|
|
470
|
+
retrieved_docs.append(self.documents[doc_idx])
|
|
471
|
+
|
|
472
|
+
logger.info(f"BM25S retrieved {len(retrieved_docs)} documents for query: {query}")
|
|
473
|
+
return retrieved_docs
|
|
474
|
+
|
|
475
|
+
except Exception as e:
|
|
476
|
+
logger.error(f"BM25S retrieval failed for query '{query}': {str(e)}")
|
|
477
|
+
return []
|
|
478
|
+
|
|
479
|
+
def persist(self):
|
|
480
|
+
if self.persist_path:
|
|
481
|
+
with open(self.persist_path, 'wb') as f:
|
|
482
|
+
pickle.dump({
|
|
483
|
+
'documents': self.documents,
|
|
484
|
+
'doc_texts': self.doc_texts,
|
|
485
|
+
'corpus_tokens': self.corpus_tokens,
|
|
486
|
+
'bm25_retriever': self.bm25_retriever
|
|
487
|
+
}, f)
|
|
488
|
+
logger.info(f"BM25S index persisted to {self.persist_path}")
|
|
489
|
+
|
|
490
|
+
def load(self):
|
|
491
|
+
if self.persist_path and os.path.exists(self.persist_path):
|
|
492
|
+
with open(self.persist_path, 'rb') as f:
|
|
493
|
+
data = pickle.load(f)
|
|
494
|
+
self.documents = data['documents']
|
|
495
|
+
self.doc_texts = data['doc_texts']
|
|
496
|
+
self.corpus_tokens = data['corpus_tokens']
|
|
497
|
+
self.bm25_retriever = data['bm25_retriever']
|
|
498
|
+
logger.info(f"BM25S index loaded from {self.persist_path}")
|
kssrag/models/openrouter.py
CHANGED
|
@@ -1,28 +1,35 @@
|
|
|
1
1
|
import requests
|
|
2
2
|
import json
|
|
3
|
-
from typing import List, Dict, Any, Optional
|
|
3
|
+
from typing import List, Dict, Any, Optional, Generator
|
|
4
4
|
from ..utils.helpers import logger
|
|
5
5
|
from ..config import config
|
|
6
6
|
|
|
7
7
|
class OpenRouterLLM:
|
|
8
|
-
"""OpenRouter LLM interface with
|
|
8
|
+
"""OpenRouter LLM interface with streaming support"""
|
|
9
9
|
|
|
10
10
|
def __init__(self, api_key: Optional[str] = None, model: Optional[str] = None,
|
|
11
|
-
fallback_models: Optional[List[str]] = None):
|
|
11
|
+
fallback_models: Optional[List[str]] = None, stream: bool = False):
|
|
12
12
|
self.api_key = api_key or config.OPENROUTER_API_KEY
|
|
13
13
|
self.model = model or config.DEFAULT_MODEL
|
|
14
14
|
self.fallback_models = fallback_models or config.FALLBACK_MODELS
|
|
15
|
+
self.stream = stream
|
|
15
16
|
self.base_url = "https://openrouter.ai/api/v1/chat/completions"
|
|
16
17
|
self.headers = {
|
|
17
18
|
"Authorization": f"Bearer {self.api_key}",
|
|
18
19
|
"Content-Type": "application/json",
|
|
19
20
|
"HTTP-Referer": "https://github.com/Ksschkw/kssrag",
|
|
20
|
-
"X-Title": "
|
|
21
|
+
"X-Title": "KSSRAG"
|
|
21
22
|
}
|
|
22
23
|
|
|
23
24
|
def predict(self, messages: List[Dict[str, str]]) -> str:
|
|
24
|
-
"""Generate
|
|
25
|
-
|
|
25
|
+
"""Generate response with fallback models"""
|
|
26
|
+
if self.stream:
|
|
27
|
+
full_response = ""
|
|
28
|
+
for chunk in self.predict_stream(messages):
|
|
29
|
+
full_response += chunk
|
|
30
|
+
return full_response
|
|
31
|
+
|
|
32
|
+
logger.info(f"Generating response with {len(messages)} messages")
|
|
26
33
|
|
|
27
34
|
for model in [self.model] + self.fallback_models:
|
|
28
35
|
payload = {
|
|
@@ -36,21 +43,17 @@ class OpenRouterLLM:
|
|
|
36
43
|
}
|
|
37
44
|
|
|
38
45
|
try:
|
|
39
|
-
logger.info(f"
|
|
46
|
+
logger.info(f"Using model: {model}")
|
|
40
47
|
response = requests.post(
|
|
41
48
|
self.base_url,
|
|
42
49
|
headers=self.headers,
|
|
43
50
|
json=payload,
|
|
44
|
-
timeout=
|
|
51
|
+
timeout=30
|
|
45
52
|
)
|
|
46
53
|
|
|
47
|
-
# Check for HTTP errors
|
|
48
54
|
response.raise_for_status()
|
|
49
|
-
|
|
50
|
-
# Parse JSON response
|
|
51
55
|
response_data = response.json()
|
|
52
56
|
|
|
53
|
-
# Validate response structure
|
|
54
57
|
if ("choices" not in response_data or
|
|
55
58
|
len(response_data["choices"]) == 0 or
|
|
56
59
|
"message" not in response_data["choices"][0] or
|
|
@@ -60,7 +63,7 @@ class OpenRouterLLM:
|
|
|
60
63
|
continue
|
|
61
64
|
|
|
62
65
|
content = response_data["choices"][0]["message"]["content"]
|
|
63
|
-
logger.info(f"Successfully
|
|
66
|
+
logger.info(f"Successfully generated response with model: {model}")
|
|
64
67
|
return content
|
|
65
68
|
|
|
66
69
|
except requests.exceptions.Timeout:
|
|
@@ -79,7 +82,66 @@ class OpenRouterLLM:
|
|
|
79
82
|
logger.warning(f"Unexpected error with model {model}: {str(e)}")
|
|
80
83
|
continue
|
|
81
84
|
|
|
82
|
-
|
|
83
|
-
error_msg = "I'm having trouble connecting to the knowledge service right now. Please try again in a moment."
|
|
85
|
+
error_msg = "Unable to generate response from available models. Please try again."
|
|
84
86
|
logger.error("All model fallbacks failed to respond")
|
|
85
|
-
return error_msg
|
|
87
|
+
return error_msg
|
|
88
|
+
|
|
89
|
+
def predict_stream(self, messages: List[Dict[str, str]]) -> Generator[str, None, None]:
|
|
90
|
+
"""Stream response from OpenRouter API"""
|
|
91
|
+
logger.info(f"Streaming response with {len(messages)} messages")
|
|
92
|
+
|
|
93
|
+
for model in [self.model] + self.fallback_models:
|
|
94
|
+
payload = {
|
|
95
|
+
"model": model,
|
|
96
|
+
"messages": messages,
|
|
97
|
+
"temperature": 0.7,
|
|
98
|
+
"max_tokens": 1024,
|
|
99
|
+
"top_p": 1,
|
|
100
|
+
"stop": None,
|
|
101
|
+
"stream": True
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
try:
|
|
105
|
+
logger.info(f"Streaming with model: {model}")
|
|
106
|
+
response = requests.post(
|
|
107
|
+
self.base_url,
|
|
108
|
+
headers=self.headers,
|
|
109
|
+
json=payload,
|
|
110
|
+
timeout=60,
|
|
111
|
+
stream=True
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
response.raise_for_status()
|
|
115
|
+
|
|
116
|
+
for line in response.iter_lines():
|
|
117
|
+
if line:
|
|
118
|
+
line = line.decode('utf-8')
|
|
119
|
+
if line.startswith('data: '):
|
|
120
|
+
data = line[6:]
|
|
121
|
+
if data.strip() == '[DONE]':
|
|
122
|
+
logger.info("Stream completed successfully")
|
|
123
|
+
return
|
|
124
|
+
try:
|
|
125
|
+
chunk_data = json.loads(data)
|
|
126
|
+
if ('choices' in chunk_data and
|
|
127
|
+
len(chunk_data['choices']) > 0 and
|
|
128
|
+
'delta' in chunk_data['choices'][0] and
|
|
129
|
+
'content' in chunk_data['choices'][0]['delta']):
|
|
130
|
+
|
|
131
|
+
content = chunk_data['choices'][0]['delta']['content']
|
|
132
|
+
if content:
|
|
133
|
+
yield content
|
|
134
|
+
except json.JSONDecodeError as e:
|
|
135
|
+
logger.warning(f"Failed to parse stream chunk: {str(e)}")
|
|
136
|
+
continue
|
|
137
|
+
|
|
138
|
+
logger.info(f"Successfully streamed from model: {model}")
|
|
139
|
+
return
|
|
140
|
+
|
|
141
|
+
except Exception as e:
|
|
142
|
+
logger.warning(f"Streaming failed with model {model}: {str(e)}")
|
|
143
|
+
continue
|
|
144
|
+
|
|
145
|
+
error_msg = "Unable to stream response from available models. Please try again."
|
|
146
|
+
logger.error("All model fallbacks failed for streaming")
|
|
147
|
+
yield error_msg
|