vectara-agentic 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of vectara-agentic might be problematic. Click here for more details.
- tests/__init__.py +1 -0
- tests/benchmark_models.py +547 -372
- tests/conftest.py +14 -12
- tests/endpoint.py +9 -5
- tests/run_tests.py +1 -0
- tests/test_agent.py +22 -9
- tests/test_agent_fallback_memory.py +4 -4
- tests/test_agent_memory_consistency.py +4 -4
- tests/test_agent_type.py +2 -0
- tests/test_api_endpoint.py +13 -13
- tests/test_bedrock.py +9 -1
- tests/test_fallback.py +18 -7
- tests/test_gemini.py +14 -40
- tests/test_groq.py +43 -1
- tests/test_openai.py +160 -0
- tests/test_private_llm.py +19 -6
- tests/test_react_error_handling.py +293 -0
- tests/test_react_memory.py +257 -0
- tests/test_react_streaming.py +135 -0
- tests/test_react_workflow_events.py +395 -0
- tests/test_return_direct.py +1 -0
- tests/test_serialization.py +58 -20
- tests/test_session_memory.py +11 -11
- tests/test_streaming.py +0 -44
- tests/test_together.py +75 -1
- tests/test_tools.py +3 -1
- tests/test_vectara_llms.py +2 -2
- tests/test_vhc.py +7 -2
- tests/test_workflow.py +17 -11
- vectara_agentic/_callback.py +79 -21
- vectara_agentic/_version.py +1 -1
- vectara_agentic/agent.py +65 -27
- vectara_agentic/agent_core/serialization.py +5 -9
- vectara_agentic/agent_core/streaming.py +245 -64
- vectara_agentic/agent_core/utils/schemas.py +2 -2
- vectara_agentic/llm_utils.py +64 -15
- vectara_agentic/tools.py +88 -31
- {vectara_agentic-0.4.2.dist-info → vectara_agentic-0.4.4.dist-info}/METADATA +133 -36
- vectara_agentic-0.4.4.dist-info/RECORD +59 -0
- vectara_agentic-0.4.2.dist-info/RECORD +0 -54
- {vectara_agentic-0.4.2.dist-info → vectara_agentic-0.4.4.dist-info}/WHEEL +0 -0
- {vectara_agentic-0.4.2.dist-info → vectara_agentic-0.4.4.dist-info}/licenses/LICENSE +0 -0
- {vectara_agentic-0.4.2.dist-info → vectara_agentic-0.4.4.dist-info}/top_level.txt +0 -0
|
@@ -7,12 +7,11 @@ for managing asynchronous agent interactions with proper synchronization.
|
|
|
7
7
|
|
|
8
8
|
import asyncio
|
|
9
9
|
import logging
|
|
10
|
-
import uuid
|
|
11
10
|
import json
|
|
12
11
|
import traceback
|
|
12
|
+
import uuid
|
|
13
13
|
|
|
14
14
|
from typing import Callable, Any, Dict, AsyncIterator
|
|
15
|
-
from collections import OrderedDict
|
|
16
15
|
|
|
17
16
|
from llama_index.core.agent.workflow import (
|
|
18
17
|
ToolCall,
|
|
@@ -20,58 +19,28 @@ from llama_index.core.agent.workflow import (
|
|
|
20
19
|
AgentInput,
|
|
21
20
|
AgentOutput,
|
|
22
21
|
)
|
|
23
|
-
from ..types import AgentResponse
|
|
22
|
+
from ..types import AgentResponse, AgentStatusType
|
|
24
23
|
|
|
25
|
-
class ToolEventTracker:
|
|
26
|
-
"""
|
|
27
|
-
Tracks event IDs for tool calls to ensure consistent pairing of tool calls and outputs.
|
|
28
24
|
|
|
29
|
-
|
|
30
|
-
that related tool call and tool output events share the same event_id for proper
|
|
31
|
-
frontend grouping.
|
|
25
|
+
def get_event_id(event) -> str:
|
|
32
26
|
"""
|
|
27
|
+
Get event ID from LlamaIndex event.
|
|
33
28
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
self.fallback_counter = 0 # For events without identifiable tool_ids
|
|
37
|
-
|
|
38
|
-
def get_event_id(self, event) -> str:
|
|
39
|
-
"""
|
|
40
|
-
Get a consistent event ID for a tool event.
|
|
29
|
+
Args:
|
|
30
|
+
event: The event object from LlamaIndex
|
|
41
31
|
|
|
42
|
-
|
|
43
|
-
|
|
32
|
+
Returns:
|
|
33
|
+
str: Event ID from the event, or creates a new one if it does not exist
|
|
34
|
+
"""
|
|
35
|
+
# Check for direct event_id first
|
|
36
|
+
if hasattr(event, "event_id") and event.event_id:
|
|
37
|
+
return event.event_id
|
|
44
38
|
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
# Try to get tool_id from the event first
|
|
49
|
-
tool_id = getattr(event, "tool_id", None)
|
|
50
|
-
|
|
51
|
-
# If we have a tool_id, use it directly (any format from any LLM provider)
|
|
52
|
-
if tool_id:
|
|
53
|
-
pass # We already have tool_id, just use it
|
|
54
|
-
# If no tool_id, try to derive one from tool_name (for LlamaIndex events)
|
|
55
|
-
elif hasattr(event, "tool_name") and event.tool_name:
|
|
56
|
-
tool_id = f"{event.tool_name}_{self.fallback_counter}"
|
|
57
|
-
self.fallback_counter += 1
|
|
58
|
-
# If still no tool_id, create a generic one based on event type
|
|
59
|
-
else:
|
|
60
|
-
event_type = type(event).__name__
|
|
61
|
-
tool_id = f"{event_type.lower()}_{self.fallback_counter}"
|
|
62
|
-
self.fallback_counter += 1
|
|
63
|
-
|
|
64
|
-
# Get or create event_id for this tool_id
|
|
65
|
-
if tool_id not in self.event_ids:
|
|
66
|
-
self.event_ids[tool_id] = str(uuid.uuid4())
|
|
67
|
-
|
|
68
|
-
return self.event_ids[tool_id]
|
|
69
|
-
|
|
70
|
-
def clear_old_entries(self, max_entries: int = 100):
|
|
71
|
-
"""Clear old entries to prevent unbounded memory growth."""
|
|
72
|
-
while len(self.event_ids) > max_entries // 2:
|
|
73
|
-
self.event_ids.popitem(last=False) # Remove oldest entry
|
|
39
|
+
# Check for tool_id for tool-related events
|
|
40
|
+
if hasattr(event, "tool_id") and event.tool_id:
|
|
41
|
+
return event.tool_id
|
|
74
42
|
|
|
43
|
+
return str(uuid.uuid4())
|
|
75
44
|
|
|
76
45
|
class StreamingResponseAdapter:
|
|
77
46
|
"""
|
|
@@ -284,7 +253,6 @@ class FunctionCallingStreamHandler:
|
|
|
284
253
|
self.prompt = prompt
|
|
285
254
|
self.final_response_container = {"resp": None}
|
|
286
255
|
self.stream_complete_event = asyncio.Event()
|
|
287
|
-
self.event_tracker = ToolEventTracker()
|
|
288
256
|
|
|
289
257
|
async def process_stream_events(self) -> AsyncIterator[str]:
|
|
290
258
|
"""
|
|
@@ -299,16 +267,22 @@ class FunctionCallingStreamHandler:
|
|
|
299
267
|
async for ev in self.handler.stream_events():
|
|
300
268
|
# Store tool outputs for VHC regardless of progress callback
|
|
301
269
|
if isinstance(ev, ToolCallResult):
|
|
302
|
-
if hasattr(self.agent_instance,
|
|
270
|
+
if hasattr(self.agent_instance, "_add_tool_output"):
|
|
303
271
|
# pylint: disable=W0212
|
|
304
|
-
self.agent_instance._add_tool_output(
|
|
272
|
+
self.agent_instance._add_tool_output(
|
|
273
|
+
ev.tool_name, str(ev.tool_output)
|
|
274
|
+
)
|
|
305
275
|
|
|
306
276
|
# Handle progress callbacks if available
|
|
307
277
|
if self.agent_instance.agent_progress_callback:
|
|
308
278
|
# Only track events that are actual tool-related events
|
|
309
279
|
if self._is_tool_related_event(ev):
|
|
310
|
-
|
|
311
|
-
|
|
280
|
+
try:
|
|
281
|
+
event_id = get_event_id(ev)
|
|
282
|
+
await self._handle_progress_callback(ev, event_id)
|
|
283
|
+
except ValueError as e:
|
|
284
|
+
logging.warning(f"Skipping event due to missing ID: {e}")
|
|
285
|
+
continue
|
|
312
286
|
|
|
313
287
|
# Process streaming text events
|
|
314
288
|
if hasattr(ev, "__class__") and "AgentStream" in str(ev.__class__):
|
|
@@ -335,16 +309,25 @@ class FunctionCallingStreamHandler:
|
|
|
335
309
|
try:
|
|
336
310
|
self.final_response_container["resp"] = await self.handler
|
|
337
311
|
except Exception as e:
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
312
|
+
error_str = str(e).lower()
|
|
313
|
+
if "rate limit" in error_str or "429" in error_str:
|
|
314
|
+
logging.error(f"🔍 [RATE_LIMIT_ERROR] Rate limit exceeded: {e}")
|
|
315
|
+
self.final_response_container["resp"] = AgentResponse(
|
|
316
|
+
response="Rate limit exceeded. Please try again later.",
|
|
317
|
+
source_nodes=[],
|
|
318
|
+
metadata={"error_type": "rate_limit", "original_error": str(e)},
|
|
319
|
+
)
|
|
320
|
+
else:
|
|
321
|
+
logging.error(f"🔍 [STREAM_ERROR] Error processing stream events: {e}")
|
|
322
|
+
logging.error(
|
|
323
|
+
f"🔍 [STREAM_ERROR] Full traceback: {traceback.format_exc()}"
|
|
324
|
+
)
|
|
325
|
+
self.final_response_container["resp"] = AgentResponse(
|
|
326
|
+
response="Response completion Error",
|
|
327
|
+
source_nodes=[],
|
|
328
|
+
metadata={"error_type": "general", "original_error": str(e)},
|
|
329
|
+
)
|
|
345
330
|
finally:
|
|
346
|
-
# Clean up event tracker to prevent memory leaks
|
|
347
|
-
self.event_tracker.clear_old_entries()
|
|
348
331
|
# Signal that stream processing is complete
|
|
349
332
|
self.stream_complete_event.set()
|
|
350
333
|
|
|
@@ -380,9 +363,6 @@ class FunctionCallingStreamHandler:
|
|
|
380
363
|
|
|
381
364
|
async def _handle_progress_callback(self, event, event_id: str):
|
|
382
365
|
"""Handle progress callback events for different event types with proper context propagation."""
|
|
383
|
-
# Import here to avoid circular imports
|
|
384
|
-
from ..types import AgentStatusType
|
|
385
|
-
|
|
386
366
|
try:
|
|
387
367
|
if isinstance(event, ToolCall):
|
|
388
368
|
# Check if callback is async or sync
|
|
@@ -477,3 +457,204 @@ class FunctionCallingStreamHandler:
|
|
|
477
457
|
metadata={},
|
|
478
458
|
post_process_task=post_process_task,
|
|
479
459
|
)
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
class ReActStreamHandler:
|
|
463
|
+
"""
|
|
464
|
+
Handles streaming for ReAct agents with proper event processing.
|
|
465
|
+
|
|
466
|
+
ReAct agents use a workflow-based approach and emit ToolCall/ToolCallResult events
|
|
467
|
+
that need to be captured and converted to progress callbacks.
|
|
468
|
+
"""
|
|
469
|
+
|
|
470
|
+
def __init__(self, agent_instance, handler, prompt: str):
|
|
471
|
+
self.agent_instance = agent_instance
|
|
472
|
+
self.handler = handler
|
|
473
|
+
self.prompt = prompt
|
|
474
|
+
self.final_response_container = {"resp": None}
|
|
475
|
+
self.stream_complete_event = asyncio.Event()
|
|
476
|
+
|
|
477
|
+
async def process_stream_events(self) -> AsyncIterator[str]:
|
|
478
|
+
"""
|
|
479
|
+
Process streaming events from ReAct workflow and yield text tokens.
|
|
480
|
+
|
|
481
|
+
Yields:
|
|
482
|
+
str: Text tokens from the streaming response
|
|
483
|
+
"""
|
|
484
|
+
async for event in self.handler.stream_events():
|
|
485
|
+
# Store tool outputs for VHC regardless of progress callback
|
|
486
|
+
if isinstance(event, ToolCallResult):
|
|
487
|
+
if hasattr(self.agent_instance, "_add_tool_output"):
|
|
488
|
+
# pylint: disable=W0212
|
|
489
|
+
self.agent_instance._add_tool_output(
|
|
490
|
+
event.tool_name, str(event.tool_output)
|
|
491
|
+
)
|
|
492
|
+
# Handle progress callbacks if available - this is the key missing piece!
|
|
493
|
+
if self.agent_instance.agent_progress_callback:
|
|
494
|
+
# Only track events that are actual tool-related events
|
|
495
|
+
if self._is_tool_related_event(event):
|
|
496
|
+
try:
|
|
497
|
+
# Get event ID from LlamaIndex event
|
|
498
|
+
event_id = get_event_id(event)
|
|
499
|
+
|
|
500
|
+
# Handle different types of workflow events using same logic as achat method
|
|
501
|
+
if isinstance(event, ToolCall):
|
|
502
|
+
# Check if callback is async or sync
|
|
503
|
+
if asyncio.iscoroutinefunction(
|
|
504
|
+
self.agent_instance.agent_progress_callback
|
|
505
|
+
):
|
|
506
|
+
await self.agent_instance.agent_progress_callback(
|
|
507
|
+
status_type=AgentStatusType.TOOL_CALL,
|
|
508
|
+
msg={
|
|
509
|
+
"tool_name": event.tool_name,
|
|
510
|
+
"arguments": json.dumps(event.tool_kwargs),
|
|
511
|
+
},
|
|
512
|
+
event_id=event_id,
|
|
513
|
+
)
|
|
514
|
+
else:
|
|
515
|
+
self.agent_instance.agent_progress_callback(
|
|
516
|
+
status_type=AgentStatusType.TOOL_CALL,
|
|
517
|
+
msg={
|
|
518
|
+
"tool_name": event.tool_name,
|
|
519
|
+
"arguments": json.dumps(event.tool_kwargs),
|
|
520
|
+
},
|
|
521
|
+
event_id=event_id,
|
|
522
|
+
)
|
|
523
|
+
elif isinstance(event, ToolCallResult):
|
|
524
|
+
# Check if callback is async or sync
|
|
525
|
+
if asyncio.iscoroutinefunction(
|
|
526
|
+
self.agent_instance.agent_progress_callback
|
|
527
|
+
):
|
|
528
|
+
await self.agent_instance.agent_progress_callback(
|
|
529
|
+
status_type=AgentStatusType.TOOL_OUTPUT,
|
|
530
|
+
msg={
|
|
531
|
+
"tool_name": event.tool_name,
|
|
532
|
+
"content": str(event.tool_output),
|
|
533
|
+
},
|
|
534
|
+
event_id=event_id,
|
|
535
|
+
)
|
|
536
|
+
else:
|
|
537
|
+
self.agent_instance.agent_progress_callback(
|
|
538
|
+
status_type=AgentStatusType.TOOL_OUTPUT,
|
|
539
|
+
msg={
|
|
540
|
+
"tool_name": event.tool_name,
|
|
541
|
+
"content": str(event.tool_output),
|
|
542
|
+
},
|
|
543
|
+
event_id=event_id,
|
|
544
|
+
)
|
|
545
|
+
elif isinstance(event, AgentInput):
|
|
546
|
+
if asyncio.iscoroutinefunction(
|
|
547
|
+
self.agent_instance.agent_progress_callback
|
|
548
|
+
):
|
|
549
|
+
await self.agent_instance.agent_progress_callback(
|
|
550
|
+
status_type=AgentStatusType.AGENT_UPDATE,
|
|
551
|
+
msg={"content": f"Agent input: {event.input}"},
|
|
552
|
+
event_id=event_id,
|
|
553
|
+
)
|
|
554
|
+
else:
|
|
555
|
+
self.agent_instance.agent_progress_callback(
|
|
556
|
+
status_type=AgentStatusType.AGENT_UPDATE,
|
|
557
|
+
msg={"content": f"Agent input: {event.input}"},
|
|
558
|
+
event_id=event_id,
|
|
559
|
+
)
|
|
560
|
+
elif isinstance(event, AgentOutput):
|
|
561
|
+
if asyncio.iscoroutinefunction(
|
|
562
|
+
self.agent_instance.agent_progress_callback
|
|
563
|
+
):
|
|
564
|
+
await self.agent_instance.agent_progress_callback(
|
|
565
|
+
status_type=AgentStatusType.AGENT_UPDATE,
|
|
566
|
+
msg={"content": f"Agent output: {event.response}"},
|
|
567
|
+
event_id=event_id,
|
|
568
|
+
)
|
|
569
|
+
else:
|
|
570
|
+
self.agent_instance.agent_progress_callback(
|
|
571
|
+
status_type=AgentStatusType.AGENT_UPDATE,
|
|
572
|
+
msg={"content": f"Agent output: {event.response}"},
|
|
573
|
+
event_id=event_id,
|
|
574
|
+
)
|
|
575
|
+
except ValueError as e:
|
|
576
|
+
logging.warning(f"Skipping event due to missing ID: {e}")
|
|
577
|
+
continue
|
|
578
|
+
except Exception as e:
|
|
579
|
+
logging.error(f"Exception in ReAct progress callback: {e}")
|
|
580
|
+
logging.error(f"Traceback: {traceback.format_exc()}")
|
|
581
|
+
# Continue execution despite callback errors
|
|
582
|
+
|
|
583
|
+
# For ReAct agents, we typically don't have streaming text like function calling
|
|
584
|
+
# ReAct usually processes in steps and then provides complete responses
|
|
585
|
+
# So we just yield empty strings to maintain streaming interface
|
|
586
|
+
yield ""
|
|
587
|
+
|
|
588
|
+
# When stream is done, await the handler to get the final response
|
|
589
|
+
try:
|
|
590
|
+
self.final_response_container["resp"] = await self.handler
|
|
591
|
+
except Exception as e:
|
|
592
|
+
logging.error(
|
|
593
|
+
f"🔍 [REACT_STREAM_ERROR] Error processing ReAct stream events: {e}"
|
|
594
|
+
)
|
|
595
|
+
logging.error(
|
|
596
|
+
f"🔍 [REACT_STREAM_ERROR] Full traceback: {traceback.format_exc()}"
|
|
597
|
+
)
|
|
598
|
+
self.final_response_container["resp"] = AgentResponse(
|
|
599
|
+
response="ReAct Response completion Error", source_nodes=[], metadata={}
|
|
600
|
+
)
|
|
601
|
+
finally:
|
|
602
|
+
# Signal that stream processing is complete
|
|
603
|
+
self.stream_complete_event.set()
|
|
604
|
+
|
|
605
|
+
def _is_tool_related_event(self, event) -> bool:
|
|
606
|
+
"""
|
|
607
|
+
Determine if an event is actually tool-related and should be tracked.
|
|
608
|
+
|
|
609
|
+
This should only return True for events that represent actual tool calls or tool outputs,
|
|
610
|
+
not for streaming text deltas or other LLM response events.
|
|
611
|
+
|
|
612
|
+
Args:
|
|
613
|
+
event: The stream event to check
|
|
614
|
+
|
|
615
|
+
Returns:
|
|
616
|
+
bool: True if this event should be tracked for tool purposes
|
|
617
|
+
"""
|
|
618
|
+
# Track explicit tool events from LlamaIndex workflow
|
|
619
|
+
if isinstance(event, (ToolCall, ToolCallResult)):
|
|
620
|
+
return True
|
|
621
|
+
|
|
622
|
+
has_tool_id = hasattr(event, "tool_id") and event.tool_id
|
|
623
|
+
has_delta = hasattr(event, "delta") and event.delta
|
|
624
|
+
has_tool_name = hasattr(event, "tool_name") and event.tool_name
|
|
625
|
+
|
|
626
|
+
# We're not seeing ToolCall/ToolCallResult events in the stream, so let's be more liberal
|
|
627
|
+
# but still avoid streaming deltas
|
|
628
|
+
if (has_tool_id or has_tool_name) and not has_delta:
|
|
629
|
+
return True
|
|
630
|
+
|
|
631
|
+
# Everything else (streaming deltas, agent outputs, workflow events, etc.)
|
|
632
|
+
# should NOT be tracked as tool events
|
|
633
|
+
return False
|
|
634
|
+
|
|
635
|
+
def create_streaming_response(
|
|
636
|
+
self, user_metadata: Dict[str, Any]
|
|
637
|
+
) -> "StreamingResponseAdapter":
|
|
638
|
+
"""
|
|
639
|
+
Create a StreamingResponseAdapter for ReAct agents with proper post-processing.
|
|
640
|
+
|
|
641
|
+
Args:
|
|
642
|
+
user_metadata: User metadata dictionary to update
|
|
643
|
+
|
|
644
|
+
Returns:
|
|
645
|
+
StreamingResponseAdapter: Configured streaming adapter
|
|
646
|
+
"""
|
|
647
|
+
post_process_task = create_stream_post_processing_task(
|
|
648
|
+
self.stream_complete_event,
|
|
649
|
+
self.final_response_container,
|
|
650
|
+
self.prompt,
|
|
651
|
+
self.agent_instance,
|
|
652
|
+
user_metadata,
|
|
653
|
+
)
|
|
654
|
+
|
|
655
|
+
return StreamingResponseAdapter(
|
|
656
|
+
async_response_gen=self.process_stream_events,
|
|
657
|
+
response="", # will be filled post-stream
|
|
658
|
+
metadata={},
|
|
659
|
+
post_process_task=post_process_task,
|
|
660
|
+
)
|
|
@@ -78,8 +78,8 @@ def get_field_type(field_schema: dict) -> Any:
|
|
|
78
78
|
# If only "items" is present (implies array by some conventions, but less standard)
|
|
79
79
|
# Or if it's a schema with other keywords like 'properties' (implying object)
|
|
80
80
|
# For simplicity, if no "type" or "anyOf" at this point, default to Any or add more specific handling.
|
|
81
|
-
# If 'properties' in field_schema
|
|
82
|
-
if "properties" in field_schema
|
|
81
|
+
# If 'properties' in field_schema, it's likely an object.
|
|
82
|
+
if "properties" in field_schema:
|
|
83
83
|
# This path might need to reconstruct a nested Pydantic model if you encounter such schemas.
|
|
84
84
|
# For now, treating as 'dict' or 'Any' might be a simpler placeholder.
|
|
85
85
|
return dict # Or Any, or more sophisticated object reconstruction.
|
vectara_agentic/llm_utils.py
CHANGED
|
@@ -18,7 +18,7 @@ from .agent_config import AgentConfig
|
|
|
18
18
|
|
|
19
19
|
provider_to_default_model_name = {
|
|
20
20
|
ModelProvider.OPENAI: "gpt-4.1-mini",
|
|
21
|
-
ModelProvider.ANTHROPIC: "claude-sonnet-4-
|
|
21
|
+
ModelProvider.ANTHROPIC: "claude-sonnet-4-0",
|
|
22
22
|
ModelProvider.TOGETHER: "deepseek-ai/DeepSeek-V3",
|
|
23
23
|
ModelProvider.GROQ: "openai/gpt-oss-20b",
|
|
24
24
|
ModelProvider.BEDROCK: "us.anthropic.claude-sonnet-4-20250514-v1:0",
|
|
@@ -26,6 +26,41 @@ provider_to_default_model_name = {
|
|
|
26
26
|
ModelProvider.GEMINI: "models/gemini-2.5-flash",
|
|
27
27
|
}
|
|
28
28
|
|
|
29
|
+
models_to_max_tokens = {
|
|
30
|
+
"gpt-5": 128000,
|
|
31
|
+
"gpt-4.1": 32768,
|
|
32
|
+
"gpt-4o": 16384,
|
|
33
|
+
"gpt-4.1-mini": 32768,
|
|
34
|
+
"claude-sonnet-4": 65536,
|
|
35
|
+
"deepseek-ai/deepseek-v3": 8192,
|
|
36
|
+
"models/gemini-2.5-flash": 65536,
|
|
37
|
+
"models/gemini-2.5-flash-lite": 65536,
|
|
38
|
+
"models/gemini-2.5-pro": 65536,
|
|
39
|
+
"openai/gpt-oss-20b": 65536,
|
|
40
|
+
"openai/gpt-oss-120b": 65536,
|
|
41
|
+
"us.anthropic.claude-sonnet-4-20250514-v1:0": 65536,
|
|
42
|
+
"command-a-03-2025": 8192,
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_max_tokens(model_name: str, model_provider: str) -> int:
|
|
47
|
+
"""Get the maximum token limit for a given model name and provider."""
|
|
48
|
+
if model_provider in [
|
|
49
|
+
ModelProvider.GEMINI,
|
|
50
|
+
ModelProvider.TOGETHER,
|
|
51
|
+
ModelProvider.OPENAI,
|
|
52
|
+
ModelProvider.ANTHROPIC,
|
|
53
|
+
ModelProvider.GROQ,
|
|
54
|
+
ModelProvider.BEDROCK,
|
|
55
|
+
ModelProvider.COHERE,
|
|
56
|
+
]:
|
|
57
|
+
# Try exact match first (case-insensitive)
|
|
58
|
+
max_tokens = models_to_max_tokens.get(model_name, 16384)
|
|
59
|
+
else:
|
|
60
|
+
max_tokens = 8192
|
|
61
|
+
return max_tokens
|
|
62
|
+
|
|
63
|
+
|
|
29
64
|
DEFAULT_MODEL_PROVIDER = ModelProvider.OPENAI
|
|
30
65
|
|
|
31
66
|
# Manual cache for LLM instances to handle mutable AgentConfig objects
|
|
@@ -87,24 +122,18 @@ def get_llm(role: LLMRole, config: Optional[AgentConfig] = None) -> LLM:
|
|
|
87
122
|
|
|
88
123
|
Uses a cache based on configuration parameters to avoid repeated LLM instantiation.
|
|
89
124
|
"""
|
|
125
|
+
if config is None:
|
|
126
|
+
config = AgentConfig()
|
|
90
127
|
# Check cache first
|
|
91
128
|
cache_key = _create_llm_cache_key(role, config)
|
|
92
129
|
if cache_key in _llm_cache:
|
|
93
130
|
return _llm_cache[cache_key]
|
|
94
131
|
model_provider, model_name = _get_llm_params_for_role(role, config)
|
|
95
|
-
max_tokens = (
|
|
96
|
-
16384
|
|
97
|
-
if model_provider
|
|
98
|
-
in [
|
|
99
|
-
ModelProvider.GEMINI,
|
|
100
|
-
ModelProvider.TOGETHER,
|
|
101
|
-
ModelProvider.OPENAI,
|
|
102
|
-
ModelProvider.ANTHROPIC,
|
|
103
|
-
]
|
|
104
|
-
else 8192
|
|
105
|
-
)
|
|
132
|
+
max_tokens = get_max_tokens(model_name, model_provider)
|
|
106
133
|
if model_provider == ModelProvider.OPENAI:
|
|
107
|
-
additional_kwargs =
|
|
134
|
+
additional_kwargs = (
|
|
135
|
+
{"reasoning_effort": "minimal"} if model_name.startswith("gpt-5") else {}
|
|
136
|
+
)
|
|
108
137
|
llm = OpenAI(
|
|
109
138
|
model=model_name,
|
|
110
139
|
temperature=0,
|
|
@@ -112,7 +141,7 @@ def get_llm(role: LLMRole, config: Optional[AgentConfig] = None) -> LLM:
|
|
|
112
141
|
strict=False,
|
|
113
142
|
max_tokens=max_tokens,
|
|
114
143
|
pydantic_program_mode="openai",
|
|
115
|
-
additional_kwargs=additional_kwargs
|
|
144
|
+
additional_kwargs=additional_kwargs,
|
|
116
145
|
)
|
|
117
146
|
elif model_provider == ModelProvider.ANTHROPIC:
|
|
118
147
|
llm = Anthropic(
|
|
@@ -127,11 +156,20 @@ def get_llm(role: LLMRole, config: Optional[AgentConfig] = None) -> LLM:
|
|
|
127
156
|
raise ImportError(
|
|
128
157
|
"google_genai not available. Install with: pip install llama-index-llms-google-genai"
|
|
129
158
|
) from e
|
|
159
|
+
import google.genai.types as google_types
|
|
160
|
+
generation_config = google_types.GenerateContentConfig(
|
|
161
|
+
temperature=0.0,
|
|
162
|
+
seed=123,
|
|
163
|
+
max_output_tokens=max_tokens,
|
|
164
|
+
thinking_config=google_types.ThinkingConfig(thinking_budget=0, include_thoughts=False),
|
|
165
|
+
)
|
|
130
166
|
llm = GoogleGenAI(
|
|
131
167
|
model=model_name,
|
|
132
168
|
temperature=0,
|
|
133
169
|
is_function_calling_model=True,
|
|
134
170
|
max_tokens=max_tokens,
|
|
171
|
+
generation_config=generation_config,
|
|
172
|
+
context_window=1_000_000,
|
|
135
173
|
)
|
|
136
174
|
elif model_provider == ModelProvider.TOGETHER:
|
|
137
175
|
try:
|
|
@@ -140,11 +178,18 @@ def get_llm(role: LLMRole, config: Optional[AgentConfig] = None) -> LLM:
|
|
|
140
178
|
raise ImportError(
|
|
141
179
|
"together not available. Install with: pip install llama-index-llms-together"
|
|
142
180
|
) from e
|
|
181
|
+
additional_kwargs = {"seed": 42}
|
|
182
|
+
if model_name in [
|
|
183
|
+
"deepseek-ai/DeepSeek-V3.1", "openai/gpt-oss-120b",
|
|
184
|
+
"deepseek-ai/DeepSeek-R1", "Qwen/Qwen3-235B-A22B-Thinking-2507"
|
|
185
|
+
]:
|
|
186
|
+
additional_kwargs['reasoning_effort'] = "low"
|
|
143
187
|
llm = TogetherLLM(
|
|
144
188
|
model=model_name,
|
|
145
189
|
temperature=0,
|
|
146
190
|
is_function_calling_model=True,
|
|
147
191
|
max_tokens=max_tokens,
|
|
192
|
+
additional_kwargs=additional_kwargs,
|
|
148
193
|
)
|
|
149
194
|
elif model_provider == ModelProvider.GROQ:
|
|
150
195
|
try:
|
|
@@ -191,7 +236,11 @@ def get_llm(role: LLMRole, config: Optional[AgentConfig] = None) -> LLM:
|
|
|
191
236
|
raise ImportError(
|
|
192
237
|
"openai_like not available. Install with: pip install llama-index-llms-openai-like"
|
|
193
238
|
) from e
|
|
194
|
-
if
|
|
239
|
+
if (
|
|
240
|
+
not config
|
|
241
|
+
or not config.private_llm_api_base
|
|
242
|
+
or not config.private_llm_api_key
|
|
243
|
+
):
|
|
195
244
|
raise ValueError(
|
|
196
245
|
"Private LLM requires both private_llm_api_base and private_llm_api_key to be set in AgentConfig."
|
|
197
246
|
)
|