letta-nightly 0.7.7.dev20250430205840__py3-none-any.whl → 0.7.8.dev20250501104226__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. letta/__init__.py +1 -1
  2. letta/agent.py +8 -12
  3. letta/agents/exceptions.py +6 -0
  4. letta/agents/letta_agent.py +48 -35
  5. letta/agents/letta_agent_batch.py +6 -2
  6. letta/agents/voice_agent.py +10 -7
  7. letta/constants.py +5 -1
  8. letta/functions/composio_helpers.py +100 -0
  9. letta/functions/functions.py +4 -2
  10. letta/functions/helpers.py +19 -99
  11. letta/groups/helpers.py +1 -0
  12. letta/groups/sleeptime_multi_agent.py +5 -1
  13. letta/helpers/message_helper.py +21 -4
  14. letta/helpers/tool_execution_helper.py +1 -1
  15. letta/interfaces/anthropic_streaming_interface.py +165 -158
  16. letta/interfaces/openai_chat_completions_streaming_interface.py +1 -1
  17. letta/llm_api/anthropic.py +15 -10
  18. letta/llm_api/anthropic_client.py +5 -1
  19. letta/llm_api/google_vertex_client.py +1 -1
  20. letta/llm_api/llm_api_tools.py +7 -0
  21. letta/llm_api/llm_client.py +12 -2
  22. letta/llm_api/llm_client_base.py +4 -0
  23. letta/llm_api/openai.py +9 -3
  24. letta/llm_api/openai_client.py +18 -4
  25. letta/memory.py +3 -1
  26. letta/orm/group.py +2 -0
  27. letta/orm/provider.py +10 -0
  28. letta/schemas/agent.py +0 -1
  29. letta/schemas/enums.py +11 -0
  30. letta/schemas/group.py +24 -0
  31. letta/schemas/llm_config.py +1 -0
  32. letta/schemas/llm_config_overrides.py +2 -2
  33. letta/schemas/providers.py +75 -20
  34. letta/schemas/tool.py +3 -8
  35. letta/server/rest_api/app.py +12 -0
  36. letta/server/rest_api/chat_completions_interface.py +1 -1
  37. letta/server/rest_api/interface.py +8 -10
  38. letta/server/rest_api/{optimistic_json_parser.py → json_parser.py} +62 -26
  39. letta/server/rest_api/routers/v1/agents.py +1 -1
  40. letta/server/rest_api/routers/v1/llms.py +4 -3
  41. letta/server/rest_api/routers/v1/providers.py +4 -1
  42. letta/server/rest_api/routers/v1/voice.py +0 -2
  43. letta/server/rest_api/utils.py +8 -19
  44. letta/server/server.py +25 -11
  45. letta/services/group_manager.py +58 -0
  46. letta/services/provider_manager.py +25 -14
  47. letta/services/summarizer/summarizer.py +15 -7
  48. letta/services/tool_executor/tool_execution_manager.py +1 -1
  49. letta/services/tool_executor/tool_executor.py +3 -3
  50. {letta_nightly-0.7.7.dev20250430205840.dist-info → letta_nightly-0.7.8.dev20250501104226.dist-info}/METADATA +4 -5
  51. {letta_nightly-0.7.7.dev20250430205840.dist-info → letta_nightly-0.7.8.dev20250501104226.dist-info}/RECORD +54 -52
  52. {letta_nightly-0.7.7.dev20250430205840.dist-info → letta_nightly-0.7.8.dev20250501104226.dist-info}/LICENSE +0 -0
  53. {letta_nightly-0.7.7.dev20250430205840.dist-info → letta_nightly-0.7.8.dev20250501104226.dist-info}/WHEEL +0 -0
  54. {letta_nightly-0.7.7.dev20250430205840.dist-info → letta_nightly-0.7.8.dev20250501104226.dist-info}/entry_points.txt +0 -0
@@ -22,6 +22,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_st
22
22
  from letta.llm_api.llm_client_base import LLMClientBase
23
23
  from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST
24
24
  from letta.log import get_logger
25
+ from letta.schemas.enums import ProviderType
25
26
  from letta.schemas.llm_config import LLMConfig
26
27
  from letta.schemas.message import Message as PydanticMessage
27
28
  from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
@@ -64,7 +65,14 @@ def supports_parallel_tool_calling(model: str) -> bool:
64
65
 
65
66
  class OpenAIClient(LLMClientBase):
66
67
  def _prepare_client_kwargs(self, llm_config: LLMConfig) -> dict:
67
- api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
68
+ api_key = None
69
+ if llm_config.provider_name and llm_config.provider_name != ProviderType.openai.value:
70
+ from letta.services.provider_manager import ProviderManager
71
+
72
+ api_key = ProviderManager().get_override_key(llm_config.provider_name)
73
+
74
+ if not api_key:
75
+ api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
68
76
  # supposedly the openai python client requires a dummy API key
69
77
  api_key = api_key or "DUMMY_API_KEY"
70
78
  kwargs = {"api_key": api_key, "base_url": llm_config.model_endpoint}
@@ -135,11 +143,17 @@ class OpenAIClient(LLMClientBase):
135
143
  temperature=llm_config.temperature if supports_temperature_param(model) else None,
136
144
  )
137
145
 
146
+ # always set user id for openai requests
147
+ if self.actor_id:
148
+ data.user = self.actor_id
149
+
138
150
  if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT:
139
- # override user id for inference.memgpt.ai
140
- import uuid
151
+ if not self.actor_id:
152
+ # override user id for inference.letta.com
153
+ import uuid
154
+
155
+ data.user = str(uuid.UUID(int=0))
141
156
 
142
- data.user = str(uuid.UUID(int=0))
143
157
  data.model = "memgpt-openai"
144
158
 
145
159
  if data.tools is not None and len(data.tools) > 0:
letta/memory.py CHANGED
@@ -79,8 +79,10 @@ def summarize_messages(
79
79
  llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False
80
80
 
81
81
  llm_client = LLMClient.create(
82
- provider=llm_config_no_inner_thoughts.model_endpoint_type,
82
+ provider_name=llm_config_no_inner_thoughts.provider_name,
83
+ provider_type=llm_config_no_inner_thoughts.model_endpoint_type,
83
84
  put_inner_thoughts_first=False,
85
+ actor_id=agent_state.created_by_id,
84
86
  )
85
87
  # try to use new client, otherwise fallback to old flow
86
88
  # TODO: we can just directly call the LLM here?
letta/orm/group.py CHANGED
@@ -21,6 +21,8 @@ class Group(SqlalchemyBase, OrganizationMixin):
21
21
  termination_token: Mapped[Optional[str]] = mapped_column(nullable=True, doc="")
22
22
  max_turns: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
23
23
  sleeptime_agent_frequency: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
24
+ max_message_buffer_length: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
25
+ min_message_buffer_length: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
24
26
  turns_counter: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
25
27
  last_processed_message_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="")
26
28
 
letta/orm/provider.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from typing import TYPE_CHECKING
2
2
 
3
+ from sqlalchemy import UniqueConstraint
3
4
  from sqlalchemy.orm import Mapped, mapped_column, relationship
4
5
 
5
6
  from letta.orm.mixins import OrganizationMixin
@@ -15,9 +16,18 @@ class Provider(SqlalchemyBase, OrganizationMixin):
15
16
 
16
17
  __tablename__ = "providers"
17
18
  __pydantic_model__ = PydanticProvider
19
+ __table_args__ = (
20
+ UniqueConstraint(
21
+ "name",
22
+ "organization_id",
23
+ name="unique_name_organization_id",
24
+ ),
25
+ )
18
26
 
19
27
  name: Mapped[str] = mapped_column(nullable=False, doc="The name of the provider")
28
+ provider_type: Mapped[str] = mapped_column(nullable=True, doc="The type of the provider")
20
29
  api_key: Mapped[str] = mapped_column(nullable=True, doc="API key used for requests to the provider.")
30
+ base_url: Mapped[str] = mapped_column(nullable=True, doc="Base URL for the provider.")
21
31
 
22
32
  # relationships
23
33
  organization: Mapped["Organization"] = relationship("Organization", back_populates="providers")
letta/schemas/agent.py CHANGED
@@ -56,7 +56,6 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
56
56
  name: str = Field(..., description="The name of the agent.")
57
57
  # tool rules
58
58
  tool_rules: Optional[List[ToolRule]] = Field(default=None, description="The list of tool rules.")
59
-
60
59
  # in-context memory
61
60
  message_ids: Optional[List[str]] = Field(default=None, description="The ids of the messages in the agent's in-context memory.")
62
61
 
letta/schemas/enums.py CHANGED
@@ -6,6 +6,17 @@ class ProviderType(str, Enum):
6
6
  google_ai = "google_ai"
7
7
  google_vertex = "google_vertex"
8
8
  openai = "openai"
9
+ letta = "letta"
10
+ deepseek = "deepseek"
11
+ lmstudio_openai = "lmstudio_openai"
12
+ xai = "xai"
13
+ mistral = "mistral"
14
+ ollama = "ollama"
15
+ groq = "groq"
16
+ together = "together"
17
+ azure = "azure"
18
+ vllm = "vllm"
19
+ bedrock = "bedrock"
9
20
 
10
21
 
11
22
  class MessageRole(str, Enum):
letta/schemas/group.py CHANGED
@@ -32,6 +32,14 @@ class Group(GroupBase):
32
32
  sleeptime_agent_frequency: Optional[int] = Field(None, description="")
33
33
  turns_counter: Optional[int] = Field(None, description="")
34
34
  last_processed_message_id: Optional[str] = Field(None, description="")
35
+ max_message_buffer_length: Optional[int] = Field(
36
+ None,
37
+ description="The desired maximum length of messages in the context window of the convo agent. This is a best effort, and may be off slightly due to user/assistant interleaving.",
38
+ )
39
+ min_message_buffer_length: Optional[int] = Field(
40
+ None,
41
+ description="The desired minimum length of messages in the context window of the convo agent. This is a best effort, and may be off-by-one due to user/assistant interleaving.",
42
+ )
35
43
 
36
44
 
37
45
  class ManagerConfig(BaseModel):
@@ -87,11 +95,27 @@ class SleeptimeManagerUpdate(ManagerConfig):
87
95
  class VoiceSleeptimeManager(ManagerConfig):
88
96
  manager_type: Literal[ManagerType.voice_sleeptime] = Field(ManagerType.voice_sleeptime, description="")
89
97
  manager_agent_id: str = Field(..., description="")
98
+ max_message_buffer_length: Optional[int] = Field(
99
+ None,
100
+ description="The desired maximum length of messages in the context window of the convo agent. This is a best effort, and may be off slightly due to user/assistant interleaving.",
101
+ )
102
+ min_message_buffer_length: Optional[int] = Field(
103
+ None,
104
+ description="The desired minimum length of messages in the context window of the convo agent. This is a best effort, and may be off-by-one due to user/assistant interleaving.",
105
+ )
90
106
 
91
107
 
92
108
  class VoiceSleeptimeManagerUpdate(ManagerConfig):
93
109
  manager_type: Literal[ManagerType.voice_sleeptime] = Field(ManagerType.voice_sleeptime, description="")
94
110
  manager_agent_id: Optional[str] = Field(None, description="")
111
+ max_message_buffer_length: Optional[int] = Field(
112
+ None,
113
+ description="The desired maximum length of messages in the context window of the convo agent. This is a best effort, and may be off slightly due to user/assistant interleaving.",
114
+ )
115
+ min_message_buffer_length: Optional[int] = Field(
116
+ None,
117
+ description="The desired minimum length of messages in the context window of the convo agent. This is a best effort, and may be off-by-one due to user/assistant interleaving.",
118
+ )
95
119
 
96
120
 
97
121
  # class SwarmGroup(ManagerConfig):
@@ -50,6 +50,7 @@ class LLMConfig(BaseModel):
50
50
  "xai",
51
51
  ] = Field(..., description="The endpoint type for the model.")
52
52
  model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.")
53
+ provider_name: Optional[str] = Field(None, description="The provider name for the model.")
53
54
  model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.")
54
55
  context_window: int = Field(..., description="The context window size for the model.")
55
56
  put_inner_thoughts_in_kwargs: Optional[bool] = Field(
@@ -2,8 +2,8 @@ from typing import Dict
2
2
 
3
3
  LLM_HANDLE_OVERRIDES: Dict[str, Dict[str, str]] = {
4
4
  "anthropic": {
5
- "claude-3-5-haiku-20241022": "claude-3.5-haiku",
6
- "claude-3-5-sonnet-20241022": "claude-3.5-sonnet",
5
+ "claude-3-5-haiku-20241022": "claude-3-5-haiku",
6
+ "claude-3-5-sonnet-20241022": "claude-3-5-sonnet",
7
7
  "claude-3-opus-20240229": "claude-3-opus",
8
8
  },
9
9
  "openai": {
@@ -1,6 +1,6 @@
1
1
  import warnings
2
2
  from datetime import datetime
3
- from typing import List, Optional
3
+ from typing import List, Literal, Optional
4
4
 
5
5
  from pydantic import Field, model_validator
6
6
 
@@ -9,9 +9,11 @@ from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_
9
9
  from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
10
10
  from letta.schemas.embedding_config import EmbeddingConfig
11
11
  from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES
12
+ from letta.schemas.enums import ProviderType
12
13
  from letta.schemas.letta_base import LettaBase
13
14
  from letta.schemas.llm_config import LLMConfig
14
15
  from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES
16
+ from letta.settings import model_settings
15
17
 
16
18
 
17
19
  class ProviderBase(LettaBase):
@@ -21,10 +23,18 @@ class ProviderBase(LettaBase):
21
23
  class Provider(ProviderBase):
22
24
  id: Optional[str] = Field(None, description="The id of the provider, lazily created by the database manager.")
23
25
  name: str = Field(..., description="The name of the provider")
26
+ provider_type: ProviderType = Field(..., description="The type of the provider")
24
27
  api_key: Optional[str] = Field(None, description="API key used for requests to the provider.")
28
+ base_url: Optional[str] = Field(None, description="Base URL for the provider.")
25
29
  organization_id: Optional[str] = Field(None, description="The organization id of the user")
26
30
  updated_at: Optional[datetime] = Field(None, description="The last update timestamp of the provider.")
27
31
 
32
+ @model_validator(mode="after")
33
+ def default_base_url(self):
34
+ if self.provider_type == ProviderType.openai and self.base_url is None:
35
+ self.base_url = model_settings.openai_api_base
36
+ return self
37
+
28
38
  def resolve_identifier(self):
29
39
  if not self.id:
30
40
  self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__)
@@ -59,9 +69,41 @@ class Provider(ProviderBase):
59
69
 
60
70
  return f"{self.name}/{model_name}"
61
71
 
72
+ def cast_to_subtype(self):
73
+ match (self.provider_type):
74
+ case ProviderType.letta:
75
+ return LettaProvider(**self.model_dump(exclude_none=True))
76
+ case ProviderType.openai:
77
+ return OpenAIProvider(**self.model_dump(exclude_none=True))
78
+ case ProviderType.anthropic:
79
+ return AnthropicProvider(**self.model_dump(exclude_none=True))
80
+ case ProviderType.anthropic_bedrock:
81
+ return AnthropicBedrockProvider(**self.model_dump(exclude_none=True))
82
+ case ProviderType.ollama:
83
+ return OllamaProvider(**self.model_dump(exclude_none=True))
84
+ case ProviderType.google_ai:
85
+ return GoogleAIProvider(**self.model_dump(exclude_none=True))
86
+ case ProviderType.google_vertex:
87
+ return GoogleVertexProvider(**self.model_dump(exclude_none=True))
88
+ case ProviderType.azure:
89
+ return AzureProvider(**self.model_dump(exclude_none=True))
90
+ case ProviderType.groq:
91
+ return GroqProvider(**self.model_dump(exclude_none=True))
92
+ case ProviderType.together:
93
+ return TogetherProvider(**self.model_dump(exclude_none=True))
94
+ case ProviderType.vllm_chat_completions:
95
+ return VLLMChatCompletionsProvider(**self.model_dump(exclude_none=True))
96
+ case ProviderType.vllm_completions:
97
+ return VLLMCompletionsProvider(**self.model_dump(exclude_none=True))
98
+ case ProviderType.xai:
99
+ return XAIProvider(**self.model_dump(exclude_none=True))
100
+ case _:
101
+ raise ValueError(f"Unknown provider type: {self.provider_type}")
102
+
62
103
 
63
104
  class ProviderCreate(ProviderBase):
64
105
  name: str = Field(..., description="The name of the provider.")
106
+ provider_type: ProviderType = Field(..., description="The type of the provider.")
65
107
  api_key: str = Field(..., description="API key used for requests to the provider.")
66
108
 
67
109
 
@@ -70,8 +112,7 @@ class ProviderUpdate(ProviderBase):
70
112
 
71
113
 
72
114
  class LettaProvider(Provider):
73
-
74
- name: str = "letta"
115
+ provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.")
75
116
 
76
117
  def list_llm_models(self) -> List[LLMConfig]:
77
118
  return [
@@ -81,6 +122,7 @@ class LettaProvider(Provider):
81
122
  model_endpoint=LETTA_MODEL_ENDPOINT,
82
123
  context_window=8192,
83
124
  handle=self.get_handle("letta-free"),
125
+ provider_name=self.name,
84
126
  )
85
127
  ]
86
128
 
@@ -98,7 +140,7 @@ class LettaProvider(Provider):
98
140
 
99
141
 
100
142
  class OpenAIProvider(Provider):
101
- name: str = "openai"
143
+ provider_type: Literal[ProviderType.openai] = Field(ProviderType.openai, description="The type of the provider.")
102
144
  api_key: str = Field(..., description="API key for the OpenAI API.")
103
145
  base_url: str = Field(..., description="Base URL for the OpenAI API.")
104
146
 
@@ -180,6 +222,7 @@ class OpenAIProvider(Provider):
180
222
  model_endpoint=self.base_url,
181
223
  context_window=context_window_size,
182
224
  handle=self.get_handle(model_name),
225
+ provider_name=self.name,
183
226
  )
184
227
  )
185
228
 
@@ -235,7 +278,7 @@ class DeepSeekProvider(OpenAIProvider):
235
278
  * It also does not support native function calling
236
279
  """
237
280
 
238
- name: str = "deepseek"
281
+ provider_type: Literal[ProviderType.deepseek] = Field(ProviderType.deepseek, description="The type of the provider.")
239
282
  base_url: str = Field("https://api.deepseek.com/v1", description="Base URL for the DeepSeek API.")
240
283
  api_key: str = Field(..., description="API key for the DeepSeek API.")
241
284
 
@@ -286,6 +329,7 @@ class DeepSeekProvider(OpenAIProvider):
286
329
  context_window=context_window_size,
287
330
  handle=self.get_handle(model_name),
288
331
  put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
332
+ provider_name=self.name,
289
333
  )
290
334
  )
291
335
 
@@ -297,7 +341,7 @@ class DeepSeekProvider(OpenAIProvider):
297
341
 
298
342
 
299
343
  class LMStudioOpenAIProvider(OpenAIProvider):
300
- name: str = "lmstudio-openai"
344
+ provider_type: Literal[ProviderType.lmstudio_openai] = Field(ProviderType.lmstudio_openai, description="The type of the provider.")
301
345
  base_url: str = Field(..., description="Base URL for the LMStudio OpenAI API.")
302
346
  api_key: Optional[str] = Field(None, description="API key for the LMStudio API.")
303
347
 
@@ -423,7 +467,7 @@ class LMStudioOpenAIProvider(OpenAIProvider):
423
467
  class XAIProvider(OpenAIProvider):
424
468
  """https://docs.x.ai/docs/api-reference"""
425
469
 
426
- name: str = "xai"
470
+ provider_type: Literal[ProviderType.xai] = Field(ProviderType.xai, description="The type of the provider.")
427
471
  api_key: str = Field(..., description="API key for the xAI/Grok API.")
428
472
  base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.")
429
473
 
@@ -476,6 +520,7 @@ class XAIProvider(OpenAIProvider):
476
520
  model_endpoint=self.base_url,
477
521
  context_window=context_window_size,
478
522
  handle=self.get_handle(model_name),
523
+ provider_name=self.name,
479
524
  )
480
525
  )
481
526
 
@@ -487,7 +532,7 @@ class XAIProvider(OpenAIProvider):
487
532
 
488
533
 
489
534
  class AnthropicProvider(Provider):
490
- name: str = "anthropic"
535
+ provider_type: Literal[ProviderType.anthropic] = Field(ProviderType.anthropic, description="The type of the provider.")
491
536
  api_key: str = Field(..., description="API key for the Anthropic API.")
492
537
  base_url: str = "https://api.anthropic.com/v1"
493
538
 
@@ -563,6 +608,7 @@ class AnthropicProvider(Provider):
563
608
  handle=self.get_handle(model["id"]),
564
609
  put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
565
610
  max_tokens=max_tokens,
611
+ provider_name=self.name,
566
612
  )
567
613
  )
568
614
  return configs
@@ -572,7 +618,7 @@ class AnthropicProvider(Provider):
572
618
 
573
619
 
574
620
  class MistralProvider(Provider):
575
- name: str = "mistral"
621
+ provider_type: Literal[ProviderType.mistral] = Field(ProviderType.mistral, description="The type of the provider.")
576
622
  api_key: str = Field(..., description="API key for the Mistral API.")
577
623
  base_url: str = "https://api.mistral.ai/v1"
578
624
 
@@ -596,6 +642,7 @@ class MistralProvider(Provider):
596
642
  model_endpoint=self.base_url,
597
643
  context_window=model["max_context_length"],
598
644
  handle=self.get_handle(model["id"]),
645
+ provider_name=self.name,
599
646
  )
600
647
  )
601
648
 
@@ -622,7 +669,7 @@ class OllamaProvider(OpenAIProvider):
622
669
  See: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
623
670
  """
624
671
 
625
- name: str = "ollama"
672
+ provider_type: Literal[ProviderType.ollama] = Field(ProviderType.ollama, description="The type of the provider.")
626
673
  base_url: str = Field(..., description="Base URL for the Ollama API.")
627
674
  api_key: Optional[str] = Field(None, description="API key for the Ollama API (default: `None`).")
628
675
  default_prompt_formatter: str = Field(
@@ -652,6 +699,7 @@ class OllamaProvider(OpenAIProvider):
652
699
  model_wrapper=self.default_prompt_formatter,
653
700
  context_window=context_window,
654
701
  handle=self.get_handle(model["name"]),
702
+ provider_name=self.name,
655
703
  )
656
704
  )
657
705
  return configs
@@ -734,7 +782,7 @@ class OllamaProvider(OpenAIProvider):
734
782
 
735
783
 
736
784
  class GroqProvider(OpenAIProvider):
737
- name: str = "groq"
785
+ provider_type: Literal[ProviderType.groq] = Field(ProviderType.groq, description="The type of the provider.")
738
786
  base_url: str = "https://api.groq.com/openai/v1"
739
787
  api_key: str = Field(..., description="API key for the Groq API.")
740
788
 
@@ -753,6 +801,7 @@ class GroqProvider(OpenAIProvider):
753
801
  model_endpoint=self.base_url,
754
802
  context_window=model["context_window"],
755
803
  handle=self.get_handle(model["id"]),
804
+ provider_name=self.name,
756
805
  )
757
806
  )
758
807
  return configs
@@ -773,7 +822,7 @@ class TogetherProvider(OpenAIProvider):
773
822
  function calling support is limited.
774
823
  """
775
824
 
776
- name: str = "together"
825
+ provider_type: Literal[ProviderType.together] = Field(ProviderType.together, description="The type of the provider.")
777
826
  base_url: str = "https://api.together.ai/v1"
778
827
  api_key: str = Field(..., description="API key for the TogetherAI API.")
779
828
  default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
@@ -821,6 +870,7 @@ class TogetherProvider(OpenAIProvider):
821
870
  model_wrapper=self.default_prompt_formatter,
822
871
  context_window=context_window_size,
823
872
  handle=self.get_handle(model_name),
873
+ provider_name=self.name,
824
874
  )
825
875
  )
826
876
 
@@ -874,7 +924,7 @@ class TogetherProvider(OpenAIProvider):
874
924
 
875
925
  class GoogleAIProvider(Provider):
876
926
  # gemini
877
- name: str = "google_ai"
927
+ provider_type: Literal[ProviderType.google_ai] = Field(ProviderType.google_ai, description="The type of the provider.")
878
928
  api_key: str = Field(..., description="API key for the Google AI API.")
879
929
  base_url: str = "https://generativelanguage.googleapis.com"
880
930
 
@@ -889,7 +939,6 @@ class GoogleAIProvider(Provider):
889
939
  # filter by model names
890
940
  model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
891
941
 
892
- # TODO remove manual filtering for gemini-pro
893
942
  # Add support for all gemini models
894
943
  model_options = [mo for mo in model_options if str(mo).startswith("gemini-")]
895
944
 
@@ -903,6 +952,7 @@ class GoogleAIProvider(Provider):
903
952
  context_window=self.get_model_context_window(model),
904
953
  handle=self.get_handle(model),
905
954
  max_tokens=8192,
955
+ provider_name=self.name,
906
956
  )
907
957
  )
908
958
  return configs
@@ -938,7 +988,7 @@ class GoogleAIProvider(Provider):
938
988
 
939
989
 
940
990
  class GoogleVertexProvider(Provider):
941
- name: str = "google_vertex"
991
+ provider_type: Literal[ProviderType.google_vertex] = Field(ProviderType.google_vertex, description="The type of the provider.")
942
992
  google_cloud_project: str = Field(..., description="GCP project ID for the Google Vertex API.")
943
993
  google_cloud_location: str = Field(..., description="GCP region for the Google Vertex API.")
944
994
 
@@ -955,6 +1005,7 @@ class GoogleVertexProvider(Provider):
955
1005
  context_window=context_length,
956
1006
  handle=self.get_handle(model),
957
1007
  max_tokens=8192,
1008
+ provider_name=self.name,
958
1009
  )
959
1010
  )
960
1011
  return configs
@@ -978,7 +1029,7 @@ class GoogleVertexProvider(Provider):
978
1029
 
979
1030
 
980
1031
  class AzureProvider(Provider):
981
- name: str = "azure"
1032
+ provider_type: Literal[ProviderType.azure] = Field(ProviderType.azure, description="The type of the provider.")
982
1033
  latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
983
1034
  base_url: str = Field(
984
1035
  ..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`."
@@ -1011,6 +1062,7 @@ class AzureProvider(Provider):
1011
1062
  model_endpoint=model_endpoint,
1012
1063
  context_window=context_window_size,
1013
1064
  handle=self.get_handle(model_name),
1065
+ provider_name=self.name,
1014
1066
  ),
1015
1067
  )
1016
1068
  return configs
@@ -1051,7 +1103,7 @@ class VLLMChatCompletionsProvider(Provider):
1051
1103
  """vLLM provider that treats vLLM as an OpenAI /chat/completions proxy"""
1052
1104
 
1053
1105
  # NOTE: vLLM only serves one model at a time (so could configure that through env variables)
1054
- name: str = "vllm"
1106
+ provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
1055
1107
  base_url: str = Field(..., description="Base URL for the vLLM API.")
1056
1108
 
1057
1109
  def list_llm_models(self) -> List[LLMConfig]:
@@ -1070,6 +1122,7 @@ class VLLMChatCompletionsProvider(Provider):
1070
1122
  model_endpoint=self.base_url,
1071
1123
  context_window=model["max_model_len"],
1072
1124
  handle=self.get_handle(model["id"]),
1125
+ provider_name=self.name,
1073
1126
  )
1074
1127
  )
1075
1128
  return configs
@@ -1083,7 +1136,7 @@ class VLLMCompletionsProvider(Provider):
1083
1136
  """This uses /completions API as the backend, not /chat/completions, so we need to specify a model wrapper"""
1084
1137
 
1085
1138
  # NOTE: vLLM only serves one model at a time (so could configure that through env variables)
1086
- name: str = "vllm"
1139
+ provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
1087
1140
  base_url: str = Field(..., description="Base URL for the vLLM API.")
1088
1141
  default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
1089
1142
 
@@ -1103,6 +1156,7 @@ class VLLMCompletionsProvider(Provider):
1103
1156
  model_wrapper=self.default_prompt_formatter,
1104
1157
  context_window=model["max_model_len"],
1105
1158
  handle=self.get_handle(model["id"]),
1159
+ provider_name=self.name,
1106
1160
  )
1107
1161
  )
1108
1162
  return configs
@@ -1117,7 +1171,7 @@ class CohereProvider(OpenAIProvider):
1117
1171
 
1118
1172
 
1119
1173
  class AnthropicBedrockProvider(Provider):
1120
- name: str = "bedrock"
1174
+ provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.")
1121
1175
  aws_region: str = Field(..., description="AWS region for Bedrock")
1122
1176
 
1123
1177
  def list_llm_models(self):
@@ -1131,10 +1185,11 @@ class AnthropicBedrockProvider(Provider):
1131
1185
  configs.append(
1132
1186
  LLMConfig(
1133
1187
  model=model_arn,
1134
- model_endpoint_type=self.name,
1188
+ model_endpoint_type=self.provider_type.value,
1135
1189
  model_endpoint=None,
1136
1190
  context_window=self.get_model_context_window(model_arn),
1137
1191
  handle=self.get_handle(model_arn),
1192
+ provider_name=self.name,
1138
1193
  )
1139
1194
  )
1140
1195
  return configs
letta/schemas/tool.py CHANGED
@@ -11,13 +11,9 @@ from letta.constants import (
11
11
  MCP_TOOL_TAG_NAME_PREFIX,
12
12
  )
13
13
  from letta.functions.ast_parsers import get_function_name_and_description
14
+ from letta.functions.composio_helpers import generate_composio_tool_wrapper
14
15
  from letta.functions.functions import derive_openai_json_schema, get_json_schema_from_module
15
- from letta.functions.helpers import (
16
- generate_composio_tool_wrapper,
17
- generate_langchain_tool_wrapper,
18
- generate_mcp_tool_wrapper,
19
- generate_model_from_args_json_schema,
20
- )
16
+ from letta.functions.helpers import generate_langchain_tool_wrapper, generate_mcp_tool_wrapper, generate_model_from_args_json_schema
21
17
  from letta.functions.mcp_client.types import MCPTool
22
18
  from letta.functions.schema_generator import (
23
19
  generate_schema_from_args_schema_v2,
@@ -176,8 +172,7 @@ class ToolCreate(LettaBase):
176
172
  Returns:
177
173
  Tool: A Letta Tool initialized with attributes derived from the Composio tool.
178
174
  """
179
- from composio import LogLevel
180
- from composio_langchain import ComposioToolSet
175
+ from composio import ComposioToolSet, LogLevel
181
176
 
182
177
  composio_toolset = ComposioToolSet(logging_level=LogLevel.ERROR, lock=False)
183
178
  composio_action_schemas = composio_toolset.get_action_schemas(actions=[action_name], check_connected_accounts=False)
@@ -14,6 +14,7 @@ from starlette.middleware.base import BaseHTTPMiddleware
14
14
  from starlette.middleware.cors import CORSMiddleware
15
15
 
16
16
  from letta.__init__ import __version__
17
+ from letta.agents.exceptions import IncompatibleAgentType
17
18
  from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
18
19
  from letta.errors import BedrockPermissionError, LettaAgentNotFoundError, LettaUserNotFoundError
19
20
  from letta.jobs.scheduler import shutdown_cron_scheduler, start_cron_jobs
@@ -173,6 +174,17 @@ def create_application() -> "FastAPI":
173
174
  def shutdown_scheduler():
174
175
  shutdown_cron_scheduler()
175
176
 
177
+ @app.exception_handler(IncompatibleAgentType)
178
+ async def handle_incompatible_agent_type(request: Request, exc: IncompatibleAgentType):
179
+ return JSONResponse(
180
+ status_code=400,
181
+ content={
182
+ "detail": str(exc),
183
+ "expected_type": exc.expected_type,
184
+ "actual_type": exc.actual_type,
185
+ },
186
+ )
187
+
176
188
  @app.exception_handler(Exception)
177
189
  async def generic_error_handler(request: Request, exc: Exception):
178
190
  # Log the actual error for debugging
@@ -12,7 +12,7 @@ from letta.schemas.enums import MessageStreamStatus
12
12
  from letta.schemas.letta_message import LettaMessage
13
13
  from letta.schemas.message import Message
14
14
  from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse
15
- from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
15
+ from letta.server.rest_api.json_parser import OptimisticJSONParser
16
16
  from letta.streaming_interface import AgentChunkStreamingInterface
17
17
 
18
18
  logger = get_logger(__name__)