letta-nightly 0.7.10.dev20250507104304__py3-none-any.whl → 0.7.12.dev20250508044425__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- letta/__init__.py +1 -1
- letta/agent.py +8 -4
- letta/agents/letta_agent.py +3 -5
- letta/agents/letta_agent_batch.py +2 -4
- letta/client/client.py +2 -2
- letta/functions/async_composio_toolset.py +106 -0
- letta/functions/composio_helpers.py +20 -24
- letta/llm_api/anthropic.py +31 -6
- letta/llm_api/anthropic_client.py +10 -8
- letta/llm_api/google_ai_client.py +32 -10
- letta/llm_api/google_constants.py +2 -0
- letta/llm_api/google_vertex_client.py +107 -27
- letta/llm_api/llm_api_tools.py +9 -3
- letta/llm_api/llm_client.py +9 -11
- letta/llm_api/llm_client_base.py +6 -5
- letta/llm_api/openai.py +16 -0
- letta/llm_api/openai_client.py +6 -6
- letta/local_llm/constants.py +1 -0
- letta/memory.py +8 -5
- letta/orm/provider.py +1 -0
- letta/schemas/enums.py +6 -0
- letta/schemas/llm_config.py +2 -0
- letta/schemas/message.py +3 -3
- letta/schemas/providers.py +58 -2
- letta/server/rest_api/routers/v1/agents.py +10 -5
- letta/server/rest_api/routers/v1/llms.py +16 -6
- letta/server/rest_api/routers/v1/providers.py +24 -4
- letta/server/rest_api/routers/v1/sources.py +1 -0
- letta/server/server.py +58 -24
- letta/services/provider_manager.py +26 -8
- letta/settings.py +2 -0
- {letta_nightly-0.7.10.dev20250507104304.dist-info → letta_nightly-0.7.12.dev20250508044425.dist-info}/METADATA +2 -2
- {letta_nightly-0.7.10.dev20250507104304.dist-info → letta_nightly-0.7.12.dev20250508044425.dist-info}/RECORD +36 -35
- {letta_nightly-0.7.10.dev20250507104304.dist-info → letta_nightly-0.7.12.dev20250508044425.dist-info}/LICENSE +0 -0
- {letta_nightly-0.7.10.dev20250507104304.dist-info → letta_nightly-0.7.12.dev20250508044425.dist-info}/WHEEL +0 -0
- {letta_nightly-0.7.10.dev20250507104304.dist-info → letta_nightly-0.7.12.dev20250508044425.dist-info}/entry_points.txt +0 -0
letta/schemas/providers.py
CHANGED
@@ -2,14 +2,14 @@ import warnings
|
|
2
2
|
from datetime import datetime
|
3
3
|
from typing import List, Literal, Optional
|
4
4
|
|
5
|
-
from pydantic import Field, model_validator
|
5
|
+
from pydantic import BaseModel, Field, model_validator
|
6
6
|
|
7
7
|
from letta.constants import LETTA_MODEL_ENDPOINT, LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW
|
8
8
|
from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_azure_embeddings_endpoint
|
9
9
|
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
|
10
10
|
from letta.schemas.embedding_config import EmbeddingConfig
|
11
11
|
from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES
|
12
|
-
from letta.schemas.enums import ProviderType
|
12
|
+
from letta.schemas.enums import ProviderCategory, ProviderType
|
13
13
|
from letta.schemas.letta_base import LettaBase
|
14
14
|
from letta.schemas.llm_config import LLMConfig
|
15
15
|
from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES
|
@@ -24,6 +24,7 @@ class Provider(ProviderBase):
|
|
24
24
|
id: Optional[str] = Field(None, description="The id of the provider, lazily created by the database manager.")
|
25
25
|
name: str = Field(..., description="The name of the provider")
|
26
26
|
provider_type: ProviderType = Field(..., description="The type of the provider")
|
27
|
+
provider_category: ProviderCategory = Field(..., description="The category of the provider (base or byok)")
|
27
28
|
api_key: Optional[str] = Field(None, description="API key used for requests to the provider.")
|
28
29
|
base_url: Optional[str] = Field(None, description="Base URL for the provider.")
|
29
30
|
organization_id: Optional[str] = Field(None, description="The organization id of the user")
|
@@ -39,6 +40,10 @@ class Provider(ProviderBase):
|
|
39
40
|
if not self.id:
|
40
41
|
self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__)
|
41
42
|
|
43
|
+
def check_api_key(self):
|
44
|
+
"""Check if the API key is valid for the provider"""
|
45
|
+
raise NotImplementedError
|
46
|
+
|
42
47
|
def list_llm_models(self) -> List[LLMConfig]:
|
43
48
|
return []
|
44
49
|
|
@@ -111,8 +116,14 @@ class ProviderUpdate(ProviderBase):
|
|
111
116
|
api_key: str = Field(..., description="API key used for requests to the provider.")
|
112
117
|
|
113
118
|
|
119
|
+
class ProviderCheck(BaseModel):
|
120
|
+
provider_type: ProviderType = Field(..., description="The type of the provider.")
|
121
|
+
api_key: str = Field(..., description="API key used for requests to the provider.")
|
122
|
+
|
123
|
+
|
114
124
|
class LettaProvider(Provider):
|
115
125
|
provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.")
|
126
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
116
127
|
|
117
128
|
def list_llm_models(self) -> List[LLMConfig]:
|
118
129
|
return [
|
@@ -123,6 +134,7 @@ class LettaProvider(Provider):
|
|
123
134
|
context_window=8192,
|
124
135
|
handle=self.get_handle("letta-free"),
|
125
136
|
provider_name=self.name,
|
137
|
+
provider_category=self.provider_category,
|
126
138
|
)
|
127
139
|
]
|
128
140
|
|
@@ -141,9 +153,15 @@ class LettaProvider(Provider):
|
|
141
153
|
|
142
154
|
class OpenAIProvider(Provider):
|
143
155
|
provider_type: Literal[ProviderType.openai] = Field(ProviderType.openai, description="The type of the provider.")
|
156
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
144
157
|
api_key: str = Field(..., description="API key for the OpenAI API.")
|
145
158
|
base_url: str = Field(..., description="Base URL for the OpenAI API.")
|
146
159
|
|
160
|
+
def check_api_key(self):
|
161
|
+
from letta.llm_api.openai import openai_check_valid_api_key
|
162
|
+
|
163
|
+
openai_check_valid_api_key(self.base_url, self.api_key)
|
164
|
+
|
147
165
|
def list_llm_models(self) -> List[LLMConfig]:
|
148
166
|
from letta.llm_api.openai import openai_get_model_list
|
149
167
|
|
@@ -225,6 +243,7 @@ class OpenAIProvider(Provider):
|
|
225
243
|
context_window=context_window_size,
|
226
244
|
handle=self.get_handle(model_name),
|
227
245
|
provider_name=self.name,
|
246
|
+
provider_category=self.provider_category,
|
228
247
|
)
|
229
248
|
)
|
230
249
|
|
@@ -281,6 +300,7 @@ class DeepSeekProvider(OpenAIProvider):
|
|
281
300
|
"""
|
282
301
|
|
283
302
|
provider_type: Literal[ProviderType.deepseek] = Field(ProviderType.deepseek, description="The type of the provider.")
|
303
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
284
304
|
base_url: str = Field("https://api.deepseek.com/v1", description="Base URL for the DeepSeek API.")
|
285
305
|
api_key: str = Field(..., description="API key for the DeepSeek API.")
|
286
306
|
|
@@ -332,6 +352,7 @@ class DeepSeekProvider(OpenAIProvider):
|
|
332
352
|
handle=self.get_handle(model_name),
|
333
353
|
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
|
334
354
|
provider_name=self.name,
|
355
|
+
provider_category=self.provider_category,
|
335
356
|
)
|
336
357
|
)
|
337
358
|
|
@@ -344,6 +365,7 @@ class DeepSeekProvider(OpenAIProvider):
|
|
344
365
|
|
345
366
|
class LMStudioOpenAIProvider(OpenAIProvider):
|
346
367
|
provider_type: Literal[ProviderType.lmstudio_openai] = Field(ProviderType.lmstudio_openai, description="The type of the provider.")
|
368
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
347
369
|
base_url: str = Field(..., description="Base URL for the LMStudio OpenAI API.")
|
348
370
|
api_key: Optional[str] = Field(None, description="API key for the LMStudio API.")
|
349
371
|
|
@@ -470,6 +492,7 @@ class XAIProvider(OpenAIProvider):
|
|
470
492
|
"""https://docs.x.ai/docs/api-reference"""
|
471
493
|
|
472
494
|
provider_type: Literal[ProviderType.xai] = Field(ProviderType.xai, description="The type of the provider.")
|
495
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
473
496
|
api_key: str = Field(..., description="API key for the xAI/Grok API.")
|
474
497
|
base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.")
|
475
498
|
|
@@ -523,6 +546,7 @@ class XAIProvider(OpenAIProvider):
|
|
523
546
|
context_window=context_window_size,
|
524
547
|
handle=self.get_handle(model_name),
|
525
548
|
provider_name=self.name,
|
549
|
+
provider_category=self.provider_category,
|
526
550
|
)
|
527
551
|
)
|
528
552
|
|
@@ -535,9 +559,15 @@ class XAIProvider(OpenAIProvider):
|
|
535
559
|
|
536
560
|
class AnthropicProvider(Provider):
|
537
561
|
provider_type: Literal[ProviderType.anthropic] = Field(ProviderType.anthropic, description="The type of the provider.")
|
562
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
538
563
|
api_key: str = Field(..., description="API key for the Anthropic API.")
|
539
564
|
base_url: str = "https://api.anthropic.com/v1"
|
540
565
|
|
566
|
+
def check_api_key(self):
|
567
|
+
from letta.llm_api.anthropic import anthropic_check_valid_api_key
|
568
|
+
|
569
|
+
anthropic_check_valid_api_key(self.api_key)
|
570
|
+
|
541
571
|
def list_llm_models(self) -> List[LLMConfig]:
|
542
572
|
from letta.llm_api.anthropic import MODEL_LIST, anthropic_get_model_list
|
543
573
|
|
@@ -611,6 +641,7 @@ class AnthropicProvider(Provider):
|
|
611
641
|
put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
|
612
642
|
max_tokens=max_tokens,
|
613
643
|
provider_name=self.name,
|
644
|
+
provider_category=self.provider_category,
|
614
645
|
)
|
615
646
|
)
|
616
647
|
return configs
|
@@ -621,6 +652,7 @@ class AnthropicProvider(Provider):
|
|
621
652
|
|
622
653
|
class MistralProvider(Provider):
|
623
654
|
provider_type: Literal[ProviderType.mistral] = Field(ProviderType.mistral, description="The type of the provider.")
|
655
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
624
656
|
api_key: str = Field(..., description="API key for the Mistral API.")
|
625
657
|
base_url: str = "https://api.mistral.ai/v1"
|
626
658
|
|
@@ -645,6 +677,7 @@ class MistralProvider(Provider):
|
|
645
677
|
context_window=model["max_context_length"],
|
646
678
|
handle=self.get_handle(model["id"]),
|
647
679
|
provider_name=self.name,
|
680
|
+
provider_category=self.provider_category,
|
648
681
|
)
|
649
682
|
)
|
650
683
|
|
@@ -672,6 +705,7 @@ class OllamaProvider(OpenAIProvider):
|
|
672
705
|
"""
|
673
706
|
|
674
707
|
provider_type: Literal[ProviderType.ollama] = Field(ProviderType.ollama, description="The type of the provider.")
|
708
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
675
709
|
base_url: str = Field(..., description="Base URL for the Ollama API.")
|
676
710
|
api_key: Optional[str] = Field(None, description="API key for the Ollama API (default: `None`).")
|
677
711
|
default_prompt_formatter: str = Field(
|
@@ -702,6 +736,7 @@ class OllamaProvider(OpenAIProvider):
|
|
702
736
|
context_window=context_window,
|
703
737
|
handle=self.get_handle(model["name"]),
|
704
738
|
provider_name=self.name,
|
739
|
+
provider_category=self.provider_category,
|
705
740
|
)
|
706
741
|
)
|
707
742
|
return configs
|
@@ -785,6 +820,7 @@ class OllamaProvider(OpenAIProvider):
|
|
785
820
|
|
786
821
|
class GroqProvider(OpenAIProvider):
|
787
822
|
provider_type: Literal[ProviderType.groq] = Field(ProviderType.groq, description="The type of the provider.")
|
823
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
788
824
|
base_url: str = "https://api.groq.com/openai/v1"
|
789
825
|
api_key: str = Field(..., description="API key for the Groq API.")
|
790
826
|
|
@@ -804,6 +840,7 @@ class GroqProvider(OpenAIProvider):
|
|
804
840
|
context_window=model["context_window"],
|
805
841
|
handle=self.get_handle(model["id"]),
|
806
842
|
provider_name=self.name,
|
843
|
+
provider_category=self.provider_category,
|
807
844
|
)
|
808
845
|
)
|
809
846
|
return configs
|
@@ -825,6 +862,7 @@ class TogetherProvider(OpenAIProvider):
|
|
825
862
|
"""
|
826
863
|
|
827
864
|
provider_type: Literal[ProviderType.together] = Field(ProviderType.together, description="The type of the provider.")
|
865
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
828
866
|
base_url: str = "https://api.together.ai/v1"
|
829
867
|
api_key: str = Field(..., description="API key for the TogetherAI API.")
|
830
868
|
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
|
@@ -873,6 +911,7 @@ class TogetherProvider(OpenAIProvider):
|
|
873
911
|
context_window=context_window_size,
|
874
912
|
handle=self.get_handle(model_name),
|
875
913
|
provider_name=self.name,
|
914
|
+
provider_category=self.provider_category,
|
876
915
|
)
|
877
916
|
)
|
878
917
|
|
@@ -927,9 +966,15 @@ class TogetherProvider(OpenAIProvider):
|
|
927
966
|
class GoogleAIProvider(Provider):
|
928
967
|
# gemini
|
929
968
|
provider_type: Literal[ProviderType.google_ai] = Field(ProviderType.google_ai, description="The type of the provider.")
|
969
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
930
970
|
api_key: str = Field(..., description="API key for the Google AI API.")
|
931
971
|
base_url: str = "https://generativelanguage.googleapis.com"
|
932
972
|
|
973
|
+
def check_api_key(self):
|
974
|
+
from letta.llm_api.google_ai_client import google_ai_check_valid_api_key
|
975
|
+
|
976
|
+
google_ai_check_valid_api_key(self.api_key)
|
977
|
+
|
933
978
|
def list_llm_models(self):
|
934
979
|
from letta.llm_api.google_ai_client import google_ai_get_model_list
|
935
980
|
|
@@ -955,6 +1000,7 @@ class GoogleAIProvider(Provider):
|
|
955
1000
|
handle=self.get_handle(model),
|
956
1001
|
max_tokens=8192,
|
957
1002
|
provider_name=self.name,
|
1003
|
+
provider_category=self.provider_category,
|
958
1004
|
)
|
959
1005
|
)
|
960
1006
|
return configs
|
@@ -991,6 +1037,7 @@ class GoogleAIProvider(Provider):
|
|
991
1037
|
|
992
1038
|
class GoogleVertexProvider(Provider):
|
993
1039
|
provider_type: Literal[ProviderType.google_vertex] = Field(ProviderType.google_vertex, description="The type of the provider.")
|
1040
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
994
1041
|
google_cloud_project: str = Field(..., description="GCP project ID for the Google Vertex API.")
|
995
1042
|
google_cloud_location: str = Field(..., description="GCP region for the Google Vertex API.")
|
996
1043
|
|
@@ -1008,6 +1055,7 @@ class GoogleVertexProvider(Provider):
|
|
1008
1055
|
handle=self.get_handle(model),
|
1009
1056
|
max_tokens=8192,
|
1010
1057
|
provider_name=self.name,
|
1058
|
+
provider_category=self.provider_category,
|
1011
1059
|
)
|
1012
1060
|
)
|
1013
1061
|
return configs
|
@@ -1032,6 +1080,7 @@ class GoogleVertexProvider(Provider):
|
|
1032
1080
|
|
1033
1081
|
class AzureProvider(Provider):
|
1034
1082
|
provider_type: Literal[ProviderType.azure] = Field(ProviderType.azure, description="The type of the provider.")
|
1083
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
1035
1084
|
latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
|
1036
1085
|
base_url: str = Field(
|
1037
1086
|
..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`."
|
@@ -1065,6 +1114,7 @@ class AzureProvider(Provider):
|
|
1065
1114
|
context_window=context_window_size,
|
1066
1115
|
handle=self.get_handle(model_name),
|
1067
1116
|
provider_name=self.name,
|
1117
|
+
provider_category=self.provider_category,
|
1068
1118
|
),
|
1069
1119
|
)
|
1070
1120
|
return configs
|
@@ -1106,6 +1156,7 @@ class VLLMChatCompletionsProvider(Provider):
|
|
1106
1156
|
|
1107
1157
|
# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
|
1108
1158
|
provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
|
1159
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
1109
1160
|
base_url: str = Field(..., description="Base URL for the vLLM API.")
|
1110
1161
|
|
1111
1162
|
def list_llm_models(self) -> List[LLMConfig]:
|
@@ -1125,6 +1176,7 @@ class VLLMChatCompletionsProvider(Provider):
|
|
1125
1176
|
context_window=model["max_model_len"],
|
1126
1177
|
handle=self.get_handle(model["id"]),
|
1127
1178
|
provider_name=self.name,
|
1179
|
+
provider_category=self.provider_category,
|
1128
1180
|
)
|
1129
1181
|
)
|
1130
1182
|
return configs
|
@@ -1139,6 +1191,7 @@ class VLLMCompletionsProvider(Provider):
|
|
1139
1191
|
|
1140
1192
|
# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
|
1141
1193
|
provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
|
1194
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
1142
1195
|
base_url: str = Field(..., description="Base URL for the vLLM API.")
|
1143
1196
|
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
|
1144
1197
|
|
@@ -1159,6 +1212,7 @@ class VLLMCompletionsProvider(Provider):
|
|
1159
1212
|
context_window=model["max_model_len"],
|
1160
1213
|
handle=self.get_handle(model["id"]),
|
1161
1214
|
provider_name=self.name,
|
1215
|
+
provider_category=self.provider_category,
|
1162
1216
|
)
|
1163
1217
|
)
|
1164
1218
|
return configs
|
@@ -1174,6 +1228,7 @@ class CohereProvider(OpenAIProvider):
|
|
1174
1228
|
|
1175
1229
|
class AnthropicBedrockProvider(Provider):
|
1176
1230
|
provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.")
|
1231
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
1177
1232
|
aws_region: str = Field(..., description="AWS region for Bedrock")
|
1178
1233
|
|
1179
1234
|
def list_llm_models(self):
|
@@ -1192,6 +1247,7 @@ class AnthropicBedrockProvider(Provider):
|
|
1192
1247
|
context_window=self.get_model_context_window(model_arn),
|
1193
1248
|
handle=self.get_handle(model_arn),
|
1194
1249
|
provider_name=self.name,
|
1250
|
+
provider_category=self.provider_category,
|
1195
1251
|
)
|
1196
1252
|
)
|
1197
1253
|
return configs
|
@@ -631,12 +631,17 @@ async def send_message(
|
|
631
631
|
# TODO: This is redundant, remove soon
|
632
632
|
agent = server.agent_manager.get_agent_by_id(agent_id, actor)
|
633
633
|
|
634
|
-
if (
|
634
|
+
if all(
|
635
|
+
(
|
636
|
+
settings.use_experimental,
|
637
|
+
not agent.enable_sleeptime,
|
638
|
+
not agent.multi_agent_group,
|
639
|
+
not agent.agent_type == AgentType.sleeptime_agent,
|
640
|
+
)
|
641
|
+
) and (
|
642
|
+
# LLM Model Check: (1) Anthropic or (2) Google Vertex + Flag
|
635
643
|
agent.llm_config.model_endpoint_type == "anthropic"
|
636
|
-
|
637
|
-
and not agent.multi_agent_group
|
638
|
-
and not agent.agent_type == AgentType.sleeptime_agent
|
639
|
-
and settings.use_experimental
|
644
|
+
or (agent.llm_config.model_endpoint_type == "google_vertex" and settings.use_vertex_async_loop_experimental)
|
640
645
|
):
|
641
646
|
experimental_agent = LettaAgent(
|
642
647
|
agent_id=agent_id,
|
@@ -1,8 +1,9 @@
|
|
1
1
|
from typing import TYPE_CHECKING, List, Optional
|
2
2
|
|
3
|
-
from fastapi import APIRouter, Depends, Query
|
3
|
+
from fastapi import APIRouter, Depends, Header, Query
|
4
4
|
|
5
5
|
from letta.schemas.embedding_config import EmbeddingConfig
|
6
|
+
from letta.schemas.enums import ProviderCategory, ProviderType
|
6
7
|
from letta.schemas.llm_config import LLMConfig
|
7
8
|
from letta.server.rest_api.utils import get_letta_server
|
8
9
|
|
@@ -14,11 +15,19 @@ router = APIRouter(prefix="/models", tags=["models", "llms"])
|
|
14
15
|
|
15
16
|
@router.get("/", response_model=List[LLMConfig], operation_id="list_models")
|
16
17
|
def list_llm_models(
|
17
|
-
|
18
|
+
provider_category: Optional[List[ProviderCategory]] = Query(None),
|
19
|
+
provider_name: Optional[str] = Query(None),
|
20
|
+
provider_type: Optional[ProviderType] = Query(None),
|
18
21
|
server: "SyncServer" = Depends(get_letta_server),
|
22
|
+
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
19
23
|
):
|
20
|
-
|
21
|
-
models = server.list_llm_models(
|
24
|
+
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
25
|
+
models = server.list_llm_models(
|
26
|
+
provider_category=provider_category,
|
27
|
+
provider_name=provider_name,
|
28
|
+
provider_type=provider_type,
|
29
|
+
actor=actor,
|
30
|
+
)
|
22
31
|
# print(models)
|
23
32
|
return models
|
24
33
|
|
@@ -26,8 +35,9 @@ def list_llm_models(
|
|
26
35
|
@router.get("/embedding", response_model=List[EmbeddingConfig], operation_id="list_embedding_models")
|
27
36
|
def list_embedding_models(
|
28
37
|
server: "SyncServer" = Depends(get_letta_server),
|
38
|
+
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
29
39
|
):
|
30
|
-
|
31
|
-
models = server.list_embedding_models()
|
40
|
+
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
41
|
+
models = server.list_embedding_models(actor=actor)
|
32
42
|
# print(models)
|
33
43
|
return models
|
@@ -1,9 +1,12 @@
|
|
1
1
|
from typing import TYPE_CHECKING, List, Optional
|
2
2
|
|
3
|
-
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
|
3
|
+
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, status
|
4
|
+
from fastapi.responses import JSONResponse
|
4
5
|
|
6
|
+
from letta.errors import LLMAuthenticationError
|
7
|
+
from letta.orm.errors import NoResultFound
|
5
8
|
from letta.schemas.enums import ProviderType
|
6
|
-
from letta.schemas.providers import Provider, ProviderCreate, ProviderUpdate
|
9
|
+
from letta.schemas.providers import Provider, ProviderCheck, ProviderCreate, ProviderUpdate
|
7
10
|
from letta.server.rest_api.utils import get_letta_server
|
8
11
|
|
9
12
|
if TYPE_CHECKING:
|
@@ -45,7 +48,8 @@ def create_provider(
|
|
45
48
|
"""
|
46
49
|
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
47
50
|
|
48
|
-
provider =
|
51
|
+
provider = ProviderCreate(**request.model_dump())
|
52
|
+
|
49
53
|
provider = server.provider_manager.create_provider(provider, actor=actor)
|
50
54
|
return provider
|
51
55
|
|
@@ -61,7 +65,23 @@ def modify_provider(
|
|
61
65
|
Update an existing custom provider
|
62
66
|
"""
|
63
67
|
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
64
|
-
return server.provider_manager.update_provider(provider_id=provider_id,
|
68
|
+
return server.provider_manager.update_provider(provider_id=provider_id, provider_update=request, actor=actor)
|
69
|
+
|
70
|
+
|
71
|
+
@router.get("/check", response_model=None, operation_id="check_provider")
|
72
|
+
def check_provider(
|
73
|
+
provider_type: ProviderType = Query(...),
|
74
|
+
api_key: str = Header(..., alias="x-api-key"),
|
75
|
+
server: "SyncServer" = Depends(get_letta_server),
|
76
|
+
):
|
77
|
+
try:
|
78
|
+
provider_check = ProviderCheck(provider_type=provider_type, api_key=api_key)
|
79
|
+
server.provider_manager.check_provider_api_key(provider_check=provider_check)
|
80
|
+
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Valid api key for provider_type={provider_type.value}"})
|
81
|
+
except LLMAuthenticationError as e:
|
82
|
+
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=f"{e.message}")
|
83
|
+
except Exception as e:
|
84
|
+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"{e}")
|
65
85
|
|
66
86
|
|
67
87
|
@router.delete("/{provider_id}", response_model=None, operation_id="delete_provider")
|
@@ -106,6 +106,7 @@ def create_source(
|
|
106
106
|
source_create.embedding_config = server.get_embedding_config_from_handle(
|
107
107
|
handle=source_create.embedding,
|
108
108
|
embedding_chunk_size=source_create.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
|
109
|
+
actor=actor,
|
109
110
|
)
|
110
111
|
source = Source(
|
111
112
|
name=source_create.name,
|
letta/server/server.py
CHANGED
@@ -42,7 +42,7 @@ from letta.schemas.block import Block, BlockUpdate, CreateBlock
|
|
42
42
|
from letta.schemas.embedding_config import EmbeddingConfig
|
43
43
|
|
44
44
|
# openai schemas
|
45
|
-
from letta.schemas.enums import JobStatus, MessageStreamStatus
|
45
|
+
from letta.schemas.enums import JobStatus, MessageStreamStatus, ProviderCategory, ProviderType
|
46
46
|
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate
|
47
47
|
from letta.schemas.group import GroupCreate, ManagerType, SleeptimeManager, VoiceSleeptimeManager
|
48
48
|
from letta.schemas.job import Job, JobUpdate
|
@@ -734,17 +734,17 @@ class SyncServer(Server):
|
|
734
734
|
return self._command(user_id=user_id, agent_id=agent_id, command=command)
|
735
735
|
|
736
736
|
@trace_method
|
737
|
-
def get_cached_llm_config(self, **kwargs):
|
737
|
+
def get_cached_llm_config(self, actor: User, **kwargs):
|
738
738
|
key = make_key(**kwargs)
|
739
739
|
if key not in self._llm_config_cache:
|
740
|
-
self._llm_config_cache[key] = self.get_llm_config_from_handle(**kwargs)
|
740
|
+
self._llm_config_cache[key] = self.get_llm_config_from_handle(actor=actor, **kwargs)
|
741
741
|
return self._llm_config_cache[key]
|
742
742
|
|
743
743
|
@trace_method
|
744
|
-
def get_cached_embedding_config(self, **kwargs):
|
744
|
+
def get_cached_embedding_config(self, actor: User, **kwargs):
|
745
745
|
key = make_key(**kwargs)
|
746
746
|
if key not in self._embedding_config_cache:
|
747
|
-
self._embedding_config_cache[key] = self.get_embedding_config_from_handle(**kwargs)
|
747
|
+
self._embedding_config_cache[key] = self.get_embedding_config_from_handle(actor=actor, **kwargs)
|
748
748
|
return self._embedding_config_cache[key]
|
749
749
|
|
750
750
|
@trace_method
|
@@ -766,7 +766,7 @@ class SyncServer(Server):
|
|
766
766
|
"enable_reasoner": request.enable_reasoner,
|
767
767
|
}
|
768
768
|
log_event(name="start get_cached_llm_config", attributes=config_params)
|
769
|
-
request.llm_config = self.get_cached_llm_config(**config_params)
|
769
|
+
request.llm_config = self.get_cached_llm_config(actor=actor, **config_params)
|
770
770
|
log_event(name="end get_cached_llm_config", attributes=config_params)
|
771
771
|
|
772
772
|
if request.embedding_config is None:
|
@@ -777,7 +777,7 @@ class SyncServer(Server):
|
|
777
777
|
"embedding_chunk_size": request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
|
778
778
|
}
|
779
779
|
log_event(name="start get_cached_embedding_config", attributes=embedding_config_params)
|
780
|
-
request.embedding_config = self.get_cached_embedding_config(**embedding_config_params)
|
780
|
+
request.embedding_config = self.get_cached_embedding_config(actor=actor, **embedding_config_params)
|
781
781
|
log_event(name="end get_cached_embedding_config", attributes=embedding_config_params)
|
782
782
|
|
783
783
|
log_event(name="start create_agent db")
|
@@ -802,10 +802,10 @@ class SyncServer(Server):
|
|
802
802
|
actor: User,
|
803
803
|
) -> AgentState:
|
804
804
|
if request.model is not None:
|
805
|
-
request.llm_config = self.get_llm_config_from_handle(handle=request.model)
|
805
|
+
request.llm_config = self.get_llm_config_from_handle(handle=request.model, actor=actor)
|
806
806
|
|
807
807
|
if request.embedding is not None:
|
808
|
-
request.embedding_config = self.get_embedding_config_from_handle(handle=request.embedding)
|
808
|
+
request.embedding_config = self.get_embedding_config_from_handle(handle=request.embedding, actor=actor)
|
809
809
|
|
810
810
|
if request.enable_sleeptime:
|
811
811
|
agent = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
|
@@ -1201,10 +1201,21 @@ class SyncServer(Server):
|
|
1201
1201
|
except NoResultFound:
|
1202
1202
|
raise HTTPException(status_code=404, detail=f"Organization with id {org_id} not found")
|
1203
1203
|
|
1204
|
-
def list_llm_models(
|
1204
|
+
def list_llm_models(
|
1205
|
+
self,
|
1206
|
+
actor: User,
|
1207
|
+
provider_category: Optional[List[ProviderCategory]] = None,
|
1208
|
+
provider_name: Optional[str] = None,
|
1209
|
+
provider_type: Optional[ProviderType] = None,
|
1210
|
+
) -> List[LLMConfig]:
|
1205
1211
|
"""List available models"""
|
1206
1212
|
llm_models = []
|
1207
|
-
for provider in self.get_enabled_providers(
|
1213
|
+
for provider in self.get_enabled_providers(
|
1214
|
+
provider_category=provider_category,
|
1215
|
+
provider_name=provider_name,
|
1216
|
+
provider_type=provider_type,
|
1217
|
+
actor=actor,
|
1218
|
+
):
|
1208
1219
|
try:
|
1209
1220
|
llm_models.extend(provider.list_llm_models())
|
1210
1221
|
except Exception as e:
|
@@ -1214,26 +1225,49 @@ class SyncServer(Server):
|
|
1214
1225
|
|
1215
1226
|
return llm_models
|
1216
1227
|
|
1217
|
-
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
1228
|
+
def list_embedding_models(self, actor: User) -> List[EmbeddingConfig]:
|
1218
1229
|
"""List available embedding models"""
|
1219
1230
|
embedding_models = []
|
1220
|
-
for provider in self.get_enabled_providers():
|
1231
|
+
for provider in self.get_enabled_providers(actor):
|
1221
1232
|
try:
|
1222
1233
|
embedding_models.extend(provider.list_embedding_models())
|
1223
1234
|
except Exception as e:
|
1224
1235
|
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
|
1225
1236
|
return embedding_models
|
1226
1237
|
|
1227
|
-
def get_enabled_providers(
|
1228
|
-
|
1229
|
-
|
1230
|
-
|
1231
|
-
|
1232
|
-
|
1238
|
+
def get_enabled_providers(
|
1239
|
+
self,
|
1240
|
+
actor: User,
|
1241
|
+
provider_category: Optional[List[ProviderCategory]] = None,
|
1242
|
+
provider_name: Optional[str] = None,
|
1243
|
+
provider_type: Optional[ProviderType] = None,
|
1244
|
+
) -> List[Provider]:
|
1245
|
+
providers = []
|
1246
|
+
if not provider_category or ProviderCategory.base in provider_category:
|
1247
|
+
providers_from_env = [p for p in self._enabled_providers]
|
1248
|
+
providers.extend(providers_from_env)
|
1249
|
+
|
1250
|
+
if not provider_category or ProviderCategory.byok in provider_category:
|
1251
|
+
providers_from_db = self.provider_manager.list_providers(
|
1252
|
+
name=provider_name,
|
1253
|
+
provider_type=provider_type,
|
1254
|
+
actor=actor,
|
1255
|
+
)
|
1256
|
+
providers_from_db = [p.cast_to_subtype() for p in providers_from_db]
|
1257
|
+
providers.extend(providers_from_db)
|
1258
|
+
|
1259
|
+
if provider_name is not None:
|
1260
|
+
providers = [p for p in providers if p.name == provider_name]
|
1261
|
+
|
1262
|
+
if provider_type is not None:
|
1263
|
+
providers = [p for p in providers if p.provider_type == provider_type]
|
1264
|
+
|
1265
|
+
return providers
|
1233
1266
|
|
1234
1267
|
@trace_method
|
1235
1268
|
def get_llm_config_from_handle(
|
1236
1269
|
self,
|
1270
|
+
actor: User,
|
1237
1271
|
handle: str,
|
1238
1272
|
context_window_limit: Optional[int] = None,
|
1239
1273
|
max_tokens: Optional[int] = None,
|
@@ -1242,7 +1276,7 @@ class SyncServer(Server):
|
|
1242
1276
|
) -> LLMConfig:
|
1243
1277
|
try:
|
1244
1278
|
provider_name, model_name = handle.split("/", 1)
|
1245
|
-
provider = self.get_provider_from_name(provider_name)
|
1279
|
+
provider = self.get_provider_from_name(provider_name, actor)
|
1246
1280
|
|
1247
1281
|
llm_configs = [config for config in provider.list_llm_models() if config.handle == handle]
|
1248
1282
|
if not llm_configs:
|
@@ -1286,11 +1320,11 @@ class SyncServer(Server):
|
|
1286
1320
|
|
1287
1321
|
@trace_method
|
1288
1322
|
def get_embedding_config_from_handle(
|
1289
|
-
self, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
|
1323
|
+
self, actor: User, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
|
1290
1324
|
) -> EmbeddingConfig:
|
1291
1325
|
try:
|
1292
1326
|
provider_name, model_name = handle.split("/", 1)
|
1293
|
-
provider = self.get_provider_from_name(provider_name)
|
1327
|
+
provider = self.get_provider_from_name(provider_name, actor)
|
1294
1328
|
|
1295
1329
|
embedding_configs = [config for config in provider.list_embedding_models() if config.handle == handle]
|
1296
1330
|
if not embedding_configs:
|
@@ -1313,8 +1347,8 @@ class SyncServer(Server):
|
|
1313
1347
|
|
1314
1348
|
return embedding_config
|
1315
1349
|
|
1316
|
-
def get_provider_from_name(self, provider_name: str) -> Provider:
|
1317
|
-
providers = [provider for provider in self.get_enabled_providers() if provider.name == provider_name]
|
1350
|
+
def get_provider_from_name(self, provider_name: str, actor: User) -> Provider:
|
1351
|
+
providers = [provider for provider in self.get_enabled_providers(actor) if provider.name == provider_name]
|
1318
1352
|
if not providers:
|
1319
1353
|
raise ValueError(f"Provider {provider_name} is not supported")
|
1320
1354
|
elif len(providers) > 1:
|