letta-nightly 0.6.53.dev20250417104214__py3-none-any.whl → 0.6.54.dev20250419104029__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- letta/__init__.py +1 -1
- letta/agent.py +6 -31
- letta/agents/letta_agent.py +1 -0
- letta/agents/letta_agent_batch.py +369 -18
- letta/constants.py +15 -4
- letta/functions/function_sets/base.py +168 -21
- letta/groups/sleeptime_multi_agent.py +3 -3
- letta/helpers/converters.py +1 -1
- letta/helpers/message_helper.py +1 -0
- letta/jobs/llm_batch_job_polling.py +39 -10
- letta/jobs/scheduler.py +54 -13
- letta/jobs/types.py +26 -6
- letta/llm_api/anthropic_client.py +3 -1
- letta/llm_api/llm_api_tools.py +7 -1
- letta/llm_api/openai.py +2 -0
- letta/orm/agent.py +5 -29
- letta/orm/base.py +2 -2
- letta/orm/enums.py +1 -0
- letta/orm/job.py +5 -0
- letta/orm/llm_batch_items.py +2 -2
- letta/orm/llm_batch_job.py +5 -2
- letta/orm/message.py +12 -4
- letta/orm/passage.py +0 -6
- letta/orm/sqlalchemy_base.py +0 -3
- letta/personas/examples/sleeptime_doc_persona.txt +2 -0
- letta/prompts/system/sleeptime.txt +20 -11
- letta/prompts/system/sleeptime_doc_ingest.txt +35 -0
- letta/schemas/agent.py +24 -1
- letta/schemas/enums.py +3 -1
- letta/schemas/job.py +39 -0
- letta/schemas/letta_message.py +24 -7
- letta/schemas/letta_request.py +7 -2
- letta/schemas/letta_response.py +3 -1
- letta/schemas/llm_batch_job.py +4 -3
- letta/schemas/llm_config.py +6 -2
- letta/schemas/message.py +11 -1
- letta/schemas/providers.py +10 -58
- letta/serialize_schemas/marshmallow_agent.py +25 -22
- letta/serialize_schemas/marshmallow_message.py +1 -1
- letta/server/db.py +75 -49
- letta/server/rest_api/app.py +1 -0
- letta/server/rest_api/interface.py +7 -2
- letta/server/rest_api/routers/v1/__init__.py +2 -0
- letta/server/rest_api/routers/v1/agents.py +33 -6
- letta/server/rest_api/routers/v1/messages.py +132 -0
- letta/server/rest_api/routers/v1/sources.py +21 -2
- letta/server/rest_api/utils.py +23 -10
- letta/server/server.py +67 -21
- letta/services/agent_manager.py +44 -21
- letta/services/group_manager.py +2 -2
- letta/services/helpers/agent_manager_helper.py +5 -3
- letta/services/job_manager.py +34 -5
- letta/services/llm_batch_manager.py +200 -57
- letta/services/message_manager.py +23 -1
- letta/services/passage_manager.py +2 -2
- letta/services/tool_executor/tool_execution_manager.py +13 -3
- letta/services/tool_executor/tool_execution_sandbox.py +0 -1
- letta/services/tool_executor/tool_executor.py +48 -9
- letta/services/tool_sandbox/base.py +24 -6
- letta/services/tool_sandbox/e2b_sandbox.py +25 -5
- letta/services/tool_sandbox/local_sandbox.py +23 -7
- letta/settings.py +2 -2
- {letta_nightly-0.6.53.dev20250417104214.dist-info → letta_nightly-0.6.54.dev20250419104029.dist-info}/METADATA +2 -1
- {letta_nightly-0.6.53.dev20250417104214.dist-info → letta_nightly-0.6.54.dev20250419104029.dist-info}/RECORD +67 -65
- letta/sleeptime_agent.py +0 -61
- {letta_nightly-0.6.53.dev20250417104214.dist-info → letta_nightly-0.6.54.dev20250419104029.dist-info}/LICENSE +0 -0
- {letta_nightly-0.6.53.dev20250417104214.dist-info → letta_nightly-0.6.54.dev20250419104029.dist-info}/WHEEL +0 -0
- {letta_nightly-0.6.53.dev20250417104214.dist-info → letta_nightly-0.6.54.dev20250419104029.dist-info}/entry_points.txt +0 -0
letta/schemas/providers.py
CHANGED
@@ -228,63 +228,6 @@ class OpenAIProvider(Provider):
|
|
228
228
|
return LLM_MAX_TOKENS["DEFAULT"]
|
229
229
|
|
230
230
|
|
231
|
-
class xAIProvider(OpenAIProvider):
|
232
|
-
"""https://docs.x.ai/docs/api-reference"""
|
233
|
-
|
234
|
-
name: str = "xai"
|
235
|
-
api_key: str = Field(..., description="API key for the xAI/Grok API.")
|
236
|
-
base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.")
|
237
|
-
|
238
|
-
def get_model_context_window_size(self, model_name: str) -> Optional[int]:
|
239
|
-
# xAI doesn't return context window in the model listing,
|
240
|
-
# so these are hardcoded from their website
|
241
|
-
if model_name == "grok-2-1212":
|
242
|
-
return 131072
|
243
|
-
else:
|
244
|
-
return None
|
245
|
-
|
246
|
-
def list_llm_models(self) -> List[LLMConfig]:
|
247
|
-
from letta.llm_api.openai import openai_get_model_list
|
248
|
-
|
249
|
-
response = openai_get_model_list(self.base_url, api_key=self.api_key)
|
250
|
-
|
251
|
-
if "data" in response:
|
252
|
-
data = response["data"]
|
253
|
-
else:
|
254
|
-
data = response
|
255
|
-
|
256
|
-
configs = []
|
257
|
-
for model in data:
|
258
|
-
assert "id" in model, f"xAI/Grok model missing 'id' field: {model}"
|
259
|
-
model_name = model["id"]
|
260
|
-
|
261
|
-
# In case xAI starts supporting it in the future:
|
262
|
-
if "context_length" in model:
|
263
|
-
context_window_size = model["context_length"]
|
264
|
-
else:
|
265
|
-
context_window_size = self.get_model_context_window_size(model_name)
|
266
|
-
|
267
|
-
if not context_window_size:
|
268
|
-
warnings.warn(f"Couldn't find context window size for model {model_name}")
|
269
|
-
continue
|
270
|
-
|
271
|
-
configs.append(
|
272
|
-
LLMConfig(
|
273
|
-
model=model_name,
|
274
|
-
model_endpoint_type="xai",
|
275
|
-
model_endpoint=self.base_url,
|
276
|
-
context_window=context_window_size,
|
277
|
-
handle=self.get_handle(model_name),
|
278
|
-
)
|
279
|
-
)
|
280
|
-
|
281
|
-
return configs
|
282
|
-
|
283
|
-
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
284
|
-
# No embeddings supported
|
285
|
-
return []
|
286
|
-
|
287
|
-
|
288
231
|
class DeepSeekProvider(OpenAIProvider):
|
289
232
|
"""
|
290
233
|
DeepSeek ChatCompletions API is similar to OpenAI's reasoning API,
|
@@ -478,7 +421,7 @@ class LMStudioOpenAIProvider(OpenAIProvider):
|
|
478
421
|
return configs
|
479
422
|
|
480
423
|
|
481
|
-
class
|
424
|
+
class XAIProvider(OpenAIProvider):
|
482
425
|
"""https://docs.x.ai/docs/api-reference"""
|
483
426
|
|
484
427
|
name: str = "xai"
|
@@ -490,6 +433,15 @@ class xAIProvider(OpenAIProvider):
|
|
490
433
|
# so these are hardcoded from their website
|
491
434
|
if model_name == "grok-2-1212":
|
492
435
|
return 131072
|
436
|
+
# NOTE: disabling the minis for now since they return weird MM parts
|
437
|
+
# elif model_name == "grok-3-mini-fast-beta":
|
438
|
+
# return 131072
|
439
|
+
# elif model_name == "grok-3-mini-beta":
|
440
|
+
# return 131072
|
441
|
+
elif model_name == "grok-3-fast-beta":
|
442
|
+
return 131072
|
443
|
+
elif model_name == "grok-3-beta":
|
444
|
+
return 131072
|
493
445
|
else:
|
494
446
|
return None
|
495
447
|
|
@@ -4,6 +4,7 @@ from marshmallow import fields, post_dump, pre_load
|
|
4
4
|
|
5
5
|
import letta
|
6
6
|
from letta.orm import Agent
|
7
|
+
from letta.orm import Message as MessageModel
|
7
8
|
from letta.schemas.agent import AgentState as PydanticAgentState
|
8
9
|
from letta.schemas.user import User
|
9
10
|
from letta.serialize_schemas.marshmallow_agent_environment_variable import SerializedAgentEnvironmentVariableSchema
|
@@ -35,7 +36,6 @@ class MarshmallowAgentSchema(BaseSchema):
|
|
35
36
|
|
36
37
|
tool_rules = ToolRulesField()
|
37
38
|
|
38
|
-
messages = fields.List(fields.Nested(SerializedMessageSchema))
|
39
39
|
core_memory = fields.List(fields.Nested(SerializedBlockSchema))
|
40
40
|
tools = fields.List(fields.Nested(SerializedToolSchema))
|
41
41
|
tool_exec_environment_variables = fields.List(fields.Nested(SerializedAgentEnvironmentVariableSchema))
|
@@ -54,6 +54,30 @@ class MarshmallowAgentSchema(BaseSchema):
|
|
54
54
|
field.schema.session = session
|
55
55
|
field.schema.actor = actor
|
56
56
|
|
57
|
+
@post_dump
|
58
|
+
def attach_messages(self, data: Dict, **kwargs):
|
59
|
+
"""
|
60
|
+
After dumping the agent, load all its Message rows and serialize them here.
|
61
|
+
"""
|
62
|
+
# TODO: This is hacky, but want to move fast, please refactor moving forward
|
63
|
+
from letta.server.db import db_context as session_maker
|
64
|
+
|
65
|
+
with session_maker() as session:
|
66
|
+
agent_id = data.get("id")
|
67
|
+
msgs = (
|
68
|
+
session.query(MessageModel)
|
69
|
+
.filter(
|
70
|
+
MessageModel.agent_id == agent_id,
|
71
|
+
MessageModel.organization_id == self.actor.organization_id,
|
72
|
+
)
|
73
|
+
.order_by(MessageModel.sequence_id.asc())
|
74
|
+
.all()
|
75
|
+
)
|
76
|
+
# overwrite the “messages” key with a fully serialized list
|
77
|
+
data[self.FIELD_MESSAGES] = [SerializedMessageSchema(session=self.session, actor=self.actor).dump(m) for m in msgs]
|
78
|
+
|
79
|
+
return data
|
80
|
+
|
57
81
|
@post_dump
|
58
82
|
def sanitize_ids(self, data: Dict, **kwargs):
|
59
83
|
"""
|
@@ -101,25 +125,6 @@ class MarshmallowAgentSchema(BaseSchema):
|
|
101
125
|
del data[self.FIELD_VERSION]
|
102
126
|
return data
|
103
127
|
|
104
|
-
@pre_load
|
105
|
-
def remap_in_context_messages(self, data, **kwargs):
|
106
|
-
"""
|
107
|
-
Restores `message_ids` by collecting message IDs where `in_context` is True,
|
108
|
-
generates new IDs for all messages, and removes `in_context` from all messages.
|
109
|
-
"""
|
110
|
-
messages = data.get(self.FIELD_MESSAGES, [])
|
111
|
-
for msg in messages:
|
112
|
-
msg[self.FIELD_ID] = SerializedMessageSchema.generate_id() # Generate new ID
|
113
|
-
|
114
|
-
message_ids = []
|
115
|
-
in_context_message_indices = data.pop(self.FIELD_IN_CONTEXT_INDICES)
|
116
|
-
for idx in in_context_message_indices:
|
117
|
-
message_ids.append(messages[idx][self.FIELD_ID])
|
118
|
-
|
119
|
-
data[self.FIELD_MESSAGE_IDS] = message_ids
|
120
|
-
|
121
|
-
return data
|
122
|
-
|
123
128
|
class Meta(BaseSchema.Meta):
|
124
129
|
model = Agent
|
125
130
|
exclude = BaseSchema.Meta.exclude + (
|
@@ -127,8 +132,6 @@ class MarshmallowAgentSchema(BaseSchema):
|
|
127
132
|
"template_id",
|
128
133
|
"base_template_id",
|
129
134
|
"sources",
|
130
|
-
"source_passages",
|
131
|
-
"agent_passages",
|
132
135
|
"identities",
|
133
136
|
"is_deleted",
|
134
137
|
"groups",
|
@@ -39,4 +39,4 @@ class SerializedMessageSchema(BaseSchema):
|
|
39
39
|
|
40
40
|
class Meta(BaseSchema.Meta):
|
41
41
|
model = Message
|
42
|
-
exclude = BaseSchema.Meta.exclude + ("step", "job_message", "
|
42
|
+
exclude = BaseSchema.Meta.exclude + ("step", "job_message", "otid", "is_deleted")
|
letta/server/db.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
import os
|
2
|
+
import threading
|
2
3
|
from contextlib import contextmanager
|
3
4
|
|
4
5
|
from rich.console import Console
|
@@ -10,13 +11,17 @@ from sqlalchemy.orm import sessionmaker
|
|
10
11
|
from letta.config import LettaConfig
|
11
12
|
from letta.log import get_logger
|
12
13
|
from letta.orm import Base
|
13
|
-
|
14
|
-
# NOTE: hack to see if single session management works
|
15
14
|
from letta.settings import settings
|
16
15
|
|
17
|
-
|
16
|
+
# Use globals for the lock and initialization flag
|
17
|
+
_engine_lock = threading.Lock()
|
18
|
+
_engine_initialized = False
|
18
19
|
|
20
|
+
# Create variables in global scope but don't initialize them yet
|
21
|
+
config = LettaConfig.load()
|
19
22
|
logger = get_logger(__name__)
|
23
|
+
engine = None
|
24
|
+
SessionLocal = None
|
20
25
|
|
21
26
|
|
22
27
|
def print_sqlite_schema_error():
|
@@ -49,59 +54,80 @@ def db_error_handler():
|
|
49
54
|
exit(1)
|
50
55
|
|
51
56
|
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
settings.
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
def
|
57
|
+
def initialize_engine():
|
58
|
+
"""Initialize the database engine only when needed."""
|
59
|
+
global engine, SessionLocal, _engine_initialized
|
60
|
+
|
61
|
+
with _engine_lock:
|
62
|
+
# Check again inside the lock to prevent race conditions
|
63
|
+
if _engine_initialized:
|
64
|
+
return
|
65
|
+
|
66
|
+
if settings.letta_pg_uri_no_default:
|
67
|
+
logger.info("Creating postgres engine")
|
68
|
+
config.recall_storage_type = "postgres"
|
69
|
+
config.recall_storage_uri = settings.letta_pg_uri_no_default
|
70
|
+
config.archival_storage_type = "postgres"
|
71
|
+
config.archival_storage_uri = settings.letta_pg_uri_no_default
|
72
|
+
|
73
|
+
# create engine
|
74
|
+
engine = create_engine(
|
75
|
+
settings.letta_pg_uri,
|
76
|
+
# f"{settings.letta_pg_uri}?options=-c%20client_encoding=UTF8",
|
77
|
+
pool_size=settings.pg_pool_size,
|
78
|
+
max_overflow=settings.pg_max_overflow,
|
79
|
+
pool_timeout=settings.pg_pool_timeout,
|
80
|
+
pool_recycle=settings.pg_pool_recycle,
|
81
|
+
echo=settings.pg_echo,
|
82
|
+
# connect_args={"client_encoding": "utf8"},
|
83
|
+
)
|
84
|
+
else:
|
85
|
+
# TODO: don't rely on config storage
|
86
|
+
engine_path = "sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db")
|
87
|
+
logger.info("Creating sqlite engine " + engine_path)
|
88
|
+
|
89
|
+
engine = create_engine(engine_path)
|
90
|
+
|
91
|
+
# Store the original connect method
|
92
|
+
original_connect = engine.connect
|
93
|
+
|
94
|
+
def wrapped_connect(*args, **kwargs):
|
90
95
|
with db_error_handler():
|
91
|
-
|
96
|
+
# Get the connection
|
97
|
+
connection = original_connect(*args, **kwargs)
|
98
|
+
|
99
|
+
# Store the original execution method
|
100
|
+
original_execute = connection.execute
|
92
101
|
|
93
|
-
|
94
|
-
|
102
|
+
# Wrap the execute method of the connection
|
103
|
+
def wrapped_execute(*args, **kwargs):
|
104
|
+
with db_error_handler():
|
105
|
+
return original_execute(*args, **kwargs)
|
95
106
|
|
96
|
-
|
107
|
+
# Replace the connection's execute method
|
108
|
+
connection.execute = wrapped_execute
|
97
109
|
|
98
|
-
|
99
|
-
engine.connect = wrapped_connect
|
110
|
+
return connection
|
100
111
|
|
101
|
-
|
112
|
+
# Replace the engine's connect method
|
113
|
+
engine.connect = wrapped_connect
|
114
|
+
|
115
|
+
Base.metadata.create_all(bind=engine)
|
116
|
+
|
117
|
+
# Create the session factory
|
118
|
+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
119
|
+
_engine_initialized = True
|
102
120
|
|
103
121
|
|
104
122
|
def get_db():
|
123
|
+
"""Get a database session, initializing the engine if needed."""
|
124
|
+
global engine, SessionLocal
|
125
|
+
|
126
|
+
# Make sure engine is initialized
|
127
|
+
if not _engine_initialized:
|
128
|
+
initialize_engine()
|
129
|
+
|
130
|
+
# Now SessionLocal should be defined and callable
|
105
131
|
db = SessionLocal()
|
106
132
|
try:
|
107
133
|
yield db
|
@@ -109,5 +135,5 @@ def get_db():
|
|
109
135
|
db.close()
|
110
136
|
|
111
137
|
|
112
|
-
|
138
|
+
# Define db_context as a context manager that uses get_db
|
113
139
|
db_context = contextmanager(get_db)
|
letta/server/rest_api/app.py
CHANGED
@@ -174,6 +174,7 @@ def create_application() -> "FastAPI":
|
|
174
174
|
async def generic_error_handler(request: Request, exc: Exception):
|
175
175
|
# Log the actual error for debugging
|
176
176
|
log.error(f"Unhandled error: {exc}", exc_info=True)
|
177
|
+
print(f"Unhandled error: {exc}")
|
177
178
|
|
178
179
|
# Print the stack trace
|
179
180
|
print(f"Stack trace: {exc}")
|
@@ -6,6 +6,8 @@ from collections import deque
|
|
6
6
|
from datetime import datetime
|
7
7
|
from typing import AsyncGenerator, Literal, Optional, Union
|
8
8
|
|
9
|
+
import demjson3 as demjson
|
10
|
+
|
9
11
|
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
10
12
|
from letta.helpers.datetime_helpers import is_utc_datetime
|
11
13
|
from letta.interface import AgentInterface
|
@@ -502,7 +504,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
|
502
504
|
date=message_date,
|
503
505
|
reasoning=message_delta.reasoning_content,
|
504
506
|
signature=message_delta.reasoning_content_signature,
|
505
|
-
source="reasoner_model" if message_delta.
|
507
|
+
source="reasoner_model" if message_delta.reasoning_content else "non_reasoner_model",
|
506
508
|
name=name,
|
507
509
|
otid=otid,
|
508
510
|
)
|
@@ -530,7 +532,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
|
530
532
|
try:
|
531
533
|
# NOTE: this is hardcoded for our DeepSeek API integration
|
532
534
|
json_reasoning_content = parse_json(self.expect_reasoning_content_buffer)
|
533
|
-
print(f"json_reasoning_content: {json_reasoning_content}")
|
534
535
|
|
535
536
|
processed_chunk = ToolCallMessage(
|
536
537
|
id=message_id,
|
@@ -547,6 +548,10 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
|
547
548
|
except json.JSONDecodeError as e:
|
548
549
|
print(f"Failed to interpret reasoning content ({self.expect_reasoning_content_buffer}) as JSON: {e}")
|
549
550
|
|
551
|
+
return None
|
552
|
+
except demjson.JSONDecodeError as e:
|
553
|
+
print(f"Failed to interpret reasoning content ({self.expect_reasoning_content_buffer}) as JSON: {e}")
|
554
|
+
|
550
555
|
return None
|
551
556
|
# Else,
|
552
557
|
# return None
|
@@ -5,6 +5,7 @@ from letta.server.rest_api.routers.v1.health import router as health_router
|
|
5
5
|
from letta.server.rest_api.routers.v1.identities import router as identities_router
|
6
6
|
from letta.server.rest_api.routers.v1.jobs import router as jobs_router
|
7
7
|
from letta.server.rest_api.routers.v1.llms import router as llm_router
|
8
|
+
from letta.server.rest_api.routers.v1.messages import router as messages_router
|
8
9
|
from letta.server.rest_api.routers.v1.providers import router as providers_router
|
9
10
|
from letta.server.rest_api.routers.v1.runs import router as runs_router
|
10
11
|
from letta.server.rest_api.routers.v1.sandbox_configs import router as sandbox_configs_router
|
@@ -29,5 +30,6 @@ ROUTERS = [
|
|
29
30
|
runs_router,
|
30
31
|
steps_router,
|
31
32
|
tags_router,
|
33
|
+
messages_router,
|
32
34
|
voice_router,
|
33
35
|
]
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import json
|
2
2
|
import traceback
|
3
|
-
from datetime import datetime
|
3
|
+
from datetime import datetime, timezone
|
4
4
|
from typing import Annotated, Any, List, Optional
|
5
5
|
|
6
6
|
from fastapi import APIRouter, BackgroundTasks, Body, Depends, File, Header, HTTPException, Query, UploadFile, status
|
@@ -17,6 +17,7 @@ from letta.log import get_logger
|
|
17
17
|
from letta.orm.errors import NoResultFound
|
18
18
|
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgent
|
19
19
|
from letta.schemas.block import Block, BlockUpdate
|
20
|
+
from letta.schemas.group import Group
|
20
21
|
from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig
|
21
22
|
from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion
|
22
23
|
from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
|
@@ -173,7 +174,7 @@ async def import_agent_serialized(
|
|
173
174
|
raise HTTPException(status_code=400, detail="Corrupted agent file format.")
|
174
175
|
|
175
176
|
except ValidationError as e:
|
176
|
-
raise HTTPException(status_code=422, detail=f"Invalid agent schema: {e
|
177
|
+
raise HTTPException(status_code=422, detail=f"Invalid agent schema: {str(e)}")
|
177
178
|
|
178
179
|
except IntegrityError as e:
|
179
180
|
raise HTTPException(status_code=409, detail=f"Database integrity error: {str(e)}")
|
@@ -282,6 +283,7 @@ def detach_tool(
|
|
282
283
|
def attach_source(
|
283
284
|
agent_id: str,
|
284
285
|
source_id: str,
|
286
|
+
background_tasks: BackgroundTasks,
|
285
287
|
server: "SyncServer" = Depends(get_letta_server),
|
286
288
|
actor_id: Optional[str] = Header(None, alias="user_id"),
|
287
289
|
):
|
@@ -289,7 +291,11 @@ def attach_source(
|
|
289
291
|
Attach a source to an agent.
|
290
292
|
"""
|
291
293
|
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
292
|
-
|
294
|
+
agent = server.agent_manager.attach_source(agent_id=agent_id, source_id=source_id, actor=actor)
|
295
|
+
if agent.enable_sleeptime:
|
296
|
+
source = server.source_manager.get_source_by_id(source_id=source_id)
|
297
|
+
background_tasks.add_task(server.sleeptime_document_ingest, agent, source, actor)
|
298
|
+
return agent
|
293
299
|
|
294
300
|
|
295
301
|
@router.patch("/{agent_id}/sources/detach/{source_id}", response_model=AgentState, operation_id="detach_source_from_agent")
|
@@ -303,7 +309,15 @@ def detach_source(
|
|
303
309
|
Detach a source from an agent.
|
304
310
|
"""
|
305
311
|
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
306
|
-
|
312
|
+
agent = server.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor)
|
313
|
+
if agent.enable_sleeptime:
|
314
|
+
try:
|
315
|
+
source = server.source_manager.get_source_by_id(source_id=source_id)
|
316
|
+
block = server.agent_manager.get_block_with_label(agent_id=agent.id, block_label=source.name, actor=actor)
|
317
|
+
server.block_manager.delete_block(block.id, actor)
|
318
|
+
except:
|
319
|
+
pass
|
320
|
+
return agent
|
307
321
|
|
308
322
|
|
309
323
|
@router.get("/{agent_id}", response_model=AgentState, operation_id="retrieve_agent")
|
@@ -728,7 +742,7 @@ async def process_message_background(
|
|
728
742
|
# Update job status to completed
|
729
743
|
job_update = JobUpdate(
|
730
744
|
status=JobStatus.completed,
|
731
|
-
completed_at=datetime.
|
745
|
+
completed_at=datetime.now(timezone.utc),
|
732
746
|
metadata={"result": result.model_dump(mode="json")}, # Store the result in metadata
|
733
747
|
)
|
734
748
|
server.job_manager.update_job_by_id(job_id=job_id, job_update=job_update, actor=actor)
|
@@ -737,7 +751,7 @@ async def process_message_background(
|
|
737
751
|
# Update job status to failed
|
738
752
|
job_update = JobUpdate(
|
739
753
|
status=JobStatus.failed,
|
740
|
-
completed_at=datetime.
|
754
|
+
completed_at=datetime.now(timezone.utc),
|
741
755
|
metadata={"error": str(e)},
|
742
756
|
)
|
743
757
|
server.job_manager.update_job_by_id(job_id=job_id, job_update=job_update, actor=actor)
|
@@ -804,3 +818,16 @@ def reset_messages(
|
|
804
818
|
"""Resets the messages for an agent"""
|
805
819
|
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
806
820
|
return server.agent_manager.reset_messages(agent_id=agent_id, actor=actor, add_default_initial_messages=add_default_initial_messages)
|
821
|
+
|
822
|
+
|
823
|
+
@router.get("/{agent_id}/groups", response_model=List[Group], operation_id="list_agent_groups")
|
824
|
+
async def list_agent_groups(
|
825
|
+
agent_id: str,
|
826
|
+
manager_type: Optional[str] = Query(None, description="Manager type to filter groups by"),
|
827
|
+
server: "SyncServer" = Depends(get_letta_server),
|
828
|
+
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
829
|
+
):
|
830
|
+
"""Lists the groups for an agent"""
|
831
|
+
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
832
|
+
print("in list agents with manager_type", manager_type)
|
833
|
+
return server.agent_manager.list_groups(agent_id=agent_id, manager_type=manager_type, actor=actor)
|
@@ -0,0 +1,132 @@
|
|
1
|
+
from typing import List, Optional
|
2
|
+
|
3
|
+
from fastapi import APIRouter, Body, Depends, Header
|
4
|
+
from fastapi.exceptions import HTTPException
|
5
|
+
from starlette.requests import Request
|
6
|
+
|
7
|
+
from letta.agents.letta_agent_batch import LettaAgentBatch
|
8
|
+
from letta.log import get_logger
|
9
|
+
from letta.orm.errors import NoResultFound
|
10
|
+
from letta.schemas.job import BatchJob, JobStatus, JobType
|
11
|
+
from letta.schemas.letta_request import CreateBatch
|
12
|
+
from letta.server.rest_api.utils import get_letta_server
|
13
|
+
from letta.server.server import SyncServer
|
14
|
+
|
15
|
+
router = APIRouter(prefix="/messages", tags=["messages"])
|
16
|
+
|
17
|
+
logger = get_logger(__name__)
|
18
|
+
|
19
|
+
|
20
|
+
# Batch APIs
|
21
|
+
|
22
|
+
|
23
|
+
@router.post(
|
24
|
+
"/batches",
|
25
|
+
response_model=BatchJob,
|
26
|
+
operation_id="create_messages_batch",
|
27
|
+
)
|
28
|
+
async def create_messages_batch(
|
29
|
+
request: Request,
|
30
|
+
payload: CreateBatch = Body(..., description="Messages and config for all agents"),
|
31
|
+
server: SyncServer = Depends(get_letta_server),
|
32
|
+
actor_id: Optional[str] = Header(None, alias="user_id"),
|
33
|
+
):
|
34
|
+
"""
|
35
|
+
Submit a batch of agent messages for asynchronous processing.
|
36
|
+
Creates a job that will fan out messages to all listed agents and process them in parallel.
|
37
|
+
"""
|
38
|
+
# Reject requests greater than 256Mbs
|
39
|
+
max_bytes = 256 * 1024 * 1024
|
40
|
+
content_length = request.headers.get("content-length")
|
41
|
+
if content_length:
|
42
|
+
length = int(content_length)
|
43
|
+
if length > max_bytes:
|
44
|
+
raise HTTPException(status_code=413, detail=f"Request too large ({length} bytes). Max is {max_bytes} bytes.")
|
45
|
+
|
46
|
+
try:
|
47
|
+
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
48
|
+
|
49
|
+
# Create a new job
|
50
|
+
batch_job = BatchJob(
|
51
|
+
user_id=actor.id,
|
52
|
+
status=JobStatus.created,
|
53
|
+
metadata={
|
54
|
+
"job_type": "batch_messages",
|
55
|
+
},
|
56
|
+
callback_url=str(payload.callback_url),
|
57
|
+
)
|
58
|
+
|
59
|
+
# create the batch runner
|
60
|
+
batch_runner = LettaAgentBatch(
|
61
|
+
message_manager=server.message_manager,
|
62
|
+
agent_manager=server.agent_manager,
|
63
|
+
block_manager=server.block_manager,
|
64
|
+
passage_manager=server.passage_manager,
|
65
|
+
batch_manager=server.batch_manager,
|
66
|
+
sandbox_config_manager=server.sandbox_config_manager,
|
67
|
+
job_manager=server.job_manager,
|
68
|
+
actor=actor,
|
69
|
+
)
|
70
|
+
llm_batch_job = await batch_runner.step_until_request(batch_requests=payload.requests, letta_batch_job_id=batch_job.id)
|
71
|
+
|
72
|
+
# TODO: update run metadata
|
73
|
+
batch_job = server.job_manager.create_job(pydantic_job=batch_job, actor=actor)
|
74
|
+
except Exception:
|
75
|
+
import traceback
|
76
|
+
|
77
|
+
traceback.print_exc()
|
78
|
+
raise
|
79
|
+
return batch_job
|
80
|
+
|
81
|
+
|
82
|
+
@router.get("/batches/{batch_id}", response_model=BatchJob, operation_id="retrieve_batch_run")
|
83
|
+
async def retrieve_batch_run(
|
84
|
+
batch_id: str,
|
85
|
+
actor_id: Optional[str] = Header(None, alias="user_id"),
|
86
|
+
server: "SyncServer" = Depends(get_letta_server),
|
87
|
+
):
|
88
|
+
"""
|
89
|
+
Get the status of a batch run.
|
90
|
+
"""
|
91
|
+
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
92
|
+
|
93
|
+
try:
|
94
|
+
job = server.job_manager.get_job_by_id(job_id=batch_id, actor=actor)
|
95
|
+
return BatchJob.from_job(job)
|
96
|
+
except NoResultFound:
|
97
|
+
raise HTTPException(status_code=404, detail="Batch not found")
|
98
|
+
|
99
|
+
|
100
|
+
@router.get("/batches", response_model=List[BatchJob], operation_id="list_batch_runs")
|
101
|
+
async def list_batch_runs(
|
102
|
+
actor_id: Optional[str] = Header(None, alias="user_id"),
|
103
|
+
server: "SyncServer" = Depends(get_letta_server),
|
104
|
+
):
|
105
|
+
"""
|
106
|
+
List all batch runs.
|
107
|
+
"""
|
108
|
+
# TODO: filter
|
109
|
+
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
110
|
+
|
111
|
+
jobs = server.job_manager.list_jobs(actor=actor, statuses=[JobStatus.created, JobStatus.running], job_type=JobType.BATCH)
|
112
|
+
return [BatchJob.from_job(job) for job in jobs]
|
113
|
+
|
114
|
+
|
115
|
+
@router.patch("/batches/{batch_id}/cancel", operation_id="cancel_batch_run")
|
116
|
+
async def cancel_batch_run(
|
117
|
+
batch_id: str,
|
118
|
+
server: "SyncServer" = Depends(get_letta_server),
|
119
|
+
actor_id: Optional[str] = Header(None, alias="user_id"),
|
120
|
+
):
|
121
|
+
"""
|
122
|
+
Cancel a batch run.
|
123
|
+
"""
|
124
|
+
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
125
|
+
|
126
|
+
try:
|
127
|
+
job = server.job_manager.get_job_by_id(job_id=batch_id, actor=actor)
|
128
|
+
job.status = JobStatus.cancelled
|
129
|
+
server.job_manager.update_job_by_id(job_id=job, job=job)
|
130
|
+
# TODO: actually cancel it
|
131
|
+
except NoResultFound:
|
132
|
+
raise HTTPException(status_code=404, detail="Run not found")
|