vectara-agentic 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of vectara-agentic might be problematic. Click here for more details.
- tests/test_agent.py +18 -1
- tests/test_agent_planning.py +0 -9
- tests/test_agent_type.py +40 -0
- tests/test_groq.py +120 -0
- tests/test_tools.py +176 -42
- tests/test_vectara_llms.py +66 -0
- vectara_agentic/_prompts.py +6 -8
- vectara_agentic/_version.py +1 -1
- vectara_agentic/agent.py +304 -79
- vectara_agentic/llm_utils.py +174 -0
- vectara_agentic/tool_utils.py +513 -0
- vectara_agentic/tools.py +73 -452
- vectara_agentic/tools_catalog.py +2 -1
- vectara_agentic/utils.py +25 -150
- {vectara_agentic-0.2.12.dist-info → vectara_agentic-0.2.14.dist-info}/METADATA +355 -236
- vectara_agentic-0.2.14.dist-info/RECORD +33 -0
- {vectara_agentic-0.2.12.dist-info → vectara_agentic-0.2.14.dist-info}/WHEEL +1 -1
- vectara_agentic-0.2.12.dist-info/RECORD +0 -29
- {vectara_agentic-0.2.12.dist-info → vectara_agentic-0.2.14.dist-info}/licenses/LICENSE +0 -0
- {vectara_agentic-0.2.12.dist-info → vectara_agentic-0.2.14.dist-info}/top_level.txt +0 -0
vectara_agentic/agent.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""
|
|
2
2
|
This module contains the Agent class for handling different types of agents and their interactions.
|
|
3
3
|
"""
|
|
4
|
+
|
|
4
5
|
from typing import List, Callable, Optional, Dict, Any, Union, Tuple
|
|
5
6
|
import os
|
|
6
7
|
import re
|
|
@@ -11,6 +12,8 @@ import logging
|
|
|
11
12
|
import asyncio
|
|
12
13
|
import importlib
|
|
13
14
|
from collections import Counter
|
|
15
|
+
import inspect
|
|
16
|
+
from inspect import Signature, Parameter, ismethod
|
|
14
17
|
|
|
15
18
|
import cloudpickle as pickle
|
|
16
19
|
|
|
@@ -18,10 +21,15 @@ from dotenv import load_dotenv
|
|
|
18
21
|
|
|
19
22
|
from pydantic import Field, create_model, ValidationError
|
|
20
23
|
|
|
24
|
+
|
|
21
25
|
from llama_index.core.memory import ChatMemoryBuffer
|
|
22
26
|
from llama_index.core.llms import ChatMessage, MessageRole
|
|
23
27
|
from llama_index.core.tools import FunctionTool
|
|
24
|
-
from llama_index.core.agent import
|
|
28
|
+
from llama_index.core.agent import (
|
|
29
|
+
ReActAgent,
|
|
30
|
+
StructuredPlannerAgent,
|
|
31
|
+
FunctionCallingAgent,
|
|
32
|
+
)
|
|
25
33
|
from llama_index.core.agent.react.formatter import ReActChatFormatter
|
|
26
34
|
from llama_index.agent.llm_compiler import LLMCompilerAgentWorker
|
|
27
35
|
from llama_index.agent.lats import LATSAgentWorker
|
|
@@ -33,13 +41,22 @@ from llama_index.core.agent.types import BaseAgent
|
|
|
33
41
|
from llama_index.core.workflow import Workflow
|
|
34
42
|
|
|
35
43
|
from .types import (
|
|
36
|
-
AgentType,
|
|
37
|
-
|
|
44
|
+
AgentType,
|
|
45
|
+
AgentStatusType,
|
|
46
|
+
LLMRole,
|
|
47
|
+
ToolType,
|
|
48
|
+
ModelProvider,
|
|
49
|
+
AgentResponse,
|
|
50
|
+
AgentStreamingResponse,
|
|
51
|
+
AgentConfigType,
|
|
38
52
|
)
|
|
39
|
-
from .
|
|
53
|
+
from .llm_utils import get_llm, get_tokenizer_for_model
|
|
40
54
|
from ._prompts import (
|
|
41
|
-
REACT_PROMPT_TEMPLATE,
|
|
42
|
-
|
|
55
|
+
REACT_PROMPT_TEMPLATE,
|
|
56
|
+
GENERAL_PROMPT_TEMPLATE,
|
|
57
|
+
GENERAL_INSTRUCTIONS,
|
|
58
|
+
STRUCTURED_PLANNER_PLAN_REFINE_PROMPT,
|
|
59
|
+
STRUCTURED_PLANNER_INITIAL_PLAN_PROMPT,
|
|
43
60
|
)
|
|
44
61
|
from ._callback import AgentCallbackHandler
|
|
45
62
|
from ._observability import setup_observer, eval_fcs
|
|
@@ -47,10 +64,12 @@ from .tools import VectaraToolFactory, VectaraTool, ToolsFactory
|
|
|
47
64
|
from .tools_catalog import get_current_date
|
|
48
65
|
from .agent_config import AgentConfig
|
|
49
66
|
|
|
67
|
+
|
|
50
68
|
class IgnoreUnpickleableAttributeFilter(logging.Filter):
|
|
51
|
-
|
|
69
|
+
"""
|
|
52
70
|
Filter to ignore log messages that contain certain strings
|
|
53
|
-
|
|
71
|
+
"""
|
|
72
|
+
|
|
54
73
|
def filter(self, record):
|
|
55
74
|
msgs_to_ignore = [
|
|
56
75
|
"Removing unpickleable private attribute _chunking_tokenizer_fn",
|
|
@@ -67,12 +86,19 @@ logger.setLevel(logging.CRITICAL)
|
|
|
67
86
|
|
|
68
87
|
load_dotenv(override=True)
|
|
69
88
|
|
|
70
|
-
|
|
89
|
+
|
|
90
|
+
def _get_prompt(
|
|
91
|
+
prompt_template: str,
|
|
92
|
+
general_instructions: str,
|
|
93
|
+
topic: str,
|
|
94
|
+
custom_instructions: str,
|
|
95
|
+
):
|
|
71
96
|
"""
|
|
72
97
|
Generate a prompt by replacing placeholders with topic and date.
|
|
73
98
|
|
|
74
99
|
Args:
|
|
75
100
|
prompt_template (str): The template for the prompt.
|
|
101
|
+
general_instructions (str): General instructions to be included in the prompt.
|
|
76
102
|
topic (str): The topic to be included in the prompt.
|
|
77
103
|
custom_instructions(str): The custom instructions to be included in the prompt.
|
|
78
104
|
|
|
@@ -83,10 +109,13 @@ def _get_prompt(prompt_template: str, topic: str, custom_instructions: str):
|
|
|
83
109
|
prompt_template.replace("{chat_topic}", topic)
|
|
84
110
|
.replace("{today}", date.today().strftime("%A, %B %d, %Y"))
|
|
85
111
|
.replace("{custom_instructions}", custom_instructions)
|
|
112
|
+
.replace("{INSTRUCTIONS}", general_instructions)
|
|
86
113
|
)
|
|
87
114
|
|
|
88
115
|
|
|
89
|
-
def _get_llm_compiler_prompt(
|
|
116
|
+
def _get_llm_compiler_prompt(
|
|
117
|
+
prompt: str, general_instructions: str, topic: str, custom_instructions: str
|
|
118
|
+
) -> str:
|
|
90
119
|
"""
|
|
91
120
|
Add custom instructions to the prompt.
|
|
92
121
|
|
|
@@ -98,7 +127,7 @@ def _get_llm_compiler_prompt(prompt: str, topic: str, custom_instructions: str)
|
|
|
98
127
|
"""
|
|
99
128
|
prompt += "\nAdditional Instructions:\n"
|
|
100
129
|
prompt += f"You have experise in {topic}.\n"
|
|
101
|
-
prompt +=
|
|
130
|
+
prompt += general_instructions
|
|
102
131
|
prompt += custom_instructions
|
|
103
132
|
prompt += f"Today is {date.today().strftime('%A, %B %d, %Y')}"
|
|
104
133
|
return prompt
|
|
@@ -132,6 +161,7 @@ def get_field_type(field_schema: dict) -> Any:
|
|
|
132
161
|
else:
|
|
133
162
|
return Any
|
|
134
163
|
|
|
164
|
+
|
|
135
165
|
class Agent:
|
|
136
166
|
"""
|
|
137
167
|
Agent class for handling different types of agents and their interactions.
|
|
@@ -142,10 +172,13 @@ class Agent:
|
|
|
142
172
|
tools: list[FunctionTool],
|
|
143
173
|
topic: str = "general",
|
|
144
174
|
custom_instructions: str = "",
|
|
175
|
+
general_instructions: str = GENERAL_INSTRUCTIONS,
|
|
145
176
|
verbose: bool = True,
|
|
146
177
|
use_structured_planning: bool = False,
|
|
147
178
|
update_func: Optional[Callable[[AgentStatusType, str], None]] = None,
|
|
148
|
-
agent_progress_callback: Optional[
|
|
179
|
+
agent_progress_callback: Optional[
|
|
180
|
+
Callable[[AgentStatusType, str], None]
|
|
181
|
+
] = None,
|
|
149
182
|
query_logging_callback: Optional[Callable[[str, str], None]] = None,
|
|
150
183
|
agent_config: Optional[AgentConfig] = None,
|
|
151
184
|
fallback_agent_config: Optional[AgentConfig] = None,
|
|
@@ -162,6 +195,9 @@ class Agent:
|
|
|
162
195
|
tools (list[FunctionTool]): A list of tools to be used by the agent.
|
|
163
196
|
topic (str, optional): The topic for the agent. Defaults to 'general'.
|
|
164
197
|
custom_instructions (str, optional): Custom instructions for the agent. Defaults to ''.
|
|
198
|
+
general_instructions (str, optional): General instructions for the agent.
|
|
199
|
+
The Agent has a default set of instructions that are crafted to help it operate effectively.
|
|
200
|
+
This allows you to customize the agent's behavior and personality, but use with caution.
|
|
165
201
|
verbose (bool, optional): Whether the agent should print its steps. Defaults to True.
|
|
166
202
|
use_structured_planning (bool, optional)
|
|
167
203
|
Whether or not we want to wrap the agent with LlamaIndex StructuredPlannerAgent.
|
|
@@ -181,19 +217,26 @@ class Agent:
|
|
|
181
217
|
self.agent_config = agent_config or AgentConfig()
|
|
182
218
|
self.agent_config_type = AgentConfigType.DEFAULT
|
|
183
219
|
self.tools = tools
|
|
184
|
-
if not any(tool.metadata.name ==
|
|
220
|
+
if not any(tool.metadata.name == "get_current_date" for tool in self.tools):
|
|
185
221
|
self.tools += [ToolsFactory().create_tool(get_current_date)]
|
|
186
222
|
self.agent_type = self.agent_config.agent_type
|
|
187
223
|
self.use_structured_planning = use_structured_planning
|
|
188
224
|
self.llm = get_llm(LLMRole.MAIN, config=self.agent_config)
|
|
189
225
|
self._custom_instructions = custom_instructions
|
|
226
|
+
self._general_instructions = general_instructions
|
|
190
227
|
self._topic = topic
|
|
191
|
-
self.agent_progress_callback =
|
|
228
|
+
self.agent_progress_callback = (
|
|
229
|
+
agent_progress_callback if agent_progress_callback else update_func
|
|
230
|
+
)
|
|
192
231
|
self.query_logging_callback = query_logging_callback
|
|
193
232
|
|
|
194
233
|
self.workflow_cls = workflow_cls
|
|
195
234
|
self.workflow_timeout = workflow_timeout
|
|
196
235
|
|
|
236
|
+
# Sanitize tools for Gemini if needed
|
|
237
|
+
if self.agent_config.main_llm_provider == ModelProvider.GEMINI:
|
|
238
|
+
self.tools = self._sanitize_tools_for_gemini(self.tools)
|
|
239
|
+
|
|
197
240
|
# Validate tools
|
|
198
241
|
# Check for:
|
|
199
242
|
# 1. multiple copies of the same tool
|
|
@@ -204,7 +247,7 @@ class Agent:
|
|
|
204
247
|
raise ValueError(f"Duplicate tools detected: {', '.join(duplicates)}")
|
|
205
248
|
|
|
206
249
|
if validate_tools:
|
|
207
|
-
prompt = f
|
|
250
|
+
prompt = f"""
|
|
208
251
|
Given the following instructions, and a list of tool names,
|
|
209
252
|
Please identify tools mentioned in the instructions that do not exist in the list.
|
|
210
253
|
Instructions:
|
|
@@ -212,20 +255,28 @@ class Agent:
|
|
|
212
255
|
Tool names: {', '.join(tool_names)}
|
|
213
256
|
Your response should include a comma separated list of tool names that do not exist in the list.
|
|
214
257
|
Your response should be an empty string if all tools mentioned in the instructions are in the list.
|
|
215
|
-
|
|
258
|
+
"""
|
|
216
259
|
llm = get_llm(LLMRole.MAIN, config=self.agent_config)
|
|
217
260
|
bad_tools = llm.complete(prompt).text.split(", ")
|
|
218
261
|
if bad_tools:
|
|
219
|
-
raise ValueError(
|
|
262
|
+
raise ValueError(
|
|
263
|
+
f"The Agent custom instructions mention these invalid tools: {', '.join(bad_tools)}"
|
|
264
|
+
)
|
|
220
265
|
|
|
221
266
|
# Create token counters for the main and tool LLMs
|
|
222
267
|
main_tok = get_tokenizer_for_model(role=LLMRole.MAIN)
|
|
223
|
-
self.main_token_counter =
|
|
268
|
+
self.main_token_counter = (
|
|
269
|
+
TokenCountingHandler(tokenizer=main_tok) if main_tok else None
|
|
270
|
+
)
|
|
224
271
|
tool_tok = get_tokenizer_for_model(role=LLMRole.TOOL)
|
|
225
|
-
self.tool_token_counter =
|
|
272
|
+
self.tool_token_counter = (
|
|
273
|
+
TokenCountingHandler(tokenizer=tool_tok) if tool_tok else None
|
|
274
|
+
)
|
|
226
275
|
|
|
227
276
|
# Setup callback manager
|
|
228
|
-
callbacks: list[BaseCallbackHandler] = [
|
|
277
|
+
callbacks: list[BaseCallbackHandler] = [
|
|
278
|
+
AgentCallbackHandler(self.agent_progress_callback)
|
|
279
|
+
]
|
|
229
280
|
if self.main_token_counter:
|
|
230
281
|
callbacks.append(self.main_token_counter)
|
|
231
282
|
if self.tool_token_counter:
|
|
@@ -236,9 +287,17 @@ class Agent:
|
|
|
236
287
|
if chat_history:
|
|
237
288
|
msg_history = []
|
|
238
289
|
for text_pairs in chat_history:
|
|
239
|
-
msg_history.append(
|
|
240
|
-
|
|
241
|
-
|
|
290
|
+
msg_history.append(
|
|
291
|
+
ChatMessage.from_str(content=text_pairs[0], role=MessageRole.USER)
|
|
292
|
+
)
|
|
293
|
+
msg_history.append(
|
|
294
|
+
ChatMessage.from_str(
|
|
295
|
+
content=text_pairs[1], role=MessageRole.ASSISTANT
|
|
296
|
+
)
|
|
297
|
+
)
|
|
298
|
+
self.memory = ChatMemoryBuffer.from_defaults(
|
|
299
|
+
token_limit=128000, chat_history=msg_history
|
|
300
|
+
)
|
|
242
301
|
else:
|
|
243
302
|
self.memory = ChatMemoryBuffer.from_defaults(token_limit=128000)
|
|
244
303
|
|
|
@@ -246,7 +305,9 @@ class Agent:
|
|
|
246
305
|
self.agent = self._create_agent(self.agent_config, callback_manager)
|
|
247
306
|
self.fallback_agent_config = fallback_agent_config
|
|
248
307
|
if self.fallback_agent_config:
|
|
249
|
-
self.fallback_agent = self._create_agent(
|
|
308
|
+
self.fallback_agent = self._create_agent(
|
|
309
|
+
self.fallback_agent_config, callback_manager
|
|
310
|
+
)
|
|
250
311
|
else:
|
|
251
312
|
self.fallback_agent_config = None
|
|
252
313
|
|
|
@@ -257,10 +318,65 @@ class Agent:
|
|
|
257
318
|
print(f"Failed to set up observer ({e}), ignoring")
|
|
258
319
|
self.observability_enabled = False
|
|
259
320
|
|
|
321
|
+
def _sanitize_tools_for_gemini(
|
|
322
|
+
self, tools: list[FunctionTool]
|
|
323
|
+
) -> list[FunctionTool]:
|
|
324
|
+
"""
|
|
325
|
+
Strip all default values from:
|
|
326
|
+
- tool.fn
|
|
327
|
+
- tool.async_fn
|
|
328
|
+
- tool.metadata.fn_schema
|
|
329
|
+
so Gemini sees *only* required parameters, no defaults.
|
|
330
|
+
"""
|
|
331
|
+
for tool in tools:
|
|
332
|
+
# 1) strip defaults off the actual callables
|
|
333
|
+
for func in (tool.fn, tool.async_fn):
|
|
334
|
+
if not func:
|
|
335
|
+
continue
|
|
336
|
+
orig_sig = inspect.signature(func)
|
|
337
|
+
new_params = [
|
|
338
|
+
p.replace(default=Parameter.empty)
|
|
339
|
+
for p in orig_sig.parameters.values()
|
|
340
|
+
]
|
|
341
|
+
new_sig = Signature(
|
|
342
|
+
new_params, return_annotation=orig_sig.return_annotation
|
|
343
|
+
)
|
|
344
|
+
if ismethod(func):
|
|
345
|
+
func.__func__.__signature__ = new_sig
|
|
346
|
+
else:
|
|
347
|
+
func.__signature__ = new_sig
|
|
348
|
+
|
|
349
|
+
# 2) rebuild the Pydantic schema so that *every* field is required
|
|
350
|
+
schema_cls = getattr(tool.metadata, "fn_schema", None)
|
|
351
|
+
if schema_cls and hasattr(schema_cls, "model_fields"):
|
|
352
|
+
# collect (name → (type, Field(...))) for all fields
|
|
353
|
+
new_fields: dict[str, tuple[type, Any]] = {}
|
|
354
|
+
for name, mf in schema_cls.model_fields.items():
|
|
355
|
+
typ = mf.annotation
|
|
356
|
+
desc = getattr(mf, "description", "")
|
|
357
|
+
# force required (no default) with Field(...)
|
|
358
|
+
new_fields[name] = (typ, Field(..., description=desc))
|
|
359
|
+
|
|
360
|
+
# make a brand-new schema class where every field is required
|
|
361
|
+
no_default_schema = create_model(
|
|
362
|
+
f"{schema_cls.__name__}", # new class name
|
|
363
|
+
**new_fields, # type: ignore
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
# give it a clean __signature__ so inspect.signature sees no defaults
|
|
367
|
+
params = [
|
|
368
|
+
Parameter(n, Parameter.POSITIONAL_OR_KEYWORD, annotation=typ)
|
|
369
|
+
for n, (typ, _) in new_fields.items()
|
|
370
|
+
]
|
|
371
|
+
no_default_schema.__signature__ = Signature(params)
|
|
372
|
+
|
|
373
|
+
# swap it back onto the tool
|
|
374
|
+
tool.metadata.fn_schema = no_default_schema
|
|
375
|
+
|
|
376
|
+
return tools
|
|
377
|
+
|
|
260
378
|
def _create_agent(
|
|
261
|
-
self,
|
|
262
|
-
config: AgentConfig,
|
|
263
|
-
llm_callback_manager: CallbackManager
|
|
379
|
+
self, config: AgentConfig, llm_callback_manager: CallbackManager
|
|
264
380
|
) -> Union[BaseAgent, AgentRunner]:
|
|
265
381
|
"""
|
|
266
382
|
Creates the agent based on the configuration object.
|
|
@@ -282,7 +398,12 @@ class Agent:
|
|
|
282
398
|
raise ValueError(
|
|
283
399
|
"Vectara-agentic: Function calling agent type is not supported with the OpenAI LLM."
|
|
284
400
|
)
|
|
285
|
-
prompt = _get_prompt(
|
|
401
|
+
prompt = _get_prompt(
|
|
402
|
+
GENERAL_PROMPT_TEMPLATE,
|
|
403
|
+
self._general_instructions,
|
|
404
|
+
self._topic,
|
|
405
|
+
self._custom_instructions,
|
|
406
|
+
)
|
|
286
407
|
agent = FunctionCallingAgent.from_tools(
|
|
287
408
|
tools=self.tools,
|
|
288
409
|
llm=llm,
|
|
@@ -294,7 +415,12 @@ class Agent:
|
|
|
294
415
|
allow_parallel_tool_calls=True,
|
|
295
416
|
)
|
|
296
417
|
elif agent_type == AgentType.REACT:
|
|
297
|
-
prompt = _get_prompt(
|
|
418
|
+
prompt = _get_prompt(
|
|
419
|
+
REACT_PROMPT_TEMPLATE,
|
|
420
|
+
self._general_instructions,
|
|
421
|
+
self._topic,
|
|
422
|
+
self._custom_instructions,
|
|
423
|
+
)
|
|
298
424
|
agent = ReActAgent.from_tools(
|
|
299
425
|
tools=self.tools,
|
|
300
426
|
llm=llm,
|
|
@@ -309,7 +435,12 @@ class Agent:
|
|
|
309
435
|
raise ValueError(
|
|
310
436
|
"Vectara-agentic: OPENAI agent type requires the OpenAI LLM."
|
|
311
437
|
)
|
|
312
|
-
prompt = _get_prompt(
|
|
438
|
+
prompt = _get_prompt(
|
|
439
|
+
GENERAL_PROMPT_TEMPLATE,
|
|
440
|
+
self._general_instructions,
|
|
441
|
+
self._topic,
|
|
442
|
+
self._custom_instructions,
|
|
443
|
+
)
|
|
313
444
|
agent = OpenAIAgent.from_tools(
|
|
314
445
|
tools=self.tools,
|
|
315
446
|
llm=llm,
|
|
@@ -327,12 +458,26 @@ class Agent:
|
|
|
327
458
|
callback_manager=llm_callback_manager,
|
|
328
459
|
)
|
|
329
460
|
agent_worker.system_prompt = _get_prompt(
|
|
330
|
-
_get_llm_compiler_prompt(
|
|
331
|
-
|
|
461
|
+
prompt_template=_get_llm_compiler_prompt(
|
|
462
|
+
prompt=agent_worker.system_prompt,
|
|
463
|
+
general_instructions=self._general_instructions,
|
|
464
|
+
topic=self._topic,
|
|
465
|
+
custom_instructions=self._custom_instructions,
|
|
466
|
+
),
|
|
467
|
+
general_instructions=self._general_instructions,
|
|
468
|
+
topic=self._topic,
|
|
469
|
+
custom_instructions=self._custom_instructions,
|
|
332
470
|
)
|
|
333
471
|
agent_worker.system_prompt_replan = _get_prompt(
|
|
334
|
-
_get_llm_compiler_prompt(
|
|
335
|
-
|
|
472
|
+
prompt_template=_get_llm_compiler_prompt(
|
|
473
|
+
prompt=agent_worker.system_prompt_replan,
|
|
474
|
+
general_instructions=GENERAL_INSTRUCTIONS,
|
|
475
|
+
topic=self._topic,
|
|
476
|
+
custom_instructions=self._custom_instructions,
|
|
477
|
+
),
|
|
478
|
+
general_instructions=GENERAL_INSTRUCTIONS,
|
|
479
|
+
topic=self._topic,
|
|
480
|
+
custom_instructions=self._custom_instructions,
|
|
336
481
|
)
|
|
337
482
|
agent = agent_worker.as_agent()
|
|
338
483
|
elif agent_type == AgentType.LATS:
|
|
@@ -344,18 +489,27 @@ class Agent:
|
|
|
344
489
|
verbose=self.verbose,
|
|
345
490
|
callback_manager=llm_callback_manager,
|
|
346
491
|
)
|
|
347
|
-
prompt = _get_prompt(
|
|
492
|
+
prompt = _get_prompt(
|
|
493
|
+
REACT_PROMPT_TEMPLATE,
|
|
494
|
+
self._general_instructions,
|
|
495
|
+
self._topic,
|
|
496
|
+
self._custom_instructions,
|
|
497
|
+
)
|
|
348
498
|
agent_worker.chat_formatter = ReActChatFormatter(system_header=prompt)
|
|
349
499
|
agent = agent_worker.as_agent()
|
|
350
500
|
else:
|
|
351
501
|
raise ValueError(f"Unknown agent type: {agent_type}")
|
|
352
502
|
|
|
353
503
|
# Set up structured planner if needed
|
|
354
|
-
if
|
|
355
|
-
|
|
504
|
+
if self.use_structured_planning or self.agent_type in [
|
|
505
|
+
AgentType.LLMCOMPILER,
|
|
506
|
+
AgentType.LATS,
|
|
507
|
+
]:
|
|
508
|
+
planner_llm = get_llm(LLMRole.TOOL, config=config)
|
|
356
509
|
agent = StructuredPlannerAgent(
|
|
357
510
|
agent_worker=agent.agent_worker,
|
|
358
511
|
tools=self.tools,
|
|
512
|
+
llm=planner_llm,
|
|
359
513
|
memory=self.memory,
|
|
360
514
|
verbose=self.verbose,
|
|
361
515
|
initial_plan_prompt=STRUCTURED_PLANNER_INITIAL_PLAN_PROMPT,
|
|
@@ -370,14 +524,19 @@ class Agent:
|
|
|
370
524
|
"""
|
|
371
525
|
if self.agent_config_type == AgentConfigType.DEFAULT:
|
|
372
526
|
self.agent.memory.reset()
|
|
373
|
-
elif
|
|
527
|
+
elif (
|
|
528
|
+
self.agent_config_type == AgentConfigType.FALLBACK
|
|
529
|
+
and self.fallback_agent_config
|
|
530
|
+
):
|
|
374
531
|
self.fallback_agent.memory.reset()
|
|
375
532
|
else:
|
|
376
533
|
raise ValueError(f"Invalid agent config type {self.agent_config_type}")
|
|
377
534
|
|
|
378
535
|
def __eq__(self, other):
|
|
379
536
|
if not isinstance(other, Agent):
|
|
380
|
-
print(
|
|
537
|
+
print(
|
|
538
|
+
f"Comparison failed: other is not an instance of Agent. (self: {type(self)}, other: {type(other)})"
|
|
539
|
+
)
|
|
381
540
|
return False
|
|
382
541
|
|
|
383
542
|
# Compare agent_type
|
|
@@ -393,12 +552,15 @@ class Agent:
|
|
|
393
552
|
print(
|
|
394
553
|
"Comparison failed: tools differ."
|
|
395
554
|
f"(self.tools: {[t.metadata.name for t in self.tools]}, "
|
|
396
|
-
f"other.tools: {[t.metadata.name for t in other.tools]})"
|
|
555
|
+
f"other.tools: {[t.metadata.name for t in other.tools]})"
|
|
556
|
+
)
|
|
397
557
|
return False
|
|
398
558
|
|
|
399
559
|
# Compare topic
|
|
400
560
|
if self._topic != other._topic:
|
|
401
|
-
print(
|
|
561
|
+
print(
|
|
562
|
+
f"Comparison failed: topic differs. (self.topic: {self._topic}, other.topic: {other._topic})"
|
|
563
|
+
)
|
|
402
564
|
return False
|
|
403
565
|
|
|
404
566
|
# Compare custom_instructions
|
|
@@ -411,7 +573,9 @@ class Agent:
|
|
|
411
573
|
|
|
412
574
|
# Compare verbose
|
|
413
575
|
if self.verbose != other.verbose:
|
|
414
|
-
print(
|
|
576
|
+
print(
|
|
577
|
+
f"Comparison failed: verbose differs. (self.verbose: {self.verbose}, other.verbose: {other.verbose})"
|
|
578
|
+
)
|
|
415
579
|
return False
|
|
416
580
|
|
|
417
581
|
# Compare agent memory
|
|
@@ -434,7 +598,9 @@ class Agent:
|
|
|
434
598
|
custom_instructions: str = "",
|
|
435
599
|
verbose: bool = True,
|
|
436
600
|
update_func: Optional[Callable[[AgentStatusType, str], None]] = None,
|
|
437
|
-
agent_progress_callback: Optional[
|
|
601
|
+
agent_progress_callback: Optional[
|
|
602
|
+
Callable[[AgentStatusType, str], None]
|
|
603
|
+
] = None,
|
|
438
604
|
query_logging_callback: Optional[Callable[[str, str], None]] = None,
|
|
439
605
|
agent_config: AgentConfig = AgentConfig(),
|
|
440
606
|
validate_tools: bool = False,
|
|
@@ -467,14 +633,19 @@ class Agent:
|
|
|
467
633
|
Agent: An instance of the Agent class.
|
|
468
634
|
"""
|
|
469
635
|
return cls(
|
|
470
|
-
tools=tools,
|
|
471
|
-
|
|
636
|
+
tools=tools,
|
|
637
|
+
topic=topic,
|
|
638
|
+
custom_instructions=custom_instructions,
|
|
639
|
+
verbose=verbose,
|
|
640
|
+
agent_progress_callback=agent_progress_callback,
|
|
472
641
|
query_logging_callback=query_logging_callback,
|
|
473
|
-
update_func=update_func,
|
|
642
|
+
update_func=update_func,
|
|
643
|
+
agent_config=agent_config,
|
|
474
644
|
chat_history=chat_history,
|
|
475
645
|
validate_tools=validate_tools,
|
|
476
646
|
fallback_agent_config=fallback_agent_config,
|
|
477
|
-
workflow_cls
|
|
647
|
+
workflow_cls=workflow_cls,
|
|
648
|
+
workflow_timeout=workflow_timeout,
|
|
478
649
|
)
|
|
479
650
|
|
|
480
651
|
@classmethod
|
|
@@ -483,9 +654,12 @@ class Agent:
|
|
|
483
654
|
tool_name: str,
|
|
484
655
|
data_description: str,
|
|
485
656
|
assistant_specialty: str,
|
|
657
|
+
general_instructions: str = GENERAL_INSTRUCTIONS,
|
|
486
658
|
vectara_corpus_key: str = str(os.environ.get("VECTARA_CORPUS_KEY", "")),
|
|
487
659
|
vectara_api_key: str = str(os.environ.get("VECTARA_API_KEY", "")),
|
|
488
|
-
agent_progress_callback: Optional[
|
|
660
|
+
agent_progress_callback: Optional[
|
|
661
|
+
Callable[[AgentStatusType, str], None]
|
|
662
|
+
] = None,
|
|
489
663
|
query_logging_callback: Optional[Callable[[str, str], None]] = None,
|
|
490
664
|
agent_config: AgentConfig = AgentConfig(),
|
|
491
665
|
fallback_agent_config: Optional[AgentConfig] = None,
|
|
@@ -530,6 +704,9 @@ class Agent:
|
|
|
530
704
|
chat_history (Tuple[str, str], optional): A list of user/agent chat pairs to initialize the agent memory.
|
|
531
705
|
data_description (str): The description of the data.
|
|
532
706
|
assistant_specialty (str): The specialty of the assistant.
|
|
707
|
+
general_instructions (str, optional): General instructions for the agent.
|
|
708
|
+
The Agent has a default set of instructions that are crafted to help it operate effectively.
|
|
709
|
+
This allows you to customize the agent's behavior and personality, but use with caution.
|
|
533
710
|
verbose (bool, optional): Whether to print verbose output.
|
|
534
711
|
vectara_filter_fields (List[dict], optional): The filterable attributes
|
|
535
712
|
(each dict maps field name to Tuple[type, description]).
|
|
@@ -626,6 +803,7 @@ class Agent:
|
|
|
626
803
|
tools=[vectara_tool],
|
|
627
804
|
topic=assistant_specialty,
|
|
628
805
|
custom_instructions=assistant_instructions,
|
|
806
|
+
general_instructions=general_instructions,
|
|
629
807
|
verbose=verbose,
|
|
630
808
|
agent_progress_callback=agent_progress_callback,
|
|
631
809
|
query_logging_callback=query_logging_callback,
|
|
@@ -635,7 +813,7 @@ class Agent:
|
|
|
635
813
|
)
|
|
636
814
|
|
|
637
815
|
def _switch_agent_config(self) -> None:
|
|
638
|
-
""""
|
|
816
|
+
""" "
|
|
639
817
|
Switch the configuration type of the agent.
|
|
640
818
|
This function is called automatically to switch the agent configuration if the current configuration fails.
|
|
641
819
|
"""
|
|
@@ -659,15 +837,19 @@ class Agent:
|
|
|
659
837
|
print(f"Topic = {self._topic}")
|
|
660
838
|
print("Tools:")
|
|
661
839
|
for tool in self.tools:
|
|
662
|
-
if hasattr(tool,
|
|
840
|
+
if hasattr(tool, "metadata"):
|
|
663
841
|
if detailed:
|
|
664
842
|
print(f"- {tool.metadata.description}")
|
|
665
843
|
else:
|
|
666
844
|
print(f"- {tool.metadata.name}")
|
|
667
845
|
else:
|
|
668
846
|
print("- tool without metadata")
|
|
669
|
-
print(
|
|
670
|
-
|
|
847
|
+
print(
|
|
848
|
+
f"Agent LLM = {get_llm(LLMRole.MAIN, config=self.agent_config).metadata.model_name}"
|
|
849
|
+
)
|
|
850
|
+
print(
|
|
851
|
+
f"Tool LLM = {get_llm(LLMRole.TOOL, config=self.agent_config).metadata.model_name}"
|
|
852
|
+
)
|
|
671
853
|
|
|
672
854
|
def token_counts(self) -> dict:
|
|
673
855
|
"""
|
|
@@ -677,16 +859,29 @@ class Agent:
|
|
|
677
859
|
dict: The token counts for the agent and tools.
|
|
678
860
|
"""
|
|
679
861
|
return {
|
|
680
|
-
"main token count":
|
|
681
|
-
|
|
862
|
+
"main token count": (
|
|
863
|
+
self.main_token_counter.total_llm_token_count
|
|
864
|
+
if self.main_token_counter
|
|
865
|
+
else -1
|
|
866
|
+
),
|
|
867
|
+
"tool token count": (
|
|
868
|
+
self.tool_token_counter.total_llm_token_count
|
|
869
|
+
if self.tool_token_counter
|
|
870
|
+
else -1
|
|
871
|
+
),
|
|
682
872
|
}
|
|
683
873
|
|
|
684
874
|
def _get_current_agent(self):
|
|
685
|
-
return
|
|
875
|
+
return (
|
|
876
|
+
self.agent
|
|
877
|
+
if self.agent_config_type == AgentConfigType.DEFAULT
|
|
878
|
+
else self.fallback_agent
|
|
879
|
+
)
|
|
686
880
|
|
|
687
881
|
def _get_current_agent_type(self):
|
|
688
882
|
return (
|
|
689
|
-
self.agent_config.agent_type
|
|
883
|
+
self.agent_config.agent_type
|
|
884
|
+
if self.agent_config_type == AgentConfigType.DEFAULT
|
|
690
885
|
else self.fallback_agent_config.agent_type
|
|
691
886
|
)
|
|
692
887
|
|
|
@@ -703,7 +898,7 @@ class Agent:
|
|
|
703
898
|
agent = self._get_current_agent()
|
|
704
899
|
agent_response.response = str(agent.llm.acomplete(llm_prompt))
|
|
705
900
|
|
|
706
|
-
def chat(self, prompt: str) -> AgentResponse:
|
|
901
|
+
def chat(self, prompt: str) -> AgentResponse: # type: ignore
|
|
707
902
|
"""
|
|
708
903
|
Interact with the agent using a chat prompt.
|
|
709
904
|
|
|
@@ -715,7 +910,7 @@ class Agent:
|
|
|
715
910
|
"""
|
|
716
911
|
return asyncio.run(self.achat(prompt))
|
|
717
912
|
|
|
718
|
-
async def achat(self, prompt: str) -> AgentResponse:
|
|
913
|
+
async def achat(self, prompt: str) -> AgentResponse: # type: ignore
|
|
719
914
|
"""
|
|
720
915
|
Interact with the agent using a chat prompt.
|
|
721
916
|
|
|
@@ -744,7 +939,9 @@ class Agent:
|
|
|
744
939
|
last_error = e
|
|
745
940
|
if attempt >= 2:
|
|
746
941
|
if self.verbose:
|
|
747
|
-
print(
|
|
942
|
+
print(
|
|
943
|
+
f"LLM call failed on attempt {attempt}. Switching agent configuration."
|
|
944
|
+
)
|
|
748
945
|
self._switch_agent_config()
|
|
749
946
|
time.sleep(1)
|
|
750
947
|
attempt += 1
|
|
@@ -756,7 +953,7 @@ class Agent:
|
|
|
756
953
|
)
|
|
757
954
|
)
|
|
758
955
|
|
|
759
|
-
def stream_chat(self, prompt: str) -> AgentStreamingResponse:
|
|
956
|
+
def stream_chat(self, prompt: str) -> AgentStreamingResponse: # type: ignore
|
|
760
957
|
"""
|
|
761
958
|
Interact with the agent using a chat prompt with streaming.
|
|
762
959
|
Args:
|
|
@@ -766,7 +963,7 @@ class Agent:
|
|
|
766
963
|
"""
|
|
767
964
|
return asyncio.run(self.astream_chat(prompt))
|
|
768
965
|
|
|
769
|
-
async def astream_chat(self, prompt: str) -> AgentStreamingResponse:
|
|
966
|
+
async def astream_chat(self, prompt: str) -> AgentStreamingResponse: # type: ignore
|
|
770
967
|
"""
|
|
771
968
|
Interact with the agent using a chat prompt asynchronously with streaming.
|
|
772
969
|
Args:
|
|
@@ -794,14 +991,18 @@ class Agent:
|
|
|
794
991
|
if self.observability_enabled:
|
|
795
992
|
eval_fcs()
|
|
796
993
|
|
|
797
|
-
agent_response.async_response_gen =
|
|
994
|
+
agent_response.async_response_gen = (
|
|
995
|
+
_stream_response_wrapper # Override the generator
|
|
996
|
+
)
|
|
798
997
|
return agent_response
|
|
799
998
|
|
|
800
999
|
except Exception as e:
|
|
801
1000
|
last_error = e
|
|
802
1001
|
if attempt >= 2:
|
|
803
1002
|
if self.verbose:
|
|
804
|
-
print(
|
|
1003
|
+
print(
|
|
1004
|
+
f"LLM call failed on attempt {attempt}. Switching agent configuration."
|
|
1005
|
+
)
|
|
805
1006
|
self._switch_agent_config()
|
|
806
1007
|
time.sleep(1)
|
|
807
1008
|
attempt += 1
|
|
@@ -818,11 +1019,7 @@ class Agent:
|
|
|
818
1019
|
# workflow will always get these arguments in the StartEvent: agent, tools, llm, verbose
|
|
819
1020
|
# the inputs argument comes from the call to run()
|
|
820
1021
|
#
|
|
821
|
-
async def run(
|
|
822
|
-
self,
|
|
823
|
-
inputs: Any,
|
|
824
|
-
verbose: bool = False
|
|
825
|
-
) -> Any:
|
|
1022
|
+
async def run(self, inputs: Any, verbose: bool = False) -> Any:
|
|
826
1023
|
"""
|
|
827
1024
|
Run a workflow using the agent.
|
|
828
1025
|
workflow class must be provided in the agent constructor.
|
|
@@ -886,7 +1083,7 @@ class Agent:
|
|
|
886
1083
|
"metadata": {
|
|
887
1084
|
"module": fn_schema_cls.__module__,
|
|
888
1085
|
"class": fn_schema_cls.__name__,
|
|
889
|
-
}
|
|
1086
|
+
},
|
|
890
1087
|
}
|
|
891
1088
|
else:
|
|
892
1089
|
fn_schema_serialized = None
|
|
@@ -895,9 +1092,16 @@ class Agent:
|
|
|
895
1092
|
"tool_type": tool.metadata.tool_type.value,
|
|
896
1093
|
"name": tool.metadata.name,
|
|
897
1094
|
"description": tool.metadata.description,
|
|
898
|
-
"fn":
|
|
899
|
-
|
|
900
|
-
|
|
1095
|
+
"fn": (
|
|
1096
|
+
pickle.dumps(getattr(tool, "fn", None)).decode("latin-1")
|
|
1097
|
+
if getattr(tool, "fn", None)
|
|
1098
|
+
else None
|
|
1099
|
+
),
|
|
1100
|
+
"async_fn": (
|
|
1101
|
+
pickle.dumps(getattr(tool, "async_fn", None)).decode("latin-1")
|
|
1102
|
+
if getattr(tool, "async_fn", None)
|
|
1103
|
+
else None
|
|
1104
|
+
),
|
|
901
1105
|
"fn_schema": fn_schema_serialized,
|
|
902
1106
|
}
|
|
903
1107
|
tool_info.append(tool_dict)
|
|
@@ -910,7 +1114,11 @@ class Agent:
|
|
|
910
1114
|
"custom_instructions": self._custom_instructions,
|
|
911
1115
|
"verbose": self.verbose,
|
|
912
1116
|
"agent_config": self.agent_config.to_dict(),
|
|
913
|
-
"fallback_agent":
|
|
1117
|
+
"fallback_agent": (
|
|
1118
|
+
self.fallback_agent_config.to_dict()
|
|
1119
|
+
if self.fallback_agent_config
|
|
1120
|
+
else None
|
|
1121
|
+
),
|
|
914
1122
|
"workflow_cls": self.workflow_cls if self.workflow_cls else None,
|
|
915
1123
|
}
|
|
916
1124
|
|
|
@@ -938,12 +1146,17 @@ class Agent:
|
|
|
938
1146
|
except Exception:
|
|
939
1147
|
# Fallback: rebuild using the JSON schema
|
|
940
1148
|
field_definitions = {}
|
|
941
|
-
for field, values in
|
|
1149
|
+
for field, values in (
|
|
1150
|
+
schema_info.get("schema", {}).get("properties", {}).items()
|
|
1151
|
+
):
|
|
942
1152
|
field_type = get_field_type(values)
|
|
943
1153
|
if "default" in values:
|
|
944
1154
|
field_definitions[field] = (
|
|
945
1155
|
field_type,
|
|
946
|
-
Field(
|
|
1156
|
+
Field(
|
|
1157
|
+
description=values.get("description", ""),
|
|
1158
|
+
default=values["default"],
|
|
1159
|
+
),
|
|
947
1160
|
)
|
|
948
1161
|
else:
|
|
949
1162
|
field_definitions[field] = (
|
|
@@ -952,13 +1165,21 @@ class Agent:
|
|
|
952
1165
|
)
|
|
953
1166
|
query_args_model = create_model(
|
|
954
1167
|
schema_info.get("schema", {}).get("title", "QueryArgs"),
|
|
955
|
-
**field_definitions
|
|
1168
|
+
**field_definitions,
|
|
956
1169
|
)
|
|
957
1170
|
else:
|
|
958
1171
|
query_args_model = create_model("QueryArgs")
|
|
959
1172
|
|
|
960
|
-
fn =
|
|
961
|
-
|
|
1173
|
+
fn = (
|
|
1174
|
+
pickle.loads(tool_data["fn"].encode("latin-1"))
|
|
1175
|
+
if tool_data["fn"]
|
|
1176
|
+
else None
|
|
1177
|
+
)
|
|
1178
|
+
async_fn = (
|
|
1179
|
+
pickle.loads(tool_data["async_fn"].encode("latin-1"))
|
|
1180
|
+
if tool_data["async_fn"]
|
|
1181
|
+
else None
|
|
1182
|
+
)
|
|
962
1183
|
|
|
963
1184
|
tool = VectaraTool.from_defaults(
|
|
964
1185
|
name=tool_data["name"],
|
|
@@ -979,7 +1200,11 @@ class Agent:
|
|
|
979
1200
|
fallback_agent_config=fallback_agent_config,
|
|
980
1201
|
workflow_cls=data["workflow_cls"],
|
|
981
1202
|
)
|
|
982
|
-
memory =
|
|
1203
|
+
memory = (
|
|
1204
|
+
pickle.loads(data["memory"].encode("latin-1"))
|
|
1205
|
+
if data.get("memory")
|
|
1206
|
+
else None
|
|
1207
|
+
)
|
|
983
1208
|
if memory:
|
|
984
1209
|
agent.agent.memory = memory
|
|
985
1210
|
return agent
|