letta-nightly 0.1.7.dev20240924104148__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of letta-nightly might be problematic. Click here for more details.

Files changed (189) hide show
  1. letta/__init__.py +24 -0
  2. letta/__main__.py +3 -0
  3. letta/agent.py +1427 -0
  4. letta/agent_store/chroma.py +295 -0
  5. letta/agent_store/db.py +546 -0
  6. letta/agent_store/lancedb.py +177 -0
  7. letta/agent_store/milvus.py +198 -0
  8. letta/agent_store/qdrant.py +201 -0
  9. letta/agent_store/storage.py +188 -0
  10. letta/benchmark/benchmark.py +96 -0
  11. letta/benchmark/constants.py +14 -0
  12. letta/cli/cli.py +689 -0
  13. letta/cli/cli_config.py +1282 -0
  14. letta/cli/cli_load.py +166 -0
  15. letta/client/__init__.py +0 -0
  16. letta/client/admin.py +171 -0
  17. letta/client/client.py +2360 -0
  18. letta/client/streaming.py +90 -0
  19. letta/client/utils.py +61 -0
  20. letta/config.py +484 -0
  21. letta/configs/anthropic.json +13 -0
  22. letta/configs/letta_hosted.json +11 -0
  23. letta/configs/openai.json +12 -0
  24. letta/constants.py +134 -0
  25. letta/credentials.py +140 -0
  26. letta/data_sources/connectors.py +247 -0
  27. letta/embeddings.py +218 -0
  28. letta/errors.py +26 -0
  29. letta/functions/__init__.py +0 -0
  30. letta/functions/function_sets/base.py +174 -0
  31. letta/functions/function_sets/extras.py +132 -0
  32. letta/functions/functions.py +105 -0
  33. letta/functions/schema_generator.py +205 -0
  34. letta/humans/__init__.py +0 -0
  35. letta/humans/examples/basic.txt +1 -0
  36. letta/humans/examples/cs_phd.txt +9 -0
  37. letta/interface.py +314 -0
  38. letta/llm_api/__init__.py +0 -0
  39. letta/llm_api/anthropic.py +383 -0
  40. letta/llm_api/azure_openai.py +155 -0
  41. letta/llm_api/cohere.py +396 -0
  42. letta/llm_api/google_ai.py +468 -0
  43. letta/llm_api/llm_api_tools.py +485 -0
  44. letta/llm_api/openai.py +470 -0
  45. letta/local_llm/README.md +3 -0
  46. letta/local_llm/__init__.py +0 -0
  47. letta/local_llm/chat_completion_proxy.py +279 -0
  48. letta/local_llm/constants.py +31 -0
  49. letta/local_llm/function_parser.py +68 -0
  50. letta/local_llm/grammars/__init__.py +0 -0
  51. letta/local_llm/grammars/gbnf_grammar_generator.py +1324 -0
  52. letta/local_llm/grammars/json.gbnf +26 -0
  53. letta/local_llm/grammars/json_func_calls_with_inner_thoughts.gbnf +32 -0
  54. letta/local_llm/groq/api.py +97 -0
  55. letta/local_llm/json_parser.py +202 -0
  56. letta/local_llm/koboldcpp/api.py +62 -0
  57. letta/local_llm/koboldcpp/settings.py +23 -0
  58. letta/local_llm/llamacpp/api.py +58 -0
  59. letta/local_llm/llamacpp/settings.py +22 -0
  60. letta/local_llm/llm_chat_completion_wrappers/__init__.py +0 -0
  61. letta/local_llm/llm_chat_completion_wrappers/airoboros.py +452 -0
  62. letta/local_llm/llm_chat_completion_wrappers/chatml.py +470 -0
  63. letta/local_llm/llm_chat_completion_wrappers/configurable_wrapper.py +387 -0
  64. letta/local_llm/llm_chat_completion_wrappers/dolphin.py +246 -0
  65. letta/local_llm/llm_chat_completion_wrappers/llama3.py +345 -0
  66. letta/local_llm/llm_chat_completion_wrappers/simple_summary_wrapper.py +156 -0
  67. letta/local_llm/llm_chat_completion_wrappers/wrapper_base.py +11 -0
  68. letta/local_llm/llm_chat_completion_wrappers/zephyr.py +345 -0
  69. letta/local_llm/lmstudio/api.py +100 -0
  70. letta/local_llm/lmstudio/settings.py +29 -0
  71. letta/local_llm/ollama/api.py +88 -0
  72. letta/local_llm/ollama/settings.py +32 -0
  73. letta/local_llm/settings/__init__.py +0 -0
  74. letta/local_llm/settings/deterministic_mirostat.py +45 -0
  75. letta/local_llm/settings/settings.py +72 -0
  76. letta/local_llm/settings/simple.py +28 -0
  77. letta/local_llm/utils.py +265 -0
  78. letta/local_llm/vllm/api.py +63 -0
  79. letta/local_llm/webui/api.py +60 -0
  80. letta/local_llm/webui/legacy_api.py +58 -0
  81. letta/local_llm/webui/legacy_settings.py +23 -0
  82. letta/local_llm/webui/settings.py +24 -0
  83. letta/log.py +76 -0
  84. letta/main.py +437 -0
  85. letta/memory.py +440 -0
  86. letta/metadata.py +884 -0
  87. letta/openai_backcompat/__init__.py +0 -0
  88. letta/openai_backcompat/openai_object.py +437 -0
  89. letta/persistence_manager.py +148 -0
  90. letta/personas/__init__.py +0 -0
  91. letta/personas/examples/anna_pa.txt +13 -0
  92. letta/personas/examples/google_search_persona.txt +15 -0
  93. letta/personas/examples/memgpt_doc.txt +6 -0
  94. letta/personas/examples/memgpt_starter.txt +4 -0
  95. letta/personas/examples/sam.txt +14 -0
  96. letta/personas/examples/sam_pov.txt +14 -0
  97. letta/personas/examples/sam_simple_pov_gpt35.txt +13 -0
  98. letta/personas/examples/sqldb/test.db +0 -0
  99. letta/prompts/__init__.py +0 -0
  100. letta/prompts/gpt_summarize.py +14 -0
  101. letta/prompts/gpt_system.py +26 -0
  102. letta/prompts/system/memgpt_base.txt +49 -0
  103. letta/prompts/system/memgpt_chat.txt +58 -0
  104. letta/prompts/system/memgpt_chat_compressed.txt +13 -0
  105. letta/prompts/system/memgpt_chat_fstring.txt +51 -0
  106. letta/prompts/system/memgpt_doc.txt +50 -0
  107. letta/prompts/system/memgpt_gpt35_extralong.txt +53 -0
  108. letta/prompts/system/memgpt_intuitive_knowledge.txt +31 -0
  109. letta/prompts/system/memgpt_modified_chat.txt +23 -0
  110. letta/pytest.ini +0 -0
  111. letta/schemas/agent.py +117 -0
  112. letta/schemas/api_key.py +21 -0
  113. letta/schemas/block.py +135 -0
  114. letta/schemas/document.py +21 -0
  115. letta/schemas/embedding_config.py +54 -0
  116. letta/schemas/enums.py +35 -0
  117. letta/schemas/job.py +38 -0
  118. letta/schemas/letta_base.py +80 -0
  119. letta/schemas/letta_message.py +175 -0
  120. letta/schemas/letta_request.py +23 -0
  121. letta/schemas/letta_response.py +28 -0
  122. letta/schemas/llm_config.py +54 -0
  123. letta/schemas/memory.py +224 -0
  124. letta/schemas/message.py +727 -0
  125. letta/schemas/openai/chat_completion_request.py +123 -0
  126. letta/schemas/openai/chat_completion_response.py +136 -0
  127. letta/schemas/openai/chat_completions.py +123 -0
  128. letta/schemas/openai/embedding_response.py +11 -0
  129. letta/schemas/openai/openai.py +157 -0
  130. letta/schemas/organization.py +20 -0
  131. letta/schemas/passage.py +80 -0
  132. letta/schemas/source.py +62 -0
  133. letta/schemas/tool.py +143 -0
  134. letta/schemas/usage.py +18 -0
  135. letta/schemas/user.py +33 -0
  136. letta/server/__init__.py +0 -0
  137. letta/server/constants.py +6 -0
  138. letta/server/rest_api/__init__.py +0 -0
  139. letta/server/rest_api/admin/__init__.py +0 -0
  140. letta/server/rest_api/admin/agents.py +21 -0
  141. letta/server/rest_api/admin/tools.py +83 -0
  142. letta/server/rest_api/admin/users.py +98 -0
  143. letta/server/rest_api/app.py +193 -0
  144. letta/server/rest_api/auth/__init__.py +0 -0
  145. letta/server/rest_api/auth/index.py +43 -0
  146. letta/server/rest_api/auth_token.py +22 -0
  147. letta/server/rest_api/interface.py +726 -0
  148. letta/server/rest_api/routers/__init__.py +0 -0
  149. letta/server/rest_api/routers/openai/__init__.py +0 -0
  150. letta/server/rest_api/routers/openai/assistants/__init__.py +0 -0
  151. letta/server/rest_api/routers/openai/assistants/assistants.py +115 -0
  152. letta/server/rest_api/routers/openai/assistants/schemas.py +121 -0
  153. letta/server/rest_api/routers/openai/assistants/threads.py +336 -0
  154. letta/server/rest_api/routers/openai/chat_completions/__init__.py +0 -0
  155. letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +131 -0
  156. letta/server/rest_api/routers/v1/__init__.py +15 -0
  157. letta/server/rest_api/routers/v1/agents.py +543 -0
  158. letta/server/rest_api/routers/v1/blocks.py +73 -0
  159. letta/server/rest_api/routers/v1/jobs.py +46 -0
  160. letta/server/rest_api/routers/v1/llms.py +28 -0
  161. letta/server/rest_api/routers/v1/organizations.py +61 -0
  162. letta/server/rest_api/routers/v1/sources.py +199 -0
  163. letta/server/rest_api/routers/v1/tools.py +103 -0
  164. letta/server/rest_api/routers/v1/users.py +109 -0
  165. letta/server/rest_api/static_files.py +74 -0
  166. letta/server/rest_api/utils.py +69 -0
  167. letta/server/server.py +1995 -0
  168. letta/server/startup.sh +8 -0
  169. letta/server/static_files/assets/index-0cbf7ad5.js +274 -0
  170. letta/server/static_files/assets/index-156816da.css +1 -0
  171. letta/server/static_files/assets/index-486e3228.js +274 -0
  172. letta/server/static_files/favicon.ico +0 -0
  173. letta/server/static_files/index.html +39 -0
  174. letta/server/static_files/memgpt_logo_transparent.png +0 -0
  175. letta/server/utils.py +46 -0
  176. letta/server/ws_api/__init__.py +0 -0
  177. letta/server/ws_api/example_client.py +104 -0
  178. letta/server/ws_api/interface.py +108 -0
  179. letta/server/ws_api/protocol.py +100 -0
  180. letta/server/ws_api/server.py +145 -0
  181. letta/settings.py +165 -0
  182. letta/streaming_interface.py +396 -0
  183. letta/system.py +207 -0
  184. letta/utils.py +1065 -0
  185. letta_nightly-0.1.7.dev20240924104148.dist-info/LICENSE +190 -0
  186. letta_nightly-0.1.7.dev20240924104148.dist-info/METADATA +98 -0
  187. letta_nightly-0.1.7.dev20240924104148.dist-info/RECORD +189 -0
  188. letta_nightly-0.1.7.dev20240924104148.dist-info/WHEEL +4 -0
  189. letta_nightly-0.1.7.dev20240924104148.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,177 @@
1
+ # type: ignore
2
+
3
+ import uuid
4
+ from datetime import datetime
5
+ from typing import Dict, Iterator, List, Optional
6
+
7
+ from lancedb.pydantic import LanceModel, Vector
8
+
9
+ from letta.agent_store.storage import StorageConnector, TableType
10
+ from letta.config import AgentConfig, LettaConfig
11
+ from letta.schemas.message import Message, Passage, Record
12
+
13
+ """ Initial implementation - not complete """
14
+
15
+
16
+ def get_db_model(table_name: str, table_type: TableType):
17
+ config = LettaConfig.load()
18
+
19
+ if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
20
+ # create schema for archival memory
21
+ class PassageModel(LanceModel):
22
+ """Defines data model for storing Passages (consisting of text, embedding)"""
23
+
24
+ id: uuid.UUID
25
+ user_id: str
26
+ text: str
27
+ doc_id: str
28
+ agent_id: str
29
+ data_source: str
30
+ embedding: Vector(config.default_embedding_config.embedding_dim)
31
+ metadata_: Dict
32
+
33
+ def __repr__(self):
34
+ return f"<Passage(passage_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
35
+
36
+ def to_record(self):
37
+ return Passage(
38
+ text=self.text,
39
+ embedding=self.embedding,
40
+ doc_id=self.doc_id,
41
+ user_id=self.user_id,
42
+ id=self.id,
43
+ data_source=self.data_source,
44
+ agent_id=self.agent_id,
45
+ metadata=self.metadata_,
46
+ )
47
+
48
+ return PassageModel
49
+ elif table_type == TableType.RECALL_MEMORY:
50
+
51
+ class MessageModel(LanceModel):
52
+ """Defines data model for storing Message objects"""
53
+
54
+ __abstract__ = True # this line is necessary
55
+
56
+ # Assuming message_id is the primary key
57
+ id: uuid.UUID
58
+ user_id: str
59
+ agent_id: str
60
+
61
+ # openai info
62
+ role: str
63
+ name: str
64
+ text: str
65
+ model: str
66
+ user: str
67
+
68
+ # function info
69
+ function_name: str
70
+ function_args: str
71
+ function_response: str
72
+
73
+ embedding = Vector(config.default_embedding_config.embedding_dim)
74
+
75
+ # Add a datetime column, with default value as the current time
76
+ created_at = datetime
77
+
78
+ def __repr__(self):
79
+ return f"<Message(message_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
80
+
81
+ def to_record(self):
82
+ return Message(
83
+ user_id=self.user_id,
84
+ agent_id=self.agent_id,
85
+ role=self.role,
86
+ name=self.name,
87
+ text=self.text,
88
+ model=self.model,
89
+ function_name=self.function_name,
90
+ function_args=self.function_args,
91
+ function_response=self.function_response,
92
+ embedding=self.embedding,
93
+ created_at=self.created_at,
94
+ id=self.id,
95
+ )
96
+
97
+ """Create database model for table_name"""
98
+ return MessageModel
99
+
100
+ else:
101
+ raise ValueError(f"Table type {table_type} not implemented")
102
+
103
+
104
+ class LanceDBConnector(StorageConnector):
105
+ """Storage via LanceDB"""
106
+
107
+ # TODO: this should probably eventually be moved into a parent DB class
108
+
109
+ def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfig] = None):
110
+ # TODO
111
+ pass
112
+
113
+ def generate_where_filter(self, filters: Dict) -> str:
114
+ where_filters = []
115
+ for key, value in filters.items():
116
+ where_filters.append(f"{key}={value}")
117
+ return where_filters.join(" AND ")
118
+
119
+ @abstractmethod
120
+ def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000) -> Iterator[List[Record]]:
121
+ # TODO
122
+ pass
123
+
124
+ @abstractmethod
125
+ def get_all(self, filters: Optional[Dict] = {}, limit=10) -> List[Record]:
126
+ # TODO
127
+ pass
128
+
129
+ @abstractmethod
130
+ def get(self, id: uuid.UUID) -> Optional[Record]:
131
+ # TODO
132
+ pass
133
+
134
+ @abstractmethod
135
+ def size(self, filters: Optional[Dict] = {}) -> int:
136
+ # TODO
137
+ pass
138
+
139
+ @abstractmethod
140
+ def insert(self, record: Record):
141
+ # TODO
142
+ pass
143
+
144
+ @abstractmethod
145
+ def insert_many(self, records: List[Record], show_progress=False):
146
+ # TODO
147
+ pass
148
+
149
+ @abstractmethod
150
+ def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
151
+ # TODO
152
+ pass
153
+
154
+ @abstractmethod
155
+ def query_date(self, start_date, end_date):
156
+ # TODO
157
+ pass
158
+
159
+ @abstractmethod
160
+ def query_text(self, query):
161
+ # TODO
162
+ pass
163
+
164
+ @abstractmethod
165
+ def delete_table(self):
166
+ # TODO
167
+ pass
168
+
169
+ @abstractmethod
170
+ def delete(self, filters: Optional[Dict] = {}):
171
+ # TODO
172
+ pass
173
+
174
+ @abstractmethod
175
+ def save(self):
176
+ # TODO
177
+ pass
@@ -0,0 +1,198 @@
1
+ import uuid
2
+ from copy import deepcopy
3
+ from typing import Dict, Iterator, List, Optional, cast
4
+
5
+ from pymilvus import DataType, MilvusClient
6
+ from pymilvus.client.constants import ConsistencyLevel
7
+
8
+ from letta.agent_store.storage import StorageConnector, TableType
9
+ from letta.config import LettaConfig
10
+ from letta.constants import MAX_EMBEDDING_DIM
11
+ from letta.data_types import Passage, Record, RecordType
12
+ from letta.utils import datetime_to_timestamp, printd, timestamp_to_datetime
13
+
14
+
15
+ class MilvusStorageConnector(StorageConnector):
16
+ """Storage via Milvus"""
17
+
18
+ def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None):
19
+ super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id)
20
+
21
+ assert table_type in [TableType.ARCHIVAL_MEMORY, TableType.PASSAGES], "Milvus only supports archival memory"
22
+ if config.archival_storage_uri:
23
+ self.client = MilvusClient(uri=config.archival_storage_uri)
24
+ self._create_collection()
25
+ else:
26
+ raise ValueError("Please set `archival_storage_uri` in the config file when using Milvus.")
27
+
28
+ # need to be converted to strings
29
+ self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "doc_id"]
30
+
31
+ def _create_collection(self):
32
+ schema = MilvusClient.create_schema(
33
+ auto_id=False,
34
+ enable_dynamic_field=True,
35
+ )
36
+ schema.add_field(field_name="id", datatype=DataType.VARCHAR, is_primary=True, max_length=65_535)
37
+ schema.add_field(field_name="text", datatype=DataType.VARCHAR, is_primary=False, max_length=65_535)
38
+ schema.add_field(field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=MAX_EMBEDDING_DIM)
39
+ index_params = self.client.prepare_index_params()
40
+ index_params.add_index(field_name="id")
41
+ index_params.add_index(field_name="embedding", index_type="AUTOINDEX", metric_type="IP")
42
+ self.client.create_collection(
43
+ collection_name=self.table_name, schema=schema, index_params=index_params, consistency_level=ConsistencyLevel.Strong
44
+ )
45
+
46
+ def get_milvus_filter(self, filters: Optional[Dict] = {}) -> str:
47
+ filter_conditions = {**self.filters, **filters} if filters is not None else self.filters
48
+ if not filter_conditions:
49
+ return ""
50
+ conditions = []
51
+ for key, value in filter_conditions.items():
52
+ if key in self.uuid_fields or isinstance(key, str):
53
+ condition = f'({key} == "{value}")'
54
+ else:
55
+ condition = f"({key} == {value})"
56
+ conditions.append(condition)
57
+ filter_expr = " and ".join(conditions)
58
+ if len(conditions) == 1:
59
+ filter_expr = filter_expr[1:-1]
60
+ return filter_expr
61
+
62
+ def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 1000) -> Iterator[List[RecordType]]:
63
+ if not self.client.has_collection(collection_name=self.table_name):
64
+ yield []
65
+ filter_expr = self.get_milvus_filter(filters)
66
+ offset = 0
67
+ while True:
68
+ # Retrieve a chunk of records with the given page_size
69
+ query_res = self.client.query(
70
+ collection_name=self.table_name,
71
+ filter=filter_expr,
72
+ offset=offset,
73
+ limit=page_size,
74
+ )
75
+ if not query_res:
76
+ break
77
+ # Yield a list of Record objects converted from the chunk
78
+ yield self._list_to_records(query_res)
79
+
80
+ # Increment the offset to get the next chunk in the next iteration
81
+ offset += page_size
82
+
83
+ def get_all(self, filters: Optional[Dict] = {}, limit=None) -> List[RecordType]:
84
+ if not self.client.has_collection(collection_name=self.table_name):
85
+ return []
86
+ filter_expr = self.get_milvus_filter(filters)
87
+ query_res = self.client.query(
88
+ collection_name=self.table_name,
89
+ filter=filter_expr,
90
+ limit=limit,
91
+ )
92
+ return self._list_to_records(query_res)
93
+
94
+ def get(self, id: str) -> Optional[RecordType]:
95
+ res = self.client.get(collection_name=self.table_name, ids=str(id))
96
+ return self._list_to_records(res)[0] if res else None
97
+
98
+ def size(self, filters: Optional[Dict] = {}) -> int:
99
+ if not self.client.has_collection(collection_name=self.table_name):
100
+ return 0
101
+ filter_expr = self.get_milvus_filter(filters)
102
+ count_expr = "count(*)"
103
+ query_res = self.client.query(
104
+ collection_name=self.table_name,
105
+ filter=filter_expr,
106
+ output_fields=[count_expr],
107
+ )
108
+ doc_num = query_res[0][count_expr]
109
+ return doc_num
110
+
111
+ def insert(self, record: RecordType):
112
+ self.insert_many([record])
113
+
114
+ def insert_many(self, records: List[RecordType], show_progress=False):
115
+ if not records:
116
+ return
117
+
118
+ # Milvus lite currently does not support upsert, so we delete and insert instead
119
+ # self.client.upsert(collection_name=self.table_name, data=self._records_to_list(records))
120
+ ids = [str(record.id) for record in records]
121
+ self.client.delete(collection_name=self.table_name, ids=ids)
122
+ data = self._records_to_list(records)
123
+ self.client.insert(collection_name=self.table_name, data=data)
124
+
125
+ def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[RecordType]:
126
+ if not self.client.has_collection(self.table_name):
127
+ return []
128
+ search_res = self.client.search(
129
+ collection_name=self.table_name, data=[query_vec], filter=self.get_milvus_filter(filters), limit=top_k, output_fields=["*"]
130
+ )[0]
131
+ entity_res = [res["entity"] for res in search_res]
132
+ return self._list_to_records(entity_res)
133
+
134
+ def delete_table(self):
135
+ self.client.drop_collection(collection_name=self.table_name)
136
+
137
+ def delete(self, filters: Optional[Dict] = {}):
138
+ if not self.client.has_collection(collection_name=self.table_name):
139
+ return
140
+ filter_expr = self.get_milvus_filter(filters)
141
+ self.client.delete(collection_name=self.table_name, filter=filter_expr)
142
+
143
+ def save(self):
144
+ # save to persistence file (nothing needs to be done)
145
+ printd("Saving milvus")
146
+
147
+ def _records_to_list(self, records: List[Record]) -> List[Dict]:
148
+ if records == []:
149
+ return []
150
+ assert all(isinstance(r, Passage) for r in records)
151
+ record_list = []
152
+ records = list(set(records))
153
+ for record in records:
154
+ record_vars = deepcopy(vars(record))
155
+ _id = record_vars.pop("id")
156
+ text = record_vars.pop("text", "")
157
+ embedding = record_vars.pop("embedding")
158
+ record_metadata = record_vars.pop("metadata_", None) or {}
159
+ if "created_at" in record_vars:
160
+ record_vars["created_at"] = datetime_to_timestamp(record_vars["created_at"])
161
+ record_dict = {key: value for key, value in record_vars.items() if value is not None}
162
+ record_dict = {
163
+ **record_dict,
164
+ **record_metadata,
165
+ "id": str(_id),
166
+ "text": text,
167
+ "embedding": embedding,
168
+ }
169
+ for key, value in record_dict.items():
170
+ if key in self.uuid_fields:
171
+ record_dict[key] = str(value)
172
+ record_list.append(record_dict)
173
+ return record_list
174
+
175
+ def _list_to_records(self, query_res: List[Dict]) -> List[RecordType]:
176
+ records = []
177
+ for res_dict in query_res:
178
+ _id = res_dict.pop("id")
179
+ embedding = res_dict.pop("embedding")
180
+ text = res_dict.pop("text")
181
+ metadata = deepcopy(res_dict)
182
+ for key, value in metadata.items():
183
+ if key in self.uuid_fields:
184
+ metadata[key] = uuid.UUID(value)
185
+ elif key == "created_at":
186
+ metadata[key] = timestamp_to_datetime(value)
187
+ records.append(
188
+ cast(
189
+ RecordType,
190
+ self.type(
191
+ text=text,
192
+ embedding=embedding,
193
+ id=uuid.UUID(_id),
194
+ **metadata,
195
+ ),
196
+ )
197
+ )
198
+ return records
@@ -0,0 +1,201 @@
1
+ import os
2
+ import uuid
3
+ from copy import deepcopy
4
+ from typing import Dict, Iterator, List, Optional, cast
5
+
6
+ from letta.agent_store.storage import StorageConnector, TableType
7
+ from letta.config import LettaConfig
8
+ from letta.constants import MAX_EMBEDDING_DIM
9
+ from letta.data_types import Passage, Record, RecordType
10
+ from letta.utils import datetime_to_timestamp, timestamp_to_datetime
11
+
12
+ TEXT_PAYLOAD_KEY = "text_content"
13
+ METADATA_PAYLOAD_KEY = "metadata"
14
+
15
+
16
+ class QdrantStorageConnector(StorageConnector):
17
+ """Storage via Qdrant"""
18
+
19
+ def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None):
20
+ super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id)
21
+ try:
22
+ from qdrant_client import QdrantClient, models
23
+ except ImportError as e:
24
+ raise ImportError("'qdrant-client' not installed. Run `pip install qdrant-client`.") from e
25
+ assert table_type in [TableType.ARCHIVAL_MEMORY, TableType.PASSAGES], "Qdrant only supports archival memory"
26
+ if config.archival_storage_uri and len(config.archival_storage_uri.split(":")) == 2:
27
+ host, port = config.archival_storage_uri.split(":")
28
+ self.qdrant_client = QdrantClient(host=host, port=port, api_key=os.getenv("QDRANT_API_KEY"))
29
+ elif config.archival_storage_path:
30
+ self.qdrant_client = QdrantClient(path=config.archival_storage_path)
31
+ else:
32
+ raise ValueError("Qdrant storage requires either a URI or a path to the storage configured")
33
+ if not self.qdrant_client.collection_exists(self.table_name):
34
+ self.qdrant_client.create_collection(
35
+ collection_name=self.table_name,
36
+ vectors_config=models.VectorParams(
37
+ size=MAX_EMBEDDING_DIM,
38
+ distance=models.Distance.COSINE,
39
+ ),
40
+ )
41
+ self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "doc_id"]
42
+
43
+ def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 10) -> Iterator[List[RecordType]]:
44
+ from qdrant_client import grpc
45
+
46
+ filters = self.get_qdrant_filters(filters)
47
+ next_offset = None
48
+ stop_scrolling = False
49
+ while not stop_scrolling:
50
+ results, next_offset = self.qdrant_client.scroll(
51
+ collection_name=self.table_name,
52
+ scroll_filter=filters,
53
+ limit=page_size,
54
+ offset=next_offset,
55
+ with_payload=True,
56
+ with_vectors=True,
57
+ )
58
+ stop_scrolling = next_offset is None or (
59
+ isinstance(next_offset, grpc.PointId) and next_offset.num == 0 and next_offset.uuid == ""
60
+ )
61
+ yield self.to_records(results)
62
+
63
+ def get_all(self, filters: Optional[Dict] = {}, limit=10) -> List[RecordType]:
64
+ if self.size(filters) == 0:
65
+ return []
66
+ filters = self.get_qdrant_filters(filters)
67
+ results, _ = self.qdrant_client.scroll(
68
+ self.table_name,
69
+ scroll_filter=filters,
70
+ limit=limit,
71
+ with_payload=True,
72
+ with_vectors=True,
73
+ )
74
+ return self.to_records(results)
75
+
76
+ def get(self, id: str) -> Optional[RecordType]:
77
+ results = self.qdrant_client.retrieve(
78
+ collection_name=self.table_name,
79
+ ids=[str(id)],
80
+ with_payload=True,
81
+ with_vectors=True,
82
+ )
83
+ if not results:
84
+ return None
85
+ return self.to_records(results)[0]
86
+
87
+ def insert(self, record: Record):
88
+ points = self.to_points([record])
89
+ self.qdrant_client.upsert(self.table_name, points=points)
90
+
91
+ def insert_many(self, records: List[RecordType], show_progress=False):
92
+ points = self.to_points(records)
93
+ self.qdrant_client.upsert(self.table_name, points=points)
94
+
95
+ def delete(self, filters: Optional[Dict] = {}):
96
+ filters = self.get_qdrant_filters(filters)
97
+ self.qdrant_client.delete(self.table_name, points_selector=filters)
98
+
99
+ def delete_table(self):
100
+ self.qdrant_client.delete_collection(self.table_name)
101
+ self.qdrant_client.close()
102
+
103
+ def size(self, filters: Optional[Dict] = {}) -> int:
104
+ filters = self.get_qdrant_filters(filters)
105
+ return self.qdrant_client.count(collection_name=self.table_name, count_filter=filters).count
106
+
107
+ def close(self):
108
+ self.qdrant_client.close()
109
+
110
+ def query(
111
+ self,
112
+ query: str,
113
+ query_vec: List[float],
114
+ top_k: int = 10,
115
+ filters: Optional[Dict] = {},
116
+ ) -> List[RecordType]:
117
+ filters = self.get_filters(filters)
118
+ results = self.qdrant_client.search(
119
+ self.table_name,
120
+ query_vector=query_vec,
121
+ query_filter=filters,
122
+ limit=top_k,
123
+ with_payload=True,
124
+ with_vectors=True,
125
+ )
126
+ return self.to_records(results)
127
+
128
+ def to_records(self, records: list) -> List[RecordType]:
129
+ parsed_records = []
130
+ for record in records:
131
+ record = deepcopy(record)
132
+ metadata = record.payload[METADATA_PAYLOAD_KEY]
133
+ text = record.payload[TEXT_PAYLOAD_KEY]
134
+ _id = metadata.pop("id")
135
+ embedding = record.vector
136
+ for key, value in metadata.items():
137
+ if key in self.uuid_fields:
138
+ metadata[key] = uuid.UUID(value)
139
+ elif key == "created_at":
140
+ metadata[key] = timestamp_to_datetime(value)
141
+ parsed_records.append(
142
+ cast(
143
+ RecordType,
144
+ self.type(
145
+ text=text,
146
+ embedding=embedding,
147
+ id=uuid.UUID(_id),
148
+ **metadata,
149
+ ),
150
+ )
151
+ )
152
+ return parsed_records
153
+
154
+ def to_points(self, records: List[RecordType]):
155
+ from qdrant_client import models
156
+
157
+ assert all(isinstance(r, Passage) for r in records)
158
+ points = []
159
+ records = list(set(records))
160
+ for record in records:
161
+ record = vars(record)
162
+ _id = record.pop("id")
163
+ text = record.pop("text", "")
164
+ embedding = record.pop("embedding", {})
165
+ record_metadata = record.pop("metadata_", None) or {}
166
+ if "created_at" in record:
167
+ record["created_at"] = datetime_to_timestamp(record["created_at"])
168
+ metadata = {key: value for key, value in record.items() if value is not None}
169
+ metadata = {
170
+ **metadata,
171
+ **record_metadata,
172
+ "id": str(_id),
173
+ }
174
+ for key, value in metadata.items():
175
+ if key in self.uuid_fields:
176
+ metadata[key] = str(value)
177
+ points.append(
178
+ models.PointStruct(
179
+ id=str(_id),
180
+ vector=embedding,
181
+ payload={
182
+ TEXT_PAYLOAD_KEY: text,
183
+ METADATA_PAYLOAD_KEY: metadata,
184
+ },
185
+ )
186
+ )
187
+ return points
188
+
189
+ def get_qdrant_filters(self, filters: Optional[Dict] = {}):
190
+ from qdrant_client import models
191
+
192
+ filter_conditions = {**self.filters, **filters} if filters is not None else self.filters
193
+ must_conditions = []
194
+ for key, value in filter_conditions.items():
195
+ match_value = str(value) if key in self.uuid_fields else value
196
+ field_condition = models.FieldCondition(
197
+ key=f"{METADATA_PAYLOAD_KEY}.{key}",
198
+ match=models.MatchValue(value=match_value),
199
+ )
200
+ must_conditions.append(field_condition)
201
+ return models.Filter(must=must_conditions)