kssrag 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kssrag/__init__.py +66 -0
- kssrag/cli.py +142 -0
- kssrag/config.py +193 -0
- kssrag/core/__init__.py +0 -0
- kssrag/core/agents.py +68 -0
- kssrag/core/chunkers.py +100 -0
- kssrag/core/retrievers.py +74 -0
- kssrag/core/vectorstores.py +397 -0
- kssrag/kssrag.py +116 -0
- kssrag/models/__init__.py +0 -0
- kssrag/models/local_llms.py +30 -0
- kssrag/models/openrouter.py +85 -0
- kssrag/server.py +116 -0
- kssrag/utils/__init__.py +0 -0
- kssrag/utils/document_loaders.py +40 -0
- kssrag/utils/helpers.py +55 -0
- kssrag/utils/preprocessors.py +30 -0
- kssrag-0.1.0.dist-info/METADATA +407 -0
- kssrag-0.1.0.dist-info/RECORD +26 -0
- kssrag-0.1.0.dist-info/WHEEL +5 -0
- kssrag-0.1.0.dist-info/entry_points.txt +2 -0
- kssrag-0.1.0.dist-info/licenses/LICENSE +0 -0
- kssrag-0.1.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +0 -0
- tests/test_basic.py +43 -0
- tests/test_vectorstores.py +35 -0
kssrag/__init__.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
"""
|
|
2
|
+
KSS RAG - A flexible Retrieval-Augmented Generation framework by Ksschkw
|
|
3
|
+
"""
|
|
4
|
+
from .kssrag import KSSRAG
|
|
5
|
+
from .core.chunkers import TextChunker, JSONChunker, PDFChunker
|
|
6
|
+
from .core.vectorstores import BM25VectorStore, FAISSVectorStore, TFIDFVectorStore, HybridVectorStore, HybridOfflineVectorStore
|
|
7
|
+
from .core.retrievers import SimpleRetriever, HybridRetriever
|
|
8
|
+
from .core.agents import RAGAgent
|
|
9
|
+
from .models.openrouter import OpenRouterLLM
|
|
10
|
+
from .utils.document_loaders import load_document, load_json_documents
|
|
11
|
+
from .utils.helpers import logger, validate_config
|
|
12
|
+
from .config import Config, VectorStoreType, ChunkerType, RetrieverType
|
|
13
|
+
from .server import create_app, ServerConfig
|
|
14
|
+
from .cli import main
|
|
15
|
+
|
|
16
|
+
__version__ = "0.1.0"
|
|
17
|
+
__author__ = "Ksschkw"
|
|
18
|
+
__license__ = "MIT"
|
|
19
|
+
|
|
20
|
+
# Your footprint - include your GitHub username and a signature
|
|
21
|
+
__signature__ = "Built with HATE by Ksschkw (github.com/Ksschkw)"
|
|
22
|
+
|
|
23
|
+
# Export the main classes for easy access
|
|
24
|
+
__all__ = [
|
|
25
|
+
'KSSRAG',
|
|
26
|
+
'TextChunker',
|
|
27
|
+
'JSONChunker',
|
|
28
|
+
'PDFChunker',
|
|
29
|
+
'BM25VectorStore',
|
|
30
|
+
'FAISSVectorStore',
|
|
31
|
+
'TFIDFVectorStore',
|
|
32
|
+
'HybridVectorStore',
|
|
33
|
+
'HybridOfflineVectorStore',
|
|
34
|
+
'SimpleRetriever',
|
|
35
|
+
'HybridRetriever',
|
|
36
|
+
'RAGAgent',
|
|
37
|
+
'OpenRouterLLM',
|
|
38
|
+
'load_document',
|
|
39
|
+
'load_json_documents',
|
|
40
|
+
'Config',
|
|
41
|
+
'VectorStoreType',
|
|
42
|
+
'ChunkerType',
|
|
43
|
+
'RetrieverType',
|
|
44
|
+
'ServerConfig',
|
|
45
|
+
'create_app',
|
|
46
|
+
'main',
|
|
47
|
+
'logger',
|
|
48
|
+
'validate_config'
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
# Initialize configuration validation on import
|
|
52
|
+
validate_config()
|
|
53
|
+
|
|
54
|
+
import platform
|
|
55
|
+
from pathlib import Path
|
|
56
|
+
import os
|
|
57
|
+
|
|
58
|
+
# Windows-specific cache directory handling
|
|
59
|
+
if platform.system() == "Windows":
|
|
60
|
+
# Use local app data directory instead of home directory
|
|
61
|
+
cache_base = os.getenv('LOCALAPPDATA', os.path.expanduser('~'))
|
|
62
|
+
config.CACHE_DIR = os.path.join(cache_base, '.kssrag', 'cache')
|
|
63
|
+
else:
|
|
64
|
+
config.CACHE_DIR = os.path.join(os.path.expanduser('~'), '.cache', 'kssrag')
|
|
65
|
+
|
|
66
|
+
os.makedirs(config.CACHE_DIR, exist_ok=True)
|
kssrag/cli.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import sys
|
|
3
|
+
from .utils.document_loaders import load_document, load_json_documents
|
|
4
|
+
from .core.chunkers import TextChunker, JSONChunker, PDFChunker
|
|
5
|
+
from .core.vectorstores import BM25VectorStore, FAISSVectorStore, TFIDFVectorStore, HybridVectorStore, HybridOfflineVectorStore
|
|
6
|
+
from .core.retrievers import SimpleRetriever, HybridRetriever
|
|
7
|
+
from .core.agents import RAGAgent
|
|
8
|
+
from .models.openrouter import OpenRouterLLM
|
|
9
|
+
from .config import config
|
|
10
|
+
from .utils.helpers import logger, validate_config
|
|
11
|
+
|
|
12
|
+
def main():
|
|
13
|
+
"""Command-line interface for KSS RAG"""
|
|
14
|
+
parser = argparse.ArgumentParser(description="KSS RAG - Retrieval-Augmented Generation Framework")
|
|
15
|
+
subparsers = parser.add_subparsers(dest="command", help="Command to execute")
|
|
16
|
+
|
|
17
|
+
# Query command
|
|
18
|
+
query_parser = subparsers.add_parser("query", help="Query the RAG system")
|
|
19
|
+
query_parser.add_argument("--file", type=str, required=True, help="Path to document file")
|
|
20
|
+
query_parser.add_argument("--query", type=str, required=True, help="Query to ask")
|
|
21
|
+
query_parser.add_argument("--format", type=str, default="text", choices=["text", "json", "pdf"],
|
|
22
|
+
help="Document format")
|
|
23
|
+
query_parser.add_argument("--vector-store", type=str, default=config.VECTOR_STORE_TYPE,
|
|
24
|
+
choices=["bm25", "faiss", "tfidf", "hybrid_online", "hybrid_offline"],
|
|
25
|
+
help="Vector store type")
|
|
26
|
+
query_parser.add_argument("--top-k", type=int, default=config.TOP_K, help="Number of results to retrieve")
|
|
27
|
+
|
|
28
|
+
# Server command
|
|
29
|
+
server_parser = subparsers.add_parser("server", help="Start the RAG API server")
|
|
30
|
+
server_parser.add_argument("--file", type=str, required=True, help="Path to document file")
|
|
31
|
+
server_parser.add_argument("--format", type=str, default="text", choices=["text", "json", "pdf"],
|
|
32
|
+
help="Document format")
|
|
33
|
+
server_parser.add_argument("--vector-store", type=str, default=config.VECTOR_STORE_TYPE,
|
|
34
|
+
choices=["bm25", "faiss", "tfidf", "hybrid_online", "hybrid_offline"],
|
|
35
|
+
help="Vector store type")
|
|
36
|
+
server_parser.add_argument("--port", type=int, default=config.SERVER_PORT, help="Port to run server on")
|
|
37
|
+
server_parser.add_argument("--host", type=str, default=config.SERVER_HOST, help="Host to run server on")
|
|
38
|
+
|
|
39
|
+
args = parser.parse_args()
|
|
40
|
+
|
|
41
|
+
# Validate config
|
|
42
|
+
validate_config()
|
|
43
|
+
|
|
44
|
+
if args.command == "query":
|
|
45
|
+
# Load and process document
|
|
46
|
+
if args.format == "text":
|
|
47
|
+
content = load_document(args.file)
|
|
48
|
+
chunker = TextChunker(chunk_size=config.CHUNK_SIZE, overlap=config.CHUNK_OVERLAP)
|
|
49
|
+
documents = chunker.chunk(content, {"source": args.file})
|
|
50
|
+
elif args.format == "json":
|
|
51
|
+
data = load_json_documents(args.file)
|
|
52
|
+
chunker = JSONChunker()
|
|
53
|
+
documents = chunker.chunk(data)
|
|
54
|
+
elif args.format == "pdf":
|
|
55
|
+
chunker = PDFChunker(chunk_size=config.CHUNK_SIZE, overlap=config.CHUNK_OVERLAP)
|
|
56
|
+
documents = chunker.chunk_pdf(args.file, {"source": args.file})
|
|
57
|
+
else:
|
|
58
|
+
logger.error(f"Unsupported format: {args.format}")
|
|
59
|
+
return 1
|
|
60
|
+
|
|
61
|
+
# Create vector store
|
|
62
|
+
if args.vector_store == "bm25":
|
|
63
|
+
vector_store = BM25VectorStore()
|
|
64
|
+
elif args.vector_store == "faiss":
|
|
65
|
+
vector_store = FAISSVectorStore()
|
|
66
|
+
elif args.vector_store == "tfidf":
|
|
67
|
+
vector_store = TFIDFVectorStore()
|
|
68
|
+
elif args.vector_store == "hybrid_online":
|
|
69
|
+
vector_store = HybridVectorStore()
|
|
70
|
+
elif args.vector_store == "hybrid_offline":
|
|
71
|
+
vector_store = HybridOfflineVectorStore()
|
|
72
|
+
else:
|
|
73
|
+
logger.error(f"Unsupported vector store: {args.vector_store}")
|
|
74
|
+
return 1
|
|
75
|
+
|
|
76
|
+
vector_store.add_documents(documents)
|
|
77
|
+
|
|
78
|
+
# Create retriever and agent
|
|
79
|
+
retriever = SimpleRetriever(vector_store)
|
|
80
|
+
llm = OpenRouterLLM()
|
|
81
|
+
agent = RAGAgent(retriever, llm)
|
|
82
|
+
|
|
83
|
+
# Query and print response
|
|
84
|
+
response = agent.query(args.query, top_k=args.top_k)
|
|
85
|
+
print(f"Query: {args.query}")
|
|
86
|
+
print(f"Response: {response}")
|
|
87
|
+
|
|
88
|
+
elif args.command == "server":
|
|
89
|
+
# Load and process document
|
|
90
|
+
if args.format == "text":
|
|
91
|
+
content = load_document(args.file)
|
|
92
|
+
chunker = TextChunker(chunk_size=config.CHUNK_SIZE, overlap=config.CHUNK_OVERLAP)
|
|
93
|
+
documents = chunker.chunk(content, {"source": args.file})
|
|
94
|
+
elif args.format == "json":
|
|
95
|
+
data = load_json_documents(args.file)
|
|
96
|
+
chunker = JSONChunker()
|
|
97
|
+
documents = chunker.chunk(data)
|
|
98
|
+
elif args.format == "pdf":
|
|
99
|
+
chunker = PDFChunker(chunk_size=config.CHUNK_SIZE, overlap=config.CHUNK_OVERLAP)
|
|
100
|
+
documents = chunker.chunk_pdf(args.file, {"source": args.file})
|
|
101
|
+
else:
|
|
102
|
+
logger.error(f"Unsupported format: {args.format}")
|
|
103
|
+
return 1
|
|
104
|
+
|
|
105
|
+
# Create vector store
|
|
106
|
+
if args.vector_store == "bm25":
|
|
107
|
+
vector_store = BM25VectorStore()
|
|
108
|
+
elif args.vector_store == "faiss":
|
|
109
|
+
vector_store = FAISSVectorStore()
|
|
110
|
+
elif args.vector_store == "tfidf":
|
|
111
|
+
vector_store = TFIDFVectorStore()
|
|
112
|
+
elif args.vector_store == "hybrid_online":
|
|
113
|
+
vector_store = HybridVectorStore()
|
|
114
|
+
elif args.vector_store == "hybrid_offline":
|
|
115
|
+
vector_store = HybridOfflineVectorStore()
|
|
116
|
+
else:
|
|
117
|
+
logger.error(f"Unsupported vector store: {args.vector_store}")
|
|
118
|
+
return 1
|
|
119
|
+
|
|
120
|
+
vector_store.add_documents(documents)
|
|
121
|
+
|
|
122
|
+
# Create retriever and agent
|
|
123
|
+
retriever = SimpleRetriever(vector_store)
|
|
124
|
+
llm = OpenRouterLLM()
|
|
125
|
+
agent = RAGAgent(retriever, llm)
|
|
126
|
+
|
|
127
|
+
# Create and run server
|
|
128
|
+
from .server import create_app
|
|
129
|
+
import uvicorn
|
|
130
|
+
|
|
131
|
+
app, server_config = create_app(agent)
|
|
132
|
+
logger.info(f"Starting server on {args.host}:{args.port}")
|
|
133
|
+
uvicorn.run(app, host=args.host, port=args.port)
|
|
134
|
+
|
|
135
|
+
else:
|
|
136
|
+
parser.print_help()
|
|
137
|
+
return 1
|
|
138
|
+
|
|
139
|
+
return 0
|
|
140
|
+
|
|
141
|
+
if __name__ == "__main__":
|
|
142
|
+
sys.exit(main())
|
kssrag/config.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from dotenv import load_dotenv
|
|
3
|
+
from typing import List, Optional, Dict, Any
|
|
4
|
+
from pydantic import Field, validator
|
|
5
|
+
from pydantic_settings import BaseSettings
|
|
6
|
+
from enum import Enum
|
|
7
|
+
|
|
8
|
+
load_dotenv()
|
|
9
|
+
|
|
10
|
+
class VectorStoreType(str, Enum):
|
|
11
|
+
BM25 = "bm25"
|
|
12
|
+
FAISS = "faiss"
|
|
13
|
+
TFIDF = "tfidf"
|
|
14
|
+
HYBRID_ONLINE = "hybrid_online"
|
|
15
|
+
HYBRID_OFFLINE = "hybrid_offline"
|
|
16
|
+
CUSTOM = "custom"
|
|
17
|
+
|
|
18
|
+
class ChunkerType(str, Enum):
|
|
19
|
+
TEXT = "text"
|
|
20
|
+
JSON = "json"
|
|
21
|
+
PDF = "pdf"
|
|
22
|
+
CUSTOM = "custom"
|
|
23
|
+
|
|
24
|
+
class RetrieverType(str, Enum):
|
|
25
|
+
SIMPLE = "simple"
|
|
26
|
+
HYBRID = "hybrid"
|
|
27
|
+
CUSTOM = "custom"
|
|
28
|
+
|
|
29
|
+
class Config(BaseSettings):
|
|
30
|
+
"""Configuration settings for KSS RAG with extensive customization options"""
|
|
31
|
+
|
|
32
|
+
# OpenRouter settings
|
|
33
|
+
OPENROUTER_API_KEY: str = Field(
|
|
34
|
+
default=os.getenv("OPENROUTER_API_KEY", ""),
|
|
35
|
+
description="Your OpenRouter API key for accessing LLMs"
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
DEFAULT_MODEL: str = Field(
|
|
39
|
+
default=os.getenv("DEFAULT_MODEL", "deepseek/deepseek-chat-v3.1:free"),
|
|
40
|
+
description="Default model to use for LLM responses"
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
FALLBACK_MODELS: List[str] = Field(
|
|
44
|
+
default=os.getenv("FALLBACK_MODELS", "deepseek/deepseek-r1-0528:free,deepseek/deepseek-chat,deepseek/deepseek-r1:free").split(","),
|
|
45
|
+
description="List of fallback models to try if the default model fails"
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
# Chunking settings
|
|
49
|
+
CHUNK_SIZE: int = Field(
|
|
50
|
+
default=int(os.getenv("CHUNK_SIZE", 500)),
|
|
51
|
+
ge=100,
|
|
52
|
+
le=2000,
|
|
53
|
+
description="Size of text chunks in characters"
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
CHUNK_OVERLAP: int = Field(
|
|
57
|
+
default=int(os.getenv("CHUNK_OVERLAP", 50)),
|
|
58
|
+
ge=0,
|
|
59
|
+
le=500,
|
|
60
|
+
description="Overlap between chunks in characters"
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
CHUNKER_TYPE: ChunkerType = Field(
|
|
64
|
+
default=os.getenv("CHUNKER_TYPE", ChunkerType.TEXT),
|
|
65
|
+
description="Type of chunker to use"
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# Vector store settings
|
|
69
|
+
VECTOR_STORE_TYPE: VectorStoreType = Field(
|
|
70
|
+
default=os.getenv("VECTOR_STORE_TYPE", VectorStoreType.HYBRID_OFFLINE),
|
|
71
|
+
description="Type of vector store to use"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
FAISS_MODEL_NAME: str = Field(
|
|
75
|
+
default=os.getenv("FAISS_MODEL_NAME", "sentence-transformers/all-MiniLM-L6-v2"),
|
|
76
|
+
description="SentenceTransformer model name for FAISS embeddings"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
# Retrieval settings
|
|
80
|
+
RETRIEVER_TYPE: RetrieverType = Field(
|
|
81
|
+
default=os.getenv("RETRIEVER_TYPE", RetrieverType.SIMPLE),
|
|
82
|
+
description="Type of retriever to use"
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
TOP_K: int = Field(
|
|
86
|
+
default=int(os.getenv("TOP_K", 5)),
|
|
87
|
+
ge=1,
|
|
88
|
+
le=20,
|
|
89
|
+
description="Number of results to retrieve"
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
FUZZY_MATCH_THRESHOLD: int = Field(
|
|
93
|
+
default=int(os.getenv("FUZZY_MATCH_THRESHOLD", 80)),
|
|
94
|
+
ge=0,
|
|
95
|
+
le=100,
|
|
96
|
+
description="Threshold for fuzzy matching (0-100)"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Performance settings
|
|
100
|
+
BATCH_SIZE: int = Field(
|
|
101
|
+
default=int(os.getenv("BATCH_SIZE", 64)),
|
|
102
|
+
ge=1,
|
|
103
|
+
le=256,
|
|
104
|
+
description="Batch size for processing documents"
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
MAX_DOCS_FOR_TESTING: Optional[int] = Field(
|
|
108
|
+
default=os.getenv("MAX_DOCS_FOR_TESTING"),
|
|
109
|
+
description="Limit documents for testing (None for all)"
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Server settings
|
|
113
|
+
SERVER_HOST: str = Field(
|
|
114
|
+
default=os.getenv("SERVER_HOST", "localhost"),
|
|
115
|
+
description="Host to run the server on"
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
SERVER_PORT: int = Field(
|
|
119
|
+
default=int(os.getenv("SERVER_PORT", 8000)),
|
|
120
|
+
ge=1024,
|
|
121
|
+
le=65535,
|
|
122
|
+
description="Port to run the server on"
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
CORS_ORIGINS: List[str] = Field(
|
|
126
|
+
default=os.getenv("CORS_ORIGINS", "*").split(","),
|
|
127
|
+
description="List of CORS origins"
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
CORS_ALLOW_CREDENTIALS: bool = Field(
|
|
131
|
+
default=os.getenv("CORS_ALLOW_CREDENTIALS", "True").lower() == "true",
|
|
132
|
+
description="Whether to allow CORS credentials"
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
CORS_ALLOW_METHODS: List[str] = Field(
|
|
136
|
+
default=os.getenv("CORS_ALLOW_METHODS", "GET,POST,PUT,DELETE,OPTIONS").split(","),
|
|
137
|
+
description="List of allowed CORS methods"
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
CORS_ALLOW_HEADERS: List[str] = Field(
|
|
141
|
+
default=os.getenv("CORS_ALLOW_HEADERS", "Content-Type,Authorization").split(","),
|
|
142
|
+
description="List of allowed CORS headers"
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Advanced settings
|
|
146
|
+
ENABLE_CACHE: bool = Field(
|
|
147
|
+
default=os.getenv("ENABLE_CACHE", "True").lower() == "true",
|
|
148
|
+
description="Whether to enable caching for vector stores"
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
CACHE_DIR: str = Field(
|
|
152
|
+
default=os.getenv("CACHE_DIR", ".cache"),
|
|
153
|
+
description="Directory to store cache files"
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
LOG_LEVEL: str = Field(
|
|
157
|
+
default=os.getenv("LOG_LEVEL", "INFO"),
|
|
158
|
+
description="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)"
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
# Custom components (for advanced users)
|
|
162
|
+
CUSTOM_CHUNKER: Optional[str] = Field(
|
|
163
|
+
default=os.getenv("CUSTOM_CHUNKER"),
|
|
164
|
+
description="Import path to a custom chunker class"
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
CUSTOM_VECTOR_STORE: Optional[str] = Field(
|
|
168
|
+
default=os.getenv("CUSTOM_VECTOR_STORE"),
|
|
169
|
+
description="Import path to a custom vector store class"
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
CUSTOM_RETRIEVER: Optional[str] = Field(
|
|
173
|
+
default=os.getenv("CUSTOM_RETRIEVER"),
|
|
174
|
+
description="Import path to a custom retriever class"
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
CUSTOM_LLM: Optional[str] = Field(
|
|
178
|
+
default=os.getenv("CUSTOM_LLM"),
|
|
179
|
+
description="Import path to a custom LLM class"
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
class Config:
|
|
183
|
+
env_file = ".env"
|
|
184
|
+
case_sensitive = False
|
|
185
|
+
use_enum_values = True
|
|
186
|
+
|
|
187
|
+
@validator('FALLBACK_MODELS', 'CORS_ORIGINS', 'CORS_ALLOW_METHODS', 'CORS_ALLOW_HEADERS', pre=True)
|
|
188
|
+
def split_string(cls, v):
|
|
189
|
+
if isinstance(v, str):
|
|
190
|
+
return [item.strip() for item in v.split(',')]
|
|
191
|
+
return v
|
|
192
|
+
|
|
193
|
+
config = Config()
|
kssrag/core/__init__.py
ADDED
|
File without changes
|
kssrag/core/agents.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from typing import List, Dict, Any, Optional
|
|
2
|
+
from ..utils.helpers import logger
|
|
3
|
+
|
|
4
|
+
class RAGAgent:
|
|
5
|
+
"""RAG agent implementation"""
|
|
6
|
+
|
|
7
|
+
def __init__(self, retriever, llm, system_prompt: Optional[str] = None,
|
|
8
|
+
conversation_history: Optional[List[Dict[str, str]]] = None):
|
|
9
|
+
self.retriever = retriever
|
|
10
|
+
self.llm = llm
|
|
11
|
+
self.conversation = conversation_history or []
|
|
12
|
+
self.system_prompt = system_prompt or """You are a helpful AI assistant. Use the following context to answer the user's question.
|
|
13
|
+
If you don't know the answer based on the context, say so."""
|
|
14
|
+
|
|
15
|
+
# Initialize with system message if not already present
|
|
16
|
+
if not any(msg.get("role") == "system" for msg in self.conversation):
|
|
17
|
+
self.add_message("system", self.system_prompt)
|
|
18
|
+
|
|
19
|
+
def add_message(self, role: str, content: str):
|
|
20
|
+
"""Add a message to the conversation history"""
|
|
21
|
+
self.conversation.append({"role": role, "content": content})
|
|
22
|
+
|
|
23
|
+
# Keep conversation manageable (last 10 messages)
|
|
24
|
+
if len(self.conversation) > 10:
|
|
25
|
+
# Always keep the system message
|
|
26
|
+
system_msg = next((msg for msg in self.conversation if msg["role"] == "system"), None)
|
|
27
|
+
other_msgs = [msg for msg in self.conversation if msg["role"] != "system"]
|
|
28
|
+
|
|
29
|
+
# Keep the most recent messages
|
|
30
|
+
self.conversation = [system_msg] + other_msgs[-9:] if system_msg else other_msgs[-10:]
|
|
31
|
+
|
|
32
|
+
def query(self, question: str, top_k: int = 5, include_context: bool = True) -> str:
|
|
33
|
+
"""Process a query and return a response"""
|
|
34
|
+
try:
|
|
35
|
+
# Retrieve relevant context
|
|
36
|
+
context_docs = self.retriever.retrieve(question, top_k)
|
|
37
|
+
|
|
38
|
+
if not context_docs and include_context:
|
|
39
|
+
logger.warning(f"No context found for query: {question}")
|
|
40
|
+
return "I couldn't find relevant information to answer your question."
|
|
41
|
+
|
|
42
|
+
# Format context
|
|
43
|
+
context = ""
|
|
44
|
+
if include_context and context_docs:
|
|
45
|
+
context = "Relevant information:\n"
|
|
46
|
+
for i, doc in enumerate(context_docs, 1):
|
|
47
|
+
context += f"\n--- Document {i} ---\n{doc['content']}\n"
|
|
48
|
+
|
|
49
|
+
# Add user query with context
|
|
50
|
+
user_message = f"{context}\n\nQuestion: {question}" if context else question
|
|
51
|
+
self.add_message("user", user_message)
|
|
52
|
+
|
|
53
|
+
# Generate response
|
|
54
|
+
response = self.llm.predict(self.conversation)
|
|
55
|
+
|
|
56
|
+
# Add assistant response to conversation
|
|
57
|
+
self.add_message("assistant", response)
|
|
58
|
+
|
|
59
|
+
return response
|
|
60
|
+
|
|
61
|
+
except Exception as e:
|
|
62
|
+
logger.error(f"Error processing query: {str(e)}")
|
|
63
|
+
return "I encountered an issue processing your query. Please try again."
|
|
64
|
+
|
|
65
|
+
def clear_conversation(self):
|
|
66
|
+
"""Clear conversation history except system message"""
|
|
67
|
+
system_msg = next((msg for msg in self.conversation if msg["role"] == "system"), None)
|
|
68
|
+
self.conversation = [system_msg] if system_msg else []
|
kssrag/core/chunkers.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import re
|
|
3
|
+
from typing import List, Dict, Any, Optional
|
|
4
|
+
import pypdf
|
|
5
|
+
from ..utils.helpers import logger
|
|
6
|
+
|
|
7
|
+
class BaseChunker:
|
|
8
|
+
"""Base class for document chunkers"""
|
|
9
|
+
|
|
10
|
+
def __init__(self, chunk_size: int = 500, overlap: int = 50):
|
|
11
|
+
self.chunk_size = chunk_size
|
|
12
|
+
self.overlap = overlap
|
|
13
|
+
|
|
14
|
+
def chunk(self, content: Any, metadata: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
|
|
15
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
16
|
+
|
|
17
|
+
class TextChunker(BaseChunker):
|
|
18
|
+
"""Chunker for plain text documents"""
|
|
19
|
+
|
|
20
|
+
def chunk(self, content: str, metadata: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
|
|
21
|
+
"""Split text into chunks with overlap"""
|
|
22
|
+
if metadata is None:
|
|
23
|
+
metadata = {}
|
|
24
|
+
|
|
25
|
+
chunks = []
|
|
26
|
+
start = 0
|
|
27
|
+
content_length = len(content)
|
|
28
|
+
|
|
29
|
+
while start < content_length:
|
|
30
|
+
end = start + self.chunk_size
|
|
31
|
+
if end > content_length:
|
|
32
|
+
end = content_length
|
|
33
|
+
|
|
34
|
+
chunk = content[start:end]
|
|
35
|
+
chunk_metadata = metadata.copy()
|
|
36
|
+
chunk_metadata["chunk_id"] = len(chunks)
|
|
37
|
+
|
|
38
|
+
chunks.append({
|
|
39
|
+
"content": chunk,
|
|
40
|
+
"metadata": chunk_metadata
|
|
41
|
+
})
|
|
42
|
+
|
|
43
|
+
start += self.chunk_size - self.overlap
|
|
44
|
+
|
|
45
|
+
logger.info(f"Created {len(chunks)} chunks from text")
|
|
46
|
+
return chunks
|
|
47
|
+
|
|
48
|
+
class JSONChunker(BaseChunker):
|
|
49
|
+
"""Chunker for JSON documents (like your drug data)"""
|
|
50
|
+
|
|
51
|
+
def chunk(self, data: List[Dict[str, Any]], metadata_field: str = "name") -> List[Dict[str, Any]]:
|
|
52
|
+
"""Create chunks from JSON data"""
|
|
53
|
+
chunks = []
|
|
54
|
+
|
|
55
|
+
for item in data:
|
|
56
|
+
if metadata_field not in item:
|
|
57
|
+
continue
|
|
58
|
+
|
|
59
|
+
# Create a comprehensive text representation
|
|
60
|
+
item_text = f"Item Name: {item.get(metadata_field, 'N/A')}\n"
|
|
61
|
+
|
|
62
|
+
for key, value in item.items():
|
|
63
|
+
if key != metadata_field and value:
|
|
64
|
+
if isinstance(value, str):
|
|
65
|
+
item_text += f"{key.replace('_', ' ').title()}: {value}\n"
|
|
66
|
+
elif isinstance(value, list):
|
|
67
|
+
item_text += f"{key.replace('_', ' ').title()}: {', '.join(value)}\n"
|
|
68
|
+
|
|
69
|
+
chunks.append({
|
|
70
|
+
"content": item_text,
|
|
71
|
+
"metadata": {
|
|
72
|
+
"name": item[metadata_field],
|
|
73
|
+
"source": item.get('url', 'N/A')
|
|
74
|
+
}
|
|
75
|
+
})
|
|
76
|
+
|
|
77
|
+
logger.info(f"Created {len(chunks)} chunks from JSON data")
|
|
78
|
+
return chunks
|
|
79
|
+
|
|
80
|
+
class PDFChunker(TextChunker):
|
|
81
|
+
"""Chunker for PDF documents"""
|
|
82
|
+
|
|
83
|
+
def extract_text(self, pdf_path: str) -> str:
|
|
84
|
+
"""Extract text from PDF file"""
|
|
85
|
+
text = ""
|
|
86
|
+
try:
|
|
87
|
+
with open(pdf_path, 'rb') as f:
|
|
88
|
+
reader = pypdf.PdfReader(f)
|
|
89
|
+
for page in reader.pages:
|
|
90
|
+
text += page.extract_text() + "\n"
|
|
91
|
+
except Exception as e:
|
|
92
|
+
logger.error(f"Failed to extract text from PDF: {str(e)}")
|
|
93
|
+
raise
|
|
94
|
+
|
|
95
|
+
return text
|
|
96
|
+
|
|
97
|
+
def chunk_pdf(self, pdf_path: str, metadata: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
|
|
98
|
+
"""Extract text from PDF and chunk it"""
|
|
99
|
+
text = self.extract_text(pdf_path)
|
|
100
|
+
return self.chunk(text, metadata)
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
from typing import List, Dict, Any, Optional
|
|
2
|
+
from ..utils.helpers import logger
|
|
3
|
+
|
|
4
|
+
class BaseRetriever:
|
|
5
|
+
"""Base class for retrievers"""
|
|
6
|
+
|
|
7
|
+
def __init__(self, vector_store):
|
|
8
|
+
self.vector_store = vector_store
|
|
9
|
+
|
|
10
|
+
def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
|
11
|
+
"""Retrieve documents based on query"""
|
|
12
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
13
|
+
|
|
14
|
+
class SimpleRetriever(BaseRetriever):
|
|
15
|
+
"""Simple retriever using only vector store"""
|
|
16
|
+
|
|
17
|
+
def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
|
18
|
+
return self.vector_store.retrieve(query, top_k)
|
|
19
|
+
|
|
20
|
+
class HybridRetriever(BaseRetriever):
|
|
21
|
+
"""Hybrid retriever with fuzzy matching for specific entities"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, vector_store, entity_names: Optional[List[str]] = None):
|
|
24
|
+
super().__init__(vector_store)
|
|
25
|
+
self.entity_names = entity_names or []
|
|
26
|
+
|
|
27
|
+
def extract_entities(self, query: str) -> List[str]:
|
|
28
|
+
"""Extract entities from query using fuzzy matching"""
|
|
29
|
+
from rapidfuzz import process, fuzz
|
|
30
|
+
|
|
31
|
+
extracted_entities = []
|
|
32
|
+
query_lower = query.lower()
|
|
33
|
+
|
|
34
|
+
# Check for exact matches first
|
|
35
|
+
for entity in self.entity_names:
|
|
36
|
+
if entity in query_lower:
|
|
37
|
+
extracted_entities.append(entity)
|
|
38
|
+
|
|
39
|
+
# Use fuzzy matching for partial matches
|
|
40
|
+
if not extracted_entities and self.entity_names:
|
|
41
|
+
matches = process.extract(query, self.entity_names, scorer=fuzz.partial_ratio, limit=5)
|
|
42
|
+
extracted_entities = [match[0] for match in matches if match[1] > 80]
|
|
43
|
+
|
|
44
|
+
return extracted_entities
|
|
45
|
+
|
|
46
|
+
def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
|
47
|
+
# First get results from vector store
|
|
48
|
+
results = self.vector_store.retrieve(query, top_k * 2)
|
|
49
|
+
|
|
50
|
+
# If we have entity names, boost documents that mention extracted entities
|
|
51
|
+
if self.entity_names:
|
|
52
|
+
extracted_entities = self.extract_entities(query)
|
|
53
|
+
|
|
54
|
+
if extracted_entities:
|
|
55
|
+
# Boost scores for documents containing extracted entities
|
|
56
|
+
scored_results = []
|
|
57
|
+
|
|
58
|
+
for doc in results:
|
|
59
|
+
score = 1.0
|
|
60
|
+
content_lower = doc["content"].lower()
|
|
61
|
+
|
|
62
|
+
# Boost score if any extracted entity is mentioned
|
|
63
|
+
for entity in extracted_entities:
|
|
64
|
+
if entity in content_lower:
|
|
65
|
+
score += 0.5
|
|
66
|
+
break
|
|
67
|
+
|
|
68
|
+
scored_results.append((doc, score))
|
|
69
|
+
|
|
70
|
+
# Sort by score and return top_k
|
|
71
|
+
scored_results.sort(key=lambda x: x[1], reverse=True)
|
|
72
|
+
return [doc for doc, _ in scored_results[:top_k]]
|
|
73
|
+
|
|
74
|
+
return results[:top_k]
|