ai-parrot 0.3.4__cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-parrot might be problematic. Click here for more details.
- ai_parrot-0.3.4.dist-info/LICENSE +21 -0
- ai_parrot-0.3.4.dist-info/METADATA +319 -0
- ai_parrot-0.3.4.dist-info/RECORD +109 -0
- ai_parrot-0.3.4.dist-info/WHEEL +6 -0
- ai_parrot-0.3.4.dist-info/top_level.txt +3 -0
- parrot/__init__.py +21 -0
- parrot/chatbots/__init__.py +7 -0
- parrot/chatbots/abstract.py +728 -0
- parrot/chatbots/asktroc.py +16 -0
- parrot/chatbots/base.py +366 -0
- parrot/chatbots/basic.py +9 -0
- parrot/chatbots/bose.py +17 -0
- parrot/chatbots/cody.py +17 -0
- parrot/chatbots/copilot.py +83 -0
- parrot/chatbots/dataframe.py +103 -0
- parrot/chatbots/hragents.py +15 -0
- parrot/chatbots/odoo.py +17 -0
- parrot/chatbots/retrievals/__init__.py +578 -0
- parrot/chatbots/retrievals/constitutional.py +19 -0
- parrot/conf.py +110 -0
- parrot/crew/__init__.py +3 -0
- parrot/crew/tools/__init__.py +22 -0
- parrot/crew/tools/bing.py +13 -0
- parrot/crew/tools/config.py +43 -0
- parrot/crew/tools/duckgo.py +62 -0
- parrot/crew/tools/file.py +24 -0
- parrot/crew/tools/google.py +168 -0
- parrot/crew/tools/gtrends.py +16 -0
- parrot/crew/tools/md2pdf.py +25 -0
- parrot/crew/tools/rag.py +42 -0
- parrot/crew/tools/search.py +32 -0
- parrot/crew/tools/url.py +21 -0
- parrot/exceptions.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/handlers/__init__.py +4 -0
- parrot/handlers/bots.py +196 -0
- parrot/handlers/chat.py +162 -0
- parrot/interfaces/__init__.py +6 -0
- parrot/interfaces/database.py +29 -0
- parrot/llms/__init__.py +137 -0
- parrot/llms/abstract.py +47 -0
- parrot/llms/anthropic.py +42 -0
- parrot/llms/google.py +42 -0
- parrot/llms/groq.py +45 -0
- parrot/llms/hf.py +45 -0
- parrot/llms/openai.py +59 -0
- parrot/llms/pipes.py +114 -0
- parrot/llms/vertex.py +78 -0
- parrot/loaders/__init__.py +20 -0
- parrot/loaders/abstract.py +456 -0
- parrot/loaders/audio.py +106 -0
- parrot/loaders/basepdf.py +102 -0
- parrot/loaders/basevideo.py +280 -0
- parrot/loaders/csv.py +42 -0
- parrot/loaders/dir.py +37 -0
- parrot/loaders/excel.py +349 -0
- parrot/loaders/github.py +65 -0
- parrot/loaders/handlers/__init__.py +5 -0
- parrot/loaders/handlers/data.py +213 -0
- parrot/loaders/image.py +119 -0
- parrot/loaders/json.py +52 -0
- parrot/loaders/pdf.py +437 -0
- parrot/loaders/pdfchapters.py +142 -0
- parrot/loaders/pdffn.py +112 -0
- parrot/loaders/pdfimages.py +207 -0
- parrot/loaders/pdfmark.py +88 -0
- parrot/loaders/pdftables.py +145 -0
- parrot/loaders/ppt.py +30 -0
- parrot/loaders/qa.py +81 -0
- parrot/loaders/repo.py +103 -0
- parrot/loaders/rtd.py +65 -0
- parrot/loaders/txt.py +92 -0
- parrot/loaders/utils/__init__.py +1 -0
- parrot/loaders/utils/models.py +25 -0
- parrot/loaders/video.py +96 -0
- parrot/loaders/videolocal.py +120 -0
- parrot/loaders/vimeo.py +106 -0
- parrot/loaders/web.py +216 -0
- parrot/loaders/web_base.py +112 -0
- parrot/loaders/word.py +125 -0
- parrot/loaders/youtube.py +192 -0
- parrot/manager.py +166 -0
- parrot/models.py +372 -0
- parrot/py.typed +0 -0
- parrot/stores/__init__.py +48 -0
- parrot/stores/abstract.py +171 -0
- parrot/stores/milvus.py +632 -0
- parrot/stores/qdrant.py +153 -0
- parrot/tools/__init__.py +12 -0
- parrot/tools/abstract.py +53 -0
- parrot/tools/asknews.py +32 -0
- parrot/tools/bing.py +13 -0
- parrot/tools/duck.py +62 -0
- parrot/tools/google.py +170 -0
- parrot/tools/stack.py +26 -0
- parrot/tools/weather.py +70 -0
- parrot/tools/wikipedia.py +59 -0
- parrot/tools/zipcode.py +179 -0
- parrot/utils/__init__.py +2 -0
- parrot/utils/parsers/__init__.py +5 -0
- parrot/utils/parsers/toml.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/utils/toml.py +11 -0
- parrot/utils/types.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/utils/uv.py +11 -0
- parrot/version.py +10 -0
- resources/users/__init__.py +5 -0
- resources/users/handlers.py +13 -0
- resources/users/models.py +205 -0
- settings/__init__.py +0 -0
- settings/settings.py +51 -0
parrot/llms/pipes.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from langchain_community.llms import HuggingFacePipeline # pylint: disable=import-error, E0611
|
|
3
|
+
from transformers import (
|
|
4
|
+
AutoModelForCausalLM,
|
|
5
|
+
AutoProcessor,
|
|
6
|
+
LlavaForConditionalGeneration,
|
|
7
|
+
AutoTokenizer,
|
|
8
|
+
GenerationConfig,
|
|
9
|
+
pipeline
|
|
10
|
+
)
|
|
11
|
+
from .abstract import AbstractLLM
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class PipelineLLM(AbstractLLM):
|
|
15
|
+
"""PipelineLLM.
|
|
16
|
+
|
|
17
|
+
Load a LLM (Language Model) from HuggingFace Hub.
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
_type_: an instance of HuggingFace LLM Model.
|
|
21
|
+
"""
|
|
22
|
+
model: str = "databricks/dolly-v2-3b"
|
|
23
|
+
embed_model: str = None
|
|
24
|
+
max_tokens: int = 1024
|
|
25
|
+
supported_models: list = [
|
|
26
|
+
"databricks/dolly-v2-3b",
|
|
27
|
+
"gpt2",
|
|
28
|
+
"bigscience/bloom-1b7",
|
|
29
|
+
"meta-llama/Llama-2-7b-hf",
|
|
30
|
+
'llava-hf/llava-1.5-7b-hf'
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
def __init__(self, *args, **kwargs):
|
|
34
|
+
self.batch_size = kwargs.get('batch_size', 4)
|
|
35
|
+
self.use_llava: bool = kwargs.get('use_llava', False)
|
|
36
|
+
self.model_args = kwargs.get('model_args', {})
|
|
37
|
+
super().__init__(*args, **kwargs)
|
|
38
|
+
dtype = kwargs.get('dtype', 'float16')
|
|
39
|
+
if dtype == 'bfloat16':
|
|
40
|
+
torch_dtype = torch.bfloat16
|
|
41
|
+
if dtype == 'float16':
|
|
42
|
+
torch_dtype = torch.float16
|
|
43
|
+
elif dtype == 'float32':
|
|
44
|
+
torch_dtype = torch.float32
|
|
45
|
+
elif dtype == 'float8':
|
|
46
|
+
torch_dtype = torch.float8
|
|
47
|
+
else:
|
|
48
|
+
torch_dtype = "auto"
|
|
49
|
+
use_fast = kwargs.get('use_fast', True)
|
|
50
|
+
if self.use_llava is False:
|
|
51
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
52
|
+
self.model,
|
|
53
|
+
chunk_size=self.max_tokens
|
|
54
|
+
)
|
|
55
|
+
self._model = AutoModelForCausalLM.from_pretrained(
|
|
56
|
+
self.model,
|
|
57
|
+
device_map="auto",
|
|
58
|
+
torch_dtype=torch_dtype,
|
|
59
|
+
trust_remote_code=True,
|
|
60
|
+
)
|
|
61
|
+
config = GenerationConfig(
|
|
62
|
+
do_sample=True,
|
|
63
|
+
temperature=self.temperature,
|
|
64
|
+
max_new_tokens=self.max_tokens,
|
|
65
|
+
top_p=self.top_p,
|
|
66
|
+
top_k=self.top_k,
|
|
67
|
+
repetition_penalty=1.15,
|
|
68
|
+
)
|
|
69
|
+
self._pipe = pipeline(
|
|
70
|
+
task=self.task,
|
|
71
|
+
model=self._model,
|
|
72
|
+
tokenizer=self.tokenizer,
|
|
73
|
+
return_full_text=True,
|
|
74
|
+
use_fast=use_fast,
|
|
75
|
+
device_map='auto',
|
|
76
|
+
batch_size=self.batch_size,
|
|
77
|
+
generation_config=config,
|
|
78
|
+
pad_token_id = 50256,
|
|
79
|
+
framework="pt"
|
|
80
|
+
)
|
|
81
|
+
else:
|
|
82
|
+
self._model = LlavaForConditionalGeneration.from_pretrained(
|
|
83
|
+
self.model,
|
|
84
|
+
device_map="auto",
|
|
85
|
+
torch_dtype=torch_dtype,
|
|
86
|
+
trust_remote_code=True,
|
|
87
|
+
low_cpu_mem_usage=True,
|
|
88
|
+
)
|
|
89
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model)
|
|
90
|
+
processor = AutoProcessor.from_pretrained(self.model)
|
|
91
|
+
self._pipe = pipeline(
|
|
92
|
+
task=self.task,
|
|
93
|
+
model=self._model,
|
|
94
|
+
tokenizer=self.tokenizer,
|
|
95
|
+
use_fast=use_fast,
|
|
96
|
+
device_map='auto',
|
|
97
|
+
batch_size=self.batch_size,
|
|
98
|
+
image_processor=processor.image_processor,
|
|
99
|
+
framework="pt",
|
|
100
|
+
**self.model_args
|
|
101
|
+
)
|
|
102
|
+
self._pipe.tokenizer.pad_token_id = self._pipe.model.config.eos_token_id
|
|
103
|
+
self._llm = HuggingFacePipeline(
|
|
104
|
+
model_id=self.model,
|
|
105
|
+
pipeline=self._pipe,
|
|
106
|
+
verbose=True
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
def pipe(self, *args, **kwargs):
|
|
110
|
+
return self._pipe(
|
|
111
|
+
*args,
|
|
112
|
+
**kwargs,
|
|
113
|
+
generate_kwargs={"max_new_tokens": self.max_tokens}
|
|
114
|
+
)
|
parrot/llms/vertex.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from navconfig import config, BASE_DIR
|
|
3
|
+
from google.cloud import aiplatform
|
|
4
|
+
from langchain_google_vertexai import (
|
|
5
|
+
ChatVertexAI,
|
|
6
|
+
VertexAI,
|
|
7
|
+
VertexAIModelGarden,
|
|
8
|
+
VertexAIEmbeddings
|
|
9
|
+
)
|
|
10
|
+
from .abstract import AbstractLLM
|
|
11
|
+
|
|
12
|
+
class VertexLLM(AbstractLLM):
|
|
13
|
+
"""VertexLLM.
|
|
14
|
+
|
|
15
|
+
Interact with VertexAI Language Model.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
_type_: VertexAI LLM.
|
|
19
|
+
"""
|
|
20
|
+
model: str = "gemini-1.0-pro"
|
|
21
|
+
embed_model: str = "textembedding-gecko@003"
|
|
22
|
+
max_tokens: int = 1024
|
|
23
|
+
supported_models: list = [
|
|
24
|
+
"gemini-1.0-pro",
|
|
25
|
+
"gemini-1.5-pro-001",
|
|
26
|
+
"gemini-1.5-pro-exp-0801",
|
|
27
|
+
"gemini-1.5-flash-preview-0514",
|
|
28
|
+
"gemini-1.5-flash-001",
|
|
29
|
+
"chat-bison@001",
|
|
30
|
+
"claude-3-opus@20240229",
|
|
31
|
+
'claude-3-5-sonnet@20240620'
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
def __init__(self, *args, **kwargs):
|
|
35
|
+
super().__init__(*args, **kwargs)
|
|
36
|
+
use_garden: bool = kwargs.get("use_garden", False)
|
|
37
|
+
project_id = config.get("VERTEX_PROJECT_ID")
|
|
38
|
+
region = config.get("VERTEX_REGION")
|
|
39
|
+
config_file = config.get('GOOGLE_CREDENTIALS_FILE', 'env/google/vertexai.json')
|
|
40
|
+
config_dir = BASE_DIR.joinpath(config_file)
|
|
41
|
+
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = str(config_dir)
|
|
42
|
+
self.args = {
|
|
43
|
+
"project": project_id,
|
|
44
|
+
"location": region,
|
|
45
|
+
"max_output_tokens": self.max_tokens,
|
|
46
|
+
"temperature": self.temperature,
|
|
47
|
+
"max_retries": 4,
|
|
48
|
+
"top_p": self.top_p,
|
|
49
|
+
"top_k": self.top_k,
|
|
50
|
+
"verbose": True,
|
|
51
|
+
}
|
|
52
|
+
if use_garden is True:
|
|
53
|
+
base_llm = VertexAIModelGarden
|
|
54
|
+
self.args['endpoint_id'] = self.model
|
|
55
|
+
elif self.model == "chat":
|
|
56
|
+
self.args['model_name'] = "chat-bison@001"
|
|
57
|
+
base_llm = ChatVertexAI
|
|
58
|
+
else:
|
|
59
|
+
self.args['model_name'] = self.model
|
|
60
|
+
base_llm = VertexAI
|
|
61
|
+
# LLM
|
|
62
|
+
self._llm = base_llm(
|
|
63
|
+
system_prompt="Always respond in the same language as the user's question. If the user's language is not English, translate your response into their language.",
|
|
64
|
+
**self.args
|
|
65
|
+
)
|
|
66
|
+
# Embedding Model:
|
|
67
|
+
embed_model = kwargs.get("embed_model", self.embed_model)
|
|
68
|
+
self._embed = VertexAIEmbeddings(
|
|
69
|
+
model_name=embed_model,
|
|
70
|
+
project=project_id,
|
|
71
|
+
location=region,
|
|
72
|
+
request_parallelism=5,
|
|
73
|
+
max_retries=4,
|
|
74
|
+
temperature=self.temperature,
|
|
75
|
+
top_p=self.top_p,
|
|
76
|
+
top_k=self.top_k,
|
|
77
|
+
)
|
|
78
|
+
self._version_ = aiplatform.__version__
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from .dir import load_directory
|
|
2
|
+
from .pdf import PDFLoader
|
|
3
|
+
from .web import WebLoader
|
|
4
|
+
from .youtube import YoutubeLoader
|
|
5
|
+
from .vimeo import VimeoLoader
|
|
6
|
+
from .word import MSWordLoader
|
|
7
|
+
from .ppt import PPTXLoader
|
|
8
|
+
from .repo import RepositoryLoader
|
|
9
|
+
from .github import GithubLoader
|
|
10
|
+
from .json import JSONLoader
|
|
11
|
+
from .excel import ExcelLoader
|
|
12
|
+
from .web_base import WebBaseLoader
|
|
13
|
+
from .pdfmark import PDFMarkdownLoader
|
|
14
|
+
from .pdfimages import PDFImageLoader
|
|
15
|
+
from .pdftables import PDFTablesLoader
|
|
16
|
+
from .pdfchapters import PDFChapterLoader
|
|
17
|
+
from .txt import TXTLoader
|
|
18
|
+
from .qa import QAFileLoader
|
|
19
|
+
from .rtd import ReadTheDocsLoader
|
|
20
|
+
from .videolocal import VideoLocalLoader
|
|
@@ -0,0 +1,456 @@
|
|
|
1
|
+
"""Loaders are classes that are responsible for loading data from a source
|
|
2
|
+
and returning it as a Langchain Document.
|
|
3
|
+
"""
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from typing import List, Union, Optional, Any
|
|
8
|
+
from pathlib import Path, PurePath
|
|
9
|
+
import torch
|
|
10
|
+
from langchain.docstore.document import Document
|
|
11
|
+
from langchain.chains.summarize import load_summarize_chain
|
|
12
|
+
from langchain.text_splitter import (
|
|
13
|
+
RecursiveCharacterTextSplitter,
|
|
14
|
+
TokenTextSplitter
|
|
15
|
+
)
|
|
16
|
+
from langchain_core.prompts import PromptTemplate
|
|
17
|
+
from langchain_core.document_loaders.blob_loaders import Blob
|
|
18
|
+
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
|
19
|
+
|
|
20
|
+
from navconfig.logging import logging
|
|
21
|
+
from navigator.libs.json import JSONContent # pylint: disable=E0611
|
|
22
|
+
from transformers import (
|
|
23
|
+
AutoTokenizer,
|
|
24
|
+
AutoModelForSeq2SeqLM,
|
|
25
|
+
AutoConfig,
|
|
26
|
+
AutoModel,
|
|
27
|
+
pipeline
|
|
28
|
+
)
|
|
29
|
+
from ..conf import EMBEDDING_DEVICE, EMBEDDING_DEFAULT_MODEL
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
logging.getLogger(name='httpx').setLevel(logging.WARNING)
|
|
33
|
+
logging.getLogger(name='httpcore').setLevel(logging.WARNING)
|
|
34
|
+
logging.getLogger(name='pdfminer').setLevel(logging.WARNING)
|
|
35
|
+
logging.getLogger(name='langchain_community').setLevel(logging.WARNING)
|
|
36
|
+
logging.getLogger(name='numba').setLevel(logging.WARNING)
|
|
37
|
+
logging.getLogger(name='PIL').setLevel(level=logging.WARNING)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def as_string(self) -> str:
|
|
41
|
+
"""Read data as a string."""
|
|
42
|
+
if self.data is None and self.path:
|
|
43
|
+
with open(str(self.path), "r", encoding=self.encoding) as f:
|
|
44
|
+
try:
|
|
45
|
+
return f.read()
|
|
46
|
+
except UnicodeDecodeError:
|
|
47
|
+
try:
|
|
48
|
+
with open(str(self.path), "r", encoding="latin-1") as f:
|
|
49
|
+
return f.read()
|
|
50
|
+
except UnicodeDecodeError:
|
|
51
|
+
with open(str(self.path), "rb") as f:
|
|
52
|
+
return f.read().decode("utf-8", "replace")
|
|
53
|
+
elif isinstance(self.data, bytes):
|
|
54
|
+
return self.data.decode(self.encoding)
|
|
55
|
+
elif isinstance(self.data, str):
|
|
56
|
+
return self.data
|
|
57
|
+
else:
|
|
58
|
+
raise ValueError(f"Unable to get string for blob {self}")
|
|
59
|
+
|
|
60
|
+
# Monkey patch the Blob class's as_string method
|
|
61
|
+
Blob.as_string = as_string
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class AbstractLoader(ABC):
|
|
65
|
+
"""
|
|
66
|
+
Abstract class for Document loaders.
|
|
67
|
+
"""
|
|
68
|
+
_extension: List[str] = ['.txt']
|
|
69
|
+
encoding: str = 'utf-8'
|
|
70
|
+
skip_directories: List[str] = []
|
|
71
|
+
_chunk_size: int = 768
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
tokenizer: Union[str, Callable] = None,
|
|
76
|
+
text_splitter: Union[str, Callable] = None,
|
|
77
|
+
source_type: str = 'file',
|
|
78
|
+
**kwargs
|
|
79
|
+
):
|
|
80
|
+
self.tokenizer = tokenizer
|
|
81
|
+
self._summary_model = None
|
|
82
|
+
self.text_splitter = text_splitter
|
|
83
|
+
self._device = self._get_device()
|
|
84
|
+
self._chunk_size = kwargs.get('chunk_size', 768)
|
|
85
|
+
self._no_summarization = bool(
|
|
86
|
+
kwargs.get('no_summarization', False)
|
|
87
|
+
)
|
|
88
|
+
self.summarization_model = kwargs.get(
|
|
89
|
+
'summarization_model',
|
|
90
|
+
"facebook/bart-large-cnn"
|
|
91
|
+
)
|
|
92
|
+
self._no_summarization = bool(
|
|
93
|
+
kwargs.get('no_summarization', False)
|
|
94
|
+
)
|
|
95
|
+
self._source_type = source_type
|
|
96
|
+
self.logger = logging.getLogger(
|
|
97
|
+
f"Loader.{self.__class__.__name__}"
|
|
98
|
+
)
|
|
99
|
+
if 'extension' in kwargs:
|
|
100
|
+
self._extension = kwargs['extension']
|
|
101
|
+
self.encoding = kwargs.get('encoding', 'utf-8')
|
|
102
|
+
self.skip_directories: List[str] = kwargs.get('skip_directories', [])
|
|
103
|
+
# LLM (if required)
|
|
104
|
+
self._llm = kwargs.get('llm', None)
|
|
105
|
+
if not self.tokenizer:
|
|
106
|
+
self.tokenizer = self.default_tokenizer()
|
|
107
|
+
elif isinstance(self.tokenizer, str):
|
|
108
|
+
self.tokenizer = self.get_tokenizer(
|
|
109
|
+
model_name=self.tokenizer
|
|
110
|
+
)
|
|
111
|
+
if not text_splitter:
|
|
112
|
+
self.text_splitter = self.default_splitter(
|
|
113
|
+
model=self.tokenizer
|
|
114
|
+
)
|
|
115
|
+
# JSON encoder:
|
|
116
|
+
self._encoder = JSONContent()
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def __enter__(self):
|
|
120
|
+
return self
|
|
121
|
+
|
|
122
|
+
def __exit__(self, *exc_info):
|
|
123
|
+
self.post_load()
|
|
124
|
+
|
|
125
|
+
def default_tokenizer(self):
|
|
126
|
+
return self.get_tokenizer(
|
|
127
|
+
EMBEDDING_DEFAULT_MODEL,
|
|
128
|
+
chunk_size=768
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
def get_tokenizer(self, model_name: str, chunk_size: int = 768):
|
|
132
|
+
return AutoTokenizer.from_pretrained(
|
|
133
|
+
model_name,
|
|
134
|
+
chunk_size=chunk_size
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
def _get_device(self, cuda_number: int = 0):
|
|
138
|
+
if torch.cuda.is_available():
|
|
139
|
+
# Use CUDA GPU if available
|
|
140
|
+
device = torch.device(f'cuda:{cuda_number}')
|
|
141
|
+
elif torch.backends.mps.is_available():
|
|
142
|
+
# Use CUDA Multi-Processing Service if available
|
|
143
|
+
device = torch.device("mps")
|
|
144
|
+
elif EMBEDDING_DEVICE == 'cuda':
|
|
145
|
+
device = torch.device(f'cuda:{cuda_number}')
|
|
146
|
+
else:
|
|
147
|
+
device = torch.device(EMBEDDING_DEVICE)
|
|
148
|
+
return device
|
|
149
|
+
|
|
150
|
+
def get_model(self, model_name: str):
|
|
151
|
+
self._model_config = AutoConfig.from_pretrained(
|
|
152
|
+
model_name, trust_remote_code=True
|
|
153
|
+
)
|
|
154
|
+
return AutoModel.from_pretrained(
|
|
155
|
+
model_name,
|
|
156
|
+
trust_remote_code=True,
|
|
157
|
+
config=self._model_config,
|
|
158
|
+
unpad_inputs=True,
|
|
159
|
+
use_memory_efficient_attention=True,
|
|
160
|
+
).to(self._device)
|
|
161
|
+
|
|
162
|
+
def get_summarization_model(self, model_name: str = 'facebook/bart-large-cnn'):
|
|
163
|
+
if self._no_summarization is True:
|
|
164
|
+
return None
|
|
165
|
+
if not self._summary_model:
|
|
166
|
+
summarize_model = AutoModelForSeq2SeqLM.from_pretrained(
|
|
167
|
+
model_name,
|
|
168
|
+
device_map="auto",
|
|
169
|
+
torch_dtype=torch.bfloat16,
|
|
170
|
+
trust_remote_code=True
|
|
171
|
+
)
|
|
172
|
+
summarize_tokenizer = AutoTokenizer.from_pretrained(
|
|
173
|
+
model_name,
|
|
174
|
+
padding_side="left"
|
|
175
|
+
)
|
|
176
|
+
pipe_summary = pipeline(
|
|
177
|
+
"summarization",
|
|
178
|
+
model=summarize_model,
|
|
179
|
+
tokenizer=summarize_tokenizer,
|
|
180
|
+
batch_size=True,
|
|
181
|
+
max_new_tokens=500,
|
|
182
|
+
min_new_tokens=300,
|
|
183
|
+
use_fast=True
|
|
184
|
+
)
|
|
185
|
+
self._summary_model = HuggingFacePipeline(
|
|
186
|
+
model_id=model_name,
|
|
187
|
+
pipeline=pipe_summary,
|
|
188
|
+
verbose=True
|
|
189
|
+
)
|
|
190
|
+
return self._summary_model
|
|
191
|
+
|
|
192
|
+
def get_text_splitter(self, model, chunk_size: int = 2000, overlap: int = 100):
|
|
193
|
+
return RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
|
|
194
|
+
model,
|
|
195
|
+
chunk_size=chunk_size,
|
|
196
|
+
chunk_overlap=overlap,
|
|
197
|
+
add_start_index=True, # If `True`, includes chunk's start index in metadata
|
|
198
|
+
strip_whitespace=True, # strips whitespace from the start and end
|
|
199
|
+
separators=["\n\n", "\n", "\r\n", "\r", "\f", "\v", "\x0b", "\x0c"],
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
def default_splitter(self, model: Callable):
|
|
203
|
+
"""Get the text splitter."""
|
|
204
|
+
return self.get_text_splitter(
|
|
205
|
+
model,
|
|
206
|
+
chunk_size=2000,
|
|
207
|
+
overlap=100
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
def get_summary_from_text(self, text: str) -> str:
|
|
211
|
+
"""
|
|
212
|
+
Get a summary of a text.
|
|
213
|
+
"""
|
|
214
|
+
if not text:
|
|
215
|
+
# NO data to be summarized
|
|
216
|
+
return ''
|
|
217
|
+
try:
|
|
218
|
+
splitter = TokenTextSplitter(
|
|
219
|
+
chunk_size=5000,
|
|
220
|
+
chunk_overlap=100,
|
|
221
|
+
)
|
|
222
|
+
prompt_template = """Write a summary of the following, please also identify the main theme:
|
|
223
|
+
{text}
|
|
224
|
+
SUMMARY:"""
|
|
225
|
+
prompt = PromptTemplate.from_template(prompt_template)
|
|
226
|
+
refine_template = (
|
|
227
|
+
"Your job is to produce a final summary\n"
|
|
228
|
+
"We have provided an existing summary up to a certain point: {existing_answer}\n"
|
|
229
|
+
"We have the opportunity to refine the existing summary"
|
|
230
|
+
"(only if needed) with some more context below.\n"
|
|
231
|
+
"------------\n"
|
|
232
|
+
"{text}\n"
|
|
233
|
+
"------------\n"
|
|
234
|
+
"Given the new context, refine the original summary adding more explanation."
|
|
235
|
+
"If the context isn't useful, return the original summary."
|
|
236
|
+
)
|
|
237
|
+
refine_prompt = PromptTemplate.from_template(refine_template)
|
|
238
|
+
if self._llm:
|
|
239
|
+
llm = self._llm
|
|
240
|
+
else:
|
|
241
|
+
llm = self.get_summarization_model(
|
|
242
|
+
self.summarization_model
|
|
243
|
+
)
|
|
244
|
+
if not llm:
|
|
245
|
+
return ''
|
|
246
|
+
summarize_chain = load_summarize_chain(
|
|
247
|
+
llm=llm,
|
|
248
|
+
chain_type="refine",
|
|
249
|
+
question_prompt=prompt,
|
|
250
|
+
refine_prompt=refine_prompt,
|
|
251
|
+
return_intermediate_steps=True,
|
|
252
|
+
input_key="input_documents",
|
|
253
|
+
output_key="output_text",
|
|
254
|
+
)
|
|
255
|
+
chunks = splitter.split_text(text)
|
|
256
|
+
documents = [Document(page_content=chunk) for chunk in chunks]
|
|
257
|
+
summary = summarize_chain.invoke(
|
|
258
|
+
{"input_documents": documents}, return_only_outputs=True
|
|
259
|
+
)
|
|
260
|
+
return summary['output_text']
|
|
261
|
+
except Exception as e:
|
|
262
|
+
print('ERROR in get_summary_from_text:', e)
|
|
263
|
+
return ""
|
|
264
|
+
|
|
265
|
+
def split_documents(self, documents: List[Document], max_tokens: int = None) -> List[Document]:
|
|
266
|
+
"""Split the documents into chunks."""
|
|
267
|
+
if not max_tokens:
|
|
268
|
+
max_tokens = self._chunk_size
|
|
269
|
+
split_documents = []
|
|
270
|
+
for doc in documents:
|
|
271
|
+
metadata = doc.metadata.copy()
|
|
272
|
+
chunks = self.text_splitter.split_text(doc.page_content)
|
|
273
|
+
for chunk in chunks:
|
|
274
|
+
split_documents.append(
|
|
275
|
+
Document(page_content=chunk, metadata=metadata)
|
|
276
|
+
)
|
|
277
|
+
return split_documents
|
|
278
|
+
|
|
279
|
+
def split_by_tokens(self, documents: List[Document], max_tokens: int = 768) -> List[Document]:
|
|
280
|
+
"""Split the documents into chunks."""
|
|
281
|
+
split_documents = []
|
|
282
|
+
current_chunk = []
|
|
283
|
+
for doc in documents:
|
|
284
|
+
metadata = doc.metadata.copy()
|
|
285
|
+
tokens = self.tokenizer.tokenize(doc.page_content)
|
|
286
|
+
with torch.no_grad():
|
|
287
|
+
current_chunk = []
|
|
288
|
+
for token in tokens:
|
|
289
|
+
current_chunk.append(token)
|
|
290
|
+
if len(current_chunk) >= max_tokens:
|
|
291
|
+
chunk_text = self.tokenizer.convert_tokens_to_string(current_chunk)
|
|
292
|
+
# Create a new Document for this chunk, preserving metadata
|
|
293
|
+
split_doc = Document(
|
|
294
|
+
page_content=chunk_text,
|
|
295
|
+
metadata=metadata
|
|
296
|
+
)
|
|
297
|
+
split_documents.append(split_doc)
|
|
298
|
+
current_chunk = [] # Reset for the next chunk
|
|
299
|
+
# Handle the last chunk if it didn't reach the max_tokens limit
|
|
300
|
+
if current_chunk:
|
|
301
|
+
chunk_text = self.tokenizer.convert_tokens_to_string(current_chunk)
|
|
302
|
+
split_documents.append(
|
|
303
|
+
Document(page_content=chunk_text, metadata=metadata)
|
|
304
|
+
)
|
|
305
|
+
del tokens, current_chunk
|
|
306
|
+
torch.cuda.empty_cache()
|
|
307
|
+
return split_documents
|
|
308
|
+
|
|
309
|
+
def post_load(self):
|
|
310
|
+
self.tokenizer = None # Reset the tokenizer
|
|
311
|
+
self.text_splitter = None # Reset the text splitter
|
|
312
|
+
torch.cuda.synchronize() # Wait for all kernels to finish
|
|
313
|
+
torch.cuda.empty_cache() # Clear unused memory
|
|
314
|
+
|
|
315
|
+
def read_bytes(self, path: Union[str, PurePath]) -> bytes:
|
|
316
|
+
"""Read the bytes from a file.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
path (Union[str, PurePath]): The path to the file.
|
|
320
|
+
|
|
321
|
+
Returns:
|
|
322
|
+
bytes: The bytes of the file.
|
|
323
|
+
"""
|
|
324
|
+
if isinstance(path, str):
|
|
325
|
+
path = PurePath(path)
|
|
326
|
+
with open(str(path), 'rb') as f:
|
|
327
|
+
return f.read()
|
|
328
|
+
|
|
329
|
+
def open_bytes(self, path: Union[str, PurePath]) -> Any:
|
|
330
|
+
"""Open the bytes from a file.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
path (Union[str, PurePath]): The path to the file.
|
|
334
|
+
|
|
335
|
+
Returns:
|
|
336
|
+
Any: The bytes of the file.
|
|
337
|
+
"""
|
|
338
|
+
if isinstance(path, str):
|
|
339
|
+
path = PurePath(path)
|
|
340
|
+
return open(str(path), 'rb')
|
|
341
|
+
|
|
342
|
+
def read_string(self, path: Union[str, PurePath]) -> str:
|
|
343
|
+
"""Read the string from a file.
|
|
344
|
+
|
|
345
|
+
Args:
|
|
346
|
+
path (Union[str, PurePath]): The path to the file.
|
|
347
|
+
|
|
348
|
+
Returns:
|
|
349
|
+
str: The string of the file.
|
|
350
|
+
"""
|
|
351
|
+
if isinstance(path, str):
|
|
352
|
+
path = PurePath(path)
|
|
353
|
+
with open(str(path), 'r', encoding=self.encoding) as f:
|
|
354
|
+
return f.read()
|
|
355
|
+
|
|
356
|
+
def _check_path(
|
|
357
|
+
self,
|
|
358
|
+
path: PurePath,
|
|
359
|
+
suffix: Optional[List[str]] = None
|
|
360
|
+
) -> bool:
|
|
361
|
+
"""Check if the file path exists.
|
|
362
|
+
Args:
|
|
363
|
+
path (PurePath): The path to the file.
|
|
364
|
+
Returns:
|
|
365
|
+
bool: True if the file exists, False otherwise.
|
|
366
|
+
"""
|
|
367
|
+
if isinstance(path, str):
|
|
368
|
+
path = Path(path).resolve()
|
|
369
|
+
if not suffix:
|
|
370
|
+
suffix = self._extension
|
|
371
|
+
return path.exists() and path.is_file() and path.suffix in suffix
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
@abstractmethod
|
|
375
|
+
def load(self, path: Union[str, PurePath]) -> List[Document]:
|
|
376
|
+
"""Load data from a source and return it as a Langchain Document.
|
|
377
|
+
|
|
378
|
+
Args:
|
|
379
|
+
path (str): The source of the data.
|
|
380
|
+
|
|
381
|
+
Returns:
|
|
382
|
+
List[Document]: A list of Langchain Documents.
|
|
383
|
+
"""
|
|
384
|
+
pass
|
|
385
|
+
|
|
386
|
+
@abstractmethod
|
|
387
|
+
def parse(self, source: Any) -> List[Document]:
|
|
388
|
+
"""Parse data from a source and return it as a Langchain Document.
|
|
389
|
+
|
|
390
|
+
Args:
|
|
391
|
+
source (Any): The source of the data.
|
|
392
|
+
|
|
393
|
+
Returns:
|
|
394
|
+
List[Document]: A list of Langchain Documents.
|
|
395
|
+
"""
|
|
396
|
+
pass
|
|
397
|
+
|
|
398
|
+
@classmethod
|
|
399
|
+
def from_path(
|
|
400
|
+
cls,
|
|
401
|
+
path: Union[str, PurePath],
|
|
402
|
+
text_splitter: Callable,
|
|
403
|
+
source_type: str = 'file',
|
|
404
|
+
**kwargs
|
|
405
|
+
) -> List[Document]:
|
|
406
|
+
"""Load Multiple documents from a Path
|
|
407
|
+
|
|
408
|
+
Args:
|
|
409
|
+
path (Union[str, PurePath]): The path to the file.
|
|
410
|
+
|
|
411
|
+
Returns:
|
|
412
|
+
-> List[Document]: A list of Langchain Documents.
|
|
413
|
+
"""
|
|
414
|
+
if isinstance(path, str):
|
|
415
|
+
path = PurePath(path)
|
|
416
|
+
if path.is_dir():
|
|
417
|
+
documents = []
|
|
418
|
+
obj = cls(
|
|
419
|
+
tokenizer=kwargs.pop('tokenizer', None),
|
|
420
|
+
text_splitter=text_splitter,
|
|
421
|
+
source_type=source_type,
|
|
422
|
+
)
|
|
423
|
+
for ext in cls._extension:
|
|
424
|
+
for item in path.glob(f'*{ext}'):
|
|
425
|
+
if set(item.parts).isdisjoint(cls.skip_directories):
|
|
426
|
+
documents += obj.load(path=item, **kwargs)
|
|
427
|
+
# documents += cls.load(cls, path=item, **kwargs)
|
|
428
|
+
return documents
|
|
429
|
+
|
|
430
|
+
@classmethod
|
|
431
|
+
def from_url(
|
|
432
|
+
cls,
|
|
433
|
+
urls: List[str],
|
|
434
|
+
text_splitter: Callable,
|
|
435
|
+
source_type: str = 'content',
|
|
436
|
+
**kwargs
|
|
437
|
+
) -> List[Document]:
|
|
438
|
+
"""Load Multiple documents from a URL
|
|
439
|
+
|
|
440
|
+
Args:
|
|
441
|
+
urls (List[str]): The list of URLs.
|
|
442
|
+
|
|
443
|
+
Returns:
|
|
444
|
+
-> List[Document]: A list of Langchain Documents.
|
|
445
|
+
"""
|
|
446
|
+
documents = []
|
|
447
|
+
cls.tokenizer=kwargs.pop('tokenizer', None),
|
|
448
|
+
cls.text_splitter = text_splitter
|
|
449
|
+
cls._source_type = source_type
|
|
450
|
+
cls.summarization_model = kwargs.pop(
|
|
451
|
+
'summarization_model',
|
|
452
|
+
"facebook/bart-large-cnn"
|
|
453
|
+
)
|
|
454
|
+
for url in urls:
|
|
455
|
+
documents += cls.load(url, **kwargs)
|
|
456
|
+
return documents
|