llm-dialog-manager 0.3.2__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
1
  from .chat_history import ChatHistory
2
2
  from .agent import Agent
3
3
 
4
- __version__ = "0.3.2"
4
+ __version__ = "0.3.5"
@@ -2,13 +2,15 @@
2
2
  import json
3
3
  import os
4
4
  import uuid
5
- from typing import List, Dict, Optional
5
+ from typing import List, Dict, Optional, Union
6
6
  import logging
7
7
  from pathlib import Path
8
8
  import random
9
9
  import requests
10
10
  import zipfile
11
11
  import io
12
+ import base64
13
+ from PIL import Image
12
14
 
13
15
  # Third-party imports
14
16
  import anthropic
@@ -18,8 +20,8 @@ import openai
18
20
  from dotenv import load_dotenv
19
21
 
20
22
  # Local imports
21
- from llm_dialog_manager.chat_history import ChatHistory
22
- from llm_dialog_manager.key_manager import key_manager
23
+ from .chat_history import ChatHistory
24
+ from .key_manager import key_manager
23
25
 
24
26
  # Set up logging
25
27
  logging.basicConfig(level=logging.INFO)
@@ -28,7 +30,7 @@ logger = logging.getLogger(__name__)
28
30
  # Load environment variables
29
31
  def load_env_vars():
30
32
  """Load environment variables from .env file"""
31
- env_path = Path(__file__).parent.parent / '.env'
33
+ env_path = Path(__file__).parent / '.env'
32
34
  if env_path.exists():
33
35
  load_dotenv(env_path)
34
36
  else:
@@ -36,24 +38,9 @@ def load_env_vars():
36
38
 
37
39
  load_env_vars()
38
40
 
39
- def create_and_send_message(client, model, max_tokens, temperature, messages, system_msg):
40
- """Function to send a message to the Anthropic API and handle the response."""
41
- try:
42
- response = client.messages.create(
43
- model=model,
44
- max_tokens=max_tokens,
45
- temperature=temperature,
46
- messages=messages,
47
- system=system_msg
48
- )
49
- return response
50
- except Exception as e:
51
- logger.error(f"Error sending message: {str(e)}")
52
- raise
53
-
54
- def completion(model: str, messages: List[Dict[str, str]], max_tokens: int = 1000,
41
+ def completion(model: str, messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]], max_tokens: int = 1000,
55
42
  temperature: float = 0.5, api_key: Optional[str] = None,
56
- base_url: Optional[str] = None) -> str:
43
+ base_url: Optional[str] = None, json_format: bool = False) -> str:
57
44
  """
58
45
  Generate a completion using the specified model and messages.
59
46
  """
@@ -70,202 +57,393 @@ def completion(model: str, messages: List[Dict[str, str]], max_tokens: int = 100
70
57
 
71
58
  # Get API key and base URL from key manager if not provided
72
59
  if not api_key:
73
- api_key, base_url = key_manager.get_config(service)
60
+ # api_key, base_url = key_manager.get_config(service)
61
+ # Placeholder for key_manager
62
+ api_key = os.getenv(f"{service.upper()}_API_KEY")
63
+ base_url = os.getenv(f"{service.upper()}_BASE_URL")
74
64
 
75
- try:
65
+ def format_messages_for_api(model, messages):
66
+ """Convert ChatHistory messages to the format required by the specific API."""
76
67
  if "claude" in model:
77
- # Check for Vertex configuration
78
- vertex_project_id = os.getenv('VERTEX_PROJECT_ID')
79
- vertex_region = os.getenv('VERTEX_REGION')
80
-
81
- if vertex_project_id and vertex_region:
82
- client = AnthropicVertex(
83
- region=vertex_region,
84
- project_id=vertex_project_id
85
- )
68
+ formatted = []
69
+ system_msg = ""
70
+ if messages and messages[0]["role"] == "system":
71
+ system_msg = messages.pop(0)["content"]
72
+ for msg in messages:
73
+ content = msg["content"]
74
+ if isinstance(content, str):
75
+ formatted.append({"role": msg["role"], "content": content})
76
+ elif isinstance(content, list):
77
+ # Combine content blocks into a single message
78
+ combined_content = []
79
+ for block in content:
80
+ if isinstance(block, str):
81
+ combined_content.append({"type": "text", "text": block})
82
+ elif isinstance(block, Image.Image):
83
+ # For Claude, convert PIL.Image to base64
84
+ buffered = io.BytesIO()
85
+ block.save(buffered, format="PNG")
86
+ image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
87
+ combined_content.append({
88
+ "type": "image",
89
+ "source": {
90
+ "type": "base64",
91
+ "media_type": "image/png",
92
+ "data": image_base64
93
+ }
94
+ })
95
+ elif isinstance(block, dict):
96
+ if block.get("type") == "image_url":
97
+ combined_content.append({
98
+ "type": "image",
99
+ "source": {
100
+ "type": "url",
101
+ "url": block["image_url"]["url"]
102
+ }
103
+ })
104
+ elif block.get("type") == "image_base64":
105
+ combined_content.append({
106
+ "type": "image",
107
+ "source": {
108
+ "type": "base64",
109
+ "media_type": block["image_base64"]["media_type"],
110
+ "data": block["image_base64"]["data"]
111
+ }
112
+ })
113
+ formatted.append({"role": msg["role"], "content": combined_content})
114
+ return system_msg, formatted
115
+
116
+ elif "gemini" in model or "gpt" in model or "grok" in model:
117
+ formatted = []
118
+ for msg in messages:
119
+ content = msg["content"]
120
+ if isinstance(content, str):
121
+ formatted.append({"role": msg["role"], "parts": [content]})
122
+ elif isinstance(content, list):
123
+ parts = []
124
+ for block in content:
125
+ if isinstance(block, str):
126
+ parts.append(block)
127
+ elif isinstance(block, Image.Image):
128
+ parts.append(block)
129
+ elif isinstance(block, dict):
130
+ if block.get("type") == "image_url":
131
+ parts.append({"type": "image_url", "image_url": {"url": block["image_url"]["url"]}})
132
+ elif block.get("type") == "image_base64":
133
+ parts.append({"type": "image_base64", "image_base64": {"data": block["image_base64"]["data"], "media_type": block["image_base64"]["media_type"]}})
134
+ formatted.append({"role": msg["role"], "parts": parts})
135
+ return None, formatted
136
+
137
+ else: # OpenAI models
138
+ formatted = []
139
+ for msg in messages:
140
+ content = msg["content"]
141
+ if isinstance(content, str):
142
+ formatted.append({"role": msg["role"], "content": content})
143
+ elif isinstance(content, list):
144
+ # OpenAI expects 'content' as string; images are not directly supported
145
+ # You can convert images to URLs or descriptions if needed
146
+ combined_content = ""
147
+ for block in content:
148
+ if isinstance(block, str):
149
+ combined_content += block + "\n"
150
+ elif isinstance(block, Image.Image):
151
+ # Convert PIL.Image to base64 or upload and use URL
152
+ buffered = io.BytesIO()
153
+ block.save(buffered, format="PNG")
154
+ image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
155
+ combined_content += f"[Image Base64: {image_base64[:30]}...]\n"
156
+ elif isinstance(block, dict):
157
+ if block.get("type") == "image_url":
158
+ combined_content += f"[Image: {block['image_url']['url']}]\n"
159
+ elif block.get("type") == "image_base64":
160
+ combined_content += f"[Image Base64: {block['image_base64']['data'][:30]}...]\n"
161
+ formatted.append({"role": msg["role"], "content": combined_content.strip()})
162
+ return None, formatted
163
+
164
+ system_msg, formatted_messages = format_messages_for_api(model, messages.copy())
165
+
166
+ if "claude" in model:
167
+ # Check for Vertex configuration
168
+ vertex_project_id = os.getenv('VERTEX_PROJECT_ID')
169
+ vertex_region = os.getenv('VERTEX_REGION')
170
+
171
+ if vertex_project_id and vertex_region:
172
+ client = AnthropicVertex(
173
+ region=vertex_region,
174
+ project_id=vertex_project_id
175
+ )
176
+ else:
177
+ client = anthropic.Anthropic(api_key=api_key, base_url=base_url)
178
+
179
+ response = client.messages.create(
180
+ model=model,
181
+ max_tokens=max_tokens,
182
+ temperature=temperature,
183
+ messages=formatted_messages,
184
+ system=system_msg
185
+ )
186
+
187
+ while response.stop_reason == "max_tokens":
188
+ if formatted_messages[-1]['role'] == "user":
189
+ formatted_messages.append({"role": "assistant", "content": response.completion})
86
190
  else:
87
- client = anthropic.Anthropic(api_key=api_key, base_url=base_url)
191
+ formatted_messages[-1]['content'] += response.completion
88
192
 
89
- system_msg = messages.pop(0)["content"] if messages and messages[0]["role"] == "system" else ""
90
193
  response = client.messages.create(
91
194
  model=model,
92
195
  max_tokens=max_tokens,
93
196
  temperature=temperature,
94
- messages=messages,
197
+ messages=formatted_messages,
95
198
  system=system_msg
96
199
  )
97
-
98
- while response.stop_reason == "max_tokens":
99
- if messages[-1]['role'] == "user":
100
- messages.append({"role": "assistant", "content": response.content[0].text})
101
- else:
102
- messages[-1]['content'] += response.content[0].text
103
200
 
104
- response = client.messages.create(
105
- model=model,
106
- max_tokens=max_tokens,
107
- temperature=temperature,
108
- messages=messages,
109
- system=system_msg
110
- )
201
+ if formatted_messages[-1]['role'] == "assistant" and response.stop_reason == "end_turn":
202
+ formatted_messages[-1]['content'] += response.completion
203
+ return formatted_messages[-1]['content']
111
204
 
112
- if messages[-1]['role'] == "assistant" and response.stop_reason == "end_turn":
113
- messages[-1]['content'] += response.content[0].text
114
- return messages[-1]['content']
115
-
116
- return response.content[0].text
117
-
118
- elif "gemini" in model:
119
- try:
120
- # First try OpenAI-style API
121
- client = openai.OpenAI(
122
- api_key=api_key,
123
- base_url="https://generativelanguage.googleapis.com/v1beta/"
124
- )
125
- # Remove any system message from the beginning if present
126
- if messages and messages[0]["role"] == "system":
127
- system_msg = messages.pop(0)
205
+ return response.completion
206
+
207
+ elif "gemini" in model:
208
+ try:
209
+ # First try OpenAI-style API
210
+ client = openai.OpenAI(
211
+ api_key=api_key,
212
+ base_url="https://generativelanguage.googleapis.com/v1beta/"
213
+ )
214
+ # Set response_format based on json_format
215
+ response_format = {"type": "json_object"} if json_format else {"type": "plain_text"}
216
+
217
+ response = client.chat.completions.create(
218
+ model=model,
219
+ messages=formatted_messages,
220
+ temperature=temperature,
221
+ response_format=response_format # Added response_format
222
+ )
223
+ return response.choices[0].message.content
224
+
225
+ except Exception as e:
226
+ # If OpenAI-style API fails, fall back to Google's genai library
227
+ logger.info("Falling back to Google's genai library")
228
+ genai.configure(api_key=api_key)
229
+
230
+ # Convert messages to Gemini format
231
+ gemini_messages = []
232
+ for msg in messages:
233
+ if msg["role"] == "system":
128
234
  # Prepend system message to first user message if exists
129
- if messages:
130
- messages[0]["content"] = f"{system_msg['content']}\n\n{messages[0]['content']}"
131
-
132
- response = client.chat.completions.create(
133
- model=model,
134
- messages=messages,
135
- temperature=temperature
136
- )
137
-
138
- return response.choices[0].message.content
139
-
140
- except Exception as e:
141
- # If OpenAI-style API fails, fall back to Google's genai library
142
- logger.info("Falling back to Google's genai library")
143
- genai.configure(api_key=api_key)
144
-
145
- # Convert messages to Gemini format
146
- gemini_messages = []
147
- for msg in messages:
148
- if msg["role"] == "system":
149
- # Prepend system message to first user message if exists
150
- if gemini_messages:
151
- gemini_messages[0].parts[0].text = f"{msg['content']}\n\n{gemini_messages[0].parts[0].text}"
152
- else:
153
- gemini_messages.append({"role": msg["role"], "parts": [{"text": msg["content"]}]})
154
-
155
- # Create Gemini model and generate response
156
- model = genai.GenerativeModel(model_name=model)
157
- response = model.generate_content(
158
- gemini_messages,
159
- generation_config=genai.types.GenerationConfig(
160
- temperature=temperature,
161
- max_output_tokens=max_tokens
162
- )
163
- )
164
-
165
- return response.text
166
-
167
- elif "grok" in model:
168
- # Randomly choose between OpenAI and Anthropic SDK
169
- use_anthropic = random.choice([True, False])
170
-
171
- if use_anthropic:
172
- print("using anthropic")
173
- client = anthropic.Anthropic(
174
- api_key=api_key,
175
- base_url="https://api.x.ai"
176
- )
177
-
178
- system_msg = messages.pop(0)["content"] if messages and messages[0]["role"] == "system" else ""
179
- response = client.messages.create(
180
- model=model,
181
- max_tokens=max_tokens,
235
+ if gemini_messages:
236
+ first_msg = gemini_messages[0]
237
+ if "parts" in first_msg and len(first_msg["parts"]) > 0:
238
+ first_msg["parts"][0] = f"{msg['content']}\n\n{first_msg['parts'][0]}"
239
+ else:
240
+ gemini_messages.append({"role": msg["role"], "parts": msg["content"]})
241
+
242
+ # Set response_mime_type based on json_format
243
+ mime_type = "application/json" if json_format else "text/plain"
244
+
245
+ # Create Gemini model and generate response
246
+ model_instance = genai.GenerativeModel(model_name=model)
247
+ response = model_instance.generate_content(
248
+ gemini_messages,
249
+ generation_config=genai.types.GenerationConfig(
182
250
  temperature=temperature,
183
- messages=messages,
184
- system=system_msg
185
- )
186
- return response.content[0].text
187
- else:
188
- print("using openai")
189
- client = openai.OpenAI(
190
- api_key=api_key,
191
- base_url="https://api.x.ai/v1"
192
- )
193
- response = client.chat.completions.create(
194
- model=model,
195
- messages=messages,
196
- max_tokens=max_tokens,
197
- temperature=temperature
251
+ response_mime_type=mime_type, # Modified based on json_format
252
+ max_output_tokens=max_tokens
198
253
  )
199
- return response.choices[0].message.content
254
+ )
255
+
256
+ return response.text
257
+
258
+ elif "grok" in model:
259
+ # Randomly choose between OpenAI and Anthropic SDK
260
+ use_anthropic = random.choice([True, False])
261
+
262
+ if use_anthropic:
263
+ logger.info("Using Anthropic for Grok model")
264
+ client = anthropic.Anthropic(
265
+ api_key=api_key,
266
+ base_url="https://api.x.ai"
267
+ )
268
+
269
+ system_msg = ""
270
+ if messages and messages[0]["role"] == "system":
271
+ system_msg = messages.pop(0)["content"]
272
+
273
+ response = client.messages.create(
274
+ model=model,
275
+ max_tokens=max_tokens,
276
+ temperature=temperature,
277
+ messages=formatted_messages,
278
+ system=system_msg
279
+ )
280
+ return response.completion
281
+ else:
282
+ logger.info("Using OpenAI for Grok model")
283
+ client = openai.OpenAI(
284
+ api_key=api_key,
285
+ base_url="https://api.x.ai/v1"
286
+ )
287
+ # Set response_format based on json_format
288
+ response_format = {"type": "json_object"} if json_format else {"type": "plain_text"}
200
289
 
201
- else: # OpenAI models
202
- client = openai.OpenAI(api_key=api_key, base_url=base_url)
203
290
  response = client.chat.completions.create(
204
291
  model=model,
205
- messages=messages,
292
+ messages=formatted_messages,
206
293
  max_tokens=max_tokens,
207
294
  temperature=temperature,
295
+ response_format=response_format # Added response_format
208
296
  )
209
297
  return response.choices[0].message.content
210
298
 
211
- # Release the API key after successful use
212
- if not api_key:
213
- key_manager.release_config(service, api_key)
299
+ else: # OpenAI models
300
+ client = openai.OpenAI(api_key=api_key, base_url=base_url)
301
+ # Set response_format based on json_format
302
+ response_format = {"type": "json_object"} if json_format else {"type": "plain_text"}
214
303
 
215
- return response
304
+ response = client.chat.completions.create(
305
+ model=model,
306
+ messages=formatted_messages,
307
+ max_tokens=max_tokens,
308
+ temperature=temperature,
309
+ response_format=response_format # Added response_format
310
+ )
311
+ return response.choices[0].message.content
216
312
 
217
- except Exception as e:
218
- # Report error to key manager
219
- if not api_key:
220
- key_manager.report_error(service, api_key)
221
- raise
313
+ # Release the API key after successful use
314
+ if not api_key:
315
+ # key_manager.release_config(service, api_key)
316
+ pass
317
+
318
+ return response
222
319
 
223
320
  except Exception as e:
224
321
  logger.error(f"Error in completion: {str(e)}")
225
322
  raise
226
323
 
227
324
  class Agent:
228
- def __init__(self, model_name: str, messages: Optional[str] = None,
325
+ def __init__(self, model_name: str, messages: Optional[Union[str, List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]]] = None,
229
326
  memory_enabled: bool = False, api_key: Optional[str] = None) -> None:
230
327
  """Initialize an Agent instance."""
231
- # valid_models = ['gpt-3.5-turbo', 'gpt-4', 'claude-2.1', 'gemini-1.5-pro', 'gemini-1.5-flash', 'grok-beta', 'claude-3-5-sonnet-20241022']
232
- # if model_name not in valid_models:
233
- # raise ValueError(f"Model {model_name} not supported. Supported models: {valid_models}")
234
-
235
328
  self.id = f"{model_name}-{uuid.uuid4().hex[:8]}"
236
329
  self.model_name = model_name
237
- self.history = ChatHistory(messages)
330
+ self.history = ChatHistory(messages) if messages else ChatHistory()
238
331
  self.memory_enabled = memory_enabled
239
332
  self.api_key = api_key
240
333
  self.repo_content = []
241
334
 
242
- def add_message(self, role, content):
243
- repo_content = ""
244
- while self.repo_content:
245
- repo = self.repo_content.pop()
246
- repo_content += f"<repo>\n{repo}\n</repo>\n"
247
-
248
- content = repo_content + content
335
+ def add_message(self, role: str, content: Union[str, List[Union[str, Image.Image, Dict]]]):
336
+ """Add a message to the conversation."""
249
337
  self.history.add_message(content, role)
250
338
 
251
- def generate_response(self, max_tokens=3585, temperature=0.7):
339
+ def add_user_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
340
+ """Add a user message."""
341
+ self.history.add_user_message(content)
342
+
343
+ def add_assistant_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
344
+ """Add an assistant message."""
345
+ self.history.add_assistant_message(content)
346
+
347
+ def add_image(self, image_path: Optional[str] = None, image_url: Optional[str] = None, media_type: Optional[str] = "image/jpeg"):
348
+ """
349
+ Add an image to the conversation.
350
+ Either image_path or image_url must be provided.
351
+ """
352
+ if not image_path and not image_url:
353
+ raise ValueError("Either image_path or image_url must be provided.")
354
+
355
+ if image_path:
356
+ if not os.path.exists(image_path):
357
+ raise FileNotFoundError(f"Image file {image_path} does not exist.")
358
+ if "gemini" in self.model_name:
359
+ # For Gemini, load as PIL.Image
360
+ image_pil = Image.open(image_path)
361
+ image_block = image_pil
362
+ else:
363
+ # For Claude and others, use base64 encoding
364
+ with open(image_path, "rb") as img_file:
365
+ image_data = base64.standard_b64encode(img_file.read()).decode("utf-8")
366
+ image_block = {
367
+ "type": "image_base64",
368
+ "image_base64": {
369
+ "media_type": media_type,
370
+ "data": image_data
371
+ }
372
+ }
373
+ else:
374
+ # If image_url is provided
375
+ if "gemini" in self.model_name:
376
+ # For Gemini, you can pass image URLs directly
377
+ image_block = {"type": "image_url", "image_url": {"url": image_url}}
378
+ else:
379
+ # For Claude and others, use image URLs
380
+ image_block = {
381
+ "type": "image_url",
382
+ "image_url": {
383
+ "url": image_url
384
+ }
385
+ }
386
+
387
+ # Add the image block to the last user message or as a new user message
388
+ if self.history.last_role == "user":
389
+ current_content = self.history.messages[-1]["content"]
390
+ if isinstance(current_content, list):
391
+ current_content.append(image_block)
392
+ else:
393
+ self.history.messages[-1]["content"] = [current_content, image_block]
394
+ else:
395
+ # Start a new user message with the image
396
+ self.history.add_message([image_block], "user")
397
+
398
+ def generate_response(self, max_tokens=3585, temperature=0.7, json_format: bool = False) -> str:
399
+ """Generate a response from the agent.
400
+
401
+ Args:
402
+ max_tokens (int, optional): Maximum number of tokens. Defaults to 3585.
403
+ temperature (float, optional): Sampling temperature. Defaults to 0.7.
404
+ json_format (bool, optional): Whether to enable JSON output format. Defaults to False.
405
+
406
+ Returns:
407
+ str: The generated response.
408
+ """
252
409
  if not self.history.messages:
253
410
  raise ValueError("No messages in history to generate response from")
254
411
 
255
- messages = [{"role": msg["role"], "content": msg["content"]} for msg in self.history.messages]
256
-
412
+ messages = self.history.messages
413
+
257
414
  response_text = completion(
258
415
  model=self.model_name,
259
416
  messages=messages,
260
417
  max_tokens=max_tokens,
261
418
  temperature=temperature,
262
- api_key=self.api_key
419
+ api_key=self.api_key,
420
+ json_format=json_format # Pass json_format to completion
263
421
  )
264
- if messages[-1]["role"] == "assistant":
265
- self.history.messages[-1]["content"] = response_text
266
-
267
- elif self.memory_enabled:
268
- self.add_message("assistant", response_text)
422
+ if self.model_name.startswith("openai"):
423
+ # OpenAI does not support images, so responses are simple strings
424
+ if self.history.messages[-1]["role"] == "assistant":
425
+ self.history.messages[-1]["content"] = response_text
426
+ elif self.memory_enabled:
427
+ self.add_message("assistant", response_text)
428
+ elif "claude" in self.model_name:
429
+ if self.history.messages[-1]["role"] == "assistant":
430
+ self.history.messages[-1]["content"] = response_text
431
+ elif self.memory_enabled:
432
+ self.add_message("assistant", response_text)
433
+ elif "gemini" in self.model_name or "grok" in self.model_name:
434
+ if self.history.messages[-1]["role"] == "assistant":
435
+ if isinstance(self.history.messages[-1]["content"], list):
436
+ self.history.messages[-1]["content"].append(response_text)
437
+ else:
438
+ self.history.messages[-1]["content"] = [self.history.messages[-1]["content"], response_text]
439
+ elif self.memory_enabled:
440
+ self.add_message("assistant", response_text)
441
+ else:
442
+ # Handle other models similarly
443
+ if self.history.messages[-1]["role"] == "assistant":
444
+ self.history.messages[-1]["content"] = response_text
445
+ elif self.memory_enabled:
446
+ self.add_message("assistant", response_text)
269
447
 
270
448
  return response_text
271
449
 
@@ -274,20 +452,20 @@ class Agent:
274
452
  with open(filename, 'w', encoding='utf-8') as file:
275
453
  json.dump(self.history.messages, file, ensure_ascii=False, indent=4)
276
454
 
277
- def load_conversation(self, filename=None):
455
+ def load_conversation(self, filename: Optional[str] = None):
278
456
  if filename is None:
279
457
  filename = f"{self.id}.json"
280
458
  with open(filename, 'r', encoding='utf-8') as file:
281
459
  messages = json.load(file)
460
+ # Handle deserialization of images if necessary
282
461
  self.history = ChatHistory(messages)
283
462
 
284
- def load_repo(self, repo_name: Optional[str] = None, commit_hash: Optional[str] = None):
285
- """eg: repo_name: xihajun/llm_dialog_manager"""
286
- if repo_name:
463
+ def add_repo(self, repo_url: Optional[str] = None, username: Optional[str] = None, repo_name: Optional[str] = None, commit_hash: Optional[str] = None):
464
+ if username and repo_name:
287
465
  if commit_hash:
288
- repo_url = f"https://github.com/{repo_name}/archive/{commit_hash}.zip"
466
+ repo_url = f"https://github.com/{username}/{repo_name}/archive/{commit_hash}.zip"
289
467
  else:
290
- repo_url = f"https://github.com/{repo_name}/archive/refs/heads/main.zip"
468
+ repo_url = f"https://github.com/{username}/{repo_name}/archive/refs/heads/main.zip"
291
469
 
292
470
  if not repo_url:
293
471
  raise ValueError("Either repo_url or both username and repo_name must be provided")
@@ -306,28 +484,31 @@ class Agent:
306
484
  raise ValueError(f"Failed to download repository from {repo_url}")
307
485
 
308
486
  if __name__ == "__main__":
309
-
310
- # write a test for detect finding agent
311
- text = "I think the answer is 42"
312
-
313
- # from agent.messageloader import information_detector_messages
487
+ # Example Usage
488
+ # Create an Agent instance (Gemini model)
489
+ agent = Agent("gemini-1.5-flash", "you are an assistant", memory_enabled=True)
314
490
 
315
- # # Now you can print or use information_detector_messages as needed
316
- # information_detector_agent = Agent("gemini-1.5-pro", information_detector_messages)
317
- # information_detector_agent.add_message("user", text)
318
- # response = information_detector_agent.generate_response()
319
- # print(response)
320
- agent = Agent("gemini-1.5-pro-002", "you are an assistant", memory_enabled=True)
491
+ # Add an image
492
+ agent.add_image(image_path="/Users/junfan/Projects/Personal/oneapi/dialog_manager/example.png")
321
493
 
322
- # Format the prompt to check if the section is the last one in the outline
323
- prompt = f"Say: {text}\n"
494
+ # Add a user message
495
+ agent.add_message("user", "What's in this image?")
324
496
 
325
- # Add the prompt as a message from the user
326
- agent.add_message("user", prompt)
327
- agent.add_message("assistant", "the answer")
497
+ # Generate response with JSON format enabled
498
+ try:
499
+ response = agent.generate_response(json_format=True) # json_format set to True
500
+ print("Response:", response)
501
+ except Exception as e:
502
+ logger.error(f"Failed to generate response: {e}")
328
503
 
329
- print(agent.generate_response())
330
- print(agent.history[:])
504
+ # Print the entire conversation history
505
+ print("Conversation History:")
506
+ print(agent.history)
507
+
508
+ # Pop the last message
331
509
  last_message = agent.history.pop()
332
- print(last_message)
333
- print(agent.history[:])
510
+ print("Last Message:", last_message)
511
+
512
+ # Generate another response without JSON format
513
+ response = agent.generate_response()
514
+ print("Response:", response)
@@ -1,53 +1,122 @@
1
1
  from typing import List, Dict, Optional, Union
2
+ from PIL import Image
2
3
 
3
4
  class ChatHistory:
4
- def __init__(self, input_data: Union[str, List[Dict[str, str]]] = "") -> None:
5
- self.messages: List[Dict[str, str]] = []
5
+ def __init__(self, input_data: Union[str, List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]] = "") -> None:
6
+ self.messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]] = []
6
7
  if isinstance(input_data, str) and input_data:
7
8
  self.add_message(input_data, "system")
8
9
  elif isinstance(input_data, list):
9
10
  self.load_messages(input_data)
10
- self.last_role: str = "system" if not self.messages else self.messages[-1]["role"]
11
+ self.last_role: str = "system" if not self.messages else self.get_last_role()
11
12
 
12
- def load_messages(self, messages: List[Dict[str, str]]) -> None:
13
+ def load_messages(self, messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]) -> None:
13
14
  for message in messages:
14
15
  if not ("role" in message and "content" in message):
15
16
  raise ValueError("Each message must have a 'role' and 'content'.")
16
17
  if message["role"] not in ["user", "assistant", "system"]:
17
18
  raise ValueError(f"Invalid role: {message['role']}")
18
19
  self.messages.append(message)
19
-
20
+ self.last_role = self.get_last_role()
21
+
22
+ def get_last_role(self):
23
+ return self.messages[-1]["role"] if self.messages else "system"
24
+
20
25
  def pop(self):
21
26
  if not self.messages:
22
27
  return None
23
-
28
+
24
29
  popped_message = self.messages.pop()
25
-
30
+
26
31
  if self.messages:
27
- self.last_role = self.messages[-1]["role"]
32
+ self.last_role = self.get_last_role()
28
33
  else:
29
34
  self.last_role = "system"
30
-
35
+
31
36
  return popped_message["content"]
32
37
 
33
38
  def __len__(self):
34
39
  return len(self.messages)
35
40
 
36
41
  def __str__(self):
37
- return '\n'.join([f"Message {i} ({msg['role']}): {msg['content']}" for i, msg in enumerate(self.messages)])
42
+ formatted_messages = []
43
+ for i, msg in enumerate(self.messages):
44
+ role = msg['role']
45
+ content = msg['content']
46
+ if isinstance(content, str):
47
+ formatted_content = content
48
+ elif isinstance(content, list):
49
+ parts = []
50
+ for block in content:
51
+ if isinstance(block, str):
52
+ parts.append(block)
53
+ elif isinstance(block, Image.Image):
54
+ parts.append(f"[Image Object: {block.filename}]")
55
+ elif isinstance(block, dict):
56
+ if block.get("type") == "image_url":
57
+ parts.append(f"[Image URL: {block.get('image_url', {}).get('url', '')}]")
58
+ elif block.get("type") == "image_base64":
59
+ parts.append(f"[Image Base64: {block.get('image_base64', {}).get('data', '')[:30]}...]")
60
+ formatted_content = "\n".join(parts)
61
+ else:
62
+ formatted_content = str(content)
63
+ formatted_messages.append(f"Message {i} ({role}): {formatted_content}")
64
+ return '\n'.join(formatted_messages)
38
65
 
39
66
  def __getitem__(self, key):
40
67
  if isinstance(key, slice):
41
- # Handle slice object, no change needed here for negative indices as slices are handled by list itself
42
- print('\n'.join([f"({msg['role']}): {msg['content']}" for i, msg in enumerate(self.messages[key])]))
43
- return self.messages[key]
68
+ sliced_messages = self.messages[key]
69
+ formatted = []
70
+ for msg in sliced_messages:
71
+ role = msg['role']
72
+ content = msg['content']
73
+ if isinstance(content, str):
74
+ formatted_content = content
75
+ elif isinstance(content, list):
76
+ parts = []
77
+ for block in content:
78
+ if isinstance(block, str):
79
+ parts.append(block)
80
+ elif isinstance(block, Image.Image):
81
+ parts.append(f"[Image Object: {block.filename}]")
82
+ elif isinstance(block, dict):
83
+ if block.get("type") == "image_url":
84
+ parts.append(f"[Image URL: {block.get('image_url', {}).get('url', '')}]")
85
+ elif block.get("type") == "image_base64":
86
+ parts.append(f"[Image Base64: {block.get('image_base64', {}).get('data', '')[:30]}...]")
87
+ formatted_content = "\n".join(parts)
88
+ else:
89
+ formatted_content = str(content)
90
+ formatted.append(f"({role}): {formatted_content}")
91
+ print('\n'.join(formatted))
92
+ return sliced_messages
44
93
  elif isinstance(key, int):
45
94
  # Adjust for negative indices
46
95
  if key < 0:
47
- key += len(self.messages) # Convert negative index to positive
96
+ key += len(self.messages)
48
97
  if 0 <= key < len(self.messages):
98
+ msg = self.messages[key]
99
+ role = msg['role']
100
+ content = msg['content']
101
+ if isinstance(content, str):
102
+ formatted_content = content
103
+ elif isinstance(content, list):
104
+ parts = []
105
+ for block in content:
106
+ if isinstance(block, str):
107
+ parts.append(block)
108
+ elif isinstance(block, Image.Image):
109
+ parts.append(f"[Image Object: {block.filename}]")
110
+ elif isinstance(block, dict):
111
+ if block.get("type") == "image_url":
112
+ parts.append(f"[Image URL: {block.get('image_url', {}).get('url', '')}]")
113
+ elif block.get("type") == "image_base64":
114
+ parts.append(f"[Image Base64: {block.get('image_base64', {}).get('data', '')[:30]}...]")
115
+ formatted_content = "\n".join(parts)
116
+ else:
117
+ formatted_content = str(content)
49
118
  snippet = self.get_conversation_snippet(key)
50
- print('\n'.join([f"({v['role']}): {self.color_text(v['content'], 'green') if k == 'current' else v['content']}" for k, v in snippet.items() if v]))
119
+ print('\n'.join([f"({v['role']}): {v['content']}" for k, v in snippet.items() if v]))
51
120
  return self.messages[key]
52
121
  else:
53
122
  raise IndexError("Message index out of range.")
@@ -55,8 +124,8 @@ class ChatHistory:
55
124
  raise TypeError("Invalid argument type.")
56
125
 
57
126
  def __setitem__(self, index, value):
58
- if not isinstance(value, str):
59
- raise ValueError("Message content must be a string.")
127
+ if not isinstance(value, (str, list)):
128
+ raise ValueError("Message content must be a string or a list of content blocks.")
60
129
  role = "system" if index % 2 == 0 else "user"
61
130
  self.messages[index] = {"role": role, "content": value}
62
131
 
@@ -68,19 +137,27 @@ class ChatHistory:
68
137
  self.add_message(message, next_role)
69
138
 
70
139
  def __contains__(self, item):
71
- return any(item in message['content'] for message in self.messages)
140
+ for message in self.messages:
141
+ content = message['content']
142
+ if isinstance(content, str) and item in content:
143
+ return True
144
+ elif isinstance(content, list):
145
+ for block in content:
146
+ if isinstance(block, str) and item in block:
147
+ return True
148
+ return False
72
149
 
73
- def add_message(self, content, role):
150
+ def add_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]], role: str):
74
151
  self.messages.append({"role": role, "content": content})
75
152
  self.last_role = role
76
153
 
77
- def add_user_message(self, content):
154
+ def add_user_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
78
155
  if self.last_role in ["system", "assistant"]:
79
156
  self.add_message(content, "user")
80
157
  else:
81
158
  raise ValueError("A user message must follow a system or assistant message.")
82
159
 
83
- def add_assistant_message(self, content):
160
+ def add_assistant_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
84
161
  if self.last_role == "user":
85
162
  self.add_message(content, "assistant")
86
163
  else:
@@ -110,7 +187,17 @@ class ChatHistory:
110
187
  print(f"Content of the last message: {status['last_message_content']}")
111
188
 
112
189
  def search_for_keyword(self, keyword):
113
- return [msg for msg in self.messages if keyword.lower() in msg["content"].lower()]
190
+ results = []
191
+ for msg in self.messages:
192
+ content = msg['content']
193
+ if isinstance(content, str) and keyword.lower() in content.lower():
194
+ results.append(msg)
195
+ elif isinstance(content, list):
196
+ for block in content:
197
+ if isinstance(block, str) and keyword.lower() in block.lower():
198
+ results.append(msg)
199
+ break
200
+ return results
114
201
 
115
202
  def has_user_or_assistant_spoken_since_last_system(self):
116
203
  for msg in reversed(self.messages):
@@ -143,4 +230,4 @@ class ChatHistory:
143
230
  @staticmethod
144
231
  def color_text(text, color):
145
232
  colors = {"green": "\033[92m", "red": "\033[91m", "end": "\033[0m"}
146
- return f"{colors[color]}{text}{colors['end']}"
233
+ return f"{colors.get(color, '')}{text}{colors.get('end', '')}"
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: llm_dialog_manager
3
- Version: 0.3.2
3
+ Version: 0.3.5
4
4
  Summary: A Python package for managing LLM chat conversation history
5
5
  Author-email: xihajun <work@2333.fun>
6
6
  License: MIT
@@ -0,0 +1,9 @@
1
+ llm_dialog_manager/__init__.py,sha256=J7L76hDTCNM56mprhbMclqCG04IacKiIgaHm8Ty7shQ,86
2
+ llm_dialog_manager/agent.py,sha256=aMeSL7rV7sSGgJzCkXp_ahiq569eTy-9Jfepam4pKUU,23064
3
+ llm_dialog_manager/chat_history.py,sha256=DKKRnj_M6h-4JncnH6KekMTghX7vMgdN3J9uOwXKzMU,10347
4
+ llm_dialog_manager/key_manager.py,sha256=shvxmn4zUtQx_p-x1EFyOmnk-WlhigbpKtxTKve-zXk,4421
5
+ llm_dialog_manager-0.3.5.dist-info/LICENSE,sha256=vWGbYgGuWpWrXL8-xi6pNcX5UzD6pWoIAZmcetyfbus,1064
6
+ llm_dialog_manager-0.3.5.dist-info/METADATA,sha256=y9rfQ9rcmwrScQJYK-0PjRPCKFWqAKFsm_8NoSwiloI,4152
7
+ llm_dialog_manager-0.3.5.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
8
+ llm_dialog_manager-0.3.5.dist-info/top_level.txt,sha256=u2EQEXW0NGAt0AAHT7jx1odXZ4rZfjcgbmJhvKFuMkI,19
9
+ llm_dialog_manager-0.3.5.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.6.0)
2
+ Generator: setuptools (75.8.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,9 +0,0 @@
1
- llm_dialog_manager/__init__.py,sha256=u9m9_3UUibXqNSRzkf63lAbkw_l5RmGcxlVOhmeShlg,86
2
- llm_dialog_manager/agent.py,sha256=o0ayPp9eS_8Vx78GYMPlIxHrK1UHTbu_I3dEsk4Agdo,13268
3
- llm_dialog_manager/chat_history.py,sha256=xKA-oQCv8jv_g8EhXrG9h1S8Icbj2FfqPIhbty5vra4,6033
4
- llm_dialog_manager/key_manager.py,sha256=shvxmn4zUtQx_p-x1EFyOmnk-WlhigbpKtxTKve-zXk,4421
5
- llm_dialog_manager-0.3.2.dist-info/LICENSE,sha256=vWGbYgGuWpWrXL8-xi6pNcX5UzD6pWoIAZmcetyfbus,1064
6
- llm_dialog_manager-0.3.2.dist-info/METADATA,sha256=-DNhRHechHF38K6DGrHCeHVTj2JX3sxw4ZYFQr6Sj5I,4152
7
- llm_dialog_manager-0.3.2.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
8
- llm_dialog_manager-0.3.2.dist-info/top_level.txt,sha256=u2EQEXW0NGAt0AAHT7jx1odXZ4rZfjcgbmJhvKFuMkI,19
9
- llm_dialog_manager-0.3.2.dist-info/RECORD,,