letta-nightly 0.12.0.dev20251009104148__py3-none-any.whl → 0.12.1.dev20251009224219__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -336,14 +336,16 @@ class OpenAIStreamingInterface:
336
336
  step_id=self.step_id,
337
337
  )
338
338
  else:
339
+ tool_call_delta = ToolCallDelta(
340
+ name=self.function_name_buffer,
341
+ arguments=None,
342
+ tool_call_id=self.function_id_buffer,
343
+ )
339
344
  tool_call_msg = ToolCallMessage(
340
345
  id=self.letta_message_id,
341
346
  date=datetime.now(timezone.utc),
342
- tool_call=ToolCallDelta(
343
- name=self.function_name_buffer,
344
- arguments=None,
345
- tool_call_id=self.function_id_buffer,
346
- ),
347
+ tool_call=tool_call_delta,
348
+ tool_calls=tool_call_delta,
347
349
  otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
348
350
  run_id=self.run_id,
349
351
  step_id=self.step_id,
@@ -423,14 +425,16 @@ class OpenAIStreamingInterface:
423
425
  step_id=self.step_id,
424
426
  )
425
427
  else:
428
+ tool_call_delta = ToolCallDelta(
429
+ name=self.function_name_buffer,
430
+ arguments=combined_chunk,
431
+ tool_call_id=self.function_id_buffer,
432
+ )
426
433
  tool_call_msg = ToolCallMessage(
427
434
  id=self.letta_message_id,
428
435
  date=datetime.now(timezone.utc),
429
- tool_call=ToolCallDelta(
430
- name=self.function_name_buffer,
431
- arguments=combined_chunk,
432
- tool_call_id=self.function_id_buffer,
433
- ),
436
+ tool_call=tool_call_delta,
437
+ tool_calls=tool_call_delta,
434
438
  # name=name,
435
439
  otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
436
440
  run_id=self.run_id,
@@ -460,14 +464,16 @@ class OpenAIStreamingInterface:
460
464
  step_id=self.step_id,
461
465
  )
462
466
  else:
467
+ tool_call_delta = ToolCallDelta(
468
+ name=None,
469
+ arguments=updates_main_json,
470
+ tool_call_id=self.function_id_buffer,
471
+ )
463
472
  tool_call_msg = ToolCallMessage(
464
473
  id=self.letta_message_id,
465
474
  date=datetime.now(timezone.utc),
466
- tool_call=ToolCallDelta(
467
- name=None,
468
- arguments=updates_main_json,
469
- tool_call_id=self.function_id_buffer,
470
- ),
475
+ tool_call=tool_call_delta,
476
+ tool_calls=tool_call_delta,
471
477
  # name=name,
472
478
  otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
473
479
  run_id=self.run_id,
@@ -717,14 +723,16 @@ class SimpleOpenAIStreamingInterface:
717
723
  step_id=self.step_id,
718
724
  )
719
725
  else:
726
+ tool_call_delta = ToolCallDelta(
727
+ name=tool_call.function.name,
728
+ arguments=tool_call.function.arguments,
729
+ tool_call_id=tool_call.id,
730
+ )
720
731
  tool_call_msg = ToolCallMessage(
721
732
  id=self.letta_message_id,
722
733
  date=datetime.now(timezone.utc),
723
- tool_call=ToolCallDelta(
724
- name=tool_call.function.name,
725
- arguments=tool_call.function.arguments,
726
- tool_call_id=tool_call.id,
727
- ),
734
+ tool_call=tool_call_delta,
735
+ tool_calls=tool_call_delta,
728
736
  # name=name,
729
737
  otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
730
738
  run_id=self.run_id,
@@ -945,15 +953,17 @@ class SimpleOpenAIResponsesStreamingInterface:
945
953
  else:
946
954
  if prev_message_type and prev_message_type != "tool_call_message":
947
955
  message_index += 1
956
+ tool_call_delta = ToolCallDelta(
957
+ name=name,
958
+ arguments=arguments if arguments != "" else None,
959
+ tool_call_id=call_id,
960
+ )
948
961
  yield ToolCallMessage(
949
962
  id=self.letta_message_id,
950
963
  otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
951
964
  date=datetime.now(timezone.utc),
952
- tool_call=ToolCallDelta(
953
- name=name,
954
- arguments=arguments if arguments != "" else None,
955
- tool_call_id=call_id,
956
- ),
965
+ tool_call=tool_call_delta,
966
+ tool_calls=tool_call_delta,
957
967
  run_id=self.run_id,
958
968
  step_id=self.step_id,
959
969
  )
@@ -1113,15 +1123,17 @@ class SimpleOpenAIResponsesStreamingInterface:
1113
1123
  else:
1114
1124
  if prev_message_type and prev_message_type != "tool_call_message":
1115
1125
  message_index += 1
1126
+ tool_call_delta = ToolCallDelta(
1127
+ name=None,
1128
+ arguments=delta,
1129
+ tool_call_id=None,
1130
+ )
1116
1131
  yield ToolCallMessage(
1117
1132
  id=self.letta_message_id,
1118
1133
  otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
1119
1134
  date=datetime.now(timezone.utc),
1120
- tool_call=ToolCallDelta(
1121
- name=None,
1122
- arguments=delta,
1123
- tool_call_id=None,
1124
- ),
1135
+ tool_call=tool_call_delta,
1136
+ tool_calls=tool_call_delta,
1125
1137
  run_id=self.run_id,
1126
1138
  step_id=self.step_id,
1127
1139
  )
@@ -56,6 +56,9 @@ class AnthropicClient(LLMClientBase):
56
56
  def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
57
57
  client = self._get_anthropic_client(llm_config, async_client=False)
58
58
  betas: list[str] = []
59
+ # Interleaved thinking for reasoner (sync path parity)
60
+ if llm_config.enable_reasoner:
61
+ betas.append("interleaved-thinking-2025-05-14")
59
62
  # 1M context beta for Sonnet 4/4.5 when enabled
60
63
  try:
61
64
  from letta.settings import model_settings
@@ -371,6 +374,7 @@ class AnthropicClient(LLMClientBase):
371
374
  async def count_tokens(self, messages: List[dict] = None, model: str = None, tools: List[OpenAITool] = None) -> int:
372
375
  logging.getLogger("httpx").setLevel(logging.WARNING)
373
376
 
377
+ # Use the default client; token counting is lightweight and does not require BYOK overrides
374
378
  client = anthropic.AsyncAnthropic()
375
379
  if messages and len(messages) == 0:
376
380
  messages = None
@@ -379,23 +383,20 @@ class AnthropicClient(LLMClientBase):
379
383
  else:
380
384
  anthropic_tools = None
381
385
 
386
+ # Detect presence of reasoning blocks anywhere in the final assistant message.
387
+ # Interleaved thinking is not guaranteed to be the first content part.
382
388
  thinking_enabled = False
383
389
  if messages and len(messages) > 0:
384
- # Check if the last assistant message starts with a thinking block
385
- # Find the last assistant message
386
- last_assistant_message = None
387
- for message in reversed(messages):
388
- if message.get("role") == "assistant":
389
- last_assistant_message = message
390
- break
391
-
392
- if (
393
- last_assistant_message
394
- and isinstance(last_assistant_message.get("content"), list)
395
- and len(last_assistant_message["content"]) > 0
396
- and last_assistant_message["content"][0].get("type") == "thinking"
397
- ):
398
- thinking_enabled = True
390
+ last_assistant_message = next((m for m in reversed(messages) if m.get("role") == "assistant"), None)
391
+ if last_assistant_message:
392
+ content = last_assistant_message.get("content")
393
+ if isinstance(content, list):
394
+ for part in content:
395
+ if isinstance(part, dict) and part.get("type") in {"thinking", "redacted_thinking"}:
396
+ thinking_enabled = True
397
+ break
398
+ elif isinstance(content, str) and "<thinking>" in content:
399
+ thinking_enabled = True
399
400
 
400
401
  try:
401
402
  count_params = {
@@ -404,9 +405,27 @@ class AnthropicClient(LLMClientBase):
404
405
  "tools": anthropic_tools or [],
405
406
  }
406
407
 
408
+ betas: list[str] = []
407
409
  if thinking_enabled:
410
+ # Match interleaved thinking behavior so token accounting is consistent
408
411
  count_params["thinking"] = {"type": "enabled", "budget_tokens": 16000}
409
- result = await client.beta.messages.count_tokens(**count_params)
412
+ betas.append("interleaved-thinking-2025-05-14")
413
+
414
+ # Opt-in to 1M context if enabled for this model in settings
415
+ try:
416
+ if (
417
+ model
418
+ and model_settings.anthropic_sonnet_1m
419
+ and (model.startswith("claude-sonnet-4") or model.startswith("claude-sonnet-4-5"))
420
+ ):
421
+ betas.append("context-1m-2025-08-07")
422
+ except Exception:
423
+ pass
424
+
425
+ if betas:
426
+ result = await client.beta.messages.count_tokens(**count_params, betas=betas)
427
+ else:
428
+ result = await client.beta.messages.count_tokens(**count_params)
410
429
  except:
411
430
  raise
412
431
 
@@ -420,6 +420,17 @@ class OpenAIClient(LLMClientBase):
420
420
  logger.warning(f"Model type not set in llm_config: {llm_config.model_dump_json(indent=4)}")
421
421
  model = None
422
422
 
423
+ # TODO: we may need to extend this to more models using proxy?
424
+ is_openrouter = (llm_config.model_endpoint and "openrouter.ai" in llm_config.model_endpoint) or (
425
+ llm_config.provider_name == "openrouter"
426
+ )
427
+ if is_openrouter:
428
+ try:
429
+ model = llm_config.handle.split("/", 1)[-1]
430
+ except:
431
+ # don't raise error since this isn't robust against edge cases
432
+ pass
433
+
423
434
  # force function calling for reliability, see https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
424
435
  # TODO(matt) move into LLMConfig
425
436
  # TODO: This vllm checking is very brittle and is a patch at most
@@ -3,6 +3,8 @@ from typing import Optional
3
3
  from pydantic import Field
4
4
 
5
5
  from letta.schemas.letta_base import LettaBase, OrmMetadataBase
6
+ from letta.schemas.secret import Secret
7
+ from letta.settings import settings
6
8
 
7
9
 
8
10
  # Base Environment Variable
@@ -13,6 +15,28 @@ class EnvironmentVariableBase(OrmMetadataBase):
13
15
  description: Optional[str] = Field(None, description="An optional description of the environment variable.")
14
16
  organization_id: Optional[str] = Field(None, description="The ID of the organization this environment variable belongs to.")
15
17
 
18
+ # Encrypted field (stored as Secret object, serialized to string for DB)
19
+ # Secret class handles validation and serialization automatically via __get_pydantic_core_schema__
20
+ value_enc: Secret | None = Field(None, description="Encrypted value as Secret object")
21
+
22
+ def get_value_secret(self) -> Secret:
23
+ """Get the value as a Secret object, preferring encrypted over plaintext."""
24
+ # If value_enc is already a Secret, return it
25
+ if self.value_enc is not None:
26
+ return self.value_enc
27
+ # Otherwise, create from plaintext value
28
+ return Secret.from_db(None, self.value)
29
+
30
+ def set_value_secret(self, secret: Secret) -> None:
31
+ """Set value from a Secret object, directly storing the Secret."""
32
+ self.value_enc = secret
33
+ # Also update plaintext field for dual-write during migration
34
+ secret_dict = secret.to_dict()
35
+ if not secret.was_encrypted:
36
+ self.value = secret_dict["plaintext"]
37
+ else:
38
+ self.value = None
39
+
16
40
 
17
41
  class EnvironmentVariableCreateBase(LettaBase):
18
42
  key: str = Field(..., description="The name of the environment variable.")
@@ -190,7 +190,8 @@ class ToolCallMessage(LettaMessage):
190
190
  message_type: Literal[MessageType.tool_call_message] = Field(
191
191
  default=MessageType.tool_call_message, description="The type of the message."
192
192
  )
193
- tool_call: Union[ToolCall, ToolCallDelta]
193
+ tool_call: Union[ToolCall, ToolCallDelta] = Field(..., deprecated=True)
194
+ tool_calls: Optional[Union[List[ToolCall], ToolCallDelta]] = None
194
195
 
195
196
  def model_dump(self, *args, **kwargs):
196
197
  """
@@ -198,8 +199,14 @@ class ToolCallMessage(LettaMessage):
198
199
  """
199
200
  kwargs["exclude_none"] = True
200
201
  data = super().model_dump(*args, **kwargs)
201
- if isinstance(data["tool_call"], dict):
202
+ if isinstance(data.get("tool_call"), dict):
202
203
  data["tool_call"] = {k: v for k, v in data["tool_call"].items() if v is not None}
204
+ if isinstance(data.get("tool_calls"), dict):
205
+ data["tool_calls"] = {k: v for k, v in data["tool_calls"].items() if v is not None}
206
+ elif isinstance(data.get("tool_calls"), list):
207
+ data["tool_calls"] = [
208
+ {k: v for k, v in item.items() if v is not None} if isinstance(item, dict) else item for item in data["tool_calls"]
209
+ ]
203
210
  return data
204
211
 
205
212
  class Config:
@@ -226,6 +233,14 @@ class ToolCallMessage(LettaMessage):
226
233
  return v
227
234
 
228
235
 
236
+ class ToolReturn(BaseModel):
237
+ tool_return: str
238
+ status: Literal["success", "error"]
239
+ tool_call_id: str
240
+ stdout: Optional[List[str]] = None
241
+ stderr: Optional[List[str]] = None
242
+
243
+
229
244
  class ToolReturnMessage(LettaMessage):
230
245
  """
231
246
  A message representing the return value of a tool call (generated by Letta executing the requested tool).
@@ -234,21 +249,23 @@ class ToolReturnMessage(LettaMessage):
234
249
  id (str): The ID of the message
235
250
  date (datetime): The date the message was created in ISO format
236
251
  name (Optional[str]): The name of the sender of the message
237
- tool_return (str): The return value of the tool
238
- status (Literal["success", "error"]): The status of the tool call
239
- tool_call_id (str): A unique identifier for the tool call that generated this message
240
- stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the tool invocation
241
- stderr (Optional[List(str)]): Captured stderr from the tool invocation
252
+ tool_return (str): The return value of the tool (deprecated, use tool_returns)
253
+ status (Literal["success", "error"]): The status of the tool call (deprecated, use tool_returns)
254
+ tool_call_id (str): A unique identifier for the tool call that generated this message (deprecated, use tool_returns)
255
+ stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the tool invocation (deprecated, use tool_returns)
256
+ stderr (Optional[List(str)]): Captured stderr from the tool invocation (deprecated, use tool_returns)
257
+ tool_returns (Optional[List[ToolReturn]]): List of tool returns for multi-tool support
242
258
  """
243
259
 
244
260
  message_type: Literal[MessageType.tool_return_message] = Field(
245
261
  default=MessageType.tool_return_message, description="The type of the message."
246
262
  )
247
- tool_return: str
248
- status: Literal["success", "error"]
249
- tool_call_id: str
250
- stdout: Optional[List[str]] = None
251
- stderr: Optional[List[str]] = None
263
+ tool_return: str = Field(..., deprecated=True)
264
+ status: Literal["success", "error"] = Field(..., deprecated=True)
265
+ tool_call_id: str = Field(..., deprecated=True)
266
+ stdout: Optional[List[str]] = Field(None, deprecated=True)
267
+ stderr: Optional[List[str]] = Field(None, deprecated=True)
268
+ tool_returns: Optional[List[ToolReturn]] = None
252
269
 
253
270
 
254
271
  class ApprovalRequestMessage(LettaMessage):
letta/schemas/message.py CHANGED
@@ -492,23 +492,27 @@ class Message(BaseMessage):
492
492
  assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
493
493
  ) -> List[LettaMessage]:
494
494
  messages = []
495
- # This is type FunctionCall
496
- for tool_call in self.tool_calls:
497
- otid = Message.generate_otid_from_id(self.id, current_message_count + len(messages))
498
- # If we're supporting using assistant message,
499
- # then we want to treat certain function calls as a special case
500
- if use_assistant_message and tool_call.function.name == assistant_message_tool_name:
501
- # We need to unpack the actual message contents from the function call
502
- try:
503
- func_args = parse_json(tool_call.function.arguments)
504
- message_string = validate_function_response(func_args[assistant_message_tool_kwarg], 0, truncate=False)
505
- except KeyError:
506
- raise ValueError(f"Function call {tool_call.function.name} missing {assistant_message_tool_kwarg} argument")
495
+
496
+ # If assistant mode is off, just create one ToolCallMessage with all tool calls
497
+ if not use_assistant_message:
498
+ all_tool_call_objs = [
499
+ ToolCall(
500
+ name=tool_call.function.name,
501
+ arguments=tool_call.function.arguments,
502
+ tool_call_id=tool_call.id,
503
+ )
504
+ for tool_call in self.tool_calls
505
+ ]
506
+
507
+ if all_tool_call_objs:
508
+ otid = Message.generate_otid_from_id(self.id, current_message_count)
507
509
  messages.append(
508
- AssistantMessage(
510
+ ToolCallMessage(
509
511
  id=self.id,
510
512
  date=self.created_at,
511
- content=message_string,
513
+ # use first tool call for the deprecated field
514
+ tool_call=all_tool_call_objs[0],
515
+ tool_calls=all_tool_call_objs,
512
516
  name=self.name,
513
517
  otid=otid,
514
518
  sender_id=self.sender_id,
@@ -517,16 +521,41 @@ class Message(BaseMessage):
517
521
  run_id=self.run_id,
518
522
  )
519
523
  )
520
- else:
524
+ return messages
525
+
526
+ collected_tool_calls = []
527
+
528
+ for tool_call in self.tool_calls:
529
+ otid = Message.generate_otid_from_id(self.id, current_message_count + len(messages))
530
+
531
+ if tool_call.function.name == assistant_message_tool_name:
532
+ if collected_tool_calls:
533
+ tool_call_message = ToolCallMessage(
534
+ id=self.id,
535
+ date=self.created_at,
536
+ # use first tool call for the deprecated field
537
+ tool_call=collected_tool_calls[0],
538
+ tool_calls=collected_tool_calls.copy(),
539
+ name=self.name,
540
+ otid=Message.generate_otid_from_id(self.id, current_message_count + len(messages)),
541
+ sender_id=self.sender_id,
542
+ step_id=self.step_id,
543
+ is_err=self.is_err,
544
+ run_id=self.run_id,
545
+ )
546
+ messages.append(tool_call_message)
547
+ collected_tool_calls = [] # reset the collection
548
+
549
+ try:
550
+ func_args = parse_json(tool_call.function.arguments)
551
+ message_string = validate_function_response(func_args[assistant_message_tool_kwarg], 0, truncate=False)
552
+ except KeyError:
553
+ raise ValueError(f"Function call {tool_call.function.name} missing {assistant_message_tool_kwarg} argument")
521
554
  messages.append(
522
- ToolCallMessage(
555
+ AssistantMessage(
523
556
  id=self.id,
524
557
  date=self.created_at,
525
- tool_call=ToolCall(
526
- name=tool_call.function.name,
527
- arguments=tool_call.function.arguments,
528
- tool_call_id=tool_call.id,
529
- ),
558
+ content=message_string,
530
559
  name=self.name,
531
560
  otid=otid,
532
561
  sender_id=self.sender_id,
@@ -535,6 +564,32 @@ class Message(BaseMessage):
535
564
  run_id=self.run_id,
536
565
  )
537
566
  )
567
+ else:
568
+ # non-assistant tool call, collect it
569
+ tool_call_obj = ToolCall(
570
+ name=tool_call.function.name,
571
+ arguments=tool_call.function.arguments,
572
+ tool_call_id=tool_call.id,
573
+ )
574
+ collected_tool_calls.append(tool_call_obj)
575
+
576
+ # flush any remaining collected tool calls
577
+ if collected_tool_calls:
578
+ tool_call_message = ToolCallMessage(
579
+ id=self.id,
580
+ date=self.created_at,
581
+ # use first tool call for the deprecated field
582
+ tool_call=collected_tool_calls[0],
583
+ tool_calls=collected_tool_calls,
584
+ name=self.name,
585
+ otid=Message.generate_otid_from_id(self.id, current_message_count + len(messages)),
586
+ sender_id=self.sender_id,
587
+ step_id=self.step_id,
588
+ is_err=self.is_err,
589
+ run_id=self.run_id,
590
+ )
591
+ messages.append(tool_call_message)
592
+
538
593
  return messages
539
594
 
540
595
  def _convert_tool_return_message(self) -> List[ToolReturnMessage]:
@@ -556,6 +611,13 @@ class Message(BaseMessage):
556
611
  if self.role != MessageRole.tool:
557
612
  raise ValueError(f"Cannot convert message of type {self.role} to ToolReturnMessage")
558
613
 
614
+ # This is a very special buggy case during the double writing period
615
+ # where there is no tool call id on the tool return object, but it exists top level
616
+ # This is meant to be a short term patch - this can happen when people are using old agent files that were exported
617
+ # during a specific migration state
618
+ if len(self.tool_returns) == 1 and self.tool_call_id and not self.tool_returns[0].tool_call_id:
619
+ self.tool_returns[0].tool_call_id = self.tool_call_id
620
+
559
621
  if self.tool_returns:
560
622
  return self._convert_explicit_tool_returns()
561
623
 
@@ -647,6 +709,16 @@ class Message(BaseMessage):
647
709
  Returns:
648
710
  Configured ToolReturnMessage instance
649
711
  """
712
+ from letta.schemas.letta_message import ToolReturn as ToolReturnSchema
713
+
714
+ tool_return_obj = ToolReturnSchema(
715
+ tool_return=message_text,
716
+ status=status,
717
+ tool_call_id=tool_call_id,
718
+ stdout=stdout,
719
+ stderr=stderr,
720
+ )
721
+
650
722
  return ToolReturnMessage(
651
723
  id=self.id,
652
724
  date=self.created_at,
@@ -655,6 +727,7 @@ class Message(BaseMessage):
655
727
  tool_call_id=tool_call_id,
656
728
  stdout=stdout,
657
729
  stderr=stderr,
730
+ tool_returns=[tool_return_obj],
658
731
  name=self.name,
659
732
  otid=Message.generate_otid_from_id(self.id, otid_index),
660
733
  sender_id=self.sender_id,
@@ -1625,6 +1698,14 @@ class Message(BaseMessage):
1625
1698
  if messages[-1].role == "approval" and messages[-1].tool_calls is not None and len(messages[-1].tool_calls) > 0:
1626
1699
  messages.remove(messages[-1])
1627
1700
 
1701
+ # Filter last message if it is a lone reasoning message without assistant message or tool call
1702
+ if (
1703
+ messages[-1].role == "assistant"
1704
+ and messages[-1].tool_calls is None
1705
+ and (not messages[-1].content or all(not isinstance(content_part, TextContent) for content_part in messages[-1].content))
1706
+ ):
1707
+ messages.remove(messages[-1])
1708
+
1628
1709
  return messages
1629
1710
 
1630
1711
  @staticmethod
@@ -8,6 +8,7 @@ from letta.schemas.enums import ProviderCategory, ProviderType
8
8
  from letta.schemas.letta_base import LettaBase
9
9
  from letta.schemas.llm_config import LLMConfig
10
10
  from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES
11
+ from letta.schemas.secret import Secret
11
12
  from letta.settings import model_settings
12
13
 
13
14
 
@@ -28,8 +29,14 @@ class Provider(ProviderBase):
28
29
  organization_id: str | None = Field(None, description="The organization id of the user")
29
30
  updated_at: datetime | None = Field(None, description="The last update timestamp of the provider.")
30
31
 
32
+ # Encrypted fields (stored as Secret objects, serialized to strings for DB)
33
+ # Secret class handles validation and serialization automatically via __get_pydantic_core_schema__
34
+ api_key_enc: Secret | None = Field(None, description="Encrypted API key as Secret object")
35
+ access_key_enc: Secret | None = Field(None, description="Encrypted access key as Secret object")
36
+
31
37
  @model_validator(mode="after")
32
38
  def default_base_url(self):
39
+ # Set default base URL
33
40
  if self.provider_type == ProviderType.openai and self.base_url is None:
34
41
  self.base_url = model_settings.openai_api_base
35
42
  return self
@@ -38,6 +45,42 @@ class Provider(ProviderBase):
38
45
  if not self.id:
39
46
  self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__)
40
47
 
48
+ def get_api_key_secret(self) -> Secret:
49
+ """Get the API key as a Secret object, preferring encrypted over plaintext."""
50
+ # If api_key_enc is already a Secret, return it
51
+ if self.api_key_enc is not None:
52
+ return self.api_key_enc
53
+ # Otherwise, create from plaintext api_key
54
+ return Secret.from_db(None, self.api_key)
55
+
56
+ def get_access_key_secret(self) -> Secret:
57
+ """Get the access key as a Secret object, preferring encrypted over plaintext."""
58
+ # If access_key_enc is already a Secret, return it
59
+ if self.access_key_enc is not None:
60
+ return self.access_key_enc
61
+ # Otherwise, create from plaintext access_key
62
+ return Secret.from_db(None, self.access_key)
63
+
64
+ def set_api_key_secret(self, secret: Secret) -> None:
65
+ """Set API key from a Secret object, directly storing the Secret."""
66
+ self.api_key_enc = secret
67
+ # Also update plaintext field for dual-write during migration
68
+ secret_dict = secret.to_dict()
69
+ if not secret.was_encrypted:
70
+ self.api_key = secret_dict["plaintext"]
71
+ else:
72
+ self.api_key = None
73
+
74
+ def set_access_key_secret(self, secret: Secret) -> None:
75
+ """Set access key from a Secret object, directly storing the Secret."""
76
+ self.access_key_enc = secret
77
+ # Also update plaintext field for dual-write during migration
78
+ secret_dict = secret.to_dict()
79
+ if not secret.was_encrypted:
80
+ self.access_key = secret_dict["plaintext"]
81
+ else:
82
+ self.access_key = None
83
+
41
84
  async def check_api_key(self):
42
85
  """Check if the API key is valid for the provider"""
43
86
  raise NotImplementedError