mb-rag 1.0.123__py3-none-any.whl → 1.0.125__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mb-rag might be problematic. Click here for more details.
- mb_rag/chatbot/basic.py +6 -5
- mb_rag/rag/embeddings.py +521 -272
- mb_rag/version.py +1 -1
- {mb_rag-1.0.123.dist-info → mb_rag-1.0.125.dist-info}/METADATA +1 -1
- {mb_rag-1.0.123.dist-info → mb_rag-1.0.125.dist-info}/RECORD +7 -7
- {mb_rag-1.0.123.dist-info → mb_rag-1.0.125.dist-info}/WHEEL +0 -0
- {mb_rag-1.0.123.dist-info → mb_rag-1.0.125.dist-info}/top_level.txt +0 -0
mb_rag/chatbot/basic.py
CHANGED
|
@@ -220,13 +220,14 @@ class ConversationModel:
|
|
|
220
220
|
self.chatbot = ModelFactory(model_type, model_name, **kwargs)
|
|
221
221
|
|
|
222
222
|
def initialize_conversation(self,
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
223
|
+
question: Optional[str],
|
|
224
|
+
context: Optional[str] = None,
|
|
225
|
+
file_path: Optional[str]=None) -> None:
|
|
226
226
|
"""Initialize conversation state"""
|
|
227
227
|
if file_path:
|
|
228
|
-
self.file_path = file_path
|
|
228
|
+
self.file_path = file_path
|
|
229
229
|
self.load_conversation(file_path)
|
|
230
|
+
|
|
230
231
|
else:
|
|
231
232
|
if not question:
|
|
232
233
|
raise ValueError("Question is required.")
|
|
@@ -371,7 +372,7 @@ class ConversationModel:
|
|
|
371
372
|
raise ValueError(f"Error loading conversation from s3: {e}")
|
|
372
373
|
|
|
373
374
|
def _load_from_file(self, file_path: str) -> List[Any]:
|
|
374
|
-
"""Load conversation from file"""
|
|
375
|
+
"""Load conversation from file"""
|
|
375
376
|
try:
|
|
376
377
|
with open(file_path, 'r') as f:
|
|
377
378
|
lines = f.readlines()
|
mb_rag/rag/embeddings.py
CHANGED
|
@@ -1,480 +1,729 @@
|
|
|
1
|
-
|
|
1
|
+
"""
|
|
2
|
+
RAG (Retrieval-Augmented Generation) Embeddings Module
|
|
3
|
+
|
|
4
|
+
This module provides functionality for generating and managing embeddings for RAG models.
|
|
5
|
+
It supports multiple embedding models (OpenAI, Ollama, Google, Anthropic) and includes
|
|
6
|
+
features for text processing, embedding generation, vector store management, and
|
|
7
|
+
conversation handling.
|
|
8
|
+
|
|
9
|
+
Example Usage:
|
|
10
|
+
```python
|
|
11
|
+
# Initialize embedding generator
|
|
12
|
+
em_gen = embedding_generator(
|
|
13
|
+
model="openai",
|
|
14
|
+
model_type="text-embedding-3-small",
|
|
15
|
+
vector_store_type="chroma"
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
# Generate embeddings from text
|
|
19
|
+
em_gen.generate_text_embeddings(
|
|
20
|
+
text_data_path=['./data/text.txt'],
|
|
21
|
+
chunk_size=500,
|
|
22
|
+
chunk_overlap=5,
|
|
23
|
+
folder_save_path='./embeddings'
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
# Load embeddings and create retriever
|
|
27
|
+
em_loading = em_gen.load_embeddings('./embeddings')
|
|
28
|
+
em_retriever = em_gen.load_retriever(
|
|
29
|
+
'./embeddings',
|
|
30
|
+
search_params=[{"k": 2, "score_threshold": 0.1}]
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
# Query embeddings
|
|
34
|
+
results = em_retriever.invoke("What is the text about?")
|
|
35
|
+
|
|
36
|
+
# Generate RAG chain for conversation
|
|
37
|
+
rag_chain = em_gen.generate_rag_chain(retriever=em_retriever)
|
|
38
|
+
response = em_gen.conversation_chain("Tell me more", rag_chain)
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
Features:
|
|
42
|
+
- Multiple model support (OpenAI, Ollama, Google, Anthropic)
|
|
43
|
+
- Text processing and chunking
|
|
44
|
+
- Embedding generation and storage
|
|
45
|
+
- Vector store management
|
|
46
|
+
- Retrieval operations
|
|
47
|
+
- Conversation chains
|
|
48
|
+
- Web crawling integration
|
|
49
|
+
|
|
50
|
+
Classes:
|
|
51
|
+
- ModelProvider: Base class for model loading and validation
|
|
52
|
+
- TextProcessor: Handles text processing operations
|
|
53
|
+
- embedding_generator: Main class for RAG operations
|
|
54
|
+
"""
|
|
2
55
|
|
|
3
56
|
import os
|
|
4
57
|
import shutil
|
|
5
58
|
import importlib.util
|
|
59
|
+
from typing import List, Dict, Optional, Union, Any
|
|
6
60
|
from langchain.text_splitter import (
|
|
7
61
|
CharacterTextSplitter,
|
|
8
62
|
RecursiveCharacterTextSplitter,
|
|
9
63
|
SentenceTransformersTokenTextSplitter,
|
|
10
64
|
TokenTextSplitter)
|
|
11
|
-
from langchain_community.document_loaders import TextLoader,FireCrawlLoader
|
|
65
|
+
from langchain_community.document_loaders import TextLoader, FireCrawlLoader
|
|
12
66
|
from langchain_community.vectorstores import Chroma
|
|
13
|
-
from ..utils.extra
|
|
67
|
+
from ..utils.extra import load_env_file
|
|
14
68
|
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
|
|
15
69
|
from langchain.chains.combine_documents import create_stuff_documents_chain
|
|
16
70
|
from langchain_core.messages import HumanMessage, SystemMessage
|
|
17
71
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
18
|
-
import time
|
|
19
72
|
|
|
20
73
|
load_env_file()
|
|
21
74
|
|
|
22
|
-
__all__ = ['embedding_generator', '
|
|
75
|
+
__all__ = ['embedding_generator', 'load_embedding_model']
|
|
23
76
|
|
|
24
|
-
|
|
77
|
+
class ModelProvider:
|
|
25
78
|
"""
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
79
|
+
Base class for managing different model providers and their loading logic.
|
|
80
|
+
|
|
81
|
+
This class provides static methods for loading different types of embedding models
|
|
82
|
+
and checking package dependencies.
|
|
83
|
+
|
|
84
|
+
Methods:
|
|
85
|
+
check_package: Check if a Python package is installed
|
|
86
|
+
get_rag_openai: Load OpenAI embedding model
|
|
87
|
+
get_rag_ollama: Load Ollama embedding model
|
|
88
|
+
get_rag_anthropic: Load Anthropic model
|
|
89
|
+
get_rag_google: Load Google embedding model
|
|
90
|
+
|
|
91
|
+
Example:
|
|
92
|
+
```python
|
|
93
|
+
# Check if a package is installed
|
|
94
|
+
has_openai = ModelProvider.check_package("langchain_openai")
|
|
95
|
+
|
|
96
|
+
# Load an OpenAI model
|
|
97
|
+
model = ModelProvider.get_rag_openai("text-embedding-3-small")
|
|
98
|
+
```
|
|
31
99
|
"""
|
|
32
|
-
return importlib.util.find_spec(package_name) is not None
|
|
33
|
-
|
|
34
|
-
def get_rag_openai(model_type: str = 'text-embedding-3-small',**kwargs):
|
|
35
|
-
"""
|
|
36
|
-
Load model from openai for RAG
|
|
37
|
-
Args:
|
|
38
|
-
model_type (str): Name of the model
|
|
39
|
-
**kwargs: Additional arguments (temperature, max_tokens, timeout, max_retries, api_key etc.)
|
|
40
|
-
Returns:
|
|
41
|
-
ChatOpenAI: Chatbot model
|
|
42
|
-
"""
|
|
43
|
-
if not check_package("langchain_openai"):
|
|
44
|
-
raise ImportError("OpenAI package not found. Please install it using: pip install langchain-openai")
|
|
45
100
|
|
|
46
|
-
|
|
47
|
-
|
|
101
|
+
@staticmethod
|
|
102
|
+
def check_package(package_name: str) -> bool:
|
|
103
|
+
"""
|
|
104
|
+
Check if a Python package is installed.
|
|
48
105
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
Load model from ollama for RAG
|
|
52
|
-
Args:
|
|
53
|
-
model_type (str): Name of the model
|
|
54
|
-
**kwargs: Additional arguments (temperature, max_tokens, timeout, max_retries, api_key etc.)
|
|
55
|
-
Returns:
|
|
56
|
-
OllamaEmbeddings: Embeddings model
|
|
57
|
-
"""
|
|
58
|
-
if not check_package("langchain_ollama"):
|
|
59
|
-
raise ImportError("Ollama package not found. Please install it using: pip install langchain-ollama")
|
|
60
|
-
|
|
61
|
-
from langchain_ollama import OllamaEmbeddings
|
|
62
|
-
return OllamaEmbeddings(model = model_type,**kwargs)
|
|
106
|
+
Args:
|
|
107
|
+
package_name (str): Name of the package to check
|
|
63
108
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
Load the chatbot model from Anthropic
|
|
67
|
-
Args:
|
|
68
|
-
model_name (str): Name of the model
|
|
69
|
-
**kwargs: Additional arguments (temperature, max_tokens, timeout, max_retries, api_key etc.)
|
|
70
|
-
Returns:
|
|
71
|
-
ChatAnthropic: Chatbot model
|
|
72
|
-
"""
|
|
73
|
-
if not check_package("langchain_anthropic"):
|
|
74
|
-
raise ImportError("Anthropic package not found. Please install it using: pip install langchain-anthropic")
|
|
75
|
-
|
|
76
|
-
from langchain_anthropic import ChatAnthropic
|
|
77
|
-
kwargs["model_name"] = model_name
|
|
78
|
-
return ChatAnthropic(**kwargs)
|
|
109
|
+
"""
|
|
110
|
+
return importlib.util.find_spec(package_name) is not None
|
|
79
111
|
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
112
|
+
@staticmethod
|
|
113
|
+
def get_rag_openai(model_type: str = 'text-embedding-3-small', **kwargs):
|
|
114
|
+
"""
|
|
115
|
+
Load OpenAI embedding model.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
model_type (str): Model identifier (default: 'text-embedding-3-small')
|
|
119
|
+
**kwargs: Additional arguments for model initialization
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
OpenAIEmbeddings: Initialized OpenAI embeddings model
|
|
123
|
+
"""
|
|
124
|
+
if not ModelProvider.check_package("langchain_openai"):
|
|
125
|
+
raise ImportError("OpenAI package not found. Please install: pip install langchain-openai")
|
|
126
|
+
from langchain_openai import OpenAIEmbeddings
|
|
127
|
+
return OpenAIEmbeddings(model=model_type, **kwargs)
|
|
128
|
+
|
|
129
|
+
@staticmethod
|
|
130
|
+
def get_rag_ollama(model_type: str = 'llama3', **kwargs):
|
|
131
|
+
"""
|
|
132
|
+
Load Ollama embedding model.
|
|
95
133
|
|
|
96
|
-
|
|
134
|
+
Args:
|
|
135
|
+
model_type (str): Model identifier (default: 'llama3')
|
|
136
|
+
**kwargs: Additional arguments for model initialization
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
OllamaEmbeddings: Initialized Ollama embeddings model
|
|
140
|
+
"""
|
|
141
|
+
if not ModelProvider.check_package("langchain_ollama"):
|
|
142
|
+
raise ImportError("Ollama package not found. Please install: pip install langchain-ollama")
|
|
143
|
+
from langchain_ollama import OllamaEmbeddings
|
|
144
|
+
return OllamaEmbeddings(model=model_type, **kwargs)
|
|
145
|
+
|
|
146
|
+
@staticmethod
|
|
147
|
+
def get_rag_anthropic(model_name: str = "claude-3-opus-20240229", **kwargs):
|
|
148
|
+
"""
|
|
149
|
+
Load Anthropic model.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
model_name (str): Model identifier (default: "claude-3-opus-20240229")
|
|
153
|
+
**kwargs: Additional arguments for model initialization
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
ChatAnthropic: Initialized Anthropic chat model
|
|
157
|
+
|
|
158
|
+
"""
|
|
159
|
+
if not ModelProvider.check_package("langchain_anthropic"):
|
|
160
|
+
raise ImportError("Anthropic package not found. Please install: pip install langchain-anthropic")
|
|
161
|
+
from langchain_anthropic import ChatAnthropic
|
|
162
|
+
kwargs["model_name"] = model_name
|
|
163
|
+
return ChatAnthropic(**kwargs)
|
|
164
|
+
|
|
165
|
+
@staticmethod
|
|
166
|
+
def get_rag_google(model_name: str = "gemini-1.5-flash", **kwargs):
|
|
167
|
+
"""
|
|
168
|
+
Load Google embedding model.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
model_name (str): Model identifier (default: "gemini-1.5-flash")
|
|
172
|
+
**kwargs: Additional arguments for model initialization
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
GoogleGenerativeAIEmbeddings: Initialized Google embeddings model
|
|
176
|
+
"""
|
|
177
|
+
if not ModelProvider.check_package("google.generativeai"):
|
|
178
|
+
raise ImportError("Google Generative AI package not found. Please install: pip install langchain-google-genai")
|
|
179
|
+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
|
180
|
+
kwargs["model"] = model_name
|
|
181
|
+
return GoogleGenerativeAIEmbeddings(**kwargs)
|
|
182
|
+
|
|
183
|
+
def load_embedding_model(model_name: str = 'openai', model_type: str = "text-embedding-ada-002", **kwargs):
|
|
97
184
|
"""
|
|
98
|
-
Load a RAG model
|
|
185
|
+
Load a RAG model based on provider and type.
|
|
186
|
+
|
|
99
187
|
Args:
|
|
100
|
-
model_name (str): Name of the model
|
|
101
|
-
model_type (str): Type of the model
|
|
102
|
-
**kwargs: Additional arguments
|
|
188
|
+
model_name (str): Name of the model provider (default: 'openai')
|
|
189
|
+
model_type (str): Type/identifier of the model (default: "text-embedding-ada-002")
|
|
190
|
+
**kwargs: Additional arguments for model initialization
|
|
191
|
+
|
|
103
192
|
Returns:
|
|
104
|
-
|
|
193
|
+
Any: Initialized model instance
|
|
194
|
+
|
|
195
|
+
Example:
|
|
196
|
+
```python
|
|
197
|
+
model = load_embedding_model('openai', 'text-embedding-3-small')
|
|
198
|
+
```
|
|
105
199
|
"""
|
|
106
200
|
try:
|
|
107
201
|
if model_name == 'openai':
|
|
108
|
-
return get_rag_openai(model_type, **
|
|
202
|
+
return ModelProvider.get_rag_openai(model_type, **kwargs)
|
|
109
203
|
elif model_name == 'ollama':
|
|
110
|
-
return get_rag_ollama(model_type, **
|
|
204
|
+
return ModelProvider.get_rag_ollama(model_type, **kwargs)
|
|
111
205
|
elif model_name == 'google':
|
|
112
|
-
return get_rag_google(model_type, **
|
|
206
|
+
return ModelProvider.get_rag_google(model_type, **kwargs)
|
|
113
207
|
elif model_name == 'anthropic':
|
|
114
|
-
return get_rag_anthropic(model_type, **
|
|
208
|
+
return ModelProvider.get_rag_anthropic(model_type, **kwargs)
|
|
115
209
|
else:
|
|
116
210
|
raise ValueError(f"Invalid model name: {model_name}")
|
|
117
211
|
except ImportError as e:
|
|
118
212
|
print(f"Error loading model: {str(e)}")
|
|
119
213
|
return None
|
|
120
214
|
|
|
121
|
-
class
|
|
215
|
+
class TextProcessor:
|
|
122
216
|
"""
|
|
123
|
-
|
|
217
|
+
Handles text processing operations including file checking and tokenization.
|
|
218
|
+
|
|
219
|
+
This class provides methods for loading text files, processing them into chunks,
|
|
220
|
+
and preparing them for embedding generation.
|
|
221
|
+
|
|
124
222
|
Args:
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
223
|
+
logger: Optional logger instance for logging operations
|
|
224
|
+
|
|
225
|
+
Example:
|
|
226
|
+
```python
|
|
227
|
+
processor = TextProcessor()
|
|
228
|
+
docs = processor.tokenize(
|
|
229
|
+
['./data.txt'],
|
|
230
|
+
'recursive_character',
|
|
231
|
+
chunk_size=1000,
|
|
232
|
+
chunk_overlap=5
|
|
233
|
+
)
|
|
234
|
+
```
|
|
132
235
|
"""
|
|
133
|
-
|
|
134
|
-
def __init__(self,
|
|
236
|
+
|
|
237
|
+
def __init__(self, logger=None):
|
|
135
238
|
self.logger = logger
|
|
136
|
-
self.model = load_rag_model(model_name=model, model_type=model_type, **(model_kwargs or {}))
|
|
137
|
-
if self.model is None:
|
|
138
|
-
raise ValueError(f"Failed to initialize model {model}. Please ensure required packages are installed.")
|
|
139
|
-
self.vector_store_type = vector_store_type
|
|
140
|
-
self.vector_store = self.load_vectorstore(**(vector_store_kwargs or {}))
|
|
141
|
-
self.collection_name = collection_name
|
|
142
239
|
|
|
143
|
-
def check_file(self, file_path):
|
|
144
|
-
"""
|
|
145
|
-
|
|
146
|
-
"""
|
|
147
|
-
if os.path.exists(file_path):
|
|
148
|
-
return True
|
|
149
|
-
else:
|
|
150
|
-
return False
|
|
240
|
+
def check_file(self, file_path: str) -> bool:
|
|
241
|
+
"""Check if file exists."""
|
|
242
|
+
return os.path.exists(file_path)
|
|
151
243
|
|
|
152
|
-
def tokenize(self,text_data_path
|
|
244
|
+
def tokenize(self, text_data_path: List[str], text_splitter_type: str,
|
|
245
|
+
chunk_size: int, chunk_overlap: int) -> List:
|
|
153
246
|
"""
|
|
154
|
-
|
|
247
|
+
Process and tokenize text data from files.
|
|
248
|
+
|
|
155
249
|
Args:
|
|
156
|
-
|
|
250
|
+
text_data_path (List[str]): List of paths to text files
|
|
251
|
+
text_splitter_type (str): Type of text splitter to use
|
|
252
|
+
chunk_size (int): Size of text chunks
|
|
253
|
+
chunk_overlap (int): Overlap between chunks
|
|
254
|
+
|
|
157
255
|
Returns:
|
|
158
|
-
|
|
256
|
+
List: List of processed document chunks
|
|
257
|
+
|
|
159
258
|
"""
|
|
160
|
-
doc_data = []
|
|
161
|
-
for
|
|
162
|
-
if self.check_file(
|
|
163
|
-
text_loader = TextLoader(
|
|
259
|
+
doc_data = []
|
|
260
|
+
for path in text_data_path:
|
|
261
|
+
if self.check_file(path):
|
|
262
|
+
text_loader = TextLoader(path)
|
|
164
263
|
get_text = text_loader.load()
|
|
165
|
-
|
|
166
|
-
file_name = i.split('/')[-1]
|
|
264
|
+
file_name = path.split('/')[-1]
|
|
167
265
|
metadata = {'source': file_name}
|
|
168
266
|
if metadata is not None:
|
|
169
|
-
for
|
|
170
|
-
|
|
171
|
-
doc_data.append(
|
|
172
|
-
if self.logger
|
|
267
|
+
for doc in get_text:
|
|
268
|
+
doc.metadata = metadata
|
|
269
|
+
doc_data.append(doc)
|
|
270
|
+
if self.logger:
|
|
173
271
|
self.logger.info(f"Text data loaded from {file_name}")
|
|
174
272
|
else:
|
|
175
|
-
return f"File {
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
273
|
+
return f"File {path} not found"
|
|
274
|
+
|
|
275
|
+
splitters = {
|
|
276
|
+
'character': CharacterTextSplitter(
|
|
277
|
+
chunk_size=chunk_size,
|
|
278
|
+
chunk_overlap=chunk_overlap,
|
|
279
|
+
separator=["\n", "\n\n", "\n\n\n", " "]
|
|
280
|
+
),
|
|
281
|
+
'recursive_character': RecursiveCharacterTextSplitter(
|
|
282
|
+
chunk_size=chunk_size,
|
|
283
|
+
chunk_overlap=chunk_overlap,
|
|
284
|
+
separators=["\n", "\n\n", "\n\n\n", " "]
|
|
285
|
+
),
|
|
286
|
+
'sentence_transformers_token': SentenceTransformersTokenTextSplitter(
|
|
287
|
+
chunk_size=chunk_size
|
|
288
|
+
),
|
|
289
|
+
'token': TokenTextSplitter(
|
|
290
|
+
chunk_size=chunk_size,
|
|
291
|
+
chunk_overlap=chunk_overlap
|
|
292
|
+
)
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
if text_splitter_type not in splitters:
|
|
296
|
+
raise ValueError(f"Invalid text splitter type: {text_splitter_type}")
|
|
297
|
+
|
|
298
|
+
text_splitter = splitters[text_splitter_type]
|
|
299
|
+
docs = text_splitter.split_documents(doc_data)
|
|
300
|
+
|
|
301
|
+
if self.logger:
|
|
189
302
|
self.logger.info(f"Text data splitted into {len(docs)} chunks")
|
|
190
303
|
else:
|
|
191
304
|
print(f"Text data splitted into {len(docs)} chunks")
|
|
192
|
-
return docs
|
|
305
|
+
return docs
|
|
306
|
+
|
|
307
|
+
class embedding_generator:
|
|
308
|
+
"""
|
|
309
|
+
Main class for generating embeddings and managing RAG operations.
|
|
310
|
+
|
|
311
|
+
This class provides comprehensive functionality for generating embeddings,
|
|
312
|
+
managing vector stores, handling retrievers, and managing conversations.
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
model (str): Model provider name (default: 'openai')
|
|
316
|
+
model_type (str): Model type/identifier (default: 'text-embedding-3-small')
|
|
317
|
+
vector_store_type (str): Type of vector store (default: 'chroma')
|
|
318
|
+
collection_name (str): Name of the collection (default: 'test')
|
|
319
|
+
logger: Optional logger instance
|
|
320
|
+
model_kwargs (dict): Additional arguments for model initialization
|
|
321
|
+
vector_store_kwargs (dict): Additional arguments for vector store initialization
|
|
322
|
+
|
|
323
|
+
Example:
|
|
324
|
+
```python
|
|
325
|
+
# Initialize generator
|
|
326
|
+
gen = embedding_generator(
|
|
327
|
+
model="openai",
|
|
328
|
+
model_type="text-embedding-3-small"
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
# Generate embeddings
|
|
332
|
+
gen.generate_text_embeddings(
|
|
333
|
+
text_data_path=['./data.txt'],
|
|
334
|
+
folder_save_path='./embeddings'
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
# Load retriever
|
|
338
|
+
retriever = gen.load_retriever('./embeddings')
|
|
339
|
+
|
|
340
|
+
# Query embeddings
|
|
341
|
+
results = gen.query_embeddings("What is this about?")
|
|
342
|
+
```
|
|
343
|
+
"""
|
|
193
344
|
|
|
194
|
-
def
|
|
195
|
-
|
|
196
|
-
|
|
345
|
+
def __init__(self, model: str = 'openai', model_type: str = 'text-embedding-3-small',
|
|
346
|
+
vector_store_type: str = 'chroma', collection_name: str = 'test',
|
|
347
|
+
logger=None, model_kwargs: dict = None, vector_store_kwargs: dict = None) -> None:
|
|
348
|
+
"""Initialize the embedding generator with specified configuration."""
|
|
349
|
+
self.logger = logger
|
|
350
|
+
self.model = load_embedding_model(model_name=model, model_type=model_type, **(model_kwargs or {}))
|
|
351
|
+
if self.model is None:
|
|
352
|
+
raise ValueError(f"Failed to initialize model {model}. Please ensure required packages are installed.")
|
|
353
|
+
self.vector_store_type = vector_store_type
|
|
354
|
+
self.vector_store = self.load_vectorstore(**(vector_store_kwargs or {}))
|
|
355
|
+
self.collection_name = collection_name
|
|
356
|
+
self.text_processor = TextProcessor(logger)
|
|
357
|
+
|
|
358
|
+
def check_file(self, file_path: str) -> bool:
|
|
359
|
+
"""Check if file exists."""
|
|
360
|
+
return self.text_processor.check_file(file_path)
|
|
361
|
+
|
|
362
|
+
def tokenize(self, text_data_path: List[str], text_splitter_type: str,
|
|
363
|
+
chunk_size: int, chunk_overlap: int) -> List:
|
|
364
|
+
"""Process and tokenize text data."""
|
|
365
|
+
return self.text_processor.tokenize(text_data_path, text_splitter_type,
|
|
366
|
+
chunk_size, chunk_overlap)
|
|
367
|
+
|
|
368
|
+
def generate_text_embeddings(self, text_data_path: List[str] = None,
|
|
369
|
+
text_splitter_type: str = 'recursive_character',
|
|
370
|
+
chunk_size: int = 1000, chunk_overlap: int = 5,
|
|
371
|
+
folder_save_path: str = './text_embeddings',
|
|
372
|
+
replace_existing: bool = False) -> str:
|
|
197
373
|
"""
|
|
198
|
-
|
|
374
|
+
Generate text embeddings from input files.
|
|
375
|
+
|
|
199
376
|
Args:
|
|
200
|
-
text_data_path:
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
377
|
+
text_data_path (List[str]): List of paths to text files
|
|
378
|
+
text_splitter_type (str): Type of text splitter
|
|
379
|
+
chunk_size (int): Size of text chunks
|
|
380
|
+
chunk_overlap (int): Overlap between chunks
|
|
381
|
+
folder_save_path (str): Path to save embeddings
|
|
382
|
+
replace_existing (bool): Whether to replace existing embeddings
|
|
383
|
+
|
|
207
384
|
Returns:
|
|
208
|
-
|
|
385
|
+
str: Status message
|
|
386
|
+
|
|
387
|
+
Example:
|
|
388
|
+
```python
|
|
389
|
+
gen.generate_text_embeddings(
|
|
390
|
+
text_data_path=['./data.txt'],
|
|
391
|
+
folder_save_path='./embeddings'
|
|
392
|
+
)
|
|
393
|
+
```
|
|
209
394
|
"""
|
|
395
|
+
if self.logger:
|
|
396
|
+
self.logger.info("Performing basic checks")
|
|
210
397
|
|
|
211
|
-
if self.
|
|
212
|
-
self.logger.info("Perforing basic checks")
|
|
213
|
-
|
|
214
|
-
if self.check_file(folder_save_path) and replace_existing==False:
|
|
398
|
+
if self.check_file(folder_save_path) and not replace_existing:
|
|
215
399
|
return "File already exists"
|
|
216
400
|
elif self.check_file(folder_save_path) and replace_existing:
|
|
217
|
-
shutil.rmtree(folder_save_path)
|
|
401
|
+
shutil.rmtree(folder_save_path)
|
|
218
402
|
|
|
219
403
|
if text_data_path is None:
|
|
220
404
|
return "Please provide text data path"
|
|
221
405
|
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
# assert isinstance(metadata, list), "metadata should be a list"
|
|
225
|
-
# assert len(text_data_path) == len(metadata), "Number of text files and metadata should be equal"
|
|
406
|
+
if not isinstance(text_data_path, list):
|
|
407
|
+
raise ValueError("text_data_path should be a list")
|
|
226
408
|
|
|
227
|
-
if self.logger
|
|
409
|
+
if self.logger:
|
|
228
410
|
self.logger.info(f"Loading text data from {text_data_path}")
|
|
229
411
|
|
|
230
|
-
docs = self.tokenize(text_data_path,text_splitter_type,chunk_size,chunk_overlap)
|
|
412
|
+
docs = self.tokenize(text_data_path, text_splitter_type, chunk_size, chunk_overlap)
|
|
231
413
|
|
|
232
|
-
if self.logger
|
|
233
|
-
self.logger.info(f"Generating embeddings for {len(docs)} documents")
|
|
414
|
+
if self.logger:
|
|
415
|
+
self.logger.info(f"Generating embeddings for {len(docs)} documents")
|
|
234
416
|
|
|
235
|
-
self.vector_store.from_documents(docs, self.model,collection_name=self.collection_name,
|
|
417
|
+
self.vector_store.from_documents(docs, self.model, collection_name=self.collection_name,
|
|
418
|
+
persist_directory=folder_save_path)
|
|
236
419
|
|
|
237
|
-
if self.logger
|
|
420
|
+
if self.logger:
|
|
238
421
|
self.logger.info(f"Embeddings generated and saved at {folder_save_path}")
|
|
239
422
|
|
|
240
|
-
def load_vectorstore(self):
|
|
241
|
-
"""
|
|
242
|
-
Function to load vector store
|
|
243
|
-
Args:
|
|
244
|
-
vector_store_type: type of vector store
|
|
245
|
-
Returns:
|
|
246
|
-
vector store
|
|
247
|
-
"""
|
|
423
|
+
def load_vectorstore(self, **kwargs):
|
|
424
|
+
"""Load vector store."""
|
|
248
425
|
if self.vector_store_type == 'chroma':
|
|
249
426
|
vector_store = Chroma()
|
|
250
|
-
if self.logger
|
|
427
|
+
if self.logger:
|
|
251
428
|
self.logger.info(f"Loaded vector store {self.vector_store_type}")
|
|
252
429
|
return vector_store
|
|
253
430
|
else:
|
|
254
431
|
return "Vector store not found"
|
|
255
432
|
|
|
256
|
-
def load_embeddings(self,embeddings_folder_path: str):
|
|
433
|
+
def load_embeddings(self, embeddings_folder_path: str):
|
|
257
434
|
"""
|
|
258
|
-
|
|
435
|
+
Load embeddings from folder.
|
|
436
|
+
|
|
259
437
|
Args:
|
|
260
|
-
|
|
438
|
+
embeddings_folder_path (str): Path to embeddings folder
|
|
439
|
+
|
|
261
440
|
Returns:
|
|
262
|
-
|
|
441
|
+
Optional[Chroma]: Loaded vector store or None if not found
|
|
263
442
|
"""
|
|
264
443
|
if self.check_file(embeddings_folder_path):
|
|
265
444
|
if self.vector_store_type == 'chroma':
|
|
266
|
-
|
|
267
|
-
|
|
445
|
+
return Chroma(persist_directory=embeddings_folder_path,
|
|
446
|
+
embedding_function=self.model)
|
|
268
447
|
else:
|
|
269
448
|
if self.logger:
|
|
270
|
-
self.logger.info("Embeddings file not found")
|
|
271
|
-
return None
|
|
272
|
-
|
|
273
|
-
def load_retriever(self,embeddings_folder_path: str,
|
|
449
|
+
self.logger.info("Embeddings file not found")
|
|
450
|
+
return None
|
|
451
|
+
|
|
452
|
+
def load_retriever(self, embeddings_folder_path: str,
|
|
453
|
+
search_type: List[str] = ["similarity_score_threshold"],
|
|
454
|
+
search_params: List[Dict] = [{"k": 3, "score_threshold": 0.9}]):
|
|
274
455
|
"""
|
|
275
|
-
|
|
456
|
+
Load retriever with search configuration.
|
|
457
|
+
|
|
276
458
|
Args:
|
|
277
|
-
|
|
278
|
-
search_type
|
|
279
|
-
search_params:
|
|
459
|
+
embeddings_folder_path (str): Path to embeddings folder
|
|
460
|
+
search_type (List[str]): List of search types
|
|
461
|
+
search_params (List[Dict]): List of search parameters
|
|
462
|
+
|
|
280
463
|
Returns:
|
|
281
|
-
|
|
464
|
+
Union[Any, List[Any]]: Single retriever or list of retrievers
|
|
465
|
+
|
|
466
|
+
Example:
|
|
467
|
+
```python
|
|
468
|
+
retriever = gen.load_retriever(
|
|
469
|
+
'./embeddings',
|
|
470
|
+
search_type=["similarity_score_threshold"],
|
|
471
|
+
search_params=[{"k": 3, "score_threshold": 0.9}]
|
|
472
|
+
)
|
|
473
|
+
```
|
|
282
474
|
"""
|
|
283
475
|
db = self.load_embeddings(embeddings_folder_path)
|
|
284
476
|
if db is not None:
|
|
285
477
|
if self.vector_store_type == 'chroma':
|
|
286
|
-
|
|
478
|
+
if len(search_type) != len(search_params):
|
|
479
|
+
raise ValueError("Length of search_type and search_params should be equal")
|
|
287
480
|
if len(search_type) == 1:
|
|
288
|
-
self.retriever = db.as_retriever(search_type
|
|
481
|
+
self.retriever = db.as_retriever(search_type=search_type[0],
|
|
482
|
+
search_kwargs=search_params[0])
|
|
289
483
|
if self.logger:
|
|
290
484
|
self.logger.info("Retriever loaded")
|
|
291
485
|
return self.retriever
|
|
292
486
|
else:
|
|
293
487
|
retriever_list = []
|
|
294
488
|
for i in range(len(search_type)):
|
|
295
|
-
retriever_list.append(db.as_retriever(search_type
|
|
489
|
+
retriever_list.append(db.as_retriever(search_type=search_type[i],
|
|
490
|
+
search_kwargs=search_params[i]))
|
|
296
491
|
if self.logger:
|
|
297
|
-
|
|
492
|
+
self.logger.info("List of Retriever loaded")
|
|
298
493
|
return retriever_list
|
|
299
494
|
else:
|
|
300
495
|
return "Embeddings file not found"
|
|
301
|
-
|
|
302
|
-
def add_data(self,embeddings_folder_path: str, data:
|
|
303
|
-
|
|
496
|
+
|
|
497
|
+
def add_data(self, embeddings_folder_path: str, data: List[str],
|
|
498
|
+
text_splitter_type: str = 'recursive_character',
|
|
499
|
+
chunk_size: int = 1000, chunk_overlap: int = 5):
|
|
304
500
|
"""
|
|
305
|
-
|
|
501
|
+
Add data to existing embeddings.
|
|
502
|
+
|
|
306
503
|
Args:
|
|
307
|
-
|
|
308
|
-
data:
|
|
309
|
-
text_splitter_type:
|
|
310
|
-
chunk_size:
|
|
311
|
-
chunk_overlap:
|
|
312
|
-
Returns:
|
|
313
|
-
None
|
|
504
|
+
embeddings_folder_path (str): Path to embeddings folder
|
|
505
|
+
data (List[str]): List of text data to add
|
|
506
|
+
text_splitter_type (str): Type of text splitter
|
|
507
|
+
chunk_size (int): Size of text chunks
|
|
508
|
+
chunk_overlap (int): Overlap between chunks
|
|
314
509
|
"""
|
|
315
510
|
if self.vector_store_type == 'chroma':
|
|
316
511
|
db = self.load_embeddings(embeddings_folder_path)
|
|
317
512
|
if db is not None:
|
|
318
|
-
docs = self.tokenize(data,text_splitter_type,chunk_size,chunk_overlap)
|
|
513
|
+
docs = self.tokenize(data, text_splitter_type, chunk_size, chunk_overlap)
|
|
319
514
|
db.add_documents(docs)
|
|
320
515
|
if self.logger:
|
|
321
516
|
self.logger.info("Data added to the existing db/embeddings")
|
|
322
|
-
|
|
323
|
-
def query_embeddings(self,query: str,retriever
|
|
517
|
+
|
|
518
|
+
def query_embeddings(self, query: str, retriever=None):
|
|
324
519
|
"""
|
|
325
|
-
|
|
520
|
+
Query embeddings.
|
|
521
|
+
|
|
326
522
|
Args:
|
|
327
|
-
|
|
328
|
-
|
|
523
|
+
query (str): Query string
|
|
524
|
+
retriever: Optional retriever instance
|
|
525
|
+
|
|
329
526
|
Returns:
|
|
330
|
-
results
|
|
527
|
+
Any: Query results
|
|
331
528
|
"""
|
|
332
|
-
# if self.vector_store_type == 'chroma':
|
|
333
529
|
if retriever is None:
|
|
334
530
|
retriever = self.retriever
|
|
335
531
|
return retriever.invoke(query)
|
|
336
|
-
# else:
|
|
337
|
-
# return "Vector store not found"
|
|
338
532
|
|
|
339
|
-
def get_relevant_documents(self,query: str,retriever
|
|
533
|
+
def get_relevant_documents(self, query: str, retriever=None):
|
|
340
534
|
"""
|
|
341
|
-
|
|
535
|
+
Get relevant documents for query.
|
|
536
|
+
|
|
342
537
|
Args:
|
|
343
|
-
query:
|
|
538
|
+
query (str): Query string
|
|
539
|
+
retriever: Optional retriever instance
|
|
540
|
+
|
|
344
541
|
Returns:
|
|
345
|
-
|
|
542
|
+
List: List of relevant documents
|
|
346
543
|
"""
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
544
|
+
if retriever is None:
|
|
545
|
+
retriever = self.retriever
|
|
546
|
+
return retriever.get_relevant_documents(query)
|
|
547
|
+
|
|
548
|
+
def generate_rag_chain(self, context_prompt: str = None, retriever=None, llm=None):
|
|
350
549
|
"""
|
|
351
|
-
|
|
550
|
+
Generate RAG chain for conversation.
|
|
551
|
+
|
|
352
552
|
Args:
|
|
353
|
-
context_prompt:
|
|
354
|
-
retriever: retriever
|
|
355
|
-
llm: language model
|
|
553
|
+
context_prompt (str): Optional context prompt
|
|
554
|
+
retriever: Optional retriever instance
|
|
555
|
+
llm: Optional language model instance
|
|
556
|
+
|
|
356
557
|
Returns:
|
|
357
|
-
|
|
558
|
+
Any: Generated RAG chain
|
|
559
|
+
|
|
560
|
+
Example:
|
|
561
|
+
```python
|
|
562
|
+
rag_chain = gen.generate_rag_chain(retriever=retriever)
|
|
563
|
+
```
|
|
358
564
|
"""
|
|
359
565
|
if context_prompt is None:
|
|
360
|
-
context_prompt = ("You are an assistant for question-answering tasks.
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
566
|
+
context_prompt = ("You are an assistant for question-answering tasks. "
|
|
567
|
+
"Use the following pieces of retrieved context to answer the question. "
|
|
568
|
+
"If you don't know the answer, just say that you don't know. "
|
|
569
|
+
"Use three sentences maximum and keep the answer concise.\n\n{context}")
|
|
570
|
+
|
|
571
|
+
contextualize_q_system_prompt = ("Given a chat history and the latest user question "
|
|
572
|
+
"which might reference context in the chat history, "
|
|
573
|
+
"formulate a standalone question which can be understood "
|
|
574
|
+
"without the chat history. Do NOT answer the question, "
|
|
575
|
+
"just reformulate it if needed and otherwise return it as is.")
|
|
576
|
+
|
|
577
|
+
contextualize_q_prompt = ChatPromptTemplate.from_messages([
|
|
578
|
+
("system", contextualize_q_system_prompt),
|
|
579
|
+
MessagesPlaceholder("chat_history"),
|
|
580
|
+
("human", "{input}"),
|
|
581
|
+
])
|
|
366
582
|
|
|
367
583
|
if retriever is None:
|
|
368
584
|
retriever = self.retriever
|
|
369
585
|
if llm is None:
|
|
370
|
-
if not check_package("langchain_openai"):
|
|
371
|
-
raise ImportError("OpenAI package not found. Please install
|
|
586
|
+
if not ModelProvider.check_package("langchain_openai"):
|
|
587
|
+
raise ImportError("OpenAI package not found. Please install: pip install langchain-openai")
|
|
372
588
|
from langchain_openai import ChatOpenAI
|
|
373
589
|
llm = ChatOpenAI(model="gpt-4")
|
|
374
590
|
|
|
375
|
-
history_aware_retriever = create_history_aware_retriever(llm,retriever,
|
|
376
|
-
|
|
377
|
-
|
|
591
|
+
history_aware_retriever = create_history_aware_retriever(llm, retriever,
|
|
592
|
+
contextualize_q_prompt)
|
|
593
|
+
qa_prompt = ChatPromptTemplate.from_messages([
|
|
594
|
+
("system", context_prompt),
|
|
595
|
+
MessagesPlaceholder("chat_history"),
|
|
596
|
+
("human", "{input}"),
|
|
597
|
+
])
|
|
598
|
+
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
|
|
378
599
|
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
|
|
379
600
|
return rag_chain
|
|
380
601
|
|
|
381
|
-
def conversation_chain(self,query: str,rag_chain,file:str =None):
|
|
602
|
+
def conversation_chain(self, query: str, rag_chain, file: str = None):
|
|
382
603
|
"""
|
|
383
|
-
|
|
604
|
+
Create conversation chain.
|
|
605
|
+
|
|
384
606
|
Args:
|
|
385
|
-
query: query
|
|
386
|
-
rag_chain
|
|
387
|
-
file:
|
|
607
|
+
query (str): User query
|
|
608
|
+
rag_chain: RAG chain instance
|
|
609
|
+
file (str): Optional file to save conversation
|
|
610
|
+
|
|
388
611
|
Returns:
|
|
389
|
-
|
|
612
|
+
List: Conversation history
|
|
613
|
+
|
|
614
|
+
Example:
|
|
615
|
+
```python
|
|
616
|
+
history = gen.conversation_chain(
|
|
617
|
+
"Tell me about...",
|
|
618
|
+
rag_chain,
|
|
619
|
+
file='conversation.txt'
|
|
620
|
+
)
|
|
621
|
+
```
|
|
390
622
|
"""
|
|
391
623
|
if file is not None:
|
|
392
624
|
try:
|
|
393
|
-
chat_history = self.load_conversation(file,list_type=True)
|
|
625
|
+
chat_history = self.load_conversation(file, list_type=True)
|
|
394
626
|
if len(chat_history) == 0:
|
|
395
627
|
chat_history = []
|
|
396
628
|
except:
|
|
397
629
|
chat_history = []
|
|
398
630
|
else:
|
|
399
631
|
chat_history = []
|
|
400
|
-
|
|
401
|
-
|
|
632
|
+
|
|
633
|
+
query = "You : " + query
|
|
634
|
+
res = rag_chain.invoke({"input": query, "chat_history": chat_history})
|
|
402
635
|
print(f"Response: {res['answer']}")
|
|
403
636
|
chat_history.append(HumanMessage(content=query))
|
|
404
637
|
chat_history.append(SystemMessage(content=res['answer']))
|
|
405
638
|
if file is not None:
|
|
406
|
-
self.save_conversation(chat_history,file)
|
|
639
|
+
self.save_conversation(chat_history, file)
|
|
407
640
|
return chat_history
|
|
408
641
|
|
|
409
|
-
def load_conversation(self,file: str,list_type: bool = False):
|
|
642
|
+
def load_conversation(self, file: str, list_type: bool = False):
|
|
410
643
|
"""
|
|
411
|
-
|
|
644
|
+
Load conversation history.
|
|
645
|
+
|
|
412
646
|
Args:
|
|
413
|
-
file:
|
|
414
|
-
list_type:
|
|
647
|
+
file (str): Path to conversation file
|
|
648
|
+
list_type (bool): Whether to return as list
|
|
649
|
+
|
|
415
650
|
Returns:
|
|
416
|
-
|
|
651
|
+
Union[str, List]: Conversation history
|
|
417
652
|
"""
|
|
418
653
|
if list_type:
|
|
419
654
|
chat_history = []
|
|
420
|
-
with open(file,'r') as f:
|
|
655
|
+
with open(file, 'r') as f:
|
|
421
656
|
for line in f:
|
|
422
|
-
# inner_list = [elt.strip() for elt in line.split(',')]
|
|
423
657
|
chat_history.append(line.strip())
|
|
424
658
|
else:
|
|
425
659
|
with open(file, "r") as f:
|
|
426
660
|
chat_history = f.read()
|
|
427
661
|
return chat_history
|
|
428
662
|
|
|
429
|
-
def save_conversation(self,chat: str,file: str):
|
|
663
|
+
def save_conversation(self, chat: Union[str, List], file: str):
|
|
430
664
|
"""
|
|
431
|
-
|
|
665
|
+
Save conversation history.
|
|
666
|
+
|
|
432
667
|
Args:
|
|
433
|
-
chat:
|
|
434
|
-
file:
|
|
435
|
-
Returns:
|
|
436
|
-
None
|
|
668
|
+
chat (Union[str, List]): Conversation to save
|
|
669
|
+
file (str): Path to save file
|
|
437
670
|
"""
|
|
438
|
-
if isinstance(chat,str):
|
|
671
|
+
if isinstance(chat, str):
|
|
439
672
|
with open(file, "a") as f:
|
|
440
673
|
f.write(chat)
|
|
441
|
-
elif isinstance(chat,list):
|
|
674
|
+
elif isinstance(chat, list):
|
|
442
675
|
with open(file, "a") as f:
|
|
443
676
|
for i in chat[-2:]:
|
|
444
677
|
f.write("%s\n" % i)
|
|
445
678
|
print(f"Saved file : {file}")
|
|
446
679
|
|
|
447
|
-
def firecrawl_web(self, website, api_key: str = None, mode
|
|
680
|
+
def firecrawl_web(self, website: str, api_key: str = None, mode: str = "scrape",
|
|
681
|
+
file_to_save: str = './firecrawl_embeddings', **kwargs):
|
|
448
682
|
"""
|
|
449
|
-
|
|
450
|
-
|
|
683
|
+
Get data from website using FireCrawl.
|
|
684
|
+
|
|
451
685
|
Args:
|
|
452
|
-
website :
|
|
453
|
-
api_key :
|
|
454
|
-
mode(str)
|
|
455
|
-
file_to_save:
|
|
456
|
-
**kwargs:
|
|
686
|
+
website (str): Website URL to crawl
|
|
687
|
+
api_key (str): Optional FireCrawl API key
|
|
688
|
+
mode (str): Crawl mode (default: "scrape")
|
|
689
|
+
file_to_save (str): Path to save embeddings
|
|
690
|
+
**kwargs: Additional arguments for FireCrawl
|
|
691
|
+
|
|
457
692
|
Returns:
|
|
458
|
-
|
|
693
|
+
Chroma: Vector store with crawled data
|
|
694
|
+
|
|
695
|
+
Example:
|
|
696
|
+
```python
|
|
697
|
+
db = gen.firecrawl_web(
|
|
698
|
+
"https://example.com",
|
|
699
|
+
mode="scrape",
|
|
700
|
+
file_to_save='./crawl_embeddings'
|
|
701
|
+
)
|
|
702
|
+
```
|
|
459
703
|
"""
|
|
460
|
-
if not check_package("firecrawl"):
|
|
461
|
-
raise ImportError("Firecrawl package not found. Please install
|
|
462
|
-
|
|
704
|
+
if not ModelProvider.check_package("firecrawl"):
|
|
705
|
+
raise ImportError("Firecrawl package not found. Please install: pip install firecrawl")
|
|
706
|
+
|
|
463
707
|
if api_key is None:
|
|
464
708
|
api_key = os.getenv("FIRECRAWL_API_KEY")
|
|
709
|
+
|
|
465
710
|
loader = FireCrawlLoader(api_key=api_key, url=website, mode=mode)
|
|
466
711
|
docs = loader.load()
|
|
712
|
+
|
|
467
713
|
for doc in docs:
|
|
468
714
|
for key, value in doc.metadata.items():
|
|
469
715
|
if isinstance(value, list):
|
|
470
716
|
doc.metadata[key] = ", ".join(map(str, value))
|
|
717
|
+
|
|
471
718
|
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
|
472
719
|
split_docs = text_splitter.split_documents(docs)
|
|
720
|
+
|
|
473
721
|
print("\n--- Document Chunks Information ---")
|
|
474
722
|
print(f"Number of document chunks: {len(split_docs)}")
|
|
475
723
|
print(f"Sample chunk:\n{split_docs[0].page_content}\n")
|
|
724
|
+
|
|
476
725
|
embeddings = self.model
|
|
477
|
-
db = Chroma.from_documents(
|
|
478
|
-
|
|
726
|
+
db = Chroma.from_documents(split_docs, embeddings,
|
|
727
|
+
persist_directory=file_to_save)
|
|
479
728
|
print(f"Retriever saved at {file_to_save}")
|
|
480
729
|
return db
|
mb_rag/version.py
CHANGED
|
@@ -1,15 +1,15 @@
|
|
|
1
1
|
mb_rag/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
mb_rag/version.py,sha256=
|
|
2
|
+
mb_rag/version.py,sha256=XPsqrBre0W3TXoTeuLAD2FpvSnzAgbTcMVKLDjbbv20,208
|
|
3
3
|
mb_rag/chatbot/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
|
-
mb_rag/chatbot/basic.py,sha256=
|
|
4
|
+
mb_rag/chatbot/basic.py,sha256=OR2IvDg-Sy968C2Mna6lxmFfh7Czj8yCEkCfyvtxBwI,14223
|
|
5
5
|
mb_rag/chatbot/chains.py,sha256=vDbLX5R29sWN1pcFqJ5fyxJEgMCM81JAikunAEvMC9A,7223
|
|
6
6
|
mb_rag/chatbot/prompts.py,sha256=n1PyiLbU-5fkslRv6aVOzt0dDlwya_cEdQ7kRnRhMuY,1749
|
|
7
7
|
mb_rag/rag/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
|
-
mb_rag/rag/embeddings.py,sha256=
|
|
8
|
+
mb_rag/rag/embeddings.py,sha256=E1USUON80Y5VzPENWHx9ke2FNpe6w-nxeddfJl9kOR8,27202
|
|
9
9
|
mb_rag/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
10
|
mb_rag/utils/bounding_box.py,sha256=XnuhnLrsGvsI8P8VtOwlBrDlFE2It1HEZOcLlK6kusE,7931
|
|
11
11
|
mb_rag/utils/extra.py,sha256=spbFrGgdruNyYQ5PzgvpSIa6Nm0rn9bb4qc8W9g582o,2492
|
|
12
|
-
mb_rag-1.0.
|
|
13
|
-
mb_rag-1.0.
|
|
14
|
-
mb_rag-1.0.
|
|
15
|
-
mb_rag-1.0.
|
|
12
|
+
mb_rag-1.0.125.dist-info/METADATA,sha256=D5VYXjhHmVZE3UREC0cMov6z23X7HAJlyaGbLCpfYYY,154
|
|
13
|
+
mb_rag-1.0.125.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
|
14
|
+
mb_rag-1.0.125.dist-info/top_level.txt,sha256=FIK1eAa5uYnurgXZquBG-s3PIy-HDTC5yJBW4lTH_pM,7
|
|
15
|
+
mb_rag-1.0.125.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|