vectara-agentic 0.4.2__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. tests/__init__.py +1 -0
  2. tests/benchmark_models.py +547 -372
  3. tests/conftest.py +14 -12
  4. tests/endpoint.py +9 -5
  5. tests/run_tests.py +1 -0
  6. tests/test_agent.py +22 -9
  7. tests/test_agent_fallback_memory.py +4 -4
  8. tests/test_agent_memory_consistency.py +4 -4
  9. tests/test_agent_type.py +2 -0
  10. tests/test_api_endpoint.py +13 -13
  11. tests/test_bedrock.py +9 -1
  12. tests/test_fallback.py +18 -7
  13. tests/test_gemini.py +14 -40
  14. tests/test_groq.py +9 -1
  15. tests/test_private_llm.py +19 -6
  16. tests/test_react_error_handling.py +293 -0
  17. tests/test_react_memory.py +257 -0
  18. tests/test_react_streaming.py +135 -0
  19. tests/test_react_workflow_events.py +395 -0
  20. tests/test_return_direct.py +1 -0
  21. tests/test_serialization.py +58 -20
  22. tests/test_session_memory.py +11 -11
  23. tests/test_together.py +9 -1
  24. tests/test_tools.py +3 -1
  25. tests/test_vectara_llms.py +2 -2
  26. tests/test_vhc.py +7 -2
  27. tests/test_workflow.py +17 -11
  28. vectara_agentic/_callback.py +79 -21
  29. vectara_agentic/_version.py +1 -1
  30. vectara_agentic/agent.py +65 -27
  31. vectara_agentic/agent_core/serialization.py +5 -9
  32. vectara_agentic/agent_core/streaming.py +245 -64
  33. vectara_agentic/agent_core/utils/schemas.py +2 -2
  34. vectara_agentic/llm_utils.py +4 -2
  35. {vectara_agentic-0.4.2.dist-info → vectara_agentic-0.4.3.dist-info}/METADATA +127 -31
  36. vectara_agentic-0.4.3.dist-info/RECORD +58 -0
  37. vectara_agentic-0.4.2.dist-info/RECORD +0 -54
  38. {vectara_agentic-0.4.2.dist-info → vectara_agentic-0.4.3.dist-info}/WHEEL +0 -0
  39. {vectara_agentic-0.4.2.dist-info → vectara_agentic-0.4.3.dist-info}/licenses/LICENSE +0 -0
  40. {vectara_agentic-0.4.2.dist-info → vectara_agentic-0.4.3.dist-info}/top_level.txt +0 -0
tests/test_vhc.py CHANGED
@@ -1,5 +1,6 @@
1
1
  # Suppress external dependency warnings before any other imports
2
2
  import warnings
3
+
3
4
  warnings.simplefilter("ignore", DeprecationWarning)
4
5
 
5
6
  import unittest
@@ -10,6 +11,7 @@ from vectara_agentic.tools import ToolsFactory
10
11
  from vectara_agentic.types import ModelProvider
11
12
 
12
13
  import nest_asyncio
14
+
13
15
  nest_asyncio.apply()
14
16
 
15
17
  statements = [
@@ -20,6 +22,8 @@ statements = [
20
22
  "Chocolate is the best ice cream flavor.",
21
23
  ]
22
24
  st_inx = 0
25
+
26
+
23
27
  def get_statement() -> str:
24
28
  "Generate next statement"
25
29
  global st_inx
@@ -34,7 +38,8 @@ fc_config = AgentConfig(
34
38
  tool_llm_provider=ModelProvider.OPENAI,
35
39
  )
36
40
 
37
- vectara_api_key = 'zqt_UXrBcnI2UXINZkrv4g1tQPhzj02vfdtqYJIDiA'
41
+ vectara_api_key = "zqt_UXrBcnI2UXINZkrv4g1tQPhzj02vfdtqYJIDiA"
42
+
38
43
 
39
44
  class TestVHC(unittest.TestCase):
40
45
 
@@ -59,7 +64,7 @@ class TestVHC(unittest.TestCase):
59
64
  vhc_corrections = vhc_res.get("corrections", [])
60
65
  self.assertTrue(
61
66
  len(vhc_corrections) >= 0 and len(vhc_corrections) <= 2,
62
- "Corrections should be between 0 and 2"
67
+ "Corrections should be between 0 and 2",
63
68
  )
64
69
 
65
70
 
tests/test_workflow.py CHANGED
@@ -1,5 +1,6 @@
1
1
  # Suppress external dependency warnings before any other imports
2
2
  import warnings
3
+
3
4
  warnings.simplefilter("ignore", DeprecationWarning)
4
5
 
5
6
  import unittest
@@ -7,9 +8,13 @@ import unittest
7
8
  from vectara_agentic.agent import Agent
8
9
  from vectara_agentic.agent_config import AgentConfig
9
10
  from vectara_agentic.tools import ToolsFactory
10
- from vectara_agentic.sub_query_workflow import SubQuestionQueryWorkflow, SequentialSubQuestionsWorkflow
11
+ from vectara_agentic.sub_query_workflow import (
12
+ SubQuestionQueryWorkflow,
13
+ SequentialSubQuestionsWorkflow,
14
+ )
11
15
  from conftest import mult, add, STANDARD_TEST_TOPIC, WORKFLOW_TEST_INSTRUCTIONS
12
16
 
17
+
13
18
  class TestWorkflowPackage(unittest.IsolatedAsyncioTestCase):
14
19
 
15
20
  async def test_sub_query_workflow(self):
@@ -18,8 +23,8 @@ class TestWorkflowPackage(unittest.IsolatedAsyncioTestCase):
18
23
  tools=tools,
19
24
  topic=STANDARD_TEST_TOPIC,
20
25
  custom_instructions=WORKFLOW_TEST_INSTRUCTIONS,
21
- agent_config = AgentConfig(),
22
- workflow_cls = SubQuestionQueryWorkflow,
26
+ agent_config=AgentConfig(),
27
+ workflow_cls=SubQuestionQueryWorkflow,
23
28
  )
24
29
 
25
30
  inputs = SubQuestionQueryWorkflow.InputsModel(
@@ -41,8 +46,8 @@ class TestWorkflowPackage(unittest.IsolatedAsyncioTestCase):
41
46
  tools=tools,
42
47
  topic=STANDARD_TEST_TOPIC,
43
48
  custom_instructions=WORKFLOW_TEST_INSTRUCTIONS,
44
- agent_config = AgentConfig(),
45
- workflow_cls = SequentialSubQuestionsWorkflow,
49
+ agent_config=AgentConfig(),
50
+ workflow_cls=SequentialSubQuestionsWorkflow,
46
51
  )
47
52
 
48
53
  inputs = SequentialSubQuestionsWorkflow.InputsModel(
@@ -51,6 +56,7 @@ class TestWorkflowPackage(unittest.IsolatedAsyncioTestCase):
51
56
  res = await agent.run(inputs=inputs, verbose=True)
52
57
  self.assertIn("22", res.response)
53
58
 
59
+
54
60
  class TestWorkflowFailure(unittest.IsolatedAsyncioTestCase):
55
61
 
56
62
  async def test_workflow_failure_sub_question(self):
@@ -59,9 +65,9 @@ class TestWorkflowFailure(unittest.IsolatedAsyncioTestCase):
59
65
  tools=tools,
60
66
  topic=STANDARD_TEST_TOPIC,
61
67
  custom_instructions=WORKFLOW_TEST_INSTRUCTIONS,
62
- agent_config = AgentConfig(),
63
- workflow_cls = SubQuestionQueryWorkflow,
64
- workflow_timeout = 1
68
+ agent_config=AgentConfig(),
69
+ workflow_cls=SubQuestionQueryWorkflow,
70
+ workflow_timeout=1,
65
71
  )
66
72
 
67
73
  inputs = SubQuestionQueryWorkflow.InputsModel(
@@ -76,9 +82,9 @@ class TestWorkflowFailure(unittest.IsolatedAsyncioTestCase):
76
82
  tools=tools,
77
83
  topic=STANDARD_TEST_TOPIC,
78
84
  custom_instructions=WORKFLOW_TEST_INSTRUCTIONS,
79
- agent_config = AgentConfig(),
80
- workflow_cls = SequentialSubQuestionsWorkflow,
81
- workflow_timeout = 1
85
+ agent_config=AgentConfig(),
86
+ workflow_cls=SequentialSubQuestionsWorkflow,
87
+ workflow_timeout=1,
82
88
  )
83
89
 
84
90
  inputs = SequentialSubQuestionsWorkflow.InputsModel(
@@ -38,6 +38,46 @@ def wrap_callback_fn(callback):
38
38
  return new_callback
39
39
 
40
40
 
41
+ def _extract_content_from_response(response) -> str:
42
+ """
43
+ Extract text content from various LLM response formats.
44
+
45
+ Handles different provider response objects and extracts the text content consistently.
46
+
47
+ Args:
48
+ response: Response object from LLM provider
49
+
50
+ Returns:
51
+ str: Extracted text content
52
+ """
53
+ # Handle case where response is a string
54
+ if isinstance(response, str):
55
+ return response
56
+
57
+ # Handle ChatMessage objects with blocks (Anthropic, etc.)
58
+ if hasattr(response, "blocks") and response.blocks:
59
+ text_parts = []
60
+ for block in response.blocks:
61
+ if hasattr(block, "text"):
62
+ text_parts.append(block.text)
63
+ return "".join(text_parts)
64
+
65
+ # Handle responses with content attribute
66
+ if hasattr(response, "content"):
67
+ return str(response.content)
68
+
69
+ # Handle responses with message attribute that has content
70
+ if hasattr(response, "message") and hasattr(response.message, "content"):
71
+ return str(response.message.content)
72
+
73
+ # Handle delta attribute for streaming responses
74
+ if hasattr(response, "delta"):
75
+ return str(response.delta)
76
+
77
+ # Fallback to string conversion
78
+ return str(response)
79
+
80
+
41
81
  class AgentCallbackHandler(BaseCallbackHandler):
42
82
  """
43
83
  Callback handler to track agent status
@@ -151,26 +191,36 @@ class AgentCallbackHandler(BaseCallbackHandler):
151
191
  def _handle_event(
152
192
  self, event_type: CBEventType, payload: Dict[str, Any], event_id: str
153
193
  ) -> None:
154
- if event_type == CBEventType.LLM:
155
- self._handle_llm(payload, event_id)
156
- elif event_type == CBEventType.FUNCTION_CALL:
157
- self._handle_function_call(payload, event_id)
158
- elif event_type == CBEventType.AGENT_STEP:
159
- self._handle_agent_step(payload, event_id)
160
- else:
161
- pass
194
+ try:
195
+ if event_type == CBEventType.LLM:
196
+ self._handle_llm(payload, event_id)
197
+ elif event_type == CBEventType.FUNCTION_CALL:
198
+ self._handle_function_call(payload, event_id)
199
+ elif event_type == CBEventType.AGENT_STEP:
200
+ self._handle_agent_step(payload, event_id)
201
+ else:
202
+ pass
203
+ except Exception as e:
204
+ logging.error(f"Exception in callback handler: {e}")
205
+ logging.error(f"Traceback: {traceback.format_exc()}")
206
+ # Continue execution to prevent callback failures from breaking the agent
162
207
 
163
208
  async def _ahandle_event(
164
209
  self, event_type: CBEventType, payload: Dict[str, Any], event_id: str
165
210
  ) -> None:
166
- if event_type == CBEventType.LLM:
167
- await self._ahandle_llm(payload, event_id)
168
- elif event_type == CBEventType.FUNCTION_CALL:
169
- await self._ahandle_function_call(payload, event_id)
170
- elif event_type == CBEventType.AGENT_STEP:
171
- await self._ahandle_agent_step(payload, event_id)
172
- else:
173
- pass
211
+ try:
212
+ if event_type == CBEventType.LLM:
213
+ await self._ahandle_llm(payload, event_id)
214
+ elif event_type == CBEventType.FUNCTION_CALL:
215
+ await self._ahandle_function_call(payload, event_id)
216
+ elif event_type == CBEventType.AGENT_STEP:
217
+ await self._ahandle_agent_step(payload, event_id)
218
+ else:
219
+ pass
220
+ except Exception as e:
221
+ logging.error(f"Exception in async callback handler: {e}")
222
+ logging.error(f"Traceback: {traceback.format_exc()}")
223
+ # Continue execution to prevent callback failures from breaking the agent
174
224
 
175
225
  # Synchronous handlers
176
226
  def _handle_llm(
@@ -182,17 +232,21 @@ class AgentCallbackHandler(BaseCallbackHandler):
182
232
  response = payload.get(EventPayload.RESPONSE)
183
233
  if response and str(response) not in ["None", "assistant: None"]:
184
234
  if self.fn:
235
+ # Convert response to consistent dict format
236
+ content = _extract_content_from_response(response)
185
237
  self.fn(
186
238
  status_type=AgentStatusType.AGENT_UPDATE,
187
- msg=response,
239
+ msg={"content": content},
188
240
  event_id=event_id,
189
241
  )
190
242
  elif EventPayload.PROMPT in payload:
191
243
  prompt = payload.get(EventPayload.PROMPT)
192
244
  if self.fn:
245
+ # Convert prompt to consistent dict format
246
+ content = str(prompt) if prompt else ""
193
247
  self.fn(
194
248
  status_type=AgentStatusType.AGENT_UPDATE,
195
- msg=prompt,
249
+ msg={"content": content},
196
250
  event_id=event_id,
197
251
  )
198
252
  else:
@@ -253,24 +307,28 @@ class AgentCallbackHandler(BaseCallbackHandler):
253
307
  response = payload.get(EventPayload.RESPONSE)
254
308
  if response and str(response) not in ["None", "assistant: None"]:
255
309
  if self.fn:
310
+ # Convert response to consistent dict format
311
+ content = _extract_content_from_response(response)
256
312
  if inspect.iscoroutinefunction(self.fn):
257
313
  await self.fn(
258
314
  status_type=AgentStatusType.AGENT_UPDATE,
259
- msg=response,
315
+ msg={"content": content},
260
316
  event_id=event_id,
261
317
  )
262
318
  else:
263
319
  self.fn(
264
320
  status_type=AgentStatusType.AGENT_UPDATE,
265
- msg=response,
321
+ msg={"content": content},
266
322
  event_id=event_id,
267
323
  )
268
324
  elif EventPayload.PROMPT in payload:
269
325
  prompt = payload.get(EventPayload.PROMPT)
270
326
  if self.fn:
327
+ # Convert prompt to consistent dict format
328
+ content = str(prompt) if prompt else ""
271
329
  self.fn(
272
330
  status_type=AgentStatusType.AGENT_UPDATE,
273
- msg=prompt,
331
+ msg={"content": content},
274
332
  event_id=event_id,
275
333
  )
276
334
 
@@ -1,4 +1,4 @@
1
1
  """
2
2
  Define the version of the package.
3
3
  """
4
- __version__ = "0.4.2"
4
+ __version__ = "0.4.3"
vectara_agentic/agent.py CHANGED
@@ -22,8 +22,8 @@ from dotenv import load_dotenv
22
22
  # Runtime imports for components used at module level
23
23
  from llama_index.core.llms import MessageRole, ChatMessage
24
24
  from llama_index.core.callbacks import CallbackManager
25
- from llama_index.core.memory import ChatMemoryBuffer
26
- from llama_index.core.storage.chat_store import SimpleChatStore
25
+ from llama_index.core.memory import Memory
26
+
27
27
 
28
28
  # Heavy llama_index imports moved to TYPE_CHECKING for lazy loading
29
29
  if TYPE_CHECKING:
@@ -53,6 +53,7 @@ from .agent_config import AgentConfig
53
53
  # Import utilities from agent core modules
54
54
  from .agent_core.streaming import (
55
55
  FunctionCallingStreamHandler,
56
+ ReActStreamHandler,
56
57
  execute_post_stream_processing,
57
58
  )
58
59
  from .agent_core.factory import create_agent_from_config, create_agent_from_corpus
@@ -168,11 +169,8 @@ class Agent:
168
169
  or f"{topic}:{date.today().isoformat()}"
169
170
  )
170
171
 
171
- chat_store = SimpleChatStore()
172
- self.memory = ChatMemoryBuffer.from_defaults(
173
- chat_store=chat_store,
174
- chat_store_key=self.session_id,
175
- token_limit=65536
172
+ self.memory = Memory.from_defaults(
173
+ session_id=self.session_id, token_limit=65536
176
174
  )
177
175
  if chat_history:
178
176
  msgs = []
@@ -491,6 +489,14 @@ class Agent:
491
489
  # Clear the main agent so it gets recreated with current memory
492
490
  self._agent = None
493
491
 
492
+ def _reset_agent_state(self) -> None:
493
+ """
494
+ Reset agent state to recover from workflow runtime errors.
495
+ Clears both agent instances to force recreation with fresh state.
496
+ """
497
+ self._agent = None
498
+ self._fallback_agent = None
499
+
494
500
  def report(self, detailed: bool = False) -> None:
495
501
  """
496
502
  Get a report from the agent.
@@ -546,11 +552,14 @@ class Agent:
546
552
  AgentResponse: The response from the agent.
547
553
  """
548
554
  try:
549
- _ = asyncio.get_running_loop()
550
- except RuntimeError:
555
+ loop = asyncio.get_running_loop()
556
+ if hasattr(loop, "_nest_level"):
557
+ return asyncio.run(self.achat(prompt))
558
+ except (RuntimeError, ImportError):
559
+ # No running loop or nest_asyncio not available
551
560
  return asyncio.run(self.achat(prompt))
552
561
 
553
- # We are inside a running loop (Jupyter, uvicorn, etc.)
562
+ # We are inside a running loop without nest_asyncio
554
563
  raise RuntimeError(
555
564
  "Use `await agent.achat(...)` inside an event loop (e.g. Jupyter)."
556
565
  )
@@ -565,8 +574,8 @@ class Agent:
565
574
  Returns:
566
575
  AgentResponse: The response from the agent.
567
576
  """
568
- if not prompt:
569
- return AgentResponse(response="")
577
+ if not prompt or not prompt.strip():
578
+ return AgentResponse(response="Please provide a valid prompt.")
570
579
 
571
580
  max_attempts = 4 if self.fallback_agent_config else 2
572
581
  attempt = 0
@@ -593,14 +602,12 @@ class Agent:
593
602
 
594
603
  # Listen to workflow events if progress callback is set
595
604
  if self.agent_progress_callback:
596
- # Create event tracker for consistent event ID generation
597
- from .agent_core.streaming import ToolEventTracker
598
-
599
- event_tracker = ToolEventTracker()
605
+ # Import the event ID utility function
606
+ from .agent_core.streaming import get_event_id
600
607
 
601
608
  async for event in handler.stream_events():
602
609
  # Use consistent event ID tracking to ensure tool calls and outputs are paired
603
- event_id = event_tracker.get_event_id(event)
610
+ event_id = get_event_id(event)
604
611
 
605
612
  # Handle different types of workflow events using same logic as FunctionCallingStreamHandler
606
613
  from llama_index.core.agent.workflow import (
@@ -831,6 +838,27 @@ class Agent:
831
838
  base=streaming_adapter, metadata=user_meta
832
839
  )
833
840
 
841
+ # Deal with ReAct agent type
842
+ elif self._get_current_agent_type() == AgentType.REACT:
843
+ from llama_index.core.workflow import Context
844
+
845
+ # Create context and pass memory to the workflow agent
846
+ ctx = Context(current_agent)
847
+
848
+ handler = current_agent.run(
849
+ user_msg=prompt, memory=self.memory, ctx=ctx
850
+ )
851
+
852
+ # Create a streaming adapter for ReAct with event handling
853
+ react_stream_handler = ReActStreamHandler(self, handler, prompt)
854
+ streaming_adapter = react_stream_handler.create_streaming_response(
855
+ user_meta
856
+ )
857
+
858
+ return AgentStreamingResponse(
859
+ base=streaming_adapter, metadata=user_meta
860
+ )
861
+
834
862
  #
835
863
  # For other agent types, use the standard async chat method
836
864
  #
@@ -870,16 +898,20 @@ class Agent:
870
898
  def _add_tool_output(self, tool_name: str, content: str):
871
899
  """Add a tool output to the current collection for VHC."""
872
900
  tool_output = {
873
- 'status_type': 'TOOL_OUTPUT',
874
- 'content': content,
875
- 'tool_name': tool_name
901
+ "status_type": "TOOL_OUTPUT",
902
+ "content": content,
903
+ "tool_name": tool_name,
876
904
  }
877
905
  self._current_tool_outputs.append(tool_output)
878
- logging.info(f"🔧 [TOOL_STORAGE] Added tool output from '{tool_name}': {len(content)} chars")
906
+ logging.info(
907
+ f"🔧 [TOOL_STORAGE] Added tool output from '{tool_name}': {len(content)} chars"
908
+ )
879
909
 
880
910
  def _get_stored_tool_outputs(self) -> List[dict]:
881
911
  """Get the stored tool outputs from the current query."""
882
- logging.info(f"🔧 [TOOL_STORAGE] Retrieved {len(self._current_tool_outputs)} stored tool outputs")
912
+ logging.info(
913
+ f"🔧 [TOOL_STORAGE] Retrieved {len(self._current_tool_outputs)} stored tool outputs"
914
+ )
883
915
  return self._current_tool_outputs.copy()
884
916
 
885
917
  async def acompute_vhc(self) -> Dict[str, Any]:
@@ -926,7 +958,9 @@ class Agent:
926
958
  )
927
959
 
928
960
  if not last_response:
929
- logging.info("🔍 [VHC_AGENT] Returning early - no last assistant response found")
961
+ logging.info(
962
+ "🔍 [VHC_AGENT] Returning early - no last assistant response found"
963
+ )
930
964
  return {"corrected_text": None, "corrections": []}
931
965
 
932
966
  # Update stored response for caching
@@ -944,7 +978,9 @@ class Agent:
944
978
  f"🔍 [VHC_AGENT] acompute_vhc called with vectara_api_key={'set' if self.vectara_api_key else 'None'}"
945
979
  )
946
980
  if not self.vectara_api_key:
947
- logging.info("🔍 [VHC_AGENT] No vectara_api_key - returning early with None")
981
+ logging.info(
982
+ "🔍 [VHC_AGENT] No vectara_api_key - returning early with None"
983
+ )
948
984
  return {"corrected_text": None, "corrections": []}
949
985
 
950
986
  # Compute VHC using existing library function
@@ -953,7 +989,9 @@ class Agent:
953
989
  try:
954
990
  # Use stored tool outputs from current query
955
991
  stored_tool_outputs = self._get_stored_tool_outputs()
956
- logging.info(f"🔧 [VHC_AGENT] Using {len(stored_tool_outputs)} stored tool outputs for VHC")
992
+ logging.info(
993
+ f"🔧 [VHC_AGENT] Using {len(stored_tool_outputs)} stored tool outputs for VHC"
994
+ )
957
995
 
958
996
  corrected_text, corrections = analyze_hallucinations(
959
997
  query=self._last_query,
@@ -1111,9 +1149,9 @@ class Agent:
1111
1149
  """Clean up resources used by the agent."""
1112
1150
  from ._observability import shutdown_observer
1113
1151
 
1114
- if hasattr(self, 'agent') and hasattr(self.agent, '_llm'):
1152
+ if hasattr(self, "agent") and hasattr(self.agent, "_llm"):
1115
1153
  llm = self.agent._llm
1116
- if hasattr(llm, 'client') and hasattr(llm.client, 'close'):
1154
+ if hasattr(llm, "client") and hasattr(llm.client, "close"):
1117
1155
  try:
1118
1156
  if asyncio.iscoroutinefunction(llm.client.close):
1119
1157
  asyncio.run(llm.client.close())
@@ -13,8 +13,7 @@ from typing import Dict, Any, List, Optional, Callable
13
13
 
14
14
  import cloudpickle as pickle
15
15
  from pydantic import Field, create_model, BaseModel
16
- from llama_index.core.memory import ChatMemoryBuffer
17
- from llama_index.core.storage.chat_store import SimpleChatStore
16
+ from llama_index.core.memory import Memory
18
17
  from llama_index.core.llms import ChatMessage
19
18
  from llama_index.core.tools import FunctionTool
20
19
 
@@ -23,8 +22,7 @@ from ..tools import VectaraTool
23
22
  from ..types import ToolType
24
23
  from .utils.schemas import get_field_type
25
24
 
26
-
27
- def restore_memory_from_dict(data: Dict[str, Any], session_id: str, token_limit: int = 65536) -> ChatMemoryBuffer:
25
+ def restore_memory_from_dict(data: Dict[str, Any], session_id: str, token_limit: int = 65536) -> Memory:
28
26
  """
29
27
  Restore agent memory from serialized dictionary data.
30
28
 
@@ -36,12 +34,10 @@ def restore_memory_from_dict(data: Dict[str, Any], session_id: str, token_limit:
36
34
  token_limit: Token limit for the memory instance
37
35
 
38
36
  Returns:
39
- ChatMemoryBuffer: Restored memory instance
37
+ Memory: Restored memory instance
40
38
  """
41
- chat_store = SimpleChatStore()
42
- mem = ChatMemoryBuffer.from_defaults(
43
- chat_store=chat_store,
44
- chat_store_key=session_id,
39
+ mem = Memory.from_defaults(
40
+ session_id=session_id,
45
41
  token_limit=token_limit
46
42
  )
47
43