mb-rag 1.1.47__py3-none-any.whl → 1.1.56.post0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mb-rag might be problematic. Click here for more details.
- mb_rag/basic.py +306 -0
- mb_rag/chatbot/chains.py +206 -206
- mb_rag/chatbot/conversation.py +185 -0
- mb_rag/chatbot/prompts.py +58 -58
- mb_rag/rag/embeddings.py +810 -810
- mb_rag/utils/all_data_extract.py +64 -64
- mb_rag/utils/bounding_box.py +231 -231
- mb_rag/utils/document_extract.py +354 -354
- mb_rag/utils/extra.py +73 -73
- mb_rag/utils/pdf_extract.py +428 -428
- mb_rag/version.py +1 -1
- {mb_rag-1.1.47.dist-info → mb_rag-1.1.56.post0.dist-info}/METADATA +11 -11
- mb_rag-1.1.56.post0.dist-info/RECORD +19 -0
- mb_rag/chatbot/basic.py +0 -644
- mb_rag-1.1.47.dist-info/RECORD +0 -18
- {mb_rag-1.1.47.dist-info → mb_rag-1.1.56.post0.dist-info}/WHEEL +0 -0
- {mb_rag-1.1.47.dist-info → mb_rag-1.1.56.post0.dist-info}/top_level.txt +0 -0
mb_rag/version.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
Metadata-Version: 2.4
|
|
2
|
-
Name: mb_rag
|
|
3
|
-
Version: 1.1.
|
|
4
|
-
Summary: RAG function file
|
|
5
|
-
Author: ['Malav Bateriwala']
|
|
6
|
-
Requires-Python: >=3.8
|
|
7
|
-
Requires-Dist: mb_base
|
|
8
|
-
Dynamic: author
|
|
9
|
-
Dynamic: requires-dist
|
|
10
|
-
Dynamic: requires-python
|
|
11
|
-
Dynamic: summary
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: mb_rag
|
|
3
|
+
Version: 1.1.56.post0
|
|
4
|
+
Summary: RAG function file
|
|
5
|
+
Author: ['Malav Bateriwala']
|
|
6
|
+
Requires-Python: >=3.8
|
|
7
|
+
Requires-Dist: mb_base
|
|
8
|
+
Dynamic: author
|
|
9
|
+
Dynamic: requires-dist
|
|
10
|
+
Dynamic: requires-python
|
|
11
|
+
Dynamic: summary
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
mb_rag/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
mb_rag/basic.py,sha256=v-D27FAJZTpUHsKHTNw2-cZiLHbwlKYRzrOlFNyAudY,12216
|
|
3
|
+
mb_rag/version.py,sha256=pCIU6hqXriGFDx2MivR_HZSvVwlwBMSh-2M7HpwDDyI,207
|
|
4
|
+
mb_rag/chatbot/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
|
+
mb_rag/chatbot/chains.py,sha256=0Ir1nUlNdmzs7jPFiES2z3EaGErDTmIMZ1Lx7R4ajLg,7429
|
|
6
|
+
mb_rag/chatbot/conversation.py,sha256=uboW4p_MBXbD3u4b79O0cnuKRrr-6fyeQgnv4C4DrA0,7179
|
|
7
|
+
mb_rag/chatbot/prompts.py,sha256=-DGEV-a9kaPhsRTMrRPeg0kJxOo9um315SHpHoAijbg,1807
|
|
8
|
+
mb_rag/rag/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
9
|
+
mb_rag/rag/embeddings.py,sha256=vqU9Yt-6aXu9XN8gBcXkhYHBTZvERH87HiFTzYtJHKM,31325
|
|
10
|
+
mb_rag/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
+
mb_rag/utils/all_data_extract.py,sha256=CCxEPwD52qAAbMESW9QbPNtXlQsSTviSlcUOVQqPOC8,2596
|
|
12
|
+
mb_rag/utils/bounding_box.py,sha256=vjy1BzqebEz0dtM5H5S_94UkmZemCnM-_AnBItOWK3o,8597
|
|
13
|
+
mb_rag/utils/document_extract.py,sha256=AHvC2YGo60snXXB9C15QyMRi3ZfkfcRAfV8t_IRD4IM,12912
|
|
14
|
+
mb_rag/utils/extra.py,sha256=SWbZmTttK20GNkt9AM0xoOzhQEYwJiPmR9rgHnt1_RU,2565
|
|
15
|
+
mb_rag/utils/pdf_extract.py,sha256=aHk61kdzSM6uptzQHuGin8Kgf9DfaQjcLX5-q9hNQok,16026
|
|
16
|
+
mb_rag-1.1.56.post0.dist-info/METADATA,sha256=bM5gnocs1dQczGE0UPl3fWvJmPTQwD-WagHQYzW05zM,251
|
|
17
|
+
mb_rag-1.1.56.post0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
18
|
+
mb_rag-1.1.56.post0.dist-info/top_level.txt,sha256=FIK1eAa5uYnurgXZquBG-s3PIy-HDTC5yJBW4lTH_pM,7
|
|
19
|
+
mb_rag-1.1.56.post0.dist-info/RECORD,,
|
mb_rag/chatbot/basic.py
DELETED
|
@@ -1,644 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
from dotenv import load_dotenv
|
|
3
|
-
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
|
4
|
-
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|
5
|
-
from langchain.document_loaders import TextLoader
|
|
6
|
-
from IPython.display import display, HTML
|
|
7
|
-
from typing import Optional, List, Dict, Any, Union
|
|
8
|
-
from mb_rag.utils.extra import check_package
|
|
9
|
-
import base64
|
|
10
|
-
|
|
11
|
-
__all__ = [
|
|
12
|
-
'ChatbotBase',
|
|
13
|
-
'ModelFactory',
|
|
14
|
-
'ConversationModel',
|
|
15
|
-
'IPythonStreamHandler',
|
|
16
|
-
'AgentFactory'
|
|
17
|
-
]
|
|
18
|
-
|
|
19
|
-
class ChatbotBase:
|
|
20
|
-
"""Base class for chatbot functionality"""
|
|
21
|
-
|
|
22
|
-
@staticmethod
|
|
23
|
-
def load_env(file_path: str) -> None:
|
|
24
|
-
"""
|
|
25
|
-
Load environment variables from a file
|
|
26
|
-
Args:
|
|
27
|
-
file_path (str): Path to the environment file
|
|
28
|
-
"""
|
|
29
|
-
load_dotenv(file_path)
|
|
30
|
-
|
|
31
|
-
@staticmethod
|
|
32
|
-
def add_os_key(name: str, key: str) -> None:
|
|
33
|
-
"""
|
|
34
|
-
Add an API key to the environment
|
|
35
|
-
Args:
|
|
36
|
-
name (str): Name of the API key
|
|
37
|
-
key (str): API key
|
|
38
|
-
"""
|
|
39
|
-
os.environ[name] = key
|
|
40
|
-
|
|
41
|
-
@staticmethod
|
|
42
|
-
def get_client():
|
|
43
|
-
"""
|
|
44
|
-
Returns a boto3 client for S3
|
|
45
|
-
"""
|
|
46
|
-
if not check_package("boto3"):
|
|
47
|
-
raise ImportError("Boto3 package not found. Please install it using: pip install boto3")
|
|
48
|
-
|
|
49
|
-
import boto3
|
|
50
|
-
return boto3.client('s3')
|
|
51
|
-
|
|
52
|
-
class ModelFactory:
|
|
53
|
-
"""Factory class for creating different types of chatbot models"""
|
|
54
|
-
|
|
55
|
-
def __init__(self, model_type: str = 'openai', model_name: str = "gpt-4o", **kwargs) -> Any:
|
|
56
|
-
"""
|
|
57
|
-
Factory method to create any type of model
|
|
58
|
-
Args:
|
|
59
|
-
model_type (str): Type of model to create. Default is OpenAI. Options are openai, anthropic, google, ollama , groq
|
|
60
|
-
model_name (str): Name of the model
|
|
61
|
-
**kwargs: Additional arguments
|
|
62
|
-
Returns:
|
|
63
|
-
Any: Chatbot model
|
|
64
|
-
"""
|
|
65
|
-
creators = {
|
|
66
|
-
'openai': self.create_openai,
|
|
67
|
-
'anthropic': self.create_anthropic,
|
|
68
|
-
'google': self.create_google,
|
|
69
|
-
'ollama': self.create_ollama,
|
|
70
|
-
'groq': self.create_groq,
|
|
71
|
-
'deepseek': self.create_deepseek,
|
|
72
|
-
'qwen' : self.create_qwen,
|
|
73
|
-
'hugging_face': self.create_hugging_face
|
|
74
|
-
}
|
|
75
|
-
|
|
76
|
-
self.model_type = model_type
|
|
77
|
-
self.model_name = model_name
|
|
78
|
-
model_data = creators.get(model_type)
|
|
79
|
-
if not model_data:
|
|
80
|
-
raise ValueError(f"Unsupported model type: {model_type}")
|
|
81
|
-
|
|
82
|
-
try:
|
|
83
|
-
self.model = model_data(model_name, **kwargs)
|
|
84
|
-
except Exception as e:
|
|
85
|
-
raise ValueError(f"Error creating {model_type} model: {str(e)}")
|
|
86
|
-
|
|
87
|
-
@classmethod
|
|
88
|
-
def create_openai(cls, model_name: str = "gpt-4o", **kwargs) -> Any:
|
|
89
|
-
"""
|
|
90
|
-
Create OpenAI chatbot model
|
|
91
|
-
Args:
|
|
92
|
-
model_name (str): Name of the model
|
|
93
|
-
**kwargs: Additional arguments
|
|
94
|
-
Returns:
|
|
95
|
-
ChatOpenAI: Chatbot model
|
|
96
|
-
"""
|
|
97
|
-
if not check_package("openai"):
|
|
98
|
-
raise ImportError("OpenAI package not found. Please install it using: pip install openai langchain-openai")
|
|
99
|
-
|
|
100
|
-
from langchain_openai import ChatOpenAI
|
|
101
|
-
kwargs["model_name"] = model_name
|
|
102
|
-
return ChatOpenAI(**kwargs)
|
|
103
|
-
|
|
104
|
-
@classmethod
|
|
105
|
-
def create_anthropic(cls, model_name: str = "claude-3-opus-20240229", **kwargs) -> Any:
|
|
106
|
-
"""
|
|
107
|
-
Create Anthropic chatbot model
|
|
108
|
-
Args:
|
|
109
|
-
model_name (str): Name of the model
|
|
110
|
-
**kwargs: Additional arguments
|
|
111
|
-
Returns:
|
|
112
|
-
ChatAnthropic: Chatbot model
|
|
113
|
-
"""
|
|
114
|
-
if not check_package("anthropic"):
|
|
115
|
-
raise ImportError("Anthropic package not found. Please install it using: pip install anthropic langchain-anthropic")
|
|
116
|
-
|
|
117
|
-
from langchain_anthropic import ChatAnthropic
|
|
118
|
-
kwargs["model_name"] = model_name
|
|
119
|
-
return ChatAnthropic(**kwargs)
|
|
120
|
-
|
|
121
|
-
@classmethod
|
|
122
|
-
def create_google(cls, model_name: str = "gemini-2.0-flash", **kwargs) -> Any:
|
|
123
|
-
"""
|
|
124
|
-
Create Google chatbot model
|
|
125
|
-
Args:
|
|
126
|
-
model_name (str): Name of the model
|
|
127
|
-
**kwargs: Additional arguments
|
|
128
|
-
Returns:
|
|
129
|
-
ChatGoogleGenerativeAI: Chatbot model
|
|
130
|
-
"""
|
|
131
|
-
if not check_package("langchain_google_genai"):
|
|
132
|
-
raise ImportError("langchain_google_genai package not found. Please install it using: pip install google-generativeai langchain-google-genai")
|
|
133
|
-
|
|
134
|
-
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
135
|
-
kwargs["model"] = model_name
|
|
136
|
-
return ChatGoogleGenerativeAI(**kwargs)
|
|
137
|
-
|
|
138
|
-
@classmethod
|
|
139
|
-
def create_ollama(cls, model_name: str = "llama3", **kwargs) -> Any:
|
|
140
|
-
"""
|
|
141
|
-
Create Ollama chatbot model
|
|
142
|
-
Args:
|
|
143
|
-
model_name (str): Name of the model
|
|
144
|
-
**kwargs: Additional arguments
|
|
145
|
-
Returns:
|
|
146
|
-
Ollama: Chatbot model
|
|
147
|
-
"""
|
|
148
|
-
if not check_package("langchain_ollama"):
|
|
149
|
-
raise ImportError("Langchain Community package not found. Please install it using: pip install langchain_ollama")
|
|
150
|
-
|
|
151
|
-
from langchain_ollama import ChatOllama
|
|
152
|
-
|
|
153
|
-
print(f"Current Ollama serve model is {os.system('ollama ps')}")
|
|
154
|
-
kwargs["model"] = model_name
|
|
155
|
-
return ChatOllama(**kwargs)
|
|
156
|
-
|
|
157
|
-
@classmethod
|
|
158
|
-
def create_groq(cls, model_name: str = "llama-3.3-70b-versatile", **kwargs) -> Any:
|
|
159
|
-
"""
|
|
160
|
-
Create Groq chatbot model
|
|
161
|
-
Args:
|
|
162
|
-
model_name (str): Name of the model
|
|
163
|
-
**kwargs: Additional arguments. Options are: temperature, groq_api_key, model_name
|
|
164
|
-
Returns:
|
|
165
|
-
ChatGroq: Chatbot model
|
|
166
|
-
"""
|
|
167
|
-
if not check_package("langchain_groq"):
|
|
168
|
-
raise ImportError("Langchain Groq package not found. Please install it using: pip install langchain-groq")
|
|
169
|
-
|
|
170
|
-
from langchain_groq import ChatGroq
|
|
171
|
-
kwargs["model"] = model_name
|
|
172
|
-
return ChatGroq(**kwargs)
|
|
173
|
-
|
|
174
|
-
@classmethod
|
|
175
|
-
def create_deepseek(cls, model_name: str = "deepseek-chat", **kwargs) -> Any:
|
|
176
|
-
"""
|
|
177
|
-
Create Deepseek chatbot model
|
|
178
|
-
Args:
|
|
179
|
-
model_name (str): Name of the model
|
|
180
|
-
**kwargs: Additional arguments
|
|
181
|
-
Returns:
|
|
182
|
-
ChatDeepseek: Chatbot model
|
|
183
|
-
"""
|
|
184
|
-
if not check_package("langchain_deepseek"):
|
|
185
|
-
raise ImportError("Langchain Deepseek package not found. Please install it using: pip install langchain-deepseek")
|
|
186
|
-
|
|
187
|
-
from langchain_deepseek import ChatDeepSeek
|
|
188
|
-
kwargs["model"] = model_name
|
|
189
|
-
return ChatDeepSeek(**kwargs)
|
|
190
|
-
|
|
191
|
-
@classmethod
|
|
192
|
-
def create_qwen(cls, model_name: str = "qwen", **kwargs) -> Any:
|
|
193
|
-
"""
|
|
194
|
-
Create Qwen chatbot model
|
|
195
|
-
Args:
|
|
196
|
-
model_name (str): Name of the model
|
|
197
|
-
**kwargs: Additional arguments
|
|
198
|
-
Returns:
|
|
199
|
-
ChatQwen: Chatbot model
|
|
200
|
-
"""
|
|
201
|
-
if not check_package("langchain_community"):
|
|
202
|
-
raise ImportError("Langchain Qwen package not found. Please install it using: pip install langchain_community")
|
|
203
|
-
|
|
204
|
-
from langchain_community.chat_models.tongyi import ChatTongyi
|
|
205
|
-
kwargs["model"] = model_name
|
|
206
|
-
return ChatTongyi(streaming=True,**kwargs)
|
|
207
|
-
|
|
208
|
-
@classmethod
|
|
209
|
-
def create_hugging_face(cls, model_name: str = "Qwen/Qwen2.5-VL-7B-Instruct",model_function: str = "image-text-to-text",
|
|
210
|
-
device='cpu',**kwargs) -> Any:
|
|
211
|
-
"""
|
|
212
|
-
Create and load hugging face model.
|
|
213
|
-
Args:
|
|
214
|
-
model_name (str): Name of the model
|
|
215
|
-
model_function (str): model function. Default is image-text-to-text.
|
|
216
|
-
device (str): Device to use. Default is cpu
|
|
217
|
-
**kwargs: Additional arguments
|
|
218
|
-
Returns:
|
|
219
|
-
ChatHuggingFace: Chatbot model
|
|
220
|
-
"""
|
|
221
|
-
if not check_package("transformers"):
|
|
222
|
-
raise ImportError("Transformers package not found. Please install it using: pip install transformers")
|
|
223
|
-
if not check_package("langchain_huggingface"):
|
|
224
|
-
raise ImportError("langchain_huggingface package not found. Please install it using: pip install langchain_huggingface")
|
|
225
|
-
if not check_package("torch"):
|
|
226
|
-
raise ImportError("Torch package not found. Please install it using: pip install torch")
|
|
227
|
-
|
|
228
|
-
from langchain_huggingface import HuggingFacePipeline
|
|
229
|
-
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForImageTextToText,AutoProcessor
|
|
230
|
-
import torch
|
|
231
|
-
|
|
232
|
-
device = torch.device(device) if torch.cuda.is_available() else torch.device("cpu")
|
|
233
|
-
|
|
234
|
-
temperature = kwargs.pop("temperature", 0.7)
|
|
235
|
-
max_length = kwargs.pop("max_length", 1024)
|
|
236
|
-
|
|
237
|
-
if model_function == "image-text-to-text":
|
|
238
|
-
tokenizer = AutoProcessor.from_pretrained(model_name,trust_remote_code=True)
|
|
239
|
-
model = AutoModelForImageTextToText.from_pretrained(
|
|
240
|
-
model_name,
|
|
241
|
-
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
|
242
|
-
device_map=device,
|
|
243
|
-
trust_remote_code=True,
|
|
244
|
-
**kwargs
|
|
245
|
-
)
|
|
246
|
-
else:
|
|
247
|
-
tokenizer = AutoTokenizer.from_pretrained(model_name,trust_remote_code=True)
|
|
248
|
-
model = AutoModelForCausalLM.from_pretrained(
|
|
249
|
-
model_name,
|
|
250
|
-
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
|
251
|
-
device_map=device,
|
|
252
|
-
trust_remote_code=True,
|
|
253
|
-
**kwargs)
|
|
254
|
-
|
|
255
|
-
# Create pipeline
|
|
256
|
-
pipe = pipeline(
|
|
257
|
-
model_function,
|
|
258
|
-
model=model,
|
|
259
|
-
tokenizer=tokenizer,
|
|
260
|
-
max_length=max_length,
|
|
261
|
-
temperature=temperature
|
|
262
|
-
)
|
|
263
|
-
|
|
264
|
-
# Create and return LangChain HuggingFacePipeline
|
|
265
|
-
return HuggingFacePipeline(pipeline=pipe)
|
|
266
|
-
|
|
267
|
-
def _reset_model(self):
|
|
268
|
-
"""Reset the model"""
|
|
269
|
-
self.model = self.model.reset()
|
|
270
|
-
|
|
271
|
-
def invoke_query(self,query: str,file_path: str = None,get_content_only: bool = True,images: list = None,pydantic_model = None) -> str:
|
|
272
|
-
"""
|
|
273
|
-
Invoke the model
|
|
274
|
-
Args:
|
|
275
|
-
query (str): Query to send to the model
|
|
276
|
-
file_path (str): Path to text file to load. Default is None
|
|
277
|
-
get_content_only (bool): Whether to return only content
|
|
278
|
-
images (list): List of images to send to the model
|
|
279
|
-
pydantic_model: Pydantic model for structured output
|
|
280
|
-
Returns:
|
|
281
|
-
str: Response from the model
|
|
282
|
-
"""
|
|
283
|
-
if file_path:
|
|
284
|
-
loader = TextLoader(file_path)
|
|
285
|
-
document = loader.load()
|
|
286
|
-
query = document.content
|
|
287
|
-
|
|
288
|
-
if pydantic_model is not None:
|
|
289
|
-
if hasattr(self.model, 'with_structured_output'):
|
|
290
|
-
try:
|
|
291
|
-
self.model = self.model.with_structured_output(pydantic_model)
|
|
292
|
-
except Exception as e:
|
|
293
|
-
raise ValueError(f"Error with pydantic_model: {e}")
|
|
294
|
-
if images:
|
|
295
|
-
res = self._model_invoke_images(images=images,prompt=query,pydantic_model=pydantic_model,get_content_only=get_content_only)
|
|
296
|
-
else:
|
|
297
|
-
res = self.model.invoke(query)
|
|
298
|
-
if get_content_only:
|
|
299
|
-
try:
|
|
300
|
-
return res.content
|
|
301
|
-
except Exception:
|
|
302
|
-
return res
|
|
303
|
-
return res
|
|
304
|
-
|
|
305
|
-
def _image_to_base64(self,image):
|
|
306
|
-
with open(image, "rb") as f:
|
|
307
|
-
return base64.b64encode(f.read()).decode('utf-8')
|
|
308
|
-
|
|
309
|
-
def _model_invoke_images(self,images: list, prompt: str,pydantic_model = None,get_content_only: bool = True) -> str:
|
|
310
|
-
"""
|
|
311
|
-
Function to invoke the model with images
|
|
312
|
-
Args:
|
|
313
|
-
images (list): List of images
|
|
314
|
-
prompt (str): Prompt
|
|
315
|
-
pydantic_model (PydanticModel): Pydantic model
|
|
316
|
-
Returns:
|
|
317
|
-
str: Output from the model
|
|
318
|
-
"""
|
|
319
|
-
base64_images = [self._image_to_base64(image) for image in images]
|
|
320
|
-
image_prompt_create = [{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_images[i]}"}} for i in range(len(images))]
|
|
321
|
-
prompt_new = [{"type": "text", "text": prompt},
|
|
322
|
-
*image_prompt_create,]
|
|
323
|
-
if pydantic_model is not None:
|
|
324
|
-
try:
|
|
325
|
-
self.model = self.model.with_structured_output(pydantic_model)
|
|
326
|
-
except Exception as e:
|
|
327
|
-
print(f"Error with pydantic_model: {e}")
|
|
328
|
-
print("Continuing without structured output")
|
|
329
|
-
message= HumanMessage(content=prompt_new,)
|
|
330
|
-
response = self.model.invoke([message])
|
|
331
|
-
|
|
332
|
-
if get_content_only:
|
|
333
|
-
try:
|
|
334
|
-
return response.content
|
|
335
|
-
except Exception:
|
|
336
|
-
print("Failed to get content from response. Returning response object")
|
|
337
|
-
return response
|
|
338
|
-
else:
|
|
339
|
-
return response
|
|
340
|
-
|
|
341
|
-
def _get_llm_metadata(self):
|
|
342
|
-
"""
|
|
343
|
-
Returns Basic metadata about the LLM
|
|
344
|
-
"""
|
|
345
|
-
print("Model Name: ", self.model)
|
|
346
|
-
print("Model Temperature: ", self.model.temperature)
|
|
347
|
-
print("Model Max Tokens: ", self.model.max_output_tokens)
|
|
348
|
-
print("Model Top P: ", self.model.top_p)
|
|
349
|
-
print("Model Top K: ", self.model.top_k)
|
|
350
|
-
print("Model Input Schema:",self.model.input_schema)
|
|
351
|
-
|
|
352
|
-
class ConversationModel:
|
|
353
|
-
"""
|
|
354
|
-
A class to handle conversation with AI models
|
|
355
|
-
|
|
356
|
-
Attributes:
|
|
357
|
-
chatbot: The AI model for conversation
|
|
358
|
-
message_list (List): List of conversation messages
|
|
359
|
-
file_path (str): Path to save/load conversations. Can be local or S3
|
|
360
|
-
"""
|
|
361
|
-
|
|
362
|
-
def __init__(self,
|
|
363
|
-
model_name: str = "gpt-4o",
|
|
364
|
-
model_type: str = 'openai',
|
|
365
|
-
**kwargs) -> None:
|
|
366
|
-
"""Initialize conversation model"""
|
|
367
|
-
self.chatbot = ModelFactory(model_type, model_name, **kwargs)
|
|
368
|
-
|
|
369
|
-
def initialize_conversation(self,
|
|
370
|
-
question: Optional[str],
|
|
371
|
-
context: Optional[str] = None,
|
|
372
|
-
file_path: Optional[str]=None) -> None:
|
|
373
|
-
"""Initialize conversation state"""
|
|
374
|
-
if file_path:
|
|
375
|
-
self.file_path = file_path
|
|
376
|
-
self.load_conversation(file_path)
|
|
377
|
-
|
|
378
|
-
else:
|
|
379
|
-
if not question:
|
|
380
|
-
raise ValueError("Question is required.")
|
|
381
|
-
|
|
382
|
-
if context:
|
|
383
|
-
self.context = context
|
|
384
|
-
else:
|
|
385
|
-
self.context = "Answer question to the point and don't hallucinate."
|
|
386
|
-
self.message_list = [
|
|
387
|
-
SystemMessage(content=context),
|
|
388
|
-
HumanMessage(content=question)
|
|
389
|
-
]
|
|
390
|
-
|
|
391
|
-
res = self._ask_question(self.message_list)
|
|
392
|
-
print(res)
|
|
393
|
-
self.message_list.append(AIMessage(content=res))
|
|
394
|
-
|
|
395
|
-
def _ask_question(self,messages: List[Union[SystemMessage, HumanMessage, AIMessage]],
|
|
396
|
-
get_content_only: bool = True) -> str:
|
|
397
|
-
"""
|
|
398
|
-
Ask a question and get response
|
|
399
|
-
Args:
|
|
400
|
-
messages: List of messages
|
|
401
|
-
get_content_only: Whether to return only content
|
|
402
|
-
Returns:
|
|
403
|
-
str: Response from the model
|
|
404
|
-
"""
|
|
405
|
-
res = self.chatbot.invoke_query(messages)
|
|
406
|
-
if get_content_only:
|
|
407
|
-
try:
|
|
408
|
-
return res.content
|
|
409
|
-
except Exception:
|
|
410
|
-
return res
|
|
411
|
-
return res
|
|
412
|
-
|
|
413
|
-
def add_message(self, message: str) -> str:
|
|
414
|
-
"""
|
|
415
|
-
Add a message to the conversation
|
|
416
|
-
Args:
|
|
417
|
-
message (str): Message to add
|
|
418
|
-
Returns:
|
|
419
|
-
str: Response from the chatbot
|
|
420
|
-
"""
|
|
421
|
-
self.message_list.append(HumanMessage(content=message))
|
|
422
|
-
res = self._ask_question(self.message_list)
|
|
423
|
-
self.message_list.append(AIMessage(content=res))
|
|
424
|
-
return res
|
|
425
|
-
|
|
426
|
-
@property
|
|
427
|
-
def all_messages(self) -> List[Union[SystemMessage, HumanMessage, AIMessage]]:
|
|
428
|
-
"""Get all messages"""
|
|
429
|
-
return self.message_list
|
|
430
|
-
|
|
431
|
-
@property
|
|
432
|
-
def last_message(self) -> str:
|
|
433
|
-
"""Get the last message"""
|
|
434
|
-
return self.message_list[-1].content
|
|
435
|
-
|
|
436
|
-
@property
|
|
437
|
-
def all_messages_content(self) -> List[str]:
|
|
438
|
-
"""Get content of all messages"""
|
|
439
|
-
return [message.content for message in self.message_list]
|
|
440
|
-
|
|
441
|
-
def _is_s3_path(self, path: str) -> bool:
|
|
442
|
-
"""
|
|
443
|
-
Check if path is an S3 path
|
|
444
|
-
Args:
|
|
445
|
-
path (str): Path to check
|
|
446
|
-
Returns:
|
|
447
|
-
bool: True if S3 path
|
|
448
|
-
"""
|
|
449
|
-
return path.startswith("s3://")
|
|
450
|
-
|
|
451
|
-
def save_conversation(self, file_path: Optional[str] = None, **kwargs) -> bool:
|
|
452
|
-
"""
|
|
453
|
-
Save the conversation
|
|
454
|
-
Args:
|
|
455
|
-
file_path: Path to save the conversation
|
|
456
|
-
**kwargs: Additional arguments for S3
|
|
457
|
-
Returns:
|
|
458
|
-
bool: Success status
|
|
459
|
-
"""
|
|
460
|
-
if self._is_s3_path(file_path or self.file_path):
|
|
461
|
-
print("Saving conversation to S3.")
|
|
462
|
-
self.save_file_path = file_path
|
|
463
|
-
return self._save_to_s3(self.file_path,**kwargs)
|
|
464
|
-
return self._save_to_file(file_path or self.file_path)
|
|
465
|
-
|
|
466
|
-
def _save_to_s3(self,**kwargs) -> bool:
|
|
467
|
-
"""Save conversation to S3"""
|
|
468
|
-
try:
|
|
469
|
-
client = kwargs.get('client', self.client)
|
|
470
|
-
bucket = kwargs.get('bucket', self.bucket)
|
|
471
|
-
client.put_object(
|
|
472
|
-
Body=str(self.message_list),
|
|
473
|
-
Bucket=bucket,
|
|
474
|
-
Key=self.save_file_path
|
|
475
|
-
)
|
|
476
|
-
print(f"Conversation saved to s3_path: {self.s3_path}")
|
|
477
|
-
return True
|
|
478
|
-
except Exception as e:
|
|
479
|
-
raise ValueError(f"Error saving conversation to s3: {e}")
|
|
480
|
-
|
|
481
|
-
def _save_to_file(self, file_path: str) -> bool:
|
|
482
|
-
"""Save conversation to file"""
|
|
483
|
-
try:
|
|
484
|
-
with open(file_path, 'w') as f:
|
|
485
|
-
for message in self.message_list:
|
|
486
|
-
f.write(f"{message.content}\n")
|
|
487
|
-
print(f"Conversation saved to file: {file_path}")
|
|
488
|
-
return True
|
|
489
|
-
except Exception as e:
|
|
490
|
-
raise ValueError(f"Error saving conversation to file: {e}")
|
|
491
|
-
|
|
492
|
-
def load_conversation(self, file_path: Optional[str] = None, **kwargs) -> List[Any]:
|
|
493
|
-
"""
|
|
494
|
-
Load a conversation
|
|
495
|
-
Args:
|
|
496
|
-
file_path: Path to load from
|
|
497
|
-
**kwargs: Additional arguments for S3
|
|
498
|
-
Returns:
|
|
499
|
-
List: Loaded messages
|
|
500
|
-
"""
|
|
501
|
-
self.message_list = []
|
|
502
|
-
if self._is_s3_path(file_path or self.file_path):
|
|
503
|
-
print("Loading conversation from S3.")
|
|
504
|
-
self.file_path = file_path
|
|
505
|
-
return self._load_from_s3(**kwargs)
|
|
506
|
-
return self._load_from_file(file_path or self.file_path)
|
|
507
|
-
|
|
508
|
-
def _load_from_s3(self, **kwargs) -> List[Any]:
|
|
509
|
-
"""Load conversation from S3"""
|
|
510
|
-
try:
|
|
511
|
-
client = kwargs.get('client', self.client)
|
|
512
|
-
bucket = kwargs.get('bucket', self.bucket)
|
|
513
|
-
res = client.get_response(client, bucket, self.s3_path)
|
|
514
|
-
res_str = eval(res['Body'].read().decode('utf-8'))
|
|
515
|
-
self.message_list = [SystemMessage(content=res_str)]
|
|
516
|
-
print(f"Conversation loaded from s3_path: {self.file_path}")
|
|
517
|
-
return self.message_list
|
|
518
|
-
except Exception as e:
|
|
519
|
-
raise ValueError(f"Error loading conversation from s3: {e}")
|
|
520
|
-
|
|
521
|
-
def _load_from_file(self, file_path: str) -> List[Any]:
|
|
522
|
-
"""Load conversation from file"""
|
|
523
|
-
try:
|
|
524
|
-
with open(file_path, 'r') as f:
|
|
525
|
-
lines = f.readlines()
|
|
526
|
-
for line in lines:
|
|
527
|
-
self.message_list.append(SystemMessage(content=line))
|
|
528
|
-
print(f"Conversation loaded from file: {file_path}")
|
|
529
|
-
return self.message_list
|
|
530
|
-
except Exception as e:
|
|
531
|
-
raise ValueError(f"Error loading conversation from file: {e}")
|
|
532
|
-
|
|
533
|
-
class IPythonStreamHandler(StreamingStdOutCallbackHandler):
|
|
534
|
-
"""Handler for IPython display"""
|
|
535
|
-
|
|
536
|
-
def __init__(self):
|
|
537
|
-
self.output = ""
|
|
538
|
-
|
|
539
|
-
def on_llm_new_token(self, token: str, **kwargs) -> None:
|
|
540
|
-
"""Handle new token"""
|
|
541
|
-
self.output += token
|
|
542
|
-
display(HTML(self.output), clear=True)
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
class AgentFactory:
|
|
546
|
-
"""Factory class for creating different types of agents"""
|
|
547
|
-
|
|
548
|
-
def __init__(self, agent_type: str = 'basic', model_name: str = "gpt-4o", **kwargs) -> Any:
|
|
549
|
-
"""
|
|
550
|
-
Factory method to create any type of agent
|
|
551
|
-
Args:
|
|
552
|
-
agent_type (str): Type of agent to create. Default is basic.
|
|
553
|
-
model_name (str): Name of the model
|
|
554
|
-
**kwargs: Additional arguments
|
|
555
|
-
Returns:
|
|
556
|
-
Any: Agent
|
|
557
|
-
"""
|
|
558
|
-
creators = {
|
|
559
|
-
'basic': self.create_basic_agent,
|
|
560
|
-
'langgraph': self.create_langgraph_agent,
|
|
561
|
-
}
|
|
562
|
-
|
|
563
|
-
agent_data = creators.get(agent_type)
|
|
564
|
-
if not agent_data:
|
|
565
|
-
raise ValueError(f"Unsupported agent type: {agent_type}")
|
|
566
|
-
|
|
567
|
-
try:
|
|
568
|
-
self.agent = agent_data(model_name, **kwargs)
|
|
569
|
-
except Exception as e:
|
|
570
|
-
raise ValueError(f"Error creating {agent_type} agent: {str(e)}")
|
|
571
|
-
|
|
572
|
-
@classmethod
|
|
573
|
-
def create_basic_agent(cls, model_name: str = "gpt-4o", **kwargs) -> Any:
|
|
574
|
-
"""
|
|
575
|
-
Create basic agent
|
|
576
|
-
Args:
|
|
577
|
-
model_name (str): Name of the model
|
|
578
|
-
**kwargs: Additional arguments
|
|
579
|
-
Returns:
|
|
580
|
-
Runnable: Agent
|
|
581
|
-
"""
|
|
582
|
-
# Basic agent creation logic here
|
|
583
|
-
llm = ModelFactory(model_name=model_name, **kwargs).model
|
|
584
|
-
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
585
|
-
prompt = ChatPromptTemplate.from_messages([
|
|
586
|
-
("system", "You are a helpful AI assistant"),
|
|
587
|
-
MessagesPlaceholder(variable_name="messages")
|
|
588
|
-
])
|
|
589
|
-
from langchain_core.runnables import chain
|
|
590
|
-
agent = prompt | llm
|
|
591
|
-
return agent
|
|
592
|
-
|
|
593
|
-
@classmethod
|
|
594
|
-
def create_langgraph_agent(cls, model_name: str = "gpt-4o", **kwargs) -> Any:
|
|
595
|
-
"""
|
|
596
|
-
Create LangGraph agent
|
|
597
|
-
Args:
|
|
598
|
-
model_name (str): Name of the model
|
|
599
|
-
**kwargs: Additional arguments
|
|
600
|
-
Returns:
|
|
601
|
-
Graph: LangGraph agent
|
|
602
|
-
"""
|
|
603
|
-
if not check_package("langgraph"):
|
|
604
|
-
raise ImportError("LangGraph package not found. Please install it using: pip install langgraph")
|
|
605
|
-
|
|
606
|
-
from langgraph.graph import StateGraph, MessageGraph
|
|
607
|
-
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
608
|
-
from langchain_core.runnables import chain
|
|
609
|
-
from langchain_core.messages import BaseMessage
|
|
610
|
-
|
|
611
|
-
llm = ModelFactory(model_name=model_name, **kwargs).model
|
|
612
|
-
|
|
613
|
-
# Define the state of the graph
|
|
614
|
-
class GraphState:
|
|
615
|
-
messages: List[BaseMessage]
|
|
616
|
-
agent_state: Dict[str, Any]
|
|
617
|
-
|
|
618
|
-
# Define the nodes
|
|
619
|
-
def agent(state: GraphState):
|
|
620
|
-
prompt = ChatPromptTemplate.from_messages([
|
|
621
|
-
("system", "You are a helpful AI assistant"),
|
|
622
|
-
MessagesPlaceholder(variable_name="messages")
|
|
623
|
-
])
|
|
624
|
-
return (prompt | llm).invoke({"messages": state.messages})
|
|
625
|
-
|
|
626
|
-
def user(state: GraphState, input: str):
|
|
627
|
-
return HumanMessage(content=input)
|
|
628
|
-
|
|
629
|
-
# Define the graph
|
|
630
|
-
graph = MessageGraph()
|
|
631
|
-
|
|
632
|
-
# Add the nodes
|
|
633
|
-
graph.add_node("agent", agent)
|
|
634
|
-
graph.add_node("user", user)
|
|
635
|
-
|
|
636
|
-
# Set the entrypoint
|
|
637
|
-
graph.set_entry_point("user")
|
|
638
|
-
|
|
639
|
-
# Add the edges
|
|
640
|
-
graph.add_edge("user", "agent")
|
|
641
|
-
graph.add_edge("agent", "user")
|
|
642
|
-
|
|
643
|
-
# Compile the graph
|
|
644
|
-
return graph.compile()
|