vectara-agentic 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of vectara-agentic might be problematic. Click here for more details.

Files changed (56) hide show
  1. tests/__init__.py +7 -0
  2. tests/conftest.py +316 -0
  3. tests/endpoint.py +54 -17
  4. tests/run_tests.py +112 -0
  5. tests/test_agent.py +35 -33
  6. tests/test_agent_fallback_memory.py +270 -0
  7. tests/test_agent_memory_consistency.py +229 -0
  8. tests/test_agent_type.py +86 -143
  9. tests/test_api_endpoint.py +4 -0
  10. tests/test_bedrock.py +50 -31
  11. tests/test_fallback.py +4 -0
  12. tests/test_gemini.py +27 -59
  13. tests/test_groq.py +50 -31
  14. tests/test_private_llm.py +11 -2
  15. tests/test_return_direct.py +6 -2
  16. tests/test_serialization.py +7 -6
  17. tests/test_session_memory.py +252 -0
  18. tests/test_streaming.py +109 -0
  19. tests/test_together.py +62 -0
  20. tests/test_tools.py +10 -82
  21. tests/test_vectara_llms.py +4 -0
  22. tests/test_vhc.py +67 -0
  23. tests/test_workflow.py +13 -28
  24. vectara_agentic/__init__.py +27 -4
  25. vectara_agentic/_callback.py +65 -67
  26. vectara_agentic/_observability.py +30 -30
  27. vectara_agentic/_version.py +1 -1
  28. vectara_agentic/agent.py +565 -859
  29. vectara_agentic/agent_config.py +15 -14
  30. vectara_agentic/agent_core/__init__.py +22 -0
  31. vectara_agentic/agent_core/factory.py +383 -0
  32. vectara_agentic/{_prompts.py → agent_core/prompts.py} +21 -46
  33. vectara_agentic/agent_core/serialization.py +348 -0
  34. vectara_agentic/agent_core/streaming.py +483 -0
  35. vectara_agentic/agent_core/utils/__init__.py +29 -0
  36. vectara_agentic/agent_core/utils/hallucination.py +157 -0
  37. vectara_agentic/agent_core/utils/logging.py +52 -0
  38. vectara_agentic/agent_core/utils/schemas.py +87 -0
  39. vectara_agentic/agent_core/utils/tools.py +125 -0
  40. vectara_agentic/agent_endpoint.py +4 -6
  41. vectara_agentic/db_tools.py +37 -12
  42. vectara_agentic/llm_utils.py +42 -43
  43. vectara_agentic/sub_query_workflow.py +9 -14
  44. vectara_agentic/tool_utils.py +138 -83
  45. vectara_agentic/tools.py +36 -21
  46. vectara_agentic/tools_catalog.py +16 -16
  47. vectara_agentic/types.py +106 -8
  48. {vectara_agentic-0.3.3.dist-info → vectara_agentic-0.4.1.dist-info}/METADATA +111 -31
  49. vectara_agentic-0.4.1.dist-info/RECORD +53 -0
  50. tests/test_agent_planning.py +0 -64
  51. tests/test_hhem.py +0 -100
  52. vectara_agentic/hhem.py +0 -82
  53. vectara_agentic-0.3.3.dist-info/RECORD +0 -39
  54. {vectara_agentic-0.3.3.dist-info → vectara_agentic-0.4.1.dist-info}/WHEEL +0 -0
  55. {vectara_agentic-0.3.3.dist-info → vectara_agentic-0.4.1.dist-info}/licenses/LICENSE +0 -0
  56. {vectara_agentic-0.3.3.dist-info → vectara_agentic-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,52 @@
1
+ """
2
+ Logging configuration and utilities for agent functionality.
3
+
4
+ This module provides logging filters, configuration, and setup utilities
5
+ specifically tailored for agent operations and debugging.
6
+ """
7
+
8
+ import logging
9
+ from dotenv import load_dotenv
10
+
11
+
12
+ class IgnoreUnpickleableAttributeFilter(logging.Filter):
13
+ """
14
+ Filter to ignore log messages that contain certain strings.
15
+
16
+ This filter is used to suppress common unpickleable attribute warnings
17
+ that occur during agent serialization/deserialization operations.
18
+ """
19
+
20
+ def filter(self, record):
21
+ """
22
+ Filter log records based on message content.
23
+
24
+ Args:
25
+ record: LogRecord to evaluate
26
+
27
+ Returns:
28
+ bool: True if record should be logged, False if it should be ignored
29
+ """
30
+ msgs_to_ignore = [
31
+ "Removing unpickleable private attribute _split_fns",
32
+ "Removing unpickleable private attribute _sub_sentence_split_fns",
33
+ ]
34
+ return all(msg not in record.getMessage() for msg in msgs_to_ignore)
35
+
36
+
37
+ def setup_agent_logging():
38
+ """
39
+ Set up logging configuration for agent operations.
40
+
41
+ This configures logging filters and levels to reduce noise from
42
+ agent-related operations while maintaining useful debug information.
43
+ """
44
+ # Add filter to suppress unpickleable attribute warnings
45
+ logging.getLogger().addFilter(IgnoreUnpickleableAttributeFilter())
46
+
47
+ # Set critical level for OTLP trace exporter to reduce noise
48
+ logger = logging.getLogger("opentelemetry.exporter.otlp.proto.http.trace_exporter")
49
+ logger.setLevel(logging.CRITICAL)
50
+
51
+ # Load environment variables with override
52
+ load_dotenv(override=True)
@@ -0,0 +1,87 @@
1
+ """
2
+ Schema and type conversion utilities for agent functionality.
3
+
4
+ This module handles JSON schema to Python type conversion,
5
+ Pydantic model reconstruction, and type mapping operations.
6
+ """
7
+
8
+ from typing import Any, Union, List
9
+
10
+
11
+ # Type mapping constants
12
+ JSON_TYPE_TO_PYTHON = {
13
+ "string": str,
14
+ "integer": int,
15
+ "boolean": bool,
16
+ "array": list,
17
+ "object": dict,
18
+ "number": float,
19
+ "null": type(None),
20
+ }
21
+
22
+ PY_TYPES = {
23
+ "str": str,
24
+ "int": int,
25
+ "float": float,
26
+ "bool": bool,
27
+ "dict": dict,
28
+ "list": list,
29
+ }
30
+
31
+
32
+ def get_field_type(field_schema: dict) -> Any:
33
+ """
34
+ Convert a JSON schema field definition to a Python type.
35
+ Handles 'type' and 'anyOf' cases.
36
+
37
+ Args:
38
+ field_schema: JSON schema field definition
39
+
40
+ Returns:
41
+ Any: Corresponding Python type
42
+ """
43
+ if not field_schema: # Handles empty schema {}
44
+ return Any
45
+
46
+ if "anyOf" in field_schema:
47
+ types = []
48
+ for option_schema in field_schema["anyOf"]:
49
+ types.append(get_field_type(option_schema)) # Recursive call
50
+ if not types:
51
+ return Any
52
+ return Union[tuple(types)]
53
+
54
+ if "type" in field_schema and isinstance(field_schema["type"], list):
55
+ types = []
56
+ for type_name in field_schema["type"]:
57
+ if type_name == "array":
58
+ item_schema = field_schema.get("items", {})
59
+ types.append(List[get_field_type(item_schema)])
60
+ elif type_name in JSON_TYPE_TO_PYTHON:
61
+ types.append(JSON_TYPE_TO_PYTHON[type_name])
62
+ else:
63
+ types.append(Any) # Fallback for unknown types in the list
64
+ if not types:
65
+ return Any
66
+ return Union[tuple(types)] # type: ignore
67
+
68
+ if "type" in field_schema:
69
+ schema_type_name = field_schema["type"]
70
+ if schema_type_name == "array":
71
+ item_schema = field_schema.get(
72
+ "items", {}
73
+ ) # Default to Any if "items" is missing
74
+ return List[get_field_type(item_schema)]
75
+
76
+ return JSON_TYPE_TO_PYTHON.get(schema_type_name, Any)
77
+
78
+ # If only "items" is present (implies array by some conventions, but less standard)
79
+ # Or if it's a schema with other keywords like 'properties' (implying object)
80
+ # For simplicity, if no "type" or "anyOf" at this point, default to Any or add more specific handling.
81
+ # If 'properties' in field_schema or 'additionalProperties' in field_schema, it's likely an object.
82
+ if "properties" in field_schema or "additionalProperties" in field_schema:
83
+ # This path might need to reconstruct a nested Pydantic model if you encounter such schemas.
84
+ # For now, treating as 'dict' or 'Any' might be a simpler placeholder.
85
+ return dict # Or Any, or more sophisticated object reconstruction.
86
+
87
+ return Any
@@ -0,0 +1,125 @@
1
+ """
2
+ Tool processing and validation utilities for agent functionality.
3
+
4
+ This module provides utilities for tool validation, processing, and
5
+ compatibility adjustments for different LLM providers.
6
+ """
7
+
8
+ import inspect
9
+ from typing import Any, List
10
+ from inspect import Signature, Parameter, ismethod
11
+ from collections import Counter
12
+
13
+ from pydantic import Field, create_model
14
+ from llama_index.core.tools import FunctionTool
15
+ from ...llm_utils import get_llm
16
+ from ...types import LLMRole
17
+
18
+
19
+ def sanitize_tools_for_gemini(tools: List[FunctionTool]) -> List[FunctionTool]:
20
+ """
21
+ Strip all default values from tools for Gemini LLM compatibility.
22
+
23
+ Gemini requires that tools only show required parameters without defaults.
24
+ This function modifies:
25
+ - tool.fn signature
26
+ - tool.async_fn signature
27
+ - tool.metadata.fn_schema
28
+
29
+ Args:
30
+ tools: List of FunctionTool objects to sanitize
31
+
32
+ Returns:
33
+ List[FunctionTool]: Sanitized tools with no default values
34
+ """
35
+ for tool in tools:
36
+ # 1) Strip defaults off the actual callables
37
+ for func in (tool.fn, tool.async_fn):
38
+ if not func:
39
+ continue
40
+ orig_sig = inspect.signature(func)
41
+ new_params = [
42
+ p.replace(default=Parameter.empty) for p in orig_sig.parameters.values()
43
+ ]
44
+ new_sig = Signature(
45
+ new_params, return_annotation=orig_sig.return_annotation
46
+ )
47
+ if ismethod(func):
48
+ func.__func__.__signature__ = new_sig
49
+ else:
50
+ func.__signature__ = new_sig
51
+
52
+ # 2) Rebuild the Pydantic schema so that *every* field is required
53
+ schema_cls = getattr(tool.metadata, "fn_schema", None)
54
+ if schema_cls and hasattr(schema_cls, "model_fields"):
55
+ # Collect (name → (type, Field(...))) for all fields
56
+ new_fields: dict[str, tuple[type, Any]] = {}
57
+ for name, mf in schema_cls.model_fields.items():
58
+ typ = mf.annotation
59
+ desc = getattr(mf, "description", "")
60
+ # Force required (no default) with Field(...)
61
+ new_fields[name] = (typ, Field(..., description=desc))
62
+
63
+ # Make a brand-new schema class where every field is required
64
+ no_default_schema = create_model(
65
+ f"{schema_cls.__name__}", # new class name
66
+ **new_fields, # type: ignore
67
+ )
68
+
69
+ # Give it a clean __signature__ so inspect.signature sees no defaults
70
+ params = [
71
+ Parameter(n, Parameter.POSITIONAL_OR_KEYWORD, annotation=typ)
72
+ for n, (typ, _) in new_fields.items()
73
+ ]
74
+ no_default_schema.__signature__ = Signature(params)
75
+
76
+ # Swap it back onto the tool
77
+ tool.metadata.fn_schema = no_default_schema
78
+
79
+ return tools
80
+
81
+
82
+ def validate_tool_consistency(
83
+ tools: List[FunctionTool], custom_instructions: str, agent_config
84
+ ) -> None:
85
+ """
86
+ Validate that tools mentioned in instructions actually exist.
87
+
88
+ Args:
89
+ tools: List of available tools
90
+ custom_instructions: Custom instructions that may reference tools
91
+ agent_config: Agent configuration for LLM access
92
+
93
+ Raises:
94
+ ValueError: If invalid tools are referenced in instructions
95
+ """
96
+ tool_names = [tool.metadata.name for tool in tools]
97
+
98
+ # Check for duplicate tools
99
+ duplicates = [tool for tool, count in Counter(tool_names).items() if count > 1]
100
+ if duplicates:
101
+ raise ValueError(f"Duplicate tools detected: {', '.join(duplicates)}")
102
+
103
+ # Validate tools mentioned in instructions exist
104
+ if custom_instructions:
105
+ prompt = f"""
106
+ You are provided these tools:
107
+ <tools>{','.join(tool_names)}</tools>
108
+ And these instructions:
109
+ <instructions>
110
+ {custom_instructions}
111
+ </instructions>
112
+ Your task is to identify invalid tools.
113
+ A tool is invalid if it is mentioned in the instructions but not in the tools list.
114
+ A tool's name must have at least two characters.
115
+ Your response should be a comma-separated list of the invalid tools.
116
+ If no invalid tools exist, respond with "<OKAY>" (and nothing else).
117
+ """
118
+ llm = get_llm(LLMRole.MAIN, config=agent_config)
119
+ bad_tools_str = llm.complete(prompt).text.strip("\n")
120
+ if bad_tools_str and bad_tools_str != "<OKAY>":
121
+ bad_tools = [tool.strip() for tool in bad_tools_str.split(",")]
122
+ numbered = ", ".join(f"({i}) {tool}" for i, tool in enumerate(bad_tools, 1))
123
+ raise ValueError(
124
+ f"The Agent custom instructions mention these invalid tools: {numbered}"
125
+ )
@@ -16,12 +16,6 @@ from .agent import Agent
16
16
  from .agent_config import AgentConfig
17
17
 
18
18
 
19
- class ChatRequest(BaseModel):
20
- """Request schema for the /chat endpoint."""
21
-
22
- message: str
23
-
24
-
25
19
  class CompletionRequest(BaseModel):
26
20
  """Request schema for the /v1/completions endpoint."""
27
21
 
@@ -64,12 +58,14 @@ class CompletionResponse(BaseModel):
64
58
 
65
59
  class ChatMessage(BaseModel):
66
60
  """Schema for individual chat messages in ChatCompletionRequest."""
61
+
67
62
  role: Literal["system", "user", "assistant"]
68
63
  content: str
69
64
 
70
65
 
71
66
  class ChatCompletionRequest(BaseModel):
72
67
  """Request schema for the /v1/chat endpoint."""
68
+
73
69
  model: str
74
70
  messages: List[ChatMessage]
75
71
  temperature: Optional[float] = Field(1.0, ge=0.0, le=2.0)
@@ -79,6 +75,7 @@ class ChatCompletionRequest(BaseModel):
79
75
 
80
76
  class ChatCompletionChoice(BaseModel):
81
77
  """Choice schema returned in ChatCompletionResponse."""
78
+
82
79
  index: int
83
80
  message: ChatMessage
84
81
  finish_reason: Literal["stop", "length", "error", None]
@@ -86,6 +83,7 @@ class ChatCompletionChoice(BaseModel):
86
83
 
87
84
  class ChatCompletionResponse(BaseModel):
88
85
  """Response schema for the /v1/chat endpoint."""
86
+
89
87
  id: str
90
88
  object: Literal["chat.completion"]
91
89
  created: int
@@ -6,6 +6,7 @@ It makes the following adjustments:
6
6
  * Makes sure the load_data method returns a list of text values from the database (and not Document[] objects).
7
7
  * Limits the returned rows to self.max_rows.
8
8
  """
9
+
9
10
  from typing import Any, Optional, List, Awaitable, Callable
10
11
  import asyncio
11
12
  from inspect import signature
@@ -24,15 +25,20 @@ from llama_index.core.tools.utils import create_schema_from_function
24
25
 
25
26
  AsyncCallable = Callable[..., Awaitable[Any]]
26
27
 
28
+
27
29
  class DatabaseTools:
28
30
  """Database tools for vectara-agentic
29
31
  This class provides a set of tools to interact with a database.
30
32
  It allows you to load data, list tables, describe tables, and load unique values.
31
33
  It also provides a method to load sample data from a specified table.
32
34
  """
35
+
33
36
  spec_functions = [
34
- "load_data", "load_sample_data", "list_tables",
35
- "describe_tables", "load_unique_values",
37
+ "load_data",
38
+ "load_sample_data",
39
+ "list_tables",
40
+ "describe_tables",
41
+ "load_unique_values",
36
42
  ]
37
43
 
38
44
  def __init__(
@@ -61,7 +67,7 @@ class DatabaseTools:
61
67
  elif uri:
62
68
  self.uri = uri
63
69
  self.sql_database = SQLDatabase.from_uri(uri, *args, **kwargs)
64
- elif (scheme and host and port and user and password and dbname):
70
+ elif scheme and host and port and user and password and dbname:
65
71
  uri = f"{scheme}://{user}:{password}@{host}:{port}/{dbname}"
66
72
  self.uri = uri
67
73
  self.sql_database = SQLDatabase.from_uri(uri, *args, **kwargs)
@@ -76,7 +82,8 @@ class DatabaseTools:
76
82
  self._metadata.reflect(bind=self.sql_database.engine)
77
83
 
78
84
  def _get_metadata_from_fn_name(
79
- self, fn_name: Callable,
85
+ self,
86
+ fn_name: str,
80
87
  ) -> Optional[ToolMetadata]:
81
88
  """Return map from function name.
82
89
 
@@ -87,7 +94,9 @@ class DatabaseTools:
87
94
  func = getattr(self, fn_name)
88
95
  except AttributeError:
89
96
  return None
90
- name = self.tool_name_prefix + "_" + fn_name if self.tool_name_prefix else fn_name
97
+ name = (
98
+ self.tool_name_prefix + "_" + fn_name if self.tool_name_prefix else fn_name
99
+ )
91
100
  docstring = func.__doc__ or ""
92
101
  description = f"{name}{signature(func)}\n{docstring}"
93
102
  fn_schema = create_schema_from_function(fn_name, getattr(self, fn_name))
@@ -118,7 +127,9 @@ class DatabaseTools:
118
127
  try:
119
128
  count_rows = self._load_data(count_query)
120
129
  except Exception as e:
121
- return [f"Error ({str(e)}) occurred while counting number of rows, check your query."]
130
+ return [
131
+ f"Error ({str(e)}) occurred while counting number of rows, check your query."
132
+ ]
122
133
  num_rows = int(count_rows[0].text)
123
134
  if num_rows > self.max_rows:
124
135
  return [
@@ -128,7 +139,9 @@ class DatabaseTools:
128
139
  try:
129
140
  res = self._load_data(sql_query)
130
141
  except Exception as e:
131
- return [f"Error ({str(e)}) occurred while executing the query {sql_query}, check your query."]
142
+ return [
143
+ f"Error ({str(e)}) occurred while executing the query {sql_query}, check your query."
144
+ ]
132
145
  return [d.text for d in res]
133
146
 
134
147
  def load_sample_data(self, table_name: str, num_rows: int = 25) -> Any:
@@ -149,7 +162,9 @@ class DatabaseTools:
149
162
  try:
150
163
  res = self._load_data(f"SELECT * FROM {table_name} LIMIT {num_rows}")
151
164
  except Exception as e:
152
- return [f"Error ({str(e)}) occurred while loading sample data for table {table_name}"]
165
+ return [
166
+ f"Error ({str(e)}) occurred while loading sample data for table {table_name}"
167
+ ]
153
168
  return [d.text for d in res]
154
169
 
155
170
  def list_tables(self) -> List[str]:
@@ -179,7 +194,11 @@ class DatabaseTools:
179
194
  table_schemas = []
180
195
  for table_name in table_names:
181
196
  table = next(
182
- (table for table in self._metadata.sorted_tables if table.name == table_name),
197
+ (
198
+ table
199
+ for table in self._metadata.sorted_tables
200
+ if table.name == table_name
201
+ ),
183
202
  None,
184
203
  )
185
204
  if table is None:
@@ -188,7 +207,9 @@ class DatabaseTools:
188
207
  table_schemas.append(f"{schema}\n")
189
208
  return "\n".join(table_schemas)
190
209
 
191
- def load_unique_values(self, table_name: str, columns: list[str], num_vals: int = 200) -> Any:
210
+ def load_unique_values(
211
+ self, table_name: str, columns: list[str], num_vals: int = 200
212
+ ) -> Any:
192
213
  """
193
214
  Fetches the first num_vals unique values from the specified columns of the database table.
194
215
 
@@ -209,10 +230,14 @@ class DatabaseTools:
209
230
  res = {}
210
231
  try:
211
232
  for column in columns:
212
- unique_vals = self._load_data(f'SELECT DISTINCT "{column}" FROM {table_name} LIMIT {num_vals}')
233
+ unique_vals = self._load_data(
234
+ f'SELECT DISTINCT "{column}" FROM {table_name} LIMIT {num_vals}'
235
+ )
213
236
  res[column] = [d.text for d in unique_vals]
214
237
  except Exception as e:
215
- return {f"Error ({str(e)}) occurred while loading unique values for table {table_name}"}
238
+ return {
239
+ f"Error ({str(e)}) occurred while loading unique values for table {table_name}"
240
+ }
216
241
  return res
217
242
 
218
243
  def to_tool_list(self) -> List[FunctionTool]:
@@ -2,10 +2,10 @@
2
2
  Utilities for the Vectara agentic.
3
3
  """
4
4
 
5
- from typing import Tuple, Callable, Optional
5
+ from typing import Tuple, Optional
6
6
  import os
7
7
  from functools import lru_cache
8
- import tiktoken
8
+ import hashlib
9
9
 
10
10
  from llama_index.core.llms import LLM
11
11
  from llama_index.llms.openai import OpenAI
@@ -13,15 +13,14 @@ from llama_index.llms.anthropic import Anthropic
13
13
 
14
14
  # LLM provider imports are now lazy-loaded in get_llm() function
15
15
 
16
- from .types import LLMRole, AgentType, ModelProvider
16
+ from .types import LLMRole, ModelProvider
17
17
  from .agent_config import AgentConfig
18
18
 
19
19
  provider_to_default_model_name = {
20
20
  ModelProvider.OPENAI: "gpt-4.1",
21
21
  ModelProvider.ANTHROPIC: "claude-sonnet-4-20250514",
22
22
  ModelProvider.TOGETHER: "deepseek-ai/DeepSeek-V3",
23
- ModelProvider.GROQ: "deepseek-r1-distill-llama-70b",
24
- ModelProvider.FIREWORKS: "accounts/fireworks/models/firefunction-v2",
23
+ ModelProvider.GROQ: "openai/gpt-oss-20b",
25
24
  ModelProvider.BEDROCK: "us.anthropic.claude-sonnet-4-20250514-v1:0",
26
25
  ModelProvider.COHERE: "command-a-03-2025",
27
26
  ModelProvider.GEMINI: "models/gemini-2.5-flash",
@@ -29,6 +28,30 @@ provider_to_default_model_name = {
29
28
 
30
29
  DEFAULT_MODEL_PROVIDER = ModelProvider.OPENAI
31
30
 
31
+ # Manual cache for LLM instances to handle mutable AgentConfig objects
32
+ _llm_cache = {}
33
+
34
+
35
+ def _create_llm_cache_key(role: LLMRole, config: Optional[AgentConfig] = None) -> str:
36
+ """Create a hash-based cache key for LLM instances."""
37
+ if config is None:
38
+ config = AgentConfig()
39
+
40
+ # Extract only the relevant config parameters for the cache key
41
+ cache_data = {
42
+ "role": role.value,
43
+ "main_llm_provider": config.main_llm_provider.value,
44
+ "main_llm_model_name": config.main_llm_model_name,
45
+ "tool_llm_provider": config.tool_llm_provider.value,
46
+ "tool_llm_model_name": config.tool_llm_model_name,
47
+ "private_llm_api_base": config.private_llm_api_base,
48
+ "private_llm_api_key": config.private_llm_api_key,
49
+ }
50
+
51
+ # Create a stable hash from the cache data
52
+ cache_str = str(sorted(cache_data.items()))
53
+ return hashlib.md5(cache_str.encode()).hexdigest()
54
+
32
55
 
33
56
  @lru_cache(maxsize=None)
34
57
  def _get_llm_params_for_role(
@@ -54,42 +77,20 @@ def _get_llm_params_for_role(
54
77
  model_provider
55
78
  )
56
79
 
57
- # If the agent type is OpenAI, check that the main LLM provider is also OpenAI.
58
- if role == LLMRole.MAIN and config.agent_type == AgentType.OPENAI:
59
- if model_provider != ModelProvider.OPENAI:
60
- raise ValueError(
61
- "OpenAI agent requested but main model provider is not OpenAI."
62
- )
63
-
64
80
  return model_provider, model_name
65
81
 
66
82
 
67
- @lru_cache(maxsize=None)
68
- def get_tokenizer_for_model(
69
- role: LLMRole, config: Optional[AgentConfig] = None
70
- ) -> Optional[Callable]:
71
- """
72
- Get the tokenizer for the specified model, as determined by the role & config.
73
- """
74
- model_name = "Unknown model"
75
- try:
76
- model_provider, model_name = _get_llm_params_for_role(role, config)
77
- if model_provider == ModelProvider.OPENAI:
78
- return tiktoken.encoding_for_model("gpt-4o").encode
79
- if model_provider == ModelProvider.ANTHROPIC:
80
- return Anthropic().tokenizer
81
- except Exception:
82
- print(f"Error getting tokenizer for model {model_name}, ignoring")
83
- return None
84
- return None
85
-
86
-
87
- @lru_cache(maxsize=None)
88
83
  def get_llm(role: LLMRole, config: Optional[AgentConfig] = None) -> LLM:
89
84
  """
90
85
  Get the LLM for the specified role, using the provided config
91
86
  or a default if none is provided.
87
+
88
+ Uses a cache based on configuration parameters to avoid repeated LLM instantiation.
92
89
  """
90
+ # Check cache first
91
+ cache_key = _create_llm_cache_key(role, config)
92
+ if cache_key in _llm_cache:
93
+ return _llm_cache[cache_key]
93
94
  model_provider, model_name = _get_llm_params_for_role(role, config)
94
95
  max_tokens = (
95
96
  16384
@@ -107,7 +108,7 @@ def get_llm(role: LLMRole, config: Optional[AgentConfig] = None) -> LLM:
107
108
  model=model_name,
108
109
  temperature=0,
109
110
  is_function_calling_model=True,
110
- strict=True,
111
+ strict=False,
111
112
  max_tokens=max_tokens,
112
113
  pydantic_program_mode="openai",
113
114
  )
@@ -128,7 +129,6 @@ def get_llm(role: LLMRole, config: Optional[AgentConfig] = None) -> LLM:
128
129
  model=model_name,
129
130
  temperature=0,
130
131
  is_function_calling_model=True,
131
- allow_parallel_tool_calls=True,
132
132
  max_tokens=max_tokens,
133
133
  )
134
134
  elif model_provider == ModelProvider.TOGETHER:
@@ -157,14 +157,6 @@ def get_llm(role: LLMRole, config: Optional[AgentConfig] = None) -> LLM:
157
157
  is_function_calling_model=True,
158
158
  max_tokens=max_tokens,
159
159
  )
160
- elif model_provider == ModelProvider.FIREWORKS:
161
- try:
162
- from llama_index.llms.fireworks import Fireworks
163
- except ImportError as e:
164
- raise ImportError(
165
- "fireworks not available. Install with: pip install llama-index-llms-fireworks"
166
- ) from e
167
- llm = Fireworks(model=model_name, temperature=0, max_tokens=max_tokens)
168
160
  elif model_provider == ModelProvider.BEDROCK:
169
161
  try:
170
162
  from llama_index.llms.bedrock_converse import BedrockConverse
@@ -197,6 +189,10 @@ def get_llm(role: LLMRole, config: Optional[AgentConfig] = None) -> LLM:
197
189
  raise ImportError(
198
190
  "openai_like not available. Install with: pip install llama-index-llms-openai-like"
199
191
  ) from e
192
+ if not config or not config.private_llm_api_base or not config.private_llm_api_key:
193
+ raise ValueError(
194
+ "Private LLM requires both private_llm_api_base and private_llm_api_key to be set in AgentConfig."
195
+ )
200
196
  llm = OpenAILike(
201
197
  model=model_name,
202
198
  temperature=0,
@@ -209,4 +205,7 @@ def get_llm(role: LLMRole, config: Optional[AgentConfig] = None) -> LLM:
209
205
 
210
206
  else:
211
207
  raise ValueError(f"Unknown LLM provider: {model_provider}")
208
+
209
+ # Cache the created LLM instance
210
+ _llm_cache[cache_key] = llm
212
211
  return llm