ai-parrot 0.8.3__cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-parrot might be problematic. Click here for more details.

Files changed (128) hide show
  1. ai_parrot-0.8.3.dist-info/LICENSE +21 -0
  2. ai_parrot-0.8.3.dist-info/METADATA +306 -0
  3. ai_parrot-0.8.3.dist-info/RECORD +128 -0
  4. ai_parrot-0.8.3.dist-info/WHEEL +6 -0
  5. ai_parrot-0.8.3.dist-info/top_level.txt +2 -0
  6. parrot/__init__.py +30 -0
  7. parrot/bots/__init__.py +5 -0
  8. parrot/bots/abstract.py +1115 -0
  9. parrot/bots/agent.py +492 -0
  10. parrot/bots/basic.py +9 -0
  11. parrot/bots/bose.py +17 -0
  12. parrot/bots/chatbot.py +271 -0
  13. parrot/bots/cody.py +17 -0
  14. parrot/bots/copilot.py +117 -0
  15. parrot/bots/data.py +730 -0
  16. parrot/bots/dataframe.py +103 -0
  17. parrot/bots/hrbot.py +15 -0
  18. parrot/bots/interfaces/__init__.py +1 -0
  19. parrot/bots/interfaces/retrievers.py +12 -0
  20. parrot/bots/notebook.py +619 -0
  21. parrot/bots/odoo.py +17 -0
  22. parrot/bots/prompts/__init__.py +41 -0
  23. parrot/bots/prompts/agents.py +91 -0
  24. parrot/bots/prompts/data.py +214 -0
  25. parrot/bots/retrievals/__init__.py +1 -0
  26. parrot/bots/retrievals/constitutional.py +19 -0
  27. parrot/bots/retrievals/multi.py +122 -0
  28. parrot/bots/retrievals/retrieval.py +610 -0
  29. parrot/bots/tools/__init__.py +7 -0
  30. parrot/bots/tools/eda.py +325 -0
  31. parrot/bots/tools/pdf.py +50 -0
  32. parrot/bots/tools/plot.py +48 -0
  33. parrot/bots/troc.py +16 -0
  34. parrot/conf.py +170 -0
  35. parrot/crew/__init__.py +3 -0
  36. parrot/crew/tools/__init__.py +22 -0
  37. parrot/crew/tools/bing.py +13 -0
  38. parrot/crew/tools/config.py +43 -0
  39. parrot/crew/tools/duckgo.py +62 -0
  40. parrot/crew/tools/file.py +24 -0
  41. parrot/crew/tools/google.py +168 -0
  42. parrot/crew/tools/gtrends.py +16 -0
  43. parrot/crew/tools/md2pdf.py +25 -0
  44. parrot/crew/tools/rag.py +42 -0
  45. parrot/crew/tools/search.py +32 -0
  46. parrot/crew/tools/url.py +21 -0
  47. parrot/exceptions.cpython-312-x86_64-linux-gnu.so +0 -0
  48. parrot/handlers/__init__.py +4 -0
  49. parrot/handlers/agents.py +292 -0
  50. parrot/handlers/bots.py +196 -0
  51. parrot/handlers/chat.py +192 -0
  52. parrot/interfaces/__init__.py +6 -0
  53. parrot/interfaces/database.py +27 -0
  54. parrot/interfaces/http.py +805 -0
  55. parrot/interfaces/images/__init__.py +0 -0
  56. parrot/interfaces/images/plugins/__init__.py +18 -0
  57. parrot/interfaces/images/plugins/abstract.py +58 -0
  58. parrot/interfaces/images/plugins/exif.py +709 -0
  59. parrot/interfaces/images/plugins/hash.py +52 -0
  60. parrot/interfaces/images/plugins/vision.py +104 -0
  61. parrot/interfaces/images/plugins/yolo.py +66 -0
  62. parrot/interfaces/images/plugins/zerodetect.py +197 -0
  63. parrot/llms/__init__.py +1 -0
  64. parrot/llms/abstract.py +69 -0
  65. parrot/llms/anthropic.py +58 -0
  66. parrot/llms/gemma.py +15 -0
  67. parrot/llms/google.py +44 -0
  68. parrot/llms/groq.py +67 -0
  69. parrot/llms/hf.py +45 -0
  70. parrot/llms/openai.py +61 -0
  71. parrot/llms/pipes.py +114 -0
  72. parrot/llms/vertex.py +89 -0
  73. parrot/loaders/__init__.py +9 -0
  74. parrot/loaders/abstract.py +628 -0
  75. parrot/loaders/files/__init__.py +0 -0
  76. parrot/loaders/files/abstract.py +39 -0
  77. parrot/loaders/files/text.py +63 -0
  78. parrot/loaders/txt.py +26 -0
  79. parrot/manager.py +333 -0
  80. parrot/models.py +504 -0
  81. parrot/py.typed +0 -0
  82. parrot/stores/__init__.py +11 -0
  83. parrot/stores/abstract.py +248 -0
  84. parrot/stores/chroma.py +188 -0
  85. parrot/stores/duck.py +162 -0
  86. parrot/stores/embeddings/__init__.py +10 -0
  87. parrot/stores/embeddings/abstract.py +46 -0
  88. parrot/stores/embeddings/base.py +52 -0
  89. parrot/stores/embeddings/bge.py +20 -0
  90. parrot/stores/embeddings/fastembed.py +17 -0
  91. parrot/stores/embeddings/google.py +18 -0
  92. parrot/stores/embeddings/huggingface.py +20 -0
  93. parrot/stores/embeddings/ollama.py +14 -0
  94. parrot/stores/embeddings/openai.py +26 -0
  95. parrot/stores/embeddings/transformers.py +21 -0
  96. parrot/stores/embeddings/vertexai.py +17 -0
  97. parrot/stores/empty.py +10 -0
  98. parrot/stores/faiss.py +160 -0
  99. parrot/stores/milvus.py +397 -0
  100. parrot/stores/postgres.py +653 -0
  101. parrot/stores/qdrant.py +170 -0
  102. parrot/tools/__init__.py +23 -0
  103. parrot/tools/abstract.py +68 -0
  104. parrot/tools/asknews.py +33 -0
  105. parrot/tools/basic.py +51 -0
  106. parrot/tools/bby.py +359 -0
  107. parrot/tools/bing.py +13 -0
  108. parrot/tools/docx.py +343 -0
  109. parrot/tools/duck.py +62 -0
  110. parrot/tools/execute.py +56 -0
  111. parrot/tools/gamma.py +28 -0
  112. parrot/tools/google.py +170 -0
  113. parrot/tools/gvoice.py +301 -0
  114. parrot/tools/results.py +278 -0
  115. parrot/tools/stack.py +27 -0
  116. parrot/tools/weather.py +70 -0
  117. parrot/tools/wikipedia.py +58 -0
  118. parrot/tools/zipcode.py +198 -0
  119. parrot/utils/__init__.py +2 -0
  120. parrot/utils/parsers/__init__.py +5 -0
  121. parrot/utils/parsers/toml.cpython-312-x86_64-linux-gnu.so +0 -0
  122. parrot/utils/toml.py +11 -0
  123. parrot/utils/types.cpython-312-x86_64-linux-gnu.so +0 -0
  124. parrot/utils/uv.py +11 -0
  125. parrot/version.py +10 -0
  126. resources/users/__init__.py +5 -0
  127. resources/users/handlers.py +13 -0
  128. resources/users/models.py +205 -0
@@ -0,0 +1,628 @@
1
+ from typing import Generator, Union, List, Any, Optional, TypeVar
2
+ from collections.abc import Callable
3
+ from abc import ABC, abstractmethod
4
+ from datetime import datetime
5
+ from pathlib import Path, PurePath
6
+ import asyncio
7
+ import torch
8
+ from transformers import (
9
+ AutoModelForSeq2SeqLM,
10
+ AutoTokenizer,
11
+ pipeline
12
+ )
13
+ from langchain.schema.runnable import RunnablePassthrough
14
+ from langchain.chains.summarize import load_summarize_chain
15
+ from langchain.docstore.document import Document
16
+ from langchain.text_splitter import (
17
+ TokenTextSplitter
18
+ )
19
+ from langchain_core.prompts import PromptTemplate
20
+ from navconfig.logging import logging
21
+ from navigator.libs.json import JSONContent # pylint: disable=E0611
22
+ from parrot.llms.vertex import VertexLLM
23
+ from ..conf import (
24
+ DEFAULT_LLM_MODEL,
25
+ DEFAULT_LLM_TEMPERATURE,
26
+ CUDA_DEFAULT_DEVICE,
27
+ CUDA_DEFAULT_DEVICE_NUMBER
28
+ )
29
+
30
+
31
+ T = TypeVar('T')
32
+
33
+
34
+ class AbstractLoader(ABC):
35
+ """
36
+ Base class for all loaders. Loaders are responsible for loading data from various sources.
37
+ """
38
+ extensions: List[str] = ['.*']
39
+ skip_directories: List[str] = []
40
+
41
+ def __init__(
42
+ self,
43
+ *args,
44
+ tokenizer: Union[str, Callable] = None,
45
+ text_splitter: Union[str, Callable] = None,
46
+ source_type: str = 'file',
47
+ **kwargs
48
+ ):
49
+ self.chunk_size: int = kwargs.get('chunk_size', 50)
50
+ self.semaphore = asyncio.Semaphore(kwargs.get('semaphore', 10))
51
+ self.extensions = kwargs.get('extensions', self.extensions)
52
+ self.skip_directories = kwargs.get('skip_directories', self.skip_directories)
53
+ self.encoding = kwargs.get('encoding', 'utf-8')
54
+ self._source_type = source_type
55
+ self._recursive: bool = kwargs.get('recursive', False)
56
+ self.category: str = kwargs.get('category', 'document')
57
+ self.doctype: str = kwargs.get('doctype', 'text')
58
+ self._summarization = kwargs.get('summarization', False)
59
+ self._summary_model: Optional[Any] = kwargs.get('summary_model', None)
60
+ self._use_summary_pipeline: bool = kwargs.get('use_summary_pipeline', False)
61
+ self._use_translation_pipeline: bool = kwargs.get('use_translation_pipeline', False)
62
+ self._translation = kwargs.get('translation', False)
63
+ # Tokenizer
64
+ self.tokenizer = tokenizer
65
+ # Text Splitter
66
+ self.text_splitter = text_splitter
67
+ # Summarization Model:
68
+ self.summarization_model = kwargs.get('summarizer', None)
69
+ # Markdown Splitter:
70
+ self.markdown_splitter = kwargs.get('markdown_splitter', None)
71
+ if 'path' in kwargs:
72
+ self.path = kwargs['path']
73
+ if isinstance(self.path, str):
74
+ self.path = Path(self.path).resolve()
75
+ # LLM (if required)
76
+ self._use_llm = kwargs.get('use_llm', False)
77
+ self._llm_model = kwargs.get('llm_model', None)
78
+ self._llm_model_kwargs = kwargs.get('model_kwargs', {})
79
+ self._llm = kwargs.get('llm', None)
80
+ if self._use_llm:
81
+ self._llm = self.get_default_llm(
82
+ model=self._llm_model,
83
+ model_kwargs=self._llm_model_kwargs,
84
+ )
85
+ self.logger = logging.getLogger(
86
+ f"Parrot.Loaders.{self.__class__.__name__}"
87
+ )
88
+ # JSON encoder:
89
+ self._encoder = JSONContent()
90
+ # Use CUDA if available:
91
+ self.device_name = kwargs.get('device', CUDA_DEFAULT_DEVICE)
92
+ self.cuda_number = kwargs.get('cuda_number', CUDA_DEFAULT_DEVICE_NUMBER)
93
+ self._device = None
94
+
95
+ def get_default_llm(self, model: str = None, model_kwargs: dict = None):
96
+ """Return a VertexLLM instance."""
97
+ if not model_kwargs:
98
+ model_kwargs = {
99
+ "temperature": DEFAULT_LLM_TEMPERATURE,
100
+ "top_k": 30,
101
+ "top_p": 0.5,
102
+ }
103
+ return VertexLLM(
104
+ model=model or DEFAULT_LLM_MODEL,
105
+ **model_kwargs
106
+ )
107
+
108
+ def _get_device(
109
+ self,
110
+ device_type: str = None,
111
+ cuda_number: int = 0
112
+ ):
113
+ """Get Default device for Torch and transformers.
114
+
115
+ """
116
+ if device_type == 'cpu':
117
+ return torch.device('cpu')
118
+ if device_type == 'cuda':
119
+ return torch.device(f'cuda:{cuda_number}')
120
+ if CUDA_DEFAULT_DEVICE == 'cpu':
121
+ # Use CPU if CUDA is not available
122
+ return torch.device('cpu')
123
+ if torch.cuda.is_available():
124
+ # Use CUDA GPU if available
125
+ return torch.device(f'cuda:{cuda_number}')
126
+ if torch.backends.mps.is_available():
127
+ # Use CUDA Multi-Processing Service if available
128
+ return torch.device("mps")
129
+ if CUDA_DEFAULT_DEVICE == 'cuda':
130
+ return torch.device(f'cuda:{cuda_number}')
131
+ else:
132
+ return torch.device(CUDA_DEFAULT_DEVICE)
133
+
134
+ def clear_cuda(self):
135
+ self.tokenizer = None # Reset the tokenizer
136
+ self.text_splitter = None # Reset the text splitter
137
+ torch.cuda.synchronize() # Wait for all kernels to finish
138
+ torch.cuda.empty_cache() # Clear unused memory
139
+
140
+ async def __aenter__(self):
141
+ """Open the loader if it has an open method."""
142
+ # Check if the loader has an open method and call it
143
+ if hasattr(self, "open"):
144
+ await self.open()
145
+ return self
146
+
147
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
148
+ """Close the loader if it has a close method."""
149
+ if hasattr(self, "close"):
150
+ await self.close()
151
+ return True
152
+
153
+ def supported_extensions(self):
154
+ """Get the supported file extensions."""
155
+ return self.extensions
156
+
157
+ def is_valid_path(self, path: Union[str, Path]) -> bool:
158
+ """Check if a path is valid."""
159
+ if isinstance(path, str):
160
+ path = Path(path)
161
+ if not path.exists():
162
+ return False
163
+ if path.is_dir() and path.name in self.skip_directories:
164
+ return False
165
+ if path.is_file():
166
+ if path.suffix not in self.extensions:
167
+ return False
168
+ if path.name.startswith("."):
169
+ return False
170
+ # check if file is empty
171
+ if path.stat().st_size == 0:
172
+ return False
173
+ # check if file is inside of skip directories:
174
+ for skip_dir in self.skip_directories:
175
+ if path.is_relative_to(skip_dir):
176
+ return False
177
+ return True
178
+
179
+ @abstractmethod
180
+ async def _load(self, source: Union[str, PurePath], **kwargs) -> List[Document]:
181
+ """Load a single data/url/file from a source and return it as a Langchain Document.
182
+
183
+ Args:
184
+ source (str): The source of the data.
185
+
186
+ Returns:
187
+ List[Document]: A list of Langchain Documents.
188
+ """
189
+ pass
190
+
191
+ async def from_path(self, path: Union[str, Path], recursive: bool = False, **kwargs) -> List[asyncio.Task]:
192
+ """
193
+ Load data from a path. This method should be overridden by subclasses.
194
+ """
195
+ tasks = []
196
+ if isinstance(path, str):
197
+ path = PurePath(path)
198
+ if path.is_dir():
199
+ for ext in self.extensions:
200
+ glob_method = path.rglob if recursive else path.glob
201
+ # Use glob to find all files with the specified extension
202
+ for item in glob_method(f'*{ext}'):
203
+ # Check if the item is a directory and if it should be skipped
204
+ if set(item.parts).isdisjoint(self.skip_directories):
205
+ if self.is_valid_path(item):
206
+ tasks.append(
207
+ asyncio.create_task(self._load(item, **kwargs))
208
+ )
209
+ elif path.is_file():
210
+ if self.is_valid_path(path):
211
+ tasks.append(
212
+ asyncio.create_task(self._load(path, **kwargs))
213
+ )
214
+ else:
215
+ self.logger.warning(f"Path {path} is not valid.")
216
+ return tasks
217
+
218
+ async def from_url(
219
+ self,
220
+ url: Union[str, List[str]],
221
+ **kwargs
222
+ ) -> List[asyncio.Task]:
223
+ """
224
+ Load data from a URL. This method should be overridden by subclasses.
225
+ """
226
+ tasks = []
227
+ if isinstance(url, str):
228
+ url = [url]
229
+ for item in url:
230
+ tasks.append(
231
+ asyncio.create_task(self._load(item, **kwargs))
232
+ )
233
+ return tasks
234
+
235
+ def chunkify(self, lst: List[T], n: int = 50) -> Generator[List[T], None, None]:
236
+ """Split a List of objects into chunks of size n.
237
+
238
+ Args:
239
+ lst: The list to split into chunks
240
+ n: The maximum size of each chunk
241
+
242
+ Yields:
243
+ List[T]: Chunks of the original list, each of size at most n
244
+ """
245
+ for i in range(0, len(lst), n):
246
+ yield lst[i:i + n]
247
+
248
+ async def _async_map(self, func: Callable, iterable: list) -> list:
249
+ """Run a function on a list of items asynchronously."""
250
+ async def async_func(item):
251
+ async with self.semaphore:
252
+ return await func(item)
253
+
254
+ tasks = [async_func(item) for item in iterable]
255
+ return await asyncio.gather(*tasks)
256
+
257
+ async def _load_tasks(self, tasks: list) -> list:
258
+ """Load a list of tasks asynchronously."""
259
+ results = []
260
+
261
+ if not tasks:
262
+ return results
263
+
264
+ # Create a controlled task function to limit concurrency
265
+ async def controlled_task(task):
266
+ async with self.semaphore:
267
+ try:
268
+ return await task
269
+ except Exception as e:
270
+ self.logger.error(f"Task error: {e}")
271
+ return e
272
+
273
+ for chunk in self.chunkify(tasks, self.chunk_size):
274
+ # Wrap each task with semaphore control
275
+ controlled_tasks = [controlled_task(task) for task in chunk]
276
+ result = await asyncio.gather(*controlled_tasks, return_exceptions=True)
277
+ if result:
278
+ for res in result:
279
+ if isinstance(res, Exception):
280
+ # Handle the exception
281
+ self.logger.error(f"Error loading {res}")
282
+ else:
283
+ # Handle both single documents and lists of documents
284
+ if isinstance(res, list):
285
+ results.extend(res)
286
+ else:
287
+ results.append(res)
288
+ return results
289
+
290
+ async def load(
291
+ self,
292
+ source: Optional[Any] = None,
293
+ **kwargs
294
+ ) -> List[Document]:
295
+ """Load data from a source and return it as a list of Langchain Documents.
296
+
297
+ The source can be:
298
+ - None: Uses self.path attribute if available
299
+ - Path or str: Treated as file path or directory
300
+ - List[str/Path]: Treated as list of file paths
301
+ - URL string: Treated as a URL
302
+ - List of URLs: Treated as list of URLs
303
+
304
+ Args:
305
+ source (Optional[Any]): The source of the data.
306
+
307
+ Returns:
308
+ List[Document]: A list of Langchain Documents.
309
+ """
310
+ tasks = []
311
+ # If no source is provided, use self.path
312
+ if source is None:
313
+ if not hasattr(self, 'path') or self.path is None:
314
+ raise ValueError(
315
+ "No source provided and self.path is not set"
316
+ )
317
+ source = self.path
318
+
319
+ if isinstance(source, (str, Path, PurePath)):
320
+ # Check if it's a URL
321
+ if isinstance(source, str) and (source.startswith('http://') or
322
+ source.startswith('https://')):
323
+ tasks = await self.from_url(source, **kwargs)
324
+ else:
325
+ # Assume it's a file path or directory
326
+ tasks = await self.from_path(source, recursive=self._recursive, **kwargs)
327
+ elif isinstance(source, list):
328
+ # Check if it's a list of URLs or paths
329
+ if all(isinstance(item, str) and (item.startswith('http://') or
330
+ item.startswith('https://'))
331
+ for item in source):
332
+ tasks = await self.from_url(source, **kwargs)
333
+ else:
334
+ # Assume it's a list of file paths
335
+ path_tasks = []
336
+ for path in source:
337
+ path_tasks.extend(await self.from_path(path, recursive=self._recursive, **kwargs))
338
+ tasks = path_tasks
339
+ else:
340
+ raise ValueError(
341
+ f"Unsupported source type: {type(source)}"
342
+ )
343
+ # Load tasks
344
+ if tasks:
345
+ results = await self._load_tasks(tasks)
346
+ return results
347
+
348
+ return []
349
+
350
+ def create_metadata(
351
+ self,
352
+ path: Union[str, PurePath],
353
+ doctype: str = 'document',
354
+ source_type: str = 'source',
355
+ doc_metadata: Optional[dict] = None,
356
+ summary: Optional[str] = '',
357
+ **kwargs
358
+ ):
359
+ if not doc_metadata:
360
+ doc_metadata = {}
361
+ if isinstance(path, PurePath):
362
+ origin = path.name
363
+ url = f'file://{path.name}'
364
+ filename = path
365
+ else:
366
+ origin = path
367
+ url = path
368
+ filename = f'file://{path}'
369
+ metadata = {
370
+ "url": url,
371
+ "source": origin,
372
+ "filename": str(filename),
373
+ "type": doctype,
374
+ "summary": summary,
375
+ "source_type": source_type or self._source_type,
376
+ "created_at": datetime.now().strftime("%Y-%m-%d, %H:%M:%S"),
377
+ "category": self.category,
378
+ "document_meta": {
379
+ **doc_metadata
380
+ },
381
+ **kwargs
382
+ }
383
+ return metadata
384
+
385
+ def create_document(self, content: Any, path: Union[str, PurePath]) -> Document:
386
+ """Create a Langchain Document from the content.
387
+ Args:
388
+ content (Any): The content to create the document from.
389
+ Returns:
390
+ Document: A Langchain Document.
391
+ """
392
+ return Document(
393
+ page_content=content,
394
+ metadata=self.create_metadata(
395
+ path=path,
396
+ doctype=self.doctype,
397
+ source_type=self._source_type
398
+ )
399
+ )
400
+
401
+
402
+ def summary_from_text(self, text: str, max_length: int = 500, min_length: int = 50) -> str:
403
+ """
404
+ Get a summary of a text.
405
+ """
406
+ if not text:
407
+ return ''
408
+ try:
409
+ summarizer = self.get_summarization_model()
410
+ if self._use_summary_pipeline:
411
+ # Use Huggingface pipeline
412
+ content = summarizer(
413
+ text,
414
+ max_length=max_length,
415
+ min_length=min_length,
416
+ do_sample=False,
417
+ truncation=True
418
+ )
419
+ return content[0].get('summary_text', '')
420
+ # Use Summarize Chain from Langchain
421
+ doc = Document(page_content=text)
422
+ summary = summarizer.invoke(
423
+ {"input_documents": [doc]}, return_only_outputs=True
424
+ )
425
+ return summary.get('output_text', '')
426
+ except Exception as e:
427
+ self.logger.error(
428
+ f'ERROR on summary_from_text: {e}'
429
+ )
430
+ return ""
431
+
432
+ def get_summarization_model(
433
+ self,
434
+ model_name: str = 'facebook/bart-large-cnn'
435
+ ):
436
+ if not self._summary_model:
437
+ if self._use_summary_pipeline:
438
+ summarize_model = AutoModelForSeq2SeqLM.from_pretrained(
439
+ model_name,
440
+ )
441
+ summarize_tokenizer = AutoTokenizer.from_pretrained(
442
+ model_name,
443
+ padding_side="left"
444
+ )
445
+ self._summary_model = pipeline(
446
+ "summarization",
447
+ model=summarize_model,
448
+ tokenizer=summarize_tokenizer
449
+ )
450
+ else:
451
+ # Use Summarize Chain from Langchain
452
+ prompt_template = """Write a summary of the following, please also identify the main theme:
453
+ {text}
454
+ SUMMARY:"""
455
+ prompt = PromptTemplate.from_template(prompt_template)
456
+ refine_template = (
457
+ "Your job is to produce a final summary\n"
458
+ "We have provided an existing summary up to a certain point: {existing_answer}\n"
459
+ "We have the opportunity to refine the existing summary"
460
+ "(only if needed) with some more context below.\n"
461
+ "------------\n"
462
+ "{text}\n"
463
+ "------------\n"
464
+ "Given the new context, refine the original summary adding more explanation."
465
+ "If the context isn't useful, return the original summary."
466
+ )
467
+ refine_prompt = PromptTemplate.from_template(refine_template)
468
+ llm = self.get_default_llm()
469
+ llm = llm.get_llm()
470
+ summarize_chain = load_summarize_chain(
471
+ llm=llm,
472
+ chain_type="refine",
473
+ question_prompt=prompt,
474
+ refine_prompt=refine_prompt,
475
+ return_intermediate_steps=False,
476
+ input_key="input_documents",
477
+ output_key="output_text",
478
+ )
479
+ self._summary_model = summarize_chain
480
+ return self._summary_model
481
+
482
+ def translate_text(self, text: str, source_lang: str = "en", target_lang: str = "es") -> str:
483
+ """
484
+ Translate text from source language to target language.
485
+
486
+ Args:
487
+ text: Text to translate
488
+ source_lang: Source language code (default: 'en')
489
+ target_lang: Target language code (default: 'es')
490
+
491
+ Returns:
492
+ Translated text
493
+ """
494
+ if not text:
495
+ return ''
496
+ try:
497
+ translator = self.get_translation_model(source_lang, target_lang)
498
+ if self._use_translation_pipeline:
499
+ # Use Huggingface pipeline
500
+ content = translator(
501
+ text,
502
+ max_length=len(text) * 2, # Allow for expansion in target language
503
+ truncation=True
504
+ )
505
+ return content[0].get('translation_text', '')
506
+ else:
507
+ # Use LLM for translation
508
+ translation = translator.invoke(
509
+ {
510
+ "text": text,
511
+ "source_lang": source_lang,
512
+ "target_lang": target_lang
513
+ }
514
+ )
515
+ return translation.get('text', '')
516
+ except Exception as e:
517
+ self.logger.error(f'ERROR on translate_text: {e}')
518
+ return ""
519
+
520
+ def get_translation_model(
521
+ self,
522
+ source_lang: str = "en",
523
+ target_lang: str = "es",
524
+ model_name: str = None
525
+ ):
526
+ """
527
+ Get or create a translation model.
528
+
529
+ Args:
530
+ source_lang: Source language code
531
+ target_lang: Target language code
532
+ model_name: Optional model name override
533
+
534
+ Returns:
535
+ Translation model/chain
536
+ """
537
+ # Create a cache key for the language pair
538
+ cache_key = f"{source_lang}_{target_lang}"
539
+
540
+ # Check if we already have a model for this language pair
541
+ if not hasattr(self, '_translation_models'):
542
+ self._translation_models = {}
543
+
544
+ if cache_key not in self._translation_models:
545
+ if self._use_translation_pipeline:
546
+ # Select appropriate model based on language pair if not specified
547
+ if model_name is None:
548
+ if source_lang == "en" and target_lang in ["es", "fr", "de", "it", "pt", "ru"]:
549
+ model_name = "Helsinki-NLP/opus-mt-en-ROMANCE"
550
+ elif source_lang in ["es", "fr", "de", "it", "pt"] and target_lang == "en":
551
+ model_name = "Helsinki-NLP/opus-mt-ROMANCE-en"
552
+ else:
553
+ # Default to a specific model for the language pair
554
+ model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}"
555
+
556
+ try:
557
+ translate_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
558
+ translate_tokenizer = AutoTokenizer.from_pretrained(model_name)
559
+
560
+ self._translation_models[cache_key] = pipeline(
561
+ "translation",
562
+ model=translate_model,
563
+ tokenizer=translate_tokenizer
564
+ )
565
+ except Exception as e:
566
+ self.logger.error(f"Error loading translation model {model_name}: {e}")
567
+ # Fallback to using LLM for translation
568
+ self._use_translation_pipeline = False
569
+
570
+ if not self._use_translation_pipeline:
571
+ # Use LLM Chain for translation
572
+ prompt_template = """Translate the following text from {source_lang} to {target_lang}:
573
+
574
+ Text: {text}
575
+
576
+ Translation:"""
577
+
578
+ prompt = PromptTemplate(
579
+ template=prompt_template,
580
+ input_variables=["text", "source_lang", "target_lang"]
581
+ )
582
+
583
+ llm = self.get_default_llm().get_llm()
584
+ # Create a simple translation chain
585
+ translation_chain = (
586
+ {
587
+ "text": RunnablePassthrough(),
588
+ "source_lang": lambda x: source_lang,
589
+ "target_lang": lambda x: target_lang,
590
+ }
591
+ | prompt
592
+ | llm
593
+ | (lambda x: {"text": x})
594
+ )
595
+ self._translation_models[cache_key] = translation_chain
596
+
597
+ return self._translation_models[cache_key]
598
+
599
+ def create_translated_document(
600
+ self,
601
+ content: str,
602
+ metadata: dict,
603
+ source_lang: str = "en",
604
+ target_lang: str = "es"
605
+ ) -> Document:
606
+ """
607
+ Create a document with translated content.
608
+
609
+ Args:
610
+ content: Original content
611
+ metadata: Document metadata
612
+ source_lang: Source language code
613
+ target_lang: Target language code
614
+
615
+ Returns:
616
+ Document with translated content
617
+ """
618
+ translated_content = self.translate_text(content, source_lang, target_lang)
619
+
620
+ # Clone the metadata and add translation info
621
+ translation_metadata = metadata.copy()
622
+ translation_metadata.update({
623
+ "original_language": source_lang,
624
+ "language": target_lang,
625
+ "is_translation": True
626
+ })
627
+
628
+ return Document(page_content=translated_content, metadata=translation_metadata)
File without changes
@@ -0,0 +1,39 @@
1
+ from typing import Optional, Any
2
+ from abc import ABC, abstractmethod
3
+ from navconfig.logging import logging
4
+
5
+
6
+ class FilePlugin(ABC):
7
+ """
8
+ FilePlugin is a base class for Open Files.
9
+ It provides a common interface for all opening all kind of iles.
10
+ Subclasses should implement the `open` method to define
11
+ the specific file processing logic.
12
+ """
13
+
14
+ def __init__(self, *args, **kwargs):
15
+ """
16
+ Initialize the ImagePlugin with an optional image path.
17
+
18
+ :param image: Path to the image file.
19
+ """
20
+ self.logger = logging.getLogger(
21
+ f'parrot.FileLoader.{self.__class__.__name__}'
22
+ )
23
+
24
+ async def __aenter__(self):
25
+ if hasattr(self, "open"):
26
+ await self.open()
27
+ return self
28
+
29
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
30
+ if hasattr(self, "close"):
31
+ await self.close()
32
+ return True
33
+
34
+ @abstractmethod
35
+ async def read(self):
36
+ """
37
+ Return the content of the file, need to be implemented in the subclass.
38
+ """
39
+ pass