mb-rag 1.0.124__py3-none-any.whl → 1.0.126__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mb-rag might be problematic. Click here for more details.

mb_rag/rag/embeddings.py CHANGED
@@ -1,480 +1,728 @@
1
- ## Function to generate embeddings for the RAG model
1
+ """
2
+ RAG (Retrieval-Augmented Generation) Embeddings Module
3
+
4
+ This module provides functionality for generating and managing embeddings for RAG models.
5
+ It supports multiple embedding models (OpenAI, Ollama, Google, Anthropic) and includes
6
+ features for text processing, embedding generation, vector store management, and
7
+ conversation handling.
8
+
9
+ Example Usage:
10
+ ```python
11
+ # Initialize embedding generator
12
+ em_gen = embedding_generator(
13
+ model="openai",
14
+ model_type="text-embedding-3-small",
15
+ vector_store_type="chroma"
16
+ )
17
+
18
+ # Generate embeddings from text
19
+ em_gen.generate_text_embeddings(
20
+ text_data_path=['./data/text.txt'],
21
+ chunk_size=500,
22
+ chunk_overlap=5,
23
+ folder_save_path='./embeddings'
24
+ )
25
+
26
+ # Load embeddings and create retriever
27
+ em_loading = em_gen.load_embeddings('./embeddings')
28
+ em_retriever = em_gen.load_retriever(
29
+ './embeddings',
30
+ search_params=[{"k": 2, "score_threshold": 0.1}]
31
+ )
32
+
33
+ # Query embeddings
34
+ results = em_retriever.invoke("What is the text about?")
35
+
36
+ # Generate RAG chain for conversation
37
+ rag_chain = em_gen.generate_rag_chain(retriever=em_retriever)
38
+ response = em_gen.conversation_chain("Tell me more", rag_chain)
39
+ ```
40
+
41
+ Features:
42
+ - Multiple model support (OpenAI, Ollama, Google, Anthropic)
43
+ - Text processing and chunking
44
+ - Embedding generation and storage
45
+ - Vector store management
46
+ - Retrieval operations
47
+ - Conversation chains
48
+ - Web crawling integration
49
+
50
+ Classes:
51
+ - ModelProvider: Base class for model loading and validation
52
+ - TextProcessor: Handles text processing operations
53
+ - embedding_generator: Main class for RAG operations
54
+ """
2
55
 
3
56
  import os
4
57
  import shutil
5
58
  import importlib.util
59
+ from typing import List, Dict, Optional, Union, Any
6
60
  from langchain.text_splitter import (
7
61
  CharacterTextSplitter,
8
62
  RecursiveCharacterTextSplitter,
9
63
  SentenceTransformersTokenTextSplitter,
10
64
  TokenTextSplitter)
11
- from langchain_community.document_loaders import TextLoader,FireCrawlLoader
12
- from langchain_community.vectorstores import Chroma
13
- from ..utils.extra import load_env_file
65
+ from langchain_community.document_loaders import TextLoader, FireCrawlLoader
66
+ from langchain_chroma import Chroma
67
+ from ..utils.extra import load_env_file
14
68
  from langchain.chains import create_history_aware_retriever, create_retrieval_chain
15
69
  from langchain.chains.combine_documents import create_stuff_documents_chain
16
70
  from langchain_core.messages import HumanMessage, SystemMessage
17
71
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
18
- import time
19
72
 
20
73
  load_env_file()
21
74
 
22
- __all__ = ['embedding_generator', 'load_rag_model']
75
+ __all__ = ['embedding_generator', 'load_embedding_model']
23
76
 
24
- def check_package(package_name):
77
+ class ModelProvider:
25
78
  """
26
- Check if a package is installed
27
- Args:
28
- package_name (str): Name of the package
29
- Returns:
30
- bool: True if package is installed, False otherwise
79
+ Base class for managing different model providers and their loading logic.
80
+
81
+ This class provides static methods for loading different types of embedding models
82
+ and checking package dependencies.
83
+
84
+ Methods:
85
+ check_package: Check if a Python package is installed
86
+ get_rag_openai: Load OpenAI embedding model
87
+ get_rag_ollama: Load Ollama embedding model
88
+ get_rag_anthropic: Load Anthropic model
89
+ get_rag_google: Load Google embedding model
90
+
91
+ Example:
92
+ ```python
93
+ # Check if a package is installed
94
+ has_openai = ModelProvider.check_package("langchain_openai")
95
+
96
+ # Load an OpenAI model
97
+ model = ModelProvider.get_rag_openai("text-embedding-3-small")
98
+ ```
31
99
  """
32
- return importlib.util.find_spec(package_name) is not None
33
-
34
- def get_rag_openai(model_type: str = 'text-embedding-3-small',**kwargs):
35
- """
36
- Load model from openai for RAG
37
- Args:
38
- model_type (str): Name of the model
39
- **kwargs: Additional arguments (temperature, max_tokens, timeout, max_retries, api_key etc.)
40
- Returns:
41
- ChatOpenAI: Chatbot model
42
- """
43
- if not check_package("langchain_openai"):
44
- raise ImportError("OpenAI package not found. Please install it using: pip install langchain-openai")
45
100
 
46
- from langchain_openai import OpenAIEmbeddings
47
- return OpenAIEmbeddings(model = model_type,**kwargs)
101
+ @staticmethod
102
+ def check_package(package_name: str) -> bool:
103
+ """
104
+ Check if a Python package is installed.
48
105
 
49
- def get_rag_ollama(model_type: str = 'llama3',**kwargs):
50
- """
51
- Load model from ollama for RAG
52
- Args:
53
- model_type (str): Name of the model
54
- **kwargs: Additional arguments (temperature, max_tokens, timeout, max_retries, api_key etc.)
55
- Returns:
56
- OllamaEmbeddings: Embeddings model
57
- """
58
- if not check_package("langchain_ollama"):
59
- raise ImportError("Ollama package not found. Please install it using: pip install langchain-ollama")
60
-
61
- from langchain_ollama import OllamaEmbeddings
62
- return OllamaEmbeddings(model = model_type,**kwargs)
106
+ Args:
107
+ package_name (str): Name of the package to check
63
108
 
64
- def get_rag_anthropic(model_name: str = "claude-3-opus-20240229",**kwargs):
65
- """
66
- Load the chatbot model from Anthropic
67
- Args:
68
- model_name (str): Name of the model
69
- **kwargs: Additional arguments (temperature, max_tokens, timeout, max_retries, api_key etc.)
70
- Returns:
71
- ChatAnthropic: Chatbot model
72
- """
73
- if not check_package("langchain_anthropic"):
74
- raise ImportError("Anthropic package not found. Please install it using: pip install langchain-anthropic")
75
-
76
- from langchain_anthropic import ChatAnthropic
77
- kwargs["model_name"] = model_name
78
- return ChatAnthropic(**kwargs)
109
+ """
110
+ return importlib.util.find_spec(package_name) is not None
79
111
 
80
- def get_rag_google(model_name: str = "gemini-1.5-flash",**kwargs):
81
- """
82
- Load the chatbot model from Google Generative AI
83
- Args:
84
- model_name (str): Name of the model
85
- **kwargs: Additional arguments (temperature, max_tokens, timeout, max_retries, api_key etc.)
86
- Returns:
87
- ChatGoogleGenerativeAI: Chatbot model
88
- """
89
- if not check_package("google.generativeai"):
90
- raise ImportError("Google Generative AI package not found. Please install it using: pip install langchain-google-genai")
91
-
92
- from langchain_google_genai import GoogleGenerativeAIEmbeddings
93
- kwargs["model"] = model_name
94
- return GoogleGenerativeAIEmbeddings(**kwargs)
112
+ @staticmethod
113
+ def get_rag_openai(model_type: str = 'text-embedding-3-small', **kwargs):
114
+ """
115
+ Load OpenAI embedding model.
116
+
117
+ Args:
118
+ model_type (str): Model identifier (default: 'text-embedding-3-small')
119
+ **kwargs: Additional arguments for model initialization
120
+
121
+ Returns:
122
+ OpenAIEmbeddings: Initialized OpenAI embeddings model
123
+ """
124
+ if not ModelProvider.check_package("langchain_openai"):
125
+ raise ImportError("OpenAI package not found. Please install: pip install langchain-openai")
126
+ from langchain_openai import OpenAIEmbeddings
127
+ return OpenAIEmbeddings(model=model_type, **kwargs)
128
+
129
+ @staticmethod
130
+ def get_rag_ollama(model_type: str = 'llama3', **kwargs):
131
+ """
132
+ Load Ollama embedding model.
95
133
 
96
- def load_rag_model(model_name: str ='openai', model_type: str = "text-embedding-ada-002", **kwargs):
134
+ Args:
135
+ model_type (str): Model identifier (default: 'llama3')
136
+ **kwargs: Additional arguments for model initialization
137
+
138
+ Returns:
139
+ OllamaEmbeddings: Initialized Ollama embeddings model
140
+ """
141
+ if not ModelProvider.check_package("langchain_ollama"):
142
+ raise ImportError("Ollama package not found. Please install: pip install langchain-ollama")
143
+ from langchain_ollama import OllamaEmbeddings
144
+ return OllamaEmbeddings(model=model_type, **kwargs)
145
+
146
+ @staticmethod
147
+ def get_rag_anthropic(model_name: str = "claude-3-opus-20240229", **kwargs):
148
+ """
149
+ Load Anthropic model.
150
+
151
+ Args:
152
+ model_name (str): Model identifier (default: "claude-3-opus-20240229")
153
+ **kwargs: Additional arguments for model initialization
154
+
155
+ Returns:
156
+ ChatAnthropic: Initialized Anthropic chat model
157
+
158
+ """
159
+ if not ModelProvider.check_package("langchain_anthropic"):
160
+ raise ImportError("Anthropic package not found. Please install: pip install langchain-anthropic")
161
+ from langchain_anthropic import ChatAnthropic
162
+ kwargs["model_name"] = model_name
163
+ return ChatAnthropic(**kwargs)
164
+
165
+ @staticmethod
166
+ def get_rag_google(model_name: str = "gemini-1.5-flash", **kwargs):
167
+ """
168
+ Load Google embedding model.
169
+
170
+ Args:
171
+ model_name (str): Model identifier (default: "gemini-1.5-flash")
172
+ **kwargs: Additional arguments for model initialization
173
+
174
+ Returns:
175
+ GoogleGenerativeAIEmbeddings: Initialized Google embeddings model
176
+ """
177
+ if not ModelProvider.check_package("google.generativeai"):
178
+ raise ImportError("Google Generative AI package not found. Please install: pip install langchain-google-genai")
179
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings
180
+ kwargs["model"] = model_name
181
+ return GoogleGenerativeAIEmbeddings(**kwargs)
182
+
183
+ def load_embedding_model(model_name: str = 'openai', model_type: str = "text-embedding-ada-002", **kwargs):
97
184
  """
98
- Load a RAG model from a given model name and type
185
+ Load a RAG model based on provider and type.
186
+
99
187
  Args:
100
- model_name (str): Name of the model. Default is openai.
101
- model_type (str): Type of the model. Default is text-embedding-ada-002.
102
- **kwargs: Additional arguments (temperature, max_tokens, timeout, max_retries, api_key etc.)
188
+ model_name (str): Name of the model provider (default: 'openai')
189
+ model_type (str): Type/identifier of the model (default: "text-embedding-ada-002")
190
+ **kwargs: Additional arguments for model initialization
191
+
103
192
  Returns:
104
- RAGModel: RAG model
193
+ Any: Initialized model instance
194
+
195
+ Example:
196
+ ```python
197
+ model = load_embedding_model('openai', 'text-embedding-3-small')
198
+ ```
105
199
  """
106
200
  try:
107
201
  if model_name == 'openai':
108
- return get_rag_openai(model_type, **(kwargs or {}))
202
+ return ModelProvider.get_rag_openai(model_type, **kwargs)
109
203
  elif model_name == 'ollama':
110
- return get_rag_ollama(model_type, **(kwargs or {}))
204
+ return ModelProvider.get_rag_ollama(model_type, **kwargs)
111
205
  elif model_name == 'google':
112
- return get_rag_google(model_type, **(kwargs or {}))
206
+ return ModelProvider.get_rag_google(model_type, **kwargs)
113
207
  elif model_name == 'anthropic':
114
- return get_rag_anthropic(model_type, **(kwargs or {}))
208
+ return ModelProvider.get_rag_anthropic(model_type, **kwargs)
115
209
  else:
116
210
  raise ValueError(f"Invalid model name: {model_name}")
117
211
  except ImportError as e:
118
212
  print(f"Error loading model: {str(e)}")
119
213
  return None
120
214
 
121
- class embedding_generator:
215
+ class TextProcessor:
122
216
  """
123
- Class to generate embeddings for the RAG model abnd chat with data
217
+ Handles text processing operations including file checking and tokenization.
218
+
219
+ This class provides methods for loading text files, processing them into chunks,
220
+ and preparing them for embedding generation.
221
+
124
222
  Args:
125
- model: type of model. Default is openai. Options are openai, anthropic, google, ollama
126
- model_type: type of model. Default is text-embedding-3-small. Options are text-embedding-3-small, text-embedding-3-large, text-embedding-ada-002 for openai.
127
- vector_store_type: type of vector store. Default is chroma
128
- logger: logger
129
- model_kwargs: additional arguments for the model
130
- vector_store_kwargs: additional arguments for the vector store
131
- collection_name: name of the collection (default : test)
223
+ logger: Optional logger instance for logging operations
224
+
225
+ Example:
226
+ ```python
227
+ processor = TextProcessor()
228
+ docs = processor.tokenize(
229
+ ['./data.txt'],
230
+ 'recursive_character',
231
+ chunk_size=1000,
232
+ chunk_overlap=5
233
+ )
234
+ ```
132
235
  """
133
-
134
- def __init__(self,model: str = 'openai',model_type: str = 'text-embedding-3-small',vector_store_type:str = 'chroma' ,collection_name: str = 'test',logger= None,model_kwargs: dict = None, vector_store_kwargs: dict = None) -> None:
236
+
237
+ def __init__(self, logger=None):
135
238
  self.logger = logger
136
- self.model = load_rag_model(model_name=model, model_type=model_type, **(model_kwargs or {}))
137
- if self.model is None:
138
- raise ValueError(f"Failed to initialize model {model}. Please ensure required packages are installed.")
139
- self.vector_store_type = vector_store_type
140
- self.vector_store = self.load_vectorstore(**(vector_store_kwargs or {}))
141
- self.collection_name = collection_name
142
239
 
143
- def check_file(self, file_path):
144
- """
145
- Check if the file exists
146
- """
147
- if os.path.exists(file_path):
148
- return True
149
- else:
150
- return False
240
+ def check_file(self, file_path: str) -> bool:
241
+ """Check if file exists."""
242
+ return os.path.exists(file_path)
151
243
 
152
- def tokenize(self,text_data_path :list,text_splitter_type: str,chunk_size: int,chunk_overlap: int):
244
+ def tokenize(self, text_data_path: List[str], text_splitter_type: str,
245
+ chunk_size: int, chunk_overlap: int) -> List:
153
246
  """
154
- Function to tokenize the text
247
+ Process and tokenize text data from files.
248
+
155
249
  Args:
156
- text: text to tokenize
250
+ text_data_path (List[str]): List of paths to text files
251
+ text_splitter_type (str): Type of text splitter to use
252
+ chunk_size (int): Size of text chunks
253
+ chunk_overlap (int): Overlap between chunks
254
+
157
255
  Returns:
158
- tokens
256
+ List: List of processed document chunks
257
+
159
258
  """
160
- doc_data = []
161
- for i in text_data_path:
162
- if self.check_file(i):
163
- text_loader = TextLoader(i)
259
+ doc_data = []
260
+ for path in text_data_path:
261
+ if self.check_file(path):
262
+ text_loader = TextLoader(path)
164
263
  get_text = text_loader.load()
165
- # print(get_text) ## testing - Need to remove
166
- file_name = i.split('/')[-1]
264
+ file_name = path.split('/')[-1]
167
265
  metadata = {'source': file_name}
168
266
  if metadata is not None:
169
- for j in get_text:
170
- j.metadata = metadata
171
- doc_data.append(j)
172
- if self.logger is not None:
267
+ for doc in get_text:
268
+ doc.metadata = metadata
269
+ doc_data.append(doc)
270
+ if self.logger:
173
271
  self.logger.info(f"Text data loaded from {file_name}")
174
272
  else:
175
- return f"File {i} not found"
176
-
177
- if self.logger is not None:
178
- self.logger.info(f"Splitting text data into chunks of size {chunk_size} with overlap {chunk_overlap}")
179
- if text_splitter_type == 'character':
180
- text_splitter = CharacterTextSplitter(chunk_size=chunk_size,chunk_overlap=chunk_overlap, separator=["\n","\n\n","\n\n\n"," "])
181
- if text_splitter_type == 'recursive_character':
182
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,chunk_overlap=chunk_overlap,separators=["\n","\n\n","\n\n\n"," "])
183
- if text_splitter_type == 'sentence_transformers_token':
184
- text_splitter = SentenceTransformersTokenTextSplitter(chunk_size=chunk_size)
185
- if text_splitter_type == 'token':
186
- text_splitter = TokenTextSplitter(chunk_size=chunk_size,chunk_overlap=chunk_overlap)
187
- docs = text_splitter.split_documents(doc_data)
188
- if self.logger is not None:
273
+ return f"File {path} not found"
274
+
275
+ splitters = {
276
+ 'character': CharacterTextSplitter(
277
+ chunk_size=chunk_size,
278
+ chunk_overlap=chunk_overlap,
279
+ separator=["\n", "\n\n", "\n\n\n", " "]
280
+ ),
281
+ 'recursive_character': RecursiveCharacterTextSplitter(
282
+ chunk_size=chunk_size,
283
+ chunk_overlap=chunk_overlap,
284
+ separators=["\n", "\n\n", "\n\n\n", " "]
285
+ ),
286
+ 'sentence_transformers_token': SentenceTransformersTokenTextSplitter(
287
+ chunk_size=chunk_size
288
+ ),
289
+ 'token': TokenTextSplitter(
290
+ chunk_size=chunk_size,
291
+ chunk_overlap=chunk_overlap
292
+ )
293
+ }
294
+
295
+ if text_splitter_type not in splitters:
296
+ raise ValueError(f"Invalid text splitter type: {text_splitter_type}")
297
+
298
+ text_splitter = splitters[text_splitter_type]
299
+ docs = text_splitter.split_documents(doc_data)
300
+
301
+ if self.logger:
189
302
  self.logger.info(f"Text data splitted into {len(docs)} chunks")
190
303
  else:
191
304
  print(f"Text data splitted into {len(docs)} chunks")
192
- return docs
305
+ return docs
306
+
307
+ class embedding_generator:
308
+ """
309
+ Main class for generating embeddings and managing RAG operations.
310
+
311
+ This class provides comprehensive functionality for generating embeddings,
312
+ managing vector stores, handling retrievers, and managing conversations.
313
+
314
+ Args:
315
+ model (str): Model provider name (default: 'openai')
316
+ model_type (str): Model type/identifier (default: 'text-embedding-3-small')
317
+ vector_store_type (str): Type of vector store (default: 'chroma')
318
+ collection_name (str): Name of the collection (default: 'test')
319
+ logger: Optional logger instance
320
+ model_kwargs (dict): Additional arguments for model initialization
321
+ vector_store_kwargs (dict): Additional arguments for vector store initialization
322
+
323
+ Example:
324
+ ```python
325
+ # Initialize generator
326
+ gen = embedding_generator(
327
+ model="openai",
328
+ model_type="text-embedding-3-small"
329
+ )
330
+
331
+ # Generate embeddings
332
+ gen.generate_text_embeddings(
333
+ text_data_path=['./data.txt'],
334
+ folder_save_path='./embeddings'
335
+ )
336
+
337
+ # Load retriever
338
+ retriever = gen.load_retriever('./embeddings')
339
+
340
+ # Query embeddings
341
+ results = gen.query_embeddings("What is this about?")
342
+ ```
343
+ """
193
344
 
194
- def generate_text_embeddings(self,text_data_path: list = None,text_splitter_type: str = 'recursive_character',
195
- chunk_size: int = 1000,chunk_overlap: int = 5,folder_save_path: str = './text_embeddings',
196
- replace_existing: bool = False):
345
+ def __init__(self, model: str = 'openai', model_type: str = 'text-embedding-3-small',
346
+ vector_store_type: str = 'chroma', collection_name: str = 'test',
347
+ logger=None, model_kwargs: dict = None, vector_store_kwargs: dict = None) -> None:
348
+ """Initialize the embedding generator with specified configuration."""
349
+ self.logger = logger
350
+ self.model = load_embedding_model(model_name=model, model_type=model_type, **(model_kwargs or {}))
351
+ if self.model is None:
352
+ raise ValueError(f"Failed to initialize model {model}. Please ensure required packages are installed.")
353
+ self.vector_store_type = vector_store_type
354
+ self.vector_store = self.load_vectorstore(**(vector_store_kwargs or {}))
355
+ self.collection_name = collection_name
356
+ self.text_processor = TextProcessor(logger)
357
+
358
+ def check_file(self, file_path: str) -> bool:
359
+ """Check if file exists."""
360
+ return self.text_processor.check_file(file_path)
361
+
362
+ def tokenize(self, text_data_path: List[str], text_splitter_type: str,
363
+ chunk_size: int, chunk_overlap: int) -> List:
364
+ """Process and tokenize text data."""
365
+ return self.text_processor.tokenize(text_data_path, text_splitter_type,
366
+ chunk_size, chunk_overlap)
367
+
368
+ def generate_text_embeddings(self, text_data_path: List[str] = None,
369
+ text_splitter_type: str = 'recursive_character',
370
+ chunk_size: int = 1000, chunk_overlap: int = 5,
371
+ folder_save_path: str = './text_embeddings',
372
+ replace_existing: bool = False) -> str:
197
373
  """
198
- Function to generate text embeddings
374
+ Generate text embeddings from input files.
375
+
199
376
  Args:
200
- text_data_path: list of text files
201
- # metadata: list of metadata for each text file. Dictionary format
202
- text_splitter_type: type of text splitter. Default is recursive_character
203
- chunk_size: size of the chunk
204
- chunk_overlap: overlap between chunks
205
- folder_save_path: path to save the embeddings
206
- replace_existing: if True, replace the existing embeddings
377
+ text_data_path (List[str]): List of paths to text files
378
+ text_splitter_type (str): Type of text splitter
379
+ chunk_size (int): Size of text chunks
380
+ chunk_overlap (int): Overlap between chunks
381
+ folder_save_path (str): Path to save embeddings
382
+ replace_existing (bool): Whether to replace existing embeddings
383
+
207
384
  Returns:
208
- None
385
+ str: Status message
386
+
387
+ Example:
388
+ ```python
389
+ gen.generate_text_embeddings(
390
+ text_data_path=['./data.txt'],
391
+ folder_save_path='./embeddings'
392
+ )
393
+ ```
209
394
  """
395
+ if self.logger:
396
+ self.logger.info("Performing basic checks")
210
397
 
211
- if self.logger is not None:
212
- self.logger.info("Perforing basic checks")
213
-
214
- if self.check_file(folder_save_path) and replace_existing==False:
398
+ if self.check_file(folder_save_path) and not replace_existing:
215
399
  return "File already exists"
216
400
  elif self.check_file(folder_save_path) and replace_existing:
217
- shutil.rmtree(folder_save_path)
401
+ shutil.rmtree(folder_save_path)
218
402
 
219
403
  if text_data_path is None:
220
404
  return "Please provide text data path"
221
405
 
222
- assert isinstance(text_data_path, list), "text_data_path should be a list"
223
- # if metadata is not None:
224
- # assert isinstance(metadata, list), "metadata should be a list"
225
- # assert len(text_data_path) == len(metadata), "Number of text files and metadata should be equal"
406
+ if not isinstance(text_data_path, list):
407
+ raise ValueError("text_data_path should be a list")
226
408
 
227
- if self.logger is not None:
409
+ if self.logger:
228
410
  self.logger.info(f"Loading text data from {text_data_path}")
229
411
 
230
- docs = self.tokenize(text_data_path,text_splitter_type,chunk_size,chunk_overlap)
412
+ docs = self.tokenize(text_data_path, text_splitter_type, chunk_size, chunk_overlap)
231
413
 
232
- if self.logger is not None:
233
- self.logger.info(f"Generating embeddings for {len(docs)} documents")
414
+ if self.logger:
415
+ self.logger.info(f"Generating embeddings for {len(docs)} documents")
234
416
 
235
- self.vector_store.from_documents(docs, self.model,collection_name=self.collection_name,persist_directory=folder_save_path)
417
+ self.vector_store.from_documents(docs, self.model, collection_name=self.collection_name,
418
+ persist_directory=folder_save_path)
236
419
 
237
- if self.logger is not None:
420
+ if self.logger:
238
421
  self.logger.info(f"Embeddings generated and saved at {folder_save_path}")
239
422
 
240
- def load_vectorstore(self):
241
- """
242
- Function to load vector store
243
- Args:
244
- vector_store_type: type of vector store
245
- Returns:
246
- vector store
247
- """
423
+ def load_vectorstore(self, **kwargs):
424
+ """Load vector store."""
248
425
  if self.vector_store_type == 'chroma':
249
426
  vector_store = Chroma()
250
- if self.logger is not None:
427
+ if self.logger:
251
428
  self.logger.info(f"Loaded vector store {self.vector_store_type}")
252
429
  return vector_store
253
430
  else:
254
431
  return "Vector store not found"
255
432
 
256
- def load_embeddings(self,embeddings_folder_path: str):
433
+ def load_embeddings(self, embeddings_folder_path: str):
257
434
  """
258
- Function to load embeddings from the folder
435
+ Load embeddings from folder.
436
+
259
437
  Args:
260
- embeddings_path: path to the embeddings
438
+ embeddings_folder_path (str): Path to embeddings folder
439
+
261
440
  Returns:
262
- embeddings
441
+ Optional[Chroma]: Loaded vector store or None if not found
263
442
  """
264
443
  if self.check_file(embeddings_folder_path):
265
444
  if self.vector_store_type == 'chroma':
266
- # embeddings_path = os.path.join(embeddings_folder_path)
267
- return Chroma(persist_directory = embeddings_folder_path,embedding_function=self.model)
445
+ return Chroma(persist_directory=embeddings_folder_path,
446
+ embedding_function=self.model)
268
447
  else:
269
448
  if self.logger:
270
- self.logger.info("Embeddings file not found")
271
- return None
272
-
273
- def load_retriever(self,embeddings_folder_path: str,search_type: list = ["similarity_score_threshold"],search_params: list = [{"k": 3, "score_threshold": 0.9}]):
449
+ self.logger.info("Embeddings file not found")
450
+ return None
451
+
452
+ def load_retriever(self, embeddings_folder_path: str,
453
+ search_type: List[str] = ["similarity_score_threshold"],
454
+ search_params: List[Dict] = [{"k": 3, "score_threshold": 0.9}]):
274
455
  """
275
- Function to load retriever
456
+ Load retriever with search configuration.
457
+
276
458
  Args:
277
- embeddings_path: path to the embeddings
278
- search_type: list of str: type of search. Default : ["similarity_score_threshold"]
279
- search_params: list of dict: parameters for the search. Default : [{"k": 3, "score_threshold": 0.9}]
459
+ embeddings_folder_path (str): Path to embeddings folder
460
+ search_type (List[str]): List of search types
461
+ search_params (List[Dict]): List of search parameters
462
+
280
463
  Returns:
281
- Retriever. If multiple search types are provided, a list of retrievers is returned
464
+ Union[Any, List[Any]]: Single retriever or list of retrievers
465
+
466
+ Example:
467
+ ```python
468
+ retriever = gen.load_retriever(
469
+ './embeddings',
470
+ search_type=["similarity_score_threshold"],
471
+ search_params=[{"k": 3, "score_threshold": 0.9}]
472
+ )
473
+ ```
282
474
  """
283
475
  db = self.load_embeddings(embeddings_folder_path)
284
476
  if db is not None:
285
477
  if self.vector_store_type == 'chroma':
286
- assert len(search_type) == len(search_params), "Length of search_type and search_params should be equal"
478
+ if len(search_type) != len(search_params):
479
+ raise ValueError("Length of search_type and search_params should be equal")
287
480
  if len(search_type) == 1:
288
- self.retriever = db.as_retriever(search_type = search_type[0],search_kwargs=search_params[0])
481
+ self.retriever = db.as_retriever(search_type=search_type[0],
482
+ search_kwargs=search_params[0])
289
483
  if self.logger:
290
484
  self.logger.info("Retriever loaded")
291
485
  return self.retriever
292
486
  else:
293
487
  retriever_list = []
294
488
  for i in range(len(search_type)):
295
- retriever_list.append(db.as_retriever(search_type = search_type[i],search_kwargs=search_params[i]))
489
+ retriever_list.append(db.as_retriever(search_type=search_type[i],
490
+ search_kwargs=search_params[i]))
296
491
  if self.logger:
297
- self.logger.info("List of Retriever loaded")
492
+ self.logger.info("List of Retriever loaded")
298
493
  return retriever_list
299
494
  else:
300
495
  return "Embeddings file not found"
301
-
302
- def add_data(self,embeddings_folder_path: str, data: list,text_splitter_type: str = 'recursive_character',
303
- chunk_size: int = 1000,chunk_overlap: int = 5):
496
+
497
+ def add_data(self, embeddings_folder_path: str, data: List[str],
498
+ text_splitter_type: str = 'recursive_character',
499
+ chunk_size: int = 1000, chunk_overlap: int = 5):
304
500
  """
305
- Function to add data to the existing db/embeddings
501
+ Add data to existing embeddings.
502
+
306
503
  Args:
307
- embeddings_path: path to the embeddings
308
- data: list of data to add
309
- text_splitter_type: type of text splitter. Default is recursive_character
310
- chunk_size: size of the chunk
311
- chunk_overlap: overlap between chunks
312
- Returns:
313
- None
504
+ embeddings_folder_path (str): Path to embeddings folder
505
+ data (List[str]): List of text data to add
506
+ text_splitter_type (str): Type of text splitter
507
+ chunk_size (int): Size of text chunks
508
+ chunk_overlap (int): Overlap between chunks
314
509
  """
315
510
  if self.vector_store_type == 'chroma':
316
511
  db = self.load_embeddings(embeddings_folder_path)
317
512
  if db is not None:
318
- docs = self.tokenize(data,text_splitter_type,chunk_size,chunk_overlap)
513
+ docs = self.tokenize(data, text_splitter_type, chunk_size, chunk_overlap)
319
514
  db.add_documents(docs)
320
515
  if self.logger:
321
516
  self.logger.info("Data added to the existing db/embeddings")
322
-
323
- def query_embeddings(self,query: str,retriever = None):
517
+
518
+ def query_embeddings(self, query: str, retriever=None):
324
519
  """
325
- Function to query embeddings
520
+ Query embeddings.
521
+
326
522
  Args:
327
- search_type: type of search
328
- query: query to search
523
+ query (str): Query string
524
+ retriever: Optional retriever instance
525
+
329
526
  Returns:
330
- results
527
+ Any: Query results
331
528
  """
332
- # if self.vector_store_type == 'chroma':
333
529
  if retriever is None:
334
530
  retriever = self.retriever
335
531
  return retriever.invoke(query)
336
- # else:
337
- # return "Vector store not found"
338
532
 
339
- def get_relevant_documents(self,query: str,retriever = None):
533
+ def get_relevant_documents(self, query: str, retriever=None):
340
534
  """
341
- Function to get relevant documents
535
+ Get relevant documents for query.
536
+
342
537
  Args:
343
- query: query to search
538
+ query (str): Query string
539
+ retriever: Optional retriever instance
540
+
344
541
  Returns:
345
- results
542
+ List: List of relevant documents
346
543
  """
347
- return self.retriever.get_relevant_documents(query)
348
-
349
- def generate_rag_chain(self,context_prompt: str = None,retriever = None,llm= None):
544
+ if retriever is None:
545
+ retriever = self.retriever
546
+ return retriever.get_relevant_documents(query)
547
+
548
+ def generate_rag_chain(self, context_prompt: str = None, retriever=None, llm=None):
350
549
  """
351
- Function to start a conversation chain with a rag data. Call this to load a rag_chain module.
550
+ Generate RAG chain for conversation.
551
+
352
552
  Args:
353
- context_prompt: prompt to context
354
- retriever: retriever. Default is None.
355
- llm: language model. Default is openai. Need chat model llm. "ChatOpenAI", "ChatAnthropic" etc. like chatbot
553
+ context_prompt (str): Optional context prompt
554
+ retriever: Optional retriever instance
555
+ llm: Optional language model instance
556
+
356
557
  Returns:
357
- rag_chain_model.
558
+ Any: Generated RAG chain
559
+
560
+ Example:
561
+ ```python
562
+ rag_chain = gen.generate_rag_chain(retriever=retriever)
563
+ ```
358
564
  """
359
565
  if context_prompt is None:
360
- context_prompt = ("You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. "
361
- "If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise. "
362
- "\n\n {context}")
363
- contextualize_q_system_prompt = ("Given a chat history and the latest user question which might reference context in the chat history, formulate a standalone question which can be understood "
364
- "without the chat history. Do NOT answer the question, just reformulate it if needed and otherwise return it as is.")
365
- contextualize_q_prompt = ChatPromptTemplate.from_messages([("system", contextualize_q_system_prompt),MessagesPlaceholder("chat_history"),("human", "{input}"),])
566
+ context_prompt = ("You are an assistant for question-answering tasks. "
567
+ "Use the following pieces of retrieved context to answer the question. "
568
+ "If you don't know the answer, just say that you don't know. "
569
+ "Use three sentences maximum and keep the answer concise.\n\n{context}")
570
+
571
+ contextualize_q_system_prompt = ("Given a chat history and the latest user question "
572
+ "which might reference context in the chat history, "
573
+ "formulate a standalone question which can be understood, "
574
+ "just reformulate it if needed and otherwise return it as is.")
575
+
576
+ contextualize_q_prompt = ChatPromptTemplate.from_messages([
577
+ ("system", contextualize_q_system_prompt),
578
+ MessagesPlaceholder("chat_history"),
579
+ ("human", "{input}"),
580
+ ])
366
581
 
367
582
  if retriever is None:
368
583
  retriever = self.retriever
369
584
  if llm is None:
370
- if not check_package("langchain_openai"):
371
- raise ImportError("OpenAI package not found. Please install it using: pip install langchain-openai")
585
+ if not ModelProvider.check_package("langchain_openai"):
586
+ raise ImportError("OpenAI package not found. Please install: pip install langchain-openai")
372
587
  from langchain_openai import ChatOpenAI
373
588
  llm = ChatOpenAI(model="gpt-4")
374
589
 
375
- history_aware_retriever = create_history_aware_retriever(llm,retriever, contextualize_q_prompt)
376
- qa_prompt = ChatPromptTemplate.from_messages([("system", context_prompt),MessagesPlaceholder("chat_history"),("human", "{input}"),])
377
- question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
590
+ history_aware_retriever = create_history_aware_retriever(llm, retriever,
591
+ contextualize_q_prompt)
592
+ qa_prompt = ChatPromptTemplate.from_messages([
593
+ ("system", context_prompt),
594
+ MessagesPlaceholder("chat_history"),
595
+ ("human", "{input}"),
596
+ ])
597
+ question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
378
598
  rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
379
599
  return rag_chain
380
600
 
381
- def conversation_chain(self,query: str,rag_chain,file:str =None):
601
+ def conversation_chain(self, query: str, rag_chain, file: str = None):
382
602
  """
383
- Function to create a conversation chain
603
+ Create conversation chain.
604
+
384
605
  Args:
385
- query: query to search
386
- rag_chain : rag_chain model
387
- file: load a file and update it with the conversation. If None it will not be saved.
606
+ query (str): User query
607
+ rag_chain: RAG chain instance
608
+ file (str): Optional file to save conversation
609
+
388
610
  Returns:
389
- results
611
+ List: Conversation history
612
+
613
+ Example:
614
+ ```python
615
+ history = gen.conversation_chain(
616
+ "Tell me about...",
617
+ rag_chain,
618
+ file='conversation.txt'
619
+ )
620
+ ```
390
621
  """
391
622
  if file is not None:
392
623
  try:
393
- chat_history = self.load_conversation(file,list_type=True)
624
+ chat_history = self.load_conversation(file, list_type=True)
394
625
  if len(chat_history) == 0:
395
626
  chat_history = []
396
627
  except:
397
628
  chat_history = []
398
629
  else:
399
630
  chat_history = []
400
- query = "You : " + query
401
- res = rag_chain.invoke({"input": query,"chat_history": chat_history})
631
+
632
+ query = "You : " + query
633
+ res = rag_chain.invoke({"input": query, "chat_history": chat_history})
402
634
  print(f"Response: {res['answer']}")
403
635
  chat_history.append(HumanMessage(content=query))
404
636
  chat_history.append(SystemMessage(content=res['answer']))
405
637
  if file is not None:
406
- self.save_conversation(chat_history,file)
638
+ self.save_conversation(chat_history, file)
407
639
  return chat_history
408
640
 
409
- def load_conversation(self,file: str,list_type: bool = False):
641
+ def load_conversation(self, file: str, list_type: bool = False):
410
642
  """
411
- Function to load the conversation
643
+ Load conversation history.
644
+
412
645
  Args:
413
- file: file to load
414
- list_type: if True, return the chat_history as a list. Default is False.
646
+ file (str): Path to conversation file
647
+ list_type (bool): Whether to return as list
648
+
415
649
  Returns:
416
- chat_history
650
+ Union[str, List]: Conversation history
417
651
  """
418
652
  if list_type:
419
653
  chat_history = []
420
- with open(file,'r') as f:
654
+ with open(file, 'r') as f:
421
655
  for line in f:
422
- # inner_list = [elt.strip() for elt in line.split(',')]
423
656
  chat_history.append(line.strip())
424
657
  else:
425
658
  with open(file, "r") as f:
426
659
  chat_history = f.read()
427
660
  return chat_history
428
661
 
429
- def save_conversation(self,chat: str,file: str):
662
+ def save_conversation(self, chat: Union[str, List], file: str):
430
663
  """
431
- Function to save the conversation
664
+ Save conversation history.
665
+
432
666
  Args:
433
- chat: chat results
434
- file: file to save
435
- Returns:
436
- None
667
+ chat (Union[str, List]): Conversation to save
668
+ file (str): Path to save file
437
669
  """
438
- if isinstance(chat,str):
670
+ if isinstance(chat, str):
439
671
  with open(file, "a") as f:
440
672
  f.write(chat)
441
- elif isinstance(chat,list):
673
+ elif isinstance(chat, list):
442
674
  with open(file, "a") as f:
443
675
  for i in chat[-2:]:
444
676
  f.write("%s\n" % i)
445
677
  print(f"Saved file : {file}")
446
678
 
447
- def firecrawl_web(self, website, api_key: str = None, mode="scrape", file_to_save: str = './firecrawl_embeddings',**kwargs):
679
+ def firecrawl_web(self, website: str, api_key: str = None, mode: str = "scrape",
680
+ file_to_save: str = './firecrawl_embeddings', **kwargs):
448
681
  """
449
- Function to get data from website. Use this to get data from a website and save it as embeddings/retriever. To ask questions from the website,
450
- use the load_retriever and query_embeddings function.
682
+ Get data from website using FireCrawl.
683
+
451
684
  Args:
452
- website : str - link to website.
453
- api_key : api key of firecrawl, if None environment variable "FIRECRAWL_API_KEY" will be used.
454
- mode(str) : 'scrape' default to just use the same page. Not the whole website.
455
- file_to_save: path to save the embeddings
456
- **kwargs: additional arguments
685
+ website (str): Website URL to crawl
686
+ api_key (str): Optional FireCrawl API key
687
+ mode (str): Crawl mode (default: "scrape")
688
+ file_to_save (str): Path to save embeddings
689
+ **kwargs: Additional arguments for FireCrawl
690
+
457
691
  Returns:
458
- retriever
692
+ Chroma: Vector store with crawled data
693
+
694
+ Example:
695
+ ```python
696
+ db = gen.firecrawl_web(
697
+ "https://example.com",
698
+ mode="scrape",
699
+ file_to_save='./crawl_embeddings'
700
+ )
701
+ ```
459
702
  """
460
- if not check_package("firecrawl"):
461
- raise ImportError("Firecrawl package not found. Please install it using: pip install firecrawl")
462
-
703
+ if not ModelProvider.check_package("firecrawl"):
704
+ raise ImportError("Firecrawl package not found. Please install: pip install firecrawl")
705
+
463
706
  if api_key is None:
464
707
  api_key = os.getenv("FIRECRAWL_API_KEY")
708
+
465
709
  loader = FireCrawlLoader(api_key=api_key, url=website, mode=mode)
466
710
  docs = loader.load()
711
+
467
712
  for doc in docs:
468
713
  for key, value in doc.metadata.items():
469
714
  if isinstance(value, list):
470
715
  doc.metadata[key] = ", ".join(map(str, value))
716
+
471
717
  text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
472
718
  split_docs = text_splitter.split_documents(docs)
719
+
473
720
  print("\n--- Document Chunks Information ---")
474
721
  print(f"Number of document chunks: {len(split_docs)}")
475
722
  print(f"Sample chunk:\n{split_docs[0].page_content}\n")
723
+
476
724
  embeddings = self.model
477
- db = Chroma.from_documents(
478
- split_docs, embeddings, persist_directory=file_to_save)
725
+ db = Chroma.from_documents(split_docs, embeddings,
726
+ persist_directory=file_to_save)
479
727
  print(f"Retriever saved at {file_to_save}")
480
728
  return db
mb_rag/version.py CHANGED
@@ -1,5 +1,5 @@
1
1
  MAJOR_VERSION = 1
2
2
  MINOR_VERSION = 0
3
- PATCH_VERSION = 124
3
+ PATCH_VERSION = 126
4
4
  version = '{}.{}.{}'.format(MAJOR_VERSION, MINOR_VERSION, PATCH_VERSION)
5
5
  __all__ = ['MAJOR_VERSION', 'MINOR_VERSION', 'PATCH_VERSION', 'version']
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mb_rag
3
- Version: 1.0.124
3
+ Version: 1.0.126
4
4
  Summary: RAG function file
5
5
  Author: ['Malav Bateriwala']
6
6
  Requires-Python: >=3.8
@@ -1,15 +1,15 @@
1
1
  mb_rag/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- mb_rag/version.py,sha256=g3pkzzTRM6lWK4_vY_dwd2hYFDmXrp2VRW8Uj2krk4k,208
2
+ mb_rag/version.py,sha256=i5tH3RRI4rh19s3yPr1vEfDicqpWdyWzE-Orh1cEInA,208
3
3
  mb_rag/chatbot/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  mb_rag/chatbot/basic.py,sha256=OR2IvDg-Sy968C2Mna6lxmFfh7Czj8yCEkCfyvtxBwI,14223
5
5
  mb_rag/chatbot/chains.py,sha256=vDbLX5R29sWN1pcFqJ5fyxJEgMCM81JAikunAEvMC9A,7223
6
6
  mb_rag/chatbot/prompts.py,sha256=n1PyiLbU-5fkslRv6aVOzt0dDlwya_cEdQ7kRnRhMuY,1749
7
7
  mb_rag/rag/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
- mb_rag/rag/embeddings.py,sha256=kOnHjrbi0GRVErfcjML8fZz-KttipObUfa5fW9tGOoY,21196
8
+ mb_rag/rag/embeddings.py,sha256=KjBdekFDb5M3dRMco4r3dDMXMsG5dxdzKImuVIipsd0,27091
9
9
  mb_rag/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  mb_rag/utils/bounding_box.py,sha256=XnuhnLrsGvsI8P8VtOwlBrDlFE2It1HEZOcLlK6kusE,7931
11
11
  mb_rag/utils/extra.py,sha256=spbFrGgdruNyYQ5PzgvpSIa6Nm0rn9bb4qc8W9g582o,2492
12
- mb_rag-1.0.124.dist-info/METADATA,sha256=XyQ055JYBEiv5OU8m9FBhcOnHCoipo69LbnMgSa4bmM,154
13
- mb_rag-1.0.124.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
14
- mb_rag-1.0.124.dist-info/top_level.txt,sha256=FIK1eAa5uYnurgXZquBG-s3PIy-HDTC5yJBW4lTH_pM,7
15
- mb_rag-1.0.124.dist-info/RECORD,,
12
+ mb_rag-1.0.126.dist-info/METADATA,sha256=Sy6LYAyBqo9QgDUcP3UODBTEje9YBPSj3kHiLrz2Hs8,154
13
+ mb_rag-1.0.126.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
14
+ mb_rag-1.0.126.dist-info/top_level.txt,sha256=FIK1eAa5uYnurgXZquBG-s3PIy-HDTC5yJBW4lTH_pM,7
15
+ mb_rag-1.0.126.dist-info/RECORD,,