letta-nightly 0.5.2.dev20241112104101__py3-none-any.whl → 0.5.2.dev20241113234401__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of letta-nightly might be problematic. Click here for more details.
- letta/agent.py +4 -2
- letta/agent_store/db.py +2 -1
- letta/cli/cli.py +1 -0
- letta/client/client.py +19 -16
- letta/data_sources/connectors.py +5 -10
- letta/llm_api/google_ai.py +0 -2
- letta/llm_api/openai.py +1 -0
- letta/memory.py +10 -6
- letta/metadata.py +3 -165
- letta/orm/__init__.py +2 -0
- letta/orm/file.py +29 -0
- letta/orm/mixins.py +8 -0
- letta/orm/organization.py +4 -4
- letta/orm/source.py +51 -0
- letta/orm/tool.py +0 -1
- letta/orm/user.py +0 -2
- letta/providers.py +0 -1
- letta/schemas/file.py +5 -5
- letta/schemas/source.py +30 -24
- letta/server/rest_api/app.py +27 -0
- letta/server/rest_api/routers/v1/sources.py +15 -15
- letta/server/server.py +38 -65
- letta/services/organization_manager.py +12 -13
- letta/services/source_manager.py +145 -0
- {letta_nightly-0.5.2.dev20241112104101.dist-info → letta_nightly-0.5.2.dev20241113234401.dist-info}/METADATA +1 -1
- {letta_nightly-0.5.2.dev20241112104101.dist-info → letta_nightly-0.5.2.dev20241113234401.dist-info}/RECORD +29 -26
- {letta_nightly-0.5.2.dev20241112104101.dist-info → letta_nightly-0.5.2.dev20241113234401.dist-info}/LICENSE +0 -0
- {letta_nightly-0.5.2.dev20241112104101.dist-info → letta_nightly-0.5.2.dev20241113234401.dist-info}/WHEEL +0 -0
- {letta_nightly-0.5.2.dev20241112104101.dist-info → letta_nightly-0.5.2.dev20241113234401.dist-info}/entry_points.txt +0 -0
letta/orm/source.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, List, Optional
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import JSON, TypeDecorator
|
|
4
|
+
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
|
5
|
+
|
|
6
|
+
from letta.orm.mixins import OrganizationMixin
|
|
7
|
+
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
|
8
|
+
from letta.schemas.embedding_config import EmbeddingConfig
|
|
9
|
+
from letta.schemas.source import Source as PydanticSource
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from letta.orm.organization import Organization
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class EmbeddingConfigColumn(TypeDecorator):
|
|
16
|
+
"""Custom type for storing EmbeddingConfig as JSON"""
|
|
17
|
+
|
|
18
|
+
impl = JSON
|
|
19
|
+
cache_ok = True
|
|
20
|
+
|
|
21
|
+
def load_dialect_impl(self, dialect):
|
|
22
|
+
return dialect.type_descriptor(JSON())
|
|
23
|
+
|
|
24
|
+
def process_bind_param(self, value, dialect):
|
|
25
|
+
if value:
|
|
26
|
+
# return vars(value)
|
|
27
|
+
if isinstance(value, EmbeddingConfig):
|
|
28
|
+
return value.model_dump()
|
|
29
|
+
return value
|
|
30
|
+
|
|
31
|
+
def process_result_value(self, value, dialect):
|
|
32
|
+
if value:
|
|
33
|
+
return EmbeddingConfig(**value)
|
|
34
|
+
return value
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Source(SqlalchemyBase, OrganizationMixin):
|
|
38
|
+
"""A source represents an embedded text passage"""
|
|
39
|
+
|
|
40
|
+
__tablename__ = "sources"
|
|
41
|
+
__pydantic_model__ = PydanticSource
|
|
42
|
+
|
|
43
|
+
name: Mapped[str] = mapped_column(doc="the name of the source, must be unique within the org", nullable=False)
|
|
44
|
+
description: Mapped[str] = mapped_column(nullable=True, doc="a human-readable description of the source")
|
|
45
|
+
embedding_config: Mapped[EmbeddingConfig] = mapped_column(EmbeddingConfigColumn, doc="Configuration settings for embedding.")
|
|
46
|
+
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="metadata for the source.")
|
|
47
|
+
|
|
48
|
+
# relationships
|
|
49
|
+
organization: Mapped["Organization"] = relationship("Organization", back_populates="sources")
|
|
50
|
+
files: Mapped[List["Source"]] = relationship("FileMetadata", back_populates="source", cascade="all, delete-orphan")
|
|
51
|
+
# agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents", back_populates="sources")
|
letta/orm/tool.py
CHANGED
|
@@ -28,7 +28,6 @@ class Tool(SqlalchemyBase, OrganizationMixin):
|
|
|
28
28
|
# An organization should not have multiple tools with the same name
|
|
29
29
|
__table_args__ = (UniqueConstraint("name", "organization_id", name="uix_name_organization"),)
|
|
30
30
|
|
|
31
|
-
id: Mapped[str] = mapped_column(String, primary_key=True)
|
|
32
31
|
name: Mapped[str] = mapped_column(doc="The display name of the tool.")
|
|
33
32
|
description: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The description of the tool.")
|
|
34
33
|
tags: Mapped[List] = mapped_column(JSON, doc="Metadata tags used to filter tools.")
|
letta/orm/user.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from typing import TYPE_CHECKING
|
|
2
2
|
|
|
3
|
-
from sqlalchemy import String
|
|
4
3
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
|
5
4
|
|
|
6
5
|
from letta.orm.mixins import OrganizationMixin
|
|
@@ -17,7 +16,6 @@ class User(SqlalchemyBase, OrganizationMixin):
|
|
|
17
16
|
__tablename__ = "users"
|
|
18
17
|
__pydantic_model__ = PydanticUser
|
|
19
18
|
|
|
20
|
-
id: Mapped[str] = mapped_column(String, primary_key=True)
|
|
21
19
|
name: Mapped[str] = mapped_column(nullable=False, doc="The display name of the user.")
|
|
22
20
|
|
|
23
21
|
# relationships
|
letta/providers.py
CHANGED
letta/schemas/file.py
CHANGED
|
@@ -4,7 +4,6 @@ from typing import Optional
|
|
|
4
4
|
from pydantic import Field
|
|
5
5
|
|
|
6
6
|
from letta.schemas.letta_base import LettaBase
|
|
7
|
-
from letta.utils import get_utc_time
|
|
8
7
|
|
|
9
8
|
|
|
10
9
|
class FileMetadataBase(LettaBase):
|
|
@@ -17,7 +16,7 @@ class FileMetadata(FileMetadataBase):
|
|
|
17
16
|
"""Representation of a single FileMetadata"""
|
|
18
17
|
|
|
19
18
|
id: str = FileMetadataBase.generate_id_field()
|
|
20
|
-
|
|
19
|
+
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the document.")
|
|
21
20
|
source_id: str = Field(..., description="The unique identifier of the source associated with the document.")
|
|
22
21
|
file_name: Optional[str] = Field(None, description="The name of the file.")
|
|
23
22
|
file_path: Optional[str] = Field(None, description="The path to the file.")
|
|
@@ -25,7 +24,8 @@ class FileMetadata(FileMetadataBase):
|
|
|
25
24
|
file_size: Optional[int] = Field(None, description="The size of the file in bytes.")
|
|
26
25
|
file_creation_date: Optional[str] = Field(None, description="The creation date of the file.")
|
|
27
26
|
file_last_modified_date: Optional[str] = Field(None, description="The last modified date of the file.")
|
|
28
|
-
created_at: datetime = Field(default_factory=get_utc_time, description="The creation date of this file metadata object.")
|
|
29
27
|
|
|
30
|
-
|
|
31
|
-
|
|
28
|
+
# orm metadata, optional fields
|
|
29
|
+
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow, description="The creation date of the file.")
|
|
30
|
+
updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow, description="The update date of the file.")
|
|
31
|
+
is_deleted: bool = Field(False, description="Whether this file is deleted or not.")
|
letta/schemas/source.py
CHANGED
|
@@ -1,12 +1,10 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
2
|
from typing import Optional
|
|
3
3
|
|
|
4
|
-
from
|
|
5
|
-
from pydantic import BaseModel, Field
|
|
4
|
+
from pydantic import Field
|
|
6
5
|
|
|
7
6
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
8
7
|
from letta.schemas.letta_base import LettaBase
|
|
9
|
-
from letta.utils import get_utc_time
|
|
10
8
|
|
|
11
9
|
|
|
12
10
|
class BaseSource(LettaBase):
|
|
@@ -15,15 +13,6 @@ class BaseSource(LettaBase):
|
|
|
15
13
|
"""
|
|
16
14
|
|
|
17
15
|
__id_prefix__ = "source"
|
|
18
|
-
description: Optional[str] = Field(None, description="The description of the source.")
|
|
19
|
-
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the passage.")
|
|
20
|
-
# NOTE: .metadata is a reserved attribute on SQLModel
|
|
21
|
-
metadata_: Optional[dict] = Field(None, description="Metadata associated with the source.")
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
class SourceCreate(BaseSource):
|
|
25
|
-
name: str = Field(..., description="The name of the source.")
|
|
26
|
-
description: Optional[str] = Field(None, description="The description of the source.")
|
|
27
16
|
|
|
28
17
|
|
|
29
18
|
class Source(BaseSource):
|
|
@@ -34,7 +23,6 @@ class Source(BaseSource):
|
|
|
34
23
|
id (str): The ID of the source
|
|
35
24
|
name (str): The name of the source.
|
|
36
25
|
embedding_config (EmbeddingConfig): The embedding configuration used by the source.
|
|
37
|
-
created_at (datetime): The creation date of the source.
|
|
38
26
|
user_id (str): The ID of the user that created the source.
|
|
39
27
|
metadata_ (dict): Metadata associated with the source.
|
|
40
28
|
description (str): The description of the source.
|
|
@@ -42,21 +30,39 @@ class Source(BaseSource):
|
|
|
42
30
|
|
|
43
31
|
id: str = BaseSource.generate_id_field()
|
|
44
32
|
name: str = Field(..., description="The name of the source.")
|
|
33
|
+
description: Optional[str] = Field(None, description="The description of the source.")
|
|
45
34
|
embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the source.")
|
|
46
|
-
|
|
47
|
-
|
|
35
|
+
organization_id: Optional[str] = Field(None, description="The ID of the organization that created the source.")
|
|
36
|
+
metadata_: Optional[dict] = Field(None, description="Metadata associated with the source.")
|
|
48
37
|
|
|
38
|
+
# metadata fields
|
|
39
|
+
created_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
|
|
40
|
+
last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
|
|
41
|
+
created_at: Optional[datetime] = Field(None, description="The timestamp when the source was created.")
|
|
42
|
+
updated_at: Optional[datetime] = Field(None, description="The timestamp when the source was last updated.")
|
|
49
43
|
|
|
50
|
-
class SourceUpdate(BaseSource):
|
|
51
|
-
id: str = Field(..., description="The ID of the source.")
|
|
52
|
-
name: Optional[str] = Field(None, description="The name of the source.")
|
|
53
44
|
|
|
45
|
+
class SourceCreate(BaseSource):
|
|
46
|
+
"""
|
|
47
|
+
Schema for creating a new Source.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
# required
|
|
51
|
+
name: str = Field(..., description="The name of the source.")
|
|
52
|
+
# TODO: @matt, make this required after shub makes the FE changes
|
|
53
|
+
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the source.")
|
|
54
|
+
|
|
55
|
+
# optional
|
|
56
|
+
description: Optional[str] = Field(None, description="The description of the source.")
|
|
57
|
+
metadata_: Optional[dict] = Field(None, description="Metadata associated with the source.")
|
|
54
58
|
|
|
55
|
-
class UploadFileToSourceRequest(BaseModel):
|
|
56
|
-
file: UploadFile = Field(..., description="The file to upload.")
|
|
57
59
|
|
|
60
|
+
class SourceUpdate(BaseSource):
|
|
61
|
+
"""
|
|
62
|
+
Schema for updating an existing Source.
|
|
63
|
+
"""
|
|
58
64
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
65
|
+
name: Optional[str] = Field(None, description="The name of the source.")
|
|
66
|
+
description: Optional[str] = Field(None, description="The description of the source.")
|
|
67
|
+
metadata_: Optional[dict] = Field(None, description="Metadata associated with the source.")
|
|
68
|
+
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the source.")
|
letta/server/rest_api/app.py
CHANGED
|
@@ -6,6 +6,8 @@ from typing import Optional
|
|
|
6
6
|
|
|
7
7
|
import uvicorn
|
|
8
8
|
from fastapi import FastAPI
|
|
9
|
+
from fastapi.responses import JSONResponse
|
|
10
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
9
11
|
from starlette.middleware.cors import CORSMiddleware
|
|
10
12
|
|
|
11
13
|
from letta.__init__ import __version__
|
|
@@ -94,6 +96,27 @@ def generate_openapi_schema(app: FastAPI):
|
|
|
94
96
|
Path(f"openapi_{name}.json").write_text(json.dumps(docs, indent=2))
|
|
95
97
|
|
|
96
98
|
|
|
99
|
+
# middleware that only allows requests to pass through if user provides a password thats randomly generated and stored in memory
|
|
100
|
+
def generate_password():
|
|
101
|
+
import secrets
|
|
102
|
+
|
|
103
|
+
return secrets.token_urlsafe(16)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
random_password = generate_password()
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class CheckPasswordMiddleware(BaseHTTPMiddleware):
|
|
110
|
+
async def dispatch(self, request, call_next):
|
|
111
|
+
if request.headers.get("X-BARE-PASSWORD") == f"password {random_password}":
|
|
112
|
+
return await call_next(request)
|
|
113
|
+
|
|
114
|
+
return JSONResponse(
|
|
115
|
+
content={"detail": "Unauthorized"},
|
|
116
|
+
status_code=401,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
97
120
|
def create_application() -> "FastAPI":
|
|
98
121
|
"""the application start routine"""
|
|
99
122
|
# global server
|
|
@@ -113,6 +136,10 @@ def create_application() -> "FastAPI":
|
|
|
113
136
|
settings.cors_origins.append("https://app.letta.com")
|
|
114
137
|
print(f"▶ View using ADE at: https://app.letta.com/local-project/agents")
|
|
115
138
|
|
|
139
|
+
if "--secure" in sys.argv:
|
|
140
|
+
print(f"▶ Using secure mode with password: {random_password}")
|
|
141
|
+
app.add_middleware(CheckPasswordMiddleware)
|
|
142
|
+
|
|
116
143
|
app.add_middleware(
|
|
117
144
|
CORSMiddleware,
|
|
118
145
|
allow_origins=settings.cors_origins,
|
|
@@ -36,7 +36,7 @@ def get_source(
|
|
|
36
36
|
"""
|
|
37
37
|
actor = server.get_user_or_default(user_id=user_id)
|
|
38
38
|
|
|
39
|
-
return server.
|
|
39
|
+
return server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
|
40
40
|
|
|
41
41
|
|
|
42
42
|
@router.get("/name/{source_name}", response_model=str, operation_id="get_source_id_by_name")
|
|
@@ -50,8 +50,8 @@ def get_source_id_by_name(
|
|
|
50
50
|
"""
|
|
51
51
|
actor = server.get_user_or_default(user_id=user_id)
|
|
52
52
|
|
|
53
|
-
|
|
54
|
-
return
|
|
53
|
+
source = server.source_manager.get_source_by_name(source_name=source_name, actor=actor)
|
|
54
|
+
return source.id
|
|
55
55
|
|
|
56
56
|
|
|
57
57
|
@router.get("/", response_model=List[Source], operation_id="list_sources")
|
|
@@ -64,12 +64,12 @@ def list_sources(
|
|
|
64
64
|
"""
|
|
65
65
|
actor = server.get_user_or_default(user_id=user_id)
|
|
66
66
|
|
|
67
|
-
return server.list_all_sources(
|
|
67
|
+
return server.list_all_sources(actor=actor)
|
|
68
68
|
|
|
69
69
|
|
|
70
70
|
@router.post("/", response_model=Source, operation_id="create_source")
|
|
71
71
|
def create_source(
|
|
72
|
-
|
|
72
|
+
source_create: SourceCreate,
|
|
73
73
|
server: "SyncServer" = Depends(get_letta_server),
|
|
74
74
|
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
75
75
|
):
|
|
@@ -77,8 +77,9 @@ def create_source(
|
|
|
77
77
|
Create a new data source.
|
|
78
78
|
"""
|
|
79
79
|
actor = server.get_user_or_default(user_id=user_id)
|
|
80
|
+
source = Source(**source_create.model_dump())
|
|
80
81
|
|
|
81
|
-
return server.create_source(
|
|
82
|
+
return server.source_manager.create_source(source=source, actor=actor)
|
|
82
83
|
|
|
83
84
|
|
|
84
85
|
@router.patch("/{source_id}", response_model=Source, operation_id="update_source")
|
|
@@ -92,10 +93,7 @@ def update_source(
|
|
|
92
93
|
Update the name or documentation of an existing data source.
|
|
93
94
|
"""
|
|
94
95
|
actor = server.get_user_or_default(user_id=user_id)
|
|
95
|
-
|
|
96
|
-
assert source.id == source_id, "Source ID in path must match ID in request body"
|
|
97
|
-
|
|
98
|
-
return server.update_source(request=source, user_id=actor.id)
|
|
96
|
+
return server.source_manager.update_source(source_id=source_id, source_update=source, actor=actor)
|
|
99
97
|
|
|
100
98
|
|
|
101
99
|
@router.delete("/{source_id}", response_model=None, operation_id="delete_source")
|
|
@@ -109,7 +107,7 @@ def delete_source(
|
|
|
109
107
|
"""
|
|
110
108
|
actor = server.get_user_or_default(user_id=user_id)
|
|
111
109
|
|
|
112
|
-
server.delete_source(source_id=source_id,
|
|
110
|
+
server.delete_source(source_id=source_id, actor=actor)
|
|
113
111
|
|
|
114
112
|
|
|
115
113
|
@router.post("/{source_id}/attach", response_model=Source, operation_id="attach_agent_to_source")
|
|
@@ -124,7 +122,7 @@ def attach_source_to_agent(
|
|
|
124
122
|
"""
|
|
125
123
|
actor = server.get_user_or_default(user_id=user_id)
|
|
126
124
|
|
|
127
|
-
source = server.
|
|
125
|
+
source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
|
128
126
|
assert source is not None, f"Source with id={source_id} not found."
|
|
129
127
|
source = server.attach_source_to_agent(source_id=source.id, agent_id=agent_id, user_id=actor.id)
|
|
130
128
|
return source
|
|
@@ -158,7 +156,7 @@ def upload_file_to_source(
|
|
|
158
156
|
"""
|
|
159
157
|
actor = server.get_user_or_default(user_id=user_id)
|
|
160
158
|
|
|
161
|
-
source = server.
|
|
159
|
+
source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
|
162
160
|
assert source is not None, f"Source with id={source_id} not found."
|
|
163
161
|
bytes = file.file.read()
|
|
164
162
|
|
|
@@ -200,11 +198,13 @@ def list_files_from_source(
|
|
|
200
198
|
limit: int = Query(1000, description="Number of files to return"),
|
|
201
199
|
cursor: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"),
|
|
202
200
|
server: "SyncServer" = Depends(get_letta_server),
|
|
201
|
+
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
203
202
|
):
|
|
204
203
|
"""
|
|
205
204
|
List paginated files associated with a data source.
|
|
206
205
|
"""
|
|
207
|
-
|
|
206
|
+
actor = server.get_user_or_default(user_id=user_id)
|
|
207
|
+
return server.source_manager.list_files(source_id=source_id, limit=limit, cursor=cursor, actor=actor)
|
|
208
208
|
|
|
209
209
|
|
|
210
210
|
# it's redundant to include /delete in the URL path. The HTTP verb DELETE already implies that action.
|
|
@@ -221,7 +221,7 @@ def delete_file_from_source(
|
|
|
221
221
|
"""
|
|
222
222
|
actor = server.get_user_or_default(user_id=user_id)
|
|
223
223
|
|
|
224
|
-
deleted_file = server.
|
|
224
|
+
deleted_file = server.source_manager.delete_file(file_id=file_id, actor=actor)
|
|
225
225
|
if deleted_file is None:
|
|
226
226
|
raise HTTPException(status_code=404, detail=f"File with id={file_id} not found.")
|
|
227
227
|
|
letta/server/server.py
CHANGED
|
@@ -65,7 +65,6 @@ from letta.schemas.embedding_config import EmbeddingConfig
|
|
|
65
65
|
|
|
66
66
|
# openai schemas
|
|
67
67
|
from letta.schemas.enums import JobStatus
|
|
68
|
-
from letta.schemas.file import FileMetadata
|
|
69
68
|
from letta.schemas.job import Job
|
|
70
69
|
from letta.schemas.letta_message import LettaMessage
|
|
71
70
|
from letta.schemas.llm_config import LLMConfig
|
|
@@ -78,12 +77,13 @@ from letta.schemas.memory import (
|
|
|
78
77
|
from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage
|
|
79
78
|
from letta.schemas.organization import Organization
|
|
80
79
|
from letta.schemas.passage import Passage
|
|
81
|
-
from letta.schemas.source import Source
|
|
80
|
+
from letta.schemas.source import Source
|
|
82
81
|
from letta.schemas.tool import Tool, ToolCreate
|
|
83
82
|
from letta.schemas.usage import LettaUsageStatistics
|
|
84
83
|
from letta.schemas.user import User
|
|
85
84
|
from letta.services.agents_tags_manager import AgentsTagsManager
|
|
86
85
|
from letta.services.organization_manager import OrganizationManager
|
|
86
|
+
from letta.services.source_manager import SourceManager
|
|
87
87
|
from letta.services.tool_manager import ToolManager
|
|
88
88
|
from letta.services.user_manager import UserManager
|
|
89
89
|
from letta.utils import create_random_username, json_dumps, json_loads
|
|
@@ -249,6 +249,7 @@ class SyncServer(Server):
|
|
|
249
249
|
self.organization_manager = OrganizationManager()
|
|
250
250
|
self.user_manager = UserManager()
|
|
251
251
|
self.tool_manager = ToolManager()
|
|
252
|
+
self.source_manager = SourceManager()
|
|
252
253
|
self.agents_tags_manager = AgentsTagsManager()
|
|
253
254
|
|
|
254
255
|
# Make default user and org
|
|
@@ -1511,12 +1512,16 @@ class SyncServer(Server):
|
|
|
1511
1512
|
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
|
|
1512
1513
|
raise ValueError(f"Agent agent_id={agent_id} does not exist")
|
|
1513
1514
|
|
|
1514
|
-
# Verify that the agent exists and
|
|
1515
|
+
# Verify that the agent exists and belongs to the org of the user
|
|
1515
1516
|
agent_state = self.ms.get_agent(agent_id=agent_id, user_id=user_id)
|
|
1516
1517
|
if not agent_state:
|
|
1517
1518
|
raise ValueError(f"Could not find agent_id={agent_id} under user_id={user_id}")
|
|
1518
|
-
|
|
1519
|
-
|
|
1519
|
+
|
|
1520
|
+
agent_state_user = self.user_manager.get_user_by_id(user_id=agent_state.user_id)
|
|
1521
|
+
if agent_state_user.organization_id != actor.organization_id:
|
|
1522
|
+
raise ValueError(
|
|
1523
|
+
f"Could not authorize agent_id={agent_id} with user_id={user_id} because of differing organizations; agent_id was created in {agent_state_user.organization_id} while user belongs to {actor.organization_id}. How did they get the agent id?"
|
|
1524
|
+
)
|
|
1520
1525
|
|
|
1521
1526
|
# First, if the agent is in the in-memory cache we should remove it
|
|
1522
1527
|
# List of {'user_id': user_id, 'agent_id': agent_id, 'agent': agent_obj} dicts
|
|
@@ -1560,44 +1565,12 @@ class SyncServer(Server):
|
|
|
1560
1565
|
self.ms.delete_api_key(api_key=api_key)
|
|
1561
1566
|
return api_key_obj
|
|
1562
1567
|
|
|
1563
|
-
def
|
|
1564
|
-
"""Create a new data source"""
|
|
1565
|
-
source = Source(
|
|
1566
|
-
name=request.name,
|
|
1567
|
-
user_id=user_id,
|
|
1568
|
-
embedding_config=self.list_embedding_models()[0], # TODO: require providing this
|
|
1569
|
-
)
|
|
1570
|
-
self.ms.create_source(source)
|
|
1571
|
-
assert self.ms.get_source(source_name=request.name, user_id=user_id) is not None, f"Failed to create source {request.name}"
|
|
1572
|
-
return source
|
|
1573
|
-
|
|
1574
|
-
def update_source(self, request: SourceUpdate, user_id: str) -> Source:
|
|
1575
|
-
"""Update an existing data source"""
|
|
1576
|
-
if not request.id:
|
|
1577
|
-
existing_source = self.ms.get_source(source_name=request.name, user_id=user_id)
|
|
1578
|
-
else:
|
|
1579
|
-
existing_source = self.ms.get_source(source_id=request.id)
|
|
1580
|
-
if not existing_source:
|
|
1581
|
-
raise ValueError("Source does not exist")
|
|
1582
|
-
|
|
1583
|
-
# override updated fields
|
|
1584
|
-
if request.name:
|
|
1585
|
-
existing_source.name = request.name
|
|
1586
|
-
if request.metadata_:
|
|
1587
|
-
existing_source.metadata_ = request.metadata_
|
|
1588
|
-
if request.description:
|
|
1589
|
-
existing_source.description = request.description
|
|
1590
|
-
|
|
1591
|
-
self.ms.update_source(existing_source)
|
|
1592
|
-
return existing_source
|
|
1593
|
-
|
|
1594
|
-
def delete_source(self, source_id: str, user_id: str):
|
|
1568
|
+
def delete_source(self, source_id: str, actor: User):
|
|
1595
1569
|
"""Delete a data source"""
|
|
1596
|
-
|
|
1597
|
-
self.ms.delete_source(source_id)
|
|
1570
|
+
self.source_manager.delete_source(source_id=source_id, actor=actor)
|
|
1598
1571
|
|
|
1599
1572
|
# delete data from passage store
|
|
1600
|
-
passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=
|
|
1573
|
+
passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=actor.id)
|
|
1601
1574
|
passage_store.delete({"source_id": source_id})
|
|
1602
1575
|
|
|
1603
1576
|
# TODO: delete data from agent passage stores (?)
|
|
@@ -1639,9 +1612,9 @@ class SyncServer(Server):
|
|
|
1639
1612
|
# try:
|
|
1640
1613
|
from letta.data_sources.connectors import DirectoryConnector
|
|
1641
1614
|
|
|
1642
|
-
source = self.
|
|
1615
|
+
source = self.source_manager.get_source_by_id(source_id=source_id)
|
|
1643
1616
|
connector = DirectoryConnector(input_files=[file_path])
|
|
1644
|
-
num_passages, num_documents = self.load_data(user_id=source.
|
|
1617
|
+
num_passages, num_documents = self.load_data(user_id=source.created_by_id, source_name=source.name, connector=connector)
|
|
1645
1618
|
# except Exception as e:
|
|
1646
1619
|
# # job failed with error
|
|
1647
1620
|
# error = str(e)
|
|
@@ -1662,9 +1635,6 @@ class SyncServer(Server):
|
|
|
1662
1635
|
|
|
1663
1636
|
return job
|
|
1664
1637
|
|
|
1665
|
-
def delete_file_from_source(self, source_id: str, file_id: str, user_id: Optional[str]) -> Optional[FileMetadata]:
|
|
1666
|
-
return self.ms.delete_file_from_source(source_id=source_id, file_id=file_id, user_id=user_id)
|
|
1667
|
-
|
|
1668
1638
|
def load_data(
|
|
1669
1639
|
self,
|
|
1670
1640
|
user_id: str,
|
|
@@ -1675,16 +1645,16 @@ class SyncServer(Server):
|
|
|
1675
1645
|
# TODO: this should be implemented as a batch job or at least async, since it may take a long time
|
|
1676
1646
|
|
|
1677
1647
|
# load data from a data source into the document store
|
|
1678
|
-
|
|
1648
|
+
user = self.user_manager.get_user_by_id(user_id=user_id)
|
|
1649
|
+
source = self.source_manager.get_source_by_name(source_name=source_name, actor=user)
|
|
1679
1650
|
if source is None:
|
|
1680
1651
|
raise ValueError(f"Data source {source_name} does not exist for user {user_id}")
|
|
1681
1652
|
|
|
1682
1653
|
# get the data connectors
|
|
1683
1654
|
passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
|
|
1684
|
-
file_store = StorageConnector.get_storage_connector(TableType.FILES, self.config, user_id=user_id)
|
|
1685
1655
|
|
|
1686
1656
|
# load data into the document store
|
|
1687
|
-
passage_count, document_count = load_data(connector, source, passage_store,
|
|
1657
|
+
passage_count, document_count = load_data(connector, source, passage_store, self.source_manager, actor=user)
|
|
1688
1658
|
return passage_count, document_count
|
|
1689
1659
|
|
|
1690
1660
|
def attach_source_to_agent(
|
|
@@ -1696,10 +1666,13 @@ class SyncServer(Server):
|
|
|
1696
1666
|
source_name: Optional[str] = None,
|
|
1697
1667
|
) -> Source:
|
|
1698
1668
|
# attach a data source to an agent
|
|
1699
|
-
|
|
1700
|
-
if
|
|
1701
|
-
|
|
1702
|
-
|
|
1669
|
+
user = self.user_manager.get_user_by_id(user_id=user_id)
|
|
1670
|
+
if source_id:
|
|
1671
|
+
data_source = self.source_manager.get_source_by_id(source_id=source_id, actor=user)
|
|
1672
|
+
elif source_name:
|
|
1673
|
+
data_source = self.source_manager.get_source_by_name(source_name=source_name, actor=user)
|
|
1674
|
+
else:
|
|
1675
|
+
raise ValueError(f"Need to provide at least source_id or source_name to find the source.")
|
|
1703
1676
|
# get connection to data source storage
|
|
1704
1677
|
source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
|
|
1705
1678
|
|
|
@@ -1719,12 +1692,14 @@ class SyncServer(Server):
|
|
|
1719
1692
|
source_id: Optional[str] = None,
|
|
1720
1693
|
source_name: Optional[str] = None,
|
|
1721
1694
|
) -> Source:
|
|
1722
|
-
|
|
1723
|
-
|
|
1724
|
-
source = self.
|
|
1725
|
-
|
|
1695
|
+
user = self.user_manager.get_user_by_id(user_id=user_id)
|
|
1696
|
+
if source_id:
|
|
1697
|
+
source = self.source_manager.get_source_by_id(source_id=source_id, actor=user)
|
|
1698
|
+
elif source_name:
|
|
1699
|
+
source = self.source_manager.get_source_by_name(source_name=source_name, actor=user)
|
|
1726
1700
|
else:
|
|
1727
|
-
|
|
1701
|
+
raise ValueError(f"Need to provide at least source_id or source_name to find the source.")
|
|
1702
|
+
source_id = source.id
|
|
1728
1703
|
|
|
1729
1704
|
# delete all Passage objects with source_id==source_id from agent's archival memory
|
|
1730
1705
|
agent = self._get_or_load_agent(agent_id=agent_id)
|
|
@@ -1739,27 +1714,25 @@ class SyncServer(Server):
|
|
|
1739
1714
|
|
|
1740
1715
|
def list_attached_sources(self, agent_id: str) -> List[Source]:
|
|
1741
1716
|
# list all attached sources to an agent
|
|
1742
|
-
|
|
1717
|
+
source_ids = self.ms.list_attached_source_ids(agent_id)
|
|
1743
1718
|
|
|
1744
|
-
|
|
1745
|
-
# list all attached sources to an agent
|
|
1746
|
-
return self.ms.list_files_from_source(source_id=source_id, limit=limit, cursor=cursor)
|
|
1719
|
+
return [self.source_manager.get_source_by_id(source_id=id) for id in source_ids]
|
|
1747
1720
|
|
|
1748
1721
|
def list_data_source_passages(self, user_id: str, source_id: str) -> List[Passage]:
|
|
1749
1722
|
warnings.warn("list_data_source_passages is not yet implemented, returning empty list.", category=UserWarning)
|
|
1750
1723
|
return []
|
|
1751
1724
|
|
|
1752
|
-
def list_all_sources(self,
|
|
1725
|
+
def list_all_sources(self, actor: User) -> List[Source]:
|
|
1753
1726
|
"""List all sources (w/ extra metadata) belonging to a user"""
|
|
1754
1727
|
|
|
1755
|
-
sources = self.
|
|
1728
|
+
sources = self.source_manager.list_sources(actor=actor)
|
|
1756
1729
|
|
|
1757
1730
|
# Add extra metadata to the sources
|
|
1758
1731
|
sources_with_metadata = []
|
|
1759
1732
|
for source in sources:
|
|
1760
1733
|
|
|
1761
1734
|
# count number of passages
|
|
1762
|
-
passage_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=
|
|
1735
|
+
passage_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=actor.id)
|
|
1763
1736
|
num_passages = passage_conn.size({"source_id": source.id})
|
|
1764
1737
|
|
|
1765
1738
|
# TODO: add when files table implemented
|
|
@@ -1773,7 +1746,7 @@ class SyncServer(Server):
|
|
|
1773
1746
|
attached_agents = [
|
|
1774
1747
|
{
|
|
1775
1748
|
"id": str(a_id),
|
|
1776
|
-
"name": self.ms.get_agent(user_id=
|
|
1749
|
+
"name": self.ms.get_agent(user_id=actor.id, agent_id=a_id).name,
|
|
1777
1750
|
}
|
|
1778
1751
|
for a_id in agent_ids
|
|
1779
1752
|
]
|
|
@@ -27,18 +27,26 @@ class OrganizationManager:
|
|
|
27
27
|
return self.get_organization_by_id(self.DEFAULT_ORG_ID)
|
|
28
28
|
|
|
29
29
|
@enforce_types
|
|
30
|
-
def get_organization_by_id(self, org_id: str) -> PydanticOrganization:
|
|
30
|
+
def get_organization_by_id(self, org_id: str) -> Optional[PydanticOrganization]:
|
|
31
31
|
"""Fetch an organization by ID."""
|
|
32
32
|
with self.session_maker() as session:
|
|
33
33
|
try:
|
|
34
34
|
organization = OrganizationModel.read(db_session=session, identifier=org_id)
|
|
35
35
|
return organization.to_pydantic()
|
|
36
36
|
except NoResultFound:
|
|
37
|
-
|
|
37
|
+
return None
|
|
38
38
|
|
|
39
39
|
@enforce_types
|
|
40
40
|
def create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
|
|
41
41
|
"""Create a new organization. If a name is provided, it is used, otherwise, a random one is generated."""
|
|
42
|
+
org = self.get_organization_by_id(pydantic_org.id)
|
|
43
|
+
if org:
|
|
44
|
+
return org
|
|
45
|
+
else:
|
|
46
|
+
return self._create_organization(pydantic_org=pydantic_org)
|
|
47
|
+
|
|
48
|
+
@enforce_types
|
|
49
|
+
def _create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
|
|
42
50
|
with self.session_maker() as session:
|
|
43
51
|
org = OrganizationModel(**pydantic_org.model_dump())
|
|
44
52
|
org.create(session)
|
|
@@ -47,16 +55,7 @@ class OrganizationManager:
|
|
|
47
55
|
@enforce_types
|
|
48
56
|
def create_default_organization(self) -> PydanticOrganization:
|
|
49
57
|
"""Create the default organization."""
|
|
50
|
-
|
|
51
|
-
# Try to get it first
|
|
52
|
-
try:
|
|
53
|
-
org = OrganizationModel.read(db_session=session, identifier=self.DEFAULT_ORG_ID)
|
|
54
|
-
# If it doesn't exist, make it
|
|
55
|
-
except NoResultFound:
|
|
56
|
-
org = OrganizationModel(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID)
|
|
57
|
-
org.create(session)
|
|
58
|
-
|
|
59
|
-
return org.to_pydantic()
|
|
58
|
+
return self.create_organization(PydanticOrganization(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID))
|
|
60
59
|
|
|
61
60
|
@enforce_types
|
|
62
61
|
def update_organization_name_using_id(self, org_id: str, name: Optional[str] = None) -> PydanticOrganization:
|
|
@@ -73,7 +72,7 @@ class OrganizationManager:
|
|
|
73
72
|
"""Delete an organization by marking it as deleted."""
|
|
74
73
|
with self.session_maker() as session:
|
|
75
74
|
organization = OrganizationModel.read(db_session=session, identifier=org_id)
|
|
76
|
-
organization.
|
|
75
|
+
organization.hard_delete(session)
|
|
77
76
|
|
|
78
77
|
@enforce_types
|
|
79
78
|
def list_organizations(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticOrganization]:
|