letta-nightly 0.1.7.dev20240924104148__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of letta-nightly might be problematic. Click here for more details.

Files changed (189) hide show
  1. letta/__init__.py +24 -0
  2. letta/__main__.py +3 -0
  3. letta/agent.py +1427 -0
  4. letta/agent_store/chroma.py +295 -0
  5. letta/agent_store/db.py +546 -0
  6. letta/agent_store/lancedb.py +177 -0
  7. letta/agent_store/milvus.py +198 -0
  8. letta/agent_store/qdrant.py +201 -0
  9. letta/agent_store/storage.py +188 -0
  10. letta/benchmark/benchmark.py +96 -0
  11. letta/benchmark/constants.py +14 -0
  12. letta/cli/cli.py +689 -0
  13. letta/cli/cli_config.py +1282 -0
  14. letta/cli/cli_load.py +166 -0
  15. letta/client/__init__.py +0 -0
  16. letta/client/admin.py +171 -0
  17. letta/client/client.py +2360 -0
  18. letta/client/streaming.py +90 -0
  19. letta/client/utils.py +61 -0
  20. letta/config.py +484 -0
  21. letta/configs/anthropic.json +13 -0
  22. letta/configs/letta_hosted.json +11 -0
  23. letta/configs/openai.json +12 -0
  24. letta/constants.py +134 -0
  25. letta/credentials.py +140 -0
  26. letta/data_sources/connectors.py +247 -0
  27. letta/embeddings.py +218 -0
  28. letta/errors.py +26 -0
  29. letta/functions/__init__.py +0 -0
  30. letta/functions/function_sets/base.py +174 -0
  31. letta/functions/function_sets/extras.py +132 -0
  32. letta/functions/functions.py +105 -0
  33. letta/functions/schema_generator.py +205 -0
  34. letta/humans/__init__.py +0 -0
  35. letta/humans/examples/basic.txt +1 -0
  36. letta/humans/examples/cs_phd.txt +9 -0
  37. letta/interface.py +314 -0
  38. letta/llm_api/__init__.py +0 -0
  39. letta/llm_api/anthropic.py +383 -0
  40. letta/llm_api/azure_openai.py +155 -0
  41. letta/llm_api/cohere.py +396 -0
  42. letta/llm_api/google_ai.py +468 -0
  43. letta/llm_api/llm_api_tools.py +485 -0
  44. letta/llm_api/openai.py +470 -0
  45. letta/local_llm/README.md +3 -0
  46. letta/local_llm/__init__.py +0 -0
  47. letta/local_llm/chat_completion_proxy.py +279 -0
  48. letta/local_llm/constants.py +31 -0
  49. letta/local_llm/function_parser.py +68 -0
  50. letta/local_llm/grammars/__init__.py +0 -0
  51. letta/local_llm/grammars/gbnf_grammar_generator.py +1324 -0
  52. letta/local_llm/grammars/json.gbnf +26 -0
  53. letta/local_llm/grammars/json_func_calls_with_inner_thoughts.gbnf +32 -0
  54. letta/local_llm/groq/api.py +97 -0
  55. letta/local_llm/json_parser.py +202 -0
  56. letta/local_llm/koboldcpp/api.py +62 -0
  57. letta/local_llm/koboldcpp/settings.py +23 -0
  58. letta/local_llm/llamacpp/api.py +58 -0
  59. letta/local_llm/llamacpp/settings.py +22 -0
  60. letta/local_llm/llm_chat_completion_wrappers/__init__.py +0 -0
  61. letta/local_llm/llm_chat_completion_wrappers/airoboros.py +452 -0
  62. letta/local_llm/llm_chat_completion_wrappers/chatml.py +470 -0
  63. letta/local_llm/llm_chat_completion_wrappers/configurable_wrapper.py +387 -0
  64. letta/local_llm/llm_chat_completion_wrappers/dolphin.py +246 -0
  65. letta/local_llm/llm_chat_completion_wrappers/llama3.py +345 -0
  66. letta/local_llm/llm_chat_completion_wrappers/simple_summary_wrapper.py +156 -0
  67. letta/local_llm/llm_chat_completion_wrappers/wrapper_base.py +11 -0
  68. letta/local_llm/llm_chat_completion_wrappers/zephyr.py +345 -0
  69. letta/local_llm/lmstudio/api.py +100 -0
  70. letta/local_llm/lmstudio/settings.py +29 -0
  71. letta/local_llm/ollama/api.py +88 -0
  72. letta/local_llm/ollama/settings.py +32 -0
  73. letta/local_llm/settings/__init__.py +0 -0
  74. letta/local_llm/settings/deterministic_mirostat.py +45 -0
  75. letta/local_llm/settings/settings.py +72 -0
  76. letta/local_llm/settings/simple.py +28 -0
  77. letta/local_llm/utils.py +265 -0
  78. letta/local_llm/vllm/api.py +63 -0
  79. letta/local_llm/webui/api.py +60 -0
  80. letta/local_llm/webui/legacy_api.py +58 -0
  81. letta/local_llm/webui/legacy_settings.py +23 -0
  82. letta/local_llm/webui/settings.py +24 -0
  83. letta/log.py +76 -0
  84. letta/main.py +437 -0
  85. letta/memory.py +440 -0
  86. letta/metadata.py +884 -0
  87. letta/openai_backcompat/__init__.py +0 -0
  88. letta/openai_backcompat/openai_object.py +437 -0
  89. letta/persistence_manager.py +148 -0
  90. letta/personas/__init__.py +0 -0
  91. letta/personas/examples/anna_pa.txt +13 -0
  92. letta/personas/examples/google_search_persona.txt +15 -0
  93. letta/personas/examples/memgpt_doc.txt +6 -0
  94. letta/personas/examples/memgpt_starter.txt +4 -0
  95. letta/personas/examples/sam.txt +14 -0
  96. letta/personas/examples/sam_pov.txt +14 -0
  97. letta/personas/examples/sam_simple_pov_gpt35.txt +13 -0
  98. letta/personas/examples/sqldb/test.db +0 -0
  99. letta/prompts/__init__.py +0 -0
  100. letta/prompts/gpt_summarize.py +14 -0
  101. letta/prompts/gpt_system.py +26 -0
  102. letta/prompts/system/memgpt_base.txt +49 -0
  103. letta/prompts/system/memgpt_chat.txt +58 -0
  104. letta/prompts/system/memgpt_chat_compressed.txt +13 -0
  105. letta/prompts/system/memgpt_chat_fstring.txt +51 -0
  106. letta/prompts/system/memgpt_doc.txt +50 -0
  107. letta/prompts/system/memgpt_gpt35_extralong.txt +53 -0
  108. letta/prompts/system/memgpt_intuitive_knowledge.txt +31 -0
  109. letta/prompts/system/memgpt_modified_chat.txt +23 -0
  110. letta/pytest.ini +0 -0
  111. letta/schemas/agent.py +117 -0
  112. letta/schemas/api_key.py +21 -0
  113. letta/schemas/block.py +135 -0
  114. letta/schemas/document.py +21 -0
  115. letta/schemas/embedding_config.py +54 -0
  116. letta/schemas/enums.py +35 -0
  117. letta/schemas/job.py +38 -0
  118. letta/schemas/letta_base.py +80 -0
  119. letta/schemas/letta_message.py +175 -0
  120. letta/schemas/letta_request.py +23 -0
  121. letta/schemas/letta_response.py +28 -0
  122. letta/schemas/llm_config.py +54 -0
  123. letta/schemas/memory.py +224 -0
  124. letta/schemas/message.py +727 -0
  125. letta/schemas/openai/chat_completion_request.py +123 -0
  126. letta/schemas/openai/chat_completion_response.py +136 -0
  127. letta/schemas/openai/chat_completions.py +123 -0
  128. letta/schemas/openai/embedding_response.py +11 -0
  129. letta/schemas/openai/openai.py +157 -0
  130. letta/schemas/organization.py +20 -0
  131. letta/schemas/passage.py +80 -0
  132. letta/schemas/source.py +62 -0
  133. letta/schemas/tool.py +143 -0
  134. letta/schemas/usage.py +18 -0
  135. letta/schemas/user.py +33 -0
  136. letta/server/__init__.py +0 -0
  137. letta/server/constants.py +6 -0
  138. letta/server/rest_api/__init__.py +0 -0
  139. letta/server/rest_api/admin/__init__.py +0 -0
  140. letta/server/rest_api/admin/agents.py +21 -0
  141. letta/server/rest_api/admin/tools.py +83 -0
  142. letta/server/rest_api/admin/users.py +98 -0
  143. letta/server/rest_api/app.py +193 -0
  144. letta/server/rest_api/auth/__init__.py +0 -0
  145. letta/server/rest_api/auth/index.py +43 -0
  146. letta/server/rest_api/auth_token.py +22 -0
  147. letta/server/rest_api/interface.py +726 -0
  148. letta/server/rest_api/routers/__init__.py +0 -0
  149. letta/server/rest_api/routers/openai/__init__.py +0 -0
  150. letta/server/rest_api/routers/openai/assistants/__init__.py +0 -0
  151. letta/server/rest_api/routers/openai/assistants/assistants.py +115 -0
  152. letta/server/rest_api/routers/openai/assistants/schemas.py +121 -0
  153. letta/server/rest_api/routers/openai/assistants/threads.py +336 -0
  154. letta/server/rest_api/routers/openai/chat_completions/__init__.py +0 -0
  155. letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +131 -0
  156. letta/server/rest_api/routers/v1/__init__.py +15 -0
  157. letta/server/rest_api/routers/v1/agents.py +543 -0
  158. letta/server/rest_api/routers/v1/blocks.py +73 -0
  159. letta/server/rest_api/routers/v1/jobs.py +46 -0
  160. letta/server/rest_api/routers/v1/llms.py +28 -0
  161. letta/server/rest_api/routers/v1/organizations.py +61 -0
  162. letta/server/rest_api/routers/v1/sources.py +199 -0
  163. letta/server/rest_api/routers/v1/tools.py +103 -0
  164. letta/server/rest_api/routers/v1/users.py +109 -0
  165. letta/server/rest_api/static_files.py +74 -0
  166. letta/server/rest_api/utils.py +69 -0
  167. letta/server/server.py +1995 -0
  168. letta/server/startup.sh +8 -0
  169. letta/server/static_files/assets/index-0cbf7ad5.js +274 -0
  170. letta/server/static_files/assets/index-156816da.css +1 -0
  171. letta/server/static_files/assets/index-486e3228.js +274 -0
  172. letta/server/static_files/favicon.ico +0 -0
  173. letta/server/static_files/index.html +39 -0
  174. letta/server/static_files/memgpt_logo_transparent.png +0 -0
  175. letta/server/utils.py +46 -0
  176. letta/server/ws_api/__init__.py +0 -0
  177. letta/server/ws_api/example_client.py +104 -0
  178. letta/server/ws_api/interface.py +108 -0
  179. letta/server/ws_api/protocol.py +100 -0
  180. letta/server/ws_api/server.py +145 -0
  181. letta/settings.py +165 -0
  182. letta/streaming_interface.py +396 -0
  183. letta/system.py +207 -0
  184. letta/utils.py +1065 -0
  185. letta_nightly-0.1.7.dev20240924104148.dist-info/LICENSE +190 -0
  186. letta_nightly-0.1.7.dev20240924104148.dist-info/METADATA +98 -0
  187. letta_nightly-0.1.7.dev20240924104148.dist-info/RECORD +189 -0
  188. letta_nightly-0.1.7.dev20240924104148.dist-info/WHEEL +4 -0
  189. letta_nightly-0.1.7.dev20240924104148.dist-info/entry_points.txt +3 -0
letta/metadata.py ADDED
@@ -0,0 +1,884 @@
1
+ """ Metadata store for user/agent/data_source information"""
2
+
3
+ import os
4
+ import secrets
5
+ from typing import List, Optional
6
+
7
+ from sqlalchemy import (
8
+ BIGINT,
9
+ JSON,
10
+ Boolean,
11
+ Column,
12
+ DateTime,
13
+ Index,
14
+ String,
15
+ TypeDecorator,
16
+ desc,
17
+ func,
18
+ )
19
+ from sqlalchemy.orm import declarative_base
20
+ from sqlalchemy.sql import func
21
+
22
+ from letta.config import LettaConfig
23
+ from letta.schemas.agent import AgentState
24
+ from letta.schemas.api_key import APIKey
25
+ from letta.schemas.block import Block, Human, Persona
26
+ from letta.schemas.embedding_config import EmbeddingConfig
27
+ from letta.schemas.enums import JobStatus
28
+ from letta.schemas.job import Job
29
+ from letta.schemas.llm_config import LLMConfig
30
+ from letta.schemas.memory import Memory
31
+ from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
32
+ from letta.schemas.organization import Organization
33
+ from letta.schemas.source import Source
34
+ from letta.schemas.tool import Tool
35
+ from letta.schemas.user import User
36
+ from letta.settings import settings
37
+ from letta.utils import enforce_types, get_utc_time, printd
38
+
39
+ Base = declarative_base()
40
+
41
+
42
+ class LLMConfigColumn(TypeDecorator):
43
+ """Custom type for storing LLMConfig as JSON"""
44
+
45
+ impl = JSON
46
+ cache_ok = True
47
+
48
+ def load_dialect_impl(self, dialect):
49
+ return dialect.type_descriptor(JSON())
50
+
51
+ def process_bind_param(self, value, dialect):
52
+ if value:
53
+ # return vars(value)
54
+ if isinstance(value, LLMConfig):
55
+ return value.model_dump()
56
+ return value
57
+
58
+ def process_result_value(self, value, dialect):
59
+ if value:
60
+ return LLMConfig(**value)
61
+ return value
62
+
63
+
64
+ class EmbeddingConfigColumn(TypeDecorator):
65
+ """Custom type for storing EmbeddingConfig as JSON"""
66
+
67
+ impl = JSON
68
+ cache_ok = True
69
+
70
+ def load_dialect_impl(self, dialect):
71
+ return dialect.type_descriptor(JSON())
72
+
73
+ def process_bind_param(self, value, dialect):
74
+ if value:
75
+ # return vars(value)
76
+ if isinstance(value, EmbeddingConfig):
77
+ return value.model_dump()
78
+ return value
79
+
80
+ def process_result_value(self, value, dialect):
81
+ if value:
82
+ return EmbeddingConfig(**value)
83
+ return value
84
+
85
+
86
+ class ToolCallColumn(TypeDecorator):
87
+
88
+ impl = JSON
89
+ cache_ok = True
90
+
91
+ def load_dialect_impl(self, dialect):
92
+ return dialect.type_descriptor(JSON())
93
+
94
+ def process_bind_param(self, value, dialect):
95
+ if value:
96
+ values = []
97
+ for v in value:
98
+ if isinstance(v, ToolCall):
99
+ values.append(v.model_dump())
100
+ else:
101
+ values.append(v)
102
+ return values
103
+
104
+ return value
105
+
106
+ def process_result_value(self, value, dialect):
107
+ if value:
108
+ tools = []
109
+ for tool_value in value:
110
+ if "function" in tool_value:
111
+ tool_call_function = ToolCallFunction(**tool_value["function"])
112
+ del tool_value["function"]
113
+ else:
114
+ tool_call_function = None
115
+ tools.append(ToolCall(function=tool_call_function, **tool_value))
116
+ return tools
117
+ return value
118
+
119
+
120
+ class UserModel(Base):
121
+ __tablename__ = "users"
122
+ __table_args__ = {"extend_existing": True}
123
+
124
+ id = Column(String, primary_key=True)
125
+ org_id = Column(String)
126
+ name = Column(String, nullable=False)
127
+ created_at = Column(DateTime(timezone=True))
128
+
129
+ # TODO: what is this?
130
+ policies_accepted = Column(Boolean, nullable=False, default=False)
131
+
132
+ def __repr__(self) -> str:
133
+ return f"<User(id='{self.id}' name='{self.name}')>"
134
+
135
+ def to_record(self) -> User:
136
+ return User(id=self.id, name=self.name, created_at=self.created_at, org_id=self.org_id)
137
+
138
+
139
+ class OrganizationModel(Base):
140
+ __tablename__ = "organizations"
141
+ __table_args__ = {"extend_existing": True}
142
+
143
+ id = Column(String, primary_key=True)
144
+ name = Column(String, nullable=False)
145
+ created_at = Column(DateTime(timezone=True))
146
+
147
+ def __repr__(self) -> str:
148
+ return f"<Organization(id='{self.id}' name='{self.name}')>"
149
+
150
+ def to_record(self) -> Organization:
151
+ return Organization(id=self.id, name=self.name, created_at=self.created_at)
152
+
153
+
154
+ class APIKeyModel(Base):
155
+ """Data model for authentication tokens. One-to-many relationship with UserModel (1 User - N tokens)."""
156
+
157
+ __tablename__ = "tokens"
158
+
159
+ id = Column(String, primary_key=True)
160
+ # each api key is tied to a user account (that it validates access for)
161
+ user_id = Column(String, nullable=False)
162
+ # the api key
163
+ key = Column(String, nullable=False)
164
+ # extra (optional) metadata
165
+ name = Column(String)
166
+
167
+ Index(__tablename__ + "_idx_user", user_id),
168
+ Index(__tablename__ + "_idx_key", key),
169
+
170
+ def __repr__(self) -> str:
171
+ return f"<APIKey(id='{self.id}', key='{self.key}', name='{self.name}')>"
172
+
173
+ def to_record(self) -> User:
174
+ return APIKey(
175
+ id=self.id,
176
+ user_id=self.user_id,
177
+ key=self.key,
178
+ name=self.name,
179
+ )
180
+
181
+
182
+ def generate_api_key(prefix="sk-", length=51) -> str:
183
+ # Generate 'length // 2' bytes because each byte becomes two hex digits. Adjust length for prefix.
184
+ actual_length = max(length - len(prefix), 1) // 2 # Ensure at least 1 byte is generated
185
+ random_bytes = secrets.token_bytes(actual_length)
186
+ new_key = prefix + random_bytes.hex()
187
+ return new_key
188
+
189
+
190
+ class AgentModel(Base):
191
+ """Defines data model for storing Passages (consisting of text, embedding)"""
192
+
193
+ __tablename__ = "agents"
194
+ __table_args__ = {"extend_existing": True}
195
+
196
+ id = Column(String, primary_key=True)
197
+ user_id = Column(String, nullable=False)
198
+ name = Column(String, nullable=False)
199
+ created_at = Column(DateTime(timezone=True), server_default=func.now())
200
+ description = Column(String)
201
+
202
+ # state (context compilation)
203
+ message_ids = Column(JSON)
204
+ memory = Column(JSON)
205
+ system = Column(String)
206
+ tools = Column(JSON)
207
+
208
+ # configs
209
+ llm_config = Column(LLMConfigColumn)
210
+ embedding_config = Column(EmbeddingConfigColumn)
211
+
212
+ # state
213
+ metadata_ = Column(JSON)
214
+
215
+ # tools
216
+ tools = Column(JSON)
217
+
218
+ Index(__tablename__ + "_idx_user", user_id),
219
+
220
+ def __repr__(self) -> str:
221
+ return f"<Agent(id='{self.id}', name='{self.name}')>"
222
+
223
+ def to_record(self) -> AgentState:
224
+ return AgentState(
225
+ id=self.id,
226
+ user_id=self.user_id,
227
+ name=self.name,
228
+ created_at=self.created_at,
229
+ description=self.description,
230
+ message_ids=self.message_ids,
231
+ memory=Memory.load(self.memory), # load dictionary
232
+ system=self.system,
233
+ tools=self.tools,
234
+ llm_config=self.llm_config,
235
+ embedding_config=self.embedding_config,
236
+ metadata_=self.metadata_,
237
+ )
238
+
239
+
240
+ class SourceModel(Base):
241
+ """Defines data model for storing Passages (consisting of text, embedding)"""
242
+
243
+ __tablename__ = "sources"
244
+ __table_args__ = {"extend_existing": True}
245
+
246
+ # Assuming passage_id is the primary key
247
+ # id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
248
+ id = Column(String, primary_key=True)
249
+ user_id = Column(String, nullable=False)
250
+ name = Column(String, nullable=False)
251
+ created_at = Column(DateTime(timezone=True), server_default=func.now())
252
+ embedding_config = Column(EmbeddingConfigColumn)
253
+ description = Column(String)
254
+ metadata_ = Column(JSON)
255
+ Index(__tablename__ + "_idx_user", user_id),
256
+
257
+ # TODO: add num passages
258
+
259
+ def __repr__(self) -> str:
260
+ return f"<Source(passage_id='{self.id}', name='{self.name}')>"
261
+
262
+ def to_record(self) -> Source:
263
+ return Source(
264
+ id=self.id,
265
+ user_id=self.user_id,
266
+ name=self.name,
267
+ created_at=self.created_at,
268
+ embedding_config=self.embedding_config,
269
+ description=self.description,
270
+ metadata_=self.metadata_,
271
+ )
272
+
273
+
274
+ class AgentSourceMappingModel(Base):
275
+ """Stores mapping between agent -> source"""
276
+
277
+ __tablename__ = "agent_source_mapping"
278
+
279
+ id = Column(String, primary_key=True)
280
+ user_id = Column(String, nullable=False)
281
+ agent_id = Column(String, nullable=False)
282
+ source_id = Column(String, nullable=False)
283
+ Index(__tablename__ + "_idx_user", user_id, agent_id, source_id),
284
+
285
+ def __repr__(self) -> str:
286
+ return f"<AgentSourceMapping(user_id='{self.user_id}', agent_id='{self.agent_id}', source_id='{self.source_id}')>"
287
+
288
+
289
+ class BlockModel(Base):
290
+ __tablename__ = "block"
291
+ __table_args__ = {"extend_existing": True}
292
+
293
+ id = Column(String, primary_key=True, nullable=False)
294
+ value = Column(String, nullable=False)
295
+ limit = Column(BIGINT)
296
+ name = Column(String, nullable=False)
297
+ template = Column(Boolean, default=False) # True: listed as possible human/persona
298
+ label = Column(String)
299
+ metadata_ = Column(JSON)
300
+ description = Column(String)
301
+ user_id = Column(String)
302
+ Index(__tablename__ + "_idx_user", user_id),
303
+
304
+ def __repr__(self) -> str:
305
+ return f"<Block(id='{self.id}', name='{self.name}', template='{self.template}', label='{self.label}', user_id='{self.user_id}')>"
306
+
307
+ def to_record(self) -> Block:
308
+ if self.label == "persona":
309
+ return Persona(
310
+ id=self.id,
311
+ value=self.value,
312
+ limit=self.limit,
313
+ name=self.name,
314
+ template=self.template,
315
+ label=self.label,
316
+ metadata_=self.metadata_,
317
+ description=self.description,
318
+ user_id=self.user_id,
319
+ )
320
+ elif self.label == "human":
321
+ return Human(
322
+ id=self.id,
323
+ value=self.value,
324
+ limit=self.limit,
325
+ name=self.name,
326
+ template=self.template,
327
+ label=self.label,
328
+ metadata_=self.metadata_,
329
+ description=self.description,
330
+ user_id=self.user_id,
331
+ )
332
+ else:
333
+ return Block(
334
+ id=self.id,
335
+ value=self.value,
336
+ limit=self.limit,
337
+ name=self.name,
338
+ template=self.template,
339
+ label=self.label,
340
+ metadata_=self.metadata_,
341
+ description=self.description,
342
+ user_id=self.user_id,
343
+ )
344
+
345
+
346
+ class ToolModel(Base):
347
+ __tablename__ = "tools"
348
+ __table_args__ = {"extend_existing": True}
349
+
350
+ id = Column(String, primary_key=True)
351
+ name = Column(String, nullable=False)
352
+ user_id = Column(String)
353
+ description = Column(String)
354
+ source_type = Column(String)
355
+ source_code = Column(String)
356
+ json_schema = Column(JSON)
357
+ module = Column(String)
358
+ tags = Column(JSON)
359
+
360
+ def __repr__(self) -> str:
361
+ return f"<Tool(id='{self.id}', name='{self.name}')>"
362
+
363
+ def to_record(self) -> Tool:
364
+ return Tool(
365
+ id=self.id,
366
+ name=self.name,
367
+ user_id=self.user_id,
368
+ description=self.description,
369
+ source_type=self.source_type,
370
+ source_code=self.source_code,
371
+ json_schema=self.json_schema,
372
+ module=self.module,
373
+ tags=self.tags,
374
+ )
375
+
376
+
377
+ class JobModel(Base):
378
+ __tablename__ = "jobs"
379
+ __table_args__ = {"extend_existing": True}
380
+
381
+ id = Column(String, primary_key=True)
382
+ user_id = Column(String)
383
+ status = Column(String, default=JobStatus.pending)
384
+ created_at = Column(DateTime(timezone=True), server_default=func.now())
385
+ completed_at = Column(DateTime(timezone=True), onupdate=func.now())
386
+ metadata_ = Column(JSON)
387
+
388
+ def __repr__(self) -> str:
389
+ return f"<Job(id='{self.id}', status='{self.status}')>"
390
+
391
+ def to_record(self):
392
+ return Job(
393
+ id=self.id,
394
+ user_id=self.user_id,
395
+ status=self.status,
396
+ created_at=self.created_at,
397
+ completed_at=self.completed_at,
398
+ metadata_=self.metadata_,
399
+ )
400
+
401
+
402
+ class MetadataStore:
403
+ uri: Optional[str] = None
404
+
405
+ def __init__(self, config: LettaConfig):
406
+ # TODO: get DB URI or path
407
+ if config.metadata_storage_type == "postgres":
408
+ # construct URI from enviornment variables
409
+ self.uri = settings.pg_uri if settings.pg_uri else config.metadata_storage_uri
410
+
411
+ elif config.metadata_storage_type == "sqlite":
412
+ path = os.path.join(config.metadata_storage_path, "sqlite.db")
413
+ self.uri = f"sqlite:///{path}"
414
+ else:
415
+ raise ValueError(f"Invalid metadata storage type: {config.metadata_storage_type}")
416
+
417
+ # Ensure valid URI
418
+ assert self.uri, "Database URI is not provided or is invalid."
419
+
420
+ from letta.server.server import db_context
421
+
422
+ self.session_maker = db_context
423
+
424
+ @enforce_types
425
+ def create_api_key(self, user_id: str, name: str) -> APIKey:
426
+ """Create an API key for a user"""
427
+ new_api_key = generate_api_key()
428
+ with self.session_maker() as session:
429
+ if session.query(APIKeyModel).filter(APIKeyModel.key == new_api_key).count() > 0:
430
+ # NOTE duplicate API keys / tokens should never happen, but if it does don't allow it
431
+ raise ValueError(f"Token {new_api_key} already exists")
432
+ # TODO store the API keys as hashed
433
+ assert user_id and name, "User ID and name must be provided"
434
+ token = APIKey(user_id=user_id, key=new_api_key, name=name)
435
+ session.add(APIKeyModel(**vars(token)))
436
+ session.commit()
437
+ return self.get_api_key(api_key=new_api_key)
438
+
439
+ @enforce_types
440
+ def delete_api_key(self, api_key: str):
441
+ """Delete an API key from the database"""
442
+ with self.session_maker() as session:
443
+ session.query(APIKeyModel).filter(APIKeyModel.key == api_key).delete()
444
+ session.commit()
445
+
446
+ @enforce_types
447
+ def get_api_key(self, api_key: str) -> Optional[APIKey]:
448
+ with self.session_maker() as session:
449
+ results = session.query(APIKeyModel).filter(APIKeyModel.key == api_key).all()
450
+ if len(results) == 0:
451
+ return None
452
+ assert len(results) == 1, f"Expected 1 result, got {len(results)}" # should only be one result
453
+ return results[0].to_record()
454
+
455
+ @enforce_types
456
+ def get_all_api_keys_for_user(self, user_id: str) -> List[APIKey]:
457
+ with self.session_maker() as session:
458
+ results = session.query(APIKeyModel).filter(APIKeyModel.user_id == user_id).all()
459
+ tokens = [r.to_record() for r in results]
460
+ return tokens
461
+
462
+ @enforce_types
463
+ def get_user_from_api_key(self, api_key: str) -> Optional[User]:
464
+ """Get the user associated with a given API key"""
465
+ token = self.get_api_key(api_key=api_key)
466
+ if token is None:
467
+ raise ValueError(f"Provided token does not exist")
468
+ else:
469
+ return self.get_user(user_id=token.user_id)
470
+
471
+ @enforce_types
472
+ def create_agent(self, agent: AgentState):
473
+ # insert into agent table
474
+ # make sure agent.name does not already exist for user user_id
475
+ with self.session_maker() as session:
476
+ if session.query(AgentModel).filter(AgentModel.name == agent.name).filter(AgentModel.user_id == agent.user_id).count() > 0:
477
+ raise ValueError(f"Agent with name {agent.name} already exists")
478
+ fields = vars(agent)
479
+ fields["memory"] = agent.memory.to_dict()
480
+ session.add(AgentModel(**fields))
481
+ session.commit()
482
+
483
+ @enforce_types
484
+ def create_source(self, source: Source):
485
+ with self.session_maker() as session:
486
+ if session.query(SourceModel).filter(SourceModel.name == source.name).filter(SourceModel.user_id == source.user_id).count() > 0:
487
+ raise ValueError(f"Source with name {source.name} already exists for user {source.user_id}")
488
+ session.add(SourceModel(**vars(source)))
489
+ session.commit()
490
+
491
+ @enforce_types
492
+ def create_user(self, user: User):
493
+ with self.session_maker() as session:
494
+ if session.query(UserModel).filter(UserModel.id == user.id).count() > 0:
495
+ raise ValueError(f"User with id {user.id} already exists")
496
+ session.add(UserModel(**vars(user)))
497
+ session.commit()
498
+
499
+ @enforce_types
500
+ def create_organization(self, organization: Organization):
501
+ with self.session_maker() as session:
502
+ if session.query(OrganizationModel).filter(OrganizationModel.id == organization.id).count() > 0:
503
+ raise ValueError(f"Organization with id {organization.id} already exists")
504
+ session.add(OrganizationModel(**vars(organization)))
505
+ session.commit()
506
+
507
+ @enforce_types
508
+ def create_block(self, block: Block):
509
+ with self.session_maker() as session:
510
+ # TODO: fix?
511
+ # we are only validating that more than one template block
512
+ # with a given name doesn't exist.
513
+ if (
514
+ session.query(BlockModel)
515
+ .filter(BlockModel.name == block.name)
516
+ .filter(BlockModel.user_id == block.user_id)
517
+ .filter(BlockModel.template == True)
518
+ .filter(BlockModel.label == block.label)
519
+ .count()
520
+ > 0
521
+ ):
522
+
523
+ raise ValueError(f"Block with name {block.name} already exists")
524
+ session.add(BlockModel(**vars(block)))
525
+ session.commit()
526
+
527
+ @enforce_types
528
+ def create_tool(self, tool: Tool):
529
+ with self.session_maker() as session:
530
+ if self.get_tool(tool_name=tool.name, user_id=tool.user_id) is not None:
531
+ raise ValueError(f"Tool with name {tool.name} already exists")
532
+ session.add(ToolModel(**vars(tool)))
533
+ session.commit()
534
+
535
+ @enforce_types
536
+ def update_agent(self, agent: AgentState):
537
+ with self.session_maker() as session:
538
+ fields = vars(agent)
539
+ if isinstance(agent.memory, Memory): # TODO: this is nasty but this whole class will soon be removed so whatever
540
+ fields["memory"] = agent.memory.to_dict()
541
+ session.query(AgentModel).filter(AgentModel.id == agent.id).update(fields)
542
+ session.commit()
543
+
544
+ @enforce_types
545
+ def update_user(self, user: User):
546
+ with self.session_maker() as session:
547
+ session.query(UserModel).filter(UserModel.id == user.id).update(vars(user))
548
+ session.commit()
549
+
550
+ @enforce_types
551
+ def update_source(self, source: Source):
552
+ with self.session_maker() as session:
553
+ session.query(SourceModel).filter(SourceModel.id == source.id).update(vars(source))
554
+ session.commit()
555
+
556
+ @enforce_types
557
+ def update_block(self, block: Block):
558
+ with self.session_maker() as session:
559
+ session.query(BlockModel).filter(BlockModel.id == block.id).update(vars(block))
560
+ session.commit()
561
+
562
+ @enforce_types
563
+ def update_or_create_block(self, block: Block):
564
+ with self.session_maker() as session:
565
+ existing_block = session.query(BlockModel).filter(BlockModel.id == block.id).first()
566
+ if existing_block:
567
+ session.query(BlockModel).filter(BlockModel.id == block.id).update(vars(block))
568
+ else:
569
+ session.add(BlockModel(**vars(block)))
570
+ session.commit()
571
+
572
+ @enforce_types
573
+ def update_tool(self, tool: Tool):
574
+ with self.session_maker() as session:
575
+ session.query(ToolModel).filter(ToolModel.id == tool.id).update(vars(tool))
576
+ session.commit()
577
+
578
+ @enforce_types
579
+ def delete_tool(self, tool_id: str):
580
+ with self.session_maker() as session:
581
+ session.query(ToolModel).filter(ToolModel.id == tool_id).delete()
582
+ session.commit()
583
+
584
+ @enforce_types
585
+ def delete_block(self, block_id: str):
586
+ with self.session_maker() as session:
587
+ session.query(BlockModel).filter(BlockModel.id == block_id).delete()
588
+ session.commit()
589
+
590
+ @enforce_types
591
+ def delete_agent(self, agent_id: str):
592
+ with self.session_maker() as session:
593
+
594
+ # delete agents
595
+ session.query(AgentModel).filter(AgentModel.id == agent_id).delete()
596
+
597
+ # delete mappings
598
+ session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.agent_id == agent_id).delete()
599
+
600
+ session.commit()
601
+
602
+ @enforce_types
603
+ def delete_source(self, source_id: str):
604
+ with self.session_maker() as session:
605
+ # delete from sources table
606
+ session.query(SourceModel).filter(SourceModel.id == source_id).delete()
607
+
608
+ # delete any mappings
609
+ session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).delete()
610
+
611
+ session.commit()
612
+
613
+ @enforce_types
614
+ def delete_user(self, user_id: str):
615
+ with self.session_maker() as session:
616
+ # delete from users table
617
+ session.query(UserModel).filter(UserModel.id == user_id).delete()
618
+
619
+ # delete associated agents
620
+ session.query(AgentModel).filter(AgentModel.user_id == user_id).delete()
621
+
622
+ # delete associated sources
623
+ session.query(SourceModel).filter(SourceModel.user_id == user_id).delete()
624
+
625
+ # delete associated mappings
626
+ session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.user_id == user_id).delete()
627
+
628
+ session.commit()
629
+
630
+ @enforce_types
631
+ def delete_organization(self, org_id: str):
632
+ with self.session_maker() as session:
633
+ # delete from organizations table
634
+ session.query(OrganizationModel).filter(OrganizationModel.id == org_id).delete()
635
+
636
+ # TODO: delete associated data
637
+
638
+ session.commit()
639
+
640
+ @enforce_types
641
+ # def list_tools(self, user_id: str) -> List[ToolModel]: # TODO: add when users can creat tools
642
+ def list_tools(self, user_id: Optional[str] = None) -> List[ToolModel]:
643
+ with self.session_maker() as session:
644
+ results = session.query(ToolModel).filter(ToolModel.user_id == None).all()
645
+ if user_id:
646
+ results += session.query(ToolModel).filter(ToolModel.user_id == user_id).all()
647
+ res = [r.to_record() for r in results]
648
+ return res
649
+
650
+ @enforce_types
651
+ def list_agents(self, user_id: str) -> List[AgentState]:
652
+ with self.session_maker() as session:
653
+ results = session.query(AgentModel).filter(AgentModel.user_id == user_id).all()
654
+ return [r.to_record() for r in results]
655
+
656
+ @enforce_types
657
+ def list_sources(self, user_id: str) -> List[Source]:
658
+ with self.session_maker() as session:
659
+ results = session.query(SourceModel).filter(SourceModel.user_id == user_id).all()
660
+ return [r.to_record() for r in results]
661
+
662
+ @enforce_types
663
+ def get_agent(
664
+ self, agent_id: Optional[str] = None, agent_name: Optional[str] = None, user_id: Optional[str] = None
665
+ ) -> Optional[AgentState]:
666
+ with self.session_maker() as session:
667
+ if agent_id:
668
+ results = session.query(AgentModel).filter(AgentModel.id == agent_id).all()
669
+ else:
670
+ assert agent_name is not None and user_id is not None, "Must provide either agent_id or agent_name"
671
+ results = session.query(AgentModel).filter(AgentModel.name == agent_name).filter(AgentModel.user_id == user_id).all()
672
+
673
+ if len(results) == 0:
674
+ return None
675
+ assert len(results) == 1, f"Expected 1 result, got {len(results)}" # should only be one result
676
+ return results[0].to_record()
677
+
678
+ @enforce_types
679
+ def get_user(self, user_id: str) -> Optional[User]:
680
+ with self.session_maker() as session:
681
+ results = session.query(UserModel).filter(UserModel.id == user_id).all()
682
+ if len(results) == 0:
683
+ return None
684
+ assert len(results) == 1, f"Expected 1 result, got {len(results)}"
685
+ return results[0].to_record()
686
+
687
+ @enforce_types
688
+ def get_organization(self, org_id: str) -> Optional[Organization]:
689
+ with self.session_maker() as session:
690
+ results = session.query(OrganizationModel).filter(OrganizationModel.id == org_id).all()
691
+ if len(results) == 0:
692
+ return None
693
+ assert len(results) == 1, f"Expected 1 result, got {len(results)}"
694
+ return results[0].to_record()
695
+
696
+ @enforce_types
697
+ def list_organizations(self, cursor: Optional[str] = None, limit: Optional[int] = 50):
698
+ with self.session_maker() as session:
699
+ query = session.query(OrganizationModel).order_by(desc(OrganizationModel.id))
700
+ if cursor:
701
+ query = query.filter(OrganizationModel.id < cursor)
702
+ results = query.limit(limit).all()
703
+ if not results:
704
+ return None, []
705
+ organization_records = [r.to_record() for r in results]
706
+ next_cursor = organization_records[-1].id
707
+ assert isinstance(next_cursor, str)
708
+
709
+ return next_cursor, organization_records
710
+
711
+ @enforce_types
712
+ def get_all_users(self, cursor: Optional[str] = None, limit: Optional[int] = 50):
713
+ with self.session_maker() as session:
714
+ query = session.query(UserModel).order_by(desc(UserModel.id))
715
+ if cursor:
716
+ query = query.filter(UserModel.id < cursor)
717
+ results = query.limit(limit).all()
718
+ if not results:
719
+ return None, []
720
+ user_records = [r.to_record() for r in results]
721
+ next_cursor = user_records[-1].id
722
+ assert isinstance(next_cursor, str)
723
+
724
+ return next_cursor, user_records
725
+
726
+ @enforce_types
727
+ def get_source(
728
+ self, source_id: Optional[str] = None, user_id: Optional[str] = None, source_name: Optional[str] = None
729
+ ) -> Optional[Source]:
730
+ with self.session_maker() as session:
731
+ if source_id:
732
+ results = session.query(SourceModel).filter(SourceModel.id == source_id).all()
733
+ else:
734
+ assert user_id is not None and source_name is not None
735
+ results = session.query(SourceModel).filter(SourceModel.name == source_name).filter(SourceModel.user_id == user_id).all()
736
+ if len(results) == 0:
737
+ return None
738
+ assert len(results) == 1, f"Expected 1 result, got {len(results)}"
739
+ return results[0].to_record()
740
+
741
+ @enforce_types
742
+ def get_tool(
743
+ self, tool_name: Optional[str] = None, tool_id: Optional[str] = None, user_id: Optional[str] = None
744
+ ) -> Optional[ToolModel]:
745
+ with self.session_maker() as session:
746
+ if tool_id:
747
+ results = session.query(ToolModel).filter(ToolModel.id == tool_id).all()
748
+ else:
749
+ assert tool_name is not None
750
+ results = session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == None).all()
751
+ if user_id:
752
+ results += session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == user_id).all()
753
+ if len(results) == 0:
754
+ return None
755
+ assert len(results) == 1, f"Expected 1 result, got {len(results)}"
756
+ return results[0].to_record()
757
+
758
+ @enforce_types
759
+ def get_block(self, block_id: str) -> Optional[Block]:
760
+ with self.session_maker() as session:
761
+ results = session.query(BlockModel).filter(BlockModel.id == block_id).all()
762
+ if len(results) == 0:
763
+ return None
764
+ assert len(results) == 1, f"Expected 1 result, got {len(results)}"
765
+ return results[0].to_record()
766
+
767
+ @enforce_types
768
+ def get_blocks(
769
+ self,
770
+ user_id: Optional[str],
771
+ label: Optional[str] = None,
772
+ template: Optional[bool] = None,
773
+ name: Optional[str] = None,
774
+ id: Optional[str] = None,
775
+ ) -> Optional[List[Block]]:
776
+ """List available blocks"""
777
+ with self.session_maker() as session:
778
+ query = session.query(BlockModel)
779
+
780
+ if user_id:
781
+ query = query.filter(BlockModel.user_id == user_id)
782
+
783
+ if label:
784
+ query = query.filter(BlockModel.label == label)
785
+
786
+ if name:
787
+ query = query.filter(BlockModel.name == name)
788
+
789
+ if id:
790
+ query = query.filter(BlockModel.id == id)
791
+
792
+ if template:
793
+ query = query.filter(BlockModel.template == template)
794
+
795
+ results = query.all()
796
+
797
+ if len(results) == 0:
798
+ return None
799
+
800
+ return [r.to_record() for r in results]
801
+
802
+ # agent source metadata
803
+ @enforce_types
804
+ def attach_source(self, user_id: str, agent_id: str, source_id: str):
805
+ with self.session_maker() as session:
806
+ # TODO: remove this (is a hack)
807
+ mapping_id = f"{user_id}-{agent_id}-{source_id}"
808
+ session.add(AgentSourceMappingModel(id=mapping_id, user_id=user_id, agent_id=agent_id, source_id=source_id))
809
+ session.commit()
810
+
811
+ @enforce_types
812
+ def list_attached_sources(self, agent_id: str) -> List[Source]:
813
+ with self.session_maker() as session:
814
+ results = session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.agent_id == agent_id).all()
815
+
816
+ sources = []
817
+ # make sure source exists
818
+ for r in results:
819
+ source = self.get_source(source_id=r.source_id)
820
+ if source:
821
+ sources.append(source)
822
+ else:
823
+ printd(f"Warning: source {r.source_id} does not exist but exists in mapping database. This should never happen.")
824
+ return sources
825
+
826
+ @enforce_types
827
+ def list_attached_agents(self, source_id: str) -> List[str]:
828
+ with self.session_maker() as session:
829
+ results = session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).all()
830
+
831
+ agent_ids = []
832
+ # make sure agent exists
833
+ for r in results:
834
+ agent = self.get_agent(agent_id=r.agent_id)
835
+ if agent:
836
+ agent_ids.append(r.agent_id)
837
+ else:
838
+ printd(f"Warning: agent {r.agent_id} does not exist but exists in mapping database. This should never happen.")
839
+ return agent_ids
840
+
841
+ @enforce_types
842
+ def detach_source(self, agent_id: str, source_id: str):
843
+ with self.session_maker() as session:
844
+ session.query(AgentSourceMappingModel).filter(
845
+ AgentSourceMappingModel.agent_id == agent_id, AgentSourceMappingModel.source_id == source_id
846
+ ).delete()
847
+ session.commit()
848
+
849
+ @enforce_types
850
+ def create_job(self, job: Job):
851
+ with self.session_maker() as session:
852
+ session.add(JobModel(**vars(job)))
853
+ session.commit()
854
+
855
+ def delete_job(self, job_id: str):
856
+ with self.session_maker() as session:
857
+ session.query(JobModel).filter(JobModel.id == job_id).delete()
858
+ session.commit()
859
+
860
+ def get_job(self, job_id: str) -> Optional[Job]:
861
+ with self.session_maker() as session:
862
+ results = session.query(JobModel).filter(JobModel.id == job_id).all()
863
+ if len(results) == 0:
864
+ return None
865
+ assert len(results) == 1, f"Expected 1 result, got {len(results)}"
866
+ return results[0].to_record()
867
+
868
+ def list_jobs(self, user_id: str) -> List[Job]:
869
+ with self.session_maker() as session:
870
+ results = session.query(JobModel).filter(JobModel.user_id == user_id).all()
871
+ return [r.to_record() for r in results]
872
+
873
+ def update_job(self, job: Job) -> Job:
874
+ with self.session_maker() as session:
875
+ session.query(JobModel).filter(JobModel.id == job.id).update(vars(job))
876
+ session.commit()
877
+ return Job
878
+
879
+ def update_job_status(self, job_id: str, status: JobStatus):
880
+ with self.session_maker() as session:
881
+ session.query(JobModel).filter(JobModel.id == job_id).update({"status": status})
882
+ if status == JobStatus.COMPLETED:
883
+ session.query(JobModel).filter(JobModel.id == job_id).update({"completed_at": get_utc_time()})
884
+ session.commit()