vectara-agentic 0.2.12__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of vectara-agentic might be problematic. Click here for more details.

vectara_agentic/agent.py CHANGED
@@ -1,6 +1,7 @@
1
1
  """
2
2
  This module contains the Agent class for handling different types of agents and their interactions.
3
3
  """
4
+
4
5
  from typing import List, Callable, Optional, Dict, Any, Union, Tuple
5
6
  import os
6
7
  import re
@@ -21,7 +22,11 @@ from pydantic import Field, create_model, ValidationError
21
22
  from llama_index.core.memory import ChatMemoryBuffer
22
23
  from llama_index.core.llms import ChatMessage, MessageRole
23
24
  from llama_index.core.tools import FunctionTool
24
- from llama_index.core.agent import ReActAgent, StructuredPlannerAgent, FunctionCallingAgent
25
+ from llama_index.core.agent import (
26
+ ReActAgent,
27
+ StructuredPlannerAgent,
28
+ FunctionCallingAgent,
29
+ )
25
30
  from llama_index.core.agent.react.formatter import ReActChatFormatter
26
31
  from llama_index.agent.llm_compiler import LLMCompilerAgentWorker
27
32
  from llama_index.agent.lats import LATSAgentWorker
@@ -33,13 +38,22 @@ from llama_index.core.agent.types import BaseAgent
33
38
  from llama_index.core.workflow import Workflow
34
39
 
35
40
  from .types import (
36
- AgentType, AgentStatusType, LLMRole, ToolType, ModelProvider,
37
- AgentResponse, AgentStreamingResponse, AgentConfigType
41
+ AgentType,
42
+ AgentStatusType,
43
+ LLMRole,
44
+ ToolType,
45
+ ModelProvider,
46
+ AgentResponse,
47
+ AgentStreamingResponse,
48
+ AgentConfigType,
38
49
  )
39
50
  from .utils import get_llm, get_tokenizer_for_model
40
51
  from ._prompts import (
41
- REACT_PROMPT_TEMPLATE, GENERAL_PROMPT_TEMPLATE, GENERAL_INSTRUCTIONS,
42
- STRUCTURED_PLANNER_PLAN_REFINE_PROMPT, STRUCTURED_PLANNER_INITIAL_PLAN_PROMPT
52
+ REACT_PROMPT_TEMPLATE,
53
+ GENERAL_PROMPT_TEMPLATE,
54
+ GENERAL_INSTRUCTIONS,
55
+ STRUCTURED_PLANNER_PLAN_REFINE_PROMPT,
56
+ STRUCTURED_PLANNER_INITIAL_PLAN_PROMPT,
43
57
  )
44
58
  from ._callback import AgentCallbackHandler
45
59
  from ._observability import setup_observer, eval_fcs
@@ -47,10 +61,12 @@ from .tools import VectaraToolFactory, VectaraTool, ToolsFactory
47
61
  from .tools_catalog import get_current_date
48
62
  from .agent_config import AgentConfig
49
63
 
64
+
50
65
  class IgnoreUnpickleableAttributeFilter(logging.Filter):
51
- '''
66
+ """
52
67
  Filter to ignore log messages that contain certain strings
53
- '''
68
+ """
69
+
54
70
  def filter(self, record):
55
71
  msgs_to_ignore = [
56
72
  "Removing unpickleable private attribute _chunking_tokenizer_fn",
@@ -67,12 +83,19 @@ logger.setLevel(logging.CRITICAL)
67
83
 
68
84
  load_dotenv(override=True)
69
85
 
70
- def _get_prompt(prompt_template: str, topic: str, custom_instructions: str):
86
+
87
+ def _get_prompt(
88
+ prompt_template: str,
89
+ general_instructions: str,
90
+ topic: str,
91
+ custom_instructions: str,
92
+ ):
71
93
  """
72
94
  Generate a prompt by replacing placeholders with topic and date.
73
95
 
74
96
  Args:
75
97
  prompt_template (str): The template for the prompt.
98
+ general_instructions (str): General instructions to be included in the prompt.
76
99
  topic (str): The topic to be included in the prompt.
77
100
  custom_instructions(str): The custom instructions to be included in the prompt.
78
101
 
@@ -83,10 +106,13 @@ def _get_prompt(prompt_template: str, topic: str, custom_instructions: str):
83
106
  prompt_template.replace("{chat_topic}", topic)
84
107
  .replace("{today}", date.today().strftime("%A, %B %d, %Y"))
85
108
  .replace("{custom_instructions}", custom_instructions)
109
+ .replace("{INSTRUCTIONS}", general_instructions)
86
110
  )
87
111
 
88
112
 
89
- def _get_llm_compiler_prompt(prompt: str, topic: str, custom_instructions: str) -> str:
113
+ def _get_llm_compiler_prompt(
114
+ prompt: str, general_instructions: str, topic: str, custom_instructions: str
115
+ ) -> str:
90
116
  """
91
117
  Add custom instructions to the prompt.
92
118
 
@@ -98,7 +124,7 @@ def _get_llm_compiler_prompt(prompt: str, topic: str, custom_instructions: str)
98
124
  """
99
125
  prompt += "\nAdditional Instructions:\n"
100
126
  prompt += f"You have experise in {topic}.\n"
101
- prompt += GENERAL_INSTRUCTIONS
127
+ prompt += general_instructions
102
128
  prompt += custom_instructions
103
129
  prompt += f"Today is {date.today().strftime('%A, %B %d, %Y')}"
104
130
  return prompt
@@ -132,6 +158,7 @@ def get_field_type(field_schema: dict) -> Any:
132
158
  else:
133
159
  return Any
134
160
 
161
+
135
162
  class Agent:
136
163
  """
137
164
  Agent class for handling different types of agents and their interactions.
@@ -142,10 +169,13 @@ class Agent:
142
169
  tools: list[FunctionTool],
143
170
  topic: str = "general",
144
171
  custom_instructions: str = "",
172
+ general_instructions: str = GENERAL_INSTRUCTIONS,
145
173
  verbose: bool = True,
146
174
  use_structured_planning: bool = False,
147
175
  update_func: Optional[Callable[[AgentStatusType, str], None]] = None,
148
- agent_progress_callback: Optional[Callable[[AgentStatusType, str], None]] = None,
176
+ agent_progress_callback: Optional[
177
+ Callable[[AgentStatusType, str], None]
178
+ ] = None,
149
179
  query_logging_callback: Optional[Callable[[str, str], None]] = None,
150
180
  agent_config: Optional[AgentConfig] = None,
151
181
  fallback_agent_config: Optional[AgentConfig] = None,
@@ -162,6 +192,9 @@ class Agent:
162
192
  tools (list[FunctionTool]): A list of tools to be used by the agent.
163
193
  topic (str, optional): The topic for the agent. Defaults to 'general'.
164
194
  custom_instructions (str, optional): Custom instructions for the agent. Defaults to ''.
195
+ general_instructions (str, optional): General instructions for the agent.
196
+ The Agent has a default set of instructions that are crafted to help it operate effectively.
197
+ This allows you to customize the agent's behavior and personality, but use with caution.
165
198
  verbose (bool, optional): Whether the agent should print its steps. Defaults to True.
166
199
  use_structured_planning (bool, optional)
167
200
  Whether or not we want to wrap the agent with LlamaIndex StructuredPlannerAgent.
@@ -181,14 +214,17 @@ class Agent:
181
214
  self.agent_config = agent_config or AgentConfig()
182
215
  self.agent_config_type = AgentConfigType.DEFAULT
183
216
  self.tools = tools
184
- if not any(tool.metadata.name == 'get_current_date' for tool in self.tools):
217
+ if not any(tool.metadata.name == "get_current_date" for tool in self.tools):
185
218
  self.tools += [ToolsFactory().create_tool(get_current_date)]
186
219
  self.agent_type = self.agent_config.agent_type
187
220
  self.use_structured_planning = use_structured_planning
188
221
  self.llm = get_llm(LLMRole.MAIN, config=self.agent_config)
189
222
  self._custom_instructions = custom_instructions
223
+ self._general_instructions = general_instructions
190
224
  self._topic = topic
191
- self.agent_progress_callback = agent_progress_callback if agent_progress_callback else update_func
225
+ self.agent_progress_callback = (
226
+ agent_progress_callback if agent_progress_callback else update_func
227
+ )
192
228
  self.query_logging_callback = query_logging_callback
193
229
 
194
230
  self.workflow_cls = workflow_cls
@@ -204,7 +240,7 @@ class Agent:
204
240
  raise ValueError(f"Duplicate tools detected: {', '.join(duplicates)}")
205
241
 
206
242
  if validate_tools:
207
- prompt = f'''
243
+ prompt = f"""
208
244
  Given the following instructions, and a list of tool names,
209
245
  Please identify tools mentioned in the instructions that do not exist in the list.
210
246
  Instructions:
@@ -212,20 +248,28 @@ class Agent:
212
248
  Tool names: {', '.join(tool_names)}
213
249
  Your response should include a comma separated list of tool names that do not exist in the list.
214
250
  Your response should be an empty string if all tools mentioned in the instructions are in the list.
215
- '''
251
+ """
216
252
  llm = get_llm(LLMRole.MAIN, config=self.agent_config)
217
253
  bad_tools = llm.complete(prompt).text.split(", ")
218
254
  if bad_tools:
219
- raise ValueError(f"The Agent custom instructions mention these invalid tools: {', '.join(bad_tools)}")
255
+ raise ValueError(
256
+ f"The Agent custom instructions mention these invalid tools: {', '.join(bad_tools)}"
257
+ )
220
258
 
221
259
  # Create token counters for the main and tool LLMs
222
260
  main_tok = get_tokenizer_for_model(role=LLMRole.MAIN)
223
- self.main_token_counter = TokenCountingHandler(tokenizer=main_tok) if main_tok else None
261
+ self.main_token_counter = (
262
+ TokenCountingHandler(tokenizer=main_tok) if main_tok else None
263
+ )
224
264
  tool_tok = get_tokenizer_for_model(role=LLMRole.TOOL)
225
- self.tool_token_counter = TokenCountingHandler(tokenizer=tool_tok) if tool_tok else None
265
+ self.tool_token_counter = (
266
+ TokenCountingHandler(tokenizer=tool_tok) if tool_tok else None
267
+ )
226
268
 
227
269
  # Setup callback manager
228
- callbacks: list[BaseCallbackHandler] = [AgentCallbackHandler(self.agent_progress_callback)]
270
+ callbacks: list[BaseCallbackHandler] = [
271
+ AgentCallbackHandler(self.agent_progress_callback)
272
+ ]
229
273
  if self.main_token_counter:
230
274
  callbacks.append(self.main_token_counter)
231
275
  if self.tool_token_counter:
@@ -236,9 +280,17 @@ class Agent:
236
280
  if chat_history:
237
281
  msg_history = []
238
282
  for text_pairs in chat_history:
239
- msg_history.append(ChatMessage.from_str(content=text_pairs[0], role=MessageRole.USER))
240
- msg_history.append(ChatMessage.from_str(content=text_pairs[1], role=MessageRole.ASSISTANT))
241
- self.memory = ChatMemoryBuffer.from_defaults(token_limit=128000, chat_history=msg_history)
283
+ msg_history.append(
284
+ ChatMessage.from_str(content=text_pairs[0], role=MessageRole.USER)
285
+ )
286
+ msg_history.append(
287
+ ChatMessage.from_str(
288
+ content=text_pairs[1], role=MessageRole.ASSISTANT
289
+ )
290
+ )
291
+ self.memory = ChatMemoryBuffer.from_defaults(
292
+ token_limit=128000, chat_history=msg_history
293
+ )
242
294
  else:
243
295
  self.memory = ChatMemoryBuffer.from_defaults(token_limit=128000)
244
296
 
@@ -246,7 +298,9 @@ class Agent:
246
298
  self.agent = self._create_agent(self.agent_config, callback_manager)
247
299
  self.fallback_agent_config = fallback_agent_config
248
300
  if self.fallback_agent_config:
249
- self.fallback_agent = self._create_agent(self.fallback_agent_config, callback_manager)
301
+ self.fallback_agent = self._create_agent(
302
+ self.fallback_agent_config, callback_manager
303
+ )
250
304
  else:
251
305
  self.fallback_agent_config = None
252
306
 
@@ -258,9 +312,7 @@ class Agent:
258
312
  self.observability_enabled = False
259
313
 
260
314
  def _create_agent(
261
- self,
262
- config: AgentConfig,
263
- llm_callback_manager: CallbackManager
315
+ self, config: AgentConfig, llm_callback_manager: CallbackManager
264
316
  ) -> Union[BaseAgent, AgentRunner]:
265
317
  """
266
318
  Creates the agent based on the configuration object.
@@ -282,7 +334,12 @@ class Agent:
282
334
  raise ValueError(
283
335
  "Vectara-agentic: Function calling agent type is not supported with the OpenAI LLM."
284
336
  )
285
- prompt = _get_prompt(GENERAL_PROMPT_TEMPLATE, self._topic, self._custom_instructions)
337
+ prompt = _get_prompt(
338
+ GENERAL_PROMPT_TEMPLATE,
339
+ self._general_instructions,
340
+ self._topic,
341
+ self._custom_instructions,
342
+ )
286
343
  agent = FunctionCallingAgent.from_tools(
287
344
  tools=self.tools,
288
345
  llm=llm,
@@ -294,7 +351,12 @@ class Agent:
294
351
  allow_parallel_tool_calls=True,
295
352
  )
296
353
  elif agent_type == AgentType.REACT:
297
- prompt = _get_prompt(REACT_PROMPT_TEMPLATE, self._topic, self._custom_instructions)
354
+ prompt = _get_prompt(
355
+ REACT_PROMPT_TEMPLATE,
356
+ self._general_instructions,
357
+ self._topic,
358
+ self._custom_instructions,
359
+ )
298
360
  agent = ReActAgent.from_tools(
299
361
  tools=self.tools,
300
362
  llm=llm,
@@ -309,7 +371,12 @@ class Agent:
309
371
  raise ValueError(
310
372
  "Vectara-agentic: OPENAI agent type requires the OpenAI LLM."
311
373
  )
312
- prompt = _get_prompt(GENERAL_PROMPT_TEMPLATE, self._topic, self._custom_instructions)
374
+ prompt = _get_prompt(
375
+ GENERAL_PROMPT_TEMPLATE,
376
+ self._general_instructions,
377
+ self._topic,
378
+ self._custom_instructions,
379
+ )
313
380
  agent = OpenAIAgent.from_tools(
314
381
  tools=self.tools,
315
382
  llm=llm,
@@ -327,12 +394,26 @@ class Agent:
327
394
  callback_manager=llm_callback_manager,
328
395
  )
329
396
  agent_worker.system_prompt = _get_prompt(
330
- _get_llm_compiler_prompt(agent_worker.system_prompt, self._topic, self._custom_instructions),
331
- self._topic, self._custom_instructions
397
+ prompt_template=_get_llm_compiler_prompt(
398
+ prompt=agent_worker.system_prompt,
399
+ general_instructions=self._general_instructions,
400
+ topic=self._topic,
401
+ custom_instructions=self._custom_instructions,
402
+ ),
403
+ general_instructions=self._general_instructions,
404
+ topic=self._topic,
405
+ custom_instructions=self._custom_instructions,
332
406
  )
333
407
  agent_worker.system_prompt_replan = _get_prompt(
334
- _get_llm_compiler_prompt(agent_worker.system_prompt_replan, self._topic, self._custom_instructions),
335
- self._topic, self._custom_instructions
408
+ prompt_template=_get_llm_compiler_prompt(
409
+ prompt=agent_worker.system_prompt_replan,
410
+ general_instructions=GENERAL_INSTRUCTIONS,
411
+ topic=self._topic,
412
+ custom_instructions=self._custom_instructions,
413
+ ),
414
+ general_instructions=GENERAL_INSTRUCTIONS,
415
+ topic=self._topic,
416
+ custom_instructions=self._custom_instructions,
336
417
  )
337
418
  agent = agent_worker.as_agent()
338
419
  elif agent_type == AgentType.LATS:
@@ -344,18 +425,27 @@ class Agent:
344
425
  verbose=self.verbose,
345
426
  callback_manager=llm_callback_manager,
346
427
  )
347
- prompt = _get_prompt(REACT_PROMPT_TEMPLATE, self._topic, self._custom_instructions)
428
+ prompt = _get_prompt(
429
+ REACT_PROMPT_TEMPLATE,
430
+ self._general_instructions,
431
+ self._topic,
432
+ self._custom_instructions,
433
+ )
348
434
  agent_worker.chat_formatter = ReActChatFormatter(system_header=prompt)
349
435
  agent = agent_worker.as_agent()
350
436
  else:
351
437
  raise ValueError(f"Unknown agent type: {agent_type}")
352
438
 
353
439
  # Set up structured planner if needed
354
- if (self.use_structured_planning
355
- or self.agent_type in [AgentType.LLMCOMPILER, AgentType.LATS]):
440
+ if self.use_structured_planning or self.agent_type in [
441
+ AgentType.LLMCOMPILER,
442
+ AgentType.LATS,
443
+ ]:
444
+ planner_llm = get_llm(LLMRole.TOOL, config=config)
356
445
  agent = StructuredPlannerAgent(
357
446
  agent_worker=agent.agent_worker,
358
447
  tools=self.tools,
448
+ llm=planner_llm,
359
449
  memory=self.memory,
360
450
  verbose=self.verbose,
361
451
  initial_plan_prompt=STRUCTURED_PLANNER_INITIAL_PLAN_PROMPT,
@@ -370,14 +460,19 @@ class Agent:
370
460
  """
371
461
  if self.agent_config_type == AgentConfigType.DEFAULT:
372
462
  self.agent.memory.reset()
373
- elif self.agent_config_type == AgentConfigType.FALLBACK and self.fallback_agent_config:
463
+ elif (
464
+ self.agent_config_type == AgentConfigType.FALLBACK
465
+ and self.fallback_agent_config
466
+ ):
374
467
  self.fallback_agent.memory.reset()
375
468
  else:
376
469
  raise ValueError(f"Invalid agent config type {self.agent_config_type}")
377
470
 
378
471
  def __eq__(self, other):
379
472
  if not isinstance(other, Agent):
380
- print(f"Comparison failed: other is not an instance of Agent. (self: {type(self)}, other: {type(other)})")
473
+ print(
474
+ f"Comparison failed: other is not an instance of Agent. (self: {type(self)}, other: {type(other)})"
475
+ )
381
476
  return False
382
477
 
383
478
  # Compare agent_type
@@ -393,12 +488,15 @@ class Agent:
393
488
  print(
394
489
  "Comparison failed: tools differ."
395
490
  f"(self.tools: {[t.metadata.name for t in self.tools]}, "
396
- f"other.tools: {[t.metadata.name for t in other.tools]})")
491
+ f"other.tools: {[t.metadata.name for t in other.tools]})"
492
+ )
397
493
  return False
398
494
 
399
495
  # Compare topic
400
496
  if self._topic != other._topic:
401
- print(f"Comparison failed: topic differs. (self.topic: {self._topic}, other.topic: {other._topic})")
497
+ print(
498
+ f"Comparison failed: topic differs. (self.topic: {self._topic}, other.topic: {other._topic})"
499
+ )
402
500
  return False
403
501
 
404
502
  # Compare custom_instructions
@@ -411,7 +509,9 @@ class Agent:
411
509
 
412
510
  # Compare verbose
413
511
  if self.verbose != other.verbose:
414
- print(f"Comparison failed: verbose differs. (self.verbose: {self.verbose}, other.verbose: {other.verbose})")
512
+ print(
513
+ f"Comparison failed: verbose differs. (self.verbose: {self.verbose}, other.verbose: {other.verbose})"
514
+ )
415
515
  return False
416
516
 
417
517
  # Compare agent memory
@@ -434,7 +534,9 @@ class Agent:
434
534
  custom_instructions: str = "",
435
535
  verbose: bool = True,
436
536
  update_func: Optional[Callable[[AgentStatusType, str], None]] = None,
437
- agent_progress_callback: Optional[Callable[[AgentStatusType, str], None]] = None,
537
+ agent_progress_callback: Optional[
538
+ Callable[[AgentStatusType, str], None]
539
+ ] = None,
438
540
  query_logging_callback: Optional[Callable[[str, str], None]] = None,
439
541
  agent_config: AgentConfig = AgentConfig(),
440
542
  validate_tools: bool = False,
@@ -467,14 +569,19 @@ class Agent:
467
569
  Agent: An instance of the Agent class.
468
570
  """
469
571
  return cls(
470
- tools=tools, topic=topic, custom_instructions=custom_instructions,
471
- verbose=verbose, agent_progress_callback=agent_progress_callback,
572
+ tools=tools,
573
+ topic=topic,
574
+ custom_instructions=custom_instructions,
575
+ verbose=verbose,
576
+ agent_progress_callback=agent_progress_callback,
472
577
  query_logging_callback=query_logging_callback,
473
- update_func=update_func, agent_config=agent_config,
578
+ update_func=update_func,
579
+ agent_config=agent_config,
474
580
  chat_history=chat_history,
475
581
  validate_tools=validate_tools,
476
582
  fallback_agent_config=fallback_agent_config,
477
- workflow_cls = workflow_cls, workflow_timeout = workflow_timeout,
583
+ workflow_cls=workflow_cls,
584
+ workflow_timeout=workflow_timeout,
478
585
  )
479
586
 
480
587
  @classmethod
@@ -483,9 +590,12 @@ class Agent:
483
590
  tool_name: str,
484
591
  data_description: str,
485
592
  assistant_specialty: str,
593
+ general_instructions: str = GENERAL_INSTRUCTIONS,
486
594
  vectara_corpus_key: str = str(os.environ.get("VECTARA_CORPUS_KEY", "")),
487
595
  vectara_api_key: str = str(os.environ.get("VECTARA_API_KEY", "")),
488
- agent_progress_callback: Optional[Callable[[AgentStatusType, str], None]] = None,
596
+ agent_progress_callback: Optional[
597
+ Callable[[AgentStatusType, str], None]
598
+ ] = None,
489
599
  query_logging_callback: Optional[Callable[[str, str], None]] = None,
490
600
  agent_config: AgentConfig = AgentConfig(),
491
601
  fallback_agent_config: Optional[AgentConfig] = None,
@@ -530,6 +640,9 @@ class Agent:
530
640
  chat_history (Tuple[str, str], optional): A list of user/agent chat pairs to initialize the agent memory.
531
641
  data_description (str): The description of the data.
532
642
  assistant_specialty (str): The specialty of the assistant.
643
+ general_instructions (str, optional): General instructions for the agent.
644
+ The Agent has a default set of instructions that are crafted to help it operate effectively.
645
+ This allows you to customize the agent's behavior and personality, but use with caution.
533
646
  verbose (bool, optional): Whether to print verbose output.
534
647
  vectara_filter_fields (List[dict], optional): The filterable attributes
535
648
  (each dict maps field name to Tuple[type, description]).
@@ -626,6 +739,7 @@ class Agent:
626
739
  tools=[vectara_tool],
627
740
  topic=assistant_specialty,
628
741
  custom_instructions=assistant_instructions,
742
+ general_instructions=general_instructions,
629
743
  verbose=verbose,
630
744
  agent_progress_callback=agent_progress_callback,
631
745
  query_logging_callback=query_logging_callback,
@@ -635,7 +749,7 @@ class Agent:
635
749
  )
636
750
 
637
751
  def _switch_agent_config(self) -> None:
638
- """"
752
+ """ "
639
753
  Switch the configuration type of the agent.
640
754
  This function is called automatically to switch the agent configuration if the current configuration fails.
641
755
  """
@@ -659,15 +773,19 @@ class Agent:
659
773
  print(f"Topic = {self._topic}")
660
774
  print("Tools:")
661
775
  for tool in self.tools:
662
- if hasattr(tool, 'metadata'):
776
+ if hasattr(tool, "metadata"):
663
777
  if detailed:
664
778
  print(f"- {tool.metadata.description}")
665
779
  else:
666
780
  print(f"- {tool.metadata.name}")
667
781
  else:
668
782
  print("- tool without metadata")
669
- print(f"Agent LLM = {get_llm(LLMRole.MAIN, config=self.agent_config).metadata.model_name}")
670
- print(f"Tool LLM = {get_llm(LLMRole.TOOL, config=self.agent_config).metadata.model_name}")
783
+ print(
784
+ f"Agent LLM = {get_llm(LLMRole.MAIN, config=self.agent_config).metadata.model_name}"
785
+ )
786
+ print(
787
+ f"Tool LLM = {get_llm(LLMRole.TOOL, config=self.agent_config).metadata.model_name}"
788
+ )
671
789
 
672
790
  def token_counts(self) -> dict:
673
791
  """
@@ -677,16 +795,29 @@ class Agent:
677
795
  dict: The token counts for the agent and tools.
678
796
  """
679
797
  return {
680
- "main token count": self.main_token_counter.total_llm_token_count if self.main_token_counter else -1,
681
- "tool token count": self.tool_token_counter.total_llm_token_count if self.tool_token_counter else -1,
798
+ "main token count": (
799
+ self.main_token_counter.total_llm_token_count
800
+ if self.main_token_counter
801
+ else -1
802
+ ),
803
+ "tool token count": (
804
+ self.tool_token_counter.total_llm_token_count
805
+ if self.tool_token_counter
806
+ else -1
807
+ ),
682
808
  }
683
809
 
684
810
  def _get_current_agent(self):
685
- return self.agent if self.agent_config_type == AgentConfigType.DEFAULT else self.fallback_agent
811
+ return (
812
+ self.agent
813
+ if self.agent_config_type == AgentConfigType.DEFAULT
814
+ else self.fallback_agent
815
+ )
686
816
 
687
817
  def _get_current_agent_type(self):
688
818
  return (
689
- self.agent_config.agent_type if self.agent_config_type == AgentConfigType.DEFAULT
819
+ self.agent_config.agent_type
820
+ if self.agent_config_type == AgentConfigType.DEFAULT
690
821
  else self.fallback_agent_config.agent_type
691
822
  )
692
823
 
@@ -703,7 +834,7 @@ class Agent:
703
834
  agent = self._get_current_agent()
704
835
  agent_response.response = str(agent.llm.acomplete(llm_prompt))
705
836
 
706
- def chat(self, prompt: str) -> AgentResponse: # type: ignore
837
+ def chat(self, prompt: str) -> AgentResponse: # type: ignore
707
838
  """
708
839
  Interact with the agent using a chat prompt.
709
840
 
@@ -715,7 +846,7 @@ class Agent:
715
846
  """
716
847
  return asyncio.run(self.achat(prompt))
717
848
 
718
- async def achat(self, prompt: str) -> AgentResponse: # type: ignore
849
+ async def achat(self, prompt: str) -> AgentResponse: # type: ignore
719
850
  """
720
851
  Interact with the agent using a chat prompt.
721
852
 
@@ -744,7 +875,9 @@ class Agent:
744
875
  last_error = e
745
876
  if attempt >= 2:
746
877
  if self.verbose:
747
- print(f"LLM call failed on attempt {attempt}. Switching agent configuration.")
878
+ print(
879
+ f"LLM call failed on attempt {attempt}. Switching agent configuration."
880
+ )
748
881
  self._switch_agent_config()
749
882
  time.sleep(1)
750
883
  attempt += 1
@@ -756,7 +889,7 @@ class Agent:
756
889
  )
757
890
  )
758
891
 
759
- def stream_chat(self, prompt: str) -> AgentStreamingResponse: # type: ignore
892
+ def stream_chat(self, prompt: str) -> AgentStreamingResponse: # type: ignore
760
893
  """
761
894
  Interact with the agent using a chat prompt with streaming.
762
895
  Args:
@@ -766,7 +899,7 @@ class Agent:
766
899
  """
767
900
  return asyncio.run(self.astream_chat(prompt))
768
901
 
769
- async def astream_chat(self, prompt: str) -> AgentStreamingResponse: # type: ignore
902
+ async def astream_chat(self, prompt: str) -> AgentStreamingResponse: # type: ignore
770
903
  """
771
904
  Interact with the agent using a chat prompt asynchronously with streaming.
772
905
  Args:
@@ -794,14 +927,18 @@ class Agent:
794
927
  if self.observability_enabled:
795
928
  eval_fcs()
796
929
 
797
- agent_response.async_response_gen = _stream_response_wrapper # Override the generator
930
+ agent_response.async_response_gen = (
931
+ _stream_response_wrapper # Override the generator
932
+ )
798
933
  return agent_response
799
934
 
800
935
  except Exception as e:
801
936
  last_error = e
802
937
  if attempt >= 2:
803
938
  if self.verbose:
804
- print(f"LLM call failed on attempt {attempt}. Switching agent configuration.")
939
+ print(
940
+ f"LLM call failed on attempt {attempt}. Switching agent configuration."
941
+ )
805
942
  self._switch_agent_config()
806
943
  time.sleep(1)
807
944
  attempt += 1
@@ -818,11 +955,7 @@ class Agent:
818
955
  # workflow will always get these arguments in the StartEvent: agent, tools, llm, verbose
819
956
  # the inputs argument comes from the call to run()
820
957
  #
821
- async def run(
822
- self,
823
- inputs: Any,
824
- verbose: bool = False
825
- ) -> Any:
958
+ async def run(self, inputs: Any, verbose: bool = False) -> Any:
826
959
  """
827
960
  Run a workflow using the agent.
828
961
  workflow class must be provided in the agent constructor.
@@ -886,7 +1019,7 @@ class Agent:
886
1019
  "metadata": {
887
1020
  "module": fn_schema_cls.__module__,
888
1021
  "class": fn_schema_cls.__name__,
889
- }
1022
+ },
890
1023
  }
891
1024
  else:
892
1025
  fn_schema_serialized = None
@@ -895,9 +1028,16 @@ class Agent:
895
1028
  "tool_type": tool.metadata.tool_type.value,
896
1029
  "name": tool.metadata.name,
897
1030
  "description": tool.metadata.description,
898
- "fn": pickle.dumps(getattr(tool, 'fn', None)).decode("latin-1") if getattr(tool, 'fn', None) else None,
899
- "async_fn": pickle.dumps(getattr(tool, 'async_fn', None)).decode("latin-1")
900
- if getattr(tool, 'async_fn', None) else None,
1031
+ "fn": (
1032
+ pickle.dumps(getattr(tool, "fn", None)).decode("latin-1")
1033
+ if getattr(tool, "fn", None)
1034
+ else None
1035
+ ),
1036
+ "async_fn": (
1037
+ pickle.dumps(getattr(tool, "async_fn", None)).decode("latin-1")
1038
+ if getattr(tool, "async_fn", None)
1039
+ else None
1040
+ ),
901
1041
  "fn_schema": fn_schema_serialized,
902
1042
  }
903
1043
  tool_info.append(tool_dict)
@@ -910,7 +1050,11 @@ class Agent:
910
1050
  "custom_instructions": self._custom_instructions,
911
1051
  "verbose": self.verbose,
912
1052
  "agent_config": self.agent_config.to_dict(),
913
- "fallback_agent": self.fallback_agent_config.to_dict() if self.fallback_agent_config else None,
1053
+ "fallback_agent": (
1054
+ self.fallback_agent_config.to_dict()
1055
+ if self.fallback_agent_config
1056
+ else None
1057
+ ),
914
1058
  "workflow_cls": self.workflow_cls if self.workflow_cls else None,
915
1059
  }
916
1060
 
@@ -938,12 +1082,17 @@ class Agent:
938
1082
  except Exception:
939
1083
  # Fallback: rebuild using the JSON schema
940
1084
  field_definitions = {}
941
- for field, values in schema_info.get("schema", {}).get("properties", {}).items():
1085
+ for field, values in (
1086
+ schema_info.get("schema", {}).get("properties", {}).items()
1087
+ ):
942
1088
  field_type = get_field_type(values)
943
1089
  if "default" in values:
944
1090
  field_definitions[field] = (
945
1091
  field_type,
946
- Field(description=values.get("description", ""), default=values["default"]),
1092
+ Field(
1093
+ description=values.get("description", ""),
1094
+ default=values["default"],
1095
+ ),
947
1096
  )
948
1097
  else:
949
1098
  field_definitions[field] = (
@@ -952,13 +1101,21 @@ class Agent:
952
1101
  )
953
1102
  query_args_model = create_model(
954
1103
  schema_info.get("schema", {}).get("title", "QueryArgs"),
955
- **field_definitions
1104
+ **field_definitions,
956
1105
  )
957
1106
  else:
958
1107
  query_args_model = create_model("QueryArgs")
959
1108
 
960
- fn = pickle.loads(tool_data["fn"].encode("latin-1")) if tool_data["fn"] else None
961
- async_fn = pickle.loads(tool_data["async_fn"].encode("latin-1")) if tool_data["async_fn"] else None
1109
+ fn = (
1110
+ pickle.loads(tool_data["fn"].encode("latin-1"))
1111
+ if tool_data["fn"]
1112
+ else None
1113
+ )
1114
+ async_fn = (
1115
+ pickle.loads(tool_data["async_fn"].encode("latin-1"))
1116
+ if tool_data["async_fn"]
1117
+ else None
1118
+ )
962
1119
 
963
1120
  tool = VectaraTool.from_defaults(
964
1121
  name=tool_data["name"],
@@ -979,7 +1136,11 @@ class Agent:
979
1136
  fallback_agent_config=fallback_agent_config,
980
1137
  workflow_cls=data["workflow_cls"],
981
1138
  )
982
- memory = pickle.loads(data["memory"].encode("latin-1")) if data.get("memory") else None
1139
+ memory = (
1140
+ pickle.loads(data["memory"].encode("latin-1"))
1141
+ if data.get("memory")
1142
+ else None
1143
+ )
983
1144
  if memory:
984
1145
  agent.agent.memory = memory
985
1146
  return agent