llm-dialog-manager 0.3.4__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_dialog_manager/__init__.py +1 -1
- llm_dialog_manager/agent.py +405 -192
- llm_dialog_manager/chat_history.py +110 -23
- {llm_dialog_manager-0.3.4.dist-info → llm_dialog_manager-0.4.1.dist-info}/METADATA +3 -2
- llm_dialog_manager-0.4.1.dist-info/RECORD +9 -0
- {llm_dialog_manager-0.3.4.dist-info → llm_dialog_manager-0.4.1.dist-info}/WHEEL +1 -1
- llm_dialog_manager-0.3.4.dist-info/RECORD +0 -9
- {llm_dialog_manager-0.3.4.dist-info → llm_dialog_manager-0.4.1.dist-info}/LICENSE +0 -0
- {llm_dialog_manager-0.3.4.dist-info → llm_dialog_manager-0.4.1.dist-info}/top_level.txt +0 -0
llm_dialog_manager/__init__.py
CHANGED
llm_dialog_manager/agent.py
CHANGED
@@ -2,13 +2,15 @@
|
|
2
2
|
import json
|
3
3
|
import os
|
4
4
|
import uuid
|
5
|
-
from typing import List, Dict, Optional
|
5
|
+
from typing import List, Dict, Optional, Union
|
6
6
|
import logging
|
7
7
|
from pathlib import Path
|
8
8
|
import random
|
9
9
|
import requests
|
10
10
|
import zipfile
|
11
11
|
import io
|
12
|
+
import base64
|
13
|
+
from PIL import Image
|
12
14
|
|
13
15
|
# Third-party imports
|
14
16
|
import anthropic
|
@@ -18,8 +20,8 @@ import openai
|
|
18
20
|
from dotenv import load_dotenv
|
19
21
|
|
20
22
|
# Local imports
|
21
|
-
from
|
22
|
-
from
|
23
|
+
from .chat_history import ChatHistory
|
24
|
+
from .key_manager import key_manager
|
23
25
|
|
24
26
|
# Set up logging
|
25
27
|
logging.basicConfig(level=logging.INFO)
|
@@ -28,7 +30,7 @@ logger = logging.getLogger(__name__)
|
|
28
30
|
# Load environment variables
|
29
31
|
def load_env_vars():
|
30
32
|
"""Load environment variables from .env file"""
|
31
|
-
env_path = Path(__file__).parent
|
33
|
+
env_path = Path(__file__).parent / '.env'
|
32
34
|
if env_path.exists():
|
33
35
|
load_dotenv(env_path)
|
34
36
|
else:
|
@@ -36,24 +38,37 @@ def load_env_vars():
|
|
36
38
|
|
37
39
|
load_env_vars()
|
38
40
|
|
39
|
-
def
|
40
|
-
"""
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
messages=messages,
|
47
|
-
system=system_msg
|
48
|
-
)
|
49
|
-
return response
|
50
|
-
except Exception as e:
|
51
|
-
logger.error(f"Error sending message: {str(e)}")
|
52
|
-
raise
|
41
|
+
def format_messages_for_gemini(messages):
|
42
|
+
"""
|
43
|
+
将标准化的消息格式转化为 Gemini 格式。
|
44
|
+
system 消息应该通过 GenerativeModel 的 system_instruction 参数传入,
|
45
|
+
不在这个函数处理。
|
46
|
+
"""
|
47
|
+
gemini_messages = []
|
53
48
|
|
54
|
-
|
55
|
-
|
56
|
-
|
49
|
+
for msg in messages:
|
50
|
+
role = msg["role"]
|
51
|
+
content = msg["content"]
|
52
|
+
|
53
|
+
# 跳过 system 消息,因为它会通过 system_instruction 设置
|
54
|
+
if role == "system":
|
55
|
+
continue
|
56
|
+
|
57
|
+
# 处理 user/assistant 消息
|
58
|
+
# 如果 content 是单一对象,转换为列表
|
59
|
+
if not isinstance(content, list):
|
60
|
+
content = [content]
|
61
|
+
|
62
|
+
gemini_messages.append({
|
63
|
+
"role": role,
|
64
|
+
"parts": content # content 可以包含文本和 FileMedia
|
65
|
+
})
|
66
|
+
|
67
|
+
return gemini_messages
|
68
|
+
|
69
|
+
def completion(model: str, messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]], max_tokens: int = 1000,
|
70
|
+
temperature: float = 0.5, top_p: float = 1.0, top_k: int = 40, api_key: Optional[str] = None,
|
71
|
+
base_url: Optional[str] = None, json_format: bool = False) -> str:
|
57
72
|
"""
|
58
73
|
Generate a completion using the specified model and messages.
|
59
74
|
"""
|
@@ -70,202 +85,396 @@ def completion(model: str, messages: List[Dict[str, str]], max_tokens: int = 100
|
|
70
85
|
|
71
86
|
# Get API key and base URL from key manager if not provided
|
72
87
|
if not api_key:
|
73
|
-
api_key, base_url = key_manager.get_config(service)
|
88
|
+
# api_key, base_url = key_manager.get_config(service)
|
89
|
+
# Placeholder for key_manager
|
90
|
+
api_key = os.getenv(f"{service.upper()}_API_KEY")
|
91
|
+
base_url = os.getenv(f"{service.upper()}_BASE_URL")
|
74
92
|
|
75
|
-
|
93
|
+
def format_messages_for_api(model, messages):
|
94
|
+
"""Convert ChatHistory messages to the format required by the specific API."""
|
76
95
|
if "claude" in model:
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
)
|
96
|
+
formatted = []
|
97
|
+
system_msg = ""
|
98
|
+
if messages and messages[0]["role"] == "system":
|
99
|
+
system_msg = messages.pop(0)["content"]
|
100
|
+
for msg in messages:
|
101
|
+
content = msg["content"]
|
102
|
+
if isinstance(content, str):
|
103
|
+
formatted.append({"role": msg["role"], "content": content})
|
104
|
+
elif isinstance(content, list):
|
105
|
+
# Combine content blocks into a single message
|
106
|
+
combined_content = []
|
107
|
+
for block in content:
|
108
|
+
if isinstance(block, str):
|
109
|
+
combined_content.append({"type": "text", "text": block})
|
110
|
+
elif isinstance(block, Image.Image):
|
111
|
+
# For Claude, convert PIL.Image to base64
|
112
|
+
buffered = io.BytesIO()
|
113
|
+
block.save(buffered, format="PNG")
|
114
|
+
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
115
|
+
combined_content.append({
|
116
|
+
"type": "image",
|
117
|
+
"source": {
|
118
|
+
"type": "base64",
|
119
|
+
"media_type": "image/png",
|
120
|
+
"data": image_base64
|
121
|
+
}
|
122
|
+
})
|
123
|
+
elif isinstance(block, dict):
|
124
|
+
if block.get("type") == "image_url":
|
125
|
+
combined_content.append({
|
126
|
+
"type": "image",
|
127
|
+
"source": {
|
128
|
+
"type": "url",
|
129
|
+
"url": block["image_url"]["url"]
|
130
|
+
}
|
131
|
+
})
|
132
|
+
elif block.get("type") == "image_base64":
|
133
|
+
combined_content.append({
|
134
|
+
"type": "image",
|
135
|
+
"source": {
|
136
|
+
"type": "base64",
|
137
|
+
"media_type": block["image_base64"]["media_type"],
|
138
|
+
"data": block["image_base64"]["data"]
|
139
|
+
}
|
140
|
+
})
|
141
|
+
formatted.append({"role": msg["role"], "content": combined_content})
|
142
|
+
return system_msg, formatted
|
143
|
+
|
144
|
+
elif "gemini" in model or "gpt" in model or "grok" in model:
|
145
|
+
formatted = []
|
146
|
+
for msg in messages:
|
147
|
+
content = msg["content"]
|
148
|
+
if isinstance(content, str):
|
149
|
+
formatted.append({"role": msg["role"], "parts": [content]})
|
150
|
+
elif isinstance(content, list):
|
151
|
+
parts = []
|
152
|
+
for block in content:
|
153
|
+
if isinstance(block, str):
|
154
|
+
parts.append(block)
|
155
|
+
elif isinstance(block, Image.Image):
|
156
|
+
parts.append(block)
|
157
|
+
elif isinstance(block, dict):
|
158
|
+
if block.get("type") == "image_url":
|
159
|
+
parts.append({"type": "image_url", "image_url": {"url": block["image_url"]["url"]}})
|
160
|
+
elif block.get("type") == "image_base64":
|
161
|
+
parts.append({"type": "image_base64", "image_base64": {"data": block["image_base64"]["data"], "media_type": block["image_base64"]["media_type"]}})
|
162
|
+
formatted.append({"role": msg["role"], "parts": parts})
|
163
|
+
return None, formatted
|
164
|
+
|
165
|
+
else: # OpenAI models
|
166
|
+
formatted = []
|
167
|
+
for msg in messages:
|
168
|
+
content = msg["content"]
|
169
|
+
if isinstance(content, str):
|
170
|
+
formatted.append({"role": msg["role"], "content": content})
|
171
|
+
elif isinstance(content, list):
|
172
|
+
# OpenAI expects 'content' as string; images are not directly supported
|
173
|
+
# You can convert images to URLs or descriptions if needed
|
174
|
+
combined_content = ""
|
175
|
+
for block in content:
|
176
|
+
if isinstance(block, str):
|
177
|
+
combined_content += block + "\n"
|
178
|
+
elif isinstance(block, Image.Image):
|
179
|
+
# Convert PIL.Image to base64 or upload and use URL
|
180
|
+
buffered = io.BytesIO()
|
181
|
+
block.save(buffered, format="PNG")
|
182
|
+
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
183
|
+
combined_content += f"[Image Base64: {image_base64[:30]}...]\n"
|
184
|
+
elif isinstance(block, dict):
|
185
|
+
if block.get("type") == "image_url":
|
186
|
+
combined_content += f"[Image: {block['image_url']['url']}]\n"
|
187
|
+
elif block.get("type") == "image_base64":
|
188
|
+
combined_content += f"[Image Base64: {block['image_base64']['data'][:30]}...]\n"
|
189
|
+
formatted.append({"role": msg["role"], "content": combined_content.strip()})
|
190
|
+
return None, formatted
|
191
|
+
|
192
|
+
system_msg, formatted_messages = format_messages_for_api(model, messages.copy())
|
193
|
+
|
194
|
+
if "claude" in model:
|
195
|
+
# Check for Vertex configuration
|
196
|
+
vertex_project_id = os.getenv('VERTEX_PROJECT_ID')
|
197
|
+
vertex_region = os.getenv('VERTEX_REGION')
|
198
|
+
|
199
|
+
if vertex_project_id and vertex_region:
|
200
|
+
client = AnthropicVertex(
|
201
|
+
region=vertex_region,
|
202
|
+
project_id=vertex_project_id
|
203
|
+
)
|
204
|
+
else:
|
205
|
+
client = anthropic.Anthropic(api_key=api_key, base_url=base_url)
|
206
|
+
|
207
|
+
response = client.messages.create(
|
208
|
+
model=model,
|
209
|
+
max_tokens=max_tokens,
|
210
|
+
temperature=temperature,
|
211
|
+
messages=formatted_messages,
|
212
|
+
system=system_msg
|
213
|
+
)
|
214
|
+
|
215
|
+
while response.stop_reason == "max_tokens":
|
216
|
+
if formatted_messages[-1]['role'] == "user":
|
217
|
+
formatted_messages.append({"role": "assistant", "content": response.completion})
|
86
218
|
else:
|
87
|
-
|
219
|
+
formatted_messages[-1]['content'] += response.completion
|
88
220
|
|
89
|
-
system_msg = messages.pop(0)["content"] if messages and messages[0]["role"] == "system" else ""
|
90
221
|
response = client.messages.create(
|
91
222
|
model=model,
|
92
223
|
max_tokens=max_tokens,
|
93
224
|
temperature=temperature,
|
94
|
-
messages=
|
225
|
+
messages=formatted_messages,
|
95
226
|
system=system_msg
|
96
227
|
)
|
228
|
+
|
229
|
+
if formatted_messages[-1]['role'] == "assistant" and response.stop_reason == "end_turn":
|
230
|
+
formatted_messages[-1]['content'] += response.completion
|
231
|
+
return formatted_messages[-1]['content']
|
232
|
+
|
233
|
+
return response.completion
|
234
|
+
|
235
|
+
elif "gemini" in model:
|
236
|
+
try:
|
237
|
+
# First try OpenAI-style API
|
238
|
+
client = openai.OpenAI(
|
239
|
+
api_key=api_key,
|
240
|
+
base_url="https://generativelanguage.googleapis.com/v1beta/"
|
241
|
+
)
|
242
|
+
# Set response_format based on json_format
|
243
|
+
response_format = {"type": "json_object"} if json_format else {"type": "plain_text"}
|
244
|
+
|
245
|
+
response = client.chat.completions.create(
|
246
|
+
model=model,
|
247
|
+
max_tokens=max_tokens,
|
248
|
+
top_p=top_p,
|
249
|
+
top_k=top_k,
|
250
|
+
messages=formatted_messages,
|
251
|
+
temperature=temperature,
|
252
|
+
response_format=response_format # Added response_format
|
253
|
+
)
|
254
|
+
return response.choices[0].message.content
|
255
|
+
|
256
|
+
except Exception as e:
|
257
|
+
# If OpenAI-style API fails, fall back to Google's genai library
|
258
|
+
logger.info("Falling back to Google's genai library")
|
259
|
+
genai.configure(api_key=api_key)
|
260
|
+
system_instruction = ""
|
261
|
+
for msg in messages:
|
262
|
+
if msg["role"] == "system":
|
263
|
+
system_instruction = msg["content"]
|
264
|
+
break
|
97
265
|
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
if gemini_messages:
|
151
|
-
gemini_messages[0].parts[0].text = f"{msg['content']}\n\n{gemini_messages[0].parts[0].text}"
|
152
|
-
else:
|
153
|
-
gemini_messages.append({"role": msg["role"], "parts": [{"text": msg["content"]}]})
|
154
|
-
|
155
|
-
# Create Gemini model and generate response
|
156
|
-
model = genai.GenerativeModel(model_name=model)
|
157
|
-
response = model.generate_content(
|
158
|
-
gemini_messages,
|
159
|
-
generation_config=genai.types.GenerationConfig(
|
160
|
-
temperature=temperature,
|
161
|
-
max_output_tokens=max_tokens
|
162
|
-
)
|
163
|
-
)
|
164
|
-
|
165
|
-
return response.text
|
166
|
-
|
167
|
-
elif "grok" in model:
|
168
|
-
# Randomly choose between OpenAI and Anthropic SDK
|
169
|
-
use_anthropic = random.choice([True, False])
|
170
|
-
|
171
|
-
if use_anthropic:
|
172
|
-
print("using anthropic")
|
173
|
-
client = anthropic.Anthropic(
|
174
|
-
api_key=api_key,
|
175
|
-
base_url="https://api.x.ai"
|
176
|
-
)
|
177
|
-
|
178
|
-
system_msg = messages.pop(0)["content"] if messages and messages[0]["role"] == "system" else ""
|
179
|
-
response = client.messages.create(
|
180
|
-
model=model,
|
181
|
-
max_tokens=max_tokens,
|
182
|
-
temperature=temperature,
|
183
|
-
messages=messages,
|
184
|
-
system=system_msg
|
185
|
-
)
|
186
|
-
return response.content[0].text
|
187
|
-
else:
|
188
|
-
print("using openai")
|
189
|
-
client = openai.OpenAI(
|
190
|
-
api_key=api_key,
|
191
|
-
base_url="https://api.x.ai/v1"
|
192
|
-
)
|
193
|
-
response = client.chat.completions.create(
|
194
|
-
model=model,
|
195
|
-
messages=messages,
|
196
|
-
max_tokens=max_tokens,
|
197
|
-
temperature=temperature
|
198
|
-
)
|
199
|
-
return response.choices[0].message.content
|
266
|
+
# 将其他消息转换为 gemini 格式
|
267
|
+
gemini_messages = format_messages_for_gemini(messages)
|
268
|
+
mime_type = "application/json" if json_format else "text/plain"
|
269
|
+
generation_config = genai.types.GenerationConfig(
|
270
|
+
temperature=temperature,
|
271
|
+
top_p=top_p,
|
272
|
+
top_k=top_k,
|
273
|
+
max_output_tokens=max_tokens,
|
274
|
+
response_mime_type=mime_type
|
275
|
+
)
|
276
|
+
|
277
|
+
model_instance = genai.GenerativeModel(
|
278
|
+
model_name=model,
|
279
|
+
system_instruction=system_instruction, # system 消息通过这里传入
|
280
|
+
generation_config=generation_config
|
281
|
+
)
|
282
|
+
|
283
|
+
response = model_instance.generate_content(gemini_messages, generation_config=generation_config)
|
284
|
+
|
285
|
+
return response.text
|
286
|
+
|
287
|
+
elif "grok" in model:
|
288
|
+
# Randomly choose between OpenAI and Anthropic SDK
|
289
|
+
use_anthropic = random.choice([True, False])
|
290
|
+
|
291
|
+
if use_anthropic:
|
292
|
+
logger.info("Using Anthropic for Grok model")
|
293
|
+
client = anthropic.Anthropic(
|
294
|
+
api_key=api_key,
|
295
|
+
base_url="https://api.x.ai"
|
296
|
+
)
|
297
|
+
|
298
|
+
system_msg = ""
|
299
|
+
if messages and messages[0]["role"] == "system":
|
300
|
+
system_msg = messages.pop(0)["content"]
|
301
|
+
|
302
|
+
response = client.messages.create(
|
303
|
+
model=model,
|
304
|
+
max_tokens=max_tokens,
|
305
|
+
temperature=temperature,
|
306
|
+
messages=formatted_messages,
|
307
|
+
system=system_msg
|
308
|
+
)
|
309
|
+
return response.completion
|
310
|
+
else:
|
311
|
+
logger.info("Using OpenAI for Grok model")
|
312
|
+
client = openai.OpenAI(
|
313
|
+
api_key=api_key,
|
314
|
+
base_url="https://api.x.ai/v1"
|
315
|
+
)
|
316
|
+
# Set response_format based on json_format
|
317
|
+
response_format = {"type": "json_object"} if json_format else {"type": "plain_text"}
|
200
318
|
|
201
|
-
else: # OpenAI models
|
202
|
-
client = openai.OpenAI(api_key=api_key, base_url=base_url)
|
203
319
|
response = client.chat.completions.create(
|
204
320
|
model=model,
|
205
|
-
messages=
|
321
|
+
messages=formatted_messages,
|
206
322
|
max_tokens=max_tokens,
|
207
323
|
temperature=temperature,
|
324
|
+
response_format=response_format # Added response_format
|
208
325
|
)
|
209
326
|
return response.choices[0].message.content
|
210
327
|
|
211
|
-
|
212
|
-
|
213
|
-
|
328
|
+
else: # OpenAI models
|
329
|
+
client = openai.OpenAI(api_key=api_key, base_url=base_url)
|
330
|
+
# Set response_format based on json_format
|
331
|
+
response_format = {"type": "json_object"} if json_format else {"type": "plain_text"}
|
214
332
|
|
215
|
-
|
333
|
+
response = client.chat.completions.create(
|
334
|
+
model=model,
|
335
|
+
messages=formatted_messages,
|
336
|
+
max_tokens=max_tokens,
|
337
|
+
temperature=temperature,
|
338
|
+
response_format=response_format # Added response_format
|
339
|
+
)
|
340
|
+
return response.choices[0].message.content
|
216
341
|
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
342
|
+
# Release the API key after successful use
|
343
|
+
if not api_key:
|
344
|
+
# key_manager.release_config(service, api_key)
|
345
|
+
pass
|
346
|
+
|
347
|
+
return response
|
222
348
|
|
223
349
|
except Exception as e:
|
224
350
|
logger.error(f"Error in completion: {str(e)}")
|
225
351
|
raise
|
226
352
|
|
227
353
|
class Agent:
|
228
|
-
def __init__(self, model_name: str, messages: Optional[str] = None,
|
354
|
+
def __init__(self, model_name: str, messages: Optional[Union[str, List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]]] = None,
|
229
355
|
memory_enabled: bool = False, api_key: Optional[str] = None) -> None:
|
230
356
|
"""Initialize an Agent instance."""
|
231
|
-
# valid_models = ['gpt-3.5-turbo', 'gpt-4', 'claude-2.1', 'gemini-1.5-pro', 'gemini-1.5-flash', 'grok-beta', 'claude-3-5-sonnet-20241022']
|
232
|
-
# if model_name not in valid_models:
|
233
|
-
# raise ValueError(f"Model {model_name} not supported. Supported models: {valid_models}")
|
234
|
-
|
235
357
|
self.id = f"{model_name}-{uuid.uuid4().hex[:8]}"
|
236
358
|
self.model_name = model_name
|
237
|
-
self.history = ChatHistory(messages)
|
359
|
+
self.history = ChatHistory(messages) if messages else ChatHistory()
|
238
360
|
self.memory_enabled = memory_enabled
|
239
361
|
self.api_key = api_key
|
240
362
|
self.repo_content = []
|
241
363
|
|
242
|
-
def add_message(self, role, content):
|
243
|
-
|
244
|
-
while self.repo_content:
|
245
|
-
repo = self.repo_content.pop()
|
246
|
-
repo_content += f"<repo>\n{repo}\n</repo>\n"
|
247
|
-
|
248
|
-
content = repo_content + content
|
364
|
+
def add_message(self, role: str, content: Union[str, List[Union[str, Image.Image, Dict]]]):
|
365
|
+
"""Add a message to the conversation."""
|
249
366
|
self.history.add_message(content, role)
|
250
367
|
|
251
|
-
def
|
368
|
+
def add_user_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
|
369
|
+
"""Add a user message."""
|
370
|
+
self.history.add_user_message(content)
|
371
|
+
|
372
|
+
def add_assistant_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
|
373
|
+
"""Add an assistant message."""
|
374
|
+
self.history.add_assistant_message(content)
|
375
|
+
|
376
|
+
def add_image(self, image_path: Optional[str] = None, image_url: Optional[str] = None, media_type: Optional[str] = "image/jpeg"):
|
377
|
+
"""
|
378
|
+
Add an image to the conversation.
|
379
|
+
Either image_path or image_url must be provided.
|
380
|
+
"""
|
381
|
+
if not image_path and not image_url:
|
382
|
+
raise ValueError("Either image_path or image_url must be provided.")
|
383
|
+
|
384
|
+
if image_path:
|
385
|
+
if not os.path.exists(image_path):
|
386
|
+
raise FileNotFoundError(f"Image file {image_path} does not exist.")
|
387
|
+
if "gemini" in self.model_name:
|
388
|
+
# For Gemini, load as PIL.Image
|
389
|
+
image_pil = Image.open(image_path)
|
390
|
+
image_block = image_pil
|
391
|
+
else:
|
392
|
+
# For Claude and others, use base64 encoding
|
393
|
+
with open(image_path, "rb") as img_file:
|
394
|
+
image_data = base64.standard_b64encode(img_file.read()).decode("utf-8")
|
395
|
+
image_block = {
|
396
|
+
"type": "image_base64",
|
397
|
+
"image_base64": {
|
398
|
+
"media_type": media_type,
|
399
|
+
"data": image_data
|
400
|
+
}
|
401
|
+
}
|
402
|
+
else:
|
403
|
+
# If image_url is provided
|
404
|
+
if "gemini" in self.model_name:
|
405
|
+
# For Gemini, you can pass image URLs directly
|
406
|
+
image_block = {"type": "image_url", "image_url": {"url": image_url}}
|
407
|
+
else:
|
408
|
+
# For Claude and others, use image URLs
|
409
|
+
image_block = {
|
410
|
+
"type": "image_url",
|
411
|
+
"image_url": {
|
412
|
+
"url": image_url
|
413
|
+
}
|
414
|
+
}
|
415
|
+
|
416
|
+
# Add the image block to the last user message or as a new user message
|
417
|
+
if self.history.last_role == "user":
|
418
|
+
current_content = self.history.messages[-1]["content"]
|
419
|
+
if isinstance(current_content, list):
|
420
|
+
current_content.append(image_block)
|
421
|
+
else:
|
422
|
+
self.history.messages[-1]["content"] = [current_content, image_block]
|
423
|
+
else:
|
424
|
+
# Start a new user message with the image
|
425
|
+
self.history.add_message([image_block], "user")
|
426
|
+
|
427
|
+
def generate_response(self, max_tokens=3585, temperature=0.7, top_p=1.0, top_k=40, json_format: bool = False) -> str:
|
428
|
+
"""Generate a response from the agent.
|
429
|
+
|
430
|
+
Args:
|
431
|
+
max_tokens (int, optional): Maximum number of tokens. Defaults to 3585.
|
432
|
+
temperature (float, optional): Sampling temperature. Defaults to 0.7.
|
433
|
+
json_format (bool, optional): Whether to enable JSON output format. Defaults to False.
|
434
|
+
|
435
|
+
Returns:
|
436
|
+
str: The generated response.
|
437
|
+
"""
|
252
438
|
if not self.history.messages:
|
253
439
|
raise ValueError("No messages in history to generate response from")
|
254
440
|
|
255
|
-
messages =
|
256
|
-
|
441
|
+
messages = self.history.messages
|
442
|
+
print(self.model_name)
|
257
443
|
response_text = completion(
|
258
444
|
model=self.model_name,
|
259
445
|
messages=messages,
|
260
446
|
max_tokens=max_tokens,
|
261
447
|
temperature=temperature,
|
262
|
-
|
448
|
+
top_p=top_p,
|
449
|
+
top_k=top_k,
|
450
|
+
api_key=self.api_key,
|
451
|
+
json_format=json_format # Pass json_format to completion
|
263
452
|
)
|
264
|
-
if
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
self.
|
453
|
+
if self.model_name.startswith("openai"):
|
454
|
+
# OpenAI does not support images, so responses are simple strings
|
455
|
+
if self.history.messages[-1]["role"] == "assistant":
|
456
|
+
self.history.messages[-1]["content"] = response_text
|
457
|
+
elif self.memory_enabled:
|
458
|
+
self.add_message("assistant", response_text)
|
459
|
+
elif "claude" in self.model_name:
|
460
|
+
if self.history.messages[-1]["role"] == "assistant":
|
461
|
+
self.history.messages[-1]["content"] = response_text
|
462
|
+
elif self.memory_enabled:
|
463
|
+
self.add_message("assistant", response_text)
|
464
|
+
elif "gemini" in self.model_name or "grok" in self.model_name:
|
465
|
+
if self.history.messages[-1]["role"] == "assistant":
|
466
|
+
if isinstance(self.history.messages[-1]["content"], list):
|
467
|
+
self.history.messages[-1]["content"].append(response_text)
|
468
|
+
else:
|
469
|
+
self.history.messages[-1]["content"] = [self.history.messages[-1]["content"], response_text]
|
470
|
+
elif self.memory_enabled:
|
471
|
+
self.add_message("assistant", response_text)
|
472
|
+
else:
|
473
|
+
# Handle other models similarly
|
474
|
+
if self.history.messages[-1]["role"] == "assistant":
|
475
|
+
self.history.messages[-1]["content"] = response_text
|
476
|
+
elif self.memory_enabled:
|
477
|
+
self.add_message("assistant", response_text)
|
269
478
|
|
270
479
|
return response_text
|
271
480
|
|
@@ -274,11 +483,12 @@ class Agent:
|
|
274
483
|
with open(filename, 'w', encoding='utf-8') as file:
|
275
484
|
json.dump(self.history.messages, file, ensure_ascii=False, indent=4)
|
276
485
|
|
277
|
-
def load_conversation(self, filename=None):
|
486
|
+
def load_conversation(self, filename: Optional[str] = None):
|
278
487
|
if filename is None:
|
279
488
|
filename = f"{self.id}.json"
|
280
489
|
with open(filename, 'r', encoding='utf-8') as file:
|
281
490
|
messages = json.load(file)
|
491
|
+
# Handle deserialization of images if necessary
|
282
492
|
self.history = ChatHistory(messages)
|
283
493
|
|
284
494
|
def add_repo(self, repo_url: Optional[str] = None, username: Optional[str] = None, repo_name: Optional[str] = None, commit_hash: Optional[str] = None):
|
@@ -305,28 +515,31 @@ class Agent:
|
|
305
515
|
raise ValueError(f"Failed to download repository from {repo_url}")
|
306
516
|
|
307
517
|
if __name__ == "__main__":
|
308
|
-
|
309
|
-
#
|
310
|
-
|
311
|
-
|
312
|
-
# from agent.messageloader import information_detector_messages
|
518
|
+
# Example Usage
|
519
|
+
# Create an Agent instance (Gemini model)
|
520
|
+
agent = Agent("gemini-1.5-flash", "you are Jack101", memory_enabled=True)
|
313
521
|
|
314
|
-
#
|
315
|
-
|
316
|
-
# information_detector_agent.add_message("user", text)
|
317
|
-
# response = information_detector_agent.generate_response()
|
318
|
-
# print(response)
|
319
|
-
agent = Agent("gemini-1.5-pro-002", "you are an assistant", memory_enabled=True)
|
522
|
+
# Add an image
|
523
|
+
agent.add_image(image_path="/Users/junfan/Projects/Personal/oneapi/dialog_manager/example.png")
|
320
524
|
|
321
|
-
#
|
322
|
-
|
525
|
+
# Add a user message
|
526
|
+
agent.add_message("user", "Who are you? What's in this image?")
|
323
527
|
|
324
|
-
#
|
325
|
-
|
326
|
-
|
528
|
+
# Generate response with JSON format enabled
|
529
|
+
try:
|
530
|
+
response = agent.generate_response(json_format=True) # json_format set to True
|
531
|
+
print("Response:", response)
|
532
|
+
except Exception as e:
|
533
|
+
logger.error(f"Failed to generate response: {e}")
|
327
534
|
|
328
|
-
|
329
|
-
print(
|
535
|
+
# Print the entire conversation history
|
536
|
+
print("Conversation History:")
|
537
|
+
print(agent.history)
|
538
|
+
|
539
|
+
# Pop the last message
|
330
540
|
last_message = agent.history.pop()
|
331
|
-
print(last_message)
|
332
|
-
|
541
|
+
print("Last Message:", last_message)
|
542
|
+
|
543
|
+
# Generate another response without JSON format
|
544
|
+
response = agent.generate_response()
|
545
|
+
print("Response:", response)
|
@@ -1,53 +1,122 @@
|
|
1
1
|
from typing import List, Dict, Optional, Union
|
2
|
+
from PIL import Image
|
2
3
|
|
3
4
|
class ChatHistory:
|
4
|
-
def __init__(self, input_data: Union[str, List[Dict[str, str]]] = "") -> None:
|
5
|
-
self.messages: List[Dict[str, str]] = []
|
5
|
+
def __init__(self, input_data: Union[str, List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]] = "") -> None:
|
6
|
+
self.messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]] = []
|
6
7
|
if isinstance(input_data, str) and input_data:
|
7
8
|
self.add_message(input_data, "system")
|
8
9
|
elif isinstance(input_data, list):
|
9
10
|
self.load_messages(input_data)
|
10
|
-
self.last_role: str = "system" if not self.messages else self.
|
11
|
+
self.last_role: str = "system" if not self.messages else self.get_last_role()
|
11
12
|
|
12
|
-
def load_messages(self, messages: List[Dict[str, str]]) -> None:
|
13
|
+
def load_messages(self, messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]) -> None:
|
13
14
|
for message in messages:
|
14
15
|
if not ("role" in message and "content" in message):
|
15
16
|
raise ValueError("Each message must have a 'role' and 'content'.")
|
16
17
|
if message["role"] not in ["user", "assistant", "system"]:
|
17
18
|
raise ValueError(f"Invalid role: {message['role']}")
|
18
19
|
self.messages.append(message)
|
19
|
-
|
20
|
+
self.last_role = self.get_last_role()
|
21
|
+
|
22
|
+
def get_last_role(self):
|
23
|
+
return self.messages[-1]["role"] if self.messages else "system"
|
24
|
+
|
20
25
|
def pop(self):
|
21
26
|
if not self.messages:
|
22
27
|
return None
|
23
|
-
|
28
|
+
|
24
29
|
popped_message = self.messages.pop()
|
25
|
-
|
30
|
+
|
26
31
|
if self.messages:
|
27
|
-
self.last_role = self.
|
32
|
+
self.last_role = self.get_last_role()
|
28
33
|
else:
|
29
34
|
self.last_role = "system"
|
30
|
-
|
35
|
+
|
31
36
|
return popped_message["content"]
|
32
37
|
|
33
38
|
def __len__(self):
|
34
39
|
return len(self.messages)
|
35
40
|
|
36
41
|
def __str__(self):
|
37
|
-
|
42
|
+
formatted_messages = []
|
43
|
+
for i, msg in enumerate(self.messages):
|
44
|
+
role = msg['role']
|
45
|
+
content = msg['content']
|
46
|
+
if isinstance(content, str):
|
47
|
+
formatted_content = content
|
48
|
+
elif isinstance(content, list):
|
49
|
+
parts = []
|
50
|
+
for block in content:
|
51
|
+
if isinstance(block, str):
|
52
|
+
parts.append(block)
|
53
|
+
elif isinstance(block, Image.Image):
|
54
|
+
parts.append(f"[Image Object: {block.filename}]")
|
55
|
+
elif isinstance(block, dict):
|
56
|
+
if block.get("type") == "image_url":
|
57
|
+
parts.append(f"[Image URL: {block.get('image_url', {}).get('url', '')}]")
|
58
|
+
elif block.get("type") == "image_base64":
|
59
|
+
parts.append(f"[Image Base64: {block.get('image_base64', {}).get('data', '')[:30]}...]")
|
60
|
+
formatted_content = "\n".join(parts)
|
61
|
+
else:
|
62
|
+
formatted_content = str(content)
|
63
|
+
formatted_messages.append(f"Message {i} ({role}): {formatted_content}")
|
64
|
+
return '\n'.join(formatted_messages)
|
38
65
|
|
39
66
|
def __getitem__(self, key):
|
40
67
|
if isinstance(key, slice):
|
41
|
-
|
42
|
-
|
43
|
-
|
68
|
+
sliced_messages = self.messages[key]
|
69
|
+
formatted = []
|
70
|
+
for msg in sliced_messages:
|
71
|
+
role = msg['role']
|
72
|
+
content = msg['content']
|
73
|
+
if isinstance(content, str):
|
74
|
+
formatted_content = content
|
75
|
+
elif isinstance(content, list):
|
76
|
+
parts = []
|
77
|
+
for block in content:
|
78
|
+
if isinstance(block, str):
|
79
|
+
parts.append(block)
|
80
|
+
elif isinstance(block, Image.Image):
|
81
|
+
parts.append(f"[Image Object: {block.filename}]")
|
82
|
+
elif isinstance(block, dict):
|
83
|
+
if block.get("type") == "image_url":
|
84
|
+
parts.append(f"[Image URL: {block.get('image_url', {}).get('url', '')}]")
|
85
|
+
elif block.get("type") == "image_base64":
|
86
|
+
parts.append(f"[Image Base64: {block.get('image_base64', {}).get('data', '')[:30]}...]")
|
87
|
+
formatted_content = "\n".join(parts)
|
88
|
+
else:
|
89
|
+
formatted_content = str(content)
|
90
|
+
formatted.append(f"({role}): {formatted_content}")
|
91
|
+
print('\n'.join(formatted))
|
92
|
+
return sliced_messages
|
44
93
|
elif isinstance(key, int):
|
45
94
|
# Adjust for negative indices
|
46
95
|
if key < 0:
|
47
|
-
key += len(self.messages)
|
96
|
+
key += len(self.messages)
|
48
97
|
if 0 <= key < len(self.messages):
|
98
|
+
msg = self.messages[key]
|
99
|
+
role = msg['role']
|
100
|
+
content = msg['content']
|
101
|
+
if isinstance(content, str):
|
102
|
+
formatted_content = content
|
103
|
+
elif isinstance(content, list):
|
104
|
+
parts = []
|
105
|
+
for block in content:
|
106
|
+
if isinstance(block, str):
|
107
|
+
parts.append(block)
|
108
|
+
elif isinstance(block, Image.Image):
|
109
|
+
parts.append(f"[Image Object: {block.filename}]")
|
110
|
+
elif isinstance(block, dict):
|
111
|
+
if block.get("type") == "image_url":
|
112
|
+
parts.append(f"[Image URL: {block.get('image_url', {}).get('url', '')}]")
|
113
|
+
elif block.get("type") == "image_base64":
|
114
|
+
parts.append(f"[Image Base64: {block.get('image_base64', {}).get('data', '')[:30]}...]")
|
115
|
+
formatted_content = "\n".join(parts)
|
116
|
+
else:
|
117
|
+
formatted_content = str(content)
|
49
118
|
snippet = self.get_conversation_snippet(key)
|
50
|
-
print('\n'.join([f"({v['role']}): {
|
119
|
+
print('\n'.join([f"({v['role']}): {v['content']}" for k, v in snippet.items() if v]))
|
51
120
|
return self.messages[key]
|
52
121
|
else:
|
53
122
|
raise IndexError("Message index out of range.")
|
@@ -55,8 +124,8 @@ class ChatHistory:
|
|
55
124
|
raise TypeError("Invalid argument type.")
|
56
125
|
|
57
126
|
def __setitem__(self, index, value):
|
58
|
-
if not isinstance(value, str):
|
59
|
-
raise ValueError("Message content must be a string.")
|
127
|
+
if not isinstance(value, (str, list)):
|
128
|
+
raise ValueError("Message content must be a string or a list of content blocks.")
|
60
129
|
role = "system" if index % 2 == 0 else "user"
|
61
130
|
self.messages[index] = {"role": role, "content": value}
|
62
131
|
|
@@ -68,19 +137,27 @@ class ChatHistory:
|
|
68
137
|
self.add_message(message, next_role)
|
69
138
|
|
70
139
|
def __contains__(self, item):
|
71
|
-
|
140
|
+
for message in self.messages:
|
141
|
+
content = message['content']
|
142
|
+
if isinstance(content, str) and item in content:
|
143
|
+
return True
|
144
|
+
elif isinstance(content, list):
|
145
|
+
for block in content:
|
146
|
+
if isinstance(block, str) and item in block:
|
147
|
+
return True
|
148
|
+
return False
|
72
149
|
|
73
|
-
def add_message(self, content, role):
|
150
|
+
def add_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]], role: str):
|
74
151
|
self.messages.append({"role": role, "content": content})
|
75
152
|
self.last_role = role
|
76
153
|
|
77
|
-
def add_user_message(self, content):
|
154
|
+
def add_user_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
|
78
155
|
if self.last_role in ["system", "assistant"]:
|
79
156
|
self.add_message(content, "user")
|
80
157
|
else:
|
81
158
|
raise ValueError("A user message must follow a system or assistant message.")
|
82
159
|
|
83
|
-
def add_assistant_message(self, content):
|
160
|
+
def add_assistant_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
|
84
161
|
if self.last_role == "user":
|
85
162
|
self.add_message(content, "assistant")
|
86
163
|
else:
|
@@ -110,7 +187,17 @@ class ChatHistory:
|
|
110
187
|
print(f"Content of the last message: {status['last_message_content']}")
|
111
188
|
|
112
189
|
def search_for_keyword(self, keyword):
|
113
|
-
|
190
|
+
results = []
|
191
|
+
for msg in self.messages:
|
192
|
+
content = msg['content']
|
193
|
+
if isinstance(content, str) and keyword.lower() in content.lower():
|
194
|
+
results.append(msg)
|
195
|
+
elif isinstance(content, list):
|
196
|
+
for block in content:
|
197
|
+
if isinstance(block, str) and keyword.lower() in block.lower():
|
198
|
+
results.append(msg)
|
199
|
+
break
|
200
|
+
return results
|
114
201
|
|
115
202
|
def has_user_or_assistant_spoken_since_last_system(self):
|
116
203
|
for msg in reversed(self.messages):
|
@@ -143,4 +230,4 @@ class ChatHistory:
|
|
143
230
|
@staticmethod
|
144
231
|
def color_text(text, color):
|
145
232
|
colors = {"green": "\033[92m", "red": "\033[91m", "end": "\033[0m"}
|
146
|
-
return f"{colors
|
233
|
+
return f"{colors.get(color, '')}{text}{colors.get('end', '')}"
|
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.2
|
2
2
|
Name: llm_dialog_manager
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.4.1
|
4
4
|
Summary: A Python package for managing LLM chat conversation history
|
5
5
|
Author-email: xihajun <work@2333.fun>
|
6
6
|
License: MIT
|
@@ -64,6 +64,7 @@ A Python package for managing AI chat conversation history with support for mult
|
|
64
64
|
- Memory management options
|
65
65
|
- Conversation search and indexing
|
66
66
|
- Rich conversation display options
|
67
|
+
- Vision & Json Output enabled [20240111]
|
67
68
|
|
68
69
|
## Installation
|
69
70
|
|
@@ -0,0 +1,9 @@
|
|
1
|
+
llm_dialog_manager/__init__.py,sha256=hTHvsXzvD5geKgv2XERYcp2f-T3LoVVc3arXfPtNS1k,86
|
2
|
+
llm_dialog_manager/agent.py,sha256=ZKO3eKHTKcbmYpVRRIpzDy7Tlp_VgQ90ewr1758Ozgs,23931
|
3
|
+
llm_dialog_manager/chat_history.py,sha256=DKKRnj_M6h-4JncnH6KekMTghX7vMgdN3J9uOwXKzMU,10347
|
4
|
+
llm_dialog_manager/key_manager.py,sha256=shvxmn4zUtQx_p-x1EFyOmnk-WlhigbpKtxTKve-zXk,4421
|
5
|
+
llm_dialog_manager-0.4.1.dist-info/LICENSE,sha256=vWGbYgGuWpWrXL8-xi6pNcX5UzD6pWoIAZmcetyfbus,1064
|
6
|
+
llm_dialog_manager-0.4.1.dist-info/METADATA,sha256=LER5FN6lFQFPs_8A-fIM7VYmqN-fh0nCD6Dt8vslsiY,4194
|
7
|
+
llm_dialog_manager-0.4.1.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
8
|
+
llm_dialog_manager-0.4.1.dist-info/top_level.txt,sha256=u2EQEXW0NGAt0AAHT7jx1odXZ4rZfjcgbmJhvKFuMkI,19
|
9
|
+
llm_dialog_manager-0.4.1.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
llm_dialog_manager/__init__.py,sha256=L-vxURJQtPwajtaFFl8gFI-xAsggFXNNY3aOqlVIViU,86
|
2
|
-
llm_dialog_manager/agent.py,sha256=ruZWlSZZ6jBr8x8I-lWdsjiSP44BTYTR9fBS0FuE-Rc,13310
|
3
|
-
llm_dialog_manager/chat_history.py,sha256=xKA-oQCv8jv_g8EhXrG9h1S8Icbj2FfqPIhbty5vra4,6033
|
4
|
-
llm_dialog_manager/key_manager.py,sha256=shvxmn4zUtQx_p-x1EFyOmnk-WlhigbpKtxTKve-zXk,4421
|
5
|
-
llm_dialog_manager-0.3.4.dist-info/LICENSE,sha256=vWGbYgGuWpWrXL8-xi6pNcX5UzD6pWoIAZmcetyfbus,1064
|
6
|
-
llm_dialog_manager-0.3.4.dist-info/METADATA,sha256=6XPZd3oGCVC8YS-EVHII1EFFsnml1kS2VgRYoVsNaIQ,4152
|
7
|
-
llm_dialog_manager-0.3.4.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
8
|
-
llm_dialog_manager-0.3.4.dist-info/top_level.txt,sha256=u2EQEXW0NGAt0AAHT7jx1odXZ4rZfjcgbmJhvKFuMkI,19
|
9
|
-
llm_dialog_manager-0.3.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|