letta-nightly 0.6.3.dev20241211050151__py3-none-any.whl → 0.6.3.dev20241212015858__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of letta-nightly might be problematic. Click here for more details.

@@ -1,198 +0,0 @@
1
- import uuid
2
- from copy import deepcopy
3
- from typing import Dict, Iterator, List, Optional, cast
4
-
5
- from pymilvus import DataType, MilvusClient
6
- from pymilvus.client.constants import ConsistencyLevel
7
-
8
- from letta.agent_store.storage import StorageConnector, TableType
9
- from letta.config import LettaConfig
10
- from letta.constants import MAX_EMBEDDING_DIM
11
- from letta.data_types import Passage, Record, RecordType
12
- from letta.utils import datetime_to_timestamp, printd, timestamp_to_datetime
13
-
14
-
15
- class MilvusStorageConnector(StorageConnector):
16
- """Storage via Milvus"""
17
-
18
- def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None):
19
- super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id)
20
-
21
- assert table_type in [TableType.ARCHIVAL_MEMORY, TableType.PASSAGES], "Milvus only supports archival memory"
22
- if config.archival_storage_uri:
23
- self.client = MilvusClient(uri=config.archival_storage_uri)
24
- self._create_collection()
25
- else:
26
- raise ValueError("Please set `archival_storage_uri` in the config file when using Milvus.")
27
-
28
- # need to be converted to strings
29
- self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "file_id"]
30
-
31
- def _create_collection(self):
32
- schema = MilvusClient.create_schema(
33
- auto_id=False,
34
- enable_dynamic_field=True,
35
- )
36
- schema.add_field(field_name="id", datatype=DataType.VARCHAR, is_primary=True, max_length=65_535)
37
- schema.add_field(field_name="text", datatype=DataType.VARCHAR, is_primary=False, max_length=65_535)
38
- schema.add_field(field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=MAX_EMBEDDING_DIM)
39
- index_params = self.client.prepare_index_params()
40
- index_params.add_index(field_name="id")
41
- index_params.add_index(field_name="embedding", index_type="AUTOINDEX", metric_type="IP")
42
- self.client.create_collection(
43
- collection_name=self.table_name, schema=schema, index_params=index_params, consistency_level=ConsistencyLevel.Strong
44
- )
45
-
46
- def get_milvus_filter(self, filters: Optional[Dict] = {}) -> str:
47
- filter_conditions = {**self.filters, **filters} if filters is not None else self.filters
48
- if not filter_conditions:
49
- return ""
50
- conditions = []
51
- for key, value in filter_conditions.items():
52
- if key in self.uuid_fields or isinstance(key, str):
53
- condition = f'({key} == "{value}")'
54
- else:
55
- condition = f"({key} == {value})"
56
- conditions.append(condition)
57
- filter_expr = " and ".join(conditions)
58
- if len(conditions) == 1:
59
- filter_expr = filter_expr[1:-1]
60
- return filter_expr
61
-
62
- def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 1000) -> Iterator[List[RecordType]]:
63
- if not self.client.has_collection(collection_name=self.table_name):
64
- yield []
65
- filter_expr = self.get_milvus_filter(filters)
66
- offset = 0
67
- while True:
68
- # Retrieve a chunk of records with the given page_size
69
- query_res = self.client.query(
70
- collection_name=self.table_name,
71
- filter=filter_expr,
72
- offset=offset,
73
- limit=page_size,
74
- )
75
- if not query_res:
76
- break
77
- # Yield a list of Record objects converted from the chunk
78
- yield self._list_to_records(query_res)
79
-
80
- # Increment the offset to get the next chunk in the next iteration
81
- offset += page_size
82
-
83
- def get_all(self, filters: Optional[Dict] = {}, limit=None) -> List[RecordType]:
84
- if not self.client.has_collection(collection_name=self.table_name):
85
- return []
86
- filter_expr = self.get_milvus_filter(filters)
87
- query_res = self.client.query(
88
- collection_name=self.table_name,
89
- filter=filter_expr,
90
- limit=limit,
91
- )
92
- return self._list_to_records(query_res)
93
-
94
- def get(self, id: str) -> Optional[RecordType]:
95
- res = self.client.get(collection_name=self.table_name, ids=str(id))
96
- return self._list_to_records(res)[0] if res else None
97
-
98
- def size(self, filters: Optional[Dict] = {}) -> int:
99
- if not self.client.has_collection(collection_name=self.table_name):
100
- return 0
101
- filter_expr = self.get_milvus_filter(filters)
102
- count_expr = "count(*)"
103
- query_res = self.client.query(
104
- collection_name=self.table_name,
105
- filter=filter_expr,
106
- output_fields=[count_expr],
107
- )
108
- doc_num = query_res[0][count_expr]
109
- return doc_num
110
-
111
- def insert(self, record: RecordType):
112
- self.insert_many([record])
113
-
114
- def insert_many(self, records: List[RecordType], show_progress=False):
115
- if not records:
116
- return
117
-
118
- # Milvus lite currently does not support upsert, so we delete and insert instead
119
- # self.client.upsert(collection_name=self.table_name, data=self._records_to_list(records))
120
- ids = [str(record.id) for record in records]
121
- self.client.delete(collection_name=self.table_name, ids=ids)
122
- data = self._records_to_list(records)
123
- self.client.insert(collection_name=self.table_name, data=data)
124
-
125
- def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[RecordType]:
126
- if not self.client.has_collection(self.table_name):
127
- return []
128
- search_res = self.client.search(
129
- collection_name=self.table_name, data=[query_vec], filter=self.get_milvus_filter(filters), limit=top_k, output_fields=["*"]
130
- )[0]
131
- entity_res = [res["entity"] for res in search_res]
132
- return self._list_to_records(entity_res)
133
-
134
- def delete_table(self):
135
- self.client.drop_collection(collection_name=self.table_name)
136
-
137
- def delete(self, filters: Optional[Dict] = {}):
138
- if not self.client.has_collection(collection_name=self.table_name):
139
- return
140
- filter_expr = self.get_milvus_filter(filters)
141
- self.client.delete(collection_name=self.table_name, filter=filter_expr)
142
-
143
- def save(self):
144
- # save to persistence file (nothing needs to be done)
145
- printd("Saving milvus")
146
-
147
- def _records_to_list(self, records: List[Record]) -> List[Dict]:
148
- if records == []:
149
- return []
150
- assert all(isinstance(r, Passage) for r in records)
151
- record_list = []
152
- records = list(set(records))
153
- for record in records:
154
- record_vars = deepcopy(vars(record))
155
- _id = record_vars.pop("id")
156
- text = record_vars.pop("text", "")
157
- embedding = record_vars.pop("embedding")
158
- record_metadata = record_vars.pop("metadata_", None) or {}
159
- if "created_at" in record_vars:
160
- record_vars["created_at"] = datetime_to_timestamp(record_vars["created_at"])
161
- record_dict = {key: value for key, value in record_vars.items() if value is not None}
162
- record_dict = {
163
- **record_dict,
164
- **record_metadata,
165
- "id": str(_id),
166
- "text": text,
167
- "embedding": embedding,
168
- }
169
- for key, value in record_dict.items():
170
- if key in self.uuid_fields:
171
- record_dict[key] = str(value)
172
- record_list.append(record_dict)
173
- return record_list
174
-
175
- def _list_to_records(self, query_res: List[Dict]) -> List[RecordType]:
176
- records = []
177
- for res_dict in query_res:
178
- _id = res_dict.pop("id")
179
- embedding = res_dict.pop("embedding")
180
- text = res_dict.pop("text")
181
- metadata = deepcopy(res_dict)
182
- for key, value in metadata.items():
183
- if key in self.uuid_fields:
184
- metadata[key] = uuid.UUID(value)
185
- elif key == "created_at":
186
- metadata[key] = timestamp_to_datetime(value)
187
- records.append(
188
- cast(
189
- RecordType,
190
- self.type(
191
- text=text,
192
- embedding=embedding,
193
- id=uuid.UUID(_id),
194
- **metadata,
195
- ),
196
- )
197
- )
198
- return records
@@ -1,201 +0,0 @@
1
- import os
2
- import uuid
3
- from copy import deepcopy
4
- from typing import Dict, Iterator, List, Optional, cast
5
-
6
- from letta.agent_store.storage import StorageConnector, TableType
7
- from letta.config import LettaConfig
8
- from letta.constants import MAX_EMBEDDING_DIM
9
- from letta.data_types import Passage, Record, RecordType
10
- from letta.utils import datetime_to_timestamp, timestamp_to_datetime
11
-
12
- TEXT_PAYLOAD_KEY = "text_content"
13
- METADATA_PAYLOAD_KEY = "metadata"
14
-
15
-
16
- class QdrantStorageConnector(StorageConnector):
17
- """Storage via Qdrant"""
18
-
19
- def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None):
20
- super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id)
21
- try:
22
- from qdrant_client import QdrantClient, models
23
- except ImportError as e:
24
- raise ImportError("'qdrant-client' not installed. Run `pip install qdrant-client`.") from e
25
- assert table_type in [TableType.ARCHIVAL_MEMORY, TableType.PASSAGES], "Qdrant only supports archival memory"
26
- if config.archival_storage_uri and len(config.archival_storage_uri.split(":")) == 2:
27
- host, port = config.archival_storage_uri.split(":")
28
- self.qdrant_client = QdrantClient(host=host, port=port, api_key=os.getenv("QDRANT_API_KEY"))
29
- elif config.archival_storage_path:
30
- self.qdrant_client = QdrantClient(path=config.archival_storage_path)
31
- else:
32
- raise ValueError("Qdrant storage requires either a URI or a path to the storage configured")
33
- if not self.qdrant_client.collection_exists(self.table_name):
34
- self.qdrant_client.create_collection(
35
- collection_name=self.table_name,
36
- vectors_config=models.VectorParams(
37
- size=MAX_EMBEDDING_DIM,
38
- distance=models.Distance.COSINE,
39
- ),
40
- )
41
- self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "file_id"]
42
-
43
- def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 10) -> Iterator[List[RecordType]]:
44
- from qdrant_client import grpc
45
-
46
- filters = self.get_qdrant_filters(filters)
47
- next_offset = None
48
- stop_scrolling = False
49
- while not stop_scrolling:
50
- results, next_offset = self.qdrant_client.scroll(
51
- collection_name=self.table_name,
52
- scroll_filter=filters,
53
- limit=page_size,
54
- offset=next_offset,
55
- with_payload=True,
56
- with_vectors=True,
57
- )
58
- stop_scrolling = next_offset is None or (
59
- isinstance(next_offset, grpc.PointId) and next_offset.num == 0 and next_offset.uuid == ""
60
- )
61
- yield self.to_records(results)
62
-
63
- def get_all(self, filters: Optional[Dict] = {}, limit=10) -> List[RecordType]:
64
- if self.size(filters) == 0:
65
- return []
66
- filters = self.get_qdrant_filters(filters)
67
- results, _ = self.qdrant_client.scroll(
68
- self.table_name,
69
- scroll_filter=filters,
70
- limit=limit,
71
- with_payload=True,
72
- with_vectors=True,
73
- )
74
- return self.to_records(results)
75
-
76
- def get(self, id: str) -> Optional[RecordType]:
77
- results = self.qdrant_client.retrieve(
78
- collection_name=self.table_name,
79
- ids=[str(id)],
80
- with_payload=True,
81
- with_vectors=True,
82
- )
83
- if not results:
84
- return None
85
- return self.to_records(results)[0]
86
-
87
- def insert(self, record: Record):
88
- points = self.to_points([record])
89
- self.qdrant_client.upsert(self.table_name, points=points)
90
-
91
- def insert_many(self, records: List[RecordType], show_progress=False):
92
- points = self.to_points(records)
93
- self.qdrant_client.upsert(self.table_name, points=points)
94
-
95
- def delete(self, filters: Optional[Dict] = {}):
96
- filters = self.get_qdrant_filters(filters)
97
- self.qdrant_client.delete(self.table_name, points_selector=filters)
98
-
99
- def delete_table(self):
100
- self.qdrant_client.delete_collection(self.table_name)
101
- self.qdrant_client.close()
102
-
103
- def size(self, filters: Optional[Dict] = {}) -> int:
104
- filters = self.get_qdrant_filters(filters)
105
- return self.qdrant_client.count(collection_name=self.table_name, count_filter=filters).count
106
-
107
- def close(self):
108
- self.qdrant_client.close()
109
-
110
- def query(
111
- self,
112
- query: str,
113
- query_vec: List[float],
114
- top_k: int = 10,
115
- filters: Optional[Dict] = {},
116
- ) -> List[RecordType]:
117
- filters = self.get_filters(filters)
118
- results = self.qdrant_client.search(
119
- self.table_name,
120
- query_vector=query_vec,
121
- query_filter=filters,
122
- limit=top_k,
123
- with_payload=True,
124
- with_vectors=True,
125
- )
126
- return self.to_records(results)
127
-
128
- def to_records(self, records: list) -> List[RecordType]:
129
- parsed_records = []
130
- for record in records:
131
- record = deepcopy(record)
132
- metadata = record.payload[METADATA_PAYLOAD_KEY]
133
- text = record.payload[TEXT_PAYLOAD_KEY]
134
- _id = metadata.pop("id")
135
- embedding = record.vector
136
- for key, value in metadata.items():
137
- if key in self.uuid_fields:
138
- metadata[key] = uuid.UUID(value)
139
- elif key == "created_at":
140
- metadata[key] = timestamp_to_datetime(value)
141
- parsed_records.append(
142
- cast(
143
- RecordType,
144
- self.type(
145
- text=text,
146
- embedding=embedding,
147
- id=uuid.UUID(_id),
148
- **metadata,
149
- ),
150
- )
151
- )
152
- return parsed_records
153
-
154
- def to_points(self, records: List[RecordType]):
155
- from qdrant_client import models
156
-
157
- assert all(isinstance(r, Passage) for r in records)
158
- points = []
159
- records = list(set(records))
160
- for record in records:
161
- record = vars(record)
162
- _id = record.pop("id")
163
- text = record.pop("text", "")
164
- embedding = record.pop("embedding", {})
165
- record_metadata = record.pop("metadata_", None) or {}
166
- if "created_at" in record:
167
- record["created_at"] = datetime_to_timestamp(record["created_at"])
168
- metadata = {key: value for key, value in record.items() if value is not None}
169
- metadata = {
170
- **metadata,
171
- **record_metadata,
172
- "id": str(_id),
173
- }
174
- for key, value in metadata.items():
175
- if key in self.uuid_fields:
176
- metadata[key] = str(value)
177
- points.append(
178
- models.PointStruct(
179
- id=str(_id),
180
- vector=embedding,
181
- payload={
182
- TEXT_PAYLOAD_KEY: text,
183
- METADATA_PAYLOAD_KEY: metadata,
184
- },
185
- )
186
- )
187
- return points
188
-
189
- def get_qdrant_filters(self, filters: Optional[Dict] = {}):
190
- from qdrant_client import models
191
-
192
- filter_conditions = {**self.filters, **filters} if filters is not None else self.filters
193
- must_conditions = []
194
- for key, value in filter_conditions.items():
195
- match_value = str(value) if key in self.uuid_fields else value
196
- field_condition = models.FieldCondition(
197
- key=f"{METADATA_PAYLOAD_KEY}.{key}",
198
- match=models.MatchValue(value=match_value),
199
- )
200
- must_conditions.append(field_condition)
201
- return models.Filter(must=must_conditions)
@@ -1,186 +0,0 @@
1
- """ These classes define storage connectors.
2
-
3
- We originally tried to use Llama Index VectorIndex, but their limited API was extremely problematic.
4
- """
5
-
6
- import uuid
7
- from abc import abstractmethod
8
- from typing import Dict, List, Optional, Tuple, Type, Union
9
-
10
- from pydantic import BaseModel
11
-
12
- from letta.config import LettaConfig
13
- from letta.schemas.file import FileMetadata
14
- from letta.schemas.message import Message
15
- from letta.schemas.passage import Passage
16
- from letta.utils import printd
17
-
18
-
19
- # ENUM representing table types in Letta
20
- # each table corresponds to a different table schema (specified in data_types.py)
21
- class TableType:
22
- ARCHIVAL_MEMORY = "archival_memory" # recall memory table: letta_agent_{agent_id}
23
- RECALL_MEMORY = "recall_memory" # archival memory table: letta_agent_recall_{agent_id}
24
- PASSAGES = "passages" # TODO
25
- FILES = "files"
26
-
27
-
28
- # table names used by Letta
29
-
30
- # agent tables
31
- RECALL_TABLE_NAME = "letta_recall_memory_agent" # agent memory
32
- ARCHIVAL_TABLE_NAME = "letta_archival_memory_agent" # agent memory
33
-
34
- # external data source tables
35
- PASSAGE_TABLE_NAME = "letta_passages" # chunked/embedded passages (from source)
36
- FILE_TABLE_NAME = "letta_files" # original files (from source)
37
-
38
-
39
- class StorageConnector:
40
- """Defines a DB connection that is user-specific to access data: files, Passages, Archival/Recall Memory"""
41
-
42
- type: Type[BaseModel]
43
-
44
- def __init__(
45
- self,
46
- table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.FILES],
47
- config: LettaConfig,
48
- user_id: str,
49
- agent_id: Optional[str] = None,
50
- organization_id: Optional[str] = None,
51
- ):
52
- self.user_id = user_id
53
- self.agent_id = agent_id
54
- self.organization_id = organization_id
55
- self.table_type = table_type
56
-
57
- # get object type
58
- if table_type == TableType.ARCHIVAL_MEMORY:
59
- self.type = Passage
60
- self.table_name = ARCHIVAL_TABLE_NAME
61
- elif table_type == TableType.RECALL_MEMORY:
62
- self.type = Message
63
- self.table_name = RECALL_TABLE_NAME
64
- elif table_type == TableType.FILES:
65
- self.type = FileMetadata
66
- self.table_name = FILE_TABLE_NAME
67
- elif table_type == TableType.PASSAGES:
68
- self.type = Passage
69
- self.table_name = PASSAGE_TABLE_NAME
70
- else:
71
- raise ValueError(f"Table type {table_type} not implemented")
72
- printd(f"Using table name {self.table_name}")
73
-
74
- # setup base filters for agent-specific tables
75
- if self.table_type == TableType.ARCHIVAL_MEMORY or self.table_type == TableType.RECALL_MEMORY:
76
- # agent-specific table
77
- assert agent_id is not None, "Agent ID must be provided for agent-specific tables"
78
- self.filters = {"user_id": self.user_id, "agent_id": self.agent_id}
79
- elif self.table_type == TableType.FILES:
80
- # setup base filters for user-specific tables
81
- assert agent_id is None, "Agent ID must not be provided for user-specific tables"
82
- self.filters = {"user_id": self.user_id}
83
- elif self.table_type == TableType.PASSAGES:
84
- self.filters = {"organization_id": self.organization_id}
85
- else:
86
- raise ValueError(f"Table type {table_type} not implemented")
87
-
88
- @staticmethod
89
- def get_storage_connector(
90
- table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.FILES],
91
- config: LettaConfig,
92
- user_id: str,
93
- organization_id: Optional[str] = None,
94
- agent_id: Optional[str] = None,
95
- ):
96
- if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
97
- storage_type = config.archival_storage_type
98
- elif table_type == TableType.RECALL_MEMORY:
99
- storage_type = config.recall_storage_type
100
- elif table_type == TableType.FILES:
101
- storage_type = config.metadata_storage_type
102
- else:
103
- raise ValueError(f"Table type {table_type} not implemented")
104
-
105
- if storage_type == "postgres":
106
- from letta.agent_store.db import PostgresStorageConnector
107
-
108
- return PostgresStorageConnector(table_type, config, user_id, agent_id)
109
-
110
- elif storage_type == "qdrant":
111
- from letta.agent_store.qdrant import QdrantStorageConnector
112
-
113
- return QdrantStorageConnector(table_type, config, user_id, agent_id)
114
-
115
- elif storage_type == "sqlite":
116
- from letta.agent_store.db import SQLLiteStorageConnector
117
-
118
- return SQLLiteStorageConnector(table_type, config, user_id, agent_id)
119
- elif storage_type == "milvus":
120
- from letta.agent_store.milvus import MilvusStorageConnector
121
-
122
- return MilvusStorageConnector(table_type, config, user_id, agent_id)
123
- else:
124
- raise NotImplementedError(f"Storage type {storage_type} not implemented")
125
-
126
- @staticmethod
127
- def get_archival_storage_connector(user_id, agent_id):
128
- config = LettaConfig.load()
129
- return StorageConnector.get_storage_connector(TableType.ARCHIVAL_MEMORY, config, user_id, agent_id)
130
-
131
- @staticmethod
132
- def get_recall_storage_connector(user_id, agent_id):
133
- config = LettaConfig.load()
134
- return StorageConnector.get_storage_connector(TableType.RECALL_MEMORY, config, user_id, agent_id)
135
-
136
- @abstractmethod
137
- def get_filters(self, filters: Optional[Dict] = {}) -> Union[Tuple[list, dict], dict]:
138
- pass
139
-
140
- @abstractmethod
141
- def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 1000):
142
- pass
143
-
144
- @abstractmethod
145
- def get_all(self, filters: Optional[Dict] = {}, limit=10):
146
- pass
147
-
148
- @abstractmethod
149
- def get(self, id: uuid.UUID):
150
- pass
151
-
152
- @abstractmethod
153
- def size(self, filters: Optional[Dict] = {}) -> int:
154
- pass
155
-
156
- @abstractmethod
157
- def insert(self, record):
158
- pass
159
-
160
- @abstractmethod
161
- def insert_many(self, records, show_progress=False):
162
- pass
163
-
164
- @abstractmethod
165
- def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}):
166
- pass
167
-
168
- @abstractmethod
169
- def query_date(self, start_date, end_date):
170
- pass
171
-
172
- @abstractmethod
173
- def query_text(self, query):
174
- pass
175
-
176
- @abstractmethod
177
- def delete_table(self):
178
- pass
179
-
180
- @abstractmethod
181
- def delete(self, filters: Optional[Dict] = {}):
182
- pass
183
-
184
- @abstractmethod
185
- def save(self):
186
- pass