letta-nightly 0.1.7.dev20240924104148__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of letta-nightly might be problematic. Click here for more details.
- letta/__init__.py +24 -0
- letta/__main__.py +3 -0
- letta/agent.py +1427 -0
- letta/agent_store/chroma.py +295 -0
- letta/agent_store/db.py +546 -0
- letta/agent_store/lancedb.py +177 -0
- letta/agent_store/milvus.py +198 -0
- letta/agent_store/qdrant.py +201 -0
- letta/agent_store/storage.py +188 -0
- letta/benchmark/benchmark.py +96 -0
- letta/benchmark/constants.py +14 -0
- letta/cli/cli.py +689 -0
- letta/cli/cli_config.py +1282 -0
- letta/cli/cli_load.py +166 -0
- letta/client/__init__.py +0 -0
- letta/client/admin.py +171 -0
- letta/client/client.py +2360 -0
- letta/client/streaming.py +90 -0
- letta/client/utils.py +61 -0
- letta/config.py +484 -0
- letta/configs/anthropic.json +13 -0
- letta/configs/letta_hosted.json +11 -0
- letta/configs/openai.json +12 -0
- letta/constants.py +134 -0
- letta/credentials.py +140 -0
- letta/data_sources/connectors.py +247 -0
- letta/embeddings.py +218 -0
- letta/errors.py +26 -0
- letta/functions/__init__.py +0 -0
- letta/functions/function_sets/base.py +174 -0
- letta/functions/function_sets/extras.py +132 -0
- letta/functions/functions.py +105 -0
- letta/functions/schema_generator.py +205 -0
- letta/humans/__init__.py +0 -0
- letta/humans/examples/basic.txt +1 -0
- letta/humans/examples/cs_phd.txt +9 -0
- letta/interface.py +314 -0
- letta/llm_api/__init__.py +0 -0
- letta/llm_api/anthropic.py +383 -0
- letta/llm_api/azure_openai.py +155 -0
- letta/llm_api/cohere.py +396 -0
- letta/llm_api/google_ai.py +468 -0
- letta/llm_api/llm_api_tools.py +485 -0
- letta/llm_api/openai.py +470 -0
- letta/local_llm/README.md +3 -0
- letta/local_llm/__init__.py +0 -0
- letta/local_llm/chat_completion_proxy.py +279 -0
- letta/local_llm/constants.py +31 -0
- letta/local_llm/function_parser.py +68 -0
- letta/local_llm/grammars/__init__.py +0 -0
- letta/local_llm/grammars/gbnf_grammar_generator.py +1324 -0
- letta/local_llm/grammars/json.gbnf +26 -0
- letta/local_llm/grammars/json_func_calls_with_inner_thoughts.gbnf +32 -0
- letta/local_llm/groq/api.py +97 -0
- letta/local_llm/json_parser.py +202 -0
- letta/local_llm/koboldcpp/api.py +62 -0
- letta/local_llm/koboldcpp/settings.py +23 -0
- letta/local_llm/llamacpp/api.py +58 -0
- letta/local_llm/llamacpp/settings.py +22 -0
- letta/local_llm/llm_chat_completion_wrappers/__init__.py +0 -0
- letta/local_llm/llm_chat_completion_wrappers/airoboros.py +452 -0
- letta/local_llm/llm_chat_completion_wrappers/chatml.py +470 -0
- letta/local_llm/llm_chat_completion_wrappers/configurable_wrapper.py +387 -0
- letta/local_llm/llm_chat_completion_wrappers/dolphin.py +246 -0
- letta/local_llm/llm_chat_completion_wrappers/llama3.py +345 -0
- letta/local_llm/llm_chat_completion_wrappers/simple_summary_wrapper.py +156 -0
- letta/local_llm/llm_chat_completion_wrappers/wrapper_base.py +11 -0
- letta/local_llm/llm_chat_completion_wrappers/zephyr.py +345 -0
- letta/local_llm/lmstudio/api.py +100 -0
- letta/local_llm/lmstudio/settings.py +29 -0
- letta/local_llm/ollama/api.py +88 -0
- letta/local_llm/ollama/settings.py +32 -0
- letta/local_llm/settings/__init__.py +0 -0
- letta/local_llm/settings/deterministic_mirostat.py +45 -0
- letta/local_llm/settings/settings.py +72 -0
- letta/local_llm/settings/simple.py +28 -0
- letta/local_llm/utils.py +265 -0
- letta/local_llm/vllm/api.py +63 -0
- letta/local_llm/webui/api.py +60 -0
- letta/local_llm/webui/legacy_api.py +58 -0
- letta/local_llm/webui/legacy_settings.py +23 -0
- letta/local_llm/webui/settings.py +24 -0
- letta/log.py +76 -0
- letta/main.py +437 -0
- letta/memory.py +440 -0
- letta/metadata.py +884 -0
- letta/openai_backcompat/__init__.py +0 -0
- letta/openai_backcompat/openai_object.py +437 -0
- letta/persistence_manager.py +148 -0
- letta/personas/__init__.py +0 -0
- letta/personas/examples/anna_pa.txt +13 -0
- letta/personas/examples/google_search_persona.txt +15 -0
- letta/personas/examples/memgpt_doc.txt +6 -0
- letta/personas/examples/memgpt_starter.txt +4 -0
- letta/personas/examples/sam.txt +14 -0
- letta/personas/examples/sam_pov.txt +14 -0
- letta/personas/examples/sam_simple_pov_gpt35.txt +13 -0
- letta/personas/examples/sqldb/test.db +0 -0
- letta/prompts/__init__.py +0 -0
- letta/prompts/gpt_summarize.py +14 -0
- letta/prompts/gpt_system.py +26 -0
- letta/prompts/system/memgpt_base.txt +49 -0
- letta/prompts/system/memgpt_chat.txt +58 -0
- letta/prompts/system/memgpt_chat_compressed.txt +13 -0
- letta/prompts/system/memgpt_chat_fstring.txt +51 -0
- letta/prompts/system/memgpt_doc.txt +50 -0
- letta/prompts/system/memgpt_gpt35_extralong.txt +53 -0
- letta/prompts/system/memgpt_intuitive_knowledge.txt +31 -0
- letta/prompts/system/memgpt_modified_chat.txt +23 -0
- letta/pytest.ini +0 -0
- letta/schemas/agent.py +117 -0
- letta/schemas/api_key.py +21 -0
- letta/schemas/block.py +135 -0
- letta/schemas/document.py +21 -0
- letta/schemas/embedding_config.py +54 -0
- letta/schemas/enums.py +35 -0
- letta/schemas/job.py +38 -0
- letta/schemas/letta_base.py +80 -0
- letta/schemas/letta_message.py +175 -0
- letta/schemas/letta_request.py +23 -0
- letta/schemas/letta_response.py +28 -0
- letta/schemas/llm_config.py +54 -0
- letta/schemas/memory.py +224 -0
- letta/schemas/message.py +727 -0
- letta/schemas/openai/chat_completion_request.py +123 -0
- letta/schemas/openai/chat_completion_response.py +136 -0
- letta/schemas/openai/chat_completions.py +123 -0
- letta/schemas/openai/embedding_response.py +11 -0
- letta/schemas/openai/openai.py +157 -0
- letta/schemas/organization.py +20 -0
- letta/schemas/passage.py +80 -0
- letta/schemas/source.py +62 -0
- letta/schemas/tool.py +143 -0
- letta/schemas/usage.py +18 -0
- letta/schemas/user.py +33 -0
- letta/server/__init__.py +0 -0
- letta/server/constants.py +6 -0
- letta/server/rest_api/__init__.py +0 -0
- letta/server/rest_api/admin/__init__.py +0 -0
- letta/server/rest_api/admin/agents.py +21 -0
- letta/server/rest_api/admin/tools.py +83 -0
- letta/server/rest_api/admin/users.py +98 -0
- letta/server/rest_api/app.py +193 -0
- letta/server/rest_api/auth/__init__.py +0 -0
- letta/server/rest_api/auth/index.py +43 -0
- letta/server/rest_api/auth_token.py +22 -0
- letta/server/rest_api/interface.py +726 -0
- letta/server/rest_api/routers/__init__.py +0 -0
- letta/server/rest_api/routers/openai/__init__.py +0 -0
- letta/server/rest_api/routers/openai/assistants/__init__.py +0 -0
- letta/server/rest_api/routers/openai/assistants/assistants.py +115 -0
- letta/server/rest_api/routers/openai/assistants/schemas.py +121 -0
- letta/server/rest_api/routers/openai/assistants/threads.py +336 -0
- letta/server/rest_api/routers/openai/chat_completions/__init__.py +0 -0
- letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +131 -0
- letta/server/rest_api/routers/v1/__init__.py +15 -0
- letta/server/rest_api/routers/v1/agents.py +543 -0
- letta/server/rest_api/routers/v1/blocks.py +73 -0
- letta/server/rest_api/routers/v1/jobs.py +46 -0
- letta/server/rest_api/routers/v1/llms.py +28 -0
- letta/server/rest_api/routers/v1/organizations.py +61 -0
- letta/server/rest_api/routers/v1/sources.py +199 -0
- letta/server/rest_api/routers/v1/tools.py +103 -0
- letta/server/rest_api/routers/v1/users.py +109 -0
- letta/server/rest_api/static_files.py +74 -0
- letta/server/rest_api/utils.py +69 -0
- letta/server/server.py +1995 -0
- letta/server/startup.sh +8 -0
- letta/server/static_files/assets/index-0cbf7ad5.js +274 -0
- letta/server/static_files/assets/index-156816da.css +1 -0
- letta/server/static_files/assets/index-486e3228.js +274 -0
- letta/server/static_files/favicon.ico +0 -0
- letta/server/static_files/index.html +39 -0
- letta/server/static_files/memgpt_logo_transparent.png +0 -0
- letta/server/utils.py +46 -0
- letta/server/ws_api/__init__.py +0 -0
- letta/server/ws_api/example_client.py +104 -0
- letta/server/ws_api/interface.py +108 -0
- letta/server/ws_api/protocol.py +100 -0
- letta/server/ws_api/server.py +145 -0
- letta/settings.py +165 -0
- letta/streaming_interface.py +396 -0
- letta/system.py +207 -0
- letta/utils.py +1065 -0
- letta_nightly-0.1.7.dev20240924104148.dist-info/LICENSE +190 -0
- letta_nightly-0.1.7.dev20240924104148.dist-info/METADATA +98 -0
- letta_nightly-0.1.7.dev20240924104148.dist-info/RECORD +189 -0
- letta_nightly-0.1.7.dev20240924104148.dist-info/WHEEL +4 -0
- letta_nightly-0.1.7.dev20240924104148.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
# type: ignore
|
|
2
|
+
|
|
3
|
+
import uuid
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import Dict, Iterator, List, Optional
|
|
6
|
+
|
|
7
|
+
from lancedb.pydantic import LanceModel, Vector
|
|
8
|
+
|
|
9
|
+
from letta.agent_store.storage import StorageConnector, TableType
|
|
10
|
+
from letta.config import AgentConfig, LettaConfig
|
|
11
|
+
from letta.schemas.message import Message, Passage, Record
|
|
12
|
+
|
|
13
|
+
""" Initial implementation - not complete """
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_db_model(table_name: str, table_type: TableType):
|
|
17
|
+
config = LettaConfig.load()
|
|
18
|
+
|
|
19
|
+
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
|
|
20
|
+
# create schema for archival memory
|
|
21
|
+
class PassageModel(LanceModel):
|
|
22
|
+
"""Defines data model for storing Passages (consisting of text, embedding)"""
|
|
23
|
+
|
|
24
|
+
id: uuid.UUID
|
|
25
|
+
user_id: str
|
|
26
|
+
text: str
|
|
27
|
+
doc_id: str
|
|
28
|
+
agent_id: str
|
|
29
|
+
data_source: str
|
|
30
|
+
embedding: Vector(config.default_embedding_config.embedding_dim)
|
|
31
|
+
metadata_: Dict
|
|
32
|
+
|
|
33
|
+
def __repr__(self):
|
|
34
|
+
return f"<Passage(passage_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
|
|
35
|
+
|
|
36
|
+
def to_record(self):
|
|
37
|
+
return Passage(
|
|
38
|
+
text=self.text,
|
|
39
|
+
embedding=self.embedding,
|
|
40
|
+
doc_id=self.doc_id,
|
|
41
|
+
user_id=self.user_id,
|
|
42
|
+
id=self.id,
|
|
43
|
+
data_source=self.data_source,
|
|
44
|
+
agent_id=self.agent_id,
|
|
45
|
+
metadata=self.metadata_,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
return PassageModel
|
|
49
|
+
elif table_type == TableType.RECALL_MEMORY:
|
|
50
|
+
|
|
51
|
+
class MessageModel(LanceModel):
|
|
52
|
+
"""Defines data model for storing Message objects"""
|
|
53
|
+
|
|
54
|
+
__abstract__ = True # this line is necessary
|
|
55
|
+
|
|
56
|
+
# Assuming message_id is the primary key
|
|
57
|
+
id: uuid.UUID
|
|
58
|
+
user_id: str
|
|
59
|
+
agent_id: str
|
|
60
|
+
|
|
61
|
+
# openai info
|
|
62
|
+
role: str
|
|
63
|
+
name: str
|
|
64
|
+
text: str
|
|
65
|
+
model: str
|
|
66
|
+
user: str
|
|
67
|
+
|
|
68
|
+
# function info
|
|
69
|
+
function_name: str
|
|
70
|
+
function_args: str
|
|
71
|
+
function_response: str
|
|
72
|
+
|
|
73
|
+
embedding = Vector(config.default_embedding_config.embedding_dim)
|
|
74
|
+
|
|
75
|
+
# Add a datetime column, with default value as the current time
|
|
76
|
+
created_at = datetime
|
|
77
|
+
|
|
78
|
+
def __repr__(self):
|
|
79
|
+
return f"<Message(message_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
|
|
80
|
+
|
|
81
|
+
def to_record(self):
|
|
82
|
+
return Message(
|
|
83
|
+
user_id=self.user_id,
|
|
84
|
+
agent_id=self.agent_id,
|
|
85
|
+
role=self.role,
|
|
86
|
+
name=self.name,
|
|
87
|
+
text=self.text,
|
|
88
|
+
model=self.model,
|
|
89
|
+
function_name=self.function_name,
|
|
90
|
+
function_args=self.function_args,
|
|
91
|
+
function_response=self.function_response,
|
|
92
|
+
embedding=self.embedding,
|
|
93
|
+
created_at=self.created_at,
|
|
94
|
+
id=self.id,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
"""Create database model for table_name"""
|
|
98
|
+
return MessageModel
|
|
99
|
+
|
|
100
|
+
else:
|
|
101
|
+
raise ValueError(f"Table type {table_type} not implemented")
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class LanceDBConnector(StorageConnector):
|
|
105
|
+
"""Storage via LanceDB"""
|
|
106
|
+
|
|
107
|
+
# TODO: this should probably eventually be moved into a parent DB class
|
|
108
|
+
|
|
109
|
+
def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfig] = None):
|
|
110
|
+
# TODO
|
|
111
|
+
pass
|
|
112
|
+
|
|
113
|
+
def generate_where_filter(self, filters: Dict) -> str:
|
|
114
|
+
where_filters = []
|
|
115
|
+
for key, value in filters.items():
|
|
116
|
+
where_filters.append(f"{key}={value}")
|
|
117
|
+
return where_filters.join(" AND ")
|
|
118
|
+
|
|
119
|
+
@abstractmethod
|
|
120
|
+
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000) -> Iterator[List[Record]]:
|
|
121
|
+
# TODO
|
|
122
|
+
pass
|
|
123
|
+
|
|
124
|
+
@abstractmethod
|
|
125
|
+
def get_all(self, filters: Optional[Dict] = {}, limit=10) -> List[Record]:
|
|
126
|
+
# TODO
|
|
127
|
+
pass
|
|
128
|
+
|
|
129
|
+
@abstractmethod
|
|
130
|
+
def get(self, id: uuid.UUID) -> Optional[Record]:
|
|
131
|
+
# TODO
|
|
132
|
+
pass
|
|
133
|
+
|
|
134
|
+
@abstractmethod
|
|
135
|
+
def size(self, filters: Optional[Dict] = {}) -> int:
|
|
136
|
+
# TODO
|
|
137
|
+
pass
|
|
138
|
+
|
|
139
|
+
@abstractmethod
|
|
140
|
+
def insert(self, record: Record):
|
|
141
|
+
# TODO
|
|
142
|
+
pass
|
|
143
|
+
|
|
144
|
+
@abstractmethod
|
|
145
|
+
def insert_many(self, records: List[Record], show_progress=False):
|
|
146
|
+
# TODO
|
|
147
|
+
pass
|
|
148
|
+
|
|
149
|
+
@abstractmethod
|
|
150
|
+
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
|
|
151
|
+
# TODO
|
|
152
|
+
pass
|
|
153
|
+
|
|
154
|
+
@abstractmethod
|
|
155
|
+
def query_date(self, start_date, end_date):
|
|
156
|
+
# TODO
|
|
157
|
+
pass
|
|
158
|
+
|
|
159
|
+
@abstractmethod
|
|
160
|
+
def query_text(self, query):
|
|
161
|
+
# TODO
|
|
162
|
+
pass
|
|
163
|
+
|
|
164
|
+
@abstractmethod
|
|
165
|
+
def delete_table(self):
|
|
166
|
+
# TODO
|
|
167
|
+
pass
|
|
168
|
+
|
|
169
|
+
@abstractmethod
|
|
170
|
+
def delete(self, filters: Optional[Dict] = {}):
|
|
171
|
+
# TODO
|
|
172
|
+
pass
|
|
173
|
+
|
|
174
|
+
@abstractmethod
|
|
175
|
+
def save(self):
|
|
176
|
+
# TODO
|
|
177
|
+
pass
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
import uuid
|
|
2
|
+
from copy import deepcopy
|
|
3
|
+
from typing import Dict, Iterator, List, Optional, cast
|
|
4
|
+
|
|
5
|
+
from pymilvus import DataType, MilvusClient
|
|
6
|
+
from pymilvus.client.constants import ConsistencyLevel
|
|
7
|
+
|
|
8
|
+
from letta.agent_store.storage import StorageConnector, TableType
|
|
9
|
+
from letta.config import LettaConfig
|
|
10
|
+
from letta.constants import MAX_EMBEDDING_DIM
|
|
11
|
+
from letta.data_types import Passage, Record, RecordType
|
|
12
|
+
from letta.utils import datetime_to_timestamp, printd, timestamp_to_datetime
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class MilvusStorageConnector(StorageConnector):
|
|
16
|
+
"""Storage via Milvus"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None):
|
|
19
|
+
super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id)
|
|
20
|
+
|
|
21
|
+
assert table_type in [TableType.ARCHIVAL_MEMORY, TableType.PASSAGES], "Milvus only supports archival memory"
|
|
22
|
+
if config.archival_storage_uri:
|
|
23
|
+
self.client = MilvusClient(uri=config.archival_storage_uri)
|
|
24
|
+
self._create_collection()
|
|
25
|
+
else:
|
|
26
|
+
raise ValueError("Please set `archival_storage_uri` in the config file when using Milvus.")
|
|
27
|
+
|
|
28
|
+
# need to be converted to strings
|
|
29
|
+
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "doc_id"]
|
|
30
|
+
|
|
31
|
+
def _create_collection(self):
|
|
32
|
+
schema = MilvusClient.create_schema(
|
|
33
|
+
auto_id=False,
|
|
34
|
+
enable_dynamic_field=True,
|
|
35
|
+
)
|
|
36
|
+
schema.add_field(field_name="id", datatype=DataType.VARCHAR, is_primary=True, max_length=65_535)
|
|
37
|
+
schema.add_field(field_name="text", datatype=DataType.VARCHAR, is_primary=False, max_length=65_535)
|
|
38
|
+
schema.add_field(field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=MAX_EMBEDDING_DIM)
|
|
39
|
+
index_params = self.client.prepare_index_params()
|
|
40
|
+
index_params.add_index(field_name="id")
|
|
41
|
+
index_params.add_index(field_name="embedding", index_type="AUTOINDEX", metric_type="IP")
|
|
42
|
+
self.client.create_collection(
|
|
43
|
+
collection_name=self.table_name, schema=schema, index_params=index_params, consistency_level=ConsistencyLevel.Strong
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
def get_milvus_filter(self, filters: Optional[Dict] = {}) -> str:
|
|
47
|
+
filter_conditions = {**self.filters, **filters} if filters is not None else self.filters
|
|
48
|
+
if not filter_conditions:
|
|
49
|
+
return ""
|
|
50
|
+
conditions = []
|
|
51
|
+
for key, value in filter_conditions.items():
|
|
52
|
+
if key in self.uuid_fields or isinstance(key, str):
|
|
53
|
+
condition = f'({key} == "{value}")'
|
|
54
|
+
else:
|
|
55
|
+
condition = f"({key} == {value})"
|
|
56
|
+
conditions.append(condition)
|
|
57
|
+
filter_expr = " and ".join(conditions)
|
|
58
|
+
if len(conditions) == 1:
|
|
59
|
+
filter_expr = filter_expr[1:-1]
|
|
60
|
+
return filter_expr
|
|
61
|
+
|
|
62
|
+
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 1000) -> Iterator[List[RecordType]]:
|
|
63
|
+
if not self.client.has_collection(collection_name=self.table_name):
|
|
64
|
+
yield []
|
|
65
|
+
filter_expr = self.get_milvus_filter(filters)
|
|
66
|
+
offset = 0
|
|
67
|
+
while True:
|
|
68
|
+
# Retrieve a chunk of records with the given page_size
|
|
69
|
+
query_res = self.client.query(
|
|
70
|
+
collection_name=self.table_name,
|
|
71
|
+
filter=filter_expr,
|
|
72
|
+
offset=offset,
|
|
73
|
+
limit=page_size,
|
|
74
|
+
)
|
|
75
|
+
if not query_res:
|
|
76
|
+
break
|
|
77
|
+
# Yield a list of Record objects converted from the chunk
|
|
78
|
+
yield self._list_to_records(query_res)
|
|
79
|
+
|
|
80
|
+
# Increment the offset to get the next chunk in the next iteration
|
|
81
|
+
offset += page_size
|
|
82
|
+
|
|
83
|
+
def get_all(self, filters: Optional[Dict] = {}, limit=None) -> List[RecordType]:
|
|
84
|
+
if not self.client.has_collection(collection_name=self.table_name):
|
|
85
|
+
return []
|
|
86
|
+
filter_expr = self.get_milvus_filter(filters)
|
|
87
|
+
query_res = self.client.query(
|
|
88
|
+
collection_name=self.table_name,
|
|
89
|
+
filter=filter_expr,
|
|
90
|
+
limit=limit,
|
|
91
|
+
)
|
|
92
|
+
return self._list_to_records(query_res)
|
|
93
|
+
|
|
94
|
+
def get(self, id: str) -> Optional[RecordType]:
|
|
95
|
+
res = self.client.get(collection_name=self.table_name, ids=str(id))
|
|
96
|
+
return self._list_to_records(res)[0] if res else None
|
|
97
|
+
|
|
98
|
+
def size(self, filters: Optional[Dict] = {}) -> int:
|
|
99
|
+
if not self.client.has_collection(collection_name=self.table_name):
|
|
100
|
+
return 0
|
|
101
|
+
filter_expr = self.get_milvus_filter(filters)
|
|
102
|
+
count_expr = "count(*)"
|
|
103
|
+
query_res = self.client.query(
|
|
104
|
+
collection_name=self.table_name,
|
|
105
|
+
filter=filter_expr,
|
|
106
|
+
output_fields=[count_expr],
|
|
107
|
+
)
|
|
108
|
+
doc_num = query_res[0][count_expr]
|
|
109
|
+
return doc_num
|
|
110
|
+
|
|
111
|
+
def insert(self, record: RecordType):
|
|
112
|
+
self.insert_many([record])
|
|
113
|
+
|
|
114
|
+
def insert_many(self, records: List[RecordType], show_progress=False):
|
|
115
|
+
if not records:
|
|
116
|
+
return
|
|
117
|
+
|
|
118
|
+
# Milvus lite currently does not support upsert, so we delete and insert instead
|
|
119
|
+
# self.client.upsert(collection_name=self.table_name, data=self._records_to_list(records))
|
|
120
|
+
ids = [str(record.id) for record in records]
|
|
121
|
+
self.client.delete(collection_name=self.table_name, ids=ids)
|
|
122
|
+
data = self._records_to_list(records)
|
|
123
|
+
self.client.insert(collection_name=self.table_name, data=data)
|
|
124
|
+
|
|
125
|
+
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[RecordType]:
|
|
126
|
+
if not self.client.has_collection(self.table_name):
|
|
127
|
+
return []
|
|
128
|
+
search_res = self.client.search(
|
|
129
|
+
collection_name=self.table_name, data=[query_vec], filter=self.get_milvus_filter(filters), limit=top_k, output_fields=["*"]
|
|
130
|
+
)[0]
|
|
131
|
+
entity_res = [res["entity"] for res in search_res]
|
|
132
|
+
return self._list_to_records(entity_res)
|
|
133
|
+
|
|
134
|
+
def delete_table(self):
|
|
135
|
+
self.client.drop_collection(collection_name=self.table_name)
|
|
136
|
+
|
|
137
|
+
def delete(self, filters: Optional[Dict] = {}):
|
|
138
|
+
if not self.client.has_collection(collection_name=self.table_name):
|
|
139
|
+
return
|
|
140
|
+
filter_expr = self.get_milvus_filter(filters)
|
|
141
|
+
self.client.delete(collection_name=self.table_name, filter=filter_expr)
|
|
142
|
+
|
|
143
|
+
def save(self):
|
|
144
|
+
# save to persistence file (nothing needs to be done)
|
|
145
|
+
printd("Saving milvus")
|
|
146
|
+
|
|
147
|
+
def _records_to_list(self, records: List[Record]) -> List[Dict]:
|
|
148
|
+
if records == []:
|
|
149
|
+
return []
|
|
150
|
+
assert all(isinstance(r, Passage) for r in records)
|
|
151
|
+
record_list = []
|
|
152
|
+
records = list(set(records))
|
|
153
|
+
for record in records:
|
|
154
|
+
record_vars = deepcopy(vars(record))
|
|
155
|
+
_id = record_vars.pop("id")
|
|
156
|
+
text = record_vars.pop("text", "")
|
|
157
|
+
embedding = record_vars.pop("embedding")
|
|
158
|
+
record_metadata = record_vars.pop("metadata_", None) or {}
|
|
159
|
+
if "created_at" in record_vars:
|
|
160
|
+
record_vars["created_at"] = datetime_to_timestamp(record_vars["created_at"])
|
|
161
|
+
record_dict = {key: value for key, value in record_vars.items() if value is not None}
|
|
162
|
+
record_dict = {
|
|
163
|
+
**record_dict,
|
|
164
|
+
**record_metadata,
|
|
165
|
+
"id": str(_id),
|
|
166
|
+
"text": text,
|
|
167
|
+
"embedding": embedding,
|
|
168
|
+
}
|
|
169
|
+
for key, value in record_dict.items():
|
|
170
|
+
if key in self.uuid_fields:
|
|
171
|
+
record_dict[key] = str(value)
|
|
172
|
+
record_list.append(record_dict)
|
|
173
|
+
return record_list
|
|
174
|
+
|
|
175
|
+
def _list_to_records(self, query_res: List[Dict]) -> List[RecordType]:
|
|
176
|
+
records = []
|
|
177
|
+
for res_dict in query_res:
|
|
178
|
+
_id = res_dict.pop("id")
|
|
179
|
+
embedding = res_dict.pop("embedding")
|
|
180
|
+
text = res_dict.pop("text")
|
|
181
|
+
metadata = deepcopy(res_dict)
|
|
182
|
+
for key, value in metadata.items():
|
|
183
|
+
if key in self.uuid_fields:
|
|
184
|
+
metadata[key] = uuid.UUID(value)
|
|
185
|
+
elif key == "created_at":
|
|
186
|
+
metadata[key] = timestamp_to_datetime(value)
|
|
187
|
+
records.append(
|
|
188
|
+
cast(
|
|
189
|
+
RecordType,
|
|
190
|
+
self.type(
|
|
191
|
+
text=text,
|
|
192
|
+
embedding=embedding,
|
|
193
|
+
id=uuid.UUID(_id),
|
|
194
|
+
**metadata,
|
|
195
|
+
),
|
|
196
|
+
)
|
|
197
|
+
)
|
|
198
|
+
return records
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import uuid
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from typing import Dict, Iterator, List, Optional, cast
|
|
5
|
+
|
|
6
|
+
from letta.agent_store.storage import StorageConnector, TableType
|
|
7
|
+
from letta.config import LettaConfig
|
|
8
|
+
from letta.constants import MAX_EMBEDDING_DIM
|
|
9
|
+
from letta.data_types import Passage, Record, RecordType
|
|
10
|
+
from letta.utils import datetime_to_timestamp, timestamp_to_datetime
|
|
11
|
+
|
|
12
|
+
TEXT_PAYLOAD_KEY = "text_content"
|
|
13
|
+
METADATA_PAYLOAD_KEY = "metadata"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class QdrantStorageConnector(StorageConnector):
|
|
17
|
+
"""Storage via Qdrant"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None):
|
|
20
|
+
super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id)
|
|
21
|
+
try:
|
|
22
|
+
from qdrant_client import QdrantClient, models
|
|
23
|
+
except ImportError as e:
|
|
24
|
+
raise ImportError("'qdrant-client' not installed. Run `pip install qdrant-client`.") from e
|
|
25
|
+
assert table_type in [TableType.ARCHIVAL_MEMORY, TableType.PASSAGES], "Qdrant only supports archival memory"
|
|
26
|
+
if config.archival_storage_uri and len(config.archival_storage_uri.split(":")) == 2:
|
|
27
|
+
host, port = config.archival_storage_uri.split(":")
|
|
28
|
+
self.qdrant_client = QdrantClient(host=host, port=port, api_key=os.getenv("QDRANT_API_KEY"))
|
|
29
|
+
elif config.archival_storage_path:
|
|
30
|
+
self.qdrant_client = QdrantClient(path=config.archival_storage_path)
|
|
31
|
+
else:
|
|
32
|
+
raise ValueError("Qdrant storage requires either a URI or a path to the storage configured")
|
|
33
|
+
if not self.qdrant_client.collection_exists(self.table_name):
|
|
34
|
+
self.qdrant_client.create_collection(
|
|
35
|
+
collection_name=self.table_name,
|
|
36
|
+
vectors_config=models.VectorParams(
|
|
37
|
+
size=MAX_EMBEDDING_DIM,
|
|
38
|
+
distance=models.Distance.COSINE,
|
|
39
|
+
),
|
|
40
|
+
)
|
|
41
|
+
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "doc_id"]
|
|
42
|
+
|
|
43
|
+
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 10) -> Iterator[List[RecordType]]:
|
|
44
|
+
from qdrant_client import grpc
|
|
45
|
+
|
|
46
|
+
filters = self.get_qdrant_filters(filters)
|
|
47
|
+
next_offset = None
|
|
48
|
+
stop_scrolling = False
|
|
49
|
+
while not stop_scrolling:
|
|
50
|
+
results, next_offset = self.qdrant_client.scroll(
|
|
51
|
+
collection_name=self.table_name,
|
|
52
|
+
scroll_filter=filters,
|
|
53
|
+
limit=page_size,
|
|
54
|
+
offset=next_offset,
|
|
55
|
+
with_payload=True,
|
|
56
|
+
with_vectors=True,
|
|
57
|
+
)
|
|
58
|
+
stop_scrolling = next_offset is None or (
|
|
59
|
+
isinstance(next_offset, grpc.PointId) and next_offset.num == 0 and next_offset.uuid == ""
|
|
60
|
+
)
|
|
61
|
+
yield self.to_records(results)
|
|
62
|
+
|
|
63
|
+
def get_all(self, filters: Optional[Dict] = {}, limit=10) -> List[RecordType]:
|
|
64
|
+
if self.size(filters) == 0:
|
|
65
|
+
return []
|
|
66
|
+
filters = self.get_qdrant_filters(filters)
|
|
67
|
+
results, _ = self.qdrant_client.scroll(
|
|
68
|
+
self.table_name,
|
|
69
|
+
scroll_filter=filters,
|
|
70
|
+
limit=limit,
|
|
71
|
+
with_payload=True,
|
|
72
|
+
with_vectors=True,
|
|
73
|
+
)
|
|
74
|
+
return self.to_records(results)
|
|
75
|
+
|
|
76
|
+
def get(self, id: str) -> Optional[RecordType]:
|
|
77
|
+
results = self.qdrant_client.retrieve(
|
|
78
|
+
collection_name=self.table_name,
|
|
79
|
+
ids=[str(id)],
|
|
80
|
+
with_payload=True,
|
|
81
|
+
with_vectors=True,
|
|
82
|
+
)
|
|
83
|
+
if not results:
|
|
84
|
+
return None
|
|
85
|
+
return self.to_records(results)[0]
|
|
86
|
+
|
|
87
|
+
def insert(self, record: Record):
|
|
88
|
+
points = self.to_points([record])
|
|
89
|
+
self.qdrant_client.upsert(self.table_name, points=points)
|
|
90
|
+
|
|
91
|
+
def insert_many(self, records: List[RecordType], show_progress=False):
|
|
92
|
+
points = self.to_points(records)
|
|
93
|
+
self.qdrant_client.upsert(self.table_name, points=points)
|
|
94
|
+
|
|
95
|
+
def delete(self, filters: Optional[Dict] = {}):
|
|
96
|
+
filters = self.get_qdrant_filters(filters)
|
|
97
|
+
self.qdrant_client.delete(self.table_name, points_selector=filters)
|
|
98
|
+
|
|
99
|
+
def delete_table(self):
|
|
100
|
+
self.qdrant_client.delete_collection(self.table_name)
|
|
101
|
+
self.qdrant_client.close()
|
|
102
|
+
|
|
103
|
+
def size(self, filters: Optional[Dict] = {}) -> int:
|
|
104
|
+
filters = self.get_qdrant_filters(filters)
|
|
105
|
+
return self.qdrant_client.count(collection_name=self.table_name, count_filter=filters).count
|
|
106
|
+
|
|
107
|
+
def close(self):
|
|
108
|
+
self.qdrant_client.close()
|
|
109
|
+
|
|
110
|
+
def query(
|
|
111
|
+
self,
|
|
112
|
+
query: str,
|
|
113
|
+
query_vec: List[float],
|
|
114
|
+
top_k: int = 10,
|
|
115
|
+
filters: Optional[Dict] = {},
|
|
116
|
+
) -> List[RecordType]:
|
|
117
|
+
filters = self.get_filters(filters)
|
|
118
|
+
results = self.qdrant_client.search(
|
|
119
|
+
self.table_name,
|
|
120
|
+
query_vector=query_vec,
|
|
121
|
+
query_filter=filters,
|
|
122
|
+
limit=top_k,
|
|
123
|
+
with_payload=True,
|
|
124
|
+
with_vectors=True,
|
|
125
|
+
)
|
|
126
|
+
return self.to_records(results)
|
|
127
|
+
|
|
128
|
+
def to_records(self, records: list) -> List[RecordType]:
|
|
129
|
+
parsed_records = []
|
|
130
|
+
for record in records:
|
|
131
|
+
record = deepcopy(record)
|
|
132
|
+
metadata = record.payload[METADATA_PAYLOAD_KEY]
|
|
133
|
+
text = record.payload[TEXT_PAYLOAD_KEY]
|
|
134
|
+
_id = metadata.pop("id")
|
|
135
|
+
embedding = record.vector
|
|
136
|
+
for key, value in metadata.items():
|
|
137
|
+
if key in self.uuid_fields:
|
|
138
|
+
metadata[key] = uuid.UUID(value)
|
|
139
|
+
elif key == "created_at":
|
|
140
|
+
metadata[key] = timestamp_to_datetime(value)
|
|
141
|
+
parsed_records.append(
|
|
142
|
+
cast(
|
|
143
|
+
RecordType,
|
|
144
|
+
self.type(
|
|
145
|
+
text=text,
|
|
146
|
+
embedding=embedding,
|
|
147
|
+
id=uuid.UUID(_id),
|
|
148
|
+
**metadata,
|
|
149
|
+
),
|
|
150
|
+
)
|
|
151
|
+
)
|
|
152
|
+
return parsed_records
|
|
153
|
+
|
|
154
|
+
def to_points(self, records: List[RecordType]):
|
|
155
|
+
from qdrant_client import models
|
|
156
|
+
|
|
157
|
+
assert all(isinstance(r, Passage) for r in records)
|
|
158
|
+
points = []
|
|
159
|
+
records = list(set(records))
|
|
160
|
+
for record in records:
|
|
161
|
+
record = vars(record)
|
|
162
|
+
_id = record.pop("id")
|
|
163
|
+
text = record.pop("text", "")
|
|
164
|
+
embedding = record.pop("embedding", {})
|
|
165
|
+
record_metadata = record.pop("metadata_", None) or {}
|
|
166
|
+
if "created_at" in record:
|
|
167
|
+
record["created_at"] = datetime_to_timestamp(record["created_at"])
|
|
168
|
+
metadata = {key: value for key, value in record.items() if value is not None}
|
|
169
|
+
metadata = {
|
|
170
|
+
**metadata,
|
|
171
|
+
**record_metadata,
|
|
172
|
+
"id": str(_id),
|
|
173
|
+
}
|
|
174
|
+
for key, value in metadata.items():
|
|
175
|
+
if key in self.uuid_fields:
|
|
176
|
+
metadata[key] = str(value)
|
|
177
|
+
points.append(
|
|
178
|
+
models.PointStruct(
|
|
179
|
+
id=str(_id),
|
|
180
|
+
vector=embedding,
|
|
181
|
+
payload={
|
|
182
|
+
TEXT_PAYLOAD_KEY: text,
|
|
183
|
+
METADATA_PAYLOAD_KEY: metadata,
|
|
184
|
+
},
|
|
185
|
+
)
|
|
186
|
+
)
|
|
187
|
+
return points
|
|
188
|
+
|
|
189
|
+
def get_qdrant_filters(self, filters: Optional[Dict] = {}):
|
|
190
|
+
from qdrant_client import models
|
|
191
|
+
|
|
192
|
+
filter_conditions = {**self.filters, **filters} if filters is not None else self.filters
|
|
193
|
+
must_conditions = []
|
|
194
|
+
for key, value in filter_conditions.items():
|
|
195
|
+
match_value = str(value) if key in self.uuid_fields else value
|
|
196
|
+
field_condition = models.FieldCondition(
|
|
197
|
+
key=f"{METADATA_PAYLOAD_KEY}.{key}",
|
|
198
|
+
match=models.MatchValue(value=match_value),
|
|
199
|
+
)
|
|
200
|
+
must_conditions.append(field_condition)
|
|
201
|
+
return models.Filter(must=must_conditions)
|