ai-parrot 0.3.4__cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-parrot might be problematic. Click here for more details.

Files changed (109) hide show
  1. ai_parrot-0.3.4.dist-info/LICENSE +21 -0
  2. ai_parrot-0.3.4.dist-info/METADATA +319 -0
  3. ai_parrot-0.3.4.dist-info/RECORD +109 -0
  4. ai_parrot-0.3.4.dist-info/WHEEL +6 -0
  5. ai_parrot-0.3.4.dist-info/top_level.txt +3 -0
  6. parrot/__init__.py +21 -0
  7. parrot/chatbots/__init__.py +7 -0
  8. parrot/chatbots/abstract.py +728 -0
  9. parrot/chatbots/asktroc.py +16 -0
  10. parrot/chatbots/base.py +366 -0
  11. parrot/chatbots/basic.py +9 -0
  12. parrot/chatbots/bose.py +17 -0
  13. parrot/chatbots/cody.py +17 -0
  14. parrot/chatbots/copilot.py +83 -0
  15. parrot/chatbots/dataframe.py +103 -0
  16. parrot/chatbots/hragents.py +15 -0
  17. parrot/chatbots/odoo.py +17 -0
  18. parrot/chatbots/retrievals/__init__.py +578 -0
  19. parrot/chatbots/retrievals/constitutional.py +19 -0
  20. parrot/conf.py +110 -0
  21. parrot/crew/__init__.py +3 -0
  22. parrot/crew/tools/__init__.py +22 -0
  23. parrot/crew/tools/bing.py +13 -0
  24. parrot/crew/tools/config.py +43 -0
  25. parrot/crew/tools/duckgo.py +62 -0
  26. parrot/crew/tools/file.py +24 -0
  27. parrot/crew/tools/google.py +168 -0
  28. parrot/crew/tools/gtrends.py +16 -0
  29. parrot/crew/tools/md2pdf.py +25 -0
  30. parrot/crew/tools/rag.py +42 -0
  31. parrot/crew/tools/search.py +32 -0
  32. parrot/crew/tools/url.py +21 -0
  33. parrot/exceptions.cpython-310-x86_64-linux-gnu.so +0 -0
  34. parrot/handlers/__init__.py +4 -0
  35. parrot/handlers/bots.py +196 -0
  36. parrot/handlers/chat.py +162 -0
  37. parrot/interfaces/__init__.py +6 -0
  38. parrot/interfaces/database.py +29 -0
  39. parrot/llms/__init__.py +137 -0
  40. parrot/llms/abstract.py +47 -0
  41. parrot/llms/anthropic.py +42 -0
  42. parrot/llms/google.py +42 -0
  43. parrot/llms/groq.py +45 -0
  44. parrot/llms/hf.py +45 -0
  45. parrot/llms/openai.py +59 -0
  46. parrot/llms/pipes.py +114 -0
  47. parrot/llms/vertex.py +78 -0
  48. parrot/loaders/__init__.py +20 -0
  49. parrot/loaders/abstract.py +456 -0
  50. parrot/loaders/audio.py +106 -0
  51. parrot/loaders/basepdf.py +102 -0
  52. parrot/loaders/basevideo.py +280 -0
  53. parrot/loaders/csv.py +42 -0
  54. parrot/loaders/dir.py +37 -0
  55. parrot/loaders/excel.py +349 -0
  56. parrot/loaders/github.py +65 -0
  57. parrot/loaders/handlers/__init__.py +5 -0
  58. parrot/loaders/handlers/data.py +213 -0
  59. parrot/loaders/image.py +119 -0
  60. parrot/loaders/json.py +52 -0
  61. parrot/loaders/pdf.py +437 -0
  62. parrot/loaders/pdfchapters.py +142 -0
  63. parrot/loaders/pdffn.py +112 -0
  64. parrot/loaders/pdfimages.py +207 -0
  65. parrot/loaders/pdfmark.py +88 -0
  66. parrot/loaders/pdftables.py +145 -0
  67. parrot/loaders/ppt.py +30 -0
  68. parrot/loaders/qa.py +81 -0
  69. parrot/loaders/repo.py +103 -0
  70. parrot/loaders/rtd.py +65 -0
  71. parrot/loaders/txt.py +92 -0
  72. parrot/loaders/utils/__init__.py +1 -0
  73. parrot/loaders/utils/models.py +25 -0
  74. parrot/loaders/video.py +96 -0
  75. parrot/loaders/videolocal.py +120 -0
  76. parrot/loaders/vimeo.py +106 -0
  77. parrot/loaders/web.py +216 -0
  78. parrot/loaders/web_base.py +112 -0
  79. parrot/loaders/word.py +125 -0
  80. parrot/loaders/youtube.py +192 -0
  81. parrot/manager.py +166 -0
  82. parrot/models.py +372 -0
  83. parrot/py.typed +0 -0
  84. parrot/stores/__init__.py +48 -0
  85. parrot/stores/abstract.py +171 -0
  86. parrot/stores/milvus.py +632 -0
  87. parrot/stores/qdrant.py +153 -0
  88. parrot/tools/__init__.py +12 -0
  89. parrot/tools/abstract.py +53 -0
  90. parrot/tools/asknews.py +32 -0
  91. parrot/tools/bing.py +13 -0
  92. parrot/tools/duck.py +62 -0
  93. parrot/tools/google.py +170 -0
  94. parrot/tools/stack.py +26 -0
  95. parrot/tools/weather.py +70 -0
  96. parrot/tools/wikipedia.py +59 -0
  97. parrot/tools/zipcode.py +179 -0
  98. parrot/utils/__init__.py +2 -0
  99. parrot/utils/parsers/__init__.py +5 -0
  100. parrot/utils/parsers/toml.cpython-310-x86_64-linux-gnu.so +0 -0
  101. parrot/utils/toml.py +11 -0
  102. parrot/utils/types.cpython-310-x86_64-linux-gnu.so +0 -0
  103. parrot/utils/uv.py +11 -0
  104. parrot/version.py +10 -0
  105. resources/users/__init__.py +5 -0
  106. resources/users/handlers.py +13 -0
  107. resources/users/models.py +205 -0
  108. settings/__init__.py +0 -0
  109. settings/settings.py +51 -0
parrot/llms/pipes.py ADDED
@@ -0,0 +1,114 @@
1
+ import torch
2
+ from langchain_community.llms import HuggingFacePipeline # pylint: disable=import-error, E0611
3
+ from transformers import (
4
+ AutoModelForCausalLM,
5
+ AutoProcessor,
6
+ LlavaForConditionalGeneration,
7
+ AutoTokenizer,
8
+ GenerationConfig,
9
+ pipeline
10
+ )
11
+ from .abstract import AbstractLLM
12
+
13
+
14
+ class PipelineLLM(AbstractLLM):
15
+ """PipelineLLM.
16
+
17
+ Load a LLM (Language Model) from HuggingFace Hub.
18
+
19
+ Returns:
20
+ _type_: an instance of HuggingFace LLM Model.
21
+ """
22
+ model: str = "databricks/dolly-v2-3b"
23
+ embed_model: str = None
24
+ max_tokens: int = 1024
25
+ supported_models: list = [
26
+ "databricks/dolly-v2-3b",
27
+ "gpt2",
28
+ "bigscience/bloom-1b7",
29
+ "meta-llama/Llama-2-7b-hf",
30
+ 'llava-hf/llava-1.5-7b-hf'
31
+ ]
32
+
33
+ def __init__(self, *args, **kwargs):
34
+ self.batch_size = kwargs.get('batch_size', 4)
35
+ self.use_llava: bool = kwargs.get('use_llava', False)
36
+ self.model_args = kwargs.get('model_args', {})
37
+ super().__init__(*args, **kwargs)
38
+ dtype = kwargs.get('dtype', 'float16')
39
+ if dtype == 'bfloat16':
40
+ torch_dtype = torch.bfloat16
41
+ if dtype == 'float16':
42
+ torch_dtype = torch.float16
43
+ elif dtype == 'float32':
44
+ torch_dtype = torch.float32
45
+ elif dtype == 'float8':
46
+ torch_dtype = torch.float8
47
+ else:
48
+ torch_dtype = "auto"
49
+ use_fast = kwargs.get('use_fast', True)
50
+ if self.use_llava is False:
51
+ self.tokenizer = AutoTokenizer.from_pretrained(
52
+ self.model,
53
+ chunk_size=self.max_tokens
54
+ )
55
+ self._model = AutoModelForCausalLM.from_pretrained(
56
+ self.model,
57
+ device_map="auto",
58
+ torch_dtype=torch_dtype,
59
+ trust_remote_code=True,
60
+ )
61
+ config = GenerationConfig(
62
+ do_sample=True,
63
+ temperature=self.temperature,
64
+ max_new_tokens=self.max_tokens,
65
+ top_p=self.top_p,
66
+ top_k=self.top_k,
67
+ repetition_penalty=1.15,
68
+ )
69
+ self._pipe = pipeline(
70
+ task=self.task,
71
+ model=self._model,
72
+ tokenizer=self.tokenizer,
73
+ return_full_text=True,
74
+ use_fast=use_fast,
75
+ device_map='auto',
76
+ batch_size=self.batch_size,
77
+ generation_config=config,
78
+ pad_token_id = 50256,
79
+ framework="pt"
80
+ )
81
+ else:
82
+ self._model = LlavaForConditionalGeneration.from_pretrained(
83
+ self.model,
84
+ device_map="auto",
85
+ torch_dtype=torch_dtype,
86
+ trust_remote_code=True,
87
+ low_cpu_mem_usage=True,
88
+ )
89
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model)
90
+ processor = AutoProcessor.from_pretrained(self.model)
91
+ self._pipe = pipeline(
92
+ task=self.task,
93
+ model=self._model,
94
+ tokenizer=self.tokenizer,
95
+ use_fast=use_fast,
96
+ device_map='auto',
97
+ batch_size=self.batch_size,
98
+ image_processor=processor.image_processor,
99
+ framework="pt",
100
+ **self.model_args
101
+ )
102
+ self._pipe.tokenizer.pad_token_id = self._pipe.model.config.eos_token_id
103
+ self._llm = HuggingFacePipeline(
104
+ model_id=self.model,
105
+ pipeline=self._pipe,
106
+ verbose=True
107
+ )
108
+
109
+ def pipe(self, *args, **kwargs):
110
+ return self._pipe(
111
+ *args,
112
+ **kwargs,
113
+ generate_kwargs={"max_new_tokens": self.max_tokens}
114
+ )
parrot/llms/vertex.py ADDED
@@ -0,0 +1,78 @@
1
+ import os
2
+ from navconfig import config, BASE_DIR
3
+ from google.cloud import aiplatform
4
+ from langchain_google_vertexai import (
5
+ ChatVertexAI,
6
+ VertexAI,
7
+ VertexAIModelGarden,
8
+ VertexAIEmbeddings
9
+ )
10
+ from .abstract import AbstractLLM
11
+
12
+ class VertexLLM(AbstractLLM):
13
+ """VertexLLM.
14
+
15
+ Interact with VertexAI Language Model.
16
+
17
+ Returns:
18
+ _type_: VertexAI LLM.
19
+ """
20
+ model: str = "gemini-1.0-pro"
21
+ embed_model: str = "textembedding-gecko@003"
22
+ max_tokens: int = 1024
23
+ supported_models: list = [
24
+ "gemini-1.0-pro",
25
+ "gemini-1.5-pro-001",
26
+ "gemini-1.5-pro-exp-0801",
27
+ "gemini-1.5-flash-preview-0514",
28
+ "gemini-1.5-flash-001",
29
+ "chat-bison@001",
30
+ "claude-3-opus@20240229",
31
+ 'claude-3-5-sonnet@20240620'
32
+ ]
33
+
34
+ def __init__(self, *args, **kwargs):
35
+ super().__init__(*args, **kwargs)
36
+ use_garden: bool = kwargs.get("use_garden", False)
37
+ project_id = config.get("VERTEX_PROJECT_ID")
38
+ region = config.get("VERTEX_REGION")
39
+ config_file = config.get('GOOGLE_CREDENTIALS_FILE', 'env/google/vertexai.json')
40
+ config_dir = BASE_DIR.joinpath(config_file)
41
+ os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = str(config_dir)
42
+ self.args = {
43
+ "project": project_id,
44
+ "location": region,
45
+ "max_output_tokens": self.max_tokens,
46
+ "temperature": self.temperature,
47
+ "max_retries": 4,
48
+ "top_p": self.top_p,
49
+ "top_k": self.top_k,
50
+ "verbose": True,
51
+ }
52
+ if use_garden is True:
53
+ base_llm = VertexAIModelGarden
54
+ self.args['endpoint_id'] = self.model
55
+ elif self.model == "chat":
56
+ self.args['model_name'] = "chat-bison@001"
57
+ base_llm = ChatVertexAI
58
+ else:
59
+ self.args['model_name'] = self.model
60
+ base_llm = VertexAI
61
+ # LLM
62
+ self._llm = base_llm(
63
+ system_prompt="Always respond in the same language as the user's question. If the user's language is not English, translate your response into their language.",
64
+ **self.args
65
+ )
66
+ # Embedding Model:
67
+ embed_model = kwargs.get("embed_model", self.embed_model)
68
+ self._embed = VertexAIEmbeddings(
69
+ model_name=embed_model,
70
+ project=project_id,
71
+ location=region,
72
+ request_parallelism=5,
73
+ max_retries=4,
74
+ temperature=self.temperature,
75
+ top_p=self.top_p,
76
+ top_k=self.top_k,
77
+ )
78
+ self._version_ = aiplatform.__version__
@@ -0,0 +1,20 @@
1
+ from .dir import load_directory
2
+ from .pdf import PDFLoader
3
+ from .web import WebLoader
4
+ from .youtube import YoutubeLoader
5
+ from .vimeo import VimeoLoader
6
+ from .word import MSWordLoader
7
+ from .ppt import PPTXLoader
8
+ from .repo import RepositoryLoader
9
+ from .github import GithubLoader
10
+ from .json import JSONLoader
11
+ from .excel import ExcelLoader
12
+ from .web_base import WebBaseLoader
13
+ from .pdfmark import PDFMarkdownLoader
14
+ from .pdfimages import PDFImageLoader
15
+ from .pdftables import PDFTablesLoader
16
+ from .pdfchapters import PDFChapterLoader
17
+ from .txt import TXTLoader
18
+ from .qa import QAFileLoader
19
+ from .rtd import ReadTheDocsLoader
20
+ from .videolocal import VideoLocalLoader
@@ -0,0 +1,456 @@
1
+ """Loaders are classes that are responsible for loading data from a source
2
+ and returning it as a Langchain Document.
3
+ """
4
+ from __future__ import annotations
5
+ from abc import ABC, abstractmethod
6
+ from collections.abc import Callable
7
+ from typing import List, Union, Optional, Any
8
+ from pathlib import Path, PurePath
9
+ import torch
10
+ from langchain.docstore.document import Document
11
+ from langchain.chains.summarize import load_summarize_chain
12
+ from langchain.text_splitter import (
13
+ RecursiveCharacterTextSplitter,
14
+ TokenTextSplitter
15
+ )
16
+ from langchain_core.prompts import PromptTemplate
17
+ from langchain_core.document_loaders.blob_loaders import Blob
18
+ from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
19
+
20
+ from navconfig.logging import logging
21
+ from navigator.libs.json import JSONContent # pylint: disable=E0611
22
+ from transformers import (
23
+ AutoTokenizer,
24
+ AutoModelForSeq2SeqLM,
25
+ AutoConfig,
26
+ AutoModel,
27
+ pipeline
28
+ )
29
+ from ..conf import EMBEDDING_DEVICE, EMBEDDING_DEFAULT_MODEL
30
+
31
+
32
+ logging.getLogger(name='httpx').setLevel(logging.WARNING)
33
+ logging.getLogger(name='httpcore').setLevel(logging.WARNING)
34
+ logging.getLogger(name='pdfminer').setLevel(logging.WARNING)
35
+ logging.getLogger(name='langchain_community').setLevel(logging.WARNING)
36
+ logging.getLogger(name='numba').setLevel(logging.WARNING)
37
+ logging.getLogger(name='PIL').setLevel(level=logging.WARNING)
38
+
39
+
40
+ def as_string(self) -> str:
41
+ """Read data as a string."""
42
+ if self.data is None and self.path:
43
+ with open(str(self.path), "r", encoding=self.encoding) as f:
44
+ try:
45
+ return f.read()
46
+ except UnicodeDecodeError:
47
+ try:
48
+ with open(str(self.path), "r", encoding="latin-1") as f:
49
+ return f.read()
50
+ except UnicodeDecodeError:
51
+ with open(str(self.path), "rb") as f:
52
+ return f.read().decode("utf-8", "replace")
53
+ elif isinstance(self.data, bytes):
54
+ return self.data.decode(self.encoding)
55
+ elif isinstance(self.data, str):
56
+ return self.data
57
+ else:
58
+ raise ValueError(f"Unable to get string for blob {self}")
59
+
60
+ # Monkey patch the Blob class's as_string method
61
+ Blob.as_string = as_string
62
+
63
+
64
+ class AbstractLoader(ABC):
65
+ """
66
+ Abstract class for Document loaders.
67
+ """
68
+ _extension: List[str] = ['.txt']
69
+ encoding: str = 'utf-8'
70
+ skip_directories: List[str] = []
71
+ _chunk_size: int = 768
72
+
73
+ def __init__(
74
+ self,
75
+ tokenizer: Union[str, Callable] = None,
76
+ text_splitter: Union[str, Callable] = None,
77
+ source_type: str = 'file',
78
+ **kwargs
79
+ ):
80
+ self.tokenizer = tokenizer
81
+ self._summary_model = None
82
+ self.text_splitter = text_splitter
83
+ self._device = self._get_device()
84
+ self._chunk_size = kwargs.get('chunk_size', 768)
85
+ self._no_summarization = bool(
86
+ kwargs.get('no_summarization', False)
87
+ )
88
+ self.summarization_model = kwargs.get(
89
+ 'summarization_model',
90
+ "facebook/bart-large-cnn"
91
+ )
92
+ self._no_summarization = bool(
93
+ kwargs.get('no_summarization', False)
94
+ )
95
+ self._source_type = source_type
96
+ self.logger = logging.getLogger(
97
+ f"Loader.{self.__class__.__name__}"
98
+ )
99
+ if 'extension' in kwargs:
100
+ self._extension = kwargs['extension']
101
+ self.encoding = kwargs.get('encoding', 'utf-8')
102
+ self.skip_directories: List[str] = kwargs.get('skip_directories', [])
103
+ # LLM (if required)
104
+ self._llm = kwargs.get('llm', None)
105
+ if not self.tokenizer:
106
+ self.tokenizer = self.default_tokenizer()
107
+ elif isinstance(self.tokenizer, str):
108
+ self.tokenizer = self.get_tokenizer(
109
+ model_name=self.tokenizer
110
+ )
111
+ if not text_splitter:
112
+ self.text_splitter = self.default_splitter(
113
+ model=self.tokenizer
114
+ )
115
+ # JSON encoder:
116
+ self._encoder = JSONContent()
117
+
118
+
119
+ def __enter__(self):
120
+ return self
121
+
122
+ def __exit__(self, *exc_info):
123
+ self.post_load()
124
+
125
+ def default_tokenizer(self):
126
+ return self.get_tokenizer(
127
+ EMBEDDING_DEFAULT_MODEL,
128
+ chunk_size=768
129
+ )
130
+
131
+ def get_tokenizer(self, model_name: str, chunk_size: int = 768):
132
+ return AutoTokenizer.from_pretrained(
133
+ model_name,
134
+ chunk_size=chunk_size
135
+ )
136
+
137
+ def _get_device(self, cuda_number: int = 0):
138
+ if torch.cuda.is_available():
139
+ # Use CUDA GPU if available
140
+ device = torch.device(f'cuda:{cuda_number}')
141
+ elif torch.backends.mps.is_available():
142
+ # Use CUDA Multi-Processing Service if available
143
+ device = torch.device("mps")
144
+ elif EMBEDDING_DEVICE == 'cuda':
145
+ device = torch.device(f'cuda:{cuda_number}')
146
+ else:
147
+ device = torch.device(EMBEDDING_DEVICE)
148
+ return device
149
+
150
+ def get_model(self, model_name: str):
151
+ self._model_config = AutoConfig.from_pretrained(
152
+ model_name, trust_remote_code=True
153
+ )
154
+ return AutoModel.from_pretrained(
155
+ model_name,
156
+ trust_remote_code=True,
157
+ config=self._model_config,
158
+ unpad_inputs=True,
159
+ use_memory_efficient_attention=True,
160
+ ).to(self._device)
161
+
162
+ def get_summarization_model(self, model_name: str = 'facebook/bart-large-cnn'):
163
+ if self._no_summarization is True:
164
+ return None
165
+ if not self._summary_model:
166
+ summarize_model = AutoModelForSeq2SeqLM.from_pretrained(
167
+ model_name,
168
+ device_map="auto",
169
+ torch_dtype=torch.bfloat16,
170
+ trust_remote_code=True
171
+ )
172
+ summarize_tokenizer = AutoTokenizer.from_pretrained(
173
+ model_name,
174
+ padding_side="left"
175
+ )
176
+ pipe_summary = pipeline(
177
+ "summarization",
178
+ model=summarize_model,
179
+ tokenizer=summarize_tokenizer,
180
+ batch_size=True,
181
+ max_new_tokens=500,
182
+ min_new_tokens=300,
183
+ use_fast=True
184
+ )
185
+ self._summary_model = HuggingFacePipeline(
186
+ model_id=model_name,
187
+ pipeline=pipe_summary,
188
+ verbose=True
189
+ )
190
+ return self._summary_model
191
+
192
+ def get_text_splitter(self, model, chunk_size: int = 2000, overlap: int = 100):
193
+ return RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
194
+ model,
195
+ chunk_size=chunk_size,
196
+ chunk_overlap=overlap,
197
+ add_start_index=True, # If `True`, includes chunk's start index in metadata
198
+ strip_whitespace=True, # strips whitespace from the start and end
199
+ separators=["\n\n", "\n", "\r\n", "\r", "\f", "\v", "\x0b", "\x0c"],
200
+ )
201
+
202
+ def default_splitter(self, model: Callable):
203
+ """Get the text splitter."""
204
+ return self.get_text_splitter(
205
+ model,
206
+ chunk_size=2000,
207
+ overlap=100
208
+ )
209
+
210
+ def get_summary_from_text(self, text: str) -> str:
211
+ """
212
+ Get a summary of a text.
213
+ """
214
+ if not text:
215
+ # NO data to be summarized
216
+ return ''
217
+ try:
218
+ splitter = TokenTextSplitter(
219
+ chunk_size=5000,
220
+ chunk_overlap=100,
221
+ )
222
+ prompt_template = """Write a summary of the following, please also identify the main theme:
223
+ {text}
224
+ SUMMARY:"""
225
+ prompt = PromptTemplate.from_template(prompt_template)
226
+ refine_template = (
227
+ "Your job is to produce a final summary\n"
228
+ "We have provided an existing summary up to a certain point: {existing_answer}\n"
229
+ "We have the opportunity to refine the existing summary"
230
+ "(only if needed) with some more context below.\n"
231
+ "------------\n"
232
+ "{text}\n"
233
+ "------------\n"
234
+ "Given the new context, refine the original summary adding more explanation."
235
+ "If the context isn't useful, return the original summary."
236
+ )
237
+ refine_prompt = PromptTemplate.from_template(refine_template)
238
+ if self._llm:
239
+ llm = self._llm
240
+ else:
241
+ llm = self.get_summarization_model(
242
+ self.summarization_model
243
+ )
244
+ if not llm:
245
+ return ''
246
+ summarize_chain = load_summarize_chain(
247
+ llm=llm,
248
+ chain_type="refine",
249
+ question_prompt=prompt,
250
+ refine_prompt=refine_prompt,
251
+ return_intermediate_steps=True,
252
+ input_key="input_documents",
253
+ output_key="output_text",
254
+ )
255
+ chunks = splitter.split_text(text)
256
+ documents = [Document(page_content=chunk) for chunk in chunks]
257
+ summary = summarize_chain.invoke(
258
+ {"input_documents": documents}, return_only_outputs=True
259
+ )
260
+ return summary['output_text']
261
+ except Exception as e:
262
+ print('ERROR in get_summary_from_text:', e)
263
+ return ""
264
+
265
+ def split_documents(self, documents: List[Document], max_tokens: int = None) -> List[Document]:
266
+ """Split the documents into chunks."""
267
+ if not max_tokens:
268
+ max_tokens = self._chunk_size
269
+ split_documents = []
270
+ for doc in documents:
271
+ metadata = doc.metadata.copy()
272
+ chunks = self.text_splitter.split_text(doc.page_content)
273
+ for chunk in chunks:
274
+ split_documents.append(
275
+ Document(page_content=chunk, metadata=metadata)
276
+ )
277
+ return split_documents
278
+
279
+ def split_by_tokens(self, documents: List[Document], max_tokens: int = 768) -> List[Document]:
280
+ """Split the documents into chunks."""
281
+ split_documents = []
282
+ current_chunk = []
283
+ for doc in documents:
284
+ metadata = doc.metadata.copy()
285
+ tokens = self.tokenizer.tokenize(doc.page_content)
286
+ with torch.no_grad():
287
+ current_chunk = []
288
+ for token in tokens:
289
+ current_chunk.append(token)
290
+ if len(current_chunk) >= max_tokens:
291
+ chunk_text = self.tokenizer.convert_tokens_to_string(current_chunk)
292
+ # Create a new Document for this chunk, preserving metadata
293
+ split_doc = Document(
294
+ page_content=chunk_text,
295
+ metadata=metadata
296
+ )
297
+ split_documents.append(split_doc)
298
+ current_chunk = [] # Reset for the next chunk
299
+ # Handle the last chunk if it didn't reach the max_tokens limit
300
+ if current_chunk:
301
+ chunk_text = self.tokenizer.convert_tokens_to_string(current_chunk)
302
+ split_documents.append(
303
+ Document(page_content=chunk_text, metadata=metadata)
304
+ )
305
+ del tokens, current_chunk
306
+ torch.cuda.empty_cache()
307
+ return split_documents
308
+
309
+ def post_load(self):
310
+ self.tokenizer = None # Reset the tokenizer
311
+ self.text_splitter = None # Reset the text splitter
312
+ torch.cuda.synchronize() # Wait for all kernels to finish
313
+ torch.cuda.empty_cache() # Clear unused memory
314
+
315
+ def read_bytes(self, path: Union[str, PurePath]) -> bytes:
316
+ """Read the bytes from a file.
317
+
318
+ Args:
319
+ path (Union[str, PurePath]): The path to the file.
320
+
321
+ Returns:
322
+ bytes: The bytes of the file.
323
+ """
324
+ if isinstance(path, str):
325
+ path = PurePath(path)
326
+ with open(str(path), 'rb') as f:
327
+ return f.read()
328
+
329
+ def open_bytes(self, path: Union[str, PurePath]) -> Any:
330
+ """Open the bytes from a file.
331
+
332
+ Args:
333
+ path (Union[str, PurePath]): The path to the file.
334
+
335
+ Returns:
336
+ Any: The bytes of the file.
337
+ """
338
+ if isinstance(path, str):
339
+ path = PurePath(path)
340
+ return open(str(path), 'rb')
341
+
342
+ def read_string(self, path: Union[str, PurePath]) -> str:
343
+ """Read the string from a file.
344
+
345
+ Args:
346
+ path (Union[str, PurePath]): The path to the file.
347
+
348
+ Returns:
349
+ str: The string of the file.
350
+ """
351
+ if isinstance(path, str):
352
+ path = PurePath(path)
353
+ with open(str(path), 'r', encoding=self.encoding) as f:
354
+ return f.read()
355
+
356
+ def _check_path(
357
+ self,
358
+ path: PurePath,
359
+ suffix: Optional[List[str]] = None
360
+ ) -> bool:
361
+ """Check if the file path exists.
362
+ Args:
363
+ path (PurePath): The path to the file.
364
+ Returns:
365
+ bool: True if the file exists, False otherwise.
366
+ """
367
+ if isinstance(path, str):
368
+ path = Path(path).resolve()
369
+ if not suffix:
370
+ suffix = self._extension
371
+ return path.exists() and path.is_file() and path.suffix in suffix
372
+
373
+
374
+ @abstractmethod
375
+ def load(self, path: Union[str, PurePath]) -> List[Document]:
376
+ """Load data from a source and return it as a Langchain Document.
377
+
378
+ Args:
379
+ path (str): The source of the data.
380
+
381
+ Returns:
382
+ List[Document]: A list of Langchain Documents.
383
+ """
384
+ pass
385
+
386
+ @abstractmethod
387
+ def parse(self, source: Any) -> List[Document]:
388
+ """Parse data from a source and return it as a Langchain Document.
389
+
390
+ Args:
391
+ source (Any): The source of the data.
392
+
393
+ Returns:
394
+ List[Document]: A list of Langchain Documents.
395
+ """
396
+ pass
397
+
398
+ @classmethod
399
+ def from_path(
400
+ cls,
401
+ path: Union[str, PurePath],
402
+ text_splitter: Callable,
403
+ source_type: str = 'file',
404
+ **kwargs
405
+ ) -> List[Document]:
406
+ """Load Multiple documents from a Path
407
+
408
+ Args:
409
+ path (Union[str, PurePath]): The path to the file.
410
+
411
+ Returns:
412
+ -> List[Document]: A list of Langchain Documents.
413
+ """
414
+ if isinstance(path, str):
415
+ path = PurePath(path)
416
+ if path.is_dir():
417
+ documents = []
418
+ obj = cls(
419
+ tokenizer=kwargs.pop('tokenizer', None),
420
+ text_splitter=text_splitter,
421
+ source_type=source_type,
422
+ )
423
+ for ext in cls._extension:
424
+ for item in path.glob(f'*{ext}'):
425
+ if set(item.parts).isdisjoint(cls.skip_directories):
426
+ documents += obj.load(path=item, **kwargs)
427
+ # documents += cls.load(cls, path=item, **kwargs)
428
+ return documents
429
+
430
+ @classmethod
431
+ def from_url(
432
+ cls,
433
+ urls: List[str],
434
+ text_splitter: Callable,
435
+ source_type: str = 'content',
436
+ **kwargs
437
+ ) -> List[Document]:
438
+ """Load Multiple documents from a URL
439
+
440
+ Args:
441
+ urls (List[str]): The list of URLs.
442
+
443
+ Returns:
444
+ -> List[Document]: A list of Langchain Documents.
445
+ """
446
+ documents = []
447
+ cls.tokenizer=kwargs.pop('tokenizer', None),
448
+ cls.text_splitter = text_splitter
449
+ cls._source_type = source_type
450
+ cls.summarization_model = kwargs.pop(
451
+ 'summarization_model',
452
+ "facebook/bart-large-cnn"
453
+ )
454
+ for url in urls:
455
+ documents += cls.load(url, **kwargs)
456
+ return documents