vectara-agentic 0.4.2__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +1 -0
- tests/benchmark_models.py +547 -372
- tests/conftest.py +14 -12
- tests/endpoint.py +9 -5
- tests/run_tests.py +1 -0
- tests/test_agent.py +22 -9
- tests/test_agent_fallback_memory.py +4 -4
- tests/test_agent_memory_consistency.py +4 -4
- tests/test_agent_type.py +2 -0
- tests/test_api_endpoint.py +13 -13
- tests/test_bedrock.py +9 -1
- tests/test_fallback.py +18 -7
- tests/test_gemini.py +14 -40
- tests/test_groq.py +9 -1
- tests/test_private_llm.py +19 -6
- tests/test_react_error_handling.py +293 -0
- tests/test_react_memory.py +257 -0
- tests/test_react_streaming.py +135 -0
- tests/test_react_workflow_events.py +395 -0
- tests/test_return_direct.py +1 -0
- tests/test_serialization.py +58 -20
- tests/test_session_memory.py +11 -11
- tests/test_together.py +9 -1
- tests/test_tools.py +3 -1
- tests/test_vectara_llms.py +2 -2
- tests/test_vhc.py +7 -2
- tests/test_workflow.py +17 -11
- vectara_agentic/_callback.py +79 -21
- vectara_agentic/_version.py +1 -1
- vectara_agentic/agent.py +65 -27
- vectara_agentic/agent_core/serialization.py +5 -9
- vectara_agentic/agent_core/streaming.py +245 -64
- vectara_agentic/agent_core/utils/schemas.py +2 -2
- vectara_agentic/llm_utils.py +4 -2
- {vectara_agentic-0.4.2.dist-info → vectara_agentic-0.4.3.dist-info}/METADATA +127 -31
- vectara_agentic-0.4.3.dist-info/RECORD +58 -0
- vectara_agentic-0.4.2.dist-info/RECORD +0 -54
- {vectara_agentic-0.4.2.dist-info → vectara_agentic-0.4.3.dist-info}/WHEEL +0 -0
- {vectara_agentic-0.4.2.dist-info → vectara_agentic-0.4.3.dist-info}/licenses/LICENSE +0 -0
- {vectara_agentic-0.4.2.dist-info → vectara_agentic-0.4.3.dist-info}/top_level.txt +0 -0
tests/test_vhc.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
# Suppress external dependency warnings before any other imports
|
|
2
2
|
import warnings
|
|
3
|
+
|
|
3
4
|
warnings.simplefilter("ignore", DeprecationWarning)
|
|
4
5
|
|
|
5
6
|
import unittest
|
|
@@ -10,6 +11,7 @@ from vectara_agentic.tools import ToolsFactory
|
|
|
10
11
|
from vectara_agentic.types import ModelProvider
|
|
11
12
|
|
|
12
13
|
import nest_asyncio
|
|
14
|
+
|
|
13
15
|
nest_asyncio.apply()
|
|
14
16
|
|
|
15
17
|
statements = [
|
|
@@ -20,6 +22,8 @@ statements = [
|
|
|
20
22
|
"Chocolate is the best ice cream flavor.",
|
|
21
23
|
]
|
|
22
24
|
st_inx = 0
|
|
25
|
+
|
|
26
|
+
|
|
23
27
|
def get_statement() -> str:
|
|
24
28
|
"Generate next statement"
|
|
25
29
|
global st_inx
|
|
@@ -34,7 +38,8 @@ fc_config = AgentConfig(
|
|
|
34
38
|
tool_llm_provider=ModelProvider.OPENAI,
|
|
35
39
|
)
|
|
36
40
|
|
|
37
|
-
vectara_api_key =
|
|
41
|
+
vectara_api_key = "zqt_UXrBcnI2UXINZkrv4g1tQPhzj02vfdtqYJIDiA"
|
|
42
|
+
|
|
38
43
|
|
|
39
44
|
class TestVHC(unittest.TestCase):
|
|
40
45
|
|
|
@@ -59,7 +64,7 @@ class TestVHC(unittest.TestCase):
|
|
|
59
64
|
vhc_corrections = vhc_res.get("corrections", [])
|
|
60
65
|
self.assertTrue(
|
|
61
66
|
len(vhc_corrections) >= 0 and len(vhc_corrections) <= 2,
|
|
62
|
-
"Corrections should be between 0 and 2"
|
|
67
|
+
"Corrections should be between 0 and 2",
|
|
63
68
|
)
|
|
64
69
|
|
|
65
70
|
|
tests/test_workflow.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
# Suppress external dependency warnings before any other imports
|
|
2
2
|
import warnings
|
|
3
|
+
|
|
3
4
|
warnings.simplefilter("ignore", DeprecationWarning)
|
|
4
5
|
|
|
5
6
|
import unittest
|
|
@@ -7,9 +8,13 @@ import unittest
|
|
|
7
8
|
from vectara_agentic.agent import Agent
|
|
8
9
|
from vectara_agentic.agent_config import AgentConfig
|
|
9
10
|
from vectara_agentic.tools import ToolsFactory
|
|
10
|
-
from vectara_agentic.sub_query_workflow import
|
|
11
|
+
from vectara_agentic.sub_query_workflow import (
|
|
12
|
+
SubQuestionQueryWorkflow,
|
|
13
|
+
SequentialSubQuestionsWorkflow,
|
|
14
|
+
)
|
|
11
15
|
from conftest import mult, add, STANDARD_TEST_TOPIC, WORKFLOW_TEST_INSTRUCTIONS
|
|
12
16
|
|
|
17
|
+
|
|
13
18
|
class TestWorkflowPackage(unittest.IsolatedAsyncioTestCase):
|
|
14
19
|
|
|
15
20
|
async def test_sub_query_workflow(self):
|
|
@@ -18,8 +23,8 @@ class TestWorkflowPackage(unittest.IsolatedAsyncioTestCase):
|
|
|
18
23
|
tools=tools,
|
|
19
24
|
topic=STANDARD_TEST_TOPIC,
|
|
20
25
|
custom_instructions=WORKFLOW_TEST_INSTRUCTIONS,
|
|
21
|
-
agent_config
|
|
22
|
-
workflow_cls
|
|
26
|
+
agent_config=AgentConfig(),
|
|
27
|
+
workflow_cls=SubQuestionQueryWorkflow,
|
|
23
28
|
)
|
|
24
29
|
|
|
25
30
|
inputs = SubQuestionQueryWorkflow.InputsModel(
|
|
@@ -41,8 +46,8 @@ class TestWorkflowPackage(unittest.IsolatedAsyncioTestCase):
|
|
|
41
46
|
tools=tools,
|
|
42
47
|
topic=STANDARD_TEST_TOPIC,
|
|
43
48
|
custom_instructions=WORKFLOW_TEST_INSTRUCTIONS,
|
|
44
|
-
agent_config
|
|
45
|
-
workflow_cls
|
|
49
|
+
agent_config=AgentConfig(),
|
|
50
|
+
workflow_cls=SequentialSubQuestionsWorkflow,
|
|
46
51
|
)
|
|
47
52
|
|
|
48
53
|
inputs = SequentialSubQuestionsWorkflow.InputsModel(
|
|
@@ -51,6 +56,7 @@ class TestWorkflowPackage(unittest.IsolatedAsyncioTestCase):
|
|
|
51
56
|
res = await agent.run(inputs=inputs, verbose=True)
|
|
52
57
|
self.assertIn("22", res.response)
|
|
53
58
|
|
|
59
|
+
|
|
54
60
|
class TestWorkflowFailure(unittest.IsolatedAsyncioTestCase):
|
|
55
61
|
|
|
56
62
|
async def test_workflow_failure_sub_question(self):
|
|
@@ -59,9 +65,9 @@ class TestWorkflowFailure(unittest.IsolatedAsyncioTestCase):
|
|
|
59
65
|
tools=tools,
|
|
60
66
|
topic=STANDARD_TEST_TOPIC,
|
|
61
67
|
custom_instructions=WORKFLOW_TEST_INSTRUCTIONS,
|
|
62
|
-
agent_config
|
|
63
|
-
workflow_cls
|
|
64
|
-
workflow_timeout
|
|
68
|
+
agent_config=AgentConfig(),
|
|
69
|
+
workflow_cls=SubQuestionQueryWorkflow,
|
|
70
|
+
workflow_timeout=1,
|
|
65
71
|
)
|
|
66
72
|
|
|
67
73
|
inputs = SubQuestionQueryWorkflow.InputsModel(
|
|
@@ -76,9 +82,9 @@ class TestWorkflowFailure(unittest.IsolatedAsyncioTestCase):
|
|
|
76
82
|
tools=tools,
|
|
77
83
|
topic=STANDARD_TEST_TOPIC,
|
|
78
84
|
custom_instructions=WORKFLOW_TEST_INSTRUCTIONS,
|
|
79
|
-
agent_config
|
|
80
|
-
workflow_cls
|
|
81
|
-
workflow_timeout
|
|
85
|
+
agent_config=AgentConfig(),
|
|
86
|
+
workflow_cls=SequentialSubQuestionsWorkflow,
|
|
87
|
+
workflow_timeout=1,
|
|
82
88
|
)
|
|
83
89
|
|
|
84
90
|
inputs = SequentialSubQuestionsWorkflow.InputsModel(
|
vectara_agentic/_callback.py
CHANGED
|
@@ -38,6 +38,46 @@ def wrap_callback_fn(callback):
|
|
|
38
38
|
return new_callback
|
|
39
39
|
|
|
40
40
|
|
|
41
|
+
def _extract_content_from_response(response) -> str:
|
|
42
|
+
"""
|
|
43
|
+
Extract text content from various LLM response formats.
|
|
44
|
+
|
|
45
|
+
Handles different provider response objects and extracts the text content consistently.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
response: Response object from LLM provider
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
str: Extracted text content
|
|
52
|
+
"""
|
|
53
|
+
# Handle case where response is a string
|
|
54
|
+
if isinstance(response, str):
|
|
55
|
+
return response
|
|
56
|
+
|
|
57
|
+
# Handle ChatMessage objects with blocks (Anthropic, etc.)
|
|
58
|
+
if hasattr(response, "blocks") and response.blocks:
|
|
59
|
+
text_parts = []
|
|
60
|
+
for block in response.blocks:
|
|
61
|
+
if hasattr(block, "text"):
|
|
62
|
+
text_parts.append(block.text)
|
|
63
|
+
return "".join(text_parts)
|
|
64
|
+
|
|
65
|
+
# Handle responses with content attribute
|
|
66
|
+
if hasattr(response, "content"):
|
|
67
|
+
return str(response.content)
|
|
68
|
+
|
|
69
|
+
# Handle responses with message attribute that has content
|
|
70
|
+
if hasattr(response, "message") and hasattr(response.message, "content"):
|
|
71
|
+
return str(response.message.content)
|
|
72
|
+
|
|
73
|
+
# Handle delta attribute for streaming responses
|
|
74
|
+
if hasattr(response, "delta"):
|
|
75
|
+
return str(response.delta)
|
|
76
|
+
|
|
77
|
+
# Fallback to string conversion
|
|
78
|
+
return str(response)
|
|
79
|
+
|
|
80
|
+
|
|
41
81
|
class AgentCallbackHandler(BaseCallbackHandler):
|
|
42
82
|
"""
|
|
43
83
|
Callback handler to track agent status
|
|
@@ -151,26 +191,36 @@ class AgentCallbackHandler(BaseCallbackHandler):
|
|
|
151
191
|
def _handle_event(
|
|
152
192
|
self, event_type: CBEventType, payload: Dict[str, Any], event_id: str
|
|
153
193
|
) -> None:
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
194
|
+
try:
|
|
195
|
+
if event_type == CBEventType.LLM:
|
|
196
|
+
self._handle_llm(payload, event_id)
|
|
197
|
+
elif event_type == CBEventType.FUNCTION_CALL:
|
|
198
|
+
self._handle_function_call(payload, event_id)
|
|
199
|
+
elif event_type == CBEventType.AGENT_STEP:
|
|
200
|
+
self._handle_agent_step(payload, event_id)
|
|
201
|
+
else:
|
|
202
|
+
pass
|
|
203
|
+
except Exception as e:
|
|
204
|
+
logging.error(f"Exception in callback handler: {e}")
|
|
205
|
+
logging.error(f"Traceback: {traceback.format_exc()}")
|
|
206
|
+
# Continue execution to prevent callback failures from breaking the agent
|
|
162
207
|
|
|
163
208
|
async def _ahandle_event(
|
|
164
209
|
self, event_type: CBEventType, payload: Dict[str, Any], event_id: str
|
|
165
210
|
) -> None:
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
211
|
+
try:
|
|
212
|
+
if event_type == CBEventType.LLM:
|
|
213
|
+
await self._ahandle_llm(payload, event_id)
|
|
214
|
+
elif event_type == CBEventType.FUNCTION_CALL:
|
|
215
|
+
await self._ahandle_function_call(payload, event_id)
|
|
216
|
+
elif event_type == CBEventType.AGENT_STEP:
|
|
217
|
+
await self._ahandle_agent_step(payload, event_id)
|
|
218
|
+
else:
|
|
219
|
+
pass
|
|
220
|
+
except Exception as e:
|
|
221
|
+
logging.error(f"Exception in async callback handler: {e}")
|
|
222
|
+
logging.error(f"Traceback: {traceback.format_exc()}")
|
|
223
|
+
# Continue execution to prevent callback failures from breaking the agent
|
|
174
224
|
|
|
175
225
|
# Synchronous handlers
|
|
176
226
|
def _handle_llm(
|
|
@@ -182,17 +232,21 @@ class AgentCallbackHandler(BaseCallbackHandler):
|
|
|
182
232
|
response = payload.get(EventPayload.RESPONSE)
|
|
183
233
|
if response and str(response) not in ["None", "assistant: None"]:
|
|
184
234
|
if self.fn:
|
|
235
|
+
# Convert response to consistent dict format
|
|
236
|
+
content = _extract_content_from_response(response)
|
|
185
237
|
self.fn(
|
|
186
238
|
status_type=AgentStatusType.AGENT_UPDATE,
|
|
187
|
-
msg=
|
|
239
|
+
msg={"content": content},
|
|
188
240
|
event_id=event_id,
|
|
189
241
|
)
|
|
190
242
|
elif EventPayload.PROMPT in payload:
|
|
191
243
|
prompt = payload.get(EventPayload.PROMPT)
|
|
192
244
|
if self.fn:
|
|
245
|
+
# Convert prompt to consistent dict format
|
|
246
|
+
content = str(prompt) if prompt else ""
|
|
193
247
|
self.fn(
|
|
194
248
|
status_type=AgentStatusType.AGENT_UPDATE,
|
|
195
|
-
msg=
|
|
249
|
+
msg={"content": content},
|
|
196
250
|
event_id=event_id,
|
|
197
251
|
)
|
|
198
252
|
else:
|
|
@@ -253,24 +307,28 @@ class AgentCallbackHandler(BaseCallbackHandler):
|
|
|
253
307
|
response = payload.get(EventPayload.RESPONSE)
|
|
254
308
|
if response and str(response) not in ["None", "assistant: None"]:
|
|
255
309
|
if self.fn:
|
|
310
|
+
# Convert response to consistent dict format
|
|
311
|
+
content = _extract_content_from_response(response)
|
|
256
312
|
if inspect.iscoroutinefunction(self.fn):
|
|
257
313
|
await self.fn(
|
|
258
314
|
status_type=AgentStatusType.AGENT_UPDATE,
|
|
259
|
-
msg=
|
|
315
|
+
msg={"content": content},
|
|
260
316
|
event_id=event_id,
|
|
261
317
|
)
|
|
262
318
|
else:
|
|
263
319
|
self.fn(
|
|
264
320
|
status_type=AgentStatusType.AGENT_UPDATE,
|
|
265
|
-
msg=
|
|
321
|
+
msg={"content": content},
|
|
266
322
|
event_id=event_id,
|
|
267
323
|
)
|
|
268
324
|
elif EventPayload.PROMPT in payload:
|
|
269
325
|
prompt = payload.get(EventPayload.PROMPT)
|
|
270
326
|
if self.fn:
|
|
327
|
+
# Convert prompt to consistent dict format
|
|
328
|
+
content = str(prompt) if prompt else ""
|
|
271
329
|
self.fn(
|
|
272
330
|
status_type=AgentStatusType.AGENT_UPDATE,
|
|
273
|
-
msg=
|
|
331
|
+
msg={"content": content},
|
|
274
332
|
event_id=event_id,
|
|
275
333
|
)
|
|
276
334
|
|
vectara_agentic/_version.py
CHANGED
vectara_agentic/agent.py
CHANGED
|
@@ -22,8 +22,8 @@ from dotenv import load_dotenv
|
|
|
22
22
|
# Runtime imports for components used at module level
|
|
23
23
|
from llama_index.core.llms import MessageRole, ChatMessage
|
|
24
24
|
from llama_index.core.callbacks import CallbackManager
|
|
25
|
-
from llama_index.core.memory import
|
|
26
|
-
|
|
25
|
+
from llama_index.core.memory import Memory
|
|
26
|
+
|
|
27
27
|
|
|
28
28
|
# Heavy llama_index imports moved to TYPE_CHECKING for lazy loading
|
|
29
29
|
if TYPE_CHECKING:
|
|
@@ -53,6 +53,7 @@ from .agent_config import AgentConfig
|
|
|
53
53
|
# Import utilities from agent core modules
|
|
54
54
|
from .agent_core.streaming import (
|
|
55
55
|
FunctionCallingStreamHandler,
|
|
56
|
+
ReActStreamHandler,
|
|
56
57
|
execute_post_stream_processing,
|
|
57
58
|
)
|
|
58
59
|
from .agent_core.factory import create_agent_from_config, create_agent_from_corpus
|
|
@@ -168,11 +169,8 @@ class Agent:
|
|
|
168
169
|
or f"{topic}:{date.today().isoformat()}"
|
|
169
170
|
)
|
|
170
171
|
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
chat_store=chat_store,
|
|
174
|
-
chat_store_key=self.session_id,
|
|
175
|
-
token_limit=65536
|
|
172
|
+
self.memory = Memory.from_defaults(
|
|
173
|
+
session_id=self.session_id, token_limit=65536
|
|
176
174
|
)
|
|
177
175
|
if chat_history:
|
|
178
176
|
msgs = []
|
|
@@ -491,6 +489,14 @@ class Agent:
|
|
|
491
489
|
# Clear the main agent so it gets recreated with current memory
|
|
492
490
|
self._agent = None
|
|
493
491
|
|
|
492
|
+
def _reset_agent_state(self) -> None:
|
|
493
|
+
"""
|
|
494
|
+
Reset agent state to recover from workflow runtime errors.
|
|
495
|
+
Clears both agent instances to force recreation with fresh state.
|
|
496
|
+
"""
|
|
497
|
+
self._agent = None
|
|
498
|
+
self._fallback_agent = None
|
|
499
|
+
|
|
494
500
|
def report(self, detailed: bool = False) -> None:
|
|
495
501
|
"""
|
|
496
502
|
Get a report from the agent.
|
|
@@ -546,11 +552,14 @@ class Agent:
|
|
|
546
552
|
AgentResponse: The response from the agent.
|
|
547
553
|
"""
|
|
548
554
|
try:
|
|
549
|
-
|
|
550
|
-
|
|
555
|
+
loop = asyncio.get_running_loop()
|
|
556
|
+
if hasattr(loop, "_nest_level"):
|
|
557
|
+
return asyncio.run(self.achat(prompt))
|
|
558
|
+
except (RuntimeError, ImportError):
|
|
559
|
+
# No running loop or nest_asyncio not available
|
|
551
560
|
return asyncio.run(self.achat(prompt))
|
|
552
561
|
|
|
553
|
-
# We are inside a running loop
|
|
562
|
+
# We are inside a running loop without nest_asyncio
|
|
554
563
|
raise RuntimeError(
|
|
555
564
|
"Use `await agent.achat(...)` inside an event loop (e.g. Jupyter)."
|
|
556
565
|
)
|
|
@@ -565,8 +574,8 @@ class Agent:
|
|
|
565
574
|
Returns:
|
|
566
575
|
AgentResponse: The response from the agent.
|
|
567
576
|
"""
|
|
568
|
-
if not prompt:
|
|
569
|
-
return AgentResponse(response="")
|
|
577
|
+
if not prompt or not prompt.strip():
|
|
578
|
+
return AgentResponse(response="Please provide a valid prompt.")
|
|
570
579
|
|
|
571
580
|
max_attempts = 4 if self.fallback_agent_config else 2
|
|
572
581
|
attempt = 0
|
|
@@ -593,14 +602,12 @@ class Agent:
|
|
|
593
602
|
|
|
594
603
|
# Listen to workflow events if progress callback is set
|
|
595
604
|
if self.agent_progress_callback:
|
|
596
|
-
#
|
|
597
|
-
from .agent_core.streaming import
|
|
598
|
-
|
|
599
|
-
event_tracker = ToolEventTracker()
|
|
605
|
+
# Import the event ID utility function
|
|
606
|
+
from .agent_core.streaming import get_event_id
|
|
600
607
|
|
|
601
608
|
async for event in handler.stream_events():
|
|
602
609
|
# Use consistent event ID tracking to ensure tool calls and outputs are paired
|
|
603
|
-
event_id =
|
|
610
|
+
event_id = get_event_id(event)
|
|
604
611
|
|
|
605
612
|
# Handle different types of workflow events using same logic as FunctionCallingStreamHandler
|
|
606
613
|
from llama_index.core.agent.workflow import (
|
|
@@ -831,6 +838,27 @@ class Agent:
|
|
|
831
838
|
base=streaming_adapter, metadata=user_meta
|
|
832
839
|
)
|
|
833
840
|
|
|
841
|
+
# Deal with ReAct agent type
|
|
842
|
+
elif self._get_current_agent_type() == AgentType.REACT:
|
|
843
|
+
from llama_index.core.workflow import Context
|
|
844
|
+
|
|
845
|
+
# Create context and pass memory to the workflow agent
|
|
846
|
+
ctx = Context(current_agent)
|
|
847
|
+
|
|
848
|
+
handler = current_agent.run(
|
|
849
|
+
user_msg=prompt, memory=self.memory, ctx=ctx
|
|
850
|
+
)
|
|
851
|
+
|
|
852
|
+
# Create a streaming adapter for ReAct with event handling
|
|
853
|
+
react_stream_handler = ReActStreamHandler(self, handler, prompt)
|
|
854
|
+
streaming_adapter = react_stream_handler.create_streaming_response(
|
|
855
|
+
user_meta
|
|
856
|
+
)
|
|
857
|
+
|
|
858
|
+
return AgentStreamingResponse(
|
|
859
|
+
base=streaming_adapter, metadata=user_meta
|
|
860
|
+
)
|
|
861
|
+
|
|
834
862
|
#
|
|
835
863
|
# For other agent types, use the standard async chat method
|
|
836
864
|
#
|
|
@@ -870,16 +898,20 @@ class Agent:
|
|
|
870
898
|
def _add_tool_output(self, tool_name: str, content: str):
|
|
871
899
|
"""Add a tool output to the current collection for VHC."""
|
|
872
900
|
tool_output = {
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
901
|
+
"status_type": "TOOL_OUTPUT",
|
|
902
|
+
"content": content,
|
|
903
|
+
"tool_name": tool_name,
|
|
876
904
|
}
|
|
877
905
|
self._current_tool_outputs.append(tool_output)
|
|
878
|
-
logging.info(
|
|
906
|
+
logging.info(
|
|
907
|
+
f"🔧 [TOOL_STORAGE] Added tool output from '{tool_name}': {len(content)} chars"
|
|
908
|
+
)
|
|
879
909
|
|
|
880
910
|
def _get_stored_tool_outputs(self) -> List[dict]:
|
|
881
911
|
"""Get the stored tool outputs from the current query."""
|
|
882
|
-
logging.info(
|
|
912
|
+
logging.info(
|
|
913
|
+
f"🔧 [TOOL_STORAGE] Retrieved {len(self._current_tool_outputs)} stored tool outputs"
|
|
914
|
+
)
|
|
883
915
|
return self._current_tool_outputs.copy()
|
|
884
916
|
|
|
885
917
|
async def acompute_vhc(self) -> Dict[str, Any]:
|
|
@@ -926,7 +958,9 @@ class Agent:
|
|
|
926
958
|
)
|
|
927
959
|
|
|
928
960
|
if not last_response:
|
|
929
|
-
logging.info(
|
|
961
|
+
logging.info(
|
|
962
|
+
"🔍 [VHC_AGENT] Returning early - no last assistant response found"
|
|
963
|
+
)
|
|
930
964
|
return {"corrected_text": None, "corrections": []}
|
|
931
965
|
|
|
932
966
|
# Update stored response for caching
|
|
@@ -944,7 +978,9 @@ class Agent:
|
|
|
944
978
|
f"🔍 [VHC_AGENT] acompute_vhc called with vectara_api_key={'set' if self.vectara_api_key else 'None'}"
|
|
945
979
|
)
|
|
946
980
|
if not self.vectara_api_key:
|
|
947
|
-
logging.info(
|
|
981
|
+
logging.info(
|
|
982
|
+
"🔍 [VHC_AGENT] No vectara_api_key - returning early with None"
|
|
983
|
+
)
|
|
948
984
|
return {"corrected_text": None, "corrections": []}
|
|
949
985
|
|
|
950
986
|
# Compute VHC using existing library function
|
|
@@ -953,7 +989,9 @@ class Agent:
|
|
|
953
989
|
try:
|
|
954
990
|
# Use stored tool outputs from current query
|
|
955
991
|
stored_tool_outputs = self._get_stored_tool_outputs()
|
|
956
|
-
logging.info(
|
|
992
|
+
logging.info(
|
|
993
|
+
f"🔧 [VHC_AGENT] Using {len(stored_tool_outputs)} stored tool outputs for VHC"
|
|
994
|
+
)
|
|
957
995
|
|
|
958
996
|
corrected_text, corrections = analyze_hallucinations(
|
|
959
997
|
query=self._last_query,
|
|
@@ -1111,9 +1149,9 @@ class Agent:
|
|
|
1111
1149
|
"""Clean up resources used by the agent."""
|
|
1112
1150
|
from ._observability import shutdown_observer
|
|
1113
1151
|
|
|
1114
|
-
if hasattr(self,
|
|
1152
|
+
if hasattr(self, "agent") and hasattr(self.agent, "_llm"):
|
|
1115
1153
|
llm = self.agent._llm
|
|
1116
|
-
if hasattr(llm,
|
|
1154
|
+
if hasattr(llm, "client") and hasattr(llm.client, "close"):
|
|
1117
1155
|
try:
|
|
1118
1156
|
if asyncio.iscoroutinefunction(llm.client.close):
|
|
1119
1157
|
asyncio.run(llm.client.close())
|
|
@@ -13,8 +13,7 @@ from typing import Dict, Any, List, Optional, Callable
|
|
|
13
13
|
|
|
14
14
|
import cloudpickle as pickle
|
|
15
15
|
from pydantic import Field, create_model, BaseModel
|
|
16
|
-
from llama_index.core.memory import
|
|
17
|
-
from llama_index.core.storage.chat_store import SimpleChatStore
|
|
16
|
+
from llama_index.core.memory import Memory
|
|
18
17
|
from llama_index.core.llms import ChatMessage
|
|
19
18
|
from llama_index.core.tools import FunctionTool
|
|
20
19
|
|
|
@@ -23,8 +22,7 @@ from ..tools import VectaraTool
|
|
|
23
22
|
from ..types import ToolType
|
|
24
23
|
from .utils.schemas import get_field_type
|
|
25
24
|
|
|
26
|
-
|
|
27
|
-
def restore_memory_from_dict(data: Dict[str, Any], session_id: str, token_limit: int = 65536) -> ChatMemoryBuffer:
|
|
25
|
+
def restore_memory_from_dict(data: Dict[str, Any], session_id: str, token_limit: int = 65536) -> Memory:
|
|
28
26
|
"""
|
|
29
27
|
Restore agent memory from serialized dictionary data.
|
|
30
28
|
|
|
@@ -36,12 +34,10 @@ def restore_memory_from_dict(data: Dict[str, Any], session_id: str, token_limit:
|
|
|
36
34
|
token_limit: Token limit for the memory instance
|
|
37
35
|
|
|
38
36
|
Returns:
|
|
39
|
-
|
|
37
|
+
Memory: Restored memory instance
|
|
40
38
|
"""
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
chat_store=chat_store,
|
|
44
|
-
chat_store_key=session_id,
|
|
39
|
+
mem = Memory.from_defaults(
|
|
40
|
+
session_id=session_id,
|
|
45
41
|
token_limit=token_limit
|
|
46
42
|
)
|
|
47
43
|
|