llm-dialog-manager 0.4.7__tar.gz → 0.5.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {llm_dialog_manager-0.4.7 → llm_dialog_manager-0.5.3}/PKG-INFO +2 -2
- {llm_dialog_manager-0.4.7 → llm_dialog_manager-0.5.3}/README.md +1 -1
- llm_dialog_manager-0.5.3/llm_dialog_manager/__init__.py +20 -0
- llm_dialog_manager-0.5.3/llm_dialog_manager/agent.py +189 -0
- llm_dialog_manager-0.5.3/llm_dialog_manager/clients/__init__.py +31 -0
- llm_dialog_manager-0.5.3/llm_dialog_manager/clients/anthropic_client.py +143 -0
- llm_dialog_manager-0.5.3/llm_dialog_manager/clients/base.py +65 -0
- llm_dialog_manager-0.5.3/llm_dialog_manager/clients/gemini_client.py +78 -0
- llm_dialog_manager-0.5.3/llm_dialog_manager/clients/openai_client.py +97 -0
- llm_dialog_manager-0.5.3/llm_dialog_manager/clients/x_client.py +83 -0
- llm_dialog_manager-0.5.3/llm_dialog_manager/formatters/__init__.py +27 -0
- llm_dialog_manager-0.5.3/llm_dialog_manager/formatters/anthropic.py +76 -0
- llm_dialog_manager-0.5.3/llm_dialog_manager/formatters/base.py +23 -0
- llm_dialog_manager-0.5.3/llm_dialog_manager/formatters/gemini.py +59 -0
- llm_dialog_manager-0.5.3/llm_dialog_manager/formatters/openai.py +67 -0
- llm_dialog_manager-0.5.3/llm_dialog_manager/formatters/x.py +77 -0
- llm_dialog_manager-0.5.3/llm_dialog_manager/utils/__init__.py +3 -0
- llm_dialog_manager-0.5.3/llm_dialog_manager/utils/environment.py +66 -0
- llm_dialog_manager-0.5.3/llm_dialog_manager/utils/image_tools.py +81 -0
- llm_dialog_manager-0.5.3/llm_dialog_manager/utils/logging.py +35 -0
- {llm_dialog_manager-0.4.7 → llm_dialog_manager-0.5.3}/llm_dialog_manager.egg-info/PKG-INFO +2 -2
- llm_dialog_manager-0.5.3/llm_dialog_manager.egg-info/SOURCES.txt +31 -0
- llm_dialog_manager-0.5.3/pyproject.toml +64 -0
- llm_dialog_manager-0.4.7/llm_dialog_manager/__init__.py +0 -4
- llm_dialog_manager-0.4.7/llm_dialog_manager/agent.py +0 -642
- llm_dialog_manager-0.4.7/llm_dialog_manager.egg-info/SOURCES.txt +0 -15
- llm_dialog_manager-0.4.7/pyproject.toml +0 -32
- {llm_dialog_manager-0.4.7 → llm_dialog_manager-0.5.3}/LICENSE +0 -0
- {llm_dialog_manager-0.4.7 → llm_dialog_manager-0.5.3}/llm_dialog_manager/chat_history.py +0 -0
- {llm_dialog_manager-0.4.7 → llm_dialog_manager-0.5.3}/llm_dialog_manager/key_manager.py +0 -0
- {llm_dialog_manager-0.4.7 → llm_dialog_manager-0.5.3}/llm_dialog_manager.egg-info/dependency_links.txt +0 -0
- {llm_dialog_manager-0.4.7 → llm_dialog_manager-0.5.3}/llm_dialog_manager.egg-info/requires.txt +0 -0
- {llm_dialog_manager-0.4.7 → llm_dialog_manager-0.5.3}/llm_dialog_manager.egg-info/top_level.txt +0 -0
- {llm_dialog_manager-0.4.7 → llm_dialog_manager-0.5.3}/setup.cfg +0 -0
- {llm_dialog_manager-0.4.7 → llm_dialog_manager-0.5.3}/tests/test_agent.py +0 -0
- {llm_dialog_manager-0.4.7 → llm_dialog_manager-0.5.3}/tests/test_chat_history.py +0 -0
- {llm_dialog_manager-0.4.7 → llm_dialog_manager-0.5.3}/tests/test_key_manager.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: llm_dialog_manager
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.5.3
|
4
4
|
Summary: A Python package for managing LLM chat conversation history
|
5
5
|
Author-email: xihajun <work@2333.fun>
|
6
6
|
License: MIT
|
@@ -103,7 +103,7 @@ XAI_API_KEY=your-x-key
|
|
103
103
|
from llm_dialog_manager import Agent
|
104
104
|
|
105
105
|
# Initialize an agent with a specific model
|
106
|
-
agent = Agent("
|
106
|
+
agent = Agent("ep-20250319212209-j6tfj-openai", memory_enabled=True)
|
107
107
|
|
108
108
|
# Add messages and generate responses
|
109
109
|
agent.add_message("system", "You are a helpful assistant")
|
@@ -55,7 +55,7 @@ XAI_API_KEY=your-x-key
|
|
55
55
|
from llm_dialog_manager import Agent
|
56
56
|
|
57
57
|
# Initialize an agent with a specific model
|
58
|
-
agent = Agent("
|
58
|
+
agent = Agent("ep-20250319212209-j6tfj-openai", memory_enabled=True)
|
59
59
|
|
60
60
|
# Add messages and generate responses
|
61
61
|
agent.add_message("system", "You are a helpful assistant")
|
@@ -0,0 +1,20 @@
|
|
1
|
+
"""
|
2
|
+
LLM Dialog Manager
|
3
|
+
|
4
|
+
A modular framework for building conversational AI applications with
|
5
|
+
support for multiple LLM providers.
|
6
|
+
"""
|
7
|
+
|
8
|
+
__version__ = "0.5.3"
|
9
|
+
|
10
|
+
from .agent import Agent
|
11
|
+
from .chat_history import ChatHistory
|
12
|
+
from .key_manager import key_manager
|
13
|
+
|
14
|
+
# Import factory functions for easy access
|
15
|
+
from .clients import get_client
|
16
|
+
from .formatters import get_formatter
|
17
|
+
|
18
|
+
# Setup environment by default
|
19
|
+
from .utils.environment import load_env_vars
|
20
|
+
load_env_vars()
|
@@ -0,0 +1,189 @@
|
|
1
|
+
"""
|
2
|
+
Agent class for managing LLM conversations
|
3
|
+
"""
|
4
|
+
# Standard library imports
|
5
|
+
import uuid
|
6
|
+
import logging
|
7
|
+
from typing import List, Dict, Optional, Union
|
8
|
+
from PIL import Image
|
9
|
+
|
10
|
+
# Local imports
|
11
|
+
from .chat_history import ChatHistory
|
12
|
+
from .clients import get_client
|
13
|
+
from .utils.environment import load_env_vars
|
14
|
+
from .utils.image_tools import load_image_from_path, load_image_from_url, create_image_content_block
|
15
|
+
|
16
|
+
# Setup logging
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
19
|
+
# Load environment variables
|
20
|
+
load_env_vars()
|
21
|
+
|
22
|
+
class Agent:
|
23
|
+
"""
|
24
|
+
Agent class for managing conversations with LLMs.
|
25
|
+
|
26
|
+
This class provides a high-level interface for interacting with different
|
27
|
+
LLM providers through a unified API.
|
28
|
+
"""
|
29
|
+
|
30
|
+
def __init__(self, model_name: str,
|
31
|
+
messages: Optional[Union[str, List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]]] = None,
|
32
|
+
memory_enabled: bool = False,
|
33
|
+
api_key: Optional[str] = None,
|
34
|
+
base_url: Optional[str] = None) -> None:
|
35
|
+
"""
|
36
|
+
Initialize an Agent instance.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
model_name: Name of the LLM model to use
|
40
|
+
messages: Optional initial messages or system prompt
|
41
|
+
memory_enabled: Whether to enable conversation memory
|
42
|
+
api_key: Optional API key to use
|
43
|
+
base_url: Optional base URL for API requests
|
44
|
+
"""
|
45
|
+
self.id = f"{model_name}-{uuid.uuid4().hex[:8]}"
|
46
|
+
self.model_name = model_name
|
47
|
+
self.history = ChatHistory(messages) if messages else ChatHistory()
|
48
|
+
self.memory_enabled = memory_enabled
|
49
|
+
self.client = get_client(model_name, api_key=api_key, base_url=base_url)
|
50
|
+
self.repo_content = []
|
51
|
+
|
52
|
+
def add_message(self, role: str, content: Union[str, List[Union[str, Image.Image, Dict]]]):
|
53
|
+
"""
|
54
|
+
Add a message to the conversation.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
role: Message role ('system', 'user', or 'assistant')
|
58
|
+
content: Message content (text, image, or mixed content)
|
59
|
+
"""
|
60
|
+
self.history.add_message(content, role)
|
61
|
+
|
62
|
+
def add_user_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
|
63
|
+
"""
|
64
|
+
Add a user message to the conversation.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
content: Message content (text, image, or mixed content)
|
68
|
+
"""
|
69
|
+
self.history.add_user_message(content)
|
70
|
+
|
71
|
+
def add_assistant_message(self, content: Union[str, List[Union[str, Image.Image, Dict]]]):
|
72
|
+
"""
|
73
|
+
Add an assistant message to the conversation.
|
74
|
+
|
75
|
+
Args:
|
76
|
+
content: Message content (text, image, or mixed content)
|
77
|
+
"""
|
78
|
+
self.history.add_assistant_message(content)
|
79
|
+
|
80
|
+
def add_image(self, image_path: Optional[str] = None,
|
81
|
+
image_url: Optional[str] = None,
|
82
|
+
media_type: Optional[str] = "image/jpeg"):
|
83
|
+
"""
|
84
|
+
Add an image to the conversation.
|
85
|
+
|
86
|
+
Either image_path or image_url must be provided.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
image_path: Path to a local image file
|
90
|
+
image_url: URL of an image
|
91
|
+
media_type: MIME type of the image
|
92
|
+
|
93
|
+
Returns:
|
94
|
+
The image content block that was added
|
95
|
+
"""
|
96
|
+
if not (image_path or image_url):
|
97
|
+
raise ValueError("Either image_path or image_url must be provided.")
|
98
|
+
|
99
|
+
if image_path:
|
100
|
+
image = load_image_from_path(image_path)
|
101
|
+
else:
|
102
|
+
image = load_image_from_url(image_url)
|
103
|
+
|
104
|
+
return create_image_content_block(image, media_type)
|
105
|
+
|
106
|
+
def generate_response(self, max_tokens=3585, temperature=0.7,
|
107
|
+
top_p=1.0, top_k=40, json_format=False, **kwargs):
|
108
|
+
"""
|
109
|
+
Generate a response from the agent.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
max_tokens: Maximum number of tokens to generate
|
113
|
+
temperature: Sampling temperature
|
114
|
+
top_p: Nucleus sampling parameter
|
115
|
+
top_k: Top-k sampling parameter
|
116
|
+
json_format: Whether to enable JSON output format
|
117
|
+
**kwargs: Additional model-specific parameters
|
118
|
+
|
119
|
+
Returns:
|
120
|
+
The generated response text
|
121
|
+
"""
|
122
|
+
response = self.client.completion(
|
123
|
+
messages=self.history.messages,
|
124
|
+
max_tokens=max_tokens,
|
125
|
+
temperature=temperature,
|
126
|
+
top_p=top_p,
|
127
|
+
top_k=top_k,
|
128
|
+
json_format=json_format,
|
129
|
+
model=self.model_name,
|
130
|
+
**kwargs
|
131
|
+
)
|
132
|
+
|
133
|
+
# Add the response to history
|
134
|
+
if not json_format:
|
135
|
+
self.add_assistant_message(response)
|
136
|
+
|
137
|
+
return response
|
138
|
+
|
139
|
+
def save_conversation(self, filename=None):
|
140
|
+
"""
|
141
|
+
Save the conversation history to a file.
|
142
|
+
|
143
|
+
Args:
|
144
|
+
filename: Optional filename to save to
|
145
|
+
"""
|
146
|
+
if filename is None:
|
147
|
+
filename = f"conversation_{self.id}.json"
|
148
|
+
|
149
|
+
import json
|
150
|
+
|
151
|
+
# Convert any PIL.Image objects to base64 for serialization
|
152
|
+
serializable_history = []
|
153
|
+
for msg in self.history.messages:
|
154
|
+
role = msg["role"]
|
155
|
+
content = msg["content"]
|
156
|
+
|
157
|
+
if isinstance(content, str):
|
158
|
+
serializable_history.append({"role": role, "content": content})
|
159
|
+
elif isinstance(content, list):
|
160
|
+
serializable_content = []
|
161
|
+
for item in content:
|
162
|
+
if isinstance(item, str):
|
163
|
+
serializable_content.append(item)
|
164
|
+
elif isinstance(item, Image.Image):
|
165
|
+
serializable_content.append(create_image_content_block(item))
|
166
|
+
elif isinstance(item, dict):
|
167
|
+
serializable_content.append(item)
|
168
|
+
serializable_history.append({"role": role, "content": serializable_content})
|
169
|
+
|
170
|
+
with open(filename, 'w') as f:
|
171
|
+
json.dump(serializable_history, f, indent=2)
|
172
|
+
|
173
|
+
return filename
|
174
|
+
|
175
|
+
def load_conversation(self, filename):
|
176
|
+
"""
|
177
|
+
Load a conversation from a file.
|
178
|
+
|
179
|
+
Args:
|
180
|
+
filename: Path to the conversation file
|
181
|
+
"""
|
182
|
+
import json
|
183
|
+
|
184
|
+
with open(filename, 'r') as f:
|
185
|
+
history = json.load(f)
|
186
|
+
|
187
|
+
self.history = ChatHistory(history)
|
188
|
+
|
189
|
+
return self.history
|
@@ -0,0 +1,31 @@
|
|
1
|
+
"""
|
2
|
+
LLM API client modules for different services.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from .base import BaseClient
|
6
|
+
from .anthropic_client import AnthropicClient
|
7
|
+
from .gemini_client import GeminiClient
|
8
|
+
from .openai_client import OpenAIClient
|
9
|
+
from .x_client import XClient
|
10
|
+
|
11
|
+
def get_client(model_name, api_key=None, base_url=None):
|
12
|
+
"""Factory method to get the appropriate client for a given model.
|
13
|
+
|
14
|
+
Args:
|
15
|
+
model_name: Name of the model (e.g., "claude-3-opus", "gemini-1.5-pro")
|
16
|
+
api_key: Optional API key to use
|
17
|
+
base_url: Optional base URL to use
|
18
|
+
|
19
|
+
Returns:
|
20
|
+
An instance of the appropriate client class
|
21
|
+
"""
|
22
|
+
if "-openai" in model_name:
|
23
|
+
return OpenAIClient(api_key=api_key, base_url=base_url)
|
24
|
+
if "claude" in model_name:
|
25
|
+
return AnthropicClient(api_key=api_key, base_url=base_url)
|
26
|
+
elif "gemini" in model_name:
|
27
|
+
return GeminiClient(api_key=api_key, base_url=base_url)
|
28
|
+
elif "grok" in model_name:
|
29
|
+
return XClient(api_key=api_key, base_url=base_url)
|
30
|
+
else: # Default to OpenAI client
|
31
|
+
return OpenAIClient(api_key=api_key, base_url=base_url)
|
@@ -0,0 +1,143 @@
|
|
1
|
+
"""
|
2
|
+
Client implementation for Anthropic Claude models
|
3
|
+
"""
|
4
|
+
import os
|
5
|
+
import logging
|
6
|
+
import httpx
|
7
|
+
from typing import List, Dict, Optional, Union
|
8
|
+
|
9
|
+
import anthropic
|
10
|
+
from anthropic import AnthropicVertex, AnthropicBedrock
|
11
|
+
|
12
|
+
from ..formatters import AnthropicFormatter
|
13
|
+
from .base import BaseClient
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
class AnthropicClient(BaseClient):
|
18
|
+
"""Client for Anthropic Claude API"""
|
19
|
+
|
20
|
+
def _get_service_name(self) -> str:
|
21
|
+
return "anthropic"
|
22
|
+
|
23
|
+
def completion(self, messages, max_tokens=1000, temperature=0.5,
|
24
|
+
top_p=1.0, top_k=40, json_format=False, **kwargs):
|
25
|
+
"""
|
26
|
+
Generate a completion using Anthropic Claude.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
messages: List of message dictionaries
|
30
|
+
max_tokens: Maximum number of tokens to generate
|
31
|
+
temperature: Sampling temperature
|
32
|
+
top_p: Nucleus sampling parameter
|
33
|
+
top_k: Top-k sampling parameter
|
34
|
+
json_format: Whether to return JSON
|
35
|
+
**kwargs: Additional model-specific parameters
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
Generated text response
|
39
|
+
"""
|
40
|
+
try:
|
41
|
+
# Get API credentials if not set
|
42
|
+
self.get_credentials()
|
43
|
+
|
44
|
+
# Format messages for Anthropic API
|
45
|
+
formatter = AnthropicFormatter()
|
46
|
+
system_message, formatted_messages = formatter.format_messages(messages)
|
47
|
+
|
48
|
+
# Get the model name from kwargs or use default
|
49
|
+
model_name = kwargs.get("model", "claude-3-opus")
|
50
|
+
|
51
|
+
# Check for Vertex configuration
|
52
|
+
vertex_project_id = os.getenv('VERTEX_PROJECT_ID')
|
53
|
+
vertex_region = os.getenv('VERTEX_REGION')
|
54
|
+
|
55
|
+
# Check for AWS Bedrock configuration
|
56
|
+
aws_region = os.getenv('AWS_REGION', 'us-east-1')
|
57
|
+
aws_access_key = os.getenv('AWS_ACCESS_KEY_ID')
|
58
|
+
aws_secret_key = os.getenv('AWS_SECRET_ACCESS_KEY')
|
59
|
+
aws_session_token = os.getenv('AWS_SESSION_TOKEN')
|
60
|
+
|
61
|
+
# Get proxy configuration from environment or default to None
|
62
|
+
http_proxy = os.getenv("HTTP_PROXY")
|
63
|
+
https_proxy = os.getenv("HTTPS_PROXY")
|
64
|
+
|
65
|
+
# Determine if we should use Bedrock based on model name prefix
|
66
|
+
use_bedrock = "anthropic." in model_name
|
67
|
+
|
68
|
+
if use_bedrock:
|
69
|
+
logger.info(f"Using AWS Bedrock for model: {model_name}")
|
70
|
+
# Use AWS Bedrock for Claude
|
71
|
+
bedrock_kwargs = {
|
72
|
+
"aws_region": aws_region
|
73
|
+
}
|
74
|
+
|
75
|
+
# Only add credentials if explicitly provided
|
76
|
+
if aws_access_key and aws_secret_key:
|
77
|
+
bedrock_kwargs["aws_access_key"] = aws_access_key
|
78
|
+
bedrock_kwargs["aws_secret_key"] = aws_secret_key
|
79
|
+
|
80
|
+
client = AnthropicBedrock(**bedrock_kwargs)
|
81
|
+
|
82
|
+
response = client.messages.create(
|
83
|
+
model=model_name,
|
84
|
+
max_tokens=max_tokens,
|
85
|
+
temperature=temperature,
|
86
|
+
system=system_message,
|
87
|
+
messages=formatted_messages,
|
88
|
+
top_p=top_p,
|
89
|
+
top_k=top_k
|
90
|
+
)
|
91
|
+
elif vertex_project_id and vertex_region:
|
92
|
+
# Use Vertex AI for Claude
|
93
|
+
client = AnthropicVertex(
|
94
|
+
region=vertex_region,
|
95
|
+
project_id=vertex_project_id
|
96
|
+
)
|
97
|
+
|
98
|
+
response = client.messages.create(
|
99
|
+
model=model_name,
|
100
|
+
max_tokens=max_tokens,
|
101
|
+
temperature=temperature,
|
102
|
+
system=system_message,
|
103
|
+
messages=formatted_messages,
|
104
|
+
top_p=top_p,
|
105
|
+
top_k=top_k
|
106
|
+
)
|
107
|
+
else:
|
108
|
+
# Create httpx client with proxy settings if needed
|
109
|
+
http_options = {}
|
110
|
+
if http_proxy or https_proxy:
|
111
|
+
proxies = {}
|
112
|
+
if http_proxy:
|
113
|
+
proxies["http://"] = http_proxy
|
114
|
+
if https_proxy:
|
115
|
+
proxies["https://"] = https_proxy
|
116
|
+
http_options["proxies"] = proxies
|
117
|
+
|
118
|
+
# Use direct Anthropic API with proper http client
|
119
|
+
client = anthropic.Anthropic(
|
120
|
+
api_key=self.api_key,
|
121
|
+
base_url=self.base_url,
|
122
|
+
http_client=httpx.Client(**http_options) if http_options else None
|
123
|
+
)
|
124
|
+
|
125
|
+
response = client.messages.create(
|
126
|
+
model=model_name,
|
127
|
+
max_tokens=max_tokens,
|
128
|
+
temperature=temperature,
|
129
|
+
system=system_message,
|
130
|
+
messages=formatted_messages,
|
131
|
+
top_p=top_p,
|
132
|
+
top_k=top_k
|
133
|
+
)
|
134
|
+
|
135
|
+
# Release API credentials
|
136
|
+
self.release_credentials()
|
137
|
+
|
138
|
+
return response.content[0].text
|
139
|
+
|
140
|
+
except Exception as e:
|
141
|
+
logger.error(f"Anthropic API error: {e}")
|
142
|
+
self.report_error()
|
143
|
+
raise
|
@@ -0,0 +1,65 @@
|
|
1
|
+
"""
|
2
|
+
Base client interface for LLM APIs
|
3
|
+
"""
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
from typing import List, Dict, Optional, Union
|
6
|
+
import logging
|
7
|
+
|
8
|
+
from ..key_manager import key_manager
|
9
|
+
|
10
|
+
logger = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
class BaseClient(ABC):
|
13
|
+
"""Base class for LLM API clients"""
|
14
|
+
|
15
|
+
def __init__(self, api_key=None, base_url=None):
|
16
|
+
"""
|
17
|
+
Initialize client with optional API key and base URL.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
api_key: Optional API key to use
|
21
|
+
base_url: Optional base URL for API requests
|
22
|
+
"""
|
23
|
+
self.api_key = api_key
|
24
|
+
self.base_url = base_url
|
25
|
+
self.service_name = self._get_service_name()
|
26
|
+
|
27
|
+
@abstractmethod
|
28
|
+
def _get_service_name(self) -> str:
|
29
|
+
"""Return the service name for this client (e.g., 'openai', 'anthropic')"""
|
30
|
+
pass
|
31
|
+
|
32
|
+
def get_credentials(self):
|
33
|
+
"""Get API credentials from key manager if not set"""
|
34
|
+
if not self.api_key:
|
35
|
+
self.api_key, self.base_url = key_manager.get_config(self.service_name)
|
36
|
+
|
37
|
+
def release_credentials(self):
|
38
|
+
"""Release API credentials in key manager"""
|
39
|
+
if self.api_key:
|
40
|
+
key_manager.release_config(self.service_name, self.api_key)
|
41
|
+
|
42
|
+
def report_error(self):
|
43
|
+
"""Report API error to key manager"""
|
44
|
+
if self.api_key:
|
45
|
+
key_manager.report_error(self.service_name, self.api_key)
|
46
|
+
|
47
|
+
@abstractmethod
|
48
|
+
def completion(self, messages, max_tokens=1000, temperature=0.5,
|
49
|
+
top_p=1.0, top_k=40, json_format=False, **kwargs):
|
50
|
+
"""
|
51
|
+
Generate a completion for the given messages.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
messages: List of message dictionaries
|
55
|
+
max_tokens: Maximum number of tokens to generate
|
56
|
+
temperature: Sampling temperature
|
57
|
+
top_p: Nucleus sampling parameter
|
58
|
+
top_k: Top-k sampling parameter
|
59
|
+
json_format: Whether to return JSON
|
60
|
+
**kwargs: Additional model-specific parameters
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
Generated text response
|
64
|
+
"""
|
65
|
+
pass
|
@@ -0,0 +1,78 @@
|
|
1
|
+
"""
|
2
|
+
Client implementation for Google Gemini models
|
3
|
+
"""
|
4
|
+
import os
|
5
|
+
import logging
|
6
|
+
from typing import List, Dict, Optional, Union
|
7
|
+
|
8
|
+
import google.generativeai as genai
|
9
|
+
|
10
|
+
from ..formatters import GeminiFormatter
|
11
|
+
from .base import BaseClient
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
class GeminiClient(BaseClient):
|
16
|
+
"""Client for Google Gemini API"""
|
17
|
+
|
18
|
+
def _get_service_name(self) -> str:
|
19
|
+
return "gemini"
|
20
|
+
|
21
|
+
def completion(self, messages, max_tokens=1000, temperature=0.5,
|
22
|
+
top_p=1.0, top_k=40, json_format=False, **kwargs):
|
23
|
+
"""
|
24
|
+
Generate a completion using Google Gemini.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
messages: List of message dictionaries
|
28
|
+
max_tokens: Maximum number of tokens to generate
|
29
|
+
temperature: Sampling temperature
|
30
|
+
top_p: Nucleus sampling parameter
|
31
|
+
top_k: Top-k sampling parameter
|
32
|
+
json_format: Whether to return JSON
|
33
|
+
**kwargs: Additional model-specific parameters
|
34
|
+
|
35
|
+
Returns:
|
36
|
+
Generated text response
|
37
|
+
"""
|
38
|
+
try:
|
39
|
+
# Get API credentials if not set
|
40
|
+
self.get_credentials()
|
41
|
+
|
42
|
+
# Configure Google API
|
43
|
+
genai.configure(api_key=self.api_key)
|
44
|
+
|
45
|
+
# Format messages for Gemini API
|
46
|
+
formatter = GeminiFormatter()
|
47
|
+
system_message, formatted_messages = formatter.format_messages(messages)
|
48
|
+
|
49
|
+
# Create model configuration
|
50
|
+
model = genai.GenerativeModel(
|
51
|
+
model_name=kwargs.get("model", "gemini-1.5-pro"),
|
52
|
+
generation_config={
|
53
|
+
"max_output_tokens": max_tokens,
|
54
|
+
"temperature": temperature,
|
55
|
+
"top_p": top_p,
|
56
|
+
"top_k": top_k
|
57
|
+
},
|
58
|
+
system_instruction=system_message
|
59
|
+
)
|
60
|
+
|
61
|
+
# Generate response
|
62
|
+
if json_format:
|
63
|
+
response = model.generate_content(
|
64
|
+
formatted_messages,
|
65
|
+
generation_config={"response_mime_type": "application/json"}
|
66
|
+
)
|
67
|
+
else:
|
68
|
+
response = model.generate_content(formatted_messages)
|
69
|
+
|
70
|
+
# Release API credentials
|
71
|
+
self.release_credentials()
|
72
|
+
|
73
|
+
return response.text
|
74
|
+
|
75
|
+
except Exception as e:
|
76
|
+
logger.error(f"Gemini API error: {e}")
|
77
|
+
self.report_error()
|
78
|
+
raise
|
@@ -0,0 +1,97 @@
|
|
1
|
+
"""
|
2
|
+
Client implementation for OpenAI models
|
3
|
+
"""
|
4
|
+
import os
|
5
|
+
import logging
|
6
|
+
import httpx
|
7
|
+
from typing import List, Dict, Optional, Union
|
8
|
+
|
9
|
+
import openai
|
10
|
+
|
11
|
+
from ..formatters import OpenAIFormatter
|
12
|
+
from .base import BaseClient
|
13
|
+
|
14
|
+
logger = logging.getLogger(__name__)
|
15
|
+
|
16
|
+
class OpenAIClient(BaseClient):
|
17
|
+
"""Client for OpenAI API"""
|
18
|
+
|
19
|
+
def _get_service_name(self) -> str:
|
20
|
+
return "openai"
|
21
|
+
|
22
|
+
def completion(self, messages, max_tokens=1000, temperature=0.5,
|
23
|
+
top_p=1.0, top_k=40, json_format=False, **kwargs):
|
24
|
+
"""
|
25
|
+
Generate a completion using OpenAI API.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
messages: List of message dictionaries
|
29
|
+
max_tokens: Maximum number of tokens to generate
|
30
|
+
temperature: Sampling temperature
|
31
|
+
top_p: Nucleus sampling parameter
|
32
|
+
top_k: Top-k sampling parameter (not used for OpenAI)
|
33
|
+
json_format: Whether to return JSON
|
34
|
+
**kwargs: Additional model-specific parameters
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
Generated text response
|
38
|
+
"""
|
39
|
+
try:
|
40
|
+
# Get API credentials if not set
|
41
|
+
self.get_credentials()
|
42
|
+
|
43
|
+
# Format messages for OpenAI API
|
44
|
+
formatter = OpenAIFormatter()
|
45
|
+
_, formatted_messages = formatter.format_messages(messages)
|
46
|
+
|
47
|
+
# Get proxy configuration from environment or default to None
|
48
|
+
http_proxy = os.getenv("HTTP_PROXY")
|
49
|
+
https_proxy = os.getenv("HTTPS_PROXY")
|
50
|
+
|
51
|
+
# Create httpx client with proxy settings if needed
|
52
|
+
http_options = {}
|
53
|
+
if http_proxy or https_proxy:
|
54
|
+
proxies = {}
|
55
|
+
if http_proxy:
|
56
|
+
proxies["http://"] = http_proxy
|
57
|
+
if https_proxy:
|
58
|
+
proxies["https://"] = https_proxy
|
59
|
+
http_options["proxies"] = proxies
|
60
|
+
|
61
|
+
# Create OpenAI client with proper configuration
|
62
|
+
client = openai.OpenAI(
|
63
|
+
api_key=self.api_key,
|
64
|
+
base_url=self.base_url,
|
65
|
+
http_client=httpx.Client(**http_options) if http_options else None
|
66
|
+
)
|
67
|
+
|
68
|
+
# Process model name
|
69
|
+
model = kwargs.get("model", "gpt-4")
|
70
|
+
if model.endswith("-openai"):
|
71
|
+
model = model[:-7] # Remove last 7 characters ("-openai")
|
72
|
+
|
73
|
+
# Create base parameters
|
74
|
+
params = {
|
75
|
+
"model": model,
|
76
|
+
"messages": formatted_messages,
|
77
|
+
"max_tokens": max_tokens,
|
78
|
+
"temperature": temperature,
|
79
|
+
"top_p": top_p
|
80
|
+
}
|
81
|
+
|
82
|
+
# Add optional parameters
|
83
|
+
if json_format:
|
84
|
+
params["response_format"] = {"type": "json_object"}
|
85
|
+
|
86
|
+
# Generate completion
|
87
|
+
response = client.chat.completions.create(**params)
|
88
|
+
|
89
|
+
# Release API credentials
|
90
|
+
self.release_credentials()
|
91
|
+
|
92
|
+
return response.choices[0].message.content
|
93
|
+
|
94
|
+
except Exception as e:
|
95
|
+
logger.error(f"OpenAI API error: {e}")
|
96
|
+
self.report_error()
|
97
|
+
raise
|