letta-nightly 0.11.6.dev20250903104037__py3-none-any.whl → 0.11.7.dev20250904104046__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. letta/__init__.py +1 -1
  2. letta/agent.py +10 -14
  3. letta/agents/base_agent.py +18 -0
  4. letta/agents/helpers.py +32 -7
  5. letta/agents/letta_agent.py +953 -762
  6. letta/agents/voice_agent.py +1 -1
  7. letta/client/streaming.py +0 -1
  8. letta/constants.py +11 -8
  9. letta/errors.py +9 -0
  10. letta/functions/function_sets/base.py +77 -69
  11. letta/functions/function_sets/builtin.py +41 -22
  12. letta/functions/function_sets/multi_agent.py +1 -2
  13. letta/functions/schema_generator.py +0 -1
  14. letta/helpers/converters.py +8 -3
  15. letta/helpers/datetime_helpers.py +5 -4
  16. letta/helpers/message_helper.py +1 -2
  17. letta/helpers/pinecone_utils.py +0 -1
  18. letta/helpers/tool_rule_solver.py +10 -0
  19. letta/helpers/tpuf_client.py +848 -0
  20. letta/interface.py +8 -8
  21. letta/interfaces/anthropic_streaming_interface.py +7 -0
  22. letta/interfaces/openai_streaming_interface.py +29 -6
  23. letta/llm_api/anthropic_client.py +188 -18
  24. letta/llm_api/azure_client.py +0 -1
  25. letta/llm_api/bedrock_client.py +1 -2
  26. letta/llm_api/deepseek_client.py +319 -5
  27. letta/llm_api/google_vertex_client.py +75 -17
  28. letta/llm_api/groq_client.py +0 -1
  29. letta/llm_api/helpers.py +2 -2
  30. letta/llm_api/llm_api_tools.py +1 -50
  31. letta/llm_api/llm_client.py +6 -8
  32. letta/llm_api/mistral.py +1 -1
  33. letta/llm_api/openai.py +16 -13
  34. letta/llm_api/openai_client.py +31 -16
  35. letta/llm_api/together_client.py +0 -1
  36. letta/llm_api/xai_client.py +0 -1
  37. letta/local_llm/chat_completion_proxy.py +7 -6
  38. letta/local_llm/settings/settings.py +1 -1
  39. letta/orm/__init__.py +1 -0
  40. letta/orm/agent.py +8 -6
  41. letta/orm/archive.py +9 -1
  42. letta/orm/block.py +3 -4
  43. letta/orm/block_history.py +3 -1
  44. letta/orm/group.py +2 -3
  45. letta/orm/identity.py +1 -2
  46. letta/orm/job.py +1 -2
  47. letta/orm/llm_batch_items.py +1 -2
  48. letta/orm/message.py +8 -4
  49. letta/orm/mixins.py +18 -0
  50. letta/orm/organization.py +2 -0
  51. letta/orm/passage.py +8 -1
  52. letta/orm/passage_tag.py +55 -0
  53. letta/orm/sandbox_config.py +1 -3
  54. letta/orm/step.py +1 -2
  55. letta/orm/tool.py +1 -0
  56. letta/otel/resource.py +2 -2
  57. letta/plugins/plugins.py +1 -1
  58. letta/prompts/prompt_generator.py +10 -2
  59. letta/schemas/agent.py +11 -0
  60. letta/schemas/archive.py +4 -0
  61. letta/schemas/block.py +13 -0
  62. letta/schemas/embedding_config.py +0 -1
  63. letta/schemas/enums.py +24 -7
  64. letta/schemas/group.py +12 -0
  65. letta/schemas/letta_message.py +55 -1
  66. letta/schemas/letta_message_content.py +28 -0
  67. letta/schemas/letta_request.py +21 -4
  68. letta/schemas/letta_stop_reason.py +9 -1
  69. letta/schemas/llm_config.py +24 -8
  70. letta/schemas/mcp.py +0 -3
  71. letta/schemas/memory.py +14 -0
  72. letta/schemas/message.py +245 -141
  73. letta/schemas/openai/chat_completion_request.py +2 -1
  74. letta/schemas/passage.py +1 -0
  75. letta/schemas/providers/bedrock.py +1 -1
  76. letta/schemas/providers/openai.py +2 -2
  77. letta/schemas/tool.py +11 -5
  78. letta/schemas/tool_execution_result.py +0 -1
  79. letta/schemas/tool_rule.py +71 -0
  80. letta/serialize_schemas/marshmallow_agent.py +1 -2
  81. letta/server/rest_api/app.py +3 -3
  82. letta/server/rest_api/auth/index.py +0 -1
  83. letta/server/rest_api/interface.py +3 -11
  84. letta/server/rest_api/redis_stream_manager.py +3 -4
  85. letta/server/rest_api/routers/v1/agents.py +143 -84
  86. letta/server/rest_api/routers/v1/blocks.py +1 -1
  87. letta/server/rest_api/routers/v1/folders.py +1 -1
  88. letta/server/rest_api/routers/v1/groups.py +23 -22
  89. letta/server/rest_api/routers/v1/internal_templates.py +68 -0
  90. letta/server/rest_api/routers/v1/sandbox_configs.py +11 -5
  91. letta/server/rest_api/routers/v1/sources.py +1 -1
  92. letta/server/rest_api/routers/v1/tools.py +167 -15
  93. letta/server/rest_api/streaming_response.py +4 -3
  94. letta/server/rest_api/utils.py +75 -18
  95. letta/server/server.py +24 -35
  96. letta/services/agent_manager.py +359 -45
  97. letta/services/agent_serialization_manager.py +23 -3
  98. letta/services/archive_manager.py +72 -3
  99. letta/services/block_manager.py +1 -2
  100. letta/services/context_window_calculator/token_counter.py +11 -6
  101. letta/services/file_manager.py +1 -3
  102. letta/services/files_agents_manager.py +2 -4
  103. letta/services/group_manager.py +73 -12
  104. letta/services/helpers/agent_manager_helper.py +5 -5
  105. letta/services/identity_manager.py +8 -3
  106. letta/services/job_manager.py +2 -14
  107. letta/services/llm_batch_manager.py +1 -3
  108. letta/services/mcp/base_client.py +1 -2
  109. letta/services/mcp_manager.py +5 -6
  110. letta/services/message_manager.py +536 -15
  111. letta/services/organization_manager.py +1 -2
  112. letta/services/passage_manager.py +287 -12
  113. letta/services/provider_manager.py +1 -3
  114. letta/services/sandbox_config_manager.py +12 -7
  115. letta/services/source_manager.py +1 -2
  116. letta/services/step_manager.py +0 -1
  117. letta/services/summarizer/summarizer.py +4 -2
  118. letta/services/telemetry_manager.py +1 -3
  119. letta/services/tool_executor/builtin_tool_executor.py +136 -316
  120. letta/services/tool_executor/core_tool_executor.py +231 -74
  121. letta/services/tool_executor/files_tool_executor.py +2 -2
  122. letta/services/tool_executor/mcp_tool_executor.py +0 -1
  123. letta/services/tool_executor/multi_agent_tool_executor.py +2 -2
  124. letta/services/tool_executor/sandbox_tool_executor.py +0 -1
  125. letta/services/tool_executor/tool_execution_sandbox.py +2 -3
  126. letta/services/tool_manager.py +181 -64
  127. letta/services/tool_sandbox/modal_deployment_manager.py +2 -2
  128. letta/services/user_manager.py +1 -2
  129. letta/settings.py +5 -3
  130. letta/streaming_interface.py +3 -3
  131. letta/system.py +1 -1
  132. letta/utils.py +0 -1
  133. {letta_nightly-0.11.6.dev20250903104037.dist-info → letta_nightly-0.11.7.dev20250904104046.dist-info}/METADATA +11 -7
  134. {letta_nightly-0.11.6.dev20250903104037.dist-info → letta_nightly-0.11.7.dev20250904104046.dist-info}/RECORD +137 -135
  135. letta/llm_api/deepseek.py +0 -303
  136. {letta_nightly-0.11.6.dev20250903104037.dist-info → letta_nightly-0.11.7.dev20250904104046.dist-info}/WHEEL +0 -0
  137. {letta_nightly-0.11.6.dev20250903104037.dist-info → letta_nightly-0.11.7.dev20250904104046.dist-info}/entry_points.txt +0 -0
  138. {letta_nightly-0.11.6.dev20250903104037.dist-info → letta_nightly-0.11.7.dev20250904104046.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,848 @@
1
+ """Turbopuffer utilities for archival memory storage."""
2
+
3
+ import logging
4
+ from datetime import datetime, timezone
5
+ from typing import Any, Callable, List, Optional, Tuple
6
+
7
+ from letta.otel.tracing import trace_method
8
+ from letta.schemas.enums import MessageRole, TagMatchMode
9
+ from letta.schemas.passage import Passage as PydanticPassage
10
+ from letta.settings import settings
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def should_use_tpuf() -> bool:
16
+ return bool(settings.use_tpuf) and bool(settings.tpuf_api_key)
17
+
18
+
19
+ def should_use_tpuf_for_messages() -> bool:
20
+ """Check if Turbopuffer should be used for messages."""
21
+ return should_use_tpuf() and bool(settings.embed_all_messages)
22
+
23
+
24
+ class TurbopufferClient:
25
+ """Client for managing archival memory with Turbopuffer vector database."""
26
+
27
+ def __init__(self, api_key: str = None, region: str = None):
28
+ """Initialize Turbopuffer client."""
29
+ self.api_key = api_key or settings.tpuf_api_key
30
+ self.region = region or settings.tpuf_region
31
+
32
+ from letta.services.agent_manager import AgentManager
33
+ from letta.services.archive_manager import ArchiveManager
34
+
35
+ self.archive_manager = ArchiveManager()
36
+ self.agent_manager = AgentManager()
37
+
38
+ if not self.api_key:
39
+ raise ValueError("Turbopuffer API key not provided")
40
+
41
+ @trace_method
42
+ async def _get_archive_namespace_name(self, archive_id: str) -> str:
43
+ """Get namespace name for a specific archive."""
44
+ return await self.archive_manager.get_or_set_vector_db_namespace_async(archive_id)
45
+
46
+ @trace_method
47
+ async def _get_message_namespace_name(self, agent_id: str, organization_id: str) -> str:
48
+ """Get namespace name for messages (org-scoped).
49
+
50
+ Args:
51
+ agent_id: Agent ID (stored for future sharding)
52
+ organization_id: Organization ID for namespace generation
53
+
54
+ Returns:
55
+ The org-scoped namespace name for messages
56
+ """
57
+ return await self.agent_manager.get_or_set_vector_db_namespace_async(agent_id, organization_id)
58
+
59
+ @trace_method
60
+ async def insert_archival_memories(
61
+ self,
62
+ archive_id: str,
63
+ text_chunks: List[str],
64
+ embeddings: List[List[float]],
65
+ passage_ids: List[str],
66
+ organization_id: str,
67
+ tags: Optional[List[str]] = None,
68
+ created_at: Optional[datetime] = None,
69
+ ) -> List[PydanticPassage]:
70
+ """Insert passages into Turbopuffer.
71
+
72
+ Args:
73
+ archive_id: ID of the archive
74
+ text_chunks: List of text chunks to store
75
+ embeddings: List of embedding vectors corresponding to text chunks
76
+ passage_ids: List of passage IDs (must match 1:1 with text_chunks)
77
+ organization_id: Organization ID for the passages
78
+ tags: Optional list of tags to attach to all passages
79
+ created_at: Optional timestamp for retroactive entries (defaults to current UTC time)
80
+
81
+ Returns:
82
+ List of PydanticPassage objects that were inserted
83
+ """
84
+ from turbopuffer import AsyncTurbopuffer
85
+
86
+ namespace_name = await self._get_archive_namespace_name(archive_id)
87
+
88
+ # handle timestamp - ensure UTC
89
+ if created_at is None:
90
+ timestamp = datetime.now(timezone.utc)
91
+ else:
92
+ # ensure the provided timestamp is timezone-aware and in UTC
93
+ if created_at.tzinfo is None:
94
+ # assume UTC if no timezone provided
95
+ timestamp = created_at.replace(tzinfo=timezone.utc)
96
+ else:
97
+ # convert to UTC if in different timezone
98
+ timestamp = created_at.astimezone(timezone.utc)
99
+
100
+ # passage_ids must be provided for dual-write consistency
101
+ if not passage_ids:
102
+ raise ValueError("passage_ids must be provided for Turbopuffer insertion")
103
+ if len(passage_ids) != len(text_chunks):
104
+ raise ValueError(f"passage_ids length ({len(passage_ids)}) must match text_chunks length ({len(text_chunks)})")
105
+ if len(passage_ids) != len(embeddings):
106
+ raise ValueError(f"passage_ids length ({len(passage_ids)}) must match embeddings length ({len(embeddings)})")
107
+
108
+ # prepare column-based data for turbopuffer - optimized for batch insert
109
+ ids = []
110
+ vectors = []
111
+ texts = []
112
+ organization_ids = []
113
+ archive_ids = []
114
+ created_ats = []
115
+ tags_arrays = [] # Store tags as arrays
116
+ passages = []
117
+
118
+ for idx, (text, embedding) in enumerate(zip(text_chunks, embeddings)):
119
+ passage_id = passage_ids[idx]
120
+
121
+ # append to columns
122
+ ids.append(passage_id)
123
+ vectors.append(embedding)
124
+ texts.append(text)
125
+ organization_ids.append(organization_id)
126
+ archive_ids.append(archive_id)
127
+ created_ats.append(timestamp)
128
+ tags_arrays.append(tags or []) # Store tags as array
129
+
130
+ # Create PydanticPassage object
131
+ passage = PydanticPassage(
132
+ id=passage_id,
133
+ text=text,
134
+ organization_id=organization_id,
135
+ archive_id=archive_id,
136
+ created_at=timestamp,
137
+ metadata_={},
138
+ tags=tags or [], # Include tags in the passage
139
+ embedding=embedding,
140
+ embedding_config=None, # Will be set by caller if needed
141
+ )
142
+ passages.append(passage)
143
+
144
+ # build column-based upsert data
145
+ upsert_columns = {
146
+ "id": ids,
147
+ "vector": vectors,
148
+ "text": texts,
149
+ "organization_id": organization_ids,
150
+ "archive_id": archive_ids,
151
+ "created_at": created_ats,
152
+ "tags": tags_arrays, # Add tags as array column
153
+ }
154
+
155
+ try:
156
+ # Use AsyncTurbopuffer as a context manager for proper resource cleanup
157
+ async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
158
+ namespace = client.namespace(namespace_name)
159
+ # turbopuffer recommends column-based writes for performance
160
+ await namespace.write(
161
+ upsert_columns=upsert_columns,
162
+ distance_metric="cosine_distance",
163
+ schema={"text": {"type": "string", "full_text_search": True}},
164
+ )
165
+ logger.info(f"Successfully inserted {len(ids)} passages to Turbopuffer for archive {archive_id}")
166
+ return passages
167
+
168
+ except Exception as e:
169
+ logger.error(f"Failed to insert passages to Turbopuffer: {e}")
170
+ # check if it's a duplicate ID error
171
+ if "duplicate" in str(e).lower():
172
+ logger.error("Duplicate passage IDs detected in batch")
173
+ raise
174
+
175
+ @trace_method
176
+ async def insert_messages(
177
+ self,
178
+ agent_id: str,
179
+ message_texts: List[str],
180
+ embeddings: List[List[float]],
181
+ message_ids: List[str],
182
+ organization_id: str,
183
+ roles: List[MessageRole],
184
+ created_ats: List[datetime],
185
+ ) -> bool:
186
+ """Insert messages into Turbopuffer.
187
+
188
+ Args:
189
+ agent_id: ID of the agent
190
+ message_texts: List of message text content to store
191
+ embeddings: List of embedding vectors corresponding to message texts
192
+ message_ids: List of message IDs (must match 1:1 with message_texts)
193
+ organization_id: Organization ID for the messages
194
+ roles: List of message roles corresponding to each message
195
+ created_ats: List of creation timestamps for each message
196
+
197
+ Returns:
198
+ True if successful
199
+ """
200
+ from turbopuffer import AsyncTurbopuffer
201
+
202
+ namespace_name = await self._get_message_namespace_name(agent_id, organization_id)
203
+
204
+ # validation checks
205
+ if not message_ids:
206
+ raise ValueError("message_ids must be provided for Turbopuffer insertion")
207
+ if len(message_ids) != len(message_texts):
208
+ raise ValueError(f"message_ids length ({len(message_ids)}) must match message_texts length ({len(message_texts)})")
209
+ if len(message_ids) != len(embeddings):
210
+ raise ValueError(f"message_ids length ({len(message_ids)}) must match embeddings length ({len(embeddings)})")
211
+ if len(message_ids) != len(roles):
212
+ raise ValueError(f"message_ids length ({len(message_ids)}) must match roles length ({len(roles)})")
213
+ if len(message_ids) != len(created_ats):
214
+ raise ValueError(f"message_ids length ({len(message_ids)}) must match created_ats length ({len(created_ats)})")
215
+
216
+ # prepare column-based data for turbopuffer - optimized for batch insert
217
+ ids = []
218
+ vectors = []
219
+ texts = []
220
+ organization_ids = []
221
+ agent_ids = []
222
+ message_roles = []
223
+ created_at_timestamps = []
224
+
225
+ for idx, (text, embedding, role, created_at) in enumerate(zip(message_texts, embeddings, roles, created_ats)):
226
+ message_id = message_ids[idx]
227
+
228
+ # ensure the provided timestamp is timezone-aware and in UTC
229
+ if created_at.tzinfo is None:
230
+ # assume UTC if no timezone provided
231
+ timestamp = created_at.replace(tzinfo=timezone.utc)
232
+ else:
233
+ # convert to UTC if in different timezone
234
+ timestamp = created_at.astimezone(timezone.utc)
235
+
236
+ # append to columns
237
+ ids.append(message_id)
238
+ vectors.append(embedding)
239
+ texts.append(text)
240
+ organization_ids.append(organization_id)
241
+ agent_ids.append(agent_id)
242
+ message_roles.append(role.value)
243
+ created_at_timestamps.append(timestamp)
244
+
245
+ # build column-based upsert data
246
+ upsert_columns = {
247
+ "id": ids,
248
+ "vector": vectors,
249
+ "text": texts,
250
+ "organization_id": organization_ids,
251
+ "agent_id": agent_ids,
252
+ "role": message_roles,
253
+ "created_at": created_at_timestamps,
254
+ }
255
+
256
+ try:
257
+ # Use AsyncTurbopuffer as a context manager for proper resource cleanup
258
+ async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
259
+ namespace = client.namespace(namespace_name)
260
+ # turbopuffer recommends column-based writes for performance
261
+ await namespace.write(
262
+ upsert_columns=upsert_columns,
263
+ distance_metric="cosine_distance",
264
+ schema={"text": {"type": "string", "full_text_search": True}},
265
+ )
266
+ logger.info(f"Successfully inserted {len(ids)} messages to Turbopuffer for agent {agent_id}")
267
+ return True
268
+
269
+ except Exception as e:
270
+ logger.error(f"Failed to insert messages to Turbopuffer: {e}")
271
+ # check if it's a duplicate ID error
272
+ if "duplicate" in str(e).lower():
273
+ logger.error("Duplicate message IDs detected in batch")
274
+ raise
275
+
276
+ @trace_method
277
+ async def _execute_query(
278
+ self,
279
+ namespace_name: str,
280
+ search_mode: str,
281
+ query_embedding: Optional[List[float]],
282
+ query_text: Optional[str],
283
+ top_k: int,
284
+ include_attributes: List[str],
285
+ filters: Optional[Any] = None,
286
+ vector_weight: float = 0.5,
287
+ fts_weight: float = 0.5,
288
+ ) -> Any:
289
+ """Generic query execution for Turbopuffer.
290
+
291
+ Args:
292
+ namespace_name: Turbopuffer namespace to query
293
+ search_mode: "vector", "fts", "hybrid", or "timestamp"
294
+ query_embedding: Embedding for vector search
295
+ query_text: Text for full-text search
296
+ top_k: Number of results to return
297
+ include_attributes: Attributes to include in results
298
+ filters: Turbopuffer filter expression
299
+ vector_weight: Weight for vector search in hybrid mode
300
+ fts_weight: Weight for FTS in hybrid mode
301
+
302
+ Returns:
303
+ Raw Turbopuffer query results or multi-query response
304
+ """
305
+ from turbopuffer import AsyncTurbopuffer
306
+ from turbopuffer.types import QueryParam
307
+
308
+ # validate inputs based on search mode
309
+ if search_mode == "vector" and query_embedding is None:
310
+ raise ValueError("query_embedding is required for vector search mode")
311
+ if search_mode == "fts" and query_text is None:
312
+ raise ValueError("query_text is required for FTS search mode")
313
+ if search_mode == "hybrid":
314
+ if query_embedding is None or query_text is None:
315
+ raise ValueError("Both query_embedding and query_text are required for hybrid search mode")
316
+ if search_mode not in ["vector", "fts", "hybrid", "timestamp"]:
317
+ raise ValueError(f"Invalid search_mode: {search_mode}. Must be 'vector', 'fts', 'hybrid', or 'timestamp'")
318
+
319
+ async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
320
+ namespace = client.namespace(namespace_name)
321
+
322
+ if search_mode == "timestamp":
323
+ # retrieve most recent items by timestamp
324
+ query_params = {
325
+ "rank_by": ("created_at", "desc"),
326
+ "top_k": top_k,
327
+ "include_attributes": include_attributes,
328
+ }
329
+ if filters:
330
+ query_params["filters"] = filters
331
+ return await namespace.query(**query_params)
332
+
333
+ elif search_mode == "vector":
334
+ # vector search query
335
+ query_params = {
336
+ "rank_by": ("vector", "ANN", query_embedding),
337
+ "top_k": top_k,
338
+ "include_attributes": include_attributes,
339
+ }
340
+ if filters:
341
+ query_params["filters"] = filters
342
+ return await namespace.query(**query_params)
343
+
344
+ elif search_mode == "fts":
345
+ # full-text search query
346
+ query_params = {
347
+ "rank_by": ("text", "BM25", query_text),
348
+ "top_k": top_k,
349
+ "include_attributes": include_attributes,
350
+ }
351
+ if filters:
352
+ query_params["filters"] = filters
353
+ return await namespace.query(**query_params)
354
+
355
+ else: # hybrid mode
356
+ queries = []
357
+
358
+ # vector search query
359
+ vector_query = {
360
+ "rank_by": ("vector", "ANN", query_embedding),
361
+ "top_k": top_k,
362
+ "include_attributes": include_attributes,
363
+ }
364
+ if filters:
365
+ vector_query["filters"] = filters
366
+ queries.append(vector_query)
367
+
368
+ # full-text search query
369
+ fts_query = {
370
+ "rank_by": ("text", "BM25", query_text),
371
+ "top_k": top_k,
372
+ "include_attributes": include_attributes,
373
+ }
374
+ if filters:
375
+ fts_query["filters"] = filters
376
+ queries.append(fts_query)
377
+
378
+ # execute multi-query
379
+ return await namespace.multi_query(queries=[QueryParam(**q) for q in queries])
380
+
381
+ @trace_method
382
+ async def query_passages(
383
+ self,
384
+ archive_id: str,
385
+ query_embedding: Optional[List[float]] = None,
386
+ query_text: Optional[str] = None,
387
+ search_mode: str = "vector", # "vector", "fts", "hybrid"
388
+ top_k: int = 10,
389
+ tags: Optional[List[str]] = None,
390
+ tag_match_mode: TagMatchMode = TagMatchMode.ANY,
391
+ vector_weight: float = 0.5,
392
+ fts_weight: float = 0.5,
393
+ start_date: Optional[datetime] = None,
394
+ end_date: Optional[datetime] = None,
395
+ ) -> List[Tuple[PydanticPassage, float]]:
396
+ """Query passages from Turbopuffer using vector search, full-text search, or hybrid search.
397
+
398
+ Args:
399
+ archive_id: ID of the archive
400
+ query_embedding: Embedding vector for vector search (required for "vector" and "hybrid" modes)
401
+ query_text: Text query for full-text search (required for "fts" and "hybrid" modes)
402
+ search_mode: Search mode - "vector", "fts", or "hybrid" (default: "vector")
403
+ top_k: Number of results to return
404
+ tags: Optional list of tags to filter by
405
+ tag_match_mode: TagMatchMode.ANY (match any tag) or TagMatchMode.ALL (match all tags) - default: TagMatchMode.ANY
406
+ vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
407
+ fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
408
+ start_date: Optional datetime to filter passages created after this date
409
+ end_date: Optional datetime to filter passages created before this date
410
+
411
+ Returns:
412
+ List of (passage, score) tuples
413
+ """
414
+ # Check if we should fallback to timestamp-based retrieval
415
+ if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
416
+ # Fallback to retrieving most recent passages when no search query is provided
417
+ search_mode = "timestamp"
418
+
419
+ namespace_name = await self._get_archive_namespace_name(archive_id)
420
+
421
+ # build tag filter conditions
422
+ tag_filter = None
423
+ if tags:
424
+ if tag_match_mode == TagMatchMode.ALL:
425
+ # For ALL mode, need to check each tag individually with Contains
426
+ tag_conditions = []
427
+ for tag in tags:
428
+ tag_conditions.append(("tags", "Contains", tag))
429
+ if len(tag_conditions) == 1:
430
+ tag_filter = tag_conditions[0]
431
+ else:
432
+ tag_filter = ("And", tag_conditions)
433
+ else: # tag_match_mode == TagMatchMode.ANY
434
+ # For ANY mode, use ContainsAny to match any of the tags
435
+ tag_filter = ("tags", "ContainsAny", tags)
436
+
437
+ # build date filter conditions
438
+ date_filters = []
439
+ if start_date:
440
+ date_filters.append(("created_at", "Gte", start_date))
441
+ if end_date:
442
+ date_filters.append(("created_at", "Lte", end_date))
443
+
444
+ # combine all filters
445
+ all_filters = []
446
+ if tag_filter:
447
+ all_filters.append(tag_filter)
448
+ if date_filters:
449
+ all_filters.extend(date_filters)
450
+
451
+ # create final filter expression
452
+ final_filter = None
453
+ if len(all_filters) == 1:
454
+ final_filter = all_filters[0]
455
+ elif len(all_filters) > 1:
456
+ final_filter = ("And", all_filters)
457
+
458
+ try:
459
+ # use generic query executor
460
+ result = await self._execute_query(
461
+ namespace_name=namespace_name,
462
+ search_mode=search_mode,
463
+ query_embedding=query_embedding,
464
+ query_text=query_text,
465
+ top_k=top_k,
466
+ include_attributes=["text", "organization_id", "archive_id", "created_at", "tags"],
467
+ filters=final_filter,
468
+ vector_weight=vector_weight,
469
+ fts_weight=fts_weight,
470
+ )
471
+
472
+ # process results based on search mode
473
+ if search_mode == "hybrid":
474
+ # for hybrid mode, we get a multi-query response
475
+ vector_results = self._process_single_query_results(result.results[0], archive_id, tags)
476
+ fts_results = self._process_single_query_results(result.results[1], archive_id, tags, is_fts=True)
477
+ # use RRF and return only (passage, score) for backwards compatibility
478
+ results_with_metadata = self._reciprocal_rank_fusion(
479
+ vector_results=[passage for passage, _ in vector_results],
480
+ fts_results=[passage for passage, _ in fts_results],
481
+ get_id_func=lambda p: p.id,
482
+ vector_weight=vector_weight,
483
+ fts_weight=fts_weight,
484
+ top_k=top_k,
485
+ )
486
+ return [(passage, rrf_score) for passage, rrf_score, metadata in results_with_metadata]
487
+ else:
488
+ # for single queries (vector, fts, timestamp)
489
+ is_fts = search_mode == "fts"
490
+ return self._process_single_query_results(result, archive_id, tags, is_fts=is_fts)
491
+
492
+ except Exception as e:
493
+ logger.error(f"Failed to query passages from Turbopuffer: {e}")
494
+ raise
495
+
496
+ @trace_method
497
+ async def query_messages(
498
+ self,
499
+ agent_id: str,
500
+ organization_id: str,
501
+ query_embedding: Optional[List[float]] = None,
502
+ query_text: Optional[str] = None,
503
+ search_mode: str = "vector", # "vector", "fts", "hybrid", "timestamp"
504
+ top_k: int = 10,
505
+ roles: Optional[List[MessageRole]] = None,
506
+ vector_weight: float = 0.5,
507
+ fts_weight: float = 0.5,
508
+ start_date: Optional[datetime] = None,
509
+ end_date: Optional[datetime] = None,
510
+ ) -> List[Tuple[dict, float, dict]]:
511
+ """Query messages from Turbopuffer using vector search, full-text search, or hybrid search.
512
+
513
+ Args:
514
+ agent_id: ID of the agent (used for filtering results)
515
+ organization_id: Organization ID for namespace lookup
516
+ query_embedding: Embedding vector for vector search (required for "vector" and "hybrid" modes)
517
+ query_text: Text query for full-text search (required for "fts" and "hybrid" modes)
518
+ search_mode: Search mode - "vector", "fts", "hybrid", or "timestamp" (default: "vector")
519
+ top_k: Number of results to return
520
+ roles: Optional list of message roles to filter by
521
+ vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
522
+ fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
523
+ start_date: Optional datetime to filter messages created after this date
524
+ end_date: Optional datetime to filter messages created before this date
525
+
526
+ Returns:
527
+ List of (message_dict, score, metadata) tuples where:
528
+ - message_dict contains id, text, role, created_at
529
+ - score is the final relevance score
530
+ - metadata contains individual scores and ranking information
531
+ """
532
+ # Check if we should fallback to timestamp-based retrieval
533
+ if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
534
+ # Fallback to retrieving most recent messages when no search query is provided
535
+ search_mode = "timestamp"
536
+
537
+ namespace_name = await self._get_message_namespace_name(agent_id, organization_id)
538
+
539
+ # build agent_id filter
540
+ agent_filter = ("agent_id", "Eq", agent_id)
541
+
542
+ # build role filter conditions
543
+ role_filter = None
544
+ if roles:
545
+ role_values = [r.value for r in roles]
546
+ if len(role_values) == 1:
547
+ role_filter = ("role", "Eq", role_values[0])
548
+ else:
549
+ role_filter = ("role", "In", role_values)
550
+
551
+ # build date filter conditions
552
+ date_filters = []
553
+ if start_date:
554
+ date_filters.append(("created_at", "Gte", start_date))
555
+ if end_date:
556
+ date_filters.append(("created_at", "Lte", end_date))
557
+
558
+ # combine all filters
559
+ all_filters = [agent_filter] # always include agent_id filter
560
+ if role_filter:
561
+ all_filters.append(role_filter)
562
+ if date_filters:
563
+ all_filters.extend(date_filters)
564
+
565
+ # create final filter expression
566
+ final_filter = None
567
+ if len(all_filters) == 1:
568
+ final_filter = all_filters[0]
569
+ elif len(all_filters) > 1:
570
+ final_filter = ("And", all_filters)
571
+
572
+ try:
573
+ # use generic query executor
574
+ result = await self._execute_query(
575
+ namespace_name=namespace_name,
576
+ search_mode=search_mode,
577
+ query_embedding=query_embedding,
578
+ query_text=query_text,
579
+ top_k=top_k,
580
+ include_attributes=["text", "organization_id", "agent_id", "role", "created_at"],
581
+ filters=final_filter,
582
+ vector_weight=vector_weight,
583
+ fts_weight=fts_weight,
584
+ )
585
+
586
+ # process results based on search mode
587
+ if search_mode == "hybrid":
588
+ # for hybrid mode, we get a multi-query response
589
+ vector_results = self._process_message_query_results(result.results[0])
590
+ fts_results = self._process_message_query_results(result.results[1])
591
+ # use RRF with lambda to extract ID from dict - returns metadata
592
+ results_with_metadata = self._reciprocal_rank_fusion(
593
+ vector_results=vector_results,
594
+ fts_results=fts_results,
595
+ get_id_func=lambda msg_dict: msg_dict["id"],
596
+ vector_weight=vector_weight,
597
+ fts_weight=fts_weight,
598
+ top_k=top_k,
599
+ )
600
+ # return results with metadata
601
+ return results_with_metadata
602
+ else:
603
+ # for single queries (vector, fts, timestamp)
604
+ results = self._process_message_query_results(result)
605
+ # add simple metadata for single search modes
606
+ results_with_metadata = []
607
+ for idx, msg_dict in enumerate(results):
608
+ metadata = {
609
+ "combined_score": 1.0 / (idx + 1), # Use rank-based score for single mode
610
+ "search_mode": search_mode,
611
+ f"{search_mode}_rank": idx + 1, # Add the rank for this search mode
612
+ }
613
+ results_with_metadata.append((msg_dict, metadata["combined_score"], metadata))
614
+ return results_with_metadata
615
+
616
+ except Exception as e:
617
+ logger.error(f"Failed to query messages from Turbopuffer: {e}")
618
+ raise
619
+
620
+ def _process_message_query_results(self, result) -> List[dict]:
621
+ """Process results from a message query into message dicts.
622
+
623
+ For RRF, we only need the rank order - scores are not used.
624
+ """
625
+ messages = []
626
+
627
+ for row in result.rows:
628
+ # Build message dict with key fields
629
+ message_dict = {
630
+ "id": row.id,
631
+ "text": getattr(row, "text", ""),
632
+ "organization_id": getattr(row, "organization_id", None),
633
+ "agent_id": getattr(row, "agent_id", None),
634
+ "role": getattr(row, "role", None),
635
+ "created_at": getattr(row, "created_at", None),
636
+ }
637
+ messages.append(message_dict)
638
+
639
+ return messages
640
+
641
+ def _process_single_query_results(
642
+ self, result, archive_id: str, tags: Optional[List[str]], is_fts: bool = False
643
+ ) -> List[Tuple[PydanticPassage, float]]:
644
+ """Process results from a single query into passage objects with scores."""
645
+ passages_with_scores = []
646
+
647
+ for row in result.rows:
648
+ # Extract tags from the result row
649
+ passage_tags = getattr(row, "tags", []) or []
650
+
651
+ # Build metadata
652
+ metadata = {}
653
+
654
+ # Create a passage with minimal fields - embeddings are not returned from Turbopuffer
655
+ passage = PydanticPassage(
656
+ id=row.id,
657
+ text=getattr(row, "text", ""),
658
+ organization_id=getattr(row, "organization_id", None),
659
+ archive_id=archive_id, # use the archive_id from the query
660
+ created_at=getattr(row, "created_at", None),
661
+ metadata_=metadata,
662
+ tags=passage_tags, # Set the actual tags from the passage
663
+ # Set required fields to empty/default values since we don't store embeddings
664
+ embedding=[], # Empty embedding since we don't return it from Turbopuffer
665
+ embedding_config=None, # No embedding config needed for retrieved passages
666
+ )
667
+
668
+ # handle score based on search type
669
+ if is_fts:
670
+ # for FTS, use the BM25 score directly (higher is better)
671
+ score = getattr(row, "$score", 0.0)
672
+ else:
673
+ # for vector search, convert distance to similarity score
674
+ distance = getattr(row, "$dist", 0.0)
675
+ score = 1.0 - distance
676
+
677
+ passages_with_scores.append((passage, score))
678
+
679
+ return passages_with_scores
680
+
681
+ def _reciprocal_rank_fusion(
682
+ self,
683
+ vector_results: List[Any],
684
+ fts_results: List[Any],
685
+ get_id_func: Callable[[Any], str],
686
+ vector_weight: float,
687
+ fts_weight: float,
688
+ top_k: int,
689
+ ) -> List[Tuple[Any, float, dict]]:
690
+ """RRF implementation that works with any object type.
691
+
692
+ RRF score = vector_weight * (1/(k + rank)) + fts_weight * (1/(k + rank))
693
+ where k is a constant (typically 60) to avoid division by zero
694
+
695
+ This is a pure rank-based fusion following the standard RRF algorithm.
696
+
697
+ Args:
698
+ vector_results: List of items from vector search (ordered by relevance)
699
+ fts_results: List of items from FTS (ordered by relevance)
700
+ get_id_func: Function to extract ID from an item
701
+ vector_weight: Weight for vector search results
702
+ fts_weight: Weight for FTS results
703
+ top_k: Number of results to return
704
+
705
+ Returns:
706
+ List of (item, score, metadata) tuples sorted by RRF score
707
+ metadata contains ranks from each result list
708
+ """
709
+ k = 60 # standard RRF constant from Cormack et al. (2009)
710
+
711
+ # create rank mappings based on position in result lists
712
+ # rank starts at 1, not 0
713
+ vector_ranks = {get_id_func(item): rank + 1 for rank, item in enumerate(vector_results)}
714
+ fts_ranks = {get_id_func(item): rank + 1 for rank, item in enumerate(fts_results)}
715
+
716
+ # combine all unique items from both result sets
717
+ all_items = {}
718
+ for item in vector_results:
719
+ all_items[get_id_func(item)] = item
720
+ for item in fts_results:
721
+ all_items[get_id_func(item)] = item
722
+
723
+ # calculate RRF scores based purely on ranks
724
+ rrf_scores = {}
725
+ score_metadata = {}
726
+ for item_id in all_items:
727
+ # RRF formula: sum of 1/(k + rank) across result lists
728
+ # If item not in a list, we don't add anything (equivalent to rank = infinity)
729
+ vector_rrf_score = 0.0
730
+ fts_rrf_score = 0.0
731
+
732
+ if item_id in vector_ranks:
733
+ vector_rrf_score = vector_weight / (k + vector_ranks[item_id])
734
+ if item_id in fts_ranks:
735
+ fts_rrf_score = fts_weight / (k + fts_ranks[item_id])
736
+
737
+ combined_score = vector_rrf_score + fts_rrf_score
738
+
739
+ rrf_scores[item_id] = combined_score
740
+ score_metadata[item_id] = {
741
+ "combined_score": combined_score, # Final RRF score
742
+ "vector_rank": vector_ranks.get(item_id),
743
+ "fts_rank": fts_ranks.get(item_id),
744
+ }
745
+
746
+ # sort by RRF score and return with metadata
747
+ sorted_results = sorted(
748
+ [(all_items[iid], score, score_metadata[iid]) for iid, score in rrf_scores.items()], key=lambda x: x[1], reverse=True
749
+ )
750
+
751
+ return sorted_results[:top_k]
752
+
753
+ @trace_method
754
+ async def delete_passage(self, archive_id: str, passage_id: str) -> bool:
755
+ """Delete a passage from Turbopuffer."""
756
+ from turbopuffer import AsyncTurbopuffer
757
+
758
+ namespace_name = await self._get_archive_namespace_name(archive_id)
759
+
760
+ try:
761
+ async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
762
+ namespace = client.namespace(namespace_name)
763
+ # Use write API with deletes parameter as per Turbopuffer docs
764
+ await namespace.write(deletes=[passage_id])
765
+ logger.info(f"Successfully deleted passage {passage_id} from Turbopuffer archive {archive_id}")
766
+ return True
767
+ except Exception as e:
768
+ logger.error(f"Failed to delete passage from Turbopuffer: {e}")
769
+ raise
770
+
771
+ @trace_method
772
+ async def delete_passages(self, archive_id: str, passage_ids: List[str]) -> bool:
773
+ """Delete multiple passages from Turbopuffer."""
774
+ from turbopuffer import AsyncTurbopuffer
775
+
776
+ if not passage_ids:
777
+ return True
778
+
779
+ namespace_name = await self._get_archive_namespace_name(archive_id)
780
+
781
+ try:
782
+ async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
783
+ namespace = client.namespace(namespace_name)
784
+ # Use write API with deletes parameter as per Turbopuffer docs
785
+ await namespace.write(deletes=passage_ids)
786
+ logger.info(f"Successfully deleted {len(passage_ids)} passages from Turbopuffer archive {archive_id}")
787
+ return True
788
+ except Exception as e:
789
+ logger.error(f"Failed to delete passages from Turbopuffer: {e}")
790
+ raise
791
+
792
+ @trace_method
793
+ async def delete_all_passages(self, archive_id: str) -> bool:
794
+ """Delete all passages for an archive from Turbopuffer."""
795
+ from turbopuffer import AsyncTurbopuffer
796
+
797
+ namespace_name = await self._get_archive_namespace_name(archive_id)
798
+
799
+ try:
800
+ async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
801
+ namespace = client.namespace(namespace_name)
802
+ # Turbopuffer has a delete_all() method on namespace
803
+ await namespace.delete_all()
804
+ logger.info(f"Successfully deleted all passages for archive {archive_id}")
805
+ return True
806
+ except Exception as e:
807
+ logger.error(f"Failed to delete all passages from Turbopuffer: {e}")
808
+ raise
809
+
810
+ @trace_method
811
+ async def delete_messages(self, agent_id: str, organization_id: str, message_ids: List[str]) -> bool:
812
+ """Delete multiple messages from Turbopuffer."""
813
+ from turbopuffer import AsyncTurbopuffer
814
+
815
+ if not message_ids:
816
+ return True
817
+
818
+ namespace_name = await self._get_message_namespace_name(agent_id, organization_id)
819
+
820
+ try:
821
+ async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
822
+ namespace = client.namespace(namespace_name)
823
+ # Use write API with deletes parameter as per Turbopuffer docs
824
+ await namespace.write(deletes=message_ids)
825
+ logger.info(f"Successfully deleted {len(message_ids)} messages from Turbopuffer for agent {agent_id}")
826
+ return True
827
+ except Exception as e:
828
+ logger.error(f"Failed to delete messages from Turbopuffer: {e}")
829
+ raise
830
+
831
+ @trace_method
832
+ async def delete_all_messages(self, agent_id: str, organization_id: str) -> bool:
833
+ """Delete all messages for an agent from Turbopuffer."""
834
+ from turbopuffer import AsyncTurbopuffer
835
+
836
+ namespace_name = await self._get_message_namespace_name(agent_id, organization_id)
837
+
838
+ try:
839
+ async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
840
+ namespace = client.namespace(namespace_name)
841
+ # Use delete_by_filter to only delete messages for this agent
842
+ # since namespace is now org-scoped
843
+ result = await namespace.write(delete_by_filter=("agent_id", "Eq", agent_id))
844
+ logger.info(f"Successfully deleted all messages for agent {agent_id} (deleted {result.rows_affected} rows)")
845
+ return True
846
+ except Exception as e:
847
+ logger.error(f"Failed to delete all messages from Turbopuffer: {e}")
848
+ raise