letta-nightly 0.4.1.dev20241014104152__py3-none-any.whl → 0.5.0.dev20241015104156__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of letta-nightly might be problematic. Click here for more details.
- letta/__init__.py +2 -2
- letta/agent_store/db.py +18 -7
- letta/agent_store/lancedb.py +2 -2
- letta/agent_store/milvus.py +1 -1
- letta/agent_store/qdrant.py +1 -1
- letta/agent_store/storage.py +12 -10
- letta/cli/cli_load.py +1 -1
- letta/client/client.py +51 -0
- letta/data_sources/connectors.py +124 -124
- letta/data_sources/connectors_helper.py +97 -0
- letta/llm_api/mistral.py +47 -0
- letta/metadata.py +58 -0
- letta/providers.py +44 -0
- letta/schemas/file.py +31 -0
- letta/schemas/job.py +1 -1
- letta/schemas/letta_request.py +3 -3
- letta/schemas/llm_config.py +1 -0
- letta/schemas/message.py +6 -2
- letta/schemas/passage.py +3 -3
- letta/schemas/source.py +2 -2
- letta/server/rest_api/routers/v1/agents.py +10 -16
- letta/server/rest_api/routers/v1/jobs.py +17 -1
- letta/server/rest_api/routers/v1/sources.py +7 -9
- letta/server/server.py +86 -13
- letta/server/static_files/assets/{index-9a9c449b.js → index-dc228d4a.js} +4 -4
- letta/server/static_files/index.html +1 -1
- {letta_nightly-0.4.1.dev20241014104152.dist-info → letta_nightly-0.5.0.dev20241015104156.dist-info}/METADATA +1 -1
- {letta_nightly-0.4.1.dev20241014104152.dist-info → letta_nightly-0.5.0.dev20241015104156.dist-info}/RECORD +31 -29
- letta/schemas/document.py +0 -21
- {letta_nightly-0.4.1.dev20241014104152.dist-info → letta_nightly-0.5.0.dev20241015104156.dist-info}/LICENSE +0 -0
- {letta_nightly-0.4.1.dev20241014104152.dist-info → letta_nightly-0.5.0.dev20241015104156.dist-info}/WHEEL +0 -0
- {letta_nightly-0.4.1.dev20241014104152.dist-info → letta_nightly-0.5.0.dev20241015104156.dist-info}/entry_points.txt +0 -0
letta/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
__version__ = "0.
|
|
1
|
+
__version__ = "0.5.0"
|
|
2
2
|
|
|
3
3
|
# import clients
|
|
4
4
|
from letta.client.admin import Admin
|
|
@@ -7,9 +7,9 @@ from letta.client.client import LocalClient, RESTClient, create_client
|
|
|
7
7
|
# imports for easier access
|
|
8
8
|
from letta.schemas.agent import AgentState
|
|
9
9
|
from letta.schemas.block import Block
|
|
10
|
-
from letta.schemas.document import Document
|
|
11
10
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
12
11
|
from letta.schemas.enums import JobStatus
|
|
12
|
+
from letta.schemas.file import FileMetadata
|
|
13
13
|
from letta.schemas.job import Job
|
|
14
14
|
from letta.schemas.letta_message import LettaMessage
|
|
15
15
|
from letta.schemas.llm_config import LLMConfig
|
letta/agent_store/db.py
CHANGED
|
@@ -28,7 +28,7 @@ from letta.agent_store.storage import StorageConnector, TableType
|
|
|
28
28
|
from letta.base import Base
|
|
29
29
|
from letta.config import LettaConfig
|
|
30
30
|
from letta.constants import MAX_EMBEDDING_DIM
|
|
31
|
-
from letta.metadata import EmbeddingConfigColumn, ToolCallColumn
|
|
31
|
+
from letta.metadata import EmbeddingConfigColumn, FileMetadataModel, ToolCallColumn
|
|
32
32
|
|
|
33
33
|
# from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall
|
|
34
34
|
from letta.schemas.message import Message
|
|
@@ -141,7 +141,7 @@ class PassageModel(Base):
|
|
|
141
141
|
id = Column(String, primary_key=True)
|
|
142
142
|
user_id = Column(String, nullable=False)
|
|
143
143
|
text = Column(String)
|
|
144
|
-
|
|
144
|
+
file_id = Column(String)
|
|
145
145
|
agent_id = Column(String)
|
|
146
146
|
source_id = Column(String)
|
|
147
147
|
|
|
@@ -160,7 +160,7 @@ class PassageModel(Base):
|
|
|
160
160
|
# Add a datetime column, with default value as the current time
|
|
161
161
|
created_at = Column(DateTime(timezone=True))
|
|
162
162
|
|
|
163
|
-
Index("passage_idx_user", user_id, agent_id,
|
|
163
|
+
Index("passage_idx_user", user_id, agent_id, file_id),
|
|
164
164
|
|
|
165
165
|
def __repr__(self):
|
|
166
166
|
return f"<Passage(passage_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
|
|
@@ -170,7 +170,7 @@ class PassageModel(Base):
|
|
|
170
170
|
text=self.text,
|
|
171
171
|
embedding=self.embedding,
|
|
172
172
|
embedding_config=self.embedding_config,
|
|
173
|
-
|
|
173
|
+
file_id=self.file_id,
|
|
174
174
|
user_id=self.user_id,
|
|
175
175
|
id=self.id,
|
|
176
176
|
source_id=self.source_id,
|
|
@@ -365,12 +365,17 @@ class PostgresStorageConnector(SQLStorageConnector):
|
|
|
365
365
|
self.uri = self.config.archival_storage_uri
|
|
366
366
|
self.db_model = PassageModel
|
|
367
367
|
if self.config.archival_storage_uri is None:
|
|
368
|
-
raise ValueError(f"Must
|
|
368
|
+
raise ValueError(f"Must specify archival_storage_uri in config {self.config.config_path}")
|
|
369
369
|
elif table_type == TableType.RECALL_MEMORY:
|
|
370
370
|
self.uri = self.config.recall_storage_uri
|
|
371
371
|
self.db_model = MessageModel
|
|
372
372
|
if self.config.recall_storage_uri is None:
|
|
373
|
-
raise ValueError(f"Must
|
|
373
|
+
raise ValueError(f"Must specify recall_storage_uri in config {self.config.config_path}")
|
|
374
|
+
elif table_type == TableType.FILES:
|
|
375
|
+
self.uri = self.config.metadata_storage_uri
|
|
376
|
+
self.db_model = FileMetadataModel
|
|
377
|
+
if self.config.metadata_storage_uri is None:
|
|
378
|
+
raise ValueError(f"Must specify metadata_storage_uri in config {self.config.config_path}")
|
|
374
379
|
else:
|
|
375
380
|
raise ValueError(f"Table type {table_type} not implemented")
|
|
376
381
|
|
|
@@ -487,8 +492,14 @@ class SQLLiteStorageConnector(SQLStorageConnector):
|
|
|
487
492
|
# TODO: eventually implement URI option
|
|
488
493
|
self.path = self.config.recall_storage_path
|
|
489
494
|
if self.path is None:
|
|
490
|
-
raise ValueError(f"Must
|
|
495
|
+
raise ValueError(f"Must specify recall_storage_path in config.")
|
|
491
496
|
self.db_model = MessageModel
|
|
497
|
+
elif table_type == TableType.FILES:
|
|
498
|
+
self.path = self.config.metadata_storage_path
|
|
499
|
+
if self.path is None:
|
|
500
|
+
raise ValueError(f"Must specify metadata_storage_path in config.")
|
|
501
|
+
self.db_model = FileMetadataModel
|
|
502
|
+
|
|
492
503
|
else:
|
|
493
504
|
raise ValueError(f"Table type {table_type} not implemented")
|
|
494
505
|
|
letta/agent_store/lancedb.py
CHANGED
|
@@ -24,7 +24,7 @@ def get_db_model(table_name: str, table_type: TableType):
|
|
|
24
24
|
id: uuid.UUID
|
|
25
25
|
user_id: str
|
|
26
26
|
text: str
|
|
27
|
-
|
|
27
|
+
file_id: str
|
|
28
28
|
agent_id: str
|
|
29
29
|
data_source: str
|
|
30
30
|
embedding: Vector(config.default_embedding_config.embedding_dim)
|
|
@@ -37,7 +37,7 @@ def get_db_model(table_name: str, table_type: TableType):
|
|
|
37
37
|
return Passage(
|
|
38
38
|
text=self.text,
|
|
39
39
|
embedding=self.embedding,
|
|
40
|
-
|
|
40
|
+
file_id=self.file_id,
|
|
41
41
|
user_id=self.user_id,
|
|
42
42
|
id=self.id,
|
|
43
43
|
data_source=self.data_source,
|
letta/agent_store/milvus.py
CHANGED
|
@@ -26,7 +26,7 @@ class MilvusStorageConnector(StorageConnector):
|
|
|
26
26
|
raise ValueError("Please set `archival_storage_uri` in the config file when using Milvus.")
|
|
27
27
|
|
|
28
28
|
# need to be converted to strings
|
|
29
|
-
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "
|
|
29
|
+
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "file_id"]
|
|
30
30
|
|
|
31
31
|
def _create_collection(self):
|
|
32
32
|
schema = MilvusClient.create_schema(
|
letta/agent_store/qdrant.py
CHANGED
|
@@ -38,7 +38,7 @@ class QdrantStorageConnector(StorageConnector):
|
|
|
38
38
|
distance=models.Distance.COSINE,
|
|
39
39
|
),
|
|
40
40
|
)
|
|
41
|
-
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "
|
|
41
|
+
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "file_id"]
|
|
42
42
|
|
|
43
43
|
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 10) -> Iterator[List[RecordType]]:
|
|
44
44
|
from qdrant_client import grpc
|
letta/agent_store/storage.py
CHANGED
|
@@ -10,7 +10,7 @@ from typing import Dict, List, Optional, Tuple, Type, Union
|
|
|
10
10
|
from pydantic import BaseModel
|
|
11
11
|
|
|
12
12
|
from letta.config import LettaConfig
|
|
13
|
-
from letta.schemas.
|
|
13
|
+
from letta.schemas.file import FileMetadata
|
|
14
14
|
from letta.schemas.message import Message
|
|
15
15
|
from letta.schemas.passage import Passage
|
|
16
16
|
from letta.utils import printd
|
|
@@ -22,7 +22,7 @@ class TableType:
|
|
|
22
22
|
ARCHIVAL_MEMORY = "archival_memory" # recall memory table: letta_agent_{agent_id}
|
|
23
23
|
RECALL_MEMORY = "recall_memory" # archival memory table: letta_agent_recall_{agent_id}
|
|
24
24
|
PASSAGES = "passages" # TODO
|
|
25
|
-
|
|
25
|
+
FILES = "files"
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
# table names used by Letta
|
|
@@ -33,17 +33,17 @@ ARCHIVAL_TABLE_NAME = "letta_archival_memory_agent" # agent memory
|
|
|
33
33
|
|
|
34
34
|
# external data source tables
|
|
35
35
|
PASSAGE_TABLE_NAME = "letta_passages" # chunked/embedded passages (from source)
|
|
36
|
-
|
|
36
|
+
FILE_TABLE_NAME = "letta_files" # original files (from source)
|
|
37
37
|
|
|
38
38
|
|
|
39
39
|
class StorageConnector:
|
|
40
|
-
"""Defines a DB connection that is user-specific to access data:
|
|
40
|
+
"""Defines a DB connection that is user-specific to access data: files, Passages, Archival/Recall Memory"""
|
|
41
41
|
|
|
42
42
|
type: Type[BaseModel]
|
|
43
43
|
|
|
44
44
|
def __init__(
|
|
45
45
|
self,
|
|
46
|
-
table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.
|
|
46
|
+
table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.FILES],
|
|
47
47
|
config: LettaConfig,
|
|
48
48
|
user_id,
|
|
49
49
|
agent_id=None,
|
|
@@ -59,9 +59,9 @@ class StorageConnector:
|
|
|
59
59
|
elif table_type == TableType.RECALL_MEMORY:
|
|
60
60
|
self.type = Message
|
|
61
61
|
self.table_name = RECALL_TABLE_NAME
|
|
62
|
-
elif table_type == TableType.
|
|
63
|
-
self.type =
|
|
64
|
-
self.table_name
|
|
62
|
+
elif table_type == TableType.FILES:
|
|
63
|
+
self.type = FileMetadata
|
|
64
|
+
self.table_name = FILE_TABLE_NAME
|
|
65
65
|
elif table_type == TableType.PASSAGES:
|
|
66
66
|
self.type = Passage
|
|
67
67
|
self.table_name = PASSAGE_TABLE_NAME
|
|
@@ -74,7 +74,7 @@ class StorageConnector:
|
|
|
74
74
|
# agent-specific table
|
|
75
75
|
assert agent_id is not None, "Agent ID must be provided for agent-specific tables"
|
|
76
76
|
self.filters = {"user_id": self.user_id, "agent_id": self.agent_id}
|
|
77
|
-
elif self.table_type == TableType.PASSAGES or self.table_type == TableType.
|
|
77
|
+
elif self.table_type == TableType.PASSAGES or self.table_type == TableType.FILES:
|
|
78
78
|
# setup base filters for user-specific tables
|
|
79
79
|
assert agent_id is None, "Agent ID must not be provided for user-specific tables"
|
|
80
80
|
self.filters = {"user_id": self.user_id}
|
|
@@ -83,7 +83,7 @@ class StorageConnector:
|
|
|
83
83
|
|
|
84
84
|
@staticmethod
|
|
85
85
|
def get_storage_connector(
|
|
86
|
-
table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.
|
|
86
|
+
table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.FILES],
|
|
87
87
|
config: LettaConfig,
|
|
88
88
|
user_id,
|
|
89
89
|
agent_id=None,
|
|
@@ -92,6 +92,8 @@ class StorageConnector:
|
|
|
92
92
|
storage_type = config.archival_storage_type
|
|
93
93
|
elif table_type == TableType.RECALL_MEMORY:
|
|
94
94
|
storage_type = config.recall_storage_type
|
|
95
|
+
elif table_type == TableType.FILES:
|
|
96
|
+
storage_type = config.metadata_storage_type
|
|
95
97
|
else:
|
|
96
98
|
raise ValueError(f"Table type {table_type} not implemented")
|
|
97
99
|
|
letta/cli/cli_load.py
CHANGED
|
@@ -106,7 +106,7 @@ def load_vector_database(
|
|
|
106
106
|
# document_store=None,
|
|
107
107
|
# passage_store=passage_storage,
|
|
108
108
|
# )
|
|
109
|
-
# print(f"Loaded {num_passages} passages and {num_documents}
|
|
109
|
+
# print(f"Loaded {num_passages} passages and {num_documents} files from {name}")
|
|
110
110
|
# except Exception as e:
|
|
111
111
|
# typer.secho(f"Failed to load data from provided information.\n{e}", fg=typer.colors.RED)
|
|
112
112
|
# ms.delete_source(source_id=source.id)
|
letta/client/client.py
CHANGED
|
@@ -25,6 +25,7 @@ from letta.schemas.embedding_config import EmbeddingConfig
|
|
|
25
25
|
|
|
26
26
|
# new schemas
|
|
27
27
|
from letta.schemas.enums import JobStatus, MessageRole
|
|
28
|
+
from letta.schemas.file import FileMetadata
|
|
28
29
|
from letta.schemas.job import Job
|
|
29
30
|
from letta.schemas.letta_request import LettaRequest
|
|
30
31
|
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
|
|
@@ -232,6 +233,9 @@ class AbstractClient(object):
|
|
|
232
233
|
def list_attached_sources(self, agent_id: str) -> List[Source]:
|
|
233
234
|
raise NotImplementedError
|
|
234
235
|
|
|
236
|
+
def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Optional[str] = None) -> List[FileMetadata]:
|
|
237
|
+
raise NotImplementedError
|
|
238
|
+
|
|
235
239
|
def update_source(self, source_id: str, name: Optional[str] = None) -> Source:
|
|
236
240
|
raise NotImplementedError
|
|
237
241
|
|
|
@@ -1016,6 +1020,12 @@ class RESTClient(AbstractClient):
|
|
|
1016
1020
|
raise ValueError(f"Failed to get job: {response.text}")
|
|
1017
1021
|
return Job(**response.json())
|
|
1018
1022
|
|
|
1023
|
+
def delete_job(self, job_id: str) -> Job:
|
|
1024
|
+
response = requests.delete(f"{self.base_url}/{self.api_prefix}/jobs/{job_id}", headers=self.headers)
|
|
1025
|
+
if response.status_code != 200:
|
|
1026
|
+
raise ValueError(f"Failed to delete job: {response.text}")
|
|
1027
|
+
return Job(**response.json())
|
|
1028
|
+
|
|
1019
1029
|
def list_jobs(self):
|
|
1020
1030
|
response = requests.get(f"{self.base_url}/{self.api_prefix}/jobs", headers=self.headers)
|
|
1021
1031
|
return [Job(**job) for job in response.json()]
|
|
@@ -1088,6 +1098,30 @@ class RESTClient(AbstractClient):
|
|
|
1088
1098
|
raise ValueError(f"Failed to list attached sources: {response.text}")
|
|
1089
1099
|
return [Source(**source) for source in response.json()]
|
|
1090
1100
|
|
|
1101
|
+
def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Optional[str] = None) -> List[FileMetadata]:
|
|
1102
|
+
"""
|
|
1103
|
+
List files from source with pagination support.
|
|
1104
|
+
|
|
1105
|
+
Args:
|
|
1106
|
+
source_id (str): ID of the source
|
|
1107
|
+
limit (int): Number of files to return
|
|
1108
|
+
cursor (Optional[str]): Pagination cursor for fetching the next page
|
|
1109
|
+
|
|
1110
|
+
Returns:
|
|
1111
|
+
List[FileMetadata]: List of files
|
|
1112
|
+
"""
|
|
1113
|
+
# Prepare query parameters for pagination
|
|
1114
|
+
params = {"limit": limit, "cursor": cursor}
|
|
1115
|
+
|
|
1116
|
+
# Make the request to the FastAPI endpoint
|
|
1117
|
+
response = requests.get(f"{self.base_url}/{self.api_prefix}/sources/{source_id}/files", headers=self.headers, params=params)
|
|
1118
|
+
|
|
1119
|
+
if response.status_code != 200:
|
|
1120
|
+
raise ValueError(f"Failed to list files with source id {source_id}: [{response.status_code}] {response.text}")
|
|
1121
|
+
|
|
1122
|
+
# Parse the JSON response
|
|
1123
|
+
return [FileMetadata(**metadata) for metadata in response.json()]
|
|
1124
|
+
|
|
1091
1125
|
def update_source(self, source_id: str, name: Optional[str] = None) -> Source:
|
|
1092
1126
|
"""
|
|
1093
1127
|
Update a source
|
|
@@ -2162,6 +2196,9 @@ class LocalClient(AbstractClient):
|
|
|
2162
2196
|
def get_job(self, job_id: str):
|
|
2163
2197
|
return self.server.get_job(job_id=job_id)
|
|
2164
2198
|
|
|
2199
|
+
def delete_job(self, job_id: str):
|
|
2200
|
+
return self.server.delete_job(job_id)
|
|
2201
|
+
|
|
2165
2202
|
def list_jobs(self):
|
|
2166
2203
|
return self.server.list_jobs(user_id=self.user_id)
|
|
2167
2204
|
|
|
@@ -2261,6 +2298,20 @@ class LocalClient(AbstractClient):
|
|
|
2261
2298
|
"""
|
|
2262
2299
|
return self.server.list_attached_sources(agent_id=agent_id)
|
|
2263
2300
|
|
|
2301
|
+
def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Optional[str] = None) -> List[FileMetadata]:
|
|
2302
|
+
"""
|
|
2303
|
+
List files from source.
|
|
2304
|
+
|
|
2305
|
+
Args:
|
|
2306
|
+
source_id (str): ID of the source
|
|
2307
|
+
limit (int): The # of items to return
|
|
2308
|
+
cursor (str): The cursor for fetching the next page
|
|
2309
|
+
|
|
2310
|
+
Returns:
|
|
2311
|
+
files (List[FileMetadata]): List of files
|
|
2312
|
+
"""
|
|
2313
|
+
return self.server.list_files_from_source(source_id=source_id, limit=limit, cursor=cursor)
|
|
2314
|
+
|
|
2264
2315
|
def update_source(self, source_id: str, name: Optional[str] = None) -> Source:
|
|
2265
2316
|
"""
|
|
2266
2317
|
Update a source
|
letta/data_sources/connectors.py
CHANGED
|
@@ -1,11 +1,15 @@
|
|
|
1
|
-
from typing import Dict, Iterator, List,
|
|
1
|
+
from typing import Dict, Iterator, List, Tuple
|
|
2
2
|
|
|
3
3
|
import typer
|
|
4
|
-
from llama_index.core import Document as LlamaIndexDocument
|
|
5
4
|
|
|
6
5
|
from letta.agent_store.storage import StorageConnector
|
|
6
|
+
from letta.data_sources.connectors_helper import (
|
|
7
|
+
assert_all_files_exist_locally,
|
|
8
|
+
extract_metadata_from_files,
|
|
9
|
+
get_filenames_in_dir,
|
|
10
|
+
)
|
|
7
11
|
from letta.embeddings import embedding_model
|
|
8
|
-
from letta.schemas.
|
|
12
|
+
from letta.schemas.file import FileMetadata
|
|
9
13
|
from letta.schemas.passage import Passage
|
|
10
14
|
from letta.schemas.source import Source
|
|
11
15
|
from letta.utils import create_uuid_from_string
|
|
@@ -13,23 +17,23 @@ from letta.utils import create_uuid_from_string
|
|
|
13
17
|
|
|
14
18
|
class DataConnector:
|
|
15
19
|
"""
|
|
16
|
-
Base class for data connectors that can be extended to generate
|
|
20
|
+
Base class for data connectors that can be extended to generate files and passages from a custom data source.
|
|
17
21
|
"""
|
|
18
22
|
|
|
19
|
-
def
|
|
23
|
+
def find_files(self, source: Source) -> Iterator[FileMetadata]:
|
|
20
24
|
"""
|
|
21
|
-
Generate
|
|
25
|
+
Generate file metadata from a data source.
|
|
22
26
|
|
|
23
27
|
Returns:
|
|
24
|
-
|
|
28
|
+
files (Iterator[FileMetadata]): Generate file metadata for each file found.
|
|
25
29
|
"""
|
|
26
30
|
|
|
27
|
-
def generate_passages(self,
|
|
31
|
+
def generate_passages(self, file: FileMetadata, chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
|
|
28
32
|
"""
|
|
29
|
-
Generate passage text and metadata from a list of
|
|
33
|
+
Generate passage text and metadata from a list of files.
|
|
30
34
|
|
|
31
35
|
Args:
|
|
32
|
-
|
|
36
|
+
file (FileMetadata): The document to generate passages from.
|
|
33
37
|
chunk_size (int, optional): Chunk size for splitting passages. Defaults to 1024.
|
|
34
38
|
|
|
35
39
|
Returns:
|
|
@@ -41,33 +45,25 @@ def load_data(
|
|
|
41
45
|
connector: DataConnector,
|
|
42
46
|
source: Source,
|
|
43
47
|
passage_store: StorageConnector,
|
|
44
|
-
|
|
48
|
+
file_metadata_store: StorageConnector,
|
|
45
49
|
):
|
|
46
|
-
"""Load data from a connector (generates
|
|
50
|
+
"""Load data from a connector (generates file and passages) into a specified source_id, associatedw with a user_id."""
|
|
47
51
|
embedding_config = source.embedding_config
|
|
48
52
|
|
|
49
53
|
# embedding model
|
|
50
54
|
embed_model = embedding_model(embedding_config)
|
|
51
55
|
|
|
52
|
-
# insert passages/
|
|
56
|
+
# insert passages/file
|
|
53
57
|
passages = []
|
|
54
58
|
embedding_to_document_name = {}
|
|
55
59
|
passage_count = 0
|
|
56
|
-
|
|
57
|
-
for
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
text=document_text,
|
|
61
|
-
metadata_=document_metadata,
|
|
62
|
-
source_id=source.id,
|
|
63
|
-
user_id=source.user_id,
|
|
64
|
-
)
|
|
65
|
-
document_count += 1
|
|
66
|
-
if document_store:
|
|
67
|
-
document_store.insert(document)
|
|
60
|
+
file_count = 0
|
|
61
|
+
for file_metadata in connector.find_files(source):
|
|
62
|
+
file_count += 1
|
|
63
|
+
file_metadata_store.insert(file_metadata)
|
|
68
64
|
|
|
69
65
|
# generate passages
|
|
70
|
-
for passage_text, passage_metadata in connector.generate_passages(
|
|
66
|
+
for passage_text, passage_metadata in connector.generate_passages(file_metadata, chunk_size=embedding_config.embedding_chunk_size):
|
|
71
67
|
# for some reason, llama index parsers sometimes return empty strings
|
|
72
68
|
if len(passage_text) == 0:
|
|
73
69
|
typer.secho(
|
|
@@ -89,7 +85,7 @@ def load_data(
|
|
|
89
85
|
passage = Passage(
|
|
90
86
|
id=create_uuid_from_string(f"{str(source.id)}_{passage_text}"),
|
|
91
87
|
text=passage_text,
|
|
92
|
-
|
|
88
|
+
file_id=file_metadata.id,
|
|
93
89
|
source_id=source.id,
|
|
94
90
|
metadata_=passage_metadata,
|
|
95
91
|
user_id=source.user_id,
|
|
@@ -98,16 +94,16 @@ def load_data(
|
|
|
98
94
|
)
|
|
99
95
|
|
|
100
96
|
hashable_embedding = tuple(passage.embedding)
|
|
101
|
-
|
|
97
|
+
file_name = file_metadata.file_name
|
|
102
98
|
if hashable_embedding in embedding_to_document_name:
|
|
103
99
|
typer.secho(
|
|
104
|
-
f"Warning: Duplicate embedding found for passage in {
|
|
100
|
+
f"Warning: Duplicate embedding found for passage in {file_name} (already exists in {embedding_to_document_name[hashable_embedding]}), skipping insert into VectorDB.",
|
|
105
101
|
fg=typer.colors.YELLOW,
|
|
106
102
|
)
|
|
107
103
|
continue
|
|
108
104
|
|
|
109
105
|
passages.append(passage)
|
|
110
|
-
embedding_to_document_name[hashable_embedding] =
|
|
106
|
+
embedding_to_document_name[hashable_embedding] = file_name
|
|
111
107
|
if len(passages) >= 100:
|
|
112
108
|
# insert passages into passage store
|
|
113
109
|
passage_store.insert_many(passages)
|
|
@@ -120,7 +116,7 @@ def load_data(
|
|
|
120
116
|
passage_store.insert_many(passages)
|
|
121
117
|
passage_count += len(passages)
|
|
122
118
|
|
|
123
|
-
return passage_count,
|
|
119
|
+
return passage_count, file_count
|
|
124
120
|
|
|
125
121
|
|
|
126
122
|
class DirectoryConnector(DataConnector):
|
|
@@ -143,105 +139,109 @@ class DirectoryConnector(DataConnector):
|
|
|
143
139
|
if self.recursive == True:
|
|
144
140
|
assert self.input_directory is not None, "Must provide input directory if recursive is True."
|
|
145
141
|
|
|
146
|
-
def
|
|
147
|
-
from llama_index.core import SimpleDirectoryReader
|
|
148
|
-
|
|
142
|
+
def find_files(self, source: Source) -> Iterator[FileMetadata]:
|
|
149
143
|
if self.input_directory is not None:
|
|
150
|
-
|
|
144
|
+
files = get_filenames_in_dir(
|
|
151
145
|
input_dir=self.input_directory,
|
|
152
146
|
recursive=self.recursive,
|
|
153
147
|
required_exts=[ext.strip() for ext in str(self.extensions).split(",")],
|
|
148
|
+
exclude=["*png", "*jpg", "*jpeg"],
|
|
154
149
|
)
|
|
155
150
|
else:
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
151
|
+
files = self.input_files
|
|
152
|
+
|
|
153
|
+
# Check that file paths are valid
|
|
154
|
+
assert_all_files_exist_locally(files)
|
|
155
|
+
|
|
156
|
+
for metadata in extract_metadata_from_files(files):
|
|
157
|
+
yield FileMetadata(
|
|
158
|
+
user_id=source.user_id,
|
|
159
|
+
source_id=source.id,
|
|
160
|
+
file_name=metadata.get("file_name"),
|
|
161
|
+
file_path=metadata.get("file_path"),
|
|
162
|
+
file_type=metadata.get("file_type"),
|
|
163
|
+
file_size=metadata.get("file_size"),
|
|
164
|
+
file_creation_date=metadata.get("file_creation_date"),
|
|
165
|
+
file_last_modified_date=metadata.get("file_last_modified_date"),
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
def generate_passages(self, file: FileMetadata, chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]:
|
|
169
|
+
from llama_index.core import SimpleDirectoryReader
|
|
169
170
|
from llama_index.core.node_parser import TokenTextSplitter
|
|
170
171
|
|
|
171
172
|
parser = TokenTextSplitter(chunk_size=chunk_size)
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
class
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
yield text, {"embedding": embedding}
|
|
173
|
+
documents = SimpleDirectoryReader(input_files=[file.file_path]).load_data()
|
|
174
|
+
nodes = parser.get_nodes_from_documents(documents)
|
|
175
|
+
for node in nodes:
|
|
176
|
+
yield node.text, None
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
"""
|
|
180
|
+
The below isn't used anywhere, it isn't tested, and pretty much should be deleted.
|
|
181
|
+
- Matt
|
|
182
|
+
"""
|
|
183
|
+
# class WebConnector(DirectoryConnector):
|
|
184
|
+
# def __init__(self, urls: List[str] = None, html_to_text: bool = True):
|
|
185
|
+
# self.urls = urls
|
|
186
|
+
# self.html_to_text = html_to_text
|
|
187
|
+
#
|
|
188
|
+
# def generate_files(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
|
|
189
|
+
# from llama_index.readers.web import SimpleWebPageReader
|
|
190
|
+
#
|
|
191
|
+
# files = SimpleWebPageReader(html_to_text=self.html_to_text).load_data(self.urls)
|
|
192
|
+
# for document in files:
|
|
193
|
+
# yield document.text, {"url": document.id_}
|
|
194
|
+
#
|
|
195
|
+
#
|
|
196
|
+
# class VectorDBConnector(DataConnector):
|
|
197
|
+
# # NOTE: this class has not been properly tested, so is unlikely to work
|
|
198
|
+
# # TODO: allow loading multiple tables (1:1 mapping between FileMetadata and Table)
|
|
199
|
+
#
|
|
200
|
+
# def __init__(
|
|
201
|
+
# self,
|
|
202
|
+
# name: str,
|
|
203
|
+
# uri: str,
|
|
204
|
+
# table_name: str,
|
|
205
|
+
# text_column: str,
|
|
206
|
+
# embedding_column: str,
|
|
207
|
+
# embedding_dim: int,
|
|
208
|
+
# ):
|
|
209
|
+
# self.name = name
|
|
210
|
+
# self.uri = uri
|
|
211
|
+
# self.table_name = table_name
|
|
212
|
+
# self.text_column = text_column
|
|
213
|
+
# self.embedding_column = embedding_column
|
|
214
|
+
# self.embedding_dim = embedding_dim
|
|
215
|
+
#
|
|
216
|
+
# # connect to db table
|
|
217
|
+
# from sqlalchemy import create_engine
|
|
218
|
+
#
|
|
219
|
+
# self.engine = create_engine(uri)
|
|
220
|
+
#
|
|
221
|
+
# def generate_files(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
|
|
222
|
+
# yield self.table_name, None
|
|
223
|
+
#
|
|
224
|
+
# def generate_passages(self, file_text: str, file: FileMetadata, chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
|
|
225
|
+
# from pgvector.sqlalchemy import Vector
|
|
226
|
+
# from sqlalchemy import Inspector, MetaData, Table, select
|
|
227
|
+
#
|
|
228
|
+
# metadata = MetaData()
|
|
229
|
+
# # Create an inspector to inspect the database
|
|
230
|
+
# inspector = Inspector.from_engine(self.engine)
|
|
231
|
+
# table_names = inspector.get_table_names()
|
|
232
|
+
# assert self.table_name in table_names, f"Table {self.table_name} not found in database: tables that exist {table_names}."
|
|
233
|
+
#
|
|
234
|
+
# table = Table(self.table_name, metadata, autoload_with=self.engine)
|
|
235
|
+
#
|
|
236
|
+
# # Prepare a select statement
|
|
237
|
+
# select_statement = select(table.c[self.text_column], table.c[self.embedding_column].cast(Vector(self.embedding_dim)))
|
|
238
|
+
#
|
|
239
|
+
# # Execute the query and fetch the results
|
|
240
|
+
# # TODO: paginate results
|
|
241
|
+
# with self.engine.connect() as connection:
|
|
242
|
+
# result = connection.execute(select_statement).fetchall()
|
|
243
|
+
#
|
|
244
|
+
# for text, embedding in result:
|
|
245
|
+
# # assume that embeddings are the same model as in config
|
|
246
|
+
# # TODO: don't re-compute embedding
|
|
247
|
+
# yield text, {"embedding": embedding}
|