llm-dialog-manager 0.3.4__py3-none-any.whl → 0.4.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,4 +1,4 @@
1
1
  from .chat_history import ChatHistory
2
2
  from .agent import Agent
3
3
 
4
- __version__ = "0.3.4"
4
+ __version__ = "0.4.1"
@@ -2,13 +2,15 @@
2
2
  import json
3
3
  import os
4
4
  import uuid
5
- from typing import List, Dict, Optional
5
+ from typing import List, Dict, Optional, Union
6
6
  import logging
7
7
  from pathlib import Path
8
8
  import random
9
9
  import requests
10
10
  import zipfile
11
11
  import io
12
+ import base64
13
+ from PIL import Image
12
14
 
13
15
  # Third-party imports
14
16
  import anthropic
@@ -18,8 +20,8 @@ import openai
18
20
  from dotenv import load_dotenv
19
21
 
20
22
  # Local imports
21
- from llm_dialog_manager.chat_history import ChatHistory
22
- from llm_dialog_manager.key_manager import key_manager
23
+ from .chat_history import ChatHistory
24
+ from .key_manager import key_manager
23
25
 
24
26
  # Set up logging
25
27
  logging.basicConfig(level=logging.INFO)
@@ -28,7 +30,7 @@ logger = logging.getLogger(__name__)
28
30
  # Load environment variables
29
31
  def load_env_vars():
30
32
  """Load environment variables from .env file"""
31
- env_path = Path(__file__).parent.parent / '.env'
33
+ env_path = Path(__file__).parent / '.env'
32
34
  if env_path.exists():
33
35
  load_dotenv(env_path)
34
36
  else:
@@ -36,24 +38,37 @@ def load_env_vars():
36
38
 
37
39
  load_env_vars()
38
40
 
39
- def create_and_send_message(client, model, max_tokens, temperature, messages, system_msg):
40
- """Function to send a message to the Anthropic API and handle the response."""
41
- try:
42
- response = client.messages.create(
43
- model=model,
44
- max_tokens=max_tokens,
45
- temperature=temperature,
46
- messages=messages,
47
- system=system_msg
48
- )
49
- return response
50
- except Exception as e:
51
- logger.error(f"Error sending message: {str(e)}")
52
- raise
41
+ def format_messages_for_gemini(messages):
42
+ """
43
+ 将标准化的消息格式转化为 Gemini 格式。
44
+ system 消息应该通过 GenerativeModel 的 system_instruction 参数传入,
45
+ 不在这个函数处理。
46
+ """
47
+ gemini_messages = []
53
48
 
54
- def completion(model: str, messages: List[Dict[str, str]], max_tokens: int = 1000,
55
- temperature: float = 0.5, api_key: Optional[str] = None,
56
- base_url: Optional[str] = None) -> str:
49
+ for msg in messages:
50
+ role = msg["role"]
51
+ content = msg["content"]
52
+
53
+ # 跳过 system 消息,因为它会通过 system_instruction 设置
54
+ if role == "system":
55
+ continue
56
+
57
+ # 处理 user/assistant 消息
58
+ # 如果 content 是单一对象,转换为列表
59
+ if not isinstance(content, list):
60
+ content = [content]
61
+
62
+ gemini_messages.append({
63
+ "role": role,
64
+ "parts": content # content 可以包含文本和 FileMedia
65
+ })
66
+
67
+ return gemini_messages
68
+
69
+ def completion(model: str, messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]], max_tokens: int = 1000,
70
+ temperature: float = 0.5, top_p: float = 1.0, top_k: int = 40, api_key: Optional[str] = None,
71
+ base_url: Optional[str] = None, json_format: bool = False) -> str:
57
72
  """
58
73
  Generate a completion using the specified model and messages.
59
74
  """
@@ -70,202 +85,396 @@ def completion(model: str, messages: List[Dict[str, str]], max_tokens: int = 100
70
85
 
71
86
  # Get API key and base URL from key manager if not provided
72
87
  if not api_key:
73
- api_key, base_url = key_manager.get_config(service)
88
+ # api_key, base_url = key_manager.get_config(service)
89
+ # Placeholder for key_manager
90
+ api_key = os.getenv(f"{service.upper()}_API_KEY")
91
+ base_url = os.getenv(f"{service.upper()}_BASE_URL")
74
92
 
75
- try:
93
+ def format_messages_for_api(model, messages):
94
+ """Convert ChatHistory messages to the format required by the specific API."""
76
95
  if "claude" in model:
77
- # Check for Vertex configuration
78
- vertex_project_id = os.getenv('VERTEX_PROJECT_ID')
79
- vertex_region = os.getenv('VERTEX_REGION')
80
-
81
- if vertex_project_id and vertex_region:
82
- client = AnthropicVertex(
83
- region=vertex_region,
84
- project_id=vertex_project_id
85
- )
96
+ formatted = []
97
+ system_msg = ""
98
+ if messages and messages[0]["role"] == "system":
99
+ system_msg = messages.pop(0)["content"]
100
+ for msg in messages:
101
+ content = msg["content"]
102
+ if isinstance(content, str):
103
+ formatted.append({"role": msg["role"], "content": content})
104
+ elif isinstance(content, list):
105
+ # Combine content blocks into a single message
106
+ combined_content = []
107
+ for block in content:
108
+ if isinstance(block, str):
109
+ combined_content.append({"type": "text", "text": block})
110
+ elif isinstance(block, Image.Image):
111
+ # For Claude, convert PIL.Image to base64
112
+ buffered = io.BytesIO()
113
+ block.save(buffered, format="PNG")
114
+ image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
115
+ combined_content.append({
116
+ "type": "image",
117
+ "source": {
118
+ "type": "base64",
119
+ "media_type": "image/png",
120
+ "data": image_base64
121
+ }
122
+ })
123
+ elif isinstance(block, dict):
124
+ if block.get("type") == "image_url":
125
+ combined_content.append({
126
+ "type": "image",
127
+ "source": {
128
+ "type": "url",
129
+ "url": block["image_url"]["url"]
130
+ }
131
+ })
132
+ elif block.get("type") == "image_base64":
133
+ combined_content.append({
134
+ "type": "image",
135
+ "source": {
136
+ "type": "base64",
137
+ "media_type": block["image_base64"]["media_type"],
138
+ "data": block["image_base64"]["data"]
139
+ }
140
+ })
141
+ formatted.append({"role": msg["role"], "content": combined_content})
142
+ return system_msg, formatted
143
+
144
+ elif "gemini" in model or "gpt" in model or "grok" in model:
145
+ formatted = []
146
+ for msg in messages:
147
+ content = msg["content"]
148
+ if isinstance(content, str):
149
+ formatted.append({"role": msg["role"], "parts": [content]})
150
+ elif isinstance(content, list):
151
+ parts = []
152
+ for block in content:
153
+ if isinstance(block, str):
154
+ parts.append(block)
155
+ elif isinstance(block, Image.Image):
156
+ parts.append(block)
157
+ elif isinstance(block, dict):
158
+ if block.get("type") == "image_url":
159
+ parts.append({"type": "image_url", "image_url": {"url": block["image_url"]["url"]}})
160
+ elif block.get("type") == "image_base64":
161
+ parts.append({"type": "image_base64", "image_base64": {"data": block["image_base64"]["data"], "media_type": block["image_base64"]["media_type"]}})
162
+ formatted.append({"role": msg["role"], "parts": parts})
163
+ return None, formatted
164
+
165
+ else: # OpenAI models
166
+ formatted = []
167
+ for msg in messages:
168
+ content = msg["content"]
169
+ if isinstance(content, str):
170
+ formatted.append({"role": msg["role"], "content": content})
171
+ elif isinstance(content, list):
172
+ # OpenAI expects 'content' as string; images are not directly supported
173
+ # You can convert images to URLs or descriptions if needed
174
+ combined_content = ""
175
+ for block in content:
176
+ if isinstance(block, str):
177
+ combined_content += block + "\n"
178
+ elif isinstance(block, Image.Image):
179
+ # Convert PIL.Image to base64 or upload and use URL
180
+ buffered = io.BytesIO()
181
+ block.save(buffered, format="PNG")
182
+ image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
183
+ combined_content += f"[Image Base64: {image_base64[:30]}...]\n"
184
+ elif isinstance(block, dict):
185
+ if block.get("type") == "image_url":
186
+ combined_content += f"[Image: {block['image_url']['url']}]\n"
187
+ elif block.get("type") == "image_base64":
188
+ combined_content += f"[Image Base64: {block['image_base64']['data'][:30]}...]\n"
189
+ formatted.append({"role": msg["role"], "content": combined_content.strip()})
190
+ return None, formatted
191
+
192
+ system_msg, formatted_messages = format_messages_for_api(model, messages.copy())
193
+
194
+ if "claude" in model:
195
+ # Check for Vertex configuration
196
+ vertex_project_id = os.getenv('VERTEX_PROJECT_ID')
197
+ vertex_region = os.getenv('VERTEX_REGION')
198
+
199
+ if vertex_project_id and vertex_region:
200
+ client = AnthropicVertex(
201
+ region=vertex_region,
202
+ project_id=vertex_project_id
203
+ )
204
+ else:
205
+ client = anthropic.Anthropic(api_key=api_key, base_url=base_url)
206
+
207
+ response = client.messages.create(
208
+ model=model,
209
+ max_tokens=max_tokens,
210
+ temperature=temperature,
211
+ messages=formatted_messages,
212
+ system=system_msg
213
+ )
214
+
215
+ while response.stop_reason == "max_tokens":
216
+ if formatted_messages[-1]['role'] == "user":
217
+ formatted_messages.append({"role": "assistant", "content": response.completion})
86
218
  else:
87
- client = anthropic.Anthropic(api_key=api_key, base_url=base_url)
219
+ formatted_messages[-1]['content'] += response.completion
88
220
 
89
- system_msg = messages.pop(0)["content"] if messages and messages[0]["role"] == "system" else ""
90
221
  response = client.messages.create(
91
222
  model=model,
92
223
  max_tokens=max_tokens,
93
224
  temperature=temperature,
94
- messages=messages,
225
+ messages=formatted_messages,
95
226
  system=system_msg
96
227
  )
228
+
229
+ if formatted_messages[-1]['role'] == "assistant" and response.stop_reason == "end_turn":
230
+ formatted_messages[-1]['content'] += response.completion
231
+ return formatted_messages[-1]['content']
232
+
233
+ return response.completion
234
+
235
+ elif "gemini" in model:
236
+ try:
237
+ # First try OpenAI-style API
238
+ client = openai.OpenAI(
239
+ api_key=api_key,
240
+ base_url="https://generativelanguage.googleapis.com/v1beta/"
241
+ )
242
+ # Set response_format based on json_format
243
+ response_format = {"type": "json_object"} if json_format else {"type": "plain_text"}
244
+
245
+ response = client.chat.completions.create(
246
+ model=model,
247
+ max_tokens=max_tokens,
248
+ top_p=top_p,
249
+ top_k=top_k,
250
+ messages=formatted_messages,
251
+ temperature=temperature,
252
+ response_format=response_format # Added response_format
253
+ )
254
+ return response.choices[0].message.content
255
+
256
+ except Exception as e:
257
+ # If OpenAI-style API fails, fall back to Google's genai library
258
+ logger.info("Falling back to Google's genai library")
259
+ genai.configure(api_key=api_key)
260
+ system_instruction = ""
261
+ for msg in messages:
262
+ if msg["role"] == "system":
263
+ system_instruction = msg["content"]
264
+ break
97
265
 
98
- while response.stop_reason == "max_tokens":
99
- if messages[-1]['role'] == "user":
100
- messages.append({"role": "assistant", "content": response.content[0].text})
101
- else:
102
- messages[-1]['content'] += response.content[0].text
103
-
104
- response = client.messages.create(
105
- model=model,
106
- max_tokens=max_tokens,
107
- temperature=temperature,
108
- messages=messages,
109
- system=system_msg
110
- )
111
-
112
- if messages[-1]['role'] == "assistant" and response.stop_reason == "end_turn":
113
- messages[-1]['content'] += response.content[0].text
114
- return messages[-1]['content']
115
-
116
- return response.content[0].text
117
-
118
- elif "gemini" in model:
119
- try:
120
- # First try OpenAI-style API
121
- client = openai.OpenAI(
122
- api_key=api_key,
123
- base_url="https://generativelanguage.googleapis.com/v1beta/"
124
- )
125
- # Remove any system message from the beginning if present
126
- if messages and messages[0]["role"] == "system":
127
- system_msg = messages.pop(0)
128
- # Prepend system message to first user message if exists
129
- if messages:
130
- messages[0]["content"] = f"{system_msg['content']}\n\n{messages[0]['content']}"
131
-
132
- response = client.chat.completions.create(
133
- model=model,
134
- messages=messages,
135
- temperature=temperature
136
- )
137
-
138
- return response.choices[0].message.content
139
-
140
- except Exception as e:
141
- # If OpenAI-style API fails, fall back to Google's genai library
142
- logger.info("Falling back to Google's genai library")
143
- genai.configure(api_key=api_key)
144
-
145
- # Convert messages to Gemini format
146
- gemini_messages = []
147
- for msg in messages:
148
- if msg["role"] == "system":
149
- # Prepend system message to first user message if exists
150
- if gemini_messages:
151
- gemini_messages[0].parts[0].text = f"{msg['content']}\n\n{gemini_messages[0].parts[0].text}"
152
- else:
153
- gemini_messages.append({"role": msg["role"], "parts": [{"text": msg["content"]}]})
154
-
155
- # Create Gemini model and generate response
156
- model = genai.GenerativeModel(model_name=model)
157
- response = model.generate_content(
158
- gemini_messages,
159
- generation_config=genai.types.GenerationConfig(
160
- temperature=temperature,
161
- max_output_tokens=max_tokens
162
- )
163
- )
164
-
165
- return response.text
166
-
167
- elif "grok" in model:
168
- # Randomly choose between OpenAI and Anthropic SDK
169
- use_anthropic = random.choice([True, False])
170
-
171
- if use_anthropic:
172
- print("using anthropic")
173
- client = anthropic.Anthropic(
174
- api_key=api_key,
175
- base_url="https://api.x.ai"
176
- )
177
-
178
- system_msg = messages.pop(0)["content"] if messages and messages[0]["role"] == "system" else ""
179
- response = client.messages.create(
180
- model=model,
181
- max_tokens=max_tokens,
182
- temperature=temperature,
183
- messages=messages,
184
- system=system_msg
185
- )
186
- return response.content[0].text
187
- else:
188
- print("using openai")
189
- client = openai.OpenAI(
190
- api_key=api_key,
191
- base_url="https://api.x.ai/v1"
192
- )
193
- response = client.chat.completions.create(
194
- model=model,
195
- messages=messages,
196
- max_tokens=max_tokens,
197
- temperature=temperature
198
- )
199
- return response.choices[0].message.content
266
+ # 将其他消息转换为 gemini 格式
267
+ gemini_messages = format_messages_for_gemini(messages)
268
+ mime_type = "application/json" if json_format else "text/plain"
269
+ generation_config = genai.types.GenerationConfig(
270
+ temperature=temperature,
271
+ top_p=top_p,
272
+ top_k=top_k,
273
+ max_output_tokens=max_tokens,
274
+ response_mime_type=mime_type
275
+ )
276
+
277
+ model_instance = genai.GenerativeModel(
278
+ model_name=model,
279
+ system_instruction=system_instruction, # system 消息通过这里传入
280
+ generation_config=generation_config
281
+ )
282
+
283
+ response = model_instance.generate_content(gemini_messages, generation_config=generation_config)
284
+
285
+ return response.text
286
+
287
+ elif "grok" in model:
288
+ # Randomly choose between OpenAI and Anthropic SDK
289
+ use_anthropic = random.choice([True, False])
290
+
291
+ if use_anthropic:
292
+ logger.info("Using Anthropic for Grok model")
293
+ client = anthropic.Anthropic(
294
+ api_key=api_key,
295
+ base_url="https://api.x.ai"
296
+ )
297
+
298
+ system_msg = ""
299
+ if messages and messages[0]["role"] == "system":
300
+ system_msg = messages.pop(0)["content"]
301
+
302
+ response = client.messages.create(
303
+ model=model,
304
+ max_tokens=max_tokens,
305
+ temperature=temperature,
306
+ messages=formatted_messages,
307
+ system=system_msg
308
+ )
309
+ return response.completion
310
+ else:
311
+ logger.info("Using OpenAI for Grok model")
312
+ client = openai.OpenAI(
313
+ api_key=api_key,
314
+ base_url="https://api.x.ai/v1"
315
+ )
316
+ # Set response_format based on json_format
317
+ response_format = {"type": "json_object"} if json_format else {"type": "plain_text"}
200
318
 
201
- else: # OpenAI models
202
- client = openai.OpenAI(api_key=api_key, base_url=base_url)
203
319
  response = client.chat.completions.create(
204
320
  model=model,
205
- messages=messages,
321
+ messages=formatted_messages,
206
322
  max_tokens=max_tokens,
207
323
  temperature=temperature,
324
+ response_format=response_format # Added response_format
208
325
  )
209
326
  return response.choices[0].message.content
210
327
 
211
- # Release the API key after successful use
212
- if not api_key:
213
- key_manager.release_config(service, api_key)
328
+ else: # OpenAI models
329
+ client = openai.OpenAI(api_key=api_key, base_url=base_url)
330
+ # Set response_format based on json_format
331
+ response_format = {"type": "json_object"} if json_format else {"type": "plain_text"}
214
332
 
215
- return response
333
+ response = client.chat.completions.create(
334
+ model=model,
335
+ messages=formatted_messages,
336
+ max_tokens=max_tokens,
337
+ temperature=temperature,
338
+ response_format=response_format # Added response_format
339
+ )
340
+ return response.choices[0].message.content
216
341
 
217
- except Exception as e:
218
- # Report error to key manager
219
- if not api_key:
220
- key_manager.report_error(service, api_key)
221
- raise
342
+ # Release the API key after successful use
343
+ if not api_key:
344
+ # key_manager.release_config(service, api_key)
345
+ pass
346
+
347
+ return response
222
348
 
223
349
  except Exception as e:
224
350
  logger.error(f"Error in completion: {str(e)}")
225
351
  raise
226
352
 
227
353
  class Agent:
228
- def __init__(self, model_name: str, messages: Optional[str] = None,
354
+ def __init__(self, model_name: str, messages: Optional[Union[str, List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]]] = None,
229
355
  memory_enabled: bool = False, api_key: Optional[str] = None) -> None:
230
356
  """Initialize an Agent instance."""
231
- # valid_models = ['gpt-3.5-turbo', 'gpt-4', 'claude-2.1', 'gemini-1.5-pro', 'gemini-1.5-flash', 'grok-beta', 'claude-3-5-sonnet-20241022']
232
- # if model_name not in valid_models:
233
- # raise ValueError(f"Model {model_name} not supported. Supported models: {valid_models}")
234
-
235
357
  self.id = f"{model_name}-{uuid.uuid4().hex[:8]}"
236
358
  self.model_name = model_name
237
- self.history = ChatHistory(messages)
359
+ self.history = ChatHistory(messages) if messages else ChatHistory()
238
360
  self.memory_enabled = memory_enabled
239
361
  self.api_key = api_key
240
362
  self.repo_content = []
241
363
 
242
- def add_message(self, role, content):
243
- repo_content = ""
244
- while self.repo_content:
245
- repo = self.repo_content.pop()
246
- repo_content += f"<repo>\n{repo}\n</repo>\n"
247
-
248
- content = repo_content + content
364
+ def add_message(self, role: str, content: Union[str, List[Union[str, Image.Image, Dict]]]):
365
+ """Add a message to the conversation."""
249
366
  self.history.add_message(content, role)
250
367
 
251
- def generate_response(self, max_tokens=3585, temperature=0.7):
368
+ def add_user_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
369
+ """Add a user message."""
370
+ self.history.add_user_message(content)
371
+
372
+ def add_assistant_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
373
+ """Add an assistant message."""
374
+ self.history.add_assistant_message(content)
375
+
376
+ def add_image(self, image_path: Optional[str] = None, image_url: Optional[str] = None, media_type: Optional[str] = "image/jpeg"):
377
+ """
378
+ Add an image to the conversation.
379
+ Either image_path or image_url must be provided.
380
+ """
381
+ if not image_path and not image_url:
382
+ raise ValueError("Either image_path or image_url must be provided.")
383
+
384
+ if image_path:
385
+ if not os.path.exists(image_path):
386
+ raise FileNotFoundError(f"Image file {image_path} does not exist.")
387
+ if "gemini" in self.model_name:
388
+ # For Gemini, load as PIL.Image
389
+ image_pil = Image.open(image_path)
390
+ image_block = image_pil
391
+ else:
392
+ # For Claude and others, use base64 encoding
393
+ with open(image_path, "rb") as img_file:
394
+ image_data = base64.standard_b64encode(img_file.read()).decode("utf-8")
395
+ image_block = {
396
+ "type": "image_base64",
397
+ "image_base64": {
398
+ "media_type": media_type,
399
+ "data": image_data
400
+ }
401
+ }
402
+ else:
403
+ # If image_url is provided
404
+ if "gemini" in self.model_name:
405
+ # For Gemini, you can pass image URLs directly
406
+ image_block = {"type": "image_url", "image_url": {"url": image_url}}
407
+ else:
408
+ # For Claude and others, use image URLs
409
+ image_block = {
410
+ "type": "image_url",
411
+ "image_url": {
412
+ "url": image_url
413
+ }
414
+ }
415
+
416
+ # Add the image block to the last user message or as a new user message
417
+ if self.history.last_role == "user":
418
+ current_content = self.history.messages[-1]["content"]
419
+ if isinstance(current_content, list):
420
+ current_content.append(image_block)
421
+ else:
422
+ self.history.messages[-1]["content"] = [current_content, image_block]
423
+ else:
424
+ # Start a new user message with the image
425
+ self.history.add_message([image_block], "user")
426
+
427
+ def generate_response(self, max_tokens=3585, temperature=0.7, top_p=1.0, top_k=40, json_format: bool = False) -> str:
428
+ """Generate a response from the agent.
429
+
430
+ Args:
431
+ max_tokens (int, optional): Maximum number of tokens. Defaults to 3585.
432
+ temperature (float, optional): Sampling temperature. Defaults to 0.7.
433
+ json_format (bool, optional): Whether to enable JSON output format. Defaults to False.
434
+
435
+ Returns:
436
+ str: The generated response.
437
+ """
252
438
  if not self.history.messages:
253
439
  raise ValueError("No messages in history to generate response from")
254
440
 
255
- messages = [{"role": msg["role"], "content": msg["content"]} for msg in self.history.messages]
256
-
441
+ messages = self.history.messages
442
+ print(self.model_name)
257
443
  response_text = completion(
258
444
  model=self.model_name,
259
445
  messages=messages,
260
446
  max_tokens=max_tokens,
261
447
  temperature=temperature,
262
- api_key=self.api_key
448
+ top_p=top_p,
449
+ top_k=top_k,
450
+ api_key=self.api_key,
451
+ json_format=json_format # Pass json_format to completion
263
452
  )
264
- if messages[-1]["role"] == "assistant":
265
- self.history.messages[-1]["content"] = response_text
266
-
267
- elif self.memory_enabled:
268
- self.add_message("assistant", response_text)
453
+ if self.model_name.startswith("openai"):
454
+ # OpenAI does not support images, so responses are simple strings
455
+ if self.history.messages[-1]["role"] == "assistant":
456
+ self.history.messages[-1]["content"] = response_text
457
+ elif self.memory_enabled:
458
+ self.add_message("assistant", response_text)
459
+ elif "claude" in self.model_name:
460
+ if self.history.messages[-1]["role"] == "assistant":
461
+ self.history.messages[-1]["content"] = response_text
462
+ elif self.memory_enabled:
463
+ self.add_message("assistant", response_text)
464
+ elif "gemini" in self.model_name or "grok" in self.model_name:
465
+ if self.history.messages[-1]["role"] == "assistant":
466
+ if isinstance(self.history.messages[-1]["content"], list):
467
+ self.history.messages[-1]["content"].append(response_text)
468
+ else:
469
+ self.history.messages[-1]["content"] = [self.history.messages[-1]["content"], response_text]
470
+ elif self.memory_enabled:
471
+ self.add_message("assistant", response_text)
472
+ else:
473
+ # Handle other models similarly
474
+ if self.history.messages[-1]["role"] == "assistant":
475
+ self.history.messages[-1]["content"] = response_text
476
+ elif self.memory_enabled:
477
+ self.add_message("assistant", response_text)
269
478
 
270
479
  return response_text
271
480
 
@@ -274,11 +483,12 @@ class Agent:
274
483
  with open(filename, 'w', encoding='utf-8') as file:
275
484
  json.dump(self.history.messages, file, ensure_ascii=False, indent=4)
276
485
 
277
- def load_conversation(self, filename=None):
486
+ def load_conversation(self, filename: Optional[str] = None):
278
487
  if filename is None:
279
488
  filename = f"{self.id}.json"
280
489
  with open(filename, 'r', encoding='utf-8') as file:
281
490
  messages = json.load(file)
491
+ # Handle deserialization of images if necessary
282
492
  self.history = ChatHistory(messages)
283
493
 
284
494
  def add_repo(self, repo_url: Optional[str] = None, username: Optional[str] = None, repo_name: Optional[str] = None, commit_hash: Optional[str] = None):
@@ -305,28 +515,31 @@ class Agent:
305
515
  raise ValueError(f"Failed to download repository from {repo_url}")
306
516
 
307
517
  if __name__ == "__main__":
308
-
309
- # write a test for detect finding agent
310
- text = "I think the answer is 42"
311
-
312
- # from agent.messageloader import information_detector_messages
518
+ # Example Usage
519
+ # Create an Agent instance (Gemini model)
520
+ agent = Agent("gemini-1.5-flash", "you are Jack101", memory_enabled=True)
313
521
 
314
- # # Now you can print or use information_detector_messages as needed
315
- # information_detector_agent = Agent("gemini-1.5-pro", information_detector_messages)
316
- # information_detector_agent.add_message("user", text)
317
- # response = information_detector_agent.generate_response()
318
- # print(response)
319
- agent = Agent("gemini-1.5-pro-002", "you are an assistant", memory_enabled=True)
522
+ # Add an image
523
+ agent.add_image(image_path="/Users/junfan/Projects/Personal/oneapi/dialog_manager/example.png")
320
524
 
321
- # Format the prompt to check if the section is the last one in the outline
322
- prompt = f"Say: {text}\n"
525
+ # Add a user message
526
+ agent.add_message("user", "Who are you? What's in this image?")
323
527
 
324
- # Add the prompt as a message from the user
325
- agent.add_message("user", prompt)
326
- agent.add_message("assistant", "the answer")
528
+ # Generate response with JSON format enabled
529
+ try:
530
+ response = agent.generate_response(json_format=True) # json_format set to True
531
+ print("Response:", response)
532
+ except Exception as e:
533
+ logger.error(f"Failed to generate response: {e}")
327
534
 
328
- print(agent.generate_response())
329
- print(agent.history[:])
535
+ # Print the entire conversation history
536
+ print("Conversation History:")
537
+ print(agent.history)
538
+
539
+ # Pop the last message
330
540
  last_message = agent.history.pop()
331
- print(last_message)
332
- print(agent.history[:])
541
+ print("Last Message:", last_message)
542
+
543
+ # Generate another response without JSON format
544
+ response = agent.generate_response()
545
+ print("Response:", response)
@@ -1,53 +1,122 @@
1
1
  from typing import List, Dict, Optional, Union
2
+ from PIL import Image
2
3
 
3
4
  class ChatHistory:
4
- def __init__(self, input_data: Union[str, List[Dict[str, str]]] = "") -> None:
5
- self.messages: List[Dict[str, str]] = []
5
+ def __init__(self, input_data: Union[str, List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]] = "") -> None:
6
+ self.messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]] = []
6
7
  if isinstance(input_data, str) and input_data:
7
8
  self.add_message(input_data, "system")
8
9
  elif isinstance(input_data, list):
9
10
  self.load_messages(input_data)
10
- self.last_role: str = "system" if not self.messages else self.messages[-1]["role"]
11
+ self.last_role: str = "system" if not self.messages else self.get_last_role()
11
12
 
12
- def load_messages(self, messages: List[Dict[str, str]]) -> None:
13
+ def load_messages(self, messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]) -> None:
13
14
  for message in messages:
14
15
  if not ("role" in message and "content" in message):
15
16
  raise ValueError("Each message must have a 'role' and 'content'.")
16
17
  if message["role"] not in ["user", "assistant", "system"]:
17
18
  raise ValueError(f"Invalid role: {message['role']}")
18
19
  self.messages.append(message)
19
-
20
+ self.last_role = self.get_last_role()
21
+
22
+ def get_last_role(self):
23
+ return self.messages[-1]["role"] if self.messages else "system"
24
+
20
25
  def pop(self):
21
26
  if not self.messages:
22
27
  return None
23
-
28
+
24
29
  popped_message = self.messages.pop()
25
-
30
+
26
31
  if self.messages:
27
- self.last_role = self.messages[-1]["role"]
32
+ self.last_role = self.get_last_role()
28
33
  else:
29
34
  self.last_role = "system"
30
-
35
+
31
36
  return popped_message["content"]
32
37
 
33
38
  def __len__(self):
34
39
  return len(self.messages)
35
40
 
36
41
  def __str__(self):
37
- return '\n'.join([f"Message {i} ({msg['role']}): {msg['content']}" for i, msg in enumerate(self.messages)])
42
+ formatted_messages = []
43
+ for i, msg in enumerate(self.messages):
44
+ role = msg['role']
45
+ content = msg['content']
46
+ if isinstance(content, str):
47
+ formatted_content = content
48
+ elif isinstance(content, list):
49
+ parts = []
50
+ for block in content:
51
+ if isinstance(block, str):
52
+ parts.append(block)
53
+ elif isinstance(block, Image.Image):
54
+ parts.append(f"[Image Object: {block.filename}]")
55
+ elif isinstance(block, dict):
56
+ if block.get("type") == "image_url":
57
+ parts.append(f"[Image URL: {block.get('image_url', {}).get('url', '')}]")
58
+ elif block.get("type") == "image_base64":
59
+ parts.append(f"[Image Base64: {block.get('image_base64', {}).get('data', '')[:30]}...]")
60
+ formatted_content = "\n".join(parts)
61
+ else:
62
+ formatted_content = str(content)
63
+ formatted_messages.append(f"Message {i} ({role}): {formatted_content}")
64
+ return '\n'.join(formatted_messages)
38
65
 
39
66
  def __getitem__(self, key):
40
67
  if isinstance(key, slice):
41
- # Handle slice object, no change needed here for negative indices as slices are handled by list itself
42
- print('\n'.join([f"({msg['role']}): {msg['content']}" for i, msg in enumerate(self.messages[key])]))
43
- return self.messages[key]
68
+ sliced_messages = self.messages[key]
69
+ formatted = []
70
+ for msg in sliced_messages:
71
+ role = msg['role']
72
+ content = msg['content']
73
+ if isinstance(content, str):
74
+ formatted_content = content
75
+ elif isinstance(content, list):
76
+ parts = []
77
+ for block in content:
78
+ if isinstance(block, str):
79
+ parts.append(block)
80
+ elif isinstance(block, Image.Image):
81
+ parts.append(f"[Image Object: {block.filename}]")
82
+ elif isinstance(block, dict):
83
+ if block.get("type") == "image_url":
84
+ parts.append(f"[Image URL: {block.get('image_url', {}).get('url', '')}]")
85
+ elif block.get("type") == "image_base64":
86
+ parts.append(f"[Image Base64: {block.get('image_base64', {}).get('data', '')[:30]}...]")
87
+ formatted_content = "\n".join(parts)
88
+ else:
89
+ formatted_content = str(content)
90
+ formatted.append(f"({role}): {formatted_content}")
91
+ print('\n'.join(formatted))
92
+ return sliced_messages
44
93
  elif isinstance(key, int):
45
94
  # Adjust for negative indices
46
95
  if key < 0:
47
- key += len(self.messages) # Convert negative index to positive
96
+ key += len(self.messages)
48
97
  if 0 <= key < len(self.messages):
98
+ msg = self.messages[key]
99
+ role = msg['role']
100
+ content = msg['content']
101
+ if isinstance(content, str):
102
+ formatted_content = content
103
+ elif isinstance(content, list):
104
+ parts = []
105
+ for block in content:
106
+ if isinstance(block, str):
107
+ parts.append(block)
108
+ elif isinstance(block, Image.Image):
109
+ parts.append(f"[Image Object: {block.filename}]")
110
+ elif isinstance(block, dict):
111
+ if block.get("type") == "image_url":
112
+ parts.append(f"[Image URL: {block.get('image_url', {}).get('url', '')}]")
113
+ elif block.get("type") == "image_base64":
114
+ parts.append(f"[Image Base64: {block.get('image_base64', {}).get('data', '')[:30]}...]")
115
+ formatted_content = "\n".join(parts)
116
+ else:
117
+ formatted_content = str(content)
49
118
  snippet = self.get_conversation_snippet(key)
50
- print('\n'.join([f"({v['role']}): {self.color_text(v['content'], 'green') if k == 'current' else v['content']}" for k, v in snippet.items() if v]))
119
+ print('\n'.join([f"({v['role']}): {v['content']}" for k, v in snippet.items() if v]))
51
120
  return self.messages[key]
52
121
  else:
53
122
  raise IndexError("Message index out of range.")
@@ -55,8 +124,8 @@ class ChatHistory:
55
124
  raise TypeError("Invalid argument type.")
56
125
 
57
126
  def __setitem__(self, index, value):
58
- if not isinstance(value, str):
59
- raise ValueError("Message content must be a string.")
127
+ if not isinstance(value, (str, list)):
128
+ raise ValueError("Message content must be a string or a list of content blocks.")
60
129
  role = "system" if index % 2 == 0 else "user"
61
130
  self.messages[index] = {"role": role, "content": value}
62
131
 
@@ -68,19 +137,27 @@ class ChatHistory:
68
137
  self.add_message(message, next_role)
69
138
 
70
139
  def __contains__(self, item):
71
- return any(item in message['content'] for message in self.messages)
140
+ for message in self.messages:
141
+ content = message['content']
142
+ if isinstance(content, str) and item in content:
143
+ return True
144
+ elif isinstance(content, list):
145
+ for block in content:
146
+ if isinstance(block, str) and item in block:
147
+ return True
148
+ return False
72
149
 
73
- def add_message(self, content, role):
150
+ def add_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]], role: str):
74
151
  self.messages.append({"role": role, "content": content})
75
152
  self.last_role = role
76
153
 
77
- def add_user_message(self, content):
154
+ def add_user_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
78
155
  if self.last_role in ["system", "assistant"]:
79
156
  self.add_message(content, "user")
80
157
  else:
81
158
  raise ValueError("A user message must follow a system or assistant message.")
82
159
 
83
- def add_assistant_message(self, content):
160
+ def add_assistant_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
84
161
  if self.last_role == "user":
85
162
  self.add_message(content, "assistant")
86
163
  else:
@@ -110,7 +187,17 @@ class ChatHistory:
110
187
  print(f"Content of the last message: {status['last_message_content']}")
111
188
 
112
189
  def search_for_keyword(self, keyword):
113
- return [msg for msg in self.messages if keyword.lower() in msg["content"].lower()]
190
+ results = []
191
+ for msg in self.messages:
192
+ content = msg['content']
193
+ if isinstance(content, str) and keyword.lower() in content.lower():
194
+ results.append(msg)
195
+ elif isinstance(content, list):
196
+ for block in content:
197
+ if isinstance(block, str) and keyword.lower() in block.lower():
198
+ results.append(msg)
199
+ break
200
+ return results
114
201
 
115
202
  def has_user_or_assistant_spoken_since_last_system(self):
116
203
  for msg in reversed(self.messages):
@@ -143,4 +230,4 @@ class ChatHistory:
143
230
  @staticmethod
144
231
  def color_text(text, color):
145
232
  colors = {"green": "\033[92m", "red": "\033[91m", "end": "\033[0m"}
146
- return f"{colors[color]}{text}{colors['end']}"
233
+ return f"{colors.get(color, '')}{text}{colors.get('end', '')}"
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: llm_dialog_manager
3
- Version: 0.3.4
3
+ Version: 0.4.1
4
4
  Summary: A Python package for managing LLM chat conversation history
5
5
  Author-email: xihajun <work@2333.fun>
6
6
  License: MIT
@@ -64,6 +64,7 @@ A Python package for managing AI chat conversation history with support for mult
64
64
  - Memory management options
65
65
  - Conversation search and indexing
66
66
  - Rich conversation display options
67
+ - Vision & Json Output enabled [20240111]
67
68
 
68
69
  ## Installation
69
70
 
@@ -0,0 +1,9 @@
1
+ llm_dialog_manager/__init__.py,sha256=hTHvsXzvD5geKgv2XERYcp2f-T3LoVVc3arXfPtNS1k,86
2
+ llm_dialog_manager/agent.py,sha256=ZKO3eKHTKcbmYpVRRIpzDy7Tlp_VgQ90ewr1758Ozgs,23931
3
+ llm_dialog_manager/chat_history.py,sha256=DKKRnj_M6h-4JncnH6KekMTghX7vMgdN3J9uOwXKzMU,10347
4
+ llm_dialog_manager/key_manager.py,sha256=shvxmn4zUtQx_p-x1EFyOmnk-WlhigbpKtxTKve-zXk,4421
5
+ llm_dialog_manager-0.4.1.dist-info/LICENSE,sha256=vWGbYgGuWpWrXL8-xi6pNcX5UzD6pWoIAZmcetyfbus,1064
6
+ llm_dialog_manager-0.4.1.dist-info/METADATA,sha256=LER5FN6lFQFPs_8A-fIM7VYmqN-fh0nCD6Dt8vslsiY,4194
7
+ llm_dialog_manager-0.4.1.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
8
+ llm_dialog_manager-0.4.1.dist-info/top_level.txt,sha256=u2EQEXW0NGAt0AAHT7jx1odXZ4rZfjcgbmJhvKFuMkI,19
9
+ llm_dialog_manager-0.4.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.6.0)
2
+ Generator: setuptools (75.8.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,9 +0,0 @@
1
- llm_dialog_manager/__init__.py,sha256=L-vxURJQtPwajtaFFl8gFI-xAsggFXNNY3aOqlVIViU,86
2
- llm_dialog_manager/agent.py,sha256=ruZWlSZZ6jBr8x8I-lWdsjiSP44BTYTR9fBS0FuE-Rc,13310
3
- llm_dialog_manager/chat_history.py,sha256=xKA-oQCv8jv_g8EhXrG9h1S8Icbj2FfqPIhbty5vra4,6033
4
- llm_dialog_manager/key_manager.py,sha256=shvxmn4zUtQx_p-x1EFyOmnk-WlhigbpKtxTKve-zXk,4421
5
- llm_dialog_manager-0.3.4.dist-info/LICENSE,sha256=vWGbYgGuWpWrXL8-xi6pNcX5UzD6pWoIAZmcetyfbus,1064
6
- llm_dialog_manager-0.3.4.dist-info/METADATA,sha256=6XPZd3oGCVC8YS-EVHII1EFFsnml1kS2VgRYoVsNaIQ,4152
7
- llm_dialog_manager-0.3.4.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
8
- llm_dialog_manager-0.3.4.dist-info/top_level.txt,sha256=u2EQEXW0NGAt0AAHT7jx1odXZ4rZfjcgbmJhvKFuMkI,19
9
- llm_dialog_manager-0.3.4.dist-info/RECORD,,