letta-nightly 0.7.20.dev20250520104253__py3-none-any.whl → 0.7.21.dev20250521233415__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- letta/__init__.py +1 -1
- letta/agent.py +290 -3
- letta/agents/base_agent.py +0 -55
- letta/agents/helpers.py +5 -0
- letta/agents/letta_agent.py +314 -64
- letta/agents/letta_agent_batch.py +102 -55
- letta/agents/voice_agent.py +5 -5
- letta/client/client.py +9 -18
- letta/constants.py +55 -1
- letta/functions/function_sets/builtin.py +27 -0
- letta/functions/mcp_client/stdio_client.py +1 -1
- letta/groups/sleeptime_multi_agent_v2.py +1 -1
- letta/interfaces/anthropic_streaming_interface.py +10 -1
- letta/interfaces/openai_streaming_interface.py +9 -2
- letta/llm_api/anthropic.py +21 -2
- letta/llm_api/anthropic_client.py +33 -6
- letta/llm_api/google_ai_client.py +136 -423
- letta/llm_api/google_vertex_client.py +173 -22
- letta/llm_api/llm_api_tools.py +27 -0
- letta/llm_api/llm_client.py +1 -1
- letta/llm_api/llm_client_base.py +32 -21
- letta/llm_api/openai.py +57 -0
- letta/llm_api/openai_client.py +7 -11
- letta/memory.py +0 -1
- letta/orm/__init__.py +1 -0
- letta/orm/enums.py +1 -0
- letta/orm/provider_trace.py +26 -0
- letta/orm/step.py +1 -0
- letta/schemas/provider_trace.py +43 -0
- letta/schemas/providers.py +210 -65
- letta/schemas/step.py +1 -0
- letta/schemas/tool.py +4 -0
- letta/server/db.py +37 -19
- letta/server/rest_api/routers/v1/__init__.py +2 -0
- letta/server/rest_api/routers/v1/agents.py +57 -34
- letta/server/rest_api/routers/v1/blocks.py +3 -3
- letta/server/rest_api/routers/v1/identities.py +24 -26
- letta/server/rest_api/routers/v1/jobs.py +3 -3
- letta/server/rest_api/routers/v1/llms.py +13 -8
- letta/server/rest_api/routers/v1/sandbox_configs.py +6 -6
- letta/server/rest_api/routers/v1/tags.py +3 -3
- letta/server/rest_api/routers/v1/telemetry.py +18 -0
- letta/server/rest_api/routers/v1/tools.py +6 -6
- letta/server/rest_api/streaming_response.py +105 -0
- letta/server/rest_api/utils.py +4 -0
- letta/server/server.py +140 -0
- letta/services/agent_manager.py +251 -18
- letta/services/block_manager.py +52 -37
- letta/services/helpers/noop_helper.py +10 -0
- letta/services/identity_manager.py +43 -38
- letta/services/job_manager.py +29 -0
- letta/services/message_manager.py +111 -0
- letta/services/sandbox_config_manager.py +36 -0
- letta/services/step_manager.py +146 -0
- letta/services/telemetry_manager.py +58 -0
- letta/services/tool_executor/tool_execution_manager.py +49 -5
- letta/services/tool_executor/tool_execution_sandbox.py +47 -0
- letta/services/tool_executor/tool_executor.py +236 -7
- letta/services/tool_manager.py +160 -1
- letta/services/tool_sandbox/e2b_sandbox.py +65 -3
- letta/settings.py +10 -2
- letta/tracing.py +5 -5
- {letta_nightly-0.7.20.dev20250520104253.dist-info → letta_nightly-0.7.21.dev20250521233415.dist-info}/METADATA +3 -2
- {letta_nightly-0.7.20.dev20250520104253.dist-info → letta_nightly-0.7.21.dev20250521233415.dist-info}/RECORD +67 -60
- {letta_nightly-0.7.20.dev20250520104253.dist-info → letta_nightly-0.7.21.dev20250521233415.dist-info}/LICENSE +0 -0
- {letta_nightly-0.7.20.dev20250520104253.dist-info → letta_nightly-0.7.21.dev20250521233415.dist-info}/WHEEL +0 -0
- {letta_nightly-0.7.20.dev20250520104253.dist-info → letta_nightly-0.7.21.dev20250521233415.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,105 @@
|
|
1
|
+
# Alternative implementation of StreamingResponse that allows for effectively
|
2
|
+
# stremaing HTTP trailers, as we cannot set codes after the initial response.
|
3
|
+
# Taken from: https://github.com/fastapi/fastapi/discussions/10138#discussioncomment-10377361
|
4
|
+
|
5
|
+
import json
|
6
|
+
from collections.abc import AsyncIterator
|
7
|
+
|
8
|
+
from fastapi.responses import StreamingResponse
|
9
|
+
from starlette.types import Send
|
10
|
+
|
11
|
+
from letta.log import get_logger
|
12
|
+
|
13
|
+
logger = get_logger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
class StreamingResponseWithStatusCode(StreamingResponse):
|
17
|
+
"""
|
18
|
+
Variation of StreamingResponse that can dynamically decide the HTTP status code,
|
19
|
+
based on the return value of the content iterator (parameter `content`).
|
20
|
+
Expects the content to yield either just str content as per the original `StreamingResponse`
|
21
|
+
or else tuples of (`content`: `str`, `status_code`: `int`).
|
22
|
+
"""
|
23
|
+
|
24
|
+
body_iterator: AsyncIterator[str | bytes]
|
25
|
+
response_started: bool = False
|
26
|
+
|
27
|
+
async def stream_response(self, send: Send) -> None:
|
28
|
+
more_body = True
|
29
|
+
try:
|
30
|
+
first_chunk = await self.body_iterator.__anext__()
|
31
|
+
if isinstance(first_chunk, tuple):
|
32
|
+
first_chunk_content, self.status_code = first_chunk
|
33
|
+
else:
|
34
|
+
first_chunk_content = first_chunk
|
35
|
+
if isinstance(first_chunk_content, str):
|
36
|
+
first_chunk_content = first_chunk_content.encode(self.charset)
|
37
|
+
|
38
|
+
await send(
|
39
|
+
{
|
40
|
+
"type": "http.response.start",
|
41
|
+
"status": self.status_code,
|
42
|
+
"headers": self.raw_headers,
|
43
|
+
}
|
44
|
+
)
|
45
|
+
self.response_started = True
|
46
|
+
await send(
|
47
|
+
{
|
48
|
+
"type": "http.response.body",
|
49
|
+
"body": first_chunk_content,
|
50
|
+
"more_body": more_body,
|
51
|
+
}
|
52
|
+
)
|
53
|
+
|
54
|
+
async for chunk in self.body_iterator:
|
55
|
+
if isinstance(chunk, tuple):
|
56
|
+
content, status_code = chunk
|
57
|
+
if status_code // 100 != 2:
|
58
|
+
# An error occurred mid-stream
|
59
|
+
if not isinstance(content, bytes):
|
60
|
+
content = content.encode(self.charset)
|
61
|
+
more_body = False
|
62
|
+
await send(
|
63
|
+
{
|
64
|
+
"type": "http.response.body",
|
65
|
+
"body": content,
|
66
|
+
"more_body": more_body,
|
67
|
+
}
|
68
|
+
)
|
69
|
+
return
|
70
|
+
else:
|
71
|
+
content = chunk
|
72
|
+
|
73
|
+
if isinstance(content, str):
|
74
|
+
content = content.encode(self.charset)
|
75
|
+
more_body = True
|
76
|
+
await send(
|
77
|
+
{
|
78
|
+
"type": "http.response.body",
|
79
|
+
"body": content,
|
80
|
+
"more_body": more_body,
|
81
|
+
}
|
82
|
+
)
|
83
|
+
|
84
|
+
except Exception:
|
85
|
+
logger.exception("unhandled_streaming_error")
|
86
|
+
more_body = False
|
87
|
+
error_resp = {"error": {"message": "Internal Server Error"}}
|
88
|
+
error_event = f"event: error\ndata: {json.dumps(error_resp)}\n\n".encode(self.charset)
|
89
|
+
if not self.response_started:
|
90
|
+
await send(
|
91
|
+
{
|
92
|
+
"type": "http.response.start",
|
93
|
+
"status": 500,
|
94
|
+
"headers": self.raw_headers,
|
95
|
+
}
|
96
|
+
)
|
97
|
+
await send(
|
98
|
+
{
|
99
|
+
"type": "http.response.body",
|
100
|
+
"body": error_event,
|
101
|
+
"more_body": more_body,
|
102
|
+
}
|
103
|
+
)
|
104
|
+
if more_body:
|
105
|
+
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
letta/server/rest_api/utils.py
CHANGED
@@ -190,6 +190,7 @@ def create_letta_messages_from_llm_response(
|
|
190
190
|
pre_computed_assistant_message_id: Optional[str] = None,
|
191
191
|
pre_computed_tool_message_id: Optional[str] = None,
|
192
192
|
llm_batch_item_id: Optional[str] = None,
|
193
|
+
step_id: str | None = None,
|
193
194
|
) -> List[Message]:
|
194
195
|
messages = []
|
195
196
|
|
@@ -244,6 +245,9 @@ def create_letta_messages_from_llm_response(
|
|
244
245
|
)
|
245
246
|
messages.append(heartbeat_system_message)
|
246
247
|
|
248
|
+
for message in messages:
|
249
|
+
message.step_id = step_id
|
250
|
+
|
247
251
|
return messages
|
248
252
|
|
249
253
|
|
letta/server/server.py
CHANGED
@@ -94,6 +94,7 @@ from letta.services.provider_manager import ProviderManager
|
|
94
94
|
from letta.services.sandbox_config_manager import SandboxConfigManager
|
95
95
|
from letta.services.source_manager import SourceManager
|
96
96
|
from letta.services.step_manager import StepManager
|
97
|
+
from letta.services.telemetry_manager import TelemetryManager
|
97
98
|
from letta.services.tool_executor.tool_execution_sandbox import ToolExecutionSandbox
|
98
99
|
from letta.services.tool_manager import ToolManager
|
99
100
|
from letta.services.user_manager import UserManager
|
@@ -213,6 +214,7 @@ class SyncServer(Server):
|
|
213
214
|
self.identity_manager = IdentityManager()
|
214
215
|
self.group_manager = GroupManager()
|
215
216
|
self.batch_manager = LLMBatchManager()
|
217
|
+
self.telemetry_manager = TelemetryManager()
|
216
218
|
|
217
219
|
# A resusable httpx client
|
218
220
|
timeout = httpx.Timeout(connect=10.0, read=20.0, write=10.0, pool=10.0)
|
@@ -1000,6 +1002,30 @@ class SyncServer(Server):
|
|
1000
1002
|
)
|
1001
1003
|
return records
|
1002
1004
|
|
1005
|
+
async def get_agent_archival_async(
|
1006
|
+
self,
|
1007
|
+
agent_id: str,
|
1008
|
+
actor: User,
|
1009
|
+
after: Optional[str] = None,
|
1010
|
+
before: Optional[str] = None,
|
1011
|
+
limit: Optional[int] = 100,
|
1012
|
+
order_by: Optional[str] = "created_at",
|
1013
|
+
reverse: Optional[bool] = False,
|
1014
|
+
query_text: Optional[str] = None,
|
1015
|
+
ascending: Optional[bool] = True,
|
1016
|
+
) -> List[Passage]:
|
1017
|
+
# iterate over records
|
1018
|
+
records = await self.agent_manager.list_passages_async(
|
1019
|
+
actor=actor,
|
1020
|
+
agent_id=agent_id,
|
1021
|
+
after=after,
|
1022
|
+
query_text=query_text,
|
1023
|
+
before=before,
|
1024
|
+
ascending=ascending,
|
1025
|
+
limit=limit,
|
1026
|
+
)
|
1027
|
+
return records
|
1028
|
+
|
1003
1029
|
def insert_archival_memory(self, agent_id: str, memory_contents: str, actor: User) -> List[Passage]:
|
1004
1030
|
# Get the agent object (loaded in memory)
|
1005
1031
|
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
|
@@ -1070,6 +1096,44 @@ class SyncServer(Server):
|
|
1070
1096
|
|
1071
1097
|
return records
|
1072
1098
|
|
1099
|
+
async def get_agent_recall_async(
|
1100
|
+
self,
|
1101
|
+
agent_id: str,
|
1102
|
+
actor: User,
|
1103
|
+
after: Optional[str] = None,
|
1104
|
+
before: Optional[str] = None,
|
1105
|
+
limit: Optional[int] = 100,
|
1106
|
+
group_id: Optional[str] = None,
|
1107
|
+
reverse: Optional[bool] = False,
|
1108
|
+
return_message_object: bool = True,
|
1109
|
+
use_assistant_message: bool = True,
|
1110
|
+
assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL,
|
1111
|
+
assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG,
|
1112
|
+
) -> Union[List[Message], List[LettaMessage]]:
|
1113
|
+
records = await self.message_manager.list_messages_for_agent_async(
|
1114
|
+
agent_id=agent_id,
|
1115
|
+
actor=actor,
|
1116
|
+
after=after,
|
1117
|
+
before=before,
|
1118
|
+
limit=limit,
|
1119
|
+
ascending=not reverse,
|
1120
|
+
group_id=group_id,
|
1121
|
+
)
|
1122
|
+
|
1123
|
+
if not return_message_object:
|
1124
|
+
records = Message.to_letta_messages_from_list(
|
1125
|
+
messages=records,
|
1126
|
+
use_assistant_message=use_assistant_message,
|
1127
|
+
assistant_message_tool_name=assistant_message_tool_name,
|
1128
|
+
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
1129
|
+
reverse=reverse,
|
1130
|
+
)
|
1131
|
+
|
1132
|
+
if reverse:
|
1133
|
+
records = records[::-1]
|
1134
|
+
|
1135
|
+
return records
|
1136
|
+
|
1073
1137
|
def get_server_config(self, include_defaults: bool = False) -> dict:
|
1074
1138
|
"""Return the base config"""
|
1075
1139
|
|
@@ -1301,6 +1365,48 @@ class SyncServer(Server):
|
|
1301
1365
|
|
1302
1366
|
return llm_models
|
1303
1367
|
|
1368
|
+
@trace_method
|
1369
|
+
async def list_llm_models_async(
|
1370
|
+
self,
|
1371
|
+
actor: User,
|
1372
|
+
provider_category: Optional[List[ProviderCategory]] = None,
|
1373
|
+
provider_name: Optional[str] = None,
|
1374
|
+
provider_type: Optional[ProviderType] = None,
|
1375
|
+
) -> List[LLMConfig]:
|
1376
|
+
"""Asynchronously list available models with maximum concurrency"""
|
1377
|
+
import asyncio
|
1378
|
+
|
1379
|
+
providers = self.get_enabled_providers(
|
1380
|
+
provider_category=provider_category,
|
1381
|
+
provider_name=provider_name,
|
1382
|
+
provider_type=provider_type,
|
1383
|
+
actor=actor,
|
1384
|
+
)
|
1385
|
+
|
1386
|
+
async def get_provider_models(provider):
|
1387
|
+
try:
|
1388
|
+
return await provider.list_llm_models_async()
|
1389
|
+
except Exception as e:
|
1390
|
+
import traceback
|
1391
|
+
|
1392
|
+
traceback.print_exc()
|
1393
|
+
warnings.warn(f"An error occurred while listing LLM models for provider {provider}: {e}")
|
1394
|
+
return []
|
1395
|
+
|
1396
|
+
# Execute all provider model listing tasks concurrently
|
1397
|
+
provider_results = await asyncio.gather(*[get_provider_models(provider) for provider in providers])
|
1398
|
+
|
1399
|
+
# Flatten the results
|
1400
|
+
llm_models = []
|
1401
|
+
for models in provider_results:
|
1402
|
+
llm_models.extend(models)
|
1403
|
+
|
1404
|
+
# Get local configs - if this is potentially slow, consider making it async too
|
1405
|
+
local_configs = self.get_local_llm_configs()
|
1406
|
+
llm_models.extend(local_configs)
|
1407
|
+
|
1408
|
+
return llm_models
|
1409
|
+
|
1304
1410
|
def list_embedding_models(self, actor: User) -> List[EmbeddingConfig]:
|
1305
1411
|
"""List available embedding models"""
|
1306
1412
|
embedding_models = []
|
@@ -1311,6 +1417,35 @@ class SyncServer(Server):
|
|
1311
1417
|
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
|
1312
1418
|
return embedding_models
|
1313
1419
|
|
1420
|
+
async def list_embedding_models_async(self, actor: User) -> List[EmbeddingConfig]:
|
1421
|
+
"""Asynchronously list available embedding models with maximum concurrency"""
|
1422
|
+
import asyncio
|
1423
|
+
|
1424
|
+
# Get all eligible providers first
|
1425
|
+
providers = self.get_enabled_providers(actor=actor)
|
1426
|
+
|
1427
|
+
# Fetch embedding models from each provider concurrently
|
1428
|
+
async def get_provider_embedding_models(provider):
|
1429
|
+
try:
|
1430
|
+
# All providers now have list_embedding_models_async
|
1431
|
+
return await provider.list_embedding_models_async()
|
1432
|
+
except Exception as e:
|
1433
|
+
import traceback
|
1434
|
+
|
1435
|
+
traceback.print_exc()
|
1436
|
+
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
|
1437
|
+
return []
|
1438
|
+
|
1439
|
+
# Execute all provider model listing tasks concurrently
|
1440
|
+
provider_results = await asyncio.gather(*[get_provider_embedding_models(provider) for provider in providers])
|
1441
|
+
|
1442
|
+
# Flatten the results
|
1443
|
+
embedding_models = []
|
1444
|
+
for models in provider_results:
|
1445
|
+
embedding_models.extend(models)
|
1446
|
+
|
1447
|
+
return embedding_models
|
1448
|
+
|
1314
1449
|
def get_enabled_providers(
|
1315
1450
|
self,
|
1316
1451
|
actor: User,
|
@@ -1482,6 +1617,10 @@ class SyncServer(Server):
|
|
1482
1617
|
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
1483
1618
|
return letta_agent.get_context_window()
|
1484
1619
|
|
1620
|
+
async def get_agent_context_window_async(self, agent_id: str, actor: User) -> ContextWindowOverview:
|
1621
|
+
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
1622
|
+
return await letta_agent.get_context_window_async()
|
1623
|
+
|
1485
1624
|
def run_tool_from_source(
|
1486
1625
|
self,
|
1487
1626
|
actor: User,
|
@@ -1615,6 +1754,7 @@ class SyncServer(Server):
|
|
1615
1754
|
server_name=server_name,
|
1616
1755
|
command=server_params_raw["command"],
|
1617
1756
|
args=server_params_raw.get("args", []),
|
1757
|
+
env=server_params_raw.get("env", {}),
|
1618
1758
|
)
|
1619
1759
|
mcp_server_list[server_name] = server_params
|
1620
1760
|
except Exception as e:
|
letta/services/agent_manager.py
CHANGED
@@ -892,7 +892,7 @@ class AgentManager:
|
|
892
892
|
List[PydanticAgentState]: The filtered list of matching agents.
|
893
893
|
"""
|
894
894
|
async with db_registry.async_session() as session:
|
895
|
-
query = select(AgentModel)
|
895
|
+
query = select(AgentModel)
|
896
896
|
query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
897
897
|
|
898
898
|
# Apply filters
|
@@ -961,6 +961,16 @@ class AgentManager:
|
|
961
961
|
with db_registry.session() as session:
|
962
962
|
return AgentModel.size(db_session=session, actor=actor)
|
963
963
|
|
964
|
+
async def size_async(
|
965
|
+
self,
|
966
|
+
actor: PydanticUser,
|
967
|
+
) -> int:
|
968
|
+
"""
|
969
|
+
Get the total count of agents for the given user.
|
970
|
+
"""
|
971
|
+
async with db_registry.async_session() as session:
|
972
|
+
return await AgentModel.size_async(db_session=session, actor=actor)
|
973
|
+
|
964
974
|
@enforce_types
|
965
975
|
def get_agent_by_id(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
966
976
|
"""Fetch an agent by its ID."""
|
@@ -969,18 +979,32 @@ class AgentManager:
|
|
969
979
|
return agent.to_pydantic()
|
970
980
|
|
971
981
|
@enforce_types
|
972
|
-
async def get_agent_by_id_async(
|
982
|
+
async def get_agent_by_id_async(
|
983
|
+
self,
|
984
|
+
agent_id: str,
|
985
|
+
actor: PydanticUser,
|
986
|
+
include_relationships: Optional[List[str]] = None,
|
987
|
+
) -> PydanticAgentState:
|
973
988
|
"""Fetch an agent by its ID."""
|
974
989
|
async with db_registry.async_session() as session:
|
975
990
|
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
976
|
-
return agent.
|
991
|
+
return await agent.to_pydantic_async(include_relationships=include_relationships)
|
977
992
|
|
978
993
|
@enforce_types
|
979
|
-
async def get_agents_by_ids_async(
|
994
|
+
async def get_agents_by_ids_async(
|
995
|
+
self,
|
996
|
+
agent_ids: list[str],
|
997
|
+
actor: PydanticUser,
|
998
|
+
include_relationships: Optional[List[str]] = None,
|
999
|
+
) -> list[PydanticAgentState]:
|
980
1000
|
"""Fetch a list of agents by their IDs."""
|
981
1001
|
async with db_registry.async_session() as session:
|
982
|
-
agents = await AgentModel.read_multiple_async(
|
983
|
-
|
1002
|
+
agents = await AgentModel.read_multiple_async(
|
1003
|
+
db_session=session,
|
1004
|
+
identifiers=agent_ids,
|
1005
|
+
actor=actor,
|
1006
|
+
)
|
1007
|
+
return await asyncio.gather(*[agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents])
|
984
1008
|
|
985
1009
|
@enforce_types
|
986
1010
|
def get_agent_by_name(self, agent_name: str, actor: PydanticUser) -> PydanticAgentState:
|
@@ -1191,7 +1215,7 @@ class AgentManager:
|
|
1191
1215
|
|
1192
1216
|
@enforce_types
|
1193
1217
|
async def get_in_context_messages_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]:
|
1194
|
-
agent = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
1218
|
+
agent = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=[], actor=actor)
|
1195
1219
|
return await self.message_manager.get_messages_by_ids_async(message_ids=agent.message_ids, actor=actor)
|
1196
1220
|
|
1197
1221
|
@enforce_types
|
@@ -1199,6 +1223,11 @@ class AgentManager:
|
|
1199
1223
|
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
1200
1224
|
return self.message_manager.get_message_by_id(message_id=message_ids[0], actor=actor)
|
1201
1225
|
|
1226
|
+
@enforce_types
|
1227
|
+
async def get_system_message_async(self, agent_id: str, actor: PydanticUser) -> PydanticMessage:
|
1228
|
+
agent = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=[], actor=actor)
|
1229
|
+
return await self.message_manager.get_message_by_id_async(message_id=agent.message_ids[0], actor=actor)
|
1230
|
+
|
1202
1231
|
# TODO: This is duplicated below
|
1203
1232
|
# TODO: This is legacy code and should be cleaned up
|
1204
1233
|
# TODO: A lot of the memory "compilation" should be offset to a separate class
|
@@ -1267,10 +1296,81 @@ class AgentManager:
|
|
1267
1296
|
else:
|
1268
1297
|
return agent_state
|
1269
1298
|
|
1299
|
+
@enforce_types
|
1300
|
+
async def rebuild_system_prompt_async(
|
1301
|
+
self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True
|
1302
|
+
) -> PydanticAgentState:
|
1303
|
+
"""Rebuilds the system message with the latest memory object and any shared memory block updates
|
1304
|
+
|
1305
|
+
Updates to core memory blocks should trigger a "rebuild", which itself will create a new message object
|
1306
|
+
|
1307
|
+
Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages
|
1308
|
+
"""
|
1309
|
+
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory"], actor=actor)
|
1310
|
+
|
1311
|
+
curr_system_message = await self.get_system_message_async(
|
1312
|
+
agent_id=agent_id, actor=actor
|
1313
|
+
) # this is the system + memory bank, not just the system prompt
|
1314
|
+
curr_system_message_openai = curr_system_message.to_openai_dict()
|
1315
|
+
|
1316
|
+
# note: we only update the system prompt if the core memory is changed
|
1317
|
+
# this means that the archival/recall memory statistics may be someout out of date
|
1318
|
+
curr_memory_str = agent_state.memory.compile()
|
1319
|
+
if curr_memory_str in curr_system_message_openai["content"] and not force:
|
1320
|
+
# NOTE: could this cause issues if a block is removed? (substring match would still work)
|
1321
|
+
logger.debug(
|
1322
|
+
f"Memory hasn't changed for agent id={agent_id} and actor=({actor.id}, {actor.name}), skipping system prompt rebuild"
|
1323
|
+
)
|
1324
|
+
return agent_state
|
1325
|
+
|
1326
|
+
# If the memory didn't update, we probably don't want to update the timestamp inside
|
1327
|
+
# For example, if we're doing a system prompt swap, this should probably be False
|
1328
|
+
if update_timestamp:
|
1329
|
+
memory_edit_timestamp = get_utc_time()
|
1330
|
+
else:
|
1331
|
+
# NOTE: a bit of a hack - we pull the timestamp from the message created_by
|
1332
|
+
memory_edit_timestamp = curr_system_message.created_at
|
1333
|
+
|
1334
|
+
num_messages = await self.message_manager.size_async(actor=actor, agent_id=agent_id)
|
1335
|
+
num_archival_memories = await self.passage_manager.size_async(actor=actor, agent_id=agent_id)
|
1336
|
+
|
1337
|
+
# update memory (TODO: potentially update recall/archival stats separately)
|
1338
|
+
new_system_message_str = compile_system_message(
|
1339
|
+
system_prompt=agent_state.system,
|
1340
|
+
in_context_memory=agent_state.memory,
|
1341
|
+
in_context_memory_last_edit=memory_edit_timestamp,
|
1342
|
+
recent_passages=self.list_passages(actor=actor, agent_id=agent_id, ascending=False, limit=10),
|
1343
|
+
previous_message_count=num_messages,
|
1344
|
+
archival_memory_size=num_archival_memories,
|
1345
|
+
)
|
1346
|
+
|
1347
|
+
diff = united_diff(curr_system_message_openai["content"], new_system_message_str)
|
1348
|
+
if len(diff) > 0: # there was a diff
|
1349
|
+
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
|
1350
|
+
|
1351
|
+
# Swap the system message out (only if there is a diff)
|
1352
|
+
message = PydanticMessage.dict_to_message(
|
1353
|
+
agent_id=agent_id,
|
1354
|
+
model=agent_state.llm_config.model,
|
1355
|
+
openai_message_dict={"role": "system", "content": new_system_message_str},
|
1356
|
+
)
|
1357
|
+
message = await self.message_manager.update_message_by_id_async(
|
1358
|
+
message_id=curr_system_message.id,
|
1359
|
+
message_update=MessageUpdate(**message.model_dump()),
|
1360
|
+
actor=actor,
|
1361
|
+
)
|
1362
|
+
return await self.set_in_context_messages_async(agent_id=agent_id, message_ids=agent_state.message_ids, actor=actor)
|
1363
|
+
else:
|
1364
|
+
return agent_state
|
1365
|
+
|
1270
1366
|
@enforce_types
|
1271
1367
|
def set_in_context_messages(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState:
|
1272
1368
|
return self.update_agent(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor)
|
1273
1369
|
|
1370
|
+
@enforce_types
|
1371
|
+
async def set_in_context_messages_async(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState:
|
1372
|
+
return await self.update_agent_async(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor)
|
1373
|
+
|
1274
1374
|
@enforce_types
|
1275
1375
|
def trim_older_in_context_messages(self, num: int, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
1276
1376
|
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
@@ -1382,17 +1482,6 @@ class AgentManager:
|
|
1382
1482
|
|
1383
1483
|
return agent_state
|
1384
1484
|
|
1385
|
-
@enforce_types
|
1386
|
-
def refresh_memory(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState:
|
1387
|
-
block_ids = [b.id for b in agent_state.memory.blocks]
|
1388
|
-
if not block_ids:
|
1389
|
-
return agent_state
|
1390
|
-
|
1391
|
-
agent_state.memory.blocks = self.block_manager.get_all_blocks_by_ids(
|
1392
|
-
block_ids=[b.id for b in agent_state.memory.blocks], actor=actor
|
1393
|
-
)
|
1394
|
-
return agent_state
|
1395
|
-
|
1396
1485
|
@enforce_types
|
1397
1486
|
async def refresh_memory_async(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState:
|
1398
1487
|
block_ids = [b.id for b in agent_state.memory.blocks]
|
@@ -1482,6 +1571,25 @@ class AgentManager:
|
|
1482
1571
|
# Use the lazy-loaded relationship to get sources
|
1483
1572
|
return [source.to_pydantic() for source in agent.sources]
|
1484
1573
|
|
1574
|
+
@enforce_types
|
1575
|
+
async def list_attached_sources_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticSource]:
|
1576
|
+
"""
|
1577
|
+
Lists all sources attached to an agent.
|
1578
|
+
|
1579
|
+
Args:
|
1580
|
+
agent_id: ID of the agent to list sources for
|
1581
|
+
actor: User performing the action
|
1582
|
+
|
1583
|
+
Returns:
|
1584
|
+
List[str]: List of source IDs attached to the agent
|
1585
|
+
"""
|
1586
|
+
async with db_registry.async_session() as session:
|
1587
|
+
# Verify agent exists and user has permission to access it
|
1588
|
+
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
1589
|
+
|
1590
|
+
# Use the lazy-loaded relationship to get sources
|
1591
|
+
return [source.to_pydantic() for source in agent.sources]
|
1592
|
+
|
1485
1593
|
@enforce_types
|
1486
1594
|
def detach_source(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
|
1487
1595
|
"""
|
@@ -1527,6 +1635,33 @@ class AgentManager:
|
|
1527
1635
|
return block.to_pydantic()
|
1528
1636
|
raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'")
|
1529
1637
|
|
1638
|
+
@enforce_types
|
1639
|
+
async def modify_block_by_label_async(
|
1640
|
+
self,
|
1641
|
+
agent_id: str,
|
1642
|
+
block_label: str,
|
1643
|
+
block_update: BlockUpdate,
|
1644
|
+
actor: PydanticUser,
|
1645
|
+
) -> PydanticBlock:
|
1646
|
+
"""Gets a block attached to an agent by its label."""
|
1647
|
+
async with db_registry.async_session() as session:
|
1648
|
+
block = None
|
1649
|
+
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
1650
|
+
for block in agent.core_memory:
|
1651
|
+
if block.label == block_label:
|
1652
|
+
block = block
|
1653
|
+
break
|
1654
|
+
if not block:
|
1655
|
+
raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'")
|
1656
|
+
|
1657
|
+
update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
1658
|
+
|
1659
|
+
for key, value in update_data.items():
|
1660
|
+
setattr(block, key, value)
|
1661
|
+
|
1662
|
+
await block.update_async(session, actor=actor)
|
1663
|
+
return block.to_pydantic()
|
1664
|
+
|
1530
1665
|
@enforce_types
|
1531
1666
|
def update_block_with_label(
|
1532
1667
|
self,
|
@@ -1848,6 +1983,65 @@ class AgentManager:
|
|
1848
1983
|
|
1849
1984
|
return [p.to_pydantic() for p in passages]
|
1850
1985
|
|
1986
|
+
@enforce_types
|
1987
|
+
async def list_passages_async(
|
1988
|
+
self,
|
1989
|
+
actor: PydanticUser,
|
1990
|
+
agent_id: Optional[str] = None,
|
1991
|
+
file_id: Optional[str] = None,
|
1992
|
+
limit: Optional[int] = 50,
|
1993
|
+
query_text: Optional[str] = None,
|
1994
|
+
start_date: Optional[datetime] = None,
|
1995
|
+
end_date: Optional[datetime] = None,
|
1996
|
+
before: Optional[str] = None,
|
1997
|
+
after: Optional[str] = None,
|
1998
|
+
source_id: Optional[str] = None,
|
1999
|
+
embed_query: bool = False,
|
2000
|
+
ascending: bool = True,
|
2001
|
+
embedding_config: Optional[EmbeddingConfig] = None,
|
2002
|
+
agent_only: bool = False,
|
2003
|
+
) -> List[PydanticPassage]:
|
2004
|
+
"""Lists all passages attached to an agent."""
|
2005
|
+
async with db_registry.async_session() as session:
|
2006
|
+
main_query = self._build_passage_query(
|
2007
|
+
actor=actor,
|
2008
|
+
agent_id=agent_id,
|
2009
|
+
file_id=file_id,
|
2010
|
+
query_text=query_text,
|
2011
|
+
start_date=start_date,
|
2012
|
+
end_date=end_date,
|
2013
|
+
before=before,
|
2014
|
+
after=after,
|
2015
|
+
source_id=source_id,
|
2016
|
+
embed_query=embed_query,
|
2017
|
+
ascending=ascending,
|
2018
|
+
embedding_config=embedding_config,
|
2019
|
+
agent_only=agent_only,
|
2020
|
+
)
|
2021
|
+
|
2022
|
+
# Add limit
|
2023
|
+
if limit:
|
2024
|
+
main_query = main_query.limit(limit)
|
2025
|
+
|
2026
|
+
# Execute query
|
2027
|
+
result = await session.execute(main_query)
|
2028
|
+
|
2029
|
+
passages = []
|
2030
|
+
for row in result:
|
2031
|
+
data = dict(row._mapping)
|
2032
|
+
if data["agent_id"] is not None:
|
2033
|
+
# This is an AgentPassage - remove source fields
|
2034
|
+
data.pop("source_id", None)
|
2035
|
+
data.pop("file_id", None)
|
2036
|
+
passage = AgentPassage(**data)
|
2037
|
+
else:
|
2038
|
+
# This is a SourcePassage - remove agent field
|
2039
|
+
data.pop("agent_id", None)
|
2040
|
+
passage = SourcePassage(**data)
|
2041
|
+
passages.append(passage)
|
2042
|
+
|
2043
|
+
return [p.to_pydantic() for p in passages]
|
2044
|
+
|
1851
2045
|
@enforce_types
|
1852
2046
|
def passage_size(
|
1853
2047
|
self,
|
@@ -2010,3 +2204,42 @@ class AgentManager:
|
|
2010
2204
|
query = query.order_by(AgentsTags.tag).limit(limit)
|
2011
2205
|
results = [tag[0] for tag in query.all()]
|
2012
2206
|
return results
|
2207
|
+
|
2208
|
+
@enforce_types
|
2209
|
+
async def list_tags_async(
|
2210
|
+
self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, query_text: Optional[str] = None
|
2211
|
+
) -> List[str]:
|
2212
|
+
"""
|
2213
|
+
Get all tags a user has created, ordered alphabetically.
|
2214
|
+
|
2215
|
+
Args:
|
2216
|
+
actor: User performing the action.
|
2217
|
+
after: Cursor for forward pagination.
|
2218
|
+
limit: Maximum number of tags to return.
|
2219
|
+
query text to filter tags by.
|
2220
|
+
|
2221
|
+
Returns:
|
2222
|
+
List[str]: List of all tags.
|
2223
|
+
"""
|
2224
|
+
async with db_registry.async_session() as session:
|
2225
|
+
# Build the query using select() for async SQLAlchemy
|
2226
|
+
query = (
|
2227
|
+
select(AgentsTags.tag)
|
2228
|
+
.join(AgentModel, AgentModel.id == AgentsTags.agent_id)
|
2229
|
+
.where(AgentModel.organization_id == actor.organization_id)
|
2230
|
+
.distinct()
|
2231
|
+
)
|
2232
|
+
|
2233
|
+
if query_text:
|
2234
|
+
query = query.where(AgentsTags.tag.ilike(f"%{query_text}%"))
|
2235
|
+
|
2236
|
+
if after:
|
2237
|
+
query = query.where(AgentsTags.tag > after)
|
2238
|
+
|
2239
|
+
query = query.order_by(AgentsTags.tag).limit(limit)
|
2240
|
+
|
2241
|
+
# Execute the query asynchronously
|
2242
|
+
result = await session.execute(query)
|
2243
|
+
# Extract the tag values from the result
|
2244
|
+
results = [row[0] for row in result.all()]
|
2245
|
+
return results
|