letta-nightly 0.13.0.dev20251030104218__py3-none-any.whl → 0.13.1.dev20251031234110__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of letta-nightly might be problematic. Click here for more details.
- letta/__init__.py +1 -1
- letta/adapters/simple_llm_stream_adapter.py +1 -0
- letta/agents/letta_agent_v2.py +8 -0
- letta/agents/letta_agent_v3.py +120 -27
- letta/agents/temporal/activities/__init__.py +25 -0
- letta/agents/temporal/activities/create_messages.py +26 -0
- letta/agents/temporal/activities/create_step.py +57 -0
- letta/agents/temporal/activities/example_activity.py +9 -0
- letta/agents/temporal/activities/execute_tool.py +130 -0
- letta/agents/temporal/activities/llm_request.py +114 -0
- letta/agents/temporal/activities/prepare_messages.py +27 -0
- letta/agents/temporal/activities/refresh_context.py +160 -0
- letta/agents/temporal/activities/summarize_conversation_history.py +77 -0
- letta/agents/temporal/activities/update_message_ids.py +25 -0
- letta/agents/temporal/activities/update_run.py +43 -0
- letta/agents/temporal/constants.py +59 -0
- letta/agents/temporal/temporal_agent_workflow.py +704 -0
- letta/agents/temporal/types.py +275 -0
- letta/constants.py +8 -0
- letta/errors.py +4 -0
- letta/functions/function_sets/base.py +0 -11
- letta/groups/helpers.py +7 -1
- letta/groups/sleeptime_multi_agent_v4.py +4 -3
- letta/interfaces/anthropic_streaming_interface.py +0 -1
- letta/interfaces/openai_streaming_interface.py +103 -100
- letta/llm_api/anthropic_client.py +57 -12
- letta/llm_api/bedrock_client.py +1 -0
- letta/llm_api/deepseek_client.py +3 -2
- letta/llm_api/google_vertex_client.py +1 -0
- letta/llm_api/groq_client.py +1 -0
- letta/llm_api/llm_client_base.py +15 -1
- letta/llm_api/openai.py +2 -2
- letta/llm_api/openai_client.py +17 -3
- letta/llm_api/xai_client.py +1 -0
- letta/orm/organization.py +4 -0
- letta/orm/sqlalchemy_base.py +7 -0
- letta/otel/tracing.py +131 -4
- letta/schemas/agent_file.py +10 -10
- letta/schemas/block.py +22 -3
- letta/schemas/enums.py +21 -0
- letta/schemas/environment_variables.py +3 -2
- letta/schemas/group.py +3 -3
- letta/schemas/letta_response.py +36 -4
- letta/schemas/llm_batch_job.py +3 -3
- letta/schemas/llm_config.py +27 -3
- letta/schemas/mcp.py +3 -2
- letta/schemas/mcp_server.py +3 -2
- letta/schemas/message.py +167 -49
- letta/schemas/organization.py +2 -1
- letta/schemas/passage.py +2 -1
- letta/schemas/provider_trace.py +2 -1
- letta/schemas/providers/openrouter.py +1 -2
- letta/schemas/run_metrics.py +2 -1
- letta/schemas/sandbox_config.py +3 -1
- letta/schemas/step_metrics.py +2 -1
- letta/schemas/tool_rule.py +2 -2
- letta/schemas/user.py +2 -1
- letta/server/rest_api/app.py +5 -1
- letta/server/rest_api/routers/v1/__init__.py +4 -0
- letta/server/rest_api/routers/v1/agents.py +71 -9
- letta/server/rest_api/routers/v1/blocks.py +7 -7
- letta/server/rest_api/routers/v1/groups.py +40 -0
- letta/server/rest_api/routers/v1/identities.py +2 -2
- letta/server/rest_api/routers/v1/internal_agents.py +31 -0
- letta/server/rest_api/routers/v1/internal_blocks.py +177 -0
- letta/server/rest_api/routers/v1/internal_runs.py +25 -1
- letta/server/rest_api/routers/v1/runs.py +2 -22
- letta/server/rest_api/routers/v1/tools.py +10 -0
- letta/server/server.py +5 -2
- letta/services/agent_manager.py +4 -4
- letta/services/archive_manager.py +16 -0
- letta/services/group_manager.py +44 -0
- letta/services/helpers/run_manager_helper.py +2 -2
- letta/services/lettuce/lettuce_client.py +148 -0
- letta/services/mcp/base_client.py +9 -3
- letta/services/run_manager.py +148 -37
- letta/services/source_manager.py +91 -3
- letta/services/step_manager.py +2 -3
- letta/services/streaming_service.py +52 -13
- letta/services/summarizer/summarizer.py +28 -2
- letta/services/tool_executor/builtin_tool_executor.py +1 -1
- letta/services/tool_executor/core_tool_executor.py +2 -117
- letta/services/tool_schema_generator.py +2 -2
- letta/validators.py +21 -0
- {letta_nightly-0.13.0.dev20251030104218.dist-info → letta_nightly-0.13.1.dev20251031234110.dist-info}/METADATA +1 -1
- {letta_nightly-0.13.0.dev20251030104218.dist-info → letta_nightly-0.13.1.dev20251031234110.dist-info}/RECORD +89 -84
- letta/agent.py +0 -1758
- letta/cli/cli_load.py +0 -16
- letta/client/__init__.py +0 -0
- letta/client/streaming.py +0 -95
- letta/client/utils.py +0 -78
- letta/functions/async_composio_toolset.py +0 -109
- letta/functions/composio_helpers.py +0 -96
- letta/helpers/composio_helpers.py +0 -38
- letta/orm/job_messages.py +0 -33
- letta/schemas/providers.py +0 -1617
- letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +0 -132
- letta/services/tool_executor/composio_tool_executor.py +0 -57
- {letta_nightly-0.13.0.dev20251030104218.dist-info → letta_nightly-0.13.1.dev20251031234110.dist-info}/WHEEL +0 -0
- {letta_nightly-0.13.0.dev20251030104218.dist-info → letta_nightly-0.13.1.dev20251031234110.dist-info}/entry_points.txt +0 -0
- {letta_nightly-0.13.0.dev20251030104218.dist-info → letta_nightly-0.13.1.dev20251031234110.dist-info}/licenses/LICENSE +0 -0
letta/schemas/providers.py
DELETED
|
@@ -1,1617 +0,0 @@
|
|
|
1
|
-
import warnings
|
|
2
|
-
from datetime import datetime
|
|
3
|
-
from typing import List, Literal, Optional
|
|
4
|
-
|
|
5
|
-
import aiohttp
|
|
6
|
-
import requests
|
|
7
|
-
from pydantic import BaseModel, Field, model_validator
|
|
8
|
-
|
|
9
|
-
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE, LETTA_MODEL_ENDPOINT, LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW
|
|
10
|
-
from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_azure_embeddings_endpoint
|
|
11
|
-
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
|
|
12
|
-
from letta.schemas.embedding_config import EmbeddingConfig
|
|
13
|
-
from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES
|
|
14
|
-
from letta.schemas.enums import ProviderCategory, ProviderType
|
|
15
|
-
from letta.schemas.letta_base import LettaBase
|
|
16
|
-
from letta.schemas.llm_config import LLMConfig
|
|
17
|
-
from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES
|
|
18
|
-
from letta.settings import model_settings
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class ProviderBase(LettaBase):
|
|
22
|
-
__id_prefix__ = "provider"
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
class Provider(ProviderBase):
|
|
26
|
-
id: Optional[str] = Field(None, description="The id of the provider, lazily created by the database manager.")
|
|
27
|
-
name: str = Field(..., description="The name of the provider")
|
|
28
|
-
provider_type: ProviderType = Field(..., description="The type of the provider")
|
|
29
|
-
provider_category: ProviderCategory = Field(..., description="The category of the provider (base or byok)")
|
|
30
|
-
api_key: Optional[str] = Field(None, description="API key or secret key used for requests to the provider.")
|
|
31
|
-
base_url: Optional[str] = Field(None, description="Base URL for the provider.")
|
|
32
|
-
access_key: Optional[str] = Field(None, description="Access key used for requests to the provider.")
|
|
33
|
-
region: Optional[str] = Field(None, description="Region used for requests to the provider.")
|
|
34
|
-
organization_id: Optional[str] = Field(None, description="The organization id of the user")
|
|
35
|
-
updated_at: Optional[datetime] = Field(None, description="The last update timestamp of the provider.")
|
|
36
|
-
|
|
37
|
-
@model_validator(mode="after")
|
|
38
|
-
def default_base_url(self):
|
|
39
|
-
if self.provider_type == ProviderType.openai and self.base_url is None:
|
|
40
|
-
self.base_url = model_settings.openai_api_base
|
|
41
|
-
return self
|
|
42
|
-
|
|
43
|
-
def resolve_identifier(self):
|
|
44
|
-
if not self.id:
|
|
45
|
-
self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__)
|
|
46
|
-
|
|
47
|
-
def check_api_key(self):
|
|
48
|
-
"""Check if the API key is valid for the provider"""
|
|
49
|
-
raise NotImplementedError
|
|
50
|
-
|
|
51
|
-
def list_llm_models(self) -> List[LLMConfig]:
|
|
52
|
-
return []
|
|
53
|
-
|
|
54
|
-
async def list_llm_models_async(self) -> List[LLMConfig]:
|
|
55
|
-
return []
|
|
56
|
-
|
|
57
|
-
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
|
58
|
-
return []
|
|
59
|
-
|
|
60
|
-
async def list_embedding_models_async(self) -> List[EmbeddingConfig]:
|
|
61
|
-
return self.list_embedding_models()
|
|
62
|
-
|
|
63
|
-
def get_model_context_window(self, model_name: str) -> Optional[int]:
|
|
64
|
-
raise NotImplementedError
|
|
65
|
-
|
|
66
|
-
async def get_model_context_window_async(self, model_name: str) -> Optional[int]:
|
|
67
|
-
raise NotImplementedError
|
|
68
|
-
|
|
69
|
-
def provider_tag(self) -> str:
|
|
70
|
-
"""String representation of the provider for display purposes"""
|
|
71
|
-
raise NotImplementedError
|
|
72
|
-
|
|
73
|
-
def get_handle(self, model_name: str, is_embedding: bool = False, base_name: Optional[str] = None) -> str:
|
|
74
|
-
"""
|
|
75
|
-
Get the handle for a model, with support for custom overrides.
|
|
76
|
-
|
|
77
|
-
Args:
|
|
78
|
-
model_name (str): The name of the model.
|
|
79
|
-
is_embedding (bool, optional): Whether the handle is for an embedding model. Defaults to False.
|
|
80
|
-
|
|
81
|
-
Returns:
|
|
82
|
-
str: The handle for the model.
|
|
83
|
-
"""
|
|
84
|
-
base_name = base_name if base_name else self.name
|
|
85
|
-
|
|
86
|
-
overrides = EMBEDDING_HANDLE_OVERRIDES if is_embedding else LLM_HANDLE_OVERRIDES
|
|
87
|
-
if base_name in overrides and model_name in overrides[base_name]:
|
|
88
|
-
model_name = overrides[base_name][model_name]
|
|
89
|
-
|
|
90
|
-
return f"{base_name}/{model_name}"
|
|
91
|
-
|
|
92
|
-
def cast_to_subtype(self):
|
|
93
|
-
match self.provider_type:
|
|
94
|
-
case ProviderType.letta:
|
|
95
|
-
return LettaProvider(**self.model_dump(exclude_none=True))
|
|
96
|
-
case ProviderType.openai:
|
|
97
|
-
return OpenAIProvider(**self.model_dump(exclude_none=True))
|
|
98
|
-
case ProviderType.anthropic:
|
|
99
|
-
return AnthropicProvider(**self.model_dump(exclude_none=True))
|
|
100
|
-
case ProviderType.bedrock:
|
|
101
|
-
return BedrockProvider(**self.model_dump(exclude_none=True))
|
|
102
|
-
case ProviderType.ollama:
|
|
103
|
-
return OllamaProvider(**self.model_dump(exclude_none=True))
|
|
104
|
-
case ProviderType.google_ai:
|
|
105
|
-
return GoogleAIProvider(**self.model_dump(exclude_none=True))
|
|
106
|
-
case ProviderType.google_vertex:
|
|
107
|
-
return GoogleVertexProvider(**self.model_dump(exclude_none=True))
|
|
108
|
-
case ProviderType.azure:
|
|
109
|
-
return AzureProvider(**self.model_dump(exclude_none=True))
|
|
110
|
-
case ProviderType.groq:
|
|
111
|
-
return GroqProvider(**self.model_dump(exclude_none=True))
|
|
112
|
-
case ProviderType.together:
|
|
113
|
-
return TogetherProvider(**self.model_dump(exclude_none=True))
|
|
114
|
-
case ProviderType.vllm_chat_completions:
|
|
115
|
-
return VLLMChatCompletionsProvider(**self.model_dump(exclude_none=True))
|
|
116
|
-
case ProviderType.vllm_completions:
|
|
117
|
-
return VLLMCompletionsProvider(**self.model_dump(exclude_none=True))
|
|
118
|
-
case ProviderType.xai:
|
|
119
|
-
return XAIProvider(**self.model_dump(exclude_none=True))
|
|
120
|
-
case _:
|
|
121
|
-
raise ValueError(f"Unknown provider type: {self.provider_type}")
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
class ProviderCreate(ProviderBase):
|
|
125
|
-
name: str = Field(..., description="The name of the provider.")
|
|
126
|
-
provider_type: ProviderType = Field(..., description="The type of the provider.")
|
|
127
|
-
api_key: str = Field(..., description="API key or secret key used for requests to the provider.")
|
|
128
|
-
access_key: Optional[str] = Field(None, description="Access key used for requests to the provider.")
|
|
129
|
-
region: Optional[str] = Field(None, description="Region used for requests to the provider.")
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
class ProviderUpdate(ProviderBase):
|
|
133
|
-
api_key: str = Field(..., description="API key or secret key used for requests to the provider.")
|
|
134
|
-
access_key: Optional[str] = Field(None, description="Access key used for requests to the provider.")
|
|
135
|
-
region: Optional[str] = Field(None, description="Region used for requests to the provider.")
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
class ProviderCheck(BaseModel):
|
|
139
|
-
provider_type: ProviderType = Field(..., description="The type of the provider.")
|
|
140
|
-
api_key: str = Field(..., description="API key or secret key used for requests to the provider.")
|
|
141
|
-
access_key: Optional[str] = Field(None, description="Access key used for requests to the provider.")
|
|
142
|
-
region: Optional[str] = Field(None, description="Region used for requests to the provider.")
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
class LettaProvider(Provider):
|
|
146
|
-
provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.")
|
|
147
|
-
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
|
148
|
-
|
|
149
|
-
def list_llm_models(self) -> List[LLMConfig]:
|
|
150
|
-
return [
|
|
151
|
-
LLMConfig(
|
|
152
|
-
model="letta-free", # NOTE: renamed
|
|
153
|
-
model_endpoint_type="openai",
|
|
154
|
-
model_endpoint=LETTA_MODEL_ENDPOINT,
|
|
155
|
-
context_window=30000,
|
|
156
|
-
handle=self.get_handle("letta-free"),
|
|
157
|
-
provider_name=self.name,
|
|
158
|
-
provider_category=self.provider_category,
|
|
159
|
-
)
|
|
160
|
-
]
|
|
161
|
-
|
|
162
|
-
async def list_llm_models_async(self) -> List[LLMConfig]:
|
|
163
|
-
return [
|
|
164
|
-
LLMConfig(
|
|
165
|
-
model="letta-free", # NOTE: renamed
|
|
166
|
-
model_endpoint_type="openai",
|
|
167
|
-
model_endpoint=LETTA_MODEL_ENDPOINT,
|
|
168
|
-
context_window=30000,
|
|
169
|
-
handle=self.get_handle("letta-free"),
|
|
170
|
-
provider_name=self.name,
|
|
171
|
-
provider_category=self.provider_category,
|
|
172
|
-
)
|
|
173
|
-
]
|
|
174
|
-
|
|
175
|
-
def list_embedding_models(self):
|
|
176
|
-
return [
|
|
177
|
-
EmbeddingConfig(
|
|
178
|
-
embedding_model="letta-free", # NOTE: renamed
|
|
179
|
-
embedding_endpoint_type="hugging-face",
|
|
180
|
-
embedding_endpoint="https://embeddings.memgpt.ai",
|
|
181
|
-
embedding_dim=1024,
|
|
182
|
-
embedding_chunk_size=300,
|
|
183
|
-
handle=self.get_handle("letta-free", is_embedding=True),
|
|
184
|
-
batch_size=32,
|
|
185
|
-
)
|
|
186
|
-
]
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
class OpenAIProvider(Provider):
|
|
190
|
-
provider_type: Literal[ProviderType.openai] = Field(ProviderType.openai, description="The type of the provider.")
|
|
191
|
-
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
|
192
|
-
api_key: str = Field(..., description="API key for the OpenAI API.")
|
|
193
|
-
base_url: str = Field(..., description="Base URL for the OpenAI API.")
|
|
194
|
-
|
|
195
|
-
def check_api_key(self):
|
|
196
|
-
from letta.llm_api.openai import openai_check_valid_api_key
|
|
197
|
-
|
|
198
|
-
openai_check_valid_api_key(self.base_url, self.api_key)
|
|
199
|
-
|
|
200
|
-
def _get_models(self) -> List[dict]:
|
|
201
|
-
from letta.llm_api.openai import openai_get_model_list
|
|
202
|
-
|
|
203
|
-
# Some hardcoded support for OpenRouter (so that we only get models with tool calling support)...
|
|
204
|
-
# See: https://openrouter.ai/docs/requests
|
|
205
|
-
extra_params = {"supported_parameters": "tools"} if "openrouter.ai" in self.base_url else None
|
|
206
|
-
|
|
207
|
-
# Similar to Nebius
|
|
208
|
-
extra_params = {"verbose": True} if "nebius.com" in self.base_url else None
|
|
209
|
-
|
|
210
|
-
response = openai_get_model_list(
|
|
211
|
-
self.base_url,
|
|
212
|
-
api_key=self.api_key,
|
|
213
|
-
extra_params=extra_params,
|
|
214
|
-
# fix_url=True, # NOTE: make sure together ends with /v1
|
|
215
|
-
)
|
|
216
|
-
|
|
217
|
-
if "data" in response:
|
|
218
|
-
data = response["data"]
|
|
219
|
-
else:
|
|
220
|
-
# TogetherAI's response is missing the 'data' field
|
|
221
|
-
data = response
|
|
222
|
-
|
|
223
|
-
return data
|
|
224
|
-
|
|
225
|
-
async def _get_models_async(self) -> List[dict]:
|
|
226
|
-
from letta.llm_api.openai import openai_get_model_list_async
|
|
227
|
-
|
|
228
|
-
# Some hardcoded support for OpenRouter (so that we only get models with tool calling support)...
|
|
229
|
-
# See: https://openrouter.ai/docs/requests
|
|
230
|
-
extra_params = {"supported_parameters": "tools"} if "openrouter.ai" in self.base_url else None
|
|
231
|
-
|
|
232
|
-
# Similar to Nebius
|
|
233
|
-
extra_params = {"verbose": True} if "nebius.com" in self.base_url else None
|
|
234
|
-
|
|
235
|
-
response = await openai_get_model_list_async(
|
|
236
|
-
self.base_url,
|
|
237
|
-
api_key=self.api_key,
|
|
238
|
-
extra_params=extra_params,
|
|
239
|
-
# fix_url=True, # NOTE: make sure together ends with /v1
|
|
240
|
-
)
|
|
241
|
-
|
|
242
|
-
if "data" in response:
|
|
243
|
-
data = response["data"]
|
|
244
|
-
else:
|
|
245
|
-
# TogetherAI's response is missing the 'data' field
|
|
246
|
-
data = response
|
|
247
|
-
|
|
248
|
-
return data
|
|
249
|
-
|
|
250
|
-
def list_llm_models(self) -> List[LLMConfig]:
|
|
251
|
-
data = self._get_models()
|
|
252
|
-
return self._list_llm_models(data)
|
|
253
|
-
|
|
254
|
-
async def list_llm_models_async(self) -> List[LLMConfig]:
|
|
255
|
-
data = await self._get_models_async()
|
|
256
|
-
return self._list_llm_models(data)
|
|
257
|
-
|
|
258
|
-
def _list_llm_models(self, data) -> List[LLMConfig]:
|
|
259
|
-
configs = []
|
|
260
|
-
for model in data:
|
|
261
|
-
assert "id" in model, f"OpenAI model missing 'id' field: {model}"
|
|
262
|
-
model_name = model["id"]
|
|
263
|
-
|
|
264
|
-
if "context_length" in model:
|
|
265
|
-
# Context length is returned in OpenRouter as "context_length"
|
|
266
|
-
context_window_size = model["context_length"]
|
|
267
|
-
else:
|
|
268
|
-
context_window_size = self.get_model_context_window_size(model_name)
|
|
269
|
-
|
|
270
|
-
if not context_window_size:
|
|
271
|
-
continue
|
|
272
|
-
|
|
273
|
-
# TogetherAI includes the type, which we can use to filter out embedding models
|
|
274
|
-
if "api.together.ai" in self.base_url or "api.together.xyz" in self.base_url:
|
|
275
|
-
if "type" in model and model["type"] not in ["chat", "language"]:
|
|
276
|
-
continue
|
|
277
|
-
|
|
278
|
-
# for TogetherAI, we need to skip the models that don't support JSON mode / function calling
|
|
279
|
-
# requests.exceptions.HTTPError: HTTP error occurred: 400 Client Error: Bad Request for url: https://api.together.ai/v1/chat/completions | Status code: 400, Message: {
|
|
280
|
-
# "error": {
|
|
281
|
-
# "message": "mistralai/Mixtral-8x7B-v0.1 is not supported for JSON mode/function calling",
|
|
282
|
-
# "type": "invalid_request_error",
|
|
283
|
-
# "param": null,
|
|
284
|
-
# "code": "constraints_model"
|
|
285
|
-
# }
|
|
286
|
-
# }
|
|
287
|
-
if "config" not in model:
|
|
288
|
-
continue
|
|
289
|
-
|
|
290
|
-
if "nebius.com" in self.base_url:
|
|
291
|
-
# Nebius includes the type, which we can use to filter for text models
|
|
292
|
-
try:
|
|
293
|
-
model_type = model["architecture"]["modality"]
|
|
294
|
-
if model_type not in ["text->text", "text+image->text"]:
|
|
295
|
-
# print(f"Skipping model w/ modality {model_type}:\n{model}")
|
|
296
|
-
continue
|
|
297
|
-
except KeyError:
|
|
298
|
-
print(f"Couldn't access architecture type field, skipping model:\n{model}")
|
|
299
|
-
continue
|
|
300
|
-
|
|
301
|
-
# for openai, filter models
|
|
302
|
-
if self.base_url == "https://api.openai.com/v1":
|
|
303
|
-
allowed_types = ["gpt-4", "o1", "o3", "o4"]
|
|
304
|
-
# NOTE: o1-mini and o1-preview do not support tool calling
|
|
305
|
-
# NOTE: o1-mini does not support system messages
|
|
306
|
-
# NOTE: o1-pro is only available in Responses API
|
|
307
|
-
disallowed_types = ["transcribe", "search", "realtime", "tts", "audio", "computer", "o1-mini", "o1-preview", "o1-pro"]
|
|
308
|
-
skip = True
|
|
309
|
-
for model_type in allowed_types:
|
|
310
|
-
if model_name.startswith(model_type):
|
|
311
|
-
skip = False
|
|
312
|
-
break
|
|
313
|
-
for keyword in disallowed_types:
|
|
314
|
-
if keyword in model_name:
|
|
315
|
-
skip = True
|
|
316
|
-
break
|
|
317
|
-
# ignore this model
|
|
318
|
-
if skip:
|
|
319
|
-
continue
|
|
320
|
-
|
|
321
|
-
# set the handle to openai-proxy if the base URL isn't OpenAI
|
|
322
|
-
if self.base_url != "https://api.openai.com/v1":
|
|
323
|
-
handle = self.get_handle(model_name, base_name="openai-proxy")
|
|
324
|
-
else:
|
|
325
|
-
handle = self.get_handle(model_name)
|
|
326
|
-
|
|
327
|
-
llm_config = LLMConfig(
|
|
328
|
-
model=model_name,
|
|
329
|
-
model_endpoint_type="openai",
|
|
330
|
-
model_endpoint=self.base_url,
|
|
331
|
-
context_window=context_window_size,
|
|
332
|
-
handle=handle,
|
|
333
|
-
provider_name=self.name,
|
|
334
|
-
provider_category=self.provider_category,
|
|
335
|
-
)
|
|
336
|
-
|
|
337
|
-
# gpt-4o-mini has started to regress with pretty bad emoji spam loops
|
|
338
|
-
# this is to counteract that
|
|
339
|
-
if "gpt-4o-mini" in model_name:
|
|
340
|
-
llm_config.frequency_penalty = 1.0
|
|
341
|
-
if "gpt-4.1-mini" in model_name:
|
|
342
|
-
llm_config.frequency_penalty = 1.0
|
|
343
|
-
|
|
344
|
-
configs.append(llm_config)
|
|
345
|
-
|
|
346
|
-
# for OpenAI, sort in reverse order
|
|
347
|
-
if self.base_url == "https://api.openai.com/v1":
|
|
348
|
-
# alphnumeric sort
|
|
349
|
-
configs.sort(key=lambda x: x.model, reverse=True)
|
|
350
|
-
|
|
351
|
-
return configs
|
|
352
|
-
|
|
353
|
-
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
|
354
|
-
if self.base_url == "https://api.openai.com/v1":
|
|
355
|
-
# TODO: actually automatically list models for OpenAI
|
|
356
|
-
return [
|
|
357
|
-
EmbeddingConfig(
|
|
358
|
-
embedding_model="text-embedding-ada-002",
|
|
359
|
-
embedding_endpoint_type="openai",
|
|
360
|
-
embedding_endpoint=self.base_url,
|
|
361
|
-
embedding_dim=1536,
|
|
362
|
-
embedding_chunk_size=300,
|
|
363
|
-
handle=self.get_handle("text-embedding-ada-002", is_embedding=True),
|
|
364
|
-
batch_size=1024,
|
|
365
|
-
),
|
|
366
|
-
EmbeddingConfig(
|
|
367
|
-
embedding_model="text-embedding-3-small",
|
|
368
|
-
embedding_endpoint_type="openai",
|
|
369
|
-
embedding_endpoint=self.base_url,
|
|
370
|
-
embedding_dim=2000,
|
|
371
|
-
embedding_chunk_size=300,
|
|
372
|
-
handle=self.get_handle("text-embedding-3-small", is_embedding=True),
|
|
373
|
-
batch_size=1024,
|
|
374
|
-
),
|
|
375
|
-
EmbeddingConfig(
|
|
376
|
-
embedding_model="text-embedding-3-large",
|
|
377
|
-
embedding_endpoint_type="openai",
|
|
378
|
-
embedding_endpoint=self.base_url,
|
|
379
|
-
embedding_dim=2000,
|
|
380
|
-
embedding_chunk_size=300,
|
|
381
|
-
handle=self.get_handle("text-embedding-3-large", is_embedding=True),
|
|
382
|
-
batch_size=1024,
|
|
383
|
-
),
|
|
384
|
-
]
|
|
385
|
-
|
|
386
|
-
else:
|
|
387
|
-
# Actually attempt to list
|
|
388
|
-
data = self._get_models()
|
|
389
|
-
return self._list_embedding_models(data)
|
|
390
|
-
|
|
391
|
-
async def list_embedding_models_async(self) -> List[EmbeddingConfig]:
|
|
392
|
-
if self.base_url == "https://api.openai.com/v1":
|
|
393
|
-
# TODO: actually automatically list models for OpenAI
|
|
394
|
-
return [
|
|
395
|
-
EmbeddingConfig(
|
|
396
|
-
embedding_model="text-embedding-ada-002",
|
|
397
|
-
embedding_endpoint_type="openai",
|
|
398
|
-
embedding_endpoint=self.base_url,
|
|
399
|
-
embedding_dim=1536,
|
|
400
|
-
embedding_chunk_size=300,
|
|
401
|
-
handle=self.get_handle("text-embedding-ada-002", is_embedding=True),
|
|
402
|
-
batch_size=1024,
|
|
403
|
-
),
|
|
404
|
-
EmbeddingConfig(
|
|
405
|
-
embedding_model="text-embedding-3-small",
|
|
406
|
-
embedding_endpoint_type="openai",
|
|
407
|
-
embedding_endpoint=self.base_url,
|
|
408
|
-
embedding_dim=2000,
|
|
409
|
-
embedding_chunk_size=300,
|
|
410
|
-
handle=self.get_handle("text-embedding-3-small", is_embedding=True),
|
|
411
|
-
batch_size=1024,
|
|
412
|
-
),
|
|
413
|
-
EmbeddingConfig(
|
|
414
|
-
embedding_model="text-embedding-3-large",
|
|
415
|
-
embedding_endpoint_type="openai",
|
|
416
|
-
embedding_endpoint=self.base_url,
|
|
417
|
-
embedding_dim=2000,
|
|
418
|
-
embedding_chunk_size=300,
|
|
419
|
-
handle=self.get_handle("text-embedding-3-large", is_embedding=True),
|
|
420
|
-
batch_size=1024,
|
|
421
|
-
),
|
|
422
|
-
]
|
|
423
|
-
|
|
424
|
-
else:
|
|
425
|
-
# Actually attempt to list
|
|
426
|
-
data = await self._get_models_async()
|
|
427
|
-
return self._list_embedding_models(data)
|
|
428
|
-
|
|
429
|
-
def _list_embedding_models(self, data) -> List[EmbeddingConfig]:
|
|
430
|
-
configs = []
|
|
431
|
-
for model in data:
|
|
432
|
-
assert "id" in model, f"Model missing 'id' field: {model}"
|
|
433
|
-
model_name = model["id"]
|
|
434
|
-
|
|
435
|
-
if "context_length" in model:
|
|
436
|
-
# Context length is returned in Nebius as "context_length"
|
|
437
|
-
context_window_size = model["context_length"]
|
|
438
|
-
else:
|
|
439
|
-
context_window_size = self.get_model_context_window_size(model_name)
|
|
440
|
-
|
|
441
|
-
# We need the context length for embeddings too
|
|
442
|
-
if not context_window_size:
|
|
443
|
-
continue
|
|
444
|
-
|
|
445
|
-
if "nebius.com" in self.base_url:
|
|
446
|
-
# Nebius includes the type, which we can use to filter for embedidng models
|
|
447
|
-
try:
|
|
448
|
-
model_type = model["architecture"]["modality"]
|
|
449
|
-
if model_type not in ["text->embedding"]:
|
|
450
|
-
# print(f"Skipping model w/ modality {model_type}:\n{model}")
|
|
451
|
-
continue
|
|
452
|
-
except KeyError:
|
|
453
|
-
print(f"Couldn't access architecture type field, skipping model:\n{model}")
|
|
454
|
-
continue
|
|
455
|
-
|
|
456
|
-
elif "together.ai" in self.base_url or "together.xyz" in self.base_url:
|
|
457
|
-
# TogetherAI includes the type, which we can use to filter for embedding models
|
|
458
|
-
if "type" in model and model["type"] not in ["embedding"]:
|
|
459
|
-
# print(f"Skipping model w/ modality {model_type}:\n{model}")
|
|
460
|
-
continue
|
|
461
|
-
|
|
462
|
-
else:
|
|
463
|
-
# For other providers we should skip by default, since we don't want to assume embeddings are supported
|
|
464
|
-
continue
|
|
465
|
-
|
|
466
|
-
configs.append(
|
|
467
|
-
EmbeddingConfig(
|
|
468
|
-
embedding_model=model_name,
|
|
469
|
-
embedding_endpoint_type=self.provider_type,
|
|
470
|
-
embedding_endpoint=self.base_url,
|
|
471
|
-
embedding_dim=context_window_size,
|
|
472
|
-
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
|
|
473
|
-
handle=self.get_handle(model, is_embedding=True),
|
|
474
|
-
)
|
|
475
|
-
)
|
|
476
|
-
|
|
477
|
-
return configs
|
|
478
|
-
|
|
479
|
-
def get_model_context_window_size(self, model_name: str):
|
|
480
|
-
if model_name in LLM_MAX_TOKENS:
|
|
481
|
-
return LLM_MAX_TOKENS[model_name]
|
|
482
|
-
else:
|
|
483
|
-
return LLM_MAX_TOKENS["DEFAULT"]
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
class DeepSeekProvider(OpenAIProvider):
|
|
487
|
-
"""
|
|
488
|
-
DeepSeek ChatCompletions API is similar to OpenAI's reasoning API,
|
|
489
|
-
but with slight differences:
|
|
490
|
-
* For example, DeepSeek's API requires perfect interleaving of user/assistant
|
|
491
|
-
* It also does not support native function calling
|
|
492
|
-
"""
|
|
493
|
-
|
|
494
|
-
provider_type: Literal[ProviderType.deepseek] = Field(ProviderType.deepseek, description="The type of the provider.")
|
|
495
|
-
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
|
496
|
-
base_url: str = Field("https://api.deepseek.com/v1", description="Base URL for the DeepSeek API.")
|
|
497
|
-
api_key: str = Field(..., description="API key for the DeepSeek API.")
|
|
498
|
-
|
|
499
|
-
def get_model_context_window_size(self, model_name: str) -> Optional[int]:
|
|
500
|
-
# DeepSeek doesn't return context window in the model listing,
|
|
501
|
-
# so these are hardcoded from their website
|
|
502
|
-
if model_name == "deepseek-reasoner":
|
|
503
|
-
return 64000
|
|
504
|
-
elif model_name == "deepseek-chat":
|
|
505
|
-
return 64000
|
|
506
|
-
else:
|
|
507
|
-
return None
|
|
508
|
-
|
|
509
|
-
def list_llm_models(self) -> List[LLMConfig]:
|
|
510
|
-
from letta.llm_api.openai import openai_get_model_list
|
|
511
|
-
|
|
512
|
-
response = openai_get_model_list(self.base_url, api_key=self.api_key)
|
|
513
|
-
|
|
514
|
-
if "data" in response:
|
|
515
|
-
data = response["data"]
|
|
516
|
-
else:
|
|
517
|
-
data = response
|
|
518
|
-
|
|
519
|
-
configs = []
|
|
520
|
-
for model in data:
|
|
521
|
-
assert "id" in model, f"DeepSeek model missing 'id' field: {model}"
|
|
522
|
-
model_name = model["id"]
|
|
523
|
-
|
|
524
|
-
# In case DeepSeek starts supporting it in the future:
|
|
525
|
-
if "context_length" in model:
|
|
526
|
-
# Context length is returned in OpenRouter as "context_length"
|
|
527
|
-
context_window_size = model["context_length"]
|
|
528
|
-
else:
|
|
529
|
-
context_window_size = self.get_model_context_window_size(model_name)
|
|
530
|
-
|
|
531
|
-
if not context_window_size:
|
|
532
|
-
warnings.warn(f"Couldn't find context window size for model {model_name}")
|
|
533
|
-
continue
|
|
534
|
-
|
|
535
|
-
# Not used for deepseek-reasoner, but otherwise is true
|
|
536
|
-
put_inner_thoughts_in_kwargs = False if model_name == "deepseek-reasoner" else True
|
|
537
|
-
|
|
538
|
-
configs.append(
|
|
539
|
-
LLMConfig(
|
|
540
|
-
model=model_name,
|
|
541
|
-
model_endpoint_type="deepseek",
|
|
542
|
-
model_endpoint=self.base_url,
|
|
543
|
-
context_window=context_window_size,
|
|
544
|
-
handle=self.get_handle(model_name),
|
|
545
|
-
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
|
|
546
|
-
provider_name=self.name,
|
|
547
|
-
provider_category=self.provider_category,
|
|
548
|
-
)
|
|
549
|
-
)
|
|
550
|
-
|
|
551
|
-
return configs
|
|
552
|
-
|
|
553
|
-
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
|
554
|
-
# No embeddings supported
|
|
555
|
-
return []
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
class LMStudioOpenAIProvider(OpenAIProvider):
|
|
559
|
-
provider_type: Literal[ProviderType.lmstudio_openai] = Field(ProviderType.lmstudio_openai, description="The type of the provider.")
|
|
560
|
-
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
|
561
|
-
base_url: str = Field(..., description="Base URL for the LMStudio OpenAI API.")
|
|
562
|
-
api_key: Optional[str] = Field(None, description="API key for the LMStudio API.")
|
|
563
|
-
|
|
564
|
-
def list_llm_models(self) -> List[LLMConfig]:
|
|
565
|
-
from letta.llm_api.openai import openai_get_model_list
|
|
566
|
-
|
|
567
|
-
# For LMStudio, we want to hit 'GET /api/v0/models' instead of 'GET /v1/models'
|
|
568
|
-
MODEL_ENDPOINT_URL = f"{self.base_url.strip('/v1')}/api/v0"
|
|
569
|
-
response = openai_get_model_list(MODEL_ENDPOINT_URL)
|
|
570
|
-
|
|
571
|
-
"""
|
|
572
|
-
Example response:
|
|
573
|
-
|
|
574
|
-
{
|
|
575
|
-
"object": "list",
|
|
576
|
-
"data": [
|
|
577
|
-
{
|
|
578
|
-
"id": "qwen2-vl-7b-instruct",
|
|
579
|
-
"object": "model",
|
|
580
|
-
"type": "vlm",
|
|
581
|
-
"publisher": "mlx-community",
|
|
582
|
-
"arch": "qwen2_vl",
|
|
583
|
-
"compatibility_type": "mlx",
|
|
584
|
-
"quantization": "4bit",
|
|
585
|
-
"state": "not-loaded",
|
|
586
|
-
"max_context_length": 32768
|
|
587
|
-
},
|
|
588
|
-
...
|
|
589
|
-
"""
|
|
590
|
-
if "data" not in response:
|
|
591
|
-
warnings.warn(f"LMStudio OpenAI model query response missing 'data' field: {response}")
|
|
592
|
-
return []
|
|
593
|
-
|
|
594
|
-
configs = []
|
|
595
|
-
for model in response["data"]:
|
|
596
|
-
assert "id" in model, f"Model missing 'id' field: {model}"
|
|
597
|
-
model_name = model["id"]
|
|
598
|
-
|
|
599
|
-
if "type" not in model:
|
|
600
|
-
warnings.warn(f"LMStudio OpenAI model missing 'type' field: {model}")
|
|
601
|
-
continue
|
|
602
|
-
elif model["type"] not in ["vlm", "llm"]:
|
|
603
|
-
continue
|
|
604
|
-
|
|
605
|
-
if "max_context_length" in model:
|
|
606
|
-
context_window_size = model["max_context_length"]
|
|
607
|
-
else:
|
|
608
|
-
warnings.warn(f"LMStudio OpenAI model missing 'max_context_length' field: {model}")
|
|
609
|
-
continue
|
|
610
|
-
|
|
611
|
-
configs.append(
|
|
612
|
-
LLMConfig(
|
|
613
|
-
model=model_name,
|
|
614
|
-
model_endpoint_type="openai",
|
|
615
|
-
model_endpoint=self.base_url,
|
|
616
|
-
context_window=context_window_size,
|
|
617
|
-
handle=self.get_handle(model_name),
|
|
618
|
-
)
|
|
619
|
-
)
|
|
620
|
-
|
|
621
|
-
return configs
|
|
622
|
-
|
|
623
|
-
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
|
624
|
-
from letta.llm_api.openai import openai_get_model_list
|
|
625
|
-
|
|
626
|
-
# For LMStudio, we want to hit 'GET /api/v0/models' instead of 'GET /v1/models'
|
|
627
|
-
MODEL_ENDPOINT_URL = f"{self.base_url.strip('/v1')}/api/v0"
|
|
628
|
-
response = openai_get_model_list(MODEL_ENDPOINT_URL)
|
|
629
|
-
|
|
630
|
-
"""
|
|
631
|
-
Example response:
|
|
632
|
-
{
|
|
633
|
-
"object": "list",
|
|
634
|
-
"data": [
|
|
635
|
-
{
|
|
636
|
-
"id": "text-embedding-nomic-embed-text-v1.5",
|
|
637
|
-
"object": "model",
|
|
638
|
-
"type": "embeddings",
|
|
639
|
-
"publisher": "nomic-ai",
|
|
640
|
-
"arch": "nomic-bert",
|
|
641
|
-
"compatibility_type": "gguf",
|
|
642
|
-
"quantization": "Q4_0",
|
|
643
|
-
"state": "not-loaded",
|
|
644
|
-
"max_context_length": 2048
|
|
645
|
-
}
|
|
646
|
-
...
|
|
647
|
-
"""
|
|
648
|
-
if "data" not in response:
|
|
649
|
-
warnings.warn(f"LMStudio OpenAI model query response missing 'data' field: {response}")
|
|
650
|
-
return []
|
|
651
|
-
|
|
652
|
-
configs = []
|
|
653
|
-
for model in response["data"]:
|
|
654
|
-
assert "id" in model, f"Model missing 'id' field: {model}"
|
|
655
|
-
model_name = model["id"]
|
|
656
|
-
|
|
657
|
-
if "type" not in model:
|
|
658
|
-
warnings.warn(f"LMStudio OpenAI model missing 'type' field: {model}")
|
|
659
|
-
continue
|
|
660
|
-
elif model["type"] not in ["embeddings"]:
|
|
661
|
-
continue
|
|
662
|
-
|
|
663
|
-
if "max_context_length" in model:
|
|
664
|
-
context_window_size = model["max_context_length"]
|
|
665
|
-
else:
|
|
666
|
-
warnings.warn(f"LMStudio OpenAI model missing 'max_context_length' field: {model}")
|
|
667
|
-
continue
|
|
668
|
-
|
|
669
|
-
configs.append(
|
|
670
|
-
EmbeddingConfig(
|
|
671
|
-
embedding_model=model_name,
|
|
672
|
-
embedding_endpoint_type="openai",
|
|
673
|
-
embedding_endpoint=self.base_url,
|
|
674
|
-
embedding_dim=context_window_size,
|
|
675
|
-
embedding_chunk_size=300, # NOTE: max is 2048
|
|
676
|
-
handle=self.get_handle(model_name),
|
|
677
|
-
),
|
|
678
|
-
)
|
|
679
|
-
|
|
680
|
-
return configs
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
class XAIProvider(OpenAIProvider):
|
|
684
|
-
"""https://docs.x.ai/docs/api-reference"""
|
|
685
|
-
|
|
686
|
-
provider_type: Literal[ProviderType.xai] = Field(ProviderType.xai, description="The type of the provider.")
|
|
687
|
-
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
|
688
|
-
api_key: str = Field(..., description="API key for the xAI/Grok API.")
|
|
689
|
-
base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.")
|
|
690
|
-
|
|
691
|
-
def get_model_context_window_size(self, model_name: str) -> Optional[int]:
|
|
692
|
-
# xAI doesn't return context window in the model listing,
|
|
693
|
-
# so these are hardcoded from their website
|
|
694
|
-
if model_name == "grok-2-1212":
|
|
695
|
-
return 131072
|
|
696
|
-
# NOTE: disabling the minis for now since they return weird MM parts
|
|
697
|
-
# elif model_name == "grok-3-mini-fast-beta":
|
|
698
|
-
# return 131072
|
|
699
|
-
# elif model_name == "grok-3-mini-beta":
|
|
700
|
-
# return 131072
|
|
701
|
-
elif model_name == "grok-3-fast-beta":
|
|
702
|
-
return 131072
|
|
703
|
-
elif model_name == "grok-3-beta":
|
|
704
|
-
return 131072
|
|
705
|
-
else:
|
|
706
|
-
return None
|
|
707
|
-
|
|
708
|
-
def list_llm_models(self) -> List[LLMConfig]:
|
|
709
|
-
from letta.llm_api.openai import openai_get_model_list
|
|
710
|
-
|
|
711
|
-
response = openai_get_model_list(self.base_url, api_key=self.api_key)
|
|
712
|
-
|
|
713
|
-
if "data" in response:
|
|
714
|
-
data = response["data"]
|
|
715
|
-
else:
|
|
716
|
-
data = response
|
|
717
|
-
|
|
718
|
-
configs = []
|
|
719
|
-
for model in data:
|
|
720
|
-
assert "id" in model, f"xAI/Grok model missing 'id' field: {model}"
|
|
721
|
-
model_name = model["id"]
|
|
722
|
-
|
|
723
|
-
# In case xAI starts supporting it in the future:
|
|
724
|
-
if "context_length" in model:
|
|
725
|
-
context_window_size = model["context_length"]
|
|
726
|
-
else:
|
|
727
|
-
context_window_size = self.get_model_context_window_size(model_name)
|
|
728
|
-
|
|
729
|
-
if not context_window_size:
|
|
730
|
-
warnings.warn(f"Couldn't find context window size for model {model_name}")
|
|
731
|
-
continue
|
|
732
|
-
|
|
733
|
-
configs.append(
|
|
734
|
-
LLMConfig(
|
|
735
|
-
model=model_name,
|
|
736
|
-
model_endpoint_type="xai",
|
|
737
|
-
model_endpoint=self.base_url,
|
|
738
|
-
context_window=context_window_size,
|
|
739
|
-
handle=self.get_handle(model_name),
|
|
740
|
-
provider_name=self.name,
|
|
741
|
-
provider_category=self.provider_category,
|
|
742
|
-
)
|
|
743
|
-
)
|
|
744
|
-
|
|
745
|
-
return configs
|
|
746
|
-
|
|
747
|
-
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
|
748
|
-
# No embeddings supported
|
|
749
|
-
return []
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
class AnthropicProvider(Provider):
|
|
753
|
-
provider_type: Literal[ProviderType.anthropic] = Field(ProviderType.anthropic, description="The type of the provider.")
|
|
754
|
-
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
|
755
|
-
api_key: str = Field(..., description="API key for the Anthropic API.")
|
|
756
|
-
base_url: str = "https://api.anthropic.com/v1"
|
|
757
|
-
|
|
758
|
-
def check_api_key(self):
|
|
759
|
-
from letta.llm_api.anthropic import anthropic_check_valid_api_key
|
|
760
|
-
|
|
761
|
-
anthropic_check_valid_api_key(self.api_key)
|
|
762
|
-
|
|
763
|
-
def list_llm_models(self) -> List[LLMConfig]:
|
|
764
|
-
from letta.llm_api.anthropic import anthropic_get_model_list
|
|
765
|
-
|
|
766
|
-
models = anthropic_get_model_list(api_key=self.api_key)
|
|
767
|
-
return self._list_llm_models(models)
|
|
768
|
-
|
|
769
|
-
async def list_llm_models_async(self) -> List[LLMConfig]:
|
|
770
|
-
from letta.llm_api.anthropic import anthropic_get_model_list_async
|
|
771
|
-
|
|
772
|
-
models = await anthropic_get_model_list_async(api_key=self.api_key)
|
|
773
|
-
return self._list_llm_models(models)
|
|
774
|
-
|
|
775
|
-
def _list_llm_models(self, models) -> List[LLMConfig]:
|
|
776
|
-
from letta.llm_api.anthropic import MODEL_LIST
|
|
777
|
-
|
|
778
|
-
configs = []
|
|
779
|
-
for model in models:
|
|
780
|
-
if model["type"] != "model":
|
|
781
|
-
continue
|
|
782
|
-
|
|
783
|
-
if "id" not in model:
|
|
784
|
-
continue
|
|
785
|
-
|
|
786
|
-
# Don't support 2.0 and 2.1
|
|
787
|
-
if model["id"].startswith("claude-2"):
|
|
788
|
-
continue
|
|
789
|
-
|
|
790
|
-
# Anthropic doesn't return the context window in their API
|
|
791
|
-
if "context_window" not in model:
|
|
792
|
-
# Remap list to name: context_window
|
|
793
|
-
model_library = {m["name"]: m["context_window"] for m in MODEL_LIST}
|
|
794
|
-
# Attempt to look it up in a hardcoded list
|
|
795
|
-
if model["id"] in model_library:
|
|
796
|
-
model["context_window"] = model_library[model["id"]]
|
|
797
|
-
else:
|
|
798
|
-
# On fallback, we can set 200k (generally safe), but we should warn the user
|
|
799
|
-
warnings.warn(f"Couldn't find context window size for model {model['id']}, defaulting to 200,000")
|
|
800
|
-
model["context_window"] = 200000
|
|
801
|
-
|
|
802
|
-
max_tokens = 8192
|
|
803
|
-
if "claude-3-opus" in model["id"]:
|
|
804
|
-
max_tokens = 4096
|
|
805
|
-
if "claude-3-haiku" in model["id"]:
|
|
806
|
-
max_tokens = 4096
|
|
807
|
-
# TODO: set for 3-7 extended thinking mode
|
|
808
|
-
|
|
809
|
-
# We set this to false by default, because Anthropic can
|
|
810
|
-
# natively support <thinking> tags inside of content fields
|
|
811
|
-
# However, putting COT inside of tool calls can make it more
|
|
812
|
-
# reliable for tool calling (no chance of a non-tool call step)
|
|
813
|
-
# Since tool_choice_type 'any' doesn't work with in-content COT
|
|
814
|
-
# NOTE For Haiku, it can be flaky if we don't enable this by default
|
|
815
|
-
# inner_thoughts_in_kwargs = True if "haiku" in model["id"] else False
|
|
816
|
-
inner_thoughts_in_kwargs = True # we no longer support thinking tags
|
|
817
|
-
|
|
818
|
-
configs.append(
|
|
819
|
-
LLMConfig(
|
|
820
|
-
model=model["id"],
|
|
821
|
-
model_endpoint_type="anthropic",
|
|
822
|
-
model_endpoint=self.base_url,
|
|
823
|
-
context_window=model["context_window"],
|
|
824
|
-
handle=self.get_handle(model["id"]),
|
|
825
|
-
put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
|
|
826
|
-
max_tokens=max_tokens,
|
|
827
|
-
provider_name=self.name,
|
|
828
|
-
provider_category=self.provider_category,
|
|
829
|
-
)
|
|
830
|
-
)
|
|
831
|
-
return configs
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
class MistralProvider(Provider):
|
|
835
|
-
provider_type: Literal[ProviderType.mistral] = Field(ProviderType.mistral, description="The type of the provider.")
|
|
836
|
-
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
|
837
|
-
api_key: str = Field(..., description="API key for the Mistral API.")
|
|
838
|
-
base_url: str = "https://api.mistral.ai/v1"
|
|
839
|
-
|
|
840
|
-
def list_llm_models(self) -> List[LLMConfig]:
|
|
841
|
-
from letta.llm_api.mistral import mistral_get_model_list
|
|
842
|
-
|
|
843
|
-
# Some hardcoded support for OpenRouter (so that we only get models with tool calling support)...
|
|
844
|
-
# See: https://openrouter.ai/docs/requests
|
|
845
|
-
response = mistral_get_model_list(self.base_url, api_key=self.api_key)
|
|
846
|
-
|
|
847
|
-
assert "data" in response, f"Mistral model query response missing 'data' field: {response}"
|
|
848
|
-
|
|
849
|
-
configs = []
|
|
850
|
-
for model in response["data"]:
|
|
851
|
-
# If model has chat completions and function calling enabled
|
|
852
|
-
if model["capabilities"]["completion_chat"] and model["capabilities"]["function_calling"]:
|
|
853
|
-
configs.append(
|
|
854
|
-
LLMConfig(
|
|
855
|
-
model=model["id"],
|
|
856
|
-
model_endpoint_type="openai",
|
|
857
|
-
model_endpoint=self.base_url,
|
|
858
|
-
context_window=model["max_context_length"],
|
|
859
|
-
handle=self.get_handle(model["id"]),
|
|
860
|
-
provider_name=self.name,
|
|
861
|
-
provider_category=self.provider_category,
|
|
862
|
-
)
|
|
863
|
-
)
|
|
864
|
-
|
|
865
|
-
return configs
|
|
866
|
-
|
|
867
|
-
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
|
868
|
-
# Not supported for mistral
|
|
869
|
-
return []
|
|
870
|
-
|
|
871
|
-
def get_model_context_window(self, model_name: str) -> Optional[int]:
|
|
872
|
-
# Redoing this is fine because it's a pretty lightweight call
|
|
873
|
-
models = self.list_llm_models()
|
|
874
|
-
|
|
875
|
-
for m in models:
|
|
876
|
-
if model_name in m["id"]:
|
|
877
|
-
return int(m["max_context_length"])
|
|
878
|
-
|
|
879
|
-
return None
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
class OllamaProvider(OpenAIProvider):
|
|
883
|
-
"""Ollama provider that uses the native /api/generate endpoint
|
|
884
|
-
|
|
885
|
-
See: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
|
|
886
|
-
"""
|
|
887
|
-
|
|
888
|
-
provider_type: Literal[ProviderType.ollama] = Field(ProviderType.ollama, description="The type of the provider.")
|
|
889
|
-
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
|
890
|
-
base_url: str = Field(..., description="Base URL for the Ollama API.")
|
|
891
|
-
api_key: Optional[str] = Field(None, description="API key for the Ollama API (default: `None`).")
|
|
892
|
-
default_prompt_formatter: str = Field(
|
|
893
|
-
..., description="Default prompt formatter (aka model wrapper) to use on a /completions style API."
|
|
894
|
-
)
|
|
895
|
-
|
|
896
|
-
async def list_llm_models_async(self) -> List[LLMConfig]:
|
|
897
|
-
"""Async version of list_llm_models below"""
|
|
898
|
-
endpoint = f"{self.base_url}/api/tags"
|
|
899
|
-
async with aiohttp.ClientSession() as session:
|
|
900
|
-
async with session.get(endpoint) as response:
|
|
901
|
-
if response.status != 200:
|
|
902
|
-
raise Exception(f"Failed to list Ollama models: {response.text}")
|
|
903
|
-
response_json = await response.json()
|
|
904
|
-
|
|
905
|
-
configs = []
|
|
906
|
-
for model in response_json["models"]:
|
|
907
|
-
context_window = self.get_model_context_window(model["name"])
|
|
908
|
-
if context_window is None:
|
|
909
|
-
print(f"Ollama model {model['name']} has no context window")
|
|
910
|
-
continue
|
|
911
|
-
configs.append(
|
|
912
|
-
LLMConfig(
|
|
913
|
-
model=model["name"],
|
|
914
|
-
model_endpoint_type="ollama",
|
|
915
|
-
model_endpoint=self.base_url,
|
|
916
|
-
model_wrapper=self.default_prompt_formatter,
|
|
917
|
-
context_window=context_window,
|
|
918
|
-
handle=self.get_handle(model["name"]),
|
|
919
|
-
provider_name=self.name,
|
|
920
|
-
provider_category=self.provider_category,
|
|
921
|
-
)
|
|
922
|
-
)
|
|
923
|
-
return configs
|
|
924
|
-
|
|
925
|
-
def list_llm_models(self) -> List[LLMConfig]:
|
|
926
|
-
# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
|
|
927
|
-
response = requests.get(f"{self.base_url}/api/tags")
|
|
928
|
-
if response.status_code != 200:
|
|
929
|
-
raise Exception(f"Failed to list Ollama models: {response.text}")
|
|
930
|
-
response_json = response.json()
|
|
931
|
-
|
|
932
|
-
configs = []
|
|
933
|
-
for model in response_json["models"]:
|
|
934
|
-
context_window = self.get_model_context_window(model["name"])
|
|
935
|
-
if context_window is None:
|
|
936
|
-
print(f"Ollama model {model['name']} has no context window")
|
|
937
|
-
continue
|
|
938
|
-
configs.append(
|
|
939
|
-
LLMConfig(
|
|
940
|
-
model=model["name"],
|
|
941
|
-
model_endpoint_type="ollama",
|
|
942
|
-
model_endpoint=self.base_url,
|
|
943
|
-
model_wrapper=self.default_prompt_formatter,
|
|
944
|
-
context_window=context_window,
|
|
945
|
-
handle=self.get_handle(model["name"]),
|
|
946
|
-
provider_name=self.name,
|
|
947
|
-
provider_category=self.provider_category,
|
|
948
|
-
)
|
|
949
|
-
)
|
|
950
|
-
return configs
|
|
951
|
-
|
|
952
|
-
def get_model_context_window(self, model_name: str) -> Optional[int]:
|
|
953
|
-
response = requests.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True})
|
|
954
|
-
response_json = response.json()
|
|
955
|
-
|
|
956
|
-
## thank you vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L1675
|
|
957
|
-
# possible_keys = [
|
|
958
|
-
# # OPT
|
|
959
|
-
# "max_position_embeddings",
|
|
960
|
-
# # GPT-2
|
|
961
|
-
# "n_positions",
|
|
962
|
-
# # MPT
|
|
963
|
-
# "max_seq_len",
|
|
964
|
-
# # ChatGLM2
|
|
965
|
-
# "seq_length",
|
|
966
|
-
# # Command-R
|
|
967
|
-
# "model_max_length",
|
|
968
|
-
# # Others
|
|
969
|
-
# "max_sequence_length",
|
|
970
|
-
# "max_seq_length",
|
|
971
|
-
# "seq_len",
|
|
972
|
-
# ]
|
|
973
|
-
# max_position_embeddings
|
|
974
|
-
# parse model cards: nous, dolphon, llama
|
|
975
|
-
if "model_info" not in response_json:
|
|
976
|
-
if "error" in response_json:
|
|
977
|
-
print(f"Ollama fetch model info error for {model_name}: {response_json['error']}")
|
|
978
|
-
return None
|
|
979
|
-
for key, value in response_json["model_info"].items():
|
|
980
|
-
if "context_length" in key:
|
|
981
|
-
return value
|
|
982
|
-
return None
|
|
983
|
-
|
|
984
|
-
def _get_model_embedding_dim(self, model_name: str):
|
|
985
|
-
response = requests.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True})
|
|
986
|
-
response_json = response.json()
|
|
987
|
-
return self._get_model_embedding_dim_impl(response_json, model_name)
|
|
988
|
-
|
|
989
|
-
async def _get_model_embedding_dim_async(self, model_name: str):
|
|
990
|
-
async with aiohttp.ClientSession() as session:
|
|
991
|
-
async with session.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True}) as response:
|
|
992
|
-
response_json = await response.json()
|
|
993
|
-
return self._get_model_embedding_dim_impl(response_json, model_name)
|
|
994
|
-
|
|
995
|
-
@staticmethod
|
|
996
|
-
def _get_model_embedding_dim_impl(response_json: dict, model_name: str):
|
|
997
|
-
if "model_info" not in response_json:
|
|
998
|
-
if "error" in response_json:
|
|
999
|
-
print(f"Ollama fetch model info error for {model_name}: {response_json['error']}")
|
|
1000
|
-
return None
|
|
1001
|
-
for key, value in response_json["model_info"].items():
|
|
1002
|
-
if "embedding_length" in key:
|
|
1003
|
-
return value
|
|
1004
|
-
return None
|
|
1005
|
-
|
|
1006
|
-
async def list_embedding_models_async(self) -> List[EmbeddingConfig]:
|
|
1007
|
-
"""Async version of list_embedding_models below"""
|
|
1008
|
-
endpoint = f"{self.base_url}/api/tags"
|
|
1009
|
-
async with aiohttp.ClientSession() as session:
|
|
1010
|
-
async with session.get(endpoint) as response:
|
|
1011
|
-
if response.status != 200:
|
|
1012
|
-
raise Exception(f"Failed to list Ollama models: {response.text}")
|
|
1013
|
-
response_json = await response.json()
|
|
1014
|
-
|
|
1015
|
-
configs = []
|
|
1016
|
-
for model in response_json["models"]:
|
|
1017
|
-
embedding_dim = await self._get_model_embedding_dim_async(model["name"])
|
|
1018
|
-
if not embedding_dim:
|
|
1019
|
-
print(f"Ollama model {model['name']} has no embedding dimension")
|
|
1020
|
-
continue
|
|
1021
|
-
configs.append(
|
|
1022
|
-
EmbeddingConfig(
|
|
1023
|
-
embedding_model=model["name"],
|
|
1024
|
-
embedding_endpoint_type="ollama",
|
|
1025
|
-
embedding_endpoint=self.base_url,
|
|
1026
|
-
embedding_dim=embedding_dim,
|
|
1027
|
-
embedding_chunk_size=300,
|
|
1028
|
-
handle=self.get_handle(model["name"], is_embedding=True),
|
|
1029
|
-
)
|
|
1030
|
-
)
|
|
1031
|
-
return configs
|
|
1032
|
-
|
|
1033
|
-
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
|
1034
|
-
# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
|
|
1035
|
-
response = requests.get(f"{self.base_url}/api/tags")
|
|
1036
|
-
if response.status_code != 200:
|
|
1037
|
-
raise Exception(f"Failed to list Ollama models: {response.text}")
|
|
1038
|
-
response_json = response.json()
|
|
1039
|
-
|
|
1040
|
-
configs = []
|
|
1041
|
-
for model in response_json["models"]:
|
|
1042
|
-
embedding_dim = self._get_model_embedding_dim(model["name"])
|
|
1043
|
-
if not embedding_dim:
|
|
1044
|
-
print(f"Ollama model {model['name']} has no embedding dimension")
|
|
1045
|
-
continue
|
|
1046
|
-
configs.append(
|
|
1047
|
-
EmbeddingConfig(
|
|
1048
|
-
embedding_model=model["name"],
|
|
1049
|
-
embedding_endpoint_type="ollama",
|
|
1050
|
-
embedding_endpoint=self.base_url,
|
|
1051
|
-
embedding_dim=embedding_dim,
|
|
1052
|
-
embedding_chunk_size=300,
|
|
1053
|
-
handle=self.get_handle(model["name"], is_embedding=True),
|
|
1054
|
-
)
|
|
1055
|
-
)
|
|
1056
|
-
return configs
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
class GroqProvider(OpenAIProvider):
|
|
1060
|
-
provider_type: Literal[ProviderType.groq] = Field(ProviderType.groq, description="The type of the provider.")
|
|
1061
|
-
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
|
1062
|
-
base_url: str = "https://api.groq.com/openai/v1"
|
|
1063
|
-
api_key: str = Field(..., description="API key for the Groq API.")
|
|
1064
|
-
|
|
1065
|
-
def list_llm_models(self) -> List[LLMConfig]:
|
|
1066
|
-
from letta.llm_api.openai import openai_get_model_list
|
|
1067
|
-
|
|
1068
|
-
response = openai_get_model_list(self.base_url, api_key=self.api_key)
|
|
1069
|
-
configs = []
|
|
1070
|
-
for model in response["data"]:
|
|
1071
|
-
if "context_window" not in model:
|
|
1072
|
-
continue
|
|
1073
|
-
configs.append(
|
|
1074
|
-
LLMConfig(
|
|
1075
|
-
model=model["id"],
|
|
1076
|
-
model_endpoint_type="groq",
|
|
1077
|
-
model_endpoint=self.base_url,
|
|
1078
|
-
context_window=model["context_window"],
|
|
1079
|
-
handle=self.get_handle(model["id"]),
|
|
1080
|
-
provider_name=self.name,
|
|
1081
|
-
provider_category=self.provider_category,
|
|
1082
|
-
)
|
|
1083
|
-
)
|
|
1084
|
-
return configs
|
|
1085
|
-
|
|
1086
|
-
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
|
1087
|
-
return []
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
class TogetherProvider(OpenAIProvider):
|
|
1091
|
-
"""TogetherAI provider that uses the /completions API
|
|
1092
|
-
|
|
1093
|
-
TogetherAI can also be used via the /chat/completions API
|
|
1094
|
-
by settings OPENAI_API_KEY and OPENAI_API_BASE to the TogetherAI API key
|
|
1095
|
-
and API URL, however /completions is preferred because their /chat/completions
|
|
1096
|
-
function calling support is limited.
|
|
1097
|
-
"""
|
|
1098
|
-
|
|
1099
|
-
provider_type: Literal[ProviderType.together] = Field(ProviderType.together, description="The type of the provider.")
|
|
1100
|
-
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
|
1101
|
-
base_url: str = "https://api.together.ai/v1"
|
|
1102
|
-
api_key: str = Field(..., description="API key for the TogetherAI API.")
|
|
1103
|
-
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
|
|
1104
|
-
|
|
1105
|
-
def list_llm_models(self) -> List[LLMConfig]:
|
|
1106
|
-
from letta.llm_api.openai import openai_get_model_list
|
|
1107
|
-
|
|
1108
|
-
models = openai_get_model_list(self.base_url, api_key=self.api_key)
|
|
1109
|
-
return self._list_llm_models(models)
|
|
1110
|
-
|
|
1111
|
-
async def list_llm_models_async(self) -> List[LLMConfig]:
|
|
1112
|
-
from letta.llm_api.openai import openai_get_model_list_async
|
|
1113
|
-
|
|
1114
|
-
models = await openai_get_model_list_async(self.base_url, api_key=self.api_key)
|
|
1115
|
-
return self._list_llm_models(models)
|
|
1116
|
-
|
|
1117
|
-
def _list_llm_models(self, models) -> List[LLMConfig]:
|
|
1118
|
-
pass
|
|
1119
|
-
|
|
1120
|
-
# TogetherAI's response is missing the 'data' field
|
|
1121
|
-
# assert "data" in response, f"OpenAI model query response missing 'data' field: {response}"
|
|
1122
|
-
if "data" in models:
|
|
1123
|
-
data = models["data"]
|
|
1124
|
-
else:
|
|
1125
|
-
data = models
|
|
1126
|
-
|
|
1127
|
-
configs = []
|
|
1128
|
-
for model in data:
|
|
1129
|
-
assert "id" in model, f"TogetherAI model missing 'id' field: {model}"
|
|
1130
|
-
model_name = model["id"]
|
|
1131
|
-
|
|
1132
|
-
if "context_length" in model:
|
|
1133
|
-
# Context length is returned in OpenRouter as "context_length"
|
|
1134
|
-
context_window_size = model["context_length"]
|
|
1135
|
-
else:
|
|
1136
|
-
context_window_size = self.get_model_context_window_size(model_name)
|
|
1137
|
-
|
|
1138
|
-
# We need the context length for embeddings too
|
|
1139
|
-
if not context_window_size:
|
|
1140
|
-
continue
|
|
1141
|
-
|
|
1142
|
-
# Skip models that are too small for Letta
|
|
1143
|
-
if context_window_size <= MIN_CONTEXT_WINDOW:
|
|
1144
|
-
continue
|
|
1145
|
-
|
|
1146
|
-
# TogetherAI includes the type, which we can use to filter for embedding models
|
|
1147
|
-
if "type" in model and model["type"] not in ["chat", "language"]:
|
|
1148
|
-
continue
|
|
1149
|
-
|
|
1150
|
-
configs.append(
|
|
1151
|
-
LLMConfig(
|
|
1152
|
-
model=model_name,
|
|
1153
|
-
model_endpoint_type="together",
|
|
1154
|
-
model_endpoint=self.base_url,
|
|
1155
|
-
model_wrapper=self.default_prompt_formatter,
|
|
1156
|
-
context_window=context_window_size,
|
|
1157
|
-
handle=self.get_handle(model_name),
|
|
1158
|
-
provider_name=self.name,
|
|
1159
|
-
provider_category=self.provider_category,
|
|
1160
|
-
)
|
|
1161
|
-
)
|
|
1162
|
-
|
|
1163
|
-
return configs
|
|
1164
|
-
|
|
1165
|
-
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
|
1166
|
-
# TODO renable once we figure out how to pass API keys through properly
|
|
1167
|
-
return []
|
|
1168
|
-
|
|
1169
|
-
# from letta.llm_api.openai import openai_get_model_list
|
|
1170
|
-
|
|
1171
|
-
# response = openai_get_model_list(self.base_url, api_key=self.api_key)
|
|
1172
|
-
|
|
1173
|
-
# # TogetherAI's response is missing the 'data' field
|
|
1174
|
-
# # assert "data" in response, f"OpenAI model query response missing 'data' field: {response}"
|
|
1175
|
-
# if "data" in response:
|
|
1176
|
-
# data = response["data"]
|
|
1177
|
-
# else:
|
|
1178
|
-
# data = response
|
|
1179
|
-
|
|
1180
|
-
# configs = []
|
|
1181
|
-
# for model in data:
|
|
1182
|
-
# assert "id" in model, f"TogetherAI model missing 'id' field: {model}"
|
|
1183
|
-
# model_name = model["id"]
|
|
1184
|
-
|
|
1185
|
-
# if "context_length" in model:
|
|
1186
|
-
# # Context length is returned in OpenRouter as "context_length"
|
|
1187
|
-
# context_window_size = model["context_length"]
|
|
1188
|
-
# else:
|
|
1189
|
-
# context_window_size = self.get_model_context_window_size(model_name)
|
|
1190
|
-
|
|
1191
|
-
# if not context_window_size:
|
|
1192
|
-
# continue
|
|
1193
|
-
|
|
1194
|
-
# # TogetherAI includes the type, which we can use to filter out embedding models
|
|
1195
|
-
# if "type" in model and model["type"] not in ["embedding"]:
|
|
1196
|
-
# continue
|
|
1197
|
-
|
|
1198
|
-
# configs.append(
|
|
1199
|
-
# EmbeddingConfig(
|
|
1200
|
-
# embedding_model=model_name,
|
|
1201
|
-
# embedding_endpoint_type="openai",
|
|
1202
|
-
# embedding_endpoint=self.base_url,
|
|
1203
|
-
# embedding_dim=context_window_size,
|
|
1204
|
-
# embedding_chunk_size=300, # TODO: change?
|
|
1205
|
-
# )
|
|
1206
|
-
# )
|
|
1207
|
-
|
|
1208
|
-
# return configs
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
class GoogleAIProvider(Provider):
|
|
1212
|
-
# gemini
|
|
1213
|
-
provider_type: Literal[ProviderType.google_ai] = Field(ProviderType.google_ai, description="The type of the provider.")
|
|
1214
|
-
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
|
1215
|
-
api_key: str = Field(..., description="API key for the Google AI API.")
|
|
1216
|
-
base_url: str = "https://generativelanguage.googleapis.com"
|
|
1217
|
-
|
|
1218
|
-
def check_api_key(self):
|
|
1219
|
-
from letta.llm_api.google_ai_client import google_ai_check_valid_api_key
|
|
1220
|
-
|
|
1221
|
-
google_ai_check_valid_api_key(self.api_key)
|
|
1222
|
-
|
|
1223
|
-
def list_llm_models(self):
|
|
1224
|
-
from letta.llm_api.google_ai_client import google_ai_get_model_list
|
|
1225
|
-
|
|
1226
|
-
model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key)
|
|
1227
|
-
model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]]
|
|
1228
|
-
model_options = [str(m["name"]) for m in model_options]
|
|
1229
|
-
|
|
1230
|
-
# filter by model names
|
|
1231
|
-
model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
|
|
1232
|
-
|
|
1233
|
-
# Add support for all gemini models
|
|
1234
|
-
model_options = [mo for mo in model_options if str(mo).startswith("gemini-")]
|
|
1235
|
-
|
|
1236
|
-
configs = []
|
|
1237
|
-
for model in model_options:
|
|
1238
|
-
configs.append(
|
|
1239
|
-
LLMConfig(
|
|
1240
|
-
model=model,
|
|
1241
|
-
model_endpoint_type="google_ai",
|
|
1242
|
-
model_endpoint=self.base_url,
|
|
1243
|
-
context_window=self.get_model_context_window(model),
|
|
1244
|
-
handle=self.get_handle(model),
|
|
1245
|
-
max_tokens=8192,
|
|
1246
|
-
provider_name=self.name,
|
|
1247
|
-
provider_category=self.provider_category,
|
|
1248
|
-
)
|
|
1249
|
-
)
|
|
1250
|
-
|
|
1251
|
-
return configs
|
|
1252
|
-
|
|
1253
|
-
async def list_llm_models_async(self):
|
|
1254
|
-
import asyncio
|
|
1255
|
-
|
|
1256
|
-
from letta.llm_api.google_ai_client import google_ai_get_model_list_async
|
|
1257
|
-
|
|
1258
|
-
# Get and filter the model list
|
|
1259
|
-
model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=self.api_key)
|
|
1260
|
-
model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]]
|
|
1261
|
-
model_options = [str(m["name"]) for m in model_options]
|
|
1262
|
-
|
|
1263
|
-
# filter by model names
|
|
1264
|
-
model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
|
|
1265
|
-
|
|
1266
|
-
# Add support for all gemini models
|
|
1267
|
-
model_options = [mo for mo in model_options if str(mo).startswith("gemini-")]
|
|
1268
|
-
|
|
1269
|
-
# Prepare tasks for context window lookups in parallel
|
|
1270
|
-
async def create_config(model):
|
|
1271
|
-
context_window = await self.get_model_context_window_async(model)
|
|
1272
|
-
return LLMConfig(
|
|
1273
|
-
model=model,
|
|
1274
|
-
model_endpoint_type="google_ai",
|
|
1275
|
-
model_endpoint=self.base_url,
|
|
1276
|
-
context_window=context_window,
|
|
1277
|
-
handle=self.get_handle(model),
|
|
1278
|
-
max_tokens=8192,
|
|
1279
|
-
provider_name=self.name,
|
|
1280
|
-
provider_category=self.provider_category,
|
|
1281
|
-
)
|
|
1282
|
-
|
|
1283
|
-
# Execute all config creation tasks concurrently
|
|
1284
|
-
configs = await asyncio.gather(*[create_config(model) for model in model_options])
|
|
1285
|
-
|
|
1286
|
-
return configs
|
|
1287
|
-
|
|
1288
|
-
def list_embedding_models(self):
|
|
1289
|
-
from letta.llm_api.google_ai_client import google_ai_get_model_list
|
|
1290
|
-
|
|
1291
|
-
# TODO: use base_url instead
|
|
1292
|
-
model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key)
|
|
1293
|
-
return self._list_embedding_models(model_options)
|
|
1294
|
-
|
|
1295
|
-
async def list_embedding_models_async(self):
|
|
1296
|
-
from letta.llm_api.google_ai_client import google_ai_get_model_list_async
|
|
1297
|
-
|
|
1298
|
-
# TODO: use base_url instead
|
|
1299
|
-
model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=self.api_key)
|
|
1300
|
-
return self._list_embedding_models(model_options)
|
|
1301
|
-
|
|
1302
|
-
def _list_embedding_models(self, model_options):
|
|
1303
|
-
# filter by 'generateContent' models
|
|
1304
|
-
model_options = [mo for mo in model_options if "embedContent" in mo["supportedGenerationMethods"]]
|
|
1305
|
-
model_options = [str(m["name"]) for m in model_options]
|
|
1306
|
-
model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
|
|
1307
|
-
|
|
1308
|
-
configs = []
|
|
1309
|
-
for model in model_options:
|
|
1310
|
-
configs.append(
|
|
1311
|
-
EmbeddingConfig(
|
|
1312
|
-
embedding_model=model,
|
|
1313
|
-
embedding_endpoint_type="google_ai",
|
|
1314
|
-
embedding_endpoint=self.base_url,
|
|
1315
|
-
embedding_dim=768,
|
|
1316
|
-
embedding_chunk_size=300, # NOTE: max is 2048
|
|
1317
|
-
handle=self.get_handle(model, is_embedding=True),
|
|
1318
|
-
batch_size=1024,
|
|
1319
|
-
)
|
|
1320
|
-
)
|
|
1321
|
-
return configs
|
|
1322
|
-
|
|
1323
|
-
def get_model_context_window(self, model_name: str) -> Optional[int]:
|
|
1324
|
-
from letta.llm_api.google_ai_client import google_ai_get_model_context_window
|
|
1325
|
-
|
|
1326
|
-
if model_name in LLM_MAX_TOKENS:
|
|
1327
|
-
return LLM_MAX_TOKENS[model_name]
|
|
1328
|
-
else:
|
|
1329
|
-
return google_ai_get_model_context_window(self.base_url, self.api_key, model_name)
|
|
1330
|
-
|
|
1331
|
-
async def get_model_context_window_async(self, model_name: str) -> Optional[int]:
|
|
1332
|
-
from letta.llm_api.google_ai_client import google_ai_get_model_context_window_async
|
|
1333
|
-
|
|
1334
|
-
if model_name in LLM_MAX_TOKENS:
|
|
1335
|
-
return LLM_MAX_TOKENS[model_name]
|
|
1336
|
-
else:
|
|
1337
|
-
return await google_ai_get_model_context_window_async(self.base_url, self.api_key, model_name)
|
|
1338
|
-
|
|
1339
|
-
|
|
1340
|
-
class GoogleVertexProvider(Provider):
|
|
1341
|
-
provider_type: Literal[ProviderType.google_vertex] = Field(ProviderType.google_vertex, description="The type of the provider.")
|
|
1342
|
-
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
|
1343
|
-
google_cloud_project: str = Field(..., description="GCP project ID for the Google Vertex API.")
|
|
1344
|
-
google_cloud_location: str = Field(..., description="GCP region for the Google Vertex API.")
|
|
1345
|
-
|
|
1346
|
-
def list_llm_models(self) -> List[LLMConfig]:
|
|
1347
|
-
from letta.llm_api.google_constants import GOOGLE_MODEL_TO_CONTEXT_LENGTH
|
|
1348
|
-
|
|
1349
|
-
configs = []
|
|
1350
|
-
for model, context_length in GOOGLE_MODEL_TO_CONTEXT_LENGTH.items():
|
|
1351
|
-
configs.append(
|
|
1352
|
-
LLMConfig(
|
|
1353
|
-
model=model,
|
|
1354
|
-
model_endpoint_type="google_vertex",
|
|
1355
|
-
model_endpoint=f"https://{self.google_cloud_location}-aiplatform.googleapis.com/v1/projects/{self.google_cloud_project}/locations/{self.google_cloud_location}",
|
|
1356
|
-
context_window=context_length,
|
|
1357
|
-
handle=self.get_handle(model),
|
|
1358
|
-
max_tokens=8192,
|
|
1359
|
-
provider_name=self.name,
|
|
1360
|
-
provider_category=self.provider_category,
|
|
1361
|
-
)
|
|
1362
|
-
)
|
|
1363
|
-
return configs
|
|
1364
|
-
|
|
1365
|
-
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
|
1366
|
-
from letta.llm_api.google_constants import GOOGLE_EMBEDING_MODEL_TO_DIM
|
|
1367
|
-
|
|
1368
|
-
configs = []
|
|
1369
|
-
for model, dim in GOOGLE_EMBEDING_MODEL_TO_DIM.items():
|
|
1370
|
-
configs.append(
|
|
1371
|
-
EmbeddingConfig(
|
|
1372
|
-
embedding_model=model,
|
|
1373
|
-
embedding_endpoint_type="google_vertex",
|
|
1374
|
-
embedding_endpoint=f"https://{self.google_cloud_location}-aiplatform.googleapis.com/v1/projects/{self.google_cloud_project}/locations/{self.google_cloud_location}",
|
|
1375
|
-
embedding_dim=dim,
|
|
1376
|
-
embedding_chunk_size=300, # NOTE: max is 2048
|
|
1377
|
-
handle=self.get_handle(model, is_embedding=True),
|
|
1378
|
-
batch_size=1024,
|
|
1379
|
-
)
|
|
1380
|
-
)
|
|
1381
|
-
return configs
|
|
1382
|
-
|
|
1383
|
-
|
|
1384
|
-
class AzureProvider(Provider):
|
|
1385
|
-
provider_type: Literal[ProviderType.azure] = Field(ProviderType.azure, description="The type of the provider.")
|
|
1386
|
-
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
|
1387
|
-
latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
|
|
1388
|
-
base_url: str = Field(
|
|
1389
|
-
..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`."
|
|
1390
|
-
)
|
|
1391
|
-
api_key: str = Field(..., description="API key for the Azure API.")
|
|
1392
|
-
api_version: str = Field(latest_api_version, description="API version for the Azure API")
|
|
1393
|
-
|
|
1394
|
-
@model_validator(mode="before")
|
|
1395
|
-
def set_default_api_version(cls, values):
|
|
1396
|
-
"""
|
|
1397
|
-
This ensures that api_version is always set to the default if None is passed in.
|
|
1398
|
-
"""
|
|
1399
|
-
if values.get("api_version") is None:
|
|
1400
|
-
values["api_version"] = cls.model_fields["latest_api_version"].default
|
|
1401
|
-
return values
|
|
1402
|
-
|
|
1403
|
-
def list_llm_models(self) -> List[LLMConfig]:
|
|
1404
|
-
from letta.llm_api.azure_openai import azure_openai_get_chat_completion_model_list
|
|
1405
|
-
|
|
1406
|
-
model_options = azure_openai_get_chat_completion_model_list(self.base_url, api_key=self.api_key, api_version=self.api_version)
|
|
1407
|
-
configs = []
|
|
1408
|
-
for model_option in model_options:
|
|
1409
|
-
model_name = model_option["id"]
|
|
1410
|
-
context_window_size = self.get_model_context_window(model_name)
|
|
1411
|
-
model_endpoint = get_azure_chat_completions_endpoint(self.base_url, model_name, self.api_version)
|
|
1412
|
-
configs.append(
|
|
1413
|
-
LLMConfig(
|
|
1414
|
-
model=model_name,
|
|
1415
|
-
model_endpoint_type="azure",
|
|
1416
|
-
model_endpoint=model_endpoint,
|
|
1417
|
-
context_window=context_window_size,
|
|
1418
|
-
handle=self.get_handle(model_name),
|
|
1419
|
-
provider_name=self.name,
|
|
1420
|
-
provider_category=self.provider_category,
|
|
1421
|
-
),
|
|
1422
|
-
)
|
|
1423
|
-
return configs
|
|
1424
|
-
|
|
1425
|
-
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
|
1426
|
-
from letta.llm_api.azure_openai import azure_openai_get_embeddings_model_list
|
|
1427
|
-
|
|
1428
|
-
model_options = azure_openai_get_embeddings_model_list(
|
|
1429
|
-
self.base_url, api_key=self.api_key, api_version=self.api_version, require_embedding_in_name=True
|
|
1430
|
-
)
|
|
1431
|
-
configs = []
|
|
1432
|
-
for model_option in model_options:
|
|
1433
|
-
model_name = model_option["id"]
|
|
1434
|
-
model_endpoint = get_azure_embeddings_endpoint(self.base_url, model_name, self.api_version)
|
|
1435
|
-
configs.append(
|
|
1436
|
-
EmbeddingConfig(
|
|
1437
|
-
embedding_model=model_name,
|
|
1438
|
-
embedding_endpoint_type="azure",
|
|
1439
|
-
embedding_endpoint=model_endpoint,
|
|
1440
|
-
embedding_dim=768,
|
|
1441
|
-
embedding_chunk_size=300, # NOTE: max is 2048
|
|
1442
|
-
handle=self.get_handle(model_name),
|
|
1443
|
-
batch_size=1024,
|
|
1444
|
-
),
|
|
1445
|
-
)
|
|
1446
|
-
return configs
|
|
1447
|
-
|
|
1448
|
-
def get_model_context_window(self, model_name: str) -> Optional[int]:
|
|
1449
|
-
"""
|
|
1450
|
-
This is hardcoded for now, since there is no API endpoints to retrieve metadata for a model.
|
|
1451
|
-
"""
|
|
1452
|
-
context_window = AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, None)
|
|
1453
|
-
if context_window is None:
|
|
1454
|
-
context_window = LLM_MAX_TOKENS.get(model_name, 4096)
|
|
1455
|
-
return context_window
|
|
1456
|
-
|
|
1457
|
-
|
|
1458
|
-
class VLLMChatCompletionsProvider(Provider):
|
|
1459
|
-
"""vLLM provider that treats vLLM as an OpenAI /chat/completions proxy"""
|
|
1460
|
-
|
|
1461
|
-
# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
|
|
1462
|
-
provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
|
|
1463
|
-
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
|
1464
|
-
base_url: str = Field(..., description="Base URL for the vLLM API.")
|
|
1465
|
-
|
|
1466
|
-
def list_llm_models(self) -> List[LLMConfig]:
|
|
1467
|
-
# not supported with vLLM
|
|
1468
|
-
from letta.llm_api.openai import openai_get_model_list
|
|
1469
|
-
|
|
1470
|
-
assert self.base_url, "base_url is required for vLLM provider"
|
|
1471
|
-
response = openai_get_model_list(self.base_url, api_key=None)
|
|
1472
|
-
|
|
1473
|
-
configs = []
|
|
1474
|
-
for model in response["data"]:
|
|
1475
|
-
configs.append(
|
|
1476
|
-
LLMConfig(
|
|
1477
|
-
model=model["id"],
|
|
1478
|
-
model_endpoint_type="openai",
|
|
1479
|
-
model_endpoint=self.base_url,
|
|
1480
|
-
context_window=model["max_model_len"],
|
|
1481
|
-
handle=self.get_handle(model["id"]),
|
|
1482
|
-
provider_name=self.name,
|
|
1483
|
-
provider_category=self.provider_category,
|
|
1484
|
-
)
|
|
1485
|
-
)
|
|
1486
|
-
return configs
|
|
1487
|
-
|
|
1488
|
-
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
|
1489
|
-
# not supported with vLLM
|
|
1490
|
-
return []
|
|
1491
|
-
|
|
1492
|
-
|
|
1493
|
-
class VLLMCompletionsProvider(Provider):
|
|
1494
|
-
"""This uses /completions API as the backend, not /chat/completions, so we need to specify a model wrapper"""
|
|
1495
|
-
|
|
1496
|
-
# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
|
|
1497
|
-
provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
|
|
1498
|
-
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
|
1499
|
-
base_url: str = Field(..., description="Base URL for the vLLM API.")
|
|
1500
|
-
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
|
|
1501
|
-
|
|
1502
|
-
def list_llm_models(self) -> List[LLMConfig]:
|
|
1503
|
-
# not supported with vLLM
|
|
1504
|
-
from letta.llm_api.openai import openai_get_model_list
|
|
1505
|
-
|
|
1506
|
-
response = openai_get_model_list(self.base_url, api_key=None)
|
|
1507
|
-
|
|
1508
|
-
configs = []
|
|
1509
|
-
for model in response["data"]:
|
|
1510
|
-
configs.append(
|
|
1511
|
-
LLMConfig(
|
|
1512
|
-
model=model["id"],
|
|
1513
|
-
model_endpoint_type="vllm",
|
|
1514
|
-
model_endpoint=self.base_url,
|
|
1515
|
-
model_wrapper=self.default_prompt_formatter,
|
|
1516
|
-
context_window=model["max_model_len"],
|
|
1517
|
-
handle=self.get_handle(model["id"]),
|
|
1518
|
-
provider_name=self.name,
|
|
1519
|
-
provider_category=self.provider_category,
|
|
1520
|
-
)
|
|
1521
|
-
)
|
|
1522
|
-
return configs
|
|
1523
|
-
|
|
1524
|
-
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
|
1525
|
-
# not supported with vLLM
|
|
1526
|
-
return []
|
|
1527
|
-
|
|
1528
|
-
|
|
1529
|
-
class CohereProvider(OpenAIProvider):
|
|
1530
|
-
pass
|
|
1531
|
-
|
|
1532
|
-
|
|
1533
|
-
class BedrockProvider(Provider):
|
|
1534
|
-
provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.")
|
|
1535
|
-
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
|
1536
|
-
region: str = Field(..., description="AWS region for Bedrock")
|
|
1537
|
-
|
|
1538
|
-
def check_api_key(self):
|
|
1539
|
-
"""Check if the Bedrock credentials are valid"""
|
|
1540
|
-
from letta.errors import LLMAuthenticationError
|
|
1541
|
-
from letta.llm_api.aws_bedrock import bedrock_get_model_list
|
|
1542
|
-
|
|
1543
|
-
try:
|
|
1544
|
-
# For BYOK providers, use the custom credentials
|
|
1545
|
-
if self.provider_category == ProviderCategory.byok:
|
|
1546
|
-
# If we can list models, the credentials are valid
|
|
1547
|
-
bedrock_get_model_list(
|
|
1548
|
-
region_name=self.region,
|
|
1549
|
-
access_key_id=self.access_key,
|
|
1550
|
-
secret_access_key=self.api_key, # api_key stores the secret access key
|
|
1551
|
-
)
|
|
1552
|
-
else:
|
|
1553
|
-
# For base providers, use default credentials
|
|
1554
|
-
bedrock_get_model_list(region_name=self.region)
|
|
1555
|
-
except Exception as e:
|
|
1556
|
-
raise LLMAuthenticationError(message=f"Failed to authenticate with Bedrock: {e}")
|
|
1557
|
-
|
|
1558
|
-
def list_llm_models(self):
|
|
1559
|
-
from letta.llm_api.aws_bedrock import bedrock_get_model_list
|
|
1560
|
-
|
|
1561
|
-
models = bedrock_get_model_list(self.region)
|
|
1562
|
-
|
|
1563
|
-
configs = []
|
|
1564
|
-
for model_summary in models:
|
|
1565
|
-
model_arn = model_summary["inferenceProfileArn"]
|
|
1566
|
-
configs.append(
|
|
1567
|
-
LLMConfig(
|
|
1568
|
-
model=model_arn,
|
|
1569
|
-
model_endpoint_type=self.provider_type.value,
|
|
1570
|
-
model_endpoint=None,
|
|
1571
|
-
context_window=self.get_model_context_window(model_arn),
|
|
1572
|
-
handle=self.get_handle(model_arn),
|
|
1573
|
-
provider_name=self.name,
|
|
1574
|
-
provider_category=self.provider_category,
|
|
1575
|
-
)
|
|
1576
|
-
)
|
|
1577
|
-
return configs
|
|
1578
|
-
|
|
1579
|
-
async def list_llm_models_async(self) -> List[LLMConfig]:
|
|
1580
|
-
from letta.llm_api.aws_bedrock import bedrock_get_model_list_async
|
|
1581
|
-
|
|
1582
|
-
models = await bedrock_get_model_list_async(
|
|
1583
|
-
self.access_key,
|
|
1584
|
-
self.api_key,
|
|
1585
|
-
self.region,
|
|
1586
|
-
)
|
|
1587
|
-
|
|
1588
|
-
configs = []
|
|
1589
|
-
for model_summary in models:
|
|
1590
|
-
model_arn = model_summary["inferenceProfileArn"]
|
|
1591
|
-
configs.append(
|
|
1592
|
-
LLMConfig(
|
|
1593
|
-
model=model_arn,
|
|
1594
|
-
model_endpoint_type=self.provider_type.value,
|
|
1595
|
-
model_endpoint=None,
|
|
1596
|
-
context_window=self.get_model_context_window(model_arn),
|
|
1597
|
-
handle=self.get_handle(model_arn),
|
|
1598
|
-
provider_name=self.name,
|
|
1599
|
-
provider_category=self.provider_category,
|
|
1600
|
-
)
|
|
1601
|
-
)
|
|
1602
|
-
|
|
1603
|
-
return configs
|
|
1604
|
-
|
|
1605
|
-
def list_embedding_models(self):
|
|
1606
|
-
return []
|
|
1607
|
-
|
|
1608
|
-
def get_model_context_window(self, model_name: str) -> Optional[int]:
|
|
1609
|
-
# Context windows for Claude models
|
|
1610
|
-
from letta.llm_api.aws_bedrock import bedrock_get_model_context_window
|
|
1611
|
-
|
|
1612
|
-
return bedrock_get_model_context_window(model_name)
|
|
1613
|
-
|
|
1614
|
-
def get_handle(self, model_name: str, is_embedding: bool = False, base_name: Optional[str] = None) -> str:
|
|
1615
|
-
print(model_name)
|
|
1616
|
-
model = model_name.split(".")[-1]
|
|
1617
|
-
return f"{self.name}/{model}"
|