ai-parrot 0.3.4__cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-parrot might be problematic. Click here for more details.

Files changed (109) hide show
  1. ai_parrot-0.3.4.dist-info/LICENSE +21 -0
  2. ai_parrot-0.3.4.dist-info/METADATA +319 -0
  3. ai_parrot-0.3.4.dist-info/RECORD +109 -0
  4. ai_parrot-0.3.4.dist-info/WHEEL +6 -0
  5. ai_parrot-0.3.4.dist-info/top_level.txt +3 -0
  6. parrot/__init__.py +21 -0
  7. parrot/chatbots/__init__.py +7 -0
  8. parrot/chatbots/abstract.py +728 -0
  9. parrot/chatbots/asktroc.py +16 -0
  10. parrot/chatbots/base.py +366 -0
  11. parrot/chatbots/basic.py +9 -0
  12. parrot/chatbots/bose.py +17 -0
  13. parrot/chatbots/cody.py +17 -0
  14. parrot/chatbots/copilot.py +83 -0
  15. parrot/chatbots/dataframe.py +103 -0
  16. parrot/chatbots/hragents.py +15 -0
  17. parrot/chatbots/odoo.py +17 -0
  18. parrot/chatbots/retrievals/__init__.py +578 -0
  19. parrot/chatbots/retrievals/constitutional.py +19 -0
  20. parrot/conf.py +110 -0
  21. parrot/crew/__init__.py +3 -0
  22. parrot/crew/tools/__init__.py +22 -0
  23. parrot/crew/tools/bing.py +13 -0
  24. parrot/crew/tools/config.py +43 -0
  25. parrot/crew/tools/duckgo.py +62 -0
  26. parrot/crew/tools/file.py +24 -0
  27. parrot/crew/tools/google.py +168 -0
  28. parrot/crew/tools/gtrends.py +16 -0
  29. parrot/crew/tools/md2pdf.py +25 -0
  30. parrot/crew/tools/rag.py +42 -0
  31. parrot/crew/tools/search.py +32 -0
  32. parrot/crew/tools/url.py +21 -0
  33. parrot/exceptions.cpython-311-x86_64-linux-gnu.so +0 -0
  34. parrot/handlers/__init__.py +4 -0
  35. parrot/handlers/bots.py +196 -0
  36. parrot/handlers/chat.py +162 -0
  37. parrot/interfaces/__init__.py +6 -0
  38. parrot/interfaces/database.py +29 -0
  39. parrot/llms/__init__.py +137 -0
  40. parrot/llms/abstract.py +47 -0
  41. parrot/llms/anthropic.py +42 -0
  42. parrot/llms/google.py +42 -0
  43. parrot/llms/groq.py +45 -0
  44. parrot/llms/hf.py +45 -0
  45. parrot/llms/openai.py +59 -0
  46. parrot/llms/pipes.py +114 -0
  47. parrot/llms/vertex.py +78 -0
  48. parrot/loaders/__init__.py +20 -0
  49. parrot/loaders/abstract.py +456 -0
  50. parrot/loaders/audio.py +106 -0
  51. parrot/loaders/basepdf.py +102 -0
  52. parrot/loaders/basevideo.py +280 -0
  53. parrot/loaders/csv.py +42 -0
  54. parrot/loaders/dir.py +37 -0
  55. parrot/loaders/excel.py +349 -0
  56. parrot/loaders/github.py +65 -0
  57. parrot/loaders/handlers/__init__.py +5 -0
  58. parrot/loaders/handlers/data.py +213 -0
  59. parrot/loaders/image.py +119 -0
  60. parrot/loaders/json.py +52 -0
  61. parrot/loaders/pdf.py +437 -0
  62. parrot/loaders/pdfchapters.py +142 -0
  63. parrot/loaders/pdffn.py +112 -0
  64. parrot/loaders/pdfimages.py +207 -0
  65. parrot/loaders/pdfmark.py +88 -0
  66. parrot/loaders/pdftables.py +145 -0
  67. parrot/loaders/ppt.py +30 -0
  68. parrot/loaders/qa.py +81 -0
  69. parrot/loaders/repo.py +103 -0
  70. parrot/loaders/rtd.py +65 -0
  71. parrot/loaders/txt.py +92 -0
  72. parrot/loaders/utils/__init__.py +1 -0
  73. parrot/loaders/utils/models.py +25 -0
  74. parrot/loaders/video.py +96 -0
  75. parrot/loaders/videolocal.py +120 -0
  76. parrot/loaders/vimeo.py +106 -0
  77. parrot/loaders/web.py +216 -0
  78. parrot/loaders/web_base.py +112 -0
  79. parrot/loaders/word.py +125 -0
  80. parrot/loaders/youtube.py +192 -0
  81. parrot/manager.py +166 -0
  82. parrot/models.py +372 -0
  83. parrot/py.typed +0 -0
  84. parrot/stores/__init__.py +48 -0
  85. parrot/stores/abstract.py +171 -0
  86. parrot/stores/milvus.py +632 -0
  87. parrot/stores/qdrant.py +153 -0
  88. parrot/tools/__init__.py +12 -0
  89. parrot/tools/abstract.py +53 -0
  90. parrot/tools/asknews.py +32 -0
  91. parrot/tools/bing.py +13 -0
  92. parrot/tools/duck.py +62 -0
  93. parrot/tools/google.py +170 -0
  94. parrot/tools/stack.py +26 -0
  95. parrot/tools/weather.py +70 -0
  96. parrot/tools/wikipedia.py +59 -0
  97. parrot/tools/zipcode.py +179 -0
  98. parrot/utils/__init__.py +2 -0
  99. parrot/utils/parsers/__init__.py +5 -0
  100. parrot/utils/parsers/toml.cpython-311-x86_64-linux-gnu.so +0 -0
  101. parrot/utils/toml.py +11 -0
  102. parrot/utils/types.cpython-311-x86_64-linux-gnu.so +0 -0
  103. parrot/utils/uv.py +11 -0
  104. parrot/version.py +10 -0
  105. resources/users/__init__.py +5 -0
  106. resources/users/handlers.py +13 -0
  107. resources/users/models.py +205 -0
  108. settings/__init__.py +0 -0
  109. settings/settings.py +51 -0
@@ -0,0 +1,728 @@
1
+ """
2
+ Foundational base of every Chatbot and Agent in ai-parrot.
3
+ """
4
+ from abc import ABC
5
+ from collections.abc import Callable
6
+ from typing import Any, Union
7
+ from pathlib import Path, PurePath
8
+ import uuid
9
+ from aiohttp import web
10
+ import torch
11
+ from transformers import (
12
+ AutoModel,
13
+ AutoConfig,
14
+ AutoTokenizer,
15
+ )
16
+ # Langchain
17
+ from langchain.docstore.document import Document
18
+ from langchain.memory import (
19
+ ConversationBufferMemory
20
+ )
21
+ from langchain.text_splitter import (
22
+ RecursiveCharacterTextSplitter
23
+ )
24
+ from langchain_community.chat_message_histories import (
25
+ RedisChatMessageHistory
26
+ )
27
+ # Navconfig
28
+ from navconfig import BASE_DIR
29
+ from navconfig.exceptions import ConfigError # pylint: disable=E0611
30
+ from navconfig.logging import logging
31
+ from asyncdb.exceptions import NoDataFound
32
+
33
+ ## LLM configuration
34
+ from ..llms import get_llm, AbstractLLM
35
+
36
+ ## Vector Database configuration:
37
+ from ..stores import get_vectordb
38
+
39
+ from ..utils import SafeDict, parse_toml_config
40
+ from .retrievals import RetrievalManager
41
+ from ..conf import (
42
+ EMBEDDING_DEVICE,
43
+ MAX_VRAM_AVAILABLE,
44
+ RAM_AVAILABLE,
45
+ default_dsn,
46
+ REDIS_HISTORY_URL,
47
+ EMBEDDING_DEFAULT_MODEL
48
+ )
49
+ from ..interfaces import DBInterface
50
+ from ..models import ChatbotModel
51
+
52
+
53
+ class AbstractChatbot(ABC, DBInterface):
54
+ """Represents an Chatbot in Navigator.
55
+
56
+ Each Chatbot has a name, a role, a goal, a backstory,
57
+ and an optional language model (llm).
58
+ """
59
+
60
+ template_prompt: str = (
61
+ "You are {name}, an expert AI assistant and {role} Working at {company}.\n\n"
62
+ "Your primary function is to {goal}\n"
63
+ "Use the provided context of the documents you have processed or extracted from other provided tools or sources to provide informative, detailed and accurate responses.\n"
64
+ "I am here to help with {role}.\n"
65
+ "**Backstory:**\n"
66
+ "{backstory}.\n\n"
67
+ "Focus on answering the question directly but detailed. Do not include an introduction or greeting in your response.\n\n"
68
+ "{company_information}\n\n"
69
+ "Here is a brief summary of relevant information:\n"
70
+ "Context: {context}\n\n"
71
+ "Given this information, please provide answers to the following question adding detailed and useful insights:\n\n"
72
+ "**Chat History:** {chat_history}\n\n"
73
+ "**Human Question:** {question}\n"
74
+ "Assistant Answer:\n\n"
75
+ "{rationale}\n"
76
+ "You are a fluent speaker, you can talk and respond fluently in English and Spanish, and you must answer in the same language as the user's question. If the user's language is not English, you should translate your response into their language.\n"
77
+ )
78
+
79
+ def _get_default_attr(self, key, default: Any = None, **kwargs):
80
+ if key in kwargs:
81
+ return kwargs.get(key)
82
+ if hasattr(self, key):
83
+ return getattr(self, key)
84
+ if not hasattr(self, key):
85
+ return default
86
+ return getattr(self, key)
87
+
88
+ def __init__(self, **kwargs):
89
+ """Initialize the Chatbot with the given configuration."""
90
+ # Start initialization:
91
+ self.kb = None
92
+ self.knowledge_base: list = []
93
+ # Chatbot ID:
94
+ self.chatbot_id: uuid.UUID = kwargs.get(
95
+ 'chatbot_id',
96
+ str(uuid.uuid4().hex)
97
+ )
98
+ # Basic Information:
99
+ self.name = self._get_default_attr(
100
+ 'name', 'NAV', **kwargs
101
+ )
102
+ ## Logging:
103
+ self.logger = logging.getLogger(f'{self.name}.Chatbot')
104
+ self.description = self._get_default_attr(
105
+ 'description', 'Navigator Chatbot', **kwargs
106
+ )
107
+ self.role = self._get_default_attr(
108
+ 'role', 'Chatbot', **kwargs
109
+ )
110
+ self.goal = self._get_default_attr(
111
+ 'goal',
112
+ 'provide helpful information to users',
113
+ **kwargs
114
+ )
115
+ self.backstory = self._get_default_attr(
116
+ 'backstory',
117
+ default=self.default_backstory(),
118
+ **kwargs
119
+ )
120
+ self.rationale = self._get_default_attr(
121
+ 'rationale',
122
+ default=self.default_rationale(),
123
+ **kwargs
124
+ )
125
+ # Configuration File:
126
+ self.config_file: PurePath = kwargs.get('config_file', None)
127
+ # Other Configuration
128
+ self.confidence_threshold: float = kwargs.get('threshold', 0.5)
129
+ self.context = kwargs.pop('context', '')
130
+
131
+ # Company Information:
132
+ self.company_information: dict = kwargs.pop('company_information', {})
133
+
134
+ # Pre-Instructions:
135
+ self.pre_instructions: list = kwargs.get(
136
+ 'pre_instructions',
137
+ []
138
+ )
139
+
140
+ # Knowledge base:
141
+ self.knowledge_base: list = []
142
+ self._documents_: list = []
143
+
144
+ # Text Documents
145
+ self.documents_dir = kwargs.get(
146
+ 'documents_dir',
147
+ None
148
+ )
149
+ if isinstance(self.documents_dir, str):
150
+ self.documents_dir = Path(self.documents_dir)
151
+ if not self.documents_dir:
152
+ self.documents_dir = BASE_DIR.joinpath('documents')
153
+ if not self.documents_dir.exists():
154
+ self.documents_dir.mkdir(
155
+ parents=True,
156
+ exist_ok=True
157
+ )
158
+ # Models, Embed and collections
159
+ # Vector information:
160
+ self.chunk_size: int = int(kwargs.get('chunk_size', 768))
161
+ self.dimension: int = int(kwargs.get('dimension', 768))
162
+ self._database: dict = kwargs.get('database', {})
163
+ self._store: Callable = None
164
+ # Embedding Model Name
165
+ self.use_bge: bool = bool(
166
+ kwargs.get('use_bge', 'False')
167
+ )
168
+ self.use_fastembed: bool = bool(
169
+ kwargs.get('use_fastembed', 'False')
170
+ )
171
+ self.embedding_model_name = kwargs.get(
172
+ 'embedding_model_name', None
173
+ )
174
+ # embedding object:
175
+ self.embeddings = kwargs.get('embeddings', None)
176
+ self.tokenizer_model_name = kwargs.get(
177
+ 'tokenizer', None
178
+ )
179
+ self.summarization_model = kwargs.get(
180
+ 'summarization_model',
181
+ "facebook/bart-large-cnn"
182
+ )
183
+ self.rag_model = kwargs.get(
184
+ 'rag_model',
185
+ "rlm/rag-prompt-llama"
186
+ )
187
+ self._text_splitter_model = kwargs.get(
188
+ 'text_splitter',
189
+ 'mixedbread-ai/mxbai-embed-large-v1'
190
+ )
191
+ # Definition of LLM
192
+ # Overrriding LLM object
193
+ self._llm_obj: Callable = kwargs.get('llm', None)
194
+ # LLM base Object:
195
+ self._llm: Callable = None
196
+
197
+ # Max VRAM usage:
198
+ self._max_vram = int(
199
+ kwargs.get('max_vram', MAX_VRAM_AVAILABLE)
200
+ )
201
+
202
+ def get_llm(self):
203
+ return self._llm_obj
204
+
205
+ def __repr__(self):
206
+ return f"<Chatbot.{self.__class__.__name__}:{self.name}>"
207
+
208
+ # Database:
209
+ @property
210
+ def store(self):
211
+ if not self._store.connected:
212
+ self._store.connect()
213
+ return self._store
214
+
215
+ def default_rationale(self) -> str:
216
+ # TODO: read rationale from a file
217
+ return (
218
+ "I am a language model trained by Google.\n"
219
+ "I am designed to provide helpful information to users."
220
+ "Remember to maintain a professional tone."
221
+ "If I cannot find relevant information in the documents,"
222
+ "I will indicate this and suggest alternative avenues for the user to find an answer."
223
+ )
224
+
225
+ def default_backstory(self) -> str:
226
+ return (
227
+ "help with Human Resources related queries or knowledge-based questions about T-ROC Global.\n"
228
+ "You can ask me about the company's products and services, the company's culture, the company's clients.\n"
229
+ "You have the capability to read and understand various Human Resources documents, "
230
+ "such as employee handbooks, policy documents, onboarding materials, company's website, and more.\n"
231
+ "I can also provide information about the company's policies and procedures, benefits, and other HR-related topics."
232
+ )
233
+
234
+ async def configure(self, app = None) -> None:
235
+ if app is None:
236
+ self.app = None
237
+ else:
238
+ if isinstance(app, web.Application):
239
+ self.app = app # register the app into the Extension
240
+ else:
241
+ self.app = app.get_app() # Nav Application
242
+ # Config File:
243
+ config_file = BASE_DIR.joinpath(
244
+ 'etc',
245
+ 'config',
246
+ 'chatbots',
247
+ self.name.lower(),
248
+ "config.toml"
249
+ )
250
+ if config_file.exists():
251
+ self.logger.notice(
252
+ f"Loading Bot {self.name} from config: {config_file.name}"
253
+ )
254
+ if (bot := await self.bot_exists(name=self.name, uuid=self.chatbot_id)):
255
+ self.logger.notice(
256
+ f"Loading Bot {self.name} from Database: {bot.chatbot_id}"
257
+ )
258
+ # Bot exists on Database, Configure from the Database
259
+ await self.from_database(bot, config_file)
260
+ elif config_file.exists():
261
+ # Configure from the TOML file
262
+ await self.from_config_file(config_file)
263
+ else:
264
+ raise ValueError(
265
+ f'Bad configuration procedure for bot {self.name}'
266
+ )
267
+ # adding this configured chatbot to app:
268
+ if self.app:
269
+ self.app[f"{self.name.lower()}_chatbot"] = self
270
+
271
+ def _configure_llm(self, llm, config):
272
+ """
273
+ Configuration of LLM.
274
+ """
275
+ if isinstance(self._llm_obj, AbstractLLM):
276
+ self._llm = self._llm_obj.get_llm()
277
+ elif self._llm_obj is not None:
278
+ self._llm = self._llm_obj
279
+ else:
280
+ if llm:
281
+ # LLM:
282
+ self._llm_obj = get_llm(
283
+ llm,
284
+ **config
285
+ )
286
+ # getting langchain LLM from Obj:
287
+ self._llm = self._llm_obj.get_llm()
288
+ else:
289
+ raise ValueError(
290
+ f"{self.name}: LLM is not defined in bot Configuration."
291
+ )
292
+
293
+ def _from_bot(self, bot, key, config, default) -> Any:
294
+ value = getattr(bot, key, None)
295
+ file_value = config.get(key, default)
296
+ return value if value else file_value
297
+
298
+ def _from_db(self, botobj, key, default = None) -> Any:
299
+ value = getattr(botobj, key, default)
300
+ return value if value else default
301
+
302
+ async def bot_exists(
303
+ self,
304
+ name: str = None,
305
+ uuid: uuid.UUID = None
306
+ ) -> Union[ChatbotModel, bool]:
307
+ """Check if the Chatbot exists in the Database."""
308
+ db = self.get_database('pg', dsn=default_dsn)
309
+ async with await db.connection() as conn: # pylint: disable=E1101
310
+ ChatbotModel.Meta.connection = conn
311
+ try:
312
+ if self.chatbot_id:
313
+ try:
314
+ bot = await ChatbotModel.get(chatbot_id=uuid)
315
+ except Exception:
316
+ bot = await ChatbotModel.get(name=name)
317
+ else:
318
+ bot = await ChatbotModel.get(name=self.name)
319
+ if bot:
320
+ return bot
321
+ except NoDataFound:
322
+ return False
323
+
324
+ async def from_database(
325
+ self,
326
+ bot: Union[ChatbotModel, None] = None,
327
+ config_file: PurePath = None
328
+ ) -> None:
329
+ """Load the Chatbot Configuration from the Database."""
330
+ if not bot:
331
+ db = self.get_database('pg', dsn=default_dsn)
332
+ async with await db.connection() as conn: # pylint: disable=E1101
333
+ # import model
334
+ ChatbotModel.Meta.connection = conn
335
+ try:
336
+ if self.chatbot_id:
337
+ try:
338
+ bot = await ChatbotModel.get(chatbot_id=self.chatbot_id)
339
+ except Exception:
340
+ bot = await ChatbotModel.get(name=self.name)
341
+ else:
342
+ bot = await ChatbotModel.get(name=self.name)
343
+ except NoDataFound:
344
+ # Fallback to File configuration:
345
+ raise ConfigError(
346
+ f"Chatbot {self.name} not found in the database."
347
+ )
348
+ # Start Bot configuration from Database:
349
+ if config_file and config_file.exists():
350
+ file_config = await parse_toml_config(config_file)
351
+ # Knowledge Base come from file:
352
+ # Contextual knowledge-base
353
+ self.kb = file_config.get('knowledge-base', [])
354
+ if self.kb:
355
+ self.knowledge_base = self.create_kb(
356
+ self.kb.get('data', [])
357
+ )
358
+ self.name = self._from_db(bot, 'name', default=self.name)
359
+ self.chatbot_id = str(self._from_db(bot, 'chatbot_id', default=self.chatbot_id))
360
+ self.description = self._from_db(bot, 'description', default=self.description)
361
+ self.role = self._from_db(bot, 'role', default=self.role)
362
+ self.goal = self._from_db(bot, 'goal', default=self.goal)
363
+ self.rationale = self._from_db(bot, 'rationale', default=self.rationale)
364
+ self.backstory = self._from_db(bot, 'backstory', default=self.backstory)
365
+ # company information:
366
+ self.company_information = self._from_db(
367
+ bot, 'company_information', default=self.company_information
368
+ )
369
+ # LLM Configuration:
370
+ llm = self._from_db(bot, 'llm', default='VertexLLM')
371
+ llm_config = self._from_db(bot, 'llm_config', default={})
372
+ # Configuration of LLM:
373
+ self._configure_llm(llm, llm_config)
374
+ # Other models:
375
+ self.embedding_model_name = self._from_db(
376
+ bot, 'embedding_name', None
377
+ )
378
+ self.tokenizer_model_name = self._from_db(
379
+ bot, 'tokenizer', None
380
+ )
381
+ self.summarization_model = self._from_db(
382
+ bot, 'summarize_model', "facebook/bart-large-cnn"
383
+ )
384
+ self.classification_model = self._from_db(
385
+ bot, 'classification_model', None
386
+ )
387
+ # Database Configuration:
388
+ db_config = bot.database
389
+ vector_db = db_config.pop('vector_database')
390
+ await self.store_configuration(vector_db, db_config)
391
+ # after configuration, setup the chatbot
392
+ if bot.template_prompt:
393
+ self.template_prompt = bot.template_prompt
394
+ self._define_prompt(
395
+ config={}
396
+ )
397
+
398
+ async def from_config_file(self, config_file: PurePath) -> None:
399
+ """Load the Chatbot Configuration from the TOML file."""
400
+ self.logger.debug(
401
+ f"Using Config File: {config_file}"
402
+ )
403
+ file_config = await parse_toml_config(config_file)
404
+ # getting the configuration from config
405
+ self.config_file = config_file
406
+ # basic config
407
+ basic = file_config.get('chatbot', {})
408
+ # Chatbot Name:
409
+ self.name = basic.get('name', self.name)
410
+ self.description = basic.get('description', self.description)
411
+ self.role = basic.get('role', self.role)
412
+ self.goal = basic.get('goal', self.goal)
413
+ self.rationale = basic.get('rationale', self.rationale)
414
+ self.backstory = basic.get('backstory', self.backstory)
415
+ # Company Information:
416
+ self.company_information = basic.get(
417
+ 'company_information',
418
+ self.company_information
419
+ )
420
+ # Model Information:
421
+ llminfo = file_config.get('llm')
422
+ llm = llminfo.get('llm', 'VertexLLM')
423
+ cfg = llminfo.get('config', {})
424
+ # Configuration of LLM:
425
+ self._configure_llm(llm, cfg)
426
+
427
+ # Other models:
428
+ models = file_config.get('models', {})
429
+ if not self.embedding_model_name:
430
+ self.embedding_model_name = models.get(
431
+ 'embedding', EMBEDDING_DEFAULT_MODEL
432
+ )
433
+ if not self.tokenizer_model_name:
434
+ self.tokenizer_model_name = models.get('tokenizer')
435
+ if not self.embedding_model_name:
436
+ # Getting the Embedding Model from the LLM
437
+ self.embeddings = self._llm_obj.get_embedding()
438
+ self.use_bge = models.get('use_bge', False)
439
+ self.use_fastembed = models.get('use_fastembed', False)
440
+ self.summarization_model = models.get(
441
+ 'summarize_model',
442
+ "facebook/bart-large-cnn"
443
+ )
444
+ self.classification_model = models.get(
445
+ 'classification_model',
446
+ None
447
+ )
448
+ # pre-instructions
449
+ instructions = file_config.get('pre-instructions')
450
+ if instructions:
451
+ self.pre_instructions = instructions.get('instructions', [])
452
+ # Contextual knowledge-base
453
+ self.kb = file_config.get('knowledge-base', [])
454
+ if self.kb:
455
+ self.knowledge_base = self.create_kb(
456
+ self.kb.get('data', [])
457
+ )
458
+ vector_config = file_config.get('database', {})
459
+ vector_db = vector_config.pop('vector_database')
460
+ # configure vector database:
461
+ await self.store_configuration(
462
+ vector_db,
463
+ vector_config
464
+ )
465
+ # after configuration, setup the chatbot
466
+ if 'template_prompt' in basic:
467
+ self.template_prompt = basic.get('template_prompt')
468
+ self._define_prompt(
469
+ config=basic
470
+ )
471
+
472
+ def create_kb(self, documents: list):
473
+ new_docs = []
474
+ for doc in documents:
475
+ content = doc.pop('content')
476
+ source = doc.pop('source', 'knowledge-base')
477
+ if doc:
478
+ meta = {
479
+ 'source': source,
480
+ **doc
481
+ }
482
+ else:
483
+ meta = { 'source': source}
484
+ if content:
485
+ new_docs.append(
486
+ Document(
487
+ page_content=content,
488
+ metadata=meta
489
+ )
490
+ )
491
+ return new_docs
492
+
493
+ async def store_configuration(self, vector_db: str, config: dict):
494
+ """Create the Vector Store Configuration."""
495
+ self.collection_name = config.get('collection_name')
496
+ if not self.embeddings:
497
+ embed = self.embedding_model_name
498
+ else:
499
+ embed = self.embeddings
500
+ # TODO: add dynamic configuration of VectorStore
501
+ self._store = get_vectordb(
502
+ vector_db,
503
+ embeddings=embed,
504
+ use_bge=self.use_bge,
505
+ use_fastembed=self.use_fastembed,
506
+ **config
507
+ )
508
+
509
+ def _define_prompt(self, config: dict):
510
+ # setup the prompt variables:
511
+ for key, val in config.items():
512
+ setattr(self, key, val)
513
+ if self.company_information:
514
+ self.template_prompt = self.template_prompt.format_map(
515
+ SafeDict(
516
+ company_information=(
517
+ "For further inquiries or detailed information, you can contact us at:\n"
518
+ "- Contact Information: {contact_email}\n"
519
+ "- Use our contact form: {contact_form}\n"
520
+ "- or Visit our website: {company_website}\n"
521
+ )
522
+ )
523
+ )
524
+ # Parsing the Template:
525
+ self.template_prompt = self.template_prompt.format_map(
526
+ SafeDict(
527
+ name=self.name,
528
+ role=self.role,
529
+ goal=self.goal,
530
+ backstory=self.backstory,
531
+ rationale=self.rationale,
532
+ threshold=self.confidence_threshold,
533
+ **self.company_information
534
+ )
535
+ )
536
+ # print('Template Prompt:', self.template_prompt)
537
+
538
+ @property
539
+ def llm(self):
540
+ return self._llm
541
+
542
+ @llm.setter
543
+ def llm(self, model):
544
+ self._llm_obj = model
545
+ self._llm = model.get_llm()
546
+
547
+ def _get_device(self, cuda_number: int = 0):
548
+ torch.backends.cudnn.deterministic = True
549
+ if torch.cuda.is_available():
550
+ # Use CUDA GPU if available
551
+ device = torch.device(f'cuda:{cuda_number}')
552
+ elif torch.backends.mps.is_available():
553
+ # Use CUDA Multi-Processing Service if available
554
+ device = torch.device("mps")
555
+ elif EMBEDDING_DEVICE == 'cuda':
556
+ device = torch.device(f'cuda:{cuda_number}')
557
+ else:
558
+ device = torch.device(EMBEDDING_DEVICE)
559
+ return device
560
+
561
+ def get_tokenizer(self, model_name: str, chunk_size: int = 768):
562
+ return AutoTokenizer.from_pretrained(
563
+ model_name,
564
+ chunk_size=chunk_size
565
+ )
566
+
567
+ def get_model(self, model_name: str):
568
+ device = self._get_device()
569
+ self._model_config = AutoConfig.from_pretrained(
570
+ model_name, trust_remote_code=True
571
+ )
572
+ return AutoModel.from_pretrained(
573
+ model_name,
574
+ trust_remote_code=True,
575
+ config=self._model_config,
576
+ unpad_inputs=True,
577
+ use_memory_efficient_attention=True,
578
+ ).to(device)
579
+
580
+ def get_text_splitter(self, model, chunk_size: int = 1024, overlap: int = 100):
581
+ return RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
582
+ model,
583
+ chunk_size=chunk_size,
584
+ chunk_overlap=overlap,
585
+ add_start_index=True, # If `True`, includes chunk's start index in metadata
586
+ strip_whitespace=True, # strips whitespace from the start and end
587
+ separators=["\n\n", "\n", "\r\n", "\r", "\f", "\v", "\x0b", "\x0c"],
588
+ )
589
+
590
+ def chunk_documents(self, documents, chunk_size):
591
+ # Yield successive n-sized chunks from documents.
592
+ for i in range(0, len(documents), chunk_size):
593
+ yield documents[i:i + chunk_size]
594
+
595
+ def get_available_vram(self):
596
+ """
597
+ Returns available VRAM in megabytes.
598
+ """
599
+ try:
600
+ # Clear any unused memory to get a fresher estimate
601
+ torch.cuda.empty_cache()
602
+ # Convert to MB
603
+ total_memory = torch.cuda.get_device_properties(0).total_memory / (1024 ** 2)
604
+ reserved_memory = torch.cuda.memory_reserved(0) / (1024 ** 2)
605
+ available_memory = total_memory - reserved_memory
606
+ self.logger.notice(f'Available VRAM : {available_memory}')
607
+ # Limit by predefined max usage
608
+ return min(available_memory, self._max_vram)
609
+ except RuntimeError:
610
+ # Limit by predefined max usage
611
+ return min(RAM_AVAILABLE, self._max_vram)
612
+
613
+ def _estimate_chunk_size(self):
614
+ """Estimate chunk size based on VRAM usage.
615
+ This is a simplistic heuristic and might need tuning based on empirical data
616
+ """
617
+ available_vram = self.get_available_vram()
618
+ estimated_vram_per_doc = 50 # Estimated VRAM in megabytes per document, adjust based on empirical observation
619
+ chunk_size = max(1, int(available_vram / estimated_vram_per_doc))
620
+ self.logger.notice(
621
+ f'Chunk size for Load Documents: {chunk_size}'
622
+ )
623
+ return chunk_size
624
+
625
+ ## Utility Loaders
626
+ ##
627
+
628
+ async def load_documents(
629
+ self,
630
+ documents: list,
631
+ collection: str = None,
632
+ delete: bool = False
633
+ ):
634
+ # Load Raw Documents into the Vectorstore
635
+ print('::: LEN >> ', len(documents), type(documents))
636
+ if len(documents) < 1:
637
+ self.logger.warning(
638
+ "There is no documents to be loaded, skipping."
639
+ )
640
+ return
641
+
642
+ self._documents_.extend(documents)
643
+ if not collection:
644
+ collection = self.collection_name
645
+
646
+ self.logger.notice(f'Loading Documents: {len(documents)}')
647
+ document_chunks = self.chunk_documents(
648
+ documents,
649
+ self._estimate_chunk_size()
650
+ )
651
+ async with self._store.connection(alias='default') as store:
652
+ # if delete is True, then delete the collection
653
+ if delete is True:
654
+ await store.delete_collection(collection)
655
+ fdoc = documents.pop(0)
656
+ await store.create_collection(
657
+ collection,
658
+ fdoc
659
+ )
660
+ for chunk in document_chunks:
661
+ await store.load_documents(
662
+ chunk,
663
+ collection=collection
664
+ )
665
+
666
+ def clean_history(
667
+ self,
668
+ session_id: str = None
669
+ ):
670
+ try:
671
+ redis_client = RedisChatMessageHistory(
672
+ url=REDIS_HISTORY_URL,
673
+ session_id=session_id,
674
+ ttl=60
675
+ )
676
+ redis_client.clear()
677
+ except Exception as e:
678
+ self.logger.error(
679
+ f"Error clearing chat history: {e}"
680
+ )
681
+
682
+ def get_memory(
683
+ self,
684
+ session_id: str = None,
685
+ key: str = 'chat_history',
686
+ input_key: str = 'question',
687
+ output_key: str = 'answer',
688
+ size: int = 5,
689
+ ttl: int = 86400
690
+ ):
691
+ args = {
692
+ 'memory_key': key,
693
+ 'input_key': input_key,
694
+ 'output_key': output_key,
695
+ 'return_messages': True,
696
+ 'max_len': size
697
+ }
698
+ if session_id:
699
+ message_history = RedisChatMessageHistory(
700
+ url=REDIS_HISTORY_URL,
701
+ session_id=session_id,
702
+ ttl=ttl
703
+ )
704
+ args['chat_memory'] = message_history
705
+ return ConversationBufferMemory(
706
+ **args
707
+ )
708
+
709
+ def get_retrieval(self, source_path: str = 'web', request: web.Request = None):
710
+ pre_context = "\n".join(f"- {a}." for a in self.pre_instructions)
711
+ custom_template = self.template_prompt.format_map(
712
+ SafeDict(
713
+ summaries=pre_context
714
+ )
715
+ )
716
+ # Generate the Retrieval
717
+ rm = RetrievalManager(
718
+ chatbot_id=self.chatbot_id,
719
+ chatbot_name=self.name,
720
+ source_path=source_path,
721
+ model=self._llm,
722
+ store=self._store,
723
+ memory=None,
724
+ template=custom_template,
725
+ kb=self.knowledge_base,
726
+ request=request
727
+ )
728
+ return rm