llm-dialog-manager 0.3.2__py3-none-any.whl → 0.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_dialog_manager/__init__.py +1 -1
- llm_dialog_manager/agent.py +374 -193
- llm_dialog_manager/chat_history.py +110 -23
- {llm_dialog_manager-0.3.2.dist-info → llm_dialog_manager-0.3.5.dist-info}/METADATA +2 -2
- llm_dialog_manager-0.3.5.dist-info/RECORD +9 -0
- {llm_dialog_manager-0.3.2.dist-info → llm_dialog_manager-0.3.5.dist-info}/WHEEL +1 -1
- llm_dialog_manager-0.3.2.dist-info/RECORD +0 -9
- {llm_dialog_manager-0.3.2.dist-info → llm_dialog_manager-0.3.5.dist-info}/LICENSE +0 -0
- {llm_dialog_manager-0.3.2.dist-info → llm_dialog_manager-0.3.5.dist-info}/top_level.txt +0 -0
llm_dialog_manager/__init__.py
CHANGED
llm_dialog_manager/agent.py
CHANGED
@@ -2,13 +2,15 @@
|
|
2
2
|
import json
|
3
3
|
import os
|
4
4
|
import uuid
|
5
|
-
from typing import List, Dict, Optional
|
5
|
+
from typing import List, Dict, Optional, Union
|
6
6
|
import logging
|
7
7
|
from pathlib import Path
|
8
8
|
import random
|
9
9
|
import requests
|
10
10
|
import zipfile
|
11
11
|
import io
|
12
|
+
import base64
|
13
|
+
from PIL import Image
|
12
14
|
|
13
15
|
# Third-party imports
|
14
16
|
import anthropic
|
@@ -18,8 +20,8 @@ import openai
|
|
18
20
|
from dotenv import load_dotenv
|
19
21
|
|
20
22
|
# Local imports
|
21
|
-
from
|
22
|
-
from
|
23
|
+
from .chat_history import ChatHistory
|
24
|
+
from .key_manager import key_manager
|
23
25
|
|
24
26
|
# Set up logging
|
25
27
|
logging.basicConfig(level=logging.INFO)
|
@@ -28,7 +30,7 @@ logger = logging.getLogger(__name__)
|
|
28
30
|
# Load environment variables
|
29
31
|
def load_env_vars():
|
30
32
|
"""Load environment variables from .env file"""
|
31
|
-
env_path = Path(__file__).parent
|
33
|
+
env_path = Path(__file__).parent / '.env'
|
32
34
|
if env_path.exists():
|
33
35
|
load_dotenv(env_path)
|
34
36
|
else:
|
@@ -36,24 +38,9 @@ def load_env_vars():
|
|
36
38
|
|
37
39
|
load_env_vars()
|
38
40
|
|
39
|
-
def
|
40
|
-
"""Function to send a message to the Anthropic API and handle the response."""
|
41
|
-
try:
|
42
|
-
response = client.messages.create(
|
43
|
-
model=model,
|
44
|
-
max_tokens=max_tokens,
|
45
|
-
temperature=temperature,
|
46
|
-
messages=messages,
|
47
|
-
system=system_msg
|
48
|
-
)
|
49
|
-
return response
|
50
|
-
except Exception as e:
|
51
|
-
logger.error(f"Error sending message: {str(e)}")
|
52
|
-
raise
|
53
|
-
|
54
|
-
def completion(model: str, messages: List[Dict[str, str]], max_tokens: int = 1000,
|
41
|
+
def completion(model: str, messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]], max_tokens: int = 1000,
|
55
42
|
temperature: float = 0.5, api_key: Optional[str] = None,
|
56
|
-
base_url: Optional[str] = None) -> str:
|
43
|
+
base_url: Optional[str] = None, json_format: bool = False) -> str:
|
57
44
|
"""
|
58
45
|
Generate a completion using the specified model and messages.
|
59
46
|
"""
|
@@ -70,202 +57,393 @@ def completion(model: str, messages: List[Dict[str, str]], max_tokens: int = 100
|
|
70
57
|
|
71
58
|
# Get API key and base URL from key manager if not provided
|
72
59
|
if not api_key:
|
73
|
-
api_key, base_url = key_manager.get_config(service)
|
60
|
+
# api_key, base_url = key_manager.get_config(service)
|
61
|
+
# Placeholder for key_manager
|
62
|
+
api_key = os.getenv(f"{service.upper()}_API_KEY")
|
63
|
+
base_url = os.getenv(f"{service.upper()}_BASE_URL")
|
74
64
|
|
75
|
-
|
65
|
+
def format_messages_for_api(model, messages):
|
66
|
+
"""Convert ChatHistory messages to the format required by the specific API."""
|
76
67
|
if "claude" in model:
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
)
|
68
|
+
formatted = []
|
69
|
+
system_msg = ""
|
70
|
+
if messages and messages[0]["role"] == "system":
|
71
|
+
system_msg = messages.pop(0)["content"]
|
72
|
+
for msg in messages:
|
73
|
+
content = msg["content"]
|
74
|
+
if isinstance(content, str):
|
75
|
+
formatted.append({"role": msg["role"], "content": content})
|
76
|
+
elif isinstance(content, list):
|
77
|
+
# Combine content blocks into a single message
|
78
|
+
combined_content = []
|
79
|
+
for block in content:
|
80
|
+
if isinstance(block, str):
|
81
|
+
combined_content.append({"type": "text", "text": block})
|
82
|
+
elif isinstance(block, Image.Image):
|
83
|
+
# For Claude, convert PIL.Image to base64
|
84
|
+
buffered = io.BytesIO()
|
85
|
+
block.save(buffered, format="PNG")
|
86
|
+
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
87
|
+
combined_content.append({
|
88
|
+
"type": "image",
|
89
|
+
"source": {
|
90
|
+
"type": "base64",
|
91
|
+
"media_type": "image/png",
|
92
|
+
"data": image_base64
|
93
|
+
}
|
94
|
+
})
|
95
|
+
elif isinstance(block, dict):
|
96
|
+
if block.get("type") == "image_url":
|
97
|
+
combined_content.append({
|
98
|
+
"type": "image",
|
99
|
+
"source": {
|
100
|
+
"type": "url",
|
101
|
+
"url": block["image_url"]["url"]
|
102
|
+
}
|
103
|
+
})
|
104
|
+
elif block.get("type") == "image_base64":
|
105
|
+
combined_content.append({
|
106
|
+
"type": "image",
|
107
|
+
"source": {
|
108
|
+
"type": "base64",
|
109
|
+
"media_type": block["image_base64"]["media_type"],
|
110
|
+
"data": block["image_base64"]["data"]
|
111
|
+
}
|
112
|
+
})
|
113
|
+
formatted.append({"role": msg["role"], "content": combined_content})
|
114
|
+
return system_msg, formatted
|
115
|
+
|
116
|
+
elif "gemini" in model or "gpt" in model or "grok" in model:
|
117
|
+
formatted = []
|
118
|
+
for msg in messages:
|
119
|
+
content = msg["content"]
|
120
|
+
if isinstance(content, str):
|
121
|
+
formatted.append({"role": msg["role"], "parts": [content]})
|
122
|
+
elif isinstance(content, list):
|
123
|
+
parts = []
|
124
|
+
for block in content:
|
125
|
+
if isinstance(block, str):
|
126
|
+
parts.append(block)
|
127
|
+
elif isinstance(block, Image.Image):
|
128
|
+
parts.append(block)
|
129
|
+
elif isinstance(block, dict):
|
130
|
+
if block.get("type") == "image_url":
|
131
|
+
parts.append({"type": "image_url", "image_url": {"url": block["image_url"]["url"]}})
|
132
|
+
elif block.get("type") == "image_base64":
|
133
|
+
parts.append({"type": "image_base64", "image_base64": {"data": block["image_base64"]["data"], "media_type": block["image_base64"]["media_type"]}})
|
134
|
+
formatted.append({"role": msg["role"], "parts": parts})
|
135
|
+
return None, formatted
|
136
|
+
|
137
|
+
else: # OpenAI models
|
138
|
+
formatted = []
|
139
|
+
for msg in messages:
|
140
|
+
content = msg["content"]
|
141
|
+
if isinstance(content, str):
|
142
|
+
formatted.append({"role": msg["role"], "content": content})
|
143
|
+
elif isinstance(content, list):
|
144
|
+
# OpenAI expects 'content' as string; images are not directly supported
|
145
|
+
# You can convert images to URLs or descriptions if needed
|
146
|
+
combined_content = ""
|
147
|
+
for block in content:
|
148
|
+
if isinstance(block, str):
|
149
|
+
combined_content += block + "\n"
|
150
|
+
elif isinstance(block, Image.Image):
|
151
|
+
# Convert PIL.Image to base64 or upload and use URL
|
152
|
+
buffered = io.BytesIO()
|
153
|
+
block.save(buffered, format="PNG")
|
154
|
+
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
155
|
+
combined_content += f"[Image Base64: {image_base64[:30]}...]\n"
|
156
|
+
elif isinstance(block, dict):
|
157
|
+
if block.get("type") == "image_url":
|
158
|
+
combined_content += f"[Image: {block['image_url']['url']}]\n"
|
159
|
+
elif block.get("type") == "image_base64":
|
160
|
+
combined_content += f"[Image Base64: {block['image_base64']['data'][:30]}...]\n"
|
161
|
+
formatted.append({"role": msg["role"], "content": combined_content.strip()})
|
162
|
+
return None, formatted
|
163
|
+
|
164
|
+
system_msg, formatted_messages = format_messages_for_api(model, messages.copy())
|
165
|
+
|
166
|
+
if "claude" in model:
|
167
|
+
# Check for Vertex configuration
|
168
|
+
vertex_project_id = os.getenv('VERTEX_PROJECT_ID')
|
169
|
+
vertex_region = os.getenv('VERTEX_REGION')
|
170
|
+
|
171
|
+
if vertex_project_id and vertex_region:
|
172
|
+
client = AnthropicVertex(
|
173
|
+
region=vertex_region,
|
174
|
+
project_id=vertex_project_id
|
175
|
+
)
|
176
|
+
else:
|
177
|
+
client = anthropic.Anthropic(api_key=api_key, base_url=base_url)
|
178
|
+
|
179
|
+
response = client.messages.create(
|
180
|
+
model=model,
|
181
|
+
max_tokens=max_tokens,
|
182
|
+
temperature=temperature,
|
183
|
+
messages=formatted_messages,
|
184
|
+
system=system_msg
|
185
|
+
)
|
186
|
+
|
187
|
+
while response.stop_reason == "max_tokens":
|
188
|
+
if formatted_messages[-1]['role'] == "user":
|
189
|
+
formatted_messages.append({"role": "assistant", "content": response.completion})
|
86
190
|
else:
|
87
|
-
|
191
|
+
formatted_messages[-1]['content'] += response.completion
|
88
192
|
|
89
|
-
system_msg = messages.pop(0)["content"] if messages and messages[0]["role"] == "system" else ""
|
90
193
|
response = client.messages.create(
|
91
194
|
model=model,
|
92
195
|
max_tokens=max_tokens,
|
93
196
|
temperature=temperature,
|
94
|
-
messages=
|
197
|
+
messages=formatted_messages,
|
95
198
|
system=system_msg
|
96
199
|
)
|
97
|
-
|
98
|
-
while response.stop_reason == "max_tokens":
|
99
|
-
if messages[-1]['role'] == "user":
|
100
|
-
messages.append({"role": "assistant", "content": response.content[0].text})
|
101
|
-
else:
|
102
|
-
messages[-1]['content'] += response.content[0].text
|
103
200
|
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
temperature=temperature,
|
108
|
-
messages=messages,
|
109
|
-
system=system_msg
|
110
|
-
)
|
201
|
+
if formatted_messages[-1]['role'] == "assistant" and response.stop_reason == "end_turn":
|
202
|
+
formatted_messages[-1]['content'] += response.completion
|
203
|
+
return formatted_messages[-1]['content']
|
111
204
|
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
205
|
+
return response.completion
|
206
|
+
|
207
|
+
elif "gemini" in model:
|
208
|
+
try:
|
209
|
+
# First try OpenAI-style API
|
210
|
+
client = openai.OpenAI(
|
211
|
+
api_key=api_key,
|
212
|
+
base_url="https://generativelanguage.googleapis.com/v1beta/"
|
213
|
+
)
|
214
|
+
# Set response_format based on json_format
|
215
|
+
response_format = {"type": "json_object"} if json_format else {"type": "plain_text"}
|
216
|
+
|
217
|
+
response = client.chat.completions.create(
|
218
|
+
model=model,
|
219
|
+
messages=formatted_messages,
|
220
|
+
temperature=temperature,
|
221
|
+
response_format=response_format # Added response_format
|
222
|
+
)
|
223
|
+
return response.choices[0].message.content
|
224
|
+
|
225
|
+
except Exception as e:
|
226
|
+
# If OpenAI-style API fails, fall back to Google's genai library
|
227
|
+
logger.info("Falling back to Google's genai library")
|
228
|
+
genai.configure(api_key=api_key)
|
229
|
+
|
230
|
+
# Convert messages to Gemini format
|
231
|
+
gemini_messages = []
|
232
|
+
for msg in messages:
|
233
|
+
if msg["role"] == "system":
|
128
234
|
# Prepend system message to first user message if exists
|
129
|
-
if
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
genai.
|
144
|
-
|
145
|
-
# Convert messages to Gemini format
|
146
|
-
gemini_messages = []
|
147
|
-
for msg in messages:
|
148
|
-
if msg["role"] == "system":
|
149
|
-
# Prepend system message to first user message if exists
|
150
|
-
if gemini_messages:
|
151
|
-
gemini_messages[0].parts[0].text = f"{msg['content']}\n\n{gemini_messages[0].parts[0].text}"
|
152
|
-
else:
|
153
|
-
gemini_messages.append({"role": msg["role"], "parts": [{"text": msg["content"]}]})
|
154
|
-
|
155
|
-
# Create Gemini model and generate response
|
156
|
-
model = genai.GenerativeModel(model_name=model)
|
157
|
-
response = model.generate_content(
|
158
|
-
gemini_messages,
|
159
|
-
generation_config=genai.types.GenerationConfig(
|
160
|
-
temperature=temperature,
|
161
|
-
max_output_tokens=max_tokens
|
162
|
-
)
|
163
|
-
)
|
164
|
-
|
165
|
-
return response.text
|
166
|
-
|
167
|
-
elif "grok" in model:
|
168
|
-
# Randomly choose between OpenAI and Anthropic SDK
|
169
|
-
use_anthropic = random.choice([True, False])
|
170
|
-
|
171
|
-
if use_anthropic:
|
172
|
-
print("using anthropic")
|
173
|
-
client = anthropic.Anthropic(
|
174
|
-
api_key=api_key,
|
175
|
-
base_url="https://api.x.ai"
|
176
|
-
)
|
177
|
-
|
178
|
-
system_msg = messages.pop(0)["content"] if messages and messages[0]["role"] == "system" else ""
|
179
|
-
response = client.messages.create(
|
180
|
-
model=model,
|
181
|
-
max_tokens=max_tokens,
|
235
|
+
if gemini_messages:
|
236
|
+
first_msg = gemini_messages[0]
|
237
|
+
if "parts" in first_msg and len(first_msg["parts"]) > 0:
|
238
|
+
first_msg["parts"][0] = f"{msg['content']}\n\n{first_msg['parts'][0]}"
|
239
|
+
else:
|
240
|
+
gemini_messages.append({"role": msg["role"], "parts": msg["content"]})
|
241
|
+
|
242
|
+
# Set response_mime_type based on json_format
|
243
|
+
mime_type = "application/json" if json_format else "text/plain"
|
244
|
+
|
245
|
+
# Create Gemini model and generate response
|
246
|
+
model_instance = genai.GenerativeModel(model_name=model)
|
247
|
+
response = model_instance.generate_content(
|
248
|
+
gemini_messages,
|
249
|
+
generation_config=genai.types.GenerationConfig(
|
182
250
|
temperature=temperature,
|
183
|
-
|
184
|
-
|
185
|
-
)
|
186
|
-
return response.content[0].text
|
187
|
-
else:
|
188
|
-
print("using openai")
|
189
|
-
client = openai.OpenAI(
|
190
|
-
api_key=api_key,
|
191
|
-
base_url="https://api.x.ai/v1"
|
192
|
-
)
|
193
|
-
response = client.chat.completions.create(
|
194
|
-
model=model,
|
195
|
-
messages=messages,
|
196
|
-
max_tokens=max_tokens,
|
197
|
-
temperature=temperature
|
251
|
+
response_mime_type=mime_type, # Modified based on json_format
|
252
|
+
max_output_tokens=max_tokens
|
198
253
|
)
|
199
|
-
|
254
|
+
)
|
255
|
+
|
256
|
+
return response.text
|
257
|
+
|
258
|
+
elif "grok" in model:
|
259
|
+
# Randomly choose between OpenAI and Anthropic SDK
|
260
|
+
use_anthropic = random.choice([True, False])
|
261
|
+
|
262
|
+
if use_anthropic:
|
263
|
+
logger.info("Using Anthropic for Grok model")
|
264
|
+
client = anthropic.Anthropic(
|
265
|
+
api_key=api_key,
|
266
|
+
base_url="https://api.x.ai"
|
267
|
+
)
|
268
|
+
|
269
|
+
system_msg = ""
|
270
|
+
if messages and messages[0]["role"] == "system":
|
271
|
+
system_msg = messages.pop(0)["content"]
|
272
|
+
|
273
|
+
response = client.messages.create(
|
274
|
+
model=model,
|
275
|
+
max_tokens=max_tokens,
|
276
|
+
temperature=temperature,
|
277
|
+
messages=formatted_messages,
|
278
|
+
system=system_msg
|
279
|
+
)
|
280
|
+
return response.completion
|
281
|
+
else:
|
282
|
+
logger.info("Using OpenAI for Grok model")
|
283
|
+
client = openai.OpenAI(
|
284
|
+
api_key=api_key,
|
285
|
+
base_url="https://api.x.ai/v1"
|
286
|
+
)
|
287
|
+
# Set response_format based on json_format
|
288
|
+
response_format = {"type": "json_object"} if json_format else {"type": "plain_text"}
|
200
289
|
|
201
|
-
else: # OpenAI models
|
202
|
-
client = openai.OpenAI(api_key=api_key, base_url=base_url)
|
203
290
|
response = client.chat.completions.create(
|
204
291
|
model=model,
|
205
|
-
messages=
|
292
|
+
messages=formatted_messages,
|
206
293
|
max_tokens=max_tokens,
|
207
294
|
temperature=temperature,
|
295
|
+
response_format=response_format # Added response_format
|
208
296
|
)
|
209
297
|
return response.choices[0].message.content
|
210
298
|
|
211
|
-
|
212
|
-
|
213
|
-
|
299
|
+
else: # OpenAI models
|
300
|
+
client = openai.OpenAI(api_key=api_key, base_url=base_url)
|
301
|
+
# Set response_format based on json_format
|
302
|
+
response_format = {"type": "json_object"} if json_format else {"type": "plain_text"}
|
214
303
|
|
215
|
-
|
304
|
+
response = client.chat.completions.create(
|
305
|
+
model=model,
|
306
|
+
messages=formatted_messages,
|
307
|
+
max_tokens=max_tokens,
|
308
|
+
temperature=temperature,
|
309
|
+
response_format=response_format # Added response_format
|
310
|
+
)
|
311
|
+
return response.choices[0].message.content
|
216
312
|
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
313
|
+
# Release the API key after successful use
|
314
|
+
if not api_key:
|
315
|
+
# key_manager.release_config(service, api_key)
|
316
|
+
pass
|
317
|
+
|
318
|
+
return response
|
222
319
|
|
223
320
|
except Exception as e:
|
224
321
|
logger.error(f"Error in completion: {str(e)}")
|
225
322
|
raise
|
226
323
|
|
227
324
|
class Agent:
|
228
|
-
def __init__(self, model_name: str, messages: Optional[str] = None,
|
325
|
+
def __init__(self, model_name: str, messages: Optional[Union[str, List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]]] = None,
|
229
326
|
memory_enabled: bool = False, api_key: Optional[str] = None) -> None:
|
230
327
|
"""Initialize an Agent instance."""
|
231
|
-
# valid_models = ['gpt-3.5-turbo', 'gpt-4', 'claude-2.1', 'gemini-1.5-pro', 'gemini-1.5-flash', 'grok-beta', 'claude-3-5-sonnet-20241022']
|
232
|
-
# if model_name not in valid_models:
|
233
|
-
# raise ValueError(f"Model {model_name} not supported. Supported models: {valid_models}")
|
234
|
-
|
235
328
|
self.id = f"{model_name}-{uuid.uuid4().hex[:8]}"
|
236
329
|
self.model_name = model_name
|
237
|
-
self.history = ChatHistory(messages)
|
330
|
+
self.history = ChatHistory(messages) if messages else ChatHistory()
|
238
331
|
self.memory_enabled = memory_enabled
|
239
332
|
self.api_key = api_key
|
240
333
|
self.repo_content = []
|
241
334
|
|
242
|
-
def add_message(self, role, content):
|
243
|
-
|
244
|
-
while self.repo_content:
|
245
|
-
repo = self.repo_content.pop()
|
246
|
-
repo_content += f"<repo>\n{repo}\n</repo>\n"
|
247
|
-
|
248
|
-
content = repo_content + content
|
335
|
+
def add_message(self, role: str, content: Union[str, List[Union[str, Image.Image, Dict]]]):
|
336
|
+
"""Add a message to the conversation."""
|
249
337
|
self.history.add_message(content, role)
|
250
338
|
|
251
|
-
def
|
339
|
+
def add_user_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
|
340
|
+
"""Add a user message."""
|
341
|
+
self.history.add_user_message(content)
|
342
|
+
|
343
|
+
def add_assistant_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
|
344
|
+
"""Add an assistant message."""
|
345
|
+
self.history.add_assistant_message(content)
|
346
|
+
|
347
|
+
def add_image(self, image_path: Optional[str] = None, image_url: Optional[str] = None, media_type: Optional[str] = "image/jpeg"):
|
348
|
+
"""
|
349
|
+
Add an image to the conversation.
|
350
|
+
Either image_path or image_url must be provided.
|
351
|
+
"""
|
352
|
+
if not image_path and not image_url:
|
353
|
+
raise ValueError("Either image_path or image_url must be provided.")
|
354
|
+
|
355
|
+
if image_path:
|
356
|
+
if not os.path.exists(image_path):
|
357
|
+
raise FileNotFoundError(f"Image file {image_path} does not exist.")
|
358
|
+
if "gemini" in self.model_name:
|
359
|
+
# For Gemini, load as PIL.Image
|
360
|
+
image_pil = Image.open(image_path)
|
361
|
+
image_block = image_pil
|
362
|
+
else:
|
363
|
+
# For Claude and others, use base64 encoding
|
364
|
+
with open(image_path, "rb") as img_file:
|
365
|
+
image_data = base64.standard_b64encode(img_file.read()).decode("utf-8")
|
366
|
+
image_block = {
|
367
|
+
"type": "image_base64",
|
368
|
+
"image_base64": {
|
369
|
+
"media_type": media_type,
|
370
|
+
"data": image_data
|
371
|
+
}
|
372
|
+
}
|
373
|
+
else:
|
374
|
+
# If image_url is provided
|
375
|
+
if "gemini" in self.model_name:
|
376
|
+
# For Gemini, you can pass image URLs directly
|
377
|
+
image_block = {"type": "image_url", "image_url": {"url": image_url}}
|
378
|
+
else:
|
379
|
+
# For Claude and others, use image URLs
|
380
|
+
image_block = {
|
381
|
+
"type": "image_url",
|
382
|
+
"image_url": {
|
383
|
+
"url": image_url
|
384
|
+
}
|
385
|
+
}
|
386
|
+
|
387
|
+
# Add the image block to the last user message or as a new user message
|
388
|
+
if self.history.last_role == "user":
|
389
|
+
current_content = self.history.messages[-1]["content"]
|
390
|
+
if isinstance(current_content, list):
|
391
|
+
current_content.append(image_block)
|
392
|
+
else:
|
393
|
+
self.history.messages[-1]["content"] = [current_content, image_block]
|
394
|
+
else:
|
395
|
+
# Start a new user message with the image
|
396
|
+
self.history.add_message([image_block], "user")
|
397
|
+
|
398
|
+
def generate_response(self, max_tokens=3585, temperature=0.7, json_format: bool = False) -> str:
|
399
|
+
"""Generate a response from the agent.
|
400
|
+
|
401
|
+
Args:
|
402
|
+
max_tokens (int, optional): Maximum number of tokens. Defaults to 3585.
|
403
|
+
temperature (float, optional): Sampling temperature. Defaults to 0.7.
|
404
|
+
json_format (bool, optional): Whether to enable JSON output format. Defaults to False.
|
405
|
+
|
406
|
+
Returns:
|
407
|
+
str: The generated response.
|
408
|
+
"""
|
252
409
|
if not self.history.messages:
|
253
410
|
raise ValueError("No messages in history to generate response from")
|
254
411
|
|
255
|
-
messages =
|
256
|
-
|
412
|
+
messages = self.history.messages
|
413
|
+
|
257
414
|
response_text = completion(
|
258
415
|
model=self.model_name,
|
259
416
|
messages=messages,
|
260
417
|
max_tokens=max_tokens,
|
261
418
|
temperature=temperature,
|
262
|
-
api_key=self.api_key
|
419
|
+
api_key=self.api_key,
|
420
|
+
json_format=json_format # Pass json_format to completion
|
263
421
|
)
|
264
|
-
if
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
self.
|
422
|
+
if self.model_name.startswith("openai"):
|
423
|
+
# OpenAI does not support images, so responses are simple strings
|
424
|
+
if self.history.messages[-1]["role"] == "assistant":
|
425
|
+
self.history.messages[-1]["content"] = response_text
|
426
|
+
elif self.memory_enabled:
|
427
|
+
self.add_message("assistant", response_text)
|
428
|
+
elif "claude" in self.model_name:
|
429
|
+
if self.history.messages[-1]["role"] == "assistant":
|
430
|
+
self.history.messages[-1]["content"] = response_text
|
431
|
+
elif self.memory_enabled:
|
432
|
+
self.add_message("assistant", response_text)
|
433
|
+
elif "gemini" in self.model_name or "grok" in self.model_name:
|
434
|
+
if self.history.messages[-1]["role"] == "assistant":
|
435
|
+
if isinstance(self.history.messages[-1]["content"], list):
|
436
|
+
self.history.messages[-1]["content"].append(response_text)
|
437
|
+
else:
|
438
|
+
self.history.messages[-1]["content"] = [self.history.messages[-1]["content"], response_text]
|
439
|
+
elif self.memory_enabled:
|
440
|
+
self.add_message("assistant", response_text)
|
441
|
+
else:
|
442
|
+
# Handle other models similarly
|
443
|
+
if self.history.messages[-1]["role"] == "assistant":
|
444
|
+
self.history.messages[-1]["content"] = response_text
|
445
|
+
elif self.memory_enabled:
|
446
|
+
self.add_message("assistant", response_text)
|
269
447
|
|
270
448
|
return response_text
|
271
449
|
|
@@ -274,20 +452,20 @@ class Agent:
|
|
274
452
|
with open(filename, 'w', encoding='utf-8') as file:
|
275
453
|
json.dump(self.history.messages, file, ensure_ascii=False, indent=4)
|
276
454
|
|
277
|
-
def load_conversation(self, filename=None):
|
455
|
+
def load_conversation(self, filename: Optional[str] = None):
|
278
456
|
if filename is None:
|
279
457
|
filename = f"{self.id}.json"
|
280
458
|
with open(filename, 'r', encoding='utf-8') as file:
|
281
459
|
messages = json.load(file)
|
460
|
+
# Handle deserialization of images if necessary
|
282
461
|
self.history = ChatHistory(messages)
|
283
462
|
|
284
|
-
def
|
285
|
-
|
286
|
-
if repo_name:
|
463
|
+
def add_repo(self, repo_url: Optional[str] = None, username: Optional[str] = None, repo_name: Optional[str] = None, commit_hash: Optional[str] = None):
|
464
|
+
if username and repo_name:
|
287
465
|
if commit_hash:
|
288
|
-
repo_url = f"https://github.com/{repo_name}/archive/{commit_hash}.zip"
|
466
|
+
repo_url = f"https://github.com/{username}/{repo_name}/archive/{commit_hash}.zip"
|
289
467
|
else:
|
290
|
-
repo_url = f"https://github.com/{repo_name}/archive/refs/heads/main.zip"
|
468
|
+
repo_url = f"https://github.com/{username}/{repo_name}/archive/refs/heads/main.zip"
|
291
469
|
|
292
470
|
if not repo_url:
|
293
471
|
raise ValueError("Either repo_url or both username and repo_name must be provided")
|
@@ -306,28 +484,31 @@ class Agent:
|
|
306
484
|
raise ValueError(f"Failed to download repository from {repo_url}")
|
307
485
|
|
308
486
|
if __name__ == "__main__":
|
309
|
-
|
310
|
-
#
|
311
|
-
|
312
|
-
|
313
|
-
# from agent.messageloader import information_detector_messages
|
487
|
+
# Example Usage
|
488
|
+
# Create an Agent instance (Gemini model)
|
489
|
+
agent = Agent("gemini-1.5-flash", "you are an assistant", memory_enabled=True)
|
314
490
|
|
315
|
-
#
|
316
|
-
|
317
|
-
# information_detector_agent.add_message("user", text)
|
318
|
-
# response = information_detector_agent.generate_response()
|
319
|
-
# print(response)
|
320
|
-
agent = Agent("gemini-1.5-pro-002", "you are an assistant", memory_enabled=True)
|
491
|
+
# Add an image
|
492
|
+
agent.add_image(image_path="/Users/junfan/Projects/Personal/oneapi/dialog_manager/example.png")
|
321
493
|
|
322
|
-
#
|
323
|
-
|
494
|
+
# Add a user message
|
495
|
+
agent.add_message("user", "What's in this image?")
|
324
496
|
|
325
|
-
#
|
326
|
-
|
327
|
-
|
497
|
+
# Generate response with JSON format enabled
|
498
|
+
try:
|
499
|
+
response = agent.generate_response(json_format=True) # json_format set to True
|
500
|
+
print("Response:", response)
|
501
|
+
except Exception as e:
|
502
|
+
logger.error(f"Failed to generate response: {e}")
|
328
503
|
|
329
|
-
|
330
|
-
print(
|
504
|
+
# Print the entire conversation history
|
505
|
+
print("Conversation History:")
|
506
|
+
print(agent.history)
|
507
|
+
|
508
|
+
# Pop the last message
|
331
509
|
last_message = agent.history.pop()
|
332
|
-
print(last_message)
|
333
|
-
|
510
|
+
print("Last Message:", last_message)
|
511
|
+
|
512
|
+
# Generate another response without JSON format
|
513
|
+
response = agent.generate_response()
|
514
|
+
print("Response:", response)
|
@@ -1,53 +1,122 @@
|
|
1
1
|
from typing import List, Dict, Optional, Union
|
2
|
+
from PIL import Image
|
2
3
|
|
3
4
|
class ChatHistory:
|
4
|
-
def __init__(self, input_data: Union[str, List[Dict[str, str]]] = "") -> None:
|
5
|
-
self.messages: List[Dict[str, str]] = []
|
5
|
+
def __init__(self, input_data: Union[str, List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]] = "") -> None:
|
6
|
+
self.messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]] = []
|
6
7
|
if isinstance(input_data, str) and input_data:
|
7
8
|
self.add_message(input_data, "system")
|
8
9
|
elif isinstance(input_data, list):
|
9
10
|
self.load_messages(input_data)
|
10
|
-
self.last_role: str = "system" if not self.messages else self.
|
11
|
+
self.last_role: str = "system" if not self.messages else self.get_last_role()
|
11
12
|
|
12
|
-
def load_messages(self, messages: List[Dict[str, str]]) -> None:
|
13
|
+
def load_messages(self, messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]) -> None:
|
13
14
|
for message in messages:
|
14
15
|
if not ("role" in message and "content" in message):
|
15
16
|
raise ValueError("Each message must have a 'role' and 'content'.")
|
16
17
|
if message["role"] not in ["user", "assistant", "system"]:
|
17
18
|
raise ValueError(f"Invalid role: {message['role']}")
|
18
19
|
self.messages.append(message)
|
19
|
-
|
20
|
+
self.last_role = self.get_last_role()
|
21
|
+
|
22
|
+
def get_last_role(self):
|
23
|
+
return self.messages[-1]["role"] if self.messages else "system"
|
24
|
+
|
20
25
|
def pop(self):
|
21
26
|
if not self.messages:
|
22
27
|
return None
|
23
|
-
|
28
|
+
|
24
29
|
popped_message = self.messages.pop()
|
25
|
-
|
30
|
+
|
26
31
|
if self.messages:
|
27
|
-
self.last_role = self.
|
32
|
+
self.last_role = self.get_last_role()
|
28
33
|
else:
|
29
34
|
self.last_role = "system"
|
30
|
-
|
35
|
+
|
31
36
|
return popped_message["content"]
|
32
37
|
|
33
38
|
def __len__(self):
|
34
39
|
return len(self.messages)
|
35
40
|
|
36
41
|
def __str__(self):
|
37
|
-
|
42
|
+
formatted_messages = []
|
43
|
+
for i, msg in enumerate(self.messages):
|
44
|
+
role = msg['role']
|
45
|
+
content = msg['content']
|
46
|
+
if isinstance(content, str):
|
47
|
+
formatted_content = content
|
48
|
+
elif isinstance(content, list):
|
49
|
+
parts = []
|
50
|
+
for block in content:
|
51
|
+
if isinstance(block, str):
|
52
|
+
parts.append(block)
|
53
|
+
elif isinstance(block, Image.Image):
|
54
|
+
parts.append(f"[Image Object: {block.filename}]")
|
55
|
+
elif isinstance(block, dict):
|
56
|
+
if block.get("type") == "image_url":
|
57
|
+
parts.append(f"[Image URL: {block.get('image_url', {}).get('url', '')}]")
|
58
|
+
elif block.get("type") == "image_base64":
|
59
|
+
parts.append(f"[Image Base64: {block.get('image_base64', {}).get('data', '')[:30]}...]")
|
60
|
+
formatted_content = "\n".join(parts)
|
61
|
+
else:
|
62
|
+
formatted_content = str(content)
|
63
|
+
formatted_messages.append(f"Message {i} ({role}): {formatted_content}")
|
64
|
+
return '\n'.join(formatted_messages)
|
38
65
|
|
39
66
|
def __getitem__(self, key):
|
40
67
|
if isinstance(key, slice):
|
41
|
-
|
42
|
-
|
43
|
-
|
68
|
+
sliced_messages = self.messages[key]
|
69
|
+
formatted = []
|
70
|
+
for msg in sliced_messages:
|
71
|
+
role = msg['role']
|
72
|
+
content = msg['content']
|
73
|
+
if isinstance(content, str):
|
74
|
+
formatted_content = content
|
75
|
+
elif isinstance(content, list):
|
76
|
+
parts = []
|
77
|
+
for block in content:
|
78
|
+
if isinstance(block, str):
|
79
|
+
parts.append(block)
|
80
|
+
elif isinstance(block, Image.Image):
|
81
|
+
parts.append(f"[Image Object: {block.filename}]")
|
82
|
+
elif isinstance(block, dict):
|
83
|
+
if block.get("type") == "image_url":
|
84
|
+
parts.append(f"[Image URL: {block.get('image_url', {}).get('url', '')}]")
|
85
|
+
elif block.get("type") == "image_base64":
|
86
|
+
parts.append(f"[Image Base64: {block.get('image_base64', {}).get('data', '')[:30]}...]")
|
87
|
+
formatted_content = "\n".join(parts)
|
88
|
+
else:
|
89
|
+
formatted_content = str(content)
|
90
|
+
formatted.append(f"({role}): {formatted_content}")
|
91
|
+
print('\n'.join(formatted))
|
92
|
+
return sliced_messages
|
44
93
|
elif isinstance(key, int):
|
45
94
|
# Adjust for negative indices
|
46
95
|
if key < 0:
|
47
|
-
key += len(self.messages)
|
96
|
+
key += len(self.messages)
|
48
97
|
if 0 <= key < len(self.messages):
|
98
|
+
msg = self.messages[key]
|
99
|
+
role = msg['role']
|
100
|
+
content = msg['content']
|
101
|
+
if isinstance(content, str):
|
102
|
+
formatted_content = content
|
103
|
+
elif isinstance(content, list):
|
104
|
+
parts = []
|
105
|
+
for block in content:
|
106
|
+
if isinstance(block, str):
|
107
|
+
parts.append(block)
|
108
|
+
elif isinstance(block, Image.Image):
|
109
|
+
parts.append(f"[Image Object: {block.filename}]")
|
110
|
+
elif isinstance(block, dict):
|
111
|
+
if block.get("type") == "image_url":
|
112
|
+
parts.append(f"[Image URL: {block.get('image_url', {}).get('url', '')}]")
|
113
|
+
elif block.get("type") == "image_base64":
|
114
|
+
parts.append(f"[Image Base64: {block.get('image_base64', {}).get('data', '')[:30]}...]")
|
115
|
+
formatted_content = "\n".join(parts)
|
116
|
+
else:
|
117
|
+
formatted_content = str(content)
|
49
118
|
snippet = self.get_conversation_snippet(key)
|
50
|
-
print('\n'.join([f"({v['role']}): {
|
119
|
+
print('\n'.join([f"({v['role']}): {v['content']}" for k, v in snippet.items() if v]))
|
51
120
|
return self.messages[key]
|
52
121
|
else:
|
53
122
|
raise IndexError("Message index out of range.")
|
@@ -55,8 +124,8 @@ class ChatHistory:
|
|
55
124
|
raise TypeError("Invalid argument type.")
|
56
125
|
|
57
126
|
def __setitem__(self, index, value):
|
58
|
-
if not isinstance(value, str):
|
59
|
-
raise ValueError("Message content must be a string.")
|
127
|
+
if not isinstance(value, (str, list)):
|
128
|
+
raise ValueError("Message content must be a string or a list of content blocks.")
|
60
129
|
role = "system" if index % 2 == 0 else "user"
|
61
130
|
self.messages[index] = {"role": role, "content": value}
|
62
131
|
|
@@ -68,19 +137,27 @@ class ChatHistory:
|
|
68
137
|
self.add_message(message, next_role)
|
69
138
|
|
70
139
|
def __contains__(self, item):
|
71
|
-
|
140
|
+
for message in self.messages:
|
141
|
+
content = message['content']
|
142
|
+
if isinstance(content, str) and item in content:
|
143
|
+
return True
|
144
|
+
elif isinstance(content, list):
|
145
|
+
for block in content:
|
146
|
+
if isinstance(block, str) and item in block:
|
147
|
+
return True
|
148
|
+
return False
|
72
149
|
|
73
|
-
def add_message(self, content, role):
|
150
|
+
def add_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]], role: str):
|
74
151
|
self.messages.append({"role": role, "content": content})
|
75
152
|
self.last_role = role
|
76
153
|
|
77
|
-
def add_user_message(self, content):
|
154
|
+
def add_user_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
|
78
155
|
if self.last_role in ["system", "assistant"]:
|
79
156
|
self.add_message(content, "user")
|
80
157
|
else:
|
81
158
|
raise ValueError("A user message must follow a system or assistant message.")
|
82
159
|
|
83
|
-
def add_assistant_message(self, content):
|
160
|
+
def add_assistant_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
|
84
161
|
if self.last_role == "user":
|
85
162
|
self.add_message(content, "assistant")
|
86
163
|
else:
|
@@ -110,7 +187,17 @@ class ChatHistory:
|
|
110
187
|
print(f"Content of the last message: {status['last_message_content']}")
|
111
188
|
|
112
189
|
def search_for_keyword(self, keyword):
|
113
|
-
|
190
|
+
results = []
|
191
|
+
for msg in self.messages:
|
192
|
+
content = msg['content']
|
193
|
+
if isinstance(content, str) and keyword.lower() in content.lower():
|
194
|
+
results.append(msg)
|
195
|
+
elif isinstance(content, list):
|
196
|
+
for block in content:
|
197
|
+
if isinstance(block, str) and keyword.lower() in block.lower():
|
198
|
+
results.append(msg)
|
199
|
+
break
|
200
|
+
return results
|
114
201
|
|
115
202
|
def has_user_or_assistant_spoken_since_last_system(self):
|
116
203
|
for msg in reversed(self.messages):
|
@@ -143,4 +230,4 @@ class ChatHistory:
|
|
143
230
|
@staticmethod
|
144
231
|
def color_text(text, color):
|
145
232
|
colors = {"green": "\033[92m", "red": "\033[91m", "end": "\033[0m"}
|
146
|
-
return f"{colors
|
233
|
+
return f"{colors.get(color, '')}{text}{colors.get('end', '')}"
|
@@ -0,0 +1,9 @@
|
|
1
|
+
llm_dialog_manager/__init__.py,sha256=J7L76hDTCNM56mprhbMclqCG04IacKiIgaHm8Ty7shQ,86
|
2
|
+
llm_dialog_manager/agent.py,sha256=aMeSL7rV7sSGgJzCkXp_ahiq569eTy-9Jfepam4pKUU,23064
|
3
|
+
llm_dialog_manager/chat_history.py,sha256=DKKRnj_M6h-4JncnH6KekMTghX7vMgdN3J9uOwXKzMU,10347
|
4
|
+
llm_dialog_manager/key_manager.py,sha256=shvxmn4zUtQx_p-x1EFyOmnk-WlhigbpKtxTKve-zXk,4421
|
5
|
+
llm_dialog_manager-0.3.5.dist-info/LICENSE,sha256=vWGbYgGuWpWrXL8-xi6pNcX5UzD6pWoIAZmcetyfbus,1064
|
6
|
+
llm_dialog_manager-0.3.5.dist-info/METADATA,sha256=y9rfQ9rcmwrScQJYK-0PjRPCKFWqAKFsm_8NoSwiloI,4152
|
7
|
+
llm_dialog_manager-0.3.5.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
8
|
+
llm_dialog_manager-0.3.5.dist-info/top_level.txt,sha256=u2EQEXW0NGAt0AAHT7jx1odXZ4rZfjcgbmJhvKFuMkI,19
|
9
|
+
llm_dialog_manager-0.3.5.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
llm_dialog_manager/__init__.py,sha256=u9m9_3UUibXqNSRzkf63lAbkw_l5RmGcxlVOhmeShlg,86
|
2
|
-
llm_dialog_manager/agent.py,sha256=o0ayPp9eS_8Vx78GYMPlIxHrK1UHTbu_I3dEsk4Agdo,13268
|
3
|
-
llm_dialog_manager/chat_history.py,sha256=xKA-oQCv8jv_g8EhXrG9h1S8Icbj2FfqPIhbty5vra4,6033
|
4
|
-
llm_dialog_manager/key_manager.py,sha256=shvxmn4zUtQx_p-x1EFyOmnk-WlhigbpKtxTKve-zXk,4421
|
5
|
-
llm_dialog_manager-0.3.2.dist-info/LICENSE,sha256=vWGbYgGuWpWrXL8-xi6pNcX5UzD6pWoIAZmcetyfbus,1064
|
6
|
-
llm_dialog_manager-0.3.2.dist-info/METADATA,sha256=-DNhRHechHF38K6DGrHCeHVTj2JX3sxw4ZYFQr6Sj5I,4152
|
7
|
-
llm_dialog_manager-0.3.2.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
8
|
-
llm_dialog_manager-0.3.2.dist-info/top_level.txt,sha256=u2EQEXW0NGAt0AAHT7jx1odXZ4rZfjcgbmJhvKFuMkI,19
|
9
|
-
llm_dialog_manager-0.3.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|