letta-nightly 0.4.1.dev20241014104152__py3-none-any.whl → 0.5.0.dev20241015014828__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of letta-nightly might be problematic. Click here for more details.

Files changed (32) hide show
  1. letta/__init__.py +2 -2
  2. letta/agent_store/db.py +18 -7
  3. letta/agent_store/lancedb.py +2 -2
  4. letta/agent_store/milvus.py +1 -1
  5. letta/agent_store/qdrant.py +1 -1
  6. letta/agent_store/storage.py +12 -10
  7. letta/cli/cli_load.py +1 -1
  8. letta/client/client.py +51 -0
  9. letta/data_sources/connectors.py +124 -124
  10. letta/data_sources/connectors_helper.py +97 -0
  11. letta/llm_api/mistral.py +47 -0
  12. letta/metadata.py +58 -0
  13. letta/providers.py +44 -0
  14. letta/schemas/file.py +31 -0
  15. letta/schemas/job.py +1 -1
  16. letta/schemas/letta_request.py +3 -3
  17. letta/schemas/llm_config.py +1 -0
  18. letta/schemas/message.py +6 -2
  19. letta/schemas/passage.py +3 -3
  20. letta/schemas/source.py +2 -2
  21. letta/server/rest_api/routers/v1/agents.py +10 -16
  22. letta/server/rest_api/routers/v1/jobs.py +17 -1
  23. letta/server/rest_api/routers/v1/sources.py +7 -9
  24. letta/server/server.py +86 -13
  25. letta/server/static_files/assets/{index-9a9c449b.js → index-dc228d4a.js} +4 -4
  26. letta/server/static_files/index.html +1 -1
  27. {letta_nightly-0.4.1.dev20241014104152.dist-info → letta_nightly-0.5.0.dev20241015014828.dist-info}/METADATA +1 -1
  28. {letta_nightly-0.4.1.dev20241014104152.dist-info → letta_nightly-0.5.0.dev20241015014828.dist-info}/RECORD +31 -29
  29. letta/schemas/document.py +0 -21
  30. {letta_nightly-0.4.1.dev20241014104152.dist-info → letta_nightly-0.5.0.dev20241015014828.dist-info}/LICENSE +0 -0
  31. {letta_nightly-0.4.1.dev20241014104152.dist-info → letta_nightly-0.5.0.dev20241015014828.dist-info}/WHEEL +0 -0
  32. {letta_nightly-0.4.1.dev20241014104152.dist-info → letta_nightly-0.5.0.dev20241015014828.dist-info}/entry_points.txt +0 -0
letta/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.4.1"
1
+ __version__ = "0.5.0"
2
2
 
3
3
  # import clients
4
4
  from letta.client.admin import Admin
@@ -7,9 +7,9 @@ from letta.client.client import LocalClient, RESTClient, create_client
7
7
  # imports for easier access
8
8
  from letta.schemas.agent import AgentState
9
9
  from letta.schemas.block import Block
10
- from letta.schemas.document import Document
11
10
  from letta.schemas.embedding_config import EmbeddingConfig
12
11
  from letta.schemas.enums import JobStatus
12
+ from letta.schemas.file import FileMetadata
13
13
  from letta.schemas.job import Job
14
14
  from letta.schemas.letta_message import LettaMessage
15
15
  from letta.schemas.llm_config import LLMConfig
letta/agent_store/db.py CHANGED
@@ -28,7 +28,7 @@ from letta.agent_store.storage import StorageConnector, TableType
28
28
  from letta.base import Base
29
29
  from letta.config import LettaConfig
30
30
  from letta.constants import MAX_EMBEDDING_DIM
31
- from letta.metadata import EmbeddingConfigColumn, ToolCallColumn
31
+ from letta.metadata import EmbeddingConfigColumn, FileMetadataModel, ToolCallColumn
32
32
 
33
33
  # from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall
34
34
  from letta.schemas.message import Message
@@ -141,7 +141,7 @@ class PassageModel(Base):
141
141
  id = Column(String, primary_key=True)
142
142
  user_id = Column(String, nullable=False)
143
143
  text = Column(String)
144
- doc_id = Column(String)
144
+ file_id = Column(String)
145
145
  agent_id = Column(String)
146
146
  source_id = Column(String)
147
147
 
@@ -160,7 +160,7 @@ class PassageModel(Base):
160
160
  # Add a datetime column, with default value as the current time
161
161
  created_at = Column(DateTime(timezone=True))
162
162
 
163
- Index("passage_idx_user", user_id, agent_id, doc_id),
163
+ Index("passage_idx_user", user_id, agent_id, file_id),
164
164
 
165
165
  def __repr__(self):
166
166
  return f"<Passage(passage_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
@@ -170,7 +170,7 @@ class PassageModel(Base):
170
170
  text=self.text,
171
171
  embedding=self.embedding,
172
172
  embedding_config=self.embedding_config,
173
- doc_id=self.doc_id,
173
+ file_id=self.file_id,
174
174
  user_id=self.user_id,
175
175
  id=self.id,
176
176
  source_id=self.source_id,
@@ -365,12 +365,17 @@ class PostgresStorageConnector(SQLStorageConnector):
365
365
  self.uri = self.config.archival_storage_uri
366
366
  self.db_model = PassageModel
367
367
  if self.config.archival_storage_uri is None:
368
- raise ValueError(f"Must specifiy archival_storage_uri in config {self.config.config_path}")
368
+ raise ValueError(f"Must specify archival_storage_uri in config {self.config.config_path}")
369
369
  elif table_type == TableType.RECALL_MEMORY:
370
370
  self.uri = self.config.recall_storage_uri
371
371
  self.db_model = MessageModel
372
372
  if self.config.recall_storage_uri is None:
373
- raise ValueError(f"Must specifiy recall_storage_uri in config {self.config.config_path}")
373
+ raise ValueError(f"Must specify recall_storage_uri in config {self.config.config_path}")
374
+ elif table_type == TableType.FILES:
375
+ self.uri = self.config.metadata_storage_uri
376
+ self.db_model = FileMetadataModel
377
+ if self.config.metadata_storage_uri is None:
378
+ raise ValueError(f"Must specify metadata_storage_uri in config {self.config.config_path}")
374
379
  else:
375
380
  raise ValueError(f"Table type {table_type} not implemented")
376
381
 
@@ -487,8 +492,14 @@ class SQLLiteStorageConnector(SQLStorageConnector):
487
492
  # TODO: eventually implement URI option
488
493
  self.path = self.config.recall_storage_path
489
494
  if self.path is None:
490
- raise ValueError(f"Must specifiy recall_storage_path in config {self.config.recall_storage_path}")
495
+ raise ValueError(f"Must specify recall_storage_path in config.")
491
496
  self.db_model = MessageModel
497
+ elif table_type == TableType.FILES:
498
+ self.path = self.config.metadata_storage_path
499
+ if self.path is None:
500
+ raise ValueError(f"Must specify metadata_storage_path in config.")
501
+ self.db_model = FileMetadataModel
502
+
492
503
  else:
493
504
  raise ValueError(f"Table type {table_type} not implemented")
494
505
 
@@ -24,7 +24,7 @@ def get_db_model(table_name: str, table_type: TableType):
24
24
  id: uuid.UUID
25
25
  user_id: str
26
26
  text: str
27
- doc_id: str
27
+ file_id: str
28
28
  agent_id: str
29
29
  data_source: str
30
30
  embedding: Vector(config.default_embedding_config.embedding_dim)
@@ -37,7 +37,7 @@ def get_db_model(table_name: str, table_type: TableType):
37
37
  return Passage(
38
38
  text=self.text,
39
39
  embedding=self.embedding,
40
- doc_id=self.doc_id,
40
+ file_id=self.file_id,
41
41
  user_id=self.user_id,
42
42
  id=self.id,
43
43
  data_source=self.data_source,
@@ -26,7 +26,7 @@ class MilvusStorageConnector(StorageConnector):
26
26
  raise ValueError("Please set `archival_storage_uri` in the config file when using Milvus.")
27
27
 
28
28
  # need to be converted to strings
29
- self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "doc_id"]
29
+ self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "file_id"]
30
30
 
31
31
  def _create_collection(self):
32
32
  schema = MilvusClient.create_schema(
@@ -38,7 +38,7 @@ class QdrantStorageConnector(StorageConnector):
38
38
  distance=models.Distance.COSINE,
39
39
  ),
40
40
  )
41
- self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "doc_id"]
41
+ self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "file_id"]
42
42
 
43
43
  def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 10) -> Iterator[List[RecordType]]:
44
44
  from qdrant_client import grpc
@@ -10,7 +10,7 @@ from typing import Dict, List, Optional, Tuple, Type, Union
10
10
  from pydantic import BaseModel
11
11
 
12
12
  from letta.config import LettaConfig
13
- from letta.schemas.document import Document
13
+ from letta.schemas.file import FileMetadata
14
14
  from letta.schemas.message import Message
15
15
  from letta.schemas.passage import Passage
16
16
  from letta.utils import printd
@@ -22,7 +22,7 @@ class TableType:
22
22
  ARCHIVAL_MEMORY = "archival_memory" # recall memory table: letta_agent_{agent_id}
23
23
  RECALL_MEMORY = "recall_memory" # archival memory table: letta_agent_recall_{agent_id}
24
24
  PASSAGES = "passages" # TODO
25
- DOCUMENTS = "documents" # TODO
25
+ FILES = "files"
26
26
 
27
27
 
28
28
  # table names used by Letta
@@ -33,17 +33,17 @@ ARCHIVAL_TABLE_NAME = "letta_archival_memory_agent" # agent memory
33
33
 
34
34
  # external data source tables
35
35
  PASSAGE_TABLE_NAME = "letta_passages" # chunked/embedded passages (from source)
36
- DOCUMENT_TABLE_NAME = "letta_documents" # original documents (from source)
36
+ FILE_TABLE_NAME = "letta_files" # original files (from source)
37
37
 
38
38
 
39
39
  class StorageConnector:
40
- """Defines a DB connection that is user-specific to access data: Documents, Passages, Archival/Recall Memory"""
40
+ """Defines a DB connection that is user-specific to access data: files, Passages, Archival/Recall Memory"""
41
41
 
42
42
  type: Type[BaseModel]
43
43
 
44
44
  def __init__(
45
45
  self,
46
- table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.DOCUMENTS],
46
+ table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.FILES],
47
47
  config: LettaConfig,
48
48
  user_id,
49
49
  agent_id=None,
@@ -59,9 +59,9 @@ class StorageConnector:
59
59
  elif table_type == TableType.RECALL_MEMORY:
60
60
  self.type = Message
61
61
  self.table_name = RECALL_TABLE_NAME
62
- elif table_type == TableType.DOCUMENTS:
63
- self.type = Document
64
- self.table_name == DOCUMENT_TABLE_NAME
62
+ elif table_type == TableType.FILES:
63
+ self.type = FileMetadata
64
+ self.table_name = FILE_TABLE_NAME
65
65
  elif table_type == TableType.PASSAGES:
66
66
  self.type = Passage
67
67
  self.table_name = PASSAGE_TABLE_NAME
@@ -74,7 +74,7 @@ class StorageConnector:
74
74
  # agent-specific table
75
75
  assert agent_id is not None, "Agent ID must be provided for agent-specific tables"
76
76
  self.filters = {"user_id": self.user_id, "agent_id": self.agent_id}
77
- elif self.table_type == TableType.PASSAGES or self.table_type == TableType.DOCUMENTS:
77
+ elif self.table_type == TableType.PASSAGES or self.table_type == TableType.FILES:
78
78
  # setup base filters for user-specific tables
79
79
  assert agent_id is None, "Agent ID must not be provided for user-specific tables"
80
80
  self.filters = {"user_id": self.user_id}
@@ -83,7 +83,7 @@ class StorageConnector:
83
83
 
84
84
  @staticmethod
85
85
  def get_storage_connector(
86
- table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.DOCUMENTS],
86
+ table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.FILES],
87
87
  config: LettaConfig,
88
88
  user_id,
89
89
  agent_id=None,
@@ -92,6 +92,8 @@ class StorageConnector:
92
92
  storage_type = config.archival_storage_type
93
93
  elif table_type == TableType.RECALL_MEMORY:
94
94
  storage_type = config.recall_storage_type
95
+ elif table_type == TableType.FILES:
96
+ storage_type = config.metadata_storage_type
95
97
  else:
96
98
  raise ValueError(f"Table type {table_type} not implemented")
97
99
 
letta/cli/cli_load.py CHANGED
@@ -106,7 +106,7 @@ def load_vector_database(
106
106
  # document_store=None,
107
107
  # passage_store=passage_storage,
108
108
  # )
109
- # print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")
109
+ # print(f"Loaded {num_passages} passages and {num_documents} files from {name}")
110
110
  # except Exception as e:
111
111
  # typer.secho(f"Failed to load data from provided information.\n{e}", fg=typer.colors.RED)
112
112
  # ms.delete_source(source_id=source.id)
letta/client/client.py CHANGED
@@ -25,6 +25,7 @@ from letta.schemas.embedding_config import EmbeddingConfig
25
25
 
26
26
  # new schemas
27
27
  from letta.schemas.enums import JobStatus, MessageRole
28
+ from letta.schemas.file import FileMetadata
28
29
  from letta.schemas.job import Job
29
30
  from letta.schemas.letta_request import LettaRequest
30
31
  from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
@@ -232,6 +233,9 @@ class AbstractClient(object):
232
233
  def list_attached_sources(self, agent_id: str) -> List[Source]:
233
234
  raise NotImplementedError
234
235
 
236
+ def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Optional[str] = None) -> List[FileMetadata]:
237
+ raise NotImplementedError
238
+
235
239
  def update_source(self, source_id: str, name: Optional[str] = None) -> Source:
236
240
  raise NotImplementedError
237
241
 
@@ -1016,6 +1020,12 @@ class RESTClient(AbstractClient):
1016
1020
  raise ValueError(f"Failed to get job: {response.text}")
1017
1021
  return Job(**response.json())
1018
1022
 
1023
+ def delete_job(self, job_id: str) -> Job:
1024
+ response = requests.delete(f"{self.base_url}/{self.api_prefix}/jobs/{job_id}", headers=self.headers)
1025
+ if response.status_code != 200:
1026
+ raise ValueError(f"Failed to delete job: {response.text}")
1027
+ return Job(**response.json())
1028
+
1019
1029
  def list_jobs(self):
1020
1030
  response = requests.get(f"{self.base_url}/{self.api_prefix}/jobs", headers=self.headers)
1021
1031
  return [Job(**job) for job in response.json()]
@@ -1088,6 +1098,30 @@ class RESTClient(AbstractClient):
1088
1098
  raise ValueError(f"Failed to list attached sources: {response.text}")
1089
1099
  return [Source(**source) for source in response.json()]
1090
1100
 
1101
+ def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Optional[str] = None) -> List[FileMetadata]:
1102
+ """
1103
+ List files from source with pagination support.
1104
+
1105
+ Args:
1106
+ source_id (str): ID of the source
1107
+ limit (int): Number of files to return
1108
+ cursor (Optional[str]): Pagination cursor for fetching the next page
1109
+
1110
+ Returns:
1111
+ List[FileMetadata]: List of files
1112
+ """
1113
+ # Prepare query parameters for pagination
1114
+ params = {"limit": limit, "cursor": cursor}
1115
+
1116
+ # Make the request to the FastAPI endpoint
1117
+ response = requests.get(f"{self.base_url}/{self.api_prefix}/sources/{source_id}/files", headers=self.headers, params=params)
1118
+
1119
+ if response.status_code != 200:
1120
+ raise ValueError(f"Failed to list files with source id {source_id}: [{response.status_code}] {response.text}")
1121
+
1122
+ # Parse the JSON response
1123
+ return [FileMetadata(**metadata) for metadata in response.json()]
1124
+
1091
1125
  def update_source(self, source_id: str, name: Optional[str] = None) -> Source:
1092
1126
  """
1093
1127
  Update a source
@@ -2162,6 +2196,9 @@ class LocalClient(AbstractClient):
2162
2196
  def get_job(self, job_id: str):
2163
2197
  return self.server.get_job(job_id=job_id)
2164
2198
 
2199
+ def delete_job(self, job_id: str):
2200
+ return self.server.delete_job(job_id)
2201
+
2165
2202
  def list_jobs(self):
2166
2203
  return self.server.list_jobs(user_id=self.user_id)
2167
2204
 
@@ -2261,6 +2298,20 @@ class LocalClient(AbstractClient):
2261
2298
  """
2262
2299
  return self.server.list_attached_sources(agent_id=agent_id)
2263
2300
 
2301
+ def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Optional[str] = None) -> List[FileMetadata]:
2302
+ """
2303
+ List files from source.
2304
+
2305
+ Args:
2306
+ source_id (str): ID of the source
2307
+ limit (int): The # of items to return
2308
+ cursor (str): The cursor for fetching the next page
2309
+
2310
+ Returns:
2311
+ files (List[FileMetadata]): List of files
2312
+ """
2313
+ return self.server.list_files_from_source(source_id=source_id, limit=limit, cursor=cursor)
2314
+
2264
2315
  def update_source(self, source_id: str, name: Optional[str] = None) -> Source:
2265
2316
  """
2266
2317
  Update a source
@@ -1,11 +1,15 @@
1
- from typing import Dict, Iterator, List, Optional, Tuple
1
+ from typing import Dict, Iterator, List, Tuple
2
2
 
3
3
  import typer
4
- from llama_index.core import Document as LlamaIndexDocument
5
4
 
6
5
  from letta.agent_store.storage import StorageConnector
6
+ from letta.data_sources.connectors_helper import (
7
+ assert_all_files_exist_locally,
8
+ extract_metadata_from_files,
9
+ get_filenames_in_dir,
10
+ )
7
11
  from letta.embeddings import embedding_model
8
- from letta.schemas.document import Document
12
+ from letta.schemas.file import FileMetadata
9
13
  from letta.schemas.passage import Passage
10
14
  from letta.schemas.source import Source
11
15
  from letta.utils import create_uuid_from_string
@@ -13,23 +17,23 @@ from letta.utils import create_uuid_from_string
13
17
 
14
18
  class DataConnector:
15
19
  """
16
- Base class for data connectors that can be extended to generate documents and passages from a custom data source.
20
+ Base class for data connectors that can be extended to generate files and passages from a custom data source.
17
21
  """
18
22
 
19
- def generate_documents(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
23
+ def find_files(self, source: Source) -> Iterator[FileMetadata]:
20
24
  """
21
- Generate document text and metadata from a data source.
25
+ Generate file metadata from a data source.
22
26
 
23
27
  Returns:
24
- documents (Iterator[Tuple[str, Dict]]): Generate a tuple of string text and metadata dictionary for each document.
28
+ files (Iterator[FileMetadata]): Generate file metadata for each file found.
25
29
  """
26
30
 
27
- def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
31
+ def generate_passages(self, file: FileMetadata, chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
28
32
  """
29
- Generate passage text and metadata from a list of documents.
33
+ Generate passage text and metadata from a list of files.
30
34
 
31
35
  Args:
32
- documents (List[Document]): List of documents to generate passages from.
36
+ file (FileMetadata): The document to generate passages from.
33
37
  chunk_size (int, optional): Chunk size for splitting passages. Defaults to 1024.
34
38
 
35
39
  Returns:
@@ -41,33 +45,25 @@ def load_data(
41
45
  connector: DataConnector,
42
46
  source: Source,
43
47
  passage_store: StorageConnector,
44
- document_store: Optional[StorageConnector] = None,
48
+ file_metadata_store: StorageConnector,
45
49
  ):
46
- """Load data from a connector (generates documents and passages) into a specified source_id, associatedw with a user_id."""
50
+ """Load data from a connector (generates file and passages) into a specified source_id, associatedw with a user_id."""
47
51
  embedding_config = source.embedding_config
48
52
 
49
53
  # embedding model
50
54
  embed_model = embedding_model(embedding_config)
51
55
 
52
- # insert passages/documents
56
+ # insert passages/file
53
57
  passages = []
54
58
  embedding_to_document_name = {}
55
59
  passage_count = 0
56
- document_count = 0
57
- for document_text, document_metadata in connector.generate_documents():
58
- # insert document into storage
59
- document = Document(
60
- text=document_text,
61
- metadata_=document_metadata,
62
- source_id=source.id,
63
- user_id=source.user_id,
64
- )
65
- document_count += 1
66
- if document_store:
67
- document_store.insert(document)
60
+ file_count = 0
61
+ for file_metadata in connector.find_files(source):
62
+ file_count += 1
63
+ file_metadata_store.insert(file_metadata)
68
64
 
69
65
  # generate passages
70
- for passage_text, passage_metadata in connector.generate_passages([document], chunk_size=embedding_config.embedding_chunk_size):
66
+ for passage_text, passage_metadata in connector.generate_passages(file_metadata, chunk_size=embedding_config.embedding_chunk_size):
71
67
  # for some reason, llama index parsers sometimes return empty strings
72
68
  if len(passage_text) == 0:
73
69
  typer.secho(
@@ -89,7 +85,7 @@ def load_data(
89
85
  passage = Passage(
90
86
  id=create_uuid_from_string(f"{str(source.id)}_{passage_text}"),
91
87
  text=passage_text,
92
- doc_id=document.id,
88
+ file_id=file_metadata.id,
93
89
  source_id=source.id,
94
90
  metadata_=passage_metadata,
95
91
  user_id=source.user_id,
@@ -98,16 +94,16 @@ def load_data(
98
94
  )
99
95
 
100
96
  hashable_embedding = tuple(passage.embedding)
101
- document_name = document.metadata_.get("file_path", document.id)
97
+ file_name = file_metadata.file_name
102
98
  if hashable_embedding in embedding_to_document_name:
103
99
  typer.secho(
104
- f"Warning: Duplicate embedding found for passage in {document_name} (already exists in {embedding_to_document_name[hashable_embedding]}), skipping insert into VectorDB.",
100
+ f"Warning: Duplicate embedding found for passage in {file_name} (already exists in {embedding_to_document_name[hashable_embedding]}), skipping insert into VectorDB.",
105
101
  fg=typer.colors.YELLOW,
106
102
  )
107
103
  continue
108
104
 
109
105
  passages.append(passage)
110
- embedding_to_document_name[hashable_embedding] = document_name
106
+ embedding_to_document_name[hashable_embedding] = file_name
111
107
  if len(passages) >= 100:
112
108
  # insert passages into passage store
113
109
  passage_store.insert_many(passages)
@@ -120,7 +116,7 @@ def load_data(
120
116
  passage_store.insert_many(passages)
121
117
  passage_count += len(passages)
122
118
 
123
- return passage_count, document_count
119
+ return passage_count, file_count
124
120
 
125
121
 
126
122
  class DirectoryConnector(DataConnector):
@@ -143,105 +139,109 @@ class DirectoryConnector(DataConnector):
143
139
  if self.recursive == True:
144
140
  assert self.input_directory is not None, "Must provide input directory if recursive is True."
145
141
 
146
- def generate_documents(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
147
- from llama_index.core import SimpleDirectoryReader
148
-
142
+ def find_files(self, source: Source) -> Iterator[FileMetadata]:
149
143
  if self.input_directory is not None:
150
- reader = SimpleDirectoryReader(
144
+ files = get_filenames_in_dir(
151
145
  input_dir=self.input_directory,
152
146
  recursive=self.recursive,
153
147
  required_exts=[ext.strip() for ext in str(self.extensions).split(",")],
148
+ exclude=["*png", "*jpg", "*jpeg"],
154
149
  )
155
150
  else:
156
- assert self.input_files is not None, "Must provide input files if input_dir is None"
157
- reader = SimpleDirectoryReader(input_files=[str(f) for f in self.input_files])
158
-
159
- llama_index_docs = reader.load_data(show_progress=True)
160
- for llama_index_doc in llama_index_docs:
161
- # TODO: add additional metadata?
162
- # doc = Document(text=llama_index_doc.text, metadata=llama_index_doc.metadata)
163
- # docs.append(doc)
164
- yield llama_index_doc.text, llama_index_doc.metadata
165
-
166
- def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
167
- # use llama index to run embeddings code
168
- # from llama_index.core.node_parser import SentenceSplitter
151
+ files = self.input_files
152
+
153
+ # Check that file paths are valid
154
+ assert_all_files_exist_locally(files)
155
+
156
+ for metadata in extract_metadata_from_files(files):
157
+ yield FileMetadata(
158
+ user_id=source.user_id,
159
+ source_id=source.id,
160
+ file_name=metadata.get("file_name"),
161
+ file_path=metadata.get("file_path"),
162
+ file_type=metadata.get("file_type"),
163
+ file_size=metadata.get("file_size"),
164
+ file_creation_date=metadata.get("file_creation_date"),
165
+ file_last_modified_date=metadata.get("file_last_modified_date"),
166
+ )
167
+
168
+ def generate_passages(self, file: FileMetadata, chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]:
169
+ from llama_index.core import SimpleDirectoryReader
169
170
  from llama_index.core.node_parser import TokenTextSplitter
170
171
 
171
172
  parser = TokenTextSplitter(chunk_size=chunk_size)
172
- for document in documents:
173
- llama_index_docs = [LlamaIndexDocument(text=document.text, metadata=document.metadata_)]
174
- nodes = parser.get_nodes_from_documents(llama_index_docs)
175
- for node in nodes:
176
- # passage = Passage(
177
- # text=node.text,
178
- # doc_id=document.id,
179
- # )
180
- yield node.text, None
181
-
182
-
183
- class WebConnector(DirectoryConnector):
184
- def __init__(self, urls: List[str] = None, html_to_text: bool = True):
185
- self.urls = urls
186
- self.html_to_text = html_to_text
187
-
188
- def generate_documents(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
189
- from llama_index.readers.web import SimpleWebPageReader
190
-
191
- documents = SimpleWebPageReader(html_to_text=self.html_to_text).load_data(self.urls)
192
- for document in documents:
193
- yield document.text, {"url": document.id_}
194
-
195
-
196
- class VectorDBConnector(DataConnector):
197
- # NOTE: this class has not been properly tested, so is unlikely to work
198
- # TODO: allow loading multiple tables (1:1 mapping between Document and Table)
199
-
200
- def __init__(
201
- self,
202
- name: str,
203
- uri: str,
204
- table_name: str,
205
- text_column: str,
206
- embedding_column: str,
207
- embedding_dim: int,
208
- ):
209
- self.name = name
210
- self.uri = uri
211
- self.table_name = table_name
212
- self.text_column = text_column
213
- self.embedding_column = embedding_column
214
- self.embedding_dim = embedding_dim
215
-
216
- # connect to db table
217
- from sqlalchemy import create_engine
218
-
219
- self.engine = create_engine(uri)
220
-
221
- def generate_documents(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
222
- yield self.table_name, None
223
-
224
- def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
225
- from pgvector.sqlalchemy import Vector
226
- from sqlalchemy import Inspector, MetaData, Table, select
227
-
228
- metadata = MetaData()
229
- # Create an inspector to inspect the database
230
- inspector = Inspector.from_engine(self.engine)
231
- table_names = inspector.get_table_names()
232
- assert self.table_name in table_names, f"Table {self.table_name} not found in database: tables that exist {table_names}."
233
-
234
- table = Table(self.table_name, metadata, autoload_with=self.engine)
235
-
236
- # Prepare a select statement
237
- select_statement = select(table.c[self.text_column], table.c[self.embedding_column].cast(Vector(self.embedding_dim)))
238
-
239
- # Execute the query and fetch the results
240
- # TODO: paginate results
241
- with self.engine.connect() as connection:
242
- result = connection.execute(select_statement).fetchall()
243
-
244
- for text, embedding in result:
245
- # assume that embeddings are the same model as in config
246
- # TODO: don't re-compute embedding
247
- yield text, {"embedding": embedding}
173
+ documents = SimpleDirectoryReader(input_files=[file.file_path]).load_data()
174
+ nodes = parser.get_nodes_from_documents(documents)
175
+ for node in nodes:
176
+ yield node.text, None
177
+
178
+
179
+ """
180
+ The below isn't used anywhere, it isn't tested, and pretty much should be deleted.
181
+ - Matt
182
+ """
183
+ # class WebConnector(DirectoryConnector):
184
+ # def __init__(self, urls: List[str] = None, html_to_text: bool = True):
185
+ # self.urls = urls
186
+ # self.html_to_text = html_to_text
187
+ #
188
+ # def generate_files(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
189
+ # from llama_index.readers.web import SimpleWebPageReader
190
+ #
191
+ # files = SimpleWebPageReader(html_to_text=self.html_to_text).load_data(self.urls)
192
+ # for document in files:
193
+ # yield document.text, {"url": document.id_}
194
+ #
195
+ #
196
+ # class VectorDBConnector(DataConnector):
197
+ # # NOTE: this class has not been properly tested, so is unlikely to work
198
+ # # TODO: allow loading multiple tables (1:1 mapping between FileMetadata and Table)
199
+ #
200
+ # def __init__(
201
+ # self,
202
+ # name: str,
203
+ # uri: str,
204
+ # table_name: str,
205
+ # text_column: str,
206
+ # embedding_column: str,
207
+ # embedding_dim: int,
208
+ # ):
209
+ # self.name = name
210
+ # self.uri = uri
211
+ # self.table_name = table_name
212
+ # self.text_column = text_column
213
+ # self.embedding_column = embedding_column
214
+ # self.embedding_dim = embedding_dim
215
+ #
216
+ # # connect to db table
217
+ # from sqlalchemy import create_engine
218
+ #
219
+ # self.engine = create_engine(uri)
220
+ #
221
+ # def generate_files(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
222
+ # yield self.table_name, None
223
+ #
224
+ # def generate_passages(self, file_text: str, file: FileMetadata, chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
225
+ # from pgvector.sqlalchemy import Vector
226
+ # from sqlalchemy import Inspector, MetaData, Table, select
227
+ #
228
+ # metadata = MetaData()
229
+ # # Create an inspector to inspect the database
230
+ # inspector = Inspector.from_engine(self.engine)
231
+ # table_names = inspector.get_table_names()
232
+ # assert self.table_name in table_names, f"Table {self.table_name} not found in database: tables that exist {table_names}."
233
+ #
234
+ # table = Table(self.table_name, metadata, autoload_with=self.engine)
235
+ #
236
+ # # Prepare a select statement
237
+ # select_statement = select(table.c[self.text_column], table.c[self.embedding_column].cast(Vector(self.embedding_dim)))
238
+ #
239
+ # # Execute the query and fetch the results
240
+ # # TODO: paginate results
241
+ # with self.engine.connect() as connection:
242
+ # result = connection.execute(select_statement).fetchall()
243
+ #
244
+ # for text, embedding in result:
245
+ # # assume that embeddings are the same model as in config
246
+ # # TODO: don't re-compute embedding
247
+ # yield text, {"embedding": embedding}