vectara-agentic 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of vectara-agentic might be problematic. Click here for more details.

vectara_agentic/agent.py CHANGED
@@ -1,6 +1,7 @@
1
1
  """
2
2
  This module contains the Agent class for handling different types of agents and their interactions.
3
3
  """
4
+
4
5
  from typing import List, Callable, Optional, Dict, Any, Union, Tuple
5
6
  import os
6
7
  import re
@@ -11,6 +12,8 @@ import logging
11
12
  import asyncio
12
13
  import importlib
13
14
  from collections import Counter
15
+ import inspect
16
+ from inspect import Signature, Parameter, ismethod
14
17
 
15
18
  import cloudpickle as pickle
16
19
 
@@ -18,10 +21,15 @@ from dotenv import load_dotenv
18
21
 
19
22
  from pydantic import Field, create_model, ValidationError
20
23
 
24
+
21
25
  from llama_index.core.memory import ChatMemoryBuffer
22
26
  from llama_index.core.llms import ChatMessage, MessageRole
23
27
  from llama_index.core.tools import FunctionTool
24
- from llama_index.core.agent import ReActAgent, StructuredPlannerAgent, FunctionCallingAgent
28
+ from llama_index.core.agent import (
29
+ ReActAgent,
30
+ StructuredPlannerAgent,
31
+ FunctionCallingAgent,
32
+ )
25
33
  from llama_index.core.agent.react.formatter import ReActChatFormatter
26
34
  from llama_index.agent.llm_compiler import LLMCompilerAgentWorker
27
35
  from llama_index.agent.lats import LATSAgentWorker
@@ -33,13 +41,22 @@ from llama_index.core.agent.types import BaseAgent
33
41
  from llama_index.core.workflow import Workflow
34
42
 
35
43
  from .types import (
36
- AgentType, AgentStatusType, LLMRole, ToolType, ModelProvider,
37
- AgentResponse, AgentStreamingResponse, AgentConfigType
44
+ AgentType,
45
+ AgentStatusType,
46
+ LLMRole,
47
+ ToolType,
48
+ ModelProvider,
49
+ AgentResponse,
50
+ AgentStreamingResponse,
51
+ AgentConfigType,
38
52
  )
39
- from .utils import get_llm, get_tokenizer_for_model
53
+ from .llm_utils import get_llm, get_tokenizer_for_model
40
54
  from ._prompts import (
41
- REACT_PROMPT_TEMPLATE, GENERAL_PROMPT_TEMPLATE, GENERAL_INSTRUCTIONS,
42
- STRUCTURED_PLANNER_PLAN_REFINE_PROMPT, STRUCTURED_PLANNER_INITIAL_PLAN_PROMPT
55
+ REACT_PROMPT_TEMPLATE,
56
+ GENERAL_PROMPT_TEMPLATE,
57
+ GENERAL_INSTRUCTIONS,
58
+ STRUCTURED_PLANNER_PLAN_REFINE_PROMPT,
59
+ STRUCTURED_PLANNER_INITIAL_PLAN_PROMPT,
43
60
  )
44
61
  from ._callback import AgentCallbackHandler
45
62
  from ._observability import setup_observer, eval_fcs
@@ -47,10 +64,12 @@ from .tools import VectaraToolFactory, VectaraTool, ToolsFactory
47
64
  from .tools_catalog import get_current_date
48
65
  from .agent_config import AgentConfig
49
66
 
67
+
50
68
  class IgnoreUnpickleableAttributeFilter(logging.Filter):
51
- '''
69
+ """
52
70
  Filter to ignore log messages that contain certain strings
53
- '''
71
+ """
72
+
54
73
  def filter(self, record):
55
74
  msgs_to_ignore = [
56
75
  "Removing unpickleable private attribute _chunking_tokenizer_fn",
@@ -67,12 +86,19 @@ logger.setLevel(logging.CRITICAL)
67
86
 
68
87
  load_dotenv(override=True)
69
88
 
70
- def _get_prompt(prompt_template: str, topic: str, custom_instructions: str):
89
+
90
+ def _get_prompt(
91
+ prompt_template: str,
92
+ general_instructions: str,
93
+ topic: str,
94
+ custom_instructions: str,
95
+ ):
71
96
  """
72
97
  Generate a prompt by replacing placeholders with topic and date.
73
98
 
74
99
  Args:
75
100
  prompt_template (str): The template for the prompt.
101
+ general_instructions (str): General instructions to be included in the prompt.
76
102
  topic (str): The topic to be included in the prompt.
77
103
  custom_instructions(str): The custom instructions to be included in the prompt.
78
104
 
@@ -83,10 +109,13 @@ def _get_prompt(prompt_template: str, topic: str, custom_instructions: str):
83
109
  prompt_template.replace("{chat_topic}", topic)
84
110
  .replace("{today}", date.today().strftime("%A, %B %d, %Y"))
85
111
  .replace("{custom_instructions}", custom_instructions)
112
+ .replace("{INSTRUCTIONS}", general_instructions)
86
113
  )
87
114
 
88
115
 
89
- def _get_llm_compiler_prompt(prompt: str, topic: str, custom_instructions: str) -> str:
116
+ def _get_llm_compiler_prompt(
117
+ prompt: str, general_instructions: str, topic: str, custom_instructions: str
118
+ ) -> str:
90
119
  """
91
120
  Add custom instructions to the prompt.
92
121
 
@@ -98,7 +127,7 @@ def _get_llm_compiler_prompt(prompt: str, topic: str, custom_instructions: str)
98
127
  """
99
128
  prompt += "\nAdditional Instructions:\n"
100
129
  prompt += f"You have experise in {topic}.\n"
101
- prompt += GENERAL_INSTRUCTIONS
130
+ prompt += general_instructions
102
131
  prompt += custom_instructions
103
132
  prompt += f"Today is {date.today().strftime('%A, %B %d, %Y')}"
104
133
  return prompt
@@ -132,6 +161,7 @@ def get_field_type(field_schema: dict) -> Any:
132
161
  else:
133
162
  return Any
134
163
 
164
+
135
165
  class Agent:
136
166
  """
137
167
  Agent class for handling different types of agents and their interactions.
@@ -142,10 +172,13 @@ class Agent:
142
172
  tools: list[FunctionTool],
143
173
  topic: str = "general",
144
174
  custom_instructions: str = "",
175
+ general_instructions: str = GENERAL_INSTRUCTIONS,
145
176
  verbose: bool = True,
146
177
  use_structured_planning: bool = False,
147
178
  update_func: Optional[Callable[[AgentStatusType, str], None]] = None,
148
- agent_progress_callback: Optional[Callable[[AgentStatusType, str], None]] = None,
179
+ agent_progress_callback: Optional[
180
+ Callable[[AgentStatusType, str], None]
181
+ ] = None,
149
182
  query_logging_callback: Optional[Callable[[str, str], None]] = None,
150
183
  agent_config: Optional[AgentConfig] = None,
151
184
  fallback_agent_config: Optional[AgentConfig] = None,
@@ -162,6 +195,9 @@ class Agent:
162
195
  tools (list[FunctionTool]): A list of tools to be used by the agent.
163
196
  topic (str, optional): The topic for the agent. Defaults to 'general'.
164
197
  custom_instructions (str, optional): Custom instructions for the agent. Defaults to ''.
198
+ general_instructions (str, optional): General instructions for the agent.
199
+ The Agent has a default set of instructions that are crafted to help it operate effectively.
200
+ This allows you to customize the agent's behavior and personality, but use with caution.
165
201
  verbose (bool, optional): Whether the agent should print its steps. Defaults to True.
166
202
  use_structured_planning (bool, optional)
167
203
  Whether or not we want to wrap the agent with LlamaIndex StructuredPlannerAgent.
@@ -181,19 +217,26 @@ class Agent:
181
217
  self.agent_config = agent_config or AgentConfig()
182
218
  self.agent_config_type = AgentConfigType.DEFAULT
183
219
  self.tools = tools
184
- if not any(tool.metadata.name == 'get_current_date' for tool in self.tools):
220
+ if not any(tool.metadata.name == "get_current_date" for tool in self.tools):
185
221
  self.tools += [ToolsFactory().create_tool(get_current_date)]
186
222
  self.agent_type = self.agent_config.agent_type
187
223
  self.use_structured_planning = use_structured_planning
188
224
  self.llm = get_llm(LLMRole.MAIN, config=self.agent_config)
189
225
  self._custom_instructions = custom_instructions
226
+ self._general_instructions = general_instructions
190
227
  self._topic = topic
191
- self.agent_progress_callback = agent_progress_callback if agent_progress_callback else update_func
228
+ self.agent_progress_callback = (
229
+ agent_progress_callback if agent_progress_callback else update_func
230
+ )
192
231
  self.query_logging_callback = query_logging_callback
193
232
 
194
233
  self.workflow_cls = workflow_cls
195
234
  self.workflow_timeout = workflow_timeout
196
235
 
236
+ # Sanitize tools for Gemini if needed
237
+ if self.agent_config.main_llm_provider == ModelProvider.GEMINI:
238
+ self.tools = self._sanitize_tools_for_gemini(self.tools)
239
+
197
240
  # Validate tools
198
241
  # Check for:
199
242
  # 1. multiple copies of the same tool
@@ -204,7 +247,7 @@ class Agent:
204
247
  raise ValueError(f"Duplicate tools detected: {', '.join(duplicates)}")
205
248
 
206
249
  if validate_tools:
207
- prompt = f'''
250
+ prompt = f"""
208
251
  Given the following instructions, and a list of tool names,
209
252
  Please identify tools mentioned in the instructions that do not exist in the list.
210
253
  Instructions:
@@ -212,20 +255,28 @@ class Agent:
212
255
  Tool names: {', '.join(tool_names)}
213
256
  Your response should include a comma separated list of tool names that do not exist in the list.
214
257
  Your response should be an empty string if all tools mentioned in the instructions are in the list.
215
- '''
258
+ """
216
259
  llm = get_llm(LLMRole.MAIN, config=self.agent_config)
217
260
  bad_tools = llm.complete(prompt).text.split(", ")
218
261
  if bad_tools:
219
- raise ValueError(f"The Agent custom instructions mention these invalid tools: {', '.join(bad_tools)}")
262
+ raise ValueError(
263
+ f"The Agent custom instructions mention these invalid tools: {', '.join(bad_tools)}"
264
+ )
220
265
 
221
266
  # Create token counters for the main and tool LLMs
222
267
  main_tok = get_tokenizer_for_model(role=LLMRole.MAIN)
223
- self.main_token_counter = TokenCountingHandler(tokenizer=main_tok) if main_tok else None
268
+ self.main_token_counter = (
269
+ TokenCountingHandler(tokenizer=main_tok) if main_tok else None
270
+ )
224
271
  tool_tok = get_tokenizer_for_model(role=LLMRole.TOOL)
225
- self.tool_token_counter = TokenCountingHandler(tokenizer=tool_tok) if tool_tok else None
272
+ self.tool_token_counter = (
273
+ TokenCountingHandler(tokenizer=tool_tok) if tool_tok else None
274
+ )
226
275
 
227
276
  # Setup callback manager
228
- callbacks: list[BaseCallbackHandler] = [AgentCallbackHandler(self.agent_progress_callback)]
277
+ callbacks: list[BaseCallbackHandler] = [
278
+ AgentCallbackHandler(self.agent_progress_callback)
279
+ ]
229
280
  if self.main_token_counter:
230
281
  callbacks.append(self.main_token_counter)
231
282
  if self.tool_token_counter:
@@ -236,9 +287,17 @@ class Agent:
236
287
  if chat_history:
237
288
  msg_history = []
238
289
  for text_pairs in chat_history:
239
- msg_history.append(ChatMessage.from_str(content=text_pairs[0], role=MessageRole.USER))
240
- msg_history.append(ChatMessage.from_str(content=text_pairs[1], role=MessageRole.ASSISTANT))
241
- self.memory = ChatMemoryBuffer.from_defaults(token_limit=128000, chat_history=msg_history)
290
+ msg_history.append(
291
+ ChatMessage.from_str(content=text_pairs[0], role=MessageRole.USER)
292
+ )
293
+ msg_history.append(
294
+ ChatMessage.from_str(
295
+ content=text_pairs[1], role=MessageRole.ASSISTANT
296
+ )
297
+ )
298
+ self.memory = ChatMemoryBuffer.from_defaults(
299
+ token_limit=128000, chat_history=msg_history
300
+ )
242
301
  else:
243
302
  self.memory = ChatMemoryBuffer.from_defaults(token_limit=128000)
244
303
 
@@ -246,7 +305,9 @@ class Agent:
246
305
  self.agent = self._create_agent(self.agent_config, callback_manager)
247
306
  self.fallback_agent_config = fallback_agent_config
248
307
  if self.fallback_agent_config:
249
- self.fallback_agent = self._create_agent(self.fallback_agent_config, callback_manager)
308
+ self.fallback_agent = self._create_agent(
309
+ self.fallback_agent_config, callback_manager
310
+ )
250
311
  else:
251
312
  self.fallback_agent_config = None
252
313
 
@@ -257,10 +318,65 @@ class Agent:
257
318
  print(f"Failed to set up observer ({e}), ignoring")
258
319
  self.observability_enabled = False
259
320
 
321
+ def _sanitize_tools_for_gemini(
322
+ self, tools: list[FunctionTool]
323
+ ) -> list[FunctionTool]:
324
+ """
325
+ Strip all default values from:
326
+ - tool.fn
327
+ - tool.async_fn
328
+ - tool.metadata.fn_schema
329
+ so Gemini sees *only* required parameters, no defaults.
330
+ """
331
+ for tool in tools:
332
+ # 1) strip defaults off the actual callables
333
+ for func in (tool.fn, tool.async_fn):
334
+ if not func:
335
+ continue
336
+ orig_sig = inspect.signature(func)
337
+ new_params = [
338
+ p.replace(default=Parameter.empty)
339
+ for p in orig_sig.parameters.values()
340
+ ]
341
+ new_sig = Signature(
342
+ new_params, return_annotation=orig_sig.return_annotation
343
+ )
344
+ if ismethod(func):
345
+ func.__func__.__signature__ = new_sig
346
+ else:
347
+ func.__signature__ = new_sig
348
+
349
+ # 2) rebuild the Pydantic schema so that *every* field is required
350
+ schema_cls = getattr(tool.metadata, "fn_schema", None)
351
+ if schema_cls and hasattr(schema_cls, "model_fields"):
352
+ # collect (name → (type, Field(...))) for all fields
353
+ new_fields: dict[str, tuple[type, Any]] = {}
354
+ for name, mf in schema_cls.model_fields.items():
355
+ typ = mf.annotation
356
+ desc = getattr(mf, "description", "")
357
+ # force required (no default) with Field(...)
358
+ new_fields[name] = (typ, Field(..., description=desc))
359
+
360
+ # make a brand-new schema class where every field is required
361
+ no_default_schema = create_model(
362
+ f"{schema_cls.__name__}", # new class name
363
+ **new_fields, # type: ignore
364
+ )
365
+
366
+ # give it a clean __signature__ so inspect.signature sees no defaults
367
+ params = [
368
+ Parameter(n, Parameter.POSITIONAL_OR_KEYWORD, annotation=typ)
369
+ for n, (typ, _) in new_fields.items()
370
+ ]
371
+ no_default_schema.__signature__ = Signature(params)
372
+
373
+ # swap it back onto the tool
374
+ tool.metadata.fn_schema = no_default_schema
375
+
376
+ return tools
377
+
260
378
  def _create_agent(
261
- self,
262
- config: AgentConfig,
263
- llm_callback_manager: CallbackManager
379
+ self, config: AgentConfig, llm_callback_manager: CallbackManager
264
380
  ) -> Union[BaseAgent, AgentRunner]:
265
381
  """
266
382
  Creates the agent based on the configuration object.
@@ -282,7 +398,12 @@ class Agent:
282
398
  raise ValueError(
283
399
  "Vectara-agentic: Function calling agent type is not supported with the OpenAI LLM."
284
400
  )
285
- prompt = _get_prompt(GENERAL_PROMPT_TEMPLATE, self._topic, self._custom_instructions)
401
+ prompt = _get_prompt(
402
+ GENERAL_PROMPT_TEMPLATE,
403
+ self._general_instructions,
404
+ self._topic,
405
+ self._custom_instructions,
406
+ )
286
407
  agent = FunctionCallingAgent.from_tools(
287
408
  tools=self.tools,
288
409
  llm=llm,
@@ -294,7 +415,12 @@ class Agent:
294
415
  allow_parallel_tool_calls=True,
295
416
  )
296
417
  elif agent_type == AgentType.REACT:
297
- prompt = _get_prompt(REACT_PROMPT_TEMPLATE, self._topic, self._custom_instructions)
418
+ prompt = _get_prompt(
419
+ REACT_PROMPT_TEMPLATE,
420
+ self._general_instructions,
421
+ self._topic,
422
+ self._custom_instructions,
423
+ )
298
424
  agent = ReActAgent.from_tools(
299
425
  tools=self.tools,
300
426
  llm=llm,
@@ -309,7 +435,12 @@ class Agent:
309
435
  raise ValueError(
310
436
  "Vectara-agentic: OPENAI agent type requires the OpenAI LLM."
311
437
  )
312
- prompt = _get_prompt(GENERAL_PROMPT_TEMPLATE, self._topic, self._custom_instructions)
438
+ prompt = _get_prompt(
439
+ GENERAL_PROMPT_TEMPLATE,
440
+ self._general_instructions,
441
+ self._topic,
442
+ self._custom_instructions,
443
+ )
313
444
  agent = OpenAIAgent.from_tools(
314
445
  tools=self.tools,
315
446
  llm=llm,
@@ -327,12 +458,26 @@ class Agent:
327
458
  callback_manager=llm_callback_manager,
328
459
  )
329
460
  agent_worker.system_prompt = _get_prompt(
330
- _get_llm_compiler_prompt(agent_worker.system_prompt, self._topic, self._custom_instructions),
331
- self._topic, self._custom_instructions
461
+ prompt_template=_get_llm_compiler_prompt(
462
+ prompt=agent_worker.system_prompt,
463
+ general_instructions=self._general_instructions,
464
+ topic=self._topic,
465
+ custom_instructions=self._custom_instructions,
466
+ ),
467
+ general_instructions=self._general_instructions,
468
+ topic=self._topic,
469
+ custom_instructions=self._custom_instructions,
332
470
  )
333
471
  agent_worker.system_prompt_replan = _get_prompt(
334
- _get_llm_compiler_prompt(agent_worker.system_prompt_replan, self._topic, self._custom_instructions),
335
- self._topic, self._custom_instructions
472
+ prompt_template=_get_llm_compiler_prompt(
473
+ prompt=agent_worker.system_prompt_replan,
474
+ general_instructions=GENERAL_INSTRUCTIONS,
475
+ topic=self._topic,
476
+ custom_instructions=self._custom_instructions,
477
+ ),
478
+ general_instructions=GENERAL_INSTRUCTIONS,
479
+ topic=self._topic,
480
+ custom_instructions=self._custom_instructions,
336
481
  )
337
482
  agent = agent_worker.as_agent()
338
483
  elif agent_type == AgentType.LATS:
@@ -344,18 +489,27 @@ class Agent:
344
489
  verbose=self.verbose,
345
490
  callback_manager=llm_callback_manager,
346
491
  )
347
- prompt = _get_prompt(REACT_PROMPT_TEMPLATE, self._topic, self._custom_instructions)
492
+ prompt = _get_prompt(
493
+ REACT_PROMPT_TEMPLATE,
494
+ self._general_instructions,
495
+ self._topic,
496
+ self._custom_instructions,
497
+ )
348
498
  agent_worker.chat_formatter = ReActChatFormatter(system_header=prompt)
349
499
  agent = agent_worker.as_agent()
350
500
  else:
351
501
  raise ValueError(f"Unknown agent type: {agent_type}")
352
502
 
353
503
  # Set up structured planner if needed
354
- if (self.use_structured_planning
355
- or self.agent_type in [AgentType.LLMCOMPILER, AgentType.LATS]):
504
+ if self.use_structured_planning or self.agent_type in [
505
+ AgentType.LLMCOMPILER,
506
+ AgentType.LATS,
507
+ ]:
508
+ planner_llm = get_llm(LLMRole.TOOL, config=config)
356
509
  agent = StructuredPlannerAgent(
357
510
  agent_worker=agent.agent_worker,
358
511
  tools=self.tools,
512
+ llm=planner_llm,
359
513
  memory=self.memory,
360
514
  verbose=self.verbose,
361
515
  initial_plan_prompt=STRUCTURED_PLANNER_INITIAL_PLAN_PROMPT,
@@ -370,14 +524,19 @@ class Agent:
370
524
  """
371
525
  if self.agent_config_type == AgentConfigType.DEFAULT:
372
526
  self.agent.memory.reset()
373
- elif self.agent_config_type == AgentConfigType.FALLBACK and self.fallback_agent_config:
527
+ elif (
528
+ self.agent_config_type == AgentConfigType.FALLBACK
529
+ and self.fallback_agent_config
530
+ ):
374
531
  self.fallback_agent.memory.reset()
375
532
  else:
376
533
  raise ValueError(f"Invalid agent config type {self.agent_config_type}")
377
534
 
378
535
  def __eq__(self, other):
379
536
  if not isinstance(other, Agent):
380
- print(f"Comparison failed: other is not an instance of Agent. (self: {type(self)}, other: {type(other)})")
537
+ print(
538
+ f"Comparison failed: other is not an instance of Agent. (self: {type(self)}, other: {type(other)})"
539
+ )
381
540
  return False
382
541
 
383
542
  # Compare agent_type
@@ -393,12 +552,15 @@ class Agent:
393
552
  print(
394
553
  "Comparison failed: tools differ."
395
554
  f"(self.tools: {[t.metadata.name for t in self.tools]}, "
396
- f"other.tools: {[t.metadata.name for t in other.tools]})")
555
+ f"other.tools: {[t.metadata.name for t in other.tools]})"
556
+ )
397
557
  return False
398
558
 
399
559
  # Compare topic
400
560
  if self._topic != other._topic:
401
- print(f"Comparison failed: topic differs. (self.topic: {self._topic}, other.topic: {other._topic})")
561
+ print(
562
+ f"Comparison failed: topic differs. (self.topic: {self._topic}, other.topic: {other._topic})"
563
+ )
402
564
  return False
403
565
 
404
566
  # Compare custom_instructions
@@ -411,7 +573,9 @@ class Agent:
411
573
 
412
574
  # Compare verbose
413
575
  if self.verbose != other.verbose:
414
- print(f"Comparison failed: verbose differs. (self.verbose: {self.verbose}, other.verbose: {other.verbose})")
576
+ print(
577
+ f"Comparison failed: verbose differs. (self.verbose: {self.verbose}, other.verbose: {other.verbose})"
578
+ )
415
579
  return False
416
580
 
417
581
  # Compare agent memory
@@ -434,7 +598,9 @@ class Agent:
434
598
  custom_instructions: str = "",
435
599
  verbose: bool = True,
436
600
  update_func: Optional[Callable[[AgentStatusType, str], None]] = None,
437
- agent_progress_callback: Optional[Callable[[AgentStatusType, str], None]] = None,
601
+ agent_progress_callback: Optional[
602
+ Callable[[AgentStatusType, str], None]
603
+ ] = None,
438
604
  query_logging_callback: Optional[Callable[[str, str], None]] = None,
439
605
  agent_config: AgentConfig = AgentConfig(),
440
606
  validate_tools: bool = False,
@@ -467,14 +633,19 @@ class Agent:
467
633
  Agent: An instance of the Agent class.
468
634
  """
469
635
  return cls(
470
- tools=tools, topic=topic, custom_instructions=custom_instructions,
471
- verbose=verbose, agent_progress_callback=agent_progress_callback,
636
+ tools=tools,
637
+ topic=topic,
638
+ custom_instructions=custom_instructions,
639
+ verbose=verbose,
640
+ agent_progress_callback=agent_progress_callback,
472
641
  query_logging_callback=query_logging_callback,
473
- update_func=update_func, agent_config=agent_config,
642
+ update_func=update_func,
643
+ agent_config=agent_config,
474
644
  chat_history=chat_history,
475
645
  validate_tools=validate_tools,
476
646
  fallback_agent_config=fallback_agent_config,
477
- workflow_cls = workflow_cls, workflow_timeout = workflow_timeout,
647
+ workflow_cls=workflow_cls,
648
+ workflow_timeout=workflow_timeout,
478
649
  )
479
650
 
480
651
  @classmethod
@@ -483,9 +654,12 @@ class Agent:
483
654
  tool_name: str,
484
655
  data_description: str,
485
656
  assistant_specialty: str,
657
+ general_instructions: str = GENERAL_INSTRUCTIONS,
486
658
  vectara_corpus_key: str = str(os.environ.get("VECTARA_CORPUS_KEY", "")),
487
659
  vectara_api_key: str = str(os.environ.get("VECTARA_API_KEY", "")),
488
- agent_progress_callback: Optional[Callable[[AgentStatusType, str], None]] = None,
660
+ agent_progress_callback: Optional[
661
+ Callable[[AgentStatusType, str], None]
662
+ ] = None,
489
663
  query_logging_callback: Optional[Callable[[str, str], None]] = None,
490
664
  agent_config: AgentConfig = AgentConfig(),
491
665
  fallback_agent_config: Optional[AgentConfig] = None,
@@ -530,6 +704,9 @@ class Agent:
530
704
  chat_history (Tuple[str, str], optional): A list of user/agent chat pairs to initialize the agent memory.
531
705
  data_description (str): The description of the data.
532
706
  assistant_specialty (str): The specialty of the assistant.
707
+ general_instructions (str, optional): General instructions for the agent.
708
+ The Agent has a default set of instructions that are crafted to help it operate effectively.
709
+ This allows you to customize the agent's behavior and personality, but use with caution.
533
710
  verbose (bool, optional): Whether to print verbose output.
534
711
  vectara_filter_fields (List[dict], optional): The filterable attributes
535
712
  (each dict maps field name to Tuple[type, description]).
@@ -626,6 +803,7 @@ class Agent:
626
803
  tools=[vectara_tool],
627
804
  topic=assistant_specialty,
628
805
  custom_instructions=assistant_instructions,
806
+ general_instructions=general_instructions,
629
807
  verbose=verbose,
630
808
  agent_progress_callback=agent_progress_callback,
631
809
  query_logging_callback=query_logging_callback,
@@ -635,7 +813,7 @@ class Agent:
635
813
  )
636
814
 
637
815
  def _switch_agent_config(self) -> None:
638
- """"
816
+ """ "
639
817
  Switch the configuration type of the agent.
640
818
  This function is called automatically to switch the agent configuration if the current configuration fails.
641
819
  """
@@ -659,15 +837,19 @@ class Agent:
659
837
  print(f"Topic = {self._topic}")
660
838
  print("Tools:")
661
839
  for tool in self.tools:
662
- if hasattr(tool, 'metadata'):
840
+ if hasattr(tool, "metadata"):
663
841
  if detailed:
664
842
  print(f"- {tool.metadata.description}")
665
843
  else:
666
844
  print(f"- {tool.metadata.name}")
667
845
  else:
668
846
  print("- tool without metadata")
669
- print(f"Agent LLM = {get_llm(LLMRole.MAIN, config=self.agent_config).metadata.model_name}")
670
- print(f"Tool LLM = {get_llm(LLMRole.TOOL, config=self.agent_config).metadata.model_name}")
847
+ print(
848
+ f"Agent LLM = {get_llm(LLMRole.MAIN, config=self.agent_config).metadata.model_name}"
849
+ )
850
+ print(
851
+ f"Tool LLM = {get_llm(LLMRole.TOOL, config=self.agent_config).metadata.model_name}"
852
+ )
671
853
 
672
854
  def token_counts(self) -> dict:
673
855
  """
@@ -677,16 +859,29 @@ class Agent:
677
859
  dict: The token counts for the agent and tools.
678
860
  """
679
861
  return {
680
- "main token count": self.main_token_counter.total_llm_token_count if self.main_token_counter else -1,
681
- "tool token count": self.tool_token_counter.total_llm_token_count if self.tool_token_counter else -1,
862
+ "main token count": (
863
+ self.main_token_counter.total_llm_token_count
864
+ if self.main_token_counter
865
+ else -1
866
+ ),
867
+ "tool token count": (
868
+ self.tool_token_counter.total_llm_token_count
869
+ if self.tool_token_counter
870
+ else -1
871
+ ),
682
872
  }
683
873
 
684
874
  def _get_current_agent(self):
685
- return self.agent if self.agent_config_type == AgentConfigType.DEFAULT else self.fallback_agent
875
+ return (
876
+ self.agent
877
+ if self.agent_config_type == AgentConfigType.DEFAULT
878
+ else self.fallback_agent
879
+ )
686
880
 
687
881
  def _get_current_agent_type(self):
688
882
  return (
689
- self.agent_config.agent_type if self.agent_config_type == AgentConfigType.DEFAULT
883
+ self.agent_config.agent_type
884
+ if self.agent_config_type == AgentConfigType.DEFAULT
690
885
  else self.fallback_agent_config.agent_type
691
886
  )
692
887
 
@@ -703,7 +898,7 @@ class Agent:
703
898
  agent = self._get_current_agent()
704
899
  agent_response.response = str(agent.llm.acomplete(llm_prompt))
705
900
 
706
- def chat(self, prompt: str) -> AgentResponse: # type: ignore
901
+ def chat(self, prompt: str) -> AgentResponse: # type: ignore
707
902
  """
708
903
  Interact with the agent using a chat prompt.
709
904
 
@@ -715,7 +910,7 @@ class Agent:
715
910
  """
716
911
  return asyncio.run(self.achat(prompt))
717
912
 
718
- async def achat(self, prompt: str) -> AgentResponse: # type: ignore
913
+ async def achat(self, prompt: str) -> AgentResponse: # type: ignore
719
914
  """
720
915
  Interact with the agent using a chat prompt.
721
916
 
@@ -744,7 +939,9 @@ class Agent:
744
939
  last_error = e
745
940
  if attempt >= 2:
746
941
  if self.verbose:
747
- print(f"LLM call failed on attempt {attempt}. Switching agent configuration.")
942
+ print(
943
+ f"LLM call failed on attempt {attempt}. Switching agent configuration."
944
+ )
748
945
  self._switch_agent_config()
749
946
  time.sleep(1)
750
947
  attempt += 1
@@ -756,7 +953,7 @@ class Agent:
756
953
  )
757
954
  )
758
955
 
759
- def stream_chat(self, prompt: str) -> AgentStreamingResponse: # type: ignore
956
+ def stream_chat(self, prompt: str) -> AgentStreamingResponse: # type: ignore
760
957
  """
761
958
  Interact with the agent using a chat prompt with streaming.
762
959
  Args:
@@ -766,7 +963,7 @@ class Agent:
766
963
  """
767
964
  return asyncio.run(self.astream_chat(prompt))
768
965
 
769
- async def astream_chat(self, prompt: str) -> AgentStreamingResponse: # type: ignore
966
+ async def astream_chat(self, prompt: str) -> AgentStreamingResponse: # type: ignore
770
967
  """
771
968
  Interact with the agent using a chat prompt asynchronously with streaming.
772
969
  Args:
@@ -794,14 +991,18 @@ class Agent:
794
991
  if self.observability_enabled:
795
992
  eval_fcs()
796
993
 
797
- agent_response.async_response_gen = _stream_response_wrapper # Override the generator
994
+ agent_response.async_response_gen = (
995
+ _stream_response_wrapper # Override the generator
996
+ )
798
997
  return agent_response
799
998
 
800
999
  except Exception as e:
801
1000
  last_error = e
802
1001
  if attempt >= 2:
803
1002
  if self.verbose:
804
- print(f"LLM call failed on attempt {attempt}. Switching agent configuration.")
1003
+ print(
1004
+ f"LLM call failed on attempt {attempt}. Switching agent configuration."
1005
+ )
805
1006
  self._switch_agent_config()
806
1007
  time.sleep(1)
807
1008
  attempt += 1
@@ -818,11 +1019,7 @@ class Agent:
818
1019
  # workflow will always get these arguments in the StartEvent: agent, tools, llm, verbose
819
1020
  # the inputs argument comes from the call to run()
820
1021
  #
821
- async def run(
822
- self,
823
- inputs: Any,
824
- verbose: bool = False
825
- ) -> Any:
1022
+ async def run(self, inputs: Any, verbose: bool = False) -> Any:
826
1023
  """
827
1024
  Run a workflow using the agent.
828
1025
  workflow class must be provided in the agent constructor.
@@ -886,7 +1083,7 @@ class Agent:
886
1083
  "metadata": {
887
1084
  "module": fn_schema_cls.__module__,
888
1085
  "class": fn_schema_cls.__name__,
889
- }
1086
+ },
890
1087
  }
891
1088
  else:
892
1089
  fn_schema_serialized = None
@@ -895,9 +1092,16 @@ class Agent:
895
1092
  "tool_type": tool.metadata.tool_type.value,
896
1093
  "name": tool.metadata.name,
897
1094
  "description": tool.metadata.description,
898
- "fn": pickle.dumps(getattr(tool, 'fn', None)).decode("latin-1") if getattr(tool, 'fn', None) else None,
899
- "async_fn": pickle.dumps(getattr(tool, 'async_fn', None)).decode("latin-1")
900
- if getattr(tool, 'async_fn', None) else None,
1095
+ "fn": (
1096
+ pickle.dumps(getattr(tool, "fn", None)).decode("latin-1")
1097
+ if getattr(tool, "fn", None)
1098
+ else None
1099
+ ),
1100
+ "async_fn": (
1101
+ pickle.dumps(getattr(tool, "async_fn", None)).decode("latin-1")
1102
+ if getattr(tool, "async_fn", None)
1103
+ else None
1104
+ ),
901
1105
  "fn_schema": fn_schema_serialized,
902
1106
  }
903
1107
  tool_info.append(tool_dict)
@@ -910,7 +1114,11 @@ class Agent:
910
1114
  "custom_instructions": self._custom_instructions,
911
1115
  "verbose": self.verbose,
912
1116
  "agent_config": self.agent_config.to_dict(),
913
- "fallback_agent": self.fallback_agent_config.to_dict() if self.fallback_agent_config else None,
1117
+ "fallback_agent": (
1118
+ self.fallback_agent_config.to_dict()
1119
+ if self.fallback_agent_config
1120
+ else None
1121
+ ),
914
1122
  "workflow_cls": self.workflow_cls if self.workflow_cls else None,
915
1123
  }
916
1124
 
@@ -938,12 +1146,17 @@ class Agent:
938
1146
  except Exception:
939
1147
  # Fallback: rebuild using the JSON schema
940
1148
  field_definitions = {}
941
- for field, values in schema_info.get("schema", {}).get("properties", {}).items():
1149
+ for field, values in (
1150
+ schema_info.get("schema", {}).get("properties", {}).items()
1151
+ ):
942
1152
  field_type = get_field_type(values)
943
1153
  if "default" in values:
944
1154
  field_definitions[field] = (
945
1155
  field_type,
946
- Field(description=values.get("description", ""), default=values["default"]),
1156
+ Field(
1157
+ description=values.get("description", ""),
1158
+ default=values["default"],
1159
+ ),
947
1160
  )
948
1161
  else:
949
1162
  field_definitions[field] = (
@@ -952,13 +1165,21 @@ class Agent:
952
1165
  )
953
1166
  query_args_model = create_model(
954
1167
  schema_info.get("schema", {}).get("title", "QueryArgs"),
955
- **field_definitions
1168
+ **field_definitions,
956
1169
  )
957
1170
  else:
958
1171
  query_args_model = create_model("QueryArgs")
959
1172
 
960
- fn = pickle.loads(tool_data["fn"].encode("latin-1")) if tool_data["fn"] else None
961
- async_fn = pickle.loads(tool_data["async_fn"].encode("latin-1")) if tool_data["async_fn"] else None
1173
+ fn = (
1174
+ pickle.loads(tool_data["fn"].encode("latin-1"))
1175
+ if tool_data["fn"]
1176
+ else None
1177
+ )
1178
+ async_fn = (
1179
+ pickle.loads(tool_data["async_fn"].encode("latin-1"))
1180
+ if tool_data["async_fn"]
1181
+ else None
1182
+ )
962
1183
 
963
1184
  tool = VectaraTool.from_defaults(
964
1185
  name=tool_data["name"],
@@ -979,7 +1200,11 @@ class Agent:
979
1200
  fallback_agent_config=fallback_agent_config,
980
1201
  workflow_cls=data["workflow_cls"],
981
1202
  )
982
- memory = pickle.loads(data["memory"].encode("latin-1")) if data.get("memory") else None
1203
+ memory = (
1204
+ pickle.loads(data["memory"].encode("latin-1"))
1205
+ if data.get("memory")
1206
+ else None
1207
+ )
983
1208
  if memory:
984
1209
  agent.agent.memory = memory
985
1210
  return agent