vectara-agentic 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of vectara-agentic might be problematic. Click here for more details.

@@ -3,7 +3,7 @@ vectara_agentic package.
3
3
  """
4
4
 
5
5
  # Define the package version
6
- __version__ = "0.1.15"
6
+ __version__ = "0.1.17"
7
7
 
8
8
  # Import classes and functions from modules
9
9
  # from .module1 import Class1, function1
@@ -12,6 +12,7 @@ __version__ = "0.1.15"
12
12
 
13
13
  # Any initialization code
14
14
  def initialize_package():
15
+ """print a message when the package is initialized."""
15
16
  print(f"Initializing vectara-agentic version {__version__}...")
16
17
 
17
18
 
@@ -66,7 +66,6 @@ class AgentCallbackHandler(BaseCallbackHandler):
66
66
 
67
67
  def _handle_agent_step(self, payload: dict) -> None:
68
68
  """Calls self.fn() with the information about agent step."""
69
- print(f"Handling agent step: {payload}")
70
69
  if EventPayload.MESSAGES in payload:
71
70
  msg = str(payload.get(EventPayload.MESSAGES))
72
71
  if self.fn:
@@ -1,19 +1,21 @@
1
+ """
2
+ Observability for Vectara Agentic.
3
+ """
1
4
  import os
2
5
  import json
6
+ from typing import Optional, Union
3
7
  import pandas as pd
4
-
5
8
  from .types import ObserverType
6
9
 
7
10
  def setup_observer() -> bool:
8
11
  '''
9
12
  Setup the observer.
10
13
  '''
14
+ import phoenix as px
15
+ from openinference.instrumentation.llama_index import LlamaIndexInstrumentor
16
+ from phoenix.otel import register
11
17
  observer = ObserverType(os.getenv("VECTARA_AGENTIC_OBSERVER_TYPE", "NO_OBSERVER"))
12
18
  if observer == ObserverType.ARIZE_PHOENIX:
13
- import phoenix as px
14
- from phoenix.otel import register
15
- from openinference.instrumentation.llama_index import LlamaIndexInstrumentor
16
-
17
19
  phoenix_endpoint = os.getenv("PHOENIX_ENDPOINT", None)
18
20
  if not phoenix_endpoint:
19
21
  px.launch_app()
@@ -21,7 +23,7 @@ def setup_observer() -> bool:
21
23
  elif 'app.phoenix.arize.com' in phoenix_endpoint: # hosted on Arizze
22
24
  phoenix_api_key = os.getenv("PHOENIX_API_KEY", None)
23
25
  if not phoenix_api_key:
24
- raise Exception("Arize Phoenix API key not set. Please set PHOENIX_API_KEY environment variable.")
26
+ raise ValueError("Arize Phoenix API key not set. Please set PHOENIX_API_KEY environment variable.")
25
27
  os.environ["PHOENIX_CLIENT_HEADERS"] = f"api_key={phoenix_api_key}"
26
28
  os.environ["PHOENIX_COLLECTOR_ENDPOINT"] = "https://app.phoenix.arize.com"
27
29
  tracer_provider = register(endpoint=phoenix_endpoint, project_name="vectara-agentic")
@@ -29,12 +31,11 @@ def setup_observer() -> bool:
29
31
  tracer_provider = register(endpoint=phoenix_endpoint, project_name="vectara-agentic")
30
32
  LlamaIndexInstrumentor().instrument(tracer_provider=tracer_provider)
31
33
  return True
32
- else:
33
- print("No observer set.")
34
- return False
34
+ print("No observer set.")
35
+ return False
35
36
 
36
37
 
37
- def _extract_fcs_value(output):
38
+ def _extract_fcs_value(output: Union[str, dict]) -> Optional[float]:
38
39
  '''
39
40
  Extract the FCS value from the output.
40
41
  '''
@@ -49,7 +50,7 @@ def _extract_fcs_value(output):
49
50
  return None
50
51
 
51
52
 
52
- def _find_top_level_parent_id(row, all_spans):
53
+ def _find_top_level_parent_id(row: pd.Series, all_spans: pd.DataFrame) -> Optional[str]:
53
54
  '''
54
55
  Find the top level parent id for the given span.
55
56
  '''
@@ -67,14 +68,13 @@ def _find_top_level_parent_id(row, all_spans):
67
68
  return current_id
68
69
 
69
70
 
70
- def eval_fcs():
71
+ def eval_fcs() -> None:
71
72
  '''
72
73
  Evaluate the FCS score for the VectaraQueryEngine._query span.
73
74
  '''
75
+ import phoenix as px
74
76
  from phoenix.trace.dsl import SpanQuery
75
77
  from phoenix.trace import SpanEvaluations
76
- import phoenix as px
77
-
78
78
  query = SpanQuery().select(
79
79
  "output.value",
80
80
  "parent_id",
@@ -83,8 +83,10 @@ def eval_fcs():
83
83
  client = px.Client()
84
84
  all_spans = client.query_spans(query, project_name="vectara-agentic")
85
85
  vectara_spans = all_spans[all_spans['name'] == 'VectaraQueryEngine._query'].copy()
86
- vectara_spans['top_level_parent_id'] = vectara_spans.apply(lambda row: _find_top_level_parent_id(row, all_spans), axis=1)
87
- vectara_spans['score'] = vectara_spans['output.value'].apply(lambda x: _extract_fcs_value(x))
86
+ vectara_spans['top_level_parent_id'] = vectara_spans.apply(
87
+ lambda row: _find_top_level_parent_id(row, all_spans), axis=1
88
+ )
89
+ vectara_spans['score'] = vectara_spans['output.value'].apply(_extract_fcs_value)
88
90
 
89
91
  vectara_spans.reset_index(inplace=True)
90
92
  top_level_spans = vectara_spans.copy()
@@ -7,10 +7,10 @@ GENERAL_INSTRUCTIONS = """
7
7
  - Use tools as your main source of information, do not respond without using a tool. Do not respond based on pre-trained knowledge.
8
8
  - When using a tool with arguments, simplify the query as much as possible if you use the tool with arguments.
9
9
  For example, if the original query is "revenue for apple in 2021", you can use the tool with a query "revenue" with arguments year=2021 and company=apple.
10
- - If you can't answer the question with the information provided by the tools, try to rephrase the question and call a tool again,
10
+ - If you can't answer the question with the information provided by a tool, try to rephrase the question and call the tool again,
11
11
  or break the question into sub-questions and call a tool for each sub-question, then combine the answers to provide a complete response.
12
12
  For example if asked "what is the population of France and Germany", you can call the tool twice, once for each country.
13
- - If a query tool provides citations or references in markdown as part of its response, include the citations in your response.
13
+ - If a query tool provides citations or references in markdown as part of its response, include the references in your response.
14
14
  - When providing links in your response, where possible put the name of the website or source of information for the displayed text. Don't just say 'source'.
15
15
  - If after retrying you can't get the information or answer the question, respond with "I don't know".
16
16
  - Your response should never be the input to a tool, only the output.
@@ -21,6 +21,14 @@ GENERAL_INSTRUCTIONS = """
21
21
  - If including latex equations in the markdown response, make sure the equations are on a separate line and enclosed in double dollar signs.
22
22
  - Always respond in the language of the question, and in text (no images, videos or code).
23
23
  - Always call the "get_bad_topics" tool to determine the topics you are not allowed to discuss or respond to.
24
+ - If you are provided with database tools use them for analytical queries (such as counting, calculating max, min, average, sum, or other statistics).
25
+ For each database, the database tools include: x_list_tables, x_load_data, x_describe_tables, and x_load_sample_data, where 'x' in the database name.
26
+ The x_list_tables tool provides a list of available tables in the x database.
27
+ Always use the x_describe_tables tool to understand the schema of each table, before you access data from that table.
28
+ Always use the x_load_sample_data tool to understand the column names, and the unique values in each column, so you can use them in your queries.
29
+ Some times the user may ask for a specific column value, but the actual value in the table may be different, and you will need to use the correct value.
30
+ - Never call x_load_data to retrieve values from each row in the table.
31
+ - Do not mention table names or database names in your response.
24
32
  """
25
33
 
26
34
  #
@@ -65,10 +73,7 @@ IMPORTANT - FOLLOW THESE INSTRUCTIONS CAREFULLY:
65
73
  {INSTRUCTIONS}
66
74
  {custom_instructions}
67
75
 
68
- ## Input
69
- The user will specify a task or a question in text.
70
-
71
- ### Output Format
76
+ ## Output Format
72
77
 
73
78
  Please answer in the same language as the question and use the following format:
74
79
 
@@ -95,12 +100,12 @@ At that point, you MUST respond in the one of the following two formats (and do
95
100
 
96
101
  ```
97
102
  Thought: I can answer without using any more tools. I'll use the user's language to answer
98
- Answer: [your answer here (In the same language as the user's question, and maintain any references/citations)]
103
+ Answer: [your answer here (In the same language as the user's question, and maintain any references)]
99
104
  ```
100
105
 
101
106
  ```
102
107
  Thought: I cannot answer the question with the provided tools.
103
- Answer: [your answer here (In the same language as the user's question, and maintain any references/citations)]
108
+ Answer: [your answer here (In the same language as the user's question)]
104
109
  ```
105
110
 
106
111
  ## Current Conversation
vectara_agentic/agent.py CHANGED
@@ -6,6 +6,9 @@ import os
6
6
  from datetime import date
7
7
  import time
8
8
  import json
9
+ import logging
10
+ import traceback
11
+
9
12
  import dill
10
13
  from dotenv import load_dotenv
11
14
 
@@ -21,12 +24,6 @@ from llama_index.core.callbacks.base_handler import BaseCallbackHandler
21
24
  from llama_index.agent.openai import OpenAIAgent
22
25
  from llama_index.core.memory import ChatMemoryBuffer
23
26
 
24
- import logging
25
- logger = logging.getLogger('opentelemetry.exporter.otlp.proto.http.trace_exporter')
26
- logger.setLevel(logging.CRITICAL)
27
-
28
- load_dotenv(override=True)
29
-
30
27
  from .types import AgentType, AgentStatusType, LLMRole, ToolType
31
28
  from .utils import get_llm, get_tokenizer_for_model
32
29
  from ._prompts import REACT_PROMPT_TEMPLATE, GENERAL_PROMPT_TEMPLATE
@@ -34,6 +31,10 @@ from ._callback import AgentCallbackHandler
34
31
  from ._observability import setup_observer, eval_fcs
35
32
  from .tools import VectaraToolFactory, VectaraTool
36
33
 
34
+ logger = logging.getLogger("opentelemetry.exporter.otlp.proto.http.trace_exporter")
35
+ logger.setLevel(logging.CRITICAL)
36
+
37
+ load_dotenv(override=True)
37
38
 
38
39
  def _get_prompt(prompt_template: str, topic: str, custom_instructions: str):
39
40
  """
@@ -57,15 +58,14 @@ def _get_prompt(prompt_template: str, topic: str, custom_instructions: str):
57
58
 
58
59
  def _retry_if_exception(exception):
59
60
  # Define the condition to retry on certain exceptions
60
- return isinstance(
61
- exception, (TimeoutError)
62
- )
61
+ return isinstance(exception, (TimeoutError))
63
62
 
64
63
 
65
64
  class Agent:
66
65
  """
67
66
  Agent class for handling different types of agents and their interactions.
68
67
  """
68
+
69
69
  def __init__(
70
70
  self,
71
71
  tools: list[FunctionTool],
@@ -73,6 +73,7 @@ class Agent:
73
73
  custom_instructions: str = "",
74
74
  verbose: bool = True,
75
75
  update_func: Optional[Callable[[AgentStatusType, str], None]] = None,
76
+ agent_progress_callback: Optional[Callable[[AgentStatusType, str], None]] = None,
76
77
  agent_type: AgentType = None,
77
78
  ) -> None:
78
79
  """
@@ -84,35 +85,38 @@ class Agent:
84
85
  topic (str, optional): The topic for the agent. Defaults to 'general'.
85
86
  custom_instructions (str, optional): Custom instructions for the agent. Defaults to ''.
86
87
  verbose (bool, optional): Whether the agent should print its steps. Defaults to True.
87
- update_func (Callable): A callback function the code calls on any agent updates.
88
+ agent_progress_callback (Callable): A callback function the code calls on any agent updates.
89
+ update_func (Callable): old name for agent_progress_callback. Will be deprecated in future.
90
+ agent_type (AgentType, optional): The type of agent to be used. Defaults to None.
88
91
  """
89
92
  self.agent_type = agent_type or AgentType(os.getenv("VECTARA_AGENTIC_AGENT_TYPE", "OPENAI"))
90
93
  self.tools = tools
91
94
  self.llm = get_llm(LLMRole.MAIN)
92
95
  self._custom_instructions = custom_instructions
93
96
  self._topic = topic
97
+ self.agent_progress_callback = agent_progress_callback if agent_progress_callback else update_func
94
98
 
95
99
  main_tok = get_tokenizer_for_model(role=LLMRole.MAIN)
96
100
  self.main_token_counter = TokenCountingHandler(tokenizer=main_tok) if main_tok else None
97
101
  tool_tok = get_tokenizer_for_model(role=LLMRole.TOOL)
98
102
  self.tool_token_counter = TokenCountingHandler(tokenizer=tool_tok) if tool_tok else None
99
103
 
100
- callbacks: list[BaseCallbackHandler] = [AgentCallbackHandler(update_func)]
104
+ callbacks: list[BaseCallbackHandler] = [AgentCallbackHandler(self.agent_progress_callback)]
101
105
  if self.main_token_counter:
102
106
  callbacks.append(self.main_token_counter)
103
107
  if self.tool_token_counter:
104
108
  callbacks.append(self.tool_token_counter)
105
- callback_manager = CallbackManager(callbacks) # type: ignore
109
+ callback_manager = CallbackManager(callbacks) # type: ignore
106
110
  self.llm.callback_manager = callback_manager
107
111
  self.verbose = verbose
108
112
 
109
- memory = ChatMemoryBuffer.from_defaults(token_limit=128000)
113
+ self.memory = ChatMemoryBuffer.from_defaults(token_limit=128000)
110
114
  if self.agent_type == AgentType.REACT:
111
115
  prompt = _get_prompt(REACT_PROMPT_TEMPLATE, topic, custom_instructions)
112
116
  self.agent = ReActAgent.from_tools(
113
117
  tools=tools,
114
118
  llm=self.llm,
115
- memory=memory,
119
+ memory=self.memory,
116
120
  verbose=verbose,
117
121
  react_chat_formatter=ReActChatFormatter(system_header=prompt),
118
122
  max_iterations=30,
@@ -123,7 +127,7 @@ class Agent:
123
127
  self.agent = OpenAIAgent.from_tools(
124
128
  tools=tools,
125
129
  llm=self.llm,
126
- memory=memory,
130
+ memory=self.memory,
127
131
  verbose=verbose,
128
132
  callable_manager=callback_manager,
129
133
  max_function_calls=20,
@@ -134,7 +138,7 @@ class Agent:
134
138
  tools=tools,
135
139
  llm=self.llm,
136
140
  verbose=verbose,
137
- callable_manager=callback_manager
141
+ callable_manager=callback_manager,
138
142
  ).as_agent()
139
143
  else:
140
144
  raise ValueError(f"Unknown agent type: {self.agent_type}")
@@ -145,14 +149,26 @@ class Agent:
145
149
  print(f"Failed to set up observer ({e}), ignoring")
146
150
  self.observability_enabled = False
147
151
 
152
+ def clear_memory(self) -> None:
153
+ """
154
+ Clear the agent's memory.
155
+ """
156
+ self.agent.memory.reset()
157
+
148
158
  def __eq__(self, other):
159
+ """
160
+ Compare two Agent instances for equality.
161
+ """
149
162
  if not isinstance(other, Agent):
150
163
  print(f"Comparison failed: other is not an instance of Agent. (self: {type(self)}, other: {type(other)})")
151
164
  return False
152
165
 
153
166
  # Compare agent_type
154
167
  if self.agent_type != other.agent_type:
155
- print(f"Comparison failed: agent_type differs. (self.agent_type: {self.agent_type}, other.agent_type: {other.agent_type})")
168
+ print(
169
+ f"Comparison failed: agent_type differs. (self.agent_type: {self.agent_type}, "
170
+ f"other.agent_type: {other.agent_type})"
171
+ )
156
172
  return False
157
173
 
158
174
  # Compare tools
@@ -167,7 +183,10 @@ class Agent:
167
183
 
168
184
  # Compare custom_instructions
169
185
  if self._custom_instructions != other._custom_instructions:
170
- print(f"Comparison failed: custom_instructions differ. (self.custom_instructions: {self._custom_instructions}, other.custom_instructions: {other._custom_instructions})")
186
+ print(
187
+ "Comparison failed: custom_instructions differ. (self.custom_instructions: "
188
+ f"{self._custom_instructions}, other.custom_instructions: {other._custom_instructions})"
189
+ )
171
190
  return False
172
191
 
173
192
  # Compare verbose
@@ -177,7 +196,10 @@ class Agent:
177
196
 
178
197
  # Compare agent
179
198
  if self.agent.memory.chat_store != other.agent.memory.chat_store:
180
- print(f"Comparison failed: agent memory differs. (self.agent: {repr(self.agent.memory.chat_store)}, other.agent: {repr(other.agent.memory.chat_store)})")
199
+ print(
200
+ f"Comparison failed: agent memory differs. (self.agent: {repr(self.agent.memory.chat_store)}, "
201
+ f"other.agent: {repr(other.agent.memory.chat_store)})"
202
+ )
181
203
  return False
182
204
 
183
205
  # If all comparisons pass
@@ -192,6 +214,7 @@ class Agent:
192
214
  custom_instructions: str = "",
193
215
  verbose: bool = True,
194
216
  update_func: Optional[Callable[[AgentStatusType, str], None]] = None,
217
+ agent_progress_callback: Optional[Callable[[AgentStatusType, str], None]] = None,
195
218
  agent_type: AgentType = None,
196
219
  ) -> "Agent":
197
220
  """
@@ -203,13 +226,18 @@ class Agent:
203
226
  topic (str, optional): The topic for the agent. Defaults to 'general'.
204
227
  custom_instructions (str, optional): custom instructions for the agent. Defaults to ''.
205
228
  verbose (bool, optional): Whether the agent should print its steps. Defaults to True.
206
- update_func (Callable): A callback function the code calls on any agent updates.
207
-
229
+ agent_progress_callback (Callable): A callback function the code calls on any agent updates.
230
+ update_func (Callable): old name for agent_progress_callback. Will be deprecated in future.
231
+ agent_type (AgentType, optional): The type of agent to be used. Defaults to None.
208
232
 
209
233
  Returns:
210
234
  Agent: An instance of the Agent class.
211
235
  """
212
- return cls(tools, topic, custom_instructions, verbose, update_func, agent_type)
236
+ return cls(
237
+ tools=tools, topic=topic, custom_instructions=custom_instructions,
238
+ verbose=verbose, agent_progress_callback=agent_progress_callback,
239
+ update_func=update_func, agent_type=agent_type
240
+ )
213
241
 
214
242
  @classmethod
215
243
  def from_corpus(
@@ -220,6 +248,7 @@ class Agent:
220
248
  vectara_customer_id: str = str(os.environ.get("VECTARA_CUSTOMER_ID", "")),
221
249
  vectara_corpus_id: str = str(os.environ.get("VECTARA_CORPUS_ID", "")),
222
250
  vectara_api_key: str = str(os.environ.get("VECTARA_API_KEY", "")),
251
+ agent_progress_callback: Optional[Callable[[AgentStatusType, str], None]] = None,
223
252
  verbose: bool = False,
224
253
  vectara_filter_fields: list[dict] = [],
225
254
  vectara_lambda_val: float = 0.005,
@@ -238,10 +267,12 @@ class Agent:
238
267
  vectara_customer_id (str): The Vectara customer ID.
239
268
  vectara_corpus_id (str): The Vectara corpus ID (or comma separated list of IDs).
240
269
  vectara_api_key (str): The Vectara API key.
270
+ agent_progress_callback (Callable): A callback function the code calls on any agent updates.
241
271
  data_description (str): The description of the data.
242
272
  assistant_specialty (str): The specialty of the assistant.
243
273
  verbose (bool, optional): Whether to print verbose output.
244
- vectara_filter_fields (List[dict], optional): The filterable attributes (each dict maps field name to Tuple[type, description]).
274
+ vectara_filter_fields (List[dict], optional): The filterable attributes
275
+ (each dict maps field name to Tuple[type, description]).
245
276
  vectara_lambda_val (float, optional): The lambda value for Vectara hybrid search.
246
277
  vectara_reranker (str, optional): The Vectara reranker name (default "mmr")
247
278
  vectara_rerank_k (int, optional): The number of results to use with reranking.
@@ -253,18 +284,19 @@ class Agent:
253
284
  Returns:
254
285
  Agent: An instance of the Agent class.
255
286
  """
256
- vec_factory = VectaraToolFactory(vectara_api_key=vectara_api_key,
257
- vectara_customer_id=vectara_customer_id,
258
- vectara_corpus_id=vectara_corpus_id)
287
+ vec_factory = VectaraToolFactory(
288
+ vectara_api_key=vectara_api_key,
289
+ vectara_customer_id=vectara_customer_id,
290
+ vectara_corpus_id=vectara_corpus_id,
291
+ )
259
292
  field_definitions = {}
260
- field_definitions['query'] = (str, Field(description="The user query")) # type: ignore
293
+ field_definitions["query"] = (str, Field(description="The user query")) # type: ignore
261
294
  for field in vectara_filter_fields:
262
- field_definitions[field['name']] = (eval(field['type']),
263
- Field(description=field['description'])) # type: ignore
264
- QueryArgs = create_model( # type: ignore
265
- "QueryArgs",
266
- **field_definitions
267
- )
295
+ field_definitions[field["name"]] = (
296
+ eval(field["type"]),
297
+ Field(description=field["description"]),
298
+ ) # type: ignore
299
+ query_args = create_model("QueryArgs", **field_definitions) # type: ignore
268
300
 
269
301
  vectara_tool = vec_factory.create_rag_tool(
270
302
  tool_name=tool_name or f"vectara_{vectara_corpus_id}",
@@ -272,8 +304,9 @@ class Agent:
272
304
  Given a user query,
273
305
  returns a response (str) to a user question about {data_description}.
274
306
  """,
275
- tool_args_schema=QueryArgs,
276
- reranker=vectara_reranker, rerank_k=vectara_rerank_k,
307
+ tool_args_schema=query_args,
308
+ reranker=vectara_reranker,
309
+ rerank_k=vectara_rerank_k,
277
310
  n_sentences_before=vectara_n_sentences_before,
278
311
  n_sentences_after=vectara_n_sentences_after,
279
312
  lambda_val=vectara_lambda_val,
@@ -293,7 +326,7 @@ class Agent:
293
326
  topic=assistant_specialty,
294
327
  custom_instructions=assistant_instructions,
295
328
  verbose=verbose,
296
- update_func=None
329
+ agent_progress_callback=agent_progress_callback,
297
330
  )
298
331
 
299
332
  def report(self) -> None:
@@ -308,7 +341,7 @@ class Agent:
308
341
  print(f"Topic = {self._topic}")
309
342
  print("Tools:")
310
343
  for tool in self.tools:
311
- print(f"- {tool._metadata.name}")
344
+ print(f"- {tool.metadata.name}")
312
345
  print(f"Agent LLM = {get_llm(LLMRole.MAIN).metadata.model_name}")
313
346
  print(f"Tool LLM = {get_llm(LLMRole.TOOL).metadata.model_name}")
314
347
 
@@ -349,7 +382,6 @@ class Agent:
349
382
  eval_fcs()
350
383
  return agent_response.response
351
384
  except Exception as e:
352
- import traceback
353
385
  return f"Vectara Agentic: encountered an exception ({e}) at ({traceback.format_exc()}), and can't respond."
354
386
 
355
387
  # Serialization methods
@@ -371,17 +403,21 @@ class Agent:
371
403
  # Serialize each tool's metadata, function, and dynamic model schema (QueryArgs)
372
404
  tool_dict = {
373
405
  "tool_type": tool.tool_type.value,
374
- "name": tool._metadata.name,
375
- "description": tool._metadata.description,
376
- "fn": dill.dumps(tool.fn).decode('latin-1') if tool.fn else None, # Serialize fn
377
- "async_fn": dill.dumps(tool.async_fn).decode('latin-1') if tool.async_fn else None, # Serialize async_fn
378
- "fn_schema": tool._metadata.fn_schema.model_json_schema() if hasattr(tool._metadata, 'fn_schema') else None, # Serialize schema if available
406
+ "name": tool.metadata.name,
407
+ "description": tool.metadata.description,
408
+ "fn": dill.dumps(tool.fn).decode("latin-1") if tool.fn else None, # Serialize fn
409
+ "async_fn": dill.dumps(tool.async_fn).decode("latin-1")
410
+ if tool.async_fn
411
+ else None, # Serialize async_fn
412
+ "fn_schema": tool.metadata.fn_schema.model_json_schema()
413
+ if hasattr(tool.metadata, "fn_schema")
414
+ else None, # Serialize schema if available
379
415
  }
380
416
  tool_info.append(tool_dict)
381
417
 
382
418
  return {
383
419
  "agent_type": self.agent_type.value,
384
- "memory": dill.dumps(self.agent.memory).decode('latin-1'),
420
+ "memory": dill.dumps(self.agent.memory).decode("latin-1"),
385
421
  "tools": tool_info,
386
422
  "topic": self._topic,
387
423
  "custom_instructions": self._custom_instructions,
@@ -394,13 +430,13 @@ class Agent:
394
430
  agent_type = AgentType(data["agent_type"])
395
431
  tools = []
396
432
 
397
- JSON_TYPE_TO_PYTHON = {
398
- "string": "str",
399
- "integer": "int",
400
- "boolean": "bool",
401
- "array": "list",
402
- "object": "dict",
403
- "number": "float",
433
+ json_type_to_python = {
434
+ "string": str,
435
+ "integer": int,
436
+ "boolean": bool,
437
+ "array": list,
438
+ "object": dict,
439
+ "number": float,
404
440
  }
405
441
 
406
442
  for tool_data in data["tools"]:
@@ -408,29 +444,33 @@ class Agent:
408
444
  if tool_data.get("fn_schema"):
409
445
  field_definitions = {}
410
446
  for field, values in tool_data["fn_schema"]["properties"].items():
411
- if 'default' in values:
412
- field_definitions[field] = (eval(JSON_TYPE_TO_PYTHON.get(values['type'], values['type'])),
413
- Field(description=values['description'], default=values['default'])) # type: ignore
447
+ if "default" in values:
448
+ field_definitions[field] = (
449
+ json_type_to_python.get(values["type"], values["type"]),
450
+ Field(
451
+ description=values["description"],
452
+ default=values["default"],
453
+ ),
454
+ ) # type: ignore
414
455
  else:
415
- field_definitions[field] = (eval(JSON_TYPE_TO_PYTHON.get(values['type'], values['type'])),
416
- Field(description=values['description'])) # type: ignore
417
- query_args_model = create_model( # type: ignore
418
- "QueryArgs",
419
- **field_definitions
420
- )
456
+ field_definitions[field] = (
457
+ json_type_to_python.get(values["type"], values["type"]),
458
+ Field(description=values["description"]),
459
+ ) # type: ignore
460
+ query_args_model = create_model("QueryArgs", **field_definitions) # type: ignore
421
461
  else:
422
462
  query_args_model = create_model("QueryArgs")
423
463
 
424
- fn = dill.loads(tool_data["fn"].encode('latin-1')) if tool_data["fn"] else None
425
- async_fn = dill.loads(tool_data["async_fn"].encode('latin-1')) if tool_data["async_fn"] else None
464
+ fn = dill.loads(tool_data["fn"].encode("latin-1")) if tool_data["fn"] else None
465
+ async_fn = dill.loads(tool_data["async_fn"].encode("latin-1")) if tool_data["async_fn"] else None
426
466
 
427
467
  tool = VectaraTool.from_defaults(
428
- tool_type=ToolType(tool_data["tool_type"]),
429
468
  name=tool_data["name"],
430
469
  description=tool_data["description"],
431
470
  fn=fn,
432
471
  async_fn=async_fn,
433
- fn_schema=query_args_model # Re-assign the recreated dynamic model
472
+ fn_schema=query_args_model, # Re-assign the recreated dynamic model
473
+ tool_type=ToolType(tool_data["tool_type"]),
434
474
  )
435
475
  tools.append(tool)
436
476
 
@@ -441,7 +481,7 @@ class Agent:
441
481
  custom_instructions=data["custom_instructions"],
442
482
  verbose=data["verbose"],
443
483
  )
444
- memory = dill.loads(data["memory"].encode('latin-1')) if data.get("memory") else None
484
+ memory = dill.loads(data["memory"].encode("latin-1")) if data.get("memory") else None
445
485
  if memory:
446
486
  agent.agent.memory = memory
447
487
  return agent
@@ -0,0 +1,63 @@
1
+ """
2
+ This module contains functions to start the agent behind an API endpoint.
3
+ """
4
+ import os
5
+ import logging
6
+ from fastapi import FastAPI, HTTPException, Depends
7
+ from fastapi.security.api_key import APIKeyHeader
8
+ from pydantic import BaseModel
9
+ import uvicorn
10
+
11
+ from .agent import Agent
12
+
13
+ API_KEY_NAME = "X-API-Key"
14
+ API_KEY = os.getenv("VECTARA_AGENTIC_API_KEY", "dev-api-key")
15
+ api_key_header = APIKeyHeader(name=API_KEY_NAME)
16
+
17
+ class ChatRequest(BaseModel):
18
+ """
19
+ A request model for the chat endpoint.
20
+ """
21
+ message: str
22
+
23
+
24
+ def create_app(agent: Agent) -> FastAPI:
25
+ """
26
+ Create a FastAPI application with a chat endpoint.
27
+ """
28
+ app = FastAPI()
29
+ logger = logging.getLogger("uvicorn.error")
30
+ logging.basicConfig(level=logging.INFO)
31
+
32
+ @app.get("/chat", summary="Chat with the agent")
33
+ async def chat(message: str, api_key: str = Depends(api_key_header)):
34
+ logger.info(f"Received message: {message}")
35
+ if api_key != API_KEY:
36
+ logger.warning("Unauthorized access attempt")
37
+ raise HTTPException(status_code=403, detail="Unauthorized")
38
+
39
+ if not message:
40
+ logger.error("No message provided in the request")
41
+ raise HTTPException(status_code=400, detail="No message provided")
42
+
43
+ try:
44
+ response = agent.chat(message)
45
+ logger.info(f"Generated response: {response}")
46
+ return {"response": response}
47
+ except Exception as e:
48
+ logger.error(f"Error during agent processing: {e}")
49
+ raise HTTPException(status_code=500, detail="Internal server error") from e
50
+
51
+ return app
52
+
53
+
54
+ def start_app(agent: Agent, host='0.0.0.0', port=8000):
55
+ """
56
+ Start the FastAPI server.
57
+
58
+ Args:
59
+ host (str, optional): The host address for the API. Defaults to '127.0.0.1'.
60
+ port (int, optional): The port for the API. Defaults to 8000.
61
+ """
62
+ app = create_app(agent)
63
+ uvicorn.run(app, host=host, port=port)