mb-rag 1.0.123__tar.gz → 1.0.125__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mb-rag might be problematic. Click here for more details.
- {mb_rag-1.0.123 → mb_rag-1.0.125}/PKG-INFO +1 -1
- {mb_rag-1.0.123 → mb_rag-1.0.125}/mb_rag/chatbot/basic.py +6 -5
- mb_rag-1.0.125/mb_rag/rag/embeddings.py +729 -0
- {mb_rag-1.0.123 → mb_rag-1.0.125}/mb_rag/version.py +1 -1
- {mb_rag-1.0.123 → mb_rag-1.0.125}/mb_rag.egg-info/PKG-INFO +1 -1
- mb_rag-1.0.123/mb_rag/rag/embeddings.py +0 -480
- {mb_rag-1.0.123 → mb_rag-1.0.125}/README.md +0 -0
- {mb_rag-1.0.123 → mb_rag-1.0.125}/mb_rag/__init__.py +0 -0
- {mb_rag-1.0.123 → mb_rag-1.0.125}/mb_rag/chatbot/__init__.py +0 -0
- {mb_rag-1.0.123 → mb_rag-1.0.125}/mb_rag/chatbot/chains.py +0 -0
- {mb_rag-1.0.123 → mb_rag-1.0.125}/mb_rag/chatbot/prompts.py +0 -0
- {mb_rag-1.0.123 → mb_rag-1.0.125}/mb_rag/rag/__init__.py +0 -0
- {mb_rag-1.0.123 → mb_rag-1.0.125}/mb_rag/utils/__init__.py +0 -0
- {mb_rag-1.0.123 → mb_rag-1.0.125}/mb_rag/utils/bounding_box.py +0 -0
- {mb_rag-1.0.123 → mb_rag-1.0.125}/mb_rag/utils/extra.py +0 -0
- {mb_rag-1.0.123 → mb_rag-1.0.125}/mb_rag.egg-info/SOURCES.txt +0 -0
- {mb_rag-1.0.123 → mb_rag-1.0.125}/mb_rag.egg-info/dependency_links.txt +0 -0
- {mb_rag-1.0.123 → mb_rag-1.0.125}/mb_rag.egg-info/requires.txt +0 -0
- {mb_rag-1.0.123 → mb_rag-1.0.125}/mb_rag.egg-info/top_level.txt +0 -0
- {mb_rag-1.0.123 → mb_rag-1.0.125}/pyproject.toml +0 -0
- {mb_rag-1.0.123 → mb_rag-1.0.125}/setup.cfg +0 -0
- {mb_rag-1.0.123 → mb_rag-1.0.125}/setup.py +0 -0
|
@@ -220,13 +220,14 @@ class ConversationModel:
|
|
|
220
220
|
self.chatbot = ModelFactory(model_type, model_name, **kwargs)
|
|
221
221
|
|
|
222
222
|
def initialize_conversation(self,
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
223
|
+
question: Optional[str],
|
|
224
|
+
context: Optional[str] = None,
|
|
225
|
+
file_path: Optional[str]=None) -> None:
|
|
226
226
|
"""Initialize conversation state"""
|
|
227
227
|
if file_path:
|
|
228
|
-
self.file_path = file_path
|
|
228
|
+
self.file_path = file_path
|
|
229
229
|
self.load_conversation(file_path)
|
|
230
|
+
|
|
230
231
|
else:
|
|
231
232
|
if not question:
|
|
232
233
|
raise ValueError("Question is required.")
|
|
@@ -371,7 +372,7 @@ class ConversationModel:
|
|
|
371
372
|
raise ValueError(f"Error loading conversation from s3: {e}")
|
|
372
373
|
|
|
373
374
|
def _load_from_file(self, file_path: str) -> List[Any]:
|
|
374
|
-
"""Load conversation from file"""
|
|
375
|
+
"""Load conversation from file"""
|
|
375
376
|
try:
|
|
376
377
|
with open(file_path, 'r') as f:
|
|
377
378
|
lines = f.readlines()
|
|
@@ -0,0 +1,729 @@
|
|
|
1
|
+
"""
|
|
2
|
+
RAG (Retrieval-Augmented Generation) Embeddings Module
|
|
3
|
+
|
|
4
|
+
This module provides functionality for generating and managing embeddings for RAG models.
|
|
5
|
+
It supports multiple embedding models (OpenAI, Ollama, Google, Anthropic) and includes
|
|
6
|
+
features for text processing, embedding generation, vector store management, and
|
|
7
|
+
conversation handling.
|
|
8
|
+
|
|
9
|
+
Example Usage:
|
|
10
|
+
```python
|
|
11
|
+
# Initialize embedding generator
|
|
12
|
+
em_gen = embedding_generator(
|
|
13
|
+
model="openai",
|
|
14
|
+
model_type="text-embedding-3-small",
|
|
15
|
+
vector_store_type="chroma"
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
# Generate embeddings from text
|
|
19
|
+
em_gen.generate_text_embeddings(
|
|
20
|
+
text_data_path=['./data/text.txt'],
|
|
21
|
+
chunk_size=500,
|
|
22
|
+
chunk_overlap=5,
|
|
23
|
+
folder_save_path='./embeddings'
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
# Load embeddings and create retriever
|
|
27
|
+
em_loading = em_gen.load_embeddings('./embeddings')
|
|
28
|
+
em_retriever = em_gen.load_retriever(
|
|
29
|
+
'./embeddings',
|
|
30
|
+
search_params=[{"k": 2, "score_threshold": 0.1}]
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
# Query embeddings
|
|
34
|
+
results = em_retriever.invoke("What is the text about?")
|
|
35
|
+
|
|
36
|
+
# Generate RAG chain for conversation
|
|
37
|
+
rag_chain = em_gen.generate_rag_chain(retriever=em_retriever)
|
|
38
|
+
response = em_gen.conversation_chain("Tell me more", rag_chain)
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
Features:
|
|
42
|
+
- Multiple model support (OpenAI, Ollama, Google, Anthropic)
|
|
43
|
+
- Text processing and chunking
|
|
44
|
+
- Embedding generation and storage
|
|
45
|
+
- Vector store management
|
|
46
|
+
- Retrieval operations
|
|
47
|
+
- Conversation chains
|
|
48
|
+
- Web crawling integration
|
|
49
|
+
|
|
50
|
+
Classes:
|
|
51
|
+
- ModelProvider: Base class for model loading and validation
|
|
52
|
+
- TextProcessor: Handles text processing operations
|
|
53
|
+
- embedding_generator: Main class for RAG operations
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
import os
|
|
57
|
+
import shutil
|
|
58
|
+
import importlib.util
|
|
59
|
+
from typing import List, Dict, Optional, Union, Any
|
|
60
|
+
from langchain.text_splitter import (
|
|
61
|
+
CharacterTextSplitter,
|
|
62
|
+
RecursiveCharacterTextSplitter,
|
|
63
|
+
SentenceTransformersTokenTextSplitter,
|
|
64
|
+
TokenTextSplitter)
|
|
65
|
+
from langchain_community.document_loaders import TextLoader, FireCrawlLoader
|
|
66
|
+
from langchain_community.vectorstores import Chroma
|
|
67
|
+
from ..utils.extra import load_env_file
|
|
68
|
+
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
|
|
69
|
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
|
70
|
+
from langchain_core.messages import HumanMessage, SystemMessage
|
|
71
|
+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
72
|
+
|
|
73
|
+
load_env_file()
|
|
74
|
+
|
|
75
|
+
__all__ = ['embedding_generator', 'load_embedding_model']
|
|
76
|
+
|
|
77
|
+
class ModelProvider:
|
|
78
|
+
"""
|
|
79
|
+
Base class for managing different model providers and their loading logic.
|
|
80
|
+
|
|
81
|
+
This class provides static methods for loading different types of embedding models
|
|
82
|
+
and checking package dependencies.
|
|
83
|
+
|
|
84
|
+
Methods:
|
|
85
|
+
check_package: Check if a Python package is installed
|
|
86
|
+
get_rag_openai: Load OpenAI embedding model
|
|
87
|
+
get_rag_ollama: Load Ollama embedding model
|
|
88
|
+
get_rag_anthropic: Load Anthropic model
|
|
89
|
+
get_rag_google: Load Google embedding model
|
|
90
|
+
|
|
91
|
+
Example:
|
|
92
|
+
```python
|
|
93
|
+
# Check if a package is installed
|
|
94
|
+
has_openai = ModelProvider.check_package("langchain_openai")
|
|
95
|
+
|
|
96
|
+
# Load an OpenAI model
|
|
97
|
+
model = ModelProvider.get_rag_openai("text-embedding-3-small")
|
|
98
|
+
```
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
@staticmethod
|
|
102
|
+
def check_package(package_name: str) -> bool:
|
|
103
|
+
"""
|
|
104
|
+
Check if a Python package is installed.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
package_name (str): Name of the package to check
|
|
108
|
+
|
|
109
|
+
"""
|
|
110
|
+
return importlib.util.find_spec(package_name) is not None
|
|
111
|
+
|
|
112
|
+
@staticmethod
|
|
113
|
+
def get_rag_openai(model_type: str = 'text-embedding-3-small', **kwargs):
|
|
114
|
+
"""
|
|
115
|
+
Load OpenAI embedding model.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
model_type (str): Model identifier (default: 'text-embedding-3-small')
|
|
119
|
+
**kwargs: Additional arguments for model initialization
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
OpenAIEmbeddings: Initialized OpenAI embeddings model
|
|
123
|
+
"""
|
|
124
|
+
if not ModelProvider.check_package("langchain_openai"):
|
|
125
|
+
raise ImportError("OpenAI package not found. Please install: pip install langchain-openai")
|
|
126
|
+
from langchain_openai import OpenAIEmbeddings
|
|
127
|
+
return OpenAIEmbeddings(model=model_type, **kwargs)
|
|
128
|
+
|
|
129
|
+
@staticmethod
|
|
130
|
+
def get_rag_ollama(model_type: str = 'llama3', **kwargs):
|
|
131
|
+
"""
|
|
132
|
+
Load Ollama embedding model.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
model_type (str): Model identifier (default: 'llama3')
|
|
136
|
+
**kwargs: Additional arguments for model initialization
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
OllamaEmbeddings: Initialized Ollama embeddings model
|
|
140
|
+
"""
|
|
141
|
+
if not ModelProvider.check_package("langchain_ollama"):
|
|
142
|
+
raise ImportError("Ollama package not found. Please install: pip install langchain-ollama")
|
|
143
|
+
from langchain_ollama import OllamaEmbeddings
|
|
144
|
+
return OllamaEmbeddings(model=model_type, **kwargs)
|
|
145
|
+
|
|
146
|
+
@staticmethod
|
|
147
|
+
def get_rag_anthropic(model_name: str = "claude-3-opus-20240229", **kwargs):
|
|
148
|
+
"""
|
|
149
|
+
Load Anthropic model.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
model_name (str): Model identifier (default: "claude-3-opus-20240229")
|
|
153
|
+
**kwargs: Additional arguments for model initialization
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
ChatAnthropic: Initialized Anthropic chat model
|
|
157
|
+
|
|
158
|
+
"""
|
|
159
|
+
if not ModelProvider.check_package("langchain_anthropic"):
|
|
160
|
+
raise ImportError("Anthropic package not found. Please install: pip install langchain-anthropic")
|
|
161
|
+
from langchain_anthropic import ChatAnthropic
|
|
162
|
+
kwargs["model_name"] = model_name
|
|
163
|
+
return ChatAnthropic(**kwargs)
|
|
164
|
+
|
|
165
|
+
@staticmethod
|
|
166
|
+
def get_rag_google(model_name: str = "gemini-1.5-flash", **kwargs):
|
|
167
|
+
"""
|
|
168
|
+
Load Google embedding model.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
model_name (str): Model identifier (default: "gemini-1.5-flash")
|
|
172
|
+
**kwargs: Additional arguments for model initialization
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
GoogleGenerativeAIEmbeddings: Initialized Google embeddings model
|
|
176
|
+
"""
|
|
177
|
+
if not ModelProvider.check_package("google.generativeai"):
|
|
178
|
+
raise ImportError("Google Generative AI package not found. Please install: pip install langchain-google-genai")
|
|
179
|
+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
|
180
|
+
kwargs["model"] = model_name
|
|
181
|
+
return GoogleGenerativeAIEmbeddings(**kwargs)
|
|
182
|
+
|
|
183
|
+
def load_embedding_model(model_name: str = 'openai', model_type: str = "text-embedding-ada-002", **kwargs):
|
|
184
|
+
"""
|
|
185
|
+
Load a RAG model based on provider and type.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
model_name (str): Name of the model provider (default: 'openai')
|
|
189
|
+
model_type (str): Type/identifier of the model (default: "text-embedding-ada-002")
|
|
190
|
+
**kwargs: Additional arguments for model initialization
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
Any: Initialized model instance
|
|
194
|
+
|
|
195
|
+
Example:
|
|
196
|
+
```python
|
|
197
|
+
model = load_embedding_model('openai', 'text-embedding-3-small')
|
|
198
|
+
```
|
|
199
|
+
"""
|
|
200
|
+
try:
|
|
201
|
+
if model_name == 'openai':
|
|
202
|
+
return ModelProvider.get_rag_openai(model_type, **kwargs)
|
|
203
|
+
elif model_name == 'ollama':
|
|
204
|
+
return ModelProvider.get_rag_ollama(model_type, **kwargs)
|
|
205
|
+
elif model_name == 'google':
|
|
206
|
+
return ModelProvider.get_rag_google(model_type, **kwargs)
|
|
207
|
+
elif model_name == 'anthropic':
|
|
208
|
+
return ModelProvider.get_rag_anthropic(model_type, **kwargs)
|
|
209
|
+
else:
|
|
210
|
+
raise ValueError(f"Invalid model name: {model_name}")
|
|
211
|
+
except ImportError as e:
|
|
212
|
+
print(f"Error loading model: {str(e)}")
|
|
213
|
+
return None
|
|
214
|
+
|
|
215
|
+
class TextProcessor:
|
|
216
|
+
"""
|
|
217
|
+
Handles text processing operations including file checking and tokenization.
|
|
218
|
+
|
|
219
|
+
This class provides methods for loading text files, processing them into chunks,
|
|
220
|
+
and preparing them for embedding generation.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
logger: Optional logger instance for logging operations
|
|
224
|
+
|
|
225
|
+
Example:
|
|
226
|
+
```python
|
|
227
|
+
processor = TextProcessor()
|
|
228
|
+
docs = processor.tokenize(
|
|
229
|
+
['./data.txt'],
|
|
230
|
+
'recursive_character',
|
|
231
|
+
chunk_size=1000,
|
|
232
|
+
chunk_overlap=5
|
|
233
|
+
)
|
|
234
|
+
```
|
|
235
|
+
"""
|
|
236
|
+
|
|
237
|
+
def __init__(self, logger=None):
|
|
238
|
+
self.logger = logger
|
|
239
|
+
|
|
240
|
+
def check_file(self, file_path: str) -> bool:
|
|
241
|
+
"""Check if file exists."""
|
|
242
|
+
return os.path.exists(file_path)
|
|
243
|
+
|
|
244
|
+
def tokenize(self, text_data_path: List[str], text_splitter_type: str,
|
|
245
|
+
chunk_size: int, chunk_overlap: int) -> List:
|
|
246
|
+
"""
|
|
247
|
+
Process and tokenize text data from files.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
text_data_path (List[str]): List of paths to text files
|
|
251
|
+
text_splitter_type (str): Type of text splitter to use
|
|
252
|
+
chunk_size (int): Size of text chunks
|
|
253
|
+
chunk_overlap (int): Overlap between chunks
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
List: List of processed document chunks
|
|
257
|
+
|
|
258
|
+
"""
|
|
259
|
+
doc_data = []
|
|
260
|
+
for path in text_data_path:
|
|
261
|
+
if self.check_file(path):
|
|
262
|
+
text_loader = TextLoader(path)
|
|
263
|
+
get_text = text_loader.load()
|
|
264
|
+
file_name = path.split('/')[-1]
|
|
265
|
+
metadata = {'source': file_name}
|
|
266
|
+
if metadata is not None:
|
|
267
|
+
for doc in get_text:
|
|
268
|
+
doc.metadata = metadata
|
|
269
|
+
doc_data.append(doc)
|
|
270
|
+
if self.logger:
|
|
271
|
+
self.logger.info(f"Text data loaded from {file_name}")
|
|
272
|
+
else:
|
|
273
|
+
return f"File {path} not found"
|
|
274
|
+
|
|
275
|
+
splitters = {
|
|
276
|
+
'character': CharacterTextSplitter(
|
|
277
|
+
chunk_size=chunk_size,
|
|
278
|
+
chunk_overlap=chunk_overlap,
|
|
279
|
+
separator=["\n", "\n\n", "\n\n\n", " "]
|
|
280
|
+
),
|
|
281
|
+
'recursive_character': RecursiveCharacterTextSplitter(
|
|
282
|
+
chunk_size=chunk_size,
|
|
283
|
+
chunk_overlap=chunk_overlap,
|
|
284
|
+
separators=["\n", "\n\n", "\n\n\n", " "]
|
|
285
|
+
),
|
|
286
|
+
'sentence_transformers_token': SentenceTransformersTokenTextSplitter(
|
|
287
|
+
chunk_size=chunk_size
|
|
288
|
+
),
|
|
289
|
+
'token': TokenTextSplitter(
|
|
290
|
+
chunk_size=chunk_size,
|
|
291
|
+
chunk_overlap=chunk_overlap
|
|
292
|
+
)
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
if text_splitter_type not in splitters:
|
|
296
|
+
raise ValueError(f"Invalid text splitter type: {text_splitter_type}")
|
|
297
|
+
|
|
298
|
+
text_splitter = splitters[text_splitter_type]
|
|
299
|
+
docs = text_splitter.split_documents(doc_data)
|
|
300
|
+
|
|
301
|
+
if self.logger:
|
|
302
|
+
self.logger.info(f"Text data splitted into {len(docs)} chunks")
|
|
303
|
+
else:
|
|
304
|
+
print(f"Text data splitted into {len(docs)} chunks")
|
|
305
|
+
return docs
|
|
306
|
+
|
|
307
|
+
class embedding_generator:
|
|
308
|
+
"""
|
|
309
|
+
Main class for generating embeddings and managing RAG operations.
|
|
310
|
+
|
|
311
|
+
This class provides comprehensive functionality for generating embeddings,
|
|
312
|
+
managing vector stores, handling retrievers, and managing conversations.
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
model (str): Model provider name (default: 'openai')
|
|
316
|
+
model_type (str): Model type/identifier (default: 'text-embedding-3-small')
|
|
317
|
+
vector_store_type (str): Type of vector store (default: 'chroma')
|
|
318
|
+
collection_name (str): Name of the collection (default: 'test')
|
|
319
|
+
logger: Optional logger instance
|
|
320
|
+
model_kwargs (dict): Additional arguments for model initialization
|
|
321
|
+
vector_store_kwargs (dict): Additional arguments for vector store initialization
|
|
322
|
+
|
|
323
|
+
Example:
|
|
324
|
+
```python
|
|
325
|
+
# Initialize generator
|
|
326
|
+
gen = embedding_generator(
|
|
327
|
+
model="openai",
|
|
328
|
+
model_type="text-embedding-3-small"
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
# Generate embeddings
|
|
332
|
+
gen.generate_text_embeddings(
|
|
333
|
+
text_data_path=['./data.txt'],
|
|
334
|
+
folder_save_path='./embeddings'
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
# Load retriever
|
|
338
|
+
retriever = gen.load_retriever('./embeddings')
|
|
339
|
+
|
|
340
|
+
# Query embeddings
|
|
341
|
+
results = gen.query_embeddings("What is this about?")
|
|
342
|
+
```
|
|
343
|
+
"""
|
|
344
|
+
|
|
345
|
+
def __init__(self, model: str = 'openai', model_type: str = 'text-embedding-3-small',
|
|
346
|
+
vector_store_type: str = 'chroma', collection_name: str = 'test',
|
|
347
|
+
logger=None, model_kwargs: dict = None, vector_store_kwargs: dict = None) -> None:
|
|
348
|
+
"""Initialize the embedding generator with specified configuration."""
|
|
349
|
+
self.logger = logger
|
|
350
|
+
self.model = load_embedding_model(model_name=model, model_type=model_type, **(model_kwargs or {}))
|
|
351
|
+
if self.model is None:
|
|
352
|
+
raise ValueError(f"Failed to initialize model {model}. Please ensure required packages are installed.")
|
|
353
|
+
self.vector_store_type = vector_store_type
|
|
354
|
+
self.vector_store = self.load_vectorstore(**(vector_store_kwargs or {}))
|
|
355
|
+
self.collection_name = collection_name
|
|
356
|
+
self.text_processor = TextProcessor(logger)
|
|
357
|
+
|
|
358
|
+
def check_file(self, file_path: str) -> bool:
|
|
359
|
+
"""Check if file exists."""
|
|
360
|
+
return self.text_processor.check_file(file_path)
|
|
361
|
+
|
|
362
|
+
def tokenize(self, text_data_path: List[str], text_splitter_type: str,
|
|
363
|
+
chunk_size: int, chunk_overlap: int) -> List:
|
|
364
|
+
"""Process and tokenize text data."""
|
|
365
|
+
return self.text_processor.tokenize(text_data_path, text_splitter_type,
|
|
366
|
+
chunk_size, chunk_overlap)
|
|
367
|
+
|
|
368
|
+
def generate_text_embeddings(self, text_data_path: List[str] = None,
|
|
369
|
+
text_splitter_type: str = 'recursive_character',
|
|
370
|
+
chunk_size: int = 1000, chunk_overlap: int = 5,
|
|
371
|
+
folder_save_path: str = './text_embeddings',
|
|
372
|
+
replace_existing: bool = False) -> str:
|
|
373
|
+
"""
|
|
374
|
+
Generate text embeddings from input files.
|
|
375
|
+
|
|
376
|
+
Args:
|
|
377
|
+
text_data_path (List[str]): List of paths to text files
|
|
378
|
+
text_splitter_type (str): Type of text splitter
|
|
379
|
+
chunk_size (int): Size of text chunks
|
|
380
|
+
chunk_overlap (int): Overlap between chunks
|
|
381
|
+
folder_save_path (str): Path to save embeddings
|
|
382
|
+
replace_existing (bool): Whether to replace existing embeddings
|
|
383
|
+
|
|
384
|
+
Returns:
|
|
385
|
+
str: Status message
|
|
386
|
+
|
|
387
|
+
Example:
|
|
388
|
+
```python
|
|
389
|
+
gen.generate_text_embeddings(
|
|
390
|
+
text_data_path=['./data.txt'],
|
|
391
|
+
folder_save_path='./embeddings'
|
|
392
|
+
)
|
|
393
|
+
```
|
|
394
|
+
"""
|
|
395
|
+
if self.logger:
|
|
396
|
+
self.logger.info("Performing basic checks")
|
|
397
|
+
|
|
398
|
+
if self.check_file(folder_save_path) and not replace_existing:
|
|
399
|
+
return "File already exists"
|
|
400
|
+
elif self.check_file(folder_save_path) and replace_existing:
|
|
401
|
+
shutil.rmtree(folder_save_path)
|
|
402
|
+
|
|
403
|
+
if text_data_path is None:
|
|
404
|
+
return "Please provide text data path"
|
|
405
|
+
|
|
406
|
+
if not isinstance(text_data_path, list):
|
|
407
|
+
raise ValueError("text_data_path should be a list")
|
|
408
|
+
|
|
409
|
+
if self.logger:
|
|
410
|
+
self.logger.info(f"Loading text data from {text_data_path}")
|
|
411
|
+
|
|
412
|
+
docs = self.tokenize(text_data_path, text_splitter_type, chunk_size, chunk_overlap)
|
|
413
|
+
|
|
414
|
+
if self.logger:
|
|
415
|
+
self.logger.info(f"Generating embeddings for {len(docs)} documents")
|
|
416
|
+
|
|
417
|
+
self.vector_store.from_documents(docs, self.model, collection_name=self.collection_name,
|
|
418
|
+
persist_directory=folder_save_path)
|
|
419
|
+
|
|
420
|
+
if self.logger:
|
|
421
|
+
self.logger.info(f"Embeddings generated and saved at {folder_save_path}")
|
|
422
|
+
|
|
423
|
+
def load_vectorstore(self, **kwargs):
|
|
424
|
+
"""Load vector store."""
|
|
425
|
+
if self.vector_store_type == 'chroma':
|
|
426
|
+
vector_store = Chroma()
|
|
427
|
+
if self.logger:
|
|
428
|
+
self.logger.info(f"Loaded vector store {self.vector_store_type}")
|
|
429
|
+
return vector_store
|
|
430
|
+
else:
|
|
431
|
+
return "Vector store not found"
|
|
432
|
+
|
|
433
|
+
def load_embeddings(self, embeddings_folder_path: str):
|
|
434
|
+
"""
|
|
435
|
+
Load embeddings from folder.
|
|
436
|
+
|
|
437
|
+
Args:
|
|
438
|
+
embeddings_folder_path (str): Path to embeddings folder
|
|
439
|
+
|
|
440
|
+
Returns:
|
|
441
|
+
Optional[Chroma]: Loaded vector store or None if not found
|
|
442
|
+
"""
|
|
443
|
+
if self.check_file(embeddings_folder_path):
|
|
444
|
+
if self.vector_store_type == 'chroma':
|
|
445
|
+
return Chroma(persist_directory=embeddings_folder_path,
|
|
446
|
+
embedding_function=self.model)
|
|
447
|
+
else:
|
|
448
|
+
if self.logger:
|
|
449
|
+
self.logger.info("Embeddings file not found")
|
|
450
|
+
return None
|
|
451
|
+
|
|
452
|
+
def load_retriever(self, embeddings_folder_path: str,
|
|
453
|
+
search_type: List[str] = ["similarity_score_threshold"],
|
|
454
|
+
search_params: List[Dict] = [{"k": 3, "score_threshold": 0.9}]):
|
|
455
|
+
"""
|
|
456
|
+
Load retriever with search configuration.
|
|
457
|
+
|
|
458
|
+
Args:
|
|
459
|
+
embeddings_folder_path (str): Path to embeddings folder
|
|
460
|
+
search_type (List[str]): List of search types
|
|
461
|
+
search_params (List[Dict]): List of search parameters
|
|
462
|
+
|
|
463
|
+
Returns:
|
|
464
|
+
Union[Any, List[Any]]: Single retriever or list of retrievers
|
|
465
|
+
|
|
466
|
+
Example:
|
|
467
|
+
```python
|
|
468
|
+
retriever = gen.load_retriever(
|
|
469
|
+
'./embeddings',
|
|
470
|
+
search_type=["similarity_score_threshold"],
|
|
471
|
+
search_params=[{"k": 3, "score_threshold": 0.9}]
|
|
472
|
+
)
|
|
473
|
+
```
|
|
474
|
+
"""
|
|
475
|
+
db = self.load_embeddings(embeddings_folder_path)
|
|
476
|
+
if db is not None:
|
|
477
|
+
if self.vector_store_type == 'chroma':
|
|
478
|
+
if len(search_type) != len(search_params):
|
|
479
|
+
raise ValueError("Length of search_type and search_params should be equal")
|
|
480
|
+
if len(search_type) == 1:
|
|
481
|
+
self.retriever = db.as_retriever(search_type=search_type[0],
|
|
482
|
+
search_kwargs=search_params[0])
|
|
483
|
+
if self.logger:
|
|
484
|
+
self.logger.info("Retriever loaded")
|
|
485
|
+
return self.retriever
|
|
486
|
+
else:
|
|
487
|
+
retriever_list = []
|
|
488
|
+
for i in range(len(search_type)):
|
|
489
|
+
retriever_list.append(db.as_retriever(search_type=search_type[i],
|
|
490
|
+
search_kwargs=search_params[i]))
|
|
491
|
+
if self.logger:
|
|
492
|
+
self.logger.info("List of Retriever loaded")
|
|
493
|
+
return retriever_list
|
|
494
|
+
else:
|
|
495
|
+
return "Embeddings file not found"
|
|
496
|
+
|
|
497
|
+
def add_data(self, embeddings_folder_path: str, data: List[str],
|
|
498
|
+
text_splitter_type: str = 'recursive_character',
|
|
499
|
+
chunk_size: int = 1000, chunk_overlap: int = 5):
|
|
500
|
+
"""
|
|
501
|
+
Add data to existing embeddings.
|
|
502
|
+
|
|
503
|
+
Args:
|
|
504
|
+
embeddings_folder_path (str): Path to embeddings folder
|
|
505
|
+
data (List[str]): List of text data to add
|
|
506
|
+
text_splitter_type (str): Type of text splitter
|
|
507
|
+
chunk_size (int): Size of text chunks
|
|
508
|
+
chunk_overlap (int): Overlap between chunks
|
|
509
|
+
"""
|
|
510
|
+
if self.vector_store_type == 'chroma':
|
|
511
|
+
db = self.load_embeddings(embeddings_folder_path)
|
|
512
|
+
if db is not None:
|
|
513
|
+
docs = self.tokenize(data, text_splitter_type, chunk_size, chunk_overlap)
|
|
514
|
+
db.add_documents(docs)
|
|
515
|
+
if self.logger:
|
|
516
|
+
self.logger.info("Data added to the existing db/embeddings")
|
|
517
|
+
|
|
518
|
+
def query_embeddings(self, query: str, retriever=None):
|
|
519
|
+
"""
|
|
520
|
+
Query embeddings.
|
|
521
|
+
|
|
522
|
+
Args:
|
|
523
|
+
query (str): Query string
|
|
524
|
+
retriever: Optional retriever instance
|
|
525
|
+
|
|
526
|
+
Returns:
|
|
527
|
+
Any: Query results
|
|
528
|
+
"""
|
|
529
|
+
if retriever is None:
|
|
530
|
+
retriever = self.retriever
|
|
531
|
+
return retriever.invoke(query)
|
|
532
|
+
|
|
533
|
+
def get_relevant_documents(self, query: str, retriever=None):
|
|
534
|
+
"""
|
|
535
|
+
Get relevant documents for query.
|
|
536
|
+
|
|
537
|
+
Args:
|
|
538
|
+
query (str): Query string
|
|
539
|
+
retriever: Optional retriever instance
|
|
540
|
+
|
|
541
|
+
Returns:
|
|
542
|
+
List: List of relevant documents
|
|
543
|
+
"""
|
|
544
|
+
if retriever is None:
|
|
545
|
+
retriever = self.retriever
|
|
546
|
+
return retriever.get_relevant_documents(query)
|
|
547
|
+
|
|
548
|
+
def generate_rag_chain(self, context_prompt: str = None, retriever=None, llm=None):
|
|
549
|
+
"""
|
|
550
|
+
Generate RAG chain for conversation.
|
|
551
|
+
|
|
552
|
+
Args:
|
|
553
|
+
context_prompt (str): Optional context prompt
|
|
554
|
+
retriever: Optional retriever instance
|
|
555
|
+
llm: Optional language model instance
|
|
556
|
+
|
|
557
|
+
Returns:
|
|
558
|
+
Any: Generated RAG chain
|
|
559
|
+
|
|
560
|
+
Example:
|
|
561
|
+
```python
|
|
562
|
+
rag_chain = gen.generate_rag_chain(retriever=retriever)
|
|
563
|
+
```
|
|
564
|
+
"""
|
|
565
|
+
if context_prompt is None:
|
|
566
|
+
context_prompt = ("You are an assistant for question-answering tasks. "
|
|
567
|
+
"Use the following pieces of retrieved context to answer the question. "
|
|
568
|
+
"If you don't know the answer, just say that you don't know. "
|
|
569
|
+
"Use three sentences maximum and keep the answer concise.\n\n{context}")
|
|
570
|
+
|
|
571
|
+
contextualize_q_system_prompt = ("Given a chat history and the latest user question "
|
|
572
|
+
"which might reference context in the chat history, "
|
|
573
|
+
"formulate a standalone question which can be understood "
|
|
574
|
+
"without the chat history. Do NOT answer the question, "
|
|
575
|
+
"just reformulate it if needed and otherwise return it as is.")
|
|
576
|
+
|
|
577
|
+
contextualize_q_prompt = ChatPromptTemplate.from_messages([
|
|
578
|
+
("system", contextualize_q_system_prompt),
|
|
579
|
+
MessagesPlaceholder("chat_history"),
|
|
580
|
+
("human", "{input}"),
|
|
581
|
+
])
|
|
582
|
+
|
|
583
|
+
if retriever is None:
|
|
584
|
+
retriever = self.retriever
|
|
585
|
+
if llm is None:
|
|
586
|
+
if not ModelProvider.check_package("langchain_openai"):
|
|
587
|
+
raise ImportError("OpenAI package not found. Please install: pip install langchain-openai")
|
|
588
|
+
from langchain_openai import ChatOpenAI
|
|
589
|
+
llm = ChatOpenAI(model="gpt-4")
|
|
590
|
+
|
|
591
|
+
history_aware_retriever = create_history_aware_retriever(llm, retriever,
|
|
592
|
+
contextualize_q_prompt)
|
|
593
|
+
qa_prompt = ChatPromptTemplate.from_messages([
|
|
594
|
+
("system", context_prompt),
|
|
595
|
+
MessagesPlaceholder("chat_history"),
|
|
596
|
+
("human", "{input}"),
|
|
597
|
+
])
|
|
598
|
+
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
|
|
599
|
+
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
|
|
600
|
+
return rag_chain
|
|
601
|
+
|
|
602
|
+
def conversation_chain(self, query: str, rag_chain, file: str = None):
|
|
603
|
+
"""
|
|
604
|
+
Create conversation chain.
|
|
605
|
+
|
|
606
|
+
Args:
|
|
607
|
+
query (str): User query
|
|
608
|
+
rag_chain: RAG chain instance
|
|
609
|
+
file (str): Optional file to save conversation
|
|
610
|
+
|
|
611
|
+
Returns:
|
|
612
|
+
List: Conversation history
|
|
613
|
+
|
|
614
|
+
Example:
|
|
615
|
+
```python
|
|
616
|
+
history = gen.conversation_chain(
|
|
617
|
+
"Tell me about...",
|
|
618
|
+
rag_chain,
|
|
619
|
+
file='conversation.txt'
|
|
620
|
+
)
|
|
621
|
+
```
|
|
622
|
+
"""
|
|
623
|
+
if file is not None:
|
|
624
|
+
try:
|
|
625
|
+
chat_history = self.load_conversation(file, list_type=True)
|
|
626
|
+
if len(chat_history) == 0:
|
|
627
|
+
chat_history = []
|
|
628
|
+
except:
|
|
629
|
+
chat_history = []
|
|
630
|
+
else:
|
|
631
|
+
chat_history = []
|
|
632
|
+
|
|
633
|
+
query = "You : " + query
|
|
634
|
+
res = rag_chain.invoke({"input": query, "chat_history": chat_history})
|
|
635
|
+
print(f"Response: {res['answer']}")
|
|
636
|
+
chat_history.append(HumanMessage(content=query))
|
|
637
|
+
chat_history.append(SystemMessage(content=res['answer']))
|
|
638
|
+
if file is not None:
|
|
639
|
+
self.save_conversation(chat_history, file)
|
|
640
|
+
return chat_history
|
|
641
|
+
|
|
642
|
+
def load_conversation(self, file: str, list_type: bool = False):
|
|
643
|
+
"""
|
|
644
|
+
Load conversation history.
|
|
645
|
+
|
|
646
|
+
Args:
|
|
647
|
+
file (str): Path to conversation file
|
|
648
|
+
list_type (bool): Whether to return as list
|
|
649
|
+
|
|
650
|
+
Returns:
|
|
651
|
+
Union[str, List]: Conversation history
|
|
652
|
+
"""
|
|
653
|
+
if list_type:
|
|
654
|
+
chat_history = []
|
|
655
|
+
with open(file, 'r') as f:
|
|
656
|
+
for line in f:
|
|
657
|
+
chat_history.append(line.strip())
|
|
658
|
+
else:
|
|
659
|
+
with open(file, "r") as f:
|
|
660
|
+
chat_history = f.read()
|
|
661
|
+
return chat_history
|
|
662
|
+
|
|
663
|
+
def save_conversation(self, chat: Union[str, List], file: str):
|
|
664
|
+
"""
|
|
665
|
+
Save conversation history.
|
|
666
|
+
|
|
667
|
+
Args:
|
|
668
|
+
chat (Union[str, List]): Conversation to save
|
|
669
|
+
file (str): Path to save file
|
|
670
|
+
"""
|
|
671
|
+
if isinstance(chat, str):
|
|
672
|
+
with open(file, "a") as f:
|
|
673
|
+
f.write(chat)
|
|
674
|
+
elif isinstance(chat, list):
|
|
675
|
+
with open(file, "a") as f:
|
|
676
|
+
for i in chat[-2:]:
|
|
677
|
+
f.write("%s\n" % i)
|
|
678
|
+
print(f"Saved file : {file}")
|
|
679
|
+
|
|
680
|
+
def firecrawl_web(self, website: str, api_key: str = None, mode: str = "scrape",
|
|
681
|
+
file_to_save: str = './firecrawl_embeddings', **kwargs):
|
|
682
|
+
"""
|
|
683
|
+
Get data from website using FireCrawl.
|
|
684
|
+
|
|
685
|
+
Args:
|
|
686
|
+
website (str): Website URL to crawl
|
|
687
|
+
api_key (str): Optional FireCrawl API key
|
|
688
|
+
mode (str): Crawl mode (default: "scrape")
|
|
689
|
+
file_to_save (str): Path to save embeddings
|
|
690
|
+
**kwargs: Additional arguments for FireCrawl
|
|
691
|
+
|
|
692
|
+
Returns:
|
|
693
|
+
Chroma: Vector store with crawled data
|
|
694
|
+
|
|
695
|
+
Example:
|
|
696
|
+
```python
|
|
697
|
+
db = gen.firecrawl_web(
|
|
698
|
+
"https://example.com",
|
|
699
|
+
mode="scrape",
|
|
700
|
+
file_to_save='./crawl_embeddings'
|
|
701
|
+
)
|
|
702
|
+
```
|
|
703
|
+
"""
|
|
704
|
+
if not ModelProvider.check_package("firecrawl"):
|
|
705
|
+
raise ImportError("Firecrawl package not found. Please install: pip install firecrawl")
|
|
706
|
+
|
|
707
|
+
if api_key is None:
|
|
708
|
+
api_key = os.getenv("FIRECRAWL_API_KEY")
|
|
709
|
+
|
|
710
|
+
loader = FireCrawlLoader(api_key=api_key, url=website, mode=mode)
|
|
711
|
+
docs = loader.load()
|
|
712
|
+
|
|
713
|
+
for doc in docs:
|
|
714
|
+
for key, value in doc.metadata.items():
|
|
715
|
+
if isinstance(value, list):
|
|
716
|
+
doc.metadata[key] = ", ".join(map(str, value))
|
|
717
|
+
|
|
718
|
+
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
|
719
|
+
split_docs = text_splitter.split_documents(docs)
|
|
720
|
+
|
|
721
|
+
print("\n--- Document Chunks Information ---")
|
|
722
|
+
print(f"Number of document chunks: {len(split_docs)}")
|
|
723
|
+
print(f"Sample chunk:\n{split_docs[0].page_content}\n")
|
|
724
|
+
|
|
725
|
+
embeddings = self.model
|
|
726
|
+
db = Chroma.from_documents(split_docs, embeddings,
|
|
727
|
+
persist_directory=file_to_save)
|
|
728
|
+
print(f"Retriever saved at {file_to_save}")
|
|
729
|
+
return db
|
|
@@ -1,480 +0,0 @@
|
|
|
1
|
-
## Function to generate embeddings for the RAG model
|
|
2
|
-
|
|
3
|
-
import os
|
|
4
|
-
import shutil
|
|
5
|
-
import importlib.util
|
|
6
|
-
from langchain.text_splitter import (
|
|
7
|
-
CharacterTextSplitter,
|
|
8
|
-
RecursiveCharacterTextSplitter,
|
|
9
|
-
SentenceTransformersTokenTextSplitter,
|
|
10
|
-
TokenTextSplitter)
|
|
11
|
-
from langchain_community.document_loaders import TextLoader,FireCrawlLoader
|
|
12
|
-
from langchain_community.vectorstores import Chroma
|
|
13
|
-
from ..utils.extra import load_env_file
|
|
14
|
-
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
|
|
15
|
-
from langchain.chains.combine_documents import create_stuff_documents_chain
|
|
16
|
-
from langchain_core.messages import HumanMessage, SystemMessage
|
|
17
|
-
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
18
|
-
import time
|
|
19
|
-
|
|
20
|
-
load_env_file()
|
|
21
|
-
|
|
22
|
-
__all__ = ['embedding_generator', 'load_rag_model']
|
|
23
|
-
|
|
24
|
-
def check_package(package_name):
|
|
25
|
-
"""
|
|
26
|
-
Check if a package is installed
|
|
27
|
-
Args:
|
|
28
|
-
package_name (str): Name of the package
|
|
29
|
-
Returns:
|
|
30
|
-
bool: True if package is installed, False otherwise
|
|
31
|
-
"""
|
|
32
|
-
return importlib.util.find_spec(package_name) is not None
|
|
33
|
-
|
|
34
|
-
def get_rag_openai(model_type: str = 'text-embedding-3-small',**kwargs):
|
|
35
|
-
"""
|
|
36
|
-
Load model from openai for RAG
|
|
37
|
-
Args:
|
|
38
|
-
model_type (str): Name of the model
|
|
39
|
-
**kwargs: Additional arguments (temperature, max_tokens, timeout, max_retries, api_key etc.)
|
|
40
|
-
Returns:
|
|
41
|
-
ChatOpenAI: Chatbot model
|
|
42
|
-
"""
|
|
43
|
-
if not check_package("langchain_openai"):
|
|
44
|
-
raise ImportError("OpenAI package not found. Please install it using: pip install langchain-openai")
|
|
45
|
-
|
|
46
|
-
from langchain_openai import OpenAIEmbeddings
|
|
47
|
-
return OpenAIEmbeddings(model = model_type,**kwargs)
|
|
48
|
-
|
|
49
|
-
def get_rag_ollama(model_type: str = 'llama3',**kwargs):
|
|
50
|
-
"""
|
|
51
|
-
Load model from ollama for RAG
|
|
52
|
-
Args:
|
|
53
|
-
model_type (str): Name of the model
|
|
54
|
-
**kwargs: Additional arguments (temperature, max_tokens, timeout, max_retries, api_key etc.)
|
|
55
|
-
Returns:
|
|
56
|
-
OllamaEmbeddings: Embeddings model
|
|
57
|
-
"""
|
|
58
|
-
if not check_package("langchain_ollama"):
|
|
59
|
-
raise ImportError("Ollama package not found. Please install it using: pip install langchain-ollama")
|
|
60
|
-
|
|
61
|
-
from langchain_ollama import OllamaEmbeddings
|
|
62
|
-
return OllamaEmbeddings(model = model_type,**kwargs)
|
|
63
|
-
|
|
64
|
-
def get_rag_anthropic(model_name: str = "claude-3-opus-20240229",**kwargs):
|
|
65
|
-
"""
|
|
66
|
-
Load the chatbot model from Anthropic
|
|
67
|
-
Args:
|
|
68
|
-
model_name (str): Name of the model
|
|
69
|
-
**kwargs: Additional arguments (temperature, max_tokens, timeout, max_retries, api_key etc.)
|
|
70
|
-
Returns:
|
|
71
|
-
ChatAnthropic: Chatbot model
|
|
72
|
-
"""
|
|
73
|
-
if not check_package("langchain_anthropic"):
|
|
74
|
-
raise ImportError("Anthropic package not found. Please install it using: pip install langchain-anthropic")
|
|
75
|
-
|
|
76
|
-
from langchain_anthropic import ChatAnthropic
|
|
77
|
-
kwargs["model_name"] = model_name
|
|
78
|
-
return ChatAnthropic(**kwargs)
|
|
79
|
-
|
|
80
|
-
def get_rag_google(model_name: str = "gemini-1.5-flash",**kwargs):
|
|
81
|
-
"""
|
|
82
|
-
Load the chatbot model from Google Generative AI
|
|
83
|
-
Args:
|
|
84
|
-
model_name (str): Name of the model
|
|
85
|
-
**kwargs: Additional arguments (temperature, max_tokens, timeout, max_retries, api_key etc.)
|
|
86
|
-
Returns:
|
|
87
|
-
ChatGoogleGenerativeAI: Chatbot model
|
|
88
|
-
"""
|
|
89
|
-
if not check_package("google.generativeai"):
|
|
90
|
-
raise ImportError("Google Generative AI package not found. Please install it using: pip install langchain-google-genai")
|
|
91
|
-
|
|
92
|
-
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
|
93
|
-
kwargs["model"] = model_name
|
|
94
|
-
return GoogleGenerativeAIEmbeddings(**kwargs)
|
|
95
|
-
|
|
96
|
-
def load_rag_model(model_name: str ='openai', model_type: str = "text-embedding-ada-002", **kwargs):
|
|
97
|
-
"""
|
|
98
|
-
Load a RAG model from a given model name and type
|
|
99
|
-
Args:
|
|
100
|
-
model_name (str): Name of the model. Default is openai.
|
|
101
|
-
model_type (str): Type of the model. Default is text-embedding-ada-002.
|
|
102
|
-
**kwargs: Additional arguments (temperature, max_tokens, timeout, max_retries, api_key etc.)
|
|
103
|
-
Returns:
|
|
104
|
-
RAGModel: RAG model
|
|
105
|
-
"""
|
|
106
|
-
try:
|
|
107
|
-
if model_name == 'openai':
|
|
108
|
-
return get_rag_openai(model_type, **(kwargs or {}))
|
|
109
|
-
elif model_name == 'ollama':
|
|
110
|
-
return get_rag_ollama(model_type, **(kwargs or {}))
|
|
111
|
-
elif model_name == 'google':
|
|
112
|
-
return get_rag_google(model_type, **(kwargs or {}))
|
|
113
|
-
elif model_name == 'anthropic':
|
|
114
|
-
return get_rag_anthropic(model_type, **(kwargs or {}))
|
|
115
|
-
else:
|
|
116
|
-
raise ValueError(f"Invalid model name: {model_name}")
|
|
117
|
-
except ImportError as e:
|
|
118
|
-
print(f"Error loading model: {str(e)}")
|
|
119
|
-
return None
|
|
120
|
-
|
|
121
|
-
class embedding_generator:
|
|
122
|
-
"""
|
|
123
|
-
Class to generate embeddings for the RAG model abnd chat with data
|
|
124
|
-
Args:
|
|
125
|
-
model: type of model. Default is openai. Options are openai, anthropic, google, ollama
|
|
126
|
-
model_type: type of model. Default is text-embedding-3-small. Options are text-embedding-3-small, text-embedding-3-large, text-embedding-ada-002 for openai.
|
|
127
|
-
vector_store_type: type of vector store. Default is chroma
|
|
128
|
-
logger: logger
|
|
129
|
-
model_kwargs: additional arguments for the model
|
|
130
|
-
vector_store_kwargs: additional arguments for the vector store
|
|
131
|
-
collection_name: name of the collection (default : test)
|
|
132
|
-
"""
|
|
133
|
-
|
|
134
|
-
def __init__(self,model: str = 'openai',model_type: str = 'text-embedding-3-small',vector_store_type:str = 'chroma' ,collection_name: str = 'test',logger= None,model_kwargs: dict = None, vector_store_kwargs: dict = None) -> None:
|
|
135
|
-
self.logger = logger
|
|
136
|
-
self.model = load_rag_model(model_name=model, model_type=model_type, **(model_kwargs or {}))
|
|
137
|
-
if self.model is None:
|
|
138
|
-
raise ValueError(f"Failed to initialize model {model}. Please ensure required packages are installed.")
|
|
139
|
-
self.vector_store_type = vector_store_type
|
|
140
|
-
self.vector_store = self.load_vectorstore(**(vector_store_kwargs or {}))
|
|
141
|
-
self.collection_name = collection_name
|
|
142
|
-
|
|
143
|
-
def check_file(self, file_path):
|
|
144
|
-
"""
|
|
145
|
-
Check if the file exists
|
|
146
|
-
"""
|
|
147
|
-
if os.path.exists(file_path):
|
|
148
|
-
return True
|
|
149
|
-
else:
|
|
150
|
-
return False
|
|
151
|
-
|
|
152
|
-
def tokenize(self,text_data_path :list,text_splitter_type: str,chunk_size: int,chunk_overlap: int):
|
|
153
|
-
"""
|
|
154
|
-
Function to tokenize the text
|
|
155
|
-
Args:
|
|
156
|
-
text: text to tokenize
|
|
157
|
-
Returns:
|
|
158
|
-
tokens
|
|
159
|
-
"""
|
|
160
|
-
doc_data = []
|
|
161
|
-
for i in text_data_path:
|
|
162
|
-
if self.check_file(i):
|
|
163
|
-
text_loader = TextLoader(i)
|
|
164
|
-
get_text = text_loader.load()
|
|
165
|
-
# print(get_text) ## testing - Need to remove
|
|
166
|
-
file_name = i.split('/')[-1]
|
|
167
|
-
metadata = {'source': file_name}
|
|
168
|
-
if metadata is not None:
|
|
169
|
-
for j in get_text:
|
|
170
|
-
j.metadata = metadata
|
|
171
|
-
doc_data.append(j)
|
|
172
|
-
if self.logger is not None:
|
|
173
|
-
self.logger.info(f"Text data loaded from {file_name}")
|
|
174
|
-
else:
|
|
175
|
-
return f"File {i} not found"
|
|
176
|
-
|
|
177
|
-
if self.logger is not None:
|
|
178
|
-
self.logger.info(f"Splitting text data into chunks of size {chunk_size} with overlap {chunk_overlap}")
|
|
179
|
-
if text_splitter_type == 'character':
|
|
180
|
-
text_splitter = CharacterTextSplitter(chunk_size=chunk_size,chunk_overlap=chunk_overlap, separator=["\n","\n\n","\n\n\n"," "])
|
|
181
|
-
if text_splitter_type == 'recursive_character':
|
|
182
|
-
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,chunk_overlap=chunk_overlap,separators=["\n","\n\n","\n\n\n"," "])
|
|
183
|
-
if text_splitter_type == 'sentence_transformers_token':
|
|
184
|
-
text_splitter = SentenceTransformersTokenTextSplitter(chunk_size=chunk_size)
|
|
185
|
-
if text_splitter_type == 'token':
|
|
186
|
-
text_splitter = TokenTextSplitter(chunk_size=chunk_size,chunk_overlap=chunk_overlap)
|
|
187
|
-
docs = text_splitter.split_documents(doc_data)
|
|
188
|
-
if self.logger is not None:
|
|
189
|
-
self.logger.info(f"Text data splitted into {len(docs)} chunks")
|
|
190
|
-
else:
|
|
191
|
-
print(f"Text data splitted into {len(docs)} chunks")
|
|
192
|
-
return docs
|
|
193
|
-
|
|
194
|
-
def generate_text_embeddings(self,text_data_path: list = None,text_splitter_type: str = 'recursive_character',
|
|
195
|
-
chunk_size: int = 1000,chunk_overlap: int = 5,folder_save_path: str = './text_embeddings',
|
|
196
|
-
replace_existing: bool = False):
|
|
197
|
-
"""
|
|
198
|
-
Function to generate text embeddings
|
|
199
|
-
Args:
|
|
200
|
-
text_data_path: list of text files
|
|
201
|
-
# metadata: list of metadata for each text file. Dictionary format
|
|
202
|
-
text_splitter_type: type of text splitter. Default is recursive_character
|
|
203
|
-
chunk_size: size of the chunk
|
|
204
|
-
chunk_overlap: overlap between chunks
|
|
205
|
-
folder_save_path: path to save the embeddings
|
|
206
|
-
replace_existing: if True, replace the existing embeddings
|
|
207
|
-
Returns:
|
|
208
|
-
None
|
|
209
|
-
"""
|
|
210
|
-
|
|
211
|
-
if self.logger is not None:
|
|
212
|
-
self.logger.info("Perforing basic checks")
|
|
213
|
-
|
|
214
|
-
if self.check_file(folder_save_path) and replace_existing==False:
|
|
215
|
-
return "File already exists"
|
|
216
|
-
elif self.check_file(folder_save_path) and replace_existing:
|
|
217
|
-
shutil.rmtree(folder_save_path)
|
|
218
|
-
|
|
219
|
-
if text_data_path is None:
|
|
220
|
-
return "Please provide text data path"
|
|
221
|
-
|
|
222
|
-
assert isinstance(text_data_path, list), "text_data_path should be a list"
|
|
223
|
-
# if metadata is not None:
|
|
224
|
-
# assert isinstance(metadata, list), "metadata should be a list"
|
|
225
|
-
# assert len(text_data_path) == len(metadata), "Number of text files and metadata should be equal"
|
|
226
|
-
|
|
227
|
-
if self.logger is not None:
|
|
228
|
-
self.logger.info(f"Loading text data from {text_data_path}")
|
|
229
|
-
|
|
230
|
-
docs = self.tokenize(text_data_path,text_splitter_type,chunk_size,chunk_overlap)
|
|
231
|
-
|
|
232
|
-
if self.logger is not None:
|
|
233
|
-
self.logger.info(f"Generating embeddings for {len(docs)} documents")
|
|
234
|
-
|
|
235
|
-
self.vector_store.from_documents(docs, self.model,collection_name=self.collection_name,persist_directory=folder_save_path)
|
|
236
|
-
|
|
237
|
-
if self.logger is not None:
|
|
238
|
-
self.logger.info(f"Embeddings generated and saved at {folder_save_path}")
|
|
239
|
-
|
|
240
|
-
def load_vectorstore(self):
|
|
241
|
-
"""
|
|
242
|
-
Function to load vector store
|
|
243
|
-
Args:
|
|
244
|
-
vector_store_type: type of vector store
|
|
245
|
-
Returns:
|
|
246
|
-
vector store
|
|
247
|
-
"""
|
|
248
|
-
if self.vector_store_type == 'chroma':
|
|
249
|
-
vector_store = Chroma()
|
|
250
|
-
if self.logger is not None:
|
|
251
|
-
self.logger.info(f"Loaded vector store {self.vector_store_type}")
|
|
252
|
-
return vector_store
|
|
253
|
-
else:
|
|
254
|
-
return "Vector store not found"
|
|
255
|
-
|
|
256
|
-
def load_embeddings(self,embeddings_folder_path: str):
|
|
257
|
-
"""
|
|
258
|
-
Function to load embeddings from the folder
|
|
259
|
-
Args:
|
|
260
|
-
embeddings_path: path to the embeddings
|
|
261
|
-
Returns:
|
|
262
|
-
embeddings
|
|
263
|
-
"""
|
|
264
|
-
if self.check_file(embeddings_folder_path):
|
|
265
|
-
if self.vector_store_type == 'chroma':
|
|
266
|
-
# embeddings_path = os.path.join(embeddings_folder_path)
|
|
267
|
-
return Chroma(persist_directory = embeddings_folder_path,embedding_function=self.model)
|
|
268
|
-
else:
|
|
269
|
-
if self.logger:
|
|
270
|
-
self.logger.info("Embeddings file not found")
|
|
271
|
-
return None
|
|
272
|
-
|
|
273
|
-
def load_retriever(self,embeddings_folder_path: str,search_type: list = ["similarity_score_threshold"],search_params: list = [{"k": 3, "score_threshold": 0.9}]):
|
|
274
|
-
"""
|
|
275
|
-
Function to load retriever
|
|
276
|
-
Args:
|
|
277
|
-
embeddings_path: path to the embeddings
|
|
278
|
-
search_type: list of str: type of search. Default : ["similarity_score_threshold"]
|
|
279
|
-
search_params: list of dict: parameters for the search. Default : [{"k": 3, "score_threshold": 0.9}]
|
|
280
|
-
Returns:
|
|
281
|
-
Retriever. If multiple search types are provided, a list of retrievers is returned
|
|
282
|
-
"""
|
|
283
|
-
db = self.load_embeddings(embeddings_folder_path)
|
|
284
|
-
if db is not None:
|
|
285
|
-
if self.vector_store_type == 'chroma':
|
|
286
|
-
assert len(search_type) == len(search_params), "Length of search_type and search_params should be equal"
|
|
287
|
-
if len(search_type) == 1:
|
|
288
|
-
self.retriever = db.as_retriever(search_type = search_type[0],search_kwargs=search_params[0])
|
|
289
|
-
if self.logger:
|
|
290
|
-
self.logger.info("Retriever loaded")
|
|
291
|
-
return self.retriever
|
|
292
|
-
else:
|
|
293
|
-
retriever_list = []
|
|
294
|
-
for i in range(len(search_type)):
|
|
295
|
-
retriever_list.append(db.as_retriever(search_type = search_type[i],search_kwargs=search_params[i]))
|
|
296
|
-
if self.logger:
|
|
297
|
-
self.logger.info("List of Retriever loaded")
|
|
298
|
-
return retriever_list
|
|
299
|
-
else:
|
|
300
|
-
return "Embeddings file not found"
|
|
301
|
-
|
|
302
|
-
def add_data(self,embeddings_folder_path: str, data: list,text_splitter_type: str = 'recursive_character',
|
|
303
|
-
chunk_size: int = 1000,chunk_overlap: int = 5):
|
|
304
|
-
"""
|
|
305
|
-
Function to add data to the existing db/embeddings
|
|
306
|
-
Args:
|
|
307
|
-
embeddings_path: path to the embeddings
|
|
308
|
-
data: list of data to add
|
|
309
|
-
text_splitter_type: type of text splitter. Default is recursive_character
|
|
310
|
-
chunk_size: size of the chunk
|
|
311
|
-
chunk_overlap: overlap between chunks
|
|
312
|
-
Returns:
|
|
313
|
-
None
|
|
314
|
-
"""
|
|
315
|
-
if self.vector_store_type == 'chroma':
|
|
316
|
-
db = self.load_embeddings(embeddings_folder_path)
|
|
317
|
-
if db is not None:
|
|
318
|
-
docs = self.tokenize(data,text_splitter_type,chunk_size,chunk_overlap)
|
|
319
|
-
db.add_documents(docs)
|
|
320
|
-
if self.logger:
|
|
321
|
-
self.logger.info("Data added to the existing db/embeddings")
|
|
322
|
-
|
|
323
|
-
def query_embeddings(self,query: str,retriever = None):
|
|
324
|
-
"""
|
|
325
|
-
Function to query embeddings
|
|
326
|
-
Args:
|
|
327
|
-
search_type: type of search
|
|
328
|
-
query: query to search
|
|
329
|
-
Returns:
|
|
330
|
-
results
|
|
331
|
-
"""
|
|
332
|
-
# if self.vector_store_type == 'chroma':
|
|
333
|
-
if retriever is None:
|
|
334
|
-
retriever = self.retriever
|
|
335
|
-
return retriever.invoke(query)
|
|
336
|
-
# else:
|
|
337
|
-
# return "Vector store not found"
|
|
338
|
-
|
|
339
|
-
def get_relevant_documents(self,query: str,retriever = None):
|
|
340
|
-
"""
|
|
341
|
-
Function to get relevant documents
|
|
342
|
-
Args:
|
|
343
|
-
query: query to search
|
|
344
|
-
Returns:
|
|
345
|
-
results
|
|
346
|
-
"""
|
|
347
|
-
return self.retriever.get_relevant_documents(query)
|
|
348
|
-
|
|
349
|
-
def generate_rag_chain(self,context_prompt: str = None,retriever = None,llm= None):
|
|
350
|
-
"""
|
|
351
|
-
Function to start a conversation chain with a rag data. Call this to load a rag_chain module.
|
|
352
|
-
Args:
|
|
353
|
-
context_prompt: prompt to context
|
|
354
|
-
retriever: retriever. Default is None.
|
|
355
|
-
llm: language model. Default is openai. Need chat model llm. "ChatOpenAI", "ChatAnthropic" etc. like chatbot
|
|
356
|
-
Returns:
|
|
357
|
-
rag_chain_model.
|
|
358
|
-
"""
|
|
359
|
-
if context_prompt is None:
|
|
360
|
-
context_prompt = ("You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. "
|
|
361
|
-
"If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise. "
|
|
362
|
-
"\n\n {context}")
|
|
363
|
-
contextualize_q_system_prompt = ("Given a chat history and the latest user question which might reference context in the chat history, formulate a standalone question which can be understood "
|
|
364
|
-
"without the chat history. Do NOT answer the question, just reformulate it if needed and otherwise return it as is.")
|
|
365
|
-
contextualize_q_prompt = ChatPromptTemplate.from_messages([("system", contextualize_q_system_prompt),MessagesPlaceholder("chat_history"),("human", "{input}"),])
|
|
366
|
-
|
|
367
|
-
if retriever is None:
|
|
368
|
-
retriever = self.retriever
|
|
369
|
-
if llm is None:
|
|
370
|
-
if not check_package("langchain_openai"):
|
|
371
|
-
raise ImportError("OpenAI package not found. Please install it using: pip install langchain-openai")
|
|
372
|
-
from langchain_openai import ChatOpenAI
|
|
373
|
-
llm = ChatOpenAI(model="gpt-4")
|
|
374
|
-
|
|
375
|
-
history_aware_retriever = create_history_aware_retriever(llm,retriever, contextualize_q_prompt)
|
|
376
|
-
qa_prompt = ChatPromptTemplate.from_messages([("system", context_prompt),MessagesPlaceholder("chat_history"),("human", "{input}"),])
|
|
377
|
-
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
|
|
378
|
-
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
|
|
379
|
-
return rag_chain
|
|
380
|
-
|
|
381
|
-
def conversation_chain(self,query: str,rag_chain,file:str =None):
|
|
382
|
-
"""
|
|
383
|
-
Function to create a conversation chain
|
|
384
|
-
Args:
|
|
385
|
-
query: query to search
|
|
386
|
-
rag_chain : rag_chain model
|
|
387
|
-
file: load a file and update it with the conversation. If None it will not be saved.
|
|
388
|
-
Returns:
|
|
389
|
-
results
|
|
390
|
-
"""
|
|
391
|
-
if file is not None:
|
|
392
|
-
try:
|
|
393
|
-
chat_history = self.load_conversation(file,list_type=True)
|
|
394
|
-
if len(chat_history) == 0:
|
|
395
|
-
chat_history = []
|
|
396
|
-
except:
|
|
397
|
-
chat_history = []
|
|
398
|
-
else:
|
|
399
|
-
chat_history = []
|
|
400
|
-
query = "You : " + query
|
|
401
|
-
res = rag_chain.invoke({"input": query,"chat_history": chat_history})
|
|
402
|
-
print(f"Response: {res['answer']}")
|
|
403
|
-
chat_history.append(HumanMessage(content=query))
|
|
404
|
-
chat_history.append(SystemMessage(content=res['answer']))
|
|
405
|
-
if file is not None:
|
|
406
|
-
self.save_conversation(chat_history,file)
|
|
407
|
-
return chat_history
|
|
408
|
-
|
|
409
|
-
def load_conversation(self,file: str,list_type: bool = False):
|
|
410
|
-
"""
|
|
411
|
-
Function to load the conversation
|
|
412
|
-
Args:
|
|
413
|
-
file: file to load
|
|
414
|
-
list_type: if True, return the chat_history as a list. Default is False.
|
|
415
|
-
Returns:
|
|
416
|
-
chat_history
|
|
417
|
-
"""
|
|
418
|
-
if list_type:
|
|
419
|
-
chat_history = []
|
|
420
|
-
with open(file,'r') as f:
|
|
421
|
-
for line in f:
|
|
422
|
-
# inner_list = [elt.strip() for elt in line.split(',')]
|
|
423
|
-
chat_history.append(line.strip())
|
|
424
|
-
else:
|
|
425
|
-
with open(file, "r") as f:
|
|
426
|
-
chat_history = f.read()
|
|
427
|
-
return chat_history
|
|
428
|
-
|
|
429
|
-
def save_conversation(self,chat: str,file: str):
|
|
430
|
-
"""
|
|
431
|
-
Function to save the conversation
|
|
432
|
-
Args:
|
|
433
|
-
chat: chat results
|
|
434
|
-
file: file to save
|
|
435
|
-
Returns:
|
|
436
|
-
None
|
|
437
|
-
"""
|
|
438
|
-
if isinstance(chat,str):
|
|
439
|
-
with open(file, "a") as f:
|
|
440
|
-
f.write(chat)
|
|
441
|
-
elif isinstance(chat,list):
|
|
442
|
-
with open(file, "a") as f:
|
|
443
|
-
for i in chat[-2:]:
|
|
444
|
-
f.write("%s\n" % i)
|
|
445
|
-
print(f"Saved file : {file}")
|
|
446
|
-
|
|
447
|
-
def firecrawl_web(self, website, api_key: str = None, mode="scrape", file_to_save: str = './firecrawl_embeddings',**kwargs):
|
|
448
|
-
"""
|
|
449
|
-
Function to get data from website. Use this to get data from a website and save it as embeddings/retriever. To ask questions from the website,
|
|
450
|
-
use the load_retriever and query_embeddings function.
|
|
451
|
-
Args:
|
|
452
|
-
website : str - link to website.
|
|
453
|
-
api_key : api key of firecrawl, if None environment variable "FIRECRAWL_API_KEY" will be used.
|
|
454
|
-
mode(str) : 'scrape' default to just use the same page. Not the whole website.
|
|
455
|
-
file_to_save: path to save the embeddings
|
|
456
|
-
**kwargs: additional arguments
|
|
457
|
-
Returns:
|
|
458
|
-
retriever
|
|
459
|
-
"""
|
|
460
|
-
if not check_package("firecrawl"):
|
|
461
|
-
raise ImportError("Firecrawl package not found. Please install it using: pip install firecrawl")
|
|
462
|
-
|
|
463
|
-
if api_key is None:
|
|
464
|
-
api_key = os.getenv("FIRECRAWL_API_KEY")
|
|
465
|
-
loader = FireCrawlLoader(api_key=api_key, url=website, mode=mode)
|
|
466
|
-
docs = loader.load()
|
|
467
|
-
for doc in docs:
|
|
468
|
-
for key, value in doc.metadata.items():
|
|
469
|
-
if isinstance(value, list):
|
|
470
|
-
doc.metadata[key] = ", ".join(map(str, value))
|
|
471
|
-
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
|
472
|
-
split_docs = text_splitter.split_documents(docs)
|
|
473
|
-
print("\n--- Document Chunks Information ---")
|
|
474
|
-
print(f"Number of document chunks: {len(split_docs)}")
|
|
475
|
-
print(f"Sample chunk:\n{split_docs[0].page_content}\n")
|
|
476
|
-
embeddings = self.model
|
|
477
|
-
db = Chroma.from_documents(
|
|
478
|
-
split_docs, embeddings, persist_directory=file_to_save)
|
|
479
|
-
print(f"Retriever saved at {file_to_save}")
|
|
480
|
-
return db
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|