vectara-agentic 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of vectara-agentic might be problematic. Click here for more details.
- tests/__init__.py +7 -0
- tests/conftest.py +316 -0
- tests/endpoint.py +54 -17
- tests/run_tests.py +112 -0
- tests/test_agent.py +35 -33
- tests/test_agent_fallback_memory.py +270 -0
- tests/test_agent_memory_consistency.py +229 -0
- tests/test_agent_type.py +86 -143
- tests/test_api_endpoint.py +4 -0
- tests/test_bedrock.py +50 -31
- tests/test_fallback.py +4 -0
- tests/test_gemini.py +27 -59
- tests/test_groq.py +50 -31
- tests/test_private_llm.py +11 -2
- tests/test_return_direct.py +6 -2
- tests/test_serialization.py +7 -6
- tests/test_session_memory.py +252 -0
- tests/test_streaming.py +109 -0
- tests/test_together.py +62 -0
- tests/test_tools.py +10 -82
- tests/test_vectara_llms.py +4 -0
- tests/test_vhc.py +67 -0
- tests/test_workflow.py +13 -28
- vectara_agentic/__init__.py +27 -4
- vectara_agentic/_callback.py +65 -67
- vectara_agentic/_observability.py +30 -30
- vectara_agentic/_version.py +1 -1
- vectara_agentic/agent.py +565 -859
- vectara_agentic/agent_config.py +15 -14
- vectara_agentic/agent_core/__init__.py +22 -0
- vectara_agentic/agent_core/factory.py +383 -0
- vectara_agentic/{_prompts.py → agent_core/prompts.py} +21 -46
- vectara_agentic/agent_core/serialization.py +348 -0
- vectara_agentic/agent_core/streaming.py +483 -0
- vectara_agentic/agent_core/utils/__init__.py +29 -0
- vectara_agentic/agent_core/utils/hallucination.py +157 -0
- vectara_agentic/agent_core/utils/logging.py +52 -0
- vectara_agentic/agent_core/utils/schemas.py +87 -0
- vectara_agentic/agent_core/utils/tools.py +125 -0
- vectara_agentic/agent_endpoint.py +4 -6
- vectara_agentic/db_tools.py +37 -12
- vectara_agentic/llm_utils.py +42 -43
- vectara_agentic/sub_query_workflow.py +9 -14
- vectara_agentic/tool_utils.py +138 -83
- vectara_agentic/tools.py +36 -21
- vectara_agentic/tools_catalog.py +16 -16
- vectara_agentic/types.py +106 -8
- {vectara_agentic-0.3.3.dist-info → vectara_agentic-0.4.1.dist-info}/METADATA +111 -31
- vectara_agentic-0.4.1.dist-info/RECORD +53 -0
- tests/test_agent_planning.py +0 -64
- tests/test_hhem.py +0 -100
- vectara_agentic/hhem.py +0 -82
- vectara_agentic-0.3.3.dist-info/RECORD +0 -39
- {vectara_agentic-0.3.3.dist-info → vectara_agentic-0.4.1.dist-info}/WHEEL +0 -0
- {vectara_agentic-0.3.3.dist-info → vectara_agentic-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {vectara_agentic-0.3.3.dist-info → vectara_agentic-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Logging configuration and utilities for agent functionality.
|
|
3
|
+
|
|
4
|
+
This module provides logging filters, configuration, and setup utilities
|
|
5
|
+
specifically tailored for agent operations and debugging.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
from dotenv import load_dotenv
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class IgnoreUnpickleableAttributeFilter(logging.Filter):
|
|
13
|
+
"""
|
|
14
|
+
Filter to ignore log messages that contain certain strings.
|
|
15
|
+
|
|
16
|
+
This filter is used to suppress common unpickleable attribute warnings
|
|
17
|
+
that occur during agent serialization/deserialization operations.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def filter(self, record):
|
|
21
|
+
"""
|
|
22
|
+
Filter log records based on message content.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
record: LogRecord to evaluate
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
bool: True if record should be logged, False if it should be ignored
|
|
29
|
+
"""
|
|
30
|
+
msgs_to_ignore = [
|
|
31
|
+
"Removing unpickleable private attribute _split_fns",
|
|
32
|
+
"Removing unpickleable private attribute _sub_sentence_split_fns",
|
|
33
|
+
]
|
|
34
|
+
return all(msg not in record.getMessage() for msg in msgs_to_ignore)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def setup_agent_logging():
|
|
38
|
+
"""
|
|
39
|
+
Set up logging configuration for agent operations.
|
|
40
|
+
|
|
41
|
+
This configures logging filters and levels to reduce noise from
|
|
42
|
+
agent-related operations while maintaining useful debug information.
|
|
43
|
+
"""
|
|
44
|
+
# Add filter to suppress unpickleable attribute warnings
|
|
45
|
+
logging.getLogger().addFilter(IgnoreUnpickleableAttributeFilter())
|
|
46
|
+
|
|
47
|
+
# Set critical level for OTLP trace exporter to reduce noise
|
|
48
|
+
logger = logging.getLogger("opentelemetry.exporter.otlp.proto.http.trace_exporter")
|
|
49
|
+
logger.setLevel(logging.CRITICAL)
|
|
50
|
+
|
|
51
|
+
# Load environment variables with override
|
|
52
|
+
load_dotenv(override=True)
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Schema and type conversion utilities for agent functionality.
|
|
3
|
+
|
|
4
|
+
This module handles JSON schema to Python type conversion,
|
|
5
|
+
Pydantic model reconstruction, and type mapping operations.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Any, Union, List
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# Type mapping constants
|
|
12
|
+
JSON_TYPE_TO_PYTHON = {
|
|
13
|
+
"string": str,
|
|
14
|
+
"integer": int,
|
|
15
|
+
"boolean": bool,
|
|
16
|
+
"array": list,
|
|
17
|
+
"object": dict,
|
|
18
|
+
"number": float,
|
|
19
|
+
"null": type(None),
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
PY_TYPES = {
|
|
23
|
+
"str": str,
|
|
24
|
+
"int": int,
|
|
25
|
+
"float": float,
|
|
26
|
+
"bool": bool,
|
|
27
|
+
"dict": dict,
|
|
28
|
+
"list": list,
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_field_type(field_schema: dict) -> Any:
|
|
33
|
+
"""
|
|
34
|
+
Convert a JSON schema field definition to a Python type.
|
|
35
|
+
Handles 'type' and 'anyOf' cases.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
field_schema: JSON schema field definition
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
Any: Corresponding Python type
|
|
42
|
+
"""
|
|
43
|
+
if not field_schema: # Handles empty schema {}
|
|
44
|
+
return Any
|
|
45
|
+
|
|
46
|
+
if "anyOf" in field_schema:
|
|
47
|
+
types = []
|
|
48
|
+
for option_schema in field_schema["anyOf"]:
|
|
49
|
+
types.append(get_field_type(option_schema)) # Recursive call
|
|
50
|
+
if not types:
|
|
51
|
+
return Any
|
|
52
|
+
return Union[tuple(types)]
|
|
53
|
+
|
|
54
|
+
if "type" in field_schema and isinstance(field_schema["type"], list):
|
|
55
|
+
types = []
|
|
56
|
+
for type_name in field_schema["type"]:
|
|
57
|
+
if type_name == "array":
|
|
58
|
+
item_schema = field_schema.get("items", {})
|
|
59
|
+
types.append(List[get_field_type(item_schema)])
|
|
60
|
+
elif type_name in JSON_TYPE_TO_PYTHON:
|
|
61
|
+
types.append(JSON_TYPE_TO_PYTHON[type_name])
|
|
62
|
+
else:
|
|
63
|
+
types.append(Any) # Fallback for unknown types in the list
|
|
64
|
+
if not types:
|
|
65
|
+
return Any
|
|
66
|
+
return Union[tuple(types)] # type: ignore
|
|
67
|
+
|
|
68
|
+
if "type" in field_schema:
|
|
69
|
+
schema_type_name = field_schema["type"]
|
|
70
|
+
if schema_type_name == "array":
|
|
71
|
+
item_schema = field_schema.get(
|
|
72
|
+
"items", {}
|
|
73
|
+
) # Default to Any if "items" is missing
|
|
74
|
+
return List[get_field_type(item_schema)]
|
|
75
|
+
|
|
76
|
+
return JSON_TYPE_TO_PYTHON.get(schema_type_name, Any)
|
|
77
|
+
|
|
78
|
+
# If only "items" is present (implies array by some conventions, but less standard)
|
|
79
|
+
# Or if it's a schema with other keywords like 'properties' (implying object)
|
|
80
|
+
# For simplicity, if no "type" or "anyOf" at this point, default to Any or add more specific handling.
|
|
81
|
+
# If 'properties' in field_schema or 'additionalProperties' in field_schema, it's likely an object.
|
|
82
|
+
if "properties" in field_schema or "additionalProperties" in field_schema:
|
|
83
|
+
# This path might need to reconstruct a nested Pydantic model if you encounter such schemas.
|
|
84
|
+
# For now, treating as 'dict' or 'Any' might be a simpler placeholder.
|
|
85
|
+
return dict # Or Any, or more sophisticated object reconstruction.
|
|
86
|
+
|
|
87
|
+
return Any
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tool processing and validation utilities for agent functionality.
|
|
3
|
+
|
|
4
|
+
This module provides utilities for tool validation, processing, and
|
|
5
|
+
compatibility adjustments for different LLM providers.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import inspect
|
|
9
|
+
from typing import Any, List
|
|
10
|
+
from inspect import Signature, Parameter, ismethod
|
|
11
|
+
from collections import Counter
|
|
12
|
+
|
|
13
|
+
from pydantic import Field, create_model
|
|
14
|
+
from llama_index.core.tools import FunctionTool
|
|
15
|
+
from ...llm_utils import get_llm
|
|
16
|
+
from ...types import LLMRole
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def sanitize_tools_for_gemini(tools: List[FunctionTool]) -> List[FunctionTool]:
|
|
20
|
+
"""
|
|
21
|
+
Strip all default values from tools for Gemini LLM compatibility.
|
|
22
|
+
|
|
23
|
+
Gemini requires that tools only show required parameters without defaults.
|
|
24
|
+
This function modifies:
|
|
25
|
+
- tool.fn signature
|
|
26
|
+
- tool.async_fn signature
|
|
27
|
+
- tool.metadata.fn_schema
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
tools: List of FunctionTool objects to sanitize
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
List[FunctionTool]: Sanitized tools with no default values
|
|
34
|
+
"""
|
|
35
|
+
for tool in tools:
|
|
36
|
+
# 1) Strip defaults off the actual callables
|
|
37
|
+
for func in (tool.fn, tool.async_fn):
|
|
38
|
+
if not func:
|
|
39
|
+
continue
|
|
40
|
+
orig_sig = inspect.signature(func)
|
|
41
|
+
new_params = [
|
|
42
|
+
p.replace(default=Parameter.empty) for p in orig_sig.parameters.values()
|
|
43
|
+
]
|
|
44
|
+
new_sig = Signature(
|
|
45
|
+
new_params, return_annotation=orig_sig.return_annotation
|
|
46
|
+
)
|
|
47
|
+
if ismethod(func):
|
|
48
|
+
func.__func__.__signature__ = new_sig
|
|
49
|
+
else:
|
|
50
|
+
func.__signature__ = new_sig
|
|
51
|
+
|
|
52
|
+
# 2) Rebuild the Pydantic schema so that *every* field is required
|
|
53
|
+
schema_cls = getattr(tool.metadata, "fn_schema", None)
|
|
54
|
+
if schema_cls and hasattr(schema_cls, "model_fields"):
|
|
55
|
+
# Collect (name → (type, Field(...))) for all fields
|
|
56
|
+
new_fields: dict[str, tuple[type, Any]] = {}
|
|
57
|
+
for name, mf in schema_cls.model_fields.items():
|
|
58
|
+
typ = mf.annotation
|
|
59
|
+
desc = getattr(mf, "description", "")
|
|
60
|
+
# Force required (no default) with Field(...)
|
|
61
|
+
new_fields[name] = (typ, Field(..., description=desc))
|
|
62
|
+
|
|
63
|
+
# Make a brand-new schema class where every field is required
|
|
64
|
+
no_default_schema = create_model(
|
|
65
|
+
f"{schema_cls.__name__}", # new class name
|
|
66
|
+
**new_fields, # type: ignore
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
# Give it a clean __signature__ so inspect.signature sees no defaults
|
|
70
|
+
params = [
|
|
71
|
+
Parameter(n, Parameter.POSITIONAL_OR_KEYWORD, annotation=typ)
|
|
72
|
+
for n, (typ, _) in new_fields.items()
|
|
73
|
+
]
|
|
74
|
+
no_default_schema.__signature__ = Signature(params)
|
|
75
|
+
|
|
76
|
+
# Swap it back onto the tool
|
|
77
|
+
tool.metadata.fn_schema = no_default_schema
|
|
78
|
+
|
|
79
|
+
return tools
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def validate_tool_consistency(
|
|
83
|
+
tools: List[FunctionTool], custom_instructions: str, agent_config
|
|
84
|
+
) -> None:
|
|
85
|
+
"""
|
|
86
|
+
Validate that tools mentioned in instructions actually exist.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
tools: List of available tools
|
|
90
|
+
custom_instructions: Custom instructions that may reference tools
|
|
91
|
+
agent_config: Agent configuration for LLM access
|
|
92
|
+
|
|
93
|
+
Raises:
|
|
94
|
+
ValueError: If invalid tools are referenced in instructions
|
|
95
|
+
"""
|
|
96
|
+
tool_names = [tool.metadata.name for tool in tools]
|
|
97
|
+
|
|
98
|
+
# Check for duplicate tools
|
|
99
|
+
duplicates = [tool for tool, count in Counter(tool_names).items() if count > 1]
|
|
100
|
+
if duplicates:
|
|
101
|
+
raise ValueError(f"Duplicate tools detected: {', '.join(duplicates)}")
|
|
102
|
+
|
|
103
|
+
# Validate tools mentioned in instructions exist
|
|
104
|
+
if custom_instructions:
|
|
105
|
+
prompt = f"""
|
|
106
|
+
You are provided these tools:
|
|
107
|
+
<tools>{','.join(tool_names)}</tools>
|
|
108
|
+
And these instructions:
|
|
109
|
+
<instructions>
|
|
110
|
+
{custom_instructions}
|
|
111
|
+
</instructions>
|
|
112
|
+
Your task is to identify invalid tools.
|
|
113
|
+
A tool is invalid if it is mentioned in the instructions but not in the tools list.
|
|
114
|
+
A tool's name must have at least two characters.
|
|
115
|
+
Your response should be a comma-separated list of the invalid tools.
|
|
116
|
+
If no invalid tools exist, respond with "<OKAY>" (and nothing else).
|
|
117
|
+
"""
|
|
118
|
+
llm = get_llm(LLMRole.MAIN, config=agent_config)
|
|
119
|
+
bad_tools_str = llm.complete(prompt).text.strip("\n")
|
|
120
|
+
if bad_tools_str and bad_tools_str != "<OKAY>":
|
|
121
|
+
bad_tools = [tool.strip() for tool in bad_tools_str.split(",")]
|
|
122
|
+
numbered = ", ".join(f"({i}) {tool}" for i, tool in enumerate(bad_tools, 1))
|
|
123
|
+
raise ValueError(
|
|
124
|
+
f"The Agent custom instructions mention these invalid tools: {numbered}"
|
|
125
|
+
)
|
|
@@ -16,12 +16,6 @@ from .agent import Agent
|
|
|
16
16
|
from .agent_config import AgentConfig
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
class ChatRequest(BaseModel):
|
|
20
|
-
"""Request schema for the /chat endpoint."""
|
|
21
|
-
|
|
22
|
-
message: str
|
|
23
|
-
|
|
24
|
-
|
|
25
19
|
class CompletionRequest(BaseModel):
|
|
26
20
|
"""Request schema for the /v1/completions endpoint."""
|
|
27
21
|
|
|
@@ -64,12 +58,14 @@ class CompletionResponse(BaseModel):
|
|
|
64
58
|
|
|
65
59
|
class ChatMessage(BaseModel):
|
|
66
60
|
"""Schema for individual chat messages in ChatCompletionRequest."""
|
|
61
|
+
|
|
67
62
|
role: Literal["system", "user", "assistant"]
|
|
68
63
|
content: str
|
|
69
64
|
|
|
70
65
|
|
|
71
66
|
class ChatCompletionRequest(BaseModel):
|
|
72
67
|
"""Request schema for the /v1/chat endpoint."""
|
|
68
|
+
|
|
73
69
|
model: str
|
|
74
70
|
messages: List[ChatMessage]
|
|
75
71
|
temperature: Optional[float] = Field(1.0, ge=0.0, le=2.0)
|
|
@@ -79,6 +75,7 @@ class ChatCompletionRequest(BaseModel):
|
|
|
79
75
|
|
|
80
76
|
class ChatCompletionChoice(BaseModel):
|
|
81
77
|
"""Choice schema returned in ChatCompletionResponse."""
|
|
78
|
+
|
|
82
79
|
index: int
|
|
83
80
|
message: ChatMessage
|
|
84
81
|
finish_reason: Literal["stop", "length", "error", None]
|
|
@@ -86,6 +83,7 @@ class ChatCompletionChoice(BaseModel):
|
|
|
86
83
|
|
|
87
84
|
class ChatCompletionResponse(BaseModel):
|
|
88
85
|
"""Response schema for the /v1/chat endpoint."""
|
|
86
|
+
|
|
89
87
|
id: str
|
|
90
88
|
object: Literal["chat.completion"]
|
|
91
89
|
created: int
|
vectara_agentic/db_tools.py
CHANGED
|
@@ -6,6 +6,7 @@ It makes the following adjustments:
|
|
|
6
6
|
* Makes sure the load_data method returns a list of text values from the database (and not Document[] objects).
|
|
7
7
|
* Limits the returned rows to self.max_rows.
|
|
8
8
|
"""
|
|
9
|
+
|
|
9
10
|
from typing import Any, Optional, List, Awaitable, Callable
|
|
10
11
|
import asyncio
|
|
11
12
|
from inspect import signature
|
|
@@ -24,15 +25,20 @@ from llama_index.core.tools.utils import create_schema_from_function
|
|
|
24
25
|
|
|
25
26
|
AsyncCallable = Callable[..., Awaitable[Any]]
|
|
26
27
|
|
|
28
|
+
|
|
27
29
|
class DatabaseTools:
|
|
28
30
|
"""Database tools for vectara-agentic
|
|
29
31
|
This class provides a set of tools to interact with a database.
|
|
30
32
|
It allows you to load data, list tables, describe tables, and load unique values.
|
|
31
33
|
It also provides a method to load sample data from a specified table.
|
|
32
34
|
"""
|
|
35
|
+
|
|
33
36
|
spec_functions = [
|
|
34
|
-
"load_data",
|
|
35
|
-
"
|
|
37
|
+
"load_data",
|
|
38
|
+
"load_sample_data",
|
|
39
|
+
"list_tables",
|
|
40
|
+
"describe_tables",
|
|
41
|
+
"load_unique_values",
|
|
36
42
|
]
|
|
37
43
|
|
|
38
44
|
def __init__(
|
|
@@ -61,7 +67,7 @@ class DatabaseTools:
|
|
|
61
67
|
elif uri:
|
|
62
68
|
self.uri = uri
|
|
63
69
|
self.sql_database = SQLDatabase.from_uri(uri, *args, **kwargs)
|
|
64
|
-
elif
|
|
70
|
+
elif scheme and host and port and user and password and dbname:
|
|
65
71
|
uri = f"{scheme}://{user}:{password}@{host}:{port}/{dbname}"
|
|
66
72
|
self.uri = uri
|
|
67
73
|
self.sql_database = SQLDatabase.from_uri(uri, *args, **kwargs)
|
|
@@ -76,7 +82,8 @@ class DatabaseTools:
|
|
|
76
82
|
self._metadata.reflect(bind=self.sql_database.engine)
|
|
77
83
|
|
|
78
84
|
def _get_metadata_from_fn_name(
|
|
79
|
-
self,
|
|
85
|
+
self,
|
|
86
|
+
fn_name: str,
|
|
80
87
|
) -> Optional[ToolMetadata]:
|
|
81
88
|
"""Return map from function name.
|
|
82
89
|
|
|
@@ -87,7 +94,9 @@ class DatabaseTools:
|
|
|
87
94
|
func = getattr(self, fn_name)
|
|
88
95
|
except AttributeError:
|
|
89
96
|
return None
|
|
90
|
-
name =
|
|
97
|
+
name = (
|
|
98
|
+
self.tool_name_prefix + "_" + fn_name if self.tool_name_prefix else fn_name
|
|
99
|
+
)
|
|
91
100
|
docstring = func.__doc__ or ""
|
|
92
101
|
description = f"{name}{signature(func)}\n{docstring}"
|
|
93
102
|
fn_schema = create_schema_from_function(fn_name, getattr(self, fn_name))
|
|
@@ -118,7 +127,9 @@ class DatabaseTools:
|
|
|
118
127
|
try:
|
|
119
128
|
count_rows = self._load_data(count_query)
|
|
120
129
|
except Exception as e:
|
|
121
|
-
return [
|
|
130
|
+
return [
|
|
131
|
+
f"Error ({str(e)}) occurred while counting number of rows, check your query."
|
|
132
|
+
]
|
|
122
133
|
num_rows = int(count_rows[0].text)
|
|
123
134
|
if num_rows > self.max_rows:
|
|
124
135
|
return [
|
|
@@ -128,7 +139,9 @@ class DatabaseTools:
|
|
|
128
139
|
try:
|
|
129
140
|
res = self._load_data(sql_query)
|
|
130
141
|
except Exception as e:
|
|
131
|
-
return [
|
|
142
|
+
return [
|
|
143
|
+
f"Error ({str(e)}) occurred while executing the query {sql_query}, check your query."
|
|
144
|
+
]
|
|
132
145
|
return [d.text for d in res]
|
|
133
146
|
|
|
134
147
|
def load_sample_data(self, table_name: str, num_rows: int = 25) -> Any:
|
|
@@ -149,7 +162,9 @@ class DatabaseTools:
|
|
|
149
162
|
try:
|
|
150
163
|
res = self._load_data(f"SELECT * FROM {table_name} LIMIT {num_rows}")
|
|
151
164
|
except Exception as e:
|
|
152
|
-
return [
|
|
165
|
+
return [
|
|
166
|
+
f"Error ({str(e)}) occurred while loading sample data for table {table_name}"
|
|
167
|
+
]
|
|
153
168
|
return [d.text for d in res]
|
|
154
169
|
|
|
155
170
|
def list_tables(self) -> List[str]:
|
|
@@ -179,7 +194,11 @@ class DatabaseTools:
|
|
|
179
194
|
table_schemas = []
|
|
180
195
|
for table_name in table_names:
|
|
181
196
|
table = next(
|
|
182
|
-
(
|
|
197
|
+
(
|
|
198
|
+
table
|
|
199
|
+
for table in self._metadata.sorted_tables
|
|
200
|
+
if table.name == table_name
|
|
201
|
+
),
|
|
183
202
|
None,
|
|
184
203
|
)
|
|
185
204
|
if table is None:
|
|
@@ -188,7 +207,9 @@ class DatabaseTools:
|
|
|
188
207
|
table_schemas.append(f"{schema}\n")
|
|
189
208
|
return "\n".join(table_schemas)
|
|
190
209
|
|
|
191
|
-
def load_unique_values(
|
|
210
|
+
def load_unique_values(
|
|
211
|
+
self, table_name: str, columns: list[str], num_vals: int = 200
|
|
212
|
+
) -> Any:
|
|
192
213
|
"""
|
|
193
214
|
Fetches the first num_vals unique values from the specified columns of the database table.
|
|
194
215
|
|
|
@@ -209,10 +230,14 @@ class DatabaseTools:
|
|
|
209
230
|
res = {}
|
|
210
231
|
try:
|
|
211
232
|
for column in columns:
|
|
212
|
-
unique_vals = self._load_data(
|
|
233
|
+
unique_vals = self._load_data(
|
|
234
|
+
f'SELECT DISTINCT "{column}" FROM {table_name} LIMIT {num_vals}'
|
|
235
|
+
)
|
|
213
236
|
res[column] = [d.text for d in unique_vals]
|
|
214
237
|
except Exception as e:
|
|
215
|
-
return {
|
|
238
|
+
return {
|
|
239
|
+
f"Error ({str(e)}) occurred while loading unique values for table {table_name}"
|
|
240
|
+
}
|
|
216
241
|
return res
|
|
217
242
|
|
|
218
243
|
def to_tool_list(self) -> List[FunctionTool]:
|
vectara_agentic/llm_utils.py
CHANGED
|
@@ -2,10 +2,10 @@
|
|
|
2
2
|
Utilities for the Vectara agentic.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from typing import Tuple,
|
|
5
|
+
from typing import Tuple, Optional
|
|
6
6
|
import os
|
|
7
7
|
from functools import lru_cache
|
|
8
|
-
import
|
|
8
|
+
import hashlib
|
|
9
9
|
|
|
10
10
|
from llama_index.core.llms import LLM
|
|
11
11
|
from llama_index.llms.openai import OpenAI
|
|
@@ -13,15 +13,14 @@ from llama_index.llms.anthropic import Anthropic
|
|
|
13
13
|
|
|
14
14
|
# LLM provider imports are now lazy-loaded in get_llm() function
|
|
15
15
|
|
|
16
|
-
from .types import LLMRole,
|
|
16
|
+
from .types import LLMRole, ModelProvider
|
|
17
17
|
from .agent_config import AgentConfig
|
|
18
18
|
|
|
19
19
|
provider_to_default_model_name = {
|
|
20
20
|
ModelProvider.OPENAI: "gpt-4.1",
|
|
21
21
|
ModelProvider.ANTHROPIC: "claude-sonnet-4-20250514",
|
|
22
22
|
ModelProvider.TOGETHER: "deepseek-ai/DeepSeek-V3",
|
|
23
|
-
ModelProvider.GROQ: "
|
|
24
|
-
ModelProvider.FIREWORKS: "accounts/fireworks/models/firefunction-v2",
|
|
23
|
+
ModelProvider.GROQ: "openai/gpt-oss-20b",
|
|
25
24
|
ModelProvider.BEDROCK: "us.anthropic.claude-sonnet-4-20250514-v1:0",
|
|
26
25
|
ModelProvider.COHERE: "command-a-03-2025",
|
|
27
26
|
ModelProvider.GEMINI: "models/gemini-2.5-flash",
|
|
@@ -29,6 +28,30 @@ provider_to_default_model_name = {
|
|
|
29
28
|
|
|
30
29
|
DEFAULT_MODEL_PROVIDER = ModelProvider.OPENAI
|
|
31
30
|
|
|
31
|
+
# Manual cache for LLM instances to handle mutable AgentConfig objects
|
|
32
|
+
_llm_cache = {}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _create_llm_cache_key(role: LLMRole, config: Optional[AgentConfig] = None) -> str:
|
|
36
|
+
"""Create a hash-based cache key for LLM instances."""
|
|
37
|
+
if config is None:
|
|
38
|
+
config = AgentConfig()
|
|
39
|
+
|
|
40
|
+
# Extract only the relevant config parameters for the cache key
|
|
41
|
+
cache_data = {
|
|
42
|
+
"role": role.value,
|
|
43
|
+
"main_llm_provider": config.main_llm_provider.value,
|
|
44
|
+
"main_llm_model_name": config.main_llm_model_name,
|
|
45
|
+
"tool_llm_provider": config.tool_llm_provider.value,
|
|
46
|
+
"tool_llm_model_name": config.tool_llm_model_name,
|
|
47
|
+
"private_llm_api_base": config.private_llm_api_base,
|
|
48
|
+
"private_llm_api_key": config.private_llm_api_key,
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
# Create a stable hash from the cache data
|
|
52
|
+
cache_str = str(sorted(cache_data.items()))
|
|
53
|
+
return hashlib.md5(cache_str.encode()).hexdigest()
|
|
54
|
+
|
|
32
55
|
|
|
33
56
|
@lru_cache(maxsize=None)
|
|
34
57
|
def _get_llm_params_for_role(
|
|
@@ -54,42 +77,20 @@ def _get_llm_params_for_role(
|
|
|
54
77
|
model_provider
|
|
55
78
|
)
|
|
56
79
|
|
|
57
|
-
# If the agent type is OpenAI, check that the main LLM provider is also OpenAI.
|
|
58
|
-
if role == LLMRole.MAIN and config.agent_type == AgentType.OPENAI:
|
|
59
|
-
if model_provider != ModelProvider.OPENAI:
|
|
60
|
-
raise ValueError(
|
|
61
|
-
"OpenAI agent requested but main model provider is not OpenAI."
|
|
62
|
-
)
|
|
63
|
-
|
|
64
80
|
return model_provider, model_name
|
|
65
81
|
|
|
66
82
|
|
|
67
|
-
@lru_cache(maxsize=None)
|
|
68
|
-
def get_tokenizer_for_model(
|
|
69
|
-
role: LLMRole, config: Optional[AgentConfig] = None
|
|
70
|
-
) -> Optional[Callable]:
|
|
71
|
-
"""
|
|
72
|
-
Get the tokenizer for the specified model, as determined by the role & config.
|
|
73
|
-
"""
|
|
74
|
-
model_name = "Unknown model"
|
|
75
|
-
try:
|
|
76
|
-
model_provider, model_name = _get_llm_params_for_role(role, config)
|
|
77
|
-
if model_provider == ModelProvider.OPENAI:
|
|
78
|
-
return tiktoken.encoding_for_model("gpt-4o").encode
|
|
79
|
-
if model_provider == ModelProvider.ANTHROPIC:
|
|
80
|
-
return Anthropic().tokenizer
|
|
81
|
-
except Exception:
|
|
82
|
-
print(f"Error getting tokenizer for model {model_name}, ignoring")
|
|
83
|
-
return None
|
|
84
|
-
return None
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
@lru_cache(maxsize=None)
|
|
88
83
|
def get_llm(role: LLMRole, config: Optional[AgentConfig] = None) -> LLM:
|
|
89
84
|
"""
|
|
90
85
|
Get the LLM for the specified role, using the provided config
|
|
91
86
|
or a default if none is provided.
|
|
87
|
+
|
|
88
|
+
Uses a cache based on configuration parameters to avoid repeated LLM instantiation.
|
|
92
89
|
"""
|
|
90
|
+
# Check cache first
|
|
91
|
+
cache_key = _create_llm_cache_key(role, config)
|
|
92
|
+
if cache_key in _llm_cache:
|
|
93
|
+
return _llm_cache[cache_key]
|
|
93
94
|
model_provider, model_name = _get_llm_params_for_role(role, config)
|
|
94
95
|
max_tokens = (
|
|
95
96
|
16384
|
|
@@ -107,7 +108,7 @@ def get_llm(role: LLMRole, config: Optional[AgentConfig] = None) -> LLM:
|
|
|
107
108
|
model=model_name,
|
|
108
109
|
temperature=0,
|
|
109
110
|
is_function_calling_model=True,
|
|
110
|
-
strict=
|
|
111
|
+
strict=False,
|
|
111
112
|
max_tokens=max_tokens,
|
|
112
113
|
pydantic_program_mode="openai",
|
|
113
114
|
)
|
|
@@ -128,7 +129,6 @@ def get_llm(role: LLMRole, config: Optional[AgentConfig] = None) -> LLM:
|
|
|
128
129
|
model=model_name,
|
|
129
130
|
temperature=0,
|
|
130
131
|
is_function_calling_model=True,
|
|
131
|
-
allow_parallel_tool_calls=True,
|
|
132
132
|
max_tokens=max_tokens,
|
|
133
133
|
)
|
|
134
134
|
elif model_provider == ModelProvider.TOGETHER:
|
|
@@ -157,14 +157,6 @@ def get_llm(role: LLMRole, config: Optional[AgentConfig] = None) -> LLM:
|
|
|
157
157
|
is_function_calling_model=True,
|
|
158
158
|
max_tokens=max_tokens,
|
|
159
159
|
)
|
|
160
|
-
elif model_provider == ModelProvider.FIREWORKS:
|
|
161
|
-
try:
|
|
162
|
-
from llama_index.llms.fireworks import Fireworks
|
|
163
|
-
except ImportError as e:
|
|
164
|
-
raise ImportError(
|
|
165
|
-
"fireworks not available. Install with: pip install llama-index-llms-fireworks"
|
|
166
|
-
) from e
|
|
167
|
-
llm = Fireworks(model=model_name, temperature=0, max_tokens=max_tokens)
|
|
168
160
|
elif model_provider == ModelProvider.BEDROCK:
|
|
169
161
|
try:
|
|
170
162
|
from llama_index.llms.bedrock_converse import BedrockConverse
|
|
@@ -197,6 +189,10 @@ def get_llm(role: LLMRole, config: Optional[AgentConfig] = None) -> LLM:
|
|
|
197
189
|
raise ImportError(
|
|
198
190
|
"openai_like not available. Install with: pip install llama-index-llms-openai-like"
|
|
199
191
|
) from e
|
|
192
|
+
if not config or not config.private_llm_api_base or not config.private_llm_api_key:
|
|
193
|
+
raise ValueError(
|
|
194
|
+
"Private LLM requires both private_llm_api_base and private_llm_api_key to be set in AgentConfig."
|
|
195
|
+
)
|
|
200
196
|
llm = OpenAILike(
|
|
201
197
|
model=model_name,
|
|
202
198
|
temperature=0,
|
|
@@ -209,4 +205,7 @@ def get_llm(role: LLMRole, config: Optional[AgentConfig] = None) -> LLM:
|
|
|
209
205
|
|
|
210
206
|
else:
|
|
211
207
|
raise ValueError(f"Unknown LLM provider: {model_provider}")
|
|
208
|
+
|
|
209
|
+
# Cache the created LLM instance
|
|
210
|
+
_llm_cache[cache_key] = llm
|
|
212
211
|
return llm
|