mb-rag 1.1.46__py3-none-any.whl → 1.1.56.post0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mb-rag might be problematic. Click here for more details.

mb_rag/rag/embeddings.py CHANGED
@@ -1,753 +1,810 @@
1
- """
2
- RAG (Retrieval-Augmented Generation) Embeddings Module
3
-
4
- This module provides functionality for generating and managing embeddings for RAG models.
5
- It supports multiple embedding models (OpenAI, Ollama, Google, Anthropic) and includes
6
- features for text processing, embedding generation, vector store management, and
7
- conversation handling.
8
-
9
- Example Usage:
10
- ```python
11
- # Initialize embedding generator
12
- em_gen = embedding_generator(
13
- model="openai",
14
- model_type="text-embedding-3-small",
15
- vector_store_type="chroma"
16
- )
17
-
18
- # Generate embeddings from text
19
- em_gen.generate_text_embeddings(
20
- text_data_path=['./data/text.txt'],
21
- chunk_size=500,
22
- chunk_overlap=5,
23
- folder_save_path='./embeddings'
24
- )
25
-
26
- # Load embeddings and create retriever
27
- em_loading = em_gen.load_embeddings('./embeddings')
28
- em_retriever = em_gen.load_retriever(
29
- './embeddings',
30
- search_params=[{"k": 2, "score_threshold": 0.1}]
31
- )
32
-
33
- # Query embeddings
34
- results = em_retriever.invoke("What is the text about?")
35
-
36
- # Generate RAG chain for conversation
37
- rag_chain = em_gen.generate_rag_chain(retriever=em_retriever)
38
- response = em_gen.conversation_chain("Tell me more", rag_chain)
39
- ```
40
-
41
- Features:
42
- - Multiple model support (OpenAI, Ollama, Google, Anthropic)
43
- - Text processing and chunking
44
- - Embedding generation and storage
45
- - Vector store management
46
- - Retrieval operations
47
- - Conversation chains
48
- - Web crawling integration
49
-
50
- Classes:
51
- - ModelProvider: Base class for model loading and validation
52
- - TextProcessor: Handles text processing operations
53
- - embedding_generator: Main class for RAG operations
54
- """
55
-
56
- import os
57
- import shutil
58
- import importlib.util
59
- from typing import List, Dict, Optional, Union, Any
60
- from langchain.text_splitter import (
61
- CharacterTextSplitter,
62
- RecursiveCharacterTextSplitter,
63
- SentenceTransformersTokenTextSplitter,
64
- TokenTextSplitter)
65
- from langchain_community.document_loaders import TextLoader, FireCrawlLoader
66
- from langchain_chroma import Chroma
67
- from ..utils.extra import load_env_file
68
- from langchain.chains import create_history_aware_retriever, create_retrieval_chain
69
- from langchain.chains.combine_documents import create_stuff_documents_chain
70
- from langchain_core.messages import HumanMessage, SystemMessage
71
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
72
-
73
- load_env_file()
74
-
75
- __all__ = ['embedding_generator', 'load_embedding_model']
76
-
77
- class ModelProvider:
78
- """
79
- Base class for managing different model providers and their loading logic.
80
-
81
- This class provides static methods for loading different types of embedding models
82
- and checking package dependencies.
83
-
84
- Methods:
85
- check_package: Check if a Python package is installed
86
- get_rag_openai: Load OpenAI embedding model
87
- get_rag_ollama: Load Ollama embedding model
88
- get_rag_anthropic: Load Anthropic model
89
- get_rag_google: Load Google embedding model
90
-
91
- Example:
92
- ```python
93
- # Check if a package is installed
94
- has_openai = ModelProvider.check_package("langchain_openai")
95
-
96
- # Load an OpenAI model
97
- model = ModelProvider.get_rag_openai("text-embedding-3-small")
98
- ```
99
- """
100
-
101
- @staticmethod
102
- def check_package(package_name: str) -> bool:
103
- """
104
- Check if a Python package is installed.
105
-
106
- Args:
107
- package_name (str): Name of the package to check
108
-
109
- """
110
- return importlib.util.find_spec(package_name) is not None
111
-
112
- @staticmethod
113
- def get_rag_openai(model_type: str = 'text-embedding-3-small', **kwargs):
114
- """
115
- Load OpenAI embedding model.
116
-
117
- Args:
118
- model_type (str): Model identifier (default: 'text-embedding-3-small')
119
- **kwargs: Additional arguments for model initialization
120
-
121
- Returns:
122
- OpenAIEmbeddings: Initialized OpenAI embeddings model
123
- """
124
- if not ModelProvider.check_package("langchain_openai"):
125
- raise ImportError("OpenAI package not found. Please install: pip install langchain-openai")
126
- from langchain_openai import OpenAIEmbeddings
127
- return OpenAIEmbeddings(model=model_type, **kwargs)
128
-
129
- @staticmethod
130
- def get_rag_ollama(model_type: str = 'llama3', **kwargs):
131
- """
132
- Load Ollama embedding model.
133
-
134
- Args:
135
- model_type (str): Model identifier (default: 'llama3')
136
- **kwargs: Additional arguments for model initialization
137
-
138
- Returns:
139
- OllamaEmbeddings: Initialized Ollama embeddings model
140
- """
141
- if not ModelProvider.check_package("langchain_ollama"):
142
- raise ImportError("Ollama package not found. Please install: pip install langchain-ollama")
143
- from langchain_ollama import OllamaEmbeddings
144
- return OllamaEmbeddings(model=model_type, **kwargs)
145
-
146
- @staticmethod
147
- def get_rag_anthropic(model_name: str = "claude-3-opus-20240229", **kwargs):
148
- """
149
- Load Anthropic model.
150
-
151
- Args:
152
- model_name (str): Model identifier (default: "claude-3-opus-20240229")
153
- **kwargs: Additional arguments for model initialization
154
-
155
- Returns:
156
- ChatAnthropic: Initialized Anthropic chat model
157
-
158
- """
159
- if not ModelProvider.check_package("langchain_anthropic"):
160
- raise ImportError("Anthropic package not found. Please install: pip install langchain-anthropic")
161
- from langchain_anthropic import ChatAnthropic
162
- kwargs["model_name"] = model_name
163
- return ChatAnthropic(**kwargs)
164
-
165
- @staticmethod
166
- def get_rag_google(model_name: str = "gemini-1.5-flash", **kwargs):
167
- """
168
- Load Google embedding model.
169
-
170
- Args:
171
- model_name (str): Model identifier (default: "gemini-1.5-flash")
172
- **kwargs: Additional arguments for model initialization
173
-
174
- Returns:
175
- GoogleGenerativeAIEmbeddings: Initialized Google embeddings model
176
- """
177
- if not ModelProvider.check_package("google.generativeai"):
178
- raise ImportError("Google Generative AI package not found. Please install: pip install langchain-google-genai")
179
- from langchain_google_genai import GoogleGenerativeAIEmbeddings
180
- kwargs["model"] = model_name
181
- return GoogleGenerativeAIEmbeddings(**kwargs)
182
-
183
- @staticmethod
184
- def get_rag_qwen(model_name: str = "Qwen/Qwen3-Embedding-0.6B", **kwargs):
185
- """
186
- Load Qwen embedding model.
187
- Uses Transformers for embedding generation.
188
-
189
- Args:
190
- model_name (str): Model identifier (default: "Qwen/Qwen3-Embedding-0.6B")
191
- **kwargs: Additional arguments for model initialization
192
-
193
- Returns:
194
- QwenEmbeddings: Initialized Qwen embeddings model
195
- """
196
- from langchain.embeddings import HuggingFaceEmbeddings
197
-
198
- return HuggingFaceEmbeddings(model_name=model_name, **kwargs)
199
-
200
- def load_embedding_model(model_name: str = 'openai', model_type: str = "text-embedding-ada-002", **kwargs):
201
- """
202
- Load a RAG model based on provider and type.
203
-
204
- Args:
205
- model_name (str): Name of the model provider (default: 'openai')
206
- model_type (str): Type/identifier of the model (default: "text-embedding-ada-002")
207
- **kwargs: Additional arguments for model initialization
208
-
209
- Returns:
210
- Any: Initialized model instance
211
-
212
- Example:
213
- ```python
214
- model = load_embedding_model('openai', 'text-embedding-3-small')
215
- ```
216
- """
217
- try:
218
- if model_name == 'openai':
219
- return ModelProvider.get_rag_openai(model_type, **kwargs)
220
- elif model_name == 'ollama':
221
- return ModelProvider.get_rag_ollama(model_type, **kwargs)
222
- elif model_name == 'google':
223
- return ModelProvider.get_rag_google(model_type, **kwargs)
224
- elif model_name == 'anthropic':
225
- return ModelProvider.get_rag_anthropic(model_type, **kwargs)
226
- elif model_name == 'qwen':
227
- return ModelProvider.get_rag_qwen(model_type, **kwargs)
228
- else:
229
- raise ValueError(f"Invalid model name: {model_name}")
230
- except ImportError as e:
231
- print(f"Error loading model: {str(e)}")
232
- return None
233
-
234
- class TextProcessor:
235
- """
236
- Handles text processing operations including file checking and tokenization.
237
-
238
- This class provides methods for loading text files, processing them into chunks,
239
- and preparing them for embedding generation.
240
-
241
- Args:
242
- logger: Optional logger instance for logging operations
243
-
244
- Example:
245
- ```python
246
- processor = TextProcessor()
247
- docs = processor.tokenize(
248
- ['./data.txt'],
249
- 'recursive_character',
250
- chunk_size=1000,
251
- chunk_overlap=5
252
- )
253
- ```
254
- """
255
-
256
- def __init__(self, logger=None):
257
- self.logger = logger
258
-
259
- def check_file(self, file_path: str) -> bool:
260
- """Check if file exists."""
261
- return os.path.exists(file_path)
262
-
263
- def tokenize(self, text_data_path: List[str], text_splitter_type: str,
264
- chunk_size: int, chunk_overlap: int) -> List:
265
- """
266
- Process and tokenize text data from files.
267
-
268
- Args:
269
- text_data_path (List[str]): List of paths to text files
270
- text_splitter_type (str): Type of text splitter to use
271
- chunk_size (int): Size of text chunks
272
- chunk_overlap (int): Overlap between chunks
273
-
274
- Returns:
275
- List: List of processed document chunks
276
-
277
- """
278
- doc_data = []
279
- for path in text_data_path:
280
- if self.check_file(path):
281
- text_loader = TextLoader(path)
282
- get_text = text_loader.load()
283
- file_name = path.split('/')[-1]
284
- metadata = {'source': file_name}
285
- if metadata is not None:
286
- for doc in get_text:
287
- doc.metadata = metadata
288
- doc_data.append(doc)
289
- if self.logger:
290
- self.logger.info(f"Text data loaded from {file_name}")
291
- else:
292
- return f"File {path} not found"
293
-
294
- splitters = {
295
- 'character': CharacterTextSplitter(
296
- chunk_size=chunk_size,
297
- chunk_overlap=chunk_overlap,
298
- separator=["\n", "\n\n", "\n\n\n", " "]
299
- ),
300
- 'recursive_character': RecursiveCharacterTextSplitter(
301
- chunk_size=chunk_size,
302
- chunk_overlap=chunk_overlap,
303
- separators=["\n", "\n\n", "\n\n\n", " "]
304
- ),
305
- 'sentence_transformers_token': SentenceTransformersTokenTextSplitter(
306
- chunk_size=chunk_size
307
- ),
308
- 'token': TokenTextSplitter(
309
- chunk_size=chunk_size,
310
- chunk_overlap=chunk_overlap
311
- )
312
- }
313
-
314
- if text_splitter_type not in splitters:
315
- raise ValueError(f"Invalid text splitter type: {text_splitter_type}")
316
-
317
- text_splitter = splitters[text_splitter_type]
318
- docs = text_splitter.split_documents(doc_data)
319
-
320
- if self.logger:
321
- self.logger.info(f"Text data splitted into {len(docs)} chunks")
322
- else:
323
- print(f"Text data splitted into {len(docs)} chunks")
324
- return docs
325
-
326
- class embedding_generator:
327
- """
328
- Main class for generating embeddings and managing RAG operations.
329
-
330
- This class provides comprehensive functionality for generating embeddings,
331
- managing vector stores, handling retrievers, and managing conversations.
332
-
333
- Args:
334
- model (str): Model provider name (default: 'openai')
335
- model_type (str): Model type/identifier (default: 'text-embedding-3-small')
336
- vector_store_type (str): Type of vector store (default: 'chroma')
337
- collection_name (str): Name of the collection (default: 'test')
338
- logger: Optional logger instance
339
- model_kwargs (dict): Additional arguments for model initialization
340
- vector_store_kwargs (dict): Additional arguments for vector store initialization
341
-
342
- Example:
343
- ```python
344
- # Initialize generator
345
- gen = embedding_generator(
346
- model="openai",
347
- model_type="text-embedding-3-small",
348
- collection_name='test'
349
- )
350
-
351
- # Generate embeddings
352
- gen.generate_text_embeddings(
353
- text_data_path=['./data.txt'],
354
- folder_save_path='./embeddings'
355
- )
356
-
357
- # Load retriever
358
- retriever = gen.load_retriever('./embeddings', collection_name='test')
359
-
360
- # Query embeddings
361
- results = gen.query_embeddings("What is this about?")
362
- ```
363
- """
364
-
365
- def __init__(self, model: str = 'openai', model_type: str = 'text-embedding-3-small',
366
- vector_store_type: str = 'chroma', collection_name: str = 'test',
367
- logger=None, model_kwargs: dict = None, vector_store_kwargs: dict = None) -> None:
368
- """Initialize the embedding generator with specified configuration."""
369
- self.logger = logger
370
- self.model = load_embedding_model(model_name=model, model_type=model_type, **(model_kwargs or {}))
371
- if self.model is None:
372
- raise ValueError(f"Failed to initialize model {model}. Please ensure required packages are installed.")
373
- self.vector_store_type = vector_store_type
374
- self.vector_store = self.load_vectorstore(**(vector_store_kwargs or {}))
375
- self.collection_name = collection_name
376
- self.text_processor = TextProcessor(logger)
377
-
378
- def check_file(self, file_path: str) -> bool:
379
- """Check if file exists."""
380
- return self.text_processor.check_file(file_path)
381
-
382
- def tokenize(self, text_data_path: List[str], text_splitter_type: str,
383
- chunk_size: int, chunk_overlap: int) -> List:
384
- """Process and tokenize text data."""
385
- return self.text_processor.tokenize(text_data_path, text_splitter_type,
386
- chunk_size, chunk_overlap)
387
-
388
- def generate_text_embeddings(self, text_data_path: List[str] = None,
389
- text_splitter_type: str = 'recursive_character',
390
- chunk_size: int = 1000, chunk_overlap: int = 5,
391
- folder_save_path: str = './text_embeddings',
392
- replace_existing: bool = False) -> str:
393
- """
394
- Generate text embeddings from input files.
395
-
396
- Args:
397
- text_data_path (List[str]): List of paths to text files
398
- text_splitter_type (str): Type of text splitter
399
- chunk_size (int): Size of text chunks
400
- chunk_overlap (int): Overlap between chunks
401
- folder_save_path (str): Path to save embeddings
402
- replace_existing (bool): Whether to replace existing embeddings
403
-
404
- Returns:
405
- str: Status message
406
-
407
- Example:
408
- ```python
409
- gen.generate_text_embeddings(
410
- text_data_path=['./data.txt'],
411
- folder_save_path='./embeddings'
412
- )
413
- ```
414
- """
415
- if self.logger:
416
- self.logger.info("Performing basic checks")
417
-
418
- if self.check_file(folder_save_path) and not replace_existing:
419
- return "File already exists"
420
- elif self.check_file(folder_save_path) and replace_existing:
421
- shutil.rmtree(folder_save_path)
422
-
423
- if text_data_path is None:
424
- return "Please provide text data path"
425
-
426
- if not isinstance(text_data_path, list):
427
- raise ValueError("text_data_path should be a list")
428
-
429
- if self.logger:
430
- self.logger.info(f"Loading text data from {text_data_path}")
431
-
432
- docs = self.tokenize(text_data_path, text_splitter_type, chunk_size, chunk_overlap)
433
-
434
- if self.logger:
435
- self.logger.info(f"Generating embeddings for {len(docs)} documents")
436
-
437
- self.vector_store.from_documents(docs, self.model, collection_name=self.collection_name,
438
- persist_directory=folder_save_path)
439
-
440
- if self.logger:
441
- self.logger.info(f"Embeddings generated and saved at {folder_save_path}")
442
-
443
- def load_vectorstore(self, **kwargs):
444
- """Load vector store."""
445
- if self.vector_store_type == 'chroma':
446
- vector_store = Chroma()
447
- if self.logger:
448
- self.logger.info(f"Loaded vector store {self.vector_store_type}")
449
- return vector_store
450
- else:
451
- return "Vector store not found"
452
-
453
- def load_embeddings(self, embeddings_folder_path: str,collection_name: str = 'test'):
454
- """
455
- Load embeddings from folder.
456
-
457
- Args:
458
- embeddings_folder_path (str): Path to embeddings folder
459
- collection_name (str): Name of the collection. Default: 'test'
460
-
461
- Returns:
462
- Optional[Chroma]: Loaded vector store or None if not found
463
- """
464
- if self.check_file(embeddings_folder_path):
465
- if self.vector_store_type == 'chroma':
466
- return Chroma(persist_directory=embeddings_folder_path,
467
- embedding_function=self.model,
468
- collection_name=collection_name)
469
- else:
470
- if self.logger:
471
- self.logger.info("Embeddings file not found")
472
- return None
473
-
474
- def load_retriever(self, embeddings_folder_path: str,
475
- search_type: List[str] = ["similarity_score_threshold"],
476
- search_params: List[Dict] = [{"k": 3, "score_threshold": 0.9}],
477
- collection_name: str = 'test'):
478
- """
479
- Load retriever with search configuration.
480
-
481
- Args:
482
- embeddings_folder_path (str): Path to embeddings folder
483
- search_type (List[str]): List of search types
484
- search_params (List[Dict]): List of search parameters
485
- collection_name (str): Name of the collection. Default: 'test'
486
-
487
- Returns:
488
- Union[Any, List[Any]]: Single retriever or list of retrievers
489
-
490
- Example:
491
- ```python
492
- retriever = gen.load_retriever(
493
- './embeddings',
494
- search_type=["similarity_score_threshold"],
495
- search_params=[{"k": 3, "score_threshold": 0.9}]
496
- )
497
- ```
498
- """
499
- db = self.load_embeddings(embeddings_folder_path, collection_name)
500
- if db is not None:
501
- if self.vector_store_type == 'chroma':
502
- if len(search_type) != len(search_params):
503
- raise ValueError("Length of search_type and search_params should be equal")
504
- if len(search_type) == 1:
505
- self.retriever = db.as_retriever(search_type=search_type[0],
506
- search_kwargs=search_params[0])
507
- if self.logger:
508
- self.logger.info("Retriever loaded")
509
- return self.retriever
510
- else:
511
- retriever_list = []
512
- for i in range(len(search_type)):
513
- retriever_list.append(db.as_retriever(search_type=search_type[i],
514
- search_kwargs=search_params[i]))
515
- if self.logger:
516
- self.logger.info("List of Retriever loaded")
517
- return retriever_list
518
- else:
519
- return "Embeddings file not found"
520
-
521
- def add_data(self, embeddings_folder_path: str, data: List[str],
522
- text_splitter_type: str = 'recursive_character',
523
- chunk_size: int = 1000, chunk_overlap: int = 5, collection_name: str = 'test'):
524
- """
525
- Add data to existing embeddings.
526
-
527
- Args:
528
- embeddings_folder_path (str): Path to embeddings folder
529
- data (List[str]): List of text data to add
530
- text_splitter_type (str): Type of text splitter
531
- chunk_size (int): Size of text chunks
532
- chunk_overlap (int): Overlap between chunks
533
- collection_name (str): Name of the collection. Default: 'test'
534
- """
535
- if self.vector_store_type == 'chroma':
536
- db = self.load_embeddings(embeddings_folder_path, collection_name)
537
- if db is not None:
538
- docs = self.tokenize(data, text_splitter_type, chunk_size, chunk_overlap)
539
- db.add_documents(docs)
540
- if self.logger:
541
- self.logger.info("Data added to the existing db/embeddings")
542
-
543
- def query_embeddings(self, query: str, retriever=None):
544
- """
545
- Query embeddings.
546
-
547
- Args:
548
- query (str): Query string
549
- retriever: Optional retriever instance
550
-
551
- Returns:
552
- Any: Query results
553
- """
554
- if retriever is None:
555
- retriever = self.retriever
556
- return retriever.invoke(query)
557
-
558
- def get_relevant_documents(self, query: str, retriever=None):
559
- """
560
- Get relevant documents for query.
561
-
562
- Args:
563
- query (str): Query string
564
- retriever: Optional retriever instance
565
-
566
- Returns:
567
- List: List of relevant documents
568
- """
569
- if retriever is None:
570
- retriever = self.retriever
571
- return retriever.get_relevant_documents(query)
572
-
573
- def generate_rag_chain(self, context_prompt: str = None, retriever=None, llm=None):
574
- """
575
- Generate RAG chain for conversation.
576
-
577
- Args:
578
- context_prompt (str): Optional context prompt
579
- retriever: Optional retriever instance
580
- llm: Optional language model instance
581
-
582
- Returns:
583
- Any: Generated RAG chain
584
-
585
- Example:
586
- ```python
587
- rag_chain = gen.generate_rag_chain(retriever=retriever)
588
- ```
589
- """
590
- if context_prompt is None:
591
- context_prompt = ("You are an assistant for question-answering tasks. "
592
- "Use the following pieces of retrieved context to answer the question. "
593
- "If you don't know the answer, just say that you don't know. "
594
- "Use three sentences maximum and keep the answer concise.\n\n{context}")
595
-
596
- contextualize_q_system_prompt = ("Given a chat history and the latest user question "
597
- "which might reference context in the chat history, "
598
- "formulate a standalone question which can be understood, "
599
- "just reformulate it if needed and otherwise return it as is.")
600
-
601
- contextualize_q_prompt = ChatPromptTemplate.from_messages([
602
- ("system", contextualize_q_system_prompt),
603
- MessagesPlaceholder("chat_history"),
604
- ("human", "{input}"),
605
- ])
606
-
607
- if retriever is None:
608
- retriever = self.retriever
609
- if llm is None:
610
- if not ModelProvider.check_package("langchain_openai"):
611
- raise ImportError("OpenAI package not found. Please install: pip install langchain-openai")
612
- from langchain_openai import ChatOpenAI
613
- llm = ChatOpenAI(model="gpt-4o", temperature=0.8)
614
-
615
- history_aware_retriever = create_history_aware_retriever(llm, retriever,
616
- contextualize_q_prompt)
617
- qa_prompt = ChatPromptTemplate.from_messages([
618
- ("system", context_prompt),
619
- MessagesPlaceholder("chat_history"),
620
- ("human", "{input}"),
621
- ])
622
- question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
623
- rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
624
- return rag_chain
625
-
626
- def conversation_chain(self, query: str, rag_chain, file: str = None):
627
- """
628
- Create conversation chain.
629
-
630
- Args:
631
- query (str): User query
632
- rag_chain: RAG chain instance
633
- file (str): Optional file to save conversation
634
-
635
- Returns:
636
- List: Conversation history
637
-
638
- Example:
639
- ```python
640
- history = gen.conversation_chain(
641
- "Tell me about...",
642
- rag_chain,
643
- file='conversation.txt'
644
- )
645
- ```
646
- """
647
- if file is not None:
648
- try:
649
- chat_history = self.load_conversation(file, list_type=True)
650
- if len(chat_history) == 0:
651
- chat_history = []
652
- except:
653
- chat_history = []
654
- else:
655
- chat_history = []
656
-
657
- query = "You : " + query
658
- res = rag_chain.invoke({"input": query, "chat_history": chat_history})
659
- print(f"Response: {res['answer']}")
660
- chat_history.append(HumanMessage(content=query))
661
- chat_history.append(SystemMessage(content=res['answer']))
662
- if file is not None:
663
- self.save_conversation(chat_history, file)
664
- return chat_history
665
-
666
- def load_conversation(self, file: str, list_type: bool = False):
667
- """
668
- Load conversation history.
669
-
670
- Args:
671
- file (str): Path to conversation file
672
- list_type (bool): Whether to return as list
673
-
674
- Returns:
675
- Union[str, List]: Conversation history
676
- """
677
- if list_type:
678
- chat_history = []
679
- with open(file, 'r') as f:
680
- for line in f:
681
- chat_history.append(line.strip())
682
- else:
683
- with open(file, "r") as f:
684
- chat_history = f.read()
685
- return chat_history
686
-
687
- def save_conversation(self, chat: Union[str, List], file: str):
688
- """
689
- Save conversation history.
690
-
691
- Args:
692
- chat (Union[str, List]): Conversation to save
693
- file (str): Path to save file
694
- """
695
- if isinstance(chat, str):
696
- with open(file, "a") as f:
697
- f.write(chat)
698
- elif isinstance(chat, list):
699
- with open(file, "a") as f:
700
- for i in chat[-2:]:
701
- f.write("%s\n" % i)
702
- print(f"Saved file : {file}")
703
-
704
- def firecrawl_web(self, website: str, api_key: str = None, mode: str = "scrape",
705
- file_to_save: str = './firecrawl_embeddings', **kwargs):
706
- """
707
- Get data from website using FireCrawl.
708
-
709
- Args:
710
- website (str): Website URL to crawl
711
- api_key (str): Optional FireCrawl API key
712
- mode (str): Crawl mode (default: "scrape")
713
- file_to_save (str): Path to save embeddings
714
- **kwargs: Additional arguments for FireCrawl
715
-
716
- Returns:
717
- Chroma: Vector store with crawled data
718
-
719
- Example:
720
- ```python
721
- db = gen.firecrawl_web(
722
- "https://example.com",
723
- mode="scrape",
724
- file_to_save='./crawl_embeddings'
725
- )
726
- ```
727
- """
728
- if not ModelProvider.check_package("firecrawl"):
729
- raise ImportError("Firecrawl package not found. Please install: pip install firecrawl")
730
-
731
- if api_key is None:
732
- api_key = os.getenv("FIRECRAWL_API_KEY")
733
-
734
- loader = FireCrawlLoader(api_key=api_key, url=website, mode=mode)
735
- docs = loader.load()
736
-
737
- for doc in docs:
738
- for key, value in doc.metadata.items():
739
- if isinstance(value, list):
740
- doc.metadata[key] = ", ".join(map(str, value))
741
-
742
- text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
743
- split_docs = text_splitter.split_documents(docs)
744
-
745
- print("\n--- Document Chunks Information ---")
746
- print(f"Number of document chunks: {len(split_docs)}")
747
- print(f"Sample chunk:\n{split_docs[0].page_content}\n")
748
-
749
- embeddings = self.model
750
- db = Chroma.from_documents(split_docs, embeddings,
751
- persist_directory=file_to_save)
752
- print(f"Retriever saved at {file_to_save}")
753
- return db
1
+ """
2
+ RAG (Retrieval-Augmented Generation) Embeddings Module
3
+
4
+ This module provides functionality for generating and managing embeddings for RAG models.
5
+ It supports multiple embedding models (OpenAI, Ollama, Google, Anthropic) and includes
6
+ features for text processing, embedding generation, vector store management, and
7
+ conversation handling.
8
+
9
+ Example Usage:
10
+ ```python
11
+ # Initialize embedding generator
12
+ em_gen = embedding_generator(
13
+ model="openai",
14
+ model_type="text-embedding-3-small",
15
+ vector_store_type="chroma"
16
+ )
17
+
18
+ # Generate embeddings from text
19
+ em_gen.generate_text_embeddings(
20
+ text_data_path=['./data/text.txt'],
21
+ chunk_size=500,
22
+ chunk_overlap=5,
23
+ folder_save_path='./embeddings'
24
+ )
25
+
26
+ # Load embeddings and create retriever
27
+ em_loading = em_gen.load_embeddings('./embeddings')
28
+ em_retriever = em_gen.load_retriever(
29
+ './embeddings',
30
+ search_params=[{"k": 2, "score_threshold": 0.1}]
31
+ )
32
+
33
+ # Query embeddings
34
+ results = em_retriever.invoke("What is the text about?")
35
+
36
+ # Generate RAG chain for conversation
37
+ rag_chain = em_gen.generate_rag_chain(retriever=em_retriever)
38
+ response = em_gen.conversation_chain("Tell me more", rag_chain)
39
+ ```
40
+
41
+ Features:
42
+ - Multiple model support (OpenAI, Ollama, Google, Anthropic)
43
+ - Text processing and chunking
44
+ - Embedding generation and storage
45
+ - Vector store management
46
+ - Retrieval operations
47
+ - Conversation chains
48
+ - Web crawling integration
49
+
50
+ Classes:
51
+ - ModelProvider: Base class for model loading and validation
52
+ - TextProcessor: Handles text processing operations
53
+ - embedding_generator: Main class for RAG operations
54
+ """
55
+
56
+ import os
57
+ import shutil
58
+ import importlib.util
59
+ from typing import List, Dict, Optional, Union, Any
60
+ from langchain.text_splitter import (
61
+ CharacterTextSplitter,
62
+ RecursiveCharacterTextSplitter,
63
+ SentenceTransformersTokenTextSplitter,
64
+ TokenTextSplitter,
65
+ MarkdownHeaderTextSplitter,
66
+ SemanticChunker)
67
+ from langchain_community.document_loaders import TextLoader, FireCrawlLoader
68
+ from langchain_chroma import Chroma
69
+ from ..utils.extra import load_env_file
70
+ from langchain.chains import create_history_aware_retriever, create_retrieval_chain
71
+ from langchain.chains.combine_documents import create_stuff_documents_chain
72
+ from langchain_core.messages import HumanMessage, SystemMessage
73
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
74
+ from langchain.retrievers import ContextualCompressionRetriever
75
+ from langchain_community.document_compressors import FlashrankRerank
76
+
77
+ load_env_file()
78
+
79
+ __all__ = ['embedding_generator', 'load_embedding_model']
80
+
81
+ class ModelProvider:
82
+ """
83
+ Base class for managing different model providers and their loading logic.
84
+
85
+ This class provides static methods for loading different types of embedding models
86
+ and checking package dependencies.
87
+
88
+ Methods:
89
+ check_package: Check if a Python package is installed
90
+ get_rag_openai: Load OpenAI embedding model
91
+ get_rag_ollama: Load Ollama embedding model
92
+ get_rag_anthropic: Load Anthropic model
93
+ get_rag_google: Load Google embedding model
94
+
95
+ Example:
96
+ ```python
97
+ # Check if a package is installed
98
+ has_openai = ModelProvider.check_package("langchain_openai")
99
+
100
+ # Load an OpenAI model
101
+ model = ModelProvider.get_rag_openai("text-embedding-3-small")
102
+ ```
103
+ """
104
+
105
+ @staticmethod
106
+ def check_package(package_name: str) -> bool:
107
+ """
108
+ Check if a Python package is installed.
109
+
110
+ Args:
111
+ package_name (str): Name of the package to check
112
+
113
+ """
114
+ return importlib.util.find_spec(package_name) is not None
115
+
116
+ @staticmethod
117
+ def get_rag_openai(model_type: str = 'text-embedding-3-small', **kwargs):
118
+ """
119
+ Load OpenAI embedding model.
120
+
121
+ Args:
122
+ model_type (str): Model identifier (default: 'text-embedding-3-small')
123
+ **kwargs: Additional arguments for model initialization
124
+
125
+ Returns:
126
+ OpenAIEmbeddings: Initialized OpenAI embeddings model
127
+ """
128
+ if not ModelProvider.check_package("langchain_openai"):
129
+ raise ImportError("OpenAI package not found. Please install: pip install langchain-openai")
130
+ from langchain_openai import OpenAIEmbeddings
131
+ return OpenAIEmbeddings(model=model_type, **kwargs)
132
+
133
+ @staticmethod
134
+ def get_rag_ollama(model_type: str = 'llama3', **kwargs):
135
+ """
136
+ Load Ollama embedding model.
137
+
138
+ Args:
139
+ model_type (str): Model identifier (default: 'llama3')
140
+ **kwargs: Additional arguments for model initialization
141
+
142
+ Returns:
143
+ OllamaEmbeddings: Initialized Ollama embeddings model
144
+ """
145
+ if not ModelProvider.check_package("langchain_ollama"):
146
+ raise ImportError("Ollama package not found. Please install: pip install langchain-ollama")
147
+ from langchain_ollama import OllamaEmbeddings
148
+ return OllamaEmbeddings(model=model_type, **kwargs)
149
+
150
+ @staticmethod
151
+ def get_rag_anthropic(model_name: str = "claude-3-opus-20240229", **kwargs):
152
+ """
153
+ Load Anthropic model.
154
+
155
+ Args:
156
+ model_name (str): Model identifier (default: "claude-3-opus-20240229")
157
+ **kwargs: Additional arguments for model initialization
158
+
159
+ Returns:
160
+ ChatAnthropic: Initialized Anthropic chat model
161
+
162
+ """
163
+ if not ModelProvider.check_package("langchain_anthropic"):
164
+ raise ImportError("Anthropic package not found. Please install: pip install langchain-anthropic")
165
+ from langchain_anthropic import ChatAnthropic
166
+ kwargs["model_name"] = model_name
167
+ return ChatAnthropic(**kwargs)
168
+
169
+ @staticmethod
170
+ def get_rag_google(model_name: str = "gemini-1.5-flash", **kwargs):
171
+ """
172
+ Load Google embedding model.
173
+
174
+ Args:
175
+ model_name (str): Model identifier (default: "gemini-1.5-flash")
176
+ **kwargs: Additional arguments for model initialization
177
+
178
+ Returns:
179
+ GoogleGenerativeAIEmbeddings: Initialized Google embeddings model
180
+ """
181
+ if not ModelProvider.check_package("google.generativeai"):
182
+ raise ImportError("Google Generative AI package not found. Please install: pip install langchain-google-genai")
183
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings
184
+ kwargs["model"] = model_name
185
+ return GoogleGenerativeAIEmbeddings(**kwargs)
186
+
187
+ @staticmethod
188
+ def get_rag_qwen(model_name: str = "Qwen/Qwen3-Embedding-0.6B", **kwargs):
189
+ """
190
+ Load Qwen embedding model.
191
+ Uses Transformers for embedding generation.
192
+
193
+ Args:
194
+ model_name (str): Model identifier (default: "Qwen/Qwen3-Embedding-0.6B")
195
+ **kwargs: Additional arguments for model initialization
196
+
197
+ Returns:
198
+ QwenEmbeddings: Initialized Qwen embeddings model
199
+ """
200
+ from langchain.embeddings import HuggingFaceEmbeddings
201
+
202
+ return HuggingFaceEmbeddings(model_name=model_name, **kwargs)
203
+
204
+ def load_embedding_model(model_name: str = 'openai', model_type: str = "text-embedding-ada-002", **kwargs):
205
+ """
206
+ Load a RAG model based on provider and type.
207
+
208
+ Args:
209
+ model_name (str): Name of the model provider (default: 'openai')
210
+ model_type (str): Type/identifier of the model (default: "text-embedding-ada-002")
211
+ **kwargs: Additional arguments for model initialization
212
+
213
+ Returns:
214
+ Any: Initialized model instance
215
+
216
+ Example:
217
+ ```python
218
+ model = load_embedding_model('openai', 'text-embedding-3-small')
219
+ ```
220
+ """
221
+ try:
222
+ if model_name == 'openai':
223
+ return ModelProvider.get_rag_openai(model_type, **kwargs)
224
+ elif model_name == 'ollama':
225
+ return ModelProvider.get_rag_ollama(model_type, **kwargs)
226
+ elif model_name == 'google':
227
+ return ModelProvider.get_rag_google(model_type, **kwargs)
228
+ elif model_name == 'anthropic':
229
+ return ModelProvider.get_rag_anthropic(model_type, **kwargs)
230
+ elif model_name == 'qwen':
231
+ return ModelProvider.get_rag_qwen(model_type, **kwargs)
232
+ else:
233
+ raise ValueError(f"Invalid model name: {model_name}")
234
+ except ImportError as e:
235
+ print(f"Error loading model: {str(e)}")
236
+ return None
237
+
238
+ class TextProcessor:
239
+ """
240
+ Handles text processing operations including file checking and tokenization.
241
+
242
+ This class provides methods for loading text files, processing them into chunks,
243
+ and preparing them for embedding generation.
244
+
245
+ Args:
246
+ logger: Optional logger instance for logging operations
247
+
248
+ Example:
249
+ ```python
250
+ processor = TextProcessor()
251
+ docs = processor.tokenize(
252
+ ['./data.txt'],
253
+ 'recursive_character',
254
+ chunk_size=1000,
255
+ chunk_overlap=5
256
+ )
257
+ ```
258
+ """
259
+
260
+ def __init__(self, logger=None):
261
+ self.logger = logger
262
+
263
+ def check_file(self, file_path: str) -> bool:
264
+ """Check if file exists."""
265
+ return os.path.exists(file_path)
266
+
267
+ def tokenize(self, text_data_path: List[str], text_splitter_type: str,
268
+ chunk_size: int, chunk_overlap: int) -> List:
269
+ """
270
+ Process and tokenize text data from files.
271
+
272
+ Args:
273
+ text_data_path (List[str]): List of paths to text files
274
+ text_splitter_type (str): Type of text splitter to use
275
+ chunk_size (int): Size of text chunks
276
+ chunk_overlap (int): Overlap between chunks
277
+
278
+ Returns:
279
+ List: List of processed document chunks
280
+
281
+ """
282
+ doc_data = []
283
+ for path in text_data_path:
284
+ if self.check_file(path):
285
+ text_loader = TextLoader(path)
286
+ get_text = text_loader.load()
287
+ file_name = path.split('/')[-1]
288
+ metadata = {'source': file_name}
289
+ if metadata is not None:
290
+ for doc in get_text:
291
+ doc.metadata = metadata
292
+ doc_data.append(doc)
293
+ if self.logger:
294
+ self.logger.info(f"Text data loaded from {file_name}")
295
+ else:
296
+ return f"File {path} not found"
297
+
298
+ splitters = {
299
+ 'character': CharacterTextSplitter(
300
+ chunk_size=chunk_size,
301
+ chunk_overlap=chunk_overlap,
302
+ separator=["\n", "\n\n", "\n\n\n", " "]
303
+ ),
304
+ 'recursive_character': RecursiveCharacterTextSplitter(
305
+ chunk_size=chunk_size,
306
+ chunk_overlap=chunk_overlap,
307
+ separators=["\n", "\n\n", "\n\n\n", " "]
308
+ ),
309
+ 'sentence_transformers_token': SentenceTransformersTokenTextSplitter(
310
+ chunk_size=chunk_size
311
+ ),
312
+ 'token': TokenTextSplitter(
313
+ chunk_size=chunk_size,
314
+ chunk_overlap=chunk_overlap
315
+ ),
316
+ 'markdown_header': MarkdownHeaderTextSplitter(
317
+ chunk_size=chunk_size,
318
+ chunk_overlap=chunk_overlap
319
+ ),
320
+ 'semantic_chunker': SemanticChunker(
321
+ chunk_size=chunk_size,
322
+ chunk_overlap=chunk_overlap
323
+ )
324
+ }
325
+
326
+ if text_splitter_type not in splitters:
327
+ raise ValueError(f"Invalid text splitter type: {text_splitter_type}")
328
+
329
+ text_splitter = splitters[text_splitter_type]
330
+ docs = text_splitter.split_documents(doc_data)
331
+
332
+ if self.logger:
333
+ self.logger.info(f"Text data splitted into {len(docs)} chunks")
334
+ else:
335
+ print(f"Text data splitted into {len(docs)} chunks")
336
+ return docs
337
+
338
+
339
+ class embedding_generator:
340
+ """
341
+ Main class for generating embeddings and managing RAG operations.
342
+
343
+ This class provides comprehensive functionality for generating embeddings,
344
+ managing vector stores, handling retrievers, and managing conversations.
345
+
346
+ Args:
347
+ model (str): Model provider name (default: 'openai')
348
+ model_type (str): Model type/identifier (default: 'text-embedding-3-small')
349
+ vector_store_type (str): Type of vector store (default: 'chroma')
350
+ collection_name (str): Name of the collection (default: 'test')
351
+ logger: Optional logger instance
352
+ model_kwargs (dict): Additional arguments for model initialization
353
+ vector_store_kwargs (dict): Additional arguments for vector store initialization
354
+
355
+ Example:
356
+ ```python
357
+ # Initialize generator
358
+ gen = embedding_generator(
359
+ model="openai",
360
+ model_type="text-embedding-3-small",
361
+ collection_name='test'
362
+ )
363
+
364
+ # Generate embeddings
365
+ gen.generate_text_embeddings(
366
+ text_data_path=['./data.txt'],
367
+ folder_save_path='./embeddings'
368
+ )
369
+
370
+ # Load retriever
371
+ retriever = gen.load_retriever('./embeddings', collection_name='test')
372
+
373
+ # Query embeddings
374
+ results = gen.query_embeddings("What is this about?")
375
+ ```
376
+ """
377
+
378
+ def __init__(self, model: str = 'openai', model_type: str = 'text-embedding-3-small',
379
+ vector_store_type: str = 'chroma', collection_name: str = 'test',
380
+ logger=None, model_kwargs: dict = None, vector_store_kwargs: dict = None) -> None:
381
+ """Initialize the embedding generator with specified configuration."""
382
+ self.logger = logger
383
+ self.model = load_embedding_model(model_name=model, model_type=model_type, **(model_kwargs or {}))
384
+ if self.model is None:
385
+ raise ValueError(f"Failed to initialize model {model}. Please ensure required packages are installed.")
386
+ self.vector_store_type = vector_store_type
387
+ self.vector_store = self.load_vectorstore(**(vector_store_kwargs or {}))
388
+ self.collection_name = collection_name
389
+ self.text_processor = TextProcessor(logger)
390
+ self.compression_retriever = None
391
+
392
+ def check_file(self, file_path: str) -> bool:
393
+ """Check if file exists."""
394
+ return self.text_processor.check_file(file_path)
395
+
396
+ def tokenize(self, text_data_path: List[str], text_splitter_type: str,
397
+ chunk_size: int, chunk_overlap: int) -> List:
398
+ """Process and tokenize text data."""
399
+ return self.text_processor.tokenize(text_data_path, text_splitter_type,
400
+ chunk_size, chunk_overlap)
401
+
402
+ def generate_text_embeddings(self, text_data_path: List[str] = None,
403
+ text_splitter_type: str = 'recursive_character',
404
+ chunk_size: int = 1000, chunk_overlap: int = 5,
405
+ folder_save_path: str = './text_embeddings',
406
+ replace_existing: bool = False) -> str:
407
+ """
408
+ Generate text embeddings from input files.
409
+
410
+ Args:
411
+ text_data_path (List[str]): List of paths to text files
412
+ text_splitter_type (str): Type of text splitter
413
+ chunk_size (int): Size of text chunks
414
+ chunk_overlap (int): Overlap between chunks
415
+ folder_save_path (str): Path to save embeddings
416
+ replace_existing (bool): Whether to replace existing embeddings
417
+
418
+ Returns:
419
+ str: Status message
420
+
421
+ Example:
422
+ ```python
423
+ gen.generate_text_embeddings(
424
+ text_data_path=['./data.txt'],
425
+ folder_save_path='./embeddings'
426
+ )
427
+ ```
428
+ """
429
+ if self.logger:
430
+ self.logger.info("Performing basic checks")
431
+
432
+ if self.check_file(folder_save_path) and not replace_existing:
433
+ return "File already exists"
434
+ elif self.check_file(folder_save_path) and replace_existing:
435
+ shutil.rmtree(folder_save_path)
436
+
437
+ if text_data_path is None:
438
+ return "Please provide text data path"
439
+
440
+ if not isinstance(text_data_path, list):
441
+ raise ValueError("text_data_path should be a list")
442
+
443
+ if self.logger:
444
+ self.logger.info(f"Loading text data from {text_data_path}")
445
+
446
+ docs = self.tokenize(text_data_path, text_splitter_type, chunk_size, chunk_overlap)
447
+
448
+ if self.logger:
449
+ self.logger.info(f"Generating embeddings for {len(docs)} documents")
450
+
451
+ self.vector_store.from_documents(docs, self.model, collection_name=self.collection_name,
452
+ persist_directory=folder_save_path)
453
+
454
+ if self.logger:
455
+ self.logger.info(f"Embeddings generated and saved at {folder_save_path}")
456
+
457
+ def load_vectorstore(self, **kwargs):
458
+ """Load vector store."""
459
+ if self.vector_store_type == 'chroma':
460
+ vector_store = Chroma()
461
+ if self.logger:
462
+ self.logger.info(f"Loaded vector store {self.vector_store_type}")
463
+ return vector_store
464
+ else:
465
+ return "Vector store not found"
466
+
467
+ def load_embeddings(self, embeddings_folder_path: str,collection_name: str = 'test'):
468
+ """
469
+ Load embeddings from folder.
470
+
471
+ Args:
472
+ embeddings_folder_path (str): Path to embeddings folder
473
+ collection_name (str): Name of the collection. Default: 'test'
474
+
475
+ Returns:
476
+ Optional[Chroma]: Loaded vector store or None if not found
477
+ """
478
+ if self.check_file(embeddings_folder_path):
479
+ if self.vector_store_type == 'chroma':
480
+ return Chroma(persist_directory=embeddings_folder_path,
481
+ embedding_function=self.model,
482
+ collection_name=collection_name)
483
+ else:
484
+ if self.logger:
485
+ self.logger.info("Embeddings file not found")
486
+ return None
487
+
488
+ def load_retriever(self, embeddings_folder_path: str,
489
+ search_type: List[str] = ["similarity_score_threshold"],
490
+ search_params: List[Dict] = [{"k": 3, "score_threshold": 0.9}],
491
+ collection_name: str = 'test'):
492
+ """
493
+ Load retriever with search configuration.
494
+
495
+ Args:
496
+ embeddings_folder_path (str): Path to embeddings folder
497
+ search_type (List[str]): List of search types
498
+ search_params (List[Dict]): List of search parameters
499
+ collection_name (str): Name of the collection. Default: 'test'
500
+
501
+ Returns:
502
+ Union[Any, List[Any]]: Single retriever or list of retrievers
503
+
504
+ Example:
505
+ ```python
506
+ retriever = gen.load_retriever(
507
+ './embeddings',
508
+ search_type=["similarity_score_threshold"],
509
+ search_params=[{"k": 3, "score_threshold": 0.9}]
510
+ )
511
+ ```
512
+ """
513
+ db = self.load_embeddings(embeddings_folder_path, collection_name)
514
+ if db is not None:
515
+ if self.vector_store_type == 'chroma':
516
+ if len(search_type) != len(search_params):
517
+ raise ValueError("Length of search_type and search_params should be equal")
518
+ if len(search_type) == 1:
519
+ self.retriever = db.as_retriever(search_type=search_type[0],
520
+ search_kwargs=search_params[0])
521
+ if self.logger:
522
+ self.logger.info("Retriever loaded")
523
+ return self.retriever
524
+ else:
525
+ retriever_list = []
526
+ for i in range(len(search_type)):
527
+ retriever_list.append(db.as_retriever(search_type=search_type[i],
528
+ search_kwargs=search_params[i]))
529
+ if self.logger:
530
+ self.logger.info("List of Retriever loaded")
531
+ return retriever_list
532
+ else:
533
+ return "Embeddings file not found"
534
+
535
+ def add_data(self, embeddings_folder_path: str, data: List[str],
536
+ text_splitter_type: str = 'recursive_character',
537
+ chunk_size: int = 1000, chunk_overlap: int = 5, collection_name: str = 'test'):
538
+ """
539
+ Add data to existing embeddings.
540
+
541
+ Args:
542
+ embeddings_folder_path (str): Path to embeddings folder
543
+ data (List[str]): List of text data to add
544
+ text_splitter_type (str): Type of text splitter
545
+ chunk_size (int): Size of text chunks
546
+ chunk_overlap (int): Overlap between chunks
547
+ collection_name (str): Name of the collection. Default: 'test'
548
+ """
549
+ if self.vector_store_type == 'chroma':
550
+ db = self.load_embeddings(embeddings_folder_path, collection_name)
551
+ if db is not None:
552
+ docs = self.tokenize(data, text_splitter_type, chunk_size, chunk_overlap)
553
+ db.add_documents(docs)
554
+ if self.logger:
555
+ self.logger.info("Data added to the existing db/embeddings")
556
+
557
+ def query_embeddings(self, query: str, retriever=None):
558
+ """
559
+ Query embeddings.
560
+
561
+ Args:
562
+ query (str): Query string
563
+ retriever: Optional retriever instance
564
+
565
+ Returns:
566
+ Any: Query results
567
+ """
568
+ if retriever is None:
569
+ retriever = self.retriever
570
+ return retriever.invoke(query)
571
+
572
+ def get_relevant_documents(self, query: str, retriever=None):
573
+ """
574
+ Get relevant documents for query.
575
+
576
+ Args:
577
+ query (str): Query string
578
+ retriever: Optional retriever instance
579
+
580
+ Returns:
581
+ List: List of relevant documents
582
+ """
583
+ if retriever is None:
584
+ retriever = self.retriever
585
+ return retriever.get_relevant_documents(query)
586
+
587
+ def load_flashrank_compression_retriever(self, base_retriever=None, model_name: str = "flashrank/flashrank-base", top_n: int = 5):
588
+ """
589
+ Load a ContextualCompressionRetriever using FlashrankRerank.
590
+
591
+ Args:
592
+ base_retriever: Existing retriever (if None, uses self.retriever)
593
+ model_name (str): Flashrank model identifier (default: "flashrank/flashrank-base")
594
+ top_n (int): Number of top documents to return after reranking
595
+
596
+ Returns:
597
+ ContextualCompressionRetriever: A compression-based retriever using Flashrank
598
+ """
599
+ if base_retriever is None:
600
+ base_retriever = self.retriever
601
+ if base_retriever is None:
602
+ raise ValueError("Base retriever is required.")
603
+
604
+ compressor = FlashrankRerank(model=model_name, top_n=top_n)
605
+ self.compression_retriever = ContextualCompressionRetriever(
606
+ base_compressor=compressor,
607
+ base_retriever=base_retriever
608
+ )
609
+
610
+ if self.logger:
611
+ self.logger.info("Loaded Flashrank compression retriever.")
612
+ return self.compression_retriever
613
+
614
+ def compression_invoke(self, query: str):
615
+ """
616
+ Invoke compression retriever. Only one compression retriever (Reranker) added right now.
617
+
618
+ Args:
619
+ query (str): Query string
620
+
621
+ Returns:
622
+ Any: Query results
623
+ """
624
+
625
+ if self.compression_retriever is None:
626
+ self.compression_retriever = self.load_flashrank_compression_retriever(base_retriever=self.retriever)
627
+ print("Compression retriever loaded.")
628
+ return self.compression_retriever.invoke(query)
629
+
630
+ def generate_rag_chain(self, context_prompt: str = None, retriever=None, llm=None):
631
+ """
632
+ Generate RAG chain for conversation.
633
+
634
+ Args:
635
+ context_prompt (str): Optional context prompt
636
+ retriever: Optional retriever instance
637
+ llm: Optional language model instance
638
+
639
+ Returns:
640
+ Any: Generated RAG chain
641
+
642
+ Example:
643
+ ```python
644
+ rag_chain = gen.generate_rag_chain(retriever=retriever)
645
+ ```
646
+ """
647
+ if context_prompt is None:
648
+ context_prompt = ("You are an assistant for question-answering tasks. "
649
+ "Use the following pieces of retrieved context to answer the question. "
650
+ "If you don't know the answer, just say that you don't know. "
651
+ "Use three sentences maximum and keep the answer concise.\n\n{context}")
652
+
653
+ contextualize_q_system_prompt = ("Given a chat history and the latest user question "
654
+ "which might reference context in the chat history, "
655
+ "formulate a standalone question which can be understood, "
656
+ "just reformulate it if needed and otherwise return it as is.")
657
+
658
+ contextualize_q_prompt = ChatPromptTemplate.from_messages([
659
+ ("system", contextualize_q_system_prompt),
660
+ MessagesPlaceholder("chat_history"),
661
+ ("human", "{input}"),
662
+ ])
663
+
664
+ if retriever is None:
665
+ retriever = self.retriever
666
+ if llm is None:
667
+ if not ModelProvider.check_package("langchain_openai"):
668
+ raise ImportError("OpenAI package not found. Please install: pip install langchain-openai")
669
+ from langchain_openai import ChatOpenAI
670
+ llm = ChatOpenAI(model="gpt-4o", temperature=0.8)
671
+
672
+ history_aware_retriever = create_history_aware_retriever(llm, retriever,
673
+ contextualize_q_prompt)
674
+ qa_prompt = ChatPromptTemplate.from_messages([
675
+ ("system", context_prompt),
676
+ MessagesPlaceholder("chat_history"),
677
+ ("human", "{input}"),
678
+ ])
679
+ question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
680
+ rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
681
+ return rag_chain
682
+
683
+ def conversation_chain(self, query: str, rag_chain, file: str = None):
684
+ """
685
+ Create conversation chain.
686
+
687
+ Args:
688
+ query (str): User query
689
+ rag_chain: RAG chain instance
690
+ file (str): Optional file to save conversation
691
+
692
+ Returns:
693
+ List: Conversation history
694
+
695
+ Example:
696
+ ```python
697
+ history = gen.conversation_chain(
698
+ "Tell me about...",
699
+ rag_chain,
700
+ file='conversation.txt'
701
+ )
702
+ ```
703
+ """
704
+ if file is not None:
705
+ try:
706
+ chat_history = self.load_conversation(file, list_type=True)
707
+ if len(chat_history) == 0:
708
+ chat_history = []
709
+ except:
710
+ chat_history = []
711
+ else:
712
+ chat_history = []
713
+
714
+ query = "You : " + query
715
+ res = rag_chain.invoke({"input": query, "chat_history": chat_history})
716
+ print(f"Response: {res['answer']}")
717
+ chat_history.append(HumanMessage(content=query))
718
+ chat_history.append(SystemMessage(content=res['answer']))
719
+ if file is not None:
720
+ self.save_conversation(chat_history, file)
721
+ return chat_history
722
+
723
+ def load_conversation(self, file: str, list_type: bool = False):
724
+ """
725
+ Load conversation history.
726
+
727
+ Args:
728
+ file (str): Path to conversation file
729
+ list_type (bool): Whether to return as list
730
+
731
+ Returns:
732
+ Union[str, List]: Conversation history
733
+ """
734
+ if list_type:
735
+ chat_history = []
736
+ with open(file, 'r') as f:
737
+ for line in f:
738
+ chat_history.append(line.strip())
739
+ else:
740
+ with open(file, "r") as f:
741
+ chat_history = f.read()
742
+ return chat_history
743
+
744
+ def save_conversation(self, chat: Union[str, List], file: str):
745
+ """
746
+ Save conversation history.
747
+
748
+ Args:
749
+ chat (Union[str, List]): Conversation to save
750
+ file (str): Path to save file
751
+ """
752
+ if isinstance(chat, str):
753
+ with open(file, "a") as f:
754
+ f.write(chat)
755
+ elif isinstance(chat, list):
756
+ with open(file, "a") as f:
757
+ for i in chat[-2:]:
758
+ f.write("%s\n" % i)
759
+ print(f"Saved file : {file}")
760
+
761
+ def firecrawl_web(self, website: str, api_key: str = None, mode: str = "scrape",
762
+ file_to_save: str = './firecrawl_embeddings', **kwargs):
763
+ """
764
+ Get data from website using FireCrawl.
765
+
766
+ Args:
767
+ website (str): Website URL to crawl
768
+ api_key (str): Optional FireCrawl API key
769
+ mode (str): Crawl mode (default: "scrape")
770
+ file_to_save (str): Path to save embeddings
771
+ **kwargs: Additional arguments for FireCrawl
772
+
773
+ Returns:
774
+ Chroma: Vector store with crawled data
775
+
776
+ Example:
777
+ ```python
778
+ db = gen.firecrawl_web(
779
+ "https://example.com",
780
+ mode="scrape",
781
+ file_to_save='./crawl_embeddings'
782
+ )
783
+ ```
784
+ """
785
+ if not ModelProvider.check_package("firecrawl"):
786
+ raise ImportError("Firecrawl package not found. Please install: pip install firecrawl")
787
+
788
+ if api_key is None:
789
+ api_key = os.getenv("FIRECRAWL_API_KEY")
790
+
791
+ loader = FireCrawlLoader(api_key=api_key, url=website, mode=mode)
792
+ docs = loader.load()
793
+
794
+ for doc in docs:
795
+ for key, value in doc.metadata.items():
796
+ if isinstance(value, list):
797
+ doc.metadata[key] = ", ".join(map(str, value))
798
+
799
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
800
+ split_docs = text_splitter.split_documents(docs)
801
+
802
+ print("\n--- Document Chunks Information ---")
803
+ print(f"Number of document chunks: {len(split_docs)}")
804
+ print(f"Sample chunk:\n{split_docs[0].page_content}\n")
805
+
806
+ embeddings = self.model
807
+ db = Chroma.from_documents(split_docs, embeddings,
808
+ persist_directory=file_to_save)
809
+ print(f"Retriever saved at {file_to_save}")
810
+ return db