vectara-agentic 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of vectara-agentic might be problematic. Click here for more details.
- tests/__init__.py +7 -0
- tests/conftest.py +316 -0
- tests/endpoint.py +54 -17
- tests/run_tests.py +112 -0
- tests/test_agent.py +35 -33
- tests/test_agent_fallback_memory.py +270 -0
- tests/test_agent_memory_consistency.py +229 -0
- tests/test_agent_type.py +86 -143
- tests/test_api_endpoint.py +4 -0
- tests/test_bedrock.py +50 -31
- tests/test_fallback.py +4 -0
- tests/test_gemini.py +27 -59
- tests/test_groq.py +50 -31
- tests/test_private_llm.py +11 -2
- tests/test_return_direct.py +6 -2
- tests/test_serialization.py +7 -6
- tests/test_session_memory.py +252 -0
- tests/test_streaming.py +109 -0
- tests/test_together.py +62 -0
- tests/test_tools.py +10 -82
- tests/test_vectara_llms.py +4 -0
- tests/test_vhc.py +67 -0
- tests/test_workflow.py +13 -28
- vectara_agentic/__init__.py +27 -4
- vectara_agentic/_callback.py +65 -67
- vectara_agentic/_observability.py +30 -30
- vectara_agentic/_version.py +1 -1
- vectara_agentic/agent.py +565 -859
- vectara_agentic/agent_config.py +15 -14
- vectara_agentic/agent_core/__init__.py +22 -0
- vectara_agentic/agent_core/factory.py +383 -0
- vectara_agentic/{_prompts.py → agent_core/prompts.py} +21 -46
- vectara_agentic/agent_core/serialization.py +348 -0
- vectara_agentic/agent_core/streaming.py +483 -0
- vectara_agentic/agent_core/utils/__init__.py +29 -0
- vectara_agentic/agent_core/utils/hallucination.py +157 -0
- vectara_agentic/agent_core/utils/logging.py +52 -0
- vectara_agentic/agent_core/utils/schemas.py +87 -0
- vectara_agentic/agent_core/utils/tools.py +125 -0
- vectara_agentic/agent_endpoint.py +4 -6
- vectara_agentic/db_tools.py +37 -12
- vectara_agentic/llm_utils.py +42 -43
- vectara_agentic/sub_query_workflow.py +9 -14
- vectara_agentic/tool_utils.py +138 -83
- vectara_agentic/tools.py +36 -21
- vectara_agentic/tools_catalog.py +16 -16
- vectara_agentic/types.py +106 -8
- {vectara_agentic-0.3.3.dist-info → vectara_agentic-0.4.1.dist-info}/METADATA +111 -31
- vectara_agentic-0.4.1.dist-info/RECORD +53 -0
- tests/test_agent_planning.py +0 -64
- tests/test_hhem.py +0 -100
- vectara_agentic/hhem.py +0 -82
- vectara_agentic-0.3.3.dist-info/RECORD +0 -39
- {vectara_agentic-0.3.3.dist-info → vectara_agentic-0.4.1.dist-info}/WHEEL +0 -0
- {vectara_agentic-0.3.3.dist-info → vectara_agentic-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {vectara_agentic-0.3.3.dist-info → vectara_agentic-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,483 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Streaming utilities for agent responses.
|
|
3
|
+
|
|
4
|
+
This module provides streaming response handling, adapters, and utilities
|
|
5
|
+
for managing asynchronous agent interactions with proper synchronization.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
import logging
|
|
10
|
+
import uuid
|
|
11
|
+
import json
|
|
12
|
+
import traceback
|
|
13
|
+
|
|
14
|
+
from typing import Callable, Any, Dict, AsyncIterator
|
|
15
|
+
from collections import OrderedDict
|
|
16
|
+
|
|
17
|
+
from llama_index.core.agent.workflow import (
|
|
18
|
+
ToolCall,
|
|
19
|
+
ToolCallResult,
|
|
20
|
+
AgentInput,
|
|
21
|
+
AgentOutput,
|
|
22
|
+
)
|
|
23
|
+
from ..types import AgentResponse
|
|
24
|
+
|
|
25
|
+
class ToolEventTracker:
|
|
26
|
+
"""
|
|
27
|
+
Tracks event IDs for tool calls to ensure consistent pairing of tool calls and outputs.
|
|
28
|
+
|
|
29
|
+
This class maintains a mapping between tool identifiers and event IDs to ensure
|
|
30
|
+
that related tool call and tool output events share the same event_id for proper
|
|
31
|
+
frontend grouping.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self):
|
|
35
|
+
self.event_ids = OrderedDict() # tool_call_id -> event_id mapping
|
|
36
|
+
self.fallback_counter = 0 # For events without identifiable tool_ids
|
|
37
|
+
|
|
38
|
+
def get_event_id(self, event) -> str:
|
|
39
|
+
"""
|
|
40
|
+
Get a consistent event ID for a tool event.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
event: The tool event object
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
str: Consistent event ID for this tool execution
|
|
47
|
+
"""
|
|
48
|
+
# Try to get tool_id from the event first
|
|
49
|
+
tool_id = getattr(event, "tool_id", None)
|
|
50
|
+
|
|
51
|
+
# If we have a tool_id, use it directly (any format from any LLM provider)
|
|
52
|
+
if tool_id:
|
|
53
|
+
pass # We already have tool_id, just use it
|
|
54
|
+
# If no tool_id, try to derive one from tool_name (for LlamaIndex events)
|
|
55
|
+
elif hasattr(event, "tool_name") and event.tool_name:
|
|
56
|
+
tool_id = f"{event.tool_name}_{self.fallback_counter}"
|
|
57
|
+
self.fallback_counter += 1
|
|
58
|
+
# If still no tool_id, create a generic one based on event type
|
|
59
|
+
else:
|
|
60
|
+
event_type = type(event).__name__
|
|
61
|
+
tool_id = f"{event_type.lower()}_{self.fallback_counter}"
|
|
62
|
+
self.fallback_counter += 1
|
|
63
|
+
|
|
64
|
+
# Get or create event_id for this tool_id
|
|
65
|
+
if tool_id not in self.event_ids:
|
|
66
|
+
self.event_ids[tool_id] = str(uuid.uuid4())
|
|
67
|
+
|
|
68
|
+
return self.event_ids[tool_id]
|
|
69
|
+
|
|
70
|
+
def clear_old_entries(self, max_entries: int = 100):
|
|
71
|
+
"""Clear old entries to prevent unbounded memory growth."""
|
|
72
|
+
while len(self.event_ids) > max_entries // 2:
|
|
73
|
+
self.event_ids.popitem(last=False) # Remove oldest entry
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class StreamingResponseAdapter:
|
|
77
|
+
"""
|
|
78
|
+
Adapter class that provides a LlamaIndex-compatible streaming response interface.
|
|
79
|
+
|
|
80
|
+
This class bridges custom streaming logic with AgentStreamingResponse expectations
|
|
81
|
+
by implementing the required protocol methods and properties.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
def __init__(
|
|
85
|
+
self,
|
|
86
|
+
async_response_gen: Callable[[], Any] | None = None,
|
|
87
|
+
response: str = "",
|
|
88
|
+
metadata: Dict[str, Any] | None = None,
|
|
89
|
+
post_process_task: Any = None,
|
|
90
|
+
) -> None:
|
|
91
|
+
"""
|
|
92
|
+
Initialize the streaming response adapter.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
async_response_gen: Async generator function for streaming tokens
|
|
96
|
+
response: Final response text (filled after streaming completes)
|
|
97
|
+
metadata: Response metadata dictionary
|
|
98
|
+
post_process_task: Async task that will populate response/metadata
|
|
99
|
+
"""
|
|
100
|
+
self.async_response_gen = async_response_gen
|
|
101
|
+
self.response = response
|
|
102
|
+
self.metadata = metadata or {}
|
|
103
|
+
self.post_process_task = post_process_task
|
|
104
|
+
|
|
105
|
+
async def aget_response(self) -> AgentResponse:
|
|
106
|
+
"""
|
|
107
|
+
Async version that waits for post-processing to complete.
|
|
108
|
+
"""
|
|
109
|
+
if self.post_process_task:
|
|
110
|
+
final_response = await self.post_process_task
|
|
111
|
+
# Update our state with the final response
|
|
112
|
+
self.response = final_response.response
|
|
113
|
+
self.metadata = final_response.metadata or {}
|
|
114
|
+
return AgentResponse(response=self.response, metadata=self.metadata)
|
|
115
|
+
|
|
116
|
+
def get_response(self) -> AgentResponse:
|
|
117
|
+
"""
|
|
118
|
+
Return an AgentResponse using the current state.
|
|
119
|
+
|
|
120
|
+
Required by the _StreamProto protocol for AgentStreamingResponse compatibility.
|
|
121
|
+
"""
|
|
122
|
+
return AgentResponse(response=self.response, metadata=self.metadata)
|
|
123
|
+
|
|
124
|
+
def wait_for_completion(self) -> None:
|
|
125
|
+
"""
|
|
126
|
+
Wait for post-processing to complete and update metadata.
|
|
127
|
+
This should be called after streaming finishes but before accessing metadata.
|
|
128
|
+
"""
|
|
129
|
+
if self.post_process_task and not self.post_process_task.done():
|
|
130
|
+
return
|
|
131
|
+
if self.post_process_task and self.post_process_task.done():
|
|
132
|
+
try:
|
|
133
|
+
final_response = self.post_process_task.result()
|
|
134
|
+
if hasattr(final_response, "metadata") and final_response.metadata:
|
|
135
|
+
# Update our metadata from the completed task
|
|
136
|
+
self.metadata.update(final_response.metadata)
|
|
137
|
+
except Exception as e:
|
|
138
|
+
logging.error(
|
|
139
|
+
f"Error during post-processing: {e}. "
|
|
140
|
+
"Ensure the post-processing task is correctly implemented."
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def extract_response_text_from_chat_message(response_text: Any) -> str:
|
|
145
|
+
"""
|
|
146
|
+
Extract text content from various response formats.
|
|
147
|
+
|
|
148
|
+
Handles ChatMessage objects with blocks, content attributes, or plain strings.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
response_text: Response object that may be ChatMessage, string, or other format
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
str: Extracted text content
|
|
155
|
+
"""
|
|
156
|
+
# Handle case where response is a ChatMessage object
|
|
157
|
+
if hasattr(response_text, "content"):
|
|
158
|
+
return response_text.content
|
|
159
|
+
elif hasattr(response_text, "blocks"):
|
|
160
|
+
# Extract text from ChatMessage blocks
|
|
161
|
+
text_parts = []
|
|
162
|
+
for block in response_text.blocks:
|
|
163
|
+
if hasattr(block, "text"):
|
|
164
|
+
text_parts.append(block.text)
|
|
165
|
+
return "".join(text_parts)
|
|
166
|
+
elif not isinstance(response_text, str):
|
|
167
|
+
return str(response_text)
|
|
168
|
+
|
|
169
|
+
return response_text
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
async def execute_post_stream_processing(
|
|
173
|
+
result: Any,
|
|
174
|
+
prompt: str,
|
|
175
|
+
agent_instance,
|
|
176
|
+
user_metadata: Dict[str, Any],
|
|
177
|
+
) -> AgentResponse:
|
|
178
|
+
"""
|
|
179
|
+
Execute post-stream processing on a completed result.
|
|
180
|
+
|
|
181
|
+
This function consolidates the common post-processing steps that happen
|
|
182
|
+
after streaming completes, including response extraction, formatting,
|
|
183
|
+
callbacks, and FCS calculation.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
result: The completed result object from streaming
|
|
187
|
+
prompt: Original user prompt
|
|
188
|
+
agent_instance: Agent instance for callbacks and processing
|
|
189
|
+
user_metadata: User metadata to update with FCS scores
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
AgentResponse: Processed final response
|
|
193
|
+
"""
|
|
194
|
+
if result is None:
|
|
195
|
+
logging.warning(
|
|
196
|
+
"Received None result from streaming, returning empty response."
|
|
197
|
+
)
|
|
198
|
+
return AgentResponse(
|
|
199
|
+
response="No response generated",
|
|
200
|
+
metadata=getattr(result, "metadata", {}),
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# Ensure we have an AgentResponse object with a string response
|
|
204
|
+
if hasattr(result, "response"):
|
|
205
|
+
response_text = result.response
|
|
206
|
+
else:
|
|
207
|
+
response_text = str(result)
|
|
208
|
+
|
|
209
|
+
# Extract text from various response formats
|
|
210
|
+
response_text = extract_response_text_from_chat_message(response_text)
|
|
211
|
+
|
|
212
|
+
final = AgentResponse(
|
|
213
|
+
response=response_text,
|
|
214
|
+
metadata=getattr(result, "metadata", {}),
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# Post-processing steps
|
|
218
|
+
|
|
219
|
+
if agent_instance.query_logging_callback:
|
|
220
|
+
agent_instance.query_logging_callback(prompt, final.response)
|
|
221
|
+
|
|
222
|
+
# Let LlamaIndex handle agent memory naturally - no custom capture needed
|
|
223
|
+
|
|
224
|
+
if not final.metadata:
|
|
225
|
+
final.metadata = {}
|
|
226
|
+
final.metadata.update(user_metadata)
|
|
227
|
+
|
|
228
|
+
if agent_instance.observability_enabled:
|
|
229
|
+
from .._observability import eval_fcs
|
|
230
|
+
|
|
231
|
+
eval_fcs()
|
|
232
|
+
|
|
233
|
+
return final
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def create_stream_post_processing_task(
|
|
237
|
+
stream_complete_event: asyncio.Event,
|
|
238
|
+
final_response_container: Dict[str, Any],
|
|
239
|
+
prompt: str,
|
|
240
|
+
agent_instance,
|
|
241
|
+
user_metadata: Dict[str, Any],
|
|
242
|
+
) -> asyncio.Task:
|
|
243
|
+
"""
|
|
244
|
+
Create an async task for post-stream processing.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
stream_complete_event: Event to wait for stream completion
|
|
248
|
+
final_response_container: Container with final response data
|
|
249
|
+
prompt: Original user prompt
|
|
250
|
+
agent_instance: Agent instance for callbacks and processing
|
|
251
|
+
user_metadata: User metadata to update with FCS scores
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
asyncio.Task: Task that will process the final response
|
|
255
|
+
"""
|
|
256
|
+
|
|
257
|
+
async def _post_process():
|
|
258
|
+
# Wait until the generator has finished and final response is populated
|
|
259
|
+
await stream_complete_event.wait()
|
|
260
|
+
result = final_response_container.get("resp")
|
|
261
|
+
return await execute_post_stream_processing(
|
|
262
|
+
result, prompt, agent_instance, user_metadata
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
async def _safe_post_process():
|
|
266
|
+
try:
|
|
267
|
+
return await _post_process()
|
|
268
|
+
except Exception:
|
|
269
|
+
traceback.print_exc()
|
|
270
|
+
# Return empty response on error
|
|
271
|
+
return AgentResponse(response="", metadata={})
|
|
272
|
+
|
|
273
|
+
return asyncio.create_task(_safe_post_process())
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
class FunctionCallingStreamHandler:
|
|
277
|
+
"""
|
|
278
|
+
Handles streaming for function calling agents with proper event processing.
|
|
279
|
+
"""
|
|
280
|
+
|
|
281
|
+
def __init__(self, agent_instance, handler, prompt: str):
|
|
282
|
+
self.agent_instance = agent_instance
|
|
283
|
+
self.handler = handler
|
|
284
|
+
self.prompt = prompt
|
|
285
|
+
self.final_response_container = {"resp": None}
|
|
286
|
+
self.stream_complete_event = asyncio.Event()
|
|
287
|
+
self.event_tracker = ToolEventTracker()
|
|
288
|
+
|
|
289
|
+
async def process_stream_events(self) -> AsyncIterator[str]:
|
|
290
|
+
"""
|
|
291
|
+
Process streaming events and yield text tokens.
|
|
292
|
+
|
|
293
|
+
Yields:
|
|
294
|
+
str: Text tokens from the streaming response
|
|
295
|
+
"""
|
|
296
|
+
had_tool_calls = False
|
|
297
|
+
transitioned_to_prose = False
|
|
298
|
+
|
|
299
|
+
async for ev in self.handler.stream_events():
|
|
300
|
+
# Store tool outputs for VHC regardless of progress callback
|
|
301
|
+
if isinstance(ev, ToolCallResult):
|
|
302
|
+
if hasattr(self.agent_instance, '_add_tool_output'):
|
|
303
|
+
# pylint: disable=W0212
|
|
304
|
+
self.agent_instance._add_tool_output(ev.tool_name, str(ev.tool_output))
|
|
305
|
+
|
|
306
|
+
# Handle progress callbacks if available
|
|
307
|
+
if self.agent_instance.agent_progress_callback:
|
|
308
|
+
# Only track events that are actual tool-related events
|
|
309
|
+
if self._is_tool_related_event(ev):
|
|
310
|
+
event_id = self.event_tracker.get_event_id(ev)
|
|
311
|
+
await self._handle_progress_callback(ev, event_id)
|
|
312
|
+
|
|
313
|
+
# Process streaming text events
|
|
314
|
+
if hasattr(ev, "__class__") and "AgentStream" in str(ev.__class__):
|
|
315
|
+
if hasattr(ev, "tool_calls") and ev.tool_calls:
|
|
316
|
+
had_tool_calls = True
|
|
317
|
+
elif (
|
|
318
|
+
hasattr(ev, "tool_calls")
|
|
319
|
+
and not ev.tool_calls
|
|
320
|
+
and had_tool_calls
|
|
321
|
+
and not transitioned_to_prose
|
|
322
|
+
):
|
|
323
|
+
yield "\n\n"
|
|
324
|
+
transitioned_to_prose = True
|
|
325
|
+
if hasattr(ev, "delta"):
|
|
326
|
+
yield ev.delta
|
|
327
|
+
elif (
|
|
328
|
+
hasattr(ev, "tool_calls")
|
|
329
|
+
and not ev.tool_calls
|
|
330
|
+
and hasattr(ev, "delta")
|
|
331
|
+
):
|
|
332
|
+
yield ev.delta
|
|
333
|
+
|
|
334
|
+
# When stream is done, await the handler to get the final response
|
|
335
|
+
try:
|
|
336
|
+
self.final_response_container["resp"] = await self.handler
|
|
337
|
+
except Exception as e:
|
|
338
|
+
logging.error(f"🔍 [STREAM_ERROR] Error processing stream events: {e}")
|
|
339
|
+
logging.error(f"🔍 [STREAM_ERROR] Full traceback: {traceback.format_exc()}")
|
|
340
|
+
self.final_response_container["resp"] = type(
|
|
341
|
+
"AgentResponse",
|
|
342
|
+
(),
|
|
343
|
+
{
|
|
344
|
+
"response": "Response completion Error",
|
|
345
|
+
"source_nodes": [],
|
|
346
|
+
"metadata": None,
|
|
347
|
+
},
|
|
348
|
+
)()
|
|
349
|
+
finally:
|
|
350
|
+
# Clean up event tracker to prevent memory leaks
|
|
351
|
+
self.event_tracker.clear_old_entries()
|
|
352
|
+
# Signal that stream processing is complete
|
|
353
|
+
self.stream_complete_event.set()
|
|
354
|
+
|
|
355
|
+
def _is_tool_related_event(self, event) -> bool:
|
|
356
|
+
"""
|
|
357
|
+
Determine if an event is actually tool-related and should be tracked.
|
|
358
|
+
|
|
359
|
+
This should only return True for events that represent actual tool calls or tool outputs,
|
|
360
|
+
not for streaming text deltas or other LLM response events.
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
event: The stream event to check
|
|
364
|
+
|
|
365
|
+
Returns:
|
|
366
|
+
bool: True if this event should be tracked for tool purposes
|
|
367
|
+
"""
|
|
368
|
+
# Track explicit tool events from LlamaIndex workflow
|
|
369
|
+
if isinstance(event, (ToolCall, ToolCallResult)):
|
|
370
|
+
return True
|
|
371
|
+
|
|
372
|
+
has_tool_id = hasattr(event, "tool_id") and event.tool_id
|
|
373
|
+
has_delta = hasattr(event, "delta") and event.delta
|
|
374
|
+
has_tool_name = hasattr(event, "tool_name") and event.tool_name
|
|
375
|
+
|
|
376
|
+
# We're not seeing ToolCall/ToolCallResult events in the stream, so let's be more liberal
|
|
377
|
+
# but still avoid streaming deltas
|
|
378
|
+
if (has_tool_id or has_tool_name) and not has_delta:
|
|
379
|
+
return True
|
|
380
|
+
|
|
381
|
+
# Everything else (streaming deltas, agent outputs, workflow events, etc.)
|
|
382
|
+
# should NOT be tracked as tool events
|
|
383
|
+
return False
|
|
384
|
+
|
|
385
|
+
async def _handle_progress_callback(self, event, event_id: str):
|
|
386
|
+
"""Handle progress callback events for different event types with proper context propagation."""
|
|
387
|
+
# Import here to avoid circular imports
|
|
388
|
+
from ..types import AgentStatusType
|
|
389
|
+
|
|
390
|
+
try:
|
|
391
|
+
if isinstance(event, ToolCall):
|
|
392
|
+
# Check if callback is async or sync
|
|
393
|
+
if asyncio.iscoroutinefunction(
|
|
394
|
+
self.agent_instance.agent_progress_callback
|
|
395
|
+
):
|
|
396
|
+
await self.agent_instance.agent_progress_callback(
|
|
397
|
+
status_type=AgentStatusType.TOOL_CALL,
|
|
398
|
+
msg={
|
|
399
|
+
"tool_name": event.tool_name,
|
|
400
|
+
"arguments": json.dumps(event.tool_kwargs),
|
|
401
|
+
},
|
|
402
|
+
event_id=event_id,
|
|
403
|
+
)
|
|
404
|
+
else:
|
|
405
|
+
# For sync callbacks, ensure we call them properly
|
|
406
|
+
self.agent_instance.agent_progress_callback(
|
|
407
|
+
status_type=AgentStatusType.TOOL_CALL,
|
|
408
|
+
msg={
|
|
409
|
+
"tool_name": event.tool_name,
|
|
410
|
+
"arguments": json.dumps(event.tool_kwargs),
|
|
411
|
+
},
|
|
412
|
+
event_id=event_id,
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
elif isinstance(event, ToolCallResult):
|
|
416
|
+
# Check if callback is async or sync
|
|
417
|
+
if asyncio.iscoroutinefunction(
|
|
418
|
+
self.agent_instance.agent_progress_callback
|
|
419
|
+
):
|
|
420
|
+
await self.agent_instance.agent_progress_callback(
|
|
421
|
+
status_type=AgentStatusType.TOOL_OUTPUT,
|
|
422
|
+
msg={
|
|
423
|
+
"tool_name": event.tool_name,
|
|
424
|
+
"content": str(event.tool_output),
|
|
425
|
+
},
|
|
426
|
+
event_id=event_id,
|
|
427
|
+
)
|
|
428
|
+
else:
|
|
429
|
+
self.agent_instance.agent_progress_callback(
|
|
430
|
+
status_type=AgentStatusType.TOOL_OUTPUT,
|
|
431
|
+
msg={
|
|
432
|
+
"tool_name": event.tool_name,
|
|
433
|
+
"content": str(event.tool_output),
|
|
434
|
+
},
|
|
435
|
+
event_id=event_id,
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
elif isinstance(event, AgentInput):
|
|
439
|
+
self.agent_instance.agent_progress_callback(
|
|
440
|
+
status_type=AgentStatusType.AGENT_UPDATE,
|
|
441
|
+
msg={"content": f"Agent input: {event.input}"},
|
|
442
|
+
event_id=event_id,
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
elif isinstance(event, AgentOutput):
|
|
446
|
+
self.agent_instance.agent_progress_callback(
|
|
447
|
+
status_type=AgentStatusType.AGENT_UPDATE,
|
|
448
|
+
msg={"content": f"Agent output: {event.response}"},
|
|
449
|
+
event_id=event_id,
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
except Exception as e:
|
|
453
|
+
|
|
454
|
+
logging.error(f"Exception in progress callback: {e}")
|
|
455
|
+
logging.error(f"Traceback: {traceback.format_exc()}")
|
|
456
|
+
# Continue execution despite callback errors
|
|
457
|
+
|
|
458
|
+
def create_streaming_response(
|
|
459
|
+
self, user_metadata: Dict[str, Any]
|
|
460
|
+
) -> "StreamingResponseAdapter":
|
|
461
|
+
"""
|
|
462
|
+
Create a StreamingResponseAdapter with proper post-processing.
|
|
463
|
+
|
|
464
|
+
Args:
|
|
465
|
+
user_metadata: User metadata dictionary to update
|
|
466
|
+
|
|
467
|
+
Returns:
|
|
468
|
+
StreamingResponseAdapter: Configured streaming adapter
|
|
469
|
+
"""
|
|
470
|
+
post_process_task = create_stream_post_processing_task(
|
|
471
|
+
self.stream_complete_event,
|
|
472
|
+
self.final_response_container,
|
|
473
|
+
self.prompt,
|
|
474
|
+
self.agent_instance,
|
|
475
|
+
user_metadata,
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
return StreamingResponseAdapter(
|
|
479
|
+
async_response_gen=self.process_stream_events,
|
|
480
|
+
response="", # will be filled post-stream
|
|
481
|
+
metadata={},
|
|
482
|
+
post_process_task=post_process_task,
|
|
483
|
+
)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Shared utilities for agent functionality.
|
|
3
|
+
|
|
4
|
+
This sub-module contains smaller, focused utility functions:
|
|
5
|
+
- schemas: Type conversion and schema handling
|
|
6
|
+
- tools: Tool validation and processing
|
|
7
|
+
- logging: Logging configuration and filters
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
# Import utilities for easy access
|
|
11
|
+
from .schemas import get_field_type, JSON_TYPE_TO_PYTHON, PY_TYPES
|
|
12
|
+
from .tools import (
|
|
13
|
+
sanitize_tools_for_gemini,
|
|
14
|
+
validate_tool_consistency,
|
|
15
|
+
)
|
|
16
|
+
from .logging import IgnoreUnpickleableAttributeFilter, setup_agent_logging
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
# Schemas
|
|
20
|
+
"get_field_type",
|
|
21
|
+
"JSON_TYPE_TO_PYTHON",
|
|
22
|
+
"PY_TYPES",
|
|
23
|
+
# Tools
|
|
24
|
+
"sanitize_tools_for_gemini",
|
|
25
|
+
"validate_tool_consistency",
|
|
26
|
+
# Logging
|
|
27
|
+
"IgnoreUnpickleableAttributeFilter",
|
|
28
|
+
"setup_agent_logging",
|
|
29
|
+
]
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
"""Vectara Hallucination Detection and Correction client."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import List, Optional, Tuple
|
|
5
|
+
import requests
|
|
6
|
+
|
|
7
|
+
from llama_index.core.llms import MessageRole
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Hallucination:
|
|
11
|
+
"""Vectara Hallucination Correction."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, vectara_api_key: str):
|
|
14
|
+
self._vectara_api_key = vectara_api_key
|
|
15
|
+
|
|
16
|
+
def compute(
|
|
17
|
+
self, query: str, context: list[str], hypothesis: str
|
|
18
|
+
) -> Tuple[str, list[str]]:
|
|
19
|
+
"""
|
|
20
|
+
Calls the Vectara VHC (Vectara Hallucination Correction)
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
str: The corrected hypothesis text.
|
|
24
|
+
list[str]: the list of corrections from VHC
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
payload = {
|
|
28
|
+
"generated_text": hypothesis,
|
|
29
|
+
"query": query,
|
|
30
|
+
"documents": [{"text": c} for c in context],
|
|
31
|
+
"model_name": "vhc-large-1.0",
|
|
32
|
+
}
|
|
33
|
+
headers = {
|
|
34
|
+
"Content-Type": "application/json",
|
|
35
|
+
"Accept": "application/json",
|
|
36
|
+
"x-api-key": self._vectara_api_key,
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
response = requests.post(
|
|
40
|
+
"https://api.vectara.io/v2/hallucination_correctors/correct_hallucinations",
|
|
41
|
+
json=payload,
|
|
42
|
+
headers=headers,
|
|
43
|
+
timeout=30,
|
|
44
|
+
)
|
|
45
|
+
response.raise_for_status()
|
|
46
|
+
data = response.json()
|
|
47
|
+
corrected_text = data.get("corrected_text", "")
|
|
48
|
+
corrections = data.get("corrections", [])
|
|
49
|
+
|
|
50
|
+
logging.info(f"VHC: query={query}\n")
|
|
51
|
+
logging.info(f"VHC: response={hypothesis}\n")
|
|
52
|
+
logging.info("VHC: Context:")
|
|
53
|
+
for i, ctx in enumerate(context):
|
|
54
|
+
logging.info(f"VHC: context {i}: {ctx[:200]}\n\n")
|
|
55
|
+
|
|
56
|
+
logging.info(f"VHC: outputs: {len(corrections)} corrections")
|
|
57
|
+
logging.info(f"VHC: corrected_text: {corrected_text}\n")
|
|
58
|
+
for correction in corrections:
|
|
59
|
+
logging.info(f"VHC: correction: {correction}\n")
|
|
60
|
+
|
|
61
|
+
return corrected_text, corrections
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def check_tool_eligibility(tool_name: Optional[str], tools: List) -> bool:
|
|
65
|
+
"""Check if a tool output is eligible to be included in VHC, by looking up in tools list."""
|
|
66
|
+
if not tool_name or not tools:
|
|
67
|
+
return False
|
|
68
|
+
|
|
69
|
+
# Try to find the tool and check its VHC eligibility
|
|
70
|
+
for tool in tools:
|
|
71
|
+
if (
|
|
72
|
+
hasattr(tool, "metadata")
|
|
73
|
+
and hasattr(tool.metadata, "name")
|
|
74
|
+
and tool.metadata.name == tool_name
|
|
75
|
+
):
|
|
76
|
+
if hasattr(tool.metadata, "vhc_eligible"):
|
|
77
|
+
is_vhc_eligible = tool.metadata.vhc_eligible
|
|
78
|
+
return is_vhc_eligible
|
|
79
|
+
break
|
|
80
|
+
|
|
81
|
+
return True
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def analyze_hallucinations(
|
|
85
|
+
query: str,
|
|
86
|
+
chat_history: List,
|
|
87
|
+
agent_response: str,
|
|
88
|
+
tools: List,
|
|
89
|
+
vectara_api_key: str,
|
|
90
|
+
tool_outputs: Optional[List[dict]] = None,
|
|
91
|
+
) -> Tuple[Optional[str], List[str]]:
|
|
92
|
+
"""Use VHC to compute corrected_text and corrections using provided tool data."""
|
|
93
|
+
|
|
94
|
+
if not vectara_api_key:
|
|
95
|
+
logging.warning("VHC: No Vectara API key - returning None")
|
|
96
|
+
return None, []
|
|
97
|
+
|
|
98
|
+
context = []
|
|
99
|
+
|
|
100
|
+
# Process tool outputs if provided
|
|
101
|
+
if tool_outputs:
|
|
102
|
+
tool_output_count = 0
|
|
103
|
+
for tool_output in tool_outputs:
|
|
104
|
+
if tool_output.get("status_type") == "TOOL_OUTPUT" and tool_output.get(
|
|
105
|
+
"content"
|
|
106
|
+
):
|
|
107
|
+
tool_output_count += 1
|
|
108
|
+
tool_name = tool_output.get("tool_name")
|
|
109
|
+
is_vhc_eligible = check_tool_eligibility(tool_name, tools)
|
|
110
|
+
|
|
111
|
+
if is_vhc_eligible:
|
|
112
|
+
content = str(tool_output["content"])
|
|
113
|
+
if content and content.strip():
|
|
114
|
+
context.append(content)
|
|
115
|
+
|
|
116
|
+
logging.info(
|
|
117
|
+
f"VHC: Processed {tool_output_count} tool outputs, added {len(context)} to context so far"
|
|
118
|
+
)
|
|
119
|
+
else:
|
|
120
|
+
logging.info("VHC: No tool outputs provided")
|
|
121
|
+
|
|
122
|
+
# Add user messages and previous assistant messages from chat_history for context
|
|
123
|
+
last_assistant_index = -1
|
|
124
|
+
for i, msg in enumerate(chat_history):
|
|
125
|
+
if msg.role == MessageRole.ASSISTANT and msg.content:
|
|
126
|
+
last_assistant_index = i
|
|
127
|
+
|
|
128
|
+
for i, msg in enumerate(chat_history):
|
|
129
|
+
if msg.role == MessageRole.USER and msg.content:
|
|
130
|
+
# Don't include the current query in context since it's passed separately as query parameter
|
|
131
|
+
if msg.content != query:
|
|
132
|
+
context.append(msg.content)
|
|
133
|
+
|
|
134
|
+
elif msg.role == MessageRole.ASSISTANT and msg.content:
|
|
135
|
+
if i != last_assistant_index: # do not include the last assistant message
|
|
136
|
+
context.append(msg.content)
|
|
137
|
+
|
|
138
|
+
logging.info(f"VHC: Final VHC context has {len(context)} items")
|
|
139
|
+
|
|
140
|
+
# If no context, we cannot compute VHC
|
|
141
|
+
if len(context) == 0:
|
|
142
|
+
logging.info("VHC: No context available for VHC - returning None")
|
|
143
|
+
return None, []
|
|
144
|
+
|
|
145
|
+
try:
|
|
146
|
+
h = Hallucination(vectara_api_key)
|
|
147
|
+
corrected_text, corrections = h.compute(
|
|
148
|
+
query=query, context=context, hypothesis=agent_response
|
|
149
|
+
)
|
|
150
|
+
return corrected_text, corrections
|
|
151
|
+
|
|
152
|
+
except Exception as e:
|
|
153
|
+
logging.warning(
|
|
154
|
+
f"VHC call failed: {e}. "
|
|
155
|
+
"Ensure you have a valid Vectara API key and the Hallucination Correction service is available."
|
|
156
|
+
)
|
|
157
|
+
return None, []
|