letta-nightly 0.4.1.dev20241007104134__py3-none-any.whl → 0.4.1.dev20241009104130__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of letta-nightly might be problematic. Click here for more details.

Files changed (35) hide show
  1. letta/agent.py +36 -10
  2. letta/client/client.py +8 -1
  3. letta/credentials.py +3 -3
  4. letta/errors.py +1 -1
  5. letta/functions/schema_generator.py +1 -1
  6. letta/llm_api/anthropic.py +3 -24
  7. letta/llm_api/azure_openai.py +53 -108
  8. letta/llm_api/azure_openai_constants.py +10 -0
  9. letta/llm_api/google_ai.py +39 -64
  10. letta/llm_api/helpers.py +208 -0
  11. letta/llm_api/llm_api_tools.py +43 -218
  12. letta/llm_api/openai.py +74 -50
  13. letta/main.py +1 -1
  14. letta/metadata.py +2 -0
  15. letta/providers.py +144 -31
  16. letta/schemas/agent.py +14 -0
  17. letta/schemas/llm_config.py +2 -2
  18. letta/schemas/openai/chat_completion_response.py +3 -0
  19. letta/schemas/tool.py +3 -3
  20. letta/server/rest_api/admin/tools.py +0 -1
  21. letta/server/rest_api/app.py +1 -17
  22. letta/server/rest_api/routers/openai/assistants/threads.py +10 -7
  23. letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +5 -3
  24. letta/server/rest_api/routers/v1/agents.py +23 -13
  25. letta/server/rest_api/routers/v1/blocks.py +5 -3
  26. letta/server/rest_api/routers/v1/jobs.py +5 -3
  27. letta/server/rest_api/routers/v1/sources.py +25 -13
  28. letta/server/rest_api/routers/v1/tools.py +12 -7
  29. letta/server/server.py +33 -37
  30. letta/settings.py +5 -113
  31. {letta_nightly-0.4.1.dev20241007104134.dist-info → letta_nightly-0.4.1.dev20241009104130.dist-info}/METADATA +1 -1
  32. {letta_nightly-0.4.1.dev20241007104134.dist-info → letta_nightly-0.4.1.dev20241009104130.dist-info}/RECORD +35 -33
  33. {letta_nightly-0.4.1.dev20241007104134.dist-info → letta_nightly-0.4.1.dev20241009104130.dist-info}/LICENSE +0 -0
  34. {letta_nightly-0.4.1.dev20241007104134.dist-info → letta_nightly-0.4.1.dev20241009104130.dist-info}/WHEEL +0 -0
  35. {letta_nightly-0.4.1.dev20241007104134.dist-info → letta_nightly-0.4.1.dev20241009104130.dist-info}/entry_points.txt +0 -0
letta/providers.py CHANGED
@@ -1,8 +1,13 @@
1
1
  from typing import List, Optional
2
2
 
3
- from pydantic import BaseModel, Field
3
+ from pydantic import BaseModel, Field, model_validator
4
4
 
5
5
  from letta.constants import LLM_MAX_TOKENS
6
+ from letta.llm_api.azure_openai import (
7
+ get_azure_chat_completions_endpoint,
8
+ get_azure_embeddings_endpoint,
9
+ )
10
+ from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
6
11
  from letta.schemas.embedding_config import EmbeddingConfig
7
12
  from letta.schemas.llm_config import LLMConfig
8
13
 
@@ -122,34 +127,64 @@ class OllamaProvider(OpenAIProvider):
122
127
  response = requests.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True})
123
128
  response_json = response.json()
124
129
 
125
- # thank you vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L1675
126
- possible_keys = [
127
- # OPT
128
- "max_position_embeddings",
129
- # GPT-2
130
- "n_positions",
131
- # MPT
132
- "max_seq_len",
133
- # ChatGLM2
134
- "seq_length",
135
- # Command-R
136
- "model_max_length",
137
- # Others
138
- "max_sequence_length",
139
- "max_seq_length",
140
- "seq_len",
141
- ]
142
-
130
+ ## thank you vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L1675
131
+ # possible_keys = [
132
+ # # OPT
133
+ # "max_position_embeddings",
134
+ # # GPT-2
135
+ # "n_positions",
136
+ # # MPT
137
+ # "max_seq_len",
138
+ # # ChatGLM2
139
+ # "seq_length",
140
+ # # Command-R
141
+ # "model_max_length",
142
+ # # Others
143
+ # "max_sequence_length",
144
+ # "max_seq_length",
145
+ # "seq_len",
146
+ # ]
143
147
  # max_position_embeddings
144
148
  # parse model cards: nous, dolphon, llama
145
149
  for key, value in response_json["model_info"].items():
146
- if "context_window" in key:
150
+ if "context_length" in key:
151
+ return value
152
+ return None
153
+
154
+ def get_model_embedding_dim(self, model_name: str):
155
+ import requests
156
+
157
+ response = requests.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True})
158
+ response_json = response.json()
159
+ for key, value in response_json["model_info"].items():
160
+ if "embedding_length" in key:
147
161
  return value
148
162
  return None
149
163
 
150
164
  def list_embedding_models(self) -> List[EmbeddingConfig]:
151
- # TODO: filter embedding models
152
- return []
165
+ # https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
166
+ import requests
167
+
168
+ response = requests.get(f"{self.base_url}/api/tags")
169
+ if response.status_code != 200:
170
+ raise Exception(f"Failed to list Ollama models: {response.text}")
171
+ response_json = response.json()
172
+
173
+ configs = []
174
+ for model in response_json["models"]:
175
+ embedding_dim = self.get_model_embedding_dim(model["name"])
176
+ if not embedding_dim:
177
+ continue
178
+ configs.append(
179
+ EmbeddingConfig(
180
+ embedding_model=model["name"],
181
+ embedding_endpoint_type="ollama",
182
+ embedding_endpoint=self.base_url,
183
+ embedding_dim=embedding_dim,
184
+ embedding_chunk_size=300,
185
+ )
186
+ )
187
+ return configs
153
188
 
154
189
 
155
190
  class GroqProvider(OpenAIProvider):
@@ -182,20 +217,21 @@ class GroqProvider(OpenAIProvider):
182
217
  class GoogleAIProvider(Provider):
183
218
  # gemini
184
219
  api_key: str = Field(..., description="API key for the Google AI API.")
185
- service_endpoint: str = "generativelanguage"
186
220
  base_url: str = "https://generativelanguage.googleapis.com"
187
221
 
188
222
  def list_llm_models(self):
189
223
  from letta.llm_api.google_ai import google_ai_get_model_list
190
224
 
191
- # TODO: use base_url instead
192
- model_options = google_ai_get_model_list(service_endpoint=self.service_endpoint, api_key=self.api_key)
225
+ model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key)
226
+ # filter by 'generateContent' models
227
+ model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]]
193
228
  model_options = [str(m["name"]) for m in model_options]
229
+
230
+ # filter by model names
194
231
  model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
232
+
195
233
  # TODO remove manual filtering for gemini-pro
196
234
  model_options = [mo for mo in model_options if str(mo).startswith("gemini") and "-pro" in str(mo)]
197
- # TODO: add context windows
198
- # model_options = ["gemini-pro"]
199
235
 
200
236
  configs = []
201
237
  for model in model_options:
@@ -210,17 +246,94 @@ class GoogleAIProvider(Provider):
210
246
  return configs
211
247
 
212
248
  def list_embedding_models(self):
213
- return []
249
+ from letta.llm_api.google_ai import google_ai_get_model_list
250
+
251
+ # TODO: use base_url instead
252
+ model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key)
253
+ # filter by 'generateContent' models
254
+ model_options = [mo for mo in model_options if "embedContent" in mo["supportedGenerationMethods"]]
255
+ model_options = [str(m["name"]) for m in model_options]
256
+ model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
257
+
258
+ configs = []
259
+ for model in model_options:
260
+ configs.append(
261
+ EmbeddingConfig(
262
+ embedding_model=model,
263
+ embedding_endpoint_type="google_ai",
264
+ embedding_endpoint=self.base_url,
265
+ embedding_dim=768,
266
+ embedding_chunk_size=300, # NOTE: max is 2048
267
+ )
268
+ )
269
+ return configs
214
270
 
215
271
  def get_model_context_window(self, model_name: str):
216
272
  from letta.llm_api.google_ai import google_ai_get_model_context_window
217
273
 
218
- # TODO: use base_url instead
219
- return google_ai_get_model_context_window(self.service_endpoint, self.api_key, model_name)
274
+ return google_ai_get_model_context_window(self.base_url, self.api_key, model_name)
220
275
 
221
276
 
222
277
  class AzureProvider(Provider):
223
- pass
278
+ name: str = "azure"
279
+ latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
280
+ base_url: str = Field(
281
+ ..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`."
282
+ )
283
+ api_key: str = Field(..., description="API key for the Azure API.")
284
+ api_version: str = Field(latest_api_version, description="API version for the Azure API")
285
+
286
+ @model_validator(mode="before")
287
+ def set_default_api_version(cls, values):
288
+ """
289
+ This ensures that api_version is always set to the default if None is passed in.
290
+ """
291
+ if values.get("api_version") is None:
292
+ values["api_version"] = cls.model_fields["latest_api_version"].default
293
+ return values
294
+
295
+ def list_llm_models(self) -> List[LLMConfig]:
296
+ from letta.llm_api.azure_openai import (
297
+ azure_openai_get_chat_completion_model_list,
298
+ )
299
+
300
+ model_options = azure_openai_get_chat_completion_model_list(self.base_url, api_key=self.api_key, api_version=self.api_version)
301
+ configs = []
302
+ for model_option in model_options:
303
+ model_name = model_option["id"]
304
+ context_window_size = self.get_model_context_window(model_name)
305
+ model_endpoint = get_azure_chat_completions_endpoint(self.base_url, model_name, self.api_version)
306
+ configs.append(
307
+ LLMConfig(model=model_name, model_endpoint_type="azure", model_endpoint=model_endpoint, context_window=context_window_size)
308
+ )
309
+ return configs
310
+
311
+ def list_embedding_models(self) -> List[EmbeddingConfig]:
312
+ from letta.llm_api.azure_openai import azure_openai_get_embeddings_model_list
313
+
314
+ model_options = azure_openai_get_embeddings_model_list(
315
+ self.base_url, api_key=self.api_key, api_version=self.api_version, require_embedding_in_name=True
316
+ )
317
+ configs = []
318
+ for model_option in model_options:
319
+ model_name = model_option["id"]
320
+ model_endpoint = get_azure_embeddings_endpoint(self.base_url, model_name, self.api_version)
321
+ configs.append(
322
+ EmbeddingConfig(
323
+ embedding_model=model_name,
324
+ embedding_endpoint_type="azure",
325
+ embedding_endpoint=model_endpoint,
326
+ embedding_dim=768,
327
+ embedding_chunk_size=300, # NOTE: max is 2048
328
+ )
329
+ )
330
+ return configs
331
+
332
+ def get_model_context_window(self, model_name: str):
333
+ """
334
+ This is hardcoded for now, since there is no API endpoints to retrieve metadata for a model.
335
+ """
336
+ return AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, 4096)
224
337
 
225
338
 
226
339
  class VLLMProvider(OpenAIProvider):
letta/schemas/agent.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import uuid
2
2
  from datetime import datetime
3
+ from enum import Enum
3
4
  from typing import Dict, List, Optional, Union
4
5
 
5
6
  from pydantic import BaseModel, Field, field_validator
@@ -21,6 +22,15 @@ class BaseAgent(LettaBase, validate_assignment=True):
21
22
  user_id: Optional[str] = Field(None, description="The user id of the agent.")
22
23
 
23
24
 
25
+ class AgentType(str, Enum):
26
+ """
27
+ Enum to represent the type of agent.
28
+ """
29
+
30
+ memgpt_agent = "memgpt_agent"
31
+ split_thread_agent = "split_thread_agent"
32
+
33
+
24
34
  class AgentState(BaseAgent):
25
35
  """
26
36
  Representation of an agent's state. This is the state of the agent at a given time, and is persisted in the DB backend. The state has all the information needed to recreate a persisted agent.
@@ -52,6 +62,9 @@ class AgentState(BaseAgent):
52
62
  # system prompt
53
63
  system: str = Field(..., description="The system prompt used by the agent.")
54
64
 
65
+ # agent configuration
66
+ agent_type: AgentType = Field(..., description="The type of agent.")
67
+
55
68
  # llm information
56
69
  llm_config: LLMConfig = Field(..., description="The LLM configuration used by the agent.")
57
70
  embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the agent.")
@@ -64,6 +77,7 @@ class CreateAgent(BaseAgent):
64
77
  memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.")
65
78
  tools: Optional[List[str]] = Field(None, description="The tools used by the agent.")
66
79
  system: Optional[str] = Field(None, description="The system prompt used by the agent.")
80
+ agent_type: Optional[AgentType] = Field(None, description="The type of agent.")
67
81
  llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
68
82
  embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")
69
83
 
@@ -11,7 +11,7 @@ class LLMConfig(BaseModel):
11
11
  model (str): The name of the LLM model.
12
12
  model_endpoint_type (str): The endpoint type for the model.
13
13
  model_endpoint (str): The endpoint for the model.
14
- model_wrapper (str): The wrapper for the model.
14
+ model_wrapper (str): The wrapper for the model. This is used to wrap additional text around the input/output of the model. This is useful for text-to-text completions, such as the Completions API in OpenAI.
15
15
  context_window (int): The context window size for the model.
16
16
  """
17
17
 
@@ -34,7 +34,7 @@ class LLMConfig(BaseModel):
34
34
  "vllm",
35
35
  "hugging-face",
36
36
  ] = Field(..., description="The endpoint type for the model.")
37
- model_endpoint: str = Field(..., description="The endpoint for the model.")
37
+ model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.")
38
38
  model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.")
39
39
  context_window: int = Field(..., description="The context window size for the model.")
40
40
 
@@ -74,6 +74,9 @@ class ChatCompletionResponse(BaseModel):
74
74
  object: Literal["chat.completion"] = "chat.completion"
75
75
  usage: UsageStatistics
76
76
 
77
+ def __str__(self):
78
+ return self.model_dump_json(indent=4)
79
+
77
80
 
78
81
  class FunctionCallDelta(BaseModel):
79
82
  # arguments: Optional[str] = None
letta/schemas/tool.py CHANGED
@@ -93,7 +93,7 @@ class Tool(BaseTool):
93
93
  # append heartbeat (necessary for triggering another reasoning step after this tool call)
94
94
  json_schema["parameters"]["properties"]["request_heartbeat"] = {
95
95
  "type": "boolean",
96
- "description": "Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function.",
96
+ "description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
97
97
  }
98
98
  json_schema["parameters"]["required"].append("request_heartbeat")
99
99
 
@@ -128,7 +128,7 @@ class Tool(BaseTool):
128
128
  # append heartbeat (necessary for triggering another reasoning step after this tool call)
129
129
  json_schema["parameters"]["properties"]["request_heartbeat"] = {
130
130
  "type": "boolean",
131
- "description": "Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function.",
131
+ "description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
132
132
  }
133
133
  json_schema["parameters"]["required"].append("request_heartbeat")
134
134
 
@@ -161,7 +161,7 @@ class Tool(BaseTool):
161
161
  # append heartbeat (necessary for triggering another reasoning step after this tool call)
162
162
  json_schema["parameters"]["properties"]["request_heartbeat"] = {
163
163
  "type": "boolean",
164
- "description": "Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function.",
164
+ "description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
165
165
  }
166
166
  json_schema["parameters"]["required"].append("request_heartbeat")
167
167
 
@@ -26,7 +26,6 @@ class CreateToolResponse(BaseModel):
26
26
 
27
27
 
28
28
  def setup_tools_index_router(server: SyncServer, interface: QueuingInterface):
29
- # get_current_user_with_server = partial(partial(get_current_user, server), password)
30
29
 
31
30
  @router.delete("/tools/{tool_name}", tags=["tools"])
32
31
  async def delete_tool(
@@ -5,8 +5,7 @@ from pathlib import Path
5
5
  from typing import Optional
6
6
 
7
7
  import uvicorn
8
- from fastapi import FastAPI, Request
9
- from fastapi.responses import JSONResponse
8
+ from fastapi import FastAPI
10
9
  from starlette.middleware.cors import CORSMiddleware
11
10
 
12
11
  from letta.server.constants import REST_DEFAULT_PORT
@@ -84,21 +83,6 @@ def create_application() -> "FastAPI":
84
83
  allow_headers=["*"],
85
84
  )
86
85
 
87
- @app.middleware("http")
88
- async def set_current_user_middleware(request: Request, call_next):
89
- user_id = request.headers.get("user_id")
90
- if user_id:
91
- try:
92
- server.set_current_user(user_id)
93
- except ValueError as e:
94
- # Return an HTTP 401 Unauthorized response
95
- # raise HTTPException(status_code=401, detail=str(e))
96
- return JSONResponse(status_code=401, content={"detail": str(e)})
97
- else:
98
- server.set_current_user(None)
99
- response = await call_next(request)
100
- return response
101
-
102
86
  for route in v1_routes:
103
87
  app.include_router(route, prefix=API_PREFIX)
104
88
  # this gives undocumented routes for "latest" and bare api calls.
@@ -1,7 +1,7 @@
1
1
  import uuid
2
- from typing import TYPE_CHECKING, List
2
+ from typing import TYPE_CHECKING, List, Optional
3
3
 
4
- from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query
4
+ from fastapi import APIRouter, Body, Depends, Header, HTTPException, Path, Query
5
5
 
6
6
  from letta.constants import DEFAULT_PRESET
7
7
  from letta.schemas.agent import CreateAgent
@@ -43,11 +43,12 @@ router = APIRouter(prefix="/v1/threads", tags=["threads"])
43
43
  def create_thread(
44
44
  request: CreateThreadRequest = Body(...),
45
45
  server: SyncServer = Depends(get_letta_server),
46
+ user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
46
47
  ):
47
48
  # TODO: use requests.description and requests.metadata fields
48
49
  # TODO: handle requests.file_ids and requests.tools
49
50
  # TODO: eventually allow request to override embedding/llm model
50
- actor = server.get_current_user()
51
+ actor = server.get_user_or_default(user_id=user_id)
51
52
 
52
53
  print("Create thread/agent", request)
53
54
  # create a letta agent
@@ -67,8 +68,9 @@ def create_thread(
67
68
  def retrieve_thread(
68
69
  thread_id: str = Path(..., description="The unique identifier of the thread."),
69
70
  server: SyncServer = Depends(get_letta_server),
71
+ user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
70
72
  ):
71
- actor = server.get_current_user()
73
+ actor = server.get_user_or_default(user_id=user_id)
72
74
  agent = server.get_agent(user_id=actor.id, agent_id=thread_id)
73
75
  assert agent is not None
74
76
  return OpenAIThread(
@@ -100,8 +102,9 @@ def create_message(
100
102
  thread_id: str = Path(..., description="The unique identifier of the thread."),
101
103
  request: CreateMessageRequest = Body(...),
102
104
  server: SyncServer = Depends(get_letta_server),
105
+ user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
103
106
  ):
104
- actor = server.get_current_user()
107
+ actor = server.get_user_or_default(user_id=user_id)
105
108
  agent_id = thread_id
106
109
  # create message object
107
110
  message = Message(
@@ -143,8 +146,9 @@ def list_messages(
143
146
  after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
144
147
  before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
145
148
  server: SyncServer = Depends(get_letta_server),
149
+ user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
146
150
  ):
147
- actor = server.get_current_user()
151
+ actor = server.get_user_or_default(user_id)
148
152
  after_uuid = after if before else None
149
153
  before_uuid = before if before else None
150
154
  agent_id = thread_id
@@ -239,7 +243,6 @@ def create_run(
239
243
  request: CreateRunRequest = Body(...),
240
244
  server: SyncServer = Depends(get_letta_server),
241
245
  ):
242
- server.get_current_user()
243
246
 
244
247
  # TODO: add request.instructions as a message?
245
248
  agent_id = thread_id
@@ -1,7 +1,7 @@
1
1
  import json
2
- from typing import TYPE_CHECKING
2
+ from typing import TYPE_CHECKING, Optional
3
3
 
4
- from fastapi import APIRouter, Body, Depends, HTTPException
4
+ from fastapi import APIRouter, Body, Depends, Header, HTTPException
5
5
 
6
6
  from letta.schemas.enums import MessageRole
7
7
  from letta.schemas.letta_message import FunctionCall, LettaMessage
@@ -30,12 +30,14 @@ router = APIRouter(prefix="/v1/chat/completions", tags=["chat_completions"])
30
30
  async def create_chat_completion(
31
31
  completion_request: ChatCompletionRequest = Body(...),
32
32
  server: "SyncServer" = Depends(get_letta_server),
33
+ user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
33
34
  ):
34
35
  """Send a message to a Letta agent via a /chat/completions completion_request
35
36
  The bearer token will be used to identify the user.
36
37
  The 'user' field in the completion_request should be set to the agent ID.
37
38
  """
38
- actor = server.get_current_user()
39
+ actor = server.get_user_or_default(user_id=user_id)
40
+
39
41
  agent_id = completion_request.user
40
42
  if agent_id is None:
41
43
  raise HTTPException(status_code=400, detail="Must pass agent_id in the 'user' field")
@@ -2,7 +2,7 @@ import asyncio
2
2
  from datetime import datetime
3
3
  from typing import Dict, List, Optional, Union
4
4
 
5
- from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
5
+ from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, status
6
6
  from fastapi.responses import JSONResponse, StreamingResponse
7
7
  from starlette.responses import StreamingResponse
8
8
 
@@ -40,12 +40,13 @@ router = APIRouter(prefix="/agents", tags=["agents"])
40
40
  @router.get("/", response_model=List[AgentState], operation_id="list_agents")
41
41
  def list_agents(
42
42
  server: "SyncServer" = Depends(get_letta_server),
43
+ user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
43
44
  ):
44
45
  """
45
46
  List all agents associated with a given user.
46
47
  This endpoint retrieves a list of all agents and their configurations associated with the specified user ID.
47
48
  """
48
- actor = server.get_current_user()
49
+ actor = server.get_user_or_default(user_id=user_id)
49
50
 
50
51
  return server.list_agents(user_id=actor.id)
51
52
 
@@ -54,11 +55,12 @@ def list_agents(
54
55
  def create_agent(
55
56
  agent: CreateAgent = Body(...),
56
57
  server: "SyncServer" = Depends(get_letta_server),
58
+ user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
57
59
  ):
58
60
  """
59
61
  Create a new agent with the specified configuration.
60
62
  """
61
- actor = server.get_current_user()
63
+ actor = server.get_user_or_default(user_id=user_id)
62
64
  agent.user_id = actor.id
63
65
  # TODO: sarah make general
64
66
  # TODO: eventually remove this
@@ -74,9 +76,10 @@ def update_agent(
74
76
  agent_id: str,
75
77
  update_agent: UpdateAgentState = Body(...),
76
78
  server: "SyncServer" = Depends(get_letta_server),
79
+ user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
77
80
  ):
78
81
  """Update an exsiting agent"""
79
- actor = server.get_current_user()
82
+ actor = server.get_user_or_default(user_id=user_id)
80
83
 
81
84
  update_agent.id = agent_id
82
85
  return server.update_agent(update_agent, user_id=actor.id)
@@ -86,11 +89,12 @@ def update_agent(
86
89
  def get_agent_state(
87
90
  agent_id: str,
88
91
  server: "SyncServer" = Depends(get_letta_server),
92
+ user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
89
93
  ):
90
94
  """
91
95
  Get the state of the agent.
92
96
  """
93
- actor = server.get_current_user()
97
+ actor = server.get_user_or_default(user_id=user_id)
94
98
 
95
99
  if not server.ms.get_agent(user_id=actor.id, agent_id=agent_id):
96
100
  # agent does not exist
@@ -103,11 +107,12 @@ def get_agent_state(
103
107
  def delete_agent(
104
108
  agent_id: str,
105
109
  server: "SyncServer" = Depends(get_letta_server),
110
+ user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
106
111
  ):
107
112
  """
108
113
  Delete an agent.
109
114
  """
110
- actor = server.get_current_user()
115
+ actor = server.get_user_or_default(user_id=user_id)
111
116
 
112
117
  return server.delete_agent(user_id=actor.id, agent_id=agent_id)
113
118
 
@@ -120,7 +125,6 @@ def get_agent_sources(
120
125
  """
121
126
  Get the sources associated with an agent.
122
127
  """
123
- server.get_current_user()
124
128
 
125
129
  return server.list_attached_sources(agent_id)
126
130
 
@@ -155,12 +159,13 @@ def update_agent_memory(
155
159
  agent_id: str,
156
160
  request: Dict = Body(...),
157
161
  server: "SyncServer" = Depends(get_letta_server),
162
+ user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
158
163
  ):
159
164
  """
160
165
  Update the core memory of a specific agent.
161
166
  This endpoint accepts new memory contents (human and persona) and updates the core memory of the agent identified by the user ID and agent ID.
162
167
  """
163
- actor = server.get_current_user()
168
+ actor = server.get_user_or_default(user_id=user_id)
164
169
 
165
170
  memory = server.update_agent_core_memory(user_id=actor.id, agent_id=agent_id, new_memory_contents=request)
166
171
  return memory
@@ -197,11 +202,12 @@ def get_agent_archival_memory(
197
202
  after: Optional[int] = Query(None, description="Unique ID of the memory to start the query range at."),
198
203
  before: Optional[int] = Query(None, description="Unique ID of the memory to end the query range at."),
199
204
  limit: Optional[int] = Query(None, description="How many results to include in the response."),
205
+ user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
200
206
  ):
201
207
  """
202
208
  Retrieve the memories in an agent's archival memory store (paginated query).
203
209
  """
204
- actor = server.get_current_user()
210
+ actor = server.get_user_or_default(user_id=user_id)
205
211
 
206
212
  # TODO need to add support for non-postgres here
207
213
  # chroma will throw:
@@ -221,11 +227,12 @@ def insert_agent_archival_memory(
221
227
  agent_id: str,
222
228
  request: CreateArchivalMemory = Body(...),
223
229
  server: "SyncServer" = Depends(get_letta_server),
230
+ user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
224
231
  ):
225
232
  """
226
233
  Insert a memory into an agent's archival memory store.
227
234
  """
228
- actor = server.get_current_user()
235
+ actor = server.get_user_or_default(user_id=user_id)
229
236
 
230
237
  return server.insert_archival_memory(user_id=actor.id, agent_id=agent_id, memory_contents=request.text)
231
238
 
@@ -238,11 +245,12 @@ def delete_agent_archival_memory(
238
245
  memory_id: str,
239
246
  # memory_id: str = Query(..., description="Unique ID of the memory to be deleted."),
240
247
  server: "SyncServer" = Depends(get_letta_server),
248
+ user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
241
249
  ):
242
250
  """
243
251
  Delete a memory from an agent's archival memory store.
244
252
  """
245
- actor = server.get_current_user()
253
+ actor = server.get_user_or_default(user_id=user_id)
246
254
 
247
255
  server.delete_archival_memory(user_id=actor.id, agent_id=agent_id, memory_id=memory_id)
248
256
  return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"})
@@ -268,11 +276,12 @@ def get_agent_messages(
268
276
  DEFAULT_MESSAGE_TOOL_KWARG,
269
277
  description="[Only applicable if use_assistant_message is True] The name of the message argument in the designated message tool.",
270
278
  ),
279
+ user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
271
280
  ):
272
281
  """
273
282
  Retrieve message history for an agent.
274
283
  """
275
- actor = server.get_current_user()
284
+ actor = server.get_user_or_default(user_id=user_id)
276
285
 
277
286
  return server.get_agent_recall_cursor(
278
287
  user_id=actor.id,
@@ -306,13 +315,14 @@ async def send_message(
306
315
  agent_id: str,
307
316
  server: SyncServer = Depends(get_letta_server),
308
317
  request: LettaRequest = Body(...),
318
+ user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
309
319
  ):
310
320
  """
311
321
  Process a user message and return the agent's response.
312
322
  This endpoint accepts a message from a user and processes it through the agent.
313
323
  It can optionally stream the response if 'stream_steps' or 'stream_tokens' is set to True.
314
324
  """
315
- actor = server.get_current_user()
325
+ actor = server.get_user_or_default(user_id=user_id)
316
326
 
317
327
  # TODO(charles): support sending multiple messages
318
328
  assert len(request.messages) == 1, f"Multiple messages not supported: {request.messages}"
@@ -1,6 +1,6 @@
1
1
  from typing import TYPE_CHECKING, List, Optional
2
2
 
3
- from fastapi import APIRouter, Body, Depends, HTTPException, Query
3
+ from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
4
4
 
5
5
  from letta.schemas.block import Block, CreateBlock, UpdateBlock
6
6
  from letta.server.rest_api.utils import get_letta_server
@@ -19,8 +19,9 @@ def list_blocks(
19
19
  templates_only: bool = Query(True, description="Whether to include only templates"),
20
20
  name: Optional[str] = Query(None, description="Name of the block"),
21
21
  server: SyncServer = Depends(get_letta_server),
22
+ user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
22
23
  ):
23
- actor = server.get_current_user()
24
+ actor = server.get_user_or_default(user_id=user_id)
24
25
 
25
26
  blocks = server.get_blocks(user_id=actor.id, label=label, template=templates_only, name=name)
26
27
  if blocks is None:
@@ -32,8 +33,9 @@ def list_blocks(
32
33
  def create_block(
33
34
  create_block: CreateBlock = Body(...),
34
35
  server: SyncServer = Depends(get_letta_server),
36
+ user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
35
37
  ):
36
- actor = server.get_current_user()
38
+ actor = server.get_user_or_default(user_id=user_id)
37
39
 
38
40
  create_block.user_id = actor.id
39
41
  return server.create_block(user_id=actor.id, request=create_block)