aiecs 1.7.6__py3-none-any.whl → 1.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aiecs might be problematic. Click here for more details.
- aiecs/__init__.py +1 -1
- aiecs/application/knowledge_graph/extractors/llm_entity_extractor.py +5 -1
- aiecs/application/knowledge_graph/retrieval/query_intent_classifier.py +7 -5
- aiecs/config/config.py +3 -0
- aiecs/config/tool_config.py +55 -19
- aiecs/domain/agent/base_agent.py +79 -0
- aiecs/domain/agent/hybrid_agent.py +552 -175
- aiecs/domain/agent/knowledge_aware_agent.py +3 -2
- aiecs/domain/agent/llm_agent.py +2 -0
- aiecs/domain/agent/models.py +10 -0
- aiecs/domain/agent/tools/schema_generator.py +17 -4
- aiecs/llm/callbacks/custom_callbacks.py +9 -4
- aiecs/llm/client_factory.py +20 -7
- aiecs/llm/clients/base_client.py +50 -5
- aiecs/llm/clients/google_function_calling_mixin.py +46 -88
- aiecs/llm/clients/googleai_client.py +183 -9
- aiecs/llm/clients/openai_client.py +12 -0
- aiecs/llm/clients/openai_compatible_mixin.py +42 -2
- aiecs/llm/clients/openrouter_client.py +272 -0
- aiecs/llm/clients/vertex_client.py +385 -22
- aiecs/llm/clients/xai_client.py +41 -3
- aiecs/llm/protocols.py +19 -1
- aiecs/llm/utils/image_utils.py +179 -0
- aiecs/main.py +2 -2
- aiecs/tools/docs/document_creator_tool.py +143 -2
- aiecs/tools/docs/document_parser_tool.py +9 -4
- aiecs/tools/docs/document_writer_tool.py +179 -0
- aiecs/tools/task_tools/image_tool.py +49 -14
- aiecs/tools/task_tools/scraper_tool.py +39 -2
- {aiecs-1.7.6.dist-info → aiecs-1.8.4.dist-info}/METADATA +4 -2
- {aiecs-1.7.6.dist-info → aiecs-1.8.4.dist-info}/RECORD +35 -33
- {aiecs-1.7.6.dist-info → aiecs-1.8.4.dist-info}/WHEEL +0 -0
- {aiecs-1.7.6.dist-info → aiecs-1.8.4.dist-info}/entry_points.txt +0 -0
- {aiecs-1.7.6.dist-info → aiecs-1.8.4.dist-info}/licenses/LICENSE +0 -0
- {aiecs-1.7.6.dist-info → aiecs-1.8.4.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,8 @@
|
|
|
1
|
+
import json
|
|
1
2
|
import logging
|
|
2
3
|
import os
|
|
3
|
-
|
|
4
|
+
import base64
|
|
5
|
+
from typing import Optional, List, AsyncGenerator, Dict, Any
|
|
4
6
|
|
|
5
7
|
from google import genai
|
|
6
8
|
from google.genai import types
|
|
@@ -13,6 +15,7 @@ from aiecs.llm.clients.base_client import (
|
|
|
13
15
|
RateLimitError,
|
|
14
16
|
)
|
|
15
17
|
from aiecs.config.config import get_settings
|
|
18
|
+
from aiecs.llm.utils.image_utils import parse_image_source, ImageContent
|
|
16
19
|
|
|
17
20
|
logger = logging.getLogger(__name__)
|
|
18
21
|
|
|
@@ -45,14 +48,145 @@ class GoogleAIClient(BaseLLMClient):
|
|
|
45
48
|
def _convert_messages_to_contents(
|
|
46
49
|
self, messages: List[LLMMessage]
|
|
47
50
|
) -> List[types.Content]:
|
|
48
|
-
"""
|
|
51
|
+
"""
|
|
52
|
+
Convert LLMMessage list to Google GenAI Content objects.
|
|
53
|
+
|
|
54
|
+
This properly handles multi-turn conversations including
|
|
55
|
+
function/tool responses for Google AI Function Calling.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
messages: List of LLMMessage objects (system messages should be filtered out)
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
List of Content objects for Google AI API
|
|
62
|
+
"""
|
|
49
63
|
contents = []
|
|
64
|
+
|
|
50
65
|
for msg in messages:
|
|
51
|
-
#
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
66
|
+
# Handle tool/function responses (role="tool")
|
|
67
|
+
if msg.role == "tool":
|
|
68
|
+
# Google AI expects function responses as user messages with FunctionResponse parts
|
|
69
|
+
func_name = msg.tool_call_id or "unknown_function"
|
|
70
|
+
|
|
71
|
+
# Parse content as the function response
|
|
72
|
+
try:
|
|
73
|
+
if msg.content and msg.content.strip().startswith('{'):
|
|
74
|
+
response_data = json.loads(msg.content)
|
|
75
|
+
else:
|
|
76
|
+
response_data = {"result": msg.content}
|
|
77
|
+
except json.JSONDecodeError:
|
|
78
|
+
response_data = {"result": msg.content}
|
|
79
|
+
|
|
80
|
+
# Create FunctionResponse part
|
|
81
|
+
func_response_part = types.Part.from_function_response(
|
|
82
|
+
name=func_name,
|
|
83
|
+
response=response_data
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
contents.append(types.Content(
|
|
87
|
+
role="user", # Function responses are sent as "user" role
|
|
88
|
+
parts=[func_response_part]
|
|
89
|
+
))
|
|
90
|
+
|
|
91
|
+
# Handle assistant messages with tool calls
|
|
92
|
+
elif msg.role == "assistant" and msg.tool_calls:
|
|
93
|
+
parts = []
|
|
94
|
+
if msg.content:
|
|
95
|
+
parts.append(types.Part(text=msg.content))
|
|
96
|
+
|
|
97
|
+
# Add images if present
|
|
98
|
+
if msg.images:
|
|
99
|
+
for image_source in msg.images:
|
|
100
|
+
image_content = parse_image_source(image_source)
|
|
101
|
+
|
|
102
|
+
if image_content.is_url():
|
|
103
|
+
# For URLs, use inline_data with downloaded content
|
|
104
|
+
# Note: Google AI SDK may support URL directly, but we'll use base64 for compatibility
|
|
105
|
+
try:
|
|
106
|
+
import urllib.request
|
|
107
|
+
with urllib.request.urlopen(image_content.get_url()) as response:
|
|
108
|
+
image_bytes = response.read()
|
|
109
|
+
parts.append(types.Part.from_bytes(
|
|
110
|
+
data=image_bytes,
|
|
111
|
+
mime_type=image_content.mime_type
|
|
112
|
+
))
|
|
113
|
+
except Exception as e:
|
|
114
|
+
logger.warning(f"Failed to download image from URL: {e}")
|
|
115
|
+
else:
|
|
116
|
+
# Convert to bytes for inline_data
|
|
117
|
+
base64_data = image_content.get_base64_data()
|
|
118
|
+
image_bytes = base64.b64decode(base64_data)
|
|
119
|
+
parts.append(types.Part.from_bytes(
|
|
120
|
+
data=image_bytes,
|
|
121
|
+
mime_type=image_content.mime_type
|
|
122
|
+
))
|
|
123
|
+
|
|
124
|
+
for tool_call in msg.tool_calls:
|
|
125
|
+
func = tool_call.get("function", {})
|
|
126
|
+
func_name = func.get("name", "")
|
|
127
|
+
func_args = func.get("arguments", "{}")
|
|
128
|
+
|
|
129
|
+
# Parse arguments
|
|
130
|
+
try:
|
|
131
|
+
args_dict = json.loads(func_args) if isinstance(func_args, str) else func_args
|
|
132
|
+
except json.JSONDecodeError:
|
|
133
|
+
args_dict = {}
|
|
134
|
+
|
|
135
|
+
# Create FunctionCall part using types.FunctionCall
|
|
136
|
+
# Note: types.Part.from_function_call() may not exist in google.genai
|
|
137
|
+
# Use FunctionCall type directly
|
|
138
|
+
function_call = types.FunctionCall(
|
|
139
|
+
name=func_name,
|
|
140
|
+
args=args_dict
|
|
141
|
+
)
|
|
142
|
+
parts.append(types.Part(function_call=function_call))
|
|
143
|
+
|
|
144
|
+
contents.append(types.Content(
|
|
145
|
+
role="model",
|
|
146
|
+
parts=parts
|
|
147
|
+
))
|
|
148
|
+
|
|
149
|
+
# Handle regular messages (user, assistant without tool_calls)
|
|
150
|
+
else:
|
|
151
|
+
role = "model" if msg.role == "assistant" else msg.role
|
|
152
|
+
parts = []
|
|
153
|
+
|
|
154
|
+
# Add text content if present
|
|
155
|
+
if msg.content:
|
|
156
|
+
parts.append(types.Part(text=msg.content))
|
|
157
|
+
|
|
158
|
+
# Add images if present
|
|
159
|
+
if msg.images:
|
|
160
|
+
for image_source in msg.images:
|
|
161
|
+
image_content = parse_image_source(image_source)
|
|
162
|
+
|
|
163
|
+
if image_content.is_url():
|
|
164
|
+
# Download URL and convert to bytes
|
|
165
|
+
try:
|
|
166
|
+
import urllib.request
|
|
167
|
+
with urllib.request.urlopen(image_content.get_url()) as response:
|
|
168
|
+
image_bytes = response.read()
|
|
169
|
+
parts.append(types.Part.from_bytes(
|
|
170
|
+
data=image_bytes,
|
|
171
|
+
mime_type=image_content.mime_type
|
|
172
|
+
))
|
|
173
|
+
except Exception as e:
|
|
174
|
+
logger.warning(f"Failed to download image from URL: {e}")
|
|
175
|
+
else:
|
|
176
|
+
# Convert to bytes for inline_data
|
|
177
|
+
base64_data = image_content.get_base64_data()
|
|
178
|
+
image_bytes = base64.b64decode(base64_data)
|
|
179
|
+
parts.append(types.Part.from_bytes(
|
|
180
|
+
data=image_bytes,
|
|
181
|
+
mime_type=image_content.mime_type
|
|
182
|
+
))
|
|
183
|
+
|
|
184
|
+
if parts:
|
|
185
|
+
contents.append(types.Content(
|
|
186
|
+
role=role,
|
|
187
|
+
parts=parts
|
|
188
|
+
))
|
|
189
|
+
|
|
56
190
|
return contents
|
|
57
191
|
|
|
58
192
|
async def generate_text(
|
|
@@ -61,10 +195,30 @@ class GoogleAIClient(BaseLLMClient):
|
|
|
61
195
|
model: Optional[str] = None,
|
|
62
196
|
temperature: float = 0.7,
|
|
63
197
|
max_tokens: Optional[int] = None,
|
|
198
|
+
context: Optional[Dict[str, Any]] = None,
|
|
64
199
|
system_instruction: Optional[str] = None,
|
|
65
200
|
**kwargs,
|
|
66
201
|
) -> LLMResponse:
|
|
67
|
-
"""
|
|
202
|
+
"""
|
|
203
|
+
Generate text using Google AI (google.genai SDK).
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
messages: List of conversation messages
|
|
207
|
+
model: Model name (optional, uses default if not provided)
|
|
208
|
+
temperature: Sampling temperature (0.0 to 1.0)
|
|
209
|
+
max_tokens: Maximum tokens to generate
|
|
210
|
+
context: Optional context dictionary containing metadata such as:
|
|
211
|
+
- user_id: User identifier for tracking/billing
|
|
212
|
+
- tenant_id: Tenant identifier for multi-tenant setups
|
|
213
|
+
- request_id: Request identifier for tracing
|
|
214
|
+
- session_id: Session identifier
|
|
215
|
+
- Any other custom metadata for observability or middleware
|
|
216
|
+
system_instruction: System instruction for the model
|
|
217
|
+
**kwargs: Additional provider-specific parameters
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
LLMResponse with generated text and metadata
|
|
221
|
+
"""
|
|
68
222
|
client = self._init_google_ai()
|
|
69
223
|
|
|
70
224
|
# Get model name from config if not provided
|
|
@@ -164,10 +318,30 @@ class GoogleAIClient(BaseLLMClient):
|
|
|
164
318
|
model: Optional[str] = None,
|
|
165
319
|
temperature: float = 0.7,
|
|
166
320
|
max_tokens: Optional[int] = None,
|
|
321
|
+
context: Optional[Dict[str, Any]] = None,
|
|
167
322
|
system_instruction: Optional[str] = None,
|
|
168
323
|
**kwargs,
|
|
169
324
|
) -> AsyncGenerator[str, None]:
|
|
170
|
-
"""
|
|
325
|
+
"""
|
|
326
|
+
Stream text generation using Google AI (google.genai SDK).
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
messages: List of conversation messages
|
|
330
|
+
model: Model name (optional, uses default if not provided)
|
|
331
|
+
temperature: Sampling temperature (0.0 to 1.0)
|
|
332
|
+
max_tokens: Maximum tokens to generate
|
|
333
|
+
context: Optional context dictionary containing metadata such as:
|
|
334
|
+
- user_id: User identifier for tracking/billing
|
|
335
|
+
- tenant_id: Tenant identifier for multi-tenant setups
|
|
336
|
+
- request_id: Request identifier for tracing
|
|
337
|
+
- session_id: Session identifier
|
|
338
|
+
- Any other custom metadata for observability or middleware
|
|
339
|
+
system_instruction: System instruction for the model
|
|
340
|
+
**kwargs: Additional provider-specific parameters
|
|
341
|
+
|
|
342
|
+
Yields:
|
|
343
|
+
Text tokens as they are generated
|
|
344
|
+
"""
|
|
171
345
|
client = self._init_google_ai()
|
|
172
346
|
|
|
173
347
|
# Get model name from config if not provided
|
|
@@ -52,6 +52,7 @@ class OpenAIClient(BaseLLMClient, OpenAICompatibleFunctionCallingMixin):
|
|
|
52
52
|
model: Optional[str] = None,
|
|
53
53
|
temperature: float = 0.7,
|
|
54
54
|
max_tokens: Optional[int] = None,
|
|
55
|
+
context: Optional[Dict[str, Any]] = None,
|
|
55
56
|
functions: Optional[List[Dict[str, Any]]] = None,
|
|
56
57
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
57
58
|
tool_choice: Optional[Any] = None,
|
|
@@ -65,6 +66,11 @@ class OpenAIClient(BaseLLMClient, OpenAICompatibleFunctionCallingMixin):
|
|
|
65
66
|
model: Model name (optional)
|
|
66
67
|
temperature: Temperature for generation
|
|
67
68
|
max_tokens: Maximum tokens to generate
|
|
69
|
+
context: Optional context dictionary containing metadata such as:
|
|
70
|
+
- user_id: User identifier for tracking/billing
|
|
71
|
+
- tenant_id: Tenant identifier for multi-tenant setups
|
|
72
|
+
- request_id: Request identifier for tracing
|
|
73
|
+
- session_id: Session identifier
|
|
68
74
|
functions: List of function schemas (legacy format)
|
|
69
75
|
tools: List of tool schemas (new format, recommended)
|
|
70
76
|
tool_choice: Tool choice strategy ("auto", "none", or specific tool)
|
|
@@ -103,6 +109,7 @@ class OpenAIClient(BaseLLMClient, OpenAICompatibleFunctionCallingMixin):
|
|
|
103
109
|
model: Optional[str] = None,
|
|
104
110
|
temperature: float = 0.7,
|
|
105
111
|
max_tokens: Optional[int] = None,
|
|
112
|
+
context: Optional[Dict[str, Any]] = None,
|
|
106
113
|
functions: Optional[List[Dict[str, Any]]] = None,
|
|
107
114
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
108
115
|
tool_choice: Optional[Any] = None,
|
|
@@ -117,6 +124,11 @@ class OpenAIClient(BaseLLMClient, OpenAICompatibleFunctionCallingMixin):
|
|
|
117
124
|
model: Model name (optional)
|
|
118
125
|
temperature: Temperature for generation
|
|
119
126
|
max_tokens: Maximum tokens to generate
|
|
127
|
+
context: Optional context dictionary containing metadata such as:
|
|
128
|
+
- user_id: User identifier for tracking/billing
|
|
129
|
+
- tenant_id: Tenant identifier for multi-tenant setups
|
|
130
|
+
- request_id: Request identifier for tracing
|
|
131
|
+
- session_id: Session identifier
|
|
120
132
|
functions: List of function schemas (legacy format)
|
|
121
133
|
tools: List of tool schemas (new format, recommended)
|
|
122
134
|
tool_choice: Tool choice strategy ("auto", "none", or specific tool)
|
|
@@ -11,6 +11,7 @@ from dataclasses import dataclass
|
|
|
11
11
|
from openai import AsyncOpenAI
|
|
12
12
|
|
|
13
13
|
from .base_client import LLMMessage, LLMResponse
|
|
14
|
+
from aiecs.llm.utils.image_utils import parse_image_source, ImageContent
|
|
14
15
|
|
|
15
16
|
logger = logging.getLogger(__name__)
|
|
16
17
|
|
|
@@ -49,7 +50,7 @@ class OpenAICompatibleFunctionCallingMixin:
|
|
|
49
50
|
|
|
50
51
|
def _convert_messages_to_openai_format(self, messages: List[LLMMessage]) -> List[Dict[str, Any]]:
|
|
51
52
|
"""
|
|
52
|
-
Convert LLMMessage list to OpenAI message format (support tool calls).
|
|
53
|
+
Convert LLMMessage list to OpenAI message format (support tool calls and vision).
|
|
53
54
|
|
|
54
55
|
Args:
|
|
55
56
|
messages: List of LLMMessage objects
|
|
@@ -60,8 +61,47 @@ class OpenAICompatibleFunctionCallingMixin:
|
|
|
60
61
|
openai_messages = []
|
|
61
62
|
for msg in messages:
|
|
62
63
|
msg_dict: Dict[str, Any] = {"role": msg.role}
|
|
63
|
-
|
|
64
|
+
|
|
65
|
+
# Handle multimodal content (text + images)
|
|
66
|
+
if msg.images:
|
|
67
|
+
# Build content array with text and images
|
|
68
|
+
content_array = []
|
|
69
|
+
|
|
70
|
+
# Add text content if present
|
|
71
|
+
if msg.content:
|
|
72
|
+
content_array.append({"type": "text", "text": msg.content})
|
|
73
|
+
|
|
74
|
+
# Add images
|
|
75
|
+
for image_source in msg.images:
|
|
76
|
+
image_content = parse_image_source(image_source)
|
|
77
|
+
|
|
78
|
+
if image_content.is_url():
|
|
79
|
+
# Use URL directly
|
|
80
|
+
content_array.append({
|
|
81
|
+
"type": "image_url",
|
|
82
|
+
"image_url": {
|
|
83
|
+
"url": image_content.get_url(),
|
|
84
|
+
"detail": image_content.detail,
|
|
85
|
+
}
|
|
86
|
+
})
|
|
87
|
+
else:
|
|
88
|
+
# Convert to base64 data URI
|
|
89
|
+
base64_data = image_content.get_base64_data()
|
|
90
|
+
mime_type = image_content.mime_type
|
|
91
|
+
data_uri = f"data:{mime_type};base64,{base64_data}"
|
|
92
|
+
content_array.append({
|
|
93
|
+
"type": "image_url",
|
|
94
|
+
"image_url": {
|
|
95
|
+
"url": data_uri,
|
|
96
|
+
"detail": image_content.detail,
|
|
97
|
+
}
|
|
98
|
+
})
|
|
99
|
+
|
|
100
|
+
msg_dict["content"] = content_array
|
|
101
|
+
elif msg.content is not None:
|
|
102
|
+
# Text-only content
|
|
64
103
|
msg_dict["content"] = msg.content
|
|
104
|
+
|
|
65
105
|
if msg.tool_calls:
|
|
66
106
|
msg_dict["tool_calls"] = msg.tool_calls
|
|
67
107
|
if msg.tool_call_id:
|
|
@@ -0,0 +1,272 @@
|
|
|
1
|
+
from openai import AsyncOpenAI
|
|
2
|
+
from aiecs.config.config import get_settings
|
|
3
|
+
from aiecs.llm.clients.base_client import (
|
|
4
|
+
BaseLLMClient,
|
|
5
|
+
LLMMessage,
|
|
6
|
+
LLMResponse,
|
|
7
|
+
ProviderNotAvailableError,
|
|
8
|
+
RateLimitError,
|
|
9
|
+
)
|
|
10
|
+
from aiecs.llm.clients.openai_compatible_mixin import (
|
|
11
|
+
OpenAICompatibleFunctionCallingMixin,
|
|
12
|
+
StreamChunk,
|
|
13
|
+
)
|
|
14
|
+
from tenacity import (
|
|
15
|
+
retry,
|
|
16
|
+
stop_after_attempt,
|
|
17
|
+
wait_exponential,
|
|
18
|
+
retry_if_exception_type,
|
|
19
|
+
)
|
|
20
|
+
import logging
|
|
21
|
+
from typing import Dict, Optional, List, AsyncGenerator, cast, Any
|
|
22
|
+
|
|
23
|
+
# Lazy import to avoid circular dependency
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _get_config_loader():
|
|
27
|
+
"""Lazy import of config loader to avoid circular dependency"""
|
|
28
|
+
from aiecs.llm.config import get_llm_config_loader
|
|
29
|
+
|
|
30
|
+
return get_llm_config_loader()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class OpenRouterClient(BaseLLMClient, OpenAICompatibleFunctionCallingMixin):
|
|
37
|
+
"""OpenRouter provider client using OpenAI-compatible API"""
|
|
38
|
+
|
|
39
|
+
def __init__(self) -> None:
|
|
40
|
+
super().__init__("OpenRouter")
|
|
41
|
+
self.settings = get_settings()
|
|
42
|
+
self._openai_client: Optional[AsyncOpenAI] = None
|
|
43
|
+
self._model_map: Optional[Dict[str, str]] = None
|
|
44
|
+
|
|
45
|
+
def _get_openai_client(self) -> AsyncOpenAI:
|
|
46
|
+
"""Lazy initialization of OpenAI client for OpenRouter"""
|
|
47
|
+
if not self._openai_client:
|
|
48
|
+
api_key = self._get_api_key()
|
|
49
|
+
self._openai_client = AsyncOpenAI(
|
|
50
|
+
api_key=api_key,
|
|
51
|
+
base_url="https://openrouter.ai/api/v1",
|
|
52
|
+
timeout=360.0,
|
|
53
|
+
)
|
|
54
|
+
return self._openai_client
|
|
55
|
+
|
|
56
|
+
def _get_api_key(self) -> str:
|
|
57
|
+
"""Get API key from settings"""
|
|
58
|
+
api_key = getattr(self.settings, "openrouter_api_key", None)
|
|
59
|
+
if not api_key:
|
|
60
|
+
raise ProviderNotAvailableError("OpenRouter API key not configured. Set OPENROUTER_API_KEY.")
|
|
61
|
+
return api_key
|
|
62
|
+
|
|
63
|
+
def _get_model_map(self) -> Dict[str, str]:
|
|
64
|
+
"""Get model mappings from configuration"""
|
|
65
|
+
if self._model_map is None:
|
|
66
|
+
try:
|
|
67
|
+
loader = _get_config_loader()
|
|
68
|
+
provider_config = loader.get_provider_config("OpenRouter")
|
|
69
|
+
if provider_config and provider_config.model_mappings:
|
|
70
|
+
self._model_map = provider_config.model_mappings
|
|
71
|
+
else:
|
|
72
|
+
self._model_map = {}
|
|
73
|
+
except Exception as e:
|
|
74
|
+
self.logger.warning(f"Failed to load model mappings from config: {e}")
|
|
75
|
+
self._model_map = {}
|
|
76
|
+
return self._model_map
|
|
77
|
+
|
|
78
|
+
def _get_extra_headers(self, **kwargs) -> Dict[str, str]:
|
|
79
|
+
"""
|
|
80
|
+
Get extra headers for OpenRouter API.
|
|
81
|
+
|
|
82
|
+
Supports HTTP-Referer and X-Title headers from kwargs or settings.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
**kwargs: May contain http_referer and x_title
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
Dictionary with extra headers
|
|
89
|
+
"""
|
|
90
|
+
extra_headers: Dict[str, str] = {}
|
|
91
|
+
|
|
92
|
+
# Get from kwargs first, then from settings
|
|
93
|
+
http_referer = kwargs.get("http_referer") or getattr(self.settings, "openrouter_http_referer", None)
|
|
94
|
+
x_title = kwargs.get("x_title") or getattr(self.settings, "openrouter_x_title", None)
|
|
95
|
+
|
|
96
|
+
if http_referer:
|
|
97
|
+
extra_headers["HTTP-Referer"] = http_referer
|
|
98
|
+
if x_title:
|
|
99
|
+
extra_headers["X-Title"] = x_title
|
|
100
|
+
|
|
101
|
+
return extra_headers
|
|
102
|
+
|
|
103
|
+
@retry(
|
|
104
|
+
stop=stop_after_attempt(3),
|
|
105
|
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
106
|
+
retry=retry_if_exception_type((Exception, RateLimitError)),
|
|
107
|
+
)
|
|
108
|
+
async def generate_text(
|
|
109
|
+
self,
|
|
110
|
+
messages: List[LLMMessage],
|
|
111
|
+
model: Optional[str] = None,
|
|
112
|
+
temperature: float = 0.7,
|
|
113
|
+
max_tokens: Optional[int] = None,
|
|
114
|
+
functions: Optional[List[Dict[str, Any]]] = None,
|
|
115
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
116
|
+
tool_choice: Optional[Any] = None,
|
|
117
|
+
**kwargs,
|
|
118
|
+
) -> LLMResponse:
|
|
119
|
+
"""
|
|
120
|
+
Generate text using OpenRouter API via OpenAI library.
|
|
121
|
+
|
|
122
|
+
OpenRouter API is OpenAI-compatible, so it supports Function Calling and Vision.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
messages: List of LLM messages
|
|
126
|
+
model: Model name (optional, uses default from config if not provided)
|
|
127
|
+
temperature: Temperature for generation
|
|
128
|
+
max_tokens: Maximum tokens to generate
|
|
129
|
+
functions: List of function schemas (legacy format)
|
|
130
|
+
tools: List of tool schemas (new format, recommended)
|
|
131
|
+
tool_choice: Tool choice strategy ("auto", "none", or specific tool)
|
|
132
|
+
http_referer: Optional HTTP-Referer header for OpenRouter rankings
|
|
133
|
+
x_title: Optional X-Title header for OpenRouter rankings
|
|
134
|
+
**kwargs: Additional arguments passed to OpenRouter API
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
LLMResponse with content and optional function_call information
|
|
138
|
+
"""
|
|
139
|
+
# Check API key availability
|
|
140
|
+
api_key = self._get_api_key()
|
|
141
|
+
if not api_key:
|
|
142
|
+
raise ProviderNotAvailableError("OpenRouter API key is not configured.")
|
|
143
|
+
|
|
144
|
+
client = self._get_openai_client()
|
|
145
|
+
|
|
146
|
+
# Get model name from config if not provided
|
|
147
|
+
selected_model = model or self._get_default_model() or "openai/gpt-4o"
|
|
148
|
+
|
|
149
|
+
# Get model mappings from config
|
|
150
|
+
model_map = self._get_model_map()
|
|
151
|
+
api_model = model_map.get(selected_model, selected_model)
|
|
152
|
+
|
|
153
|
+
# Extract extra headers from kwargs
|
|
154
|
+
extra_headers = self._get_extra_headers(**kwargs)
|
|
155
|
+
|
|
156
|
+
# Remove extra header kwargs to avoid passing them to API
|
|
157
|
+
kwargs_clean = {k: v for k, v in kwargs.items() if k not in ("http_referer", "x_title")}
|
|
158
|
+
|
|
159
|
+
# Add extra_headers to kwargs if present
|
|
160
|
+
if extra_headers:
|
|
161
|
+
kwargs_clean["extra_headers"] = extra_headers
|
|
162
|
+
|
|
163
|
+
try:
|
|
164
|
+
# Use mixin method for Function Calling support
|
|
165
|
+
response = await self._generate_text_with_function_calling(
|
|
166
|
+
client=client,
|
|
167
|
+
messages=messages,
|
|
168
|
+
model=api_model,
|
|
169
|
+
temperature=temperature,
|
|
170
|
+
max_tokens=max_tokens,
|
|
171
|
+
functions=functions,
|
|
172
|
+
tools=tools,
|
|
173
|
+
tool_choice=tool_choice,
|
|
174
|
+
**kwargs_clean,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Override provider and model name for OpenRouter
|
|
178
|
+
response.provider = self.provider_name
|
|
179
|
+
response.model = selected_model
|
|
180
|
+
|
|
181
|
+
return response
|
|
182
|
+
|
|
183
|
+
except Exception as e:
|
|
184
|
+
if "rate limit" in str(e).lower() or "429" in str(e):
|
|
185
|
+
raise RateLimitError(f"OpenRouter rate limit exceeded: {str(e)}")
|
|
186
|
+
logger.error(f"OpenRouter API error: {str(e)}")
|
|
187
|
+
raise
|
|
188
|
+
|
|
189
|
+
async def stream_text( # type: ignore[override]
|
|
190
|
+
self,
|
|
191
|
+
messages: List[LLMMessage],
|
|
192
|
+
model: Optional[str] = None,
|
|
193
|
+
temperature: float = 0.7,
|
|
194
|
+
max_tokens: Optional[int] = None,
|
|
195
|
+
functions: Optional[List[Dict[str, Any]]] = None,
|
|
196
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
197
|
+
tool_choice: Optional[Any] = None,
|
|
198
|
+
return_chunks: bool = False,
|
|
199
|
+
**kwargs,
|
|
200
|
+
) -> AsyncGenerator[Any, None]:
|
|
201
|
+
"""
|
|
202
|
+
Stream text using OpenRouter API via OpenAI library.
|
|
203
|
+
|
|
204
|
+
OpenRouter API is OpenAI-compatible, so it supports Function Calling and Vision.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
messages: List of LLM messages
|
|
208
|
+
model: Model name (optional, uses default from config if not provided)
|
|
209
|
+
temperature: Temperature for generation
|
|
210
|
+
max_tokens: Maximum tokens to generate
|
|
211
|
+
functions: List of function schemas (legacy format)
|
|
212
|
+
tools: List of tool schemas (new format, recommended)
|
|
213
|
+
tool_choice: Tool choice strategy ("auto", "none", or specific tool)
|
|
214
|
+
return_chunks: If True, returns StreamChunk objects with tool_calls info; if False, returns str tokens only
|
|
215
|
+
http_referer: Optional HTTP-Referer header for OpenRouter rankings
|
|
216
|
+
x_title: Optional X-Title header for OpenRouter rankings
|
|
217
|
+
**kwargs: Additional arguments passed to OpenRouter API
|
|
218
|
+
|
|
219
|
+
Yields:
|
|
220
|
+
str or StreamChunk: Text tokens as they are generated, or StreamChunk objects if return_chunks=True
|
|
221
|
+
"""
|
|
222
|
+
# Check API key availability
|
|
223
|
+
api_key = self._get_api_key()
|
|
224
|
+
if not api_key:
|
|
225
|
+
raise ProviderNotAvailableError("OpenRouter API key is not configured.")
|
|
226
|
+
|
|
227
|
+
client = self._get_openai_client()
|
|
228
|
+
|
|
229
|
+
# Get model name from config if not provided
|
|
230
|
+
selected_model = model or self._get_default_model() or "openai/gpt-4o"
|
|
231
|
+
|
|
232
|
+
# Get model mappings from config
|
|
233
|
+
model_map = self._get_model_map()
|
|
234
|
+
api_model = model_map.get(selected_model, selected_model)
|
|
235
|
+
|
|
236
|
+
# Extract extra headers from kwargs
|
|
237
|
+
extra_headers = self._get_extra_headers(**kwargs)
|
|
238
|
+
|
|
239
|
+
# Remove extra header kwargs to avoid passing them to API
|
|
240
|
+
kwargs_clean = {k: v for k, v in kwargs.items() if k not in ("http_referer", "x_title")}
|
|
241
|
+
|
|
242
|
+
# Add extra_headers to kwargs if present
|
|
243
|
+
if extra_headers:
|
|
244
|
+
kwargs_clean["extra_headers"] = extra_headers
|
|
245
|
+
|
|
246
|
+
try:
|
|
247
|
+
# Use mixin method for Function Calling support
|
|
248
|
+
async for chunk in self._stream_text_with_function_calling(
|
|
249
|
+
client=client,
|
|
250
|
+
messages=messages,
|
|
251
|
+
model=api_model,
|
|
252
|
+
temperature=temperature,
|
|
253
|
+
max_tokens=max_tokens,
|
|
254
|
+
functions=functions,
|
|
255
|
+
tools=tools,
|
|
256
|
+
tool_choice=tool_choice,
|
|
257
|
+
return_chunks=return_chunks,
|
|
258
|
+
**kwargs_clean,
|
|
259
|
+
):
|
|
260
|
+
yield chunk
|
|
261
|
+
|
|
262
|
+
except Exception as e:
|
|
263
|
+
if "rate limit" in str(e).lower() or "429" in str(e):
|
|
264
|
+
raise RateLimitError(f"OpenRouter rate limit exceeded: {str(e)}")
|
|
265
|
+
logger.error(f"OpenRouter API streaming error: {str(e)}")
|
|
266
|
+
raise
|
|
267
|
+
|
|
268
|
+
async def close(self):
|
|
269
|
+
"""Clean up resources"""
|
|
270
|
+
if self._openai_client:
|
|
271
|
+
await self._openai_client.close()
|
|
272
|
+
self._openai_client = None
|