letta-nightly 0.6.3.dev20241211104238__py3-none-any.whl → 0.6.3.dev20241212104231__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of letta-nightly might be problematic. Click here for more details.

@@ -1,201 +0,0 @@
1
- import os
2
- import uuid
3
- from copy import deepcopy
4
- from typing import Dict, Iterator, List, Optional, cast
5
-
6
- from letta.agent_store.storage import StorageConnector, TableType
7
- from letta.config import LettaConfig
8
- from letta.constants import MAX_EMBEDDING_DIM
9
- from letta.data_types import Passage, Record, RecordType
10
- from letta.utils import datetime_to_timestamp, timestamp_to_datetime
11
-
12
- TEXT_PAYLOAD_KEY = "text_content"
13
- METADATA_PAYLOAD_KEY = "metadata"
14
-
15
-
16
- class QdrantStorageConnector(StorageConnector):
17
- """Storage via Qdrant"""
18
-
19
- def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None):
20
- super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id)
21
- try:
22
- from qdrant_client import QdrantClient, models
23
- except ImportError as e:
24
- raise ImportError("'qdrant-client' not installed. Run `pip install qdrant-client`.") from e
25
- assert table_type in [TableType.ARCHIVAL_MEMORY, TableType.PASSAGES], "Qdrant only supports archival memory"
26
- if config.archival_storage_uri and len(config.archival_storage_uri.split(":")) == 2:
27
- host, port = config.archival_storage_uri.split(":")
28
- self.qdrant_client = QdrantClient(host=host, port=port, api_key=os.getenv("QDRANT_API_KEY"))
29
- elif config.archival_storage_path:
30
- self.qdrant_client = QdrantClient(path=config.archival_storage_path)
31
- else:
32
- raise ValueError("Qdrant storage requires either a URI or a path to the storage configured")
33
- if not self.qdrant_client.collection_exists(self.table_name):
34
- self.qdrant_client.create_collection(
35
- collection_name=self.table_name,
36
- vectors_config=models.VectorParams(
37
- size=MAX_EMBEDDING_DIM,
38
- distance=models.Distance.COSINE,
39
- ),
40
- )
41
- self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "file_id"]
42
-
43
- def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 10) -> Iterator[List[RecordType]]:
44
- from qdrant_client import grpc
45
-
46
- filters = self.get_qdrant_filters(filters)
47
- next_offset = None
48
- stop_scrolling = False
49
- while not stop_scrolling:
50
- results, next_offset = self.qdrant_client.scroll(
51
- collection_name=self.table_name,
52
- scroll_filter=filters,
53
- limit=page_size,
54
- offset=next_offset,
55
- with_payload=True,
56
- with_vectors=True,
57
- )
58
- stop_scrolling = next_offset is None or (
59
- isinstance(next_offset, grpc.PointId) and next_offset.num == 0 and next_offset.uuid == ""
60
- )
61
- yield self.to_records(results)
62
-
63
- def get_all(self, filters: Optional[Dict] = {}, limit=10) -> List[RecordType]:
64
- if self.size(filters) == 0:
65
- return []
66
- filters = self.get_qdrant_filters(filters)
67
- results, _ = self.qdrant_client.scroll(
68
- self.table_name,
69
- scroll_filter=filters,
70
- limit=limit,
71
- with_payload=True,
72
- with_vectors=True,
73
- )
74
- return self.to_records(results)
75
-
76
- def get(self, id: str) -> Optional[RecordType]:
77
- results = self.qdrant_client.retrieve(
78
- collection_name=self.table_name,
79
- ids=[str(id)],
80
- with_payload=True,
81
- with_vectors=True,
82
- )
83
- if not results:
84
- return None
85
- return self.to_records(results)[0]
86
-
87
- def insert(self, record: Record):
88
- points = self.to_points([record])
89
- self.qdrant_client.upsert(self.table_name, points=points)
90
-
91
- def insert_many(self, records: List[RecordType], show_progress=False):
92
- points = self.to_points(records)
93
- self.qdrant_client.upsert(self.table_name, points=points)
94
-
95
- def delete(self, filters: Optional[Dict] = {}):
96
- filters = self.get_qdrant_filters(filters)
97
- self.qdrant_client.delete(self.table_name, points_selector=filters)
98
-
99
- def delete_table(self):
100
- self.qdrant_client.delete_collection(self.table_name)
101
- self.qdrant_client.close()
102
-
103
- def size(self, filters: Optional[Dict] = {}) -> int:
104
- filters = self.get_qdrant_filters(filters)
105
- return self.qdrant_client.count(collection_name=self.table_name, count_filter=filters).count
106
-
107
- def close(self):
108
- self.qdrant_client.close()
109
-
110
- def query(
111
- self,
112
- query: str,
113
- query_vec: List[float],
114
- top_k: int = 10,
115
- filters: Optional[Dict] = {},
116
- ) -> List[RecordType]:
117
- filters = self.get_filters(filters)
118
- results = self.qdrant_client.search(
119
- self.table_name,
120
- query_vector=query_vec,
121
- query_filter=filters,
122
- limit=top_k,
123
- with_payload=True,
124
- with_vectors=True,
125
- )
126
- return self.to_records(results)
127
-
128
- def to_records(self, records: list) -> List[RecordType]:
129
- parsed_records = []
130
- for record in records:
131
- record = deepcopy(record)
132
- metadata = record.payload[METADATA_PAYLOAD_KEY]
133
- text = record.payload[TEXT_PAYLOAD_KEY]
134
- _id = metadata.pop("id")
135
- embedding = record.vector
136
- for key, value in metadata.items():
137
- if key in self.uuid_fields:
138
- metadata[key] = uuid.UUID(value)
139
- elif key == "created_at":
140
- metadata[key] = timestamp_to_datetime(value)
141
- parsed_records.append(
142
- cast(
143
- RecordType,
144
- self.type(
145
- text=text,
146
- embedding=embedding,
147
- id=uuid.UUID(_id),
148
- **metadata,
149
- ),
150
- )
151
- )
152
- return parsed_records
153
-
154
- def to_points(self, records: List[RecordType]):
155
- from qdrant_client import models
156
-
157
- assert all(isinstance(r, Passage) for r in records)
158
- points = []
159
- records = list(set(records))
160
- for record in records:
161
- record = vars(record)
162
- _id = record.pop("id")
163
- text = record.pop("text", "")
164
- embedding = record.pop("embedding", {})
165
- record_metadata = record.pop("metadata_", None) or {}
166
- if "created_at" in record:
167
- record["created_at"] = datetime_to_timestamp(record["created_at"])
168
- metadata = {key: value for key, value in record.items() if value is not None}
169
- metadata = {
170
- **metadata,
171
- **record_metadata,
172
- "id": str(_id),
173
- }
174
- for key, value in metadata.items():
175
- if key in self.uuid_fields:
176
- metadata[key] = str(value)
177
- points.append(
178
- models.PointStruct(
179
- id=str(_id),
180
- vector=embedding,
181
- payload={
182
- TEXT_PAYLOAD_KEY: text,
183
- METADATA_PAYLOAD_KEY: metadata,
184
- },
185
- )
186
- )
187
- return points
188
-
189
- def get_qdrant_filters(self, filters: Optional[Dict] = {}):
190
- from qdrant_client import models
191
-
192
- filter_conditions = {**self.filters, **filters} if filters is not None else self.filters
193
- must_conditions = []
194
- for key, value in filter_conditions.items():
195
- match_value = str(value) if key in self.uuid_fields else value
196
- field_condition = models.FieldCondition(
197
- key=f"{METADATA_PAYLOAD_KEY}.{key}",
198
- match=models.MatchValue(value=match_value),
199
- )
200
- must_conditions.append(field_condition)
201
- return models.Filter(must=must_conditions)
@@ -1,186 +0,0 @@
1
- """ These classes define storage connectors.
2
-
3
- We originally tried to use Llama Index VectorIndex, but their limited API was extremely problematic.
4
- """
5
-
6
- import uuid
7
- from abc import abstractmethod
8
- from typing import Dict, List, Optional, Tuple, Type, Union
9
-
10
- from pydantic import BaseModel
11
-
12
- from letta.config import LettaConfig
13
- from letta.schemas.file import FileMetadata
14
- from letta.schemas.message import Message
15
- from letta.schemas.passage import Passage
16
- from letta.utils import printd
17
-
18
-
19
- # ENUM representing table types in Letta
20
- # each table corresponds to a different table schema (specified in data_types.py)
21
- class TableType:
22
- ARCHIVAL_MEMORY = "archival_memory" # recall memory table: letta_agent_{agent_id}
23
- RECALL_MEMORY = "recall_memory" # archival memory table: letta_agent_recall_{agent_id}
24
- PASSAGES = "passages" # TODO
25
- FILES = "files"
26
-
27
-
28
- # table names used by Letta
29
-
30
- # agent tables
31
- RECALL_TABLE_NAME = "letta_recall_memory_agent" # agent memory
32
- ARCHIVAL_TABLE_NAME = "letta_archival_memory_agent" # agent memory
33
-
34
- # external data source tables
35
- PASSAGE_TABLE_NAME = "letta_passages" # chunked/embedded passages (from source)
36
- FILE_TABLE_NAME = "letta_files" # original files (from source)
37
-
38
-
39
- class StorageConnector:
40
- """Defines a DB connection that is user-specific to access data: files, Passages, Archival/Recall Memory"""
41
-
42
- type: Type[BaseModel]
43
-
44
- def __init__(
45
- self,
46
- table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.FILES],
47
- config: LettaConfig,
48
- user_id: str,
49
- agent_id: Optional[str] = None,
50
- organization_id: Optional[str] = None,
51
- ):
52
- self.user_id = user_id
53
- self.agent_id = agent_id
54
- self.organization_id = organization_id
55
- self.table_type = table_type
56
-
57
- # get object type
58
- if table_type == TableType.ARCHIVAL_MEMORY:
59
- self.type = Passage
60
- self.table_name = ARCHIVAL_TABLE_NAME
61
- elif table_type == TableType.RECALL_MEMORY:
62
- self.type = Message
63
- self.table_name = RECALL_TABLE_NAME
64
- elif table_type == TableType.FILES:
65
- self.type = FileMetadata
66
- self.table_name = FILE_TABLE_NAME
67
- elif table_type == TableType.PASSAGES:
68
- self.type = Passage
69
- self.table_name = PASSAGE_TABLE_NAME
70
- else:
71
- raise ValueError(f"Table type {table_type} not implemented")
72
- printd(f"Using table name {self.table_name}")
73
-
74
- # setup base filters for agent-specific tables
75
- if self.table_type == TableType.ARCHIVAL_MEMORY or self.table_type == TableType.RECALL_MEMORY:
76
- # agent-specific table
77
- assert agent_id is not None, "Agent ID must be provided for agent-specific tables"
78
- self.filters = {"user_id": self.user_id, "agent_id": self.agent_id}
79
- elif self.table_type == TableType.FILES:
80
- # setup base filters for user-specific tables
81
- assert agent_id is None, "Agent ID must not be provided for user-specific tables"
82
- self.filters = {"user_id": self.user_id}
83
- elif self.table_type == TableType.PASSAGES:
84
- self.filters = {"organization_id": self.organization_id}
85
- else:
86
- raise ValueError(f"Table type {table_type} not implemented")
87
-
88
- @staticmethod
89
- def get_storage_connector(
90
- table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.FILES],
91
- config: LettaConfig,
92
- user_id: str,
93
- organization_id: Optional[str] = None,
94
- agent_id: Optional[str] = None,
95
- ):
96
- if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
97
- storage_type = config.archival_storage_type
98
- elif table_type == TableType.RECALL_MEMORY:
99
- storage_type = config.recall_storage_type
100
- elif table_type == TableType.FILES:
101
- storage_type = config.metadata_storage_type
102
- else:
103
- raise ValueError(f"Table type {table_type} not implemented")
104
-
105
- if storage_type == "postgres":
106
- from letta.agent_store.db import PostgresStorageConnector
107
-
108
- return PostgresStorageConnector(table_type, config, user_id, agent_id)
109
-
110
- elif storage_type == "qdrant":
111
- from letta.agent_store.qdrant import QdrantStorageConnector
112
-
113
- return QdrantStorageConnector(table_type, config, user_id, agent_id)
114
-
115
- elif storage_type == "sqlite":
116
- from letta.agent_store.db import SQLLiteStorageConnector
117
-
118
- return SQLLiteStorageConnector(table_type, config, user_id, agent_id)
119
- elif storage_type == "milvus":
120
- from letta.agent_store.milvus import MilvusStorageConnector
121
-
122
- return MilvusStorageConnector(table_type, config, user_id, agent_id)
123
- else:
124
- raise NotImplementedError(f"Storage type {storage_type} not implemented")
125
-
126
- @staticmethod
127
- def get_archival_storage_connector(user_id, agent_id):
128
- config = LettaConfig.load()
129
- return StorageConnector.get_storage_connector(TableType.ARCHIVAL_MEMORY, config, user_id, agent_id)
130
-
131
- @staticmethod
132
- def get_recall_storage_connector(user_id, agent_id):
133
- config = LettaConfig.load()
134
- return StorageConnector.get_storage_connector(TableType.RECALL_MEMORY, config, user_id, agent_id)
135
-
136
- @abstractmethod
137
- def get_filters(self, filters: Optional[Dict] = {}) -> Union[Tuple[list, dict], dict]:
138
- pass
139
-
140
- @abstractmethod
141
- def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 1000):
142
- pass
143
-
144
- @abstractmethod
145
- def get_all(self, filters: Optional[Dict] = {}, limit=10):
146
- pass
147
-
148
- @abstractmethod
149
- def get(self, id: uuid.UUID):
150
- pass
151
-
152
- @abstractmethod
153
- def size(self, filters: Optional[Dict] = {}) -> int:
154
- pass
155
-
156
- @abstractmethod
157
- def insert(self, record):
158
- pass
159
-
160
- @abstractmethod
161
- def insert_many(self, records, show_progress=False):
162
- pass
163
-
164
- @abstractmethod
165
- def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}):
166
- pass
167
-
168
- @abstractmethod
169
- def query_date(self, start_date, end_date):
170
- pass
171
-
172
- @abstractmethod
173
- def query_text(self, query):
174
- pass
175
-
176
- @abstractmethod
177
- def delete_table(self):
178
- pass
179
-
180
- @abstractmethod
181
- def delete(self, filters: Optional[Dict] = {}):
182
- pass
183
-
184
- @abstractmethod
185
- def save(self):
186
- pass