letta-nightly 0.11.6.dev20250902104140__py3-none-any.whl → 0.11.7.dev20250904045700__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- letta/__init__.py +1 -1
- letta/agent.py +10 -14
- letta/agents/base_agent.py +18 -0
- letta/agents/helpers.py +32 -7
- letta/agents/letta_agent.py +953 -762
- letta/agents/voice_agent.py +1 -1
- letta/client/streaming.py +0 -1
- letta/constants.py +11 -8
- letta/errors.py +9 -0
- letta/functions/function_sets/base.py +77 -69
- letta/functions/function_sets/builtin.py +41 -22
- letta/functions/function_sets/multi_agent.py +1 -2
- letta/functions/schema_generator.py +0 -1
- letta/helpers/converters.py +8 -3
- letta/helpers/datetime_helpers.py +5 -4
- letta/helpers/message_helper.py +1 -2
- letta/helpers/pinecone_utils.py +0 -1
- letta/helpers/tool_rule_solver.py +10 -0
- letta/helpers/tpuf_client.py +848 -0
- letta/interface.py +8 -8
- letta/interfaces/anthropic_streaming_interface.py +7 -0
- letta/interfaces/openai_streaming_interface.py +29 -6
- letta/llm_api/anthropic_client.py +188 -18
- letta/llm_api/azure_client.py +0 -1
- letta/llm_api/bedrock_client.py +1 -2
- letta/llm_api/deepseek_client.py +319 -5
- letta/llm_api/google_vertex_client.py +75 -17
- letta/llm_api/groq_client.py +0 -1
- letta/llm_api/helpers.py +2 -2
- letta/llm_api/llm_api_tools.py +1 -50
- letta/llm_api/llm_client.py +6 -8
- letta/llm_api/mistral.py +1 -1
- letta/llm_api/openai.py +16 -13
- letta/llm_api/openai_client.py +31 -16
- letta/llm_api/together_client.py +0 -1
- letta/llm_api/xai_client.py +0 -1
- letta/local_llm/chat_completion_proxy.py +7 -6
- letta/local_llm/settings/settings.py +1 -1
- letta/orm/__init__.py +1 -0
- letta/orm/agent.py +8 -6
- letta/orm/archive.py +9 -1
- letta/orm/block.py +3 -4
- letta/orm/block_history.py +3 -1
- letta/orm/group.py +2 -3
- letta/orm/identity.py +1 -2
- letta/orm/job.py +1 -2
- letta/orm/llm_batch_items.py +1 -2
- letta/orm/message.py +8 -4
- letta/orm/mixins.py +18 -0
- letta/orm/organization.py +2 -0
- letta/orm/passage.py +8 -1
- letta/orm/passage_tag.py +55 -0
- letta/orm/sandbox_config.py +1 -3
- letta/orm/step.py +1 -2
- letta/orm/tool.py +1 -0
- letta/otel/resource.py +2 -2
- letta/plugins/plugins.py +1 -1
- letta/prompts/prompt_generator.py +10 -2
- letta/schemas/agent.py +11 -0
- letta/schemas/archive.py +4 -0
- letta/schemas/block.py +13 -0
- letta/schemas/embedding_config.py +0 -1
- letta/schemas/enums.py +24 -7
- letta/schemas/group.py +12 -0
- letta/schemas/letta_message.py +55 -1
- letta/schemas/letta_message_content.py +28 -0
- letta/schemas/letta_request.py +21 -4
- letta/schemas/letta_stop_reason.py +9 -1
- letta/schemas/llm_config.py +24 -8
- letta/schemas/mcp.py +0 -3
- letta/schemas/memory.py +14 -0
- letta/schemas/message.py +245 -141
- letta/schemas/openai/chat_completion_request.py +2 -1
- letta/schemas/passage.py +1 -0
- letta/schemas/providers/bedrock.py +1 -1
- letta/schemas/providers/openai.py +2 -2
- letta/schemas/tool.py +11 -5
- letta/schemas/tool_execution_result.py +0 -1
- letta/schemas/tool_rule.py +71 -0
- letta/serialize_schemas/marshmallow_agent.py +1 -2
- letta/server/rest_api/app.py +3 -3
- letta/server/rest_api/auth/index.py +0 -1
- letta/server/rest_api/interface.py +3 -11
- letta/server/rest_api/redis_stream_manager.py +3 -4
- letta/server/rest_api/routers/v1/agents.py +143 -84
- letta/server/rest_api/routers/v1/blocks.py +1 -1
- letta/server/rest_api/routers/v1/folders.py +1 -1
- letta/server/rest_api/routers/v1/groups.py +23 -22
- letta/server/rest_api/routers/v1/internal_templates.py +68 -0
- letta/server/rest_api/routers/v1/sandbox_configs.py +11 -5
- letta/server/rest_api/routers/v1/sources.py +1 -1
- letta/server/rest_api/routers/v1/tools.py +167 -15
- letta/server/rest_api/streaming_response.py +4 -3
- letta/server/rest_api/utils.py +75 -18
- letta/server/server.py +24 -35
- letta/services/agent_manager.py +359 -45
- letta/services/agent_serialization_manager.py +23 -3
- letta/services/archive_manager.py +72 -3
- letta/services/block_manager.py +1 -2
- letta/services/context_window_calculator/token_counter.py +11 -6
- letta/services/file_manager.py +1 -3
- letta/services/files_agents_manager.py +2 -4
- letta/services/group_manager.py +73 -12
- letta/services/helpers/agent_manager_helper.py +5 -5
- letta/services/identity_manager.py +8 -3
- letta/services/job_manager.py +2 -14
- letta/services/llm_batch_manager.py +1 -3
- letta/services/mcp/base_client.py +1 -2
- letta/services/mcp_manager.py +5 -6
- letta/services/message_manager.py +536 -15
- letta/services/organization_manager.py +1 -2
- letta/services/passage_manager.py +287 -12
- letta/services/provider_manager.py +1 -3
- letta/services/sandbox_config_manager.py +12 -7
- letta/services/source_manager.py +1 -2
- letta/services/step_manager.py +0 -1
- letta/services/summarizer/summarizer.py +4 -2
- letta/services/telemetry_manager.py +1 -3
- letta/services/tool_executor/builtin_tool_executor.py +136 -316
- letta/services/tool_executor/core_tool_executor.py +231 -74
- letta/services/tool_executor/files_tool_executor.py +2 -2
- letta/services/tool_executor/mcp_tool_executor.py +0 -1
- letta/services/tool_executor/multi_agent_tool_executor.py +2 -2
- letta/services/tool_executor/sandbox_tool_executor.py +0 -1
- letta/services/tool_executor/tool_execution_sandbox.py +2 -3
- letta/services/tool_manager.py +181 -64
- letta/services/tool_sandbox/modal_deployment_manager.py +2 -2
- letta/services/user_manager.py +1 -2
- letta/settings.py +5 -3
- letta/streaming_interface.py +3 -3
- letta/system.py +1 -1
- letta/utils.py +0 -1
- {letta_nightly-0.11.6.dev20250902104140.dist-info → letta_nightly-0.11.7.dev20250904045700.dist-info}/METADATA +11 -7
- {letta_nightly-0.11.6.dev20250902104140.dist-info → letta_nightly-0.11.7.dev20250904045700.dist-info}/RECORD +137 -135
- letta/llm_api/deepseek.py +0 -303
- {letta_nightly-0.11.6.dev20250902104140.dist-info → letta_nightly-0.11.7.dev20250904045700.dist-info}/WHEEL +0 -0
- {letta_nightly-0.11.6.dev20250902104140.dist-info → letta_nightly-0.11.7.dev20250904045700.dist-info}/entry_points.txt +0 -0
- {letta_nightly-0.11.6.dev20250902104140.dist-info → letta_nightly-0.11.7.dev20250904045700.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,848 @@
|
|
1
|
+
"""Turbopuffer utilities for archival memory storage."""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from datetime import datetime, timezone
|
5
|
+
from typing import Any, Callable, List, Optional, Tuple
|
6
|
+
|
7
|
+
from letta.otel.tracing import trace_method
|
8
|
+
from letta.schemas.enums import MessageRole, TagMatchMode
|
9
|
+
from letta.schemas.passage import Passage as PydanticPassage
|
10
|
+
from letta.settings import settings
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
def should_use_tpuf() -> bool:
|
16
|
+
return bool(settings.use_tpuf) and bool(settings.tpuf_api_key)
|
17
|
+
|
18
|
+
|
19
|
+
def should_use_tpuf_for_messages() -> bool:
|
20
|
+
"""Check if Turbopuffer should be used for messages."""
|
21
|
+
return should_use_tpuf() and bool(settings.embed_all_messages)
|
22
|
+
|
23
|
+
|
24
|
+
class TurbopufferClient:
|
25
|
+
"""Client for managing archival memory with Turbopuffer vector database."""
|
26
|
+
|
27
|
+
def __init__(self, api_key: str = None, region: str = None):
|
28
|
+
"""Initialize Turbopuffer client."""
|
29
|
+
self.api_key = api_key or settings.tpuf_api_key
|
30
|
+
self.region = region or settings.tpuf_region
|
31
|
+
|
32
|
+
from letta.services.agent_manager import AgentManager
|
33
|
+
from letta.services.archive_manager import ArchiveManager
|
34
|
+
|
35
|
+
self.archive_manager = ArchiveManager()
|
36
|
+
self.agent_manager = AgentManager()
|
37
|
+
|
38
|
+
if not self.api_key:
|
39
|
+
raise ValueError("Turbopuffer API key not provided")
|
40
|
+
|
41
|
+
@trace_method
|
42
|
+
async def _get_archive_namespace_name(self, archive_id: str) -> str:
|
43
|
+
"""Get namespace name for a specific archive."""
|
44
|
+
return await self.archive_manager.get_or_set_vector_db_namespace_async(archive_id)
|
45
|
+
|
46
|
+
@trace_method
|
47
|
+
async def _get_message_namespace_name(self, agent_id: str, organization_id: str) -> str:
|
48
|
+
"""Get namespace name for messages (org-scoped).
|
49
|
+
|
50
|
+
Args:
|
51
|
+
agent_id: Agent ID (stored for future sharding)
|
52
|
+
organization_id: Organization ID for namespace generation
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
The org-scoped namespace name for messages
|
56
|
+
"""
|
57
|
+
return await self.agent_manager.get_or_set_vector_db_namespace_async(agent_id, organization_id)
|
58
|
+
|
59
|
+
@trace_method
|
60
|
+
async def insert_archival_memories(
|
61
|
+
self,
|
62
|
+
archive_id: str,
|
63
|
+
text_chunks: List[str],
|
64
|
+
embeddings: List[List[float]],
|
65
|
+
passage_ids: List[str],
|
66
|
+
organization_id: str,
|
67
|
+
tags: Optional[List[str]] = None,
|
68
|
+
created_at: Optional[datetime] = None,
|
69
|
+
) -> List[PydanticPassage]:
|
70
|
+
"""Insert passages into Turbopuffer.
|
71
|
+
|
72
|
+
Args:
|
73
|
+
archive_id: ID of the archive
|
74
|
+
text_chunks: List of text chunks to store
|
75
|
+
embeddings: List of embedding vectors corresponding to text chunks
|
76
|
+
passage_ids: List of passage IDs (must match 1:1 with text_chunks)
|
77
|
+
organization_id: Organization ID for the passages
|
78
|
+
tags: Optional list of tags to attach to all passages
|
79
|
+
created_at: Optional timestamp for retroactive entries (defaults to current UTC time)
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
List of PydanticPassage objects that were inserted
|
83
|
+
"""
|
84
|
+
from turbopuffer import AsyncTurbopuffer
|
85
|
+
|
86
|
+
namespace_name = await self._get_archive_namespace_name(archive_id)
|
87
|
+
|
88
|
+
# handle timestamp - ensure UTC
|
89
|
+
if created_at is None:
|
90
|
+
timestamp = datetime.now(timezone.utc)
|
91
|
+
else:
|
92
|
+
# ensure the provided timestamp is timezone-aware and in UTC
|
93
|
+
if created_at.tzinfo is None:
|
94
|
+
# assume UTC if no timezone provided
|
95
|
+
timestamp = created_at.replace(tzinfo=timezone.utc)
|
96
|
+
else:
|
97
|
+
# convert to UTC if in different timezone
|
98
|
+
timestamp = created_at.astimezone(timezone.utc)
|
99
|
+
|
100
|
+
# passage_ids must be provided for dual-write consistency
|
101
|
+
if not passage_ids:
|
102
|
+
raise ValueError("passage_ids must be provided for Turbopuffer insertion")
|
103
|
+
if len(passage_ids) != len(text_chunks):
|
104
|
+
raise ValueError(f"passage_ids length ({len(passage_ids)}) must match text_chunks length ({len(text_chunks)})")
|
105
|
+
if len(passage_ids) != len(embeddings):
|
106
|
+
raise ValueError(f"passage_ids length ({len(passage_ids)}) must match embeddings length ({len(embeddings)})")
|
107
|
+
|
108
|
+
# prepare column-based data for turbopuffer - optimized for batch insert
|
109
|
+
ids = []
|
110
|
+
vectors = []
|
111
|
+
texts = []
|
112
|
+
organization_ids = []
|
113
|
+
archive_ids = []
|
114
|
+
created_ats = []
|
115
|
+
tags_arrays = [] # Store tags as arrays
|
116
|
+
passages = []
|
117
|
+
|
118
|
+
for idx, (text, embedding) in enumerate(zip(text_chunks, embeddings)):
|
119
|
+
passage_id = passage_ids[idx]
|
120
|
+
|
121
|
+
# append to columns
|
122
|
+
ids.append(passage_id)
|
123
|
+
vectors.append(embedding)
|
124
|
+
texts.append(text)
|
125
|
+
organization_ids.append(organization_id)
|
126
|
+
archive_ids.append(archive_id)
|
127
|
+
created_ats.append(timestamp)
|
128
|
+
tags_arrays.append(tags or []) # Store tags as array
|
129
|
+
|
130
|
+
# Create PydanticPassage object
|
131
|
+
passage = PydanticPassage(
|
132
|
+
id=passage_id,
|
133
|
+
text=text,
|
134
|
+
organization_id=organization_id,
|
135
|
+
archive_id=archive_id,
|
136
|
+
created_at=timestamp,
|
137
|
+
metadata_={},
|
138
|
+
tags=tags or [], # Include tags in the passage
|
139
|
+
embedding=embedding,
|
140
|
+
embedding_config=None, # Will be set by caller if needed
|
141
|
+
)
|
142
|
+
passages.append(passage)
|
143
|
+
|
144
|
+
# build column-based upsert data
|
145
|
+
upsert_columns = {
|
146
|
+
"id": ids,
|
147
|
+
"vector": vectors,
|
148
|
+
"text": texts,
|
149
|
+
"organization_id": organization_ids,
|
150
|
+
"archive_id": archive_ids,
|
151
|
+
"created_at": created_ats,
|
152
|
+
"tags": tags_arrays, # Add tags as array column
|
153
|
+
}
|
154
|
+
|
155
|
+
try:
|
156
|
+
# Use AsyncTurbopuffer as a context manager for proper resource cleanup
|
157
|
+
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
158
|
+
namespace = client.namespace(namespace_name)
|
159
|
+
# turbopuffer recommends column-based writes for performance
|
160
|
+
await namespace.write(
|
161
|
+
upsert_columns=upsert_columns,
|
162
|
+
distance_metric="cosine_distance",
|
163
|
+
schema={"text": {"type": "string", "full_text_search": True}},
|
164
|
+
)
|
165
|
+
logger.info(f"Successfully inserted {len(ids)} passages to Turbopuffer for archive {archive_id}")
|
166
|
+
return passages
|
167
|
+
|
168
|
+
except Exception as e:
|
169
|
+
logger.error(f"Failed to insert passages to Turbopuffer: {e}")
|
170
|
+
# check if it's a duplicate ID error
|
171
|
+
if "duplicate" in str(e).lower():
|
172
|
+
logger.error("Duplicate passage IDs detected in batch")
|
173
|
+
raise
|
174
|
+
|
175
|
+
@trace_method
|
176
|
+
async def insert_messages(
|
177
|
+
self,
|
178
|
+
agent_id: str,
|
179
|
+
message_texts: List[str],
|
180
|
+
embeddings: List[List[float]],
|
181
|
+
message_ids: List[str],
|
182
|
+
organization_id: str,
|
183
|
+
roles: List[MessageRole],
|
184
|
+
created_ats: List[datetime],
|
185
|
+
) -> bool:
|
186
|
+
"""Insert messages into Turbopuffer.
|
187
|
+
|
188
|
+
Args:
|
189
|
+
agent_id: ID of the agent
|
190
|
+
message_texts: List of message text content to store
|
191
|
+
embeddings: List of embedding vectors corresponding to message texts
|
192
|
+
message_ids: List of message IDs (must match 1:1 with message_texts)
|
193
|
+
organization_id: Organization ID for the messages
|
194
|
+
roles: List of message roles corresponding to each message
|
195
|
+
created_ats: List of creation timestamps for each message
|
196
|
+
|
197
|
+
Returns:
|
198
|
+
True if successful
|
199
|
+
"""
|
200
|
+
from turbopuffer import AsyncTurbopuffer
|
201
|
+
|
202
|
+
namespace_name = await self._get_message_namespace_name(agent_id, organization_id)
|
203
|
+
|
204
|
+
# validation checks
|
205
|
+
if not message_ids:
|
206
|
+
raise ValueError("message_ids must be provided for Turbopuffer insertion")
|
207
|
+
if len(message_ids) != len(message_texts):
|
208
|
+
raise ValueError(f"message_ids length ({len(message_ids)}) must match message_texts length ({len(message_texts)})")
|
209
|
+
if len(message_ids) != len(embeddings):
|
210
|
+
raise ValueError(f"message_ids length ({len(message_ids)}) must match embeddings length ({len(embeddings)})")
|
211
|
+
if len(message_ids) != len(roles):
|
212
|
+
raise ValueError(f"message_ids length ({len(message_ids)}) must match roles length ({len(roles)})")
|
213
|
+
if len(message_ids) != len(created_ats):
|
214
|
+
raise ValueError(f"message_ids length ({len(message_ids)}) must match created_ats length ({len(created_ats)})")
|
215
|
+
|
216
|
+
# prepare column-based data for turbopuffer - optimized for batch insert
|
217
|
+
ids = []
|
218
|
+
vectors = []
|
219
|
+
texts = []
|
220
|
+
organization_ids = []
|
221
|
+
agent_ids = []
|
222
|
+
message_roles = []
|
223
|
+
created_at_timestamps = []
|
224
|
+
|
225
|
+
for idx, (text, embedding, role, created_at) in enumerate(zip(message_texts, embeddings, roles, created_ats)):
|
226
|
+
message_id = message_ids[idx]
|
227
|
+
|
228
|
+
# ensure the provided timestamp is timezone-aware and in UTC
|
229
|
+
if created_at.tzinfo is None:
|
230
|
+
# assume UTC if no timezone provided
|
231
|
+
timestamp = created_at.replace(tzinfo=timezone.utc)
|
232
|
+
else:
|
233
|
+
# convert to UTC if in different timezone
|
234
|
+
timestamp = created_at.astimezone(timezone.utc)
|
235
|
+
|
236
|
+
# append to columns
|
237
|
+
ids.append(message_id)
|
238
|
+
vectors.append(embedding)
|
239
|
+
texts.append(text)
|
240
|
+
organization_ids.append(organization_id)
|
241
|
+
agent_ids.append(agent_id)
|
242
|
+
message_roles.append(role.value)
|
243
|
+
created_at_timestamps.append(timestamp)
|
244
|
+
|
245
|
+
# build column-based upsert data
|
246
|
+
upsert_columns = {
|
247
|
+
"id": ids,
|
248
|
+
"vector": vectors,
|
249
|
+
"text": texts,
|
250
|
+
"organization_id": organization_ids,
|
251
|
+
"agent_id": agent_ids,
|
252
|
+
"role": message_roles,
|
253
|
+
"created_at": created_at_timestamps,
|
254
|
+
}
|
255
|
+
|
256
|
+
try:
|
257
|
+
# Use AsyncTurbopuffer as a context manager for proper resource cleanup
|
258
|
+
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
259
|
+
namespace = client.namespace(namespace_name)
|
260
|
+
# turbopuffer recommends column-based writes for performance
|
261
|
+
await namespace.write(
|
262
|
+
upsert_columns=upsert_columns,
|
263
|
+
distance_metric="cosine_distance",
|
264
|
+
schema={"text": {"type": "string", "full_text_search": True}},
|
265
|
+
)
|
266
|
+
logger.info(f"Successfully inserted {len(ids)} messages to Turbopuffer for agent {agent_id}")
|
267
|
+
return True
|
268
|
+
|
269
|
+
except Exception as e:
|
270
|
+
logger.error(f"Failed to insert messages to Turbopuffer: {e}")
|
271
|
+
# check if it's a duplicate ID error
|
272
|
+
if "duplicate" in str(e).lower():
|
273
|
+
logger.error("Duplicate message IDs detected in batch")
|
274
|
+
raise
|
275
|
+
|
276
|
+
@trace_method
|
277
|
+
async def _execute_query(
|
278
|
+
self,
|
279
|
+
namespace_name: str,
|
280
|
+
search_mode: str,
|
281
|
+
query_embedding: Optional[List[float]],
|
282
|
+
query_text: Optional[str],
|
283
|
+
top_k: int,
|
284
|
+
include_attributes: List[str],
|
285
|
+
filters: Optional[Any] = None,
|
286
|
+
vector_weight: float = 0.5,
|
287
|
+
fts_weight: float = 0.5,
|
288
|
+
) -> Any:
|
289
|
+
"""Generic query execution for Turbopuffer.
|
290
|
+
|
291
|
+
Args:
|
292
|
+
namespace_name: Turbopuffer namespace to query
|
293
|
+
search_mode: "vector", "fts", "hybrid", or "timestamp"
|
294
|
+
query_embedding: Embedding for vector search
|
295
|
+
query_text: Text for full-text search
|
296
|
+
top_k: Number of results to return
|
297
|
+
include_attributes: Attributes to include in results
|
298
|
+
filters: Turbopuffer filter expression
|
299
|
+
vector_weight: Weight for vector search in hybrid mode
|
300
|
+
fts_weight: Weight for FTS in hybrid mode
|
301
|
+
|
302
|
+
Returns:
|
303
|
+
Raw Turbopuffer query results or multi-query response
|
304
|
+
"""
|
305
|
+
from turbopuffer import AsyncTurbopuffer
|
306
|
+
from turbopuffer.types import QueryParam
|
307
|
+
|
308
|
+
# validate inputs based on search mode
|
309
|
+
if search_mode == "vector" and query_embedding is None:
|
310
|
+
raise ValueError("query_embedding is required for vector search mode")
|
311
|
+
if search_mode == "fts" and query_text is None:
|
312
|
+
raise ValueError("query_text is required for FTS search mode")
|
313
|
+
if search_mode == "hybrid":
|
314
|
+
if query_embedding is None or query_text is None:
|
315
|
+
raise ValueError("Both query_embedding and query_text are required for hybrid search mode")
|
316
|
+
if search_mode not in ["vector", "fts", "hybrid", "timestamp"]:
|
317
|
+
raise ValueError(f"Invalid search_mode: {search_mode}. Must be 'vector', 'fts', 'hybrid', or 'timestamp'")
|
318
|
+
|
319
|
+
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
320
|
+
namespace = client.namespace(namespace_name)
|
321
|
+
|
322
|
+
if search_mode == "timestamp":
|
323
|
+
# retrieve most recent items by timestamp
|
324
|
+
query_params = {
|
325
|
+
"rank_by": ("created_at", "desc"),
|
326
|
+
"top_k": top_k,
|
327
|
+
"include_attributes": include_attributes,
|
328
|
+
}
|
329
|
+
if filters:
|
330
|
+
query_params["filters"] = filters
|
331
|
+
return await namespace.query(**query_params)
|
332
|
+
|
333
|
+
elif search_mode == "vector":
|
334
|
+
# vector search query
|
335
|
+
query_params = {
|
336
|
+
"rank_by": ("vector", "ANN", query_embedding),
|
337
|
+
"top_k": top_k,
|
338
|
+
"include_attributes": include_attributes,
|
339
|
+
}
|
340
|
+
if filters:
|
341
|
+
query_params["filters"] = filters
|
342
|
+
return await namespace.query(**query_params)
|
343
|
+
|
344
|
+
elif search_mode == "fts":
|
345
|
+
# full-text search query
|
346
|
+
query_params = {
|
347
|
+
"rank_by": ("text", "BM25", query_text),
|
348
|
+
"top_k": top_k,
|
349
|
+
"include_attributes": include_attributes,
|
350
|
+
}
|
351
|
+
if filters:
|
352
|
+
query_params["filters"] = filters
|
353
|
+
return await namespace.query(**query_params)
|
354
|
+
|
355
|
+
else: # hybrid mode
|
356
|
+
queries = []
|
357
|
+
|
358
|
+
# vector search query
|
359
|
+
vector_query = {
|
360
|
+
"rank_by": ("vector", "ANN", query_embedding),
|
361
|
+
"top_k": top_k,
|
362
|
+
"include_attributes": include_attributes,
|
363
|
+
}
|
364
|
+
if filters:
|
365
|
+
vector_query["filters"] = filters
|
366
|
+
queries.append(vector_query)
|
367
|
+
|
368
|
+
# full-text search query
|
369
|
+
fts_query = {
|
370
|
+
"rank_by": ("text", "BM25", query_text),
|
371
|
+
"top_k": top_k,
|
372
|
+
"include_attributes": include_attributes,
|
373
|
+
}
|
374
|
+
if filters:
|
375
|
+
fts_query["filters"] = filters
|
376
|
+
queries.append(fts_query)
|
377
|
+
|
378
|
+
# execute multi-query
|
379
|
+
return await namespace.multi_query(queries=[QueryParam(**q) for q in queries])
|
380
|
+
|
381
|
+
@trace_method
|
382
|
+
async def query_passages(
|
383
|
+
self,
|
384
|
+
archive_id: str,
|
385
|
+
query_embedding: Optional[List[float]] = None,
|
386
|
+
query_text: Optional[str] = None,
|
387
|
+
search_mode: str = "vector", # "vector", "fts", "hybrid"
|
388
|
+
top_k: int = 10,
|
389
|
+
tags: Optional[List[str]] = None,
|
390
|
+
tag_match_mode: TagMatchMode = TagMatchMode.ANY,
|
391
|
+
vector_weight: float = 0.5,
|
392
|
+
fts_weight: float = 0.5,
|
393
|
+
start_date: Optional[datetime] = None,
|
394
|
+
end_date: Optional[datetime] = None,
|
395
|
+
) -> List[Tuple[PydanticPassage, float]]:
|
396
|
+
"""Query passages from Turbopuffer using vector search, full-text search, or hybrid search.
|
397
|
+
|
398
|
+
Args:
|
399
|
+
archive_id: ID of the archive
|
400
|
+
query_embedding: Embedding vector for vector search (required for "vector" and "hybrid" modes)
|
401
|
+
query_text: Text query for full-text search (required for "fts" and "hybrid" modes)
|
402
|
+
search_mode: Search mode - "vector", "fts", or "hybrid" (default: "vector")
|
403
|
+
top_k: Number of results to return
|
404
|
+
tags: Optional list of tags to filter by
|
405
|
+
tag_match_mode: TagMatchMode.ANY (match any tag) or TagMatchMode.ALL (match all tags) - default: TagMatchMode.ANY
|
406
|
+
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
|
407
|
+
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
|
408
|
+
start_date: Optional datetime to filter passages created after this date
|
409
|
+
end_date: Optional datetime to filter passages created before this date
|
410
|
+
|
411
|
+
Returns:
|
412
|
+
List of (passage, score) tuples
|
413
|
+
"""
|
414
|
+
# Check if we should fallback to timestamp-based retrieval
|
415
|
+
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
|
416
|
+
# Fallback to retrieving most recent passages when no search query is provided
|
417
|
+
search_mode = "timestamp"
|
418
|
+
|
419
|
+
namespace_name = await self._get_archive_namespace_name(archive_id)
|
420
|
+
|
421
|
+
# build tag filter conditions
|
422
|
+
tag_filter = None
|
423
|
+
if tags:
|
424
|
+
if tag_match_mode == TagMatchMode.ALL:
|
425
|
+
# For ALL mode, need to check each tag individually with Contains
|
426
|
+
tag_conditions = []
|
427
|
+
for tag in tags:
|
428
|
+
tag_conditions.append(("tags", "Contains", tag))
|
429
|
+
if len(tag_conditions) == 1:
|
430
|
+
tag_filter = tag_conditions[0]
|
431
|
+
else:
|
432
|
+
tag_filter = ("And", tag_conditions)
|
433
|
+
else: # tag_match_mode == TagMatchMode.ANY
|
434
|
+
# For ANY mode, use ContainsAny to match any of the tags
|
435
|
+
tag_filter = ("tags", "ContainsAny", tags)
|
436
|
+
|
437
|
+
# build date filter conditions
|
438
|
+
date_filters = []
|
439
|
+
if start_date:
|
440
|
+
date_filters.append(("created_at", "Gte", start_date))
|
441
|
+
if end_date:
|
442
|
+
date_filters.append(("created_at", "Lte", end_date))
|
443
|
+
|
444
|
+
# combine all filters
|
445
|
+
all_filters = []
|
446
|
+
if tag_filter:
|
447
|
+
all_filters.append(tag_filter)
|
448
|
+
if date_filters:
|
449
|
+
all_filters.extend(date_filters)
|
450
|
+
|
451
|
+
# create final filter expression
|
452
|
+
final_filter = None
|
453
|
+
if len(all_filters) == 1:
|
454
|
+
final_filter = all_filters[0]
|
455
|
+
elif len(all_filters) > 1:
|
456
|
+
final_filter = ("And", all_filters)
|
457
|
+
|
458
|
+
try:
|
459
|
+
# use generic query executor
|
460
|
+
result = await self._execute_query(
|
461
|
+
namespace_name=namespace_name,
|
462
|
+
search_mode=search_mode,
|
463
|
+
query_embedding=query_embedding,
|
464
|
+
query_text=query_text,
|
465
|
+
top_k=top_k,
|
466
|
+
include_attributes=["text", "organization_id", "archive_id", "created_at", "tags"],
|
467
|
+
filters=final_filter,
|
468
|
+
vector_weight=vector_weight,
|
469
|
+
fts_weight=fts_weight,
|
470
|
+
)
|
471
|
+
|
472
|
+
# process results based on search mode
|
473
|
+
if search_mode == "hybrid":
|
474
|
+
# for hybrid mode, we get a multi-query response
|
475
|
+
vector_results = self._process_single_query_results(result.results[0], archive_id, tags)
|
476
|
+
fts_results = self._process_single_query_results(result.results[1], archive_id, tags, is_fts=True)
|
477
|
+
# use RRF and return only (passage, score) for backwards compatibility
|
478
|
+
results_with_metadata = self._reciprocal_rank_fusion(
|
479
|
+
vector_results=[passage for passage, _ in vector_results],
|
480
|
+
fts_results=[passage for passage, _ in fts_results],
|
481
|
+
get_id_func=lambda p: p.id,
|
482
|
+
vector_weight=vector_weight,
|
483
|
+
fts_weight=fts_weight,
|
484
|
+
top_k=top_k,
|
485
|
+
)
|
486
|
+
return [(passage, rrf_score) for passage, rrf_score, metadata in results_with_metadata]
|
487
|
+
else:
|
488
|
+
# for single queries (vector, fts, timestamp)
|
489
|
+
is_fts = search_mode == "fts"
|
490
|
+
return self._process_single_query_results(result, archive_id, tags, is_fts=is_fts)
|
491
|
+
|
492
|
+
except Exception as e:
|
493
|
+
logger.error(f"Failed to query passages from Turbopuffer: {e}")
|
494
|
+
raise
|
495
|
+
|
496
|
+
@trace_method
|
497
|
+
async def query_messages(
|
498
|
+
self,
|
499
|
+
agent_id: str,
|
500
|
+
organization_id: str,
|
501
|
+
query_embedding: Optional[List[float]] = None,
|
502
|
+
query_text: Optional[str] = None,
|
503
|
+
search_mode: str = "vector", # "vector", "fts", "hybrid", "timestamp"
|
504
|
+
top_k: int = 10,
|
505
|
+
roles: Optional[List[MessageRole]] = None,
|
506
|
+
vector_weight: float = 0.5,
|
507
|
+
fts_weight: float = 0.5,
|
508
|
+
start_date: Optional[datetime] = None,
|
509
|
+
end_date: Optional[datetime] = None,
|
510
|
+
) -> List[Tuple[dict, float, dict]]:
|
511
|
+
"""Query messages from Turbopuffer using vector search, full-text search, or hybrid search.
|
512
|
+
|
513
|
+
Args:
|
514
|
+
agent_id: ID of the agent (used for filtering results)
|
515
|
+
organization_id: Organization ID for namespace lookup
|
516
|
+
query_embedding: Embedding vector for vector search (required for "vector" and "hybrid" modes)
|
517
|
+
query_text: Text query for full-text search (required for "fts" and "hybrid" modes)
|
518
|
+
search_mode: Search mode - "vector", "fts", "hybrid", or "timestamp" (default: "vector")
|
519
|
+
top_k: Number of results to return
|
520
|
+
roles: Optional list of message roles to filter by
|
521
|
+
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
|
522
|
+
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
|
523
|
+
start_date: Optional datetime to filter messages created after this date
|
524
|
+
end_date: Optional datetime to filter messages created before this date
|
525
|
+
|
526
|
+
Returns:
|
527
|
+
List of (message_dict, score, metadata) tuples where:
|
528
|
+
- message_dict contains id, text, role, created_at
|
529
|
+
- score is the final relevance score
|
530
|
+
- metadata contains individual scores and ranking information
|
531
|
+
"""
|
532
|
+
# Check if we should fallback to timestamp-based retrieval
|
533
|
+
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
|
534
|
+
# Fallback to retrieving most recent messages when no search query is provided
|
535
|
+
search_mode = "timestamp"
|
536
|
+
|
537
|
+
namespace_name = await self._get_message_namespace_name(agent_id, organization_id)
|
538
|
+
|
539
|
+
# build agent_id filter
|
540
|
+
agent_filter = ("agent_id", "Eq", agent_id)
|
541
|
+
|
542
|
+
# build role filter conditions
|
543
|
+
role_filter = None
|
544
|
+
if roles:
|
545
|
+
role_values = [r.value for r in roles]
|
546
|
+
if len(role_values) == 1:
|
547
|
+
role_filter = ("role", "Eq", role_values[0])
|
548
|
+
else:
|
549
|
+
role_filter = ("role", "In", role_values)
|
550
|
+
|
551
|
+
# build date filter conditions
|
552
|
+
date_filters = []
|
553
|
+
if start_date:
|
554
|
+
date_filters.append(("created_at", "Gte", start_date))
|
555
|
+
if end_date:
|
556
|
+
date_filters.append(("created_at", "Lte", end_date))
|
557
|
+
|
558
|
+
# combine all filters
|
559
|
+
all_filters = [agent_filter] # always include agent_id filter
|
560
|
+
if role_filter:
|
561
|
+
all_filters.append(role_filter)
|
562
|
+
if date_filters:
|
563
|
+
all_filters.extend(date_filters)
|
564
|
+
|
565
|
+
# create final filter expression
|
566
|
+
final_filter = None
|
567
|
+
if len(all_filters) == 1:
|
568
|
+
final_filter = all_filters[0]
|
569
|
+
elif len(all_filters) > 1:
|
570
|
+
final_filter = ("And", all_filters)
|
571
|
+
|
572
|
+
try:
|
573
|
+
# use generic query executor
|
574
|
+
result = await self._execute_query(
|
575
|
+
namespace_name=namespace_name,
|
576
|
+
search_mode=search_mode,
|
577
|
+
query_embedding=query_embedding,
|
578
|
+
query_text=query_text,
|
579
|
+
top_k=top_k,
|
580
|
+
include_attributes=["text", "organization_id", "agent_id", "role", "created_at"],
|
581
|
+
filters=final_filter,
|
582
|
+
vector_weight=vector_weight,
|
583
|
+
fts_weight=fts_weight,
|
584
|
+
)
|
585
|
+
|
586
|
+
# process results based on search mode
|
587
|
+
if search_mode == "hybrid":
|
588
|
+
# for hybrid mode, we get a multi-query response
|
589
|
+
vector_results = self._process_message_query_results(result.results[0])
|
590
|
+
fts_results = self._process_message_query_results(result.results[1])
|
591
|
+
# use RRF with lambda to extract ID from dict - returns metadata
|
592
|
+
results_with_metadata = self._reciprocal_rank_fusion(
|
593
|
+
vector_results=vector_results,
|
594
|
+
fts_results=fts_results,
|
595
|
+
get_id_func=lambda msg_dict: msg_dict["id"],
|
596
|
+
vector_weight=vector_weight,
|
597
|
+
fts_weight=fts_weight,
|
598
|
+
top_k=top_k,
|
599
|
+
)
|
600
|
+
# return results with metadata
|
601
|
+
return results_with_metadata
|
602
|
+
else:
|
603
|
+
# for single queries (vector, fts, timestamp)
|
604
|
+
results = self._process_message_query_results(result)
|
605
|
+
# add simple metadata for single search modes
|
606
|
+
results_with_metadata = []
|
607
|
+
for idx, msg_dict in enumerate(results):
|
608
|
+
metadata = {
|
609
|
+
"combined_score": 1.0 / (idx + 1), # Use rank-based score for single mode
|
610
|
+
"search_mode": search_mode,
|
611
|
+
f"{search_mode}_rank": idx + 1, # Add the rank for this search mode
|
612
|
+
}
|
613
|
+
results_with_metadata.append((msg_dict, metadata["combined_score"], metadata))
|
614
|
+
return results_with_metadata
|
615
|
+
|
616
|
+
except Exception as e:
|
617
|
+
logger.error(f"Failed to query messages from Turbopuffer: {e}")
|
618
|
+
raise
|
619
|
+
|
620
|
+
def _process_message_query_results(self, result) -> List[dict]:
|
621
|
+
"""Process results from a message query into message dicts.
|
622
|
+
|
623
|
+
For RRF, we only need the rank order - scores are not used.
|
624
|
+
"""
|
625
|
+
messages = []
|
626
|
+
|
627
|
+
for row in result.rows:
|
628
|
+
# Build message dict with key fields
|
629
|
+
message_dict = {
|
630
|
+
"id": row.id,
|
631
|
+
"text": getattr(row, "text", ""),
|
632
|
+
"organization_id": getattr(row, "organization_id", None),
|
633
|
+
"agent_id": getattr(row, "agent_id", None),
|
634
|
+
"role": getattr(row, "role", None),
|
635
|
+
"created_at": getattr(row, "created_at", None),
|
636
|
+
}
|
637
|
+
messages.append(message_dict)
|
638
|
+
|
639
|
+
return messages
|
640
|
+
|
641
|
+
def _process_single_query_results(
|
642
|
+
self, result, archive_id: str, tags: Optional[List[str]], is_fts: bool = False
|
643
|
+
) -> List[Tuple[PydanticPassage, float]]:
|
644
|
+
"""Process results from a single query into passage objects with scores."""
|
645
|
+
passages_with_scores = []
|
646
|
+
|
647
|
+
for row in result.rows:
|
648
|
+
# Extract tags from the result row
|
649
|
+
passage_tags = getattr(row, "tags", []) or []
|
650
|
+
|
651
|
+
# Build metadata
|
652
|
+
metadata = {}
|
653
|
+
|
654
|
+
# Create a passage with minimal fields - embeddings are not returned from Turbopuffer
|
655
|
+
passage = PydanticPassage(
|
656
|
+
id=row.id,
|
657
|
+
text=getattr(row, "text", ""),
|
658
|
+
organization_id=getattr(row, "organization_id", None),
|
659
|
+
archive_id=archive_id, # use the archive_id from the query
|
660
|
+
created_at=getattr(row, "created_at", None),
|
661
|
+
metadata_=metadata,
|
662
|
+
tags=passage_tags, # Set the actual tags from the passage
|
663
|
+
# Set required fields to empty/default values since we don't store embeddings
|
664
|
+
embedding=[], # Empty embedding since we don't return it from Turbopuffer
|
665
|
+
embedding_config=None, # No embedding config needed for retrieved passages
|
666
|
+
)
|
667
|
+
|
668
|
+
# handle score based on search type
|
669
|
+
if is_fts:
|
670
|
+
# for FTS, use the BM25 score directly (higher is better)
|
671
|
+
score = getattr(row, "$score", 0.0)
|
672
|
+
else:
|
673
|
+
# for vector search, convert distance to similarity score
|
674
|
+
distance = getattr(row, "$dist", 0.0)
|
675
|
+
score = 1.0 - distance
|
676
|
+
|
677
|
+
passages_with_scores.append((passage, score))
|
678
|
+
|
679
|
+
return passages_with_scores
|
680
|
+
|
681
|
+
def _reciprocal_rank_fusion(
|
682
|
+
self,
|
683
|
+
vector_results: List[Any],
|
684
|
+
fts_results: List[Any],
|
685
|
+
get_id_func: Callable[[Any], str],
|
686
|
+
vector_weight: float,
|
687
|
+
fts_weight: float,
|
688
|
+
top_k: int,
|
689
|
+
) -> List[Tuple[Any, float, dict]]:
|
690
|
+
"""RRF implementation that works with any object type.
|
691
|
+
|
692
|
+
RRF score = vector_weight * (1/(k + rank)) + fts_weight * (1/(k + rank))
|
693
|
+
where k is a constant (typically 60) to avoid division by zero
|
694
|
+
|
695
|
+
This is a pure rank-based fusion following the standard RRF algorithm.
|
696
|
+
|
697
|
+
Args:
|
698
|
+
vector_results: List of items from vector search (ordered by relevance)
|
699
|
+
fts_results: List of items from FTS (ordered by relevance)
|
700
|
+
get_id_func: Function to extract ID from an item
|
701
|
+
vector_weight: Weight for vector search results
|
702
|
+
fts_weight: Weight for FTS results
|
703
|
+
top_k: Number of results to return
|
704
|
+
|
705
|
+
Returns:
|
706
|
+
List of (item, score, metadata) tuples sorted by RRF score
|
707
|
+
metadata contains ranks from each result list
|
708
|
+
"""
|
709
|
+
k = 60 # standard RRF constant from Cormack et al. (2009)
|
710
|
+
|
711
|
+
# create rank mappings based on position in result lists
|
712
|
+
# rank starts at 1, not 0
|
713
|
+
vector_ranks = {get_id_func(item): rank + 1 for rank, item in enumerate(vector_results)}
|
714
|
+
fts_ranks = {get_id_func(item): rank + 1 for rank, item in enumerate(fts_results)}
|
715
|
+
|
716
|
+
# combine all unique items from both result sets
|
717
|
+
all_items = {}
|
718
|
+
for item in vector_results:
|
719
|
+
all_items[get_id_func(item)] = item
|
720
|
+
for item in fts_results:
|
721
|
+
all_items[get_id_func(item)] = item
|
722
|
+
|
723
|
+
# calculate RRF scores based purely on ranks
|
724
|
+
rrf_scores = {}
|
725
|
+
score_metadata = {}
|
726
|
+
for item_id in all_items:
|
727
|
+
# RRF formula: sum of 1/(k + rank) across result lists
|
728
|
+
# If item not in a list, we don't add anything (equivalent to rank = infinity)
|
729
|
+
vector_rrf_score = 0.0
|
730
|
+
fts_rrf_score = 0.0
|
731
|
+
|
732
|
+
if item_id in vector_ranks:
|
733
|
+
vector_rrf_score = vector_weight / (k + vector_ranks[item_id])
|
734
|
+
if item_id in fts_ranks:
|
735
|
+
fts_rrf_score = fts_weight / (k + fts_ranks[item_id])
|
736
|
+
|
737
|
+
combined_score = vector_rrf_score + fts_rrf_score
|
738
|
+
|
739
|
+
rrf_scores[item_id] = combined_score
|
740
|
+
score_metadata[item_id] = {
|
741
|
+
"combined_score": combined_score, # Final RRF score
|
742
|
+
"vector_rank": vector_ranks.get(item_id),
|
743
|
+
"fts_rank": fts_ranks.get(item_id),
|
744
|
+
}
|
745
|
+
|
746
|
+
# sort by RRF score and return with metadata
|
747
|
+
sorted_results = sorted(
|
748
|
+
[(all_items[iid], score, score_metadata[iid]) for iid, score in rrf_scores.items()], key=lambda x: x[1], reverse=True
|
749
|
+
)
|
750
|
+
|
751
|
+
return sorted_results[:top_k]
|
752
|
+
|
753
|
+
@trace_method
|
754
|
+
async def delete_passage(self, archive_id: str, passage_id: str) -> bool:
|
755
|
+
"""Delete a passage from Turbopuffer."""
|
756
|
+
from turbopuffer import AsyncTurbopuffer
|
757
|
+
|
758
|
+
namespace_name = await self._get_archive_namespace_name(archive_id)
|
759
|
+
|
760
|
+
try:
|
761
|
+
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
762
|
+
namespace = client.namespace(namespace_name)
|
763
|
+
# Use write API with deletes parameter as per Turbopuffer docs
|
764
|
+
await namespace.write(deletes=[passage_id])
|
765
|
+
logger.info(f"Successfully deleted passage {passage_id} from Turbopuffer archive {archive_id}")
|
766
|
+
return True
|
767
|
+
except Exception as e:
|
768
|
+
logger.error(f"Failed to delete passage from Turbopuffer: {e}")
|
769
|
+
raise
|
770
|
+
|
771
|
+
@trace_method
|
772
|
+
async def delete_passages(self, archive_id: str, passage_ids: List[str]) -> bool:
|
773
|
+
"""Delete multiple passages from Turbopuffer."""
|
774
|
+
from turbopuffer import AsyncTurbopuffer
|
775
|
+
|
776
|
+
if not passage_ids:
|
777
|
+
return True
|
778
|
+
|
779
|
+
namespace_name = await self._get_archive_namespace_name(archive_id)
|
780
|
+
|
781
|
+
try:
|
782
|
+
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
783
|
+
namespace = client.namespace(namespace_name)
|
784
|
+
# Use write API with deletes parameter as per Turbopuffer docs
|
785
|
+
await namespace.write(deletes=passage_ids)
|
786
|
+
logger.info(f"Successfully deleted {len(passage_ids)} passages from Turbopuffer archive {archive_id}")
|
787
|
+
return True
|
788
|
+
except Exception as e:
|
789
|
+
logger.error(f"Failed to delete passages from Turbopuffer: {e}")
|
790
|
+
raise
|
791
|
+
|
792
|
+
@trace_method
|
793
|
+
async def delete_all_passages(self, archive_id: str) -> bool:
|
794
|
+
"""Delete all passages for an archive from Turbopuffer."""
|
795
|
+
from turbopuffer import AsyncTurbopuffer
|
796
|
+
|
797
|
+
namespace_name = await self._get_archive_namespace_name(archive_id)
|
798
|
+
|
799
|
+
try:
|
800
|
+
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
801
|
+
namespace = client.namespace(namespace_name)
|
802
|
+
# Turbopuffer has a delete_all() method on namespace
|
803
|
+
await namespace.delete_all()
|
804
|
+
logger.info(f"Successfully deleted all passages for archive {archive_id}")
|
805
|
+
return True
|
806
|
+
except Exception as e:
|
807
|
+
logger.error(f"Failed to delete all passages from Turbopuffer: {e}")
|
808
|
+
raise
|
809
|
+
|
810
|
+
@trace_method
|
811
|
+
async def delete_messages(self, agent_id: str, organization_id: str, message_ids: List[str]) -> bool:
|
812
|
+
"""Delete multiple messages from Turbopuffer."""
|
813
|
+
from turbopuffer import AsyncTurbopuffer
|
814
|
+
|
815
|
+
if not message_ids:
|
816
|
+
return True
|
817
|
+
|
818
|
+
namespace_name = await self._get_message_namespace_name(agent_id, organization_id)
|
819
|
+
|
820
|
+
try:
|
821
|
+
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
822
|
+
namespace = client.namespace(namespace_name)
|
823
|
+
# Use write API with deletes parameter as per Turbopuffer docs
|
824
|
+
await namespace.write(deletes=message_ids)
|
825
|
+
logger.info(f"Successfully deleted {len(message_ids)} messages from Turbopuffer for agent {agent_id}")
|
826
|
+
return True
|
827
|
+
except Exception as e:
|
828
|
+
logger.error(f"Failed to delete messages from Turbopuffer: {e}")
|
829
|
+
raise
|
830
|
+
|
831
|
+
@trace_method
|
832
|
+
async def delete_all_messages(self, agent_id: str, organization_id: str) -> bool:
|
833
|
+
"""Delete all messages for an agent from Turbopuffer."""
|
834
|
+
from turbopuffer import AsyncTurbopuffer
|
835
|
+
|
836
|
+
namespace_name = await self._get_message_namespace_name(agent_id, organization_id)
|
837
|
+
|
838
|
+
try:
|
839
|
+
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
840
|
+
namespace = client.namespace(namespace_name)
|
841
|
+
# Use delete_by_filter to only delete messages for this agent
|
842
|
+
# since namespace is now org-scoped
|
843
|
+
result = await namespace.write(delete_by_filter=("agent_id", "Eq", agent_id))
|
844
|
+
logger.info(f"Successfully deleted all messages for agent {agent_id} (deleted {result.rows_affected} rows)")
|
845
|
+
return True
|
846
|
+
except Exception as e:
|
847
|
+
logger.error(f"Failed to delete all messages from Turbopuffer: {e}")
|
848
|
+
raise
|