kssrag 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kssrag/__init__.py ADDED
@@ -0,0 +1,66 @@
1
+ """
2
+ KSS RAG - A flexible Retrieval-Augmented Generation framework by Ksschkw
3
+ """
4
+ from .kssrag import KSSRAG
5
+ from .core.chunkers import TextChunker, JSONChunker, PDFChunker
6
+ from .core.vectorstores import BM25VectorStore, FAISSVectorStore, TFIDFVectorStore, HybridVectorStore, HybridOfflineVectorStore
7
+ from .core.retrievers import SimpleRetriever, HybridRetriever
8
+ from .core.agents import RAGAgent
9
+ from .models.openrouter import OpenRouterLLM
10
+ from .utils.document_loaders import load_document, load_json_documents
11
+ from .utils.helpers import logger, validate_config
12
+ from .config import Config, VectorStoreType, ChunkerType, RetrieverType
13
+ from .server import create_app, ServerConfig
14
+ from .cli import main
15
+
16
+ __version__ = "0.1.0"
17
+ __author__ = "Ksschkw"
18
+ __license__ = "MIT"
19
+
20
+ # Your footprint - include your GitHub username and a signature
21
+ __signature__ = "Built with HATE by Ksschkw (github.com/Ksschkw)"
22
+
23
+ # Export the main classes for easy access
24
+ __all__ = [
25
+ 'KSSRAG',
26
+ 'TextChunker',
27
+ 'JSONChunker',
28
+ 'PDFChunker',
29
+ 'BM25VectorStore',
30
+ 'FAISSVectorStore',
31
+ 'TFIDFVectorStore',
32
+ 'HybridVectorStore',
33
+ 'HybridOfflineVectorStore',
34
+ 'SimpleRetriever',
35
+ 'HybridRetriever',
36
+ 'RAGAgent',
37
+ 'OpenRouterLLM',
38
+ 'load_document',
39
+ 'load_json_documents',
40
+ 'Config',
41
+ 'VectorStoreType',
42
+ 'ChunkerType',
43
+ 'RetrieverType',
44
+ 'ServerConfig',
45
+ 'create_app',
46
+ 'main',
47
+ 'logger',
48
+ 'validate_config'
49
+ ]
50
+
51
+ # Initialize configuration validation on import
52
+ validate_config()
53
+
54
+ import platform
55
+ from pathlib import Path
56
+ import os
57
+
58
+ # Windows-specific cache directory handling
59
+ if platform.system() == "Windows":
60
+ # Use local app data directory instead of home directory
61
+ cache_base = os.getenv('LOCALAPPDATA', os.path.expanduser('~'))
62
+ config.CACHE_DIR = os.path.join(cache_base, '.kssrag', 'cache')
63
+ else:
64
+ config.CACHE_DIR = os.path.join(os.path.expanduser('~'), '.cache', 'kssrag')
65
+
66
+ os.makedirs(config.CACHE_DIR, exist_ok=True)
kssrag/cli.py ADDED
@@ -0,0 +1,142 @@
1
+ import argparse
2
+ import sys
3
+ from .utils.document_loaders import load_document, load_json_documents
4
+ from .core.chunkers import TextChunker, JSONChunker, PDFChunker
5
+ from .core.vectorstores import BM25VectorStore, FAISSVectorStore, TFIDFVectorStore, HybridVectorStore, HybridOfflineVectorStore
6
+ from .core.retrievers import SimpleRetriever, HybridRetriever
7
+ from .core.agents import RAGAgent
8
+ from .models.openrouter import OpenRouterLLM
9
+ from .config import config
10
+ from .utils.helpers import logger, validate_config
11
+
12
+ def main():
13
+ """Command-line interface for KSS RAG"""
14
+ parser = argparse.ArgumentParser(description="KSS RAG - Retrieval-Augmented Generation Framework")
15
+ subparsers = parser.add_subparsers(dest="command", help="Command to execute")
16
+
17
+ # Query command
18
+ query_parser = subparsers.add_parser("query", help="Query the RAG system")
19
+ query_parser.add_argument("--file", type=str, required=True, help="Path to document file")
20
+ query_parser.add_argument("--query", type=str, required=True, help="Query to ask")
21
+ query_parser.add_argument("--format", type=str, default="text", choices=["text", "json", "pdf"],
22
+ help="Document format")
23
+ query_parser.add_argument("--vector-store", type=str, default=config.VECTOR_STORE_TYPE,
24
+ choices=["bm25", "faiss", "tfidf", "hybrid_online", "hybrid_offline"],
25
+ help="Vector store type")
26
+ query_parser.add_argument("--top-k", type=int, default=config.TOP_K, help="Number of results to retrieve")
27
+
28
+ # Server command
29
+ server_parser = subparsers.add_parser("server", help="Start the RAG API server")
30
+ server_parser.add_argument("--file", type=str, required=True, help="Path to document file")
31
+ server_parser.add_argument("--format", type=str, default="text", choices=["text", "json", "pdf"],
32
+ help="Document format")
33
+ server_parser.add_argument("--vector-store", type=str, default=config.VECTOR_STORE_TYPE,
34
+ choices=["bm25", "faiss", "tfidf", "hybrid_online", "hybrid_offline"],
35
+ help="Vector store type")
36
+ server_parser.add_argument("--port", type=int, default=config.SERVER_PORT, help="Port to run server on")
37
+ server_parser.add_argument("--host", type=str, default=config.SERVER_HOST, help="Host to run server on")
38
+
39
+ args = parser.parse_args()
40
+
41
+ # Validate config
42
+ validate_config()
43
+
44
+ if args.command == "query":
45
+ # Load and process document
46
+ if args.format == "text":
47
+ content = load_document(args.file)
48
+ chunker = TextChunker(chunk_size=config.CHUNK_SIZE, overlap=config.CHUNK_OVERLAP)
49
+ documents = chunker.chunk(content, {"source": args.file})
50
+ elif args.format == "json":
51
+ data = load_json_documents(args.file)
52
+ chunker = JSONChunker()
53
+ documents = chunker.chunk(data)
54
+ elif args.format == "pdf":
55
+ chunker = PDFChunker(chunk_size=config.CHUNK_SIZE, overlap=config.CHUNK_OVERLAP)
56
+ documents = chunker.chunk_pdf(args.file, {"source": args.file})
57
+ else:
58
+ logger.error(f"Unsupported format: {args.format}")
59
+ return 1
60
+
61
+ # Create vector store
62
+ if args.vector_store == "bm25":
63
+ vector_store = BM25VectorStore()
64
+ elif args.vector_store == "faiss":
65
+ vector_store = FAISSVectorStore()
66
+ elif args.vector_store == "tfidf":
67
+ vector_store = TFIDFVectorStore()
68
+ elif args.vector_store == "hybrid_online":
69
+ vector_store = HybridVectorStore()
70
+ elif args.vector_store == "hybrid_offline":
71
+ vector_store = HybridOfflineVectorStore()
72
+ else:
73
+ logger.error(f"Unsupported vector store: {args.vector_store}")
74
+ return 1
75
+
76
+ vector_store.add_documents(documents)
77
+
78
+ # Create retriever and agent
79
+ retriever = SimpleRetriever(vector_store)
80
+ llm = OpenRouterLLM()
81
+ agent = RAGAgent(retriever, llm)
82
+
83
+ # Query and print response
84
+ response = agent.query(args.query, top_k=args.top_k)
85
+ print(f"Query: {args.query}")
86
+ print(f"Response: {response}")
87
+
88
+ elif args.command == "server":
89
+ # Load and process document
90
+ if args.format == "text":
91
+ content = load_document(args.file)
92
+ chunker = TextChunker(chunk_size=config.CHUNK_SIZE, overlap=config.CHUNK_OVERLAP)
93
+ documents = chunker.chunk(content, {"source": args.file})
94
+ elif args.format == "json":
95
+ data = load_json_documents(args.file)
96
+ chunker = JSONChunker()
97
+ documents = chunker.chunk(data)
98
+ elif args.format == "pdf":
99
+ chunker = PDFChunker(chunk_size=config.CHUNK_SIZE, overlap=config.CHUNK_OVERLAP)
100
+ documents = chunker.chunk_pdf(args.file, {"source": args.file})
101
+ else:
102
+ logger.error(f"Unsupported format: {args.format}")
103
+ return 1
104
+
105
+ # Create vector store
106
+ if args.vector_store == "bm25":
107
+ vector_store = BM25VectorStore()
108
+ elif args.vector_store == "faiss":
109
+ vector_store = FAISSVectorStore()
110
+ elif args.vector_store == "tfidf":
111
+ vector_store = TFIDFVectorStore()
112
+ elif args.vector_store == "hybrid_online":
113
+ vector_store = HybridVectorStore()
114
+ elif args.vector_store == "hybrid_offline":
115
+ vector_store = HybridOfflineVectorStore()
116
+ else:
117
+ logger.error(f"Unsupported vector store: {args.vector_store}")
118
+ return 1
119
+
120
+ vector_store.add_documents(documents)
121
+
122
+ # Create retriever and agent
123
+ retriever = SimpleRetriever(vector_store)
124
+ llm = OpenRouterLLM()
125
+ agent = RAGAgent(retriever, llm)
126
+
127
+ # Create and run server
128
+ from .server import create_app
129
+ import uvicorn
130
+
131
+ app, server_config = create_app(agent)
132
+ logger.info(f"Starting server on {args.host}:{args.port}")
133
+ uvicorn.run(app, host=args.host, port=args.port)
134
+
135
+ else:
136
+ parser.print_help()
137
+ return 1
138
+
139
+ return 0
140
+
141
+ if __name__ == "__main__":
142
+ sys.exit(main())
kssrag/config.py ADDED
@@ -0,0 +1,193 @@
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from typing import List, Optional, Dict, Any
4
+ from pydantic import Field, validator
5
+ from pydantic_settings import BaseSettings
6
+ from enum import Enum
7
+
8
+ load_dotenv()
9
+
10
+ class VectorStoreType(str, Enum):
11
+ BM25 = "bm25"
12
+ FAISS = "faiss"
13
+ TFIDF = "tfidf"
14
+ HYBRID_ONLINE = "hybrid_online"
15
+ HYBRID_OFFLINE = "hybrid_offline"
16
+ CUSTOM = "custom"
17
+
18
+ class ChunkerType(str, Enum):
19
+ TEXT = "text"
20
+ JSON = "json"
21
+ PDF = "pdf"
22
+ CUSTOM = "custom"
23
+
24
+ class RetrieverType(str, Enum):
25
+ SIMPLE = "simple"
26
+ HYBRID = "hybrid"
27
+ CUSTOM = "custom"
28
+
29
+ class Config(BaseSettings):
30
+ """Configuration settings for KSS RAG with extensive customization options"""
31
+
32
+ # OpenRouter settings
33
+ OPENROUTER_API_KEY: str = Field(
34
+ default=os.getenv("OPENROUTER_API_KEY", ""),
35
+ description="Your OpenRouter API key for accessing LLMs"
36
+ )
37
+
38
+ DEFAULT_MODEL: str = Field(
39
+ default=os.getenv("DEFAULT_MODEL", "deepseek/deepseek-chat-v3.1:free"),
40
+ description="Default model to use for LLM responses"
41
+ )
42
+
43
+ FALLBACK_MODELS: List[str] = Field(
44
+ default=os.getenv("FALLBACK_MODELS", "deepseek/deepseek-r1-0528:free,deepseek/deepseek-chat,deepseek/deepseek-r1:free").split(","),
45
+ description="List of fallback models to try if the default model fails"
46
+ )
47
+
48
+ # Chunking settings
49
+ CHUNK_SIZE: int = Field(
50
+ default=int(os.getenv("CHUNK_SIZE", 500)),
51
+ ge=100,
52
+ le=2000,
53
+ description="Size of text chunks in characters"
54
+ )
55
+
56
+ CHUNK_OVERLAP: int = Field(
57
+ default=int(os.getenv("CHUNK_OVERLAP", 50)),
58
+ ge=0,
59
+ le=500,
60
+ description="Overlap between chunks in characters"
61
+ )
62
+
63
+ CHUNKER_TYPE: ChunkerType = Field(
64
+ default=os.getenv("CHUNKER_TYPE", ChunkerType.TEXT),
65
+ description="Type of chunker to use"
66
+ )
67
+
68
+ # Vector store settings
69
+ VECTOR_STORE_TYPE: VectorStoreType = Field(
70
+ default=os.getenv("VECTOR_STORE_TYPE", VectorStoreType.HYBRID_OFFLINE),
71
+ description="Type of vector store to use"
72
+ )
73
+
74
+ FAISS_MODEL_NAME: str = Field(
75
+ default=os.getenv("FAISS_MODEL_NAME", "sentence-transformers/all-MiniLM-L6-v2"),
76
+ description="SentenceTransformer model name for FAISS embeddings"
77
+ )
78
+
79
+ # Retrieval settings
80
+ RETRIEVER_TYPE: RetrieverType = Field(
81
+ default=os.getenv("RETRIEVER_TYPE", RetrieverType.SIMPLE),
82
+ description="Type of retriever to use"
83
+ )
84
+
85
+ TOP_K: int = Field(
86
+ default=int(os.getenv("TOP_K", 5)),
87
+ ge=1,
88
+ le=20,
89
+ description="Number of results to retrieve"
90
+ )
91
+
92
+ FUZZY_MATCH_THRESHOLD: int = Field(
93
+ default=int(os.getenv("FUZZY_MATCH_THRESHOLD", 80)),
94
+ ge=0,
95
+ le=100,
96
+ description="Threshold for fuzzy matching (0-100)"
97
+ )
98
+
99
+ # Performance settings
100
+ BATCH_SIZE: int = Field(
101
+ default=int(os.getenv("BATCH_SIZE", 64)),
102
+ ge=1,
103
+ le=256,
104
+ description="Batch size for processing documents"
105
+ )
106
+
107
+ MAX_DOCS_FOR_TESTING: Optional[int] = Field(
108
+ default=os.getenv("MAX_DOCS_FOR_TESTING"),
109
+ description="Limit documents for testing (None for all)"
110
+ )
111
+
112
+ # Server settings
113
+ SERVER_HOST: str = Field(
114
+ default=os.getenv("SERVER_HOST", "localhost"),
115
+ description="Host to run the server on"
116
+ )
117
+
118
+ SERVER_PORT: int = Field(
119
+ default=int(os.getenv("SERVER_PORT", 8000)),
120
+ ge=1024,
121
+ le=65535,
122
+ description="Port to run the server on"
123
+ )
124
+
125
+ CORS_ORIGINS: List[str] = Field(
126
+ default=os.getenv("CORS_ORIGINS", "*").split(","),
127
+ description="List of CORS origins"
128
+ )
129
+
130
+ CORS_ALLOW_CREDENTIALS: bool = Field(
131
+ default=os.getenv("CORS_ALLOW_CREDENTIALS", "True").lower() == "true",
132
+ description="Whether to allow CORS credentials"
133
+ )
134
+
135
+ CORS_ALLOW_METHODS: List[str] = Field(
136
+ default=os.getenv("CORS_ALLOW_METHODS", "GET,POST,PUT,DELETE,OPTIONS").split(","),
137
+ description="List of allowed CORS methods"
138
+ )
139
+
140
+ CORS_ALLOW_HEADERS: List[str] = Field(
141
+ default=os.getenv("CORS_ALLOW_HEADERS", "Content-Type,Authorization").split(","),
142
+ description="List of allowed CORS headers"
143
+ )
144
+
145
+ # Advanced settings
146
+ ENABLE_CACHE: bool = Field(
147
+ default=os.getenv("ENABLE_CACHE", "True").lower() == "true",
148
+ description="Whether to enable caching for vector stores"
149
+ )
150
+
151
+ CACHE_DIR: str = Field(
152
+ default=os.getenv("CACHE_DIR", ".cache"),
153
+ description="Directory to store cache files"
154
+ )
155
+
156
+ LOG_LEVEL: str = Field(
157
+ default=os.getenv("LOG_LEVEL", "INFO"),
158
+ description="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)"
159
+ )
160
+
161
+ # Custom components (for advanced users)
162
+ CUSTOM_CHUNKER: Optional[str] = Field(
163
+ default=os.getenv("CUSTOM_CHUNKER"),
164
+ description="Import path to a custom chunker class"
165
+ )
166
+
167
+ CUSTOM_VECTOR_STORE: Optional[str] = Field(
168
+ default=os.getenv("CUSTOM_VECTOR_STORE"),
169
+ description="Import path to a custom vector store class"
170
+ )
171
+
172
+ CUSTOM_RETRIEVER: Optional[str] = Field(
173
+ default=os.getenv("CUSTOM_RETRIEVER"),
174
+ description="Import path to a custom retriever class"
175
+ )
176
+
177
+ CUSTOM_LLM: Optional[str] = Field(
178
+ default=os.getenv("CUSTOM_LLM"),
179
+ description="Import path to a custom LLM class"
180
+ )
181
+
182
+ class Config:
183
+ env_file = ".env"
184
+ case_sensitive = False
185
+ use_enum_values = True
186
+
187
+ @validator('FALLBACK_MODELS', 'CORS_ORIGINS', 'CORS_ALLOW_METHODS', 'CORS_ALLOW_HEADERS', pre=True)
188
+ def split_string(cls, v):
189
+ if isinstance(v, str):
190
+ return [item.strip() for item in v.split(',')]
191
+ return v
192
+
193
+ config = Config()
File without changes
kssrag/core/agents.py ADDED
@@ -0,0 +1,68 @@
1
+ from typing import List, Dict, Any, Optional
2
+ from ..utils.helpers import logger
3
+
4
+ class RAGAgent:
5
+ """RAG agent implementation"""
6
+
7
+ def __init__(self, retriever, llm, system_prompt: Optional[str] = None,
8
+ conversation_history: Optional[List[Dict[str, str]]] = None):
9
+ self.retriever = retriever
10
+ self.llm = llm
11
+ self.conversation = conversation_history or []
12
+ self.system_prompt = system_prompt or """You are a helpful AI assistant. Use the following context to answer the user's question.
13
+ If you don't know the answer based on the context, say so."""
14
+
15
+ # Initialize with system message if not already present
16
+ if not any(msg.get("role") == "system" for msg in self.conversation):
17
+ self.add_message("system", self.system_prompt)
18
+
19
+ def add_message(self, role: str, content: str):
20
+ """Add a message to the conversation history"""
21
+ self.conversation.append({"role": role, "content": content})
22
+
23
+ # Keep conversation manageable (last 10 messages)
24
+ if len(self.conversation) > 10:
25
+ # Always keep the system message
26
+ system_msg = next((msg for msg in self.conversation if msg["role"] == "system"), None)
27
+ other_msgs = [msg for msg in self.conversation if msg["role"] != "system"]
28
+
29
+ # Keep the most recent messages
30
+ self.conversation = [system_msg] + other_msgs[-9:] if system_msg else other_msgs[-10:]
31
+
32
+ def query(self, question: str, top_k: int = 5, include_context: bool = True) -> str:
33
+ """Process a query and return a response"""
34
+ try:
35
+ # Retrieve relevant context
36
+ context_docs = self.retriever.retrieve(question, top_k)
37
+
38
+ if not context_docs and include_context:
39
+ logger.warning(f"No context found for query: {question}")
40
+ return "I couldn't find relevant information to answer your question."
41
+
42
+ # Format context
43
+ context = ""
44
+ if include_context and context_docs:
45
+ context = "Relevant information:\n"
46
+ for i, doc in enumerate(context_docs, 1):
47
+ context += f"\n--- Document {i} ---\n{doc['content']}\n"
48
+
49
+ # Add user query with context
50
+ user_message = f"{context}\n\nQuestion: {question}" if context else question
51
+ self.add_message("user", user_message)
52
+
53
+ # Generate response
54
+ response = self.llm.predict(self.conversation)
55
+
56
+ # Add assistant response to conversation
57
+ self.add_message("assistant", response)
58
+
59
+ return response
60
+
61
+ except Exception as e:
62
+ logger.error(f"Error processing query: {str(e)}")
63
+ return "I encountered an issue processing your query. Please try again."
64
+
65
+ def clear_conversation(self):
66
+ """Clear conversation history except system message"""
67
+ system_msg = next((msg for msg in self.conversation if msg["role"] == "system"), None)
68
+ self.conversation = [system_msg] if system_msg else []
@@ -0,0 +1,100 @@
1
+ import json
2
+ import re
3
+ from typing import List, Dict, Any, Optional
4
+ import pypdf
5
+ from ..utils.helpers import logger
6
+
7
+ class BaseChunker:
8
+ """Base class for document chunkers"""
9
+
10
+ def __init__(self, chunk_size: int = 500, overlap: int = 50):
11
+ self.chunk_size = chunk_size
12
+ self.overlap = overlap
13
+
14
+ def chunk(self, content: Any, metadata: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
15
+ raise NotImplementedError("Subclasses must implement this method")
16
+
17
+ class TextChunker(BaseChunker):
18
+ """Chunker for plain text documents"""
19
+
20
+ def chunk(self, content: str, metadata: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
21
+ """Split text into chunks with overlap"""
22
+ if metadata is None:
23
+ metadata = {}
24
+
25
+ chunks = []
26
+ start = 0
27
+ content_length = len(content)
28
+
29
+ while start < content_length:
30
+ end = start + self.chunk_size
31
+ if end > content_length:
32
+ end = content_length
33
+
34
+ chunk = content[start:end]
35
+ chunk_metadata = metadata.copy()
36
+ chunk_metadata["chunk_id"] = len(chunks)
37
+
38
+ chunks.append({
39
+ "content": chunk,
40
+ "metadata": chunk_metadata
41
+ })
42
+
43
+ start += self.chunk_size - self.overlap
44
+
45
+ logger.info(f"Created {len(chunks)} chunks from text")
46
+ return chunks
47
+
48
+ class JSONChunker(BaseChunker):
49
+ """Chunker for JSON documents (like your drug data)"""
50
+
51
+ def chunk(self, data: List[Dict[str, Any]], metadata_field: str = "name") -> List[Dict[str, Any]]:
52
+ """Create chunks from JSON data"""
53
+ chunks = []
54
+
55
+ for item in data:
56
+ if metadata_field not in item:
57
+ continue
58
+
59
+ # Create a comprehensive text representation
60
+ item_text = f"Item Name: {item.get(metadata_field, 'N/A')}\n"
61
+
62
+ for key, value in item.items():
63
+ if key != metadata_field and value:
64
+ if isinstance(value, str):
65
+ item_text += f"{key.replace('_', ' ').title()}: {value}\n"
66
+ elif isinstance(value, list):
67
+ item_text += f"{key.replace('_', ' ').title()}: {', '.join(value)}\n"
68
+
69
+ chunks.append({
70
+ "content": item_text,
71
+ "metadata": {
72
+ "name": item[metadata_field],
73
+ "source": item.get('url', 'N/A')
74
+ }
75
+ })
76
+
77
+ logger.info(f"Created {len(chunks)} chunks from JSON data")
78
+ return chunks
79
+
80
+ class PDFChunker(TextChunker):
81
+ """Chunker for PDF documents"""
82
+
83
+ def extract_text(self, pdf_path: str) -> str:
84
+ """Extract text from PDF file"""
85
+ text = ""
86
+ try:
87
+ with open(pdf_path, 'rb') as f:
88
+ reader = pypdf.PdfReader(f)
89
+ for page in reader.pages:
90
+ text += page.extract_text() + "\n"
91
+ except Exception as e:
92
+ logger.error(f"Failed to extract text from PDF: {str(e)}")
93
+ raise
94
+
95
+ return text
96
+
97
+ def chunk_pdf(self, pdf_path: str, metadata: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
98
+ """Extract text from PDF and chunk it"""
99
+ text = self.extract_text(pdf_path)
100
+ return self.chunk(text, metadata)
@@ -0,0 +1,74 @@
1
+ from typing import List, Dict, Any, Optional
2
+ from ..utils.helpers import logger
3
+
4
+ class BaseRetriever:
5
+ """Base class for retrievers"""
6
+
7
+ def __init__(self, vector_store):
8
+ self.vector_store = vector_store
9
+
10
+ def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
11
+ """Retrieve documents based on query"""
12
+ raise NotImplementedError("Subclasses must implement this method")
13
+
14
+ class SimpleRetriever(BaseRetriever):
15
+ """Simple retriever using only vector store"""
16
+
17
+ def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
18
+ return self.vector_store.retrieve(query, top_k)
19
+
20
+ class HybridRetriever(BaseRetriever):
21
+ """Hybrid retriever with fuzzy matching for specific entities"""
22
+
23
+ def __init__(self, vector_store, entity_names: Optional[List[str]] = None):
24
+ super().__init__(vector_store)
25
+ self.entity_names = entity_names or []
26
+
27
+ def extract_entities(self, query: str) -> List[str]:
28
+ """Extract entities from query using fuzzy matching"""
29
+ from rapidfuzz import process, fuzz
30
+
31
+ extracted_entities = []
32
+ query_lower = query.lower()
33
+
34
+ # Check for exact matches first
35
+ for entity in self.entity_names:
36
+ if entity in query_lower:
37
+ extracted_entities.append(entity)
38
+
39
+ # Use fuzzy matching for partial matches
40
+ if not extracted_entities and self.entity_names:
41
+ matches = process.extract(query, self.entity_names, scorer=fuzz.partial_ratio, limit=5)
42
+ extracted_entities = [match[0] for match in matches if match[1] > 80]
43
+
44
+ return extracted_entities
45
+
46
+ def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
47
+ # First get results from vector store
48
+ results = self.vector_store.retrieve(query, top_k * 2)
49
+
50
+ # If we have entity names, boost documents that mention extracted entities
51
+ if self.entity_names:
52
+ extracted_entities = self.extract_entities(query)
53
+
54
+ if extracted_entities:
55
+ # Boost scores for documents containing extracted entities
56
+ scored_results = []
57
+
58
+ for doc in results:
59
+ score = 1.0
60
+ content_lower = doc["content"].lower()
61
+
62
+ # Boost score if any extracted entity is mentioned
63
+ for entity in extracted_entities:
64
+ if entity in content_lower:
65
+ score += 0.5
66
+ break
67
+
68
+ scored_results.append((doc, score))
69
+
70
+ # Sort by score and return top_k
71
+ scored_results.sort(key=lambda x: x[1], reverse=True)
72
+ return [doc for doc, _ in scored_results[:top_k]]
73
+
74
+ return results[:top_k]