letta-nightly 0.4.1.dev20241007104134__py3-none-any.whl → 0.4.1.dev20241009104130__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of letta-nightly might be problematic. Click here for more details.
- letta/agent.py +36 -10
- letta/client/client.py +8 -1
- letta/credentials.py +3 -3
- letta/errors.py +1 -1
- letta/functions/schema_generator.py +1 -1
- letta/llm_api/anthropic.py +3 -24
- letta/llm_api/azure_openai.py +53 -108
- letta/llm_api/azure_openai_constants.py +10 -0
- letta/llm_api/google_ai.py +39 -64
- letta/llm_api/helpers.py +208 -0
- letta/llm_api/llm_api_tools.py +43 -218
- letta/llm_api/openai.py +74 -50
- letta/main.py +1 -1
- letta/metadata.py +2 -0
- letta/providers.py +144 -31
- letta/schemas/agent.py +14 -0
- letta/schemas/llm_config.py +2 -2
- letta/schemas/openai/chat_completion_response.py +3 -0
- letta/schemas/tool.py +3 -3
- letta/server/rest_api/admin/tools.py +0 -1
- letta/server/rest_api/app.py +1 -17
- letta/server/rest_api/routers/openai/assistants/threads.py +10 -7
- letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +5 -3
- letta/server/rest_api/routers/v1/agents.py +23 -13
- letta/server/rest_api/routers/v1/blocks.py +5 -3
- letta/server/rest_api/routers/v1/jobs.py +5 -3
- letta/server/rest_api/routers/v1/sources.py +25 -13
- letta/server/rest_api/routers/v1/tools.py +12 -7
- letta/server/server.py +33 -37
- letta/settings.py +5 -113
- {letta_nightly-0.4.1.dev20241007104134.dist-info → letta_nightly-0.4.1.dev20241009104130.dist-info}/METADATA +1 -1
- {letta_nightly-0.4.1.dev20241007104134.dist-info → letta_nightly-0.4.1.dev20241009104130.dist-info}/RECORD +35 -33
- {letta_nightly-0.4.1.dev20241007104134.dist-info → letta_nightly-0.4.1.dev20241009104130.dist-info}/LICENSE +0 -0
- {letta_nightly-0.4.1.dev20241007104134.dist-info → letta_nightly-0.4.1.dev20241009104130.dist-info}/WHEEL +0 -0
- {letta_nightly-0.4.1.dev20241007104134.dist-info → letta_nightly-0.4.1.dev20241009104130.dist-info}/entry_points.txt +0 -0
letta/providers.py
CHANGED
|
@@ -1,8 +1,13 @@
|
|
|
1
1
|
from typing import List, Optional
|
|
2
2
|
|
|
3
|
-
from pydantic import BaseModel, Field
|
|
3
|
+
from pydantic import BaseModel, Field, model_validator
|
|
4
4
|
|
|
5
5
|
from letta.constants import LLM_MAX_TOKENS
|
|
6
|
+
from letta.llm_api.azure_openai import (
|
|
7
|
+
get_azure_chat_completions_endpoint,
|
|
8
|
+
get_azure_embeddings_endpoint,
|
|
9
|
+
)
|
|
10
|
+
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
|
|
6
11
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
7
12
|
from letta.schemas.llm_config import LLMConfig
|
|
8
13
|
|
|
@@ -122,34 +127,64 @@ class OllamaProvider(OpenAIProvider):
|
|
|
122
127
|
response = requests.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True})
|
|
123
128
|
response_json = response.json()
|
|
124
129
|
|
|
125
|
-
|
|
126
|
-
possible_keys = [
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
]
|
|
142
|
-
|
|
130
|
+
## thank you vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L1675
|
|
131
|
+
# possible_keys = [
|
|
132
|
+
# # OPT
|
|
133
|
+
# "max_position_embeddings",
|
|
134
|
+
# # GPT-2
|
|
135
|
+
# "n_positions",
|
|
136
|
+
# # MPT
|
|
137
|
+
# "max_seq_len",
|
|
138
|
+
# # ChatGLM2
|
|
139
|
+
# "seq_length",
|
|
140
|
+
# # Command-R
|
|
141
|
+
# "model_max_length",
|
|
142
|
+
# # Others
|
|
143
|
+
# "max_sequence_length",
|
|
144
|
+
# "max_seq_length",
|
|
145
|
+
# "seq_len",
|
|
146
|
+
# ]
|
|
143
147
|
# max_position_embeddings
|
|
144
148
|
# parse model cards: nous, dolphon, llama
|
|
145
149
|
for key, value in response_json["model_info"].items():
|
|
146
|
-
if "
|
|
150
|
+
if "context_length" in key:
|
|
151
|
+
return value
|
|
152
|
+
return None
|
|
153
|
+
|
|
154
|
+
def get_model_embedding_dim(self, model_name: str):
|
|
155
|
+
import requests
|
|
156
|
+
|
|
157
|
+
response = requests.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True})
|
|
158
|
+
response_json = response.json()
|
|
159
|
+
for key, value in response_json["model_info"].items():
|
|
160
|
+
if "embedding_length" in key:
|
|
147
161
|
return value
|
|
148
162
|
return None
|
|
149
163
|
|
|
150
164
|
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
|
151
|
-
#
|
|
152
|
-
|
|
165
|
+
# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
|
|
166
|
+
import requests
|
|
167
|
+
|
|
168
|
+
response = requests.get(f"{self.base_url}/api/tags")
|
|
169
|
+
if response.status_code != 200:
|
|
170
|
+
raise Exception(f"Failed to list Ollama models: {response.text}")
|
|
171
|
+
response_json = response.json()
|
|
172
|
+
|
|
173
|
+
configs = []
|
|
174
|
+
for model in response_json["models"]:
|
|
175
|
+
embedding_dim = self.get_model_embedding_dim(model["name"])
|
|
176
|
+
if not embedding_dim:
|
|
177
|
+
continue
|
|
178
|
+
configs.append(
|
|
179
|
+
EmbeddingConfig(
|
|
180
|
+
embedding_model=model["name"],
|
|
181
|
+
embedding_endpoint_type="ollama",
|
|
182
|
+
embedding_endpoint=self.base_url,
|
|
183
|
+
embedding_dim=embedding_dim,
|
|
184
|
+
embedding_chunk_size=300,
|
|
185
|
+
)
|
|
186
|
+
)
|
|
187
|
+
return configs
|
|
153
188
|
|
|
154
189
|
|
|
155
190
|
class GroqProvider(OpenAIProvider):
|
|
@@ -182,20 +217,21 @@ class GroqProvider(OpenAIProvider):
|
|
|
182
217
|
class GoogleAIProvider(Provider):
|
|
183
218
|
# gemini
|
|
184
219
|
api_key: str = Field(..., description="API key for the Google AI API.")
|
|
185
|
-
service_endpoint: str = "generativelanguage"
|
|
186
220
|
base_url: str = "https://generativelanguage.googleapis.com"
|
|
187
221
|
|
|
188
222
|
def list_llm_models(self):
|
|
189
223
|
from letta.llm_api.google_ai import google_ai_get_model_list
|
|
190
224
|
|
|
191
|
-
|
|
192
|
-
|
|
225
|
+
model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key)
|
|
226
|
+
# filter by 'generateContent' models
|
|
227
|
+
model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]]
|
|
193
228
|
model_options = [str(m["name"]) for m in model_options]
|
|
229
|
+
|
|
230
|
+
# filter by model names
|
|
194
231
|
model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
|
|
232
|
+
|
|
195
233
|
# TODO remove manual filtering for gemini-pro
|
|
196
234
|
model_options = [mo for mo in model_options if str(mo).startswith("gemini") and "-pro" in str(mo)]
|
|
197
|
-
# TODO: add context windows
|
|
198
|
-
# model_options = ["gemini-pro"]
|
|
199
235
|
|
|
200
236
|
configs = []
|
|
201
237
|
for model in model_options:
|
|
@@ -210,17 +246,94 @@ class GoogleAIProvider(Provider):
|
|
|
210
246
|
return configs
|
|
211
247
|
|
|
212
248
|
def list_embedding_models(self):
|
|
213
|
-
|
|
249
|
+
from letta.llm_api.google_ai import google_ai_get_model_list
|
|
250
|
+
|
|
251
|
+
# TODO: use base_url instead
|
|
252
|
+
model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key)
|
|
253
|
+
# filter by 'generateContent' models
|
|
254
|
+
model_options = [mo for mo in model_options if "embedContent" in mo["supportedGenerationMethods"]]
|
|
255
|
+
model_options = [str(m["name"]) for m in model_options]
|
|
256
|
+
model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
|
|
257
|
+
|
|
258
|
+
configs = []
|
|
259
|
+
for model in model_options:
|
|
260
|
+
configs.append(
|
|
261
|
+
EmbeddingConfig(
|
|
262
|
+
embedding_model=model,
|
|
263
|
+
embedding_endpoint_type="google_ai",
|
|
264
|
+
embedding_endpoint=self.base_url,
|
|
265
|
+
embedding_dim=768,
|
|
266
|
+
embedding_chunk_size=300, # NOTE: max is 2048
|
|
267
|
+
)
|
|
268
|
+
)
|
|
269
|
+
return configs
|
|
214
270
|
|
|
215
271
|
def get_model_context_window(self, model_name: str):
|
|
216
272
|
from letta.llm_api.google_ai import google_ai_get_model_context_window
|
|
217
273
|
|
|
218
|
-
|
|
219
|
-
return google_ai_get_model_context_window(self.service_endpoint, self.api_key, model_name)
|
|
274
|
+
return google_ai_get_model_context_window(self.base_url, self.api_key, model_name)
|
|
220
275
|
|
|
221
276
|
|
|
222
277
|
class AzureProvider(Provider):
|
|
223
|
-
|
|
278
|
+
name: str = "azure"
|
|
279
|
+
latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
|
|
280
|
+
base_url: str = Field(
|
|
281
|
+
..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`."
|
|
282
|
+
)
|
|
283
|
+
api_key: str = Field(..., description="API key for the Azure API.")
|
|
284
|
+
api_version: str = Field(latest_api_version, description="API version for the Azure API")
|
|
285
|
+
|
|
286
|
+
@model_validator(mode="before")
|
|
287
|
+
def set_default_api_version(cls, values):
|
|
288
|
+
"""
|
|
289
|
+
This ensures that api_version is always set to the default if None is passed in.
|
|
290
|
+
"""
|
|
291
|
+
if values.get("api_version") is None:
|
|
292
|
+
values["api_version"] = cls.model_fields["latest_api_version"].default
|
|
293
|
+
return values
|
|
294
|
+
|
|
295
|
+
def list_llm_models(self) -> List[LLMConfig]:
|
|
296
|
+
from letta.llm_api.azure_openai import (
|
|
297
|
+
azure_openai_get_chat_completion_model_list,
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
model_options = azure_openai_get_chat_completion_model_list(self.base_url, api_key=self.api_key, api_version=self.api_version)
|
|
301
|
+
configs = []
|
|
302
|
+
for model_option in model_options:
|
|
303
|
+
model_name = model_option["id"]
|
|
304
|
+
context_window_size = self.get_model_context_window(model_name)
|
|
305
|
+
model_endpoint = get_azure_chat_completions_endpoint(self.base_url, model_name, self.api_version)
|
|
306
|
+
configs.append(
|
|
307
|
+
LLMConfig(model=model_name, model_endpoint_type="azure", model_endpoint=model_endpoint, context_window=context_window_size)
|
|
308
|
+
)
|
|
309
|
+
return configs
|
|
310
|
+
|
|
311
|
+
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
|
312
|
+
from letta.llm_api.azure_openai import azure_openai_get_embeddings_model_list
|
|
313
|
+
|
|
314
|
+
model_options = azure_openai_get_embeddings_model_list(
|
|
315
|
+
self.base_url, api_key=self.api_key, api_version=self.api_version, require_embedding_in_name=True
|
|
316
|
+
)
|
|
317
|
+
configs = []
|
|
318
|
+
for model_option in model_options:
|
|
319
|
+
model_name = model_option["id"]
|
|
320
|
+
model_endpoint = get_azure_embeddings_endpoint(self.base_url, model_name, self.api_version)
|
|
321
|
+
configs.append(
|
|
322
|
+
EmbeddingConfig(
|
|
323
|
+
embedding_model=model_name,
|
|
324
|
+
embedding_endpoint_type="azure",
|
|
325
|
+
embedding_endpoint=model_endpoint,
|
|
326
|
+
embedding_dim=768,
|
|
327
|
+
embedding_chunk_size=300, # NOTE: max is 2048
|
|
328
|
+
)
|
|
329
|
+
)
|
|
330
|
+
return configs
|
|
331
|
+
|
|
332
|
+
def get_model_context_window(self, model_name: str):
|
|
333
|
+
"""
|
|
334
|
+
This is hardcoded for now, since there is no API endpoints to retrieve metadata for a model.
|
|
335
|
+
"""
|
|
336
|
+
return AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, 4096)
|
|
224
337
|
|
|
225
338
|
|
|
226
339
|
class VLLMProvider(OpenAIProvider):
|
letta/schemas/agent.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import uuid
|
|
2
2
|
from datetime import datetime
|
|
3
|
+
from enum import Enum
|
|
3
4
|
from typing import Dict, List, Optional, Union
|
|
4
5
|
|
|
5
6
|
from pydantic import BaseModel, Field, field_validator
|
|
@@ -21,6 +22,15 @@ class BaseAgent(LettaBase, validate_assignment=True):
|
|
|
21
22
|
user_id: Optional[str] = Field(None, description="The user id of the agent.")
|
|
22
23
|
|
|
23
24
|
|
|
25
|
+
class AgentType(str, Enum):
|
|
26
|
+
"""
|
|
27
|
+
Enum to represent the type of agent.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
memgpt_agent = "memgpt_agent"
|
|
31
|
+
split_thread_agent = "split_thread_agent"
|
|
32
|
+
|
|
33
|
+
|
|
24
34
|
class AgentState(BaseAgent):
|
|
25
35
|
"""
|
|
26
36
|
Representation of an agent's state. This is the state of the agent at a given time, and is persisted in the DB backend. The state has all the information needed to recreate a persisted agent.
|
|
@@ -52,6 +62,9 @@ class AgentState(BaseAgent):
|
|
|
52
62
|
# system prompt
|
|
53
63
|
system: str = Field(..., description="The system prompt used by the agent.")
|
|
54
64
|
|
|
65
|
+
# agent configuration
|
|
66
|
+
agent_type: AgentType = Field(..., description="The type of agent.")
|
|
67
|
+
|
|
55
68
|
# llm information
|
|
56
69
|
llm_config: LLMConfig = Field(..., description="The LLM configuration used by the agent.")
|
|
57
70
|
embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the agent.")
|
|
@@ -64,6 +77,7 @@ class CreateAgent(BaseAgent):
|
|
|
64
77
|
memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.")
|
|
65
78
|
tools: Optional[List[str]] = Field(None, description="The tools used by the agent.")
|
|
66
79
|
system: Optional[str] = Field(None, description="The system prompt used by the agent.")
|
|
80
|
+
agent_type: Optional[AgentType] = Field(None, description="The type of agent.")
|
|
67
81
|
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
|
|
68
82
|
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")
|
|
69
83
|
|
letta/schemas/llm_config.py
CHANGED
|
@@ -11,7 +11,7 @@ class LLMConfig(BaseModel):
|
|
|
11
11
|
model (str): The name of the LLM model.
|
|
12
12
|
model_endpoint_type (str): The endpoint type for the model.
|
|
13
13
|
model_endpoint (str): The endpoint for the model.
|
|
14
|
-
model_wrapper (str): The wrapper for the model.
|
|
14
|
+
model_wrapper (str): The wrapper for the model. This is used to wrap additional text around the input/output of the model. This is useful for text-to-text completions, such as the Completions API in OpenAI.
|
|
15
15
|
context_window (int): The context window size for the model.
|
|
16
16
|
"""
|
|
17
17
|
|
|
@@ -34,7 +34,7 @@ class LLMConfig(BaseModel):
|
|
|
34
34
|
"vllm",
|
|
35
35
|
"hugging-face",
|
|
36
36
|
] = Field(..., description="The endpoint type for the model.")
|
|
37
|
-
model_endpoint: str = Field(
|
|
37
|
+
model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.")
|
|
38
38
|
model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.")
|
|
39
39
|
context_window: int = Field(..., description="The context window size for the model.")
|
|
40
40
|
|
|
@@ -74,6 +74,9 @@ class ChatCompletionResponse(BaseModel):
|
|
|
74
74
|
object: Literal["chat.completion"] = "chat.completion"
|
|
75
75
|
usage: UsageStatistics
|
|
76
76
|
|
|
77
|
+
def __str__(self):
|
|
78
|
+
return self.model_dump_json(indent=4)
|
|
79
|
+
|
|
77
80
|
|
|
78
81
|
class FunctionCallDelta(BaseModel):
|
|
79
82
|
# arguments: Optional[str] = None
|
letta/schemas/tool.py
CHANGED
|
@@ -93,7 +93,7 @@ class Tool(BaseTool):
|
|
|
93
93
|
# append heartbeat (necessary for triggering another reasoning step after this tool call)
|
|
94
94
|
json_schema["parameters"]["properties"]["request_heartbeat"] = {
|
|
95
95
|
"type": "boolean",
|
|
96
|
-
"description": "Request an immediate heartbeat after function execution. Set to
|
|
96
|
+
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
|
|
97
97
|
}
|
|
98
98
|
json_schema["parameters"]["required"].append("request_heartbeat")
|
|
99
99
|
|
|
@@ -128,7 +128,7 @@ class Tool(BaseTool):
|
|
|
128
128
|
# append heartbeat (necessary for triggering another reasoning step after this tool call)
|
|
129
129
|
json_schema["parameters"]["properties"]["request_heartbeat"] = {
|
|
130
130
|
"type": "boolean",
|
|
131
|
-
"description": "Request an immediate heartbeat after function execution. Set to
|
|
131
|
+
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
|
|
132
132
|
}
|
|
133
133
|
json_schema["parameters"]["required"].append("request_heartbeat")
|
|
134
134
|
|
|
@@ -161,7 +161,7 @@ class Tool(BaseTool):
|
|
|
161
161
|
# append heartbeat (necessary for triggering another reasoning step after this tool call)
|
|
162
162
|
json_schema["parameters"]["properties"]["request_heartbeat"] = {
|
|
163
163
|
"type": "boolean",
|
|
164
|
-
"description": "Request an immediate heartbeat after function execution. Set to
|
|
164
|
+
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
|
|
165
165
|
}
|
|
166
166
|
json_schema["parameters"]["required"].append("request_heartbeat")
|
|
167
167
|
|
|
@@ -26,7 +26,6 @@ class CreateToolResponse(BaseModel):
|
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
def setup_tools_index_router(server: SyncServer, interface: QueuingInterface):
|
|
29
|
-
# get_current_user_with_server = partial(partial(get_current_user, server), password)
|
|
30
29
|
|
|
31
30
|
@router.delete("/tools/{tool_name}", tags=["tools"])
|
|
32
31
|
async def delete_tool(
|
letta/server/rest_api/app.py
CHANGED
|
@@ -5,8 +5,7 @@ from pathlib import Path
|
|
|
5
5
|
from typing import Optional
|
|
6
6
|
|
|
7
7
|
import uvicorn
|
|
8
|
-
from fastapi import FastAPI
|
|
9
|
-
from fastapi.responses import JSONResponse
|
|
8
|
+
from fastapi import FastAPI
|
|
10
9
|
from starlette.middleware.cors import CORSMiddleware
|
|
11
10
|
|
|
12
11
|
from letta.server.constants import REST_DEFAULT_PORT
|
|
@@ -84,21 +83,6 @@ def create_application() -> "FastAPI":
|
|
|
84
83
|
allow_headers=["*"],
|
|
85
84
|
)
|
|
86
85
|
|
|
87
|
-
@app.middleware("http")
|
|
88
|
-
async def set_current_user_middleware(request: Request, call_next):
|
|
89
|
-
user_id = request.headers.get("user_id")
|
|
90
|
-
if user_id:
|
|
91
|
-
try:
|
|
92
|
-
server.set_current_user(user_id)
|
|
93
|
-
except ValueError as e:
|
|
94
|
-
# Return an HTTP 401 Unauthorized response
|
|
95
|
-
# raise HTTPException(status_code=401, detail=str(e))
|
|
96
|
-
return JSONResponse(status_code=401, content={"detail": str(e)})
|
|
97
|
-
else:
|
|
98
|
-
server.set_current_user(None)
|
|
99
|
-
response = await call_next(request)
|
|
100
|
-
return response
|
|
101
|
-
|
|
102
86
|
for route in v1_routes:
|
|
103
87
|
app.include_router(route, prefix=API_PREFIX)
|
|
104
88
|
# this gives undocumented routes for "latest" and bare api calls.
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import uuid
|
|
2
|
-
from typing import TYPE_CHECKING, List
|
|
2
|
+
from typing import TYPE_CHECKING, List, Optional
|
|
3
3
|
|
|
4
|
-
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query
|
|
4
|
+
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Path, Query
|
|
5
5
|
|
|
6
6
|
from letta.constants import DEFAULT_PRESET
|
|
7
7
|
from letta.schemas.agent import CreateAgent
|
|
@@ -43,11 +43,12 @@ router = APIRouter(prefix="/v1/threads", tags=["threads"])
|
|
|
43
43
|
def create_thread(
|
|
44
44
|
request: CreateThreadRequest = Body(...),
|
|
45
45
|
server: SyncServer = Depends(get_letta_server),
|
|
46
|
+
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
46
47
|
):
|
|
47
48
|
# TODO: use requests.description and requests.metadata fields
|
|
48
49
|
# TODO: handle requests.file_ids and requests.tools
|
|
49
50
|
# TODO: eventually allow request to override embedding/llm model
|
|
50
|
-
actor = server.
|
|
51
|
+
actor = server.get_user_or_default(user_id=user_id)
|
|
51
52
|
|
|
52
53
|
print("Create thread/agent", request)
|
|
53
54
|
# create a letta agent
|
|
@@ -67,8 +68,9 @@ def create_thread(
|
|
|
67
68
|
def retrieve_thread(
|
|
68
69
|
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
|
69
70
|
server: SyncServer = Depends(get_letta_server),
|
|
71
|
+
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
70
72
|
):
|
|
71
|
-
actor = server.
|
|
73
|
+
actor = server.get_user_or_default(user_id=user_id)
|
|
72
74
|
agent = server.get_agent(user_id=actor.id, agent_id=thread_id)
|
|
73
75
|
assert agent is not None
|
|
74
76
|
return OpenAIThread(
|
|
@@ -100,8 +102,9 @@ def create_message(
|
|
|
100
102
|
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
|
101
103
|
request: CreateMessageRequest = Body(...),
|
|
102
104
|
server: SyncServer = Depends(get_letta_server),
|
|
105
|
+
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
103
106
|
):
|
|
104
|
-
actor = server.
|
|
107
|
+
actor = server.get_user_or_default(user_id=user_id)
|
|
105
108
|
agent_id = thread_id
|
|
106
109
|
# create message object
|
|
107
110
|
message = Message(
|
|
@@ -143,8 +146,9 @@ def list_messages(
|
|
|
143
146
|
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
|
144
147
|
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
|
145
148
|
server: SyncServer = Depends(get_letta_server),
|
|
149
|
+
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
146
150
|
):
|
|
147
|
-
actor = server.
|
|
151
|
+
actor = server.get_user_or_default(user_id)
|
|
148
152
|
after_uuid = after if before else None
|
|
149
153
|
before_uuid = before if before else None
|
|
150
154
|
agent_id = thread_id
|
|
@@ -239,7 +243,6 @@ def create_run(
|
|
|
239
243
|
request: CreateRunRequest = Body(...),
|
|
240
244
|
server: SyncServer = Depends(get_letta_server),
|
|
241
245
|
):
|
|
242
|
-
server.get_current_user()
|
|
243
246
|
|
|
244
247
|
# TODO: add request.instructions as a message?
|
|
245
248
|
agent_id = thread_id
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import json
|
|
2
|
-
from typing import TYPE_CHECKING
|
|
2
|
+
from typing import TYPE_CHECKING, Optional
|
|
3
3
|
|
|
4
|
-
from fastapi import APIRouter, Body, Depends, HTTPException
|
|
4
|
+
from fastapi import APIRouter, Body, Depends, Header, HTTPException
|
|
5
5
|
|
|
6
6
|
from letta.schemas.enums import MessageRole
|
|
7
7
|
from letta.schemas.letta_message import FunctionCall, LettaMessage
|
|
@@ -30,12 +30,14 @@ router = APIRouter(prefix="/v1/chat/completions", tags=["chat_completions"])
|
|
|
30
30
|
async def create_chat_completion(
|
|
31
31
|
completion_request: ChatCompletionRequest = Body(...),
|
|
32
32
|
server: "SyncServer" = Depends(get_letta_server),
|
|
33
|
+
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
33
34
|
):
|
|
34
35
|
"""Send a message to a Letta agent via a /chat/completions completion_request
|
|
35
36
|
The bearer token will be used to identify the user.
|
|
36
37
|
The 'user' field in the completion_request should be set to the agent ID.
|
|
37
38
|
"""
|
|
38
|
-
actor = server.
|
|
39
|
+
actor = server.get_user_or_default(user_id=user_id)
|
|
40
|
+
|
|
39
41
|
agent_id = completion_request.user
|
|
40
42
|
if agent_id is None:
|
|
41
43
|
raise HTTPException(status_code=400, detail="Must pass agent_id in the 'user' field")
|
|
@@ -2,7 +2,7 @@ import asyncio
|
|
|
2
2
|
from datetime import datetime
|
|
3
3
|
from typing import Dict, List, Optional, Union
|
|
4
4
|
|
|
5
|
-
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
|
|
5
|
+
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, status
|
|
6
6
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
7
7
|
from starlette.responses import StreamingResponse
|
|
8
8
|
|
|
@@ -40,12 +40,13 @@ router = APIRouter(prefix="/agents", tags=["agents"])
|
|
|
40
40
|
@router.get("/", response_model=List[AgentState], operation_id="list_agents")
|
|
41
41
|
def list_agents(
|
|
42
42
|
server: "SyncServer" = Depends(get_letta_server),
|
|
43
|
+
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
43
44
|
):
|
|
44
45
|
"""
|
|
45
46
|
List all agents associated with a given user.
|
|
46
47
|
This endpoint retrieves a list of all agents and their configurations associated with the specified user ID.
|
|
47
48
|
"""
|
|
48
|
-
actor = server.
|
|
49
|
+
actor = server.get_user_or_default(user_id=user_id)
|
|
49
50
|
|
|
50
51
|
return server.list_agents(user_id=actor.id)
|
|
51
52
|
|
|
@@ -54,11 +55,12 @@ def list_agents(
|
|
|
54
55
|
def create_agent(
|
|
55
56
|
agent: CreateAgent = Body(...),
|
|
56
57
|
server: "SyncServer" = Depends(get_letta_server),
|
|
58
|
+
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
57
59
|
):
|
|
58
60
|
"""
|
|
59
61
|
Create a new agent with the specified configuration.
|
|
60
62
|
"""
|
|
61
|
-
actor = server.
|
|
63
|
+
actor = server.get_user_or_default(user_id=user_id)
|
|
62
64
|
agent.user_id = actor.id
|
|
63
65
|
# TODO: sarah make general
|
|
64
66
|
# TODO: eventually remove this
|
|
@@ -74,9 +76,10 @@ def update_agent(
|
|
|
74
76
|
agent_id: str,
|
|
75
77
|
update_agent: UpdateAgentState = Body(...),
|
|
76
78
|
server: "SyncServer" = Depends(get_letta_server),
|
|
79
|
+
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
77
80
|
):
|
|
78
81
|
"""Update an exsiting agent"""
|
|
79
|
-
actor = server.
|
|
82
|
+
actor = server.get_user_or_default(user_id=user_id)
|
|
80
83
|
|
|
81
84
|
update_agent.id = agent_id
|
|
82
85
|
return server.update_agent(update_agent, user_id=actor.id)
|
|
@@ -86,11 +89,12 @@ def update_agent(
|
|
|
86
89
|
def get_agent_state(
|
|
87
90
|
agent_id: str,
|
|
88
91
|
server: "SyncServer" = Depends(get_letta_server),
|
|
92
|
+
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
89
93
|
):
|
|
90
94
|
"""
|
|
91
95
|
Get the state of the agent.
|
|
92
96
|
"""
|
|
93
|
-
actor = server.
|
|
97
|
+
actor = server.get_user_or_default(user_id=user_id)
|
|
94
98
|
|
|
95
99
|
if not server.ms.get_agent(user_id=actor.id, agent_id=agent_id):
|
|
96
100
|
# agent does not exist
|
|
@@ -103,11 +107,12 @@ def get_agent_state(
|
|
|
103
107
|
def delete_agent(
|
|
104
108
|
agent_id: str,
|
|
105
109
|
server: "SyncServer" = Depends(get_letta_server),
|
|
110
|
+
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
106
111
|
):
|
|
107
112
|
"""
|
|
108
113
|
Delete an agent.
|
|
109
114
|
"""
|
|
110
|
-
actor = server.
|
|
115
|
+
actor = server.get_user_or_default(user_id=user_id)
|
|
111
116
|
|
|
112
117
|
return server.delete_agent(user_id=actor.id, agent_id=agent_id)
|
|
113
118
|
|
|
@@ -120,7 +125,6 @@ def get_agent_sources(
|
|
|
120
125
|
"""
|
|
121
126
|
Get the sources associated with an agent.
|
|
122
127
|
"""
|
|
123
|
-
server.get_current_user()
|
|
124
128
|
|
|
125
129
|
return server.list_attached_sources(agent_id)
|
|
126
130
|
|
|
@@ -155,12 +159,13 @@ def update_agent_memory(
|
|
|
155
159
|
agent_id: str,
|
|
156
160
|
request: Dict = Body(...),
|
|
157
161
|
server: "SyncServer" = Depends(get_letta_server),
|
|
162
|
+
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
158
163
|
):
|
|
159
164
|
"""
|
|
160
165
|
Update the core memory of a specific agent.
|
|
161
166
|
This endpoint accepts new memory contents (human and persona) and updates the core memory of the agent identified by the user ID and agent ID.
|
|
162
167
|
"""
|
|
163
|
-
actor = server.
|
|
168
|
+
actor = server.get_user_or_default(user_id=user_id)
|
|
164
169
|
|
|
165
170
|
memory = server.update_agent_core_memory(user_id=actor.id, agent_id=agent_id, new_memory_contents=request)
|
|
166
171
|
return memory
|
|
@@ -197,11 +202,12 @@ def get_agent_archival_memory(
|
|
|
197
202
|
after: Optional[int] = Query(None, description="Unique ID of the memory to start the query range at."),
|
|
198
203
|
before: Optional[int] = Query(None, description="Unique ID of the memory to end the query range at."),
|
|
199
204
|
limit: Optional[int] = Query(None, description="How many results to include in the response."),
|
|
205
|
+
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
200
206
|
):
|
|
201
207
|
"""
|
|
202
208
|
Retrieve the memories in an agent's archival memory store (paginated query).
|
|
203
209
|
"""
|
|
204
|
-
actor = server.
|
|
210
|
+
actor = server.get_user_or_default(user_id=user_id)
|
|
205
211
|
|
|
206
212
|
# TODO need to add support for non-postgres here
|
|
207
213
|
# chroma will throw:
|
|
@@ -221,11 +227,12 @@ def insert_agent_archival_memory(
|
|
|
221
227
|
agent_id: str,
|
|
222
228
|
request: CreateArchivalMemory = Body(...),
|
|
223
229
|
server: "SyncServer" = Depends(get_letta_server),
|
|
230
|
+
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
224
231
|
):
|
|
225
232
|
"""
|
|
226
233
|
Insert a memory into an agent's archival memory store.
|
|
227
234
|
"""
|
|
228
|
-
actor = server.
|
|
235
|
+
actor = server.get_user_or_default(user_id=user_id)
|
|
229
236
|
|
|
230
237
|
return server.insert_archival_memory(user_id=actor.id, agent_id=agent_id, memory_contents=request.text)
|
|
231
238
|
|
|
@@ -238,11 +245,12 @@ def delete_agent_archival_memory(
|
|
|
238
245
|
memory_id: str,
|
|
239
246
|
# memory_id: str = Query(..., description="Unique ID of the memory to be deleted."),
|
|
240
247
|
server: "SyncServer" = Depends(get_letta_server),
|
|
248
|
+
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
241
249
|
):
|
|
242
250
|
"""
|
|
243
251
|
Delete a memory from an agent's archival memory store.
|
|
244
252
|
"""
|
|
245
|
-
actor = server.
|
|
253
|
+
actor = server.get_user_or_default(user_id=user_id)
|
|
246
254
|
|
|
247
255
|
server.delete_archival_memory(user_id=actor.id, agent_id=agent_id, memory_id=memory_id)
|
|
248
256
|
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"})
|
|
@@ -268,11 +276,12 @@ def get_agent_messages(
|
|
|
268
276
|
DEFAULT_MESSAGE_TOOL_KWARG,
|
|
269
277
|
description="[Only applicable if use_assistant_message is True] The name of the message argument in the designated message tool.",
|
|
270
278
|
),
|
|
279
|
+
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
271
280
|
):
|
|
272
281
|
"""
|
|
273
282
|
Retrieve message history for an agent.
|
|
274
283
|
"""
|
|
275
|
-
actor = server.
|
|
284
|
+
actor = server.get_user_or_default(user_id=user_id)
|
|
276
285
|
|
|
277
286
|
return server.get_agent_recall_cursor(
|
|
278
287
|
user_id=actor.id,
|
|
@@ -306,13 +315,14 @@ async def send_message(
|
|
|
306
315
|
agent_id: str,
|
|
307
316
|
server: SyncServer = Depends(get_letta_server),
|
|
308
317
|
request: LettaRequest = Body(...),
|
|
318
|
+
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
309
319
|
):
|
|
310
320
|
"""
|
|
311
321
|
Process a user message and return the agent's response.
|
|
312
322
|
This endpoint accepts a message from a user and processes it through the agent.
|
|
313
323
|
It can optionally stream the response if 'stream_steps' or 'stream_tokens' is set to True.
|
|
314
324
|
"""
|
|
315
|
-
actor = server.
|
|
325
|
+
actor = server.get_user_or_default(user_id=user_id)
|
|
316
326
|
|
|
317
327
|
# TODO(charles): support sending multiple messages
|
|
318
328
|
assert len(request.messages) == 1, f"Multiple messages not supported: {request.messages}"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from typing import TYPE_CHECKING, List, Optional
|
|
2
2
|
|
|
3
|
-
from fastapi import APIRouter, Body, Depends, HTTPException, Query
|
|
3
|
+
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
|
|
4
4
|
|
|
5
5
|
from letta.schemas.block import Block, CreateBlock, UpdateBlock
|
|
6
6
|
from letta.server.rest_api.utils import get_letta_server
|
|
@@ -19,8 +19,9 @@ def list_blocks(
|
|
|
19
19
|
templates_only: bool = Query(True, description="Whether to include only templates"),
|
|
20
20
|
name: Optional[str] = Query(None, description="Name of the block"),
|
|
21
21
|
server: SyncServer = Depends(get_letta_server),
|
|
22
|
+
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
22
23
|
):
|
|
23
|
-
actor = server.
|
|
24
|
+
actor = server.get_user_or_default(user_id=user_id)
|
|
24
25
|
|
|
25
26
|
blocks = server.get_blocks(user_id=actor.id, label=label, template=templates_only, name=name)
|
|
26
27
|
if blocks is None:
|
|
@@ -32,8 +33,9 @@ def list_blocks(
|
|
|
32
33
|
def create_block(
|
|
33
34
|
create_block: CreateBlock = Body(...),
|
|
34
35
|
server: SyncServer = Depends(get_letta_server),
|
|
36
|
+
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
35
37
|
):
|
|
36
|
-
actor = server.
|
|
38
|
+
actor = server.get_user_or_default(user_id=user_id)
|
|
37
39
|
|
|
38
40
|
create_block.user_id = actor.id
|
|
39
41
|
return server.create_block(user_id=actor.id, request=create_block)
|