llm-dialog-manager 0.3.2__tar.gz → 0.3.5__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/PKG-INFO +2 -2
- {llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/llm_dialog_manager/__init__.py +1 -1
- llm_dialog_manager-0.3.5/llm_dialog_manager/agent.py +514 -0
- llm_dialog_manager-0.3.5/llm_dialog_manager/chat_history.py +233 -0
- {llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/llm_dialog_manager.egg-info/PKG-INFO +2 -2
- {llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/pyproject.toml +1 -1
- llm_dialog_manager-0.3.2/llm_dialog_manager/agent.py +0 -333
- llm_dialog_manager-0.3.2/llm_dialog_manager/chat_history.py +0 -146
- {llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/LICENSE +0 -0
- {llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/README.md +0 -0
- {llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/llm_dialog_manager/key_manager.py +0 -0
- {llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/llm_dialog_manager.egg-info/SOURCES.txt +0 -0
- {llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/llm_dialog_manager.egg-info/dependency_links.txt +0 -0
- {llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/llm_dialog_manager.egg-info/requires.txt +0 -0
- {llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/llm_dialog_manager.egg-info/top_level.txt +0 -0
- {llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/setup.cfg +0 -0
- {llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/tests/test_agent.py +0 -0
- {llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/tests/test_chat_history.py +0 -0
- {llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/tests/test_key_manager.py +0 -0
@@ -0,0 +1,514 @@
|
|
1
|
+
# Standard library imports
|
2
|
+
import json
|
3
|
+
import os
|
4
|
+
import uuid
|
5
|
+
from typing import List, Dict, Optional, Union
|
6
|
+
import logging
|
7
|
+
from pathlib import Path
|
8
|
+
import random
|
9
|
+
import requests
|
10
|
+
import zipfile
|
11
|
+
import io
|
12
|
+
import base64
|
13
|
+
from PIL import Image
|
14
|
+
|
15
|
+
# Third-party imports
|
16
|
+
import anthropic
|
17
|
+
from anthropic import AnthropicVertex
|
18
|
+
import google.generativeai as genai
|
19
|
+
import openai
|
20
|
+
from dotenv import load_dotenv
|
21
|
+
|
22
|
+
# Local imports
|
23
|
+
from .chat_history import ChatHistory
|
24
|
+
from .key_manager import key_manager
|
25
|
+
|
26
|
+
# Set up logging
|
27
|
+
logging.basicConfig(level=logging.INFO)
|
28
|
+
logger = logging.getLogger(__name__)
|
29
|
+
|
30
|
+
# Load environment variables
|
31
|
+
def load_env_vars():
|
32
|
+
"""Load environment variables from .env file"""
|
33
|
+
env_path = Path(__file__).parent / '.env'
|
34
|
+
if env_path.exists():
|
35
|
+
load_dotenv(env_path)
|
36
|
+
else:
|
37
|
+
logger.warning(".env file not found. Using system environment variables.")
|
38
|
+
|
39
|
+
load_env_vars()
|
40
|
+
|
41
|
+
def completion(model: str, messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]], max_tokens: int = 1000,
|
42
|
+
temperature: float = 0.5, api_key: Optional[str] = None,
|
43
|
+
base_url: Optional[str] = None, json_format: bool = False) -> str:
|
44
|
+
"""
|
45
|
+
Generate a completion using the specified model and messages.
|
46
|
+
"""
|
47
|
+
try:
|
48
|
+
service = ""
|
49
|
+
if "claude" in model:
|
50
|
+
service = "anthropic"
|
51
|
+
elif "gemini" in model:
|
52
|
+
service = "gemini"
|
53
|
+
elif "grok" in model:
|
54
|
+
service = "x"
|
55
|
+
else:
|
56
|
+
service = "openai"
|
57
|
+
|
58
|
+
# Get API key and base URL from key manager if not provided
|
59
|
+
if not api_key:
|
60
|
+
# api_key, base_url = key_manager.get_config(service)
|
61
|
+
# Placeholder for key_manager
|
62
|
+
api_key = os.getenv(f"{service.upper()}_API_KEY")
|
63
|
+
base_url = os.getenv(f"{service.upper()}_BASE_URL")
|
64
|
+
|
65
|
+
def format_messages_for_api(model, messages):
|
66
|
+
"""Convert ChatHistory messages to the format required by the specific API."""
|
67
|
+
if "claude" in model:
|
68
|
+
formatted = []
|
69
|
+
system_msg = ""
|
70
|
+
if messages and messages[0]["role"] == "system":
|
71
|
+
system_msg = messages.pop(0)["content"]
|
72
|
+
for msg in messages:
|
73
|
+
content = msg["content"]
|
74
|
+
if isinstance(content, str):
|
75
|
+
formatted.append({"role": msg["role"], "content": content})
|
76
|
+
elif isinstance(content, list):
|
77
|
+
# Combine content blocks into a single message
|
78
|
+
combined_content = []
|
79
|
+
for block in content:
|
80
|
+
if isinstance(block, str):
|
81
|
+
combined_content.append({"type": "text", "text": block})
|
82
|
+
elif isinstance(block, Image.Image):
|
83
|
+
# For Claude, convert PIL.Image to base64
|
84
|
+
buffered = io.BytesIO()
|
85
|
+
block.save(buffered, format="PNG")
|
86
|
+
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
87
|
+
combined_content.append({
|
88
|
+
"type": "image",
|
89
|
+
"source": {
|
90
|
+
"type": "base64",
|
91
|
+
"media_type": "image/png",
|
92
|
+
"data": image_base64
|
93
|
+
}
|
94
|
+
})
|
95
|
+
elif isinstance(block, dict):
|
96
|
+
if block.get("type") == "image_url":
|
97
|
+
combined_content.append({
|
98
|
+
"type": "image",
|
99
|
+
"source": {
|
100
|
+
"type": "url",
|
101
|
+
"url": block["image_url"]["url"]
|
102
|
+
}
|
103
|
+
})
|
104
|
+
elif block.get("type") == "image_base64":
|
105
|
+
combined_content.append({
|
106
|
+
"type": "image",
|
107
|
+
"source": {
|
108
|
+
"type": "base64",
|
109
|
+
"media_type": block["image_base64"]["media_type"],
|
110
|
+
"data": block["image_base64"]["data"]
|
111
|
+
}
|
112
|
+
})
|
113
|
+
formatted.append({"role": msg["role"], "content": combined_content})
|
114
|
+
return system_msg, formatted
|
115
|
+
|
116
|
+
elif "gemini" in model or "gpt" in model or "grok" in model:
|
117
|
+
formatted = []
|
118
|
+
for msg in messages:
|
119
|
+
content = msg["content"]
|
120
|
+
if isinstance(content, str):
|
121
|
+
formatted.append({"role": msg["role"], "parts": [content]})
|
122
|
+
elif isinstance(content, list):
|
123
|
+
parts = []
|
124
|
+
for block in content:
|
125
|
+
if isinstance(block, str):
|
126
|
+
parts.append(block)
|
127
|
+
elif isinstance(block, Image.Image):
|
128
|
+
parts.append(block)
|
129
|
+
elif isinstance(block, dict):
|
130
|
+
if block.get("type") == "image_url":
|
131
|
+
parts.append({"type": "image_url", "image_url": {"url": block["image_url"]["url"]}})
|
132
|
+
elif block.get("type") == "image_base64":
|
133
|
+
parts.append({"type": "image_base64", "image_base64": {"data": block["image_base64"]["data"], "media_type": block["image_base64"]["media_type"]}})
|
134
|
+
formatted.append({"role": msg["role"], "parts": parts})
|
135
|
+
return None, formatted
|
136
|
+
|
137
|
+
else: # OpenAI models
|
138
|
+
formatted = []
|
139
|
+
for msg in messages:
|
140
|
+
content = msg["content"]
|
141
|
+
if isinstance(content, str):
|
142
|
+
formatted.append({"role": msg["role"], "content": content})
|
143
|
+
elif isinstance(content, list):
|
144
|
+
# OpenAI expects 'content' as string; images are not directly supported
|
145
|
+
# You can convert images to URLs or descriptions if needed
|
146
|
+
combined_content = ""
|
147
|
+
for block in content:
|
148
|
+
if isinstance(block, str):
|
149
|
+
combined_content += block + "\n"
|
150
|
+
elif isinstance(block, Image.Image):
|
151
|
+
# Convert PIL.Image to base64 or upload and use URL
|
152
|
+
buffered = io.BytesIO()
|
153
|
+
block.save(buffered, format="PNG")
|
154
|
+
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
155
|
+
combined_content += f"[Image Base64: {image_base64[:30]}...]\n"
|
156
|
+
elif isinstance(block, dict):
|
157
|
+
if block.get("type") == "image_url":
|
158
|
+
combined_content += f"[Image: {block['image_url']['url']}]\n"
|
159
|
+
elif block.get("type") == "image_base64":
|
160
|
+
combined_content += f"[Image Base64: {block['image_base64']['data'][:30]}...]\n"
|
161
|
+
formatted.append({"role": msg["role"], "content": combined_content.strip()})
|
162
|
+
return None, formatted
|
163
|
+
|
164
|
+
system_msg, formatted_messages = format_messages_for_api(model, messages.copy())
|
165
|
+
|
166
|
+
if "claude" in model:
|
167
|
+
# Check for Vertex configuration
|
168
|
+
vertex_project_id = os.getenv('VERTEX_PROJECT_ID')
|
169
|
+
vertex_region = os.getenv('VERTEX_REGION')
|
170
|
+
|
171
|
+
if vertex_project_id and vertex_region:
|
172
|
+
client = AnthropicVertex(
|
173
|
+
region=vertex_region,
|
174
|
+
project_id=vertex_project_id
|
175
|
+
)
|
176
|
+
else:
|
177
|
+
client = anthropic.Anthropic(api_key=api_key, base_url=base_url)
|
178
|
+
|
179
|
+
response = client.messages.create(
|
180
|
+
model=model,
|
181
|
+
max_tokens=max_tokens,
|
182
|
+
temperature=temperature,
|
183
|
+
messages=formatted_messages,
|
184
|
+
system=system_msg
|
185
|
+
)
|
186
|
+
|
187
|
+
while response.stop_reason == "max_tokens":
|
188
|
+
if formatted_messages[-1]['role'] == "user":
|
189
|
+
formatted_messages.append({"role": "assistant", "content": response.completion})
|
190
|
+
else:
|
191
|
+
formatted_messages[-1]['content'] += response.completion
|
192
|
+
|
193
|
+
response = client.messages.create(
|
194
|
+
model=model,
|
195
|
+
max_tokens=max_tokens,
|
196
|
+
temperature=temperature,
|
197
|
+
messages=formatted_messages,
|
198
|
+
system=system_msg
|
199
|
+
)
|
200
|
+
|
201
|
+
if formatted_messages[-1]['role'] == "assistant" and response.stop_reason == "end_turn":
|
202
|
+
formatted_messages[-1]['content'] += response.completion
|
203
|
+
return formatted_messages[-1]['content']
|
204
|
+
|
205
|
+
return response.completion
|
206
|
+
|
207
|
+
elif "gemini" in model:
|
208
|
+
try:
|
209
|
+
# First try OpenAI-style API
|
210
|
+
client = openai.OpenAI(
|
211
|
+
api_key=api_key,
|
212
|
+
base_url="https://generativelanguage.googleapis.com/v1beta/"
|
213
|
+
)
|
214
|
+
# Set response_format based on json_format
|
215
|
+
response_format = {"type": "json_object"} if json_format else {"type": "plain_text"}
|
216
|
+
|
217
|
+
response = client.chat.completions.create(
|
218
|
+
model=model,
|
219
|
+
messages=formatted_messages,
|
220
|
+
temperature=temperature,
|
221
|
+
response_format=response_format # Added response_format
|
222
|
+
)
|
223
|
+
return response.choices[0].message.content
|
224
|
+
|
225
|
+
except Exception as e:
|
226
|
+
# If OpenAI-style API fails, fall back to Google's genai library
|
227
|
+
logger.info("Falling back to Google's genai library")
|
228
|
+
genai.configure(api_key=api_key)
|
229
|
+
|
230
|
+
# Convert messages to Gemini format
|
231
|
+
gemini_messages = []
|
232
|
+
for msg in messages:
|
233
|
+
if msg["role"] == "system":
|
234
|
+
# Prepend system message to first user message if exists
|
235
|
+
if gemini_messages:
|
236
|
+
first_msg = gemini_messages[0]
|
237
|
+
if "parts" in first_msg and len(first_msg["parts"]) > 0:
|
238
|
+
first_msg["parts"][0] = f"{msg['content']}\n\n{first_msg['parts'][0]}"
|
239
|
+
else:
|
240
|
+
gemini_messages.append({"role": msg["role"], "parts": msg["content"]})
|
241
|
+
|
242
|
+
# Set response_mime_type based on json_format
|
243
|
+
mime_type = "application/json" if json_format else "text/plain"
|
244
|
+
|
245
|
+
# Create Gemini model and generate response
|
246
|
+
model_instance = genai.GenerativeModel(model_name=model)
|
247
|
+
response = model_instance.generate_content(
|
248
|
+
gemini_messages,
|
249
|
+
generation_config=genai.types.GenerationConfig(
|
250
|
+
temperature=temperature,
|
251
|
+
response_mime_type=mime_type, # Modified based on json_format
|
252
|
+
max_output_tokens=max_tokens
|
253
|
+
)
|
254
|
+
)
|
255
|
+
|
256
|
+
return response.text
|
257
|
+
|
258
|
+
elif "grok" in model:
|
259
|
+
# Randomly choose between OpenAI and Anthropic SDK
|
260
|
+
use_anthropic = random.choice([True, False])
|
261
|
+
|
262
|
+
if use_anthropic:
|
263
|
+
logger.info("Using Anthropic for Grok model")
|
264
|
+
client = anthropic.Anthropic(
|
265
|
+
api_key=api_key,
|
266
|
+
base_url="https://api.x.ai"
|
267
|
+
)
|
268
|
+
|
269
|
+
system_msg = ""
|
270
|
+
if messages and messages[0]["role"] == "system":
|
271
|
+
system_msg = messages.pop(0)["content"]
|
272
|
+
|
273
|
+
response = client.messages.create(
|
274
|
+
model=model,
|
275
|
+
max_tokens=max_tokens,
|
276
|
+
temperature=temperature,
|
277
|
+
messages=formatted_messages,
|
278
|
+
system=system_msg
|
279
|
+
)
|
280
|
+
return response.completion
|
281
|
+
else:
|
282
|
+
logger.info("Using OpenAI for Grok model")
|
283
|
+
client = openai.OpenAI(
|
284
|
+
api_key=api_key,
|
285
|
+
base_url="https://api.x.ai/v1"
|
286
|
+
)
|
287
|
+
# Set response_format based on json_format
|
288
|
+
response_format = {"type": "json_object"} if json_format else {"type": "plain_text"}
|
289
|
+
|
290
|
+
response = client.chat.completions.create(
|
291
|
+
model=model,
|
292
|
+
messages=formatted_messages,
|
293
|
+
max_tokens=max_tokens,
|
294
|
+
temperature=temperature,
|
295
|
+
response_format=response_format # Added response_format
|
296
|
+
)
|
297
|
+
return response.choices[0].message.content
|
298
|
+
|
299
|
+
else: # OpenAI models
|
300
|
+
client = openai.OpenAI(api_key=api_key, base_url=base_url)
|
301
|
+
# Set response_format based on json_format
|
302
|
+
response_format = {"type": "json_object"} if json_format else {"type": "plain_text"}
|
303
|
+
|
304
|
+
response = client.chat.completions.create(
|
305
|
+
model=model,
|
306
|
+
messages=formatted_messages,
|
307
|
+
max_tokens=max_tokens,
|
308
|
+
temperature=temperature,
|
309
|
+
response_format=response_format # Added response_format
|
310
|
+
)
|
311
|
+
return response.choices[0].message.content
|
312
|
+
|
313
|
+
# Release the API key after successful use
|
314
|
+
if not api_key:
|
315
|
+
# key_manager.release_config(service, api_key)
|
316
|
+
pass
|
317
|
+
|
318
|
+
return response
|
319
|
+
|
320
|
+
except Exception as e:
|
321
|
+
logger.error(f"Error in completion: {str(e)}")
|
322
|
+
raise
|
323
|
+
|
324
|
+
class Agent:
|
325
|
+
def __init__(self, model_name: str, messages: Optional[Union[str, List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]]] = None,
|
326
|
+
memory_enabled: bool = False, api_key: Optional[str] = None) -> None:
|
327
|
+
"""Initialize an Agent instance."""
|
328
|
+
self.id = f"{model_name}-{uuid.uuid4().hex[:8]}"
|
329
|
+
self.model_name = model_name
|
330
|
+
self.history = ChatHistory(messages) if messages else ChatHistory()
|
331
|
+
self.memory_enabled = memory_enabled
|
332
|
+
self.api_key = api_key
|
333
|
+
self.repo_content = []
|
334
|
+
|
335
|
+
def add_message(self, role: str, content: Union[str, List[Union[str, Image.Image, Dict]]]):
|
336
|
+
"""Add a message to the conversation."""
|
337
|
+
self.history.add_message(content, role)
|
338
|
+
|
339
|
+
def add_user_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
|
340
|
+
"""Add a user message."""
|
341
|
+
self.history.add_user_message(content)
|
342
|
+
|
343
|
+
def add_assistant_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
|
344
|
+
"""Add an assistant message."""
|
345
|
+
self.history.add_assistant_message(content)
|
346
|
+
|
347
|
+
def add_image(self, image_path: Optional[str] = None, image_url: Optional[str] = None, media_type: Optional[str] = "image/jpeg"):
|
348
|
+
"""
|
349
|
+
Add an image to the conversation.
|
350
|
+
Either image_path or image_url must be provided.
|
351
|
+
"""
|
352
|
+
if not image_path and not image_url:
|
353
|
+
raise ValueError("Either image_path or image_url must be provided.")
|
354
|
+
|
355
|
+
if image_path:
|
356
|
+
if not os.path.exists(image_path):
|
357
|
+
raise FileNotFoundError(f"Image file {image_path} does not exist.")
|
358
|
+
if "gemini" in self.model_name:
|
359
|
+
# For Gemini, load as PIL.Image
|
360
|
+
image_pil = Image.open(image_path)
|
361
|
+
image_block = image_pil
|
362
|
+
else:
|
363
|
+
# For Claude and others, use base64 encoding
|
364
|
+
with open(image_path, "rb") as img_file:
|
365
|
+
image_data = base64.standard_b64encode(img_file.read()).decode("utf-8")
|
366
|
+
image_block = {
|
367
|
+
"type": "image_base64",
|
368
|
+
"image_base64": {
|
369
|
+
"media_type": media_type,
|
370
|
+
"data": image_data
|
371
|
+
}
|
372
|
+
}
|
373
|
+
else:
|
374
|
+
# If image_url is provided
|
375
|
+
if "gemini" in self.model_name:
|
376
|
+
# For Gemini, you can pass image URLs directly
|
377
|
+
image_block = {"type": "image_url", "image_url": {"url": image_url}}
|
378
|
+
else:
|
379
|
+
# For Claude and others, use image URLs
|
380
|
+
image_block = {
|
381
|
+
"type": "image_url",
|
382
|
+
"image_url": {
|
383
|
+
"url": image_url
|
384
|
+
}
|
385
|
+
}
|
386
|
+
|
387
|
+
# Add the image block to the last user message or as a new user message
|
388
|
+
if self.history.last_role == "user":
|
389
|
+
current_content = self.history.messages[-1]["content"]
|
390
|
+
if isinstance(current_content, list):
|
391
|
+
current_content.append(image_block)
|
392
|
+
else:
|
393
|
+
self.history.messages[-1]["content"] = [current_content, image_block]
|
394
|
+
else:
|
395
|
+
# Start a new user message with the image
|
396
|
+
self.history.add_message([image_block], "user")
|
397
|
+
|
398
|
+
def generate_response(self, max_tokens=3585, temperature=0.7, json_format: bool = False) -> str:
|
399
|
+
"""Generate a response from the agent.
|
400
|
+
|
401
|
+
Args:
|
402
|
+
max_tokens (int, optional): Maximum number of tokens. Defaults to 3585.
|
403
|
+
temperature (float, optional): Sampling temperature. Defaults to 0.7.
|
404
|
+
json_format (bool, optional): Whether to enable JSON output format. Defaults to False.
|
405
|
+
|
406
|
+
Returns:
|
407
|
+
str: The generated response.
|
408
|
+
"""
|
409
|
+
if not self.history.messages:
|
410
|
+
raise ValueError("No messages in history to generate response from")
|
411
|
+
|
412
|
+
messages = self.history.messages
|
413
|
+
|
414
|
+
response_text = completion(
|
415
|
+
model=self.model_name,
|
416
|
+
messages=messages,
|
417
|
+
max_tokens=max_tokens,
|
418
|
+
temperature=temperature,
|
419
|
+
api_key=self.api_key,
|
420
|
+
json_format=json_format # Pass json_format to completion
|
421
|
+
)
|
422
|
+
if self.model_name.startswith("openai"):
|
423
|
+
# OpenAI does not support images, so responses are simple strings
|
424
|
+
if self.history.messages[-1]["role"] == "assistant":
|
425
|
+
self.history.messages[-1]["content"] = response_text
|
426
|
+
elif self.memory_enabled:
|
427
|
+
self.add_message("assistant", response_text)
|
428
|
+
elif "claude" in self.model_name:
|
429
|
+
if self.history.messages[-1]["role"] == "assistant":
|
430
|
+
self.history.messages[-1]["content"] = response_text
|
431
|
+
elif self.memory_enabled:
|
432
|
+
self.add_message("assistant", response_text)
|
433
|
+
elif "gemini" in self.model_name or "grok" in self.model_name:
|
434
|
+
if self.history.messages[-1]["role"] == "assistant":
|
435
|
+
if isinstance(self.history.messages[-1]["content"], list):
|
436
|
+
self.history.messages[-1]["content"].append(response_text)
|
437
|
+
else:
|
438
|
+
self.history.messages[-1]["content"] = [self.history.messages[-1]["content"], response_text]
|
439
|
+
elif self.memory_enabled:
|
440
|
+
self.add_message("assistant", response_text)
|
441
|
+
else:
|
442
|
+
# Handle other models similarly
|
443
|
+
if self.history.messages[-1]["role"] == "assistant":
|
444
|
+
self.history.messages[-1]["content"] = response_text
|
445
|
+
elif self.memory_enabled:
|
446
|
+
self.add_message("assistant", response_text)
|
447
|
+
|
448
|
+
return response_text
|
449
|
+
|
450
|
+
def save_conversation(self):
|
451
|
+
filename = f"{self.id}.json"
|
452
|
+
with open(filename, 'w', encoding='utf-8') as file:
|
453
|
+
json.dump(self.history.messages, file, ensure_ascii=False, indent=4)
|
454
|
+
|
455
|
+
def load_conversation(self, filename: Optional[str] = None):
|
456
|
+
if filename is None:
|
457
|
+
filename = f"{self.id}.json"
|
458
|
+
with open(filename, 'r', encoding='utf-8') as file:
|
459
|
+
messages = json.load(file)
|
460
|
+
# Handle deserialization of images if necessary
|
461
|
+
self.history = ChatHistory(messages)
|
462
|
+
|
463
|
+
def add_repo(self, repo_url: Optional[str] = None, username: Optional[str] = None, repo_name: Optional[str] = None, commit_hash: Optional[str] = None):
|
464
|
+
if username and repo_name:
|
465
|
+
if commit_hash:
|
466
|
+
repo_url = f"https://github.com/{username}/{repo_name}/archive/{commit_hash}.zip"
|
467
|
+
else:
|
468
|
+
repo_url = f"https://github.com/{username}/{repo_name}/archive/refs/heads/main.zip"
|
469
|
+
|
470
|
+
if not repo_url:
|
471
|
+
raise ValueError("Either repo_url or both username and repo_name must be provided")
|
472
|
+
|
473
|
+
response = requests.get(repo_url)
|
474
|
+
if response.status_code == 200:
|
475
|
+
repo_content = ""
|
476
|
+
with zipfile.ZipFile(io.BytesIO(response.content)) as z:
|
477
|
+
for file_info in z.infolist():
|
478
|
+
if not file_info.is_dir() and file_info.filename.endswith(('.py', '.txt')):
|
479
|
+
with z.open(file_info) as f:
|
480
|
+
content = f.read().decode('utf-8')
|
481
|
+
repo_content += f"{file_info.filename}\n```\n{content}\n```\n"
|
482
|
+
self.repo_content.append(repo_content)
|
483
|
+
else:
|
484
|
+
raise ValueError(f"Failed to download repository from {repo_url}")
|
485
|
+
|
486
|
+
if __name__ == "__main__":
|
487
|
+
# Example Usage
|
488
|
+
# Create an Agent instance (Gemini model)
|
489
|
+
agent = Agent("gemini-1.5-flash", "you are an assistant", memory_enabled=True)
|
490
|
+
|
491
|
+
# Add an image
|
492
|
+
agent.add_image(image_path="/Users/junfan/Projects/Personal/oneapi/dialog_manager/example.png")
|
493
|
+
|
494
|
+
# Add a user message
|
495
|
+
agent.add_message("user", "What's in this image?")
|
496
|
+
|
497
|
+
# Generate response with JSON format enabled
|
498
|
+
try:
|
499
|
+
response = agent.generate_response(json_format=True) # json_format set to True
|
500
|
+
print("Response:", response)
|
501
|
+
except Exception as e:
|
502
|
+
logger.error(f"Failed to generate response: {e}")
|
503
|
+
|
504
|
+
# Print the entire conversation history
|
505
|
+
print("Conversation History:")
|
506
|
+
print(agent.history)
|
507
|
+
|
508
|
+
# Pop the last message
|
509
|
+
last_message = agent.history.pop()
|
510
|
+
print("Last Message:", last_message)
|
511
|
+
|
512
|
+
# Generate another response without JSON format
|
513
|
+
response = agent.generate_response()
|
514
|
+
print("Response:", response)
|
@@ -0,0 +1,233 @@
|
|
1
|
+
from typing import List, Dict, Optional, Union
|
2
|
+
from PIL import Image
|
3
|
+
|
4
|
+
class ChatHistory:
|
5
|
+
def __init__(self, input_data: Union[str, List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]] = "") -> None:
|
6
|
+
self.messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]] = []
|
7
|
+
if isinstance(input_data, str) and input_data:
|
8
|
+
self.add_message(input_data, "system")
|
9
|
+
elif isinstance(input_data, list):
|
10
|
+
self.load_messages(input_data)
|
11
|
+
self.last_role: str = "system" if not self.messages else self.get_last_role()
|
12
|
+
|
13
|
+
def load_messages(self, messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]) -> None:
|
14
|
+
for message in messages:
|
15
|
+
if not ("role" in message and "content" in message):
|
16
|
+
raise ValueError("Each message must have a 'role' and 'content'.")
|
17
|
+
if message["role"] not in ["user", "assistant", "system"]:
|
18
|
+
raise ValueError(f"Invalid role: {message['role']}")
|
19
|
+
self.messages.append(message)
|
20
|
+
self.last_role = self.get_last_role()
|
21
|
+
|
22
|
+
def get_last_role(self):
|
23
|
+
return self.messages[-1]["role"] if self.messages else "system"
|
24
|
+
|
25
|
+
def pop(self):
|
26
|
+
if not self.messages:
|
27
|
+
return None
|
28
|
+
|
29
|
+
popped_message = self.messages.pop()
|
30
|
+
|
31
|
+
if self.messages:
|
32
|
+
self.last_role = self.get_last_role()
|
33
|
+
else:
|
34
|
+
self.last_role = "system"
|
35
|
+
|
36
|
+
return popped_message["content"]
|
37
|
+
|
38
|
+
def __len__(self):
|
39
|
+
return len(self.messages)
|
40
|
+
|
41
|
+
def __str__(self):
|
42
|
+
formatted_messages = []
|
43
|
+
for i, msg in enumerate(self.messages):
|
44
|
+
role = msg['role']
|
45
|
+
content = msg['content']
|
46
|
+
if isinstance(content, str):
|
47
|
+
formatted_content = content
|
48
|
+
elif isinstance(content, list):
|
49
|
+
parts = []
|
50
|
+
for block in content:
|
51
|
+
if isinstance(block, str):
|
52
|
+
parts.append(block)
|
53
|
+
elif isinstance(block, Image.Image):
|
54
|
+
parts.append(f"[Image Object: {block.filename}]")
|
55
|
+
elif isinstance(block, dict):
|
56
|
+
if block.get("type") == "image_url":
|
57
|
+
parts.append(f"[Image URL: {block.get('image_url', {}).get('url', '')}]")
|
58
|
+
elif block.get("type") == "image_base64":
|
59
|
+
parts.append(f"[Image Base64: {block.get('image_base64', {}).get('data', '')[:30]}...]")
|
60
|
+
formatted_content = "\n".join(parts)
|
61
|
+
else:
|
62
|
+
formatted_content = str(content)
|
63
|
+
formatted_messages.append(f"Message {i} ({role}): {formatted_content}")
|
64
|
+
return '\n'.join(formatted_messages)
|
65
|
+
|
66
|
+
def __getitem__(self, key):
|
67
|
+
if isinstance(key, slice):
|
68
|
+
sliced_messages = self.messages[key]
|
69
|
+
formatted = []
|
70
|
+
for msg in sliced_messages:
|
71
|
+
role = msg['role']
|
72
|
+
content = msg['content']
|
73
|
+
if isinstance(content, str):
|
74
|
+
formatted_content = content
|
75
|
+
elif isinstance(content, list):
|
76
|
+
parts = []
|
77
|
+
for block in content:
|
78
|
+
if isinstance(block, str):
|
79
|
+
parts.append(block)
|
80
|
+
elif isinstance(block, Image.Image):
|
81
|
+
parts.append(f"[Image Object: {block.filename}]")
|
82
|
+
elif isinstance(block, dict):
|
83
|
+
if block.get("type") == "image_url":
|
84
|
+
parts.append(f"[Image URL: {block.get('image_url', {}).get('url', '')}]")
|
85
|
+
elif block.get("type") == "image_base64":
|
86
|
+
parts.append(f"[Image Base64: {block.get('image_base64', {}).get('data', '')[:30]}...]")
|
87
|
+
formatted_content = "\n".join(parts)
|
88
|
+
else:
|
89
|
+
formatted_content = str(content)
|
90
|
+
formatted.append(f"({role}): {formatted_content}")
|
91
|
+
print('\n'.join(formatted))
|
92
|
+
return sliced_messages
|
93
|
+
elif isinstance(key, int):
|
94
|
+
# Adjust for negative indices
|
95
|
+
if key < 0:
|
96
|
+
key += len(self.messages)
|
97
|
+
if 0 <= key < len(self.messages):
|
98
|
+
msg = self.messages[key]
|
99
|
+
role = msg['role']
|
100
|
+
content = msg['content']
|
101
|
+
if isinstance(content, str):
|
102
|
+
formatted_content = content
|
103
|
+
elif isinstance(content, list):
|
104
|
+
parts = []
|
105
|
+
for block in content:
|
106
|
+
if isinstance(block, str):
|
107
|
+
parts.append(block)
|
108
|
+
elif isinstance(block, Image.Image):
|
109
|
+
parts.append(f"[Image Object: {block.filename}]")
|
110
|
+
elif isinstance(block, dict):
|
111
|
+
if block.get("type") == "image_url":
|
112
|
+
parts.append(f"[Image URL: {block.get('image_url', {}).get('url', '')}]")
|
113
|
+
elif block.get("type") == "image_base64":
|
114
|
+
parts.append(f"[Image Base64: {block.get('image_base64', {}).get('data', '')[:30]}...]")
|
115
|
+
formatted_content = "\n".join(parts)
|
116
|
+
else:
|
117
|
+
formatted_content = str(content)
|
118
|
+
snippet = self.get_conversation_snippet(key)
|
119
|
+
print('\n'.join([f"({v['role']}): {v['content']}" for k, v in snippet.items() if v]))
|
120
|
+
return self.messages[key]
|
121
|
+
else:
|
122
|
+
raise IndexError("Message index out of range.")
|
123
|
+
else:
|
124
|
+
raise TypeError("Invalid argument type.")
|
125
|
+
|
126
|
+
def __setitem__(self, index, value):
|
127
|
+
if not isinstance(value, (str, list)):
|
128
|
+
raise ValueError("Message content must be a string or a list of content blocks.")
|
129
|
+
role = "system" if index % 2 == 0 else "user"
|
130
|
+
self.messages[index] = {"role": role, "content": value}
|
131
|
+
|
132
|
+
def __add__(self, message):
|
133
|
+
if self.last_role == "system":
|
134
|
+
self.add_user_message(message)
|
135
|
+
else:
|
136
|
+
next_role = "assistant" if self.last_role == "user" else "user"
|
137
|
+
self.add_message(message, next_role)
|
138
|
+
|
139
|
+
def __contains__(self, item):
|
140
|
+
for message in self.messages:
|
141
|
+
content = message['content']
|
142
|
+
if isinstance(content, str) and item in content:
|
143
|
+
return True
|
144
|
+
elif isinstance(content, list):
|
145
|
+
for block in content:
|
146
|
+
if isinstance(block, str) and item in block:
|
147
|
+
return True
|
148
|
+
return False
|
149
|
+
|
150
|
+
def add_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]], role: str):
|
151
|
+
self.messages.append({"role": role, "content": content})
|
152
|
+
self.last_role = role
|
153
|
+
|
154
|
+
def add_user_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
|
155
|
+
if self.last_role in ["system", "assistant"]:
|
156
|
+
self.add_message(content, "user")
|
157
|
+
else:
|
158
|
+
raise ValueError("A user message must follow a system or assistant message.")
|
159
|
+
|
160
|
+
def add_assistant_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
|
161
|
+
if self.last_role == "user":
|
162
|
+
self.add_message(content, "assistant")
|
163
|
+
else:
|
164
|
+
raise ValueError("An assistant message must follow a user message.")
|
165
|
+
|
166
|
+
def add_marker(self, marker, index=None):
|
167
|
+
if not isinstance(marker, str):
|
168
|
+
raise ValueError("Marker must be a string.")
|
169
|
+
if index is None:
|
170
|
+
index = len(self.messages) - 1
|
171
|
+
if 0 <= index < len(self.messages):
|
172
|
+
self.messages[index]["marker"] = marker
|
173
|
+
else:
|
174
|
+
raise IndexError("Invalid index for marker.")
|
175
|
+
|
176
|
+
def conversation_status(self):
|
177
|
+
return {
|
178
|
+
"last_message_role": self.last_role,
|
179
|
+
"total_messages": len(self.messages),
|
180
|
+
"last_message_content": self.messages[-1]["content"] if self.messages else "No messages",
|
181
|
+
}
|
182
|
+
|
183
|
+
def display_conversation_status(self):
|
184
|
+
status = self.conversation_status()
|
185
|
+
print(f"Role of the last message: {status['last_message_role']}")
|
186
|
+
print(f"Total number of messages: {status['total_messages']}")
|
187
|
+
print(f"Content of the last message: {status['last_message_content']}")
|
188
|
+
|
189
|
+
def search_for_keyword(self, keyword):
|
190
|
+
results = []
|
191
|
+
for msg in self.messages:
|
192
|
+
content = msg['content']
|
193
|
+
if isinstance(content, str) and keyword.lower() in content.lower():
|
194
|
+
results.append(msg)
|
195
|
+
elif isinstance(content, list):
|
196
|
+
for block in content:
|
197
|
+
if isinstance(block, str) and keyword.lower() in block.lower():
|
198
|
+
results.append(msg)
|
199
|
+
break
|
200
|
+
return results
|
201
|
+
|
202
|
+
def has_user_or_assistant_spoken_since_last_system(self):
|
203
|
+
for msg in reversed(self.messages):
|
204
|
+
if msg["role"] == "system":
|
205
|
+
return False
|
206
|
+
if msg["role"] in ["user", "assistant"]:
|
207
|
+
return True
|
208
|
+
return False
|
209
|
+
|
210
|
+
def get_conversation_snippet(self, index):
|
211
|
+
snippet = {"previous": None, "current": None, "next": None}
|
212
|
+
if 0 <= index < len(self.messages):
|
213
|
+
snippet['current'] = self.messages[index]
|
214
|
+
if index > 0:
|
215
|
+
snippet['previous'] = self.messages[index - 1]
|
216
|
+
if index + 1 < len(self.messages):
|
217
|
+
snippet['next'] = self.messages[index + 1]
|
218
|
+
else:
|
219
|
+
raise IndexError("Invalid index.")
|
220
|
+
return snippet
|
221
|
+
|
222
|
+
def display_snippet(self, index):
|
223
|
+
snippet = self.get_conversation_snippet(index)
|
224
|
+
for key, value in snippet.items():
|
225
|
+
if value:
|
226
|
+
print(f"{key.capitalize()} Message ({value['role']}): {value['content']}")
|
227
|
+
else:
|
228
|
+
print(f"{key.capitalize()}: None")
|
229
|
+
|
230
|
+
@staticmethod
|
231
|
+
def color_text(text, color):
|
232
|
+
colors = {"green": "\033[92m", "red": "\033[91m", "end": "\033[0m"}
|
233
|
+
return f"{colors.get(color, '')}{text}{colors.get('end', '')}"
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "llm_dialog_manager"
|
7
|
-
version = "0.3.
|
7
|
+
version = "0.3.5"
|
8
8
|
description = "A Python package for managing LLM chat conversation history"
|
9
9
|
readme = "README.md"
|
10
10
|
authors = [{ name = "xihajun", email = "work@2333.fun" }]
|
@@ -1,333 +0,0 @@
|
|
1
|
-
# Standard library imports
|
2
|
-
import json
|
3
|
-
import os
|
4
|
-
import uuid
|
5
|
-
from typing import List, Dict, Optional
|
6
|
-
import logging
|
7
|
-
from pathlib import Path
|
8
|
-
import random
|
9
|
-
import requests
|
10
|
-
import zipfile
|
11
|
-
import io
|
12
|
-
|
13
|
-
# Third-party imports
|
14
|
-
import anthropic
|
15
|
-
from anthropic import AnthropicVertex
|
16
|
-
import google.generativeai as genai
|
17
|
-
import openai
|
18
|
-
from dotenv import load_dotenv
|
19
|
-
|
20
|
-
# Local imports
|
21
|
-
from llm_dialog_manager.chat_history import ChatHistory
|
22
|
-
from llm_dialog_manager.key_manager import key_manager
|
23
|
-
|
24
|
-
# Set up logging
|
25
|
-
logging.basicConfig(level=logging.INFO)
|
26
|
-
logger = logging.getLogger(__name__)
|
27
|
-
|
28
|
-
# Load environment variables
|
29
|
-
def load_env_vars():
|
30
|
-
"""Load environment variables from .env file"""
|
31
|
-
env_path = Path(__file__).parent.parent / '.env'
|
32
|
-
if env_path.exists():
|
33
|
-
load_dotenv(env_path)
|
34
|
-
else:
|
35
|
-
logger.warning(".env file not found. Using system environment variables.")
|
36
|
-
|
37
|
-
load_env_vars()
|
38
|
-
|
39
|
-
def create_and_send_message(client, model, max_tokens, temperature, messages, system_msg):
|
40
|
-
"""Function to send a message to the Anthropic API and handle the response."""
|
41
|
-
try:
|
42
|
-
response = client.messages.create(
|
43
|
-
model=model,
|
44
|
-
max_tokens=max_tokens,
|
45
|
-
temperature=temperature,
|
46
|
-
messages=messages,
|
47
|
-
system=system_msg
|
48
|
-
)
|
49
|
-
return response
|
50
|
-
except Exception as e:
|
51
|
-
logger.error(f"Error sending message: {str(e)}")
|
52
|
-
raise
|
53
|
-
|
54
|
-
def completion(model: str, messages: List[Dict[str, str]], max_tokens: int = 1000,
|
55
|
-
temperature: float = 0.5, api_key: Optional[str] = None,
|
56
|
-
base_url: Optional[str] = None) -> str:
|
57
|
-
"""
|
58
|
-
Generate a completion using the specified model and messages.
|
59
|
-
"""
|
60
|
-
try:
|
61
|
-
service = ""
|
62
|
-
if "claude" in model:
|
63
|
-
service = "anthropic"
|
64
|
-
elif "gemini" in model:
|
65
|
-
service = "gemini"
|
66
|
-
elif "grok" in model:
|
67
|
-
service = "x"
|
68
|
-
else:
|
69
|
-
service = "openai"
|
70
|
-
|
71
|
-
# Get API key and base URL from key manager if not provided
|
72
|
-
if not api_key:
|
73
|
-
api_key, base_url = key_manager.get_config(service)
|
74
|
-
|
75
|
-
try:
|
76
|
-
if "claude" in model:
|
77
|
-
# Check for Vertex configuration
|
78
|
-
vertex_project_id = os.getenv('VERTEX_PROJECT_ID')
|
79
|
-
vertex_region = os.getenv('VERTEX_REGION')
|
80
|
-
|
81
|
-
if vertex_project_id and vertex_region:
|
82
|
-
client = AnthropicVertex(
|
83
|
-
region=vertex_region,
|
84
|
-
project_id=vertex_project_id
|
85
|
-
)
|
86
|
-
else:
|
87
|
-
client = anthropic.Anthropic(api_key=api_key, base_url=base_url)
|
88
|
-
|
89
|
-
system_msg = messages.pop(0)["content"] if messages and messages[0]["role"] == "system" else ""
|
90
|
-
response = client.messages.create(
|
91
|
-
model=model,
|
92
|
-
max_tokens=max_tokens,
|
93
|
-
temperature=temperature,
|
94
|
-
messages=messages,
|
95
|
-
system=system_msg
|
96
|
-
)
|
97
|
-
|
98
|
-
while response.stop_reason == "max_tokens":
|
99
|
-
if messages[-1]['role'] == "user":
|
100
|
-
messages.append({"role": "assistant", "content": response.content[0].text})
|
101
|
-
else:
|
102
|
-
messages[-1]['content'] += response.content[0].text
|
103
|
-
|
104
|
-
response = client.messages.create(
|
105
|
-
model=model,
|
106
|
-
max_tokens=max_tokens,
|
107
|
-
temperature=temperature,
|
108
|
-
messages=messages,
|
109
|
-
system=system_msg
|
110
|
-
)
|
111
|
-
|
112
|
-
if messages[-1]['role'] == "assistant" and response.stop_reason == "end_turn":
|
113
|
-
messages[-1]['content'] += response.content[0].text
|
114
|
-
return messages[-1]['content']
|
115
|
-
|
116
|
-
return response.content[0].text
|
117
|
-
|
118
|
-
elif "gemini" in model:
|
119
|
-
try:
|
120
|
-
# First try OpenAI-style API
|
121
|
-
client = openai.OpenAI(
|
122
|
-
api_key=api_key,
|
123
|
-
base_url="https://generativelanguage.googleapis.com/v1beta/"
|
124
|
-
)
|
125
|
-
# Remove any system message from the beginning if present
|
126
|
-
if messages and messages[0]["role"] == "system":
|
127
|
-
system_msg = messages.pop(0)
|
128
|
-
# Prepend system message to first user message if exists
|
129
|
-
if messages:
|
130
|
-
messages[0]["content"] = f"{system_msg['content']}\n\n{messages[0]['content']}"
|
131
|
-
|
132
|
-
response = client.chat.completions.create(
|
133
|
-
model=model,
|
134
|
-
messages=messages,
|
135
|
-
temperature=temperature
|
136
|
-
)
|
137
|
-
|
138
|
-
return response.choices[0].message.content
|
139
|
-
|
140
|
-
except Exception as e:
|
141
|
-
# If OpenAI-style API fails, fall back to Google's genai library
|
142
|
-
logger.info("Falling back to Google's genai library")
|
143
|
-
genai.configure(api_key=api_key)
|
144
|
-
|
145
|
-
# Convert messages to Gemini format
|
146
|
-
gemini_messages = []
|
147
|
-
for msg in messages:
|
148
|
-
if msg["role"] == "system":
|
149
|
-
# Prepend system message to first user message if exists
|
150
|
-
if gemini_messages:
|
151
|
-
gemini_messages[0].parts[0].text = f"{msg['content']}\n\n{gemini_messages[0].parts[0].text}"
|
152
|
-
else:
|
153
|
-
gemini_messages.append({"role": msg["role"], "parts": [{"text": msg["content"]}]})
|
154
|
-
|
155
|
-
# Create Gemini model and generate response
|
156
|
-
model = genai.GenerativeModel(model_name=model)
|
157
|
-
response = model.generate_content(
|
158
|
-
gemini_messages,
|
159
|
-
generation_config=genai.types.GenerationConfig(
|
160
|
-
temperature=temperature,
|
161
|
-
max_output_tokens=max_tokens
|
162
|
-
)
|
163
|
-
)
|
164
|
-
|
165
|
-
return response.text
|
166
|
-
|
167
|
-
elif "grok" in model:
|
168
|
-
# Randomly choose between OpenAI and Anthropic SDK
|
169
|
-
use_anthropic = random.choice([True, False])
|
170
|
-
|
171
|
-
if use_anthropic:
|
172
|
-
print("using anthropic")
|
173
|
-
client = anthropic.Anthropic(
|
174
|
-
api_key=api_key,
|
175
|
-
base_url="https://api.x.ai"
|
176
|
-
)
|
177
|
-
|
178
|
-
system_msg = messages.pop(0)["content"] if messages and messages[0]["role"] == "system" else ""
|
179
|
-
response = client.messages.create(
|
180
|
-
model=model,
|
181
|
-
max_tokens=max_tokens,
|
182
|
-
temperature=temperature,
|
183
|
-
messages=messages,
|
184
|
-
system=system_msg
|
185
|
-
)
|
186
|
-
return response.content[0].text
|
187
|
-
else:
|
188
|
-
print("using openai")
|
189
|
-
client = openai.OpenAI(
|
190
|
-
api_key=api_key,
|
191
|
-
base_url="https://api.x.ai/v1"
|
192
|
-
)
|
193
|
-
response = client.chat.completions.create(
|
194
|
-
model=model,
|
195
|
-
messages=messages,
|
196
|
-
max_tokens=max_tokens,
|
197
|
-
temperature=temperature
|
198
|
-
)
|
199
|
-
return response.choices[0].message.content
|
200
|
-
|
201
|
-
else: # OpenAI models
|
202
|
-
client = openai.OpenAI(api_key=api_key, base_url=base_url)
|
203
|
-
response = client.chat.completions.create(
|
204
|
-
model=model,
|
205
|
-
messages=messages,
|
206
|
-
max_tokens=max_tokens,
|
207
|
-
temperature=temperature,
|
208
|
-
)
|
209
|
-
return response.choices[0].message.content
|
210
|
-
|
211
|
-
# Release the API key after successful use
|
212
|
-
if not api_key:
|
213
|
-
key_manager.release_config(service, api_key)
|
214
|
-
|
215
|
-
return response
|
216
|
-
|
217
|
-
except Exception as e:
|
218
|
-
# Report error to key manager
|
219
|
-
if not api_key:
|
220
|
-
key_manager.report_error(service, api_key)
|
221
|
-
raise
|
222
|
-
|
223
|
-
except Exception as e:
|
224
|
-
logger.error(f"Error in completion: {str(e)}")
|
225
|
-
raise
|
226
|
-
|
227
|
-
class Agent:
|
228
|
-
def __init__(self, model_name: str, messages: Optional[str] = None,
|
229
|
-
memory_enabled: bool = False, api_key: Optional[str] = None) -> None:
|
230
|
-
"""Initialize an Agent instance."""
|
231
|
-
# valid_models = ['gpt-3.5-turbo', 'gpt-4', 'claude-2.1', 'gemini-1.5-pro', 'gemini-1.5-flash', 'grok-beta', 'claude-3-5-sonnet-20241022']
|
232
|
-
# if model_name not in valid_models:
|
233
|
-
# raise ValueError(f"Model {model_name} not supported. Supported models: {valid_models}")
|
234
|
-
|
235
|
-
self.id = f"{model_name}-{uuid.uuid4().hex[:8]}"
|
236
|
-
self.model_name = model_name
|
237
|
-
self.history = ChatHistory(messages)
|
238
|
-
self.memory_enabled = memory_enabled
|
239
|
-
self.api_key = api_key
|
240
|
-
self.repo_content = []
|
241
|
-
|
242
|
-
def add_message(self, role, content):
|
243
|
-
repo_content = ""
|
244
|
-
while self.repo_content:
|
245
|
-
repo = self.repo_content.pop()
|
246
|
-
repo_content += f"<repo>\n{repo}\n</repo>\n"
|
247
|
-
|
248
|
-
content = repo_content + content
|
249
|
-
self.history.add_message(content, role)
|
250
|
-
|
251
|
-
def generate_response(self, max_tokens=3585, temperature=0.7):
|
252
|
-
if not self.history.messages:
|
253
|
-
raise ValueError("No messages in history to generate response from")
|
254
|
-
|
255
|
-
messages = [{"role": msg["role"], "content": msg["content"]} for msg in self.history.messages]
|
256
|
-
|
257
|
-
response_text = completion(
|
258
|
-
model=self.model_name,
|
259
|
-
messages=messages,
|
260
|
-
max_tokens=max_tokens,
|
261
|
-
temperature=temperature,
|
262
|
-
api_key=self.api_key
|
263
|
-
)
|
264
|
-
if messages[-1]["role"] == "assistant":
|
265
|
-
self.history.messages[-1]["content"] = response_text
|
266
|
-
|
267
|
-
elif self.memory_enabled:
|
268
|
-
self.add_message("assistant", response_text)
|
269
|
-
|
270
|
-
return response_text
|
271
|
-
|
272
|
-
def save_conversation(self):
|
273
|
-
filename = f"{self.id}.json"
|
274
|
-
with open(filename, 'w', encoding='utf-8') as file:
|
275
|
-
json.dump(self.history.messages, file, ensure_ascii=False, indent=4)
|
276
|
-
|
277
|
-
def load_conversation(self, filename=None):
|
278
|
-
if filename is None:
|
279
|
-
filename = f"{self.id}.json"
|
280
|
-
with open(filename, 'r', encoding='utf-8') as file:
|
281
|
-
messages = json.load(file)
|
282
|
-
self.history = ChatHistory(messages)
|
283
|
-
|
284
|
-
def load_repo(self, repo_name: Optional[str] = None, commit_hash: Optional[str] = None):
|
285
|
-
"""eg: repo_name: xihajun/llm_dialog_manager"""
|
286
|
-
if repo_name:
|
287
|
-
if commit_hash:
|
288
|
-
repo_url = f"https://github.com/{repo_name}/archive/{commit_hash}.zip"
|
289
|
-
else:
|
290
|
-
repo_url = f"https://github.com/{repo_name}/archive/refs/heads/main.zip"
|
291
|
-
|
292
|
-
if not repo_url:
|
293
|
-
raise ValueError("Either repo_url or both username and repo_name must be provided")
|
294
|
-
|
295
|
-
response = requests.get(repo_url)
|
296
|
-
if response.status_code == 200:
|
297
|
-
repo_content = ""
|
298
|
-
with zipfile.ZipFile(io.BytesIO(response.content)) as z:
|
299
|
-
for file_info in z.infolist():
|
300
|
-
if not file_info.is_dir() and file_info.filename.endswith(('.py', '.txt')):
|
301
|
-
with z.open(file_info) as f:
|
302
|
-
content = f.read().decode('utf-8')
|
303
|
-
repo_content += f"{file_info.filename}\n```\n{content}\n```\n"
|
304
|
-
self.repo_content.append(repo_content)
|
305
|
-
else:
|
306
|
-
raise ValueError(f"Failed to download repository from {repo_url}")
|
307
|
-
|
308
|
-
if __name__ == "__main__":
|
309
|
-
|
310
|
-
# write a test for detect finding agent
|
311
|
-
text = "I think the answer is 42"
|
312
|
-
|
313
|
-
# from agent.messageloader import information_detector_messages
|
314
|
-
|
315
|
-
# # Now you can print or use information_detector_messages as needed
|
316
|
-
# information_detector_agent = Agent("gemini-1.5-pro", information_detector_messages)
|
317
|
-
# information_detector_agent.add_message("user", text)
|
318
|
-
# response = information_detector_agent.generate_response()
|
319
|
-
# print(response)
|
320
|
-
agent = Agent("gemini-1.5-pro-002", "you are an assistant", memory_enabled=True)
|
321
|
-
|
322
|
-
# Format the prompt to check if the section is the last one in the outline
|
323
|
-
prompt = f"Say: {text}\n"
|
324
|
-
|
325
|
-
# Add the prompt as a message from the user
|
326
|
-
agent.add_message("user", prompt)
|
327
|
-
agent.add_message("assistant", "the answer")
|
328
|
-
|
329
|
-
print(agent.generate_response())
|
330
|
-
print(agent.history[:])
|
331
|
-
last_message = agent.history.pop()
|
332
|
-
print(last_message)
|
333
|
-
print(agent.history[:])
|
@@ -1,146 +0,0 @@
|
|
1
|
-
from typing import List, Dict, Optional, Union
|
2
|
-
|
3
|
-
class ChatHistory:
|
4
|
-
def __init__(self, input_data: Union[str, List[Dict[str, str]]] = "") -> None:
|
5
|
-
self.messages: List[Dict[str, str]] = []
|
6
|
-
if isinstance(input_data, str) and input_data:
|
7
|
-
self.add_message(input_data, "system")
|
8
|
-
elif isinstance(input_data, list):
|
9
|
-
self.load_messages(input_data)
|
10
|
-
self.last_role: str = "system" if not self.messages else self.messages[-1]["role"]
|
11
|
-
|
12
|
-
def load_messages(self, messages: List[Dict[str, str]]) -> None:
|
13
|
-
for message in messages:
|
14
|
-
if not ("role" in message and "content" in message):
|
15
|
-
raise ValueError("Each message must have a 'role' and 'content'.")
|
16
|
-
if message["role"] not in ["user", "assistant", "system"]:
|
17
|
-
raise ValueError(f"Invalid role: {message['role']}")
|
18
|
-
self.messages.append(message)
|
19
|
-
|
20
|
-
def pop(self):
|
21
|
-
if not self.messages:
|
22
|
-
return None
|
23
|
-
|
24
|
-
popped_message = self.messages.pop()
|
25
|
-
|
26
|
-
if self.messages:
|
27
|
-
self.last_role = self.messages[-1]["role"]
|
28
|
-
else:
|
29
|
-
self.last_role = "system"
|
30
|
-
|
31
|
-
return popped_message["content"]
|
32
|
-
|
33
|
-
def __len__(self):
|
34
|
-
return len(self.messages)
|
35
|
-
|
36
|
-
def __str__(self):
|
37
|
-
return '\n'.join([f"Message {i} ({msg['role']}): {msg['content']}" for i, msg in enumerate(self.messages)])
|
38
|
-
|
39
|
-
def __getitem__(self, key):
|
40
|
-
if isinstance(key, slice):
|
41
|
-
# Handle slice object, no change needed here for negative indices as slices are handled by list itself
|
42
|
-
print('\n'.join([f"({msg['role']}): {msg['content']}" for i, msg in enumerate(self.messages[key])]))
|
43
|
-
return self.messages[key]
|
44
|
-
elif isinstance(key, int):
|
45
|
-
# Adjust for negative indices
|
46
|
-
if key < 0:
|
47
|
-
key += len(self.messages) # Convert negative index to positive
|
48
|
-
if 0 <= key < len(self.messages):
|
49
|
-
snippet = self.get_conversation_snippet(key)
|
50
|
-
print('\n'.join([f"({v['role']}): {self.color_text(v['content'], 'green') if k == 'current' else v['content']}" for k, v in snippet.items() if v]))
|
51
|
-
return self.messages[key]
|
52
|
-
else:
|
53
|
-
raise IndexError("Message index out of range.")
|
54
|
-
else:
|
55
|
-
raise TypeError("Invalid argument type.")
|
56
|
-
|
57
|
-
def __setitem__(self, index, value):
|
58
|
-
if not isinstance(value, str):
|
59
|
-
raise ValueError("Message content must be a string.")
|
60
|
-
role = "system" if index % 2 == 0 else "user"
|
61
|
-
self.messages[index] = {"role": role, "content": value}
|
62
|
-
|
63
|
-
def __add__(self, message):
|
64
|
-
if self.last_role == "system":
|
65
|
-
self.add_user_message(message)
|
66
|
-
else:
|
67
|
-
next_role = "assistant" if self.last_role == "user" else "user"
|
68
|
-
self.add_message(message, next_role)
|
69
|
-
|
70
|
-
def __contains__(self, item):
|
71
|
-
return any(item in message['content'] for message in self.messages)
|
72
|
-
|
73
|
-
def add_message(self, content, role):
|
74
|
-
self.messages.append({"role": role, "content": content})
|
75
|
-
self.last_role = role
|
76
|
-
|
77
|
-
def add_user_message(self, content):
|
78
|
-
if self.last_role in ["system", "assistant"]:
|
79
|
-
self.add_message(content, "user")
|
80
|
-
else:
|
81
|
-
raise ValueError("A user message must follow a system or assistant message.")
|
82
|
-
|
83
|
-
def add_assistant_message(self, content):
|
84
|
-
if self.last_role == "user":
|
85
|
-
self.add_message(content, "assistant")
|
86
|
-
else:
|
87
|
-
raise ValueError("An assistant message must follow a user message.")
|
88
|
-
|
89
|
-
def add_marker(self, marker, index=None):
|
90
|
-
if not isinstance(marker, str):
|
91
|
-
raise ValueError("Marker must be a string.")
|
92
|
-
if index is None:
|
93
|
-
index = len(self.messages) - 1
|
94
|
-
if 0 <= index < len(self.messages):
|
95
|
-
self.messages[index]["marker"] = marker
|
96
|
-
else:
|
97
|
-
raise IndexError("Invalid index for marker.")
|
98
|
-
|
99
|
-
def conversation_status(self):
|
100
|
-
return {
|
101
|
-
"last_message_role": self.last_role,
|
102
|
-
"total_messages": len(self.messages),
|
103
|
-
"last_message_content": self.messages[-1]["content"] if self.messages else "No messages",
|
104
|
-
}
|
105
|
-
|
106
|
-
def display_conversation_status(self):
|
107
|
-
status = self.conversation_status()
|
108
|
-
print(f"Role of the last message: {status['last_message_role']}")
|
109
|
-
print(f"Total number of messages: {status['total_messages']}")
|
110
|
-
print(f"Content of the last message: {status['last_message_content']}")
|
111
|
-
|
112
|
-
def search_for_keyword(self, keyword):
|
113
|
-
return [msg for msg in self.messages if keyword.lower() in msg["content"].lower()]
|
114
|
-
|
115
|
-
def has_user_or_assistant_spoken_since_last_system(self):
|
116
|
-
for msg in reversed(self.messages):
|
117
|
-
if msg["role"] == "system":
|
118
|
-
return False
|
119
|
-
if msg["role"] in ["user", "assistant"]:
|
120
|
-
return True
|
121
|
-
return False
|
122
|
-
|
123
|
-
def get_conversation_snippet(self, index):
|
124
|
-
snippet = {"previous": None, "current": None, "next": None}
|
125
|
-
if 0 <= index < len(self.messages):
|
126
|
-
snippet['current'] = self.messages[index]
|
127
|
-
if index > 0:
|
128
|
-
snippet['previous'] = self.messages[index - 1]
|
129
|
-
if index + 1 < len(self.messages):
|
130
|
-
snippet['next'] = self.messages[index + 1]
|
131
|
-
else:
|
132
|
-
raise IndexError("Invalid index.")
|
133
|
-
return snippet
|
134
|
-
|
135
|
-
def display_snippet(self, index):
|
136
|
-
snippet = self.get_conversation_snippet(index)
|
137
|
-
for key, value in snippet.items():
|
138
|
-
if value:
|
139
|
-
print(f"{key.capitalize()} Message ({value['role']}): {value['content']}")
|
140
|
-
else:
|
141
|
-
print(f"{key.capitalize()}: None")
|
142
|
-
|
143
|
-
@staticmethod
|
144
|
-
def color_text(text, color):
|
145
|
-
colors = {"green": "\033[92m", "red": "\033[91m", "end": "\033[0m"}
|
146
|
-
return f"{colors[color]}{text}{colors['end']}"
|
File without changes
|
File without changes
|
File without changes
|
{llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/llm_dialog_manager.egg-info/SOURCES.txt
RENAMED
File without changes
|
File without changes
|
{llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/llm_dialog_manager.egg-info/requires.txt
RENAMED
File without changes
|
{llm_dialog_manager-0.3.2 → llm_dialog_manager-0.3.5}/llm_dialog_manager.egg-info/top_level.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|