agentrun-sdk 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of agentrun-sdk might be problematic. Click here for more details.

Files changed (115) hide show
  1. agentrun_operation_sdk/cli/__init__.py +1 -0
  2. agentrun_operation_sdk/cli/cli.py +19 -0
  3. agentrun_operation_sdk/cli/common.py +21 -0
  4. agentrun_operation_sdk/cli/runtime/__init__.py +1 -0
  5. agentrun_operation_sdk/cli/runtime/commands.py +203 -0
  6. agentrun_operation_sdk/client/client.py +75 -0
  7. agentrun_operation_sdk/operations/runtime/__init__.py +8 -0
  8. agentrun_operation_sdk/operations/runtime/configure.py +101 -0
  9. agentrun_operation_sdk/operations/runtime/launch.py +82 -0
  10. agentrun_operation_sdk/operations/runtime/models.py +31 -0
  11. agentrun_operation_sdk/services/runtime.py +152 -0
  12. agentrun_operation_sdk/utils/logging_config.py +72 -0
  13. agentrun_operation_sdk/utils/runtime/config.py +94 -0
  14. agentrun_operation_sdk/utils/runtime/container.py +280 -0
  15. agentrun_operation_sdk/utils/runtime/entrypoint.py +203 -0
  16. agentrun_operation_sdk/utils/runtime/schema.py +56 -0
  17. agentrun_sdk/__init__.py +7 -0
  18. agentrun_sdk/agent/__init__.py +25 -0
  19. agentrun_sdk/agent/agent.py +696 -0
  20. agentrun_sdk/agent/agent_result.py +46 -0
  21. agentrun_sdk/agent/conversation_manager/__init__.py +26 -0
  22. agentrun_sdk/agent/conversation_manager/conversation_manager.py +88 -0
  23. agentrun_sdk/agent/conversation_manager/null_conversation_manager.py +46 -0
  24. agentrun_sdk/agent/conversation_manager/sliding_window_conversation_manager.py +179 -0
  25. agentrun_sdk/agent/conversation_manager/summarizing_conversation_manager.py +252 -0
  26. agentrun_sdk/agent/state.py +97 -0
  27. agentrun_sdk/event_loop/__init__.py +9 -0
  28. agentrun_sdk/event_loop/event_loop.py +499 -0
  29. agentrun_sdk/event_loop/streaming.py +319 -0
  30. agentrun_sdk/experimental/__init__.py +4 -0
  31. agentrun_sdk/experimental/hooks/__init__.py +15 -0
  32. agentrun_sdk/experimental/hooks/events.py +123 -0
  33. agentrun_sdk/handlers/__init__.py +10 -0
  34. agentrun_sdk/handlers/callback_handler.py +70 -0
  35. agentrun_sdk/hooks/__init__.py +49 -0
  36. agentrun_sdk/hooks/events.py +80 -0
  37. agentrun_sdk/hooks/registry.py +247 -0
  38. agentrun_sdk/models/__init__.py +10 -0
  39. agentrun_sdk/models/anthropic.py +432 -0
  40. agentrun_sdk/models/bedrock.py +649 -0
  41. agentrun_sdk/models/litellm.py +225 -0
  42. agentrun_sdk/models/llamaapi.py +438 -0
  43. agentrun_sdk/models/mistral.py +539 -0
  44. agentrun_sdk/models/model.py +95 -0
  45. agentrun_sdk/models/ollama.py +357 -0
  46. agentrun_sdk/models/openai.py +436 -0
  47. agentrun_sdk/models/sagemaker.py +598 -0
  48. agentrun_sdk/models/writer.py +449 -0
  49. agentrun_sdk/multiagent/__init__.py +22 -0
  50. agentrun_sdk/multiagent/a2a/__init__.py +15 -0
  51. agentrun_sdk/multiagent/a2a/executor.py +148 -0
  52. agentrun_sdk/multiagent/a2a/server.py +252 -0
  53. agentrun_sdk/multiagent/base.py +92 -0
  54. agentrun_sdk/multiagent/graph.py +555 -0
  55. agentrun_sdk/multiagent/swarm.py +656 -0
  56. agentrun_sdk/py.typed +1 -0
  57. agentrun_sdk/session/__init__.py +18 -0
  58. agentrun_sdk/session/file_session_manager.py +216 -0
  59. agentrun_sdk/session/repository_session_manager.py +152 -0
  60. agentrun_sdk/session/s3_session_manager.py +272 -0
  61. agentrun_sdk/session/session_manager.py +73 -0
  62. agentrun_sdk/session/session_repository.py +51 -0
  63. agentrun_sdk/telemetry/__init__.py +21 -0
  64. agentrun_sdk/telemetry/config.py +194 -0
  65. agentrun_sdk/telemetry/metrics.py +476 -0
  66. agentrun_sdk/telemetry/metrics_constants.py +15 -0
  67. agentrun_sdk/telemetry/tracer.py +563 -0
  68. agentrun_sdk/tools/__init__.py +17 -0
  69. agentrun_sdk/tools/decorator.py +569 -0
  70. agentrun_sdk/tools/executor.py +137 -0
  71. agentrun_sdk/tools/loader.py +152 -0
  72. agentrun_sdk/tools/mcp/__init__.py +13 -0
  73. agentrun_sdk/tools/mcp/mcp_agent_tool.py +99 -0
  74. agentrun_sdk/tools/mcp/mcp_client.py +423 -0
  75. agentrun_sdk/tools/mcp/mcp_instrumentation.py +322 -0
  76. agentrun_sdk/tools/mcp/mcp_types.py +63 -0
  77. agentrun_sdk/tools/registry.py +607 -0
  78. agentrun_sdk/tools/structured_output.py +421 -0
  79. agentrun_sdk/tools/tools.py +217 -0
  80. agentrun_sdk/tools/watcher.py +136 -0
  81. agentrun_sdk/types/__init__.py +5 -0
  82. agentrun_sdk/types/collections.py +23 -0
  83. agentrun_sdk/types/content.py +188 -0
  84. agentrun_sdk/types/event_loop.py +48 -0
  85. agentrun_sdk/types/exceptions.py +81 -0
  86. agentrun_sdk/types/guardrails.py +254 -0
  87. agentrun_sdk/types/media.py +89 -0
  88. agentrun_sdk/types/session.py +152 -0
  89. agentrun_sdk/types/streaming.py +201 -0
  90. agentrun_sdk/types/tools.py +258 -0
  91. agentrun_sdk/types/traces.py +5 -0
  92. agentrun_sdk-0.1.2.dist-info/METADATA +51 -0
  93. agentrun_sdk-0.1.2.dist-info/RECORD +115 -0
  94. agentrun_sdk-0.1.2.dist-info/WHEEL +5 -0
  95. agentrun_sdk-0.1.2.dist-info/entry_points.txt +2 -0
  96. agentrun_sdk-0.1.2.dist-info/top_level.txt +3 -0
  97. agentrun_wrapper/__init__.py +11 -0
  98. agentrun_wrapper/_utils/__init__.py +6 -0
  99. agentrun_wrapper/_utils/endpoints.py +16 -0
  100. agentrun_wrapper/identity/__init__.py +5 -0
  101. agentrun_wrapper/identity/auth.py +211 -0
  102. agentrun_wrapper/memory/__init__.py +6 -0
  103. agentrun_wrapper/memory/client.py +1697 -0
  104. agentrun_wrapper/memory/constants.py +103 -0
  105. agentrun_wrapper/memory/controlplane.py +626 -0
  106. agentrun_wrapper/py.typed +1 -0
  107. agentrun_wrapper/runtime/__init__.py +13 -0
  108. agentrun_wrapper/runtime/app.py +473 -0
  109. agentrun_wrapper/runtime/context.py +34 -0
  110. agentrun_wrapper/runtime/models.py +25 -0
  111. agentrun_wrapper/services/__init__.py +1 -0
  112. agentrun_wrapper/services/identity.py +192 -0
  113. agentrun_wrapper/tools/__init__.py +6 -0
  114. agentrun_wrapper/tools/browser_client.py +325 -0
  115. agentrun_wrapper/tools/code_interpreter_client.py +186 -0
@@ -0,0 +1,1697 @@
1
+ """AgentCore Memory SDK - High-level client for memory operations.
2
+
3
+ This SDK handles the asymmetric API where:
4
+ - Input parameters use old field names (memoryStrategies, memoryStrategyId, etc.)
5
+ - Output responses use new field names (strategies, strategyId, etc.)
6
+
7
+ The SDK automatically normalizes responses to provide both field names for
8
+ backward compatibility.
9
+ """
10
+
11
+ import copy
12
+ import logging
13
+ import time
14
+ import uuid
15
+ import warnings
16
+ from datetime import datetime
17
+ from typing import Any, Callable, Dict, List, Optional, Tuple
18
+
19
+ import boto3
20
+ from botocore.exceptions import ClientError
21
+
22
+ from .constants import (
23
+ CUSTOM_CONSOLIDATION_WRAPPER_KEYS,
24
+ CUSTOM_EXTRACTION_WRAPPER_KEYS,
25
+ DEFAULT_NAMESPACES,
26
+ EXTRACTION_WRAPPER_KEYS,
27
+ MemoryStatus,
28
+ MemoryStrategyTypeEnum,
29
+ MessageRole,
30
+ OverrideType,
31
+ Role,
32
+ StrategyType,
33
+ )
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ class MemoryClient:
39
+ """High-level Bedrock AgentCore Memory client with essential operations."""
40
+
41
+ def __init__(self, region_name: Optional[str] = None):
42
+ """Initialize the Memory client."""
43
+ self.region_name = region_name or boto3.Session().region_name or "us-west-2"
44
+
45
+ self.gmcp_client = boto3.client("bedrock-agentcore-control", region_name=self.region_name)
46
+ self.gmdp_client = boto3.client("bedrock-agentcore", region_name=self.region_name)
47
+
48
+ logger.info(
49
+ "Initialized MemoryClient for control plane: %s, data plane: %s",
50
+ self.gmcp_client.meta.region_name,
51
+ self.gmdp_client.meta.region_name,
52
+ )
53
+
54
+ def create_memory(
55
+ self,
56
+ name: str,
57
+ strategies: Optional[List[Dict[str, Any]]] = None,
58
+ description: Optional[str] = None,
59
+ event_expiry_days: int = 90,
60
+ memory_execution_role_arn: Optional[str] = None,
61
+ ) -> Dict[str, Any]:
62
+ """Create a memory with simplified configuration."""
63
+ if strategies is None:
64
+ strategies = []
65
+
66
+ try:
67
+ processed_strategies = self._add_default_namespaces(strategies)
68
+
69
+ params = {
70
+ "name": name,
71
+ "eventExpiryDuration": event_expiry_days,
72
+ "memoryStrategies": processed_strategies, # Using old field name for input
73
+ "clientToken": str(uuid.uuid4()),
74
+ }
75
+
76
+ if description is not None:
77
+ params["description"] = description
78
+
79
+ if memory_execution_role_arn is not None:
80
+ params["memoryExecutionRoleArn"] = memory_execution_role_arn
81
+
82
+ response = self.gmcp_client.create_memory(**params)
83
+
84
+ memory = response["memory"]
85
+ # Normalize response to handle new field names
86
+ memory = self._normalize_memory_response(memory)
87
+
88
+ logger.info("Created memory: %s", memory["memoryId"])
89
+ return memory
90
+
91
+ except ClientError as e:
92
+ logger.error("Failed to create memory: %s", e)
93
+ raise
94
+
95
+ def create_memory_and_wait(
96
+ self,
97
+ name: str,
98
+ strategies: List[Dict[str, Any]],
99
+ description: Optional[str] = None,
100
+ event_expiry_days: int = 90,
101
+ memory_execution_role_arn: Optional[str] = None,
102
+ max_wait: int = 300,
103
+ poll_interval: int = 10,
104
+ ) -> Dict[str, Any]:
105
+ """Create a memory and wait for it to become ACTIVE.
106
+
107
+ This method creates a memory and polls until it reaches ACTIVE status,
108
+ providing a convenient way to ensure the memory is ready for use.
109
+
110
+ Args:
111
+ name: Name for the memory resource
112
+ strategies: List of strategy configurations
113
+ description: Optional description
114
+ event_expiry_days: How long to retain events (default: 90 days)
115
+ memory_execution_role_arn: IAM role ARN for memory execution
116
+ max_wait: Maximum seconds to wait (default: 300)
117
+ poll_interval: Seconds between status checks (default: 10)
118
+
119
+ Returns:
120
+ Created memory object in ACTIVE status
121
+
122
+ Raises:
123
+ TimeoutError: If memory doesn't become ACTIVE within max_wait
124
+ RuntimeError: If memory creation fails
125
+ """
126
+ # Create the memory
127
+ memory = self.create_memory(
128
+ name=name,
129
+ strategies=strategies,
130
+ description=description,
131
+ event_expiry_days=event_expiry_days,
132
+ memory_execution_role_arn=memory_execution_role_arn,
133
+ )
134
+
135
+ memory_id = memory.get("memoryId", memory.get("id")) # Handle both field names
136
+ if memory_id is None:
137
+ memory_id = ""
138
+ logger.info("Created memory %s, waiting for ACTIVE status...", memory_id)
139
+
140
+ start_time = time.time()
141
+ while time.time() - start_time < max_wait:
142
+ elapsed = int(time.time() - start_time)
143
+
144
+ try:
145
+ status = self.get_memory_status(memory_id)
146
+
147
+ if status == MemoryStatus.ACTIVE.value:
148
+ logger.info("Memory %s is now ACTIVE (took %d seconds)", memory_id, elapsed)
149
+ # Get fresh memory details
150
+ response = self.gmcp_client.get_memory(memoryId=memory_id) # Input uses old field name
151
+ memory = self._normalize_memory_response(response["memory"])
152
+ return memory
153
+ elif status == MemoryStatus.FAILED.value:
154
+ # Get failure reason if available
155
+ response = self.gmcp_client.get_memory(memoryId=memory_id) # Input uses old field name
156
+ failure_reason = response["memory"].get("failureReason", "Unknown")
157
+ raise RuntimeError("Memory creation failed: %s" % failure_reason)
158
+ else:
159
+ logger.debug("Memory status: %s (%d seconds elapsed)", status, elapsed)
160
+
161
+ except ClientError as e:
162
+ logger.error("Error checking memory status: %s", e)
163
+ raise
164
+
165
+ time.sleep(poll_interval)
166
+
167
+ raise TimeoutError("Memory %s did not become ACTIVE within %d seconds" % (memory_id, max_wait))
168
+
169
+ def retrieve_memories(
170
+ self, memory_id: str, namespace: str, query: str, actor_id: Optional[str] = None, top_k: int = 3
171
+ ) -> List[Dict[str, Any]]:
172
+ """Retrieve relevant memories from a namespace.
173
+
174
+ Note: Wildcards (*) are NOT supported in namespaces. You must provide the
175
+ exact namespace path with all variables resolved.
176
+
177
+ Args:
178
+ memory_id: Memory resource ID
179
+ namespace: Exact namespace path (no wildcards)
180
+ query: Search query
181
+ actor_id: Optional actor ID (deprecated, use namespace)
182
+ top_k: Number of results to return
183
+
184
+ Returns:
185
+ List of memory records
186
+
187
+ Example:
188
+ # Correct - exact namespace
189
+ memories = client.retrieve_memories(
190
+ memory_id="mem-123",
191
+ namespace="support/facts/session-456",
192
+ query="customer preferences"
193
+ )
194
+
195
+ # Incorrect - wildcards not supported
196
+ # memories = client.retrieve_memories(..., namespace="support/facts/*", ...)
197
+ """
198
+ if "*" in namespace:
199
+ logger.error("Wildcards are not supported in namespaces. Please provide exact namespace.")
200
+ return []
201
+
202
+ try:
203
+ # Let service handle all namespace validation
204
+ response = self.gmdp_client.retrieve_memory_records(
205
+ memoryId=memory_id, namespace=namespace, searchCriteria={"searchQuery": query, "topK": top_k}
206
+ )
207
+
208
+ memories = response.get("memoryRecordSummaries", [])
209
+ logger.info("Retrieved %d memories from namespace: %s", len(memories), namespace)
210
+ return memories
211
+
212
+ except ClientError as e:
213
+ error_code = e.response["Error"]["Code"]
214
+ error_msg = e.response["Error"]["Message"]
215
+
216
+ if error_code == "ResourceNotFoundException":
217
+ logger.warning(
218
+ "Memory or namespace not found. Ensure memory %s exists and namespace '%s' is configured",
219
+ memory_id,
220
+ namespace,
221
+ )
222
+ elif error_code == "ValidationException":
223
+ logger.warning("Invalid search parameters: %s", error_msg)
224
+ elif error_code == "ServiceException":
225
+ logger.warning("Service error: %s. This may be temporary - try again later", error_msg)
226
+ else:
227
+ logger.warning("Memory retrieval failed (%s): %s", error_code, error_msg)
228
+
229
+ return []
230
+
231
+ def create_event(
232
+ self,
233
+ memory_id: str,
234
+ actor_id: str,
235
+ session_id: str,
236
+ messages: List[Tuple[str, str]],
237
+ event_timestamp: Optional[datetime] = None,
238
+ branch: Optional[Dict[str, str]] = None,
239
+ ) -> Dict[str, Any]:
240
+ """Save an event of an agent interaction or conversation with a user.
241
+
242
+ This is the basis of short-term memory. If you configured your Memory resource
243
+ to have MemoryStrategies, then events that are saved in short-term memory via
244
+ create_event will be used to extract long-term memory records.
245
+
246
+ Args:
247
+ memory_id: Memory resource ID
248
+ actor_id: Actor identifier (could be id of your user or an agent)
249
+ session_id: Session identifier (meant to logically group a series of events)
250
+ messages: List of (text, role) tuples. Role can be USER, ASSISTANT, TOOL, etc.
251
+ event_timestamp: timestamp for the entire event (not per message)
252
+ branch: Optional branch info. For new branches: {"rootEventId": "...", "name": "..."}
253
+ For continuing existing branch: {"name": "..."} or {"name": "...", "rootEventId": "..."}
254
+ A branch is used when you want to have a different history of events.
255
+
256
+ Returns:
257
+ Created event
258
+
259
+ Example:
260
+ event = client.create_event(
261
+ memory_id=memory.get("id"),
262
+ actor_id="weatherWorrier",
263
+ session_id="WeatherSession",
264
+ messages=[
265
+ ("What's the weather?", "USER"),
266
+ ("Today is sunny", "ASSISTANT")
267
+ ]
268
+ )
269
+ root_event_id = event.get("eventId")
270
+ print(event)
271
+
272
+ # Continue the conversation
273
+ event = client.create_event(
274
+ memory_id=memory.get("id"),
275
+ actor_id="weatherWorrier",
276
+ session_id="WeatherSession",
277
+ messages=[
278
+ ("How about the weather tomorrow", "USER"),
279
+ ("Tomorrow is cold!", "ASSISTANT")
280
+ ]
281
+ )
282
+ print(event)
283
+
284
+ # branch the conversation so that the previous message is not part of the history
285
+ # (suppose you did not mean to ask about the weather tomorrow and want to undo
286
+ # that, and replace with a new message)
287
+ event = client.create_event(
288
+ memory_id=memory.get("id"),
289
+ actor_id="weatherWorrier",
290
+ session_id="WeatherSession",
291
+ branch={"name": "differentWeatherQuestion", "rootEventId": root_event_id},
292
+ messages=[
293
+ ("How about the weather a year from now", "USER"),
294
+ ("I can't predict that far into the future!", "ASSISTANT")
295
+ ]
296
+ )
297
+ print(event)
298
+ """
299
+ try:
300
+ if not messages:
301
+ raise ValueError("At least one message is required")
302
+
303
+ payload = []
304
+ for msg in messages:
305
+ if len(msg) != 2:
306
+ raise ValueError("Each message must be (text, role)")
307
+
308
+ text, role = msg
309
+
310
+ try:
311
+ role_enum = MessageRole(role.upper())
312
+ except ValueError as err:
313
+ raise ValueError(
314
+ "Invalid role '%s'. Must be one of: %s" % (role, ", ".join([r.value for r in MessageRole]))
315
+ ) from err
316
+
317
+ payload.append({"conversational": {"content": {"text": text}, "role": role_enum.value}})
318
+
319
+ # Use provided timestamp or current time
320
+ if event_timestamp is None:
321
+ event_timestamp = datetime.utcnow()
322
+
323
+ params = {
324
+ "memoryId": memory_id,
325
+ "actorId": actor_id,
326
+ "sessionId": session_id,
327
+ "eventTimestamp": event_timestamp,
328
+ "payload": payload,
329
+ }
330
+
331
+ if branch:
332
+ params["branch"] = branch
333
+
334
+ response = self.gmdp_client.create_event(**params)
335
+
336
+ event = response["event"]
337
+ logger.info("Created event: %s", event["eventId"])
338
+
339
+ return event
340
+
341
+ except ClientError as e:
342
+ logger.error("Failed to create event: %s", e)
343
+ raise
344
+
345
+ def save_conversation(
346
+ self,
347
+ memory_id: str,
348
+ actor_id: str,
349
+ session_id: str,
350
+ messages: List[Tuple[str, str]],
351
+ event_timestamp: Optional[datetime] = None,
352
+ branch: Optional[Dict[str, str]] = None,
353
+ ) -> Dict[str, Any]:
354
+ """DEPRECATED: Use create_event() instead.
355
+
356
+ Args:
357
+ memory_id: Memory resource ID
358
+ actor_id: Actor identifier
359
+ session_id: Session identifier
360
+ messages: List of (text, role) tuples. Role can be USER, ASSISTANT, TOOL, etc.
361
+ event_timestamp: Optional timestamp for the entire event (not per message)
362
+ branch: Optional branch info. For new branches: {"rootEventId": "...", "name": "..."}
363
+ For continuing existing branch: {"name": "..."} or {"name": "...", "rootEventId": "..."}
364
+
365
+ Returns:
366
+ Created event
367
+
368
+ Example:
369
+ # Save multi-turn conversation
370
+ event = client.save_conversation(
371
+ memory_id="mem-xyz",
372
+ actor_id="user-123",
373
+ session_id="session-456",
374
+ messages=[
375
+ ("What's the weather?", "USER"),
376
+ ("And tomorrow?", "USER"),
377
+ ("Checking weather...", "TOOL"),
378
+ ("Today sunny, tomorrow rain", "ASSISTANT")
379
+ ]
380
+ )
381
+
382
+ # Continue existing branch (only name required)
383
+ event = client.save_conversation(
384
+ memory_id="mem-xyz",
385
+ actor_id="user-123",
386
+ session_id="session-456",
387
+ messages=[("Continue conversation", "USER")],
388
+ branch={"name": "existing-branch"}
389
+ )
390
+ """
391
+ try:
392
+ if not messages:
393
+ raise ValueError("At least one message is required")
394
+
395
+ # Build payload
396
+ payload = []
397
+
398
+ for msg in messages:
399
+ if len(msg) != 2:
400
+ raise ValueError("Each message must be (text, role)")
401
+
402
+ text, role = msg
403
+
404
+ # Validate role
405
+ try:
406
+ role_enum = MessageRole(role.upper())
407
+ except ValueError as err:
408
+ raise ValueError(
409
+ "Invalid role '%s'. Must be one of: %s" % (role, ", ".join([r.value for r in MessageRole]))
410
+ ) from err
411
+
412
+ payload.append({"conversational": {"content": {"text": text}, "role": role_enum.value}})
413
+
414
+ # Use provided timestamp or current time
415
+ if event_timestamp is None:
416
+ event_timestamp = datetime.utcnow()
417
+
418
+ params = {
419
+ "memoryId": memory_id,
420
+ "actorId": actor_id,
421
+ "sessionId": session_id,
422
+ "eventTimestamp": event_timestamp,
423
+ "payload": payload,
424
+ "clientToken": str(uuid.uuid4()),
425
+ }
426
+
427
+ if branch:
428
+ params["branch"] = branch
429
+
430
+ response = self.gmdp_client.create_event(**params)
431
+
432
+ event = response["event"]
433
+ logger.info("Created event: %s", event["eventId"])
434
+
435
+ return event
436
+
437
+ except ClientError as e:
438
+ logger.error("Failed to create event: %s", e)
439
+ raise
440
+
441
+ def save_turn(
442
+ self,
443
+ memory_id: str,
444
+ actor_id: str,
445
+ session_id: str,
446
+ user_input: str,
447
+ agent_response: str,
448
+ event_timestamp: Optional[datetime] = None,
449
+ ) -> Dict[str, Any]:
450
+ """DEPRECATED: Use save_conversation() for more flexibility.
451
+
452
+ This method will be removed in v1.0.0.
453
+ """
454
+ warnings.warn(
455
+ "save_turn() is deprecated and will be removed in v1.0.0. "
456
+ "Use save_conversation() for flexible message handling.",
457
+ DeprecationWarning,
458
+ stacklevel=2,
459
+ )
460
+
461
+ messages = [(user_input, "USER"), (agent_response, "ASSISTANT")]
462
+
463
+ return self.create_event(
464
+ memory_id=memory_id,
465
+ actor_id=actor_id,
466
+ session_id=session_id,
467
+ messages=messages,
468
+ event_timestamp=event_timestamp,
469
+ )
470
+
471
+ def process_turn(
472
+ self,
473
+ memory_id: str,
474
+ actor_id: str,
475
+ session_id: str,
476
+ user_input: str,
477
+ agent_response: str,
478
+ event_timestamp: Optional[datetime] = None,
479
+ retrieval_namespace: Optional[str] = None,
480
+ retrieval_query: Optional[str] = None,
481
+ top_k: int = 3,
482
+ ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
483
+ """DEPRECATED: Use retrieve_memories() and save_conversation() separately.
484
+
485
+ This method will be removed in v1.0.0.
486
+ """
487
+ warnings.warn(
488
+ "process_turn() is deprecated and will be removed in v1.0.0. "
489
+ "Use retrieve_memories() and save_conversation() separately, or use process_turn_with_llm().",
490
+ DeprecationWarning,
491
+ stacklevel=2,
492
+ )
493
+
494
+ retrieved_memories = []
495
+
496
+ if retrieval_namespace:
497
+ search_query = retrieval_query or user_input
498
+ retrieved_memories = self.retrieve_memories(
499
+ memory_id=memory_id, namespace=retrieval_namespace, query=search_query, top_k=top_k
500
+ )
501
+
502
+ event = self.save_turn(
503
+ memory_id=memory_id,
504
+ actor_id=actor_id,
505
+ session_id=session_id,
506
+ user_input=user_input,
507
+ agent_response=agent_response,
508
+ event_timestamp=event_timestamp,
509
+ )
510
+
511
+ return retrieved_memories, event
512
+
513
+ def process_turn_with_llm(
514
+ self,
515
+ memory_id: str,
516
+ actor_id: str,
517
+ session_id: str,
518
+ user_input: str,
519
+ llm_callback: Callable[[str, List[Dict[str, Any]]], str],
520
+ retrieval_namespace: Optional[str] = None,
521
+ retrieval_query: Optional[str] = None,
522
+ top_k: int = 3,
523
+ event_timestamp: Optional[datetime] = None,
524
+ ) -> Tuple[List[Dict[str, Any]], str, Dict[str, Any]]:
525
+ r"""Complete conversation turn with LLM callback integration.
526
+
527
+ This method combines memory retrieval, LLM invocation, and response storage
528
+ in a single call using a callback pattern.
529
+
530
+ Args:
531
+ memory_id: Memory resource ID
532
+ actor_id: Actor identifier (e.g., "user-123")
533
+ session_id: Session identifier
534
+ user_input: The user's message
535
+ llm_callback: Function that takes (user_input, memories) and returns agent_response
536
+ The callback receives the user input and retrieved memories,
537
+ and should return the agent's response string
538
+ retrieval_namespace: Namespace to search for memories (optional)
539
+ retrieval_query: Custom search query (defaults to user_input)
540
+ top_k: Number of memories to retrieve
541
+ event_timestamp: Optional timestamp for the event
542
+
543
+ Returns:
544
+ Tuple of (retrieved_memories, agent_response, created_event)
545
+
546
+ Example:
547
+ def my_llm(user_input: str, memories: List[Dict]) -> str:
548
+ # Format context from memories
549
+ context = "\\n".join([m['content']['text'] for m in memories])
550
+
551
+ # Call your LLM (Bedrock, OpenAI, etc.)
552
+ response = bedrock.invoke_model(
553
+ messages=[
554
+ {"role": "system", "content": f"Context: {context}"},
555
+ {"role": "user", "content": user_input}
556
+ ]
557
+ )
558
+ return response['content']
559
+
560
+ memories, response, event = client.process_turn_with_llm(
561
+ memory_id="mem-xyz",
562
+ actor_id="user-123",
563
+ session_id="session-456",
564
+ user_input="What did we discuss yesterday?",
565
+ llm_callback=my_llm,
566
+ retrieval_namespace="support/facts/{sessionId}"
567
+ )
568
+ """
569
+ # Step 1: Retrieve relevant memories
570
+ retrieved_memories = []
571
+ if retrieval_namespace:
572
+ search_query = retrieval_query or user_input
573
+ retrieved_memories = self.retrieve_memories(
574
+ memory_id=memory_id, namespace=retrieval_namespace, query=search_query, top_k=top_k
575
+ )
576
+ logger.info("Retrieved %d memories for LLM context", len(retrieved_memories))
577
+
578
+ # Step 2: Invoke LLM callback
579
+ try:
580
+ agent_response = llm_callback(user_input, retrieved_memories)
581
+ if not isinstance(agent_response, str):
582
+ raise ValueError("LLM callback must return a string response")
583
+ logger.info("LLM callback generated response")
584
+ except Exception as e:
585
+ logger.error("LLM callback failed: %s", e)
586
+ raise
587
+
588
+ # Step 3: Save the conversation turn
589
+ event = self.create_event(
590
+ memory_id=memory_id,
591
+ actor_id=actor_id,
592
+ session_id=session_id,
593
+ messages=[(user_input, "USER"), (agent_response, "ASSISTANT")],
594
+ event_timestamp=event_timestamp,
595
+ )
596
+
597
+ logger.info("Completed full conversation turn with LLM")
598
+ return retrieved_memories, agent_response, event
599
+
600
+ def list_events(
601
+ self,
602
+ memory_id: str,
603
+ actor_id: str,
604
+ session_id: str,
605
+ branch_name: Optional[str] = None,
606
+ include_parent_events: bool = False,
607
+ max_results: int = 100,
608
+ include_payload: bool = True,
609
+ ) -> List[Dict[str, Any]]:
610
+ """List all events in a session with pagination support.
611
+
612
+ This method provides direct access to the raw events API, allowing developers
613
+ to retrieve all events without the turn grouping logic of get_last_k_turns.
614
+
615
+ Args:
616
+ memory_id: Memory resource ID
617
+ actor_id: Actor identifier
618
+ session_id: Session identifier
619
+ branch_name: Optional branch name to filter events (None for all branches)
620
+ include_parent_events: Whether to include parent branch events (only applies with branch_name)
621
+ max_results: Maximum number of events to return
622
+ include_payload: Whether to include event payloads in response
623
+
624
+ Returns:
625
+ List of event dictionaries in chronological order
626
+
627
+ Example:
628
+ # Get all events
629
+ events = client.list_events(memory_id, actor_id, session_id)
630
+
631
+ # Get only main branch events
632
+ main_events = client.list_events(memory_id, actor_id, session_id, branch_name="main")
633
+
634
+ # Get events from a specific branch
635
+ branch_events = client.list_events(memory_id, actor_id, session_id, branch_name="test-branch")
636
+ """
637
+ try:
638
+ all_events = []
639
+ next_token = None
640
+
641
+ while len(all_events) < max_results:
642
+ params = {
643
+ "memoryId": memory_id,
644
+ "actorId": actor_id,
645
+ "sessionId": session_id,
646
+ "maxResults": min(100, max_results - len(all_events)),
647
+ }
648
+
649
+ if next_token:
650
+ params["nextToken"] = next_token
651
+
652
+ # Add branch filter if specified (but not for "main")
653
+ if branch_name and branch_name != "main":
654
+ params["filter"] = {"branch": {"name": branch_name, "includeParentBranches": include_parent_events}}
655
+
656
+ response = self.gmdp_client.list_events(**params)
657
+
658
+ events = response.get("events", [])
659
+ all_events.extend(events)
660
+
661
+ next_token = response.get("nextToken")
662
+ if not next_token or len(all_events) >= max_results:
663
+ break
664
+
665
+ logger.info("Retrieved total of %d events", len(all_events))
666
+ return all_events[:max_results]
667
+
668
+ except ClientError as e:
669
+ logger.error("Failed to list events: %s", e)
670
+ raise
671
+
672
+ def list_branches(self, memory_id: str, actor_id: str, session_id: str) -> List[Dict[str, Any]]:
673
+ """List all branches in a session.
674
+
675
+ This method handles pagination automatically and provides a structured view
676
+ of all conversation branches, which would require complex pagination and
677
+ grouping logic if done with raw boto3 calls.
678
+
679
+ Returns:
680
+ List of branch information including name and root event
681
+ """
682
+ try:
683
+ # Get all events - need to handle pagination for complete list
684
+ all_events = []
685
+ next_token = None
686
+
687
+ while True:
688
+ params = {"memoryId": memory_id, "actorId": actor_id, "sessionId": session_id, "maxResults": 100}
689
+
690
+ if next_token:
691
+ params["nextToken"] = next_token
692
+
693
+ response = self.gmdp_client.list_events(**params)
694
+ all_events.extend(response.get("events", []))
695
+
696
+ next_token = response.get("nextToken")
697
+ if not next_token:
698
+ break
699
+
700
+ branches = {}
701
+ main_branch_events = []
702
+
703
+ for event in all_events:
704
+ branch_info = event.get("branch")
705
+ if branch_info:
706
+ branch_name = branch_info["name"]
707
+ if branch_name not in branches:
708
+ branches[branch_name] = {
709
+ "name": branch_name,
710
+ "rootEventId": branch_info.get("rootEventId"),
711
+ "firstEventId": event["eventId"],
712
+ "eventCount": 1,
713
+ "created": event["eventTimestamp"],
714
+ }
715
+ else:
716
+ branches[branch_name]["eventCount"] += 1
717
+ else:
718
+ main_branch_events.append(event)
719
+
720
+ # Build result list
721
+ result = []
722
+
723
+ # Only add main branch if there are actual events
724
+ if main_branch_events:
725
+ result.append(
726
+ {
727
+ "name": "main",
728
+ "rootEventId": None,
729
+ "firstEventId": main_branch_events[0]["eventId"],
730
+ "eventCount": len(main_branch_events),
731
+ "created": main_branch_events[0]["eventTimestamp"],
732
+ }
733
+ )
734
+
735
+ # Add other branches
736
+ result.extend(list(branches.values()))
737
+
738
+ logger.info("Found %d branches in session %s", len(result), session_id)
739
+ return result
740
+
741
+ except ClientError as e:
742
+ logger.error("Failed to list branches: %s", e)
743
+ raise
744
+
745
+ def list_branch_events(
746
+ self,
747
+ memory_id: str,
748
+ actor_id: str,
749
+ session_id: str,
750
+ branch_name: Optional[str] = None,
751
+ include_parent_events: bool = False,
752
+ max_results: int = 100,
753
+ ) -> List[Dict[str, Any]]:
754
+ """List events in a specific branch.
755
+
756
+ This method provides complex filtering and pagination that would require
757
+ significant boilerplate code with raw boto3. It handles:
758
+ - Automatic pagination across multiple API calls
759
+ - Branch filtering with parent event inclusion logic
760
+ - Main branch isolation (events without branch info)
761
+
762
+ Args:
763
+ memory_id: Memory resource ID
764
+ actor_id: Actor identifier
765
+ session_id: Session identifier
766
+ branch_name: Branch name (None for main branch)
767
+ include_parent_events: Whether to include events from parent branches
768
+ max_results: Maximum events to return
769
+
770
+ Returns:
771
+ List of events in the branch
772
+ """
773
+ try:
774
+ params = {
775
+ "memoryId": memory_id,
776
+ "actorId": actor_id,
777
+ "sessionId": session_id,
778
+ "maxResults": min(100, max_results),
779
+ }
780
+
781
+ # Only add filter when we have a specific branch name
782
+ if branch_name:
783
+ params["filter"] = {"branch": {"name": branch_name, "includeParentBranches": include_parent_events}}
784
+
785
+ response = self.gmdp_client.list_events(**params)
786
+ events = response.get("events", [])
787
+
788
+ # Handle pagination
789
+ next_token = response.get("nextToken")
790
+ while next_token and len(events) < max_results:
791
+ params["nextToken"] = next_token
792
+ params["maxResults"] = min(100, max_results - len(events))
793
+ response = self.gmdp_client.list_events(**params)
794
+ events.extend(response.get("events", []))
795
+ next_token = response.get("nextToken")
796
+
797
+ # Filter for main branch if no branch specified
798
+ if not branch_name:
799
+ events = [e for e in events if not e.get("branch")]
800
+
801
+ logger.info("Retrieved %d events from branch '%s'", len(events), branch_name or "main")
802
+ return events
803
+
804
+ except ClientError as e:
805
+ logger.error("Failed to list branch events: %s", e)
806
+ raise
807
+
808
+ def get_conversation_tree(self, memory_id: str, actor_id: str, session_id: str) -> Dict[str, Any]:
809
+ """Get a tree structure of the conversation with all branches.
810
+
811
+ This method transforms a flat list of events into a hierarchical tree structure,
812
+ providing visualization-ready data that would be complex to build from raw events.
813
+ It handles:
814
+ - Full pagination to get all events
815
+ - Grouping by branches
816
+ - Message summarization
817
+ - Tree structure building
818
+
819
+ Returns:
820
+ Dictionary representing the conversation tree structure
821
+ """
822
+ try:
823
+ # Get all events - need to handle pagination for complete list
824
+ all_events = []
825
+ next_token = None
826
+
827
+ while True:
828
+ params = {"memoryId": memory_id, "actorId": actor_id, "sessionId": session_id, "maxResults": 100}
829
+
830
+ if next_token:
831
+ params["nextToken"] = next_token
832
+
833
+ response = self.gmdp_client.list_events(**params)
834
+ all_events.extend(response.get("events", []))
835
+
836
+ next_token = response.get("nextToken")
837
+ if not next_token:
838
+ break
839
+
840
+ # Build tree structure
841
+ tree = {"session_id": session_id, "actor_id": actor_id, "main_branch": {"events": [], "branches": {}}}
842
+
843
+ # Group events by branch
844
+ for event in all_events:
845
+ event_summary = {"eventId": event["eventId"], "timestamp": event["eventTimestamp"], "messages": []}
846
+
847
+ # Extract message summaries
848
+ if "payload" in event:
849
+ for payload_item in event.get("payload", []):
850
+ if "conversational" in payload_item:
851
+ conv = payload_item["conversational"]
852
+ event_summary["messages"].append(
853
+ {"role": conv.get("role"), "text": conv.get("content", {}).get("text", "")[:50] + "..."}
854
+ )
855
+
856
+ branch_info = event.get("branch")
857
+ if branch_info:
858
+ branch_name = branch_info["name"]
859
+ root_event = branch_info.get("rootEventId") # Use .get() to handle missing field
860
+
861
+ if branch_name not in tree["main_branch"]["branches"]:
862
+ tree["main_branch"]["branches"][branch_name] = {"root_event_id": root_event, "events": []}
863
+
864
+ tree["main_branch"]["branches"][branch_name]["events"].append(event_summary)
865
+ else:
866
+ tree["main_branch"]["events"].append(event_summary)
867
+
868
+ logger.info("Built conversation tree with %d branches", len(tree["main_branch"]["branches"]))
869
+ return tree
870
+
871
+ except ClientError as e:
872
+ logger.error("Failed to build conversation tree: %s", e)
873
+ raise
874
+
875
+ def merge_branch_context(
876
+ self, memory_id: str, actor_id: str, session_id: str, branch_name: str, include_parent: bool = True
877
+ ) -> List[Dict[str, Any]]:
878
+ """Get all messages from a branch for context building.
879
+
880
+ Args:
881
+ memory_id: Memory resource ID
882
+ actor_id: Actor identifier
883
+ session_id: Session identifier
884
+ branch_name: Branch to get context from
885
+ include_parent: Whether to include parent branch events
886
+
887
+ Returns:
888
+ List of all messages in chronological order
889
+ """
890
+ events = self.list_branch_events(
891
+ memory_id=memory_id,
892
+ actor_id=actor_id,
893
+ session_id=session_id,
894
+ branch_name=branch_name,
895
+ include_parent_events=include_parent,
896
+ max_results=100,
897
+ )
898
+
899
+ messages = []
900
+ for event in events:
901
+ if "payload" in event:
902
+ for payload_item in event.get("payload", []):
903
+ if "conversational" in payload_item:
904
+ conv = payload_item["conversational"]
905
+ messages.append(
906
+ {
907
+ "timestamp": event["eventTimestamp"],
908
+ "eventId": event["eventId"],
909
+ "branch": event.get("branch", {}).get("name", "main"),
910
+ "role": conv.get("role"),
911
+ "content": conv.get("content", {}).get("text", ""),
912
+ }
913
+ )
914
+
915
+ # Sort by timestamp
916
+ messages.sort(key=lambda x: x["timestamp"])
917
+
918
+ logger.info("Retrieved %d messages from branch '%s'", len(messages), branch_name)
919
+ return messages
920
+
921
+ def get_last_k_turns(
922
+ self,
923
+ memory_id: str,
924
+ actor_id: str,
925
+ session_id: str,
926
+ k: int = 5,
927
+ branch_name: Optional[str] = None,
928
+ include_branches: bool = False,
929
+ max_results: int = 100,
930
+ ) -> List[List[Dict[str, Any]]]:
931
+ """Get the last K conversation turns.
932
+
933
+ A "turn" typically consists of a user message followed by assistant response(s).
934
+ This method groups messages into logical turns for easier processing.
935
+
936
+ Returns:
937
+ List of turns, where each turn is a list of message dictionaries
938
+ """
939
+ try:
940
+ # Use the new list_events method
941
+ events = self.list_events(
942
+ memory_id=memory_id,
943
+ actor_id=actor_id,
944
+ session_id=session_id,
945
+ branch_name=branch_name,
946
+ include_parent_events=False,
947
+ max_results=max_results,
948
+ )
949
+
950
+ if not events:
951
+ return []
952
+
953
+ # Process events to group into turns
954
+ turns = []
955
+ current_turn = []
956
+
957
+ # Process events in chronological order
958
+ for _, event in enumerate(events):
959
+ if "payload" in event and event["payload"]:
960
+ for payload_item in event["payload"]:
961
+ if "conversational" in payload_item:
962
+ role = payload_item["conversational"].get("role")
963
+
964
+ # Start a new turn when we see a USER message and already have messages
965
+ if role == Role.USER.value and current_turn:
966
+ turns.append(current_turn)
967
+ current_turn = []
968
+
969
+ current_turn.append(payload_item["conversational"])
970
+
971
+ # Don't forget the last turn
972
+ if current_turn:
973
+ turns.append(current_turn)
974
+
975
+ # Return the last k turns
976
+ if len(turns) > k:
977
+ result = turns[-k:] # Get last k turns
978
+ else:
979
+ result = turns
980
+
981
+ return result
982
+
983
+ except ClientError as e:
984
+ logger.error("Failed to get last K turns: %s", e)
985
+ raise
986
+
987
+ def fork_conversation(
988
+ self,
989
+ memory_id: str,
990
+ actor_id: str,
991
+ session_id: str,
992
+ root_event_id: str,
993
+ branch_name: str,
994
+ new_messages: List[Tuple[str, str]],
995
+ event_timestamp: Optional[datetime] = None,
996
+ ) -> Dict[str, Any]:
997
+ """Fork a conversation from a specific event to create a new branch."""
998
+ try:
999
+ branch = {"rootEventId": root_event_id, "name": branch_name}
1000
+
1001
+ event = self.create_event(
1002
+ memory_id=memory_id,
1003
+ actor_id=actor_id,
1004
+ session_id=session_id,
1005
+ messages=new_messages,
1006
+ branch=branch,
1007
+ event_timestamp=event_timestamp,
1008
+ )
1009
+
1010
+ logger.info("Created branch '%s' from event %s", branch_name, root_event_id)
1011
+ return event
1012
+
1013
+ except ClientError as e:
1014
+ logger.error("Failed to fork conversation: %s", e)
1015
+ raise
1016
+
1017
+ def get_memory_strategies(self, memory_id: str) -> List[Dict[str, Any]]:
1018
+ """Get all strategies for a memory."""
1019
+ try:
1020
+ response = self.gmcp_client.get_memory(memoryId=memory_id) # Input uses old field name
1021
+ memory = response["memory"]
1022
+
1023
+ # Handle both old and new field names in response
1024
+ strategies = memory.get("strategies", memory.get("memoryStrategies", []))
1025
+
1026
+ # Normalize strategy fields
1027
+ normalized_strategies = []
1028
+ for strategy in strategies:
1029
+ # Create normalized version with both old and new field names
1030
+ normalized = strategy.copy()
1031
+
1032
+ # Ensure both field name versions exist
1033
+ if "strategyId" in strategy and "memoryStrategyId" not in normalized:
1034
+ normalized["memoryStrategyId"] = strategy["strategyId"]
1035
+ elif "memoryStrategyId" in strategy and "strategyId" not in normalized:
1036
+ normalized["strategyId"] = strategy["memoryStrategyId"]
1037
+
1038
+ if "type" in strategy and "memoryStrategyType" not in normalized:
1039
+ normalized["memoryStrategyType"] = strategy["type"]
1040
+ elif "memoryStrategyType" in strategy and "type" not in normalized:
1041
+ normalized["type"] = strategy["memoryStrategyType"]
1042
+
1043
+ normalized_strategies.append(normalized)
1044
+
1045
+ return normalized_strategies
1046
+ except ClientError as e:
1047
+ logger.error("Failed to get memory strategies: %s", e)
1048
+ raise
1049
+
1050
+ def get_memory_status(self, memory_id: str) -> str:
1051
+ """Get current memory status."""
1052
+ try:
1053
+ response = self.gmcp_client.get_memory(memoryId=memory_id) # Input uses old field name
1054
+ return response["memory"]["status"]
1055
+ except ClientError as e:
1056
+ logger.error("Failed to get memory status: %s", e)
1057
+ raise
1058
+
1059
+ def list_memories(self, max_results: int = 100) -> List[Dict[str, Any]]:
1060
+ """List all memories for the account."""
1061
+ try:
1062
+ # Ensure max_results doesn't exceed API limit per request
1063
+ results_per_request = min(max_results, 100)
1064
+
1065
+ response = self.gmcp_client.list_memories(maxResults=results_per_request)
1066
+ memories = response.get("memories", [])
1067
+
1068
+ next_token = response.get("nextToken")
1069
+ while next_token and len(memories) < max_results:
1070
+ remaining = max_results - len(memories)
1071
+ results_per_request = min(remaining, 100)
1072
+
1073
+ response = self.gmcp_client.list_memories(maxResults=results_per_request, nextToken=next_token)
1074
+ memories.extend(response.get("memories", []))
1075
+ next_token = response.get("nextToken")
1076
+
1077
+ # Normalize memory summaries if they contain new field names
1078
+ normalized_memories = []
1079
+ for memory in memories[:max_results]:
1080
+ normalized = memory.copy()
1081
+ # Ensure both field name versions exist
1082
+ if "id" in memory and "memoryId" not in normalized:
1083
+ normalized["memoryId"] = memory["id"]
1084
+ elif "memoryId" in memory and "id" not in normalized:
1085
+ normalized["id"] = memory["memoryId"]
1086
+ normalized_memories.append(normalized)
1087
+
1088
+ return normalized_memories
1089
+
1090
+ except ClientError as e:
1091
+ logger.error("Failed to list memories: %s", e)
1092
+ raise
1093
+
1094
+ def delete_memory(self, memory_id: str) -> Dict[str, Any]:
1095
+ """Delete a memory resource."""
1096
+ try:
1097
+ response = self.gmcp_client.delete_memory(
1098
+ memoryId=memory_id, clientToken=str(uuid.uuid4())
1099
+ ) # Input uses old field name
1100
+ logger.info("Deleted memory: %s", memory_id)
1101
+ return response
1102
+ except ClientError as e:
1103
+ logger.error("Failed to delete memory: %s", e)
1104
+ raise
1105
+
1106
+ def delete_memory_and_wait(self, memory_id: str, max_wait: int = 300, poll_interval: int = 10) -> Dict[str, Any]:
1107
+ """Delete a memory and wait for deletion to complete.
1108
+
1109
+ This method deletes a memory and polls until it's fully deleted,
1110
+ ensuring clean resource cleanup.
1111
+
1112
+ Args:
1113
+ memory_id: Memory resource ID to delete
1114
+ max_wait: Maximum seconds to wait (default: 300)
1115
+ poll_interval: Seconds between checks (default: 10)
1116
+
1117
+ Returns:
1118
+ Final deletion response
1119
+
1120
+ Raises:
1121
+ TimeoutError: If deletion doesn't complete within max_wait
1122
+ """
1123
+ # Initiate deletion
1124
+ response = self.delete_memory(memory_id)
1125
+ logger.info("Initiated deletion of memory %s", memory_id)
1126
+
1127
+ start_time = time.time()
1128
+ while time.time() - start_time < max_wait:
1129
+ elapsed = int(time.time() - start_time)
1130
+
1131
+ try:
1132
+ # Try to get the memory - if it doesn't exist, deletion is complete
1133
+ self.gmcp_client.get_memory(memoryId=memory_id) # Input uses old field name
1134
+ logger.debug("Memory still exists, waiting... (%d seconds elapsed)", elapsed)
1135
+
1136
+ except ClientError as e:
1137
+ if e.response["Error"]["Code"] == "ResourceNotFoundException":
1138
+ logger.info("Memory %s successfully deleted (took %d seconds)", memory_id, elapsed)
1139
+ return response
1140
+ else:
1141
+ logger.error("Error checking memory status: %s", e)
1142
+ raise
1143
+
1144
+ time.sleep(poll_interval)
1145
+
1146
+ raise TimeoutError("Memory %s was not deleted within %d seconds" % (memory_id, max_wait))
1147
+
1148
+ def add_semantic_strategy(
1149
+ self,
1150
+ memory_id: str,
1151
+ name: str,
1152
+ description: Optional[str] = None,
1153
+ namespaces: Optional[List[str]] = None,
1154
+ ) -> Dict[str, Any]:
1155
+ """Add a semantic memory strategy.
1156
+
1157
+ Note: Configuration is no longer provided for built-in strategies as per API changes.
1158
+ """
1159
+ strategy: Dict = {
1160
+ StrategyType.SEMANTIC.value: {
1161
+ "name": name,
1162
+ }
1163
+ }
1164
+
1165
+ if description:
1166
+ strategy[StrategyType.SEMANTIC.value]["description"] = description
1167
+ if namespaces:
1168
+ strategy[StrategyType.SEMANTIC.value]["namespaces"] = namespaces
1169
+
1170
+ return self._add_strategy(memory_id, strategy)
1171
+
1172
+ def add_semantic_strategy_and_wait(
1173
+ self,
1174
+ memory_id: str,
1175
+ name: str,
1176
+ description: Optional[str] = None,
1177
+ namespaces: Optional[List[str]] = None,
1178
+ max_wait: int = 300,
1179
+ poll_interval: int = 10,
1180
+ ) -> Dict[str, Any]:
1181
+ """Add a semantic strategy and wait for memory to return to ACTIVE state.
1182
+
1183
+ This addresses the issue where adding a strategy puts the memory into
1184
+ CREATING state temporarily, preventing subsequent operations.
1185
+ """
1186
+ # Add the strategy
1187
+ self.add_semantic_strategy(memory_id, name, description, namespaces)
1188
+
1189
+ # Wait for memory to return to ACTIVE
1190
+ return self._wait_for_memory_active(memory_id, max_wait, poll_interval)
1191
+
1192
+ def add_summary_strategy(
1193
+ self,
1194
+ memory_id: str,
1195
+ name: str,
1196
+ description: Optional[str] = None,
1197
+ namespaces: Optional[List[str]] = None,
1198
+ ) -> Dict[str, Any]:
1199
+ """Add a summary memory strategy.
1200
+
1201
+ Note: Configuration is no longer provided for built-in strategies as per API changes.
1202
+ """
1203
+ strategy: Dict = {
1204
+ StrategyType.SUMMARY.value: {
1205
+ "name": name,
1206
+ }
1207
+ }
1208
+
1209
+ if description:
1210
+ strategy[StrategyType.SUMMARY.value]["description"] = description
1211
+ if namespaces:
1212
+ strategy[StrategyType.SUMMARY.value]["namespaces"] = namespaces
1213
+
1214
+ return self._add_strategy(memory_id, strategy)
1215
+
1216
+ def add_summary_strategy_and_wait(
1217
+ self,
1218
+ memory_id: str,
1219
+ name: str,
1220
+ description: Optional[str] = None,
1221
+ namespaces: Optional[List[str]] = None,
1222
+ max_wait: int = 300,
1223
+ poll_interval: int = 10,
1224
+ ) -> Dict[str, Any]:
1225
+ """Add a summary strategy and wait for memory to return to ACTIVE state."""
1226
+ self.add_summary_strategy(memory_id, name, description, namespaces)
1227
+ return self._wait_for_memory_active(memory_id, max_wait, poll_interval)
1228
+
1229
+ def add_user_preference_strategy(
1230
+ self,
1231
+ memory_id: str,
1232
+ name: str,
1233
+ description: Optional[str] = None,
1234
+ namespaces: Optional[List[str]] = None,
1235
+ ) -> Dict[str, Any]:
1236
+ """Add a user preference memory strategy.
1237
+
1238
+ Note: Configuration is no longer provided for built-in strategies as per API changes.
1239
+ """
1240
+ strategy: Dict = {
1241
+ StrategyType.USER_PREFERENCE.value: {
1242
+ "name": name,
1243
+ }
1244
+ }
1245
+
1246
+ if description:
1247
+ strategy[StrategyType.USER_PREFERENCE.value]["description"] = description
1248
+ if namespaces:
1249
+ strategy[StrategyType.USER_PREFERENCE.value]["namespaces"] = namespaces
1250
+
1251
+ return self._add_strategy(memory_id, strategy)
1252
+
1253
+ def add_user_preference_strategy_and_wait(
1254
+ self,
1255
+ memory_id: str,
1256
+ name: str,
1257
+ description: Optional[str] = None,
1258
+ namespaces: Optional[List[str]] = None,
1259
+ max_wait: int = 300,
1260
+ poll_interval: int = 10,
1261
+ ) -> Dict[str, Any]:
1262
+ """Add a user preference strategy and wait for memory to return to ACTIVE state."""
1263
+ self.add_user_preference_strategy(memory_id, name, description, namespaces)
1264
+ return self._wait_for_memory_active(memory_id, max_wait, poll_interval)
1265
+
1266
+ def add_custom_semantic_strategy(
1267
+ self,
1268
+ memory_id: str,
1269
+ name: str,
1270
+ extraction_config: Dict[str, Any],
1271
+ consolidation_config: Dict[str, Any],
1272
+ description: Optional[str] = None,
1273
+ namespaces: Optional[List[str]] = None,
1274
+ ) -> Dict[str, Any]:
1275
+ """Add a custom semantic strategy with prompts.
1276
+
1277
+ Args:
1278
+ memory_id: Memory resource ID
1279
+ name: Strategy name
1280
+ extraction_config: Extraction configuration with prompt and model:
1281
+ {"prompt": "...", "modelId": "..."}
1282
+ consolidation_config: Consolidation configuration with prompt and model:
1283
+ {"prompt": "...", "modelId": "..."}
1284
+ description: Optional description
1285
+ namespaces: Optional namespaces list
1286
+ """
1287
+ strategy = {
1288
+ StrategyType.CUSTOM.value: {
1289
+ "name": name,
1290
+ "configuration": {
1291
+ "semanticOverride": {
1292
+ "extraction": {
1293
+ "appendToPrompt": extraction_config["prompt"],
1294
+ "modelId": extraction_config["modelId"],
1295
+ },
1296
+ "consolidation": {
1297
+ "appendToPrompt": consolidation_config["prompt"],
1298
+ "modelId": consolidation_config["modelId"],
1299
+ },
1300
+ }
1301
+ },
1302
+ }
1303
+ }
1304
+
1305
+ if description:
1306
+ strategy[StrategyType.CUSTOM.value]["description"] = description
1307
+ if namespaces:
1308
+ strategy[StrategyType.CUSTOM.value]["namespaces"] = namespaces
1309
+
1310
+ return self._add_strategy(memory_id, strategy)
1311
+
1312
+ def add_custom_semantic_strategy_and_wait(
1313
+ self,
1314
+ memory_id: str,
1315
+ name: str,
1316
+ extraction_config: Dict[str, Any],
1317
+ consolidation_config: Dict[str, Any],
1318
+ description: Optional[str] = None,
1319
+ namespaces: Optional[List[str]] = None,
1320
+ max_wait: int = 300,
1321
+ poll_interval: int = 10,
1322
+ ) -> Dict[str, Any]:
1323
+ """Add a custom semantic strategy and wait for memory to return to ACTIVE state."""
1324
+ self.add_custom_semantic_strategy(
1325
+ memory_id, name, extraction_config, consolidation_config, description, namespaces
1326
+ )
1327
+ return self._wait_for_memory_active(memory_id, max_wait, poll_interval)
1328
+
1329
+ def modify_strategy(
1330
+ self,
1331
+ memory_id: str,
1332
+ strategy_id: str,
1333
+ description: Optional[str] = None,
1334
+ namespaces: Optional[List[str]] = None,
1335
+ configuration: Optional[Dict[str, Any]] = None,
1336
+ ) -> Dict[str, Any]:
1337
+ """Modify a strategy with full control over configuration."""
1338
+ modify_config: Dict = {"memoryStrategyId": strategy_id} # Using old field name for input
1339
+
1340
+ if description is not None:
1341
+ modify_config["description"] = description
1342
+ if namespaces is not None:
1343
+ modify_config["namespaces"] = namespaces
1344
+ if configuration is not None:
1345
+ modify_config["configuration"] = configuration
1346
+
1347
+ return self.update_memory_strategies(memory_id=memory_id, modify_strategies=[modify_config])
1348
+
1349
+ def delete_strategy(self, memory_id: str, strategy_id: str) -> Dict[str, Any]:
1350
+ """Delete a strategy from a memory."""
1351
+ return self.update_memory_strategies(memory_id=memory_id, delete_strategy_ids=[strategy_id])
1352
+
1353
+ def update_memory_strategies(
1354
+ self,
1355
+ memory_id: str,
1356
+ add_strategies: Optional[List[Dict[str, Any]]] = None,
1357
+ modify_strategies: Optional[List[Dict[str, Any]]] = None,
1358
+ delete_strategy_ids: Optional[List[str]] = None,
1359
+ ) -> Dict[str, Any]:
1360
+ """Update memory strategies - add, modify, or delete."""
1361
+ try:
1362
+ memory_strategies = {}
1363
+
1364
+ if add_strategies:
1365
+ processed_add = self._add_default_namespaces(add_strategies)
1366
+ memory_strategies["addMemoryStrategies"] = processed_add # Using old field name for input
1367
+
1368
+ if modify_strategies:
1369
+ current_strategies = self.get_memory_strategies(memory_id)
1370
+ strategy_map = {s["memoryStrategyId"]: s for s in current_strategies} # Using normalized field
1371
+
1372
+ modify_list = []
1373
+ for strategy in modify_strategies:
1374
+ if "memoryStrategyId" not in strategy: # Using old field name
1375
+ raise ValueError("Each modify strategy must include memoryStrategyId")
1376
+
1377
+ strategy_id = strategy["memoryStrategyId"] # Using old field name
1378
+ strategy_info = strategy_map.get(strategy_id)
1379
+
1380
+ if not strategy_info:
1381
+ raise ValueError("Strategy %s not found in memory %s" % (strategy_id, memory_id))
1382
+
1383
+ strategy_type = strategy_info["memoryStrategyType"] # Using normalized field
1384
+ override_type = strategy_info.get("configuration", {}).get("type")
1385
+
1386
+ strategy_copy = copy.deepcopy(strategy)
1387
+
1388
+ if "configuration" in strategy_copy:
1389
+ wrapped_config = self._wrap_configuration(
1390
+ strategy_copy["configuration"], strategy_type, override_type
1391
+ )
1392
+ strategy_copy["configuration"] = wrapped_config
1393
+
1394
+ modify_list.append(strategy_copy)
1395
+
1396
+ memory_strategies["modifyMemoryStrategies"] = modify_list # Using old field name for input
1397
+
1398
+ if delete_strategy_ids:
1399
+ delete_list = [{"memoryStrategyId": sid} for sid in delete_strategy_ids] # Using old field name
1400
+ memory_strategies["deleteMemoryStrategies"] = delete_list # Using old field name for input
1401
+
1402
+ if not memory_strategies:
1403
+ raise ValueError("No strategy operations provided")
1404
+
1405
+ response = self.gmcp_client.update_memory(
1406
+ memoryId=memory_id,
1407
+ memoryStrategies=memory_strategies,
1408
+ clientToken=str(uuid.uuid4()), # Using old field names for input
1409
+ )
1410
+
1411
+ logger.info("Updated memory strategies for: %s", memory_id)
1412
+ memory = self._normalize_memory_response(response["memory"])
1413
+ return memory
1414
+
1415
+ except ClientError as e:
1416
+ logger.error("Failed to update memory strategies: %s", e)
1417
+ raise
1418
+
1419
+ def update_memory_strategies_and_wait(
1420
+ self,
1421
+ memory_id: str,
1422
+ add_strategies: Optional[List[Dict[str, Any]]] = None,
1423
+ modify_strategies: Optional[List[Dict[str, Any]]] = None,
1424
+ delete_strategy_ids: Optional[List[str]] = None,
1425
+ max_wait: int = 300,
1426
+ poll_interval: int = 10,
1427
+ ) -> Dict[str, Any]:
1428
+ """Update memory strategies and wait for memory to return to ACTIVE state.
1429
+
1430
+ This method handles the temporary CREATING state that occurs when
1431
+ updating strategies, preventing subsequent update errors.
1432
+ """
1433
+ # Update strategies
1434
+ self.update_memory_strategies(memory_id, add_strategies, modify_strategies, delete_strategy_ids)
1435
+
1436
+ # Wait for memory to return to ACTIVE
1437
+ return self._wait_for_memory_active(memory_id, max_wait, poll_interval)
1438
+
1439
+ def wait_for_memories(
1440
+ self, memory_id: str, namespace: str, test_query: str = "test", max_wait: int = 180, poll_interval: int = 15
1441
+ ) -> bool:
1442
+ """Wait for memory extraction to complete by polling.
1443
+
1444
+ IMPORTANT LIMITATIONS:
1445
+ 1. This method only works reliably on empty namespaces. If there are already
1446
+ existing memories in the namespace, this method may return True immediately
1447
+ even if new extractions haven't completed.
1448
+ 2. Wildcards (*) are NOT supported in namespaces. You must provide the exact
1449
+ namespace path with all variables resolved (e.g., "support/facts/session-123"
1450
+ not "support/facts/*").
1451
+
1452
+ For subsequent extractions in populated namespaces, use a fixed wait time:
1453
+ time.sleep(150) # Wait 2.5 minutes for extraction
1454
+
1455
+ Args:
1456
+ memory_id: Memory resource ID
1457
+ namespace: Exact namespace to check (no wildcards)
1458
+ test_query: Query to test with (default: "test")
1459
+ max_wait: Maximum seconds to wait (default: 180)
1460
+ poll_interval: Seconds between checks (default: 15)
1461
+
1462
+ Returns:
1463
+ True if memories found, False if timeout
1464
+
1465
+ Note:
1466
+ This method will be deprecated in future versions once the API
1467
+ provides extraction status or timestamps.
1468
+ """
1469
+ if "*" in namespace:
1470
+ logger.error("Wildcards are not supported in namespaces. Please provide exact namespace.")
1471
+ return False
1472
+
1473
+ logger.warning(
1474
+ "wait_for_memories() only works reliably on empty namespaces. "
1475
+ "For populated namespaces, consider using a fixed wait time instead."
1476
+ )
1477
+
1478
+ logger.info("Waiting for memory extraction in namespace: %s", namespace)
1479
+ start_time = time.time()
1480
+ service_errors = 0
1481
+
1482
+ while time.time() - start_time < max_wait:
1483
+ elapsed = int(time.time() - start_time)
1484
+
1485
+ try:
1486
+ memories = self.retrieve_memories(memory_id=memory_id, namespace=namespace, query=test_query, top_k=1)
1487
+
1488
+ if memories:
1489
+ logger.info("Memory extraction complete after %d seconds", elapsed)
1490
+ return True
1491
+
1492
+ # Reset service error count on successful call
1493
+ service_errors = 0
1494
+
1495
+ except Exception as e:
1496
+ if "ServiceException" in str(e):
1497
+ service_errors += 1
1498
+ if service_errors >= 3:
1499
+ logger.warning("Multiple service errors - the service may be experiencing issues")
1500
+ logger.debug("Retrieval attempt failed: %s", e)
1501
+
1502
+ if time.time() - start_time < max_wait:
1503
+ time.sleep(poll_interval)
1504
+
1505
+ logger.warning("No memories found after %d seconds", max_wait)
1506
+ if service_errors > 0:
1507
+ logger.info("Note: Encountered %d service errors during polling", service_errors)
1508
+ return False
1509
+
1510
+ def add_strategy(self, memory_id: str, strategy: Dict[str, Any]) -> Dict[str, Any]:
1511
+ """Add a strategy to a memory (without waiting).
1512
+
1513
+ WARNING: After adding a strategy, the memory enters CREATING state temporarily.
1514
+ Use add_*_strategy_and_wait() methods instead to avoid errors.
1515
+
1516
+ Args:
1517
+ memory_id: Memory resource ID
1518
+ strategy: Strategy configuration dictionary
1519
+
1520
+ Returns:
1521
+ Updated memory response
1522
+ """
1523
+ warnings.warn(
1524
+ "add_strategy() may leave memory in CREATING state. "
1525
+ "Use add_*_strategy_and_wait() methods to avoid subsequent errors.",
1526
+ UserWarning,
1527
+ stacklevel=2,
1528
+ )
1529
+ return self._add_strategy(memory_id, strategy)
1530
+
1531
+ # Private methods
1532
+
1533
+ def _normalize_memory_response(self, memory: Dict[str, Any]) -> Dict[str, Any]:
1534
+ """Normalize memory response to include both old and new field names.
1535
+
1536
+ The API returns new field names but SDK users might expect old ones.
1537
+ This ensures compatibility by providing both.
1538
+ """
1539
+ # Ensure both versions of memory ID exist
1540
+ if "id" in memory and "memoryId" not in memory:
1541
+ memory["memoryId"] = memory["id"]
1542
+ elif "memoryId" in memory and "id" not in memory:
1543
+ memory["id"] = memory["memoryId"]
1544
+
1545
+ # Ensure both versions of strategies exist
1546
+ if "strategies" in memory and "memoryStrategies" not in memory:
1547
+ memory["memoryStrategies"] = memory["strategies"]
1548
+ elif "memoryStrategies" in memory and "strategies" not in memory:
1549
+ memory["strategies"] = memory["memoryStrategies"]
1550
+
1551
+ # Normalize strategies within memory
1552
+ if "strategies" in memory:
1553
+ normalized_strategies = []
1554
+ for strategy in memory["strategies"]:
1555
+ normalized = strategy.copy()
1556
+
1557
+ # Ensure both field name versions exist for strategies
1558
+ if "strategyId" in strategy and "memoryStrategyId" not in normalized:
1559
+ normalized["memoryStrategyId"] = strategy["strategyId"]
1560
+ elif "memoryStrategyId" in strategy and "strategyId" not in normalized:
1561
+ normalized["strategyId"] = strategy["memoryStrategyId"]
1562
+
1563
+ if "type" in strategy and "memoryStrategyType" not in normalized:
1564
+ normalized["memoryStrategyType"] = strategy["type"]
1565
+ elif "memoryStrategyType" in strategy and "type" not in normalized:
1566
+ normalized["type"] = strategy["memoryStrategyType"]
1567
+
1568
+ normalized_strategies.append(normalized)
1569
+
1570
+ memory["strategies"] = normalized_strategies
1571
+ memory["memoryStrategies"] = normalized_strategies
1572
+
1573
+ return memory
1574
+
1575
+ def _add_strategy(self, memory_id: str, strategy: Dict[str, Any]) -> Dict[str, Any]:
1576
+ """Internal method to add a single strategy."""
1577
+ return self.update_memory_strategies(memory_id=memory_id, add_strategies=[strategy])
1578
+
1579
+ def _wait_for_memory_active(self, memory_id: str, max_wait: int, poll_interval: int) -> Dict[str, Any]:
1580
+ """Wait for memory to return to ACTIVE state after strategy update."""
1581
+ logger.info("Waiting for memory %s to return to ACTIVE state...", memory_id)
1582
+
1583
+ start_time = time.time()
1584
+ while time.time() - start_time < max_wait:
1585
+ elapsed = int(time.time() - start_time)
1586
+
1587
+ try:
1588
+ status = self.get_memory_status(memory_id)
1589
+
1590
+ if status == MemoryStatus.ACTIVE.value:
1591
+ logger.info("Memory %s is ACTIVE again (took %d seconds)", memory_id, elapsed)
1592
+ response = self.gmcp_client.get_memory(memoryId=memory_id) # Input uses old field name
1593
+ memory = self._normalize_memory_response(response["memory"])
1594
+ return memory
1595
+ elif status == MemoryStatus.FAILED.value:
1596
+ response = self.gmcp_client.get_memory(memoryId=memory_id) # Input uses old field name
1597
+ failure_reason = response["memory"].get("failureReason", "Unknown")
1598
+ raise RuntimeError("Memory update failed: %s" % failure_reason)
1599
+ else:
1600
+ logger.debug("Memory status: %s (%d seconds elapsed)", status, elapsed)
1601
+
1602
+ except ClientError as e:
1603
+ logger.error("Error checking memory status: %s", e)
1604
+ raise
1605
+
1606
+ time.sleep(poll_interval)
1607
+
1608
+ raise TimeoutError("Memory %s did not return to ACTIVE state within %d seconds" % (memory_id, max_wait))
1609
+
1610
+ def _add_default_namespaces(self, strategies: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
1611
+ """Add default namespaces to strategies that don't have them."""
1612
+ processed = []
1613
+
1614
+ for strategy in strategies:
1615
+ strategy_copy = copy.deepcopy(strategy)
1616
+
1617
+ strategy_type_key = list(strategy.keys())[0]
1618
+ strategy_config = strategy_copy[strategy_type_key]
1619
+
1620
+ if "namespaces" not in strategy_config:
1621
+ strategy_type = StrategyType(strategy_type_key)
1622
+ strategy_config["namespaces"] = DEFAULT_NAMESPACES.get(strategy_type, ["custom/{actorId}/{sessionId}"])
1623
+
1624
+ self._validate_strategy_config(strategy_copy, strategy_type_key)
1625
+
1626
+ processed.append(strategy_copy)
1627
+
1628
+ return processed
1629
+
1630
+ def _validate_namespace(self, namespace: str) -> bool:
1631
+ """Validate namespace format - basic check only."""
1632
+ # Only check for template variables in namespace definition
1633
+ # Note: Using memoryStrategyId (old name) as it's still used in input parameters
1634
+ if "{" in namespace and not (
1635
+ "{actorId}" in namespace or "{sessionId}" in namespace or "{memoryStrategyId}" in namespace
1636
+ ):
1637
+ logger.warning("Namespace with templates should contain valid variables: %s", namespace)
1638
+
1639
+ return True
1640
+
1641
+ def _validate_strategy_config(self, strategy: Dict[str, Any], strategy_type: str) -> None:
1642
+ """Validate strategy configuration parameters."""
1643
+ strategy_config = strategy[strategy_type]
1644
+
1645
+ namespaces = strategy_config.get("namespaces", [])
1646
+ for namespace in namespaces:
1647
+ self._validate_namespace(namespace)
1648
+
1649
+ def _wrap_configuration(
1650
+ self, config: Dict[str, Any], strategy_type: str, override_type: Optional[str] = None
1651
+ ) -> Dict[str, Any]:
1652
+ """Wrap configuration based on strategy type."""
1653
+ wrapped_config = {}
1654
+
1655
+ if "extraction" in config:
1656
+ extraction = config["extraction"]
1657
+
1658
+ if any(key in extraction for key in ["triggerEveryNMessages", "historicalContextWindowSize"]):
1659
+ strategy_type_enum = MemoryStrategyTypeEnum(strategy_type)
1660
+
1661
+ if strategy_type == "SEMANTIC":
1662
+ wrapped_config["extraction"] = {EXTRACTION_WRAPPER_KEYS[strategy_type_enum]: extraction}
1663
+ elif strategy_type == "USER_PREFERENCE":
1664
+ wrapped_config["extraction"] = {EXTRACTION_WRAPPER_KEYS[strategy_type_enum]: extraction}
1665
+ elif strategy_type == "CUSTOM" and override_type:
1666
+ override_enum = OverrideType(override_type)
1667
+ if override_type in ["SEMANTIC_OVERRIDE", "USER_PREFERENCE_OVERRIDE"]:
1668
+ wrapped_config["extraction"] = {
1669
+ "customExtractionConfiguration": {CUSTOM_EXTRACTION_WRAPPER_KEYS[override_enum]: extraction}
1670
+ }
1671
+ else:
1672
+ wrapped_config["extraction"] = extraction
1673
+
1674
+ if "consolidation" in config:
1675
+ consolidation = config["consolidation"]
1676
+
1677
+ raw_keys = ["triggerEveryNMessages", "appendToPrompt", "modelId"]
1678
+ if any(key in consolidation for key in raw_keys):
1679
+ if strategy_type == "SUMMARIZATION":
1680
+ if "triggerEveryNMessages" in consolidation:
1681
+ wrapped_config["consolidation"] = {
1682
+ "summaryConsolidationConfiguration": {
1683
+ "triggerEveryNMessages": consolidation["triggerEveryNMessages"]
1684
+ }
1685
+ }
1686
+ elif strategy_type == "CUSTOM" and override_type:
1687
+ override_enum = OverrideType(override_type)
1688
+ if override_enum in CUSTOM_CONSOLIDATION_WRAPPER_KEYS:
1689
+ wrapped_config["consolidation"] = {
1690
+ "customConsolidationConfiguration": {
1691
+ CUSTOM_CONSOLIDATION_WRAPPER_KEYS[override_enum]: consolidation
1692
+ }
1693
+ }
1694
+ else:
1695
+ wrapped_config["consolidation"] = consolidation
1696
+
1697
+ return wrapped_config