letta-nightly 0.11.4.dev20250825104222__py3-none-any.whl → 0.11.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. letta/__init__.py +1 -1
  2. letta/agent.py +9 -3
  3. letta/agents/base_agent.py +2 -2
  4. letta/agents/letta_agent.py +56 -45
  5. letta/agents/voice_agent.py +2 -2
  6. letta/data_sources/redis_client.py +146 -1
  7. letta/errors.py +4 -0
  8. letta/functions/function_sets/files.py +2 -2
  9. letta/functions/mcp_client/types.py +30 -6
  10. letta/functions/schema_generator.py +46 -1
  11. letta/functions/schema_validator.py +17 -2
  12. letta/functions/types.py +1 -1
  13. letta/helpers/tool_execution_helper.py +0 -2
  14. letta/llm_api/anthropic_client.py +27 -5
  15. letta/llm_api/deepseek_client.py +97 -0
  16. letta/llm_api/groq_client.py +79 -0
  17. letta/llm_api/helpers.py +0 -1
  18. letta/llm_api/llm_api_tools.py +2 -113
  19. letta/llm_api/llm_client.py +21 -0
  20. letta/llm_api/llm_client_base.py +11 -9
  21. letta/llm_api/openai_client.py +3 -0
  22. letta/llm_api/xai_client.py +85 -0
  23. letta/prompts/prompt_generator.py +190 -0
  24. letta/schemas/agent_file.py +17 -2
  25. letta/schemas/file.py +24 -1
  26. letta/schemas/job.py +2 -0
  27. letta/schemas/letta_message.py +2 -0
  28. letta/schemas/letta_request.py +22 -0
  29. letta/schemas/message.py +10 -1
  30. letta/schemas/providers/bedrock.py +1 -0
  31. letta/server/rest_api/redis_stream_manager.py +300 -0
  32. letta/server/rest_api/routers/v1/agents.py +129 -7
  33. letta/server/rest_api/routers/v1/folders.py +15 -5
  34. letta/server/rest_api/routers/v1/runs.py +101 -11
  35. letta/server/rest_api/routers/v1/sources.py +21 -53
  36. letta/server/rest_api/routers/v1/telemetry.py +14 -4
  37. letta/server/rest_api/routers/v1/tools.py +2 -2
  38. letta/server/rest_api/streaming_response.py +3 -24
  39. letta/server/server.py +0 -1
  40. letta/services/agent_manager.py +2 -2
  41. letta/services/agent_serialization_manager.py +129 -32
  42. letta/services/file_manager.py +111 -6
  43. letta/services/file_processor/file_processor.py +5 -2
  44. letta/services/files_agents_manager.py +60 -0
  45. letta/services/helpers/agent_manager_helper.py +4 -205
  46. letta/services/helpers/tool_parser_helper.py +6 -3
  47. letta/services/mcp/base_client.py +7 -1
  48. letta/services/mcp/sse_client.py +7 -2
  49. letta/services/mcp/stdio_client.py +5 -0
  50. letta/services/mcp/streamable_http_client.py +11 -2
  51. letta/services/mcp_manager.py +31 -30
  52. letta/services/source_manager.py +26 -1
  53. letta/services/summarizer/summarizer.py +21 -10
  54. letta/services/tool_executor/files_tool_executor.py +13 -9
  55. letta/services/tool_executor/mcp_tool_executor.py +3 -0
  56. letta/services/tool_executor/tool_execution_manager.py +13 -0
  57. letta/services/tool_manager.py +43 -20
  58. letta/settings.py +1 -0
  59. letta/utils.py +37 -0
  60. {letta_nightly-0.11.4.dev20250825104222.dist-info → letta_nightly-0.11.5.dist-info}/METADATA +2 -2
  61. {letta_nightly-0.11.4.dev20250825104222.dist-info → letta_nightly-0.11.5.dist-info}/RECORD +64 -63
  62. letta/functions/mcp_client/__init__.py +0 -0
  63. letta/functions/mcp_client/base_client.py +0 -156
  64. letta/functions/mcp_client/sse_client.py +0 -51
  65. letta/functions/mcp_client/stdio_client.py +0 -109
  66. {letta_nightly-0.11.4.dev20250825104222.dist-info → letta_nightly-0.11.5.dist-info}/LICENSE +0 -0
  67. {letta_nightly-0.11.4.dev20250825104222.dist-info → letta_nightly-0.11.5.dist-info}/WHEEL +0 -0
  68. {letta_nightly-0.11.4.dev20250825104222.dist-info → letta_nightly-0.11.5.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,190 @@
1
+ from datetime import datetime
2
+ from typing import List, Literal, Optional
3
+
4
+ from letta.constants import IN_CONTEXT_MEMORY_KEYWORD
5
+ from letta.helpers import ToolRulesSolver
6
+ from letta.helpers.datetime_helpers import format_datetime, get_local_time_fast
7
+ from letta.otel.tracing import trace_method
8
+ from letta.schemas.memory import Memory
9
+
10
+
11
+ class PromptGenerator:
12
+
13
+ # TODO: This code is kind of wonky and deserves a rewrite
14
+ @trace_method
15
+ @staticmethod
16
+ def compile_memory_metadata_block(
17
+ memory_edit_timestamp: datetime,
18
+ timezone: str,
19
+ previous_message_count: int = 0,
20
+ archival_memory_size: Optional[int] = 0,
21
+ ) -> str:
22
+ """
23
+ Generate a memory metadata block for the agent's system prompt.
24
+
25
+ This creates a structured metadata section that informs the agent about
26
+ the current state of its memory systems, including timing information
27
+ and memory counts. This helps the agent understand what information
28
+ is available through its tools.
29
+
30
+ Args:
31
+ memory_edit_timestamp: When memory blocks were last modified
32
+ timezone: The timezone to use for formatting timestamps (e.g., 'America/Los_Angeles')
33
+ previous_message_count: Number of messages in recall memory (conversation history)
34
+ archival_memory_size: Number of items in archival memory (long-term storage)
35
+
36
+ Returns:
37
+ A formatted string containing the memory metadata block with XML-style tags
38
+
39
+ Example Output:
40
+ <memory_metadata>
41
+ - The current time is: 2024-01-15 10:30 AM PST
42
+ - Memory blocks were last modified: 2024-01-15 09:00 AM PST
43
+ - 42 previous messages between you and the user are stored in recall memory (use tools to access them)
44
+ - 156 total memories you created are stored in archival memory (use tools to access them)
45
+ </memory_metadata>
46
+ """
47
+ # Put the timestamp in the local timezone (mimicking get_local_time())
48
+ timestamp_str = format_datetime(memory_edit_timestamp, timezone)
49
+
50
+ # Create a metadata block of info so the agent knows about the metadata of out-of-context memories
51
+ metadata_lines = [
52
+ "<memory_metadata>",
53
+ f"- The current time is: {get_local_time_fast(timezone)}",
54
+ f"- Memory blocks were last modified: {timestamp_str}",
55
+ f"- {previous_message_count} previous messages between you and the user are stored in recall memory (use tools to access them)",
56
+ ]
57
+
58
+ # Only include archival memory line if there are archival memories
59
+ if archival_memory_size is not None and archival_memory_size > 0:
60
+ metadata_lines.append(
61
+ f"- {archival_memory_size} total memories you created are stored in archival memory (use tools to access them)"
62
+ )
63
+
64
+ metadata_lines.append("</memory_metadata>")
65
+ memory_metadata_block = "\n".join(metadata_lines)
66
+ return memory_metadata_block
67
+
68
+ @staticmethod
69
+ def safe_format(template: str, variables: dict) -> str:
70
+ """
71
+ Safely formats a template string, preserving empty {} and {unknown_vars}
72
+ while substituting known variables.
73
+
74
+ If we simply use {} in format_map, it'll be treated as a positional field
75
+ """
76
+ # First escape any empty {} by doubling them
77
+ escaped = template.replace("{}", "{{}}")
78
+
79
+ # Now use format_map with our custom mapping
80
+ return escaped.format_map(PreserveMapping(variables))
81
+
82
+ @trace_method
83
+ @staticmethod
84
+ def get_system_message_from_compiled_memory(
85
+ system_prompt: str,
86
+ memory_with_sources: str,
87
+ in_context_memory_last_edit: datetime, # TODO move this inside of BaseMemory?
88
+ timezone: str,
89
+ user_defined_variables: Optional[dict] = None,
90
+ append_icm_if_missing: bool = True,
91
+ template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
92
+ previous_message_count: int = 0,
93
+ archival_memory_size: int = 0,
94
+ ) -> str:
95
+ """Prepare the final/full system message that will be fed into the LLM API
96
+
97
+ The base system message may be templated, in which case we need to render the variables.
98
+
99
+ The following are reserved variables:
100
+ - CORE_MEMORY: the in-context memory of the LLM
101
+ """
102
+ if user_defined_variables is not None:
103
+ # TODO eventually support the user defining their own variables to inject
104
+ raise NotImplementedError
105
+ else:
106
+ variables = {}
107
+
108
+ # Add the protected memory variable
109
+ if IN_CONTEXT_MEMORY_KEYWORD in variables:
110
+ raise ValueError(f"Found protected variable '{IN_CONTEXT_MEMORY_KEYWORD}' in user-defined vars: {str(user_defined_variables)}")
111
+ else:
112
+ # TODO should this all put into the memory.__repr__ function?
113
+ memory_metadata_string = PromptGenerator.compile_memory_metadata_block(
114
+ memory_edit_timestamp=in_context_memory_last_edit,
115
+ previous_message_count=previous_message_count,
116
+ archival_memory_size=archival_memory_size,
117
+ timezone=timezone,
118
+ )
119
+
120
+ full_memory_string = memory_with_sources + "\n\n" + memory_metadata_string
121
+
122
+ # Add to the variables list to inject
123
+ variables[IN_CONTEXT_MEMORY_KEYWORD] = full_memory_string
124
+
125
+ if template_format == "f-string":
126
+ memory_variable_string = "{" + IN_CONTEXT_MEMORY_KEYWORD + "}"
127
+
128
+ # Catch the special case where the system prompt is unformatted
129
+ if append_icm_if_missing:
130
+ if memory_variable_string not in system_prompt:
131
+ # In this case, append it to the end to make sure memory is still injected
132
+ # warnings.warn(f"{IN_CONTEXT_MEMORY_KEYWORD} variable was missing from system prompt, appending instead")
133
+ system_prompt += "\n\n" + memory_variable_string
134
+
135
+ # render the variables using the built-in templater
136
+ try:
137
+ if user_defined_variables:
138
+ formatted_prompt = PromptGenerator.safe_format(system_prompt, variables)
139
+ else:
140
+ formatted_prompt = system_prompt.replace(memory_variable_string, full_memory_string)
141
+ except Exception as e:
142
+ raise ValueError(f"Failed to format system prompt - {str(e)}. System prompt value:\n{system_prompt}")
143
+
144
+ else:
145
+ # TODO support for mustache and jinja2
146
+ raise NotImplementedError(template_format)
147
+
148
+ return formatted_prompt
149
+
150
+ @trace_method
151
+ @staticmethod
152
+ async def compile_system_message_async(
153
+ system_prompt: str,
154
+ in_context_memory: Memory,
155
+ in_context_memory_last_edit: datetime, # TODO move this inside of BaseMemory?
156
+ timezone: str,
157
+ user_defined_variables: Optional[dict] = None,
158
+ append_icm_if_missing: bool = True,
159
+ template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
160
+ previous_message_count: int = 0,
161
+ archival_memory_size: int = 0,
162
+ tool_rules_solver: Optional[ToolRulesSolver] = None,
163
+ sources: Optional[List] = None,
164
+ max_files_open: Optional[int] = None,
165
+ ) -> str:
166
+ tool_constraint_block = None
167
+ if tool_rules_solver is not None:
168
+ tool_constraint_block = tool_rules_solver.compile_tool_rule_prompts()
169
+
170
+ if user_defined_variables is not None:
171
+ # TODO eventually support the user defining their own variables to inject
172
+ raise NotImplementedError
173
+ else:
174
+ pass
175
+
176
+ memory_with_sources = await in_context_memory.compile_in_thread_async(
177
+ tool_usage_rules=tool_constraint_block, sources=sources, max_files_open=max_files_open
178
+ )
179
+
180
+ return PromptGenerator.get_system_message_from_compiled_memory(
181
+ system_prompt=system_prompt,
182
+ memory_with_sources=memory_with_sources,
183
+ in_context_memory_last_edit=in_context_memory_last_edit,
184
+ timezone=timezone,
185
+ user_defined_variables=user_defined_variables,
186
+ append_icm_if_missing=append_icm_if_missing,
187
+ template_format=template_format,
188
+ previous_message_count=previous_message_count,
189
+ archival_memory_size=archival_memory_size,
190
+ )
@@ -1,15 +1,17 @@
1
1
  from datetime import datetime
2
2
  from typing import Any, Dict, List, Optional
3
3
 
4
+ from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
4
5
  from pydantic import BaseModel, Field
5
6
 
7
+ from letta.helpers.datetime_helpers import get_utc_time
6
8
  from letta.schemas.agent import AgentState, CreateAgent
7
9
  from letta.schemas.block import Block, CreateBlock
8
10
  from letta.schemas.enums import MessageRole
9
11
  from letta.schemas.file import FileAgent, FileAgentBase, FileMetadata, FileMetadataBase
10
12
  from letta.schemas.group import Group, GroupCreate
11
13
  from letta.schemas.mcp import MCPServer
12
- from letta.schemas.message import Message, MessageCreate
14
+ from letta.schemas.message import Message, MessageCreate, ToolReturn
13
15
  from letta.schemas.source import Source, SourceCreate
14
16
  from letta.schemas.tool import Tool
15
17
  from letta.schemas.user import User
@@ -46,6 +48,15 @@ class MessageSchema(MessageCreate):
46
48
  role: MessageRole = Field(..., description="The role of the participant.")
47
49
  model: Optional[str] = Field(None, description="The model used to make the function call")
48
50
  agent_id: Optional[str] = Field(None, description="The unique identifier of the agent")
51
+ tool_calls: Optional[List[OpenAIToolCall]] = Field(
52
+ default=None, description="The list of tool calls requested. Only applicable for role assistant."
53
+ )
54
+ tool_call_id: Optional[str] = Field(default=None, description="The ID of the tool call. Only applicable for role tool.")
55
+ tool_returns: Optional[List[ToolReturn]] = Field(default=None, description="Tool execution return information for prior tool calls")
56
+ created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
57
+
58
+ # TODO: Should we also duplicate the steps here?
59
+ # TODO: What about tool_return?
49
60
 
50
61
  @classmethod
51
62
  def from_message(cls, message: Message) -> "MessageSchema":
@@ -64,6 +75,10 @@ class MessageSchema(MessageCreate):
64
75
  group_id=message.group_id,
65
76
  model=message.model,
66
77
  agent_id=message.agent_id,
78
+ tool_calls=message.tool_calls,
79
+ tool_call_id=message.tool_call_id,
80
+ tool_returns=message.tool_returns,
81
+ created_at=message.created_at,
67
82
  )
68
83
 
69
84
 
@@ -114,7 +129,7 @@ class AgentSchema(CreateAgent):
114
129
  memory_blocks=[], # TODO: Convert from agent_state.memory if needed
115
130
  tools=[],
116
131
  tool_ids=[tool.id for tool in agent_state.tools] if agent_state.tools else [],
117
- source_ids=[], # [source.id for source in agent_state.sources] if agent_state.sources else [],
132
+ source_ids=[source.id for source in agent_state.sources] if agent_state.sources else [],
118
133
  block_ids=[block.id for block in agent_state.memory.blocks],
119
134
  tool_rules=agent_state.tool_rules,
120
135
  tags=agent_state.tags,
letta/schemas/file.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from datetime import datetime
2
2
  from enum import Enum
3
- from typing import Optional
3
+ from typing import List, Optional
4
4
 
5
5
  from pydantic import Field
6
6
 
@@ -108,3 +108,26 @@ class FileAgent(FileAgentBase):
108
108
  default_factory=datetime.utcnow,
109
109
  description="Row last-update timestamp (UTC).",
110
110
  )
111
+
112
+
113
+ class AgentFileAttachment(LettaBase):
114
+ """Response model for agent file attachments showing file status in agent context"""
115
+
116
+ id: str = Field(..., description="Unique identifier of the file-agent relationship")
117
+ file_id: str = Field(..., description="Unique identifier of the file")
118
+ file_name: str = Field(..., description="Name of the file")
119
+ folder_id: str = Field(..., description="Unique identifier of the folder/source")
120
+ folder_name: str = Field(..., description="Name of the folder/source")
121
+ is_open: bool = Field(..., description="Whether the file is currently open in the agent's context")
122
+ last_accessed_at: Optional[datetime] = Field(None, description="Timestamp of last access by the agent")
123
+ visible_content: Optional[str] = Field(None, description="Portion of the file visible to the agent if open")
124
+ start_line: Optional[int] = Field(None, description="Starting line number if file was opened with line range")
125
+ end_line: Optional[int] = Field(None, description="Ending line number if file was opened with line range")
126
+
127
+
128
+ class PaginatedAgentFiles(LettaBase):
129
+ """Paginated response for agent files"""
130
+
131
+ files: List[AgentFileAttachment] = Field(..., description="List of file attachments for the agent")
132
+ next_cursor: Optional[str] = Field(None, description="Cursor for fetching the next page (file-agent relationship ID)")
133
+ has_more: bool = Field(..., description="Whether more results exist after this page")
letta/schemas/job.py CHANGED
@@ -4,6 +4,7 @@ from typing import List, Optional
4
4
  from pydantic import BaseModel, ConfigDict, Field
5
5
 
6
6
  from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
7
+ from letta.helpers.datetime_helpers import get_utc_time
7
8
  from letta.schemas.enums import JobStatus, JobType
8
9
  from letta.schemas.letta_base import OrmMetadataBase
9
10
  from letta.schemas.letta_message import MessageType
@@ -12,6 +13,7 @@ from letta.schemas.letta_message import MessageType
12
13
  class JobBase(OrmMetadataBase):
13
14
  __id_prefix__ = "job"
14
15
  status: JobStatus = Field(default=JobStatus.created, description="The status of the job.")
16
+ created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the job was created.")
15
17
  completed_at: Optional[datetime] = Field(None, description="The unix timestamp of when the job was completed.")
16
18
  metadata: Optional[dict] = Field(None, validation_alias="metadata_", description="The metadata of the job.")
17
19
  job_type: JobType = Field(default=JobType.JOB, description="The type of the job.")
@@ -52,6 +52,8 @@ class LettaMessage(BaseModel):
52
52
  sender_id: str | None = None
53
53
  step_id: str | None = None
54
54
  is_err: bool | None = None
55
+ seq_id: int | None = None
56
+ run_id: str | None = None
55
57
 
56
58
  @field_serializer("date")
57
59
  def serialize_datetime(self, dt: datetime, _info):
@@ -46,6 +46,10 @@ class LettaStreamingRequest(LettaRequest):
46
46
  default=False,
47
47
  description="Whether to include periodic keepalive ping messages in the stream to prevent connection timeouts.",
48
48
  )
49
+ background: bool = Field(
50
+ default=False,
51
+ description="Whether to process the request in the background.",
52
+ )
49
53
 
50
54
 
51
55
  class LettaAsyncRequest(LettaRequest):
@@ -66,3 +70,21 @@ class CreateBatch(BaseModel):
66
70
  "'status' is the final batch status (e.g., 'completed', 'failed'), and "
67
71
  "'completed_at' is an ISO 8601 timestamp indicating when the batch job completed.",
68
72
  )
73
+
74
+
75
+ class RetrieveStreamRequest(BaseModel):
76
+ starting_after: int = Field(
77
+ 0, description="Sequence id to use as a cursor for pagination. Response will start streaming after this chunk sequence id"
78
+ )
79
+ include_pings: Optional[bool] = Field(
80
+ default=False,
81
+ description="Whether to include periodic keepalive ping messages in the stream to prevent connection timeouts.",
82
+ )
83
+ poll_interval: Optional[float] = Field(
84
+ default=0.1,
85
+ description="Seconds to wait between polls when no new data.",
86
+ )
87
+ batch_size: Optional[int] = Field(
88
+ default=100,
89
+ description="Number of entries to read per batch.",
90
+ )
letta/schemas/message.py CHANGED
@@ -414,6 +414,8 @@ class Message(BaseMessage):
414
414
  except json.JSONDecodeError:
415
415
  raise ValueError(f"Failed to decode function return: {text_content}")
416
416
 
417
+ # if self.tool_call_id is None:
418
+ # import pdb;pdb.set_trace()
417
419
  assert self.tool_call_id is not None
418
420
 
419
421
  return ToolReturnMessage(
@@ -844,7 +846,7 @@ class Message(BaseMessage):
844
846
  }
845
847
  content = []
846
848
  # COT / reasoning / thinking
847
- if self.content is not None and len(self.content) > 1:
849
+ if self.content is not None and len(self.content) >= 1:
848
850
  for content_part in self.content:
849
851
  if isinstance(content_part, ReasoningContent):
850
852
  content.append(
@@ -861,6 +863,13 @@ class Message(BaseMessage):
861
863
  "data": content_part.data,
862
864
  }
863
865
  )
866
+ if isinstance(content_part, TextContent):
867
+ content.append(
868
+ {
869
+ "type": "text",
870
+ "text": content_part.text,
871
+ }
872
+ )
864
873
  elif text_content is not None:
865
874
  content.append(
866
875
  {
@@ -18,6 +18,7 @@ logger = get_logger(__name__)
18
18
  class BedrockProvider(Provider):
19
19
  provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.")
20
20
  provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
21
+ access_key: str = Field(..., description="AWS secret access key for Bedrock.")
21
22
  region: str = Field(..., description="AWS region for Bedrock")
22
23
 
23
24
  async def bedrock_get_model_list_async(self) -> list[dict]:
@@ -0,0 +1,300 @@
1
+ """Redis stream manager for reading and writing SSE chunks with batching and TTL."""
2
+
3
+ import asyncio
4
+ import json
5
+ import time
6
+ from collections import defaultdict
7
+ from typing import AsyncIterator, Dict, List, Optional
8
+
9
+ from letta.data_sources.redis_client import AsyncRedisClient
10
+ from letta.log import get_logger
11
+
12
+ logger = get_logger(__name__)
13
+
14
+
15
+ class RedisSSEStreamWriter:
16
+ """
17
+ Efficiently writes SSE chunks to Redis streams with batching and TTL management.
18
+
19
+ Features:
20
+ - Batches writes using Redis pipelines for performance
21
+ - Automatically sets/refreshes TTL on streams
22
+ - Tracks sequential IDs for cursor-based recovery
23
+ - Handles flush on size or time thresholds
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ redis_client: AsyncRedisClient,
29
+ flush_interval: float = 0.5,
30
+ flush_size: int = 50,
31
+ stream_ttl_seconds: int = 10800, # 3 hours default
32
+ max_stream_length: int = 10000, # Max entries per stream
33
+ ):
34
+ """
35
+ Initialize the Redis SSE stream writer.
36
+
37
+ Args:
38
+ redis_client: Redis client instance
39
+ flush_interval: Seconds between automatic flushes
40
+ flush_size: Number of chunks to buffer before flushing
41
+ stream_ttl_seconds: TTL for streams in seconds (default: 6 hours)
42
+ max_stream_length: Maximum entries per stream before trimming
43
+ """
44
+ self.redis = redis_client
45
+ self.flush_interval = flush_interval
46
+ self.flush_size = flush_size
47
+ self.stream_ttl = stream_ttl_seconds
48
+ self.max_stream_length = max_stream_length
49
+
50
+ # Buffer for batching: run_id -> list of chunks
51
+ self.buffer: Dict[str, List[Dict]] = defaultdict(list)
52
+ # Track sequence IDs per run
53
+ self.seq_counters: Dict[str, int] = defaultdict(lambda: 1)
54
+ # Track last flush time per run
55
+ self.last_flush: Dict[str, float] = defaultdict(float)
56
+
57
+ # Background flush task
58
+ self._flush_task = None
59
+ self._running = False
60
+
61
+ async def start(self):
62
+ """Start the background flush task."""
63
+ if not self._running:
64
+ self._running = True
65
+ self._flush_task = asyncio.create_task(self._periodic_flush())
66
+
67
+ async def stop(self):
68
+ """Stop the background flush task and flush remaining data."""
69
+ self._running = False
70
+ if self._flush_task:
71
+ self._flush_task.cancel()
72
+ try:
73
+ await self._flush_task
74
+ except asyncio.CancelledError:
75
+ pass
76
+
77
+ for run_id in list(self.buffer.keys()):
78
+ if self.buffer[run_id]:
79
+ await self._flush_run(run_id)
80
+
81
+ async def write_chunk(
82
+ self,
83
+ run_id: str,
84
+ data: str,
85
+ is_complete: bool = False,
86
+ ) -> int:
87
+ """
88
+ Write an SSE chunk to the buffer for a specific run.
89
+
90
+ Args:
91
+ run_id: The run ID to write to
92
+ data: SSE-formatted chunk data
93
+ is_complete: Whether this is the final chunk
94
+
95
+ Returns:
96
+ The sequence ID assigned to this chunk
97
+ """
98
+ seq_id = self.seq_counters[run_id]
99
+ self.seq_counters[run_id] += 1
100
+
101
+ chunk = {
102
+ "seq_id": seq_id,
103
+ "data": data,
104
+ "timestamp": int(time.time() * 1000),
105
+ }
106
+
107
+ if is_complete:
108
+ chunk["complete"] = "true"
109
+
110
+ self.buffer[run_id].append(chunk)
111
+
112
+ should_flush = (
113
+ len(self.buffer[run_id]) >= self.flush_size or is_complete or (time.time() - self.last_flush[run_id]) > self.flush_interval
114
+ )
115
+
116
+ if should_flush:
117
+ await self._flush_run(run_id)
118
+
119
+ return seq_id
120
+
121
+ async def _flush_run(self, run_id: str):
122
+ """Flush buffered chunks for a specific run to Redis."""
123
+ if not self.buffer[run_id]:
124
+ return
125
+
126
+ chunks = self.buffer[run_id]
127
+ self.buffer[run_id] = []
128
+ stream_key = f"sse:run:{run_id}"
129
+
130
+ try:
131
+ client = await self.redis.get_client()
132
+
133
+ async with client.pipeline(transaction=False) as pipe:
134
+ for chunk in chunks:
135
+ pipe.xadd(stream_key, chunk, maxlen=self.max_stream_length, approximate=True)
136
+
137
+ pipe.expire(stream_key, self.stream_ttl)
138
+
139
+ await pipe.execute()
140
+
141
+ self.last_flush[run_id] = time.time()
142
+
143
+ logger.debug(
144
+ f"Flushed {len(chunks)} chunks to Redis stream {stream_key}, " f"seq_ids {chunks[0]['seq_id']}-{chunks[-1]['seq_id']}"
145
+ )
146
+
147
+ if chunks[-1].get("complete") == "true":
148
+ self._cleanup_run(run_id)
149
+
150
+ except Exception as e:
151
+ logger.error(f"Failed to flush chunks for run {run_id}: {e}")
152
+ # Put chunks back in buffer to retry
153
+ self.buffer[run_id] = chunks + self.buffer[run_id]
154
+ raise
155
+
156
+ async def _periodic_flush(self):
157
+ """Background task to periodically flush buffers."""
158
+ while self._running:
159
+ try:
160
+ await asyncio.sleep(self.flush_interval)
161
+
162
+ # Check each run for time-based flush
163
+ current_time = time.time()
164
+ runs_to_flush = [
165
+ run_id
166
+ for run_id, last_flush in self.last_flush.items()
167
+ if (current_time - last_flush) > self.flush_interval and self.buffer[run_id]
168
+ ]
169
+
170
+ for run_id in runs_to_flush:
171
+ await self._flush_run(run_id)
172
+
173
+ except asyncio.CancelledError:
174
+ break
175
+ except Exception as e:
176
+ logger.error(f"Error in periodic flush: {e}")
177
+
178
+ def _cleanup_run(self, run_id: str):
179
+ """Clean up tracking data for a completed run."""
180
+ self.buffer.pop(run_id, None)
181
+ self.seq_counters.pop(run_id, None)
182
+ self.last_flush.pop(run_id, None)
183
+
184
+ async def mark_complete(self, run_id: str):
185
+ """Mark a stream as complete and flush."""
186
+ # Add a [DONE] marker
187
+ await self.write_chunk(run_id, "data: [DONE]\n\n", is_complete=True)
188
+
189
+
190
+ async def create_background_stream_processor(
191
+ stream_generator,
192
+ redis_client: AsyncRedisClient,
193
+ run_id: str,
194
+ writer: Optional[RedisSSEStreamWriter] = None,
195
+ ) -> None:
196
+ """
197
+ Process a stream in the background and store chunks to Redis.
198
+
199
+ This function consumes the stream generator and writes all chunks
200
+ to Redis for later retrieval.
201
+
202
+ Args:
203
+ stream_generator: The async generator yielding SSE chunks
204
+ redis_client: Redis client instance
205
+ run_id: The run ID to store chunks under
206
+ writer: Optional pre-configured writer (creates new if not provided)
207
+ """
208
+ if writer is None:
209
+ writer = RedisSSEStreamWriter(redis_client)
210
+ await writer.start()
211
+ should_stop_writer = True
212
+ else:
213
+ should_stop_writer = False
214
+
215
+ try:
216
+ async for chunk in stream_generator:
217
+ if isinstance(chunk, tuple):
218
+ chunk = chunk[0]
219
+
220
+ is_done = isinstance(chunk, str) and ("data: [DONE]" in chunk or "event: error" in chunk)
221
+
222
+ await writer.write_chunk(run_id=run_id, data=chunk, is_complete=is_done)
223
+
224
+ if is_done:
225
+ break
226
+
227
+ except Exception as e:
228
+ logger.error(f"Error processing stream for run {run_id}: {e}")
229
+ # Write error chunk
230
+ error_chunk = {"error": {"message": str(e)}}
231
+ await writer.write_chunk(run_id=run_id, data=f"event: error\ndata: {json.dumps(error_chunk)}\n\n", is_complete=True)
232
+ finally:
233
+ if should_stop_writer:
234
+ await writer.stop()
235
+
236
+
237
+ async def redis_sse_stream_generator(
238
+ redis_client: AsyncRedisClient,
239
+ run_id: str,
240
+ starting_after: Optional[int] = None,
241
+ poll_interval: float = 0.1,
242
+ batch_size: int = 100,
243
+ ) -> AsyncIterator[str]:
244
+ """
245
+ Generate SSE events from Redis stream chunks.
246
+
247
+ This generator reads chunks stored in Redis streams and yields them as SSE events.
248
+ It supports cursor-based recovery by allowing you to start from a specific seq_id.
249
+
250
+ Args:
251
+ redis_client: Redis client instance
252
+ run_id: The run ID to read chunks for
253
+ starting_after: Sequential ID (integer) to start reading from (default: None for beginning)
254
+ poll_interval: Seconds to wait between polls when no new data (default: 0.1)
255
+ batch_size: Number of entries to read per batch (default: 100)
256
+
257
+ Yields:
258
+ SSE-formatted chunks from the Redis stream
259
+ """
260
+ stream_key = f"sse:run:{run_id}"
261
+ last_redis_id = "-"
262
+ cursor_seq_id = starting_after or 0
263
+
264
+ logger.debug(f"Starting redis_sse_stream_generator for run_id={run_id}, stream_key={stream_key}")
265
+
266
+ while True:
267
+ entries = await redis_client.xrange(stream_key, start=last_redis_id, count=batch_size)
268
+
269
+ if entries:
270
+ yielded_any = False
271
+ for entry_id, fields in entries:
272
+ if entry_id == last_redis_id:
273
+ continue
274
+
275
+ chunk_seq_id = int(fields.get("seq_id", 0))
276
+ if chunk_seq_id > cursor_seq_id:
277
+ data = fields.get("data", "")
278
+ if not data:
279
+ logger.debug(f"No data found for chunk {chunk_seq_id} in run {run_id}")
280
+ continue
281
+
282
+ if '"run_id":null' in data:
283
+ data = data.replace('"run_id":null', f'"run_id":"{run_id}"')
284
+
285
+ if '"seq_id":null' in data:
286
+ data = data.replace('"seq_id":null', f'"seq_id":{chunk_seq_id}')
287
+
288
+ yield data
289
+ yielded_any = True
290
+
291
+ if fields.get("complete") == "true":
292
+ return
293
+
294
+ last_redis_id = entry_id
295
+
296
+ if not yielded_any and len(entries) > 1:
297
+ continue
298
+
299
+ if not entries or (len(entries) == 1 and entries[0][0] == last_redis_id):
300
+ await asyncio.sleep(poll_interval)