aiecs 1.7.17__py3-none-any.whl → 1.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aiecs might be problematic. Click here for more details.
- aiecs/__init__.py +1 -1
- aiecs/application/knowledge_graph/extractors/llm_entity_extractor.py +5 -1
- aiecs/application/knowledge_graph/retrieval/query_intent_classifier.py +7 -5
- aiecs/config/config.py +3 -0
- aiecs/domain/agent/hybrid_agent.py +93 -12
- aiecs/domain/agent/knowledge_aware_agent.py +3 -2
- aiecs/domain/agent/llm_agent.py +2 -0
- aiecs/llm/callbacks/custom_callbacks.py +9 -4
- aiecs/llm/client_factory.py +14 -6
- aiecs/llm/clients/base_client.py +45 -4
- aiecs/llm/clients/googleai_client.py +105 -4
- aiecs/llm/clients/openai_client.py +12 -0
- aiecs/llm/clients/openai_compatible_mixin.py +42 -2
- aiecs/llm/clients/openrouter_client.py +272 -0
- aiecs/llm/clients/vertex_client.py +79 -5
- aiecs/llm/clients/xai_client.py +41 -3
- aiecs/llm/protocols.py +19 -1
- aiecs/llm/utils/image_utils.py +179 -0
- aiecs/main.py +2 -2
- aiecs/tools/task_tools/scraper_tool.py +39 -2
- {aiecs-1.7.17.dist-info → aiecs-1.8.4.dist-info}/METADATA +4 -2
- {aiecs-1.7.17.dist-info → aiecs-1.8.4.dist-info}/RECORD +26 -24
- {aiecs-1.7.17.dist-info → aiecs-1.8.4.dist-info}/WHEEL +0 -0
- {aiecs-1.7.17.dist-info → aiecs-1.8.4.dist-info}/entry_points.txt +0 -0
- {aiecs-1.7.17.dist-info → aiecs-1.8.4.dist-info}/licenses/LICENSE +0 -0
- {aiecs-1.7.17.dist-info → aiecs-1.8.4.dist-info}/top_level.txt +0 -0
aiecs/__init__.py
CHANGED
|
@@ -5,7 +5,7 @@ A powerful Python middleware framework for building AI-powered applications
|
|
|
5
5
|
with tool orchestration, task execution, and multi-provider LLM support.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
-
__version__ = "1.
|
|
8
|
+
__version__ = "1.8.4"
|
|
9
9
|
__author__ = "AIECS Team"
|
|
10
10
|
__email__ = "iretbl@gmail.com"
|
|
11
11
|
|
|
@@ -162,7 +162,7 @@ class LLMEntityExtractor(EntityExtractor):
|
|
|
162
162
|
Args:
|
|
163
163
|
text: Input text to extract entities from
|
|
164
164
|
entity_types: Optional filter for specific entity types
|
|
165
|
-
**kwargs: Additional parameters (e.g., custom prompt, examples)
|
|
165
|
+
**kwargs: Additional parameters (e.g., custom prompt, examples, context)
|
|
166
166
|
|
|
167
167
|
Returns:
|
|
168
168
|
List of extracted Entity objects
|
|
@@ -174,6 +174,9 @@ class LLMEntityExtractor(EntityExtractor):
|
|
|
174
174
|
if not text or not text.strip():
|
|
175
175
|
raise ValueError("Input text cannot be empty")
|
|
176
176
|
|
|
177
|
+
# Extract context from kwargs if provided
|
|
178
|
+
context = kwargs.get("context")
|
|
179
|
+
|
|
177
180
|
# Build extraction prompt
|
|
178
181
|
prompt = self._build_extraction_prompt(text, entity_types)
|
|
179
182
|
|
|
@@ -189,6 +192,7 @@ class LLMEntityExtractor(EntityExtractor):
|
|
|
189
192
|
model=self.model,
|
|
190
193
|
temperature=self.temperature,
|
|
191
194
|
max_tokens=self.max_tokens,
|
|
195
|
+
context=context,
|
|
192
196
|
)
|
|
193
197
|
# Otherwise use LLM manager with provider
|
|
194
198
|
else:
|
|
@@ -6,7 +6,7 @@ Uses a lightweight LLM to determine the best retrieval approach based on query c
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
import logging
|
|
9
|
-
from typing import Optional, Dict, TYPE_CHECKING
|
|
9
|
+
from typing import Optional, Dict, Any, TYPE_CHECKING
|
|
10
10
|
from aiecs.application.knowledge_graph.retrieval.strategy_types import RetrievalStrategy
|
|
11
11
|
|
|
12
12
|
if TYPE_CHECKING:
|
|
@@ -58,12 +58,13 @@ class QueryIntentClassifier:
|
|
|
58
58
|
self.enable_caching = enable_caching
|
|
59
59
|
self._cache: Dict[str, RetrievalStrategy] = {}
|
|
60
60
|
|
|
61
|
-
async def classify_intent(self, query: str) -> RetrievalStrategy:
|
|
61
|
+
async def classify_intent(self, query: str, context: Optional[Dict[str, Any]] = None) -> RetrievalStrategy:
|
|
62
62
|
"""
|
|
63
63
|
Classify query intent and return optimal retrieval strategy.
|
|
64
64
|
|
|
65
65
|
Args:
|
|
66
66
|
query: Query string to classify
|
|
67
|
+
context: Optional context dictionary for tracking/observability
|
|
67
68
|
|
|
68
69
|
Returns:
|
|
69
70
|
RetrievalStrategy enum value
|
|
@@ -82,7 +83,7 @@ class QueryIntentClassifier:
|
|
|
82
83
|
# Use LLM classification if client is available
|
|
83
84
|
if self.llm_client is not None:
|
|
84
85
|
try:
|
|
85
|
-
strategy = await self._classify_with_llm(query)
|
|
86
|
+
strategy = await self._classify_with_llm(query, context)
|
|
86
87
|
except Exception as e:
|
|
87
88
|
logger.warning(f"LLM classification failed: {e}, falling back to rule-based")
|
|
88
89
|
strategy = self._classify_with_rules(query)
|
|
@@ -96,7 +97,7 @@ class QueryIntentClassifier:
|
|
|
96
97
|
|
|
97
98
|
return strategy
|
|
98
99
|
|
|
99
|
-
async def _classify_with_llm(self, query: str) -> RetrievalStrategy:
|
|
100
|
+
async def _classify_with_llm(self, query: str, context: Optional[Dict[str, Any]] = None) -> RetrievalStrategy:
|
|
100
101
|
"""
|
|
101
102
|
Classify query using LLM.
|
|
102
103
|
|
|
@@ -127,11 +128,12 @@ Respond with ONLY the strategy name (e.g., "MULTI_HOP"). No explanation needed."
|
|
|
127
128
|
if self.llm_client is None:
|
|
128
129
|
# Fallback to rule-based classification if no LLM client
|
|
129
130
|
return self._classify_with_rules(query)
|
|
130
|
-
|
|
131
|
+
|
|
131
132
|
response = await self.llm_client.generate_text(
|
|
132
133
|
messages=messages,
|
|
133
134
|
temperature=0.0, # Deterministic classification
|
|
134
135
|
max_tokens=20, # Short response
|
|
136
|
+
context=context,
|
|
135
137
|
)
|
|
136
138
|
|
|
137
139
|
# Parse response
|
aiecs/config/config.py
CHANGED
|
@@ -47,6 +47,9 @@ class Settings(BaseSettings):
|
|
|
47
47
|
google_cse_id: str = Field(default="", alias="GOOGLE_CSE_ID")
|
|
48
48
|
xai_api_key: str = Field(default="", alias="XAI_API_KEY")
|
|
49
49
|
grok_api_key: str = Field(default="", alias="GROK_API_KEY") # Backward compatibility
|
|
50
|
+
openrouter_api_key: str = Field(default="", alias="OPENROUTER_API_KEY")
|
|
51
|
+
openrouter_http_referer: str = Field(default="", alias="OPENROUTER_HTTP_REFERER")
|
|
52
|
+
openrouter_x_title: str = Field(default="", alias="OPENROUTER_X_TITLE")
|
|
50
53
|
|
|
51
54
|
# LLM Models Configuration
|
|
52
55
|
llm_models_config_path: str = Field(
|
|
@@ -242,7 +242,7 @@ class HybridAgent(BaseAIAgent):
|
|
|
242
242
|
config: AgentConfiguration,
|
|
243
243
|
description: Optional[str] = None,
|
|
244
244
|
version: str = "1.0.0",
|
|
245
|
-
max_iterations: int =
|
|
245
|
+
max_iterations: Optional[int] = None,
|
|
246
246
|
config_manager: Optional["ConfigManagerProtocol"] = None,
|
|
247
247
|
checkpointer: Optional["CheckpointerProtocol"] = None,
|
|
248
248
|
context_engine: Optional[Any] = None,
|
|
@@ -262,7 +262,7 @@ class HybridAgent(BaseAIAgent):
|
|
|
262
262
|
config: Agent configuration
|
|
263
263
|
description: Optional description
|
|
264
264
|
version: Agent version
|
|
265
|
-
max_iterations: Maximum ReAct iterations
|
|
265
|
+
max_iterations: Maximum ReAct iterations (if None, uses config.max_iterations)
|
|
266
266
|
config_manager: Optional configuration manager for dynamic config
|
|
267
267
|
checkpointer: Optional checkpointer for state persistence
|
|
268
268
|
context_engine: Optional context engine for persistent storage
|
|
@@ -316,7 +316,17 @@ class HybridAgent(BaseAIAgent):
|
|
|
316
316
|
|
|
317
317
|
# Store LLM client reference (from BaseAIAgent or local)
|
|
318
318
|
self.llm_client = self._llm_client if self._llm_client else llm_client
|
|
319
|
-
|
|
319
|
+
|
|
320
|
+
# Use config.max_iterations if constructor parameter is None
|
|
321
|
+
# This makes max_iterations consistent with max_tokens (both configurable via config)
|
|
322
|
+
# If max_iterations is explicitly provided, it takes precedence over config
|
|
323
|
+
if max_iterations is None:
|
|
324
|
+
# Use config value (defaults to 10 if not set in config)
|
|
325
|
+
self._max_iterations = config.max_iterations
|
|
326
|
+
else:
|
|
327
|
+
# Constructor parameter explicitly provided, use it
|
|
328
|
+
self._max_iterations = max_iterations
|
|
329
|
+
|
|
320
330
|
self._system_prompt: Optional[str] = None
|
|
321
331
|
self._conversation_history: List[LLMMessage] = []
|
|
322
332
|
self._tool_schemas: List[Dict[str, Any]] = []
|
|
@@ -456,6 +466,24 @@ class HybridAgent(BaseAIAgent):
|
|
|
456
466
|
agent_id=self.agent_id,
|
|
457
467
|
)
|
|
458
468
|
|
|
469
|
+
# Extract images from task dict and merge into context
|
|
470
|
+
task_images = task.get("images")
|
|
471
|
+
if task_images:
|
|
472
|
+
# Merge images from task into context
|
|
473
|
+
# If context already has images, combine them
|
|
474
|
+
if "images" in context:
|
|
475
|
+
existing_images = context["images"]
|
|
476
|
+
if isinstance(existing_images, list) and isinstance(task_images, list):
|
|
477
|
+
context["images"] = existing_images + task_images
|
|
478
|
+
elif isinstance(existing_images, list):
|
|
479
|
+
context["images"] = existing_images + [task_images]
|
|
480
|
+
elif isinstance(task_images, list):
|
|
481
|
+
context["images"] = [existing_images] + task_images
|
|
482
|
+
else:
|
|
483
|
+
context["images"] = [existing_images, task_images]
|
|
484
|
+
else:
|
|
485
|
+
context["images"] = task_images
|
|
486
|
+
|
|
459
487
|
# Transition to busy state
|
|
460
488
|
self._transition_state(self.state.__class__.BUSY)
|
|
461
489
|
self._current_task_id = task.get("task_id")
|
|
@@ -572,6 +600,24 @@ class HybridAgent(BaseAIAgent):
|
|
|
572
600
|
}
|
|
573
601
|
return
|
|
574
602
|
|
|
603
|
+
# Extract images from task dict and merge into context
|
|
604
|
+
task_images = task.get("images")
|
|
605
|
+
if task_images:
|
|
606
|
+
# Merge images from task into context
|
|
607
|
+
# If context already has images, combine them
|
|
608
|
+
if "images" in context:
|
|
609
|
+
existing_images = context["images"]
|
|
610
|
+
if isinstance(existing_images, list) and isinstance(task_images, list):
|
|
611
|
+
context["images"] = existing_images + task_images
|
|
612
|
+
elif isinstance(existing_images, list):
|
|
613
|
+
context["images"] = existing_images + [task_images]
|
|
614
|
+
elif isinstance(task_images, list):
|
|
615
|
+
context["images"] = [existing_images] + task_images
|
|
616
|
+
else:
|
|
617
|
+
context["images"] = [existing_images, task_images]
|
|
618
|
+
else:
|
|
619
|
+
context["images"] = task_images
|
|
620
|
+
|
|
575
621
|
# Transition to busy state
|
|
576
622
|
self._transition_state(self.state.__class__.BUSY)
|
|
577
623
|
self._current_task_id = task.get("task_id")
|
|
@@ -712,6 +758,7 @@ class HybridAgent(BaseAIAgent):
|
|
|
712
758
|
model=self._config.llm_model,
|
|
713
759
|
temperature=self._config.temperature,
|
|
714
760
|
max_tokens=self._config.max_tokens,
|
|
761
|
+
context=context,
|
|
715
762
|
tools=tools,
|
|
716
763
|
tool_choice="auto",
|
|
717
764
|
return_chunks=True, # Enable tool_calls accumulation
|
|
@@ -723,6 +770,7 @@ class HybridAgent(BaseAIAgent):
|
|
|
723
770
|
model=self._config.llm_model,
|
|
724
771
|
temperature=self._config.temperature,
|
|
725
772
|
max_tokens=self._config.max_tokens,
|
|
773
|
+
context=context,
|
|
726
774
|
)
|
|
727
775
|
|
|
728
776
|
# Stream tokens and collect tool calls
|
|
@@ -1064,6 +1112,7 @@ class HybridAgent(BaseAIAgent):
|
|
|
1064
1112
|
model=self._config.llm_model,
|
|
1065
1113
|
temperature=self._config.temperature,
|
|
1066
1114
|
max_tokens=self._config.max_tokens,
|
|
1115
|
+
context=context,
|
|
1067
1116
|
tools=tools,
|
|
1068
1117
|
tool_choice="auto",
|
|
1069
1118
|
)
|
|
@@ -1074,6 +1123,7 @@ class HybridAgent(BaseAIAgent):
|
|
|
1074
1123
|
model=self._config.llm_model,
|
|
1075
1124
|
temperature=self._config.temperature,
|
|
1076
1125
|
max_tokens=self._config.max_tokens,
|
|
1126
|
+
context=context,
|
|
1077
1127
|
)
|
|
1078
1128
|
|
|
1079
1129
|
thought_raw = response.content or ""
|
|
@@ -1330,6 +1380,9 @@ class HybridAgent(BaseAIAgent):
|
|
|
1330
1380
|
)
|
|
1331
1381
|
)
|
|
1332
1382
|
|
|
1383
|
+
# Collect images from context to attach to task message
|
|
1384
|
+
task_images = []
|
|
1385
|
+
|
|
1333
1386
|
# Add context if provided
|
|
1334
1387
|
if context:
|
|
1335
1388
|
# Special handling: if context contains 'history' as a list of messages,
|
|
@@ -1340,18 +1393,40 @@ class HybridAgent(BaseAIAgent):
|
|
|
1340
1393
|
for msg in history:
|
|
1341
1394
|
if isinstance(msg, dict) and "role" in msg and "content" in msg:
|
|
1342
1395
|
# Valid message format - add as separate message
|
|
1343
|
-
|
|
1344
|
-
|
|
1345
|
-
|
|
1346
|
-
|
|
1396
|
+
# Extract images if present
|
|
1397
|
+
msg_images = msg.get("images", [])
|
|
1398
|
+
if msg_images:
|
|
1399
|
+
messages.append(
|
|
1400
|
+
LLMMessage(
|
|
1401
|
+
role=msg["role"],
|
|
1402
|
+
content=msg["content"],
|
|
1403
|
+
images=msg_images if isinstance(msg_images, list) else [msg_images],
|
|
1404
|
+
)
|
|
1405
|
+
)
|
|
1406
|
+
else:
|
|
1407
|
+
messages.append(
|
|
1408
|
+
LLMMessage(
|
|
1409
|
+
role=msg["role"],
|
|
1410
|
+
content=msg["content"],
|
|
1411
|
+
)
|
|
1347
1412
|
)
|
|
1348
|
-
)
|
|
1349
1413
|
elif isinstance(msg, LLMMessage):
|
|
1350
|
-
# Already an LLMMessage instance
|
|
1414
|
+
# Already an LLMMessage instance (may already have images)
|
|
1351
1415
|
messages.append(msg)
|
|
1352
1416
|
|
|
1353
|
-
#
|
|
1354
|
-
|
|
1417
|
+
# Extract images from context if present
|
|
1418
|
+
context_images = context.get("images")
|
|
1419
|
+
if context_images:
|
|
1420
|
+
if isinstance(context_images, list):
|
|
1421
|
+
task_images.extend(context_images)
|
|
1422
|
+
else:
|
|
1423
|
+
task_images.append(context_images)
|
|
1424
|
+
|
|
1425
|
+
# Format remaining context fields (excluding history and images) as Additional Context
|
|
1426
|
+
context_without_history = {
|
|
1427
|
+
k: v for k, v in context.items()
|
|
1428
|
+
if k not in ("history", "images")
|
|
1429
|
+
}
|
|
1355
1430
|
if context_without_history:
|
|
1356
1431
|
context_str = self._format_context(context_without_history)
|
|
1357
1432
|
if context_str:
|
|
@@ -1367,7 +1442,13 @@ class HybridAgent(BaseAIAgent):
|
|
|
1367
1442
|
f"Task: {task}\n\n"
|
|
1368
1443
|
f"[Iteration 1/{self._max_iterations}, remaining: {self._max_iterations - 1}]"
|
|
1369
1444
|
)
|
|
1370
|
-
messages.append(
|
|
1445
|
+
messages.append(
|
|
1446
|
+
LLMMessage(
|
|
1447
|
+
role="user",
|
|
1448
|
+
content=task_message,
|
|
1449
|
+
images=task_images if task_images else [],
|
|
1450
|
+
)
|
|
1451
|
+
)
|
|
1371
1452
|
|
|
1372
1453
|
return messages
|
|
1373
1454
|
|
|
@@ -95,7 +95,7 @@ class KnowledgeAwareAgent(HybridAgent):
|
|
|
95
95
|
graph_store: Optional[GraphStore] = None,
|
|
96
96
|
description: Optional[str] = None,
|
|
97
97
|
version: str = "1.0.0",
|
|
98
|
-
max_iterations: int =
|
|
98
|
+
max_iterations: Optional[int] = None,
|
|
99
99
|
enable_graph_reasoning: bool = True,
|
|
100
100
|
config_manager: Optional["ConfigManagerProtocol"] = None,
|
|
101
101
|
checkpointer: Optional["CheckpointerProtocol"] = None,
|
|
@@ -118,7 +118,7 @@ class KnowledgeAwareAgent(HybridAgent):
|
|
|
118
118
|
graph_store: Optional knowledge graph store
|
|
119
119
|
description: Optional description
|
|
120
120
|
version: Agent version
|
|
121
|
-
max_iterations: Maximum ReAct iterations
|
|
121
|
+
max_iterations: Maximum ReAct iterations (if None, uses config.max_iterations)
|
|
122
122
|
enable_graph_reasoning: Whether to enable graph reasoning capabilities
|
|
123
123
|
config_manager: Optional configuration manager for dynamic config
|
|
124
124
|
checkpointer: Optional checkpointer for state persistence
|
|
@@ -745,6 +745,7 @@ Use graph reasoning proactively when questions involve:
|
|
|
745
745
|
model=self._config.llm_model,
|
|
746
746
|
temperature=self._config.temperature,
|
|
747
747
|
max_tokens=self._config.max_tokens,
|
|
748
|
+
context=context,
|
|
748
749
|
)
|
|
749
750
|
|
|
750
751
|
thought = response.content
|
aiecs/domain/agent/llm_agent.py
CHANGED
|
@@ -376,6 +376,7 @@ class LLMAgent(BaseAIAgent):
|
|
|
376
376
|
model=self._config.llm_model,
|
|
377
377
|
temperature=self._config.temperature,
|
|
378
378
|
max_tokens=self._config.max_tokens,
|
|
379
|
+
context=context,
|
|
379
380
|
)
|
|
380
381
|
|
|
381
382
|
# Extract result
|
|
@@ -513,6 +514,7 @@ class LLMAgent(BaseAIAgent):
|
|
|
513
514
|
model=self._config.llm_model,
|
|
514
515
|
temperature=self._config.temperature,
|
|
515
516
|
max_tokens=self._config.max_tokens,
|
|
517
|
+
context=context,
|
|
516
518
|
):
|
|
517
519
|
output_tokens.append(token)
|
|
518
520
|
yield {
|
|
@@ -33,7 +33,9 @@ class RedisTokenCallbackHandler(CustomAsyncCallbackHandler):
|
|
|
33
33
|
self.start_time = time.time()
|
|
34
34
|
self.messages = messages
|
|
35
35
|
|
|
36
|
-
|
|
36
|
+
# Defensive check for None messages
|
|
37
|
+
message_count = len(messages) if messages is not None else 0
|
|
38
|
+
logger.info(f"[Callback] LLM call started for user '{self.user_id}' with {message_count} messages")
|
|
37
39
|
|
|
38
40
|
async def on_llm_end(self, response: dict, **kwargs: Any) -> None:
|
|
39
41
|
"""Triggered when LLM call ends successfully"""
|
|
@@ -93,8 +95,8 @@ class DetailedRedisTokenCallbackHandler(CustomAsyncCallbackHandler):
|
|
|
93
95
|
self.start_time = time.time()
|
|
94
96
|
self.messages = messages
|
|
95
97
|
|
|
96
|
-
# Estimate input token count
|
|
97
|
-
self.prompt_tokens = self._estimate_prompt_tokens(messages)
|
|
98
|
+
# Estimate input token count with None check
|
|
99
|
+
self.prompt_tokens = self._estimate_prompt_tokens(messages) if messages else 0
|
|
98
100
|
|
|
99
101
|
logger.info(f"[DetailedCallback] LLM call started for user '{self.user_id}' with estimated {self.prompt_tokens} prompt tokens")
|
|
100
102
|
|
|
@@ -144,7 +146,10 @@ class DetailedRedisTokenCallbackHandler(CustomAsyncCallbackHandler):
|
|
|
144
146
|
|
|
145
147
|
def _estimate_prompt_tokens(self, messages: List[dict]) -> int:
|
|
146
148
|
"""Estimate token count for input messages"""
|
|
147
|
-
|
|
149
|
+
if not messages:
|
|
150
|
+
return 0
|
|
151
|
+
# Use `or ""` to handle both missing key AND None value
|
|
152
|
+
total_chars = sum(len(msg.get("content") or "") for msg in messages)
|
|
148
153
|
# Rough estimation: 4 characters ≈ 1 token
|
|
149
154
|
return total_chars // 4
|
|
150
155
|
|
aiecs/llm/client_factory.py
CHANGED
|
@@ -7,6 +7,7 @@ from .clients.openai_client import OpenAIClient
|
|
|
7
7
|
from .clients.vertex_client import VertexAIClient
|
|
8
8
|
from .clients.googleai_client import GoogleAIClient
|
|
9
9
|
from .clients.xai_client import XAIClient
|
|
10
|
+
from .clients.openrouter_client import OpenRouterClient
|
|
10
11
|
from .clients.openai_compatible_mixin import StreamChunk
|
|
11
12
|
from .callbacks.custom_callbacks import CustomAsyncCallbackHandler
|
|
12
13
|
|
|
@@ -21,6 +22,7 @@ class AIProvider(str, Enum):
|
|
|
21
22
|
VERTEX = "Vertex"
|
|
22
23
|
GOOGLEAI = "GoogleAI"
|
|
23
24
|
XAI = "xAI"
|
|
25
|
+
OPENROUTER = "OpenRouter"
|
|
24
26
|
|
|
25
27
|
|
|
26
28
|
class LLMClientFactory:
|
|
@@ -133,6 +135,8 @@ class LLMClientFactory:
|
|
|
133
135
|
return GoogleAIClient()
|
|
134
136
|
elif provider == AIProvider.XAI:
|
|
135
137
|
return XAIClient()
|
|
138
|
+
elif provider == AIProvider.OPENROUTER:
|
|
139
|
+
return OpenRouterClient()
|
|
136
140
|
else:
|
|
137
141
|
raise ValueError(f"Unsupported provider: {provider}")
|
|
138
142
|
|
|
@@ -262,14 +266,16 @@ class LLMClientManager:
|
|
|
262
266
|
final_provider = context_provider or provider or AIProvider.OPENAI
|
|
263
267
|
final_model = context_model or model
|
|
264
268
|
|
|
265
|
-
# Convert string prompt to messages format
|
|
266
|
-
if
|
|
269
|
+
# Convert string prompt to messages format and handle None
|
|
270
|
+
if messages is None:
|
|
271
|
+
messages = []
|
|
272
|
+
elif isinstance(messages, str):
|
|
267
273
|
messages = [LLMMessage(role="user", content=messages)]
|
|
268
274
|
|
|
269
275
|
# Execute on_llm_start callbacks
|
|
270
276
|
if callbacks:
|
|
271
277
|
# Convert LLMMessage objects to dictionaries for callbacks
|
|
272
|
-
messages_dict = [{"role": msg.role, "content": msg.content} for msg in messages]
|
|
278
|
+
messages_dict = [{"role": msg.role, "content": msg.content} for msg in messages] if messages else []
|
|
273
279
|
for callback in callbacks:
|
|
274
280
|
try:
|
|
275
281
|
await callback.on_llm_start(
|
|
@@ -372,14 +378,16 @@ class LLMClientManager:
|
|
|
372
378
|
final_provider = context_provider or provider or AIProvider.OPENAI
|
|
373
379
|
final_model = context_model or model
|
|
374
380
|
|
|
375
|
-
# Convert string prompt to messages format
|
|
376
|
-
if
|
|
381
|
+
# Convert string prompt to messages format and handle None
|
|
382
|
+
if messages is None:
|
|
383
|
+
messages = []
|
|
384
|
+
elif isinstance(messages, str):
|
|
377
385
|
messages = [LLMMessage(role="user", content=messages)]
|
|
378
386
|
|
|
379
387
|
# Execute on_llm_start callbacks
|
|
380
388
|
if callbacks:
|
|
381
389
|
# Convert LLMMessage objects to dictionaries for callbacks
|
|
382
|
-
messages_dict = [{"role": msg.role, "content": msg.content} for msg in messages]
|
|
390
|
+
messages_dict = [{"role": msg.role, "content": msg.content} for msg in messages] if messages else []
|
|
383
391
|
for callback in callbacks:
|
|
384
392
|
try:
|
|
385
393
|
await callback.on_llm_start(
|
aiecs/llm/clients/base_client.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from typing import Dict, Any, Optional, List, AsyncGenerator
|
|
3
|
-
from dataclasses import dataclass
|
|
2
|
+
from typing import Dict, Any, Optional, List, AsyncGenerator, Union
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
4
|
import logging
|
|
5
5
|
|
|
6
6
|
logger = logging.getLogger(__name__)
|
|
@@ -35,6 +35,7 @@ class LLMMessage:
|
|
|
35
35
|
Attributes:
|
|
36
36
|
role: Message role - "system", "user", "assistant", or "tool"
|
|
37
37
|
content: Text content of the message (None when using tool calls)
|
|
38
|
+
images: List of image sources (URLs, base64 data URIs, or file paths) for vision support
|
|
38
39
|
tool_calls: Tool call information for assistant messages
|
|
39
40
|
tool_call_id: Tool call ID for tool response messages
|
|
40
41
|
cache_control: Cache control marker for prompt caching support
|
|
@@ -42,6 +43,7 @@ class LLMMessage:
|
|
|
42
43
|
|
|
43
44
|
role: str # "system", "user", "assistant", "tool"
|
|
44
45
|
content: Optional[str] = None # None when using tool calls
|
|
46
|
+
images: List[Union[str, Dict[str, Any]]] = field(default_factory=list) # Image sources for vision support
|
|
45
47
|
tool_calls: Optional[List[Dict[str, Any]]] = None # For assistant messages with tool calls
|
|
46
48
|
tool_call_id: Optional[str] = None # For tool messages
|
|
47
49
|
cache_control: Optional[CacheControl] = None # Cache control for prompt caching
|
|
@@ -163,9 +165,28 @@ class BaseLLMClient(ABC):
|
|
|
163
165
|
model: Optional[str] = None,
|
|
164
166
|
temperature: float = 0.7,
|
|
165
167
|
max_tokens: Optional[int] = None,
|
|
168
|
+
context: Optional[Dict[str, Any]] = None,
|
|
166
169
|
**kwargs,
|
|
167
170
|
) -> LLMResponse:
|
|
168
|
-
"""
|
|
171
|
+
"""
|
|
172
|
+
Generate text using the provider's API.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
messages: List of conversation messages
|
|
176
|
+
model: Model name (optional, uses default if not provided)
|
|
177
|
+
temperature: Sampling temperature (0.0 to 1.0)
|
|
178
|
+
max_tokens: Maximum tokens to generate
|
|
179
|
+
context: Optional context dictionary containing metadata such as:
|
|
180
|
+
- user_id: User identifier for tracking/billing
|
|
181
|
+
- tenant_id: Tenant identifier for multi-tenant setups
|
|
182
|
+
- request_id: Request identifier for tracing
|
|
183
|
+
- session_id: Session identifier
|
|
184
|
+
- Any other custom metadata for observability or middleware
|
|
185
|
+
**kwargs: Additional provider-specific parameters
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
LLMResponse with generated text and metadata
|
|
189
|
+
"""
|
|
169
190
|
|
|
170
191
|
@abstractmethod
|
|
171
192
|
async def stream_text(
|
|
@@ -174,9 +195,28 @@ class BaseLLMClient(ABC):
|
|
|
174
195
|
model: Optional[str] = None,
|
|
175
196
|
temperature: float = 0.7,
|
|
176
197
|
max_tokens: Optional[int] = None,
|
|
198
|
+
context: Optional[Dict[str, Any]] = None,
|
|
177
199
|
**kwargs,
|
|
178
200
|
) -> AsyncGenerator[str, None]:
|
|
179
|
-
"""
|
|
201
|
+
"""
|
|
202
|
+
Stream text generation using the provider's API.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
messages: List of conversation messages
|
|
206
|
+
model: Model name (optional, uses default if not provided)
|
|
207
|
+
temperature: Sampling temperature (0.0 to 1.0)
|
|
208
|
+
max_tokens: Maximum tokens to generate
|
|
209
|
+
context: Optional context dictionary containing metadata such as:
|
|
210
|
+
- user_id: User identifier for tracking/billing
|
|
211
|
+
- tenant_id: Tenant identifier for multi-tenant setups
|
|
212
|
+
- request_id: Request identifier for tracing
|
|
213
|
+
- session_id: Session identifier
|
|
214
|
+
- Any other custom metadata for observability or middleware
|
|
215
|
+
**kwargs: Additional provider-specific parameters
|
|
216
|
+
|
|
217
|
+
Yields:
|
|
218
|
+
Text tokens as they are generated
|
|
219
|
+
"""
|
|
180
220
|
|
|
181
221
|
@abstractmethod
|
|
182
222
|
async def close(self):
|
|
@@ -221,6 +261,7 @@ class BaseLLMClient(ABC):
|
|
|
221
261
|
LLMMessage(
|
|
222
262
|
role=msg.role,
|
|
223
263
|
content=msg.content,
|
|
264
|
+
images=msg.images,
|
|
224
265
|
tool_calls=msg.tool_calls,
|
|
225
266
|
tool_call_id=msg.tool_call_id,
|
|
226
267
|
cache_control=CacheControl(type="ephemeral"),
|