letta-nightly 0.5.2.dev20241112104101__py3-none-any.whl → 0.5.2.dev20241113234401__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of letta-nightly might be problematic. Click here for more details.

letta/orm/source.py ADDED
@@ -0,0 +1,51 @@
1
+ from typing import TYPE_CHECKING, List, Optional
2
+
3
+ from sqlalchemy import JSON, TypeDecorator
4
+ from sqlalchemy.orm import Mapped, mapped_column, relationship
5
+
6
+ from letta.orm.mixins import OrganizationMixin
7
+ from letta.orm.sqlalchemy_base import SqlalchemyBase
8
+ from letta.schemas.embedding_config import EmbeddingConfig
9
+ from letta.schemas.source import Source as PydanticSource
10
+
11
+ if TYPE_CHECKING:
12
+ from letta.orm.organization import Organization
13
+
14
+
15
+ class EmbeddingConfigColumn(TypeDecorator):
16
+ """Custom type for storing EmbeddingConfig as JSON"""
17
+
18
+ impl = JSON
19
+ cache_ok = True
20
+
21
+ def load_dialect_impl(self, dialect):
22
+ return dialect.type_descriptor(JSON())
23
+
24
+ def process_bind_param(self, value, dialect):
25
+ if value:
26
+ # return vars(value)
27
+ if isinstance(value, EmbeddingConfig):
28
+ return value.model_dump()
29
+ return value
30
+
31
+ def process_result_value(self, value, dialect):
32
+ if value:
33
+ return EmbeddingConfig(**value)
34
+ return value
35
+
36
+
37
+ class Source(SqlalchemyBase, OrganizationMixin):
38
+ """A source represents an embedded text passage"""
39
+
40
+ __tablename__ = "sources"
41
+ __pydantic_model__ = PydanticSource
42
+
43
+ name: Mapped[str] = mapped_column(doc="the name of the source, must be unique within the org", nullable=False)
44
+ description: Mapped[str] = mapped_column(nullable=True, doc="a human-readable description of the source")
45
+ embedding_config: Mapped[EmbeddingConfig] = mapped_column(EmbeddingConfigColumn, doc="Configuration settings for embedding.")
46
+ metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="metadata for the source.")
47
+
48
+ # relationships
49
+ organization: Mapped["Organization"] = relationship("Organization", back_populates="sources")
50
+ files: Mapped[List["Source"]] = relationship("FileMetadata", back_populates="source", cascade="all, delete-orphan")
51
+ # agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents", back_populates="sources")
letta/orm/tool.py CHANGED
@@ -28,7 +28,6 @@ class Tool(SqlalchemyBase, OrganizationMixin):
28
28
  # An organization should not have multiple tools with the same name
29
29
  __table_args__ = (UniqueConstraint("name", "organization_id", name="uix_name_organization"),)
30
30
 
31
- id: Mapped[str] = mapped_column(String, primary_key=True)
32
31
  name: Mapped[str] = mapped_column(doc="The display name of the tool.")
33
32
  description: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The description of the tool.")
34
33
  tags: Mapped[List] = mapped_column(JSON, doc="Metadata tags used to filter tools.")
letta/orm/user.py CHANGED
@@ -1,6 +1,5 @@
1
1
  from typing import TYPE_CHECKING
2
2
 
3
- from sqlalchemy import String
4
3
  from sqlalchemy.orm import Mapped, mapped_column, relationship
5
4
 
6
5
  from letta.orm.mixins import OrganizationMixin
@@ -17,7 +16,6 @@ class User(SqlalchemyBase, OrganizationMixin):
17
16
  __tablename__ = "users"
18
17
  __pydantic_model__ = PydanticUser
19
18
 
20
- id: Mapped[str] = mapped_column(String, primary_key=True)
21
19
  name: Mapped[str] = mapped_column(nullable=False, doc="The display name of the user.")
22
20
 
23
21
  # relationships
letta/providers.py CHANGED
@@ -462,7 +462,6 @@ class VLLMChatCompletionsProvider(Provider):
462
462
  response = openai_get_model_list(self.base_url, api_key=None)
463
463
 
464
464
  configs = []
465
- print(response)
466
465
  for model in response["data"]:
467
466
  configs.append(
468
467
  LLMConfig(
letta/schemas/file.py CHANGED
@@ -4,7 +4,6 @@ from typing import Optional
4
4
  from pydantic import Field
5
5
 
6
6
  from letta.schemas.letta_base import LettaBase
7
- from letta.utils import get_utc_time
8
7
 
9
8
 
10
9
  class FileMetadataBase(LettaBase):
@@ -17,7 +16,7 @@ class FileMetadata(FileMetadataBase):
17
16
  """Representation of a single FileMetadata"""
18
17
 
19
18
  id: str = FileMetadataBase.generate_id_field()
20
- user_id: str = Field(description="The unique identifier of the user associated with the document.")
19
+ organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the document.")
21
20
  source_id: str = Field(..., description="The unique identifier of the source associated with the document.")
22
21
  file_name: Optional[str] = Field(None, description="The name of the file.")
23
22
  file_path: Optional[str] = Field(None, description="The path to the file.")
@@ -25,7 +24,8 @@ class FileMetadata(FileMetadataBase):
25
24
  file_size: Optional[int] = Field(None, description="The size of the file in bytes.")
26
25
  file_creation_date: Optional[str] = Field(None, description="The creation date of the file.")
27
26
  file_last_modified_date: Optional[str] = Field(None, description="The last modified date of the file.")
28
- created_at: datetime = Field(default_factory=get_utc_time, description="The creation date of this file metadata object.")
29
27
 
30
- class Config:
31
- extra = "allow"
28
+ # orm metadata, optional fields
29
+ created_at: Optional[datetime] = Field(default_factory=datetime.utcnow, description="The creation date of the file.")
30
+ updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow, description="The update date of the file.")
31
+ is_deleted: bool = Field(False, description="Whether this file is deleted or not.")
letta/schemas/source.py CHANGED
@@ -1,12 +1,10 @@
1
1
  from datetime import datetime
2
2
  from typing import Optional
3
3
 
4
- from fastapi import UploadFile
5
- from pydantic import BaseModel, Field
4
+ from pydantic import Field
6
5
 
7
6
  from letta.schemas.embedding_config import EmbeddingConfig
8
7
  from letta.schemas.letta_base import LettaBase
9
- from letta.utils import get_utc_time
10
8
 
11
9
 
12
10
  class BaseSource(LettaBase):
@@ -15,15 +13,6 @@ class BaseSource(LettaBase):
15
13
  """
16
14
 
17
15
  __id_prefix__ = "source"
18
- description: Optional[str] = Field(None, description="The description of the source.")
19
- embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the passage.")
20
- # NOTE: .metadata is a reserved attribute on SQLModel
21
- metadata_: Optional[dict] = Field(None, description="Metadata associated with the source.")
22
-
23
-
24
- class SourceCreate(BaseSource):
25
- name: str = Field(..., description="The name of the source.")
26
- description: Optional[str] = Field(None, description="The description of the source.")
27
16
 
28
17
 
29
18
  class Source(BaseSource):
@@ -34,7 +23,6 @@ class Source(BaseSource):
34
23
  id (str): The ID of the source
35
24
  name (str): The name of the source.
36
25
  embedding_config (EmbeddingConfig): The embedding configuration used by the source.
37
- created_at (datetime): The creation date of the source.
38
26
  user_id (str): The ID of the user that created the source.
39
27
  metadata_ (dict): Metadata associated with the source.
40
28
  description (str): The description of the source.
@@ -42,21 +30,39 @@ class Source(BaseSource):
42
30
 
43
31
  id: str = BaseSource.generate_id_field()
44
32
  name: str = Field(..., description="The name of the source.")
33
+ description: Optional[str] = Field(None, description="The description of the source.")
45
34
  embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the source.")
46
- created_at: datetime = Field(default_factory=get_utc_time, description="The creation date of the source.")
47
- user_id: str = Field(..., description="The ID of the user that created the source.")
35
+ organization_id: Optional[str] = Field(None, description="The ID of the organization that created the source.")
36
+ metadata_: Optional[dict] = Field(None, description="Metadata associated with the source.")
48
37
 
38
+ # metadata fields
39
+ created_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
40
+ last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
41
+ created_at: Optional[datetime] = Field(None, description="The timestamp when the source was created.")
42
+ updated_at: Optional[datetime] = Field(None, description="The timestamp when the source was last updated.")
49
43
 
50
- class SourceUpdate(BaseSource):
51
- id: str = Field(..., description="The ID of the source.")
52
- name: Optional[str] = Field(None, description="The name of the source.")
53
44
 
45
+ class SourceCreate(BaseSource):
46
+ """
47
+ Schema for creating a new Source.
48
+ """
49
+
50
+ # required
51
+ name: str = Field(..., description="The name of the source.")
52
+ # TODO: @matt, make this required after shub makes the FE changes
53
+ embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the source.")
54
+
55
+ # optional
56
+ description: Optional[str] = Field(None, description="The description of the source.")
57
+ metadata_: Optional[dict] = Field(None, description="Metadata associated with the source.")
54
58
 
55
- class UploadFileToSourceRequest(BaseModel):
56
- file: UploadFile = Field(..., description="The file to upload.")
57
59
 
60
+ class SourceUpdate(BaseSource):
61
+ """
62
+ Schema for updating an existing Source.
63
+ """
58
64
 
59
- class UploadFileToSourceResponse(BaseModel):
60
- source: Source = Field(..., description="The source the file was uploaded to.")
61
- added_passages: int = Field(..., description="The number of passages added to the source.")
62
- added_documents: int = Field(..., description="The number of files added to the source.")
65
+ name: Optional[str] = Field(None, description="The name of the source.")
66
+ description: Optional[str] = Field(None, description="The description of the source.")
67
+ metadata_: Optional[dict] = Field(None, description="Metadata associated with the source.")
68
+ embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the source.")
@@ -6,6 +6,8 @@ from typing import Optional
6
6
 
7
7
  import uvicorn
8
8
  from fastapi import FastAPI
9
+ from fastapi.responses import JSONResponse
10
+ from starlette.middleware.base import BaseHTTPMiddleware
9
11
  from starlette.middleware.cors import CORSMiddleware
10
12
 
11
13
  from letta.__init__ import __version__
@@ -94,6 +96,27 @@ def generate_openapi_schema(app: FastAPI):
94
96
  Path(f"openapi_{name}.json").write_text(json.dumps(docs, indent=2))
95
97
 
96
98
 
99
+ # middleware that only allows requests to pass through if user provides a password thats randomly generated and stored in memory
100
+ def generate_password():
101
+ import secrets
102
+
103
+ return secrets.token_urlsafe(16)
104
+
105
+
106
+ random_password = generate_password()
107
+
108
+
109
+ class CheckPasswordMiddleware(BaseHTTPMiddleware):
110
+ async def dispatch(self, request, call_next):
111
+ if request.headers.get("X-BARE-PASSWORD") == f"password {random_password}":
112
+ return await call_next(request)
113
+
114
+ return JSONResponse(
115
+ content={"detail": "Unauthorized"},
116
+ status_code=401,
117
+ )
118
+
119
+
97
120
  def create_application() -> "FastAPI":
98
121
  """the application start routine"""
99
122
  # global server
@@ -113,6 +136,10 @@ def create_application() -> "FastAPI":
113
136
  settings.cors_origins.append("https://app.letta.com")
114
137
  print(f"▶ View using ADE at: https://app.letta.com/local-project/agents")
115
138
 
139
+ if "--secure" in sys.argv:
140
+ print(f"▶ Using secure mode with password: {random_password}")
141
+ app.add_middleware(CheckPasswordMiddleware)
142
+
116
143
  app.add_middleware(
117
144
  CORSMiddleware,
118
145
  allow_origins=settings.cors_origins,
@@ -36,7 +36,7 @@ def get_source(
36
36
  """
37
37
  actor = server.get_user_or_default(user_id=user_id)
38
38
 
39
- return server.get_source(source_id=source_id, user_id=actor.id)
39
+ return server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
40
40
 
41
41
 
42
42
  @router.get("/name/{source_name}", response_model=str, operation_id="get_source_id_by_name")
@@ -50,8 +50,8 @@ def get_source_id_by_name(
50
50
  """
51
51
  actor = server.get_user_or_default(user_id=user_id)
52
52
 
53
- source_id = server.get_source_id(source_name=source_name, user_id=actor.id)
54
- return source_id
53
+ source = server.source_manager.get_source_by_name(source_name=source_name, actor=actor)
54
+ return source.id
55
55
 
56
56
 
57
57
  @router.get("/", response_model=List[Source], operation_id="list_sources")
@@ -64,12 +64,12 @@ def list_sources(
64
64
  """
65
65
  actor = server.get_user_or_default(user_id=user_id)
66
66
 
67
- return server.list_all_sources(user_id=actor.id)
67
+ return server.list_all_sources(actor=actor)
68
68
 
69
69
 
70
70
  @router.post("/", response_model=Source, operation_id="create_source")
71
71
  def create_source(
72
- source: SourceCreate,
72
+ source_create: SourceCreate,
73
73
  server: "SyncServer" = Depends(get_letta_server),
74
74
  user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
75
75
  ):
@@ -77,8 +77,9 @@ def create_source(
77
77
  Create a new data source.
78
78
  """
79
79
  actor = server.get_user_or_default(user_id=user_id)
80
+ source = Source(**source_create.model_dump())
80
81
 
81
- return server.create_source(request=source, user_id=actor.id)
82
+ return server.source_manager.create_source(source=source, actor=actor)
82
83
 
83
84
 
84
85
  @router.patch("/{source_id}", response_model=Source, operation_id="update_source")
@@ -92,10 +93,7 @@ def update_source(
92
93
  Update the name or documentation of an existing data source.
93
94
  """
94
95
  actor = server.get_user_or_default(user_id=user_id)
95
-
96
- assert source.id == source_id, "Source ID in path must match ID in request body"
97
-
98
- return server.update_source(request=source, user_id=actor.id)
96
+ return server.source_manager.update_source(source_id=source_id, source_update=source, actor=actor)
99
97
 
100
98
 
101
99
  @router.delete("/{source_id}", response_model=None, operation_id="delete_source")
@@ -109,7 +107,7 @@ def delete_source(
109
107
  """
110
108
  actor = server.get_user_or_default(user_id=user_id)
111
109
 
112
- server.delete_source(source_id=source_id, user_id=actor.id)
110
+ server.delete_source(source_id=source_id, actor=actor)
113
111
 
114
112
 
115
113
  @router.post("/{source_id}/attach", response_model=Source, operation_id="attach_agent_to_source")
@@ -124,7 +122,7 @@ def attach_source_to_agent(
124
122
  """
125
123
  actor = server.get_user_or_default(user_id=user_id)
126
124
 
127
- source = server.ms.get_source(source_id=source_id, user_id=actor.id)
125
+ source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
128
126
  assert source is not None, f"Source with id={source_id} not found."
129
127
  source = server.attach_source_to_agent(source_id=source.id, agent_id=agent_id, user_id=actor.id)
130
128
  return source
@@ -158,7 +156,7 @@ def upload_file_to_source(
158
156
  """
159
157
  actor = server.get_user_or_default(user_id=user_id)
160
158
 
161
- source = server.ms.get_source(source_id=source_id, user_id=actor.id)
159
+ source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
162
160
  assert source is not None, f"Source with id={source_id} not found."
163
161
  bytes = file.file.read()
164
162
 
@@ -200,11 +198,13 @@ def list_files_from_source(
200
198
  limit: int = Query(1000, description="Number of files to return"),
201
199
  cursor: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"),
202
200
  server: "SyncServer" = Depends(get_letta_server),
201
+ user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
203
202
  ):
204
203
  """
205
204
  List paginated files associated with a data source.
206
205
  """
207
- return server.list_files_from_source(source_id=source_id, limit=limit, cursor=cursor)
206
+ actor = server.get_user_or_default(user_id=user_id)
207
+ return server.source_manager.list_files(source_id=source_id, limit=limit, cursor=cursor, actor=actor)
208
208
 
209
209
 
210
210
  # it's redundant to include /delete in the URL path. The HTTP verb DELETE already implies that action.
@@ -221,7 +221,7 @@ def delete_file_from_source(
221
221
  """
222
222
  actor = server.get_user_or_default(user_id=user_id)
223
223
 
224
- deleted_file = server.delete_file_from_source(source_id=source_id, file_id=file_id, user_id=actor.id)
224
+ deleted_file = server.source_manager.delete_file(file_id=file_id, actor=actor)
225
225
  if deleted_file is None:
226
226
  raise HTTPException(status_code=404, detail=f"File with id={file_id} not found.")
227
227
 
letta/server/server.py CHANGED
@@ -65,7 +65,6 @@ from letta.schemas.embedding_config import EmbeddingConfig
65
65
 
66
66
  # openai schemas
67
67
  from letta.schemas.enums import JobStatus
68
- from letta.schemas.file import FileMetadata
69
68
  from letta.schemas.job import Job
70
69
  from letta.schemas.letta_message import LettaMessage
71
70
  from letta.schemas.llm_config import LLMConfig
@@ -78,12 +77,13 @@ from letta.schemas.memory import (
78
77
  from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage
79
78
  from letta.schemas.organization import Organization
80
79
  from letta.schemas.passage import Passage
81
- from letta.schemas.source import Source, SourceCreate, SourceUpdate
80
+ from letta.schemas.source import Source
82
81
  from letta.schemas.tool import Tool, ToolCreate
83
82
  from letta.schemas.usage import LettaUsageStatistics
84
83
  from letta.schemas.user import User
85
84
  from letta.services.agents_tags_manager import AgentsTagsManager
86
85
  from letta.services.organization_manager import OrganizationManager
86
+ from letta.services.source_manager import SourceManager
87
87
  from letta.services.tool_manager import ToolManager
88
88
  from letta.services.user_manager import UserManager
89
89
  from letta.utils import create_random_username, json_dumps, json_loads
@@ -249,6 +249,7 @@ class SyncServer(Server):
249
249
  self.organization_manager = OrganizationManager()
250
250
  self.user_manager = UserManager()
251
251
  self.tool_manager = ToolManager()
252
+ self.source_manager = SourceManager()
252
253
  self.agents_tags_manager = AgentsTagsManager()
253
254
 
254
255
  # Make default user and org
@@ -1511,12 +1512,16 @@ class SyncServer(Server):
1511
1512
  if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
1512
1513
  raise ValueError(f"Agent agent_id={agent_id} does not exist")
1513
1514
 
1514
- # Verify that the agent exists and is owned by the user
1515
+ # Verify that the agent exists and belongs to the org of the user
1515
1516
  agent_state = self.ms.get_agent(agent_id=agent_id, user_id=user_id)
1516
1517
  if not agent_state:
1517
1518
  raise ValueError(f"Could not find agent_id={agent_id} under user_id={user_id}")
1518
- if agent_state.user_id != user_id:
1519
- raise ValueError(f"Could not authorize agent_id={agent_id} with user_id={user_id}")
1519
+
1520
+ agent_state_user = self.user_manager.get_user_by_id(user_id=agent_state.user_id)
1521
+ if agent_state_user.organization_id != actor.organization_id:
1522
+ raise ValueError(
1523
+ f"Could not authorize agent_id={agent_id} with user_id={user_id} because of differing organizations; agent_id was created in {agent_state_user.organization_id} while user belongs to {actor.organization_id}. How did they get the agent id?"
1524
+ )
1520
1525
 
1521
1526
  # First, if the agent is in the in-memory cache we should remove it
1522
1527
  # List of {'user_id': user_id, 'agent_id': agent_id, 'agent': agent_obj} dicts
@@ -1560,44 +1565,12 @@ class SyncServer(Server):
1560
1565
  self.ms.delete_api_key(api_key=api_key)
1561
1566
  return api_key_obj
1562
1567
 
1563
- def create_source(self, request: SourceCreate, user_id: str) -> Source: # TODO: add other fields
1564
- """Create a new data source"""
1565
- source = Source(
1566
- name=request.name,
1567
- user_id=user_id,
1568
- embedding_config=self.list_embedding_models()[0], # TODO: require providing this
1569
- )
1570
- self.ms.create_source(source)
1571
- assert self.ms.get_source(source_name=request.name, user_id=user_id) is not None, f"Failed to create source {request.name}"
1572
- return source
1573
-
1574
- def update_source(self, request: SourceUpdate, user_id: str) -> Source:
1575
- """Update an existing data source"""
1576
- if not request.id:
1577
- existing_source = self.ms.get_source(source_name=request.name, user_id=user_id)
1578
- else:
1579
- existing_source = self.ms.get_source(source_id=request.id)
1580
- if not existing_source:
1581
- raise ValueError("Source does not exist")
1582
-
1583
- # override updated fields
1584
- if request.name:
1585
- existing_source.name = request.name
1586
- if request.metadata_:
1587
- existing_source.metadata_ = request.metadata_
1588
- if request.description:
1589
- existing_source.description = request.description
1590
-
1591
- self.ms.update_source(existing_source)
1592
- return existing_source
1593
-
1594
- def delete_source(self, source_id: str, user_id: str):
1568
+ def delete_source(self, source_id: str, actor: User):
1595
1569
  """Delete a data source"""
1596
- source = self.ms.get_source(source_id=source_id, user_id=user_id)
1597
- self.ms.delete_source(source_id)
1570
+ self.source_manager.delete_source(source_id=source_id, actor=actor)
1598
1571
 
1599
1572
  # delete data from passage store
1600
- passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
1573
+ passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=actor.id)
1601
1574
  passage_store.delete({"source_id": source_id})
1602
1575
 
1603
1576
  # TODO: delete data from agent passage stores (?)
@@ -1639,9 +1612,9 @@ class SyncServer(Server):
1639
1612
  # try:
1640
1613
  from letta.data_sources.connectors import DirectoryConnector
1641
1614
 
1642
- source = self.ms.get_source(source_id=source_id)
1615
+ source = self.source_manager.get_source_by_id(source_id=source_id)
1643
1616
  connector = DirectoryConnector(input_files=[file_path])
1644
- num_passages, num_documents = self.load_data(user_id=source.user_id, source_name=source.name, connector=connector)
1617
+ num_passages, num_documents = self.load_data(user_id=source.created_by_id, source_name=source.name, connector=connector)
1645
1618
  # except Exception as e:
1646
1619
  # # job failed with error
1647
1620
  # error = str(e)
@@ -1662,9 +1635,6 @@ class SyncServer(Server):
1662
1635
 
1663
1636
  return job
1664
1637
 
1665
- def delete_file_from_source(self, source_id: str, file_id: str, user_id: Optional[str]) -> Optional[FileMetadata]:
1666
- return self.ms.delete_file_from_source(source_id=source_id, file_id=file_id, user_id=user_id)
1667
-
1668
1638
  def load_data(
1669
1639
  self,
1670
1640
  user_id: str,
@@ -1675,16 +1645,16 @@ class SyncServer(Server):
1675
1645
  # TODO: this should be implemented as a batch job or at least async, since it may take a long time
1676
1646
 
1677
1647
  # load data from a data source into the document store
1678
- source = self.ms.get_source(source_name=source_name, user_id=user_id)
1648
+ user = self.user_manager.get_user_by_id(user_id=user_id)
1649
+ source = self.source_manager.get_source_by_name(source_name=source_name, actor=user)
1679
1650
  if source is None:
1680
1651
  raise ValueError(f"Data source {source_name} does not exist for user {user_id}")
1681
1652
 
1682
1653
  # get the data connectors
1683
1654
  passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
1684
- file_store = StorageConnector.get_storage_connector(TableType.FILES, self.config, user_id=user_id)
1685
1655
 
1686
1656
  # load data into the document store
1687
- passage_count, document_count = load_data(connector, source, passage_store, file_store)
1657
+ passage_count, document_count = load_data(connector, source, passage_store, self.source_manager, actor=user)
1688
1658
  return passage_count, document_count
1689
1659
 
1690
1660
  def attach_source_to_agent(
@@ -1696,10 +1666,13 @@ class SyncServer(Server):
1696
1666
  source_name: Optional[str] = None,
1697
1667
  ) -> Source:
1698
1668
  # attach a data source to an agent
1699
- data_source = self.ms.get_source(source_id=source_id, user_id=user_id, source_name=source_name)
1700
- if data_source is None:
1701
- raise ValueError(f"Data source id={source_id} name={source_name} does not exist for user_id {user_id}")
1702
-
1669
+ user = self.user_manager.get_user_by_id(user_id=user_id)
1670
+ if source_id:
1671
+ data_source = self.source_manager.get_source_by_id(source_id=source_id, actor=user)
1672
+ elif source_name:
1673
+ data_source = self.source_manager.get_source_by_name(source_name=source_name, actor=user)
1674
+ else:
1675
+ raise ValueError(f"Need to provide at least source_id or source_name to find the source.")
1703
1676
  # get connection to data source storage
1704
1677
  source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
1705
1678
 
@@ -1719,12 +1692,14 @@ class SyncServer(Server):
1719
1692
  source_id: Optional[str] = None,
1720
1693
  source_name: Optional[str] = None,
1721
1694
  ) -> Source:
1722
- if not source_id:
1723
- assert source_name is not None, "source_name must be provided if source_id is not"
1724
- source = self.ms.get_source(source_name=source_name, user_id=user_id)
1725
- source_id = source.id
1695
+ user = self.user_manager.get_user_by_id(user_id=user_id)
1696
+ if source_id:
1697
+ source = self.source_manager.get_source_by_id(source_id=source_id, actor=user)
1698
+ elif source_name:
1699
+ source = self.source_manager.get_source_by_name(source_name=source_name, actor=user)
1726
1700
  else:
1727
- source = self.ms.get_source(source_id=source_id)
1701
+ raise ValueError(f"Need to provide at least source_id or source_name to find the source.")
1702
+ source_id = source.id
1728
1703
 
1729
1704
  # delete all Passage objects with source_id==source_id from agent's archival memory
1730
1705
  agent = self._get_or_load_agent(agent_id=agent_id)
@@ -1739,27 +1714,25 @@ class SyncServer(Server):
1739
1714
 
1740
1715
  def list_attached_sources(self, agent_id: str) -> List[Source]:
1741
1716
  # list all attached sources to an agent
1742
- return self.ms.list_attached_sources(agent_id)
1717
+ source_ids = self.ms.list_attached_source_ids(agent_id)
1743
1718
 
1744
- def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Optional[str] = None) -> List[FileMetadata]:
1745
- # list all attached sources to an agent
1746
- return self.ms.list_files_from_source(source_id=source_id, limit=limit, cursor=cursor)
1719
+ return [self.source_manager.get_source_by_id(source_id=id) for id in source_ids]
1747
1720
 
1748
1721
  def list_data_source_passages(self, user_id: str, source_id: str) -> List[Passage]:
1749
1722
  warnings.warn("list_data_source_passages is not yet implemented, returning empty list.", category=UserWarning)
1750
1723
  return []
1751
1724
 
1752
- def list_all_sources(self, user_id: str) -> List[Source]:
1725
+ def list_all_sources(self, actor: User) -> List[Source]:
1753
1726
  """List all sources (w/ extra metadata) belonging to a user"""
1754
1727
 
1755
- sources = self.ms.list_sources(user_id=user_id)
1728
+ sources = self.source_manager.list_sources(actor=actor)
1756
1729
 
1757
1730
  # Add extra metadata to the sources
1758
1731
  sources_with_metadata = []
1759
1732
  for source in sources:
1760
1733
 
1761
1734
  # count number of passages
1762
- passage_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
1735
+ passage_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=actor.id)
1763
1736
  num_passages = passage_conn.size({"source_id": source.id})
1764
1737
 
1765
1738
  # TODO: add when files table implemented
@@ -1773,7 +1746,7 @@ class SyncServer(Server):
1773
1746
  attached_agents = [
1774
1747
  {
1775
1748
  "id": str(a_id),
1776
- "name": self.ms.get_agent(user_id=user_id, agent_id=a_id).name,
1749
+ "name": self.ms.get_agent(user_id=actor.id, agent_id=a_id).name,
1777
1750
  }
1778
1751
  for a_id in agent_ids
1779
1752
  ]
@@ -27,18 +27,26 @@ class OrganizationManager:
27
27
  return self.get_organization_by_id(self.DEFAULT_ORG_ID)
28
28
 
29
29
  @enforce_types
30
- def get_organization_by_id(self, org_id: str) -> PydanticOrganization:
30
+ def get_organization_by_id(self, org_id: str) -> Optional[PydanticOrganization]:
31
31
  """Fetch an organization by ID."""
32
32
  with self.session_maker() as session:
33
33
  try:
34
34
  organization = OrganizationModel.read(db_session=session, identifier=org_id)
35
35
  return organization.to_pydantic()
36
36
  except NoResultFound:
37
- raise ValueError(f"Organization with id {org_id} not found.")
37
+ return None
38
38
 
39
39
  @enforce_types
40
40
  def create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
41
41
  """Create a new organization. If a name is provided, it is used, otherwise, a random one is generated."""
42
+ org = self.get_organization_by_id(pydantic_org.id)
43
+ if org:
44
+ return org
45
+ else:
46
+ return self._create_organization(pydantic_org=pydantic_org)
47
+
48
+ @enforce_types
49
+ def _create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
42
50
  with self.session_maker() as session:
43
51
  org = OrganizationModel(**pydantic_org.model_dump())
44
52
  org.create(session)
@@ -47,16 +55,7 @@ class OrganizationManager:
47
55
  @enforce_types
48
56
  def create_default_organization(self) -> PydanticOrganization:
49
57
  """Create the default organization."""
50
- with self.session_maker() as session:
51
- # Try to get it first
52
- try:
53
- org = OrganizationModel.read(db_session=session, identifier=self.DEFAULT_ORG_ID)
54
- # If it doesn't exist, make it
55
- except NoResultFound:
56
- org = OrganizationModel(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID)
57
- org.create(session)
58
-
59
- return org.to_pydantic()
58
+ return self.create_organization(PydanticOrganization(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID))
60
59
 
61
60
  @enforce_types
62
61
  def update_organization_name_using_id(self, org_id: str, name: Optional[str] = None) -> PydanticOrganization:
@@ -73,7 +72,7 @@ class OrganizationManager:
73
72
  """Delete an organization by marking it as deleted."""
74
73
  with self.session_maker() as session:
75
74
  organization = OrganizationModel.read(db_session=session, identifier=org_id)
76
- organization.delete(session)
75
+ organization.hard_delete(session)
77
76
 
78
77
  @enforce_types
79
78
  def list_organizations(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticOrganization]: