vectara-agentic 0.1.23__py3-none-any.whl → 0.1.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of vectara-agentic might be problematic. Click here for more details.

vectara_agentic/tools.py CHANGED
@@ -9,6 +9,7 @@ import os
9
9
 
10
10
  from typing import Callable, List, Dict, Any, Optional, Type
11
11
  from pydantic import BaseModel, Field
12
+ from pydantic_core import PydanticUndefined
12
13
 
13
14
  from llama_index.core.tools import FunctionTool
14
15
  from llama_index.core.tools.function_tool import AsyncCallable
@@ -125,6 +126,122 @@ class VectaraTool(FunctionTool):
125
126
  break
126
127
  return is_equal
127
128
 
129
+ def _build_filter_string(kwargs: Dict[str, Any], tool_args_type: Dict[str, str], fixed_filter: str) -> str:
130
+ """
131
+ Build filter string for Vectara from kwargs
132
+ """
133
+ filter_parts = []
134
+ comparison_operators = [">=", "<=", "!=", ">", "<", "="]
135
+ numeric_only_ops = {">", "<", ">=", "<="}
136
+
137
+ for key, value in kwargs.items():
138
+ if value is None or value == "":
139
+ continue
140
+
141
+ # Determine the prefix for the key. Valid values are "doc" or "part"
142
+ # default to 'doc' if not specified
143
+ prefix = tool_args_type.get(key, "doc")
144
+
145
+ if prefix not in ["doc", "part"]:
146
+ raise ValueError(
147
+ f'Unrecognized prefix {prefix}. Please make sure to use either "doc" or "part" for the prefix.'
148
+ )
149
+
150
+ if value is PydanticUndefined:
151
+ raise ValueError(
152
+ f"Value of argument {key} is undefined, and this is invalid. "
153
+ "Please form proper arguments and try again."
154
+ )
155
+
156
+ # value of the argument
157
+ val_str = str(value).strip()
158
+
159
+ # Special handling for range operator
160
+ if val_str.startswith(("[", "(")) and val_str.endswith(("]", ")")):
161
+ # Extract the boundary types
162
+ start_inclusive = val_str.startswith("[")
163
+ end_inclusive = val_str.endswith("]")
164
+
165
+ # Remove the boundaries and strip whitespace
166
+ val_str = val_str[1:-1].strip()
167
+
168
+ if "," in val_str:
169
+ val_str = val_str.split(",")
170
+ if len(val_str) != 2:
171
+ raise ValueError(
172
+ f"Range operator requires two values for {key}: {value}"
173
+ )
174
+
175
+ # Validate both bounds as numeric or empty (for unbounded ranges)
176
+ start_val, end_val = val_str[0].strip(), val_str[1].strip()
177
+ if start_val and not (start_val.isdigit() or is_float(start_val)):
178
+ raise ValueError(
179
+ f"Range operator requires numeric operands for {key}: {value}"
180
+ )
181
+ if end_val and not (end_val.isdigit() or is_float(end_val)):
182
+ raise ValueError(
183
+ f"Range operator requires numeric operands for {key}: {value}"
184
+ )
185
+
186
+ # Build the SQL condition
187
+ range_conditions = []
188
+ if start_val:
189
+ operator = ">=" if start_inclusive else ">"
190
+ range_conditions.append(f"{prefix}.{key} {operator} {start_val}")
191
+ if end_val:
192
+ operator = "<=" if end_inclusive else "<"
193
+ range_conditions.append(f"{prefix}.{key} {operator} {end_val}")
194
+
195
+ # Join the range conditions with AND
196
+ filter_parts.append('( ' + " AND ".join(range_conditions) + ' )')
197
+ continue
198
+
199
+ raise ValueError(
200
+ f"Range operator requires two values for {key}: {value}"
201
+ )
202
+
203
+ # Check if value contains a known comparison operator at the start
204
+ matched_operator = None
205
+ for op in comparison_operators:
206
+ if val_str.startswith(op):
207
+ matched_operator = op
208
+ break
209
+
210
+ # Break down operator from value
211
+ # e.g. val_str = ">2022" --> operator = ">", rhs = "2022"
212
+ if matched_operator:
213
+ rhs = val_str[len(matched_operator):].strip()
214
+
215
+ if matched_operator in numeric_only_ops:
216
+ # Must be numeric
217
+ if not (rhs.isdigit() or is_float(rhs)):
218
+ raise ValueError(
219
+ f"Operator {matched_operator} requires a numeric operand for {key}: {val_str}"
220
+ )
221
+ filter_parts.append(f"{prefix}.{key}{matched_operator}{rhs}")
222
+ else:
223
+ # = and != operators can be numeric or string
224
+ if rhs.isdigit() or is_float(rhs):
225
+ filter_parts.append(f"{prefix}.{key}{matched_operator}{rhs}")
226
+ elif rhs.lower() in ["true", "false"]:
227
+ filter_parts.append(f"{prefix}.{key}{matched_operator}{rhs.lower()}")
228
+ else:
229
+ # For string operands, wrap them in quotes
230
+ filter_parts.append(f"{prefix}.{key}{matched_operator}'{rhs}'")
231
+ else:
232
+ if val_str.isdigit() or is_float(val_str):
233
+ filter_parts.append(f"{prefix}.{key}={val_str}")
234
+ elif val_str.lower() in ["true", "false"]:
235
+ # This is to handle boolean values.
236
+ # This is not complete solution - the best solution would be to test if the field is boolean
237
+ # That can be done after we move to APIv2
238
+ filter_parts.append(f"{prefix}.{key}={val_str.lower()}")
239
+ else:
240
+ filter_parts.append(f"{prefix}.{key}='{val_str}'")
241
+
242
+ filter_str = " AND ".join(filter_parts)
243
+ return f"({fixed_filter}) AND ({filter_str})" if fixed_filter else filter_str
244
+
128
245
  class VectaraToolFactory:
129
246
  """
130
247
  A factory class for creating Vectara RAG tools.
@@ -148,13 +265,159 @@ class VectaraToolFactory:
148
265
  self.vectara_api_key = vectara_api_key
149
266
  self.num_corpora = len(vectara_corpus_id.split(","))
150
267
 
268
+ def create_search_tool(
269
+ self,
270
+ tool_name: str,
271
+ tool_description: str,
272
+ tool_args_schema: type[BaseModel],
273
+ tool_args_type: Dict[str, str] = {},
274
+ fixed_filter: str = "",
275
+ lambda_val: float = 0.005,
276
+ reranker: str = "mmr",
277
+ rerank_k: int = 50,
278
+ mmr_diversity_bias: float = 0.2,
279
+ udf_expression: str = None,
280
+ rerank_chain: List[Dict] = None,
281
+ verbose: bool = False,
282
+ ) -> VectaraTool:
283
+ """
284
+ Creates a Vectara search/retrieval tool
285
+
286
+ Args:
287
+ tool_name (str): The name of the tool.
288
+ tool_description (str): The description of the tool.
289
+ tool_args_schema (BaseModel): The schema for the tool arguments.
290
+ tool_args_type (Dict[str, str], optional): The type of each argument (doc or part).
291
+ fixed_filter (str, optional): A fixed Vectara filter condition to apply to all queries.
292
+ lambda_val (float, optional): Lambda value for the Vectara query.
293
+ reranker (str, optional): The reranker mode.
294
+ rerank_k (int, optional): Number of top-k documents for reranking.
295
+ mmr_diversity_bias (float, optional): MMR diversity bias.
296
+ udf_expression (str, optional): the user defined expression for reranking results.
297
+ rerank_chain (List[Dict], optional): A list of rerankers to be applied sequentially.
298
+ Each dictionary should specify the "type" of reranker (mmr, slingshot, udf)
299
+ and any other parameters (e.g. "limit" or "cutoff" for any type,
300
+ "diversity_bias" for mmr, and "user_function" for udf).
301
+ If using slingshot/multilingual_reranker_v1, it must be first in the list.
302
+ verbose (bool, optional): Whether to print verbose output.
303
+
304
+ Returns:
305
+ VectaraTool: A VectaraTool object.
306
+ """
307
+
308
+ vectara = VectaraIndex(
309
+ vectara_api_key=self.vectara_api_key,
310
+ vectara_customer_id=self.vectara_customer_id,
311
+ vectara_corpus_id=self.vectara_corpus_id,
312
+ x_source_str="vectara-agentic",
313
+ )
314
+
315
+ # Dynamically generate the search function
316
+ def search_function(*args, **kwargs) -> ToolOutput:
317
+ """
318
+ Dynamically generated function for semantic search Vectara.
319
+ """
320
+ # Convert args to kwargs using the function signature
321
+ sig = inspect.signature(search_function)
322
+ bound_args = sig.bind_partial(*args, **kwargs)
323
+ bound_args.apply_defaults()
324
+ kwargs = bound_args.arguments
325
+
326
+ query = kwargs.pop("query")
327
+ top_k = kwargs.pop("top_k", 10)
328
+ try:
329
+ filter_string = _build_filter_string(kwargs, tool_args_type, fixed_filter)
330
+ except ValueError as e:
331
+ return ToolOutput(
332
+ tool_name=search_function.__name__,
333
+ content=str(e),
334
+ raw_input={"args": args, "kwargs": kwargs},
335
+ raw_output={"response": str(e)},
336
+ )
337
+
338
+ vectara_retriever = vectara.as_retriever(
339
+ summary_enabled=False,
340
+ similarity_top_k=top_k,
341
+ reranker=reranker,
342
+ rerank_k=rerank_k if rerank_k * self.num_corpora <= 100 else int(100 / self.num_corpora),
343
+ mmr_diversity_bias=mmr_diversity_bias,
344
+ udf_expression=udf_expression,
345
+ rerank_chain=rerank_chain,
346
+ lambda_val=lambda_val,
347
+ filter=filter_string,
348
+ x_source_str="vectara-agentic",
349
+ verbose=verbose,
350
+ )
351
+ response = vectara_retriever.retrieve(query)
352
+
353
+ if len(response) == 0:
354
+ msg = "Vectara Tool failed to retreive any results for the query."
355
+ return ToolOutput(
356
+ tool_name=search_function.__name__,
357
+ content=msg,
358
+ raw_input={"args": args, "kwargs": kwargs},
359
+ raw_output={"response": msg},
360
+ )
361
+ tool_output = "Matching documents:\n"
362
+ unique_ids = set()
363
+ for doc in response:
364
+ if doc.id_ in unique_ids:
365
+ continue
366
+ unique_ids.add(doc.id_)
367
+ tool_output += f"document '{doc.id_}' metadata: {doc.metadata}\n"
368
+ out = ToolOutput(
369
+ tool_name=search_function.__name__,
370
+ content=tool_output,
371
+ raw_input={"args": args, "kwargs": kwargs},
372
+ raw_output=response,
373
+ )
374
+ return out
375
+
376
+ fields = tool_args_schema.model_fields
377
+ params = [
378
+ inspect.Parameter(
379
+ name=field_name,
380
+ kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
381
+ default=field_info.default,
382
+ annotation=field_info,
383
+ )
384
+ for field_name, field_info in fields.items()
385
+ ]
386
+
387
+ # Create a new signature using the extracted parameters
388
+ sig = inspect.Signature(params)
389
+ search_function.__signature__ = sig
390
+ search_function.__annotations__["return"] = dict[str, Any]
391
+ search_function.__name__ = "_" + re.sub(r"[^A-Za-z0-9_]", "_", tool_name)
392
+
393
+ # Create the tool function signature string
394
+ fields = []
395
+ for name, field in tool_args_schema.__fields__.items():
396
+ annotation = field.annotation
397
+ type_name = annotation.__name__ if hasattr(annotation, '__name__') else str(annotation)
398
+ fields.append(f"{name}: {type_name}")
399
+ args_str = ", ".join(fields)
400
+ function_str = f"{tool_name}({args_str}) -> str"
401
+
402
+ # Create the tool
403
+ tool = VectaraTool.from_defaults(
404
+ fn=search_function,
405
+ name=tool_name,
406
+ description=function_str + ". " + tool_description,
407
+ fn_schema=tool_args_schema,
408
+ tool_type=ToolType.QUERY,
409
+ )
410
+ return tool
411
+
151
412
  def create_rag_tool(
152
413
  self,
153
414
  tool_name: str,
154
415
  tool_description: str,
155
416
  tool_args_schema: type[BaseModel],
156
417
  tool_args_type: Dict[str, str] = {},
418
+ fixed_filter: str = "",
157
419
  vectara_summarizer: str = "vectara-summary-ext-24-05-sml",
420
+ vectara_prompt_text: str = None,
158
421
  summary_num_results: int = 5,
159
422
  summary_response_lang: str = "eng",
160
423
  n_sentences_before: int = 2,
@@ -177,7 +440,9 @@ class VectaraToolFactory:
177
440
  tool_description (str): The description of the tool.
178
441
  tool_args_schema (BaseModel): The schema for the tool arguments.
179
442
  tool_args_type (Dict[str, str], optional): The type of each argument (doc or part).
443
+ fixed_filter (str, optional): A fixed Vectara filter condition to apply to all queries.
180
444
  vectara_summarizer (str, optional): The Vectara summarizer to use.
445
+ vectara_prompt_text (str, optional): The prompt text for the Vectara summarizer.
181
446
  summary_num_results (int, optional): The number of summary results.
182
447
  summary_response_lang (str, optional): The response language for the summary.
183
448
  n_sentences_before (int, optional): Number of sentences before the summary.
@@ -209,66 +474,6 @@ class VectaraToolFactory:
209
474
  x_source_str="vectara-agentic",
210
475
  )
211
476
 
212
- def _build_filter_string(kwargs: Dict[str, Any], tool_args_type: Dict[str, str]) -> str:
213
- filter_parts = []
214
- comparison_operators = [">=", "<=", "!=", ">", "<", "="]
215
- numeric_only_ops = {">", "<", ">=", "<="}
216
-
217
- for key, value in kwargs.items():
218
- if value is None or value == "":
219
- continue
220
-
221
- # Determine the prefix for the key. Valid values are "doc" or "part"
222
- # default to 'doc' if not specified
223
- prefix = tool_args_type.get(key, "doc")
224
-
225
- if prefix not in ["doc", "part"]:
226
- raise ValueError(
227
- f'Unrecognized prefix {prefix}. Please make sure to use either "doc" or "part" for the prefix.'
228
- )
229
-
230
- # Check if value contains a known comparison operator at the start
231
- val_str = str(value).strip()
232
- matched_operator = None
233
- for op in comparison_operators:
234
- if val_str.startswith(op):
235
- matched_operator = op
236
- break
237
-
238
- # Break down operator from value
239
- # e.g. val_str = ">2022" --> operator = ">", rhs = "2022"
240
- if matched_operator:
241
- rhs = val_str[len(matched_operator):].strip()
242
-
243
- if matched_operator in numeric_only_ops:
244
- # Must be numeric
245
- if not (rhs.isdigit() or is_float(rhs)):
246
- raise ValueError(
247
- f"Operator {matched_operator} requires a numeric operand for {key}: {val_str}"
248
- )
249
- filter_parts.append(f"{prefix}.{key}{matched_operator}{rhs}")
250
- else:
251
- # = and != operators can be numeric or string
252
- if rhs.isdigit() or is_float(rhs):
253
- filter_parts.append(f"{prefix}.{key}{matched_operator}{rhs}")
254
- elif rhs.lower() in ["true", "false"]:
255
- filter_parts.append(f"{prefix}.{key}{matched_operator}{rhs.lower()}")
256
- else:
257
- # For string operands, wrap them in quotes
258
- filter_parts.append(f"{prefix}.{key}{matched_operator}'{rhs}'")
259
- else:
260
- if val_str.isdigit() or is_float(val_str):
261
- filter_parts.append(f"{prefix}.{key}={val_str}")
262
- elif val_str.lower() in ["true", "false"]:
263
- # This is to handle boolean values.
264
- # This is not complete solution - the best solution would be to test if the field is boolean
265
- # That can be done after we move to APIv2
266
- filter_parts.append(f"{prefix}.{key}={val_str.lower()}")
267
- else:
268
- filter_parts.append(f"{prefix}.{key}='{val_str}'")
269
-
270
- return " AND ".join(filter_parts)
271
-
272
477
  # Dynamically generate the RAG function
273
478
  def rag_function(*args, **kwargs) -> ToolOutput:
274
479
  """
@@ -281,13 +486,22 @@ class VectaraToolFactory:
281
486
  kwargs = bound_args.arguments
282
487
 
283
488
  query = kwargs.pop("query")
284
- filter_string = _build_filter_string(kwargs, tool_args_type)
489
+ try:
490
+ filter_string = _build_filter_string(kwargs, tool_args_type, fixed_filter)
491
+ except ValueError as e:
492
+ return ToolOutput(
493
+ tool_name=rag_function.__name__,
494
+ content=str(e),
495
+ raw_input={"args": args, "kwargs": kwargs},
496
+ raw_output={"response": str(e)},
497
+ )
285
498
 
286
499
  vectara_query_engine = vectara.as_query_engine(
287
500
  summary_enabled=True,
288
501
  summary_num_results=summary_num_results,
289
502
  summary_response_lang=summary_response_lang,
290
503
  summary_prompt_name=vectara_summarizer,
504
+ prompt_text=vectara_prompt_text,
291
505
  reranker=reranker,
292
506
  rerank_k=rerank_k if rerank_k * self.num_corpora <= 100 else int(100 / self.num_corpora),
293
507
  mmr_diversity_bias=mmr_diversity_bias,
@@ -3,9 +3,11 @@ This module contains the tools catalog for the Vectara Agentic.
3
3
  """
4
4
  from typing import List
5
5
  from functools import lru_cache
6
- from pydantic import Field
6
+ from datetime import date
7
7
  import requests
8
8
 
9
+ from pydantic import Field
10
+
9
11
  from .types import LLMRole
10
12
  from .utils import get_llm
11
13
 
@@ -19,6 +21,11 @@ get_headers = {
19
21
  "Connection": "keep-alive",
20
22
  }
21
23
 
24
+ def get_current_date() -> str:
25
+ """
26
+ Returns: the current date.
27
+ """
28
+ return date.today().strftime("%A, %B %d, %Y")
22
29
 
23
30
  #
24
31
  # Standard Tools
vectara_agentic/types.py CHANGED
@@ -3,6 +3,9 @@ This module contains the types used in the Vectara Agentic.
3
3
  """
4
4
  from enum import Enum
5
5
 
6
+ from llama_index.core.tools.types import ToolOutput as LI_ToolOutput
7
+ from llama_index.core.chat_engine.types import AgentChatResponse as LI_AgentChatResponse
8
+ from llama_index.core.chat_engine.types import StreamingAgentChatResponse as LI_StreamingAgentChatResponse
6
9
 
7
10
  class AgentType(Enum):
8
11
  """Enumeration for different types of agents."""
@@ -29,6 +32,7 @@ class ModelProvider(Enum):
29
32
  FIREWORKS = "FIREWORKS"
30
33
  COHERE = "COHERE"
31
34
  GEMINI = "GEMINI"
35
+ BEDROCK = "BEDROCK"
32
36
 
33
37
 
34
38
  class AgentStatusType(Enum):
@@ -51,3 +55,9 @@ class ToolType(Enum):
51
55
  """Enumeration for different types of tools."""
52
56
  QUERY = "query"
53
57
  ACTION = "action"
58
+
59
+
60
+ # classes for Agent responses
61
+ ToolOutput = LI_ToolOutput
62
+ AgentResponse = LI_AgentChatResponse
63
+ AgentStreamingResponse = LI_StreamingAgentChatResponse
vectara_agentic/utils.py CHANGED
@@ -2,7 +2,6 @@
2
2
  Utilities for the Vectara agentic.
3
3
  """
4
4
 
5
- import os
6
5
  from typing import Tuple, Callable, Optional
7
6
  from functools import lru_cache
8
7
 
@@ -13,6 +12,7 @@ from llama_index.llms.openai import OpenAI
13
12
  from llama_index.llms.anthropic import Anthropic
14
13
 
15
14
  from .types import LLMRole, AgentType, ModelProvider
15
+ from .agent_config import AgentConfig
16
16
 
17
17
  provider_to_default_model_name = {
18
18
  ModelProvider.OPENAI: "gpt-4o",
@@ -20,6 +20,7 @@ provider_to_default_model_name = {
20
20
  ModelProvider.TOGETHER: "meta-llama/Llama-3.3-70B-Instruct-Turbo",
21
21
  ModelProvider.GROQ: "llama-3.3-70b-versatile",
22
22
  ModelProvider.FIREWORKS: "accounts/fireworks/models/firefunction-v2",
23
+ ModelProvider.BEDROCK: "anthropic.claude-3-5-sonnet-20241022-v2:0",
23
24
  ModelProvider.COHERE: "command-r-plus",
24
25
  ModelProvider.GEMINI: "models/gemini-1.5-flash",
25
26
  }
@@ -27,44 +28,52 @@ provider_to_default_model_name = {
27
28
  DEFAULT_MODEL_PROVIDER = ModelProvider.OPENAI
28
29
 
29
30
  @lru_cache(maxsize=None)
30
- def _get_llm_params_for_role(role: LLMRole) -> Tuple[ModelProvider, str]:
31
- """Get the model provider and model name for the specified role."""
31
+ def _get_llm_params_for_role(
32
+ role: LLMRole,
33
+ config: Optional[AgentConfig] = None
34
+ ) -> Tuple[ModelProvider, str]:
35
+ """
36
+ Get the model provider and model name for the specified role.
37
+
38
+ If config is None, a new AgentConfig() is instantiated using environment defaults.
39
+ """
40
+ config = config or AgentConfig() # fallback to default config
41
+
32
42
  if role == LLMRole.TOOL:
33
- model_provider = ModelProvider(
34
- os.getenv("VECTARA_AGENTIC_TOOL_LLM_PROVIDER", DEFAULT_MODEL_PROVIDER.value)
35
- )
36
- model_name = os.getenv(
37
- "VECTARA_AGENTIC_TOOL_MODEL_NAME",
38
- provider_to_default_model_name.get(model_provider),
43
+ model_provider = config.tool_llm_provider
44
+ # If the user hasn’t explicitly set a tool_llm_model_name,
45
+ # fallback to provider default from provider_to_default_model_name
46
+ model_name = (
47
+ config.tool_llm_model_name
48
+ or provider_to_default_model_name.get(model_provider)
39
49
  )
40
50
  else:
41
- model_provider = ModelProvider(
42
- os.getenv("VECTARA_AGENTIC_MAIN_LLM_PROVIDER", DEFAULT_MODEL_PROVIDER.value)
43
- )
44
- model_name = os.getenv(
45
- "VECTARA_AGENTIC_MAIN_MODEL_NAME",
46
- provider_to_default_model_name.get(model_provider),
51
+ model_provider = config.main_llm_provider
52
+ model_name = (
53
+ config.main_llm_model_name
54
+ or provider_to_default_model_name.get(model_provider)
47
55
  )
48
56
 
49
- agent_type = AgentType(
50
- os.getenv("VECTARA_AGENTIC_AGENT_TYPE", AgentType.OPENAI.value)
51
- )
52
- if (
53
- role == LLMRole.MAIN
54
- and agent_type == AgentType.OPENAI
55
- and model_provider != ModelProvider.OPENAI
56
- ):
57
- raise ValueError(
58
- "OpenAI agent requested but main model provider is not OpenAI."
59
- )
57
+ # If the agent type is OpenAI, check that the main LLM provider is also OpenAI.
58
+ if role == LLMRole.MAIN and config.agent_type == AgentType.OPENAI:
59
+ if model_provider != ModelProvider.OPENAI:
60
+ raise ValueError(
61
+ "OpenAI agent requested but main model provider is not OpenAI."
62
+ )
60
63
 
61
64
  return model_provider, model_name
62
65
 
63
66
  @lru_cache(maxsize=None)
64
- def get_tokenizer_for_model(role: LLMRole) -> Optional[Callable]:
65
- """Get the tokenizer for the specified model."""
66
- model_provider, model_name = _get_llm_params_for_role(role)
67
+ def get_tokenizer_for_model(
68
+ role: LLMRole,
69
+ config: Optional[AgentConfig] = None
70
+ ) -> Optional[Callable]:
71
+ """
72
+ Get the tokenizer for the specified model, as determined by the role & config.
73
+ """
74
+ model_provider, model_name = _get_llm_params_for_role(role, config)
67
75
  if model_provider == ModelProvider.OPENAI:
76
+ # This might raise an exception if the model_name is unknown to tiktoken
68
77
  return tiktoken.encoding_for_model(model_name).encode
69
78
  if model_provider == ModelProvider.ANTHROPIC:
70
79
  return Anthropic().tokenizer
@@ -72,10 +81,15 @@ def get_tokenizer_for_model(role: LLMRole) -> Optional[Callable]:
72
81
 
73
82
 
74
83
  @lru_cache(maxsize=None)
75
- def get_llm(role: LLMRole) -> LLM:
76
- """Get the LLM for the specified role."""
77
- model_provider, model_name = _get_llm_params_for_role(role)
78
-
84
+ def get_llm(
85
+ role: LLMRole,
86
+ config: Optional[AgentConfig] = None
87
+ ) -> LLM:
88
+ """
89
+ Get the LLM for the specified role, using the provided config
90
+ or a default if none is provided.
91
+ """
92
+ model_provider, model_name = _get_llm_params_for_role(role, config)
79
93
  if model_provider == ModelProvider.OPENAI:
80
94
  llm = OpenAI(model=model_name, temperature=0, is_function_calling_model=True)
81
95
  elif model_provider == ModelProvider.ANTHROPIC:
@@ -92,12 +106,14 @@ def get_llm(role: LLMRole) -> LLM:
92
106
  elif model_provider == ModelProvider.FIREWORKS:
93
107
  from llama_index.llms.fireworks import Fireworks
94
108
  llm = Fireworks(model=model_name, temperature=0)
109
+ elif model_provider == ModelProvider.BEDROCK:
110
+ from llama_index.llms.bedrock import Bedrock
111
+ llm = Bedrock(model=model_name, temperature=0)
95
112
  elif model_provider == ModelProvider.COHERE:
96
113
  from llama_index.llms.cohere import Cohere
97
114
  llm = Cohere(model=model_name, temperature=0)
98
115
  else:
99
116
  raise ValueError(f"Unknown LLM provider: {model_provider}")
100
-
101
117
  return llm
102
118
 
103
119
  def is_float(value: str) -> bool: