llm-dialog-manager 0.3.4__tar.gz → 0.4.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {llm_dialog_manager-0.3.4 → llm_dialog_manager-0.4.1}/PKG-INFO +3 -2
- {llm_dialog_manager-0.3.4 → llm_dialog_manager-0.4.1}/README.md +1 -0
- {llm_dialog_manager-0.3.4 → llm_dialog_manager-0.4.1}/llm_dialog_manager/__init__.py +1 -1
- llm_dialog_manager-0.4.1/llm_dialog_manager/agent.py +545 -0
- llm_dialog_manager-0.4.1/llm_dialog_manager/chat_history.py +233 -0
- {llm_dialog_manager-0.3.4 → llm_dialog_manager-0.4.1}/llm_dialog_manager.egg-info/PKG-INFO +3 -2
- llm_dialog_manager-0.4.1/pyproject.toml +32 -0
- llm_dialog_manager-0.3.4/llm_dialog_manager/agent.py +0 -332
- llm_dialog_manager-0.3.4/llm_dialog_manager/chat_history.py +0 -146
- llm_dialog_manager-0.3.4/pyproject.toml +0 -64
- {llm_dialog_manager-0.3.4 → llm_dialog_manager-0.4.1}/LICENSE +0 -0
- {llm_dialog_manager-0.3.4 → llm_dialog_manager-0.4.1}/llm_dialog_manager/key_manager.py +0 -0
- {llm_dialog_manager-0.3.4 → llm_dialog_manager-0.4.1}/llm_dialog_manager.egg-info/SOURCES.txt +0 -0
- {llm_dialog_manager-0.3.4 → llm_dialog_manager-0.4.1}/llm_dialog_manager.egg-info/dependency_links.txt +0 -0
- {llm_dialog_manager-0.3.4 → llm_dialog_manager-0.4.1}/llm_dialog_manager.egg-info/requires.txt +0 -0
- {llm_dialog_manager-0.3.4 → llm_dialog_manager-0.4.1}/llm_dialog_manager.egg-info/top_level.txt +0 -0
- {llm_dialog_manager-0.3.4 → llm_dialog_manager-0.4.1}/setup.cfg +0 -0
- {llm_dialog_manager-0.3.4 → llm_dialog_manager-0.4.1}/tests/test_agent.py +0 -0
- {llm_dialog_manager-0.3.4 → llm_dialog_manager-0.4.1}/tests/test_chat_history.py +0 -0
- {llm_dialog_manager-0.3.4 → llm_dialog_manager-0.4.1}/tests/test_key_manager.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.2
|
2
2
|
Name: llm_dialog_manager
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.4.1
|
4
4
|
Summary: A Python package for managing LLM chat conversation history
|
5
5
|
Author-email: xihajun <work@2333.fun>
|
6
6
|
License: MIT
|
@@ -64,6 +64,7 @@ A Python package for managing AI chat conversation history with support for mult
|
|
64
64
|
- Memory management options
|
65
65
|
- Conversation search and indexing
|
66
66
|
- Rich conversation display options
|
67
|
+
- Vision & Json Output enabled [20240111]
|
67
68
|
|
68
69
|
## Installation
|
69
70
|
|
@@ -0,0 +1,545 @@
|
|
1
|
+
# Standard library imports
|
2
|
+
import json
|
3
|
+
import os
|
4
|
+
import uuid
|
5
|
+
from typing import List, Dict, Optional, Union
|
6
|
+
import logging
|
7
|
+
from pathlib import Path
|
8
|
+
import random
|
9
|
+
import requests
|
10
|
+
import zipfile
|
11
|
+
import io
|
12
|
+
import base64
|
13
|
+
from PIL import Image
|
14
|
+
|
15
|
+
# Third-party imports
|
16
|
+
import anthropic
|
17
|
+
from anthropic import AnthropicVertex
|
18
|
+
import google.generativeai as genai
|
19
|
+
import openai
|
20
|
+
from dotenv import load_dotenv
|
21
|
+
|
22
|
+
# Local imports
|
23
|
+
from .chat_history import ChatHistory
|
24
|
+
from .key_manager import key_manager
|
25
|
+
|
26
|
+
# Set up logging
|
27
|
+
logging.basicConfig(level=logging.INFO)
|
28
|
+
logger = logging.getLogger(__name__)
|
29
|
+
|
30
|
+
# Load environment variables
|
31
|
+
def load_env_vars():
|
32
|
+
"""Load environment variables from .env file"""
|
33
|
+
env_path = Path(__file__).parent / '.env'
|
34
|
+
if env_path.exists():
|
35
|
+
load_dotenv(env_path)
|
36
|
+
else:
|
37
|
+
logger.warning(".env file not found. Using system environment variables.")
|
38
|
+
|
39
|
+
load_env_vars()
|
40
|
+
|
41
|
+
def format_messages_for_gemini(messages):
|
42
|
+
"""
|
43
|
+
将标准化的消息格式转化为 Gemini 格式。
|
44
|
+
system 消息应该通过 GenerativeModel 的 system_instruction 参数传入,
|
45
|
+
不在这个函数处理。
|
46
|
+
"""
|
47
|
+
gemini_messages = []
|
48
|
+
|
49
|
+
for msg in messages:
|
50
|
+
role = msg["role"]
|
51
|
+
content = msg["content"]
|
52
|
+
|
53
|
+
# 跳过 system 消息,因为它会通过 system_instruction 设置
|
54
|
+
if role == "system":
|
55
|
+
continue
|
56
|
+
|
57
|
+
# 处理 user/assistant 消息
|
58
|
+
# 如果 content 是单一对象,转换为列表
|
59
|
+
if not isinstance(content, list):
|
60
|
+
content = [content]
|
61
|
+
|
62
|
+
gemini_messages.append({
|
63
|
+
"role": role,
|
64
|
+
"parts": content # content 可以包含文本和 FileMedia
|
65
|
+
})
|
66
|
+
|
67
|
+
return gemini_messages
|
68
|
+
|
69
|
+
def completion(model: str, messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]], max_tokens: int = 1000,
|
70
|
+
temperature: float = 0.5, top_p: float = 1.0, top_k: int = 40, api_key: Optional[str] = None,
|
71
|
+
base_url: Optional[str] = None, json_format: bool = False) -> str:
|
72
|
+
"""
|
73
|
+
Generate a completion using the specified model and messages.
|
74
|
+
"""
|
75
|
+
try:
|
76
|
+
service = ""
|
77
|
+
if "claude" in model:
|
78
|
+
service = "anthropic"
|
79
|
+
elif "gemini" in model:
|
80
|
+
service = "gemini"
|
81
|
+
elif "grok" in model:
|
82
|
+
service = "x"
|
83
|
+
else:
|
84
|
+
service = "openai"
|
85
|
+
|
86
|
+
# Get API key and base URL from key manager if not provided
|
87
|
+
if not api_key:
|
88
|
+
# api_key, base_url = key_manager.get_config(service)
|
89
|
+
# Placeholder for key_manager
|
90
|
+
api_key = os.getenv(f"{service.upper()}_API_KEY")
|
91
|
+
base_url = os.getenv(f"{service.upper()}_BASE_URL")
|
92
|
+
|
93
|
+
def format_messages_for_api(model, messages):
|
94
|
+
"""Convert ChatHistory messages to the format required by the specific API."""
|
95
|
+
if "claude" in model:
|
96
|
+
formatted = []
|
97
|
+
system_msg = ""
|
98
|
+
if messages and messages[0]["role"] == "system":
|
99
|
+
system_msg = messages.pop(0)["content"]
|
100
|
+
for msg in messages:
|
101
|
+
content = msg["content"]
|
102
|
+
if isinstance(content, str):
|
103
|
+
formatted.append({"role": msg["role"], "content": content})
|
104
|
+
elif isinstance(content, list):
|
105
|
+
# Combine content blocks into a single message
|
106
|
+
combined_content = []
|
107
|
+
for block in content:
|
108
|
+
if isinstance(block, str):
|
109
|
+
combined_content.append({"type": "text", "text": block})
|
110
|
+
elif isinstance(block, Image.Image):
|
111
|
+
# For Claude, convert PIL.Image to base64
|
112
|
+
buffered = io.BytesIO()
|
113
|
+
block.save(buffered, format="PNG")
|
114
|
+
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
115
|
+
combined_content.append({
|
116
|
+
"type": "image",
|
117
|
+
"source": {
|
118
|
+
"type": "base64",
|
119
|
+
"media_type": "image/png",
|
120
|
+
"data": image_base64
|
121
|
+
}
|
122
|
+
})
|
123
|
+
elif isinstance(block, dict):
|
124
|
+
if block.get("type") == "image_url":
|
125
|
+
combined_content.append({
|
126
|
+
"type": "image",
|
127
|
+
"source": {
|
128
|
+
"type": "url",
|
129
|
+
"url": block["image_url"]["url"]
|
130
|
+
}
|
131
|
+
})
|
132
|
+
elif block.get("type") == "image_base64":
|
133
|
+
combined_content.append({
|
134
|
+
"type": "image",
|
135
|
+
"source": {
|
136
|
+
"type": "base64",
|
137
|
+
"media_type": block["image_base64"]["media_type"],
|
138
|
+
"data": block["image_base64"]["data"]
|
139
|
+
}
|
140
|
+
})
|
141
|
+
formatted.append({"role": msg["role"], "content": combined_content})
|
142
|
+
return system_msg, formatted
|
143
|
+
|
144
|
+
elif "gemini" in model or "gpt" in model or "grok" in model:
|
145
|
+
formatted = []
|
146
|
+
for msg in messages:
|
147
|
+
content = msg["content"]
|
148
|
+
if isinstance(content, str):
|
149
|
+
formatted.append({"role": msg["role"], "parts": [content]})
|
150
|
+
elif isinstance(content, list):
|
151
|
+
parts = []
|
152
|
+
for block in content:
|
153
|
+
if isinstance(block, str):
|
154
|
+
parts.append(block)
|
155
|
+
elif isinstance(block, Image.Image):
|
156
|
+
parts.append(block)
|
157
|
+
elif isinstance(block, dict):
|
158
|
+
if block.get("type") == "image_url":
|
159
|
+
parts.append({"type": "image_url", "image_url": {"url": block["image_url"]["url"]}})
|
160
|
+
elif block.get("type") == "image_base64":
|
161
|
+
parts.append({"type": "image_base64", "image_base64": {"data": block["image_base64"]["data"], "media_type": block["image_base64"]["media_type"]}})
|
162
|
+
formatted.append({"role": msg["role"], "parts": parts})
|
163
|
+
return None, formatted
|
164
|
+
|
165
|
+
else: # OpenAI models
|
166
|
+
formatted = []
|
167
|
+
for msg in messages:
|
168
|
+
content = msg["content"]
|
169
|
+
if isinstance(content, str):
|
170
|
+
formatted.append({"role": msg["role"], "content": content})
|
171
|
+
elif isinstance(content, list):
|
172
|
+
# OpenAI expects 'content' as string; images are not directly supported
|
173
|
+
# You can convert images to URLs or descriptions if needed
|
174
|
+
combined_content = ""
|
175
|
+
for block in content:
|
176
|
+
if isinstance(block, str):
|
177
|
+
combined_content += block + "\n"
|
178
|
+
elif isinstance(block, Image.Image):
|
179
|
+
# Convert PIL.Image to base64 or upload and use URL
|
180
|
+
buffered = io.BytesIO()
|
181
|
+
block.save(buffered, format="PNG")
|
182
|
+
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
183
|
+
combined_content += f"[Image Base64: {image_base64[:30]}...]\n"
|
184
|
+
elif isinstance(block, dict):
|
185
|
+
if block.get("type") == "image_url":
|
186
|
+
combined_content += f"[Image: {block['image_url']['url']}]\n"
|
187
|
+
elif block.get("type") == "image_base64":
|
188
|
+
combined_content += f"[Image Base64: {block['image_base64']['data'][:30]}...]\n"
|
189
|
+
formatted.append({"role": msg["role"], "content": combined_content.strip()})
|
190
|
+
return None, formatted
|
191
|
+
|
192
|
+
system_msg, formatted_messages = format_messages_for_api(model, messages.copy())
|
193
|
+
|
194
|
+
if "claude" in model:
|
195
|
+
# Check for Vertex configuration
|
196
|
+
vertex_project_id = os.getenv('VERTEX_PROJECT_ID')
|
197
|
+
vertex_region = os.getenv('VERTEX_REGION')
|
198
|
+
|
199
|
+
if vertex_project_id and vertex_region:
|
200
|
+
client = AnthropicVertex(
|
201
|
+
region=vertex_region,
|
202
|
+
project_id=vertex_project_id
|
203
|
+
)
|
204
|
+
else:
|
205
|
+
client = anthropic.Anthropic(api_key=api_key, base_url=base_url)
|
206
|
+
|
207
|
+
response = client.messages.create(
|
208
|
+
model=model,
|
209
|
+
max_tokens=max_tokens,
|
210
|
+
temperature=temperature,
|
211
|
+
messages=formatted_messages,
|
212
|
+
system=system_msg
|
213
|
+
)
|
214
|
+
|
215
|
+
while response.stop_reason == "max_tokens":
|
216
|
+
if formatted_messages[-1]['role'] == "user":
|
217
|
+
formatted_messages.append({"role": "assistant", "content": response.completion})
|
218
|
+
else:
|
219
|
+
formatted_messages[-1]['content'] += response.completion
|
220
|
+
|
221
|
+
response = client.messages.create(
|
222
|
+
model=model,
|
223
|
+
max_tokens=max_tokens,
|
224
|
+
temperature=temperature,
|
225
|
+
messages=formatted_messages,
|
226
|
+
system=system_msg
|
227
|
+
)
|
228
|
+
|
229
|
+
if formatted_messages[-1]['role'] == "assistant" and response.stop_reason == "end_turn":
|
230
|
+
formatted_messages[-1]['content'] += response.completion
|
231
|
+
return formatted_messages[-1]['content']
|
232
|
+
|
233
|
+
return response.completion
|
234
|
+
|
235
|
+
elif "gemini" in model:
|
236
|
+
try:
|
237
|
+
# First try OpenAI-style API
|
238
|
+
client = openai.OpenAI(
|
239
|
+
api_key=api_key,
|
240
|
+
base_url="https://generativelanguage.googleapis.com/v1beta/"
|
241
|
+
)
|
242
|
+
# Set response_format based on json_format
|
243
|
+
response_format = {"type": "json_object"} if json_format else {"type": "plain_text"}
|
244
|
+
|
245
|
+
response = client.chat.completions.create(
|
246
|
+
model=model,
|
247
|
+
max_tokens=max_tokens,
|
248
|
+
top_p=top_p,
|
249
|
+
top_k=top_k,
|
250
|
+
messages=formatted_messages,
|
251
|
+
temperature=temperature,
|
252
|
+
response_format=response_format # Added response_format
|
253
|
+
)
|
254
|
+
return response.choices[0].message.content
|
255
|
+
|
256
|
+
except Exception as e:
|
257
|
+
# If OpenAI-style API fails, fall back to Google's genai library
|
258
|
+
logger.info("Falling back to Google's genai library")
|
259
|
+
genai.configure(api_key=api_key)
|
260
|
+
system_instruction = ""
|
261
|
+
for msg in messages:
|
262
|
+
if msg["role"] == "system":
|
263
|
+
system_instruction = msg["content"]
|
264
|
+
break
|
265
|
+
|
266
|
+
# 将其他消息转换为 gemini 格式
|
267
|
+
gemini_messages = format_messages_for_gemini(messages)
|
268
|
+
mime_type = "application/json" if json_format else "text/plain"
|
269
|
+
generation_config = genai.types.GenerationConfig(
|
270
|
+
temperature=temperature,
|
271
|
+
top_p=top_p,
|
272
|
+
top_k=top_k,
|
273
|
+
max_output_tokens=max_tokens,
|
274
|
+
response_mime_type=mime_type
|
275
|
+
)
|
276
|
+
|
277
|
+
model_instance = genai.GenerativeModel(
|
278
|
+
model_name=model,
|
279
|
+
system_instruction=system_instruction, # system 消息通过这里传入
|
280
|
+
generation_config=generation_config
|
281
|
+
)
|
282
|
+
|
283
|
+
response = model_instance.generate_content(gemini_messages, generation_config=generation_config)
|
284
|
+
|
285
|
+
return response.text
|
286
|
+
|
287
|
+
elif "grok" in model:
|
288
|
+
# Randomly choose between OpenAI and Anthropic SDK
|
289
|
+
use_anthropic = random.choice([True, False])
|
290
|
+
|
291
|
+
if use_anthropic:
|
292
|
+
logger.info("Using Anthropic for Grok model")
|
293
|
+
client = anthropic.Anthropic(
|
294
|
+
api_key=api_key,
|
295
|
+
base_url="https://api.x.ai"
|
296
|
+
)
|
297
|
+
|
298
|
+
system_msg = ""
|
299
|
+
if messages and messages[0]["role"] == "system":
|
300
|
+
system_msg = messages.pop(0)["content"]
|
301
|
+
|
302
|
+
response = client.messages.create(
|
303
|
+
model=model,
|
304
|
+
max_tokens=max_tokens,
|
305
|
+
temperature=temperature,
|
306
|
+
messages=formatted_messages,
|
307
|
+
system=system_msg
|
308
|
+
)
|
309
|
+
return response.completion
|
310
|
+
else:
|
311
|
+
logger.info("Using OpenAI for Grok model")
|
312
|
+
client = openai.OpenAI(
|
313
|
+
api_key=api_key,
|
314
|
+
base_url="https://api.x.ai/v1"
|
315
|
+
)
|
316
|
+
# Set response_format based on json_format
|
317
|
+
response_format = {"type": "json_object"} if json_format else {"type": "plain_text"}
|
318
|
+
|
319
|
+
response = client.chat.completions.create(
|
320
|
+
model=model,
|
321
|
+
messages=formatted_messages,
|
322
|
+
max_tokens=max_tokens,
|
323
|
+
temperature=temperature,
|
324
|
+
response_format=response_format # Added response_format
|
325
|
+
)
|
326
|
+
return response.choices[0].message.content
|
327
|
+
|
328
|
+
else: # OpenAI models
|
329
|
+
client = openai.OpenAI(api_key=api_key, base_url=base_url)
|
330
|
+
# Set response_format based on json_format
|
331
|
+
response_format = {"type": "json_object"} if json_format else {"type": "plain_text"}
|
332
|
+
|
333
|
+
response = client.chat.completions.create(
|
334
|
+
model=model,
|
335
|
+
messages=formatted_messages,
|
336
|
+
max_tokens=max_tokens,
|
337
|
+
temperature=temperature,
|
338
|
+
response_format=response_format # Added response_format
|
339
|
+
)
|
340
|
+
return response.choices[0].message.content
|
341
|
+
|
342
|
+
# Release the API key after successful use
|
343
|
+
if not api_key:
|
344
|
+
# key_manager.release_config(service, api_key)
|
345
|
+
pass
|
346
|
+
|
347
|
+
return response
|
348
|
+
|
349
|
+
except Exception as e:
|
350
|
+
logger.error(f"Error in completion: {str(e)}")
|
351
|
+
raise
|
352
|
+
|
353
|
+
class Agent:
|
354
|
+
def __init__(self, model_name: str, messages: Optional[Union[str, List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]]] = None,
|
355
|
+
memory_enabled: bool = False, api_key: Optional[str] = None) -> None:
|
356
|
+
"""Initialize an Agent instance."""
|
357
|
+
self.id = f"{model_name}-{uuid.uuid4().hex[:8]}"
|
358
|
+
self.model_name = model_name
|
359
|
+
self.history = ChatHistory(messages) if messages else ChatHistory()
|
360
|
+
self.memory_enabled = memory_enabled
|
361
|
+
self.api_key = api_key
|
362
|
+
self.repo_content = []
|
363
|
+
|
364
|
+
def add_message(self, role: str, content: Union[str, List[Union[str, Image.Image, Dict]]]):
|
365
|
+
"""Add a message to the conversation."""
|
366
|
+
self.history.add_message(content, role)
|
367
|
+
|
368
|
+
def add_user_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
|
369
|
+
"""Add a user message."""
|
370
|
+
self.history.add_user_message(content)
|
371
|
+
|
372
|
+
def add_assistant_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
|
373
|
+
"""Add an assistant message."""
|
374
|
+
self.history.add_assistant_message(content)
|
375
|
+
|
376
|
+
def add_image(self, image_path: Optional[str] = None, image_url: Optional[str] = None, media_type: Optional[str] = "image/jpeg"):
|
377
|
+
"""
|
378
|
+
Add an image to the conversation.
|
379
|
+
Either image_path or image_url must be provided.
|
380
|
+
"""
|
381
|
+
if not image_path and not image_url:
|
382
|
+
raise ValueError("Either image_path or image_url must be provided.")
|
383
|
+
|
384
|
+
if image_path:
|
385
|
+
if not os.path.exists(image_path):
|
386
|
+
raise FileNotFoundError(f"Image file {image_path} does not exist.")
|
387
|
+
if "gemini" in self.model_name:
|
388
|
+
# For Gemini, load as PIL.Image
|
389
|
+
image_pil = Image.open(image_path)
|
390
|
+
image_block = image_pil
|
391
|
+
else:
|
392
|
+
# For Claude and others, use base64 encoding
|
393
|
+
with open(image_path, "rb") as img_file:
|
394
|
+
image_data = base64.standard_b64encode(img_file.read()).decode("utf-8")
|
395
|
+
image_block = {
|
396
|
+
"type": "image_base64",
|
397
|
+
"image_base64": {
|
398
|
+
"media_type": media_type,
|
399
|
+
"data": image_data
|
400
|
+
}
|
401
|
+
}
|
402
|
+
else:
|
403
|
+
# If image_url is provided
|
404
|
+
if "gemini" in self.model_name:
|
405
|
+
# For Gemini, you can pass image URLs directly
|
406
|
+
image_block = {"type": "image_url", "image_url": {"url": image_url}}
|
407
|
+
else:
|
408
|
+
# For Claude and others, use image URLs
|
409
|
+
image_block = {
|
410
|
+
"type": "image_url",
|
411
|
+
"image_url": {
|
412
|
+
"url": image_url
|
413
|
+
}
|
414
|
+
}
|
415
|
+
|
416
|
+
# Add the image block to the last user message or as a new user message
|
417
|
+
if self.history.last_role == "user":
|
418
|
+
current_content = self.history.messages[-1]["content"]
|
419
|
+
if isinstance(current_content, list):
|
420
|
+
current_content.append(image_block)
|
421
|
+
else:
|
422
|
+
self.history.messages[-1]["content"] = [current_content, image_block]
|
423
|
+
else:
|
424
|
+
# Start a new user message with the image
|
425
|
+
self.history.add_message([image_block], "user")
|
426
|
+
|
427
|
+
def generate_response(self, max_tokens=3585, temperature=0.7, top_p=1.0, top_k=40, json_format: bool = False) -> str:
|
428
|
+
"""Generate a response from the agent.
|
429
|
+
|
430
|
+
Args:
|
431
|
+
max_tokens (int, optional): Maximum number of tokens. Defaults to 3585.
|
432
|
+
temperature (float, optional): Sampling temperature. Defaults to 0.7.
|
433
|
+
json_format (bool, optional): Whether to enable JSON output format. Defaults to False.
|
434
|
+
|
435
|
+
Returns:
|
436
|
+
str: The generated response.
|
437
|
+
"""
|
438
|
+
if not self.history.messages:
|
439
|
+
raise ValueError("No messages in history to generate response from")
|
440
|
+
|
441
|
+
messages = self.history.messages
|
442
|
+
print(self.model_name)
|
443
|
+
response_text = completion(
|
444
|
+
model=self.model_name,
|
445
|
+
messages=messages,
|
446
|
+
max_tokens=max_tokens,
|
447
|
+
temperature=temperature,
|
448
|
+
top_p=top_p,
|
449
|
+
top_k=top_k,
|
450
|
+
api_key=self.api_key,
|
451
|
+
json_format=json_format # Pass json_format to completion
|
452
|
+
)
|
453
|
+
if self.model_name.startswith("openai"):
|
454
|
+
# OpenAI does not support images, so responses are simple strings
|
455
|
+
if self.history.messages[-1]["role"] == "assistant":
|
456
|
+
self.history.messages[-1]["content"] = response_text
|
457
|
+
elif self.memory_enabled:
|
458
|
+
self.add_message("assistant", response_text)
|
459
|
+
elif "claude" in self.model_name:
|
460
|
+
if self.history.messages[-1]["role"] == "assistant":
|
461
|
+
self.history.messages[-1]["content"] = response_text
|
462
|
+
elif self.memory_enabled:
|
463
|
+
self.add_message("assistant", response_text)
|
464
|
+
elif "gemini" in self.model_name or "grok" in self.model_name:
|
465
|
+
if self.history.messages[-1]["role"] == "assistant":
|
466
|
+
if isinstance(self.history.messages[-1]["content"], list):
|
467
|
+
self.history.messages[-1]["content"].append(response_text)
|
468
|
+
else:
|
469
|
+
self.history.messages[-1]["content"] = [self.history.messages[-1]["content"], response_text]
|
470
|
+
elif self.memory_enabled:
|
471
|
+
self.add_message("assistant", response_text)
|
472
|
+
else:
|
473
|
+
# Handle other models similarly
|
474
|
+
if self.history.messages[-1]["role"] == "assistant":
|
475
|
+
self.history.messages[-1]["content"] = response_text
|
476
|
+
elif self.memory_enabled:
|
477
|
+
self.add_message("assistant", response_text)
|
478
|
+
|
479
|
+
return response_text
|
480
|
+
|
481
|
+
def save_conversation(self):
|
482
|
+
filename = f"{self.id}.json"
|
483
|
+
with open(filename, 'w', encoding='utf-8') as file:
|
484
|
+
json.dump(self.history.messages, file, ensure_ascii=False, indent=4)
|
485
|
+
|
486
|
+
def load_conversation(self, filename: Optional[str] = None):
|
487
|
+
if filename is None:
|
488
|
+
filename = f"{self.id}.json"
|
489
|
+
with open(filename, 'r', encoding='utf-8') as file:
|
490
|
+
messages = json.load(file)
|
491
|
+
# Handle deserialization of images if necessary
|
492
|
+
self.history = ChatHistory(messages)
|
493
|
+
|
494
|
+
def add_repo(self, repo_url: Optional[str] = None, username: Optional[str] = None, repo_name: Optional[str] = None, commit_hash: Optional[str] = None):
|
495
|
+
if username and repo_name:
|
496
|
+
if commit_hash:
|
497
|
+
repo_url = f"https://github.com/{username}/{repo_name}/archive/{commit_hash}.zip"
|
498
|
+
else:
|
499
|
+
repo_url = f"https://github.com/{username}/{repo_name}/archive/refs/heads/main.zip"
|
500
|
+
|
501
|
+
if not repo_url:
|
502
|
+
raise ValueError("Either repo_url or both username and repo_name must be provided")
|
503
|
+
|
504
|
+
response = requests.get(repo_url)
|
505
|
+
if response.status_code == 200:
|
506
|
+
repo_content = ""
|
507
|
+
with zipfile.ZipFile(io.BytesIO(response.content)) as z:
|
508
|
+
for file_info in z.infolist():
|
509
|
+
if not file_info.is_dir() and file_info.filename.endswith(('.py', '.txt')):
|
510
|
+
with z.open(file_info) as f:
|
511
|
+
content = f.read().decode('utf-8')
|
512
|
+
repo_content += f"{file_info.filename}\n```\n{content}\n```\n"
|
513
|
+
self.repo_content.append(repo_content)
|
514
|
+
else:
|
515
|
+
raise ValueError(f"Failed to download repository from {repo_url}")
|
516
|
+
|
517
|
+
if __name__ == "__main__":
|
518
|
+
# Example Usage
|
519
|
+
# Create an Agent instance (Gemini model)
|
520
|
+
agent = Agent("gemini-1.5-flash", "you are Jack101", memory_enabled=True)
|
521
|
+
|
522
|
+
# Add an image
|
523
|
+
agent.add_image(image_path="/Users/junfan/Projects/Personal/oneapi/dialog_manager/example.png")
|
524
|
+
|
525
|
+
# Add a user message
|
526
|
+
agent.add_message("user", "Who are you? What's in this image?")
|
527
|
+
|
528
|
+
# Generate response with JSON format enabled
|
529
|
+
try:
|
530
|
+
response = agent.generate_response(json_format=True) # json_format set to True
|
531
|
+
print("Response:", response)
|
532
|
+
except Exception as e:
|
533
|
+
logger.error(f"Failed to generate response: {e}")
|
534
|
+
|
535
|
+
# Print the entire conversation history
|
536
|
+
print("Conversation History:")
|
537
|
+
print(agent.history)
|
538
|
+
|
539
|
+
# Pop the last message
|
540
|
+
last_message = agent.history.pop()
|
541
|
+
print("Last Message:", last_message)
|
542
|
+
|
543
|
+
# Generate another response without JSON format
|
544
|
+
response = agent.generate_response()
|
545
|
+
print("Response:", response)
|