praisonaiagents 0.0.156__py3-none-any.whl → 0.0.157__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- praisonaiagents/llm/llm.py +84 -0
- praisonaiagents/llm/openai_client.py +76 -14
- {praisonaiagents-0.0.156.dist-info → praisonaiagents-0.0.157.dist-info}/METADATA +1 -1
- {praisonaiagents-0.0.156.dist-info → praisonaiagents-0.0.157.dist-info}/RECORD +6 -6
- {praisonaiagents-0.0.156.dist-info → praisonaiagents-0.0.157.dist-info}/WHEEL +0 -0
- {praisonaiagents-0.0.156.dist-info → praisonaiagents-0.0.157.dist-info}/top_level.txt +0 -0
praisonaiagents/llm/llm.py
CHANGED
@@ -7,6 +7,7 @@ from typing import Any, Dict, List, Optional, Union, Literal, Callable
|
|
7
7
|
from pydantic import BaseModel
|
8
8
|
import time
|
9
9
|
import json
|
10
|
+
import xml.etree.ElementTree as ET
|
10
11
|
from ..main import (
|
11
12
|
display_error,
|
12
13
|
display_tool_call,
|
@@ -281,6 +282,8 @@ class LLM:
|
|
281
282
|
self.min_reflect = extra_settings.get('min_reflect', 1)
|
282
283
|
self.reasoning_steps = extra_settings.get('reasoning_steps', False)
|
283
284
|
self.metrics = extra_settings.get('metrics', False)
|
285
|
+
# Auto-detect XML tool format for known models, or allow manual override
|
286
|
+
self.xml_tool_format = extra_settings.get('xml_tool_format', 'auto')
|
284
287
|
|
285
288
|
# Token tracking
|
286
289
|
self.last_token_metrics: Optional[TokenMetrics] = None
|
@@ -359,6 +362,25 @@ class LLM:
|
|
359
362
|
|
360
363
|
return False
|
361
364
|
|
365
|
+
def _is_qwen_provider(self) -> bool:
|
366
|
+
"""Detect if this is a Qwen provider"""
|
367
|
+
if not self.model:
|
368
|
+
return False
|
369
|
+
|
370
|
+
# Check for Qwen patterns in model name
|
371
|
+
model_lower = self.model.lower()
|
372
|
+
return any(pattern in model_lower for pattern in ["qwen", "qwen2", "qwen2.5"])
|
373
|
+
|
374
|
+
def _supports_xml_tool_format(self) -> bool:
|
375
|
+
"""Check if the model should use XML tool format"""
|
376
|
+
if self.xml_tool_format == 'auto':
|
377
|
+
# Auto-detect based on known models that use XML format
|
378
|
+
return self._is_qwen_provider()
|
379
|
+
elif self.xml_tool_format in [True, 'true', 'True']:
|
380
|
+
return True
|
381
|
+
else:
|
382
|
+
return False
|
383
|
+
|
362
384
|
def _generate_ollama_tool_summary(self, tool_results: List[Any], response_text: str) -> Optional[str]:
|
363
385
|
"""
|
364
386
|
Generate a summary from tool results for Ollama to prevent infinite loops.
|
@@ -658,6 +680,10 @@ class LLM:
|
|
658
680
|
if any(self.model.startswith(prefix) for prefix in ["gemini-", "gemini/"]):
|
659
681
|
return True
|
660
682
|
|
683
|
+
# Models with XML tool format support streaming with tools
|
684
|
+
if self._supports_xml_tool_format():
|
685
|
+
return True
|
686
|
+
|
661
687
|
# For other providers, default to False to be safe
|
662
688
|
# This ensures we make a single non-streaming call rather than risk
|
663
689
|
# missing tool calls or making duplicate calls
|
@@ -1427,6 +1453,64 @@ class LLM:
|
|
1427
1453
|
except (json.JSONDecodeError, KeyError) as e:
|
1428
1454
|
logging.debug(f"Could not parse Ollama tool call from response: {e}")
|
1429
1455
|
|
1456
|
+
# Parse tool calls from XML format in response text
|
1457
|
+
# Try for known XML models first, or fallback for any model that might output XML
|
1458
|
+
if not tool_calls and response_text and formatted_tools:
|
1459
|
+
# Check if this model is known to use XML format, or try as fallback
|
1460
|
+
should_try_xml = (self._supports_xml_tool_format() or
|
1461
|
+
# Fallback: try XML if response contains XML-like tool call tags
|
1462
|
+
'<tool_call>' in response_text)
|
1463
|
+
|
1464
|
+
if should_try_xml:
|
1465
|
+
tool_calls = []
|
1466
|
+
|
1467
|
+
# Try proper XML parsing first
|
1468
|
+
try:
|
1469
|
+
# Wrap in root element if multiple tool_call tags exist
|
1470
|
+
xml_content = f"<root>{response_text}</root>"
|
1471
|
+
root = ET.fromstring(xml_content)
|
1472
|
+
tool_call_elements = root.findall('.//tool_call')
|
1473
|
+
|
1474
|
+
for idx, element in enumerate(tool_call_elements):
|
1475
|
+
if element.text:
|
1476
|
+
try:
|
1477
|
+
tool_json = json.loads(element.text.strip())
|
1478
|
+
if isinstance(tool_json, dict) and "name" in tool_json:
|
1479
|
+
tool_calls.append({
|
1480
|
+
"id": f"tool_{iteration_count}_{idx}",
|
1481
|
+
"type": "function",
|
1482
|
+
"function": {
|
1483
|
+
"name": tool_json["name"],
|
1484
|
+
"arguments": json.dumps(tool_json.get("arguments", {}))
|
1485
|
+
}
|
1486
|
+
})
|
1487
|
+
except (json.JSONDecodeError, KeyError) as e:
|
1488
|
+
logging.debug(f"Could not parse tool call JSON: {e}")
|
1489
|
+
continue
|
1490
|
+
except ET.ParseError:
|
1491
|
+
# Fallback to regex if XML parsing fails
|
1492
|
+
tool_call_pattern = r'<tool_call>\s*(\{(?:[^{}]|{[^{}]*})*\})\s*</tool_call>'
|
1493
|
+
matches = re.findall(tool_call_pattern, response_text, re.DOTALL)
|
1494
|
+
|
1495
|
+
for idx, match in enumerate(matches):
|
1496
|
+
try:
|
1497
|
+
tool_json = json.loads(match.strip())
|
1498
|
+
if isinstance(tool_json, dict) and "name" in tool_json:
|
1499
|
+
tool_calls.append({
|
1500
|
+
"id": f"tool_{iteration_count}_{idx}",
|
1501
|
+
"type": "function",
|
1502
|
+
"function": {
|
1503
|
+
"name": tool_json["name"],
|
1504
|
+
"arguments": json.dumps(tool_json.get("arguments", {}))
|
1505
|
+
}
|
1506
|
+
})
|
1507
|
+
except (json.JSONDecodeError, KeyError) as e:
|
1508
|
+
logging.debug(f"Could not parse XML tool call: {e}")
|
1509
|
+
continue
|
1510
|
+
|
1511
|
+
if tool_calls:
|
1512
|
+
logging.debug(f"Parsed {len(tool_calls)} tool call(s) from XML format")
|
1513
|
+
|
1430
1514
|
# For Ollama, if response is empty but we have tools, prompt for tool usage
|
1431
1515
|
if self._is_ollama_provider() and (not response_text or response_text.strip() == "") and formatted_tools and iteration_count == 0:
|
1432
1516
|
messages.append({
|
@@ -230,19 +230,34 @@ class OpenAIClient:
|
|
230
230
|
f"(e.g., 'http://localhost:1234/v1') and you can use a placeholder API key by setting OPENAI_API_KEY='{LOCAL_SERVER_API_KEY_PLACEHOLDER}'"
|
231
231
|
)
|
232
232
|
|
233
|
-
# Initialize
|
234
|
-
self._sync_client =
|
233
|
+
# Initialize clients lazily
|
234
|
+
self._sync_client = None
|
235
235
|
self._async_client = None
|
236
236
|
|
237
237
|
# Set up logging
|
238
238
|
self.logger = logging.getLogger(__name__)
|
239
239
|
|
240
|
-
# Initialize console
|
241
|
-
self.
|
240
|
+
# Initialize console lazily
|
241
|
+
self._console = None
|
242
|
+
|
243
|
+
# Cache for formatted tools and fixed schemas
|
244
|
+
self._formatted_tools_cache = {}
|
245
|
+
self._fixed_schema_cache = {}
|
246
|
+
self._max_cache_size = 100
|
247
|
+
|
248
|
+
@property
|
249
|
+
def console(self):
|
250
|
+
"""Lazily initialize Rich Console only when needed."""
|
251
|
+
if self._console is None:
|
252
|
+
from rich.console import Console
|
253
|
+
self._console = Console()
|
254
|
+
return self._console
|
242
255
|
|
243
256
|
@property
|
244
257
|
def sync_client(self) -> OpenAI:
|
245
|
-
"""Get the synchronous OpenAI client."""
|
258
|
+
"""Get the synchronous OpenAI client (lazy initialization)."""
|
259
|
+
if self._sync_client is None:
|
260
|
+
self._sync_client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
246
261
|
return self._sync_client
|
247
262
|
|
248
263
|
@property
|
@@ -350,6 +365,35 @@ class OpenAIClient:
|
|
350
365
|
|
351
366
|
return fixed_schema
|
352
367
|
|
368
|
+
def _get_tools_cache_key(self, tools: List[Any]) -> str:
|
369
|
+
"""Generate a cache key for tools."""
|
370
|
+
parts = []
|
371
|
+
for tool in tools:
|
372
|
+
if isinstance(tool, dict):
|
373
|
+
# For dict tools, use sorted JSON representation
|
374
|
+
parts.append(json.dumps(tool, sort_keys=True))
|
375
|
+
elif callable(tool):
|
376
|
+
# For functions, use module.name
|
377
|
+
parts.append(f"{tool.__module__}.{tool.__name__}")
|
378
|
+
elif isinstance(tool, str):
|
379
|
+
# For string tools, use as-is
|
380
|
+
parts.append(tool)
|
381
|
+
elif isinstance(tool, list):
|
382
|
+
# For lists, recursively process
|
383
|
+
subparts = []
|
384
|
+
for subtool in tool:
|
385
|
+
if isinstance(subtool, dict):
|
386
|
+
subparts.append(json.dumps(subtool, sort_keys=True))
|
387
|
+
elif callable(subtool):
|
388
|
+
subparts.append(f"{subtool.__module__}.{subtool.__name__}")
|
389
|
+
else:
|
390
|
+
subparts.append(str(subtool))
|
391
|
+
parts.append(f"[{','.join(subparts)}]")
|
392
|
+
else:
|
393
|
+
# For other types, use string representation
|
394
|
+
parts.append(str(tool))
|
395
|
+
return "|".join(parts)
|
396
|
+
|
353
397
|
def format_tools(self, tools: Optional[List[Any]]) -> Optional[List[Dict]]:
|
354
398
|
"""
|
355
399
|
Format tools for OpenAI API.
|
@@ -370,6 +414,11 @@ class OpenAIClient:
|
|
370
414
|
"""
|
371
415
|
if not tools:
|
372
416
|
return None
|
417
|
+
|
418
|
+
# Check cache first
|
419
|
+
cache_key = self._get_tools_cache_key(tools)
|
420
|
+
if cache_key in self._formatted_tools_cache:
|
421
|
+
return self._formatted_tools_cache[cache_key]
|
373
422
|
|
374
423
|
formatted_tools = []
|
375
424
|
for tool in tools:
|
@@ -424,8 +473,13 @@ class OpenAIClient:
|
|
424
473
|
except (TypeError, ValueError) as e:
|
425
474
|
logging.error(f"Tools are not JSON serializable: {e}")
|
426
475
|
return None
|
476
|
+
|
477
|
+
# Cache the result
|
478
|
+
result = formatted_tools if formatted_tools else None
|
479
|
+
if result is not None and len(self._formatted_tools_cache) < self._max_cache_size:
|
480
|
+
self._formatted_tools_cache[cache_key] = result
|
427
481
|
|
428
|
-
return
|
482
|
+
return result
|
429
483
|
|
430
484
|
def _generate_tool_definition(self, func: Callable) -> Optional[Dict]:
|
431
485
|
"""Generate a tool definition from a callable function."""
|
@@ -546,7 +600,7 @@ class OpenAIClient:
|
|
546
600
|
console = self.console
|
547
601
|
|
548
602
|
# Create the response stream
|
549
|
-
response_stream = self.
|
603
|
+
response_stream = self.sync_client.chat.completions.create(
|
550
604
|
model=model,
|
551
605
|
messages=messages,
|
552
606
|
temperature=temperature,
|
@@ -723,7 +777,7 @@ class OpenAIClient:
|
|
723
777
|
params["tool_choice"] = tool_choice
|
724
778
|
|
725
779
|
try:
|
726
|
-
return self.
|
780
|
+
return self.sync_client.chat.completions.create(**params)
|
727
781
|
except Exception as e:
|
728
782
|
self.logger.error(f"Error creating completion: {e}")
|
729
783
|
raise
|
@@ -1173,7 +1227,7 @@ class OpenAIClient:
|
|
1173
1227
|
while iteration_count < max_iterations:
|
1174
1228
|
try:
|
1175
1229
|
# Create streaming response
|
1176
|
-
response_stream = self.
|
1230
|
+
response_stream = self.sync_client.chat.completions.create(
|
1177
1231
|
model=model,
|
1178
1232
|
messages=messages,
|
1179
1233
|
temperature=temperature,
|
@@ -1298,7 +1352,7 @@ class OpenAIClient:
|
|
1298
1352
|
Parsed response according to the response_format
|
1299
1353
|
"""
|
1300
1354
|
try:
|
1301
|
-
response = self.
|
1355
|
+
response = self.sync_client.beta.chat.completions.parse(
|
1302
1356
|
model=model,
|
1303
1357
|
messages=messages,
|
1304
1358
|
temperature=temperature,
|
@@ -1346,14 +1400,14 @@ class OpenAIClient:
|
|
1346
1400
|
|
1347
1401
|
def close(self):
|
1348
1402
|
"""Close the OpenAI clients."""
|
1349
|
-
if hasattr(self._sync_client, 'close'):
|
1403
|
+
if self._sync_client and hasattr(self._sync_client, 'close'):
|
1350
1404
|
self._sync_client.close()
|
1351
1405
|
if self._async_client and hasattr(self._async_client, 'close'):
|
1352
1406
|
self._async_client.close()
|
1353
1407
|
|
1354
1408
|
async def aclose(self):
|
1355
1409
|
"""Asynchronously close the OpenAI clients."""
|
1356
|
-
if hasattr(self._sync_client, 'close'):
|
1410
|
+
if self._sync_client and hasattr(self._sync_client, 'close'):
|
1357
1411
|
await asyncio.to_thread(self._sync_client.close)
|
1358
1412
|
if self._async_client and hasattr(self._async_client, 'aclose'):
|
1359
1413
|
await self._async_client.aclose()
|
@@ -1361,6 +1415,7 @@ class OpenAIClient:
|
|
1361
1415
|
|
1362
1416
|
# Global client instance (similar to main.py pattern)
|
1363
1417
|
_global_client = None
|
1418
|
+
_global_client_params = None
|
1364
1419
|
|
1365
1420
|
def get_openai_client(api_key: Optional[str] = None, base_url: Optional[str] = None) -> OpenAIClient:
|
1366
1421
|
"""
|
@@ -1373,9 +1428,16 @@ def get_openai_client(api_key: Optional[str] = None, base_url: Optional[str] = N
|
|
1373
1428
|
Returns:
|
1374
1429
|
OpenAIClient instance
|
1375
1430
|
"""
|
1376
|
-
global _global_client
|
1431
|
+
global _global_client, _global_client_params
|
1432
|
+
|
1433
|
+
# Normalize parameters for comparison
|
1434
|
+
normalized_api_key = api_key or os.getenv("OPENAI_API_KEY")
|
1435
|
+
normalized_base_url = base_url
|
1436
|
+
current_params = (normalized_api_key, normalized_base_url)
|
1377
1437
|
|
1378
|
-
if
|
1438
|
+
# Only create new client if parameters changed or first time
|
1439
|
+
if _global_client is None or _global_client_params != current_params:
|
1379
1440
|
_global_client = OpenAIClient(api_key=api_key, base_url=base_url)
|
1441
|
+
_global_client_params = current_params
|
1380
1442
|
|
1381
1443
|
return _global_client
|
@@ -21,10 +21,10 @@ praisonaiagents/knowledge/__init__.py,sha256=xL1Eh-a3xsHyIcU4foOWF-JdWYIYBALJH9b
|
|
21
21
|
praisonaiagents/knowledge/chunking.py,sha256=G6wyHa7_8V0_7VpnrrUXbEmUmptlT16ISJYaxmkSgmU,7678
|
22
22
|
praisonaiagents/knowledge/knowledge.py,sha256=tog38b0SjFMoLuFBo0M1zHl9Dzzxa9YRv9FO7OZSpns,30587
|
23
23
|
praisonaiagents/llm/__init__.py,sha256=SqdU1pRqPrR6jZeWYyDeTvmZKCACywk0v4P0k5Fuowk,1107
|
24
|
-
praisonaiagents/llm/llm.py,sha256=
|
24
|
+
praisonaiagents/llm/llm.py,sha256=C4C1xrR_qgInbgF1I-YhgPLI1C1YYI-5u3vn6Gp8sVc,184239
|
25
25
|
praisonaiagents/llm/model_capabilities.py,sha256=cxOvZcjZ_PIEpUYKn3S2FMyypfOSfbGpx4vmV7Y5vhI,3967
|
26
26
|
praisonaiagents/llm/model_router.py,sha256=Jy2pShlkLxqXF3quz-MRB3-6L9vaUSgUrf2YJs_Tsg0,13995
|
27
|
-
praisonaiagents/llm/openai_client.py,sha256=
|
27
|
+
praisonaiagents/llm/openai_client.py,sha256=Qn4z_ld8IYe-R8yKDRuek_4CP8lCJz2blJIRTm-mfDg,59882
|
28
28
|
praisonaiagents/mcp/__init__.py,sha256=ibbqe3_7XB7VrIcUcetkZiUZS1fTVvyMy_AqCSFG8qc,240
|
29
29
|
praisonaiagents/mcp/mcp.py,sha256=ChaSwLCcFBB9b8eNuj0DoKbK1EqpyF1T_7xz0FX-5-A,23264
|
30
30
|
praisonaiagents/mcp/mcp_http_stream.py,sha256=TDFWMJMo8VqLXtXCW73REpmkU3t9n7CAGMa9b4dhI-c,23366
|
@@ -67,7 +67,7 @@ praisonaiagents/tools/xml_tools.py,sha256=iYTMBEk5l3L3ryQ1fkUnNVYK-Nnua2Kx2S0dxN
|
|
67
67
|
praisonaiagents/tools/yaml_tools.py,sha256=uogAZrhXV9O7xvspAtcTfpKSQYL2nlOTvCQXN94-G9A,14215
|
68
68
|
praisonaiagents/tools/yfinance_tools.py,sha256=s2PBj_1v7oQnOobo2fDbQBACEHl61ftG4beG6Z979ZE,8529
|
69
69
|
praisonaiagents/tools/train/data/generatecot.py,sha256=H6bNh-E2hqL5MW6kX3hqZ05g9ETKN2-kudSjiuU_SD8,19403
|
70
|
-
praisonaiagents-0.0.
|
71
|
-
praisonaiagents-0.0.
|
72
|
-
praisonaiagents-0.0.
|
73
|
-
praisonaiagents-0.0.
|
70
|
+
praisonaiagents-0.0.157.dist-info/METADATA,sha256=Ww1hB8QIFxzqu7UXFL7B5OscEFYl8UIGPnR8FuiUFLU,2146
|
71
|
+
praisonaiagents-0.0.157.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
72
|
+
praisonaiagents-0.0.157.dist-info/top_level.txt,sha256=_HsRddrJ23iDx5TTqVUVvXG2HeHBL5voshncAMDGjtA,16
|
73
|
+
praisonaiagents-0.0.157.dist-info/RECORD,,
|
File without changes
|
File without changes
|