letta-nightly 0.1.7.dev20240924104148__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of letta-nightly might be problematic. Click here for more details.
- letta/__init__.py +24 -0
- letta/__main__.py +3 -0
- letta/agent.py +1427 -0
- letta/agent_store/chroma.py +295 -0
- letta/agent_store/db.py +546 -0
- letta/agent_store/lancedb.py +177 -0
- letta/agent_store/milvus.py +198 -0
- letta/agent_store/qdrant.py +201 -0
- letta/agent_store/storage.py +188 -0
- letta/benchmark/benchmark.py +96 -0
- letta/benchmark/constants.py +14 -0
- letta/cli/cli.py +689 -0
- letta/cli/cli_config.py +1282 -0
- letta/cli/cli_load.py +166 -0
- letta/client/__init__.py +0 -0
- letta/client/admin.py +171 -0
- letta/client/client.py +2360 -0
- letta/client/streaming.py +90 -0
- letta/client/utils.py +61 -0
- letta/config.py +484 -0
- letta/configs/anthropic.json +13 -0
- letta/configs/letta_hosted.json +11 -0
- letta/configs/openai.json +12 -0
- letta/constants.py +134 -0
- letta/credentials.py +140 -0
- letta/data_sources/connectors.py +247 -0
- letta/embeddings.py +218 -0
- letta/errors.py +26 -0
- letta/functions/__init__.py +0 -0
- letta/functions/function_sets/base.py +174 -0
- letta/functions/function_sets/extras.py +132 -0
- letta/functions/functions.py +105 -0
- letta/functions/schema_generator.py +205 -0
- letta/humans/__init__.py +0 -0
- letta/humans/examples/basic.txt +1 -0
- letta/humans/examples/cs_phd.txt +9 -0
- letta/interface.py +314 -0
- letta/llm_api/__init__.py +0 -0
- letta/llm_api/anthropic.py +383 -0
- letta/llm_api/azure_openai.py +155 -0
- letta/llm_api/cohere.py +396 -0
- letta/llm_api/google_ai.py +468 -0
- letta/llm_api/llm_api_tools.py +485 -0
- letta/llm_api/openai.py +470 -0
- letta/local_llm/README.md +3 -0
- letta/local_llm/__init__.py +0 -0
- letta/local_llm/chat_completion_proxy.py +279 -0
- letta/local_llm/constants.py +31 -0
- letta/local_llm/function_parser.py +68 -0
- letta/local_llm/grammars/__init__.py +0 -0
- letta/local_llm/grammars/gbnf_grammar_generator.py +1324 -0
- letta/local_llm/grammars/json.gbnf +26 -0
- letta/local_llm/grammars/json_func_calls_with_inner_thoughts.gbnf +32 -0
- letta/local_llm/groq/api.py +97 -0
- letta/local_llm/json_parser.py +202 -0
- letta/local_llm/koboldcpp/api.py +62 -0
- letta/local_llm/koboldcpp/settings.py +23 -0
- letta/local_llm/llamacpp/api.py +58 -0
- letta/local_llm/llamacpp/settings.py +22 -0
- letta/local_llm/llm_chat_completion_wrappers/__init__.py +0 -0
- letta/local_llm/llm_chat_completion_wrappers/airoboros.py +452 -0
- letta/local_llm/llm_chat_completion_wrappers/chatml.py +470 -0
- letta/local_llm/llm_chat_completion_wrappers/configurable_wrapper.py +387 -0
- letta/local_llm/llm_chat_completion_wrappers/dolphin.py +246 -0
- letta/local_llm/llm_chat_completion_wrappers/llama3.py +345 -0
- letta/local_llm/llm_chat_completion_wrappers/simple_summary_wrapper.py +156 -0
- letta/local_llm/llm_chat_completion_wrappers/wrapper_base.py +11 -0
- letta/local_llm/llm_chat_completion_wrappers/zephyr.py +345 -0
- letta/local_llm/lmstudio/api.py +100 -0
- letta/local_llm/lmstudio/settings.py +29 -0
- letta/local_llm/ollama/api.py +88 -0
- letta/local_llm/ollama/settings.py +32 -0
- letta/local_llm/settings/__init__.py +0 -0
- letta/local_llm/settings/deterministic_mirostat.py +45 -0
- letta/local_llm/settings/settings.py +72 -0
- letta/local_llm/settings/simple.py +28 -0
- letta/local_llm/utils.py +265 -0
- letta/local_llm/vllm/api.py +63 -0
- letta/local_llm/webui/api.py +60 -0
- letta/local_llm/webui/legacy_api.py +58 -0
- letta/local_llm/webui/legacy_settings.py +23 -0
- letta/local_llm/webui/settings.py +24 -0
- letta/log.py +76 -0
- letta/main.py +437 -0
- letta/memory.py +440 -0
- letta/metadata.py +884 -0
- letta/openai_backcompat/__init__.py +0 -0
- letta/openai_backcompat/openai_object.py +437 -0
- letta/persistence_manager.py +148 -0
- letta/personas/__init__.py +0 -0
- letta/personas/examples/anna_pa.txt +13 -0
- letta/personas/examples/google_search_persona.txt +15 -0
- letta/personas/examples/memgpt_doc.txt +6 -0
- letta/personas/examples/memgpt_starter.txt +4 -0
- letta/personas/examples/sam.txt +14 -0
- letta/personas/examples/sam_pov.txt +14 -0
- letta/personas/examples/sam_simple_pov_gpt35.txt +13 -0
- letta/personas/examples/sqldb/test.db +0 -0
- letta/prompts/__init__.py +0 -0
- letta/prompts/gpt_summarize.py +14 -0
- letta/prompts/gpt_system.py +26 -0
- letta/prompts/system/memgpt_base.txt +49 -0
- letta/prompts/system/memgpt_chat.txt +58 -0
- letta/prompts/system/memgpt_chat_compressed.txt +13 -0
- letta/prompts/system/memgpt_chat_fstring.txt +51 -0
- letta/prompts/system/memgpt_doc.txt +50 -0
- letta/prompts/system/memgpt_gpt35_extralong.txt +53 -0
- letta/prompts/system/memgpt_intuitive_knowledge.txt +31 -0
- letta/prompts/system/memgpt_modified_chat.txt +23 -0
- letta/pytest.ini +0 -0
- letta/schemas/agent.py +117 -0
- letta/schemas/api_key.py +21 -0
- letta/schemas/block.py +135 -0
- letta/schemas/document.py +21 -0
- letta/schemas/embedding_config.py +54 -0
- letta/schemas/enums.py +35 -0
- letta/schemas/job.py +38 -0
- letta/schemas/letta_base.py +80 -0
- letta/schemas/letta_message.py +175 -0
- letta/schemas/letta_request.py +23 -0
- letta/schemas/letta_response.py +28 -0
- letta/schemas/llm_config.py +54 -0
- letta/schemas/memory.py +224 -0
- letta/schemas/message.py +727 -0
- letta/schemas/openai/chat_completion_request.py +123 -0
- letta/schemas/openai/chat_completion_response.py +136 -0
- letta/schemas/openai/chat_completions.py +123 -0
- letta/schemas/openai/embedding_response.py +11 -0
- letta/schemas/openai/openai.py +157 -0
- letta/schemas/organization.py +20 -0
- letta/schemas/passage.py +80 -0
- letta/schemas/source.py +62 -0
- letta/schemas/tool.py +143 -0
- letta/schemas/usage.py +18 -0
- letta/schemas/user.py +33 -0
- letta/server/__init__.py +0 -0
- letta/server/constants.py +6 -0
- letta/server/rest_api/__init__.py +0 -0
- letta/server/rest_api/admin/__init__.py +0 -0
- letta/server/rest_api/admin/agents.py +21 -0
- letta/server/rest_api/admin/tools.py +83 -0
- letta/server/rest_api/admin/users.py +98 -0
- letta/server/rest_api/app.py +193 -0
- letta/server/rest_api/auth/__init__.py +0 -0
- letta/server/rest_api/auth/index.py +43 -0
- letta/server/rest_api/auth_token.py +22 -0
- letta/server/rest_api/interface.py +726 -0
- letta/server/rest_api/routers/__init__.py +0 -0
- letta/server/rest_api/routers/openai/__init__.py +0 -0
- letta/server/rest_api/routers/openai/assistants/__init__.py +0 -0
- letta/server/rest_api/routers/openai/assistants/assistants.py +115 -0
- letta/server/rest_api/routers/openai/assistants/schemas.py +121 -0
- letta/server/rest_api/routers/openai/assistants/threads.py +336 -0
- letta/server/rest_api/routers/openai/chat_completions/__init__.py +0 -0
- letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +131 -0
- letta/server/rest_api/routers/v1/__init__.py +15 -0
- letta/server/rest_api/routers/v1/agents.py +543 -0
- letta/server/rest_api/routers/v1/blocks.py +73 -0
- letta/server/rest_api/routers/v1/jobs.py +46 -0
- letta/server/rest_api/routers/v1/llms.py +28 -0
- letta/server/rest_api/routers/v1/organizations.py +61 -0
- letta/server/rest_api/routers/v1/sources.py +199 -0
- letta/server/rest_api/routers/v1/tools.py +103 -0
- letta/server/rest_api/routers/v1/users.py +109 -0
- letta/server/rest_api/static_files.py +74 -0
- letta/server/rest_api/utils.py +69 -0
- letta/server/server.py +1995 -0
- letta/server/startup.sh +8 -0
- letta/server/static_files/assets/index-0cbf7ad5.js +274 -0
- letta/server/static_files/assets/index-156816da.css +1 -0
- letta/server/static_files/assets/index-486e3228.js +274 -0
- letta/server/static_files/favicon.ico +0 -0
- letta/server/static_files/index.html +39 -0
- letta/server/static_files/memgpt_logo_transparent.png +0 -0
- letta/server/utils.py +46 -0
- letta/server/ws_api/__init__.py +0 -0
- letta/server/ws_api/example_client.py +104 -0
- letta/server/ws_api/interface.py +108 -0
- letta/server/ws_api/protocol.py +100 -0
- letta/server/ws_api/server.py +145 -0
- letta/settings.py +165 -0
- letta/streaming_interface.py +396 -0
- letta/system.py +207 -0
- letta/utils.py +1065 -0
- letta_nightly-0.1.7.dev20240924104148.dist-info/LICENSE +190 -0
- letta_nightly-0.1.7.dev20240924104148.dist-info/METADATA +98 -0
- letta_nightly-0.1.7.dev20240924104148.dist-info/RECORD +189 -0
- letta_nightly-0.1.7.dev20240924104148.dist-info/WHEEL +4 -0
- letta_nightly-0.1.7.dev20240924104148.dist-info/entry_points.txt +3 -0
letta/metadata.py
ADDED
|
@@ -0,0 +1,884 @@
|
|
|
1
|
+
""" Metadata store for user/agent/data_source information"""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import secrets
|
|
5
|
+
from typing import List, Optional
|
|
6
|
+
|
|
7
|
+
from sqlalchemy import (
|
|
8
|
+
BIGINT,
|
|
9
|
+
JSON,
|
|
10
|
+
Boolean,
|
|
11
|
+
Column,
|
|
12
|
+
DateTime,
|
|
13
|
+
Index,
|
|
14
|
+
String,
|
|
15
|
+
TypeDecorator,
|
|
16
|
+
desc,
|
|
17
|
+
func,
|
|
18
|
+
)
|
|
19
|
+
from sqlalchemy.orm import declarative_base
|
|
20
|
+
from sqlalchemy.sql import func
|
|
21
|
+
|
|
22
|
+
from letta.config import LettaConfig
|
|
23
|
+
from letta.schemas.agent import AgentState
|
|
24
|
+
from letta.schemas.api_key import APIKey
|
|
25
|
+
from letta.schemas.block import Block, Human, Persona
|
|
26
|
+
from letta.schemas.embedding_config import EmbeddingConfig
|
|
27
|
+
from letta.schemas.enums import JobStatus
|
|
28
|
+
from letta.schemas.job import Job
|
|
29
|
+
from letta.schemas.llm_config import LLMConfig
|
|
30
|
+
from letta.schemas.memory import Memory
|
|
31
|
+
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
|
|
32
|
+
from letta.schemas.organization import Organization
|
|
33
|
+
from letta.schemas.source import Source
|
|
34
|
+
from letta.schemas.tool import Tool
|
|
35
|
+
from letta.schemas.user import User
|
|
36
|
+
from letta.settings import settings
|
|
37
|
+
from letta.utils import enforce_types, get_utc_time, printd
|
|
38
|
+
|
|
39
|
+
Base = declarative_base()
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class LLMConfigColumn(TypeDecorator):
|
|
43
|
+
"""Custom type for storing LLMConfig as JSON"""
|
|
44
|
+
|
|
45
|
+
impl = JSON
|
|
46
|
+
cache_ok = True
|
|
47
|
+
|
|
48
|
+
def load_dialect_impl(self, dialect):
|
|
49
|
+
return dialect.type_descriptor(JSON())
|
|
50
|
+
|
|
51
|
+
def process_bind_param(self, value, dialect):
|
|
52
|
+
if value:
|
|
53
|
+
# return vars(value)
|
|
54
|
+
if isinstance(value, LLMConfig):
|
|
55
|
+
return value.model_dump()
|
|
56
|
+
return value
|
|
57
|
+
|
|
58
|
+
def process_result_value(self, value, dialect):
|
|
59
|
+
if value:
|
|
60
|
+
return LLMConfig(**value)
|
|
61
|
+
return value
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class EmbeddingConfigColumn(TypeDecorator):
|
|
65
|
+
"""Custom type for storing EmbeddingConfig as JSON"""
|
|
66
|
+
|
|
67
|
+
impl = JSON
|
|
68
|
+
cache_ok = True
|
|
69
|
+
|
|
70
|
+
def load_dialect_impl(self, dialect):
|
|
71
|
+
return dialect.type_descriptor(JSON())
|
|
72
|
+
|
|
73
|
+
def process_bind_param(self, value, dialect):
|
|
74
|
+
if value:
|
|
75
|
+
# return vars(value)
|
|
76
|
+
if isinstance(value, EmbeddingConfig):
|
|
77
|
+
return value.model_dump()
|
|
78
|
+
return value
|
|
79
|
+
|
|
80
|
+
def process_result_value(self, value, dialect):
|
|
81
|
+
if value:
|
|
82
|
+
return EmbeddingConfig(**value)
|
|
83
|
+
return value
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class ToolCallColumn(TypeDecorator):
|
|
87
|
+
|
|
88
|
+
impl = JSON
|
|
89
|
+
cache_ok = True
|
|
90
|
+
|
|
91
|
+
def load_dialect_impl(self, dialect):
|
|
92
|
+
return dialect.type_descriptor(JSON())
|
|
93
|
+
|
|
94
|
+
def process_bind_param(self, value, dialect):
|
|
95
|
+
if value:
|
|
96
|
+
values = []
|
|
97
|
+
for v in value:
|
|
98
|
+
if isinstance(v, ToolCall):
|
|
99
|
+
values.append(v.model_dump())
|
|
100
|
+
else:
|
|
101
|
+
values.append(v)
|
|
102
|
+
return values
|
|
103
|
+
|
|
104
|
+
return value
|
|
105
|
+
|
|
106
|
+
def process_result_value(self, value, dialect):
|
|
107
|
+
if value:
|
|
108
|
+
tools = []
|
|
109
|
+
for tool_value in value:
|
|
110
|
+
if "function" in tool_value:
|
|
111
|
+
tool_call_function = ToolCallFunction(**tool_value["function"])
|
|
112
|
+
del tool_value["function"]
|
|
113
|
+
else:
|
|
114
|
+
tool_call_function = None
|
|
115
|
+
tools.append(ToolCall(function=tool_call_function, **tool_value))
|
|
116
|
+
return tools
|
|
117
|
+
return value
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class UserModel(Base):
|
|
121
|
+
__tablename__ = "users"
|
|
122
|
+
__table_args__ = {"extend_existing": True}
|
|
123
|
+
|
|
124
|
+
id = Column(String, primary_key=True)
|
|
125
|
+
org_id = Column(String)
|
|
126
|
+
name = Column(String, nullable=False)
|
|
127
|
+
created_at = Column(DateTime(timezone=True))
|
|
128
|
+
|
|
129
|
+
# TODO: what is this?
|
|
130
|
+
policies_accepted = Column(Boolean, nullable=False, default=False)
|
|
131
|
+
|
|
132
|
+
def __repr__(self) -> str:
|
|
133
|
+
return f"<User(id='{self.id}' name='{self.name}')>"
|
|
134
|
+
|
|
135
|
+
def to_record(self) -> User:
|
|
136
|
+
return User(id=self.id, name=self.name, created_at=self.created_at, org_id=self.org_id)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class OrganizationModel(Base):
|
|
140
|
+
__tablename__ = "organizations"
|
|
141
|
+
__table_args__ = {"extend_existing": True}
|
|
142
|
+
|
|
143
|
+
id = Column(String, primary_key=True)
|
|
144
|
+
name = Column(String, nullable=False)
|
|
145
|
+
created_at = Column(DateTime(timezone=True))
|
|
146
|
+
|
|
147
|
+
def __repr__(self) -> str:
|
|
148
|
+
return f"<Organization(id='{self.id}' name='{self.name}')>"
|
|
149
|
+
|
|
150
|
+
def to_record(self) -> Organization:
|
|
151
|
+
return Organization(id=self.id, name=self.name, created_at=self.created_at)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class APIKeyModel(Base):
|
|
155
|
+
"""Data model for authentication tokens. One-to-many relationship with UserModel (1 User - N tokens)."""
|
|
156
|
+
|
|
157
|
+
__tablename__ = "tokens"
|
|
158
|
+
|
|
159
|
+
id = Column(String, primary_key=True)
|
|
160
|
+
# each api key is tied to a user account (that it validates access for)
|
|
161
|
+
user_id = Column(String, nullable=False)
|
|
162
|
+
# the api key
|
|
163
|
+
key = Column(String, nullable=False)
|
|
164
|
+
# extra (optional) metadata
|
|
165
|
+
name = Column(String)
|
|
166
|
+
|
|
167
|
+
Index(__tablename__ + "_idx_user", user_id),
|
|
168
|
+
Index(__tablename__ + "_idx_key", key),
|
|
169
|
+
|
|
170
|
+
def __repr__(self) -> str:
|
|
171
|
+
return f"<APIKey(id='{self.id}', key='{self.key}', name='{self.name}')>"
|
|
172
|
+
|
|
173
|
+
def to_record(self) -> User:
|
|
174
|
+
return APIKey(
|
|
175
|
+
id=self.id,
|
|
176
|
+
user_id=self.user_id,
|
|
177
|
+
key=self.key,
|
|
178
|
+
name=self.name,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def generate_api_key(prefix="sk-", length=51) -> str:
|
|
183
|
+
# Generate 'length // 2' bytes because each byte becomes two hex digits. Adjust length for prefix.
|
|
184
|
+
actual_length = max(length - len(prefix), 1) // 2 # Ensure at least 1 byte is generated
|
|
185
|
+
random_bytes = secrets.token_bytes(actual_length)
|
|
186
|
+
new_key = prefix + random_bytes.hex()
|
|
187
|
+
return new_key
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class AgentModel(Base):
|
|
191
|
+
"""Defines data model for storing Passages (consisting of text, embedding)"""
|
|
192
|
+
|
|
193
|
+
__tablename__ = "agents"
|
|
194
|
+
__table_args__ = {"extend_existing": True}
|
|
195
|
+
|
|
196
|
+
id = Column(String, primary_key=True)
|
|
197
|
+
user_id = Column(String, nullable=False)
|
|
198
|
+
name = Column(String, nullable=False)
|
|
199
|
+
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
|
200
|
+
description = Column(String)
|
|
201
|
+
|
|
202
|
+
# state (context compilation)
|
|
203
|
+
message_ids = Column(JSON)
|
|
204
|
+
memory = Column(JSON)
|
|
205
|
+
system = Column(String)
|
|
206
|
+
tools = Column(JSON)
|
|
207
|
+
|
|
208
|
+
# configs
|
|
209
|
+
llm_config = Column(LLMConfigColumn)
|
|
210
|
+
embedding_config = Column(EmbeddingConfigColumn)
|
|
211
|
+
|
|
212
|
+
# state
|
|
213
|
+
metadata_ = Column(JSON)
|
|
214
|
+
|
|
215
|
+
# tools
|
|
216
|
+
tools = Column(JSON)
|
|
217
|
+
|
|
218
|
+
Index(__tablename__ + "_idx_user", user_id),
|
|
219
|
+
|
|
220
|
+
def __repr__(self) -> str:
|
|
221
|
+
return f"<Agent(id='{self.id}', name='{self.name}')>"
|
|
222
|
+
|
|
223
|
+
def to_record(self) -> AgentState:
|
|
224
|
+
return AgentState(
|
|
225
|
+
id=self.id,
|
|
226
|
+
user_id=self.user_id,
|
|
227
|
+
name=self.name,
|
|
228
|
+
created_at=self.created_at,
|
|
229
|
+
description=self.description,
|
|
230
|
+
message_ids=self.message_ids,
|
|
231
|
+
memory=Memory.load(self.memory), # load dictionary
|
|
232
|
+
system=self.system,
|
|
233
|
+
tools=self.tools,
|
|
234
|
+
llm_config=self.llm_config,
|
|
235
|
+
embedding_config=self.embedding_config,
|
|
236
|
+
metadata_=self.metadata_,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
class SourceModel(Base):
|
|
241
|
+
"""Defines data model for storing Passages (consisting of text, embedding)"""
|
|
242
|
+
|
|
243
|
+
__tablename__ = "sources"
|
|
244
|
+
__table_args__ = {"extend_existing": True}
|
|
245
|
+
|
|
246
|
+
# Assuming passage_id is the primary key
|
|
247
|
+
# id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
|
248
|
+
id = Column(String, primary_key=True)
|
|
249
|
+
user_id = Column(String, nullable=False)
|
|
250
|
+
name = Column(String, nullable=False)
|
|
251
|
+
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
|
252
|
+
embedding_config = Column(EmbeddingConfigColumn)
|
|
253
|
+
description = Column(String)
|
|
254
|
+
metadata_ = Column(JSON)
|
|
255
|
+
Index(__tablename__ + "_idx_user", user_id),
|
|
256
|
+
|
|
257
|
+
# TODO: add num passages
|
|
258
|
+
|
|
259
|
+
def __repr__(self) -> str:
|
|
260
|
+
return f"<Source(passage_id='{self.id}', name='{self.name}')>"
|
|
261
|
+
|
|
262
|
+
def to_record(self) -> Source:
|
|
263
|
+
return Source(
|
|
264
|
+
id=self.id,
|
|
265
|
+
user_id=self.user_id,
|
|
266
|
+
name=self.name,
|
|
267
|
+
created_at=self.created_at,
|
|
268
|
+
embedding_config=self.embedding_config,
|
|
269
|
+
description=self.description,
|
|
270
|
+
metadata_=self.metadata_,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class AgentSourceMappingModel(Base):
|
|
275
|
+
"""Stores mapping between agent -> source"""
|
|
276
|
+
|
|
277
|
+
__tablename__ = "agent_source_mapping"
|
|
278
|
+
|
|
279
|
+
id = Column(String, primary_key=True)
|
|
280
|
+
user_id = Column(String, nullable=False)
|
|
281
|
+
agent_id = Column(String, nullable=False)
|
|
282
|
+
source_id = Column(String, nullable=False)
|
|
283
|
+
Index(__tablename__ + "_idx_user", user_id, agent_id, source_id),
|
|
284
|
+
|
|
285
|
+
def __repr__(self) -> str:
|
|
286
|
+
return f"<AgentSourceMapping(user_id='{self.user_id}', agent_id='{self.agent_id}', source_id='{self.source_id}')>"
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
class BlockModel(Base):
|
|
290
|
+
__tablename__ = "block"
|
|
291
|
+
__table_args__ = {"extend_existing": True}
|
|
292
|
+
|
|
293
|
+
id = Column(String, primary_key=True, nullable=False)
|
|
294
|
+
value = Column(String, nullable=False)
|
|
295
|
+
limit = Column(BIGINT)
|
|
296
|
+
name = Column(String, nullable=False)
|
|
297
|
+
template = Column(Boolean, default=False) # True: listed as possible human/persona
|
|
298
|
+
label = Column(String)
|
|
299
|
+
metadata_ = Column(JSON)
|
|
300
|
+
description = Column(String)
|
|
301
|
+
user_id = Column(String)
|
|
302
|
+
Index(__tablename__ + "_idx_user", user_id),
|
|
303
|
+
|
|
304
|
+
def __repr__(self) -> str:
|
|
305
|
+
return f"<Block(id='{self.id}', name='{self.name}', template='{self.template}', label='{self.label}', user_id='{self.user_id}')>"
|
|
306
|
+
|
|
307
|
+
def to_record(self) -> Block:
|
|
308
|
+
if self.label == "persona":
|
|
309
|
+
return Persona(
|
|
310
|
+
id=self.id,
|
|
311
|
+
value=self.value,
|
|
312
|
+
limit=self.limit,
|
|
313
|
+
name=self.name,
|
|
314
|
+
template=self.template,
|
|
315
|
+
label=self.label,
|
|
316
|
+
metadata_=self.metadata_,
|
|
317
|
+
description=self.description,
|
|
318
|
+
user_id=self.user_id,
|
|
319
|
+
)
|
|
320
|
+
elif self.label == "human":
|
|
321
|
+
return Human(
|
|
322
|
+
id=self.id,
|
|
323
|
+
value=self.value,
|
|
324
|
+
limit=self.limit,
|
|
325
|
+
name=self.name,
|
|
326
|
+
template=self.template,
|
|
327
|
+
label=self.label,
|
|
328
|
+
metadata_=self.metadata_,
|
|
329
|
+
description=self.description,
|
|
330
|
+
user_id=self.user_id,
|
|
331
|
+
)
|
|
332
|
+
else:
|
|
333
|
+
return Block(
|
|
334
|
+
id=self.id,
|
|
335
|
+
value=self.value,
|
|
336
|
+
limit=self.limit,
|
|
337
|
+
name=self.name,
|
|
338
|
+
template=self.template,
|
|
339
|
+
label=self.label,
|
|
340
|
+
metadata_=self.metadata_,
|
|
341
|
+
description=self.description,
|
|
342
|
+
user_id=self.user_id,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
class ToolModel(Base):
|
|
347
|
+
__tablename__ = "tools"
|
|
348
|
+
__table_args__ = {"extend_existing": True}
|
|
349
|
+
|
|
350
|
+
id = Column(String, primary_key=True)
|
|
351
|
+
name = Column(String, nullable=False)
|
|
352
|
+
user_id = Column(String)
|
|
353
|
+
description = Column(String)
|
|
354
|
+
source_type = Column(String)
|
|
355
|
+
source_code = Column(String)
|
|
356
|
+
json_schema = Column(JSON)
|
|
357
|
+
module = Column(String)
|
|
358
|
+
tags = Column(JSON)
|
|
359
|
+
|
|
360
|
+
def __repr__(self) -> str:
|
|
361
|
+
return f"<Tool(id='{self.id}', name='{self.name}')>"
|
|
362
|
+
|
|
363
|
+
def to_record(self) -> Tool:
|
|
364
|
+
return Tool(
|
|
365
|
+
id=self.id,
|
|
366
|
+
name=self.name,
|
|
367
|
+
user_id=self.user_id,
|
|
368
|
+
description=self.description,
|
|
369
|
+
source_type=self.source_type,
|
|
370
|
+
source_code=self.source_code,
|
|
371
|
+
json_schema=self.json_schema,
|
|
372
|
+
module=self.module,
|
|
373
|
+
tags=self.tags,
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
class JobModel(Base):
|
|
378
|
+
__tablename__ = "jobs"
|
|
379
|
+
__table_args__ = {"extend_existing": True}
|
|
380
|
+
|
|
381
|
+
id = Column(String, primary_key=True)
|
|
382
|
+
user_id = Column(String)
|
|
383
|
+
status = Column(String, default=JobStatus.pending)
|
|
384
|
+
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
|
385
|
+
completed_at = Column(DateTime(timezone=True), onupdate=func.now())
|
|
386
|
+
metadata_ = Column(JSON)
|
|
387
|
+
|
|
388
|
+
def __repr__(self) -> str:
|
|
389
|
+
return f"<Job(id='{self.id}', status='{self.status}')>"
|
|
390
|
+
|
|
391
|
+
def to_record(self):
|
|
392
|
+
return Job(
|
|
393
|
+
id=self.id,
|
|
394
|
+
user_id=self.user_id,
|
|
395
|
+
status=self.status,
|
|
396
|
+
created_at=self.created_at,
|
|
397
|
+
completed_at=self.completed_at,
|
|
398
|
+
metadata_=self.metadata_,
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
class MetadataStore:
|
|
403
|
+
uri: Optional[str] = None
|
|
404
|
+
|
|
405
|
+
def __init__(self, config: LettaConfig):
|
|
406
|
+
# TODO: get DB URI or path
|
|
407
|
+
if config.metadata_storage_type == "postgres":
|
|
408
|
+
# construct URI from enviornment variables
|
|
409
|
+
self.uri = settings.pg_uri if settings.pg_uri else config.metadata_storage_uri
|
|
410
|
+
|
|
411
|
+
elif config.metadata_storage_type == "sqlite":
|
|
412
|
+
path = os.path.join(config.metadata_storage_path, "sqlite.db")
|
|
413
|
+
self.uri = f"sqlite:///{path}"
|
|
414
|
+
else:
|
|
415
|
+
raise ValueError(f"Invalid metadata storage type: {config.metadata_storage_type}")
|
|
416
|
+
|
|
417
|
+
# Ensure valid URI
|
|
418
|
+
assert self.uri, "Database URI is not provided or is invalid."
|
|
419
|
+
|
|
420
|
+
from letta.server.server import db_context
|
|
421
|
+
|
|
422
|
+
self.session_maker = db_context
|
|
423
|
+
|
|
424
|
+
@enforce_types
|
|
425
|
+
def create_api_key(self, user_id: str, name: str) -> APIKey:
|
|
426
|
+
"""Create an API key for a user"""
|
|
427
|
+
new_api_key = generate_api_key()
|
|
428
|
+
with self.session_maker() as session:
|
|
429
|
+
if session.query(APIKeyModel).filter(APIKeyModel.key == new_api_key).count() > 0:
|
|
430
|
+
# NOTE duplicate API keys / tokens should never happen, but if it does don't allow it
|
|
431
|
+
raise ValueError(f"Token {new_api_key} already exists")
|
|
432
|
+
# TODO store the API keys as hashed
|
|
433
|
+
assert user_id and name, "User ID and name must be provided"
|
|
434
|
+
token = APIKey(user_id=user_id, key=new_api_key, name=name)
|
|
435
|
+
session.add(APIKeyModel(**vars(token)))
|
|
436
|
+
session.commit()
|
|
437
|
+
return self.get_api_key(api_key=new_api_key)
|
|
438
|
+
|
|
439
|
+
@enforce_types
|
|
440
|
+
def delete_api_key(self, api_key: str):
|
|
441
|
+
"""Delete an API key from the database"""
|
|
442
|
+
with self.session_maker() as session:
|
|
443
|
+
session.query(APIKeyModel).filter(APIKeyModel.key == api_key).delete()
|
|
444
|
+
session.commit()
|
|
445
|
+
|
|
446
|
+
@enforce_types
|
|
447
|
+
def get_api_key(self, api_key: str) -> Optional[APIKey]:
|
|
448
|
+
with self.session_maker() as session:
|
|
449
|
+
results = session.query(APIKeyModel).filter(APIKeyModel.key == api_key).all()
|
|
450
|
+
if len(results) == 0:
|
|
451
|
+
return None
|
|
452
|
+
assert len(results) == 1, f"Expected 1 result, got {len(results)}" # should only be one result
|
|
453
|
+
return results[0].to_record()
|
|
454
|
+
|
|
455
|
+
@enforce_types
|
|
456
|
+
def get_all_api_keys_for_user(self, user_id: str) -> List[APIKey]:
|
|
457
|
+
with self.session_maker() as session:
|
|
458
|
+
results = session.query(APIKeyModel).filter(APIKeyModel.user_id == user_id).all()
|
|
459
|
+
tokens = [r.to_record() for r in results]
|
|
460
|
+
return tokens
|
|
461
|
+
|
|
462
|
+
@enforce_types
|
|
463
|
+
def get_user_from_api_key(self, api_key: str) -> Optional[User]:
|
|
464
|
+
"""Get the user associated with a given API key"""
|
|
465
|
+
token = self.get_api_key(api_key=api_key)
|
|
466
|
+
if token is None:
|
|
467
|
+
raise ValueError(f"Provided token does not exist")
|
|
468
|
+
else:
|
|
469
|
+
return self.get_user(user_id=token.user_id)
|
|
470
|
+
|
|
471
|
+
@enforce_types
|
|
472
|
+
def create_agent(self, agent: AgentState):
|
|
473
|
+
# insert into agent table
|
|
474
|
+
# make sure agent.name does not already exist for user user_id
|
|
475
|
+
with self.session_maker() as session:
|
|
476
|
+
if session.query(AgentModel).filter(AgentModel.name == agent.name).filter(AgentModel.user_id == agent.user_id).count() > 0:
|
|
477
|
+
raise ValueError(f"Agent with name {agent.name} already exists")
|
|
478
|
+
fields = vars(agent)
|
|
479
|
+
fields["memory"] = agent.memory.to_dict()
|
|
480
|
+
session.add(AgentModel(**fields))
|
|
481
|
+
session.commit()
|
|
482
|
+
|
|
483
|
+
@enforce_types
|
|
484
|
+
def create_source(self, source: Source):
|
|
485
|
+
with self.session_maker() as session:
|
|
486
|
+
if session.query(SourceModel).filter(SourceModel.name == source.name).filter(SourceModel.user_id == source.user_id).count() > 0:
|
|
487
|
+
raise ValueError(f"Source with name {source.name} already exists for user {source.user_id}")
|
|
488
|
+
session.add(SourceModel(**vars(source)))
|
|
489
|
+
session.commit()
|
|
490
|
+
|
|
491
|
+
@enforce_types
|
|
492
|
+
def create_user(self, user: User):
|
|
493
|
+
with self.session_maker() as session:
|
|
494
|
+
if session.query(UserModel).filter(UserModel.id == user.id).count() > 0:
|
|
495
|
+
raise ValueError(f"User with id {user.id} already exists")
|
|
496
|
+
session.add(UserModel(**vars(user)))
|
|
497
|
+
session.commit()
|
|
498
|
+
|
|
499
|
+
@enforce_types
|
|
500
|
+
def create_organization(self, organization: Organization):
|
|
501
|
+
with self.session_maker() as session:
|
|
502
|
+
if session.query(OrganizationModel).filter(OrganizationModel.id == organization.id).count() > 0:
|
|
503
|
+
raise ValueError(f"Organization with id {organization.id} already exists")
|
|
504
|
+
session.add(OrganizationModel(**vars(organization)))
|
|
505
|
+
session.commit()
|
|
506
|
+
|
|
507
|
+
@enforce_types
|
|
508
|
+
def create_block(self, block: Block):
|
|
509
|
+
with self.session_maker() as session:
|
|
510
|
+
# TODO: fix?
|
|
511
|
+
# we are only validating that more than one template block
|
|
512
|
+
# with a given name doesn't exist.
|
|
513
|
+
if (
|
|
514
|
+
session.query(BlockModel)
|
|
515
|
+
.filter(BlockModel.name == block.name)
|
|
516
|
+
.filter(BlockModel.user_id == block.user_id)
|
|
517
|
+
.filter(BlockModel.template == True)
|
|
518
|
+
.filter(BlockModel.label == block.label)
|
|
519
|
+
.count()
|
|
520
|
+
> 0
|
|
521
|
+
):
|
|
522
|
+
|
|
523
|
+
raise ValueError(f"Block with name {block.name} already exists")
|
|
524
|
+
session.add(BlockModel(**vars(block)))
|
|
525
|
+
session.commit()
|
|
526
|
+
|
|
527
|
+
@enforce_types
|
|
528
|
+
def create_tool(self, tool: Tool):
|
|
529
|
+
with self.session_maker() as session:
|
|
530
|
+
if self.get_tool(tool_name=tool.name, user_id=tool.user_id) is not None:
|
|
531
|
+
raise ValueError(f"Tool with name {tool.name} already exists")
|
|
532
|
+
session.add(ToolModel(**vars(tool)))
|
|
533
|
+
session.commit()
|
|
534
|
+
|
|
535
|
+
@enforce_types
|
|
536
|
+
def update_agent(self, agent: AgentState):
|
|
537
|
+
with self.session_maker() as session:
|
|
538
|
+
fields = vars(agent)
|
|
539
|
+
if isinstance(agent.memory, Memory): # TODO: this is nasty but this whole class will soon be removed so whatever
|
|
540
|
+
fields["memory"] = agent.memory.to_dict()
|
|
541
|
+
session.query(AgentModel).filter(AgentModel.id == agent.id).update(fields)
|
|
542
|
+
session.commit()
|
|
543
|
+
|
|
544
|
+
@enforce_types
|
|
545
|
+
def update_user(self, user: User):
|
|
546
|
+
with self.session_maker() as session:
|
|
547
|
+
session.query(UserModel).filter(UserModel.id == user.id).update(vars(user))
|
|
548
|
+
session.commit()
|
|
549
|
+
|
|
550
|
+
@enforce_types
|
|
551
|
+
def update_source(self, source: Source):
|
|
552
|
+
with self.session_maker() as session:
|
|
553
|
+
session.query(SourceModel).filter(SourceModel.id == source.id).update(vars(source))
|
|
554
|
+
session.commit()
|
|
555
|
+
|
|
556
|
+
@enforce_types
|
|
557
|
+
def update_block(self, block: Block):
|
|
558
|
+
with self.session_maker() as session:
|
|
559
|
+
session.query(BlockModel).filter(BlockModel.id == block.id).update(vars(block))
|
|
560
|
+
session.commit()
|
|
561
|
+
|
|
562
|
+
@enforce_types
|
|
563
|
+
def update_or_create_block(self, block: Block):
|
|
564
|
+
with self.session_maker() as session:
|
|
565
|
+
existing_block = session.query(BlockModel).filter(BlockModel.id == block.id).first()
|
|
566
|
+
if existing_block:
|
|
567
|
+
session.query(BlockModel).filter(BlockModel.id == block.id).update(vars(block))
|
|
568
|
+
else:
|
|
569
|
+
session.add(BlockModel(**vars(block)))
|
|
570
|
+
session.commit()
|
|
571
|
+
|
|
572
|
+
@enforce_types
|
|
573
|
+
def update_tool(self, tool: Tool):
|
|
574
|
+
with self.session_maker() as session:
|
|
575
|
+
session.query(ToolModel).filter(ToolModel.id == tool.id).update(vars(tool))
|
|
576
|
+
session.commit()
|
|
577
|
+
|
|
578
|
+
@enforce_types
|
|
579
|
+
def delete_tool(self, tool_id: str):
|
|
580
|
+
with self.session_maker() as session:
|
|
581
|
+
session.query(ToolModel).filter(ToolModel.id == tool_id).delete()
|
|
582
|
+
session.commit()
|
|
583
|
+
|
|
584
|
+
@enforce_types
|
|
585
|
+
def delete_block(self, block_id: str):
|
|
586
|
+
with self.session_maker() as session:
|
|
587
|
+
session.query(BlockModel).filter(BlockModel.id == block_id).delete()
|
|
588
|
+
session.commit()
|
|
589
|
+
|
|
590
|
+
@enforce_types
|
|
591
|
+
def delete_agent(self, agent_id: str):
|
|
592
|
+
with self.session_maker() as session:
|
|
593
|
+
|
|
594
|
+
# delete agents
|
|
595
|
+
session.query(AgentModel).filter(AgentModel.id == agent_id).delete()
|
|
596
|
+
|
|
597
|
+
# delete mappings
|
|
598
|
+
session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.agent_id == agent_id).delete()
|
|
599
|
+
|
|
600
|
+
session.commit()
|
|
601
|
+
|
|
602
|
+
@enforce_types
|
|
603
|
+
def delete_source(self, source_id: str):
|
|
604
|
+
with self.session_maker() as session:
|
|
605
|
+
# delete from sources table
|
|
606
|
+
session.query(SourceModel).filter(SourceModel.id == source_id).delete()
|
|
607
|
+
|
|
608
|
+
# delete any mappings
|
|
609
|
+
session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).delete()
|
|
610
|
+
|
|
611
|
+
session.commit()
|
|
612
|
+
|
|
613
|
+
@enforce_types
|
|
614
|
+
def delete_user(self, user_id: str):
|
|
615
|
+
with self.session_maker() as session:
|
|
616
|
+
# delete from users table
|
|
617
|
+
session.query(UserModel).filter(UserModel.id == user_id).delete()
|
|
618
|
+
|
|
619
|
+
# delete associated agents
|
|
620
|
+
session.query(AgentModel).filter(AgentModel.user_id == user_id).delete()
|
|
621
|
+
|
|
622
|
+
# delete associated sources
|
|
623
|
+
session.query(SourceModel).filter(SourceModel.user_id == user_id).delete()
|
|
624
|
+
|
|
625
|
+
# delete associated mappings
|
|
626
|
+
session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.user_id == user_id).delete()
|
|
627
|
+
|
|
628
|
+
session.commit()
|
|
629
|
+
|
|
630
|
+
@enforce_types
|
|
631
|
+
def delete_organization(self, org_id: str):
|
|
632
|
+
with self.session_maker() as session:
|
|
633
|
+
# delete from organizations table
|
|
634
|
+
session.query(OrganizationModel).filter(OrganizationModel.id == org_id).delete()
|
|
635
|
+
|
|
636
|
+
# TODO: delete associated data
|
|
637
|
+
|
|
638
|
+
session.commit()
|
|
639
|
+
|
|
640
|
+
@enforce_types
|
|
641
|
+
# def list_tools(self, user_id: str) -> List[ToolModel]: # TODO: add when users can creat tools
|
|
642
|
+
def list_tools(self, user_id: Optional[str] = None) -> List[ToolModel]:
|
|
643
|
+
with self.session_maker() as session:
|
|
644
|
+
results = session.query(ToolModel).filter(ToolModel.user_id == None).all()
|
|
645
|
+
if user_id:
|
|
646
|
+
results += session.query(ToolModel).filter(ToolModel.user_id == user_id).all()
|
|
647
|
+
res = [r.to_record() for r in results]
|
|
648
|
+
return res
|
|
649
|
+
|
|
650
|
+
@enforce_types
|
|
651
|
+
def list_agents(self, user_id: str) -> List[AgentState]:
|
|
652
|
+
with self.session_maker() as session:
|
|
653
|
+
results = session.query(AgentModel).filter(AgentModel.user_id == user_id).all()
|
|
654
|
+
return [r.to_record() for r in results]
|
|
655
|
+
|
|
656
|
+
@enforce_types
|
|
657
|
+
def list_sources(self, user_id: str) -> List[Source]:
|
|
658
|
+
with self.session_maker() as session:
|
|
659
|
+
results = session.query(SourceModel).filter(SourceModel.user_id == user_id).all()
|
|
660
|
+
return [r.to_record() for r in results]
|
|
661
|
+
|
|
662
|
+
@enforce_types
|
|
663
|
+
def get_agent(
|
|
664
|
+
self, agent_id: Optional[str] = None, agent_name: Optional[str] = None, user_id: Optional[str] = None
|
|
665
|
+
) -> Optional[AgentState]:
|
|
666
|
+
with self.session_maker() as session:
|
|
667
|
+
if agent_id:
|
|
668
|
+
results = session.query(AgentModel).filter(AgentModel.id == agent_id).all()
|
|
669
|
+
else:
|
|
670
|
+
assert agent_name is not None and user_id is not None, "Must provide either agent_id or agent_name"
|
|
671
|
+
results = session.query(AgentModel).filter(AgentModel.name == agent_name).filter(AgentModel.user_id == user_id).all()
|
|
672
|
+
|
|
673
|
+
if len(results) == 0:
|
|
674
|
+
return None
|
|
675
|
+
assert len(results) == 1, f"Expected 1 result, got {len(results)}" # should only be one result
|
|
676
|
+
return results[0].to_record()
|
|
677
|
+
|
|
678
|
+
@enforce_types
|
|
679
|
+
def get_user(self, user_id: str) -> Optional[User]:
|
|
680
|
+
with self.session_maker() as session:
|
|
681
|
+
results = session.query(UserModel).filter(UserModel.id == user_id).all()
|
|
682
|
+
if len(results) == 0:
|
|
683
|
+
return None
|
|
684
|
+
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
|
685
|
+
return results[0].to_record()
|
|
686
|
+
|
|
687
|
+
@enforce_types
|
|
688
|
+
def get_organization(self, org_id: str) -> Optional[Organization]:
|
|
689
|
+
with self.session_maker() as session:
|
|
690
|
+
results = session.query(OrganizationModel).filter(OrganizationModel.id == org_id).all()
|
|
691
|
+
if len(results) == 0:
|
|
692
|
+
return None
|
|
693
|
+
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
|
694
|
+
return results[0].to_record()
|
|
695
|
+
|
|
696
|
+
@enforce_types
|
|
697
|
+
def list_organizations(self, cursor: Optional[str] = None, limit: Optional[int] = 50):
|
|
698
|
+
with self.session_maker() as session:
|
|
699
|
+
query = session.query(OrganizationModel).order_by(desc(OrganizationModel.id))
|
|
700
|
+
if cursor:
|
|
701
|
+
query = query.filter(OrganizationModel.id < cursor)
|
|
702
|
+
results = query.limit(limit).all()
|
|
703
|
+
if not results:
|
|
704
|
+
return None, []
|
|
705
|
+
organization_records = [r.to_record() for r in results]
|
|
706
|
+
next_cursor = organization_records[-1].id
|
|
707
|
+
assert isinstance(next_cursor, str)
|
|
708
|
+
|
|
709
|
+
return next_cursor, organization_records
|
|
710
|
+
|
|
711
|
+
@enforce_types
|
|
712
|
+
def get_all_users(self, cursor: Optional[str] = None, limit: Optional[int] = 50):
|
|
713
|
+
with self.session_maker() as session:
|
|
714
|
+
query = session.query(UserModel).order_by(desc(UserModel.id))
|
|
715
|
+
if cursor:
|
|
716
|
+
query = query.filter(UserModel.id < cursor)
|
|
717
|
+
results = query.limit(limit).all()
|
|
718
|
+
if not results:
|
|
719
|
+
return None, []
|
|
720
|
+
user_records = [r.to_record() for r in results]
|
|
721
|
+
next_cursor = user_records[-1].id
|
|
722
|
+
assert isinstance(next_cursor, str)
|
|
723
|
+
|
|
724
|
+
return next_cursor, user_records
|
|
725
|
+
|
|
726
|
+
@enforce_types
|
|
727
|
+
def get_source(
|
|
728
|
+
self, source_id: Optional[str] = None, user_id: Optional[str] = None, source_name: Optional[str] = None
|
|
729
|
+
) -> Optional[Source]:
|
|
730
|
+
with self.session_maker() as session:
|
|
731
|
+
if source_id:
|
|
732
|
+
results = session.query(SourceModel).filter(SourceModel.id == source_id).all()
|
|
733
|
+
else:
|
|
734
|
+
assert user_id is not None and source_name is not None
|
|
735
|
+
results = session.query(SourceModel).filter(SourceModel.name == source_name).filter(SourceModel.user_id == user_id).all()
|
|
736
|
+
if len(results) == 0:
|
|
737
|
+
return None
|
|
738
|
+
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
|
739
|
+
return results[0].to_record()
|
|
740
|
+
|
|
741
|
+
@enforce_types
|
|
742
|
+
def get_tool(
|
|
743
|
+
self, tool_name: Optional[str] = None, tool_id: Optional[str] = None, user_id: Optional[str] = None
|
|
744
|
+
) -> Optional[ToolModel]:
|
|
745
|
+
with self.session_maker() as session:
|
|
746
|
+
if tool_id:
|
|
747
|
+
results = session.query(ToolModel).filter(ToolModel.id == tool_id).all()
|
|
748
|
+
else:
|
|
749
|
+
assert tool_name is not None
|
|
750
|
+
results = session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == None).all()
|
|
751
|
+
if user_id:
|
|
752
|
+
results += session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == user_id).all()
|
|
753
|
+
if len(results) == 0:
|
|
754
|
+
return None
|
|
755
|
+
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
|
756
|
+
return results[0].to_record()
|
|
757
|
+
|
|
758
|
+
@enforce_types
|
|
759
|
+
def get_block(self, block_id: str) -> Optional[Block]:
|
|
760
|
+
with self.session_maker() as session:
|
|
761
|
+
results = session.query(BlockModel).filter(BlockModel.id == block_id).all()
|
|
762
|
+
if len(results) == 0:
|
|
763
|
+
return None
|
|
764
|
+
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
|
765
|
+
return results[0].to_record()
|
|
766
|
+
|
|
767
|
+
@enforce_types
|
|
768
|
+
def get_blocks(
|
|
769
|
+
self,
|
|
770
|
+
user_id: Optional[str],
|
|
771
|
+
label: Optional[str] = None,
|
|
772
|
+
template: Optional[bool] = None,
|
|
773
|
+
name: Optional[str] = None,
|
|
774
|
+
id: Optional[str] = None,
|
|
775
|
+
) -> Optional[List[Block]]:
|
|
776
|
+
"""List available blocks"""
|
|
777
|
+
with self.session_maker() as session:
|
|
778
|
+
query = session.query(BlockModel)
|
|
779
|
+
|
|
780
|
+
if user_id:
|
|
781
|
+
query = query.filter(BlockModel.user_id == user_id)
|
|
782
|
+
|
|
783
|
+
if label:
|
|
784
|
+
query = query.filter(BlockModel.label == label)
|
|
785
|
+
|
|
786
|
+
if name:
|
|
787
|
+
query = query.filter(BlockModel.name == name)
|
|
788
|
+
|
|
789
|
+
if id:
|
|
790
|
+
query = query.filter(BlockModel.id == id)
|
|
791
|
+
|
|
792
|
+
if template:
|
|
793
|
+
query = query.filter(BlockModel.template == template)
|
|
794
|
+
|
|
795
|
+
results = query.all()
|
|
796
|
+
|
|
797
|
+
if len(results) == 0:
|
|
798
|
+
return None
|
|
799
|
+
|
|
800
|
+
return [r.to_record() for r in results]
|
|
801
|
+
|
|
802
|
+
# agent source metadata
|
|
803
|
+
@enforce_types
|
|
804
|
+
def attach_source(self, user_id: str, agent_id: str, source_id: str):
|
|
805
|
+
with self.session_maker() as session:
|
|
806
|
+
# TODO: remove this (is a hack)
|
|
807
|
+
mapping_id = f"{user_id}-{agent_id}-{source_id}"
|
|
808
|
+
session.add(AgentSourceMappingModel(id=mapping_id, user_id=user_id, agent_id=agent_id, source_id=source_id))
|
|
809
|
+
session.commit()
|
|
810
|
+
|
|
811
|
+
@enforce_types
|
|
812
|
+
def list_attached_sources(self, agent_id: str) -> List[Source]:
|
|
813
|
+
with self.session_maker() as session:
|
|
814
|
+
results = session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.agent_id == agent_id).all()
|
|
815
|
+
|
|
816
|
+
sources = []
|
|
817
|
+
# make sure source exists
|
|
818
|
+
for r in results:
|
|
819
|
+
source = self.get_source(source_id=r.source_id)
|
|
820
|
+
if source:
|
|
821
|
+
sources.append(source)
|
|
822
|
+
else:
|
|
823
|
+
printd(f"Warning: source {r.source_id} does not exist but exists in mapping database. This should never happen.")
|
|
824
|
+
return sources
|
|
825
|
+
|
|
826
|
+
@enforce_types
|
|
827
|
+
def list_attached_agents(self, source_id: str) -> List[str]:
|
|
828
|
+
with self.session_maker() as session:
|
|
829
|
+
results = session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).all()
|
|
830
|
+
|
|
831
|
+
agent_ids = []
|
|
832
|
+
# make sure agent exists
|
|
833
|
+
for r in results:
|
|
834
|
+
agent = self.get_agent(agent_id=r.agent_id)
|
|
835
|
+
if agent:
|
|
836
|
+
agent_ids.append(r.agent_id)
|
|
837
|
+
else:
|
|
838
|
+
printd(f"Warning: agent {r.agent_id} does not exist but exists in mapping database. This should never happen.")
|
|
839
|
+
return agent_ids
|
|
840
|
+
|
|
841
|
+
@enforce_types
|
|
842
|
+
def detach_source(self, agent_id: str, source_id: str):
|
|
843
|
+
with self.session_maker() as session:
|
|
844
|
+
session.query(AgentSourceMappingModel).filter(
|
|
845
|
+
AgentSourceMappingModel.agent_id == agent_id, AgentSourceMappingModel.source_id == source_id
|
|
846
|
+
).delete()
|
|
847
|
+
session.commit()
|
|
848
|
+
|
|
849
|
+
@enforce_types
|
|
850
|
+
def create_job(self, job: Job):
|
|
851
|
+
with self.session_maker() as session:
|
|
852
|
+
session.add(JobModel(**vars(job)))
|
|
853
|
+
session.commit()
|
|
854
|
+
|
|
855
|
+
def delete_job(self, job_id: str):
|
|
856
|
+
with self.session_maker() as session:
|
|
857
|
+
session.query(JobModel).filter(JobModel.id == job_id).delete()
|
|
858
|
+
session.commit()
|
|
859
|
+
|
|
860
|
+
def get_job(self, job_id: str) -> Optional[Job]:
|
|
861
|
+
with self.session_maker() as session:
|
|
862
|
+
results = session.query(JobModel).filter(JobModel.id == job_id).all()
|
|
863
|
+
if len(results) == 0:
|
|
864
|
+
return None
|
|
865
|
+
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
|
866
|
+
return results[0].to_record()
|
|
867
|
+
|
|
868
|
+
def list_jobs(self, user_id: str) -> List[Job]:
|
|
869
|
+
with self.session_maker() as session:
|
|
870
|
+
results = session.query(JobModel).filter(JobModel.user_id == user_id).all()
|
|
871
|
+
return [r.to_record() for r in results]
|
|
872
|
+
|
|
873
|
+
def update_job(self, job: Job) -> Job:
|
|
874
|
+
with self.session_maker() as session:
|
|
875
|
+
session.query(JobModel).filter(JobModel.id == job.id).update(vars(job))
|
|
876
|
+
session.commit()
|
|
877
|
+
return Job
|
|
878
|
+
|
|
879
|
+
def update_job_status(self, job_id: str, status: JobStatus):
|
|
880
|
+
with self.session_maker() as session:
|
|
881
|
+
session.query(JobModel).filter(JobModel.id == job_id).update({"status": status})
|
|
882
|
+
if status == JobStatus.COMPLETED:
|
|
883
|
+
session.query(JobModel).filter(JobModel.id == job_id).update({"completed_at": get_utc_time()})
|
|
884
|
+
session.commit()
|