ai-parrot 0.1.0__cp311-cp311-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-parrot might be problematic. Click here for more details.

Files changed (108) hide show
  1. ai_parrot-0.1.0.dist-info/LICENSE +21 -0
  2. ai_parrot-0.1.0.dist-info/METADATA +299 -0
  3. ai_parrot-0.1.0.dist-info/RECORD +108 -0
  4. ai_parrot-0.1.0.dist-info/WHEEL +5 -0
  5. ai_parrot-0.1.0.dist-info/top_level.txt +3 -0
  6. parrot/__init__.py +18 -0
  7. parrot/chatbots/__init__.py +7 -0
  8. parrot/chatbots/abstract.py +965 -0
  9. parrot/chatbots/asktroc.py +16 -0
  10. parrot/chatbots/base.py +257 -0
  11. parrot/chatbots/basic.py +9 -0
  12. parrot/chatbots/bose.py +17 -0
  13. parrot/chatbots/cody.py +17 -0
  14. parrot/chatbots/copilot.py +100 -0
  15. parrot/chatbots/dataframe.py +103 -0
  16. parrot/chatbots/hragents.py +15 -0
  17. parrot/chatbots/oddie.py +17 -0
  18. parrot/chatbots/retrievals/__init__.py +515 -0
  19. parrot/chatbots/retrievals/constitutional.py +19 -0
  20. parrot/conf.py +108 -0
  21. parrot/crew/__init__.py +3 -0
  22. parrot/crew/tools/__init__.py +22 -0
  23. parrot/crew/tools/bing.py +13 -0
  24. parrot/crew/tools/config.py +43 -0
  25. parrot/crew/tools/duckgo.py +62 -0
  26. parrot/crew/tools/file.py +24 -0
  27. parrot/crew/tools/google.py +168 -0
  28. parrot/crew/tools/gtrends.py +16 -0
  29. parrot/crew/tools/md2pdf.py +25 -0
  30. parrot/crew/tools/rag.py +42 -0
  31. parrot/crew/tools/search.py +32 -0
  32. parrot/crew/tools/url.py +21 -0
  33. parrot/exceptions.cpython-311-x86_64-linux-gnu.so +0 -0
  34. parrot/handlers/__init__.py +4 -0
  35. parrot/handlers/bots.py +196 -0
  36. parrot/handlers/chat.py +169 -0
  37. parrot/interfaces/__init__.py +6 -0
  38. parrot/interfaces/database.py +29 -0
  39. parrot/llms/__init__.py +0 -0
  40. parrot/llms/abstract.py +41 -0
  41. parrot/llms/anthropic.py +36 -0
  42. parrot/llms/google.py +37 -0
  43. parrot/llms/groq.py +33 -0
  44. parrot/llms/hf.py +39 -0
  45. parrot/llms/openai.py +49 -0
  46. parrot/llms/pipes.py +103 -0
  47. parrot/llms/vertex.py +68 -0
  48. parrot/loaders/__init__.py +20 -0
  49. parrot/loaders/abstract.py +456 -0
  50. parrot/loaders/basepdf.py +102 -0
  51. parrot/loaders/basevideo.py +280 -0
  52. parrot/loaders/csv.py +42 -0
  53. parrot/loaders/dir.py +37 -0
  54. parrot/loaders/excel.py +349 -0
  55. parrot/loaders/github.py +65 -0
  56. parrot/loaders/handlers/__init__.py +5 -0
  57. parrot/loaders/handlers/data.py +213 -0
  58. parrot/loaders/image.py +119 -0
  59. parrot/loaders/json.py +52 -0
  60. parrot/loaders/pdf.py +187 -0
  61. parrot/loaders/pdfchapters.py +142 -0
  62. parrot/loaders/pdffn.py +112 -0
  63. parrot/loaders/pdfimages.py +207 -0
  64. parrot/loaders/pdfmark.py +88 -0
  65. parrot/loaders/pdftables.py +145 -0
  66. parrot/loaders/ppt.py +30 -0
  67. parrot/loaders/qa.py +81 -0
  68. parrot/loaders/repo.py +103 -0
  69. parrot/loaders/rtd.py +65 -0
  70. parrot/loaders/txt.py +92 -0
  71. parrot/loaders/utils/__init__.py +1 -0
  72. parrot/loaders/utils/models.py +25 -0
  73. parrot/loaders/video.py +96 -0
  74. parrot/loaders/videolocal.py +107 -0
  75. parrot/loaders/vimeo.py +106 -0
  76. parrot/loaders/web.py +216 -0
  77. parrot/loaders/web_base.py +112 -0
  78. parrot/loaders/word.py +125 -0
  79. parrot/loaders/youtube.py +192 -0
  80. parrot/manager.py +152 -0
  81. parrot/models.py +347 -0
  82. parrot/py.typed +0 -0
  83. parrot/stores/__init__.py +0 -0
  84. parrot/stores/abstract.py +170 -0
  85. parrot/stores/milvus.py +540 -0
  86. parrot/stores/qdrant.py +153 -0
  87. parrot/tools/__init__.py +16 -0
  88. parrot/tools/abstract.py +53 -0
  89. parrot/tools/asknews.py +32 -0
  90. parrot/tools/bing.py +13 -0
  91. parrot/tools/duck.py +62 -0
  92. parrot/tools/google.py +170 -0
  93. parrot/tools/stack.py +26 -0
  94. parrot/tools/weather.py +70 -0
  95. parrot/tools/wikipedia.py +59 -0
  96. parrot/tools/zipcode.py +179 -0
  97. parrot/utils/__init__.py +2 -0
  98. parrot/utils/parsers/__init__.py +5 -0
  99. parrot/utils/parsers/toml.cpython-311-x86_64-linux-gnu.so +0 -0
  100. parrot/utils/toml.py +11 -0
  101. parrot/utils/types.cpython-311-x86_64-linux-gnu.so +0 -0
  102. parrot/utils/uv.py +11 -0
  103. parrot/version.py +10 -0
  104. resources/users/__init__.py +5 -0
  105. resources/users/handlers.py +13 -0
  106. resources/users/models.py +205 -0
  107. settings/__init__.py +0 -0
  108. settings/settings.py +51 -0
@@ -0,0 +1,169 @@
1
+ from navigator_auth.decorators import (
2
+ is_authenticated,
3
+ user_session
4
+ )
5
+ from navigator.views import BaseView
6
+
7
+
8
+ @is_authenticated()
9
+ @user_session()
10
+ class ChatHandler(BaseView):
11
+ """
12
+ ChatHandler.
13
+ description: ChatHandler for Parrot Application.
14
+ """
15
+
16
+ async def get(self, **kwargs):
17
+ """
18
+ get.
19
+ description: Get method for ChatHandler.
20
+ """
21
+ name = self.request.match_info.get('chatbot_name', None)
22
+ if not name:
23
+ return self.json_response({
24
+ "message": "Welcome to Parrot Chatbot Service."
25
+ })
26
+ else:
27
+ # retrieve chatbof information:
28
+ manager = self.request.app['chatbot_manager']
29
+ chatbot = manager.get_chatbot(name)
30
+ if not chatbot:
31
+ return self.error(
32
+ f"Chatbot {name} not found.",
33
+ status=404
34
+ )
35
+ return self.json_response({
36
+ "chatbot": chatbot.name,
37
+ "description": chatbot.description,
38
+ "role": chatbot.role,
39
+ "embedding_model_name": chatbot.embedding_model_name,
40
+ "llm": f"{chatbot.get_llm()!r}",
41
+ "temperature": chatbot.get_llm().temperature,
42
+ "args": chatbot.get_llm().args,
43
+ "config_file": chatbot.config_file
44
+ })
45
+
46
+ async def post(self, *args, **kwargs):
47
+ """
48
+ post.
49
+ description: Post method for ChatHandler.
50
+ """
51
+ print('ARGS > ', args, kwargs)
52
+ app = self.request.app
53
+ name = self.request.match_info.get('chatbot_name', None)
54
+ qs = self.query_parameters(self.request)
55
+ data = await self.request.json()
56
+ if not 'query' in data:
57
+ return self.json_response(
58
+ {
59
+ "message": "No query was found."
60
+ },
61
+ status=400
62
+ )
63
+ if 'use_llm' in qs:
64
+ # passing another LLM to the Chatbot:
65
+ llm = qs.get('use_llm')
66
+ else:
67
+ llm = None
68
+ try:
69
+ manager = app['chatbot_manager']
70
+ except KeyError:
71
+ return self.json_response(
72
+ {
73
+ "message": "Chatbot Manager is not installed."
74
+ },
75
+ status=404
76
+ )
77
+ try:
78
+ chatbot = manager.get_chatbot(name)
79
+ if not chatbot:
80
+ raise KeyError(
81
+ f"Chatbot {name} not found."
82
+ )
83
+ except (TypeError, KeyError):
84
+ return self.json_response(
85
+ {
86
+ "message": f"Chatbot {name} not found."
87
+ },
88
+ status=404
89
+ )
90
+ # getting the question:
91
+ question = data.get('query')
92
+ session = self.request.session
93
+ try:
94
+ session_id = session.get('session_id', None)
95
+ # print('SESSION ID > ', session_id)
96
+ memory_key = f'{session.session_id}_{name}_message_store'
97
+ # print('MEM STORAGE > ', memory_key)
98
+ memory = chatbot.get_memory(session_id=memory_key)
99
+ # print('MEMORY >> ', memory)
100
+ async with chatbot.get_retrieval(request=self.request) as retrieval:
101
+ qa = retrieval.conversation(
102
+ question=question,
103
+ search_kwargs={"k": 10},
104
+ use_llm=llm,
105
+ memory=memory
106
+ )
107
+ result = await qa.invoke(question)
108
+ # Drop "memory" information:
109
+ result.chat_history = None
110
+ result.source_documents = None
111
+ return self.json_response(response=result)
112
+ except ValueError as exc:
113
+ return self.error(
114
+ f"{exc}",
115
+ exception=exc,
116
+ status=400
117
+ )
118
+ except Exception as exc:
119
+ return self.error(
120
+ f"Error invoking chatbot {name}: {exc}",
121
+ exception=exc,
122
+ status=400
123
+ )
124
+
125
+
126
+ @is_authenticated()
127
+ @user_session()
128
+ class BotHandler(BaseView):
129
+ """BotHandler.
130
+
131
+
132
+ Use this handler to interact with a brand new chatbot, consuming a configuration.
133
+ """
134
+ async def put(self):
135
+ """Create a New Bot (passing a configuration).
136
+ """
137
+ try:
138
+ manager = self.request.app['chatbot_manager']
139
+ except KeyError:
140
+ return self.json_response(
141
+ {
142
+ "message": "Chatbot Manager is not installed."
143
+ },
144
+ status=404
145
+ )
146
+ # TODO: Making a Validation of data
147
+ data = await self.request.json()
148
+ name = data.pop('name', None)
149
+ if not name:
150
+ return self.json_response(
151
+ {
152
+ "message": "Name for Bot Creation is required."
153
+ },
154
+ status=400
155
+ )
156
+ try:
157
+ chatbot = manager.create_chatbot(name=name, **data)
158
+ await chatbot.configure(name=name, app=self.request.app)
159
+ return self.json_response(
160
+ {
161
+ "message": f"Chatbot {name} created successfully."
162
+ }
163
+ )
164
+ except Exception as exc:
165
+ return self.error(
166
+ f"Error creating chatbot {name}: {exc}",
167
+ exception=exc,
168
+ status=400
169
+ )
@@ -0,0 +1,6 @@
1
+ from .database import DBInterface
2
+
3
+
4
+ __all__ = (
5
+ 'DBInterface',
6
+ )
@@ -0,0 +1,29 @@
1
+ """DB (asyncdb) Extension.
2
+ DB connection for any Application.
3
+ """
4
+ from abc import ABCMeta
5
+ from asyncdb import AsyncDB
6
+
7
+
8
+
9
+ class DBInterface(metaclass=ABCMeta):
10
+ """
11
+ Interface for using database connections in an Application using AsyncDB.
12
+ """
13
+
14
+ def get_database(
15
+ self,
16
+ driver: str,
17
+ dsn: str = None,
18
+ params: dict = None,
19
+ timeout: int = 60,
20
+ **kwargs
21
+ ) -> AsyncDB:
22
+ """Get the driver."""
23
+ return AsyncDB(
24
+ driver,
25
+ dsn=dsn,
26
+ params=params,
27
+ timeout=timeout,
28
+ **kwargs
29
+ )
File without changes
@@ -0,0 +1,41 @@
1
+ from abc import ABC, abstractmethod
2
+ from langchain_core.prompts import ChatPromptTemplate
3
+
4
+
5
+ class AbstractLLM(ABC):
6
+ """Abstract Language Model class.
7
+ """
8
+
9
+ model: str = "databricks/dolly-v2-3b"
10
+ embed_model: str = None
11
+ max_tokens: int = 1024
12
+ max_retries: int = 3
13
+
14
+ def __init__(self, *args, **kwargs):
15
+ self.model = kwargs.get("model", "databricks/dolly-v2-3b")
16
+ self.task = kwargs.get("task", "text-generation")
17
+ self.temperature: float = kwargs.get('temperature', 0.6)
18
+ self.max_tokens: int = kwargs.get('max_tokens', 1024)
19
+ self.top_k: float = kwargs.get('top_k', 10)
20
+ self.top_p: float = kwargs.get('top_p', 0.90)
21
+ self.args = {
22
+ "top_p": self.top_p,
23
+ "top_k": self.top_k,
24
+ }
25
+ self._llm = None
26
+ self._embed = None
27
+
28
+ def get_llm(self):
29
+ return self._llm
30
+
31
+ def get_embedding(self):
32
+ return self._embed
33
+
34
+ def __call__(self, text: str, **kwargs):
35
+ return self._llm.invoke(text, **kwargs)
36
+
37
+ def get_prompt(self, system: tuple, human: str) -> ChatPromptTemplate:
38
+ """Get a prompt for the LLM."""
39
+ return ChatPromptTemplate.from_messages(
40
+ [("system", system), ("human", human)]
41
+ )
@@ -0,0 +1,36 @@
1
+ from navconfig import config
2
+ from navconfig.logging import logging
3
+ from langchain_anthropic import ChatAnthropic
4
+ from langchain_core.prompts import ChatPromptTemplate
5
+ from .abstract import AbstractLLM
6
+
7
+ logging.getLogger(name='anthropic').setLevel(logging.WARNING)
8
+
9
+ class Anthropic(AbstractLLM):
10
+ """Anthropic.
11
+
12
+ Interact with Anthropic Language Model.
13
+
14
+ Returns:
15
+ _type_: an instance of Anthropic (Claude) LLM Model.
16
+ """
17
+ model: str = 'claude-3-opus-20240229'
18
+ embed_model: str = None
19
+ max_tokens: int = 1024
20
+
21
+ def __init__(self, *args, **kwargs):
22
+ super().__init__(*args, **kwargs)
23
+ self.model = kwargs.get("model", 'claude-3-opus-20240229')
24
+ self._api_key = kwargs.pop('api_key', config.get('ANTHROPIC_API_KEY'))
25
+ args = {
26
+ "temperature": self.temperature,
27
+ "max_retries": 4,
28
+ "top_p": self.top_p,
29
+ "top_k": self.top_k,
30
+ "verbose": True,
31
+ }
32
+ self._llm = ChatAnthropic(
33
+ model_name=self.model,
34
+ api_key=self._api_key,
35
+ **args
36
+ )
parrot/llms/google.py ADDED
@@ -0,0 +1,37 @@
1
+ from langchain_google_genai import (
2
+ GoogleGenerativeAI,
3
+ ChatGoogleGenerativeAI,
4
+ GoogleGenerativeAIEmbeddings
5
+ )
6
+ from navconfig import config
7
+ from .abstract import AbstractLLM
8
+
9
+
10
+ class GoogleGenAI(AbstractLLM):
11
+ """GoogleGenAI.
12
+ Using Google Generative AI models with Google Cloud AI Platform.
13
+ """
14
+ model: str = "gemini-pro"
15
+
16
+ def __init__(self, *args, **kwargs):
17
+ self.model_type = kwargs.get("model_type", "chat")
18
+ super().__init__(*args, **kwargs)
19
+ self._api_key = kwargs.pop('api_key', config.get('GOOGLE_API_KEY'))
20
+ if self.model_type == 'chat':
21
+ base_llm = ChatGoogleGenerativeAI
22
+ else:
23
+ base_llm = GoogleGenerativeAI
24
+ self._llm = base_llm(
25
+ model=self.model,
26
+ api_key=self._api_key,
27
+ temperature=self.temperature,
28
+ **self.args
29
+ )
30
+ embed_model = kwargs.get("embed_model", "models/embedding-001")
31
+ self._embed = GoogleGenerativeAIEmbeddings(
32
+ model=embed_model,
33
+ google_api_key=self._api_key,
34
+ temperature=self.temperature,
35
+ top_p=self.top_p,
36
+ top_k=self.top_k,
37
+ )
parrot/llms/groq.py ADDED
@@ -0,0 +1,33 @@
1
+ from langchain_core.prompts import ChatPromptTemplate
2
+ from langchain_groq import ChatGroq
3
+ from navconfig import config
4
+ from .abstract import AbstractLLM
5
+
6
+
7
+ class GroqLLM(AbstractLLM):
8
+ """GroqLLM.
9
+ Using Groq Open-source models.
10
+ """
11
+ model: str = "mixtral-8x7b-32768"
12
+
13
+ def __init__(self, *args, **kwargs):
14
+ self.model_type = kwargs.get("model_type", "text")
15
+ system = kwargs.pop('system_prompt', "You are a helpful assistant.")
16
+ human = kwargs.pop('human_prompt', "{question}")
17
+ super().__init__(*args, **kwargs)
18
+ self._api_key = kwargs.pop('api_key', config.get('GROQ_API_KEY'))
19
+ self._llm = ChatGroq(
20
+ model_name=self.model,
21
+ groq_api_key=self._api_key,
22
+ temperature=self.temperature,
23
+ max_retries=self.max_retries,
24
+ max_tokens=self.max_tokens,
25
+ model_kwargs={
26
+ "top_p": self.top_p,
27
+ # "top_k": self.top_k,
28
+ },
29
+ )
30
+ self._embed = None # Not supported
31
+ self.prompt = ChatPromptTemplate.from_messages(
32
+ [("system", system), ("human", human)]
33
+ )
parrot/llms/hf.py ADDED
@@ -0,0 +1,39 @@
1
+ from langchain_community.llms import HuggingFacePipeline # pylint: disable=import-error, E0611
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from .abstract import AbstractLLM
4
+
5
+ class HuggingFace(AbstractLLM):
6
+ """HuggingFace.
7
+
8
+ Load a LLM (Language Model) from HuggingFace Hub.
9
+
10
+ Only supports text-generation, text2text-generation, summarization and translation for now.
11
+
12
+ Returns:
13
+ _type_: an instance of HuggingFace LLM Model.
14
+ """
15
+ model: str = "databricks/dolly-v2-3b"
16
+ embed_model: str = None
17
+ max_tokens: int = 1024
18
+
19
+ def __init__(self, *args, **kwargs):
20
+ self.batch_size = kwargs.get('batch_size', 4)
21
+ super().__init__(*args, **kwargs)
22
+ self._tokenizer = AutoTokenizer.from_pretrained(self.model, chunk_size=self.max_tokens)
23
+ self._model = AutoModelForCausalLM.from_pretrained(self.model, trust_remote_code=True)
24
+ self._llm = HuggingFacePipeline.from_model_id(
25
+ model_id=self.model,
26
+ task=self.task,
27
+ device_map='auto',
28
+ batch_size=self.batch_size,
29
+ model_kwargs={
30
+ "max_length": self.max_tokens,
31
+ "trust_remote_code": True
32
+ },
33
+ pipeline_kwargs={
34
+ "temperature": self.temperature,
35
+ "repetition_penalty":1.1,
36
+ "max_new_tokens": self.max_tokens,
37
+ **self.args
38
+ }
39
+ )
parrot/llms/openai.py ADDED
@@ -0,0 +1,49 @@
1
+ from langchain_openai import (
2
+ OpenAI,
3
+ ChatOpenAI,
4
+ OpenAIEmbeddings
5
+ )
6
+ from navconfig import config
7
+ from .abstract import AbstractLLM
8
+
9
+
10
+ class OpenAILLM(AbstractLLM):
11
+ """OpenAI.
12
+ Interact with OpenAI Language Model.
13
+
14
+ Returns:
15
+ _type_: an instance of OpenAI LLM Model.
16
+ """
17
+ model: str = "gpt-4-turbo"
18
+ embed_model: str = "text-embedding-3-large"
19
+ max_tokens: int = 1024
20
+
21
+ def __init__(self, *args, **kwargs):
22
+ self.model_type = kwargs.get("model_type", "text")
23
+ super().__init__(*args, **kwargs)
24
+ self.model = kwargs.get("model", "davinci")
25
+ self._api_key = kwargs.pop('api_key', config.get('OPENAI_API_KEY'))
26
+ organization = config.get("OPENAI_ORGANIZATION")
27
+ if self.model_type == 'chat':
28
+ base_llm = ChatOpenAI
29
+ else:
30
+ base_llm = OpenAI
31
+ self._llm = base_llm(
32
+ model_name=self.model,
33
+ api_key=self._api_key,
34
+ organization=organization,
35
+ temperature=self.temperature,
36
+ max_tokens=self.max_tokens,
37
+ **self.args
38
+ )
39
+ # Embedding
40
+ embed_model = kwargs.get("embed_model", "text-embedding-3-large")
41
+ self._embed = OpenAIEmbeddings(
42
+ model=embed_model,
43
+ dimensions=self.max_tokens,
44
+ api_key=self._api_key,
45
+ organization=organization,
46
+ temperature=self.temperature,
47
+ top_p=self.top_p,
48
+ top_k=self.top_k,
49
+ )
parrot/llms/pipes.py ADDED
@@ -0,0 +1,103 @@
1
+ import torch
2
+ from langchain_community.llms import HuggingFacePipeline # pylint: disable=import-error, E0611
3
+ from transformers import (
4
+ AutoModelForCausalLM,
5
+ AutoProcessor,
6
+ LlavaForConditionalGeneration,
7
+ AutoTokenizer,
8
+ GenerationConfig,
9
+ pipeline
10
+ )
11
+ from .abstract import AbstractLLM
12
+
13
+
14
+ class PipelineLLM(AbstractLLM):
15
+ """PipelineLLM.
16
+
17
+ Load a LLM (Language Model) from HuggingFace Hub.
18
+
19
+ Returns:
20
+ _type_: an instance of HuggingFace LLM Model.
21
+ """
22
+ model: str = "databricks/dolly-v2-3b"
23
+ embed_model: str = None
24
+ max_tokens: int = 1024
25
+
26
+ def __init__(self, *args, **kwargs):
27
+ self.batch_size = kwargs.get('batch_size', 4)
28
+ self.use_llava: bool = kwargs.get('use_llava', False)
29
+ self.model_args = kwargs.get('model_args', {})
30
+ super().__init__(*args, **kwargs)
31
+ dtype = kwargs.get('dtype', 'float16')
32
+ if dtype == 'bfloat16':
33
+ torch_dtype = torch.bfloat16
34
+ if dtype == 'float16':
35
+ torch_dtype = torch.float16
36
+ elif dtype == 'float32':
37
+ torch_dtype = torch.float32
38
+ elif dtype == 'float8':
39
+ torch_dtype = torch.float8
40
+ else:
41
+ torch_dtype = "auto"
42
+ use_fast = kwargs.get('use_fast', True)
43
+ if self.use_llava is False:
44
+ self.tokenizer = AutoTokenizer.from_pretrained(
45
+ self.model,
46
+ chunk_size=self.max_tokens
47
+ )
48
+ self._model = AutoModelForCausalLM.from_pretrained(
49
+ self.model,
50
+ device_map="auto",
51
+ torch_dtype=torch_dtype,
52
+ trust_remote_code=True,
53
+ )
54
+ config = GenerationConfig(
55
+ do_sample=True,
56
+ temperature=self.temperature,
57
+ max_new_tokens=self.max_tokens,
58
+ top_p=self.top_p,
59
+ top_k=self.top_k,
60
+ repetition_penalty=1.15,
61
+ )
62
+ self._pipe = pipeline(
63
+ task=self.task,
64
+ model=self._model,
65
+ tokenizer=self.tokenizer,
66
+ return_full_text=True,
67
+ use_fast=use_fast,
68
+ device_map='auto',
69
+ batch_size=self.batch_size,
70
+ generation_config=config,
71
+ pad_token_id = 50256,
72
+ framework="pt"
73
+ )
74
+ else:
75
+ self._model = LlavaForConditionalGeneration.from_pretrained(
76
+ self.model,
77
+ device_map="auto",
78
+ torch_dtype=torch_dtype,
79
+ trust_remote_code=True,
80
+ low_cpu_mem_usage=True,
81
+ )
82
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model)
83
+ processor = AutoProcessor.from_pretrained(self.model)
84
+ self._pipe = pipeline(
85
+ task=self.task,
86
+ model=self._model,
87
+ tokenizer=self.tokenizer,
88
+ use_fast=use_fast,
89
+ device_map='auto',
90
+ batch_size=self.batch_size,
91
+ image_processor=processor.image_processor,
92
+ framework="pt",
93
+ **self.model_args
94
+ )
95
+ self._pipe.tokenizer.pad_token_id = self._pipe.model.config.eos_token_id
96
+ self._llm = HuggingFacePipeline(
97
+ model_id=self.model,
98
+ pipeline=self._pipe,
99
+ verbose=True
100
+ )
101
+
102
+ def pipe(self, *args, **kwargs):
103
+ return self._pipe(*args, **kwargs, generate_kwargs={"max_new_tokens": self.max_tokens})
parrot/llms/vertex.py ADDED
@@ -0,0 +1,68 @@
1
+ import os
2
+ from navconfig import config, BASE_DIR
3
+ from google.cloud import aiplatform
4
+ from langchain_google_vertexai import (
5
+ ChatVertexAI,
6
+ VertexAI,
7
+ VertexAIModelGarden,
8
+ VertexAIEmbeddings
9
+ )
10
+ from .abstract import AbstractLLM
11
+
12
+ class VertexLLM(AbstractLLM):
13
+ """VertexLLM.
14
+
15
+ Interact with VertexAI Language Model.
16
+
17
+ Returns:
18
+ _type_: VertexAI LLM.
19
+ """
20
+ model: str = "gemini-1.0-pro"
21
+ embed_model: str = "textembedding-gecko@003"
22
+ max_tokens: int = 1024
23
+
24
+ def __init__(self, *args, **kwargs):
25
+ super().__init__(*args, **kwargs)
26
+ use_garden: bool = kwargs.get("use_garden", False)
27
+ project_id = config.get("VERTEX_PROJECT_ID")
28
+ region = config.get("VERTEX_REGION")
29
+ config_file = config.get('GOOGLE_CREDENTIALS_FILE', 'env/google/vertexai.json')
30
+ config_dir = BASE_DIR.joinpath(config_file)
31
+ os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = str(config_dir)
32
+ self.args = {
33
+ "project": project_id,
34
+ "location": region,
35
+ "max_output_tokens": self.max_tokens,
36
+ "temperature": self.temperature,
37
+ "max_retries": 4,
38
+ "top_p": self.top_p,
39
+ "top_k": self.top_k,
40
+ "verbose": True,
41
+ }
42
+ if use_garden is True:
43
+ base_llm = VertexAIModelGarden
44
+ self.args['endpoint_id'] = self.model
45
+ elif self.model == "chat":
46
+ self.args['model_name'] = "chat-bison@001"
47
+ base_llm = ChatVertexAI
48
+ else:
49
+ self.args['model_name'] = self.model
50
+ base_llm = VertexAI
51
+ # LLM
52
+ self._llm = base_llm(
53
+ system_prompt="Always respond in the same language as the user's question. If the user's language is not English, translate your response into their language.",
54
+ **self.args
55
+ )
56
+ # Embedding Model:
57
+ embed_model = kwargs.get("embed_model", self.embed_model)
58
+ self._embed = VertexAIEmbeddings(
59
+ model_name=embed_model,
60
+ project=project_id,
61
+ location=region,
62
+ request_parallelism=5,
63
+ max_retries=4,
64
+ temperature=self.temperature,
65
+ top_p=self.top_p,
66
+ top_k=self.top_k,
67
+ )
68
+ self._version_ = aiplatform.__version__
@@ -0,0 +1,20 @@
1
+ from .dir import load_directory
2
+ from .pdf import PDFLoader
3
+ from .web import WebLoader
4
+ from .youtube import YoutubeLoader
5
+ from .vimeo import VimeoLoader
6
+ from .word import MSWordLoader
7
+ from .ppt import PPTXLoader
8
+ from .repo import RepositoryLoader
9
+ from .github import GithubLoader
10
+ from .json import JSONLoader
11
+ from .excel import ExcelLoader
12
+ from .web_base import WebBaseLoader
13
+ from .pdfmark import PDFMarkdownLoader
14
+ from .pdfimages import PDFImageLoader
15
+ from .pdftables import PDFTablesLoader
16
+ from .pdfchapters import PDFChapterLoader
17
+ from .txt import TXTLoader
18
+ from .qa import QAFileLoader
19
+ from .rtd import ReadTheDocsLoader
20
+ from .videolocal import VideoLocalLoader