letta-nightly 0.5.1.dev20241028104150__py3-none-any.whl → 0.5.1.dev20241030104135__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of letta-nightly might be problematic. Click here for more details.
- letta/cli/cli.py +2 -4
- letta/client/client.py +17 -61
- letta/llm_api/llm_api_tools.py +0 -1
- letta/orm/base.py +10 -5
- letta/orm/sqlalchemy_base.py +50 -63
- letta/orm/tool.py +3 -12
- letta/orm/user.py +1 -3
- letta/providers.py +1 -1
- letta/schemas/tool.py +5 -18
- letta/server/rest_api/routers/v1/agents.py +14 -14
- letta/server/rest_api/routers/v1/tools.py +8 -9
- letta/server/rest_api/utils.py +21 -2
- letta/server/server.py +48 -23
- letta/server/startup.sh +2 -2
- letta/services/tool_manager.py +32 -66
- letta/services/user_manager.py +2 -5
- {letta_nightly-0.5.1.dev20241028104150.dist-info → letta_nightly-0.5.1.dev20241030104135.dist-info}/METADATA +1 -1
- {letta_nightly-0.5.1.dev20241028104150.dist-info → letta_nightly-0.5.1.dev20241030104135.dist-info}/RECORD +21 -21
- {letta_nightly-0.5.1.dev20241028104150.dist-info → letta_nightly-0.5.1.dev20241030104135.dist-info}/LICENSE +0 -0
- {letta_nightly-0.5.1.dev20241028104150.dist-info → letta_nightly-0.5.1.dev20241030104135.dist-info}/WHEEL +0 -0
- {letta_nightly-0.5.1.dev20241028104150.dist-info → letta_nightly-0.5.1.dev20241030104135.dist-info}/entry_points.txt +0 -0
|
@@ -26,11 +26,13 @@ def delete_tool(
|
|
|
26
26
|
def get_tool(
|
|
27
27
|
tool_id: str,
|
|
28
28
|
server: SyncServer = Depends(get_letta_server),
|
|
29
|
+
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
29
30
|
):
|
|
30
31
|
"""
|
|
31
32
|
Get a tool by ID
|
|
32
33
|
"""
|
|
33
|
-
|
|
34
|
+
actor = server.get_user_or_default(user_id=user_id)
|
|
35
|
+
tool = server.tool_manager.get_tool_by_id(tool_id=tool_id, actor=actor)
|
|
34
36
|
if tool is None:
|
|
35
37
|
# return 404 error
|
|
36
38
|
raise HTTPException(status_code=404, detail=f"Tool with id {tool_id} not found.")
|
|
@@ -49,7 +51,7 @@ def get_tool_id(
|
|
|
49
51
|
actor = server.get_user_or_default(user_id=user_id)
|
|
50
52
|
|
|
51
53
|
try:
|
|
52
|
-
tool = server.tool_manager.
|
|
54
|
+
tool = server.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor)
|
|
53
55
|
return tool.id
|
|
54
56
|
except NoResultFound:
|
|
55
57
|
raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} and organization id {actor.organization_id} not found.")
|
|
@@ -67,7 +69,7 @@ def list_tools(
|
|
|
67
69
|
"""
|
|
68
70
|
try:
|
|
69
71
|
actor = server.get_user_or_default(user_id=user_id)
|
|
70
|
-
return server.tool_manager.
|
|
72
|
+
return server.tool_manager.list_tools(actor=actor, cursor=cursor, limit=limit)
|
|
71
73
|
except Exception as e:
|
|
72
74
|
# Log or print the full exception here for debugging
|
|
73
75
|
print(f"Error occurred: {e}")
|
|
@@ -85,13 +87,9 @@ def create_tool(
|
|
|
85
87
|
"""
|
|
86
88
|
# Derive user and org id from actor
|
|
87
89
|
actor = server.get_user_or_default(user_id=user_id)
|
|
88
|
-
request.organization_id = actor.organization_id
|
|
89
|
-
request.user_id = actor.id
|
|
90
90
|
|
|
91
91
|
# Send request to create the tool
|
|
92
|
-
return server.tool_manager.create_or_update_tool(
|
|
93
|
-
tool_create=request,
|
|
94
|
-
)
|
|
92
|
+
return server.tool_manager.create_or_update_tool(tool_create=request, actor=actor)
|
|
95
93
|
|
|
96
94
|
|
|
97
95
|
@router.patch("/{tool_id}", response_model=Tool, operation_id="update_tool")
|
|
@@ -104,4 +102,5 @@ def update_tool(
|
|
|
104
102
|
"""
|
|
105
103
|
Update an existing tool
|
|
106
104
|
"""
|
|
107
|
-
|
|
105
|
+
actor = server.get_user_or_default(user_id=user_id)
|
|
106
|
+
return server.tool_manager.update_tool_by_id(tool_id, actor.id, request)
|
letta/server/rest_api/utils.py
CHANGED
|
@@ -1,10 +1,13 @@
|
|
|
1
|
+
import asyncio
|
|
1
2
|
import json
|
|
2
3
|
import traceback
|
|
4
|
+
import warnings
|
|
3
5
|
from enum import Enum
|
|
4
|
-
from typing import AsyncGenerator, Union
|
|
6
|
+
from typing import AsyncGenerator, Optional, Union
|
|
5
7
|
|
|
6
8
|
from pydantic import BaseModel
|
|
7
9
|
|
|
10
|
+
from letta.schemas.usage import LettaUsageStatistics
|
|
8
11
|
from letta.server.rest_api.interface import StreamingServerInterface
|
|
9
12
|
from letta.server.server import SyncServer
|
|
10
13
|
|
|
@@ -24,7 +27,11 @@ def sse_formatter(data: Union[dict, str]) -> str:
|
|
|
24
27
|
return f"data: {data_str}\n\n"
|
|
25
28
|
|
|
26
29
|
|
|
27
|
-
async def sse_async_generator(
|
|
30
|
+
async def sse_async_generator(
|
|
31
|
+
generator: AsyncGenerator,
|
|
32
|
+
usage_task: Optional[asyncio.Task] = None,
|
|
33
|
+
finish_message=True,
|
|
34
|
+
):
|
|
28
35
|
"""
|
|
29
36
|
Wraps a generator for use in Server-Sent Events (SSE), handling errors and ensuring a completion message.
|
|
30
37
|
|
|
@@ -45,6 +52,18 @@ async def sse_async_generator(generator: AsyncGenerator, finish_message=True):
|
|
|
45
52
|
chunk = str(chunk)
|
|
46
53
|
yield sse_formatter(chunk)
|
|
47
54
|
|
|
55
|
+
# If we have a usage task, wait for it and send its result
|
|
56
|
+
if usage_task is not None:
|
|
57
|
+
try:
|
|
58
|
+
usage = await usage_task
|
|
59
|
+
# Double-check the type
|
|
60
|
+
if not isinstance(usage, LettaUsageStatistics):
|
|
61
|
+
raise ValueError(f"Expected LettaUsageStatistics, got {type(usage)}")
|
|
62
|
+
yield sse_formatter({"usage": usage.model_dump()})
|
|
63
|
+
except Exception as e:
|
|
64
|
+
warnings.warn(f"Error getting usage data: {e}")
|
|
65
|
+
yield sse_formatter({"error": "Failed to get usage data"})
|
|
66
|
+
|
|
48
67
|
except Exception as e:
|
|
49
68
|
print("stream decoder hit error:", e)
|
|
50
69
|
print(traceback.print_stack())
|
letta/server/server.py
CHANGED
|
@@ -37,11 +37,13 @@ from letta.log import get_logger
|
|
|
37
37
|
from letta.memory import get_memory_functions
|
|
38
38
|
from letta.metadata import Base, MetadataStore
|
|
39
39
|
from letta.o1_agent import O1Agent
|
|
40
|
+
from letta.orm.errors import NoResultFound
|
|
40
41
|
from letta.prompts import gpt_system
|
|
41
42
|
from letta.providers import (
|
|
42
43
|
AnthropicProvider,
|
|
43
44
|
AzureProvider,
|
|
44
45
|
GoogleAIProvider,
|
|
46
|
+
GroqProvider,
|
|
45
47
|
LettaProvider,
|
|
46
48
|
OllamaProvider,
|
|
47
49
|
OpenAIProvider,
|
|
@@ -73,6 +75,7 @@ from letta.schemas.memory import (
|
|
|
73
75
|
RecallMemorySummary,
|
|
74
76
|
)
|
|
75
77
|
from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage
|
|
78
|
+
from letta.schemas.organization import Organization
|
|
76
79
|
from letta.schemas.passage import Passage
|
|
77
80
|
from letta.schemas.source import Source, SourceCreate, SourceUpdate
|
|
78
81
|
from letta.schemas.tool import Tool, ToolCreate
|
|
@@ -251,12 +254,12 @@ class SyncServer(Server):
|
|
|
251
254
|
self.default_org = self.organization_manager.create_default_organization()
|
|
252
255
|
self.default_user = self.user_manager.create_default_user()
|
|
253
256
|
self.add_default_blocks(self.default_user.id)
|
|
254
|
-
self.tool_manager.add_default_tools(module_name="base",
|
|
257
|
+
self.tool_manager.add_default_tools(module_name="base", actor=self.default_user)
|
|
255
258
|
|
|
256
259
|
# If there is a default org/user
|
|
257
260
|
# This logic may have to change in the future
|
|
258
261
|
if settings.load_default_external_tools:
|
|
259
|
-
self.add_default_external_tools(
|
|
262
|
+
self.add_default_external_tools(actor=self.default_user)
|
|
260
263
|
|
|
261
264
|
# collect providers (always has Letta as a default)
|
|
262
265
|
self._enabled_providers: List[Provider] = [LettaProvider()]
|
|
@@ -296,6 +299,8 @@ class SyncServer(Server):
|
|
|
296
299
|
api_version=model_settings.azure_api_version,
|
|
297
300
|
)
|
|
298
301
|
)
|
|
302
|
+
if model_settings.groq_api_key:
|
|
303
|
+
self._enabled_providers.append(GroqProvider(api_key=model_settings.groq_api_key))
|
|
299
304
|
if model_settings.vllm_api_base:
|
|
300
305
|
# vLLM exposes both a /chat/completions and a /completions endpoint
|
|
301
306
|
self._enabled_providers.append(
|
|
@@ -345,10 +350,10 @@ class SyncServer(Server):
|
|
|
345
350
|
}
|
|
346
351
|
)
|
|
347
352
|
|
|
348
|
-
def _load_agent(self,
|
|
353
|
+
def _load_agent(self, agent_id: str, actor: User, interface: Union[AgentInterface, None] = None) -> Agent:
|
|
349
354
|
"""Loads a saved agent into memory (if it doesn't exist, throw an error)"""
|
|
350
|
-
assert isinstance(user_id, str), user_id
|
|
351
355
|
assert isinstance(agent_id, str), agent_id
|
|
356
|
+
user_id = actor.id
|
|
352
357
|
|
|
353
358
|
# If an interface isn't specified, use the default
|
|
354
359
|
if interface is None:
|
|
@@ -365,7 +370,7 @@ class SyncServer(Server):
|
|
|
365
370
|
logger.debug(f"Creating an agent object")
|
|
366
371
|
tool_objs = []
|
|
367
372
|
for name in agent_state.tools:
|
|
368
|
-
tool_obj = self.tool_manager.
|
|
373
|
+
tool_obj = self.tool_manager.get_tool_by_name(tool_name=name, actor=actor)
|
|
369
374
|
if not tool_obj:
|
|
370
375
|
logger.exception(f"Tool {name} does not exist for user {user_id}")
|
|
371
376
|
raise ValueError(f"Tool {name} does not exist for user {user_id}")
|
|
@@ -396,13 +401,14 @@ class SyncServer(Server):
|
|
|
396
401
|
if not agent_state:
|
|
397
402
|
raise ValueError(f"Agent does not exist")
|
|
398
403
|
user_id = agent_state.user_id
|
|
404
|
+
actor = self.user_manager.get_user_by_id(user_id)
|
|
399
405
|
|
|
400
406
|
logger.debug(f"Checking for agent user_id={user_id} agent_id={agent_id}")
|
|
401
407
|
# TODO: consider disabling loading cached agents due to potential concurrency issues
|
|
402
408
|
letta_agent = self._get_agent(user_id=user_id, agent_id=agent_id)
|
|
403
409
|
if not letta_agent:
|
|
404
410
|
logger.debug(f"Agent not loaded, loading agent user_id={user_id} agent_id={agent_id}")
|
|
405
|
-
letta_agent = self._load_agent(
|
|
411
|
+
letta_agent = self._load_agent(agent_id=agent_id, actor=actor)
|
|
406
412
|
return letta_agent
|
|
407
413
|
|
|
408
414
|
def _step(
|
|
@@ -759,11 +765,12 @@ class SyncServer(Server):
|
|
|
759
765
|
def create_agent(
|
|
760
766
|
self,
|
|
761
767
|
request: CreateAgent,
|
|
762
|
-
|
|
768
|
+
actor: User,
|
|
763
769
|
# interface
|
|
764
770
|
interface: Union[AgentInterface, None] = None,
|
|
765
771
|
) -> AgentState:
|
|
766
772
|
"""Create a new agent using a config"""
|
|
773
|
+
user_id = actor.id
|
|
767
774
|
if self.user_manager.get_user_by_id(user_id=user_id) is None:
|
|
768
775
|
raise ValueError(f"User user_id={user_id} does not exist")
|
|
769
776
|
|
|
@@ -801,7 +808,7 @@ class SyncServer(Server):
|
|
|
801
808
|
tool_objs = []
|
|
802
809
|
if request.tools:
|
|
803
810
|
for tool_name in request.tools:
|
|
804
|
-
tool_obj = self.tool_manager.
|
|
811
|
+
tool_obj = self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor)
|
|
805
812
|
tool_objs.append(tool_obj)
|
|
806
813
|
|
|
807
814
|
assert request.memory is not None
|
|
@@ -822,9 +829,8 @@ class SyncServer(Server):
|
|
|
822
829
|
source_type=source_type,
|
|
823
830
|
tags=tags,
|
|
824
831
|
json_schema=json_schema,
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
)
|
|
832
|
+
),
|
|
833
|
+
actor=actor,
|
|
828
834
|
)
|
|
829
835
|
tool_objs.append(tool)
|
|
830
836
|
if not request.tools:
|
|
@@ -887,11 +893,14 @@ class SyncServer(Server):
|
|
|
887
893
|
def update_agent(
|
|
888
894
|
self,
|
|
889
895
|
request: UpdateAgentState,
|
|
890
|
-
|
|
896
|
+
actor: User,
|
|
891
897
|
):
|
|
892
898
|
"""Update the agents core memory block, return the new state"""
|
|
893
|
-
|
|
894
|
-
|
|
899
|
+
try:
|
|
900
|
+
self.user_manager.get_user_by_id(user_id=actor.id)
|
|
901
|
+
except Exception:
|
|
902
|
+
raise ValueError(f"User user_id={actor.id} does not exist")
|
|
903
|
+
|
|
895
904
|
if self.ms.get_agent(agent_id=request.id) is None:
|
|
896
905
|
raise ValueError(f"Agent agent_id={request.id} does not exist")
|
|
897
906
|
|
|
@@ -902,7 +911,7 @@ class SyncServer(Server):
|
|
|
902
911
|
if request.memory:
|
|
903
912
|
assert isinstance(request.memory, Memory), type(request.memory)
|
|
904
913
|
new_memory_contents = request.memory.to_flat_dict()
|
|
905
|
-
_ = self.update_agent_core_memory(user_id=
|
|
914
|
+
_ = self.update_agent_core_memory(user_id=actor.id, agent_id=request.id, new_memory_contents=new_memory_contents)
|
|
906
915
|
|
|
907
916
|
# update the system prompt
|
|
908
917
|
if request.system:
|
|
@@ -922,7 +931,7 @@ class SyncServer(Server):
|
|
|
922
931
|
# (1) get tools + make sure they exist
|
|
923
932
|
tool_objs = []
|
|
924
933
|
for tool_name in request.tools:
|
|
925
|
-
tool_obj = self.tool_manager.
|
|
934
|
+
tool_obj = self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor)
|
|
926
935
|
assert tool_obj, f"Tool {tool_name} does not exist"
|
|
927
936
|
tool_objs.append(tool_obj)
|
|
928
937
|
|
|
@@ -968,8 +977,11 @@ class SyncServer(Server):
|
|
|
968
977
|
user_id: str,
|
|
969
978
|
):
|
|
970
979
|
"""Add tools from an existing agent"""
|
|
971
|
-
|
|
980
|
+
try:
|
|
981
|
+
user = self.user_manager.get_user_by_id(user_id=user_id)
|
|
982
|
+
except NoResultFound:
|
|
972
983
|
raise ValueError(f"User user_id={user_id} does not exist")
|
|
984
|
+
|
|
973
985
|
if self.ms.get_agent(agent_id=agent_id) is None:
|
|
974
986
|
raise ValueError(f"Agent agent_id={agent_id} does not exist")
|
|
975
987
|
|
|
@@ -978,12 +990,12 @@ class SyncServer(Server):
|
|
|
978
990
|
|
|
979
991
|
# Get all the tool objects from the request
|
|
980
992
|
tool_objs = []
|
|
981
|
-
tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool_id)
|
|
993
|
+
tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool_id, actor=user)
|
|
982
994
|
assert tool_obj, f"Tool with id={tool_id} does not exist"
|
|
983
995
|
tool_objs.append(tool_obj)
|
|
984
996
|
|
|
985
997
|
for tool in letta_agent.tools:
|
|
986
|
-
tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id)
|
|
998
|
+
tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id, actor=user)
|
|
987
999
|
assert tool_obj, f"Tool with id={tool.id} does not exist"
|
|
988
1000
|
|
|
989
1001
|
# If it's not the already added tool
|
|
@@ -1007,8 +1019,11 @@ class SyncServer(Server):
|
|
|
1007
1019
|
user_id: str,
|
|
1008
1020
|
):
|
|
1009
1021
|
"""Remove tools from an existing agent"""
|
|
1010
|
-
|
|
1022
|
+
try:
|
|
1023
|
+
user = self.user_manager.get_user_by_id(user_id=user_id)
|
|
1024
|
+
except NoResultFound:
|
|
1011
1025
|
raise ValueError(f"User user_id={user_id} does not exist")
|
|
1026
|
+
|
|
1012
1027
|
if self.ms.get_agent(agent_id=agent_id) is None:
|
|
1013
1028
|
raise ValueError(f"Agent agent_id={agent_id} does not exist")
|
|
1014
1029
|
|
|
@@ -1018,7 +1033,7 @@ class SyncServer(Server):
|
|
|
1018
1033
|
# Get all the tool_objs
|
|
1019
1034
|
tool_objs = []
|
|
1020
1035
|
for tool in letta_agent.tools:
|
|
1021
|
-
tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id)
|
|
1036
|
+
tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id, actor=user)
|
|
1022
1037
|
assert tool_obj, f"Tool with id={tool.id} does not exist"
|
|
1023
1038
|
|
|
1024
1039
|
# If it's not the tool we want to remove
|
|
@@ -1733,7 +1748,7 @@ class SyncServer(Server):
|
|
|
1733
1748
|
|
|
1734
1749
|
return sources_with_metadata
|
|
1735
1750
|
|
|
1736
|
-
def add_default_external_tools(self,
|
|
1751
|
+
def add_default_external_tools(self, actor: User) -> bool:
|
|
1737
1752
|
"""Add default langchain tools. Return true if successful, false otherwise."""
|
|
1738
1753
|
success = True
|
|
1739
1754
|
tool_creates = ToolCreate.load_default_langchain_tools() + ToolCreate.load_default_crewai_tools()
|
|
@@ -1741,7 +1756,7 @@ class SyncServer(Server):
|
|
|
1741
1756
|
tool_creates += ToolCreate.load_default_composio_tools()
|
|
1742
1757
|
for tool_create in tool_creates:
|
|
1743
1758
|
try:
|
|
1744
|
-
self.tool_manager.create_or_update_tool(tool_create)
|
|
1759
|
+
self.tool_manager.create_or_update_tool(tool_create, actor=actor)
|
|
1745
1760
|
except Exception as e:
|
|
1746
1761
|
warnings.warn(f"An error occurred while creating tool {tool_create}: {e}")
|
|
1747
1762
|
warnings.warn(traceback.format_exc())
|
|
@@ -1843,6 +1858,16 @@ class SyncServer(Server):
|
|
|
1843
1858
|
except ValueError:
|
|
1844
1859
|
raise HTTPException(status_code=404, detail=f"User with id {user_id} not found")
|
|
1845
1860
|
|
|
1861
|
+
def get_organization_or_default(self, org_id: Optional[str]) -> Organization:
|
|
1862
|
+
"""Get the organization object for org_id if it exists, otherwise return the default organization object"""
|
|
1863
|
+
if org_id is None:
|
|
1864
|
+
org_id = self.organization_manager.DEFAULT_ORG_ID
|
|
1865
|
+
|
|
1866
|
+
try:
|
|
1867
|
+
return self.organization_manager.get_organization_by_id(org_id=org_id)
|
|
1868
|
+
except ValueError:
|
|
1869
|
+
raise HTTPException(status_code=404, detail=f"Organization with id {org_id} not found")
|
|
1870
|
+
|
|
1846
1871
|
def list_llm_models(self) -> List[LLMConfig]:
|
|
1847
1872
|
"""List available models"""
|
|
1848
1873
|
|
letta/server/startup.sh
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
echo "Starting MEMGPT server..."
|
|
3
3
|
if [ "$MEMGPT_ENVIRONMENT" = "DEVELOPMENT" ] ; then
|
|
4
4
|
echo "Starting in development mode!"
|
|
5
|
-
uvicorn letta.server.rest_api.app:app --reload --reload-dir /letta --host 0.0.0.0 --port
|
|
5
|
+
uvicorn letta.server.rest_api.app:app --reload --reload-dir /letta --host 0.0.0.0 --port 8283
|
|
6
6
|
else
|
|
7
|
-
uvicorn letta.server.rest_api.app:app --host 0.0.0.0 --port
|
|
7
|
+
uvicorn letta.server.rest_api.app:app --host 0.0.0.0 --port 8283
|
|
8
8
|
fi
|
letta/services/tool_manager.py
CHANGED
|
@@ -9,9 +9,9 @@ from letta.functions.functions import derive_openai_json_schema, load_function_s
|
|
|
9
9
|
from letta.orm.errors import NoResultFound
|
|
10
10
|
from letta.orm.organization import Organization as OrganizationModel
|
|
11
11
|
from letta.orm.tool import Tool as ToolModel
|
|
12
|
-
from letta.orm.user import User as UserModel
|
|
13
12
|
from letta.schemas.tool import Tool as PydanticTool
|
|
14
13
|
from letta.schemas.tool import ToolCreate, ToolUpdate
|
|
14
|
+
from letta.schemas.user import User as PydanticUser
|
|
15
15
|
from letta.utils import enforce_types
|
|
16
16
|
|
|
17
17
|
|
|
@@ -25,7 +25,7 @@ class ToolManager:
|
|
|
25
25
|
self.session_maker = db_context
|
|
26
26
|
|
|
27
27
|
@enforce_types
|
|
28
|
-
def create_or_update_tool(self, tool_create: ToolCreate) -> PydanticTool:
|
|
28
|
+
def create_or_update_tool(self, tool_create: ToolCreate, actor: PydanticUser) -> PydanticTool:
|
|
29
29
|
"""Create a new tool based on the ToolCreate schema."""
|
|
30
30
|
# Derive json_schema
|
|
31
31
|
derived_json_schema = tool_create.json_schema or derive_openai_json_schema(tool_create)
|
|
@@ -34,105 +34,72 @@ class ToolManager:
|
|
|
34
34
|
try:
|
|
35
35
|
# NOTE: We use the organization id here
|
|
36
36
|
# This is important, because even if it's a different user, adding the same tool to the org should not happen
|
|
37
|
-
tool = self.
|
|
37
|
+
tool = self.get_tool_by_name(tool_name=derived_name, actor=actor)
|
|
38
38
|
# Put to dict and remove fields that should not be reset
|
|
39
|
-
update_data = tool_create.model_dump(exclude={"
|
|
39
|
+
update_data = tool_create.model_dump(exclude={"module", "terminal"}, exclude_unset=True)
|
|
40
40
|
# Remove redundant update fields
|
|
41
41
|
update_data = {key: value for key, value in update_data.items() if getattr(tool, key) != value}
|
|
42
42
|
|
|
43
43
|
# If there's anything to update
|
|
44
44
|
if update_data:
|
|
45
|
-
self.update_tool_by_id(tool.id, ToolUpdate(**update_data))
|
|
45
|
+
self.update_tool_by_id(tool.id, ToolUpdate(**update_data), actor)
|
|
46
46
|
else:
|
|
47
47
|
warnings.warn(
|
|
48
|
-
f"`create_or_update_tool` was called with user_id={
|
|
48
|
+
f"`create_or_update_tool` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={tool_create.name}, but found existing tool with nothing to update."
|
|
49
49
|
)
|
|
50
50
|
except NoResultFound:
|
|
51
51
|
tool_create.json_schema = derived_json_schema
|
|
52
52
|
tool_create.name = derived_name
|
|
53
|
-
tool = self.create_tool(tool_create)
|
|
53
|
+
tool = self.create_tool(tool_create, actor=actor)
|
|
54
54
|
|
|
55
55
|
return tool
|
|
56
56
|
|
|
57
57
|
@enforce_types
|
|
58
|
-
def create_tool(self, tool_create: ToolCreate) -> PydanticTool:
|
|
58
|
+
def create_tool(self, tool_create: ToolCreate, actor: PydanticUser) -> PydanticTool:
|
|
59
59
|
"""Create a new tool based on the ToolCreate schema."""
|
|
60
60
|
# Create the tool
|
|
61
61
|
with self.session_maker() as session:
|
|
62
|
-
# Include all fields except
|
|
62
|
+
# Include all fields except `terminal` (which is not part of ToolModel) at the moment
|
|
63
63
|
create_data = tool_create.model_dump(exclude={"terminal"})
|
|
64
|
-
tool = ToolModel(**create_data) # Unpack everything directly into ToolModel
|
|
65
|
-
tool.create(session)
|
|
64
|
+
tool = ToolModel(**create_data, organization_id=actor.organization_id) # Unpack everything directly into ToolModel
|
|
65
|
+
tool.create(session, actor=actor)
|
|
66
66
|
|
|
67
67
|
return tool.to_pydantic()
|
|
68
68
|
|
|
69
69
|
@enforce_types
|
|
70
|
-
def get_tool_by_id(self, tool_id: str) -> PydanticTool:
|
|
70
|
+
def get_tool_by_id(self, tool_id: str, actor: PydanticUser) -> PydanticTool:
|
|
71
71
|
"""Fetch a tool by its ID."""
|
|
72
72
|
with self.session_maker() as session:
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
return tool.to_pydantic()
|
|
78
|
-
except NoResultFound:
|
|
79
|
-
raise ValueError(f"Tool with id {tool_id} not found.")
|
|
80
|
-
|
|
81
|
-
@enforce_types
|
|
82
|
-
def get_tool_by_name_and_user_id(self, tool_name: str, user_id: str) -> PydanticTool:
|
|
83
|
-
"""Retrieve a tool by its name and organization_id."""
|
|
84
|
-
with self.session_maker() as session:
|
|
85
|
-
# Use the list method to apply filters
|
|
86
|
-
results = ToolModel.list(db_session=session, name=tool_name, _user_id=UserModel.get_uid_from_identifier(user_id))
|
|
87
|
-
|
|
88
|
-
# Ensure only one result is returned (since there is a unique constraint)
|
|
89
|
-
if not results:
|
|
90
|
-
raise NoResultFound(f"Tool with name {tool_name} and user_id {user_id} not found.")
|
|
91
|
-
|
|
92
|
-
if len(results) > 1:
|
|
93
|
-
raise RuntimeError(
|
|
94
|
-
f"Multiple tools with name {tool_name} and user_id {user_id} were found. This is a serious error, and means that our table does not have uniqueness constraints properly set up. Please reach out to the letta development team if you see this error."
|
|
95
|
-
)
|
|
96
|
-
|
|
97
|
-
# Return the single result
|
|
98
|
-
return results[0]
|
|
73
|
+
# Retrieve tool by id using the Tool model's read method
|
|
74
|
+
tool = ToolModel.read(db_session=session, identifier=tool_id, actor=actor)
|
|
75
|
+
# Convert the SQLAlchemy Tool object to PydanticTool
|
|
76
|
+
return tool.to_pydantic()
|
|
99
77
|
|
|
100
78
|
@enforce_types
|
|
101
|
-
def
|
|
102
|
-
"""Retrieve a tool by its name and
|
|
79
|
+
def get_tool_by_name(self, tool_name: str, actor: PydanticUser):
|
|
80
|
+
"""Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool."""
|
|
103
81
|
with self.session_maker() as session:
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
db_session=session, name=tool_name, _organization_id=OrganizationModel.get_uid_from_identifier(organization_id)
|
|
107
|
-
)
|
|
108
|
-
|
|
109
|
-
# Ensure only one result is returned (since there is a unique constraint)
|
|
110
|
-
if not results:
|
|
111
|
-
raise NoResultFound(f"Tool with name {tool_name} and organization_id {organization_id} not found.")
|
|
112
|
-
|
|
113
|
-
if len(results) > 1:
|
|
114
|
-
raise RuntimeError(
|
|
115
|
-
f"Multiple tools with name {tool_name} and organization_id {organization_id} were found. This is a serious error, and means that our table does not have uniqueness constraints properly set up. Please reach out to the letta development team if you see this error."
|
|
116
|
-
)
|
|
117
|
-
|
|
118
|
-
# Return the single result
|
|
119
|
-
return results[0]
|
|
82
|
+
tool = ToolModel.read(db_session=session, name=tool_name, actor=actor)
|
|
83
|
+
return tool.to_pydantic()
|
|
120
84
|
|
|
121
85
|
@enforce_types
|
|
122
|
-
def
|
|
86
|
+
def list_tools(self, actor: PydanticUser, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]:
|
|
123
87
|
"""List all tools with optional pagination using cursor and limit."""
|
|
124
88
|
with self.session_maker() as session:
|
|
125
89
|
tools = ToolModel.list(
|
|
126
|
-
db_session=session,
|
|
90
|
+
db_session=session,
|
|
91
|
+
cursor=cursor,
|
|
92
|
+
limit=limit,
|
|
93
|
+
_organization_id=OrganizationModel.get_uid_from_identifier(actor.organization_id),
|
|
127
94
|
)
|
|
128
95
|
return [tool.to_pydantic() for tool in tools]
|
|
129
96
|
|
|
130
97
|
@enforce_types
|
|
131
|
-
def update_tool_by_id(self, tool_id: str, tool_update: ToolUpdate) -> None:
|
|
98
|
+
def update_tool_by_id(self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser) -> None:
|
|
132
99
|
"""Update a tool by its ID with the given ToolUpdate object."""
|
|
133
100
|
with self.session_maker() as session:
|
|
134
101
|
# Fetch the tool by ID
|
|
135
|
-
tool = ToolModel.read(db_session=session, identifier=tool_id)
|
|
102
|
+
tool = ToolModel.read(db_session=session, identifier=tool_id, actor=actor)
|
|
136
103
|
|
|
137
104
|
# Update tool attributes with only the fields that were explicitly set
|
|
138
105
|
update_data = tool_update.model_dump(exclude_unset=True, exclude_none=True)
|
|
@@ -140,20 +107,20 @@ class ToolManager:
|
|
|
140
107
|
setattr(tool, key, value)
|
|
141
108
|
|
|
142
109
|
# Save the updated tool to the database
|
|
143
|
-
tool.update(db_session=session)
|
|
110
|
+
tool.update(db_session=session, actor=actor)
|
|
144
111
|
|
|
145
112
|
@enforce_types
|
|
146
|
-
def delete_tool_by_id(self, tool_id: str) -> None:
|
|
113
|
+
def delete_tool_by_id(self, tool_id: str, actor: PydanticUser) -> None:
|
|
147
114
|
"""Delete a tool by its ID."""
|
|
148
115
|
with self.session_maker() as session:
|
|
149
116
|
try:
|
|
150
117
|
tool = ToolModel.read(db_session=session, identifier=tool_id)
|
|
151
|
-
tool.delete(db_session=session)
|
|
118
|
+
tool.delete(db_session=session, actor=actor)
|
|
152
119
|
except NoResultFound:
|
|
153
120
|
raise ValueError(f"Tool with id {tool_id} not found.")
|
|
154
121
|
|
|
155
122
|
@enforce_types
|
|
156
|
-
def add_default_tools(self,
|
|
123
|
+
def add_default_tools(self, actor: PydanticUser, module_name="base"):
|
|
157
124
|
"""Add default tools in {module_name}.py"""
|
|
158
125
|
full_module_name = f"letta.functions.function_sets.{module_name}"
|
|
159
126
|
try:
|
|
@@ -187,7 +154,6 @@ class ToolManager:
|
|
|
187
154
|
module=schema["module"],
|
|
188
155
|
source_code=source_code,
|
|
189
156
|
json_schema=schema["json_schema"],
|
|
190
|
-
organization_id=org_id,
|
|
191
|
-
user_id=user_id,
|
|
192
157
|
),
|
|
158
|
+
actor=actor,
|
|
193
159
|
)
|
letta/services/user_manager.py
CHANGED
|
@@ -85,11 +85,8 @@ class UserManager:
|
|
|
85
85
|
def get_user_by_id(self, user_id: str) -> PydanticUser:
|
|
86
86
|
"""Fetch a user by ID."""
|
|
87
87
|
with self.session_maker() as session:
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
return user.to_pydantic()
|
|
91
|
-
except NoResultFound:
|
|
92
|
-
raise ValueError(f"User with id {user_id} not found.")
|
|
88
|
+
user = UserModel.read(db_session=session, identifier=user_id)
|
|
89
|
+
return user.to_pydantic()
|
|
93
90
|
|
|
94
91
|
@enforce_types
|
|
95
92
|
def get_default_user(self) -> PydanticUser:
|