vectara-agentic 0.1.23__py3-none-any.whl → 0.1.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of vectara-agentic might be problematic. Click here for more details.
- tests/test_agent.py +44 -1
- tests/test_tools.py +3 -3
- vectara_agentic/__init__.py +12 -19
- vectara_agentic/_observability.py +3 -3
- vectara_agentic/_prompts.py +5 -3
- vectara_agentic/_version.py +4 -0
- vectara_agentic/agent.py +126 -40
- vectara_agentic/agent_config.py +86 -0
- vectara_agentic/agent_endpoint.py +6 -7
- vectara_agentic/tools.py +275 -61
- vectara_agentic/tools_catalog.py +8 -1
- vectara_agentic/types.py +10 -0
- vectara_agentic/utils.py +50 -34
- {vectara_agentic-0.1.23.dist-info → vectara_agentic-0.1.25.dist-info}/METADATA +122 -38
- vectara_agentic-0.1.25.dist-info/RECORD +21 -0
- {vectara_agentic-0.1.23.dist-info → vectara_agentic-0.1.25.dist-info}/WHEEL +1 -1
- vectara_agentic-0.1.23.dist-info/RECORD +0 -19
- {vectara_agentic-0.1.23.dist-info → vectara_agentic-0.1.25.dist-info}/LICENSE +0 -0
- {vectara_agentic-0.1.23.dist-info → vectara_agentic-0.1.25.dist-info}/top_level.txt +0 -0
tests/test_agent.py
CHANGED
|
@@ -2,7 +2,8 @@ import unittest
|
|
|
2
2
|
from datetime import date
|
|
3
3
|
|
|
4
4
|
from vectara_agentic.agent import _get_prompt, Agent, AgentType, FunctionTool
|
|
5
|
-
|
|
5
|
+
from vectara_agentic.agent_config import AgentConfig
|
|
6
|
+
from vectara_agentic.types import ModelProvider, ObserverType
|
|
6
7
|
|
|
7
8
|
class TestAgentPackage(unittest.TestCase):
|
|
8
9
|
def test_get_prompt(self):
|
|
@@ -43,6 +44,48 @@ class TestAgentPackage(unittest.TestCase):
|
|
|
43
44
|
"50",
|
|
44
45
|
)
|
|
45
46
|
|
|
47
|
+
def test_agent_config(self):
|
|
48
|
+
def mult(x, y):
|
|
49
|
+
return x * y
|
|
50
|
+
|
|
51
|
+
tools = [
|
|
52
|
+
FunctionTool.from_defaults(
|
|
53
|
+
fn=mult, name="mult", description="Multiplication functions"
|
|
54
|
+
)
|
|
55
|
+
]
|
|
56
|
+
topic = "AI topic"
|
|
57
|
+
instructions = "Always do as your father tells you, if your mother agrees!"
|
|
58
|
+
config = AgentConfig(
|
|
59
|
+
agent_type=AgentType.REACT,
|
|
60
|
+
main_llm_provider=ModelProvider.ANTHROPIC,
|
|
61
|
+
main_llm_model_name="claude-3-5-sonnet-20241022",
|
|
62
|
+
tool_llm_provider=ModelProvider.TOGETHER,
|
|
63
|
+
tool_llm_model_name="meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
|
64
|
+
observer=ObserverType.ARIZE_PHOENIX
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
agent = Agent(
|
|
68
|
+
tools=tools,
|
|
69
|
+
topic=topic,
|
|
70
|
+
custom_instructions=instructions,
|
|
71
|
+
agent_config=config
|
|
72
|
+
)
|
|
73
|
+
self.assertEqual(agent.tools, tools)
|
|
74
|
+
self.assertEqual(agent._topic, topic)
|
|
75
|
+
self.assertEqual(agent._custom_instructions, instructions)
|
|
76
|
+
self.assertEqual(agent.agent_type, AgentType.REACT)
|
|
77
|
+
self.assertEqual(agent.agent_config.observer, ObserverType.ARIZE_PHOENIX)
|
|
78
|
+
self.assertEqual(agent.agent_config.main_llm_provider, ModelProvider.ANTHROPIC)
|
|
79
|
+
self.assertEqual(agent.agent_config.tool_llm_provider, ModelProvider.TOGETHER)
|
|
80
|
+
|
|
81
|
+
# To run this test, you must have OPENAI_API_KEY in your environment
|
|
82
|
+
self.assertEqual(
|
|
83
|
+
agent.chat(
|
|
84
|
+
"What is 5 times 10. Only give the answer, nothing else"
|
|
85
|
+
).replace("$", "\\$"),
|
|
86
|
+
"50",
|
|
87
|
+
)
|
|
88
|
+
|
|
46
89
|
def test_from_corpus(self):
|
|
47
90
|
agent = Agent.from_corpus(
|
|
48
91
|
tool_name="RAG Tool",
|
tests/test_tools.py
CHANGED
|
@@ -32,7 +32,7 @@ class TestToolsPackage(unittest.TestCase):
|
|
|
32
32
|
|
|
33
33
|
self.assertIsInstance(query_tool, VectaraTool)
|
|
34
34
|
self.assertIsInstance(query_tool, FunctionTool)
|
|
35
|
-
self.assertEqual(query_tool.tool_type, ToolType.QUERY)
|
|
35
|
+
self.assertEqual(query_tool.metadata.tool_type, ToolType.QUERY)
|
|
36
36
|
|
|
37
37
|
def test_tool_factory(self):
|
|
38
38
|
def mult(x, y):
|
|
@@ -42,7 +42,7 @@ class TestToolsPackage(unittest.TestCase):
|
|
|
42
42
|
other_tool = tools_factory.create_tool(mult)
|
|
43
43
|
self.assertIsInstance(other_tool, VectaraTool)
|
|
44
44
|
self.assertIsInstance(other_tool, FunctionTool)
|
|
45
|
-
self.assertEqual(other_tool.tool_type, ToolType.QUERY)
|
|
45
|
+
self.assertEqual(other_tool.metadata.tool_type, ToolType.QUERY)
|
|
46
46
|
|
|
47
47
|
def test_llama_index_tools(self):
|
|
48
48
|
tools_factory = ToolsFactory()
|
|
@@ -56,7 +56,7 @@ class TestToolsPackage(unittest.TestCase):
|
|
|
56
56
|
|
|
57
57
|
self.assertIsInstance(arxiv_tool, VectaraTool)
|
|
58
58
|
self.assertIsInstance(arxiv_tool, FunctionTool)
|
|
59
|
-
self.assertEqual(arxiv_tool.tool_type, ToolType.QUERY)
|
|
59
|
+
self.assertEqual(arxiv_tool.metadata.tool_type, ToolType.QUERY)
|
|
60
60
|
|
|
61
61
|
def test_public_repo(self):
|
|
62
62
|
vectara_customer_id = "1366999410"
|
vectara_agentic/__init__.py
CHANGED
|
@@ -2,22 +2,15 @@
|
|
|
2
2
|
vectara_agentic package.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
#
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
initialize_package()
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
# Define the __all__ variable
|
|
23
|
-
# __all__ = ['Class1', 'function1', 'Class2', 'function2']
|
|
5
|
+
from .agent import Agent
|
|
6
|
+
from .tools import VectaraToolFactory, VectaraTool
|
|
7
|
+
|
|
8
|
+
# Define the __all__ variable for wildcard imports
|
|
9
|
+
__all__ = ['Agent', 'VectaraToolFactory', 'VectaraTool']
|
|
10
|
+
|
|
11
|
+
# Ensure package version is available
|
|
12
|
+
try:
|
|
13
|
+
import importlib.metadata
|
|
14
|
+
__version__ = importlib.metadata.version("vectara_agentic")
|
|
15
|
+
except Exception:
|
|
16
|
+
__version__ = "0.0.0" # fallback if not installed
|
|
@@ -6,16 +6,16 @@ import json
|
|
|
6
6
|
from typing import Optional, Union
|
|
7
7
|
import pandas as pd
|
|
8
8
|
from .types import ObserverType
|
|
9
|
+
from .agent_config import AgentConfig
|
|
9
10
|
|
|
10
|
-
def setup_observer() -> bool:
|
|
11
|
+
def setup_observer(config: AgentConfig) -> bool:
|
|
11
12
|
'''
|
|
12
13
|
Setup the observer.
|
|
13
14
|
'''
|
|
14
15
|
import phoenix as px
|
|
15
16
|
from openinference.instrumentation.llama_index import LlamaIndexInstrumentor
|
|
16
17
|
from phoenix.otel import register
|
|
17
|
-
observer
|
|
18
|
-
if observer == ObserverType.ARIZE_PHOENIX:
|
|
18
|
+
if config.observer == ObserverType.ARIZE_PHOENIX:
|
|
19
19
|
phoenix_endpoint = os.getenv("PHOENIX_ENDPOINT", None)
|
|
20
20
|
if not phoenix_endpoint:
|
|
21
21
|
px.launch_app()
|
vectara_agentic/_prompts.py
CHANGED
|
@@ -5,6 +5,7 @@ This file contains the prompt templates for the different types of agents.
|
|
|
5
5
|
# General (shared) instructions
|
|
6
6
|
GENERAL_INSTRUCTIONS = """
|
|
7
7
|
- Use tools as your main source of information, do not respond without using a tool. Do not respond based on pre-trained knowledge.
|
|
8
|
+
- Always call the 'get_current_date' tool to ensure you know the exact date when a user asks a question.
|
|
8
9
|
- When using a tool with arguments, simplify the query as much as possible if you use the tool with arguments.
|
|
9
10
|
For example, if the original query is "revenue for apple in 2021", you can use the tool with a query "revenue" with arguments year=2021 and company=apple.
|
|
10
11
|
- If a tool responds with "I do not have enough information", try one of the following:
|
|
@@ -31,7 +32,8 @@ GENERAL_INSTRUCTIONS = """
|
|
|
31
32
|
- Use the x_load_unique_values tool to understand the unique values in each column.
|
|
32
33
|
Sometimes the user may ask for a specific column value, but the actual value in the table may be different, and you will need to use the correct value.
|
|
33
34
|
- Use the x_load_sample_data tool to understand the column names, and typical values in each column.
|
|
34
|
-
- For tool arguments that support conditional logic (such as year='>2022'), use
|
|
35
|
+
- For tool arguments that support conditional logic (such as year='>2022'), use one of these operators: [">=", "<=", "!=", ">", "<", "="],
|
|
36
|
+
or a range operator, with inclusive or exclusive brackets (such as '[2021,2022]' or '[2021,2023)').
|
|
35
37
|
- Do not mention table names or database names in your response.
|
|
36
38
|
"""
|
|
37
39
|
|
|
@@ -42,7 +44,7 @@ GENERAL_PROMPT_TEMPLATE = """
|
|
|
42
44
|
You are a helpful chatbot in conversation with a user, with expertise in {chat_topic}.
|
|
43
45
|
|
|
44
46
|
## Date
|
|
45
|
-
|
|
47
|
+
Your birth date is {today}.
|
|
46
48
|
|
|
47
49
|
## INSTRUCTIONS:
|
|
48
50
|
IMPORTANT - FOLLOW THESE INSTRUCTIONS CAREFULLY:
|
|
@@ -62,7 +64,7 @@ You are designed to help with a variety of tasks, from answering questions to pr
|
|
|
62
64
|
You have expertise in {chat_topic}.
|
|
63
65
|
|
|
64
66
|
## Date
|
|
65
|
-
|
|
67
|
+
Your birth date is {today}.
|
|
66
68
|
|
|
67
69
|
## Tools
|
|
68
70
|
You have access to a wide variety of tools.
|
vectara_agentic/agent.py
CHANGED
|
@@ -8,6 +8,7 @@ import time
|
|
|
8
8
|
import json
|
|
9
9
|
import logging
|
|
10
10
|
import traceback
|
|
11
|
+
import asyncio
|
|
11
12
|
|
|
12
13
|
import dill
|
|
13
14
|
from dotenv import load_dotenv
|
|
@@ -25,12 +26,14 @@ from llama_index.core.callbacks.base_handler import BaseCallbackHandler
|
|
|
25
26
|
from llama_index.agent.openai import OpenAIAgent
|
|
26
27
|
from llama_index.core.memory import ChatMemoryBuffer
|
|
27
28
|
|
|
28
|
-
from .types import AgentType, AgentStatusType, LLMRole, ToolType
|
|
29
|
+
from .types import AgentType, AgentStatusType, LLMRole, ToolType, AgentResponse, AgentStreamingResponse
|
|
29
30
|
from .utils import get_llm, get_tokenizer_for_model
|
|
30
31
|
from ._prompts import REACT_PROMPT_TEMPLATE, GENERAL_PROMPT_TEMPLATE, GENERAL_INSTRUCTIONS
|
|
31
32
|
from ._callback import AgentCallbackHandler
|
|
32
33
|
from ._observability import setup_observer, eval_fcs
|
|
33
|
-
from .tools import VectaraToolFactory, VectaraTool
|
|
34
|
+
from .tools import VectaraToolFactory, VectaraTool, ToolsFactory
|
|
35
|
+
from .tools_catalog import get_current_date
|
|
36
|
+
from .agent_config import AgentConfig
|
|
34
37
|
|
|
35
38
|
logger = logging.getLogger("opentelemetry.exporter.otlp.proto.http.trace_exporter")
|
|
36
39
|
logger.setLevel(logging.CRITICAL)
|
|
@@ -91,7 +94,8 @@ class Agent:
|
|
|
91
94
|
verbose: bool = True,
|
|
92
95
|
update_func: Optional[Callable[[AgentStatusType, str], None]] = None,
|
|
93
96
|
agent_progress_callback: Optional[Callable[[AgentStatusType, str], None]] = None,
|
|
94
|
-
|
|
97
|
+
query_logging_callback: Optional[Callable[[str, str], None]] = None,
|
|
98
|
+
agent_config: Optional[AgentConfig] = None,
|
|
95
99
|
) -> None:
|
|
96
100
|
"""
|
|
97
101
|
Initialize the agent with the specified type, tools, topic, and system message.
|
|
@@ -104,14 +108,18 @@ class Agent:
|
|
|
104
108
|
verbose (bool, optional): Whether the agent should print its steps. Defaults to True.
|
|
105
109
|
agent_progress_callback (Callable): A callback function the code calls on any agent updates.
|
|
106
110
|
update_func (Callable): old name for agent_progress_callback. Will be deprecated in future.
|
|
107
|
-
|
|
111
|
+
query_logging_callback (Callable): A callback function the code calls upon completion of a query
|
|
112
|
+
agent_config (AgentConfig, optional): The configuration of the agent.
|
|
113
|
+
Defaults to AgentConfig(), which reads from environment variables.
|
|
108
114
|
"""
|
|
109
|
-
self.
|
|
110
|
-
self.
|
|
111
|
-
self.
|
|
115
|
+
self.agent_config = agent_config or AgentConfig()
|
|
116
|
+
self.agent_type = self.agent_config.agent_type
|
|
117
|
+
self.tools = tools + [ToolsFactory().create_tool(get_current_date)]
|
|
118
|
+
self.llm = get_llm(LLMRole.MAIN, config=self.agent_config)
|
|
112
119
|
self._custom_instructions = custom_instructions
|
|
113
120
|
self._topic = topic
|
|
114
121
|
self.agent_progress_callback = agent_progress_callback if agent_progress_callback else update_func
|
|
122
|
+
self.query_logging_callback = query_logging_callback
|
|
115
123
|
|
|
116
124
|
main_tok = get_tokenizer_for_model(role=LLMRole.MAIN)
|
|
117
125
|
self.main_token_counter = TokenCountingHandler(tokenizer=main_tok) if main_tok else None
|
|
@@ -131,7 +139,7 @@ class Agent:
|
|
|
131
139
|
if self.agent_type == AgentType.REACT:
|
|
132
140
|
prompt = _get_prompt(REACT_PROMPT_TEMPLATE, topic, custom_instructions)
|
|
133
141
|
self.agent = ReActAgent.from_tools(
|
|
134
|
-
tools=tools,
|
|
142
|
+
tools=self.tools,
|
|
135
143
|
llm=self.llm,
|
|
136
144
|
memory=self.memory,
|
|
137
145
|
verbose=verbose,
|
|
@@ -142,7 +150,7 @@ class Agent:
|
|
|
142
150
|
elif self.agent_type == AgentType.OPENAI:
|
|
143
151
|
prompt = _get_prompt(GENERAL_PROMPT_TEMPLATE, topic, custom_instructions)
|
|
144
152
|
self.agent = OpenAIAgent.from_tools(
|
|
145
|
-
tools=tools,
|
|
153
|
+
tools=self.tools,
|
|
146
154
|
llm=self.llm,
|
|
147
155
|
memory=self.memory,
|
|
148
156
|
verbose=verbose,
|
|
@@ -151,23 +159,24 @@ class Agent:
|
|
|
151
159
|
system_prompt=prompt,
|
|
152
160
|
)
|
|
153
161
|
elif self.agent_type == AgentType.LLMCOMPILER:
|
|
154
|
-
|
|
155
|
-
tools=tools,
|
|
162
|
+
agent_worker = LLMCompilerAgentWorker.from_tools(
|
|
163
|
+
tools=self.tools,
|
|
156
164
|
llm=self.llm,
|
|
157
165
|
verbose=verbose,
|
|
158
166
|
callable_manager=callback_manager,
|
|
159
|
-
)
|
|
160
|
-
|
|
161
|
-
_get_llm_compiler_prompt(
|
|
167
|
+
)
|
|
168
|
+
agent_worker.system_prompt = _get_prompt(
|
|
169
|
+
_get_llm_compiler_prompt(agent_worker.system_prompt, topic, custom_instructions),
|
|
162
170
|
topic, custom_instructions
|
|
163
171
|
)
|
|
164
|
-
|
|
165
|
-
_get_llm_compiler_prompt(
|
|
172
|
+
agent_worker.system_prompt_replan = _get_prompt(
|
|
173
|
+
_get_llm_compiler_prompt(agent_worker.system_prompt_replan, topic, custom_instructions),
|
|
166
174
|
topic, custom_instructions
|
|
167
175
|
)
|
|
176
|
+
self.agent = agent_worker.as_agent()
|
|
168
177
|
elif self.agent_type == AgentType.LATS:
|
|
169
178
|
agent_worker = LATSAgentWorker.from_tools(
|
|
170
|
-
tools=tools,
|
|
179
|
+
tools=self.tools,
|
|
171
180
|
llm=self.llm,
|
|
172
181
|
num_expansions=3,
|
|
173
182
|
max_rollouts=-1,
|
|
@@ -181,7 +190,7 @@ class Agent:
|
|
|
181
190
|
raise ValueError(f"Unknown agent type: {self.agent_type}")
|
|
182
191
|
|
|
183
192
|
try:
|
|
184
|
-
self.observability_enabled = setup_observer()
|
|
193
|
+
self.observability_enabled = setup_observer(self.agent_config)
|
|
185
194
|
except Exception as e:
|
|
186
195
|
print(f"Failed to set up observer ({e}), ignoring")
|
|
187
196
|
self.observability_enabled = False
|
|
@@ -252,7 +261,8 @@ class Agent:
|
|
|
252
261
|
verbose: bool = True,
|
|
253
262
|
update_func: Optional[Callable[[AgentStatusType, str], None]] = None,
|
|
254
263
|
agent_progress_callback: Optional[Callable[[AgentStatusType, str], None]] = None,
|
|
255
|
-
|
|
264
|
+
query_logging_callback: Optional[Callable[[str, str], None]] = None,
|
|
265
|
+
agent_config: AgentConfig = AgentConfig(),
|
|
256
266
|
) -> "Agent":
|
|
257
267
|
"""
|
|
258
268
|
Create an agent from tools, agent type, and language model.
|
|
@@ -265,7 +275,8 @@ class Agent:
|
|
|
265
275
|
verbose (bool, optional): Whether the agent should print its steps. Defaults to True.
|
|
266
276
|
agent_progress_callback (Callable): A callback function the code calls on any agent updates.
|
|
267
277
|
update_func (Callable): old name for agent_progress_callback. Will be deprecated in future.
|
|
268
|
-
|
|
278
|
+
query_logging_callback (Callable): A callback function the code calls upon completion of a query
|
|
279
|
+
agent_config (AgentConfig, optional): The configuration of the agent.
|
|
269
280
|
|
|
270
281
|
Returns:
|
|
271
282
|
Agent: An instance of the Agent class.
|
|
@@ -273,7 +284,8 @@ class Agent:
|
|
|
273
284
|
return cls(
|
|
274
285
|
tools=tools, topic=topic, custom_instructions=custom_instructions,
|
|
275
286
|
verbose=verbose, agent_progress_callback=agent_progress_callback,
|
|
276
|
-
|
|
287
|
+
query_logging_callback=query_logging_callback,
|
|
288
|
+
update_func=update_func, agent_config=agent_config
|
|
277
289
|
)
|
|
278
290
|
|
|
279
291
|
@classmethod
|
|
@@ -286,6 +298,7 @@ class Agent:
|
|
|
286
298
|
vectara_corpus_id: str = str(os.environ.get("VECTARA_CORPUS_ID", "")),
|
|
287
299
|
vectara_api_key: str = str(os.environ.get("VECTARA_API_KEY", "")),
|
|
288
300
|
agent_progress_callback: Optional[Callable[[AgentStatusType, str], None]] = None,
|
|
301
|
+
query_logging_callback: Optional[Callable[[str, str], None]] = None,
|
|
289
302
|
verbose: bool = False,
|
|
290
303
|
vectara_filter_fields: list[dict] = [],
|
|
291
304
|
vectara_lambda_val: float = 0.005,
|
|
@@ -305,6 +318,7 @@ class Agent:
|
|
|
305
318
|
vectara_corpus_id (str): The Vectara corpus ID (or comma separated list of IDs).
|
|
306
319
|
vectara_api_key (str): The Vectara API key.
|
|
307
320
|
agent_progress_callback (Callable): A callback function the code calls on any agent updates.
|
|
321
|
+
query_logging_callback (Callable): A callback function the code calls upon completion of a query
|
|
308
322
|
data_description (str): The description of the data.
|
|
309
323
|
assistant_specialty (str): The specialty of the assistant.
|
|
310
324
|
verbose (bool, optional): Whether to print verbose output.
|
|
@@ -364,6 +378,7 @@ class Agent:
|
|
|
364
378
|
custom_instructions=assistant_instructions,
|
|
365
379
|
verbose=verbose,
|
|
366
380
|
agent_progress_callback=agent_progress_callback,
|
|
381
|
+
query_logging_callback=query_logging_callback,
|
|
367
382
|
)
|
|
368
383
|
|
|
369
384
|
def report(self) -> None:
|
|
@@ -378,7 +393,10 @@ class Agent:
|
|
|
378
393
|
print(f"Topic = {self._topic}")
|
|
379
394
|
print("Tools:")
|
|
380
395
|
for tool in self.tools:
|
|
381
|
-
|
|
396
|
+
if hasattr(tool, 'metadata'):
|
|
397
|
+
print(f"- {tool.metadata.name}")
|
|
398
|
+
else:
|
|
399
|
+
print("- tool without metadata")
|
|
382
400
|
print(f"Agent LLM = {get_llm(LLMRole.MAIN).metadata.model_name}")
|
|
383
401
|
print(f"Tool LLM = {get_llm(LLMRole.TOOL).metadata.model_name}")
|
|
384
402
|
|
|
@@ -394,12 +412,32 @@ class Agent:
|
|
|
394
412
|
"tool token count": self.tool_token_counter.total_llm_token_count if self.tool_token_counter else -1,
|
|
395
413
|
}
|
|
396
414
|
|
|
415
|
+
async def _aformat_for_lats(self, prompt, agent_response):
|
|
416
|
+
llm_prompt = f"""
|
|
417
|
+
Given the question '{prompt}', and agent response '{agent_response.response}',
|
|
418
|
+
Please provide a well formatted final response to the query.
|
|
419
|
+
final response:
|
|
420
|
+
"""
|
|
421
|
+
agent_response.response = str(self.llm.acomplete(llm_prompt))
|
|
422
|
+
|
|
423
|
+
def chat(self, prompt: str) -> AgentResponse: # type: ignore
|
|
424
|
+
"""
|
|
425
|
+
Interact with the agent using a chat prompt.
|
|
426
|
+
|
|
427
|
+
Args:
|
|
428
|
+
prompt (str): The chat prompt.
|
|
429
|
+
|
|
430
|
+
Returns:
|
|
431
|
+
AgentResponse: The response from the agent.
|
|
432
|
+
"""
|
|
433
|
+
return asyncio.run(self.achat(prompt))
|
|
434
|
+
|
|
397
435
|
@retry(
|
|
398
436
|
retry_on_exception=_retry_if_exception,
|
|
399
437
|
stop_max_attempt_number=3,
|
|
400
438
|
wait_fixed=2000,
|
|
401
439
|
)
|
|
402
|
-
def
|
|
440
|
+
async def achat(self, prompt: str) -> AgentResponse: # type: ignore
|
|
403
441
|
"""
|
|
404
442
|
Interact with the agent using a chat prompt.
|
|
405
443
|
|
|
@@ -407,32 +445,79 @@ class Agent:
|
|
|
407
445
|
prompt (str): The chat prompt.
|
|
408
446
|
|
|
409
447
|
Returns:
|
|
410
|
-
|
|
448
|
+
AgentResponse: The response from the agent.
|
|
411
449
|
"""
|
|
412
450
|
|
|
413
451
|
try:
|
|
414
452
|
st = time.time()
|
|
415
|
-
agent_response = self.agent.
|
|
453
|
+
agent_response = await self.agent.achat(prompt)
|
|
416
454
|
if self.agent_type == AgentType.LATS:
|
|
417
|
-
prompt
|
|
418
|
-
Given the question '{prompt}', and agent response '{agent_response.response}',
|
|
419
|
-
Please provide a well formatted final response to the query.
|
|
420
|
-
final response:
|
|
421
|
-
"""
|
|
422
|
-
final_response = str(self.llm.complete(prompt))
|
|
423
|
-
else:
|
|
424
|
-
final_response = agent_response.response
|
|
425
|
-
|
|
455
|
+
await self._aformat_for_lats(prompt, agent_response)
|
|
426
456
|
if self.verbose:
|
|
427
457
|
print(f"Time taken: {time.time() - st}")
|
|
428
458
|
if self.observability_enabled:
|
|
429
459
|
eval_fcs()
|
|
430
|
-
|
|
460
|
+
if self.query_logging_callback:
|
|
461
|
+
self.query_logging_callback(prompt, agent_response.response)
|
|
462
|
+
return agent_response
|
|
431
463
|
except Exception as e:
|
|
432
|
-
return
|
|
464
|
+
return AgentResponse(
|
|
465
|
+
response = (
|
|
466
|
+
f"Vectara Agentic: encountered an exception ({e}) at ({traceback.format_exc()})"
|
|
467
|
+
", and can't respond."
|
|
468
|
+
)
|
|
469
|
+
)
|
|
433
470
|
|
|
434
|
-
#
|
|
471
|
+
def stream_chat(self, prompt: str) -> AgentStreamingResponse: # type: ignore
|
|
472
|
+
"""
|
|
473
|
+
Interact with the agent using a chat prompt with streaming.
|
|
474
|
+
Args:
|
|
475
|
+
prompt (str): The chat prompt.
|
|
476
|
+
Returns:
|
|
477
|
+
AgentStreamingResponse: The streaming response from the agent.
|
|
478
|
+
"""
|
|
479
|
+
return asyncio.run(self.astream_chat(prompt))
|
|
480
|
+
|
|
481
|
+
@retry(
|
|
482
|
+
retry_on_exception=_retry_if_exception,
|
|
483
|
+
stop_max_attempt_number=3,
|
|
484
|
+
wait_fixed=2000,
|
|
485
|
+
)
|
|
486
|
+
async def astream_chat(self, prompt: str) -> AgentStreamingResponse: # type: ignore
|
|
487
|
+
"""
|
|
488
|
+
Interact with the agent using a chat prompt asynchronously with streaming.
|
|
489
|
+
Args:
|
|
490
|
+
prompt (str): The chat prompt.
|
|
491
|
+
Returns:
|
|
492
|
+
AgentStreamingResponse: The streaming response from the agent.
|
|
493
|
+
"""
|
|
494
|
+
try:
|
|
495
|
+
agent_response = await self.agent.astream_chat(prompt)
|
|
496
|
+
original_async_response_gen = agent_response.async_response_gen
|
|
497
|
+
|
|
498
|
+
# Wrap async_response_gen
|
|
499
|
+
async def _stream_response_wrapper():
|
|
500
|
+
async for token in original_async_response_gen():
|
|
501
|
+
yield token # Yield async token to keep streaming behavior
|
|
502
|
+
|
|
503
|
+
# After streaming completes, execute additional logic
|
|
504
|
+
if self.agent_type == AgentType.LATS:
|
|
505
|
+
await self._aformat_for_lats(prompt, agent_response)
|
|
506
|
+
if self.query_logging_callback:
|
|
507
|
+
self.query_logging_callback(prompt, agent_response.response)
|
|
508
|
+
if self.observability_enabled:
|
|
509
|
+
eval_fcs()
|
|
510
|
+
|
|
511
|
+
agent_response.async_response_gen = _stream_response_wrapper # Override method
|
|
512
|
+
return agent_response
|
|
513
|
+
except Exception as e:
|
|
514
|
+
raise ValueError(
|
|
515
|
+
f"Vectara Agentic: encountered an exception ({e}) at ({traceback.format_exc()}), and can't respond."
|
|
516
|
+
) from e
|
|
435
517
|
|
|
518
|
+
#
|
|
519
|
+
# Serialization methods
|
|
520
|
+
#
|
|
436
521
|
def dumps(self) -> str:
|
|
437
522
|
"""Serialize the Agent instance to a JSON string."""
|
|
438
523
|
return json.dumps(self.to_dict())
|
|
@@ -449,7 +534,7 @@ class Agent:
|
|
|
449
534
|
for tool in self.tools:
|
|
450
535
|
# Serialize each tool's metadata, function, and dynamic model schema (QueryArgs)
|
|
451
536
|
tool_dict = {
|
|
452
|
-
"tool_type": tool.tool_type.value,
|
|
537
|
+
"tool_type": tool.metadata.tool_type.value,
|
|
453
538
|
"name": tool.metadata.name,
|
|
454
539
|
"description": tool.metadata.description,
|
|
455
540
|
"fn": dill.dumps(tool.fn).decode("latin-1") if tool.fn else None, # Serialize fn
|
|
@@ -469,12 +554,13 @@ class Agent:
|
|
|
469
554
|
"topic": self._topic,
|
|
470
555
|
"custom_instructions": self._custom_instructions,
|
|
471
556
|
"verbose": self.verbose,
|
|
557
|
+
"agent_config": self.agent_config.to_dict(),
|
|
472
558
|
}
|
|
473
559
|
|
|
474
560
|
@classmethod
|
|
475
561
|
def from_dict(cls, data: Dict[str, Any]) -> "Agent":
|
|
476
562
|
"""Create an Agent instance from a dictionary."""
|
|
477
|
-
|
|
563
|
+
agent_config = AgentConfig.from_dict(data["agent_config"])
|
|
478
564
|
tools = []
|
|
479
565
|
|
|
480
566
|
json_type_to_python = {
|
|
@@ -523,7 +609,7 @@ class Agent:
|
|
|
523
609
|
|
|
524
610
|
agent = cls(
|
|
525
611
|
tools=tools,
|
|
526
|
-
|
|
612
|
+
agent_config=agent_config,
|
|
527
613
|
topic=data["topic"],
|
|
528
614
|
custom_instructions=data["custom_instructions"],
|
|
529
615
|
verbose=data["verbose"],
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Define the AgentConfig dataclass for the Vectara Agentic utilities.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from .types import ModelProvider, AgentType, ObserverType
|
|
8
|
+
|
|
9
|
+
@dataclass(eq=True, frozen=True)
|
|
10
|
+
class AgentConfig:
|
|
11
|
+
"""
|
|
12
|
+
Centralized configuration for the Vectara Agentic utilities.
|
|
13
|
+
|
|
14
|
+
Each field can default to either a hard-coded value or an environment
|
|
15
|
+
variable. For example, if you have environment variables you want to
|
|
16
|
+
fall back on, you can default to them here.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
# Agent type
|
|
20
|
+
agent_type: AgentType = field(
|
|
21
|
+
default_factory=lambda: AgentType(
|
|
22
|
+
os.getenv("VECTARA_AGENTIC_AGENT_TYPE", AgentType.OPENAI.value)
|
|
23
|
+
)
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
# Main LLM provider & model name
|
|
27
|
+
main_llm_provider: ModelProvider = field(
|
|
28
|
+
default_factory=lambda: ModelProvider(
|
|
29
|
+
os.getenv("VECTARA_AGENTIC_MAIN_LLM_PROVIDER", ModelProvider.OPENAI.value)
|
|
30
|
+
)
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
main_llm_model_name: str = field(
|
|
34
|
+
default_factory=lambda: os.getenv("VECTARA_AGENTIC_MAIN_MODEL_NAME", "")
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# Tool LLM provider & model name
|
|
38
|
+
tool_llm_provider: ModelProvider = field(
|
|
39
|
+
default_factory=lambda: ModelProvider(
|
|
40
|
+
os.getenv("VECTARA_AGENTIC_TOOL_LLM_PROVIDER", ModelProvider.OPENAI.value)
|
|
41
|
+
)
|
|
42
|
+
)
|
|
43
|
+
tool_llm_model_name: str = field(
|
|
44
|
+
default_factory=lambda: os.getenv("VECTARA_AGENTIC_TOOL_MODEL_NAME", "")
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# Observer
|
|
48
|
+
observer: ObserverType = field(
|
|
49
|
+
default_factory=lambda: ObserverType(
|
|
50
|
+
os.getenv("VECTARA_AGENTIC_OBSERVER_TYPE", "NO_OBSERVER")
|
|
51
|
+
)
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# Endpoint API key
|
|
55
|
+
endpoint_api_key: str = field(
|
|
56
|
+
default_factory=lambda: os.getenv("VECTARA_AGENTIC_API_KEY", "dev-api-key")
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
def to_dict(self) -> dict:
|
|
60
|
+
"""
|
|
61
|
+
Convert the AgentConfig to a dictionary.
|
|
62
|
+
"""
|
|
63
|
+
return {
|
|
64
|
+
"agent_type": self.agent_type.value,
|
|
65
|
+
"main_llm_provider": self.main_llm_provider.value,
|
|
66
|
+
"main_llm_model_name": self.main_llm_model_name,
|
|
67
|
+
"tool_llm_provider": self.tool_llm_provider.value,
|
|
68
|
+
"tool_llm_model_name": self.tool_llm_model_name,
|
|
69
|
+
"observer": self.observer.value,
|
|
70
|
+
"endpoint_api_key": self.endpoint_api_key
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
@classmethod
|
|
74
|
+
def from_dict(cls, config_dict: dict) -> "AgentConfig":
|
|
75
|
+
"""
|
|
76
|
+
Create an AgentConfig from a dictionary.
|
|
77
|
+
"""
|
|
78
|
+
return cls(
|
|
79
|
+
agent_type=AgentType(config_dict["agent_type"]),
|
|
80
|
+
main_llm_provider=ModelProvider(config_dict["main_llm_provider"]),
|
|
81
|
+
main_llm_model_name=config_dict["main_llm_model_name"],
|
|
82
|
+
tool_llm_provider=ModelProvider(config_dict["tool_llm_provider"]),
|
|
83
|
+
tool_llm_model_name=config_dict["tool_llm_model_name"],
|
|
84
|
+
observer=ObserverType(config_dict["observer"]),
|
|
85
|
+
endpoint_api_key=config_dict["endpoint_api_key"]
|
|
86
|
+
)
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
"""
|
|
2
2
|
This module contains functions to start the agent behind an API endpoint.
|
|
3
3
|
"""
|
|
4
|
-
import os
|
|
5
4
|
import logging
|
|
6
5
|
from fastapi import FastAPI, HTTPException, Depends
|
|
7
6
|
from fastapi.security.api_key import APIKeyHeader
|
|
@@ -9,10 +8,9 @@ from pydantic import BaseModel
|
|
|
9
8
|
import uvicorn
|
|
10
9
|
|
|
11
10
|
from .agent import Agent
|
|
11
|
+
from .agent_config import AgentConfig
|
|
12
12
|
|
|
13
|
-
|
|
14
|
-
API_KEY = os.getenv("VECTARA_AGENTIC_API_KEY", "dev-api-key")
|
|
15
|
-
api_key_header = APIKeyHeader(name=API_KEY_NAME)
|
|
13
|
+
api_key_header = APIKeyHeader(name="X-API-Key")
|
|
16
14
|
|
|
17
15
|
class ChatRequest(BaseModel):
|
|
18
16
|
"""
|
|
@@ -21,18 +19,19 @@ class ChatRequest(BaseModel):
|
|
|
21
19
|
message: str
|
|
22
20
|
|
|
23
21
|
|
|
24
|
-
def create_app(agent: Agent) -> FastAPI:
|
|
22
|
+
def create_app(agent: Agent, config: AgentConfig) -> FastAPI:
|
|
25
23
|
"""
|
|
26
24
|
Create a FastAPI application with a chat endpoint.
|
|
27
25
|
"""
|
|
28
26
|
app = FastAPI()
|
|
29
27
|
logger = logging.getLogger("uvicorn.error")
|
|
30
28
|
logging.basicConfig(level=logging.INFO)
|
|
29
|
+
endpoint_api_key = config.endpoint_api_key
|
|
31
30
|
|
|
32
31
|
@app.get("/chat", summary="Chat with the agent")
|
|
33
32
|
async def chat(message: str, api_key: str = Depends(api_key_header)):
|
|
34
33
|
logger.info(f"Received message: {message}")
|
|
35
|
-
if api_key !=
|
|
34
|
+
if api_key != endpoint_api_key:
|
|
36
35
|
logger.warning("Unauthorized access attempt")
|
|
37
36
|
raise HTTPException(status_code=403, detail="Unauthorized")
|
|
38
37
|
|
|
@@ -59,5 +58,5 @@ def start_app(agent: Agent, host='0.0.0.0', port=8000):
|
|
|
59
58
|
host (str, optional): The host address for the API. Defaults to '127.0.0.1'.
|
|
60
59
|
port (int, optional): The port for the API. Defaults to 8000.
|
|
61
60
|
"""
|
|
62
|
-
app = create_app(agent)
|
|
61
|
+
app = create_app(agent, config=AgentConfig())
|
|
63
62
|
uvicorn.run(app, host=host, port=port)
|