ai-parrot 0.8.3__cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-parrot might be problematic. Click here for more details.
- ai_parrot-0.8.3.dist-info/LICENSE +21 -0
- ai_parrot-0.8.3.dist-info/METADATA +306 -0
- ai_parrot-0.8.3.dist-info/RECORD +128 -0
- ai_parrot-0.8.3.dist-info/WHEEL +6 -0
- ai_parrot-0.8.3.dist-info/top_level.txt +2 -0
- parrot/__init__.py +30 -0
- parrot/bots/__init__.py +5 -0
- parrot/bots/abstract.py +1115 -0
- parrot/bots/agent.py +492 -0
- parrot/bots/basic.py +9 -0
- parrot/bots/bose.py +17 -0
- parrot/bots/chatbot.py +271 -0
- parrot/bots/cody.py +17 -0
- parrot/bots/copilot.py +117 -0
- parrot/bots/data.py +730 -0
- parrot/bots/dataframe.py +103 -0
- parrot/bots/hrbot.py +15 -0
- parrot/bots/interfaces/__init__.py +1 -0
- parrot/bots/interfaces/retrievers.py +12 -0
- parrot/bots/notebook.py +619 -0
- parrot/bots/odoo.py +17 -0
- parrot/bots/prompts/__init__.py +41 -0
- parrot/bots/prompts/agents.py +91 -0
- parrot/bots/prompts/data.py +214 -0
- parrot/bots/retrievals/__init__.py +1 -0
- parrot/bots/retrievals/constitutional.py +19 -0
- parrot/bots/retrievals/multi.py +122 -0
- parrot/bots/retrievals/retrieval.py +610 -0
- parrot/bots/tools/__init__.py +7 -0
- parrot/bots/tools/eda.py +325 -0
- parrot/bots/tools/pdf.py +50 -0
- parrot/bots/tools/plot.py +48 -0
- parrot/bots/troc.py +16 -0
- parrot/conf.py +170 -0
- parrot/crew/__init__.py +3 -0
- parrot/crew/tools/__init__.py +22 -0
- parrot/crew/tools/bing.py +13 -0
- parrot/crew/tools/config.py +43 -0
- parrot/crew/tools/duckgo.py +62 -0
- parrot/crew/tools/file.py +24 -0
- parrot/crew/tools/google.py +168 -0
- parrot/crew/tools/gtrends.py +16 -0
- parrot/crew/tools/md2pdf.py +25 -0
- parrot/crew/tools/rag.py +42 -0
- parrot/crew/tools/search.py +32 -0
- parrot/crew/tools/url.py +21 -0
- parrot/exceptions.cpython-312-x86_64-linux-gnu.so +0 -0
- parrot/handlers/__init__.py +4 -0
- parrot/handlers/agents.py +292 -0
- parrot/handlers/bots.py +196 -0
- parrot/handlers/chat.py +192 -0
- parrot/interfaces/__init__.py +6 -0
- parrot/interfaces/database.py +27 -0
- parrot/interfaces/http.py +805 -0
- parrot/interfaces/images/__init__.py +0 -0
- parrot/interfaces/images/plugins/__init__.py +18 -0
- parrot/interfaces/images/plugins/abstract.py +58 -0
- parrot/interfaces/images/plugins/exif.py +709 -0
- parrot/interfaces/images/plugins/hash.py +52 -0
- parrot/interfaces/images/plugins/vision.py +104 -0
- parrot/interfaces/images/plugins/yolo.py +66 -0
- parrot/interfaces/images/plugins/zerodetect.py +197 -0
- parrot/llms/__init__.py +1 -0
- parrot/llms/abstract.py +69 -0
- parrot/llms/anthropic.py +58 -0
- parrot/llms/gemma.py +15 -0
- parrot/llms/google.py +44 -0
- parrot/llms/groq.py +67 -0
- parrot/llms/hf.py +45 -0
- parrot/llms/openai.py +61 -0
- parrot/llms/pipes.py +114 -0
- parrot/llms/vertex.py +89 -0
- parrot/loaders/__init__.py +9 -0
- parrot/loaders/abstract.py +628 -0
- parrot/loaders/files/__init__.py +0 -0
- parrot/loaders/files/abstract.py +39 -0
- parrot/loaders/files/text.py +63 -0
- parrot/loaders/txt.py +26 -0
- parrot/manager.py +333 -0
- parrot/models.py +504 -0
- parrot/py.typed +0 -0
- parrot/stores/__init__.py +11 -0
- parrot/stores/abstract.py +248 -0
- parrot/stores/chroma.py +188 -0
- parrot/stores/duck.py +162 -0
- parrot/stores/embeddings/__init__.py +10 -0
- parrot/stores/embeddings/abstract.py +46 -0
- parrot/stores/embeddings/base.py +52 -0
- parrot/stores/embeddings/bge.py +20 -0
- parrot/stores/embeddings/fastembed.py +17 -0
- parrot/stores/embeddings/google.py +18 -0
- parrot/stores/embeddings/huggingface.py +20 -0
- parrot/stores/embeddings/ollama.py +14 -0
- parrot/stores/embeddings/openai.py +26 -0
- parrot/stores/embeddings/transformers.py +21 -0
- parrot/stores/embeddings/vertexai.py +17 -0
- parrot/stores/empty.py +10 -0
- parrot/stores/faiss.py +160 -0
- parrot/stores/milvus.py +397 -0
- parrot/stores/postgres.py +653 -0
- parrot/stores/qdrant.py +170 -0
- parrot/tools/__init__.py +23 -0
- parrot/tools/abstract.py +68 -0
- parrot/tools/asknews.py +33 -0
- parrot/tools/basic.py +51 -0
- parrot/tools/bby.py +359 -0
- parrot/tools/bing.py +13 -0
- parrot/tools/docx.py +343 -0
- parrot/tools/duck.py +62 -0
- parrot/tools/execute.py +56 -0
- parrot/tools/gamma.py +28 -0
- parrot/tools/google.py +170 -0
- parrot/tools/gvoice.py +301 -0
- parrot/tools/results.py +278 -0
- parrot/tools/stack.py +27 -0
- parrot/tools/weather.py +70 -0
- parrot/tools/wikipedia.py +58 -0
- parrot/tools/zipcode.py +198 -0
- parrot/utils/__init__.py +2 -0
- parrot/utils/parsers/__init__.py +5 -0
- parrot/utils/parsers/toml.cpython-312-x86_64-linux-gnu.so +0 -0
- parrot/utils/toml.py +11 -0
- parrot/utils/types.cpython-312-x86_64-linux-gnu.so +0 -0
- parrot/utils/uv.py +11 -0
- parrot/version.py +10 -0
- resources/users/__init__.py +5 -0
- resources/users/handlers.py +13 -0
- resources/users/models.py +205 -0
parrot/bots/abstract.py
ADDED
|
@@ -0,0 +1,1115 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Abstract Bot interface.
|
|
3
|
+
"""
|
|
4
|
+
from abc import ABC
|
|
5
|
+
import importlib
|
|
6
|
+
from typing import Any, List, Union, Optional
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
import os
|
|
9
|
+
import uuid
|
|
10
|
+
from string import Template
|
|
11
|
+
import asyncio
|
|
12
|
+
from aiohttp import web
|
|
13
|
+
from langgraph.checkpoint.memory import MemorySaver
|
|
14
|
+
from langchain_core.vectorstores import VectorStoreRetriever
|
|
15
|
+
from langchain.memory import (
|
|
16
|
+
ConversationBufferMemory,
|
|
17
|
+
ConversationBufferWindowMemory
|
|
18
|
+
)
|
|
19
|
+
from langchain.prompts import (
|
|
20
|
+
ChatPromptTemplate,
|
|
21
|
+
SystemMessagePromptTemplate,
|
|
22
|
+
HumanMessagePromptTemplate,
|
|
23
|
+
PromptTemplate
|
|
24
|
+
)
|
|
25
|
+
from langchain.retrievers import (
|
|
26
|
+
EnsembleRetriever,
|
|
27
|
+
)
|
|
28
|
+
from langchain.docstore.document import Document
|
|
29
|
+
from langchain.chains.retrieval_qa.base import RetrievalQA
|
|
30
|
+
from langchain.chains.conversational_retrieval.base import (
|
|
31
|
+
ConversationalRetrievalChain
|
|
32
|
+
)
|
|
33
|
+
from langchain_community.chat_message_histories import (
|
|
34
|
+
RedisChatMessageHistory
|
|
35
|
+
)
|
|
36
|
+
from langchain_community.retrievers import BM25Retriever
|
|
37
|
+
# for exponential backoff
|
|
38
|
+
from pydantic_core._pydantic_core import ValidationError
|
|
39
|
+
|
|
40
|
+
import backoff # for exponential backoff
|
|
41
|
+
from datamodel.exceptions import ValidationError as DataError # pylint: disable=E0611
|
|
42
|
+
from navconfig.logging import logging
|
|
43
|
+
from navigator_auth.conf import AUTH_SESSION_OBJECT
|
|
44
|
+
from ..interfaces import DBInterface
|
|
45
|
+
from ..exceptions import ConfigError
|
|
46
|
+
from ..conf import (
|
|
47
|
+
REDIS_HISTORY_URL,
|
|
48
|
+
EMBEDDING_DEFAULT_MODEL
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
## LLM configuration
|
|
52
|
+
from ..llms import LLM_PRESETS, AbstractLLM
|
|
53
|
+
|
|
54
|
+
# Vertex
|
|
55
|
+
try:
|
|
56
|
+
from ..llms.vertex import VertexLLM
|
|
57
|
+
VERTEX_ENABLED = True
|
|
58
|
+
except (ModuleNotFoundError, ImportError):
|
|
59
|
+
VERTEX_ENABLED = False
|
|
60
|
+
|
|
61
|
+
# Google
|
|
62
|
+
try:
|
|
63
|
+
from ..llms.google import GoogleGenAI
|
|
64
|
+
GOOGLE_ENABLED = True
|
|
65
|
+
except (ModuleNotFoundError, ImportError):
|
|
66
|
+
GOOGLE_ENABLED = False
|
|
67
|
+
|
|
68
|
+
# Anthropic:
|
|
69
|
+
try:
|
|
70
|
+
from ..llms.anthropic import AnthropicLLM
|
|
71
|
+
ANTHROPIC_ENABLED = True
|
|
72
|
+
except (ModuleNotFoundError, ImportError):
|
|
73
|
+
ANTHROPIC_ENABLED = False
|
|
74
|
+
|
|
75
|
+
# OpenAI
|
|
76
|
+
try:
|
|
77
|
+
from ..llms.openai import OpenAILLM
|
|
78
|
+
OPENAI_ENABLED = True
|
|
79
|
+
except (ModuleNotFoundError, ImportError):
|
|
80
|
+
OPENAI_ENABLED = False
|
|
81
|
+
|
|
82
|
+
# Groq
|
|
83
|
+
try:
|
|
84
|
+
from ..llms.groq import GroqLLM
|
|
85
|
+
GROQ_ENABLED = True
|
|
86
|
+
except (ModuleNotFoundError, ImportError):
|
|
87
|
+
GROQ_ENABLED = False
|
|
88
|
+
|
|
89
|
+
from ..utils import SafeDict
|
|
90
|
+
# Chat Response:
|
|
91
|
+
from ..models import ChatResponse
|
|
92
|
+
from .prompts import (
|
|
93
|
+
BASIC_SYSTEM_PROMPT,
|
|
94
|
+
BASIC_HUMAN_PROMPT,
|
|
95
|
+
DEFAULT_GOAL,
|
|
96
|
+
DEFAULT_ROLE,
|
|
97
|
+
DEFAULT_CAPABILITIES,
|
|
98
|
+
DEFAULT_BACKHISTORY
|
|
99
|
+
)
|
|
100
|
+
from .interfaces import EmptyRetriever
|
|
101
|
+
## Vector Stores:
|
|
102
|
+
from ..stores import AbstractStore, supported_stores, EmptyStore
|
|
103
|
+
from .retrievals import MultiVectorStoreRetriever
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "0"
|
|
107
|
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # Hide TensorFlow logs if present
|
|
108
|
+
|
|
109
|
+
logging.getLogger(name='primp').setLevel(logging.INFO)
|
|
110
|
+
logging.getLogger(name='rquest').setLevel(logging.INFO)
|
|
111
|
+
logging.getLogger("grpc").setLevel(logging.CRITICAL)
|
|
112
|
+
logging.getLogger("tensorflow").setLevel(logging.CRITICAL)
|
|
113
|
+
logging.getLogger("transformers").setLevel(logging.CRITICAL)
|
|
114
|
+
logging.getLogger("pymilvus").setLevel(logging.INFO)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def predicate(exception):
|
|
118
|
+
"""Return True if we should retry, False otherwise."""
|
|
119
|
+
return not isinstance(exception, (ValidationError, RuntimeError, DataError))
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class AbstractBot(DBInterface, ABC):
|
|
123
|
+
"""AbstractBot.
|
|
124
|
+
|
|
125
|
+
This class is an abstract representation a base abstraction for all Chatbots.
|
|
126
|
+
"""
|
|
127
|
+
# TODO: make tensor and embeddings optional.
|
|
128
|
+
# Define system prompt template
|
|
129
|
+
system_prompt_template = BASIC_SYSTEM_PROMPT
|
|
130
|
+
|
|
131
|
+
# Define human prompt template
|
|
132
|
+
human_prompt_template = BASIC_HUMAN_PROMPT
|
|
133
|
+
|
|
134
|
+
def __init__(
|
|
135
|
+
self,
|
|
136
|
+
name: str = 'Nav',
|
|
137
|
+
system_prompt: str = None,
|
|
138
|
+
human_prompt: str = None,
|
|
139
|
+
**kwargs
|
|
140
|
+
):
|
|
141
|
+
"""Initialize the Chatbot with the given configuration."""
|
|
142
|
+
self._request: Optional[web.Request] = None
|
|
143
|
+
if system_prompt:
|
|
144
|
+
self.system_prompt_template = system_prompt or BASIC_SYSTEM_PROMPT
|
|
145
|
+
if human_prompt:
|
|
146
|
+
self.human_prompt_template = human_prompt or BASIC_HUMAN_PROMPT
|
|
147
|
+
# Chatbot ID:
|
|
148
|
+
self.chatbot_id: uuid.UUID = kwargs.get(
|
|
149
|
+
'chatbot_id',
|
|
150
|
+
str(uuid.uuid4().hex)
|
|
151
|
+
)
|
|
152
|
+
if self.chatbot_id is None:
|
|
153
|
+
self.chatbot_id = str(uuid.uuid4().hex)
|
|
154
|
+
# Basic Information:
|
|
155
|
+
self.name: str = name
|
|
156
|
+
## Logging:
|
|
157
|
+
self.logger = logging.getLogger(
|
|
158
|
+
f'{self.name}.Bot'
|
|
159
|
+
)
|
|
160
|
+
# Optional aiohttp Application:
|
|
161
|
+
self.app: Optional[web.Application] = None
|
|
162
|
+
# Optional Redis Memory Saver:
|
|
163
|
+
self.memory_saver: Optional[MemorySaver] = None
|
|
164
|
+
# Start initialization:
|
|
165
|
+
self.kb = None
|
|
166
|
+
self.knowledge_base: list = []
|
|
167
|
+
self.return_sources: bool = kwargs.pop('return_sources', False)
|
|
168
|
+
self.description = self._get_default_attr(
|
|
169
|
+
'description',
|
|
170
|
+
'Navigator Chatbot',
|
|
171
|
+
**kwargs
|
|
172
|
+
)
|
|
173
|
+
self.role = kwargs.get('role', DEFAULT_ROLE)
|
|
174
|
+
self.goal = kwargs.get('goal', DEFAULT_GOAL)
|
|
175
|
+
self.capabilities = kwargs.get('capabilities', DEFAULT_CAPABILITIES)
|
|
176
|
+
self.backstory = kwargs.get('backstory', DEFAULT_BACKHISTORY)
|
|
177
|
+
self.rationale = kwargs.get('rationale', self.default_rationale())
|
|
178
|
+
self.context = kwargs.get('use_context', True)
|
|
179
|
+
# Definition of LLM
|
|
180
|
+
self._llm_class: str = None
|
|
181
|
+
self._default_llm: str = kwargs.get('use_llm', 'vertexai')
|
|
182
|
+
self._use_chat: bool = kwargs.get('use_chat', False)
|
|
183
|
+
self._llm_model = kwargs.get('model_name', 'gemini-1.5-pro')
|
|
184
|
+
self._llm_preset: str = kwargs.get('preset', None)
|
|
185
|
+
if self._llm_preset:
|
|
186
|
+
try:
|
|
187
|
+
presetting = LLM_PRESETS[self._llm_preset]
|
|
188
|
+
except KeyError:
|
|
189
|
+
self.logger.warning(
|
|
190
|
+
f"Invalid preset: {self._llm_preset}, default to 'analytical'"
|
|
191
|
+
)
|
|
192
|
+
presetting = LLM_PRESETS['analytical']
|
|
193
|
+
self._llm_temp = presetting.get('temperature', 0.2)
|
|
194
|
+
self._max_tokens = presetting.get('max_tokens', 4096)
|
|
195
|
+
else:
|
|
196
|
+
# Default LLM Presetting by LLMs
|
|
197
|
+
self._llm_temp = kwargs.get('temperature', 0.2)
|
|
198
|
+
self._max_tokens = kwargs.get('max_tokens', 4096)
|
|
199
|
+
self._llm_top_k = kwargs.get('top_k', 41)
|
|
200
|
+
self._llm_top_p = kwargs.get('top_p', 0.9)
|
|
201
|
+
self._llm_config = kwargs.get('model_config', {})
|
|
202
|
+
if self._llm_config:
|
|
203
|
+
self._llm_model = self._llm_config.pop('model', self._llm_model)
|
|
204
|
+
self._llm_class = self._llm_config.pop('name', None)
|
|
205
|
+
# Overrriding LLM object
|
|
206
|
+
self._llm_obj: Callable = kwargs.get('llm', None)
|
|
207
|
+
# LLM base Object:
|
|
208
|
+
self._llm: Callable = None
|
|
209
|
+
self.context = kwargs.pop('context', '')
|
|
210
|
+
|
|
211
|
+
# Pre-Instructions:
|
|
212
|
+
self.pre_instructions: list = kwargs.get(
|
|
213
|
+
'pre_instructions',
|
|
214
|
+
[]
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# Knowledge base:
|
|
218
|
+
self.knowledge_base: list = []
|
|
219
|
+
self._documents_: list = []
|
|
220
|
+
# Models, Embed and collections
|
|
221
|
+
# Vector information:
|
|
222
|
+
self._use_vector: bool = kwargs.get('use_vectorstore', False)
|
|
223
|
+
self._vector_info_: dict = kwargs.get('vector_info', {})
|
|
224
|
+
self._vector_store: dict = kwargs.get('vector_store', None)
|
|
225
|
+
self.chunk_size: int = int(kwargs.get('chunk_size', 2048))
|
|
226
|
+
self.dimension: int = int(kwargs.get('dimension', 768))
|
|
227
|
+
self.store: Callable = None
|
|
228
|
+
self.stores: List[AbstractStore] = []
|
|
229
|
+
self.memory: Callable = None
|
|
230
|
+
# Embedding Model Name
|
|
231
|
+
self.embedding_model = kwargs.get(
|
|
232
|
+
'embedding_model',
|
|
233
|
+
{
|
|
234
|
+
'model_name': EMBEDDING_DEFAULT_MODEL,
|
|
235
|
+
'model_type': 'huggingface'
|
|
236
|
+
}
|
|
237
|
+
)
|
|
238
|
+
# embedding object:
|
|
239
|
+
self.embeddings = kwargs.get('embeddings', None)
|
|
240
|
+
self.rag_model = kwargs.get(
|
|
241
|
+
'rag_model',
|
|
242
|
+
"rlm/rag-prompt-llama"
|
|
243
|
+
)
|
|
244
|
+
# Summarization and Classification Models
|
|
245
|
+
# Bot Security and Permissions:
|
|
246
|
+
_default = self.default_permissions()
|
|
247
|
+
_permissions = kwargs.get('permissions', _default)
|
|
248
|
+
if _permissions is None:
|
|
249
|
+
_permissions = {}
|
|
250
|
+
self._permissions = {**_default, **_permissions}
|
|
251
|
+
|
|
252
|
+
def default_permissions(self) -> dict:
|
|
253
|
+
"""
|
|
254
|
+
Returns the default permissions for the bot.
|
|
255
|
+
|
|
256
|
+
This function defines and returns a dictionary containing the default
|
|
257
|
+
permission settings for the bot. These permissions are used to control
|
|
258
|
+
access and functionality of the bot across different organizational
|
|
259
|
+
structures and user groups.
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
dict: A dictionary containing the following keys, each with an empty list as its value:
|
|
263
|
+
- "organizations": List of organizations the bot has access to.
|
|
264
|
+
- "programs": List of programs the bot is allowed to interact with.
|
|
265
|
+
- "job_codes": List of job codes the bot is authorized for.
|
|
266
|
+
- "users": List of specific users granted access to the bot.
|
|
267
|
+
- "groups": List of user groups with bot access permissions.
|
|
268
|
+
"""
|
|
269
|
+
return {
|
|
270
|
+
"organizations": [],
|
|
271
|
+
"programs": [],
|
|
272
|
+
"job_codes": [],
|
|
273
|
+
"users": [],
|
|
274
|
+
"groups": [],
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
def permissions(self):
|
|
278
|
+
return self._permissions
|
|
279
|
+
|
|
280
|
+
def get_supported_models(self) -> List[str]:
|
|
281
|
+
return self._llm_obj.get_supported_models()
|
|
282
|
+
|
|
283
|
+
def _get_default_attr(self, key, default: Any = None, **kwargs):
|
|
284
|
+
if key in kwargs:
|
|
285
|
+
return kwargs.get(key)
|
|
286
|
+
if hasattr(self, key):
|
|
287
|
+
return getattr(self, key)
|
|
288
|
+
if not hasattr(self, key):
|
|
289
|
+
return default
|
|
290
|
+
return getattr(self, key)
|
|
291
|
+
|
|
292
|
+
def __repr__(self):
|
|
293
|
+
return f"<Bot.{self.__class__.__name__}:{self.name}>"
|
|
294
|
+
|
|
295
|
+
def default_rationale(self) -> str:
|
|
296
|
+
# TODO: read rationale from a file
|
|
297
|
+
return (
|
|
298
|
+
"** Your Style: **\n"
|
|
299
|
+
"- When responding to user queries, ensure that you provide accurate and up-to-date information.\n"
|
|
300
|
+
"- Be polite, clear and concise in your explanations.\n"
|
|
301
|
+
"- ensuring that responses are based only on verified information from owned sources.\n"
|
|
302
|
+
"- If you are unsure, let the user know and avoid making assumptions. Maintain a professional tone in all responses.\n"
|
|
303
|
+
"- Use simple language for complex topics to ensure user understanding.\n"
|
|
304
|
+
"- You are a fluent speaker, you can talk and respond fluently in English or Spanish, and you must answer in the same language as the user's question. If the user's language is not English, you should translate your response into their language.\n"
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
@property
|
|
309
|
+
def llm(self):
|
|
310
|
+
return self._llm
|
|
311
|
+
|
|
312
|
+
@llm.setter
|
|
313
|
+
def llm(self, model):
|
|
314
|
+
self._llm_obj = model
|
|
315
|
+
self._llm = model.get_llm()
|
|
316
|
+
|
|
317
|
+
@backoff.on_exception(
|
|
318
|
+
backoff.expo,
|
|
319
|
+
Exception,
|
|
320
|
+
max_tries=3,
|
|
321
|
+
max_time=60,
|
|
322
|
+
giveup=lambda e: not predicate(e) # Don't retry if predicate returns False
|
|
323
|
+
)
|
|
324
|
+
def llm_chain(
|
|
325
|
+
self,
|
|
326
|
+
llm: str = "vertexai",
|
|
327
|
+
model: str = None,
|
|
328
|
+
**kwargs
|
|
329
|
+
) -> AbstractLLM:
|
|
330
|
+
"""llm_chain.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
llm (str): The language model to use.
|
|
334
|
+
|
|
335
|
+
Returns:
|
|
336
|
+
AbstractLLM: The language model to use.
|
|
337
|
+
|
|
338
|
+
"""
|
|
339
|
+
try:
|
|
340
|
+
if llm == 'openai' and OPENAI_ENABLED:
|
|
341
|
+
mdl = OpenAILLM(model=model or "gpt-4.1", **kwargs)
|
|
342
|
+
elif llm in ('vertexai', 'VertexLLM') and VERTEX_ENABLED:
|
|
343
|
+
mdl = VertexLLM(model=model or "gemini-1.5-pro", **kwargs)
|
|
344
|
+
elif llm == 'anthropic' and ANTHROPIC_ENABLED:
|
|
345
|
+
mdl = AnthropicLLM(model=model or 'claude-3-5-sonnet-20240620', **kwargs)
|
|
346
|
+
elif llm in ('groq', 'Groq') and GROQ_ENABLED:
|
|
347
|
+
mdl = GroqLLM(model=model or "meta-llama/llama-4-maverick-17b-128e-instruct", **kwargs)
|
|
348
|
+
elif llm == 'llama3' and GROQ_ENABLED:
|
|
349
|
+
mdl = GroqLLM(model=model or "llama-3.3-70b-versatile", **kwargs)
|
|
350
|
+
elif llm == 'gemma' and GROQ_ENABLED:
|
|
351
|
+
mdl = GroqLLM(model=model or "gemma2-9b-it", **kwargs)
|
|
352
|
+
elif llm == 'mistral' and GROQ_ENABLED:
|
|
353
|
+
mdl = GroqLLM(model=model or "mistral-saba-24b", **kwargs)
|
|
354
|
+
elif llm == 'google' and GOOGLE_ENABLED:
|
|
355
|
+
mdl = GoogleGenAI(model=model or "models/gemini-2.5-pro-preview-03-25", **kwargs)
|
|
356
|
+
else:
|
|
357
|
+
raise ValueError(f"Invalid llm: {llm}")
|
|
358
|
+
# get the LLM:
|
|
359
|
+
return mdl
|
|
360
|
+
except Exception:
|
|
361
|
+
raise
|
|
362
|
+
|
|
363
|
+
def configure_llm(
|
|
364
|
+
self,
|
|
365
|
+
llm: Union[str, Callable] = None,
|
|
366
|
+
config: Optional[dict] = None,
|
|
367
|
+
use_chat: bool = False,
|
|
368
|
+
**kwargs
|
|
369
|
+
):
|
|
370
|
+
"""
|
|
371
|
+
Configuration of LLM.
|
|
372
|
+
"""
|
|
373
|
+
if isinstance(llm, str):
|
|
374
|
+
# Get the LLM By Name:
|
|
375
|
+
self._llm_obj = self.llm_chain(
|
|
376
|
+
llm,
|
|
377
|
+
**config
|
|
378
|
+
)
|
|
379
|
+
# getting langchain LLM from Obj:
|
|
380
|
+
self._llm = self._llm_obj.get_llm()
|
|
381
|
+
elif isinstance(llm, AbstractLLM):
|
|
382
|
+
self._llm_obj = llm
|
|
383
|
+
self._llm = llm.get_llm()
|
|
384
|
+
elif callable(llm):
|
|
385
|
+
# self._llm_obj = llm
|
|
386
|
+
self._llm = llm()
|
|
387
|
+
elif isinstance(self._llm_obj, str):
|
|
388
|
+
# is the name of the LLM object to be used:
|
|
389
|
+
self._llm_obj = self.llm_chain(
|
|
390
|
+
llm=self._llm_obj,
|
|
391
|
+
**kwargs
|
|
392
|
+
)
|
|
393
|
+
self._llm = self._llm_obj.get_llm()
|
|
394
|
+
elif isinstance(self._llm_obj, AbstractLLM):
|
|
395
|
+
self._llm = self._llm_obj.get_llm()
|
|
396
|
+
else:
|
|
397
|
+
# TODO: Calling a Default LLM
|
|
398
|
+
# TODO: passing the default configuration
|
|
399
|
+
try:
|
|
400
|
+
self._llm_obj = self.llm_chain(
|
|
401
|
+
llm=self._default_llm,
|
|
402
|
+
model=self._llm_model,
|
|
403
|
+
temperature=self._llm_temp,
|
|
404
|
+
top_k=self._llm_top_k,
|
|
405
|
+
top_p=self._llm_top_p,
|
|
406
|
+
max_tokens=self._max_tokens,
|
|
407
|
+
use_chat=use_chat
|
|
408
|
+
)
|
|
409
|
+
except Exception as e:
|
|
410
|
+
self.logger.error(
|
|
411
|
+
f"Error configuring Default LLM {self._llm_model}: {e}"
|
|
412
|
+
)
|
|
413
|
+
raise ConfigError(
|
|
414
|
+
f"Error configuring Default LLM {self._llm_model}: {e}"
|
|
415
|
+
)
|
|
416
|
+
self._llm = self._llm_obj.get_llm()
|
|
417
|
+
|
|
418
|
+
def create_kb(self, documents: list):
|
|
419
|
+
new_docs = []
|
|
420
|
+
for doc in documents:
|
|
421
|
+
content = doc.pop('content')
|
|
422
|
+
source = doc.pop('source', 'knowledge-base')
|
|
423
|
+
if doc:
|
|
424
|
+
meta = {
|
|
425
|
+
'source': source,
|
|
426
|
+
**doc
|
|
427
|
+
}
|
|
428
|
+
else:
|
|
429
|
+
meta = {'source': source}
|
|
430
|
+
if content:
|
|
431
|
+
new_docs.append(
|
|
432
|
+
Document(
|
|
433
|
+
page_content=content,
|
|
434
|
+
metadata=meta
|
|
435
|
+
)
|
|
436
|
+
)
|
|
437
|
+
return new_docs
|
|
438
|
+
|
|
439
|
+
def safe_format_template(self, template, **kwargs):
|
|
440
|
+
"""
|
|
441
|
+
Format a template string while preserving content inside triple backticks.
|
|
442
|
+
"""
|
|
443
|
+
# Split the template by triple backticks
|
|
444
|
+
parts = template.split("```")
|
|
445
|
+
|
|
446
|
+
# Format only the odd-indexed parts (outside triple backticks)
|
|
447
|
+
for i in range(0, len(parts), 2):
|
|
448
|
+
parts[i] = parts[i].format_map(SafeDict(**kwargs))
|
|
449
|
+
|
|
450
|
+
# Rejoin with triple backticks
|
|
451
|
+
return "```".join(parts)
|
|
452
|
+
|
|
453
|
+
def _define_prompt(self, config: Optional[dict] = None):
|
|
454
|
+
"""
|
|
455
|
+
Define the System Prompt and replace variables.
|
|
456
|
+
"""
|
|
457
|
+
# setup the prompt variables:
|
|
458
|
+
if config:
|
|
459
|
+
for key, val in config.items():
|
|
460
|
+
setattr(self, key, val)
|
|
461
|
+
|
|
462
|
+
pre_context = "\n".join(f"- {a}." for a in self.pre_instructions)
|
|
463
|
+
context = "{context}"
|
|
464
|
+
if self.context:
|
|
465
|
+
context = """
|
|
466
|
+
Here is a brief summary of relevant information:
|
|
467
|
+
Context: {context}
|
|
468
|
+
End of Context.
|
|
469
|
+
|
|
470
|
+
Given this information, please provide answers to the following question adding detailed and useful insights.
|
|
471
|
+
"""
|
|
472
|
+
tmpl = Template(self.system_prompt_template)
|
|
473
|
+
final_prompt = tmpl.safe_substitute(
|
|
474
|
+
name=self.name,
|
|
475
|
+
role=self.role,
|
|
476
|
+
goal=self.goal,
|
|
477
|
+
capabilities=self.capabilities,
|
|
478
|
+
backstory=self.backstory,
|
|
479
|
+
rationale=self.rationale,
|
|
480
|
+
pre_context=pre_context,
|
|
481
|
+
context=context
|
|
482
|
+
)
|
|
483
|
+
print('Template Prompt: \n', final_prompt)
|
|
484
|
+
self.system_prompt_template = final_prompt
|
|
485
|
+
|
|
486
|
+
async def configure(self, app=None) -> None:
|
|
487
|
+
"""Basic Configuration of Bot.
|
|
488
|
+
"""
|
|
489
|
+
self.app = None
|
|
490
|
+
if app:
|
|
491
|
+
if isinstance(app, web.Application):
|
|
492
|
+
self.app = app # register the app into the Extension
|
|
493
|
+
else:
|
|
494
|
+
self.app = app.get_app() # Nav Application
|
|
495
|
+
# adding this configured chatbot to app:
|
|
496
|
+
if self.app:
|
|
497
|
+
self.app[f"{self.name.lower()}_bot"] = self
|
|
498
|
+
# Configure LLM:
|
|
499
|
+
try:
|
|
500
|
+
self.configure_llm()
|
|
501
|
+
except Exception as e:
|
|
502
|
+
self.logger.error(
|
|
503
|
+
f"Error configuring LLM: {e}"
|
|
504
|
+
)
|
|
505
|
+
raise
|
|
506
|
+
# And define Prompt:
|
|
507
|
+
try:
|
|
508
|
+
self._define_prompt()
|
|
509
|
+
except Exception as e:
|
|
510
|
+
self.logger.error(
|
|
511
|
+
f"Error defining prompt: {e}"
|
|
512
|
+
)
|
|
513
|
+
raise
|
|
514
|
+
# Configure VectorStore if enabled:
|
|
515
|
+
if self._use_vector:
|
|
516
|
+
try:
|
|
517
|
+
self.configure_store()
|
|
518
|
+
except Exception as e:
|
|
519
|
+
self.logger.error(
|
|
520
|
+
f"Error configuring VectorStore: {e}"
|
|
521
|
+
)
|
|
522
|
+
raise
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
def _get_database_store(self, store: dict) -> AbstractStore:
|
|
526
|
+
name = store.get('name', 'milvus')
|
|
527
|
+
store_cls = supported_stores.get(name)
|
|
528
|
+
cls_path = f"parrot.stores.{name}"
|
|
529
|
+
try:
|
|
530
|
+
module = importlib.import_module(cls_path, package=name)
|
|
531
|
+
store_cls = getattr(module, store_cls)
|
|
532
|
+
return store_cls(
|
|
533
|
+
embedding_model=self.embedding_model,
|
|
534
|
+
embedding=self.embeddings,
|
|
535
|
+
**store
|
|
536
|
+
)
|
|
537
|
+
except (ModuleNotFoundError, ImportError) as e:
|
|
538
|
+
self.logger.error(
|
|
539
|
+
f"Error importing VectorStore: {e}"
|
|
540
|
+
)
|
|
541
|
+
raise
|
|
542
|
+
|
|
543
|
+
def configure_store(self, **kwargs):
|
|
544
|
+
# TODO: Implement VectorStore Configuration
|
|
545
|
+
if isinstance(self._vector_store, list):
|
|
546
|
+
# Is a list of vector stores instances:
|
|
547
|
+
for st in self._vector_store:
|
|
548
|
+
try:
|
|
549
|
+
store_cls = self._get_database_store(st)
|
|
550
|
+
store_cls.use_database = self._use_vector
|
|
551
|
+
self.stores.append(store_cls)
|
|
552
|
+
except ImportError:
|
|
553
|
+
continue
|
|
554
|
+
elif isinstance(self._vector_store, dict):
|
|
555
|
+
# Is a single vector store instance:
|
|
556
|
+
store_cls = self._get_database_store(self._vector_store)
|
|
557
|
+
store_cls.use_database = self._use_vector
|
|
558
|
+
self.stores.append(store_cls)
|
|
559
|
+
else:
|
|
560
|
+
raise ValueError(
|
|
561
|
+
f"Invalid Vector Store Config: {self._vector_store}"
|
|
562
|
+
)
|
|
563
|
+
self.logger.info(
|
|
564
|
+
f"Configured Vector Stores: {self.stores}"
|
|
565
|
+
)
|
|
566
|
+
if self.stores:
|
|
567
|
+
self.store = self.stores[0]
|
|
568
|
+
print('=================================')
|
|
569
|
+
print('END STORES >> ', self.stores, self.store)
|
|
570
|
+
print('=================================')
|
|
571
|
+
|
|
572
|
+
|
|
573
|
+
def get_memory(
|
|
574
|
+
self,
|
|
575
|
+
session_id: str = None,
|
|
576
|
+
key: str = 'chat_history',
|
|
577
|
+
input_key: str = 'question',
|
|
578
|
+
output_key: str = 'answer',
|
|
579
|
+
size: int = 5,
|
|
580
|
+
ttl: int = 86400
|
|
581
|
+
):
|
|
582
|
+
args = {
|
|
583
|
+
'memory_key': key,
|
|
584
|
+
'input_key': input_key,
|
|
585
|
+
'output_key': output_key,
|
|
586
|
+
'return_messages': True,
|
|
587
|
+
'max_len': size,
|
|
588
|
+
'k': 10
|
|
589
|
+
}
|
|
590
|
+
if session_id:
|
|
591
|
+
message_history = RedisChatMessageHistory(
|
|
592
|
+
url=REDIS_HISTORY_URL,
|
|
593
|
+
session_id=session_id,
|
|
594
|
+
ttl=ttl
|
|
595
|
+
)
|
|
596
|
+
args['chat_memory'] = message_history
|
|
597
|
+
return ConversationBufferWindowMemory(
|
|
598
|
+
**args
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
def clean_history(
|
|
602
|
+
self,
|
|
603
|
+
session_id: str = None
|
|
604
|
+
):
|
|
605
|
+
try:
|
|
606
|
+
redis_client = RedisChatMessageHistory(
|
|
607
|
+
url=REDIS_HISTORY_URL,
|
|
608
|
+
session_id=session_id,
|
|
609
|
+
ttl=60
|
|
610
|
+
)
|
|
611
|
+
redis_client.clear()
|
|
612
|
+
except Exception as e:
|
|
613
|
+
self.logger.error(
|
|
614
|
+
f"Error clearing chat history: {e}"
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
def get_response(self, response: dict, query: str = None):
|
|
618
|
+
if 'error' in response:
|
|
619
|
+
return response # return this error directly
|
|
620
|
+
try:
|
|
621
|
+
response = ChatResponse(**response)
|
|
622
|
+
response.query = query
|
|
623
|
+
response.response = self.as_markdown(
|
|
624
|
+
response,
|
|
625
|
+
return_sources=self.return_sources
|
|
626
|
+
)
|
|
627
|
+
return response
|
|
628
|
+
except (ValueError, TypeError) as exc:
|
|
629
|
+
self.logger.error(
|
|
630
|
+
f"Error validating response: {exc}"
|
|
631
|
+
)
|
|
632
|
+
return response
|
|
633
|
+
except ValidationError as exc:
|
|
634
|
+
self.logger.error(
|
|
635
|
+
f"Error on response: {exc.payload}"
|
|
636
|
+
)
|
|
637
|
+
return response
|
|
638
|
+
|
|
639
|
+
async def conversation(
|
|
640
|
+
self,
|
|
641
|
+
question: str,
|
|
642
|
+
chain_type: str = 'stuff',
|
|
643
|
+
search_type: str = 'similarity',
|
|
644
|
+
search_kwargs: dict = {"k": 4, "fetch_k": 10, "lambda_mult": 0.89},
|
|
645
|
+
return_docs: bool = True,
|
|
646
|
+
metric_type: str = None,
|
|
647
|
+
memory: Any = None,
|
|
648
|
+
**kwargs
|
|
649
|
+
):
|
|
650
|
+
# re-configure LLM:
|
|
651
|
+
new_llm = kwargs.pop('llm', None)
|
|
652
|
+
llm_config = kwargs.pop(
|
|
653
|
+
'llm_config',
|
|
654
|
+
{
|
|
655
|
+
"temperature": 0.2,
|
|
656
|
+
"top_k": 41,
|
|
657
|
+
"Top_p": 0.9
|
|
658
|
+
}
|
|
659
|
+
)
|
|
660
|
+
if new_llm:
|
|
661
|
+
self.configure_llm(llm=new_llm, config=llm_config)
|
|
662
|
+
# define the Pre-Context
|
|
663
|
+
pre_context = "\n".join(f"- {a}." for a in self.pre_instructions)
|
|
664
|
+
# custom_template = self.system_prompt_template.format_map(
|
|
665
|
+
# SafeDict(
|
|
666
|
+
# summaries=pre_context
|
|
667
|
+
# )
|
|
668
|
+
# )
|
|
669
|
+
custom_template = self.safe_format_template(
|
|
670
|
+
self.system_prompt_template,
|
|
671
|
+
summaries=pre_context
|
|
672
|
+
)
|
|
673
|
+
|
|
674
|
+
# Create prompt templates
|
|
675
|
+
self.system_prompt = SystemMessagePromptTemplate.from_template(
|
|
676
|
+
custom_template
|
|
677
|
+
)
|
|
678
|
+
self.human_prompt = HumanMessagePromptTemplate.from_template(
|
|
679
|
+
self.human_prompt_template,
|
|
680
|
+
input_variables=['question', 'chat_history']
|
|
681
|
+
)
|
|
682
|
+
# Combine into a ChatPromptTemplate
|
|
683
|
+
chat_prompt = ChatPromptTemplate.from_messages([
|
|
684
|
+
self.system_prompt,
|
|
685
|
+
self.human_prompt
|
|
686
|
+
])
|
|
687
|
+
if not memory:
|
|
688
|
+
memory = self.memory
|
|
689
|
+
if not self.memory:
|
|
690
|
+
# Initialize memory
|
|
691
|
+
self.memory = ConversationBufferMemory(
|
|
692
|
+
memory_key="chat_history",
|
|
693
|
+
input_key="question",
|
|
694
|
+
output_key='answer',
|
|
695
|
+
return_messages=True
|
|
696
|
+
)
|
|
697
|
+
try:
|
|
698
|
+
if self._use_vector:
|
|
699
|
+
async with self.store as store: #pylint: disable=E1101
|
|
700
|
+
vector = store.get_vector(metric_type=metric_type)
|
|
701
|
+
retriever = VectorStoreRetriever(
|
|
702
|
+
vectorstore=vector,
|
|
703
|
+
search_type=search_type,
|
|
704
|
+
chain_type=chain_type,
|
|
705
|
+
search_kwargs=search_kwargs
|
|
706
|
+
)
|
|
707
|
+
# Create the ConversationalRetrievalChain with custom prompt
|
|
708
|
+
chain = ConversationalRetrievalChain.from_llm(
|
|
709
|
+
llm=self._llm,
|
|
710
|
+
retriever=retriever,
|
|
711
|
+
memory=self.memory,
|
|
712
|
+
chain_type=chain_type, # e.g., 'stuff', 'map_reduce', etc.
|
|
713
|
+
verbose=True,
|
|
714
|
+
return_source_documents=return_docs,
|
|
715
|
+
return_generated_question=True,
|
|
716
|
+
combine_docs_chain_kwargs={
|
|
717
|
+
'prompt': chat_prompt
|
|
718
|
+
},
|
|
719
|
+
**kwargs
|
|
720
|
+
)
|
|
721
|
+
response = await chain.ainvoke(
|
|
722
|
+
{"question": question}
|
|
723
|
+
)
|
|
724
|
+
else:
|
|
725
|
+
retriever = EmptyRetriever()
|
|
726
|
+
# Create the ConversationalRetrievalChain with custom prompt
|
|
727
|
+
chain = ConversationalRetrievalChain.from_llm(
|
|
728
|
+
llm=self._llm,
|
|
729
|
+
retriever=retriever,
|
|
730
|
+
memory=self.memory,
|
|
731
|
+
chain_type=chain_type, # e.g., 'stuff', 'map_reduce', etc.
|
|
732
|
+
verbose=True,
|
|
733
|
+
return_source_documents=return_docs,
|
|
734
|
+
return_generated_question=True,
|
|
735
|
+
combine_docs_chain_kwargs={
|
|
736
|
+
'prompt': chat_prompt
|
|
737
|
+
},
|
|
738
|
+
**kwargs
|
|
739
|
+
)
|
|
740
|
+
response = await chain.ainvoke(
|
|
741
|
+
{"question": question}
|
|
742
|
+
)
|
|
743
|
+
except asyncio.CancelledError:
|
|
744
|
+
# Handle task cancellation
|
|
745
|
+
print("Conversation task was cancelled.")
|
|
746
|
+
except Exception as e:
|
|
747
|
+
self.logger.error(
|
|
748
|
+
f"Error in conversation: {e}"
|
|
749
|
+
)
|
|
750
|
+
raise
|
|
751
|
+
return self.get_response(response, question)
|
|
752
|
+
|
|
753
|
+
async def question(
|
|
754
|
+
self,
|
|
755
|
+
question: str,
|
|
756
|
+
chain_type: str = 'stuff',
|
|
757
|
+
search_type: str = 'similarity',
|
|
758
|
+
search_kwargs: dict = {"k": 4, "fetch_k": 10, "lambda_mult": 0.89},
|
|
759
|
+
return_docs: bool = True,
|
|
760
|
+
metric_type: str = None,
|
|
761
|
+
**kwargs
|
|
762
|
+
):
|
|
763
|
+
# pre_context = "\n".join(f"- {a}." for a in self.pre_instructions)
|
|
764
|
+
# system_prompt = self.system_prompt_template.format_map(
|
|
765
|
+
# SafeDict(
|
|
766
|
+
# summaries=pre_context
|
|
767
|
+
# )
|
|
768
|
+
# )
|
|
769
|
+
system_prompt = self.system_prompt_template
|
|
770
|
+
human_prompt = self.human_prompt_template.replace(
|
|
771
|
+
'**Chat History:**', ''
|
|
772
|
+
)
|
|
773
|
+
human_prompt = human_prompt.format_map(
|
|
774
|
+
SafeDict(
|
|
775
|
+
chat_history=''
|
|
776
|
+
)
|
|
777
|
+
)
|
|
778
|
+
# re-configure LLM:
|
|
779
|
+
new_llm = kwargs.pop('llm', None)
|
|
780
|
+
if new_llm:
|
|
781
|
+
llm_config = kwargs.pop(
|
|
782
|
+
'llm_config',
|
|
783
|
+
{
|
|
784
|
+
"temperature": 0.2,
|
|
785
|
+
"top_k": 41,
|
|
786
|
+
"Top_p": 0.9
|
|
787
|
+
}
|
|
788
|
+
)
|
|
789
|
+
self.configure_llm(llm=new_llm, config=llm_config)
|
|
790
|
+
# Combine into a ChatPromptTemplate
|
|
791
|
+
prompt = PromptTemplate(
|
|
792
|
+
template=system_prompt + '\n' + human_prompt,
|
|
793
|
+
input_variables=['context', 'question']
|
|
794
|
+
)
|
|
795
|
+
try:
|
|
796
|
+
if self._use_vector:
|
|
797
|
+
async with self.store as store: #pylint: disable=E1101
|
|
798
|
+
vector = store.get_vector(metric_type=metric_type)
|
|
799
|
+
retriever = VectorStoreRetriever(
|
|
800
|
+
vectorstore=vector,
|
|
801
|
+
search_type=search_type,
|
|
802
|
+
chain_type=chain_type,
|
|
803
|
+
search_kwargs=search_kwargs
|
|
804
|
+
)
|
|
805
|
+
chain = RetrievalQA.from_chain_type(
|
|
806
|
+
llm=self._llm,
|
|
807
|
+
chain_type=chain_type, # e.g., 'stuff', 'map_reduce', etc.
|
|
808
|
+
retriever=retriever,
|
|
809
|
+
chain_type_kwargs={
|
|
810
|
+
'prompt': prompt,
|
|
811
|
+
},
|
|
812
|
+
return_source_documents=return_docs,
|
|
813
|
+
**kwargs
|
|
814
|
+
)
|
|
815
|
+
response = await chain.ainvoke(
|
|
816
|
+
question
|
|
817
|
+
)
|
|
818
|
+
else:
|
|
819
|
+
retriever = EmptyRetriever()
|
|
820
|
+
# Create the RetrievalQA chain with custom prompt
|
|
821
|
+
chain = RetrievalQA.from_chain_type(
|
|
822
|
+
llm=self._llm,
|
|
823
|
+
chain_type=chain_type, # e.g., 'stuff', 'map_reduce', etc.
|
|
824
|
+
retriever=retriever,
|
|
825
|
+
chain_type_kwargs={
|
|
826
|
+
'prompt': prompt,
|
|
827
|
+
},
|
|
828
|
+
return_source_documents=return_docs,
|
|
829
|
+
**kwargs
|
|
830
|
+
)
|
|
831
|
+
response = await chain.ainvoke(
|
|
832
|
+
question
|
|
833
|
+
)
|
|
834
|
+
except (RuntimeError, asyncio.CancelledError):
|
|
835
|
+
# check for "Event loop is closed"
|
|
836
|
+
response = chain.invoke(
|
|
837
|
+
question
|
|
838
|
+
)
|
|
839
|
+
except Exception as e:
|
|
840
|
+
# Handle exceptions
|
|
841
|
+
self.logger.error(
|
|
842
|
+
f"An error occurred: {e}"
|
|
843
|
+
)
|
|
844
|
+
response = {
|
|
845
|
+
"query": question,
|
|
846
|
+
"error": str(e)
|
|
847
|
+
}
|
|
848
|
+
return self.get_response(response, question)
|
|
849
|
+
|
|
850
|
+
def as_markdown(self, response: ChatResponse, return_sources: bool = False) -> str:
|
|
851
|
+
markdown_output = f"**Question**: {response.question} \n"
|
|
852
|
+
markdown_output += f"**Answer**: \n {response.answer} \n"
|
|
853
|
+
if return_sources is True and response.source_documents:
|
|
854
|
+
source_documents = response.source_documents
|
|
855
|
+
current_sources = []
|
|
856
|
+
block_sources = []
|
|
857
|
+
count = 0
|
|
858
|
+
d = {}
|
|
859
|
+
for source in source_documents:
|
|
860
|
+
if count >= 20:
|
|
861
|
+
break # Exit loop after processing 10 documents
|
|
862
|
+
metadata = source.metadata
|
|
863
|
+
if 'url' in metadata:
|
|
864
|
+
src = metadata.get('url')
|
|
865
|
+
elif 'filename' in metadata:
|
|
866
|
+
src = metadata.get('filename')
|
|
867
|
+
else:
|
|
868
|
+
src = metadata.get('source', 'unknown')
|
|
869
|
+
if src == 'knowledge-base':
|
|
870
|
+
continue # avoid attaching kb documents
|
|
871
|
+
source_title = metadata.get('title', src)
|
|
872
|
+
if source_title in current_sources:
|
|
873
|
+
continue
|
|
874
|
+
current_sources.append(source_title)
|
|
875
|
+
if src:
|
|
876
|
+
d[src] = metadata.get('document_meta', {})
|
|
877
|
+
source_filename = metadata.get('filename', src)
|
|
878
|
+
if src:
|
|
879
|
+
block_sources.append(f"- [{source_title}]({src})")
|
|
880
|
+
else:
|
|
881
|
+
if 'page_number' in metadata:
|
|
882
|
+
block_sources.append(
|
|
883
|
+
f"- {source_filename} (Page {metadata.get('page_number')})"
|
|
884
|
+
)
|
|
885
|
+
else:
|
|
886
|
+
block_sources.append(f"- {source_filename}")
|
|
887
|
+
if block_sources:
|
|
888
|
+
markdown_output += f"**Sources**: \n"
|
|
889
|
+
markdown_output += "\n".join(block_sources)
|
|
890
|
+
if d:
|
|
891
|
+
response.documents = d
|
|
892
|
+
return markdown_output
|
|
893
|
+
|
|
894
|
+
async def __aenter__(self):
|
|
895
|
+
if not self.store:
|
|
896
|
+
self.store = EmptyStore()
|
|
897
|
+
return self
|
|
898
|
+
|
|
899
|
+
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
900
|
+
pass
|
|
901
|
+
|
|
902
|
+
def retrieval(self, request: web.Request = None) -> "AbstractBot":
|
|
903
|
+
"""
|
|
904
|
+
Configure the retrieval chain for the Chatbot, returning `self` if allowed,
|
|
905
|
+
or raise HTTPUnauthorized if not. A permissions dictionary can specify
|
|
906
|
+
* users
|
|
907
|
+
* groups
|
|
908
|
+
* job_codes
|
|
909
|
+
* programs
|
|
910
|
+
* organizations
|
|
911
|
+
If a permission list is the literal string "*", it means "unrestricted" for that category.
|
|
912
|
+
|
|
913
|
+
Args:
|
|
914
|
+
request (web.Request, optional): The request object. Defaults to None.
|
|
915
|
+
Returns:
|
|
916
|
+
AbstractBot: The Chatbot object or raise HTTPUnauthorized.
|
|
917
|
+
"""
|
|
918
|
+
self._request = request
|
|
919
|
+
session = request.session
|
|
920
|
+
try:
|
|
921
|
+
userinfo = session[AUTH_SESSION_OBJECT]
|
|
922
|
+
except KeyError:
|
|
923
|
+
userinfo = {}
|
|
924
|
+
|
|
925
|
+
# decode your user from session
|
|
926
|
+
try:
|
|
927
|
+
user = session.decode("user")
|
|
928
|
+
except (KeyError, TypeError):
|
|
929
|
+
raise web.HTTPUnauthorized(
|
|
930
|
+
reason="Invalid user"
|
|
931
|
+
)
|
|
932
|
+
|
|
933
|
+
# 1: superuser is always allowed
|
|
934
|
+
if userinfo.get('superuser', False) is True:
|
|
935
|
+
return self
|
|
936
|
+
|
|
937
|
+
# convenience references
|
|
938
|
+
users_allowed = self._permissions.get('users', [])
|
|
939
|
+
groups_allowed = self._permissions.get('groups', [])
|
|
940
|
+
job_codes_allowed = self._permissions.get('job_codes', [])
|
|
941
|
+
programs_allowed = self._permissions.get('programs', [])
|
|
942
|
+
orgs_allowed = self._permissions.get('organizations', [])
|
|
943
|
+
|
|
944
|
+
# 2: check if 'users' == "*" or user.username in 'users'
|
|
945
|
+
if users_allowed == "*":
|
|
946
|
+
return self
|
|
947
|
+
if user.get('username') in users_allowed:
|
|
948
|
+
return self
|
|
949
|
+
|
|
950
|
+
# 3: check job_code
|
|
951
|
+
if job_codes_allowed == "*":
|
|
952
|
+
return self
|
|
953
|
+
try:
|
|
954
|
+
if user.job_code in job_codes_allowed:
|
|
955
|
+
return self
|
|
956
|
+
except AttributeError:
|
|
957
|
+
pass
|
|
958
|
+
|
|
959
|
+
# 4: check groups
|
|
960
|
+
# If groups_allowed == "*", no restriction on groups
|
|
961
|
+
if groups_allowed == "*":
|
|
962
|
+
return self
|
|
963
|
+
# otherwise, see if there's an intersection
|
|
964
|
+
user_groups = set(userinfo.get("groups", []))
|
|
965
|
+
if not user_groups.isdisjoint(groups_allowed):
|
|
966
|
+
return self
|
|
967
|
+
|
|
968
|
+
# 5: check programs
|
|
969
|
+
if programs_allowed == "*":
|
|
970
|
+
return self
|
|
971
|
+
try:
|
|
972
|
+
user_programs = set(userinfo.get("programs", []))
|
|
973
|
+
if not user_programs.isdisjoint(programs_allowed):
|
|
974
|
+
return self
|
|
975
|
+
except AttributeError:
|
|
976
|
+
pass
|
|
977
|
+
|
|
978
|
+
|
|
979
|
+
# 6: check organizations
|
|
980
|
+
if orgs_allowed == "*":
|
|
981
|
+
return self
|
|
982
|
+
try:
|
|
983
|
+
user_orgs = set(userinfo.get("organizations", []))
|
|
984
|
+
if not user_orgs.isdisjoint(orgs_allowed):
|
|
985
|
+
return self
|
|
986
|
+
except AttributeError:
|
|
987
|
+
pass
|
|
988
|
+
|
|
989
|
+
# If none of the conditions pass, raise unauthorized:
|
|
990
|
+
raise web.HTTPUnauthorized(
|
|
991
|
+
reason=f"User {user.username} is not Unauthorized"
|
|
992
|
+
)
|
|
993
|
+
|
|
994
|
+
async def shutdown(self, **kwargs) -> None:
|
|
995
|
+
"""
|
|
996
|
+
Shutdown.
|
|
997
|
+
|
|
998
|
+
Optional shutdown method to clean up resources.
|
|
999
|
+
This method can be overridden in subclasses to perform any necessary cleanup tasks,
|
|
1000
|
+
such as closing database connections, releasing resources, etc.
|
|
1001
|
+
Args:
|
|
1002
|
+
**kwargs: Additional keyword arguments.
|
|
1003
|
+
"""
|
|
1004
|
+
|
|
1005
|
+
async def invoke(
|
|
1006
|
+
self,
|
|
1007
|
+
question: str,
|
|
1008
|
+
chain_type: str = 'stuff',
|
|
1009
|
+
search_type: str = 'similarity',
|
|
1010
|
+
search_kwargs: dict = {"k": 4, "fetch_k": 10, "lambda_mult": 0.89},
|
|
1011
|
+
return_docs: bool = True,
|
|
1012
|
+
metric_type: str = None,
|
|
1013
|
+
memory: Any = None,
|
|
1014
|
+
**kwargs
|
|
1015
|
+
) -> ChatResponse:
|
|
1016
|
+
"""Build a Chain to answer Questions using AI Models.
|
|
1017
|
+
"""
|
|
1018
|
+
new_llm = kwargs.pop('llm', None)
|
|
1019
|
+
if new_llm is not None:
|
|
1020
|
+
# re-configure LLM:
|
|
1021
|
+
llm_config = kwargs.pop(
|
|
1022
|
+
'llm_config',
|
|
1023
|
+
{
|
|
1024
|
+
"temperature": 0.2,
|
|
1025
|
+
"top_k": 41,
|
|
1026
|
+
"Top_p": 0.9
|
|
1027
|
+
}
|
|
1028
|
+
)
|
|
1029
|
+
self.configure_llm(llm=new_llm, config=llm_config)
|
|
1030
|
+
# define the Pre-Context
|
|
1031
|
+
pre_context = "\n".join(f"- {a}." for a in self.pre_instructions)
|
|
1032
|
+
custom_template = self.system_prompt_template.format_map(
|
|
1033
|
+
SafeDict(
|
|
1034
|
+
summaries=pre_context
|
|
1035
|
+
)
|
|
1036
|
+
)
|
|
1037
|
+
# Create prompt templates
|
|
1038
|
+
self.system_prompt = SystemMessagePromptTemplate.from_template(
|
|
1039
|
+
custom_template
|
|
1040
|
+
)
|
|
1041
|
+
self.human_prompt = HumanMessagePromptTemplate.from_template(
|
|
1042
|
+
self.human_prompt_template,
|
|
1043
|
+
input_variables=['question', 'chat_history']
|
|
1044
|
+
)
|
|
1045
|
+
# Combine into a ChatPromptTemplate
|
|
1046
|
+
chat_prompt = ChatPromptTemplate.from_messages([
|
|
1047
|
+
self.system_prompt,
|
|
1048
|
+
self.human_prompt
|
|
1049
|
+
])
|
|
1050
|
+
if not memory:
|
|
1051
|
+
memory = self.memory
|
|
1052
|
+
if not self.memory:
|
|
1053
|
+
# Initialize memory
|
|
1054
|
+
self.memory = ConversationBufferMemory(
|
|
1055
|
+
memory_key="chat_history",
|
|
1056
|
+
input_key="question",
|
|
1057
|
+
output_key='answer',
|
|
1058
|
+
return_messages=True
|
|
1059
|
+
)
|
|
1060
|
+
async with self.store as store: #pylint: disable=E1101
|
|
1061
|
+
# Check if we have multiple stores:
|
|
1062
|
+
if self._use_vector:
|
|
1063
|
+
if len(self.stores) > 1:
|
|
1064
|
+
store = self.stores[0]
|
|
1065
|
+
# retriever = MultiVectorStoreRetriever(
|
|
1066
|
+
# stores=self.stores,
|
|
1067
|
+
# metric_type=metric_type,
|
|
1068
|
+
# search_type=search_type,
|
|
1069
|
+
# chain_type=chain_type,
|
|
1070
|
+
# search_kwargs=search_kwargs
|
|
1071
|
+
# )
|
|
1072
|
+
# else:
|
|
1073
|
+
vector = store.get_vector(metric_type=metric_type)
|
|
1074
|
+
retriever = VectorStoreRetriever(
|
|
1075
|
+
vectorstore=vector,
|
|
1076
|
+
search_type=search_type,
|
|
1077
|
+
chain_type=chain_type,
|
|
1078
|
+
search_kwargs=search_kwargs
|
|
1079
|
+
)
|
|
1080
|
+
else:
|
|
1081
|
+
retriever = EmptyRetriever()
|
|
1082
|
+
print('Retriever ', retriever)
|
|
1083
|
+
# if self.kb:
|
|
1084
|
+
# b25_retriever = BM25Retriever.from_documents(self.kb)
|
|
1085
|
+
# retriever = EnsembleRetriever(
|
|
1086
|
+
# retrievers=[retriever, b25_retriever],
|
|
1087
|
+
# weights=[0.8, 0.2]
|
|
1088
|
+
# )
|
|
1089
|
+
try:
|
|
1090
|
+
# Create the ConversationalRetrievalChain with custom prompt
|
|
1091
|
+
chain = ConversationalRetrievalChain.from_llm(
|
|
1092
|
+
llm=self._llm,
|
|
1093
|
+
retriever=retriever,
|
|
1094
|
+
memory=self.memory,
|
|
1095
|
+
chain_type=chain_type, # e.g., 'stuff', 'map_reduce', etc.
|
|
1096
|
+
verbose=True,
|
|
1097
|
+
return_source_documents=return_docs,
|
|
1098
|
+
return_generated_question=True,
|
|
1099
|
+
combine_docs_chain_kwargs={
|
|
1100
|
+
'prompt': chat_prompt
|
|
1101
|
+
},
|
|
1102
|
+
**kwargs
|
|
1103
|
+
)
|
|
1104
|
+
response = await chain.ainvoke(
|
|
1105
|
+
{"question": question}
|
|
1106
|
+
)
|
|
1107
|
+
return self.get_response(response, question)
|
|
1108
|
+
except asyncio.CancelledError:
|
|
1109
|
+
# Handle task cancellation
|
|
1110
|
+
print("Conversation task was cancelled.")
|
|
1111
|
+
except Exception as e:
|
|
1112
|
+
self.logger.error(
|
|
1113
|
+
f"Error in conversation: {e}"
|
|
1114
|
+
)
|
|
1115
|
+
raise
|