llm-dialog-manager 0.3.4__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
1
  from .chat_history import ChatHistory
2
2
  from .agent import Agent
3
3
 
4
- __version__ = "0.3.4"
4
+ __version__ = "0.4.1"
@@ -2,13 +2,15 @@
2
2
  import json
3
3
  import os
4
4
  import uuid
5
- from typing import List, Dict, Optional
5
+ from typing import List, Dict, Optional, Union
6
6
  import logging
7
7
  from pathlib import Path
8
8
  import random
9
9
  import requests
10
10
  import zipfile
11
11
  import io
12
+ import base64
13
+ from PIL import Image
12
14
 
13
15
  # Third-party imports
14
16
  import anthropic
@@ -18,8 +20,8 @@ import openai
18
20
  from dotenv import load_dotenv
19
21
 
20
22
  # Local imports
21
- from llm_dialog_manager.chat_history import ChatHistory
22
- from llm_dialog_manager.key_manager import key_manager
23
+ from .chat_history import ChatHistory
24
+ from .key_manager import key_manager
23
25
 
24
26
  # Set up logging
25
27
  logging.basicConfig(level=logging.INFO)
@@ -28,7 +30,7 @@ logger = logging.getLogger(__name__)
28
30
  # Load environment variables
29
31
  def load_env_vars():
30
32
  """Load environment variables from .env file"""
31
- env_path = Path(__file__).parent.parent / '.env'
33
+ env_path = Path(__file__).parent / '.env'
32
34
  if env_path.exists():
33
35
  load_dotenv(env_path)
34
36
  else:
@@ -36,24 +38,37 @@ def load_env_vars():
36
38
 
37
39
  load_env_vars()
38
40
 
39
- def create_and_send_message(client, model, max_tokens, temperature, messages, system_msg):
40
- """Function to send a message to the Anthropic API and handle the response."""
41
- try:
42
- response = client.messages.create(
43
- model=model,
44
- max_tokens=max_tokens,
45
- temperature=temperature,
46
- messages=messages,
47
- system=system_msg
48
- )
49
- return response
50
- except Exception as e:
51
- logger.error(f"Error sending message: {str(e)}")
52
- raise
41
+ def format_messages_for_gemini(messages):
42
+ """
43
+ 将标准化的消息格式转化为 Gemini 格式。
44
+ system 消息应该通过 GenerativeModel 的 system_instruction 参数传入,
45
+ 不在这个函数处理。
46
+ """
47
+ gemini_messages = []
53
48
 
54
- def completion(model: str, messages: List[Dict[str, str]], max_tokens: int = 1000,
55
- temperature: float = 0.5, api_key: Optional[str] = None,
56
- base_url: Optional[str] = None) -> str:
49
+ for msg in messages:
50
+ role = msg["role"]
51
+ content = msg["content"]
52
+
53
+ # 跳过 system 消息,因为它会通过 system_instruction 设置
54
+ if role == "system":
55
+ continue
56
+
57
+ # 处理 user/assistant 消息
58
+ # 如果 content 是单一对象,转换为列表
59
+ if not isinstance(content, list):
60
+ content = [content]
61
+
62
+ gemini_messages.append({
63
+ "role": role,
64
+ "parts": content # content 可以包含文本和 FileMedia
65
+ })
66
+
67
+ return gemini_messages
68
+
69
+ def completion(model: str, messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]], max_tokens: int = 1000,
70
+ temperature: float = 0.5, top_p: float = 1.0, top_k: int = 40, api_key: Optional[str] = None,
71
+ base_url: Optional[str] = None, json_format: bool = False) -> str:
57
72
  """
58
73
  Generate a completion using the specified model and messages.
59
74
  """
@@ -70,202 +85,396 @@ def completion(model: str, messages: List[Dict[str, str]], max_tokens: int = 100
70
85
 
71
86
  # Get API key and base URL from key manager if not provided
72
87
  if not api_key:
73
- api_key, base_url = key_manager.get_config(service)
88
+ # api_key, base_url = key_manager.get_config(service)
89
+ # Placeholder for key_manager
90
+ api_key = os.getenv(f"{service.upper()}_API_KEY")
91
+ base_url = os.getenv(f"{service.upper()}_BASE_URL")
74
92
 
75
- try:
93
+ def format_messages_for_api(model, messages):
94
+ """Convert ChatHistory messages to the format required by the specific API."""
76
95
  if "claude" in model:
77
- # Check for Vertex configuration
78
- vertex_project_id = os.getenv('VERTEX_PROJECT_ID')
79
- vertex_region = os.getenv('VERTEX_REGION')
80
-
81
- if vertex_project_id and vertex_region:
82
- client = AnthropicVertex(
83
- region=vertex_region,
84
- project_id=vertex_project_id
85
- )
96
+ formatted = []
97
+ system_msg = ""
98
+ if messages and messages[0]["role"] == "system":
99
+ system_msg = messages.pop(0)["content"]
100
+ for msg in messages:
101
+ content = msg["content"]
102
+ if isinstance(content, str):
103
+ formatted.append({"role": msg["role"], "content": content})
104
+ elif isinstance(content, list):
105
+ # Combine content blocks into a single message
106
+ combined_content = []
107
+ for block in content:
108
+ if isinstance(block, str):
109
+ combined_content.append({"type": "text", "text": block})
110
+ elif isinstance(block, Image.Image):
111
+ # For Claude, convert PIL.Image to base64
112
+ buffered = io.BytesIO()
113
+ block.save(buffered, format="PNG")
114
+ image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
115
+ combined_content.append({
116
+ "type": "image",
117
+ "source": {
118
+ "type": "base64",
119
+ "media_type": "image/png",
120
+ "data": image_base64
121
+ }
122
+ })
123
+ elif isinstance(block, dict):
124
+ if block.get("type") == "image_url":
125
+ combined_content.append({
126
+ "type": "image",
127
+ "source": {
128
+ "type": "url",
129
+ "url": block["image_url"]["url"]
130
+ }
131
+ })
132
+ elif block.get("type") == "image_base64":
133
+ combined_content.append({
134
+ "type": "image",
135
+ "source": {
136
+ "type": "base64",
137
+ "media_type": block["image_base64"]["media_type"],
138
+ "data": block["image_base64"]["data"]
139
+ }
140
+ })
141
+ formatted.append({"role": msg["role"], "content": combined_content})
142
+ return system_msg, formatted
143
+
144
+ elif "gemini" in model or "gpt" in model or "grok" in model:
145
+ formatted = []
146
+ for msg in messages:
147
+ content = msg["content"]
148
+ if isinstance(content, str):
149
+ formatted.append({"role": msg["role"], "parts": [content]})
150
+ elif isinstance(content, list):
151
+ parts = []
152
+ for block in content:
153
+ if isinstance(block, str):
154
+ parts.append(block)
155
+ elif isinstance(block, Image.Image):
156
+ parts.append(block)
157
+ elif isinstance(block, dict):
158
+ if block.get("type") == "image_url":
159
+ parts.append({"type": "image_url", "image_url": {"url": block["image_url"]["url"]}})
160
+ elif block.get("type") == "image_base64":
161
+ parts.append({"type": "image_base64", "image_base64": {"data": block["image_base64"]["data"], "media_type": block["image_base64"]["media_type"]}})
162
+ formatted.append({"role": msg["role"], "parts": parts})
163
+ return None, formatted
164
+
165
+ else: # OpenAI models
166
+ formatted = []
167
+ for msg in messages:
168
+ content = msg["content"]
169
+ if isinstance(content, str):
170
+ formatted.append({"role": msg["role"], "content": content})
171
+ elif isinstance(content, list):
172
+ # OpenAI expects 'content' as string; images are not directly supported
173
+ # You can convert images to URLs or descriptions if needed
174
+ combined_content = ""
175
+ for block in content:
176
+ if isinstance(block, str):
177
+ combined_content += block + "\n"
178
+ elif isinstance(block, Image.Image):
179
+ # Convert PIL.Image to base64 or upload and use URL
180
+ buffered = io.BytesIO()
181
+ block.save(buffered, format="PNG")
182
+ image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
183
+ combined_content += f"[Image Base64: {image_base64[:30]}...]\n"
184
+ elif isinstance(block, dict):
185
+ if block.get("type") == "image_url":
186
+ combined_content += f"[Image: {block['image_url']['url']}]\n"
187
+ elif block.get("type") == "image_base64":
188
+ combined_content += f"[Image Base64: {block['image_base64']['data'][:30]}...]\n"
189
+ formatted.append({"role": msg["role"], "content": combined_content.strip()})
190
+ return None, formatted
191
+
192
+ system_msg, formatted_messages = format_messages_for_api(model, messages.copy())
193
+
194
+ if "claude" in model:
195
+ # Check for Vertex configuration
196
+ vertex_project_id = os.getenv('VERTEX_PROJECT_ID')
197
+ vertex_region = os.getenv('VERTEX_REGION')
198
+
199
+ if vertex_project_id and vertex_region:
200
+ client = AnthropicVertex(
201
+ region=vertex_region,
202
+ project_id=vertex_project_id
203
+ )
204
+ else:
205
+ client = anthropic.Anthropic(api_key=api_key, base_url=base_url)
206
+
207
+ response = client.messages.create(
208
+ model=model,
209
+ max_tokens=max_tokens,
210
+ temperature=temperature,
211
+ messages=formatted_messages,
212
+ system=system_msg
213
+ )
214
+
215
+ while response.stop_reason == "max_tokens":
216
+ if formatted_messages[-1]['role'] == "user":
217
+ formatted_messages.append({"role": "assistant", "content": response.completion})
86
218
  else:
87
- client = anthropic.Anthropic(api_key=api_key, base_url=base_url)
219
+ formatted_messages[-1]['content'] += response.completion
88
220
 
89
- system_msg = messages.pop(0)["content"] if messages and messages[0]["role"] == "system" else ""
90
221
  response = client.messages.create(
91
222
  model=model,
92
223
  max_tokens=max_tokens,
93
224
  temperature=temperature,
94
- messages=messages,
225
+ messages=formatted_messages,
95
226
  system=system_msg
96
227
  )
228
+
229
+ if formatted_messages[-1]['role'] == "assistant" and response.stop_reason == "end_turn":
230
+ formatted_messages[-1]['content'] += response.completion
231
+ return formatted_messages[-1]['content']
232
+
233
+ return response.completion
234
+
235
+ elif "gemini" in model:
236
+ try:
237
+ # First try OpenAI-style API
238
+ client = openai.OpenAI(
239
+ api_key=api_key,
240
+ base_url="https://generativelanguage.googleapis.com/v1beta/"
241
+ )
242
+ # Set response_format based on json_format
243
+ response_format = {"type": "json_object"} if json_format else {"type": "plain_text"}
244
+
245
+ response = client.chat.completions.create(
246
+ model=model,
247
+ max_tokens=max_tokens,
248
+ top_p=top_p,
249
+ top_k=top_k,
250
+ messages=formatted_messages,
251
+ temperature=temperature,
252
+ response_format=response_format # Added response_format
253
+ )
254
+ return response.choices[0].message.content
255
+
256
+ except Exception as e:
257
+ # If OpenAI-style API fails, fall back to Google's genai library
258
+ logger.info("Falling back to Google's genai library")
259
+ genai.configure(api_key=api_key)
260
+ system_instruction = ""
261
+ for msg in messages:
262
+ if msg["role"] == "system":
263
+ system_instruction = msg["content"]
264
+ break
97
265
 
98
- while response.stop_reason == "max_tokens":
99
- if messages[-1]['role'] == "user":
100
- messages.append({"role": "assistant", "content": response.content[0].text})
101
- else:
102
- messages[-1]['content'] += response.content[0].text
103
-
104
- response = client.messages.create(
105
- model=model,
106
- max_tokens=max_tokens,
107
- temperature=temperature,
108
- messages=messages,
109
- system=system_msg
110
- )
111
-
112
- if messages[-1]['role'] == "assistant" and response.stop_reason == "end_turn":
113
- messages[-1]['content'] += response.content[0].text
114
- return messages[-1]['content']
115
-
116
- return response.content[0].text
117
-
118
- elif "gemini" in model:
119
- try:
120
- # First try OpenAI-style API
121
- client = openai.OpenAI(
122
- api_key=api_key,
123
- base_url="https://generativelanguage.googleapis.com/v1beta/"
124
- )
125
- # Remove any system message from the beginning if present
126
- if messages and messages[0]["role"] == "system":
127
- system_msg = messages.pop(0)
128
- # Prepend system message to first user message if exists
129
- if messages:
130
- messages[0]["content"] = f"{system_msg['content']}\n\n{messages[0]['content']}"
131
-
132
- response = client.chat.completions.create(
133
- model=model,
134
- messages=messages,
135
- temperature=temperature
136
- )
137
-
138
- return response.choices[0].message.content
139
-
140
- except Exception as e:
141
- # If OpenAI-style API fails, fall back to Google's genai library
142
- logger.info("Falling back to Google's genai library")
143
- genai.configure(api_key=api_key)
144
-
145
- # Convert messages to Gemini format
146
- gemini_messages = []
147
- for msg in messages:
148
- if msg["role"] == "system":
149
- # Prepend system message to first user message if exists
150
- if gemini_messages:
151
- gemini_messages[0].parts[0].text = f"{msg['content']}\n\n{gemini_messages[0].parts[0].text}"
152
- else:
153
- gemini_messages.append({"role": msg["role"], "parts": [{"text": msg["content"]}]})
154
-
155
- # Create Gemini model and generate response
156
- model = genai.GenerativeModel(model_name=model)
157
- response = model.generate_content(
158
- gemini_messages,
159
- generation_config=genai.types.GenerationConfig(
160
- temperature=temperature,
161
- max_output_tokens=max_tokens
162
- )
163
- )
164
-
165
- return response.text
166
-
167
- elif "grok" in model:
168
- # Randomly choose between OpenAI and Anthropic SDK
169
- use_anthropic = random.choice([True, False])
170
-
171
- if use_anthropic:
172
- print("using anthropic")
173
- client = anthropic.Anthropic(
174
- api_key=api_key,
175
- base_url="https://api.x.ai"
176
- )
177
-
178
- system_msg = messages.pop(0)["content"] if messages and messages[0]["role"] == "system" else ""
179
- response = client.messages.create(
180
- model=model,
181
- max_tokens=max_tokens,
182
- temperature=temperature,
183
- messages=messages,
184
- system=system_msg
185
- )
186
- return response.content[0].text
187
- else:
188
- print("using openai")
189
- client = openai.OpenAI(
190
- api_key=api_key,
191
- base_url="https://api.x.ai/v1"
192
- )
193
- response = client.chat.completions.create(
194
- model=model,
195
- messages=messages,
196
- max_tokens=max_tokens,
197
- temperature=temperature
198
- )
199
- return response.choices[0].message.content
266
+ # 将其他消息转换为 gemini 格式
267
+ gemini_messages = format_messages_for_gemini(messages)
268
+ mime_type = "application/json" if json_format else "text/plain"
269
+ generation_config = genai.types.GenerationConfig(
270
+ temperature=temperature,
271
+ top_p=top_p,
272
+ top_k=top_k,
273
+ max_output_tokens=max_tokens,
274
+ response_mime_type=mime_type
275
+ )
276
+
277
+ model_instance = genai.GenerativeModel(
278
+ model_name=model,
279
+ system_instruction=system_instruction, # system 消息通过这里传入
280
+ generation_config=generation_config
281
+ )
282
+
283
+ response = model_instance.generate_content(gemini_messages, generation_config=generation_config)
284
+
285
+ return response.text
286
+
287
+ elif "grok" in model:
288
+ # Randomly choose between OpenAI and Anthropic SDK
289
+ use_anthropic = random.choice([True, False])
290
+
291
+ if use_anthropic:
292
+ logger.info("Using Anthropic for Grok model")
293
+ client = anthropic.Anthropic(
294
+ api_key=api_key,
295
+ base_url="https://api.x.ai"
296
+ )
297
+
298
+ system_msg = ""
299
+ if messages and messages[0]["role"] == "system":
300
+ system_msg = messages.pop(0)["content"]
301
+
302
+ response = client.messages.create(
303
+ model=model,
304
+ max_tokens=max_tokens,
305
+ temperature=temperature,
306
+ messages=formatted_messages,
307
+ system=system_msg
308
+ )
309
+ return response.completion
310
+ else:
311
+ logger.info("Using OpenAI for Grok model")
312
+ client = openai.OpenAI(
313
+ api_key=api_key,
314
+ base_url="https://api.x.ai/v1"
315
+ )
316
+ # Set response_format based on json_format
317
+ response_format = {"type": "json_object"} if json_format else {"type": "plain_text"}
200
318
 
201
- else: # OpenAI models
202
- client = openai.OpenAI(api_key=api_key, base_url=base_url)
203
319
  response = client.chat.completions.create(
204
320
  model=model,
205
- messages=messages,
321
+ messages=formatted_messages,
206
322
  max_tokens=max_tokens,
207
323
  temperature=temperature,
324
+ response_format=response_format # Added response_format
208
325
  )
209
326
  return response.choices[0].message.content
210
327
 
211
- # Release the API key after successful use
212
- if not api_key:
213
- key_manager.release_config(service, api_key)
328
+ else: # OpenAI models
329
+ client = openai.OpenAI(api_key=api_key, base_url=base_url)
330
+ # Set response_format based on json_format
331
+ response_format = {"type": "json_object"} if json_format else {"type": "plain_text"}
214
332
 
215
- return response
333
+ response = client.chat.completions.create(
334
+ model=model,
335
+ messages=formatted_messages,
336
+ max_tokens=max_tokens,
337
+ temperature=temperature,
338
+ response_format=response_format # Added response_format
339
+ )
340
+ return response.choices[0].message.content
216
341
 
217
- except Exception as e:
218
- # Report error to key manager
219
- if not api_key:
220
- key_manager.report_error(service, api_key)
221
- raise
342
+ # Release the API key after successful use
343
+ if not api_key:
344
+ # key_manager.release_config(service, api_key)
345
+ pass
346
+
347
+ return response
222
348
 
223
349
  except Exception as e:
224
350
  logger.error(f"Error in completion: {str(e)}")
225
351
  raise
226
352
 
227
353
  class Agent:
228
- def __init__(self, model_name: str, messages: Optional[str] = None,
354
+ def __init__(self, model_name: str, messages: Optional[Union[str, List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]]] = None,
229
355
  memory_enabled: bool = False, api_key: Optional[str] = None) -> None:
230
356
  """Initialize an Agent instance."""
231
- # valid_models = ['gpt-3.5-turbo', 'gpt-4', 'claude-2.1', 'gemini-1.5-pro', 'gemini-1.5-flash', 'grok-beta', 'claude-3-5-sonnet-20241022']
232
- # if model_name not in valid_models:
233
- # raise ValueError(f"Model {model_name} not supported. Supported models: {valid_models}")
234
-
235
357
  self.id = f"{model_name}-{uuid.uuid4().hex[:8]}"
236
358
  self.model_name = model_name
237
- self.history = ChatHistory(messages)
359
+ self.history = ChatHistory(messages) if messages else ChatHistory()
238
360
  self.memory_enabled = memory_enabled
239
361
  self.api_key = api_key
240
362
  self.repo_content = []
241
363
 
242
- def add_message(self, role, content):
243
- repo_content = ""
244
- while self.repo_content:
245
- repo = self.repo_content.pop()
246
- repo_content += f"<repo>\n{repo}\n</repo>\n"
247
-
248
- content = repo_content + content
364
+ def add_message(self, role: str, content: Union[str, List[Union[str, Image.Image, Dict]]]):
365
+ """Add a message to the conversation."""
249
366
  self.history.add_message(content, role)
250
367
 
251
- def generate_response(self, max_tokens=3585, temperature=0.7):
368
+ def add_user_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
369
+ """Add a user message."""
370
+ self.history.add_user_message(content)
371
+
372
+ def add_assistant_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
373
+ """Add an assistant message."""
374
+ self.history.add_assistant_message(content)
375
+
376
+ def add_image(self, image_path: Optional[str] = None, image_url: Optional[str] = None, media_type: Optional[str] = "image/jpeg"):
377
+ """
378
+ Add an image to the conversation.
379
+ Either image_path or image_url must be provided.
380
+ """
381
+ if not image_path and not image_url:
382
+ raise ValueError("Either image_path or image_url must be provided.")
383
+
384
+ if image_path:
385
+ if not os.path.exists(image_path):
386
+ raise FileNotFoundError(f"Image file {image_path} does not exist.")
387
+ if "gemini" in self.model_name:
388
+ # For Gemini, load as PIL.Image
389
+ image_pil = Image.open(image_path)
390
+ image_block = image_pil
391
+ else:
392
+ # For Claude and others, use base64 encoding
393
+ with open(image_path, "rb") as img_file:
394
+ image_data = base64.standard_b64encode(img_file.read()).decode("utf-8")
395
+ image_block = {
396
+ "type": "image_base64",
397
+ "image_base64": {
398
+ "media_type": media_type,
399
+ "data": image_data
400
+ }
401
+ }
402
+ else:
403
+ # If image_url is provided
404
+ if "gemini" in self.model_name:
405
+ # For Gemini, you can pass image URLs directly
406
+ image_block = {"type": "image_url", "image_url": {"url": image_url}}
407
+ else:
408
+ # For Claude and others, use image URLs
409
+ image_block = {
410
+ "type": "image_url",
411
+ "image_url": {
412
+ "url": image_url
413
+ }
414
+ }
415
+
416
+ # Add the image block to the last user message or as a new user message
417
+ if self.history.last_role == "user":
418
+ current_content = self.history.messages[-1]["content"]
419
+ if isinstance(current_content, list):
420
+ current_content.append(image_block)
421
+ else:
422
+ self.history.messages[-1]["content"] = [current_content, image_block]
423
+ else:
424
+ # Start a new user message with the image
425
+ self.history.add_message([image_block], "user")
426
+
427
+ def generate_response(self, max_tokens=3585, temperature=0.7, top_p=1.0, top_k=40, json_format: bool = False) -> str:
428
+ """Generate a response from the agent.
429
+
430
+ Args:
431
+ max_tokens (int, optional): Maximum number of tokens. Defaults to 3585.
432
+ temperature (float, optional): Sampling temperature. Defaults to 0.7.
433
+ json_format (bool, optional): Whether to enable JSON output format. Defaults to False.
434
+
435
+ Returns:
436
+ str: The generated response.
437
+ """
252
438
  if not self.history.messages:
253
439
  raise ValueError("No messages in history to generate response from")
254
440
 
255
- messages = [{"role": msg["role"], "content": msg["content"]} for msg in self.history.messages]
256
-
441
+ messages = self.history.messages
442
+ print(self.model_name)
257
443
  response_text = completion(
258
444
  model=self.model_name,
259
445
  messages=messages,
260
446
  max_tokens=max_tokens,
261
447
  temperature=temperature,
262
- api_key=self.api_key
448
+ top_p=top_p,
449
+ top_k=top_k,
450
+ api_key=self.api_key,
451
+ json_format=json_format # Pass json_format to completion
263
452
  )
264
- if messages[-1]["role"] == "assistant":
265
- self.history.messages[-1]["content"] = response_text
266
-
267
- elif self.memory_enabled:
268
- self.add_message("assistant", response_text)
453
+ if self.model_name.startswith("openai"):
454
+ # OpenAI does not support images, so responses are simple strings
455
+ if self.history.messages[-1]["role"] == "assistant":
456
+ self.history.messages[-1]["content"] = response_text
457
+ elif self.memory_enabled:
458
+ self.add_message("assistant", response_text)
459
+ elif "claude" in self.model_name:
460
+ if self.history.messages[-1]["role"] == "assistant":
461
+ self.history.messages[-1]["content"] = response_text
462
+ elif self.memory_enabled:
463
+ self.add_message("assistant", response_text)
464
+ elif "gemini" in self.model_name or "grok" in self.model_name:
465
+ if self.history.messages[-1]["role"] == "assistant":
466
+ if isinstance(self.history.messages[-1]["content"], list):
467
+ self.history.messages[-1]["content"].append(response_text)
468
+ else:
469
+ self.history.messages[-1]["content"] = [self.history.messages[-1]["content"], response_text]
470
+ elif self.memory_enabled:
471
+ self.add_message("assistant", response_text)
472
+ else:
473
+ # Handle other models similarly
474
+ if self.history.messages[-1]["role"] == "assistant":
475
+ self.history.messages[-1]["content"] = response_text
476
+ elif self.memory_enabled:
477
+ self.add_message("assistant", response_text)
269
478
 
270
479
  return response_text
271
480
 
@@ -274,11 +483,12 @@ class Agent:
274
483
  with open(filename, 'w', encoding='utf-8') as file:
275
484
  json.dump(self.history.messages, file, ensure_ascii=False, indent=4)
276
485
 
277
- def load_conversation(self, filename=None):
486
+ def load_conversation(self, filename: Optional[str] = None):
278
487
  if filename is None:
279
488
  filename = f"{self.id}.json"
280
489
  with open(filename, 'r', encoding='utf-8') as file:
281
490
  messages = json.load(file)
491
+ # Handle deserialization of images if necessary
282
492
  self.history = ChatHistory(messages)
283
493
 
284
494
  def add_repo(self, repo_url: Optional[str] = None, username: Optional[str] = None, repo_name: Optional[str] = None, commit_hash: Optional[str] = None):
@@ -305,28 +515,31 @@ class Agent:
305
515
  raise ValueError(f"Failed to download repository from {repo_url}")
306
516
 
307
517
  if __name__ == "__main__":
308
-
309
- # write a test for detect finding agent
310
- text = "I think the answer is 42"
311
-
312
- # from agent.messageloader import information_detector_messages
518
+ # Example Usage
519
+ # Create an Agent instance (Gemini model)
520
+ agent = Agent("gemini-1.5-flash", "you are Jack101", memory_enabled=True)
313
521
 
314
- # # Now you can print or use information_detector_messages as needed
315
- # information_detector_agent = Agent("gemini-1.5-pro", information_detector_messages)
316
- # information_detector_agent.add_message("user", text)
317
- # response = information_detector_agent.generate_response()
318
- # print(response)
319
- agent = Agent("gemini-1.5-pro-002", "you are an assistant", memory_enabled=True)
522
+ # Add an image
523
+ agent.add_image(image_path="/Users/junfan/Projects/Personal/oneapi/dialog_manager/example.png")
320
524
 
321
- # Format the prompt to check if the section is the last one in the outline
322
- prompt = f"Say: {text}\n"
525
+ # Add a user message
526
+ agent.add_message("user", "Who are you? What's in this image?")
323
527
 
324
- # Add the prompt as a message from the user
325
- agent.add_message("user", prompt)
326
- agent.add_message("assistant", "the answer")
528
+ # Generate response with JSON format enabled
529
+ try:
530
+ response = agent.generate_response(json_format=True) # json_format set to True
531
+ print("Response:", response)
532
+ except Exception as e:
533
+ logger.error(f"Failed to generate response: {e}")
327
534
 
328
- print(agent.generate_response())
329
- print(agent.history[:])
535
+ # Print the entire conversation history
536
+ print("Conversation History:")
537
+ print(agent.history)
538
+
539
+ # Pop the last message
330
540
  last_message = agent.history.pop()
331
- print(last_message)
332
- print(agent.history[:])
541
+ print("Last Message:", last_message)
542
+
543
+ # Generate another response without JSON format
544
+ response = agent.generate_response()
545
+ print("Response:", response)
@@ -1,53 +1,122 @@
1
1
  from typing import List, Dict, Optional, Union
2
+ from PIL import Image
2
3
 
3
4
  class ChatHistory:
4
- def __init__(self, input_data: Union[str, List[Dict[str, str]]] = "") -> None:
5
- self.messages: List[Dict[str, str]] = []
5
+ def __init__(self, input_data: Union[str, List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]] = "") -> None:
6
+ self.messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]] = []
6
7
  if isinstance(input_data, str) and input_data:
7
8
  self.add_message(input_data, "system")
8
9
  elif isinstance(input_data, list):
9
10
  self.load_messages(input_data)
10
- self.last_role: str = "system" if not self.messages else self.messages[-1]["role"]
11
+ self.last_role: str = "system" if not self.messages else self.get_last_role()
11
12
 
12
- def load_messages(self, messages: List[Dict[str, str]]) -> None:
13
+ def load_messages(self, messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]) -> None:
13
14
  for message in messages:
14
15
  if not ("role" in message and "content" in message):
15
16
  raise ValueError("Each message must have a 'role' and 'content'.")
16
17
  if message["role"] not in ["user", "assistant", "system"]:
17
18
  raise ValueError(f"Invalid role: {message['role']}")
18
19
  self.messages.append(message)
19
-
20
+ self.last_role = self.get_last_role()
21
+
22
+ def get_last_role(self):
23
+ return self.messages[-1]["role"] if self.messages else "system"
24
+
20
25
  def pop(self):
21
26
  if not self.messages:
22
27
  return None
23
-
28
+
24
29
  popped_message = self.messages.pop()
25
-
30
+
26
31
  if self.messages:
27
- self.last_role = self.messages[-1]["role"]
32
+ self.last_role = self.get_last_role()
28
33
  else:
29
34
  self.last_role = "system"
30
-
35
+
31
36
  return popped_message["content"]
32
37
 
33
38
  def __len__(self):
34
39
  return len(self.messages)
35
40
 
36
41
  def __str__(self):
37
- return '\n'.join([f"Message {i} ({msg['role']}): {msg['content']}" for i, msg in enumerate(self.messages)])
42
+ formatted_messages = []
43
+ for i, msg in enumerate(self.messages):
44
+ role = msg['role']
45
+ content = msg['content']
46
+ if isinstance(content, str):
47
+ formatted_content = content
48
+ elif isinstance(content, list):
49
+ parts = []
50
+ for block in content:
51
+ if isinstance(block, str):
52
+ parts.append(block)
53
+ elif isinstance(block, Image.Image):
54
+ parts.append(f"[Image Object: {block.filename}]")
55
+ elif isinstance(block, dict):
56
+ if block.get("type") == "image_url":
57
+ parts.append(f"[Image URL: {block.get('image_url', {}).get('url', '')}]")
58
+ elif block.get("type") == "image_base64":
59
+ parts.append(f"[Image Base64: {block.get('image_base64', {}).get('data', '')[:30]}...]")
60
+ formatted_content = "\n".join(parts)
61
+ else:
62
+ formatted_content = str(content)
63
+ formatted_messages.append(f"Message {i} ({role}): {formatted_content}")
64
+ return '\n'.join(formatted_messages)
38
65
 
39
66
  def __getitem__(self, key):
40
67
  if isinstance(key, slice):
41
- # Handle slice object, no change needed here for negative indices as slices are handled by list itself
42
- print('\n'.join([f"({msg['role']}): {msg['content']}" for i, msg in enumerate(self.messages[key])]))
43
- return self.messages[key]
68
+ sliced_messages = self.messages[key]
69
+ formatted = []
70
+ for msg in sliced_messages:
71
+ role = msg['role']
72
+ content = msg['content']
73
+ if isinstance(content, str):
74
+ formatted_content = content
75
+ elif isinstance(content, list):
76
+ parts = []
77
+ for block in content:
78
+ if isinstance(block, str):
79
+ parts.append(block)
80
+ elif isinstance(block, Image.Image):
81
+ parts.append(f"[Image Object: {block.filename}]")
82
+ elif isinstance(block, dict):
83
+ if block.get("type") == "image_url":
84
+ parts.append(f"[Image URL: {block.get('image_url', {}).get('url', '')}]")
85
+ elif block.get("type") == "image_base64":
86
+ parts.append(f"[Image Base64: {block.get('image_base64', {}).get('data', '')[:30]}...]")
87
+ formatted_content = "\n".join(parts)
88
+ else:
89
+ formatted_content = str(content)
90
+ formatted.append(f"({role}): {formatted_content}")
91
+ print('\n'.join(formatted))
92
+ return sliced_messages
44
93
  elif isinstance(key, int):
45
94
  # Adjust for negative indices
46
95
  if key < 0:
47
- key += len(self.messages) # Convert negative index to positive
96
+ key += len(self.messages)
48
97
  if 0 <= key < len(self.messages):
98
+ msg = self.messages[key]
99
+ role = msg['role']
100
+ content = msg['content']
101
+ if isinstance(content, str):
102
+ formatted_content = content
103
+ elif isinstance(content, list):
104
+ parts = []
105
+ for block in content:
106
+ if isinstance(block, str):
107
+ parts.append(block)
108
+ elif isinstance(block, Image.Image):
109
+ parts.append(f"[Image Object: {block.filename}]")
110
+ elif isinstance(block, dict):
111
+ if block.get("type") == "image_url":
112
+ parts.append(f"[Image URL: {block.get('image_url', {}).get('url', '')}]")
113
+ elif block.get("type") == "image_base64":
114
+ parts.append(f"[Image Base64: {block.get('image_base64', {}).get('data', '')[:30]}...]")
115
+ formatted_content = "\n".join(parts)
116
+ else:
117
+ formatted_content = str(content)
49
118
  snippet = self.get_conversation_snippet(key)
50
- print('\n'.join([f"({v['role']}): {self.color_text(v['content'], 'green') if k == 'current' else v['content']}" for k, v in snippet.items() if v]))
119
+ print('\n'.join([f"({v['role']}): {v['content']}" for k, v in snippet.items() if v]))
51
120
  return self.messages[key]
52
121
  else:
53
122
  raise IndexError("Message index out of range.")
@@ -55,8 +124,8 @@ class ChatHistory:
55
124
  raise TypeError("Invalid argument type.")
56
125
 
57
126
  def __setitem__(self, index, value):
58
- if not isinstance(value, str):
59
- raise ValueError("Message content must be a string.")
127
+ if not isinstance(value, (str, list)):
128
+ raise ValueError("Message content must be a string or a list of content blocks.")
60
129
  role = "system" if index % 2 == 0 else "user"
61
130
  self.messages[index] = {"role": role, "content": value}
62
131
 
@@ -68,19 +137,27 @@ class ChatHistory:
68
137
  self.add_message(message, next_role)
69
138
 
70
139
  def __contains__(self, item):
71
- return any(item in message['content'] for message in self.messages)
140
+ for message in self.messages:
141
+ content = message['content']
142
+ if isinstance(content, str) and item in content:
143
+ return True
144
+ elif isinstance(content, list):
145
+ for block in content:
146
+ if isinstance(block, str) and item in block:
147
+ return True
148
+ return False
72
149
 
73
- def add_message(self, content, role):
150
+ def add_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]], role: str):
74
151
  self.messages.append({"role": role, "content": content})
75
152
  self.last_role = role
76
153
 
77
- def add_user_message(self, content):
154
+ def add_user_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
78
155
  if self.last_role in ["system", "assistant"]:
79
156
  self.add_message(content, "user")
80
157
  else:
81
158
  raise ValueError("A user message must follow a system or assistant message.")
82
159
 
83
- def add_assistant_message(self, content):
160
+ def add_assistant_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
84
161
  if self.last_role == "user":
85
162
  self.add_message(content, "assistant")
86
163
  else:
@@ -110,7 +187,17 @@ class ChatHistory:
110
187
  print(f"Content of the last message: {status['last_message_content']}")
111
188
 
112
189
  def search_for_keyword(self, keyword):
113
- return [msg for msg in self.messages if keyword.lower() in msg["content"].lower()]
190
+ results = []
191
+ for msg in self.messages:
192
+ content = msg['content']
193
+ if isinstance(content, str) and keyword.lower() in content.lower():
194
+ results.append(msg)
195
+ elif isinstance(content, list):
196
+ for block in content:
197
+ if isinstance(block, str) and keyword.lower() in block.lower():
198
+ results.append(msg)
199
+ break
200
+ return results
114
201
 
115
202
  def has_user_or_assistant_spoken_since_last_system(self):
116
203
  for msg in reversed(self.messages):
@@ -143,4 +230,4 @@ class ChatHistory:
143
230
  @staticmethod
144
231
  def color_text(text, color):
145
232
  colors = {"green": "\033[92m", "red": "\033[91m", "end": "\033[0m"}
146
- return f"{colors[color]}{text}{colors['end']}"
233
+ return f"{colors.get(color, '')}{text}{colors.get('end', '')}"
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: llm_dialog_manager
3
- Version: 0.3.4
3
+ Version: 0.4.1
4
4
  Summary: A Python package for managing LLM chat conversation history
5
5
  Author-email: xihajun <work@2333.fun>
6
6
  License: MIT
@@ -64,6 +64,7 @@ A Python package for managing AI chat conversation history with support for mult
64
64
  - Memory management options
65
65
  - Conversation search and indexing
66
66
  - Rich conversation display options
67
+ - Vision & Json Output enabled [20240111]
67
68
 
68
69
  ## Installation
69
70
 
@@ -0,0 +1,9 @@
1
+ llm_dialog_manager/__init__.py,sha256=hTHvsXzvD5geKgv2XERYcp2f-T3LoVVc3arXfPtNS1k,86
2
+ llm_dialog_manager/agent.py,sha256=ZKO3eKHTKcbmYpVRRIpzDy7Tlp_VgQ90ewr1758Ozgs,23931
3
+ llm_dialog_manager/chat_history.py,sha256=DKKRnj_M6h-4JncnH6KekMTghX7vMgdN3J9uOwXKzMU,10347
4
+ llm_dialog_manager/key_manager.py,sha256=shvxmn4zUtQx_p-x1EFyOmnk-WlhigbpKtxTKve-zXk,4421
5
+ llm_dialog_manager-0.4.1.dist-info/LICENSE,sha256=vWGbYgGuWpWrXL8-xi6pNcX5UzD6pWoIAZmcetyfbus,1064
6
+ llm_dialog_manager-0.4.1.dist-info/METADATA,sha256=LER5FN6lFQFPs_8A-fIM7VYmqN-fh0nCD6Dt8vslsiY,4194
7
+ llm_dialog_manager-0.4.1.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
8
+ llm_dialog_manager-0.4.1.dist-info/top_level.txt,sha256=u2EQEXW0NGAt0AAHT7jx1odXZ4rZfjcgbmJhvKFuMkI,19
9
+ llm_dialog_manager-0.4.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.6.0)
2
+ Generator: setuptools (75.8.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,9 +0,0 @@
1
- llm_dialog_manager/__init__.py,sha256=L-vxURJQtPwajtaFFl8gFI-xAsggFXNNY3aOqlVIViU,86
2
- llm_dialog_manager/agent.py,sha256=ruZWlSZZ6jBr8x8I-lWdsjiSP44BTYTR9fBS0FuE-Rc,13310
3
- llm_dialog_manager/chat_history.py,sha256=xKA-oQCv8jv_g8EhXrG9h1S8Icbj2FfqPIhbty5vra4,6033
4
- llm_dialog_manager/key_manager.py,sha256=shvxmn4zUtQx_p-x1EFyOmnk-WlhigbpKtxTKve-zXk,4421
5
- llm_dialog_manager-0.3.4.dist-info/LICENSE,sha256=vWGbYgGuWpWrXL8-xi6pNcX5UzD6pWoIAZmcetyfbus,1064
6
- llm_dialog_manager-0.3.4.dist-info/METADATA,sha256=6XPZd3oGCVC8YS-EVHII1EFFsnml1kS2VgRYoVsNaIQ,4152
7
- llm_dialog_manager-0.3.4.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
8
- llm_dialog_manager-0.3.4.dist-info/top_level.txt,sha256=u2EQEXW0NGAt0AAHT7jx1odXZ4rZfjcgbmJhvKFuMkI,19
9
- llm_dialog_manager-0.3.4.dist-info/RECORD,,