letta-nightly 0.7.10.dev20250507104304__py3-none-any.whl → 0.7.12.dev20250508044425__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. letta/__init__.py +1 -1
  2. letta/agent.py +8 -4
  3. letta/agents/letta_agent.py +3 -5
  4. letta/agents/letta_agent_batch.py +2 -4
  5. letta/client/client.py +2 -2
  6. letta/functions/async_composio_toolset.py +106 -0
  7. letta/functions/composio_helpers.py +20 -24
  8. letta/llm_api/anthropic.py +31 -6
  9. letta/llm_api/anthropic_client.py +10 -8
  10. letta/llm_api/google_ai_client.py +32 -10
  11. letta/llm_api/google_constants.py +2 -0
  12. letta/llm_api/google_vertex_client.py +107 -27
  13. letta/llm_api/llm_api_tools.py +9 -3
  14. letta/llm_api/llm_client.py +9 -11
  15. letta/llm_api/llm_client_base.py +6 -5
  16. letta/llm_api/openai.py +16 -0
  17. letta/llm_api/openai_client.py +6 -6
  18. letta/local_llm/constants.py +1 -0
  19. letta/memory.py +8 -5
  20. letta/orm/provider.py +1 -0
  21. letta/schemas/enums.py +6 -0
  22. letta/schemas/llm_config.py +2 -0
  23. letta/schemas/message.py +3 -3
  24. letta/schemas/providers.py +58 -2
  25. letta/server/rest_api/routers/v1/agents.py +10 -5
  26. letta/server/rest_api/routers/v1/llms.py +16 -6
  27. letta/server/rest_api/routers/v1/providers.py +24 -4
  28. letta/server/rest_api/routers/v1/sources.py +1 -0
  29. letta/server/server.py +58 -24
  30. letta/services/provider_manager.py +26 -8
  31. letta/settings.py +2 -0
  32. {letta_nightly-0.7.10.dev20250507104304.dist-info → letta_nightly-0.7.12.dev20250508044425.dist-info}/METADATA +2 -2
  33. {letta_nightly-0.7.10.dev20250507104304.dist-info → letta_nightly-0.7.12.dev20250508044425.dist-info}/RECORD +36 -35
  34. {letta_nightly-0.7.10.dev20250507104304.dist-info → letta_nightly-0.7.12.dev20250508044425.dist-info}/LICENSE +0 -0
  35. {letta_nightly-0.7.10.dev20250507104304.dist-info → letta_nightly-0.7.12.dev20250508044425.dist-info}/WHEEL +0 -0
  36. {letta_nightly-0.7.10.dev20250507104304.dist-info → letta_nightly-0.7.12.dev20250508044425.dist-info}/entry_points.txt +0 -0
@@ -2,14 +2,14 @@ import warnings
2
2
  from datetime import datetime
3
3
  from typing import List, Literal, Optional
4
4
 
5
- from pydantic import Field, model_validator
5
+ from pydantic import BaseModel, Field, model_validator
6
6
 
7
7
  from letta.constants import LETTA_MODEL_ENDPOINT, LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW
8
8
  from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_azure_embeddings_endpoint
9
9
  from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
10
10
  from letta.schemas.embedding_config import EmbeddingConfig
11
11
  from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES
12
- from letta.schemas.enums import ProviderType
12
+ from letta.schemas.enums import ProviderCategory, ProviderType
13
13
  from letta.schemas.letta_base import LettaBase
14
14
  from letta.schemas.llm_config import LLMConfig
15
15
  from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES
@@ -24,6 +24,7 @@ class Provider(ProviderBase):
24
24
  id: Optional[str] = Field(None, description="The id of the provider, lazily created by the database manager.")
25
25
  name: str = Field(..., description="The name of the provider")
26
26
  provider_type: ProviderType = Field(..., description="The type of the provider")
27
+ provider_category: ProviderCategory = Field(..., description="The category of the provider (base or byok)")
27
28
  api_key: Optional[str] = Field(None, description="API key used for requests to the provider.")
28
29
  base_url: Optional[str] = Field(None, description="Base URL for the provider.")
29
30
  organization_id: Optional[str] = Field(None, description="The organization id of the user")
@@ -39,6 +40,10 @@ class Provider(ProviderBase):
39
40
  if not self.id:
40
41
  self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__)
41
42
 
43
+ def check_api_key(self):
44
+ """Check if the API key is valid for the provider"""
45
+ raise NotImplementedError
46
+
42
47
  def list_llm_models(self) -> List[LLMConfig]:
43
48
  return []
44
49
 
@@ -111,8 +116,14 @@ class ProviderUpdate(ProviderBase):
111
116
  api_key: str = Field(..., description="API key used for requests to the provider.")
112
117
 
113
118
 
119
+ class ProviderCheck(BaseModel):
120
+ provider_type: ProviderType = Field(..., description="The type of the provider.")
121
+ api_key: str = Field(..., description="API key used for requests to the provider.")
122
+
123
+
114
124
  class LettaProvider(Provider):
115
125
  provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.")
126
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
116
127
 
117
128
  def list_llm_models(self) -> List[LLMConfig]:
118
129
  return [
@@ -123,6 +134,7 @@ class LettaProvider(Provider):
123
134
  context_window=8192,
124
135
  handle=self.get_handle("letta-free"),
125
136
  provider_name=self.name,
137
+ provider_category=self.provider_category,
126
138
  )
127
139
  ]
128
140
 
@@ -141,9 +153,15 @@ class LettaProvider(Provider):
141
153
 
142
154
  class OpenAIProvider(Provider):
143
155
  provider_type: Literal[ProviderType.openai] = Field(ProviderType.openai, description="The type of the provider.")
156
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
144
157
  api_key: str = Field(..., description="API key for the OpenAI API.")
145
158
  base_url: str = Field(..., description="Base URL for the OpenAI API.")
146
159
 
160
+ def check_api_key(self):
161
+ from letta.llm_api.openai import openai_check_valid_api_key
162
+
163
+ openai_check_valid_api_key(self.base_url, self.api_key)
164
+
147
165
  def list_llm_models(self) -> List[LLMConfig]:
148
166
  from letta.llm_api.openai import openai_get_model_list
149
167
 
@@ -225,6 +243,7 @@ class OpenAIProvider(Provider):
225
243
  context_window=context_window_size,
226
244
  handle=self.get_handle(model_name),
227
245
  provider_name=self.name,
246
+ provider_category=self.provider_category,
228
247
  )
229
248
  )
230
249
 
@@ -281,6 +300,7 @@ class DeepSeekProvider(OpenAIProvider):
281
300
  """
282
301
 
283
302
  provider_type: Literal[ProviderType.deepseek] = Field(ProviderType.deepseek, description="The type of the provider.")
303
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
284
304
  base_url: str = Field("https://api.deepseek.com/v1", description="Base URL for the DeepSeek API.")
285
305
  api_key: str = Field(..., description="API key for the DeepSeek API.")
286
306
 
@@ -332,6 +352,7 @@ class DeepSeekProvider(OpenAIProvider):
332
352
  handle=self.get_handle(model_name),
333
353
  put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
334
354
  provider_name=self.name,
355
+ provider_category=self.provider_category,
335
356
  )
336
357
  )
337
358
 
@@ -344,6 +365,7 @@ class DeepSeekProvider(OpenAIProvider):
344
365
 
345
366
  class LMStudioOpenAIProvider(OpenAIProvider):
346
367
  provider_type: Literal[ProviderType.lmstudio_openai] = Field(ProviderType.lmstudio_openai, description="The type of the provider.")
368
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
347
369
  base_url: str = Field(..., description="Base URL for the LMStudio OpenAI API.")
348
370
  api_key: Optional[str] = Field(None, description="API key for the LMStudio API.")
349
371
 
@@ -470,6 +492,7 @@ class XAIProvider(OpenAIProvider):
470
492
  """https://docs.x.ai/docs/api-reference"""
471
493
 
472
494
  provider_type: Literal[ProviderType.xai] = Field(ProviderType.xai, description="The type of the provider.")
495
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
473
496
  api_key: str = Field(..., description="API key for the xAI/Grok API.")
474
497
  base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.")
475
498
 
@@ -523,6 +546,7 @@ class XAIProvider(OpenAIProvider):
523
546
  context_window=context_window_size,
524
547
  handle=self.get_handle(model_name),
525
548
  provider_name=self.name,
549
+ provider_category=self.provider_category,
526
550
  )
527
551
  )
528
552
 
@@ -535,9 +559,15 @@ class XAIProvider(OpenAIProvider):
535
559
 
536
560
  class AnthropicProvider(Provider):
537
561
  provider_type: Literal[ProviderType.anthropic] = Field(ProviderType.anthropic, description="The type of the provider.")
562
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
538
563
  api_key: str = Field(..., description="API key for the Anthropic API.")
539
564
  base_url: str = "https://api.anthropic.com/v1"
540
565
 
566
+ def check_api_key(self):
567
+ from letta.llm_api.anthropic import anthropic_check_valid_api_key
568
+
569
+ anthropic_check_valid_api_key(self.api_key)
570
+
541
571
  def list_llm_models(self) -> List[LLMConfig]:
542
572
  from letta.llm_api.anthropic import MODEL_LIST, anthropic_get_model_list
543
573
 
@@ -611,6 +641,7 @@ class AnthropicProvider(Provider):
611
641
  put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
612
642
  max_tokens=max_tokens,
613
643
  provider_name=self.name,
644
+ provider_category=self.provider_category,
614
645
  )
615
646
  )
616
647
  return configs
@@ -621,6 +652,7 @@ class AnthropicProvider(Provider):
621
652
 
622
653
  class MistralProvider(Provider):
623
654
  provider_type: Literal[ProviderType.mistral] = Field(ProviderType.mistral, description="The type of the provider.")
655
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
624
656
  api_key: str = Field(..., description="API key for the Mistral API.")
625
657
  base_url: str = "https://api.mistral.ai/v1"
626
658
 
@@ -645,6 +677,7 @@ class MistralProvider(Provider):
645
677
  context_window=model["max_context_length"],
646
678
  handle=self.get_handle(model["id"]),
647
679
  provider_name=self.name,
680
+ provider_category=self.provider_category,
648
681
  )
649
682
  )
650
683
 
@@ -672,6 +705,7 @@ class OllamaProvider(OpenAIProvider):
672
705
  """
673
706
 
674
707
  provider_type: Literal[ProviderType.ollama] = Field(ProviderType.ollama, description="The type of the provider.")
708
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
675
709
  base_url: str = Field(..., description="Base URL for the Ollama API.")
676
710
  api_key: Optional[str] = Field(None, description="API key for the Ollama API (default: `None`).")
677
711
  default_prompt_formatter: str = Field(
@@ -702,6 +736,7 @@ class OllamaProvider(OpenAIProvider):
702
736
  context_window=context_window,
703
737
  handle=self.get_handle(model["name"]),
704
738
  provider_name=self.name,
739
+ provider_category=self.provider_category,
705
740
  )
706
741
  )
707
742
  return configs
@@ -785,6 +820,7 @@ class OllamaProvider(OpenAIProvider):
785
820
 
786
821
  class GroqProvider(OpenAIProvider):
787
822
  provider_type: Literal[ProviderType.groq] = Field(ProviderType.groq, description="The type of the provider.")
823
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
788
824
  base_url: str = "https://api.groq.com/openai/v1"
789
825
  api_key: str = Field(..., description="API key for the Groq API.")
790
826
 
@@ -804,6 +840,7 @@ class GroqProvider(OpenAIProvider):
804
840
  context_window=model["context_window"],
805
841
  handle=self.get_handle(model["id"]),
806
842
  provider_name=self.name,
843
+ provider_category=self.provider_category,
807
844
  )
808
845
  )
809
846
  return configs
@@ -825,6 +862,7 @@ class TogetherProvider(OpenAIProvider):
825
862
  """
826
863
 
827
864
  provider_type: Literal[ProviderType.together] = Field(ProviderType.together, description="The type of the provider.")
865
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
828
866
  base_url: str = "https://api.together.ai/v1"
829
867
  api_key: str = Field(..., description="API key for the TogetherAI API.")
830
868
  default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
@@ -873,6 +911,7 @@ class TogetherProvider(OpenAIProvider):
873
911
  context_window=context_window_size,
874
912
  handle=self.get_handle(model_name),
875
913
  provider_name=self.name,
914
+ provider_category=self.provider_category,
876
915
  )
877
916
  )
878
917
 
@@ -927,9 +966,15 @@ class TogetherProvider(OpenAIProvider):
927
966
  class GoogleAIProvider(Provider):
928
967
  # gemini
929
968
  provider_type: Literal[ProviderType.google_ai] = Field(ProviderType.google_ai, description="The type of the provider.")
969
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
930
970
  api_key: str = Field(..., description="API key for the Google AI API.")
931
971
  base_url: str = "https://generativelanguage.googleapis.com"
932
972
 
973
+ def check_api_key(self):
974
+ from letta.llm_api.google_ai_client import google_ai_check_valid_api_key
975
+
976
+ google_ai_check_valid_api_key(self.api_key)
977
+
933
978
  def list_llm_models(self):
934
979
  from letta.llm_api.google_ai_client import google_ai_get_model_list
935
980
 
@@ -955,6 +1000,7 @@ class GoogleAIProvider(Provider):
955
1000
  handle=self.get_handle(model),
956
1001
  max_tokens=8192,
957
1002
  provider_name=self.name,
1003
+ provider_category=self.provider_category,
958
1004
  )
959
1005
  )
960
1006
  return configs
@@ -991,6 +1037,7 @@ class GoogleAIProvider(Provider):
991
1037
 
992
1038
  class GoogleVertexProvider(Provider):
993
1039
  provider_type: Literal[ProviderType.google_vertex] = Field(ProviderType.google_vertex, description="The type of the provider.")
1040
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
994
1041
  google_cloud_project: str = Field(..., description="GCP project ID for the Google Vertex API.")
995
1042
  google_cloud_location: str = Field(..., description="GCP region for the Google Vertex API.")
996
1043
 
@@ -1008,6 +1055,7 @@ class GoogleVertexProvider(Provider):
1008
1055
  handle=self.get_handle(model),
1009
1056
  max_tokens=8192,
1010
1057
  provider_name=self.name,
1058
+ provider_category=self.provider_category,
1011
1059
  )
1012
1060
  )
1013
1061
  return configs
@@ -1032,6 +1080,7 @@ class GoogleVertexProvider(Provider):
1032
1080
 
1033
1081
  class AzureProvider(Provider):
1034
1082
  provider_type: Literal[ProviderType.azure] = Field(ProviderType.azure, description="The type of the provider.")
1083
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
1035
1084
  latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
1036
1085
  base_url: str = Field(
1037
1086
  ..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`."
@@ -1065,6 +1114,7 @@ class AzureProvider(Provider):
1065
1114
  context_window=context_window_size,
1066
1115
  handle=self.get_handle(model_name),
1067
1116
  provider_name=self.name,
1117
+ provider_category=self.provider_category,
1068
1118
  ),
1069
1119
  )
1070
1120
  return configs
@@ -1106,6 +1156,7 @@ class VLLMChatCompletionsProvider(Provider):
1106
1156
 
1107
1157
  # NOTE: vLLM only serves one model at a time (so could configure that through env variables)
1108
1158
  provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
1159
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
1109
1160
  base_url: str = Field(..., description="Base URL for the vLLM API.")
1110
1161
 
1111
1162
  def list_llm_models(self) -> List[LLMConfig]:
@@ -1125,6 +1176,7 @@ class VLLMChatCompletionsProvider(Provider):
1125
1176
  context_window=model["max_model_len"],
1126
1177
  handle=self.get_handle(model["id"]),
1127
1178
  provider_name=self.name,
1179
+ provider_category=self.provider_category,
1128
1180
  )
1129
1181
  )
1130
1182
  return configs
@@ -1139,6 +1191,7 @@ class VLLMCompletionsProvider(Provider):
1139
1191
 
1140
1192
  # NOTE: vLLM only serves one model at a time (so could configure that through env variables)
1141
1193
  provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
1194
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
1142
1195
  base_url: str = Field(..., description="Base URL for the vLLM API.")
1143
1196
  default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
1144
1197
 
@@ -1159,6 +1212,7 @@ class VLLMCompletionsProvider(Provider):
1159
1212
  context_window=model["max_model_len"],
1160
1213
  handle=self.get_handle(model["id"]),
1161
1214
  provider_name=self.name,
1215
+ provider_category=self.provider_category,
1162
1216
  )
1163
1217
  )
1164
1218
  return configs
@@ -1174,6 +1228,7 @@ class CohereProvider(OpenAIProvider):
1174
1228
 
1175
1229
  class AnthropicBedrockProvider(Provider):
1176
1230
  provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.")
1231
+ provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
1177
1232
  aws_region: str = Field(..., description="AWS region for Bedrock")
1178
1233
 
1179
1234
  def list_llm_models(self):
@@ -1192,6 +1247,7 @@ class AnthropicBedrockProvider(Provider):
1192
1247
  context_window=self.get_model_context_window(model_arn),
1193
1248
  handle=self.get_handle(model_arn),
1194
1249
  provider_name=self.name,
1250
+ provider_category=self.provider_category,
1195
1251
  )
1196
1252
  )
1197
1253
  return configs
@@ -631,12 +631,17 @@ async def send_message(
631
631
  # TODO: This is redundant, remove soon
632
632
  agent = server.agent_manager.get_agent_by_id(agent_id, actor)
633
633
 
634
- if (
634
+ if all(
635
+ (
636
+ settings.use_experimental,
637
+ not agent.enable_sleeptime,
638
+ not agent.multi_agent_group,
639
+ not agent.agent_type == AgentType.sleeptime_agent,
640
+ )
641
+ ) and (
642
+ # LLM Model Check: (1) Anthropic or (2) Google Vertex + Flag
635
643
  agent.llm_config.model_endpoint_type == "anthropic"
636
- and not agent.enable_sleeptime
637
- and not agent.multi_agent_group
638
- and not agent.agent_type == AgentType.sleeptime_agent
639
- and settings.use_experimental
644
+ or (agent.llm_config.model_endpoint_type == "google_vertex" and settings.use_vertex_async_loop_experimental)
640
645
  ):
641
646
  experimental_agent = LettaAgent(
642
647
  agent_id=agent_id,
@@ -1,8 +1,9 @@
1
1
  from typing import TYPE_CHECKING, List, Optional
2
2
 
3
- from fastapi import APIRouter, Depends, Query
3
+ from fastapi import APIRouter, Depends, Header, Query
4
4
 
5
5
  from letta.schemas.embedding_config import EmbeddingConfig
6
+ from letta.schemas.enums import ProviderCategory, ProviderType
6
7
  from letta.schemas.llm_config import LLMConfig
7
8
  from letta.server.rest_api.utils import get_letta_server
8
9
 
@@ -14,11 +15,19 @@ router = APIRouter(prefix="/models", tags=["models", "llms"])
14
15
 
15
16
  @router.get("/", response_model=List[LLMConfig], operation_id="list_models")
16
17
  def list_llm_models(
17
- byok_only: Optional[bool] = Query(None),
18
+ provider_category: Optional[List[ProviderCategory]] = Query(None),
19
+ provider_name: Optional[str] = Query(None),
20
+ provider_type: Optional[ProviderType] = Query(None),
18
21
  server: "SyncServer" = Depends(get_letta_server),
22
+ actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
19
23
  ):
20
-
21
- models = server.list_llm_models(byok_only=byok_only)
24
+ actor = server.user_manager.get_user_or_default(user_id=actor_id)
25
+ models = server.list_llm_models(
26
+ provider_category=provider_category,
27
+ provider_name=provider_name,
28
+ provider_type=provider_type,
29
+ actor=actor,
30
+ )
22
31
  # print(models)
23
32
  return models
24
33
 
@@ -26,8 +35,9 @@ def list_llm_models(
26
35
  @router.get("/embedding", response_model=List[EmbeddingConfig], operation_id="list_embedding_models")
27
36
  def list_embedding_models(
28
37
  server: "SyncServer" = Depends(get_letta_server),
38
+ actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
29
39
  ):
30
-
31
- models = server.list_embedding_models()
40
+ actor = server.user_manager.get_user_or_default(user_id=actor_id)
41
+ models = server.list_embedding_models(actor=actor)
32
42
  # print(models)
33
43
  return models
@@ -1,9 +1,12 @@
1
1
  from typing import TYPE_CHECKING, List, Optional
2
2
 
3
- from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
3
+ from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, status
4
+ from fastapi.responses import JSONResponse
4
5
 
6
+ from letta.errors import LLMAuthenticationError
7
+ from letta.orm.errors import NoResultFound
5
8
  from letta.schemas.enums import ProviderType
6
- from letta.schemas.providers import Provider, ProviderCreate, ProviderUpdate
9
+ from letta.schemas.providers import Provider, ProviderCheck, ProviderCreate, ProviderUpdate
7
10
  from letta.server.rest_api.utils import get_letta_server
8
11
 
9
12
  if TYPE_CHECKING:
@@ -45,7 +48,8 @@ def create_provider(
45
48
  """
46
49
  actor = server.user_manager.get_user_or_default(user_id=actor_id)
47
50
 
48
- provider = Provider(**request.model_dump())
51
+ provider = ProviderCreate(**request.model_dump())
52
+
49
53
  provider = server.provider_manager.create_provider(provider, actor=actor)
50
54
  return provider
51
55
 
@@ -61,7 +65,23 @@ def modify_provider(
61
65
  Update an existing custom provider
62
66
  """
63
67
  actor = server.user_manager.get_user_or_default(user_id=actor_id)
64
- return server.provider_manager.update_provider(provider_id=provider_id, request=request, actor=actor)
68
+ return server.provider_manager.update_provider(provider_id=provider_id, provider_update=request, actor=actor)
69
+
70
+
71
+ @router.get("/check", response_model=None, operation_id="check_provider")
72
+ def check_provider(
73
+ provider_type: ProviderType = Query(...),
74
+ api_key: str = Header(..., alias="x-api-key"),
75
+ server: "SyncServer" = Depends(get_letta_server),
76
+ ):
77
+ try:
78
+ provider_check = ProviderCheck(provider_type=provider_type, api_key=api_key)
79
+ server.provider_manager.check_provider_api_key(provider_check=provider_check)
80
+ return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Valid api key for provider_type={provider_type.value}"})
81
+ except LLMAuthenticationError as e:
82
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=f"{e.message}")
83
+ except Exception as e:
84
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"{e}")
65
85
 
66
86
 
67
87
  @router.delete("/{provider_id}", response_model=None, operation_id="delete_provider")
@@ -106,6 +106,7 @@ def create_source(
106
106
  source_create.embedding_config = server.get_embedding_config_from_handle(
107
107
  handle=source_create.embedding,
108
108
  embedding_chunk_size=source_create.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
109
+ actor=actor,
109
110
  )
110
111
  source = Source(
111
112
  name=source_create.name,
letta/server/server.py CHANGED
@@ -42,7 +42,7 @@ from letta.schemas.block import Block, BlockUpdate, CreateBlock
42
42
  from letta.schemas.embedding_config import EmbeddingConfig
43
43
 
44
44
  # openai schemas
45
- from letta.schemas.enums import JobStatus, MessageStreamStatus
45
+ from letta.schemas.enums import JobStatus, MessageStreamStatus, ProviderCategory, ProviderType
46
46
  from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate
47
47
  from letta.schemas.group import GroupCreate, ManagerType, SleeptimeManager, VoiceSleeptimeManager
48
48
  from letta.schemas.job import Job, JobUpdate
@@ -734,17 +734,17 @@ class SyncServer(Server):
734
734
  return self._command(user_id=user_id, agent_id=agent_id, command=command)
735
735
 
736
736
  @trace_method
737
- def get_cached_llm_config(self, **kwargs):
737
+ def get_cached_llm_config(self, actor: User, **kwargs):
738
738
  key = make_key(**kwargs)
739
739
  if key not in self._llm_config_cache:
740
- self._llm_config_cache[key] = self.get_llm_config_from_handle(**kwargs)
740
+ self._llm_config_cache[key] = self.get_llm_config_from_handle(actor=actor, **kwargs)
741
741
  return self._llm_config_cache[key]
742
742
 
743
743
  @trace_method
744
- def get_cached_embedding_config(self, **kwargs):
744
+ def get_cached_embedding_config(self, actor: User, **kwargs):
745
745
  key = make_key(**kwargs)
746
746
  if key not in self._embedding_config_cache:
747
- self._embedding_config_cache[key] = self.get_embedding_config_from_handle(**kwargs)
747
+ self._embedding_config_cache[key] = self.get_embedding_config_from_handle(actor=actor, **kwargs)
748
748
  return self._embedding_config_cache[key]
749
749
 
750
750
  @trace_method
@@ -766,7 +766,7 @@ class SyncServer(Server):
766
766
  "enable_reasoner": request.enable_reasoner,
767
767
  }
768
768
  log_event(name="start get_cached_llm_config", attributes=config_params)
769
- request.llm_config = self.get_cached_llm_config(**config_params)
769
+ request.llm_config = self.get_cached_llm_config(actor=actor, **config_params)
770
770
  log_event(name="end get_cached_llm_config", attributes=config_params)
771
771
 
772
772
  if request.embedding_config is None:
@@ -777,7 +777,7 @@ class SyncServer(Server):
777
777
  "embedding_chunk_size": request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
778
778
  }
779
779
  log_event(name="start get_cached_embedding_config", attributes=embedding_config_params)
780
- request.embedding_config = self.get_cached_embedding_config(**embedding_config_params)
780
+ request.embedding_config = self.get_cached_embedding_config(actor=actor, **embedding_config_params)
781
781
  log_event(name="end get_cached_embedding_config", attributes=embedding_config_params)
782
782
 
783
783
  log_event(name="start create_agent db")
@@ -802,10 +802,10 @@ class SyncServer(Server):
802
802
  actor: User,
803
803
  ) -> AgentState:
804
804
  if request.model is not None:
805
- request.llm_config = self.get_llm_config_from_handle(handle=request.model)
805
+ request.llm_config = self.get_llm_config_from_handle(handle=request.model, actor=actor)
806
806
 
807
807
  if request.embedding is not None:
808
- request.embedding_config = self.get_embedding_config_from_handle(handle=request.embedding)
808
+ request.embedding_config = self.get_embedding_config_from_handle(handle=request.embedding, actor=actor)
809
809
 
810
810
  if request.enable_sleeptime:
811
811
  agent = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
@@ -1201,10 +1201,21 @@ class SyncServer(Server):
1201
1201
  except NoResultFound:
1202
1202
  raise HTTPException(status_code=404, detail=f"Organization with id {org_id} not found")
1203
1203
 
1204
- def list_llm_models(self, byok_only: bool = False) -> List[LLMConfig]:
1204
+ def list_llm_models(
1205
+ self,
1206
+ actor: User,
1207
+ provider_category: Optional[List[ProviderCategory]] = None,
1208
+ provider_name: Optional[str] = None,
1209
+ provider_type: Optional[ProviderType] = None,
1210
+ ) -> List[LLMConfig]:
1205
1211
  """List available models"""
1206
1212
  llm_models = []
1207
- for provider in self.get_enabled_providers(byok_only=byok_only):
1213
+ for provider in self.get_enabled_providers(
1214
+ provider_category=provider_category,
1215
+ provider_name=provider_name,
1216
+ provider_type=provider_type,
1217
+ actor=actor,
1218
+ ):
1208
1219
  try:
1209
1220
  llm_models.extend(provider.list_llm_models())
1210
1221
  except Exception as e:
@@ -1214,26 +1225,49 @@ class SyncServer(Server):
1214
1225
 
1215
1226
  return llm_models
1216
1227
 
1217
- def list_embedding_models(self) -> List[EmbeddingConfig]:
1228
+ def list_embedding_models(self, actor: User) -> List[EmbeddingConfig]:
1218
1229
  """List available embedding models"""
1219
1230
  embedding_models = []
1220
- for provider in self.get_enabled_providers():
1231
+ for provider in self.get_enabled_providers(actor):
1221
1232
  try:
1222
1233
  embedding_models.extend(provider.list_embedding_models())
1223
1234
  except Exception as e:
1224
1235
  warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
1225
1236
  return embedding_models
1226
1237
 
1227
- def get_enabled_providers(self, byok_only: bool = False):
1228
- providers_from_db = {p.name: p.cast_to_subtype() for p in self.provider_manager.list_providers()}
1229
- if byok_only:
1230
- return list(providers_from_db.values())
1231
- providers_from_env = {p.name: p for p in self._enabled_providers}
1232
- return list(providers_from_env.values()) + list(providers_from_db.values())
1238
+ def get_enabled_providers(
1239
+ self,
1240
+ actor: User,
1241
+ provider_category: Optional[List[ProviderCategory]] = None,
1242
+ provider_name: Optional[str] = None,
1243
+ provider_type: Optional[ProviderType] = None,
1244
+ ) -> List[Provider]:
1245
+ providers = []
1246
+ if not provider_category or ProviderCategory.base in provider_category:
1247
+ providers_from_env = [p for p in self._enabled_providers]
1248
+ providers.extend(providers_from_env)
1249
+
1250
+ if not provider_category or ProviderCategory.byok in provider_category:
1251
+ providers_from_db = self.provider_manager.list_providers(
1252
+ name=provider_name,
1253
+ provider_type=provider_type,
1254
+ actor=actor,
1255
+ )
1256
+ providers_from_db = [p.cast_to_subtype() for p in providers_from_db]
1257
+ providers.extend(providers_from_db)
1258
+
1259
+ if provider_name is not None:
1260
+ providers = [p for p in providers if p.name == provider_name]
1261
+
1262
+ if provider_type is not None:
1263
+ providers = [p for p in providers if p.provider_type == provider_type]
1264
+
1265
+ return providers
1233
1266
 
1234
1267
  @trace_method
1235
1268
  def get_llm_config_from_handle(
1236
1269
  self,
1270
+ actor: User,
1237
1271
  handle: str,
1238
1272
  context_window_limit: Optional[int] = None,
1239
1273
  max_tokens: Optional[int] = None,
@@ -1242,7 +1276,7 @@ class SyncServer(Server):
1242
1276
  ) -> LLMConfig:
1243
1277
  try:
1244
1278
  provider_name, model_name = handle.split("/", 1)
1245
- provider = self.get_provider_from_name(provider_name)
1279
+ provider = self.get_provider_from_name(provider_name, actor)
1246
1280
 
1247
1281
  llm_configs = [config for config in provider.list_llm_models() if config.handle == handle]
1248
1282
  if not llm_configs:
@@ -1286,11 +1320,11 @@ class SyncServer(Server):
1286
1320
 
1287
1321
  @trace_method
1288
1322
  def get_embedding_config_from_handle(
1289
- self, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
1323
+ self, actor: User, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
1290
1324
  ) -> EmbeddingConfig:
1291
1325
  try:
1292
1326
  provider_name, model_name = handle.split("/", 1)
1293
- provider = self.get_provider_from_name(provider_name)
1327
+ provider = self.get_provider_from_name(provider_name, actor)
1294
1328
 
1295
1329
  embedding_configs = [config for config in provider.list_embedding_models() if config.handle == handle]
1296
1330
  if not embedding_configs:
@@ -1313,8 +1347,8 @@ class SyncServer(Server):
1313
1347
 
1314
1348
  return embedding_config
1315
1349
 
1316
- def get_provider_from_name(self, provider_name: str) -> Provider:
1317
- providers = [provider for provider in self.get_enabled_providers() if provider.name == provider_name]
1350
+ def get_provider_from_name(self, provider_name: str, actor: User) -> Provider:
1351
+ providers = [provider for provider in self.get_enabled_providers(actor) if provider.name == provider_name]
1318
1352
  if not providers:
1319
1353
  raise ValueError(f"Provider {provider_name} is not supported")
1320
1354
  elif len(providers) > 1: