letta-nightly 0.7.10.dev20250507104304__py3-none-any.whl → 0.7.11.dev20250507230415__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- letta/__init__.py +1 -1
- letta/agent.py +8 -4
- letta/agents/letta_agent.py +3 -5
- letta/agents/letta_agent_batch.py +2 -4
- letta/client/client.py +2 -2
- letta/functions/async_composio_toolset.py +106 -0
- letta/functions/composio_helpers.py +20 -24
- letta/llm_api/anthropic.py +16 -5
- letta/llm_api/anthropic_client.py +10 -8
- letta/llm_api/google_ai_client.py +12 -10
- letta/llm_api/google_vertex_client.py +107 -27
- letta/llm_api/llm_api_tools.py +9 -3
- letta/llm_api/llm_client.py +9 -11
- letta/llm_api/llm_client_base.py +6 -5
- letta/llm_api/openai_client.py +6 -6
- letta/local_llm/constants.py +1 -0
- letta/memory.py +8 -5
- letta/orm/provider.py +1 -0
- letta/schemas/enums.py +5 -0
- letta/schemas/llm_config.py +2 -0
- letta/schemas/message.py +3 -3
- letta/schemas/providers.py +33 -1
- letta/server/rest_api/routers/v1/agents.py +10 -5
- letta/server/rest_api/routers/v1/llms.py +16 -6
- letta/server/rest_api/routers/v1/providers.py +3 -1
- letta/server/rest_api/routers/v1/sources.py +1 -0
- letta/server/server.py +58 -24
- letta/services/provider_manager.py +11 -8
- letta/settings.py +2 -0
- {letta_nightly-0.7.10.dev20250507104304.dist-info → letta_nightly-0.7.11.dev20250507230415.dist-info}/METADATA +1 -1
- {letta_nightly-0.7.10.dev20250507104304.dist-info → letta_nightly-0.7.11.dev20250507230415.dist-info}/RECORD +34 -33
- {letta_nightly-0.7.10.dev20250507104304.dist-info → letta_nightly-0.7.11.dev20250507230415.dist-info}/LICENSE +0 -0
- {letta_nightly-0.7.10.dev20250507104304.dist-info → letta_nightly-0.7.11.dev20250507230415.dist-info}/WHEEL +0 -0
- {letta_nightly-0.7.10.dev20250507104304.dist-info → letta_nightly-0.7.11.dev20250507230415.dist-info}/entry_points.txt +0 -0
letta/schemas/providers.py
CHANGED
@@ -9,7 +9,7 @@ from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_
|
|
9
9
|
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
|
10
10
|
from letta.schemas.embedding_config import EmbeddingConfig
|
11
11
|
from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES
|
12
|
-
from letta.schemas.enums import ProviderType
|
12
|
+
from letta.schemas.enums import ProviderCategory, ProviderType
|
13
13
|
from letta.schemas.letta_base import LettaBase
|
14
14
|
from letta.schemas.llm_config import LLMConfig
|
15
15
|
from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES
|
@@ -24,6 +24,7 @@ class Provider(ProviderBase):
|
|
24
24
|
id: Optional[str] = Field(None, description="The id of the provider, lazily created by the database manager.")
|
25
25
|
name: str = Field(..., description="The name of the provider")
|
26
26
|
provider_type: ProviderType = Field(..., description="The type of the provider")
|
27
|
+
provider_category: ProviderCategory = Field(..., description="The category of the provider (base or byok)")
|
27
28
|
api_key: Optional[str] = Field(None, description="API key used for requests to the provider.")
|
28
29
|
base_url: Optional[str] = Field(None, description="Base URL for the provider.")
|
29
30
|
organization_id: Optional[str] = Field(None, description="The organization id of the user")
|
@@ -113,6 +114,7 @@ class ProviderUpdate(ProviderBase):
|
|
113
114
|
|
114
115
|
class LettaProvider(Provider):
|
115
116
|
provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.")
|
117
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
116
118
|
|
117
119
|
def list_llm_models(self) -> List[LLMConfig]:
|
118
120
|
return [
|
@@ -123,6 +125,7 @@ class LettaProvider(Provider):
|
|
123
125
|
context_window=8192,
|
124
126
|
handle=self.get_handle("letta-free"),
|
125
127
|
provider_name=self.name,
|
128
|
+
provider_category=self.provider_category,
|
126
129
|
)
|
127
130
|
]
|
128
131
|
|
@@ -141,6 +144,7 @@ class LettaProvider(Provider):
|
|
141
144
|
|
142
145
|
class OpenAIProvider(Provider):
|
143
146
|
provider_type: Literal[ProviderType.openai] = Field(ProviderType.openai, description="The type of the provider.")
|
147
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
144
148
|
api_key: str = Field(..., description="API key for the OpenAI API.")
|
145
149
|
base_url: str = Field(..., description="Base URL for the OpenAI API.")
|
146
150
|
|
@@ -225,6 +229,7 @@ class OpenAIProvider(Provider):
|
|
225
229
|
context_window=context_window_size,
|
226
230
|
handle=self.get_handle(model_name),
|
227
231
|
provider_name=self.name,
|
232
|
+
provider_category=self.provider_category,
|
228
233
|
)
|
229
234
|
)
|
230
235
|
|
@@ -281,6 +286,7 @@ class DeepSeekProvider(OpenAIProvider):
|
|
281
286
|
"""
|
282
287
|
|
283
288
|
provider_type: Literal[ProviderType.deepseek] = Field(ProviderType.deepseek, description="The type of the provider.")
|
289
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
284
290
|
base_url: str = Field("https://api.deepseek.com/v1", description="Base URL for the DeepSeek API.")
|
285
291
|
api_key: str = Field(..., description="API key for the DeepSeek API.")
|
286
292
|
|
@@ -332,6 +338,7 @@ class DeepSeekProvider(OpenAIProvider):
|
|
332
338
|
handle=self.get_handle(model_name),
|
333
339
|
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
|
334
340
|
provider_name=self.name,
|
341
|
+
provider_category=self.provider_category,
|
335
342
|
)
|
336
343
|
)
|
337
344
|
|
@@ -344,6 +351,7 @@ class DeepSeekProvider(OpenAIProvider):
|
|
344
351
|
|
345
352
|
class LMStudioOpenAIProvider(OpenAIProvider):
|
346
353
|
provider_type: Literal[ProviderType.lmstudio_openai] = Field(ProviderType.lmstudio_openai, description="The type of the provider.")
|
354
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
347
355
|
base_url: str = Field(..., description="Base URL for the LMStudio OpenAI API.")
|
348
356
|
api_key: Optional[str] = Field(None, description="API key for the LMStudio API.")
|
349
357
|
|
@@ -470,6 +478,7 @@ class XAIProvider(OpenAIProvider):
|
|
470
478
|
"""https://docs.x.ai/docs/api-reference"""
|
471
479
|
|
472
480
|
provider_type: Literal[ProviderType.xai] = Field(ProviderType.xai, description="The type of the provider.")
|
481
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
473
482
|
api_key: str = Field(..., description="API key for the xAI/Grok API.")
|
474
483
|
base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.")
|
475
484
|
|
@@ -523,6 +532,7 @@ class XAIProvider(OpenAIProvider):
|
|
523
532
|
context_window=context_window_size,
|
524
533
|
handle=self.get_handle(model_name),
|
525
534
|
provider_name=self.name,
|
535
|
+
provider_category=self.provider_category,
|
526
536
|
)
|
527
537
|
)
|
528
538
|
|
@@ -535,6 +545,7 @@ class XAIProvider(OpenAIProvider):
|
|
535
545
|
|
536
546
|
class AnthropicProvider(Provider):
|
537
547
|
provider_type: Literal[ProviderType.anthropic] = Field(ProviderType.anthropic, description="The type of the provider.")
|
548
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
538
549
|
api_key: str = Field(..., description="API key for the Anthropic API.")
|
539
550
|
base_url: str = "https://api.anthropic.com/v1"
|
540
551
|
|
@@ -611,6 +622,7 @@ class AnthropicProvider(Provider):
|
|
611
622
|
put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
|
612
623
|
max_tokens=max_tokens,
|
613
624
|
provider_name=self.name,
|
625
|
+
provider_category=self.provider_category,
|
614
626
|
)
|
615
627
|
)
|
616
628
|
return configs
|
@@ -621,6 +633,7 @@ class AnthropicProvider(Provider):
|
|
621
633
|
|
622
634
|
class MistralProvider(Provider):
|
623
635
|
provider_type: Literal[ProviderType.mistral] = Field(ProviderType.mistral, description="The type of the provider.")
|
636
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
624
637
|
api_key: str = Field(..., description="API key for the Mistral API.")
|
625
638
|
base_url: str = "https://api.mistral.ai/v1"
|
626
639
|
|
@@ -645,6 +658,7 @@ class MistralProvider(Provider):
|
|
645
658
|
context_window=model["max_context_length"],
|
646
659
|
handle=self.get_handle(model["id"]),
|
647
660
|
provider_name=self.name,
|
661
|
+
provider_category=self.provider_category,
|
648
662
|
)
|
649
663
|
)
|
650
664
|
|
@@ -672,6 +686,7 @@ class OllamaProvider(OpenAIProvider):
|
|
672
686
|
"""
|
673
687
|
|
674
688
|
provider_type: Literal[ProviderType.ollama] = Field(ProviderType.ollama, description="The type of the provider.")
|
689
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
675
690
|
base_url: str = Field(..., description="Base URL for the Ollama API.")
|
676
691
|
api_key: Optional[str] = Field(None, description="API key for the Ollama API (default: `None`).")
|
677
692
|
default_prompt_formatter: str = Field(
|
@@ -702,6 +717,7 @@ class OllamaProvider(OpenAIProvider):
|
|
702
717
|
context_window=context_window,
|
703
718
|
handle=self.get_handle(model["name"]),
|
704
719
|
provider_name=self.name,
|
720
|
+
provider_category=self.provider_category,
|
705
721
|
)
|
706
722
|
)
|
707
723
|
return configs
|
@@ -785,6 +801,7 @@ class OllamaProvider(OpenAIProvider):
|
|
785
801
|
|
786
802
|
class GroqProvider(OpenAIProvider):
|
787
803
|
provider_type: Literal[ProviderType.groq] = Field(ProviderType.groq, description="The type of the provider.")
|
804
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
788
805
|
base_url: str = "https://api.groq.com/openai/v1"
|
789
806
|
api_key: str = Field(..., description="API key for the Groq API.")
|
790
807
|
|
@@ -804,6 +821,7 @@ class GroqProvider(OpenAIProvider):
|
|
804
821
|
context_window=model["context_window"],
|
805
822
|
handle=self.get_handle(model["id"]),
|
806
823
|
provider_name=self.name,
|
824
|
+
provider_category=self.provider_category,
|
807
825
|
)
|
808
826
|
)
|
809
827
|
return configs
|
@@ -825,6 +843,7 @@ class TogetherProvider(OpenAIProvider):
|
|
825
843
|
"""
|
826
844
|
|
827
845
|
provider_type: Literal[ProviderType.together] = Field(ProviderType.together, description="The type of the provider.")
|
846
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
828
847
|
base_url: str = "https://api.together.ai/v1"
|
829
848
|
api_key: str = Field(..., description="API key for the TogetherAI API.")
|
830
849
|
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
|
@@ -873,6 +892,7 @@ class TogetherProvider(OpenAIProvider):
|
|
873
892
|
context_window=context_window_size,
|
874
893
|
handle=self.get_handle(model_name),
|
875
894
|
provider_name=self.name,
|
895
|
+
provider_category=self.provider_category,
|
876
896
|
)
|
877
897
|
)
|
878
898
|
|
@@ -927,6 +947,7 @@ class TogetherProvider(OpenAIProvider):
|
|
927
947
|
class GoogleAIProvider(Provider):
|
928
948
|
# gemini
|
929
949
|
provider_type: Literal[ProviderType.google_ai] = Field(ProviderType.google_ai, description="The type of the provider.")
|
950
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
930
951
|
api_key: str = Field(..., description="API key for the Google AI API.")
|
931
952
|
base_url: str = "https://generativelanguage.googleapis.com"
|
932
953
|
|
@@ -955,6 +976,7 @@ class GoogleAIProvider(Provider):
|
|
955
976
|
handle=self.get_handle(model),
|
956
977
|
max_tokens=8192,
|
957
978
|
provider_name=self.name,
|
979
|
+
provider_category=self.provider_category,
|
958
980
|
)
|
959
981
|
)
|
960
982
|
return configs
|
@@ -991,6 +1013,7 @@ class GoogleAIProvider(Provider):
|
|
991
1013
|
|
992
1014
|
class GoogleVertexProvider(Provider):
|
993
1015
|
provider_type: Literal[ProviderType.google_vertex] = Field(ProviderType.google_vertex, description="The type of the provider.")
|
1016
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
994
1017
|
google_cloud_project: str = Field(..., description="GCP project ID for the Google Vertex API.")
|
995
1018
|
google_cloud_location: str = Field(..., description="GCP region for the Google Vertex API.")
|
996
1019
|
|
@@ -1008,6 +1031,7 @@ class GoogleVertexProvider(Provider):
|
|
1008
1031
|
handle=self.get_handle(model),
|
1009
1032
|
max_tokens=8192,
|
1010
1033
|
provider_name=self.name,
|
1034
|
+
provider_category=self.provider_category,
|
1011
1035
|
)
|
1012
1036
|
)
|
1013
1037
|
return configs
|
@@ -1032,6 +1056,7 @@ class GoogleVertexProvider(Provider):
|
|
1032
1056
|
|
1033
1057
|
class AzureProvider(Provider):
|
1034
1058
|
provider_type: Literal[ProviderType.azure] = Field(ProviderType.azure, description="The type of the provider.")
|
1059
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
1035
1060
|
latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
|
1036
1061
|
base_url: str = Field(
|
1037
1062
|
..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`."
|
@@ -1065,6 +1090,7 @@ class AzureProvider(Provider):
|
|
1065
1090
|
context_window=context_window_size,
|
1066
1091
|
handle=self.get_handle(model_name),
|
1067
1092
|
provider_name=self.name,
|
1093
|
+
provider_category=self.provider_category,
|
1068
1094
|
),
|
1069
1095
|
)
|
1070
1096
|
return configs
|
@@ -1106,6 +1132,7 @@ class VLLMChatCompletionsProvider(Provider):
|
|
1106
1132
|
|
1107
1133
|
# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
|
1108
1134
|
provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
|
1135
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
1109
1136
|
base_url: str = Field(..., description="Base URL for the vLLM API.")
|
1110
1137
|
|
1111
1138
|
def list_llm_models(self) -> List[LLMConfig]:
|
@@ -1125,6 +1152,7 @@ class VLLMChatCompletionsProvider(Provider):
|
|
1125
1152
|
context_window=model["max_model_len"],
|
1126
1153
|
handle=self.get_handle(model["id"]),
|
1127
1154
|
provider_name=self.name,
|
1155
|
+
provider_category=self.provider_category,
|
1128
1156
|
)
|
1129
1157
|
)
|
1130
1158
|
return configs
|
@@ -1139,6 +1167,7 @@ class VLLMCompletionsProvider(Provider):
|
|
1139
1167
|
|
1140
1168
|
# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
|
1141
1169
|
provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
|
1170
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
1142
1171
|
base_url: str = Field(..., description="Base URL for the vLLM API.")
|
1143
1172
|
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
|
1144
1173
|
|
@@ -1159,6 +1188,7 @@ class VLLMCompletionsProvider(Provider):
|
|
1159
1188
|
context_window=model["max_model_len"],
|
1160
1189
|
handle=self.get_handle(model["id"]),
|
1161
1190
|
provider_name=self.name,
|
1191
|
+
provider_category=self.provider_category,
|
1162
1192
|
)
|
1163
1193
|
)
|
1164
1194
|
return configs
|
@@ -1174,6 +1204,7 @@ class CohereProvider(OpenAIProvider):
|
|
1174
1204
|
|
1175
1205
|
class AnthropicBedrockProvider(Provider):
|
1176
1206
|
provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.")
|
1207
|
+
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
1177
1208
|
aws_region: str = Field(..., description="AWS region for Bedrock")
|
1178
1209
|
|
1179
1210
|
def list_llm_models(self):
|
@@ -1192,6 +1223,7 @@ class AnthropicBedrockProvider(Provider):
|
|
1192
1223
|
context_window=self.get_model_context_window(model_arn),
|
1193
1224
|
handle=self.get_handle(model_arn),
|
1194
1225
|
provider_name=self.name,
|
1226
|
+
provider_category=self.provider_category,
|
1195
1227
|
)
|
1196
1228
|
)
|
1197
1229
|
return configs
|
@@ -631,12 +631,17 @@ async def send_message(
|
|
631
631
|
# TODO: This is redundant, remove soon
|
632
632
|
agent = server.agent_manager.get_agent_by_id(agent_id, actor)
|
633
633
|
|
634
|
-
if (
|
634
|
+
if all(
|
635
|
+
(
|
636
|
+
settings.use_experimental,
|
637
|
+
not agent.enable_sleeptime,
|
638
|
+
not agent.multi_agent_group,
|
639
|
+
not agent.agent_type == AgentType.sleeptime_agent,
|
640
|
+
)
|
641
|
+
) and (
|
642
|
+
# LLM Model Check: (1) Anthropic or (2) Google Vertex + Flag
|
635
643
|
agent.llm_config.model_endpoint_type == "anthropic"
|
636
|
-
|
637
|
-
and not agent.multi_agent_group
|
638
|
-
and not agent.agent_type == AgentType.sleeptime_agent
|
639
|
-
and settings.use_experimental
|
644
|
+
or (agent.llm_config.model_endpoint_type == "google_vertex" and settings.use_vertex_async_loop_experimental)
|
640
645
|
):
|
641
646
|
experimental_agent = LettaAgent(
|
642
647
|
agent_id=agent_id,
|
@@ -1,8 +1,9 @@
|
|
1
1
|
from typing import TYPE_CHECKING, List, Optional
|
2
2
|
|
3
|
-
from fastapi import APIRouter, Depends, Query
|
3
|
+
from fastapi import APIRouter, Depends, Header, Query
|
4
4
|
|
5
5
|
from letta.schemas.embedding_config import EmbeddingConfig
|
6
|
+
from letta.schemas.enums import ProviderCategory, ProviderType
|
6
7
|
from letta.schemas.llm_config import LLMConfig
|
7
8
|
from letta.server.rest_api.utils import get_letta_server
|
8
9
|
|
@@ -14,11 +15,19 @@ router = APIRouter(prefix="/models", tags=["models", "llms"])
|
|
14
15
|
|
15
16
|
@router.get("/", response_model=List[LLMConfig], operation_id="list_models")
|
16
17
|
def list_llm_models(
|
17
|
-
|
18
|
+
provider_category: Optional[List[ProviderCategory]] = Query(None),
|
19
|
+
provider_name: Optional[str] = Query(None),
|
20
|
+
provider_type: Optional[ProviderType] = Query(None),
|
18
21
|
server: "SyncServer" = Depends(get_letta_server),
|
22
|
+
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
19
23
|
):
|
20
|
-
|
21
|
-
models = server.list_llm_models(
|
24
|
+
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
25
|
+
models = server.list_llm_models(
|
26
|
+
provider_category=provider_category,
|
27
|
+
provider_name=provider_name,
|
28
|
+
provider_type=provider_type,
|
29
|
+
actor=actor,
|
30
|
+
)
|
22
31
|
# print(models)
|
23
32
|
return models
|
24
33
|
|
@@ -26,8 +35,9 @@ def list_llm_models(
|
|
26
35
|
@router.get("/embedding", response_model=List[EmbeddingConfig], operation_id="list_embedding_models")
|
27
36
|
def list_embedding_models(
|
28
37
|
server: "SyncServer" = Depends(get_letta_server),
|
38
|
+
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
29
39
|
):
|
30
|
-
|
31
|
-
models = server.list_embedding_models()
|
40
|
+
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
41
|
+
models = server.list_embedding_models(actor=actor)
|
32
42
|
# print(models)
|
33
43
|
return models
|
@@ -1,7 +1,9 @@
|
|
1
1
|
from typing import TYPE_CHECKING, List, Optional
|
2
2
|
|
3
|
-
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
|
3
|
+
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, status
|
4
|
+
from fastapi.responses import JSONResponse
|
4
5
|
|
6
|
+
from letta.orm.errors import NoResultFound
|
5
7
|
from letta.schemas.enums import ProviderType
|
6
8
|
from letta.schemas.providers import Provider, ProviderCreate, ProviderUpdate
|
7
9
|
from letta.server.rest_api.utils import get_letta_server
|
@@ -106,6 +106,7 @@ def create_source(
|
|
106
106
|
source_create.embedding_config = server.get_embedding_config_from_handle(
|
107
107
|
handle=source_create.embedding,
|
108
108
|
embedding_chunk_size=source_create.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
|
109
|
+
actor=actor,
|
109
110
|
)
|
110
111
|
source = Source(
|
111
112
|
name=source_create.name,
|
letta/server/server.py
CHANGED
@@ -42,7 +42,7 @@ from letta.schemas.block import Block, BlockUpdate, CreateBlock
|
|
42
42
|
from letta.schemas.embedding_config import EmbeddingConfig
|
43
43
|
|
44
44
|
# openai schemas
|
45
|
-
from letta.schemas.enums import JobStatus, MessageStreamStatus
|
45
|
+
from letta.schemas.enums import JobStatus, MessageStreamStatus, ProviderCategory, ProviderType
|
46
46
|
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate
|
47
47
|
from letta.schemas.group import GroupCreate, ManagerType, SleeptimeManager, VoiceSleeptimeManager
|
48
48
|
from letta.schemas.job import Job, JobUpdate
|
@@ -734,17 +734,17 @@ class SyncServer(Server):
|
|
734
734
|
return self._command(user_id=user_id, agent_id=agent_id, command=command)
|
735
735
|
|
736
736
|
@trace_method
|
737
|
-
def get_cached_llm_config(self, **kwargs):
|
737
|
+
def get_cached_llm_config(self, actor: User, **kwargs):
|
738
738
|
key = make_key(**kwargs)
|
739
739
|
if key not in self._llm_config_cache:
|
740
|
-
self._llm_config_cache[key] = self.get_llm_config_from_handle(**kwargs)
|
740
|
+
self._llm_config_cache[key] = self.get_llm_config_from_handle(actor=actor, **kwargs)
|
741
741
|
return self._llm_config_cache[key]
|
742
742
|
|
743
743
|
@trace_method
|
744
|
-
def get_cached_embedding_config(self, **kwargs):
|
744
|
+
def get_cached_embedding_config(self, actor: User, **kwargs):
|
745
745
|
key = make_key(**kwargs)
|
746
746
|
if key not in self._embedding_config_cache:
|
747
|
-
self._embedding_config_cache[key] = self.get_embedding_config_from_handle(**kwargs)
|
747
|
+
self._embedding_config_cache[key] = self.get_embedding_config_from_handle(actor=actor, **kwargs)
|
748
748
|
return self._embedding_config_cache[key]
|
749
749
|
|
750
750
|
@trace_method
|
@@ -766,7 +766,7 @@ class SyncServer(Server):
|
|
766
766
|
"enable_reasoner": request.enable_reasoner,
|
767
767
|
}
|
768
768
|
log_event(name="start get_cached_llm_config", attributes=config_params)
|
769
|
-
request.llm_config = self.get_cached_llm_config(**config_params)
|
769
|
+
request.llm_config = self.get_cached_llm_config(actor=actor, **config_params)
|
770
770
|
log_event(name="end get_cached_llm_config", attributes=config_params)
|
771
771
|
|
772
772
|
if request.embedding_config is None:
|
@@ -777,7 +777,7 @@ class SyncServer(Server):
|
|
777
777
|
"embedding_chunk_size": request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
|
778
778
|
}
|
779
779
|
log_event(name="start get_cached_embedding_config", attributes=embedding_config_params)
|
780
|
-
request.embedding_config = self.get_cached_embedding_config(**embedding_config_params)
|
780
|
+
request.embedding_config = self.get_cached_embedding_config(actor=actor, **embedding_config_params)
|
781
781
|
log_event(name="end get_cached_embedding_config", attributes=embedding_config_params)
|
782
782
|
|
783
783
|
log_event(name="start create_agent db")
|
@@ -802,10 +802,10 @@ class SyncServer(Server):
|
|
802
802
|
actor: User,
|
803
803
|
) -> AgentState:
|
804
804
|
if request.model is not None:
|
805
|
-
request.llm_config = self.get_llm_config_from_handle(handle=request.model)
|
805
|
+
request.llm_config = self.get_llm_config_from_handle(handle=request.model, actor=actor)
|
806
806
|
|
807
807
|
if request.embedding is not None:
|
808
|
-
request.embedding_config = self.get_embedding_config_from_handle(handle=request.embedding)
|
808
|
+
request.embedding_config = self.get_embedding_config_from_handle(handle=request.embedding, actor=actor)
|
809
809
|
|
810
810
|
if request.enable_sleeptime:
|
811
811
|
agent = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
|
@@ -1201,10 +1201,21 @@ class SyncServer(Server):
|
|
1201
1201
|
except NoResultFound:
|
1202
1202
|
raise HTTPException(status_code=404, detail=f"Organization with id {org_id} not found")
|
1203
1203
|
|
1204
|
-
def list_llm_models(
|
1204
|
+
def list_llm_models(
|
1205
|
+
self,
|
1206
|
+
actor: User,
|
1207
|
+
provider_category: Optional[List[ProviderCategory]] = None,
|
1208
|
+
provider_name: Optional[str] = None,
|
1209
|
+
provider_type: Optional[ProviderType] = None,
|
1210
|
+
) -> List[LLMConfig]:
|
1205
1211
|
"""List available models"""
|
1206
1212
|
llm_models = []
|
1207
|
-
for provider in self.get_enabled_providers(
|
1213
|
+
for provider in self.get_enabled_providers(
|
1214
|
+
provider_category=provider_category,
|
1215
|
+
provider_name=provider_name,
|
1216
|
+
provider_type=provider_type,
|
1217
|
+
actor=actor,
|
1218
|
+
):
|
1208
1219
|
try:
|
1209
1220
|
llm_models.extend(provider.list_llm_models())
|
1210
1221
|
except Exception as e:
|
@@ -1214,26 +1225,49 @@ class SyncServer(Server):
|
|
1214
1225
|
|
1215
1226
|
return llm_models
|
1216
1227
|
|
1217
|
-
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
1228
|
+
def list_embedding_models(self, actor: User) -> List[EmbeddingConfig]:
|
1218
1229
|
"""List available embedding models"""
|
1219
1230
|
embedding_models = []
|
1220
|
-
for provider in self.get_enabled_providers():
|
1231
|
+
for provider in self.get_enabled_providers(actor):
|
1221
1232
|
try:
|
1222
1233
|
embedding_models.extend(provider.list_embedding_models())
|
1223
1234
|
except Exception as e:
|
1224
1235
|
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
|
1225
1236
|
return embedding_models
|
1226
1237
|
|
1227
|
-
def get_enabled_providers(
|
1228
|
-
|
1229
|
-
|
1230
|
-
|
1231
|
-
|
1232
|
-
|
1238
|
+
def get_enabled_providers(
|
1239
|
+
self,
|
1240
|
+
actor: User,
|
1241
|
+
provider_category: Optional[List[ProviderCategory]] = None,
|
1242
|
+
provider_name: Optional[str] = None,
|
1243
|
+
provider_type: Optional[ProviderType] = None,
|
1244
|
+
) -> List[Provider]:
|
1245
|
+
providers = []
|
1246
|
+
if not provider_category or ProviderCategory.base in provider_category:
|
1247
|
+
providers_from_env = [p for p in self._enabled_providers]
|
1248
|
+
providers.extend(providers_from_env)
|
1249
|
+
|
1250
|
+
if not provider_category or ProviderCategory.byok in provider_category:
|
1251
|
+
providers_from_db = self.provider_manager.list_providers(
|
1252
|
+
name=provider_name,
|
1253
|
+
provider_type=provider_type,
|
1254
|
+
actor=actor,
|
1255
|
+
)
|
1256
|
+
providers_from_db = [p.cast_to_subtype() for p in providers_from_db]
|
1257
|
+
providers.extend(providers_from_db)
|
1258
|
+
|
1259
|
+
if provider_name is not None:
|
1260
|
+
providers = [p for p in providers if p.name == provider_name]
|
1261
|
+
|
1262
|
+
if provider_type is not None:
|
1263
|
+
providers = [p for p in providers if p.provider_type == provider_type]
|
1264
|
+
|
1265
|
+
return providers
|
1233
1266
|
|
1234
1267
|
@trace_method
|
1235
1268
|
def get_llm_config_from_handle(
|
1236
1269
|
self,
|
1270
|
+
actor: User,
|
1237
1271
|
handle: str,
|
1238
1272
|
context_window_limit: Optional[int] = None,
|
1239
1273
|
max_tokens: Optional[int] = None,
|
@@ -1242,7 +1276,7 @@ class SyncServer(Server):
|
|
1242
1276
|
) -> LLMConfig:
|
1243
1277
|
try:
|
1244
1278
|
provider_name, model_name = handle.split("/", 1)
|
1245
|
-
provider = self.get_provider_from_name(provider_name)
|
1279
|
+
provider = self.get_provider_from_name(provider_name, actor)
|
1246
1280
|
|
1247
1281
|
llm_configs = [config for config in provider.list_llm_models() if config.handle == handle]
|
1248
1282
|
if not llm_configs:
|
@@ -1286,11 +1320,11 @@ class SyncServer(Server):
|
|
1286
1320
|
|
1287
1321
|
@trace_method
|
1288
1322
|
def get_embedding_config_from_handle(
|
1289
|
-
self, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
|
1323
|
+
self, actor: User, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
|
1290
1324
|
) -> EmbeddingConfig:
|
1291
1325
|
try:
|
1292
1326
|
provider_name, model_name = handle.split("/", 1)
|
1293
|
-
provider = self.get_provider_from_name(provider_name)
|
1327
|
+
provider = self.get_provider_from_name(provider_name, actor)
|
1294
1328
|
|
1295
1329
|
embedding_configs = [config for config in provider.list_embedding_models() if config.handle == handle]
|
1296
1330
|
if not embedding_configs:
|
@@ -1313,8 +1347,8 @@ class SyncServer(Server):
|
|
1313
1347
|
|
1314
1348
|
return embedding_config
|
1315
1349
|
|
1316
|
-
def get_provider_from_name(self, provider_name: str) -> Provider:
|
1317
|
-
providers = [provider for provider in self.get_enabled_providers() if provider.name == provider_name]
|
1350
|
+
def get_provider_from_name(self, provider_name: str, actor: User) -> Provider:
|
1351
|
+
providers = [provider for provider in self.get_enabled_providers(actor) if provider.name == provider_name]
|
1318
1352
|
if not providers:
|
1319
1353
|
raise ValueError(f"Provider {provider_name} is not supported")
|
1320
1354
|
elif len(providers) > 1:
|
@@ -1,9 +1,9 @@
|
|
1
1
|
from typing import List, Optional, Union
|
2
2
|
|
3
3
|
from letta.orm.provider import Provider as ProviderModel
|
4
|
-
from letta.schemas.enums import ProviderType
|
4
|
+
from letta.schemas.enums import ProviderCategory, ProviderType
|
5
5
|
from letta.schemas.providers import Provider as PydanticProvider
|
6
|
-
from letta.schemas.providers import ProviderUpdate
|
6
|
+
from letta.schemas.providers import ProviderCreate, ProviderUpdate
|
7
7
|
from letta.schemas.user import User as PydanticUser
|
8
8
|
from letta.utils import enforce_types
|
9
9
|
|
@@ -16,9 +16,12 @@ class ProviderManager:
|
|
16
16
|
self.session_maker = db_context
|
17
17
|
|
18
18
|
@enforce_types
|
19
|
-
def create_provider(self,
|
19
|
+
def create_provider(self, request: ProviderCreate, actor: PydanticUser) -> PydanticProvider:
|
20
20
|
"""Create a new provider if it doesn't already exist."""
|
21
21
|
with self.session_maker() as session:
|
22
|
+
provider_create_args = {**request.model_dump(), "provider_category": ProviderCategory.byok}
|
23
|
+
provider = PydanticProvider(**provider_create_args)
|
24
|
+
|
22
25
|
if provider.name == provider.provider_type.value:
|
23
26
|
raise ValueError("Provider name must be unique and different from provider type")
|
24
27
|
|
@@ -65,11 +68,11 @@ class ProviderManager:
|
|
65
68
|
@enforce_types
|
66
69
|
def list_providers(
|
67
70
|
self,
|
71
|
+
actor: PydanticUser,
|
68
72
|
name: Optional[str] = None,
|
69
73
|
provider_type: Optional[ProviderType] = None,
|
70
74
|
after: Optional[str] = None,
|
71
75
|
limit: Optional[int] = 50,
|
72
|
-
actor: PydanticUser = None,
|
73
76
|
) -> List[PydanticProvider]:
|
74
77
|
"""List all providers with optional pagination."""
|
75
78
|
filter_kwargs = {}
|
@@ -88,11 +91,11 @@ class ProviderManager:
|
|
88
91
|
return [provider.to_pydantic() for provider in providers]
|
89
92
|
|
90
93
|
@enforce_types
|
91
|
-
def get_provider_id_from_name(self, provider_name: Union[str, None]) -> Optional[str]:
|
92
|
-
providers = self.list_providers(name=provider_name)
|
94
|
+
def get_provider_id_from_name(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
|
95
|
+
providers = self.list_providers(name=provider_name, actor=actor)
|
93
96
|
return providers[0].id if providers else None
|
94
97
|
|
95
98
|
@enforce_types
|
96
|
-
def get_override_key(self, provider_name: Union[str, None]) -> Optional[str]:
|
97
|
-
providers = self.list_providers(name=provider_name)
|
99
|
+
def get_override_key(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
|
100
|
+
providers = self.list_providers(name=provider_name, actor=actor)
|
98
101
|
return providers[0].api_key if providers else None
|
letta/settings.py
CHANGED
@@ -195,6 +195,8 @@ class Settings(BaseSettings):
|
|
195
195
|
|
196
196
|
# experimental toggle
|
197
197
|
use_experimental: bool = False
|
198
|
+
use_vertex_structured_outputs_experimental: bool = False
|
199
|
+
use_vertex_async_loop_experimental: bool = False
|
198
200
|
|
199
201
|
# LLM provider client settings
|
200
202
|
httpx_max_retries: int = 5
|