llm-dialog-manager 0.4.7__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_dialog_manager/__init__.py +18 -2
- llm_dialog_manager/agent.py +141 -594
- llm_dialog_manager/clients/__init__.py +31 -0
- llm_dialog_manager/clients/anthropic_client.py +143 -0
- llm_dialog_manager/clients/base.py +65 -0
- llm_dialog_manager/clients/gemini_client.py +78 -0
- llm_dialog_manager/clients/openai_client.py +97 -0
- llm_dialog_manager/clients/x_client.py +83 -0
- llm_dialog_manager/formatters/__init__.py +27 -0
- llm_dialog_manager/formatters/anthropic.py +76 -0
- llm_dialog_manager/formatters/base.py +23 -0
- llm_dialog_manager/formatters/gemini.py +59 -0
- llm_dialog_manager/formatters/openai.py +67 -0
- llm_dialog_manager/formatters/x.py +77 -0
- llm_dialog_manager/utils/__init__.py +3 -0
- llm_dialog_manager/utils/environment.py +66 -0
- llm_dialog_manager/utils/image_tools.py +81 -0
- llm_dialog_manager/utils/logging.py +35 -0
- {llm_dialog_manager-0.4.7.dist-info → llm_dialog_manager-0.5.3.dist-info}/METADATA +2 -2
- llm_dialog_manager-0.5.3.dist-info/RECORD +25 -0
- {llm_dialog_manager-0.4.7.dist-info → llm_dialog_manager-0.5.3.dist-info}/WHEEL +1 -1
- llm_dialog_manager-0.4.7.dist-info/RECORD +0 -9
- {llm_dialog_manager-0.4.7.dist-info → llm_dialog_manager-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {llm_dialog_manager-0.4.7.dist-info → llm_dialog_manager-0.5.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,31 @@
|
|
1
|
+
"""
|
2
|
+
LLM API client modules for different services.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from .base import BaseClient
|
6
|
+
from .anthropic_client import AnthropicClient
|
7
|
+
from .gemini_client import GeminiClient
|
8
|
+
from .openai_client import OpenAIClient
|
9
|
+
from .x_client import XClient
|
10
|
+
|
11
|
+
def get_client(model_name, api_key=None, base_url=None):
|
12
|
+
"""Factory method to get the appropriate client for a given model.
|
13
|
+
|
14
|
+
Args:
|
15
|
+
model_name: Name of the model (e.g., "claude-3-opus", "gemini-1.5-pro")
|
16
|
+
api_key: Optional API key to use
|
17
|
+
base_url: Optional base URL to use
|
18
|
+
|
19
|
+
Returns:
|
20
|
+
An instance of the appropriate client class
|
21
|
+
"""
|
22
|
+
if "-openai" in model_name:
|
23
|
+
return OpenAIClient(api_key=api_key, base_url=base_url)
|
24
|
+
if "claude" in model_name:
|
25
|
+
return AnthropicClient(api_key=api_key, base_url=base_url)
|
26
|
+
elif "gemini" in model_name:
|
27
|
+
return GeminiClient(api_key=api_key, base_url=base_url)
|
28
|
+
elif "grok" in model_name:
|
29
|
+
return XClient(api_key=api_key, base_url=base_url)
|
30
|
+
else: # Default to OpenAI client
|
31
|
+
return OpenAIClient(api_key=api_key, base_url=base_url)
|
@@ -0,0 +1,143 @@
|
|
1
|
+
"""
|
2
|
+
Client implementation for Anthropic Claude models
|
3
|
+
"""
|
4
|
+
import os
|
5
|
+
import logging
|
6
|
+
import httpx
|
7
|
+
from typing import List, Dict, Optional, Union
|
8
|
+
|
9
|
+
import anthropic
|
10
|
+
from anthropic import AnthropicVertex, AnthropicBedrock
|
11
|
+
|
12
|
+
from ..formatters import AnthropicFormatter
|
13
|
+
from .base import BaseClient
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
class AnthropicClient(BaseClient):
|
18
|
+
"""Client for Anthropic Claude API"""
|
19
|
+
|
20
|
+
def _get_service_name(self) -> str:
|
21
|
+
return "anthropic"
|
22
|
+
|
23
|
+
def completion(self, messages, max_tokens=1000, temperature=0.5,
|
24
|
+
top_p=1.0, top_k=40, json_format=False, **kwargs):
|
25
|
+
"""
|
26
|
+
Generate a completion using Anthropic Claude.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
messages: List of message dictionaries
|
30
|
+
max_tokens: Maximum number of tokens to generate
|
31
|
+
temperature: Sampling temperature
|
32
|
+
top_p: Nucleus sampling parameter
|
33
|
+
top_k: Top-k sampling parameter
|
34
|
+
json_format: Whether to return JSON
|
35
|
+
**kwargs: Additional model-specific parameters
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
Generated text response
|
39
|
+
"""
|
40
|
+
try:
|
41
|
+
# Get API credentials if not set
|
42
|
+
self.get_credentials()
|
43
|
+
|
44
|
+
# Format messages for Anthropic API
|
45
|
+
formatter = AnthropicFormatter()
|
46
|
+
system_message, formatted_messages = formatter.format_messages(messages)
|
47
|
+
|
48
|
+
# Get the model name from kwargs or use default
|
49
|
+
model_name = kwargs.get("model", "claude-3-opus")
|
50
|
+
|
51
|
+
# Check for Vertex configuration
|
52
|
+
vertex_project_id = os.getenv('VERTEX_PROJECT_ID')
|
53
|
+
vertex_region = os.getenv('VERTEX_REGION')
|
54
|
+
|
55
|
+
# Check for AWS Bedrock configuration
|
56
|
+
aws_region = os.getenv('AWS_REGION', 'us-east-1')
|
57
|
+
aws_access_key = os.getenv('AWS_ACCESS_KEY_ID')
|
58
|
+
aws_secret_key = os.getenv('AWS_SECRET_ACCESS_KEY')
|
59
|
+
aws_session_token = os.getenv('AWS_SESSION_TOKEN')
|
60
|
+
|
61
|
+
# Get proxy configuration from environment or default to None
|
62
|
+
http_proxy = os.getenv("HTTP_PROXY")
|
63
|
+
https_proxy = os.getenv("HTTPS_PROXY")
|
64
|
+
|
65
|
+
# Determine if we should use Bedrock based on model name prefix
|
66
|
+
use_bedrock = "anthropic." in model_name
|
67
|
+
|
68
|
+
if use_bedrock:
|
69
|
+
logger.info(f"Using AWS Bedrock for model: {model_name}")
|
70
|
+
# Use AWS Bedrock for Claude
|
71
|
+
bedrock_kwargs = {
|
72
|
+
"aws_region": aws_region
|
73
|
+
}
|
74
|
+
|
75
|
+
# Only add credentials if explicitly provided
|
76
|
+
if aws_access_key and aws_secret_key:
|
77
|
+
bedrock_kwargs["aws_access_key"] = aws_access_key
|
78
|
+
bedrock_kwargs["aws_secret_key"] = aws_secret_key
|
79
|
+
|
80
|
+
client = AnthropicBedrock(**bedrock_kwargs)
|
81
|
+
|
82
|
+
response = client.messages.create(
|
83
|
+
model=model_name,
|
84
|
+
max_tokens=max_tokens,
|
85
|
+
temperature=temperature,
|
86
|
+
system=system_message,
|
87
|
+
messages=formatted_messages,
|
88
|
+
top_p=top_p,
|
89
|
+
top_k=top_k
|
90
|
+
)
|
91
|
+
elif vertex_project_id and vertex_region:
|
92
|
+
# Use Vertex AI for Claude
|
93
|
+
client = AnthropicVertex(
|
94
|
+
region=vertex_region,
|
95
|
+
project_id=vertex_project_id
|
96
|
+
)
|
97
|
+
|
98
|
+
response = client.messages.create(
|
99
|
+
model=model_name,
|
100
|
+
max_tokens=max_tokens,
|
101
|
+
temperature=temperature,
|
102
|
+
system=system_message,
|
103
|
+
messages=formatted_messages,
|
104
|
+
top_p=top_p,
|
105
|
+
top_k=top_k
|
106
|
+
)
|
107
|
+
else:
|
108
|
+
# Create httpx client with proxy settings if needed
|
109
|
+
http_options = {}
|
110
|
+
if http_proxy or https_proxy:
|
111
|
+
proxies = {}
|
112
|
+
if http_proxy:
|
113
|
+
proxies["http://"] = http_proxy
|
114
|
+
if https_proxy:
|
115
|
+
proxies["https://"] = https_proxy
|
116
|
+
http_options["proxies"] = proxies
|
117
|
+
|
118
|
+
# Use direct Anthropic API with proper http client
|
119
|
+
client = anthropic.Anthropic(
|
120
|
+
api_key=self.api_key,
|
121
|
+
base_url=self.base_url,
|
122
|
+
http_client=httpx.Client(**http_options) if http_options else None
|
123
|
+
)
|
124
|
+
|
125
|
+
response = client.messages.create(
|
126
|
+
model=model_name,
|
127
|
+
max_tokens=max_tokens,
|
128
|
+
temperature=temperature,
|
129
|
+
system=system_message,
|
130
|
+
messages=formatted_messages,
|
131
|
+
top_p=top_p,
|
132
|
+
top_k=top_k
|
133
|
+
)
|
134
|
+
|
135
|
+
# Release API credentials
|
136
|
+
self.release_credentials()
|
137
|
+
|
138
|
+
return response.content[0].text
|
139
|
+
|
140
|
+
except Exception as e:
|
141
|
+
logger.error(f"Anthropic API error: {e}")
|
142
|
+
self.report_error()
|
143
|
+
raise
|
@@ -0,0 +1,65 @@
|
|
1
|
+
"""
|
2
|
+
Base client interface for LLM APIs
|
3
|
+
"""
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
from typing import List, Dict, Optional, Union
|
6
|
+
import logging
|
7
|
+
|
8
|
+
from ..key_manager import key_manager
|
9
|
+
|
10
|
+
logger = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
class BaseClient(ABC):
|
13
|
+
"""Base class for LLM API clients"""
|
14
|
+
|
15
|
+
def __init__(self, api_key=None, base_url=None):
|
16
|
+
"""
|
17
|
+
Initialize client with optional API key and base URL.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
api_key: Optional API key to use
|
21
|
+
base_url: Optional base URL for API requests
|
22
|
+
"""
|
23
|
+
self.api_key = api_key
|
24
|
+
self.base_url = base_url
|
25
|
+
self.service_name = self._get_service_name()
|
26
|
+
|
27
|
+
@abstractmethod
|
28
|
+
def _get_service_name(self) -> str:
|
29
|
+
"""Return the service name for this client (e.g., 'openai', 'anthropic')"""
|
30
|
+
pass
|
31
|
+
|
32
|
+
def get_credentials(self):
|
33
|
+
"""Get API credentials from key manager if not set"""
|
34
|
+
if not self.api_key:
|
35
|
+
self.api_key, self.base_url = key_manager.get_config(self.service_name)
|
36
|
+
|
37
|
+
def release_credentials(self):
|
38
|
+
"""Release API credentials in key manager"""
|
39
|
+
if self.api_key:
|
40
|
+
key_manager.release_config(self.service_name, self.api_key)
|
41
|
+
|
42
|
+
def report_error(self):
|
43
|
+
"""Report API error to key manager"""
|
44
|
+
if self.api_key:
|
45
|
+
key_manager.report_error(self.service_name, self.api_key)
|
46
|
+
|
47
|
+
@abstractmethod
|
48
|
+
def completion(self, messages, max_tokens=1000, temperature=0.5,
|
49
|
+
top_p=1.0, top_k=40, json_format=False, **kwargs):
|
50
|
+
"""
|
51
|
+
Generate a completion for the given messages.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
messages: List of message dictionaries
|
55
|
+
max_tokens: Maximum number of tokens to generate
|
56
|
+
temperature: Sampling temperature
|
57
|
+
top_p: Nucleus sampling parameter
|
58
|
+
top_k: Top-k sampling parameter
|
59
|
+
json_format: Whether to return JSON
|
60
|
+
**kwargs: Additional model-specific parameters
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
Generated text response
|
64
|
+
"""
|
65
|
+
pass
|
@@ -0,0 +1,78 @@
|
|
1
|
+
"""
|
2
|
+
Client implementation for Google Gemini models
|
3
|
+
"""
|
4
|
+
import os
|
5
|
+
import logging
|
6
|
+
from typing import List, Dict, Optional, Union
|
7
|
+
|
8
|
+
import google.generativeai as genai
|
9
|
+
|
10
|
+
from ..formatters import GeminiFormatter
|
11
|
+
from .base import BaseClient
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
class GeminiClient(BaseClient):
|
16
|
+
"""Client for Google Gemini API"""
|
17
|
+
|
18
|
+
def _get_service_name(self) -> str:
|
19
|
+
return "gemini"
|
20
|
+
|
21
|
+
def completion(self, messages, max_tokens=1000, temperature=0.5,
|
22
|
+
top_p=1.0, top_k=40, json_format=False, **kwargs):
|
23
|
+
"""
|
24
|
+
Generate a completion using Google Gemini.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
messages: List of message dictionaries
|
28
|
+
max_tokens: Maximum number of tokens to generate
|
29
|
+
temperature: Sampling temperature
|
30
|
+
top_p: Nucleus sampling parameter
|
31
|
+
top_k: Top-k sampling parameter
|
32
|
+
json_format: Whether to return JSON
|
33
|
+
**kwargs: Additional model-specific parameters
|
34
|
+
|
35
|
+
Returns:
|
36
|
+
Generated text response
|
37
|
+
"""
|
38
|
+
try:
|
39
|
+
# Get API credentials if not set
|
40
|
+
self.get_credentials()
|
41
|
+
|
42
|
+
# Configure Google API
|
43
|
+
genai.configure(api_key=self.api_key)
|
44
|
+
|
45
|
+
# Format messages for Gemini API
|
46
|
+
formatter = GeminiFormatter()
|
47
|
+
system_message, formatted_messages = formatter.format_messages(messages)
|
48
|
+
|
49
|
+
# Create model configuration
|
50
|
+
model = genai.GenerativeModel(
|
51
|
+
model_name=kwargs.get("model", "gemini-1.5-pro"),
|
52
|
+
generation_config={
|
53
|
+
"max_output_tokens": max_tokens,
|
54
|
+
"temperature": temperature,
|
55
|
+
"top_p": top_p,
|
56
|
+
"top_k": top_k
|
57
|
+
},
|
58
|
+
system_instruction=system_message
|
59
|
+
)
|
60
|
+
|
61
|
+
# Generate response
|
62
|
+
if json_format:
|
63
|
+
response = model.generate_content(
|
64
|
+
formatted_messages,
|
65
|
+
generation_config={"response_mime_type": "application/json"}
|
66
|
+
)
|
67
|
+
else:
|
68
|
+
response = model.generate_content(formatted_messages)
|
69
|
+
|
70
|
+
# Release API credentials
|
71
|
+
self.release_credentials()
|
72
|
+
|
73
|
+
return response.text
|
74
|
+
|
75
|
+
except Exception as e:
|
76
|
+
logger.error(f"Gemini API error: {e}")
|
77
|
+
self.report_error()
|
78
|
+
raise
|
@@ -0,0 +1,97 @@
|
|
1
|
+
"""
|
2
|
+
Client implementation for OpenAI models
|
3
|
+
"""
|
4
|
+
import os
|
5
|
+
import logging
|
6
|
+
import httpx
|
7
|
+
from typing import List, Dict, Optional, Union
|
8
|
+
|
9
|
+
import openai
|
10
|
+
|
11
|
+
from ..formatters import OpenAIFormatter
|
12
|
+
from .base import BaseClient
|
13
|
+
|
14
|
+
logger = logging.getLogger(__name__)
|
15
|
+
|
16
|
+
class OpenAIClient(BaseClient):
|
17
|
+
"""Client for OpenAI API"""
|
18
|
+
|
19
|
+
def _get_service_name(self) -> str:
|
20
|
+
return "openai"
|
21
|
+
|
22
|
+
def completion(self, messages, max_tokens=1000, temperature=0.5,
|
23
|
+
top_p=1.0, top_k=40, json_format=False, **kwargs):
|
24
|
+
"""
|
25
|
+
Generate a completion using OpenAI API.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
messages: List of message dictionaries
|
29
|
+
max_tokens: Maximum number of tokens to generate
|
30
|
+
temperature: Sampling temperature
|
31
|
+
top_p: Nucleus sampling parameter
|
32
|
+
top_k: Top-k sampling parameter (not used for OpenAI)
|
33
|
+
json_format: Whether to return JSON
|
34
|
+
**kwargs: Additional model-specific parameters
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
Generated text response
|
38
|
+
"""
|
39
|
+
try:
|
40
|
+
# Get API credentials if not set
|
41
|
+
self.get_credentials()
|
42
|
+
|
43
|
+
# Format messages for OpenAI API
|
44
|
+
formatter = OpenAIFormatter()
|
45
|
+
_, formatted_messages = formatter.format_messages(messages)
|
46
|
+
|
47
|
+
# Get proxy configuration from environment or default to None
|
48
|
+
http_proxy = os.getenv("HTTP_PROXY")
|
49
|
+
https_proxy = os.getenv("HTTPS_PROXY")
|
50
|
+
|
51
|
+
# Create httpx client with proxy settings if needed
|
52
|
+
http_options = {}
|
53
|
+
if http_proxy or https_proxy:
|
54
|
+
proxies = {}
|
55
|
+
if http_proxy:
|
56
|
+
proxies["http://"] = http_proxy
|
57
|
+
if https_proxy:
|
58
|
+
proxies["https://"] = https_proxy
|
59
|
+
http_options["proxies"] = proxies
|
60
|
+
|
61
|
+
# Create OpenAI client with proper configuration
|
62
|
+
client = openai.OpenAI(
|
63
|
+
api_key=self.api_key,
|
64
|
+
base_url=self.base_url,
|
65
|
+
http_client=httpx.Client(**http_options) if http_options else None
|
66
|
+
)
|
67
|
+
|
68
|
+
# Process model name
|
69
|
+
model = kwargs.get("model", "gpt-4")
|
70
|
+
if model.endswith("-openai"):
|
71
|
+
model = model[:-7] # Remove last 7 characters ("-openai")
|
72
|
+
|
73
|
+
# Create base parameters
|
74
|
+
params = {
|
75
|
+
"model": model,
|
76
|
+
"messages": formatted_messages,
|
77
|
+
"max_tokens": max_tokens,
|
78
|
+
"temperature": temperature,
|
79
|
+
"top_p": top_p
|
80
|
+
}
|
81
|
+
|
82
|
+
# Add optional parameters
|
83
|
+
if json_format:
|
84
|
+
params["response_format"] = {"type": "json_object"}
|
85
|
+
|
86
|
+
# Generate completion
|
87
|
+
response = client.chat.completions.create(**params)
|
88
|
+
|
89
|
+
# Release API credentials
|
90
|
+
self.release_credentials()
|
91
|
+
|
92
|
+
return response.choices[0].message.content
|
93
|
+
|
94
|
+
except Exception as e:
|
95
|
+
logger.error(f"OpenAI API error: {e}")
|
96
|
+
self.report_error()
|
97
|
+
raise
|
@@ -0,0 +1,83 @@
|
|
1
|
+
"""
|
2
|
+
Client implementation for X.AI (Grok) models
|
3
|
+
"""
|
4
|
+
import os
|
5
|
+
import logging
|
6
|
+
import openai
|
7
|
+
from typing import List, Dict, Optional, Union
|
8
|
+
|
9
|
+
from ..formatters import OpenAIFormatter
|
10
|
+
from .base import BaseClient
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
class XClient(BaseClient):
|
15
|
+
"""Client for X.AI (Grok) API"""
|
16
|
+
|
17
|
+
def _get_service_name(self) -> str:
|
18
|
+
return "x"
|
19
|
+
|
20
|
+
def completion(self, messages, max_tokens=1000, temperature=0.5,
|
21
|
+
top_p=1.0, top_k=40, json_format=False, **kwargs):
|
22
|
+
"""
|
23
|
+
Generate a completion using X.AI (Grok) API.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
messages: List of message dictionaries
|
27
|
+
max_tokens: Maximum number of tokens to generate
|
28
|
+
temperature: Sampling temperature
|
29
|
+
top_p: Nucleus sampling parameter
|
30
|
+
top_k: Top-k sampling parameter
|
31
|
+
json_format: Whether to return JSON
|
32
|
+
**kwargs: Additional model-specific parameters
|
33
|
+
|
34
|
+
Returns:
|
35
|
+
Generated text response
|
36
|
+
"""
|
37
|
+
try:
|
38
|
+
# Get API credentials if not set
|
39
|
+
self.get_credentials()
|
40
|
+
|
41
|
+
# Format messages for X.AI API using OpenAI formatter
|
42
|
+
# (X.AI uses the same message format as OpenAI)
|
43
|
+
formatter = OpenAIFormatter()
|
44
|
+
_, formatted_messages = formatter.format_messages(messages)
|
45
|
+
|
46
|
+
# Set default base URL if not already set
|
47
|
+
if not self.base_url:
|
48
|
+
self.base_url = "https://api.x.ai/v1"
|
49
|
+
|
50
|
+
# Initialize OpenAI client
|
51
|
+
client = openai.OpenAI(
|
52
|
+
api_key=self.api_key,
|
53
|
+
base_url=self.base_url
|
54
|
+
)
|
55
|
+
|
56
|
+
# Process model name
|
57
|
+
model = kwargs.get("model", "grok-3-beta")
|
58
|
+
|
59
|
+
# Create base parameters
|
60
|
+
params = {
|
61
|
+
"model": model,
|
62
|
+
"messages": formatted_messages,
|
63
|
+
"max_tokens": max_tokens,
|
64
|
+
"temperature": temperature,
|
65
|
+
"top_p": top_p,
|
66
|
+
}
|
67
|
+
|
68
|
+
# Add optional parameters
|
69
|
+
if json_format:
|
70
|
+
params["response_format"] = {"type": "json_object"}
|
71
|
+
|
72
|
+
# Generate completion using OpenAI client
|
73
|
+
response = client.chat.completions.create(**params)
|
74
|
+
|
75
|
+
# Release API credentials
|
76
|
+
self.release_credentials()
|
77
|
+
|
78
|
+
return response.choices[0].message.content
|
79
|
+
|
80
|
+
except Exception as e:
|
81
|
+
logger.error(f"X.AI API error: {e}")
|
82
|
+
self.report_error()
|
83
|
+
raise
|
@@ -0,0 +1,27 @@
|
|
1
|
+
"""
|
2
|
+
Message formatters for different LLM services.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from .base import BaseMessageFormatter
|
6
|
+
from .anthropic import AnthropicFormatter
|
7
|
+
from .gemini import GeminiFormatter
|
8
|
+
from .openai import OpenAIFormatter
|
9
|
+
from .x import XFormatter
|
10
|
+
|
11
|
+
def get_formatter(model_name):
|
12
|
+
"""Factory method to get the appropriate formatter for a given model.
|
13
|
+
|
14
|
+
Args:
|
15
|
+
model_name: Name of the model (e.g., "claude-3-opus", "gemini-1.5-pro")
|
16
|
+
|
17
|
+
Returns:
|
18
|
+
An instance of the appropriate formatter class
|
19
|
+
"""
|
20
|
+
if "claude" in model_name:
|
21
|
+
return AnthropicFormatter()
|
22
|
+
elif "gemini" in model_name:
|
23
|
+
return GeminiFormatter()
|
24
|
+
elif "grok" in model_name:
|
25
|
+
return XFormatter()
|
26
|
+
else: # Default to OpenAI formatter
|
27
|
+
return OpenAIFormatter()
|
@@ -0,0 +1,76 @@
|
|
1
|
+
"""
|
2
|
+
Message formatter for Anthropic Claude models
|
3
|
+
"""
|
4
|
+
import io
|
5
|
+
import base64
|
6
|
+
from typing import List, Dict, Union, Optional
|
7
|
+
from PIL import Image
|
8
|
+
|
9
|
+
from .base import BaseMessageFormatter
|
10
|
+
|
11
|
+
class AnthropicFormatter(BaseMessageFormatter):
|
12
|
+
"""Formatter for Anthropic Claude API messages"""
|
13
|
+
|
14
|
+
def format_messages(self, messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]) -> tuple:
|
15
|
+
"""
|
16
|
+
Format messages for the Anthropic Claude API.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
messages: List of message dictionaries in standard format
|
20
|
+
|
21
|
+
Returns:
|
22
|
+
A tuple containing (system_message, formatted_messages)
|
23
|
+
where system_message is extracted as a separate string
|
24
|
+
"""
|
25
|
+
formatted = []
|
26
|
+
system_msg = ""
|
27
|
+
|
28
|
+
# Extract system message if present
|
29
|
+
if messages and messages[0]["role"] == "system":
|
30
|
+
system_msg = messages[0]["content"]
|
31
|
+
messages = messages[1:]
|
32
|
+
|
33
|
+
for msg in messages:
|
34
|
+
content = msg["content"]
|
35
|
+
if isinstance(content, str):
|
36
|
+
formatted.append({"role": msg["role"], "content": content})
|
37
|
+
elif isinstance(content, list):
|
38
|
+
# Combine content blocks into a single message
|
39
|
+
combined_content = []
|
40
|
+
for block in content:
|
41
|
+
if isinstance(block, str):
|
42
|
+
combined_content.append({"type": "text", "text": block})
|
43
|
+
elif isinstance(block, Image.Image):
|
44
|
+
# For Claude, convert PIL.Image to base64
|
45
|
+
buffered = io.BytesIO()
|
46
|
+
block.save(buffered, format="PNG")
|
47
|
+
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
48
|
+
combined_content.append({
|
49
|
+
"type": "image",
|
50
|
+
"source": {
|
51
|
+
"type": "base64",
|
52
|
+
"media_type": "image/png",
|
53
|
+
"data": image_base64
|
54
|
+
}
|
55
|
+
})
|
56
|
+
elif isinstance(block, dict):
|
57
|
+
if block.get("type") == "image_url":
|
58
|
+
combined_content.append({
|
59
|
+
"type": "image",
|
60
|
+
"source": {
|
61
|
+
"type": "url",
|
62
|
+
"url": block["image_url"]["url"]
|
63
|
+
}
|
64
|
+
})
|
65
|
+
elif block.get("type") == "image_base64":
|
66
|
+
combined_content.append({
|
67
|
+
"type": "image",
|
68
|
+
"source": {
|
69
|
+
"type": "base64",
|
70
|
+
"media_type": block["image_base64"]["media_type"],
|
71
|
+
"data": block["image_base64"]["data"]
|
72
|
+
}
|
73
|
+
})
|
74
|
+
formatted.append({"role": msg["role"], "content": combined_content})
|
75
|
+
|
76
|
+
return system_msg, formatted
|
@@ -0,0 +1,23 @@
|
|
1
|
+
"""
|
2
|
+
Base message formatter interface
|
3
|
+
"""
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
from typing import List, Dict, Union, Optional
|
6
|
+
from PIL import Image
|
7
|
+
|
8
|
+
class BaseMessageFormatter(ABC):
|
9
|
+
"""Base class for message formatters"""
|
10
|
+
|
11
|
+
@abstractmethod
|
12
|
+
def format_messages(self, messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]) -> tuple:
|
13
|
+
"""
|
14
|
+
Format messages for the specific LLM API.
|
15
|
+
|
16
|
+
Args:
|
17
|
+
messages: List of message dictionaries in standard format
|
18
|
+
|
19
|
+
Returns:
|
20
|
+
A tuple containing (system_message, formatted_messages)
|
21
|
+
where system_message can be None if not used by the API
|
22
|
+
"""
|
23
|
+
pass
|
@@ -0,0 +1,59 @@
|
|
1
|
+
"""
|
2
|
+
Message formatter for Google Gemini models
|
3
|
+
"""
|
4
|
+
from typing import List, Dict, Union, Optional
|
5
|
+
from PIL import Image
|
6
|
+
|
7
|
+
from .base import BaseMessageFormatter
|
8
|
+
|
9
|
+
class GeminiFormatter(BaseMessageFormatter):
|
10
|
+
"""Formatter for Google Gemini API messages"""
|
11
|
+
|
12
|
+
def format_messages(self, messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]) -> tuple:
|
13
|
+
"""
|
14
|
+
Format messages for the Google Gemini API.
|
15
|
+
|
16
|
+
Args:
|
17
|
+
messages: List of message dictionaries in standard format
|
18
|
+
|
19
|
+
Returns:
|
20
|
+
A tuple containing (system_message, formatted_messages)
|
21
|
+
where system_message is extracted separately
|
22
|
+
"""
|
23
|
+
system_msg = None
|
24
|
+
formatted = []
|
25
|
+
|
26
|
+
for msg in messages:
|
27
|
+
# Extract system message if present
|
28
|
+
if msg["role"] == "system":
|
29
|
+
system_msg = msg["content"] if isinstance(msg["content"], str) else str(msg["content"])
|
30
|
+
continue
|
31
|
+
|
32
|
+
content = msg["content"]
|
33
|
+
if isinstance(content, str):
|
34
|
+
formatted.append({"role": msg["role"], "parts": [content]})
|
35
|
+
elif isinstance(content, list):
|
36
|
+
parts = []
|
37
|
+
for block in content:
|
38
|
+
if isinstance(block, str):
|
39
|
+
parts.append(block)
|
40
|
+
elif isinstance(block, Image.Image):
|
41
|
+
parts.append(block) # Gemini supports PIL.Image directly
|
42
|
+
elif isinstance(block, dict):
|
43
|
+
if block.get("type") == "image_url":
|
44
|
+
parts.append({
|
45
|
+
"inline_data": {
|
46
|
+
"mime_type": "image/jpeg",
|
47
|
+
"data": block["image_url"]["url"]
|
48
|
+
}
|
49
|
+
})
|
50
|
+
elif block.get("type") == "image_base64":
|
51
|
+
parts.append({
|
52
|
+
"inline_data": {
|
53
|
+
"mime_type": block["image_base64"]["media_type"],
|
54
|
+
"data": block["image_base64"]["data"]
|
55
|
+
}
|
56
|
+
})
|
57
|
+
formatted.append({"role": msg["role"], "parts": parts})
|
58
|
+
|
59
|
+
return system_msg, formatted
|