letta-nightly 0.1.7.dev20240924104148__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of letta-nightly might be problematic. Click here for more details.

Files changed (189) hide show
  1. letta/__init__.py +24 -0
  2. letta/__main__.py +3 -0
  3. letta/agent.py +1427 -0
  4. letta/agent_store/chroma.py +295 -0
  5. letta/agent_store/db.py +546 -0
  6. letta/agent_store/lancedb.py +177 -0
  7. letta/agent_store/milvus.py +198 -0
  8. letta/agent_store/qdrant.py +201 -0
  9. letta/agent_store/storage.py +188 -0
  10. letta/benchmark/benchmark.py +96 -0
  11. letta/benchmark/constants.py +14 -0
  12. letta/cli/cli.py +689 -0
  13. letta/cli/cli_config.py +1282 -0
  14. letta/cli/cli_load.py +166 -0
  15. letta/client/__init__.py +0 -0
  16. letta/client/admin.py +171 -0
  17. letta/client/client.py +2360 -0
  18. letta/client/streaming.py +90 -0
  19. letta/client/utils.py +61 -0
  20. letta/config.py +484 -0
  21. letta/configs/anthropic.json +13 -0
  22. letta/configs/letta_hosted.json +11 -0
  23. letta/configs/openai.json +12 -0
  24. letta/constants.py +134 -0
  25. letta/credentials.py +140 -0
  26. letta/data_sources/connectors.py +247 -0
  27. letta/embeddings.py +218 -0
  28. letta/errors.py +26 -0
  29. letta/functions/__init__.py +0 -0
  30. letta/functions/function_sets/base.py +174 -0
  31. letta/functions/function_sets/extras.py +132 -0
  32. letta/functions/functions.py +105 -0
  33. letta/functions/schema_generator.py +205 -0
  34. letta/humans/__init__.py +0 -0
  35. letta/humans/examples/basic.txt +1 -0
  36. letta/humans/examples/cs_phd.txt +9 -0
  37. letta/interface.py +314 -0
  38. letta/llm_api/__init__.py +0 -0
  39. letta/llm_api/anthropic.py +383 -0
  40. letta/llm_api/azure_openai.py +155 -0
  41. letta/llm_api/cohere.py +396 -0
  42. letta/llm_api/google_ai.py +468 -0
  43. letta/llm_api/llm_api_tools.py +485 -0
  44. letta/llm_api/openai.py +470 -0
  45. letta/local_llm/README.md +3 -0
  46. letta/local_llm/__init__.py +0 -0
  47. letta/local_llm/chat_completion_proxy.py +279 -0
  48. letta/local_llm/constants.py +31 -0
  49. letta/local_llm/function_parser.py +68 -0
  50. letta/local_llm/grammars/__init__.py +0 -0
  51. letta/local_llm/grammars/gbnf_grammar_generator.py +1324 -0
  52. letta/local_llm/grammars/json.gbnf +26 -0
  53. letta/local_llm/grammars/json_func_calls_with_inner_thoughts.gbnf +32 -0
  54. letta/local_llm/groq/api.py +97 -0
  55. letta/local_llm/json_parser.py +202 -0
  56. letta/local_llm/koboldcpp/api.py +62 -0
  57. letta/local_llm/koboldcpp/settings.py +23 -0
  58. letta/local_llm/llamacpp/api.py +58 -0
  59. letta/local_llm/llamacpp/settings.py +22 -0
  60. letta/local_llm/llm_chat_completion_wrappers/__init__.py +0 -0
  61. letta/local_llm/llm_chat_completion_wrappers/airoboros.py +452 -0
  62. letta/local_llm/llm_chat_completion_wrappers/chatml.py +470 -0
  63. letta/local_llm/llm_chat_completion_wrappers/configurable_wrapper.py +387 -0
  64. letta/local_llm/llm_chat_completion_wrappers/dolphin.py +246 -0
  65. letta/local_llm/llm_chat_completion_wrappers/llama3.py +345 -0
  66. letta/local_llm/llm_chat_completion_wrappers/simple_summary_wrapper.py +156 -0
  67. letta/local_llm/llm_chat_completion_wrappers/wrapper_base.py +11 -0
  68. letta/local_llm/llm_chat_completion_wrappers/zephyr.py +345 -0
  69. letta/local_llm/lmstudio/api.py +100 -0
  70. letta/local_llm/lmstudio/settings.py +29 -0
  71. letta/local_llm/ollama/api.py +88 -0
  72. letta/local_llm/ollama/settings.py +32 -0
  73. letta/local_llm/settings/__init__.py +0 -0
  74. letta/local_llm/settings/deterministic_mirostat.py +45 -0
  75. letta/local_llm/settings/settings.py +72 -0
  76. letta/local_llm/settings/simple.py +28 -0
  77. letta/local_llm/utils.py +265 -0
  78. letta/local_llm/vllm/api.py +63 -0
  79. letta/local_llm/webui/api.py +60 -0
  80. letta/local_llm/webui/legacy_api.py +58 -0
  81. letta/local_llm/webui/legacy_settings.py +23 -0
  82. letta/local_llm/webui/settings.py +24 -0
  83. letta/log.py +76 -0
  84. letta/main.py +437 -0
  85. letta/memory.py +440 -0
  86. letta/metadata.py +884 -0
  87. letta/openai_backcompat/__init__.py +0 -0
  88. letta/openai_backcompat/openai_object.py +437 -0
  89. letta/persistence_manager.py +148 -0
  90. letta/personas/__init__.py +0 -0
  91. letta/personas/examples/anna_pa.txt +13 -0
  92. letta/personas/examples/google_search_persona.txt +15 -0
  93. letta/personas/examples/memgpt_doc.txt +6 -0
  94. letta/personas/examples/memgpt_starter.txt +4 -0
  95. letta/personas/examples/sam.txt +14 -0
  96. letta/personas/examples/sam_pov.txt +14 -0
  97. letta/personas/examples/sam_simple_pov_gpt35.txt +13 -0
  98. letta/personas/examples/sqldb/test.db +0 -0
  99. letta/prompts/__init__.py +0 -0
  100. letta/prompts/gpt_summarize.py +14 -0
  101. letta/prompts/gpt_system.py +26 -0
  102. letta/prompts/system/memgpt_base.txt +49 -0
  103. letta/prompts/system/memgpt_chat.txt +58 -0
  104. letta/prompts/system/memgpt_chat_compressed.txt +13 -0
  105. letta/prompts/system/memgpt_chat_fstring.txt +51 -0
  106. letta/prompts/system/memgpt_doc.txt +50 -0
  107. letta/prompts/system/memgpt_gpt35_extralong.txt +53 -0
  108. letta/prompts/system/memgpt_intuitive_knowledge.txt +31 -0
  109. letta/prompts/system/memgpt_modified_chat.txt +23 -0
  110. letta/pytest.ini +0 -0
  111. letta/schemas/agent.py +117 -0
  112. letta/schemas/api_key.py +21 -0
  113. letta/schemas/block.py +135 -0
  114. letta/schemas/document.py +21 -0
  115. letta/schemas/embedding_config.py +54 -0
  116. letta/schemas/enums.py +35 -0
  117. letta/schemas/job.py +38 -0
  118. letta/schemas/letta_base.py +80 -0
  119. letta/schemas/letta_message.py +175 -0
  120. letta/schemas/letta_request.py +23 -0
  121. letta/schemas/letta_response.py +28 -0
  122. letta/schemas/llm_config.py +54 -0
  123. letta/schemas/memory.py +224 -0
  124. letta/schemas/message.py +727 -0
  125. letta/schemas/openai/chat_completion_request.py +123 -0
  126. letta/schemas/openai/chat_completion_response.py +136 -0
  127. letta/schemas/openai/chat_completions.py +123 -0
  128. letta/schemas/openai/embedding_response.py +11 -0
  129. letta/schemas/openai/openai.py +157 -0
  130. letta/schemas/organization.py +20 -0
  131. letta/schemas/passage.py +80 -0
  132. letta/schemas/source.py +62 -0
  133. letta/schemas/tool.py +143 -0
  134. letta/schemas/usage.py +18 -0
  135. letta/schemas/user.py +33 -0
  136. letta/server/__init__.py +0 -0
  137. letta/server/constants.py +6 -0
  138. letta/server/rest_api/__init__.py +0 -0
  139. letta/server/rest_api/admin/__init__.py +0 -0
  140. letta/server/rest_api/admin/agents.py +21 -0
  141. letta/server/rest_api/admin/tools.py +83 -0
  142. letta/server/rest_api/admin/users.py +98 -0
  143. letta/server/rest_api/app.py +193 -0
  144. letta/server/rest_api/auth/__init__.py +0 -0
  145. letta/server/rest_api/auth/index.py +43 -0
  146. letta/server/rest_api/auth_token.py +22 -0
  147. letta/server/rest_api/interface.py +726 -0
  148. letta/server/rest_api/routers/__init__.py +0 -0
  149. letta/server/rest_api/routers/openai/__init__.py +0 -0
  150. letta/server/rest_api/routers/openai/assistants/__init__.py +0 -0
  151. letta/server/rest_api/routers/openai/assistants/assistants.py +115 -0
  152. letta/server/rest_api/routers/openai/assistants/schemas.py +121 -0
  153. letta/server/rest_api/routers/openai/assistants/threads.py +336 -0
  154. letta/server/rest_api/routers/openai/chat_completions/__init__.py +0 -0
  155. letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +131 -0
  156. letta/server/rest_api/routers/v1/__init__.py +15 -0
  157. letta/server/rest_api/routers/v1/agents.py +543 -0
  158. letta/server/rest_api/routers/v1/blocks.py +73 -0
  159. letta/server/rest_api/routers/v1/jobs.py +46 -0
  160. letta/server/rest_api/routers/v1/llms.py +28 -0
  161. letta/server/rest_api/routers/v1/organizations.py +61 -0
  162. letta/server/rest_api/routers/v1/sources.py +199 -0
  163. letta/server/rest_api/routers/v1/tools.py +103 -0
  164. letta/server/rest_api/routers/v1/users.py +109 -0
  165. letta/server/rest_api/static_files.py +74 -0
  166. letta/server/rest_api/utils.py +69 -0
  167. letta/server/server.py +1995 -0
  168. letta/server/startup.sh +8 -0
  169. letta/server/static_files/assets/index-0cbf7ad5.js +274 -0
  170. letta/server/static_files/assets/index-156816da.css +1 -0
  171. letta/server/static_files/assets/index-486e3228.js +274 -0
  172. letta/server/static_files/favicon.ico +0 -0
  173. letta/server/static_files/index.html +39 -0
  174. letta/server/static_files/memgpt_logo_transparent.png +0 -0
  175. letta/server/utils.py +46 -0
  176. letta/server/ws_api/__init__.py +0 -0
  177. letta/server/ws_api/example_client.py +104 -0
  178. letta/server/ws_api/interface.py +108 -0
  179. letta/server/ws_api/protocol.py +100 -0
  180. letta/server/ws_api/server.py +145 -0
  181. letta/settings.py +165 -0
  182. letta/streaming_interface.py +396 -0
  183. letta/system.py +207 -0
  184. letta/utils.py +1065 -0
  185. letta_nightly-0.1.7.dev20240924104148.dist-info/LICENSE +190 -0
  186. letta_nightly-0.1.7.dev20240924104148.dist-info/METADATA +98 -0
  187. letta_nightly-0.1.7.dev20240924104148.dist-info/RECORD +189 -0
  188. letta_nightly-0.1.7.dev20240924104148.dist-info/WHEEL +4 -0
  189. letta_nightly-0.1.7.dev20240924104148.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,546 @@
1
+ import base64
2
+ import os
3
+ from datetime import datetime
4
+ from typing import Dict, List, Optional
5
+
6
+ import numpy as np
7
+ from sqlalchemy import (
8
+ BINARY,
9
+ Column,
10
+ DateTime,
11
+ Index,
12
+ String,
13
+ TypeDecorator,
14
+ and_,
15
+ asc,
16
+ desc,
17
+ or_,
18
+ select,
19
+ text,
20
+ )
21
+ from sqlalchemy.orm import declarative_base, mapped_column
22
+ from sqlalchemy.orm.session import close_all_sessions
23
+ from sqlalchemy.sql import func
24
+ from sqlalchemy_json import MutableJson
25
+ from tqdm import tqdm
26
+
27
+ from letta.agent_store.storage import StorageConnector, TableType
28
+ from letta.config import LettaConfig
29
+ from letta.constants import MAX_EMBEDDING_DIM
30
+ from letta.metadata import EmbeddingConfigColumn, ToolCallColumn
31
+
32
+ # from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall
33
+ from letta.schemas.message import Message
34
+ from letta.schemas.openai.chat_completions import ToolCall
35
+ from letta.schemas.passage import Passage
36
+ from letta.settings import settings
37
+
38
+ Base = declarative_base()
39
+ config = LettaConfig()
40
+
41
+
42
+ class CommonVector(TypeDecorator):
43
+ """Common type for representing vectors in SQLite"""
44
+
45
+ impl = BINARY
46
+ cache_ok = True
47
+
48
+ def load_dialect_impl(self, dialect):
49
+ return dialect.type_descriptor(BINARY())
50
+
51
+ def process_bind_param(self, value, dialect):
52
+ if value is None:
53
+ return value
54
+ # Ensure value is a numpy array
55
+ if isinstance(value, list):
56
+ value = np.array(value, dtype=np.float32)
57
+ # Serialize numpy array to bytes, then encode to base64 for universal compatibility
58
+ return base64.b64encode(value.tobytes())
59
+
60
+ def process_result_value(self, value, dialect):
61
+ if not value:
62
+ return value
63
+ # Check database type and deserialize accordingly
64
+ if dialect.name == "sqlite":
65
+ # Decode from base64 and convert back to numpy array
66
+ value = base64.b64decode(value)
67
+ # For PostgreSQL, value is already in bytes
68
+ return np.frombuffer(value, dtype=np.float32)
69
+
70
+
71
+ class MessageModel(Base):
72
+ """Defines data model for storing Message objects"""
73
+
74
+ __tablename__ = "messages"
75
+ __table_args__ = {"extend_existing": True}
76
+
77
+ # Assuming message_id is the primary key
78
+ id = Column(String, primary_key=True)
79
+ user_id = Column(String, nullable=False)
80
+ agent_id = Column(String, nullable=False)
81
+
82
+ # openai info
83
+ role = Column(String, nullable=False)
84
+ text = Column(String) # optional: can be null if function call
85
+ model = Column(String) # optional: can be null if LLM backend doesn't require specifying
86
+ name = Column(String) # optional: multi-agent only
87
+
88
+ # tool call request info
89
+ # if role == "assistant", this MAY be specified
90
+ # if role != "assistant", this must be null
91
+ # TODO align with OpenAI spec of multiple tool calls
92
+ # tool_calls = Column(ToolCallColumn)
93
+ tool_calls = Column(ToolCallColumn)
94
+
95
+ # tool call response info
96
+ # if role == "tool", then this must be specified
97
+ # if role != "tool", this must be null
98
+ tool_call_id = Column(String)
99
+
100
+ # Add a datetime column, with default value as the current time
101
+ created_at = Column(DateTime(timezone=True))
102
+ Index("message_idx_user", user_id, agent_id),
103
+
104
+ def __repr__(self):
105
+ return f"<Message(message_id='{self.id}', text='{self.text}')>"
106
+
107
+ def to_record(self):
108
+ # calls = (
109
+ # [ToolCall(id=tool_call["id"], function=ToolCallFunction(**tool_call["function"])) for tool_call in self.tool_calls]
110
+ # if self.tool_calls
111
+ # else None
112
+ # )
113
+ # if calls:
114
+ # assert isinstance(calls[0], ToolCall)
115
+ if self.tool_calls and len(self.tool_calls) > 0:
116
+ assert isinstance(self.tool_calls[0], ToolCall), type(self.tool_calls[0])
117
+ for tool in self.tool_calls:
118
+ assert isinstance(tool, ToolCall), type(tool)
119
+ return Message(
120
+ user_id=self.user_id,
121
+ agent_id=self.agent_id,
122
+ role=self.role,
123
+ name=self.name,
124
+ text=self.text,
125
+ model=self.model,
126
+ # tool_calls=[ToolCall(id=tool_call["id"], function=ToolCallFunction(**tool_call["function"])) for tool_call in self.tool_calls] if self.tool_calls else None,
127
+ tool_calls=self.tool_calls,
128
+ tool_call_id=self.tool_call_id,
129
+ created_at=self.created_at,
130
+ id=self.id,
131
+ )
132
+
133
+
134
+ class PassageModel(Base):
135
+ """Defines data model for storing Passages (consisting of text, embedding)"""
136
+
137
+ __tablename__ = "passages"
138
+ __table_args__ = {"extend_existing": True}
139
+
140
+ # Assuming passage_id is the primary key
141
+ id = Column(String, primary_key=True)
142
+ user_id = Column(String, nullable=False)
143
+ text = Column(String)
144
+ doc_id = Column(String)
145
+ agent_id = Column(String)
146
+ source_id = Column(String)
147
+
148
+ # vector storage
149
+ if settings.letta_pg_uri_no_default:
150
+ from pgvector.sqlalchemy import Vector
151
+
152
+ embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
153
+ elif config.archival_storage_type == "sqlite" or config.archival_storage_type == "chroma":
154
+ embedding = Column(CommonVector)
155
+ else:
156
+ raise ValueError(f"Unsupported archival_storage_type: {config.archival_storage_type}")
157
+ embedding_config = Column(EmbeddingConfigColumn)
158
+ metadata_ = Column(MutableJson)
159
+
160
+ # Add a datetime column, with default value as the current time
161
+ created_at = Column(DateTime(timezone=True))
162
+
163
+ Index("passage_idx_user", user_id, agent_id, doc_id),
164
+
165
+ def __repr__(self):
166
+ return f"<Passage(passage_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
167
+
168
+ def to_record(self):
169
+ return Passage(
170
+ text=self.text,
171
+ embedding=self.embedding,
172
+ embedding_config=self.embedding_config,
173
+ doc_id=self.doc_id,
174
+ user_id=self.user_id,
175
+ id=self.id,
176
+ source_id=self.source_id,
177
+ agent_id=self.agent_id,
178
+ metadata_=self.metadata_,
179
+ created_at=self.created_at,
180
+ )
181
+
182
+
183
+ class SQLStorageConnector(StorageConnector):
184
+ def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None):
185
+ super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id)
186
+ self.config = config
187
+
188
+ def get_filters(self, filters: Optional[Dict] = {}):
189
+ if filters is not None:
190
+ filter_conditions = {**self.filters, **filters}
191
+ else:
192
+ filter_conditions = self.filters
193
+ all_filters = [getattr(self.db_model, key) == value for key, value in filter_conditions.items()]
194
+ return all_filters
195
+
196
+ def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000, offset=0):
197
+ filters = self.get_filters(filters)
198
+ while True:
199
+ # Retrieve a chunk of records with the given page_size
200
+ with self.session_maker() as session:
201
+ db_record_chunk = session.query(self.db_model).filter(*filters).offset(offset).limit(page_size).all()
202
+
203
+ # If the chunk is empty, we've retrieved all records
204
+ if not db_record_chunk:
205
+ break
206
+
207
+ # Yield a list of Record objects converted from the chunk
208
+ yield [record.to_record() for record in db_record_chunk]
209
+
210
+ # Increment the offset to get the next chunk in the next iteration
211
+ offset += page_size
212
+
213
+ def get_all_cursor(
214
+ self,
215
+ filters: Optional[Dict] = {},
216
+ after: str = None,
217
+ before: str = None,
218
+ limit: Optional[int] = 1000,
219
+ order_by: str = "created_at",
220
+ reverse: bool = False,
221
+ ):
222
+ """Get all that returns a cursor (record.id) and records"""
223
+ filters = self.get_filters(filters)
224
+
225
+ # generate query
226
+ with self.session_maker() as session:
227
+ query = session.query(self.db_model).filter(*filters)
228
+ # query = query.order_by(asc(self.db_model.id))
229
+
230
+ # records are sorted by the order_by field first, and then by the ID if two fields are the same
231
+ if reverse:
232
+ query = query.order_by(desc(getattr(self.db_model, order_by)), asc(self.db_model.id))
233
+ else:
234
+ query = query.order_by(asc(getattr(self.db_model, order_by)), asc(self.db_model.id))
235
+
236
+ # cursor logic: filter records based on before/after ID
237
+ if after:
238
+ after_value = getattr(self.get(id=after), order_by)
239
+ sort_exp = getattr(self.db_model, order_by) > after_value
240
+ query = query.filter(
241
+ or_(sort_exp, and_(getattr(self.db_model, order_by) == after_value, self.db_model.id > after)) # tiebreaker case
242
+ )
243
+ if before:
244
+ before_value = getattr(self.get(id=before), order_by)
245
+ sort_exp = getattr(self.db_model, order_by) < before_value
246
+ query = query.filter(or_(sort_exp, and_(getattr(self.db_model, order_by) == before_value, self.db_model.id < before)))
247
+
248
+ # get records
249
+ db_record_chunk = query.limit(limit).all()
250
+ if not db_record_chunk:
251
+ return (None, [])
252
+ records = [record.to_record() for record in db_record_chunk]
253
+ next_cursor = db_record_chunk[-1].id
254
+ assert isinstance(next_cursor, str)
255
+
256
+ # return (cursor, list[records])
257
+ return (next_cursor, records)
258
+
259
+ def get_all(self, filters: Optional[Dict] = {}, limit=None):
260
+ filters = self.get_filters(filters)
261
+ with self.session_maker() as session:
262
+ if limit:
263
+ db_records = session.query(self.db_model).filter(*filters).limit(limit).all()
264
+ else:
265
+ db_records = session.query(self.db_model).filter(*filters).all()
266
+ return [record.to_record() for record in db_records]
267
+
268
+ def get(self, id: str):
269
+ with self.session_maker() as session:
270
+ db_record = session.get(self.db_model, id)
271
+ if db_record is None:
272
+ return None
273
+ return db_record.to_record()
274
+
275
+ def size(self, filters: Optional[Dict] = {}) -> int:
276
+ # return size of table
277
+ filters = self.get_filters(filters)
278
+ with self.session_maker() as session:
279
+ return session.query(self.db_model).filter(*filters).count()
280
+
281
+ def insert(self, record):
282
+ raise NotImplementedError
283
+
284
+ def insert_many(self, records, show_progress=False):
285
+ raise NotImplementedError
286
+
287
+ def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}):
288
+ raise NotImplementedError("Vector query not implemented for SQLStorageConnector")
289
+
290
+ def save(self):
291
+ return
292
+
293
+ def list_data_sources(self):
294
+ assert self.table_type == TableType.ARCHIVAL_MEMORY, f"list_data_sources only implemented for ARCHIVAL_MEMORY"
295
+ with self.session_maker() as session:
296
+ unique_data_sources = session.query(self.db_model.data_source).filter(*self.filters).distinct().all()
297
+ return unique_data_sources
298
+
299
+ def query_date(self, start_date, end_date, limit=None, offset=0):
300
+ filters = self.get_filters({})
301
+ with self.session_maker() as session:
302
+ query = (
303
+ session.query(self.db_model)
304
+ .filter(*filters)
305
+ .filter(self.db_model.created_at >= start_date)
306
+ .filter(self.db_model.created_at <= end_date)
307
+ .filter(self.db_model.role != "system")
308
+ .filter(self.db_model.role != "tool")
309
+ .offset(offset)
310
+ )
311
+ if limit:
312
+ query = query.limit(limit)
313
+ results = query.all()
314
+ return [result.to_record() for result in results]
315
+
316
+ def query_text(self, query, limit=None, offset=0):
317
+ # todo: make fuzz https://stackoverflow.com/questions/42388956/create-a-full-text-search-index-with-sqlalchemy-on-postgresql/42390204#42390204
318
+ filters = self.get_filters({})
319
+ with self.session_maker() as session:
320
+ query = (
321
+ session.query(self.db_model)
322
+ .filter(*filters)
323
+ .filter(func.lower(self.db_model.text).contains(func.lower(query)))
324
+ .filter(self.db_model.role != "system")
325
+ .filter(self.db_model.role != "tool")
326
+ .offset(offset)
327
+ )
328
+ if limit:
329
+ query = query.limit(limit)
330
+ results = query.all()
331
+ # return [self.type(**vars(result)) for result in results]
332
+ return [result.to_record() for result in results]
333
+
334
+ # Should be used only in tests!
335
+ def delete_table(self):
336
+ close_all_sessions()
337
+ with self.session_maker() as session:
338
+ self.db_model.__table__.drop(session.bind)
339
+ session.commit()
340
+
341
+ def delete(self, filters: Optional[Dict] = {}):
342
+ filters = self.get_filters(filters)
343
+ with self.session_maker() as session:
344
+ session.query(self.db_model).filter(*filters).delete()
345
+ session.commit()
346
+
347
+
348
+ class PostgresStorageConnector(SQLStorageConnector):
349
+ """Storage via Postgres"""
350
+
351
+ # TODO: this should probably eventually be moved into a parent DB class
352
+
353
+ def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None):
354
+ from pgvector.sqlalchemy import Vector
355
+
356
+ super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id)
357
+
358
+ # construct URI from enviornment variables
359
+ if settings.pg_uri:
360
+ self.uri = settings.pg_uri
361
+ else:
362
+ # use config URI
363
+ # TODO: remove this eventually (config should NOT contain URI)
364
+ if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
365
+ self.uri = self.config.archival_storage_uri
366
+ self.db_model = PassageModel
367
+ if self.config.archival_storage_uri is None:
368
+ raise ValueError(f"Must specifiy archival_storage_uri in config {self.config.config_path}")
369
+ elif table_type == TableType.RECALL_MEMORY:
370
+ self.uri = self.config.recall_storage_uri
371
+ self.db_model = MessageModel
372
+ if self.config.recall_storage_uri is None:
373
+ raise ValueError(f"Must specifiy recall_storage_uri in config {self.config.config_path}")
374
+ else:
375
+ raise ValueError(f"Table type {table_type} not implemented")
376
+
377
+ for c in self.db_model.__table__.columns:
378
+ if c.name == "embedding":
379
+ assert isinstance(c.type, Vector), f"Embedding column must be of type Vector, got {c.type}"
380
+
381
+ from letta.server.server import db_context
382
+
383
+ self.session_maker = db_context
384
+
385
+ # TODO: move to DB init
386
+ with self.session_maker() as session:
387
+ session.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension
388
+
389
+ def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}):
390
+ filters = self.get_filters(filters)
391
+ with self.session_maker() as session:
392
+ results = session.scalars(
393
+ select(self.db_model).filter(*filters).order_by(self.db_model.embedding.l2_distance(query_vec)).limit(top_k)
394
+ ).all()
395
+
396
+ # Convert the results into Passage objects
397
+ records = [result.to_record() for result in results]
398
+ return records
399
+
400
+ def insert_many(self, records, exists_ok=True, show_progress=False):
401
+ pass
402
+
403
+ # TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel)
404
+ if len(records) == 0:
405
+ return
406
+
407
+ added_ids = [] # avoid adding duplicates
408
+ # NOTE: this has not great performance due to the excessive commits
409
+ with self.session_maker() as session:
410
+ iterable = tqdm(records) if show_progress else records
411
+ for record in iterable:
412
+ # db_record = self.db_model(**vars(record))
413
+
414
+ if record.id in added_ids:
415
+ continue
416
+
417
+ existing_record = session.query(self.db_model).filter_by(id=record.id).first()
418
+ if existing_record:
419
+ if exists_ok:
420
+ fields = record.model_dump()
421
+ fields.pop("id")
422
+ session.query(self.db_model).filter(self.db_model.id == record.id).update(fields)
423
+ print(f"Updated record with id {record.id}")
424
+ session.commit()
425
+ else:
426
+ raise ValueError(f"Record with id {record.id} already exists.")
427
+
428
+ else:
429
+ db_record = self.db_model(**record.dict())
430
+ session.add(db_record)
431
+ print(f"Added record with id {record.id}")
432
+ session.commit()
433
+
434
+ added_ids.append(record.id)
435
+
436
+ def insert(self, record, exists_ok=True):
437
+ self.insert_many([record], exists_ok=exists_ok)
438
+
439
+ def update(self, record):
440
+ """
441
+ Updates a record in the database based on the provided Record object.
442
+ """
443
+ with self.session_maker() as session:
444
+ # Find the record by its ID
445
+ db_record = session.query(self.db_model).filter_by(id=record.id).first()
446
+ if not db_record:
447
+ raise ValueError(f"Record with id {record.id} does not exist.")
448
+
449
+ # Update the record with new values from the provided Record object
450
+ for attr, value in vars(record).items():
451
+ setattr(db_record, attr, value)
452
+
453
+ # Commit the changes to the database
454
+ session.commit()
455
+
456
+ def str_to_datetime(self, str_date: str) -> datetime:
457
+ val = str_date.split("-")
458
+ _datetime = datetime(int(val[0]), int(val[1]), int(val[2]))
459
+ return _datetime
460
+
461
+ def query_date(self, start_date, end_date, limit=None, offset=0):
462
+ filters = self.get_filters({})
463
+ _start_date = self.str_to_datetime(start_date) if isinstance(start_date, str) else start_date
464
+ _end_date = self.str_to_datetime(end_date) if isinstance(end_date, str) else end_date
465
+ with self.session_maker() as session:
466
+ query = (
467
+ session.query(self.db_model)
468
+ .filter(*filters)
469
+ .filter(self.db_model.created_at >= _start_date)
470
+ .filter(self.db_model.created_at <= _end_date)
471
+ .filter(self.db_model.role != "system")
472
+ .filter(self.db_model.role != "tool")
473
+ .offset(offset)
474
+ )
475
+ if limit:
476
+ query = query.limit(limit)
477
+ results = query.all()
478
+ return [result.to_record() for result in results]
479
+
480
+
481
+ class SQLLiteStorageConnector(SQLStorageConnector):
482
+ def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None):
483
+ super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id)
484
+
485
+ # get storage URI
486
+ if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
487
+ raise ValueError(f"Table type {table_type} not implemented")
488
+ elif table_type == TableType.RECALL_MEMORY:
489
+ # TODO: eventually implement URI option
490
+ self.path = self.config.recall_storage_path
491
+ if self.path is None:
492
+ raise ValueError(f"Must specifiy recall_storage_path in config {self.config.recall_storage_path}")
493
+ self.db_model = MessageModel
494
+ else:
495
+ raise ValueError(f"Table type {table_type} not implemented")
496
+
497
+ self.path = os.path.join(self.path, f"sqlite.db")
498
+
499
+ from letta.server.server import db_context
500
+
501
+ self.session_maker = db_context
502
+
503
+ # import sqlite3
504
+
505
+ # sqlite3.register_adapter(uuid.UUID, lambda u: u.bytes_le)
506
+ # sqlite3.register_converter("UUID", lambda b: uuid.UUID(bytes_le=b))
507
+
508
+ def insert_many(self, records, exists_ok=True, show_progress=False):
509
+ pass
510
+
511
+ # TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel)
512
+ if len(records) == 0:
513
+ return
514
+ with self.session_maker() as session:
515
+ iterable = tqdm(records) if show_progress else records
516
+ for record in iterable:
517
+ # db_record = self.db_model(**vars(record))
518
+ db_record = self.db_model(**record.dict())
519
+ session.add(db_record)
520
+ session.commit()
521
+
522
+ def insert(self, record, exists_ok=True):
523
+ self.insert_many([record], exists_ok=exists_ok)
524
+
525
+ def update(self, record):
526
+ """
527
+ Updates an existing record in the database with values from the provided record object.
528
+ """
529
+ if not record.id:
530
+ raise ValueError("Record must have an id.")
531
+
532
+ with self.session_maker() as session:
533
+ # Fetch the existing record from the database
534
+ db_record = session.query(self.db_model).filter_by(id=record.id).first()
535
+ if not db_record:
536
+ raise ValueError(f"Record with id {record.id} does not exist.")
537
+
538
+ # Update the database record with values from the provided record object
539
+ for column in self.db_model.__table__.columns:
540
+ column_name = column.name
541
+ if hasattr(record, column_name):
542
+ new_value = getattr(record, column_name)
543
+ setattr(db_record, column_name, new_value)
544
+
545
+ # Commit the changes to the database
546
+ session.commit()