letta-nightly 0.11.7.dev20250909104137__py3-none-any.whl → 0.11.7.dev20250911104039__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- letta/adapters/letta_llm_adapter.py +81 -0
- letta/adapters/letta_llm_request_adapter.py +113 -0
- letta/adapters/letta_llm_stream_adapter.py +171 -0
- letta/agents/agent_loop.py +23 -0
- letta/agents/base_agent.py +4 -1
- letta/agents/base_agent_v2.py +68 -0
- letta/agents/helpers.py +3 -5
- letta/agents/letta_agent.py +23 -12
- letta/agents/letta_agent_v2.py +1221 -0
- letta/agents/voice_agent.py +2 -1
- letta/constants.py +1 -1
- letta/errors.py +12 -0
- letta/functions/function_sets/base.py +53 -12
- letta/functions/helpers.py +3 -2
- letta/functions/schema_generator.py +1 -1
- letta/groups/sleeptime_multi_agent_v2.py +4 -2
- letta/groups/sleeptime_multi_agent_v3.py +233 -0
- letta/helpers/tool_rule_solver.py +4 -0
- letta/helpers/tpuf_client.py +607 -34
- letta/interfaces/anthropic_streaming_interface.py +74 -30
- letta/interfaces/openai_streaming_interface.py +80 -37
- letta/llm_api/google_vertex_client.py +1 -1
- letta/llm_api/openai_client.py +45 -4
- letta/orm/agent.py +4 -1
- letta/orm/block.py +2 -0
- letta/orm/blocks_agents.py +1 -0
- letta/orm/group.py +1 -0
- letta/orm/source.py +8 -1
- letta/orm/sources_agents.py +2 -1
- letta/orm/step_metrics.py +10 -0
- letta/orm/tools_agents.py +5 -2
- letta/schemas/block.py +4 -0
- letta/schemas/enums.py +1 -0
- letta/schemas/group.py +8 -0
- letta/schemas/letta_message.py +1 -1
- letta/schemas/letta_request.py +2 -2
- letta/schemas/mcp.py +9 -1
- letta/schemas/message.py +42 -2
- letta/schemas/providers/ollama.py +1 -1
- letta/schemas/providers.py +1 -2
- letta/schemas/source.py +6 -0
- letta/schemas/step_metrics.py +2 -0
- letta/server/rest_api/interface.py +34 -2
- letta/server/rest_api/json_parser.py +2 -0
- letta/server/rest_api/redis_stream_manager.py +2 -1
- letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +4 -2
- letta/server/rest_api/routers/v1/__init__.py +2 -0
- letta/server/rest_api/routers/v1/agents.py +132 -170
- letta/server/rest_api/routers/v1/blocks.py +6 -0
- letta/server/rest_api/routers/v1/folders.py +25 -7
- letta/server/rest_api/routers/v1/groups.py +6 -0
- letta/server/rest_api/routers/v1/internal_templates.py +218 -12
- letta/server/rest_api/routers/v1/messages.py +14 -19
- letta/server/rest_api/routers/v1/runs.py +43 -28
- letta/server/rest_api/routers/v1/sources.py +25 -7
- letta/server/rest_api/routers/v1/tools.py +42 -0
- letta/server/rest_api/streaming_response.py +11 -2
- letta/server/server.py +9 -6
- letta/services/agent_manager.py +39 -59
- letta/services/agent_serialization_manager.py +26 -11
- letta/services/archive_manager.py +60 -9
- letta/services/block_manager.py +5 -0
- letta/services/file_processor/embedder/base_embedder.py +5 -0
- letta/services/file_processor/embedder/openai_embedder.py +4 -0
- letta/services/file_processor/embedder/pinecone_embedder.py +5 -1
- letta/services/file_processor/embedder/turbopuffer_embedder.py +71 -0
- letta/services/file_processor/file_processor.py +9 -7
- letta/services/group_manager.py +74 -11
- letta/services/mcp_manager.py +134 -28
- letta/services/message_manager.py +229 -125
- letta/services/passage_manager.py +2 -1
- letta/services/source_manager.py +23 -1
- letta/services/summarizer/summarizer.py +4 -1
- letta/services/tool_executor/core_tool_executor.py +2 -120
- letta/services/tool_executor/files_tool_executor.py +133 -8
- letta/services/tool_executor/multi_agent_tool_executor.py +17 -14
- letta/services/tool_sandbox/local_sandbox.py +2 -2
- letta/services/tool_sandbox/modal_version_manager.py +2 -1
- letta/settings.py +6 -0
- letta/streaming_utils.py +29 -4
- letta/utils.py +106 -4
- {letta_nightly-0.11.7.dev20250909104137.dist-info → letta_nightly-0.11.7.dev20250911104039.dist-info}/METADATA +2 -2
- {letta_nightly-0.11.7.dev20250909104137.dist-info → letta_nightly-0.11.7.dev20250911104039.dist-info}/RECORD +86 -78
- {letta_nightly-0.11.7.dev20250909104137.dist-info → letta_nightly-0.11.7.dev20250911104039.dist-info}/WHEEL +0 -0
- {letta_nightly-0.11.7.dev20250909104137.dist-info → letta_nightly-0.11.7.dev20250911104039.dist-info}/entry_points.txt +0 -0
- {letta_nightly-0.11.7.dev20250909104137.dist-info → letta_nightly-0.11.7.dev20250911104039.dist-info}/licenses/LICENSE +0 -0
letta/helpers/tpuf_client.py
CHANGED
@@ -4,16 +4,19 @@ import logging
|
|
4
4
|
from datetime import datetime, timezone
|
5
5
|
from typing import Any, Callable, List, Optional, Tuple
|
6
6
|
|
7
|
+
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
|
7
8
|
from letta.otel.tracing import trace_method
|
9
|
+
from letta.schemas.embedding_config import EmbeddingConfig
|
8
10
|
from letta.schemas.enums import MessageRole, TagMatchMode
|
9
11
|
from letta.schemas.passage import Passage as PydanticPassage
|
10
|
-
from letta.settings import settings
|
12
|
+
from letta.settings import model_settings, settings
|
11
13
|
|
12
14
|
logger = logging.getLogger(__name__)
|
13
15
|
|
14
16
|
|
15
17
|
def should_use_tpuf() -> bool:
|
16
|
-
|
18
|
+
# We need OpenAI since we default to their embedding model
|
19
|
+
return bool(settings.use_tpuf) and bool(settings.tpuf_api_key) and bool(model_settings.openai_api_key)
|
17
20
|
|
18
21
|
|
19
22
|
def should_use_tpuf_for_messages() -> bool:
|
@@ -24,6 +27,14 @@ def should_use_tpuf_for_messages() -> bool:
|
|
24
27
|
class TurbopufferClient:
|
25
28
|
"""Client for managing archival memory with Turbopuffer vector database."""
|
26
29
|
|
30
|
+
default_embedding_config = EmbeddingConfig(
|
31
|
+
embedding_model="text-embedding-3-small",
|
32
|
+
embedding_endpoint_type="openai",
|
33
|
+
embedding_endpoint="https://api.openai.com/v1",
|
34
|
+
embedding_dim=1536,
|
35
|
+
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
|
36
|
+
)
|
37
|
+
|
27
38
|
def __init__(self, api_key: str = None, region: str = None):
|
28
39
|
"""Initialize Turbopuffer client."""
|
29
40
|
self.api_key = api_key or settings.tpuf_api_key
|
@@ -38,32 +49,57 @@ class TurbopufferClient:
|
|
38
49
|
if not self.api_key:
|
39
50
|
raise ValueError("Turbopuffer API key not provided")
|
40
51
|
|
52
|
+
@trace_method
|
53
|
+
async def _generate_embeddings(self, texts: List[str], actor: "PydanticUser") -> List[List[float]]:
|
54
|
+
"""Generate embeddings using the default embedding configuration.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
texts: List of texts to embed
|
58
|
+
actor: User actor for embedding generation
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
List of embedding vectors
|
62
|
+
"""
|
63
|
+
from letta.llm_api.llm_client import LLMClient
|
64
|
+
|
65
|
+
embedding_client = LLMClient.create(
|
66
|
+
provider_type=self.default_embedding_config.embedding_endpoint_type,
|
67
|
+
actor=actor,
|
68
|
+
)
|
69
|
+
embeddings = await embedding_client.request_embeddings(texts, self.default_embedding_config)
|
70
|
+
return embeddings
|
71
|
+
|
41
72
|
@trace_method
|
42
73
|
async def _get_archive_namespace_name(self, archive_id: str) -> str:
|
43
74
|
"""Get namespace name for a specific archive."""
|
44
75
|
return await self.archive_manager.get_or_set_vector_db_namespace_async(archive_id)
|
45
76
|
|
46
77
|
@trace_method
|
47
|
-
async def _get_message_namespace_name(self,
|
78
|
+
async def _get_message_namespace_name(self, organization_id: str) -> str:
|
48
79
|
"""Get namespace name for messages (org-scoped).
|
49
80
|
|
50
81
|
Args:
|
51
|
-
agent_id: Agent ID (stored for future sharding)
|
52
82
|
organization_id: Organization ID for namespace generation
|
53
83
|
|
54
84
|
Returns:
|
55
85
|
The org-scoped namespace name for messages
|
56
86
|
"""
|
57
|
-
|
87
|
+
environment = settings.environment
|
88
|
+
if environment:
|
89
|
+
namespace_name = f"messages_{organization_id}_{environment.lower()}"
|
90
|
+
else:
|
91
|
+
namespace_name = f"messages_{organization_id}"
|
92
|
+
|
93
|
+
return namespace_name
|
58
94
|
|
59
95
|
@trace_method
|
60
96
|
async def insert_archival_memories(
|
61
97
|
self,
|
62
98
|
archive_id: str,
|
63
99
|
text_chunks: List[str],
|
64
|
-
embeddings: List[List[float]],
|
65
100
|
passage_ids: List[str],
|
66
101
|
organization_id: str,
|
102
|
+
actor: "PydanticUser",
|
67
103
|
tags: Optional[List[str]] = None,
|
68
104
|
created_at: Optional[datetime] = None,
|
69
105
|
) -> List[PydanticPassage]:
|
@@ -72,9 +108,9 @@ class TurbopufferClient:
|
|
72
108
|
Args:
|
73
109
|
archive_id: ID of the archive
|
74
110
|
text_chunks: List of text chunks to store
|
75
|
-
embeddings: List of embedding vectors corresponding to text chunks
|
76
111
|
passage_ids: List of passage IDs (must match 1:1 with text_chunks)
|
77
112
|
organization_id: Organization ID for the passages
|
113
|
+
actor: User actor for embedding generation
|
78
114
|
tags: Optional list of tags to attach to all passages
|
79
115
|
created_at: Optional timestamp for retroactive entries (defaults to current UTC time)
|
80
116
|
|
@@ -83,6 +119,9 @@ class TurbopufferClient:
|
|
83
119
|
"""
|
84
120
|
from turbopuffer import AsyncTurbopuffer
|
85
121
|
|
122
|
+
# generate embeddings using the default config
|
123
|
+
embeddings = await self._generate_embeddings(text_chunks, actor)
|
124
|
+
|
86
125
|
namespace_name = await self._get_archive_namespace_name(archive_id)
|
87
126
|
|
88
127
|
# handle timestamp - ensure UTC
|
@@ -102,8 +141,6 @@ class TurbopufferClient:
|
|
102
141
|
raise ValueError("passage_ids must be provided for Turbopuffer insertion")
|
103
142
|
if len(passage_ids) != len(text_chunks):
|
104
143
|
raise ValueError(f"passage_ids length ({len(passage_ids)}) must match text_chunks length ({len(text_chunks)})")
|
105
|
-
if len(passage_ids) != len(embeddings):
|
106
|
-
raise ValueError(f"passage_ids length ({len(passage_ids)}) must match embeddings length ({len(embeddings)})")
|
107
144
|
|
108
145
|
# prepare column-based data for turbopuffer - optimized for batch insert
|
109
146
|
ids = []
|
@@ -137,7 +174,7 @@ class TurbopufferClient:
|
|
137
174
|
metadata_={},
|
138
175
|
tags=tags or [], # Include tags in the passage
|
139
176
|
embedding=embedding,
|
140
|
-
embedding_config=
|
177
|
+
embedding_config=self.default_embedding_config, # Will be set by caller if needed
|
141
178
|
)
|
142
179
|
passages.append(passage)
|
143
180
|
|
@@ -177,37 +214,42 @@ class TurbopufferClient:
|
|
177
214
|
self,
|
178
215
|
agent_id: str,
|
179
216
|
message_texts: List[str],
|
180
|
-
embeddings: List[List[float]],
|
181
217
|
message_ids: List[str],
|
182
218
|
organization_id: str,
|
219
|
+
actor: "PydanticUser",
|
183
220
|
roles: List[MessageRole],
|
184
221
|
created_ats: List[datetime],
|
222
|
+
project_id: Optional[str] = None,
|
223
|
+
template_id: Optional[str] = None,
|
185
224
|
) -> bool:
|
186
225
|
"""Insert messages into Turbopuffer.
|
187
226
|
|
188
227
|
Args:
|
189
228
|
agent_id: ID of the agent
|
190
229
|
message_texts: List of message text content to store
|
191
|
-
embeddings: List of embedding vectors corresponding to message texts
|
192
230
|
message_ids: List of message IDs (must match 1:1 with message_texts)
|
193
231
|
organization_id: Organization ID for the messages
|
232
|
+
actor: User actor for embedding generation
|
194
233
|
roles: List of message roles corresponding to each message
|
195
234
|
created_ats: List of creation timestamps for each message
|
235
|
+
project_id: Optional project ID for all messages
|
236
|
+
template_id: Optional template ID for all messages
|
196
237
|
|
197
238
|
Returns:
|
198
239
|
True if successful
|
199
240
|
"""
|
200
241
|
from turbopuffer import AsyncTurbopuffer
|
201
242
|
|
202
|
-
|
243
|
+
# generate embeddings using the default config
|
244
|
+
embeddings = await self._generate_embeddings(message_texts, actor)
|
245
|
+
|
246
|
+
namespace_name = await self._get_message_namespace_name(organization_id)
|
203
247
|
|
204
248
|
# validation checks
|
205
249
|
if not message_ids:
|
206
250
|
raise ValueError("message_ids must be provided for Turbopuffer insertion")
|
207
251
|
if len(message_ids) != len(message_texts):
|
208
252
|
raise ValueError(f"message_ids length ({len(message_ids)}) must match message_texts length ({len(message_texts)})")
|
209
|
-
if len(message_ids) != len(embeddings):
|
210
|
-
raise ValueError(f"message_ids length ({len(message_ids)}) must match embeddings length ({len(embeddings)})")
|
211
253
|
if len(message_ids) != len(roles):
|
212
254
|
raise ValueError(f"message_ids length ({len(message_ids)}) must match roles length ({len(roles)})")
|
213
255
|
if len(message_ids) != len(created_ats):
|
@@ -221,6 +263,8 @@ class TurbopufferClient:
|
|
221
263
|
agent_ids = []
|
222
264
|
message_roles = []
|
223
265
|
created_at_timestamps = []
|
266
|
+
project_ids = []
|
267
|
+
template_ids = []
|
224
268
|
|
225
269
|
for idx, (text, embedding, role, created_at) in enumerate(zip(message_texts, embeddings, roles, created_ats)):
|
226
270
|
message_id = message_ids[idx]
|
@@ -241,6 +285,8 @@ class TurbopufferClient:
|
|
241
285
|
agent_ids.append(agent_id)
|
242
286
|
message_roles.append(role.value)
|
243
287
|
created_at_timestamps.append(timestamp)
|
288
|
+
project_ids.append(project_id)
|
289
|
+
template_ids.append(template_id)
|
244
290
|
|
245
291
|
# build column-based upsert data
|
246
292
|
upsert_columns = {
|
@@ -253,6 +299,14 @@ class TurbopufferClient:
|
|
253
299
|
"created_at": created_at_timestamps,
|
254
300
|
}
|
255
301
|
|
302
|
+
# only include project_id if it's provided
|
303
|
+
if project_id is not None:
|
304
|
+
upsert_columns["project_id"] = project_ids
|
305
|
+
|
306
|
+
# only include template_id if it's provided
|
307
|
+
if template_id is not None:
|
308
|
+
upsert_columns["template_id"] = template_ids
|
309
|
+
|
256
310
|
try:
|
257
311
|
# Use AsyncTurbopuffer as a context manager for proper resource cleanup
|
258
312
|
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
@@ -382,7 +436,7 @@ class TurbopufferClient:
|
|
382
436
|
async def query_passages(
|
383
437
|
self,
|
384
438
|
archive_id: str,
|
385
|
-
|
439
|
+
actor: "PydanticUser",
|
386
440
|
query_text: Optional[str] = None,
|
387
441
|
search_mode: str = "vector", # "vector", "fts", "hybrid"
|
388
442
|
top_k: int = 10,
|
@@ -392,13 +446,13 @@ class TurbopufferClient:
|
|
392
446
|
fts_weight: float = 0.5,
|
393
447
|
start_date: Optional[datetime] = None,
|
394
448
|
end_date: Optional[datetime] = None,
|
395
|
-
) -> List[Tuple[PydanticPassage, float]]:
|
449
|
+
) -> List[Tuple[PydanticPassage, float, dict]]:
|
396
450
|
"""Query passages from Turbopuffer using vector search, full-text search, or hybrid search.
|
397
451
|
|
398
452
|
Args:
|
399
453
|
archive_id: ID of the archive
|
400
|
-
|
401
|
-
query_text: Text query for
|
454
|
+
actor: User actor for embedding generation
|
455
|
+
query_text: Text query for search (used for embedding in vector/hybrid modes, and FTS in fts/hybrid modes)
|
402
456
|
search_mode: Search mode - "vector", "fts", or "hybrid" (default: "vector")
|
403
457
|
top_k: Number of results to return
|
404
458
|
tags: Optional list of tags to filter by
|
@@ -406,11 +460,17 @@ class TurbopufferClient:
|
|
406
460
|
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
|
407
461
|
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
|
408
462
|
start_date: Optional datetime to filter passages created after this date
|
409
|
-
end_date: Optional datetime to filter passages created before this date
|
463
|
+
end_date: Optional datetime to filter passages created on or before this date (inclusive)
|
410
464
|
|
411
465
|
Returns:
|
412
|
-
List of (passage, score) tuples
|
466
|
+
List of (passage, score, metadata) tuples with relevance rankings
|
413
467
|
"""
|
468
|
+
# generate embedding for vector/hybrid search if query_text is provided
|
469
|
+
query_embedding = None
|
470
|
+
if query_text and search_mode in ["vector", "hybrid"]:
|
471
|
+
embeddings = await self._generate_embeddings([query_text], actor)
|
472
|
+
query_embedding = embeddings[0]
|
473
|
+
|
414
474
|
# Check if we should fallback to timestamp-based retrieval
|
415
475
|
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
|
416
476
|
# Fallback to retrieving most recent passages when no search query is provided
|
@@ -439,6 +499,13 @@ class TurbopufferClient:
|
|
439
499
|
if start_date:
|
440
500
|
date_filters.append(("created_at", "Gte", start_date))
|
441
501
|
if end_date:
|
502
|
+
# if end_date has no time component (is at midnight), adjust to end of day
|
503
|
+
# to make the filter inclusive of the entire day
|
504
|
+
if end_date.hour == 0 and end_date.minute == 0 and end_date.second == 0 and end_date.microsecond == 0:
|
505
|
+
from datetime import timedelta
|
506
|
+
|
507
|
+
# add 1 day and subtract 1 microsecond to get 23:59:59.999999
|
508
|
+
end_date = end_date + timedelta(days=1) - timedelta(microseconds=1)
|
442
509
|
date_filters.append(("created_at", "Lte", end_date))
|
443
510
|
|
444
511
|
# combine all filters
|
@@ -474,7 +541,7 @@ class TurbopufferClient:
|
|
474
541
|
# for hybrid mode, we get a multi-query response
|
475
542
|
vector_results = self._process_single_query_results(result.results[0], archive_id, tags)
|
476
543
|
fts_results = self._process_single_query_results(result.results[1], archive_id, tags, is_fts=True)
|
477
|
-
# use RRF and
|
544
|
+
# use RRF and include metadata with ranks
|
478
545
|
results_with_metadata = self._reciprocal_rank_fusion(
|
479
546
|
vector_results=[passage for passage, _ in vector_results],
|
480
547
|
fts_results=[passage for passage, _ in fts_results],
|
@@ -483,26 +550,38 @@ class TurbopufferClient:
|
|
483
550
|
fts_weight=fts_weight,
|
484
551
|
top_k=top_k,
|
485
552
|
)
|
486
|
-
|
553
|
+
# Return (passage, score, metadata) with ranks
|
554
|
+
return results_with_metadata
|
487
555
|
else:
|
488
|
-
# for single queries (vector, fts, timestamp)
|
556
|
+
# for single queries (vector, fts, timestamp) - add basic metadata
|
489
557
|
is_fts = search_mode == "fts"
|
490
|
-
|
558
|
+
results = self._process_single_query_results(result, archive_id, tags, is_fts=is_fts)
|
559
|
+
# Add simple metadata for single search modes
|
560
|
+
results_with_metadata = []
|
561
|
+
for idx, (passage, score) in enumerate(results):
|
562
|
+
metadata = {
|
563
|
+
"combined_score": score,
|
564
|
+
f"{search_mode}_rank": idx + 1, # Add the rank for this search mode
|
565
|
+
}
|
566
|
+
results_with_metadata.append((passage, score, metadata))
|
567
|
+
return results_with_metadata
|
491
568
|
|
492
569
|
except Exception as e:
|
493
570
|
logger.error(f"Failed to query passages from Turbopuffer: {e}")
|
494
571
|
raise
|
495
572
|
|
496
573
|
@trace_method
|
497
|
-
async def
|
574
|
+
async def query_messages_by_agent_id(
|
498
575
|
self,
|
499
576
|
agent_id: str,
|
500
577
|
organization_id: str,
|
501
|
-
|
578
|
+
actor: "PydanticUser",
|
502
579
|
query_text: Optional[str] = None,
|
503
580
|
search_mode: str = "vector", # "vector", "fts", "hybrid", "timestamp"
|
504
581
|
top_k: int = 10,
|
505
582
|
roles: Optional[List[MessageRole]] = None,
|
583
|
+
project_id: Optional[str] = None,
|
584
|
+
template_id: Optional[str] = None,
|
506
585
|
vector_weight: float = 0.5,
|
507
586
|
fts_weight: float = 0.5,
|
508
587
|
start_date: Optional[datetime] = None,
|
@@ -513,15 +592,17 @@ class TurbopufferClient:
|
|
513
592
|
Args:
|
514
593
|
agent_id: ID of the agent (used for filtering results)
|
515
594
|
organization_id: Organization ID for namespace lookup
|
516
|
-
|
517
|
-
query_text: Text query for
|
595
|
+
actor: User actor for embedding generation
|
596
|
+
query_text: Text query for search (used for embedding in vector/hybrid modes, and FTS in fts/hybrid modes)
|
518
597
|
search_mode: Search mode - "vector", "fts", "hybrid", or "timestamp" (default: "vector")
|
519
598
|
top_k: Number of results to return
|
520
599
|
roles: Optional list of message roles to filter by
|
600
|
+
project_id: Optional project ID to filter messages by
|
601
|
+
template_id: Optional template ID to filter messages by
|
521
602
|
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
|
522
603
|
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
|
523
604
|
start_date: Optional datetime to filter messages created after this date
|
524
|
-
end_date: Optional datetime to filter messages created before this date
|
605
|
+
end_date: Optional datetime to filter messages created on or before this date (inclusive)
|
525
606
|
|
526
607
|
Returns:
|
527
608
|
List of (message_dict, score, metadata) tuples where:
|
@@ -529,12 +610,18 @@ class TurbopufferClient:
|
|
529
610
|
- score is the final relevance score
|
530
611
|
- metadata contains individual scores and ranking information
|
531
612
|
"""
|
613
|
+
# generate embedding for vector/hybrid search if query_text is provided
|
614
|
+
query_embedding = None
|
615
|
+
if query_text and search_mode in ["vector", "hybrid"]:
|
616
|
+
embeddings = await self._generate_embeddings([query_text], actor)
|
617
|
+
query_embedding = embeddings[0]
|
618
|
+
|
532
619
|
# Check if we should fallback to timestamp-based retrieval
|
533
620
|
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
|
534
621
|
# Fallback to retrieving most recent messages when no search query is provided
|
535
622
|
search_mode = "timestamp"
|
536
623
|
|
537
|
-
namespace_name = await self._get_message_namespace_name(
|
624
|
+
namespace_name = await self._get_message_namespace_name(organization_id)
|
538
625
|
|
539
626
|
# build agent_id filter
|
540
627
|
agent_filter = ("agent_id", "Eq", agent_id)
|
@@ -553,12 +640,33 @@ class TurbopufferClient:
|
|
553
640
|
if start_date:
|
554
641
|
date_filters.append(("created_at", "Gte", start_date))
|
555
642
|
if end_date:
|
643
|
+
# if end_date has no time component (is at midnight), adjust to end of day
|
644
|
+
# to make the filter inclusive of the entire day
|
645
|
+
if end_date.hour == 0 and end_date.minute == 0 and end_date.second == 0 and end_date.microsecond == 0:
|
646
|
+
from datetime import timedelta
|
647
|
+
|
648
|
+
# add 1 day and subtract 1 microsecond to get 23:59:59.999999
|
649
|
+
end_date = end_date + timedelta(days=1) - timedelta(microseconds=1)
|
556
650
|
date_filters.append(("created_at", "Lte", end_date))
|
557
651
|
|
652
|
+
# build project_id filter if provided
|
653
|
+
project_filter = None
|
654
|
+
if project_id:
|
655
|
+
project_filter = ("project_id", "Eq", project_id)
|
656
|
+
|
657
|
+
# build template_id filter if provided
|
658
|
+
template_filter = None
|
659
|
+
if template_id:
|
660
|
+
template_filter = ("template_id", "Eq", template_id)
|
661
|
+
|
558
662
|
# combine all filters
|
559
663
|
all_filters = [agent_filter] # always include agent_id filter
|
560
664
|
if role_filter:
|
561
665
|
all_filters.append(role_filter)
|
666
|
+
if project_filter:
|
667
|
+
all_filters.append(project_filter)
|
668
|
+
if template_filter:
|
669
|
+
all_filters.append(template_filter)
|
562
670
|
if date_filters:
|
563
671
|
all_filters.extend(date_filters)
|
564
672
|
|
@@ -617,6 +725,165 @@ class TurbopufferClient:
|
|
617
725
|
logger.error(f"Failed to query messages from Turbopuffer: {e}")
|
618
726
|
raise
|
619
727
|
|
728
|
+
async def query_messages_by_org_id(
|
729
|
+
self,
|
730
|
+
organization_id: str,
|
731
|
+
actor: "PydanticUser",
|
732
|
+
query_text: Optional[str] = None,
|
733
|
+
search_mode: str = "hybrid", # "vector", "fts", "hybrid"
|
734
|
+
top_k: int = 10,
|
735
|
+
roles: Optional[List[MessageRole]] = None,
|
736
|
+
project_id: Optional[str] = None,
|
737
|
+
template_id: Optional[str] = None,
|
738
|
+
vector_weight: float = 0.5,
|
739
|
+
fts_weight: float = 0.5,
|
740
|
+
start_date: Optional[datetime] = None,
|
741
|
+
end_date: Optional[datetime] = None,
|
742
|
+
) -> List[Tuple[dict, float, dict]]:
|
743
|
+
"""Query messages from Turbopuffer across an entire organization.
|
744
|
+
|
745
|
+
Args:
|
746
|
+
organization_id: Organization ID for namespace lookup (required)
|
747
|
+
actor: User actor for embedding generation
|
748
|
+
query_text: Text query for search (used for embedding in vector/hybrid modes, and FTS in fts/hybrid modes)
|
749
|
+
search_mode: Search mode - "vector", "fts", or "hybrid" (default: "hybrid")
|
750
|
+
top_k: Number of results to return
|
751
|
+
roles: Optional list of message roles to filter by
|
752
|
+
project_id: Optional project ID to filter messages by
|
753
|
+
template_id: Optional template ID to filter messages by
|
754
|
+
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
|
755
|
+
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
|
756
|
+
start_date: Optional datetime to filter messages created after this date
|
757
|
+
end_date: Optional datetime to filter messages created on or before this date (inclusive)
|
758
|
+
|
759
|
+
Returns:
|
760
|
+
List of (message_dict, score, metadata) tuples where:
|
761
|
+
- message_dict contains id, text, role, created_at, agent_id
|
762
|
+
- score is the final relevance score (RRF score for hybrid, rank-based for single mode)
|
763
|
+
- metadata contains individual scores and ranking information
|
764
|
+
"""
|
765
|
+
# generate embedding for vector/hybrid search if query_text is provided
|
766
|
+
query_embedding = None
|
767
|
+
if query_text and search_mode in ["vector", "hybrid"]:
|
768
|
+
embeddings = await self._generate_embeddings([query_text], actor)
|
769
|
+
query_embedding = embeddings[0]
|
770
|
+
# namespace is org-scoped
|
771
|
+
namespace_name = await self._get_message_namespace_name(organization_id)
|
772
|
+
|
773
|
+
# build filters
|
774
|
+
all_filters = []
|
775
|
+
|
776
|
+
# role filter
|
777
|
+
if roles:
|
778
|
+
role_values = [r.value for r in roles]
|
779
|
+
if len(role_values) == 1:
|
780
|
+
all_filters.append(("role", "Eq", role_values[0]))
|
781
|
+
else:
|
782
|
+
all_filters.append(("role", "In", role_values))
|
783
|
+
|
784
|
+
# project filter
|
785
|
+
if project_id:
|
786
|
+
all_filters.append(("project_id", "Eq", project_id))
|
787
|
+
|
788
|
+
# template filter
|
789
|
+
if template_id:
|
790
|
+
all_filters.append(("template_id", "Eq", template_id))
|
791
|
+
|
792
|
+
# date filters
|
793
|
+
if start_date:
|
794
|
+
all_filters.append(("created_at", "Gte", start_date))
|
795
|
+
if end_date:
|
796
|
+
# make end_date inclusive of the entire day
|
797
|
+
if end_date.hour == 0 and end_date.minute == 0 and end_date.second == 0 and end_date.microsecond == 0:
|
798
|
+
from datetime import timedelta
|
799
|
+
|
800
|
+
end_date = end_date + timedelta(days=1) - timedelta(microseconds=1)
|
801
|
+
all_filters.append(("created_at", "Lte", end_date))
|
802
|
+
|
803
|
+
# combine filters
|
804
|
+
final_filter = None
|
805
|
+
if len(all_filters) == 1:
|
806
|
+
final_filter = all_filters[0]
|
807
|
+
elif len(all_filters) > 1:
|
808
|
+
final_filter = ("And", all_filters)
|
809
|
+
|
810
|
+
try:
|
811
|
+
# execute query
|
812
|
+
result = await self._execute_query(
|
813
|
+
namespace_name=namespace_name,
|
814
|
+
search_mode=search_mode,
|
815
|
+
query_embedding=query_embedding,
|
816
|
+
query_text=query_text,
|
817
|
+
top_k=top_k,
|
818
|
+
include_attributes=["text", "organization_id", "agent_id", "role", "created_at"],
|
819
|
+
filters=final_filter,
|
820
|
+
vector_weight=vector_weight,
|
821
|
+
fts_weight=fts_weight,
|
822
|
+
)
|
823
|
+
|
824
|
+
# process results based on search mode
|
825
|
+
if search_mode == "hybrid":
|
826
|
+
# for hybrid mode, we get a multi-query response
|
827
|
+
vector_results = self._process_message_query_results(result.results[0])
|
828
|
+
fts_results = self._process_message_query_results(result.results[1])
|
829
|
+
|
830
|
+
# use existing RRF method - it already returns metadata with ranks
|
831
|
+
results_with_metadata = self._reciprocal_rank_fusion(
|
832
|
+
vector_results=vector_results,
|
833
|
+
fts_results=fts_results,
|
834
|
+
get_id_func=lambda msg_dict: msg_dict["id"],
|
835
|
+
vector_weight=vector_weight,
|
836
|
+
fts_weight=fts_weight,
|
837
|
+
top_k=top_k,
|
838
|
+
)
|
839
|
+
|
840
|
+
# add raw scores to metadata if available
|
841
|
+
vector_scores = {}
|
842
|
+
for row in result.results[0].rows:
|
843
|
+
if hasattr(row, "dist"):
|
844
|
+
vector_scores[row.id] = row.dist
|
845
|
+
|
846
|
+
fts_scores = {}
|
847
|
+
for row in result.results[1].rows:
|
848
|
+
if hasattr(row, "score"):
|
849
|
+
fts_scores[row.id] = row.score
|
850
|
+
|
851
|
+
# enhance metadata with raw scores
|
852
|
+
enhanced_results = []
|
853
|
+
for msg_dict, rrf_score, metadata in results_with_metadata:
|
854
|
+
msg_id = msg_dict["id"]
|
855
|
+
if msg_id in vector_scores:
|
856
|
+
metadata["vector_score"] = vector_scores[msg_id]
|
857
|
+
if msg_id in fts_scores:
|
858
|
+
metadata["fts_score"] = fts_scores[msg_id]
|
859
|
+
enhanced_results.append((msg_dict, rrf_score, metadata))
|
860
|
+
|
861
|
+
return enhanced_results
|
862
|
+
else:
|
863
|
+
# for single queries (vector or fts)
|
864
|
+
results = self._process_message_query_results(result)
|
865
|
+
results_with_metadata = []
|
866
|
+
for idx, msg_dict in enumerate(results):
|
867
|
+
metadata = {
|
868
|
+
"combined_score": 1.0 / (idx + 1),
|
869
|
+
"search_mode": search_mode,
|
870
|
+
f"{search_mode}_rank": idx + 1,
|
871
|
+
}
|
872
|
+
|
873
|
+
# add raw score if available
|
874
|
+
if hasattr(result.rows[idx], "dist"):
|
875
|
+
metadata["vector_score"] = result.rows[idx].dist
|
876
|
+
elif hasattr(result.rows[idx], "score"):
|
877
|
+
metadata["fts_score"] = result.rows[idx].score
|
878
|
+
|
879
|
+
results_with_metadata.append((msg_dict, metadata["combined_score"], metadata))
|
880
|
+
|
881
|
+
return results_with_metadata
|
882
|
+
|
883
|
+
except Exception as e:
|
884
|
+
logger.error(f"Failed to query messages from Turbopuffer: {e}")
|
885
|
+
raise
|
886
|
+
|
620
887
|
def _process_message_query_results(self, result) -> List[dict]:
|
621
888
|
"""Process results from a message query into message dicts.
|
622
889
|
|
@@ -662,7 +929,7 @@ class TurbopufferClient:
|
|
662
929
|
tags=passage_tags, # Set the actual tags from the passage
|
663
930
|
# Set required fields to empty/default values since we don't store embeddings
|
664
931
|
embedding=[], # Empty embedding since we don't return it from Turbopuffer
|
665
|
-
embedding_config=
|
932
|
+
embedding_config=self.default_embedding_config, # No embedding config needed for retrieved passages
|
666
933
|
)
|
667
934
|
|
668
935
|
# handle score based on search type
|
@@ -815,7 +1082,7 @@ class TurbopufferClient:
|
|
815
1082
|
if not message_ids:
|
816
1083
|
return True
|
817
1084
|
|
818
|
-
namespace_name = await self._get_message_namespace_name(
|
1085
|
+
namespace_name = await self._get_message_namespace_name(organization_id)
|
819
1086
|
|
820
1087
|
try:
|
821
1088
|
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
@@ -833,7 +1100,7 @@ class TurbopufferClient:
|
|
833
1100
|
"""Delete all messages for an agent from Turbopuffer."""
|
834
1101
|
from turbopuffer import AsyncTurbopuffer
|
835
1102
|
|
836
|
-
namespace_name = await self._get_message_namespace_name(
|
1103
|
+
namespace_name = await self._get_message_namespace_name(organization_id)
|
837
1104
|
|
838
1105
|
try:
|
839
1106
|
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
@@ -846,3 +1113,309 @@ class TurbopufferClient:
|
|
846
1113
|
except Exception as e:
|
847
1114
|
logger.error(f"Failed to delete all messages from Turbopuffer: {e}")
|
848
1115
|
raise
|
1116
|
+
|
1117
|
+
# file/source passage methods
|
1118
|
+
|
1119
|
+
@trace_method
|
1120
|
+
async def _get_file_passages_namespace_name(self, organization_id: str) -> str:
|
1121
|
+
"""Get namespace name for file passages (org-scoped).
|
1122
|
+
|
1123
|
+
Args:
|
1124
|
+
organization_id: Organization ID for namespace generation
|
1125
|
+
|
1126
|
+
Returns:
|
1127
|
+
The org-scoped namespace name for file passages
|
1128
|
+
"""
|
1129
|
+
environment = settings.environment
|
1130
|
+
if environment:
|
1131
|
+
namespace_name = f"file_passages_{organization_id}_{environment.lower()}"
|
1132
|
+
else:
|
1133
|
+
namespace_name = f"file_passages_{organization_id}"
|
1134
|
+
|
1135
|
+
return namespace_name
|
1136
|
+
|
1137
|
+
@trace_method
|
1138
|
+
async def insert_file_passages(
|
1139
|
+
self,
|
1140
|
+
source_id: str,
|
1141
|
+
file_id: str,
|
1142
|
+
text_chunks: List[str],
|
1143
|
+
organization_id: str,
|
1144
|
+
actor: "PydanticUser",
|
1145
|
+
created_at: Optional[datetime] = None,
|
1146
|
+
) -> List[PydanticPassage]:
|
1147
|
+
"""Insert file passages into Turbopuffer using org-scoped namespace.
|
1148
|
+
|
1149
|
+
Args:
|
1150
|
+
source_id: ID of the source containing the file
|
1151
|
+
file_id: ID of the file
|
1152
|
+
text_chunks: List of text chunks to store
|
1153
|
+
organization_id: Organization ID for the passages
|
1154
|
+
actor: User actor for embedding generation
|
1155
|
+
created_at: Optional timestamp for retroactive entries (defaults to current UTC time)
|
1156
|
+
|
1157
|
+
Returns:
|
1158
|
+
List of PydanticPassage objects that were inserted
|
1159
|
+
"""
|
1160
|
+
from turbopuffer import AsyncTurbopuffer
|
1161
|
+
|
1162
|
+
if not text_chunks:
|
1163
|
+
return []
|
1164
|
+
|
1165
|
+
# generate embeddings using the default config
|
1166
|
+
embeddings = await self._generate_embeddings(text_chunks, actor)
|
1167
|
+
|
1168
|
+
namespace_name = await self._get_file_passages_namespace_name(organization_id)
|
1169
|
+
|
1170
|
+
# handle timestamp - ensure UTC
|
1171
|
+
if created_at is None:
|
1172
|
+
timestamp = datetime.now(timezone.utc)
|
1173
|
+
else:
|
1174
|
+
# ensure the provided timestamp is timezone-aware and in UTC
|
1175
|
+
if created_at.tzinfo is None:
|
1176
|
+
# assume UTC if no timezone provided
|
1177
|
+
timestamp = created_at.replace(tzinfo=timezone.utc)
|
1178
|
+
else:
|
1179
|
+
# convert to UTC if in different timezone
|
1180
|
+
timestamp = created_at.astimezone(timezone.utc)
|
1181
|
+
|
1182
|
+
# prepare column-based data for turbopuffer - optimized for batch insert
|
1183
|
+
ids = []
|
1184
|
+
vectors = []
|
1185
|
+
texts = []
|
1186
|
+
organization_ids = []
|
1187
|
+
source_ids = []
|
1188
|
+
file_ids = []
|
1189
|
+
created_ats = []
|
1190
|
+
passages = []
|
1191
|
+
|
1192
|
+
for idx, (text, embedding) in enumerate(zip(text_chunks, embeddings)):
|
1193
|
+
passage = PydanticPassage(
|
1194
|
+
text=text,
|
1195
|
+
file_id=file_id,
|
1196
|
+
source_id=source_id,
|
1197
|
+
embedding=embedding,
|
1198
|
+
embedding_config=self.default_embedding_config,
|
1199
|
+
organization_id=actor.organization_id,
|
1200
|
+
)
|
1201
|
+
passages.append(passage)
|
1202
|
+
|
1203
|
+
# append to columns
|
1204
|
+
ids.append(passage.id)
|
1205
|
+
vectors.append(embedding)
|
1206
|
+
texts.append(text)
|
1207
|
+
organization_ids.append(organization_id)
|
1208
|
+
source_ids.append(source_id)
|
1209
|
+
file_ids.append(file_id)
|
1210
|
+
created_ats.append(timestamp)
|
1211
|
+
|
1212
|
+
# build column-based upsert data
|
1213
|
+
upsert_columns = {
|
1214
|
+
"id": ids,
|
1215
|
+
"vector": vectors,
|
1216
|
+
"text": texts,
|
1217
|
+
"organization_id": organization_ids,
|
1218
|
+
"source_id": source_ids,
|
1219
|
+
"file_id": file_ids,
|
1220
|
+
"created_at": created_ats,
|
1221
|
+
}
|
1222
|
+
|
1223
|
+
try:
|
1224
|
+
# use AsyncTurbopuffer as a context manager for proper resource cleanup
|
1225
|
+
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
1226
|
+
namespace = client.namespace(namespace_name)
|
1227
|
+
# turbopuffer recommends column-based writes for performance
|
1228
|
+
await namespace.write(
|
1229
|
+
upsert_columns=upsert_columns,
|
1230
|
+
distance_metric="cosine_distance",
|
1231
|
+
schema={"text": {"type": "string", "full_text_search": True}},
|
1232
|
+
)
|
1233
|
+
logger.info(f"Successfully inserted {len(ids)} file passages to Turbopuffer for source {source_id}, file {file_id}")
|
1234
|
+
return passages
|
1235
|
+
|
1236
|
+
except Exception as e:
|
1237
|
+
logger.error(f"Failed to insert file passages to Turbopuffer: {e}")
|
1238
|
+
# check if it's a duplicate ID error
|
1239
|
+
if "duplicate" in str(e).lower():
|
1240
|
+
logger.error("Duplicate passage IDs detected in batch")
|
1241
|
+
raise
|
1242
|
+
|
1243
|
+
@trace_method
|
1244
|
+
async def query_file_passages(
|
1245
|
+
self,
|
1246
|
+
source_ids: List[str],
|
1247
|
+
organization_id: str,
|
1248
|
+
actor: "PydanticUser",
|
1249
|
+
query_text: Optional[str] = None,
|
1250
|
+
search_mode: str = "vector", # "vector", "fts", "hybrid"
|
1251
|
+
top_k: int = 10,
|
1252
|
+
file_id: Optional[str] = None, # optional filter by specific file
|
1253
|
+
vector_weight: float = 0.5,
|
1254
|
+
fts_weight: float = 0.5,
|
1255
|
+
) -> List[Tuple[PydanticPassage, float, dict]]:
|
1256
|
+
"""Query file passages from Turbopuffer using org-scoped namespace.
|
1257
|
+
|
1258
|
+
Args:
|
1259
|
+
source_ids: List of source IDs to query
|
1260
|
+
organization_id: Organization ID for namespace lookup
|
1261
|
+
actor: User actor for embedding generation
|
1262
|
+
query_text: Text query for search
|
1263
|
+
search_mode: Search mode - "vector", "fts", or "hybrid" (default: "vector")
|
1264
|
+
top_k: Number of results to return
|
1265
|
+
file_id: Optional file ID to filter results to a specific file
|
1266
|
+
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
|
1267
|
+
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
|
1268
|
+
|
1269
|
+
Returns:
|
1270
|
+
List of (passage, score, metadata) tuples with relevance rankings
|
1271
|
+
"""
|
1272
|
+
# generate embedding for vector/hybrid search if query_text is provided
|
1273
|
+
query_embedding = None
|
1274
|
+
if query_text and search_mode in ["vector", "hybrid"]:
|
1275
|
+
embeddings = await self._generate_embeddings([query_text], actor)
|
1276
|
+
query_embedding = embeddings[0]
|
1277
|
+
|
1278
|
+
# check if we should fallback to timestamp-based retrieval
|
1279
|
+
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
|
1280
|
+
# fallback to retrieving most recent passages when no search query is provided
|
1281
|
+
search_mode = "timestamp"
|
1282
|
+
|
1283
|
+
namespace_name = await self._get_file_passages_namespace_name(organization_id)
|
1284
|
+
|
1285
|
+
# build filters - always filter by source_ids
|
1286
|
+
if len(source_ids) == 1:
|
1287
|
+
# single source_id, use Eq for efficiency
|
1288
|
+
filters = [("source_id", "Eq", source_ids[0])]
|
1289
|
+
else:
|
1290
|
+
# multiple source_ids, use In operator
|
1291
|
+
filters = [("source_id", "In", source_ids)]
|
1292
|
+
|
1293
|
+
# add file filter if specified
|
1294
|
+
if file_id:
|
1295
|
+
filters.append(("file_id", "Eq", file_id))
|
1296
|
+
|
1297
|
+
# combine filters
|
1298
|
+
final_filter = filters[0] if len(filters) == 1 else ("And", filters)
|
1299
|
+
|
1300
|
+
try:
|
1301
|
+
# use generic query executor
|
1302
|
+
result = await self._execute_query(
|
1303
|
+
namespace_name=namespace_name,
|
1304
|
+
search_mode=search_mode,
|
1305
|
+
query_embedding=query_embedding,
|
1306
|
+
query_text=query_text,
|
1307
|
+
top_k=top_k,
|
1308
|
+
include_attributes=["text", "organization_id", "source_id", "file_id", "created_at"],
|
1309
|
+
filters=final_filter,
|
1310
|
+
vector_weight=vector_weight,
|
1311
|
+
fts_weight=fts_weight,
|
1312
|
+
)
|
1313
|
+
|
1314
|
+
# process results based on search mode
|
1315
|
+
if search_mode == "hybrid":
|
1316
|
+
# for hybrid mode, we get a multi-query response
|
1317
|
+
vector_results = self._process_file_query_results(result.results[0])
|
1318
|
+
fts_results = self._process_file_query_results(result.results[1], is_fts=True)
|
1319
|
+
# use RRF and include metadata with ranks
|
1320
|
+
results_with_metadata = self._reciprocal_rank_fusion(
|
1321
|
+
vector_results=[passage for passage, _ in vector_results],
|
1322
|
+
fts_results=[passage for passage, _ in fts_results],
|
1323
|
+
get_id_func=lambda p: p.id,
|
1324
|
+
vector_weight=vector_weight,
|
1325
|
+
fts_weight=fts_weight,
|
1326
|
+
top_k=top_k,
|
1327
|
+
)
|
1328
|
+
return results_with_metadata
|
1329
|
+
else:
|
1330
|
+
# for single queries (vector, fts, timestamp) - add basic metadata
|
1331
|
+
is_fts = search_mode == "fts"
|
1332
|
+
results = self._process_file_query_results(result, is_fts=is_fts)
|
1333
|
+
# add simple metadata for single search modes
|
1334
|
+
results_with_metadata = []
|
1335
|
+
for idx, (passage, score) in enumerate(results):
|
1336
|
+
metadata = {
|
1337
|
+
"combined_score": score,
|
1338
|
+
f"{search_mode}_rank": idx + 1, # add the rank for this search mode
|
1339
|
+
}
|
1340
|
+
results_with_metadata.append((passage, score, metadata))
|
1341
|
+
return results_with_metadata
|
1342
|
+
|
1343
|
+
except Exception as e:
|
1344
|
+
logger.error(f"Failed to query file passages from Turbopuffer: {e}")
|
1345
|
+
raise
|
1346
|
+
|
1347
|
+
def _process_file_query_results(self, result, is_fts: bool = False) -> List[Tuple[PydanticPassage, float]]:
|
1348
|
+
"""Process results from a file query into passage objects with scores."""
|
1349
|
+
passages_with_scores = []
|
1350
|
+
|
1351
|
+
for row in result.rows:
|
1352
|
+
# build metadata
|
1353
|
+
metadata = {}
|
1354
|
+
|
1355
|
+
# create a passage with minimal fields - embeddings are not returned from Turbopuffer
|
1356
|
+
passage = PydanticPassage(
|
1357
|
+
id=row.id,
|
1358
|
+
text=getattr(row, "text", ""),
|
1359
|
+
organization_id=getattr(row, "organization_id", None),
|
1360
|
+
source_id=getattr(row, "source_id", None), # get source_id from the row
|
1361
|
+
file_id=getattr(row, "file_id", None),
|
1362
|
+
created_at=getattr(row, "created_at", None),
|
1363
|
+
metadata_=metadata,
|
1364
|
+
tags=[],
|
1365
|
+
# set required fields to empty/default values since we don't store embeddings
|
1366
|
+
embedding=[], # empty embedding since we don't return it from Turbopuffer
|
1367
|
+
embedding_config=self.default_embedding_config,
|
1368
|
+
)
|
1369
|
+
|
1370
|
+
# handle score based on search type
|
1371
|
+
if is_fts:
|
1372
|
+
# for FTS, use the BM25 score directly (higher is better)
|
1373
|
+
score = getattr(row, "$score", 0.0)
|
1374
|
+
else:
|
1375
|
+
# for vector search, convert distance to similarity score
|
1376
|
+
distance = getattr(row, "$dist", 0.0)
|
1377
|
+
score = 1.0 - distance
|
1378
|
+
|
1379
|
+
passages_with_scores.append((passage, score))
|
1380
|
+
|
1381
|
+
return passages_with_scores
|
1382
|
+
|
1383
|
+
@trace_method
|
1384
|
+
async def delete_file_passages(self, source_id: str, file_id: str, organization_id: str) -> bool:
|
1385
|
+
"""Delete all passages for a specific file from Turbopuffer."""
|
1386
|
+
from turbopuffer import AsyncTurbopuffer
|
1387
|
+
|
1388
|
+
namespace_name = await self._get_file_passages_namespace_name(organization_id)
|
1389
|
+
|
1390
|
+
try:
|
1391
|
+
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
1392
|
+
namespace = client.namespace(namespace_name)
|
1393
|
+
# use delete_by_filter to only delete passages for this file
|
1394
|
+
# need to filter by both source_id and file_id
|
1395
|
+
filter_expr = ("And", [("source_id", "Eq", source_id), ("file_id", "Eq", file_id)])
|
1396
|
+
result = await namespace.write(delete_by_filter=filter_expr)
|
1397
|
+
logger.info(
|
1398
|
+
f"Successfully deleted passages for file {file_id} from source {source_id} (deleted {result.rows_affected} rows)"
|
1399
|
+
)
|
1400
|
+
return True
|
1401
|
+
except Exception as e:
|
1402
|
+
logger.error(f"Failed to delete file passages from Turbopuffer: {e}")
|
1403
|
+
raise
|
1404
|
+
|
1405
|
+
@trace_method
|
1406
|
+
async def delete_source_passages(self, source_id: str, organization_id: str) -> bool:
|
1407
|
+
"""Delete all passages for a source from Turbopuffer."""
|
1408
|
+
from turbopuffer import AsyncTurbopuffer
|
1409
|
+
|
1410
|
+
namespace_name = await self._get_file_passages_namespace_name(organization_id)
|
1411
|
+
|
1412
|
+
try:
|
1413
|
+
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
1414
|
+
namespace = client.namespace(namespace_name)
|
1415
|
+
# delete all passages for this source
|
1416
|
+
result = await namespace.write(delete_by_filter=("source_id", "Eq", source_id))
|
1417
|
+
logger.info(f"Successfully deleted all passages for source {source_id} (deleted {result.rows_affected} rows)")
|
1418
|
+
return True
|
1419
|
+
except Exception as e:
|
1420
|
+
logger.error(f"Failed to delete source passages from Turbopuffer: {e}")
|
1421
|
+
raise
|