mb-rag 1.0.117__py3-none-any.whl → 1.0.124__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mb-rag might be problematic. Click here for more details.
- mb_rag/chatbot/basic.py +329 -295
- mb_rag/chatbot/chains.py +163 -202
- mb_rag/utils/bounding_box.py +203 -76
- mb_rag/utils/extra.py +1 -1
- mb_rag/version.py +1 -1
- {mb_rag-1.0.117.dist-info → mb_rag-1.0.124.dist-info}/METADATA +1 -1
- mb_rag-1.0.124.dist-info/RECORD +15 -0
- mb_rag-1.0.117.dist-info/RECORD +0 -15
- {mb_rag-1.0.117.dist-info → mb_rag-1.0.124.dist-info}/WHEEL +0 -0
- {mb_rag-1.0.117.dist-info → mb_rag-1.0.124.dist-info}/top_level.txt +0 -0
mb_rag/chatbot/basic.py
CHANGED
|
@@ -1,361 +1,395 @@
|
|
|
1
1
|
import os
|
|
2
|
-
import importlib.util
|
|
3
2
|
from dotenv import load_dotenv
|
|
4
3
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
|
5
4
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|
6
5
|
from IPython.display import display, HTML
|
|
6
|
+
from typing import Optional, List, Dict, Any, Union
|
|
7
|
+
from mb_rag.utils.extra import check_package
|
|
7
8
|
import base64
|
|
8
|
-
# from langchain_core import pydantic_v1
|
|
9
9
|
|
|
10
|
-
__all__ = [
|
|
11
|
-
|
|
12
|
-
|
|
10
|
+
__all__ = [
|
|
11
|
+
'ChatbotBase',
|
|
12
|
+
'ModelFactory',
|
|
13
|
+
'ConversationModel',
|
|
14
|
+
'IPythonStreamHandler'
|
|
15
|
+
]
|
|
13
16
|
|
|
14
|
-
|
|
15
|
-
"""
|
|
16
|
-
Check if a package is installed
|
|
17
|
-
Args:
|
|
18
|
-
package_name (str): Name of the package
|
|
19
|
-
Returns:
|
|
20
|
-
bool: True if package is installed, False otherwise
|
|
21
|
-
"""
|
|
22
|
-
return importlib.util.find_spec(package_name) is not None
|
|
23
|
-
|
|
24
|
-
def load_env(file_path: str):
|
|
25
|
-
"""
|
|
26
|
-
Load environment variables from a file
|
|
27
|
-
Args:
|
|
28
|
-
file_path (str): Path to the environment file
|
|
29
|
-
Returns:
|
|
30
|
-
None
|
|
31
|
-
"""
|
|
32
|
-
load_dotenv(file_path)
|
|
17
|
+
class ChatbotBase:
|
|
18
|
+
"""Base class for chatbot functionality"""
|
|
33
19
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
"""
|
|
43
|
-
os.environ[name] = key
|
|
44
|
-
|
|
45
|
-
def get_chatbot_openai(model_name: str = "gpt-4o",**kwargs):
|
|
46
|
-
"""
|
|
47
|
-
Load the chatbot model from OpenAI
|
|
48
|
-
Args:
|
|
49
|
-
model_name (str): Name of the model
|
|
50
|
-
**kwargs: Additional arguments (temperature, max_tokens, timeout, max_retries, api_key etc.)
|
|
51
|
-
Returns:
|
|
52
|
-
ChatOpenAI: Chatbot model
|
|
53
|
-
"""
|
|
54
|
-
if not check_package("langchain_openai"):
|
|
55
|
-
raise ImportError("OpenAI package not found. Please install it using: pip install openai langchain-openai")
|
|
56
|
-
|
|
57
|
-
from langchain_openai import ChatOpenAI
|
|
58
|
-
kwargs["model_name"] = model_name
|
|
59
|
-
return ChatOpenAI(**kwargs)
|
|
60
|
-
|
|
61
|
-
def get_chatbot_anthropic(model_name: str = "claude-3-opus-20240229",**kwargs):
|
|
62
|
-
"""
|
|
63
|
-
Load the chatbot model from Anthropic
|
|
64
|
-
Args:
|
|
65
|
-
model_name (str): Name of the model
|
|
66
|
-
**kwargs: Additional arguments (temperature, max_tokens, timeout, max_retries, api_key etc.)
|
|
67
|
-
Returns:
|
|
68
|
-
ChatAnthropic: Chatbot model
|
|
69
|
-
"""
|
|
70
|
-
if not check_package("langchain_anthropic"):
|
|
71
|
-
raise ImportError("Anthropic package not found. Please install it using: pip install anthropic langchain-anthropic")
|
|
20
|
+
@staticmethod
|
|
21
|
+
def load_env(file_path: str) -> None:
|
|
22
|
+
"""
|
|
23
|
+
Load environment variables from a file
|
|
24
|
+
Args:
|
|
25
|
+
file_path (str): Path to the environment file
|
|
26
|
+
"""
|
|
27
|
+
load_dotenv(file_path)
|
|
72
28
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
29
|
+
@staticmethod
|
|
30
|
+
def add_os_key(name: str, key: str) -> None:
|
|
31
|
+
"""
|
|
32
|
+
Add an API key to the environment
|
|
33
|
+
Args:
|
|
34
|
+
name (str): Name of the API key
|
|
35
|
+
key (str): API key
|
|
36
|
+
"""
|
|
37
|
+
os.environ[name] = key
|
|
76
38
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
raise ImportError("Google Generative AI package not found. Please install it using: pip install google-generativeai langchain-google-genai")
|
|
88
|
-
|
|
89
|
-
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
90
|
-
kwargs["model"] = model_name
|
|
91
|
-
return ChatGoogleGenerativeAI(**kwargs)
|
|
39
|
+
@staticmethod
|
|
40
|
+
def get_client():
|
|
41
|
+
"""
|
|
42
|
+
Returns a boto3 client for S3
|
|
43
|
+
"""
|
|
44
|
+
if not check_package("boto3"):
|
|
45
|
+
raise ImportError("Boto3 package not found. Please install it using: pip install boto3")
|
|
46
|
+
|
|
47
|
+
import boto3
|
|
48
|
+
return boto3.client('s3')
|
|
92
49
|
|
|
93
|
-
|
|
94
|
-
"""
|
|
95
|
-
Load the chatbot model from Ollama
|
|
96
|
-
Args:
|
|
97
|
-
model_name (str): Name of the model
|
|
98
|
-
**kwargs: Additional arguments (temperature, max_tokens, timeout, max_retries, api_key etc.)
|
|
99
|
-
Returns:
|
|
100
|
-
ChatOllama: Chatbot model
|
|
101
|
-
"""
|
|
102
|
-
if not check_package("langchain_community"):
|
|
103
|
-
raise ImportError("Langchain Community package not found. Please install it using: pip install langchain-community")
|
|
50
|
+
class ModelFactory:
|
|
51
|
+
"""Factory class for creating different types of chatbot models"""
|
|
104
52
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
53
|
+
def __init__(self, model_type: str = 'openai', model_name: str = "gpt-4o", **kwargs) -> Any:
|
|
54
|
+
"""
|
|
55
|
+
Factory method to create any type of model
|
|
56
|
+
Args:
|
|
57
|
+
model_type (str): Type of model to create
|
|
58
|
+
model_name (str): Name of the model
|
|
59
|
+
**kwargs: Additional arguments
|
|
60
|
+
Returns:
|
|
61
|
+
Any: Chatbot model
|
|
62
|
+
"""
|
|
63
|
+
creators = {
|
|
64
|
+
'openai': self.create_openai,
|
|
65
|
+
'anthropic': self.create_anthropic,
|
|
66
|
+
'google': self.create_google,
|
|
67
|
+
'ollama': self.create_ollama
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
model_data = creators.get(model_type)
|
|
71
|
+
if not model_data:
|
|
72
|
+
raise ValueError(f"Unsupported model type: {model_type}")
|
|
73
|
+
|
|
121
74
|
try:
|
|
122
|
-
|
|
75
|
+
self.model = model_data(model_name, **kwargs)
|
|
123
76
|
except Exception as e:
|
|
124
|
-
raise ValueError(f"Error
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
77
|
+
raise ValueError(f"Error creating {model_type} model: {str(e)}")
|
|
78
|
+
|
|
79
|
+
@classmethod
|
|
80
|
+
def create_openai(cls, model_name: str = "gpt-4o", **kwargs) -> Any:
|
|
81
|
+
"""
|
|
82
|
+
Create OpenAI chatbot model
|
|
83
|
+
Args:
|
|
84
|
+
model_name (str): Name of the model
|
|
85
|
+
**kwargs: Additional arguments
|
|
86
|
+
Returns:
|
|
87
|
+
ChatOpenAI: Chatbot model
|
|
88
|
+
"""
|
|
89
|
+
if not check_package("openai"):
|
|
90
|
+
raise ImportError("OpenAI package not found. Please install it using: pip install openai langchain-openai")
|
|
91
|
+
|
|
92
|
+
from langchain_openai import ChatOpenAI
|
|
93
|
+
kwargs["model_name"] = model_name
|
|
94
|
+
return ChatOpenAI(**kwargs)
|
|
142
95
|
|
|
96
|
+
@classmethod
|
|
97
|
+
def create_anthropic(cls, model_name: str = "claude-3-opus-20240229", **kwargs) -> Any:
|
|
98
|
+
"""
|
|
99
|
+
Create Anthropic chatbot model
|
|
100
|
+
Args:
|
|
101
|
+
model_name (str): Name of the model
|
|
102
|
+
**kwargs: Additional arguments
|
|
103
|
+
Returns:
|
|
104
|
+
ChatAnthropic: Chatbot model
|
|
105
|
+
"""
|
|
106
|
+
if not check_package("anthropic"):
|
|
107
|
+
raise ImportError("Anthropic package not found. Please install it using: pip install anthropic langchain-anthropic")
|
|
108
|
+
|
|
109
|
+
from langchain_anthropic import ChatAnthropic
|
|
110
|
+
kwargs["model_name"] = model_name
|
|
111
|
+
return ChatAnthropic(**kwargs)
|
|
143
112
|
|
|
144
|
-
|
|
145
|
-
def
|
|
146
|
-
|
|
113
|
+
@classmethod
|
|
114
|
+
def create_google(cls, model_name: str = "gemini-1.5-flash", **kwargs) -> Any:
|
|
115
|
+
"""
|
|
116
|
+
Create Google chatbot model
|
|
117
|
+
Args:
|
|
118
|
+
model_name (str): Name of the model
|
|
119
|
+
**kwargs: Additional arguments
|
|
120
|
+
Returns:
|
|
121
|
+
ChatGoogleGenerativeAI: Chatbot model
|
|
122
|
+
"""
|
|
123
|
+
if not check_package("google.generativeai"):
|
|
124
|
+
raise ImportError("Google Generative AI package not found. Please install it using: pip install google-generativeai langchain-google-genai")
|
|
147
125
|
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
126
|
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
127
|
+
kwargs["model"] = model_name
|
|
128
|
+
return ChatGoogleGenerativeAI(**kwargs)
|
|
151
129
|
|
|
130
|
+
@classmethod
|
|
131
|
+
def create_ollama(cls, model_name: str = "llama3", **kwargs) -> Any:
|
|
132
|
+
"""
|
|
133
|
+
Create Ollama chatbot model
|
|
134
|
+
Args:
|
|
135
|
+
model_name (str): Name of the model
|
|
136
|
+
**kwargs: Additional arguments
|
|
137
|
+
Returns:
|
|
138
|
+
Ollama: Chatbot model
|
|
139
|
+
"""
|
|
140
|
+
if not check_package("langchain_community"):
|
|
141
|
+
raise ImportError("Langchain Community package not found. Please install it using: pip install langchain-community")
|
|
142
|
+
|
|
143
|
+
from langchain_community.llms import Ollama
|
|
144
|
+
kwargs["model"] = model_name
|
|
145
|
+
return Ollama(**kwargs)
|
|
152
146
|
|
|
153
|
-
def
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
147
|
+
def invoke_query(self,query: str,get_content_only: bool = True,images: list = None,pydantic_model = None) -> str:
|
|
148
|
+
"""
|
|
149
|
+
Invoke the model
|
|
150
|
+
Args:
|
|
151
|
+
query (str): Query to send to the model
|
|
152
|
+
get_content_only (bool): Whether to return only content
|
|
153
|
+
images (list): List of images to send to the model
|
|
154
|
+
pydantic_model: Pydantic model for structured output
|
|
155
|
+
Returns:
|
|
156
|
+
str: Response from the model
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
if pydantic_model is not None:
|
|
160
|
+
try:
|
|
161
|
+
self.model = self.model.with_structured_output(pydantic_model)
|
|
162
|
+
except Exception as e:
|
|
163
|
+
raise ValueError(f"Error with pydantic_model: {e}")
|
|
164
|
+
if images:
|
|
165
|
+
res = self._model_invoke_images(images=images,prompt=query,pydantic_model=pydantic_model)
|
|
172
166
|
else:
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
167
|
+
res = self.model.invoke(query)
|
|
168
|
+
if get_content_only:
|
|
169
|
+
try:
|
|
170
|
+
return res.content
|
|
171
|
+
except Exception:
|
|
172
|
+
return res
|
|
173
|
+
return res
|
|
177
174
|
|
|
175
|
+
def _image_to_base64(self,image):
|
|
176
|
+
with open(image, "rb") as f:
|
|
177
|
+
return base64.b64encode(f.read()).decode('utf-8')
|
|
178
178
|
|
|
179
|
-
def
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
pydantic_model (PydanticModel): Pydantic model
|
|
179
|
+
def _model_invoke_images(self,images: list, prompt: str,pydantic_model = None):
|
|
180
|
+
"""
|
|
181
|
+
Function to invoke the model with images
|
|
182
|
+
Args:
|
|
183
|
+
model (ChatOpenAI): Chatbot model
|
|
184
|
+
images (list): List of images
|
|
185
|
+
prompt (str): Prompt
|
|
186
|
+
pydantic_model (PydanticModel): Pydantic model
|
|
188
187
|
Returns:
|
|
189
188
|
str: Output from the model
|
|
190
189
|
"""
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
190
|
+
base64_images = [self._image_to_base64(image) for image in images]
|
|
191
|
+
image_prompt_create = [{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_images[i]}"}} for i in range(len(images))]
|
|
192
|
+
prompt_new = [{"type": "text", "text": prompt},
|
|
193
|
+
*image_prompt_create,]
|
|
194
|
+
if pydantic_model is not None:
|
|
195
|
+
try:
|
|
196
|
+
self.model = self.model.with_structured_output(pydantic_model)
|
|
197
|
+
except Exception as e:
|
|
198
|
+
print(f"Error with pydantic_model: {e}")
|
|
199
|
+
print("Continuing without structured output")
|
|
200
|
+
message= HumanMessage(content=prompt_new,)
|
|
201
|
+
response = self.model.invoke([message])
|
|
202
|
+
return response.content
|
|
194
203
|
|
|
195
|
-
def model_invoke_images(model, images: list, prompt: str,pydantic_model = None):
|
|
196
|
-
"""
|
|
197
|
-
Function to invoke the model with images
|
|
198
|
-
Args:
|
|
199
|
-
model (ChatOpenAI): Chatbot model
|
|
200
|
-
images (list): List of images
|
|
201
|
-
prompt (str): Prompt
|
|
202
|
-
get_content_only (bool): Get content only. Default is True. If False then returns the full response
|
|
203
|
-
pydantic_model (PydanticModel): Pydantic model
|
|
204
|
-
Returns:
|
|
205
|
-
str: Output from the model
|
|
206
|
-
"""
|
|
207
|
-
def image_to_base64(image):
|
|
208
|
-
with open(image, "rb") as f:
|
|
209
|
-
return base64.b64encode(f.read()).decode('utf-8')
|
|
210
|
-
base64_images = [image_to_base64(image) for image in images]
|
|
211
|
-
image_prompt_create = [{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_images[i]}"}} for i in range(len(images))]
|
|
212
|
-
prompt_new = [{"type": "text", "text": prompt},
|
|
213
|
-
*image_prompt_create,]
|
|
214
|
-
message= HumanMessage(content=prompt_new,)
|
|
215
|
-
if pydantic_model is not None:
|
|
216
|
-
try:
|
|
217
|
-
model = model.with_structured_output(pydantic_model)
|
|
218
|
-
except Exception as e:
|
|
219
|
-
raise ValueError(f"Error with pydantic_model: {e}")
|
|
220
|
-
response = model.invoke([message])
|
|
221
|
-
return response.content
|
|
222
204
|
|
|
223
|
-
class
|
|
205
|
+
class ConversationModel:
|
|
224
206
|
"""
|
|
225
|
-
A class to
|
|
207
|
+
A class to handle conversation with AI models
|
|
208
|
+
|
|
226
209
|
Attributes:
|
|
227
|
-
chatbot
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
message_list (list): List of messages in the conversation
|
|
231
|
-
file_path (str): Path to the conversation file (if s3_path then add s3_path='loc' and client and bucket)
|
|
210
|
+
chatbot: The AI model for conversation
|
|
211
|
+
message_list (List): List of conversation messages
|
|
212
|
+
file_path (str): Path to save/load conversations. Can be local or S3
|
|
232
213
|
"""
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
print(self.s3_path)
|
|
241
|
-
if self.s3_path is not None:
|
|
242
|
-
self.client = kwargs['client']
|
|
243
|
-
self.bucket = kwargs['bucket']
|
|
244
|
-
except Exception:
|
|
245
|
-
self.s3_path = None
|
|
214
|
+
|
|
215
|
+
def __init__(self,
|
|
216
|
+
model_name: str = "gpt-4o",
|
|
217
|
+
model_type: str = 'openai',
|
|
218
|
+
**kwargs) -> None:
|
|
219
|
+
"""Initialize conversation model"""
|
|
220
|
+
self.chatbot = ModelFactory(model_type, model_name, **kwargs)
|
|
246
221
|
|
|
247
|
-
|
|
248
|
-
|
|
222
|
+
def initialize_conversation(self,
|
|
223
|
+
question: Optional[str],
|
|
224
|
+
context: Optional[str] = None,
|
|
225
|
+
file_path: Optional[str]=None) -> None:
|
|
226
|
+
"""Initialize conversation state"""
|
|
227
|
+
if file_path:
|
|
228
|
+
self.file_path = file_path
|
|
249
229
|
self.load_conversation(file_path)
|
|
230
|
+
|
|
250
231
|
else:
|
|
251
|
-
if
|
|
232
|
+
if not question:
|
|
233
|
+
raise ValueError("Question is required.")
|
|
234
|
+
|
|
235
|
+
if context:
|
|
252
236
|
self.context = context
|
|
253
237
|
else:
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
238
|
+
self.context = "Answer question to the point and don't hallucinate."
|
|
239
|
+
self.message_list = [
|
|
240
|
+
SystemMessage(content=context),
|
|
241
|
+
HumanMessage(content=question)
|
|
242
|
+
]
|
|
243
|
+
|
|
244
|
+
res = self._ask_question(self.message_list)
|
|
245
|
+
print(res)
|
|
246
|
+
self.message_list.append(AIMessage(content=res))
|
|
260
247
|
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
248
|
+
def _ask_question(self,messages: List[Union[SystemMessage, HumanMessage, AIMessage]],
|
|
249
|
+
get_content_only: bool = True) -> str:
|
|
250
|
+
"""
|
|
251
|
+
Ask a question and get response
|
|
252
|
+
Args:
|
|
253
|
+
messages: List of messages
|
|
254
|
+
get_content_only: Whether to return only content
|
|
255
|
+
Returns:
|
|
256
|
+
str: Response from the model
|
|
257
|
+
"""
|
|
258
|
+
res = self.chatbot.invoke_query(messages)
|
|
259
|
+
if get_content_only:
|
|
260
|
+
try:
|
|
261
|
+
return res.content
|
|
262
|
+
except Exception:
|
|
263
|
+
return res
|
|
264
|
+
return res
|
|
266
265
|
|
|
267
|
-
def add_message(self, message: str):
|
|
266
|
+
def add_message(self, message: str) -> str:
|
|
268
267
|
"""
|
|
269
268
|
Add a message to the conversation
|
|
270
269
|
Args:
|
|
271
270
|
message (str): Message to add
|
|
272
271
|
Returns:
|
|
273
|
-
str:
|
|
272
|
+
str: Response from the chatbot
|
|
274
273
|
"""
|
|
275
274
|
self.message_list.append(HumanMessage(content=message))
|
|
276
|
-
res =
|
|
275
|
+
res = self._ask_question(self.message_list)
|
|
277
276
|
self.message_list.append(AIMessage(content=res))
|
|
278
277
|
return res
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
Get all messages
|
|
283
|
-
Returns:
|
|
284
|
-
list: List of messages
|
|
285
|
-
"""
|
|
278
|
+
|
|
279
|
+
@property
|
|
280
|
+
def all_messages(self) -> List[Union[SystemMessage, HumanMessage, AIMessage]]:
|
|
281
|
+
"""Get all messages"""
|
|
286
282
|
return self.message_list
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
Get the last message
|
|
291
|
-
Returns:
|
|
292
|
-
str: Last message
|
|
293
|
-
"""
|
|
283
|
+
|
|
284
|
+
@property
|
|
285
|
+
def last_message(self) -> str:
|
|
286
|
+
"""Get the last message"""
|
|
294
287
|
return self.message_list[-1].content
|
|
295
|
-
|
|
296
|
-
|
|
288
|
+
|
|
289
|
+
@property
|
|
290
|
+
def all_messages_content(self) -> List[str]:
|
|
291
|
+
"""Get content of all messages"""
|
|
292
|
+
return [message.content for message in self.message_list]
|
|
293
|
+
|
|
294
|
+
def _is_s3_path(self, path: str) -> bool:
|
|
297
295
|
"""
|
|
298
|
-
|
|
296
|
+
Check if path is an S3 path
|
|
297
|
+
Args:
|
|
298
|
+
path (str): Path to check
|
|
299
299
|
Returns:
|
|
300
|
-
|
|
300
|
+
bool: True if S3 path
|
|
301
301
|
"""
|
|
302
|
-
return
|
|
302
|
+
return path.startswith("s3://")
|
|
303
303
|
|
|
304
|
-
def save_conversation(self, file_path: str = None, **kwargs):
|
|
304
|
+
def save_conversation(self, file_path: Optional[str] = None, **kwargs) -> bool:
|
|
305
305
|
"""
|
|
306
|
-
Save the conversation
|
|
306
|
+
Save the conversation
|
|
307
307
|
Args:
|
|
308
|
-
file_path
|
|
309
|
-
**kwargs: Additional arguments
|
|
308
|
+
file_path: Path to save the conversation
|
|
309
|
+
**kwargs: Additional arguments for S3
|
|
310
310
|
Returns:
|
|
311
|
-
bool:
|
|
311
|
+
bool: Success status
|
|
312
312
|
"""
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
313
|
+
if self._is_s3_path(file_path or self.file_path):
|
|
314
|
+
print("Saving conversation to S3.")
|
|
315
|
+
self.save_file_path = file_path
|
|
316
|
+
return self._save_to_s3(self.file_path,**kwargs)
|
|
317
|
+
return self._save_to_file(file_path or self.file_path)
|
|
318
|
+
|
|
319
|
+
def _save_to_s3(self,**kwargs) -> bool:
|
|
320
|
+
"""Save conversation to S3"""
|
|
321
|
+
try:
|
|
322
|
+
client = kwargs.get('client', self.client)
|
|
323
|
+
bucket = kwargs.get('bucket', self.bucket)
|
|
324
|
+
client.put_object(
|
|
325
|
+
Body=str(self.message_list),
|
|
326
|
+
Bucket=bucket,
|
|
327
|
+
Key=self.save_file_path
|
|
328
|
+
)
|
|
322
329
|
print(f"Conversation saved to s3_path: {self.s3_path}")
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
330
|
+
return True
|
|
331
|
+
except Exception as e:
|
|
332
|
+
raise ValueError(f"Error saving conversation to s3: {e}")
|
|
333
|
+
|
|
334
|
+
def _save_to_file(self, file_path: str) -> bool:
|
|
335
|
+
"""Save conversation to file"""
|
|
336
|
+
try:
|
|
337
|
+
with open(file_path, 'w') as f:
|
|
338
|
+
for message in self.message_list:
|
|
339
|
+
f.write(f"{message.content}\n")
|
|
340
|
+
print(f"Conversation saved to file: {file_path}")
|
|
341
|
+
return True
|
|
342
|
+
except Exception as e:
|
|
343
|
+
raise ValueError(f"Error saving conversation to file: {e}")
|
|
344
|
+
|
|
345
|
+
def load_conversation(self, file_path: Optional[str] = None, **kwargs) -> List[Any]:
|
|
336
346
|
"""
|
|
337
|
-
Load
|
|
347
|
+
Load a conversation
|
|
338
348
|
Args:
|
|
339
|
-
file_path
|
|
340
|
-
**kwargs: Additional arguments
|
|
349
|
+
file_path: Path to load from
|
|
350
|
+
**kwargs: Additional arguments for S3
|
|
341
351
|
Returns:
|
|
342
|
-
|
|
352
|
+
List: Loaded messages
|
|
343
353
|
"""
|
|
344
354
|
self.message_list = []
|
|
345
|
-
if self.
|
|
346
|
-
|
|
347
|
-
|
|
355
|
+
if self._is_s3_path(file_path or self.file_path):
|
|
356
|
+
print("Loading conversation from S3.")
|
|
357
|
+
self.file_path = file_path
|
|
358
|
+
return self._load_from_s3(**kwargs)
|
|
359
|
+
return self._load_from_file(file_path or self.file_path)
|
|
360
|
+
|
|
361
|
+
def _load_from_s3(self, **kwargs) -> List[Any]:
|
|
362
|
+
"""Load conversation from S3"""
|
|
363
|
+
try:
|
|
364
|
+
client = kwargs.get('client', self.client)
|
|
365
|
+
bucket = kwargs.get('bucket', self.bucket)
|
|
348
366
|
res = client.get_response(client, bucket, self.s3_path)
|
|
349
367
|
res_str = eval(res['Body'].read().decode('utf-8'))
|
|
350
368
|
self.message_list = [SystemMessage(content=res_str)]
|
|
351
|
-
print(f"Conversation loaded from s3_path: {self.
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
369
|
+
print(f"Conversation loaded from s3_path: {self.file_path}")
|
|
370
|
+
return self.message_list
|
|
371
|
+
except Exception as e:
|
|
372
|
+
raise ValueError(f"Error loading conversation from s3: {e}")
|
|
373
|
+
|
|
374
|
+
def _load_from_file(self, file_path: str) -> List[Any]:
|
|
375
|
+
"""Load conversation from file"""
|
|
376
|
+
try:
|
|
377
|
+
with open(file_path, 'r') as f:
|
|
378
|
+
lines = f.readlines()
|
|
379
|
+
for line in lines:
|
|
380
|
+
self.message_list.append(SystemMessage(content=line))
|
|
381
|
+
print(f"Conversation loaded from file: {file_path}")
|
|
382
|
+
return self.message_list
|
|
383
|
+
except Exception as e:
|
|
384
|
+
raise ValueError(f"Error loading conversation from file: {e}")
|
|
385
|
+
|
|
386
|
+
class IPythonStreamHandler(StreamingStdOutCallbackHandler):
|
|
387
|
+
"""Handler for IPython display"""
|
|
388
|
+
|
|
389
|
+
def __init__(self):
|
|
390
|
+
self.output = ""
|
|
391
|
+
|
|
392
|
+
def on_llm_new_token(self, token: str, **kwargs) -> None:
|
|
393
|
+
"""Handle new token"""
|
|
394
|
+
self.output += token
|
|
395
|
+
display(HTML(self.output), clear=True)
|