mb-rag 1.0.123__tar.gz → 1.0.125__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mb-rag might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mb_rag
3
- Version: 1.0.123
3
+ Version: 1.0.125
4
4
  Summary: RAG function file
5
5
  Author: ['Malav Bateriwala']
6
6
  Requires-Python: >=3.8
@@ -220,13 +220,14 @@ class ConversationModel:
220
220
  self.chatbot = ModelFactory(model_type, model_name, **kwargs)
221
221
 
222
222
  def initialize_conversation(self,
223
- file_path: Optional[str],
224
- question: Optional[str],
225
- context: Optional[str]) -> None:
223
+ question: Optional[str],
224
+ context: Optional[str] = None,
225
+ file_path: Optional[str]=None) -> None:
226
226
  """Initialize conversation state"""
227
227
  if file_path:
228
- self.file_path = file_path
228
+ self.file_path = file_path
229
229
  self.load_conversation(file_path)
230
+
230
231
  else:
231
232
  if not question:
232
233
  raise ValueError("Question is required.")
@@ -371,7 +372,7 @@ class ConversationModel:
371
372
  raise ValueError(f"Error loading conversation from s3: {e}")
372
373
 
373
374
  def _load_from_file(self, file_path: str) -> List[Any]:
374
- """Load conversation from file"""
375
+ """Load conversation from file"""
375
376
  try:
376
377
  with open(file_path, 'r') as f:
377
378
  lines = f.readlines()
@@ -0,0 +1,729 @@
1
+ """
2
+ RAG (Retrieval-Augmented Generation) Embeddings Module
3
+
4
+ This module provides functionality for generating and managing embeddings for RAG models.
5
+ It supports multiple embedding models (OpenAI, Ollama, Google, Anthropic) and includes
6
+ features for text processing, embedding generation, vector store management, and
7
+ conversation handling.
8
+
9
+ Example Usage:
10
+ ```python
11
+ # Initialize embedding generator
12
+ em_gen = embedding_generator(
13
+ model="openai",
14
+ model_type="text-embedding-3-small",
15
+ vector_store_type="chroma"
16
+ )
17
+
18
+ # Generate embeddings from text
19
+ em_gen.generate_text_embeddings(
20
+ text_data_path=['./data/text.txt'],
21
+ chunk_size=500,
22
+ chunk_overlap=5,
23
+ folder_save_path='./embeddings'
24
+ )
25
+
26
+ # Load embeddings and create retriever
27
+ em_loading = em_gen.load_embeddings('./embeddings')
28
+ em_retriever = em_gen.load_retriever(
29
+ './embeddings',
30
+ search_params=[{"k": 2, "score_threshold": 0.1}]
31
+ )
32
+
33
+ # Query embeddings
34
+ results = em_retriever.invoke("What is the text about?")
35
+
36
+ # Generate RAG chain for conversation
37
+ rag_chain = em_gen.generate_rag_chain(retriever=em_retriever)
38
+ response = em_gen.conversation_chain("Tell me more", rag_chain)
39
+ ```
40
+
41
+ Features:
42
+ - Multiple model support (OpenAI, Ollama, Google, Anthropic)
43
+ - Text processing and chunking
44
+ - Embedding generation and storage
45
+ - Vector store management
46
+ - Retrieval operations
47
+ - Conversation chains
48
+ - Web crawling integration
49
+
50
+ Classes:
51
+ - ModelProvider: Base class for model loading and validation
52
+ - TextProcessor: Handles text processing operations
53
+ - embedding_generator: Main class for RAG operations
54
+ """
55
+
56
+ import os
57
+ import shutil
58
+ import importlib.util
59
+ from typing import List, Dict, Optional, Union, Any
60
+ from langchain.text_splitter import (
61
+ CharacterTextSplitter,
62
+ RecursiveCharacterTextSplitter,
63
+ SentenceTransformersTokenTextSplitter,
64
+ TokenTextSplitter)
65
+ from langchain_community.document_loaders import TextLoader, FireCrawlLoader
66
+ from langchain_community.vectorstores import Chroma
67
+ from ..utils.extra import load_env_file
68
+ from langchain.chains import create_history_aware_retriever, create_retrieval_chain
69
+ from langchain.chains.combine_documents import create_stuff_documents_chain
70
+ from langchain_core.messages import HumanMessage, SystemMessage
71
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
72
+
73
+ load_env_file()
74
+
75
+ __all__ = ['embedding_generator', 'load_embedding_model']
76
+
77
+ class ModelProvider:
78
+ """
79
+ Base class for managing different model providers and their loading logic.
80
+
81
+ This class provides static methods for loading different types of embedding models
82
+ and checking package dependencies.
83
+
84
+ Methods:
85
+ check_package: Check if a Python package is installed
86
+ get_rag_openai: Load OpenAI embedding model
87
+ get_rag_ollama: Load Ollama embedding model
88
+ get_rag_anthropic: Load Anthropic model
89
+ get_rag_google: Load Google embedding model
90
+
91
+ Example:
92
+ ```python
93
+ # Check if a package is installed
94
+ has_openai = ModelProvider.check_package("langchain_openai")
95
+
96
+ # Load an OpenAI model
97
+ model = ModelProvider.get_rag_openai("text-embedding-3-small")
98
+ ```
99
+ """
100
+
101
+ @staticmethod
102
+ def check_package(package_name: str) -> bool:
103
+ """
104
+ Check if a Python package is installed.
105
+
106
+ Args:
107
+ package_name (str): Name of the package to check
108
+
109
+ """
110
+ return importlib.util.find_spec(package_name) is not None
111
+
112
+ @staticmethod
113
+ def get_rag_openai(model_type: str = 'text-embedding-3-small', **kwargs):
114
+ """
115
+ Load OpenAI embedding model.
116
+
117
+ Args:
118
+ model_type (str): Model identifier (default: 'text-embedding-3-small')
119
+ **kwargs: Additional arguments for model initialization
120
+
121
+ Returns:
122
+ OpenAIEmbeddings: Initialized OpenAI embeddings model
123
+ """
124
+ if not ModelProvider.check_package("langchain_openai"):
125
+ raise ImportError("OpenAI package not found. Please install: pip install langchain-openai")
126
+ from langchain_openai import OpenAIEmbeddings
127
+ return OpenAIEmbeddings(model=model_type, **kwargs)
128
+
129
+ @staticmethod
130
+ def get_rag_ollama(model_type: str = 'llama3', **kwargs):
131
+ """
132
+ Load Ollama embedding model.
133
+
134
+ Args:
135
+ model_type (str): Model identifier (default: 'llama3')
136
+ **kwargs: Additional arguments for model initialization
137
+
138
+ Returns:
139
+ OllamaEmbeddings: Initialized Ollama embeddings model
140
+ """
141
+ if not ModelProvider.check_package("langchain_ollama"):
142
+ raise ImportError("Ollama package not found. Please install: pip install langchain-ollama")
143
+ from langchain_ollama import OllamaEmbeddings
144
+ return OllamaEmbeddings(model=model_type, **kwargs)
145
+
146
+ @staticmethod
147
+ def get_rag_anthropic(model_name: str = "claude-3-opus-20240229", **kwargs):
148
+ """
149
+ Load Anthropic model.
150
+
151
+ Args:
152
+ model_name (str): Model identifier (default: "claude-3-opus-20240229")
153
+ **kwargs: Additional arguments for model initialization
154
+
155
+ Returns:
156
+ ChatAnthropic: Initialized Anthropic chat model
157
+
158
+ """
159
+ if not ModelProvider.check_package("langchain_anthropic"):
160
+ raise ImportError("Anthropic package not found. Please install: pip install langchain-anthropic")
161
+ from langchain_anthropic import ChatAnthropic
162
+ kwargs["model_name"] = model_name
163
+ return ChatAnthropic(**kwargs)
164
+
165
+ @staticmethod
166
+ def get_rag_google(model_name: str = "gemini-1.5-flash", **kwargs):
167
+ """
168
+ Load Google embedding model.
169
+
170
+ Args:
171
+ model_name (str): Model identifier (default: "gemini-1.5-flash")
172
+ **kwargs: Additional arguments for model initialization
173
+
174
+ Returns:
175
+ GoogleGenerativeAIEmbeddings: Initialized Google embeddings model
176
+ """
177
+ if not ModelProvider.check_package("google.generativeai"):
178
+ raise ImportError("Google Generative AI package not found. Please install: pip install langchain-google-genai")
179
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings
180
+ kwargs["model"] = model_name
181
+ return GoogleGenerativeAIEmbeddings(**kwargs)
182
+
183
+ def load_embedding_model(model_name: str = 'openai', model_type: str = "text-embedding-ada-002", **kwargs):
184
+ """
185
+ Load a RAG model based on provider and type.
186
+
187
+ Args:
188
+ model_name (str): Name of the model provider (default: 'openai')
189
+ model_type (str): Type/identifier of the model (default: "text-embedding-ada-002")
190
+ **kwargs: Additional arguments for model initialization
191
+
192
+ Returns:
193
+ Any: Initialized model instance
194
+
195
+ Example:
196
+ ```python
197
+ model = load_embedding_model('openai', 'text-embedding-3-small')
198
+ ```
199
+ """
200
+ try:
201
+ if model_name == 'openai':
202
+ return ModelProvider.get_rag_openai(model_type, **kwargs)
203
+ elif model_name == 'ollama':
204
+ return ModelProvider.get_rag_ollama(model_type, **kwargs)
205
+ elif model_name == 'google':
206
+ return ModelProvider.get_rag_google(model_type, **kwargs)
207
+ elif model_name == 'anthropic':
208
+ return ModelProvider.get_rag_anthropic(model_type, **kwargs)
209
+ else:
210
+ raise ValueError(f"Invalid model name: {model_name}")
211
+ except ImportError as e:
212
+ print(f"Error loading model: {str(e)}")
213
+ return None
214
+
215
+ class TextProcessor:
216
+ """
217
+ Handles text processing operations including file checking and tokenization.
218
+
219
+ This class provides methods for loading text files, processing them into chunks,
220
+ and preparing them for embedding generation.
221
+
222
+ Args:
223
+ logger: Optional logger instance for logging operations
224
+
225
+ Example:
226
+ ```python
227
+ processor = TextProcessor()
228
+ docs = processor.tokenize(
229
+ ['./data.txt'],
230
+ 'recursive_character',
231
+ chunk_size=1000,
232
+ chunk_overlap=5
233
+ )
234
+ ```
235
+ """
236
+
237
+ def __init__(self, logger=None):
238
+ self.logger = logger
239
+
240
+ def check_file(self, file_path: str) -> bool:
241
+ """Check if file exists."""
242
+ return os.path.exists(file_path)
243
+
244
+ def tokenize(self, text_data_path: List[str], text_splitter_type: str,
245
+ chunk_size: int, chunk_overlap: int) -> List:
246
+ """
247
+ Process and tokenize text data from files.
248
+
249
+ Args:
250
+ text_data_path (List[str]): List of paths to text files
251
+ text_splitter_type (str): Type of text splitter to use
252
+ chunk_size (int): Size of text chunks
253
+ chunk_overlap (int): Overlap between chunks
254
+
255
+ Returns:
256
+ List: List of processed document chunks
257
+
258
+ """
259
+ doc_data = []
260
+ for path in text_data_path:
261
+ if self.check_file(path):
262
+ text_loader = TextLoader(path)
263
+ get_text = text_loader.load()
264
+ file_name = path.split('/')[-1]
265
+ metadata = {'source': file_name}
266
+ if metadata is not None:
267
+ for doc in get_text:
268
+ doc.metadata = metadata
269
+ doc_data.append(doc)
270
+ if self.logger:
271
+ self.logger.info(f"Text data loaded from {file_name}")
272
+ else:
273
+ return f"File {path} not found"
274
+
275
+ splitters = {
276
+ 'character': CharacterTextSplitter(
277
+ chunk_size=chunk_size,
278
+ chunk_overlap=chunk_overlap,
279
+ separator=["\n", "\n\n", "\n\n\n", " "]
280
+ ),
281
+ 'recursive_character': RecursiveCharacterTextSplitter(
282
+ chunk_size=chunk_size,
283
+ chunk_overlap=chunk_overlap,
284
+ separators=["\n", "\n\n", "\n\n\n", " "]
285
+ ),
286
+ 'sentence_transformers_token': SentenceTransformersTokenTextSplitter(
287
+ chunk_size=chunk_size
288
+ ),
289
+ 'token': TokenTextSplitter(
290
+ chunk_size=chunk_size,
291
+ chunk_overlap=chunk_overlap
292
+ )
293
+ }
294
+
295
+ if text_splitter_type not in splitters:
296
+ raise ValueError(f"Invalid text splitter type: {text_splitter_type}")
297
+
298
+ text_splitter = splitters[text_splitter_type]
299
+ docs = text_splitter.split_documents(doc_data)
300
+
301
+ if self.logger:
302
+ self.logger.info(f"Text data splitted into {len(docs)} chunks")
303
+ else:
304
+ print(f"Text data splitted into {len(docs)} chunks")
305
+ return docs
306
+
307
+ class embedding_generator:
308
+ """
309
+ Main class for generating embeddings and managing RAG operations.
310
+
311
+ This class provides comprehensive functionality for generating embeddings,
312
+ managing vector stores, handling retrievers, and managing conversations.
313
+
314
+ Args:
315
+ model (str): Model provider name (default: 'openai')
316
+ model_type (str): Model type/identifier (default: 'text-embedding-3-small')
317
+ vector_store_type (str): Type of vector store (default: 'chroma')
318
+ collection_name (str): Name of the collection (default: 'test')
319
+ logger: Optional logger instance
320
+ model_kwargs (dict): Additional arguments for model initialization
321
+ vector_store_kwargs (dict): Additional arguments for vector store initialization
322
+
323
+ Example:
324
+ ```python
325
+ # Initialize generator
326
+ gen = embedding_generator(
327
+ model="openai",
328
+ model_type="text-embedding-3-small"
329
+ )
330
+
331
+ # Generate embeddings
332
+ gen.generate_text_embeddings(
333
+ text_data_path=['./data.txt'],
334
+ folder_save_path='./embeddings'
335
+ )
336
+
337
+ # Load retriever
338
+ retriever = gen.load_retriever('./embeddings')
339
+
340
+ # Query embeddings
341
+ results = gen.query_embeddings("What is this about?")
342
+ ```
343
+ """
344
+
345
+ def __init__(self, model: str = 'openai', model_type: str = 'text-embedding-3-small',
346
+ vector_store_type: str = 'chroma', collection_name: str = 'test',
347
+ logger=None, model_kwargs: dict = None, vector_store_kwargs: dict = None) -> None:
348
+ """Initialize the embedding generator with specified configuration."""
349
+ self.logger = logger
350
+ self.model = load_embedding_model(model_name=model, model_type=model_type, **(model_kwargs or {}))
351
+ if self.model is None:
352
+ raise ValueError(f"Failed to initialize model {model}. Please ensure required packages are installed.")
353
+ self.vector_store_type = vector_store_type
354
+ self.vector_store = self.load_vectorstore(**(vector_store_kwargs or {}))
355
+ self.collection_name = collection_name
356
+ self.text_processor = TextProcessor(logger)
357
+
358
+ def check_file(self, file_path: str) -> bool:
359
+ """Check if file exists."""
360
+ return self.text_processor.check_file(file_path)
361
+
362
+ def tokenize(self, text_data_path: List[str], text_splitter_type: str,
363
+ chunk_size: int, chunk_overlap: int) -> List:
364
+ """Process and tokenize text data."""
365
+ return self.text_processor.tokenize(text_data_path, text_splitter_type,
366
+ chunk_size, chunk_overlap)
367
+
368
+ def generate_text_embeddings(self, text_data_path: List[str] = None,
369
+ text_splitter_type: str = 'recursive_character',
370
+ chunk_size: int = 1000, chunk_overlap: int = 5,
371
+ folder_save_path: str = './text_embeddings',
372
+ replace_existing: bool = False) -> str:
373
+ """
374
+ Generate text embeddings from input files.
375
+
376
+ Args:
377
+ text_data_path (List[str]): List of paths to text files
378
+ text_splitter_type (str): Type of text splitter
379
+ chunk_size (int): Size of text chunks
380
+ chunk_overlap (int): Overlap between chunks
381
+ folder_save_path (str): Path to save embeddings
382
+ replace_existing (bool): Whether to replace existing embeddings
383
+
384
+ Returns:
385
+ str: Status message
386
+
387
+ Example:
388
+ ```python
389
+ gen.generate_text_embeddings(
390
+ text_data_path=['./data.txt'],
391
+ folder_save_path='./embeddings'
392
+ )
393
+ ```
394
+ """
395
+ if self.logger:
396
+ self.logger.info("Performing basic checks")
397
+
398
+ if self.check_file(folder_save_path) and not replace_existing:
399
+ return "File already exists"
400
+ elif self.check_file(folder_save_path) and replace_existing:
401
+ shutil.rmtree(folder_save_path)
402
+
403
+ if text_data_path is None:
404
+ return "Please provide text data path"
405
+
406
+ if not isinstance(text_data_path, list):
407
+ raise ValueError("text_data_path should be a list")
408
+
409
+ if self.logger:
410
+ self.logger.info(f"Loading text data from {text_data_path}")
411
+
412
+ docs = self.tokenize(text_data_path, text_splitter_type, chunk_size, chunk_overlap)
413
+
414
+ if self.logger:
415
+ self.logger.info(f"Generating embeddings for {len(docs)} documents")
416
+
417
+ self.vector_store.from_documents(docs, self.model, collection_name=self.collection_name,
418
+ persist_directory=folder_save_path)
419
+
420
+ if self.logger:
421
+ self.logger.info(f"Embeddings generated and saved at {folder_save_path}")
422
+
423
+ def load_vectorstore(self, **kwargs):
424
+ """Load vector store."""
425
+ if self.vector_store_type == 'chroma':
426
+ vector_store = Chroma()
427
+ if self.logger:
428
+ self.logger.info(f"Loaded vector store {self.vector_store_type}")
429
+ return vector_store
430
+ else:
431
+ return "Vector store not found"
432
+
433
+ def load_embeddings(self, embeddings_folder_path: str):
434
+ """
435
+ Load embeddings from folder.
436
+
437
+ Args:
438
+ embeddings_folder_path (str): Path to embeddings folder
439
+
440
+ Returns:
441
+ Optional[Chroma]: Loaded vector store or None if not found
442
+ """
443
+ if self.check_file(embeddings_folder_path):
444
+ if self.vector_store_type == 'chroma':
445
+ return Chroma(persist_directory=embeddings_folder_path,
446
+ embedding_function=self.model)
447
+ else:
448
+ if self.logger:
449
+ self.logger.info("Embeddings file not found")
450
+ return None
451
+
452
+ def load_retriever(self, embeddings_folder_path: str,
453
+ search_type: List[str] = ["similarity_score_threshold"],
454
+ search_params: List[Dict] = [{"k": 3, "score_threshold": 0.9}]):
455
+ """
456
+ Load retriever with search configuration.
457
+
458
+ Args:
459
+ embeddings_folder_path (str): Path to embeddings folder
460
+ search_type (List[str]): List of search types
461
+ search_params (List[Dict]): List of search parameters
462
+
463
+ Returns:
464
+ Union[Any, List[Any]]: Single retriever or list of retrievers
465
+
466
+ Example:
467
+ ```python
468
+ retriever = gen.load_retriever(
469
+ './embeddings',
470
+ search_type=["similarity_score_threshold"],
471
+ search_params=[{"k": 3, "score_threshold": 0.9}]
472
+ )
473
+ ```
474
+ """
475
+ db = self.load_embeddings(embeddings_folder_path)
476
+ if db is not None:
477
+ if self.vector_store_type == 'chroma':
478
+ if len(search_type) != len(search_params):
479
+ raise ValueError("Length of search_type and search_params should be equal")
480
+ if len(search_type) == 1:
481
+ self.retriever = db.as_retriever(search_type=search_type[0],
482
+ search_kwargs=search_params[0])
483
+ if self.logger:
484
+ self.logger.info("Retriever loaded")
485
+ return self.retriever
486
+ else:
487
+ retriever_list = []
488
+ for i in range(len(search_type)):
489
+ retriever_list.append(db.as_retriever(search_type=search_type[i],
490
+ search_kwargs=search_params[i]))
491
+ if self.logger:
492
+ self.logger.info("List of Retriever loaded")
493
+ return retriever_list
494
+ else:
495
+ return "Embeddings file not found"
496
+
497
+ def add_data(self, embeddings_folder_path: str, data: List[str],
498
+ text_splitter_type: str = 'recursive_character',
499
+ chunk_size: int = 1000, chunk_overlap: int = 5):
500
+ """
501
+ Add data to existing embeddings.
502
+
503
+ Args:
504
+ embeddings_folder_path (str): Path to embeddings folder
505
+ data (List[str]): List of text data to add
506
+ text_splitter_type (str): Type of text splitter
507
+ chunk_size (int): Size of text chunks
508
+ chunk_overlap (int): Overlap between chunks
509
+ """
510
+ if self.vector_store_type == 'chroma':
511
+ db = self.load_embeddings(embeddings_folder_path)
512
+ if db is not None:
513
+ docs = self.tokenize(data, text_splitter_type, chunk_size, chunk_overlap)
514
+ db.add_documents(docs)
515
+ if self.logger:
516
+ self.logger.info("Data added to the existing db/embeddings")
517
+
518
+ def query_embeddings(self, query: str, retriever=None):
519
+ """
520
+ Query embeddings.
521
+
522
+ Args:
523
+ query (str): Query string
524
+ retriever: Optional retriever instance
525
+
526
+ Returns:
527
+ Any: Query results
528
+ """
529
+ if retriever is None:
530
+ retriever = self.retriever
531
+ return retriever.invoke(query)
532
+
533
+ def get_relevant_documents(self, query: str, retriever=None):
534
+ """
535
+ Get relevant documents for query.
536
+
537
+ Args:
538
+ query (str): Query string
539
+ retriever: Optional retriever instance
540
+
541
+ Returns:
542
+ List: List of relevant documents
543
+ """
544
+ if retriever is None:
545
+ retriever = self.retriever
546
+ return retriever.get_relevant_documents(query)
547
+
548
+ def generate_rag_chain(self, context_prompt: str = None, retriever=None, llm=None):
549
+ """
550
+ Generate RAG chain for conversation.
551
+
552
+ Args:
553
+ context_prompt (str): Optional context prompt
554
+ retriever: Optional retriever instance
555
+ llm: Optional language model instance
556
+
557
+ Returns:
558
+ Any: Generated RAG chain
559
+
560
+ Example:
561
+ ```python
562
+ rag_chain = gen.generate_rag_chain(retriever=retriever)
563
+ ```
564
+ """
565
+ if context_prompt is None:
566
+ context_prompt = ("You are an assistant for question-answering tasks. "
567
+ "Use the following pieces of retrieved context to answer the question. "
568
+ "If you don't know the answer, just say that you don't know. "
569
+ "Use three sentences maximum and keep the answer concise.\n\n{context}")
570
+
571
+ contextualize_q_system_prompt = ("Given a chat history and the latest user question "
572
+ "which might reference context in the chat history, "
573
+ "formulate a standalone question which can be understood "
574
+ "without the chat history. Do NOT answer the question, "
575
+ "just reformulate it if needed and otherwise return it as is.")
576
+
577
+ contextualize_q_prompt = ChatPromptTemplate.from_messages([
578
+ ("system", contextualize_q_system_prompt),
579
+ MessagesPlaceholder("chat_history"),
580
+ ("human", "{input}"),
581
+ ])
582
+
583
+ if retriever is None:
584
+ retriever = self.retriever
585
+ if llm is None:
586
+ if not ModelProvider.check_package("langchain_openai"):
587
+ raise ImportError("OpenAI package not found. Please install: pip install langchain-openai")
588
+ from langchain_openai import ChatOpenAI
589
+ llm = ChatOpenAI(model="gpt-4")
590
+
591
+ history_aware_retriever = create_history_aware_retriever(llm, retriever,
592
+ contextualize_q_prompt)
593
+ qa_prompt = ChatPromptTemplate.from_messages([
594
+ ("system", context_prompt),
595
+ MessagesPlaceholder("chat_history"),
596
+ ("human", "{input}"),
597
+ ])
598
+ question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
599
+ rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
600
+ return rag_chain
601
+
602
+ def conversation_chain(self, query: str, rag_chain, file: str = None):
603
+ """
604
+ Create conversation chain.
605
+
606
+ Args:
607
+ query (str): User query
608
+ rag_chain: RAG chain instance
609
+ file (str): Optional file to save conversation
610
+
611
+ Returns:
612
+ List: Conversation history
613
+
614
+ Example:
615
+ ```python
616
+ history = gen.conversation_chain(
617
+ "Tell me about...",
618
+ rag_chain,
619
+ file='conversation.txt'
620
+ )
621
+ ```
622
+ """
623
+ if file is not None:
624
+ try:
625
+ chat_history = self.load_conversation(file, list_type=True)
626
+ if len(chat_history) == 0:
627
+ chat_history = []
628
+ except:
629
+ chat_history = []
630
+ else:
631
+ chat_history = []
632
+
633
+ query = "You : " + query
634
+ res = rag_chain.invoke({"input": query, "chat_history": chat_history})
635
+ print(f"Response: {res['answer']}")
636
+ chat_history.append(HumanMessage(content=query))
637
+ chat_history.append(SystemMessage(content=res['answer']))
638
+ if file is not None:
639
+ self.save_conversation(chat_history, file)
640
+ return chat_history
641
+
642
+ def load_conversation(self, file: str, list_type: bool = False):
643
+ """
644
+ Load conversation history.
645
+
646
+ Args:
647
+ file (str): Path to conversation file
648
+ list_type (bool): Whether to return as list
649
+
650
+ Returns:
651
+ Union[str, List]: Conversation history
652
+ """
653
+ if list_type:
654
+ chat_history = []
655
+ with open(file, 'r') as f:
656
+ for line in f:
657
+ chat_history.append(line.strip())
658
+ else:
659
+ with open(file, "r") as f:
660
+ chat_history = f.read()
661
+ return chat_history
662
+
663
+ def save_conversation(self, chat: Union[str, List], file: str):
664
+ """
665
+ Save conversation history.
666
+
667
+ Args:
668
+ chat (Union[str, List]): Conversation to save
669
+ file (str): Path to save file
670
+ """
671
+ if isinstance(chat, str):
672
+ with open(file, "a") as f:
673
+ f.write(chat)
674
+ elif isinstance(chat, list):
675
+ with open(file, "a") as f:
676
+ for i in chat[-2:]:
677
+ f.write("%s\n" % i)
678
+ print(f"Saved file : {file}")
679
+
680
+ def firecrawl_web(self, website: str, api_key: str = None, mode: str = "scrape",
681
+ file_to_save: str = './firecrawl_embeddings', **kwargs):
682
+ """
683
+ Get data from website using FireCrawl.
684
+
685
+ Args:
686
+ website (str): Website URL to crawl
687
+ api_key (str): Optional FireCrawl API key
688
+ mode (str): Crawl mode (default: "scrape")
689
+ file_to_save (str): Path to save embeddings
690
+ **kwargs: Additional arguments for FireCrawl
691
+
692
+ Returns:
693
+ Chroma: Vector store with crawled data
694
+
695
+ Example:
696
+ ```python
697
+ db = gen.firecrawl_web(
698
+ "https://example.com",
699
+ mode="scrape",
700
+ file_to_save='./crawl_embeddings'
701
+ )
702
+ ```
703
+ """
704
+ if not ModelProvider.check_package("firecrawl"):
705
+ raise ImportError("Firecrawl package not found. Please install: pip install firecrawl")
706
+
707
+ if api_key is None:
708
+ api_key = os.getenv("FIRECRAWL_API_KEY")
709
+
710
+ loader = FireCrawlLoader(api_key=api_key, url=website, mode=mode)
711
+ docs = loader.load()
712
+
713
+ for doc in docs:
714
+ for key, value in doc.metadata.items():
715
+ if isinstance(value, list):
716
+ doc.metadata[key] = ", ".join(map(str, value))
717
+
718
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
719
+ split_docs = text_splitter.split_documents(docs)
720
+
721
+ print("\n--- Document Chunks Information ---")
722
+ print(f"Number of document chunks: {len(split_docs)}")
723
+ print(f"Sample chunk:\n{split_docs[0].page_content}\n")
724
+
725
+ embeddings = self.model
726
+ db = Chroma.from_documents(split_docs, embeddings,
727
+ persist_directory=file_to_save)
728
+ print(f"Retriever saved at {file_to_save}")
729
+ return db
@@ -1,5 +1,5 @@
1
1
  MAJOR_VERSION = 1
2
2
  MINOR_VERSION = 0
3
- PATCH_VERSION = 123
3
+ PATCH_VERSION = 125
4
4
  version = '{}.{}.{}'.format(MAJOR_VERSION, MINOR_VERSION, PATCH_VERSION)
5
5
  __all__ = ['MAJOR_VERSION', 'MINOR_VERSION', 'PATCH_VERSION', 'version']
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mb_rag
3
- Version: 1.0.123
3
+ Version: 1.0.125
4
4
  Summary: RAG function file
5
5
  Author: ['Malav Bateriwala']
6
6
  Requires-Python: >=3.8
@@ -1,480 +0,0 @@
1
- ## Function to generate embeddings for the RAG model
2
-
3
- import os
4
- import shutil
5
- import importlib.util
6
- from langchain.text_splitter import (
7
- CharacterTextSplitter,
8
- RecursiveCharacterTextSplitter,
9
- SentenceTransformersTokenTextSplitter,
10
- TokenTextSplitter)
11
- from langchain_community.document_loaders import TextLoader,FireCrawlLoader
12
- from langchain_community.vectorstores import Chroma
13
- from ..utils.extra import load_env_file
14
- from langchain.chains import create_history_aware_retriever, create_retrieval_chain
15
- from langchain.chains.combine_documents import create_stuff_documents_chain
16
- from langchain_core.messages import HumanMessage, SystemMessage
17
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
18
- import time
19
-
20
- load_env_file()
21
-
22
- __all__ = ['embedding_generator', 'load_rag_model']
23
-
24
- def check_package(package_name):
25
- """
26
- Check if a package is installed
27
- Args:
28
- package_name (str): Name of the package
29
- Returns:
30
- bool: True if package is installed, False otherwise
31
- """
32
- return importlib.util.find_spec(package_name) is not None
33
-
34
- def get_rag_openai(model_type: str = 'text-embedding-3-small',**kwargs):
35
- """
36
- Load model from openai for RAG
37
- Args:
38
- model_type (str): Name of the model
39
- **kwargs: Additional arguments (temperature, max_tokens, timeout, max_retries, api_key etc.)
40
- Returns:
41
- ChatOpenAI: Chatbot model
42
- """
43
- if not check_package("langchain_openai"):
44
- raise ImportError("OpenAI package not found. Please install it using: pip install langchain-openai")
45
-
46
- from langchain_openai import OpenAIEmbeddings
47
- return OpenAIEmbeddings(model = model_type,**kwargs)
48
-
49
- def get_rag_ollama(model_type: str = 'llama3',**kwargs):
50
- """
51
- Load model from ollama for RAG
52
- Args:
53
- model_type (str): Name of the model
54
- **kwargs: Additional arguments (temperature, max_tokens, timeout, max_retries, api_key etc.)
55
- Returns:
56
- OllamaEmbeddings: Embeddings model
57
- """
58
- if not check_package("langchain_ollama"):
59
- raise ImportError("Ollama package not found. Please install it using: pip install langchain-ollama")
60
-
61
- from langchain_ollama import OllamaEmbeddings
62
- return OllamaEmbeddings(model = model_type,**kwargs)
63
-
64
- def get_rag_anthropic(model_name: str = "claude-3-opus-20240229",**kwargs):
65
- """
66
- Load the chatbot model from Anthropic
67
- Args:
68
- model_name (str): Name of the model
69
- **kwargs: Additional arguments (temperature, max_tokens, timeout, max_retries, api_key etc.)
70
- Returns:
71
- ChatAnthropic: Chatbot model
72
- """
73
- if not check_package("langchain_anthropic"):
74
- raise ImportError("Anthropic package not found. Please install it using: pip install langchain-anthropic")
75
-
76
- from langchain_anthropic import ChatAnthropic
77
- kwargs["model_name"] = model_name
78
- return ChatAnthropic(**kwargs)
79
-
80
- def get_rag_google(model_name: str = "gemini-1.5-flash",**kwargs):
81
- """
82
- Load the chatbot model from Google Generative AI
83
- Args:
84
- model_name (str): Name of the model
85
- **kwargs: Additional arguments (temperature, max_tokens, timeout, max_retries, api_key etc.)
86
- Returns:
87
- ChatGoogleGenerativeAI: Chatbot model
88
- """
89
- if not check_package("google.generativeai"):
90
- raise ImportError("Google Generative AI package not found. Please install it using: pip install langchain-google-genai")
91
-
92
- from langchain_google_genai import GoogleGenerativeAIEmbeddings
93
- kwargs["model"] = model_name
94
- return GoogleGenerativeAIEmbeddings(**kwargs)
95
-
96
- def load_rag_model(model_name: str ='openai', model_type: str = "text-embedding-ada-002", **kwargs):
97
- """
98
- Load a RAG model from a given model name and type
99
- Args:
100
- model_name (str): Name of the model. Default is openai.
101
- model_type (str): Type of the model. Default is text-embedding-ada-002.
102
- **kwargs: Additional arguments (temperature, max_tokens, timeout, max_retries, api_key etc.)
103
- Returns:
104
- RAGModel: RAG model
105
- """
106
- try:
107
- if model_name == 'openai':
108
- return get_rag_openai(model_type, **(kwargs or {}))
109
- elif model_name == 'ollama':
110
- return get_rag_ollama(model_type, **(kwargs or {}))
111
- elif model_name == 'google':
112
- return get_rag_google(model_type, **(kwargs or {}))
113
- elif model_name == 'anthropic':
114
- return get_rag_anthropic(model_type, **(kwargs or {}))
115
- else:
116
- raise ValueError(f"Invalid model name: {model_name}")
117
- except ImportError as e:
118
- print(f"Error loading model: {str(e)}")
119
- return None
120
-
121
- class embedding_generator:
122
- """
123
- Class to generate embeddings for the RAG model abnd chat with data
124
- Args:
125
- model: type of model. Default is openai. Options are openai, anthropic, google, ollama
126
- model_type: type of model. Default is text-embedding-3-small. Options are text-embedding-3-small, text-embedding-3-large, text-embedding-ada-002 for openai.
127
- vector_store_type: type of vector store. Default is chroma
128
- logger: logger
129
- model_kwargs: additional arguments for the model
130
- vector_store_kwargs: additional arguments for the vector store
131
- collection_name: name of the collection (default : test)
132
- """
133
-
134
- def __init__(self,model: str = 'openai',model_type: str = 'text-embedding-3-small',vector_store_type:str = 'chroma' ,collection_name: str = 'test',logger= None,model_kwargs: dict = None, vector_store_kwargs: dict = None) -> None:
135
- self.logger = logger
136
- self.model = load_rag_model(model_name=model, model_type=model_type, **(model_kwargs or {}))
137
- if self.model is None:
138
- raise ValueError(f"Failed to initialize model {model}. Please ensure required packages are installed.")
139
- self.vector_store_type = vector_store_type
140
- self.vector_store = self.load_vectorstore(**(vector_store_kwargs or {}))
141
- self.collection_name = collection_name
142
-
143
- def check_file(self, file_path):
144
- """
145
- Check if the file exists
146
- """
147
- if os.path.exists(file_path):
148
- return True
149
- else:
150
- return False
151
-
152
- def tokenize(self,text_data_path :list,text_splitter_type: str,chunk_size: int,chunk_overlap: int):
153
- """
154
- Function to tokenize the text
155
- Args:
156
- text: text to tokenize
157
- Returns:
158
- tokens
159
- """
160
- doc_data = []
161
- for i in text_data_path:
162
- if self.check_file(i):
163
- text_loader = TextLoader(i)
164
- get_text = text_loader.load()
165
- # print(get_text) ## testing - Need to remove
166
- file_name = i.split('/')[-1]
167
- metadata = {'source': file_name}
168
- if metadata is not None:
169
- for j in get_text:
170
- j.metadata = metadata
171
- doc_data.append(j)
172
- if self.logger is not None:
173
- self.logger.info(f"Text data loaded from {file_name}")
174
- else:
175
- return f"File {i} not found"
176
-
177
- if self.logger is not None:
178
- self.logger.info(f"Splitting text data into chunks of size {chunk_size} with overlap {chunk_overlap}")
179
- if text_splitter_type == 'character':
180
- text_splitter = CharacterTextSplitter(chunk_size=chunk_size,chunk_overlap=chunk_overlap, separator=["\n","\n\n","\n\n\n"," "])
181
- if text_splitter_type == 'recursive_character':
182
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,chunk_overlap=chunk_overlap,separators=["\n","\n\n","\n\n\n"," "])
183
- if text_splitter_type == 'sentence_transformers_token':
184
- text_splitter = SentenceTransformersTokenTextSplitter(chunk_size=chunk_size)
185
- if text_splitter_type == 'token':
186
- text_splitter = TokenTextSplitter(chunk_size=chunk_size,chunk_overlap=chunk_overlap)
187
- docs = text_splitter.split_documents(doc_data)
188
- if self.logger is not None:
189
- self.logger.info(f"Text data splitted into {len(docs)} chunks")
190
- else:
191
- print(f"Text data splitted into {len(docs)} chunks")
192
- return docs
193
-
194
- def generate_text_embeddings(self,text_data_path: list = None,text_splitter_type: str = 'recursive_character',
195
- chunk_size: int = 1000,chunk_overlap: int = 5,folder_save_path: str = './text_embeddings',
196
- replace_existing: bool = False):
197
- """
198
- Function to generate text embeddings
199
- Args:
200
- text_data_path: list of text files
201
- # metadata: list of metadata for each text file. Dictionary format
202
- text_splitter_type: type of text splitter. Default is recursive_character
203
- chunk_size: size of the chunk
204
- chunk_overlap: overlap between chunks
205
- folder_save_path: path to save the embeddings
206
- replace_existing: if True, replace the existing embeddings
207
- Returns:
208
- None
209
- """
210
-
211
- if self.logger is not None:
212
- self.logger.info("Perforing basic checks")
213
-
214
- if self.check_file(folder_save_path) and replace_existing==False:
215
- return "File already exists"
216
- elif self.check_file(folder_save_path) and replace_existing:
217
- shutil.rmtree(folder_save_path)
218
-
219
- if text_data_path is None:
220
- return "Please provide text data path"
221
-
222
- assert isinstance(text_data_path, list), "text_data_path should be a list"
223
- # if metadata is not None:
224
- # assert isinstance(metadata, list), "metadata should be a list"
225
- # assert len(text_data_path) == len(metadata), "Number of text files and metadata should be equal"
226
-
227
- if self.logger is not None:
228
- self.logger.info(f"Loading text data from {text_data_path}")
229
-
230
- docs = self.tokenize(text_data_path,text_splitter_type,chunk_size,chunk_overlap)
231
-
232
- if self.logger is not None:
233
- self.logger.info(f"Generating embeddings for {len(docs)} documents")
234
-
235
- self.vector_store.from_documents(docs, self.model,collection_name=self.collection_name,persist_directory=folder_save_path)
236
-
237
- if self.logger is not None:
238
- self.logger.info(f"Embeddings generated and saved at {folder_save_path}")
239
-
240
- def load_vectorstore(self):
241
- """
242
- Function to load vector store
243
- Args:
244
- vector_store_type: type of vector store
245
- Returns:
246
- vector store
247
- """
248
- if self.vector_store_type == 'chroma':
249
- vector_store = Chroma()
250
- if self.logger is not None:
251
- self.logger.info(f"Loaded vector store {self.vector_store_type}")
252
- return vector_store
253
- else:
254
- return "Vector store not found"
255
-
256
- def load_embeddings(self,embeddings_folder_path: str):
257
- """
258
- Function to load embeddings from the folder
259
- Args:
260
- embeddings_path: path to the embeddings
261
- Returns:
262
- embeddings
263
- """
264
- if self.check_file(embeddings_folder_path):
265
- if self.vector_store_type == 'chroma':
266
- # embeddings_path = os.path.join(embeddings_folder_path)
267
- return Chroma(persist_directory = embeddings_folder_path,embedding_function=self.model)
268
- else:
269
- if self.logger:
270
- self.logger.info("Embeddings file not found")
271
- return None
272
-
273
- def load_retriever(self,embeddings_folder_path: str,search_type: list = ["similarity_score_threshold"],search_params: list = [{"k": 3, "score_threshold": 0.9}]):
274
- """
275
- Function to load retriever
276
- Args:
277
- embeddings_path: path to the embeddings
278
- search_type: list of str: type of search. Default : ["similarity_score_threshold"]
279
- search_params: list of dict: parameters for the search. Default : [{"k": 3, "score_threshold": 0.9}]
280
- Returns:
281
- Retriever. If multiple search types are provided, a list of retrievers is returned
282
- """
283
- db = self.load_embeddings(embeddings_folder_path)
284
- if db is not None:
285
- if self.vector_store_type == 'chroma':
286
- assert len(search_type) == len(search_params), "Length of search_type and search_params should be equal"
287
- if len(search_type) == 1:
288
- self.retriever = db.as_retriever(search_type = search_type[0],search_kwargs=search_params[0])
289
- if self.logger:
290
- self.logger.info("Retriever loaded")
291
- return self.retriever
292
- else:
293
- retriever_list = []
294
- for i in range(len(search_type)):
295
- retriever_list.append(db.as_retriever(search_type = search_type[i],search_kwargs=search_params[i]))
296
- if self.logger:
297
- self.logger.info("List of Retriever loaded")
298
- return retriever_list
299
- else:
300
- return "Embeddings file not found"
301
-
302
- def add_data(self,embeddings_folder_path: str, data: list,text_splitter_type: str = 'recursive_character',
303
- chunk_size: int = 1000,chunk_overlap: int = 5):
304
- """
305
- Function to add data to the existing db/embeddings
306
- Args:
307
- embeddings_path: path to the embeddings
308
- data: list of data to add
309
- text_splitter_type: type of text splitter. Default is recursive_character
310
- chunk_size: size of the chunk
311
- chunk_overlap: overlap between chunks
312
- Returns:
313
- None
314
- """
315
- if self.vector_store_type == 'chroma':
316
- db = self.load_embeddings(embeddings_folder_path)
317
- if db is not None:
318
- docs = self.tokenize(data,text_splitter_type,chunk_size,chunk_overlap)
319
- db.add_documents(docs)
320
- if self.logger:
321
- self.logger.info("Data added to the existing db/embeddings")
322
-
323
- def query_embeddings(self,query: str,retriever = None):
324
- """
325
- Function to query embeddings
326
- Args:
327
- search_type: type of search
328
- query: query to search
329
- Returns:
330
- results
331
- """
332
- # if self.vector_store_type == 'chroma':
333
- if retriever is None:
334
- retriever = self.retriever
335
- return retriever.invoke(query)
336
- # else:
337
- # return "Vector store not found"
338
-
339
- def get_relevant_documents(self,query: str,retriever = None):
340
- """
341
- Function to get relevant documents
342
- Args:
343
- query: query to search
344
- Returns:
345
- results
346
- """
347
- return self.retriever.get_relevant_documents(query)
348
-
349
- def generate_rag_chain(self,context_prompt: str = None,retriever = None,llm= None):
350
- """
351
- Function to start a conversation chain with a rag data. Call this to load a rag_chain module.
352
- Args:
353
- context_prompt: prompt to context
354
- retriever: retriever. Default is None.
355
- llm: language model. Default is openai. Need chat model llm. "ChatOpenAI", "ChatAnthropic" etc. like chatbot
356
- Returns:
357
- rag_chain_model.
358
- """
359
- if context_prompt is None:
360
- context_prompt = ("You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. "
361
- "If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise. "
362
- "\n\n {context}")
363
- contextualize_q_system_prompt = ("Given a chat history and the latest user question which might reference context in the chat history, formulate a standalone question which can be understood "
364
- "without the chat history. Do NOT answer the question, just reformulate it if needed and otherwise return it as is.")
365
- contextualize_q_prompt = ChatPromptTemplate.from_messages([("system", contextualize_q_system_prompt),MessagesPlaceholder("chat_history"),("human", "{input}"),])
366
-
367
- if retriever is None:
368
- retriever = self.retriever
369
- if llm is None:
370
- if not check_package("langchain_openai"):
371
- raise ImportError("OpenAI package not found. Please install it using: pip install langchain-openai")
372
- from langchain_openai import ChatOpenAI
373
- llm = ChatOpenAI(model="gpt-4")
374
-
375
- history_aware_retriever = create_history_aware_retriever(llm,retriever, contextualize_q_prompt)
376
- qa_prompt = ChatPromptTemplate.from_messages([("system", context_prompt),MessagesPlaceholder("chat_history"),("human", "{input}"),])
377
- question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
378
- rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
379
- return rag_chain
380
-
381
- def conversation_chain(self,query: str,rag_chain,file:str =None):
382
- """
383
- Function to create a conversation chain
384
- Args:
385
- query: query to search
386
- rag_chain : rag_chain model
387
- file: load a file and update it with the conversation. If None it will not be saved.
388
- Returns:
389
- results
390
- """
391
- if file is not None:
392
- try:
393
- chat_history = self.load_conversation(file,list_type=True)
394
- if len(chat_history) == 0:
395
- chat_history = []
396
- except:
397
- chat_history = []
398
- else:
399
- chat_history = []
400
- query = "You : " + query
401
- res = rag_chain.invoke({"input": query,"chat_history": chat_history})
402
- print(f"Response: {res['answer']}")
403
- chat_history.append(HumanMessage(content=query))
404
- chat_history.append(SystemMessage(content=res['answer']))
405
- if file is not None:
406
- self.save_conversation(chat_history,file)
407
- return chat_history
408
-
409
- def load_conversation(self,file: str,list_type: bool = False):
410
- """
411
- Function to load the conversation
412
- Args:
413
- file: file to load
414
- list_type: if True, return the chat_history as a list. Default is False.
415
- Returns:
416
- chat_history
417
- """
418
- if list_type:
419
- chat_history = []
420
- with open(file,'r') as f:
421
- for line in f:
422
- # inner_list = [elt.strip() for elt in line.split(',')]
423
- chat_history.append(line.strip())
424
- else:
425
- with open(file, "r") as f:
426
- chat_history = f.read()
427
- return chat_history
428
-
429
- def save_conversation(self,chat: str,file: str):
430
- """
431
- Function to save the conversation
432
- Args:
433
- chat: chat results
434
- file: file to save
435
- Returns:
436
- None
437
- """
438
- if isinstance(chat,str):
439
- with open(file, "a") as f:
440
- f.write(chat)
441
- elif isinstance(chat,list):
442
- with open(file, "a") as f:
443
- for i in chat[-2:]:
444
- f.write("%s\n" % i)
445
- print(f"Saved file : {file}")
446
-
447
- def firecrawl_web(self, website, api_key: str = None, mode="scrape", file_to_save: str = './firecrawl_embeddings',**kwargs):
448
- """
449
- Function to get data from website. Use this to get data from a website and save it as embeddings/retriever. To ask questions from the website,
450
- use the load_retriever and query_embeddings function.
451
- Args:
452
- website : str - link to website.
453
- api_key : api key of firecrawl, if None environment variable "FIRECRAWL_API_KEY" will be used.
454
- mode(str) : 'scrape' default to just use the same page. Not the whole website.
455
- file_to_save: path to save the embeddings
456
- **kwargs: additional arguments
457
- Returns:
458
- retriever
459
- """
460
- if not check_package("firecrawl"):
461
- raise ImportError("Firecrawl package not found. Please install it using: pip install firecrawl")
462
-
463
- if api_key is None:
464
- api_key = os.getenv("FIRECRAWL_API_KEY")
465
- loader = FireCrawlLoader(api_key=api_key, url=website, mode=mode)
466
- docs = loader.load()
467
- for doc in docs:
468
- for key, value in doc.metadata.items():
469
- if isinstance(value, list):
470
- doc.metadata[key] = ", ".join(map(str, value))
471
- text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
472
- split_docs = text_splitter.split_documents(docs)
473
- print("\n--- Document Chunks Information ---")
474
- print(f"Number of document chunks: {len(split_docs)}")
475
- print(f"Sample chunk:\n{split_docs[0].page_content}\n")
476
- embeddings = self.model
477
- db = Chroma.from_documents(
478
- split_docs, embeddings, persist_directory=file_to_save)
479
- print(f"Retriever saved at {file_to_save}")
480
- return db
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes