ai-parrot 0.1.0__cp311-cp311-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-parrot might be problematic. Click here for more details.
- ai_parrot-0.1.0.dist-info/LICENSE +21 -0
- ai_parrot-0.1.0.dist-info/METADATA +299 -0
- ai_parrot-0.1.0.dist-info/RECORD +108 -0
- ai_parrot-0.1.0.dist-info/WHEEL +5 -0
- ai_parrot-0.1.0.dist-info/top_level.txt +3 -0
- parrot/__init__.py +18 -0
- parrot/chatbots/__init__.py +7 -0
- parrot/chatbots/abstract.py +965 -0
- parrot/chatbots/asktroc.py +16 -0
- parrot/chatbots/base.py +257 -0
- parrot/chatbots/basic.py +9 -0
- parrot/chatbots/bose.py +17 -0
- parrot/chatbots/cody.py +17 -0
- parrot/chatbots/copilot.py +100 -0
- parrot/chatbots/dataframe.py +103 -0
- parrot/chatbots/hragents.py +15 -0
- parrot/chatbots/oddie.py +17 -0
- parrot/chatbots/retrievals/__init__.py +515 -0
- parrot/chatbots/retrievals/constitutional.py +19 -0
- parrot/conf.py +108 -0
- parrot/crew/__init__.py +3 -0
- parrot/crew/tools/__init__.py +22 -0
- parrot/crew/tools/bing.py +13 -0
- parrot/crew/tools/config.py +43 -0
- parrot/crew/tools/duckgo.py +62 -0
- parrot/crew/tools/file.py +24 -0
- parrot/crew/tools/google.py +168 -0
- parrot/crew/tools/gtrends.py +16 -0
- parrot/crew/tools/md2pdf.py +25 -0
- parrot/crew/tools/rag.py +42 -0
- parrot/crew/tools/search.py +32 -0
- parrot/crew/tools/url.py +21 -0
- parrot/exceptions.cpython-311-x86_64-linux-gnu.so +0 -0
- parrot/handlers/__init__.py +4 -0
- parrot/handlers/bots.py +196 -0
- parrot/handlers/chat.py +169 -0
- parrot/interfaces/__init__.py +6 -0
- parrot/interfaces/database.py +29 -0
- parrot/llms/__init__.py +0 -0
- parrot/llms/abstract.py +41 -0
- parrot/llms/anthropic.py +36 -0
- parrot/llms/google.py +37 -0
- parrot/llms/groq.py +33 -0
- parrot/llms/hf.py +39 -0
- parrot/llms/openai.py +49 -0
- parrot/llms/pipes.py +103 -0
- parrot/llms/vertex.py +68 -0
- parrot/loaders/__init__.py +20 -0
- parrot/loaders/abstract.py +456 -0
- parrot/loaders/basepdf.py +102 -0
- parrot/loaders/basevideo.py +280 -0
- parrot/loaders/csv.py +42 -0
- parrot/loaders/dir.py +37 -0
- parrot/loaders/excel.py +349 -0
- parrot/loaders/github.py +65 -0
- parrot/loaders/handlers/__init__.py +5 -0
- parrot/loaders/handlers/data.py +213 -0
- parrot/loaders/image.py +119 -0
- parrot/loaders/json.py +52 -0
- parrot/loaders/pdf.py +187 -0
- parrot/loaders/pdfchapters.py +142 -0
- parrot/loaders/pdffn.py +112 -0
- parrot/loaders/pdfimages.py +207 -0
- parrot/loaders/pdfmark.py +88 -0
- parrot/loaders/pdftables.py +145 -0
- parrot/loaders/ppt.py +30 -0
- parrot/loaders/qa.py +81 -0
- parrot/loaders/repo.py +103 -0
- parrot/loaders/rtd.py +65 -0
- parrot/loaders/txt.py +92 -0
- parrot/loaders/utils/__init__.py +1 -0
- parrot/loaders/utils/models.py +25 -0
- parrot/loaders/video.py +96 -0
- parrot/loaders/videolocal.py +107 -0
- parrot/loaders/vimeo.py +106 -0
- parrot/loaders/web.py +216 -0
- parrot/loaders/web_base.py +112 -0
- parrot/loaders/word.py +125 -0
- parrot/loaders/youtube.py +192 -0
- parrot/manager.py +152 -0
- parrot/models.py +347 -0
- parrot/py.typed +0 -0
- parrot/stores/__init__.py +0 -0
- parrot/stores/abstract.py +170 -0
- parrot/stores/milvus.py +540 -0
- parrot/stores/qdrant.py +153 -0
- parrot/tools/__init__.py +16 -0
- parrot/tools/abstract.py +53 -0
- parrot/tools/asknews.py +32 -0
- parrot/tools/bing.py +13 -0
- parrot/tools/duck.py +62 -0
- parrot/tools/google.py +170 -0
- parrot/tools/stack.py +26 -0
- parrot/tools/weather.py +70 -0
- parrot/tools/wikipedia.py +59 -0
- parrot/tools/zipcode.py +179 -0
- parrot/utils/__init__.py +2 -0
- parrot/utils/parsers/__init__.py +5 -0
- parrot/utils/parsers/toml.cpython-311-x86_64-linux-gnu.so +0 -0
- parrot/utils/toml.py +11 -0
- parrot/utils/types.cpython-311-x86_64-linux-gnu.so +0 -0
- parrot/utils/uv.py +11 -0
- parrot/version.py +10 -0
- resources/users/__init__.py +5 -0
- resources/users/handlers.py +13 -0
- resources/users/models.py +205 -0
- settings/__init__.py +0 -0
- settings/settings.py +51 -0
|
@@ -0,0 +1,456 @@
|
|
|
1
|
+
"""Loaders are classes that are responsible for loading data from a source
|
|
2
|
+
and returning it as a Langchain Document.
|
|
3
|
+
"""
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from typing import List, Union, Optional, Any
|
|
8
|
+
from pathlib import Path, PurePath
|
|
9
|
+
import torch
|
|
10
|
+
from langchain.docstore.document import Document
|
|
11
|
+
from langchain.chains.summarize import load_summarize_chain
|
|
12
|
+
from langchain.text_splitter import (
|
|
13
|
+
RecursiveCharacterTextSplitter,
|
|
14
|
+
TokenTextSplitter
|
|
15
|
+
)
|
|
16
|
+
from langchain_core.prompts import PromptTemplate
|
|
17
|
+
from langchain_core.document_loaders.blob_loaders import Blob
|
|
18
|
+
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
|
19
|
+
|
|
20
|
+
from navconfig.logging import logging
|
|
21
|
+
from navigator.libs.json import JSONContent # pylint: disable=E0611
|
|
22
|
+
from transformers import (
|
|
23
|
+
AutoTokenizer,
|
|
24
|
+
AutoModelForSeq2SeqLM,
|
|
25
|
+
AutoConfig,
|
|
26
|
+
AutoModel,
|
|
27
|
+
pipeline
|
|
28
|
+
)
|
|
29
|
+
from ..conf import EMBEDDING_DEVICE, EMBEDDING_DEFAULT_MODEL
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
logging.getLogger(name='httpx').setLevel(logging.WARNING)
|
|
33
|
+
logging.getLogger(name='httpcore').setLevel(logging.WARNING)
|
|
34
|
+
logging.getLogger(name='pdfminer').setLevel(logging.WARNING)
|
|
35
|
+
logging.getLogger(name='langchain_community').setLevel(logging.WARNING)
|
|
36
|
+
logging.getLogger(name='numba').setLevel(logging.WARNING)
|
|
37
|
+
logging.getLogger(name='PIL').setLevel(level=logging.WARNING)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def as_string(self) -> str:
|
|
41
|
+
"""Read data as a string."""
|
|
42
|
+
if self.data is None and self.path:
|
|
43
|
+
with open(str(self.path), "r", encoding=self.encoding) as f:
|
|
44
|
+
try:
|
|
45
|
+
return f.read()
|
|
46
|
+
except UnicodeDecodeError:
|
|
47
|
+
try:
|
|
48
|
+
with open(str(self.path), "r", encoding="latin-1") as f:
|
|
49
|
+
return f.read()
|
|
50
|
+
except UnicodeDecodeError:
|
|
51
|
+
with open(str(self.path), "rb") as f:
|
|
52
|
+
return f.read().decode("utf-8", "replace")
|
|
53
|
+
elif isinstance(self.data, bytes):
|
|
54
|
+
return self.data.decode(self.encoding)
|
|
55
|
+
elif isinstance(self.data, str):
|
|
56
|
+
return self.data
|
|
57
|
+
else:
|
|
58
|
+
raise ValueError(f"Unable to get string for blob {self}")
|
|
59
|
+
|
|
60
|
+
# Monkey patch the Blob class's as_string method
|
|
61
|
+
Blob.as_string = as_string
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class AbstractLoader(ABC):
|
|
65
|
+
"""
|
|
66
|
+
Abstract class for Document loaders.
|
|
67
|
+
"""
|
|
68
|
+
_extension: List[str] = ['.txt']
|
|
69
|
+
encoding: str = 'utf-8'
|
|
70
|
+
skip_directories: List[str] = []
|
|
71
|
+
_chunk_size: int = 768
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
tokenizer: Union[str, Callable] = None,
|
|
76
|
+
text_splitter: Union[str, Callable] = None,
|
|
77
|
+
source_type: str = 'file',
|
|
78
|
+
**kwargs
|
|
79
|
+
):
|
|
80
|
+
self.tokenizer = tokenizer
|
|
81
|
+
self._summary_model = None
|
|
82
|
+
self.text_splitter = text_splitter
|
|
83
|
+
self._device = self._get_device()
|
|
84
|
+
self._chunk_size = kwargs.get('chunk_size', 768)
|
|
85
|
+
self._no_summarization = bool(
|
|
86
|
+
kwargs.get('no_summarization', False)
|
|
87
|
+
)
|
|
88
|
+
self.summarization_model = kwargs.get(
|
|
89
|
+
'summarization_model',
|
|
90
|
+
"facebook/bart-large-cnn"
|
|
91
|
+
)
|
|
92
|
+
self._no_summarization = bool(
|
|
93
|
+
kwargs.get('no_summarization', False)
|
|
94
|
+
)
|
|
95
|
+
self._source_type = source_type
|
|
96
|
+
self.logger = logging.getLogger(
|
|
97
|
+
f"Loader.{self.__class__.__name__}"
|
|
98
|
+
)
|
|
99
|
+
if 'extension' in kwargs:
|
|
100
|
+
self._extension = kwargs['extension']
|
|
101
|
+
self.encoding = kwargs.get('encoding', 'utf-8')
|
|
102
|
+
self.skip_directories: List[str] = kwargs.get('skip_directories', [])
|
|
103
|
+
# LLM (if required)
|
|
104
|
+
self._llm = kwargs.get('llm', None)
|
|
105
|
+
if not self.tokenizer:
|
|
106
|
+
self.tokenizer = self.default_tokenizer()
|
|
107
|
+
elif isinstance(self.tokenizer, str):
|
|
108
|
+
self.tokenizer = self.get_tokenizer(
|
|
109
|
+
model_name=self.tokenizer
|
|
110
|
+
)
|
|
111
|
+
if not text_splitter:
|
|
112
|
+
self.text_splitter = self.default_splitter(
|
|
113
|
+
model=self.tokenizer
|
|
114
|
+
)
|
|
115
|
+
# JSON encoder:
|
|
116
|
+
self._encoder = JSONContent()
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def __enter__(self):
|
|
120
|
+
return self
|
|
121
|
+
|
|
122
|
+
def __exit__(self, *exc_info):
|
|
123
|
+
self.post_load()
|
|
124
|
+
|
|
125
|
+
def default_tokenizer(self):
|
|
126
|
+
return self.get_tokenizer(
|
|
127
|
+
EMBEDDING_DEFAULT_MODEL,
|
|
128
|
+
chunk_size=768
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
def get_tokenizer(self, model_name: str, chunk_size: int = 768):
|
|
132
|
+
return AutoTokenizer.from_pretrained(
|
|
133
|
+
model_name,
|
|
134
|
+
chunk_size=chunk_size
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
def _get_device(self, cuda_number: int = 0):
|
|
138
|
+
if torch.cuda.is_available():
|
|
139
|
+
# Use CUDA GPU if available
|
|
140
|
+
device = torch.device(f'cuda:{cuda_number}')
|
|
141
|
+
elif torch.backends.mps.is_available():
|
|
142
|
+
# Use CUDA Multi-Processing Service if available
|
|
143
|
+
device = torch.device("mps")
|
|
144
|
+
elif EMBEDDING_DEVICE == 'cuda':
|
|
145
|
+
device = torch.device(f'cuda:{cuda_number}')
|
|
146
|
+
else:
|
|
147
|
+
device = torch.device(EMBEDDING_DEVICE)
|
|
148
|
+
return device
|
|
149
|
+
|
|
150
|
+
def get_model(self, model_name: str):
|
|
151
|
+
self._model_config = AutoConfig.from_pretrained(
|
|
152
|
+
model_name, trust_remote_code=True
|
|
153
|
+
)
|
|
154
|
+
return AutoModel.from_pretrained(
|
|
155
|
+
model_name,
|
|
156
|
+
trust_remote_code=True,
|
|
157
|
+
config=self._model_config,
|
|
158
|
+
unpad_inputs=True,
|
|
159
|
+
use_memory_efficient_attention=True,
|
|
160
|
+
).to(self._device)
|
|
161
|
+
|
|
162
|
+
def get_summarization_model(self, model_name: str = 'facebook/bart-large-cnn'):
|
|
163
|
+
if self._no_summarization is True:
|
|
164
|
+
return None
|
|
165
|
+
if not self._summary_model:
|
|
166
|
+
summarize_model = AutoModelForSeq2SeqLM.from_pretrained(
|
|
167
|
+
model_name,
|
|
168
|
+
device_map="auto",
|
|
169
|
+
torch_dtype=torch.bfloat16,
|
|
170
|
+
trust_remote_code=True
|
|
171
|
+
)
|
|
172
|
+
summarize_tokenizer = AutoTokenizer.from_pretrained(
|
|
173
|
+
model_name,
|
|
174
|
+
padding_side="left"
|
|
175
|
+
)
|
|
176
|
+
pipe_summary = pipeline(
|
|
177
|
+
"summarization",
|
|
178
|
+
model=summarize_model,
|
|
179
|
+
tokenizer=summarize_tokenizer,
|
|
180
|
+
batch_size=True,
|
|
181
|
+
max_new_tokens=500,
|
|
182
|
+
min_new_tokens=300,
|
|
183
|
+
use_fast=True
|
|
184
|
+
)
|
|
185
|
+
self._summary_model = HuggingFacePipeline(
|
|
186
|
+
model_id=model_name,
|
|
187
|
+
pipeline=pipe_summary,
|
|
188
|
+
verbose=True
|
|
189
|
+
)
|
|
190
|
+
return self._summary_model
|
|
191
|
+
|
|
192
|
+
def get_text_splitter(self, model, chunk_size: int = 2000, overlap: int = 100):
|
|
193
|
+
return RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
|
|
194
|
+
model,
|
|
195
|
+
chunk_size=chunk_size,
|
|
196
|
+
chunk_overlap=overlap,
|
|
197
|
+
add_start_index=True, # If `True`, includes chunk's start index in metadata
|
|
198
|
+
strip_whitespace=True, # strips whitespace from the start and end
|
|
199
|
+
separators=["\n\n", "\n", "\r\n", "\r", "\f", "\v", "\x0b", "\x0c"],
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
def default_splitter(self, model: Callable):
|
|
203
|
+
"""Get the text splitter."""
|
|
204
|
+
return self.get_text_splitter(
|
|
205
|
+
model,
|
|
206
|
+
chunk_size=2000,
|
|
207
|
+
overlap=100
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
def get_summary_from_text(self, text: str) -> str:
|
|
211
|
+
"""
|
|
212
|
+
Get a summary of a text.
|
|
213
|
+
"""
|
|
214
|
+
if not text:
|
|
215
|
+
# NO data to be summarized
|
|
216
|
+
return ''
|
|
217
|
+
try:
|
|
218
|
+
splitter = TokenTextSplitter(
|
|
219
|
+
chunk_size=5000,
|
|
220
|
+
chunk_overlap=100,
|
|
221
|
+
)
|
|
222
|
+
prompt_template = """Write a summary of the following, please also identify the main theme:
|
|
223
|
+
{text}
|
|
224
|
+
SUMMARY:"""
|
|
225
|
+
prompt = PromptTemplate.from_template(prompt_template)
|
|
226
|
+
refine_template = (
|
|
227
|
+
"Your job is to produce a final summary\n"
|
|
228
|
+
"We have provided an existing summary up to a certain point: {existing_answer}\n"
|
|
229
|
+
"We have the opportunity to refine the existing summary"
|
|
230
|
+
"(only if needed) with some more context below.\n"
|
|
231
|
+
"------------\n"
|
|
232
|
+
"{text}\n"
|
|
233
|
+
"------------\n"
|
|
234
|
+
"Given the new context, refine the original summary adding more explanation."
|
|
235
|
+
"If the context isn't useful, return the original summary."
|
|
236
|
+
)
|
|
237
|
+
refine_prompt = PromptTemplate.from_template(refine_template)
|
|
238
|
+
if self._llm:
|
|
239
|
+
llm = self._llm
|
|
240
|
+
else:
|
|
241
|
+
llm = self.get_summarization_model(
|
|
242
|
+
self.summarization_model
|
|
243
|
+
)
|
|
244
|
+
if not llm:
|
|
245
|
+
return ''
|
|
246
|
+
summarize_chain = load_summarize_chain(
|
|
247
|
+
llm=llm,
|
|
248
|
+
chain_type="refine",
|
|
249
|
+
question_prompt=prompt,
|
|
250
|
+
refine_prompt=refine_prompt,
|
|
251
|
+
return_intermediate_steps=True,
|
|
252
|
+
input_key="input_documents",
|
|
253
|
+
output_key="output_text",
|
|
254
|
+
)
|
|
255
|
+
chunks = splitter.split_text(text)
|
|
256
|
+
documents = [Document(page_content=chunk) for chunk in chunks]
|
|
257
|
+
summary = summarize_chain.invoke(
|
|
258
|
+
{"input_documents": documents}, return_only_outputs=True
|
|
259
|
+
)
|
|
260
|
+
return summary['output_text']
|
|
261
|
+
except Exception as e:
|
|
262
|
+
print('ERROR in get_summary_from_text:', e)
|
|
263
|
+
return ""
|
|
264
|
+
|
|
265
|
+
def split_documents(self, documents: List[Document], max_tokens: int = None) -> List[Document]:
|
|
266
|
+
"""Split the documents into chunks."""
|
|
267
|
+
if not max_tokens:
|
|
268
|
+
max_tokens = self._chunk_size
|
|
269
|
+
split_documents = []
|
|
270
|
+
for doc in documents:
|
|
271
|
+
metadata = doc.metadata.copy()
|
|
272
|
+
chunks = self.text_splitter.split_text(doc.page_content)
|
|
273
|
+
for chunk in chunks:
|
|
274
|
+
split_documents.append(
|
|
275
|
+
Document(page_content=chunk, metadata=metadata)
|
|
276
|
+
)
|
|
277
|
+
return split_documents
|
|
278
|
+
|
|
279
|
+
def split_by_tokens(self, documents: List[Document], max_tokens: int = 768) -> List[Document]:
|
|
280
|
+
"""Split the documents into chunks."""
|
|
281
|
+
split_documents = []
|
|
282
|
+
current_chunk = []
|
|
283
|
+
for doc in documents:
|
|
284
|
+
metadata = doc.metadata.copy()
|
|
285
|
+
tokens = self.tokenizer.tokenize(doc.page_content)
|
|
286
|
+
with torch.no_grad():
|
|
287
|
+
current_chunk = []
|
|
288
|
+
for token in tokens:
|
|
289
|
+
current_chunk.append(token)
|
|
290
|
+
if len(current_chunk) >= max_tokens:
|
|
291
|
+
chunk_text = self.tokenizer.convert_tokens_to_string(current_chunk)
|
|
292
|
+
# Create a new Document for this chunk, preserving metadata
|
|
293
|
+
split_doc = Document(
|
|
294
|
+
page_content=chunk_text,
|
|
295
|
+
metadata=metadata
|
|
296
|
+
)
|
|
297
|
+
split_documents.append(split_doc)
|
|
298
|
+
current_chunk = [] # Reset for the next chunk
|
|
299
|
+
# Handle the last chunk if it didn't reach the max_tokens limit
|
|
300
|
+
if current_chunk:
|
|
301
|
+
chunk_text = self.tokenizer.convert_tokens_to_string(current_chunk)
|
|
302
|
+
split_documents.append(
|
|
303
|
+
Document(page_content=chunk_text, metadata=metadata)
|
|
304
|
+
)
|
|
305
|
+
del tokens, current_chunk
|
|
306
|
+
torch.cuda.empty_cache()
|
|
307
|
+
return split_documents
|
|
308
|
+
|
|
309
|
+
def post_load(self):
|
|
310
|
+
self.tokenizer = None # Reset the tokenizer
|
|
311
|
+
self.text_splitter = None # Reset the text splitter
|
|
312
|
+
torch.cuda.synchronize() # Wait for all kernels to finish
|
|
313
|
+
torch.cuda.empty_cache() # Clear unused memory
|
|
314
|
+
|
|
315
|
+
def read_bytes(self, path: Union[str, PurePath]) -> bytes:
|
|
316
|
+
"""Read the bytes from a file.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
path (Union[str, PurePath]): The path to the file.
|
|
320
|
+
|
|
321
|
+
Returns:
|
|
322
|
+
bytes: The bytes of the file.
|
|
323
|
+
"""
|
|
324
|
+
if isinstance(path, str):
|
|
325
|
+
path = PurePath(path)
|
|
326
|
+
with open(str(path), 'rb') as f:
|
|
327
|
+
return f.read()
|
|
328
|
+
|
|
329
|
+
def open_bytes(self, path: Union[str, PurePath]) -> Any:
|
|
330
|
+
"""Open the bytes from a file.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
path (Union[str, PurePath]): The path to the file.
|
|
334
|
+
|
|
335
|
+
Returns:
|
|
336
|
+
Any: The bytes of the file.
|
|
337
|
+
"""
|
|
338
|
+
if isinstance(path, str):
|
|
339
|
+
path = PurePath(path)
|
|
340
|
+
return open(str(path), 'rb')
|
|
341
|
+
|
|
342
|
+
def read_string(self, path: Union[str, PurePath]) -> str:
|
|
343
|
+
"""Read the string from a file.
|
|
344
|
+
|
|
345
|
+
Args:
|
|
346
|
+
path (Union[str, PurePath]): The path to the file.
|
|
347
|
+
|
|
348
|
+
Returns:
|
|
349
|
+
str: The string of the file.
|
|
350
|
+
"""
|
|
351
|
+
if isinstance(path, str):
|
|
352
|
+
path = PurePath(path)
|
|
353
|
+
with open(str(path), 'r', encoding=self.encoding) as f:
|
|
354
|
+
return f.read()
|
|
355
|
+
|
|
356
|
+
def _check_path(
|
|
357
|
+
self,
|
|
358
|
+
path: PurePath,
|
|
359
|
+
suffix: Optional[List[str]] = None
|
|
360
|
+
) -> bool:
|
|
361
|
+
"""Check if the file path exists.
|
|
362
|
+
Args:
|
|
363
|
+
path (PurePath): The path to the file.
|
|
364
|
+
Returns:
|
|
365
|
+
bool: True if the file exists, False otherwise.
|
|
366
|
+
"""
|
|
367
|
+
if isinstance(path, str):
|
|
368
|
+
path = Path(path).resolve()
|
|
369
|
+
if not suffix:
|
|
370
|
+
suffix = self._extension
|
|
371
|
+
return path.exists() and path.is_file() and path.suffix in suffix
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
@abstractmethod
|
|
375
|
+
def load(self, path: Union[str, PurePath]) -> List[Document]:
|
|
376
|
+
"""Load data from a source and return it as a Langchain Document.
|
|
377
|
+
|
|
378
|
+
Args:
|
|
379
|
+
path (str): The source of the data.
|
|
380
|
+
|
|
381
|
+
Returns:
|
|
382
|
+
List[Document]: A list of Langchain Documents.
|
|
383
|
+
"""
|
|
384
|
+
pass
|
|
385
|
+
|
|
386
|
+
@abstractmethod
|
|
387
|
+
def parse(self, source: Any) -> List[Document]:
|
|
388
|
+
"""Parse data from a source and return it as a Langchain Document.
|
|
389
|
+
|
|
390
|
+
Args:
|
|
391
|
+
source (Any): The source of the data.
|
|
392
|
+
|
|
393
|
+
Returns:
|
|
394
|
+
List[Document]: A list of Langchain Documents.
|
|
395
|
+
"""
|
|
396
|
+
pass
|
|
397
|
+
|
|
398
|
+
@classmethod
|
|
399
|
+
def from_path(
|
|
400
|
+
cls,
|
|
401
|
+
path: Union[str, PurePath],
|
|
402
|
+
text_splitter: Callable,
|
|
403
|
+
source_type: str = 'file',
|
|
404
|
+
**kwargs
|
|
405
|
+
) -> List[Document]:
|
|
406
|
+
"""Load Multiple documents from a Path
|
|
407
|
+
|
|
408
|
+
Args:
|
|
409
|
+
path (Union[str, PurePath]): The path to the file.
|
|
410
|
+
|
|
411
|
+
Returns:
|
|
412
|
+
-> List[Document]: A list of Langchain Documents.
|
|
413
|
+
"""
|
|
414
|
+
if isinstance(path, str):
|
|
415
|
+
path = PurePath(path)
|
|
416
|
+
if path.is_dir():
|
|
417
|
+
documents = []
|
|
418
|
+
obj = cls(
|
|
419
|
+
tokenizer=kwargs.pop('tokenizer', None),
|
|
420
|
+
text_splitter=text_splitter,
|
|
421
|
+
source_type=source_type,
|
|
422
|
+
)
|
|
423
|
+
for ext in cls._extension:
|
|
424
|
+
for item in path.glob(f'*{ext}'):
|
|
425
|
+
if set(item.parts).isdisjoint(cls.skip_directories):
|
|
426
|
+
documents += obj.load(path=item, **kwargs)
|
|
427
|
+
# documents += cls.load(cls, path=item, **kwargs)
|
|
428
|
+
return documents
|
|
429
|
+
|
|
430
|
+
@classmethod
|
|
431
|
+
def from_url(
|
|
432
|
+
cls,
|
|
433
|
+
urls: List[str],
|
|
434
|
+
text_splitter: Callable,
|
|
435
|
+
source_type: str = 'content',
|
|
436
|
+
**kwargs
|
|
437
|
+
) -> List[Document]:
|
|
438
|
+
"""Load Multiple documents from a URL
|
|
439
|
+
|
|
440
|
+
Args:
|
|
441
|
+
urls (List[str]): The list of URLs.
|
|
442
|
+
|
|
443
|
+
Returns:
|
|
444
|
+
-> List[Document]: A list of Langchain Documents.
|
|
445
|
+
"""
|
|
446
|
+
documents = []
|
|
447
|
+
cls.tokenizer=kwargs.pop('tokenizer', None),
|
|
448
|
+
cls.text_splitter = text_splitter
|
|
449
|
+
cls._source_type = source_type
|
|
450
|
+
cls.summarization_model = kwargs.pop(
|
|
451
|
+
'summarization_model',
|
|
452
|
+
"facebook/bart-large-cnn"
|
|
453
|
+
)
|
|
454
|
+
for url in urls:
|
|
455
|
+
documents += cls.load(url, **kwargs)
|
|
456
|
+
return documents
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import Any
|
|
3
|
+
from abc import abstractmethod
|
|
4
|
+
from pathlib import Path, PurePath
|
|
5
|
+
from PIL import Image
|
|
6
|
+
from .abstract import AbstractLoader
|
|
7
|
+
from ..conf import STATIC_DIR
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BasePDF(AbstractLoader):
|
|
11
|
+
"""
|
|
12
|
+
Base Abstract loader for all PDF files.
|
|
13
|
+
"""
|
|
14
|
+
_extension = ['.pdf']
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
path: PurePath,
|
|
19
|
+
tokenizer: Callable[..., Any] = None,
|
|
20
|
+
text_splitter: Callable[..., Any] = None,
|
|
21
|
+
source_type: str = 'pdf',
|
|
22
|
+
language: str = "eng",
|
|
23
|
+
**kwargs
|
|
24
|
+
):
|
|
25
|
+
super().__init__(tokenizer, text_splitter, source_type=source_type, **kwargs)
|
|
26
|
+
self.path = path
|
|
27
|
+
if isinstance(path, str):
|
|
28
|
+
self.path = Path(path).resolve()
|
|
29
|
+
self.save_images: bool = bool(kwargs.get('save_images', False))
|
|
30
|
+
self._imgdir = STATIC_DIR.joinpath('images')
|
|
31
|
+
if self.save_images is True:
|
|
32
|
+
if self._imgdir.exists() is False:
|
|
33
|
+
self._imgdir.mkdir(parents=True, exist_ok=True)
|
|
34
|
+
if language == 'en':
|
|
35
|
+
language = 'eng'
|
|
36
|
+
self._lang = language
|
|
37
|
+
|
|
38
|
+
def save_image(self, img_stream: Image, image_name: str, save_path: Path):
|
|
39
|
+
# Create the image directory if it does not exist
|
|
40
|
+
if save_path.exists() is False:
|
|
41
|
+
save_path.mkdir(parents=True, exist_ok=True)
|
|
42
|
+
img_path = save_path.joinpath(image_name)
|
|
43
|
+
self.logger.notice(
|
|
44
|
+
f"Saving Image Page on {img_path}"
|
|
45
|
+
)
|
|
46
|
+
if not img_path.exists():
|
|
47
|
+
# Save the image
|
|
48
|
+
img_stream.save(img_path, format="PNG", optimize=True)
|
|
49
|
+
return img_path
|
|
50
|
+
|
|
51
|
+
@abstractmethod
|
|
52
|
+
def _load_pdf(self, path: Path) -> list:
|
|
53
|
+
"""
|
|
54
|
+
Load a PDF file using Fitz.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
path (Path): The path to the PDF file.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
list: A list of Langchain Documents.
|
|
61
|
+
"""
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
def load(self) -> list:
|
|
65
|
+
"""
|
|
66
|
+
Load data from a PDF file.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
source (str): The path to the PDF file.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
list: A list of Langchain Documents.
|
|
73
|
+
"""
|
|
74
|
+
if isinstance(self.path, list):
|
|
75
|
+
# list of files:
|
|
76
|
+
documents = []
|
|
77
|
+
for p in self.path:
|
|
78
|
+
if self._check_path(p):
|
|
79
|
+
documents.extend(self._load_pdf(p))
|
|
80
|
+
if not self.path.exists():
|
|
81
|
+
raise FileNotFoundError(
|
|
82
|
+
f"PDF file/directory not found: {self.path}"
|
|
83
|
+
)
|
|
84
|
+
if self.path.is_dir():
|
|
85
|
+
documents = []
|
|
86
|
+
# iterate over the files in the directory
|
|
87
|
+
for ext in self._extension:
|
|
88
|
+
for item in self.path.glob(f'*{ext}'):
|
|
89
|
+
if set(item.parts).isdisjoint(self.skip_directories):
|
|
90
|
+
documents.extend(self._load_pdf(item))
|
|
91
|
+
elif self.path.is_file():
|
|
92
|
+
documents = self._load_pdf(self.path)
|
|
93
|
+
else:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
f"PDF Loader: Invalid path: {self.path}"
|
|
96
|
+
)
|
|
97
|
+
return self.split_documents(documents)
|
|
98
|
+
|
|
99
|
+
def parse(self, source):
|
|
100
|
+
raise NotImplementedError(
|
|
101
|
+
"Parser method is not implemented for PDFLoader."
|
|
102
|
+
)
|