letta-nightly 0.11.4.dev20250825104222__py3-none-any.whl → 0.11.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- letta/__init__.py +1 -1
- letta/agent.py +9 -3
- letta/agents/base_agent.py +2 -2
- letta/agents/letta_agent.py +56 -45
- letta/agents/voice_agent.py +2 -2
- letta/data_sources/redis_client.py +146 -1
- letta/errors.py +4 -0
- letta/functions/function_sets/files.py +2 -2
- letta/functions/mcp_client/types.py +30 -6
- letta/functions/schema_generator.py +46 -1
- letta/functions/schema_validator.py +17 -2
- letta/functions/types.py +1 -1
- letta/helpers/tool_execution_helper.py +0 -2
- letta/llm_api/anthropic_client.py +27 -5
- letta/llm_api/deepseek_client.py +97 -0
- letta/llm_api/groq_client.py +79 -0
- letta/llm_api/helpers.py +0 -1
- letta/llm_api/llm_api_tools.py +2 -113
- letta/llm_api/llm_client.py +21 -0
- letta/llm_api/llm_client_base.py +11 -9
- letta/llm_api/openai_client.py +3 -0
- letta/llm_api/xai_client.py +85 -0
- letta/prompts/prompt_generator.py +190 -0
- letta/schemas/agent_file.py +17 -2
- letta/schemas/file.py +24 -1
- letta/schemas/job.py +2 -0
- letta/schemas/letta_message.py +2 -0
- letta/schemas/letta_request.py +22 -0
- letta/schemas/message.py +10 -1
- letta/schemas/providers/bedrock.py +1 -0
- letta/server/rest_api/redis_stream_manager.py +300 -0
- letta/server/rest_api/routers/v1/agents.py +129 -7
- letta/server/rest_api/routers/v1/folders.py +15 -5
- letta/server/rest_api/routers/v1/runs.py +101 -11
- letta/server/rest_api/routers/v1/sources.py +21 -53
- letta/server/rest_api/routers/v1/telemetry.py +14 -4
- letta/server/rest_api/routers/v1/tools.py +2 -2
- letta/server/rest_api/streaming_response.py +3 -24
- letta/server/server.py +0 -1
- letta/services/agent_manager.py +2 -2
- letta/services/agent_serialization_manager.py +129 -32
- letta/services/file_manager.py +111 -6
- letta/services/file_processor/file_processor.py +5 -2
- letta/services/files_agents_manager.py +60 -0
- letta/services/helpers/agent_manager_helper.py +4 -205
- letta/services/helpers/tool_parser_helper.py +6 -3
- letta/services/mcp/base_client.py +7 -1
- letta/services/mcp/sse_client.py +7 -2
- letta/services/mcp/stdio_client.py +5 -0
- letta/services/mcp/streamable_http_client.py +11 -2
- letta/services/mcp_manager.py +31 -30
- letta/services/source_manager.py +26 -1
- letta/services/summarizer/summarizer.py +21 -10
- letta/services/tool_executor/files_tool_executor.py +13 -9
- letta/services/tool_executor/mcp_tool_executor.py +3 -0
- letta/services/tool_executor/tool_execution_manager.py +13 -0
- letta/services/tool_manager.py +43 -20
- letta/settings.py +1 -0
- letta/utils.py +37 -0
- {letta_nightly-0.11.4.dev20250825104222.dist-info → letta_nightly-0.11.5.dist-info}/METADATA +2 -2
- {letta_nightly-0.11.4.dev20250825104222.dist-info → letta_nightly-0.11.5.dist-info}/RECORD +64 -63
- letta/functions/mcp_client/__init__.py +0 -0
- letta/functions/mcp_client/base_client.py +0 -156
- letta/functions/mcp_client/sse_client.py +0 -51
- letta/functions/mcp_client/stdio_client.py +0 -109
- {letta_nightly-0.11.4.dev20250825104222.dist-info → letta_nightly-0.11.5.dist-info}/LICENSE +0 -0
- {letta_nightly-0.11.4.dev20250825104222.dist-info → letta_nightly-0.11.5.dist-info}/WHEEL +0 -0
- {letta_nightly-0.11.4.dev20250825104222.dist-info → letta_nightly-0.11.5.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,190 @@
|
|
1
|
+
from datetime import datetime
|
2
|
+
from typing import List, Literal, Optional
|
3
|
+
|
4
|
+
from letta.constants import IN_CONTEXT_MEMORY_KEYWORD
|
5
|
+
from letta.helpers import ToolRulesSolver
|
6
|
+
from letta.helpers.datetime_helpers import format_datetime, get_local_time_fast
|
7
|
+
from letta.otel.tracing import trace_method
|
8
|
+
from letta.schemas.memory import Memory
|
9
|
+
|
10
|
+
|
11
|
+
class PromptGenerator:
|
12
|
+
|
13
|
+
# TODO: This code is kind of wonky and deserves a rewrite
|
14
|
+
@trace_method
|
15
|
+
@staticmethod
|
16
|
+
def compile_memory_metadata_block(
|
17
|
+
memory_edit_timestamp: datetime,
|
18
|
+
timezone: str,
|
19
|
+
previous_message_count: int = 0,
|
20
|
+
archival_memory_size: Optional[int] = 0,
|
21
|
+
) -> str:
|
22
|
+
"""
|
23
|
+
Generate a memory metadata block for the agent's system prompt.
|
24
|
+
|
25
|
+
This creates a structured metadata section that informs the agent about
|
26
|
+
the current state of its memory systems, including timing information
|
27
|
+
and memory counts. This helps the agent understand what information
|
28
|
+
is available through its tools.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
memory_edit_timestamp: When memory blocks were last modified
|
32
|
+
timezone: The timezone to use for formatting timestamps (e.g., 'America/Los_Angeles')
|
33
|
+
previous_message_count: Number of messages in recall memory (conversation history)
|
34
|
+
archival_memory_size: Number of items in archival memory (long-term storage)
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
A formatted string containing the memory metadata block with XML-style tags
|
38
|
+
|
39
|
+
Example Output:
|
40
|
+
<memory_metadata>
|
41
|
+
- The current time is: 2024-01-15 10:30 AM PST
|
42
|
+
- Memory blocks were last modified: 2024-01-15 09:00 AM PST
|
43
|
+
- 42 previous messages between you and the user are stored in recall memory (use tools to access them)
|
44
|
+
- 156 total memories you created are stored in archival memory (use tools to access them)
|
45
|
+
</memory_metadata>
|
46
|
+
"""
|
47
|
+
# Put the timestamp in the local timezone (mimicking get_local_time())
|
48
|
+
timestamp_str = format_datetime(memory_edit_timestamp, timezone)
|
49
|
+
|
50
|
+
# Create a metadata block of info so the agent knows about the metadata of out-of-context memories
|
51
|
+
metadata_lines = [
|
52
|
+
"<memory_metadata>",
|
53
|
+
f"- The current time is: {get_local_time_fast(timezone)}",
|
54
|
+
f"- Memory blocks were last modified: {timestamp_str}",
|
55
|
+
f"- {previous_message_count} previous messages between you and the user are stored in recall memory (use tools to access them)",
|
56
|
+
]
|
57
|
+
|
58
|
+
# Only include archival memory line if there are archival memories
|
59
|
+
if archival_memory_size is not None and archival_memory_size > 0:
|
60
|
+
metadata_lines.append(
|
61
|
+
f"- {archival_memory_size} total memories you created are stored in archival memory (use tools to access them)"
|
62
|
+
)
|
63
|
+
|
64
|
+
metadata_lines.append("</memory_metadata>")
|
65
|
+
memory_metadata_block = "\n".join(metadata_lines)
|
66
|
+
return memory_metadata_block
|
67
|
+
|
68
|
+
@staticmethod
|
69
|
+
def safe_format(template: str, variables: dict) -> str:
|
70
|
+
"""
|
71
|
+
Safely formats a template string, preserving empty {} and {unknown_vars}
|
72
|
+
while substituting known variables.
|
73
|
+
|
74
|
+
If we simply use {} in format_map, it'll be treated as a positional field
|
75
|
+
"""
|
76
|
+
# First escape any empty {} by doubling them
|
77
|
+
escaped = template.replace("{}", "{{}}")
|
78
|
+
|
79
|
+
# Now use format_map with our custom mapping
|
80
|
+
return escaped.format_map(PreserveMapping(variables))
|
81
|
+
|
82
|
+
@trace_method
|
83
|
+
@staticmethod
|
84
|
+
def get_system_message_from_compiled_memory(
|
85
|
+
system_prompt: str,
|
86
|
+
memory_with_sources: str,
|
87
|
+
in_context_memory_last_edit: datetime, # TODO move this inside of BaseMemory?
|
88
|
+
timezone: str,
|
89
|
+
user_defined_variables: Optional[dict] = None,
|
90
|
+
append_icm_if_missing: bool = True,
|
91
|
+
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
|
92
|
+
previous_message_count: int = 0,
|
93
|
+
archival_memory_size: int = 0,
|
94
|
+
) -> str:
|
95
|
+
"""Prepare the final/full system message that will be fed into the LLM API
|
96
|
+
|
97
|
+
The base system message may be templated, in which case we need to render the variables.
|
98
|
+
|
99
|
+
The following are reserved variables:
|
100
|
+
- CORE_MEMORY: the in-context memory of the LLM
|
101
|
+
"""
|
102
|
+
if user_defined_variables is not None:
|
103
|
+
# TODO eventually support the user defining their own variables to inject
|
104
|
+
raise NotImplementedError
|
105
|
+
else:
|
106
|
+
variables = {}
|
107
|
+
|
108
|
+
# Add the protected memory variable
|
109
|
+
if IN_CONTEXT_MEMORY_KEYWORD in variables:
|
110
|
+
raise ValueError(f"Found protected variable '{IN_CONTEXT_MEMORY_KEYWORD}' in user-defined vars: {str(user_defined_variables)}")
|
111
|
+
else:
|
112
|
+
# TODO should this all put into the memory.__repr__ function?
|
113
|
+
memory_metadata_string = PromptGenerator.compile_memory_metadata_block(
|
114
|
+
memory_edit_timestamp=in_context_memory_last_edit,
|
115
|
+
previous_message_count=previous_message_count,
|
116
|
+
archival_memory_size=archival_memory_size,
|
117
|
+
timezone=timezone,
|
118
|
+
)
|
119
|
+
|
120
|
+
full_memory_string = memory_with_sources + "\n\n" + memory_metadata_string
|
121
|
+
|
122
|
+
# Add to the variables list to inject
|
123
|
+
variables[IN_CONTEXT_MEMORY_KEYWORD] = full_memory_string
|
124
|
+
|
125
|
+
if template_format == "f-string":
|
126
|
+
memory_variable_string = "{" + IN_CONTEXT_MEMORY_KEYWORD + "}"
|
127
|
+
|
128
|
+
# Catch the special case where the system prompt is unformatted
|
129
|
+
if append_icm_if_missing:
|
130
|
+
if memory_variable_string not in system_prompt:
|
131
|
+
# In this case, append it to the end to make sure memory is still injected
|
132
|
+
# warnings.warn(f"{IN_CONTEXT_MEMORY_KEYWORD} variable was missing from system prompt, appending instead")
|
133
|
+
system_prompt += "\n\n" + memory_variable_string
|
134
|
+
|
135
|
+
# render the variables using the built-in templater
|
136
|
+
try:
|
137
|
+
if user_defined_variables:
|
138
|
+
formatted_prompt = PromptGenerator.safe_format(system_prompt, variables)
|
139
|
+
else:
|
140
|
+
formatted_prompt = system_prompt.replace(memory_variable_string, full_memory_string)
|
141
|
+
except Exception as e:
|
142
|
+
raise ValueError(f"Failed to format system prompt - {str(e)}. System prompt value:\n{system_prompt}")
|
143
|
+
|
144
|
+
else:
|
145
|
+
# TODO support for mustache and jinja2
|
146
|
+
raise NotImplementedError(template_format)
|
147
|
+
|
148
|
+
return formatted_prompt
|
149
|
+
|
150
|
+
@trace_method
|
151
|
+
@staticmethod
|
152
|
+
async def compile_system_message_async(
|
153
|
+
system_prompt: str,
|
154
|
+
in_context_memory: Memory,
|
155
|
+
in_context_memory_last_edit: datetime, # TODO move this inside of BaseMemory?
|
156
|
+
timezone: str,
|
157
|
+
user_defined_variables: Optional[dict] = None,
|
158
|
+
append_icm_if_missing: bool = True,
|
159
|
+
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
|
160
|
+
previous_message_count: int = 0,
|
161
|
+
archival_memory_size: int = 0,
|
162
|
+
tool_rules_solver: Optional[ToolRulesSolver] = None,
|
163
|
+
sources: Optional[List] = None,
|
164
|
+
max_files_open: Optional[int] = None,
|
165
|
+
) -> str:
|
166
|
+
tool_constraint_block = None
|
167
|
+
if tool_rules_solver is not None:
|
168
|
+
tool_constraint_block = tool_rules_solver.compile_tool_rule_prompts()
|
169
|
+
|
170
|
+
if user_defined_variables is not None:
|
171
|
+
# TODO eventually support the user defining their own variables to inject
|
172
|
+
raise NotImplementedError
|
173
|
+
else:
|
174
|
+
pass
|
175
|
+
|
176
|
+
memory_with_sources = await in_context_memory.compile_in_thread_async(
|
177
|
+
tool_usage_rules=tool_constraint_block, sources=sources, max_files_open=max_files_open
|
178
|
+
)
|
179
|
+
|
180
|
+
return PromptGenerator.get_system_message_from_compiled_memory(
|
181
|
+
system_prompt=system_prompt,
|
182
|
+
memory_with_sources=memory_with_sources,
|
183
|
+
in_context_memory_last_edit=in_context_memory_last_edit,
|
184
|
+
timezone=timezone,
|
185
|
+
user_defined_variables=user_defined_variables,
|
186
|
+
append_icm_if_missing=append_icm_if_missing,
|
187
|
+
template_format=template_format,
|
188
|
+
previous_message_count=previous_message_count,
|
189
|
+
archival_memory_size=archival_memory_size,
|
190
|
+
)
|
letta/schemas/agent_file.py
CHANGED
@@ -1,15 +1,17 @@
|
|
1
1
|
from datetime import datetime
|
2
2
|
from typing import Any, Dict, List, Optional
|
3
3
|
|
4
|
+
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
4
5
|
from pydantic import BaseModel, Field
|
5
6
|
|
7
|
+
from letta.helpers.datetime_helpers import get_utc_time
|
6
8
|
from letta.schemas.agent import AgentState, CreateAgent
|
7
9
|
from letta.schemas.block import Block, CreateBlock
|
8
10
|
from letta.schemas.enums import MessageRole
|
9
11
|
from letta.schemas.file import FileAgent, FileAgentBase, FileMetadata, FileMetadataBase
|
10
12
|
from letta.schemas.group import Group, GroupCreate
|
11
13
|
from letta.schemas.mcp import MCPServer
|
12
|
-
from letta.schemas.message import Message, MessageCreate
|
14
|
+
from letta.schemas.message import Message, MessageCreate, ToolReturn
|
13
15
|
from letta.schemas.source import Source, SourceCreate
|
14
16
|
from letta.schemas.tool import Tool
|
15
17
|
from letta.schemas.user import User
|
@@ -46,6 +48,15 @@ class MessageSchema(MessageCreate):
|
|
46
48
|
role: MessageRole = Field(..., description="The role of the participant.")
|
47
49
|
model: Optional[str] = Field(None, description="The model used to make the function call")
|
48
50
|
agent_id: Optional[str] = Field(None, description="The unique identifier of the agent")
|
51
|
+
tool_calls: Optional[List[OpenAIToolCall]] = Field(
|
52
|
+
default=None, description="The list of tool calls requested. Only applicable for role assistant."
|
53
|
+
)
|
54
|
+
tool_call_id: Optional[str] = Field(default=None, description="The ID of the tool call. Only applicable for role tool.")
|
55
|
+
tool_returns: Optional[List[ToolReturn]] = Field(default=None, description="Tool execution return information for prior tool calls")
|
56
|
+
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
|
57
|
+
|
58
|
+
# TODO: Should we also duplicate the steps here?
|
59
|
+
# TODO: What about tool_return?
|
49
60
|
|
50
61
|
@classmethod
|
51
62
|
def from_message(cls, message: Message) -> "MessageSchema":
|
@@ -64,6 +75,10 @@ class MessageSchema(MessageCreate):
|
|
64
75
|
group_id=message.group_id,
|
65
76
|
model=message.model,
|
66
77
|
agent_id=message.agent_id,
|
78
|
+
tool_calls=message.tool_calls,
|
79
|
+
tool_call_id=message.tool_call_id,
|
80
|
+
tool_returns=message.tool_returns,
|
81
|
+
created_at=message.created_at,
|
67
82
|
)
|
68
83
|
|
69
84
|
|
@@ -114,7 +129,7 @@ class AgentSchema(CreateAgent):
|
|
114
129
|
memory_blocks=[], # TODO: Convert from agent_state.memory if needed
|
115
130
|
tools=[],
|
116
131
|
tool_ids=[tool.id for tool in agent_state.tools] if agent_state.tools else [],
|
117
|
-
source_ids=[
|
132
|
+
source_ids=[source.id for source in agent_state.sources] if agent_state.sources else [],
|
118
133
|
block_ids=[block.id for block in agent_state.memory.blocks],
|
119
134
|
tool_rules=agent_state.tool_rules,
|
120
135
|
tags=agent_state.tags,
|
letta/schemas/file.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
from datetime import datetime
|
2
2
|
from enum import Enum
|
3
|
-
from typing import Optional
|
3
|
+
from typing import List, Optional
|
4
4
|
|
5
5
|
from pydantic import Field
|
6
6
|
|
@@ -108,3 +108,26 @@ class FileAgent(FileAgentBase):
|
|
108
108
|
default_factory=datetime.utcnow,
|
109
109
|
description="Row last-update timestamp (UTC).",
|
110
110
|
)
|
111
|
+
|
112
|
+
|
113
|
+
class AgentFileAttachment(LettaBase):
|
114
|
+
"""Response model for agent file attachments showing file status in agent context"""
|
115
|
+
|
116
|
+
id: str = Field(..., description="Unique identifier of the file-agent relationship")
|
117
|
+
file_id: str = Field(..., description="Unique identifier of the file")
|
118
|
+
file_name: str = Field(..., description="Name of the file")
|
119
|
+
folder_id: str = Field(..., description="Unique identifier of the folder/source")
|
120
|
+
folder_name: str = Field(..., description="Name of the folder/source")
|
121
|
+
is_open: bool = Field(..., description="Whether the file is currently open in the agent's context")
|
122
|
+
last_accessed_at: Optional[datetime] = Field(None, description="Timestamp of last access by the agent")
|
123
|
+
visible_content: Optional[str] = Field(None, description="Portion of the file visible to the agent if open")
|
124
|
+
start_line: Optional[int] = Field(None, description="Starting line number if file was opened with line range")
|
125
|
+
end_line: Optional[int] = Field(None, description="Ending line number if file was opened with line range")
|
126
|
+
|
127
|
+
|
128
|
+
class PaginatedAgentFiles(LettaBase):
|
129
|
+
"""Paginated response for agent files"""
|
130
|
+
|
131
|
+
files: List[AgentFileAttachment] = Field(..., description="List of file attachments for the agent")
|
132
|
+
next_cursor: Optional[str] = Field(None, description="Cursor for fetching the next page (file-agent relationship ID)")
|
133
|
+
has_more: bool = Field(..., description="Whether more results exist after this page")
|
letta/schemas/job.py
CHANGED
@@ -4,6 +4,7 @@ from typing import List, Optional
|
|
4
4
|
from pydantic import BaseModel, ConfigDict, Field
|
5
5
|
|
6
6
|
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
7
|
+
from letta.helpers.datetime_helpers import get_utc_time
|
7
8
|
from letta.schemas.enums import JobStatus, JobType
|
8
9
|
from letta.schemas.letta_base import OrmMetadataBase
|
9
10
|
from letta.schemas.letta_message import MessageType
|
@@ -12,6 +13,7 @@ from letta.schemas.letta_message import MessageType
|
|
12
13
|
class JobBase(OrmMetadataBase):
|
13
14
|
__id_prefix__ = "job"
|
14
15
|
status: JobStatus = Field(default=JobStatus.created, description="The status of the job.")
|
16
|
+
created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the job was created.")
|
15
17
|
completed_at: Optional[datetime] = Field(None, description="The unix timestamp of when the job was completed.")
|
16
18
|
metadata: Optional[dict] = Field(None, validation_alias="metadata_", description="The metadata of the job.")
|
17
19
|
job_type: JobType = Field(default=JobType.JOB, description="The type of the job.")
|
letta/schemas/letta_message.py
CHANGED
@@ -52,6 +52,8 @@ class LettaMessage(BaseModel):
|
|
52
52
|
sender_id: str | None = None
|
53
53
|
step_id: str | None = None
|
54
54
|
is_err: bool | None = None
|
55
|
+
seq_id: int | None = None
|
56
|
+
run_id: str | None = None
|
55
57
|
|
56
58
|
@field_serializer("date")
|
57
59
|
def serialize_datetime(self, dt: datetime, _info):
|
letta/schemas/letta_request.py
CHANGED
@@ -46,6 +46,10 @@ class LettaStreamingRequest(LettaRequest):
|
|
46
46
|
default=False,
|
47
47
|
description="Whether to include periodic keepalive ping messages in the stream to prevent connection timeouts.",
|
48
48
|
)
|
49
|
+
background: bool = Field(
|
50
|
+
default=False,
|
51
|
+
description="Whether to process the request in the background.",
|
52
|
+
)
|
49
53
|
|
50
54
|
|
51
55
|
class LettaAsyncRequest(LettaRequest):
|
@@ -66,3 +70,21 @@ class CreateBatch(BaseModel):
|
|
66
70
|
"'status' is the final batch status (e.g., 'completed', 'failed'), and "
|
67
71
|
"'completed_at' is an ISO 8601 timestamp indicating when the batch job completed.",
|
68
72
|
)
|
73
|
+
|
74
|
+
|
75
|
+
class RetrieveStreamRequest(BaseModel):
|
76
|
+
starting_after: int = Field(
|
77
|
+
0, description="Sequence id to use as a cursor for pagination. Response will start streaming after this chunk sequence id"
|
78
|
+
)
|
79
|
+
include_pings: Optional[bool] = Field(
|
80
|
+
default=False,
|
81
|
+
description="Whether to include periodic keepalive ping messages in the stream to prevent connection timeouts.",
|
82
|
+
)
|
83
|
+
poll_interval: Optional[float] = Field(
|
84
|
+
default=0.1,
|
85
|
+
description="Seconds to wait between polls when no new data.",
|
86
|
+
)
|
87
|
+
batch_size: Optional[int] = Field(
|
88
|
+
default=100,
|
89
|
+
description="Number of entries to read per batch.",
|
90
|
+
)
|
letta/schemas/message.py
CHANGED
@@ -414,6 +414,8 @@ class Message(BaseMessage):
|
|
414
414
|
except json.JSONDecodeError:
|
415
415
|
raise ValueError(f"Failed to decode function return: {text_content}")
|
416
416
|
|
417
|
+
# if self.tool_call_id is None:
|
418
|
+
# import pdb;pdb.set_trace()
|
417
419
|
assert self.tool_call_id is not None
|
418
420
|
|
419
421
|
return ToolReturnMessage(
|
@@ -844,7 +846,7 @@ class Message(BaseMessage):
|
|
844
846
|
}
|
845
847
|
content = []
|
846
848
|
# COT / reasoning / thinking
|
847
|
-
if self.content is not None and len(self.content)
|
849
|
+
if self.content is not None and len(self.content) >= 1:
|
848
850
|
for content_part in self.content:
|
849
851
|
if isinstance(content_part, ReasoningContent):
|
850
852
|
content.append(
|
@@ -861,6 +863,13 @@ class Message(BaseMessage):
|
|
861
863
|
"data": content_part.data,
|
862
864
|
}
|
863
865
|
)
|
866
|
+
if isinstance(content_part, TextContent):
|
867
|
+
content.append(
|
868
|
+
{
|
869
|
+
"type": "text",
|
870
|
+
"text": content_part.text,
|
871
|
+
}
|
872
|
+
)
|
864
873
|
elif text_content is not None:
|
865
874
|
content.append(
|
866
875
|
{
|
@@ -18,6 +18,7 @@ logger = get_logger(__name__)
|
|
18
18
|
class BedrockProvider(Provider):
|
19
19
|
provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.")
|
20
20
|
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
21
|
+
access_key: str = Field(..., description="AWS secret access key for Bedrock.")
|
21
22
|
region: str = Field(..., description="AWS region for Bedrock")
|
22
23
|
|
23
24
|
async def bedrock_get_model_list_async(self) -> list[dict]:
|
@@ -0,0 +1,300 @@
|
|
1
|
+
"""Redis stream manager for reading and writing SSE chunks with batching and TTL."""
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
import json
|
5
|
+
import time
|
6
|
+
from collections import defaultdict
|
7
|
+
from typing import AsyncIterator, Dict, List, Optional
|
8
|
+
|
9
|
+
from letta.data_sources.redis_client import AsyncRedisClient
|
10
|
+
from letta.log import get_logger
|
11
|
+
|
12
|
+
logger = get_logger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
class RedisSSEStreamWriter:
|
16
|
+
"""
|
17
|
+
Efficiently writes SSE chunks to Redis streams with batching and TTL management.
|
18
|
+
|
19
|
+
Features:
|
20
|
+
- Batches writes using Redis pipelines for performance
|
21
|
+
- Automatically sets/refreshes TTL on streams
|
22
|
+
- Tracks sequential IDs for cursor-based recovery
|
23
|
+
- Handles flush on size or time thresholds
|
24
|
+
"""
|
25
|
+
|
26
|
+
def __init__(
|
27
|
+
self,
|
28
|
+
redis_client: AsyncRedisClient,
|
29
|
+
flush_interval: float = 0.5,
|
30
|
+
flush_size: int = 50,
|
31
|
+
stream_ttl_seconds: int = 10800, # 3 hours default
|
32
|
+
max_stream_length: int = 10000, # Max entries per stream
|
33
|
+
):
|
34
|
+
"""
|
35
|
+
Initialize the Redis SSE stream writer.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
redis_client: Redis client instance
|
39
|
+
flush_interval: Seconds between automatic flushes
|
40
|
+
flush_size: Number of chunks to buffer before flushing
|
41
|
+
stream_ttl_seconds: TTL for streams in seconds (default: 6 hours)
|
42
|
+
max_stream_length: Maximum entries per stream before trimming
|
43
|
+
"""
|
44
|
+
self.redis = redis_client
|
45
|
+
self.flush_interval = flush_interval
|
46
|
+
self.flush_size = flush_size
|
47
|
+
self.stream_ttl = stream_ttl_seconds
|
48
|
+
self.max_stream_length = max_stream_length
|
49
|
+
|
50
|
+
# Buffer for batching: run_id -> list of chunks
|
51
|
+
self.buffer: Dict[str, List[Dict]] = defaultdict(list)
|
52
|
+
# Track sequence IDs per run
|
53
|
+
self.seq_counters: Dict[str, int] = defaultdict(lambda: 1)
|
54
|
+
# Track last flush time per run
|
55
|
+
self.last_flush: Dict[str, float] = defaultdict(float)
|
56
|
+
|
57
|
+
# Background flush task
|
58
|
+
self._flush_task = None
|
59
|
+
self._running = False
|
60
|
+
|
61
|
+
async def start(self):
|
62
|
+
"""Start the background flush task."""
|
63
|
+
if not self._running:
|
64
|
+
self._running = True
|
65
|
+
self._flush_task = asyncio.create_task(self._periodic_flush())
|
66
|
+
|
67
|
+
async def stop(self):
|
68
|
+
"""Stop the background flush task and flush remaining data."""
|
69
|
+
self._running = False
|
70
|
+
if self._flush_task:
|
71
|
+
self._flush_task.cancel()
|
72
|
+
try:
|
73
|
+
await self._flush_task
|
74
|
+
except asyncio.CancelledError:
|
75
|
+
pass
|
76
|
+
|
77
|
+
for run_id in list(self.buffer.keys()):
|
78
|
+
if self.buffer[run_id]:
|
79
|
+
await self._flush_run(run_id)
|
80
|
+
|
81
|
+
async def write_chunk(
|
82
|
+
self,
|
83
|
+
run_id: str,
|
84
|
+
data: str,
|
85
|
+
is_complete: bool = False,
|
86
|
+
) -> int:
|
87
|
+
"""
|
88
|
+
Write an SSE chunk to the buffer for a specific run.
|
89
|
+
|
90
|
+
Args:
|
91
|
+
run_id: The run ID to write to
|
92
|
+
data: SSE-formatted chunk data
|
93
|
+
is_complete: Whether this is the final chunk
|
94
|
+
|
95
|
+
Returns:
|
96
|
+
The sequence ID assigned to this chunk
|
97
|
+
"""
|
98
|
+
seq_id = self.seq_counters[run_id]
|
99
|
+
self.seq_counters[run_id] += 1
|
100
|
+
|
101
|
+
chunk = {
|
102
|
+
"seq_id": seq_id,
|
103
|
+
"data": data,
|
104
|
+
"timestamp": int(time.time() * 1000),
|
105
|
+
}
|
106
|
+
|
107
|
+
if is_complete:
|
108
|
+
chunk["complete"] = "true"
|
109
|
+
|
110
|
+
self.buffer[run_id].append(chunk)
|
111
|
+
|
112
|
+
should_flush = (
|
113
|
+
len(self.buffer[run_id]) >= self.flush_size or is_complete or (time.time() - self.last_flush[run_id]) > self.flush_interval
|
114
|
+
)
|
115
|
+
|
116
|
+
if should_flush:
|
117
|
+
await self._flush_run(run_id)
|
118
|
+
|
119
|
+
return seq_id
|
120
|
+
|
121
|
+
async def _flush_run(self, run_id: str):
|
122
|
+
"""Flush buffered chunks for a specific run to Redis."""
|
123
|
+
if not self.buffer[run_id]:
|
124
|
+
return
|
125
|
+
|
126
|
+
chunks = self.buffer[run_id]
|
127
|
+
self.buffer[run_id] = []
|
128
|
+
stream_key = f"sse:run:{run_id}"
|
129
|
+
|
130
|
+
try:
|
131
|
+
client = await self.redis.get_client()
|
132
|
+
|
133
|
+
async with client.pipeline(transaction=False) as pipe:
|
134
|
+
for chunk in chunks:
|
135
|
+
pipe.xadd(stream_key, chunk, maxlen=self.max_stream_length, approximate=True)
|
136
|
+
|
137
|
+
pipe.expire(stream_key, self.stream_ttl)
|
138
|
+
|
139
|
+
await pipe.execute()
|
140
|
+
|
141
|
+
self.last_flush[run_id] = time.time()
|
142
|
+
|
143
|
+
logger.debug(
|
144
|
+
f"Flushed {len(chunks)} chunks to Redis stream {stream_key}, " f"seq_ids {chunks[0]['seq_id']}-{chunks[-1]['seq_id']}"
|
145
|
+
)
|
146
|
+
|
147
|
+
if chunks[-1].get("complete") == "true":
|
148
|
+
self._cleanup_run(run_id)
|
149
|
+
|
150
|
+
except Exception as e:
|
151
|
+
logger.error(f"Failed to flush chunks for run {run_id}: {e}")
|
152
|
+
# Put chunks back in buffer to retry
|
153
|
+
self.buffer[run_id] = chunks + self.buffer[run_id]
|
154
|
+
raise
|
155
|
+
|
156
|
+
async def _periodic_flush(self):
|
157
|
+
"""Background task to periodically flush buffers."""
|
158
|
+
while self._running:
|
159
|
+
try:
|
160
|
+
await asyncio.sleep(self.flush_interval)
|
161
|
+
|
162
|
+
# Check each run for time-based flush
|
163
|
+
current_time = time.time()
|
164
|
+
runs_to_flush = [
|
165
|
+
run_id
|
166
|
+
for run_id, last_flush in self.last_flush.items()
|
167
|
+
if (current_time - last_flush) > self.flush_interval and self.buffer[run_id]
|
168
|
+
]
|
169
|
+
|
170
|
+
for run_id in runs_to_flush:
|
171
|
+
await self._flush_run(run_id)
|
172
|
+
|
173
|
+
except asyncio.CancelledError:
|
174
|
+
break
|
175
|
+
except Exception as e:
|
176
|
+
logger.error(f"Error in periodic flush: {e}")
|
177
|
+
|
178
|
+
def _cleanup_run(self, run_id: str):
|
179
|
+
"""Clean up tracking data for a completed run."""
|
180
|
+
self.buffer.pop(run_id, None)
|
181
|
+
self.seq_counters.pop(run_id, None)
|
182
|
+
self.last_flush.pop(run_id, None)
|
183
|
+
|
184
|
+
async def mark_complete(self, run_id: str):
|
185
|
+
"""Mark a stream as complete and flush."""
|
186
|
+
# Add a [DONE] marker
|
187
|
+
await self.write_chunk(run_id, "data: [DONE]\n\n", is_complete=True)
|
188
|
+
|
189
|
+
|
190
|
+
async def create_background_stream_processor(
|
191
|
+
stream_generator,
|
192
|
+
redis_client: AsyncRedisClient,
|
193
|
+
run_id: str,
|
194
|
+
writer: Optional[RedisSSEStreamWriter] = None,
|
195
|
+
) -> None:
|
196
|
+
"""
|
197
|
+
Process a stream in the background and store chunks to Redis.
|
198
|
+
|
199
|
+
This function consumes the stream generator and writes all chunks
|
200
|
+
to Redis for later retrieval.
|
201
|
+
|
202
|
+
Args:
|
203
|
+
stream_generator: The async generator yielding SSE chunks
|
204
|
+
redis_client: Redis client instance
|
205
|
+
run_id: The run ID to store chunks under
|
206
|
+
writer: Optional pre-configured writer (creates new if not provided)
|
207
|
+
"""
|
208
|
+
if writer is None:
|
209
|
+
writer = RedisSSEStreamWriter(redis_client)
|
210
|
+
await writer.start()
|
211
|
+
should_stop_writer = True
|
212
|
+
else:
|
213
|
+
should_stop_writer = False
|
214
|
+
|
215
|
+
try:
|
216
|
+
async for chunk in stream_generator:
|
217
|
+
if isinstance(chunk, tuple):
|
218
|
+
chunk = chunk[0]
|
219
|
+
|
220
|
+
is_done = isinstance(chunk, str) and ("data: [DONE]" in chunk or "event: error" in chunk)
|
221
|
+
|
222
|
+
await writer.write_chunk(run_id=run_id, data=chunk, is_complete=is_done)
|
223
|
+
|
224
|
+
if is_done:
|
225
|
+
break
|
226
|
+
|
227
|
+
except Exception as e:
|
228
|
+
logger.error(f"Error processing stream for run {run_id}: {e}")
|
229
|
+
# Write error chunk
|
230
|
+
error_chunk = {"error": {"message": str(e)}}
|
231
|
+
await writer.write_chunk(run_id=run_id, data=f"event: error\ndata: {json.dumps(error_chunk)}\n\n", is_complete=True)
|
232
|
+
finally:
|
233
|
+
if should_stop_writer:
|
234
|
+
await writer.stop()
|
235
|
+
|
236
|
+
|
237
|
+
async def redis_sse_stream_generator(
|
238
|
+
redis_client: AsyncRedisClient,
|
239
|
+
run_id: str,
|
240
|
+
starting_after: Optional[int] = None,
|
241
|
+
poll_interval: float = 0.1,
|
242
|
+
batch_size: int = 100,
|
243
|
+
) -> AsyncIterator[str]:
|
244
|
+
"""
|
245
|
+
Generate SSE events from Redis stream chunks.
|
246
|
+
|
247
|
+
This generator reads chunks stored in Redis streams and yields them as SSE events.
|
248
|
+
It supports cursor-based recovery by allowing you to start from a specific seq_id.
|
249
|
+
|
250
|
+
Args:
|
251
|
+
redis_client: Redis client instance
|
252
|
+
run_id: The run ID to read chunks for
|
253
|
+
starting_after: Sequential ID (integer) to start reading from (default: None for beginning)
|
254
|
+
poll_interval: Seconds to wait between polls when no new data (default: 0.1)
|
255
|
+
batch_size: Number of entries to read per batch (default: 100)
|
256
|
+
|
257
|
+
Yields:
|
258
|
+
SSE-formatted chunks from the Redis stream
|
259
|
+
"""
|
260
|
+
stream_key = f"sse:run:{run_id}"
|
261
|
+
last_redis_id = "-"
|
262
|
+
cursor_seq_id = starting_after or 0
|
263
|
+
|
264
|
+
logger.debug(f"Starting redis_sse_stream_generator for run_id={run_id}, stream_key={stream_key}")
|
265
|
+
|
266
|
+
while True:
|
267
|
+
entries = await redis_client.xrange(stream_key, start=last_redis_id, count=batch_size)
|
268
|
+
|
269
|
+
if entries:
|
270
|
+
yielded_any = False
|
271
|
+
for entry_id, fields in entries:
|
272
|
+
if entry_id == last_redis_id:
|
273
|
+
continue
|
274
|
+
|
275
|
+
chunk_seq_id = int(fields.get("seq_id", 0))
|
276
|
+
if chunk_seq_id > cursor_seq_id:
|
277
|
+
data = fields.get("data", "")
|
278
|
+
if not data:
|
279
|
+
logger.debug(f"No data found for chunk {chunk_seq_id} in run {run_id}")
|
280
|
+
continue
|
281
|
+
|
282
|
+
if '"run_id":null' in data:
|
283
|
+
data = data.replace('"run_id":null', f'"run_id":"{run_id}"')
|
284
|
+
|
285
|
+
if '"seq_id":null' in data:
|
286
|
+
data = data.replace('"seq_id":null', f'"seq_id":{chunk_seq_id}')
|
287
|
+
|
288
|
+
yield data
|
289
|
+
yielded_any = True
|
290
|
+
|
291
|
+
if fields.get("complete") == "true":
|
292
|
+
return
|
293
|
+
|
294
|
+
last_redis_id = entry_id
|
295
|
+
|
296
|
+
if not yielded_any and len(entries) > 1:
|
297
|
+
continue
|
298
|
+
|
299
|
+
if not entries or (len(entries) == 1 and entries[0][0] == last_redis_id):
|
300
|
+
await asyncio.sleep(poll_interval)
|