ai-parrot 0.1.0__cp311-cp311-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-parrot might be problematic. Click here for more details.
- ai_parrot-0.1.0.dist-info/LICENSE +21 -0
- ai_parrot-0.1.0.dist-info/METADATA +299 -0
- ai_parrot-0.1.0.dist-info/RECORD +108 -0
- ai_parrot-0.1.0.dist-info/WHEEL +5 -0
- ai_parrot-0.1.0.dist-info/top_level.txt +3 -0
- parrot/__init__.py +18 -0
- parrot/chatbots/__init__.py +7 -0
- parrot/chatbots/abstract.py +965 -0
- parrot/chatbots/asktroc.py +16 -0
- parrot/chatbots/base.py +257 -0
- parrot/chatbots/basic.py +9 -0
- parrot/chatbots/bose.py +17 -0
- parrot/chatbots/cody.py +17 -0
- parrot/chatbots/copilot.py +100 -0
- parrot/chatbots/dataframe.py +103 -0
- parrot/chatbots/hragents.py +15 -0
- parrot/chatbots/oddie.py +17 -0
- parrot/chatbots/retrievals/__init__.py +515 -0
- parrot/chatbots/retrievals/constitutional.py +19 -0
- parrot/conf.py +108 -0
- parrot/crew/__init__.py +3 -0
- parrot/crew/tools/__init__.py +22 -0
- parrot/crew/tools/bing.py +13 -0
- parrot/crew/tools/config.py +43 -0
- parrot/crew/tools/duckgo.py +62 -0
- parrot/crew/tools/file.py +24 -0
- parrot/crew/tools/google.py +168 -0
- parrot/crew/tools/gtrends.py +16 -0
- parrot/crew/tools/md2pdf.py +25 -0
- parrot/crew/tools/rag.py +42 -0
- parrot/crew/tools/search.py +32 -0
- parrot/crew/tools/url.py +21 -0
- parrot/exceptions.cpython-311-x86_64-linux-gnu.so +0 -0
- parrot/handlers/__init__.py +4 -0
- parrot/handlers/bots.py +196 -0
- parrot/handlers/chat.py +169 -0
- parrot/interfaces/__init__.py +6 -0
- parrot/interfaces/database.py +29 -0
- parrot/llms/__init__.py +0 -0
- parrot/llms/abstract.py +41 -0
- parrot/llms/anthropic.py +36 -0
- parrot/llms/google.py +37 -0
- parrot/llms/groq.py +33 -0
- parrot/llms/hf.py +39 -0
- parrot/llms/openai.py +49 -0
- parrot/llms/pipes.py +103 -0
- parrot/llms/vertex.py +68 -0
- parrot/loaders/__init__.py +20 -0
- parrot/loaders/abstract.py +456 -0
- parrot/loaders/basepdf.py +102 -0
- parrot/loaders/basevideo.py +280 -0
- parrot/loaders/csv.py +42 -0
- parrot/loaders/dir.py +37 -0
- parrot/loaders/excel.py +349 -0
- parrot/loaders/github.py +65 -0
- parrot/loaders/handlers/__init__.py +5 -0
- parrot/loaders/handlers/data.py +213 -0
- parrot/loaders/image.py +119 -0
- parrot/loaders/json.py +52 -0
- parrot/loaders/pdf.py +187 -0
- parrot/loaders/pdfchapters.py +142 -0
- parrot/loaders/pdffn.py +112 -0
- parrot/loaders/pdfimages.py +207 -0
- parrot/loaders/pdfmark.py +88 -0
- parrot/loaders/pdftables.py +145 -0
- parrot/loaders/ppt.py +30 -0
- parrot/loaders/qa.py +81 -0
- parrot/loaders/repo.py +103 -0
- parrot/loaders/rtd.py +65 -0
- parrot/loaders/txt.py +92 -0
- parrot/loaders/utils/__init__.py +1 -0
- parrot/loaders/utils/models.py +25 -0
- parrot/loaders/video.py +96 -0
- parrot/loaders/videolocal.py +107 -0
- parrot/loaders/vimeo.py +106 -0
- parrot/loaders/web.py +216 -0
- parrot/loaders/web_base.py +112 -0
- parrot/loaders/word.py +125 -0
- parrot/loaders/youtube.py +192 -0
- parrot/manager.py +152 -0
- parrot/models.py +347 -0
- parrot/py.typed +0 -0
- parrot/stores/__init__.py +0 -0
- parrot/stores/abstract.py +170 -0
- parrot/stores/milvus.py +540 -0
- parrot/stores/qdrant.py +153 -0
- parrot/tools/__init__.py +16 -0
- parrot/tools/abstract.py +53 -0
- parrot/tools/asknews.py +32 -0
- parrot/tools/bing.py +13 -0
- parrot/tools/duck.py +62 -0
- parrot/tools/google.py +170 -0
- parrot/tools/stack.py +26 -0
- parrot/tools/weather.py +70 -0
- parrot/tools/wikipedia.py +59 -0
- parrot/tools/zipcode.py +179 -0
- parrot/utils/__init__.py +2 -0
- parrot/utils/parsers/__init__.py +5 -0
- parrot/utils/parsers/toml.cpython-311-x86_64-linux-gnu.so +0 -0
- parrot/utils/toml.py +11 -0
- parrot/utils/types.cpython-311-x86_64-linux-gnu.so +0 -0
- parrot/utils/uv.py +11 -0
- parrot/version.py +10 -0
- resources/users/__init__.py +5 -0
- resources/users/handlers.py +13 -0
- resources/users/models.py +205 -0
- settings/__init__.py +0 -0
- settings/settings.py +51 -0
|
@@ -0,0 +1,965 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from typing import Any, Union
|
|
4
|
+
from pathlib import Path, PurePath
|
|
5
|
+
import uuid
|
|
6
|
+
from aiohttp import web
|
|
7
|
+
import torch
|
|
8
|
+
from transformers import (
|
|
9
|
+
AutoModel,
|
|
10
|
+
AutoConfig,
|
|
11
|
+
AutoTokenizer,
|
|
12
|
+
# AutoModelForSeq2SeqLM
|
|
13
|
+
)
|
|
14
|
+
# Langchain
|
|
15
|
+
from langchain import hub
|
|
16
|
+
from langchain.docstore.document import Document
|
|
17
|
+
from langchain.memory import (
|
|
18
|
+
# ConversationSummaryMemory,
|
|
19
|
+
ConversationBufferMemory
|
|
20
|
+
)
|
|
21
|
+
# from langchain.retrievers import (
|
|
22
|
+
# EnsembleRetriever,
|
|
23
|
+
# ContextualCompressionRetriever
|
|
24
|
+
# )
|
|
25
|
+
from langchain.text_splitter import (
|
|
26
|
+
RecursiveCharacterTextSplitter
|
|
27
|
+
)
|
|
28
|
+
# from langchain.chains.retrieval_qa.base import RetrievalQA
|
|
29
|
+
# from langchain.chains.conversational_retrieval.base import (
|
|
30
|
+
# ConversationalRetrievalChain
|
|
31
|
+
# )
|
|
32
|
+
# from langchain_core.runnables import (
|
|
33
|
+
# RunnablePassthrough,
|
|
34
|
+
# RunnablePick,
|
|
35
|
+
# RunnableParallel
|
|
36
|
+
# )
|
|
37
|
+
# from langchain_core.output_parsers import StrOutputParser
|
|
38
|
+
# from langchain_core.prompts import (
|
|
39
|
+
# PromptTemplate,
|
|
40
|
+
# ChatPromptTemplate
|
|
41
|
+
# )
|
|
42
|
+
# from langchain_core.vectorstores import VectorStoreRetriever
|
|
43
|
+
# from langchain_community.retrievers import BM25Retriever
|
|
44
|
+
from langchain_community.chat_message_histories import RedisChatMessageHistory
|
|
45
|
+
|
|
46
|
+
# Navconfig
|
|
47
|
+
from navconfig import BASE_DIR
|
|
48
|
+
from navconfig.exceptions import ConfigError # pylint: disable=E0611
|
|
49
|
+
from navconfig.logging import logging
|
|
50
|
+
from asyncdb.exceptions import NoDataFound
|
|
51
|
+
|
|
52
|
+
try:
|
|
53
|
+
from ..stores.qdrant import QdrantStore
|
|
54
|
+
QDRANT_ENABLED = True
|
|
55
|
+
except (ModuleNotFoundError, ImportError):
|
|
56
|
+
QDRANT_ENABLED = False
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
from ..stores.milvus import MilvusStore
|
|
60
|
+
MILVUS_ENABLED = True
|
|
61
|
+
except (ModuleNotFoundError, ImportError):
|
|
62
|
+
MILVUS_ENABLED = False
|
|
63
|
+
|
|
64
|
+
from ..utils import SafeDict, parse_toml_config
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
## LLM configuration
|
|
68
|
+
# Vertex
|
|
69
|
+
try:
|
|
70
|
+
from ..llms.vertex import VertexLLM
|
|
71
|
+
VERTEX_ENABLED = True
|
|
72
|
+
except (ModuleNotFoundError, ImportError):
|
|
73
|
+
VERTEX_ENABLED = False
|
|
74
|
+
|
|
75
|
+
# Anthropic:
|
|
76
|
+
try:
|
|
77
|
+
from ..llms.anthropic import Anthropic
|
|
78
|
+
ANTHROPIC_ENABLED = True
|
|
79
|
+
except (ModuleNotFoundError, ImportError):
|
|
80
|
+
ANTHROPIC_ENABLED = False
|
|
81
|
+
|
|
82
|
+
# OpenAI
|
|
83
|
+
try:
|
|
84
|
+
from ..llms.openai import OpenAILLM
|
|
85
|
+
OPENAI_ENABLED = True
|
|
86
|
+
except (ModuleNotFoundError, ImportError):
|
|
87
|
+
OPENAI_ENABLED = False
|
|
88
|
+
|
|
89
|
+
# LLM Transformers
|
|
90
|
+
try:
|
|
91
|
+
from ..llms.pipes import PipelineLLM
|
|
92
|
+
TRANSFORMERS_ENABLED = True
|
|
93
|
+
except (ModuleNotFoundError, ImportError):
|
|
94
|
+
TRANSFORMERS_ENABLED = False
|
|
95
|
+
|
|
96
|
+
# HuggingFaces Hub:
|
|
97
|
+
try:
|
|
98
|
+
from ..llms.hf import HuggingFace
|
|
99
|
+
HF_ENABLED = True
|
|
100
|
+
except (ModuleNotFoundError, ImportError):
|
|
101
|
+
HF_ENABLED = False
|
|
102
|
+
|
|
103
|
+
# GroQ:
|
|
104
|
+
try:
|
|
105
|
+
from ..llms.groq import GroqLLM
|
|
106
|
+
GROQ_ENABLED = True
|
|
107
|
+
except (ModuleNotFoundError, ImportError):
|
|
108
|
+
GROQ_ENABLED = False
|
|
109
|
+
|
|
110
|
+
from ..loaders import (
|
|
111
|
+
PDFLoader,
|
|
112
|
+
PDFTablesLoader,
|
|
113
|
+
GithubLoader,
|
|
114
|
+
RepositoryLoader,
|
|
115
|
+
WebLoader,
|
|
116
|
+
VimeoLoader,
|
|
117
|
+
YoutubeLoader,
|
|
118
|
+
PPTXLoader,
|
|
119
|
+
MSWordLoader
|
|
120
|
+
)
|
|
121
|
+
from .retrievals import RetrievalManager
|
|
122
|
+
from ..conf import (
|
|
123
|
+
DEFAULT_LLM_MODEL_NAME,
|
|
124
|
+
EMBEDDING_DEVICE,
|
|
125
|
+
MAX_VRAM_AVAILABLE,
|
|
126
|
+
RAM_AVAILABLE,
|
|
127
|
+
default_dsn,
|
|
128
|
+
REDIS_HISTORY_URL
|
|
129
|
+
)
|
|
130
|
+
from ..interfaces import DBInterface
|
|
131
|
+
from ..models import ChatbotModel
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
logging.getLogger(name='selenium.webdriver').setLevel(logging.WARNING)
|
|
135
|
+
logging.getLogger(name='selenium').setLevel(logging.INFO)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class AbstractChatbot(ABC, DBInterface):
|
|
139
|
+
"""Represents an Chatbot in Navigator.
|
|
140
|
+
|
|
141
|
+
Each Chatbot has a name, a role, a goal, a backstory,
|
|
142
|
+
and an optional language model (llm).
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
template_prompt: str = (
|
|
146
|
+
"You are {name}, an expert AI assistant and {role} Working at {company}.\n\n"
|
|
147
|
+
"Your primary function is to {goal}\n"
|
|
148
|
+
"Use the provided context of the documents you have processed or extracted from other provided tools or sources to provide informative, detailed and accurate responses.\n"
|
|
149
|
+
"I am here to help with {role}, {backstory}.\n\n"
|
|
150
|
+
"Focus on answering the question directly but detailed. Do not include an introduction or greeting in your response.\n\n"
|
|
151
|
+
"{company_information}\n\n"
|
|
152
|
+
"Context: {context}\n\n"
|
|
153
|
+
"Given this information, please provide answers to the following question adding detailed and useful insights:\n\n"
|
|
154
|
+
"Chat History: {chat_history}\n"
|
|
155
|
+
"Human: {question}\n"
|
|
156
|
+
"Here is a brief summary of relevant information:\n"
|
|
157
|
+
"{summaries}\n\n"
|
|
158
|
+
"Assistant Answer:\n"
|
|
159
|
+
"{rationale}\n"
|
|
160
|
+
"You are a fluent speaker, you can talk and respond fluently in English and Spanish, and you must answer in the same language as the user's question. If the user's language is not English, you should translate your response into their language.\n"
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
def _get_default_attr(self, key, default: Any = None, **kwargs):
|
|
164
|
+
if key in kwargs:
|
|
165
|
+
return kwargs.get(key)
|
|
166
|
+
if hasattr(self, key):
|
|
167
|
+
return getattr(self, key)
|
|
168
|
+
if not hasattr(self, key):
|
|
169
|
+
return default
|
|
170
|
+
return getattr(self, key)
|
|
171
|
+
|
|
172
|
+
def __init__(self, **kwargs):
|
|
173
|
+
"""Initialize the Chatbot with the given configuration."""
|
|
174
|
+
# Chatbot ID:
|
|
175
|
+
self.chatbot_id: uuid.UUID = kwargs.get(
|
|
176
|
+
'chatbot_id',
|
|
177
|
+
None
|
|
178
|
+
)
|
|
179
|
+
# Basic Information:
|
|
180
|
+
self.name = self._get_default_attr(
|
|
181
|
+
'name', 'NAV', **kwargs
|
|
182
|
+
)
|
|
183
|
+
## Logging:
|
|
184
|
+
self.logger = logging.getLogger(f'{self.name}.Chatbot')
|
|
185
|
+
self.description = self._get_default_attr(
|
|
186
|
+
'description', 'Navigator Chatbot', **kwargs
|
|
187
|
+
)
|
|
188
|
+
self.role = self._get_default_attr(
|
|
189
|
+
'role', 'Chatbot', **kwargs
|
|
190
|
+
)
|
|
191
|
+
self.goal = self._get_default_attr(
|
|
192
|
+
'goal', 'provide helpful information to users', **kwargs
|
|
193
|
+
)
|
|
194
|
+
self.backstory = self._get_default_attr(
|
|
195
|
+
'backstory',
|
|
196
|
+
default=self.default_backstory(),
|
|
197
|
+
**kwargs
|
|
198
|
+
)
|
|
199
|
+
self.rationale = self._get_default_attr(
|
|
200
|
+
'rationale',
|
|
201
|
+
default=self.default_rationale(),
|
|
202
|
+
**kwargs
|
|
203
|
+
)
|
|
204
|
+
# Configuration File:
|
|
205
|
+
self.config_file: PurePath = kwargs.get('config_file', None)
|
|
206
|
+
# Other Configuration
|
|
207
|
+
self.confidence_threshold: float = kwargs.get('threshold', 0.5)
|
|
208
|
+
self.context = kwargs.pop('context', '')
|
|
209
|
+
|
|
210
|
+
# Company Information:
|
|
211
|
+
self.company_information: dict = kwargs.pop('company_information', {})
|
|
212
|
+
|
|
213
|
+
# Pre-Instructions:
|
|
214
|
+
self.pre_instructions: list = kwargs.get(
|
|
215
|
+
'pre_instructions',
|
|
216
|
+
[]
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
# Knowledge base:
|
|
220
|
+
self.knowledge_base: list = []
|
|
221
|
+
self._documents_: list = []
|
|
222
|
+
|
|
223
|
+
# Text Documents
|
|
224
|
+
self.documents_dir = kwargs.get(
|
|
225
|
+
'documents_dir',
|
|
226
|
+
None
|
|
227
|
+
)
|
|
228
|
+
if isinstance(self.documents_dir, str):
|
|
229
|
+
self.documents_dir = Path(self.documents_dir)
|
|
230
|
+
if not self.documents_dir:
|
|
231
|
+
self.documents_dir = BASE_DIR.joinpath('documents')
|
|
232
|
+
if not self.documents_dir.exists():
|
|
233
|
+
self.documents_dir.mkdir(
|
|
234
|
+
parents=True,
|
|
235
|
+
exist_ok=True
|
|
236
|
+
)
|
|
237
|
+
# Models, Embed and collections
|
|
238
|
+
# Vector information:
|
|
239
|
+
self.chunk_size: int = int(kwargs.get('chunk_size', 768))
|
|
240
|
+
self.dimension: int = int(kwargs.get('dimension', 768))
|
|
241
|
+
self._database: dict = kwargs.get('database', {})
|
|
242
|
+
self._store: Callable = None
|
|
243
|
+
# Embedding Model Name
|
|
244
|
+
self.use_bge: bool = bool(
|
|
245
|
+
kwargs.get('use_bge', 'False')
|
|
246
|
+
)
|
|
247
|
+
self.use_fastembed: bool = bool(
|
|
248
|
+
kwargs.get('use_fastembed', 'False')
|
|
249
|
+
)
|
|
250
|
+
self.embedding_model_name = kwargs.get(
|
|
251
|
+
'embedding_model_name', None
|
|
252
|
+
)
|
|
253
|
+
# embedding object:
|
|
254
|
+
self.embeddings = kwargs.get('embeddings', None)
|
|
255
|
+
self.tokenizer_model_name = kwargs.get(
|
|
256
|
+
'tokenizer', None
|
|
257
|
+
)
|
|
258
|
+
self.summarization_model = kwargs.get(
|
|
259
|
+
'summarization_model',
|
|
260
|
+
"facebook/bart-large-cnn"
|
|
261
|
+
)
|
|
262
|
+
self.rag_model = kwargs.get(
|
|
263
|
+
'rag_model',
|
|
264
|
+
"rlm/rag-prompt-llama"
|
|
265
|
+
)
|
|
266
|
+
self._text_splitter_model = kwargs.get(
|
|
267
|
+
'text_splitter',
|
|
268
|
+
'mixedbread-ai/mxbai-embed-large-v1'
|
|
269
|
+
)
|
|
270
|
+
# Definition of LLM
|
|
271
|
+
self._llm: Callable = None
|
|
272
|
+
self._llm_obj: Callable = kwargs.get('llm', None)
|
|
273
|
+
|
|
274
|
+
# Max VRAM usage:
|
|
275
|
+
self._max_vram = int(kwargs.get('max_vram', MAX_VRAM_AVAILABLE))
|
|
276
|
+
|
|
277
|
+
def get_llm(self):
|
|
278
|
+
return self._llm_obj
|
|
279
|
+
|
|
280
|
+
def __repr__(self):
|
|
281
|
+
return f"<Chatbot.{self.__class__.__name__}:{self.name}>"
|
|
282
|
+
|
|
283
|
+
# Database:
|
|
284
|
+
@property
|
|
285
|
+
def store(self):
|
|
286
|
+
if not self._store.connected:
|
|
287
|
+
self._store.connect()
|
|
288
|
+
return self._store
|
|
289
|
+
|
|
290
|
+
def default_rationale(self) -> str:
|
|
291
|
+
# TODO: read rationale from a file
|
|
292
|
+
return """
|
|
293
|
+
I am a language model trained by Google.
|
|
294
|
+
I am designed to provide helpful information to users.
|
|
295
|
+
Remember to maintain a professional tone.
|
|
296
|
+
If I cannot find relevant information in the documents,
|
|
297
|
+
I will indicate this and suggest alternative avenues for the user to find an answer.
|
|
298
|
+
"""
|
|
299
|
+
|
|
300
|
+
def default_backstory(self) -> str:
|
|
301
|
+
return (
|
|
302
|
+
"help with Human Resources related queries or knowledge-based questions about T-ROC Global.\n"
|
|
303
|
+
"You can ask me about the company's products and services, the company's culture, the company's clients.\n"
|
|
304
|
+
"You have the capability to read and understand various Human Resources documents, "
|
|
305
|
+
"such as employee handbooks, policy documents, onboarding materials, company's website, and more.\n"
|
|
306
|
+
"I can also provide information about the company's policies and procedures, benefits, and other HR-related topics."
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
def load_llm(self, llm_name: str, model_name: str = None, **kwargs):
|
|
310
|
+
"""Load the Language Model for the Chatbot.
|
|
311
|
+
"""
|
|
312
|
+
print('LLM > ', llm_name)
|
|
313
|
+
if llm_name == 'VertexLLM':
|
|
314
|
+
if VERTEX_ENABLED is False:
|
|
315
|
+
raise ConfigError(
|
|
316
|
+
"VertexAI enabled but not installed."
|
|
317
|
+
)
|
|
318
|
+
return VertexLLM(model=model_name, **kwargs)
|
|
319
|
+
elif llm_name == 'Anthropic':
|
|
320
|
+
if ANTHROPIC_ENABLED is False:
|
|
321
|
+
raise ConfigError(
|
|
322
|
+
"ANTHROPIC is enabled but not installed."
|
|
323
|
+
)
|
|
324
|
+
return Anthropic(model=model_name, **kwargs)
|
|
325
|
+
elif llm_name == 'OpenAI':
|
|
326
|
+
if OPENAI_ENABLED is False:
|
|
327
|
+
raise ConfigError(
|
|
328
|
+
"OpenAI is enabled but not installed."
|
|
329
|
+
)
|
|
330
|
+
return OpenAILLM(model=model_name, **kwargs)
|
|
331
|
+
elif llm_name == 'hf':
|
|
332
|
+
if HF_ENABLED is False:
|
|
333
|
+
raise ConfigError(
|
|
334
|
+
"Hugginfaces Hub is enabled but not installed."
|
|
335
|
+
)
|
|
336
|
+
return HuggingFace(model=model_name, **kwargs)
|
|
337
|
+
elif llm_name == 'pipe':
|
|
338
|
+
if TRANSFORMERS_ENABLED is False:
|
|
339
|
+
raise ConfigError(
|
|
340
|
+
"Transformes Pipelines are enabled, but not installed."
|
|
341
|
+
)
|
|
342
|
+
return PipelineLLM(model=model_name, **kwargs)
|
|
343
|
+
elif llm_name == 'Groq':
|
|
344
|
+
if GROQ_ENABLED is False:
|
|
345
|
+
raise ConfigError(
|
|
346
|
+
"Groq is enabled but not installed."
|
|
347
|
+
)
|
|
348
|
+
return GroqLLM(model=model_name, **kwargs)
|
|
349
|
+
# TODO: Add more LLMs
|
|
350
|
+
return hub.pull(llm_name)
|
|
351
|
+
|
|
352
|
+
async def configure(self, app = None) -> None:
|
|
353
|
+
if isinstance(app, web.Application):
|
|
354
|
+
self.app = app # register the app into the Extension
|
|
355
|
+
elif app is None:
|
|
356
|
+
self.app = None
|
|
357
|
+
else:
|
|
358
|
+
self.app = app.get_app() # Nav Application
|
|
359
|
+
# Config File:
|
|
360
|
+
config_file = BASE_DIR.joinpath(
|
|
361
|
+
'etc',
|
|
362
|
+
'config',
|
|
363
|
+
'chatbots',
|
|
364
|
+
self.name.lower(),
|
|
365
|
+
"config.toml"
|
|
366
|
+
)
|
|
367
|
+
if config_file.exists():
|
|
368
|
+
self.logger.notice(
|
|
369
|
+
f"Using Bot config {config_file}"
|
|
370
|
+
)
|
|
371
|
+
else:
|
|
372
|
+
config_file = None
|
|
373
|
+
# Database-based Bot Configuration
|
|
374
|
+
if self.chatbot_id is not None:
|
|
375
|
+
# Configure from the Database
|
|
376
|
+
await self.from_database(config_file)
|
|
377
|
+
elif config_file:
|
|
378
|
+
# Configure from the TOML file
|
|
379
|
+
await self.from_config_file(config_file)
|
|
380
|
+
# else:
|
|
381
|
+
# # Configure from a default configuration
|
|
382
|
+
# vector_config = {
|
|
383
|
+
# "vector_database": self.vector_database,
|
|
384
|
+
# "collection_name": self.collection_name
|
|
385
|
+
# }
|
|
386
|
+
# # configure vector database:
|
|
387
|
+
# await self.store_configuration(
|
|
388
|
+
# config=vector_config
|
|
389
|
+
# )
|
|
390
|
+
# # Get the Embeddings:
|
|
391
|
+
# if not self.embedding_model_name:
|
|
392
|
+
# self.embeddings = self._llm_obj.get_embedding()
|
|
393
|
+
# # Config Prompt:
|
|
394
|
+
# self._define_prompt(
|
|
395
|
+
# config={}
|
|
396
|
+
# )
|
|
397
|
+
# adding this configured chatbot to app:
|
|
398
|
+
if self.app:
|
|
399
|
+
self.app[f"{self.name.lower()}_chatbot"] = self
|
|
400
|
+
|
|
401
|
+
def _configure_llm(self, llm, config):
|
|
402
|
+
if self._llm_obj:
|
|
403
|
+
self._llm = self._llm_obj.get_llm()
|
|
404
|
+
else:
|
|
405
|
+
if llm:
|
|
406
|
+
# LLM:
|
|
407
|
+
self._llm_obj = self.load_llm(
|
|
408
|
+
llm,
|
|
409
|
+
**config
|
|
410
|
+
)
|
|
411
|
+
self._llm = self._llm_obj.get_llm()
|
|
412
|
+
else:
|
|
413
|
+
raise ValueError(
|
|
414
|
+
f"LLM is not defined in the Configuration."
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
def _from_bot(self, bot, key, config, default) -> Any:
|
|
418
|
+
value = getattr(bot, key, None)
|
|
419
|
+
file_value = config.get(key, default)
|
|
420
|
+
return value if value else file_value
|
|
421
|
+
|
|
422
|
+
async def from_database(self, config_file: PurePath = None) -> None:
|
|
423
|
+
"""Load the Chatbot Configuration from the Database."""
|
|
424
|
+
file_config = await parse_toml_config(config_file)
|
|
425
|
+
db = self.get_database('pg', dsn=default_dsn)
|
|
426
|
+
bot = None
|
|
427
|
+
async with await db.connection() as conn: # pylint: disable=E1101
|
|
428
|
+
# import model
|
|
429
|
+
ChatbotModel.Meta.connection = conn
|
|
430
|
+
try:
|
|
431
|
+
if self.chatbot_id:
|
|
432
|
+
try:
|
|
433
|
+
bot = await ChatbotModel.get(chatbot_id=self.chatbot_id)
|
|
434
|
+
except Exception:
|
|
435
|
+
bot = await ChatbotModel.get(name=self.name)
|
|
436
|
+
else:
|
|
437
|
+
bot = await ChatbotModel.get(name=self.name)
|
|
438
|
+
except NoDataFound:
|
|
439
|
+
# Fallback to File configuration:
|
|
440
|
+
if file_config:
|
|
441
|
+
await self.from_config_file(config_file)
|
|
442
|
+
else:
|
|
443
|
+
raise ConfigError(
|
|
444
|
+
f"Chatbot {self.name} not found in the database."
|
|
445
|
+
)
|
|
446
|
+
if not bot:
|
|
447
|
+
raise ConfigError(
|
|
448
|
+
f"Chatbot {self.name} not found in the database."
|
|
449
|
+
)
|
|
450
|
+
# Start Bot configuration from Database:
|
|
451
|
+
config_file = Path(bot.config_file).resolve()
|
|
452
|
+
if config_file:
|
|
453
|
+
file_config = await parse_toml_config(config_file)
|
|
454
|
+
# basic configuration
|
|
455
|
+
basic = file_config.get('chatbot', {})
|
|
456
|
+
self.name = self._from_bot(bot, 'name', basic, self.name)
|
|
457
|
+
self.description = self._from_bot(bot, 'description', basic, self.description)
|
|
458
|
+
self.role = self._from_bot(bot, 'role', basic, self.role)
|
|
459
|
+
self.goal = self._from_bot(bot, 'goal', basic, self.goal)
|
|
460
|
+
self.rationale = self._from_bot(bot, 'rationale', basic, self.rationale)
|
|
461
|
+
self.backstory = self._from_bot(bot, 'backstory', basic, self.backstory)
|
|
462
|
+
# company information:
|
|
463
|
+
self.company_information = self._from_bot(
|
|
464
|
+
bot, 'company_information', basic, self.company_information
|
|
465
|
+
)
|
|
466
|
+
# Contextual knowledge-base
|
|
467
|
+
self.kb = file_config.get('knowledge-base', [])
|
|
468
|
+
if self.kb:
|
|
469
|
+
self.knowledge_base = self.create_kb(
|
|
470
|
+
self.kb.get('data', [])
|
|
471
|
+
)
|
|
472
|
+
# Model Information:
|
|
473
|
+
models = file_config.get('llm', {})
|
|
474
|
+
# LLM Configuration (from file and from db)
|
|
475
|
+
llm_config = models.get('config', bot.llm_config)
|
|
476
|
+
llm = self._from_bot(bot, 'llm', models, 'VertexLLM')
|
|
477
|
+
# Configuration of LLM:
|
|
478
|
+
self._configure_llm(llm, llm_config)
|
|
479
|
+
# Other models:
|
|
480
|
+
models = file_config.get('models', {})
|
|
481
|
+
self.embedding_model_name = self._from_bot(
|
|
482
|
+
bot, 'embedding_name', models, None
|
|
483
|
+
)
|
|
484
|
+
self.tokenizer_model_name = self._from_bot(
|
|
485
|
+
bot, 'tokenizer', models, None
|
|
486
|
+
)
|
|
487
|
+
self.summarization_model = self._from_bot(
|
|
488
|
+
bot, 'summarize_model', models, "facebook/bart-large-cnn"
|
|
489
|
+
)
|
|
490
|
+
self.classification_model = self._from_bot(
|
|
491
|
+
bot, 'classification_model', models, None
|
|
492
|
+
)
|
|
493
|
+
# Database Configuration:
|
|
494
|
+
vector_config = file_config.get('database', {})
|
|
495
|
+
db_config = bot.database
|
|
496
|
+
db_config = {**vector_config, **db_config}
|
|
497
|
+
vector_db = db_config.pop('vector_database')
|
|
498
|
+
await self.store_configuration(vector_db, db_config)
|
|
499
|
+
# after configuration, setup the chatbot
|
|
500
|
+
if bot.template_prompt:
|
|
501
|
+
self.template_prompt = bot.template_prompt
|
|
502
|
+
self._define_prompt(
|
|
503
|
+
config={}
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
async def from_config_file(self, config_file: PurePath) -> None:
|
|
507
|
+
"""Load the Chatbot Configuration from the TOML file."""
|
|
508
|
+
self.logger.debug(
|
|
509
|
+
f"Using Config File: {config_file}"
|
|
510
|
+
)
|
|
511
|
+
file_config = await parse_toml_config(config_file)
|
|
512
|
+
# getting the configuration from config
|
|
513
|
+
self.config_file = config_file
|
|
514
|
+
# basic config
|
|
515
|
+
basic = file_config.get('chatbot', {})
|
|
516
|
+
# Chatbot Name:
|
|
517
|
+
self.name = basic.get('name', self.name)
|
|
518
|
+
self.description = basic.get('description', self.description)
|
|
519
|
+
self.role = basic.get('role', self.role)
|
|
520
|
+
self.goal = basic.get('goal', self.goal)
|
|
521
|
+
self.rationale = basic.get('rationale', self.rationale)
|
|
522
|
+
self.backstory = basic.get('backstory', self.backstory)
|
|
523
|
+
# Company Information:
|
|
524
|
+
self.company_information = basic.get(
|
|
525
|
+
'company_information',
|
|
526
|
+
self.company_information
|
|
527
|
+
)
|
|
528
|
+
# Model Information:
|
|
529
|
+
llminfo = file_config.get('llm')
|
|
530
|
+
llm = llminfo.get('llm', 'VertexLLM')
|
|
531
|
+
cfg = llminfo.get('config', {})
|
|
532
|
+
# Configuration of LLM:
|
|
533
|
+
self._configure_llm(llm, cfg)
|
|
534
|
+
|
|
535
|
+
# Other models:
|
|
536
|
+
models = file_config.get('models', {})
|
|
537
|
+
if not self.embedding_model_name:
|
|
538
|
+
self.embedding_model_name = models.get(
|
|
539
|
+
'embedding_name', None,
|
|
540
|
+
)
|
|
541
|
+
if not self.tokenizer_model_name:
|
|
542
|
+
self.tokenizer_model_name = models.get('tokenizer')
|
|
543
|
+
if not self.embedding_model_name:
|
|
544
|
+
# Getting the Embedding Model from the LLM
|
|
545
|
+
self.embeddings = self._llm_obj.get_embedding()
|
|
546
|
+
self.use_bge = models.get('use_bge', False)
|
|
547
|
+
self.use_fastembed = models.get('use_fastembed', False)
|
|
548
|
+
self.summarization_model = models.get(
|
|
549
|
+
'summarize_model',
|
|
550
|
+
"facebook/bart-large-cnn"
|
|
551
|
+
)
|
|
552
|
+
self.classification_model = models.get(
|
|
553
|
+
'classification_model',
|
|
554
|
+
None
|
|
555
|
+
)
|
|
556
|
+
# pre-instructions
|
|
557
|
+
instructions = file_config.get('pre-instructions')
|
|
558
|
+
if instructions:
|
|
559
|
+
self.pre_instructions = instructions.get('instructions', [])
|
|
560
|
+
# Contextual knowledge-base
|
|
561
|
+
self.kb = file_config.get('knowledge-base', [])
|
|
562
|
+
if self.kb:
|
|
563
|
+
self.knowledge_base = self.create_kb(
|
|
564
|
+
self.kb.get('data', [])
|
|
565
|
+
)
|
|
566
|
+
vector_config = file_config.get('database', {})
|
|
567
|
+
vector_db = vector_config.pop('vector_database')
|
|
568
|
+
# configure vector database:
|
|
569
|
+
await self.store_configuration(
|
|
570
|
+
vector_db,
|
|
571
|
+
vector_config
|
|
572
|
+
)
|
|
573
|
+
# after configuration, setup the chatbot
|
|
574
|
+
if 'template_prompt' in basic:
|
|
575
|
+
self.template_prompt = basic.get('template_prompt')
|
|
576
|
+
self._define_prompt(
|
|
577
|
+
config=basic
|
|
578
|
+
)
|
|
579
|
+
|
|
580
|
+
def create_kb(self, documents: list):
|
|
581
|
+
new_docs = []
|
|
582
|
+
for doc in documents:
|
|
583
|
+
content = doc.pop('content')
|
|
584
|
+
source = doc.pop('source', 'knowledge-base')
|
|
585
|
+
if doc:
|
|
586
|
+
meta = {
|
|
587
|
+
'source': source,
|
|
588
|
+
**doc
|
|
589
|
+
}
|
|
590
|
+
else:
|
|
591
|
+
meta = { 'source': source}
|
|
592
|
+
if content:
|
|
593
|
+
new_docs.append(
|
|
594
|
+
Document(
|
|
595
|
+
page_content=content,
|
|
596
|
+
metadata=meta
|
|
597
|
+
)
|
|
598
|
+
)
|
|
599
|
+
return new_docs
|
|
600
|
+
|
|
601
|
+
async def store_configuration(self, vector_db: str, config: dict):
|
|
602
|
+
"""Create the Vector Store Configuration."""
|
|
603
|
+
self.collection_name = config.get('collection_name')
|
|
604
|
+
if not self.embeddings:
|
|
605
|
+
embed = self.embedding_model_name
|
|
606
|
+
else:
|
|
607
|
+
embed = self.embeddings
|
|
608
|
+
# TODO: add dynamic configuration of VectorStore
|
|
609
|
+
if vector_db == 'QdrantStore':
|
|
610
|
+
if QDRANT_ENABLED is True:
|
|
611
|
+
## TODO: support pluggable vector store
|
|
612
|
+
self._store = QdrantStore( # pylint: disable=E0110
|
|
613
|
+
embeddings=embed,
|
|
614
|
+
use_bge=self.use_bge,
|
|
615
|
+
use_fastembed=self.use_fastembed,
|
|
616
|
+
**config
|
|
617
|
+
)
|
|
618
|
+
else:
|
|
619
|
+
raise ConfigError(
|
|
620
|
+
(
|
|
621
|
+
"Qdrant is enabled but not installed, "
|
|
622
|
+
"Hint: Please install with pip install -e .[qdrant]"
|
|
623
|
+
)
|
|
624
|
+
)
|
|
625
|
+
elif vector_db == 'MilvusStore':
|
|
626
|
+
if MILVUS_ENABLED is True:
|
|
627
|
+
self._store = MilvusStore(
|
|
628
|
+
embeddings=embed,
|
|
629
|
+
use_bge=self.use_bge,
|
|
630
|
+
use_fastembed=self.use_fastembed,
|
|
631
|
+
**config
|
|
632
|
+
)
|
|
633
|
+
else:
|
|
634
|
+
raise ConfigError(
|
|
635
|
+
(
|
|
636
|
+
"Milvus is enabled but not installed, "
|
|
637
|
+
"Hint: Please install with pip install -e .[milvus]"
|
|
638
|
+
)
|
|
639
|
+
)
|
|
640
|
+
else:
|
|
641
|
+
raise ValueError(
|
|
642
|
+
f"Invalid Vector Store {vector_db}"
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
def _define_prompt(self, config: dict):
|
|
646
|
+
# setup the prompt variables:
|
|
647
|
+
for key, val in config.items():
|
|
648
|
+
setattr(self, key, val)
|
|
649
|
+
if self.company_information:
|
|
650
|
+
self.template_prompt = self.template_prompt.format_map(
|
|
651
|
+
SafeDict(
|
|
652
|
+
company_information=(
|
|
653
|
+
"For further inquiries or detailed information, you can contact us at:\n"
|
|
654
|
+
"- Contact Information: {contact_email}\n"
|
|
655
|
+
"- Use our contact form: {contact_form}\n"
|
|
656
|
+
"- or Visit our website: {company_website}\n"
|
|
657
|
+
)
|
|
658
|
+
)
|
|
659
|
+
)
|
|
660
|
+
# Parsing the Template:
|
|
661
|
+
self.template_prompt = self.template_prompt.format_map(
|
|
662
|
+
SafeDict(
|
|
663
|
+
name=self.name,
|
|
664
|
+
role=self.role,
|
|
665
|
+
goal=self.goal,
|
|
666
|
+
backstory=self.backstory,
|
|
667
|
+
rationale=self.rationale,
|
|
668
|
+
threshold=self.confidence_threshold,
|
|
669
|
+
**self.company_information
|
|
670
|
+
)
|
|
671
|
+
)
|
|
672
|
+
# print('Template Prompt:', self.template_prompt)
|
|
673
|
+
|
|
674
|
+
@property
|
|
675
|
+
def llm(self):
|
|
676
|
+
return self._llm
|
|
677
|
+
|
|
678
|
+
@llm.setter
|
|
679
|
+
def llm(self, model):
|
|
680
|
+
self._llm_obj = model
|
|
681
|
+
self._llm = model.get_llm()
|
|
682
|
+
|
|
683
|
+
def _get_device(self, cuda_number: int = 0):
|
|
684
|
+
torch.backends.cudnn.deterministic = True
|
|
685
|
+
if torch.cuda.is_available():
|
|
686
|
+
# Use CUDA GPU if available
|
|
687
|
+
device = torch.device(f'cuda:{cuda_number}')
|
|
688
|
+
elif torch.backends.mps.is_available():
|
|
689
|
+
# Use CUDA Multi-Processing Service if available
|
|
690
|
+
device = torch.device("mps")
|
|
691
|
+
elif EMBEDDING_DEVICE == 'cuda':
|
|
692
|
+
device = torch.device(f'cuda:{cuda_number}')
|
|
693
|
+
else:
|
|
694
|
+
device = torch.device(EMBEDDING_DEVICE)
|
|
695
|
+
return device
|
|
696
|
+
|
|
697
|
+
def get_tokenizer(self, model_name: str, chunk_size: int = 768):
|
|
698
|
+
return AutoTokenizer.from_pretrained(
|
|
699
|
+
model_name,
|
|
700
|
+
chunk_size=chunk_size
|
|
701
|
+
)
|
|
702
|
+
|
|
703
|
+
def get_model(self, model_name: str):
|
|
704
|
+
device = self._get_device()
|
|
705
|
+
self._model_config = AutoConfig.from_pretrained(
|
|
706
|
+
model_name, trust_remote_code=True
|
|
707
|
+
)
|
|
708
|
+
return AutoModel.from_pretrained(
|
|
709
|
+
model_name,
|
|
710
|
+
trust_remote_code=True,
|
|
711
|
+
config=self._model_config,
|
|
712
|
+
unpad_inputs=True,
|
|
713
|
+
use_memory_efficient_attention=True,
|
|
714
|
+
).to(device)
|
|
715
|
+
|
|
716
|
+
def get_text_splitter(self, model, chunk_size: int = 1024, overlap: int = 100):
|
|
717
|
+
return RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
|
|
718
|
+
model,
|
|
719
|
+
chunk_size=chunk_size,
|
|
720
|
+
chunk_overlap=overlap,
|
|
721
|
+
add_start_index=True, # If `True`, includes chunk's start index in metadata
|
|
722
|
+
strip_whitespace=True, # strips whitespace from the start and end
|
|
723
|
+
separators=["\n\n", "\n", "\r\n", "\r", "\f", "\v", "\x0b", "\x0c"],
|
|
724
|
+
)
|
|
725
|
+
|
|
726
|
+
def chunk_documents(self, documents, chunk_size):
|
|
727
|
+
# Yield successive n-sized chunks from documents.
|
|
728
|
+
for i in range(0, len(documents), chunk_size):
|
|
729
|
+
yield documents[i:i + chunk_size]
|
|
730
|
+
|
|
731
|
+
def get_available_vram(self):
|
|
732
|
+
"""
|
|
733
|
+
Returns available VRAM in megabytes.
|
|
734
|
+
"""
|
|
735
|
+
try:
|
|
736
|
+
# Clear any unused memory to get a fresher estimate
|
|
737
|
+
torch.cuda.empty_cache()
|
|
738
|
+
# Convert to MB
|
|
739
|
+
total_memory = torch.cuda.get_device_properties(0).total_memory / (1024 ** 2)
|
|
740
|
+
reserved_memory = torch.cuda.memory_reserved(0) / (1024 ** 2)
|
|
741
|
+
available_memory = total_memory - reserved_memory
|
|
742
|
+
self.logger.notice(f'Available VRAM : {available_memory}')
|
|
743
|
+
# Limit by predefined max usage
|
|
744
|
+
return min(available_memory, self._max_vram)
|
|
745
|
+
except RuntimeError:
|
|
746
|
+
# Limit by predefined max usage
|
|
747
|
+
return min(RAM_AVAILABLE, self._max_vram)
|
|
748
|
+
|
|
749
|
+
def _estimate_chunk_size(self):
|
|
750
|
+
"""Estimate chunk size based on VRAM usage.
|
|
751
|
+
This is a simplistic heuristic and might need tuning based on empirical data
|
|
752
|
+
"""
|
|
753
|
+
available_vram = self.get_available_vram()
|
|
754
|
+
estimated_vram_per_doc = 50 # Estimated VRAM in megabytes per document, adjust based on empirical observation
|
|
755
|
+
chunk_size = max(1, int(available_vram / estimated_vram_per_doc))
|
|
756
|
+
self.logger.notice(
|
|
757
|
+
f'Chunk size for Load Documents: {chunk_size}'
|
|
758
|
+
)
|
|
759
|
+
return chunk_size
|
|
760
|
+
|
|
761
|
+
## Utility Loaders
|
|
762
|
+
##
|
|
763
|
+
|
|
764
|
+
async def load_documents(
|
|
765
|
+
self,
|
|
766
|
+
documents: list,
|
|
767
|
+
collection: str = None,
|
|
768
|
+
delete: bool = False
|
|
769
|
+
):
|
|
770
|
+
# Load Raw Documents into the Vectorstore
|
|
771
|
+
print('::: LEN >> ', len(documents), type(documents))
|
|
772
|
+
if len(documents) < 1:
|
|
773
|
+
self.logger.warning(
|
|
774
|
+
"There is no documents to be loaded, skipping."
|
|
775
|
+
)
|
|
776
|
+
return
|
|
777
|
+
|
|
778
|
+
self._documents_.extend(documents)
|
|
779
|
+
if not collection:
|
|
780
|
+
collection = self.collection_name
|
|
781
|
+
|
|
782
|
+
self.logger.notice(f'Loading Documents: {len(documents)}')
|
|
783
|
+
document_chunks = self.chunk_documents(
|
|
784
|
+
documents,
|
|
785
|
+
self._estimate_chunk_size()
|
|
786
|
+
)
|
|
787
|
+
async with self._store as store:
|
|
788
|
+
# if delete is True, then delete the collection
|
|
789
|
+
if delete is True:
|
|
790
|
+
await store.delete_collection(collection)
|
|
791
|
+
fdoc = documents.pop(0)
|
|
792
|
+
await store.create_collection(
|
|
793
|
+
collection,
|
|
794
|
+
fdoc
|
|
795
|
+
)
|
|
796
|
+
for chunk in document_chunks:
|
|
797
|
+
await store.load_documents(
|
|
798
|
+
chunk,
|
|
799
|
+
collection=collection
|
|
800
|
+
)
|
|
801
|
+
|
|
802
|
+
def load_pdf(self, path: Path, source_type: str = 'pdf', **kwargs):
|
|
803
|
+
loader = PDFLoader(path, source_type=source_type, no_summarization=True, **kwargs)
|
|
804
|
+
return loader.load()
|
|
805
|
+
|
|
806
|
+
def load_github(
|
|
807
|
+
self,
|
|
808
|
+
url: str,
|
|
809
|
+
github_token: str,
|
|
810
|
+
lang: str = 'python',
|
|
811
|
+
branch: str = 'master',
|
|
812
|
+
source_type: str = 'code'
|
|
813
|
+
) -> list:
|
|
814
|
+
git = GithubLoader(
|
|
815
|
+
url,
|
|
816
|
+
github_token=github_token,
|
|
817
|
+
lang=lang,
|
|
818
|
+
branch=branch,
|
|
819
|
+
source_type=source_type
|
|
820
|
+
)
|
|
821
|
+
return git.load()
|
|
822
|
+
|
|
823
|
+
def load_repository(
|
|
824
|
+
self,
|
|
825
|
+
path: Path,
|
|
826
|
+
lang: str = 'python',
|
|
827
|
+
source_type: str = 'code',
|
|
828
|
+
**kwargs
|
|
829
|
+
) -> list:
|
|
830
|
+
repo = RepositoryLoader(
|
|
831
|
+
source_type=source_type,
|
|
832
|
+
**kwargs
|
|
833
|
+
)
|
|
834
|
+
return repo.load(path, lang=lang)
|
|
835
|
+
|
|
836
|
+
def process_websites(
|
|
837
|
+
self,
|
|
838
|
+
websites: list,
|
|
839
|
+
source_type: str = 'website',
|
|
840
|
+
**kwargs
|
|
841
|
+
) -> list:
|
|
842
|
+
loader = WebLoader(
|
|
843
|
+
urls=websites,
|
|
844
|
+
source_type=source_type
|
|
845
|
+
)
|
|
846
|
+
return loader.load()
|
|
847
|
+
|
|
848
|
+
def load_youtube_videos(
|
|
849
|
+
self,
|
|
850
|
+
urls: list,
|
|
851
|
+
video_path: Union[str, Path],
|
|
852
|
+
source_type: str = 'youtube',
|
|
853
|
+
priority: int = 'high',
|
|
854
|
+
language: str = 'en',
|
|
855
|
+
**kwargs
|
|
856
|
+
) -> list:
|
|
857
|
+
yt = YoutubeLoader(
|
|
858
|
+
urls=urls,
|
|
859
|
+
video_path=video_path,
|
|
860
|
+
source_type=source_type,
|
|
861
|
+
priority=priority,
|
|
862
|
+
language=language,
|
|
863
|
+
llm=self._llm,
|
|
864
|
+
**kwargs
|
|
865
|
+
)
|
|
866
|
+
return yt.load()
|
|
867
|
+
|
|
868
|
+
def load_vimeo_videos(
|
|
869
|
+
self,
|
|
870
|
+
urls: list,
|
|
871
|
+
video_path: Union[str, Path],
|
|
872
|
+
source_type: str = 'vimeo',
|
|
873
|
+
priority: int = 'high',
|
|
874
|
+
language: str = 'en',
|
|
875
|
+
**kwargs
|
|
876
|
+
) -> list:
|
|
877
|
+
yt = VimeoLoader(
|
|
878
|
+
urls=urls,
|
|
879
|
+
video_path=video_path,
|
|
880
|
+
source_type=source_type,
|
|
881
|
+
priority=priority,
|
|
882
|
+
language=language,
|
|
883
|
+
llm=self._llm,
|
|
884
|
+
**kwargs
|
|
885
|
+
)
|
|
886
|
+
return yt.load()
|
|
887
|
+
|
|
888
|
+
def load_directory(
|
|
889
|
+
self,
|
|
890
|
+
path: Union[str, Path],
|
|
891
|
+
source_type: str = 'documents',
|
|
892
|
+
) -> list:
|
|
893
|
+
return None
|
|
894
|
+
|
|
895
|
+
def load_docx(
|
|
896
|
+
self,
|
|
897
|
+
path: Path,
|
|
898
|
+
source_type: str = 'docx',
|
|
899
|
+
**kwargs
|
|
900
|
+
) -> list:
|
|
901
|
+
return MSWordLoader.from_path(
|
|
902
|
+
path=path,
|
|
903
|
+
source_type=source_type,
|
|
904
|
+
**kwargs
|
|
905
|
+
)
|
|
906
|
+
|
|
907
|
+
def load_pptx(
|
|
908
|
+
self,
|
|
909
|
+
path: Path,
|
|
910
|
+
source_type: str = 'pptx',
|
|
911
|
+
**kwargs
|
|
912
|
+
) -> list:
|
|
913
|
+
return PPTXLoader.from_path(
|
|
914
|
+
path=path,
|
|
915
|
+
source_type=source_type,
|
|
916
|
+
**kwargs
|
|
917
|
+
)
|
|
918
|
+
|
|
919
|
+
def get_memory(
|
|
920
|
+
self,
|
|
921
|
+
session_id: str = None,
|
|
922
|
+
key: str = 'chat_history',
|
|
923
|
+
input_key: str = 'question',
|
|
924
|
+
output_key: str = 'answer',
|
|
925
|
+
size: int = 30,
|
|
926
|
+
ttl: int = 86400
|
|
927
|
+
):
|
|
928
|
+
args = {
|
|
929
|
+
'memory_key': key,
|
|
930
|
+
'input_key': input_key,
|
|
931
|
+
'output_key': output_key,
|
|
932
|
+
'return_messages': True,
|
|
933
|
+
'max_len': size
|
|
934
|
+
}
|
|
935
|
+
if session_id:
|
|
936
|
+
message_history = RedisChatMessageHistory(
|
|
937
|
+
url=REDIS_HISTORY_URL,
|
|
938
|
+
session_id=session_id,
|
|
939
|
+
ttl=ttl
|
|
940
|
+
)
|
|
941
|
+
args['chat_memory'] = message_history
|
|
942
|
+
return ConversationBufferMemory(
|
|
943
|
+
**args
|
|
944
|
+
)
|
|
945
|
+
|
|
946
|
+
def get_retrieval(self, source_path: str = 'web', request: web.Request = None):
|
|
947
|
+
pre_context = "\n".join(f"- {a}." for a in self.pre_instructions)
|
|
948
|
+
custom_template = self.template_prompt.format_map(
|
|
949
|
+
SafeDict(
|
|
950
|
+
summaries=pre_context
|
|
951
|
+
)
|
|
952
|
+
)
|
|
953
|
+
# Generate the Retrieval
|
|
954
|
+
rm = RetrievalManager(
|
|
955
|
+
chatbot_id=self.chatbot_id,
|
|
956
|
+
chatbot_name=self.name,
|
|
957
|
+
source_path=source_path,
|
|
958
|
+
model=self._llm,
|
|
959
|
+
store=self._store,
|
|
960
|
+
memory=None,
|
|
961
|
+
template=custom_template,
|
|
962
|
+
kb=self.knowledge_base,
|
|
963
|
+
request=request
|
|
964
|
+
)
|
|
965
|
+
return rm
|