ai-parrot 0.8.3__cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-parrot might be problematic. Click here for more details.
- ai_parrot-0.8.3.dist-info/LICENSE +21 -0
- ai_parrot-0.8.3.dist-info/METADATA +306 -0
- ai_parrot-0.8.3.dist-info/RECORD +128 -0
- ai_parrot-0.8.3.dist-info/WHEEL +6 -0
- ai_parrot-0.8.3.dist-info/top_level.txt +2 -0
- parrot/__init__.py +30 -0
- parrot/bots/__init__.py +5 -0
- parrot/bots/abstract.py +1115 -0
- parrot/bots/agent.py +492 -0
- parrot/bots/basic.py +9 -0
- parrot/bots/bose.py +17 -0
- parrot/bots/chatbot.py +271 -0
- parrot/bots/cody.py +17 -0
- parrot/bots/copilot.py +117 -0
- parrot/bots/data.py +730 -0
- parrot/bots/dataframe.py +103 -0
- parrot/bots/hrbot.py +15 -0
- parrot/bots/interfaces/__init__.py +1 -0
- parrot/bots/interfaces/retrievers.py +12 -0
- parrot/bots/notebook.py +619 -0
- parrot/bots/odoo.py +17 -0
- parrot/bots/prompts/__init__.py +41 -0
- parrot/bots/prompts/agents.py +91 -0
- parrot/bots/prompts/data.py +214 -0
- parrot/bots/retrievals/__init__.py +1 -0
- parrot/bots/retrievals/constitutional.py +19 -0
- parrot/bots/retrievals/multi.py +122 -0
- parrot/bots/retrievals/retrieval.py +610 -0
- parrot/bots/tools/__init__.py +7 -0
- parrot/bots/tools/eda.py +325 -0
- parrot/bots/tools/pdf.py +50 -0
- parrot/bots/tools/plot.py +48 -0
- parrot/bots/troc.py +16 -0
- parrot/conf.py +170 -0
- parrot/crew/__init__.py +3 -0
- parrot/crew/tools/__init__.py +22 -0
- parrot/crew/tools/bing.py +13 -0
- parrot/crew/tools/config.py +43 -0
- parrot/crew/tools/duckgo.py +62 -0
- parrot/crew/tools/file.py +24 -0
- parrot/crew/tools/google.py +168 -0
- parrot/crew/tools/gtrends.py +16 -0
- parrot/crew/tools/md2pdf.py +25 -0
- parrot/crew/tools/rag.py +42 -0
- parrot/crew/tools/search.py +32 -0
- parrot/crew/tools/url.py +21 -0
- parrot/exceptions.cpython-39-x86_64-linux-gnu.so +0 -0
- parrot/handlers/__init__.py +4 -0
- parrot/handlers/agents.py +292 -0
- parrot/handlers/bots.py +196 -0
- parrot/handlers/chat.py +192 -0
- parrot/interfaces/__init__.py +6 -0
- parrot/interfaces/database.py +27 -0
- parrot/interfaces/http.py +805 -0
- parrot/interfaces/images/__init__.py +0 -0
- parrot/interfaces/images/plugins/__init__.py +18 -0
- parrot/interfaces/images/plugins/abstract.py +58 -0
- parrot/interfaces/images/plugins/exif.py +709 -0
- parrot/interfaces/images/plugins/hash.py +52 -0
- parrot/interfaces/images/plugins/vision.py +104 -0
- parrot/interfaces/images/plugins/yolo.py +66 -0
- parrot/interfaces/images/plugins/zerodetect.py +197 -0
- parrot/llms/__init__.py +1 -0
- parrot/llms/abstract.py +69 -0
- parrot/llms/anthropic.py +58 -0
- parrot/llms/gemma.py +15 -0
- parrot/llms/google.py +44 -0
- parrot/llms/groq.py +67 -0
- parrot/llms/hf.py +45 -0
- parrot/llms/openai.py +61 -0
- parrot/llms/pipes.py +114 -0
- parrot/llms/vertex.py +89 -0
- parrot/loaders/__init__.py +9 -0
- parrot/loaders/abstract.py +628 -0
- parrot/loaders/files/__init__.py +0 -0
- parrot/loaders/files/abstract.py +39 -0
- parrot/loaders/files/text.py +63 -0
- parrot/loaders/txt.py +26 -0
- parrot/manager.py +333 -0
- parrot/models.py +504 -0
- parrot/py.typed +0 -0
- parrot/stores/__init__.py +11 -0
- parrot/stores/abstract.py +248 -0
- parrot/stores/chroma.py +188 -0
- parrot/stores/duck.py +162 -0
- parrot/stores/embeddings/__init__.py +10 -0
- parrot/stores/embeddings/abstract.py +46 -0
- parrot/stores/embeddings/base.py +52 -0
- parrot/stores/embeddings/bge.py +20 -0
- parrot/stores/embeddings/fastembed.py +17 -0
- parrot/stores/embeddings/google.py +18 -0
- parrot/stores/embeddings/huggingface.py +20 -0
- parrot/stores/embeddings/ollama.py +14 -0
- parrot/stores/embeddings/openai.py +26 -0
- parrot/stores/embeddings/transformers.py +21 -0
- parrot/stores/embeddings/vertexai.py +17 -0
- parrot/stores/empty.py +10 -0
- parrot/stores/faiss.py +160 -0
- parrot/stores/milvus.py +397 -0
- parrot/stores/postgres.py +653 -0
- parrot/stores/qdrant.py +170 -0
- parrot/tools/__init__.py +23 -0
- parrot/tools/abstract.py +68 -0
- parrot/tools/asknews.py +33 -0
- parrot/tools/basic.py +51 -0
- parrot/tools/bby.py +359 -0
- parrot/tools/bing.py +13 -0
- parrot/tools/docx.py +343 -0
- parrot/tools/duck.py +62 -0
- parrot/tools/execute.py +56 -0
- parrot/tools/gamma.py +28 -0
- parrot/tools/google.py +170 -0
- parrot/tools/gvoice.py +301 -0
- parrot/tools/results.py +278 -0
- parrot/tools/stack.py +27 -0
- parrot/tools/weather.py +70 -0
- parrot/tools/wikipedia.py +58 -0
- parrot/tools/zipcode.py +198 -0
- parrot/utils/__init__.py +2 -0
- parrot/utils/parsers/__init__.py +5 -0
- parrot/utils/parsers/toml.cpython-39-x86_64-linux-gnu.so +0 -0
- parrot/utils/toml.py +11 -0
- parrot/utils/types.cpython-39-x86_64-linux-gnu.so +0 -0
- parrot/utils/uv.py +11 -0
- parrot/version.py +10 -0
- resources/users/__init__.py +5 -0
- resources/users/handlers.py +13 -0
- resources/users/models.py +205 -0
|
@@ -0,0 +1,628 @@
|
|
|
1
|
+
from typing import Generator, Union, List, Any, Optional, TypeVar
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from pathlib import Path, PurePath
|
|
6
|
+
import asyncio
|
|
7
|
+
import torch
|
|
8
|
+
from transformers import (
|
|
9
|
+
AutoModelForSeq2SeqLM,
|
|
10
|
+
AutoTokenizer,
|
|
11
|
+
pipeline
|
|
12
|
+
)
|
|
13
|
+
from langchain.schema.runnable import RunnablePassthrough
|
|
14
|
+
from langchain.chains.summarize import load_summarize_chain
|
|
15
|
+
from langchain.docstore.document import Document
|
|
16
|
+
from langchain.text_splitter import (
|
|
17
|
+
TokenTextSplitter
|
|
18
|
+
)
|
|
19
|
+
from langchain_core.prompts import PromptTemplate
|
|
20
|
+
from navconfig.logging import logging
|
|
21
|
+
from navigator.libs.json import JSONContent # pylint: disable=E0611
|
|
22
|
+
from parrot.llms.vertex import VertexLLM
|
|
23
|
+
from ..conf import (
|
|
24
|
+
DEFAULT_LLM_MODEL,
|
|
25
|
+
DEFAULT_LLM_TEMPERATURE,
|
|
26
|
+
CUDA_DEFAULT_DEVICE,
|
|
27
|
+
CUDA_DEFAULT_DEVICE_NUMBER
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
T = TypeVar('T')
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class AbstractLoader(ABC):
|
|
35
|
+
"""
|
|
36
|
+
Base class for all loaders. Loaders are responsible for loading data from various sources.
|
|
37
|
+
"""
|
|
38
|
+
extensions: List[str] = ['.*']
|
|
39
|
+
skip_directories: List[str] = []
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
*args,
|
|
44
|
+
tokenizer: Union[str, Callable] = None,
|
|
45
|
+
text_splitter: Union[str, Callable] = None,
|
|
46
|
+
source_type: str = 'file',
|
|
47
|
+
**kwargs
|
|
48
|
+
):
|
|
49
|
+
self.chunk_size: int = kwargs.get('chunk_size', 50)
|
|
50
|
+
self.semaphore = asyncio.Semaphore(kwargs.get('semaphore', 10))
|
|
51
|
+
self.extensions = kwargs.get('extensions', self.extensions)
|
|
52
|
+
self.skip_directories = kwargs.get('skip_directories', self.skip_directories)
|
|
53
|
+
self.encoding = kwargs.get('encoding', 'utf-8')
|
|
54
|
+
self._source_type = source_type
|
|
55
|
+
self._recursive: bool = kwargs.get('recursive', False)
|
|
56
|
+
self.category: str = kwargs.get('category', 'document')
|
|
57
|
+
self.doctype: str = kwargs.get('doctype', 'text')
|
|
58
|
+
self._summarization = kwargs.get('summarization', False)
|
|
59
|
+
self._summary_model: Optional[Any] = kwargs.get('summary_model', None)
|
|
60
|
+
self._use_summary_pipeline: bool = kwargs.get('use_summary_pipeline', False)
|
|
61
|
+
self._use_translation_pipeline: bool = kwargs.get('use_translation_pipeline', False)
|
|
62
|
+
self._translation = kwargs.get('translation', False)
|
|
63
|
+
# Tokenizer
|
|
64
|
+
self.tokenizer = tokenizer
|
|
65
|
+
# Text Splitter
|
|
66
|
+
self.text_splitter = text_splitter
|
|
67
|
+
# Summarization Model:
|
|
68
|
+
self.summarization_model = kwargs.get('summarizer', None)
|
|
69
|
+
# Markdown Splitter:
|
|
70
|
+
self.markdown_splitter = kwargs.get('markdown_splitter', None)
|
|
71
|
+
if 'path' in kwargs:
|
|
72
|
+
self.path = kwargs['path']
|
|
73
|
+
if isinstance(self.path, str):
|
|
74
|
+
self.path = Path(self.path).resolve()
|
|
75
|
+
# LLM (if required)
|
|
76
|
+
self._use_llm = kwargs.get('use_llm', False)
|
|
77
|
+
self._llm_model = kwargs.get('llm_model', None)
|
|
78
|
+
self._llm_model_kwargs = kwargs.get('model_kwargs', {})
|
|
79
|
+
self._llm = kwargs.get('llm', None)
|
|
80
|
+
if self._use_llm:
|
|
81
|
+
self._llm = self.get_default_llm(
|
|
82
|
+
model=self._llm_model,
|
|
83
|
+
model_kwargs=self._llm_model_kwargs,
|
|
84
|
+
)
|
|
85
|
+
self.logger = logging.getLogger(
|
|
86
|
+
f"Parrot.Loaders.{self.__class__.__name__}"
|
|
87
|
+
)
|
|
88
|
+
# JSON encoder:
|
|
89
|
+
self._encoder = JSONContent()
|
|
90
|
+
# Use CUDA if available:
|
|
91
|
+
self.device_name = kwargs.get('device', CUDA_DEFAULT_DEVICE)
|
|
92
|
+
self.cuda_number = kwargs.get('cuda_number', CUDA_DEFAULT_DEVICE_NUMBER)
|
|
93
|
+
self._device = None
|
|
94
|
+
|
|
95
|
+
def get_default_llm(self, model: str = None, model_kwargs: dict = None):
|
|
96
|
+
"""Return a VertexLLM instance."""
|
|
97
|
+
if not model_kwargs:
|
|
98
|
+
model_kwargs = {
|
|
99
|
+
"temperature": DEFAULT_LLM_TEMPERATURE,
|
|
100
|
+
"top_k": 30,
|
|
101
|
+
"top_p": 0.5,
|
|
102
|
+
}
|
|
103
|
+
return VertexLLM(
|
|
104
|
+
model=model or DEFAULT_LLM_MODEL,
|
|
105
|
+
**model_kwargs
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
def _get_device(
|
|
109
|
+
self,
|
|
110
|
+
device_type: str = None,
|
|
111
|
+
cuda_number: int = 0
|
|
112
|
+
):
|
|
113
|
+
"""Get Default device for Torch and transformers.
|
|
114
|
+
|
|
115
|
+
"""
|
|
116
|
+
if device_type == 'cpu':
|
|
117
|
+
return torch.device('cpu')
|
|
118
|
+
if device_type == 'cuda':
|
|
119
|
+
return torch.device(f'cuda:{cuda_number}')
|
|
120
|
+
if CUDA_DEFAULT_DEVICE == 'cpu':
|
|
121
|
+
# Use CPU if CUDA is not available
|
|
122
|
+
return torch.device('cpu')
|
|
123
|
+
if torch.cuda.is_available():
|
|
124
|
+
# Use CUDA GPU if available
|
|
125
|
+
return torch.device(f'cuda:{cuda_number}')
|
|
126
|
+
if torch.backends.mps.is_available():
|
|
127
|
+
# Use CUDA Multi-Processing Service if available
|
|
128
|
+
return torch.device("mps")
|
|
129
|
+
if CUDA_DEFAULT_DEVICE == 'cuda':
|
|
130
|
+
return torch.device(f'cuda:{cuda_number}')
|
|
131
|
+
else:
|
|
132
|
+
return torch.device(CUDA_DEFAULT_DEVICE)
|
|
133
|
+
|
|
134
|
+
def clear_cuda(self):
|
|
135
|
+
self.tokenizer = None # Reset the tokenizer
|
|
136
|
+
self.text_splitter = None # Reset the text splitter
|
|
137
|
+
torch.cuda.synchronize() # Wait for all kernels to finish
|
|
138
|
+
torch.cuda.empty_cache() # Clear unused memory
|
|
139
|
+
|
|
140
|
+
async def __aenter__(self):
|
|
141
|
+
"""Open the loader if it has an open method."""
|
|
142
|
+
# Check if the loader has an open method and call it
|
|
143
|
+
if hasattr(self, "open"):
|
|
144
|
+
await self.open()
|
|
145
|
+
return self
|
|
146
|
+
|
|
147
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
148
|
+
"""Close the loader if it has a close method."""
|
|
149
|
+
if hasattr(self, "close"):
|
|
150
|
+
await self.close()
|
|
151
|
+
return True
|
|
152
|
+
|
|
153
|
+
def supported_extensions(self):
|
|
154
|
+
"""Get the supported file extensions."""
|
|
155
|
+
return self.extensions
|
|
156
|
+
|
|
157
|
+
def is_valid_path(self, path: Union[str, Path]) -> bool:
|
|
158
|
+
"""Check if a path is valid."""
|
|
159
|
+
if isinstance(path, str):
|
|
160
|
+
path = Path(path)
|
|
161
|
+
if not path.exists():
|
|
162
|
+
return False
|
|
163
|
+
if path.is_dir() and path.name in self.skip_directories:
|
|
164
|
+
return False
|
|
165
|
+
if path.is_file():
|
|
166
|
+
if path.suffix not in self.extensions:
|
|
167
|
+
return False
|
|
168
|
+
if path.name.startswith("."):
|
|
169
|
+
return False
|
|
170
|
+
# check if file is empty
|
|
171
|
+
if path.stat().st_size == 0:
|
|
172
|
+
return False
|
|
173
|
+
# check if file is inside of skip directories:
|
|
174
|
+
for skip_dir in self.skip_directories:
|
|
175
|
+
if path.is_relative_to(skip_dir):
|
|
176
|
+
return False
|
|
177
|
+
return True
|
|
178
|
+
|
|
179
|
+
@abstractmethod
|
|
180
|
+
async def _load(self, source: Union[str, PurePath], **kwargs) -> List[Document]:
|
|
181
|
+
"""Load a single data/url/file from a source and return it as a Langchain Document.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
source (str): The source of the data.
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
List[Document]: A list of Langchain Documents.
|
|
188
|
+
"""
|
|
189
|
+
pass
|
|
190
|
+
|
|
191
|
+
async def from_path(self, path: Union[str, Path], recursive: bool = False, **kwargs) -> List[asyncio.Task]:
|
|
192
|
+
"""
|
|
193
|
+
Load data from a path. This method should be overridden by subclasses.
|
|
194
|
+
"""
|
|
195
|
+
tasks = []
|
|
196
|
+
if isinstance(path, str):
|
|
197
|
+
path = PurePath(path)
|
|
198
|
+
if path.is_dir():
|
|
199
|
+
for ext in self.extensions:
|
|
200
|
+
glob_method = path.rglob if recursive else path.glob
|
|
201
|
+
# Use glob to find all files with the specified extension
|
|
202
|
+
for item in glob_method(f'*{ext}'):
|
|
203
|
+
# Check if the item is a directory and if it should be skipped
|
|
204
|
+
if set(item.parts).isdisjoint(self.skip_directories):
|
|
205
|
+
if self.is_valid_path(item):
|
|
206
|
+
tasks.append(
|
|
207
|
+
asyncio.create_task(self._load(item, **kwargs))
|
|
208
|
+
)
|
|
209
|
+
elif path.is_file():
|
|
210
|
+
if self.is_valid_path(path):
|
|
211
|
+
tasks.append(
|
|
212
|
+
asyncio.create_task(self._load(path, **kwargs))
|
|
213
|
+
)
|
|
214
|
+
else:
|
|
215
|
+
self.logger.warning(f"Path {path} is not valid.")
|
|
216
|
+
return tasks
|
|
217
|
+
|
|
218
|
+
async def from_url(
|
|
219
|
+
self,
|
|
220
|
+
url: Union[str, List[str]],
|
|
221
|
+
**kwargs
|
|
222
|
+
) -> List[asyncio.Task]:
|
|
223
|
+
"""
|
|
224
|
+
Load data from a URL. This method should be overridden by subclasses.
|
|
225
|
+
"""
|
|
226
|
+
tasks = []
|
|
227
|
+
if isinstance(url, str):
|
|
228
|
+
url = [url]
|
|
229
|
+
for item in url:
|
|
230
|
+
tasks.append(
|
|
231
|
+
asyncio.create_task(self._load(item, **kwargs))
|
|
232
|
+
)
|
|
233
|
+
return tasks
|
|
234
|
+
|
|
235
|
+
def chunkify(self, lst: List[T], n: int = 50) -> Generator[List[T], None, None]:
|
|
236
|
+
"""Split a List of objects into chunks of size n.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
lst: The list to split into chunks
|
|
240
|
+
n: The maximum size of each chunk
|
|
241
|
+
|
|
242
|
+
Yields:
|
|
243
|
+
List[T]: Chunks of the original list, each of size at most n
|
|
244
|
+
"""
|
|
245
|
+
for i in range(0, len(lst), n):
|
|
246
|
+
yield lst[i:i + n]
|
|
247
|
+
|
|
248
|
+
async def _async_map(self, func: Callable, iterable: list) -> list:
|
|
249
|
+
"""Run a function on a list of items asynchronously."""
|
|
250
|
+
async def async_func(item):
|
|
251
|
+
async with self.semaphore:
|
|
252
|
+
return await func(item)
|
|
253
|
+
|
|
254
|
+
tasks = [async_func(item) for item in iterable]
|
|
255
|
+
return await asyncio.gather(*tasks)
|
|
256
|
+
|
|
257
|
+
async def _load_tasks(self, tasks: list) -> list:
|
|
258
|
+
"""Load a list of tasks asynchronously."""
|
|
259
|
+
results = []
|
|
260
|
+
|
|
261
|
+
if not tasks:
|
|
262
|
+
return results
|
|
263
|
+
|
|
264
|
+
# Create a controlled task function to limit concurrency
|
|
265
|
+
async def controlled_task(task):
|
|
266
|
+
async with self.semaphore:
|
|
267
|
+
try:
|
|
268
|
+
return await task
|
|
269
|
+
except Exception as e:
|
|
270
|
+
self.logger.error(f"Task error: {e}")
|
|
271
|
+
return e
|
|
272
|
+
|
|
273
|
+
for chunk in self.chunkify(tasks, self.chunk_size):
|
|
274
|
+
# Wrap each task with semaphore control
|
|
275
|
+
controlled_tasks = [controlled_task(task) for task in chunk]
|
|
276
|
+
result = await asyncio.gather(*controlled_tasks, return_exceptions=True)
|
|
277
|
+
if result:
|
|
278
|
+
for res in result:
|
|
279
|
+
if isinstance(res, Exception):
|
|
280
|
+
# Handle the exception
|
|
281
|
+
self.logger.error(f"Error loading {res}")
|
|
282
|
+
else:
|
|
283
|
+
# Handle both single documents and lists of documents
|
|
284
|
+
if isinstance(res, list):
|
|
285
|
+
results.extend(res)
|
|
286
|
+
else:
|
|
287
|
+
results.append(res)
|
|
288
|
+
return results
|
|
289
|
+
|
|
290
|
+
async def load(
|
|
291
|
+
self,
|
|
292
|
+
source: Optional[Any] = None,
|
|
293
|
+
**kwargs
|
|
294
|
+
) -> List[Document]:
|
|
295
|
+
"""Load data from a source and return it as a list of Langchain Documents.
|
|
296
|
+
|
|
297
|
+
The source can be:
|
|
298
|
+
- None: Uses self.path attribute if available
|
|
299
|
+
- Path or str: Treated as file path or directory
|
|
300
|
+
- List[str/Path]: Treated as list of file paths
|
|
301
|
+
- URL string: Treated as a URL
|
|
302
|
+
- List of URLs: Treated as list of URLs
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
source (Optional[Any]): The source of the data.
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
List[Document]: A list of Langchain Documents.
|
|
309
|
+
"""
|
|
310
|
+
tasks = []
|
|
311
|
+
# If no source is provided, use self.path
|
|
312
|
+
if source is None:
|
|
313
|
+
if not hasattr(self, 'path') or self.path is None:
|
|
314
|
+
raise ValueError(
|
|
315
|
+
"No source provided and self.path is not set"
|
|
316
|
+
)
|
|
317
|
+
source = self.path
|
|
318
|
+
|
|
319
|
+
if isinstance(source, (str, Path, PurePath)):
|
|
320
|
+
# Check if it's a URL
|
|
321
|
+
if isinstance(source, str) and (source.startswith('http://') or
|
|
322
|
+
source.startswith('https://')):
|
|
323
|
+
tasks = await self.from_url(source, **kwargs)
|
|
324
|
+
else:
|
|
325
|
+
# Assume it's a file path or directory
|
|
326
|
+
tasks = await self.from_path(source, recursive=self._recursive, **kwargs)
|
|
327
|
+
elif isinstance(source, list):
|
|
328
|
+
# Check if it's a list of URLs or paths
|
|
329
|
+
if all(isinstance(item, str) and (item.startswith('http://') or
|
|
330
|
+
item.startswith('https://'))
|
|
331
|
+
for item in source):
|
|
332
|
+
tasks = await self.from_url(source, **kwargs)
|
|
333
|
+
else:
|
|
334
|
+
# Assume it's a list of file paths
|
|
335
|
+
path_tasks = []
|
|
336
|
+
for path in source:
|
|
337
|
+
path_tasks.extend(await self.from_path(path, recursive=self._recursive, **kwargs))
|
|
338
|
+
tasks = path_tasks
|
|
339
|
+
else:
|
|
340
|
+
raise ValueError(
|
|
341
|
+
f"Unsupported source type: {type(source)}"
|
|
342
|
+
)
|
|
343
|
+
# Load tasks
|
|
344
|
+
if tasks:
|
|
345
|
+
results = await self._load_tasks(tasks)
|
|
346
|
+
return results
|
|
347
|
+
|
|
348
|
+
return []
|
|
349
|
+
|
|
350
|
+
def create_metadata(
|
|
351
|
+
self,
|
|
352
|
+
path: Union[str, PurePath],
|
|
353
|
+
doctype: str = 'document',
|
|
354
|
+
source_type: str = 'source',
|
|
355
|
+
doc_metadata: Optional[dict] = None,
|
|
356
|
+
summary: Optional[str] = '',
|
|
357
|
+
**kwargs
|
|
358
|
+
):
|
|
359
|
+
if not doc_metadata:
|
|
360
|
+
doc_metadata = {}
|
|
361
|
+
if isinstance(path, PurePath):
|
|
362
|
+
origin = path.name
|
|
363
|
+
url = f'file://{path.name}'
|
|
364
|
+
filename = path
|
|
365
|
+
else:
|
|
366
|
+
origin = path
|
|
367
|
+
url = path
|
|
368
|
+
filename = f'file://{path}'
|
|
369
|
+
metadata = {
|
|
370
|
+
"url": url,
|
|
371
|
+
"source": origin,
|
|
372
|
+
"filename": str(filename),
|
|
373
|
+
"type": doctype,
|
|
374
|
+
"summary": summary,
|
|
375
|
+
"source_type": source_type or self._source_type,
|
|
376
|
+
"created_at": datetime.now().strftime("%Y-%m-%d, %H:%M:%S"),
|
|
377
|
+
"category": self.category,
|
|
378
|
+
"document_meta": {
|
|
379
|
+
**doc_metadata
|
|
380
|
+
},
|
|
381
|
+
**kwargs
|
|
382
|
+
}
|
|
383
|
+
return metadata
|
|
384
|
+
|
|
385
|
+
def create_document(self, content: Any, path: Union[str, PurePath]) -> Document:
|
|
386
|
+
"""Create a Langchain Document from the content.
|
|
387
|
+
Args:
|
|
388
|
+
content (Any): The content to create the document from.
|
|
389
|
+
Returns:
|
|
390
|
+
Document: A Langchain Document.
|
|
391
|
+
"""
|
|
392
|
+
return Document(
|
|
393
|
+
page_content=content,
|
|
394
|
+
metadata=self.create_metadata(
|
|
395
|
+
path=path,
|
|
396
|
+
doctype=self.doctype,
|
|
397
|
+
source_type=self._source_type
|
|
398
|
+
)
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def summary_from_text(self, text: str, max_length: int = 500, min_length: int = 50) -> str:
|
|
403
|
+
"""
|
|
404
|
+
Get a summary of a text.
|
|
405
|
+
"""
|
|
406
|
+
if not text:
|
|
407
|
+
return ''
|
|
408
|
+
try:
|
|
409
|
+
summarizer = self.get_summarization_model()
|
|
410
|
+
if self._use_summary_pipeline:
|
|
411
|
+
# Use Huggingface pipeline
|
|
412
|
+
content = summarizer(
|
|
413
|
+
text,
|
|
414
|
+
max_length=max_length,
|
|
415
|
+
min_length=min_length,
|
|
416
|
+
do_sample=False,
|
|
417
|
+
truncation=True
|
|
418
|
+
)
|
|
419
|
+
return content[0].get('summary_text', '')
|
|
420
|
+
# Use Summarize Chain from Langchain
|
|
421
|
+
doc = Document(page_content=text)
|
|
422
|
+
summary = summarizer.invoke(
|
|
423
|
+
{"input_documents": [doc]}, return_only_outputs=True
|
|
424
|
+
)
|
|
425
|
+
return summary.get('output_text', '')
|
|
426
|
+
except Exception as e:
|
|
427
|
+
self.logger.error(
|
|
428
|
+
f'ERROR on summary_from_text: {e}'
|
|
429
|
+
)
|
|
430
|
+
return ""
|
|
431
|
+
|
|
432
|
+
def get_summarization_model(
|
|
433
|
+
self,
|
|
434
|
+
model_name: str = 'facebook/bart-large-cnn'
|
|
435
|
+
):
|
|
436
|
+
if not self._summary_model:
|
|
437
|
+
if self._use_summary_pipeline:
|
|
438
|
+
summarize_model = AutoModelForSeq2SeqLM.from_pretrained(
|
|
439
|
+
model_name,
|
|
440
|
+
)
|
|
441
|
+
summarize_tokenizer = AutoTokenizer.from_pretrained(
|
|
442
|
+
model_name,
|
|
443
|
+
padding_side="left"
|
|
444
|
+
)
|
|
445
|
+
self._summary_model = pipeline(
|
|
446
|
+
"summarization",
|
|
447
|
+
model=summarize_model,
|
|
448
|
+
tokenizer=summarize_tokenizer
|
|
449
|
+
)
|
|
450
|
+
else:
|
|
451
|
+
# Use Summarize Chain from Langchain
|
|
452
|
+
prompt_template = """Write a summary of the following, please also identify the main theme:
|
|
453
|
+
{text}
|
|
454
|
+
SUMMARY:"""
|
|
455
|
+
prompt = PromptTemplate.from_template(prompt_template)
|
|
456
|
+
refine_template = (
|
|
457
|
+
"Your job is to produce a final summary\n"
|
|
458
|
+
"We have provided an existing summary up to a certain point: {existing_answer}\n"
|
|
459
|
+
"We have the opportunity to refine the existing summary"
|
|
460
|
+
"(only if needed) with some more context below.\n"
|
|
461
|
+
"------------\n"
|
|
462
|
+
"{text}\n"
|
|
463
|
+
"------------\n"
|
|
464
|
+
"Given the new context, refine the original summary adding more explanation."
|
|
465
|
+
"If the context isn't useful, return the original summary."
|
|
466
|
+
)
|
|
467
|
+
refine_prompt = PromptTemplate.from_template(refine_template)
|
|
468
|
+
llm = self.get_default_llm()
|
|
469
|
+
llm = llm.get_llm()
|
|
470
|
+
summarize_chain = load_summarize_chain(
|
|
471
|
+
llm=llm,
|
|
472
|
+
chain_type="refine",
|
|
473
|
+
question_prompt=prompt,
|
|
474
|
+
refine_prompt=refine_prompt,
|
|
475
|
+
return_intermediate_steps=False,
|
|
476
|
+
input_key="input_documents",
|
|
477
|
+
output_key="output_text",
|
|
478
|
+
)
|
|
479
|
+
self._summary_model = summarize_chain
|
|
480
|
+
return self._summary_model
|
|
481
|
+
|
|
482
|
+
def translate_text(self, text: str, source_lang: str = "en", target_lang: str = "es") -> str:
|
|
483
|
+
"""
|
|
484
|
+
Translate text from source language to target language.
|
|
485
|
+
|
|
486
|
+
Args:
|
|
487
|
+
text: Text to translate
|
|
488
|
+
source_lang: Source language code (default: 'en')
|
|
489
|
+
target_lang: Target language code (default: 'es')
|
|
490
|
+
|
|
491
|
+
Returns:
|
|
492
|
+
Translated text
|
|
493
|
+
"""
|
|
494
|
+
if not text:
|
|
495
|
+
return ''
|
|
496
|
+
try:
|
|
497
|
+
translator = self.get_translation_model(source_lang, target_lang)
|
|
498
|
+
if self._use_translation_pipeline:
|
|
499
|
+
# Use Huggingface pipeline
|
|
500
|
+
content = translator(
|
|
501
|
+
text,
|
|
502
|
+
max_length=len(text) * 2, # Allow for expansion in target language
|
|
503
|
+
truncation=True
|
|
504
|
+
)
|
|
505
|
+
return content[0].get('translation_text', '')
|
|
506
|
+
else:
|
|
507
|
+
# Use LLM for translation
|
|
508
|
+
translation = translator.invoke(
|
|
509
|
+
{
|
|
510
|
+
"text": text,
|
|
511
|
+
"source_lang": source_lang,
|
|
512
|
+
"target_lang": target_lang
|
|
513
|
+
}
|
|
514
|
+
)
|
|
515
|
+
return translation.get('text', '')
|
|
516
|
+
except Exception as e:
|
|
517
|
+
self.logger.error(f'ERROR on translate_text: {e}')
|
|
518
|
+
return ""
|
|
519
|
+
|
|
520
|
+
def get_translation_model(
|
|
521
|
+
self,
|
|
522
|
+
source_lang: str = "en",
|
|
523
|
+
target_lang: str = "es",
|
|
524
|
+
model_name: str = None
|
|
525
|
+
):
|
|
526
|
+
"""
|
|
527
|
+
Get or create a translation model.
|
|
528
|
+
|
|
529
|
+
Args:
|
|
530
|
+
source_lang: Source language code
|
|
531
|
+
target_lang: Target language code
|
|
532
|
+
model_name: Optional model name override
|
|
533
|
+
|
|
534
|
+
Returns:
|
|
535
|
+
Translation model/chain
|
|
536
|
+
"""
|
|
537
|
+
# Create a cache key for the language pair
|
|
538
|
+
cache_key = f"{source_lang}_{target_lang}"
|
|
539
|
+
|
|
540
|
+
# Check if we already have a model for this language pair
|
|
541
|
+
if not hasattr(self, '_translation_models'):
|
|
542
|
+
self._translation_models = {}
|
|
543
|
+
|
|
544
|
+
if cache_key not in self._translation_models:
|
|
545
|
+
if self._use_translation_pipeline:
|
|
546
|
+
# Select appropriate model based on language pair if not specified
|
|
547
|
+
if model_name is None:
|
|
548
|
+
if source_lang == "en" and target_lang in ["es", "fr", "de", "it", "pt", "ru"]:
|
|
549
|
+
model_name = "Helsinki-NLP/opus-mt-en-ROMANCE"
|
|
550
|
+
elif source_lang in ["es", "fr", "de", "it", "pt"] and target_lang == "en":
|
|
551
|
+
model_name = "Helsinki-NLP/opus-mt-ROMANCE-en"
|
|
552
|
+
else:
|
|
553
|
+
# Default to a specific model for the language pair
|
|
554
|
+
model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}"
|
|
555
|
+
|
|
556
|
+
try:
|
|
557
|
+
translate_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
|
558
|
+
translate_tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
559
|
+
|
|
560
|
+
self._translation_models[cache_key] = pipeline(
|
|
561
|
+
"translation",
|
|
562
|
+
model=translate_model,
|
|
563
|
+
tokenizer=translate_tokenizer
|
|
564
|
+
)
|
|
565
|
+
except Exception as e:
|
|
566
|
+
self.logger.error(f"Error loading translation model {model_name}: {e}")
|
|
567
|
+
# Fallback to using LLM for translation
|
|
568
|
+
self._use_translation_pipeline = False
|
|
569
|
+
|
|
570
|
+
if not self._use_translation_pipeline:
|
|
571
|
+
# Use LLM Chain for translation
|
|
572
|
+
prompt_template = """Translate the following text from {source_lang} to {target_lang}:
|
|
573
|
+
|
|
574
|
+
Text: {text}
|
|
575
|
+
|
|
576
|
+
Translation:"""
|
|
577
|
+
|
|
578
|
+
prompt = PromptTemplate(
|
|
579
|
+
template=prompt_template,
|
|
580
|
+
input_variables=["text", "source_lang", "target_lang"]
|
|
581
|
+
)
|
|
582
|
+
|
|
583
|
+
llm = self.get_default_llm().get_llm()
|
|
584
|
+
# Create a simple translation chain
|
|
585
|
+
translation_chain = (
|
|
586
|
+
{
|
|
587
|
+
"text": RunnablePassthrough(),
|
|
588
|
+
"source_lang": lambda x: source_lang,
|
|
589
|
+
"target_lang": lambda x: target_lang,
|
|
590
|
+
}
|
|
591
|
+
| prompt
|
|
592
|
+
| llm
|
|
593
|
+
| (lambda x: {"text": x})
|
|
594
|
+
)
|
|
595
|
+
self._translation_models[cache_key] = translation_chain
|
|
596
|
+
|
|
597
|
+
return self._translation_models[cache_key]
|
|
598
|
+
|
|
599
|
+
def create_translated_document(
|
|
600
|
+
self,
|
|
601
|
+
content: str,
|
|
602
|
+
metadata: dict,
|
|
603
|
+
source_lang: str = "en",
|
|
604
|
+
target_lang: str = "es"
|
|
605
|
+
) -> Document:
|
|
606
|
+
"""
|
|
607
|
+
Create a document with translated content.
|
|
608
|
+
|
|
609
|
+
Args:
|
|
610
|
+
content: Original content
|
|
611
|
+
metadata: Document metadata
|
|
612
|
+
source_lang: Source language code
|
|
613
|
+
target_lang: Target language code
|
|
614
|
+
|
|
615
|
+
Returns:
|
|
616
|
+
Document with translated content
|
|
617
|
+
"""
|
|
618
|
+
translated_content = self.translate_text(content, source_lang, target_lang)
|
|
619
|
+
|
|
620
|
+
# Clone the metadata and add translation info
|
|
621
|
+
translation_metadata = metadata.copy()
|
|
622
|
+
translation_metadata.update({
|
|
623
|
+
"original_language": source_lang,
|
|
624
|
+
"language": target_lang,
|
|
625
|
+
"is_translation": True
|
|
626
|
+
})
|
|
627
|
+
|
|
628
|
+
return Document(page_content=translated_content, metadata=translation_metadata)
|
|
File without changes
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from typing import Optional, Any
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from navconfig.logging import logging
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class FilePlugin(ABC):
|
|
7
|
+
"""
|
|
8
|
+
FilePlugin is a base class for Open Files.
|
|
9
|
+
It provides a common interface for all opening all kind of iles.
|
|
10
|
+
Subclasses should implement the `open` method to define
|
|
11
|
+
the specific file processing logic.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, *args, **kwargs):
|
|
15
|
+
"""
|
|
16
|
+
Initialize the ImagePlugin with an optional image path.
|
|
17
|
+
|
|
18
|
+
:param image: Path to the image file.
|
|
19
|
+
"""
|
|
20
|
+
self.logger = logging.getLogger(
|
|
21
|
+
f'parrot.FileLoader.{self.__class__.__name__}'
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
async def __aenter__(self):
|
|
25
|
+
if hasattr(self, "open"):
|
|
26
|
+
await self.open()
|
|
27
|
+
return self
|
|
28
|
+
|
|
29
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
30
|
+
if hasattr(self, "close"):
|
|
31
|
+
await self.close()
|
|
32
|
+
return True
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
async def read(self):
|
|
36
|
+
"""
|
|
37
|
+
Return the content of the file, need to be implemented in the subclass.
|
|
38
|
+
"""
|
|
39
|
+
pass
|