letta-nightly 0.1.7.dev20240924104148__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of letta-nightly might be problematic. Click here for more details.
- letta/__init__.py +24 -0
- letta/__main__.py +3 -0
- letta/agent.py +1427 -0
- letta/agent_store/chroma.py +295 -0
- letta/agent_store/db.py +546 -0
- letta/agent_store/lancedb.py +177 -0
- letta/agent_store/milvus.py +198 -0
- letta/agent_store/qdrant.py +201 -0
- letta/agent_store/storage.py +188 -0
- letta/benchmark/benchmark.py +96 -0
- letta/benchmark/constants.py +14 -0
- letta/cli/cli.py +689 -0
- letta/cli/cli_config.py +1282 -0
- letta/cli/cli_load.py +166 -0
- letta/client/__init__.py +0 -0
- letta/client/admin.py +171 -0
- letta/client/client.py +2360 -0
- letta/client/streaming.py +90 -0
- letta/client/utils.py +61 -0
- letta/config.py +484 -0
- letta/configs/anthropic.json +13 -0
- letta/configs/letta_hosted.json +11 -0
- letta/configs/openai.json +12 -0
- letta/constants.py +134 -0
- letta/credentials.py +140 -0
- letta/data_sources/connectors.py +247 -0
- letta/embeddings.py +218 -0
- letta/errors.py +26 -0
- letta/functions/__init__.py +0 -0
- letta/functions/function_sets/base.py +174 -0
- letta/functions/function_sets/extras.py +132 -0
- letta/functions/functions.py +105 -0
- letta/functions/schema_generator.py +205 -0
- letta/humans/__init__.py +0 -0
- letta/humans/examples/basic.txt +1 -0
- letta/humans/examples/cs_phd.txt +9 -0
- letta/interface.py +314 -0
- letta/llm_api/__init__.py +0 -0
- letta/llm_api/anthropic.py +383 -0
- letta/llm_api/azure_openai.py +155 -0
- letta/llm_api/cohere.py +396 -0
- letta/llm_api/google_ai.py +468 -0
- letta/llm_api/llm_api_tools.py +485 -0
- letta/llm_api/openai.py +470 -0
- letta/local_llm/README.md +3 -0
- letta/local_llm/__init__.py +0 -0
- letta/local_llm/chat_completion_proxy.py +279 -0
- letta/local_llm/constants.py +31 -0
- letta/local_llm/function_parser.py +68 -0
- letta/local_llm/grammars/__init__.py +0 -0
- letta/local_llm/grammars/gbnf_grammar_generator.py +1324 -0
- letta/local_llm/grammars/json.gbnf +26 -0
- letta/local_llm/grammars/json_func_calls_with_inner_thoughts.gbnf +32 -0
- letta/local_llm/groq/api.py +97 -0
- letta/local_llm/json_parser.py +202 -0
- letta/local_llm/koboldcpp/api.py +62 -0
- letta/local_llm/koboldcpp/settings.py +23 -0
- letta/local_llm/llamacpp/api.py +58 -0
- letta/local_llm/llamacpp/settings.py +22 -0
- letta/local_llm/llm_chat_completion_wrappers/__init__.py +0 -0
- letta/local_llm/llm_chat_completion_wrappers/airoboros.py +452 -0
- letta/local_llm/llm_chat_completion_wrappers/chatml.py +470 -0
- letta/local_llm/llm_chat_completion_wrappers/configurable_wrapper.py +387 -0
- letta/local_llm/llm_chat_completion_wrappers/dolphin.py +246 -0
- letta/local_llm/llm_chat_completion_wrappers/llama3.py +345 -0
- letta/local_llm/llm_chat_completion_wrappers/simple_summary_wrapper.py +156 -0
- letta/local_llm/llm_chat_completion_wrappers/wrapper_base.py +11 -0
- letta/local_llm/llm_chat_completion_wrappers/zephyr.py +345 -0
- letta/local_llm/lmstudio/api.py +100 -0
- letta/local_llm/lmstudio/settings.py +29 -0
- letta/local_llm/ollama/api.py +88 -0
- letta/local_llm/ollama/settings.py +32 -0
- letta/local_llm/settings/__init__.py +0 -0
- letta/local_llm/settings/deterministic_mirostat.py +45 -0
- letta/local_llm/settings/settings.py +72 -0
- letta/local_llm/settings/simple.py +28 -0
- letta/local_llm/utils.py +265 -0
- letta/local_llm/vllm/api.py +63 -0
- letta/local_llm/webui/api.py +60 -0
- letta/local_llm/webui/legacy_api.py +58 -0
- letta/local_llm/webui/legacy_settings.py +23 -0
- letta/local_llm/webui/settings.py +24 -0
- letta/log.py +76 -0
- letta/main.py +437 -0
- letta/memory.py +440 -0
- letta/metadata.py +884 -0
- letta/openai_backcompat/__init__.py +0 -0
- letta/openai_backcompat/openai_object.py +437 -0
- letta/persistence_manager.py +148 -0
- letta/personas/__init__.py +0 -0
- letta/personas/examples/anna_pa.txt +13 -0
- letta/personas/examples/google_search_persona.txt +15 -0
- letta/personas/examples/memgpt_doc.txt +6 -0
- letta/personas/examples/memgpt_starter.txt +4 -0
- letta/personas/examples/sam.txt +14 -0
- letta/personas/examples/sam_pov.txt +14 -0
- letta/personas/examples/sam_simple_pov_gpt35.txt +13 -0
- letta/personas/examples/sqldb/test.db +0 -0
- letta/prompts/__init__.py +0 -0
- letta/prompts/gpt_summarize.py +14 -0
- letta/prompts/gpt_system.py +26 -0
- letta/prompts/system/memgpt_base.txt +49 -0
- letta/prompts/system/memgpt_chat.txt +58 -0
- letta/prompts/system/memgpt_chat_compressed.txt +13 -0
- letta/prompts/system/memgpt_chat_fstring.txt +51 -0
- letta/prompts/system/memgpt_doc.txt +50 -0
- letta/prompts/system/memgpt_gpt35_extralong.txt +53 -0
- letta/prompts/system/memgpt_intuitive_knowledge.txt +31 -0
- letta/prompts/system/memgpt_modified_chat.txt +23 -0
- letta/pytest.ini +0 -0
- letta/schemas/agent.py +117 -0
- letta/schemas/api_key.py +21 -0
- letta/schemas/block.py +135 -0
- letta/schemas/document.py +21 -0
- letta/schemas/embedding_config.py +54 -0
- letta/schemas/enums.py +35 -0
- letta/schemas/job.py +38 -0
- letta/schemas/letta_base.py +80 -0
- letta/schemas/letta_message.py +175 -0
- letta/schemas/letta_request.py +23 -0
- letta/schemas/letta_response.py +28 -0
- letta/schemas/llm_config.py +54 -0
- letta/schemas/memory.py +224 -0
- letta/schemas/message.py +727 -0
- letta/schemas/openai/chat_completion_request.py +123 -0
- letta/schemas/openai/chat_completion_response.py +136 -0
- letta/schemas/openai/chat_completions.py +123 -0
- letta/schemas/openai/embedding_response.py +11 -0
- letta/schemas/openai/openai.py +157 -0
- letta/schemas/organization.py +20 -0
- letta/schemas/passage.py +80 -0
- letta/schemas/source.py +62 -0
- letta/schemas/tool.py +143 -0
- letta/schemas/usage.py +18 -0
- letta/schemas/user.py +33 -0
- letta/server/__init__.py +0 -0
- letta/server/constants.py +6 -0
- letta/server/rest_api/__init__.py +0 -0
- letta/server/rest_api/admin/__init__.py +0 -0
- letta/server/rest_api/admin/agents.py +21 -0
- letta/server/rest_api/admin/tools.py +83 -0
- letta/server/rest_api/admin/users.py +98 -0
- letta/server/rest_api/app.py +193 -0
- letta/server/rest_api/auth/__init__.py +0 -0
- letta/server/rest_api/auth/index.py +43 -0
- letta/server/rest_api/auth_token.py +22 -0
- letta/server/rest_api/interface.py +726 -0
- letta/server/rest_api/routers/__init__.py +0 -0
- letta/server/rest_api/routers/openai/__init__.py +0 -0
- letta/server/rest_api/routers/openai/assistants/__init__.py +0 -0
- letta/server/rest_api/routers/openai/assistants/assistants.py +115 -0
- letta/server/rest_api/routers/openai/assistants/schemas.py +121 -0
- letta/server/rest_api/routers/openai/assistants/threads.py +336 -0
- letta/server/rest_api/routers/openai/chat_completions/__init__.py +0 -0
- letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +131 -0
- letta/server/rest_api/routers/v1/__init__.py +15 -0
- letta/server/rest_api/routers/v1/agents.py +543 -0
- letta/server/rest_api/routers/v1/blocks.py +73 -0
- letta/server/rest_api/routers/v1/jobs.py +46 -0
- letta/server/rest_api/routers/v1/llms.py +28 -0
- letta/server/rest_api/routers/v1/organizations.py +61 -0
- letta/server/rest_api/routers/v1/sources.py +199 -0
- letta/server/rest_api/routers/v1/tools.py +103 -0
- letta/server/rest_api/routers/v1/users.py +109 -0
- letta/server/rest_api/static_files.py +74 -0
- letta/server/rest_api/utils.py +69 -0
- letta/server/server.py +1995 -0
- letta/server/startup.sh +8 -0
- letta/server/static_files/assets/index-0cbf7ad5.js +274 -0
- letta/server/static_files/assets/index-156816da.css +1 -0
- letta/server/static_files/assets/index-486e3228.js +274 -0
- letta/server/static_files/favicon.ico +0 -0
- letta/server/static_files/index.html +39 -0
- letta/server/static_files/memgpt_logo_transparent.png +0 -0
- letta/server/utils.py +46 -0
- letta/server/ws_api/__init__.py +0 -0
- letta/server/ws_api/example_client.py +104 -0
- letta/server/ws_api/interface.py +108 -0
- letta/server/ws_api/protocol.py +100 -0
- letta/server/ws_api/server.py +145 -0
- letta/settings.py +165 -0
- letta/streaming_interface.py +396 -0
- letta/system.py +207 -0
- letta/utils.py +1065 -0
- letta_nightly-0.1.7.dev20240924104148.dist-info/LICENSE +190 -0
- letta_nightly-0.1.7.dev20240924104148.dist-info/METADATA +98 -0
- letta_nightly-0.1.7.dev20240924104148.dist-info/RECORD +189 -0
- letta_nightly-0.1.7.dev20240924104148.dist-info/WHEEL +4 -0
- letta_nightly-0.1.7.dev20240924104148.dist-info/entry_points.txt +3 -0
letta/agent_store/db.py
ADDED
|
@@ -0,0 +1,546 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import os
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from typing import Dict, List, Optional
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from sqlalchemy import (
|
|
8
|
+
BINARY,
|
|
9
|
+
Column,
|
|
10
|
+
DateTime,
|
|
11
|
+
Index,
|
|
12
|
+
String,
|
|
13
|
+
TypeDecorator,
|
|
14
|
+
and_,
|
|
15
|
+
asc,
|
|
16
|
+
desc,
|
|
17
|
+
or_,
|
|
18
|
+
select,
|
|
19
|
+
text,
|
|
20
|
+
)
|
|
21
|
+
from sqlalchemy.orm import declarative_base, mapped_column
|
|
22
|
+
from sqlalchemy.orm.session import close_all_sessions
|
|
23
|
+
from sqlalchemy.sql import func
|
|
24
|
+
from sqlalchemy_json import MutableJson
|
|
25
|
+
from tqdm import tqdm
|
|
26
|
+
|
|
27
|
+
from letta.agent_store.storage import StorageConnector, TableType
|
|
28
|
+
from letta.config import LettaConfig
|
|
29
|
+
from letta.constants import MAX_EMBEDDING_DIM
|
|
30
|
+
from letta.metadata import EmbeddingConfigColumn, ToolCallColumn
|
|
31
|
+
|
|
32
|
+
# from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall
|
|
33
|
+
from letta.schemas.message import Message
|
|
34
|
+
from letta.schemas.openai.chat_completions import ToolCall
|
|
35
|
+
from letta.schemas.passage import Passage
|
|
36
|
+
from letta.settings import settings
|
|
37
|
+
|
|
38
|
+
Base = declarative_base()
|
|
39
|
+
config = LettaConfig()
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class CommonVector(TypeDecorator):
|
|
43
|
+
"""Common type for representing vectors in SQLite"""
|
|
44
|
+
|
|
45
|
+
impl = BINARY
|
|
46
|
+
cache_ok = True
|
|
47
|
+
|
|
48
|
+
def load_dialect_impl(self, dialect):
|
|
49
|
+
return dialect.type_descriptor(BINARY())
|
|
50
|
+
|
|
51
|
+
def process_bind_param(self, value, dialect):
|
|
52
|
+
if value is None:
|
|
53
|
+
return value
|
|
54
|
+
# Ensure value is a numpy array
|
|
55
|
+
if isinstance(value, list):
|
|
56
|
+
value = np.array(value, dtype=np.float32)
|
|
57
|
+
# Serialize numpy array to bytes, then encode to base64 for universal compatibility
|
|
58
|
+
return base64.b64encode(value.tobytes())
|
|
59
|
+
|
|
60
|
+
def process_result_value(self, value, dialect):
|
|
61
|
+
if not value:
|
|
62
|
+
return value
|
|
63
|
+
# Check database type and deserialize accordingly
|
|
64
|
+
if dialect.name == "sqlite":
|
|
65
|
+
# Decode from base64 and convert back to numpy array
|
|
66
|
+
value = base64.b64decode(value)
|
|
67
|
+
# For PostgreSQL, value is already in bytes
|
|
68
|
+
return np.frombuffer(value, dtype=np.float32)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class MessageModel(Base):
|
|
72
|
+
"""Defines data model for storing Message objects"""
|
|
73
|
+
|
|
74
|
+
__tablename__ = "messages"
|
|
75
|
+
__table_args__ = {"extend_existing": True}
|
|
76
|
+
|
|
77
|
+
# Assuming message_id is the primary key
|
|
78
|
+
id = Column(String, primary_key=True)
|
|
79
|
+
user_id = Column(String, nullable=False)
|
|
80
|
+
agent_id = Column(String, nullable=False)
|
|
81
|
+
|
|
82
|
+
# openai info
|
|
83
|
+
role = Column(String, nullable=False)
|
|
84
|
+
text = Column(String) # optional: can be null if function call
|
|
85
|
+
model = Column(String) # optional: can be null if LLM backend doesn't require specifying
|
|
86
|
+
name = Column(String) # optional: multi-agent only
|
|
87
|
+
|
|
88
|
+
# tool call request info
|
|
89
|
+
# if role == "assistant", this MAY be specified
|
|
90
|
+
# if role != "assistant", this must be null
|
|
91
|
+
# TODO align with OpenAI spec of multiple tool calls
|
|
92
|
+
# tool_calls = Column(ToolCallColumn)
|
|
93
|
+
tool_calls = Column(ToolCallColumn)
|
|
94
|
+
|
|
95
|
+
# tool call response info
|
|
96
|
+
# if role == "tool", then this must be specified
|
|
97
|
+
# if role != "tool", this must be null
|
|
98
|
+
tool_call_id = Column(String)
|
|
99
|
+
|
|
100
|
+
# Add a datetime column, with default value as the current time
|
|
101
|
+
created_at = Column(DateTime(timezone=True))
|
|
102
|
+
Index("message_idx_user", user_id, agent_id),
|
|
103
|
+
|
|
104
|
+
def __repr__(self):
|
|
105
|
+
return f"<Message(message_id='{self.id}', text='{self.text}')>"
|
|
106
|
+
|
|
107
|
+
def to_record(self):
|
|
108
|
+
# calls = (
|
|
109
|
+
# [ToolCall(id=tool_call["id"], function=ToolCallFunction(**tool_call["function"])) for tool_call in self.tool_calls]
|
|
110
|
+
# if self.tool_calls
|
|
111
|
+
# else None
|
|
112
|
+
# )
|
|
113
|
+
# if calls:
|
|
114
|
+
# assert isinstance(calls[0], ToolCall)
|
|
115
|
+
if self.tool_calls and len(self.tool_calls) > 0:
|
|
116
|
+
assert isinstance(self.tool_calls[0], ToolCall), type(self.tool_calls[0])
|
|
117
|
+
for tool in self.tool_calls:
|
|
118
|
+
assert isinstance(tool, ToolCall), type(tool)
|
|
119
|
+
return Message(
|
|
120
|
+
user_id=self.user_id,
|
|
121
|
+
agent_id=self.agent_id,
|
|
122
|
+
role=self.role,
|
|
123
|
+
name=self.name,
|
|
124
|
+
text=self.text,
|
|
125
|
+
model=self.model,
|
|
126
|
+
# tool_calls=[ToolCall(id=tool_call["id"], function=ToolCallFunction(**tool_call["function"])) for tool_call in self.tool_calls] if self.tool_calls else None,
|
|
127
|
+
tool_calls=self.tool_calls,
|
|
128
|
+
tool_call_id=self.tool_call_id,
|
|
129
|
+
created_at=self.created_at,
|
|
130
|
+
id=self.id,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class PassageModel(Base):
|
|
135
|
+
"""Defines data model for storing Passages (consisting of text, embedding)"""
|
|
136
|
+
|
|
137
|
+
__tablename__ = "passages"
|
|
138
|
+
__table_args__ = {"extend_existing": True}
|
|
139
|
+
|
|
140
|
+
# Assuming passage_id is the primary key
|
|
141
|
+
id = Column(String, primary_key=True)
|
|
142
|
+
user_id = Column(String, nullable=False)
|
|
143
|
+
text = Column(String)
|
|
144
|
+
doc_id = Column(String)
|
|
145
|
+
agent_id = Column(String)
|
|
146
|
+
source_id = Column(String)
|
|
147
|
+
|
|
148
|
+
# vector storage
|
|
149
|
+
if settings.letta_pg_uri_no_default:
|
|
150
|
+
from pgvector.sqlalchemy import Vector
|
|
151
|
+
|
|
152
|
+
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
|
|
153
|
+
elif config.archival_storage_type == "sqlite" or config.archival_storage_type == "chroma":
|
|
154
|
+
embedding = Column(CommonVector)
|
|
155
|
+
else:
|
|
156
|
+
raise ValueError(f"Unsupported archival_storage_type: {config.archival_storage_type}")
|
|
157
|
+
embedding_config = Column(EmbeddingConfigColumn)
|
|
158
|
+
metadata_ = Column(MutableJson)
|
|
159
|
+
|
|
160
|
+
# Add a datetime column, with default value as the current time
|
|
161
|
+
created_at = Column(DateTime(timezone=True))
|
|
162
|
+
|
|
163
|
+
Index("passage_idx_user", user_id, agent_id, doc_id),
|
|
164
|
+
|
|
165
|
+
def __repr__(self):
|
|
166
|
+
return f"<Passage(passage_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
|
|
167
|
+
|
|
168
|
+
def to_record(self):
|
|
169
|
+
return Passage(
|
|
170
|
+
text=self.text,
|
|
171
|
+
embedding=self.embedding,
|
|
172
|
+
embedding_config=self.embedding_config,
|
|
173
|
+
doc_id=self.doc_id,
|
|
174
|
+
user_id=self.user_id,
|
|
175
|
+
id=self.id,
|
|
176
|
+
source_id=self.source_id,
|
|
177
|
+
agent_id=self.agent_id,
|
|
178
|
+
metadata_=self.metadata_,
|
|
179
|
+
created_at=self.created_at,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class SQLStorageConnector(StorageConnector):
|
|
184
|
+
def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None):
|
|
185
|
+
super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id)
|
|
186
|
+
self.config = config
|
|
187
|
+
|
|
188
|
+
def get_filters(self, filters: Optional[Dict] = {}):
|
|
189
|
+
if filters is not None:
|
|
190
|
+
filter_conditions = {**self.filters, **filters}
|
|
191
|
+
else:
|
|
192
|
+
filter_conditions = self.filters
|
|
193
|
+
all_filters = [getattr(self.db_model, key) == value for key, value in filter_conditions.items()]
|
|
194
|
+
return all_filters
|
|
195
|
+
|
|
196
|
+
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000, offset=0):
|
|
197
|
+
filters = self.get_filters(filters)
|
|
198
|
+
while True:
|
|
199
|
+
# Retrieve a chunk of records with the given page_size
|
|
200
|
+
with self.session_maker() as session:
|
|
201
|
+
db_record_chunk = session.query(self.db_model).filter(*filters).offset(offset).limit(page_size).all()
|
|
202
|
+
|
|
203
|
+
# If the chunk is empty, we've retrieved all records
|
|
204
|
+
if not db_record_chunk:
|
|
205
|
+
break
|
|
206
|
+
|
|
207
|
+
# Yield a list of Record objects converted from the chunk
|
|
208
|
+
yield [record.to_record() for record in db_record_chunk]
|
|
209
|
+
|
|
210
|
+
# Increment the offset to get the next chunk in the next iteration
|
|
211
|
+
offset += page_size
|
|
212
|
+
|
|
213
|
+
def get_all_cursor(
|
|
214
|
+
self,
|
|
215
|
+
filters: Optional[Dict] = {},
|
|
216
|
+
after: str = None,
|
|
217
|
+
before: str = None,
|
|
218
|
+
limit: Optional[int] = 1000,
|
|
219
|
+
order_by: str = "created_at",
|
|
220
|
+
reverse: bool = False,
|
|
221
|
+
):
|
|
222
|
+
"""Get all that returns a cursor (record.id) and records"""
|
|
223
|
+
filters = self.get_filters(filters)
|
|
224
|
+
|
|
225
|
+
# generate query
|
|
226
|
+
with self.session_maker() as session:
|
|
227
|
+
query = session.query(self.db_model).filter(*filters)
|
|
228
|
+
# query = query.order_by(asc(self.db_model.id))
|
|
229
|
+
|
|
230
|
+
# records are sorted by the order_by field first, and then by the ID if two fields are the same
|
|
231
|
+
if reverse:
|
|
232
|
+
query = query.order_by(desc(getattr(self.db_model, order_by)), asc(self.db_model.id))
|
|
233
|
+
else:
|
|
234
|
+
query = query.order_by(asc(getattr(self.db_model, order_by)), asc(self.db_model.id))
|
|
235
|
+
|
|
236
|
+
# cursor logic: filter records based on before/after ID
|
|
237
|
+
if after:
|
|
238
|
+
after_value = getattr(self.get(id=after), order_by)
|
|
239
|
+
sort_exp = getattr(self.db_model, order_by) > after_value
|
|
240
|
+
query = query.filter(
|
|
241
|
+
or_(sort_exp, and_(getattr(self.db_model, order_by) == after_value, self.db_model.id > after)) # tiebreaker case
|
|
242
|
+
)
|
|
243
|
+
if before:
|
|
244
|
+
before_value = getattr(self.get(id=before), order_by)
|
|
245
|
+
sort_exp = getattr(self.db_model, order_by) < before_value
|
|
246
|
+
query = query.filter(or_(sort_exp, and_(getattr(self.db_model, order_by) == before_value, self.db_model.id < before)))
|
|
247
|
+
|
|
248
|
+
# get records
|
|
249
|
+
db_record_chunk = query.limit(limit).all()
|
|
250
|
+
if not db_record_chunk:
|
|
251
|
+
return (None, [])
|
|
252
|
+
records = [record.to_record() for record in db_record_chunk]
|
|
253
|
+
next_cursor = db_record_chunk[-1].id
|
|
254
|
+
assert isinstance(next_cursor, str)
|
|
255
|
+
|
|
256
|
+
# return (cursor, list[records])
|
|
257
|
+
return (next_cursor, records)
|
|
258
|
+
|
|
259
|
+
def get_all(self, filters: Optional[Dict] = {}, limit=None):
|
|
260
|
+
filters = self.get_filters(filters)
|
|
261
|
+
with self.session_maker() as session:
|
|
262
|
+
if limit:
|
|
263
|
+
db_records = session.query(self.db_model).filter(*filters).limit(limit).all()
|
|
264
|
+
else:
|
|
265
|
+
db_records = session.query(self.db_model).filter(*filters).all()
|
|
266
|
+
return [record.to_record() for record in db_records]
|
|
267
|
+
|
|
268
|
+
def get(self, id: str):
|
|
269
|
+
with self.session_maker() as session:
|
|
270
|
+
db_record = session.get(self.db_model, id)
|
|
271
|
+
if db_record is None:
|
|
272
|
+
return None
|
|
273
|
+
return db_record.to_record()
|
|
274
|
+
|
|
275
|
+
def size(self, filters: Optional[Dict] = {}) -> int:
|
|
276
|
+
# return size of table
|
|
277
|
+
filters = self.get_filters(filters)
|
|
278
|
+
with self.session_maker() as session:
|
|
279
|
+
return session.query(self.db_model).filter(*filters).count()
|
|
280
|
+
|
|
281
|
+
def insert(self, record):
|
|
282
|
+
raise NotImplementedError
|
|
283
|
+
|
|
284
|
+
def insert_many(self, records, show_progress=False):
|
|
285
|
+
raise NotImplementedError
|
|
286
|
+
|
|
287
|
+
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}):
|
|
288
|
+
raise NotImplementedError("Vector query not implemented for SQLStorageConnector")
|
|
289
|
+
|
|
290
|
+
def save(self):
|
|
291
|
+
return
|
|
292
|
+
|
|
293
|
+
def list_data_sources(self):
|
|
294
|
+
assert self.table_type == TableType.ARCHIVAL_MEMORY, f"list_data_sources only implemented for ARCHIVAL_MEMORY"
|
|
295
|
+
with self.session_maker() as session:
|
|
296
|
+
unique_data_sources = session.query(self.db_model.data_source).filter(*self.filters).distinct().all()
|
|
297
|
+
return unique_data_sources
|
|
298
|
+
|
|
299
|
+
def query_date(self, start_date, end_date, limit=None, offset=0):
|
|
300
|
+
filters = self.get_filters({})
|
|
301
|
+
with self.session_maker() as session:
|
|
302
|
+
query = (
|
|
303
|
+
session.query(self.db_model)
|
|
304
|
+
.filter(*filters)
|
|
305
|
+
.filter(self.db_model.created_at >= start_date)
|
|
306
|
+
.filter(self.db_model.created_at <= end_date)
|
|
307
|
+
.filter(self.db_model.role != "system")
|
|
308
|
+
.filter(self.db_model.role != "tool")
|
|
309
|
+
.offset(offset)
|
|
310
|
+
)
|
|
311
|
+
if limit:
|
|
312
|
+
query = query.limit(limit)
|
|
313
|
+
results = query.all()
|
|
314
|
+
return [result.to_record() for result in results]
|
|
315
|
+
|
|
316
|
+
def query_text(self, query, limit=None, offset=0):
|
|
317
|
+
# todo: make fuzz https://stackoverflow.com/questions/42388956/create-a-full-text-search-index-with-sqlalchemy-on-postgresql/42390204#42390204
|
|
318
|
+
filters = self.get_filters({})
|
|
319
|
+
with self.session_maker() as session:
|
|
320
|
+
query = (
|
|
321
|
+
session.query(self.db_model)
|
|
322
|
+
.filter(*filters)
|
|
323
|
+
.filter(func.lower(self.db_model.text).contains(func.lower(query)))
|
|
324
|
+
.filter(self.db_model.role != "system")
|
|
325
|
+
.filter(self.db_model.role != "tool")
|
|
326
|
+
.offset(offset)
|
|
327
|
+
)
|
|
328
|
+
if limit:
|
|
329
|
+
query = query.limit(limit)
|
|
330
|
+
results = query.all()
|
|
331
|
+
# return [self.type(**vars(result)) for result in results]
|
|
332
|
+
return [result.to_record() for result in results]
|
|
333
|
+
|
|
334
|
+
# Should be used only in tests!
|
|
335
|
+
def delete_table(self):
|
|
336
|
+
close_all_sessions()
|
|
337
|
+
with self.session_maker() as session:
|
|
338
|
+
self.db_model.__table__.drop(session.bind)
|
|
339
|
+
session.commit()
|
|
340
|
+
|
|
341
|
+
def delete(self, filters: Optional[Dict] = {}):
|
|
342
|
+
filters = self.get_filters(filters)
|
|
343
|
+
with self.session_maker() as session:
|
|
344
|
+
session.query(self.db_model).filter(*filters).delete()
|
|
345
|
+
session.commit()
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
class PostgresStorageConnector(SQLStorageConnector):
|
|
349
|
+
"""Storage via Postgres"""
|
|
350
|
+
|
|
351
|
+
# TODO: this should probably eventually be moved into a parent DB class
|
|
352
|
+
|
|
353
|
+
def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None):
|
|
354
|
+
from pgvector.sqlalchemy import Vector
|
|
355
|
+
|
|
356
|
+
super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id)
|
|
357
|
+
|
|
358
|
+
# construct URI from enviornment variables
|
|
359
|
+
if settings.pg_uri:
|
|
360
|
+
self.uri = settings.pg_uri
|
|
361
|
+
else:
|
|
362
|
+
# use config URI
|
|
363
|
+
# TODO: remove this eventually (config should NOT contain URI)
|
|
364
|
+
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
|
|
365
|
+
self.uri = self.config.archival_storage_uri
|
|
366
|
+
self.db_model = PassageModel
|
|
367
|
+
if self.config.archival_storage_uri is None:
|
|
368
|
+
raise ValueError(f"Must specifiy archival_storage_uri in config {self.config.config_path}")
|
|
369
|
+
elif table_type == TableType.RECALL_MEMORY:
|
|
370
|
+
self.uri = self.config.recall_storage_uri
|
|
371
|
+
self.db_model = MessageModel
|
|
372
|
+
if self.config.recall_storage_uri is None:
|
|
373
|
+
raise ValueError(f"Must specifiy recall_storage_uri in config {self.config.config_path}")
|
|
374
|
+
else:
|
|
375
|
+
raise ValueError(f"Table type {table_type} not implemented")
|
|
376
|
+
|
|
377
|
+
for c in self.db_model.__table__.columns:
|
|
378
|
+
if c.name == "embedding":
|
|
379
|
+
assert isinstance(c.type, Vector), f"Embedding column must be of type Vector, got {c.type}"
|
|
380
|
+
|
|
381
|
+
from letta.server.server import db_context
|
|
382
|
+
|
|
383
|
+
self.session_maker = db_context
|
|
384
|
+
|
|
385
|
+
# TODO: move to DB init
|
|
386
|
+
with self.session_maker() as session:
|
|
387
|
+
session.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension
|
|
388
|
+
|
|
389
|
+
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}):
|
|
390
|
+
filters = self.get_filters(filters)
|
|
391
|
+
with self.session_maker() as session:
|
|
392
|
+
results = session.scalars(
|
|
393
|
+
select(self.db_model).filter(*filters).order_by(self.db_model.embedding.l2_distance(query_vec)).limit(top_k)
|
|
394
|
+
).all()
|
|
395
|
+
|
|
396
|
+
# Convert the results into Passage objects
|
|
397
|
+
records = [result.to_record() for result in results]
|
|
398
|
+
return records
|
|
399
|
+
|
|
400
|
+
def insert_many(self, records, exists_ok=True, show_progress=False):
|
|
401
|
+
pass
|
|
402
|
+
|
|
403
|
+
# TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel)
|
|
404
|
+
if len(records) == 0:
|
|
405
|
+
return
|
|
406
|
+
|
|
407
|
+
added_ids = [] # avoid adding duplicates
|
|
408
|
+
# NOTE: this has not great performance due to the excessive commits
|
|
409
|
+
with self.session_maker() as session:
|
|
410
|
+
iterable = tqdm(records) if show_progress else records
|
|
411
|
+
for record in iterable:
|
|
412
|
+
# db_record = self.db_model(**vars(record))
|
|
413
|
+
|
|
414
|
+
if record.id in added_ids:
|
|
415
|
+
continue
|
|
416
|
+
|
|
417
|
+
existing_record = session.query(self.db_model).filter_by(id=record.id).first()
|
|
418
|
+
if existing_record:
|
|
419
|
+
if exists_ok:
|
|
420
|
+
fields = record.model_dump()
|
|
421
|
+
fields.pop("id")
|
|
422
|
+
session.query(self.db_model).filter(self.db_model.id == record.id).update(fields)
|
|
423
|
+
print(f"Updated record with id {record.id}")
|
|
424
|
+
session.commit()
|
|
425
|
+
else:
|
|
426
|
+
raise ValueError(f"Record with id {record.id} already exists.")
|
|
427
|
+
|
|
428
|
+
else:
|
|
429
|
+
db_record = self.db_model(**record.dict())
|
|
430
|
+
session.add(db_record)
|
|
431
|
+
print(f"Added record with id {record.id}")
|
|
432
|
+
session.commit()
|
|
433
|
+
|
|
434
|
+
added_ids.append(record.id)
|
|
435
|
+
|
|
436
|
+
def insert(self, record, exists_ok=True):
|
|
437
|
+
self.insert_many([record], exists_ok=exists_ok)
|
|
438
|
+
|
|
439
|
+
def update(self, record):
|
|
440
|
+
"""
|
|
441
|
+
Updates a record in the database based on the provided Record object.
|
|
442
|
+
"""
|
|
443
|
+
with self.session_maker() as session:
|
|
444
|
+
# Find the record by its ID
|
|
445
|
+
db_record = session.query(self.db_model).filter_by(id=record.id).first()
|
|
446
|
+
if not db_record:
|
|
447
|
+
raise ValueError(f"Record with id {record.id} does not exist.")
|
|
448
|
+
|
|
449
|
+
# Update the record with new values from the provided Record object
|
|
450
|
+
for attr, value in vars(record).items():
|
|
451
|
+
setattr(db_record, attr, value)
|
|
452
|
+
|
|
453
|
+
# Commit the changes to the database
|
|
454
|
+
session.commit()
|
|
455
|
+
|
|
456
|
+
def str_to_datetime(self, str_date: str) -> datetime:
|
|
457
|
+
val = str_date.split("-")
|
|
458
|
+
_datetime = datetime(int(val[0]), int(val[1]), int(val[2]))
|
|
459
|
+
return _datetime
|
|
460
|
+
|
|
461
|
+
def query_date(self, start_date, end_date, limit=None, offset=0):
|
|
462
|
+
filters = self.get_filters({})
|
|
463
|
+
_start_date = self.str_to_datetime(start_date) if isinstance(start_date, str) else start_date
|
|
464
|
+
_end_date = self.str_to_datetime(end_date) if isinstance(end_date, str) else end_date
|
|
465
|
+
with self.session_maker() as session:
|
|
466
|
+
query = (
|
|
467
|
+
session.query(self.db_model)
|
|
468
|
+
.filter(*filters)
|
|
469
|
+
.filter(self.db_model.created_at >= _start_date)
|
|
470
|
+
.filter(self.db_model.created_at <= _end_date)
|
|
471
|
+
.filter(self.db_model.role != "system")
|
|
472
|
+
.filter(self.db_model.role != "tool")
|
|
473
|
+
.offset(offset)
|
|
474
|
+
)
|
|
475
|
+
if limit:
|
|
476
|
+
query = query.limit(limit)
|
|
477
|
+
results = query.all()
|
|
478
|
+
return [result.to_record() for result in results]
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
class SQLLiteStorageConnector(SQLStorageConnector):
|
|
482
|
+
def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None):
|
|
483
|
+
super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id)
|
|
484
|
+
|
|
485
|
+
# get storage URI
|
|
486
|
+
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
|
|
487
|
+
raise ValueError(f"Table type {table_type} not implemented")
|
|
488
|
+
elif table_type == TableType.RECALL_MEMORY:
|
|
489
|
+
# TODO: eventually implement URI option
|
|
490
|
+
self.path = self.config.recall_storage_path
|
|
491
|
+
if self.path is None:
|
|
492
|
+
raise ValueError(f"Must specifiy recall_storage_path in config {self.config.recall_storage_path}")
|
|
493
|
+
self.db_model = MessageModel
|
|
494
|
+
else:
|
|
495
|
+
raise ValueError(f"Table type {table_type} not implemented")
|
|
496
|
+
|
|
497
|
+
self.path = os.path.join(self.path, f"sqlite.db")
|
|
498
|
+
|
|
499
|
+
from letta.server.server import db_context
|
|
500
|
+
|
|
501
|
+
self.session_maker = db_context
|
|
502
|
+
|
|
503
|
+
# import sqlite3
|
|
504
|
+
|
|
505
|
+
# sqlite3.register_adapter(uuid.UUID, lambda u: u.bytes_le)
|
|
506
|
+
# sqlite3.register_converter("UUID", lambda b: uuid.UUID(bytes_le=b))
|
|
507
|
+
|
|
508
|
+
def insert_many(self, records, exists_ok=True, show_progress=False):
|
|
509
|
+
pass
|
|
510
|
+
|
|
511
|
+
# TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel)
|
|
512
|
+
if len(records) == 0:
|
|
513
|
+
return
|
|
514
|
+
with self.session_maker() as session:
|
|
515
|
+
iterable = tqdm(records) if show_progress else records
|
|
516
|
+
for record in iterable:
|
|
517
|
+
# db_record = self.db_model(**vars(record))
|
|
518
|
+
db_record = self.db_model(**record.dict())
|
|
519
|
+
session.add(db_record)
|
|
520
|
+
session.commit()
|
|
521
|
+
|
|
522
|
+
def insert(self, record, exists_ok=True):
|
|
523
|
+
self.insert_many([record], exists_ok=exists_ok)
|
|
524
|
+
|
|
525
|
+
def update(self, record):
|
|
526
|
+
"""
|
|
527
|
+
Updates an existing record in the database with values from the provided record object.
|
|
528
|
+
"""
|
|
529
|
+
if not record.id:
|
|
530
|
+
raise ValueError("Record must have an id.")
|
|
531
|
+
|
|
532
|
+
with self.session_maker() as session:
|
|
533
|
+
# Fetch the existing record from the database
|
|
534
|
+
db_record = session.query(self.db_model).filter_by(id=record.id).first()
|
|
535
|
+
if not db_record:
|
|
536
|
+
raise ValueError(f"Record with id {record.id} does not exist.")
|
|
537
|
+
|
|
538
|
+
# Update the database record with values from the provided record object
|
|
539
|
+
for column in self.db_model.__table__.columns:
|
|
540
|
+
column_name = column.name
|
|
541
|
+
if hasattr(record, column_name):
|
|
542
|
+
new_value = getattr(record, column_name)
|
|
543
|
+
setattr(db_record, column_name, new_value)
|
|
544
|
+
|
|
545
|
+
# Commit the changes to the database
|
|
546
|
+
session.commit()
|