letta-nightly 0.7.29.dev20250602104315__py3-none-any.whl → 0.8.0.dev20250604104349__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- letta/__init__.py +7 -1
- letta/agent.py +16 -9
- letta/agents/base_agent.py +1 -0
- letta/agents/ephemeral_summary_agent.py +104 -0
- letta/agents/helpers.py +35 -3
- letta/agents/letta_agent.py +492 -176
- letta/agents/letta_agent_batch.py +22 -16
- letta/agents/prompts/summary_system_prompt.txt +62 -0
- letta/agents/voice_agent.py +22 -7
- letta/agents/voice_sleeptime_agent.py +13 -8
- letta/constants.py +33 -1
- letta/data_sources/connectors.py +52 -36
- letta/errors.py +4 -0
- letta/functions/ast_parsers.py +13 -30
- letta/functions/function_sets/base.py +3 -1
- letta/functions/functions.py +2 -0
- letta/functions/mcp_client/base_client.py +151 -97
- letta/functions/mcp_client/sse_client.py +49 -31
- letta/functions/mcp_client/stdio_client.py +107 -106
- letta/functions/schema_generator.py +22 -22
- letta/groups/helpers.py +3 -4
- letta/groups/sleeptime_multi_agent.py +4 -4
- letta/groups/sleeptime_multi_agent_v2.py +22 -0
- letta/helpers/composio_helpers.py +16 -0
- letta/helpers/converters.py +20 -0
- letta/helpers/datetime_helpers.py +1 -6
- letta/helpers/tool_rule_solver.py +2 -1
- letta/interfaces/anthropic_streaming_interface.py +17 -2
- letta/interfaces/openai_chat_completions_streaming_interface.py +1 -0
- letta/interfaces/openai_streaming_interface.py +18 -2
- letta/jobs/llm_batch_job_polling.py +1 -1
- letta/jobs/scheduler.py +1 -1
- letta/llm_api/anthropic_client.py +24 -3
- letta/llm_api/google_ai_client.py +0 -15
- letta/llm_api/google_vertex_client.py +6 -5
- letta/llm_api/llm_client_base.py +15 -0
- letta/llm_api/openai.py +2 -2
- letta/llm_api/openai_client.py +60 -8
- letta/orm/__init__.py +2 -0
- letta/orm/agent.py +45 -43
- letta/orm/base.py +0 -2
- letta/orm/block.py +1 -0
- letta/orm/custom_columns.py +13 -0
- letta/orm/enums.py +5 -0
- letta/orm/file.py +3 -1
- letta/orm/files_agents.py +68 -0
- letta/orm/mcp_server.py +48 -0
- letta/orm/message.py +1 -0
- letta/orm/organization.py +11 -2
- letta/orm/passage.py +25 -10
- letta/orm/sandbox_config.py +5 -2
- letta/orm/sqlalchemy_base.py +171 -110
- letta/prompts/system/memgpt_base.txt +6 -1
- letta/prompts/system/memgpt_v2_chat.txt +57 -0
- letta/prompts/system/sleeptime.txt +2 -0
- letta/prompts/system/sleeptime_v2.txt +28 -0
- letta/schemas/agent.py +87 -20
- letta/schemas/block.py +7 -1
- letta/schemas/file.py +57 -0
- letta/schemas/mcp.py +74 -0
- letta/schemas/memory.py +5 -2
- letta/schemas/message.py +9 -0
- letta/schemas/openai/openai.py +0 -6
- letta/schemas/providers.py +33 -4
- letta/schemas/tool.py +26 -21
- letta/schemas/tool_execution_result.py +5 -0
- letta/server/db.py +23 -8
- letta/server/rest_api/app.py +73 -56
- letta/server/rest_api/interface.py +4 -4
- letta/server/rest_api/routers/v1/agents.py +132 -47
- letta/server/rest_api/routers/v1/blocks.py +3 -2
- letta/server/rest_api/routers/v1/embeddings.py +3 -3
- letta/server/rest_api/routers/v1/groups.py +3 -3
- letta/server/rest_api/routers/v1/jobs.py +14 -17
- letta/server/rest_api/routers/v1/organizations.py +10 -10
- letta/server/rest_api/routers/v1/providers.py +12 -10
- letta/server/rest_api/routers/v1/runs.py +3 -3
- letta/server/rest_api/routers/v1/sandbox_configs.py +12 -12
- letta/server/rest_api/routers/v1/sources.py +108 -43
- letta/server/rest_api/routers/v1/steps.py +8 -6
- letta/server/rest_api/routers/v1/tools.py +134 -95
- letta/server/rest_api/utils.py +12 -1
- letta/server/server.py +272 -73
- letta/services/agent_manager.py +246 -313
- letta/services/block_manager.py +30 -9
- letta/services/context_window_calculator/__init__.py +0 -0
- letta/services/context_window_calculator/context_window_calculator.py +150 -0
- letta/services/context_window_calculator/token_counter.py +82 -0
- letta/services/file_processor/__init__.py +0 -0
- letta/services/file_processor/chunker/__init__.py +0 -0
- letta/services/file_processor/chunker/llama_index_chunker.py +29 -0
- letta/services/file_processor/embedder/__init__.py +0 -0
- letta/services/file_processor/embedder/openai_embedder.py +84 -0
- letta/services/file_processor/file_processor.py +123 -0
- letta/services/file_processor/parser/__init__.py +0 -0
- letta/services/file_processor/parser/base_parser.py +9 -0
- letta/services/file_processor/parser/mistral_parser.py +54 -0
- letta/services/file_processor/types.py +0 -0
- letta/services/files_agents_manager.py +184 -0
- letta/services/group_manager.py +118 -0
- letta/services/helpers/agent_manager_helper.py +76 -21
- letta/services/helpers/tool_execution_helper.py +3 -0
- letta/services/helpers/tool_parser_helper.py +100 -0
- letta/services/identity_manager.py +44 -42
- letta/services/job_manager.py +21 -10
- letta/services/mcp/base_client.py +5 -2
- letta/services/mcp/sse_client.py +3 -5
- letta/services/mcp/stdio_client.py +3 -5
- letta/services/mcp_manager.py +281 -0
- letta/services/message_manager.py +40 -26
- letta/services/organization_manager.py +55 -19
- letta/services/passage_manager.py +211 -13
- letta/services/provider_manager.py +48 -2
- letta/services/sandbox_config_manager.py +105 -0
- letta/services/source_manager.py +4 -5
- letta/services/step_manager.py +9 -6
- letta/services/summarizer/summarizer.py +50 -23
- letta/services/telemetry_manager.py +7 -0
- letta/services/tool_executor/tool_execution_manager.py +11 -52
- letta/services/tool_executor/tool_execution_sandbox.py +4 -34
- letta/services/tool_executor/tool_executor.py +107 -105
- letta/services/tool_manager.py +56 -17
- letta/services/tool_sandbox/base.py +39 -92
- letta/services/tool_sandbox/e2b_sandbox.py +16 -11
- letta/services/tool_sandbox/local_sandbox.py +51 -23
- letta/services/user_manager.py +36 -3
- letta/settings.py +10 -3
- letta/templates/__init__.py +0 -0
- letta/templates/sandbox_code_file.py.j2 +47 -0
- letta/templates/template_helper.py +16 -0
- letta/tracing.py +30 -1
- letta/types/__init__.py +7 -0
- letta/utils.py +25 -1
- {letta_nightly-0.7.29.dev20250602104315.dist-info → letta_nightly-0.8.0.dev20250604104349.dist-info}/METADATA +7 -2
- {letta_nightly-0.7.29.dev20250602104315.dist-info → letta_nightly-0.8.0.dev20250604104349.dist-info}/RECORD +138 -112
- {letta_nightly-0.7.29.dev20250602104315.dist-info → letta_nightly-0.8.0.dev20250604104349.dist-info}/LICENSE +0 -0
- {letta_nightly-0.7.29.dev20250602104315.dist-info → letta_nightly-0.8.0.dev20250604104349.dist-info}/WHEEL +0 -0
- {letta_nightly-0.7.29.dev20250602104315.dist-info → letta_nightly-0.8.0.dev20250604104349.dist-info}/entry_points.txt +0 -0
letta/services/block_manager.py
CHANGED
@@ -77,6 +77,30 @@ class BlockManager:
|
|
77
77
|
# Convert back to Pydantic
|
78
78
|
return [m.to_pydantic() for m in created_models]
|
79
79
|
|
80
|
+
@trace_method
|
81
|
+
@enforce_types
|
82
|
+
async def batch_create_blocks_async(self, blocks: List[PydanticBlock], actor: PydanticUser) -> List[PydanticBlock]:
|
83
|
+
"""
|
84
|
+
Batch-create multiple Blocks in one transaction for better performance.
|
85
|
+
Args:
|
86
|
+
blocks: List of PydanticBlock schemas to create
|
87
|
+
actor: The user performing the operation
|
88
|
+
Returns:
|
89
|
+
List of created PydanticBlock instances (with IDs, timestamps, etc.)
|
90
|
+
"""
|
91
|
+
if not blocks:
|
92
|
+
return []
|
93
|
+
|
94
|
+
async with db_registry.async_session() as session:
|
95
|
+
block_models = [
|
96
|
+
BlockModel(**block.model_dump(to_orm=True, exclude_none=True), organization_id=actor.organization_id) for block in blocks
|
97
|
+
]
|
98
|
+
|
99
|
+
created_models = await BlockModel.batch_create_async(items=block_models, db_session=session, actor=actor)
|
100
|
+
|
101
|
+
# Convert back to Pydantic
|
102
|
+
return [m.to_pydantic() for m in created_models]
|
103
|
+
|
80
104
|
@trace_method
|
81
105
|
@enforce_types
|
82
106
|
def update_block(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock:
|
@@ -238,9 +262,9 @@ class BlockManager:
|
|
238
262
|
if actor:
|
239
263
|
query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
240
264
|
|
241
|
-
# Add soft delete filter if applicable
|
242
|
-
if hasattr(BlockModel, "is_deleted"):
|
243
|
-
|
265
|
+
# TODO: Add soft delete filter if applicable
|
266
|
+
# if hasattr(BlockModel, "is_deleted"):
|
267
|
+
# query = query.where(BlockModel.is_deleted == False)
|
244
268
|
|
245
269
|
# Execute the query
|
246
270
|
result = await session.execute(query)
|
@@ -273,15 +297,12 @@ class BlockManager:
|
|
273
297
|
|
274
298
|
@trace_method
|
275
299
|
@enforce_types
|
276
|
-
def
|
277
|
-
self,
|
278
|
-
actor: PydanticUser,
|
279
|
-
) -> int:
|
300
|
+
async def size_async(self, actor: PydanticUser) -> int:
|
280
301
|
"""
|
281
302
|
Get the total count of blocks for the given user.
|
282
303
|
"""
|
283
|
-
with db_registry.
|
284
|
-
return BlockModel.
|
304
|
+
async with db_registry.async_session() as session:
|
305
|
+
return await BlockModel.size_async(db_session=session, actor=actor)
|
285
306
|
|
286
307
|
# Block History Functions
|
287
308
|
|
File without changes
|
@@ -0,0 +1,150 @@
|
|
1
|
+
import asyncio
|
2
|
+
from typing import Any, List, Optional, Tuple
|
3
|
+
|
4
|
+
from openai.types.beta.function_tool import FunctionTool as OpenAITool
|
5
|
+
|
6
|
+
from letta.log import get_logger
|
7
|
+
from letta.schemas.enums import MessageRole
|
8
|
+
from letta.schemas.letta_message_content import TextContent
|
9
|
+
from letta.schemas.memory import ContextWindowOverview
|
10
|
+
from letta.schemas.user import User as PydanticUser
|
11
|
+
from letta.services.context_window_calculator.token_counter import TokenCounter
|
12
|
+
|
13
|
+
logger = get_logger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
class ContextWindowCalculator:
|
17
|
+
"""Handles context window calculations with different token counting strategies"""
|
18
|
+
|
19
|
+
@staticmethod
|
20
|
+
def extract_system_components(system_message: str) -> Tuple[str, str, str]:
|
21
|
+
"""Extract system prompt, core memory, and external memory summary from system message"""
|
22
|
+
base_start = system_message.find("<base_instructions>")
|
23
|
+
memory_blocks_start = system_message.find("<memory_blocks>")
|
24
|
+
metadata_start = system_message.find("<memory_metadata>")
|
25
|
+
|
26
|
+
system_prompt = ""
|
27
|
+
core_memory = ""
|
28
|
+
external_memory_summary = ""
|
29
|
+
|
30
|
+
if base_start != -1 and memory_blocks_start != -1:
|
31
|
+
system_prompt = system_message[base_start:memory_blocks_start].strip()
|
32
|
+
|
33
|
+
if memory_blocks_start != -1 and metadata_start != -1:
|
34
|
+
core_memory = system_message[memory_blocks_start:metadata_start].strip()
|
35
|
+
|
36
|
+
if metadata_start != -1:
|
37
|
+
external_memory_summary = system_message[metadata_start:].strip()
|
38
|
+
|
39
|
+
return system_prompt, core_memory, external_memory_summary
|
40
|
+
|
41
|
+
@staticmethod
|
42
|
+
def extract_summary_memory(messages: List[Any]) -> Tuple[Optional[str], int]:
|
43
|
+
"""Extract summary memory if present and return starting index for real messages"""
|
44
|
+
if (
|
45
|
+
len(messages) > 1
|
46
|
+
and messages[1].role == MessageRole.user
|
47
|
+
and messages[1].content
|
48
|
+
and len(messages[1].content) == 1
|
49
|
+
and isinstance(messages[1].content[0], TextContent)
|
50
|
+
and "The following is a summary of the previous " in messages[1].content[0].text
|
51
|
+
):
|
52
|
+
summary_memory = messages[1].content[0].text
|
53
|
+
start_index = 2
|
54
|
+
return summary_memory, start_index
|
55
|
+
|
56
|
+
return None, 1
|
57
|
+
|
58
|
+
async def calculate_context_window(
|
59
|
+
self, agent_state: Any, actor: PydanticUser, token_counter: TokenCounter, message_manager: Any, passage_manager: Any
|
60
|
+
) -> ContextWindowOverview:
|
61
|
+
"""Calculate context window information using the provided token counter"""
|
62
|
+
|
63
|
+
# Fetch data concurrently
|
64
|
+
(in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather(
|
65
|
+
message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor),
|
66
|
+
passage_manager.size_async(actor=actor, agent_id=agent_state.id),
|
67
|
+
message_manager.size_async(actor=actor, agent_id=agent_state.id),
|
68
|
+
)
|
69
|
+
|
70
|
+
# Convert messages to appropriate format
|
71
|
+
converted_messages = token_counter.convert_messages(in_context_messages)
|
72
|
+
|
73
|
+
# Extract system components
|
74
|
+
system_prompt = ""
|
75
|
+
core_memory = ""
|
76
|
+
external_memory_summary = ""
|
77
|
+
|
78
|
+
if (
|
79
|
+
in_context_messages
|
80
|
+
and in_context_messages[0].role == MessageRole.system
|
81
|
+
and in_context_messages[0].content
|
82
|
+
and len(in_context_messages[0].content) == 1
|
83
|
+
and isinstance(in_context_messages[0].content[0], TextContent)
|
84
|
+
):
|
85
|
+
system_message = in_context_messages[0].content[0].text
|
86
|
+
system_prompt, core_memory, external_memory_summary = self.extract_system_components(system_message)
|
87
|
+
|
88
|
+
# System prompt
|
89
|
+
system_prompt = system_prompt or agent_state.system
|
90
|
+
|
91
|
+
# Extract summary memory
|
92
|
+
summary_memory, message_start_index = self.extract_summary_memory(in_context_messages)
|
93
|
+
|
94
|
+
# Prepare tool definitions
|
95
|
+
available_functions_definitions = []
|
96
|
+
if agent_state.tools:
|
97
|
+
available_functions_definitions = [OpenAITool(type="function", function=f.json_schema) for f in agent_state.tools]
|
98
|
+
|
99
|
+
# Count tokens concurrently
|
100
|
+
token_counts = await asyncio.gather(
|
101
|
+
token_counter.count_text_tokens(system_prompt),
|
102
|
+
token_counter.count_text_tokens(core_memory),
|
103
|
+
token_counter.count_text_tokens(external_memory_summary),
|
104
|
+
token_counter.count_text_tokens(summary_memory) if summary_memory else asyncio.sleep(0, result=0),
|
105
|
+
(
|
106
|
+
token_counter.count_message_tokens(converted_messages[message_start_index:])
|
107
|
+
if len(converted_messages) > message_start_index
|
108
|
+
else asyncio.sleep(0, result=0)
|
109
|
+
),
|
110
|
+
(
|
111
|
+
token_counter.count_tool_tokens(available_functions_definitions)
|
112
|
+
if available_functions_definitions
|
113
|
+
else asyncio.sleep(0, result=0)
|
114
|
+
),
|
115
|
+
)
|
116
|
+
|
117
|
+
(
|
118
|
+
num_tokens_system,
|
119
|
+
num_tokens_core_memory,
|
120
|
+
num_tokens_external_memory_summary,
|
121
|
+
num_tokens_summary_memory,
|
122
|
+
num_tokens_messages,
|
123
|
+
num_tokens_available_functions_definitions,
|
124
|
+
) = token_counts
|
125
|
+
|
126
|
+
num_tokens_used_total = sum(token_counts)
|
127
|
+
|
128
|
+
return ContextWindowOverview(
|
129
|
+
# context window breakdown (in messages)
|
130
|
+
num_messages=len(in_context_messages),
|
131
|
+
num_archival_memory=passage_manager_size,
|
132
|
+
num_recall_memory=message_manager_size,
|
133
|
+
num_tokens_external_memory_summary=num_tokens_external_memory_summary,
|
134
|
+
external_memory_summary=external_memory_summary,
|
135
|
+
# top-level information
|
136
|
+
context_window_size_max=agent_state.llm_config.context_window,
|
137
|
+
context_window_size_current=num_tokens_used_total,
|
138
|
+
# context window breakdown (in tokens)
|
139
|
+
num_tokens_system=num_tokens_system,
|
140
|
+
system_prompt=system_prompt,
|
141
|
+
num_tokens_core_memory=num_tokens_core_memory,
|
142
|
+
core_memory=core_memory,
|
143
|
+
num_tokens_summary_memory=num_tokens_summary_memory,
|
144
|
+
summary_memory=summary_memory,
|
145
|
+
num_tokens_messages=num_tokens_messages,
|
146
|
+
messages=in_context_messages,
|
147
|
+
# related to functions
|
148
|
+
num_tokens_functions_definitions=num_tokens_available_functions_definitions,
|
149
|
+
functions_definitions=available_functions_definitions,
|
150
|
+
)
|
@@ -0,0 +1,82 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import Any, Dict, List
|
3
|
+
|
4
|
+
from letta.llm_api.anthropic_client import AnthropicClient
|
5
|
+
from letta.utils import count_tokens
|
6
|
+
|
7
|
+
|
8
|
+
class TokenCounter(ABC):
|
9
|
+
"""Abstract base class for token counting strategies"""
|
10
|
+
|
11
|
+
@abstractmethod
|
12
|
+
async def count_text_tokens(self, text: str) -> int:
|
13
|
+
"""Count tokens in a text string"""
|
14
|
+
|
15
|
+
@abstractmethod
|
16
|
+
async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int:
|
17
|
+
"""Count tokens in a list of messages"""
|
18
|
+
|
19
|
+
@abstractmethod
|
20
|
+
async def count_tool_tokens(self, tools: List[Any]) -> int:
|
21
|
+
"""Count tokens in tool definitions"""
|
22
|
+
|
23
|
+
@abstractmethod
|
24
|
+
def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:
|
25
|
+
"""Convert messages to the appropriate format for this counter"""
|
26
|
+
|
27
|
+
|
28
|
+
class AnthropicTokenCounter(TokenCounter):
|
29
|
+
"""Token counter using Anthropic's API"""
|
30
|
+
|
31
|
+
def __init__(self, anthropic_client: AnthropicClient, model: str):
|
32
|
+
self.client = anthropic_client
|
33
|
+
self.model = model
|
34
|
+
|
35
|
+
async def count_text_tokens(self, text: str) -> int:
|
36
|
+
if not text:
|
37
|
+
return 0
|
38
|
+
return await self.client.count_tokens(model=self.model, messages=[{"role": "user", "content": text}])
|
39
|
+
|
40
|
+
async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int:
|
41
|
+
if not messages:
|
42
|
+
return 0
|
43
|
+
return await self.client.count_tokens(model=self.model, messages=messages)
|
44
|
+
|
45
|
+
async def count_tool_tokens(self, tools: List[Any]) -> int:
|
46
|
+
if not tools:
|
47
|
+
return 0
|
48
|
+
return await self.client.count_tokens(model=self.model, tools=tools)
|
49
|
+
|
50
|
+
def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:
|
51
|
+
return [m.to_anthropic_dict() for m in messages]
|
52
|
+
|
53
|
+
|
54
|
+
class TiktokenCounter(TokenCounter):
|
55
|
+
"""Token counter using tiktoken"""
|
56
|
+
|
57
|
+
def __init__(self, model: str):
|
58
|
+
self.model = model
|
59
|
+
|
60
|
+
async def count_text_tokens(self, text: str) -> int:
|
61
|
+
if not text:
|
62
|
+
return 0
|
63
|
+
return count_tokens(text)
|
64
|
+
|
65
|
+
async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int:
|
66
|
+
if not messages:
|
67
|
+
return 0
|
68
|
+
from letta.local_llm.utils import num_tokens_from_messages
|
69
|
+
|
70
|
+
return num_tokens_from_messages(messages=messages, model=self.model)
|
71
|
+
|
72
|
+
async def count_tool_tokens(self, tools: List[Any]) -> int:
|
73
|
+
if not tools:
|
74
|
+
return 0
|
75
|
+
from letta.local_llm.utils import num_tokens_from_functions
|
76
|
+
|
77
|
+
# Extract function definitions from OpenAITool objects
|
78
|
+
functions = [t.function.model_dump() for t in tools]
|
79
|
+
return num_tokens_from_functions(functions=functions, model=self.model)
|
80
|
+
|
81
|
+
def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:
|
82
|
+
return [m.to_openai_dict() for m in messages]
|
File without changes
|
File without changes
|
@@ -0,0 +1,29 @@
|
|
1
|
+
from typing import List
|
2
|
+
|
3
|
+
from mistralai import OCRPageObject
|
4
|
+
|
5
|
+
from letta.log import get_logger
|
6
|
+
|
7
|
+
logger = get_logger(__name__)
|
8
|
+
|
9
|
+
|
10
|
+
class LlamaIndexChunker:
|
11
|
+
"""LlamaIndex-based text chunking"""
|
12
|
+
|
13
|
+
def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50):
|
14
|
+
self.chunk_size = chunk_size
|
15
|
+
self.chunk_overlap = chunk_overlap
|
16
|
+
|
17
|
+
from llama_index.core.node_parser import SentenceSplitter
|
18
|
+
|
19
|
+
self.parser = SentenceSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
20
|
+
|
21
|
+
# TODO: Make this more general beyond Mistral
|
22
|
+
def chunk_text(self, page: OCRPageObject) -> List[str]:
|
23
|
+
"""Chunk text using LlamaIndex splitter"""
|
24
|
+
try:
|
25
|
+
return self.parser.split_text(page.markdown)
|
26
|
+
|
27
|
+
except Exception as e:
|
28
|
+
logger.error(f"Chunking failed: {str(e)}")
|
29
|
+
raise
|
File without changes
|
@@ -0,0 +1,84 @@
|
|
1
|
+
import asyncio
|
2
|
+
from typing import List, Optional, Tuple
|
3
|
+
|
4
|
+
import openai
|
5
|
+
|
6
|
+
from letta.log import get_logger
|
7
|
+
from letta.schemas.embedding_config import EmbeddingConfig
|
8
|
+
from letta.schemas.passage import Passage
|
9
|
+
from letta.schemas.user import User
|
10
|
+
from letta.settings import model_settings
|
11
|
+
|
12
|
+
logger = get_logger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
class OpenAIEmbedder:
|
16
|
+
"""OpenAI-based embedding generation"""
|
17
|
+
|
18
|
+
def __init__(self, embedding_config: Optional[EmbeddingConfig] = None):
|
19
|
+
self.embedding_config = embedding_config or EmbeddingConfig.default_config(provider="openai")
|
20
|
+
|
21
|
+
# TODO: Unify to global OpenAI client
|
22
|
+
self.client = openai.AsyncOpenAI(api_key=model_settings.openai_api_key)
|
23
|
+
self.max_batch = 1024
|
24
|
+
self.max_concurrent_requests = 20
|
25
|
+
|
26
|
+
async def _embed_batch(self, batch: List[str], batch_indices: List[int]) -> List[Tuple[int, List[float]]]:
|
27
|
+
"""Embed a single batch and return embeddings with their original indices"""
|
28
|
+
response = await self.client.embeddings.create(model=self.embedding_config.embedding_model, input=batch)
|
29
|
+
return [(idx, res.embedding) for idx, res in zip(batch_indices, response.data)]
|
30
|
+
|
31
|
+
async def generate_embedded_passages(self, file_id: str, source_id: str, chunks: List[str], actor: User) -> List[Passage]:
|
32
|
+
"""Generate embeddings for chunks with batching and concurrent processing"""
|
33
|
+
if not chunks:
|
34
|
+
return []
|
35
|
+
|
36
|
+
logger.info(f"Generating embeddings for {len(chunks)} chunks using {self.embedding_config.embedding_model}")
|
37
|
+
|
38
|
+
# Create batches with their original indices
|
39
|
+
batches = []
|
40
|
+
batch_indices = []
|
41
|
+
|
42
|
+
for i in range(0, len(chunks), self.max_batch):
|
43
|
+
batch = chunks[i : i + self.max_batch]
|
44
|
+
indices = list(range(i, min(i + self.max_batch, len(chunks))))
|
45
|
+
batches.append(batch)
|
46
|
+
batch_indices.append(indices)
|
47
|
+
|
48
|
+
logger.info(f"Processing {len(batches)} batches")
|
49
|
+
|
50
|
+
async def process(batch: List[str], indices: List[int]):
|
51
|
+
try:
|
52
|
+
return await self._embed_batch(batch, indices)
|
53
|
+
except Exception as e:
|
54
|
+
logger.error(f"Failed to embed batch of size {len(batch)}: {str(e)}")
|
55
|
+
raise
|
56
|
+
|
57
|
+
# Execute all batches concurrently with semaphore control
|
58
|
+
tasks = [process(batch, indices) for batch, indices in zip(batches, batch_indices)]
|
59
|
+
|
60
|
+
results = await asyncio.gather(*tasks)
|
61
|
+
|
62
|
+
# Flatten results and sort by original index
|
63
|
+
indexed_embeddings = []
|
64
|
+
for batch_result in results:
|
65
|
+
indexed_embeddings.extend(batch_result)
|
66
|
+
|
67
|
+
# Sort by index to maintain original order
|
68
|
+
indexed_embeddings.sort(key=lambda x: x[0])
|
69
|
+
|
70
|
+
# Create Passage objects in original order
|
71
|
+
passages = []
|
72
|
+
for (idx, embedding), text in zip(indexed_embeddings, chunks):
|
73
|
+
passage = Passage(
|
74
|
+
text=text,
|
75
|
+
file_id=file_id,
|
76
|
+
source_id=source_id,
|
77
|
+
embedding=embedding,
|
78
|
+
embedding_config=self.embedding_config,
|
79
|
+
organization_id=actor.organization_id,
|
80
|
+
)
|
81
|
+
passages.append(passage)
|
82
|
+
|
83
|
+
logger.info(f"Successfully generated {len(passages)} embeddings")
|
84
|
+
return passages
|
@@ -0,0 +1,123 @@
|
|
1
|
+
import mimetypes
|
2
|
+
from typing import List, Optional
|
3
|
+
|
4
|
+
from fastapi import UploadFile
|
5
|
+
|
6
|
+
from letta.log import get_logger
|
7
|
+
from letta.schemas.agent import AgentState
|
8
|
+
from letta.schemas.enums import JobStatus
|
9
|
+
from letta.schemas.file import FileMetadata
|
10
|
+
from letta.schemas.job import Job, JobUpdate
|
11
|
+
from letta.schemas.passage import Passage
|
12
|
+
from letta.schemas.user import User
|
13
|
+
from letta.server.server import SyncServer
|
14
|
+
from letta.services.file_processor.chunker.llama_index_chunker import LlamaIndexChunker
|
15
|
+
from letta.services.file_processor.embedder.openai_embedder import OpenAIEmbedder
|
16
|
+
from letta.services.file_processor.parser.mistral_parser import MistralFileParser
|
17
|
+
from letta.services.job_manager import JobManager
|
18
|
+
from letta.services.passage_manager import PassageManager
|
19
|
+
from letta.services.source_manager import SourceManager
|
20
|
+
|
21
|
+
logger = get_logger(__name__)
|
22
|
+
|
23
|
+
|
24
|
+
class FileProcessor:
|
25
|
+
"""Main PDF processing orchestrator"""
|
26
|
+
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
file_parser: MistralFileParser,
|
30
|
+
text_chunker: LlamaIndexChunker,
|
31
|
+
embedder: OpenAIEmbedder,
|
32
|
+
actor: User,
|
33
|
+
max_file_size: int = 50 * 1024 * 1024, # 50MB default
|
34
|
+
):
|
35
|
+
self.file_parser = file_parser
|
36
|
+
self.text_chunker = text_chunker
|
37
|
+
self.embedder = embedder
|
38
|
+
self.max_file_size = max_file_size
|
39
|
+
self.source_manager = SourceManager()
|
40
|
+
self.passage_manager = PassageManager()
|
41
|
+
self.job_manager = JobManager()
|
42
|
+
self.actor = actor
|
43
|
+
|
44
|
+
# TODO: Factor this function out of SyncServer
|
45
|
+
async def process(
|
46
|
+
self,
|
47
|
+
server: SyncServer,
|
48
|
+
agent_states: List[AgentState],
|
49
|
+
source_id: str,
|
50
|
+
content: bytes,
|
51
|
+
file: UploadFile,
|
52
|
+
job: Optional[Job] = None,
|
53
|
+
) -> List[Passage]:
|
54
|
+
file_metadata = self._extract_upload_file_metadata(file, source_id=source_id)
|
55
|
+
file_metadata = await self.source_manager.create_file(file_metadata, self.actor)
|
56
|
+
filename = file_metadata.file_name
|
57
|
+
|
58
|
+
try:
|
59
|
+
# Ensure we're working with bytes
|
60
|
+
if isinstance(content, str):
|
61
|
+
content = content.encode("utf-8")
|
62
|
+
|
63
|
+
if len(content) > self.max_file_size:
|
64
|
+
raise ValueError(f"PDF size exceeds maximum allowed size of {self.max_file_size} bytes")
|
65
|
+
|
66
|
+
logger.info(f"Starting OCR extraction for {filename}")
|
67
|
+
ocr_response = await self.file_parser.extract_text(content, mime_type=file_metadata.file_type)
|
68
|
+
|
69
|
+
if not ocr_response or len(ocr_response.pages) == 0:
|
70
|
+
raise ValueError("No text extracted from PDF")
|
71
|
+
|
72
|
+
logger.info("Chunking extracted text")
|
73
|
+
all_passages = []
|
74
|
+
for page in ocr_response.pages:
|
75
|
+
chunks = self.text_chunker.chunk_text(page)
|
76
|
+
|
77
|
+
if not chunks:
|
78
|
+
raise ValueError("No chunks created from text")
|
79
|
+
|
80
|
+
passages = await self.embedder.generate_embedded_passages(
|
81
|
+
file_id=file_metadata.id, source_id=source_id, chunks=chunks, actor=self.actor
|
82
|
+
)
|
83
|
+
all_passages.extend(passages)
|
84
|
+
|
85
|
+
all_passages = await self.passage_manager.create_many_passages_async(all_passages, self.actor)
|
86
|
+
|
87
|
+
logger.info(f"Successfully processed {filename}: {len(all_passages)} passages")
|
88
|
+
|
89
|
+
await server.insert_file_into_context_windows(
|
90
|
+
source_id=source_id,
|
91
|
+
text="".join([ocr_response.pages[i].markdown for i in range(min(3, len(ocr_response.pages)))]),
|
92
|
+
file_id=file_metadata.id,
|
93
|
+
actor=self.actor,
|
94
|
+
agent_states=agent_states,
|
95
|
+
)
|
96
|
+
|
97
|
+
# update job status
|
98
|
+
if job:
|
99
|
+
job.status = JobStatus.completed
|
100
|
+
job.metadata["num_passages"] = len(all_passages)
|
101
|
+
await self.job_manager.update_job_by_id_async(job_id=job.id, job_update=JobUpdate(**job.model_dump()), actor=self.actor)
|
102
|
+
|
103
|
+
return all_passages
|
104
|
+
|
105
|
+
except Exception as e:
|
106
|
+
logger.error(f"PDF processing failed for {filename}: {str(e)}")
|
107
|
+
|
108
|
+
# update job status
|
109
|
+
if job:
|
110
|
+
job.status = JobStatus.failed
|
111
|
+
job.metadata["error"] = str(e)
|
112
|
+
await self.job_manager.update_job_by_id_async(job_id=job.id, job_update=JobUpdate(**job.model_dump()), actor=self.actor)
|
113
|
+
|
114
|
+
return []
|
115
|
+
|
116
|
+
def _extract_upload_file_metadata(self, file: UploadFile, source_id: str) -> FileMetadata:
|
117
|
+
file_metadata = {
|
118
|
+
"file_name": file.filename,
|
119
|
+
"file_path": None,
|
120
|
+
"file_type": mimetypes.guess_type(file.filename)[0] or file.content_type or "unknown",
|
121
|
+
"file_size": file.size if file.size is not None else None,
|
122
|
+
}
|
123
|
+
return FileMetadata(**file_metadata, source_id=source_id)
|
File without changes
|
@@ -0,0 +1,54 @@
|
|
1
|
+
import base64
|
2
|
+
|
3
|
+
from mistralai import Mistral, OCRPageObject, OCRResponse, OCRUsageInfo
|
4
|
+
|
5
|
+
from letta.log import get_logger
|
6
|
+
from letta.services.file_processor.parser.base_parser import FileParser
|
7
|
+
from letta.settings import settings
|
8
|
+
|
9
|
+
logger = get_logger(__name__)
|
10
|
+
|
11
|
+
|
12
|
+
class MistralFileParser(FileParser):
|
13
|
+
"""Mistral-based OCR extraction"""
|
14
|
+
|
15
|
+
def __init__(self, model: str = "mistral-ocr-latest"):
|
16
|
+
self.model = model
|
17
|
+
|
18
|
+
# TODO: Make this return something general if we add more file parsers
|
19
|
+
async def extract_text(self, content: bytes, mime_type: str) -> OCRResponse:
|
20
|
+
"""Extract text using Mistral OCR or shortcut for plain text."""
|
21
|
+
try:
|
22
|
+
logger.info(f"Extracting text using Mistral OCR model: {self.model}")
|
23
|
+
|
24
|
+
# TODO: Kind of hacky...we try to exit early here?
|
25
|
+
# TODO: Create our internal file parser representation we return instead of OCRResponse
|
26
|
+
if mime_type == "text/plain":
|
27
|
+
text = content.decode("utf-8", errors="replace")
|
28
|
+
return OCRResponse(
|
29
|
+
model=self.model,
|
30
|
+
pages=[
|
31
|
+
OCRPageObject(
|
32
|
+
index=0,
|
33
|
+
markdown=text,
|
34
|
+
images=[],
|
35
|
+
dimensions=None,
|
36
|
+
)
|
37
|
+
],
|
38
|
+
usage_info=OCRUsageInfo(pages_processed=1), # You might need to construct this properly
|
39
|
+
document_annotation=None,
|
40
|
+
)
|
41
|
+
|
42
|
+
base64_encoded_content = base64.b64encode(content).decode("utf-8")
|
43
|
+
document_url = f"data:{mime_type};base64,{base64_encoded_content}"
|
44
|
+
|
45
|
+
async with Mistral(api_key=settings.mistral_api_key) as mistral:
|
46
|
+
ocr_response = await mistral.ocr.process_async(
|
47
|
+
model="mistral-ocr-latest", document={"type": "document_url", "document_url": document_url}, include_image_base64=False
|
48
|
+
)
|
49
|
+
|
50
|
+
return ocr_response
|
51
|
+
|
52
|
+
except Exception as e:
|
53
|
+
logger.error(f"OCR extraction failed: {str(e)}")
|
54
|
+
raise
|
File without changes
|