llm-dialog-manager 0.3.2__tar.gz → 0.3.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/PKG-INFO +2 -2
- {llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/llm_dialog_manager/__init__.py +1 -1
- llm_dialog_manager-0.3.5/llm_dialog_manager/agent.py +514 -0
- llm_dialog_manager-0.3.5/llm_dialog_manager/chat_history.py +233 -0
- {llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/llm_dialog_manager.egg-info/PKG-INFO +2 -2
- {llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/pyproject.toml +1 -1
- llm_dialog_manager-0.3.2/llm_dialog_manager/agent.py +0 -333
- llm_dialog_manager-0.3.2/llm_dialog_manager/chat_history.py +0 -146
- {llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/LICENSE +0 -0
- {llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/README.md +0 -0
- {llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/llm_dialog_manager/key_manager.py +0 -0
- {llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/llm_dialog_manager.egg-info/SOURCES.txt +0 -0
- {llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/llm_dialog_manager.egg-info/dependency_links.txt +0 -0
- {llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/llm_dialog_manager.egg-info/requires.txt +0 -0
- {llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/llm_dialog_manager.egg-info/top_level.txt +0 -0
- {llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/setup.cfg +0 -0
- {llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/tests/test_agent.py +0 -0
- {llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/tests/test_chat_history.py +0 -0
- {llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/tests/test_key_manager.py +0 -0
@@ -0,0 +1,514 @@
|
|
1
|
+
# Standard library imports
|
2
|
+
import json
|
3
|
+
import os
|
4
|
+
import uuid
|
5
|
+
from typing import List, Dict, Optional, Union
|
6
|
+
import logging
|
7
|
+
from pathlib import Path
|
8
|
+
import random
|
9
|
+
import requests
|
10
|
+
import zipfile
|
11
|
+
import io
|
12
|
+
import base64
|
13
|
+
from PIL import Image
|
14
|
+
|
15
|
+
# Third-party imports
|
16
|
+
import anthropic
|
17
|
+
from anthropic import AnthropicVertex
|
18
|
+
import google.generativeai as genai
|
19
|
+
import openai
|
20
|
+
from dotenv import load_dotenv
|
21
|
+
|
22
|
+
# Local imports
|
23
|
+
from .chat_history import ChatHistory
|
24
|
+
from .key_manager import key_manager
|
25
|
+
|
26
|
+
# Set up logging
|
27
|
+
logging.basicConfig(level=logging.INFO)
|
28
|
+
logger = logging.getLogger(__name__)
|
29
|
+
|
30
|
+
# Load environment variables
|
31
|
+
def load_env_vars():
|
32
|
+
"""Load environment variables from .env file"""
|
33
|
+
env_path = Path(__file__).parent / '.env'
|
34
|
+
if env_path.exists():
|
35
|
+
load_dotenv(env_path)
|
36
|
+
else:
|
37
|
+
logger.warning(".env file not found. Using system environment variables.")
|
38
|
+
|
39
|
+
load_env_vars()
|
40
|
+
|
41
|
+
def completion(model: str, messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]], max_tokens: int = 1000,
|
42
|
+
temperature: float = 0.5, api_key: Optional[str] = None,
|
43
|
+
base_url: Optional[str] = None, json_format: bool = False) -> str:
|
44
|
+
"""
|
45
|
+
Generate a completion using the specified model and messages.
|
46
|
+
"""
|
47
|
+
try:
|
48
|
+
service = ""
|
49
|
+
if "claude" in model:
|
50
|
+
service = "anthropic"
|
51
|
+
elif "gemini" in model:
|
52
|
+
service = "gemini"
|
53
|
+
elif "grok" in model:
|
54
|
+
service = "x"
|
55
|
+
else:
|
56
|
+
service = "openai"
|
57
|
+
|
58
|
+
# Get API key and base URL from key manager if not provided
|
59
|
+
if not api_key:
|
60
|
+
# api_key, base_url = key_manager.get_config(service)
|
61
|
+
# Placeholder for key_manager
|
62
|
+
api_key = os.getenv(f"{service.upper()}_API_KEY")
|
63
|
+
base_url = os.getenv(f"{service.upper()}_BASE_URL")
|
64
|
+
|
65
|
+
def format_messages_for_api(model, messages):
|
66
|
+
"""Convert ChatHistory messages to the format required by the specific API."""
|
67
|
+
if "claude" in model:
|
68
|
+
formatted = []
|
69
|
+
system_msg = ""
|
70
|
+
if messages and messages[0]["role"] == "system":
|
71
|
+
system_msg = messages.pop(0)["content"]
|
72
|
+
for msg in messages:
|
73
|
+
content = msg["content"]
|
74
|
+
if isinstance(content, str):
|
75
|
+
formatted.append({"role": msg["role"], "content": content})
|
76
|
+
elif isinstance(content, list):
|
77
|
+
# Combine content blocks into a single message
|
78
|
+
combined_content = []
|
79
|
+
for block in content:
|
80
|
+
if isinstance(block, str):
|
81
|
+
combined_content.append({"type": "text", "text": block})
|
82
|
+
elif isinstance(block, Image.Image):
|
83
|
+
# For Claude, convert PIL.Image to base64
|
84
|
+
buffered = io.BytesIO()
|
85
|
+
block.save(buffered, format="PNG")
|
86
|
+
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
87
|
+
combined_content.append({
|
88
|
+
"type": "image",
|
89
|
+
"source": {
|
90
|
+
"type": "base64",
|
91
|
+
"media_type": "image/png",
|
92
|
+
"data": image_base64
|
93
|
+
}
|
94
|
+
})
|
95
|
+
elif isinstance(block, dict):
|
96
|
+
if block.get("type") == "image_url":
|
97
|
+
combined_content.append({
|
98
|
+
"type": "image",
|
99
|
+
"source": {
|
100
|
+
"type": "url",
|
101
|
+
"url": block["image_url"]["url"]
|
102
|
+
}
|
103
|
+
})
|
104
|
+
elif block.get("type") == "image_base64":
|
105
|
+
combined_content.append({
|
106
|
+
"type": "image",
|
107
|
+
"source": {
|
108
|
+
"type": "base64",
|
109
|
+
"media_type": block["image_base64"]["media_type"],
|
110
|
+
"data": block["image_base64"]["data"]
|
111
|
+
}
|
112
|
+
})
|
113
|
+
formatted.append({"role": msg["role"], "content": combined_content})
|
114
|
+
return system_msg, formatted
|
115
|
+
|
116
|
+
elif "gemini" in model or "gpt" in model or "grok" in model:
|
117
|
+
formatted = []
|
118
|
+
for msg in messages:
|
119
|
+
content = msg["content"]
|
120
|
+
if isinstance(content, str):
|
121
|
+
formatted.append({"role": msg["role"], "parts": [content]})
|
122
|
+
elif isinstance(content, list):
|
123
|
+
parts = []
|
124
|
+
for block in content:
|
125
|
+
if isinstance(block, str):
|
126
|
+
parts.append(block)
|
127
|
+
elif isinstance(block, Image.Image):
|
128
|
+
parts.append(block)
|
129
|
+
elif isinstance(block, dict):
|
130
|
+
if block.get("type") == "image_url":
|
131
|
+
parts.append({"type": "image_url", "image_url": {"url": block["image_url"]["url"]}})
|
132
|
+
elif block.get("type") == "image_base64":
|
133
|
+
parts.append({"type": "image_base64", "image_base64": {"data": block["image_base64"]["data"], "media_type": block["image_base64"]["media_type"]}})
|
134
|
+
formatted.append({"role": msg["role"], "parts": parts})
|
135
|
+
return None, formatted
|
136
|
+
|
137
|
+
else: # OpenAI models
|
138
|
+
formatted = []
|
139
|
+
for msg in messages:
|
140
|
+
content = msg["content"]
|
141
|
+
if isinstance(content, str):
|
142
|
+
formatted.append({"role": msg["role"], "content": content})
|
143
|
+
elif isinstance(content, list):
|
144
|
+
# OpenAI expects 'content' as string; images are not directly supported
|
145
|
+
# You can convert images to URLs or descriptions if needed
|
146
|
+
combined_content = ""
|
147
|
+
for block in content:
|
148
|
+
if isinstance(block, str):
|
149
|
+
combined_content += block + "\n"
|
150
|
+
elif isinstance(block, Image.Image):
|
151
|
+
# Convert PIL.Image to base64 or upload and use URL
|
152
|
+
buffered = io.BytesIO()
|
153
|
+
block.save(buffered, format="PNG")
|
154
|
+
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
155
|
+
combined_content += f"[Image Base64: {image_base64[:30]}...]\n"
|
156
|
+
elif isinstance(block, dict):
|
157
|
+
if block.get("type") == "image_url":
|
158
|
+
combined_content += f"[Image: {block['image_url']['url']}]\n"
|
159
|
+
elif block.get("type") == "image_base64":
|
160
|
+
combined_content += f"[Image Base64: {block['image_base64']['data'][:30]}...]\n"
|
161
|
+
formatted.append({"role": msg["role"], "content": combined_content.strip()})
|
162
|
+
return None, formatted
|
163
|
+
|
164
|
+
system_msg, formatted_messages = format_messages_for_api(model, messages.copy())
|
165
|
+
|
166
|
+
if "claude" in model:
|
167
|
+
# Check for Vertex configuration
|
168
|
+
vertex_project_id = os.getenv('VERTEX_PROJECT_ID')
|
169
|
+
vertex_region = os.getenv('VERTEX_REGION')
|
170
|
+
|
171
|
+
if vertex_project_id and vertex_region:
|
172
|
+
client = AnthropicVertex(
|
173
|
+
region=vertex_region,
|
174
|
+
project_id=vertex_project_id
|
175
|
+
)
|
176
|
+
else:
|
177
|
+
client = anthropic.Anthropic(api_key=api_key, base_url=base_url)
|
178
|
+
|
179
|
+
response = client.messages.create(
|
180
|
+
model=model,
|
181
|
+
max_tokens=max_tokens,
|
182
|
+
temperature=temperature,
|
183
|
+
messages=formatted_messages,
|
184
|
+
system=system_msg
|
185
|
+
)
|
186
|
+
|
187
|
+
while response.stop_reason == "max_tokens":
|
188
|
+
if formatted_messages[-1]['role'] == "user":
|
189
|
+
formatted_messages.append({"role": "assistant", "content": response.completion})
|
190
|
+
else:
|
191
|
+
formatted_messages[-1]['content'] += response.completion
|
192
|
+
|
193
|
+
response = client.messages.create(
|
194
|
+
model=model,
|
195
|
+
max_tokens=max_tokens,
|
196
|
+
temperature=temperature,
|
197
|
+
messages=formatted_messages,
|
198
|
+
system=system_msg
|
199
|
+
)
|
200
|
+
|
201
|
+
if formatted_messages[-1]['role'] == "assistant" and response.stop_reason == "end_turn":
|
202
|
+
formatted_messages[-1]['content'] += response.completion
|
203
|
+
return formatted_messages[-1]['content']
|
204
|
+
|
205
|
+
return response.completion
|
206
|
+
|
207
|
+
elif "gemini" in model:
|
208
|
+
try:
|
209
|
+
# First try OpenAI-style API
|
210
|
+
client = openai.OpenAI(
|
211
|
+
api_key=api_key,
|
212
|
+
base_url="https://generativelanguage.googleapis.com/v1beta/"
|
213
|
+
)
|
214
|
+
# Set response_format based on json_format
|
215
|
+
response_format = {"type": "json_object"} if json_format else {"type": "plain_text"}
|
216
|
+
|
217
|
+
response = client.chat.completions.create(
|
218
|
+
model=model,
|
219
|
+
messages=formatted_messages,
|
220
|
+
temperature=temperature,
|
221
|
+
response_format=response_format # Added response_format
|
222
|
+
)
|
223
|
+
return response.choices[0].message.content
|
224
|
+
|
225
|
+
except Exception as e:
|
226
|
+
# If OpenAI-style API fails, fall back to Google's genai library
|
227
|
+
logger.info("Falling back to Google's genai library")
|
228
|
+
genai.configure(api_key=api_key)
|
229
|
+
|
230
|
+
# Convert messages to Gemini format
|
231
|
+
gemini_messages = []
|
232
|
+
for msg in messages:
|
233
|
+
if msg["role"] == "system":
|
234
|
+
# Prepend system message to first user message if exists
|
235
|
+
if gemini_messages:
|
236
|
+
first_msg = gemini_messages[0]
|
237
|
+
if "parts" in first_msg and len(first_msg["parts"]) > 0:
|
238
|
+
first_msg["parts"][0] = f"{msg['content']}\n\n{first_msg['parts'][0]}"
|
239
|
+
else:
|
240
|
+
gemini_messages.append({"role": msg["role"], "parts": msg["content"]})
|
241
|
+
|
242
|
+
# Set response_mime_type based on json_format
|
243
|
+
mime_type = "application/json" if json_format else "text/plain"
|
244
|
+
|
245
|
+
# Create Gemini model and generate response
|
246
|
+
model_instance = genai.GenerativeModel(model_name=model)
|
247
|
+
response = model_instance.generate_content(
|
248
|
+
gemini_messages,
|
249
|
+
generation_config=genai.types.GenerationConfig(
|
250
|
+
temperature=temperature,
|
251
|
+
response_mime_type=mime_type, # Modified based on json_format
|
252
|
+
max_output_tokens=max_tokens
|
253
|
+
)
|
254
|
+
)
|
255
|
+
|
256
|
+
return response.text
|
257
|
+
|
258
|
+
elif "grok" in model:
|
259
|
+
# Randomly choose between OpenAI and Anthropic SDK
|
260
|
+
use_anthropic = random.choice([True, False])
|
261
|
+
|
262
|
+
if use_anthropic:
|
263
|
+
logger.info("Using Anthropic for Grok model")
|
264
|
+
client = anthropic.Anthropic(
|
265
|
+
api_key=api_key,
|
266
|
+
base_url="https://api.x.ai"
|
267
|
+
)
|
268
|
+
|
269
|
+
system_msg = ""
|
270
|
+
if messages and messages[0]["role"] == "system":
|
271
|
+
system_msg = messages.pop(0)["content"]
|
272
|
+
|
273
|
+
response = client.messages.create(
|
274
|
+
model=model,
|
275
|
+
max_tokens=max_tokens,
|
276
|
+
temperature=temperature,
|
277
|
+
messages=formatted_messages,
|
278
|
+
system=system_msg
|
279
|
+
)
|
280
|
+
return response.completion
|
281
|
+
else:
|
282
|
+
logger.info("Using OpenAI for Grok model")
|
283
|
+
client = openai.OpenAI(
|
284
|
+
api_key=api_key,
|
285
|
+
base_url="https://api.x.ai/v1"
|
286
|
+
)
|
287
|
+
# Set response_format based on json_format
|
288
|
+
response_format = {"type": "json_object"} if json_format else {"type": "plain_text"}
|
289
|
+
|
290
|
+
response = client.chat.completions.create(
|
291
|
+
model=model,
|
292
|
+
messages=formatted_messages,
|
293
|
+
max_tokens=max_tokens,
|
294
|
+
temperature=temperature,
|
295
|
+
response_format=response_format # Added response_format
|
296
|
+
)
|
297
|
+
return response.choices[0].message.content
|
298
|
+
|
299
|
+
else: # OpenAI models
|
300
|
+
client = openai.OpenAI(api_key=api_key, base_url=base_url)
|
301
|
+
# Set response_format based on json_format
|
302
|
+
response_format = {"type": "json_object"} if json_format else {"type": "plain_text"}
|
303
|
+
|
304
|
+
response = client.chat.completions.create(
|
305
|
+
model=model,
|
306
|
+
messages=formatted_messages,
|
307
|
+
max_tokens=max_tokens,
|
308
|
+
temperature=temperature,
|
309
|
+
response_format=response_format # Added response_format
|
310
|
+
)
|
311
|
+
return response.choices[0].message.content
|
312
|
+
|
313
|
+
# Release the API key after successful use
|
314
|
+
if not api_key:
|
315
|
+
# key_manager.release_config(service, api_key)
|
316
|
+
pass
|
317
|
+
|
318
|
+
return response
|
319
|
+
|
320
|
+
except Exception as e:
|
321
|
+
logger.error(f"Error in completion: {str(e)}")
|
322
|
+
raise
|
323
|
+
|
324
|
+
class Agent:
|
325
|
+
def __init__(self, model_name: str, messages: Optional[Union[str, List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]]] = None,
|
326
|
+
memory_enabled: bool = False, api_key: Optional[str] = None) -> None:
|
327
|
+
"""Initialize an Agent instance."""
|
328
|
+
self.id = f"{model_name}-{uuid.uuid4().hex[:8]}"
|
329
|
+
self.model_name = model_name
|
330
|
+
self.history = ChatHistory(messages) if messages else ChatHistory()
|
331
|
+
self.memory_enabled = memory_enabled
|
332
|
+
self.api_key = api_key
|
333
|
+
self.repo_content = []
|
334
|
+
|
335
|
+
def add_message(self, role: str, content: Union[str, List[Union[str, Image.Image, Dict]]]):
|
336
|
+
"""Add a message to the conversation."""
|
337
|
+
self.history.add_message(content, role)
|
338
|
+
|
339
|
+
def add_user_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
|
340
|
+
"""Add a user message."""
|
341
|
+
self.history.add_user_message(content)
|
342
|
+
|
343
|
+
def add_assistant_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
|
344
|
+
"""Add an assistant message."""
|
345
|
+
self.history.add_assistant_message(content)
|
346
|
+
|
347
|
+
def add_image(self, image_path: Optional[str] = None, image_url: Optional[str] = None, media_type: Optional[str] = "image/jpeg"):
|
348
|
+
"""
|
349
|
+
Add an image to the conversation.
|
350
|
+
Either image_path or image_url must be provided.
|
351
|
+
"""
|
352
|
+
if not image_path and not image_url:
|
353
|
+
raise ValueError("Either image_path or image_url must be provided.")
|
354
|
+
|
355
|
+
if image_path:
|
356
|
+
if not os.path.exists(image_path):
|
357
|
+
raise FileNotFoundError(f"Image file {image_path} does not exist.")
|
358
|
+
if "gemini" in self.model_name:
|
359
|
+
# For Gemini, load as PIL.Image
|
360
|
+
image_pil = Image.open(image_path)
|
361
|
+
image_block = image_pil
|
362
|
+
else:
|
363
|
+
# For Claude and others, use base64 encoding
|
364
|
+
with open(image_path, "rb") as img_file:
|
365
|
+
image_data = base64.standard_b64encode(img_file.read()).decode("utf-8")
|
366
|
+
image_block = {
|
367
|
+
"type": "image_base64",
|
368
|
+
"image_base64": {
|
369
|
+
"media_type": media_type,
|
370
|
+
"data": image_data
|
371
|
+
}
|
372
|
+
}
|
373
|
+
else:
|
374
|
+
# If image_url is provided
|
375
|
+
if "gemini" in self.model_name:
|
376
|
+
# For Gemini, you can pass image URLs directly
|
377
|
+
image_block = {"type": "image_url", "image_url": {"url": image_url}}
|
378
|
+
else:
|
379
|
+
# For Claude and others, use image URLs
|
380
|
+
image_block = {
|
381
|
+
"type": "image_url",
|
382
|
+
"image_url": {
|
383
|
+
"url": image_url
|
384
|
+
}
|
385
|
+
}
|
386
|
+
|
387
|
+
# Add the image block to the last user message or as a new user message
|
388
|
+
if self.history.last_role == "user":
|
389
|
+
current_content = self.history.messages[-1]["content"]
|
390
|
+
if isinstance(current_content, list):
|
391
|
+
current_content.append(image_block)
|
392
|
+
else:
|
393
|
+
self.history.messages[-1]["content"] = [current_content, image_block]
|
394
|
+
else:
|
395
|
+
# Start a new user message with the image
|
396
|
+
self.history.add_message([image_block], "user")
|
397
|
+
|
398
|
+
def generate_response(self, max_tokens=3585, temperature=0.7, json_format: bool = False) -> str:
|
399
|
+
"""Generate a response from the agent.
|
400
|
+
|
401
|
+
Args:
|
402
|
+
max_tokens (int, optional): Maximum number of tokens. Defaults to 3585.
|
403
|
+
temperature (float, optional): Sampling temperature. Defaults to 0.7.
|
404
|
+
json_format (bool, optional): Whether to enable JSON output format. Defaults to False.
|
405
|
+
|
406
|
+
Returns:
|
407
|
+
str: The generated response.
|
408
|
+
"""
|
409
|
+
if not self.history.messages:
|
410
|
+
raise ValueError("No messages in history to generate response from")
|
411
|
+
|
412
|
+
messages = self.history.messages
|
413
|
+
|
414
|
+
response_text = completion(
|
415
|
+
model=self.model_name,
|
416
|
+
messages=messages,
|
417
|
+
max_tokens=max_tokens,
|
418
|
+
temperature=temperature,
|
419
|
+
api_key=self.api_key,
|
420
|
+
json_format=json_format # Pass json_format to completion
|
421
|
+
)
|
422
|
+
if self.model_name.startswith("openai"):
|
423
|
+
# OpenAI does not support images, so responses are simple strings
|
424
|
+
if self.history.messages[-1]["role"] == "assistant":
|
425
|
+
self.history.messages[-1]["content"] = response_text
|
426
|
+
elif self.memory_enabled:
|
427
|
+
self.add_message("assistant", response_text)
|
428
|
+
elif "claude" in self.model_name:
|
429
|
+
if self.history.messages[-1]["role"] == "assistant":
|
430
|
+
self.history.messages[-1]["content"] = response_text
|
431
|
+
elif self.memory_enabled:
|
432
|
+
self.add_message("assistant", response_text)
|
433
|
+
elif "gemini" in self.model_name or "grok" in self.model_name:
|
434
|
+
if self.history.messages[-1]["role"] == "assistant":
|
435
|
+
if isinstance(self.history.messages[-1]["content"], list):
|
436
|
+
self.history.messages[-1]["content"].append(response_text)
|
437
|
+
else:
|
438
|
+
self.history.messages[-1]["content"] = [self.history.messages[-1]["content"], response_text]
|
439
|
+
elif self.memory_enabled:
|
440
|
+
self.add_message("assistant", response_text)
|
441
|
+
else:
|
442
|
+
# Handle other models similarly
|
443
|
+
if self.history.messages[-1]["role"] == "assistant":
|
444
|
+
self.history.messages[-1]["content"] = response_text
|
445
|
+
elif self.memory_enabled:
|
446
|
+
self.add_message("assistant", response_text)
|
447
|
+
|
448
|
+
return response_text
|
449
|
+
|
450
|
+
def save_conversation(self):
|
451
|
+
filename = f"{self.id}.json"
|
452
|
+
with open(filename, 'w', encoding='utf-8') as file:
|
453
|
+
json.dump(self.history.messages, file, ensure_ascii=False, indent=4)
|
454
|
+
|
455
|
+
def load_conversation(self, filename: Optional[str] = None):
|
456
|
+
if filename is None:
|
457
|
+
filename = f"{self.id}.json"
|
458
|
+
with open(filename, 'r', encoding='utf-8') as file:
|
459
|
+
messages = json.load(file)
|
460
|
+
# Handle deserialization of images if necessary
|
461
|
+
self.history = ChatHistory(messages)
|
462
|
+
|
463
|
+
def add_repo(self, repo_url: Optional[str] = None, username: Optional[str] = None, repo_name: Optional[str] = None, commit_hash: Optional[str] = None):
|
464
|
+
if username and repo_name:
|
465
|
+
if commit_hash:
|
466
|
+
repo_url = f"https://github.com/{username}/{repo_name}/archive/{commit_hash}.zip"
|
467
|
+
else:
|
468
|
+
repo_url = f"https://github.com/{username}/{repo_name}/archive/refs/heads/main.zip"
|
469
|
+
|
470
|
+
if not repo_url:
|
471
|
+
raise ValueError("Either repo_url or both username and repo_name must be provided")
|
472
|
+
|
473
|
+
response = requests.get(repo_url)
|
474
|
+
if response.status_code == 200:
|
475
|
+
repo_content = ""
|
476
|
+
with zipfile.ZipFile(io.BytesIO(response.content)) as z:
|
477
|
+
for file_info in z.infolist():
|
478
|
+
if not file_info.is_dir() and file_info.filename.endswith(('.py', '.txt')):
|
479
|
+
with z.open(file_info) as f:
|
480
|
+
content = f.read().decode('utf-8')
|
481
|
+
repo_content += f"{file_info.filename}\n```\n{content}\n```\n"
|
482
|
+
self.repo_content.append(repo_content)
|
483
|
+
else:
|
484
|
+
raise ValueError(f"Failed to download repository from {repo_url}")
|
485
|
+
|
486
|
+
if __name__ == "__main__":
|
487
|
+
# Example Usage
|
488
|
+
# Create an Agent instance (Gemini model)
|
489
|
+
agent = Agent("gemini-1.5-flash", "you are an assistant", memory_enabled=True)
|
490
|
+
|
491
|
+
# Add an image
|
492
|
+
agent.add_image(image_path="/Users/junfan/Projects/Personal/oneapi/dialog_manager/example.png")
|
493
|
+
|
494
|
+
# Add a user message
|
495
|
+
agent.add_message("user", "What's in this image?")
|
496
|
+
|
497
|
+
# Generate response with JSON format enabled
|
498
|
+
try:
|
499
|
+
response = agent.generate_response(json_format=True) # json_format set to True
|
500
|
+
print("Response:", response)
|
501
|
+
except Exception as e:
|
502
|
+
logger.error(f"Failed to generate response: {e}")
|
503
|
+
|
504
|
+
# Print the entire conversation history
|
505
|
+
print("Conversation History:")
|
506
|
+
print(agent.history)
|
507
|
+
|
508
|
+
# Pop the last message
|
509
|
+
last_message = agent.history.pop()
|
510
|
+
print("Last Message:", last_message)
|
511
|
+
|
512
|
+
# Generate another response without JSON format
|
513
|
+
response = agent.generate_response()
|
514
|
+
print("Response:", response)
|
@@ -0,0 +1,233 @@
|
|
1
|
+
from typing import List, Dict, Optional, Union
|
2
|
+
from PIL import Image
|
3
|
+
|
4
|
+
class ChatHistory:
|
5
|
+
def __init__(self, input_data: Union[str, List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]] = "") -> None:
|
6
|
+
self.messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]] = []
|
7
|
+
if isinstance(input_data, str) and input_data:
|
8
|
+
self.add_message(input_data, "system")
|
9
|
+
elif isinstance(input_data, list):
|
10
|
+
self.load_messages(input_data)
|
11
|
+
self.last_role: str = "system" if not self.messages else self.get_last_role()
|
12
|
+
|
13
|
+
def load_messages(self, messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]) -> None:
|
14
|
+
for message in messages:
|
15
|
+
if not ("role" in message and "content" in message):
|
16
|
+
raise ValueError("Each message must have a 'role' and 'content'.")
|
17
|
+
if message["role"] not in ["user", "assistant", "system"]:
|
18
|
+
raise ValueError(f"Invalid role: {message['role']}")
|
19
|
+
self.messages.append(message)
|
20
|
+
self.last_role = self.get_last_role()
|
21
|
+
|
22
|
+
def get_last_role(self):
|
23
|
+
return self.messages[-1]["role"] if self.messages else "system"
|
24
|
+
|
25
|
+
def pop(self):
|
26
|
+
if not self.messages:
|
27
|
+
return None
|
28
|
+
|
29
|
+
popped_message = self.messages.pop()
|
30
|
+
|
31
|
+
if self.messages:
|
32
|
+
self.last_role = self.get_last_role()
|
33
|
+
else:
|
34
|
+
self.last_role = "system"
|
35
|
+
|
36
|
+
return popped_message["content"]
|
37
|
+
|
38
|
+
def __len__(self):
|
39
|
+
return len(self.messages)
|
40
|
+
|
41
|
+
def __str__(self):
|
42
|
+
formatted_messages = []
|
43
|
+
for i, msg in enumerate(self.messages):
|
44
|
+
role = msg['role']
|
45
|
+
content = msg['content']
|
46
|
+
if isinstance(content, str):
|
47
|
+
formatted_content = content
|
48
|
+
elif isinstance(content, list):
|
49
|
+
parts = []
|
50
|
+
for block in content:
|
51
|
+
if isinstance(block, str):
|
52
|
+
parts.append(block)
|
53
|
+
elif isinstance(block, Image.Image):
|
54
|
+
parts.append(f"[Image Object: {block.filename}]")
|
55
|
+
elif isinstance(block, dict):
|
56
|
+
if block.get("type") == "image_url":
|
57
|
+
parts.append(f"[Image URL: {block.get('image_url', {}).get('url', '')}]")
|
58
|
+
elif block.get("type") == "image_base64":
|
59
|
+
parts.append(f"[Image Base64: {block.get('image_base64', {}).get('data', '')[:30]}...]")
|
60
|
+
formatted_content = "\n".join(parts)
|
61
|
+
else:
|
62
|
+
formatted_content = str(content)
|
63
|
+
formatted_messages.append(f"Message {i} ({role}): {formatted_content}")
|
64
|
+
return '\n'.join(formatted_messages)
|
65
|
+
|
66
|
+
def __getitem__(self, key):
|
67
|
+
if isinstance(key, slice):
|
68
|
+
sliced_messages = self.messages[key]
|
69
|
+
formatted = []
|
70
|
+
for msg in sliced_messages:
|
71
|
+
role = msg['role']
|
72
|
+
content = msg['content']
|
73
|
+
if isinstance(content, str):
|
74
|
+
formatted_content = content
|
75
|
+
elif isinstance(content, list):
|
76
|
+
parts = []
|
77
|
+
for block in content:
|
78
|
+
if isinstance(block, str):
|
79
|
+
parts.append(block)
|
80
|
+
elif isinstance(block, Image.Image):
|
81
|
+
parts.append(f"[Image Object: {block.filename}]")
|
82
|
+
elif isinstance(block, dict):
|
83
|
+
if block.get("type") == "image_url":
|
84
|
+
parts.append(f"[Image URL: {block.get('image_url', {}).get('url', '')}]")
|
85
|
+
elif block.get("type") == "image_base64":
|
86
|
+
parts.append(f"[Image Base64: {block.get('image_base64', {}).get('data', '')[:30]}...]")
|
87
|
+
formatted_content = "\n".join(parts)
|
88
|
+
else:
|
89
|
+
formatted_content = str(content)
|
90
|
+
formatted.append(f"({role}): {formatted_content}")
|
91
|
+
print('\n'.join(formatted))
|
92
|
+
return sliced_messages
|
93
|
+
elif isinstance(key, int):
|
94
|
+
# Adjust for negative indices
|
95
|
+
if key < 0:
|
96
|
+
key += len(self.messages)
|
97
|
+
if 0 <= key < len(self.messages):
|
98
|
+
msg = self.messages[key]
|
99
|
+
role = msg['role']
|
100
|
+
content = msg['content']
|
101
|
+
if isinstance(content, str):
|
102
|
+
formatted_content = content
|
103
|
+
elif isinstance(content, list):
|
104
|
+
parts = []
|
105
|
+
for block in content:
|
106
|
+
if isinstance(block, str):
|
107
|
+
parts.append(block)
|
108
|
+
elif isinstance(block, Image.Image):
|
109
|
+
parts.append(f"[Image Object: {block.filename}]")
|
110
|
+
elif isinstance(block, dict):
|
111
|
+
if block.get("type") == "image_url":
|
112
|
+
parts.append(f"[Image URL: {block.get('image_url', {}).get('url', '')}]")
|
113
|
+
elif block.get("type") == "image_base64":
|
114
|
+
parts.append(f"[Image Base64: {block.get('image_base64', {}).get('data', '')[:30]}...]")
|
115
|
+
formatted_content = "\n".join(parts)
|
116
|
+
else:
|
117
|
+
formatted_content = str(content)
|
118
|
+
snippet = self.get_conversation_snippet(key)
|
119
|
+
print('\n'.join([f"({v['role']}): {v['content']}" for k, v in snippet.items() if v]))
|
120
|
+
return self.messages[key]
|
121
|
+
else:
|
122
|
+
raise IndexError("Message index out of range.")
|
123
|
+
else:
|
124
|
+
raise TypeError("Invalid argument type.")
|
125
|
+
|
126
|
+
def __setitem__(self, index, value):
|
127
|
+
if not isinstance(value, (str, list)):
|
128
|
+
raise ValueError("Message content must be a string or a list of content blocks.")
|
129
|
+
role = "system" if index % 2 == 0 else "user"
|
130
|
+
self.messages[index] = {"role": role, "content": value}
|
131
|
+
|
132
|
+
def __add__(self, message):
|
133
|
+
if self.last_role == "system":
|
134
|
+
self.add_user_message(message)
|
135
|
+
else:
|
136
|
+
next_role = "assistant" if self.last_role == "user" else "user"
|
137
|
+
self.add_message(message, next_role)
|
138
|
+
|
139
|
+
def __contains__(self, item):
|
140
|
+
for message in self.messages:
|
141
|
+
content = message['content']
|
142
|
+
if isinstance(content, str) and item in content:
|
143
|
+
return True
|
144
|
+
elif isinstance(content, list):
|
145
|
+
for block in content:
|
146
|
+
if isinstance(block, str) and item in block:
|
147
|
+
return True
|
148
|
+
return False
|
149
|
+
|
150
|
+
def add_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]], role: str):
|
151
|
+
self.messages.append({"role": role, "content": content})
|
152
|
+
self.last_role = role
|
153
|
+
|
154
|
+
def add_user_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
|
155
|
+
if self.last_role in ["system", "assistant"]:
|
156
|
+
self.add_message(content, "user")
|
157
|
+
else:
|
158
|
+
raise ValueError("A user message must follow a system or assistant message.")
|
159
|
+
|
160
|
+
def add_assistant_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
|
161
|
+
if self.last_role == "user":
|
162
|
+
self.add_message(content, "assistant")
|
163
|
+
else:
|
164
|
+
raise ValueError("An assistant message must follow a user message.")
|
165
|
+
|
166
|
+
def add_marker(self, marker, index=None):
|
167
|
+
if not isinstance(marker, str):
|
168
|
+
raise ValueError("Marker must be a string.")
|
169
|
+
if index is None:
|
170
|
+
index = len(self.messages) - 1
|
171
|
+
if 0 <= index < len(self.messages):
|
172
|
+
self.messages[index]["marker"] = marker
|
173
|
+
else:
|
174
|
+
raise IndexError("Invalid index for marker.")
|
175
|
+
|
176
|
+
def conversation_status(self):
|
177
|
+
return {
|
178
|
+
"last_message_role": self.last_role,
|
179
|
+
"total_messages": len(self.messages),
|
180
|
+
"last_message_content": self.messages[-1]["content"] if self.messages else "No messages",
|
181
|
+
}
|
182
|
+
|
183
|
+
def display_conversation_status(self):
|
184
|
+
status = self.conversation_status()
|
185
|
+
print(f"Role of the last message: {status['last_message_role']}")
|
186
|
+
print(f"Total number of messages: {status['total_messages']}")
|
187
|
+
print(f"Content of the last message: {status['last_message_content']}")
|
188
|
+
|
189
|
+
def search_for_keyword(self, keyword):
|
190
|
+
results = []
|
191
|
+
for msg in self.messages:
|
192
|
+
content = msg['content']
|
193
|
+
if isinstance(content, str) and keyword.lower() in content.lower():
|
194
|
+
results.append(msg)
|
195
|
+
elif isinstance(content, list):
|
196
|
+
for block in content:
|
197
|
+
if isinstance(block, str) and keyword.lower() in block.lower():
|
198
|
+
results.append(msg)
|
199
|
+
break
|
200
|
+
return results
|
201
|
+
|
202
|
+
def has_user_or_assistant_spoken_since_last_system(self):
|
203
|
+
for msg in reversed(self.messages):
|
204
|
+
if msg["role"] == "system":
|
205
|
+
return False
|
206
|
+
if msg["role"] in ["user", "assistant"]:
|
207
|
+
return True
|
208
|
+
return False
|
209
|
+
|
210
|
+
def get_conversation_snippet(self, index):
|
211
|
+
snippet = {"previous": None, "current": None, "next": None}
|
212
|
+
if 0 <= index < len(self.messages):
|
213
|
+
snippet['current'] = self.messages[index]
|
214
|
+
if index > 0:
|
215
|
+
snippet['previous'] = self.messages[index - 1]
|
216
|
+
if index + 1 < len(self.messages):
|
217
|
+
snippet['next'] = self.messages[index + 1]
|
218
|
+
else:
|
219
|
+
raise IndexError("Invalid index.")
|
220
|
+
return snippet
|
221
|
+
|
222
|
+
def display_snippet(self, index):
|
223
|
+
snippet = self.get_conversation_snippet(index)
|
224
|
+
for key, value in snippet.items():
|
225
|
+
if value:
|
226
|
+
print(f"{key.capitalize()} Message ({value['role']}): {value['content']}")
|
227
|
+
else:
|
228
|
+
print(f"{key.capitalize()}: None")
|
229
|
+
|
230
|
+
@staticmethod
|
231
|
+
def color_text(text, color):
|
232
|
+
colors = {"green": "\033[92m", "red": "\033[91m", "end": "\033[0m"}
|
233
|
+
return f"{colors.get(color, '')}{text}{colors.get('end', '')}"
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "llm_dialog_manager"
|
7
|
-
version = "0.3.
|
7
|
+
version = "0.3.5"
|
8
8
|
description = "A Python package for managing LLM chat conversation history"
|
9
9
|
readme = "README.md"
|
10
10
|
authors = [{ name = "xihajun", email = "work@2333.fun" }]
|
@@ -1,333 +0,0 @@
|
|
1
|
-
# Standard library imports
|
2
|
-
import json
|
3
|
-
import os
|
4
|
-
import uuid
|
5
|
-
from typing import List, Dict, Optional
|
6
|
-
import logging
|
7
|
-
from pathlib import Path
|
8
|
-
import random
|
9
|
-
import requests
|
10
|
-
import zipfile
|
11
|
-
import io
|
12
|
-
|
13
|
-
# Third-party imports
|
14
|
-
import anthropic
|
15
|
-
from anthropic import AnthropicVertex
|
16
|
-
import google.generativeai as genai
|
17
|
-
import openai
|
18
|
-
from dotenv import load_dotenv
|
19
|
-
|
20
|
-
# Local imports
|
21
|
-
from llm_dialog_manager.chat_history import ChatHistory
|
22
|
-
from llm_dialog_manager.key_manager import key_manager
|
23
|
-
|
24
|
-
# Set up logging
|
25
|
-
logging.basicConfig(level=logging.INFO)
|
26
|
-
logger = logging.getLogger(__name__)
|
27
|
-
|
28
|
-
# Load environment variables
|
29
|
-
def load_env_vars():
|
30
|
-
"""Load environment variables from .env file"""
|
31
|
-
env_path = Path(__file__).parent.parent / '.env'
|
32
|
-
if env_path.exists():
|
33
|
-
load_dotenv(env_path)
|
34
|
-
else:
|
35
|
-
logger.warning(".env file not found. Using system environment variables.")
|
36
|
-
|
37
|
-
load_env_vars()
|
38
|
-
|
39
|
-
def create_and_send_message(client, model, max_tokens, temperature, messages, system_msg):
|
40
|
-
"""Function to send a message to the Anthropic API and handle the response."""
|
41
|
-
try:
|
42
|
-
response = client.messages.create(
|
43
|
-
model=model,
|
44
|
-
max_tokens=max_tokens,
|
45
|
-
temperature=temperature,
|
46
|
-
messages=messages,
|
47
|
-
system=system_msg
|
48
|
-
)
|
49
|
-
return response
|
50
|
-
except Exception as e:
|
51
|
-
logger.error(f"Error sending message: {str(e)}")
|
52
|
-
raise
|
53
|
-
|
54
|
-
def completion(model: str, messages: List[Dict[str, str]], max_tokens: int = 1000,
|
55
|
-
temperature: float = 0.5, api_key: Optional[str] = None,
|
56
|
-
base_url: Optional[str] = None) -> str:
|
57
|
-
"""
|
58
|
-
Generate a completion using the specified model and messages.
|
59
|
-
"""
|
60
|
-
try:
|
61
|
-
service = ""
|
62
|
-
if "claude" in model:
|
63
|
-
service = "anthropic"
|
64
|
-
elif "gemini" in model:
|
65
|
-
service = "gemini"
|
66
|
-
elif "grok" in model:
|
67
|
-
service = "x"
|
68
|
-
else:
|
69
|
-
service = "openai"
|
70
|
-
|
71
|
-
# Get API key and base URL from key manager if not provided
|
72
|
-
if not api_key:
|
73
|
-
api_key, base_url = key_manager.get_config(service)
|
74
|
-
|
75
|
-
try:
|
76
|
-
if "claude" in model:
|
77
|
-
# Check for Vertex configuration
|
78
|
-
vertex_project_id = os.getenv('VERTEX_PROJECT_ID')
|
79
|
-
vertex_region = os.getenv('VERTEX_REGION')
|
80
|
-
|
81
|
-
if vertex_project_id and vertex_region:
|
82
|
-
client = AnthropicVertex(
|
83
|
-
region=vertex_region,
|
84
|
-
project_id=vertex_project_id
|
85
|
-
)
|
86
|
-
else:
|
87
|
-
client = anthropic.Anthropic(api_key=api_key, base_url=base_url)
|
88
|
-
|
89
|
-
system_msg = messages.pop(0)["content"] if messages and messages[0]["role"] == "system" else ""
|
90
|
-
response = client.messages.create(
|
91
|
-
model=model,
|
92
|
-
max_tokens=max_tokens,
|
93
|
-
temperature=temperature,
|
94
|
-
messages=messages,
|
95
|
-
system=system_msg
|
96
|
-
)
|
97
|
-
|
98
|
-
while response.stop_reason == "max_tokens":
|
99
|
-
if messages[-1]['role'] == "user":
|
100
|
-
messages.append({"role": "assistant", "content": response.content[0].text})
|
101
|
-
else:
|
102
|
-
messages[-1]['content'] += response.content[0].text
|
103
|
-
|
104
|
-
response = client.messages.create(
|
105
|
-
model=model,
|
106
|
-
max_tokens=max_tokens,
|
107
|
-
temperature=temperature,
|
108
|
-
messages=messages,
|
109
|
-
system=system_msg
|
110
|
-
)
|
111
|
-
|
112
|
-
if messages[-1]['role'] == "assistant" and response.stop_reason == "end_turn":
|
113
|
-
messages[-1]['content'] += response.content[0].text
|
114
|
-
return messages[-1]['content']
|
115
|
-
|
116
|
-
return response.content[0].text
|
117
|
-
|
118
|
-
elif "gemini" in model:
|
119
|
-
try:
|
120
|
-
# First try OpenAI-style API
|
121
|
-
client = openai.OpenAI(
|
122
|
-
api_key=api_key,
|
123
|
-
base_url="https://generativelanguage.googleapis.com/v1beta/"
|
124
|
-
)
|
125
|
-
# Remove any system message from the beginning if present
|
126
|
-
if messages and messages[0]["role"] == "system":
|
127
|
-
system_msg = messages.pop(0)
|
128
|
-
# Prepend system message to first user message if exists
|
129
|
-
if messages:
|
130
|
-
messages[0]["content"] = f"{system_msg['content']}\n\n{messages[0]['content']}"
|
131
|
-
|
132
|
-
response = client.chat.completions.create(
|
133
|
-
model=model,
|
134
|
-
messages=messages,
|
135
|
-
temperature=temperature
|
136
|
-
)
|
137
|
-
|
138
|
-
return response.choices[0].message.content
|
139
|
-
|
140
|
-
except Exception as e:
|
141
|
-
# If OpenAI-style API fails, fall back to Google's genai library
|
142
|
-
logger.info("Falling back to Google's genai library")
|
143
|
-
genai.configure(api_key=api_key)
|
144
|
-
|
145
|
-
# Convert messages to Gemini format
|
146
|
-
gemini_messages = []
|
147
|
-
for msg in messages:
|
148
|
-
if msg["role"] == "system":
|
149
|
-
# Prepend system message to first user message if exists
|
150
|
-
if gemini_messages:
|
151
|
-
gemini_messages[0].parts[0].text = f"{msg['content']}\n\n{gemini_messages[0].parts[0].text}"
|
152
|
-
else:
|
153
|
-
gemini_messages.append({"role": msg["role"], "parts": [{"text": msg["content"]}]})
|
154
|
-
|
155
|
-
# Create Gemini model and generate response
|
156
|
-
model = genai.GenerativeModel(model_name=model)
|
157
|
-
response = model.generate_content(
|
158
|
-
gemini_messages,
|
159
|
-
generation_config=genai.types.GenerationConfig(
|
160
|
-
temperature=temperature,
|
161
|
-
max_output_tokens=max_tokens
|
162
|
-
)
|
163
|
-
)
|
164
|
-
|
165
|
-
return response.text
|
166
|
-
|
167
|
-
elif "grok" in model:
|
168
|
-
# Randomly choose between OpenAI and Anthropic SDK
|
169
|
-
use_anthropic = random.choice([True, False])
|
170
|
-
|
171
|
-
if use_anthropic:
|
172
|
-
print("using anthropic")
|
173
|
-
client = anthropic.Anthropic(
|
174
|
-
api_key=api_key,
|
175
|
-
base_url="https://api.x.ai"
|
176
|
-
)
|
177
|
-
|
178
|
-
system_msg = messages.pop(0)["content"] if messages and messages[0]["role"] == "system" else ""
|
179
|
-
response = client.messages.create(
|
180
|
-
model=model,
|
181
|
-
max_tokens=max_tokens,
|
182
|
-
temperature=temperature,
|
183
|
-
messages=messages,
|
184
|
-
system=system_msg
|
185
|
-
)
|
186
|
-
return response.content[0].text
|
187
|
-
else:
|
188
|
-
print("using openai")
|
189
|
-
client = openai.OpenAI(
|
190
|
-
api_key=api_key,
|
191
|
-
base_url="https://api.x.ai/v1"
|
192
|
-
)
|
193
|
-
response = client.chat.completions.create(
|
194
|
-
model=model,
|
195
|
-
messages=messages,
|
196
|
-
max_tokens=max_tokens,
|
197
|
-
temperature=temperature
|
198
|
-
)
|
199
|
-
return response.choices[0].message.content
|
200
|
-
|
201
|
-
else: # OpenAI models
|
202
|
-
client = openai.OpenAI(api_key=api_key, base_url=base_url)
|
203
|
-
response = client.chat.completions.create(
|
204
|
-
model=model,
|
205
|
-
messages=messages,
|
206
|
-
max_tokens=max_tokens,
|
207
|
-
temperature=temperature,
|
208
|
-
)
|
209
|
-
return response.choices[0].message.content
|
210
|
-
|
211
|
-
# Release the API key after successful use
|
212
|
-
if not api_key:
|
213
|
-
key_manager.release_config(service, api_key)
|
214
|
-
|
215
|
-
return response
|
216
|
-
|
217
|
-
except Exception as e:
|
218
|
-
# Report error to key manager
|
219
|
-
if not api_key:
|
220
|
-
key_manager.report_error(service, api_key)
|
221
|
-
raise
|
222
|
-
|
223
|
-
except Exception as e:
|
224
|
-
logger.error(f"Error in completion: {str(e)}")
|
225
|
-
raise
|
226
|
-
|
227
|
-
class Agent:
|
228
|
-
def __init__(self, model_name: str, messages: Optional[str] = None,
|
229
|
-
memory_enabled: bool = False, api_key: Optional[str] = None) -> None:
|
230
|
-
"""Initialize an Agent instance."""
|
231
|
-
# valid_models = ['gpt-3.5-turbo', 'gpt-4', 'claude-2.1', 'gemini-1.5-pro', 'gemini-1.5-flash', 'grok-beta', 'claude-3-5-sonnet-20241022']
|
232
|
-
# if model_name not in valid_models:
|
233
|
-
# raise ValueError(f"Model {model_name} not supported. Supported models: {valid_models}")
|
234
|
-
|
235
|
-
self.id = f"{model_name}-{uuid.uuid4().hex[:8]}"
|
236
|
-
self.model_name = model_name
|
237
|
-
self.history = ChatHistory(messages)
|
238
|
-
self.memory_enabled = memory_enabled
|
239
|
-
self.api_key = api_key
|
240
|
-
self.repo_content = []
|
241
|
-
|
242
|
-
def add_message(self, role, content):
|
243
|
-
repo_content = ""
|
244
|
-
while self.repo_content:
|
245
|
-
repo = self.repo_content.pop()
|
246
|
-
repo_content += f"<repo>\n{repo}\n</repo>\n"
|
247
|
-
|
248
|
-
content = repo_content + content
|
249
|
-
self.history.add_message(content, role)
|
250
|
-
|
251
|
-
def generate_response(self, max_tokens=3585, temperature=0.7):
|
252
|
-
if not self.history.messages:
|
253
|
-
raise ValueError("No messages in history to generate response from")
|
254
|
-
|
255
|
-
messages = [{"role": msg["role"], "content": msg["content"]} for msg in self.history.messages]
|
256
|
-
|
257
|
-
response_text = completion(
|
258
|
-
model=self.model_name,
|
259
|
-
messages=messages,
|
260
|
-
max_tokens=max_tokens,
|
261
|
-
temperature=temperature,
|
262
|
-
api_key=self.api_key
|
263
|
-
)
|
264
|
-
if messages[-1]["role"] == "assistant":
|
265
|
-
self.history.messages[-1]["content"] = response_text
|
266
|
-
|
267
|
-
elif self.memory_enabled:
|
268
|
-
self.add_message("assistant", response_text)
|
269
|
-
|
270
|
-
return response_text
|
271
|
-
|
272
|
-
def save_conversation(self):
|
273
|
-
filename = f"{self.id}.json"
|
274
|
-
with open(filename, 'w', encoding='utf-8') as file:
|
275
|
-
json.dump(self.history.messages, file, ensure_ascii=False, indent=4)
|
276
|
-
|
277
|
-
def load_conversation(self, filename=None):
|
278
|
-
if filename is None:
|
279
|
-
filename = f"{self.id}.json"
|
280
|
-
with open(filename, 'r', encoding='utf-8') as file:
|
281
|
-
messages = json.load(file)
|
282
|
-
self.history = ChatHistory(messages)
|
283
|
-
|
284
|
-
def load_repo(self, repo_name: Optional[str] = None, commit_hash: Optional[str] = None):
|
285
|
-
"""eg: repo_name: xihajun/llm_dialog_manager"""
|
286
|
-
if repo_name:
|
287
|
-
if commit_hash:
|
288
|
-
repo_url = f"https://github.com/{repo_name}/archive/{commit_hash}.zip"
|
289
|
-
else:
|
290
|
-
repo_url = f"https://github.com/{repo_name}/archive/refs/heads/main.zip"
|
291
|
-
|
292
|
-
if not repo_url:
|
293
|
-
raise ValueError("Either repo_url or both username and repo_name must be provided")
|
294
|
-
|
295
|
-
response = requests.get(repo_url)
|
296
|
-
if response.status_code == 200:
|
297
|
-
repo_content = ""
|
298
|
-
with zipfile.ZipFile(io.BytesIO(response.content)) as z:
|
299
|
-
for file_info in z.infolist():
|
300
|
-
if not file_info.is_dir() and file_info.filename.endswith(('.py', '.txt')):
|
301
|
-
with z.open(file_info) as f:
|
302
|
-
content = f.read().decode('utf-8')
|
303
|
-
repo_content += f"{file_info.filename}\n```\n{content}\n```\n"
|
304
|
-
self.repo_content.append(repo_content)
|
305
|
-
else:
|
306
|
-
raise ValueError(f"Failed to download repository from {repo_url}")
|
307
|
-
|
308
|
-
if __name__ == "__main__":
|
309
|
-
|
310
|
-
# write a test for detect finding agent
|
311
|
-
text = "I think the answer is 42"
|
312
|
-
|
313
|
-
# from agent.messageloader import information_detector_messages
|
314
|
-
|
315
|
-
# # Now you can print or use information_detector_messages as needed
|
316
|
-
# information_detector_agent = Agent("gemini-1.5-pro", information_detector_messages)
|
317
|
-
# information_detector_agent.add_message("user", text)
|
318
|
-
# response = information_detector_agent.generate_response()
|
319
|
-
# print(response)
|
320
|
-
agent = Agent("gemini-1.5-pro-002", "you are an assistant", memory_enabled=True)
|
321
|
-
|
322
|
-
# Format the prompt to check if the section is the last one in the outline
|
323
|
-
prompt = f"Say: {text}\n"
|
324
|
-
|
325
|
-
# Add the prompt as a message from the user
|
326
|
-
agent.add_message("user", prompt)
|
327
|
-
agent.add_message("assistant", "the answer")
|
328
|
-
|
329
|
-
print(agent.generate_response())
|
330
|
-
print(agent.history[:])
|
331
|
-
last_message = agent.history.pop()
|
332
|
-
print(last_message)
|
333
|
-
print(agent.history[:])
|
@@ -1,146 +0,0 @@
|
|
1
|
-
from typing import List, Dict, Optional, Union
|
2
|
-
|
3
|
-
class ChatHistory:
|
4
|
-
def __init__(self, input_data: Union[str, List[Dict[str, str]]] = "") -> None:
|
5
|
-
self.messages: List[Dict[str, str]] = []
|
6
|
-
if isinstance(input_data, str) and input_data:
|
7
|
-
self.add_message(input_data, "system")
|
8
|
-
elif isinstance(input_data, list):
|
9
|
-
self.load_messages(input_data)
|
10
|
-
self.last_role: str = "system" if not self.messages else self.messages[-1]["role"]
|
11
|
-
|
12
|
-
def load_messages(self, messages: List[Dict[str, str]]) -> None:
|
13
|
-
for message in messages:
|
14
|
-
if not ("role" in message and "content" in message):
|
15
|
-
raise ValueError("Each message must have a 'role' and 'content'.")
|
16
|
-
if message["role"] not in ["user", "assistant", "system"]:
|
17
|
-
raise ValueError(f"Invalid role: {message['role']}")
|
18
|
-
self.messages.append(message)
|
19
|
-
|
20
|
-
def pop(self):
|
21
|
-
if not self.messages:
|
22
|
-
return None
|
23
|
-
|
24
|
-
popped_message = self.messages.pop()
|
25
|
-
|
26
|
-
if self.messages:
|
27
|
-
self.last_role = self.messages[-1]["role"]
|
28
|
-
else:
|
29
|
-
self.last_role = "system"
|
30
|
-
|
31
|
-
return popped_message["content"]
|
32
|
-
|
33
|
-
def __len__(self):
|
34
|
-
return len(self.messages)
|
35
|
-
|
36
|
-
def __str__(self):
|
37
|
-
return '\n'.join([f"Message {i} ({msg['role']}): {msg['content']}" for i, msg in enumerate(self.messages)])
|
38
|
-
|
39
|
-
def __getitem__(self, key):
|
40
|
-
if isinstance(key, slice):
|
41
|
-
# Handle slice object, no change needed here for negative indices as slices are handled by list itself
|
42
|
-
print('\n'.join([f"({msg['role']}): {msg['content']}" for i, msg in enumerate(self.messages[key])]))
|
43
|
-
return self.messages[key]
|
44
|
-
elif isinstance(key, int):
|
45
|
-
# Adjust for negative indices
|
46
|
-
if key < 0:
|
47
|
-
key += len(self.messages) # Convert negative index to positive
|
48
|
-
if 0 <= key < len(self.messages):
|
49
|
-
snippet = self.get_conversation_snippet(key)
|
50
|
-
print('\n'.join([f"({v['role']}): {self.color_text(v['content'], 'green') if k == 'current' else v['content']}" for k, v in snippet.items() if v]))
|
51
|
-
return self.messages[key]
|
52
|
-
else:
|
53
|
-
raise IndexError("Message index out of range.")
|
54
|
-
else:
|
55
|
-
raise TypeError("Invalid argument type.")
|
56
|
-
|
57
|
-
def __setitem__(self, index, value):
|
58
|
-
if not isinstance(value, str):
|
59
|
-
raise ValueError("Message content must be a string.")
|
60
|
-
role = "system" if index % 2 == 0 else "user"
|
61
|
-
self.messages[index] = {"role": role, "content": value}
|
62
|
-
|
63
|
-
def __add__(self, message):
|
64
|
-
if self.last_role == "system":
|
65
|
-
self.add_user_message(message)
|
66
|
-
else:
|
67
|
-
next_role = "assistant" if self.last_role == "user" else "user"
|
68
|
-
self.add_message(message, next_role)
|
69
|
-
|
70
|
-
def __contains__(self, item):
|
71
|
-
return any(item in message['content'] for message in self.messages)
|
72
|
-
|
73
|
-
def add_message(self, content, role):
|
74
|
-
self.messages.append({"role": role, "content": content})
|
75
|
-
self.last_role = role
|
76
|
-
|
77
|
-
def add_user_message(self, content):
|
78
|
-
if self.last_role in ["system", "assistant"]:
|
79
|
-
self.add_message(content, "user")
|
80
|
-
else:
|
81
|
-
raise ValueError("A user message must follow a system or assistant message.")
|
82
|
-
|
83
|
-
def add_assistant_message(self, content):
|
84
|
-
if self.last_role == "user":
|
85
|
-
self.add_message(content, "assistant")
|
86
|
-
else:
|
87
|
-
raise ValueError("An assistant message must follow a user message.")
|
88
|
-
|
89
|
-
def add_marker(self, marker, index=None):
|
90
|
-
if not isinstance(marker, str):
|
91
|
-
raise ValueError("Marker must be a string.")
|
92
|
-
if index is None:
|
93
|
-
index = len(self.messages) - 1
|
94
|
-
if 0 <= index < len(self.messages):
|
95
|
-
self.messages[index]["marker"] = marker
|
96
|
-
else:
|
97
|
-
raise IndexError("Invalid index for marker.")
|
98
|
-
|
99
|
-
def conversation_status(self):
|
100
|
-
return {
|
101
|
-
"last_message_role": self.last_role,
|
102
|
-
"total_messages": len(self.messages),
|
103
|
-
"last_message_content": self.messages[-1]["content"] if self.messages else "No messages",
|
104
|
-
}
|
105
|
-
|
106
|
-
def display_conversation_status(self):
|
107
|
-
status = self.conversation_status()
|
108
|
-
print(f"Role of the last message: {status['last_message_role']}")
|
109
|
-
print(f"Total number of messages: {status['total_messages']}")
|
110
|
-
print(f"Content of the last message: {status['last_message_content']}")
|
111
|
-
|
112
|
-
def search_for_keyword(self, keyword):
|
113
|
-
return [msg for msg in self.messages if keyword.lower() in msg["content"].lower()]
|
114
|
-
|
115
|
-
def has_user_or_assistant_spoken_since_last_system(self):
|
116
|
-
for msg in reversed(self.messages):
|
117
|
-
if msg["role"] == "system":
|
118
|
-
return False
|
119
|
-
if msg["role"] in ["user", "assistant"]:
|
120
|
-
return True
|
121
|
-
return False
|
122
|
-
|
123
|
-
def get_conversation_snippet(self, index):
|
124
|
-
snippet = {"previous": None, "current": None, "next": None}
|
125
|
-
if 0 <= index < len(self.messages):
|
126
|
-
snippet['current'] = self.messages[index]
|
127
|
-
if index > 0:
|
128
|
-
snippet['previous'] = self.messages[index - 1]
|
129
|
-
if index + 1 < len(self.messages):
|
130
|
-
snippet['next'] = self.messages[index + 1]
|
131
|
-
else:
|
132
|
-
raise IndexError("Invalid index.")
|
133
|
-
return snippet
|
134
|
-
|
135
|
-
def display_snippet(self, index):
|
136
|
-
snippet = self.get_conversation_snippet(index)
|
137
|
-
for key, value in snippet.items():
|
138
|
-
if value:
|
139
|
-
print(f"{key.capitalize()} Message ({value['role']}): {value['content']}")
|
140
|
-
else:
|
141
|
-
print(f"{key.capitalize()}: None")
|
142
|
-
|
143
|
-
@staticmethod
|
144
|
-
def color_text(text, color):
|
145
|
-
colors = {"green": "\033[92m", "red": "\033[91m", "end": "\033[0m"}
|
146
|
-
return f"{colors[color]}{text}{colors['end']}"
|
File without changes
|
File without changes
|
File without changes
|
{llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/llm_dialog_manager.egg-info/SOURCES.txt
RENAMED
File without changes
|
File without changes
|
{llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/llm_dialog_manager.egg-info/requires.txt
RENAMED
File without changes
|
{llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/llm_dialog_manager.egg-info/top_level.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|